feat: Implement OAuth flow for MCP server integration

- Added MCPOAuthManager to handle OAuth authorization.
- Updated MCPServerSave resource to manage OAuth status and callback.
- Introduced new endpoints for OAuth status and callback handling.
- Enhanced user interface to support OAuth authentication type.
- Implemented polling mechanism for OAuth status in MCPServerModal.
- Updated frontend services and endpoints to accommodate new OAuth features.
- Improved error handling and user feedback for OAuth processes.
This commit is contained in:
Siddhant Rai
2025-09-26 02:44:08 +05:30
parent 00b4e133d4
commit 3b27db36f2
9 changed files with 991 additions and 289 deletions

View File

@@ -1,47 +1,39 @@
import asyncio
import base64
import json
import logging
import time
from typing import Any, Dict, List, Optional
from urllib.parse import parse_qs, urlparse
from application.agents.tools.base import Tool
from application.security.encryption import decrypt_credentials
from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
from application.cache import get_redis_instance
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.security.encryption import decrypt_credentials
from fastmcp import Client
from fastmcp.client.auth import BearerAuth
from fastmcp.client.auth.oauth import OAuth as FastMCPOAuth
from fastmcp.client.transports import (
SSETransport,
StdioTransport,
StreamableHttpTransport,
)
from mcp.client.auth import OAuthClientProvider, TokenStorage
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
from pydantic import AnyHttpUrl, ValidationError
from redis import Redis
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
_mcp_clients_cache = {}
class DocsGPTOAuth(FastMCPOAuth):
"""Custom OAuth handler that integrates with DocsGPT frontend instead of opening browser."""
def __init__(self, *args, **kwargs):
self.auth_url_callback = kwargs.pop("auth_url_callback", None)
self.auth_code_callback = kwargs.pop("auth_code_callback", None)
super().__init__(*args, **kwargs)
async def redirect_handler(self, authorization_url: str) -> None:
"""Override to send auth URL to frontend instead of opening browser."""
if self.auth_url_callback:
self.auth_url_callback(authorization_url)
else:
raise Exception("OAuth authorization URL callback not configured")
async def callback_handler(self) -> tuple[str, str | None]:
"""Override to wait for auth code from frontend instead of local server."""
if self.auth_code_callback:
auth_code, state = await self.auth_code_callback()
return auth_code, state
else:
raise Exception("OAuth callback handler not configured")
class MCPTool(Tool):
"""
MCP Tool
@@ -67,6 +59,7 @@ class MCPTool(Tool):
user_id: User ID for decrypting credentials (required if encrypted_credentials exist)
"""
self.config = config
self.user_id = user_id
self.server_url = config.get("server_url", "")
self.transport_type = config.get("transport_type", "auto")
self.auth_type = config.get("auth_type", "none")
@@ -80,22 +73,16 @@ class MCPTool(Tool):
)
else:
self.auth_credentials = config.get("auth_credentials", {})
# OAuth specific configuration
self.oauth_scopes = config.get("oauth_scopes", [])
self.oauth_task_id = config.get("oauth_task_id", None)
self.oauth_client_name = config.get("oauth_client_name", "DocsGPT-MCP")
# OAuth callback handlers (to be set by frontend)
self.oauth_auth_url_callback = None
self.oauth_auth_code_callback = None
self.redirect_uri = f"{settings.API_URL}/api/mcp_server/callback"
self.available_tools = []
self._cache_key = self._generate_cache_key()
self._client = None
# Only validate and setup if server_url is provided and not OAuth
# OAuth setup will happen after callbacks are set
if self.server_url and self.auth_type != "oauth":
self._setup_client()
@@ -104,8 +91,6 @@ class MCPTool(Tool):
"""Generate a unique cache key for this MCP server configuration."""
auth_key = ""
if self.auth_type == "oauth":
# For OAuth, use scopes and client name as part of the key
scopes_str = ",".join(self.oauth_scopes) if self.oauth_scopes else "none"
auth_key = f"oauth:{self.oauth_client_name}:{scopes_str}"
elif self.auth_type in ["bearer"]:
@@ -137,22 +122,17 @@ class MCPTool(Tool):
auth = None
if self.auth_type == "oauth":
# Ensure callbacks are configured before creating OAuth instance
if not self.oauth_auth_url_callback or not self.oauth_auth_code_callback:
raise Exception(
"OAuth callbacks not configured. Call set_oauth_callbacks() first."
)
# Use custom OAuth handler for frontend integration
redis_client = get_redis_instance()
auth = DocsGPTOAuth(
mcp_url=self.server_url,
scopes=self.oauth_scopes,
client_name=self.oauth_client_name,
auth_url_callback=self.oauth_auth_url_callback,
auth_code_callback=self.oauth_auth_code_callback,
redis_client=redis_client,
redirect_uri=self.redirect_uri,
task_id=self.oauth_task_id,
db=db,
user_id=self.user_id,
)
elif self.auth_type in ["bearer"]:
elif self.auth_type == "bearer":
token = self.auth_credentials.get(
"bearer_token", ""
) or self.auth_credentials.get("access_token", "")
@@ -202,6 +182,33 @@ class MCPTool(Tool):
else:
return StreamableHttpTransport(url=self.server_url, headers=headers)
def _format_tools(self, tools_response) -> List[Dict]:
"""Format tools response to match expected format."""
if hasattr(tools_response, "tools"):
tools = tools_response.tools
elif isinstance(tools_response, list):
tools = tools_response
else:
tools = []
tools_dict = []
for tool in tools:
if hasattr(tool, "name"):
tool_dict = {
"name": tool.name,
"description": tool.description,
}
if hasattr(tool, "inputSchema"):
tool_dict["inputSchema"] = tool.inputSchema
tools_dict.append(tool_dict)
elif isinstance(tool, dict):
tools_dict.append(tool)
else:
if hasattr(tool, "model_dump"):
tools_dict.append(tool.model_dump())
else:
tools_dict.append({"name": str(tool), "description": ""})
return tools_dict
async def _execute_with_client(self, operation: str, *args, **kwargs):
"""Execute operation with FastMCP client."""
if not self._client:
@@ -211,32 +218,8 @@ class MCPTool(Tool):
return await self._client.ping()
elif operation == "list_tools":
tools_response = await self._client.list_tools()
if hasattr(tools_response, "tools"):
tools = tools_response.tools
elif isinstance(tools_response, list):
tools = tools_response
else:
tools = []
tools_dict = []
for tool in tools:
if hasattr(tool, "name"):
tool_dict = {
"name": tool.name,
"description": tool.description,
}
if hasattr(tool, "inputSchema"):
tool_dict["inputSchema"] = tool.inputSchema
tools_dict.append(tool_dict)
elif isinstance(tool, dict):
tools_dict.append(tool)
else:
if hasattr(tool, "model_dump"):
tools_dict.append(tool.model_dump())
else:
tools_dict.append({"name": str(tool), "description": ""})
return tools_dict
self.available_tools = self._format_tools(tools_response)
return self.available_tools
elif operation == "call_tool":
tool_name = args[0]
tool_args = kwargs
@@ -251,12 +234,8 @@ class MCPTool(Tool):
def _run_async_operation(self, operation: str, *args, **kwargs):
"""Run async operation in sync context."""
try:
# Check if there's already a running event loop
try:
loop = asyncio.get_running_loop()
# If we're in an async context, we need to run in a separate thread
import concurrent.futures
def run_in_thread():
@@ -273,8 +252,6 @@ class MCPTool(Tool):
future = executor.submit(run_in_thread)
return future.result(timeout=self.timeout)
except RuntimeError:
# No running loop, we can create one
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
@@ -284,13 +261,8 @@ class MCPTool(Tool):
finally:
loop.close()
except Exception as e:
# If async fails, try to give a better error message for OAuth
if self.auth_type == "oauth" and "callback not configured" in str(e):
raise Exception(
"OAuth callbacks not configured. Call set_oauth_callbacks() first."
)
raise e
print(f"Error occurred while running async operation: {e}")
raise
def discover_tools(self) -> List[Dict]:
"""
@@ -377,8 +349,6 @@ class MCPTool(Tool):
if not self._client:
self._setup_client()
try:
# For OAuth, we need to handle async operations differently
if self.auth_type == "oauth":
return self._test_oauth_connection()
else:
@@ -412,40 +382,21 @@ class MCPTool(Tool):
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"ping_supported": ping_success,
"tools": [tool.get("name", "unknown") for tool in tools[:5]],
"tools": [tool.get("name", "unknown") for tool in tools],
}
def _test_oauth_connection(self) -> Dict:
"""Test connection for OAuth auth type with proper async handling."""
try:
# Ensure callbacks are configured before proceeding
if not self.oauth_auth_url_callback or not self.oauth_auth_code_callback:
return {
"success": False,
"message": "OAuth callbacks not configured. Call set_oauth_callbacks() first.",
"tools_count": 0,
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"error_type": "ConfigurationError",
}
# Ensure client is set up with proper callbacks
if not self._client:
self._setup_client()
# For OAuth, we use a simpler approach - just try to discover tools
# This will trigger the OAuth flow if needed
tools = self.discover_tools()
task = mcp_oauth_task.delay(config=self.config, user=self.user_id)
if not task:
raise Exception("Failed to start OAuth authentication")
return {
"success": True,
"message": f"Successfully connected to OAuth MCP server. Found {len(tools)} tools.",
"tools_count": len(tools),
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"ping_supported": False, # Skip ping for OAuth to avoid complexity
"tools": [tool.get("name", "unknown") for tool in tools[:5]],
"requires_oauth": True,
"task_id": task.id,
"status": "pending",
"message": "OAuth flow started",
}
except Exception as e:
return {
@@ -457,67 +408,6 @@ class MCPTool(Tool):
"error_type": type(e).__name__,
}
def set_oauth_callbacks(self, auth_url_callback, auth_code_callback):
"""
Set OAuth callback handlers for frontend integration.
Args:
auth_url_callback: Function to call with authorization URL
auth_code_callback: Async function that returns (auth_code, state) tuple
"""
self.oauth_auth_url_callback = auth_url_callback
self.oauth_auth_code_callback = auth_code_callback
# Clear the client so it gets recreated with the new callbacks
self._client = None
# Also clear from cache to ensure fresh creation
global _mcp_clients_cache
if self._cache_key in _mcp_clients_cache:
del _mcp_clients_cache[self._cache_key]
def clear_oauth_cache(self):
"""
Clear OAuth cache to force fresh authentication.
This will remove stored tokens and client info for the server.
"""
if self.auth_type == "oauth":
try:
from fastmcp.client.auth.oauth import FileTokenStorage
storage = FileTokenStorage(server_url=self.server_url)
storage.clear()
print(f"✅ Cleared OAuth cache for {self.server_url}")
except Exception as e:
print(f"⚠️ Failed to clear OAuth cache: {e}")
# Also clear our internal client cache
global _mcp_clients_cache
if self._cache_key in _mcp_clients_cache:
del _mcp_clients_cache[self._cache_key]
print(f"✅ Cleared internal client cache")
@staticmethod
def clear_all_oauth_cache():
"""
Clear all OAuth cache for all servers.
This will remove all stored tokens and client info.
"""
try:
from fastmcp.client.auth.oauth import FileTokenStorage
FileTokenStorage.clear_all()
print(f"✅ Cleared all OAuth cache")
except Exception as e:
print(f"⚠️ Failed to clear all OAuth cache: {e}")
# Also clear all internal client cache
global _mcp_clients_cache
_mcp_clients_cache.clear()
print(f"✅ Cleared all internal client cache")
def get_actions_metadata(self) -> List[Dict]:
"""
Get metadata for all available actions.
@@ -666,3 +556,306 @@ class MCPTool(Tool):
"required": False,
},
}
class DocsGPTOAuth(OAuthClientProvider):
"""
Custom OAuth handler for DocsGPT that uses frontend redirect instead of browser.
"""
def __init__(
self,
mcp_url: str,
redirect_uri: str,
redis_client: Redis | None = None,
redis_prefix: str = "mcp_oauth:",
task_id: str = None,
scopes: str | list[str] | None = None,
client_name: str = "DocsGPT-MCP",
user_id=None,
db=None,
additional_client_metadata: dict[str, Any] | None = None,
):
"""
Initialize custom OAuth client provider for DocsGPT.
Args:
mcp_url: Full URL to the MCP endpoint
redirect_uri: Custom redirect URI for DocsGPT frontend
redis_client: Redis client for storing auth state
redis_prefix: Prefix for Redis keys
task_id: Task ID for tracking auth status
scopes: OAuth scopes to request
client_name: Name for this client during registration
user_id: User ID for token storage
db: Database instance for token storage
additional_client_metadata: Extra fields for OAuthClientMetadata
"""
self.redirect_uri = redirect_uri
self.redis_client = redis_client
self.redis_prefix = redis_prefix
self.task_id = task_id
self.user_id = user_id
self.db = db
parsed_url = urlparse(mcp_url)
self.server_base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
if isinstance(scopes, list):
scopes = " ".join(scopes)
client_metadata = OAuthClientMetadata(
client_name=client_name,
redirect_uris=[AnyHttpUrl(redirect_uri)],
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
scope=scopes,
**(additional_client_metadata or {}),
)
storage = DBTokenStorage(
server_url=self.server_base_url, user_id=self.user_id, db_client=self.db
)
super().__init__(
server_url=self.server_base_url,
client_metadata=client_metadata,
storage=storage,
redirect_handler=self.redirect_handler,
callback_handler=self.callback_handler,
)
self.auth_url = None
self.extracted_state = None
def _process_auth_url(self, authorization_url: str) -> tuple[str, str]:
"""Process authorization URL to extract state"""
try:
parsed_url = urlparse(authorization_url)
query_params = parse_qs(parsed_url.query)
state_params = query_params.get("state", [])
if state_params:
state = state_params[0]
else:
raise ValueError("No state in auth URL")
return authorization_url, state
except Exception as e:
raise Exception(f"Failed to process auth URL: {e}")
async def redirect_handler(self, authorization_url: str) -> None:
"""Store auth URL and state in Redis for frontend to use."""
auth_url, state = self._process_auth_url(authorization_url)
logging.info(
"[DocsGPTOAuth] Processed auth_url: %s, state: %s", auth_url, state
)
self.auth_url = auth_url
self.extracted_state = state
if self.redis_client and self.extracted_state:
key = f"{self.redis_prefix}auth_url:{self.extracted_state}"
self.redis_client.setex(key, 600, auth_url)
logging.info("[DocsGPTOAuth] Stored auth_url in Redis: %s", key)
if self.task_id:
status_key = f"mcp_oauth_status:{self.task_id}"
status_data = {
"status": "requires_redirect",
"message": "OAuth authorization required",
"authorization_url": self.auth_url,
"state": self.extracted_state,
"requires_oauth": True,
"task_id": self.task_id,
}
self.redis_client.setex(status_key, 600, json.dumps(status_data))
async def callback_handler(self) -> tuple[str, str | None]:
"""Wait for auth code from Redis using the state value."""
if not self.redis_client or not self.extracted_state:
raise Exception("Redis client or state not configured for OAuth")
poll_interval = 1
max_wait_time = 300
code_key = f"{self.redis_prefix}code:{self.extracted_state}"
if self.task_id:
status_key = f"mcp_oauth_status:{self.task_id}"
status_data = {
"status": "awaiting_callback",
"message": "Waiting for OAuth callback...",
"authorization_url": self.auth_url,
"state": self.extracted_state,
"requires_oauth": True,
"task_id": self.task_id,
}
self.redis_client.setex(status_key, 600, json.dumps(status_data))
start_time = time.time()
while time.time() - start_time < max_wait_time:
code_data = self.redis_client.get(code_key)
if code_data:
code = code_data.decode()
returned_state = self.extracted_state
self.redis_client.delete(code_key)
self.redis_client.delete(
f"{self.redis_prefix}auth_url:{self.extracted_state}"
)
self.redis_client.delete(
f"{self.redis_prefix}state:{self.extracted_state}"
)
if self.task_id:
status_data = {
"status": "callback_received",
"message": "OAuth callback received, completing authentication...",
"task_id": self.task_id,
}
self.redis_client.setex(status_key, 600, json.dumps(status_data))
return code, returned_state
error_key = f"{self.redis_prefix}error:{self.extracted_state}"
error_data = self.redis_client.get(error_key)
if error_data:
error_msg = error_data.decode()
self.redis_client.delete(error_key)
self.redis_client.delete(
f"{self.redis_prefix}auth_url:{self.extracted_state}"
)
self.redis_client.delete(
f"{self.redis_prefix}state:{self.extracted_state}"
)
raise Exception(f"OAuth error: {error_msg}")
await asyncio.sleep(poll_interval)
self.redis_client.delete(f"{self.redis_prefix}auth_url:{self.extracted_state}")
self.redis_client.delete(f"{self.redis_prefix}state:{self.extracted_state}")
raise Exception("OAuth callback timeout: no code received within 5 minutes")
class DBTokenStorage(TokenStorage):
def __init__(self, server_url: str, user_id: str, db_client):
self.server_url = server_url
self.user_id = user_id
self.db_client = db_client
self.collection = db_client["connector_sessions"]
@staticmethod
def get_base_url(url: str) -> str:
parsed = urlparse(url)
return f"{parsed.scheme}://{parsed.netloc}"
def get_db_key(self) -> dict:
return {
"server_url": self.get_base_url(self.server_url),
"user_id": self.user_id,
}
async def get_tokens(self) -> OAuthToken | None:
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
if not doc or "tokens" not in doc:
return None
try:
tokens = OAuthToken.model_validate(doc["tokens"])
return tokens
except ValidationError as e:
logging.error(f"Could not load tokens: {e}")
return None
async def set_tokens(self, tokens: OAuthToken) -> None:
await asyncio.to_thread(
self.collection.update_one,
self.get_db_key(),
{"$set": {"tokens": tokens.model_dump()}},
True,
)
logging.info(f"Saved tokens for {self.get_base_url(self.server_url)}")
async def get_client_info(self) -> OAuthClientInformationFull | None:
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
if not doc or "client_info" not in doc:
return None
try:
client_info = OAuthClientInformationFull.model_validate(doc["client_info"])
tokens = await self.get_tokens()
if tokens is None:
logging.debug(
"No tokens found, clearing client info to force fresh registration."
)
await asyncio.to_thread(
self.collection.update_one,
self.get_db_key(),
{"$unset": {"client_info": ""}},
)
return None
return client_info
except ValidationError as e:
logging.error(f"Could not load client info: {e}")
return None
def _serialize_client_info(self, info: dict) -> dict:
if "redirect_uris" in info and isinstance(info["redirect_uris"], list):
info["redirect_uris"] = [str(u) for u in info["redirect_uris"]]
return info
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
serialized_info = self._serialize_client_info(client_info.model_dump())
await asyncio.to_thread(
self.collection.update_one,
self.get_db_key(),
{"$set": {"client_info": serialized_info}},
True,
)
logging.info(f"Saved client info for {self.get_base_url(self.server_url)}")
async def clear(self) -> None:
await asyncio.to_thread(self.collection.delete_one, self.get_db_key())
logging.info(f"Cleared OAuth cache for {self.get_base_url(self.server_url)}")
@classmethod
async def clear_all(cls, db_client) -> None:
collection = db_client["connector_sessions"]
await asyncio.to_thread(collection.delete_many, {})
logging.info("Cleared all OAuth client cache data.")
class MCPOAuthManager:
"""Manager for handling MCP OAuth callbacks."""
def __init__(self, redis_client: Redis | None, redis_prefix: str = "mcp_oauth:"):
self.redis_client = redis_client
self.redis_prefix = redis_prefix
def handle_oauth_callback(
self, state: str, code: str, error: Optional[str] = None
) -> bool:
"""
Handle OAuth callback from provider.
Args:
state: The state parameter from OAuth callback
code: The authorization code from OAuth callback
error: Error message if OAuth failed
Returns:
True if successful, False otherwise
"""
try:
if not self.redis_client or not state:
raise Exception("Redis client or state not provided")
if error:
error_key = f"{self.redis_prefix}error:{state}"
self.redis_client.setex(error_key, 300, error)
raise Exception(f"OAuth error received: {error}")
code_key = f"{self.redis_prefix}code:{state}"
self.redis_client.setex(code_key, 300, code)
state_key = f"{self.redis_prefix}state:{state}"
self.redis_client.setex(state_key, 300, "completed")
return True
except Exception as e:
logging.error(f"Error handling OAuth callback: {e}")
return False
def get_oauth_status(self, task_id: str) -> Dict[str, Any]:
"""Get current status of OAuth flow using provided task_id."""
if not task_id:
return {"status": "not_started", "message": "OAuth flow not started"}
return mcp_oauth_status_task(task_id)

View File

@@ -8,6 +8,7 @@ import uuid
import zipfile
from functools import wraps
from typing import Optional, Tuple
from urllib.parse import unquote
from bson.binary import Binary, UuidRepresentation
from bson.dbref import DBRef
@@ -25,7 +26,7 @@ from flask_restx import fields, inputs, Namespace, Resource
from pymongo import ReturnDocument
from werkzeug.utils import secure_filename
from application.agents.tools.mcp_tool import MCPTool
from application.agents.tools.mcp_tool import MCPOAuthManager, MCPTool
from application.agents.tools.tool_manager import ToolManager
from application.api import api
@@ -37,6 +38,8 @@ from application.api.user.tasks import (
process_agent_webhook,
store_attachment,
)
from application.cache import get_redis_instance
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.parser.connectors.connector_creator import ConnectorCreator
@@ -4303,7 +4306,7 @@ class TestMCPServerConfig(Resource):
test_config = config.copy()
test_config["auth_credentials"] = auth_credentials
mcp_tool = MCPTool(test_config, user)
mcp_tool = MCPTool(config=test_config, user_id=user)
result = mcp_tool.test_connection()
return make_response(jsonify(result), 200)
@@ -4371,8 +4374,33 @@ class MCPServerSave(Resource):
mcp_config = config.copy()
mcp_config["auth_credentials"] = auth_credentials
if auth_type == "none" or auth_credentials:
mcp_tool = MCPTool(mcp_config, user)
if auth_type == "oauth":
if not config.get("oauth_task_id"):
return make_response(
jsonify(
{
"success": False,
"error": "Connection not authorized. Please complete the OAuth authorization first.",
}
),
400,
)
redis_client = get_redis_instance()
manager = MCPOAuthManager(redis_client)
result = manager.get_oauth_status(config["oauth_task_id"])
if not result.get("status") == "completed":
return make_response(
jsonify(
{
"success": False,
"error": "OAuth failed or not completed. Please try authorizing again.",
}
),
400,
)
actions_metadata = result.get("tools", [])
elif auth_type == "none" or auth_credentials:
mcp_tool = MCPTool(config=mcp_config, user_id=user)
mcp_tool.discover_tools()
actions_metadata = mcp_tool.get_actions_metadata()
else:
@@ -4455,3 +4483,96 @@ class MCPServerSave(Resource):
),
500,
)
@user_ns.route("/api/mcp_server/callback")
class MCPOAuthCallback(Resource):
@api.expect(
api.model(
"MCPServerCallbackModel",
{
"code": fields.String(required=True, description="Authorization code"),
"state": fields.String(required=True, description="State parameter"),
"error": fields.String(
required=False, description="Error message (if any)"
),
},
)
)
@api.doc(
description="Handle OAuth callback by providing the authorization code and state"
)
def get(self):
code = request.args.get("code")
state = request.args.get("state")
error = request.args.get("error")
if error:
return redirect(
f"/api/connectors/callback-status?status=error&message=OAuth+error:+{error}.+Please+try+again+and+make+sure+to+grant+all+requested+permissions,+including+offline+access.&provider=mcp_tool"
)
if not code or not state:
return redirect(
f"/api/connectors/callback-status?status=error&message=Authorization+code+or+state+not+provided.+Please+complete+the+authorization+process+and+make+sure+to+grant+offline+access.&provider=mcp_tool"
)
try:
redis_client = get_redis_instance()
if not redis_client:
return redirect(
f"/api/connectors/callback-status?status=error&message=Internal+server+error:+Redis+not+available.&provider=mcp_tool"
)
code = unquote(code)
manager = MCPOAuthManager(redis_client)
success = manager.handle_oauth_callback(state, code, error)
if success:
return redirect(
f"/api/connectors/callback-status?status=success&message=Authorization+code+received+successfully.+You+can+close+this+window.&provider=mcp_tool"
)
else:
return redirect(
f"/api/connectors/callback-status?status=error&message=OAuth+callback+failed.&provider=mcp_tool"
)
except Exception as e:
current_app.logger.error(
f"Error handling MCP OAuth callback: {str(e)}", exc_info=True
)
return redirect(
f"/api/connectors/callback-status?status=error&message=Internal+server+error:+{str(e)}.&provider=mcp_tool"
)
@user_ns.route("/api/mcp_server/oauth_status/<string:task_id>")
class MCPOAuthStatus(Resource):
def get(self, task_id):
"""
Get current status of OAuth flow.
Frontend should poll this endpoint periodically.
"""
try:
redis_client = get_redis_instance()
status_key = f"mcp_oauth_status:{task_id}"
status_data = redis_client.get(status_key)
if status_data:
status = json.loads(status_data)
return make_response(
jsonify({"success": True, "task_id": task_id, **status})
)
else:
return make_response(
jsonify(
{
"success": False,
"error": "Task not found or expired",
"task_id": task_id,
}
),
404,
)
except Exception as e:
current_app.logger.error(
f"Error getting OAuth status for task {task_id}: {str(e)}"
)
return make_response(
jsonify({"success": False, "error": str(e), "task_id": task_id}), 500
)

View File

@@ -5,6 +5,8 @@ from application.worker import (
agent_webhook_worker,
attachment_worker,
ingest_worker,
mcp_oauth,
mcp_oauth_status,
remote_worker,
sync_worker,
)
@@ -25,6 +27,7 @@ def ingest_remote(self, source_data, job_name, user, loader):
@celery.task(bind=True)
def reingest_source_task(self, source_id, user):
from application.worker import reingest_source_worker
resp = reingest_source_worker(self, source_id, user)
return resp
@@ -60,9 +63,10 @@ def ingest_connector_task(
retriever="classic",
operation_mode="upload",
doc_id=None,
sync_frequency="never"
sync_frequency="never",
):
from application.worker import ingest_connector
resp = ingest_connector(
self,
job_name,
@@ -75,7 +79,7 @@ def ingest_connector_task(
retriever=retriever,
operation_mode=operation_mode,
doc_id=doc_id,
sync_frequency=sync_frequency
sync_frequency=sync_frequency,
)
return resp
@@ -94,3 +98,15 @@ def setup_periodic_tasks(sender, **kwargs):
timedelta(days=30),
schedule_syncs.s("monthly"),
)
@celery.task(bind=True)
def mcp_oauth_task(self, config, user):
resp = mcp_oauth(self, config, user)
return resp
@celery.task(bind=True)
def mcp_oauth_status_task(self, task_id):
resp = mcp_oauth_status(self, task_id)
return resp

View File

@@ -12,6 +12,7 @@ esprima==4.0.1
esutils==1.0.1
Flask==3.1.1
faiss-cpu==1.9.0.post1
fastmcp==2.11.0
flask-restx==1.3.0
google-genai==1.3.0
google-api-python-client==2.179.0
@@ -56,13 +57,13 @@ prompt-toolkit==3.0.51
protobuf==5.29.3
psycopg2-binary==2.9.10
py==1.11.0
pydantic==2.10.6
pydantic-core==2.27.2
pydantic-settings==2.7.1
pydantic
pydantic-core
pydantic-settings
pymongo==4.11.3
pypdf==5.5.0
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
python-dotenv
python-jose==3.4.0
python-pptx==1.0.2
redis==5.2.1
@@ -82,7 +83,7 @@ tzdata==2024.2
urllib3==2.3.0
vine==5.1.0
wcwidth==0.2.13
werkzeug==3.1.3
werkzeug>=3.1.0,<3.1.2
yarl==1.20.0
markdownify==1.1.0
tldextract==5.1.3

View File

@@ -19,6 +19,7 @@ from bson.objectid import ObjectId
from application.agents.agent_creator import AgentCreator
from application.api.answer.services.stream_processor import get_prompt
from application.cache import get_redis_instance
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.parser.chunking import Chunker
@@ -214,8 +215,7 @@ def run_agent_logic(agent_config, input_data):
def ingest_worker(
self, directory, formats, job_name, file_path, filename, user,
retriever="classic"
self, directory, formats, job_name, file_path, filename, user, retriever="classic"
):
"""
Ingest and process documents.
@@ -270,7 +270,9 @@ def ingest_worker(
with open(local_file_path, "wb") as f:
f.write(file_data.read())
except Exception as e:
logging.error(f"Error downloading file {storage_file_path}: {e}")
logging.error(
f"Error downloading file {storage_file_path}: {e}"
)
continue
else:
# Handle single file case
@@ -285,7 +287,10 @@ def ingest_worker(
if temp_filename.endswith(".zip"):
logging.info(f"Extracting zip file: {temp_filename}")
extract_zip_recursive(
temp_file_path, temp_dir, current_depth=0, max_depth=RECURSION_DEPTH
temp_file_path,
temp_dir,
current_depth=0,
max_depth=RECURSION_DEPTH,
)
self.update_state(state="PROGRESS", meta={"current": 1})
@@ -301,7 +306,7 @@ def ingest_worker(
)
raw_docs = reader.load_data()
directory_structure = getattr(reader, 'directory_structure', {})
directory_structure = getattr(reader, "directory_structure", {})
logging.info(f"Directory structure from reader: {directory_structure}")
chunker = Chunker(
@@ -371,7 +376,10 @@ def reingest_source_worker(self, source_id, user):
try:
from application.vectorstore.vector_creator import VectorCreator
self.update_state(state="PROGRESS", meta={"current": 10, "status": "Initializing re-ingestion scan"})
self.update_state(
state="PROGRESS",
meta={"current": 10, "status": "Initializing re-ingestion scan"},
)
source = sources_collection.find_one({"_id": ObjectId(source_id), "user": user})
if not source:
@@ -380,7 +388,9 @@ def reingest_source_worker(self, source_id, user):
storage = StorageCreator.get_storage()
source_file_path = source.get("file_path", "")
self.update_state(state="PROGRESS", meta={"current": 20, "status": "Scanning current files"})
self.update_state(
state="PROGRESS", meta={"current": 20, "status": "Scanning current files"}
)
with tempfile.TemporaryDirectory() as temp_dir:
# Download all files from storage to temp directory, preserving directory structure
@@ -391,7 +401,6 @@ def reingest_source_worker(self, source_id, user):
if storage.is_directory(storage_file_path):
continue
rel_path = os.path.relpath(storage_file_path, source_file_path)
local_file_path = os.path.join(temp_dir, rel_path)
@@ -403,23 +412,39 @@ def reingest_source_worker(self, source_id, user):
with open(local_file_path, "wb") as f:
f.write(file_data.read())
except Exception as e:
logging.error(f"Error downloading file {storage_file_path}: {e}")
logging.error(
f"Error downloading file {storage_file_path}: {e}"
)
continue
reader = SimpleDirectoryReader(
input_dir=temp_dir,
recursive=True,
required_exts=[
".rst", ".md", ".pdf", ".txt", ".docx", ".csv", ".epub",
".html", ".mdx", ".json", ".xlsx", ".pptx", ".png",
".jpg", ".jpeg",
".rst",
".md",
".pdf",
".txt",
".docx",
".csv",
".epub",
".html",
".mdx",
".json",
".xlsx",
".pptx",
".png",
".jpg",
".jpeg",
],
exclude_hidden=True,
file_metadata=metadata_from_filename,
)
reader.load_data()
directory_structure = reader.directory_structure
logging.info(f"Directory structure built with token counts: {directory_structure}")
logging.info(
f"Directory structure built with token counts: {directory_structure}"
)
try:
old_directory_structure = source.get("directory_structure") or {}
@@ -433,11 +458,17 @@ def reingest_source_worker(self, source_id, user):
files = set()
if isinstance(struct, dict):
for name, meta in struct.items():
current_path = os.path.join(prefix, name) if prefix else name
if isinstance(meta, dict) and ("type" in meta and "size_bytes" in meta):
current_path = (
os.path.join(prefix, name) if prefix else name
)
if isinstance(meta, dict) and (
"type" in meta and "size_bytes" in meta
):
files.add(current_path)
elif isinstance(meta, dict):
files |= _flatten_directory_structure(meta, current_path)
files |= _flatten_directory_structure(
meta, current_path
)
return files
old_files = _flatten_directory_structure(old_directory_structure)
@@ -457,7 +488,9 @@ def reingest_source_worker(self, source_id, user):
logging.info("No files removed since last ingest.")
except Exception as e:
logging.error(f"Error comparing directory structures: {e}", exc_info=True)
logging.error(
f"Error comparing directory structures: {e}", exc_info=True
)
added_files = []
removed_files = []
try:
@@ -477,14 +510,21 @@ def reingest_source_worker(self, source_id, user):
settings.EMBEDDINGS_KEY,
)
self.update_state(state="PROGRESS", meta={"current": 40, "status": "Processing file changes"})
self.update_state(
state="PROGRESS",
meta={"current": 40, "status": "Processing file changes"},
)
# 1) Delete chunks from removed files
deleted = 0
if removed_files:
try:
for ch in vector_store.get_chunks() or []:
metadata = ch.get("metadata", {}) if isinstance(ch, dict) else getattr(ch, "metadata", {})
metadata = (
ch.get("metadata", {})
if isinstance(ch, dict)
else getattr(ch, "metadata", {})
)
raw_source = metadata.get("source")
source_file = str(raw_source) if raw_source else ""
@@ -496,10 +536,17 @@ def reingest_source_worker(self, source_id, user):
vector_store.delete_chunk(cid)
deleted += 1
except Exception as de:
logging.error(f"Failed deleting chunk {cid}: {de}")
logging.info(f"Deleted {deleted} chunks from {len(removed_files)} removed files")
logging.error(
f"Failed deleting chunk {cid}: {de}"
)
logging.info(
f"Deleted {deleted} chunks from {len(removed_files)} removed files"
)
except Exception as e:
logging.error(f"Error during deletion of removed file chunks: {e}", exc_info=True)
logging.error(
f"Error during deletion of removed file chunks: {e}",
exc_info=True,
)
# 2) Add chunks from new files
added = 0
@@ -528,39 +575,62 @@ def reingest_source_worker(self, source_id, user):
)
chunked_new = chunker_new.chunk(documents=raw_docs_new)
for file_path, token_count in reader_new.file_token_counts.items():
for (
file_path,
token_count,
) in reader_new.file_token_counts.items():
try:
rel_path = os.path.relpath(file_path, start=temp_dir)
rel_path = os.path.relpath(
file_path, start=temp_dir
)
path_parts = rel_path.split(os.sep)
current_dir = directory_structure
for part in path_parts[:-1]:
if part in current_dir and isinstance(current_dir[part], dict):
if part in current_dir and isinstance(
current_dir[part], dict
):
current_dir = current_dir[part]
else:
break
filename = path_parts[-1]
if filename in current_dir and isinstance(current_dir[filename], dict):
current_dir[filename]["token_count"] = token_count
logging.info(f"Updated token count for {rel_path}: {token_count}")
if filename in current_dir and isinstance(
current_dir[filename], dict
):
current_dir[filename][
"token_count"
] = token_count
logging.info(
f"Updated token count for {rel_path}: {token_count}"
)
except Exception as e:
logging.warning(f"Could not update token count for {file_path}: {e}")
logging.warning(
f"Could not update token count for {file_path}: {e}"
)
for d in chunked_new:
meta = dict(d.extra_info or {})
try:
raw_src = meta.get("source")
if isinstance(raw_src, str) and os.path.isabs(raw_src):
meta["source"] = os.path.relpath(raw_src, start=temp_dir)
if isinstance(raw_src, str) and os.path.isabs(
raw_src
):
meta["source"] = os.path.relpath(
raw_src, start=temp_dir
)
except Exception:
pass
vector_store.add_chunk(d.text, metadata=meta)
added += 1
logging.info(f"Added {added} chunks from {len(added_files)} new files")
logging.info(
f"Added {added} chunks from {len(added_files)} new files"
)
except Exception as e:
logging.error(f"Error during ingestion of new files: {e}", exc_info=True)
logging.error(
f"Error during ingestion of new files: {e}", exc_info=True
)
# 3) Update source directory structure timestamp
try:
@@ -572,14 +642,19 @@ def reingest_source_worker(self, source_id, user):
"$set": {
"directory_structure": directory_structure,
"date": datetime.datetime.now(),
"tokens": total_tokens
"tokens": total_tokens,
}
},
)
except Exception as e:
logging.error(f"Error updating directory_structure in DB: {e}", exc_info=True)
logging.error(
f"Error updating directory_structure in DB: {e}", exc_info=True
)
self.update_state(state="PROGRESS", meta={"current": 100, "status": "Re-ingestion completed"})
self.update_state(
state="PROGRESS",
meta={"current": 100, "status": "Re-ingestion completed"},
)
return {
"source_id": source_id,
@@ -591,15 +666,16 @@ def reingest_source_worker(self, source_id, user):
"chunks_deleted": deleted,
}
except Exception as e:
logging.error(f"Error while processing file changes: {e}", exc_info=True)
logging.error(
f"Error while processing file changes: {e}", exc_info=True
)
raise
except Exception as e:
logging.error(f"Error in reingest_source_worker: {e}", exc_info=True)
raise
def remote_worker(
self,
source_data,
@@ -651,7 +727,7 @@ def remote_worker(
"id": str(id),
"type": loader,
"remote_data": source_data,
"sync_frequency": sync_frequency
"sync_frequency": sync_frequency,
}
if operation_mode == "sync":
@@ -752,7 +828,6 @@ def attachment_worker(self, file_info, user):
.text,
)
token_count = num_tokens_from_string(content)
if token_count > 100000:
content = content[:250000]
@@ -872,37 +947,49 @@ def ingest_connector(
doc_id: Document ID for sync operations (required when operation_mode="sync")
sync_frequency: How often to sync ("never", "daily", "weekly", "monthly")
"""
logging.info(f"Starting remote ingestion from {source_type} for user: {user}, job: {job_name}")
logging.info(
f"Starting remote ingestion from {source_type} for user: {user}, job: {job_name}"
)
self.update_state(state="PROGRESS", meta={"current": 1})
with tempfile.TemporaryDirectory() as temp_dir:
try:
# Step 1: Initialize the appropriate loader
self.update_state(state="PROGRESS", meta={"current": 10, "status": "Initializing connector"})
self.update_state(
state="PROGRESS",
meta={"current": 10, "status": "Initializing connector"},
)
if not session_token:
raise ValueError(f"{source_type} connector requires session_token")
if not ConnectorCreator.is_supported(source_type):
raise ValueError(f"Unsupported connector type: {source_type}. Supported types: {ConnectorCreator.get_supported_connectors()}")
raise ValueError(
f"Unsupported connector type: {source_type}. Supported types: {ConnectorCreator.get_supported_connectors()}"
)
remote_loader = ConnectorCreator.create_connector(source_type, session_token)
remote_loader = ConnectorCreator.create_connector(
source_type, session_token
)
# Create a clean config for storage
api_source_config = {
"file_ids": file_ids or [],
"folder_ids": folder_ids or [],
"recursive": recursive
"recursive": recursive,
}
# Step 2: Download files to temp directory
self.update_state(state="PROGRESS", meta={"current": 20, "status": "Downloading files"})
self.update_state(
state="PROGRESS", meta={"current": 20, "status": "Downloading files"}
)
download_info = remote_loader.download_to_directory(
temp_dir,
api_source_config
temp_dir, api_source_config
)
if download_info.get("empty_result", False) or not download_info.get("files_downloaded", 0):
if download_info.get("empty_result", False) or not download_info.get(
"files_downloaded", 0
):
logging.warning(f"No files were downloaded from {source_type}")
# Create empty result directly instead of calling a separate method
return {
@@ -915,25 +1002,39 @@ def ingest_connector(
}
# Step 3: Use SimpleDirectoryReader to process downloaded files
self.update_state(state="PROGRESS", meta={"current": 40, "status": "Processing files"})
self.update_state(
state="PROGRESS", meta={"current": 40, "status": "Processing files"}
)
reader = SimpleDirectoryReader(
input_dir=temp_dir,
recursive=True,
required_exts=[
".rst", ".md", ".pdf", ".txt", ".docx", ".csv", ".epub",
".html", ".mdx", ".json", ".xlsx", ".pptx", ".png",
".jpg", ".jpeg",
".rst",
".md",
".pdf",
".txt",
".docx",
".csv",
".epub",
".html",
".mdx",
".json",
".xlsx",
".pptx",
".png",
".jpg",
".jpeg",
],
exclude_hidden=True,
file_metadata=metadata_from_filename,
)
raw_docs = reader.load_data()
directory_structure = getattr(reader, 'directory_structure', {})
directory_structure = getattr(reader, "directory_structure", {})
# Step 4: Process documents (chunking, embedding, etc.)
self.update_state(state="PROGRESS", meta={"current": 60, "status": "Processing documents"})
self.update_state(
state="PROGRESS", meta={"current": 60, "status": "Processing documents"}
)
chunker = Chunker(
chunking_strategy="classic_chunk",
@@ -945,11 +1046,13 @@ def ingest_connector(
# Preserve source information in document metadata
for doc in raw_docs:
if hasattr(doc, 'extra_info') and doc.extra_info:
source = doc.extra_info.get('source')
if hasattr(doc, "extra_info") and doc.extra_info:
source = doc.extra_info.get("source")
if source and os.path.isabs(source):
# Convert absolute path to relative path
doc.extra_info['source'] = os.path.relpath(source, start=temp_dir)
doc.extra_info["source"] = os.path.relpath(
source, start=temp_dir
)
docs = [Document.to_langchain_format(raw_doc) for raw_doc in raw_docs]
@@ -957,7 +1060,9 @@ def ingest_connector(
id = ObjectId()
elif operation_mode == "sync":
if not doc_id or not ObjectId.is_valid(doc_id):
logging.error("Invalid doc_id provided for sync operation: %s", doc_id)
logging.error(
"Invalid doc_id provided for sync operation: %s", doc_id
)
raise ValueError("doc_id must be provided for sync operation.")
id = ObjectId(doc_id)
else:
@@ -966,7 +1071,9 @@ def ingest_connector(
vector_store_path = os.path.join(temp_dir, "vector_store")
os.makedirs(vector_store_path, exist_ok=True)
self.update_state(state="PROGRESS", meta={"current": 80, "status": "Storing documents"})
self.update_state(
state="PROGRESS", meta={"current": 80, "status": "Storing documents"}
)
embed_and_store_documents(docs, vector_store_path, id, self)
tokens = count_tokens_docs(docs)
@@ -979,12 +1086,11 @@ def ingest_connector(
"retriever": retriever,
"id": str(id),
"type": "connector",
"remote_data": json.dumps({
"provider": source_type,
**api_source_config
}),
"remote_data": json.dumps(
{"provider": source_type, **api_source_config}
),
"directory_structure": json.dumps(directory_structure),
"sync_frequency": sync_frequency
"sync_frequency": sync_frequency,
}
if operation_mode == "sync":
@@ -995,7 +1101,9 @@ def ingest_connector(
upload_index(vector_store_path, file_data)
# Ensure we mark the task as complete
self.update_state(state="PROGRESS", meta={"current": 100, "status": "Complete"})
self.update_state(
state="PROGRESS", meta={"current": 100, "status": "Complete"}
)
logging.info(f"Remote ingestion completed: {job_name}")
@@ -1005,9 +1113,136 @@ def ingest_connector(
"tokens": tokens,
"type": source_type,
"id": str(id),
"status": "complete"
"status": "complete",
}
except Exception as e:
logging.error(f"Error during remote ingestion: {e}", exc_info=True)
raise
def mcp_oauth(self, config: Dict[str, Any], user_id: str = None) -> Dict[str, Any]:
"""Worker to handle MCP OAuth flow asynchronously."""
logging.info(
"[MCP OAuth] Worker started for user_id=%s, config=%s", user_id, config
)
try:
import asyncio
from application.agents.tools.mcp_tool import MCPTool
task_id = self.request.id
logging.info("[MCP OAuth] Task ID: %s", task_id)
redis_client = get_redis_instance()
def update_status(status_data: Dict[str, Any]):
logging.info("[MCP OAuth] Updating status: %s", status_data)
status_key = f"mcp_oauth_status:{task_id}"
redis_client.setex(status_key, 600, json.dumps(status_data))
update_status(
{
"status": "in_progress",
"message": "Starting OAuth flow...",
"task_id": task_id,
}
)
tool_config = config.copy()
tool_config["oauth_task_id"] = task_id
logging.info("[MCP OAuth] Initializing MCPTool with config: %s", tool_config)
mcp_tool = MCPTool(tool_config, user_id)
async def run_oauth_discovery():
if not mcp_tool._client:
mcp_tool._setup_client()
return await mcp_tool._execute_with_client("list_tools")
update_status(
{
"status": "awaiting_redirect",
"message": "Waiting for OAuth redirect...",
"task_id": task_id,
}
)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
logging.info("[MCP OAuth] Starting event loop for OAuth discovery...")
tools_response = loop.run_until_complete(run_oauth_discovery())
logging.info(
"[MCP OAuth] Tools response after async call: %s", tools_response
)
status_key = f"mcp_oauth_status:{task_id}"
redis_status = redis_client.get(status_key)
if redis_status:
logging.info(
"[MCP OAuth] Redis status after async call: %s", redis_status
)
else:
logging.warning(
"[MCP OAuth] No Redis status found after async call for key: %s",
status_key,
)
tools = mcp_tool.get_actions_metadata()
update_status(
{
"status": "completed",
"message": f"OAuth completed successfully. Found {len(tools)} tools.",
"tools": tools,
"tools_count": len(tools),
"task_id": task_id,
}
)
logging.info(
"[MCP OAuth] OAuth flow completed successfully for task_id=%s", task_id
)
return {"success": True, "tools": tools, "tools_count": len(tools)}
except Exception as e:
error_msg = f"OAuth flow failed: {str(e)}"
logging.error(
"[MCP OAuth] Exception in OAuth discovery: %s", error_msg, exc_info=True
)
update_status(
{
"status": "error",
"message": error_msg,
"error": str(e),
"task_id": task_id,
}
)
return {"success": False, "error": error_msg}
finally:
logging.info("[MCP OAuth] Closing event loop for task_id=%s", task_id)
loop.close()
except Exception as e:
error_msg = f"Failed to initialize OAuth flow: {str(e)}"
logging.error(
"[MCP OAuth] Exception during initialization: %s", error_msg, exc_info=True
)
update_status(
{
"status": "error",
"message": error_msg,
"error": str(e),
"task_id": task_id,
}
)
return {"success": False, "error": error_msg}
def mcp_oauth_status(self, task_id: str) -> Dict[str, Any]:
"""Check the status of an MCP OAuth flow."""
redis_client = get_redis_instance()
status_key = f"mcp_oauth_status:{task_id}"
status_data = redis_client.get(status_key)
if status_data:
return json.loads(status_data)
return {"status": "not_found", "message": "Status not found"}

View File

@@ -59,6 +59,8 @@ const endpoints = {
MANAGE_SOURCE_FILES: '/api/manage_source_files',
MCP_TEST_CONNECTION: '/api/mcp_server/test',
MCP_SAVE_SERVER: '/api/mcp_server/save',
MCP_OAUTH_STATUS: (task_id: string) =>
`/api/mcp_server/oauth_status/${task_id}`,
},
CONVERSATION: {
ANSWER: '/api/answer',

View File

@@ -1,6 +1,6 @@
import { getSessionToken } from '../../utils/providerUtils';
import apiClient from '../client';
import endpoints from '../endpoints';
import { getSessionToken } from '../../utils/providerUtils';
const userService = {
getConfig: (): Promise<any> => apiClient.get(endpoints.USER.CONFIG, null),
@@ -112,6 +112,8 @@ const userService = {
apiClient.post(endpoints.USER.MCP_TEST_CONNECTION, data, token),
saveMCPServer: (data: any, token: string | null): Promise<any> =>
apiClient.post(endpoints.USER.MCP_SAVE_SERVER, data, token),
getMCPOAuthStatus: (task_id: string, token: string | null): Promise<any> =>
apiClient.get(endpoints.USER.MCP_OAUTH_STATUS(task_id), token),
syncConnector: (
docId: string,
provider: string,

View File

@@ -193,17 +193,20 @@
"headerName": "Header Name",
"timeout": "Timeout (seconds)",
"testConnection": "Test Connection",
"testing": "Testing...",
"saving": "Saving...",
"testing": "Testing",
"saving": "Saving",
"save": "Save",
"cancel": "Cancel",
"noAuth": "No Authentication",
"oauthInProgress": "Waiting for OAuth completion...",
"oauthCompleted": "OAuth completed successfully",
"placeholders": {
"serverUrl": "https://api.example.com",
"apiKey": "Your secret API key",
"bearerToken": "Your secret token",
"username": "Your username",
"password": "Your password"
"password": "Your password",
"oauthScopes": "OAuth scopes (comma separated)"
},
"errors": {
"nameRequired": "Server name is required",
@@ -214,7 +217,9 @@
"usernameRequired": "Username is required",
"passwordRequired": "Password is required",
"testFailed": "Connection test failed",
"saveFailed": "Failed to save MCP server"
"saveFailed": "Failed to save MCP server",
"oauthFailed": "OAuth process failed or was cancelled",
"oauthTimeout": "OAuth process timed out, please try again"
}
}
}

View File

@@ -22,6 +22,7 @@ const authTypes = [
{ label: 'No Authentication', value: 'none' },
{ label: 'API Key', value: 'api_key' },
{ label: 'Bearer Token', value: 'bearer' },
{ label: 'OAuth', value: 'oauth' },
// { label: 'Basic Authentication', value: 'basic' },
];
@@ -45,6 +46,8 @@ export default function MCPServerModal({
username: '',
password: '',
timeout: server?.timeout || 30,
oauth_scopes: '',
oauth_task_id: '',
});
const [loading, setLoading] = useState(false);
@@ -52,8 +55,13 @@ export default function MCPServerModal({
const [testResult, setTestResult] = useState<{
success: boolean;
message: string;
status?: string;
authorization_url?: string;
} | null>(null);
const [errors, setErrors] = useState<{ [key: string]: string }>({});
const oauthPopupRef = useRef<Window | null>(null);
const [oauthCompleted, setOAuthCompleted] = useState(false);
const [saveActive, setSaveActive] = useState(false);
useOutsideAlerter(modalRef, () => {
if (modalState === 'ACTIVE') {
@@ -73,9 +81,12 @@ export default function MCPServerModal({
username: '',
password: '',
timeout: 30,
oauth_scopes: '',
oauth_task_id: '',
});
setErrors({});
setTestResult(null);
setSaveActive(false);
};
const validateForm = () => {
@@ -154,10 +165,81 @@ export default function MCPServerModal({
} else if (formData.auth_type === 'basic') {
config.username = formData.username.trim();
config.password = formData.password.trim();
} else if (formData.auth_type === 'oauth') {
config.oauth_scopes = formData.oauth_scopes
.split(',')
.map((s) => s.trim())
.filter(Boolean);
config.oauth_task_id = formData.oauth_task_id.trim();
}
return config;
};
const pollOAuthStatus = async (
taskId: string,
onComplete: (result: any) => void,
) => {
let attempts = 0;
const maxAttempts = 60;
let popupOpened = false;
const poll = async () => {
try {
const resp = await userService.getMCPOAuthStatus(taskId, token);
const data = await resp.json();
if (data.authorization_url && !popupOpened) {
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
oauthPopupRef.current.close();
}
oauthPopupRef.current = window.open(
data.authorization_url,
'oauthPopup',
'width=600,height=700',
);
popupOpened = true;
}
if (data.status === 'completed') {
setOAuthCompleted(true);
setSaveActive(true);
onComplete({
...data,
success: true,
message: t('settings.tools.mcp.oauthCompleted'),
});
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
oauthPopupRef.current.close();
}
} else if (data.status === 'error' || data.success === false) {
setSaveActive(false);
onComplete({
...data,
success: false,
message: t('settings.tools.mcp.errors.oauthFailed'),
});
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
oauthPopupRef.current.close();
}
} else {
if (++attempts < maxAttempts) setTimeout(poll, 1000);
else {
setSaveActive(false);
onComplete({
success: false,
message: t('settings.tools.mcp.errors.oauthTimeout'),
});
}
}
} catch {
if (++attempts < maxAttempts) setTimeout(poll, 1000);
else
onComplete({
success: false,
message: t('settings.tools.mcp.errors.oauthTimeout'),
});
}
};
poll();
};
const testConnection = async () => {
if (!validateForm()) return;
setTesting(true);
@@ -167,13 +249,37 @@ export default function MCPServerModal({
const response = await userService.testMCPConnection({ config }, token);
const result = await response.json();
setTestResult(result);
if (
formData.auth_type === 'oauth' &&
result.requires_oauth &&
result.task_id
) {
setTestResult({
success: true,
message: t('settings.tools.mcp.oauthInProgress'),
});
setOAuthCompleted(false);
setSaveActive(false);
pollOAuthStatus(result.task_id, (finalResult) => {
setTestResult(finalResult);
setFormData((prev) => ({
...prev,
oauth_task_id: result.task_id || '',
}));
setTesting(false);
});
} else {
setTestResult(result);
setSaveActive(result.success === true);
setTesting(false);
}
} catch (error) {
setTestResult({
success: false,
message: t('settings.tools.mcp.errors.testFailed'),
});
} finally {
setOAuthCompleted(false);
setSaveActive(false);
setTesting(false);
}
};
@@ -305,6 +411,28 @@ export default function MCPServerModal({
</div>
</div>
);
case 'oauth':
return (
<div className="mb-10">
<div className="mt-6">
<Input
name="oauth_scopes"
type="text"
className="rounded-md"
value={formData.oauth_scopes}
onChange={(e) =>
handleInputChange('oauth_scopes', e.target.value)
}
placeholder={
t('settings.tools.mcp.placeholders.oauthScopes') ||
'Scopes (comma separated)'
}
borderVariant="thin"
labelBgClassName="bg-white dark:bg-charleston-green-2"
/>
</div>
</div>
);
default:
return null;
}
@@ -331,7 +459,6 @@ export default function MCPServerModal({
<div className="space-y-6 py-6">
<div>
<Input
name="name"
type="text"
className="rounded-md"
value={formData.name}
@@ -410,7 +537,7 @@ export default function MCPServerModal({
{testResult && (
<div
className={`rounded-md p-5 ${
className={`rounded-2xl p-5 ${
testResult.success
? 'bg-green-50 text-green-700 dark:bg-green-900/40 dark:text-green-300'
: 'bg-red-50 text-red-700 dark:bg-red-900 dark:text-red-300'
@@ -458,7 +585,7 @@ export default function MCPServerModal({
</button>
<button
onClick={handleSave}
disabled={loading}
disabled={loading || !saveActive}
className="bg-purple-30 hover:bg-violets-are-blue w-full rounded-3xl px-6 py-2 text-sm font-medium text-white transition-all disabled:opacity-50 sm:w-auto"
>
{loading ? (