mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 00:23:17 +00:00
feat: implement OAuth 2.1 integration with custom handlers for fastmcp
This commit is contained in:
@@ -8,6 +8,7 @@ from application.security.encryption import decrypt_credentials
|
||||
|
||||
from fastmcp import Client
|
||||
from fastmcp.client.auth import BearerAuth
|
||||
from fastmcp.client.auth.oauth import OAuth as FastMCPOAuth
|
||||
from fastmcp.client.transports import (
|
||||
SSETransport,
|
||||
StdioTransport,
|
||||
@@ -17,6 +18,30 @@ from fastmcp.client.transports import (
|
||||
_mcp_clients_cache = {}
|
||||
|
||||
|
||||
class DocsGPTOAuth(FastMCPOAuth):
|
||||
"""Custom OAuth handler that integrates with DocsGPT frontend instead of opening browser."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.auth_url_callback = kwargs.pop("auth_url_callback", None)
|
||||
self.auth_code_callback = kwargs.pop("auth_code_callback", None)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def redirect_handler(self, authorization_url: str) -> None:
|
||||
"""Override to send auth URL to frontend instead of opening browser."""
|
||||
if self.auth_url_callback:
|
||||
self.auth_url_callback(authorization_url)
|
||||
else:
|
||||
raise Exception("OAuth authorization URL callback not configured")
|
||||
|
||||
async def callback_handler(self) -> tuple[str, str | None]:
|
||||
"""Override to wait for auth code from frontend instead of local server."""
|
||||
if self.auth_code_callback:
|
||||
auth_code, state = await self.auth_code_callback()
|
||||
return auth_code, state
|
||||
else:
|
||||
raise Exception("OAuth callback handler not configured")
|
||||
|
||||
|
||||
class MCPTool(Tool):
|
||||
"""
|
||||
MCP Tool
|
||||
@@ -37,6 +62,8 @@ class MCPTool(Tool):
|
||||
- headers: Custom headers for requests
|
||||
- command: Command for STDIO transport
|
||||
- args: Arguments for STDIO transport
|
||||
- oauth_scopes: OAuth scopes for oauth auth type
|
||||
- oauth_client_name: OAuth client name for oauth auth type
|
||||
user_id: User ID for decrypting credentials (required if encrypted_credentials exist)
|
||||
"""
|
||||
self.config = config
|
||||
@@ -53,19 +80,35 @@ class MCPTool(Tool):
|
||||
)
|
||||
else:
|
||||
self.auth_credentials = config.get("auth_credentials", {})
|
||||
# OAuth specific configuration
|
||||
|
||||
self.oauth_scopes = config.get("oauth_scopes", [])
|
||||
self.oauth_client_name = config.get("oauth_client_name", "DocsGPT-MCP")
|
||||
|
||||
# OAuth callback handlers (to be set by frontend)
|
||||
|
||||
self.oauth_auth_url_callback = None
|
||||
self.oauth_auth_code_callback = None
|
||||
|
||||
self.available_tools = []
|
||||
self._cache_key = self._generate_cache_key()
|
||||
self._client = None
|
||||
|
||||
# Only validate and setup if server_url is provided
|
||||
# Only validate and setup if server_url is provided and not OAuth
|
||||
# OAuth setup will happen after callbacks are set
|
||||
|
||||
if self.server_url:
|
||||
if self.server_url and self.auth_type != "oauth":
|
||||
self._setup_client()
|
||||
|
||||
def _generate_cache_key(self) -> str:
|
||||
"""Generate a unique cache key for this MCP server configuration."""
|
||||
auth_key = ""
|
||||
if self.auth_type in ["bearer", "oauth"]:
|
||||
if self.auth_type == "oauth":
|
||||
# For OAuth, use scopes and client name as part of the key
|
||||
|
||||
scopes_str = ",".join(self.oauth_scopes) if self.oauth_scopes else "none"
|
||||
auth_key = f"oauth:{self.oauth_client_name}:{scopes_str}"
|
||||
elif self.auth_type in ["bearer"]:
|
||||
token = self.auth_credentials.get(
|
||||
"bearer_token", ""
|
||||
) or self.auth_credentials.get("access_token", "")
|
||||
@@ -91,17 +134,31 @@ class MCPTool(Tool):
|
||||
else:
|
||||
del _mcp_clients_cache[self._cache_key]
|
||||
transport = self._create_transport()
|
||||
auth = None
|
||||
|
||||
if self.auth_type in ["bearer", "oauth"]:
|
||||
if self.auth_type == "oauth":
|
||||
# Ensure callbacks are configured before creating OAuth instance
|
||||
|
||||
if not self.oauth_auth_url_callback or not self.oauth_auth_code_callback:
|
||||
raise Exception(
|
||||
"OAuth callbacks not configured. Call set_oauth_callbacks() first."
|
||||
)
|
||||
# Use custom OAuth handler for frontend integration
|
||||
|
||||
auth = DocsGPTOAuth(
|
||||
mcp_url=self.server_url,
|
||||
scopes=self.oauth_scopes,
|
||||
client_name=self.oauth_client_name,
|
||||
auth_url_callback=self.oauth_auth_url_callback,
|
||||
auth_code_callback=self.oauth_auth_code_callback,
|
||||
)
|
||||
elif self.auth_type in ["bearer"]:
|
||||
token = self.auth_credentials.get(
|
||||
"bearer_token", ""
|
||||
) or self.auth_credentials.get("access_token", "")
|
||||
if token:
|
||||
self._client = Client(transport, auth=BearerAuth(token))
|
||||
else:
|
||||
self._client = Client(transport)
|
||||
else:
|
||||
self._client = Client(transport)
|
||||
auth = BearerAuth(token)
|
||||
self._client = Client(transport, auth=auth)
|
||||
_mcp_clients_cache[self._cache_key] = {
|
||||
"client": self._client,
|
||||
"created_at": time.time(),
|
||||
@@ -194,8 +251,12 @@ class MCPTool(Tool):
|
||||
def _run_async_operation(self, operation: str, *args, **kwargs):
|
||||
"""Run async operation in sync context."""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# Check if there's already a running event loop
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
# If we're in an async context, we need to run in a separate thread
|
||||
|
||||
import concurrent.futures
|
||||
|
||||
def run_in_thread():
|
||||
@@ -211,19 +272,25 @@ class MCPTool(Tool):
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(run_in_thread)
|
||||
return future.result(timeout=self.timeout)
|
||||
else:
|
||||
return loop.run_until_complete(
|
||||
self._execute_with_client(operation, *args, **kwargs)
|
||||
except RuntimeError:
|
||||
# No running loop, we can create one
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(
|
||||
self._execute_with_client(operation, *args, **kwargs)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
except Exception as e:
|
||||
# If async fails, try to give a better error message for OAuth
|
||||
|
||||
if self.auth_type == "oauth" and "callback not configured" in str(e):
|
||||
raise Exception(
|
||||
"OAuth callbacks not configured. Call set_oauth_callbacks() first."
|
||||
)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(
|
||||
self._execute_with_client(operation, *args, **kwargs)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
raise e
|
||||
|
||||
def discover_tools(self) -> List[Dict]:
|
||||
"""
|
||||
@@ -310,25 +377,12 @@ class MCPTool(Tool):
|
||||
if not self._client:
|
||||
self._setup_client()
|
||||
try:
|
||||
try:
|
||||
self._run_async_operation("ping")
|
||||
ping_success = True
|
||||
except Exception:
|
||||
ping_success = False
|
||||
tools = self.discover_tools()
|
||||
# For OAuth, we need to handle async operations differently
|
||||
|
||||
message = f"Successfully connected to MCP server. Found {len(tools)} tools."
|
||||
if not ping_success:
|
||||
message += " (Ping not supported, but tool discovery worked)"
|
||||
return {
|
||||
"success": True,
|
||||
"message": message,
|
||||
"tools_count": len(tools),
|
||||
"transport_type": self.transport_type,
|
||||
"auth_type": self.auth_type,
|
||||
"ping_supported": ping_success,
|
||||
"tools": [tool.get("name", "unknown") for tool in tools[:5]],
|
||||
}
|
||||
if self.auth_type == "oauth":
|
||||
return self._test_oauth_connection()
|
||||
else:
|
||||
return self._test_regular_connection()
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
@@ -339,6 +393,131 @@ class MCPTool(Tool):
|
||||
"error_type": type(e).__name__,
|
||||
}
|
||||
|
||||
def _test_regular_connection(self) -> Dict:
|
||||
"""Test connection for non-OAuth auth types."""
|
||||
try:
|
||||
self._run_async_operation("ping")
|
||||
ping_success = True
|
||||
except Exception:
|
||||
ping_success = False
|
||||
tools = self.discover_tools()
|
||||
|
||||
message = f"Successfully connected to MCP server. Found {len(tools)} tools."
|
||||
if not ping_success:
|
||||
message += " (Ping not supported, but tool discovery worked)"
|
||||
return {
|
||||
"success": True,
|
||||
"message": message,
|
||||
"tools_count": len(tools),
|
||||
"transport_type": self.transport_type,
|
||||
"auth_type": self.auth_type,
|
||||
"ping_supported": ping_success,
|
||||
"tools": [tool.get("name", "unknown") for tool in tools[:5]],
|
||||
}
|
||||
|
||||
def _test_oauth_connection(self) -> Dict:
|
||||
"""Test connection for OAuth auth type with proper async handling."""
|
||||
try:
|
||||
# Ensure callbacks are configured before proceeding
|
||||
|
||||
if not self.oauth_auth_url_callback or not self.oauth_auth_code_callback:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "OAuth callbacks not configured. Call set_oauth_callbacks() first.",
|
||||
"tools_count": 0,
|
||||
"transport_type": self.transport_type,
|
||||
"auth_type": self.auth_type,
|
||||
"error_type": "ConfigurationError",
|
||||
}
|
||||
# Ensure client is set up with proper callbacks
|
||||
|
||||
if not self._client:
|
||||
self._setup_client()
|
||||
# For OAuth, we use a simpler approach - just try to discover tools
|
||||
# This will trigger the OAuth flow if needed
|
||||
|
||||
tools = self.discover_tools()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Successfully connected to OAuth MCP server. Found {len(tools)} tools.",
|
||||
"tools_count": len(tools),
|
||||
"transport_type": self.transport_type,
|
||||
"auth_type": self.auth_type,
|
||||
"ping_supported": False, # Skip ping for OAuth to avoid complexity
|
||||
"tools": [tool.get("name", "unknown") for tool in tools[:5]],
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"OAuth connection failed: {str(e)}",
|
||||
"tools_count": 0,
|
||||
"transport_type": self.transport_type,
|
||||
"auth_type": self.auth_type,
|
||||
"error_type": type(e).__name__,
|
||||
}
|
||||
|
||||
def set_oauth_callbacks(self, auth_url_callback, auth_code_callback):
|
||||
"""
|
||||
Set OAuth callback handlers for frontend integration.
|
||||
|
||||
Args:
|
||||
auth_url_callback: Function to call with authorization URL
|
||||
auth_code_callback: Async function that returns (auth_code, state) tuple
|
||||
"""
|
||||
self.oauth_auth_url_callback = auth_url_callback
|
||||
self.oauth_auth_code_callback = auth_code_callback
|
||||
|
||||
# Clear the client so it gets recreated with the new callbacks
|
||||
|
||||
self._client = None
|
||||
|
||||
# Also clear from cache to ensure fresh creation
|
||||
|
||||
global _mcp_clients_cache
|
||||
if self._cache_key in _mcp_clients_cache:
|
||||
del _mcp_clients_cache[self._cache_key]
|
||||
|
||||
def clear_oauth_cache(self):
|
||||
"""
|
||||
Clear OAuth cache to force fresh authentication.
|
||||
This will remove stored tokens and client info for the server.
|
||||
"""
|
||||
if self.auth_type == "oauth":
|
||||
try:
|
||||
from fastmcp.client.auth.oauth import FileTokenStorage
|
||||
|
||||
storage = FileTokenStorage(server_url=self.server_url)
|
||||
storage.clear()
|
||||
print(f"✅ Cleared OAuth cache for {self.server_url}")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to clear OAuth cache: {e}")
|
||||
# Also clear our internal client cache
|
||||
|
||||
global _mcp_clients_cache
|
||||
if self._cache_key in _mcp_clients_cache:
|
||||
del _mcp_clients_cache[self._cache_key]
|
||||
print(f"✅ Cleared internal client cache")
|
||||
|
||||
@staticmethod
|
||||
def clear_all_oauth_cache():
|
||||
"""
|
||||
Clear all OAuth cache for all servers.
|
||||
This will remove all stored tokens and client info.
|
||||
"""
|
||||
try:
|
||||
from fastmcp.client.auth.oauth import FileTokenStorage
|
||||
|
||||
FileTokenStorage.clear_all()
|
||||
print(f"✅ Cleared all OAuth cache")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to clear all OAuth cache: {e}")
|
||||
# Also clear all internal client cache
|
||||
|
||||
global _mcp_clients_cache
|
||||
_mcp_clients_cache.clear()
|
||||
print(f"✅ Cleared all internal client cache")
|
||||
|
||||
def get_actions_metadata(self) -> List[Dict]:
|
||||
"""
|
||||
Get metadata for all available actions.
|
||||
@@ -413,7 +592,7 @@ class MCPTool(Tool):
|
||||
"help": {
|
||||
"none": "No authentication",
|
||||
"bearer": "Bearer token authentication",
|
||||
"oauth": "OAuth 2.0 authentication",
|
||||
"oauth": "OAuth 2.1 authentication (with frontend integration)",
|
||||
"api_key": "API key authentication",
|
||||
"basic": "Basic authentication",
|
||||
},
|
||||
@@ -425,11 +604,11 @@ class MCPTool(Tool):
|
||||
"properties": {
|
||||
"bearer_token": {
|
||||
"type": "string",
|
||||
"description": "Bearer token for bearer/oauth auth",
|
||||
"description": "Bearer token for bearer auth",
|
||||
},
|
||||
"access_token": {
|
||||
"type": "string",
|
||||
"description": "Access token for OAuth",
|
||||
"description": "Access token for OAuth (if pre-obtained)",
|
||||
},
|
||||
"api_key": {
|
||||
"type": "string",
|
||||
@@ -449,6 +628,19 @@ class MCPTool(Tool):
|
||||
},
|
||||
},
|
||||
},
|
||||
"oauth_scopes": {
|
||||
"type": "array",
|
||||
"description": "OAuth scopes to request (for oauth auth_type)",
|
||||
"items": {"type": "string"},
|
||||
"required": False,
|
||||
"default": [],
|
||||
},
|
||||
"oauth_client_name": {
|
||||
"type": "string",
|
||||
"description": "Client name for OAuth registration (for oauth auth_type)",
|
||||
"default": "DocsGPT-MCP",
|
||||
"required": False,
|
||||
},
|
||||
"headers": {
|
||||
"type": "object",
|
||||
"description": "Custom headers to send with requests",
|
||||
|
||||
Reference in New Issue
Block a user