From 3369b910b4b4020651ac5c2fff572d1b6883dc93 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Tue, 16 Sep 2025 20:43:04 +0530 Subject: [PATCH 01/11] feat: update MCP request ID generation and enhance response handling for event streams --- application/agents/tools/mcp_tool.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/application/agents/tools/mcp_tool.py b/application/agents/tools/mcp_tool.py index dc689367..27d0b5f5 100644 --- a/application/agents/tools/mcp_tool.py +++ b/application/agents/tools/mcp_tool.py @@ -146,7 +146,7 @@ class MCPTool(Tool): mcp_message = {"jsonrpc": "2.0", "method": method} if not method.startswith("notifications/"): - mcp_message["id"] = 1 + mcp_message["id"] = int(time.time() * 1000000) if params: mcp_message["params"] = params return self._execute_mcp_request(mcp_message, method) @@ -181,7 +181,11 @@ class MCPTool(Tool): if method.startswith("notifications/"): return {} response_text = response.text.strip() - if response_text.startswith("event:") and "data:" in response_text: + if ( + response.headers.get("content-type", "").startswith("text/event-stream") + or response_text.startswith("event:") + and "data:" in response_text + ): lines = response_text.split("\n") data_line = None for line in lines: @@ -200,12 +204,15 @@ class MCPTool(Tool): result = response.json() except json.JSONDecodeError: raise Exception(f"Invalid JSON response: {response.text}") + if "error" in result: error_msg = result["error"] if isinstance(error_msg, dict): error_msg = error_msg.get("message", str(error_msg)) raise Exception(f"MCP server error: {error_msg}") + return result.get("result", result) + except requests.exceptions.RequestException as e: if not is_retry and self._should_retry_with_new_session(e): self._invalidate_and_refresh_session() From 82b47b5673a6438471dc7d0e7fca2ddbd3d12357 Mon Sep 17 00:00:00 2001 From: jane Date: Tue, 16 Sep 2025 23:53:06 +0530 Subject: [PATCH 02/11] Added fix in frontend/src/conversation/ConversationMessages.tsx line 213 --- frontend/src/conversation/ConversationMessages.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/src/conversation/ConversationMessages.tsx b/frontend/src/conversation/ConversationMessages.tsx index 4bc2bb08..717023a4 100644 --- a/frontend/src/conversation/ConversationMessages.tsx +++ b/frontend/src/conversation/ConversationMessages.tsx @@ -210,7 +210,7 @@ export default function ConversationMessages({ )}
- {headerContent && headerContent} + {headerContent} {queries.length > 0 ? ( queries.map((query, index) => ( From 8ce345cd94c5c53316df179087f7d8fa3e99cde9 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Wed, 17 Sep 2025 20:51:32 +0530 Subject: [PATCH 03/11] feat: refactor MCPTool to use FastMCP client and improve async handling, update transport and authentication configurations --- application/agents/tools/mcp_tool.py | 552 +++++++++++++++------------ application/api/user/routes.py | 64 ++-- 2 files changed, 331 insertions(+), 285 deletions(-) diff --git a/application/agents/tools/mcp_tool.py b/application/agents/tools/mcp_tool.py index 27d0b5f5..7cb32633 100644 --- a/application/agents/tools/mcp_tool.py +++ b/application/agents/tools/mcp_tool.py @@ -1,14 +1,20 @@ -import json +import asyncio +import base64 import time from typing import Any, Dict, List, Optional -import requests - from application.agents.tools.base import Tool from application.security.encryption import decrypt_credentials +from fastmcp import Client +from fastmcp.client.auth import BearerAuth +from fastmcp.client.transports import ( + SSETransport, + StdioTransport, + StreamableHttpTransport, +) -_mcp_session_cache = {} +_mcp_clients_cache = {} class MCPTool(Tool): @@ -24,15 +30,21 @@ class MCPTool(Tool): Args: config: Dictionary containing MCP server configuration: - server_url: URL of the remote MCP server - - auth_type: Type of authentication (api_key, bearer, basic, none) + - transport_type: Transport type (auto, sse, http, stdio) + - auth_type: Type of authentication (bearer, oauth, api_key, basic, none) - encrypted_credentials: Encrypted credentials (if available) - timeout: Request timeout in seconds (default: 30) + - headers: Custom headers for requests + - command: Command for STDIO transport + - args: Arguments for STDIO transport user_id: User ID for decrypting credentials (required if encrypted_credentials exist) """ self.config = config self.server_url = config.get("server_url", "") + self.transport_type = config.get("transport_type", "auto") self.auth_type = config.get("auth_type", "none") self.timeout = config.get("timeout", 30) + self.custom_headers = config.get("headers", {}) self.auth_credentials = {} if config.get("encrypted_credentials") and user_id: @@ -42,33 +54,21 @@ class MCPTool(Tool): else: self.auth_credentials = config.get("auth_credentials", {}) self.available_tools = [] - self._session = requests.Session() - self._mcp_session_id = None - self._setup_authentication() self._cache_key = self._generate_cache_key() + self._client = None - def _setup_authentication(self): - """Setup authentication for the MCP server connection.""" - if self.auth_type == "api_key": - api_key = self.auth_credentials.get("api_key", "") - header_name = self.auth_credentials.get("api_key_header", "X-API-Key") - if api_key: - self._session.headers.update({header_name: api_key}) - elif self.auth_type == "bearer": - token = self.auth_credentials.get("bearer_token", "") - if token: - self._session.headers.update({"Authorization": f"Bearer {token}"}) - elif self.auth_type == "basic": - username = self.auth_credentials.get("username", "") - password = self.auth_credentials.get("password", "") - if username and password: - self._session.auth = (username, password) + # Only validate and setup if server_url is provided + + if self.server_url: + self._setup_client() def _generate_cache_key(self) -> str: """Generate a unique cache key for this MCP server configuration.""" auth_key = "" - if self.auth_type == "bearer": - token = self.auth_credentials.get("bearer_token", "") + if self.auth_type in ["bearer", "oauth"]: + token = self.auth_credentials.get( + "bearer_token", "" + ) or self.auth_credentials.get("access_token", "") auth_key = f"bearer:{token[:10]}..." if token else "bearer:none" elif self.auth_type == "api_key": api_key = self.auth_credentials.get("api_key", "") @@ -78,208 +78,174 @@ class MCPTool(Tool): auth_key = f"basic:{username}" else: auth_key = "none" - return f"{self.server_url}#{auth_key}" + return f"{self.server_url}#{self.transport_type}#{auth_key}" - def _get_cached_session(self) -> Optional[str]: - """Get cached session ID if available and not expired.""" - global _mcp_session_cache - - if self._cache_key in _mcp_session_cache: - session_data = _mcp_session_cache[self._cache_key] - if time.time() - session_data["created_at"] < 1800: - return session_data["session_id"] + def _setup_client(self): + """Setup FastMCP client with proper transport and authentication.""" + global _mcp_clients_cache + if self._cache_key in _mcp_clients_cache: + cached_data = _mcp_clients_cache[self._cache_key] + if time.time() - cached_data["created_at"] < 1800: + self._client = cached_data["client"] + return else: - del _mcp_session_cache[self._cache_key] - return None + del _mcp_clients_cache[self._cache_key] + transport = self._create_transport() - def _cache_session(self, session_id: str): - """Cache the session ID for reuse.""" - global _mcp_session_cache - _mcp_session_cache[self._cache_key] = { - "session_id": session_id, + if self.auth_type in ["bearer", "oauth"]: + token = self.auth_credentials.get( + "bearer_token", "" + ) or self.auth_credentials.get("access_token", "") + if token: + self._client = Client(transport, auth=BearerAuth(token)) + else: + self._client = Client(transport) + else: + self._client = Client(transport) + _mcp_clients_cache[self._cache_key] = { + "client": self._client, "created_at": time.time(), } - def _initialize_mcp_connection(self) -> Dict: - """ - Initialize MCP connection with the server, using cached session if available. + def _create_transport(self): + """Create appropriate transport based on configuration.""" + headers = {"Content-Type": "application/json", "User-Agent": "DocsGPT-MCP/1.0"} + headers.update(self.custom_headers) - Returns: - Server capabilities and information - """ - cached_session = self._get_cached_session() - if cached_session: - self._mcp_session_id = cached_session - return {"cached": True} - try: - init_params = { - "protocolVersion": "2024-11-05", - "capabilities": {"roots": {"listChanged": True}, "sampling": {}}, - "clientInfo": {"name": "DocsGPT", "version": "1.0.0"}, - } - response = self._make_mcp_request("initialize", init_params) - self._make_mcp_request("notifications/initialized") - - return response - except Exception as e: - return {"error": str(e), "fallback": True} - - def _ensure_valid_session(self): - """Ensure we have a valid MCP session, reinitializing if needed.""" - if not self._mcp_session_id: - self._initialize_mcp_connection() - - def _make_mcp_request(self, method: str, params: Optional[Dict] = None) -> Dict: - """ - Make an MCP protocol request to the server with automatic session recovery. - - Args: - method: MCP method name (e.g., "tools/list", "tools/call") - params: Parameters for the MCP method - - Returns: - Response data as dictionary - - Raises: - Exception: If request fails after retry - """ - mcp_message = {"jsonrpc": "2.0", "method": method} - - if not method.startswith("notifications/"): - mcp_message["id"] = int(time.time() * 1000000) - if params: - mcp_message["params"] = params - return self._execute_mcp_request(mcp_message, method) - - def _execute_mcp_request( - self, mcp_message: Dict, method: str, is_retry: bool = False - ) -> Dict: - """Execute MCP request with optional retry on session failure.""" - try: - final_headers = self._session.headers.copy() - final_headers.update( - { - "Content-Type": "application/json", - "Accept": "application/json, text/event-stream", - } - ) - - if self._mcp_session_id: - final_headers["Mcp-Session-Id"] = self._mcp_session_id - response = self._session.post( - self.server_url.rstrip("/"), - json=mcp_message, - headers=final_headers, - timeout=self.timeout, - ) - - if "mcp-session-id" in response.headers: - self._mcp_session_id = response.headers["mcp-session-id"] - self._cache_session(self._mcp_session_id) - response.raise_for_status() - - if method.startswith("notifications/"): - return {} - response_text = response.text.strip() - if ( - response.headers.get("content-type", "").startswith("text/event-stream") - or response_text.startswith("event:") - and "data:" in response_text - ): - lines = response_text.split("\n") - data_line = None - for line in lines: - if line.startswith("data:"): - data_line = line[5:].strip() - break - if data_line: - try: - result = json.loads(data_line) - except json.JSONDecodeError: - raise Exception(f"Invalid JSON in SSE data: {data_line}") - else: - raise Exception(f"No data found in SSE response: {response_text}") + if self.auth_type == "api_key": + api_key = self.auth_credentials.get("api_key", "") + header_name = self.auth_credentials.get("api_key_header", "X-API-Key") + if api_key: + headers[header_name] = api_key + elif self.auth_type == "basic": + username = self.auth_credentials.get("username", "") + password = self.auth_credentials.get("password", "") + if username and password: + credentials = base64.b64encode( + f"{username}:{password}".encode() + ).decode() + headers["Authorization"] = f"Basic {credentials}" + if self.transport_type == "auto": + if "sse" in self.server_url.lower() or self.server_url.endswith("/sse"): + transport_type = "sse" else: - try: - result = response.json() - except json.JSONDecodeError: - raise Exception(f"Invalid JSON response: {response.text}") + transport_type = "http" + else: + transport_type = self.transport_type + if transport_type == "sse": + headers.update({"Accept": "text/event-stream", "Cache-Control": "no-cache"}) + return SSETransport(url=self.server_url, headers=headers) + elif transport_type == "http": + return StreamableHttpTransport(url=self.server_url, headers=headers) + elif transport_type == "stdio": + command = self.config.get("command", "python") + args = self.config.get("args", []) + env = self.auth_credentials if self.auth_credentials else None + return StdioTransport(command=command, args=args, env=env) + else: + return StreamableHttpTransport(url=self.server_url, headers=headers) - if "error" in result: - error_msg = result["error"] - if isinstance(error_msg, dict): - error_msg = error_msg.get("message", str(error_msg)) - raise Exception(f"MCP server error: {error_msg}") + async def _execute_with_client(self, operation: str, *args, **kwargs): + """Execute operation with FastMCP client.""" + if not self._client: + raise Exception("FastMCP client not initialized") + async with self._client: + if operation == "ping": + return await self._client.ping() + elif operation == "list_tools": + tools_response = await self._client.list_tools() - return result.get("result", result) + if hasattr(tools_response, "tools"): + tools = tools_response.tools + elif isinstance(tools_response, list): + tools = tools_response + else: + tools = [] + tools_dict = [] + for tool in tools: + if hasattr(tool, "name"): + tool_dict = { + "name": tool.name, + "description": tool.description, + } + if hasattr(tool, "inputSchema"): + tool_dict["inputSchema"] = tool.inputSchema + tools_dict.append(tool_dict) + elif isinstance(tool, dict): + tools_dict.append(tool) + else: - except requests.exceptions.RequestException as e: - if not is_retry and self._should_retry_with_new_session(e): - self._invalidate_and_refresh_session() - return self._execute_mcp_request(mcp_message, method, is_retry=True) - raise Exception(f"MCP server request failed: {str(e)}") + if hasattr(tool, "model_dump"): + tools_dict.append(tool.model_dump()) + else: + tools_dict.append({"name": str(tool), "description": ""}) + return tools_dict + elif operation == "call_tool": + tool_name = args[0] + tool_args = kwargs + return await self._client.call_tool(tool_name, tool_args) + elif operation == "list_resources": + return await self._client.list_resources() + elif operation == "list_prompts": + return await self._client.list_prompts() + else: + raise Exception(f"Unknown operation: {operation}") - def _should_retry_with_new_session(self, error: Exception) -> bool: - """Check if error indicates session invalidation and retry is warranted.""" - error_str = str(error).lower() - return ( - any( - indicator in error_str - for indicator in [ - "invalid session", - "session expired", - "unauthorized", - "401", - "403", - ] - ) - and self._mcp_session_id is not None - ) + def _run_async_operation(self, operation: str, *args, **kwargs): + """Run async operation in sync context.""" + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + import concurrent.futures - def _invalidate_and_refresh_session(self) -> None: - """Invalidate current session and create a new one.""" - global _mcp_session_cache - if self._cache_key in _mcp_session_cache: - del _mcp_session_cache[self._cache_key] - self._mcp_session_id = None - self._initialize_mcp_connection() + def run_in_thread(): + new_loop = asyncio.new_event_loop() + asyncio.set_event_loop(new_loop) + try: + return new_loop.run_until_complete( + self._execute_with_client(operation, *args, **kwargs) + ) + finally: + new_loop.close() + + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(run_in_thread) + return future.result(timeout=self.timeout) + else: + return loop.run_until_complete( + self._execute_with_client(operation, *args, **kwargs) + ) + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete( + self._execute_with_client(operation, *args, **kwargs) + ) + finally: + loop.close() def discover_tools(self) -> List[Dict]: """ - Discover available tools from the MCP server using MCP protocol. + Discover available tools from the MCP server using FastMCP. Returns: List of tool definitions from the server """ + if not self.server_url: + return [] + if not self._client: + self._setup_client() try: - self._ensure_valid_session() - - response = self._make_mcp_request("tools/list") - - # Handle both formats: response with 'tools' key or response that IS the tools list - - if isinstance(response, dict): - if "tools" in response: - self.available_tools = response["tools"] - elif ( - "result" in response - and isinstance(response["result"], dict) - and "tools" in response["result"] - ): - self.available_tools = response["result"]["tools"] - else: - self.available_tools = [response] if response else [] - elif isinstance(response, list): - self.available_tools = response - else: - self.available_tools = [] + tools = self._run_async_operation("list_tools") + self.available_tools = tools return self.available_tools except Exception as e: raise Exception(f"Failed to discover tools from MCP server: {str(e)}") def execute_action(self, action_name: str, **kwargs) -> Any: """ - Execute an action on the remote MCP server using MCP protocol. + Execute an action on the remote MCP server using FastMCP. Args: action_name: Name of the action to execute @@ -288,22 +254,91 @@ class MCPTool(Tool): Returns: Result from the MCP server """ - self._ensure_valid_session() - - # Skipping empty/None values - letting the server use defaults - + if not self.server_url: + raise Exception("No MCP server configured") + if not self._client: + self._setup_client() cleaned_kwargs = {} for key, value in kwargs.items(): if value == "" or value is None: continue cleaned_kwargs[key] = value - call_params = {"name": action_name, "arguments": cleaned_kwargs} try: - result = self._make_mcp_request("tools/call", call_params) - return result + result = self._run_async_operation( + "call_tool", action_name, **cleaned_kwargs + ) + return self._format_result(result) except Exception as e: raise Exception(f"Failed to execute action '{action_name}': {str(e)}") + def _format_result(self, result) -> Dict: + """Format FastMCP result to match expected format.""" + if hasattr(result, "content"): + content_list = [] + for content_item in result.content: + if hasattr(content_item, "text"): + content_list.append({"type": "text", "text": content_item.text}) + elif hasattr(content_item, "data"): + content_list.append({"type": "data", "data": content_item.data}) + else: + content_list.append( + {"type": "unknown", "content": str(content_item)} + ) + return { + "content": content_list, + "isError": getattr(result, "isError", False), + } + else: + return result + + def test_connection(self) -> Dict: + """ + Test the connection to the MCP server and validate functionality. + + Returns: + Dictionary with connection test results including tool count + """ + if not self.server_url: + return { + "success": False, + "message": "No MCP server URL configured", + "tools_count": 0, + "transport_type": self.transport_type, + "auth_type": self.auth_type, + "error_type": "ConfigurationError", + } + if not self._client: + self._setup_client() + try: + try: + self._run_async_operation("ping") + ping_success = True + except Exception: + ping_success = False + tools = self.discover_tools() + + message = f"Successfully connected to MCP server. Found {len(tools)} tools." + if not ping_success: + message += " (Ping not supported, but tool discovery worked)" + return { + "success": True, + "message": message, + "tools_count": len(tools), + "transport_type": self.transport_type, + "auth_type": self.auth_type, + "ping_supported": ping_success, + "tools": [tool.get("name", "unknown") for tool in tools[:5]], + } + except Exception as e: + return { + "success": False, + "message": f"Connection failed: {str(e)}", + "tools_count": 0, + "transport_type": self.transport_type, + "auth_type": self.auth_type, + "error_type": type(e).__name__, + } + def get_actions_metadata(self) -> List[Dict]: """ Get metadata for all available actions. @@ -348,58 +383,76 @@ class MCPTool(Tool): actions.append(action) return actions - def test_connection(self) -> Dict: - """ - Test the connection to the MCP server and validate functionality. - - Returns: - Dictionary with connection test results including tool count - """ - try: - self._mcp_session_id = None - - init_result = self._initialize_mcp_connection() - - tools = self.discover_tools() - - message = f"Successfully connected to MCP server. Found {len(tools)} tools." - if init_result.get("cached"): - message += " (Using cached session)" - elif init_result.get("fallback"): - message += " (No formal initialization required)" - return { - "success": True, - "message": message, - "tools_count": len(tools), - "session_id": self._mcp_session_id, - "tools": [tool.get("name", "unknown") for tool in tools[:5]], - } - except Exception as e: - return { - "success": False, - "message": f"Connection failed: {str(e)}", - "tools_count": 0, - "error_type": type(e).__name__, - } - def get_config_requirements(self) -> Dict: + """Get configuration requirements for the MCP tool.""" return { "server_url": { "type": "string", - "description": "URL of the remote MCP server (e.g., https://api.example.com)", + "description": "URL of the remote MCP server (e.g., https://api.example.com/mcp or https://docs.mcp.cloudflare.com/sse)", "required": True, }, + "transport_type": { + "type": "string", + "description": "Transport type for connection", + "enum": ["auto", "sse", "http", "stdio"], + "default": "auto", + "required": False, + "help": { + "auto": "Automatically detect best transport", + "sse": "Server-Sent Events (for real-time streaming)", + "http": "HTTP streaming (recommended for production)", + "stdio": "Standard I/O (for local servers)", + }, + }, "auth_type": { "type": "string", "description": "Authentication type", - "enum": ["none", "api_key", "bearer", "basic"], + "enum": ["none", "bearer", "oauth", "api_key", "basic"], "default": "none", "required": True, + "help": { + "none": "No authentication", + "bearer": "Bearer token authentication", + "oauth": "OAuth 2.0 authentication", + "api_key": "API key authentication", + "basic": "Basic authentication", + }, }, "auth_credentials": { "type": "object", "description": "Authentication credentials (varies by auth_type)", "required": False, + "properties": { + "bearer_token": { + "type": "string", + "description": "Bearer token for bearer/oauth auth", + }, + "access_token": { + "type": "string", + "description": "Access token for OAuth", + }, + "api_key": { + "type": "string", + "description": "API key for api_key auth", + }, + "api_key_header": { + "type": "string", + "description": "Header name for API key (default: X-API-Key)", + }, + "username": { + "type": "string", + "description": "Username for basic auth", + }, + "password": { + "type": "string", + "description": "Password for basic auth", + }, + }, + }, + "headers": { + "type": "object", + "description": "Custom headers to send with requests", + "required": False, }, "timeout": { "type": "integer", @@ -409,4 +462,15 @@ class MCPTool(Tool): "maximum": 300, "required": False, }, + "command": { + "type": "string", + "description": "Command to run for STDIO transport (e.g., 'python')", + "required": False, + }, + "args": { + "type": "array", + "description": "Arguments for STDIO command", + "items": {"type": "string"}, + "required": False, + }, } diff --git a/application/api/user/routes.py b/application/api/user/routes.py index f0493c7c..293a7c26 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -494,7 +494,6 @@ class DeleteOldIndexes(Resource): ) if not doc: return make_response(jsonify({"status": "not found"}), 404) - storage = StorageCreator.get_storage() try: @@ -511,7 +510,6 @@ class DeleteOldIndexes(Resource): settings.VECTOR_STORE, source_id=str(doc["_id"]) ) vectorstore.delete_index() - if "file_path" in doc and doc["file_path"]: file_path = doc["file_path"] if storage.is_directory(file_path): @@ -520,7 +518,6 @@ class DeleteOldIndexes(Resource): storage.delete_file(f) else: storage.delete_file(file_path) - except FileNotFoundError: pass except Exception as err: @@ -528,7 +525,6 @@ class DeleteOldIndexes(Resource): f"Error deleting files and indexes: {err}", exc_info=True ) return make_response(jsonify({"success": False}), 400) - sources_collection.delete_one({"_id": ObjectId(source_id)}) return make_response(jsonify({"success": True}), 200) @@ -600,7 +596,6 @@ class UploadFile(Resource): == temp_file_path ): continue - rel_path = os.path.relpath( os.path.join(root, extracted_file), temp_dir ) @@ -625,7 +620,6 @@ class UploadFile(Resource): file_path = f"{base_path}/{safe_file}" with open(temp_file_path, "rb") as f: storage.save_file(f, file_path) - task = ingest.delay( settings.UPLOAD_FOLDER, [ @@ -697,7 +691,6 @@ class ManageSourceFiles(Resource): return make_response( jsonify({"success": False, "message": "Unauthorized"}), 401 ) - user = decoded_token.get("sub") source_id = request.form.get("source_id") operation = request.form.get("operation") @@ -747,7 +740,6 @@ class ManageSourceFiles(Resource): return make_response( jsonify({"success": False, "message": "Database error"}), 500 ) - try: storage = StorageCreator.get_storage() source_file_path = source.get("file_path", "") @@ -804,7 +796,6 @@ class ManageSourceFiles(Resource): ), 200, ) - elif operation == "remove": file_paths_str = request.form.get("file_paths") if not file_paths_str: @@ -858,7 +849,6 @@ class ManageSourceFiles(Resource): ), 200, ) - elif operation == "remove_directory": directory_path = request.form.get("directory_path") if not directory_path: @@ -884,7 +874,6 @@ class ManageSourceFiles(Resource): ), 400, ) - full_directory_path = ( f"{source_file_path}/{directory_path}" if directory_path @@ -943,7 +932,6 @@ class ManageSourceFiles(Resource): ), 200, ) - except Exception as err: error_context = f"operation={operation}, user={user}, source_id={source_id}" if operation == "remove_directory": @@ -955,7 +943,6 @@ class ManageSourceFiles(Resource): elif operation == "add": parent_dir = request.form.get("parent_dir", "") error_context += f", parent_dir={parent_dir}" - current_app.logger.error( f"Error managing source files: {err} ({error_context})", exc_info=True ) @@ -1632,7 +1619,6 @@ class CreateAgent(Resource): ), 400, ) - # Validate that it has either a 'schema' property or is itself a schema if "schema" not in json_schema and "type" not in json_schema: @@ -3476,7 +3462,6 @@ class AvailableTools(Resource): "displayName": name, "description": description, "configRequirements": tool_instance.get_config_requirements(), - "actions": tool_instance.get_actions_metadata(), } ) except Exception as err: @@ -3527,11 +3512,6 @@ class CreateTool(Resource): "customName": fields.String( required=False, description="Custom name for the tool" ), - "actions": fields.List( - fields.Raw, - required=True, - description="Actions the tool can perform", - ), "status": fields.Boolean( required=True, description="Status of the tool" ), @@ -3549,24 +3529,35 @@ class CreateTool(Resource): "name", "displayName", "description", - "actions", "config", "status", ] missing_fields = check_required_fields(data, required_fields) if missing_fields: return missing_fields - transformed_actions = [] - for action in data["actions"]: - action["active"] = True - if "parameters" in action: - if "properties" in action["parameters"]: - for param_name, param_details in action["parameters"][ - "properties" - ].items(): - param_details["filled_by_llm"] = True - param_details["value"] = "" - transformed_actions.append(action) + try: + tool_instance = tool_manager.tools.get(data["name"]) + if not tool_instance: + return make_response( + jsonify({"success": False, "message": "Tool not found"}), 404 + ) + actions_metadata = tool_instance.get_actions_metadata() + transformed_actions = [] + for action in actions_metadata: + action["active"] = True + if "parameters" in action: + if "properties" in action["parameters"]: + for param_name, param_details in action["parameters"][ + "properties" + ].items(): + param_details["filled_by_llm"] = True + param_details["value"] = "" + transformed_actions.append(action) + except Exception as err: + current_app.logger.error( + f"Error getting tool actions: {err}", exc_info=True + ) + return make_response(jsonify({"success": False}), 400) try: new_tool = { "user": user, @@ -3907,7 +3898,6 @@ class GetChunks(Resource): if not (text_match or title_match): continue filtered_chunks.append(chunk) - chunks = filtered_chunks total_chunks = len(chunks) @@ -4098,7 +4088,6 @@ class UpdateChunk(Resource): current_app.logger.warning( f"Failed to delete old chunk {chunk_id}, but new chunk {new_chunk_id} was created" ) - return make_response( jsonify( { @@ -4226,23 +4215,19 @@ class DirectoryStructure(Resource): decoded_token = request.decoded_token if not decoded_token: return make_response(jsonify({"success": False}), 401) - user = decoded_token.get("sub") doc_id = request.args.get("id") if not doc_id: return make_response(jsonify({"error": "Document ID is required"}), 400) - if not ObjectId.is_valid(doc_id): return make_response(jsonify({"error": "Invalid document ID"}), 400) - try: doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user}) if not doc: return make_response( jsonify({"error": "Document not found or access denied"}), 404 ) - directory_structure = doc.get("directory_structure", {}) base_path = doc.get("file_path", "") @@ -4315,7 +4300,6 @@ class TestMCPServerConfig(Resource): auth_credentials["username"] = config["username"] if "password" in config: auth_credentials["password"] = config["password"] - test_config = config.copy() test_config["auth_credentials"] = auth_credentials @@ -4395,14 +4379,12 @@ class MCPServerSave(Resource): raise Exception( "No valid credentials provided for the selected authentication type" ) - storage_config = config.copy() if auth_credentials: encrypted_credentials_string = encrypt_credentials( auth_credentials, user ) storage_config["encrypted_credentials"] = encrypted_credentials_string - for field in [ "api_key", "bearer_token", From 00b4e133d432b57e7eaee6f06f807f06703b9cf2 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Mon, 22 Sep 2025 01:31:09 +0530 Subject: [PATCH 04/11] feat: implement OAuth 2.1 integration with custom handlers for fastmcp --- application/agents/tools/mcp_tool.py | 280 ++++++++++++++++++++++----- 1 file changed, 236 insertions(+), 44 deletions(-) diff --git a/application/agents/tools/mcp_tool.py b/application/agents/tools/mcp_tool.py index 7cb32633..fd0767ec 100644 --- a/application/agents/tools/mcp_tool.py +++ b/application/agents/tools/mcp_tool.py @@ -8,6 +8,7 @@ from application.security.encryption import decrypt_credentials from fastmcp import Client from fastmcp.client.auth import BearerAuth +from fastmcp.client.auth.oauth import OAuth as FastMCPOAuth from fastmcp.client.transports import ( SSETransport, StdioTransport, @@ -17,6 +18,30 @@ from fastmcp.client.transports import ( _mcp_clients_cache = {} +class DocsGPTOAuth(FastMCPOAuth): + """Custom OAuth handler that integrates with DocsGPT frontend instead of opening browser.""" + + def __init__(self, *args, **kwargs): + self.auth_url_callback = kwargs.pop("auth_url_callback", None) + self.auth_code_callback = kwargs.pop("auth_code_callback", None) + super().__init__(*args, **kwargs) + + async def redirect_handler(self, authorization_url: str) -> None: + """Override to send auth URL to frontend instead of opening browser.""" + if self.auth_url_callback: + self.auth_url_callback(authorization_url) + else: + raise Exception("OAuth authorization URL callback not configured") + + async def callback_handler(self) -> tuple[str, str | None]: + """Override to wait for auth code from frontend instead of local server.""" + if self.auth_code_callback: + auth_code, state = await self.auth_code_callback() + return auth_code, state + else: + raise Exception("OAuth callback handler not configured") + + class MCPTool(Tool): """ MCP Tool @@ -37,6 +62,8 @@ class MCPTool(Tool): - headers: Custom headers for requests - command: Command for STDIO transport - args: Arguments for STDIO transport + - oauth_scopes: OAuth scopes for oauth auth type + - oauth_client_name: OAuth client name for oauth auth type user_id: User ID for decrypting credentials (required if encrypted_credentials exist) """ self.config = config @@ -53,19 +80,35 @@ class MCPTool(Tool): ) else: self.auth_credentials = config.get("auth_credentials", {}) + # OAuth specific configuration + + self.oauth_scopes = config.get("oauth_scopes", []) + self.oauth_client_name = config.get("oauth_client_name", "DocsGPT-MCP") + + # OAuth callback handlers (to be set by frontend) + + self.oauth_auth_url_callback = None + self.oauth_auth_code_callback = None + self.available_tools = [] self._cache_key = self._generate_cache_key() self._client = None - # Only validate and setup if server_url is provided + # Only validate and setup if server_url is provided and not OAuth + # OAuth setup will happen after callbacks are set - if self.server_url: + if self.server_url and self.auth_type != "oauth": self._setup_client() def _generate_cache_key(self) -> str: """Generate a unique cache key for this MCP server configuration.""" auth_key = "" - if self.auth_type in ["bearer", "oauth"]: + if self.auth_type == "oauth": + # For OAuth, use scopes and client name as part of the key + + scopes_str = ",".join(self.oauth_scopes) if self.oauth_scopes else "none" + auth_key = f"oauth:{self.oauth_client_name}:{scopes_str}" + elif self.auth_type in ["bearer"]: token = self.auth_credentials.get( "bearer_token", "" ) or self.auth_credentials.get("access_token", "") @@ -91,17 +134,31 @@ class MCPTool(Tool): else: del _mcp_clients_cache[self._cache_key] transport = self._create_transport() + auth = None - if self.auth_type in ["bearer", "oauth"]: + if self.auth_type == "oauth": + # Ensure callbacks are configured before creating OAuth instance + + if not self.oauth_auth_url_callback or not self.oauth_auth_code_callback: + raise Exception( + "OAuth callbacks not configured. Call set_oauth_callbacks() first." + ) + # Use custom OAuth handler for frontend integration + + auth = DocsGPTOAuth( + mcp_url=self.server_url, + scopes=self.oauth_scopes, + client_name=self.oauth_client_name, + auth_url_callback=self.oauth_auth_url_callback, + auth_code_callback=self.oauth_auth_code_callback, + ) + elif self.auth_type in ["bearer"]: token = self.auth_credentials.get( "bearer_token", "" ) or self.auth_credentials.get("access_token", "") if token: - self._client = Client(transport, auth=BearerAuth(token)) - else: - self._client = Client(transport) - else: - self._client = Client(transport) + auth = BearerAuth(token) + self._client = Client(transport, auth=auth) _mcp_clients_cache[self._cache_key] = { "client": self._client, "created_at": time.time(), @@ -194,8 +251,12 @@ class MCPTool(Tool): def _run_async_operation(self, operation: str, *args, **kwargs): """Run async operation in sync context.""" try: - loop = asyncio.get_event_loop() - if loop.is_running(): + # Check if there's already a running event loop + + try: + loop = asyncio.get_running_loop() + # If we're in an async context, we need to run in a separate thread + import concurrent.futures def run_in_thread(): @@ -211,19 +272,25 @@ class MCPTool(Tool): with concurrent.futures.ThreadPoolExecutor() as executor: future = executor.submit(run_in_thread) return future.result(timeout=self.timeout) - else: - return loop.run_until_complete( - self._execute_with_client(operation, *args, **kwargs) + except RuntimeError: + # No running loop, we can create one + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete( + self._execute_with_client(operation, *args, **kwargs) + ) + finally: + loop.close() + except Exception as e: + # If async fails, try to give a better error message for OAuth + + if self.auth_type == "oauth" and "callback not configured" in str(e): + raise Exception( + "OAuth callbacks not configured. Call set_oauth_callbacks() first." ) - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - return loop.run_until_complete( - self._execute_with_client(operation, *args, **kwargs) - ) - finally: - loop.close() + raise e def discover_tools(self) -> List[Dict]: """ @@ -310,25 +377,12 @@ class MCPTool(Tool): if not self._client: self._setup_client() try: - try: - self._run_async_operation("ping") - ping_success = True - except Exception: - ping_success = False - tools = self.discover_tools() + # For OAuth, we need to handle async operations differently - message = f"Successfully connected to MCP server. Found {len(tools)} tools." - if not ping_success: - message += " (Ping not supported, but tool discovery worked)" - return { - "success": True, - "message": message, - "tools_count": len(tools), - "transport_type": self.transport_type, - "auth_type": self.auth_type, - "ping_supported": ping_success, - "tools": [tool.get("name", "unknown") for tool in tools[:5]], - } + if self.auth_type == "oauth": + return self._test_oauth_connection() + else: + return self._test_regular_connection() except Exception as e: return { "success": False, @@ -339,6 +393,131 @@ class MCPTool(Tool): "error_type": type(e).__name__, } + def _test_regular_connection(self) -> Dict: + """Test connection for non-OAuth auth types.""" + try: + self._run_async_operation("ping") + ping_success = True + except Exception: + ping_success = False + tools = self.discover_tools() + + message = f"Successfully connected to MCP server. Found {len(tools)} tools." + if not ping_success: + message += " (Ping not supported, but tool discovery worked)" + return { + "success": True, + "message": message, + "tools_count": len(tools), + "transport_type": self.transport_type, + "auth_type": self.auth_type, + "ping_supported": ping_success, + "tools": [tool.get("name", "unknown") for tool in tools[:5]], + } + + def _test_oauth_connection(self) -> Dict: + """Test connection for OAuth auth type with proper async handling.""" + try: + # Ensure callbacks are configured before proceeding + + if not self.oauth_auth_url_callback or not self.oauth_auth_code_callback: + return { + "success": False, + "message": "OAuth callbacks not configured. Call set_oauth_callbacks() first.", + "tools_count": 0, + "transport_type": self.transport_type, + "auth_type": self.auth_type, + "error_type": "ConfigurationError", + } + # Ensure client is set up with proper callbacks + + if not self._client: + self._setup_client() + # For OAuth, we use a simpler approach - just try to discover tools + # This will trigger the OAuth flow if needed + + tools = self.discover_tools() + + return { + "success": True, + "message": f"Successfully connected to OAuth MCP server. Found {len(tools)} tools.", + "tools_count": len(tools), + "transport_type": self.transport_type, + "auth_type": self.auth_type, + "ping_supported": False, # Skip ping for OAuth to avoid complexity + "tools": [tool.get("name", "unknown") for tool in tools[:5]], + } + except Exception as e: + return { + "success": False, + "message": f"OAuth connection failed: {str(e)}", + "tools_count": 0, + "transport_type": self.transport_type, + "auth_type": self.auth_type, + "error_type": type(e).__name__, + } + + def set_oauth_callbacks(self, auth_url_callback, auth_code_callback): + """ + Set OAuth callback handlers for frontend integration. + + Args: + auth_url_callback: Function to call with authorization URL + auth_code_callback: Async function that returns (auth_code, state) tuple + """ + self.oauth_auth_url_callback = auth_url_callback + self.oauth_auth_code_callback = auth_code_callback + + # Clear the client so it gets recreated with the new callbacks + + self._client = None + + # Also clear from cache to ensure fresh creation + + global _mcp_clients_cache + if self._cache_key in _mcp_clients_cache: + del _mcp_clients_cache[self._cache_key] + + def clear_oauth_cache(self): + """ + Clear OAuth cache to force fresh authentication. + This will remove stored tokens and client info for the server. + """ + if self.auth_type == "oauth": + try: + from fastmcp.client.auth.oauth import FileTokenStorage + + storage = FileTokenStorage(server_url=self.server_url) + storage.clear() + print(f"✅ Cleared OAuth cache for {self.server_url}") + except Exception as e: + print(f"⚠️ Failed to clear OAuth cache: {e}") + # Also clear our internal client cache + + global _mcp_clients_cache + if self._cache_key in _mcp_clients_cache: + del _mcp_clients_cache[self._cache_key] + print(f"✅ Cleared internal client cache") + + @staticmethod + def clear_all_oauth_cache(): + """ + Clear all OAuth cache for all servers. + This will remove all stored tokens and client info. + """ + try: + from fastmcp.client.auth.oauth import FileTokenStorage + + FileTokenStorage.clear_all() + print(f"✅ Cleared all OAuth cache") + except Exception as e: + print(f"⚠️ Failed to clear all OAuth cache: {e}") + # Also clear all internal client cache + + global _mcp_clients_cache + _mcp_clients_cache.clear() + print(f"✅ Cleared all internal client cache") + def get_actions_metadata(self) -> List[Dict]: """ Get metadata for all available actions. @@ -413,7 +592,7 @@ class MCPTool(Tool): "help": { "none": "No authentication", "bearer": "Bearer token authentication", - "oauth": "OAuth 2.0 authentication", + "oauth": "OAuth 2.1 authentication (with frontend integration)", "api_key": "API key authentication", "basic": "Basic authentication", }, @@ -425,11 +604,11 @@ class MCPTool(Tool): "properties": { "bearer_token": { "type": "string", - "description": "Bearer token for bearer/oauth auth", + "description": "Bearer token for bearer auth", }, "access_token": { "type": "string", - "description": "Access token for OAuth", + "description": "Access token for OAuth (if pre-obtained)", }, "api_key": { "type": "string", @@ -449,6 +628,19 @@ class MCPTool(Tool): }, }, }, + "oauth_scopes": { + "type": "array", + "description": "OAuth scopes to request (for oauth auth_type)", + "items": {"type": "string"}, + "required": False, + "default": [], + }, + "oauth_client_name": { + "type": "string", + "description": "Client name for OAuth registration (for oauth auth_type)", + "default": "DocsGPT-MCP", + "required": False, + }, "headers": { "type": "object", "description": "Custom headers to send with requests", From c0361ff03d9f8cb82f13cc48b15e057b59f83ebe Mon Sep 17 00:00:00 2001 From: ManishMadan2882 Date: Thu, 25 Sep 2025 03:27:16 +0530 Subject: [PATCH 05/11] (fix:oauth) user field null --- application/api/connector/routes.py | 49 +++++++++++++++++++---------- 1 file changed, 32 insertions(+), 17 deletions(-) diff --git a/application/api/connector/routes.py b/application/api/connector/routes.py index 09b6c0c9..e65bee55 100644 --- a/application/api/connector/routes.py +++ b/application/api/connector/routes.py @@ -234,8 +234,20 @@ class ConnectorAuth(Resource): if not ConnectorCreator.is_supported(provider): return make_response(jsonify({"success": False, "error": f"Unsupported provider: {provider}"}), 400) - import uuid - state = str(uuid.uuid4()) + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401) + user_id = decoded_token.get('sub') + + now = datetime.datetime.now(datetime.timezone.utc) + result = sessions_collection.insert_one({ + "provider": provider, + "user": user_id, + "status": "pending", + "created_at": now + }) + state = str(result.inserted_id) + auth = ConnectorCreator.create_auth(provider) authorization_url = auth.get_authorization_url(state=state) return make_response(jsonify({ @@ -260,21 +272,22 @@ class ConnectorsCallback(Resource): provider = request.args.get('provider', 'google_drive') authorization_code = request.args.get('code') - _ = request.args.get('state') + state = request.args.get('state') error = request.args.get('error') if error: - return redirect(f"/api/connectors/callback-status?status=error&message=OAuth+error:+{error}.+Please+try+again+and+make+sure+to+grant+all+requested+permissions,+including+offline+access.&provider={provider}") + return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}") if not authorization_code: - return redirect(f"/api/connectors/callback-status?status=error&message=Authorization+code+not+provided.+Please+complete+the+authorization+process+and+make+sure+to+grant+offline+access.&provider={provider}") + return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}") + + state_object_id = ObjectId(state) try: auth = ConnectorCreator.create_auth(provider) token_info = auth.exchange_code_for_tokens(authorization_code) session_token = str(uuid.uuid4()) - try: credentials = auth.create_credentials_from_token_info(token_info) @@ -292,26 +305,28 @@ class ConnectorsCallback(Resource): "expiry": token_info.get("expiry") } - user_id = request.decoded_token.get("sub") if getattr(request, "decoded_token", None) else None - sessions_collection.insert_one({ - "session_token": session_token, - "user": user_id, - "token_info": sanitized_token_info, - "created_at": datetime.datetime.now(datetime.timezone.utc), - "user_email": user_email, - "provider": provider - }) + sessions_collection.find_one_and_update( + {"_id": state_object_id, "provider": provider}, + { + "$set": { + "session_token": session_token, + "token_info": sanitized_token_info, + "user_email": user_email, + "status": "authorized" + } + } + ) # Redirect to success page with session token and user email return redirect(f"/api/connectors/callback-status?status=success&message=Authentication+successful&provider={provider}&session_token={session_token}&user_email={user_email}") except Exception as e: current_app.logger.error(f"Error exchanging code for tokens: {str(e)}", exc_info=True) - return redirect(f"/api/connectors/callback-status?status=error&message=Failed+to+exchange+authorization+code+for+tokens:+{str(e)}&provider={provider}") + return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}") except Exception as e: current_app.logger.error(f"Error handling connector callback: {e}") - return redirect(f"/api/connectors/callback-status?status=error&message=Failed+to+complete+connector+authentication:+{str(e)}.+Please+try+again+and+make+sure+to+grant+all+requested+permissions,+including+offline+access.") + return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.") @connectors_ns.route("/api/connectors/refresh") From 56256051d2a47f356c42fe97b7758c4f3b706a56 Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 24 Sep 2025 22:59:53 +0100 Subject: [PATCH 06/11] fix: chunking --- application/agents/base.py | 10 +++++++++- application/retriever/classic_rag.py | 22 ++++++++++++++++++---- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/application/agents/base.py b/application/agents/base.py index 77729fe6..134de1c3 100644 --- a/application/agents/base.py +++ b/application/agents/base.py @@ -264,7 +264,15 @@ class BaseAgent(ABC): query: str, retrieved_data: List[Dict], ) -> List[Dict]: - docs_together = "\n".join([doc["text"] for doc in retrieved_data]) + docs_with_filenames = [] + for doc in retrieved_data: + filename = doc.get("filename") or doc.get("title") or doc.get("source") + if filename: + chunk_header = str(filename) + docs_with_filenames.append(f"{chunk_header}\n{doc['text']}") + else: + docs_with_filenames.append(doc["text"]) + docs_together = "\n\n".join(docs_with_filenames) p_chat_combine = system_prompt.replace("{summaries}", docs_together) messages_combine = [{"role": "system", "content": p_chat_combine}] diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index 2ce863c2..f90a751c 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -1,4 +1,5 @@ import logging +import os from application.core.settings import settings from application.llm.llm_creator import LLMCreator @@ -141,15 +142,28 @@ class ClassicRAG(BaseRetriever): title = metadata.get( "title", metadata.get("post_title", page_content) ) - if isinstance(title, str): - title = title.split("/")[-1] + if not isinstance(title, str): + title = str(title) + title = title.split("/")[-1] + + filename = ( + metadata.get("filename") + or metadata.get("file_name") + or metadata.get("source") + ) + if isinstance(filename, str): + filename = os.path.basename(filename) or filename else: - title = str(title).split("/")[-1] + filename = title + if not filename: + filename = title + source_path = metadata.get("source") or vectorstore_id all_docs.append( { "title": title, "text": page_content, - "source": metadata.get("source") or vectorstore_id, + "source": source_path, + "filename": filename, } ) except Exception as e: From ac66d77512ca512cc5b21fa44d0e4f55d9a420cd Mon Sep 17 00:00:00 2001 From: ManishMadan2882 Date: Thu, 25 Sep 2025 03:45:12 +0530 Subject: [PATCH 07/11] (fix:oauth) handle access denied --- application/api/connector/routes.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/application/api/connector/routes.py b/application/api/connector/routes.py index e65bee55..fb22fe90 100644 --- a/application/api/connector/routes.py +++ b/application/api/connector/routes.py @@ -276,7 +276,11 @@ class ConnectorsCallback(Resource): error = request.args.get('error') if error: - return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}") + if error == "access_denied": + return redirect(f"/api/connectors/callback-status?status=cancelled&message=Authentication+was+cancelled.+You+can+try+again+if+you'd+like+to+connect+your+account.&provider={provider}") + else: + current_app.logger.warning(f"OAuth error in callback: {error}") + return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}") if not authorization_code: return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}") @@ -644,20 +648,23 @@ class ConnectorCallbackStatus(Resource): .container {{ max-width: 600px; margin: 0 auto; }} .success {{ color: #4CAF50; }} .error {{ color: #F44336; }} + .cancelled {{ color: #FF9800; }}