mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
Merge branch 'tester' of https://github.com/manishmadan2882/docsgpt into tester
This commit is contained in:
@@ -264,7 +264,15 @@ class BaseAgent(ABC):
|
||||
query: str,
|
||||
retrieved_data: List[Dict],
|
||||
) -> List[Dict]:
|
||||
docs_together = "\n".join([doc["text"] for doc in retrieved_data])
|
||||
docs_with_filenames = []
|
||||
for doc in retrieved_data:
|
||||
filename = doc.get("filename") or doc.get("title") or doc.get("source")
|
||||
if filename:
|
||||
chunk_header = str(filename)
|
||||
docs_with_filenames.append(f"{chunk_header}\n{doc['text']}")
|
||||
else:
|
||||
docs_with_filenames.append(doc["text"])
|
||||
docs_together = "\n\n".join(docs_with_filenames)
|
||||
p_chat_combine = system_prompt.replace("{summaries}", docs_together)
|
||||
messages_combine = [{"role": "system", "content": p_chat_combine}]
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,7 @@
|
||||
import base64
|
||||
import datetime
|
||||
import json
|
||||
import uuid
|
||||
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
@@ -13,8 +15,6 @@ from flask import (
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
|
||||
|
||||
|
||||
from application.api.user.tasks import (
|
||||
ingest_connector_task,
|
||||
)
|
||||
@@ -234,8 +234,24 @@ class ConnectorAuth(Resource):
|
||||
if not ConnectorCreator.is_supported(provider):
|
||||
return make_response(jsonify({"success": False, "error": f"Unsupported provider: {provider}"}), 400)
|
||||
|
||||
import uuid
|
||||
state = str(uuid.uuid4())
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
||||
user_id = decoded_token.get('sub')
|
||||
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
result = sessions_collection.insert_one({
|
||||
"provider": provider,
|
||||
"user": user_id,
|
||||
"status": "pending",
|
||||
"created_at": now
|
||||
})
|
||||
state_dict = {
|
||||
"provider": provider,
|
||||
"object_id": str(result.inserted_id)
|
||||
}
|
||||
state = base64.urlsafe_b64encode(json.dumps(state_dict).encode()).decode()
|
||||
|
||||
auth = ConnectorCreator.create_auth(provider)
|
||||
authorization_url = auth.get_authorization_url(state=state)
|
||||
return make_response(jsonify({
|
||||
@@ -256,18 +272,24 @@ class ConnectorsCallback(Resource):
|
||||
try:
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
from flask import request, redirect
|
||||
import uuid
|
||||
|
||||
provider = request.args.get('provider', 'google_drive')
|
||||
authorization_code = request.args.get('code')
|
||||
_ = request.args.get('state')
|
||||
state = request.args.get('state')
|
||||
error = request.args.get('error')
|
||||
|
||||
state_dict = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
|
||||
provider = state_dict["provider"]
|
||||
state_object_id = state_dict["object_id"]
|
||||
|
||||
if error:
|
||||
return redirect(f"/api/connectors/callback-status?status=error&message=OAuth+error:+{error}.+Please+try+again+and+make+sure+to+grant+all+requested+permissions,+including+offline+access.&provider={provider}")
|
||||
if error == "access_denied":
|
||||
return redirect(f"/api/connectors/callback-status?status=cancelled&message=Authentication+was+cancelled.+You+can+try+again+if+you'd+like+to+connect+your+account.&provider={provider}")
|
||||
else:
|
||||
current_app.logger.warning(f"OAuth error in callback: {error}")
|
||||
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
|
||||
|
||||
if not authorization_code:
|
||||
return redirect(f"/api/connectors/callback-status?status=error&message=Authorization+code+not+provided.+Please+complete+the+authorization+process+and+make+sure+to+grant+offline+access.&provider={provider}")
|
||||
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
|
||||
|
||||
try:
|
||||
auth = ConnectorCreator.create_auth(provider)
|
||||
@@ -275,7 +297,6 @@ class ConnectorsCallback(Resource):
|
||||
|
||||
session_token = str(uuid.uuid4())
|
||||
|
||||
|
||||
try:
|
||||
credentials = auth.create_credentials_from_token_info(token_info)
|
||||
service = auth.build_drive_service(credentials)
|
||||
@@ -292,26 +313,28 @@ class ConnectorsCallback(Resource):
|
||||
"expiry": token_info.get("expiry")
|
||||
}
|
||||
|
||||
user_id = request.decoded_token.get("sub") if getattr(request, "decoded_token", None) else None
|
||||
sessions_collection.insert_one({
|
||||
sessions_collection.find_one_and_update(
|
||||
{"_id": ObjectId(state_object_id), "provider": provider},
|
||||
{
|
||||
"$set": {
|
||||
"session_token": session_token,
|
||||
"user": user_id,
|
||||
"token_info": sanitized_token_info,
|
||||
"created_at": datetime.datetime.now(datetime.timezone.utc),
|
||||
"user_email": user_email,
|
||||
"provider": provider
|
||||
})
|
||||
"status": "authorized"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Redirect to success page with session token and user email
|
||||
return redirect(f"/api/connectors/callback-status?status=success&message=Authentication+successful&provider={provider}&session_token={session_token}&user_email={user_email}")
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error exchanging code for tokens: {str(e)}", exc_info=True)
|
||||
return redirect(f"/api/connectors/callback-status?status=error&message=Failed+to+exchange+authorization+code+for+tokens:+{str(e)}&provider={provider}")
|
||||
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error handling connector callback: {e}")
|
||||
return redirect(f"/api/connectors/callback-status?status=error&message=Failed+to+complete+connector+authentication:+{str(e)}.+Please+try+again+and+make+sure+to+grant+all+requested+permissions,+including+offline+access.")
|
||||
return redirect("/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.")
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/refresh")
|
||||
@@ -629,6 +652,7 @@ class ConnectorCallbackStatus(Resource):
|
||||
.container {{ max-width: 600px; margin: 0 auto; }}
|
||||
.success {{ color: #4CAF50; }}
|
||||
.error {{ color: #F44336; }}
|
||||
.cancelled {{ color: #FF9800; }}
|
||||
</style>
|
||||
<script>
|
||||
window.onload = function() {{
|
||||
@@ -643,6 +667,8 @@ class ConnectorCallbackStatus(Resource):
|
||||
user_email: userEmail
|
||||
}}, '*');
|
||||
|
||||
setTimeout(() => window.close(), 3000);
|
||||
}} else if (status === "cancelled" || status === "error") {{
|
||||
setTimeout(() => window.close(), 3000);
|
||||
}}
|
||||
}};
|
||||
@@ -655,7 +681,7 @@ class ConnectorCallbackStatus(Resource):
|
||||
<p>{message}</p>
|
||||
{f'<p>Connected as: {user_email}</p>' if status == 'success' else ''}
|
||||
</div>
|
||||
<p><small>You can close this window. {f"Your {provider.replace('_', ' ').title()} is now connected and ready to use." if status == 'success' else ''}</small></p>
|
||||
<p><small>You can close this window. {f"Your {provider.replace('_', ' ').title()} is now connected and ready to use." if status == 'success' else "Feel free to close this window."}</small></p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
|
||||
@@ -8,6 +8,7 @@ import uuid
|
||||
import zipfile
|
||||
from functools import wraps
|
||||
from typing import Optional, Tuple
|
||||
from urllib.parse import unquote
|
||||
|
||||
from bson.binary import Binary, UuidRepresentation
|
||||
from bson.dbref import DBRef
|
||||
@@ -25,7 +26,7 @@ from flask_restx import fields, inputs, Namespace, Resource
|
||||
from pymongo import ReturnDocument
|
||||
from werkzeug.utils import secure_filename
|
||||
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
from application.agents.tools.mcp_tool import MCPOAuthManager, MCPTool
|
||||
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
from application.api import api
|
||||
@@ -37,6 +38,8 @@ from application.api.user.tasks import (
|
||||
process_agent_webhook,
|
||||
store_attachment,
|
||||
)
|
||||
|
||||
from application.cache import get_redis_instance
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
@@ -494,7 +497,6 @@ class DeleteOldIndexes(Resource):
|
||||
)
|
||||
if not doc:
|
||||
return make_response(jsonify({"status": "not found"}), 404)
|
||||
|
||||
storage = StorageCreator.get_storage()
|
||||
|
||||
try:
|
||||
@@ -511,7 +513,6 @@ class DeleteOldIndexes(Resource):
|
||||
settings.VECTOR_STORE, source_id=str(doc["_id"])
|
||||
)
|
||||
vectorstore.delete_index()
|
||||
|
||||
if "file_path" in doc and doc["file_path"]:
|
||||
file_path = doc["file_path"]
|
||||
if storage.is_directory(file_path):
|
||||
@@ -520,7 +521,6 @@ class DeleteOldIndexes(Resource):
|
||||
storage.delete_file(f)
|
||||
else:
|
||||
storage.delete_file(file_path)
|
||||
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception as err:
|
||||
@@ -528,7 +528,6 @@ class DeleteOldIndexes(Resource):
|
||||
f"Error deleting files and indexes: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
sources_collection.delete_one({"_id": ObjectId(source_id)})
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
@@ -600,7 +599,6 @@ class UploadFile(Resource):
|
||||
== temp_file_path
|
||||
):
|
||||
continue
|
||||
|
||||
rel_path = os.path.relpath(
|
||||
os.path.join(root, extracted_file), temp_dir
|
||||
)
|
||||
@@ -625,7 +623,6 @@ class UploadFile(Resource):
|
||||
file_path = f"{base_path}/{safe_file}"
|
||||
with open(temp_file_path, "rb") as f:
|
||||
storage.save_file(f, file_path)
|
||||
|
||||
task = ingest.delay(
|
||||
settings.UPLOAD_FOLDER,
|
||||
[
|
||||
@@ -697,7 +694,6 @@ class ManageSourceFiles(Resource):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Unauthorized"}), 401
|
||||
)
|
||||
|
||||
user = decoded_token.get("sub")
|
||||
source_id = request.form.get("source_id")
|
||||
operation = request.form.get("operation")
|
||||
@@ -747,7 +743,6 @@ class ManageSourceFiles(Resource):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Database error"}), 500
|
||||
)
|
||||
|
||||
try:
|
||||
storage = StorageCreator.get_storage()
|
||||
source_file_path = source.get("file_path", "")
|
||||
@@ -804,7 +799,6 @@ class ManageSourceFiles(Resource):
|
||||
),
|
||||
200,
|
||||
)
|
||||
|
||||
elif operation == "remove":
|
||||
file_paths_str = request.form.get("file_paths")
|
||||
if not file_paths_str:
|
||||
@@ -858,7 +852,6 @@ class ManageSourceFiles(Resource):
|
||||
),
|
||||
200,
|
||||
)
|
||||
|
||||
elif operation == "remove_directory":
|
||||
directory_path = request.form.get("directory_path")
|
||||
if not directory_path:
|
||||
@@ -884,7 +877,6 @@ class ManageSourceFiles(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
|
||||
full_directory_path = (
|
||||
f"{source_file_path}/{directory_path}"
|
||||
if directory_path
|
||||
@@ -943,7 +935,6 @@ class ManageSourceFiles(Resource):
|
||||
),
|
||||
200,
|
||||
)
|
||||
|
||||
except Exception as err:
|
||||
error_context = f"operation={operation}, user={user}, source_id={source_id}"
|
||||
if operation == "remove_directory":
|
||||
@@ -955,7 +946,6 @@ class ManageSourceFiles(Resource):
|
||||
elif operation == "add":
|
||||
parent_dir = request.form.get("parent_dir", "")
|
||||
error_context += f", parent_dir={parent_dir}"
|
||||
|
||||
current_app.logger.error(
|
||||
f"Error managing source files: {err} ({error_context})", exc_info=True
|
||||
)
|
||||
@@ -1632,7 +1622,6 @@ class CreateAgent(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
|
||||
# Validate that it has either a 'schema' property or is itself a schema
|
||||
|
||||
if "schema" not in json_schema and "type" not in json_schema:
|
||||
@@ -3476,7 +3465,6 @@ class AvailableTools(Resource):
|
||||
"displayName": name,
|
||||
"description": description,
|
||||
"configRequirements": tool_instance.get_config_requirements(),
|
||||
"actions": tool_instance.get_actions_metadata(),
|
||||
}
|
||||
)
|
||||
except Exception as err:
|
||||
@@ -3527,11 +3515,6 @@ class CreateTool(Resource):
|
||||
"customName": fields.String(
|
||||
required=False, description="Custom name for the tool"
|
||||
),
|
||||
"actions": fields.List(
|
||||
fields.Raw,
|
||||
required=True,
|
||||
description="Actions the tool can perform",
|
||||
),
|
||||
"status": fields.Boolean(
|
||||
required=True, description="Status of the tool"
|
||||
),
|
||||
@@ -3549,15 +3532,21 @@ class CreateTool(Resource):
|
||||
"name",
|
||||
"displayName",
|
||||
"description",
|
||||
"actions",
|
||||
"config",
|
||||
"status",
|
||||
]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
tool_instance = tool_manager.tools.get(data["name"])
|
||||
if not tool_instance:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Tool not found"}), 404
|
||||
)
|
||||
actions_metadata = tool_instance.get_actions_metadata()
|
||||
transformed_actions = []
|
||||
for action in data["actions"]:
|
||||
for action in actions_metadata:
|
||||
action["active"] = True
|
||||
if "parameters" in action:
|
||||
if "properties" in action["parameters"]:
|
||||
@@ -3567,6 +3556,11 @@ class CreateTool(Resource):
|
||||
param_details["filled_by_llm"] = True
|
||||
param_details["value"] = ""
|
||||
transformed_actions.append(action)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error getting tool actions: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
try:
|
||||
new_tool = {
|
||||
"user": user,
|
||||
@@ -3907,7 +3901,6 @@ class GetChunks(Resource):
|
||||
if not (text_match or title_match):
|
||||
continue
|
||||
filtered_chunks.append(chunk)
|
||||
|
||||
chunks = filtered_chunks
|
||||
|
||||
total_chunks = len(chunks)
|
||||
@@ -4098,7 +4091,6 @@ class UpdateChunk(Resource):
|
||||
current_app.logger.warning(
|
||||
f"Failed to delete old chunk {chunk_id}, but new chunk {new_chunk_id} was created"
|
||||
)
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
@@ -4226,23 +4218,19 @@ class DirectoryStructure(Resource):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
|
||||
user = decoded_token.get("sub")
|
||||
doc_id = request.args.get("id")
|
||||
|
||||
if not doc_id:
|
||||
return make_response(jsonify({"error": "Document ID is required"}), 400)
|
||||
|
||||
if not ObjectId.is_valid(doc_id):
|
||||
return make_response(jsonify({"error": "Invalid document ID"}), 400)
|
||||
|
||||
try:
|
||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"error": "Document not found or access denied"}), 404
|
||||
)
|
||||
|
||||
directory_structure = doc.get("directory_structure", {})
|
||||
base_path = doc.get("file_path", "")
|
||||
|
||||
@@ -4315,11 +4303,10 @@ class TestMCPServerConfig(Resource):
|
||||
auth_credentials["username"] = config["username"]
|
||||
if "password" in config:
|
||||
auth_credentials["password"] = config["password"]
|
||||
|
||||
test_config = config.copy()
|
||||
test_config["auth_credentials"] = auth_credentials
|
||||
|
||||
mcp_tool = MCPTool(test_config, user)
|
||||
mcp_tool = MCPTool(config=test_config, user_id=user)
|
||||
result = mcp_tool.test_connection()
|
||||
|
||||
return make_response(jsonify(result), 200)
|
||||
@@ -4387,22 +4374,45 @@ class MCPServerSave(Resource):
|
||||
mcp_config = config.copy()
|
||||
mcp_config["auth_credentials"] = auth_credentials
|
||||
|
||||
if auth_type == "none" or auth_credentials:
|
||||
mcp_tool = MCPTool(mcp_config, user)
|
||||
if auth_type == "oauth":
|
||||
if not config.get("oauth_task_id"):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Connection not authorized. Please complete the OAuth authorization first.",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
redis_client = get_redis_instance()
|
||||
manager = MCPOAuthManager(redis_client)
|
||||
result = manager.get_oauth_status(config["oauth_task_id"])
|
||||
if not result.get("status") == "completed":
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "OAuth failed or not completed. Please try authorizing again.",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
actions_metadata = result.get("tools", [])
|
||||
elif auth_type == "none" or auth_credentials:
|
||||
mcp_tool = MCPTool(config=mcp_config, user_id=user)
|
||||
mcp_tool.discover_tools()
|
||||
actions_metadata = mcp_tool.get_actions_metadata()
|
||||
else:
|
||||
raise Exception(
|
||||
"No valid credentials provided for the selected authentication type"
|
||||
)
|
||||
|
||||
storage_config = config.copy()
|
||||
if auth_credentials:
|
||||
encrypted_credentials_string = encrypt_credentials(
|
||||
auth_credentials, user
|
||||
)
|
||||
storage_config["encrypted_credentials"] = encrypted_credentials_string
|
||||
|
||||
for field in [
|
||||
"api_key",
|
||||
"bearer_token",
|
||||
@@ -4473,3 +4483,96 @@ class MCPServerSave(Resource):
|
||||
),
|
||||
500,
|
||||
)
|
||||
|
||||
|
||||
@user_ns.route("/api/mcp_server/callback")
|
||||
class MCPOAuthCallback(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"MCPServerCallbackModel",
|
||||
{
|
||||
"code": fields.String(required=True, description="Authorization code"),
|
||||
"state": fields.String(required=True, description="State parameter"),
|
||||
"error": fields.String(
|
||||
required=False, description="Error message (if any)"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Handle OAuth callback by providing the authorization code and state"
|
||||
)
|
||||
def get(self):
|
||||
code = request.args.get("code")
|
||||
state = request.args.get("state")
|
||||
error = request.args.get("error")
|
||||
|
||||
if error:
|
||||
return redirect(
|
||||
f"/api/connectors/callback-status?status=error&message=OAuth+error:+{error}.+Please+try+again+and+make+sure+to+grant+all+requested+permissions,+including+offline+access.&provider=mcp_tool"
|
||||
)
|
||||
if not code or not state:
|
||||
return redirect(
|
||||
"/api/connectors/callback-status?status=error&message=Authorization+code+or+state+not+provided.+Please+complete+the+authorization+process+and+make+sure+to+grant+offline+access.&provider=mcp_tool"
|
||||
)
|
||||
try:
|
||||
redis_client = get_redis_instance()
|
||||
if not redis_client:
|
||||
return redirect(
|
||||
"/api/connectors/callback-status?status=error&message=Internal+server+error:+Redis+not+available.&provider=mcp_tool"
|
||||
)
|
||||
code = unquote(code)
|
||||
manager = MCPOAuthManager(redis_client)
|
||||
success = manager.handle_oauth_callback(state, code, error)
|
||||
if success:
|
||||
return redirect(
|
||||
"/api/connectors/callback-status?status=success&message=Authorization+code+received+successfully.+You+can+close+this+window.&provider=mcp_tool"
|
||||
)
|
||||
else:
|
||||
return redirect(
|
||||
"/api/connectors/callback-status?status=error&message=OAuth+callback+failed.&provider=mcp_tool"
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error handling MCP OAuth callback: {str(e)}", exc_info=True
|
||||
)
|
||||
return redirect(
|
||||
f"/api/connectors/callback-status?status=error&message=Internal+server+error:+{str(e)}.&provider=mcp_tool"
|
||||
)
|
||||
|
||||
|
||||
@user_ns.route("/api/mcp_server/oauth_status/<string:task_id>")
|
||||
class MCPOAuthStatus(Resource):
|
||||
def get(self, task_id):
|
||||
"""
|
||||
Get current status of OAuth flow.
|
||||
Frontend should poll this endpoint periodically.
|
||||
"""
|
||||
try:
|
||||
redis_client = get_redis_instance()
|
||||
status_key = f"mcp_oauth_status:{task_id}"
|
||||
status_data = redis_client.get(status_key)
|
||||
|
||||
if status_data:
|
||||
status = json.loads(status_data)
|
||||
return make_response(
|
||||
jsonify({"success": True, "task_id": task_id, **status})
|
||||
)
|
||||
else:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Task not found or expired",
|
||||
"task_id": task_id,
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error getting OAuth status for task {task_id}: {str(e)}"
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "error": str(e), "task_id": task_id}), 500
|
||||
)
|
||||
|
||||
@@ -5,6 +5,8 @@ from application.worker import (
|
||||
agent_webhook_worker,
|
||||
attachment_worker,
|
||||
ingest_worker,
|
||||
mcp_oauth,
|
||||
mcp_oauth_status,
|
||||
remote_worker,
|
||||
sync_worker,
|
||||
)
|
||||
@@ -25,6 +27,7 @@ def ingest_remote(self, source_data, job_name, user, loader):
|
||||
@celery.task(bind=True)
|
||||
def reingest_source_task(self, source_id, user):
|
||||
from application.worker import reingest_source_worker
|
||||
|
||||
resp = reingest_source_worker(self, source_id, user)
|
||||
return resp
|
||||
|
||||
@@ -60,9 +63,10 @@ def ingest_connector_task(
|
||||
retriever="classic",
|
||||
operation_mode="upload",
|
||||
doc_id=None,
|
||||
sync_frequency="never"
|
||||
sync_frequency="never",
|
||||
):
|
||||
from application.worker import ingest_connector
|
||||
|
||||
resp = ingest_connector(
|
||||
self,
|
||||
job_name,
|
||||
@@ -75,7 +79,7 @@ def ingest_connector_task(
|
||||
retriever=retriever,
|
||||
operation_mode=operation_mode,
|
||||
doc_id=doc_id,
|
||||
sync_frequency=sync_frequency
|
||||
sync_frequency=sync_frequency,
|
||||
)
|
||||
return resp
|
||||
|
||||
@@ -94,3 +98,15 @@ def setup_periodic_tasks(sender, **kwargs):
|
||||
timedelta(days=30),
|
||||
schedule_syncs.s("monthly"),
|
||||
)
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def mcp_oauth_task(self, config, user):
|
||||
resp = mcp_oauth(self, config, user)
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def mcp_oauth_status_task(self, task_id):
|
||||
resp = mcp_oauth_status(self, task_id)
|
||||
return resp
|
||||
|
||||
@@ -43,8 +43,7 @@ class Settings(BaseSettings):
|
||||
# Google Drive integration
|
||||
GOOGLE_CLIENT_ID: Optional[str] = None # Replace with your actual Google OAuth client ID
|
||||
GOOGLE_CLIENT_SECRET: Optional[str] = None# Replace with your actual Google OAuth client secret
|
||||
CONNECTOR_REDIRECT_BASE_URI: Optional[str] = "http://127.0.0.1:7091/api/connectors/callback"
|
||||
##append ?provider={provider_name} in your Provider console like http://127.0.0.1:7091/api/connectors/callback?provider=google_drive
|
||||
CONNECTOR_REDIRECT_BASE_URI: Optional[str] = "http://127.0.0.1:7091/api/connectors/callback" ##add redirect url as it is to your provider's console(gcp)
|
||||
|
||||
|
||||
# LLM Cache
|
||||
|
||||
@@ -23,7 +23,7 @@ class GoogleDriveAuth(BaseConnectorAuth):
|
||||
def __init__(self):
|
||||
self.client_id = settings.GOOGLE_CLIENT_ID
|
||||
self.client_secret = settings.GOOGLE_CLIENT_SECRET
|
||||
self.redirect_uri = f"{settings.CONNECTOR_REDIRECT_BASE_URI}?provider=google_drive"
|
||||
self.redirect_uri = f"{settings.CONNECTOR_REDIRECT_BASE_URI}"
|
||||
|
||||
if not self.client_id or not self.client_secret:
|
||||
raise ValueError("Google OAuth credentials not configured. Please set GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET in settings.")
|
||||
|
||||
@@ -12,6 +12,7 @@ esprima==4.0.1
|
||||
esutils==1.0.1
|
||||
Flask==3.1.1
|
||||
faiss-cpu==1.9.0.post1
|
||||
fastmcp==2.11.0
|
||||
flask-restx==1.3.0
|
||||
google-genai==1.3.0
|
||||
google-api-python-client==2.179.0
|
||||
@@ -56,13 +57,13 @@ prompt-toolkit==3.0.51
|
||||
protobuf==5.29.3
|
||||
psycopg2-binary==2.9.10
|
||||
py==1.11.0
|
||||
pydantic==2.10.6
|
||||
pydantic-core==2.27.2
|
||||
pydantic-settings==2.7.1
|
||||
pydantic
|
||||
pydantic-core
|
||||
pydantic-settings
|
||||
pymongo==4.11.3
|
||||
pypdf==5.5.0
|
||||
python-dateutil==2.9.0.post0
|
||||
python-dotenv==1.0.1
|
||||
python-dotenv
|
||||
python-jose==3.4.0
|
||||
python-pptx==1.0.2
|
||||
redis==5.2.1
|
||||
@@ -82,7 +83,7 @@ tzdata==2024.2
|
||||
urllib3==2.3.0
|
||||
vine==5.1.0
|
||||
wcwidth==0.2.13
|
||||
werkzeug==3.1.3
|
||||
werkzeug>=3.1.0,<3.1.2
|
||||
yarl==1.20.0
|
||||
markdownify==1.1.0
|
||||
tldextract==5.1.3
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
@@ -141,15 +142,28 @@ class ClassicRAG(BaseRetriever):
|
||||
title = metadata.get(
|
||||
"title", metadata.get("post_title", page_content)
|
||||
)
|
||||
if isinstance(title, str):
|
||||
if not isinstance(title, str):
|
||||
title = str(title)
|
||||
title = title.split("/")[-1]
|
||||
|
||||
filename = (
|
||||
metadata.get("filename")
|
||||
or metadata.get("file_name")
|
||||
or metadata.get("source")
|
||||
)
|
||||
if isinstance(filename, str):
|
||||
filename = os.path.basename(filename) or filename
|
||||
else:
|
||||
title = str(title).split("/")[-1]
|
||||
filename = title
|
||||
if not filename:
|
||||
filename = title
|
||||
source_path = metadata.get("source") or vectorstore_id
|
||||
all_docs.append(
|
||||
{
|
||||
"title": title,
|
||||
"text": page_content,
|
||||
"source": metadata.get("source") or vectorstore_id,
|
||||
"source": source_path,
|
||||
"filename": filename,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
@@ -19,6 +19,7 @@ from bson.objectid import ObjectId
|
||||
from application.agents.agent_creator import AgentCreator
|
||||
from application.api.answer.services.stream_processor import get_prompt
|
||||
|
||||
from application.cache import get_redis_instance
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.parser.chunking import Chunker
|
||||
@@ -214,8 +215,7 @@ def run_agent_logic(agent_config, input_data):
|
||||
|
||||
|
||||
def ingest_worker(
|
||||
self, directory, formats, job_name, file_path, filename, user,
|
||||
retriever="classic"
|
||||
self, directory, formats, job_name, file_path, filename, user, retriever="classic"
|
||||
):
|
||||
"""
|
||||
Ingest and process documents.
|
||||
@@ -270,7 +270,9 @@ def ingest_worker(
|
||||
with open(local_file_path, "wb") as f:
|
||||
f.write(file_data.read())
|
||||
except Exception as e:
|
||||
logging.error(f"Error downloading file {storage_file_path}: {e}")
|
||||
logging.error(
|
||||
f"Error downloading file {storage_file_path}: {e}"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
# Handle single file case
|
||||
@@ -285,7 +287,10 @@ def ingest_worker(
|
||||
if temp_filename.endswith(".zip"):
|
||||
logging.info(f"Extracting zip file: {temp_filename}")
|
||||
extract_zip_recursive(
|
||||
temp_file_path, temp_dir, current_depth=0, max_depth=RECURSION_DEPTH
|
||||
temp_file_path,
|
||||
temp_dir,
|
||||
current_depth=0,
|
||||
max_depth=RECURSION_DEPTH,
|
||||
)
|
||||
|
||||
self.update_state(state="PROGRESS", meta={"current": 1})
|
||||
@@ -301,7 +306,7 @@ def ingest_worker(
|
||||
)
|
||||
raw_docs = reader.load_data()
|
||||
|
||||
directory_structure = getattr(reader, 'directory_structure', {})
|
||||
directory_structure = getattr(reader, "directory_structure", {})
|
||||
logging.info(f"Directory structure from reader: {directory_structure}")
|
||||
|
||||
chunker = Chunker(
|
||||
@@ -371,7 +376,10 @@ def reingest_source_worker(self, source_id, user):
|
||||
try:
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
|
||||
self.update_state(state="PROGRESS", meta={"current": 10, "status": "Initializing re-ingestion scan"})
|
||||
self.update_state(
|
||||
state="PROGRESS",
|
||||
meta={"current": 10, "status": "Initializing re-ingestion scan"},
|
||||
)
|
||||
|
||||
source = sources_collection.find_one({"_id": ObjectId(source_id), "user": user})
|
||||
if not source:
|
||||
@@ -380,7 +388,9 @@ def reingest_source_worker(self, source_id, user):
|
||||
storage = StorageCreator.get_storage()
|
||||
source_file_path = source.get("file_path", "")
|
||||
|
||||
self.update_state(state="PROGRESS", meta={"current": 20, "status": "Scanning current files"})
|
||||
self.update_state(
|
||||
state="PROGRESS", meta={"current": 20, "status": "Scanning current files"}
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Download all files from storage to temp directory, preserving directory structure
|
||||
@@ -391,7 +401,6 @@ def reingest_source_worker(self, source_id, user):
|
||||
if storage.is_directory(storage_file_path):
|
||||
continue
|
||||
|
||||
|
||||
rel_path = os.path.relpath(storage_file_path, source_file_path)
|
||||
local_file_path = os.path.join(temp_dir, rel_path)
|
||||
|
||||
@@ -403,23 +412,39 @@ def reingest_source_worker(self, source_id, user):
|
||||
with open(local_file_path, "wb") as f:
|
||||
f.write(file_data.read())
|
||||
except Exception as e:
|
||||
logging.error(f"Error downloading file {storage_file_path}: {e}")
|
||||
logging.error(
|
||||
f"Error downloading file {storage_file_path}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
reader = SimpleDirectoryReader(
|
||||
input_dir=temp_dir,
|
||||
recursive=True,
|
||||
required_exts=[
|
||||
".rst", ".md", ".pdf", ".txt", ".docx", ".csv", ".epub",
|
||||
".html", ".mdx", ".json", ".xlsx", ".pptx", ".png",
|
||||
".jpg", ".jpeg",
|
||||
".rst",
|
||||
".md",
|
||||
".pdf",
|
||||
".txt",
|
||||
".docx",
|
||||
".csv",
|
||||
".epub",
|
||||
".html",
|
||||
".mdx",
|
||||
".json",
|
||||
".xlsx",
|
||||
".pptx",
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
],
|
||||
exclude_hidden=True,
|
||||
file_metadata=metadata_from_filename,
|
||||
)
|
||||
reader.load_data()
|
||||
directory_structure = reader.directory_structure
|
||||
logging.info(f"Directory structure built with token counts: {directory_structure}")
|
||||
logging.info(
|
||||
f"Directory structure built with token counts: {directory_structure}"
|
||||
)
|
||||
|
||||
try:
|
||||
old_directory_structure = source.get("directory_structure") or {}
|
||||
@@ -433,11 +458,17 @@ def reingest_source_worker(self, source_id, user):
|
||||
files = set()
|
||||
if isinstance(struct, dict):
|
||||
for name, meta in struct.items():
|
||||
current_path = os.path.join(prefix, name) if prefix else name
|
||||
if isinstance(meta, dict) and ("type" in meta and "size_bytes" in meta):
|
||||
current_path = (
|
||||
os.path.join(prefix, name) if prefix else name
|
||||
)
|
||||
if isinstance(meta, dict) and (
|
||||
"type" in meta and "size_bytes" in meta
|
||||
):
|
||||
files.add(current_path)
|
||||
elif isinstance(meta, dict):
|
||||
files |= _flatten_directory_structure(meta, current_path)
|
||||
files |= _flatten_directory_structure(
|
||||
meta, current_path
|
||||
)
|
||||
return files
|
||||
|
||||
old_files = _flatten_directory_structure(old_directory_structure)
|
||||
@@ -457,7 +488,9 @@ def reingest_source_worker(self, source_id, user):
|
||||
logging.info("No files removed since last ingest.")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error comparing directory structures: {e}", exc_info=True)
|
||||
logging.error(
|
||||
f"Error comparing directory structures: {e}", exc_info=True
|
||||
)
|
||||
added_files = []
|
||||
removed_files = []
|
||||
try:
|
||||
@@ -477,14 +510,21 @@ def reingest_source_worker(self, source_id, user):
|
||||
settings.EMBEDDINGS_KEY,
|
||||
)
|
||||
|
||||
self.update_state(state="PROGRESS", meta={"current": 40, "status": "Processing file changes"})
|
||||
self.update_state(
|
||||
state="PROGRESS",
|
||||
meta={"current": 40, "status": "Processing file changes"},
|
||||
)
|
||||
|
||||
# 1) Delete chunks from removed files
|
||||
deleted = 0
|
||||
if removed_files:
|
||||
try:
|
||||
for ch in vector_store.get_chunks() or []:
|
||||
metadata = ch.get("metadata", {}) if isinstance(ch, dict) else getattr(ch, "metadata", {})
|
||||
metadata = (
|
||||
ch.get("metadata", {})
|
||||
if isinstance(ch, dict)
|
||||
else getattr(ch, "metadata", {})
|
||||
)
|
||||
raw_source = metadata.get("source")
|
||||
|
||||
source_file = str(raw_source) if raw_source else ""
|
||||
@@ -496,10 +536,17 @@ def reingest_source_worker(self, source_id, user):
|
||||
vector_store.delete_chunk(cid)
|
||||
deleted += 1
|
||||
except Exception as de:
|
||||
logging.error(f"Failed deleting chunk {cid}: {de}")
|
||||
logging.info(f"Deleted {deleted} chunks from {len(removed_files)} removed files")
|
||||
logging.error(
|
||||
f"Failed deleting chunk {cid}: {de}"
|
||||
)
|
||||
logging.info(
|
||||
f"Deleted {deleted} chunks from {len(removed_files)} removed files"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error during deletion of removed file chunks: {e}", exc_info=True)
|
||||
logging.error(
|
||||
f"Error during deletion of removed file chunks: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# 2) Add chunks from new files
|
||||
added = 0
|
||||
@@ -528,39 +575,62 @@ def reingest_source_worker(self, source_id, user):
|
||||
)
|
||||
chunked_new = chunker_new.chunk(documents=raw_docs_new)
|
||||
|
||||
for file_path, token_count in reader_new.file_token_counts.items():
|
||||
for (
|
||||
file_path,
|
||||
token_count,
|
||||
) in reader_new.file_token_counts.items():
|
||||
try:
|
||||
rel_path = os.path.relpath(file_path, start=temp_dir)
|
||||
rel_path = os.path.relpath(
|
||||
file_path, start=temp_dir
|
||||
)
|
||||
path_parts = rel_path.split(os.sep)
|
||||
current_dir = directory_structure
|
||||
|
||||
for part in path_parts[:-1]:
|
||||
if part in current_dir and isinstance(current_dir[part], dict):
|
||||
if part in current_dir and isinstance(
|
||||
current_dir[part], dict
|
||||
):
|
||||
current_dir = current_dir[part]
|
||||
else:
|
||||
break
|
||||
|
||||
filename = path_parts[-1]
|
||||
if filename in current_dir and isinstance(current_dir[filename], dict):
|
||||
current_dir[filename]["token_count"] = token_count
|
||||
logging.info(f"Updated token count for {rel_path}: {token_count}")
|
||||
if filename in current_dir and isinstance(
|
||||
current_dir[filename], dict
|
||||
):
|
||||
current_dir[filename][
|
||||
"token_count"
|
||||
] = token_count
|
||||
logging.info(
|
||||
f"Updated token count for {rel_path}: {token_count}"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(f"Could not update token count for {file_path}: {e}")
|
||||
logging.warning(
|
||||
f"Could not update token count for {file_path}: {e}"
|
||||
)
|
||||
|
||||
for d in chunked_new:
|
||||
meta = dict(d.extra_info or {})
|
||||
try:
|
||||
raw_src = meta.get("source")
|
||||
if isinstance(raw_src, str) and os.path.isabs(raw_src):
|
||||
meta["source"] = os.path.relpath(raw_src, start=temp_dir)
|
||||
if isinstance(raw_src, str) and os.path.isabs(
|
||||
raw_src
|
||||
):
|
||||
meta["source"] = os.path.relpath(
|
||||
raw_src, start=temp_dir
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
vector_store.add_chunk(d.text, metadata=meta)
|
||||
added += 1
|
||||
logging.info(f"Added {added} chunks from {len(added_files)} new files")
|
||||
logging.info(
|
||||
f"Added {added} chunks from {len(added_files)} new files"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error during ingestion of new files: {e}", exc_info=True)
|
||||
logging.error(
|
||||
f"Error during ingestion of new files: {e}", exc_info=True
|
||||
)
|
||||
|
||||
# 3) Update source directory structure timestamp
|
||||
try:
|
||||
@@ -572,14 +642,19 @@ def reingest_source_worker(self, source_id, user):
|
||||
"$set": {
|
||||
"directory_structure": directory_structure,
|
||||
"date": datetime.datetime.now(),
|
||||
"tokens": total_tokens
|
||||
"tokens": total_tokens,
|
||||
}
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error updating directory_structure in DB: {e}", exc_info=True)
|
||||
logging.error(
|
||||
f"Error updating directory_structure in DB: {e}", exc_info=True
|
||||
)
|
||||
|
||||
self.update_state(state="PROGRESS", meta={"current": 100, "status": "Re-ingestion completed"})
|
||||
self.update_state(
|
||||
state="PROGRESS",
|
||||
meta={"current": 100, "status": "Re-ingestion completed"},
|
||||
)
|
||||
|
||||
return {
|
||||
"source_id": source_id,
|
||||
@@ -591,15 +666,16 @@ def reingest_source_worker(self, source_id, user):
|
||||
"chunks_deleted": deleted,
|
||||
}
|
||||
except Exception as e:
|
||||
logging.error(f"Error while processing file changes: {e}", exc_info=True)
|
||||
logging.error(
|
||||
f"Error while processing file changes: {e}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error in reingest_source_worker: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def remote_worker(
|
||||
self,
|
||||
source_data,
|
||||
@@ -651,7 +727,7 @@ def remote_worker(
|
||||
"id": str(id),
|
||||
"type": loader,
|
||||
"remote_data": source_data,
|
||||
"sync_frequency": sync_frequency
|
||||
"sync_frequency": sync_frequency,
|
||||
}
|
||||
|
||||
if operation_mode == "sync":
|
||||
@@ -752,7 +828,6 @@ def attachment_worker(self, file_info, user):
|
||||
.text,
|
||||
)
|
||||
|
||||
|
||||
token_count = num_tokens_from_string(content)
|
||||
if token_count > 100000:
|
||||
content = content[:250000]
|
||||
@@ -872,37 +947,49 @@ def ingest_connector(
|
||||
doc_id: Document ID for sync operations (required when operation_mode="sync")
|
||||
sync_frequency: How often to sync ("never", "daily", "weekly", "monthly")
|
||||
"""
|
||||
logging.info(f"Starting remote ingestion from {source_type} for user: {user}, job: {job_name}")
|
||||
logging.info(
|
||||
f"Starting remote ingestion from {source_type} for user: {user}, job: {job_name}"
|
||||
)
|
||||
self.update_state(state="PROGRESS", meta={"current": 1})
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
try:
|
||||
# Step 1: Initialize the appropriate loader
|
||||
self.update_state(state="PROGRESS", meta={"current": 10, "status": "Initializing connector"})
|
||||
self.update_state(
|
||||
state="PROGRESS",
|
||||
meta={"current": 10, "status": "Initializing connector"},
|
||||
)
|
||||
|
||||
if not session_token:
|
||||
raise ValueError(f"{source_type} connector requires session_token")
|
||||
|
||||
if not ConnectorCreator.is_supported(source_type):
|
||||
raise ValueError(f"Unsupported connector type: {source_type}. Supported types: {ConnectorCreator.get_supported_connectors()}")
|
||||
raise ValueError(
|
||||
f"Unsupported connector type: {source_type}. Supported types: {ConnectorCreator.get_supported_connectors()}"
|
||||
)
|
||||
|
||||
remote_loader = ConnectorCreator.create_connector(source_type, session_token)
|
||||
remote_loader = ConnectorCreator.create_connector(
|
||||
source_type, session_token
|
||||
)
|
||||
|
||||
# Create a clean config for storage
|
||||
api_source_config = {
|
||||
"file_ids": file_ids or [],
|
||||
"folder_ids": folder_ids or [],
|
||||
"recursive": recursive
|
||||
"recursive": recursive,
|
||||
}
|
||||
|
||||
# Step 2: Download files to temp directory
|
||||
self.update_state(state="PROGRESS", meta={"current": 20, "status": "Downloading files"})
|
||||
self.update_state(
|
||||
state="PROGRESS", meta={"current": 20, "status": "Downloading files"}
|
||||
)
|
||||
download_info = remote_loader.download_to_directory(
|
||||
temp_dir,
|
||||
api_source_config
|
||||
temp_dir, api_source_config
|
||||
)
|
||||
|
||||
if download_info.get("empty_result", False) or not download_info.get("files_downloaded", 0):
|
||||
if download_info.get("empty_result", False) or not download_info.get(
|
||||
"files_downloaded", 0
|
||||
):
|
||||
logging.warning(f"No files were downloaded from {source_type}")
|
||||
# Create empty result directly instead of calling a separate method
|
||||
return {
|
||||
@@ -915,25 +1002,39 @@ def ingest_connector(
|
||||
}
|
||||
|
||||
# Step 3: Use SimpleDirectoryReader to process downloaded files
|
||||
self.update_state(state="PROGRESS", meta={"current": 40, "status": "Processing files"})
|
||||
self.update_state(
|
||||
state="PROGRESS", meta={"current": 40, "status": "Processing files"}
|
||||
)
|
||||
reader = SimpleDirectoryReader(
|
||||
input_dir=temp_dir,
|
||||
recursive=True,
|
||||
required_exts=[
|
||||
".rst", ".md", ".pdf", ".txt", ".docx", ".csv", ".epub",
|
||||
".html", ".mdx", ".json", ".xlsx", ".pptx", ".png",
|
||||
".jpg", ".jpeg",
|
||||
".rst",
|
||||
".md",
|
||||
".pdf",
|
||||
".txt",
|
||||
".docx",
|
||||
".csv",
|
||||
".epub",
|
||||
".html",
|
||||
".mdx",
|
||||
".json",
|
||||
".xlsx",
|
||||
".pptx",
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
],
|
||||
exclude_hidden=True,
|
||||
file_metadata=metadata_from_filename,
|
||||
)
|
||||
raw_docs = reader.load_data()
|
||||
directory_structure = getattr(reader, 'directory_structure', {})
|
||||
|
||||
|
||||
directory_structure = getattr(reader, "directory_structure", {})
|
||||
|
||||
# Step 4: Process documents (chunking, embedding, etc.)
|
||||
self.update_state(state="PROGRESS", meta={"current": 60, "status": "Processing documents"})
|
||||
self.update_state(
|
||||
state="PROGRESS", meta={"current": 60, "status": "Processing documents"}
|
||||
)
|
||||
|
||||
chunker = Chunker(
|
||||
chunking_strategy="classic_chunk",
|
||||
@@ -945,11 +1046,13 @@ def ingest_connector(
|
||||
|
||||
# Preserve source information in document metadata
|
||||
for doc in raw_docs:
|
||||
if hasattr(doc, 'extra_info') and doc.extra_info:
|
||||
source = doc.extra_info.get('source')
|
||||
if hasattr(doc, "extra_info") and doc.extra_info:
|
||||
source = doc.extra_info.get("source")
|
||||
if source and os.path.isabs(source):
|
||||
# Convert absolute path to relative path
|
||||
doc.extra_info['source'] = os.path.relpath(source, start=temp_dir)
|
||||
doc.extra_info["source"] = os.path.relpath(
|
||||
source, start=temp_dir
|
||||
)
|
||||
|
||||
docs = [Document.to_langchain_format(raw_doc) for raw_doc in raw_docs]
|
||||
|
||||
@@ -957,7 +1060,9 @@ def ingest_connector(
|
||||
id = ObjectId()
|
||||
elif operation_mode == "sync":
|
||||
if not doc_id or not ObjectId.is_valid(doc_id):
|
||||
logging.error("Invalid doc_id provided for sync operation: %s", doc_id)
|
||||
logging.error(
|
||||
"Invalid doc_id provided for sync operation: %s", doc_id
|
||||
)
|
||||
raise ValueError("doc_id must be provided for sync operation.")
|
||||
id = ObjectId(doc_id)
|
||||
else:
|
||||
@@ -966,7 +1071,9 @@ def ingest_connector(
|
||||
vector_store_path = os.path.join(temp_dir, "vector_store")
|
||||
os.makedirs(vector_store_path, exist_ok=True)
|
||||
|
||||
self.update_state(state="PROGRESS", meta={"current": 80, "status": "Storing documents"})
|
||||
self.update_state(
|
||||
state="PROGRESS", meta={"current": 80, "status": "Storing documents"}
|
||||
)
|
||||
embed_and_store_documents(docs, vector_store_path, id, self)
|
||||
|
||||
tokens = count_tokens_docs(docs)
|
||||
@@ -979,12 +1086,11 @@ def ingest_connector(
|
||||
"retriever": retriever,
|
||||
"id": str(id),
|
||||
"type": "connector:file",
|
||||
"remote_data": json.dumps({
|
||||
"provider": source_type,
|
||||
**api_source_config
|
||||
}),
|
||||
"remote_data": json.dumps(
|
||||
{"provider": source_type, **api_source_config}
|
||||
),
|
||||
"directory_structure": json.dumps(directory_structure),
|
||||
"sync_frequency": sync_frequency
|
||||
"sync_frequency": sync_frequency,
|
||||
}
|
||||
|
||||
if operation_mode == "sync":
|
||||
@@ -995,7 +1101,9 @@ def ingest_connector(
|
||||
upload_index(vector_store_path, file_data)
|
||||
|
||||
# Ensure we mark the task as complete
|
||||
self.update_state(state="PROGRESS", meta={"current": 100, "status": "Complete"})
|
||||
self.update_state(
|
||||
state="PROGRESS", meta={"current": 100, "status": "Complete"}
|
||||
)
|
||||
|
||||
logging.info(f"Remote ingestion completed: {job_name}")
|
||||
|
||||
@@ -1005,9 +1113,136 @@ def ingest_connector(
|
||||
"tokens": tokens,
|
||||
"type": source_type,
|
||||
"id": str(id),
|
||||
"status": "complete"
|
||||
"status": "complete",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error during remote ingestion: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def mcp_oauth(self, config: Dict[str, Any], user_id: str = None) -> Dict[str, Any]:
|
||||
"""Worker to handle MCP OAuth flow asynchronously."""
|
||||
|
||||
logging.info(
|
||||
"[MCP OAuth] Worker started for user_id=%s, config=%s", user_id, config
|
||||
)
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
task_id = self.request.id
|
||||
logging.info("[MCP OAuth] Task ID: %s", task_id)
|
||||
redis_client = get_redis_instance()
|
||||
|
||||
def update_status(status_data: Dict[str, Any]):
|
||||
logging.info("[MCP OAuth] Updating status: %s", status_data)
|
||||
status_key = f"mcp_oauth_status:{task_id}"
|
||||
redis_client.setex(status_key, 600, json.dumps(status_data))
|
||||
|
||||
update_status(
|
||||
{
|
||||
"status": "in_progress",
|
||||
"message": "Starting OAuth flow...",
|
||||
"task_id": task_id,
|
||||
}
|
||||
)
|
||||
|
||||
tool_config = config.copy()
|
||||
tool_config["oauth_task_id"] = task_id
|
||||
logging.info("[MCP OAuth] Initializing MCPTool with config: %s", tool_config)
|
||||
mcp_tool = MCPTool(tool_config, user_id)
|
||||
|
||||
async def run_oauth_discovery():
|
||||
if not mcp_tool._client:
|
||||
mcp_tool._setup_client()
|
||||
return await mcp_tool._execute_with_client("list_tools")
|
||||
|
||||
update_status(
|
||||
{
|
||||
"status": "awaiting_redirect",
|
||||
"message": "Waiting for OAuth redirect...",
|
||||
"task_id": task_id,
|
||||
}
|
||||
)
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
logging.info("[MCP OAuth] Starting event loop for OAuth discovery...")
|
||||
tools_response = loop.run_until_complete(run_oauth_discovery())
|
||||
logging.info(
|
||||
"[MCP OAuth] Tools response after async call: %s", tools_response
|
||||
)
|
||||
|
||||
status_key = f"mcp_oauth_status:{task_id}"
|
||||
redis_status = redis_client.get(status_key)
|
||||
if redis_status:
|
||||
logging.info(
|
||||
"[MCP OAuth] Redis status after async call: %s", redis_status
|
||||
)
|
||||
else:
|
||||
logging.warning(
|
||||
"[MCP OAuth] No Redis status found after async call for key: %s",
|
||||
status_key,
|
||||
)
|
||||
tools = mcp_tool.get_actions_metadata()
|
||||
|
||||
update_status(
|
||||
{
|
||||
"status": "completed",
|
||||
"message": f"OAuth completed successfully. Found {len(tools)} tools.",
|
||||
"tools": tools,
|
||||
"tools_count": len(tools),
|
||||
"task_id": task_id,
|
||||
}
|
||||
)
|
||||
|
||||
logging.info(
|
||||
"[MCP OAuth] OAuth flow completed successfully for task_id=%s", task_id
|
||||
)
|
||||
return {"success": True, "tools": tools, "tools_count": len(tools)}
|
||||
except Exception as e:
|
||||
error_msg = f"OAuth flow failed: {str(e)}"
|
||||
logging.error(
|
||||
"[MCP OAuth] Exception in OAuth discovery: %s", error_msg, exc_info=True
|
||||
)
|
||||
update_status(
|
||||
{
|
||||
"status": "error",
|
||||
"message": error_msg,
|
||||
"error": str(e),
|
||||
"task_id": task_id,
|
||||
}
|
||||
)
|
||||
return {"success": False, "error": error_msg}
|
||||
finally:
|
||||
logging.info("[MCP OAuth] Closing event loop for task_id=%s", task_id)
|
||||
loop.close()
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to initialize OAuth flow: {str(e)}"
|
||||
logging.error(
|
||||
"[MCP OAuth] Exception during initialization: %s", error_msg, exc_info=True
|
||||
)
|
||||
update_status(
|
||||
{
|
||||
"status": "error",
|
||||
"message": error_msg,
|
||||
"error": str(e),
|
||||
"task_id": task_id,
|
||||
}
|
||||
)
|
||||
return {"success": False, "error": error_msg}
|
||||
|
||||
|
||||
def mcp_oauth_status(self, task_id: str) -> Dict[str, Any]:
|
||||
"""Check the status of an MCP OAuth flow."""
|
||||
redis_client = get_redis_instance()
|
||||
status_key = f"mcp_oauth_status:{task_id}"
|
||||
|
||||
status_data = redis_client.get(status_key)
|
||||
if status_data:
|
||||
return json.loads(status_data)
|
||||
return {"status": "not_found", "message": "Status not found"}
|
||||
|
||||
@@ -59,6 +59,8 @@ const endpoints = {
|
||||
MANAGE_SOURCE_FILES: '/api/manage_source_files',
|
||||
MCP_TEST_CONNECTION: '/api/mcp_server/test',
|
||||
MCP_SAVE_SERVER: '/api/mcp_server/save',
|
||||
MCP_OAUTH_STATUS: (task_id: string) =>
|
||||
`/api/mcp_server/oauth_status/${task_id}`,
|
||||
},
|
||||
CONVERSATION: {
|
||||
ANSWER: '/api/answer',
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { getSessionToken } from '../../utils/providerUtils';
|
||||
import apiClient from '../client';
|
||||
import endpoints from '../endpoints';
|
||||
import { getSessionToken } from '../../utils/providerUtils';
|
||||
|
||||
const userService = {
|
||||
getConfig: (): Promise<any> => apiClient.get(endpoints.USER.CONFIG, null),
|
||||
@@ -112,6 +112,8 @@ const userService = {
|
||||
apiClient.post(endpoints.USER.MCP_TEST_CONNECTION, data, token),
|
||||
saveMCPServer: (data: any, token: string | null): Promise<any> =>
|
||||
apiClient.post(endpoints.USER.MCP_SAVE_SERVER, data, token),
|
||||
getMCPOAuthStatus: (task_id: string, token: string | null): Promise<any> =>
|
||||
apiClient.get(endpoints.USER.MCP_OAUTH_STATUS(task_id), token),
|
||||
syncConnector: (
|
||||
docId: string,
|
||||
provider: string,
|
||||
|
||||
@@ -210,7 +210,7 @@ export default function ConversationMessages({
|
||||
)}
|
||||
|
||||
<div className="w-full max-w-[1300px] px-2 md:w-9/12 lg:w-8/12 xl:w-8/12 2xl:w-6/12">
|
||||
{headerContent && headerContent}
|
||||
{headerContent}
|
||||
|
||||
{queries.length > 0 ? (
|
||||
queries.map((query, index) => (
|
||||
|
||||
@@ -194,17 +194,20 @@
|
||||
"headerName": "Header Name",
|
||||
"timeout": "Timeout (seconds)",
|
||||
"testConnection": "Test Connection",
|
||||
"testing": "Testing...",
|
||||
"saving": "Saving...",
|
||||
"testing": "Testing",
|
||||
"saving": "Saving",
|
||||
"save": "Save",
|
||||
"cancel": "Cancel",
|
||||
"noAuth": "No Authentication",
|
||||
"oauthInProgress": "Waiting for OAuth completion...",
|
||||
"oauthCompleted": "OAuth completed successfully",
|
||||
"placeholders": {
|
||||
"serverUrl": "https://api.example.com",
|
||||
"apiKey": "Your secret API key",
|
||||
"bearerToken": "Your secret token",
|
||||
"username": "Your username",
|
||||
"password": "Your password"
|
||||
"password": "Your password",
|
||||
"oauthScopes": "OAuth scopes (comma separated)"
|
||||
},
|
||||
"errors": {
|
||||
"nameRequired": "Server name is required",
|
||||
@@ -215,7 +218,9 @@
|
||||
"usernameRequired": "Username is required",
|
||||
"passwordRequired": "Password is required",
|
||||
"testFailed": "Connection test failed",
|
||||
"saveFailed": "Failed to save MCP server"
|
||||
"saveFailed": "Failed to save MCP server",
|
||||
"oauthFailed": "OAuth process failed or was cancelled",
|
||||
"oauthTimeout": "OAuth process timed out, please try again"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ const authTypes = [
|
||||
{ label: 'No Authentication', value: 'none' },
|
||||
{ label: 'API Key', value: 'api_key' },
|
||||
{ label: 'Bearer Token', value: 'bearer' },
|
||||
{ label: 'OAuth', value: 'oauth' },
|
||||
// { label: 'Basic Authentication', value: 'basic' },
|
||||
];
|
||||
|
||||
@@ -45,6 +46,8 @@ export default function MCPServerModal({
|
||||
username: '',
|
||||
password: '',
|
||||
timeout: server?.timeout || 30,
|
||||
oauth_scopes: '',
|
||||
oauth_task_id: '',
|
||||
});
|
||||
|
||||
const [loading, setLoading] = useState(false);
|
||||
@@ -52,8 +55,13 @@ export default function MCPServerModal({
|
||||
const [testResult, setTestResult] = useState<{
|
||||
success: boolean;
|
||||
message: string;
|
||||
status?: string;
|
||||
authorization_url?: string;
|
||||
} | null>(null);
|
||||
const [errors, setErrors] = useState<{ [key: string]: string }>({});
|
||||
const oauthPopupRef = useRef<Window | null>(null);
|
||||
const [oauthCompleted, setOAuthCompleted] = useState(false);
|
||||
const [saveActive, setSaveActive] = useState(false);
|
||||
|
||||
useOutsideAlerter(modalRef, () => {
|
||||
if (modalState === 'ACTIVE') {
|
||||
@@ -73,9 +81,12 @@ export default function MCPServerModal({
|
||||
username: '',
|
||||
password: '',
|
||||
timeout: 30,
|
||||
oauth_scopes: '',
|
||||
oauth_task_id: '',
|
||||
});
|
||||
setErrors({});
|
||||
setTestResult(null);
|
||||
setSaveActive(false);
|
||||
};
|
||||
|
||||
const validateForm = () => {
|
||||
@@ -154,10 +165,81 @@ export default function MCPServerModal({
|
||||
} else if (formData.auth_type === 'basic') {
|
||||
config.username = formData.username.trim();
|
||||
config.password = formData.password.trim();
|
||||
} else if (formData.auth_type === 'oauth') {
|
||||
config.oauth_scopes = formData.oauth_scopes
|
||||
.split(',')
|
||||
.map((s) => s.trim())
|
||||
.filter(Boolean);
|
||||
config.oauth_task_id = formData.oauth_task_id.trim();
|
||||
}
|
||||
return config;
|
||||
};
|
||||
|
||||
const pollOAuthStatus = async (
|
||||
taskId: string,
|
||||
onComplete: (result: any) => void,
|
||||
) => {
|
||||
let attempts = 0;
|
||||
const maxAttempts = 60;
|
||||
let popupOpened = false;
|
||||
const poll = async () => {
|
||||
try {
|
||||
const resp = await userService.getMCPOAuthStatus(taskId, token);
|
||||
const data = await resp.json();
|
||||
if (data.authorization_url && !popupOpened) {
|
||||
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
|
||||
oauthPopupRef.current.close();
|
||||
}
|
||||
oauthPopupRef.current = window.open(
|
||||
data.authorization_url,
|
||||
'oauthPopup',
|
||||
'width=600,height=700',
|
||||
);
|
||||
popupOpened = true;
|
||||
}
|
||||
if (data.status === 'completed') {
|
||||
setOAuthCompleted(true);
|
||||
setSaveActive(true);
|
||||
onComplete({
|
||||
...data,
|
||||
success: true,
|
||||
message: t('settings.tools.mcp.oauthCompleted'),
|
||||
});
|
||||
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
|
||||
oauthPopupRef.current.close();
|
||||
}
|
||||
} else if (data.status === 'error' || data.success === false) {
|
||||
setSaveActive(false);
|
||||
onComplete({
|
||||
...data,
|
||||
success: false,
|
||||
message: t('settings.tools.mcp.errors.oauthFailed'),
|
||||
});
|
||||
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
|
||||
oauthPopupRef.current.close();
|
||||
}
|
||||
} else {
|
||||
if (++attempts < maxAttempts) setTimeout(poll, 1000);
|
||||
else {
|
||||
setSaveActive(false);
|
||||
onComplete({
|
||||
success: false,
|
||||
message: t('settings.tools.mcp.errors.oauthTimeout'),
|
||||
});
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
if (++attempts < maxAttempts) setTimeout(poll, 1000);
|
||||
else
|
||||
onComplete({
|
||||
success: false,
|
||||
message: t('settings.tools.mcp.errors.oauthTimeout'),
|
||||
});
|
||||
}
|
||||
};
|
||||
poll();
|
||||
};
|
||||
|
||||
const testConnection = async () => {
|
||||
if (!validateForm()) return;
|
||||
setTesting(true);
|
||||
@@ -167,13 +249,37 @@ export default function MCPServerModal({
|
||||
const response = await userService.testMCPConnection({ config }, token);
|
||||
const result = await response.json();
|
||||
|
||||
if (
|
||||
formData.auth_type === 'oauth' &&
|
||||
result.requires_oauth &&
|
||||
result.task_id
|
||||
) {
|
||||
setTestResult({
|
||||
success: true,
|
||||
message: t('settings.tools.mcp.oauthInProgress'),
|
||||
});
|
||||
setOAuthCompleted(false);
|
||||
setSaveActive(false);
|
||||
pollOAuthStatus(result.task_id, (finalResult) => {
|
||||
setTestResult(finalResult);
|
||||
setFormData((prev) => ({
|
||||
...prev,
|
||||
oauth_task_id: result.task_id || '',
|
||||
}));
|
||||
setTesting(false);
|
||||
});
|
||||
} else {
|
||||
setTestResult(result);
|
||||
setSaveActive(result.success === true);
|
||||
setTesting(false);
|
||||
}
|
||||
} catch (error) {
|
||||
setTestResult({
|
||||
success: false,
|
||||
message: t('settings.tools.mcp.errors.testFailed'),
|
||||
});
|
||||
} finally {
|
||||
setOAuthCompleted(false);
|
||||
setSaveActive(false);
|
||||
setTesting(false);
|
||||
}
|
||||
};
|
||||
@@ -305,6 +411,28 @@ export default function MCPServerModal({
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
case 'oauth':
|
||||
return (
|
||||
<div className="mb-10">
|
||||
<div className="mt-6">
|
||||
<Input
|
||||
name="oauth_scopes"
|
||||
type="text"
|
||||
className="rounded-md"
|
||||
value={formData.oauth_scopes}
|
||||
onChange={(e) =>
|
||||
handleInputChange('oauth_scopes', e.target.value)
|
||||
}
|
||||
placeholder={
|
||||
t('settings.tools.mcp.placeholders.oauthScopes') ||
|
||||
'Scopes (comma separated)'
|
||||
}
|
||||
borderVariant="thin"
|
||||
labelBgClassName="bg-white dark:bg-charleston-green-2"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
@@ -331,7 +459,6 @@ export default function MCPServerModal({
|
||||
<div className="space-y-6 py-6">
|
||||
<div>
|
||||
<Input
|
||||
name="name"
|
||||
type="text"
|
||||
className="rounded-md"
|
||||
value={formData.name}
|
||||
@@ -410,7 +537,7 @@ export default function MCPServerModal({
|
||||
|
||||
{testResult && (
|
||||
<div
|
||||
className={`rounded-md p-5 ${
|
||||
className={`rounded-2xl p-5 ${
|
||||
testResult.success
|
||||
? 'bg-green-50 text-green-700 dark:bg-green-900/40 dark:text-green-300'
|
||||
: 'bg-red-50 text-red-700 dark:bg-red-900 dark:text-red-300'
|
||||
@@ -458,7 +585,7 @@ export default function MCPServerModal({
|
||||
</button>
|
||||
<button
|
||||
onClick={handleSave}
|
||||
disabled={loading}
|
||||
disabled={loading || !saveActive}
|
||||
className="bg-purple-30 hover:bg-violets-are-blue w-full rounded-3xl px-6 py-2 text-sm font-medium text-white transition-all disabled:opacity-50 sm:w-auto"
|
||||
>
|
||||
{loading ? (
|
||||
|
||||
Reference in New Issue
Block a user