From 65e57be4ddfd273aea572c15fbc6d09d4e9c6845 Mon Sep 17 00:00:00 2001 From: Siddhant Rai <47355538+siiddhantt@users.noreply.github.com> Date: Fri, 13 Mar 2026 21:28:50 +0530 Subject: [PATCH] feat: dynamic config rendering + mcp tool enhancement (#2286) * feat: enhance modal functionality and configuration handling - Updated WrapperModal to improve click outside detection for closing the modal. - Refactored ToolConfig to utilize ConfigFieldSpec for better configuration management. - Added validation and dynamic handling of configuration fields in ToolConfig. - Introduced reconnect functionality for MCP tools in the Tools component. - Enhanced user experience with improved error handling and loading states. - Updated types for better type safety and clarity in configuration requirements. * refactor: reorganize imports and improve conditional formatting * fix: revert API_URL to use backend service name in docker-compose * feat: add MCP auth status endpoint and integrate into user service and tools * feat: implement logging for Brave, Postgres, and Telegram tools; add transport sanitization and credential extraction for MCP --------- Co-authored-by: Alex --- application/agents/base.py | 13 +- application/agents/tools/brave.py | 8 +- application/agents/tools/mcp_tool.py | 591 ++++++++++++++--------- application/agents/tools/ntfy.py | 15 +- application/agents/tools/postgres.py | 45 +- application/agents/tools/telegram.py | 13 +- application/agents/tools/tool_manager.py | 2 +- application/api/user/tools/mcp.py | 253 +++++++--- application/api/user/tools/routes.py | 284 ++++++++--- application/core/settings.py | 1 + application/worker.py | 48 +- frontend/package.json | 1 + frontend/src/api/endpoints.ts | 1 + frontend/src/api/services/userService.ts | 2 + frontend/src/components/ConfigFields.tsx | 149 ++++++ frontend/src/components/ContextMenu.tsx | 34 +- frontend/src/components/ui/input.tsx | 23 + frontend/src/components/ui/label.tsx | 22 + frontend/src/components/ui/select.tsx | 12 +- frontend/src/index.css | 36 +- frontend/src/locale/en.json | 12 + frontend/src/modals/ConfigToolModal.tsx | 155 ++++-- frontend/src/modals/MCPServerModal.tsx | 422 +++++++++++----- frontend/src/modals/WrapperModal.tsx | 12 +- frontend/src/modals/types/index.ts | 18 +- frontend/src/settings/ToolConfig.tsx | 280 ++++++----- frontend/src/settings/Tools.tsx | 142 ++++-- frontend/src/settings/types/index.ts | 6 +- tests/requirements.txt | 1 + 29 files changed, 1805 insertions(+), 796 deletions(-) create mode 100644 frontend/src/components/ConfigFields.tsx create mode 100644 frontend/src/components/ui/input.tsx create mode 100644 frontend/src/components/ui/label.tsx diff --git a/application/agents/base.py b/application/agents/base.py index 2795847b..ee55a449 100644 --- a/application/agents/base.py +++ b/application/agents/base.py @@ -16,6 +16,7 @@ from application.core.settings import settings from application.llm.handlers.handler_creator import LLMHandlerCreator from application.llm.llm_creator import LLMCreator from application.logging import build_stack_data, log_activity, LogContext +from application.security.encryption import decrypt_credentials logger = logging.getLogger(__name__) @@ -264,12 +265,18 @@ class BaseAgent(ABC): ) else: tool_config = tool_data["config"].copy() if tool_data["config"] else {} - # Add tool_id from MongoDB _id for tools that need instance isolation (like memory tool) - # Use MongoDB _id if available, otherwise fall back to enumerated tool_id - + if tool_config.get("encrypted_credentials") and self.user: + decrypted = decrypt_credentials( + tool_config["encrypted_credentials"], self.user + ) + tool_config.update(decrypted) + tool_config["auth_credentials"] = decrypted + tool_config.pop("encrypted_credentials", None) tool_config["tool_id"] = str(tool_data.get("_id", tool_id)) if hasattr(self, "conversation_id") and self.conversation_id: tool_config["conversation_id"] = self.conversation_id + if tool_data["name"] == "mcp_tool": + tool_config["query_mode"] = True tool = tm.load_tool( tool_data["name"], tool_config=tool_config, diff --git a/application/agents/tools/brave.py b/application/agents/tools/brave.py index 4888fd44..66b21b10 100644 --- a/application/agents/tools/brave.py +++ b/application/agents/tools/brave.py @@ -46,7 +46,7 @@ class BraveSearchTool(Tool): """ Performs a web search using the Brave Search API. """ - logger.info(f"Brave web search: {query}") + logger.debug("Performing Brave web search for: %s", query) url = f"{self.base_url}/web/search" @@ -99,7 +99,7 @@ class BraveSearchTool(Tool): """ Performs an image search using the Brave Search API. """ - logger.info(f"Brave image search: {query}") + logger.debug("Performing Brave image search for: %s", query) url = f"{self.base_url}/images/search" @@ -182,6 +182,10 @@ class BraveSearchTool(Tool): return { "token": { "type": "string", + "label": "API Key", "description": "Brave Search API key for authentication", + "required": True, + "secret": True, + "order": 1, }, } diff --git a/application/agents/tools/mcp_tool.py b/application/agents/tools/mcp_tool.py index f70dbe80..265688ea 100644 --- a/application/agents/tools/mcp_tool.py +++ b/application/agents/tools/mcp_tool.py @@ -1,20 +1,12 @@ import asyncio import base64 +import concurrent.futures import json import logging import time from typing import Any, Dict, List, Optional from urllib.parse import parse_qs, urlparse -from application.agents.tools.base import Tool -from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task -from application.cache import get_redis_instance - -from application.core.mongo_db import MongoDB - -from application.core.settings import settings - -from application.security.encryption import decrypt_credentials from fastmcp import Client from fastmcp.client.auth import BearerAuth from fastmcp.client.transports import ( @@ -24,10 +16,16 @@ from fastmcp.client.transports import ( ) from mcp.client.auth import OAuthClientProvider, TokenStorage from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken - from pydantic import AnyHttpUrl, ValidationError from redis import Redis +from application.agents.tools.base import Tool +from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task +from application.cache import get_redis_instance +from application.core.mongo_db import MongoDB +from application.core.settings import settings +from application.security.encryption import decrypt_credentials + logger = logging.getLogger(__name__) mongo = MongoDB.get_client() @@ -58,6 +56,7 @@ class MCPTool(Tool): - args: Arguments for STDIO transport - oauth_scopes: OAuth scopes for oauth auth type - oauth_client_name: OAuth client name for oauth auth type + - query_mode: If True, use non-interactive OAuth (fail-fast on 401) user_id: User ID for decrypting credentials (required if encrypted_credentials exist) """ self.config = config @@ -78,23 +77,40 @@ class MCPTool(Tool): self.oauth_scopes = config.get("oauth_scopes", []) self.oauth_task_id = config.get("oauth_task_id", None) self.oauth_client_name = config.get("oauth_client_name", "DocsGPT-MCP") - self.redirect_uri = f"{settings.API_URL}/api/mcp_server/callback" + self.redirect_uri = self._resolve_redirect_uri(config.get("redirect_uri")) self.available_tools = [] self._cache_key = self._generate_cache_key() self._client = None - - # Only validate and setup if server_url is provided and not OAuth + self.query_mode = config.get("query_mode", False) if self.server_url and self.auth_type != "oauth": self._setup_client() + def _resolve_redirect_uri(self, configured_redirect_uri: Optional[str]) -> str: + if configured_redirect_uri: + return configured_redirect_uri.rstrip("/") + + explicit = getattr(settings, "MCP_OAUTH_REDIRECT_URI", None) + if explicit: + return explicit.rstrip("/") + + connector_base = getattr(settings, "CONNECTOR_REDIRECT_BASE_URI", None) + if connector_base: + parsed = urlparse(connector_base) + if parsed.scheme and parsed.netloc: + return f"{parsed.scheme}://{parsed.netloc}/api/mcp_server/callback" + + return f"{settings.API_URL.rstrip('/')}/api/mcp_server/callback" + def _generate_cache_key(self) -> str: """Generate a unique cache key for this MCP server configuration.""" auth_key = "" if self.auth_type == "oauth": scopes_str = ",".join(self.oauth_scopes) if self.oauth_scopes else "none" - auth_key = f"oauth:{self.oauth_client_name}:{scopes_str}" + auth_key = ( + f"oauth:{self.oauth_client_name}:{scopes_str}:{self.redirect_uri}" + ) elif self.auth_type in ["bearer"]: token = self.auth_credentials.get( "bearer_token", "" @@ -111,11 +127,10 @@ class MCPTool(Tool): return f"{self.server_url}#{self.transport_type}#{auth_key}" def _setup_client(self): - """Setup FastMCP client with proper transport and authentication.""" global _mcp_clients_cache if self._cache_key in _mcp_clients_cache: cached_data = _mcp_clients_cache[self._cache_key] - if time.time() - cached_data["created_at"] < 1800: + if time.time() - cached_data["created_at"] < 300: self._client = cached_data["client"] return else: @@ -125,15 +140,25 @@ class MCPTool(Tool): if self.auth_type == "oauth": redis_client = get_redis_instance() - auth = DocsGPTOAuth( - mcp_url=self.server_url, - scopes=self.oauth_scopes, - redis_client=redis_client, - redirect_uri=self.redirect_uri, - task_id=self.oauth_task_id, - db=db, - user_id=self.user_id, - ) + if self.query_mode: + auth = NonInteractiveOAuth( + mcp_url=self.server_url, + scopes=self.oauth_scopes, + redis_client=redis_client, + redirect_uri=self.redirect_uri, + db=db, + user_id=self.user_id, + ) + else: + auth = DocsGPTOAuth( + mcp_url=self.server_url, + scopes=self.oauth_scopes, + redis_client=redis_client, + redirect_uri=self.redirect_uri, + task_id=self.oauth_task_id, + db=db, + user_id=self.user_id, + ) elif self.auth_type == "bearer": token = self.auth_credentials.get( "bearer_token", "" @@ -235,38 +260,53 @@ class MCPTool(Tool): else: raise Exception(f"Unknown operation: {operation}") + _ERROR_MAP = [ + (concurrent.futures.TimeoutError, lambda op, t, _: f"Timed out after {t}s"), + (ConnectionRefusedError, lambda *_: "Connection refused"), + ] + + _ERROR_PATTERNS = { + ("403", "Forbidden"): "Access denied (403 Forbidden)", + ("401", "Unauthorized"): "Authentication failed (401 Unauthorized)", + ("ECONNREFUSED",): "Connection refused", + ("SSL", "certificate"): "SSL/TLS error", + } + def _run_async_operation(self, operation: str, *args, **kwargs): - """Run async operation in sync context.""" try: try: - loop = asyncio.get_running_loop() - import concurrent.futures - - def run_in_thread(): - new_loop = asyncio.new_event_loop() - asyncio.set_event_loop(new_loop) - try: - return new_loop.run_until_complete( - self._execute_with_client(operation, *args, **kwargs) - ) - finally: - new_loop.close() - + asyncio.get_running_loop() with concurrent.futures.ThreadPoolExecutor() as executor: - future = executor.submit(run_in_thread) + future = executor.submit( + self._run_in_new_loop, operation, *args, **kwargs + ) return future.result(timeout=self.timeout) except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - return loop.run_until_complete( - self._execute_with_client(operation, *args, **kwargs) - ) - finally: - loop.close() + return self._run_in_new_loop(operation, *args, **kwargs) except Exception as e: - logger.error(f"Error occurred while running async operation: {e}") - raise + raise self._map_error(operation, e) from e + raise self._map_error(operation, e) from e + + def _run_in_new_loop(self, operation, *args, **kwargs): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete( + self._execute_with_client(operation, *args, **kwargs) + ) + finally: + loop.close() + + def _map_error(self, operation: str, exc: Exception) -> Exception: + for exc_type, msg_fn in self._ERROR_MAP: + if isinstance(exc, exc_type): + return Exception(msg_fn(operation, self.timeout, exc)) + error_msg = str(exc) + for patterns, friendly in self._ERROR_PATTERNS.items(): + if any(p.lower() in error_msg.lower() for p in patterns): + return Exception(friendly) + logger.error("MCP %s failed: %s", operation, exc) + return exc def discover_tools(self) -> List[Dict]: """ @@ -287,16 +327,6 @@ class MCPTool(Tool): raise Exception(f"Failed to discover tools from MCP server: {str(e)}") def execute_action(self, action_name: str, **kwargs) -> Any: - """ - Execute an action on the remote MCP server using FastMCP. - - Args: - action_name: Name of the action to execute - **kwargs: Parameters for the action - - Returns: - Result from the MCP server - """ if not self.server_url: raise Exception("No MCP server configured") if not self._client: @@ -312,7 +342,37 @@ class MCPTool(Tool): ) return self._format_result(result) except Exception as e: - raise Exception(f"Failed to execute action '{action_name}': {str(e)}") + error_msg = str(e) + lower_msg = error_msg.lower() + is_auth_error = ( + "401" in error_msg + or "unauthorized" in lower_msg + or "session expired" in lower_msg + or "re-authorize" in lower_msg + ) + if is_auth_error: + if self.auth_type == "oauth": + raise Exception( + f"Action '{action_name}' failed: OAuth session expired. " + "Please re-authorize this MCP server in tool settings." + ) from e + global _mcp_clients_cache + _mcp_clients_cache.pop(self._cache_key, None) + self._client = None + self._setup_client() + try: + result = self._run_async_operation( + "call_tool", action_name, **cleaned_kwargs + ) + return self._format_result(result) + except Exception as retry_e: + raise Exception( + f"Action '{action_name}' failed after re-auth attempt: {retry_e}. " + "Your credentials may have expired — please re-authorize in tool settings." + ) from retry_e + raise Exception( + f"Failed to execute action '{action_name}': {error_msg}" + ) from e def _format_result(self, result) -> Dict: """Format FastMCP result to match expected format.""" @@ -335,23 +395,35 @@ class MCPTool(Tool): return result def test_connection(self) -> Dict: - """ - Test the connection to the MCP server and validate functionality. - - Returns: - Dictionary with connection test results including tool count - """ if not self.server_url: return { "success": False, - "message": "No MCP server URL configured", + "message": "No server URL configured", + "tools_count": 0, + } + try: + parsed = urlparse(self.server_url) + if parsed.scheme not in ("http", "https"): + return { + "success": False, + "message": f"Invalid URL scheme '{parsed.scheme}' — use http:// or https://", + "tools_count": 0, + } + except Exception: + return { + "success": False, + "message": "Invalid URL format", "tools_count": 0, - "transport_type": self.transport_type, - "auth_type": self.auth_type, - "error_type": "ConfigurationError", } if not self._client: - self._setup_client() + try: + self._setup_client() + except Exception as e: + return { + "success": False, + "message": f"Client init failed: {str(e)}", + "tools_count": 0, + } try: if self.auth_type == "oauth": return self._test_oauth_connection() @@ -362,56 +434,94 @@ class MCPTool(Tool): "success": False, "message": f"Connection failed: {str(e)}", "tools_count": 0, - "transport_type": self.transport_type, - "auth_type": self.auth_type, - "error_type": type(e).__name__, } def _test_regular_connection(self) -> Dict: - """Test connection for non-OAuth auth types.""" + ping_ok = False + ping_error = None try: self._run_async_operation("ping") - ping_success = True - except Exception: - ping_success = False - tools = self.discover_tools() + ping_ok = True + except Exception as e: + ping_error = str(e) - message = f"Successfully connected to MCP server. Found {len(tools)} tools." - if not ping_success: - message += " (Ping not supported, but tool discovery worked)" - return { - "success": True, - "message": message, - "tools_count": len(tools), - "transport_type": self.transport_type, - "auth_type": self.auth_type, - "ping_supported": ping_success, - "tools": [tool.get("name", "unknown") for tool in tools], - } - - def _test_oauth_connection(self) -> Dict: - """Test connection for OAuth auth type with proper async handling.""" try: - task = mcp_oauth_task.delay(config=self.config, user=self.user_id) - if not task: - raise Exception("Failed to start OAuth authentication") - return { - "success": True, - "requires_oauth": True, - "task_id": task.id, - "status": "pending", - "message": "OAuth flow started", - } + tools = self.discover_tools() except Exception as e: return { "success": False, - "message": f"OAuth connection failed: {str(e)}", + "message": f"Connection failed: {ping_error or str(e)}", "tools_count": 0, - "transport_type": self.transport_type, - "auth_type": self.auth_type, - "error_type": type(e).__name__, } + if not tools and not ping_ok: + return { + "success": False, + "message": f"Connection failed: {ping_error or 'No tools found'}", + "tools_count": 0, + } + + return { + "success": True, + "message": f"Connected — found {len(tools)} tool{'s' if len(tools) != 1 else ''}.", + "tools_count": len(tools), + "tools": [ + { + "name": tool.get("name", "unknown"), + "description": tool.get("description", ""), + } + for tool in tools + ], + } + + def _test_oauth_connection(self) -> Dict: + storage = DBTokenStorage( + server_url=self.server_url, user_id=self.user_id, db_client=db + ) + loop = asyncio.new_event_loop() + try: + tokens = loop.run_until_complete(storage.get_tokens()) + finally: + loop.close() + + if tokens and tokens.access_token: + self.query_mode = True + _mcp_clients_cache.pop(self._cache_key, None) + self._client = None + self._setup_client() + try: + tools = self.discover_tools() + return { + "success": True, + "message": f"Connected — found {len(tools)} tool{'s' if len(tools) != 1 else ''}.", + "tools_count": len(tools), + "tools": [ + { + "name": t.get("name", "unknown"), + "description": t.get("description", ""), + } + for t in tools + ], + } + except Exception as e: + logger.warning("OAuth token validation failed: %s", e) + _mcp_clients_cache.pop(self._cache_key, None) + self._client = None + + return self._start_oauth_task() + + def _start_oauth_task(self) -> Dict: + task_config = self.config.copy() + task_config.pop("query_mode", None) + result = mcp_oauth_task.delay(task_config, self.user_id) + return { + "success": False, + "requires_oauth": True, + "task_id": result.id, + "message": "OAuth authorization required.", + "tools_count": 0, + } + def get_actions_metadata(self) -> List[Dict]: """ Get metadata for all available actions. @@ -457,110 +567,88 @@ class MCPTool(Tool): return actions def get_config_requirements(self) -> Dict: - """Get configuration requirements for the MCP tool.""" - transport_enum = ["auto", "sse", "http"] - transport_help = { - "auto": "Automatically detect best transport", - "sse": "Server-Sent Events (for real-time streaming)", - "http": "HTTP streaming (recommended for production)", - } return { "server_url": { "type": "string", - "description": "URL of the remote MCP server (e.g., https://api.example.com/mcp or https://docs.mcp.cloudflare.com/sse)", + "label": "Server URL", + "description": "URL of the remote MCP server", "required": True, - }, - "transport_type": { - "type": "string", - "description": "Transport type for connection", - "enum": transport_enum, - "default": "auto", - "required": False, - "help": { - **transport_help, - }, + "secret": False, + "order": 1, }, "auth_type": { "type": "string", - "description": "Authentication type", + "label": "Authentication Type", + "description": "Authentication method for the MCP server", "enum": ["none", "bearer", "oauth", "api_key", "basic"], "default": "none", "required": True, - "help": { - "none": "No authentication", - "bearer": "Bearer token authentication", - "oauth": "OAuth 2.1 authentication (with frontend integration)", - "api_key": "API key authentication", - "basic": "Basic authentication", - }, + "secret": False, + "order": 2, }, - "auth_credentials": { - "type": "object", - "description": "Authentication credentials (varies by auth_type)", + "api_key": { + "type": "string", + "label": "API Key", + "description": "API key for authentication", "required": False, - "properties": { - "bearer_token": { - "type": "string", - "description": "Bearer token for bearer auth", - }, - "access_token": { - "type": "string", - "description": "Access token for OAuth (if pre-obtained)", - }, - "api_key": { - "type": "string", - "description": "API key for api_key auth", - }, - "api_key_header": { - "type": "string", - "description": "Header name for API key (default: X-API-Key)", - }, - "username": { - "type": "string", - "description": "Username for basic auth", - }, - "password": { - "type": "string", - "description": "Password for basic auth", - }, - }, + "secret": True, + "order": 3, + "depends_on": {"auth_type": "api_key"}, + }, + "api_key_header": { + "type": "string", + "label": "API Key Header", + "description": "Header name for API key (default: X-API-Key)", + "default": "X-API-Key", + "required": False, + "secret": False, + "order": 4, + "depends_on": {"auth_type": "api_key"}, + }, + "bearer_token": { + "type": "string", + "label": "Bearer Token", + "description": "Bearer token for authentication", + "required": False, + "secret": True, + "order": 3, + "depends_on": {"auth_type": "bearer"}, + }, + "username": { + "type": "string", + "label": "Username", + "description": "Username for basic authentication", + "required": False, + "secret": False, + "order": 3, + "depends_on": {"auth_type": "basic"}, + }, + "password": { + "type": "string", + "label": "Password", + "description": "Password for basic authentication", + "required": False, + "secret": True, + "order": 4, + "depends_on": {"auth_type": "basic"}, }, "oauth_scopes": { - "type": "array", - "description": "OAuth scopes to request (for oauth auth_type)", - "items": {"type": "string"}, - "required": False, - "default": [], - }, - "oauth_client_name": { "type": "string", - "description": "Client name for OAuth registration (for oauth auth_type)", - "default": "DocsGPT-MCP", - "required": False, - }, - "headers": { - "type": "object", - "description": "Custom headers to send with requests", + "label": "OAuth Scopes", + "description": "Comma-separated OAuth scopes to request", "required": False, + "secret": False, + "order": 3, + "depends_on": {"auth_type": "oauth"}, }, "timeout": { - "type": "integer", - "description": "Request timeout in seconds", + "type": "number", + "label": "Timeout (seconds)", + "description": "Request timeout in seconds (1-300)", "default": 30, - "minimum": 1, - "maximum": 300, - "required": False, - }, - "command": { - "type": "string", - "description": "Command to run for STDIO transport (e.g., 'python')", - "required": False, - }, - "args": { - "type": "array", - "description": "Arguments for STDIO command", - "items": {"type": "string"}, "required": False, + "secret": False, + "order": 10, }, } @@ -582,23 +670,8 @@ class DocsGPTOAuth(OAuthClientProvider): user_id=None, db=None, additional_client_metadata: dict[str, Any] | None = None, + skip_redirect_validation: bool = False, ): - """ - Initialize custom OAuth client provider for DocsGPT. - - Args: - mcp_url: Full URL to the MCP endpoint - redirect_uri: Custom redirect URI for DocsGPT frontend - redis_client: Redis client for storing auth state - redis_prefix: Prefix for Redis keys - task_id: Task ID for tracking auth status - scopes: OAuth scopes to request - client_name: Name for this client during registration - user_id: User ID for token storage - db: Database instance for token storage - additional_client_metadata: Extra fields for OAuthClientMetadata - """ - self.redirect_uri = redirect_uri self.redis_client = redis_client self.redis_prefix = redis_prefix @@ -621,7 +694,10 @@ class DocsGPTOAuth(OAuthClientProvider): ) storage = DBTokenStorage( - server_url=self.server_base_url, user_id=self.user_id, db_client=self.db + server_url=self.server_base_url, + user_id=self.user_id, + db_client=self.db, + expected_redirect_uri=None if skip_redirect_validation else redirect_uri, ) super().__init__( @@ -653,22 +729,20 @@ class DocsGPTOAuth(OAuthClientProvider): async def redirect_handler(self, authorization_url: str) -> None: """Store auth URL and state in Redis for frontend to use.""" auth_url, state = self._process_auth_url(authorization_url) - logging.info( - "[DocsGPTOAuth] Processed auth_url: %s, state: %s", auth_url, state - ) + logger.info("Processed auth_url: %s, state: %s", auth_url, state) self.auth_url = auth_url self.extracted_state = state if self.redis_client and self.extracted_state: key = f"{self.redis_prefix}auth_url:{self.extracted_state}" self.redis_client.setex(key, 600, auth_url) - logging.info("[DocsGPTOAuth] Stored auth_url in Redis: %s", key) + logger.info("Stored auth_url in Redis: %s", key) if self.task_id: status_key = f"mcp_oauth_status:{self.task_id}" status_data = { "status": "requires_redirect", - "message": "OAuth authorization required", + "message": "Authorization required", "authorization_url": self.auth_url, "state": self.extracted_state, "requires_oauth": True, @@ -688,7 +762,7 @@ class DocsGPTOAuth(OAuthClientProvider): status_key = f"mcp_oauth_status:{self.task_id}" status_data = { "status": "awaiting_callback", - "message": "Waiting for OAuth callback...", + "message": "Waiting for authorization...", "authorization_url": self.auth_url, "state": self.extracted_state, "requires_oauth": True, @@ -713,7 +787,7 @@ class DocsGPTOAuth(OAuthClientProvider): if self.task_id: status_data = { "status": "callback_received", - "message": "OAuth callback received, completing authentication...", + "message": "Completing authentication...", "task_id": self.task_id, } self.redis_client.setex(status_key, 600, json.dumps(status_data)) @@ -733,14 +807,44 @@ class DocsGPTOAuth(OAuthClientProvider): await asyncio.sleep(poll_interval) self.redis_client.delete(f"{self.redis_prefix}auth_url:{self.extracted_state}") self.redis_client.delete(f"{self.redis_prefix}state:{self.extracted_state}") - raise Exception("OAuth callback timeout: no code received within 5 minutes") + raise Exception("OAuth timeout: no code received within 5 minutes") + + +class NonInteractiveOAuth(DocsGPTOAuth): + """OAuth provider that fails fast on 401 instead of starting interactive auth. + + Used during query execution to prevent the streaming response from blocking + while waiting for user authorization that will never come. + """ + + def __init__(self, **kwargs): + kwargs.setdefault("task_id", None) + kwargs["skip_redirect_validation"] = True + super().__init__(**kwargs) + + async def redirect_handler(self, authorization_url: str) -> None: + raise Exception( + "OAuth session expired — please re-authorize this MCP server in tool settings." + ) + + async def callback_handler(self) -> tuple[str, str | None]: + raise Exception( + "OAuth session expired — please re-authorize this MCP server in tool settings." + ) class DBTokenStorage(TokenStorage): - def __init__(self, server_url: str, user_id: str, db_client): + def __init__( + self, + server_url: str, + user_id: str, + db_client, + expected_redirect_uri: Optional[str] = None, + ): self.server_url = server_url self.user_id = user_id self.db_client = db_client + self.expected_redirect_uri = expected_redirect_uri self.collection = db_client["connector_sessions"] @staticmethod @@ -759,10 +863,9 @@ class DBTokenStorage(TokenStorage): if not doc or "tokens" not in doc: return None try: - tokens = OAuthToken.model_validate(doc["tokens"]) - return tokens + return OAuthToken.model_validate(doc["tokens"]) except ValidationError as e: - logging.error(f"Could not load tokens: {e}") + logger.error("Could not load tokens: %s", e) return None async def set_tokens(self, tokens: OAuthToken) -> None: @@ -772,28 +875,38 @@ class DBTokenStorage(TokenStorage): {"$set": {"tokens": tokens.model_dump()}}, True, ) - logging.info(f"Saved tokens for {self.get_base_url(self.server_url)}") + logger.info("Saved tokens for %s", self.get_base_url(self.server_url)) async def get_client_info(self) -> OAuthClientInformationFull | None: doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key()) if not doc or "client_info" not in doc: + logger.debug( + "No client_info in DB for %s", self.get_base_url(self.server_url) + ) return None try: client_info = OAuthClientInformationFull.model_validate(doc["client_info"]) - tokens = await self.get_tokens() - if tokens is None: - logging.debug( - "No tokens found, clearing client info to force fresh registration." - ) - await asyncio.to_thread( - self.collection.update_one, - self.get_db_key(), - {"$unset": {"client_info": ""}}, - ) - return None + if self.expected_redirect_uri: + stored_uris = [ + str(uri).rstrip("/") for uri in client_info.redirect_uris + ] + expected_uri = self.expected_redirect_uri.rstrip("/") + if expected_uri not in stored_uris: + logger.warning( + "Redirect URI mismatch for %s: expected=%s stored=%s — clearing.", + self.get_base_url(self.server_url), + expected_uri, + stored_uris, + ) + await asyncio.to_thread( + self.collection.update_one, + self.get_db_key(), + {"$unset": {"client_info": "", "tokens": ""}}, + ) + return None return client_info except ValidationError as e: - logging.error(f"Could not load client info: {e}") + logger.error("Could not load client info: %s", e) return None def _serialize_client_info(self, info: dict) -> dict: @@ -809,17 +922,17 @@ class DBTokenStorage(TokenStorage): {"$set": {"client_info": serialized_info}}, True, ) - logging.info(f"Saved client info for {self.get_base_url(self.server_url)}") + logger.info("Saved client info for %s", self.get_base_url(self.server_url)) async def clear(self) -> None: await asyncio.to_thread(self.collection.delete_one, self.get_db_key()) - logging.info(f"Cleared OAuth cache for {self.get_base_url(self.server_url)}") + logger.info("Cleared OAuth cache for %s", self.get_base_url(self.server_url)) @classmethod async def clear_all(cls, db_client) -> None: collection = db_client["connector_sessions"] await asyncio.to_thread(collection.delete_many, {}) - logging.info("Cleared all OAuth client cache data.") + logger.info("Cleared all OAuth client cache data.") class MCPOAuthManager: @@ -858,7 +971,7 @@ class MCPOAuthManager: return True except Exception as e: - logging.error(f"Error handling OAuth callback: {e}") + logger.error("Error handling OAuth callback: %s", e) return False def get_oauth_status(self, task_id: str) -> Dict[str, Any]: diff --git a/application/agents/tools/ntfy.py b/application/agents/tools/ntfy.py index e968dfc4..9a2d44ca 100644 --- a/application/agents/tools/ntfy.py +++ b/application/agents/tools/ntfy.py @@ -116,12 +116,13 @@ class NtfyTool(Tool): ] def get_config_requirements(self): - """ - Specify the configuration requirements. - - Returns: - dict: Dictionary describing required config parameters. - """ return { - "token": {"type": "string", "description": "Access token for authentication"}, + "token": { + "type": "string", + "label": "Access Token", + "description": "Ntfy access token for authentication", + "required": True, + "secret": True, + "order": 1, + }, } \ No newline at end of file diff --git a/application/agents/tools/postgres.py b/application/agents/tools/postgres.py index 045bec62..d9d5a2b4 100644 --- a/application/agents/tools/postgres.py +++ b/application/agents/tools/postgres.py @@ -28,6 +28,9 @@ class PostgresTool(Tool): return actions[action_name](**kwargs) def _execute_sql(self, sql_query): + """ + Executes an SQL query against the PostgreSQL database using a connection string. + """ conn = None try: conn = psycopg2.connect(self.connection_string) @@ -36,7 +39,9 @@ class PostgresTool(Tool): conn.commit() if sql_query.strip().lower().startswith("select"): - column_names = [desc[0] for desc in cur.description] if cur.description else [] + column_names = ( + [desc[0] for desc in cur.description] if cur.description else [] + ) results = [] rows = cur.fetchall() for row in rows: @@ -44,7 +49,9 @@ class PostgresTool(Tool): response_data = {"data": results, "column_names": column_names} else: row_count = cur.rowcount - response_data = {"message": f"Query executed successfully, {row_count} rows affected."} + response_data = { + "message": f"Query executed successfully, {row_count} rows affected." + } cur.close() return { @@ -55,7 +62,7 @@ class PostgresTool(Tool): except psycopg2.Error as e: error_message = f"Database error: {e}" - logger.error(error_message) + logger.error("PostgreSQL execute_sql error: %s", e) return { "status_code": 500, "message": "Failed to execute SQL query.", @@ -69,12 +76,13 @@ class PostgresTool(Tool): """ Retrieves the schema of the PostgreSQL database using a connection string. """ - conn = None # Initialize conn to None for error handling + conn = None try: conn = psycopg2.connect(self.connection_string) cur = conn.cursor() - cur.execute(""" + cur.execute( + """ SELECT table_name, column_name, @@ -88,19 +96,22 @@ class PostgresTool(Tool): ORDER BY table_name, ordinal_position; - """) + """ + ) schema_data = {} for row in cur.fetchall(): table_name, column_name, data_type, column_default, is_nullable = row if table_name not in schema_data: schema_data[table_name] = [] - schema_data[table_name].append({ - "column_name": column_name, - "data_type": data_type, - "column_default": column_default, - "is_nullable": is_nullable - }) + schema_data[table_name].append( + { + "column_name": column_name, + "data_type": data_type, + "column_default": column_default, + "is_nullable": is_nullable, + } + ) cur.close() return { @@ -111,7 +122,7 @@ class PostgresTool(Tool): except psycopg2.Error as e: error_message = f"Database error: {e}" - logger.error(error_message) + logger.error("PostgreSQL get_schema error: %s", e) return { "status_code": 500, "message": "Failed to retrieve database schema.", @@ -159,6 +170,10 @@ class PostgresTool(Tool): return { "token": { "type": "string", - "description": "PostgreSQL database connection string (e.g., 'postgresql://user:password@host:port/dbname')", + "label": "Connection String", + "description": "PostgreSQL database connection string", + "required": True, + "secret": True, + "order": 1, }, - } \ No newline at end of file + } diff --git a/application/agents/tools/telegram.py b/application/agents/tools/telegram.py index a06eed4b..d4381370 100644 --- a/application/agents/tools/telegram.py +++ b/application/agents/tools/telegram.py @@ -28,14 +28,14 @@ class TelegramTool(Tool): return actions[action_name](**kwargs) def _send_message(self, text, chat_id): - logger.info(f"Telegram: sending message to {chat_id}") + logger.debug("Sending Telegram message to chat_id=%s", chat_id) url = f"https://api.telegram.org/bot{self.token}/sendMessage" payload = {"chat_id": chat_id, "text": text} response = requests.post(url, data=payload) return {"status_code": response.status_code, "message": "Message sent"} def _send_image(self, image_url, chat_id): - logger.info(f"Telegram: sending image to {chat_id}") + logger.debug("Sending Telegram image to chat_id=%s", chat_id) url = f"https://api.telegram.org/bot{self.token}/sendPhoto" payload = {"chat_id": chat_id, "photo": image_url} response = requests.post(url, data=payload) @@ -85,5 +85,12 @@ class TelegramTool(Tool): def get_config_requirements(self): return { - "token": {"type": "string", "description": "Bot token for authentication"}, + "token": { + "type": "string", + "label": "Bot Token", + "description": "Telegram bot token for authentication", + "required": True, + "secret": True, + "order": 1, + }, } diff --git a/application/agents/tools/tool_manager.py b/application/agents/tools/tool_manager.py index 855f1b53..08ef30a4 100644 --- a/application/agents/tools/tool_manager.py +++ b/application/agents/tools/tool_manager.py @@ -36,7 +36,7 @@ class ToolManager: def execute_action(self, tool_name, action_name, user_id=None, **kwargs): if tool_name not in self.tools: raise ValueError(f"Tool '{tool_name}' not loaded") - if tool_name in {"mcp_tool", "memory", "todo_list"} and user_id: + if tool_name in {"mcp_tool", "memory", "todo_list", "notes"} and user_id: tool_config = self.config.get(tool_name, {}) tool = self.load_tool(tool_name, tool_config, user_id) return tool.execute_action(action_name, **kwargs) diff --git a/application/api/user/tools/mcp.py b/application/api/user/tools/mcp.py index 0db2700f..3f9b2873 100644 --- a/application/api/user/tools/mcp.py +++ b/application/api/user/tools/mcp.py @@ -1,21 +1,67 @@ """Tool management MCP server integration.""" import json -from urllib.parse import unquote, urlencode +from urllib.parse import urlencode, urlparse from bson.objectid import ObjectId from flask import current_app, jsonify, make_response, redirect, request -from flask_restx import fields, Namespace, Resource +from flask_restx import Namespace, Resource, fields from application.agents.tools.mcp_tool import MCPOAuthManager, MCPTool from application.api import api from application.api.user.base import user_tools_collection +from application.api.user.tools.routes import transform_actions from application.cache import get_redis_instance -from application.security.encryption import encrypt_credentials +from application.core.mongo_db import MongoDB +from application.core.settings import settings +from application.security.encryption import decrypt_credentials, encrypt_credentials from application.utils import check_required_fields tools_mcp_ns = Namespace("tools", description="Tool management operations", path="/api") +_mongo = MongoDB.get_client() +_db = _mongo[settings.MONGO_DB_NAME] +_connector_sessions = _db["connector_sessions"] + +_ALLOWED_TRANSPORTS = {"auto", "sse", "http"} + + +def _sanitize_mcp_transport(config): + """Normalise and validate the transport_type field. + + Strips ``command`` / ``args`` keys that are only valid for local STDIO + transports and returns the cleaned transport type string. + """ + transport_type = (config.get("transport_type") or "auto").lower() + if transport_type not in _ALLOWED_TRANSPORTS: + raise ValueError(f"Unsupported transport_type: {transport_type}") + config.pop("command", None) + config.pop("args", None) + config["transport_type"] = transport_type + return transport_type + + +def _extract_auth_credentials(config): + """Build an ``auth_credentials`` dict from the raw MCP config.""" + auth_credentials = {} + auth_type = config.get("auth_type", "none") + + if auth_type == "api_key": + if config.get("api_key"): + auth_credentials["api_key"] = config["api_key"] + if config.get("api_key_header"): + auth_credentials["api_key_header"] = config["api_key_header"] + elif auth_type == "bearer": + if config.get("bearer_token"): + auth_credentials["bearer_token"] = config["bearer_token"] + elif auth_type == "basic": + if config.get("username"): + auth_credentials["username"] = config["username"] + if config.get("password"): + auth_credentials["password"] = config["password"] + + return auth_credentials + @tools_mcp_ns.route("/mcp_server/test") class TestMCPServerConfig(Resource): @@ -43,49 +89,35 @@ class TestMCPServerConfig(Resource): return missing_fields try: config = data["config"] - transport_type = (config.get("transport_type") or "auto").lower() - allowed_transports = {"auto", "sse", "http"} - if transport_type not in allowed_transports: + try: + _sanitize_mcp_transport(config) + except ValueError: return make_response( jsonify({"success": False, "error": "Unsupported transport_type"}), 400, ) - config.pop("command", None) - config.pop("args", None) - config["transport_type"] = transport_type - auth_credentials = {} - auth_type = config.get("auth_type", "none") - - if auth_type == "api_key" and "api_key" in config: - auth_credentials["api_key"] = config["api_key"] - if "api_key_header" in config: - auth_credentials["api_key_header"] = config["api_key_header"] - elif auth_type == "bearer" and "bearer_token" in config: - auth_credentials["bearer_token"] = config["bearer_token"] - elif auth_type == "basic": - if "username" in config: - auth_credentials["username"] = config["username"] - if "password" in config: - auth_credentials["password"] = config["password"] + auth_credentials = _extract_auth_credentials(config) test_config = config.copy() test_config["auth_credentials"] = auth_credentials mcp_tool = MCPTool(config=test_config, user_id=user) result = mcp_tool.test_connection() - # Sanitize the response to avoid exposing internal error details + if result.get("requires_oauth"): + return make_response(jsonify(result), 200) + if not result.get("success") and "message" in result: - current_app.logger.error(f"MCP connection test failed: {result.get('message')}") + current_app.logger.error( + f"MCP connection test failed: {result.get('message')}" + ) result["message"] = "Connection test failed" return make_response(jsonify(result), 200) except Exception as e: current_app.logger.error(f"Error testing MCP server: {e}", exc_info=True) return make_response( - jsonify( - {"success": False, "error": "Connection test failed"} - ), + jsonify({"success": False, "error": "Connection test failed"}), 500, ) @@ -125,32 +157,16 @@ class MCPServerSave(Resource): return missing_fields try: config = data["config"] - transport_type = (config.get("transport_type") or "auto").lower() - allowed_transports = {"auto", "sse", "http"} - if transport_type not in allowed_transports: + try: + _sanitize_mcp_transport(config) + except ValueError: return make_response( jsonify({"success": False, "error": "Unsupported transport_type"}), 400, ) - config.pop("command", None) - config.pop("args", None) - config["transport_type"] = transport_type - auth_credentials = {} + auth_credentials = _extract_auth_credentials(config) auth_type = config.get("auth_type", "none") - if auth_type == "api_key": - if "api_key" in config and config["api_key"]: - auth_credentials["api_key"] = config["api_key"] - if "api_key_header" in config: - auth_credentials["api_key_header"] = config["api_key_header"] - elif auth_type == "bearer": - if "bearer_token" in config and config["bearer_token"]: - auth_credentials["bearer_token"] = config["bearer_token"] - elif auth_type == "basic": - if "username" in config and config["username"]: - auth_credentials["username"] = config["username"] - if "password" in config and config["password"]: - auth_credentials["password"] = config["password"] mcp_config = config.copy() mcp_config["auth_credentials"] = auth_credentials @@ -188,30 +204,39 @@ class MCPServerSave(Resource): "No valid credentials provided for the selected authentication type" ) storage_config = config.copy() + + tool_id = data.get("id") + existing_encrypted = None + if tool_id: + existing_doc = user_tools_collection.find_one( + {"_id": ObjectId(tool_id), "user": user, "name": "mcp_tool"} + ) + if existing_doc: + existing_encrypted = existing_doc.get("config", {}).get( + "encrypted_credentials" + ) + if auth_credentials: - encrypted_credentials_string = encrypt_credentials( + if existing_encrypted: + existing_secrets = decrypt_credentials(existing_encrypted, user) + existing_secrets.update(auth_credentials) + auth_credentials = existing_secrets + storage_config["encrypted_credentials"] = encrypt_credentials( auth_credentials, user ) - storage_config["encrypted_credentials"] = encrypted_credentials_string + elif existing_encrypted: + storage_config["encrypted_credentials"] = existing_encrypted + for field in [ "api_key", "bearer_token", "username", "password", "api_key_header", + "redirect_uri", ]: storage_config.pop(field, None) - transformed_actions = [] - for action in actions_metadata: - action["active"] = True - if "parameters" in action: - if "properties" in action["parameters"]: - for param_name, param_details in action["parameters"][ - "properties" - ].items(): - param_details["filled_by_llm"] = True - param_details["value"] = "" - transformed_actions.append(action) + transformed_actions = transform_actions(actions_metadata) tool_data = { "name": "mcp_tool", "displayName": data["displayName"], @@ -223,7 +248,6 @@ class MCPServerSave(Resource): "user": user, } - tool_id = data.get("id") if tool_id: result = user_tools_collection.update_one( {"_id": ObjectId(tool_id), "user": user, "name": "mcp_tool"}, @@ -258,9 +282,7 @@ class MCPServerSave(Resource): except Exception as e: current_app.logger.error(f"Error saving MCP server: {e}", exc_info=True) return make_response( - jsonify( - {"success": False, "error": "Failed to save MCP server"} - ), + jsonify({"success": False, "error": "Failed to save MCP server"}), 500, ) @@ -291,7 +313,7 @@ class MCPOAuthCallback(Resource): params = { "status": "error", "message": f"OAuth error: {error}. Please try again and make sure to grant all requested permissions, including offline access.", - "provider": "mcp_tool" + "provider": "mcp_tool", } return redirect(f"/api/connectors/callback-status?{urlencode(params)}") if not code or not state: @@ -304,7 +326,6 @@ class MCPOAuthCallback(Resource): return redirect( "/api/connectors/callback-status?status=error&message=Internal+server+error:+Redis+not+available.&provider=mcp_tool" ) - code = unquote(code) manager = MCPOAuthManager(redis_client) success = manager.handle_oauth_callback(state, code, error) if success: @@ -327,10 +348,6 @@ class MCPOAuthCallback(Resource): @tools_mcp_ns.route("/mcp_server/oauth_status/") class MCPOAuthStatus(Resource): def get(self, task_id): - """ - Get current status of OAuth flow. - Frontend should poll this endpoint periodically. - """ try: redis_client = get_redis_instance() status_key = f"mcp_oauth_status:{task_id}" @@ -338,6 +355,14 @@ class MCPOAuthStatus(Resource): if status_data: status = json.loads(status_data) + if "tools" in status and isinstance(status["tools"], list): + status["tools"] = [ + { + "name": t.get("name", "unknown"), + "description": t.get("description", ""), + } + for t in status["tools"] + ] return make_response( jsonify({"success": True, "task_id": task_id, **status}) ) @@ -345,17 +370,93 @@ class MCPOAuthStatus(Resource): return make_response( jsonify( { - "success": False, - "error": "Task not found or expired", + "success": True, "task_id": task_id, + "status": "pending", + "message": "Waiting for OAuth to start...", } ), - 404, + 200, ) except Exception as e: current_app.logger.error( - f"Error getting OAuth status for task {task_id}: {str(e)}", exc_info=True + f"Error getting OAuth status for task {task_id}: {str(e)}", + exc_info=True, ) return make_response( - jsonify({"success": False, "error": "Failed to get OAuth status", "task_id": task_id}), 500 + jsonify( + { + "success": False, + "error": "Failed to get OAuth status", + "task_id": task_id, + } + ), + 500, + ) + + +@tools_mcp_ns.route("/mcp_server/auth_status") +class MCPAuthStatus(Resource): + @api.doc( + description="Batch check auth status for all MCP tools. " + "Lightweight DB-only check — no network calls to MCP servers." + ) + def get(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + user = decoded_token.get("sub") + try: + mcp_tools = list( + user_tools_collection.find( + {"user": user, "name": "mcp_tool"}, + {"_id": 1, "config": 1}, + ) + ) + if not mcp_tools: + return make_response(jsonify({"success": True, "statuses": {}}), 200) + + oauth_server_urls = {} + statuses = {} + for tool in mcp_tools: + tool_id = str(tool["_id"]) + config = tool.get("config", {}) + auth_type = config.get("auth_type", "none") + if auth_type == "oauth": + server_url = config.get("server_url", "") + if server_url: + parsed = urlparse(server_url) + base_url = f"{parsed.scheme}://{parsed.netloc}" + oauth_server_urls[tool_id] = base_url + else: + statuses[tool_id] = "needs_auth" + else: + statuses[tool_id] = "configured" + + if oauth_server_urls: + unique_urls = list(set(oauth_server_urls.values())) + sessions = list( + _connector_sessions.find( + {"user_id": user, "server_url": {"$in": unique_urls}}, + {"server_url": 1, "tokens": 1}, + ) + ) + url_has_tokens = { + doc["server_url"]: bool(doc.get("tokens", {}).get("access_token")) + for doc in sessions + } + for tool_id, base_url in oauth_server_urls.items(): + if url_has_tokens.get(base_url): + statuses[tool_id] = "connected" + else: + statuses[tool_id] = "needs_auth" + + return make_response(jsonify({"success": True, "statuses": statuses}), 200) + except Exception as e: + current_app.logger.error( + "Error checking MCP auth status: %s", e, exc_info=True + ) + return make_response( + jsonify({"success": False, "error": "Failed to check auth status"}), + 500, ) diff --git a/application/api/user/tools/routes.py b/application/api/user/tools/routes.py index 760f0120..170b6e3c 100644 --- a/application/api/user/tools/routes.py +++ b/application/api/user/tools/routes.py @@ -15,6 +15,114 @@ tool_config = {} tool_manager = ToolManager(config=tool_config) +def _encrypt_secret_fields(config, config_requirements, user_id): + secret_keys = [ + key for key, spec in config_requirements.items() + if spec.get("secret") and key in config and config[key] + ] + if not secret_keys: + return config + + storage_config = config.copy() + secret_values = {k: config[k] for k in secret_keys} + storage_config["encrypted_credentials"] = encrypt_credentials(secret_values, user_id) + for key in secret_keys: + storage_config.pop(key, None) + return storage_config + + +def _validate_config(config, config_requirements, has_existing_secrets=False): + errors = {} + for key, spec in config_requirements.items(): + depends_on = spec.get("depends_on") + if depends_on: + if not all(config.get(dk) == dv for dk, dv in depends_on.items()): + continue + if spec.get("required") and not config.get(key): + if has_existing_secrets and spec.get("secret"): + continue + errors[key] = f"{spec.get('label', key)} is required" + value = config.get(key) + if value is not None and value != "": + if spec.get("type") == "number": + try: + num = float(value) + if key == "timeout" and (num < 1 or num > 300): + errors[key] = "Timeout must be between 1 and 300" + except (ValueError, TypeError): + errors[key] = f"{spec.get('label', key)} must be a number" + if spec.get("enum") and value not in spec["enum"]: + errors[key] = f"Invalid value for {spec.get('label', key)}" + return errors + + +def _merge_secrets_on_update(new_config, existing_config, config_requirements, user_id): + """Merge incoming config with existing encrypted secrets and re-encrypt. + + For updates, the client may omit unchanged secret values. This helper + decrypts any previously stored secrets, overlays whatever the client *did* + send, strips plain-text secrets from the stored config, and re-encrypts + the merged result. + + Returns the final ``config`` dict ready for persistence. + """ + secret_keys = [ + key for key, spec in config_requirements.items() + if spec.get("secret") + ] + + if not secret_keys: + return new_config + + existing_secrets = {} + if "encrypted_credentials" in existing_config: + existing_secrets = decrypt_credentials( + existing_config["encrypted_credentials"], user_id + ) + + merged_secrets = existing_secrets.copy() + for key in secret_keys: + if key in new_config and new_config[key]: + merged_secrets[key] = new_config[key] + + # Start from existing non-secret values, then overlay incoming non-secrets + storage_config = { + k: v for k, v in existing_config.items() + if k not in secret_keys and k != "encrypted_credentials" + } + storage_config.update( + {k: v for k, v in new_config.items() if k not in secret_keys} + ) + + if merged_secrets: + storage_config["encrypted_credentials"] = encrypt_credentials( + merged_secrets, user_id + ) + else: + storage_config.pop("encrypted_credentials", None) + + storage_config.pop("has_encrypted_credentials", None) + return storage_config + + +def transform_actions(actions_metadata): + """Set default flags on action metadata for storage. + + Marks each action as active, sets ``filled_by_llm`` and ``value`` on every + parameter property. Used by both the generic create_tool and MCP save routes. + """ + transformed = [] + for action in actions_metadata: + action["active"] = True + if "parameters" in action: + props = action["parameters"].get("properties", {}) + for param_details in props.values(): + param_details["filled_by_llm"] = True + param_details["value"] = "" + transformed.append(action) + return transformed + + tools_ns = Namespace("tools", description="Tool management operations", path="/api") @@ -29,12 +137,15 @@ class AvailableTools(Resource): lines = doc.split("\n", 1) name = lines[0].strip() description = lines[1].strip() if len(lines) > 1 else "" + config_req = tool_instance.get_config_requirements() + actions = tool_instance.get_actions_metadata() tools_metadata.append( { "name": tool_name, "displayName": name, "description": description, - "configRequirements": tool_instance.get_config_requirements(), + "configRequirements": config_req, + "actions": actions, } ) except Exception as err: @@ -60,6 +171,21 @@ class GetTools(Resource): tool_copy = {**tool} tool_copy["id"] = str(tool["_id"]) tool_copy.pop("_id", None) + + config_req = tool_copy.get("configRequirements", {}) + if not config_req: + tool_instance = tool_manager.tools.get(tool_copy.get("name")) + if tool_instance: + config_req = tool_instance.get_config_requirements() + tool_copy["configRequirements"] = config_req + + has_secrets = any( + spec.get("secret") for spec in config_req.values() + ) if config_req else False + if has_secrets and "encrypted_credentials" in tool_copy.get("config", {}): + tool_copy["config"]["has_encrypted_credentials"] = True + tool_copy["config"].pop("encrypted_credentials", None) + user_tools.append(tool_copy) except Exception as err: current_app.logger.error(f"Error getting user tools: {err}", exc_info=True) @@ -116,23 +242,32 @@ class CreateTool(Resource): jsonify({"success": False, "message": "Tool not found"}), 404 ) actions_metadata = tool_instance.get_actions_metadata() - transformed_actions = [] - for action in actions_metadata: - action["active"] = True - if "parameters" in action: - if "properties" in action["parameters"]: - for param_name, param_details in action["parameters"][ - "properties" - ].items(): - param_details["filled_by_llm"] = True - param_details["value"] = "" - transformed_actions.append(action) + transformed_actions = transform_actions(actions_metadata) except Exception as err: current_app.logger.error( f"Error getting tool actions: {err}", exc_info=True ) return make_response(jsonify({"success": False}), 400) try: + config_requirements = tool_instance.get_config_requirements() + if config_requirements: + validation_errors = _validate_config( + data["config"], config_requirements + ) + if validation_errors: + return make_response( + jsonify( + { + "success": False, + "message": "Validation failed", + "errors": validation_errors, + } + ), + 400, + ) + storage_config = _encrypt_secret_fields( + data["config"], config_requirements, user + ) new_tool = { "user": user, "name": data["name"], @@ -140,7 +275,8 @@ class CreateTool(Resource): "description": data["description"], "customName": data.get("customName", ""), "actions": transformed_actions, - "config": data["config"], + "config": storage_config, + "configRequirements": config_requirements, "status": data["status"], } resp = user_tools_collection.insert_one(new_tool) @@ -210,57 +346,37 @@ class UpdateTool(Resource): tool_doc = user_tools_collection.find_one( {"_id": ObjectId(data["id"]), "user": user} ) - if tool_doc and tool_doc.get("name") == "mcp_tool": - config = data["config"] - existing_config = tool_doc.get("config", {}) - storage_config = existing_config.copy() + if not tool_doc: + return make_response( + jsonify({"success": False, "message": "Tool not found"}), + 404, + ) + tool_name = tool_doc.get("name", data.get("name")) + tool_instance = tool_manager.tools.get(tool_name) + config_requirements = ( + tool_instance.get_config_requirements() if tool_instance else {} + ) + existing_config = tool_doc.get("config", {}) + has_existing_secrets = "encrypted_credentials" in existing_config - storage_config.update(config) - existing_credentials = {} - if "encrypted_credentials" in existing_config: - existing_credentials = decrypt_credentials( - existing_config["encrypted_credentials"], user + if config_requirements: + validation_errors = _validate_config( + data["config"], config_requirements, + has_existing_secrets=has_existing_secrets, + ) + if validation_errors: + return make_response( + jsonify({ + "success": False, + "message": "Validation failed", + "errors": validation_errors, + }), + 400, ) - auth_credentials = existing_credentials.copy() - auth_type = storage_config.get("auth_type", "none") - if auth_type == "api_key": - if "api_key" in config and config["api_key"]: - auth_credentials["api_key"] = config["api_key"] - if "api_key_header" in config: - auth_credentials["api_key_header"] = config[ - "api_key_header" - ] - elif auth_type == "bearer": - if "bearer_token" in config and config["bearer_token"]: - auth_credentials["bearer_token"] = config["bearer_token"] - elif "encrypted_token" in config and config["encrypted_token"]: - auth_credentials["bearer_token"] = config["encrypted_token"] - elif auth_type == "basic": - if "username" in config and config["username"]: - auth_credentials["username"] = config["username"] - if "password" in config and config["password"]: - auth_credentials["password"] = config["password"] - if auth_type != "none" and auth_credentials: - encrypted_credentials_string = encrypt_credentials( - auth_credentials, user - ) - storage_config["encrypted_credentials"] = ( - encrypted_credentials_string - ) - elif auth_type == "none": - storage_config.pop("encrypted_credentials", None) - for field in [ - "api_key", - "bearer_token", - "encrypted_token", - "username", - "password", - "api_key_header", - ]: - storage_config.pop(field, None) - update_data["config"] = storage_config - else: - update_data["config"] = data["config"] + + update_data["config"] = _merge_secrets_on_update( + data["config"], existing_config, config_requirements, user + ) if "status" in data: update_data["status"] = data["status"] user_tools_collection.update_one( @@ -298,9 +414,42 @@ class UpdateToolConfig(Resource): if missing_fields: return missing_fields try: + tool_doc = user_tools_collection.find_one( + {"_id": ObjectId(data["id"]), "user": user} + ) + if not tool_doc: + return make_response(jsonify({"success": False}), 404) + + tool_name = tool_doc.get("name") + tool_instance = tool_manager.tools.get(tool_name) + config_requirements = ( + tool_instance.get_config_requirements() if tool_instance else {} + ) + existing_config = tool_doc.get("config", {}) + has_existing_secrets = "encrypted_credentials" in existing_config + + if config_requirements: + validation_errors = _validate_config( + data["config"], config_requirements, + has_existing_secrets=has_existing_secrets, + ) + if validation_errors: + return make_response( + jsonify({ + "success": False, + "message": "Validation failed", + "errors": validation_errors, + }), + 400, + ) + + final_config = _merge_secrets_on_update( + data["config"], existing_config, config_requirements, user + ) + user_tools_collection.update_one( {"_id": ObjectId(data["id"]), "user": user}, - {"$set": {"config": data["config"]}}, + {"$set": {"config": final_config}}, ) except Exception as err: current_app.logger.error( @@ -410,11 +559,13 @@ class DeleteTool(Resource): {"_id": ObjectId(data["id"]), "user": user} ) if result.deleted_count == 0: - return {"success": False, "message": "Tool not found"}, 404 + return make_response( + jsonify({"success": False, "message": "Tool not found"}), 404 + ) except Exception as err: current_app.logger.error(f"Error deleting tool: {err}", exc_info=True) - return {"success": False}, 400 - return {"success": True}, 200 + return make_response(jsonify({"success": False}), 400) + return make_response(jsonify({"success": True}), 200) @tools_ns.route("/parse_spec") @@ -511,7 +662,6 @@ class GetArtifact(Resource): todo_doc = db["todos"].find_one({"_id": obj_id, "user_id": user_id}) if todo_doc: tool_id = todo_doc.get("tool_id") - # Return all todos for the tool query = {"user_id": user_id, "tool_id": tool_id} all_todos = list(db["todos"].find(query)) items = [] diff --git a/application/core/settings.py b/application/core/settings.py index 5cdf7f09..50c30a80 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -78,6 +78,7 @@ class Settings(BaseSettings): CACHE_REDIS_URL: str = "redis://localhost:6379/2" API_URL: str = "http://localhost:7091" # backend url for celery worker + MCP_OAUTH_REDIRECT_URI: Optional[str] = None # public callback URL for MCP OAuth INTERNAL_KEY: Optional[str] = None # internal api key for worker-to-backend auth API_KEY: Optional[str] = None # LLM api key (used by LLM_PROVIDER) diff --git a/application/worker.py b/application/worker.py index e1f6a733..69b22ea8 100755 --- a/application/worker.py +++ b/application/worker.py @@ -1449,34 +1449,28 @@ def ingest_connector( def mcp_oauth(self, config: Dict[str, Any], user_id: str = None) -> Dict[str, Any]: """Worker to handle MCP OAuth flow asynchronously.""" - logging.info( - "[MCP OAuth] Worker started for user_id=%s, config=%s", user_id, config - ) try: import asyncio from application.agents.tools.mcp_tool import MCPTool task_id = self.request.id - logging.info("[MCP OAuth] Task ID: %s", task_id) redis_client = get_redis_instance() def update_status(status_data: Dict[str, Any]): - logging.info("[MCP OAuth] Updating status: %s", status_data) status_key = f"mcp_oauth_status:{task_id}" redis_client.setex(status_key, 600, json.dumps(status_data)) update_status( { "status": "in_progress", - "message": "Starting OAuth flow...", + "message": "Starting OAuth...", "task_id": task_id, } ) tool_config = config.copy() tool_config["oauth_task_id"] = task_id - logging.info("[MCP OAuth] Initializing MCPTool with config: %s", tool_config) mcp_tool = MCPTool(tool_config, user_id) async def run_oauth_discovery(): @@ -1487,7 +1481,7 @@ def mcp_oauth(self, config: Dict[str, Any], user_id: str = None) -> Dict[str, An update_status( { "status": "awaiting_redirect", - "message": "Waiting for OAuth redirect...", + "message": "Awaiting OAuth redirect...", "task_id": task_id, } ) @@ -1496,66 +1490,40 @@ def mcp_oauth(self, config: Dict[str, Any], user_id: str = None) -> Dict[str, An asyncio.set_event_loop(loop) try: - logging.info("[MCP OAuth] Starting event loop for OAuth discovery...") - tools_response = loop.run_until_complete(run_oauth_discovery()) - logging.info( - "[MCP OAuth] Tools response after async call: %s", tools_response - ) - - status_key = f"mcp_oauth_status:{task_id}" - redis_status = redis_client.get(status_key) - if redis_status: - logging.info( - "[MCP OAuth] Redis status after async call: %s", redis_status - ) - else: - logging.warning( - "[MCP OAuth] No Redis status found after async call for key: %s", - status_key, - ) + loop.run_until_complete(run_oauth_discovery()) tools = mcp_tool.get_actions_metadata() update_status( { "status": "completed", - "message": f"OAuth completed successfully. Found {len(tools)} tools.", + "message": f"Connected \u2014 found {len(tools)} tool{'s' if len(tools) != 1 else ''}.", "tools": tools, "tools_count": len(tools), "task_id": task_id, } ) - logging.info( - "[MCP OAuth] OAuth flow completed successfully for task_id=%s", task_id - ) return {"success": True, "tools": tools, "tools_count": len(tools)} except Exception as e: - error_msg = f"OAuth flow failed: {str(e)}" - logging.error( - "[MCP OAuth] Exception in OAuth discovery: %s", error_msg, exc_info=True - ) + error_msg = f"OAuth failed: {str(e)}" + logging.error("MCP OAuth discovery failed: %s", error_msg, exc_info=True) update_status( { "status": "error", "message": error_msg, - "error": str(e), "task_id": task_id, } ) return {"success": False, "error": error_msg} finally: - logging.info("[MCP OAuth] Closing event loop for task_id=%s", task_id) loop.close() except Exception as e: - error_msg = f"Failed to initialize OAuth flow: {str(e)}" - logging.error( - "[MCP OAuth] Exception during initialization: %s", error_msg, exc_info=True - ) + error_msg = f"OAuth init failed: {str(e)}" + logging.error("MCP OAuth init failed: %s", error_msg, exc_info=True) update_status( { "status": "error", "message": error_msg, - "error": str(e), "task_id": task_id, } ) diff --git a/frontend/package.json b/frontend/package.json index 87a2b454..1b952184 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -35,6 +35,7 @@ "lucide-react": "^0.562.0", "mermaid": "^11.12.1", "prop-types": "^15.8.1", + "radix-ui": "^1.4.3", "react": "^19.1.0", "react-chartjs-2": "^5.3.0", "react-dom": "^19.1.1", diff --git a/frontend/src/api/endpoints.ts b/frontend/src/api/endpoints.ts index eec77508..317b8b9c 100644 --- a/frontend/src/api/endpoints.ts +++ b/frontend/src/api/endpoints.ts @@ -65,6 +65,7 @@ const endpoints = { MCP_SAVE_SERVER: '/api/mcp_server/save', MCP_OAUTH_STATUS: (task_id: string) => `/api/mcp_server/oauth_status/${task_id}`, + MCP_AUTH_STATUS: '/api/mcp_server/auth_status', AGENT_FOLDERS: '/api/agents/folders/', AGENT_FOLDER: (id: string) => `/api/agents/folders/${id}`, MOVE_AGENT_TO_FOLDER: '/api/agents/folders/move_agent', diff --git a/frontend/src/api/services/userService.ts b/frontend/src/api/services/userService.ts index 06382435..d309c9ba 100644 --- a/frontend/src/api/services/userService.ts +++ b/frontend/src/api/services/userService.ts @@ -123,6 +123,8 @@ const userService = { apiClient.post(endpoints.USER.MCP_SAVE_SERVER, data, token), getMCPOAuthStatus: (task_id: string, token: string | null): Promise => apiClient.get(endpoints.USER.MCP_OAUTH_STATUS(task_id), token), + getMCPAuthStatus: (token: string | null): Promise => + apiClient.get(endpoints.USER.MCP_AUTH_STATUS, token), syncConnector: ( docId: string, provider: string, diff --git a/frontend/src/components/ConfigFields.tsx b/frontend/src/components/ConfigFields.tsx new file mode 100644 index 00000000..87a520d5 --- /dev/null +++ b/frontend/src/components/ConfigFields.tsx @@ -0,0 +1,149 @@ +import { useMemo } from 'react'; + +import { cn } from '@/lib/utils'; + +import { ConfigRequirements } from '../modals/types'; +import { Input } from './ui/input'; +import { Label } from './ui/label'; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from './ui/select'; + +type ConfigValues = { [key: string]: any }; + +interface ConfigFieldsProps { + configRequirements: ConfigRequirements; + values: ConfigValues; + onChange: (key: string, value: any) => void; + errors?: { [key: string]: string }; + isEditing?: boolean; + hasEncryptedCredentials?: boolean; +} + +function shouldShowField( + spec: ConfigRequirements[string], + values: ConfigValues, +): boolean { + if (!spec.depends_on) return true; + return Object.entries(spec.depends_on).every( + ([depKey, depValue]) => values[depKey] === depValue, + ); +} + +export default function ConfigFields({ + configRequirements, + values, + onChange, + errors = {}, + isEditing = false, + hasEncryptedCredentials = false, +}: ConfigFieldsProps) { + const sortedFields = useMemo( + () => + Object.entries(configRequirements).sort( + ([, a], [, b]) => (a.order ?? 99) - (b.order ?? 99), + ), + [configRequirements], + ); + + if (sortedFields.length === 0) return null; + + return ( +
+ {sortedFields.map(([key, spec]) => { + if (!shouldShowField(spec, values)) return null; + + const value = values[key] ?? spec.default ?? ''; + const hasEncrypted = + isEditing && spec.secret && hasEncryptedCredentials; + const placeholder = hasEncrypted ? '••••••••' : ''; + const hasError = !!errors[key]; + + if (spec.enum) { + return ( +
+ + + {hasError && ( +

{errors[key]}

+ )} +
+ ); + } + + return ( +
+ + { + const v = e.target.value; + if (spec.type === 'number') { + if (v === '') onChange(key, ''); + else { + const num = parseInt(v, 10); + if (!isNaN(num)) onChange(key, num); + } + } else { + onChange(key, v); + } + }} + placeholder={placeholder || spec.description || ''} + min={spec.type === 'number' ? 1 : undefined} + max={spec.type === 'number' && key === 'timeout' ? 300 : undefined} + aria-invalid={hasError || undefined} + className={cn('rounded-xl', hasError && 'border-destructive')} + /> + {hasError && ( +

{errors[key]}

+ )} +
+ ); + })} +
+ ); +} diff --git a/frontend/src/components/ContextMenu.tsx b/frontend/src/components/ContextMenu.tsx index d9482e3b..9348a45c 100644 --- a/frontend/src/components/ContextMenu.tsx +++ b/frontend/src/components/ContextMenu.tsx @@ -1,7 +1,9 @@ -import { SyntheticEvent, useRef, useEffect, CSSProperties } from 'react'; +import { CSSProperties, SyntheticEvent, useEffect, useRef } from 'react'; + +import type { LucideIcon } from 'lucide-react'; export interface MenuOption { - icon?: string; + icon?: string | LucideIcon; label: string; onClick: (event: SyntheticEvent) => void; variant?: 'primary' | 'danger'; @@ -145,16 +147,28 @@ export default function ContextMenu({ > {option.icon && (
- {option.label} + {typeof option.icon === 'string' ? ( + {option.label} + ) : ( +
)} - {option.label} + {option.label} ))} diff --git a/frontend/src/components/ui/input.tsx b/frontend/src/components/ui/input.tsx new file mode 100644 index 00000000..cdfcb987 --- /dev/null +++ b/frontend/src/components/ui/input.tsx @@ -0,0 +1,23 @@ +import * as React from 'react'; + +import { cn } from '@/lib/utils'; + +function Input({ className, type, ...props }: React.ComponentProps<'input'>) { + return ( + + ); +} + +export { Input }; diff --git a/frontend/src/components/ui/label.tsx b/frontend/src/components/ui/label.tsx new file mode 100644 index 00000000..af666c99 --- /dev/null +++ b/frontend/src/components/ui/label.tsx @@ -0,0 +1,22 @@ +import { Label as LabelPrimitive } from 'radix-ui'; +import * as React from 'react'; + +import { cn } from '@/lib/utils'; + +function Label({ + className, + ...props +}: React.ComponentProps) { + return ( + + ); +} + +export { Label }; diff --git a/frontend/src/components/ui/select.tsx b/frontend/src/components/ui/select.tsx index 74581f39..3be303d5 100644 --- a/frontend/src/components/ui/select.tsx +++ b/frontend/src/components/ui/select.tsx @@ -25,17 +25,23 @@ function SelectValue({ function SelectTrigger({ className, size = 'default', + variant = 'default', children, ...props }: React.ComponentProps & { - size?: 'sm' | 'default'; + size?: 'sm' | 'default' | 'lg'; + variant?: 'default' | 'ghost'; }) { return ( (''); - const [customName, setCustomName] = React.useState(''); + const [configValues, setConfigValues] = useState<{ [key: string]: any }>({}); + const [customName, setCustomName] = useState(''); + const [errors, setErrors] = useState<{ [key: string]: string }>({}); + const [saving, setSaving] = useState(false); - const handleAddTool = (tool: AvailableToolType) => { + const configRequirements = useMemo( + () => tool?.configRequirements ?? {}, + [tool], + ); + + const hasConfig = Object.keys(configRequirements).length > 0; + + const handleFieldChange = (key: string, value: any) => { + setConfigValues((prev) => ({ ...prev, [key]: value })); + if (errors[key]) setErrors((prev) => ({ ...prev, [key]: '' })); + }; + + const validate = () => { + const newErrors: { [key: string]: string } = {}; + Object.entries(configRequirements).forEach(([key, spec]) => { + if (spec.depends_on) { + const visible = Object.entries(spec.depends_on).every( + ([dk, dv]) => configValues[dk] === dv, + ); + if (!visible) return; + } + if (spec.required && !configValues[key]?.toString().trim()) { + newErrors[key] = `${spec.label || key} is required`; + } + if (spec.type === 'number' && configValues[key] !== undefined) { + const num = Number(configValues[key]); + if (isNaN(num) || num < 1) { + newErrors[key] = 'Must be a positive number'; + } + if (key === 'timeout' && num > 300) { + newErrors[key] = 'Maximum timeout is 300 seconds'; + } + } + }); + setErrors(newErrors); + return Object.keys(newErrors).length === 0; + }; + + const handleClose = () => { + setModalState('INACTIVE'); + setConfigValues({}); + setCustomName(''); + setErrors({}); + }; + + const handleAddTool = () => { + if (!tool || !validate()) return; + + const config: { [key: string]: any } = {}; + Object.entries(configRequirements).forEach(([key, spec]) => { + const val = configValues[key]; + if (val !== undefined && val !== '') { + config[key] = val; + } else if (spec.default !== undefined) { + config[key] = spec.default; + } + }); + + setSaving(true); userService .createTool( { name: tool.name, displayName: tool.displayName, description: tool.description, - config: { token: authKey }, - customName: customName, + config, + customName, actions: tool.actions, status: true, }, token, ) .then(() => { - setModalState('INACTIVE'); + handleClose(); getUserTools(); - }); + }) + .finally(() => setSaving(false)); }; - // Only render when modal is active - if (modalState !== 'ACTIVE') return null; + if (modalState !== 'ACTIVE' || !tool) return null; return ( - setModalState('INACTIVE')}> -
-

+ +
+

{t('modals.configTool.title')}

-

+

{t('modals.configTool.type')}:{' '} - {tool?.name} + + {tool.displayName} +

-
- setCustomName(e.target.value)} - borderVariant="thin" - placeholder={t('modals.configTool.customNamePlaceholder')} - labelBgClassName="bg-white dark:bg-charleston-green-2" - /> + +
+
+ + setCustomName(e.target.value)} + placeholder={tool.displayName} + className="rounded-xl" + /> +
+ + {hasConfig && }
-
- setAuthKey(e.target.value)} - borderVariant="thin" - placeholder={t('modals.configTool.apiKeyPlaceholder')} - labelBgClassName="bg-white dark:bg-charleston-green-2" - /> -
-
+ +
diff --git a/frontend/src/modals/MCPServerModal.tsx b/frontend/src/modals/MCPServerModal.tsx index e4042dbe..6fcac371 100644 --- a/frontend/src/modals/MCPServerModal.tsx +++ b/frontend/src/modals/MCPServerModal.tsx @@ -1,12 +1,19 @@ -import { useRef, useState } from 'react'; +import { useCallback, useEffect, useRef, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { useSelector } from 'react-redux'; +import { baseURL } from '../api/client'; import userService from '../api/services/userService'; -import Dropdown from '../components/Dropdown'; -import Input from '../components/Input'; import Spinner from '../components/Spinner'; -import { useOutsideAlerter } from '../hooks'; +import { Input } from '../components/ui/input'; +import { Label } from '../components/ui/label'; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from '../components/ui/select'; import { ActiveState } from '../models/misc'; import { selectToken } from '../preferences/preferenceSlice'; import WrapperComponent from './WrapperModal'; @@ -26,7 +33,6 @@ export default function MCPServerModal({ }: MCPServerModalProps) { const { t } = useTranslation(); const token = useSelector(selectToken); - const modalRef = useRef(null); const authTypes = [ { label: t('settings.tools.mcp.authTypes.none'), value: 'none' }, @@ -41,12 +47,12 @@ export default function MCPServerModal({ server_url: server?.server_url || '', auth_type: server?.auth_type || 'none', api_key: '', - header_name: 'X-API-Key', + header_name: server?.api_key_header || 'X-API-Key', bearer_token: '', username: '', password: '', timeout: server?.timeout || 30, - oauth_scopes: '', + oauth_scopes: server?.oauth_scopes || '', oauth_task_id: '', }); @@ -57,20 +63,63 @@ export default function MCPServerModal({ message: string; status?: string; authorization_url?: string; + tools?: { name: string; description?: string }[]; + tools_count?: number; } | null>(null); + const [discoveredTools, setDiscoveredTools] = useState< + { name: string; description?: string }[] + >([]); const [errors, setErrors] = useState<{ [key: string]: string }>({}); const oauthPopupRef = useRef(null); + const pollingCancelledRef = useRef(false); + const pollTimerRef = useRef | null>(null); const [oauthCompleted, setOAuthCompleted] = useState(false); const [saveActive, setSaveActive] = useState(false); - useOutsideAlerter(modalRef, () => { - if (modalState === 'ACTIVE') { - setModalState('INACTIVE'); - resetForm(); + const cleanupPolling = useCallback(() => { + pollingCancelledRef.current = true; + if (pollTimerRef.current) { + clearTimeout(pollTimerRef.current); + pollTimerRef.current = null; } - }, [modalState]); + if (oauthPopupRef.current && !oauthPopupRef.current.closed) { + oauthPopupRef.current.close(); + } + oauthPopupRef.current = null; + }, []); + + useEffect(() => { + return cleanupPolling; + }, [cleanupPolling]); + + useEffect(() => { + if (modalState === 'ACTIVE' && server) { + const oauthScopes = Array.isArray(server.oauth_scopes) + ? server.oauth_scopes.join(', ') + : server.oauth_scopes || ''; + setFormData({ + name: server.displayName || t('settings.tools.mcp.defaultServerName'), + server_url: server.server_url || '', + auth_type: server.auth_type || 'none', + api_key: '', + header_name: server.api_key_header || 'X-API-Key', + bearer_token: '', + username: '', + password: '', + timeout: server.timeout || 30, + oauth_scopes: oauthScopes, + oauth_task_id: '', + }); + setErrors({}); + setTestResult(null); + setDiscoveredTools([]); + setSaveActive(false); + setOAuthCompleted(false); + } + }, [modalState, server]); const resetForm = () => { + cleanupPolling(); setFormData({ name: t('settings.tools.mcp.defaultServerName'), server_url: '', @@ -86,7 +135,10 @@ export default function MCPServerModal({ }); setErrors({}); setTestResult(null); + setDiscoveredTools([]); setSaveActive(false); + setTesting(false); + setOAuthCompleted(false); }; const validateForm = () => { @@ -168,9 +220,10 @@ export default function MCPServerModal({ } else if (formData.auth_type === 'oauth') { config.oauth_scopes = formData.oauth_scopes .split(',') - .map((s) => s.trim()) + .map((s: string) => s.trim()) .filter(Boolean); config.oauth_task_id = formData.oauth_task_id.trim(); + config.redirect_uri = `${baseURL.replace(/\/$/, '')}/api/mcp_server/callback`; } return config; }; @@ -182,10 +235,16 @@ export default function MCPServerModal({ let attempts = 0; const maxAttempts = 60; let popupOpened = false; + pollingCancelledRef.current = false; + const poll = async () => { + if (pollingCancelledRef.current) return; try { const resp = await userService.getMCPOAuthStatus(taskId, token); + if (pollingCancelledRef.current) return; const data = await resp.json(); + if (pollingCancelledRef.current) return; + if (data.authorization_url && !popupOpened) { if (oauthPopupRef.current && !oauthPopupRef.current.closed) { oauthPopupRef.current.close(); @@ -196,7 +255,22 @@ export default function MCPServerModal({ 'width=600,height=700', ); popupOpened = true; + + if (!oauthPopupRef.current) { + setTestResult({ + success: true, + message: t('settings.tools.mcp.oauthPopupBlocked', { + defaultValue: + 'Popup blocked by browser. Click below to authorize:', + }), + authorization_url: data.authorization_url, + }); + } } + + const callbackReceived = + data.status === 'callback_received' || data.status === 'completed'; + if (data.status === 'completed') { setOAuthCompleted(true); setSaveActive(true); @@ -213,15 +287,30 @@ export default function MCPServerModal({ onComplete({ ...data, success: false, - message: t('settings.tools.mcp.errors.oauthFailed'), + message: data.message || t('settings.tools.mcp.errors.oauthFailed'), }); if (oauthPopupRef.current && !oauthPopupRef.current.closed) { oauthPopupRef.current.close(); } } else { - if (++attempts < maxAttempts) setTimeout(poll, 1000); - else { + if (++attempts < maxAttempts) { + if ( + oauthPopupRef.current && + oauthPopupRef.current.closed && + popupOpened && + !callbackReceived + ) { + setSaveActive(false); + onComplete({ + success: false, + message: t('settings.tools.mcp.errors.oauthFailed'), + }); + return; + } + pollTimerRef.current = setTimeout(poll, 1000); + } else { setSaveActive(false); + cleanupPolling(); onComplete({ success: false, message: t('settings.tools.mcp.errors.oauthTimeout'), @@ -229,12 +318,16 @@ export default function MCPServerModal({ } } } catch { - if (++attempts < maxAttempts) setTimeout(poll, 1000); - else + if (pollingCancelledRef.current) return; + if (++attempts < maxAttempts) { + pollTimerRef.current = setTimeout(poll, 1000); + } else { + cleanupPolling(); onComplete({ success: false, message: t('settings.tools.mcp.errors.oauthTimeout'), }); + } } }; poll(); @@ -242,8 +335,11 @@ export default function MCPServerModal({ const testConnection = async () => { if (!validateForm()) return; + cleanupPolling(); setTesting(true); setTestResult(null); + setDiscoveredTools([]); + setOAuthCompleted(false); try { const config = buildToolConfig(); const response = await userService.testMCPConnection({ config }, token); @@ -258,10 +354,12 @@ export default function MCPServerModal({ success: true, message: t('settings.tools.mcp.oauthInProgress'), }); - setOAuthCompleted(false); setSaveActive(false); pollOAuthStatus(result.task_id, (finalResult) => { setTestResult(finalResult); + if (finalResult.tools && Array.isArray(finalResult.tools)) { + setDiscoveredTools(finalResult.tools); + } setFormData((prev) => ({ ...prev, oauth_task_id: result.task_id || '', @@ -270,6 +368,9 @@ export default function MCPServerModal({ }); } else { setTestResult(result); + if (result.success && result.tools && Array.isArray(result.tools)) { + setDiscoveredTools(result.tools); + } setSaveActive(result.success === true); setTesting(false); } @@ -312,8 +413,7 @@ export default function MCPServerModal({ general: result.error || t('settings.tools.mcp.errors.saveFailed'), }); } - } catch (error) { - console.error('Error saving MCP server:', error); + } catch { setErrors({ general: t('settings.tools.mcp.errors.saveFailed') }); } finally { setLoading(false); @@ -324,113 +424,123 @@ export default function MCPServerModal({ switch (formData.auth_type) { case 'api_key': return ( -
-
+
+
+ handleInputChange('api_key', e.target.value)} placeholder={t('settings.tools.mcp.placeholders.apiKey')} - borderVariant="thin" - labelBgClassName="bg-white dark:bg-charleston-green-2" + aria-invalid={!!errors.api_key || undefined} + className="rounded-xl" /> {errors.api_key && ( -

{errors.api_key}

+

{errors.api_key}

)}
-
+
+ handleInputChange('header_name', e.target.value) } - placeholder={t('settings.tools.mcp.headerName')} - borderVariant="thin" - labelBgClassName="bg-white dark:bg-charleston-green-2" + placeholder="X-API-Key" + className="rounded-xl" />
); case 'bearer': return ( -
+
+ handleInputChange('bearer_token', e.target.value) } placeholder={t('settings.tools.mcp.placeholders.bearerToken')} - borderVariant="thin" - labelBgClassName="bg-white dark:bg-charleston-green-2" + aria-invalid={!!errors.bearer_token || undefined} + className="rounded-xl" /> {errors.bearer_token && ( -

{errors.bearer_token}

+

{errors.bearer_token}

)}
); case 'basic': return ( -
-
+
+
+ handleInputChange('username', e.target.value)} placeholder={t('settings.tools.mcp.username')} - borderVariant="thin" - labelBgClassName="bg-white dark:bg-charleston-green-2" + aria-invalid={!!errors.username || undefined} + className="rounded-xl" /> {errors.username && ( -

{errors.username}

+

{errors.username}

)}
-
+
+ handleInputChange('password', e.target.value)} placeholder={t('settings.tools.mcp.password')} - borderVariant="thin" - labelBgClassName="bg-white dark:bg-charleston-green-2" + aria-invalid={!!errors.password || undefined} + className="rounded-xl" /> {errors.password && ( -

{errors.password}

+

{errors.password}

)}
); case 'oauth': return ( -
-
- - handleInputChange('oauth_scopes', e.target.value) - } - placeholder={ - t('settings.tools.mcp.placeholders.oauthScopes') || - 'Scopes (comma separated)' - } - borderVariant="thin" - labelBgClassName="bg-white dark:bg-charleston-green-2" - /> -
+
+ + + handleInputChange('oauth_scopes', e.target.value) + } + placeholder="read, write" + className="rounded-xl" + />
); default: @@ -451,69 +561,99 @@ export default function MCPServerModal({

{server - ? t('settings.tools.mcp.editServer') + ? t('settings.tools.mcp.reconnectServer', { + defaultValue: 'Reconnect Server', + }) : t('settings.tools.mcp.addServer')}

-
-
+
+ {server?.has_encrypted_credentials && + formData.auth_type !== 'oauth' && ( +
+ {t('settings.tools.mcp.reenterCredentials', { + defaultValue: + 'Re-enter your credentials to test and update the connection.', + })} +
+ )} +
+ handleInputChange('name', e.target.value)} - borderVariant="thin" placeholder={t('settings.tools.mcp.serverName')} - labelBgClassName="bg-white dark:bg-charleston-green-2" + aria-invalid={!!errors.name || undefined} + className="rounded-xl" /> {errors.name && ( -

{errors.name}

+

{errors.name}

)}
-
+
+ handleInputChange('server_url', e.target.value) } - placeholder={t('settings.tools.mcp.serverUrl')} - borderVariant="thin" - labelBgClassName="bg-white dark:bg-charleston-green-2" + placeholder="https://example.com/mcp" + aria-invalid={!!errors.server_url || undefined} + className="rounded-xl" /> {errors.server_url && ( -

+

{errors.server_url}

)}
- type.value === formData.auth_type) - ?.label || null - } - onSelect={(selection: { label: string; value: string }) => { - handleInputChange('auth_type', selection.value); - }} - options={authTypes} - size="w-full" - rounded="3xl" - border="border" - /> +
+ + +
{renderAuthFields()} -
+
+ { const value = e.target.value; @@ -526,40 +666,94 @@ export default function MCPServerModal({ } } }} - placeholder={t('settings.tools.mcp.timeout')} - borderVariant="thin" - labelBgClassName="bg-white dark:bg-charleston-green-2" + placeholder="30" + min={1} + max={300} + aria-invalid={!!errors.timeout || undefined} + className="rounded-xl" /> {errors.timeout && ( -

{errors.timeout}

+

{errors.timeout}

)}
{testResult && ( + )} + + {discoveredTools.length > 0 && testResult?.success && ( +
+

+ {t('settings.tools.mcp.discoveredTools', { + count: discoveredTools.length, + defaultValue: `Discovered Actions (${discoveredTools.length})`, + })} +

+
    + {discoveredTools.map((tool) => ( +
  • + +
    + + {tool.name} + + {tool.description && ( +

    + {tool.description} +

    + )} +
    +
  • + ))} +
)} {errors.general && ( -
+
{errors.general}
)}
-
-
+
+
- {/* Custom name section */} + {saveError && ( +
+ {saveError} +
+ )}

{t('settings.tools.customName')}

- setCustomName(e.target.value)} - borderVariant="thin" placeholder={t('settings.tools.customNamePlaceholder')} + className="rounded-xl" />
- {Object.keys(tool?.config).length !== 0 && tool.name !== 'api_tool' && ( -

- {tool.name === 'mcp_tool' - ? (tool.config as any)?.auth_type === 'bearer' - ? 'Bearer Token' - : (tool.config as any)?.auth_type === 'api_key' - ? 'API Key' - : (tool.config as any)?.auth_type === 'basic' - ? 'Password' - : t('settings.tools.authentication') - : t('settings.tools.authentication')} -

- )} -
- {Object.keys(tool?.config).length !== 0 && - tool.name !== 'api_tool' && ( -
- setAuthKey(e.target.value)} - borderVariant="thin" - placeholder={ - tool.name === 'mcp_tool' - ? (tool.config as any)?.auth_type === 'bearer' - ? 'Bearer Token' - : (tool.config as any)?.auth_type === 'api_key' - ? 'API Key' - : (tool.config as any)?.auth_type === 'basic' - ? 'Password' - : t('modals.configTool.apiKeyPlaceholder') - : t('modals.configTool.apiKeyPlaceholder') + {tool.name !== 'api_tool' && + Object.keys(configRequirements).length > 0 && ( +
+

+ {t('settings.tools.authentication')} +

+
+
- )} -
+
+ )}
@@ -522,7 +565,7 @@ export default function ToolConfig({