Merge branch 'tester' of https://github.com/manishmadan2882/docsgpt into tester

This commit is contained in:
ManishMadan2882
2025-10-01 01:58:52 +05:30
15 changed files with 1422 additions and 428 deletions

View File

@@ -264,7 +264,15 @@ class BaseAgent(ABC):
query: str,
retrieved_data: List[Dict],
) -> List[Dict]:
docs_together = "\n".join([doc["text"] for doc in retrieved_data])
docs_with_filenames = []
for doc in retrieved_data:
filename = doc.get("filename") or doc.get("title") or doc.get("source")
if filename:
chunk_header = str(filename)
docs_with_filenames.append(f"{chunk_header}\n{doc['text']}")
else:
docs_with_filenames.append(doc["text"])
docs_together = "\n\n".join(docs_with_filenames)
p_chat_combine = system_prompt.replace("{summaries}", docs_together)
messages_combine = [{"role": "system", "content": p_chat_combine}]

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,7 @@
import base64
import datetime
import json
import uuid
from bson.objectid import ObjectId
@@ -13,8 +15,6 @@ from flask import (
from flask_restx import fields, Namespace, Resource
from application.api.user.tasks import (
ingest_connector_task,
)
@@ -234,8 +234,24 @@ class ConnectorAuth(Resource):
if not ConnectorCreator.is_supported(provider):
return make_response(jsonify({"success": False, "error": f"Unsupported provider: {provider}"}), 400)
import uuid
state = str(uuid.uuid4())
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
user_id = decoded_token.get('sub')
now = datetime.datetime.now(datetime.timezone.utc)
result = sessions_collection.insert_one({
"provider": provider,
"user": user_id,
"status": "pending",
"created_at": now
})
state_dict = {
"provider": provider,
"object_id": str(result.inserted_id)
}
state = base64.urlsafe_b64encode(json.dumps(state_dict).encode()).decode()
auth = ConnectorCreator.create_auth(provider)
authorization_url = auth.get_authorization_url(state=state)
return make_response(jsonify({
@@ -256,25 +272,30 @@ class ConnectorsCallback(Resource):
try:
from application.parser.connectors.connector_creator import ConnectorCreator
from flask import request, redirect
import uuid
provider = request.args.get('provider', 'google_drive')
authorization_code = request.args.get('code')
_ = request.args.get('state')
state = request.args.get('state')
error = request.args.get('error')
state_dict = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
provider = state_dict["provider"]
state_object_id = state_dict["object_id"]
if error:
return redirect(f"/api/connectors/callback-status?status=error&message=OAuth+error:+{error}.+Please+try+again+and+make+sure+to+grant+all+requested+permissions,+including+offline+access.&provider={provider}")
if error == "access_denied":
return redirect(f"/api/connectors/callback-status?status=cancelled&message=Authentication+was+cancelled.+You+can+try+again+if+you'd+like+to+connect+your+account.&provider={provider}")
else:
current_app.logger.warning(f"OAuth error in callback: {error}")
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
if not authorization_code:
return redirect(f"/api/connectors/callback-status?status=error&message=Authorization+code+not+provided.+Please+complete+the+authorization+process+and+make+sure+to+grant+offline+access.&provider={provider}")
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
try:
auth = ConnectorCreator.create_auth(provider)
token_info = auth.exchange_code_for_tokens(authorization_code)
session_token = str(uuid.uuid4())
try:
credentials = auth.create_credentials_from_token_info(token_info)
@@ -292,26 +313,28 @@ class ConnectorsCallback(Resource):
"expiry": token_info.get("expiry")
}
user_id = request.decoded_token.get("sub") if getattr(request, "decoded_token", None) else None
sessions_collection.insert_one({
"session_token": session_token,
"user": user_id,
"token_info": sanitized_token_info,
"created_at": datetime.datetime.now(datetime.timezone.utc),
"user_email": user_email,
"provider": provider
})
sessions_collection.find_one_and_update(
{"_id": ObjectId(state_object_id), "provider": provider},
{
"$set": {
"session_token": session_token,
"token_info": sanitized_token_info,
"user_email": user_email,
"status": "authorized"
}
}
)
# Redirect to success page with session token and user email
return redirect(f"/api/connectors/callback-status?status=success&message=Authentication+successful&provider={provider}&session_token={session_token}&user_email={user_email}")
except Exception as e:
current_app.logger.error(f"Error exchanging code for tokens: {str(e)}", exc_info=True)
return redirect(f"/api/connectors/callback-status?status=error&message=Failed+to+exchange+authorization+code+for+tokens:+{str(e)}&provider={provider}")
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
except Exception as e:
current_app.logger.error(f"Error handling connector callback: {e}")
return redirect(f"/api/connectors/callback-status?status=error&message=Failed+to+complete+connector+authentication:+{str(e)}.+Please+try+again+and+make+sure+to+grant+all+requested+permissions,+including+offline+access.")
return redirect("/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.")
@connectors_ns.route("/api/connectors/refresh")
@@ -629,20 +652,23 @@ class ConnectorCallbackStatus(Resource):
.container {{ max-width: 600px; margin: 0 auto; }}
.success {{ color: #4CAF50; }}
.error {{ color: #F44336; }}
.cancelled {{ color: #FF9800; }}
</style>
<script>
window.onload = function() {{
const status = "{status}";
const sessionToken = "{session_token}";
const userEmail = "{user_email}";
if (status === "success" && window.opener) {{
window.opener.postMessage({{
type: '{provider}_auth_success',
session_token: sessionToken,
user_email: userEmail
}}, '*');
setTimeout(() => window.close(), 3000);
}} else if (status === "cancelled" || status === "error") {{
setTimeout(() => window.close(), 3000);
}}
}};
@@ -655,7 +681,7 @@ class ConnectorCallbackStatus(Resource):
<p>{message}</p>
{f'<p>Connected as: {user_email}</p>' if status == 'success' else ''}
</div>
<p><small>You can close this window. {f"Your {provider.replace('_', ' ').title()} is now connected and ready to use." if status == 'success' else ''}</small></p>
<p><small>You can close this window. {f"Your {provider.replace('_', ' ').title()} is now connected and ready to use." if status == 'success' else "Feel free to close this window."}</small></p>
</div>
</body>
</html>

View File

@@ -8,6 +8,7 @@ import uuid
import zipfile
from functools import wraps
from typing import Optional, Tuple
from urllib.parse import unquote
from bson.binary import Binary, UuidRepresentation
from bson.dbref import DBRef
@@ -25,7 +26,7 @@ from flask_restx import fields, inputs, Namespace, Resource
from pymongo import ReturnDocument
from werkzeug.utils import secure_filename
from application.agents.tools.mcp_tool import MCPTool
from application.agents.tools.mcp_tool import MCPOAuthManager, MCPTool
from application.agents.tools.tool_manager import ToolManager
from application.api import api
@@ -37,6 +38,8 @@ from application.api.user.tasks import (
process_agent_webhook,
store_attachment,
)
from application.cache import get_redis_instance
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.parser.connectors.connector_creator import ConnectorCreator
@@ -494,7 +497,6 @@ class DeleteOldIndexes(Resource):
)
if not doc:
return make_response(jsonify({"status": "not found"}), 404)
storage = StorageCreator.get_storage()
try:
@@ -511,7 +513,6 @@ class DeleteOldIndexes(Resource):
settings.VECTOR_STORE, source_id=str(doc["_id"])
)
vectorstore.delete_index()
if "file_path" in doc and doc["file_path"]:
file_path = doc["file_path"]
if storage.is_directory(file_path):
@@ -520,7 +521,6 @@ class DeleteOldIndexes(Resource):
storage.delete_file(f)
else:
storage.delete_file(file_path)
except FileNotFoundError:
pass
except Exception as err:
@@ -528,7 +528,6 @@ class DeleteOldIndexes(Resource):
f"Error deleting files and indexes: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
sources_collection.delete_one({"_id": ObjectId(source_id)})
return make_response(jsonify({"success": True}), 200)
@@ -600,7 +599,6 @@ class UploadFile(Resource):
== temp_file_path
):
continue
rel_path = os.path.relpath(
os.path.join(root, extracted_file), temp_dir
)
@@ -625,7 +623,6 @@ class UploadFile(Resource):
file_path = f"{base_path}/{safe_file}"
with open(temp_file_path, "rb") as f:
storage.save_file(f, file_path)
task = ingest.delay(
settings.UPLOAD_FOLDER,
[
@@ -697,7 +694,6 @@ class ManageSourceFiles(Resource):
return make_response(
jsonify({"success": False, "message": "Unauthorized"}), 401
)
user = decoded_token.get("sub")
source_id = request.form.get("source_id")
operation = request.form.get("operation")
@@ -747,7 +743,6 @@ class ManageSourceFiles(Resource):
return make_response(
jsonify({"success": False, "message": "Database error"}), 500
)
try:
storage = StorageCreator.get_storage()
source_file_path = source.get("file_path", "")
@@ -804,7 +799,6 @@ class ManageSourceFiles(Resource):
),
200,
)
elif operation == "remove":
file_paths_str = request.form.get("file_paths")
if not file_paths_str:
@@ -858,7 +852,6 @@ class ManageSourceFiles(Resource):
),
200,
)
elif operation == "remove_directory":
directory_path = request.form.get("directory_path")
if not directory_path:
@@ -884,7 +877,6 @@ class ManageSourceFiles(Resource):
),
400,
)
full_directory_path = (
f"{source_file_path}/{directory_path}"
if directory_path
@@ -943,7 +935,6 @@ class ManageSourceFiles(Resource):
),
200,
)
except Exception as err:
error_context = f"operation={operation}, user={user}, source_id={source_id}"
if operation == "remove_directory":
@@ -955,7 +946,6 @@ class ManageSourceFiles(Resource):
elif operation == "add":
parent_dir = request.form.get("parent_dir", "")
error_context += f", parent_dir={parent_dir}"
current_app.logger.error(
f"Error managing source files: {err} ({error_context})", exc_info=True
)
@@ -1632,7 +1622,6 @@ class CreateAgent(Resource):
),
400,
)
# Validate that it has either a 'schema' property or is itself a schema
if "schema" not in json_schema and "type" not in json_schema:
@@ -3476,7 +3465,6 @@ class AvailableTools(Resource):
"displayName": name,
"description": description,
"configRequirements": tool_instance.get_config_requirements(),
"actions": tool_instance.get_actions_metadata(),
}
)
except Exception as err:
@@ -3527,11 +3515,6 @@ class CreateTool(Resource):
"customName": fields.String(
required=False, description="Custom name for the tool"
),
"actions": fields.List(
fields.Raw,
required=True,
description="Actions the tool can perform",
),
"status": fields.Boolean(
required=True, description="Status of the tool"
),
@@ -3549,24 +3532,35 @@ class CreateTool(Resource):
"name",
"displayName",
"description",
"actions",
"config",
"status",
]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
transformed_actions = []
for action in data["actions"]:
action["active"] = True
if "parameters" in action:
if "properties" in action["parameters"]:
for param_name, param_details in action["parameters"][
"properties"
].items():
param_details["filled_by_llm"] = True
param_details["value"] = ""
transformed_actions.append(action)
try:
tool_instance = tool_manager.tools.get(data["name"])
if not tool_instance:
return make_response(
jsonify({"success": False, "message": "Tool not found"}), 404
)
actions_metadata = tool_instance.get_actions_metadata()
transformed_actions = []
for action in actions_metadata:
action["active"] = True
if "parameters" in action:
if "properties" in action["parameters"]:
for param_name, param_details in action["parameters"][
"properties"
].items():
param_details["filled_by_llm"] = True
param_details["value"] = ""
transformed_actions.append(action)
except Exception as err:
current_app.logger.error(
f"Error getting tool actions: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
try:
new_tool = {
"user": user,
@@ -3907,7 +3901,6 @@ class GetChunks(Resource):
if not (text_match or title_match):
continue
filtered_chunks.append(chunk)
chunks = filtered_chunks
total_chunks = len(chunks)
@@ -4098,7 +4091,6 @@ class UpdateChunk(Resource):
current_app.logger.warning(
f"Failed to delete old chunk {chunk_id}, but new chunk {new_chunk_id} was created"
)
return make_response(
jsonify(
{
@@ -4226,23 +4218,19 @@ class DirectoryStructure(Resource):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
doc_id = request.args.get("id")
if not doc_id:
return make_response(jsonify({"error": "Document ID is required"}), 400)
if not ObjectId.is_valid(doc_id):
return make_response(jsonify({"error": "Invalid document ID"}), 400)
try:
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
if not doc:
return make_response(
jsonify({"error": "Document not found or access denied"}), 404
)
directory_structure = doc.get("directory_structure", {})
base_path = doc.get("file_path", "")
@@ -4315,11 +4303,10 @@ class TestMCPServerConfig(Resource):
auth_credentials["username"] = config["username"]
if "password" in config:
auth_credentials["password"] = config["password"]
test_config = config.copy()
test_config["auth_credentials"] = auth_credentials
mcp_tool = MCPTool(test_config, user)
mcp_tool = MCPTool(config=test_config, user_id=user)
result = mcp_tool.test_connection()
return make_response(jsonify(result), 200)
@@ -4387,22 +4374,45 @@ class MCPServerSave(Resource):
mcp_config = config.copy()
mcp_config["auth_credentials"] = auth_credentials
if auth_type == "none" or auth_credentials:
mcp_tool = MCPTool(mcp_config, user)
if auth_type == "oauth":
if not config.get("oauth_task_id"):
return make_response(
jsonify(
{
"success": False,
"error": "Connection not authorized. Please complete the OAuth authorization first.",
}
),
400,
)
redis_client = get_redis_instance()
manager = MCPOAuthManager(redis_client)
result = manager.get_oauth_status(config["oauth_task_id"])
if not result.get("status") == "completed":
return make_response(
jsonify(
{
"success": False,
"error": "OAuth failed or not completed. Please try authorizing again.",
}
),
400,
)
actions_metadata = result.get("tools", [])
elif auth_type == "none" or auth_credentials:
mcp_tool = MCPTool(config=mcp_config, user_id=user)
mcp_tool.discover_tools()
actions_metadata = mcp_tool.get_actions_metadata()
else:
raise Exception(
"No valid credentials provided for the selected authentication type"
)
storage_config = config.copy()
if auth_credentials:
encrypted_credentials_string = encrypt_credentials(
auth_credentials, user
)
storage_config["encrypted_credentials"] = encrypted_credentials_string
for field in [
"api_key",
"bearer_token",
@@ -4473,3 +4483,96 @@ class MCPServerSave(Resource):
),
500,
)
@user_ns.route("/api/mcp_server/callback")
class MCPOAuthCallback(Resource):
@api.expect(
api.model(
"MCPServerCallbackModel",
{
"code": fields.String(required=True, description="Authorization code"),
"state": fields.String(required=True, description="State parameter"),
"error": fields.String(
required=False, description="Error message (if any)"
),
},
)
)
@api.doc(
description="Handle OAuth callback by providing the authorization code and state"
)
def get(self):
code = request.args.get("code")
state = request.args.get("state")
error = request.args.get("error")
if error:
return redirect(
f"/api/connectors/callback-status?status=error&message=OAuth+error:+{error}.+Please+try+again+and+make+sure+to+grant+all+requested+permissions,+including+offline+access.&provider=mcp_tool"
)
if not code or not state:
return redirect(
"/api/connectors/callback-status?status=error&message=Authorization+code+or+state+not+provided.+Please+complete+the+authorization+process+and+make+sure+to+grant+offline+access.&provider=mcp_tool"
)
try:
redis_client = get_redis_instance()
if not redis_client:
return redirect(
"/api/connectors/callback-status?status=error&message=Internal+server+error:+Redis+not+available.&provider=mcp_tool"
)
code = unquote(code)
manager = MCPOAuthManager(redis_client)
success = manager.handle_oauth_callback(state, code, error)
if success:
return redirect(
"/api/connectors/callback-status?status=success&message=Authorization+code+received+successfully.+You+can+close+this+window.&provider=mcp_tool"
)
else:
return redirect(
"/api/connectors/callback-status?status=error&message=OAuth+callback+failed.&provider=mcp_tool"
)
except Exception as e:
current_app.logger.error(
f"Error handling MCP OAuth callback: {str(e)}", exc_info=True
)
return redirect(
f"/api/connectors/callback-status?status=error&message=Internal+server+error:+{str(e)}.&provider=mcp_tool"
)
@user_ns.route("/api/mcp_server/oauth_status/<string:task_id>")
class MCPOAuthStatus(Resource):
def get(self, task_id):
"""
Get current status of OAuth flow.
Frontend should poll this endpoint periodically.
"""
try:
redis_client = get_redis_instance()
status_key = f"mcp_oauth_status:{task_id}"
status_data = redis_client.get(status_key)
if status_data:
status = json.loads(status_data)
return make_response(
jsonify({"success": True, "task_id": task_id, **status})
)
else:
return make_response(
jsonify(
{
"success": False,
"error": "Task not found or expired",
"task_id": task_id,
}
),
404,
)
except Exception as e:
current_app.logger.error(
f"Error getting OAuth status for task {task_id}: {str(e)}"
)
return make_response(
jsonify({"success": False, "error": str(e), "task_id": task_id}), 500
)

View File

@@ -5,6 +5,8 @@ from application.worker import (
agent_webhook_worker,
attachment_worker,
ingest_worker,
mcp_oauth,
mcp_oauth_status,
remote_worker,
sync_worker,
)
@@ -25,6 +27,7 @@ def ingest_remote(self, source_data, job_name, user, loader):
@celery.task(bind=True)
def reingest_source_task(self, source_id, user):
from application.worker import reingest_source_worker
resp = reingest_source_worker(self, source_id, user)
return resp
@@ -60,9 +63,10 @@ def ingest_connector_task(
retriever="classic",
operation_mode="upload",
doc_id=None,
sync_frequency="never"
sync_frequency="never",
):
from application.worker import ingest_connector
resp = ingest_connector(
self,
job_name,
@@ -75,7 +79,7 @@ def ingest_connector_task(
retriever=retriever,
operation_mode=operation_mode,
doc_id=doc_id,
sync_frequency=sync_frequency
sync_frequency=sync_frequency,
)
return resp
@@ -94,3 +98,15 @@ def setup_periodic_tasks(sender, **kwargs):
timedelta(days=30),
schedule_syncs.s("monthly"),
)
@celery.task(bind=True)
def mcp_oauth_task(self, config, user):
resp = mcp_oauth(self, config, user)
return resp
@celery.task(bind=True)
def mcp_oauth_status_task(self, task_id):
resp = mcp_oauth_status(self, task_id)
return resp

View File

@@ -43,8 +43,7 @@ class Settings(BaseSettings):
# Google Drive integration
GOOGLE_CLIENT_ID: Optional[str] = None # Replace with your actual Google OAuth client ID
GOOGLE_CLIENT_SECRET: Optional[str] = None# Replace with your actual Google OAuth client secret
CONNECTOR_REDIRECT_BASE_URI: Optional[str] = "http://127.0.0.1:7091/api/connectors/callback"
##append ?provider={provider_name} in your Provider console like http://127.0.0.1:7091/api/connectors/callback?provider=google_drive
CONNECTOR_REDIRECT_BASE_URI: Optional[str] = "http://127.0.0.1:7091/api/connectors/callback" ##add redirect url as it is to your provider's console(gcp)
# LLM Cache

View File

@@ -23,7 +23,7 @@ class GoogleDriveAuth(BaseConnectorAuth):
def __init__(self):
self.client_id = settings.GOOGLE_CLIENT_ID
self.client_secret = settings.GOOGLE_CLIENT_SECRET
self.redirect_uri = f"{settings.CONNECTOR_REDIRECT_BASE_URI}?provider=google_drive"
self.redirect_uri = f"{settings.CONNECTOR_REDIRECT_BASE_URI}"
if not self.client_id or not self.client_secret:
raise ValueError("Google OAuth credentials not configured. Please set GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET in settings.")

View File

@@ -12,6 +12,7 @@ esprima==4.0.1
esutils==1.0.1
Flask==3.1.1
faiss-cpu==1.9.0.post1
fastmcp==2.11.0
flask-restx==1.3.0
google-genai==1.3.0
google-api-python-client==2.179.0
@@ -56,13 +57,13 @@ prompt-toolkit==3.0.51
protobuf==5.29.3
psycopg2-binary==2.9.10
py==1.11.0
pydantic==2.10.6
pydantic-core==2.27.2
pydantic-settings==2.7.1
pydantic
pydantic-core
pydantic-settings
pymongo==4.11.3
pypdf==5.5.0
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
python-dotenv
python-jose==3.4.0
python-pptx==1.0.2
redis==5.2.1
@@ -82,7 +83,7 @@ tzdata==2024.2
urllib3==2.3.0
vine==5.1.0
wcwidth==0.2.13
werkzeug==3.1.3
werkzeug>=3.1.0,<3.1.2
yarl==1.20.0
markdownify==1.1.0
tldextract==5.1.3

View File

@@ -1,4 +1,5 @@
import logging
import os
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
@@ -141,15 +142,28 @@ class ClassicRAG(BaseRetriever):
title = metadata.get(
"title", metadata.get("post_title", page_content)
)
if isinstance(title, str):
title = title.split("/")[-1]
if not isinstance(title, str):
title = str(title)
title = title.split("/")[-1]
filename = (
metadata.get("filename")
or metadata.get("file_name")
or metadata.get("source")
)
if isinstance(filename, str):
filename = os.path.basename(filename) or filename
else:
title = str(title).split("/")[-1]
filename = title
if not filename:
filename = title
source_path = metadata.get("source") or vectorstore_id
all_docs.append(
{
"title": title,
"text": page_content,
"source": metadata.get("source") or vectorstore_id,
"source": source_path,
"filename": filename,
}
)
except Exception as e:

View File

@@ -19,6 +19,7 @@ from bson.objectid import ObjectId
from application.agents.agent_creator import AgentCreator
from application.api.answer.services.stream_processor import get_prompt
from application.cache import get_redis_instance
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.parser.chunking import Chunker
@@ -214,8 +215,7 @@ def run_agent_logic(agent_config, input_data):
def ingest_worker(
self, directory, formats, job_name, file_path, filename, user,
retriever="classic"
self, directory, formats, job_name, file_path, filename, user, retriever="classic"
):
"""
Ingest and process documents.
@@ -240,7 +240,7 @@ def ingest_worker(
sample = False
storage = StorageCreator.get_storage()
logging.info(f"Ingest path: {file_path}", extra={"user": user, "job": job_name})
# Create temporary working directory
@@ -253,30 +253,32 @@ def ingest_worker(
# Handle directory case
logging.info(f"Processing directory: {file_path}")
files_list = storage.list_files(file_path)
for storage_file_path in files_list:
if storage.is_directory(storage_file_path):
continue
# Create relative path structure in temp directory
rel_path = os.path.relpath(storage_file_path, file_path)
local_file_path = os.path.join(temp_dir, rel_path)
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
# Download file
try:
file_data = storage.get_file(storage_file_path)
with open(local_file_path, "wb") as f:
f.write(file_data.read())
except Exception as e:
logging.error(f"Error downloading file {storage_file_path}: {e}")
logging.error(
f"Error downloading file {storage_file_path}: {e}"
)
continue
else:
# Handle single file case
temp_filename = os.path.basename(file_path)
temp_file_path = os.path.join(temp_dir, temp_filename)
file_data = storage.get_file(file_path)
with open(temp_file_path, "wb") as f:
f.write(file_data.read())
@@ -285,7 +287,10 @@ def ingest_worker(
if temp_filename.endswith(".zip"):
logging.info(f"Extracting zip file: {temp_filename}")
extract_zip_recursive(
temp_file_path, temp_dir, current_depth=0, max_depth=RECURSION_DEPTH
temp_file_path,
temp_dir,
current_depth=0,
max_depth=RECURSION_DEPTH,
)
self.update_state(state="PROGRESS", meta={"current": 1})
@@ -300,8 +305,8 @@ def ingest_worker(
file_metadata=metadata_from_filename,
)
raw_docs = reader.load_data()
directory_structure = getattr(reader, 'directory_structure', {})
directory_structure = getattr(reader, "directory_structure", {})
logging.info(f"Directory structure from reader: {directory_structure}")
chunker = Chunker(
@@ -371,7 +376,10 @@ def reingest_source_worker(self, source_id, user):
try:
from application.vectorstore.vector_creator import VectorCreator
self.update_state(state="PROGRESS", meta={"current": 10, "status": "Initializing re-ingestion scan"})
self.update_state(
state="PROGRESS",
meta={"current": 10, "status": "Initializing re-ingestion scan"},
)
source = sources_collection.find_one({"_id": ObjectId(source_id), "user": user})
if not source:
@@ -380,7 +388,9 @@ def reingest_source_worker(self, source_id, user):
storage = StorageCreator.get_storage()
source_file_path = source.get("file_path", "")
self.update_state(state="PROGRESS", meta={"current": 20, "status": "Scanning current files"})
self.update_state(
state="PROGRESS", meta={"current": 20, "status": "Scanning current files"}
)
with tempfile.TemporaryDirectory() as temp_dir:
# Download all files from storage to temp directory, preserving directory structure
@@ -391,7 +401,6 @@ def reingest_source_worker(self, source_id, user):
if storage.is_directory(storage_file_path):
continue
rel_path = os.path.relpath(storage_file_path, source_file_path)
local_file_path = os.path.join(temp_dir, rel_path)
@@ -403,23 +412,39 @@ def reingest_source_worker(self, source_id, user):
with open(local_file_path, "wb") as f:
f.write(file_data.read())
except Exception as e:
logging.error(f"Error downloading file {storage_file_path}: {e}")
logging.error(
f"Error downloading file {storage_file_path}: {e}"
)
continue
reader = SimpleDirectoryReader(
input_dir=temp_dir,
recursive=True,
required_exts=[
".rst", ".md", ".pdf", ".txt", ".docx", ".csv", ".epub",
".html", ".mdx", ".json", ".xlsx", ".pptx", ".png",
".jpg", ".jpeg",
".rst",
".md",
".pdf",
".txt",
".docx",
".csv",
".epub",
".html",
".mdx",
".json",
".xlsx",
".pptx",
".png",
".jpg",
".jpeg",
],
exclude_hidden=True,
file_metadata=metadata_from_filename,
)
reader.load_data()
directory_structure = reader.directory_structure
logging.info(f"Directory structure built with token counts: {directory_structure}")
logging.info(
f"Directory structure built with token counts: {directory_structure}"
)
try:
old_directory_structure = source.get("directory_structure") or {}
@@ -433,11 +458,17 @@ def reingest_source_worker(self, source_id, user):
files = set()
if isinstance(struct, dict):
for name, meta in struct.items():
current_path = os.path.join(prefix, name) if prefix else name
if isinstance(meta, dict) and ("type" in meta and "size_bytes" in meta):
current_path = (
os.path.join(prefix, name) if prefix else name
)
if isinstance(meta, dict) and (
"type" in meta and "size_bytes" in meta
):
files.add(current_path)
elif isinstance(meta, dict):
files |= _flatten_directory_structure(meta, current_path)
files |= _flatten_directory_structure(
meta, current_path
)
return files
old_files = _flatten_directory_structure(old_directory_structure)
@@ -457,7 +488,9 @@ def reingest_source_worker(self, source_id, user):
logging.info("No files removed since last ingest.")
except Exception as e:
logging.error(f"Error comparing directory structures: {e}", exc_info=True)
logging.error(
f"Error comparing directory structures: {e}", exc_info=True
)
added_files = []
removed_files = []
try:
@@ -477,14 +510,21 @@ def reingest_source_worker(self, source_id, user):
settings.EMBEDDINGS_KEY,
)
self.update_state(state="PROGRESS", meta={"current": 40, "status": "Processing file changes"})
self.update_state(
state="PROGRESS",
meta={"current": 40, "status": "Processing file changes"},
)
# 1) Delete chunks from removed files
deleted = 0
if removed_files:
try:
for ch in vector_store.get_chunks() or []:
metadata = ch.get("metadata", {}) if isinstance(ch, dict) else getattr(ch, "metadata", {})
metadata = (
ch.get("metadata", {})
if isinstance(ch, dict)
else getattr(ch, "metadata", {})
)
raw_source = metadata.get("source")
source_file = str(raw_source) if raw_source else ""
@@ -496,10 +536,17 @@ def reingest_source_worker(self, source_id, user):
vector_store.delete_chunk(cid)
deleted += 1
except Exception as de:
logging.error(f"Failed deleting chunk {cid}: {de}")
logging.info(f"Deleted {deleted} chunks from {len(removed_files)} removed files")
logging.error(
f"Failed deleting chunk {cid}: {de}"
)
logging.info(
f"Deleted {deleted} chunks from {len(removed_files)} removed files"
)
except Exception as e:
logging.error(f"Error during deletion of removed file chunks: {e}", exc_info=True)
logging.error(
f"Error during deletion of removed file chunks: {e}",
exc_info=True,
)
# 2) Add chunks from new files
added = 0
@@ -528,58 +575,86 @@ def reingest_source_worker(self, source_id, user):
)
chunked_new = chunker_new.chunk(documents=raw_docs_new)
for file_path, token_count in reader_new.file_token_counts.items():
for (
file_path,
token_count,
) in reader_new.file_token_counts.items():
try:
rel_path = os.path.relpath(file_path, start=temp_dir)
rel_path = os.path.relpath(
file_path, start=temp_dir
)
path_parts = rel_path.split(os.sep)
current_dir = directory_structure
for part in path_parts[:-1]:
if part in current_dir and isinstance(current_dir[part], dict):
if part in current_dir and isinstance(
current_dir[part], dict
):
current_dir = current_dir[part]
else:
break
filename = path_parts[-1]
if filename in current_dir and isinstance(current_dir[filename], dict):
current_dir[filename]["token_count"] = token_count
logging.info(f"Updated token count for {rel_path}: {token_count}")
if filename in current_dir and isinstance(
current_dir[filename], dict
):
current_dir[filename][
"token_count"
] = token_count
logging.info(
f"Updated token count for {rel_path}: {token_count}"
)
except Exception as e:
logging.warning(f"Could not update token count for {file_path}: {e}")
logging.warning(
f"Could not update token count for {file_path}: {e}"
)
for d in chunked_new:
meta = dict(d.extra_info or {})
try:
raw_src = meta.get("source")
if isinstance(raw_src, str) and os.path.isabs(raw_src):
meta["source"] = os.path.relpath(raw_src, start=temp_dir)
if isinstance(raw_src, str) and os.path.isabs(
raw_src
):
meta["source"] = os.path.relpath(
raw_src, start=temp_dir
)
except Exception:
pass
vector_store.add_chunk(d.text, metadata=meta)
added += 1
logging.info(f"Added {added} chunks from {len(added_files)} new files")
logging.info(
f"Added {added} chunks from {len(added_files)} new files"
)
except Exception as e:
logging.error(f"Error during ingestion of new files: {e}", exc_info=True)
logging.error(
f"Error during ingestion of new files: {e}", exc_info=True
)
# 3) Update source directory structure timestamp
try:
total_tokens = sum(reader.file_token_counts.values())
sources_collection.update_one(
{"_id": ObjectId(source_id)},
{
"$set": {
"directory_structure": directory_structure,
"date": datetime.datetime.now(),
"tokens": total_tokens
"tokens": total_tokens,
}
},
)
except Exception as e:
logging.error(f"Error updating directory_structure in DB: {e}", exc_info=True)
logging.error(
f"Error updating directory_structure in DB: {e}", exc_info=True
)
self.update_state(state="PROGRESS", meta={"current": 100, "status": "Re-ingestion completed"})
self.update_state(
state="PROGRESS",
meta={"current": 100, "status": "Re-ingestion completed"},
)
return {
"source_id": source_id,
@@ -591,15 +666,16 @@ def reingest_source_worker(self, source_id, user):
"chunks_deleted": deleted,
}
except Exception as e:
logging.error(f"Error while processing file changes: {e}", exc_info=True)
logging.error(
f"Error while processing file changes: {e}", exc_info=True
)
raise
except Exception as e:
logging.error(f"Error in reingest_source_worker: {e}", exc_info=True)
raise
def remote_worker(
self,
source_data,
@@ -651,7 +727,7 @@ def remote_worker(
"id": str(id),
"type": loader,
"remote_data": source_data,
"sync_frequency": sync_frequency
"sync_frequency": sync_frequency,
}
if operation_mode == "sync":
@@ -712,7 +788,7 @@ def sync_worker(self, frequency):
self, source_data, name, user, source_type, frequency, retriever, doc_id
)
sync_counts["total_sync_count"] += 1
sync_counts[
sync_counts[
"sync_success" if resp["status"] == "success" else "sync_failure"
] += 1
return {
@@ -749,15 +825,14 @@ def attachment_worker(self, file_info, user):
input_files=[local_path], exclude_hidden=True, errors="ignore"
)
.load_data()[0]
.text,
.text,
)
token_count = num_tokens_from_string(content)
if token_count > 100000:
content = content[:250000]
token_count = num_tokens_from_string(content)
self.update_state(
state="PROGRESS", meta={"current": 80, "status": "Storing in database"}
)
@@ -872,37 +947,49 @@ def ingest_connector(
doc_id: Document ID for sync operations (required when operation_mode="sync")
sync_frequency: How often to sync ("never", "daily", "weekly", "monthly")
"""
logging.info(f"Starting remote ingestion from {source_type} for user: {user}, job: {job_name}")
logging.info(
f"Starting remote ingestion from {source_type} for user: {user}, job: {job_name}"
)
self.update_state(state="PROGRESS", meta={"current": 1})
with tempfile.TemporaryDirectory() as temp_dir:
try:
# Step 1: Initialize the appropriate loader
self.update_state(state="PROGRESS", meta={"current": 10, "status": "Initializing connector"})
self.update_state(
state="PROGRESS",
meta={"current": 10, "status": "Initializing connector"},
)
if not session_token:
raise ValueError(f"{source_type} connector requires session_token")
if not ConnectorCreator.is_supported(source_type):
raise ValueError(f"Unsupported connector type: {source_type}. Supported types: {ConnectorCreator.get_supported_connectors()}")
raise ValueError(
f"Unsupported connector type: {source_type}. Supported types: {ConnectorCreator.get_supported_connectors()}"
)
remote_loader = ConnectorCreator.create_connector(source_type, session_token)
remote_loader = ConnectorCreator.create_connector(
source_type, session_token
)
# Create a clean config for storage
api_source_config = {
"file_ids": file_ids or [],
"folder_ids": folder_ids or [],
"recursive": recursive
"recursive": recursive,
}
# Step 2: Download files to temp directory
self.update_state(state="PROGRESS", meta={"current": 20, "status": "Downloading files"})
download_info = remote_loader.download_to_directory(
temp_dir,
api_source_config
self.update_state(
state="PROGRESS", meta={"current": 20, "status": "Downloading files"}
)
if download_info.get("empty_result", False) or not download_info.get("files_downloaded", 0):
download_info = remote_loader.download_to_directory(
temp_dir, api_source_config
)
if download_info.get("empty_result", False) or not download_info.get(
"files_downloaded", 0
):
logging.warning(f"No files were downloaded from {source_type}")
# Create empty result directly instead of calling a separate method
return {
@@ -913,28 +1000,42 @@ def ingest_connector(
"source_config": api_source_config,
"directory_structure": "{}",
}
# Step 3: Use SimpleDirectoryReader to process downloaded files
self.update_state(state="PROGRESS", meta={"current": 40, "status": "Processing files"})
self.update_state(
state="PROGRESS", meta={"current": 40, "status": "Processing files"}
)
reader = SimpleDirectoryReader(
input_dir=temp_dir,
recursive=True,
required_exts=[
".rst", ".md", ".pdf", ".txt", ".docx", ".csv", ".epub",
".html", ".mdx", ".json", ".xlsx", ".pptx", ".png",
".jpg", ".jpeg",
".rst",
".md",
".pdf",
".txt",
".docx",
".csv",
".epub",
".html",
".mdx",
".json",
".xlsx",
".pptx",
".png",
".jpg",
".jpeg",
],
exclude_hidden=True,
file_metadata=metadata_from_filename,
)
raw_docs = reader.load_data()
directory_structure = getattr(reader, 'directory_structure', {})
directory_structure = getattr(reader, "directory_structure", {})
# Step 4: Process documents (chunking, embedding, etc.)
self.update_state(state="PROGRESS", meta={"current": 60, "status": "Processing documents"})
self.update_state(
state="PROGRESS", meta={"current": 60, "status": "Processing documents"}
)
chunker = Chunker(
chunking_strategy="classic_chunk",
max_tokens=MAX_TOKENS,
@@ -942,22 +1043,26 @@ def ingest_connector(
duplicate_headers=False,
)
raw_docs = chunker.chunk(documents=raw_docs)
# Preserve source information in document metadata
for doc in raw_docs:
if hasattr(doc, 'extra_info') and doc.extra_info:
source = doc.extra_info.get('source')
if hasattr(doc, "extra_info") and doc.extra_info:
source = doc.extra_info.get("source")
if source and os.path.isabs(source):
# Convert absolute path to relative path
doc.extra_info['source'] = os.path.relpath(source, start=temp_dir)
doc.extra_info["source"] = os.path.relpath(
source, start=temp_dir
)
docs = [Document.to_langchain_format(raw_doc) for raw_doc in raw_docs]
if operation_mode == "upload":
id = ObjectId()
elif operation_mode == "sync":
if not doc_id or not ObjectId.is_valid(doc_id):
logging.error("Invalid doc_id provided for sync operation: %s", doc_id)
logging.error(
"Invalid doc_id provided for sync operation: %s", doc_id
)
raise ValueError("doc_id must be provided for sync operation.")
id = ObjectId(doc_id)
else:
@@ -966,7 +1071,9 @@ def ingest_connector(
vector_store_path = os.path.join(temp_dir, "vector_store")
os.makedirs(vector_store_path, exist_ok=True)
self.update_state(state="PROGRESS", meta={"current": 80, "status": "Storing documents"})
self.update_state(
state="PROGRESS", meta={"current": 80, "status": "Storing documents"}
)
embed_and_store_documents(docs, vector_store_path, id, self)
tokens = count_tokens_docs(docs)
@@ -979,12 +1086,11 @@ def ingest_connector(
"retriever": retriever,
"id": str(id),
"type": "connector:file",
"remote_data": json.dumps({
"provider": source_type,
**api_source_config
}),
"remote_data": json.dumps(
{"provider": source_type, **api_source_config}
),
"directory_structure": json.dumps(directory_structure),
"sync_frequency": sync_frequency
"sync_frequency": sync_frequency,
}
if operation_mode == "sync":
@@ -995,7 +1101,9 @@ def ingest_connector(
upload_index(vector_store_path, file_data)
# Ensure we mark the task as complete
self.update_state(state="PROGRESS", meta={"current": 100, "status": "Complete"})
self.update_state(
state="PROGRESS", meta={"current": 100, "status": "Complete"}
)
logging.info(f"Remote ingestion completed: {job_name}")
@@ -1005,9 +1113,136 @@ def ingest_connector(
"tokens": tokens,
"type": source_type,
"id": str(id),
"status": "complete"
"status": "complete",
}
except Exception as e:
logging.error(f"Error during remote ingestion: {e}", exc_info=True)
raise
def mcp_oauth(self, config: Dict[str, Any], user_id: str = None) -> Dict[str, Any]:
"""Worker to handle MCP OAuth flow asynchronously."""
logging.info(
"[MCP OAuth] Worker started for user_id=%s, config=%s", user_id, config
)
try:
import asyncio
from application.agents.tools.mcp_tool import MCPTool
task_id = self.request.id
logging.info("[MCP OAuth] Task ID: %s", task_id)
redis_client = get_redis_instance()
def update_status(status_data: Dict[str, Any]):
logging.info("[MCP OAuth] Updating status: %s", status_data)
status_key = f"mcp_oauth_status:{task_id}"
redis_client.setex(status_key, 600, json.dumps(status_data))
update_status(
{
"status": "in_progress",
"message": "Starting OAuth flow...",
"task_id": task_id,
}
)
tool_config = config.copy()
tool_config["oauth_task_id"] = task_id
logging.info("[MCP OAuth] Initializing MCPTool with config: %s", tool_config)
mcp_tool = MCPTool(tool_config, user_id)
async def run_oauth_discovery():
if not mcp_tool._client:
mcp_tool._setup_client()
return await mcp_tool._execute_with_client("list_tools")
update_status(
{
"status": "awaiting_redirect",
"message": "Waiting for OAuth redirect...",
"task_id": task_id,
}
)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
logging.info("[MCP OAuth] Starting event loop for OAuth discovery...")
tools_response = loop.run_until_complete(run_oauth_discovery())
logging.info(
"[MCP OAuth] Tools response after async call: %s", tools_response
)
status_key = f"mcp_oauth_status:{task_id}"
redis_status = redis_client.get(status_key)
if redis_status:
logging.info(
"[MCP OAuth] Redis status after async call: %s", redis_status
)
else:
logging.warning(
"[MCP OAuth] No Redis status found after async call for key: %s",
status_key,
)
tools = mcp_tool.get_actions_metadata()
update_status(
{
"status": "completed",
"message": f"OAuth completed successfully. Found {len(tools)} tools.",
"tools": tools,
"tools_count": len(tools),
"task_id": task_id,
}
)
logging.info(
"[MCP OAuth] OAuth flow completed successfully for task_id=%s", task_id
)
return {"success": True, "tools": tools, "tools_count": len(tools)}
except Exception as e:
error_msg = f"OAuth flow failed: {str(e)}"
logging.error(
"[MCP OAuth] Exception in OAuth discovery: %s", error_msg, exc_info=True
)
update_status(
{
"status": "error",
"message": error_msg,
"error": str(e),
"task_id": task_id,
}
)
return {"success": False, "error": error_msg}
finally:
logging.info("[MCP OAuth] Closing event loop for task_id=%s", task_id)
loop.close()
except Exception as e:
error_msg = f"Failed to initialize OAuth flow: {str(e)}"
logging.error(
"[MCP OAuth] Exception during initialization: %s", error_msg, exc_info=True
)
update_status(
{
"status": "error",
"message": error_msg,
"error": str(e),
"task_id": task_id,
}
)
return {"success": False, "error": error_msg}
def mcp_oauth_status(self, task_id: str) -> Dict[str, Any]:
"""Check the status of an MCP OAuth flow."""
redis_client = get_redis_instance()
status_key = f"mcp_oauth_status:{task_id}"
status_data = redis_client.get(status_key)
if status_data:
return json.loads(status_data)
return {"status": "not_found", "message": "Status not found"}

View File

@@ -59,6 +59,8 @@ const endpoints = {
MANAGE_SOURCE_FILES: '/api/manage_source_files',
MCP_TEST_CONNECTION: '/api/mcp_server/test',
MCP_SAVE_SERVER: '/api/mcp_server/save',
MCP_OAUTH_STATUS: (task_id: string) =>
`/api/mcp_server/oauth_status/${task_id}`,
},
CONVERSATION: {
ANSWER: '/api/answer',

View File

@@ -1,6 +1,6 @@
import { getSessionToken } from '../../utils/providerUtils';
import apiClient from '../client';
import endpoints from '../endpoints';
import { getSessionToken } from '../../utils/providerUtils';
const userService = {
getConfig: (): Promise<any> => apiClient.get(endpoints.USER.CONFIG, null),
@@ -112,6 +112,8 @@ const userService = {
apiClient.post(endpoints.USER.MCP_TEST_CONNECTION, data, token),
saveMCPServer: (data: any, token: string | null): Promise<any> =>
apiClient.post(endpoints.USER.MCP_SAVE_SERVER, data, token),
getMCPOAuthStatus: (task_id: string, token: string | null): Promise<any> =>
apiClient.get(endpoints.USER.MCP_OAUTH_STATUS(task_id), token),
syncConnector: (
docId: string,
provider: string,

View File

@@ -210,7 +210,7 @@ export default function ConversationMessages({
)}
<div className="w-full max-w-[1300px] px-2 md:w-9/12 lg:w-8/12 xl:w-8/12 2xl:w-6/12">
{headerContent && headerContent}
{headerContent}
{queries.length > 0 ? (
queries.map((query, index) => (

View File

@@ -194,17 +194,20 @@
"headerName": "Header Name",
"timeout": "Timeout (seconds)",
"testConnection": "Test Connection",
"testing": "Testing...",
"saving": "Saving...",
"testing": "Testing",
"saving": "Saving",
"save": "Save",
"cancel": "Cancel",
"noAuth": "No Authentication",
"oauthInProgress": "Waiting for OAuth completion...",
"oauthCompleted": "OAuth completed successfully",
"placeholders": {
"serverUrl": "https://api.example.com",
"apiKey": "Your secret API key",
"bearerToken": "Your secret token",
"username": "Your username",
"password": "Your password"
"password": "Your password",
"oauthScopes": "OAuth scopes (comma separated)"
},
"errors": {
"nameRequired": "Server name is required",
@@ -215,7 +218,9 @@
"usernameRequired": "Username is required",
"passwordRequired": "Password is required",
"testFailed": "Connection test failed",
"saveFailed": "Failed to save MCP server"
"saveFailed": "Failed to save MCP server",
"oauthFailed": "OAuth process failed or was cancelled",
"oauthTimeout": "OAuth process timed out, please try again"
}
}
}

View File

@@ -22,6 +22,7 @@ const authTypes = [
{ label: 'No Authentication', value: 'none' },
{ label: 'API Key', value: 'api_key' },
{ label: 'Bearer Token', value: 'bearer' },
{ label: 'OAuth', value: 'oauth' },
// { label: 'Basic Authentication', value: 'basic' },
];
@@ -45,6 +46,8 @@ export default function MCPServerModal({
username: '',
password: '',
timeout: server?.timeout || 30,
oauth_scopes: '',
oauth_task_id: '',
});
const [loading, setLoading] = useState(false);
@@ -52,8 +55,13 @@ export default function MCPServerModal({
const [testResult, setTestResult] = useState<{
success: boolean;
message: string;
status?: string;
authorization_url?: string;
} | null>(null);
const [errors, setErrors] = useState<{ [key: string]: string }>({});
const oauthPopupRef = useRef<Window | null>(null);
const [oauthCompleted, setOAuthCompleted] = useState(false);
const [saveActive, setSaveActive] = useState(false);
useOutsideAlerter(modalRef, () => {
if (modalState === 'ACTIVE') {
@@ -73,9 +81,12 @@ export default function MCPServerModal({
username: '',
password: '',
timeout: 30,
oauth_scopes: '',
oauth_task_id: '',
});
setErrors({});
setTestResult(null);
setSaveActive(false);
};
const validateForm = () => {
@@ -154,10 +165,81 @@ export default function MCPServerModal({
} else if (formData.auth_type === 'basic') {
config.username = formData.username.trim();
config.password = formData.password.trim();
} else if (formData.auth_type === 'oauth') {
config.oauth_scopes = formData.oauth_scopes
.split(',')
.map((s) => s.trim())
.filter(Boolean);
config.oauth_task_id = formData.oauth_task_id.trim();
}
return config;
};
const pollOAuthStatus = async (
taskId: string,
onComplete: (result: any) => void,
) => {
let attempts = 0;
const maxAttempts = 60;
let popupOpened = false;
const poll = async () => {
try {
const resp = await userService.getMCPOAuthStatus(taskId, token);
const data = await resp.json();
if (data.authorization_url && !popupOpened) {
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
oauthPopupRef.current.close();
}
oauthPopupRef.current = window.open(
data.authorization_url,
'oauthPopup',
'width=600,height=700',
);
popupOpened = true;
}
if (data.status === 'completed') {
setOAuthCompleted(true);
setSaveActive(true);
onComplete({
...data,
success: true,
message: t('settings.tools.mcp.oauthCompleted'),
});
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
oauthPopupRef.current.close();
}
} else if (data.status === 'error' || data.success === false) {
setSaveActive(false);
onComplete({
...data,
success: false,
message: t('settings.tools.mcp.errors.oauthFailed'),
});
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
oauthPopupRef.current.close();
}
} else {
if (++attempts < maxAttempts) setTimeout(poll, 1000);
else {
setSaveActive(false);
onComplete({
success: false,
message: t('settings.tools.mcp.errors.oauthTimeout'),
});
}
}
} catch {
if (++attempts < maxAttempts) setTimeout(poll, 1000);
else
onComplete({
success: false,
message: t('settings.tools.mcp.errors.oauthTimeout'),
});
}
};
poll();
};
const testConnection = async () => {
if (!validateForm()) return;
setTesting(true);
@@ -167,13 +249,37 @@ export default function MCPServerModal({
const response = await userService.testMCPConnection({ config }, token);
const result = await response.json();
setTestResult(result);
if (
formData.auth_type === 'oauth' &&
result.requires_oauth &&
result.task_id
) {
setTestResult({
success: true,
message: t('settings.tools.mcp.oauthInProgress'),
});
setOAuthCompleted(false);
setSaveActive(false);
pollOAuthStatus(result.task_id, (finalResult) => {
setTestResult(finalResult);
setFormData((prev) => ({
...prev,
oauth_task_id: result.task_id || '',
}));
setTesting(false);
});
} else {
setTestResult(result);
setSaveActive(result.success === true);
setTesting(false);
}
} catch (error) {
setTestResult({
success: false,
message: t('settings.tools.mcp.errors.testFailed'),
});
} finally {
setOAuthCompleted(false);
setSaveActive(false);
setTesting(false);
}
};
@@ -305,6 +411,28 @@ export default function MCPServerModal({
</div>
</div>
);
case 'oauth':
return (
<div className="mb-10">
<div className="mt-6">
<Input
name="oauth_scopes"
type="text"
className="rounded-md"
value={formData.oauth_scopes}
onChange={(e) =>
handleInputChange('oauth_scopes', e.target.value)
}
placeholder={
t('settings.tools.mcp.placeholders.oauthScopes') ||
'Scopes (comma separated)'
}
borderVariant="thin"
labelBgClassName="bg-white dark:bg-charleston-green-2"
/>
</div>
</div>
);
default:
return null;
}
@@ -331,7 +459,6 @@ export default function MCPServerModal({
<div className="space-y-6 py-6">
<div>
<Input
name="name"
type="text"
className="rounded-md"
value={formData.name}
@@ -410,7 +537,7 @@ export default function MCPServerModal({
{testResult && (
<div
className={`rounded-md p-5 ${
className={`rounded-2xl p-5 ${
testResult.success
? 'bg-green-50 text-green-700 dark:bg-green-900/40 dark:text-green-300'
: 'bg-red-50 text-red-700 dark:bg-red-900 dark:text-red-300'
@@ -458,7 +585,7 @@ export default function MCPServerModal({
</button>
<button
onClick={handleSave}
disabled={loading}
disabled={loading || !saveActive}
className="bg-purple-30 hover:bg-violets-are-blue w-full rounded-3xl px-6 py-2 text-sm font-medium text-white transition-all disabled:opacity-50 sm:w-auto"
>
{loading ? (