import asyncio import base64 import json import logging import time from typing import Any, Dict, List, Optional from urllib.parse import parse_qs, urlparse from application.agents.tools.base import Tool from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task from application.cache import get_redis_instance from application.core.mongo_db import MongoDB from application.core.settings import settings from application.security.encryption import decrypt_credentials from fastmcp import Client from fastmcp.client.auth import BearerAuth from fastmcp.client.transports import ( SSETransport, StdioTransport, StreamableHttpTransport, ) from mcp.client.auth import OAuthClientProvider, TokenStorage from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken from pydantic import AnyHttpUrl, ValidationError from redis import Redis mongo = MongoDB.get_client() db = mongo[settings.MONGO_DB_NAME] _mcp_clients_cache = {} class MCPTool(Tool): """ MCP Tool Connect to remote Model Context Protocol (MCP) servers to access dynamic tools and resources. """ def __init__(self, config: Dict[str, Any], user_id: Optional[str] = None): """ Initialize the MCP Tool with configuration. Args: config: Dictionary containing MCP server configuration: - server_url: URL of the remote MCP server - transport_type: Transport type (auto, sse, http, stdio) - auth_type: Type of authentication (bearer, oauth, api_key, basic, none) - encrypted_credentials: Encrypted credentials (if available) - timeout: Request timeout in seconds (default: 30) - headers: Custom headers for requests - command: Command for STDIO transport - args: Arguments for STDIO transport - oauth_scopes: OAuth scopes for oauth auth type - oauth_client_name: OAuth client name for oauth auth type user_id: User ID for decrypting credentials (required if encrypted_credentials exist) """ self.config = config self.user_id = user_id self.server_url = config.get("server_url", "") self.transport_type = config.get("transport_type", "auto") self.auth_type = config.get("auth_type", "none") self.timeout = config.get("timeout", 30) self.custom_headers = config.get("headers", {}) self.auth_credentials = {} if config.get("encrypted_credentials") and user_id: self.auth_credentials = decrypt_credentials( config["encrypted_credentials"], user_id ) else: self.auth_credentials = config.get("auth_credentials", {}) self.oauth_scopes = config.get("oauth_scopes", []) self.oauth_task_id = config.get("oauth_task_id", None) self.oauth_client_name = config.get("oauth_client_name", "DocsGPT-MCP") self.redirect_uri = f"{settings.API_URL}/api/mcp_server/callback" self.available_tools = [] self._cache_key = self._generate_cache_key() self._client = None # Only validate and setup if server_url is provided and not OAuth if self.server_url and self.auth_type != "oauth": self._setup_client() def _generate_cache_key(self) -> str: """Generate a unique cache key for this MCP server configuration.""" auth_key = "" if self.auth_type == "oauth": scopes_str = ",".join(self.oauth_scopes) if self.oauth_scopes else "none" auth_key = f"oauth:{self.oauth_client_name}:{scopes_str}" elif self.auth_type in ["bearer"]: token = self.auth_credentials.get( "bearer_token", "" ) or self.auth_credentials.get("access_token", "") auth_key = f"bearer:{token[:10]}..." if token else "bearer:none" elif self.auth_type == "api_key": api_key = self.auth_credentials.get("api_key", "") auth_key = f"apikey:{api_key[:10]}..." if api_key else "apikey:none" elif self.auth_type == "basic": username = self.auth_credentials.get("username", "") auth_key = f"basic:{username}" else: auth_key = "none" return f"{self.server_url}#{self.transport_type}#{auth_key}" def _setup_client(self): """Setup FastMCP client with proper transport and authentication.""" global _mcp_clients_cache if self._cache_key in _mcp_clients_cache: cached_data = _mcp_clients_cache[self._cache_key] if time.time() - cached_data["created_at"] < 1800: self._client = cached_data["client"] return else: del _mcp_clients_cache[self._cache_key] transport = self._create_transport() auth = None if self.auth_type == "oauth": redis_client = get_redis_instance() auth = DocsGPTOAuth( mcp_url=self.server_url, scopes=self.oauth_scopes, redis_client=redis_client, redirect_uri=self.redirect_uri, task_id=self.oauth_task_id, db=db, user_id=self.user_id, ) elif self.auth_type == "bearer": token = self.auth_credentials.get( "bearer_token", "" ) or self.auth_credentials.get("access_token", "") if token: auth = BearerAuth(token) self._client = Client(transport, auth=auth) _mcp_clients_cache[self._cache_key] = { "client": self._client, "created_at": time.time(), } def _create_transport(self): """Create appropriate transport based on configuration.""" headers = {"Content-Type": "application/json", "User-Agent": "DocsGPT-MCP/1.0"} headers.update(self.custom_headers) if self.auth_type == "api_key": api_key = self.auth_credentials.get("api_key", "") header_name = self.auth_credentials.get("api_key_header", "X-API-Key") if api_key: headers[header_name] = api_key elif self.auth_type == "basic": username = self.auth_credentials.get("username", "") password = self.auth_credentials.get("password", "") if username and password: credentials = base64.b64encode( f"{username}:{password}".encode() ).decode() headers["Authorization"] = f"Basic {credentials}" if self.transport_type == "auto": if "sse" in self.server_url.lower() or self.server_url.endswith("/sse"): transport_type = "sse" else: transport_type = "http" else: transport_type = self.transport_type if transport_type == "sse": headers.update({"Accept": "text/event-stream", "Cache-Control": "no-cache"}) return SSETransport(url=self.server_url, headers=headers) elif transport_type == "http": return StreamableHttpTransport(url=self.server_url, headers=headers) elif transport_type == "stdio": command = self.config.get("command", "python") args = self.config.get("args", []) env = self.auth_credentials if self.auth_credentials else None return StdioTransport(command=command, args=args, env=env) else: return StreamableHttpTransport(url=self.server_url, headers=headers) def _format_tools(self, tools_response) -> List[Dict]: """Format tools response to match expected format.""" if hasattr(tools_response, "tools"): tools = tools_response.tools elif isinstance(tools_response, list): tools = tools_response else: tools = [] tools_dict = [] for tool in tools: if hasattr(tool, "name"): tool_dict = { "name": tool.name, "description": tool.description, } if hasattr(tool, "inputSchema"): tool_dict["inputSchema"] = tool.inputSchema tools_dict.append(tool_dict) elif isinstance(tool, dict): tools_dict.append(tool) else: if hasattr(tool, "model_dump"): tools_dict.append(tool.model_dump()) else: tools_dict.append({"name": str(tool), "description": ""}) return tools_dict async def _execute_with_client(self, operation: str, *args, **kwargs): """Execute operation with FastMCP client.""" if not self._client: raise Exception("FastMCP client not initialized") async with self._client: if operation == "ping": return await self._client.ping() elif operation == "list_tools": tools_response = await self._client.list_tools() self.available_tools = self._format_tools(tools_response) return self.available_tools elif operation == "call_tool": tool_name = args[0] tool_args = kwargs return await self._client.call_tool(tool_name, tool_args) elif operation == "list_resources": return await self._client.list_resources() elif operation == "list_prompts": return await self._client.list_prompts() else: raise Exception(f"Unknown operation: {operation}") def _run_async_operation(self, operation: str, *args, **kwargs): """Run async operation in sync context.""" try: try: loop = asyncio.get_running_loop() import concurrent.futures def run_in_thread(): new_loop = asyncio.new_event_loop() asyncio.set_event_loop(new_loop) try: return new_loop.run_until_complete( self._execute_with_client(operation, *args, **kwargs) ) finally: new_loop.close() with concurrent.futures.ThreadPoolExecutor() as executor: future = executor.submit(run_in_thread) return future.result(timeout=self.timeout) except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: return loop.run_until_complete( self._execute_with_client(operation, *args, **kwargs) ) finally: loop.close() except Exception as e: print(f"Error occurred while running async operation: {e}") raise def discover_tools(self) -> List[Dict]: """ Discover available tools from the MCP server using FastMCP. Returns: List of tool definitions from the server """ if not self.server_url: return [] if not self._client: self._setup_client() try: tools = self._run_async_operation("list_tools") self.available_tools = tools return self.available_tools except Exception as e: raise Exception(f"Failed to discover tools from MCP server: {str(e)}") def execute_action(self, action_name: str, **kwargs) -> Any: """ Execute an action on the remote MCP server using FastMCP. Args: action_name: Name of the action to execute **kwargs: Parameters for the action Returns: Result from the MCP server """ if not self.server_url: raise Exception("No MCP server configured") if not self._client: self._setup_client() cleaned_kwargs = {} for key, value in kwargs.items(): if value == "" or value is None: continue cleaned_kwargs[key] = value try: result = self._run_async_operation( "call_tool", action_name, **cleaned_kwargs ) return self._format_result(result) except Exception as e: raise Exception(f"Failed to execute action '{action_name}': {str(e)}") def _format_result(self, result) -> Dict: """Format FastMCP result to match expected format.""" if hasattr(result, "content"): content_list = [] for content_item in result.content: if hasattr(content_item, "text"): content_list.append({"type": "text", "text": content_item.text}) elif hasattr(content_item, "data"): content_list.append({"type": "data", "data": content_item.data}) else: content_list.append( {"type": "unknown", "content": str(content_item)} ) return { "content": content_list, "isError": getattr(result, "isError", False), } else: return result def test_connection(self) -> Dict: """ Test the connection to the MCP server and validate functionality. Returns: Dictionary with connection test results including tool count """ if not self.server_url: return { "success": False, "message": "No MCP server URL configured", "tools_count": 0, "transport_type": self.transport_type, "auth_type": self.auth_type, "error_type": "ConfigurationError", } if not self._client: self._setup_client() try: if self.auth_type == "oauth": return self._test_oauth_connection() else: return self._test_regular_connection() except Exception as e: return { "success": False, "message": f"Connection failed: {str(e)}", "tools_count": 0, "transport_type": self.transport_type, "auth_type": self.auth_type, "error_type": type(e).__name__, } def _test_regular_connection(self) -> Dict: """Test connection for non-OAuth auth types.""" try: self._run_async_operation("ping") ping_success = True except Exception: ping_success = False tools = self.discover_tools() message = f"Successfully connected to MCP server. Found {len(tools)} tools." if not ping_success: message += " (Ping not supported, but tool discovery worked)" return { "success": True, "message": message, "tools_count": len(tools), "transport_type": self.transport_type, "auth_type": self.auth_type, "ping_supported": ping_success, "tools": [tool.get("name", "unknown") for tool in tools], } def _test_oauth_connection(self) -> Dict: """Test connection for OAuth auth type with proper async handling.""" try: task = mcp_oauth_task.delay(config=self.config, user=self.user_id) if not task: raise Exception("Failed to start OAuth authentication") return { "success": True, "requires_oauth": True, "task_id": task.id, "status": "pending", "message": "OAuth flow started", } except Exception as e: return { "success": False, "message": f"OAuth connection failed: {str(e)}", "tools_count": 0, "transport_type": self.transport_type, "auth_type": self.auth_type, "error_type": type(e).__name__, } def get_actions_metadata(self) -> List[Dict]: """ Get metadata for all available actions. Returns: List of action metadata dictionaries """ actions = [] for tool in self.available_tools: input_schema = ( tool.get("inputSchema") or tool.get("input_schema") or tool.get("schema") or tool.get("parameters") ) parameters_schema = { "type": "object", "properties": {}, "required": [], } if input_schema: if isinstance(input_schema, dict): if "properties" in input_schema: parameters_schema = { "type": input_schema.get("type", "object"), "properties": input_schema.get("properties", {}), "required": input_schema.get("required", []), } for key in ["additionalProperties", "description"]: if key in input_schema: parameters_schema[key] = input_schema[key] else: parameters_schema["properties"] = input_schema action = { "name": tool.get("name", ""), "description": tool.get("description", ""), "parameters": parameters_schema, } actions.append(action) return actions def get_config_requirements(self) -> Dict: """Get configuration requirements for the MCP tool.""" return { "server_url": { "type": "string", "description": "URL of the remote MCP server (e.g., https://api.example.com/mcp or https://docs.mcp.cloudflare.com/sse)", "required": True, }, "transport_type": { "type": "string", "description": "Transport type for connection", "enum": ["auto", "sse", "http", "stdio"], "default": "auto", "required": False, "help": { "auto": "Automatically detect best transport", "sse": "Server-Sent Events (for real-time streaming)", "http": "HTTP streaming (recommended for production)", "stdio": "Standard I/O (for local servers)", }, }, "auth_type": { "type": "string", "description": "Authentication type", "enum": ["none", "bearer", "oauth", "api_key", "basic"], "default": "none", "required": True, "help": { "none": "No authentication", "bearer": "Bearer token authentication", "oauth": "OAuth 2.1 authentication (with frontend integration)", "api_key": "API key authentication", "basic": "Basic authentication", }, }, "auth_credentials": { "type": "object", "description": "Authentication credentials (varies by auth_type)", "required": False, "properties": { "bearer_token": { "type": "string", "description": "Bearer token for bearer auth", }, "access_token": { "type": "string", "description": "Access token for OAuth (if pre-obtained)", }, "api_key": { "type": "string", "description": "API key for api_key auth", }, "api_key_header": { "type": "string", "description": "Header name for API key (default: X-API-Key)", }, "username": { "type": "string", "description": "Username for basic auth", }, "password": { "type": "string", "description": "Password for basic auth", }, }, }, "oauth_scopes": { "type": "array", "description": "OAuth scopes to request (for oauth auth_type)", "items": {"type": "string"}, "required": False, "default": [], }, "oauth_client_name": { "type": "string", "description": "Client name for OAuth registration (for oauth auth_type)", "default": "DocsGPT-MCP", "required": False, }, "headers": { "type": "object", "description": "Custom headers to send with requests", "required": False, }, "timeout": { "type": "integer", "description": "Request timeout in seconds", "default": 30, "minimum": 1, "maximum": 300, "required": False, }, "command": { "type": "string", "description": "Command to run for STDIO transport (e.g., 'python')", "required": False, }, "args": { "type": "array", "description": "Arguments for STDIO command", "items": {"type": "string"}, "required": False, }, } class DocsGPTOAuth(OAuthClientProvider): """ Custom OAuth handler for DocsGPT that uses frontend redirect instead of browser. """ def __init__( self, mcp_url: str, redirect_uri: str, redis_client: Redis | None = None, redis_prefix: str = "mcp_oauth:", task_id: str = None, scopes: str | list[str] | None = None, client_name: str = "DocsGPT-MCP", user_id=None, db=None, additional_client_metadata: dict[str, Any] | None = None, ): """ Initialize custom OAuth client provider for DocsGPT. Args: mcp_url: Full URL to the MCP endpoint redirect_uri: Custom redirect URI for DocsGPT frontend redis_client: Redis client for storing auth state redis_prefix: Prefix for Redis keys task_id: Task ID for tracking auth status scopes: OAuth scopes to request client_name: Name for this client during registration user_id: User ID for token storage db: Database instance for token storage additional_client_metadata: Extra fields for OAuthClientMetadata """ self.redirect_uri = redirect_uri self.redis_client = redis_client self.redis_prefix = redis_prefix self.task_id = task_id self.user_id = user_id self.db = db parsed_url = urlparse(mcp_url) self.server_base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" if isinstance(scopes, list): scopes = " ".join(scopes) client_metadata = OAuthClientMetadata( client_name=client_name, redirect_uris=[AnyHttpUrl(redirect_uri)], grant_types=["authorization_code", "refresh_token"], response_types=["code"], scope=scopes, **(additional_client_metadata or {}), ) storage = DBTokenStorage( server_url=self.server_base_url, user_id=self.user_id, db_client=self.db ) super().__init__( server_url=self.server_base_url, client_metadata=client_metadata, storage=storage, redirect_handler=self.redirect_handler, callback_handler=self.callback_handler, ) self.auth_url = None self.extracted_state = None def _process_auth_url(self, authorization_url: str) -> tuple[str, str]: """Process authorization URL to extract state""" try: parsed_url = urlparse(authorization_url) query_params = parse_qs(parsed_url.query) state_params = query_params.get("state", []) if state_params: state = state_params[0] else: raise ValueError("No state in auth URL") return authorization_url, state except Exception as e: raise Exception(f"Failed to process auth URL: {e}") async def redirect_handler(self, authorization_url: str) -> None: """Store auth URL and state in Redis for frontend to use.""" auth_url, state = self._process_auth_url(authorization_url) logging.info( "[DocsGPTOAuth] Processed auth_url: %s, state: %s", auth_url, state ) self.auth_url = auth_url self.extracted_state = state if self.redis_client and self.extracted_state: key = f"{self.redis_prefix}auth_url:{self.extracted_state}" self.redis_client.setex(key, 600, auth_url) logging.info("[DocsGPTOAuth] Stored auth_url in Redis: %s", key) if self.task_id: status_key = f"mcp_oauth_status:{self.task_id}" status_data = { "status": "requires_redirect", "message": "OAuth authorization required", "authorization_url": self.auth_url, "state": self.extracted_state, "requires_oauth": True, "task_id": self.task_id, } self.redis_client.setex(status_key, 600, json.dumps(status_data)) async def callback_handler(self) -> tuple[str, str | None]: """Wait for auth code from Redis using the state value.""" if not self.redis_client or not self.extracted_state: raise Exception("Redis client or state not configured for OAuth") poll_interval = 1 max_wait_time = 300 code_key = f"{self.redis_prefix}code:{self.extracted_state}" if self.task_id: status_key = f"mcp_oauth_status:{self.task_id}" status_data = { "status": "awaiting_callback", "message": "Waiting for OAuth callback...", "authorization_url": self.auth_url, "state": self.extracted_state, "requires_oauth": True, "task_id": self.task_id, } self.redis_client.setex(status_key, 600, json.dumps(status_data)) start_time = time.time() while time.time() - start_time < max_wait_time: code_data = self.redis_client.get(code_key) if code_data: code = code_data.decode() returned_state = self.extracted_state self.redis_client.delete(code_key) self.redis_client.delete( f"{self.redis_prefix}auth_url:{self.extracted_state}" ) self.redis_client.delete( f"{self.redis_prefix}state:{self.extracted_state}" ) if self.task_id: status_data = { "status": "callback_received", "message": "OAuth callback received, completing authentication...", "task_id": self.task_id, } self.redis_client.setex(status_key, 600, json.dumps(status_data)) return code, returned_state error_key = f"{self.redis_prefix}error:{self.extracted_state}" error_data = self.redis_client.get(error_key) if error_data: error_msg = error_data.decode() self.redis_client.delete(error_key) self.redis_client.delete( f"{self.redis_prefix}auth_url:{self.extracted_state}" ) self.redis_client.delete( f"{self.redis_prefix}state:{self.extracted_state}" ) raise Exception(f"OAuth error: {error_msg}") await asyncio.sleep(poll_interval) self.redis_client.delete(f"{self.redis_prefix}auth_url:{self.extracted_state}") self.redis_client.delete(f"{self.redis_prefix}state:{self.extracted_state}") raise Exception("OAuth callback timeout: no code received within 5 minutes") class DBTokenStorage(TokenStorage): def __init__(self, server_url: str, user_id: str, db_client): self.server_url = server_url self.user_id = user_id self.db_client = db_client self.collection = db_client["connector_sessions"] @staticmethod def get_base_url(url: str) -> str: parsed = urlparse(url) return f"{parsed.scheme}://{parsed.netloc}" def get_db_key(self) -> dict: return { "server_url": self.get_base_url(self.server_url), "user_id": self.user_id, } async def get_tokens(self) -> OAuthToken | None: doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key()) if not doc or "tokens" not in doc: return None try: tokens = OAuthToken.model_validate(doc["tokens"]) return tokens except ValidationError as e: logging.error(f"Could not load tokens: {e}") return None async def set_tokens(self, tokens: OAuthToken) -> None: await asyncio.to_thread( self.collection.update_one, self.get_db_key(), {"$set": {"tokens": tokens.model_dump()}}, True, ) logging.info(f"Saved tokens for {self.get_base_url(self.server_url)}") async def get_client_info(self) -> OAuthClientInformationFull | None: doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key()) if not doc or "client_info" not in doc: return None try: client_info = OAuthClientInformationFull.model_validate(doc["client_info"]) tokens = await self.get_tokens() if tokens is None: logging.debug( "No tokens found, clearing client info to force fresh registration." ) await asyncio.to_thread( self.collection.update_one, self.get_db_key(), {"$unset": {"client_info": ""}}, ) return None return client_info except ValidationError as e: logging.error(f"Could not load client info: {e}") return None def _serialize_client_info(self, info: dict) -> dict: if "redirect_uris" in info and isinstance(info["redirect_uris"], list): info["redirect_uris"] = [str(u) for u in info["redirect_uris"]] return info async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: serialized_info = self._serialize_client_info(client_info.model_dump()) await asyncio.to_thread( self.collection.update_one, self.get_db_key(), {"$set": {"client_info": serialized_info}}, True, ) logging.info(f"Saved client info for {self.get_base_url(self.server_url)}") async def clear(self) -> None: await asyncio.to_thread(self.collection.delete_one, self.get_db_key()) logging.info(f"Cleared OAuth cache for {self.get_base_url(self.server_url)}") @classmethod async def clear_all(cls, db_client) -> None: collection = db_client["connector_sessions"] await asyncio.to_thread(collection.delete_many, {}) logging.info("Cleared all OAuth client cache data.") class MCPOAuthManager: """Manager for handling MCP OAuth callbacks.""" def __init__(self, redis_client: Redis | None, redis_prefix: str = "mcp_oauth:"): self.redis_client = redis_client self.redis_prefix = redis_prefix def handle_oauth_callback( self, state: str, code: str, error: Optional[str] = None ) -> bool: """ Handle OAuth callback from provider. Args: state: The state parameter from OAuth callback code: The authorization code from OAuth callback error: Error message if OAuth failed Returns: True if successful, False otherwise """ try: if not self.redis_client or not state: raise Exception("Redis client or state not provided") if error: error_key = f"{self.redis_prefix}error:{state}" self.redis_client.setex(error_key, 300, error) raise Exception(f"OAuth error received: {error}") code_key = f"{self.redis_prefix}code:{state}" self.redis_client.setex(code_key, 300, code) state_key = f"{self.redis_prefix}state:{state}" self.redis_client.setex(state_key, 300, "completed") return True except Exception as e: logging.error(f"Error handling OAuth callback: {e}") return False def get_oauth_status(self, task_id: str) -> Dict[str, Any]: """Get current status of OAuth flow using provided task_id.""" if not task_id: return {"status": "not_started", "message": "OAuth flow not started"} return mcp_oauth_status_task(task_id)