diff --git a/application/agents/base.py b/application/agents/base.py
index 068b2a3c..77729fe6 100644
--- a/application/agents/base.py
+++ b/application/agents/base.py
@@ -5,19 +5,17 @@ from typing import Dict, Generator, List, Optional
from bson.objectid import ObjectId
-logger = logging.getLogger(__name__)
-
from application.agents.tools.tool_action_parser import ToolActionParser
from application.agents.tools.tool_manager import ToolManager
-
from application.core.mongo_db import MongoDB
from application.core.settings import settings
-
from application.llm.handlers.handler_creator import LLMHandlerCreator
from application.llm.llm_creator import LLMCreator
from application.logging import build_stack_data, log_activity, LogContext
from application.retriever.base import BaseRetriever
+logger = logging.getLogger(__name__)
+
class BaseAgent(ABC):
def __init__(
@@ -157,7 +155,7 @@ class BaseAgent(ABC):
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
self.tool_calls.append(tool_call_data)
- return f"Failed to parse tool call.", call_id
+ return "Failed to parse tool call.", call_id
# Check if tool_id exists in available tools
if tool_id not in tools_dict:
diff --git a/application/api/answer/services/stream_processor.py b/application/api/answer/services/stream_processor.py
index 648d24f5..f6e639ef 100644
--- a/application/api/answer/services/stream_processor.py
+++ b/application/api/answer/services/stream_processor.py
@@ -69,11 +69,8 @@ class StreamProcessor:
self.decoded_token.get("sub") if self.decoded_token is not None else None
)
self.conversation_id = self.data.get("conversation_id")
- self.source = (
- {"active_docs": self.data["active_docs"]}
- if "active_docs" in self.data
- else {}
- )
+ self.source = {}
+ self.all_sources = []
self.attachments = []
self.history = []
self.agent_config = {}
@@ -85,6 +82,8 @@ class StreamProcessor:
def initialize(self):
"""Initialize all required components for processing"""
+ self._configure_agent()
+ self._configure_source()
self._configure_retriever()
self._configure_agent()
self._load_conversation_history()
@@ -171,13 +170,77 @@ class StreamProcessor:
source = data.get("source")
if isinstance(source, DBRef):
source_doc = self.db.dereference(source)
- data["source"] = str(source_doc["_id"])
- data["retriever"] = source_doc.get("retriever", data.get("retriever"))
- data["chunks"] = source_doc.get("chunks", data.get("chunks"))
+ if source_doc:
+ data["source"] = str(source_doc["_id"])
+ data["retriever"] = source_doc.get("retriever", data.get("retriever"))
+ data["chunks"] = source_doc.get("chunks", data.get("chunks"))
+ else:
+ data["source"] = None
+ elif source == "default":
+ data["source"] = "default"
else:
data["source"] = None
+ # Handle multiple sources
+
+ sources = data.get("sources", [])
+ if sources and isinstance(sources, list):
+ sources_list = []
+ for i, source_ref in enumerate(sources):
+ if source_ref == "default":
+ processed_source = {
+ "id": "default",
+ "retriever": "classic",
+ "chunks": data.get("chunks", "2"),
+ }
+ sources_list.append(processed_source)
+ elif isinstance(source_ref, DBRef):
+ source_doc = self.db.dereference(source_ref)
+ if source_doc:
+ processed_source = {
+ "id": str(source_doc["_id"]),
+ "retriever": source_doc.get("retriever", "classic"),
+ "chunks": source_doc.get("chunks", data.get("chunks", "2")),
+ }
+ sources_list.append(processed_source)
+ data["sources"] = sources_list
+ else:
+ data["sources"] = []
return data
+ def _configure_source(self):
+ """Configure the source based on agent data"""
+ api_key = self.data.get("api_key") or self.agent_key
+
+ if api_key:
+ agent_data = self._get_data_from_api_key(api_key)
+
+ if agent_data.get("sources") and len(agent_data["sources"]) > 0:
+ source_ids = [
+ source["id"] for source in agent_data["sources"] if source.get("id")
+ ]
+ if source_ids:
+ self.source = {"active_docs": source_ids}
+ else:
+ self.source = {}
+ self.all_sources = agent_data["sources"]
+ elif agent_data.get("source"):
+ self.source = {"active_docs": agent_data["source"]}
+ self.all_sources = [
+ {
+ "id": agent_data["source"],
+ "retriever": agent_data.get("retriever", "classic"),
+ }
+ ]
+ else:
+ self.source = {}
+ self.all_sources = []
+ return
+ if "active_docs" in self.data:
+ self.source = {"active_docs": self.data["active_docs"]}
+ return
+ self.source = {}
+ self.all_sources = []
+
def _configure_agent(self):
"""Configure the agent based on request data"""
agent_id = self.data.get("agent_id")
@@ -203,7 +266,13 @@ class StreamProcessor:
if data_key.get("retriever"):
self.retriever_config["retriever_name"] = data_key["retriever"]
if data_key.get("chunks") is not None:
- self.retriever_config["chunks"] = data_key["chunks"]
+ try:
+ self.retriever_config["chunks"] = int(data_key["chunks"])
+ except (ValueError, TypeError):
+ logger.warning(
+ f"Invalid chunks value: {data_key['chunks']}, using default value 2"
+ )
+ self.retriever_config["chunks"] = 2
elif self.agent_key:
data_key = self._get_data_from_api_key(self.agent_key)
self.agent_config.update(
@@ -224,7 +293,13 @@ class StreamProcessor:
if data_key.get("retriever"):
self.retriever_config["retriever_name"] = data_key["retriever"]
if data_key.get("chunks") is not None:
- self.retriever_config["chunks"] = data_key["chunks"]
+ try:
+ self.retriever_config["chunks"] = int(data_key["chunks"])
+ except (ValueError, TypeError):
+ logger.warning(
+ f"Invalid chunks value: {data_key['chunks']}, using default value 2"
+ )
+ self.retriever_config["chunks"] = 2
else:
self.agent_config.update(
{
@@ -243,7 +318,8 @@ class StreamProcessor:
"token_limit": self.data.get("token_limit", settings.DEFAULT_MAX_HISTORY),
}
- if "isNoneDoc" in self.data and self.data["isNoneDoc"]:
+ api_key = self.data.get("api_key") or self.agent_key
+ if not api_key and "isNoneDoc" in self.data and self.data["isNoneDoc"]:
self.retriever_config["chunks"] = 0
def create_agent(self):
diff --git a/application/api/connector/routes.py b/application/api/connector/routes.py
new file mode 100644
index 00000000..1647aa78
--- /dev/null
+++ b/application/api/connector/routes.py
@@ -0,0 +1,626 @@
+import datetime
+import json
+
+
+from bson.objectid import ObjectId
+from flask import (
+ Blueprint,
+ current_app,
+ jsonify,
+ make_response,
+ request
+)
+from flask_restx import fields, Namespace, Resource
+
+
+
+
+from application.api.user.tasks import (
+ ingest_connector_task,
+)
+from application.core.mongo_db import MongoDB
+from application.core.settings import settings
+from application.api import api
+
+
+from application.utils import (
+ check_required_fields
+)
+
+
+from application.parser.connectors.connector_creator import ConnectorCreator
+
+
+
+mongo = MongoDB.get_client()
+db = mongo[settings.MONGO_DB_NAME]
+sources_collection = db["sources"]
+sessions_collection = db["connector_sessions"]
+
+connector = Blueprint("connector", __name__)
+connectors_ns = Namespace("connectors", description="Connector operations", path="/")
+api.add_namespace(connectors_ns)
+
+
+
+@connectors_ns.route("/api/connectors/upload")
+class UploadConnector(Resource):
+ @api.expect(
+ api.model(
+ "ConnectorUploadModel",
+ {
+ "user": fields.String(required=True, description="User ID"),
+ "source": fields.String(
+ required=True, description="Source type (google_drive, github, etc.)"
+ ),
+ "name": fields.String(required=True, description="Job name"),
+ "data": fields.String(required=True, description="Configuration data"),
+ "repo_url": fields.String(description="GitHub repository URL"),
+ },
+ )
+ )
+ @api.doc(
+ description="Uploads connector source for vectorization",
+ )
+ def post(self):
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
+ data = request.form
+ required_fields = ["user", "source", "name", "data"]
+ missing_fields = check_required_fields(data, required_fields)
+ if missing_fields:
+ return missing_fields
+ try:
+ config = json.loads(data["data"])
+ source_data = None
+ sync_frequency = config.get("sync_frequency", "never")
+
+ if data["source"] == "github":
+ source_data = config.get("repo_url")
+ elif data["source"] in ["crawler", "url"]:
+ source_data = config.get("url")
+ elif data["source"] == "reddit":
+ source_data = config
+ elif data["source"] in ConnectorCreator.get_supported_connectors():
+ session_token = config.get("session_token")
+ if not session_token:
+ return make_response(jsonify({
+ "success": False,
+ "error": f"Missing session_token in {data['source']} configuration"
+ }), 400)
+
+ file_ids = config.get("file_ids", [])
+ if isinstance(file_ids, str):
+ file_ids = [id.strip() for id in file_ids.split(',') if id.strip()]
+ elif not isinstance(file_ids, list):
+ file_ids = []
+
+ folder_ids = config.get("folder_ids", [])
+ if isinstance(folder_ids, str):
+ folder_ids = [id.strip() for id in folder_ids.split(',') if id.strip()]
+ elif not isinstance(folder_ids, list):
+ folder_ids = []
+
+ config["file_ids"] = file_ids
+ config["folder_ids"] = folder_ids
+
+ task = ingest_connector_task.delay(
+ job_name=data["name"],
+ user=decoded_token.get("sub"),
+ source_type=data["source"],
+ session_token=session_token,
+ file_ids=file_ids,
+ folder_ids=folder_ids,
+ recursive=config.get("recursive", False),
+ retriever=config.get("retriever", "classic"),
+ sync_frequency=sync_frequency
+ )
+ return make_response(jsonify({"success": True, "task_id": task.id}), 200)
+ task = ingest_connector_task.delay(
+ source_data=source_data,
+ job_name=data["name"],
+ user=decoded_token.get("sub"),
+ loader=data["source"],
+ sync_frequency=sync_frequency
+ )
+ except Exception as err:
+ current_app.logger.error(
+ f"Error uploading connector source: {err}", exc_info=True
+ )
+ return make_response(jsonify({"success": False}), 400)
+ return make_response(jsonify({"success": True, "task_id": task.id}), 200)
+
+
+@connectors_ns.route("/api/connectors/task_status")
+class ConnectorTaskStatus(Resource):
+ task_status_model = api.model(
+ "ConnectorTaskStatusModel",
+ {"task_id": fields.String(required=True, description="Task ID")},
+ )
+
+ @api.expect(task_status_model)
+ @api.doc(description="Get connector task status")
+ def get(self):
+ task_id = request.args.get("task_id")
+ if not task_id:
+ return make_response(
+ jsonify({"success": False, "message": "Task ID is required"}), 400
+ )
+ try:
+ from application.celery_init import celery
+
+ task = celery.AsyncResult(task_id)
+ task_meta = task.info
+ print(f"Task status: {task.status}")
+ if not isinstance(
+ task_meta, (dict, list, str, int, float, bool, type(None))
+ ):
+ task_meta = str(task_meta)
+ except Exception as err:
+ current_app.logger.error(f"Error getting task status: {err}", exc_info=True)
+ return make_response(jsonify({"success": False}), 400)
+ return make_response(jsonify({"status": task.status, "result": task_meta}), 200)
+
+
+@connectors_ns.route("/api/connectors/sources")
+class ConnectorSources(Resource):
+ @api.doc(description="Get connector sources")
+ def get(self):
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
+ user = decoded_token.get("sub")
+ try:
+ sources = sources_collection.find({"user": user, "type": "connector"}).sort("date", -1)
+ connector_sources = []
+ for source in sources:
+ connector_sources.append({
+ "id": str(source["_id"]),
+ "name": source.get("name"),
+ "date": source.get("date"),
+ "type": source.get("type"),
+ "source": source.get("source"),
+ "tokens": source.get("tokens", ""),
+ "retriever": source.get("retriever", "classic"),
+ "syncFrequency": source.get("sync_frequency", ""),
+ })
+ except Exception as err:
+ current_app.logger.error(f"Error retrieving connector sources: {err}", exc_info=True)
+ return make_response(jsonify({"success": False}), 400)
+ return make_response(jsonify(connector_sources), 200)
+
+
+@connectors_ns.route("/api/connectors/delete")
+class DeleteConnectorSource(Resource):
+ @api.doc(
+ description="Delete a connector source",
+ params={"source_id": "The source ID to delete"},
+ )
+ def delete(self):
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
+ source_id = request.args.get("source_id")
+ if not source_id:
+ return make_response(
+ jsonify({"success": False, "message": "source_id is required"}), 400
+ )
+ try:
+ result = sources_collection.delete_one(
+ {"_id": ObjectId(source_id), "user": decoded_token.get("sub")}
+ )
+ if result.deleted_count == 0:
+ return make_response(
+ jsonify({"success": False, "message": "Source not found"}), 404
+ )
+ except Exception as err:
+ current_app.logger.error(
+ f"Error deleting connector source: {err}", exc_info=True
+ )
+ return make_response(jsonify({"success": False}), 400)
+ return make_response(jsonify({"success": True}), 200)
+
+
+@connectors_ns.route("/api/connectors/auth")
+class ConnectorAuth(Resource):
+ @api.doc(description="Get connector OAuth authorization URL", params={"provider": "Connector provider (e.g., google_drive)"})
+ def get(self):
+ try:
+ provider = request.args.get('provider') or request.args.get('source')
+ if not provider:
+ return make_response(jsonify({"success": False, "error": "Missing provider"}), 400)
+
+ if not ConnectorCreator.is_supported(provider):
+ return make_response(jsonify({"success": False, "error": f"Unsupported provider: {provider}"}), 400)
+
+ import uuid
+ state = str(uuid.uuid4())
+ auth = ConnectorCreator.create_auth(provider)
+ authorization_url = auth.get_authorization_url(state=state)
+ return make_response(jsonify({
+ "success": True,
+ "authorization_url": authorization_url,
+ "state": state
+ }), 200)
+ except Exception as e:
+ current_app.logger.error(f"Error generating connector auth URL: {e}")
+ return make_response(jsonify({"success": False, "error": str(e)}), 500)
+
+
+@connectors_ns.route("/api/connectors/callback")
+class ConnectorsCallback(Resource):
+ @api.doc(description="Handle OAuth callback for external connectors")
+ def get(self):
+ """Handle OAuth callback for external connectors"""
+ try:
+ from application.parser.connectors.connector_creator import ConnectorCreator
+ from flask import request, redirect
+ import uuid
+
+ provider = request.args.get('provider', 'google_drive')
+ authorization_code = request.args.get('code')
+ _ = request.args.get('state')
+ error = request.args.get('error')
+
+ if error:
+ return redirect(f"/api/connectors/callback-status?status=error&message=OAuth+error:+{error}.+Please+try+again+and+make+sure+to+grant+all+requested+permissions,+including+offline+access.&provider={provider}")
+
+ if not authorization_code:
+ return redirect(f"/api/connectors/callback-status?status=error&message=Authorization+code+not+provided.+Please+complete+the+authorization+process+and+make+sure+to+grant+offline+access.&provider={provider}")
+
+ try:
+ auth = ConnectorCreator.create_auth(provider)
+ token_info = auth.exchange_code_for_tokens(authorization_code)
+
+ session_token = str(uuid.uuid4())
+
+
+ try:
+ credentials = auth.create_credentials_from_token_info(token_info)
+ service = auth.build_drive_service(credentials)
+ user_info = service.about().get(fields="user").execute()
+ user_email = user_info.get('user', {}).get('emailAddress', 'Connected User')
+ except Exception as e:
+ current_app.logger.warning(f"Could not get user info: {e}")
+ user_email = 'Connected User'
+
+ sanitized_token_info = {
+ "access_token": token_info.get("access_token"),
+ "refresh_token": token_info.get("refresh_token"),
+ "token_uri": token_info.get("token_uri"),
+ "expiry": token_info.get("expiry"),
+ "scopes": token_info.get("scopes")
+ }
+
+ user_id = request.decoded_token.get("sub") if getattr(request, "decoded_token", None) else None
+ sessions_collection.insert_one({
+ "session_token": session_token,
+ "user": user_id,
+ "token_info": sanitized_token_info,
+ "created_at": datetime.datetime.now(datetime.timezone.utc),
+ "user_email": user_email,
+ "provider": provider
+ })
+
+ # Redirect to success page with session token and user email
+ return redirect(f"/api/connectors/callback-status?status=success&message=Authentication+successful&provider={provider}&session_token={session_token}&user_email={user_email}")
+
+ except Exception as e:
+ current_app.logger.error(f"Error exchanging code for tokens: {str(e)}", exc_info=True)
+ return redirect(f"/api/connectors/callback-status?status=error&message=Failed+to+exchange+authorization+code+for+tokens:+{str(e)}&provider={provider}")
+
+ except Exception as e:
+ current_app.logger.error(f"Error handling connector callback: {e}")
+ return redirect(f"/api/connectors/callback-status?status=error&message=Failed+to+complete+connector+authentication:+{str(e)}.+Please+try+again+and+make+sure+to+grant+all+requested+permissions,+including+offline+access.")
+
+
+@connectors_ns.route("/api/connectors/refresh")
+class ConnectorRefresh(Resource):
+ @api.expect(api.model("ConnectorRefreshModel", {"provider": fields.String(required=True), "refresh_token": fields.String(required=True)}))
+ @api.doc(description="Refresh connector access token")
+ def post(self):
+ try:
+ data = request.get_json()
+ provider = data.get('provider')
+ refresh_token = data.get('refresh_token')
+
+ if not provider or not refresh_token:
+ return make_response(jsonify({"success": False, "error": "provider and refresh_token are required"}), 400)
+
+ auth = ConnectorCreator.create_auth(provider)
+ token_info = auth.refresh_access_token(refresh_token)
+ return make_response(jsonify({"success": True, "token_info": token_info}), 200)
+ except Exception as e:
+ current_app.logger.error(f"Error refreshing token for connector: {e}")
+ return make_response(jsonify({"success": False, "error": str(e)}), 500)
+
+
+@connectors_ns.route("/api/connectors/files")
+class ConnectorFiles(Resource):
+ @api.expect(api.model("ConnectorFilesModel", {"provider": fields.String(required=True), "session_token": fields.String(required=True), "folder_id": fields.String(required=False), "limit": fields.Integer(required=False), "page_token": fields.String(required=False)}))
+ @api.doc(description="List files from a connector provider (supports pagination)")
+ def post(self):
+ try:
+ data = request.get_json()
+ provider = data.get('provider')
+ session_token = data.get('session_token')
+ folder_id = data.get('folder_id')
+ limit = data.get('limit', 10)
+ page_token = data.get('page_token')
+ if not provider or not session_token:
+ return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400)
+
+
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
+ user = decoded_token.get('sub')
+ session = sessions_collection.find_one({"session_token": session_token, "user": user})
+ if not session:
+ return make_response(jsonify({"success": False, "error": "Invalid or unauthorized session"}), 401)
+
+ loader = ConnectorCreator.create_connector(provider, session_token)
+ documents = loader.load_data({
+ 'limit': limit,
+ 'list_only': True,
+ 'session_token': session_token,
+ 'folder_id': folder_id,
+ 'page_token': page_token
+ })
+
+ files = []
+ for doc in documents[:limit]:
+ metadata = doc.extra_info
+ modified_time = metadata.get('modified_time')
+ if modified_time:
+ date_part = modified_time.split('T')[0]
+ time_part = modified_time.split('T')[1].split('.')[0].split('Z')[0]
+ formatted_time = f"{date_part} {time_part}"
+ else:
+ formatted_time = None
+
+ files.append({
+ 'id': doc.doc_id,
+ 'name': metadata.get('file_name', 'Unknown File'),
+ 'type': metadata.get('mime_type', 'unknown'),
+ 'size': metadata.get('size', None),
+ 'modifiedTime': formatted_time
+ })
+
+ next_token = getattr(loader, 'next_page_token', None)
+ has_more = bool(next_token)
+
+ return make_response(jsonify({"success": True, "files": files, "total": len(files), "next_page_token": next_token, "has_more": has_more}), 200)
+ except Exception as e:
+ current_app.logger.error(f"Error loading connector files: {e}")
+ return make_response(jsonify({"success": False, "error": f"Failed to load files: {str(e)}"}), 500)
+
+
+@connectors_ns.route("/api/connectors/validate-session")
+class ConnectorValidateSession(Resource):
+ @api.expect(api.model("ConnectorValidateSessionModel", {"provider": fields.String(required=True), "session_token": fields.String(required=True)}))
+ @api.doc(description="Validate connector session token and return user info")
+ def post(self):
+ try:
+ data = request.get_json()
+ provider = data.get('provider')
+ session_token = data.get('session_token')
+ if not provider or not session_token:
+ return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400)
+
+
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
+ user = decoded_token.get('sub')
+
+ session = sessions_collection.find_one({"session_token": session_token, "user": user})
+ if not session or "token_info" not in session:
+ return make_response(jsonify({"success": False, "error": "Invalid or expired session"}), 401)
+
+ token_info = session["token_info"]
+ auth = ConnectorCreator.create_auth(provider)
+ is_expired = auth.is_token_expired(token_info)
+
+ return make_response(jsonify({
+ "success": True,
+ "expired": is_expired,
+ "user_email": session.get('user_email', 'Connected User')
+ }), 200)
+ except Exception as e:
+ current_app.logger.error(f"Error validating connector session: {e}")
+ return make_response(jsonify({"success": False, "error": str(e)}), 500)
+
+
+@connectors_ns.route("/api/connectors/disconnect")
+class ConnectorDisconnect(Resource):
+ @api.expect(api.model("ConnectorDisconnectModel", {"provider": fields.String(required=True), "session_token": fields.String(required=False)}))
+ @api.doc(description="Disconnect a connector session")
+ def post(self):
+ try:
+ data = request.get_json()
+ provider = data.get('provider')
+ session_token = data.get('session_token')
+ if not provider:
+ return make_response(jsonify({"success": False, "error": "provider is required"}), 400)
+
+
+ if session_token:
+ sessions_collection.delete_one({"session_token": session_token})
+
+ return make_response(jsonify({"success": True}), 200)
+ except Exception as e:
+ current_app.logger.error(f"Error disconnecting connector session: {e}")
+ return make_response(jsonify({"success": False, "error": str(e)}), 500)
+
+
+@connectors_ns.route("/api/connectors/sync")
+class ConnectorSync(Resource):
+ @api.expect(
+ api.model(
+ "ConnectorSyncModel",
+ {
+ "source_id": fields.String(required=True, description="Source ID to sync"),
+ "session_token": fields.String(required=True, description="Authentication token")
+ },
+ )
+ )
+ @api.doc(description="Sync connector source to check for modifications")
+ def post(self):
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
+
+ try:
+ data = request.get_json()
+ source_id = data.get('source_id')
+ session_token = data.get('session_token')
+
+ if not all([source_id, session_token]):
+ return make_response(
+ jsonify({
+ "success": False,
+ "error": "source_id and session_token are required"
+ }),
+ 400
+ )
+ source = sources_collection.find_one({"_id": ObjectId(source_id)})
+ if not source:
+ return make_response(
+ jsonify({
+ "success": False,
+ "error": "Source not found"
+ }),
+ 404
+ )
+
+ if source.get('user') != decoded_token.get('sub'):
+ return make_response(
+ jsonify({
+ "success": False,
+ "error": "Unauthorized access to source"
+ }),
+ 403
+ )
+
+ remote_data = {}
+ try:
+ if source.get('remote_data'):
+ remote_data = json.loads(source.get('remote_data'))
+ except json.JSONDecodeError:
+ current_app.logger.error(f"Invalid remote_data format for source {source_id}")
+ remote_data = {}
+
+ source_type = remote_data.get('provider')
+ if not source_type:
+ return make_response(
+ jsonify({
+ "success": False,
+ "error": "Source provider not found in remote_data"
+ }),
+ 400
+ )
+
+ # Extract configuration from remote_data
+ file_ids = remote_data.get('file_ids', [])
+ folder_ids = remote_data.get('folder_ids', [])
+ recursive = remote_data.get('recursive', True)
+
+ # Start the sync task
+ task = ingest_connector_task.delay(
+ job_name=source.get('name'),
+ user=decoded_token.get('sub'),
+ source_type=source_type,
+ session_token=session_token,
+ file_ids=file_ids,
+ folder_ids=folder_ids,
+ recursive=recursive,
+ retriever=source.get('retriever', 'classic'),
+ operation_mode="sync",
+ doc_id=source_id,
+ sync_frequency=source.get('sync_frequency', 'never')
+ )
+
+ return make_response(
+ jsonify({
+ "success": True,
+ "task_id": task.id
+ }),
+ 200
+ )
+
+ except Exception as err:
+ current_app.logger.error(
+ f"Error syncing connector source: {err}",
+ exc_info=True
+ )
+ return make_response(
+ jsonify({
+ "success": False,
+ "error": str(err)
+ }),
+ 400
+ )
+
+
+@connectors_ns.route("/api/connectors/callback-status")
+class ConnectorCallbackStatus(Resource):
+ @api.doc(description="Return HTML page with connector authentication status")
+ def get(self):
+ """Return HTML page with connector authentication status"""
+ try:
+ status = request.args.get('status', 'error')
+ message = request.args.get('message', '')
+ provider = request.args.get('provider', 'connector')
+ session_token = request.args.get('session_token', '')
+ user_email = request.args.get('user_email', '')
+
+ html_content = f"""
+
+
+
+ {provider.replace('_', ' ').title()} Authentication
+
+
+
+
+
+
{provider.replace('_', ' ').title()} Authentication
+
+
{message}
+ {f'
Connected as: {user_email}
' if status == 'success' else ''}
+
+
You can close this window. {f"Your {provider.replace('_', ' ').title()} is now connected and ready to use." if status == 'success' else ''}
+
+
+
+ """
+
+ return make_response(html_content, 200, {'Content-Type': 'text/html'})
+ except Exception as e:
+ current_app.logger.error(f"Error rendering callback status page: {e}")
+ return make_response("Authentication error occurred", 500, {'Content-Type': 'text/html'})
+
+
diff --git a/application/api/user/routes.py b/application/api/user/routes.py
index b0554461..f0493c7c 100644
--- a/application/api/user/routes.py
+++ b/application/api/user/routes.py
@@ -32,13 +32,15 @@ from application.api import api
from application.api.user.tasks import (
ingest,
+ ingest_connector_task,
ingest_remote,
process_agent_webhook,
store_attachment,
)
from application.core.mongo_db import MongoDB
from application.core.settings import settings
-from application.security.encryption import encrypt_credentials, decrypt_credentials
+from application.parser.connectors.connector_creator import ConnectorCreator
+from application.security.encryption import decrypt_credentials, encrypt_credentials
from application.storage.storage_creator import StorageCreator
from application.tts.google_tts import GoogleTTS
from application.utils import (
@@ -76,7 +78,6 @@ try:
users_collection.create_index("user_id", unique=True)
except Exception as e:
print("Error creating indexes:", e)
-
user = Blueprint("user", __name__)
user_ns = Namespace("user", description="User related operations", path="/")
api.add_namespace(user_ns)
@@ -129,11 +130,9 @@ def ensure_user_doc(user_id):
updates["agent_preferences.pinned"] = []
if "shared_with_me" not in prefs:
updates["agent_preferences.shared_with_me"] = []
-
if updates:
users_collection.update_one({"user_id": user_id}, {"$set": updates})
user_doc = users_collection.find_one({"user_id": user_id})
-
return user_doc
@@ -185,7 +184,6 @@ def handle_image_upload(
jsonify({"success": False, "message": "Image upload failed"}),
400,
)
-
return image_url, None
@@ -299,8 +297,8 @@ class GetSingleConversation(Resource):
)
if not conversation:
return make_response(jsonify({"status": "not found"}), 404)
-
# Process queries to include attachment names
+
queries = conversation["queries"]
for query in queries:
if "attachments" in query and query["attachments"]:
@@ -501,6 +499,7 @@ class DeleteOldIndexes(Resource):
try:
# Delete vector index
+
if settings.VECTOR_STORE == "faiss":
index_path = f"indexes/{str(doc['_id'])}"
if storage.file_exists(f"{index_path}/index.faiss"):
@@ -571,6 +570,7 @@ class UploadFile(Resource):
job_name = request.form["name"]
# Create safe versions for filesystem operations
+
safe_user = safe_filename(user)
dir_name = safe_filename(job_name)
base_path = f"{settings.UPLOAD_FOLDER}/{safe_user}/{dir_name}"
@@ -592,6 +592,7 @@ class UploadFile(Resource):
zip_ref.extractall(path=temp_dir)
# Walk through extracted files and upload them
+
for root, _, files in os.walk(temp_dir):
for extracted_file in files:
if (
@@ -614,11 +615,13 @@ class UploadFile(Resource):
f"Error extracting zip: {e}", exc_info=True
)
# If zip extraction fails, save the original zip file
+
file_path = f"{base_path}/{safe_file}"
with open(temp_file_path, "rb") as f:
storage.save_file(f, file_path)
else:
# For non-zip files, save directly
+
file_path = f"{base_path}/{safe_file}"
with open(temp_file_path, "rb") as f:
storage.save_file(f, file_path)
@@ -709,7 +712,6 @@ class ManageSourceFiles(Resource):
),
400,
)
-
if operation not in ["add", "remove", "remove_directory"]:
return make_response(
jsonify(
@@ -720,14 +722,12 @@ class ManageSourceFiles(Resource):
),
400,
)
-
try:
ObjectId(source_id)
except Exception:
return make_response(
jsonify({"success": False, "message": "Invalid source ID format"}), 400
)
-
try:
source = sources_collection.find_one(
{"_id": ObjectId(source_id), "user": user}
@@ -760,7 +760,6 @@ class ManageSourceFiles(Resource):
),
400,
)
-
if operation == "add":
files = request.files.getlist("file")
if not files or all(file.filename == "" for file in files):
@@ -773,23 +772,22 @@ class ManageSourceFiles(Resource):
),
400,
)
-
added_files = []
target_dir = source_file_path
if parent_dir:
target_dir = f"{source_file_path}/{parent_dir}"
-
for file in files:
if file.filename:
safe_filename_str = safe_filename(file.filename)
file_path = f"{target_dir}/{safe_filename_str}"
# Save file to storage
+
storage.save_file(file, file_path)
added_files.append(safe_filename_str)
-
# Trigger re-ingestion pipeline
+
from application.api.user.tasks import reingest_source_task
task = reingest_source_task.delay(source_id=source_id, user=user)
@@ -819,7 +817,6 @@ class ManageSourceFiles(Resource):
),
400,
)
-
try:
file_paths = (
json.loads(file_paths_str)
@@ -833,18 +830,19 @@ class ManageSourceFiles(Resource):
),
400,
)
-
# Remove files from storage and directory structure
+
removed_files = []
for file_path in file_paths:
full_path = f"{source_file_path}/{file_path}"
# Remove from storage
+
if storage.file_exists(full_path):
storage.delete_file(full_path)
removed_files.append(file_path)
-
# Trigger re-ingestion pipeline
+
from application.api.user.tasks import reingest_source_task
task = reingest_source_task.delay(source_id=source_id, user=user)
@@ -873,8 +871,8 @@ class ManageSourceFiles(Resource):
),
400,
)
-
# Validate directory path (prevent path traversal)
+
if directory_path.startswith("/") or ".." in directory_path:
current_app.logger.warning(
f"Invalid directory path attempted for removal. "
@@ -908,7 +906,6 @@ class ManageSourceFiles(Resource):
),
404,
)
-
success = storage.remove_directory(full_directory_path)
if not success:
@@ -923,7 +920,6 @@ class ManageSourceFiles(Resource):
),
500,
)
-
current_app.logger.info(
f"Successfully removed directory. "
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}, "
@@ -931,6 +927,7 @@ class ManageSourceFiles(Resource):
)
# Trigger re-ingestion pipeline
+
from application.api.user.tasks import reingest_source_task
task = reingest_source_task.delay(source_id=source_id, user=user)
@@ -1005,6 +1002,50 @@ class UploadRemote(Resource):
source_data = config.get("url")
elif data["source"] == "reddit":
source_data = config
+ elif data["source"] in ConnectorCreator.get_supported_connectors():
+ session_token = config.get("session_token")
+ if not session_token:
+ return make_response(
+ jsonify(
+ {
+ "success": False,
+ "error": f"Missing session_token in {data['source']} configuration",
+ }
+ ),
+ 400,
+ )
+ # Process file_ids
+
+ file_ids = config.get("file_ids", [])
+ if isinstance(file_ids, str):
+ file_ids = [id.strip() for id in file_ids.split(",") if id.strip()]
+ elif not isinstance(file_ids, list):
+ file_ids = []
+ # Process folder_ids
+
+ folder_ids = config.get("folder_ids", [])
+ if isinstance(folder_ids, str):
+ folder_ids = [
+ id.strip() for id in folder_ids.split(",") if id.strip()
+ ]
+ elif not isinstance(folder_ids, list):
+ folder_ids = []
+ config["file_ids"] = file_ids
+ config["folder_ids"] = folder_ids
+
+ task = ingest_connector_task.delay(
+ job_name=data["name"],
+ user=decoded_token.get("sub"),
+ source_type=data["source"],
+ session_token=session_token,
+ file_ids=file_ids,
+ folder_ids=folder_ids,
+ recursive=config.get("recursive", False),
+ retriever=config.get("retriever", "classic"),
+ )
+ return make_response(
+ jsonify({"success": True, "task_id": task.id}), 200
+ )
task = ingest_remote.delay(
source_data=source_data,
job_name=data["name"],
@@ -1113,6 +1154,7 @@ class PaginatedSources(Resource):
"retriever": doc.get("retriever", "classic"),
"syncFrequency": doc.get("sync_frequency", ""),
"isNested": bool(doc.get("directory_structure")),
+ "type": doc.get("type", "file"),
}
paginated_docs.append(doc_data)
response = {
@@ -1161,6 +1203,9 @@ class CombinedJson(Resource):
"retriever": index.get("retriever", "classic"),
"syncFrequency": index.get("sync_frequency", ""),
"is_nested": bool(index.get("directory_structure")),
+ "type": index.get(
+ "type", "file"
+ ), # Add type field with default "file"
}
)
except Exception as err:
@@ -1376,17 +1421,14 @@ class GetAgent(Resource):
def get(self):
if not (decoded_token := request.decoded_token):
return {"success": False}, 401
-
if not (agent_id := request.args.get("id")):
return {"success": False, "message": "ID required"}, 400
-
try:
agent = agents_collection.find_one(
{"_id": ObjectId(agent_id), "user": decoded_token["sub"]}
)
if not agent:
return {"status": "Not found"}, 404
-
data = {
"id": str(agent["_id"]),
"name": agent["name"],
@@ -1400,6 +1442,16 @@ class GetAgent(Resource):
and (source_doc := db.dereference(agent.get("source")))
else ""
),
+ "sources": [
+ (
+ str(db.dereference(source_ref)["_id"])
+ if isinstance(source_ref, DBRef) and db.dereference(source_ref)
+ else source_ref
+ )
+ for source_ref in agent.get("sources", [])
+ if (isinstance(source_ref, DBRef) and db.dereference(source_ref))
+ or source_ref == "default"
+ ],
"chunks": agent["chunks"],
"retriever": agent.get("retriever", ""),
"prompt_id": agent.get("prompt_id", ""),
@@ -1422,7 +1474,6 @@ class GetAgent(Resource):
"shared_token": agent.get("shared_token", ""),
}
return make_response(jsonify(data), 200)
-
except Exception as e:
current_app.logger.error(f"Agent fetch error: {e}", exc_info=True)
return {"success": False}, 400
@@ -1434,7 +1485,6 @@ class GetAgents(Resource):
def get(self):
if not (decoded_token := request.decoded_token):
return {"success": False}, 401
-
user = decoded_token.get("sub")
try:
user_doc = ensure_user_doc(user)
@@ -1453,8 +1503,24 @@ class GetAgents(Resource):
str(source_doc["_id"])
if isinstance(agent.get("source"), DBRef)
and (source_doc := db.dereference(agent.get("source")))
- else ""
+ else (
+ agent.get("source", "")
+ if agent.get("source") == "default"
+ else ""
+ )
),
+ "sources": [
+ (
+ source_ref
+ if source_ref == "default"
+ else str(db.dereference(source_ref)["_id"])
+ )
+ for source_ref in agent.get("sources", [])
+ if source_ref == "default"
+ or (
+ isinstance(source_ref, DBRef) and db.dereference(source_ref)
+ )
+ ],
"chunks": agent["chunks"],
"retriever": agent.get("retriever", ""),
"prompt_id": agent.get("prompt_id", ""),
@@ -1497,7 +1563,14 @@ class CreateAgent(Resource):
"image": fields.Raw(
required=False, description="Image file upload", type="file"
),
- "source": fields.String(required=True, description="Source ID"),
+ "source": fields.String(
+ required=False, description="Source ID (legacy single source)"
+ ),
+ "sources": fields.List(
+ fields.String,
+ required=False,
+ description="List of source identifiers for multiple sources",
+ ),
"chunks": fields.Integer(required=True, description="Chunks count"),
"retriever": fields.String(required=True, description="Retriever ID"),
"prompt_id": fields.String(required=True, description="Prompt ID"),
@@ -1530,6 +1603,11 @@ class CreateAgent(Resource):
data["tools"] = json.loads(data["tools"])
except json.JSONDecodeError:
data["tools"] = []
+ if "sources" in data:
+ try:
+ data["sources"] = json.loads(data["sources"])
+ except json.JSONDecodeError:
+ data["sources"] = []
if "json_schema" in data:
try:
data["json_schema"] = json.loads(data["json_schema"])
@@ -1538,9 +1616,11 @@ class CreateAgent(Resource):
print(f"Received data: {data}")
# Validate JSON schema if provided
+
if data.get("json_schema"):
try:
# Basic validation - ensure it's a valid JSON structure
+
json_schema = data.get("json_schema")
if not isinstance(json_schema, dict):
return make_response(
@@ -1554,6 +1634,7 @@ class CreateAgent(Resource):
)
# Validate that it has either a 'schema' property or is itself a schema
+
if "schema" not in json_schema and "type" not in json_schema:
return make_response(
jsonify(
@@ -1571,7 +1652,6 @@ class CreateAgent(Resource):
),
400,
)
-
if data.get("status") not in ["draft", "published"]:
return make_response(
jsonify(
@@ -1582,17 +1662,27 @@ class CreateAgent(Resource):
),
400,
)
-
if data.get("status") == "published":
required_fields = [
"name",
"description",
- "source",
"chunks",
"retriever",
"prompt_id",
"agent_type",
]
+ # Require either source or sources (but not both)
+
+ if not data.get("source") and not data.get("sources"):
+ return make_response(
+ jsonify(
+ {
+ "success": False,
+ "message": "Either 'source' or 'sources' field is required for published agents",
+ }
+ ),
+ 400,
+ )
validate_fields = ["name", "description", "prompt_id", "agent_type"]
else:
required_fields = ["name"]
@@ -1603,25 +1693,37 @@ class CreateAgent(Resource):
return missing_fields
if invalid_fields:
return invalid_fields
-
image_url, error = handle_image_upload(request, "", user, storage)
if error:
return make_response(
jsonify({"success": False, "message": "Image upload failed"}), 400
)
-
try:
key = str(uuid.uuid4()) if data.get("status") == "published" else ""
+
+ sources_list = []
+ if data.get("sources") and len(data.get("sources", [])) > 0:
+ for source_id in data.get("sources", []):
+ if source_id == "default":
+ sources_list.append("default")
+ elif ObjectId.is_valid(source_id):
+ sources_list.append(DBRef("sources", ObjectId(source_id)))
+ source_field = ""
+ else:
+ source_value = data.get("source", "")
+ if source_value == "default":
+ source_field = "default"
+ elif ObjectId.is_valid(source_value):
+ source_field = DBRef("sources", ObjectId(source_value))
+ else:
+ source_field = ""
new_agent = {
"user": user,
"name": data.get("name"),
"description": data.get("description", ""),
"image": image_url,
- "source": (
- DBRef("sources", ObjectId(data.get("source")))
- if ObjectId.is_valid(data.get("source"))
- else ""
- ),
+ "source": source_field,
+ "sources": sources_list,
"chunks": data.get("chunks", ""),
"retriever": data.get("retriever", ""),
"prompt_id": data.get("prompt_id", ""),
@@ -1636,7 +1738,11 @@ class CreateAgent(Resource):
}
if new_agent["chunks"] == "":
new_agent["chunks"] = "0"
- if new_agent["source"] == "" and new_agent["retriever"] == "":
+ if (
+ new_agent["source"] == ""
+ and new_agent["retriever"] == ""
+ and not new_agent["sources"]
+ ):
new_agent["retriever"] = "classic"
resp = agents_collection.insert_one(new_agent)
new_id = str(resp.inserted_id)
@@ -1658,7 +1764,14 @@ class UpdateAgent(Resource):
"image": fields.String(
required=False, description="New image URL or identifier"
),
- "source": fields.String(required=True, description="Source ID"),
+ "source": fields.String(
+ required=False, description="Source ID (legacy single source)"
+ ),
+ "sources": fields.List(
+ fields.String,
+ required=False,
+ description="List of source identifiers for multiple sources",
+ ),
"chunks": fields.Integer(required=True, description="Chunks count"),
"retriever": fields.String(required=True, description="Retriever ID"),
"prompt_id": fields.String(required=True, description="Prompt ID"),
@@ -1691,12 +1804,16 @@ class UpdateAgent(Resource):
data["tools"] = json.loads(data["tools"])
except json.JSONDecodeError:
data["tools"] = []
+ if "sources" in data:
+ try:
+ data["sources"] = json.loads(data["sources"])
+ except json.JSONDecodeError:
+ data["sources"] = []
if "json_schema" in data:
try:
data["json_schema"] = json.loads(data["json_schema"])
except json.JSONDecodeError:
data["json_schema"] = None
-
if not ObjectId.is_valid(agent_id):
return make_response(
jsonify({"success": False, "message": "Invalid agent ID format"}), 400
@@ -1720,7 +1837,6 @@ class UpdateAgent(Resource):
),
404,
)
-
image_url, error = handle_image_upload(
request, existing_agent.get("image", ""), user, storage
)
@@ -1728,13 +1844,13 @@ class UpdateAgent(Resource):
return make_response(
jsonify({"success": False, "message": "Image upload failed"}), 400
)
-
update_fields = {}
allowed_fields = [
"name",
"description",
"image",
"source",
+ "sources",
"chunks",
"retriever",
"prompt_id",
@@ -1758,7 +1874,11 @@ class UpdateAgent(Resource):
update_fields[field] = new_status
elif field == "source":
source_id = data.get("source")
- if source_id and ObjectId.is_valid(source_id):
+ if source_id == "default":
+ # Handle special "default" source
+
+ update_fields[field] = "default"
+ elif source_id and ObjectId.is_valid(source_id):
update_fields[field] = DBRef("sources", ObjectId(source_id))
elif source_id:
return make_response(
@@ -1772,6 +1892,30 @@ class UpdateAgent(Resource):
)
else:
update_fields[field] = ""
+ elif field == "sources":
+ sources_list = data.get("sources", [])
+ if sources_list and isinstance(sources_list, list):
+ valid_sources = []
+ for source_id in sources_list:
+ if source_id == "default":
+ valid_sources.append("default")
+ elif ObjectId.is_valid(source_id):
+ valid_sources.append(
+ DBRef("sources", ObjectId(source_id))
+ )
+ else:
+ return make_response(
+ jsonify(
+ {
+ "success": False,
+ "message": f"Invalid source ID format: {source_id}",
+ }
+ ),
+ 400,
+ )
+ update_fields[field] = valid_sources
+ else:
+ update_fields[field] = []
elif field == "chunks":
chunks_value = data.get("chunks")
if chunks_value == "":
@@ -1837,7 +1981,6 @@ class UpdateAgent(Resource):
),
400,
)
-
if not existing_agent.get("key"):
newly_generated_key = str(uuid.uuid4())
update_fields["key"] = newly_generated_key
@@ -1924,7 +2067,6 @@ class PinnedAgents(Resource):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
-
user_id = decoded_token.get("sub")
try:
@@ -1933,7 +2075,6 @@ class PinnedAgents(Resource):
if not pinned_ids:
return make_response(jsonify([]), 200)
-
pinned_object_ids = [ObjectId(agent_id) for agent_id in pinned_ids]
pinned_agents_cursor = agents_collection.find(
@@ -1943,6 +2084,7 @@ class PinnedAgents(Resource):
existing_ids = {str(agent["_id"]) for agent in pinned_agents}
# Clean up any stale pinned IDs
+
stale_ids = [
agent_id for agent_id in pinned_ids if agent_id not in existing_ids
]
@@ -1951,7 +2093,6 @@ class PinnedAgents(Resource):
{"user_id": user_id},
{"$pullAll": {"agent_preferences.pinned": stale_ids}},
)
-
list_pinned_agents = [
{
"id": str(agent["_id"]),
@@ -1988,11 +2129,9 @@ class PinnedAgents(Resource):
for agent in pinned_agents
if "source" in agent or "retriever" in agent
]
-
except Exception as err:
current_app.logger.error(f"Error retrieving pinned agents: {err}")
return make_response(jsonify({"success": False}), 400)
-
return make_response(jsonify(list_pinned_agents), 200)
@@ -2056,7 +2195,6 @@ class RemoveSharedAgent(Resource):
return make_response(
jsonify({"success": False, "message": "ID is required"}), 400
)
-
try:
agent = agents_collection.find_one(
{"_id": ObjectId(agent_id), "shared_publicly": True}
@@ -2066,7 +2204,6 @@ class RemoveSharedAgent(Resource):
jsonify({"success": False, "message": "Shared agent not found"}),
404,
)
-
ensure_user_doc(user_id)
users_collection.update_one(
{"user_id": user_id},
@@ -2079,7 +2216,6 @@ class RemoveSharedAgent(Resource):
)
return make_response(jsonify({"success": True, "action": "removed"}), 200)
-
except Exception as err:
current_app.logger.error(f"Error removing shared agent: {err}")
return make_response(
@@ -2102,7 +2238,6 @@ class SharedAgent(Resource):
return make_response(
jsonify({"success": False, "message": "Token or ID is required"}), 400
)
-
try:
query = {
"shared_publicly": True,
@@ -2114,7 +2249,6 @@ class SharedAgent(Resource):
jsonify({"success": False, "message": "Shared agent not found"}),
404,
)
-
agent_id = str(shared_agent["_id"])
data = {
"id": agent_id,
@@ -2154,7 +2288,6 @@ class SharedAgent(Resource):
if tool_data:
enriched_tools.append(tool_data.get("name", ""))
data["tools"] = enriched_tools
-
decoded_token = getattr(request, "decoded_token", None)
if decoded_token:
user_id = decoded_token.get("sub")
@@ -2166,9 +2299,7 @@ class SharedAgent(Resource):
{"user_id": user_id},
{"$addToSet": {"agent_preferences.shared_with_me": agent_id}},
)
-
return make_response(jsonify(data), 200)
-
except Exception as err:
current_app.logger.error(f"Error retrieving shared agent: {err}")
return make_response(jsonify({"success": False}), 400)
@@ -2202,7 +2333,6 @@ class SharedAgents(Resource):
{"user_id": user_id},
{"$pullAll": {"agent_preferences.shared_with_me": stale_ids}},
)
-
pinned_ids = set(user_doc.get("agent_preferences", {}).get("pinned", []))
list_shared_agents = [
@@ -2229,7 +2359,6 @@ class SharedAgents(Resource):
]
return make_response(jsonify(list_shared_agents), 200)
-
except Exception as err:
current_app.logger.error(f"Error retrieving shared agents: {err}")
return make_response(jsonify({"success": False}), 400)
@@ -3762,20 +3891,21 @@ class GetChunks(Resource):
metadata = chunk.get("metadata", {})
# Filter by path if provided
+
if path:
chunk_source = metadata.get("source", "")
# Check if the chunk's source matches the requested path
+
if not chunk_source or not chunk_source.endswith(path):
continue
-
# Filter by search term if provided
+
if search_term:
text_match = search_term in chunk.get("text", "").lower()
title_match = search_term in metadata.get("title", "").lower()
if not (text_match or title_match):
continue
-
filtered_chunks.append(chunk)
chunks = filtered_chunks
@@ -3937,7 +4067,6 @@ class UpdateChunk(Resource):
if metadata is None:
metadata = {}
metadata["token_count"] = token_count
-
if not ObjectId.is_valid(doc_id):
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
@@ -3952,7 +4081,6 @@ class UpdateChunk(Resource):
existing_chunk = next((c for c in chunks if c["doc_id"] == chunk_id), None)
if not existing_chunk:
return make_response(jsonify({"error": "Chunk not found"}), 404)
-
new_text = text if text is not None else existing_chunk["text"]
if metadata is not None:
@@ -3960,10 +4088,8 @@ class UpdateChunk(Resource):
new_metadata.update(metadata)
else:
new_metadata = existing_chunk["metadata"].copy()
-
if text is not None:
new_metadata["token_count"] = num_tokens_from_string(new_text)
-
try:
new_chunk_id = store.add_chunk(new_text, new_metadata)
@@ -4019,7 +4145,6 @@ class StoreAttachment(Resource):
jsonify({"status": "error", "message": "Missing file"}),
400,
)
-
user = None
if decoded_token:
user = safe_filename(decoded_token.get("sub"))
@@ -4034,7 +4159,6 @@ class StoreAttachment(Resource):
return make_response(
jsonify({"success": False, "message": "Authentication required"}), 401
)
-
try:
attachment_id = ObjectId()
original_filename = safe_filename(os.path.basename(file.filename))
@@ -4076,7 +4200,6 @@ class ServeImage(Resource):
content_type = f"image/{extension}"
if extension == "jpg":
content_type = "image/jpeg"
-
response = make_response(file_obj.read())
response.headers.set("Content-Type", content_type)
response.headers.set("Cache-Control", "max-age=86400")
@@ -4121,18 +4244,29 @@ class DirectoryStructure(Resource):
)
directory_structure = doc.get("directory_structure", {})
+ base_path = doc.get("file_path", "")
+ provider = None
+ remote_data = doc.get("remote_data")
+ try:
+ if isinstance(remote_data, str) and remote_data:
+ remote_data_obj = json.loads(remote_data)
+ provider = remote_data_obj.get("provider")
+ except Exception as e:
+ current_app.logger.warning(
+ f"Failed to parse remote_data for doc {doc_id}: {e}"
+ )
return make_response(
jsonify(
{
"success": True,
"directory_structure": directory_structure,
- "base_path": doc.get("file_path", ""),
+ "base_path": base_path,
+ "provider": provider,
}
),
200,
)
-
except Exception as e:
current_app.logger.error(
f"Error retrieving directory structure: {e}", exc_info=True
diff --git a/application/api/user/tasks.py b/application/api/user/tasks.py
index 28a78c0d..3519b701 100644
--- a/application/api/user/tasks.py
+++ b/application/api/user/tasks.py
@@ -47,6 +47,39 @@ def process_agent_webhook(self, agent_id, payload):
return resp
+@celery.task(bind=True)
+def ingest_connector_task(
+ self,
+ job_name,
+ user,
+ source_type,
+ session_token=None,
+ file_ids=None,
+ folder_ids=None,
+ recursive=True,
+ retriever="classic",
+ operation_mode="upload",
+ doc_id=None,
+ sync_frequency="never"
+):
+ from application.worker import ingest_connector
+ resp = ingest_connector(
+ self,
+ job_name,
+ user,
+ source_type,
+ session_token=session_token,
+ file_ids=file_ids,
+ folder_ids=folder_ids,
+ recursive=recursive,
+ retriever=retriever,
+ operation_mode=operation_mode,
+ doc_id=doc_id,
+ sync_frequency=sync_frequency
+ )
+ return resp
+
+
@celery.on_after_configure.connect
def setup_periodic_tasks(sender, **kwargs):
sender.add_periodic_task(
diff --git a/application/app.py b/application/app.py
index 4159a2bb..489ec840 100644
--- a/application/app.py
+++ b/application/app.py
@@ -16,6 +16,7 @@ from application.api import api # noqa: E402
from application.api.answer import answer # noqa: E402
from application.api.internal.routes import internal # noqa: E402
from application.api.user.routes import user # noqa: E402
+from application.api.connector.routes import connector # noqa: E402
from application.celery_init import celery # noqa: E402
from application.core.settings import settings # noqa: E402
@@ -30,6 +31,7 @@ app = Flask(__name__)
app.register_blueprint(user)
app.register_blueprint(answer)
app.register_blueprint(internal)
+app.register_blueprint(connector)
app.config.update(
UPLOAD_FOLDER="inputs",
CELERY_BROKER_URL=settings.CELERY_BROKER_URL,
diff --git a/application/core/settings.py b/application/core/settings.py
index f1563569..7ede4e86 100644
--- a/application/core/settings.py
+++ b/application/core/settings.py
@@ -40,6 +40,13 @@ class Settings(BaseSettings):
FALLBACK_LLM_NAME: Optional[str] = None # model name for fallback llm
FALLBACK_LLM_API_KEY: Optional[str] = None # api key for fallback llm
+ # Google Drive integration
+ GOOGLE_CLIENT_ID: Optional[str] = None # Replace with your actual Google OAuth client ID
+ GOOGLE_CLIENT_SECRET: Optional[str] = None# Replace with your actual Google OAuth client secret
+ CONNECTOR_REDIRECT_BASE_URI: Optional[str] = "http://127.0.0.1:7091/api/connectors/callback"
+ ##append ?provider={provider_name} in your Provider console like http://127.0.0.1:7091/api/connectors/callback?provider=google_drive
+
+
# LLM Cache
CACHE_REDIS_URL: str = "redis://localhost:6379/2"
diff --git a/application/parser/connectors/__init__.py b/application/parser/connectors/__init__.py
new file mode 100644
index 00000000..c9add3d7
--- /dev/null
+++ b/application/parser/connectors/__init__.py
@@ -0,0 +1,18 @@
+"""
+External knowledge base connectors for DocsGPT.
+
+This module contains connectors for external knowledge bases and document storage systems
+that require authentication and specialized handling, separate from simple web scrapers.
+"""
+
+from .base import BaseConnectorAuth, BaseConnectorLoader
+from .connector_creator import ConnectorCreator
+from .google_drive import GoogleDriveAuth, GoogleDriveLoader
+
+__all__ = [
+ 'BaseConnectorAuth',
+ 'BaseConnectorLoader',
+ 'ConnectorCreator',
+ 'GoogleDriveAuth',
+ 'GoogleDriveLoader'
+]
diff --git a/application/parser/connectors/base.py b/application/parser/connectors/base.py
new file mode 100644
index 00000000..dfb6de87
--- /dev/null
+++ b/application/parser/connectors/base.py
@@ -0,0 +1,129 @@
+"""
+Base classes for external knowledge base connectors.
+
+This module provides minimal abstract base classes that define the essential
+interface for external knowledge base connectors.
+"""
+
+from abc import ABC, abstractmethod
+from typing import Any, Dict, List, Optional
+
+from application.parser.schema.base import Document
+
+
+class BaseConnectorAuth(ABC):
+ """
+ Abstract base class for connector authentication.
+
+ Defines the minimal interface that all connector authentication
+ implementations must follow.
+ """
+
+ @abstractmethod
+ def get_authorization_url(self, state: Optional[str] = None) -> str:
+ """
+ Generate authorization URL for OAuth flows.
+
+ Args:
+ state: Optional state parameter for CSRF protection
+
+ Returns:
+ Authorization URL
+ """
+ pass
+
+ @abstractmethod
+ def exchange_code_for_tokens(self, authorization_code: str) -> Dict[str, Any]:
+ """
+ Exchange authorization code for access tokens.
+
+ Args:
+ authorization_code: Authorization code from OAuth callback
+
+ Returns:
+ Dictionary containing token information
+ """
+ pass
+
+ @abstractmethod
+ def refresh_access_token(self, refresh_token: str) -> Dict[str, Any]:
+ """
+ Refresh an expired access token.
+
+ Args:
+ refresh_token: Refresh token
+
+ Returns:
+ Dictionary containing refreshed token information
+ """
+ pass
+
+ @abstractmethod
+ def is_token_expired(self, token_info: Dict[str, Any]) -> bool:
+ """
+ Check if a token is expired.
+
+ Args:
+ token_info: Token information dictionary
+
+ Returns:
+ True if token is expired, False otherwise
+ """
+ pass
+
+
+class BaseConnectorLoader(ABC):
+ """
+ Abstract base class for connector loaders.
+
+ Defines the minimal interface that all connector loader
+ implementations must follow.
+ """
+
+ @abstractmethod
+ def __init__(self, session_token: str):
+ """
+ Initialize the connector loader.
+
+ Args:
+ session_token: Authentication session token
+ """
+ pass
+
+ @abstractmethod
+ def load_data(self, inputs: Dict[str, Any]) -> List[Document]:
+ """
+ Load documents from the external knowledge base.
+
+ Args:
+ inputs: Configuration dictionary containing:
+ - file_ids: Optional list of specific file IDs to load
+ - folder_ids: Optional list of folder IDs to browse/download
+ - limit: Maximum number of items to return
+ - list_only: If True, return metadata without content
+ - recursive: Whether to recursively process folders
+
+ Returns:
+ List of Document objects
+ """
+ pass
+
+ @abstractmethod
+ def download_to_directory(self, local_dir: str, source_config: Dict[str, Any] = None) -> Dict[str, Any]:
+ """
+ Download files/folders to a local directory.
+
+ Args:
+ local_dir: Local directory path to download files to
+ source_config: Configuration for what to download
+
+ Returns:
+ Dictionary containing download results:
+ - files_downloaded: Number of files downloaded
+ - directory_path: Path where files were downloaded
+ - empty_result: Whether no files were downloaded
+ - source_type: Type of connector
+ - config_used: Configuration that was used
+ - error: Error message if download failed (optional)
+ """
+ pass
diff --git a/application/parser/connectors/connector_creator.py b/application/parser/connectors/connector_creator.py
new file mode 100644
index 00000000..bf4456ca
--- /dev/null
+++ b/application/parser/connectors/connector_creator.py
@@ -0,0 +1,81 @@
+from application.parser.connectors.google_drive.loader import GoogleDriveLoader
+from application.parser.connectors.google_drive.auth import GoogleDriveAuth
+
+
+class ConnectorCreator:
+ """
+ Factory class for creating external knowledge base connectors and auth providers.
+
+ These are different from remote loaders as they typically require
+ authentication and connect to external document storage systems.
+ """
+
+ connectors = {
+ "google_drive": GoogleDriveLoader,
+ }
+
+ auth_providers = {
+ "google_drive": GoogleDriveAuth,
+ }
+
+ @classmethod
+ def create_connector(cls, connector_type, *args, **kwargs):
+ """
+ Create a connector instance for the specified type.
+
+ Args:
+ connector_type: Type of connector to create (e.g., 'google_drive')
+ *args, **kwargs: Arguments to pass to the connector constructor
+
+ Returns:
+ Connector instance
+
+ Raises:
+ ValueError: If connector type is not supported
+ """
+ connector_class = cls.connectors.get(connector_type.lower())
+ if not connector_class:
+ raise ValueError(f"No connector class found for type {connector_type}")
+ return connector_class(*args, **kwargs)
+
+ @classmethod
+ def create_auth(cls, connector_type):
+ """
+ Create an auth provider instance for the specified connector type.
+
+ Args:
+ connector_type: Type of connector auth to create (e.g., 'google_drive')
+
+ Returns:
+ Auth provider instance
+
+ Raises:
+ ValueError: If connector type is not supported for auth
+ """
+ auth_class = cls.auth_providers.get(connector_type.lower())
+ if not auth_class:
+ raise ValueError(f"No auth class found for type {connector_type}")
+ return auth_class()
+
+ @classmethod
+ def get_supported_connectors(cls):
+ """
+ Get list of supported connector types.
+
+ Returns:
+ List of supported connector type strings
+ """
+ return list(cls.connectors.keys())
+
+ @classmethod
+ def is_supported(cls, connector_type):
+ """
+ Check if a connector type is supported.
+
+ Args:
+ connector_type: Type of connector to check
+
+ Returns:
+ True if supported, False otherwise
+ """
+ return connector_type.lower() in cls.connectors
diff --git a/application/parser/connectors/google_drive/__init__.py b/application/parser/connectors/google_drive/__init__.py
new file mode 100644
index 00000000..18abeec1
--- /dev/null
+++ b/application/parser/connectors/google_drive/__init__.py
@@ -0,0 +1,10 @@
+"""
+Google Drive connector for DocsGPT.
+
+This module provides authentication and document loading capabilities for Google Drive.
+"""
+
+from .auth import GoogleDriveAuth
+from .loader import GoogleDriveLoader
+
+__all__ = ['GoogleDriveAuth', 'GoogleDriveLoader']
diff --git a/application/parser/connectors/google_drive/auth.py b/application/parser/connectors/google_drive/auth.py
new file mode 100644
index 00000000..37d55dcc
--- /dev/null
+++ b/application/parser/connectors/google_drive/auth.py
@@ -0,0 +1,268 @@
+import logging
+import datetime
+from typing import Optional, Dict, Any
+
+from google.oauth2.credentials import Credentials
+from google_auth_oauthlib.flow import Flow
+from googleapiclient.discovery import build
+from googleapiclient.errors import HttpError
+
+from application.core.settings import settings
+from application.parser.connectors.base import BaseConnectorAuth
+
+
+class GoogleDriveAuth(BaseConnectorAuth):
+ """
+ Handles Google OAuth 2.0 authentication for Google Drive access.
+ """
+
+ SCOPES = [
+ 'https://www.googleapis.com/auth/drive.readonly',
+ 'https://www.googleapis.com/auth/drive.metadata.readonly'
+ ]
+
+ def __init__(self):
+ self.client_id = settings.GOOGLE_CLIENT_ID
+ self.client_secret = settings.GOOGLE_CLIENT_SECRET
+ self.redirect_uri = f"{settings.CONNECTOR_REDIRECT_BASE_URI}?provider=google_drive"
+
+ if not self.client_id or not self.client_secret:
+ raise ValueError("Google OAuth credentials not configured. Please set GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET in settings.")
+
+
+
+ def get_authorization_url(self, state: Optional[str] = None) -> str:
+ try:
+ flow = Flow.from_client_config(
+ {
+ "web": {
+ "client_id": self.client_id,
+ "client_secret": self.client_secret,
+ "auth_uri": "https://accounts.google.com/o/oauth2/auth",
+ "token_uri": "https://oauth2.googleapis.com/token",
+ "redirect_uris": [self.redirect_uri]
+ }
+ },
+ scopes=self.SCOPES
+ )
+ flow.redirect_uri = self.redirect_uri
+
+ authorization_url, _ = flow.authorization_url(
+ access_type='offline',
+ prompt='consent',
+ include_granted_scopes='true',
+ state=state
+ )
+
+ return authorization_url
+
+ except Exception as e:
+ logging.error(f"Error generating authorization URL: {e}")
+ raise
+
+ def exchange_code_for_tokens(self, authorization_code: str) -> Dict[str, Any]:
+ try:
+ if not authorization_code:
+ raise ValueError("Authorization code is required")
+
+ flow = Flow.from_client_config(
+ {
+ "web": {
+ "client_id": self.client_id,
+ "client_secret": self.client_secret,
+ "auth_uri": "https://accounts.google.com/o/oauth2/auth",
+ "token_uri": "https://oauth2.googleapis.com/token",
+ "redirect_uris": [self.redirect_uri]
+ }
+ },
+ scopes=self.SCOPES
+ )
+ flow.redirect_uri = self.redirect_uri
+
+ flow.fetch_token(code=authorization_code)
+
+ credentials = flow.credentials
+
+ if not credentials.refresh_token:
+ logging.warning("OAuth flow did not return a refresh_token.")
+ if not credentials.token:
+ raise ValueError("OAuth flow did not return an access token")
+
+ if not credentials.token_uri:
+ credentials.token_uri = "https://oauth2.googleapis.com/token"
+
+ if not credentials.client_id:
+ credentials.client_id = self.client_id
+
+ if not credentials.client_secret:
+ credentials.client_secret = self.client_secret
+
+ if not credentials.refresh_token:
+ raise ValueError(
+ "No refresh token received. This typically happens when offline access wasn't granted. "
+ )
+
+ return {
+ 'access_token': credentials.token,
+ 'refresh_token': credentials.refresh_token,
+ 'token_uri': credentials.token_uri,
+ 'client_id': credentials.client_id,
+ 'client_secret': credentials.client_secret,
+ 'scopes': credentials.scopes,
+ 'expiry': credentials.expiry.isoformat() if credentials.expiry else None
+ }
+
+ except Exception as e:
+ logging.error(f"Error exchanging code for tokens: {e}")
+ raise
+
+ def refresh_access_token(self, refresh_token: str) -> Dict[str, Any]:
+ try:
+ if not refresh_token:
+ raise ValueError("Refresh token is required")
+
+ credentials = Credentials(
+ token=None,
+ refresh_token=refresh_token,
+ token_uri="https://oauth2.googleapis.com/token",
+ client_id=self.client_id,
+ client_secret=self.client_secret
+ )
+
+ from google.auth.transport.requests import Request
+ credentials.refresh(Request())
+
+ return {
+ 'access_token': credentials.token,
+ 'refresh_token': refresh_token,
+ 'token_uri': credentials.token_uri,
+ 'client_id': credentials.client_id,
+ 'client_secret': credentials.client_secret,
+ 'scopes': credentials.scopes,
+ 'expiry': credentials.expiry.isoformat() if credentials.expiry else None
+ }
+ except Exception as e:
+ logging.error(f"Error refreshing access token: {e}", exc_info=True)
+ raise
+
+ def create_credentials_from_token_info(self, token_info: Dict[str, Any]) -> Credentials:
+ from application.core.settings import settings
+
+ access_token = token_info.get('access_token')
+ if not access_token:
+ raise ValueError("No access token found in token_info")
+
+ credentials = Credentials(
+ token=access_token,
+ refresh_token=token_info.get('refresh_token'),
+ token_uri= 'https://oauth2.googleapis.com/token',
+ client_id=settings.GOOGLE_CLIENT_ID,
+ client_secret=settings.GOOGLE_CLIENT_SECRET,
+ scopes=token_info.get('scopes', ['https://www.googleapis.com/auth/drive.readonly'])
+ )
+
+ if not credentials.token:
+ raise ValueError("Credentials created without valid access token")
+
+ return credentials
+
+ def build_drive_service(self, credentials: Credentials):
+ try:
+ if not credentials:
+ raise ValueError("No credentials provided")
+
+ if not credentials.token and not credentials.refresh_token:
+ raise ValueError("No access token or refresh token available. User must re-authorize with offline access.")
+
+ needs_refresh = credentials.expired or not credentials.token
+ if needs_refresh:
+ if credentials.refresh_token:
+ try:
+ from google.auth.transport.requests import Request
+ credentials.refresh(Request())
+ except Exception as refresh_error:
+ raise ValueError(f"Failed to refresh credentials: {refresh_error}")
+ else:
+ raise ValueError("No access token or refresh token available. User must re-authorize with offline access.")
+
+ return build('drive', 'v3', credentials=credentials)
+
+ except HttpError as e:
+ raise ValueError(f"Failed to build Google Drive service: HTTP {e.resp.status}")
+ except Exception as e:
+ raise ValueError(f"Failed to build Google Drive service: {str(e)}")
+
+ def is_token_expired(self, token_info):
+ if 'expiry' in token_info and token_info['expiry']:
+ try:
+ from dateutil import parser
+ # Google Drive provides timezone-aware ISO8601 dates
+ expiry_dt = parser.parse(token_info['expiry'])
+ current_time = datetime.datetime.now(datetime.timezone.utc)
+ return current_time >= expiry_dt - datetime.timedelta(seconds=60)
+ except Exception:
+ return True
+
+ if 'access_token' in token_info and token_info['access_token']:
+ return False
+
+ return True
+
+ def get_token_info_from_session(self, session_token: str) -> Dict[str, Any]:
+ try:
+ from application.core.mongo_db import MongoDB
+ from application.core.settings import settings
+
+ mongo = MongoDB.get_client()
+ db = mongo[settings.MONGO_DB_NAME]
+
+ sessions_collection = db["connector_sessions"]
+ session = sessions_collection.find_one({"session_token": session_token})
+ if not session:
+ raise ValueError(f"Invalid session token: {session_token}")
+
+ if "token_info" not in session:
+ raise ValueError("Session missing token information")
+
+ token_info = session["token_info"]
+ if not token_info:
+ raise ValueError("Invalid token information")
+
+ required_fields = ["access_token", "refresh_token"]
+ missing_fields = [field for field in required_fields if field not in token_info or not token_info.get(field)]
+ if missing_fields:
+ raise ValueError(f"Missing required token fields: {missing_fields}")
+
+ if 'client_id' not in token_info:
+ token_info['client_id'] = settings.GOOGLE_CLIENT_ID
+ if 'client_secret' not in token_info:
+ token_info['client_secret'] = settings.GOOGLE_CLIENT_SECRET
+ if 'token_uri' not in token_info:
+ token_info['token_uri'] = 'https://oauth2.googleapis.com/token'
+
+ return token_info
+
+ except Exception as e:
+ raise ValueError(f"Failed to retrieve Google Drive token information: {str(e)}")
+
+ def validate_credentials(self, credentials: Credentials) -> bool:
+ """
+ Validate Google Drive credentials by making a test API call.
+
+ Args:
+ credentials: Google credentials object
+
+ Returns:
+ True if credentials are valid, False otherwise
+ """
+ try:
+ service = self.build_drive_service(credentials)
+ service.about().get(fields="user").execute()
+ return True
+
+ except HttpError as e:
+ logging.error(f"HTTP error validating credentials: {e}")
+ return False
+ except Exception as e:
+ logging.error(f"Error validating credentials: {e}")
+ return False
diff --git a/application/parser/connectors/google_drive/loader.py b/application/parser/connectors/google_drive/loader.py
new file mode 100644
index 00000000..07219344
--- /dev/null
+++ b/application/parser/connectors/google_drive/loader.py
@@ -0,0 +1,536 @@
+"""
+Google Drive loader for DocsGPT.
+Loads documents from Google Drive using Google Drive API.
+"""
+
+import io
+import logging
+import os
+from typing import List, Dict, Any, Optional
+
+from googleapiclient.http import MediaIoBaseDownload
+from googleapiclient.errors import HttpError
+
+from application.parser.connectors.base import BaseConnectorLoader
+from application.parser.connectors.google_drive.auth import GoogleDriveAuth
+from application.parser.schema.base import Document
+
+
+class GoogleDriveLoader(BaseConnectorLoader):
+
+ SUPPORTED_MIME_TYPES = {
+ 'application/pdf': '.pdf',
+ 'application/vnd.google-apps.document': '.docx',
+ 'application/vnd.google-apps.presentation': '.pptx',
+ 'application/vnd.google-apps.spreadsheet': '.xlsx',
+ 'application/vnd.openxmlformats-officedocument.wordprocessingml.document': '.docx',
+ 'application/vnd.openxmlformats-officedocument.presentationml.presentation': '.pptx',
+ 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': '.xlsx',
+ 'application/msword': '.doc',
+ 'application/vnd.ms-powerpoint': '.ppt',
+ 'application/vnd.ms-excel': '.xls',
+ 'text/plain': '.txt',
+ 'text/csv': '.csv',
+ 'text/html': '.html',
+ 'application/rtf': '.rtf',
+ 'image/jpeg': '.jpg',
+ 'image/jpg': '.jpg',
+ 'image/png': '.png',
+ }
+
+ EXPORT_FORMATS = {
+ 'application/vnd.google-apps.document': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
+ 'application/vnd.google-apps.presentation': 'application/vnd.openxmlformats-officedocument.presentationml.presentation',
+ 'application/vnd.google-apps.spreadsheet': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'
+ }
+
+ def __init__(self, session_token: str):
+ self.auth = GoogleDriveAuth()
+ self.session_token = session_token
+
+ token_info = self.auth.get_token_info_from_session(session_token)
+ self.credentials = self.auth.create_credentials_from_token_info(token_info)
+
+ try:
+ self.service = self.auth.build_drive_service(self.credentials)
+ except Exception as e:
+ logging.warning(f"Could not build Google Drive service: {e}")
+ self.service = None
+
+ self.next_page_token = None
+
+
+
+ def _process_file(self, file_metadata: Dict[str, Any], load_content: bool = True) -> Optional[Document]:
+ try:
+ file_id = file_metadata.get('id')
+ file_name = file_metadata.get('name', 'Unknown')
+ mime_type = file_metadata.get('mimeType', 'application/octet-stream')
+
+ if mime_type not in self.SUPPORTED_MIME_TYPES and not mime_type.startswith('application/vnd.google-apps.'):
+ return None
+ if mime_type not in self.SUPPORTED_MIME_TYPES and not mime_type.startswith('application/vnd.google-apps.'):
+ logging.info(f"Skipping unsupported file type: {mime_type} for file {file_name}")
+ return None
+ # Google Drive provides timezone-aware ISO8601 dates
+ doc_metadata = {
+ 'file_name': file_name,
+ 'mime_type': mime_type,
+ 'size': file_metadata.get('size', None),
+ 'created_time': file_metadata.get('createdTime'),
+ 'modified_time': file_metadata.get('modifiedTime'),
+ 'parents': file_metadata.get('parents', []),
+ 'source': 'google_drive'
+ }
+
+ if not load_content:
+ return Document(
+ text="",
+ doc_id=file_id,
+ extra_info=doc_metadata
+ )
+
+ content = self._download_file_content(file_id, mime_type)
+ if content is None:
+ logging.warning(f"Could not load content for file {file_name} ({file_id})")
+ return None
+
+ return Document(
+ text=content,
+ doc_id=file_id,
+ extra_info=doc_metadata
+ )
+
+ except Exception as e:
+ logging.error(f"Error processing file: {e}")
+ return None
+
+ def load_data(self, inputs: Dict[str, Any]) -> List[Document]:
+ session_token = inputs.get('session_token')
+ if session_token and session_token != self.session_token:
+ logging.warning("Session token in inputs differs from loader's session token. Using loader's session token.")
+ self.config = inputs
+
+ try:
+ documents: List[Document] = []
+
+ folder_id = inputs.get('folder_id')
+ file_ids = inputs.get('file_ids', [])
+ limit = inputs.get('limit', 100)
+ list_only = inputs.get('list_only', False)
+ load_content = not list_only
+ page_token = inputs.get('page_token')
+ self.next_page_token = None
+
+ if file_ids:
+ # Specific files requested: load them
+ for file_id in file_ids:
+ try:
+ doc = self._load_file_by_id(file_id, load_content=load_content)
+ if doc:
+ documents.append(doc)
+ elif hasattr(self, '_credential_refreshed') and self._credential_refreshed:
+ self._credential_refreshed = False
+ logging.info(f"Retrying load of file {file_id} after credential refresh")
+ doc = self._load_file_by_id(file_id, load_content=load_content)
+ if doc:
+ documents.append(doc)
+ except Exception as e:
+ logging.error(f"Error loading file {file_id}: {e}")
+ continue
+ else:
+ # Browsing mode: list immediate children of provided folder or root
+ parent_id = folder_id if folder_id else 'root'
+ documents = self._list_items_in_parent(parent_id, limit=limit, load_content=load_content, page_token=page_token)
+
+ logging.info(f"Loaded {len(documents)} documents from Google Drive")
+ return documents
+
+ except Exception as e:
+ logging.error(f"Error loading data from Google Drive: {e}", exc_info=True)
+ raise
+
+
+
+ def _load_file_by_id(self, file_id: str, load_content: bool = True) -> Optional[Document]:
+ self._ensure_service()
+
+ try:
+ file_metadata = self.service.files().get(
+ fileId=file_id,
+ fields='id,name,mimeType,size,createdTime,modifiedTime,parents'
+ ).execute()
+
+ return self._process_file(file_metadata, load_content=load_content)
+
+ except HttpError as e:
+ logging.error(f"HTTP error loading file {file_id}: {e.resp.status} - {e.content}")
+
+ if e.resp.status in [401, 403]:
+ if hasattr(self.credentials, 'refresh_token') and self.credentials.refresh_token:
+ try:
+ from google.auth.transport.requests import Request
+ self.credentials.refresh(Request())
+ self._ensure_service()
+ return None
+ except Exception as refresh_error:
+ raise ValueError(f"Authentication failed and could not be refreshed: {refresh_error}")
+ else:
+ raise ValueError("Authentication failed and cannot be refreshed: missing refresh_token")
+
+ return None
+ except Exception as e:
+ logging.error(f"Error loading file {file_id}: {e}")
+ return None
+
+
+ def _list_items_in_parent(self, parent_id: str, limit: int = 100, load_content: bool = False, page_token: Optional[str] = None) -> List[Document]:
+ self._ensure_service()
+
+ documents: List[Document] = []
+
+ try:
+ query = f"'{parent_id}' in parents and trashed=false"
+ next_token_out: Optional[str] = None
+
+ while True:
+ page_size = 100
+ if limit:
+ remaining = max(0, limit - len(documents))
+ if remaining == 0:
+ break
+ page_size = min(100, remaining)
+
+ results = self.service.files().list(
+ q=query,
+ fields='nextPageToken,files(id,name,mimeType,size,createdTime,modifiedTime,parents)',
+ pageToken=page_token,
+ pageSize=page_size
+ ).execute()
+
+ items = results.get('files', [])
+ for item in items:
+ mime_type = item.get('mimeType')
+ if mime_type == 'application/vnd.google-apps.folder':
+ doc_metadata = {
+ 'file_name': item.get('name', 'Unknown'),
+ 'mime_type': mime_type,
+ 'size': item.get('size', None),
+ 'created_time': item.get('createdTime'),
+ 'modified_time': item.get('modifiedTime'),
+ 'parents': item.get('parents', []),
+ 'source': 'google_drive',
+ 'is_folder': True
+ }
+ documents.append(Document(text="", doc_id=item.get('id'), extra_info=doc_metadata))
+ else:
+ doc = self._process_file(item, load_content=load_content)
+ if doc:
+ documents.append(doc)
+
+ if limit and len(documents) >= limit:
+ self.next_page_token = results.get('nextPageToken')
+ return documents
+
+ page_token = results.get('nextPageToken')
+ next_token_out = page_token
+ if not page_token:
+ break
+
+ self.next_page_token = next_token_out
+ return documents
+ except Exception as e:
+ logging.error(f"Error listing items under parent {parent_id}: {e}")
+ return documents
+
+
+
+
+ def _download_file_content(self, file_id: str, mime_type: str) -> Optional[str]:
+ if not self.credentials.token:
+ logging.warning("No access token in credentials, attempting to refresh")
+ if hasattr(self.credentials, 'refresh_token') and self.credentials.refresh_token:
+ try:
+ from google.auth.transport.requests import Request
+ self.credentials.refresh(Request())
+ logging.info("Credentials refreshed successfully")
+ self._ensure_service()
+ except Exception as e:
+ logging.error(f"Failed to refresh credentials: {e}")
+ raise ValueError("Authentication failed and cannot be refreshed: missing or invalid refresh_token")
+ else:
+ logging.error("No access token and no refresh_token available")
+ raise ValueError("Authentication failed and cannot be refreshed: missing refresh_token")
+
+ if self.credentials.expired:
+ logging.warning("Credentials are expired, attempting to refresh")
+ if hasattr(self.credentials, 'refresh_token') and self.credentials.refresh_token:
+ try:
+ from google.auth.transport.requests import Request
+ self.credentials.refresh(Request())
+ logging.info("Credentials refreshed successfully")
+ self._ensure_service()
+ except Exception as e:
+ logging.error(f"Failed to refresh expired credentials: {e}")
+ raise ValueError("Authentication failed and cannot be refreshed: expired credentials")
+ else:
+ logging.error("Credentials expired and no refresh_token available")
+ raise ValueError("Authentication failed and cannot be refreshed: missing refresh_token")
+
+ try:
+ if mime_type in self.EXPORT_FORMATS:
+ export_mime_type = self.EXPORT_FORMATS[mime_type]
+ request = self.service.files().export_media(
+ fileId=file_id,
+ mimeType=export_mime_type
+ )
+ else:
+ request = self.service.files().get_media(fileId=file_id)
+
+ file_io = io.BytesIO()
+ downloader = MediaIoBaseDownload(file_io, request)
+
+ done = False
+ while done is False:
+ try:
+ _, done = downloader.next_chunk()
+ except HttpError as e:
+ logging.error(f"HTTP error downloading file {file_id}: {e.resp.status} - {e.content}")
+ return None
+ except Exception as e:
+ logging.error(f"Error during download of file {file_id}: {e}")
+ return None
+
+ content_bytes = file_io.getvalue()
+
+ try:
+ content = content_bytes.decode('utf-8')
+ except UnicodeDecodeError:
+ try:
+ content = content_bytes.decode('latin-1')
+ except UnicodeDecodeError:
+ logging.error(f"Could not decode file {file_id} as text")
+ return None
+
+ return content
+
+ except HttpError as e:
+ logging.error(f"HTTP error downloading file {file_id}: {e.resp.status} - {e.content}")
+
+ if e.resp.status in [401, 403]:
+ logging.error(f"Authentication error downloading file {file_id}")
+
+ if hasattr(self.credentials, 'refresh_token') and self.credentials.refresh_token:
+ logging.info(f"Attempting to refresh credentials for file {file_id}")
+ try:
+ from google.auth.transport.requests import Request
+ self.credentials.refresh(Request())
+ logging.info("Credentials refreshed successfully")
+ self._credential_refreshed = True
+ self._ensure_service()
+ return None
+ except Exception as refresh_error:
+ logging.error(f"Error refreshing credentials: {refresh_error}")
+ raise ValueError(f"Authentication failed and could not be refreshed: {refresh_error}")
+ else:
+ logging.error("Cannot refresh credentials: missing refresh_token")
+ raise ValueError("Authentication failed and cannot be refreshed: missing refresh_token")
+
+ return None
+ except Exception as e:
+ logging.error(f"Error downloading file {file_id}: {e}")
+ return None
+
+
+ def _download_file_to_directory(self, file_id: str, local_dir: str) -> bool:
+ try:
+ self._ensure_service()
+ return self._download_single_file(file_id, local_dir)
+ except Exception as e:
+ logging.error(f"Error downloading file {file_id}: {e}", exc_info=True)
+ return False
+
+ def _ensure_service(self):
+ if not self.service:
+ try:
+ self.service = self.auth.build_drive_service(self.credentials)
+ except Exception as e:
+ raise ValueError(f"Cannot access Google Drive: {e}")
+
+ def _download_single_file(self, file_id: str, local_dir: str) -> bool:
+ file_metadata = self.service.files().get(
+ fileId=file_id,
+ fields='name,mimeType'
+ ).execute()
+
+ file_name = file_metadata['name']
+ mime_type = file_metadata['mimeType']
+
+ if mime_type not in self.SUPPORTED_MIME_TYPES and not mime_type.startswith('application/vnd.google-apps.'):
+ return False
+
+ os.makedirs(local_dir, exist_ok=True)
+ full_path = os.path.join(local_dir, file_name)
+
+ if mime_type in self.EXPORT_FORMATS:
+ export_mime_type = self.EXPORT_FORMATS[mime_type]
+ request = self.service.files().export_media(
+ fileId=file_id,
+ mimeType=export_mime_type
+ )
+ extension = self._get_extension_for_mime_type(export_mime_type)
+ if not full_path.endswith(extension):
+ full_path += extension
+ else:
+ request = self.service.files().get_media(fileId=file_id)
+
+ with open(full_path, 'wb') as f:
+ downloader = MediaIoBaseDownload(f, request)
+ done = False
+ while not done:
+ _, done = downloader.next_chunk()
+
+ return True
+
+ def _download_folder_recursive(self, folder_id: str, local_dir: str, recursive: bool = True) -> int:
+ files_downloaded = 0
+ try:
+ os.makedirs(local_dir, exist_ok=True)
+
+ query = f"'{folder_id}' in parents and trashed=false"
+ page_token = None
+
+ while True:
+ results = self.service.files().list(
+ q=query,
+ fields='nextPageToken, files(id, name, mimeType)',
+ pageToken=page_token,
+ pageSize=1000
+ ).execute()
+
+ items = results.get('files', [])
+ logging.info(f"Found {len(items)} items in folder {folder_id}")
+
+ for item in items:
+ item_name = item['name']
+ item_id = item['id']
+ mime_type = item['mimeType']
+
+ if mime_type == 'application/vnd.google-apps.folder':
+ if recursive:
+ # Create subfolder and recurse
+ subfolder_path = os.path.join(local_dir, item_name)
+ os.makedirs(subfolder_path, exist_ok=True)
+ subfolder_files = self._download_folder_recursive(
+ item_id,
+ subfolder_path,
+ recursive
+ )
+ files_downloaded += subfolder_files
+ logging.info(f"Downloaded {subfolder_files} files from subfolder {item_name}")
+ else:
+ # Download file
+ success = self._download_single_file(item_id, local_dir)
+ if success:
+ files_downloaded += 1
+ logging.info(f"Downloaded file: {item_name}")
+ else:
+ logging.warning(f"Failed to download file: {item_name}")
+
+ page_token = results.get('nextPageToken')
+ if not page_token:
+ break
+
+ return files_downloaded
+
+ except Exception as e:
+ logging.error(f"Error in _download_folder_recursive for folder {folder_id}: {e}", exc_info=True)
+ return files_downloaded
+
+ def _get_extension_for_mime_type(self, mime_type: str) -> str:
+ extensions = {
+ 'application/pdf': '.pdf',
+ 'text/plain': '.txt',
+ 'application/vnd.openxmlformats-officedocument.wordprocessingml.document': '.docx',
+ 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': '.xlsx',
+ 'application/vnd.openxmlformats-officedocument.presentationml.presentation': '.pptx',
+ 'text/html': '.html',
+ 'text/markdown': '.md',
+ }
+ return extensions.get(mime_type, '.bin')
+
+ def _download_folder_contents(self, folder_id: str, local_dir: str, recursive: bool = True) -> int:
+ try:
+ self._ensure_service()
+ return self._download_folder_recursive(folder_id, local_dir, recursive)
+ except Exception as e:
+ logging.error(f"Error downloading folder {folder_id}: {e}", exc_info=True)
+ return 0
+
+ def download_to_directory(self, local_dir: str, source_config: dict = None) -> dict:
+ if source_config is None:
+ source_config = {}
+
+ config = source_config if source_config else getattr(self, 'config', {})
+ files_downloaded = 0
+
+ try:
+ folder_ids = config.get('folder_ids', [])
+ file_ids = config.get('file_ids', [])
+ recursive = config.get('recursive', True)
+
+ self._ensure_service()
+
+ if file_ids:
+ if isinstance(file_ids, str):
+ file_ids = [file_ids]
+
+ for file_id in file_ids:
+ if self._download_file_to_directory(file_id, local_dir):
+ files_downloaded += 1
+
+ # Process folders
+ if folder_ids:
+ if isinstance(folder_ids, str):
+ folder_ids = [folder_ids]
+
+ for folder_id in folder_ids:
+ try:
+ folder_metadata = self.service.files().get(
+ fileId=folder_id,
+ fields='name'
+ ).execute()
+ folder_name = folder_metadata.get('name', '')
+ folder_path = os.path.join(local_dir, folder_name)
+ os.makedirs(folder_path, exist_ok=True)
+
+ folder_files = self._download_folder_recursive(
+ folder_id,
+ folder_path,
+ recursive
+ )
+ files_downloaded += folder_files
+ logging.info(f"Downloaded {folder_files} files from folder {folder_name}")
+ except Exception as e:
+ logging.error(f"Error downloading folder {folder_id}: {e}", exc_info=True)
+
+ if not file_ids and not folder_ids:
+ raise ValueError("No folder_ids or file_ids provided for download")
+
+ return {
+ "files_downloaded": files_downloaded,
+ "directory_path": local_dir,
+ "empty_result": files_downloaded == 0,
+ "source_type": "google_drive",
+ "config_used": config
+ }
+
+ except Exception as e:
+ return {
+ "files_downloaded": files_downloaded,
+ "directory_path": local_dir,
+ "empty_result": True,
+ "source_type": "google_drive",
+ "config_used": config,
+ "error": str(e)
+ }
diff --git a/application/parser/remote/remote_creator.py b/application/parser/remote/remote_creator.py
index 026abd76..a47b186a 100644
--- a/application/parser/remote/remote_creator.py
+++ b/application/parser/remote/remote_creator.py
@@ -6,6 +6,16 @@ from application.parser.remote.github_loader import GitHubLoader
class RemoteCreator:
+ """
+ Factory class for creating remote content loaders.
+
+ These loaders fetch content from remote web sources like URLs,
+ sitemaps, web crawlers, social media platforms, etc.
+
+ For external knowledge base connectors (like Google Drive),
+ use ConnectorCreator instead.
+ """
+
loaders = {
"url": WebLoader,
"sitemap": SitemapLoader,
@@ -18,5 +28,5 @@ class RemoteCreator:
def create_loader(cls, type, *args, **kwargs):
loader_class = cls.loaders.get(type.lower())
if not loader_class:
- raise ValueError(f"No LLM class found for type {type}")
+ raise ValueError(f"No loader class found for type {type}")
return loader_class(*args, **kwargs)
diff --git a/application/requirements.txt b/application/requirements.txt
index f922a2cb..80564689 100644
--- a/application/requirements.txt
+++ b/application/requirements.txt
@@ -14,6 +14,9 @@ Flask==3.1.1
faiss-cpu==1.9.0.post1
flask-restx==1.3.0
google-genai==1.3.0
+google-api-python-client==2.179.0
+google-auth-httplib2==0.2.0
+google-auth-oauthlib==1.2.2
gTTS==2.5.4
gunicorn==23.0.0
javalang==0.13.0
diff --git a/application/retriever/base.py b/application/retriever/base.py
index fd99dbdd..36ac2e93 100644
--- a/application/retriever/base.py
+++ b/application/retriever/base.py
@@ -5,10 +5,6 @@ class BaseRetriever(ABC):
def __init__(self):
pass
- @abstractmethod
- def gen(self, *args, **kwargs):
- pass
-
@abstractmethod
def search(self, *args, **kwargs):
pass
diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py
index 9416b4f7..2ce863c2 100644
--- a/application/retriever/classic_rag.py
+++ b/application/retriever/classic_rag.py
@@ -1,4 +1,5 @@
import logging
+
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
from application.retriever.base import BaseRetriever
@@ -20,10 +21,20 @@ class ClassicRAG(BaseRetriever):
api_key=settings.API_KEY,
decoded_token=None,
):
- self.original_question = ""
+ """Initialize ClassicRAG retriever with vectorstore sources and LLM configuration"""
+ self.original_question = source.get("question", "")
self.chat_history = chat_history if chat_history is not None else []
self.prompt = prompt
- self.chunks = chunks
+ if isinstance(chunks, str):
+ try:
+ self.chunks = int(chunks)
+ except ValueError:
+ logging.warning(
+ f"Invalid chunks value '{chunks}', using default value 2"
+ )
+ self.chunks = 2
+ else:
+ self.chunks = chunks
self.gpt_model = gpt_model
self.token_limit = (
token_limit
@@ -44,25 +55,52 @@ class ClassicRAG(BaseRetriever):
user_api_key=self.user_api_key,
decoded_token=decoded_token,
)
- self.vectorstore = source["active_docs"] if "active_docs" in source else None
+
+ if "active_docs" in source and source["active_docs"] is not None:
+ if isinstance(source["active_docs"], list):
+ self.vectorstores = source["active_docs"]
+ else:
+ self.vectorstores = [source["active_docs"]]
+ else:
+ self.vectorstores = []
self.question = self._rephrase_query()
self.decoded_token = decoded_token
+ self._validate_vectorstore_config()
+
+ def _validate_vectorstore_config(self):
+ """Validate vectorstore IDs and remove any empty/invalid entries"""
+ if not self.vectorstores:
+ logging.warning("No vectorstores configured for retrieval")
+ return
+ invalid_ids = [
+ vs_id for vs_id in self.vectorstores if not vs_id or not vs_id.strip()
+ ]
+ if invalid_ids:
+ logging.warning(f"Found invalid vectorstore IDs: {invalid_ids}")
+ self.vectorstores = [
+ vs_id for vs_id in self.vectorstores if vs_id and vs_id.strip()
+ ]
def _rephrase_query(self):
+ """Rephrase user query with chat history context for better retrieval"""
if (
not self.original_question
or not self.chat_history
or self.chat_history == []
or self.chunks == 0
- or self.vectorstore is None
+ or not self.vectorstores
):
return self.original_question
-
prompt = f"""Given the following conversation history:
+
{self.chat_history}
+
+
Rephrase the following user question to be a standalone search query
+
that captures all relevant context from the conversation:
+
"""
messages = [
@@ -79,44 +117,62 @@ class ClassicRAG(BaseRetriever):
return self.original_question
def _get_data(self):
- if self.chunks == 0 or self.vectorstore is None:
- docs = []
- else:
- docsearch = VectorCreator.create_vectorstore(
- settings.VECTOR_STORE, self.vectorstore, settings.EMBEDDINGS_KEY
- )
- docs_temp = docsearch.search(self.question, k=self.chunks)
- docs = [
- {
- "title": i.metadata.get(
- "title", i.metadata.get("post_title", i.page_content)
- ).split("/")[-1],
- "text": i.page_content,
- "source": (
- i.metadata.get("source")
- if i.metadata.get("source")
- else "local"
- ),
- }
- for i in docs_temp
- ]
+ """Retrieve relevant documents from configured vectorstores"""
+ if self.chunks == 0 or not self.vectorstores:
+ return []
+ all_docs = []
+ chunks_per_source = max(1, self.chunks // len(self.vectorstores))
- return docs
+ for vectorstore_id in self.vectorstores:
+ if vectorstore_id:
+ try:
+ docsearch = VectorCreator.create_vectorstore(
+ settings.VECTOR_STORE, vectorstore_id, settings.EMBEDDINGS_KEY
+ )
+ docs_temp = docsearch.search(self.question, k=chunks_per_source)
- def gen():
- pass
+ for doc in docs_temp:
+ if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
+ page_content = doc.page_content
+ metadata = doc.metadata
+ else:
+ page_content = doc.get("text", doc.get("page_content", ""))
+ metadata = doc.get("metadata", {})
+ title = metadata.get(
+ "title", metadata.get("post_title", page_content)
+ )
+ if isinstance(title, str):
+ title = title.split("/")[-1]
+ else:
+ title = str(title).split("/")[-1]
+ all_docs.append(
+ {
+ "title": title,
+ "text": page_content,
+ "source": metadata.get("source") or vectorstore_id,
+ }
+ )
+ except Exception as e:
+ logging.error(
+ f"Error searching vectorstore {vectorstore_id}: {e}",
+ exc_info=True,
+ )
+ continue
+ return all_docs
def search(self, query: str = ""):
+ """Search for documents using optional query override"""
if query:
self.original_question = query
self.question = self._rephrase_query()
return self._get_data()
def get_params(self):
+ """Return current retriever configuration parameters"""
return {
"question": self.original_question,
"rephrased_question": self.question,
- "source": self.vectorstore,
+ "sources": self.vectorstores,
"chunks": self.chunks,
"token_limit": self.token_limit,
"gpt_model": self.gpt_model,
diff --git a/application/vectorstore/base.py b/application/vectorstore/base.py
index a6b206c9..ea4885cd 100644
--- a/application/vectorstore/base.py
+++ b/application/vectorstore/base.py
@@ -1,20 +1,28 @@
-from abc import ABC, abstractmethod
import os
-from sentence_transformers import SentenceTransformer
+from abc import ABC, abstractmethod
+
from langchain_openai import OpenAIEmbeddings
+from sentence_transformers import SentenceTransformer
+
from application.core.settings import settings
+
class EmbeddingsWrapper:
def __init__(self, model_name, *args, **kwargs):
- self.model = SentenceTransformer(model_name, config_kwargs={'allow_dangerous_deserialization': True}, *args, **kwargs)
+ self.model = SentenceTransformer(
+ model_name,
+ config_kwargs={"allow_dangerous_deserialization": True},
+ *args,
+ **kwargs
+ )
self.dimension = self.model.get_sentence_embedding_dimension()
def embed_query(self, query: str):
return self.model.encode(query).tolist()
-
+
def embed_documents(self, documents: list):
return self.model.encode(documents).tolist()
-
+
def __call__(self, text):
if isinstance(text, str):
return self.embed_query(text)
@@ -24,15 +32,14 @@ class EmbeddingsWrapper:
raise ValueError("Input must be a string or a list of strings")
-
class EmbeddingsSingleton:
_instances = {}
@staticmethod
def get_instance(embeddings_name, *args, **kwargs):
if embeddings_name not in EmbeddingsSingleton._instances:
- EmbeddingsSingleton._instances[embeddings_name] = EmbeddingsSingleton._create_instance(
- embeddings_name, *args, **kwargs
+ EmbeddingsSingleton._instances[embeddings_name] = (
+ EmbeddingsSingleton._create_instance(embeddings_name, *args, **kwargs)
)
return EmbeddingsSingleton._instances[embeddings_name]
@@ -40,9 +47,15 @@ class EmbeddingsSingleton:
def _create_instance(embeddings_name, *args, **kwargs):
embeddings_factory = {
"openai_text-embedding-ada-002": OpenAIEmbeddings,
- "huggingface_sentence-transformers/all-mpnet-base-v2": lambda: EmbeddingsWrapper("sentence-transformers/all-mpnet-base-v2"),
- "huggingface_sentence-transformers-all-mpnet-base-v2": lambda: EmbeddingsWrapper("sentence-transformers/all-mpnet-base-v2"),
- "huggingface_hkunlp/instructor-large": lambda: EmbeddingsWrapper("hkunlp/instructor-large"),
+ "huggingface_sentence-transformers/all-mpnet-base-v2": lambda: EmbeddingsWrapper(
+ "sentence-transformers/all-mpnet-base-v2"
+ ),
+ "huggingface_sentence-transformers-all-mpnet-base-v2": lambda: EmbeddingsWrapper(
+ "sentence-transformers/all-mpnet-base-v2"
+ ),
+ "huggingface_hkunlp/instructor-large": lambda: EmbeddingsWrapper(
+ "hkunlp/instructor-large"
+ ),
}
if embeddings_name in embeddings_factory:
@@ -50,34 +63,63 @@ class EmbeddingsSingleton:
else:
return EmbeddingsWrapper(embeddings_name, *args, **kwargs)
+
class BaseVectorStore(ABC):
def __init__(self):
pass
@abstractmethod
def search(self, *args, **kwargs):
+ """Search for similar documents/chunks in the vectorstore"""
+ pass
+
+ @abstractmethod
+ def add_texts(self, texts, metadatas=None, *args, **kwargs):
+ """Add texts with their embeddings to the vectorstore"""
+ pass
+
+ def delete_index(self, *args, **kwargs):
+ """Delete the entire index/collection"""
+ pass
+
+ def save_local(self, *args, **kwargs):
+ """Save vectorstore to local storage"""
+ pass
+
+ def get_chunks(self, *args, **kwargs):
+ """Get all chunks from the vectorstore"""
+ pass
+
+ def add_chunk(self, text, metadata=None, *args, **kwargs):
+ """Add a single chunk to the vectorstore"""
+ pass
+
+ def delete_chunk(self, chunk_id, *args, **kwargs):
+ """Delete a specific chunk from the vectorstore"""
pass
def is_azure_configured(self):
- return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME
+ return (
+ settings.OPENAI_API_BASE
+ and settings.OPENAI_API_VERSION
+ and settings.AZURE_DEPLOYMENT_NAME
+ )
def _get_embeddings(self, embeddings_name, embeddings_key=None):
if embeddings_name == "openai_text-embedding-ada-002":
if self.is_azure_configured():
os.environ["OPENAI_API_TYPE"] = "azure"
embedding_instance = EmbeddingsSingleton.get_instance(
- embeddings_name,
- model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME
+ embeddings_name, model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME
)
else:
embedding_instance = EmbeddingsSingleton.get_instance(
- embeddings_name,
- openai_api_key=embeddings_key
+ embeddings_name, openai_api_key=embeddings_key
)
elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2":
if os.path.exists("./models/all-mpnet-base-v2"):
embedding_instance = EmbeddingsSingleton.get_instance(
- embeddings_name = "./models/all-mpnet-base-v2",
+ embeddings_name="./models/all-mpnet-base-v2",
)
else:
embedding_instance = EmbeddingsSingleton.get_instance(
@@ -87,4 +129,3 @@ class BaseVectorStore(ABC):
embedding_instance = EmbeddingsSingleton.get_instance(embeddings_name)
return embedding_instance
-
diff --git a/application/worker.py b/application/worker.py
index 7309806d..10fb6c2b 100755
--- a/application/worker.py
+++ b/application/worker.py
@@ -6,6 +6,7 @@ import os
import shutil
import string
import tempfile
+from typing import Any, Dict
import zipfile
from collections import Counter
@@ -21,6 +22,7 @@ from application.api.answer.services.stream_processor import get_prompt
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.parser.chunking import Chunker
+from application.parser.connectors.connector_creator import ConnectorCreator
from application.parser.embedding_pipeline import embed_and_store_documents
from application.parser.file.bulk import SimpleDirectoryReader
from application.parser.remote.remote_creator import RemoteCreator
@@ -649,8 +651,11 @@ def remote_worker(
"id": str(id),
"type": loader,
"remote_data": source_data,
- "sync_frequency": sync_frequency,
+ "sync_frequency": sync_frequency
}
+
+ if operation_mode == "sync":
+ file_data["last_sync"] = datetime.datetime.now()
upload_index(full_path, file_data)
except Exception as e:
logging.error("Error in remote_worker task: %s", str(e), exc_info=True)
@@ -707,7 +712,7 @@ def sync_worker(self, frequency):
self, source_data, name, user, source_type, frequency, retriever, doc_id
)
sync_counts["total_sync_count"] += 1
- sync_counts[
+ sync_counts[
"sync_success" if resp["status"] == "success" else "sync_failure"
] += 1
return {
@@ -744,7 +749,7 @@ def attachment_worker(self, file_info, user):
input_files=[local_path], exclude_hidden=True, errors="ignore"
)
.load_data()[0]
- .text,
+ .text,
)
@@ -835,3 +840,174 @@ def agent_webhook_worker(self, agent_id, payload):
f"Webhook processed for agent {agent_id}", extra={"agent_id": agent_id}
)
return {"status": "success", "result": result}
+
+
+def ingest_connector(
+ self,
+ job_name: str,
+ user: str,
+ source_type: str,
+ session_token=None,
+ file_ids=None,
+ folder_ids=None,
+ recursive=True,
+ retriever: str = "classic",
+ operation_mode: str = "upload",
+ doc_id=None,
+ sync_frequency: str = "never",
+) -> Dict[str, Any]:
+ """
+ Ingestion for internal knowledge bases (GoogleDrive, etc.).
+
+ Args:
+ job_name: Name of the ingestion job
+ user: User identifier
+ source_type: Type of remote source ("google_drive", "dropbox", etc.)
+ session_token: Authentication token for the service
+ file_ids: List of file IDs to download
+ folder_ids: List of folder IDs to download
+ recursive: Whether to recursively download folders
+ retriever: Type of retriever to use
+ operation_mode: "upload" for initial ingestion, "sync" for incremental sync
+ doc_id: Document ID for sync operations (required when operation_mode="sync")
+ sync_frequency: How often to sync ("never", "daily", "weekly", "monthly")
+ """
+ logging.info(f"Starting remote ingestion from {source_type} for user: {user}, job: {job_name}")
+ self.update_state(state="PROGRESS", meta={"current": 1})
+
+ with tempfile.TemporaryDirectory() as temp_dir:
+ try:
+ # Step 1: Initialize the appropriate loader
+ self.update_state(state="PROGRESS", meta={"current": 10, "status": "Initializing connector"})
+
+ if not session_token:
+ raise ValueError(f"{source_type} connector requires session_token")
+
+ if not ConnectorCreator.is_supported(source_type):
+ raise ValueError(f"Unsupported connector type: {source_type}. Supported types: {ConnectorCreator.get_supported_connectors()}")
+
+ remote_loader = ConnectorCreator.create_connector(source_type, session_token)
+
+ # Create a clean config for storage
+ api_source_config = {
+ "file_ids": file_ids or [],
+ "folder_ids": folder_ids or [],
+ "recursive": recursive
+ }
+
+ # Step 2: Download files to temp directory
+ self.update_state(state="PROGRESS", meta={"current": 20, "status": "Downloading files"})
+ download_info = remote_loader.download_to_directory(
+ temp_dir,
+ api_source_config
+ )
+
+ if download_info.get("empty_result", False) or not download_info.get("files_downloaded", 0):
+ logging.warning(f"No files were downloaded from {source_type}")
+ # Create empty result directly instead of calling a separate method
+ return {
+ "name": job_name,
+ "user": user,
+ "tokens": 0,
+ "type": source_type,
+ "source_config": api_source_config,
+ "directory_structure": "{}",
+ }
+
+ # Step 3: Use SimpleDirectoryReader to process downloaded files
+ self.update_state(state="PROGRESS", meta={"current": 40, "status": "Processing files"})
+ reader = SimpleDirectoryReader(
+ input_dir=temp_dir,
+ recursive=True,
+ required_exts=[
+ ".rst", ".md", ".pdf", ".txt", ".docx", ".csv", ".epub",
+ ".html", ".mdx", ".json", ".xlsx", ".pptx", ".png",
+ ".jpg", ".jpeg",
+ ],
+ exclude_hidden=True,
+ file_metadata=metadata_from_filename,
+ )
+ raw_docs = reader.load_data()
+ directory_structure = getattr(reader, 'directory_structure', {})
+
+
+
+ # Step 4: Process documents (chunking, embedding, etc.)
+ self.update_state(state="PROGRESS", meta={"current": 60, "status": "Processing documents"})
+
+ chunker = Chunker(
+ chunking_strategy="classic_chunk",
+ max_tokens=MAX_TOKENS,
+ min_tokens=MIN_TOKENS,
+ duplicate_headers=False,
+ )
+ raw_docs = chunker.chunk(documents=raw_docs)
+
+ # Preserve source information in document metadata
+ for doc in raw_docs:
+ if hasattr(doc, 'extra_info') and doc.extra_info:
+ source = doc.extra_info.get('source')
+ if source and os.path.isabs(source):
+ # Convert absolute path to relative path
+ doc.extra_info['source'] = os.path.relpath(source, start=temp_dir)
+
+ docs = [Document.to_langchain_format(raw_doc) for raw_doc in raw_docs]
+
+ if operation_mode == "upload":
+ id = ObjectId()
+ elif operation_mode == "sync":
+ if not doc_id or not ObjectId.is_valid(doc_id):
+ logging.error("Invalid doc_id provided for sync operation: %s", doc_id)
+ raise ValueError("doc_id must be provided for sync operation.")
+ id = ObjectId(doc_id)
+ else:
+ raise ValueError(f"Invalid operation_mode: {operation_mode}")
+
+ vector_store_path = os.path.join(temp_dir, "vector_store")
+ os.makedirs(vector_store_path, exist_ok=True)
+
+ self.update_state(state="PROGRESS", meta={"current": 80, "status": "Storing documents"})
+ embed_and_store_documents(docs, vector_store_path, id, self)
+
+ tokens = count_tokens_docs(docs)
+
+ # Step 6: Upload index files
+ file_data = {
+ "user": user,
+ "name": job_name,
+ "tokens": tokens,
+ "retriever": retriever,
+ "id": str(id),
+ "type": "connector",
+ "remote_data": json.dumps({
+ "provider": source_type,
+ **api_source_config
+ }),
+ "directory_structure": json.dumps(directory_structure),
+ "sync_frequency": sync_frequency
+ }
+
+ if operation_mode == "sync":
+ file_data["last_sync"] = datetime.datetime.now()
+ else:
+ file_data["last_sync"] = datetime.datetime.now()
+
+ upload_index(vector_store_path, file_data)
+
+ # Ensure we mark the task as complete
+ self.update_state(state="PROGRESS", meta={"current": 100, "status": "Complete"})
+
+ logging.info(f"Remote ingestion completed: {job_name}")
+
+ return {
+ "user": user,
+ "name": job_name,
+ "tokens": tokens,
+ "type": source_type,
+ "id": str(id),
+ "status": "complete"
+ }
+
+ except Exception as e:
+ logging.error(f"Error during remote ingestion: {e}", exc_info=True)
+ raise
diff --git a/frontend/src/Hero.tsx b/frontend/src/Hero.tsx
index 01695a19..9b17c10f 100644
--- a/frontend/src/Hero.tsx
+++ b/frontend/src/Hero.tsx
@@ -29,7 +29,7 @@ export default function Hero({
{/* Demo Buttons Section */}
-
+
{demos?.map(
(demo: { header: string; query: string }, key: number) =>
diff --git a/frontend/src/agents/NewAgent.tsx b/frontend/src/agents/NewAgent.tsx
index da8cef5d..92e8b961 100644
--- a/frontend/src/agents/NewAgent.tsx
+++ b/frontend/src/agents/NewAgent.tsx
@@ -45,6 +45,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
description: '',
image: '',
source: '',
+ sources: [],
chunks: '',
retriever: '',
prompt_id: 'default',
@@ -150,7 +151,41 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
const formData = new FormData();
formData.append('name', agent.name);
formData.append('description', agent.description);
- formData.append('source', agent.source);
+
+ if (selectedSourceIds.size > 1) {
+ const sourcesArray = Array.from(selectedSourceIds)
+ .map((id) => {
+ const sourceDoc = sourceDocs?.find(
+ (source) =>
+ source.id === id || source.retriever === id || source.name === id,
+ );
+ if (sourceDoc?.name === 'Default' && !sourceDoc?.id) {
+ return 'default';
+ }
+ return sourceDoc?.id || id;
+ })
+ .filter(Boolean);
+ formData.append('sources', JSON.stringify(sourcesArray));
+ formData.append('source', '');
+ } else if (selectedSourceIds.size === 1) {
+ const singleSourceId = Array.from(selectedSourceIds)[0];
+ const sourceDoc = sourceDocs?.find(
+ (source) =>
+ source.id === singleSourceId ||
+ source.retriever === singleSourceId ||
+ source.name === singleSourceId,
+ );
+ let finalSourceId;
+ if (sourceDoc?.name === 'Default' && !sourceDoc?.id)
+ finalSourceId = 'default';
+ else finalSourceId = sourceDoc?.id || singleSourceId;
+ formData.append('source', String(finalSourceId));
+ formData.append('sources', JSON.stringify([]));
+ } else {
+ formData.append('source', '');
+ formData.append('sources', JSON.stringify([]));
+ }
+
formData.append('chunks', agent.chunks);
formData.append('retriever', agent.retriever);
formData.append('prompt_id', agent.prompt_id);
@@ -196,7 +231,41 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
const formData = new FormData();
formData.append('name', agent.name);
formData.append('description', agent.description);
- formData.append('source', agent.source);
+
+ if (selectedSourceIds.size > 1) {
+ const sourcesArray = Array.from(selectedSourceIds)
+ .map((id) => {
+ const sourceDoc = sourceDocs?.find(
+ (source) =>
+ source.id === id || source.retriever === id || source.name === id,
+ );
+ if (sourceDoc?.name === 'Default' && !sourceDoc?.id) {
+ return 'default';
+ }
+ return sourceDoc?.id || id;
+ })
+ .filter(Boolean);
+ formData.append('sources', JSON.stringify(sourcesArray));
+ formData.append('source', '');
+ } else if (selectedSourceIds.size === 1) {
+ const singleSourceId = Array.from(selectedSourceIds)[0];
+ const sourceDoc = sourceDocs?.find(
+ (source) =>
+ source.id === singleSourceId ||
+ source.retriever === singleSourceId ||
+ source.name === singleSourceId,
+ );
+ let finalSourceId;
+ if (sourceDoc?.name === 'Default' && !sourceDoc?.id)
+ finalSourceId = 'default';
+ else finalSourceId = sourceDoc?.id || singleSourceId;
+ formData.append('source', String(finalSourceId));
+ formData.append('sources', JSON.stringify([]));
+ } else {
+ formData.append('source', '');
+ formData.append('sources', JSON.stringify([]));
+ }
+
formData.append('chunks', agent.chunks);
formData.append('retriever', agent.retriever);
formData.append('prompt_id', agent.prompt_id);
@@ -293,9 +362,33 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
throw new Error('Failed to fetch agent');
}
const data = await response.json();
- if (data.source) setSelectedSourceIds(new Set([data.source]));
- else if (data.retriever)
+
+ if (data.sources && data.sources.length > 0) {
+ const mappedSources = data.sources.map((sourceId: string) => {
+ if (sourceId === 'default') {
+ const defaultSource = sourceDocs?.find(
+ (source) => source.name === 'Default',
+ );
+ return defaultSource?.retriever || 'classic';
+ }
+ return sourceId;
+ });
+ setSelectedSourceIds(new Set(mappedSources));
+ } else if (data.source) {
+ if (data.source === 'default') {
+ const defaultSource = sourceDocs?.find(
+ (source) => source.name === 'Default',
+ );
+ setSelectedSourceIds(
+ new Set([defaultSource?.retriever || 'classic']),
+ );
+ } else {
+ setSelectedSourceIds(new Set([data.source]));
+ }
+ } else if (data.retriever) {
setSelectedSourceIds(new Set([data.retriever]));
+ }
+
if (data.tools) setSelectedToolIds(new Set(data.tools));
if (data.status === 'draft') setEffectiveMode('draft');
if (data.json_schema) {
@@ -311,25 +404,57 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
}, [agentId, mode, token]);
useEffect(() => {
- const selectedSource = Array.from(selectedSourceIds).map((id) =>
- sourceDocs?.find(
- (source) =>
- source.id === id || source.retriever === id || source.name === id,
- ),
- );
- if (selectedSource[0]?.model === embeddingsName) {
- if (selectedSource[0] && 'id' in selectedSource[0]) {
+ const selectedSources = Array.from(selectedSourceIds)
+ .map((id) =>
+ sourceDocs?.find(
+ (source) =>
+ source.id === id || source.retriever === id || source.name === id,
+ ),
+ )
+ .filter(Boolean);
+
+ if (selectedSources.length > 0) {
+ // Handle multiple sources
+ if (selectedSources.length > 1) {
+ // Multiple sources selected - store in sources array
+ const sourceIds = selectedSources
+ .map((source) => source?.id)
+ .filter((id): id is string => Boolean(id));
setAgent((prev) => ({
...prev,
- source: selectedSource[0]?.id || 'default',
+ sources: sourceIds,
+ source: '', // Clear single source for multiple sources
retriever: '',
}));
- } else
- setAgent((prev) => ({
- ...prev,
- source: '',
- retriever: selectedSource[0]?.retriever || 'classic',
- }));
+ } else {
+ // Single source selected - maintain backward compatibility
+ const selectedSource = selectedSources[0];
+ if (selectedSource?.model === embeddingsName) {
+ if (selectedSource && 'id' in selectedSource) {
+ setAgent((prev) => ({
+ ...prev,
+ source: selectedSource?.id || 'default',
+ sources: [], // Clear sources array for single source
+ retriever: '',
+ }));
+ } else {
+ setAgent((prev) => ({
+ ...prev,
+ source: '',
+ sources: [], // Clear sources array
+ retriever: selectedSource?.retriever || 'classic',
+ }));
+ }
+ }
+ }
+ } else {
+ // No sources selected
+ setAgent((prev) => ({
+ ...prev,
+ source: '',
+ sources: [],
+ retriever: '',
+ }));
}
}, [selectedSourceIds]);
@@ -461,7 +586,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
onChange={(e) => setAgent({ ...agent, name: e.target.value })}
/>
diff --git a/frontend/src/agents/types/index.ts b/frontend/src/agents/types/index.ts
index e841cb0a..442097a1 100644
--- a/frontend/src/agents/types/index.ts
+++ b/frontend/src/agents/types/index.ts
@@ -10,6 +10,7 @@ export type Agent = {
description: string;
image: string;
source: string;
+ sources?: string[];
chunks: string;
retriever: string;
prompt_id: string;
diff --git a/frontend/src/api/endpoints.ts b/frontend/src/api/endpoints.ts
index 62f8ba92..dad008da 100644
--- a/frontend/src/api/endpoints.ts
+++ b/frontend/src/api/endpoints.ts
@@ -38,6 +38,7 @@ const endpoints = {
UPDATE_TOOL_STATUS: '/api/update_tool_status',
UPDATE_TOOL: '/api/update_tool',
DELETE_TOOL: '/api/delete_tool',
+ SYNC_CONNECTOR: '/api/connectors/sync',
GET_CHUNKS: (
docId: string,
page: number,
diff --git a/frontend/src/api/services/userService.ts b/frontend/src/api/services/userService.ts
index 3f69f719..5dda8ddf 100644
--- a/frontend/src/api/services/userService.ts
+++ b/frontend/src/api/services/userService.ts
@@ -1,5 +1,6 @@
import apiClient from '../client';
import endpoints from '../endpoints';
+import { getSessionToken } from '../../utils/providerUtils';
const userService = {
getConfig: (): Promise
=> apiClient.get(endpoints.USER.CONFIG, null),
@@ -111,6 +112,22 @@ const userService = {
apiClient.post(endpoints.USER.MCP_TEST_CONNECTION, data, token),
saveMCPServer: (data: any, token: string | null): Promise =>
apiClient.post(endpoints.USER.MCP_SAVE_SERVER, data, token),
+ syncConnector: (
+ docId: string,
+ provider: string,
+ token: string | null,
+ ): Promise => {
+ const sessionToken = getSessionToken(provider);
+ return apiClient.post(
+ endpoints.USER.SYNC_CONNECTOR,
+ {
+ source_id: docId,
+ session_token: sessionToken,
+ provider: provider,
+ },
+ token,
+ );
+ },
};
export default userService;
diff --git a/frontend/src/assets/checkmark.svg b/frontend/src/assets/checkmark.svg
index 499000e8..3923a50a 100644
--- a/frontend/src/assets/checkmark.svg
+++ b/frontend/src/assets/checkmark.svg
@@ -1 +1 @@
-
\ No newline at end of file
+
\ No newline at end of file
diff --git a/frontend/src/assets/spinner-dark.svg b/frontend/src/assets/spinner-dark.svg
index b7f1eaed..d2423d8f 100644
--- a/frontend/src/assets/spinner-dark.svg
+++ b/frontend/src/assets/spinner-dark.svg
@@ -1 +1 @@
-
\ No newline at end of file
+
\ No newline at end of file
diff --git a/frontend/src/assets/spinner.svg b/frontend/src/assets/spinner.svg
index d2423d8f..b7f1eaed 100644
--- a/frontend/src/assets/spinner.svg
+++ b/frontend/src/assets/spinner.svg
@@ -1 +1 @@
-
\ No newline at end of file
+
\ No newline at end of file
diff --git a/frontend/src/components/ConnectorAuth.tsx b/frontend/src/components/ConnectorAuth.tsx
new file mode 100644
index 00000000..61b6e895
--- /dev/null
+++ b/frontend/src/components/ConnectorAuth.tsx
@@ -0,0 +1,130 @@
+import React, { useRef } from 'react';
+import { useSelector } from 'react-redux';
+import { selectToken } from '../preferences/preferenceSlice';
+
+interface ConnectorAuthProps {
+ provider: string;
+ onSuccess: (data: { session_token: string; user_email: string }) => void;
+ onError: (error: string) => void;
+ label?: string;
+}
+
+const providerLabel = (provider: string) => {
+ const map: Record = {
+ google_drive: 'Google Drive',
+ };
+ return map[provider] || provider.replace(/_/g, ' ');
+};
+
+const ConnectorAuth: React.FC = ({
+ provider,
+ onSuccess,
+ onError,
+ label,
+}) => {
+ const token = useSelector(selectToken);
+ const completedRef = useRef(false);
+ const intervalRef = useRef(null);
+
+ const cleanup = () => {
+ if (intervalRef.current) {
+ clearInterval(intervalRef.current);
+ intervalRef.current = null;
+ }
+ window.removeEventListener('message', handleAuthMessage as any);
+ };
+
+ const handleAuthMessage = (event: MessageEvent) => {
+ const successGeneric = event.data?.type === 'connector_auth_success';
+ const successProvider =
+ event.data?.type === `${provider}_auth_success` ||
+ event.data?.type === 'google_drive_auth_success';
+ const errorProvider =
+ event.data?.type === `${provider}_auth_error` ||
+ event.data?.type === 'google_drive_auth_error';
+
+ if (successGeneric || successProvider) {
+ completedRef.current = true;
+ cleanup();
+ onSuccess({
+ session_token: event.data.session_token,
+ user_email: event.data.user_email || 'Connected User',
+ });
+ } else if (errorProvider) {
+ completedRef.current = true;
+ cleanup();
+ onError(event.data.error || 'Authentication failed');
+ }
+ };
+
+ const handleAuth = async () => {
+ try {
+ completedRef.current = false;
+ cleanup();
+
+ const apiHost = import.meta.env.VITE_API_HOST;
+ const authResponse = await fetch(
+ `${apiHost}/api/connectors/auth?provider=${provider}`,
+ {
+ headers: { Authorization: `Bearer ${token}` },
+ },
+ );
+
+ if (!authResponse.ok) {
+ throw new Error(
+ `Failed to get authorization URL: ${authResponse.status}`,
+ );
+ }
+
+ const authData = await authResponse.json();
+ if (!authData.success || !authData.authorization_url) {
+ throw new Error(authData.error || 'Failed to get authorization URL');
+ }
+
+ const authWindow = window.open(
+ authData.authorization_url,
+ `${provider}-auth`,
+ 'width=500,height=600,scrollbars=yes,resizable=yes',
+ );
+ if (!authWindow) {
+ throw new Error(
+ 'Failed to open authentication window. Please allow popups.',
+ );
+ }
+
+ window.addEventListener('message', handleAuthMessage as any);
+
+ const checkClosed = window.setInterval(() => {
+ if (authWindow.closed) {
+ clearInterval(checkClosed);
+ window.removeEventListener('message', handleAuthMessage as any);
+ if (!completedRef.current) {
+ onError('Authentication was cancelled');
+ }
+ }
+ }, 1000);
+ intervalRef.current = checkClosed;
+ } catch (error) {
+ onError(error instanceof Error ? error.message : 'Authentication failed');
+ }
+ };
+
+ const buttonLabel = label || `Connect ${providerLabel(provider)}`;
+
+ return (
+
+ );
+};
+
+export default ConnectorAuth;
diff --git a/frontend/src/components/ConnectorTreeComponent.tsx b/frontend/src/components/ConnectorTreeComponent.tsx
new file mode 100644
index 00000000..9249145c
--- /dev/null
+++ b/frontend/src/components/ConnectorTreeComponent.tsx
@@ -0,0 +1,731 @@
+import React, { useState, useRef, useEffect } from 'react';
+import { useTranslation } from 'react-i18next';
+import { useSelector } from 'react-redux';
+import { formatBytes } from '../utils/stringUtils';
+import { selectToken } from '../preferences/preferenceSlice';
+import Chunks from './Chunks';
+import ContextMenu, { MenuOption } from './ContextMenu';
+import userService from '../api/services/userService';
+import FileIcon from '../assets/file.svg';
+import FolderIcon from '../assets/folder.svg';
+import ArrowLeft from '../assets/arrow-left.svg';
+import ThreeDots from '../assets/three-dots.svg';
+import EyeView from '../assets/eye-view.svg';
+import SyncIcon from '../assets/sync.svg';
+import { useOutsideAlerter } from '../hooks';
+
+interface FileNode {
+ type?: string;
+ token_count?: number;
+ size_bytes?: number;
+ [key: string]: any;
+}
+
+interface DirectoryStructure {
+ [key: string]: FileNode;
+}
+
+interface ConnectorTreeComponentProps {
+ docId: string;
+ sourceName: string;
+ onBackToDocuments: () => void;
+}
+
+interface SearchResult {
+ name: string;
+ path: string;
+ isFile: boolean;
+}
+
+const ConnectorTreeComponent: React.FC = ({
+ docId,
+ sourceName,
+ onBackToDocuments,
+}) => {
+ const { t } = useTranslation();
+ const [loading, setLoading] = useState(true);
+ const [error, setError] = useState(null);
+ const [directoryStructure, setDirectoryStructure] =
+ useState(null);
+ const [currentPath, setCurrentPath] = useState([]);
+ const token = useSelector(selectToken);
+ const [activeMenuId, setActiveMenuId] = useState(null);
+ const menuRefs = useRef<{
+ [key: string]: React.RefObject;
+ }>({});
+ const [selectedFile, setSelectedFile] = useState<{
+ id: string;
+ name: string;
+ } | null>(null);
+ const [searchQuery, setSearchQuery] = useState('');
+ const [searchResults, setSearchResults] = useState([]);
+ const searchDropdownRef = useRef(null);
+ const [isSyncing, setIsSyncing] = useState(false);
+ const [syncProgress, setSyncProgress] = useState(0);
+ const [sourceProvider, setSourceProvider] = useState('');
+ const [syncDone, setSyncDone] = useState(false);
+
+ useOutsideAlerter(
+ searchDropdownRef,
+ () => {
+ setSearchQuery('');
+ setSearchResults([]);
+ },
+ [],
+ false,
+ );
+
+ const handleFileClick = (fileName: string) => {
+ const fullPath = [...currentPath, fileName].join('/');
+ setSelectedFile({
+ id: fullPath,
+ name: fileName,
+ });
+ };
+
+ const handleSync = async () => {
+ if (isSyncing) return;
+
+ const provider = sourceProvider;
+
+ setIsSyncing(true);
+ setSyncProgress(0);
+
+ try {
+ const response = await userService.syncConnector(docId, provider, token);
+ const data = await response.json();
+
+ if (data.success) {
+ console.log('Sync started successfully:', data.task_id);
+ setSyncProgress(10);
+
+ // Poll task status using userService
+ const maxAttempts = 30;
+ const pollInterval = 2000;
+
+ for (let attempt = 0; attempt < maxAttempts; attempt++) {
+ try {
+ const statusResponse = await userService.getTaskStatus(
+ data.task_id,
+ token,
+ );
+ const statusData = await statusResponse.json();
+
+ console.log(
+ `Task status (attempt ${attempt + 1}):`,
+ statusData.status,
+ );
+
+ if (statusData.status === 'SUCCESS') {
+ setSyncProgress(100);
+ console.log('Sync completed successfully');
+
+ // Refresh directory structure
+ try {
+ const refreshResponse = await userService.getDirectoryStructure(
+ docId,
+ token,
+ );
+ const refreshData = await refreshResponse.json();
+ if (refreshData && refreshData.directory_structure) {
+ setDirectoryStructure(refreshData.directory_structure);
+ setCurrentPath([]);
+ }
+ if (refreshData && refreshData.provider) {
+ setSourceProvider(refreshData.provider);
+ }
+
+ setSyncDone(true);
+ setTimeout(() => setSyncDone(false), 5000);
+ } catch (err) {
+ console.error('Error refreshing directory structure:', err);
+ }
+ break;
+ } else if (statusData.status === 'FAILURE') {
+ console.error('Sync task failed:', statusData.result);
+ break;
+ } else if (statusData.status === 'PROGRESS') {
+ const progress = Number(
+ statusData.result && statusData.result.current != null
+ ? statusData.result.current
+ : statusData.meta && statusData.meta.current != null
+ ? statusData.meta.current
+ : 0,
+ );
+ setSyncProgress(Math.max(10, progress));
+ }
+
+ await new Promise((resolve) => setTimeout(resolve, pollInterval));
+ } catch (error) {
+ console.error('Error polling task status:', error);
+ break;
+ }
+ }
+ } else {
+ console.error('Sync failed:', data.error);
+ }
+ } catch (err) {
+ console.error('Error syncing connector:', err);
+ } finally {
+ setIsSyncing(false);
+ setSyncProgress(0);
+ }
+ };
+
+ useEffect(() => {
+ const fetchDirectoryStructure = async () => {
+ try {
+ setLoading(true);
+
+ const directoryResponse = await userService.getDirectoryStructure(
+ docId,
+ token,
+ );
+ const directoryData = await directoryResponse.json();
+
+ if (directoryData && directoryData.directory_structure) {
+ setDirectoryStructure(directoryData.directory_structure);
+ } else {
+ setError('Invalid response format');
+ }
+
+ if (directoryData && directoryData.provider) {
+ setSourceProvider(directoryData.provider);
+ }
+ } catch (err) {
+ setError('Failed to load source information');
+ console.error(err);
+ } finally {
+ setLoading(false);
+ }
+ };
+
+ if (docId) {
+ fetchDirectoryStructure();
+ }
+ }, [docId, token]);
+
+ const navigateToDirectory = (dirName: string) => {
+ setCurrentPath([...currentPath, dirName]);
+ };
+
+ const navigateUp = () => {
+ setCurrentPath(currentPath.slice(0, -1));
+ };
+
+ const getCurrentDirectory = (): DirectoryStructure => {
+ if (!directoryStructure) return {};
+
+ let current = directoryStructure;
+ for (const dir of currentPath) {
+ if (current[dir] && !current[dir].type) {
+ current = current[dir] as DirectoryStructure;
+ } else {
+ return {};
+ }
+ }
+ return current;
+ };
+
+ const getMenuRef = (id: string) => {
+ if (!menuRefs.current[id]) {
+ menuRefs.current[id] = React.createRef();
+ }
+ return menuRefs.current[id];
+ };
+
+ const handleMenuClick = (
+ e: React.MouseEvent,
+ id: string,
+ ) => {
+ e.stopPropagation();
+ setActiveMenuId(activeMenuId === id ? null : id);
+ };
+
+ const getActionOptions = (
+ name: string,
+ isFile: boolean,
+ _itemId: string,
+ ): MenuOption[] => {
+ const options: MenuOption[] = [];
+
+ options.push({
+ icon: EyeView,
+ label: t('settings.sources.view'),
+ onClick: (event: React.SyntheticEvent) => {
+ event.stopPropagation();
+ if (isFile) {
+ handleFileClick(name);
+ } else {
+ navigateToDirectory(name);
+ }
+ },
+ iconWidth: 18,
+ iconHeight: 18,
+ variant: 'primary',
+ });
+
+ return options;
+ };
+
+ const calculateDirectoryStats = (
+ structure: DirectoryStructure,
+ ): { totalSize: number; totalTokens: number } => {
+ let totalSize = 0;
+ let totalTokens = 0;
+
+ Object.entries(structure).forEach(([_, node]) => {
+ if (node.type) {
+ // It's a file
+ totalSize += node.size_bytes || 0;
+ totalTokens += node.token_count || 0;
+ } else {
+ // It's a directory, recurse
+ const stats = calculateDirectoryStats(node);
+ totalSize += stats.totalSize;
+ totalTokens += stats.totalTokens;
+ }
+ });
+
+ return { totalSize, totalTokens };
+ };
+
+ const handleBackNavigation = () => {
+ if (selectedFile) {
+ setSelectedFile(null);
+ } else if (currentPath.length === 0) {
+ if (onBackToDocuments) {
+ onBackToDocuments();
+ }
+ } else {
+ navigateUp();
+ }
+ };
+
+ const renderPathNavigation = () => {
+ return (
+
+ {/* Left side with path navigation */}
+
+
+
+
+
+ {sourceName}
+
+ {currentPath.length > 0 && (
+ <>
+ /
+ {currentPath.map((dir, index) => (
+
+
+ {dir}
+
+ {index < currentPath.length - 1 && (
+
+ /
+
+ )}
+
+ ))}
+ >
+ )}
+
+
+
+
+ {renderFileSearch()}
+
+ {/* Sync button */}
+
+
+
+ );
+ };
+
+ const renderFileTree = (directory: DirectoryStructure) => {
+ if (!directory) return [];
+
+ // Create parent directory row
+ const parentRow =
+ currentPath.length > 0
+ ? [
+
+
+
+ 
+
+ ..
+
+
+ |
+
+ -
+ |
+
+ -
+ |
+ |
+
,
+ ]
+ : [];
+
+ // Sort entries: directories first, then files, both alphabetically
+ const sortedEntries = Object.entries(directory).sort(
+ ([nameA, nodeA], [nameB, nodeB]) => {
+ const isFileA = !!nodeA.type;
+ const isFileB = !!nodeB.type;
+
+ if (isFileA !== isFileB) {
+ return isFileA ? 1 : -1; // Directories first
+ }
+
+ return nameA.localeCompare(nameB); // Alphabetical within each group
+ },
+ );
+
+ // Process directories
+ const directoryRows = sortedEntries
+ .filter(([_, node]) => !node.type)
+ .map(([name, node]) => {
+ const itemId = `dir-${name}`;
+ const menuRef = getMenuRef(itemId);
+
+ // Calculate directory stats
+ const dirStats = calculateDirectoryStats(node as DirectoryStructure);
+
+ return (
+ navigateToDirectory(name)}
+ >
+
+
+ 
+
+ {name}
+
+
+ |
+
+ {dirStats.totalTokens > 0
+ ? dirStats.totalTokens.toLocaleString()
+ : '-'}
+ |
+
+ {dirStats.totalSize > 0 ? formatBytes(dirStats.totalSize) : '-'}
+ |
+
+
+
+
+ setActiveMenuId(isOpen ? itemId : null)
+ }
+ options={getActionOptions(name, false, itemId)}
+ anchorRef={menuRef}
+ position="bottom-left"
+ offset={{ x: -4, y: 4 }}
+ />
+
+ |
+
+ );
+ });
+
+ // Process files
+ const fileRows = sortedEntries
+ .filter(([_, node]) => !!node.type)
+ .map(([name, node]) => {
+ const itemId = `file-${name}`;
+ const menuRef = getMenuRef(itemId);
+
+ return (
+ handleFileClick(name)}
+ >
+
+
+ 
+
+ {name}
+
+
+ |
+
+ {node.token_count?.toLocaleString() || '-'}
+ |
+
+ {node.size_bytes ? formatBytes(node.size_bytes) : '-'}
+ |
+
+
+
+
+ setActiveMenuId(isOpen ? itemId : null)
+ }
+ options={getActionOptions(name, true, itemId)}
+ anchorRef={menuRef}
+ position="bottom-left"
+ offset={{ x: -4, y: 4 }}
+ />
+
+ |
+
+ );
+ });
+
+ return [...parentRow, ...directoryRows, ...fileRows];
+ };
+
+ const searchFiles = (
+ query: string,
+ structure: DirectoryStructure,
+ currentPath: string[] = [],
+ ): SearchResult[] => {
+ let results: SearchResult[] = [];
+
+ Object.entries(structure).forEach(([name, node]) => {
+ const fullPath = [...currentPath, name].join('/');
+
+ if (name.toLowerCase().includes(query.toLowerCase())) {
+ results.push({
+ name,
+ path: fullPath,
+ isFile: !!node.type,
+ });
+ }
+
+ if (!node.type) {
+ // If it's a directory, search recursively
+ results = [
+ ...results,
+ ...searchFiles(query, node as DirectoryStructure, [
+ ...currentPath,
+ name,
+ ]),
+ ];
+ }
+ });
+
+ return results;
+ };
+
+ const handleSearchSelect = (result: SearchResult) => {
+ if (result.isFile) {
+ const pathParts = result.path.split('/');
+ const fileName = pathParts.pop() || '';
+ setCurrentPath(pathParts);
+
+ setSelectedFile({
+ id: result.path,
+ name: fileName,
+ });
+ } else {
+ setCurrentPath(result.path.split('/'));
+ setSelectedFile(null);
+ }
+ setSearchQuery('');
+ setSearchResults([]);
+ };
+
+ const renderFileSearch = () => {
+ return (
+
+
{
+ setSearchQuery(e.target.value);
+ if (directoryStructure) {
+ setSearchResults(searchFiles(e.target.value, directoryStructure));
+ }
+ }}
+ placeholder={t('settings.sources.searchFiles')}
+ className={`h-[38px] w-full border border-[#D1D9E0] px-4 py-2 dark:border-[#6A6A6A] ${searchQuery ? 'rounded-t-[24px]' : 'rounded-[24px]'} bg-transparent focus:outline-none dark:text-[#E0E0E0]`}
+ />
+
+ {searchQuery && (
+
+
+ {searchResults.length === 0 ? (
+
+ {t('settings.sources.noResults')}
+
+ ) : (
+ searchResults.map((result, index) => (
+
handleSearchSelect(result)}
+ title={result.path}
+ className={`flex min-w-0 cursor-pointer items-center px-3 py-2 hover:bg-[#ECEEEF] dark:hover:bg-[#27282D] ${
+ index !== searchResults.length - 1
+ ? 'border-b border-[#D1D9E0] dark:border-[#6A6A6A]'
+ : ''
+ }`}
+ >
+

+
+ {result.path.split('/').pop() || result.path}
+
+
+ ))
+ )}
+
+
+ )}
+
+ );
+ };
+
+ const handleFileSearch = (searchQuery: string) => {
+ if (directoryStructure) {
+ return searchFiles(searchQuery, directoryStructure);
+ }
+ return [];
+ };
+
+ const handleFileSelect = (path: string) => {
+ const pathParts = path.split('/');
+ const fileName = pathParts.pop() || '';
+ setCurrentPath(pathParts);
+ setSelectedFile({
+ id: path,
+ name: fileName,
+ });
+ };
+
+ const currentDirectory = getCurrentDirectory();
+
+ const navigateToPath = (index: number) => {
+ setCurrentPath(currentPath.slice(0, index + 1));
+ };
+
+ return (
+
+ {selectedFile ? (
+
+
+ setSelectedFile(null)}
+ path={selectedFile.id}
+ onFileSearch={handleFileSearch}
+ onFileSelect={handleFileSelect}
+ />
+
+
+ ) : (
+
+
{renderPathNavigation()}
+
+
+
+
+
+
+ |
+ {t('settings.sources.fileName')}
+ |
+
+ {t('settings.sources.tokens')}
+ |
+
+ {t('settings.sources.size')}
+ |
+ |
+
+
+ {renderFileTree(getCurrentDirectory())}
+
+
+
+
+ )}
+
+ );
+};
+
+export default ConnectorTreeComponent;
diff --git a/frontend/src/components/FileTreeComponent.tsx b/frontend/src/components/FileTreeComponent.tsx
index ad714869..724ca233 100644
--- a/frontend/src/components/FileTreeComponent.tsx
+++ b/frontend/src/components/FileTreeComponent.tsx
@@ -2,6 +2,7 @@ import React, { useState, useRef, useEffect } from 'react';
import { useTranslation } from 'react-i18next';
import { useSelector } from 'react-redux';
import { selectToken } from '../preferences/preferenceSlice';
+import { formatBytes } from '../utils/stringUtils';
import Chunks from './Chunks';
import ContextMenu, { MenuOption } from './ContextMenu';
import userService from '../api/services/userService';
@@ -10,9 +11,7 @@ import FolderIcon from '../assets/folder.svg';
import ArrowLeft from '../assets/arrow-left.svg';
import ThreeDots from '../assets/three-dots.svg';
import EyeView from '../assets/eye-view.svg';
-import OutlineSource from '../assets/outline-source.svg';
import Trash from '../assets/red-trash.svg';
-import SearchIcon from '../assets/search.svg';
import { useOutsideAlerter } from '../hooks';
import ConfirmationModal from '../modals/ConfirmationModal';
@@ -128,14 +127,6 @@ const FileTreeComponent: React.FC = ({
}
}, [docId, token]);
- const formatBytes = (bytes: number): string => {
- if (bytes === 0) return '0 Bytes';
- const k = 1024;
- const sizes = ['Bytes', 'KB', 'MB', 'GB'];
- const i = Math.floor(Math.log(bytes) / Math.log(k));
- return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i];
- };
-
const navigateToDirectory = (dirName: string) => {
setCurrentPath((prev) => [...prev, dirName]);
};
@@ -443,18 +434,18 @@ const FileTreeComponent: React.FC = ({
const renderPathNavigation = () => {
return (
-
+
{/* Left side with path navigation */}
-
+
{sourceName}
{currentPath.length > 0 && (
@@ -485,8 +476,7 @@ const FileTreeComponent: React.FC = ({
-
-
+
{processingRef.current && (
{currentOpRef.current === 'add'
@@ -495,13 +485,13 @@ const FileTreeComponent: React.FC = ({
)}
- {renderFileSearch()}
+ {renderFileSearch()}
{/* Add file button */}
{!processingRef.current && (