feat: Implement OAuth flow for MCP server integration

- Added MCPOAuthManager to handle OAuth authorization.
- Updated MCPServerSave resource to manage OAuth status and callback.
- Introduced new endpoints for OAuth status and callback handling.
- Enhanced user interface to support OAuth authentication type.
- Implemented polling mechanism for OAuth status in MCPServerModal.
- Updated frontend services and endpoints to accommodate new OAuth features.
- Improved error handling and user feedback for OAuth processes.
This commit is contained in:
Siddhant Rai
2025-09-26 02:44:08 +05:30
parent 00b4e133d4
commit 3b27db36f2
9 changed files with 991 additions and 289 deletions

View File

@@ -1,47 +1,39 @@
import asyncio
import base64
import json
import logging
import time
from typing import Any, Dict, List, Optional
from urllib.parse import parse_qs, urlparse
from application.agents.tools.base import Tool
from application.security.encryption import decrypt_credentials
from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
from application.cache import get_redis_instance
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.security.encryption import decrypt_credentials
from fastmcp import Client
from fastmcp.client.auth import BearerAuth
from fastmcp.client.auth.oauth import OAuth as FastMCPOAuth
from fastmcp.client.transports import (
SSETransport,
StdioTransport,
StreamableHttpTransport,
)
from mcp.client.auth import OAuthClientProvider, TokenStorage
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
from pydantic import AnyHttpUrl, ValidationError
from redis import Redis
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
_mcp_clients_cache = {}
class DocsGPTOAuth(FastMCPOAuth):
"""Custom OAuth handler that integrates with DocsGPT frontend instead of opening browser."""
def __init__(self, *args, **kwargs):
self.auth_url_callback = kwargs.pop("auth_url_callback", None)
self.auth_code_callback = kwargs.pop("auth_code_callback", None)
super().__init__(*args, **kwargs)
async def redirect_handler(self, authorization_url: str) -> None:
"""Override to send auth URL to frontend instead of opening browser."""
if self.auth_url_callback:
self.auth_url_callback(authorization_url)
else:
raise Exception("OAuth authorization URL callback not configured")
async def callback_handler(self) -> tuple[str, str | None]:
"""Override to wait for auth code from frontend instead of local server."""
if self.auth_code_callback:
auth_code, state = await self.auth_code_callback()
return auth_code, state
else:
raise Exception("OAuth callback handler not configured")
class MCPTool(Tool):
"""
MCP Tool
@@ -67,6 +59,7 @@ class MCPTool(Tool):
user_id: User ID for decrypting credentials (required if encrypted_credentials exist)
"""
self.config = config
self.user_id = user_id
self.server_url = config.get("server_url", "")
self.transport_type = config.get("transport_type", "auto")
self.auth_type = config.get("auth_type", "none")
@@ -80,22 +73,16 @@ class MCPTool(Tool):
)
else:
self.auth_credentials = config.get("auth_credentials", {})
# OAuth specific configuration
self.oauth_scopes = config.get("oauth_scopes", [])
self.oauth_task_id = config.get("oauth_task_id", None)
self.oauth_client_name = config.get("oauth_client_name", "DocsGPT-MCP")
# OAuth callback handlers (to be set by frontend)
self.oauth_auth_url_callback = None
self.oauth_auth_code_callback = None
self.redirect_uri = f"{settings.API_URL}/api/mcp_server/callback"
self.available_tools = []
self._cache_key = self._generate_cache_key()
self._client = None
# Only validate and setup if server_url is provided and not OAuth
# OAuth setup will happen after callbacks are set
if self.server_url and self.auth_type != "oauth":
self._setup_client()
@@ -104,8 +91,6 @@ class MCPTool(Tool):
"""Generate a unique cache key for this MCP server configuration."""
auth_key = ""
if self.auth_type == "oauth":
# For OAuth, use scopes and client name as part of the key
scopes_str = ",".join(self.oauth_scopes) if self.oauth_scopes else "none"
auth_key = f"oauth:{self.oauth_client_name}:{scopes_str}"
elif self.auth_type in ["bearer"]:
@@ -137,22 +122,17 @@ class MCPTool(Tool):
auth = None
if self.auth_type == "oauth":
# Ensure callbacks are configured before creating OAuth instance
if not self.oauth_auth_url_callback or not self.oauth_auth_code_callback:
raise Exception(
"OAuth callbacks not configured. Call set_oauth_callbacks() first."
)
# Use custom OAuth handler for frontend integration
redis_client = get_redis_instance()
auth = DocsGPTOAuth(
mcp_url=self.server_url,
scopes=self.oauth_scopes,
client_name=self.oauth_client_name,
auth_url_callback=self.oauth_auth_url_callback,
auth_code_callback=self.oauth_auth_code_callback,
redis_client=redis_client,
redirect_uri=self.redirect_uri,
task_id=self.oauth_task_id,
db=db,
user_id=self.user_id,
)
elif self.auth_type in ["bearer"]:
elif self.auth_type == "bearer":
token = self.auth_credentials.get(
"bearer_token", ""
) or self.auth_credentials.get("access_token", "")
@@ -202,6 +182,33 @@ class MCPTool(Tool):
else:
return StreamableHttpTransport(url=self.server_url, headers=headers)
def _format_tools(self, tools_response) -> List[Dict]:
"""Format tools response to match expected format."""
if hasattr(tools_response, "tools"):
tools = tools_response.tools
elif isinstance(tools_response, list):
tools = tools_response
else:
tools = []
tools_dict = []
for tool in tools:
if hasattr(tool, "name"):
tool_dict = {
"name": tool.name,
"description": tool.description,
}
if hasattr(tool, "inputSchema"):
tool_dict["inputSchema"] = tool.inputSchema
tools_dict.append(tool_dict)
elif isinstance(tool, dict):
tools_dict.append(tool)
else:
if hasattr(tool, "model_dump"):
tools_dict.append(tool.model_dump())
else:
tools_dict.append({"name": str(tool), "description": ""})
return tools_dict
async def _execute_with_client(self, operation: str, *args, **kwargs):
"""Execute operation with FastMCP client."""
if not self._client:
@@ -211,32 +218,8 @@ class MCPTool(Tool):
return await self._client.ping()
elif operation == "list_tools":
tools_response = await self._client.list_tools()
if hasattr(tools_response, "tools"):
tools = tools_response.tools
elif isinstance(tools_response, list):
tools = tools_response
else:
tools = []
tools_dict = []
for tool in tools:
if hasattr(tool, "name"):
tool_dict = {
"name": tool.name,
"description": tool.description,
}
if hasattr(tool, "inputSchema"):
tool_dict["inputSchema"] = tool.inputSchema
tools_dict.append(tool_dict)
elif isinstance(tool, dict):
tools_dict.append(tool)
else:
if hasattr(tool, "model_dump"):
tools_dict.append(tool.model_dump())
else:
tools_dict.append({"name": str(tool), "description": ""})
return tools_dict
self.available_tools = self._format_tools(tools_response)
return self.available_tools
elif operation == "call_tool":
tool_name = args[0]
tool_args = kwargs
@@ -251,12 +234,8 @@ class MCPTool(Tool):
def _run_async_operation(self, operation: str, *args, **kwargs):
"""Run async operation in sync context."""
try:
# Check if there's already a running event loop
try:
loop = asyncio.get_running_loop()
# If we're in an async context, we need to run in a separate thread
import concurrent.futures
def run_in_thread():
@@ -273,8 +252,6 @@ class MCPTool(Tool):
future = executor.submit(run_in_thread)
return future.result(timeout=self.timeout)
except RuntimeError:
# No running loop, we can create one
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
@@ -284,13 +261,8 @@ class MCPTool(Tool):
finally:
loop.close()
except Exception as e:
# If async fails, try to give a better error message for OAuth
if self.auth_type == "oauth" and "callback not configured" in str(e):
raise Exception(
"OAuth callbacks not configured. Call set_oauth_callbacks() first."
)
raise e
print(f"Error occurred while running async operation: {e}")
raise
def discover_tools(self) -> List[Dict]:
"""
@@ -377,8 +349,6 @@ class MCPTool(Tool):
if not self._client:
self._setup_client()
try:
# For OAuth, we need to handle async operations differently
if self.auth_type == "oauth":
return self._test_oauth_connection()
else:
@@ -412,40 +382,21 @@ class MCPTool(Tool):
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"ping_supported": ping_success,
"tools": [tool.get("name", "unknown") for tool in tools[:5]],
"tools": [tool.get("name", "unknown") for tool in tools],
}
def _test_oauth_connection(self) -> Dict:
"""Test connection for OAuth auth type with proper async handling."""
try:
# Ensure callbacks are configured before proceeding
if not self.oauth_auth_url_callback or not self.oauth_auth_code_callback:
return {
"success": False,
"message": "OAuth callbacks not configured. Call set_oauth_callbacks() first.",
"tools_count": 0,
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"error_type": "ConfigurationError",
}
# Ensure client is set up with proper callbacks
if not self._client:
self._setup_client()
# For OAuth, we use a simpler approach - just try to discover tools
# This will trigger the OAuth flow if needed
tools = self.discover_tools()
task = mcp_oauth_task.delay(config=self.config, user=self.user_id)
if not task:
raise Exception("Failed to start OAuth authentication")
return {
"success": True,
"message": f"Successfully connected to OAuth MCP server. Found {len(tools)} tools.",
"tools_count": len(tools),
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"ping_supported": False, # Skip ping for OAuth to avoid complexity
"tools": [tool.get("name", "unknown") for tool in tools[:5]],
"requires_oauth": True,
"task_id": task.id,
"status": "pending",
"message": "OAuth flow started",
}
except Exception as e:
return {
@@ -457,67 +408,6 @@ class MCPTool(Tool):
"error_type": type(e).__name__,
}
def set_oauth_callbacks(self, auth_url_callback, auth_code_callback):
"""
Set OAuth callback handlers for frontend integration.
Args:
auth_url_callback: Function to call with authorization URL
auth_code_callback: Async function that returns (auth_code, state) tuple
"""
self.oauth_auth_url_callback = auth_url_callback
self.oauth_auth_code_callback = auth_code_callback
# Clear the client so it gets recreated with the new callbacks
self._client = None
# Also clear from cache to ensure fresh creation
global _mcp_clients_cache
if self._cache_key in _mcp_clients_cache:
del _mcp_clients_cache[self._cache_key]
def clear_oauth_cache(self):
"""
Clear OAuth cache to force fresh authentication.
This will remove stored tokens and client info for the server.
"""
if self.auth_type == "oauth":
try:
from fastmcp.client.auth.oauth import FileTokenStorage
storage = FileTokenStorage(server_url=self.server_url)
storage.clear()
print(f"✅ Cleared OAuth cache for {self.server_url}")
except Exception as e:
print(f"⚠️ Failed to clear OAuth cache: {e}")
# Also clear our internal client cache
global _mcp_clients_cache
if self._cache_key in _mcp_clients_cache:
del _mcp_clients_cache[self._cache_key]
print(f"✅ Cleared internal client cache")
@staticmethod
def clear_all_oauth_cache():
"""
Clear all OAuth cache for all servers.
This will remove all stored tokens and client info.
"""
try:
from fastmcp.client.auth.oauth import FileTokenStorage
FileTokenStorage.clear_all()
print(f"✅ Cleared all OAuth cache")
except Exception as e:
print(f"⚠️ Failed to clear all OAuth cache: {e}")
# Also clear all internal client cache
global _mcp_clients_cache
_mcp_clients_cache.clear()
print(f"✅ Cleared all internal client cache")
def get_actions_metadata(self) -> List[Dict]:
"""
Get metadata for all available actions.
@@ -666,3 +556,306 @@ class MCPTool(Tool):
"required": False,
},
}
class DocsGPTOAuth(OAuthClientProvider):
"""
Custom OAuth handler for DocsGPT that uses frontend redirect instead of browser.
"""
def __init__(
self,
mcp_url: str,
redirect_uri: str,
redis_client: Redis | None = None,
redis_prefix: str = "mcp_oauth:",
task_id: str = None,
scopes: str | list[str] | None = None,
client_name: str = "DocsGPT-MCP",
user_id=None,
db=None,
additional_client_metadata: dict[str, Any] | None = None,
):
"""
Initialize custom OAuth client provider for DocsGPT.
Args:
mcp_url: Full URL to the MCP endpoint
redirect_uri: Custom redirect URI for DocsGPT frontend
redis_client: Redis client for storing auth state
redis_prefix: Prefix for Redis keys
task_id: Task ID for tracking auth status
scopes: OAuth scopes to request
client_name: Name for this client during registration
user_id: User ID for token storage
db: Database instance for token storage
additional_client_metadata: Extra fields for OAuthClientMetadata
"""
self.redirect_uri = redirect_uri
self.redis_client = redis_client
self.redis_prefix = redis_prefix
self.task_id = task_id
self.user_id = user_id
self.db = db
parsed_url = urlparse(mcp_url)
self.server_base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
if isinstance(scopes, list):
scopes = " ".join(scopes)
client_metadata = OAuthClientMetadata(
client_name=client_name,
redirect_uris=[AnyHttpUrl(redirect_uri)],
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
scope=scopes,
**(additional_client_metadata or {}),
)
storage = DBTokenStorage(
server_url=self.server_base_url, user_id=self.user_id, db_client=self.db
)
super().__init__(
server_url=self.server_base_url,
client_metadata=client_metadata,
storage=storage,
redirect_handler=self.redirect_handler,
callback_handler=self.callback_handler,
)
self.auth_url = None
self.extracted_state = None
def _process_auth_url(self, authorization_url: str) -> tuple[str, str]:
"""Process authorization URL to extract state"""
try:
parsed_url = urlparse(authorization_url)
query_params = parse_qs(parsed_url.query)
state_params = query_params.get("state", [])
if state_params:
state = state_params[0]
else:
raise ValueError("No state in auth URL")
return authorization_url, state
except Exception as e:
raise Exception(f"Failed to process auth URL: {e}")
async def redirect_handler(self, authorization_url: str) -> None:
"""Store auth URL and state in Redis for frontend to use."""
auth_url, state = self._process_auth_url(authorization_url)
logging.info(
"[DocsGPTOAuth] Processed auth_url: %s, state: %s", auth_url, state
)
self.auth_url = auth_url
self.extracted_state = state
if self.redis_client and self.extracted_state:
key = f"{self.redis_prefix}auth_url:{self.extracted_state}"
self.redis_client.setex(key, 600, auth_url)
logging.info("[DocsGPTOAuth] Stored auth_url in Redis: %s", key)
if self.task_id:
status_key = f"mcp_oauth_status:{self.task_id}"
status_data = {
"status": "requires_redirect",
"message": "OAuth authorization required",
"authorization_url": self.auth_url,
"state": self.extracted_state,
"requires_oauth": True,
"task_id": self.task_id,
}
self.redis_client.setex(status_key, 600, json.dumps(status_data))
async def callback_handler(self) -> tuple[str, str | None]:
"""Wait for auth code from Redis using the state value."""
if not self.redis_client or not self.extracted_state:
raise Exception("Redis client or state not configured for OAuth")
poll_interval = 1
max_wait_time = 300
code_key = f"{self.redis_prefix}code:{self.extracted_state}"
if self.task_id:
status_key = f"mcp_oauth_status:{self.task_id}"
status_data = {
"status": "awaiting_callback",
"message": "Waiting for OAuth callback...",
"authorization_url": self.auth_url,
"state": self.extracted_state,
"requires_oauth": True,
"task_id": self.task_id,
}
self.redis_client.setex(status_key, 600, json.dumps(status_data))
start_time = time.time()
while time.time() - start_time < max_wait_time:
code_data = self.redis_client.get(code_key)
if code_data:
code = code_data.decode()
returned_state = self.extracted_state
self.redis_client.delete(code_key)
self.redis_client.delete(
f"{self.redis_prefix}auth_url:{self.extracted_state}"
)
self.redis_client.delete(
f"{self.redis_prefix}state:{self.extracted_state}"
)
if self.task_id:
status_data = {
"status": "callback_received",
"message": "OAuth callback received, completing authentication...",
"task_id": self.task_id,
}
self.redis_client.setex(status_key, 600, json.dumps(status_data))
return code, returned_state
error_key = f"{self.redis_prefix}error:{self.extracted_state}"
error_data = self.redis_client.get(error_key)
if error_data:
error_msg = error_data.decode()
self.redis_client.delete(error_key)
self.redis_client.delete(
f"{self.redis_prefix}auth_url:{self.extracted_state}"
)
self.redis_client.delete(
f"{self.redis_prefix}state:{self.extracted_state}"
)
raise Exception(f"OAuth error: {error_msg}")
await asyncio.sleep(poll_interval)
self.redis_client.delete(f"{self.redis_prefix}auth_url:{self.extracted_state}")
self.redis_client.delete(f"{self.redis_prefix}state:{self.extracted_state}")
raise Exception("OAuth callback timeout: no code received within 5 minutes")
class DBTokenStorage(TokenStorage):
def __init__(self, server_url: str, user_id: str, db_client):
self.server_url = server_url
self.user_id = user_id
self.db_client = db_client
self.collection = db_client["connector_sessions"]
@staticmethod
def get_base_url(url: str) -> str:
parsed = urlparse(url)
return f"{parsed.scheme}://{parsed.netloc}"
def get_db_key(self) -> dict:
return {
"server_url": self.get_base_url(self.server_url),
"user_id": self.user_id,
}
async def get_tokens(self) -> OAuthToken | None:
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
if not doc or "tokens" not in doc:
return None
try:
tokens = OAuthToken.model_validate(doc["tokens"])
return tokens
except ValidationError as e:
logging.error(f"Could not load tokens: {e}")
return None
async def set_tokens(self, tokens: OAuthToken) -> None:
await asyncio.to_thread(
self.collection.update_one,
self.get_db_key(),
{"$set": {"tokens": tokens.model_dump()}},
True,
)
logging.info(f"Saved tokens for {self.get_base_url(self.server_url)}")
async def get_client_info(self) -> OAuthClientInformationFull | None:
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
if not doc or "client_info" not in doc:
return None
try:
client_info = OAuthClientInformationFull.model_validate(doc["client_info"])
tokens = await self.get_tokens()
if tokens is None:
logging.debug(
"No tokens found, clearing client info to force fresh registration."
)
await asyncio.to_thread(
self.collection.update_one,
self.get_db_key(),
{"$unset": {"client_info": ""}},
)
return None
return client_info
except ValidationError as e:
logging.error(f"Could not load client info: {e}")
return None
def _serialize_client_info(self, info: dict) -> dict:
if "redirect_uris" in info and isinstance(info["redirect_uris"], list):
info["redirect_uris"] = [str(u) for u in info["redirect_uris"]]
return info
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
serialized_info = self._serialize_client_info(client_info.model_dump())
await asyncio.to_thread(
self.collection.update_one,
self.get_db_key(),
{"$set": {"client_info": serialized_info}},
True,
)
logging.info(f"Saved client info for {self.get_base_url(self.server_url)}")
async def clear(self) -> None:
await asyncio.to_thread(self.collection.delete_one, self.get_db_key())
logging.info(f"Cleared OAuth cache for {self.get_base_url(self.server_url)}")
@classmethod
async def clear_all(cls, db_client) -> None:
collection = db_client["connector_sessions"]
await asyncio.to_thread(collection.delete_many, {})
logging.info("Cleared all OAuth client cache data.")
class MCPOAuthManager:
"""Manager for handling MCP OAuth callbacks."""
def __init__(self, redis_client: Redis | None, redis_prefix: str = "mcp_oauth:"):
self.redis_client = redis_client
self.redis_prefix = redis_prefix
def handle_oauth_callback(
self, state: str, code: str, error: Optional[str] = None
) -> bool:
"""
Handle OAuth callback from provider.
Args:
state: The state parameter from OAuth callback
code: The authorization code from OAuth callback
error: Error message if OAuth failed
Returns:
True if successful, False otherwise
"""
try:
if not self.redis_client or not state:
raise Exception("Redis client or state not provided")
if error:
error_key = f"{self.redis_prefix}error:{state}"
self.redis_client.setex(error_key, 300, error)
raise Exception(f"OAuth error received: {error}")
code_key = f"{self.redis_prefix}code:{state}"
self.redis_client.setex(code_key, 300, code)
state_key = f"{self.redis_prefix}state:{state}"
self.redis_client.setex(state_key, 300, "completed")
return True
except Exception as e:
logging.error(f"Error handling OAuth callback: {e}")
return False
def get_oauth_status(self, task_id: str) -> Dict[str, Any]:
"""Get current status of OAuth flow using provided task_id."""
if not task_id:
return {"status": "not_started", "message": "OAuth flow not started"}
return mcp_oauth_status_task(task_id)