diff --git a/application/agents/base.py b/application/agents/base.py index 32d860b8..dff191a3 100644 --- a/application/agents/base.py +++ b/application/agents/base.py @@ -227,6 +227,7 @@ class BaseAgent(ABC): if tool_data["name"] == "api_tool" else tool_data["config"] ), + user_id=self.user, # Pass user ID for MCP tools credential decryption ) if tool_data["name"] == "api_tool": print( diff --git a/application/agents/tools/mcp_tool.py b/application/agents/tools/mcp_tool.py new file mode 100644 index 00000000..fb47d0ed --- /dev/null +++ b/application/agents/tools/mcp_tool.py @@ -0,0 +1,424 @@ +import json +import time +from typing import Any, Dict, List, Optional + +import requests + +from application.agents.tools.base import Tool + + +_mcp_session_cache = {} + + +class MCPTool(Tool): + """ + MCP Tool + Connect to remote Model Context Protocol (MCP) servers to access dynamic tools and resources. Supports various authentication methods and provides secure access to external services through the MCP protocol. + """ + + def __init__(self, config: Dict[str, Any], user_id: Optional[str] = None): + """ + Initialize the MCP Tool with configuration. + + Args: + config: Dictionary containing MCP server configuration: + - server_url: URL of the remote MCP server + - auth_type: Type of authentication (api_key, bearer, basic, none) + - encrypted_credentials: Encrypted credentials (if available) + - timeout: Request timeout in seconds (default: 30) + user_id: User ID for decrypting credentials (required if encrypted_credentials exist) + """ + self.config = config + self.server_url = config.get("server_url", "") + self.auth_type = config.get("auth_type", "none") + self.timeout = config.get("timeout", 30) + + # Decrypt credentials if they are encrypted + + self.auth_credentials = {} + if config.get("encrypted_credentials") and user_id: + from application.security.encryption import decrypt_credentials + + self.auth_credentials = decrypt_credentials( + config["encrypted_credentials"], user_id + ) + else: + # Fallback to unencrypted credentials (for backward compatibility) + + self.auth_credentials = config.get("auth_credentials", {}) + self.available_tools = [] + self._session = requests.Session() + self._mcp_session_id = None + self._setup_authentication() + self._cache_key = self._generate_cache_key() + + def _generate_cache_key(self) -> str: + """Generate a unique cache key for this MCP server configuration.""" + # Use server URL + auth info to create unique key + + auth_key = "" + if self.auth_type == "bearer": + token = self.auth_credentials.get("bearer_token", "") + auth_key = f"bearer:{token[:10]}..." if token else "bearer:none" + elif self.auth_type == "api_key": + api_key = self.auth_credentials.get("api_key", "") + auth_key = f"apikey:{api_key[:10]}..." if api_key else "apikey:none" + elif self.auth_type == "basic": + username = self.auth_credentials.get("username", "") + auth_key = f"basic:{username}" + else: + auth_key = "none" + return f"{self.server_url}#{auth_key}" + + def _get_cached_session(self) -> Optional[str]: + """Get cached session ID if available and not expired.""" + global _mcp_session_cache + + if self._cache_key in _mcp_session_cache: + session_data = _mcp_session_cache[self._cache_key] + # Check if session is less than 30 minutes old + + if time.time() - session_data["created_at"] < 1800: # 30 minutes + return session_data["session_id"] + else: + # Remove expired session + + del _mcp_session_cache[self._cache_key] + return None + + def _cache_session(self, session_id: str): + """Cache the session ID for reuse.""" + global _mcp_session_cache + _mcp_session_cache[self._cache_key] = { + "session_id": session_id, + "created_at": time.time(), + } + + def _setup_authentication(self): + """Setup authentication for the MCP server connection.""" + if self.auth_type == "api_key": + api_key = self.auth_credentials.get("api_key", "") + header_name = self.auth_credentials.get("api_key_header", "X-API-Key") + if api_key: + self._session.headers.update({header_name: api_key}) + elif self.auth_type == "bearer": + token = self.auth_credentials.get("bearer_token", "") + if token: + self._session.headers.update({"Authorization": f"Bearer {token}"}) + elif self.auth_type == "basic": + username = self.auth_credentials.get("username", "") + password = self.auth_credentials.get("password", "") + if username and password: + self._session.auth = (username, password) + + def _initialize_mcp_connection(self) -> Dict: + """ + Initialize MCP connection with the server, using cached session if available. + + Returns: + Server capabilities and information + """ + cached_session = self._get_cached_session() + if cached_session: + self._mcp_session_id = cached_session + return {"cached": True} + try: + init_params = { + "protocolVersion": "2024-11-05", + "capabilities": {"roots": {"listChanged": True}, "sampling": {}}, + "clientInfo": {"name": "DocsGPT", "version": "1.0.0"}, + } + response = self._make_mcp_request("initialize", init_params) + self._make_mcp_request("notifications/initialized") + + return response + except Exception as e: + return {"error": str(e), "fallback": True} + + def _ensure_valid_session(self): + """Ensure we have a valid MCP session, reinitializing if needed.""" + if not self._mcp_session_id: + self._initialize_mcp_connection() + + def _make_mcp_request(self, method: str, params: Optional[Dict] = None) -> Dict: + """ + Make an MCP protocol request to the server with automatic session recovery. + + Args: + method: MCP method name (e.g., "tools/list", "tools/call") + params: Parameters for the MCP method + + Returns: + Response data as dictionary + + Raises: + Exception: If request fails after retry + """ + mcp_message = {"jsonrpc": "2.0", "method": method} + + if not method.startswith("notifications/"): + mcp_message["id"] = 1 + if params: + mcp_message["params"] = params + return self._execute_mcp_request(mcp_message, method) + + def _execute_mcp_request( + self, mcp_message: Dict, method: str, is_retry: bool = False + ) -> Dict: + """Execute MCP request with optional retry on session failure.""" + try: + headers = {"Content-Type": "application/json", "Accept": "application/json"} + headers.update(self._session.headers) + + if self._mcp_session_id: + headers["Mcp-Session-Id"] = self._mcp_session_id + response = self._session.post( + self.server_url.rstrip("/"), + json=mcp_message, + headers=headers, + timeout=self.timeout, + ) + + if "mcp-session-id" in response.headers: + self._mcp_session_id = response.headers["mcp-session-id"] + self._cache_session(self._mcp_session_id) + response.raise_for_status() + + if method.startswith("notifications/"): + return {} + try: + result = response.json() + except json.JSONDecodeError: + raise Exception(f"Invalid JSON response: {response.text}") + if "error" in result: + error_msg = result["error"] + if isinstance(error_msg, dict): + error_msg = error_msg.get("message", str(error_msg)) + raise Exception(f"MCP server error: {error_msg}") + return result.get("result", result) + except requests.exceptions.RequestException as e: + if not is_retry and self._should_retry_with_new_session(e): + self._invalidate_and_refresh_session() + return self._execute_mcp_request(mcp_message, method, is_retry=True) + raise Exception(f"MCP server request failed: {str(e)}") + + def _should_retry_with_new_session(self, error: Exception) -> bool: + """Check if error indicates session invalidation and retry is warranted.""" + error_str = str(error).lower() + return ( + any( + indicator in error_str + for indicator in [ + "invalid session", + "session expired", + "unauthorized", + "401", + "403", + ] + ) + and self._mcp_session_id is not None + ) + + def _invalidate_and_refresh_session(self) -> None: + """Invalidate current session and create a new one.""" + global _mcp_session_cache + if self._cache_key in _mcp_session_cache: + del _mcp_session_cache[self._cache_key] + self._mcp_session_id = None + self._initialize_mcp_connection() + + def discover_tools(self) -> List[Dict]: + """ + Discover available tools from the MCP server using MCP protocol. + + Returns: + List of tool definitions from the server + """ + try: + self._ensure_valid_session() + + response = self._make_mcp_request("tools/list") + + if isinstance(response, dict) and "tools" in response: + self.available_tools = response["tools"] + return self.available_tools + elif isinstance(response, list): + self.available_tools = response + return self.available_tools + else: + self.available_tools = [] + return self.available_tools + except Exception as e: + raise Exception(f"Failed to discover tools from MCP server: {str(e)}") + + def execute_action(self, action_name: str, **kwargs) -> Any: + """ + Execute an action on the remote MCP server using MCP protocol. + + Args: + action_name: Name of the action to execute + **kwargs: Parameters for the action + + Returns: + Result from the MCP server + """ + self._ensure_valid_session() + + # Prepare call parameters for MCP protocol + + call_params = {"name": action_name, "arguments": kwargs} + + try: + result = self._make_mcp_request("tools/call", call_params) + return result + except Exception as e: + raise Exception(f"Failed to execute action '{action_name}': {str(e)}") + + def get_actions_metadata(self) -> List[Dict]: + """ + Get metadata for all available actions. + + Returns: + List of action metadata dictionaries + """ + actions = [] + for tool in self.available_tools: + # Parse MCP tool schema according to MCP specification + # Check multiple possible schema locations for compatibility + + input_schema = ( + tool.get("inputSchema") + or tool.get("input_schema") + or tool.get("schema") + or tool.get("parameters") + ) + + # Default empty schema if no inputSchema provided + + parameters_schema = { + "type": "object", + "properties": {}, + "required": [], + } + + # Parse the inputSchema if it exists + + if input_schema: + if isinstance(input_schema, dict): + # Handle standard JSON Schema format + + if "properties" in input_schema: + parameters_schema = { + "type": input_schema.get("type", "object"), + "properties": input_schema.get("properties", {}), + "required": input_schema.get("required", []), + } + + # Add additional schema properties if they exist + + for key in ["additionalProperties", "description"]: + if key in input_schema: + parameters_schema[key] = input_schema[key] + else: + # Might be properties directly at root level + + parameters_schema["properties"] = input_schema + action = { + "name": tool.get("name", ""), + "description": tool.get("description", ""), + "parameters": parameters_schema, + } + actions.append(action) + return actions + + def get_config_requirements(self) -> Dict: + """ + Get configuration requirements for the MCP tool. + + Returns: + Dictionary describing required configuration + """ + return { + "server_url": { + "type": "string", + "description": "URL of the remote MCP server (e.g., https://api.example.com)", + "required": True, + }, + "auth_type": { + "type": "string", + "description": "Authentication type", + "enum": ["none", "api_key", "bearer", "basic"], + "default": "none", + "required": True, + }, + "auth_credentials": { + "type": "object", + "description": "Authentication credentials (varies by auth_type)", + "properties": { + "api_key": { + "type": "string", + "description": "API key for api_key auth", + }, + "header_name": { + "type": "string", + "description": "Header name for API key (default: X-API-Key)", + "default": "X-API-Key", + }, + "token": { + "type": "string", + "description": "Bearer token for bearer auth", + }, + "username": { + "type": "string", + "description": "Username for basic auth", + }, + "password": { + "type": "string", + "description": "Password for basic auth", + }, + }, + "required": False, + }, + "timeout": { + "type": "integer", + "description": "Request timeout in seconds", + "default": 30, + "minimum": 1, + "maximum": 300, + "required": False, + }, + } + + def test_connection(self) -> Dict: + """ + Test the connection to the MCP server and validate functionality. + + Returns: + Dictionary with connection test results including tool count + """ + try: + init_result = self._initialize_mcp_connection() + + tools = self.discover_tools() + + message = f"Successfully connected to MCP server. Found {len(tools)} tools." + if init_result.get("cached"): + message += " (Using cached session)" + elif init_result.get("fallback"): + message += " (No formal initialization required)" + return { + "success": True, + "message": message, + "tools_count": len(tools), + "session_id": self._mcp_session_id, + "tools": [ + tool.get("name", "unknown") for tool in tools[:5] + ], # First 5 tool names + } + except Exception as e: + return { + "success": False, + "message": f"Connection failed: {str(e)}", + "tools_count": 0, + "error_type": type(e).__name__, + } diff --git a/application/agents/tools/tool_manager.py b/application/agents/tools/tool_manager.py index ad71db28..890262bc 100644 --- a/application/agents/tools/tool_manager.py +++ b/application/agents/tools/tool_manager.py @@ -23,16 +23,31 @@ class ToolManager: tool_config = self.config.get(name, {}) self.tools[name] = obj(tool_config) - def load_tool(self, tool_name, tool_config): + def load_tool(self, tool_name, tool_config, user_id=None): self.config[tool_name] = tool_config module = importlib.import_module(f"application.agents.tools.{tool_name}") for member_name, obj in inspect.getmembers(module, inspect.isclass): if issubclass(obj, Tool) and obj is not Tool: - return obj(tool_config) + # For MCP tools, pass the user_id for credential decryption + if tool_name == "mcp_tool" and user_id: + return obj(tool_config, user_id) + else: + return obj(tool_config) - def execute_action(self, tool_name, action_name, **kwargs): + def execute_action(self, tool_name, action_name, user_id=None, **kwargs): if tool_name not in self.tools: + # For MCP tools, they might not be pre-loaded, so load dynamically + if tool_name == "mcp_tool": + raise ValueError(f"Tool '{tool_name}' not loaded and no config provided for dynamic loading") raise ValueError(f"Tool '{tool_name}' not loaded") + + # For MCP tools, if user_id is provided, create a new instance with user context + if tool_name == "mcp_tool" and user_id: + # Load tool dynamically with user context for proper credential access + tool_config = self.config.get(tool_name, {}) + tool = self.load_tool(tool_name, tool_config, user_id) + return tool.execute_action(action_name, **kwargs) + return self.tools[tool_name].execute_action(action_name, **kwargs) def get_all_actions_metadata(self): diff --git a/application/api/user/routes.py b/application/api/user/routes.py index 486690fb..8309a984 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -492,9 +492,9 @@ class DeleteOldIndexes(Resource): ) if not doc: return make_response(jsonify({"status": "not found"}), 404) - + storage = StorageCreator.get_storage() - + try: # Delete vector index if settings.VECTOR_STORE == "faiss": @@ -508,7 +508,7 @@ class DeleteOldIndexes(Resource): settings.VECTOR_STORE, source_id=str(doc["_id"]) ) vectorstore.delete_index() - + if "file_path" in doc and doc["file_path"]: file_path = doc["file_path"] if storage.is_directory(file_path): @@ -517,7 +517,7 @@ class DeleteOldIndexes(Resource): storage.delete_file(f) else: storage.delete_file(file_path) - + except FileNotFoundError: pass except Exception as err: @@ -525,7 +525,7 @@ class DeleteOldIndexes(Resource): f"Error deleting files and indexes: {err}", exc_info=True ) return make_response(jsonify({"success": False}), 400) - + sources_collection.delete_one({"_id": ObjectId(source_id)}) return make_response(jsonify({"success": True}), 200) @@ -573,55 +573,75 @@ class UploadFile(Resource): try: storage = StorageCreator.get_storage() - - + for file in files: original_filename = file.filename safe_file = safe_filename(original_filename) - + with tempfile.TemporaryDirectory() as temp_dir: temp_file_path = os.path.join(temp_dir, safe_file) file.save(temp_file_path) - + if zipfile.is_zipfile(temp_file_path): try: - with zipfile.ZipFile(temp_file_path, 'r') as zip_ref: + with zipfile.ZipFile(temp_file_path, "r") as zip_ref: zip_ref.extractall(path=temp_dir) - + # Walk through extracted files and upload them for root, _, files in os.walk(temp_dir): for extracted_file in files: - if os.path.join(root, extracted_file) == temp_file_path: + if ( + os.path.join(root, extracted_file) + == temp_file_path + ): continue - - rel_path = os.path.relpath(os.path.join(root, extracted_file), temp_dir) + + rel_path = os.path.relpath( + os.path.join(root, extracted_file), temp_dir + ) storage_path = f"{base_path}/{rel_path}" - - with open(os.path.join(root, extracted_file), 'rb') as f: + + with open( + os.path.join(root, extracted_file), "rb" + ) as f: storage.save_file(f, storage_path) except Exception as e: - current_app.logger.error(f"Error extracting zip: {e}", exc_info=True) + current_app.logger.error( + f"Error extracting zip: {e}", exc_info=True + ) # If zip extraction fails, save the original zip file file_path = f"{base_path}/{safe_file}" - with open(temp_file_path, 'rb') as f: + with open(temp_file_path, "rb") as f: storage.save_file(f, file_path) else: # For non-zip files, save directly file_path = f"{base_path}/{safe_file}" - with open(temp_file_path, 'rb') as f: + with open(temp_file_path, "rb") as f: storage.save_file(f, file_path) - + task = ingest.delay( settings.UPLOAD_FOLDER, [ - ".rst", ".md", ".pdf", ".txt", ".docx", ".csv", ".epub", - ".html", ".mdx", ".json", ".xlsx", ".pptx", ".png", - ".jpg", ".jpeg", + ".rst", + ".md", + ".pdf", + ".txt", + ".docx", + ".csv", + ".epub", + ".html", + ".mdx", + ".json", + ".xlsx", + ".pptx", + ".png", + ".jpg", + ".jpeg", ], job_name, user, file_path=base_path, - filename=dir_name + filename=dir_name, ) except Exception as err: current_app.logger.error(f"Error uploading file: {err}", exc_info=True) @@ -635,12 +655,29 @@ class ManageSourceFiles(Resource): api.model( "ManageSourceFilesModel", { - "source_id": fields.String(required=True, description="Source ID to modify"), - "operation": fields.String(required=True, description="Operation: 'add', 'remove', or 'remove_directory'"), - "file_paths": fields.List(fields.String, required=False, description="File paths to remove (for remove operation)"), - "directory_path": fields.String(required=False, description="Directory path to remove (for remove_directory operation)"), - "file": fields.Raw(required=False, description="Files to add (for add operation)"), - "parent_dir": fields.String(required=False, description="Parent directory path relative to source root"), + "source_id": fields.String( + required=True, description="Source ID to modify" + ), + "operation": fields.String( + required=True, + description="Operation: 'add', 'remove', or 'remove_directory'", + ), + "file_paths": fields.List( + fields.String, + required=False, + description="File paths to remove (for remove operation)", + ), + "directory_path": fields.String( + required=False, + description="Directory path to remove (for remove_directory operation)", + ), + "file": fields.Raw( + required=False, description="Files to add (for add operation)" + ), + "parent_dir": fields.String( + required=False, + description="Parent directory path relative to source root", + ), }, ) ) @@ -650,7 +687,9 @@ class ManageSourceFiles(Resource): def post(self): decoded_token = request.decoded_token if not decoded_token: - return make_response(jsonify({"success": False, "message": "Unauthorized"}), 401) + return make_response( + jsonify({"success": False, "message": "Unauthorized"}), 401 + ) user = decoded_token.get("sub") source_id = request.form.get("source_id") @@ -658,12 +697,24 @@ class ManageSourceFiles(Resource): if not source_id or not operation: return make_response( - jsonify({"success": False, "message": "source_id and operation are required"}), 400 + jsonify( + { + "success": False, + "message": "source_id and operation are required", + } + ), + 400, ) if operation not in ["add", "remove", "remove_directory"]: return make_response( - jsonify({"success": False, "message": "operation must be 'add', 'remove', or 'remove_directory'"}), 400 + jsonify( + { + "success": False, + "message": "operation must be 'add', 'remove', or 'remove_directory'", + } + ), + 400, ) try: @@ -674,34 +725,53 @@ class ManageSourceFiles(Resource): ) try: - source = sources_collection.find_one({"_id": ObjectId(source_id), "user": user}) + source = sources_collection.find_one( + {"_id": ObjectId(source_id), "user": user} + ) if not source: return make_response( - jsonify({"success": False, "message": "Source not found or access denied"}), 404 + jsonify( + { + "success": False, + "message": "Source not found or access denied", + } + ), + 404, ) except Exception as err: current_app.logger.error(f"Error finding source: {err}", exc_info=True) - return make_response(jsonify({"success": False, "message": "Database error"}), 500) + return make_response( + jsonify({"success": False, "message": "Database error"}), 500 + ) try: storage = StorageCreator.get_storage() source_file_path = source.get("file_path", "") - parent_dir = request.form.get("parent_dir", "") - + parent_dir = request.form.get("parent_dir", "") + if parent_dir and (parent_dir.startswith("/") or ".." in parent_dir): return make_response( - jsonify({"success": False, "message": "Invalid parent directory path"}), 400 + jsonify( + {"success": False, "message": "Invalid parent directory path"} + ), + 400, ) if operation == "add": files = request.files.getlist("file") if not files or all(file.filename == "" for file in files): return make_response( - jsonify({"success": False, "message": "No files provided for add operation"}), 400 + jsonify( + { + "success": False, + "message": "No files provided for add operation", + } + ), + 400, ) added_files = [] - + target_dir = source_file_path if parent_dir: target_dir = f"{source_file_path}/{parent_dir}" @@ -720,26 +790,44 @@ class ManageSourceFiles(Resource): task = reingest_source_task.delay(source_id=source_id, user=user) - return make_response(jsonify({ - "success": True, - "message": f"Added {len(added_files)} files", - "added_files": added_files, - "parent_dir": parent_dir, - "reingest_task_id": task.id - }), 200) + return make_response( + jsonify( + { + "success": True, + "message": f"Added {len(added_files)} files", + "added_files": added_files, + "parent_dir": parent_dir, + "reingest_task_id": task.id, + } + ), + 200, + ) elif operation == "remove": file_paths_str = request.form.get("file_paths") if not file_paths_str: return make_response( - jsonify({"success": False, "message": "file_paths required for remove operation"}), 400 + jsonify( + { + "success": False, + "message": "file_paths required for remove operation", + } + ), + 400, ) try: - file_paths = json.loads(file_paths_str) if isinstance(file_paths_str, str) else file_paths_str + file_paths = ( + json.loads(file_paths_str) + if isinstance(file_paths_str, str) + else file_paths_str + ) except Exception: return make_response( - jsonify({"success": False, "message": "Invalid file_paths format"}), 400 + jsonify( + {"success": False, "message": "Invalid file_paths format"} + ), + 400, ) # Remove files from storage and directory structure @@ -757,18 +845,29 @@ class ManageSourceFiles(Resource): task = reingest_source_task.delay(source_id=source_id, user=user) - return make_response(jsonify({ - "success": True, - "message": f"Removed {len(removed_files)} files", - "removed_files": removed_files, - "reingest_task_id": task.id - }), 200) + return make_response( + jsonify( + { + "success": True, + "message": f"Removed {len(removed_files)} files", + "removed_files": removed_files, + "reingest_task_id": task.id, + } + ), + 200, + ) elif operation == "remove_directory": directory_path = request.form.get("directory_path") if not directory_path: return make_response( - jsonify({"success": False, "message": "directory_path required for remove_directory operation"}), 400 + jsonify( + { + "success": False, + "message": "directory_path required for remove_directory operation", + } + ), + 400, ) # Validate directory path (prevent path traversal) @@ -778,10 +877,17 @@ class ManageSourceFiles(Resource): f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}" ) return make_response( - jsonify({"success": False, "message": "Invalid directory path"}), 400 + jsonify( + {"success": False, "message": "Invalid directory path"} + ), + 400, ) - full_directory_path = f"{source_file_path}/{directory_path}" if directory_path else source_file_path + full_directory_path = ( + f"{source_file_path}/{directory_path}" + if directory_path + else source_file_path + ) if not storage.is_directory(full_directory_path): current_app.logger.warning( @@ -790,7 +896,13 @@ class ManageSourceFiles(Resource): f"Full path: {full_directory_path}" ) return make_response( - jsonify({"success": False, "message": "Directory not found or is not a directory"}), 404 + jsonify( + { + "success": False, + "message": "Directory not found or is not a directory", + } + ), + 404, ) success = storage.remove_directory(full_directory_path) @@ -802,7 +914,10 @@ class ManageSourceFiles(Resource): f"Full path: {full_directory_path}" ) return make_response( - jsonify({"success": False, "message": "Failed to remove directory"}), 500 + jsonify( + {"success": False, "message": "Failed to remove directory"} + ), + 500, ) current_app.logger.info( @@ -816,12 +931,17 @@ class ManageSourceFiles(Resource): task = reingest_source_task.delay(source_id=source_id, user=user) - return make_response(jsonify({ - "success": True, - "message": f"Successfully removed directory: {directory_path}", - "removed_directory": directory_path, - "reingest_task_id": task.id - }), 200) + return make_response( + jsonify( + { + "success": True, + "message": f"Successfully removed directory: {directory_path}", + "removed_directory": directory_path, + "reingest_task_id": task.id, + } + ), + 200, + ) except Exception as err: error_context = f"operation={operation}, user={user}, source_id={source_id}" @@ -835,8 +955,12 @@ class ManageSourceFiles(Resource): parent_dir = request.form.get("parent_dir", "") error_context += f", parent_dir={parent_dir}" - current_app.logger.error(f"Error managing source files: {err} ({error_context})", exc_info=True) - return make_response(jsonify({"success": False, "message": "Operation failed"}), 500) + current_app.logger.error( + f"Error managing source files: {err} ({error_context})", exc_info=True + ) + return make_response( + jsonify({"success": False, "message": "Operation failed"}), 500 + ) @user_ns.route("/api/remote") @@ -984,7 +1108,7 @@ class PaginatedSources(Resource): "tokens": doc.get("tokens", ""), "retriever": doc.get("retriever", "classic"), "syncFrequency": doc.get("sync_frequency", ""), - "isNested": bool(doc.get("directory_structure")) + "isNested": bool(doc.get("directory_structure")), } paginated_docs.append(doc_data) response = { @@ -1032,7 +1156,7 @@ class CombinedJson(Resource): "tokens": index.get("tokens", ""), "retriever": index.get("retriever", "classic"), "syncFrequency": index.get("sync_frequency", ""), - "is_nested": bool(index.get("directory_structure")) + "is_nested": bool(index.get("directory_structure")), } ) except Exception as err: @@ -1381,7 +1505,8 @@ class CreateAgent(Resource): required=True, description="Status of the agent (draft or published)" ), "json_schema": fields.Raw( - required=False, description="JSON schema for enforcing structured output format" + required=False, + description="JSON schema for enforcing structured output format", ), }, ) @@ -1407,7 +1532,7 @@ class CreateAgent(Resource): except json.JSONDecodeError: data["json_schema"] = None print(f"Received data: {data}") - + # Validate JSON schema if provided if data.get("json_schema"): try: @@ -1415,20 +1540,32 @@ class CreateAgent(Resource): json_schema = data.get("json_schema") if not isinstance(json_schema, dict): return make_response( - jsonify({"success": False, "message": "JSON schema must be a valid JSON object"}), - 400 + jsonify( + { + "success": False, + "message": "JSON schema must be a valid JSON object", + } + ), + 400, ) - + # Validate that it has either a 'schema' property or is itself a schema if "schema" not in json_schema and "type" not in json_schema: return make_response( - jsonify({"success": False, "message": "JSON schema must contain either a 'schema' property or be a valid JSON schema with 'type' property"}), - 400 + jsonify( + { + "success": False, + "message": "JSON schema must contain either a 'schema' property or be a valid JSON schema with 'type' property", + } + ), + 400, ) except Exception as e: return make_response( - jsonify({"success": False, "message": f"Invalid JSON schema: {str(e)}"}), - 400 + jsonify( + {"success": False, "message": f"Invalid JSON schema: {str(e)}"} + ), + 400, ) if data.get("status") not in ["draft", "published"]: @@ -1529,7 +1666,8 @@ class UpdateAgent(Resource): required=True, description="Status of the agent (draft or published)" ), "json_schema": fields.Raw( - required=False, description="JSON schema for enforcing structured output format" + required=False, + description="JSON schema for enforcing structured output format", ), }, ) @@ -3297,6 +3435,31 @@ class CreateTool(Resource): param_details["value"] = "" transformed_actions.append(action) try: + # Process config to encrypt credentials for MCP tools + config = data["config"] + if data["name"] == "mcp_tool": + from application.security.encryption import encrypt_credentials + + # Extract credentials from config + credentials = {} + if config.get("auth_type") == "bearer": + credentials["bearer_token"] = config.get("bearer_token", "") + elif config.get("auth_type") == "api_key": + credentials["api_key"] = config.get("api_key", "") + credentials["api_key_header"] = config.get("api_key_header", "") + elif config.get("auth_type") == "basic": + credentials["username"] = config.get("username", "") + credentials["password"] = config.get("password", "") + + # Encrypt credentials if any exist + if credentials: + config["encrypted_credentials"] = encrypt_credentials( + credentials, user + ) + # Remove plaintext credentials from config + for key in credentials.keys(): + config.pop(key, None) + new_tool = { "user": user, "name": data["name"], @@ -3304,7 +3467,7 @@ class CreateTool(Resource): "description": data["description"], "customName": data.get("customName", ""), "actions": transformed_actions, - "config": data["config"], + "config": config, "status": data["status"], } resp = user_tools_collection.insert_one(new_tool) @@ -3371,7 +3534,41 @@ class UpdateTool(Resource): ), 400, ) - update_data["config"] = data["config"] + + # Handle MCP tool credential encryption + config = data["config"] + tool_name = data.get("name") + if not tool_name: + # Get the tool name from the database + existing_tool = user_tools_collection.find_one( + {"_id": ObjectId(data["id"]), "user": user} + ) + tool_name = existing_tool.get("name") if existing_tool else None + + if tool_name == "mcp_tool": + from application.security.encryption import encrypt_credentials + + # Extract credentials from config + credentials = {} + if config.get("auth_type") == "bearer": + credentials["bearer_token"] = config.get("bearer_token", "") + elif config.get("auth_type") == "api_key": + credentials["api_key"] = config.get("api_key", "") + credentials["api_key_header"] = config.get("api_key_header", "") + elif config.get("auth_type") == "basic": + credentials["username"] = config.get("username", "") + credentials["password"] = config.get("password", "") + + # Encrypt credentials if any exist + if credentials: + config["encrypted_credentials"] = encrypt_credentials( + credentials, user + ) + # Remove plaintext credentials from config + for key in credentials.keys(): + config.pop(key, None) + + update_data["config"] = config if "status" in data: update_data["status"] = data["status"] user_tools_collection.update_one( @@ -3537,7 +3734,7 @@ class GetChunks(Resource): "page": "Page number for pagination", "per_page": "Number of chunks per page", "path": "Optional: Filter chunks by relative file path", - "search": "Optional: Search term to filter chunks by title or content" + "search": "Optional: Search term to filter chunks by title or content", }, ) def get(self): @@ -3561,7 +3758,7 @@ class GetChunks(Resource): try: store = get_vector_store(doc_id) chunks = store.get_chunks() - + filtered_chunks = [] for chunk in chunks: metadata = chunk.get("metadata", {}) @@ -3582,9 +3779,9 @@ class GetChunks(Resource): continue filtered_chunks.append(chunk) - + chunks = filtered_chunks - + total_chunks = len(chunks) start = (page - 1) * per_page end = start + per_page @@ -3598,7 +3795,7 @@ class GetChunks(Resource): "total": total_chunks, "chunks": paginated_chunks, "path": path if path else None, - "search": search_term if search_term else None + "search": search_term if search_term else None, } ), 200, @@ -3607,6 +3804,7 @@ class GetChunks(Resource): current_app.logger.error(f"Error getting chunks: {e}", exc_info=True) return make_response(jsonify({"success": False}), 500) + @user_ns.route("/api/add_chunk") class AddChunk(Resource): @api.expect( @@ -3773,7 +3971,9 @@ class UpdateChunk(Resource): deleted = store.delete_chunk(chunk_id) if not deleted: - current_app.logger.warning(f"Failed to delete old chunk {chunk_id}, but new chunk {new_chunk_id} was created") + current_app.logger.warning( + f"Failed to delete old chunk {chunk_id}, but new chunk {new_chunk_id} was created" + ) return make_response( jsonify( @@ -3905,39 +4105,233 @@ class DirectoryStructure(Resource): decoded_token = request.decoded_token if not decoded_token: return make_response(jsonify({"success": False}), 401) - + user = decoded_token.get("sub") doc_id = request.args.get("id") - + if not doc_id: - return make_response( - jsonify({"error": "Document ID is required"}), 400 - ) - + return make_response(jsonify({"error": "Document ID is required"}), 400) + if not ObjectId.is_valid(doc_id): return make_response(jsonify({"error": "Invalid document ID"}), 400) - + try: doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user}) if not doc: return make_response( jsonify({"error": "Document not found or access denied"}), 404 ) - + directory_structure = doc.get("directory_structure", {}) - + return make_response( - jsonify({ - "success": True, - "directory_structure": directory_structure, - "base_path": doc.get("file_path", "") - }), 200 + jsonify( + { + "success": True, + "directory_structure": directory_structure, + "base_path": doc.get("file_path", ""), + } + ), + 200, ) - + except Exception as e: current_app.logger.error( f"Error retrieving directory structure: {e}", exc_info=True ) - return make_response( - jsonify({"success": False, "error": str(e)}), 500 + return make_response(jsonify({"success": False, "error": str(e)}), 500) + + +@user_ns.route("/api/mcp_servers") +class MCPServers(Resource): + @api.doc(description="Get all MCP servers configured by the user") + def get(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + + user = decoded_token.get("sub") + try: + # Find all MCP tools for this user + mcp_tools = user_tools_collection.find({"user": user, "name": "mcp_tool"}) + + servers = [] + for tool in mcp_tools: + config = tool.get("config", {}) + servers.append( + { + "id": str(tool["_id"]), + "name": tool.get("displayName", "MCP Server"), + "server_url": config.get("server_url", ""), + "auth_type": config.get("auth_type", "none"), + "status": tool.get("status", False), + "created_at": ( + tool.get("_id").generation_time.isoformat() + if tool.get("_id") + else None + ), + } + ) + + return make_response(jsonify({"success": True, "servers": servers}), 200) + + except Exception as e: + current_app.logger.error( + f"Error retrieving MCP servers: {e}", exc_info=True + ) + return make_response(jsonify({"success": False, "error": str(e)}), 500) + + +@user_ns.route("/api/mcp_server//test") +class TestMCPServer(Resource): + @api.doc(description="Test connection to an MCP server") + def post(self, server_id): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + + user = decoded_token.get("sub") + try: + # Find the MCP tool + mcp_tool_doc = user_tools_collection.find_one( + {"_id": ObjectId(server_id), "user": user, "name": "mcp_tool"} + ) + + if not mcp_tool_doc: + return make_response( + jsonify({"success": False, "error": "MCP server not found"}), 404 + ) + + # Load the tool and test connection + from application.agents.tools.mcp_tool import MCPTool + + mcp_tool = MCPTool(mcp_tool_doc.get("config", {}), user) + result = mcp_tool.test_connection() + + return make_response(jsonify(result), 200) + + except Exception as e: + current_app.logger.error(f"Error testing MCP server: {e}", exc_info=True) + return make_response( + jsonify( + {"success": False, "error": f"Connection test failed: {str(e)}"} + ), + 500, + ) + + +@user_ns.route("/api/mcp_server//tools") +class MCPServerTools(Resource): + @api.doc(description="Discover and get tools from an MCP server") + def get(self, server_id): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + + user = decoded_token.get("sub") + try: + # Find the MCP tool + mcp_tool_doc = user_tools_collection.find_one( + {"_id": ObjectId(server_id), "user": user, "name": "mcp_tool"} + ) + + if not mcp_tool_doc: + return make_response( + jsonify({"success": False, "error": "MCP server not found"}), 404 + ) + + # Load the tool and discover tools + from application.agents.tools.mcp_tool import MCPTool + + mcp_tool = MCPTool(mcp_tool_doc.get("config", {}), user) + tools = mcp_tool.discover_tools() + + # Get actions metadata and transform to match other tools format + actions_metadata = mcp_tool.get_actions_metadata() + transformed_actions = [] + + for action in actions_metadata: + # Add active flag and transform parameters + action["active"] = True + if "parameters" in action: + if "properties" in action["parameters"]: + for param_name, param_details in action["parameters"][ + "properties" + ].items(): + param_details["filled_by_llm"] = True + param_details["value"] = "" + transformed_actions.append(action) + + # Update the stored actions in the database + user_tools_collection.update_one( + {"_id": ObjectId(server_id)}, {"$set": {"actions": transformed_actions}} + ) + + return make_response( + jsonify( + {"success": True, "tools": tools, "actions": transformed_actions} + ), + 200, + ) + + except Exception as e: + current_app.logger.error(f"Error discovering MCP tools: {e}", exc_info=True) + return make_response( + jsonify( + {"success": False, "error": f"Tool discovery failed: {str(e)}"} + ), + 500, + ) + + +@user_ns.route("/api/mcp_server//tools/") +class MCPServerToolAction(Resource): + @api.expect( + api.model( + "MCPToolActionModel", + { + "parameters": fields.Raw( + required=False, description="Parameters for the tool action" + ) + }, + ) + ) + @api.doc(description="Execute a specific tool action on an MCP server") + def post(self, server_id, action_name): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + + user = decoded_token.get("sub") + data = request.get_json() or {} + parameters = data.get("parameters", {}) + + try: + # Find the MCP tool + mcp_tool_doc = user_tools_collection.find_one( + {"_id": ObjectId(server_id), "user": user, "name": "mcp_tool"} + ) + + if not mcp_tool_doc: + return make_response( + jsonify({"success": False, "error": "MCP server not found"}), 404 + ) + + # Load the tool and execute action + from application.agents.tools.mcp_tool import MCPTool + + mcp_tool = MCPTool(mcp_tool_doc.get("config", {}), user) + result = mcp_tool.execute_action(action_name, **parameters) + + return make_response(jsonify({"success": True, "result": result}), 200) + + except Exception as e: + current_app.logger.error( + f"Error executing MCP tool action: {e}", exc_info=True + ) + return make_response( + jsonify( + {"success": False, "error": f"Action execution failed: {str(e)}"} + ), + 500, ) diff --git a/application/core/settings.py b/application/core/settings.py index 9303b996..7c25084e 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -89,7 +89,7 @@ class Settings(BaseSettings): QDRANT_HOST: Optional[str] = None QDRANT_PATH: Optional[str] = None QDRANT_DISTANCE_FUNC: str = "Cosine" - + # PGVector vectorstore config PGVECTOR_CONNECTION_STRING: Optional[str] = None # Milvus vectorstore config diff --git a/application/security/__init__.py b/application/security/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/application/security/encryption.py b/application/security/encryption.py new file mode 100644 index 00000000..5cc891f6 --- /dev/null +++ b/application/security/encryption.py @@ -0,0 +1,97 @@ +""" +Simple encryption utility for securely storing sensitive credentials. +Uses XOR encryption with a key derived from app secret and user ID. +Note: This is basic obfuscation. For production, consider using cryptography library. +""" + +import base64 +import hashlib +import os +import json + + +def _get_encryption_key(user_id: str) -> bytes: + """ + Generate a consistent encryption key for a specific user. + Uses app secret + user ID to create a unique key per user. + """ + # Get app secret from environment or use a default (in production, always use env) + app_secret = os.environ.get( + "APP_SECRET_KEY", "default-docsgpt-secret-key-change-in-production" + ) + + # Combine app secret with user ID for user-specific encryption + combined = f"{app_secret}#{user_id}" + + # Create a 32-byte key + key_material = hashlib.sha256(combined.encode()).digest() + + return key_material + + +def _xor_encrypt_decrypt(data: bytes, key: bytes) -> bytes: + """Simple XOR encryption/decryption.""" + result = bytearray() + for i, byte in enumerate(data): + result.append(byte ^ key[i % len(key)]) + return bytes(result) + + +def encrypt_credentials(credentials: dict, user_id: str) -> str: + """ + Encrypt credentials dictionary for secure storage. + + Args: + credentials: Dictionary containing sensitive data + user_id: User ID for creating user-specific encryption key + + Returns: + Base64 encoded encrypted string + """ + if not credentials: + return "" + + try: + key = _get_encryption_key(user_id) + + # Convert dict to JSON string and encrypt + json_str = json.dumps(credentials) + encrypted_data = _xor_encrypt_decrypt(json_str.encode(), key) + + # Return base64 encoded for storage + return base64.b64encode(encrypted_data).decode() + + except Exception as e: + # If encryption fails, store empty string (will require re-auth) + print(f"Warning: Failed to encrypt credentials: {e}") + return "" + + +def decrypt_credentials(encrypted_data: str, user_id: str) -> dict: + """ + Decrypt credentials from storage. + + Args: + encrypted_data: Base64 encoded encrypted string + user_id: User ID for creating user-specific encryption key + + Returns: + Dictionary containing decrypted credentials + """ + if not encrypted_data: + return {} + + try: + key = _get_encryption_key(user_id) + + # Decode and decrypt + encrypted_bytes = base64.b64decode(encrypted_data.encode()) + decrypted_data = _xor_encrypt_decrypt(encrypted_bytes, key) + + # Parse JSON back to dict + return json.loads(decrypted_data.decode()) + + except Exception as e: + # If decryption fails, return empty dict (will require re-auth) + print(f"Warning: Failed to decrypt credentials: {e}") + return {} diff --git a/frontend/public/toolIcons/tool_mcp_tool.svg b/frontend/public/toolIcons/tool_mcp_tool.svg new file mode 100644 index 00000000..e69de29b diff --git a/frontend/src/assets/server.svg b/frontend/src/assets/server.svg new file mode 100644 index 00000000..e69de29b diff --git a/frontend/src/locale/en.json b/frontend/src/locale/en.json index d0d1b4b3..1b3067c9 100644 --- a/frontend/src/locale/en.json +++ b/frontend/src/locale/en.json @@ -183,7 +183,64 @@ "cancel": "Cancel", "addNew": "Add New", "name": "Name", - "type": "Type" + "type": "Type", + "regularTools": "Regular Tools", + "mcpTools": "MCP Tools", + "mcp": { + "title": "MCP (Model Context Protocol) Servers", + "description": "Connect to remote MCP servers to access their tools and capabilities. Only remote servers are supported.", + "addServer": "Add MCP Server", + "editServer": "Edit Server", + "deleteServer": "Delete Server", + "delete": "Delete", + "serverName": "Server Name", + "serverUrl": "Server URL", + "authType": "Authentication Type", + "apiKey": "API Key", + "headerName": "Header Name", + "bearerToken": "Bearer Token", + "username": "Username", + "password": "Password", + "timeout": "Timeout (seconds)", + "testConnection": "Test Connection", + "testing": "Testing...", + "saving": "Saving...", + "save": "Save", + "cancel": "Cancel", + "backToServers": "← Back to Servers", + "availableTools": "Available Tools", + "refreshTools": "Refresh Tools", + "refreshing": "Refreshing...", + "serverDisabled": "Server is disabled. Enable it to view available tools.", + "noToolsFound": "No tools found on this server.", + "noServersFound": "No MCP servers configured.", + "addFirstServer": "Add your first MCP server to get started.", + "parameters": "Parameters", + "active": "Active", + "inactive": "Inactive", + "noAuth": "No Authentication", + "toggleServer": "Toggle {{serverName}}", + "deleteWarning": "Are you sure you want to delete the MCP server \"{{serverName}}\"? This action cannot be undone.", + "placeholders": { + "serverName": "My MCP Server", + "serverUrl": "https://api.example.com", + "apiKey": "Enter your API key", + "bearerToken": "Enter your bearer token", + "username": "Enter username", + "password": "Enter password" + }, + "errors": { + "nameRequired": "Server name is required", + "urlRequired": "Server URL is required", + "invalidUrl": "Please enter a valid URL", + "apiKeyRequired": "API key is required", + "tokenRequired": "Bearer token is required", + "usernameRequired": "Username is required", + "passwordRequired": "Password is required", + "testFailed": "Connection test failed", + "saveFailed": "Failed to save MCP server" + } + } } }, "modals": { diff --git a/frontend/src/modals/AddToolModal.tsx b/frontend/src/modals/AddToolModal.tsx index 885ef467..28766bd1 100644 --- a/frontend/src/modals/AddToolModal.tsx +++ b/frontend/src/modals/AddToolModal.tsx @@ -8,6 +8,7 @@ import { useOutsideAlerter } from '../hooks'; import { ActiveState } from '../models/misc'; import { selectToken } from '../preferences/preferenceSlice'; import ConfigToolModal from './ConfigToolModal'; +import MCPServerModal from './MCPServerModal'; import { AvailableToolType } from './types'; import WrapperComponent from './WrapperModal'; @@ -34,6 +35,8 @@ export default function AddToolModal({ React.useState(null); const [configModalState, setConfigModalState] = React.useState('INACTIVE'); + const [mcpModalState, setMcpModalState] = + React.useState('INACTIVE'); const [loading, setLoading] = React.useState(false); useOutsideAlerter(modalRef, () => { @@ -86,6 +89,9 @@ export default function AddToolModal({ .catch((error) => { console.error('Failed to create tool:', error); }); + } else if (tool.name === 'mcp_tool') { + setModalState('INACTIVE'); + setMcpModalState('ACTIVE'); } else { setModalState('INACTIVE'); setConfigModalState('ACTIVE'); @@ -95,6 +101,12 @@ export default function AddToolModal({ React.useEffect(() => { if (modalState === 'ACTIVE') getAvailableTools(); }, [modalState]); + + const handleMcpServerAdded = () => { + getUserTools(); + setMcpModalState('INACTIVE'); + }; + return ( <> {modalState === 'ACTIVE' && ( @@ -166,6 +178,11 @@ export default function AddToolModal({ tool={selectedTool} getUserTools={getUserTools} /> + ); } diff --git a/frontend/src/modals/MCPServerModal.tsx b/frontend/src/modals/MCPServerModal.tsx new file mode 100644 index 00000000..32710712 --- /dev/null +++ b/frontend/src/modals/MCPServerModal.tsx @@ -0,0 +1,541 @@ +import { useRef, useState } from 'react'; +import { useTranslation } from 'react-i18next'; +import { useSelector } from 'react-redux'; + +import apiClient from '../api/client'; +import userService from '../api/services/userService'; +import Input from '../components/Input'; +import Spinner from '../components/Spinner'; +import { useOutsideAlerter } from '../hooks'; +import { ActiveState } from '../models/misc'; +import { selectToken } from '../preferences/preferenceSlice'; +import WrapperComponent from './WrapperModal'; + +interface MCPServerModalProps { + modalState: ActiveState; + setModalState: (state: ActiveState) => void; + server?: any; + onServerSaved: () => void; +} + +const authTypes = [ + { value: 'none', label: 'No Authentication' }, + { value: 'api_key', label: 'API Key' }, + { value: 'bearer', label: 'Bearer Token' }, + { value: 'basic', label: 'Basic Authentication' }, +]; + +export default function MCPServerModal({ + modalState, + setModalState, + server, + onServerSaved, +}: MCPServerModalProps) { + const { t } = useTranslation(); + const token = useSelector(selectToken); + const modalRef = useRef(null); + + const [formData, setFormData] = useState({ + name: server?.name || 'My MCP Server', + server_url: server?.server_url || '', + auth_type: server?.auth_type || 'none', + api_key: '', + header_name: 'X-API-Key', + bearer_token: '', + username: '', + password: '', + timeout: 30, + }); + + const [loading, setLoading] = useState(false); + const [testing, setTesting] = useState(false); + const [testResult, setTestResult] = useState<{ + success: boolean; + message: string; + } | null>(null); + const [errors, setErrors] = useState<{ [key: string]: string }>({}); + + useOutsideAlerter(modalRef, () => { + if (modalState === 'ACTIVE') { + setModalState('INACTIVE'); + resetForm(); + } + }, [modalState]); + + const resetForm = () => { + setFormData({ + name: 'My MCP Server', + server_url: '', + auth_type: 'none', + api_key: '', + header_name: 'X-API-Key', + bearer_token: '', + username: '', + password: '', + timeout: 30, + }); + setErrors({}); + setTestResult(null); + }; + + const validateForm = () => { + const newErrors: { [key: string]: string } = {}; + + if (!formData.name.trim()) { + newErrors.name = t('settings.tools.mcp.errors.nameRequired'); + } + + if (!formData.server_url.trim()) { + newErrors.server_url = t('settings.tools.mcp.errors.urlRequired'); + } else { + try { + new URL(formData.server_url); + } catch { + newErrors.server_url = t('settings.tools.mcp.errors.invalidUrl'); + } + } + + if (formData.auth_type === 'api_key' && !formData.api_key.trim()) { + newErrors.api_key = t('settings.tools.mcp.errors.apiKeyRequired'); + } + + if (formData.auth_type === 'bearer' && !formData.bearer_token.trim()) { + newErrors.bearer_token = t('settings.tools.mcp.errors.tokenRequired'); + } + + if (formData.auth_type === 'basic') { + if (!formData.username.trim()) { + newErrors.username = t('settings.tools.mcp.errors.usernameRequired'); + } + if (!formData.password.trim()) { + newErrors.password = t('settings.tools.mcp.errors.passwordRequired'); + } + } + + setErrors(newErrors); + return Object.keys(newErrors).length === 0; + }; + + const handleInputChange = (name: string, value: string | number) => { + setFormData((prev) => ({ ...prev, [name]: value })); + if (errors[name]) { + setErrors((prev) => ({ ...prev, [name]: '' })); + } + setTestResult(null); + }; + + const buildToolConfig = () => { + const config: any = { + server_url: formData.server_url.trim(), + auth_type: formData.auth_type, + timeout: formData.timeout, + }; + + // Add credentials directly to config for encryption + if (formData.auth_type === 'api_key') { + config.api_key = formData.api_key.trim(); + config.api_key_header = formData.header_name.trim() || 'X-API-Key'; + } else if (formData.auth_type === 'bearer') { + config.bearer_token = formData.bearer_token.trim(); + } else if (formData.auth_type === 'basic') { + config.username = formData.username.trim(); + config.password = formData.password.trim(); + } + + return config; + }; + + const testConnection = async () => { + if (!validateForm()) return; + + setTesting(true); + setTestResult(null); + + try { + // Create a temporary tool to test + const config = buildToolConfig(); + + const testData = { + name: 'mcp_tool', + displayName: formData.name, + description: 'MCP Server Connection', + config, + actions: [], + status: false, + }; + + const response = await userService.createTool(testData, token); + const result = await response.json(); + + if (response.ok && result.id) { + // Test the connection + try { + const testResponse = await apiClient.post( + `/api/mcp_server/${result.id}/test`, + {}, + token, + ); + const testData = await testResponse.json(); + setTestResult(testData); + + // Clean up the temporary tool + await userService.deleteTool({ id: result.id }, token); + } catch (error) { + setTestResult({ + success: false, + message: t('settings.tools.mcp.errors.testFailed'), + }); + // Clean up the temporary tool + await userService.deleteTool({ id: result.id }, token); + } + } else { + setTestResult({ + success: false, + message: t('settings.tools.mcp.errors.testFailed'), + }); + } + } catch (error) { + setTestResult({ + success: false, + message: t('settings.tools.mcp.errors.testFailed'), + }); + } finally { + setTesting(false); + } + }; + + const handleSave = async () => { + if (!validateForm()) return; + + setLoading(true); + + try { + const config = buildToolConfig(); + + const toolData = { + name: 'mcp_tool', + displayName: formData.name, + description: `MCP Server: ${formData.server_url}`, + config, + actions: [], // Will be populated after tool creation + status: true, + }; + + let toolId: string; + + if (server) { + // Update existing server + await userService.updateTool({ id: server.id, ...toolData }, token); + toolId = server.id; + } else { + // Create new server + const response = await userService.createTool(toolData, token); + const result = await response.json(); + toolId = result.id; + } + + // Now fetch the MCP tools and update the actions + try { + const toolsResponse = await apiClient.get( + `/api/mcp_server/${toolId}/tools`, + token, + ); + + if (toolsResponse.success && toolsResponse.actions) { + // Update the tool with discovered actions (already formatted by backend) + await userService.updateTool( + { + id: toolId, + ...toolData, + actions: toolsResponse.actions, + }, + token, + ); + + console.log( + `Successfully discovered and saved ${toolsResponse.actions.length} MCP tools`, + ); + + // Show success message with tool count + setTestResult({ + success: true, + message: `MCP server saved successfully! Discovered ${toolsResponse.actions.length} tools.`, + }); + } + } catch (error) { + console.warn( + 'Warning: Could not fetch MCP tools immediately after creation:', + error, + ); + // Don't fail the save operation if tool discovery fails + } + + onServerSaved(); + setModalState('INACTIVE'); + resetForm(); + } catch (error) { + console.error('Error saving MCP server:', error); + setErrors({ general: t('settings.tools.mcp.errors.saveFailed') }); + } finally { + setLoading(false); + } + }; + + const renderAuthFields = () => { + switch (formData.auth_type) { + case 'api_key': + return ( +
+
+ + handleInputChange('api_key', e.target.value)} + placeholder={t('settings.tools.mcp.placeholders.apiKey')} + /> + {errors.api_key && ( +

{errors.api_key}

+ )} +
+
+ + + handleInputChange('header_name', e.target.value) + } + placeholder="X-API-Key" + /> +
+
+ ); + case 'bearer': + return ( +
+ + + handleInputChange('bearer_token', e.target.value) + } + placeholder={t('settings.tools.mcp.placeholders.bearerToken')} + /> + {errors.bearer_token && ( +

{errors.bearer_token}

+ )} +
+ ); + case 'basic': + return ( +
+
+ + handleInputChange('username', e.target.value)} + placeholder={t('settings.tools.mcp.placeholders.username')} + /> + {errors.username && ( +

{errors.username}

+ )} +
+
+ + handleInputChange('password', e.target.value)} + placeholder={t('settings.tools.mcp.placeholders.password')} + /> + {errors.password && ( +

{errors.password}

+ )} +
+
+ ); + default: + return null; + } + }; + + return ( + modalState === 'ACTIVE' && ( + { + setModalState('INACTIVE'); + resetForm(); + }} + className="max-w-[600px] md:w-[80vw] lg:w-[60vw]" + > +
+
+

+ {server + ? t('settings.tools.mcp.editServer') + : t('settings.tools.mcp.addServer')} +

+
+ +
+
+
+ handleInputChange('name', e.target.value)} + borderVariant="thin" + placeholder={t('settings.tools.mcp.placeholders.serverName')} + labelBgClassName="bg-white dark:bg-charleston-green-2" + /> + {errors.name && ( +

{errors.name}

+ )} +
+ +
+ + + handleInputChange('server_url', e.target.value) + } + placeholder={t('settings.tools.mcp.placeholders.serverUrl')} + /> + {errors.server_url && ( +

+ {errors.server_url} +

+ )} +
+ +
+ + +
+ + {renderAuthFields()} + +
+ + + handleInputChange('timeout', parseInt(e.target.value) || 30) + } + placeholder="30" + /> +
+ + {testResult && ( +
+ {testResult.message} +
+ )} + + {errors.general && ( +
+ {errors.general} +
+ )} +
+
+ +
+ + +
+ + +
+
+
+
+ ) + ); +}