From 00b4e133d432b57e7eaee6f06f807f06703b9cf2 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Mon, 22 Sep 2025 01:31:09 +0530 Subject: [PATCH] feat: implement OAuth 2.1 integration with custom handlers for fastmcp --- application/agents/tools/mcp_tool.py | 280 ++++++++++++++++++++++----- 1 file changed, 236 insertions(+), 44 deletions(-) diff --git a/application/agents/tools/mcp_tool.py b/application/agents/tools/mcp_tool.py index 7cb32633..fd0767ec 100644 --- a/application/agents/tools/mcp_tool.py +++ b/application/agents/tools/mcp_tool.py @@ -8,6 +8,7 @@ from application.security.encryption import decrypt_credentials from fastmcp import Client from fastmcp.client.auth import BearerAuth +from fastmcp.client.auth.oauth import OAuth as FastMCPOAuth from fastmcp.client.transports import ( SSETransport, StdioTransport, @@ -17,6 +18,30 @@ from fastmcp.client.transports import ( _mcp_clients_cache = {} +class DocsGPTOAuth(FastMCPOAuth): + """Custom OAuth handler that integrates with DocsGPT frontend instead of opening browser.""" + + def __init__(self, *args, **kwargs): + self.auth_url_callback = kwargs.pop("auth_url_callback", None) + self.auth_code_callback = kwargs.pop("auth_code_callback", None) + super().__init__(*args, **kwargs) + + async def redirect_handler(self, authorization_url: str) -> None: + """Override to send auth URL to frontend instead of opening browser.""" + if self.auth_url_callback: + self.auth_url_callback(authorization_url) + else: + raise Exception("OAuth authorization URL callback not configured") + + async def callback_handler(self) -> tuple[str, str | None]: + """Override to wait for auth code from frontend instead of local server.""" + if self.auth_code_callback: + auth_code, state = await self.auth_code_callback() + return auth_code, state + else: + raise Exception("OAuth callback handler not configured") + + class MCPTool(Tool): """ MCP Tool @@ -37,6 +62,8 @@ class MCPTool(Tool): - headers: Custom headers for requests - command: Command for STDIO transport - args: Arguments for STDIO transport + - oauth_scopes: OAuth scopes for oauth auth type + - oauth_client_name: OAuth client name for oauth auth type user_id: User ID for decrypting credentials (required if encrypted_credentials exist) """ self.config = config @@ -53,19 +80,35 @@ class MCPTool(Tool): ) else: self.auth_credentials = config.get("auth_credentials", {}) + # OAuth specific configuration + + self.oauth_scopes = config.get("oauth_scopes", []) + self.oauth_client_name = config.get("oauth_client_name", "DocsGPT-MCP") + + # OAuth callback handlers (to be set by frontend) + + self.oauth_auth_url_callback = None + self.oauth_auth_code_callback = None + self.available_tools = [] self._cache_key = self._generate_cache_key() self._client = None - # Only validate and setup if server_url is provided + # Only validate and setup if server_url is provided and not OAuth + # OAuth setup will happen after callbacks are set - if self.server_url: + if self.server_url and self.auth_type != "oauth": self._setup_client() def _generate_cache_key(self) -> str: """Generate a unique cache key for this MCP server configuration.""" auth_key = "" - if self.auth_type in ["bearer", "oauth"]: + if self.auth_type == "oauth": + # For OAuth, use scopes and client name as part of the key + + scopes_str = ",".join(self.oauth_scopes) if self.oauth_scopes else "none" + auth_key = f"oauth:{self.oauth_client_name}:{scopes_str}" + elif self.auth_type in ["bearer"]: token = self.auth_credentials.get( "bearer_token", "" ) or self.auth_credentials.get("access_token", "") @@ -91,17 +134,31 @@ class MCPTool(Tool): else: del _mcp_clients_cache[self._cache_key] transport = self._create_transport() + auth = None - if self.auth_type in ["bearer", "oauth"]: + if self.auth_type == "oauth": + # Ensure callbacks are configured before creating OAuth instance + + if not self.oauth_auth_url_callback or not self.oauth_auth_code_callback: + raise Exception( + "OAuth callbacks not configured. Call set_oauth_callbacks() first." + ) + # Use custom OAuth handler for frontend integration + + auth = DocsGPTOAuth( + mcp_url=self.server_url, + scopes=self.oauth_scopes, + client_name=self.oauth_client_name, + auth_url_callback=self.oauth_auth_url_callback, + auth_code_callback=self.oauth_auth_code_callback, + ) + elif self.auth_type in ["bearer"]: token = self.auth_credentials.get( "bearer_token", "" ) or self.auth_credentials.get("access_token", "") if token: - self._client = Client(transport, auth=BearerAuth(token)) - else: - self._client = Client(transport) - else: - self._client = Client(transport) + auth = BearerAuth(token) + self._client = Client(transport, auth=auth) _mcp_clients_cache[self._cache_key] = { "client": self._client, "created_at": time.time(), @@ -194,8 +251,12 @@ class MCPTool(Tool): def _run_async_operation(self, operation: str, *args, **kwargs): """Run async operation in sync context.""" try: - loop = asyncio.get_event_loop() - if loop.is_running(): + # Check if there's already a running event loop + + try: + loop = asyncio.get_running_loop() + # If we're in an async context, we need to run in a separate thread + import concurrent.futures def run_in_thread(): @@ -211,19 +272,25 @@ class MCPTool(Tool): with concurrent.futures.ThreadPoolExecutor() as executor: future = executor.submit(run_in_thread) return future.result(timeout=self.timeout) - else: - return loop.run_until_complete( - self._execute_with_client(operation, *args, **kwargs) + except RuntimeError: + # No running loop, we can create one + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete( + self._execute_with_client(operation, *args, **kwargs) + ) + finally: + loop.close() + except Exception as e: + # If async fails, try to give a better error message for OAuth + + if self.auth_type == "oauth" and "callback not configured" in str(e): + raise Exception( + "OAuth callbacks not configured. Call set_oauth_callbacks() first." ) - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - return loop.run_until_complete( - self._execute_with_client(operation, *args, **kwargs) - ) - finally: - loop.close() + raise e def discover_tools(self) -> List[Dict]: """ @@ -310,25 +377,12 @@ class MCPTool(Tool): if not self._client: self._setup_client() try: - try: - self._run_async_operation("ping") - ping_success = True - except Exception: - ping_success = False - tools = self.discover_tools() + # For OAuth, we need to handle async operations differently - message = f"Successfully connected to MCP server. Found {len(tools)} tools." - if not ping_success: - message += " (Ping not supported, but tool discovery worked)" - return { - "success": True, - "message": message, - "tools_count": len(tools), - "transport_type": self.transport_type, - "auth_type": self.auth_type, - "ping_supported": ping_success, - "tools": [tool.get("name", "unknown") for tool in tools[:5]], - } + if self.auth_type == "oauth": + return self._test_oauth_connection() + else: + return self._test_regular_connection() except Exception as e: return { "success": False, @@ -339,6 +393,131 @@ class MCPTool(Tool): "error_type": type(e).__name__, } + def _test_regular_connection(self) -> Dict: + """Test connection for non-OAuth auth types.""" + try: + self._run_async_operation("ping") + ping_success = True + except Exception: + ping_success = False + tools = self.discover_tools() + + message = f"Successfully connected to MCP server. Found {len(tools)} tools." + if not ping_success: + message += " (Ping not supported, but tool discovery worked)" + return { + "success": True, + "message": message, + "tools_count": len(tools), + "transport_type": self.transport_type, + "auth_type": self.auth_type, + "ping_supported": ping_success, + "tools": [tool.get("name", "unknown") for tool in tools[:5]], + } + + def _test_oauth_connection(self) -> Dict: + """Test connection for OAuth auth type with proper async handling.""" + try: + # Ensure callbacks are configured before proceeding + + if not self.oauth_auth_url_callback or not self.oauth_auth_code_callback: + return { + "success": False, + "message": "OAuth callbacks not configured. Call set_oauth_callbacks() first.", + "tools_count": 0, + "transport_type": self.transport_type, + "auth_type": self.auth_type, + "error_type": "ConfigurationError", + } + # Ensure client is set up with proper callbacks + + if not self._client: + self._setup_client() + # For OAuth, we use a simpler approach - just try to discover tools + # This will trigger the OAuth flow if needed + + tools = self.discover_tools() + + return { + "success": True, + "message": f"Successfully connected to OAuth MCP server. Found {len(tools)} tools.", + "tools_count": len(tools), + "transport_type": self.transport_type, + "auth_type": self.auth_type, + "ping_supported": False, # Skip ping for OAuth to avoid complexity + "tools": [tool.get("name", "unknown") for tool in tools[:5]], + } + except Exception as e: + return { + "success": False, + "message": f"OAuth connection failed: {str(e)}", + "tools_count": 0, + "transport_type": self.transport_type, + "auth_type": self.auth_type, + "error_type": type(e).__name__, + } + + def set_oauth_callbacks(self, auth_url_callback, auth_code_callback): + """ + Set OAuth callback handlers for frontend integration. + + Args: + auth_url_callback: Function to call with authorization URL + auth_code_callback: Async function that returns (auth_code, state) tuple + """ + self.oauth_auth_url_callback = auth_url_callback + self.oauth_auth_code_callback = auth_code_callback + + # Clear the client so it gets recreated with the new callbacks + + self._client = None + + # Also clear from cache to ensure fresh creation + + global _mcp_clients_cache + if self._cache_key in _mcp_clients_cache: + del _mcp_clients_cache[self._cache_key] + + def clear_oauth_cache(self): + """ + Clear OAuth cache to force fresh authentication. + This will remove stored tokens and client info for the server. + """ + if self.auth_type == "oauth": + try: + from fastmcp.client.auth.oauth import FileTokenStorage + + storage = FileTokenStorage(server_url=self.server_url) + storage.clear() + print(f"✅ Cleared OAuth cache for {self.server_url}") + except Exception as e: + print(f"⚠️ Failed to clear OAuth cache: {e}") + # Also clear our internal client cache + + global _mcp_clients_cache + if self._cache_key in _mcp_clients_cache: + del _mcp_clients_cache[self._cache_key] + print(f"✅ Cleared internal client cache") + + @staticmethod + def clear_all_oauth_cache(): + """ + Clear all OAuth cache for all servers. + This will remove all stored tokens and client info. + """ + try: + from fastmcp.client.auth.oauth import FileTokenStorage + + FileTokenStorage.clear_all() + print(f"✅ Cleared all OAuth cache") + except Exception as e: + print(f"⚠️ Failed to clear all OAuth cache: {e}") + # Also clear all internal client cache + + global _mcp_clients_cache + _mcp_clients_cache.clear() + print(f"✅ Cleared all internal client cache") + def get_actions_metadata(self) -> List[Dict]: """ Get metadata for all available actions. @@ -413,7 +592,7 @@ class MCPTool(Tool): "help": { "none": "No authentication", "bearer": "Bearer token authentication", - "oauth": "OAuth 2.0 authentication", + "oauth": "OAuth 2.1 authentication (with frontend integration)", "api_key": "API key authentication", "basic": "Basic authentication", }, @@ -425,11 +604,11 @@ class MCPTool(Tool): "properties": { "bearer_token": { "type": "string", - "description": "Bearer token for bearer/oauth auth", + "description": "Bearer token for bearer auth", }, "access_token": { "type": "string", - "description": "Access token for OAuth", + "description": "Access token for OAuth (if pre-obtained)", }, "api_key": { "type": "string", @@ -449,6 +628,19 @@ class MCPTool(Tool): }, }, }, + "oauth_scopes": { + "type": "array", + "description": "OAuth scopes to request (for oauth auth_type)", + "items": {"type": "string"}, + "required": False, + "default": [], + }, + "oauth_client_name": { + "type": "string", + "description": "Client name for OAuth registration (for oauth auth_type)", + "default": "DocsGPT-MCP", + "required": False, + }, "headers": { "type": "object", "description": "Custom headers to send with requests",