feat: implement OAuth 2.1 integration with custom handlers for fastmcp

This commit is contained in:
Siddhant Rai
2025-09-22 01:31:09 +05:30
parent 8ce345cd94
commit 00b4e133d4

View File

@@ -8,6 +8,7 @@ from application.security.encryption import decrypt_credentials
from fastmcp import Client
from fastmcp.client.auth import BearerAuth
from fastmcp.client.auth.oauth import OAuth as FastMCPOAuth
from fastmcp.client.transports import (
SSETransport,
StdioTransport,
@@ -17,6 +18,30 @@ from fastmcp.client.transports import (
_mcp_clients_cache = {}
class DocsGPTOAuth(FastMCPOAuth):
"""Custom OAuth handler that integrates with DocsGPT frontend instead of opening browser."""
def __init__(self, *args, **kwargs):
self.auth_url_callback = kwargs.pop("auth_url_callback", None)
self.auth_code_callback = kwargs.pop("auth_code_callback", None)
super().__init__(*args, **kwargs)
async def redirect_handler(self, authorization_url: str) -> None:
"""Override to send auth URL to frontend instead of opening browser."""
if self.auth_url_callback:
self.auth_url_callback(authorization_url)
else:
raise Exception("OAuth authorization URL callback not configured")
async def callback_handler(self) -> tuple[str, str | None]:
"""Override to wait for auth code from frontend instead of local server."""
if self.auth_code_callback:
auth_code, state = await self.auth_code_callback()
return auth_code, state
else:
raise Exception("OAuth callback handler not configured")
class MCPTool(Tool):
"""
MCP Tool
@@ -37,6 +62,8 @@ class MCPTool(Tool):
- headers: Custom headers for requests
- command: Command for STDIO transport
- args: Arguments for STDIO transport
- oauth_scopes: OAuth scopes for oauth auth type
- oauth_client_name: OAuth client name for oauth auth type
user_id: User ID for decrypting credentials (required if encrypted_credentials exist)
"""
self.config = config
@@ -53,19 +80,35 @@ class MCPTool(Tool):
)
else:
self.auth_credentials = config.get("auth_credentials", {})
# OAuth specific configuration
self.oauth_scopes = config.get("oauth_scopes", [])
self.oauth_client_name = config.get("oauth_client_name", "DocsGPT-MCP")
# OAuth callback handlers (to be set by frontend)
self.oauth_auth_url_callback = None
self.oauth_auth_code_callback = None
self.available_tools = []
self._cache_key = self._generate_cache_key()
self._client = None
# Only validate and setup if server_url is provided
# Only validate and setup if server_url is provided and not OAuth
# OAuth setup will happen after callbacks are set
if self.server_url:
if self.server_url and self.auth_type != "oauth":
self._setup_client()
def _generate_cache_key(self) -> str:
"""Generate a unique cache key for this MCP server configuration."""
auth_key = ""
if self.auth_type in ["bearer", "oauth"]:
if self.auth_type == "oauth":
# For OAuth, use scopes and client name as part of the key
scopes_str = ",".join(self.oauth_scopes) if self.oauth_scopes else "none"
auth_key = f"oauth:{self.oauth_client_name}:{scopes_str}"
elif self.auth_type in ["bearer"]:
token = self.auth_credentials.get(
"bearer_token", ""
) or self.auth_credentials.get("access_token", "")
@@ -91,17 +134,31 @@ class MCPTool(Tool):
else:
del _mcp_clients_cache[self._cache_key]
transport = self._create_transport()
auth = None
if self.auth_type in ["bearer", "oauth"]:
if self.auth_type == "oauth":
# Ensure callbacks are configured before creating OAuth instance
if not self.oauth_auth_url_callback or not self.oauth_auth_code_callback:
raise Exception(
"OAuth callbacks not configured. Call set_oauth_callbacks() first."
)
# Use custom OAuth handler for frontend integration
auth = DocsGPTOAuth(
mcp_url=self.server_url,
scopes=self.oauth_scopes,
client_name=self.oauth_client_name,
auth_url_callback=self.oauth_auth_url_callback,
auth_code_callback=self.oauth_auth_code_callback,
)
elif self.auth_type in ["bearer"]:
token = self.auth_credentials.get(
"bearer_token", ""
) or self.auth_credentials.get("access_token", "")
if token:
self._client = Client(transport, auth=BearerAuth(token))
else:
self._client = Client(transport)
else:
self._client = Client(transport)
auth = BearerAuth(token)
self._client = Client(transport, auth=auth)
_mcp_clients_cache[self._cache_key] = {
"client": self._client,
"created_at": time.time(),
@@ -194,8 +251,12 @@ class MCPTool(Tool):
def _run_async_operation(self, operation: str, *args, **kwargs):
"""Run async operation in sync context."""
try:
loop = asyncio.get_event_loop()
if loop.is_running():
# Check if there's already a running event loop
try:
loop = asyncio.get_running_loop()
# If we're in an async context, we need to run in a separate thread
import concurrent.futures
def run_in_thread():
@@ -211,19 +272,25 @@ class MCPTool(Tool):
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_thread)
return future.result(timeout=self.timeout)
else:
return loop.run_until_complete(
self._execute_with_client(operation, *args, **kwargs)
except RuntimeError:
# No running loop, we can create one
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(
self._execute_with_client(operation, *args, **kwargs)
)
finally:
loop.close()
except Exception as e:
# If async fails, try to give a better error message for OAuth
if self.auth_type == "oauth" and "callback not configured" in str(e):
raise Exception(
"OAuth callbacks not configured. Call set_oauth_callbacks() first."
)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(
self._execute_with_client(operation, *args, **kwargs)
)
finally:
loop.close()
raise e
def discover_tools(self) -> List[Dict]:
"""
@@ -310,25 +377,12 @@ class MCPTool(Tool):
if not self._client:
self._setup_client()
try:
try:
self._run_async_operation("ping")
ping_success = True
except Exception:
ping_success = False
tools = self.discover_tools()
# For OAuth, we need to handle async operations differently
message = f"Successfully connected to MCP server. Found {len(tools)} tools."
if not ping_success:
message += " (Ping not supported, but tool discovery worked)"
return {
"success": True,
"message": message,
"tools_count": len(tools),
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"ping_supported": ping_success,
"tools": [tool.get("name", "unknown") for tool in tools[:5]],
}
if self.auth_type == "oauth":
return self._test_oauth_connection()
else:
return self._test_regular_connection()
except Exception as e:
return {
"success": False,
@@ -339,6 +393,131 @@ class MCPTool(Tool):
"error_type": type(e).__name__,
}
def _test_regular_connection(self) -> Dict:
"""Test connection for non-OAuth auth types."""
try:
self._run_async_operation("ping")
ping_success = True
except Exception:
ping_success = False
tools = self.discover_tools()
message = f"Successfully connected to MCP server. Found {len(tools)} tools."
if not ping_success:
message += " (Ping not supported, but tool discovery worked)"
return {
"success": True,
"message": message,
"tools_count": len(tools),
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"ping_supported": ping_success,
"tools": [tool.get("name", "unknown") for tool in tools[:5]],
}
def _test_oauth_connection(self) -> Dict:
"""Test connection for OAuth auth type with proper async handling."""
try:
# Ensure callbacks are configured before proceeding
if not self.oauth_auth_url_callback or not self.oauth_auth_code_callback:
return {
"success": False,
"message": "OAuth callbacks not configured. Call set_oauth_callbacks() first.",
"tools_count": 0,
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"error_type": "ConfigurationError",
}
# Ensure client is set up with proper callbacks
if not self._client:
self._setup_client()
# For OAuth, we use a simpler approach - just try to discover tools
# This will trigger the OAuth flow if needed
tools = self.discover_tools()
return {
"success": True,
"message": f"Successfully connected to OAuth MCP server. Found {len(tools)} tools.",
"tools_count": len(tools),
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"ping_supported": False, # Skip ping for OAuth to avoid complexity
"tools": [tool.get("name", "unknown") for tool in tools[:5]],
}
except Exception as e:
return {
"success": False,
"message": f"OAuth connection failed: {str(e)}",
"tools_count": 0,
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"error_type": type(e).__name__,
}
def set_oauth_callbacks(self, auth_url_callback, auth_code_callback):
"""
Set OAuth callback handlers for frontend integration.
Args:
auth_url_callback: Function to call with authorization URL
auth_code_callback: Async function that returns (auth_code, state) tuple
"""
self.oauth_auth_url_callback = auth_url_callback
self.oauth_auth_code_callback = auth_code_callback
# Clear the client so it gets recreated with the new callbacks
self._client = None
# Also clear from cache to ensure fresh creation
global _mcp_clients_cache
if self._cache_key in _mcp_clients_cache:
del _mcp_clients_cache[self._cache_key]
def clear_oauth_cache(self):
"""
Clear OAuth cache to force fresh authentication.
This will remove stored tokens and client info for the server.
"""
if self.auth_type == "oauth":
try:
from fastmcp.client.auth.oauth import FileTokenStorage
storage = FileTokenStorage(server_url=self.server_url)
storage.clear()
print(f"✅ Cleared OAuth cache for {self.server_url}")
except Exception as e:
print(f"⚠️ Failed to clear OAuth cache: {e}")
# Also clear our internal client cache
global _mcp_clients_cache
if self._cache_key in _mcp_clients_cache:
del _mcp_clients_cache[self._cache_key]
print(f"✅ Cleared internal client cache")
@staticmethod
def clear_all_oauth_cache():
"""
Clear all OAuth cache for all servers.
This will remove all stored tokens and client info.
"""
try:
from fastmcp.client.auth.oauth import FileTokenStorage
FileTokenStorage.clear_all()
print(f"✅ Cleared all OAuth cache")
except Exception as e:
print(f"⚠️ Failed to clear all OAuth cache: {e}")
# Also clear all internal client cache
global _mcp_clients_cache
_mcp_clients_cache.clear()
print(f"✅ Cleared all internal client cache")
def get_actions_metadata(self) -> List[Dict]:
"""
Get metadata for all available actions.
@@ -413,7 +592,7 @@ class MCPTool(Tool):
"help": {
"none": "No authentication",
"bearer": "Bearer token authentication",
"oauth": "OAuth 2.0 authentication",
"oauth": "OAuth 2.1 authentication (with frontend integration)",
"api_key": "API key authentication",
"basic": "Basic authentication",
},
@@ -425,11 +604,11 @@ class MCPTool(Tool):
"properties": {
"bearer_token": {
"type": "string",
"description": "Bearer token for bearer/oauth auth",
"description": "Bearer token for bearer auth",
},
"access_token": {
"type": "string",
"description": "Access token for OAuth",
"description": "Access token for OAuth (if pre-obtained)",
},
"api_key": {
"type": "string",
@@ -449,6 +628,19 @@ class MCPTool(Tool):
},
},
},
"oauth_scopes": {
"type": "array",
"description": "OAuth scopes to request (for oauth auth_type)",
"items": {"type": "string"},
"required": False,
"default": [],
},
"oauth_client_name": {
"type": "string",
"description": "Client name for OAuth registration (for oauth auth_type)",
"default": "DocsGPT-MCP",
"required": False,
},
"headers": {
"type": "object",
"description": "Custom headers to send with requests",