diff --git a/application/agents/base.py b/application/agents/base.py index 77729fe6..134de1c3 100644 --- a/application/agents/base.py +++ b/application/agents/base.py @@ -264,7 +264,15 @@ class BaseAgent(ABC): query: str, retrieved_data: List[Dict], ) -> List[Dict]: - docs_together = "\n".join([doc["text"] for doc in retrieved_data]) + docs_with_filenames = [] + for doc in retrieved_data: + filename = doc.get("filename") or doc.get("title") or doc.get("source") + if filename: + chunk_header = str(filename) + docs_with_filenames.append(f"{chunk_header}\n{doc['text']}") + else: + docs_with_filenames.append(doc["text"]) + docs_together = "\n\n".join(docs_with_filenames) p_chat_combine = system_prompt.replace("{summaries}", docs_together) messages_combine = [{"role": "system", "content": p_chat_combine}] diff --git a/application/agents/tools/mcp_tool.py b/application/agents/tools/mcp_tool.py index dc689367..e539986c 100644 --- a/application/agents/tools/mcp_tool.py +++ b/application/agents/tools/mcp_tool.py @@ -1,14 +1,37 @@ +import asyncio +import base64 import json +import logging import time from typing import Any, Dict, List, Optional - -import requests +from urllib.parse import parse_qs, urlparse from application.agents.tools.base import Tool +from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task +from application.cache import get_redis_instance + +from application.core.mongo_db import MongoDB + +from application.core.settings import settings + from application.security.encryption import decrypt_credentials +from fastmcp import Client +from fastmcp.client.auth import BearerAuth +from fastmcp.client.transports import ( + SSETransport, + StdioTransport, + StreamableHttpTransport, +) +from mcp.client.auth import OAuthClientProvider, TokenStorage +from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken +from pydantic import AnyHttpUrl, ValidationError +from redis import Redis -_mcp_session_cache = {} +mongo = MongoDB.get_client() +db = mongo[settings.MONGO_DB_NAME] + +_mcp_clients_cache = {} class MCPTool(Tool): @@ -24,15 +47,24 @@ class MCPTool(Tool): Args: config: Dictionary containing MCP server configuration: - server_url: URL of the remote MCP server - - auth_type: Type of authentication (api_key, bearer, basic, none) + - transport_type: Transport type (auto, sse, http, stdio) + - auth_type: Type of authentication (bearer, oauth, api_key, basic, none) - encrypted_credentials: Encrypted credentials (if available) - timeout: Request timeout in seconds (default: 30) + - headers: Custom headers for requests + - command: Command for STDIO transport + - args: Arguments for STDIO transport + - oauth_scopes: OAuth scopes for oauth auth type + - oauth_client_name: OAuth client name for oauth auth type user_id: User ID for decrypting credentials (required if encrypted_credentials exist) """ self.config = config + self.user_id = user_id self.server_url = config.get("server_url", "") + self.transport_type = config.get("transport_type", "auto") self.auth_type = config.get("auth_type", "none") self.timeout = config.get("timeout", 30) + self.custom_headers = config.get("headers", {}) self.auth_credentials = {} if config.get("encrypted_credentials") and user_id: @@ -41,34 +73,30 @@ class MCPTool(Tool): ) else: self.auth_credentials = config.get("auth_credentials", {}) - self.available_tools = [] - self._session = requests.Session() - self._mcp_session_id = None - self._setup_authentication() - self._cache_key = self._generate_cache_key() + self.oauth_scopes = config.get("oauth_scopes", []) + self.oauth_task_id = config.get("oauth_task_id", None) + self.oauth_client_name = config.get("oauth_client_name", "DocsGPT-MCP") + self.redirect_uri = f"{settings.API_URL}/api/mcp_server/callback" - def _setup_authentication(self): - """Setup authentication for the MCP server connection.""" - if self.auth_type == "api_key": - api_key = self.auth_credentials.get("api_key", "") - header_name = self.auth_credentials.get("api_key_header", "X-API-Key") - if api_key: - self._session.headers.update({header_name: api_key}) - elif self.auth_type == "bearer": - token = self.auth_credentials.get("bearer_token", "") - if token: - self._session.headers.update({"Authorization": f"Bearer {token}"}) - elif self.auth_type == "basic": - username = self.auth_credentials.get("username", "") - password = self.auth_credentials.get("password", "") - if username and password: - self._session.auth = (username, password) + self.available_tools = [] + self._cache_key = self._generate_cache_key() + self._client = None + + # Only validate and setup if server_url is provided and not OAuth + + if self.server_url and self.auth_type != "oauth": + self._setup_client() def _generate_cache_key(self) -> str: """Generate a unique cache key for this MCP server configuration.""" auth_key = "" - if self.auth_type == "bearer": - token = self.auth_credentials.get("bearer_token", "") + if self.auth_type == "oauth": + scopes_str = ",".join(self.oauth_scopes) if self.oauth_scopes else "none" + auth_key = f"oauth:{self.oauth_client_name}:{scopes_str}" + elif self.auth_type in ["bearer"]: + token = self.auth_credentials.get( + "bearer_token", "" + ) or self.auth_credentials.get("access_token", "") auth_key = f"bearer:{token[:10]}..." if token else "bearer:none" elif self.auth_type == "api_key": api_key = self.auth_credentials.get("api_key", "") @@ -78,201 +106,185 @@ class MCPTool(Tool): auth_key = f"basic:{username}" else: auth_key = "none" - return f"{self.server_url}#{auth_key}" + return f"{self.server_url}#{self.transport_type}#{auth_key}" - def _get_cached_session(self) -> Optional[str]: - """Get cached session ID if available and not expired.""" - global _mcp_session_cache - - if self._cache_key in _mcp_session_cache: - session_data = _mcp_session_cache[self._cache_key] - if time.time() - session_data["created_at"] < 1800: - return session_data["session_id"] + def _setup_client(self): + """Setup FastMCP client with proper transport and authentication.""" + global _mcp_clients_cache + if self._cache_key in _mcp_clients_cache: + cached_data = _mcp_clients_cache[self._cache_key] + if time.time() - cached_data["created_at"] < 1800: + self._client = cached_data["client"] + return else: - del _mcp_session_cache[self._cache_key] - return None + del _mcp_clients_cache[self._cache_key] + transport = self._create_transport() + auth = None - def _cache_session(self, session_id: str): - """Cache the session ID for reuse.""" - global _mcp_session_cache - _mcp_session_cache[self._cache_key] = { - "session_id": session_id, + if self.auth_type == "oauth": + redis_client = get_redis_instance() + auth = DocsGPTOAuth( + mcp_url=self.server_url, + scopes=self.oauth_scopes, + redis_client=redis_client, + redirect_uri=self.redirect_uri, + task_id=self.oauth_task_id, + db=db, + user_id=self.user_id, + ) + elif self.auth_type == "bearer": + token = self.auth_credentials.get( + "bearer_token", "" + ) or self.auth_credentials.get("access_token", "") + if token: + auth = BearerAuth(token) + self._client = Client(transport, auth=auth) + _mcp_clients_cache[self._cache_key] = { + "client": self._client, "created_at": time.time(), } - def _initialize_mcp_connection(self) -> Dict: - """ - Initialize MCP connection with the server, using cached session if available. + def _create_transport(self): + """Create appropriate transport based on configuration.""" + headers = {"Content-Type": "application/json", "User-Agent": "DocsGPT-MCP/1.0"} + headers.update(self.custom_headers) - Returns: - Server capabilities and information - """ - cached_session = self._get_cached_session() - if cached_session: - self._mcp_session_id = cached_session - return {"cached": True} - try: - init_params = { - "protocolVersion": "2024-11-05", - "capabilities": {"roots": {"listChanged": True}, "sampling": {}}, - "clientInfo": {"name": "DocsGPT", "version": "1.0.0"}, - } - response = self._make_mcp_request("initialize", init_params) - self._make_mcp_request("notifications/initialized") - - return response - except Exception as e: - return {"error": str(e), "fallback": True} - - def _ensure_valid_session(self): - """Ensure we have a valid MCP session, reinitializing if needed.""" - if not self._mcp_session_id: - self._initialize_mcp_connection() - - def _make_mcp_request(self, method: str, params: Optional[Dict] = None) -> Dict: - """ - Make an MCP protocol request to the server with automatic session recovery. - - Args: - method: MCP method name (e.g., "tools/list", "tools/call") - params: Parameters for the MCP method - - Returns: - Response data as dictionary - - Raises: - Exception: If request fails after retry - """ - mcp_message = {"jsonrpc": "2.0", "method": method} - - if not method.startswith("notifications/"): - mcp_message["id"] = 1 - if params: - mcp_message["params"] = params - return self._execute_mcp_request(mcp_message, method) - - def _execute_mcp_request( - self, mcp_message: Dict, method: str, is_retry: bool = False - ) -> Dict: - """Execute MCP request with optional retry on session failure.""" - try: - final_headers = self._session.headers.copy() - final_headers.update( - { - "Content-Type": "application/json", - "Accept": "application/json, text/event-stream", - } - ) - - if self._mcp_session_id: - final_headers["Mcp-Session-Id"] = self._mcp_session_id - response = self._session.post( - self.server_url.rstrip("/"), - json=mcp_message, - headers=final_headers, - timeout=self.timeout, - ) - - if "mcp-session-id" in response.headers: - self._mcp_session_id = response.headers["mcp-session-id"] - self._cache_session(self._mcp_session_id) - response.raise_for_status() - - if method.startswith("notifications/"): - return {} - response_text = response.text.strip() - if response_text.startswith("event:") and "data:" in response_text: - lines = response_text.split("\n") - data_line = None - for line in lines: - if line.startswith("data:"): - data_line = line[5:].strip() - break - if data_line: - try: - result = json.loads(data_line) - except json.JSONDecodeError: - raise Exception(f"Invalid JSON in SSE data: {data_line}") - else: - raise Exception(f"No data found in SSE response: {response_text}") + if self.auth_type == "api_key": + api_key = self.auth_credentials.get("api_key", "") + header_name = self.auth_credentials.get("api_key_header", "X-API-Key") + if api_key: + headers[header_name] = api_key + elif self.auth_type == "basic": + username = self.auth_credentials.get("username", "") + password = self.auth_credentials.get("password", "") + if username and password: + credentials = base64.b64encode( + f"{username}:{password}".encode() + ).decode() + headers["Authorization"] = f"Basic {credentials}" + if self.transport_type == "auto": + if "sse" in self.server_url.lower() or self.server_url.endswith("/sse"): + transport_type = "sse" else: + transport_type = "http" + else: + transport_type = self.transport_type + if transport_type == "sse": + headers.update({"Accept": "text/event-stream", "Cache-Control": "no-cache"}) + return SSETransport(url=self.server_url, headers=headers) + elif transport_type == "http": + return StreamableHttpTransport(url=self.server_url, headers=headers) + elif transport_type == "stdio": + command = self.config.get("command", "python") + args = self.config.get("args", []) + env = self.auth_credentials if self.auth_credentials else None + return StdioTransport(command=command, args=args, env=env) + else: + return StreamableHttpTransport(url=self.server_url, headers=headers) + + def _format_tools(self, tools_response) -> List[Dict]: + """Format tools response to match expected format.""" + if hasattr(tools_response, "tools"): + tools = tools_response.tools + elif isinstance(tools_response, list): + tools = tools_response + else: + tools = [] + tools_dict = [] + for tool in tools: + if hasattr(tool, "name"): + tool_dict = { + "name": tool.name, + "description": tool.description, + } + if hasattr(tool, "inputSchema"): + tool_dict["inputSchema"] = tool.inputSchema + tools_dict.append(tool_dict) + elif isinstance(tool, dict): + tools_dict.append(tool) + else: + if hasattr(tool, "model_dump"): + tools_dict.append(tool.model_dump()) + else: + tools_dict.append({"name": str(tool), "description": ""}) + return tools_dict + + async def _execute_with_client(self, operation: str, *args, **kwargs): + """Execute operation with FastMCP client.""" + if not self._client: + raise Exception("FastMCP client not initialized") + async with self._client: + if operation == "ping": + return await self._client.ping() + elif operation == "list_tools": + tools_response = await self._client.list_tools() + self.available_tools = self._format_tools(tools_response) + return self.available_tools + elif operation == "call_tool": + tool_name = args[0] + tool_args = kwargs + return await self._client.call_tool(tool_name, tool_args) + elif operation == "list_resources": + return await self._client.list_resources() + elif operation == "list_prompts": + return await self._client.list_prompts() + else: + raise Exception(f"Unknown operation: {operation}") + + def _run_async_operation(self, operation: str, *args, **kwargs): + """Run async operation in sync context.""" + try: + try: + loop = asyncio.get_running_loop() + import concurrent.futures + + def run_in_thread(): + new_loop = asyncio.new_event_loop() + asyncio.set_event_loop(new_loop) + try: + return new_loop.run_until_complete( + self._execute_with_client(operation, *args, **kwargs) + ) + finally: + new_loop.close() + + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(run_in_thread) + return future.result(timeout=self.timeout) + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) try: - result = response.json() - except json.JSONDecodeError: - raise Exception(f"Invalid JSON response: {response.text}") - if "error" in result: - error_msg = result["error"] - if isinstance(error_msg, dict): - error_msg = error_msg.get("message", str(error_msg)) - raise Exception(f"MCP server error: {error_msg}") - return result.get("result", result) - except requests.exceptions.RequestException as e: - if not is_retry and self._should_retry_with_new_session(e): - self._invalidate_and_refresh_session() - return self._execute_mcp_request(mcp_message, method, is_retry=True) - raise Exception(f"MCP server request failed: {str(e)}") - - def _should_retry_with_new_session(self, error: Exception) -> bool: - """Check if error indicates session invalidation and retry is warranted.""" - error_str = str(error).lower() - return ( - any( - indicator in error_str - for indicator in [ - "invalid session", - "session expired", - "unauthorized", - "401", - "403", - ] - ) - and self._mcp_session_id is not None - ) - - def _invalidate_and_refresh_session(self) -> None: - """Invalidate current session and create a new one.""" - global _mcp_session_cache - if self._cache_key in _mcp_session_cache: - del _mcp_session_cache[self._cache_key] - self._mcp_session_id = None - self._initialize_mcp_connection() + return loop.run_until_complete( + self._execute_with_client(operation, *args, **kwargs) + ) + finally: + loop.close() + except Exception as e: + print(f"Error occurred while running async operation: {e}") + raise def discover_tools(self) -> List[Dict]: """ - Discover available tools from the MCP server using MCP protocol. + Discover available tools from the MCP server using FastMCP. Returns: List of tool definitions from the server """ + if not self.server_url: + return [] + if not self._client: + self._setup_client() try: - self._ensure_valid_session() - - response = self._make_mcp_request("tools/list") - - # Handle both formats: response with 'tools' key or response that IS the tools list - - if isinstance(response, dict): - if "tools" in response: - self.available_tools = response["tools"] - elif ( - "result" in response - and isinstance(response["result"], dict) - and "tools" in response["result"] - ): - self.available_tools = response["result"]["tools"] - else: - self.available_tools = [response] if response else [] - elif isinstance(response, list): - self.available_tools = response - else: - self.available_tools = [] + tools = self._run_async_operation("list_tools") + self.available_tools = tools return self.available_tools except Exception as e: raise Exception(f"Failed to discover tools from MCP server: {str(e)}") def execute_action(self, action_name: str, **kwargs) -> Any: """ - Execute an action on the remote MCP server using MCP protocol. + Execute an action on the remote MCP server using FastMCP. Args: action_name: Name of the action to execute @@ -281,22 +293,121 @@ class MCPTool(Tool): Returns: Result from the MCP server """ - self._ensure_valid_session() - - # Skipping empty/None values - letting the server use defaults - + if not self.server_url: + raise Exception("No MCP server configured") + if not self._client: + self._setup_client() cleaned_kwargs = {} for key, value in kwargs.items(): if value == "" or value is None: continue cleaned_kwargs[key] = value - call_params = {"name": action_name, "arguments": cleaned_kwargs} try: - result = self._make_mcp_request("tools/call", call_params) - return result + result = self._run_async_operation( + "call_tool", action_name, **cleaned_kwargs + ) + return self._format_result(result) except Exception as e: raise Exception(f"Failed to execute action '{action_name}': {str(e)}") + def _format_result(self, result) -> Dict: + """Format FastMCP result to match expected format.""" + if hasattr(result, "content"): + content_list = [] + for content_item in result.content: + if hasattr(content_item, "text"): + content_list.append({"type": "text", "text": content_item.text}) + elif hasattr(content_item, "data"): + content_list.append({"type": "data", "data": content_item.data}) + else: + content_list.append( + {"type": "unknown", "content": str(content_item)} + ) + return { + "content": content_list, + "isError": getattr(result, "isError", False), + } + else: + return result + + def test_connection(self) -> Dict: + """ + Test the connection to the MCP server and validate functionality. + + Returns: + Dictionary with connection test results including tool count + """ + if not self.server_url: + return { + "success": False, + "message": "No MCP server URL configured", + "tools_count": 0, + "transport_type": self.transport_type, + "auth_type": self.auth_type, + "error_type": "ConfigurationError", + } + if not self._client: + self._setup_client() + try: + if self.auth_type == "oauth": + return self._test_oauth_connection() + else: + return self._test_regular_connection() + except Exception as e: + return { + "success": False, + "message": f"Connection failed: {str(e)}", + "tools_count": 0, + "transport_type": self.transport_type, + "auth_type": self.auth_type, + "error_type": type(e).__name__, + } + + def _test_regular_connection(self) -> Dict: + """Test connection for non-OAuth auth types.""" + try: + self._run_async_operation("ping") + ping_success = True + except Exception: + ping_success = False + tools = self.discover_tools() + + message = f"Successfully connected to MCP server. Found {len(tools)} tools." + if not ping_success: + message += " (Ping not supported, but tool discovery worked)" + return { + "success": True, + "message": message, + "tools_count": len(tools), + "transport_type": self.transport_type, + "auth_type": self.auth_type, + "ping_supported": ping_success, + "tools": [tool.get("name", "unknown") for tool in tools], + } + + def _test_oauth_connection(self) -> Dict: + """Test connection for OAuth auth type with proper async handling.""" + try: + task = mcp_oauth_task.delay(config=self.config, user=self.user_id) + if not task: + raise Exception("Failed to start OAuth authentication") + return { + "success": True, + "requires_oauth": True, + "task_id": task.id, + "status": "pending", + "message": "OAuth flow started", + } + except Exception as e: + return { + "success": False, + "message": f"OAuth connection failed: {str(e)}", + "tools_count": 0, + "transport_type": self.transport_type, + "auth_type": self.auth_type, + "error_type": type(e).__name__, + } + def get_actions_metadata(self) -> List[Dict]: """ Get metadata for all available actions. @@ -341,58 +452,89 @@ class MCPTool(Tool): actions.append(action) return actions - def test_connection(self) -> Dict: - """ - Test the connection to the MCP server and validate functionality. - - Returns: - Dictionary with connection test results including tool count - """ - try: - self._mcp_session_id = None - - init_result = self._initialize_mcp_connection() - - tools = self.discover_tools() - - message = f"Successfully connected to MCP server. Found {len(tools)} tools." - if init_result.get("cached"): - message += " (Using cached session)" - elif init_result.get("fallback"): - message += " (No formal initialization required)" - return { - "success": True, - "message": message, - "tools_count": len(tools), - "session_id": self._mcp_session_id, - "tools": [tool.get("name", "unknown") for tool in tools[:5]], - } - except Exception as e: - return { - "success": False, - "message": f"Connection failed: {str(e)}", - "tools_count": 0, - "error_type": type(e).__name__, - } - def get_config_requirements(self) -> Dict: + """Get configuration requirements for the MCP tool.""" return { "server_url": { "type": "string", - "description": "URL of the remote MCP server (e.g., https://api.example.com)", + "description": "URL of the remote MCP server (e.g., https://api.example.com/mcp or https://docs.mcp.cloudflare.com/sse)", "required": True, }, + "transport_type": { + "type": "string", + "description": "Transport type for connection", + "enum": ["auto", "sse", "http", "stdio"], + "default": "auto", + "required": False, + "help": { + "auto": "Automatically detect best transport", + "sse": "Server-Sent Events (for real-time streaming)", + "http": "HTTP streaming (recommended for production)", + "stdio": "Standard I/O (for local servers)", + }, + }, "auth_type": { "type": "string", "description": "Authentication type", - "enum": ["none", "api_key", "bearer", "basic"], + "enum": ["none", "bearer", "oauth", "api_key", "basic"], "default": "none", "required": True, + "help": { + "none": "No authentication", + "bearer": "Bearer token authentication", + "oauth": "OAuth 2.1 authentication (with frontend integration)", + "api_key": "API key authentication", + "basic": "Basic authentication", + }, }, "auth_credentials": { "type": "object", "description": "Authentication credentials (varies by auth_type)", "required": False, + "properties": { + "bearer_token": { + "type": "string", + "description": "Bearer token for bearer auth", + }, + "access_token": { + "type": "string", + "description": "Access token for OAuth (if pre-obtained)", + }, + "api_key": { + "type": "string", + "description": "API key for api_key auth", + }, + "api_key_header": { + "type": "string", + "description": "Header name for API key (default: X-API-Key)", + }, + "username": { + "type": "string", + "description": "Username for basic auth", + }, + "password": { + "type": "string", + "description": "Password for basic auth", + }, + }, + }, + "oauth_scopes": { + "type": "array", + "description": "OAuth scopes to request (for oauth auth_type)", + "items": {"type": "string"}, + "required": False, + "default": [], + }, + "oauth_client_name": { + "type": "string", + "description": "Client name for OAuth registration (for oauth auth_type)", + "default": "DocsGPT-MCP", + "required": False, + }, + "headers": { + "type": "object", + "description": "Custom headers to send with requests", + "required": False, }, "timeout": { "type": "integer", @@ -402,4 +544,318 @@ class MCPTool(Tool): "maximum": 300, "required": False, }, + "command": { + "type": "string", + "description": "Command to run for STDIO transport (e.g., 'python')", + "required": False, + }, + "args": { + "type": "array", + "description": "Arguments for STDIO command", + "items": {"type": "string"}, + "required": False, + }, } + + +class DocsGPTOAuth(OAuthClientProvider): + """ + Custom OAuth handler for DocsGPT that uses frontend redirect instead of browser. + """ + + def __init__( + self, + mcp_url: str, + redirect_uri: str, + redis_client: Redis | None = None, + redis_prefix: str = "mcp_oauth:", + task_id: str = None, + scopes: str | list[str] | None = None, + client_name: str = "DocsGPT-MCP", + user_id=None, + db=None, + additional_client_metadata: dict[str, Any] | None = None, + ): + """ + Initialize custom OAuth client provider for DocsGPT. + + Args: + mcp_url: Full URL to the MCP endpoint + redirect_uri: Custom redirect URI for DocsGPT frontend + redis_client: Redis client for storing auth state + redis_prefix: Prefix for Redis keys + task_id: Task ID for tracking auth status + scopes: OAuth scopes to request + client_name: Name for this client during registration + user_id: User ID for token storage + db: Database instance for token storage + additional_client_metadata: Extra fields for OAuthClientMetadata + """ + + self.redirect_uri = redirect_uri + self.redis_client = redis_client + self.redis_prefix = redis_prefix + self.task_id = task_id + self.user_id = user_id + self.db = db + + parsed_url = urlparse(mcp_url) + self.server_base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + + if isinstance(scopes, list): + scopes = " ".join(scopes) + client_metadata = OAuthClientMetadata( + client_name=client_name, + redirect_uris=[AnyHttpUrl(redirect_uri)], + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + scope=scopes, + **(additional_client_metadata or {}), + ) + + storage = DBTokenStorage( + server_url=self.server_base_url, user_id=self.user_id, db_client=self.db + ) + + super().__init__( + server_url=self.server_base_url, + client_metadata=client_metadata, + storage=storage, + redirect_handler=self.redirect_handler, + callback_handler=self.callback_handler, + ) + + self.auth_url = None + self.extracted_state = None + + def _process_auth_url(self, authorization_url: str) -> tuple[str, str]: + """Process authorization URL to extract state""" + try: + parsed_url = urlparse(authorization_url) + query_params = parse_qs(parsed_url.query) + + state_params = query_params.get("state", []) + if state_params: + state = state_params[0] + else: + raise ValueError("No state in auth URL") + return authorization_url, state + except Exception as e: + raise Exception(f"Failed to process auth URL: {e}") + + async def redirect_handler(self, authorization_url: str) -> None: + """Store auth URL and state in Redis for frontend to use.""" + auth_url, state = self._process_auth_url(authorization_url) + logging.info( + "[DocsGPTOAuth] Processed auth_url: %s, state: %s", auth_url, state + ) + self.auth_url = auth_url + self.extracted_state = state + + if self.redis_client and self.extracted_state: + key = f"{self.redis_prefix}auth_url:{self.extracted_state}" + self.redis_client.setex(key, 600, auth_url) + logging.info("[DocsGPTOAuth] Stored auth_url in Redis: %s", key) + + if self.task_id: + status_key = f"mcp_oauth_status:{self.task_id}" + status_data = { + "status": "requires_redirect", + "message": "OAuth authorization required", + "authorization_url": self.auth_url, + "state": self.extracted_state, + "requires_oauth": True, + "task_id": self.task_id, + } + self.redis_client.setex(status_key, 600, json.dumps(status_data)) + + async def callback_handler(self) -> tuple[str, str | None]: + """Wait for auth code from Redis using the state value.""" + if not self.redis_client or not self.extracted_state: + raise Exception("Redis client or state not configured for OAuth") + poll_interval = 1 + max_wait_time = 300 + code_key = f"{self.redis_prefix}code:{self.extracted_state}" + + if self.task_id: + status_key = f"mcp_oauth_status:{self.task_id}" + status_data = { + "status": "awaiting_callback", + "message": "Waiting for OAuth callback...", + "authorization_url": self.auth_url, + "state": self.extracted_state, + "requires_oauth": True, + "task_id": self.task_id, + } + self.redis_client.setex(status_key, 600, json.dumps(status_data)) + start_time = time.time() + while time.time() - start_time < max_wait_time: + code_data = self.redis_client.get(code_key) + if code_data: + code = code_data.decode() + returned_state = self.extracted_state + + self.redis_client.delete(code_key) + self.redis_client.delete( + f"{self.redis_prefix}auth_url:{self.extracted_state}" + ) + self.redis_client.delete( + f"{self.redis_prefix}state:{self.extracted_state}" + ) + + if self.task_id: + status_data = { + "status": "callback_received", + "message": "OAuth callback received, completing authentication...", + "task_id": self.task_id, + } + self.redis_client.setex(status_key, 600, json.dumps(status_data)) + return code, returned_state + error_key = f"{self.redis_prefix}error:{self.extracted_state}" + error_data = self.redis_client.get(error_key) + if error_data: + error_msg = error_data.decode() + self.redis_client.delete(error_key) + self.redis_client.delete( + f"{self.redis_prefix}auth_url:{self.extracted_state}" + ) + self.redis_client.delete( + f"{self.redis_prefix}state:{self.extracted_state}" + ) + raise Exception(f"OAuth error: {error_msg}") + await asyncio.sleep(poll_interval) + self.redis_client.delete(f"{self.redis_prefix}auth_url:{self.extracted_state}") + self.redis_client.delete(f"{self.redis_prefix}state:{self.extracted_state}") + raise Exception("OAuth callback timeout: no code received within 5 minutes") + + +class DBTokenStorage(TokenStorage): + def __init__(self, server_url: str, user_id: str, db_client): + self.server_url = server_url + self.user_id = user_id + self.db_client = db_client + self.collection = db_client["connector_sessions"] + + @staticmethod + def get_base_url(url: str) -> str: + parsed = urlparse(url) + return f"{parsed.scheme}://{parsed.netloc}" + + def get_db_key(self) -> dict: + return { + "server_url": self.get_base_url(self.server_url), + "user_id": self.user_id, + } + + async def get_tokens(self) -> OAuthToken | None: + doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key()) + if not doc or "tokens" not in doc: + return None + try: + tokens = OAuthToken.model_validate(doc["tokens"]) + return tokens + except ValidationError as e: + logging.error(f"Could not load tokens: {e}") + return None + + async def set_tokens(self, tokens: OAuthToken) -> None: + await asyncio.to_thread( + self.collection.update_one, + self.get_db_key(), + {"$set": {"tokens": tokens.model_dump()}}, + True, + ) + logging.info(f"Saved tokens for {self.get_base_url(self.server_url)}") + + async def get_client_info(self) -> OAuthClientInformationFull | None: + doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key()) + if not doc or "client_info" not in doc: + return None + try: + client_info = OAuthClientInformationFull.model_validate(doc["client_info"]) + tokens = await self.get_tokens() + if tokens is None: + logging.debug( + "No tokens found, clearing client info to force fresh registration." + ) + await asyncio.to_thread( + self.collection.update_one, + self.get_db_key(), + {"$unset": {"client_info": ""}}, + ) + return None + return client_info + except ValidationError as e: + logging.error(f"Could not load client info: {e}") + return None + + def _serialize_client_info(self, info: dict) -> dict: + if "redirect_uris" in info and isinstance(info["redirect_uris"], list): + info["redirect_uris"] = [str(u) for u in info["redirect_uris"]] + return info + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: + serialized_info = self._serialize_client_info(client_info.model_dump()) + await asyncio.to_thread( + self.collection.update_one, + self.get_db_key(), + {"$set": {"client_info": serialized_info}}, + True, + ) + logging.info(f"Saved client info for {self.get_base_url(self.server_url)}") + + async def clear(self) -> None: + await asyncio.to_thread(self.collection.delete_one, self.get_db_key()) + logging.info(f"Cleared OAuth cache for {self.get_base_url(self.server_url)}") + + @classmethod + async def clear_all(cls, db_client) -> None: + collection = db_client["connector_sessions"] + await asyncio.to_thread(collection.delete_many, {}) + logging.info("Cleared all OAuth client cache data.") + + +class MCPOAuthManager: + """Manager for handling MCP OAuth callbacks.""" + + def __init__(self, redis_client: Redis | None, redis_prefix: str = "mcp_oauth:"): + self.redis_client = redis_client + self.redis_prefix = redis_prefix + + def handle_oauth_callback( + self, state: str, code: str, error: Optional[str] = None + ) -> bool: + """ + Handle OAuth callback from provider. + + Args: + state: The state parameter from OAuth callback + code: The authorization code from OAuth callback + error: Error message if OAuth failed + + Returns: + True if successful, False otherwise + """ + try: + if not self.redis_client or not state: + raise Exception("Redis client or state not provided") + if error: + error_key = f"{self.redis_prefix}error:{state}" + self.redis_client.setex(error_key, 300, error) + raise Exception(f"OAuth error received: {error}") + code_key = f"{self.redis_prefix}code:{state}" + self.redis_client.setex(code_key, 300, code) + + state_key = f"{self.redis_prefix}state:{state}" + self.redis_client.setex(state_key, 300, "completed") + + return True + except Exception as e: + logging.error(f"Error handling OAuth callback: {e}") + return False + + def get_oauth_status(self, task_id: str) -> Dict[str, Any]: + """Get current status of OAuth flow using provided task_id.""" + if not task_id: + return {"status": "not_started", "message": "OAuth flow not started"} + return mcp_oauth_status_task(task_id) diff --git a/application/api/connector/routes.py b/application/api/connector/routes.py index 09b6c0c9..49307058 100644 --- a/application/api/connector/routes.py +++ b/application/api/connector/routes.py @@ -1,5 +1,7 @@ +import base64 import datetime import json +import uuid from bson.objectid import ObjectId @@ -13,8 +15,6 @@ from flask import ( from flask_restx import fields, Namespace, Resource - - from application.api.user.tasks import ( ingest_connector_task, ) @@ -234,8 +234,24 @@ class ConnectorAuth(Resource): if not ConnectorCreator.is_supported(provider): return make_response(jsonify({"success": False, "error": f"Unsupported provider: {provider}"}), 400) - import uuid - state = str(uuid.uuid4()) + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401) + user_id = decoded_token.get('sub') + + now = datetime.datetime.now(datetime.timezone.utc) + result = sessions_collection.insert_one({ + "provider": provider, + "user": user_id, + "status": "pending", + "created_at": now + }) + state_dict = { + "provider": provider, + "object_id": str(result.inserted_id) + } + state = base64.urlsafe_b64encode(json.dumps(state_dict).encode()).decode() + auth = ConnectorCreator.create_auth(provider) authorization_url = auth.get_authorization_url(state=state) return make_response(jsonify({ @@ -256,25 +272,30 @@ class ConnectorsCallback(Resource): try: from application.parser.connectors.connector_creator import ConnectorCreator from flask import request, redirect - import uuid - provider = request.args.get('provider', 'google_drive') authorization_code = request.args.get('code') - _ = request.args.get('state') + state = request.args.get('state') error = request.args.get('error') + state_dict = json.loads(base64.urlsafe_b64decode(state.encode()).decode()) + provider = state_dict["provider"] + state_object_id = state_dict["object_id"] + if error: - return redirect(f"/api/connectors/callback-status?status=error&message=OAuth+error:+{error}.+Please+try+again+and+make+sure+to+grant+all+requested+permissions,+including+offline+access.&provider={provider}") + if error == "access_denied": + return redirect(f"/api/connectors/callback-status?status=cancelled&message=Authentication+was+cancelled.+You+can+try+again+if+you'd+like+to+connect+your+account.&provider={provider}") + else: + current_app.logger.warning(f"OAuth error in callback: {error}") + return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}") if not authorization_code: - return redirect(f"/api/connectors/callback-status?status=error&message=Authorization+code+not+provided.+Please+complete+the+authorization+process+and+make+sure+to+grant+offline+access.&provider={provider}") + return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}") try: auth = ConnectorCreator.create_auth(provider) token_info = auth.exchange_code_for_tokens(authorization_code) session_token = str(uuid.uuid4()) - try: credentials = auth.create_credentials_from_token_info(token_info) @@ -292,26 +313,28 @@ class ConnectorsCallback(Resource): "expiry": token_info.get("expiry") } - user_id = request.decoded_token.get("sub") if getattr(request, "decoded_token", None) else None - sessions_collection.insert_one({ - "session_token": session_token, - "user": user_id, - "token_info": sanitized_token_info, - "created_at": datetime.datetime.now(datetime.timezone.utc), - "user_email": user_email, - "provider": provider - }) + sessions_collection.find_one_and_update( + {"_id": ObjectId(state_object_id), "provider": provider}, + { + "$set": { + "session_token": session_token, + "token_info": sanitized_token_info, + "user_email": user_email, + "status": "authorized" + } + } + ) # Redirect to success page with session token and user email return redirect(f"/api/connectors/callback-status?status=success&message=Authentication+successful&provider={provider}&session_token={session_token}&user_email={user_email}") except Exception as e: current_app.logger.error(f"Error exchanging code for tokens: {str(e)}", exc_info=True) - return redirect(f"/api/connectors/callback-status?status=error&message=Failed+to+exchange+authorization+code+for+tokens:+{str(e)}&provider={provider}") + return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}") except Exception as e: current_app.logger.error(f"Error handling connector callback: {e}") - return redirect(f"/api/connectors/callback-status?status=error&message=Failed+to+complete+connector+authentication:+{str(e)}.+Please+try+again+and+make+sure+to+grant+all+requested+permissions,+including+offline+access.") + return redirect("/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.") @connectors_ns.route("/api/connectors/refresh") @@ -629,20 +652,23 @@ class ConnectorCallbackStatus(Resource): .container {{ max-width: 600px; margin: 0 auto; }} .success {{ color: #4CAF50; }} .error {{ color: #F44336; }} + .cancelled {{ color: #FF9800; }}