mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
feat: Implement OAuth flow for MCP server integration
- Added MCPOAuthManager to handle OAuth authorization. - Updated MCPServerSave resource to manage OAuth status and callback. - Introduced new endpoints for OAuth status and callback handling. - Enhanced user interface to support OAuth authentication type. - Implemented polling mechanism for OAuth status in MCPServerModal. - Updated frontend services and endpoints to accommodate new OAuth features. - Improved error handling and user feedback for OAuth processes.
This commit is contained in:
@@ -1,47 +1,39 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from application.agents.tools.base import Tool
|
||||
from application.security.encryption import decrypt_credentials
|
||||
from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
|
||||
from application.cache import get_redis_instance
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
|
||||
from application.core.settings import settings
|
||||
|
||||
from application.security.encryption import decrypt_credentials
|
||||
from fastmcp import Client
|
||||
from fastmcp.client.auth import BearerAuth
|
||||
from fastmcp.client.auth.oauth import OAuth as FastMCPOAuth
|
||||
from fastmcp.client.transports import (
|
||||
SSETransport,
|
||||
StdioTransport,
|
||||
StreamableHttpTransport,
|
||||
)
|
||||
from mcp.client.auth import OAuthClientProvider, TokenStorage
|
||||
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
|
||||
|
||||
from pydantic import AnyHttpUrl, ValidationError
|
||||
from redis import Redis
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
|
||||
_mcp_clients_cache = {}
|
||||
|
||||
|
||||
class DocsGPTOAuth(FastMCPOAuth):
|
||||
"""Custom OAuth handler that integrates with DocsGPT frontend instead of opening browser."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.auth_url_callback = kwargs.pop("auth_url_callback", None)
|
||||
self.auth_code_callback = kwargs.pop("auth_code_callback", None)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def redirect_handler(self, authorization_url: str) -> None:
|
||||
"""Override to send auth URL to frontend instead of opening browser."""
|
||||
if self.auth_url_callback:
|
||||
self.auth_url_callback(authorization_url)
|
||||
else:
|
||||
raise Exception("OAuth authorization URL callback not configured")
|
||||
|
||||
async def callback_handler(self) -> tuple[str, str | None]:
|
||||
"""Override to wait for auth code from frontend instead of local server."""
|
||||
if self.auth_code_callback:
|
||||
auth_code, state = await self.auth_code_callback()
|
||||
return auth_code, state
|
||||
else:
|
||||
raise Exception("OAuth callback handler not configured")
|
||||
|
||||
|
||||
class MCPTool(Tool):
|
||||
"""
|
||||
MCP Tool
|
||||
@@ -67,6 +59,7 @@ class MCPTool(Tool):
|
||||
user_id: User ID for decrypting credentials (required if encrypted_credentials exist)
|
||||
"""
|
||||
self.config = config
|
||||
self.user_id = user_id
|
||||
self.server_url = config.get("server_url", "")
|
||||
self.transport_type = config.get("transport_type", "auto")
|
||||
self.auth_type = config.get("auth_type", "none")
|
||||
@@ -80,22 +73,16 @@ class MCPTool(Tool):
|
||||
)
|
||||
else:
|
||||
self.auth_credentials = config.get("auth_credentials", {})
|
||||
# OAuth specific configuration
|
||||
|
||||
self.oauth_scopes = config.get("oauth_scopes", [])
|
||||
self.oauth_task_id = config.get("oauth_task_id", None)
|
||||
self.oauth_client_name = config.get("oauth_client_name", "DocsGPT-MCP")
|
||||
|
||||
# OAuth callback handlers (to be set by frontend)
|
||||
|
||||
self.oauth_auth_url_callback = None
|
||||
self.oauth_auth_code_callback = None
|
||||
self.redirect_uri = f"{settings.API_URL}/api/mcp_server/callback"
|
||||
|
||||
self.available_tools = []
|
||||
self._cache_key = self._generate_cache_key()
|
||||
self._client = None
|
||||
|
||||
# Only validate and setup if server_url is provided and not OAuth
|
||||
# OAuth setup will happen after callbacks are set
|
||||
|
||||
if self.server_url and self.auth_type != "oauth":
|
||||
self._setup_client()
|
||||
@@ -104,8 +91,6 @@ class MCPTool(Tool):
|
||||
"""Generate a unique cache key for this MCP server configuration."""
|
||||
auth_key = ""
|
||||
if self.auth_type == "oauth":
|
||||
# For OAuth, use scopes and client name as part of the key
|
||||
|
||||
scopes_str = ",".join(self.oauth_scopes) if self.oauth_scopes else "none"
|
||||
auth_key = f"oauth:{self.oauth_client_name}:{scopes_str}"
|
||||
elif self.auth_type in ["bearer"]:
|
||||
@@ -137,22 +122,17 @@ class MCPTool(Tool):
|
||||
auth = None
|
||||
|
||||
if self.auth_type == "oauth":
|
||||
# Ensure callbacks are configured before creating OAuth instance
|
||||
|
||||
if not self.oauth_auth_url_callback or not self.oauth_auth_code_callback:
|
||||
raise Exception(
|
||||
"OAuth callbacks not configured. Call set_oauth_callbacks() first."
|
||||
)
|
||||
# Use custom OAuth handler for frontend integration
|
||||
|
||||
redis_client = get_redis_instance()
|
||||
auth = DocsGPTOAuth(
|
||||
mcp_url=self.server_url,
|
||||
scopes=self.oauth_scopes,
|
||||
client_name=self.oauth_client_name,
|
||||
auth_url_callback=self.oauth_auth_url_callback,
|
||||
auth_code_callback=self.oauth_auth_code_callback,
|
||||
redis_client=redis_client,
|
||||
redirect_uri=self.redirect_uri,
|
||||
task_id=self.oauth_task_id,
|
||||
db=db,
|
||||
user_id=self.user_id,
|
||||
)
|
||||
elif self.auth_type in ["bearer"]:
|
||||
elif self.auth_type == "bearer":
|
||||
token = self.auth_credentials.get(
|
||||
"bearer_token", ""
|
||||
) or self.auth_credentials.get("access_token", "")
|
||||
@@ -202,6 +182,33 @@ class MCPTool(Tool):
|
||||
else:
|
||||
return StreamableHttpTransport(url=self.server_url, headers=headers)
|
||||
|
||||
def _format_tools(self, tools_response) -> List[Dict]:
|
||||
"""Format tools response to match expected format."""
|
||||
if hasattr(tools_response, "tools"):
|
||||
tools = tools_response.tools
|
||||
elif isinstance(tools_response, list):
|
||||
tools = tools_response
|
||||
else:
|
||||
tools = []
|
||||
tools_dict = []
|
||||
for tool in tools:
|
||||
if hasattr(tool, "name"):
|
||||
tool_dict = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
}
|
||||
if hasattr(tool, "inputSchema"):
|
||||
tool_dict["inputSchema"] = tool.inputSchema
|
||||
tools_dict.append(tool_dict)
|
||||
elif isinstance(tool, dict):
|
||||
tools_dict.append(tool)
|
||||
else:
|
||||
if hasattr(tool, "model_dump"):
|
||||
tools_dict.append(tool.model_dump())
|
||||
else:
|
||||
tools_dict.append({"name": str(tool), "description": ""})
|
||||
return tools_dict
|
||||
|
||||
async def _execute_with_client(self, operation: str, *args, **kwargs):
|
||||
"""Execute operation with FastMCP client."""
|
||||
if not self._client:
|
||||
@@ -211,32 +218,8 @@ class MCPTool(Tool):
|
||||
return await self._client.ping()
|
||||
elif operation == "list_tools":
|
||||
tools_response = await self._client.list_tools()
|
||||
|
||||
if hasattr(tools_response, "tools"):
|
||||
tools = tools_response.tools
|
||||
elif isinstance(tools_response, list):
|
||||
tools = tools_response
|
||||
else:
|
||||
tools = []
|
||||
tools_dict = []
|
||||
for tool in tools:
|
||||
if hasattr(tool, "name"):
|
||||
tool_dict = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
}
|
||||
if hasattr(tool, "inputSchema"):
|
||||
tool_dict["inputSchema"] = tool.inputSchema
|
||||
tools_dict.append(tool_dict)
|
||||
elif isinstance(tool, dict):
|
||||
tools_dict.append(tool)
|
||||
else:
|
||||
|
||||
if hasattr(tool, "model_dump"):
|
||||
tools_dict.append(tool.model_dump())
|
||||
else:
|
||||
tools_dict.append({"name": str(tool), "description": ""})
|
||||
return tools_dict
|
||||
self.available_tools = self._format_tools(tools_response)
|
||||
return self.available_tools
|
||||
elif operation == "call_tool":
|
||||
tool_name = args[0]
|
||||
tool_args = kwargs
|
||||
@@ -251,12 +234,8 @@ class MCPTool(Tool):
|
||||
def _run_async_operation(self, operation: str, *args, **kwargs):
|
||||
"""Run async operation in sync context."""
|
||||
try:
|
||||
# Check if there's already a running event loop
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
# If we're in an async context, we need to run in a separate thread
|
||||
|
||||
import concurrent.futures
|
||||
|
||||
def run_in_thread():
|
||||
@@ -273,8 +252,6 @@ class MCPTool(Tool):
|
||||
future = executor.submit(run_in_thread)
|
||||
return future.result(timeout=self.timeout)
|
||||
except RuntimeError:
|
||||
# No running loop, we can create one
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
@@ -284,13 +261,8 @@ class MCPTool(Tool):
|
||||
finally:
|
||||
loop.close()
|
||||
except Exception as e:
|
||||
# If async fails, try to give a better error message for OAuth
|
||||
|
||||
if self.auth_type == "oauth" and "callback not configured" in str(e):
|
||||
raise Exception(
|
||||
"OAuth callbacks not configured. Call set_oauth_callbacks() first."
|
||||
)
|
||||
raise e
|
||||
print(f"Error occurred while running async operation: {e}")
|
||||
raise
|
||||
|
||||
def discover_tools(self) -> List[Dict]:
|
||||
"""
|
||||
@@ -377,8 +349,6 @@ class MCPTool(Tool):
|
||||
if not self._client:
|
||||
self._setup_client()
|
||||
try:
|
||||
# For OAuth, we need to handle async operations differently
|
||||
|
||||
if self.auth_type == "oauth":
|
||||
return self._test_oauth_connection()
|
||||
else:
|
||||
@@ -412,40 +382,21 @@ class MCPTool(Tool):
|
||||
"transport_type": self.transport_type,
|
||||
"auth_type": self.auth_type,
|
||||
"ping_supported": ping_success,
|
||||
"tools": [tool.get("name", "unknown") for tool in tools[:5]],
|
||||
"tools": [tool.get("name", "unknown") for tool in tools],
|
||||
}
|
||||
|
||||
def _test_oauth_connection(self) -> Dict:
|
||||
"""Test connection for OAuth auth type with proper async handling."""
|
||||
try:
|
||||
# Ensure callbacks are configured before proceeding
|
||||
|
||||
if not self.oauth_auth_url_callback or not self.oauth_auth_code_callback:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "OAuth callbacks not configured. Call set_oauth_callbacks() first.",
|
||||
"tools_count": 0,
|
||||
"transport_type": self.transport_type,
|
||||
"auth_type": self.auth_type,
|
||||
"error_type": "ConfigurationError",
|
||||
}
|
||||
# Ensure client is set up with proper callbacks
|
||||
|
||||
if not self._client:
|
||||
self._setup_client()
|
||||
# For OAuth, we use a simpler approach - just try to discover tools
|
||||
# This will trigger the OAuth flow if needed
|
||||
|
||||
tools = self.discover_tools()
|
||||
|
||||
task = mcp_oauth_task.delay(config=self.config, user=self.user_id)
|
||||
if not task:
|
||||
raise Exception("Failed to start OAuth authentication")
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Successfully connected to OAuth MCP server. Found {len(tools)} tools.",
|
||||
"tools_count": len(tools),
|
||||
"transport_type": self.transport_type,
|
||||
"auth_type": self.auth_type,
|
||||
"ping_supported": False, # Skip ping for OAuth to avoid complexity
|
||||
"tools": [tool.get("name", "unknown") for tool in tools[:5]],
|
||||
"requires_oauth": True,
|
||||
"task_id": task.id,
|
||||
"status": "pending",
|
||||
"message": "OAuth flow started",
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
@@ -457,67 +408,6 @@ class MCPTool(Tool):
|
||||
"error_type": type(e).__name__,
|
||||
}
|
||||
|
||||
def set_oauth_callbacks(self, auth_url_callback, auth_code_callback):
|
||||
"""
|
||||
Set OAuth callback handlers for frontend integration.
|
||||
|
||||
Args:
|
||||
auth_url_callback: Function to call with authorization URL
|
||||
auth_code_callback: Async function that returns (auth_code, state) tuple
|
||||
"""
|
||||
self.oauth_auth_url_callback = auth_url_callback
|
||||
self.oauth_auth_code_callback = auth_code_callback
|
||||
|
||||
# Clear the client so it gets recreated with the new callbacks
|
||||
|
||||
self._client = None
|
||||
|
||||
# Also clear from cache to ensure fresh creation
|
||||
|
||||
global _mcp_clients_cache
|
||||
if self._cache_key in _mcp_clients_cache:
|
||||
del _mcp_clients_cache[self._cache_key]
|
||||
|
||||
def clear_oauth_cache(self):
|
||||
"""
|
||||
Clear OAuth cache to force fresh authentication.
|
||||
This will remove stored tokens and client info for the server.
|
||||
"""
|
||||
if self.auth_type == "oauth":
|
||||
try:
|
||||
from fastmcp.client.auth.oauth import FileTokenStorage
|
||||
|
||||
storage = FileTokenStorage(server_url=self.server_url)
|
||||
storage.clear()
|
||||
print(f"✅ Cleared OAuth cache for {self.server_url}")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to clear OAuth cache: {e}")
|
||||
# Also clear our internal client cache
|
||||
|
||||
global _mcp_clients_cache
|
||||
if self._cache_key in _mcp_clients_cache:
|
||||
del _mcp_clients_cache[self._cache_key]
|
||||
print(f"✅ Cleared internal client cache")
|
||||
|
||||
@staticmethod
|
||||
def clear_all_oauth_cache():
|
||||
"""
|
||||
Clear all OAuth cache for all servers.
|
||||
This will remove all stored tokens and client info.
|
||||
"""
|
||||
try:
|
||||
from fastmcp.client.auth.oauth import FileTokenStorage
|
||||
|
||||
FileTokenStorage.clear_all()
|
||||
print(f"✅ Cleared all OAuth cache")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to clear all OAuth cache: {e}")
|
||||
# Also clear all internal client cache
|
||||
|
||||
global _mcp_clients_cache
|
||||
_mcp_clients_cache.clear()
|
||||
print(f"✅ Cleared all internal client cache")
|
||||
|
||||
def get_actions_metadata(self) -> List[Dict]:
|
||||
"""
|
||||
Get metadata for all available actions.
|
||||
@@ -666,3 +556,306 @@ class MCPTool(Tool):
|
||||
"required": False,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class DocsGPTOAuth(OAuthClientProvider):
|
||||
"""
|
||||
Custom OAuth handler for DocsGPT that uses frontend redirect instead of browser.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_url: str,
|
||||
redirect_uri: str,
|
||||
redis_client: Redis | None = None,
|
||||
redis_prefix: str = "mcp_oauth:",
|
||||
task_id: str = None,
|
||||
scopes: str | list[str] | None = None,
|
||||
client_name: str = "DocsGPT-MCP",
|
||||
user_id=None,
|
||||
db=None,
|
||||
additional_client_metadata: dict[str, Any] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize custom OAuth client provider for DocsGPT.
|
||||
|
||||
Args:
|
||||
mcp_url: Full URL to the MCP endpoint
|
||||
redirect_uri: Custom redirect URI for DocsGPT frontend
|
||||
redis_client: Redis client for storing auth state
|
||||
redis_prefix: Prefix for Redis keys
|
||||
task_id: Task ID for tracking auth status
|
||||
scopes: OAuth scopes to request
|
||||
client_name: Name for this client during registration
|
||||
user_id: User ID for token storage
|
||||
db: Database instance for token storage
|
||||
additional_client_metadata: Extra fields for OAuthClientMetadata
|
||||
"""
|
||||
|
||||
self.redirect_uri = redirect_uri
|
||||
self.redis_client = redis_client
|
||||
self.redis_prefix = redis_prefix
|
||||
self.task_id = task_id
|
||||
self.user_id = user_id
|
||||
self.db = db
|
||||
|
||||
parsed_url = urlparse(mcp_url)
|
||||
self.server_base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
|
||||
if isinstance(scopes, list):
|
||||
scopes = " ".join(scopes)
|
||||
client_metadata = OAuthClientMetadata(
|
||||
client_name=client_name,
|
||||
redirect_uris=[AnyHttpUrl(redirect_uri)],
|
||||
grant_types=["authorization_code", "refresh_token"],
|
||||
response_types=["code"],
|
||||
scope=scopes,
|
||||
**(additional_client_metadata or {}),
|
||||
)
|
||||
|
||||
storage = DBTokenStorage(
|
||||
server_url=self.server_base_url, user_id=self.user_id, db_client=self.db
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
server_url=self.server_base_url,
|
||||
client_metadata=client_metadata,
|
||||
storage=storage,
|
||||
redirect_handler=self.redirect_handler,
|
||||
callback_handler=self.callback_handler,
|
||||
)
|
||||
|
||||
self.auth_url = None
|
||||
self.extracted_state = None
|
||||
|
||||
def _process_auth_url(self, authorization_url: str) -> tuple[str, str]:
|
||||
"""Process authorization URL to extract state"""
|
||||
try:
|
||||
parsed_url = urlparse(authorization_url)
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
|
||||
state_params = query_params.get("state", [])
|
||||
if state_params:
|
||||
state = state_params[0]
|
||||
else:
|
||||
raise ValueError("No state in auth URL")
|
||||
return authorization_url, state
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to process auth URL: {e}")
|
||||
|
||||
async def redirect_handler(self, authorization_url: str) -> None:
|
||||
"""Store auth URL and state in Redis for frontend to use."""
|
||||
auth_url, state = self._process_auth_url(authorization_url)
|
||||
logging.info(
|
||||
"[DocsGPTOAuth] Processed auth_url: %s, state: %s", auth_url, state
|
||||
)
|
||||
self.auth_url = auth_url
|
||||
self.extracted_state = state
|
||||
|
||||
if self.redis_client and self.extracted_state:
|
||||
key = f"{self.redis_prefix}auth_url:{self.extracted_state}"
|
||||
self.redis_client.setex(key, 600, auth_url)
|
||||
logging.info("[DocsGPTOAuth] Stored auth_url in Redis: %s", key)
|
||||
|
||||
if self.task_id:
|
||||
status_key = f"mcp_oauth_status:{self.task_id}"
|
||||
status_data = {
|
||||
"status": "requires_redirect",
|
||||
"message": "OAuth authorization required",
|
||||
"authorization_url": self.auth_url,
|
||||
"state": self.extracted_state,
|
||||
"requires_oauth": True,
|
||||
"task_id": self.task_id,
|
||||
}
|
||||
self.redis_client.setex(status_key, 600, json.dumps(status_data))
|
||||
|
||||
async def callback_handler(self) -> tuple[str, str | None]:
|
||||
"""Wait for auth code from Redis using the state value."""
|
||||
if not self.redis_client or not self.extracted_state:
|
||||
raise Exception("Redis client or state not configured for OAuth")
|
||||
poll_interval = 1
|
||||
max_wait_time = 300
|
||||
code_key = f"{self.redis_prefix}code:{self.extracted_state}"
|
||||
|
||||
if self.task_id:
|
||||
status_key = f"mcp_oauth_status:{self.task_id}"
|
||||
status_data = {
|
||||
"status": "awaiting_callback",
|
||||
"message": "Waiting for OAuth callback...",
|
||||
"authorization_url": self.auth_url,
|
||||
"state": self.extracted_state,
|
||||
"requires_oauth": True,
|
||||
"task_id": self.task_id,
|
||||
}
|
||||
self.redis_client.setex(status_key, 600, json.dumps(status_data))
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < max_wait_time:
|
||||
code_data = self.redis_client.get(code_key)
|
||||
if code_data:
|
||||
code = code_data.decode()
|
||||
returned_state = self.extracted_state
|
||||
|
||||
self.redis_client.delete(code_key)
|
||||
self.redis_client.delete(
|
||||
f"{self.redis_prefix}auth_url:{self.extracted_state}"
|
||||
)
|
||||
self.redis_client.delete(
|
||||
f"{self.redis_prefix}state:{self.extracted_state}"
|
||||
)
|
||||
|
||||
if self.task_id:
|
||||
status_data = {
|
||||
"status": "callback_received",
|
||||
"message": "OAuth callback received, completing authentication...",
|
||||
"task_id": self.task_id,
|
||||
}
|
||||
self.redis_client.setex(status_key, 600, json.dumps(status_data))
|
||||
return code, returned_state
|
||||
error_key = f"{self.redis_prefix}error:{self.extracted_state}"
|
||||
error_data = self.redis_client.get(error_key)
|
||||
if error_data:
|
||||
error_msg = error_data.decode()
|
||||
self.redis_client.delete(error_key)
|
||||
self.redis_client.delete(
|
||||
f"{self.redis_prefix}auth_url:{self.extracted_state}"
|
||||
)
|
||||
self.redis_client.delete(
|
||||
f"{self.redis_prefix}state:{self.extracted_state}"
|
||||
)
|
||||
raise Exception(f"OAuth error: {error_msg}")
|
||||
await asyncio.sleep(poll_interval)
|
||||
self.redis_client.delete(f"{self.redis_prefix}auth_url:{self.extracted_state}")
|
||||
self.redis_client.delete(f"{self.redis_prefix}state:{self.extracted_state}")
|
||||
raise Exception("OAuth callback timeout: no code received within 5 minutes")
|
||||
|
||||
|
||||
class DBTokenStorage(TokenStorage):
|
||||
def __init__(self, server_url: str, user_id: str, db_client):
|
||||
self.server_url = server_url
|
||||
self.user_id = user_id
|
||||
self.db_client = db_client
|
||||
self.collection = db_client["connector_sessions"]
|
||||
|
||||
@staticmethod
|
||||
def get_base_url(url: str) -> str:
|
||||
parsed = urlparse(url)
|
||||
return f"{parsed.scheme}://{parsed.netloc}"
|
||||
|
||||
def get_db_key(self) -> dict:
|
||||
return {
|
||||
"server_url": self.get_base_url(self.server_url),
|
||||
"user_id": self.user_id,
|
||||
}
|
||||
|
||||
async def get_tokens(self) -> OAuthToken | None:
|
||||
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
|
||||
if not doc or "tokens" not in doc:
|
||||
return None
|
||||
try:
|
||||
tokens = OAuthToken.model_validate(doc["tokens"])
|
||||
return tokens
|
||||
except ValidationError as e:
|
||||
logging.error(f"Could not load tokens: {e}")
|
||||
return None
|
||||
|
||||
async def set_tokens(self, tokens: OAuthToken) -> None:
|
||||
await asyncio.to_thread(
|
||||
self.collection.update_one,
|
||||
self.get_db_key(),
|
||||
{"$set": {"tokens": tokens.model_dump()}},
|
||||
True,
|
||||
)
|
||||
logging.info(f"Saved tokens for {self.get_base_url(self.server_url)}")
|
||||
|
||||
async def get_client_info(self) -> OAuthClientInformationFull | None:
|
||||
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
|
||||
if not doc or "client_info" not in doc:
|
||||
return None
|
||||
try:
|
||||
client_info = OAuthClientInformationFull.model_validate(doc["client_info"])
|
||||
tokens = await self.get_tokens()
|
||||
if tokens is None:
|
||||
logging.debug(
|
||||
"No tokens found, clearing client info to force fresh registration."
|
||||
)
|
||||
await asyncio.to_thread(
|
||||
self.collection.update_one,
|
||||
self.get_db_key(),
|
||||
{"$unset": {"client_info": ""}},
|
||||
)
|
||||
return None
|
||||
return client_info
|
||||
except ValidationError as e:
|
||||
logging.error(f"Could not load client info: {e}")
|
||||
return None
|
||||
|
||||
def _serialize_client_info(self, info: dict) -> dict:
|
||||
if "redirect_uris" in info and isinstance(info["redirect_uris"], list):
|
||||
info["redirect_uris"] = [str(u) for u in info["redirect_uris"]]
|
||||
return info
|
||||
|
||||
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
|
||||
serialized_info = self._serialize_client_info(client_info.model_dump())
|
||||
await asyncio.to_thread(
|
||||
self.collection.update_one,
|
||||
self.get_db_key(),
|
||||
{"$set": {"client_info": serialized_info}},
|
||||
True,
|
||||
)
|
||||
logging.info(f"Saved client info for {self.get_base_url(self.server_url)}")
|
||||
|
||||
async def clear(self) -> None:
|
||||
await asyncio.to_thread(self.collection.delete_one, self.get_db_key())
|
||||
logging.info(f"Cleared OAuth cache for {self.get_base_url(self.server_url)}")
|
||||
|
||||
@classmethod
|
||||
async def clear_all(cls, db_client) -> None:
|
||||
collection = db_client["connector_sessions"]
|
||||
await asyncio.to_thread(collection.delete_many, {})
|
||||
logging.info("Cleared all OAuth client cache data.")
|
||||
|
||||
|
||||
class MCPOAuthManager:
|
||||
"""Manager for handling MCP OAuth callbacks."""
|
||||
|
||||
def __init__(self, redis_client: Redis | None, redis_prefix: str = "mcp_oauth:"):
|
||||
self.redis_client = redis_client
|
||||
self.redis_prefix = redis_prefix
|
||||
|
||||
def handle_oauth_callback(
|
||||
self, state: str, code: str, error: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Handle OAuth callback from provider.
|
||||
|
||||
Args:
|
||||
state: The state parameter from OAuth callback
|
||||
code: The authorization code from OAuth callback
|
||||
error: Error message if OAuth failed
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
if not self.redis_client or not state:
|
||||
raise Exception("Redis client or state not provided")
|
||||
if error:
|
||||
error_key = f"{self.redis_prefix}error:{state}"
|
||||
self.redis_client.setex(error_key, 300, error)
|
||||
raise Exception(f"OAuth error received: {error}")
|
||||
code_key = f"{self.redis_prefix}code:{state}"
|
||||
self.redis_client.setex(code_key, 300, code)
|
||||
|
||||
state_key = f"{self.redis_prefix}state:{state}"
|
||||
self.redis_client.setex(state_key, 300, "completed")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error(f"Error handling OAuth callback: {e}")
|
||||
return False
|
||||
|
||||
def get_oauth_status(self, task_id: str) -> Dict[str, Any]:
|
||||
"""Get current status of OAuth flow using provided task_id."""
|
||||
if not task_id:
|
||||
return {"status": "not_started", "message": "OAuth flow not started"}
|
||||
return mcp_oauth_status_task(task_id)
|
||||
|
||||
@@ -8,6 +8,7 @@ import uuid
|
||||
import zipfile
|
||||
from functools import wraps
|
||||
from typing import Optional, Tuple
|
||||
from urllib.parse import unquote
|
||||
|
||||
from bson.binary import Binary, UuidRepresentation
|
||||
from bson.dbref import DBRef
|
||||
@@ -25,7 +26,7 @@ from flask_restx import fields, inputs, Namespace, Resource
|
||||
from pymongo import ReturnDocument
|
||||
from werkzeug.utils import secure_filename
|
||||
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
from application.agents.tools.mcp_tool import MCPOAuthManager, MCPTool
|
||||
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
from application.api import api
|
||||
@@ -37,6 +38,8 @@ from application.api.user.tasks import (
|
||||
process_agent_webhook,
|
||||
store_attachment,
|
||||
)
|
||||
|
||||
from application.cache import get_redis_instance
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
@@ -4303,7 +4306,7 @@ class TestMCPServerConfig(Resource):
|
||||
test_config = config.copy()
|
||||
test_config["auth_credentials"] = auth_credentials
|
||||
|
||||
mcp_tool = MCPTool(test_config, user)
|
||||
mcp_tool = MCPTool(config=test_config, user_id=user)
|
||||
result = mcp_tool.test_connection()
|
||||
|
||||
return make_response(jsonify(result), 200)
|
||||
@@ -4371,8 +4374,33 @@ class MCPServerSave(Resource):
|
||||
mcp_config = config.copy()
|
||||
mcp_config["auth_credentials"] = auth_credentials
|
||||
|
||||
if auth_type == "none" or auth_credentials:
|
||||
mcp_tool = MCPTool(mcp_config, user)
|
||||
if auth_type == "oauth":
|
||||
if not config.get("oauth_task_id"):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Connection not authorized. Please complete the OAuth authorization first.",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
redis_client = get_redis_instance()
|
||||
manager = MCPOAuthManager(redis_client)
|
||||
result = manager.get_oauth_status(config["oauth_task_id"])
|
||||
if not result.get("status") == "completed":
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "OAuth failed or not completed. Please try authorizing again.",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
actions_metadata = result.get("tools", [])
|
||||
elif auth_type == "none" or auth_credentials:
|
||||
mcp_tool = MCPTool(config=mcp_config, user_id=user)
|
||||
mcp_tool.discover_tools()
|
||||
actions_metadata = mcp_tool.get_actions_metadata()
|
||||
else:
|
||||
@@ -4455,3 +4483,96 @@ class MCPServerSave(Resource):
|
||||
),
|
||||
500,
|
||||
)
|
||||
|
||||
|
||||
@user_ns.route("/api/mcp_server/callback")
|
||||
class MCPOAuthCallback(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"MCPServerCallbackModel",
|
||||
{
|
||||
"code": fields.String(required=True, description="Authorization code"),
|
||||
"state": fields.String(required=True, description="State parameter"),
|
||||
"error": fields.String(
|
||||
required=False, description="Error message (if any)"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Handle OAuth callback by providing the authorization code and state"
|
||||
)
|
||||
def get(self):
|
||||
code = request.args.get("code")
|
||||
state = request.args.get("state")
|
||||
error = request.args.get("error")
|
||||
|
||||
if error:
|
||||
return redirect(
|
||||
f"/api/connectors/callback-status?status=error&message=OAuth+error:+{error}.+Please+try+again+and+make+sure+to+grant+all+requested+permissions,+including+offline+access.&provider=mcp_tool"
|
||||
)
|
||||
if not code or not state:
|
||||
return redirect(
|
||||
f"/api/connectors/callback-status?status=error&message=Authorization+code+or+state+not+provided.+Please+complete+the+authorization+process+and+make+sure+to+grant+offline+access.&provider=mcp_tool"
|
||||
)
|
||||
try:
|
||||
redis_client = get_redis_instance()
|
||||
if not redis_client:
|
||||
return redirect(
|
||||
f"/api/connectors/callback-status?status=error&message=Internal+server+error:+Redis+not+available.&provider=mcp_tool"
|
||||
)
|
||||
code = unquote(code)
|
||||
manager = MCPOAuthManager(redis_client)
|
||||
success = manager.handle_oauth_callback(state, code, error)
|
||||
if success:
|
||||
return redirect(
|
||||
f"/api/connectors/callback-status?status=success&message=Authorization+code+received+successfully.+You+can+close+this+window.&provider=mcp_tool"
|
||||
)
|
||||
else:
|
||||
return redirect(
|
||||
f"/api/connectors/callback-status?status=error&message=OAuth+callback+failed.&provider=mcp_tool"
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error handling MCP OAuth callback: {str(e)}", exc_info=True
|
||||
)
|
||||
return redirect(
|
||||
f"/api/connectors/callback-status?status=error&message=Internal+server+error:+{str(e)}.&provider=mcp_tool"
|
||||
)
|
||||
|
||||
|
||||
@user_ns.route("/api/mcp_server/oauth_status/<string:task_id>")
|
||||
class MCPOAuthStatus(Resource):
|
||||
def get(self, task_id):
|
||||
"""
|
||||
Get current status of OAuth flow.
|
||||
Frontend should poll this endpoint periodically.
|
||||
"""
|
||||
try:
|
||||
redis_client = get_redis_instance()
|
||||
status_key = f"mcp_oauth_status:{task_id}"
|
||||
status_data = redis_client.get(status_key)
|
||||
|
||||
if status_data:
|
||||
status = json.loads(status_data)
|
||||
return make_response(
|
||||
jsonify({"success": True, "task_id": task_id, **status})
|
||||
)
|
||||
else:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Task not found or expired",
|
||||
"task_id": task_id,
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error getting OAuth status for task {task_id}: {str(e)}"
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "error": str(e), "task_id": task_id}), 500
|
||||
)
|
||||
|
||||
@@ -5,6 +5,8 @@ from application.worker import (
|
||||
agent_webhook_worker,
|
||||
attachment_worker,
|
||||
ingest_worker,
|
||||
mcp_oauth,
|
||||
mcp_oauth_status,
|
||||
remote_worker,
|
||||
sync_worker,
|
||||
)
|
||||
@@ -25,6 +27,7 @@ def ingest_remote(self, source_data, job_name, user, loader):
|
||||
@celery.task(bind=True)
|
||||
def reingest_source_task(self, source_id, user):
|
||||
from application.worker import reingest_source_worker
|
||||
|
||||
resp = reingest_source_worker(self, source_id, user)
|
||||
return resp
|
||||
|
||||
@@ -60,9 +63,10 @@ def ingest_connector_task(
|
||||
retriever="classic",
|
||||
operation_mode="upload",
|
||||
doc_id=None,
|
||||
sync_frequency="never"
|
||||
sync_frequency="never",
|
||||
):
|
||||
from application.worker import ingest_connector
|
||||
|
||||
resp = ingest_connector(
|
||||
self,
|
||||
job_name,
|
||||
@@ -75,7 +79,7 @@ def ingest_connector_task(
|
||||
retriever=retriever,
|
||||
operation_mode=operation_mode,
|
||||
doc_id=doc_id,
|
||||
sync_frequency=sync_frequency
|
||||
sync_frequency=sync_frequency,
|
||||
)
|
||||
return resp
|
||||
|
||||
@@ -94,3 +98,15 @@ def setup_periodic_tasks(sender, **kwargs):
|
||||
timedelta(days=30),
|
||||
schedule_syncs.s("monthly"),
|
||||
)
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def mcp_oauth_task(self, config, user):
|
||||
resp = mcp_oauth(self, config, user)
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def mcp_oauth_status_task(self, task_id):
|
||||
resp = mcp_oauth_status(self, task_id)
|
||||
return resp
|
||||
|
||||
@@ -12,6 +12,7 @@ esprima==4.0.1
|
||||
esutils==1.0.1
|
||||
Flask==3.1.1
|
||||
faiss-cpu==1.9.0.post1
|
||||
fastmcp==2.11.0
|
||||
flask-restx==1.3.0
|
||||
google-genai==1.3.0
|
||||
google-api-python-client==2.179.0
|
||||
@@ -56,13 +57,13 @@ prompt-toolkit==3.0.51
|
||||
protobuf==5.29.3
|
||||
psycopg2-binary==2.9.10
|
||||
py==1.11.0
|
||||
pydantic==2.10.6
|
||||
pydantic-core==2.27.2
|
||||
pydantic-settings==2.7.1
|
||||
pydantic
|
||||
pydantic-core
|
||||
pydantic-settings
|
||||
pymongo==4.11.3
|
||||
pypdf==5.5.0
|
||||
python-dateutil==2.9.0.post0
|
||||
python-dotenv==1.0.1
|
||||
python-dotenv
|
||||
python-jose==3.4.0
|
||||
python-pptx==1.0.2
|
||||
redis==5.2.1
|
||||
@@ -82,7 +83,7 @@ tzdata==2024.2
|
||||
urllib3==2.3.0
|
||||
vine==5.1.0
|
||||
wcwidth==0.2.13
|
||||
werkzeug==3.1.3
|
||||
werkzeug>=3.1.0,<3.1.2
|
||||
yarl==1.20.0
|
||||
markdownify==1.1.0
|
||||
tldextract==5.1.3
|
||||
|
||||
@@ -19,6 +19,7 @@ from bson.objectid import ObjectId
|
||||
from application.agents.agent_creator import AgentCreator
|
||||
from application.api.answer.services.stream_processor import get_prompt
|
||||
|
||||
from application.cache import get_redis_instance
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.parser.chunking import Chunker
|
||||
@@ -214,8 +215,7 @@ def run_agent_logic(agent_config, input_data):
|
||||
|
||||
|
||||
def ingest_worker(
|
||||
self, directory, formats, job_name, file_path, filename, user,
|
||||
retriever="classic"
|
||||
self, directory, formats, job_name, file_path, filename, user, retriever="classic"
|
||||
):
|
||||
"""
|
||||
Ingest and process documents.
|
||||
@@ -240,7 +240,7 @@ def ingest_worker(
|
||||
sample = False
|
||||
|
||||
storage = StorageCreator.get_storage()
|
||||
|
||||
|
||||
logging.info(f"Ingest path: {file_path}", extra={"user": user, "job": job_name})
|
||||
|
||||
# Create temporary working directory
|
||||
@@ -253,30 +253,32 @@ def ingest_worker(
|
||||
# Handle directory case
|
||||
logging.info(f"Processing directory: {file_path}")
|
||||
files_list = storage.list_files(file_path)
|
||||
|
||||
|
||||
for storage_file_path in files_list:
|
||||
if storage.is_directory(storage_file_path):
|
||||
continue
|
||||
|
||||
|
||||
# Create relative path structure in temp directory
|
||||
rel_path = os.path.relpath(storage_file_path, file_path)
|
||||
local_file_path = os.path.join(temp_dir, rel_path)
|
||||
|
||||
|
||||
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
|
||||
|
||||
|
||||
# Download file
|
||||
try:
|
||||
file_data = storage.get_file(storage_file_path)
|
||||
with open(local_file_path, "wb") as f:
|
||||
f.write(file_data.read())
|
||||
except Exception as e:
|
||||
logging.error(f"Error downloading file {storage_file_path}: {e}")
|
||||
logging.error(
|
||||
f"Error downloading file {storage_file_path}: {e}"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
# Handle single file case
|
||||
temp_filename = os.path.basename(file_path)
|
||||
temp_file_path = os.path.join(temp_dir, temp_filename)
|
||||
|
||||
|
||||
file_data = storage.get_file(file_path)
|
||||
with open(temp_file_path, "wb") as f:
|
||||
f.write(file_data.read())
|
||||
@@ -285,7 +287,10 @@ def ingest_worker(
|
||||
if temp_filename.endswith(".zip"):
|
||||
logging.info(f"Extracting zip file: {temp_filename}")
|
||||
extract_zip_recursive(
|
||||
temp_file_path, temp_dir, current_depth=0, max_depth=RECURSION_DEPTH
|
||||
temp_file_path,
|
||||
temp_dir,
|
||||
current_depth=0,
|
||||
max_depth=RECURSION_DEPTH,
|
||||
)
|
||||
|
||||
self.update_state(state="PROGRESS", meta={"current": 1})
|
||||
@@ -300,8 +305,8 @@ def ingest_worker(
|
||||
file_metadata=metadata_from_filename,
|
||||
)
|
||||
raw_docs = reader.load_data()
|
||||
|
||||
directory_structure = getattr(reader, 'directory_structure', {})
|
||||
|
||||
directory_structure = getattr(reader, "directory_structure", {})
|
||||
logging.info(f"Directory structure from reader: {directory_structure}")
|
||||
|
||||
chunker = Chunker(
|
||||
@@ -371,7 +376,10 @@ def reingest_source_worker(self, source_id, user):
|
||||
try:
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
|
||||
self.update_state(state="PROGRESS", meta={"current": 10, "status": "Initializing re-ingestion scan"})
|
||||
self.update_state(
|
||||
state="PROGRESS",
|
||||
meta={"current": 10, "status": "Initializing re-ingestion scan"},
|
||||
)
|
||||
|
||||
source = sources_collection.find_one({"_id": ObjectId(source_id), "user": user})
|
||||
if not source:
|
||||
@@ -380,7 +388,9 @@ def reingest_source_worker(self, source_id, user):
|
||||
storage = StorageCreator.get_storage()
|
||||
source_file_path = source.get("file_path", "")
|
||||
|
||||
self.update_state(state="PROGRESS", meta={"current": 20, "status": "Scanning current files"})
|
||||
self.update_state(
|
||||
state="PROGRESS", meta={"current": 20, "status": "Scanning current files"}
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Download all files from storage to temp directory, preserving directory structure
|
||||
@@ -391,7 +401,6 @@ def reingest_source_worker(self, source_id, user):
|
||||
if storage.is_directory(storage_file_path):
|
||||
continue
|
||||
|
||||
|
||||
rel_path = os.path.relpath(storage_file_path, source_file_path)
|
||||
local_file_path = os.path.join(temp_dir, rel_path)
|
||||
|
||||
@@ -403,23 +412,39 @@ def reingest_source_worker(self, source_id, user):
|
||||
with open(local_file_path, "wb") as f:
|
||||
f.write(file_data.read())
|
||||
except Exception as e:
|
||||
logging.error(f"Error downloading file {storage_file_path}: {e}")
|
||||
logging.error(
|
||||
f"Error downloading file {storage_file_path}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
reader = SimpleDirectoryReader(
|
||||
input_dir=temp_dir,
|
||||
recursive=True,
|
||||
required_exts=[
|
||||
".rst", ".md", ".pdf", ".txt", ".docx", ".csv", ".epub",
|
||||
".html", ".mdx", ".json", ".xlsx", ".pptx", ".png",
|
||||
".jpg", ".jpeg",
|
||||
".rst",
|
||||
".md",
|
||||
".pdf",
|
||||
".txt",
|
||||
".docx",
|
||||
".csv",
|
||||
".epub",
|
||||
".html",
|
||||
".mdx",
|
||||
".json",
|
||||
".xlsx",
|
||||
".pptx",
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
],
|
||||
exclude_hidden=True,
|
||||
file_metadata=metadata_from_filename,
|
||||
)
|
||||
reader.load_data()
|
||||
directory_structure = reader.directory_structure
|
||||
logging.info(f"Directory structure built with token counts: {directory_structure}")
|
||||
logging.info(
|
||||
f"Directory structure built with token counts: {directory_structure}"
|
||||
)
|
||||
|
||||
try:
|
||||
old_directory_structure = source.get("directory_structure") or {}
|
||||
@@ -433,11 +458,17 @@ def reingest_source_worker(self, source_id, user):
|
||||
files = set()
|
||||
if isinstance(struct, dict):
|
||||
for name, meta in struct.items():
|
||||
current_path = os.path.join(prefix, name) if prefix else name
|
||||
if isinstance(meta, dict) and ("type" in meta and "size_bytes" in meta):
|
||||
current_path = (
|
||||
os.path.join(prefix, name) if prefix else name
|
||||
)
|
||||
if isinstance(meta, dict) and (
|
||||
"type" in meta and "size_bytes" in meta
|
||||
):
|
||||
files.add(current_path)
|
||||
elif isinstance(meta, dict):
|
||||
files |= _flatten_directory_structure(meta, current_path)
|
||||
files |= _flatten_directory_structure(
|
||||
meta, current_path
|
||||
)
|
||||
return files
|
||||
|
||||
old_files = _flatten_directory_structure(old_directory_structure)
|
||||
@@ -457,7 +488,9 @@ def reingest_source_worker(self, source_id, user):
|
||||
logging.info("No files removed since last ingest.")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error comparing directory structures: {e}", exc_info=True)
|
||||
logging.error(
|
||||
f"Error comparing directory structures: {e}", exc_info=True
|
||||
)
|
||||
added_files = []
|
||||
removed_files = []
|
||||
try:
|
||||
@@ -477,14 +510,21 @@ def reingest_source_worker(self, source_id, user):
|
||||
settings.EMBEDDINGS_KEY,
|
||||
)
|
||||
|
||||
self.update_state(state="PROGRESS", meta={"current": 40, "status": "Processing file changes"})
|
||||
self.update_state(
|
||||
state="PROGRESS",
|
||||
meta={"current": 40, "status": "Processing file changes"},
|
||||
)
|
||||
|
||||
# 1) Delete chunks from removed files
|
||||
deleted = 0
|
||||
if removed_files:
|
||||
try:
|
||||
for ch in vector_store.get_chunks() or []:
|
||||
metadata = ch.get("metadata", {}) if isinstance(ch, dict) else getattr(ch, "metadata", {})
|
||||
metadata = (
|
||||
ch.get("metadata", {})
|
||||
if isinstance(ch, dict)
|
||||
else getattr(ch, "metadata", {})
|
||||
)
|
||||
raw_source = metadata.get("source")
|
||||
|
||||
source_file = str(raw_source) if raw_source else ""
|
||||
@@ -496,10 +536,17 @@ def reingest_source_worker(self, source_id, user):
|
||||
vector_store.delete_chunk(cid)
|
||||
deleted += 1
|
||||
except Exception as de:
|
||||
logging.error(f"Failed deleting chunk {cid}: {de}")
|
||||
logging.info(f"Deleted {deleted} chunks from {len(removed_files)} removed files")
|
||||
logging.error(
|
||||
f"Failed deleting chunk {cid}: {de}"
|
||||
)
|
||||
logging.info(
|
||||
f"Deleted {deleted} chunks from {len(removed_files)} removed files"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error during deletion of removed file chunks: {e}", exc_info=True)
|
||||
logging.error(
|
||||
f"Error during deletion of removed file chunks: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# 2) Add chunks from new files
|
||||
added = 0
|
||||
@@ -528,58 +575,86 @@ def reingest_source_worker(self, source_id, user):
|
||||
)
|
||||
chunked_new = chunker_new.chunk(documents=raw_docs_new)
|
||||
|
||||
for file_path, token_count in reader_new.file_token_counts.items():
|
||||
for (
|
||||
file_path,
|
||||
token_count,
|
||||
) in reader_new.file_token_counts.items():
|
||||
try:
|
||||
rel_path = os.path.relpath(file_path, start=temp_dir)
|
||||
rel_path = os.path.relpath(
|
||||
file_path, start=temp_dir
|
||||
)
|
||||
path_parts = rel_path.split(os.sep)
|
||||
current_dir = directory_structure
|
||||
|
||||
for part in path_parts[:-1]:
|
||||
if part in current_dir and isinstance(current_dir[part], dict):
|
||||
if part in current_dir and isinstance(
|
||||
current_dir[part], dict
|
||||
):
|
||||
current_dir = current_dir[part]
|
||||
else:
|
||||
break
|
||||
|
||||
filename = path_parts[-1]
|
||||
if filename in current_dir and isinstance(current_dir[filename], dict):
|
||||
current_dir[filename]["token_count"] = token_count
|
||||
logging.info(f"Updated token count for {rel_path}: {token_count}")
|
||||
if filename in current_dir and isinstance(
|
||||
current_dir[filename], dict
|
||||
):
|
||||
current_dir[filename][
|
||||
"token_count"
|
||||
] = token_count
|
||||
logging.info(
|
||||
f"Updated token count for {rel_path}: {token_count}"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(f"Could not update token count for {file_path}: {e}")
|
||||
logging.warning(
|
||||
f"Could not update token count for {file_path}: {e}"
|
||||
)
|
||||
|
||||
for d in chunked_new:
|
||||
meta = dict(d.extra_info or {})
|
||||
try:
|
||||
raw_src = meta.get("source")
|
||||
if isinstance(raw_src, str) and os.path.isabs(raw_src):
|
||||
meta["source"] = os.path.relpath(raw_src, start=temp_dir)
|
||||
if isinstance(raw_src, str) and os.path.isabs(
|
||||
raw_src
|
||||
):
|
||||
meta["source"] = os.path.relpath(
|
||||
raw_src, start=temp_dir
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
vector_store.add_chunk(d.text, metadata=meta)
|
||||
added += 1
|
||||
logging.info(f"Added {added} chunks from {len(added_files)} new files")
|
||||
logging.info(
|
||||
f"Added {added} chunks from {len(added_files)} new files"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error during ingestion of new files: {e}", exc_info=True)
|
||||
logging.error(
|
||||
f"Error during ingestion of new files: {e}", exc_info=True
|
||||
)
|
||||
|
||||
# 3) Update source directory structure timestamp
|
||||
try:
|
||||
total_tokens = sum(reader.file_token_counts.values())
|
||||
|
||||
|
||||
sources_collection.update_one(
|
||||
{"_id": ObjectId(source_id)},
|
||||
{
|
||||
"$set": {
|
||||
"directory_structure": directory_structure,
|
||||
"date": datetime.datetime.now(),
|
||||
"tokens": total_tokens
|
||||
"tokens": total_tokens,
|
||||
}
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error updating directory_structure in DB: {e}", exc_info=True)
|
||||
logging.error(
|
||||
f"Error updating directory_structure in DB: {e}", exc_info=True
|
||||
)
|
||||
|
||||
self.update_state(state="PROGRESS", meta={"current": 100, "status": "Re-ingestion completed"})
|
||||
self.update_state(
|
||||
state="PROGRESS",
|
||||
meta={"current": 100, "status": "Re-ingestion completed"},
|
||||
)
|
||||
|
||||
return {
|
||||
"source_id": source_id,
|
||||
@@ -591,15 +666,16 @@ def reingest_source_worker(self, source_id, user):
|
||||
"chunks_deleted": deleted,
|
||||
}
|
||||
except Exception as e:
|
||||
logging.error(f"Error while processing file changes: {e}", exc_info=True)
|
||||
logging.error(
|
||||
f"Error while processing file changes: {e}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error in reingest_source_worker: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def remote_worker(
|
||||
self,
|
||||
source_data,
|
||||
@@ -651,7 +727,7 @@ def remote_worker(
|
||||
"id": str(id),
|
||||
"type": loader,
|
||||
"remote_data": source_data,
|
||||
"sync_frequency": sync_frequency
|
||||
"sync_frequency": sync_frequency,
|
||||
}
|
||||
|
||||
if operation_mode == "sync":
|
||||
@@ -712,7 +788,7 @@ def sync_worker(self, frequency):
|
||||
self, source_data, name, user, source_type, frequency, retriever, doc_id
|
||||
)
|
||||
sync_counts["total_sync_count"] += 1
|
||||
sync_counts[
|
||||
sync_counts[
|
||||
"sync_success" if resp["status"] == "success" else "sync_failure"
|
||||
] += 1
|
||||
return {
|
||||
@@ -749,15 +825,14 @@ def attachment_worker(self, file_info, user):
|
||||
input_files=[local_path], exclude_hidden=True, errors="ignore"
|
||||
)
|
||||
.load_data()[0]
|
||||
.text,
|
||||
.text,
|
||||
)
|
||||
|
||||
|
||||
|
||||
token_count = num_tokens_from_string(content)
|
||||
if token_count > 100000:
|
||||
content = content[:250000]
|
||||
token_count = num_tokens_from_string(content)
|
||||
|
||||
|
||||
self.update_state(
|
||||
state="PROGRESS", meta={"current": 80, "status": "Storing in database"}
|
||||
)
|
||||
@@ -872,37 +947,49 @@ def ingest_connector(
|
||||
doc_id: Document ID for sync operations (required when operation_mode="sync")
|
||||
sync_frequency: How often to sync ("never", "daily", "weekly", "monthly")
|
||||
"""
|
||||
logging.info(f"Starting remote ingestion from {source_type} for user: {user}, job: {job_name}")
|
||||
logging.info(
|
||||
f"Starting remote ingestion from {source_type} for user: {user}, job: {job_name}"
|
||||
)
|
||||
self.update_state(state="PROGRESS", meta={"current": 1})
|
||||
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
try:
|
||||
# Step 1: Initialize the appropriate loader
|
||||
self.update_state(state="PROGRESS", meta={"current": 10, "status": "Initializing connector"})
|
||||
self.update_state(
|
||||
state="PROGRESS",
|
||||
meta={"current": 10, "status": "Initializing connector"},
|
||||
)
|
||||
|
||||
if not session_token:
|
||||
raise ValueError(f"{source_type} connector requires session_token")
|
||||
|
||||
if not ConnectorCreator.is_supported(source_type):
|
||||
raise ValueError(f"Unsupported connector type: {source_type}. Supported types: {ConnectorCreator.get_supported_connectors()}")
|
||||
raise ValueError(
|
||||
f"Unsupported connector type: {source_type}. Supported types: {ConnectorCreator.get_supported_connectors()}"
|
||||
)
|
||||
|
||||
remote_loader = ConnectorCreator.create_connector(source_type, session_token)
|
||||
remote_loader = ConnectorCreator.create_connector(
|
||||
source_type, session_token
|
||||
)
|
||||
|
||||
# Create a clean config for storage
|
||||
api_source_config = {
|
||||
"file_ids": file_ids or [],
|
||||
"folder_ids": folder_ids or [],
|
||||
"recursive": recursive
|
||||
"recursive": recursive,
|
||||
}
|
||||
|
||||
# Step 2: Download files to temp directory
|
||||
self.update_state(state="PROGRESS", meta={"current": 20, "status": "Downloading files"})
|
||||
download_info = remote_loader.download_to_directory(
|
||||
temp_dir,
|
||||
api_source_config
|
||||
self.update_state(
|
||||
state="PROGRESS", meta={"current": 20, "status": "Downloading files"}
|
||||
)
|
||||
|
||||
if download_info.get("empty_result", False) or not download_info.get("files_downloaded", 0):
|
||||
download_info = remote_loader.download_to_directory(
|
||||
temp_dir, api_source_config
|
||||
)
|
||||
|
||||
if download_info.get("empty_result", False) or not download_info.get(
|
||||
"files_downloaded", 0
|
||||
):
|
||||
logging.warning(f"No files were downloaded from {source_type}")
|
||||
# Create empty result directly instead of calling a separate method
|
||||
return {
|
||||
@@ -913,28 +1000,42 @@ def ingest_connector(
|
||||
"source_config": api_source_config,
|
||||
"directory_structure": "{}",
|
||||
}
|
||||
|
||||
|
||||
# Step 3: Use SimpleDirectoryReader to process downloaded files
|
||||
self.update_state(state="PROGRESS", meta={"current": 40, "status": "Processing files"})
|
||||
self.update_state(
|
||||
state="PROGRESS", meta={"current": 40, "status": "Processing files"}
|
||||
)
|
||||
reader = SimpleDirectoryReader(
|
||||
input_dir=temp_dir,
|
||||
recursive=True,
|
||||
required_exts=[
|
||||
".rst", ".md", ".pdf", ".txt", ".docx", ".csv", ".epub",
|
||||
".html", ".mdx", ".json", ".xlsx", ".pptx", ".png",
|
||||
".jpg", ".jpeg",
|
||||
".rst",
|
||||
".md",
|
||||
".pdf",
|
||||
".txt",
|
||||
".docx",
|
||||
".csv",
|
||||
".epub",
|
||||
".html",
|
||||
".mdx",
|
||||
".json",
|
||||
".xlsx",
|
||||
".pptx",
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
],
|
||||
exclude_hidden=True,
|
||||
file_metadata=metadata_from_filename,
|
||||
)
|
||||
raw_docs = reader.load_data()
|
||||
directory_structure = getattr(reader, 'directory_structure', {})
|
||||
directory_structure = getattr(reader, "directory_structure", {})
|
||||
|
||||
|
||||
|
||||
# Step 4: Process documents (chunking, embedding, etc.)
|
||||
self.update_state(state="PROGRESS", meta={"current": 60, "status": "Processing documents"})
|
||||
|
||||
self.update_state(
|
||||
state="PROGRESS", meta={"current": 60, "status": "Processing documents"}
|
||||
)
|
||||
|
||||
chunker = Chunker(
|
||||
chunking_strategy="classic_chunk",
|
||||
max_tokens=MAX_TOKENS,
|
||||
@@ -942,22 +1043,26 @@ def ingest_connector(
|
||||
duplicate_headers=False,
|
||||
)
|
||||
raw_docs = chunker.chunk(documents=raw_docs)
|
||||
|
||||
|
||||
# Preserve source information in document metadata
|
||||
for doc in raw_docs:
|
||||
if hasattr(doc, 'extra_info') and doc.extra_info:
|
||||
source = doc.extra_info.get('source')
|
||||
if hasattr(doc, "extra_info") and doc.extra_info:
|
||||
source = doc.extra_info.get("source")
|
||||
if source and os.path.isabs(source):
|
||||
# Convert absolute path to relative path
|
||||
doc.extra_info['source'] = os.path.relpath(source, start=temp_dir)
|
||||
|
||||
doc.extra_info["source"] = os.path.relpath(
|
||||
source, start=temp_dir
|
||||
)
|
||||
|
||||
docs = [Document.to_langchain_format(raw_doc) for raw_doc in raw_docs]
|
||||
|
||||
|
||||
if operation_mode == "upload":
|
||||
id = ObjectId()
|
||||
elif operation_mode == "sync":
|
||||
if not doc_id or not ObjectId.is_valid(doc_id):
|
||||
logging.error("Invalid doc_id provided for sync operation: %s", doc_id)
|
||||
logging.error(
|
||||
"Invalid doc_id provided for sync operation: %s", doc_id
|
||||
)
|
||||
raise ValueError("doc_id must be provided for sync operation.")
|
||||
id = ObjectId(doc_id)
|
||||
else:
|
||||
@@ -966,7 +1071,9 @@ def ingest_connector(
|
||||
vector_store_path = os.path.join(temp_dir, "vector_store")
|
||||
os.makedirs(vector_store_path, exist_ok=True)
|
||||
|
||||
self.update_state(state="PROGRESS", meta={"current": 80, "status": "Storing documents"})
|
||||
self.update_state(
|
||||
state="PROGRESS", meta={"current": 80, "status": "Storing documents"}
|
||||
)
|
||||
embed_and_store_documents(docs, vector_store_path, id, self)
|
||||
|
||||
tokens = count_tokens_docs(docs)
|
||||
@@ -979,12 +1086,11 @@ def ingest_connector(
|
||||
"retriever": retriever,
|
||||
"id": str(id),
|
||||
"type": "connector",
|
||||
"remote_data": json.dumps({
|
||||
"provider": source_type,
|
||||
**api_source_config
|
||||
}),
|
||||
"remote_data": json.dumps(
|
||||
{"provider": source_type, **api_source_config}
|
||||
),
|
||||
"directory_structure": json.dumps(directory_structure),
|
||||
"sync_frequency": sync_frequency
|
||||
"sync_frequency": sync_frequency,
|
||||
}
|
||||
|
||||
if operation_mode == "sync":
|
||||
@@ -995,7 +1101,9 @@ def ingest_connector(
|
||||
upload_index(vector_store_path, file_data)
|
||||
|
||||
# Ensure we mark the task as complete
|
||||
self.update_state(state="PROGRESS", meta={"current": 100, "status": "Complete"})
|
||||
self.update_state(
|
||||
state="PROGRESS", meta={"current": 100, "status": "Complete"}
|
||||
)
|
||||
|
||||
logging.info(f"Remote ingestion completed: {job_name}")
|
||||
|
||||
@@ -1005,9 +1113,136 @@ def ingest_connector(
|
||||
"tokens": tokens,
|
||||
"type": source_type,
|
||||
"id": str(id),
|
||||
"status": "complete"
|
||||
"status": "complete",
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error during remote ingestion: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def mcp_oauth(self, config: Dict[str, Any], user_id: str = None) -> Dict[str, Any]:
|
||||
"""Worker to handle MCP OAuth flow asynchronously."""
|
||||
|
||||
logging.info(
|
||||
"[MCP OAuth] Worker started for user_id=%s, config=%s", user_id, config
|
||||
)
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
task_id = self.request.id
|
||||
logging.info("[MCP OAuth] Task ID: %s", task_id)
|
||||
redis_client = get_redis_instance()
|
||||
|
||||
def update_status(status_data: Dict[str, Any]):
|
||||
logging.info("[MCP OAuth] Updating status: %s", status_data)
|
||||
status_key = f"mcp_oauth_status:{task_id}"
|
||||
redis_client.setex(status_key, 600, json.dumps(status_data))
|
||||
|
||||
update_status(
|
||||
{
|
||||
"status": "in_progress",
|
||||
"message": "Starting OAuth flow...",
|
||||
"task_id": task_id,
|
||||
}
|
||||
)
|
||||
|
||||
tool_config = config.copy()
|
||||
tool_config["oauth_task_id"] = task_id
|
||||
logging.info("[MCP OAuth] Initializing MCPTool with config: %s", tool_config)
|
||||
mcp_tool = MCPTool(tool_config, user_id)
|
||||
|
||||
async def run_oauth_discovery():
|
||||
if not mcp_tool._client:
|
||||
mcp_tool._setup_client()
|
||||
return await mcp_tool._execute_with_client("list_tools")
|
||||
|
||||
update_status(
|
||||
{
|
||||
"status": "awaiting_redirect",
|
||||
"message": "Waiting for OAuth redirect...",
|
||||
"task_id": task_id,
|
||||
}
|
||||
)
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
logging.info("[MCP OAuth] Starting event loop for OAuth discovery...")
|
||||
tools_response = loop.run_until_complete(run_oauth_discovery())
|
||||
logging.info(
|
||||
"[MCP OAuth] Tools response after async call: %s", tools_response
|
||||
)
|
||||
|
||||
status_key = f"mcp_oauth_status:{task_id}"
|
||||
redis_status = redis_client.get(status_key)
|
||||
if redis_status:
|
||||
logging.info(
|
||||
"[MCP OAuth] Redis status after async call: %s", redis_status
|
||||
)
|
||||
else:
|
||||
logging.warning(
|
||||
"[MCP OAuth] No Redis status found after async call for key: %s",
|
||||
status_key,
|
||||
)
|
||||
tools = mcp_tool.get_actions_metadata()
|
||||
|
||||
update_status(
|
||||
{
|
||||
"status": "completed",
|
||||
"message": f"OAuth completed successfully. Found {len(tools)} tools.",
|
||||
"tools": tools,
|
||||
"tools_count": len(tools),
|
||||
"task_id": task_id,
|
||||
}
|
||||
)
|
||||
|
||||
logging.info(
|
||||
"[MCP OAuth] OAuth flow completed successfully for task_id=%s", task_id
|
||||
)
|
||||
return {"success": True, "tools": tools, "tools_count": len(tools)}
|
||||
except Exception as e:
|
||||
error_msg = f"OAuth flow failed: {str(e)}"
|
||||
logging.error(
|
||||
"[MCP OAuth] Exception in OAuth discovery: %s", error_msg, exc_info=True
|
||||
)
|
||||
update_status(
|
||||
{
|
||||
"status": "error",
|
||||
"message": error_msg,
|
||||
"error": str(e),
|
||||
"task_id": task_id,
|
||||
}
|
||||
)
|
||||
return {"success": False, "error": error_msg}
|
||||
finally:
|
||||
logging.info("[MCP OAuth] Closing event loop for task_id=%s", task_id)
|
||||
loop.close()
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to initialize OAuth flow: {str(e)}"
|
||||
logging.error(
|
||||
"[MCP OAuth] Exception during initialization: %s", error_msg, exc_info=True
|
||||
)
|
||||
update_status(
|
||||
{
|
||||
"status": "error",
|
||||
"message": error_msg,
|
||||
"error": str(e),
|
||||
"task_id": task_id,
|
||||
}
|
||||
)
|
||||
return {"success": False, "error": error_msg}
|
||||
|
||||
|
||||
def mcp_oauth_status(self, task_id: str) -> Dict[str, Any]:
|
||||
"""Check the status of an MCP OAuth flow."""
|
||||
redis_client = get_redis_instance()
|
||||
status_key = f"mcp_oauth_status:{task_id}"
|
||||
|
||||
status_data = redis_client.get(status_key)
|
||||
if status_data:
|
||||
return json.loads(status_data)
|
||||
return {"status": "not_found", "message": "Status not found"}
|
||||
|
||||
@@ -59,6 +59,8 @@ const endpoints = {
|
||||
MANAGE_SOURCE_FILES: '/api/manage_source_files',
|
||||
MCP_TEST_CONNECTION: '/api/mcp_server/test',
|
||||
MCP_SAVE_SERVER: '/api/mcp_server/save',
|
||||
MCP_OAUTH_STATUS: (task_id: string) =>
|
||||
`/api/mcp_server/oauth_status/${task_id}`,
|
||||
},
|
||||
CONVERSATION: {
|
||||
ANSWER: '/api/answer',
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { getSessionToken } from '../../utils/providerUtils';
|
||||
import apiClient from '../client';
|
||||
import endpoints from '../endpoints';
|
||||
import { getSessionToken } from '../../utils/providerUtils';
|
||||
|
||||
const userService = {
|
||||
getConfig: (): Promise<any> => apiClient.get(endpoints.USER.CONFIG, null),
|
||||
@@ -112,6 +112,8 @@ const userService = {
|
||||
apiClient.post(endpoints.USER.MCP_TEST_CONNECTION, data, token),
|
||||
saveMCPServer: (data: any, token: string | null): Promise<any> =>
|
||||
apiClient.post(endpoints.USER.MCP_SAVE_SERVER, data, token),
|
||||
getMCPOAuthStatus: (task_id: string, token: string | null): Promise<any> =>
|
||||
apiClient.get(endpoints.USER.MCP_OAUTH_STATUS(task_id), token),
|
||||
syncConnector: (
|
||||
docId: string,
|
||||
provider: string,
|
||||
|
||||
@@ -193,17 +193,20 @@
|
||||
"headerName": "Header Name",
|
||||
"timeout": "Timeout (seconds)",
|
||||
"testConnection": "Test Connection",
|
||||
"testing": "Testing...",
|
||||
"saving": "Saving...",
|
||||
"testing": "Testing",
|
||||
"saving": "Saving",
|
||||
"save": "Save",
|
||||
"cancel": "Cancel",
|
||||
"noAuth": "No Authentication",
|
||||
"oauthInProgress": "Waiting for OAuth completion...",
|
||||
"oauthCompleted": "OAuth completed successfully",
|
||||
"placeholders": {
|
||||
"serverUrl": "https://api.example.com",
|
||||
"apiKey": "Your secret API key",
|
||||
"bearerToken": "Your secret token",
|
||||
"username": "Your username",
|
||||
"password": "Your password"
|
||||
"password": "Your password",
|
||||
"oauthScopes": "OAuth scopes (comma separated)"
|
||||
},
|
||||
"errors": {
|
||||
"nameRequired": "Server name is required",
|
||||
@@ -214,7 +217,9 @@
|
||||
"usernameRequired": "Username is required",
|
||||
"passwordRequired": "Password is required",
|
||||
"testFailed": "Connection test failed",
|
||||
"saveFailed": "Failed to save MCP server"
|
||||
"saveFailed": "Failed to save MCP server",
|
||||
"oauthFailed": "OAuth process failed or was cancelled",
|
||||
"oauthTimeout": "OAuth process timed out, please try again"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ const authTypes = [
|
||||
{ label: 'No Authentication', value: 'none' },
|
||||
{ label: 'API Key', value: 'api_key' },
|
||||
{ label: 'Bearer Token', value: 'bearer' },
|
||||
{ label: 'OAuth', value: 'oauth' },
|
||||
// { label: 'Basic Authentication', value: 'basic' },
|
||||
];
|
||||
|
||||
@@ -45,6 +46,8 @@ export default function MCPServerModal({
|
||||
username: '',
|
||||
password: '',
|
||||
timeout: server?.timeout || 30,
|
||||
oauth_scopes: '',
|
||||
oauth_task_id: '',
|
||||
});
|
||||
|
||||
const [loading, setLoading] = useState(false);
|
||||
@@ -52,8 +55,13 @@ export default function MCPServerModal({
|
||||
const [testResult, setTestResult] = useState<{
|
||||
success: boolean;
|
||||
message: string;
|
||||
status?: string;
|
||||
authorization_url?: string;
|
||||
} | null>(null);
|
||||
const [errors, setErrors] = useState<{ [key: string]: string }>({});
|
||||
const oauthPopupRef = useRef<Window | null>(null);
|
||||
const [oauthCompleted, setOAuthCompleted] = useState(false);
|
||||
const [saveActive, setSaveActive] = useState(false);
|
||||
|
||||
useOutsideAlerter(modalRef, () => {
|
||||
if (modalState === 'ACTIVE') {
|
||||
@@ -73,9 +81,12 @@ export default function MCPServerModal({
|
||||
username: '',
|
||||
password: '',
|
||||
timeout: 30,
|
||||
oauth_scopes: '',
|
||||
oauth_task_id: '',
|
||||
});
|
||||
setErrors({});
|
||||
setTestResult(null);
|
||||
setSaveActive(false);
|
||||
};
|
||||
|
||||
const validateForm = () => {
|
||||
@@ -154,10 +165,81 @@ export default function MCPServerModal({
|
||||
} else if (formData.auth_type === 'basic') {
|
||||
config.username = formData.username.trim();
|
||||
config.password = formData.password.trim();
|
||||
} else if (formData.auth_type === 'oauth') {
|
||||
config.oauth_scopes = formData.oauth_scopes
|
||||
.split(',')
|
||||
.map((s) => s.trim())
|
||||
.filter(Boolean);
|
||||
config.oauth_task_id = formData.oauth_task_id.trim();
|
||||
}
|
||||
return config;
|
||||
};
|
||||
|
||||
const pollOAuthStatus = async (
|
||||
taskId: string,
|
||||
onComplete: (result: any) => void,
|
||||
) => {
|
||||
let attempts = 0;
|
||||
const maxAttempts = 60;
|
||||
let popupOpened = false;
|
||||
const poll = async () => {
|
||||
try {
|
||||
const resp = await userService.getMCPOAuthStatus(taskId, token);
|
||||
const data = await resp.json();
|
||||
if (data.authorization_url && !popupOpened) {
|
||||
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
|
||||
oauthPopupRef.current.close();
|
||||
}
|
||||
oauthPopupRef.current = window.open(
|
||||
data.authorization_url,
|
||||
'oauthPopup',
|
||||
'width=600,height=700',
|
||||
);
|
||||
popupOpened = true;
|
||||
}
|
||||
if (data.status === 'completed') {
|
||||
setOAuthCompleted(true);
|
||||
setSaveActive(true);
|
||||
onComplete({
|
||||
...data,
|
||||
success: true,
|
||||
message: t('settings.tools.mcp.oauthCompleted'),
|
||||
});
|
||||
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
|
||||
oauthPopupRef.current.close();
|
||||
}
|
||||
} else if (data.status === 'error' || data.success === false) {
|
||||
setSaveActive(false);
|
||||
onComplete({
|
||||
...data,
|
||||
success: false,
|
||||
message: t('settings.tools.mcp.errors.oauthFailed'),
|
||||
});
|
||||
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
|
||||
oauthPopupRef.current.close();
|
||||
}
|
||||
} else {
|
||||
if (++attempts < maxAttempts) setTimeout(poll, 1000);
|
||||
else {
|
||||
setSaveActive(false);
|
||||
onComplete({
|
||||
success: false,
|
||||
message: t('settings.tools.mcp.errors.oauthTimeout'),
|
||||
});
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
if (++attempts < maxAttempts) setTimeout(poll, 1000);
|
||||
else
|
||||
onComplete({
|
||||
success: false,
|
||||
message: t('settings.tools.mcp.errors.oauthTimeout'),
|
||||
});
|
||||
}
|
||||
};
|
||||
poll();
|
||||
};
|
||||
|
||||
const testConnection = async () => {
|
||||
if (!validateForm()) return;
|
||||
setTesting(true);
|
||||
@@ -167,13 +249,37 @@ export default function MCPServerModal({
|
||||
const response = await userService.testMCPConnection({ config }, token);
|
||||
const result = await response.json();
|
||||
|
||||
setTestResult(result);
|
||||
if (
|
||||
formData.auth_type === 'oauth' &&
|
||||
result.requires_oauth &&
|
||||
result.task_id
|
||||
) {
|
||||
setTestResult({
|
||||
success: true,
|
||||
message: t('settings.tools.mcp.oauthInProgress'),
|
||||
});
|
||||
setOAuthCompleted(false);
|
||||
setSaveActive(false);
|
||||
pollOAuthStatus(result.task_id, (finalResult) => {
|
||||
setTestResult(finalResult);
|
||||
setFormData((prev) => ({
|
||||
...prev,
|
||||
oauth_task_id: result.task_id || '',
|
||||
}));
|
||||
setTesting(false);
|
||||
});
|
||||
} else {
|
||||
setTestResult(result);
|
||||
setSaveActive(result.success === true);
|
||||
setTesting(false);
|
||||
}
|
||||
} catch (error) {
|
||||
setTestResult({
|
||||
success: false,
|
||||
message: t('settings.tools.mcp.errors.testFailed'),
|
||||
});
|
||||
} finally {
|
||||
setOAuthCompleted(false);
|
||||
setSaveActive(false);
|
||||
setTesting(false);
|
||||
}
|
||||
};
|
||||
@@ -305,6 +411,28 @@ export default function MCPServerModal({
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
case 'oauth':
|
||||
return (
|
||||
<div className="mb-10">
|
||||
<div className="mt-6">
|
||||
<Input
|
||||
name="oauth_scopes"
|
||||
type="text"
|
||||
className="rounded-md"
|
||||
value={formData.oauth_scopes}
|
||||
onChange={(e) =>
|
||||
handleInputChange('oauth_scopes', e.target.value)
|
||||
}
|
||||
placeholder={
|
||||
t('settings.tools.mcp.placeholders.oauthScopes') ||
|
||||
'Scopes (comma separated)'
|
||||
}
|
||||
borderVariant="thin"
|
||||
labelBgClassName="bg-white dark:bg-charleston-green-2"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
@@ -331,7 +459,6 @@ export default function MCPServerModal({
|
||||
<div className="space-y-6 py-6">
|
||||
<div>
|
||||
<Input
|
||||
name="name"
|
||||
type="text"
|
||||
className="rounded-md"
|
||||
value={formData.name}
|
||||
@@ -410,7 +537,7 @@ export default function MCPServerModal({
|
||||
|
||||
{testResult && (
|
||||
<div
|
||||
className={`rounded-md p-5 ${
|
||||
className={`rounded-2xl p-5 ${
|
||||
testResult.success
|
||||
? 'bg-green-50 text-green-700 dark:bg-green-900/40 dark:text-green-300'
|
||||
: 'bg-red-50 text-red-700 dark:bg-red-900 dark:text-red-300'
|
||||
@@ -458,7 +585,7 @@ export default function MCPServerModal({
|
||||
</button>
|
||||
<button
|
||||
onClick={handleSave}
|
||||
disabled={loading}
|
||||
disabled={loading || !saveActive}
|
||||
className="bg-purple-30 hover:bg-violets-are-blue w-full rounded-3xl px-6 py-2 text-sm font-medium text-white transition-all disabled:opacity-50 sm:w-auto"
|
||||
>
|
||||
{loading ? (
|
||||
|
||||
Reference in New Issue
Block a user