Merge pull request #1964 from siiddhantt/refine/mcp-tool

refactor: oauth + use fastmcp client for handling SSE and different transports in remote mcp
This commit is contained in:
Alex
2025-09-26 13:54:40 +01:00
committed by GitHub
9 changed files with 1342 additions and 395 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -8,6 +8,7 @@ import uuid
import zipfile import zipfile
from functools import wraps from functools import wraps
from typing import Optional, Tuple from typing import Optional, Tuple
from urllib.parse import unquote
from bson.binary import Binary, UuidRepresentation from bson.binary import Binary, UuidRepresentation
from bson.dbref import DBRef from bson.dbref import DBRef
@@ -25,7 +26,7 @@ from flask_restx import fields, inputs, Namespace, Resource
from pymongo import ReturnDocument from pymongo import ReturnDocument
from werkzeug.utils import secure_filename from werkzeug.utils import secure_filename
from application.agents.tools.mcp_tool import MCPTool from application.agents.tools.mcp_tool import MCPOAuthManager, MCPTool
from application.agents.tools.tool_manager import ToolManager from application.agents.tools.tool_manager import ToolManager
from application.api import api from application.api import api
@@ -37,6 +38,8 @@ from application.api.user.tasks import (
process_agent_webhook, process_agent_webhook,
store_attachment, store_attachment,
) )
from application.cache import get_redis_instance
from application.core.mongo_db import MongoDB from application.core.mongo_db import MongoDB
from application.core.settings import settings from application.core.settings import settings
from application.parser.connectors.connector_creator import ConnectorCreator from application.parser.connectors.connector_creator import ConnectorCreator
@@ -494,7 +497,6 @@ class DeleteOldIndexes(Resource):
) )
if not doc: if not doc:
return make_response(jsonify({"status": "not found"}), 404) return make_response(jsonify({"status": "not found"}), 404)
storage = StorageCreator.get_storage() storage = StorageCreator.get_storage()
try: try:
@@ -511,7 +513,6 @@ class DeleteOldIndexes(Resource):
settings.VECTOR_STORE, source_id=str(doc["_id"]) settings.VECTOR_STORE, source_id=str(doc["_id"])
) )
vectorstore.delete_index() vectorstore.delete_index()
if "file_path" in doc and doc["file_path"]: if "file_path" in doc and doc["file_path"]:
file_path = doc["file_path"] file_path = doc["file_path"]
if storage.is_directory(file_path): if storage.is_directory(file_path):
@@ -520,7 +521,6 @@ class DeleteOldIndexes(Resource):
storage.delete_file(f) storage.delete_file(f)
else: else:
storage.delete_file(file_path) storage.delete_file(file_path)
except FileNotFoundError: except FileNotFoundError:
pass pass
except Exception as err: except Exception as err:
@@ -528,7 +528,6 @@ class DeleteOldIndexes(Resource):
f"Error deleting files and indexes: {err}", exc_info=True f"Error deleting files and indexes: {err}", exc_info=True
) )
return make_response(jsonify({"success": False}), 400) return make_response(jsonify({"success": False}), 400)
sources_collection.delete_one({"_id": ObjectId(source_id)}) sources_collection.delete_one({"_id": ObjectId(source_id)})
return make_response(jsonify({"success": True}), 200) return make_response(jsonify({"success": True}), 200)
@@ -600,7 +599,6 @@ class UploadFile(Resource):
== temp_file_path == temp_file_path
): ):
continue continue
rel_path = os.path.relpath( rel_path = os.path.relpath(
os.path.join(root, extracted_file), temp_dir os.path.join(root, extracted_file), temp_dir
) )
@@ -625,7 +623,6 @@ class UploadFile(Resource):
file_path = f"{base_path}/{safe_file}" file_path = f"{base_path}/{safe_file}"
with open(temp_file_path, "rb") as f: with open(temp_file_path, "rb") as f:
storage.save_file(f, file_path) storage.save_file(f, file_path)
task = ingest.delay( task = ingest.delay(
settings.UPLOAD_FOLDER, settings.UPLOAD_FOLDER,
[ [
@@ -697,7 +694,6 @@ class ManageSourceFiles(Resource):
return make_response( return make_response(
jsonify({"success": False, "message": "Unauthorized"}), 401 jsonify({"success": False, "message": "Unauthorized"}), 401
) )
user = decoded_token.get("sub") user = decoded_token.get("sub")
source_id = request.form.get("source_id") source_id = request.form.get("source_id")
operation = request.form.get("operation") operation = request.form.get("operation")
@@ -747,7 +743,6 @@ class ManageSourceFiles(Resource):
return make_response( return make_response(
jsonify({"success": False, "message": "Database error"}), 500 jsonify({"success": False, "message": "Database error"}), 500
) )
try: try:
storage = StorageCreator.get_storage() storage = StorageCreator.get_storage()
source_file_path = source.get("file_path", "") source_file_path = source.get("file_path", "")
@@ -804,7 +799,6 @@ class ManageSourceFiles(Resource):
), ),
200, 200,
) )
elif operation == "remove": elif operation == "remove":
file_paths_str = request.form.get("file_paths") file_paths_str = request.form.get("file_paths")
if not file_paths_str: if not file_paths_str:
@@ -858,7 +852,6 @@ class ManageSourceFiles(Resource):
), ),
200, 200,
) )
elif operation == "remove_directory": elif operation == "remove_directory":
directory_path = request.form.get("directory_path") directory_path = request.form.get("directory_path")
if not directory_path: if not directory_path:
@@ -884,7 +877,6 @@ class ManageSourceFiles(Resource):
), ),
400, 400,
) )
full_directory_path = ( full_directory_path = (
f"{source_file_path}/{directory_path}" f"{source_file_path}/{directory_path}"
if directory_path if directory_path
@@ -943,7 +935,6 @@ class ManageSourceFiles(Resource):
), ),
200, 200,
) )
except Exception as err: except Exception as err:
error_context = f"operation={operation}, user={user}, source_id={source_id}" error_context = f"operation={operation}, user={user}, source_id={source_id}"
if operation == "remove_directory": if operation == "remove_directory":
@@ -955,7 +946,6 @@ class ManageSourceFiles(Resource):
elif operation == "add": elif operation == "add":
parent_dir = request.form.get("parent_dir", "") parent_dir = request.form.get("parent_dir", "")
error_context += f", parent_dir={parent_dir}" error_context += f", parent_dir={parent_dir}"
current_app.logger.error( current_app.logger.error(
f"Error managing source files: {err} ({error_context})", exc_info=True f"Error managing source files: {err} ({error_context})", exc_info=True
) )
@@ -1632,7 +1622,6 @@ class CreateAgent(Resource):
), ),
400, 400,
) )
# Validate that it has either a 'schema' property or is itself a schema # Validate that it has either a 'schema' property or is itself a schema
if "schema" not in json_schema and "type" not in json_schema: if "schema" not in json_schema and "type" not in json_schema:
@@ -3476,7 +3465,6 @@ class AvailableTools(Resource):
"displayName": name, "displayName": name,
"description": description, "description": description,
"configRequirements": tool_instance.get_config_requirements(), "configRequirements": tool_instance.get_config_requirements(),
"actions": tool_instance.get_actions_metadata(),
} }
) )
except Exception as err: except Exception as err:
@@ -3527,11 +3515,6 @@ class CreateTool(Resource):
"customName": fields.String( "customName": fields.String(
required=False, description="Custom name for the tool" required=False, description="Custom name for the tool"
), ),
"actions": fields.List(
fields.Raw,
required=True,
description="Actions the tool can perform",
),
"status": fields.Boolean( "status": fields.Boolean(
required=True, description="Status of the tool" required=True, description="Status of the tool"
), ),
@@ -3549,24 +3532,35 @@ class CreateTool(Resource):
"name", "name",
"displayName", "displayName",
"description", "description",
"actions",
"config", "config",
"status", "status",
] ]
missing_fields = check_required_fields(data, required_fields) missing_fields = check_required_fields(data, required_fields)
if missing_fields: if missing_fields:
return missing_fields return missing_fields
transformed_actions = [] try:
for action in data["actions"]: tool_instance = tool_manager.tools.get(data["name"])
action["active"] = True if not tool_instance:
if "parameters" in action: return make_response(
if "properties" in action["parameters"]: jsonify({"success": False, "message": "Tool not found"}), 404
for param_name, param_details in action["parameters"][ )
"properties" actions_metadata = tool_instance.get_actions_metadata()
].items(): transformed_actions = []
param_details["filled_by_llm"] = True for action in actions_metadata:
param_details["value"] = "" action["active"] = True
transformed_actions.append(action) if "parameters" in action:
if "properties" in action["parameters"]:
for param_name, param_details in action["parameters"][
"properties"
].items():
param_details["filled_by_llm"] = True
param_details["value"] = ""
transformed_actions.append(action)
except Exception as err:
current_app.logger.error(
f"Error getting tool actions: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
try: try:
new_tool = { new_tool = {
"user": user, "user": user,
@@ -3907,7 +3901,6 @@ class GetChunks(Resource):
if not (text_match or title_match): if not (text_match or title_match):
continue continue
filtered_chunks.append(chunk) filtered_chunks.append(chunk)
chunks = filtered_chunks chunks = filtered_chunks
total_chunks = len(chunks) total_chunks = len(chunks)
@@ -4098,7 +4091,6 @@ class UpdateChunk(Resource):
current_app.logger.warning( current_app.logger.warning(
f"Failed to delete old chunk {chunk_id}, but new chunk {new_chunk_id} was created" f"Failed to delete old chunk {chunk_id}, but new chunk {new_chunk_id} was created"
) )
return make_response( return make_response(
jsonify( jsonify(
{ {
@@ -4226,23 +4218,19 @@ class DirectoryStructure(Resource):
decoded_token = request.decoded_token decoded_token = request.decoded_token
if not decoded_token: if not decoded_token:
return make_response(jsonify({"success": False}), 401) return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub") user = decoded_token.get("sub")
doc_id = request.args.get("id") doc_id = request.args.get("id")
if not doc_id: if not doc_id:
return make_response(jsonify({"error": "Document ID is required"}), 400) return make_response(jsonify({"error": "Document ID is required"}), 400)
if not ObjectId.is_valid(doc_id): if not ObjectId.is_valid(doc_id):
return make_response(jsonify({"error": "Invalid document ID"}), 400) return make_response(jsonify({"error": "Invalid document ID"}), 400)
try: try:
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user}) doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
if not doc: if not doc:
return make_response( return make_response(
jsonify({"error": "Document not found or access denied"}), 404 jsonify({"error": "Document not found or access denied"}), 404
) )
directory_structure = doc.get("directory_structure", {}) directory_structure = doc.get("directory_structure", {})
base_path = doc.get("file_path", "") base_path = doc.get("file_path", "")
@@ -4315,11 +4303,10 @@ class TestMCPServerConfig(Resource):
auth_credentials["username"] = config["username"] auth_credentials["username"] = config["username"]
if "password" in config: if "password" in config:
auth_credentials["password"] = config["password"] auth_credentials["password"] = config["password"]
test_config = config.copy() test_config = config.copy()
test_config["auth_credentials"] = auth_credentials test_config["auth_credentials"] = auth_credentials
mcp_tool = MCPTool(test_config, user) mcp_tool = MCPTool(config=test_config, user_id=user)
result = mcp_tool.test_connection() result = mcp_tool.test_connection()
return make_response(jsonify(result), 200) return make_response(jsonify(result), 200)
@@ -4387,22 +4374,45 @@ class MCPServerSave(Resource):
mcp_config = config.copy() mcp_config = config.copy()
mcp_config["auth_credentials"] = auth_credentials mcp_config["auth_credentials"] = auth_credentials
if auth_type == "none" or auth_credentials: if auth_type == "oauth":
mcp_tool = MCPTool(mcp_config, user) if not config.get("oauth_task_id"):
return make_response(
jsonify(
{
"success": False,
"error": "Connection not authorized. Please complete the OAuth authorization first.",
}
),
400,
)
redis_client = get_redis_instance()
manager = MCPOAuthManager(redis_client)
result = manager.get_oauth_status(config["oauth_task_id"])
if not result.get("status") == "completed":
return make_response(
jsonify(
{
"success": False,
"error": "OAuth failed or not completed. Please try authorizing again.",
}
),
400,
)
actions_metadata = result.get("tools", [])
elif auth_type == "none" or auth_credentials:
mcp_tool = MCPTool(config=mcp_config, user_id=user)
mcp_tool.discover_tools() mcp_tool.discover_tools()
actions_metadata = mcp_tool.get_actions_metadata() actions_metadata = mcp_tool.get_actions_metadata()
else: else:
raise Exception( raise Exception(
"No valid credentials provided for the selected authentication type" "No valid credentials provided for the selected authentication type"
) )
storage_config = config.copy() storage_config = config.copy()
if auth_credentials: if auth_credentials:
encrypted_credentials_string = encrypt_credentials( encrypted_credentials_string = encrypt_credentials(
auth_credentials, user auth_credentials, user
) )
storage_config["encrypted_credentials"] = encrypted_credentials_string storage_config["encrypted_credentials"] = encrypted_credentials_string
for field in [ for field in [
"api_key", "api_key",
"bearer_token", "bearer_token",
@@ -4473,3 +4483,96 @@ class MCPServerSave(Resource):
), ),
500, 500,
) )
@user_ns.route("/api/mcp_server/callback")
class MCPOAuthCallback(Resource):
@api.expect(
api.model(
"MCPServerCallbackModel",
{
"code": fields.String(required=True, description="Authorization code"),
"state": fields.String(required=True, description="State parameter"),
"error": fields.String(
required=False, description="Error message (if any)"
),
},
)
)
@api.doc(
description="Handle OAuth callback by providing the authorization code and state"
)
def get(self):
code = request.args.get("code")
state = request.args.get("state")
error = request.args.get("error")
if error:
return redirect(
f"/api/connectors/callback-status?status=error&message=OAuth+error:+{error}.+Please+try+again+and+make+sure+to+grant+all+requested+permissions,+including+offline+access.&provider=mcp_tool"
)
if not code or not state:
return redirect(
"/api/connectors/callback-status?status=error&message=Authorization+code+or+state+not+provided.+Please+complete+the+authorization+process+and+make+sure+to+grant+offline+access.&provider=mcp_tool"
)
try:
redis_client = get_redis_instance()
if not redis_client:
return redirect(
"/api/connectors/callback-status?status=error&message=Internal+server+error:+Redis+not+available.&provider=mcp_tool"
)
code = unquote(code)
manager = MCPOAuthManager(redis_client)
success = manager.handle_oauth_callback(state, code, error)
if success:
return redirect(
"/api/connectors/callback-status?status=success&message=Authorization+code+received+successfully.+You+can+close+this+window.&provider=mcp_tool"
)
else:
return redirect(
"/api/connectors/callback-status?status=error&message=OAuth+callback+failed.&provider=mcp_tool"
)
except Exception as e:
current_app.logger.error(
f"Error handling MCP OAuth callback: {str(e)}", exc_info=True
)
return redirect(
f"/api/connectors/callback-status?status=error&message=Internal+server+error:+{str(e)}.&provider=mcp_tool"
)
@user_ns.route("/api/mcp_server/oauth_status/<string:task_id>")
class MCPOAuthStatus(Resource):
def get(self, task_id):
"""
Get current status of OAuth flow.
Frontend should poll this endpoint periodically.
"""
try:
redis_client = get_redis_instance()
status_key = f"mcp_oauth_status:{task_id}"
status_data = redis_client.get(status_key)
if status_data:
status = json.loads(status_data)
return make_response(
jsonify({"success": True, "task_id": task_id, **status})
)
else:
return make_response(
jsonify(
{
"success": False,
"error": "Task not found or expired",
"task_id": task_id,
}
),
404,
)
except Exception as e:
current_app.logger.error(
f"Error getting OAuth status for task {task_id}: {str(e)}"
)
return make_response(
jsonify({"success": False, "error": str(e), "task_id": task_id}), 500
)

View File

@@ -5,6 +5,8 @@ from application.worker import (
agent_webhook_worker, agent_webhook_worker,
attachment_worker, attachment_worker,
ingest_worker, ingest_worker,
mcp_oauth,
mcp_oauth_status,
remote_worker, remote_worker,
sync_worker, sync_worker,
) )
@@ -25,6 +27,7 @@ def ingest_remote(self, source_data, job_name, user, loader):
@celery.task(bind=True) @celery.task(bind=True)
def reingest_source_task(self, source_id, user): def reingest_source_task(self, source_id, user):
from application.worker import reingest_source_worker from application.worker import reingest_source_worker
resp = reingest_source_worker(self, source_id, user) resp = reingest_source_worker(self, source_id, user)
return resp return resp
@@ -60,9 +63,10 @@ def ingest_connector_task(
retriever="classic", retriever="classic",
operation_mode="upload", operation_mode="upload",
doc_id=None, doc_id=None,
sync_frequency="never" sync_frequency="never",
): ):
from application.worker import ingest_connector from application.worker import ingest_connector
resp = ingest_connector( resp = ingest_connector(
self, self,
job_name, job_name,
@@ -75,7 +79,7 @@ def ingest_connector_task(
retriever=retriever, retriever=retriever,
operation_mode=operation_mode, operation_mode=operation_mode,
doc_id=doc_id, doc_id=doc_id,
sync_frequency=sync_frequency sync_frequency=sync_frequency,
) )
return resp return resp
@@ -94,3 +98,15 @@ def setup_periodic_tasks(sender, **kwargs):
timedelta(days=30), timedelta(days=30),
schedule_syncs.s("monthly"), schedule_syncs.s("monthly"),
) )
@celery.task(bind=True)
def mcp_oauth_task(self, config, user):
resp = mcp_oauth(self, config, user)
return resp
@celery.task(bind=True)
def mcp_oauth_status_task(self, task_id):
resp = mcp_oauth_status(self, task_id)
return resp

View File

@@ -12,6 +12,7 @@ esprima==4.0.1
esutils==1.0.1 esutils==1.0.1
Flask==3.1.1 Flask==3.1.1
faiss-cpu==1.9.0.post1 faiss-cpu==1.9.0.post1
fastmcp==2.11.0
flask-restx==1.3.0 flask-restx==1.3.0
google-genai==1.3.0 google-genai==1.3.0
google-api-python-client==2.179.0 google-api-python-client==2.179.0
@@ -56,13 +57,13 @@ prompt-toolkit==3.0.51
protobuf==5.29.3 protobuf==5.29.3
psycopg2-binary==2.9.10 psycopg2-binary==2.9.10
py==1.11.0 py==1.11.0
pydantic==2.10.6 pydantic
pydantic-core==2.27.2 pydantic-core
pydantic-settings==2.7.1 pydantic-settings
pymongo==4.11.3 pymongo==4.11.3
pypdf==5.5.0 pypdf==5.5.0
python-dateutil==2.9.0.post0 python-dateutil==2.9.0.post0
python-dotenv==1.0.1 python-dotenv
python-jose==3.4.0 python-jose==3.4.0
python-pptx==1.0.2 python-pptx==1.0.2
redis==5.2.1 redis==5.2.1
@@ -82,7 +83,7 @@ tzdata==2024.2
urllib3==2.3.0 urllib3==2.3.0
vine==5.1.0 vine==5.1.0
wcwidth==0.2.13 wcwidth==0.2.13
werkzeug==3.1.3 werkzeug>=3.1.0,<3.1.2
yarl==1.20.0 yarl==1.20.0
markdownify==1.1.0 markdownify==1.1.0
tldextract==5.1.3 tldextract==5.1.3

View File

@@ -19,6 +19,7 @@ from bson.objectid import ObjectId
from application.agents.agent_creator import AgentCreator from application.agents.agent_creator import AgentCreator
from application.api.answer.services.stream_processor import get_prompt from application.api.answer.services.stream_processor import get_prompt
from application.cache import get_redis_instance
from application.core.mongo_db import MongoDB from application.core.mongo_db import MongoDB
from application.core.settings import settings from application.core.settings import settings
from application.parser.chunking import Chunker from application.parser.chunking import Chunker
@@ -214,8 +215,7 @@ def run_agent_logic(agent_config, input_data):
def ingest_worker( def ingest_worker(
self, directory, formats, job_name, file_path, filename, user, self, directory, formats, job_name, file_path, filename, user, retriever="classic"
retriever="classic"
): ):
""" """
Ingest and process documents. Ingest and process documents.
@@ -240,7 +240,7 @@ def ingest_worker(
sample = False sample = False
storage = StorageCreator.get_storage() storage = StorageCreator.get_storage()
logging.info(f"Ingest path: {file_path}", extra={"user": user, "job": job_name}) logging.info(f"Ingest path: {file_path}", extra={"user": user, "job": job_name})
# Create temporary working directory # Create temporary working directory
@@ -253,30 +253,32 @@ def ingest_worker(
# Handle directory case # Handle directory case
logging.info(f"Processing directory: {file_path}") logging.info(f"Processing directory: {file_path}")
files_list = storage.list_files(file_path) files_list = storage.list_files(file_path)
for storage_file_path in files_list: for storage_file_path in files_list:
if storage.is_directory(storage_file_path): if storage.is_directory(storage_file_path):
continue continue
# Create relative path structure in temp directory # Create relative path structure in temp directory
rel_path = os.path.relpath(storage_file_path, file_path) rel_path = os.path.relpath(storage_file_path, file_path)
local_file_path = os.path.join(temp_dir, rel_path) local_file_path = os.path.join(temp_dir, rel_path)
os.makedirs(os.path.dirname(local_file_path), exist_ok=True) os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
# Download file # Download file
try: try:
file_data = storage.get_file(storage_file_path) file_data = storage.get_file(storage_file_path)
with open(local_file_path, "wb") as f: with open(local_file_path, "wb") as f:
f.write(file_data.read()) f.write(file_data.read())
except Exception as e: except Exception as e:
logging.error(f"Error downloading file {storage_file_path}: {e}") logging.error(
f"Error downloading file {storage_file_path}: {e}"
)
continue continue
else: else:
# Handle single file case # Handle single file case
temp_filename = os.path.basename(file_path) temp_filename = os.path.basename(file_path)
temp_file_path = os.path.join(temp_dir, temp_filename) temp_file_path = os.path.join(temp_dir, temp_filename)
file_data = storage.get_file(file_path) file_data = storage.get_file(file_path)
with open(temp_file_path, "wb") as f: with open(temp_file_path, "wb") as f:
f.write(file_data.read()) f.write(file_data.read())
@@ -285,7 +287,10 @@ def ingest_worker(
if temp_filename.endswith(".zip"): if temp_filename.endswith(".zip"):
logging.info(f"Extracting zip file: {temp_filename}") logging.info(f"Extracting zip file: {temp_filename}")
extract_zip_recursive( extract_zip_recursive(
temp_file_path, temp_dir, current_depth=0, max_depth=RECURSION_DEPTH temp_file_path,
temp_dir,
current_depth=0,
max_depth=RECURSION_DEPTH,
) )
self.update_state(state="PROGRESS", meta={"current": 1}) self.update_state(state="PROGRESS", meta={"current": 1})
@@ -300,8 +305,8 @@ def ingest_worker(
file_metadata=metadata_from_filename, file_metadata=metadata_from_filename,
) )
raw_docs = reader.load_data() raw_docs = reader.load_data()
directory_structure = getattr(reader, 'directory_structure', {}) directory_structure = getattr(reader, "directory_structure", {})
logging.info(f"Directory structure from reader: {directory_structure}") logging.info(f"Directory structure from reader: {directory_structure}")
chunker = Chunker( chunker = Chunker(
@@ -371,7 +376,10 @@ def reingest_source_worker(self, source_id, user):
try: try:
from application.vectorstore.vector_creator import VectorCreator from application.vectorstore.vector_creator import VectorCreator
self.update_state(state="PROGRESS", meta={"current": 10, "status": "Initializing re-ingestion scan"}) self.update_state(
state="PROGRESS",
meta={"current": 10, "status": "Initializing re-ingestion scan"},
)
source = sources_collection.find_one({"_id": ObjectId(source_id), "user": user}) source = sources_collection.find_one({"_id": ObjectId(source_id), "user": user})
if not source: if not source:
@@ -380,7 +388,9 @@ def reingest_source_worker(self, source_id, user):
storage = StorageCreator.get_storage() storage = StorageCreator.get_storage()
source_file_path = source.get("file_path", "") source_file_path = source.get("file_path", "")
self.update_state(state="PROGRESS", meta={"current": 20, "status": "Scanning current files"}) self.update_state(
state="PROGRESS", meta={"current": 20, "status": "Scanning current files"}
)
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
# Download all files from storage to temp directory, preserving directory structure # Download all files from storage to temp directory, preserving directory structure
@@ -391,7 +401,6 @@ def reingest_source_worker(self, source_id, user):
if storage.is_directory(storage_file_path): if storage.is_directory(storage_file_path):
continue continue
rel_path = os.path.relpath(storage_file_path, source_file_path) rel_path = os.path.relpath(storage_file_path, source_file_path)
local_file_path = os.path.join(temp_dir, rel_path) local_file_path = os.path.join(temp_dir, rel_path)
@@ -403,23 +412,39 @@ def reingest_source_worker(self, source_id, user):
with open(local_file_path, "wb") as f: with open(local_file_path, "wb") as f:
f.write(file_data.read()) f.write(file_data.read())
except Exception as e: except Exception as e:
logging.error(f"Error downloading file {storage_file_path}: {e}") logging.error(
f"Error downloading file {storage_file_path}: {e}"
)
continue continue
reader = SimpleDirectoryReader( reader = SimpleDirectoryReader(
input_dir=temp_dir, input_dir=temp_dir,
recursive=True, recursive=True,
required_exts=[ required_exts=[
".rst", ".md", ".pdf", ".txt", ".docx", ".csv", ".epub", ".rst",
".html", ".mdx", ".json", ".xlsx", ".pptx", ".png", ".md",
".jpg", ".jpeg", ".pdf",
".txt",
".docx",
".csv",
".epub",
".html",
".mdx",
".json",
".xlsx",
".pptx",
".png",
".jpg",
".jpeg",
], ],
exclude_hidden=True, exclude_hidden=True,
file_metadata=metadata_from_filename, file_metadata=metadata_from_filename,
) )
reader.load_data() reader.load_data()
directory_structure = reader.directory_structure directory_structure = reader.directory_structure
logging.info(f"Directory structure built with token counts: {directory_structure}") logging.info(
f"Directory structure built with token counts: {directory_structure}"
)
try: try:
old_directory_structure = source.get("directory_structure") or {} old_directory_structure = source.get("directory_structure") or {}
@@ -433,11 +458,17 @@ def reingest_source_worker(self, source_id, user):
files = set() files = set()
if isinstance(struct, dict): if isinstance(struct, dict):
for name, meta in struct.items(): for name, meta in struct.items():
current_path = os.path.join(prefix, name) if prefix else name current_path = (
if isinstance(meta, dict) and ("type" in meta and "size_bytes" in meta): os.path.join(prefix, name) if prefix else name
)
if isinstance(meta, dict) and (
"type" in meta and "size_bytes" in meta
):
files.add(current_path) files.add(current_path)
elif isinstance(meta, dict): elif isinstance(meta, dict):
files |= _flatten_directory_structure(meta, current_path) files |= _flatten_directory_structure(
meta, current_path
)
return files return files
old_files = _flatten_directory_structure(old_directory_structure) old_files = _flatten_directory_structure(old_directory_structure)
@@ -457,7 +488,9 @@ def reingest_source_worker(self, source_id, user):
logging.info("No files removed since last ingest.") logging.info("No files removed since last ingest.")
except Exception as e: except Exception as e:
logging.error(f"Error comparing directory structures: {e}", exc_info=True) logging.error(
f"Error comparing directory structures: {e}", exc_info=True
)
added_files = [] added_files = []
removed_files = [] removed_files = []
try: try:
@@ -477,14 +510,21 @@ def reingest_source_worker(self, source_id, user):
settings.EMBEDDINGS_KEY, settings.EMBEDDINGS_KEY,
) )
self.update_state(state="PROGRESS", meta={"current": 40, "status": "Processing file changes"}) self.update_state(
state="PROGRESS",
meta={"current": 40, "status": "Processing file changes"},
)
# 1) Delete chunks from removed files # 1) Delete chunks from removed files
deleted = 0 deleted = 0
if removed_files: if removed_files:
try: try:
for ch in vector_store.get_chunks() or []: for ch in vector_store.get_chunks() or []:
metadata = ch.get("metadata", {}) if isinstance(ch, dict) else getattr(ch, "metadata", {}) metadata = (
ch.get("metadata", {})
if isinstance(ch, dict)
else getattr(ch, "metadata", {})
)
raw_source = metadata.get("source") raw_source = metadata.get("source")
source_file = str(raw_source) if raw_source else "" source_file = str(raw_source) if raw_source else ""
@@ -496,10 +536,17 @@ def reingest_source_worker(self, source_id, user):
vector_store.delete_chunk(cid) vector_store.delete_chunk(cid)
deleted += 1 deleted += 1
except Exception as de: except Exception as de:
logging.error(f"Failed deleting chunk {cid}: {de}") logging.error(
logging.info(f"Deleted {deleted} chunks from {len(removed_files)} removed files") f"Failed deleting chunk {cid}: {de}"
)
logging.info(
f"Deleted {deleted} chunks from {len(removed_files)} removed files"
)
except Exception as e: except Exception as e:
logging.error(f"Error during deletion of removed file chunks: {e}", exc_info=True) logging.error(
f"Error during deletion of removed file chunks: {e}",
exc_info=True,
)
# 2) Add chunks from new files # 2) Add chunks from new files
added = 0 added = 0
@@ -528,58 +575,86 @@ def reingest_source_worker(self, source_id, user):
) )
chunked_new = chunker_new.chunk(documents=raw_docs_new) chunked_new = chunker_new.chunk(documents=raw_docs_new)
for file_path, token_count in reader_new.file_token_counts.items(): for (
file_path,
token_count,
) in reader_new.file_token_counts.items():
try: try:
rel_path = os.path.relpath(file_path, start=temp_dir) rel_path = os.path.relpath(
file_path, start=temp_dir
)
path_parts = rel_path.split(os.sep) path_parts = rel_path.split(os.sep)
current_dir = directory_structure current_dir = directory_structure
for part in path_parts[:-1]: for part in path_parts[:-1]:
if part in current_dir and isinstance(current_dir[part], dict): if part in current_dir and isinstance(
current_dir[part], dict
):
current_dir = current_dir[part] current_dir = current_dir[part]
else: else:
break break
filename = path_parts[-1] filename = path_parts[-1]
if filename in current_dir and isinstance(current_dir[filename], dict): if filename in current_dir and isinstance(
current_dir[filename]["token_count"] = token_count current_dir[filename], dict
logging.info(f"Updated token count for {rel_path}: {token_count}") ):
current_dir[filename][
"token_count"
] = token_count
logging.info(
f"Updated token count for {rel_path}: {token_count}"
)
except Exception as e: except Exception as e:
logging.warning(f"Could not update token count for {file_path}: {e}") logging.warning(
f"Could not update token count for {file_path}: {e}"
)
for d in chunked_new: for d in chunked_new:
meta = dict(d.extra_info or {}) meta = dict(d.extra_info or {})
try: try:
raw_src = meta.get("source") raw_src = meta.get("source")
if isinstance(raw_src, str) and os.path.isabs(raw_src): if isinstance(raw_src, str) and os.path.isabs(
meta["source"] = os.path.relpath(raw_src, start=temp_dir) raw_src
):
meta["source"] = os.path.relpath(
raw_src, start=temp_dir
)
except Exception: except Exception:
pass pass
vector_store.add_chunk(d.text, metadata=meta) vector_store.add_chunk(d.text, metadata=meta)
added += 1 added += 1
logging.info(f"Added {added} chunks from {len(added_files)} new files") logging.info(
f"Added {added} chunks from {len(added_files)} new files"
)
except Exception as e: except Exception as e:
logging.error(f"Error during ingestion of new files: {e}", exc_info=True) logging.error(
f"Error during ingestion of new files: {e}", exc_info=True
)
# 3) Update source directory structure timestamp # 3) Update source directory structure timestamp
try: try:
total_tokens = sum(reader.file_token_counts.values()) total_tokens = sum(reader.file_token_counts.values())
sources_collection.update_one( sources_collection.update_one(
{"_id": ObjectId(source_id)}, {"_id": ObjectId(source_id)},
{ {
"$set": { "$set": {
"directory_structure": directory_structure, "directory_structure": directory_structure,
"date": datetime.datetime.now(), "date": datetime.datetime.now(),
"tokens": total_tokens "tokens": total_tokens,
} }
}, },
) )
except Exception as e: except Exception as e:
logging.error(f"Error updating directory_structure in DB: {e}", exc_info=True) logging.error(
f"Error updating directory_structure in DB: {e}", exc_info=True
)
self.update_state(state="PROGRESS", meta={"current": 100, "status": "Re-ingestion completed"}) self.update_state(
state="PROGRESS",
meta={"current": 100, "status": "Re-ingestion completed"},
)
return { return {
"source_id": source_id, "source_id": source_id,
@@ -591,15 +666,16 @@ def reingest_source_worker(self, source_id, user):
"chunks_deleted": deleted, "chunks_deleted": deleted,
} }
except Exception as e: except Exception as e:
logging.error(f"Error while processing file changes: {e}", exc_info=True) logging.error(
f"Error while processing file changes: {e}", exc_info=True
)
raise raise
except Exception as e: except Exception as e:
logging.error(f"Error in reingest_source_worker: {e}", exc_info=True) logging.error(f"Error in reingest_source_worker: {e}", exc_info=True)
raise raise
def remote_worker( def remote_worker(
self, self,
source_data, source_data,
@@ -651,7 +727,7 @@ def remote_worker(
"id": str(id), "id": str(id),
"type": loader, "type": loader,
"remote_data": source_data, "remote_data": source_data,
"sync_frequency": sync_frequency "sync_frequency": sync_frequency,
} }
if operation_mode == "sync": if operation_mode == "sync":
@@ -712,7 +788,7 @@ def sync_worker(self, frequency):
self, source_data, name, user, source_type, frequency, retriever, doc_id self, source_data, name, user, source_type, frequency, retriever, doc_id
) )
sync_counts["total_sync_count"] += 1 sync_counts["total_sync_count"] += 1
sync_counts[ sync_counts[
"sync_success" if resp["status"] == "success" else "sync_failure" "sync_success" if resp["status"] == "success" else "sync_failure"
] += 1 ] += 1
return { return {
@@ -749,15 +825,14 @@ def attachment_worker(self, file_info, user):
input_files=[local_path], exclude_hidden=True, errors="ignore" input_files=[local_path], exclude_hidden=True, errors="ignore"
) )
.load_data()[0] .load_data()[0]
.text, .text,
) )
token_count = num_tokens_from_string(content) token_count = num_tokens_from_string(content)
if token_count > 100000: if token_count > 100000:
content = content[:250000] content = content[:250000]
token_count = num_tokens_from_string(content) token_count = num_tokens_from_string(content)
self.update_state( self.update_state(
state="PROGRESS", meta={"current": 80, "status": "Storing in database"} state="PROGRESS", meta={"current": 80, "status": "Storing in database"}
) )
@@ -872,37 +947,49 @@ def ingest_connector(
doc_id: Document ID for sync operations (required when operation_mode="sync") doc_id: Document ID for sync operations (required when operation_mode="sync")
sync_frequency: How often to sync ("never", "daily", "weekly", "monthly") sync_frequency: How often to sync ("never", "daily", "weekly", "monthly")
""" """
logging.info(f"Starting remote ingestion from {source_type} for user: {user}, job: {job_name}") logging.info(
f"Starting remote ingestion from {source_type} for user: {user}, job: {job_name}"
)
self.update_state(state="PROGRESS", meta={"current": 1}) self.update_state(state="PROGRESS", meta={"current": 1})
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
try: try:
# Step 1: Initialize the appropriate loader # Step 1: Initialize the appropriate loader
self.update_state(state="PROGRESS", meta={"current": 10, "status": "Initializing connector"}) self.update_state(
state="PROGRESS",
meta={"current": 10, "status": "Initializing connector"},
)
if not session_token: if not session_token:
raise ValueError(f"{source_type} connector requires session_token") raise ValueError(f"{source_type} connector requires session_token")
if not ConnectorCreator.is_supported(source_type): if not ConnectorCreator.is_supported(source_type):
raise ValueError(f"Unsupported connector type: {source_type}. Supported types: {ConnectorCreator.get_supported_connectors()}") raise ValueError(
f"Unsupported connector type: {source_type}. Supported types: {ConnectorCreator.get_supported_connectors()}"
)
remote_loader = ConnectorCreator.create_connector(source_type, session_token) remote_loader = ConnectorCreator.create_connector(
source_type, session_token
)
# Create a clean config for storage # Create a clean config for storage
api_source_config = { api_source_config = {
"file_ids": file_ids or [], "file_ids": file_ids or [],
"folder_ids": folder_ids or [], "folder_ids": folder_ids or [],
"recursive": recursive "recursive": recursive,
} }
# Step 2: Download files to temp directory # Step 2: Download files to temp directory
self.update_state(state="PROGRESS", meta={"current": 20, "status": "Downloading files"}) self.update_state(
download_info = remote_loader.download_to_directory( state="PROGRESS", meta={"current": 20, "status": "Downloading files"}
temp_dir,
api_source_config
) )
download_info = remote_loader.download_to_directory(
if download_info.get("empty_result", False) or not download_info.get("files_downloaded", 0): temp_dir, api_source_config
)
if download_info.get("empty_result", False) or not download_info.get(
"files_downloaded", 0
):
logging.warning(f"No files were downloaded from {source_type}") logging.warning(f"No files were downloaded from {source_type}")
# Create empty result directly instead of calling a separate method # Create empty result directly instead of calling a separate method
return { return {
@@ -913,28 +1000,42 @@ def ingest_connector(
"source_config": api_source_config, "source_config": api_source_config,
"directory_structure": "{}", "directory_structure": "{}",
} }
# Step 3: Use SimpleDirectoryReader to process downloaded files # Step 3: Use SimpleDirectoryReader to process downloaded files
self.update_state(state="PROGRESS", meta={"current": 40, "status": "Processing files"}) self.update_state(
state="PROGRESS", meta={"current": 40, "status": "Processing files"}
)
reader = SimpleDirectoryReader( reader = SimpleDirectoryReader(
input_dir=temp_dir, input_dir=temp_dir,
recursive=True, recursive=True,
required_exts=[ required_exts=[
".rst", ".md", ".pdf", ".txt", ".docx", ".csv", ".epub", ".rst",
".html", ".mdx", ".json", ".xlsx", ".pptx", ".png", ".md",
".jpg", ".jpeg", ".pdf",
".txt",
".docx",
".csv",
".epub",
".html",
".mdx",
".json",
".xlsx",
".pptx",
".png",
".jpg",
".jpeg",
], ],
exclude_hidden=True, exclude_hidden=True,
file_metadata=metadata_from_filename, file_metadata=metadata_from_filename,
) )
raw_docs = reader.load_data() raw_docs = reader.load_data()
directory_structure = getattr(reader, 'directory_structure', {}) directory_structure = getattr(reader, "directory_structure", {})
# Step 4: Process documents (chunking, embedding, etc.) # Step 4: Process documents (chunking, embedding, etc.)
self.update_state(state="PROGRESS", meta={"current": 60, "status": "Processing documents"}) self.update_state(
state="PROGRESS", meta={"current": 60, "status": "Processing documents"}
)
chunker = Chunker( chunker = Chunker(
chunking_strategy="classic_chunk", chunking_strategy="classic_chunk",
max_tokens=MAX_TOKENS, max_tokens=MAX_TOKENS,
@@ -942,22 +1043,26 @@ def ingest_connector(
duplicate_headers=False, duplicate_headers=False,
) )
raw_docs = chunker.chunk(documents=raw_docs) raw_docs = chunker.chunk(documents=raw_docs)
# Preserve source information in document metadata # Preserve source information in document metadata
for doc in raw_docs: for doc in raw_docs:
if hasattr(doc, 'extra_info') and doc.extra_info: if hasattr(doc, "extra_info") and doc.extra_info:
source = doc.extra_info.get('source') source = doc.extra_info.get("source")
if source and os.path.isabs(source): if source and os.path.isabs(source):
# Convert absolute path to relative path # Convert absolute path to relative path
doc.extra_info['source'] = os.path.relpath(source, start=temp_dir) doc.extra_info["source"] = os.path.relpath(
source, start=temp_dir
)
docs = [Document.to_langchain_format(raw_doc) for raw_doc in raw_docs] docs = [Document.to_langchain_format(raw_doc) for raw_doc in raw_docs]
if operation_mode == "upload": if operation_mode == "upload":
id = ObjectId() id = ObjectId()
elif operation_mode == "sync": elif operation_mode == "sync":
if not doc_id or not ObjectId.is_valid(doc_id): if not doc_id or not ObjectId.is_valid(doc_id):
logging.error("Invalid doc_id provided for sync operation: %s", doc_id) logging.error(
"Invalid doc_id provided for sync operation: %s", doc_id
)
raise ValueError("doc_id must be provided for sync operation.") raise ValueError("doc_id must be provided for sync operation.")
id = ObjectId(doc_id) id = ObjectId(doc_id)
else: else:
@@ -966,7 +1071,9 @@ def ingest_connector(
vector_store_path = os.path.join(temp_dir, "vector_store") vector_store_path = os.path.join(temp_dir, "vector_store")
os.makedirs(vector_store_path, exist_ok=True) os.makedirs(vector_store_path, exist_ok=True)
self.update_state(state="PROGRESS", meta={"current": 80, "status": "Storing documents"}) self.update_state(
state="PROGRESS", meta={"current": 80, "status": "Storing documents"}
)
embed_and_store_documents(docs, vector_store_path, id, self) embed_and_store_documents(docs, vector_store_path, id, self)
tokens = count_tokens_docs(docs) tokens = count_tokens_docs(docs)
@@ -979,12 +1086,11 @@ def ingest_connector(
"retriever": retriever, "retriever": retriever,
"id": str(id), "id": str(id),
"type": "connector:file", "type": "connector:file",
"remote_data": json.dumps({ "remote_data": json.dumps(
"provider": source_type, {"provider": source_type, **api_source_config}
**api_source_config ),
}),
"directory_structure": json.dumps(directory_structure), "directory_structure": json.dumps(directory_structure),
"sync_frequency": sync_frequency "sync_frequency": sync_frequency,
} }
if operation_mode == "sync": if operation_mode == "sync":
@@ -995,7 +1101,9 @@ def ingest_connector(
upload_index(vector_store_path, file_data) upload_index(vector_store_path, file_data)
# Ensure we mark the task as complete # Ensure we mark the task as complete
self.update_state(state="PROGRESS", meta={"current": 100, "status": "Complete"}) self.update_state(
state="PROGRESS", meta={"current": 100, "status": "Complete"}
)
logging.info(f"Remote ingestion completed: {job_name}") logging.info(f"Remote ingestion completed: {job_name}")
@@ -1005,9 +1113,136 @@ def ingest_connector(
"tokens": tokens, "tokens": tokens,
"type": source_type, "type": source_type,
"id": str(id), "id": str(id),
"status": "complete" "status": "complete",
} }
except Exception as e: except Exception as e:
logging.error(f"Error during remote ingestion: {e}", exc_info=True) logging.error(f"Error during remote ingestion: {e}", exc_info=True)
raise raise
def mcp_oauth(self, config: Dict[str, Any], user_id: str = None) -> Dict[str, Any]:
"""Worker to handle MCP OAuth flow asynchronously."""
logging.info(
"[MCP OAuth] Worker started for user_id=%s, config=%s", user_id, config
)
try:
import asyncio
from application.agents.tools.mcp_tool import MCPTool
task_id = self.request.id
logging.info("[MCP OAuth] Task ID: %s", task_id)
redis_client = get_redis_instance()
def update_status(status_data: Dict[str, Any]):
logging.info("[MCP OAuth] Updating status: %s", status_data)
status_key = f"mcp_oauth_status:{task_id}"
redis_client.setex(status_key, 600, json.dumps(status_data))
update_status(
{
"status": "in_progress",
"message": "Starting OAuth flow...",
"task_id": task_id,
}
)
tool_config = config.copy()
tool_config["oauth_task_id"] = task_id
logging.info("[MCP OAuth] Initializing MCPTool with config: %s", tool_config)
mcp_tool = MCPTool(tool_config, user_id)
async def run_oauth_discovery():
if not mcp_tool._client:
mcp_tool._setup_client()
return await mcp_tool._execute_with_client("list_tools")
update_status(
{
"status": "awaiting_redirect",
"message": "Waiting for OAuth redirect...",
"task_id": task_id,
}
)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
logging.info("[MCP OAuth] Starting event loop for OAuth discovery...")
tools_response = loop.run_until_complete(run_oauth_discovery())
logging.info(
"[MCP OAuth] Tools response after async call: %s", tools_response
)
status_key = f"mcp_oauth_status:{task_id}"
redis_status = redis_client.get(status_key)
if redis_status:
logging.info(
"[MCP OAuth] Redis status after async call: %s", redis_status
)
else:
logging.warning(
"[MCP OAuth] No Redis status found after async call for key: %s",
status_key,
)
tools = mcp_tool.get_actions_metadata()
update_status(
{
"status": "completed",
"message": f"OAuth completed successfully. Found {len(tools)} tools.",
"tools": tools,
"tools_count": len(tools),
"task_id": task_id,
}
)
logging.info(
"[MCP OAuth] OAuth flow completed successfully for task_id=%s", task_id
)
return {"success": True, "tools": tools, "tools_count": len(tools)}
except Exception as e:
error_msg = f"OAuth flow failed: {str(e)}"
logging.error(
"[MCP OAuth] Exception in OAuth discovery: %s", error_msg, exc_info=True
)
update_status(
{
"status": "error",
"message": error_msg,
"error": str(e),
"task_id": task_id,
}
)
return {"success": False, "error": error_msg}
finally:
logging.info("[MCP OAuth] Closing event loop for task_id=%s", task_id)
loop.close()
except Exception as e:
error_msg = f"Failed to initialize OAuth flow: {str(e)}"
logging.error(
"[MCP OAuth] Exception during initialization: %s", error_msg, exc_info=True
)
update_status(
{
"status": "error",
"message": error_msg,
"error": str(e),
"task_id": task_id,
}
)
return {"success": False, "error": error_msg}
def mcp_oauth_status(self, task_id: str) -> Dict[str, Any]:
"""Check the status of an MCP OAuth flow."""
redis_client = get_redis_instance()
status_key = f"mcp_oauth_status:{task_id}"
status_data = redis_client.get(status_key)
if status_data:
return json.loads(status_data)
return {"status": "not_found", "message": "Status not found"}

View File

@@ -59,6 +59,8 @@ const endpoints = {
MANAGE_SOURCE_FILES: '/api/manage_source_files', MANAGE_SOURCE_FILES: '/api/manage_source_files',
MCP_TEST_CONNECTION: '/api/mcp_server/test', MCP_TEST_CONNECTION: '/api/mcp_server/test',
MCP_SAVE_SERVER: '/api/mcp_server/save', MCP_SAVE_SERVER: '/api/mcp_server/save',
MCP_OAUTH_STATUS: (task_id: string) =>
`/api/mcp_server/oauth_status/${task_id}`,
}, },
CONVERSATION: { CONVERSATION: {
ANSWER: '/api/answer', ANSWER: '/api/answer',

View File

@@ -1,6 +1,6 @@
import { getSessionToken } from '../../utils/providerUtils';
import apiClient from '../client'; import apiClient from '../client';
import endpoints from '../endpoints'; import endpoints from '../endpoints';
import { getSessionToken } from '../../utils/providerUtils';
const userService = { const userService = {
getConfig: (): Promise<any> => apiClient.get(endpoints.USER.CONFIG, null), getConfig: (): Promise<any> => apiClient.get(endpoints.USER.CONFIG, null),
@@ -112,6 +112,8 @@ const userService = {
apiClient.post(endpoints.USER.MCP_TEST_CONNECTION, data, token), apiClient.post(endpoints.USER.MCP_TEST_CONNECTION, data, token),
saveMCPServer: (data: any, token: string | null): Promise<any> => saveMCPServer: (data: any, token: string | null): Promise<any> =>
apiClient.post(endpoints.USER.MCP_SAVE_SERVER, data, token), apiClient.post(endpoints.USER.MCP_SAVE_SERVER, data, token),
getMCPOAuthStatus: (task_id: string, token: string | null): Promise<any> =>
apiClient.get(endpoints.USER.MCP_OAUTH_STATUS(task_id), token),
syncConnector: ( syncConnector: (
docId: string, docId: string,
provider: string, provider: string,

View File

@@ -194,17 +194,20 @@
"headerName": "Header Name", "headerName": "Header Name",
"timeout": "Timeout (seconds)", "timeout": "Timeout (seconds)",
"testConnection": "Test Connection", "testConnection": "Test Connection",
"testing": "Testing...", "testing": "Testing",
"saving": "Saving...", "saving": "Saving",
"save": "Save", "save": "Save",
"cancel": "Cancel", "cancel": "Cancel",
"noAuth": "No Authentication", "noAuth": "No Authentication",
"oauthInProgress": "Waiting for OAuth completion...",
"oauthCompleted": "OAuth completed successfully",
"placeholders": { "placeholders": {
"serverUrl": "https://api.example.com", "serverUrl": "https://api.example.com",
"apiKey": "Your secret API key", "apiKey": "Your secret API key",
"bearerToken": "Your secret token", "bearerToken": "Your secret token",
"username": "Your username", "username": "Your username",
"password": "Your password" "password": "Your password",
"oauthScopes": "OAuth scopes (comma separated)"
}, },
"errors": { "errors": {
"nameRequired": "Server name is required", "nameRequired": "Server name is required",
@@ -215,7 +218,9 @@
"usernameRequired": "Username is required", "usernameRequired": "Username is required",
"passwordRequired": "Password is required", "passwordRequired": "Password is required",
"testFailed": "Connection test failed", "testFailed": "Connection test failed",
"saveFailed": "Failed to save MCP server" "saveFailed": "Failed to save MCP server",
"oauthFailed": "OAuth process failed or was cancelled",
"oauthTimeout": "OAuth process timed out, please try again"
} }
} }
} }

View File

@@ -22,6 +22,7 @@ const authTypes = [
{ label: 'No Authentication', value: 'none' }, { label: 'No Authentication', value: 'none' },
{ label: 'API Key', value: 'api_key' }, { label: 'API Key', value: 'api_key' },
{ label: 'Bearer Token', value: 'bearer' }, { label: 'Bearer Token', value: 'bearer' },
{ label: 'OAuth', value: 'oauth' },
// { label: 'Basic Authentication', value: 'basic' }, // { label: 'Basic Authentication', value: 'basic' },
]; ];
@@ -45,6 +46,8 @@ export default function MCPServerModal({
username: '', username: '',
password: '', password: '',
timeout: server?.timeout || 30, timeout: server?.timeout || 30,
oauth_scopes: '',
oauth_task_id: '',
}); });
const [loading, setLoading] = useState(false); const [loading, setLoading] = useState(false);
@@ -52,8 +55,13 @@ export default function MCPServerModal({
const [testResult, setTestResult] = useState<{ const [testResult, setTestResult] = useState<{
success: boolean; success: boolean;
message: string; message: string;
status?: string;
authorization_url?: string;
} | null>(null); } | null>(null);
const [errors, setErrors] = useState<{ [key: string]: string }>({}); const [errors, setErrors] = useState<{ [key: string]: string }>({});
const oauthPopupRef = useRef<Window | null>(null);
const [oauthCompleted, setOAuthCompleted] = useState(false);
const [saveActive, setSaveActive] = useState(false);
useOutsideAlerter(modalRef, () => { useOutsideAlerter(modalRef, () => {
if (modalState === 'ACTIVE') { if (modalState === 'ACTIVE') {
@@ -73,9 +81,12 @@ export default function MCPServerModal({
username: '', username: '',
password: '', password: '',
timeout: 30, timeout: 30,
oauth_scopes: '',
oauth_task_id: '',
}); });
setErrors({}); setErrors({});
setTestResult(null); setTestResult(null);
setSaveActive(false);
}; };
const validateForm = () => { const validateForm = () => {
@@ -154,10 +165,81 @@ export default function MCPServerModal({
} else if (formData.auth_type === 'basic') { } else if (formData.auth_type === 'basic') {
config.username = formData.username.trim(); config.username = formData.username.trim();
config.password = formData.password.trim(); config.password = formData.password.trim();
} else if (formData.auth_type === 'oauth') {
config.oauth_scopes = formData.oauth_scopes
.split(',')
.map((s) => s.trim())
.filter(Boolean);
config.oauth_task_id = formData.oauth_task_id.trim();
} }
return config; return config;
}; };
const pollOAuthStatus = async (
taskId: string,
onComplete: (result: any) => void,
) => {
let attempts = 0;
const maxAttempts = 60;
let popupOpened = false;
const poll = async () => {
try {
const resp = await userService.getMCPOAuthStatus(taskId, token);
const data = await resp.json();
if (data.authorization_url && !popupOpened) {
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
oauthPopupRef.current.close();
}
oauthPopupRef.current = window.open(
data.authorization_url,
'oauthPopup',
'width=600,height=700',
);
popupOpened = true;
}
if (data.status === 'completed') {
setOAuthCompleted(true);
setSaveActive(true);
onComplete({
...data,
success: true,
message: t('settings.tools.mcp.oauthCompleted'),
});
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
oauthPopupRef.current.close();
}
} else if (data.status === 'error' || data.success === false) {
setSaveActive(false);
onComplete({
...data,
success: false,
message: t('settings.tools.mcp.errors.oauthFailed'),
});
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
oauthPopupRef.current.close();
}
} else {
if (++attempts < maxAttempts) setTimeout(poll, 1000);
else {
setSaveActive(false);
onComplete({
success: false,
message: t('settings.tools.mcp.errors.oauthTimeout'),
});
}
}
} catch {
if (++attempts < maxAttempts) setTimeout(poll, 1000);
else
onComplete({
success: false,
message: t('settings.tools.mcp.errors.oauthTimeout'),
});
}
};
poll();
};
const testConnection = async () => { const testConnection = async () => {
if (!validateForm()) return; if (!validateForm()) return;
setTesting(true); setTesting(true);
@@ -167,13 +249,37 @@ export default function MCPServerModal({
const response = await userService.testMCPConnection({ config }, token); const response = await userService.testMCPConnection({ config }, token);
const result = await response.json(); const result = await response.json();
setTestResult(result); if (
formData.auth_type === 'oauth' &&
result.requires_oauth &&
result.task_id
) {
setTestResult({
success: true,
message: t('settings.tools.mcp.oauthInProgress'),
});
setOAuthCompleted(false);
setSaveActive(false);
pollOAuthStatus(result.task_id, (finalResult) => {
setTestResult(finalResult);
setFormData((prev) => ({
...prev,
oauth_task_id: result.task_id || '',
}));
setTesting(false);
});
} else {
setTestResult(result);
setSaveActive(result.success === true);
setTesting(false);
}
} catch (error) { } catch (error) {
setTestResult({ setTestResult({
success: false, success: false,
message: t('settings.tools.mcp.errors.testFailed'), message: t('settings.tools.mcp.errors.testFailed'),
}); });
} finally { setOAuthCompleted(false);
setSaveActive(false);
setTesting(false); setTesting(false);
} }
}; };
@@ -305,6 +411,28 @@ export default function MCPServerModal({
</div> </div>
</div> </div>
); );
case 'oauth':
return (
<div className="mb-10">
<div className="mt-6">
<Input
name="oauth_scopes"
type="text"
className="rounded-md"
value={formData.oauth_scopes}
onChange={(e) =>
handleInputChange('oauth_scopes', e.target.value)
}
placeholder={
t('settings.tools.mcp.placeholders.oauthScopes') ||
'Scopes (comma separated)'
}
borderVariant="thin"
labelBgClassName="bg-white dark:bg-charleston-green-2"
/>
</div>
</div>
);
default: default:
return null; return null;
} }
@@ -331,7 +459,6 @@ export default function MCPServerModal({
<div className="space-y-6 py-6"> <div className="space-y-6 py-6">
<div> <div>
<Input <Input
name="name"
type="text" type="text"
className="rounded-md" className="rounded-md"
value={formData.name} value={formData.name}
@@ -410,7 +537,7 @@ export default function MCPServerModal({
{testResult && ( {testResult && (
<div <div
className={`rounded-md p-5 ${ className={`rounded-2xl p-5 ${
testResult.success testResult.success
? 'bg-green-50 text-green-700 dark:bg-green-900/40 dark:text-green-300' ? 'bg-green-50 text-green-700 dark:bg-green-900/40 dark:text-green-300'
: 'bg-red-50 text-red-700 dark:bg-red-900 dark:text-red-300' : 'bg-red-50 text-red-700 dark:bg-red-900 dark:text-red-300'
@@ -458,7 +585,7 @@ export default function MCPServerModal({
</button> </button>
<button <button
onClick={handleSave} onClick={handleSave}
disabled={loading} disabled={loading || !saveActive}
className="bg-purple-30 hover:bg-violets-are-blue w-full rounded-3xl px-6 py-2 text-sm font-medium text-white transition-all disabled:opacity-50 sm:w-auto" className="bg-purple-30 hover:bg-violets-are-blue w-full rounded-3xl px-6 py-2 text-sm font-medium text-white transition-all disabled:opacity-50 sm:w-auto"
> >
{loading ? ( {loading ? (