diff --git a/application/agents/tools/mcp_tool.py b/application/agents/tools/mcp_tool.py index dc689367..e539986c 100644 --- a/application/agents/tools/mcp_tool.py +++ b/application/agents/tools/mcp_tool.py @@ -1,14 +1,37 @@ +import asyncio +import base64 import json +import logging import time from typing import Any, Dict, List, Optional - -import requests +from urllib.parse import parse_qs, urlparse from application.agents.tools.base import Tool +from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task +from application.cache import get_redis_instance + +from application.core.mongo_db import MongoDB + +from application.core.settings import settings + from application.security.encryption import decrypt_credentials +from fastmcp import Client +from fastmcp.client.auth import BearerAuth +from fastmcp.client.transports import ( + SSETransport, + StdioTransport, + StreamableHttpTransport, +) +from mcp.client.auth import OAuthClientProvider, TokenStorage +from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken +from pydantic import AnyHttpUrl, ValidationError +from redis import Redis -_mcp_session_cache = {} +mongo = MongoDB.get_client() +db = mongo[settings.MONGO_DB_NAME] + +_mcp_clients_cache = {} class MCPTool(Tool): @@ -24,15 +47,24 @@ class MCPTool(Tool): Args: config: Dictionary containing MCP server configuration: - server_url: URL of the remote MCP server - - auth_type: Type of authentication (api_key, bearer, basic, none) + - transport_type: Transport type (auto, sse, http, stdio) + - auth_type: Type of authentication (bearer, oauth, api_key, basic, none) - encrypted_credentials: Encrypted credentials (if available) - timeout: Request timeout in seconds (default: 30) + - headers: Custom headers for requests + - command: Command for STDIO transport + - args: Arguments for STDIO transport + - oauth_scopes: OAuth scopes for oauth auth type + - oauth_client_name: OAuth client name for oauth auth type user_id: User ID for decrypting credentials (required if encrypted_credentials exist) """ self.config = config + self.user_id = user_id self.server_url = config.get("server_url", "") + self.transport_type = config.get("transport_type", "auto") self.auth_type = config.get("auth_type", "none") self.timeout = config.get("timeout", 30) + self.custom_headers = config.get("headers", {}) self.auth_credentials = {} if config.get("encrypted_credentials") and user_id: @@ -41,34 +73,30 @@ class MCPTool(Tool): ) else: self.auth_credentials = config.get("auth_credentials", {}) - self.available_tools = [] - self._session = requests.Session() - self._mcp_session_id = None - self._setup_authentication() - self._cache_key = self._generate_cache_key() + self.oauth_scopes = config.get("oauth_scopes", []) + self.oauth_task_id = config.get("oauth_task_id", None) + self.oauth_client_name = config.get("oauth_client_name", "DocsGPT-MCP") + self.redirect_uri = f"{settings.API_URL}/api/mcp_server/callback" - def _setup_authentication(self): - """Setup authentication for the MCP server connection.""" - if self.auth_type == "api_key": - api_key = self.auth_credentials.get("api_key", "") - header_name = self.auth_credentials.get("api_key_header", "X-API-Key") - if api_key: - self._session.headers.update({header_name: api_key}) - elif self.auth_type == "bearer": - token = self.auth_credentials.get("bearer_token", "") - if token: - self._session.headers.update({"Authorization": f"Bearer {token}"}) - elif self.auth_type == "basic": - username = self.auth_credentials.get("username", "") - password = self.auth_credentials.get("password", "") - if username and password: - self._session.auth = (username, password) + self.available_tools = [] + self._cache_key = self._generate_cache_key() + self._client = None + + # Only validate and setup if server_url is provided and not OAuth + + if self.server_url and self.auth_type != "oauth": + self._setup_client() def _generate_cache_key(self) -> str: """Generate a unique cache key for this MCP server configuration.""" auth_key = "" - if self.auth_type == "bearer": - token = self.auth_credentials.get("bearer_token", "") + if self.auth_type == "oauth": + scopes_str = ",".join(self.oauth_scopes) if self.oauth_scopes else "none" + auth_key = f"oauth:{self.oauth_client_name}:{scopes_str}" + elif self.auth_type in ["bearer"]: + token = self.auth_credentials.get( + "bearer_token", "" + ) or self.auth_credentials.get("access_token", "") auth_key = f"bearer:{token[:10]}..." if token else "bearer:none" elif self.auth_type == "api_key": api_key = self.auth_credentials.get("api_key", "") @@ -78,201 +106,185 @@ class MCPTool(Tool): auth_key = f"basic:{username}" else: auth_key = "none" - return f"{self.server_url}#{auth_key}" + return f"{self.server_url}#{self.transport_type}#{auth_key}" - def _get_cached_session(self) -> Optional[str]: - """Get cached session ID if available and not expired.""" - global _mcp_session_cache - - if self._cache_key in _mcp_session_cache: - session_data = _mcp_session_cache[self._cache_key] - if time.time() - session_data["created_at"] < 1800: - return session_data["session_id"] + def _setup_client(self): + """Setup FastMCP client with proper transport and authentication.""" + global _mcp_clients_cache + if self._cache_key in _mcp_clients_cache: + cached_data = _mcp_clients_cache[self._cache_key] + if time.time() - cached_data["created_at"] < 1800: + self._client = cached_data["client"] + return else: - del _mcp_session_cache[self._cache_key] - return None + del _mcp_clients_cache[self._cache_key] + transport = self._create_transport() + auth = None - def _cache_session(self, session_id: str): - """Cache the session ID for reuse.""" - global _mcp_session_cache - _mcp_session_cache[self._cache_key] = { - "session_id": session_id, + if self.auth_type == "oauth": + redis_client = get_redis_instance() + auth = DocsGPTOAuth( + mcp_url=self.server_url, + scopes=self.oauth_scopes, + redis_client=redis_client, + redirect_uri=self.redirect_uri, + task_id=self.oauth_task_id, + db=db, + user_id=self.user_id, + ) + elif self.auth_type == "bearer": + token = self.auth_credentials.get( + "bearer_token", "" + ) or self.auth_credentials.get("access_token", "") + if token: + auth = BearerAuth(token) + self._client = Client(transport, auth=auth) + _mcp_clients_cache[self._cache_key] = { + "client": self._client, "created_at": time.time(), } - def _initialize_mcp_connection(self) -> Dict: - """ - Initialize MCP connection with the server, using cached session if available. + def _create_transport(self): + """Create appropriate transport based on configuration.""" + headers = {"Content-Type": "application/json", "User-Agent": "DocsGPT-MCP/1.0"} + headers.update(self.custom_headers) - Returns: - Server capabilities and information - """ - cached_session = self._get_cached_session() - if cached_session: - self._mcp_session_id = cached_session - return {"cached": True} - try: - init_params = { - "protocolVersion": "2024-11-05", - "capabilities": {"roots": {"listChanged": True}, "sampling": {}}, - "clientInfo": {"name": "DocsGPT", "version": "1.0.0"}, - } - response = self._make_mcp_request("initialize", init_params) - self._make_mcp_request("notifications/initialized") - - return response - except Exception as e: - return {"error": str(e), "fallback": True} - - def _ensure_valid_session(self): - """Ensure we have a valid MCP session, reinitializing if needed.""" - if not self._mcp_session_id: - self._initialize_mcp_connection() - - def _make_mcp_request(self, method: str, params: Optional[Dict] = None) -> Dict: - """ - Make an MCP protocol request to the server with automatic session recovery. - - Args: - method: MCP method name (e.g., "tools/list", "tools/call") - params: Parameters for the MCP method - - Returns: - Response data as dictionary - - Raises: - Exception: If request fails after retry - """ - mcp_message = {"jsonrpc": "2.0", "method": method} - - if not method.startswith("notifications/"): - mcp_message["id"] = 1 - if params: - mcp_message["params"] = params - return self._execute_mcp_request(mcp_message, method) - - def _execute_mcp_request( - self, mcp_message: Dict, method: str, is_retry: bool = False - ) -> Dict: - """Execute MCP request with optional retry on session failure.""" - try: - final_headers = self._session.headers.copy() - final_headers.update( - { - "Content-Type": "application/json", - "Accept": "application/json, text/event-stream", - } - ) - - if self._mcp_session_id: - final_headers["Mcp-Session-Id"] = self._mcp_session_id - response = self._session.post( - self.server_url.rstrip("/"), - json=mcp_message, - headers=final_headers, - timeout=self.timeout, - ) - - if "mcp-session-id" in response.headers: - self._mcp_session_id = response.headers["mcp-session-id"] - self._cache_session(self._mcp_session_id) - response.raise_for_status() - - if method.startswith("notifications/"): - return {} - response_text = response.text.strip() - if response_text.startswith("event:") and "data:" in response_text: - lines = response_text.split("\n") - data_line = None - for line in lines: - if line.startswith("data:"): - data_line = line[5:].strip() - break - if data_line: - try: - result = json.loads(data_line) - except json.JSONDecodeError: - raise Exception(f"Invalid JSON in SSE data: {data_line}") - else: - raise Exception(f"No data found in SSE response: {response_text}") + if self.auth_type == "api_key": + api_key = self.auth_credentials.get("api_key", "") + header_name = self.auth_credentials.get("api_key_header", "X-API-Key") + if api_key: + headers[header_name] = api_key + elif self.auth_type == "basic": + username = self.auth_credentials.get("username", "") + password = self.auth_credentials.get("password", "") + if username and password: + credentials = base64.b64encode( + f"{username}:{password}".encode() + ).decode() + headers["Authorization"] = f"Basic {credentials}" + if self.transport_type == "auto": + if "sse" in self.server_url.lower() or self.server_url.endswith("/sse"): + transport_type = "sse" else: + transport_type = "http" + else: + transport_type = self.transport_type + if transport_type == "sse": + headers.update({"Accept": "text/event-stream", "Cache-Control": "no-cache"}) + return SSETransport(url=self.server_url, headers=headers) + elif transport_type == "http": + return StreamableHttpTransport(url=self.server_url, headers=headers) + elif transport_type == "stdio": + command = self.config.get("command", "python") + args = self.config.get("args", []) + env = self.auth_credentials if self.auth_credentials else None + return StdioTransport(command=command, args=args, env=env) + else: + return StreamableHttpTransport(url=self.server_url, headers=headers) + + def _format_tools(self, tools_response) -> List[Dict]: + """Format tools response to match expected format.""" + if hasattr(tools_response, "tools"): + tools = tools_response.tools + elif isinstance(tools_response, list): + tools = tools_response + else: + tools = [] + tools_dict = [] + for tool in tools: + if hasattr(tool, "name"): + tool_dict = { + "name": tool.name, + "description": tool.description, + } + if hasattr(tool, "inputSchema"): + tool_dict["inputSchema"] = tool.inputSchema + tools_dict.append(tool_dict) + elif isinstance(tool, dict): + tools_dict.append(tool) + else: + if hasattr(tool, "model_dump"): + tools_dict.append(tool.model_dump()) + else: + tools_dict.append({"name": str(tool), "description": ""}) + return tools_dict + + async def _execute_with_client(self, operation: str, *args, **kwargs): + """Execute operation with FastMCP client.""" + if not self._client: + raise Exception("FastMCP client not initialized") + async with self._client: + if operation == "ping": + return await self._client.ping() + elif operation == "list_tools": + tools_response = await self._client.list_tools() + self.available_tools = self._format_tools(tools_response) + return self.available_tools + elif operation == "call_tool": + tool_name = args[0] + tool_args = kwargs + return await self._client.call_tool(tool_name, tool_args) + elif operation == "list_resources": + return await self._client.list_resources() + elif operation == "list_prompts": + return await self._client.list_prompts() + else: + raise Exception(f"Unknown operation: {operation}") + + def _run_async_operation(self, operation: str, *args, **kwargs): + """Run async operation in sync context.""" + try: + try: + loop = asyncio.get_running_loop() + import concurrent.futures + + def run_in_thread(): + new_loop = asyncio.new_event_loop() + asyncio.set_event_loop(new_loop) + try: + return new_loop.run_until_complete( + self._execute_with_client(operation, *args, **kwargs) + ) + finally: + new_loop.close() + + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(run_in_thread) + return future.result(timeout=self.timeout) + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) try: - result = response.json() - except json.JSONDecodeError: - raise Exception(f"Invalid JSON response: {response.text}") - if "error" in result: - error_msg = result["error"] - if isinstance(error_msg, dict): - error_msg = error_msg.get("message", str(error_msg)) - raise Exception(f"MCP server error: {error_msg}") - return result.get("result", result) - except requests.exceptions.RequestException as e: - if not is_retry and self._should_retry_with_new_session(e): - self._invalidate_and_refresh_session() - return self._execute_mcp_request(mcp_message, method, is_retry=True) - raise Exception(f"MCP server request failed: {str(e)}") - - def _should_retry_with_new_session(self, error: Exception) -> bool: - """Check if error indicates session invalidation and retry is warranted.""" - error_str = str(error).lower() - return ( - any( - indicator in error_str - for indicator in [ - "invalid session", - "session expired", - "unauthorized", - "401", - "403", - ] - ) - and self._mcp_session_id is not None - ) - - def _invalidate_and_refresh_session(self) -> None: - """Invalidate current session and create a new one.""" - global _mcp_session_cache - if self._cache_key in _mcp_session_cache: - del _mcp_session_cache[self._cache_key] - self._mcp_session_id = None - self._initialize_mcp_connection() + return loop.run_until_complete( + self._execute_with_client(operation, *args, **kwargs) + ) + finally: + loop.close() + except Exception as e: + print(f"Error occurred while running async operation: {e}") + raise def discover_tools(self) -> List[Dict]: """ - Discover available tools from the MCP server using MCP protocol. + Discover available tools from the MCP server using FastMCP. Returns: List of tool definitions from the server """ + if not self.server_url: + return [] + if not self._client: + self._setup_client() try: - self._ensure_valid_session() - - response = self._make_mcp_request("tools/list") - - # Handle both formats: response with 'tools' key or response that IS the tools list - - if isinstance(response, dict): - if "tools" in response: - self.available_tools = response["tools"] - elif ( - "result" in response - and isinstance(response["result"], dict) - and "tools" in response["result"] - ): - self.available_tools = response["result"]["tools"] - else: - self.available_tools = [response] if response else [] - elif isinstance(response, list): - self.available_tools = response - else: - self.available_tools = [] + tools = self._run_async_operation("list_tools") + self.available_tools = tools return self.available_tools except Exception as e: raise Exception(f"Failed to discover tools from MCP server: {str(e)}") def execute_action(self, action_name: str, **kwargs) -> Any: """ - Execute an action on the remote MCP server using MCP protocol. + Execute an action on the remote MCP server using FastMCP. Args: action_name: Name of the action to execute @@ -281,22 +293,121 @@ class MCPTool(Tool): Returns: Result from the MCP server """ - self._ensure_valid_session() - - # Skipping empty/None values - letting the server use defaults - + if not self.server_url: + raise Exception("No MCP server configured") + if not self._client: + self._setup_client() cleaned_kwargs = {} for key, value in kwargs.items(): if value == "" or value is None: continue cleaned_kwargs[key] = value - call_params = {"name": action_name, "arguments": cleaned_kwargs} try: - result = self._make_mcp_request("tools/call", call_params) - return result + result = self._run_async_operation( + "call_tool", action_name, **cleaned_kwargs + ) + return self._format_result(result) except Exception as e: raise Exception(f"Failed to execute action '{action_name}': {str(e)}") + def _format_result(self, result) -> Dict: + """Format FastMCP result to match expected format.""" + if hasattr(result, "content"): + content_list = [] + for content_item in result.content: + if hasattr(content_item, "text"): + content_list.append({"type": "text", "text": content_item.text}) + elif hasattr(content_item, "data"): + content_list.append({"type": "data", "data": content_item.data}) + else: + content_list.append( + {"type": "unknown", "content": str(content_item)} + ) + return { + "content": content_list, + "isError": getattr(result, "isError", False), + } + else: + return result + + def test_connection(self) -> Dict: + """ + Test the connection to the MCP server and validate functionality. + + Returns: + Dictionary with connection test results including tool count + """ + if not self.server_url: + return { + "success": False, + "message": "No MCP server URL configured", + "tools_count": 0, + "transport_type": self.transport_type, + "auth_type": self.auth_type, + "error_type": "ConfigurationError", + } + if not self._client: + self._setup_client() + try: + if self.auth_type == "oauth": + return self._test_oauth_connection() + else: + return self._test_regular_connection() + except Exception as e: + return { + "success": False, + "message": f"Connection failed: {str(e)}", + "tools_count": 0, + "transport_type": self.transport_type, + "auth_type": self.auth_type, + "error_type": type(e).__name__, + } + + def _test_regular_connection(self) -> Dict: + """Test connection for non-OAuth auth types.""" + try: + self._run_async_operation("ping") + ping_success = True + except Exception: + ping_success = False + tools = self.discover_tools() + + message = f"Successfully connected to MCP server. Found {len(tools)} tools." + if not ping_success: + message += " (Ping not supported, but tool discovery worked)" + return { + "success": True, + "message": message, + "tools_count": len(tools), + "transport_type": self.transport_type, + "auth_type": self.auth_type, + "ping_supported": ping_success, + "tools": [tool.get("name", "unknown") for tool in tools], + } + + def _test_oauth_connection(self) -> Dict: + """Test connection for OAuth auth type with proper async handling.""" + try: + task = mcp_oauth_task.delay(config=self.config, user=self.user_id) + if not task: + raise Exception("Failed to start OAuth authentication") + return { + "success": True, + "requires_oauth": True, + "task_id": task.id, + "status": "pending", + "message": "OAuth flow started", + } + except Exception as e: + return { + "success": False, + "message": f"OAuth connection failed: {str(e)}", + "tools_count": 0, + "transport_type": self.transport_type, + "auth_type": self.auth_type, + "error_type": type(e).__name__, + } + def get_actions_metadata(self) -> List[Dict]: """ Get metadata for all available actions. @@ -341,58 +452,89 @@ class MCPTool(Tool): actions.append(action) return actions - def test_connection(self) -> Dict: - """ - Test the connection to the MCP server and validate functionality. - - Returns: - Dictionary with connection test results including tool count - """ - try: - self._mcp_session_id = None - - init_result = self._initialize_mcp_connection() - - tools = self.discover_tools() - - message = f"Successfully connected to MCP server. Found {len(tools)} tools." - if init_result.get("cached"): - message += " (Using cached session)" - elif init_result.get("fallback"): - message += " (No formal initialization required)" - return { - "success": True, - "message": message, - "tools_count": len(tools), - "session_id": self._mcp_session_id, - "tools": [tool.get("name", "unknown") for tool in tools[:5]], - } - except Exception as e: - return { - "success": False, - "message": f"Connection failed: {str(e)}", - "tools_count": 0, - "error_type": type(e).__name__, - } - def get_config_requirements(self) -> Dict: + """Get configuration requirements for the MCP tool.""" return { "server_url": { "type": "string", - "description": "URL of the remote MCP server (e.g., https://api.example.com)", + "description": "URL of the remote MCP server (e.g., https://api.example.com/mcp or https://docs.mcp.cloudflare.com/sse)", "required": True, }, + "transport_type": { + "type": "string", + "description": "Transport type for connection", + "enum": ["auto", "sse", "http", "stdio"], + "default": "auto", + "required": False, + "help": { + "auto": "Automatically detect best transport", + "sse": "Server-Sent Events (for real-time streaming)", + "http": "HTTP streaming (recommended for production)", + "stdio": "Standard I/O (for local servers)", + }, + }, "auth_type": { "type": "string", "description": "Authentication type", - "enum": ["none", "api_key", "bearer", "basic"], + "enum": ["none", "bearer", "oauth", "api_key", "basic"], "default": "none", "required": True, + "help": { + "none": "No authentication", + "bearer": "Bearer token authentication", + "oauth": "OAuth 2.1 authentication (with frontend integration)", + "api_key": "API key authentication", + "basic": "Basic authentication", + }, }, "auth_credentials": { "type": "object", "description": "Authentication credentials (varies by auth_type)", "required": False, + "properties": { + "bearer_token": { + "type": "string", + "description": "Bearer token for bearer auth", + }, + "access_token": { + "type": "string", + "description": "Access token for OAuth (if pre-obtained)", + }, + "api_key": { + "type": "string", + "description": "API key for api_key auth", + }, + "api_key_header": { + "type": "string", + "description": "Header name for API key (default: X-API-Key)", + }, + "username": { + "type": "string", + "description": "Username for basic auth", + }, + "password": { + "type": "string", + "description": "Password for basic auth", + }, + }, + }, + "oauth_scopes": { + "type": "array", + "description": "OAuth scopes to request (for oauth auth_type)", + "items": {"type": "string"}, + "required": False, + "default": [], + }, + "oauth_client_name": { + "type": "string", + "description": "Client name for OAuth registration (for oauth auth_type)", + "default": "DocsGPT-MCP", + "required": False, + }, + "headers": { + "type": "object", + "description": "Custom headers to send with requests", + "required": False, }, "timeout": { "type": "integer", @@ -402,4 +544,318 @@ class MCPTool(Tool): "maximum": 300, "required": False, }, + "command": { + "type": "string", + "description": "Command to run for STDIO transport (e.g., 'python')", + "required": False, + }, + "args": { + "type": "array", + "description": "Arguments for STDIO command", + "items": {"type": "string"}, + "required": False, + }, } + + +class DocsGPTOAuth(OAuthClientProvider): + """ + Custom OAuth handler for DocsGPT that uses frontend redirect instead of browser. + """ + + def __init__( + self, + mcp_url: str, + redirect_uri: str, + redis_client: Redis | None = None, + redis_prefix: str = "mcp_oauth:", + task_id: str = None, + scopes: str | list[str] | None = None, + client_name: str = "DocsGPT-MCP", + user_id=None, + db=None, + additional_client_metadata: dict[str, Any] | None = None, + ): + """ + Initialize custom OAuth client provider for DocsGPT. + + Args: + mcp_url: Full URL to the MCP endpoint + redirect_uri: Custom redirect URI for DocsGPT frontend + redis_client: Redis client for storing auth state + redis_prefix: Prefix for Redis keys + task_id: Task ID for tracking auth status + scopes: OAuth scopes to request + client_name: Name for this client during registration + user_id: User ID for token storage + db: Database instance for token storage + additional_client_metadata: Extra fields for OAuthClientMetadata + """ + + self.redirect_uri = redirect_uri + self.redis_client = redis_client + self.redis_prefix = redis_prefix + self.task_id = task_id + self.user_id = user_id + self.db = db + + parsed_url = urlparse(mcp_url) + self.server_base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + + if isinstance(scopes, list): + scopes = " ".join(scopes) + client_metadata = OAuthClientMetadata( + client_name=client_name, + redirect_uris=[AnyHttpUrl(redirect_uri)], + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + scope=scopes, + **(additional_client_metadata or {}), + ) + + storage = DBTokenStorage( + server_url=self.server_base_url, user_id=self.user_id, db_client=self.db + ) + + super().__init__( + server_url=self.server_base_url, + client_metadata=client_metadata, + storage=storage, + redirect_handler=self.redirect_handler, + callback_handler=self.callback_handler, + ) + + self.auth_url = None + self.extracted_state = None + + def _process_auth_url(self, authorization_url: str) -> tuple[str, str]: + """Process authorization URL to extract state""" + try: + parsed_url = urlparse(authorization_url) + query_params = parse_qs(parsed_url.query) + + state_params = query_params.get("state", []) + if state_params: + state = state_params[0] + else: + raise ValueError("No state in auth URL") + return authorization_url, state + except Exception as e: + raise Exception(f"Failed to process auth URL: {e}") + + async def redirect_handler(self, authorization_url: str) -> None: + """Store auth URL and state in Redis for frontend to use.""" + auth_url, state = self._process_auth_url(authorization_url) + logging.info( + "[DocsGPTOAuth] Processed auth_url: %s, state: %s", auth_url, state + ) + self.auth_url = auth_url + self.extracted_state = state + + if self.redis_client and self.extracted_state: + key = f"{self.redis_prefix}auth_url:{self.extracted_state}" + self.redis_client.setex(key, 600, auth_url) + logging.info("[DocsGPTOAuth] Stored auth_url in Redis: %s", key) + + if self.task_id: + status_key = f"mcp_oauth_status:{self.task_id}" + status_data = { + "status": "requires_redirect", + "message": "OAuth authorization required", + "authorization_url": self.auth_url, + "state": self.extracted_state, + "requires_oauth": True, + "task_id": self.task_id, + } + self.redis_client.setex(status_key, 600, json.dumps(status_data)) + + async def callback_handler(self) -> tuple[str, str | None]: + """Wait for auth code from Redis using the state value.""" + if not self.redis_client or not self.extracted_state: + raise Exception("Redis client or state not configured for OAuth") + poll_interval = 1 + max_wait_time = 300 + code_key = f"{self.redis_prefix}code:{self.extracted_state}" + + if self.task_id: + status_key = f"mcp_oauth_status:{self.task_id}" + status_data = { + "status": "awaiting_callback", + "message": "Waiting for OAuth callback...", + "authorization_url": self.auth_url, + "state": self.extracted_state, + "requires_oauth": True, + "task_id": self.task_id, + } + self.redis_client.setex(status_key, 600, json.dumps(status_data)) + start_time = time.time() + while time.time() - start_time < max_wait_time: + code_data = self.redis_client.get(code_key) + if code_data: + code = code_data.decode() + returned_state = self.extracted_state + + self.redis_client.delete(code_key) + self.redis_client.delete( + f"{self.redis_prefix}auth_url:{self.extracted_state}" + ) + self.redis_client.delete( + f"{self.redis_prefix}state:{self.extracted_state}" + ) + + if self.task_id: + status_data = { + "status": "callback_received", + "message": "OAuth callback received, completing authentication...", + "task_id": self.task_id, + } + self.redis_client.setex(status_key, 600, json.dumps(status_data)) + return code, returned_state + error_key = f"{self.redis_prefix}error:{self.extracted_state}" + error_data = self.redis_client.get(error_key) + if error_data: + error_msg = error_data.decode() + self.redis_client.delete(error_key) + self.redis_client.delete( + f"{self.redis_prefix}auth_url:{self.extracted_state}" + ) + self.redis_client.delete( + f"{self.redis_prefix}state:{self.extracted_state}" + ) + raise Exception(f"OAuth error: {error_msg}") + await asyncio.sleep(poll_interval) + self.redis_client.delete(f"{self.redis_prefix}auth_url:{self.extracted_state}") + self.redis_client.delete(f"{self.redis_prefix}state:{self.extracted_state}") + raise Exception("OAuth callback timeout: no code received within 5 minutes") + + +class DBTokenStorage(TokenStorage): + def __init__(self, server_url: str, user_id: str, db_client): + self.server_url = server_url + self.user_id = user_id + self.db_client = db_client + self.collection = db_client["connector_sessions"] + + @staticmethod + def get_base_url(url: str) -> str: + parsed = urlparse(url) + return f"{parsed.scheme}://{parsed.netloc}" + + def get_db_key(self) -> dict: + return { + "server_url": self.get_base_url(self.server_url), + "user_id": self.user_id, + } + + async def get_tokens(self) -> OAuthToken | None: + doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key()) + if not doc or "tokens" not in doc: + return None + try: + tokens = OAuthToken.model_validate(doc["tokens"]) + return tokens + except ValidationError as e: + logging.error(f"Could not load tokens: {e}") + return None + + async def set_tokens(self, tokens: OAuthToken) -> None: + await asyncio.to_thread( + self.collection.update_one, + self.get_db_key(), + {"$set": {"tokens": tokens.model_dump()}}, + True, + ) + logging.info(f"Saved tokens for {self.get_base_url(self.server_url)}") + + async def get_client_info(self) -> OAuthClientInformationFull | None: + doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key()) + if not doc or "client_info" not in doc: + return None + try: + client_info = OAuthClientInformationFull.model_validate(doc["client_info"]) + tokens = await self.get_tokens() + if tokens is None: + logging.debug( + "No tokens found, clearing client info to force fresh registration." + ) + await asyncio.to_thread( + self.collection.update_one, + self.get_db_key(), + {"$unset": {"client_info": ""}}, + ) + return None + return client_info + except ValidationError as e: + logging.error(f"Could not load client info: {e}") + return None + + def _serialize_client_info(self, info: dict) -> dict: + if "redirect_uris" in info and isinstance(info["redirect_uris"], list): + info["redirect_uris"] = [str(u) for u in info["redirect_uris"]] + return info + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: + serialized_info = self._serialize_client_info(client_info.model_dump()) + await asyncio.to_thread( + self.collection.update_one, + self.get_db_key(), + {"$set": {"client_info": serialized_info}}, + True, + ) + logging.info(f"Saved client info for {self.get_base_url(self.server_url)}") + + async def clear(self) -> None: + await asyncio.to_thread(self.collection.delete_one, self.get_db_key()) + logging.info(f"Cleared OAuth cache for {self.get_base_url(self.server_url)}") + + @classmethod + async def clear_all(cls, db_client) -> None: + collection = db_client["connector_sessions"] + await asyncio.to_thread(collection.delete_many, {}) + logging.info("Cleared all OAuth client cache data.") + + +class MCPOAuthManager: + """Manager for handling MCP OAuth callbacks.""" + + def __init__(self, redis_client: Redis | None, redis_prefix: str = "mcp_oauth:"): + self.redis_client = redis_client + self.redis_prefix = redis_prefix + + def handle_oauth_callback( + self, state: str, code: str, error: Optional[str] = None + ) -> bool: + """ + Handle OAuth callback from provider. + + Args: + state: The state parameter from OAuth callback + code: The authorization code from OAuth callback + error: Error message if OAuth failed + + Returns: + True if successful, False otherwise + """ + try: + if not self.redis_client or not state: + raise Exception("Redis client or state not provided") + if error: + error_key = f"{self.redis_prefix}error:{state}" + self.redis_client.setex(error_key, 300, error) + raise Exception(f"OAuth error received: {error}") + code_key = f"{self.redis_prefix}code:{state}" + self.redis_client.setex(code_key, 300, code) + + state_key = f"{self.redis_prefix}state:{state}" + self.redis_client.setex(state_key, 300, "completed") + + return True + except Exception as e: + logging.error(f"Error handling OAuth callback: {e}") + return False + + def get_oauth_status(self, task_id: str) -> Dict[str, Any]: + """Get current status of OAuth flow using provided task_id.""" + if not task_id: + return {"status": "not_started", "message": "OAuth flow not started"} + return mcp_oauth_status_task(task_id) diff --git a/application/api/user/routes.py b/application/api/user/routes.py index f0493c7c..281664d3 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -8,6 +8,7 @@ import uuid import zipfile from functools import wraps from typing import Optional, Tuple +from urllib.parse import unquote from bson.binary import Binary, UuidRepresentation from bson.dbref import DBRef @@ -25,7 +26,7 @@ from flask_restx import fields, inputs, Namespace, Resource from pymongo import ReturnDocument from werkzeug.utils import secure_filename -from application.agents.tools.mcp_tool import MCPTool +from application.agents.tools.mcp_tool import MCPOAuthManager, MCPTool from application.agents.tools.tool_manager import ToolManager from application.api import api @@ -37,6 +38,8 @@ from application.api.user.tasks import ( process_agent_webhook, store_attachment, ) + +from application.cache import get_redis_instance from application.core.mongo_db import MongoDB from application.core.settings import settings from application.parser.connectors.connector_creator import ConnectorCreator @@ -494,7 +497,6 @@ class DeleteOldIndexes(Resource): ) if not doc: return make_response(jsonify({"status": "not found"}), 404) - storage = StorageCreator.get_storage() try: @@ -511,7 +513,6 @@ class DeleteOldIndexes(Resource): settings.VECTOR_STORE, source_id=str(doc["_id"]) ) vectorstore.delete_index() - if "file_path" in doc and doc["file_path"]: file_path = doc["file_path"] if storage.is_directory(file_path): @@ -520,7 +521,6 @@ class DeleteOldIndexes(Resource): storage.delete_file(f) else: storage.delete_file(file_path) - except FileNotFoundError: pass except Exception as err: @@ -528,7 +528,6 @@ class DeleteOldIndexes(Resource): f"Error deleting files and indexes: {err}", exc_info=True ) return make_response(jsonify({"success": False}), 400) - sources_collection.delete_one({"_id": ObjectId(source_id)}) return make_response(jsonify({"success": True}), 200) @@ -600,7 +599,6 @@ class UploadFile(Resource): == temp_file_path ): continue - rel_path = os.path.relpath( os.path.join(root, extracted_file), temp_dir ) @@ -625,7 +623,6 @@ class UploadFile(Resource): file_path = f"{base_path}/{safe_file}" with open(temp_file_path, "rb") as f: storage.save_file(f, file_path) - task = ingest.delay( settings.UPLOAD_FOLDER, [ @@ -697,7 +694,6 @@ class ManageSourceFiles(Resource): return make_response( jsonify({"success": False, "message": "Unauthorized"}), 401 ) - user = decoded_token.get("sub") source_id = request.form.get("source_id") operation = request.form.get("operation") @@ -747,7 +743,6 @@ class ManageSourceFiles(Resource): return make_response( jsonify({"success": False, "message": "Database error"}), 500 ) - try: storage = StorageCreator.get_storage() source_file_path = source.get("file_path", "") @@ -804,7 +799,6 @@ class ManageSourceFiles(Resource): ), 200, ) - elif operation == "remove": file_paths_str = request.form.get("file_paths") if not file_paths_str: @@ -858,7 +852,6 @@ class ManageSourceFiles(Resource): ), 200, ) - elif operation == "remove_directory": directory_path = request.form.get("directory_path") if not directory_path: @@ -884,7 +877,6 @@ class ManageSourceFiles(Resource): ), 400, ) - full_directory_path = ( f"{source_file_path}/{directory_path}" if directory_path @@ -943,7 +935,6 @@ class ManageSourceFiles(Resource): ), 200, ) - except Exception as err: error_context = f"operation={operation}, user={user}, source_id={source_id}" if operation == "remove_directory": @@ -955,7 +946,6 @@ class ManageSourceFiles(Resource): elif operation == "add": parent_dir = request.form.get("parent_dir", "") error_context += f", parent_dir={parent_dir}" - current_app.logger.error( f"Error managing source files: {err} ({error_context})", exc_info=True ) @@ -1632,7 +1622,6 @@ class CreateAgent(Resource): ), 400, ) - # Validate that it has either a 'schema' property or is itself a schema if "schema" not in json_schema and "type" not in json_schema: @@ -3476,7 +3465,6 @@ class AvailableTools(Resource): "displayName": name, "description": description, "configRequirements": tool_instance.get_config_requirements(), - "actions": tool_instance.get_actions_metadata(), } ) except Exception as err: @@ -3527,11 +3515,6 @@ class CreateTool(Resource): "customName": fields.String( required=False, description="Custom name for the tool" ), - "actions": fields.List( - fields.Raw, - required=True, - description="Actions the tool can perform", - ), "status": fields.Boolean( required=True, description="Status of the tool" ), @@ -3549,24 +3532,35 @@ class CreateTool(Resource): "name", "displayName", "description", - "actions", "config", "status", ] missing_fields = check_required_fields(data, required_fields) if missing_fields: return missing_fields - transformed_actions = [] - for action in data["actions"]: - action["active"] = True - if "parameters" in action: - if "properties" in action["parameters"]: - for param_name, param_details in action["parameters"][ - "properties" - ].items(): - param_details["filled_by_llm"] = True - param_details["value"] = "" - transformed_actions.append(action) + try: + tool_instance = tool_manager.tools.get(data["name"]) + if not tool_instance: + return make_response( + jsonify({"success": False, "message": "Tool not found"}), 404 + ) + actions_metadata = tool_instance.get_actions_metadata() + transformed_actions = [] + for action in actions_metadata: + action["active"] = True + if "parameters" in action: + if "properties" in action["parameters"]: + for param_name, param_details in action["parameters"][ + "properties" + ].items(): + param_details["filled_by_llm"] = True + param_details["value"] = "" + transformed_actions.append(action) + except Exception as err: + current_app.logger.error( + f"Error getting tool actions: {err}", exc_info=True + ) + return make_response(jsonify({"success": False}), 400) try: new_tool = { "user": user, @@ -3907,7 +3901,6 @@ class GetChunks(Resource): if not (text_match or title_match): continue filtered_chunks.append(chunk) - chunks = filtered_chunks total_chunks = len(chunks) @@ -4098,7 +4091,6 @@ class UpdateChunk(Resource): current_app.logger.warning( f"Failed to delete old chunk {chunk_id}, but new chunk {new_chunk_id} was created" ) - return make_response( jsonify( { @@ -4226,23 +4218,19 @@ class DirectoryStructure(Resource): decoded_token = request.decoded_token if not decoded_token: return make_response(jsonify({"success": False}), 401) - user = decoded_token.get("sub") doc_id = request.args.get("id") if not doc_id: return make_response(jsonify({"error": "Document ID is required"}), 400) - if not ObjectId.is_valid(doc_id): return make_response(jsonify({"error": "Invalid document ID"}), 400) - try: doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user}) if not doc: return make_response( jsonify({"error": "Document not found or access denied"}), 404 ) - directory_structure = doc.get("directory_structure", {}) base_path = doc.get("file_path", "") @@ -4315,11 +4303,10 @@ class TestMCPServerConfig(Resource): auth_credentials["username"] = config["username"] if "password" in config: auth_credentials["password"] = config["password"] - test_config = config.copy() test_config["auth_credentials"] = auth_credentials - mcp_tool = MCPTool(test_config, user) + mcp_tool = MCPTool(config=test_config, user_id=user) result = mcp_tool.test_connection() return make_response(jsonify(result), 200) @@ -4387,22 +4374,45 @@ class MCPServerSave(Resource): mcp_config = config.copy() mcp_config["auth_credentials"] = auth_credentials - if auth_type == "none" or auth_credentials: - mcp_tool = MCPTool(mcp_config, user) + if auth_type == "oauth": + if not config.get("oauth_task_id"): + return make_response( + jsonify( + { + "success": False, + "error": "Connection not authorized. Please complete the OAuth authorization first.", + } + ), + 400, + ) + redis_client = get_redis_instance() + manager = MCPOAuthManager(redis_client) + result = manager.get_oauth_status(config["oauth_task_id"]) + if not result.get("status") == "completed": + return make_response( + jsonify( + { + "success": False, + "error": "OAuth failed or not completed. Please try authorizing again.", + } + ), + 400, + ) + actions_metadata = result.get("tools", []) + elif auth_type == "none" or auth_credentials: + mcp_tool = MCPTool(config=mcp_config, user_id=user) mcp_tool.discover_tools() actions_metadata = mcp_tool.get_actions_metadata() else: raise Exception( "No valid credentials provided for the selected authentication type" ) - storage_config = config.copy() if auth_credentials: encrypted_credentials_string = encrypt_credentials( auth_credentials, user ) storage_config["encrypted_credentials"] = encrypted_credentials_string - for field in [ "api_key", "bearer_token", @@ -4473,3 +4483,96 @@ class MCPServerSave(Resource): ), 500, ) + + +@user_ns.route("/api/mcp_server/callback") +class MCPOAuthCallback(Resource): + @api.expect( + api.model( + "MCPServerCallbackModel", + { + "code": fields.String(required=True, description="Authorization code"), + "state": fields.String(required=True, description="State parameter"), + "error": fields.String( + required=False, description="Error message (if any)" + ), + }, + ) + ) + @api.doc( + description="Handle OAuth callback by providing the authorization code and state" + ) + def get(self): + code = request.args.get("code") + state = request.args.get("state") + error = request.args.get("error") + + if error: + return redirect( + f"/api/connectors/callback-status?status=error&message=OAuth+error:+{error}.+Please+try+again+and+make+sure+to+grant+all+requested+permissions,+including+offline+access.&provider=mcp_tool" + ) + if not code or not state: + return redirect( + "/api/connectors/callback-status?status=error&message=Authorization+code+or+state+not+provided.+Please+complete+the+authorization+process+and+make+sure+to+grant+offline+access.&provider=mcp_tool" + ) + try: + redis_client = get_redis_instance() + if not redis_client: + return redirect( + "/api/connectors/callback-status?status=error&message=Internal+server+error:+Redis+not+available.&provider=mcp_tool" + ) + code = unquote(code) + manager = MCPOAuthManager(redis_client) + success = manager.handle_oauth_callback(state, code, error) + if success: + return redirect( + "/api/connectors/callback-status?status=success&message=Authorization+code+received+successfully.+You+can+close+this+window.&provider=mcp_tool" + ) + else: + return redirect( + "/api/connectors/callback-status?status=error&message=OAuth+callback+failed.&provider=mcp_tool" + ) + except Exception as e: + current_app.logger.error( + f"Error handling MCP OAuth callback: {str(e)}", exc_info=True + ) + return redirect( + f"/api/connectors/callback-status?status=error&message=Internal+server+error:+{str(e)}.&provider=mcp_tool" + ) + + +@user_ns.route("/api/mcp_server/oauth_status/") +class MCPOAuthStatus(Resource): + def get(self, task_id): + """ + Get current status of OAuth flow. + Frontend should poll this endpoint periodically. + """ + try: + redis_client = get_redis_instance() + status_key = f"mcp_oauth_status:{task_id}" + status_data = redis_client.get(status_key) + + if status_data: + status = json.loads(status_data) + return make_response( + jsonify({"success": True, "task_id": task_id, **status}) + ) + else: + return make_response( + jsonify( + { + "success": False, + "error": "Task not found or expired", + "task_id": task_id, + } + ), + 404, + ) + except Exception as e: + current_app.logger.error( + f"Error getting OAuth status for task {task_id}: {str(e)}" + ) + return make_response( + jsonify({"success": False, "error": str(e), "task_id": task_id}), 500 + ) diff --git a/application/api/user/tasks.py b/application/api/user/tasks.py index 3519b701..c7414b9f 100644 --- a/application/api/user/tasks.py +++ b/application/api/user/tasks.py @@ -5,6 +5,8 @@ from application.worker import ( agent_webhook_worker, attachment_worker, ingest_worker, + mcp_oauth, + mcp_oauth_status, remote_worker, sync_worker, ) @@ -25,6 +27,7 @@ def ingest_remote(self, source_data, job_name, user, loader): @celery.task(bind=True) def reingest_source_task(self, source_id, user): from application.worker import reingest_source_worker + resp = reingest_source_worker(self, source_id, user) return resp @@ -60,9 +63,10 @@ def ingest_connector_task( retriever="classic", operation_mode="upload", doc_id=None, - sync_frequency="never" + sync_frequency="never", ): from application.worker import ingest_connector + resp = ingest_connector( self, job_name, @@ -75,7 +79,7 @@ def ingest_connector_task( retriever=retriever, operation_mode=operation_mode, doc_id=doc_id, - sync_frequency=sync_frequency + sync_frequency=sync_frequency, ) return resp @@ -94,3 +98,15 @@ def setup_periodic_tasks(sender, **kwargs): timedelta(days=30), schedule_syncs.s("monthly"), ) + + +@celery.task(bind=True) +def mcp_oauth_task(self, config, user): + resp = mcp_oauth(self, config, user) + return resp + + +@celery.task(bind=True) +def mcp_oauth_status_task(self, task_id): + resp = mcp_oauth_status(self, task_id) + return resp diff --git a/application/requirements.txt b/application/requirements.txt index 80564689..3882bd6d 100644 --- a/application/requirements.txt +++ b/application/requirements.txt @@ -12,6 +12,7 @@ esprima==4.0.1 esutils==1.0.1 Flask==3.1.1 faiss-cpu==1.9.0.post1 +fastmcp==2.11.0 flask-restx==1.3.0 google-genai==1.3.0 google-api-python-client==2.179.0 @@ -56,13 +57,13 @@ prompt-toolkit==3.0.51 protobuf==5.29.3 psycopg2-binary==2.9.10 py==1.11.0 -pydantic==2.10.6 -pydantic-core==2.27.2 -pydantic-settings==2.7.1 +pydantic +pydantic-core +pydantic-settings pymongo==4.11.3 pypdf==5.5.0 python-dateutil==2.9.0.post0 -python-dotenv==1.0.1 +python-dotenv python-jose==3.4.0 python-pptx==1.0.2 redis==5.2.1 @@ -82,7 +83,7 @@ tzdata==2024.2 urllib3==2.3.0 vine==5.1.0 wcwidth==0.2.13 -werkzeug==3.1.3 +werkzeug>=3.1.0,<3.1.2 yarl==1.20.0 markdownify==1.1.0 tldextract==5.1.3 diff --git a/application/worker.py b/application/worker.py index 5a29d00a..81909fc3 100755 --- a/application/worker.py +++ b/application/worker.py @@ -19,6 +19,7 @@ from bson.objectid import ObjectId from application.agents.agent_creator import AgentCreator from application.api.answer.services.stream_processor import get_prompt +from application.cache import get_redis_instance from application.core.mongo_db import MongoDB from application.core.settings import settings from application.parser.chunking import Chunker @@ -214,8 +215,7 @@ def run_agent_logic(agent_config, input_data): def ingest_worker( - self, directory, formats, job_name, file_path, filename, user, - retriever="classic" + self, directory, formats, job_name, file_path, filename, user, retriever="classic" ): """ Ingest and process documents. @@ -240,7 +240,7 @@ def ingest_worker( sample = False storage = StorageCreator.get_storage() - + logging.info(f"Ingest path: {file_path}", extra={"user": user, "job": job_name}) # Create temporary working directory @@ -253,30 +253,32 @@ def ingest_worker( # Handle directory case logging.info(f"Processing directory: {file_path}") files_list = storage.list_files(file_path) - + for storage_file_path in files_list: if storage.is_directory(storage_file_path): continue - + # Create relative path structure in temp directory rel_path = os.path.relpath(storage_file_path, file_path) local_file_path = os.path.join(temp_dir, rel_path) - + os.makedirs(os.path.dirname(local_file_path), exist_ok=True) - + # Download file try: file_data = storage.get_file(storage_file_path) with open(local_file_path, "wb") as f: f.write(file_data.read()) except Exception as e: - logging.error(f"Error downloading file {storage_file_path}: {e}") + logging.error( + f"Error downloading file {storage_file_path}: {e}" + ) continue else: # Handle single file case temp_filename = os.path.basename(file_path) temp_file_path = os.path.join(temp_dir, temp_filename) - + file_data = storage.get_file(file_path) with open(temp_file_path, "wb") as f: f.write(file_data.read()) @@ -285,7 +287,10 @@ def ingest_worker( if temp_filename.endswith(".zip"): logging.info(f"Extracting zip file: {temp_filename}") extract_zip_recursive( - temp_file_path, temp_dir, current_depth=0, max_depth=RECURSION_DEPTH + temp_file_path, + temp_dir, + current_depth=0, + max_depth=RECURSION_DEPTH, ) self.update_state(state="PROGRESS", meta={"current": 1}) @@ -300,8 +305,8 @@ def ingest_worker( file_metadata=metadata_from_filename, ) raw_docs = reader.load_data() - - directory_structure = getattr(reader, 'directory_structure', {}) + + directory_structure = getattr(reader, "directory_structure", {}) logging.info(f"Directory structure from reader: {directory_structure}") chunker = Chunker( @@ -371,7 +376,10 @@ def reingest_source_worker(self, source_id, user): try: from application.vectorstore.vector_creator import VectorCreator - self.update_state(state="PROGRESS", meta={"current": 10, "status": "Initializing re-ingestion scan"}) + self.update_state( + state="PROGRESS", + meta={"current": 10, "status": "Initializing re-ingestion scan"}, + ) source = sources_collection.find_one({"_id": ObjectId(source_id), "user": user}) if not source: @@ -380,7 +388,9 @@ def reingest_source_worker(self, source_id, user): storage = StorageCreator.get_storage() source_file_path = source.get("file_path", "") - self.update_state(state="PROGRESS", meta={"current": 20, "status": "Scanning current files"}) + self.update_state( + state="PROGRESS", meta={"current": 20, "status": "Scanning current files"} + ) with tempfile.TemporaryDirectory() as temp_dir: # Download all files from storage to temp directory, preserving directory structure @@ -391,7 +401,6 @@ def reingest_source_worker(self, source_id, user): if storage.is_directory(storage_file_path): continue - rel_path = os.path.relpath(storage_file_path, source_file_path) local_file_path = os.path.join(temp_dir, rel_path) @@ -403,23 +412,39 @@ def reingest_source_worker(self, source_id, user): with open(local_file_path, "wb") as f: f.write(file_data.read()) except Exception as e: - logging.error(f"Error downloading file {storage_file_path}: {e}") + logging.error( + f"Error downloading file {storage_file_path}: {e}" + ) continue reader = SimpleDirectoryReader( input_dir=temp_dir, recursive=True, required_exts=[ - ".rst", ".md", ".pdf", ".txt", ".docx", ".csv", ".epub", - ".html", ".mdx", ".json", ".xlsx", ".pptx", ".png", - ".jpg", ".jpeg", + ".rst", + ".md", + ".pdf", + ".txt", + ".docx", + ".csv", + ".epub", + ".html", + ".mdx", + ".json", + ".xlsx", + ".pptx", + ".png", + ".jpg", + ".jpeg", ], exclude_hidden=True, file_metadata=metadata_from_filename, ) reader.load_data() directory_structure = reader.directory_structure - logging.info(f"Directory structure built with token counts: {directory_structure}") + logging.info( + f"Directory structure built with token counts: {directory_structure}" + ) try: old_directory_structure = source.get("directory_structure") or {} @@ -433,11 +458,17 @@ def reingest_source_worker(self, source_id, user): files = set() if isinstance(struct, dict): for name, meta in struct.items(): - current_path = os.path.join(prefix, name) if prefix else name - if isinstance(meta, dict) and ("type" in meta and "size_bytes" in meta): + current_path = ( + os.path.join(prefix, name) if prefix else name + ) + if isinstance(meta, dict) and ( + "type" in meta and "size_bytes" in meta + ): files.add(current_path) elif isinstance(meta, dict): - files |= _flatten_directory_structure(meta, current_path) + files |= _flatten_directory_structure( + meta, current_path + ) return files old_files = _flatten_directory_structure(old_directory_structure) @@ -457,7 +488,9 @@ def reingest_source_worker(self, source_id, user): logging.info("No files removed since last ingest.") except Exception as e: - logging.error(f"Error comparing directory structures: {e}", exc_info=True) + logging.error( + f"Error comparing directory structures: {e}", exc_info=True + ) added_files = [] removed_files = [] try: @@ -477,14 +510,21 @@ def reingest_source_worker(self, source_id, user): settings.EMBEDDINGS_KEY, ) - self.update_state(state="PROGRESS", meta={"current": 40, "status": "Processing file changes"}) + self.update_state( + state="PROGRESS", + meta={"current": 40, "status": "Processing file changes"}, + ) # 1) Delete chunks from removed files deleted = 0 if removed_files: try: for ch in vector_store.get_chunks() or []: - metadata = ch.get("metadata", {}) if isinstance(ch, dict) else getattr(ch, "metadata", {}) + metadata = ( + ch.get("metadata", {}) + if isinstance(ch, dict) + else getattr(ch, "metadata", {}) + ) raw_source = metadata.get("source") source_file = str(raw_source) if raw_source else "" @@ -496,10 +536,17 @@ def reingest_source_worker(self, source_id, user): vector_store.delete_chunk(cid) deleted += 1 except Exception as de: - logging.error(f"Failed deleting chunk {cid}: {de}") - logging.info(f"Deleted {deleted} chunks from {len(removed_files)} removed files") + logging.error( + f"Failed deleting chunk {cid}: {de}" + ) + logging.info( + f"Deleted {deleted} chunks from {len(removed_files)} removed files" + ) except Exception as e: - logging.error(f"Error during deletion of removed file chunks: {e}", exc_info=True) + logging.error( + f"Error during deletion of removed file chunks: {e}", + exc_info=True, + ) # 2) Add chunks from new files added = 0 @@ -528,58 +575,86 @@ def reingest_source_worker(self, source_id, user): ) chunked_new = chunker_new.chunk(documents=raw_docs_new) - for file_path, token_count in reader_new.file_token_counts.items(): + for ( + file_path, + token_count, + ) in reader_new.file_token_counts.items(): try: - rel_path = os.path.relpath(file_path, start=temp_dir) + rel_path = os.path.relpath( + file_path, start=temp_dir + ) path_parts = rel_path.split(os.sep) current_dir = directory_structure for part in path_parts[:-1]: - if part in current_dir and isinstance(current_dir[part], dict): + if part in current_dir and isinstance( + current_dir[part], dict + ): current_dir = current_dir[part] else: break filename = path_parts[-1] - if filename in current_dir and isinstance(current_dir[filename], dict): - current_dir[filename]["token_count"] = token_count - logging.info(f"Updated token count for {rel_path}: {token_count}") + if filename in current_dir and isinstance( + current_dir[filename], dict + ): + current_dir[filename][ + "token_count" + ] = token_count + logging.info( + f"Updated token count for {rel_path}: {token_count}" + ) except Exception as e: - logging.warning(f"Could not update token count for {file_path}: {e}") + logging.warning( + f"Could not update token count for {file_path}: {e}" + ) for d in chunked_new: meta = dict(d.extra_info or {}) try: raw_src = meta.get("source") - if isinstance(raw_src, str) and os.path.isabs(raw_src): - meta["source"] = os.path.relpath(raw_src, start=temp_dir) + if isinstance(raw_src, str) and os.path.isabs( + raw_src + ): + meta["source"] = os.path.relpath( + raw_src, start=temp_dir + ) except Exception: pass vector_store.add_chunk(d.text, metadata=meta) added += 1 - logging.info(f"Added {added} chunks from {len(added_files)} new files") + logging.info( + f"Added {added} chunks from {len(added_files)} new files" + ) except Exception as e: - logging.error(f"Error during ingestion of new files: {e}", exc_info=True) + logging.error( + f"Error during ingestion of new files: {e}", exc_info=True + ) # 3) Update source directory structure timestamp try: total_tokens = sum(reader.file_token_counts.values()) - + sources_collection.update_one( {"_id": ObjectId(source_id)}, { "$set": { "directory_structure": directory_structure, "date": datetime.datetime.now(), - "tokens": total_tokens + "tokens": total_tokens, } }, ) except Exception as e: - logging.error(f"Error updating directory_structure in DB: {e}", exc_info=True) + logging.error( + f"Error updating directory_structure in DB: {e}", exc_info=True + ) - self.update_state(state="PROGRESS", meta={"current": 100, "status": "Re-ingestion completed"}) + self.update_state( + state="PROGRESS", + meta={"current": 100, "status": "Re-ingestion completed"}, + ) return { "source_id": source_id, @@ -591,15 +666,16 @@ def reingest_source_worker(self, source_id, user): "chunks_deleted": deleted, } except Exception as e: - logging.error(f"Error while processing file changes: {e}", exc_info=True) + logging.error( + f"Error while processing file changes: {e}", exc_info=True + ) raise - - except Exception as e: logging.error(f"Error in reingest_source_worker: {e}", exc_info=True) raise + def remote_worker( self, source_data, @@ -651,7 +727,7 @@ def remote_worker( "id": str(id), "type": loader, "remote_data": source_data, - "sync_frequency": sync_frequency + "sync_frequency": sync_frequency, } if operation_mode == "sync": @@ -712,7 +788,7 @@ def sync_worker(self, frequency): self, source_data, name, user, source_type, frequency, retriever, doc_id ) sync_counts["total_sync_count"] += 1 - sync_counts[ + sync_counts[ "sync_success" if resp["status"] == "success" else "sync_failure" ] += 1 return { @@ -749,15 +825,14 @@ def attachment_worker(self, file_info, user): input_files=[local_path], exclude_hidden=True, errors="ignore" ) .load_data()[0] - .text, + .text, ) - - + token_count = num_tokens_from_string(content) if token_count > 100000: content = content[:250000] token_count = num_tokens_from_string(content) - + self.update_state( state="PROGRESS", meta={"current": 80, "status": "Storing in database"} ) @@ -872,37 +947,49 @@ def ingest_connector( doc_id: Document ID for sync operations (required when operation_mode="sync") sync_frequency: How often to sync ("never", "daily", "weekly", "monthly") """ - logging.info(f"Starting remote ingestion from {source_type} for user: {user}, job: {job_name}") + logging.info( + f"Starting remote ingestion from {source_type} for user: {user}, job: {job_name}" + ) self.update_state(state="PROGRESS", meta={"current": 1}) - + with tempfile.TemporaryDirectory() as temp_dir: try: # Step 1: Initialize the appropriate loader - self.update_state(state="PROGRESS", meta={"current": 10, "status": "Initializing connector"}) + self.update_state( + state="PROGRESS", + meta={"current": 10, "status": "Initializing connector"}, + ) if not session_token: raise ValueError(f"{source_type} connector requires session_token") if not ConnectorCreator.is_supported(source_type): - raise ValueError(f"Unsupported connector type: {source_type}. Supported types: {ConnectorCreator.get_supported_connectors()}") + raise ValueError( + f"Unsupported connector type: {source_type}. Supported types: {ConnectorCreator.get_supported_connectors()}" + ) - remote_loader = ConnectorCreator.create_connector(source_type, session_token) + remote_loader = ConnectorCreator.create_connector( + source_type, session_token + ) # Create a clean config for storage api_source_config = { "file_ids": file_ids or [], "folder_ids": folder_ids or [], - "recursive": recursive + "recursive": recursive, } # Step 2: Download files to temp directory - self.update_state(state="PROGRESS", meta={"current": 20, "status": "Downloading files"}) - download_info = remote_loader.download_to_directory( - temp_dir, - api_source_config + self.update_state( + state="PROGRESS", meta={"current": 20, "status": "Downloading files"} ) - - if download_info.get("empty_result", False) or not download_info.get("files_downloaded", 0): + download_info = remote_loader.download_to_directory( + temp_dir, api_source_config + ) + + if download_info.get("empty_result", False) or not download_info.get( + "files_downloaded", 0 + ): logging.warning(f"No files were downloaded from {source_type}") # Create empty result directly instead of calling a separate method return { @@ -913,28 +1000,42 @@ def ingest_connector( "source_config": api_source_config, "directory_structure": "{}", } - + # Step 3: Use SimpleDirectoryReader to process downloaded files - self.update_state(state="PROGRESS", meta={"current": 40, "status": "Processing files"}) + self.update_state( + state="PROGRESS", meta={"current": 40, "status": "Processing files"} + ) reader = SimpleDirectoryReader( input_dir=temp_dir, recursive=True, required_exts=[ - ".rst", ".md", ".pdf", ".txt", ".docx", ".csv", ".epub", - ".html", ".mdx", ".json", ".xlsx", ".pptx", ".png", - ".jpg", ".jpeg", + ".rst", + ".md", + ".pdf", + ".txt", + ".docx", + ".csv", + ".epub", + ".html", + ".mdx", + ".json", + ".xlsx", + ".pptx", + ".png", + ".jpg", + ".jpeg", ], exclude_hidden=True, file_metadata=metadata_from_filename, ) raw_docs = reader.load_data() - directory_structure = getattr(reader, 'directory_structure', {}) + directory_structure = getattr(reader, "directory_structure", {}) - - # Step 4: Process documents (chunking, embedding, etc.) - self.update_state(state="PROGRESS", meta={"current": 60, "status": "Processing documents"}) - + self.update_state( + state="PROGRESS", meta={"current": 60, "status": "Processing documents"} + ) + chunker = Chunker( chunking_strategy="classic_chunk", max_tokens=MAX_TOKENS, @@ -942,22 +1043,26 @@ def ingest_connector( duplicate_headers=False, ) raw_docs = chunker.chunk(documents=raw_docs) - + # Preserve source information in document metadata for doc in raw_docs: - if hasattr(doc, 'extra_info') and doc.extra_info: - source = doc.extra_info.get('source') + if hasattr(doc, "extra_info") and doc.extra_info: + source = doc.extra_info.get("source") if source and os.path.isabs(source): # Convert absolute path to relative path - doc.extra_info['source'] = os.path.relpath(source, start=temp_dir) - + doc.extra_info["source"] = os.path.relpath( + source, start=temp_dir + ) + docs = [Document.to_langchain_format(raw_doc) for raw_doc in raw_docs] - + if operation_mode == "upload": id = ObjectId() elif operation_mode == "sync": if not doc_id or not ObjectId.is_valid(doc_id): - logging.error("Invalid doc_id provided for sync operation: %s", doc_id) + logging.error( + "Invalid doc_id provided for sync operation: %s", doc_id + ) raise ValueError("doc_id must be provided for sync operation.") id = ObjectId(doc_id) else: @@ -966,7 +1071,9 @@ def ingest_connector( vector_store_path = os.path.join(temp_dir, "vector_store") os.makedirs(vector_store_path, exist_ok=True) - self.update_state(state="PROGRESS", meta={"current": 80, "status": "Storing documents"}) + self.update_state( + state="PROGRESS", meta={"current": 80, "status": "Storing documents"} + ) embed_and_store_documents(docs, vector_store_path, id, self) tokens = count_tokens_docs(docs) @@ -979,12 +1086,11 @@ def ingest_connector( "retriever": retriever, "id": str(id), "type": "connector:file", - "remote_data": json.dumps({ - "provider": source_type, - **api_source_config - }), + "remote_data": json.dumps( + {"provider": source_type, **api_source_config} + ), "directory_structure": json.dumps(directory_structure), - "sync_frequency": sync_frequency + "sync_frequency": sync_frequency, } if operation_mode == "sync": @@ -995,7 +1101,9 @@ def ingest_connector( upload_index(vector_store_path, file_data) # Ensure we mark the task as complete - self.update_state(state="PROGRESS", meta={"current": 100, "status": "Complete"}) + self.update_state( + state="PROGRESS", meta={"current": 100, "status": "Complete"} + ) logging.info(f"Remote ingestion completed: {job_name}") @@ -1005,9 +1113,136 @@ def ingest_connector( "tokens": tokens, "type": source_type, "id": str(id), - "status": "complete" + "status": "complete", } - + except Exception as e: logging.error(f"Error during remote ingestion: {e}", exc_info=True) raise + + +def mcp_oauth(self, config: Dict[str, Any], user_id: str = None) -> Dict[str, Any]: + """Worker to handle MCP OAuth flow asynchronously.""" + + logging.info( + "[MCP OAuth] Worker started for user_id=%s, config=%s", user_id, config + ) + try: + import asyncio + + from application.agents.tools.mcp_tool import MCPTool + + task_id = self.request.id + logging.info("[MCP OAuth] Task ID: %s", task_id) + redis_client = get_redis_instance() + + def update_status(status_data: Dict[str, Any]): + logging.info("[MCP OAuth] Updating status: %s", status_data) + status_key = f"mcp_oauth_status:{task_id}" + redis_client.setex(status_key, 600, json.dumps(status_data)) + + update_status( + { + "status": "in_progress", + "message": "Starting OAuth flow...", + "task_id": task_id, + } + ) + + tool_config = config.copy() + tool_config["oauth_task_id"] = task_id + logging.info("[MCP OAuth] Initializing MCPTool with config: %s", tool_config) + mcp_tool = MCPTool(tool_config, user_id) + + async def run_oauth_discovery(): + if not mcp_tool._client: + mcp_tool._setup_client() + return await mcp_tool._execute_with_client("list_tools") + + update_status( + { + "status": "awaiting_redirect", + "message": "Waiting for OAuth redirect...", + "task_id": task_id, + } + ) + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + logging.info("[MCP OAuth] Starting event loop for OAuth discovery...") + tools_response = loop.run_until_complete(run_oauth_discovery()) + logging.info( + "[MCP OAuth] Tools response after async call: %s", tools_response + ) + + status_key = f"mcp_oauth_status:{task_id}" + redis_status = redis_client.get(status_key) + if redis_status: + logging.info( + "[MCP OAuth] Redis status after async call: %s", redis_status + ) + else: + logging.warning( + "[MCP OAuth] No Redis status found after async call for key: %s", + status_key, + ) + tools = mcp_tool.get_actions_metadata() + + update_status( + { + "status": "completed", + "message": f"OAuth completed successfully. Found {len(tools)} tools.", + "tools": tools, + "tools_count": len(tools), + "task_id": task_id, + } + ) + + logging.info( + "[MCP OAuth] OAuth flow completed successfully for task_id=%s", task_id + ) + return {"success": True, "tools": tools, "tools_count": len(tools)} + except Exception as e: + error_msg = f"OAuth flow failed: {str(e)}" + logging.error( + "[MCP OAuth] Exception in OAuth discovery: %s", error_msg, exc_info=True + ) + update_status( + { + "status": "error", + "message": error_msg, + "error": str(e), + "task_id": task_id, + } + ) + return {"success": False, "error": error_msg} + finally: + logging.info("[MCP OAuth] Closing event loop for task_id=%s", task_id) + loop.close() + except Exception as e: + error_msg = f"Failed to initialize OAuth flow: {str(e)}" + logging.error( + "[MCP OAuth] Exception during initialization: %s", error_msg, exc_info=True + ) + update_status( + { + "status": "error", + "message": error_msg, + "error": str(e), + "task_id": task_id, + } + ) + return {"success": False, "error": error_msg} + + +def mcp_oauth_status(self, task_id: str) -> Dict[str, Any]: + """Check the status of an MCP OAuth flow.""" + redis_client = get_redis_instance() + status_key = f"mcp_oauth_status:{task_id}" + + status_data = redis_client.get(status_key) + if status_data: + return json.loads(status_data) + return {"status": "not_found", "message": "Status not found"} diff --git a/frontend/src/api/endpoints.ts b/frontend/src/api/endpoints.ts index dad008da..d2fb1518 100644 --- a/frontend/src/api/endpoints.ts +++ b/frontend/src/api/endpoints.ts @@ -59,6 +59,8 @@ const endpoints = { MANAGE_SOURCE_FILES: '/api/manage_source_files', MCP_TEST_CONNECTION: '/api/mcp_server/test', MCP_SAVE_SERVER: '/api/mcp_server/save', + MCP_OAUTH_STATUS: (task_id: string) => + `/api/mcp_server/oauth_status/${task_id}`, }, CONVERSATION: { ANSWER: '/api/answer', diff --git a/frontend/src/api/services/userService.ts b/frontend/src/api/services/userService.ts index 5dda8ddf..4e31317d 100644 --- a/frontend/src/api/services/userService.ts +++ b/frontend/src/api/services/userService.ts @@ -1,6 +1,6 @@ +import { getSessionToken } from '../../utils/providerUtils'; import apiClient from '../client'; import endpoints from '../endpoints'; -import { getSessionToken } from '../../utils/providerUtils'; const userService = { getConfig: (): Promise => apiClient.get(endpoints.USER.CONFIG, null), @@ -112,6 +112,8 @@ const userService = { apiClient.post(endpoints.USER.MCP_TEST_CONNECTION, data, token), saveMCPServer: (data: any, token: string | null): Promise => apiClient.post(endpoints.USER.MCP_SAVE_SERVER, data, token), + getMCPOAuthStatus: (task_id: string, token: string | null): Promise => + apiClient.get(endpoints.USER.MCP_OAUTH_STATUS(task_id), token), syncConnector: ( docId: string, provider: string, diff --git a/frontend/src/locale/en.json b/frontend/src/locale/en.json index 23e9296f..35419534 100644 --- a/frontend/src/locale/en.json +++ b/frontend/src/locale/en.json @@ -194,17 +194,20 @@ "headerName": "Header Name", "timeout": "Timeout (seconds)", "testConnection": "Test Connection", - "testing": "Testing...", - "saving": "Saving...", + "testing": "Testing", + "saving": "Saving", "save": "Save", "cancel": "Cancel", "noAuth": "No Authentication", + "oauthInProgress": "Waiting for OAuth completion...", + "oauthCompleted": "OAuth completed successfully", "placeholders": { "serverUrl": "https://api.example.com", "apiKey": "Your secret API key", "bearerToken": "Your secret token", "username": "Your username", - "password": "Your password" + "password": "Your password", + "oauthScopes": "OAuth scopes (comma separated)" }, "errors": { "nameRequired": "Server name is required", @@ -215,7 +218,9 @@ "usernameRequired": "Username is required", "passwordRequired": "Password is required", "testFailed": "Connection test failed", - "saveFailed": "Failed to save MCP server" + "saveFailed": "Failed to save MCP server", + "oauthFailed": "OAuth process failed or was cancelled", + "oauthTimeout": "OAuth process timed out, please try again" } } } diff --git a/frontend/src/modals/MCPServerModal.tsx b/frontend/src/modals/MCPServerModal.tsx index 5e916210..9430de88 100644 --- a/frontend/src/modals/MCPServerModal.tsx +++ b/frontend/src/modals/MCPServerModal.tsx @@ -22,6 +22,7 @@ const authTypes = [ { label: 'No Authentication', value: 'none' }, { label: 'API Key', value: 'api_key' }, { label: 'Bearer Token', value: 'bearer' }, + { label: 'OAuth', value: 'oauth' }, // { label: 'Basic Authentication', value: 'basic' }, ]; @@ -45,6 +46,8 @@ export default function MCPServerModal({ username: '', password: '', timeout: server?.timeout || 30, + oauth_scopes: '', + oauth_task_id: '', }); const [loading, setLoading] = useState(false); @@ -52,8 +55,13 @@ export default function MCPServerModal({ const [testResult, setTestResult] = useState<{ success: boolean; message: string; + status?: string; + authorization_url?: string; } | null>(null); const [errors, setErrors] = useState<{ [key: string]: string }>({}); + const oauthPopupRef = useRef(null); + const [oauthCompleted, setOAuthCompleted] = useState(false); + const [saveActive, setSaveActive] = useState(false); useOutsideAlerter(modalRef, () => { if (modalState === 'ACTIVE') { @@ -73,9 +81,12 @@ export default function MCPServerModal({ username: '', password: '', timeout: 30, + oauth_scopes: '', + oauth_task_id: '', }); setErrors({}); setTestResult(null); + setSaveActive(false); }; const validateForm = () => { @@ -154,10 +165,81 @@ export default function MCPServerModal({ } else if (formData.auth_type === 'basic') { config.username = formData.username.trim(); config.password = formData.password.trim(); + } else if (formData.auth_type === 'oauth') { + config.oauth_scopes = formData.oauth_scopes + .split(',') + .map((s) => s.trim()) + .filter(Boolean); + config.oauth_task_id = formData.oauth_task_id.trim(); } return config; }; + const pollOAuthStatus = async ( + taskId: string, + onComplete: (result: any) => void, + ) => { + let attempts = 0; + const maxAttempts = 60; + let popupOpened = false; + const poll = async () => { + try { + const resp = await userService.getMCPOAuthStatus(taskId, token); + const data = await resp.json(); + if (data.authorization_url && !popupOpened) { + if (oauthPopupRef.current && !oauthPopupRef.current.closed) { + oauthPopupRef.current.close(); + } + oauthPopupRef.current = window.open( + data.authorization_url, + 'oauthPopup', + 'width=600,height=700', + ); + popupOpened = true; + } + if (data.status === 'completed') { + setOAuthCompleted(true); + setSaveActive(true); + onComplete({ + ...data, + success: true, + message: t('settings.tools.mcp.oauthCompleted'), + }); + if (oauthPopupRef.current && !oauthPopupRef.current.closed) { + oauthPopupRef.current.close(); + } + } else if (data.status === 'error' || data.success === false) { + setSaveActive(false); + onComplete({ + ...data, + success: false, + message: t('settings.tools.mcp.errors.oauthFailed'), + }); + if (oauthPopupRef.current && !oauthPopupRef.current.closed) { + oauthPopupRef.current.close(); + } + } else { + if (++attempts < maxAttempts) setTimeout(poll, 1000); + else { + setSaveActive(false); + onComplete({ + success: false, + message: t('settings.tools.mcp.errors.oauthTimeout'), + }); + } + } + } catch { + if (++attempts < maxAttempts) setTimeout(poll, 1000); + else + onComplete({ + success: false, + message: t('settings.tools.mcp.errors.oauthTimeout'), + }); + } + }; + poll(); + }; + const testConnection = async () => { if (!validateForm()) return; setTesting(true); @@ -167,13 +249,37 @@ export default function MCPServerModal({ const response = await userService.testMCPConnection({ config }, token); const result = await response.json(); - setTestResult(result); + if ( + formData.auth_type === 'oauth' && + result.requires_oauth && + result.task_id + ) { + setTestResult({ + success: true, + message: t('settings.tools.mcp.oauthInProgress'), + }); + setOAuthCompleted(false); + setSaveActive(false); + pollOAuthStatus(result.task_id, (finalResult) => { + setTestResult(finalResult); + setFormData((prev) => ({ + ...prev, + oauth_task_id: result.task_id || '', + })); + setTesting(false); + }); + } else { + setTestResult(result); + setSaveActive(result.success === true); + setTesting(false); + } } catch (error) { setTestResult({ success: false, message: t('settings.tools.mcp.errors.testFailed'), }); - } finally { + setOAuthCompleted(false); + setSaveActive(false); setTesting(false); } }; @@ -305,6 +411,28 @@ export default function MCPServerModal({ ); + case 'oauth': + return ( +
+
+ + handleInputChange('oauth_scopes', e.target.value) + } + placeholder={ + t('settings.tools.mcp.placeholders.oauthScopes') || + 'Scopes (comma separated)' + } + borderVariant="thin" + labelBgClassName="bg-white dark:bg-charleston-green-2" + /> +
+
+ ); default: return null; } @@ -331,7 +459,6 @@ export default function MCPServerModal({