diff --git a/application/agents/base.py b/application/agents/base.py index 068b2a3c..77729fe6 100644 --- a/application/agents/base.py +++ b/application/agents/base.py @@ -5,19 +5,17 @@ from typing import Dict, Generator, List, Optional from bson.objectid import ObjectId -logger = logging.getLogger(__name__) - from application.agents.tools.tool_action_parser import ToolActionParser from application.agents.tools.tool_manager import ToolManager - from application.core.mongo_db import MongoDB from application.core.settings import settings - from application.llm.handlers.handler_creator import LLMHandlerCreator from application.llm.llm_creator import LLMCreator from application.logging import build_stack_data, log_activity, LogContext from application.retriever.base import BaseRetriever +logger = logging.getLogger(__name__) + class BaseAgent(ABC): def __init__( @@ -157,7 +155,7 @@ class BaseAgent(ABC): } yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}} self.tool_calls.append(tool_call_data) - return f"Failed to parse tool call.", call_id + return "Failed to parse tool call.", call_id # Check if tool_id exists in available tools if tool_id not in tools_dict: diff --git a/application/api/answer/services/stream_processor.py b/application/api/answer/services/stream_processor.py index 648d24f5..f6e639ef 100644 --- a/application/api/answer/services/stream_processor.py +++ b/application/api/answer/services/stream_processor.py @@ -69,11 +69,8 @@ class StreamProcessor: self.decoded_token.get("sub") if self.decoded_token is not None else None ) self.conversation_id = self.data.get("conversation_id") - self.source = ( - {"active_docs": self.data["active_docs"]} - if "active_docs" in self.data - else {} - ) + self.source = {} + self.all_sources = [] self.attachments = [] self.history = [] self.agent_config = {} @@ -85,6 +82,8 @@ class StreamProcessor: def initialize(self): """Initialize all required components for processing""" + self._configure_agent() + self._configure_source() self._configure_retriever() self._configure_agent() self._load_conversation_history() @@ -171,13 +170,77 @@ class StreamProcessor: source = data.get("source") if isinstance(source, DBRef): source_doc = self.db.dereference(source) - data["source"] = str(source_doc["_id"]) - data["retriever"] = source_doc.get("retriever", data.get("retriever")) - data["chunks"] = source_doc.get("chunks", data.get("chunks")) + if source_doc: + data["source"] = str(source_doc["_id"]) + data["retriever"] = source_doc.get("retriever", data.get("retriever")) + data["chunks"] = source_doc.get("chunks", data.get("chunks")) + else: + data["source"] = None + elif source == "default": + data["source"] = "default" else: data["source"] = None + # Handle multiple sources + + sources = data.get("sources", []) + if sources and isinstance(sources, list): + sources_list = [] + for i, source_ref in enumerate(sources): + if source_ref == "default": + processed_source = { + "id": "default", + "retriever": "classic", + "chunks": data.get("chunks", "2"), + } + sources_list.append(processed_source) + elif isinstance(source_ref, DBRef): + source_doc = self.db.dereference(source_ref) + if source_doc: + processed_source = { + "id": str(source_doc["_id"]), + "retriever": source_doc.get("retriever", "classic"), + "chunks": source_doc.get("chunks", data.get("chunks", "2")), + } + sources_list.append(processed_source) + data["sources"] = sources_list + else: + data["sources"] = [] return data + def _configure_source(self): + """Configure the source based on agent data""" + api_key = self.data.get("api_key") or self.agent_key + + if api_key: + agent_data = self._get_data_from_api_key(api_key) + + if agent_data.get("sources") and len(agent_data["sources"]) > 0: + source_ids = [ + source["id"] for source in agent_data["sources"] if source.get("id") + ] + if source_ids: + self.source = {"active_docs": source_ids} + else: + self.source = {} + self.all_sources = agent_data["sources"] + elif agent_data.get("source"): + self.source = {"active_docs": agent_data["source"]} + self.all_sources = [ + { + "id": agent_data["source"], + "retriever": agent_data.get("retriever", "classic"), + } + ] + else: + self.source = {} + self.all_sources = [] + return + if "active_docs" in self.data: + self.source = {"active_docs": self.data["active_docs"]} + return + self.source = {} + self.all_sources = [] + def _configure_agent(self): """Configure the agent based on request data""" agent_id = self.data.get("agent_id") @@ -203,7 +266,13 @@ class StreamProcessor: if data_key.get("retriever"): self.retriever_config["retriever_name"] = data_key["retriever"] if data_key.get("chunks") is not None: - self.retriever_config["chunks"] = data_key["chunks"] + try: + self.retriever_config["chunks"] = int(data_key["chunks"]) + except (ValueError, TypeError): + logger.warning( + f"Invalid chunks value: {data_key['chunks']}, using default value 2" + ) + self.retriever_config["chunks"] = 2 elif self.agent_key: data_key = self._get_data_from_api_key(self.agent_key) self.agent_config.update( @@ -224,7 +293,13 @@ class StreamProcessor: if data_key.get("retriever"): self.retriever_config["retriever_name"] = data_key["retriever"] if data_key.get("chunks") is not None: - self.retriever_config["chunks"] = data_key["chunks"] + try: + self.retriever_config["chunks"] = int(data_key["chunks"]) + except (ValueError, TypeError): + logger.warning( + f"Invalid chunks value: {data_key['chunks']}, using default value 2" + ) + self.retriever_config["chunks"] = 2 else: self.agent_config.update( { @@ -243,7 +318,8 @@ class StreamProcessor: "token_limit": self.data.get("token_limit", settings.DEFAULT_MAX_HISTORY), } - if "isNoneDoc" in self.data and self.data["isNoneDoc"]: + api_key = self.data.get("api_key") or self.agent_key + if not api_key and "isNoneDoc" in self.data and self.data["isNoneDoc"]: self.retriever_config["chunks"] = 0 def create_agent(self): diff --git a/application/api/connector/routes.py b/application/api/connector/routes.py new file mode 100644 index 00000000..1647aa78 --- /dev/null +++ b/application/api/connector/routes.py @@ -0,0 +1,626 @@ +import datetime +import json + + +from bson.objectid import ObjectId +from flask import ( + Blueprint, + current_app, + jsonify, + make_response, + request +) +from flask_restx import fields, Namespace, Resource + + + + +from application.api.user.tasks import ( + ingest_connector_task, +) +from application.core.mongo_db import MongoDB +from application.core.settings import settings +from application.api import api + + +from application.utils import ( + check_required_fields +) + + +from application.parser.connectors.connector_creator import ConnectorCreator + + + +mongo = MongoDB.get_client() +db = mongo[settings.MONGO_DB_NAME] +sources_collection = db["sources"] +sessions_collection = db["connector_sessions"] + +connector = Blueprint("connector", __name__) +connectors_ns = Namespace("connectors", description="Connector operations", path="/") +api.add_namespace(connectors_ns) + + + +@connectors_ns.route("/api/connectors/upload") +class UploadConnector(Resource): + @api.expect( + api.model( + "ConnectorUploadModel", + { + "user": fields.String(required=True, description="User ID"), + "source": fields.String( + required=True, description="Source type (google_drive, github, etc.)" + ), + "name": fields.String(required=True, description="Job name"), + "data": fields.String(required=True, description="Configuration data"), + "repo_url": fields.String(description="GitHub repository URL"), + }, + ) + ) + @api.doc( + description="Uploads connector source for vectorization", + ) + def post(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + data = request.form + required_fields = ["user", "source", "name", "data"] + missing_fields = check_required_fields(data, required_fields) + if missing_fields: + return missing_fields + try: + config = json.loads(data["data"]) + source_data = None + sync_frequency = config.get("sync_frequency", "never") + + if data["source"] == "github": + source_data = config.get("repo_url") + elif data["source"] in ["crawler", "url"]: + source_data = config.get("url") + elif data["source"] == "reddit": + source_data = config + elif data["source"] in ConnectorCreator.get_supported_connectors(): + session_token = config.get("session_token") + if not session_token: + return make_response(jsonify({ + "success": False, + "error": f"Missing session_token in {data['source']} configuration" + }), 400) + + file_ids = config.get("file_ids", []) + if isinstance(file_ids, str): + file_ids = [id.strip() for id in file_ids.split(',') if id.strip()] + elif not isinstance(file_ids, list): + file_ids = [] + + folder_ids = config.get("folder_ids", []) + if isinstance(folder_ids, str): + folder_ids = [id.strip() for id in folder_ids.split(',') if id.strip()] + elif not isinstance(folder_ids, list): + folder_ids = [] + + config["file_ids"] = file_ids + config["folder_ids"] = folder_ids + + task = ingest_connector_task.delay( + job_name=data["name"], + user=decoded_token.get("sub"), + source_type=data["source"], + session_token=session_token, + file_ids=file_ids, + folder_ids=folder_ids, + recursive=config.get("recursive", False), + retriever=config.get("retriever", "classic"), + sync_frequency=sync_frequency + ) + return make_response(jsonify({"success": True, "task_id": task.id}), 200) + task = ingest_connector_task.delay( + source_data=source_data, + job_name=data["name"], + user=decoded_token.get("sub"), + loader=data["source"], + sync_frequency=sync_frequency + ) + except Exception as err: + current_app.logger.error( + f"Error uploading connector source: {err}", exc_info=True + ) + return make_response(jsonify({"success": False}), 400) + return make_response(jsonify({"success": True, "task_id": task.id}), 200) + + +@connectors_ns.route("/api/connectors/task_status") +class ConnectorTaskStatus(Resource): + task_status_model = api.model( + "ConnectorTaskStatusModel", + {"task_id": fields.String(required=True, description="Task ID")}, + ) + + @api.expect(task_status_model) + @api.doc(description="Get connector task status") + def get(self): + task_id = request.args.get("task_id") + if not task_id: + return make_response( + jsonify({"success": False, "message": "Task ID is required"}), 400 + ) + try: + from application.celery_init import celery + + task = celery.AsyncResult(task_id) + task_meta = task.info + print(f"Task status: {task.status}") + if not isinstance( + task_meta, (dict, list, str, int, float, bool, type(None)) + ): + task_meta = str(task_meta) + except Exception as err: + current_app.logger.error(f"Error getting task status: {err}", exc_info=True) + return make_response(jsonify({"success": False}), 400) + return make_response(jsonify({"status": task.status, "result": task_meta}), 200) + + +@connectors_ns.route("/api/connectors/sources") +class ConnectorSources(Resource): + @api.doc(description="Get connector sources") + def get(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + user = decoded_token.get("sub") + try: + sources = sources_collection.find({"user": user, "type": "connector"}).sort("date", -1) + connector_sources = [] + for source in sources: + connector_sources.append({ + "id": str(source["_id"]), + "name": source.get("name"), + "date": source.get("date"), + "type": source.get("type"), + "source": source.get("source"), + "tokens": source.get("tokens", ""), + "retriever": source.get("retriever", "classic"), + "syncFrequency": source.get("sync_frequency", ""), + }) + except Exception as err: + current_app.logger.error(f"Error retrieving connector sources: {err}", exc_info=True) + return make_response(jsonify({"success": False}), 400) + return make_response(jsonify(connector_sources), 200) + + +@connectors_ns.route("/api/connectors/delete") +class DeleteConnectorSource(Resource): + @api.doc( + description="Delete a connector source", + params={"source_id": "The source ID to delete"}, + ) + def delete(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + source_id = request.args.get("source_id") + if not source_id: + return make_response( + jsonify({"success": False, "message": "source_id is required"}), 400 + ) + try: + result = sources_collection.delete_one( + {"_id": ObjectId(source_id), "user": decoded_token.get("sub")} + ) + if result.deleted_count == 0: + return make_response( + jsonify({"success": False, "message": "Source not found"}), 404 + ) + except Exception as err: + current_app.logger.error( + f"Error deleting connector source: {err}", exc_info=True + ) + return make_response(jsonify({"success": False}), 400) + return make_response(jsonify({"success": True}), 200) + + +@connectors_ns.route("/api/connectors/auth") +class ConnectorAuth(Resource): + @api.doc(description="Get connector OAuth authorization URL", params={"provider": "Connector provider (e.g., google_drive)"}) + def get(self): + try: + provider = request.args.get('provider') or request.args.get('source') + if not provider: + return make_response(jsonify({"success": False, "error": "Missing provider"}), 400) + + if not ConnectorCreator.is_supported(provider): + return make_response(jsonify({"success": False, "error": f"Unsupported provider: {provider}"}), 400) + + import uuid + state = str(uuid.uuid4()) + auth = ConnectorCreator.create_auth(provider) + authorization_url = auth.get_authorization_url(state=state) + return make_response(jsonify({ + "success": True, + "authorization_url": authorization_url, + "state": state + }), 200) + except Exception as e: + current_app.logger.error(f"Error generating connector auth URL: {e}") + return make_response(jsonify({"success": False, "error": str(e)}), 500) + + +@connectors_ns.route("/api/connectors/callback") +class ConnectorsCallback(Resource): + @api.doc(description="Handle OAuth callback for external connectors") + def get(self): + """Handle OAuth callback for external connectors""" + try: + from application.parser.connectors.connector_creator import ConnectorCreator + from flask import request, redirect + import uuid + + provider = request.args.get('provider', 'google_drive') + authorization_code = request.args.get('code') + _ = request.args.get('state') + error = request.args.get('error') + + if error: + return redirect(f"/api/connectors/callback-status?status=error&message=OAuth+error:+{error}.+Please+try+again+and+make+sure+to+grant+all+requested+permissions,+including+offline+access.&provider={provider}") + + if not authorization_code: + return redirect(f"/api/connectors/callback-status?status=error&message=Authorization+code+not+provided.+Please+complete+the+authorization+process+and+make+sure+to+grant+offline+access.&provider={provider}") + + try: + auth = ConnectorCreator.create_auth(provider) + token_info = auth.exchange_code_for_tokens(authorization_code) + + session_token = str(uuid.uuid4()) + + + try: + credentials = auth.create_credentials_from_token_info(token_info) + service = auth.build_drive_service(credentials) + user_info = service.about().get(fields="user").execute() + user_email = user_info.get('user', {}).get('emailAddress', 'Connected User') + except Exception as e: + current_app.logger.warning(f"Could not get user info: {e}") + user_email = 'Connected User' + + sanitized_token_info = { + "access_token": token_info.get("access_token"), + "refresh_token": token_info.get("refresh_token"), + "token_uri": token_info.get("token_uri"), + "expiry": token_info.get("expiry"), + "scopes": token_info.get("scopes") + } + + user_id = request.decoded_token.get("sub") if getattr(request, "decoded_token", None) else None + sessions_collection.insert_one({ + "session_token": session_token, + "user": user_id, + "token_info": sanitized_token_info, + "created_at": datetime.datetime.now(datetime.timezone.utc), + "user_email": user_email, + "provider": provider + }) + + # Redirect to success page with session token and user email + return redirect(f"/api/connectors/callback-status?status=success&message=Authentication+successful&provider={provider}&session_token={session_token}&user_email={user_email}") + + except Exception as e: + current_app.logger.error(f"Error exchanging code for tokens: {str(e)}", exc_info=True) + return redirect(f"/api/connectors/callback-status?status=error&message=Failed+to+exchange+authorization+code+for+tokens:+{str(e)}&provider={provider}") + + except Exception as e: + current_app.logger.error(f"Error handling connector callback: {e}") + return redirect(f"/api/connectors/callback-status?status=error&message=Failed+to+complete+connector+authentication:+{str(e)}.+Please+try+again+and+make+sure+to+grant+all+requested+permissions,+including+offline+access.") + + +@connectors_ns.route("/api/connectors/refresh") +class ConnectorRefresh(Resource): + @api.expect(api.model("ConnectorRefreshModel", {"provider": fields.String(required=True), "refresh_token": fields.String(required=True)})) + @api.doc(description="Refresh connector access token") + def post(self): + try: + data = request.get_json() + provider = data.get('provider') + refresh_token = data.get('refresh_token') + + if not provider or not refresh_token: + return make_response(jsonify({"success": False, "error": "provider and refresh_token are required"}), 400) + + auth = ConnectorCreator.create_auth(provider) + token_info = auth.refresh_access_token(refresh_token) + return make_response(jsonify({"success": True, "token_info": token_info}), 200) + except Exception as e: + current_app.logger.error(f"Error refreshing token for connector: {e}") + return make_response(jsonify({"success": False, "error": str(e)}), 500) + + +@connectors_ns.route("/api/connectors/files") +class ConnectorFiles(Resource): + @api.expect(api.model("ConnectorFilesModel", {"provider": fields.String(required=True), "session_token": fields.String(required=True), "folder_id": fields.String(required=False), "limit": fields.Integer(required=False), "page_token": fields.String(required=False)})) + @api.doc(description="List files from a connector provider (supports pagination)") + def post(self): + try: + data = request.get_json() + provider = data.get('provider') + session_token = data.get('session_token') + folder_id = data.get('folder_id') + limit = data.get('limit', 10) + page_token = data.get('page_token') + if not provider or not session_token: + return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400) + + + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401) + user = decoded_token.get('sub') + session = sessions_collection.find_one({"session_token": session_token, "user": user}) + if not session: + return make_response(jsonify({"success": False, "error": "Invalid or unauthorized session"}), 401) + + loader = ConnectorCreator.create_connector(provider, session_token) + documents = loader.load_data({ + 'limit': limit, + 'list_only': True, + 'session_token': session_token, + 'folder_id': folder_id, + 'page_token': page_token + }) + + files = [] + for doc in documents[:limit]: + metadata = doc.extra_info + modified_time = metadata.get('modified_time') + if modified_time: + date_part = modified_time.split('T')[0] + time_part = modified_time.split('T')[1].split('.')[0].split('Z')[0] + formatted_time = f"{date_part} {time_part}" + else: + formatted_time = None + + files.append({ + 'id': doc.doc_id, + 'name': metadata.get('file_name', 'Unknown File'), + 'type': metadata.get('mime_type', 'unknown'), + 'size': metadata.get('size', None), + 'modifiedTime': formatted_time + }) + + next_token = getattr(loader, 'next_page_token', None) + has_more = bool(next_token) + + return make_response(jsonify({"success": True, "files": files, "total": len(files), "next_page_token": next_token, "has_more": has_more}), 200) + except Exception as e: + current_app.logger.error(f"Error loading connector files: {e}") + return make_response(jsonify({"success": False, "error": f"Failed to load files: {str(e)}"}), 500) + + +@connectors_ns.route("/api/connectors/validate-session") +class ConnectorValidateSession(Resource): + @api.expect(api.model("ConnectorValidateSessionModel", {"provider": fields.String(required=True), "session_token": fields.String(required=True)})) + @api.doc(description="Validate connector session token and return user info") + def post(self): + try: + data = request.get_json() + provider = data.get('provider') + session_token = data.get('session_token') + if not provider or not session_token: + return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400) + + + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401) + user = decoded_token.get('sub') + + session = sessions_collection.find_one({"session_token": session_token, "user": user}) + if not session or "token_info" not in session: + return make_response(jsonify({"success": False, "error": "Invalid or expired session"}), 401) + + token_info = session["token_info"] + auth = ConnectorCreator.create_auth(provider) + is_expired = auth.is_token_expired(token_info) + + return make_response(jsonify({ + "success": True, + "expired": is_expired, + "user_email": session.get('user_email', 'Connected User') + }), 200) + except Exception as e: + current_app.logger.error(f"Error validating connector session: {e}") + return make_response(jsonify({"success": False, "error": str(e)}), 500) + + +@connectors_ns.route("/api/connectors/disconnect") +class ConnectorDisconnect(Resource): + @api.expect(api.model("ConnectorDisconnectModel", {"provider": fields.String(required=True), "session_token": fields.String(required=False)})) + @api.doc(description="Disconnect a connector session") + def post(self): + try: + data = request.get_json() + provider = data.get('provider') + session_token = data.get('session_token') + if not provider: + return make_response(jsonify({"success": False, "error": "provider is required"}), 400) + + + if session_token: + sessions_collection.delete_one({"session_token": session_token}) + + return make_response(jsonify({"success": True}), 200) + except Exception as e: + current_app.logger.error(f"Error disconnecting connector session: {e}") + return make_response(jsonify({"success": False, "error": str(e)}), 500) + + +@connectors_ns.route("/api/connectors/sync") +class ConnectorSync(Resource): + @api.expect( + api.model( + "ConnectorSyncModel", + { + "source_id": fields.String(required=True, description="Source ID to sync"), + "session_token": fields.String(required=True, description="Authentication token") + }, + ) + ) + @api.doc(description="Sync connector source to check for modifications") + def post(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + + try: + data = request.get_json() + source_id = data.get('source_id') + session_token = data.get('session_token') + + if not all([source_id, session_token]): + return make_response( + jsonify({ + "success": False, + "error": "source_id and session_token are required" + }), + 400 + ) + source = sources_collection.find_one({"_id": ObjectId(source_id)}) + if not source: + return make_response( + jsonify({ + "success": False, + "error": "Source not found" + }), + 404 + ) + + if source.get('user') != decoded_token.get('sub'): + return make_response( + jsonify({ + "success": False, + "error": "Unauthorized access to source" + }), + 403 + ) + + remote_data = {} + try: + if source.get('remote_data'): + remote_data = json.loads(source.get('remote_data')) + except json.JSONDecodeError: + current_app.logger.error(f"Invalid remote_data format for source {source_id}") + remote_data = {} + + source_type = remote_data.get('provider') + if not source_type: + return make_response( + jsonify({ + "success": False, + "error": "Source provider not found in remote_data" + }), + 400 + ) + + # Extract configuration from remote_data + file_ids = remote_data.get('file_ids', []) + folder_ids = remote_data.get('folder_ids', []) + recursive = remote_data.get('recursive', True) + + # Start the sync task + task = ingest_connector_task.delay( + job_name=source.get('name'), + user=decoded_token.get('sub'), + source_type=source_type, + session_token=session_token, + file_ids=file_ids, + folder_ids=folder_ids, + recursive=recursive, + retriever=source.get('retriever', 'classic'), + operation_mode="sync", + doc_id=source_id, + sync_frequency=source.get('sync_frequency', 'never') + ) + + return make_response( + jsonify({ + "success": True, + "task_id": task.id + }), + 200 + ) + + except Exception as err: + current_app.logger.error( + f"Error syncing connector source: {err}", + exc_info=True + ) + return make_response( + jsonify({ + "success": False, + "error": str(err) + }), + 400 + ) + + +@connectors_ns.route("/api/connectors/callback-status") +class ConnectorCallbackStatus(Resource): + @api.doc(description="Return HTML page with connector authentication status") + def get(self): + """Return HTML page with connector authentication status""" + try: + status = request.args.get('status', 'error') + message = request.args.get('message', '') + provider = request.args.get('provider', 'connector') + session_token = request.args.get('session_token', '') + user_email = request.args.get('user_email', '') + + html_content = f""" + + + + {provider.replace('_', ' ').title()} Authentication + + + + +
+

{provider.replace('_', ' ').title()} Authentication

+
+

{message}

+ {f'

Connected as: {user_email}

' if status == 'success' else ''} +
+

You can close this window. {f"Your {provider.replace('_', ' ').title()} is now connected and ready to use." if status == 'success' else ''}

+
+ + + """ + + return make_response(html_content, 200, {'Content-Type': 'text/html'}) + except Exception as e: + current_app.logger.error(f"Error rendering callback status page: {e}") + return make_response("Authentication error occurred", 500, {'Content-Type': 'text/html'}) + + diff --git a/application/api/user/routes.py b/application/api/user/routes.py index b0554461..f0493c7c 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -32,13 +32,15 @@ from application.api import api from application.api.user.tasks import ( ingest, + ingest_connector_task, ingest_remote, process_agent_webhook, store_attachment, ) from application.core.mongo_db import MongoDB from application.core.settings import settings -from application.security.encryption import encrypt_credentials, decrypt_credentials +from application.parser.connectors.connector_creator import ConnectorCreator +from application.security.encryption import decrypt_credentials, encrypt_credentials from application.storage.storage_creator import StorageCreator from application.tts.google_tts import GoogleTTS from application.utils import ( @@ -76,7 +78,6 @@ try: users_collection.create_index("user_id", unique=True) except Exception as e: print("Error creating indexes:", e) - user = Blueprint("user", __name__) user_ns = Namespace("user", description="User related operations", path="/") api.add_namespace(user_ns) @@ -129,11 +130,9 @@ def ensure_user_doc(user_id): updates["agent_preferences.pinned"] = [] if "shared_with_me" not in prefs: updates["agent_preferences.shared_with_me"] = [] - if updates: users_collection.update_one({"user_id": user_id}, {"$set": updates}) user_doc = users_collection.find_one({"user_id": user_id}) - return user_doc @@ -185,7 +184,6 @@ def handle_image_upload( jsonify({"success": False, "message": "Image upload failed"}), 400, ) - return image_url, None @@ -299,8 +297,8 @@ class GetSingleConversation(Resource): ) if not conversation: return make_response(jsonify({"status": "not found"}), 404) - # Process queries to include attachment names + queries = conversation["queries"] for query in queries: if "attachments" in query and query["attachments"]: @@ -501,6 +499,7 @@ class DeleteOldIndexes(Resource): try: # Delete vector index + if settings.VECTOR_STORE == "faiss": index_path = f"indexes/{str(doc['_id'])}" if storage.file_exists(f"{index_path}/index.faiss"): @@ -571,6 +570,7 @@ class UploadFile(Resource): job_name = request.form["name"] # Create safe versions for filesystem operations + safe_user = safe_filename(user) dir_name = safe_filename(job_name) base_path = f"{settings.UPLOAD_FOLDER}/{safe_user}/{dir_name}" @@ -592,6 +592,7 @@ class UploadFile(Resource): zip_ref.extractall(path=temp_dir) # Walk through extracted files and upload them + for root, _, files in os.walk(temp_dir): for extracted_file in files: if ( @@ -614,11 +615,13 @@ class UploadFile(Resource): f"Error extracting zip: {e}", exc_info=True ) # If zip extraction fails, save the original zip file + file_path = f"{base_path}/{safe_file}" with open(temp_file_path, "rb") as f: storage.save_file(f, file_path) else: # For non-zip files, save directly + file_path = f"{base_path}/{safe_file}" with open(temp_file_path, "rb") as f: storage.save_file(f, file_path) @@ -709,7 +712,6 @@ class ManageSourceFiles(Resource): ), 400, ) - if operation not in ["add", "remove", "remove_directory"]: return make_response( jsonify( @@ -720,14 +722,12 @@ class ManageSourceFiles(Resource): ), 400, ) - try: ObjectId(source_id) except Exception: return make_response( jsonify({"success": False, "message": "Invalid source ID format"}), 400 ) - try: source = sources_collection.find_one( {"_id": ObjectId(source_id), "user": user} @@ -760,7 +760,6 @@ class ManageSourceFiles(Resource): ), 400, ) - if operation == "add": files = request.files.getlist("file") if not files or all(file.filename == "" for file in files): @@ -773,23 +772,22 @@ class ManageSourceFiles(Resource): ), 400, ) - added_files = [] target_dir = source_file_path if parent_dir: target_dir = f"{source_file_path}/{parent_dir}" - for file in files: if file.filename: safe_filename_str = safe_filename(file.filename) file_path = f"{target_dir}/{safe_filename_str}" # Save file to storage + storage.save_file(file, file_path) added_files.append(safe_filename_str) - # Trigger re-ingestion pipeline + from application.api.user.tasks import reingest_source_task task = reingest_source_task.delay(source_id=source_id, user=user) @@ -819,7 +817,6 @@ class ManageSourceFiles(Resource): ), 400, ) - try: file_paths = ( json.loads(file_paths_str) @@ -833,18 +830,19 @@ class ManageSourceFiles(Resource): ), 400, ) - # Remove files from storage and directory structure + removed_files = [] for file_path in file_paths: full_path = f"{source_file_path}/{file_path}" # Remove from storage + if storage.file_exists(full_path): storage.delete_file(full_path) removed_files.append(file_path) - # Trigger re-ingestion pipeline + from application.api.user.tasks import reingest_source_task task = reingest_source_task.delay(source_id=source_id, user=user) @@ -873,8 +871,8 @@ class ManageSourceFiles(Resource): ), 400, ) - # Validate directory path (prevent path traversal) + if directory_path.startswith("/") or ".." in directory_path: current_app.logger.warning( f"Invalid directory path attempted for removal. " @@ -908,7 +906,6 @@ class ManageSourceFiles(Resource): ), 404, ) - success = storage.remove_directory(full_directory_path) if not success: @@ -923,7 +920,6 @@ class ManageSourceFiles(Resource): ), 500, ) - current_app.logger.info( f"Successfully removed directory. " f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}, " @@ -931,6 +927,7 @@ class ManageSourceFiles(Resource): ) # Trigger re-ingestion pipeline + from application.api.user.tasks import reingest_source_task task = reingest_source_task.delay(source_id=source_id, user=user) @@ -1005,6 +1002,50 @@ class UploadRemote(Resource): source_data = config.get("url") elif data["source"] == "reddit": source_data = config + elif data["source"] in ConnectorCreator.get_supported_connectors(): + session_token = config.get("session_token") + if not session_token: + return make_response( + jsonify( + { + "success": False, + "error": f"Missing session_token in {data['source']} configuration", + } + ), + 400, + ) + # Process file_ids + + file_ids = config.get("file_ids", []) + if isinstance(file_ids, str): + file_ids = [id.strip() for id in file_ids.split(",") if id.strip()] + elif not isinstance(file_ids, list): + file_ids = [] + # Process folder_ids + + folder_ids = config.get("folder_ids", []) + if isinstance(folder_ids, str): + folder_ids = [ + id.strip() for id in folder_ids.split(",") if id.strip() + ] + elif not isinstance(folder_ids, list): + folder_ids = [] + config["file_ids"] = file_ids + config["folder_ids"] = folder_ids + + task = ingest_connector_task.delay( + job_name=data["name"], + user=decoded_token.get("sub"), + source_type=data["source"], + session_token=session_token, + file_ids=file_ids, + folder_ids=folder_ids, + recursive=config.get("recursive", False), + retriever=config.get("retriever", "classic"), + ) + return make_response( + jsonify({"success": True, "task_id": task.id}), 200 + ) task = ingest_remote.delay( source_data=source_data, job_name=data["name"], @@ -1113,6 +1154,7 @@ class PaginatedSources(Resource): "retriever": doc.get("retriever", "classic"), "syncFrequency": doc.get("sync_frequency", ""), "isNested": bool(doc.get("directory_structure")), + "type": doc.get("type", "file"), } paginated_docs.append(doc_data) response = { @@ -1161,6 +1203,9 @@ class CombinedJson(Resource): "retriever": index.get("retriever", "classic"), "syncFrequency": index.get("sync_frequency", ""), "is_nested": bool(index.get("directory_structure")), + "type": index.get( + "type", "file" + ), # Add type field with default "file" } ) except Exception as err: @@ -1376,17 +1421,14 @@ class GetAgent(Resource): def get(self): if not (decoded_token := request.decoded_token): return {"success": False}, 401 - if not (agent_id := request.args.get("id")): return {"success": False, "message": "ID required"}, 400 - try: agent = agents_collection.find_one( {"_id": ObjectId(agent_id), "user": decoded_token["sub"]} ) if not agent: return {"status": "Not found"}, 404 - data = { "id": str(agent["_id"]), "name": agent["name"], @@ -1400,6 +1442,16 @@ class GetAgent(Resource): and (source_doc := db.dereference(agent.get("source"))) else "" ), + "sources": [ + ( + str(db.dereference(source_ref)["_id"]) + if isinstance(source_ref, DBRef) and db.dereference(source_ref) + else source_ref + ) + for source_ref in agent.get("sources", []) + if (isinstance(source_ref, DBRef) and db.dereference(source_ref)) + or source_ref == "default" + ], "chunks": agent["chunks"], "retriever": agent.get("retriever", ""), "prompt_id": agent.get("prompt_id", ""), @@ -1422,7 +1474,6 @@ class GetAgent(Resource): "shared_token": agent.get("shared_token", ""), } return make_response(jsonify(data), 200) - except Exception as e: current_app.logger.error(f"Agent fetch error: {e}", exc_info=True) return {"success": False}, 400 @@ -1434,7 +1485,6 @@ class GetAgents(Resource): def get(self): if not (decoded_token := request.decoded_token): return {"success": False}, 401 - user = decoded_token.get("sub") try: user_doc = ensure_user_doc(user) @@ -1453,8 +1503,24 @@ class GetAgents(Resource): str(source_doc["_id"]) if isinstance(agent.get("source"), DBRef) and (source_doc := db.dereference(agent.get("source"))) - else "" + else ( + agent.get("source", "") + if agent.get("source") == "default" + else "" + ) ), + "sources": [ + ( + source_ref + if source_ref == "default" + else str(db.dereference(source_ref)["_id"]) + ) + for source_ref in agent.get("sources", []) + if source_ref == "default" + or ( + isinstance(source_ref, DBRef) and db.dereference(source_ref) + ) + ], "chunks": agent["chunks"], "retriever": agent.get("retriever", ""), "prompt_id": agent.get("prompt_id", ""), @@ -1497,7 +1563,14 @@ class CreateAgent(Resource): "image": fields.Raw( required=False, description="Image file upload", type="file" ), - "source": fields.String(required=True, description="Source ID"), + "source": fields.String( + required=False, description="Source ID (legacy single source)" + ), + "sources": fields.List( + fields.String, + required=False, + description="List of source identifiers for multiple sources", + ), "chunks": fields.Integer(required=True, description="Chunks count"), "retriever": fields.String(required=True, description="Retriever ID"), "prompt_id": fields.String(required=True, description="Prompt ID"), @@ -1530,6 +1603,11 @@ class CreateAgent(Resource): data["tools"] = json.loads(data["tools"]) except json.JSONDecodeError: data["tools"] = [] + if "sources" in data: + try: + data["sources"] = json.loads(data["sources"]) + except json.JSONDecodeError: + data["sources"] = [] if "json_schema" in data: try: data["json_schema"] = json.loads(data["json_schema"]) @@ -1538,9 +1616,11 @@ class CreateAgent(Resource): print(f"Received data: {data}") # Validate JSON schema if provided + if data.get("json_schema"): try: # Basic validation - ensure it's a valid JSON structure + json_schema = data.get("json_schema") if not isinstance(json_schema, dict): return make_response( @@ -1554,6 +1634,7 @@ class CreateAgent(Resource): ) # Validate that it has either a 'schema' property or is itself a schema + if "schema" not in json_schema and "type" not in json_schema: return make_response( jsonify( @@ -1571,7 +1652,6 @@ class CreateAgent(Resource): ), 400, ) - if data.get("status") not in ["draft", "published"]: return make_response( jsonify( @@ -1582,17 +1662,27 @@ class CreateAgent(Resource): ), 400, ) - if data.get("status") == "published": required_fields = [ "name", "description", - "source", "chunks", "retriever", "prompt_id", "agent_type", ] + # Require either source or sources (but not both) + + if not data.get("source") and not data.get("sources"): + return make_response( + jsonify( + { + "success": False, + "message": "Either 'source' or 'sources' field is required for published agents", + } + ), + 400, + ) validate_fields = ["name", "description", "prompt_id", "agent_type"] else: required_fields = ["name"] @@ -1603,25 +1693,37 @@ class CreateAgent(Resource): return missing_fields if invalid_fields: return invalid_fields - image_url, error = handle_image_upload(request, "", user, storage) if error: return make_response( jsonify({"success": False, "message": "Image upload failed"}), 400 ) - try: key = str(uuid.uuid4()) if data.get("status") == "published" else "" + + sources_list = [] + if data.get("sources") and len(data.get("sources", [])) > 0: + for source_id in data.get("sources", []): + if source_id == "default": + sources_list.append("default") + elif ObjectId.is_valid(source_id): + sources_list.append(DBRef("sources", ObjectId(source_id))) + source_field = "" + else: + source_value = data.get("source", "") + if source_value == "default": + source_field = "default" + elif ObjectId.is_valid(source_value): + source_field = DBRef("sources", ObjectId(source_value)) + else: + source_field = "" new_agent = { "user": user, "name": data.get("name"), "description": data.get("description", ""), "image": image_url, - "source": ( - DBRef("sources", ObjectId(data.get("source"))) - if ObjectId.is_valid(data.get("source")) - else "" - ), + "source": source_field, + "sources": sources_list, "chunks": data.get("chunks", ""), "retriever": data.get("retriever", ""), "prompt_id": data.get("prompt_id", ""), @@ -1636,7 +1738,11 @@ class CreateAgent(Resource): } if new_agent["chunks"] == "": new_agent["chunks"] = "0" - if new_agent["source"] == "" and new_agent["retriever"] == "": + if ( + new_agent["source"] == "" + and new_agent["retriever"] == "" + and not new_agent["sources"] + ): new_agent["retriever"] = "classic" resp = agents_collection.insert_one(new_agent) new_id = str(resp.inserted_id) @@ -1658,7 +1764,14 @@ class UpdateAgent(Resource): "image": fields.String( required=False, description="New image URL or identifier" ), - "source": fields.String(required=True, description="Source ID"), + "source": fields.String( + required=False, description="Source ID (legacy single source)" + ), + "sources": fields.List( + fields.String, + required=False, + description="List of source identifiers for multiple sources", + ), "chunks": fields.Integer(required=True, description="Chunks count"), "retriever": fields.String(required=True, description="Retriever ID"), "prompt_id": fields.String(required=True, description="Prompt ID"), @@ -1691,12 +1804,16 @@ class UpdateAgent(Resource): data["tools"] = json.loads(data["tools"]) except json.JSONDecodeError: data["tools"] = [] + if "sources" in data: + try: + data["sources"] = json.loads(data["sources"]) + except json.JSONDecodeError: + data["sources"] = [] if "json_schema" in data: try: data["json_schema"] = json.loads(data["json_schema"]) except json.JSONDecodeError: data["json_schema"] = None - if not ObjectId.is_valid(agent_id): return make_response( jsonify({"success": False, "message": "Invalid agent ID format"}), 400 @@ -1720,7 +1837,6 @@ class UpdateAgent(Resource): ), 404, ) - image_url, error = handle_image_upload( request, existing_agent.get("image", ""), user, storage ) @@ -1728,13 +1844,13 @@ class UpdateAgent(Resource): return make_response( jsonify({"success": False, "message": "Image upload failed"}), 400 ) - update_fields = {} allowed_fields = [ "name", "description", "image", "source", + "sources", "chunks", "retriever", "prompt_id", @@ -1758,7 +1874,11 @@ class UpdateAgent(Resource): update_fields[field] = new_status elif field == "source": source_id = data.get("source") - if source_id and ObjectId.is_valid(source_id): + if source_id == "default": + # Handle special "default" source + + update_fields[field] = "default" + elif source_id and ObjectId.is_valid(source_id): update_fields[field] = DBRef("sources", ObjectId(source_id)) elif source_id: return make_response( @@ -1772,6 +1892,30 @@ class UpdateAgent(Resource): ) else: update_fields[field] = "" + elif field == "sources": + sources_list = data.get("sources", []) + if sources_list and isinstance(sources_list, list): + valid_sources = [] + for source_id in sources_list: + if source_id == "default": + valid_sources.append("default") + elif ObjectId.is_valid(source_id): + valid_sources.append( + DBRef("sources", ObjectId(source_id)) + ) + else: + return make_response( + jsonify( + { + "success": False, + "message": f"Invalid source ID format: {source_id}", + } + ), + 400, + ) + update_fields[field] = valid_sources + else: + update_fields[field] = [] elif field == "chunks": chunks_value = data.get("chunks") if chunks_value == "": @@ -1837,7 +1981,6 @@ class UpdateAgent(Resource): ), 400, ) - if not existing_agent.get("key"): newly_generated_key = str(uuid.uuid4()) update_fields["key"] = newly_generated_key @@ -1924,7 +2067,6 @@ class PinnedAgents(Resource): decoded_token = request.decoded_token if not decoded_token: return make_response(jsonify({"success": False}), 401) - user_id = decoded_token.get("sub") try: @@ -1933,7 +2075,6 @@ class PinnedAgents(Resource): if not pinned_ids: return make_response(jsonify([]), 200) - pinned_object_ids = [ObjectId(agent_id) for agent_id in pinned_ids] pinned_agents_cursor = agents_collection.find( @@ -1943,6 +2084,7 @@ class PinnedAgents(Resource): existing_ids = {str(agent["_id"]) for agent in pinned_agents} # Clean up any stale pinned IDs + stale_ids = [ agent_id for agent_id in pinned_ids if agent_id not in existing_ids ] @@ -1951,7 +2093,6 @@ class PinnedAgents(Resource): {"user_id": user_id}, {"$pullAll": {"agent_preferences.pinned": stale_ids}}, ) - list_pinned_agents = [ { "id": str(agent["_id"]), @@ -1988,11 +2129,9 @@ class PinnedAgents(Resource): for agent in pinned_agents if "source" in agent or "retriever" in agent ] - except Exception as err: current_app.logger.error(f"Error retrieving pinned agents: {err}") return make_response(jsonify({"success": False}), 400) - return make_response(jsonify(list_pinned_agents), 200) @@ -2056,7 +2195,6 @@ class RemoveSharedAgent(Resource): return make_response( jsonify({"success": False, "message": "ID is required"}), 400 ) - try: agent = agents_collection.find_one( {"_id": ObjectId(agent_id), "shared_publicly": True} @@ -2066,7 +2204,6 @@ class RemoveSharedAgent(Resource): jsonify({"success": False, "message": "Shared agent not found"}), 404, ) - ensure_user_doc(user_id) users_collection.update_one( {"user_id": user_id}, @@ -2079,7 +2216,6 @@ class RemoveSharedAgent(Resource): ) return make_response(jsonify({"success": True, "action": "removed"}), 200) - except Exception as err: current_app.logger.error(f"Error removing shared agent: {err}") return make_response( @@ -2102,7 +2238,6 @@ class SharedAgent(Resource): return make_response( jsonify({"success": False, "message": "Token or ID is required"}), 400 ) - try: query = { "shared_publicly": True, @@ -2114,7 +2249,6 @@ class SharedAgent(Resource): jsonify({"success": False, "message": "Shared agent not found"}), 404, ) - agent_id = str(shared_agent["_id"]) data = { "id": agent_id, @@ -2154,7 +2288,6 @@ class SharedAgent(Resource): if tool_data: enriched_tools.append(tool_data.get("name", "")) data["tools"] = enriched_tools - decoded_token = getattr(request, "decoded_token", None) if decoded_token: user_id = decoded_token.get("sub") @@ -2166,9 +2299,7 @@ class SharedAgent(Resource): {"user_id": user_id}, {"$addToSet": {"agent_preferences.shared_with_me": agent_id}}, ) - return make_response(jsonify(data), 200) - except Exception as err: current_app.logger.error(f"Error retrieving shared agent: {err}") return make_response(jsonify({"success": False}), 400) @@ -2202,7 +2333,6 @@ class SharedAgents(Resource): {"user_id": user_id}, {"$pullAll": {"agent_preferences.shared_with_me": stale_ids}}, ) - pinned_ids = set(user_doc.get("agent_preferences", {}).get("pinned", [])) list_shared_agents = [ @@ -2229,7 +2359,6 @@ class SharedAgents(Resource): ] return make_response(jsonify(list_shared_agents), 200) - except Exception as err: current_app.logger.error(f"Error retrieving shared agents: {err}") return make_response(jsonify({"success": False}), 400) @@ -3762,20 +3891,21 @@ class GetChunks(Resource): metadata = chunk.get("metadata", {}) # Filter by path if provided + if path: chunk_source = metadata.get("source", "") # Check if the chunk's source matches the requested path + if not chunk_source or not chunk_source.endswith(path): continue - # Filter by search term if provided + if search_term: text_match = search_term in chunk.get("text", "").lower() title_match = search_term in metadata.get("title", "").lower() if not (text_match or title_match): continue - filtered_chunks.append(chunk) chunks = filtered_chunks @@ -3937,7 +4067,6 @@ class UpdateChunk(Resource): if metadata is None: metadata = {} metadata["token_count"] = token_count - if not ObjectId.is_valid(doc_id): return make_response(jsonify({"error": "Invalid doc_id"}), 400) doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user}) @@ -3952,7 +4081,6 @@ class UpdateChunk(Resource): existing_chunk = next((c for c in chunks if c["doc_id"] == chunk_id), None) if not existing_chunk: return make_response(jsonify({"error": "Chunk not found"}), 404) - new_text = text if text is not None else existing_chunk["text"] if metadata is not None: @@ -3960,10 +4088,8 @@ class UpdateChunk(Resource): new_metadata.update(metadata) else: new_metadata = existing_chunk["metadata"].copy() - if text is not None: new_metadata["token_count"] = num_tokens_from_string(new_text) - try: new_chunk_id = store.add_chunk(new_text, new_metadata) @@ -4019,7 +4145,6 @@ class StoreAttachment(Resource): jsonify({"status": "error", "message": "Missing file"}), 400, ) - user = None if decoded_token: user = safe_filename(decoded_token.get("sub")) @@ -4034,7 +4159,6 @@ class StoreAttachment(Resource): return make_response( jsonify({"success": False, "message": "Authentication required"}), 401 ) - try: attachment_id = ObjectId() original_filename = safe_filename(os.path.basename(file.filename)) @@ -4076,7 +4200,6 @@ class ServeImage(Resource): content_type = f"image/{extension}" if extension == "jpg": content_type = "image/jpeg" - response = make_response(file_obj.read()) response.headers.set("Content-Type", content_type) response.headers.set("Cache-Control", "max-age=86400") @@ -4121,18 +4244,29 @@ class DirectoryStructure(Resource): ) directory_structure = doc.get("directory_structure", {}) + base_path = doc.get("file_path", "") + provider = None + remote_data = doc.get("remote_data") + try: + if isinstance(remote_data, str) and remote_data: + remote_data_obj = json.loads(remote_data) + provider = remote_data_obj.get("provider") + except Exception as e: + current_app.logger.warning( + f"Failed to parse remote_data for doc {doc_id}: {e}" + ) return make_response( jsonify( { "success": True, "directory_structure": directory_structure, - "base_path": doc.get("file_path", ""), + "base_path": base_path, + "provider": provider, } ), 200, ) - except Exception as e: current_app.logger.error( f"Error retrieving directory structure: {e}", exc_info=True diff --git a/application/api/user/tasks.py b/application/api/user/tasks.py index 28a78c0d..3519b701 100644 --- a/application/api/user/tasks.py +++ b/application/api/user/tasks.py @@ -47,6 +47,39 @@ def process_agent_webhook(self, agent_id, payload): return resp +@celery.task(bind=True) +def ingest_connector_task( + self, + job_name, + user, + source_type, + session_token=None, + file_ids=None, + folder_ids=None, + recursive=True, + retriever="classic", + operation_mode="upload", + doc_id=None, + sync_frequency="never" +): + from application.worker import ingest_connector + resp = ingest_connector( + self, + job_name, + user, + source_type, + session_token=session_token, + file_ids=file_ids, + folder_ids=folder_ids, + recursive=recursive, + retriever=retriever, + operation_mode=operation_mode, + doc_id=doc_id, + sync_frequency=sync_frequency + ) + return resp + + @celery.on_after_configure.connect def setup_periodic_tasks(sender, **kwargs): sender.add_periodic_task( diff --git a/application/app.py b/application/app.py index 4159a2bb..489ec840 100644 --- a/application/app.py +++ b/application/app.py @@ -16,6 +16,7 @@ from application.api import api # noqa: E402 from application.api.answer import answer # noqa: E402 from application.api.internal.routes import internal # noqa: E402 from application.api.user.routes import user # noqa: E402 +from application.api.connector.routes import connector # noqa: E402 from application.celery_init import celery # noqa: E402 from application.core.settings import settings # noqa: E402 @@ -30,6 +31,7 @@ app = Flask(__name__) app.register_blueprint(user) app.register_blueprint(answer) app.register_blueprint(internal) +app.register_blueprint(connector) app.config.update( UPLOAD_FOLDER="inputs", CELERY_BROKER_URL=settings.CELERY_BROKER_URL, diff --git a/application/core/settings.py b/application/core/settings.py index f1563569..7ede4e86 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -40,6 +40,13 @@ class Settings(BaseSettings): FALLBACK_LLM_NAME: Optional[str] = None # model name for fallback llm FALLBACK_LLM_API_KEY: Optional[str] = None # api key for fallback llm + # Google Drive integration + GOOGLE_CLIENT_ID: Optional[str] = None # Replace with your actual Google OAuth client ID + GOOGLE_CLIENT_SECRET: Optional[str] = None# Replace with your actual Google OAuth client secret + CONNECTOR_REDIRECT_BASE_URI: Optional[str] = "http://127.0.0.1:7091/api/connectors/callback" + ##append ?provider={provider_name} in your Provider console like http://127.0.0.1:7091/api/connectors/callback?provider=google_drive + + # LLM Cache CACHE_REDIS_URL: str = "redis://localhost:6379/2" diff --git a/application/parser/connectors/__init__.py b/application/parser/connectors/__init__.py new file mode 100644 index 00000000..c9add3d7 --- /dev/null +++ b/application/parser/connectors/__init__.py @@ -0,0 +1,18 @@ +""" +External knowledge base connectors for DocsGPT. + +This module contains connectors for external knowledge bases and document storage systems +that require authentication and specialized handling, separate from simple web scrapers. +""" + +from .base import BaseConnectorAuth, BaseConnectorLoader +from .connector_creator import ConnectorCreator +from .google_drive import GoogleDriveAuth, GoogleDriveLoader + +__all__ = [ + 'BaseConnectorAuth', + 'BaseConnectorLoader', + 'ConnectorCreator', + 'GoogleDriveAuth', + 'GoogleDriveLoader' +] diff --git a/application/parser/connectors/base.py b/application/parser/connectors/base.py new file mode 100644 index 00000000..dfb6de87 --- /dev/null +++ b/application/parser/connectors/base.py @@ -0,0 +1,129 @@ +""" +Base classes for external knowledge base connectors. + +This module provides minimal abstract base classes that define the essential +interface for external knowledge base connectors. +""" + +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional + +from application.parser.schema.base import Document + + +class BaseConnectorAuth(ABC): + """ + Abstract base class for connector authentication. + + Defines the minimal interface that all connector authentication + implementations must follow. + """ + + @abstractmethod + def get_authorization_url(self, state: Optional[str] = None) -> str: + """ + Generate authorization URL for OAuth flows. + + Args: + state: Optional state parameter for CSRF protection + + Returns: + Authorization URL + """ + pass + + @abstractmethod + def exchange_code_for_tokens(self, authorization_code: str) -> Dict[str, Any]: + """ + Exchange authorization code for access tokens. + + Args: + authorization_code: Authorization code from OAuth callback + + Returns: + Dictionary containing token information + """ + pass + + @abstractmethod + def refresh_access_token(self, refresh_token: str) -> Dict[str, Any]: + """ + Refresh an expired access token. + + Args: + refresh_token: Refresh token + + Returns: + Dictionary containing refreshed token information + """ + pass + + @abstractmethod + def is_token_expired(self, token_info: Dict[str, Any]) -> bool: + """ + Check if a token is expired. + + Args: + token_info: Token information dictionary + + Returns: + True if token is expired, False otherwise + """ + pass + + +class BaseConnectorLoader(ABC): + """ + Abstract base class for connector loaders. + + Defines the minimal interface that all connector loader + implementations must follow. + """ + + @abstractmethod + def __init__(self, session_token: str): + """ + Initialize the connector loader. + + Args: + session_token: Authentication session token + """ + pass + + @abstractmethod + def load_data(self, inputs: Dict[str, Any]) -> List[Document]: + """ + Load documents from the external knowledge base. + + Args: + inputs: Configuration dictionary containing: + - file_ids: Optional list of specific file IDs to load + - folder_ids: Optional list of folder IDs to browse/download + - limit: Maximum number of items to return + - list_only: If True, return metadata without content + - recursive: Whether to recursively process folders + + Returns: + List of Document objects + """ + pass + + @abstractmethod + def download_to_directory(self, local_dir: str, source_config: Dict[str, Any] = None) -> Dict[str, Any]: + """ + Download files/folders to a local directory. + + Args: + local_dir: Local directory path to download files to + source_config: Configuration for what to download + + Returns: + Dictionary containing download results: + - files_downloaded: Number of files downloaded + - directory_path: Path where files were downloaded + - empty_result: Whether no files were downloaded + - source_type: Type of connector + - config_used: Configuration that was used + - error: Error message if download failed (optional) + """ + pass diff --git a/application/parser/connectors/connector_creator.py b/application/parser/connectors/connector_creator.py new file mode 100644 index 00000000..bf4456ca --- /dev/null +++ b/application/parser/connectors/connector_creator.py @@ -0,0 +1,81 @@ +from application.parser.connectors.google_drive.loader import GoogleDriveLoader +from application.parser.connectors.google_drive.auth import GoogleDriveAuth + + +class ConnectorCreator: + """ + Factory class for creating external knowledge base connectors and auth providers. + + These are different from remote loaders as they typically require + authentication and connect to external document storage systems. + """ + + connectors = { + "google_drive": GoogleDriveLoader, + } + + auth_providers = { + "google_drive": GoogleDriveAuth, + } + + @classmethod + def create_connector(cls, connector_type, *args, **kwargs): + """ + Create a connector instance for the specified type. + + Args: + connector_type: Type of connector to create (e.g., 'google_drive') + *args, **kwargs: Arguments to pass to the connector constructor + + Returns: + Connector instance + + Raises: + ValueError: If connector type is not supported + """ + connector_class = cls.connectors.get(connector_type.lower()) + if not connector_class: + raise ValueError(f"No connector class found for type {connector_type}") + return connector_class(*args, **kwargs) + + @classmethod + def create_auth(cls, connector_type): + """ + Create an auth provider instance for the specified connector type. + + Args: + connector_type: Type of connector auth to create (e.g., 'google_drive') + + Returns: + Auth provider instance + + Raises: + ValueError: If connector type is not supported for auth + """ + auth_class = cls.auth_providers.get(connector_type.lower()) + if not auth_class: + raise ValueError(f"No auth class found for type {connector_type}") + return auth_class() + + @classmethod + def get_supported_connectors(cls): + """ + Get list of supported connector types. + + Returns: + List of supported connector type strings + """ + return list(cls.connectors.keys()) + + @classmethod + def is_supported(cls, connector_type): + """ + Check if a connector type is supported. + + Args: + connector_type: Type of connector to check + + Returns: + True if supported, False otherwise + """ + return connector_type.lower() in cls.connectors diff --git a/application/parser/connectors/google_drive/__init__.py b/application/parser/connectors/google_drive/__init__.py new file mode 100644 index 00000000..18abeec1 --- /dev/null +++ b/application/parser/connectors/google_drive/__init__.py @@ -0,0 +1,10 @@ +""" +Google Drive connector for DocsGPT. + +This module provides authentication and document loading capabilities for Google Drive. +""" + +from .auth import GoogleDriveAuth +from .loader import GoogleDriveLoader + +__all__ = ['GoogleDriveAuth', 'GoogleDriveLoader'] diff --git a/application/parser/connectors/google_drive/auth.py b/application/parser/connectors/google_drive/auth.py new file mode 100644 index 00000000..37d55dcc --- /dev/null +++ b/application/parser/connectors/google_drive/auth.py @@ -0,0 +1,268 @@ +import logging +import datetime +from typing import Optional, Dict, Any + +from google.oauth2.credentials import Credentials +from google_auth_oauthlib.flow import Flow +from googleapiclient.discovery import build +from googleapiclient.errors import HttpError + +from application.core.settings import settings +from application.parser.connectors.base import BaseConnectorAuth + + +class GoogleDriveAuth(BaseConnectorAuth): + """ + Handles Google OAuth 2.0 authentication for Google Drive access. + """ + + SCOPES = [ + 'https://www.googleapis.com/auth/drive.readonly', + 'https://www.googleapis.com/auth/drive.metadata.readonly' + ] + + def __init__(self): + self.client_id = settings.GOOGLE_CLIENT_ID + self.client_secret = settings.GOOGLE_CLIENT_SECRET + self.redirect_uri = f"{settings.CONNECTOR_REDIRECT_BASE_URI}?provider=google_drive" + + if not self.client_id or not self.client_secret: + raise ValueError("Google OAuth credentials not configured. Please set GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET in settings.") + + + + def get_authorization_url(self, state: Optional[str] = None) -> str: + try: + flow = Flow.from_client_config( + { + "web": { + "client_id": self.client_id, + "client_secret": self.client_secret, + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + "redirect_uris": [self.redirect_uri] + } + }, + scopes=self.SCOPES + ) + flow.redirect_uri = self.redirect_uri + + authorization_url, _ = flow.authorization_url( + access_type='offline', + prompt='consent', + include_granted_scopes='true', + state=state + ) + + return authorization_url + + except Exception as e: + logging.error(f"Error generating authorization URL: {e}") + raise + + def exchange_code_for_tokens(self, authorization_code: str) -> Dict[str, Any]: + try: + if not authorization_code: + raise ValueError("Authorization code is required") + + flow = Flow.from_client_config( + { + "web": { + "client_id": self.client_id, + "client_secret": self.client_secret, + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + "redirect_uris": [self.redirect_uri] + } + }, + scopes=self.SCOPES + ) + flow.redirect_uri = self.redirect_uri + + flow.fetch_token(code=authorization_code) + + credentials = flow.credentials + + if not credentials.refresh_token: + logging.warning("OAuth flow did not return a refresh_token.") + if not credentials.token: + raise ValueError("OAuth flow did not return an access token") + + if not credentials.token_uri: + credentials.token_uri = "https://oauth2.googleapis.com/token" + + if not credentials.client_id: + credentials.client_id = self.client_id + + if not credentials.client_secret: + credentials.client_secret = self.client_secret + + if not credentials.refresh_token: + raise ValueError( + "No refresh token received. This typically happens when offline access wasn't granted. " + ) + + return { + 'access_token': credentials.token, + 'refresh_token': credentials.refresh_token, + 'token_uri': credentials.token_uri, + 'client_id': credentials.client_id, + 'client_secret': credentials.client_secret, + 'scopes': credentials.scopes, + 'expiry': credentials.expiry.isoformat() if credentials.expiry else None + } + + except Exception as e: + logging.error(f"Error exchanging code for tokens: {e}") + raise + + def refresh_access_token(self, refresh_token: str) -> Dict[str, Any]: + try: + if not refresh_token: + raise ValueError("Refresh token is required") + + credentials = Credentials( + token=None, + refresh_token=refresh_token, + token_uri="https://oauth2.googleapis.com/token", + client_id=self.client_id, + client_secret=self.client_secret + ) + + from google.auth.transport.requests import Request + credentials.refresh(Request()) + + return { + 'access_token': credentials.token, + 'refresh_token': refresh_token, + 'token_uri': credentials.token_uri, + 'client_id': credentials.client_id, + 'client_secret': credentials.client_secret, + 'scopes': credentials.scopes, + 'expiry': credentials.expiry.isoformat() if credentials.expiry else None + } + except Exception as e: + logging.error(f"Error refreshing access token: {e}", exc_info=True) + raise + + def create_credentials_from_token_info(self, token_info: Dict[str, Any]) -> Credentials: + from application.core.settings import settings + + access_token = token_info.get('access_token') + if not access_token: + raise ValueError("No access token found in token_info") + + credentials = Credentials( + token=access_token, + refresh_token=token_info.get('refresh_token'), + token_uri= 'https://oauth2.googleapis.com/token', + client_id=settings.GOOGLE_CLIENT_ID, + client_secret=settings.GOOGLE_CLIENT_SECRET, + scopes=token_info.get('scopes', ['https://www.googleapis.com/auth/drive.readonly']) + ) + + if not credentials.token: + raise ValueError("Credentials created without valid access token") + + return credentials + + def build_drive_service(self, credentials: Credentials): + try: + if not credentials: + raise ValueError("No credentials provided") + + if not credentials.token and not credentials.refresh_token: + raise ValueError("No access token or refresh token available. User must re-authorize with offline access.") + + needs_refresh = credentials.expired or not credentials.token + if needs_refresh: + if credentials.refresh_token: + try: + from google.auth.transport.requests import Request + credentials.refresh(Request()) + except Exception as refresh_error: + raise ValueError(f"Failed to refresh credentials: {refresh_error}") + else: + raise ValueError("No access token or refresh token available. User must re-authorize with offline access.") + + return build('drive', 'v3', credentials=credentials) + + except HttpError as e: + raise ValueError(f"Failed to build Google Drive service: HTTP {e.resp.status}") + except Exception as e: + raise ValueError(f"Failed to build Google Drive service: {str(e)}") + + def is_token_expired(self, token_info): + if 'expiry' in token_info and token_info['expiry']: + try: + from dateutil import parser + # Google Drive provides timezone-aware ISO8601 dates + expiry_dt = parser.parse(token_info['expiry']) + current_time = datetime.datetime.now(datetime.timezone.utc) + return current_time >= expiry_dt - datetime.timedelta(seconds=60) + except Exception: + return True + + if 'access_token' in token_info and token_info['access_token']: + return False + + return True + + def get_token_info_from_session(self, session_token: str) -> Dict[str, Any]: + try: + from application.core.mongo_db import MongoDB + from application.core.settings import settings + + mongo = MongoDB.get_client() + db = mongo[settings.MONGO_DB_NAME] + + sessions_collection = db["connector_sessions"] + session = sessions_collection.find_one({"session_token": session_token}) + if not session: + raise ValueError(f"Invalid session token: {session_token}") + + if "token_info" not in session: + raise ValueError("Session missing token information") + + token_info = session["token_info"] + if not token_info: + raise ValueError("Invalid token information") + + required_fields = ["access_token", "refresh_token"] + missing_fields = [field for field in required_fields if field not in token_info or not token_info.get(field)] + if missing_fields: + raise ValueError(f"Missing required token fields: {missing_fields}") + + if 'client_id' not in token_info: + token_info['client_id'] = settings.GOOGLE_CLIENT_ID + if 'client_secret' not in token_info: + token_info['client_secret'] = settings.GOOGLE_CLIENT_SECRET + if 'token_uri' not in token_info: + token_info['token_uri'] = 'https://oauth2.googleapis.com/token' + + return token_info + + except Exception as e: + raise ValueError(f"Failed to retrieve Google Drive token information: {str(e)}") + + def validate_credentials(self, credentials: Credentials) -> bool: + """ + Validate Google Drive credentials by making a test API call. + + Args: + credentials: Google credentials object + + Returns: + True if credentials are valid, False otherwise + """ + try: + service = self.build_drive_service(credentials) + service.about().get(fields="user").execute() + return True + + except HttpError as e: + logging.error(f"HTTP error validating credentials: {e}") + return False + except Exception as e: + logging.error(f"Error validating credentials: {e}") + return False diff --git a/application/parser/connectors/google_drive/loader.py b/application/parser/connectors/google_drive/loader.py new file mode 100644 index 00000000..07219344 --- /dev/null +++ b/application/parser/connectors/google_drive/loader.py @@ -0,0 +1,536 @@ +""" +Google Drive loader for DocsGPT. +Loads documents from Google Drive using Google Drive API. +""" + +import io +import logging +import os +from typing import List, Dict, Any, Optional + +from googleapiclient.http import MediaIoBaseDownload +from googleapiclient.errors import HttpError + +from application.parser.connectors.base import BaseConnectorLoader +from application.parser.connectors.google_drive.auth import GoogleDriveAuth +from application.parser.schema.base import Document + + +class GoogleDriveLoader(BaseConnectorLoader): + + SUPPORTED_MIME_TYPES = { + 'application/pdf': '.pdf', + 'application/vnd.google-apps.document': '.docx', + 'application/vnd.google-apps.presentation': '.pptx', + 'application/vnd.google-apps.spreadsheet': '.xlsx', + 'application/vnd.openxmlformats-officedocument.wordprocessingml.document': '.docx', + 'application/vnd.openxmlformats-officedocument.presentationml.presentation': '.pptx', + 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': '.xlsx', + 'application/msword': '.doc', + 'application/vnd.ms-powerpoint': '.ppt', + 'application/vnd.ms-excel': '.xls', + 'text/plain': '.txt', + 'text/csv': '.csv', + 'text/html': '.html', + 'application/rtf': '.rtf', + 'image/jpeg': '.jpg', + 'image/jpg': '.jpg', + 'image/png': '.png', + } + + EXPORT_FORMATS = { + 'application/vnd.google-apps.document': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', + 'application/vnd.google-apps.presentation': 'application/vnd.openxmlformats-officedocument.presentationml.presentation', + 'application/vnd.google-apps.spreadsheet': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet' + } + + def __init__(self, session_token: str): + self.auth = GoogleDriveAuth() + self.session_token = session_token + + token_info = self.auth.get_token_info_from_session(session_token) + self.credentials = self.auth.create_credentials_from_token_info(token_info) + + try: + self.service = self.auth.build_drive_service(self.credentials) + except Exception as e: + logging.warning(f"Could not build Google Drive service: {e}") + self.service = None + + self.next_page_token = None + + + + def _process_file(self, file_metadata: Dict[str, Any], load_content: bool = True) -> Optional[Document]: + try: + file_id = file_metadata.get('id') + file_name = file_metadata.get('name', 'Unknown') + mime_type = file_metadata.get('mimeType', 'application/octet-stream') + + if mime_type not in self.SUPPORTED_MIME_TYPES and not mime_type.startswith('application/vnd.google-apps.'): + return None + if mime_type not in self.SUPPORTED_MIME_TYPES and not mime_type.startswith('application/vnd.google-apps.'): + logging.info(f"Skipping unsupported file type: {mime_type} for file {file_name}") + return None + # Google Drive provides timezone-aware ISO8601 dates + doc_metadata = { + 'file_name': file_name, + 'mime_type': mime_type, + 'size': file_metadata.get('size', None), + 'created_time': file_metadata.get('createdTime'), + 'modified_time': file_metadata.get('modifiedTime'), + 'parents': file_metadata.get('parents', []), + 'source': 'google_drive' + } + + if not load_content: + return Document( + text="", + doc_id=file_id, + extra_info=doc_metadata + ) + + content = self._download_file_content(file_id, mime_type) + if content is None: + logging.warning(f"Could not load content for file {file_name} ({file_id})") + return None + + return Document( + text=content, + doc_id=file_id, + extra_info=doc_metadata + ) + + except Exception as e: + logging.error(f"Error processing file: {e}") + return None + + def load_data(self, inputs: Dict[str, Any]) -> List[Document]: + session_token = inputs.get('session_token') + if session_token and session_token != self.session_token: + logging.warning("Session token in inputs differs from loader's session token. Using loader's session token.") + self.config = inputs + + try: + documents: List[Document] = [] + + folder_id = inputs.get('folder_id') + file_ids = inputs.get('file_ids', []) + limit = inputs.get('limit', 100) + list_only = inputs.get('list_only', False) + load_content = not list_only + page_token = inputs.get('page_token') + self.next_page_token = None + + if file_ids: + # Specific files requested: load them + for file_id in file_ids: + try: + doc = self._load_file_by_id(file_id, load_content=load_content) + if doc: + documents.append(doc) + elif hasattr(self, '_credential_refreshed') and self._credential_refreshed: + self._credential_refreshed = False + logging.info(f"Retrying load of file {file_id} after credential refresh") + doc = self._load_file_by_id(file_id, load_content=load_content) + if doc: + documents.append(doc) + except Exception as e: + logging.error(f"Error loading file {file_id}: {e}") + continue + else: + # Browsing mode: list immediate children of provided folder or root + parent_id = folder_id if folder_id else 'root' + documents = self._list_items_in_parent(parent_id, limit=limit, load_content=load_content, page_token=page_token) + + logging.info(f"Loaded {len(documents)} documents from Google Drive") + return documents + + except Exception as e: + logging.error(f"Error loading data from Google Drive: {e}", exc_info=True) + raise + + + + def _load_file_by_id(self, file_id: str, load_content: bool = True) -> Optional[Document]: + self._ensure_service() + + try: + file_metadata = self.service.files().get( + fileId=file_id, + fields='id,name,mimeType,size,createdTime,modifiedTime,parents' + ).execute() + + return self._process_file(file_metadata, load_content=load_content) + + except HttpError as e: + logging.error(f"HTTP error loading file {file_id}: {e.resp.status} - {e.content}") + + if e.resp.status in [401, 403]: + if hasattr(self.credentials, 'refresh_token') and self.credentials.refresh_token: + try: + from google.auth.transport.requests import Request + self.credentials.refresh(Request()) + self._ensure_service() + return None + except Exception as refresh_error: + raise ValueError(f"Authentication failed and could not be refreshed: {refresh_error}") + else: + raise ValueError("Authentication failed and cannot be refreshed: missing refresh_token") + + return None + except Exception as e: + logging.error(f"Error loading file {file_id}: {e}") + return None + + + def _list_items_in_parent(self, parent_id: str, limit: int = 100, load_content: bool = False, page_token: Optional[str] = None) -> List[Document]: + self._ensure_service() + + documents: List[Document] = [] + + try: + query = f"'{parent_id}' in parents and trashed=false" + next_token_out: Optional[str] = None + + while True: + page_size = 100 + if limit: + remaining = max(0, limit - len(documents)) + if remaining == 0: + break + page_size = min(100, remaining) + + results = self.service.files().list( + q=query, + fields='nextPageToken,files(id,name,mimeType,size,createdTime,modifiedTime,parents)', + pageToken=page_token, + pageSize=page_size + ).execute() + + items = results.get('files', []) + for item in items: + mime_type = item.get('mimeType') + if mime_type == 'application/vnd.google-apps.folder': + doc_metadata = { + 'file_name': item.get('name', 'Unknown'), + 'mime_type': mime_type, + 'size': item.get('size', None), + 'created_time': item.get('createdTime'), + 'modified_time': item.get('modifiedTime'), + 'parents': item.get('parents', []), + 'source': 'google_drive', + 'is_folder': True + } + documents.append(Document(text="", doc_id=item.get('id'), extra_info=doc_metadata)) + else: + doc = self._process_file(item, load_content=load_content) + if doc: + documents.append(doc) + + if limit and len(documents) >= limit: + self.next_page_token = results.get('nextPageToken') + return documents + + page_token = results.get('nextPageToken') + next_token_out = page_token + if not page_token: + break + + self.next_page_token = next_token_out + return documents + except Exception as e: + logging.error(f"Error listing items under parent {parent_id}: {e}") + return documents + + + + + def _download_file_content(self, file_id: str, mime_type: str) -> Optional[str]: + if not self.credentials.token: + logging.warning("No access token in credentials, attempting to refresh") + if hasattr(self.credentials, 'refresh_token') and self.credentials.refresh_token: + try: + from google.auth.transport.requests import Request + self.credentials.refresh(Request()) + logging.info("Credentials refreshed successfully") + self._ensure_service() + except Exception as e: + logging.error(f"Failed to refresh credentials: {e}") + raise ValueError("Authentication failed and cannot be refreshed: missing or invalid refresh_token") + else: + logging.error("No access token and no refresh_token available") + raise ValueError("Authentication failed and cannot be refreshed: missing refresh_token") + + if self.credentials.expired: + logging.warning("Credentials are expired, attempting to refresh") + if hasattr(self.credentials, 'refresh_token') and self.credentials.refresh_token: + try: + from google.auth.transport.requests import Request + self.credentials.refresh(Request()) + logging.info("Credentials refreshed successfully") + self._ensure_service() + except Exception as e: + logging.error(f"Failed to refresh expired credentials: {e}") + raise ValueError("Authentication failed and cannot be refreshed: expired credentials") + else: + logging.error("Credentials expired and no refresh_token available") + raise ValueError("Authentication failed and cannot be refreshed: missing refresh_token") + + try: + if mime_type in self.EXPORT_FORMATS: + export_mime_type = self.EXPORT_FORMATS[mime_type] + request = self.service.files().export_media( + fileId=file_id, + mimeType=export_mime_type + ) + else: + request = self.service.files().get_media(fileId=file_id) + + file_io = io.BytesIO() + downloader = MediaIoBaseDownload(file_io, request) + + done = False + while done is False: + try: + _, done = downloader.next_chunk() + except HttpError as e: + logging.error(f"HTTP error downloading file {file_id}: {e.resp.status} - {e.content}") + return None + except Exception as e: + logging.error(f"Error during download of file {file_id}: {e}") + return None + + content_bytes = file_io.getvalue() + + try: + content = content_bytes.decode('utf-8') + except UnicodeDecodeError: + try: + content = content_bytes.decode('latin-1') + except UnicodeDecodeError: + logging.error(f"Could not decode file {file_id} as text") + return None + + return content + + except HttpError as e: + logging.error(f"HTTP error downloading file {file_id}: {e.resp.status} - {e.content}") + + if e.resp.status in [401, 403]: + logging.error(f"Authentication error downloading file {file_id}") + + if hasattr(self.credentials, 'refresh_token') and self.credentials.refresh_token: + logging.info(f"Attempting to refresh credentials for file {file_id}") + try: + from google.auth.transport.requests import Request + self.credentials.refresh(Request()) + logging.info("Credentials refreshed successfully") + self._credential_refreshed = True + self._ensure_service() + return None + except Exception as refresh_error: + logging.error(f"Error refreshing credentials: {refresh_error}") + raise ValueError(f"Authentication failed and could not be refreshed: {refresh_error}") + else: + logging.error("Cannot refresh credentials: missing refresh_token") + raise ValueError("Authentication failed and cannot be refreshed: missing refresh_token") + + return None + except Exception as e: + logging.error(f"Error downloading file {file_id}: {e}") + return None + + + def _download_file_to_directory(self, file_id: str, local_dir: str) -> bool: + try: + self._ensure_service() + return self._download_single_file(file_id, local_dir) + except Exception as e: + logging.error(f"Error downloading file {file_id}: {e}", exc_info=True) + return False + + def _ensure_service(self): + if not self.service: + try: + self.service = self.auth.build_drive_service(self.credentials) + except Exception as e: + raise ValueError(f"Cannot access Google Drive: {e}") + + def _download_single_file(self, file_id: str, local_dir: str) -> bool: + file_metadata = self.service.files().get( + fileId=file_id, + fields='name,mimeType' + ).execute() + + file_name = file_metadata['name'] + mime_type = file_metadata['mimeType'] + + if mime_type not in self.SUPPORTED_MIME_TYPES and not mime_type.startswith('application/vnd.google-apps.'): + return False + + os.makedirs(local_dir, exist_ok=True) + full_path = os.path.join(local_dir, file_name) + + if mime_type in self.EXPORT_FORMATS: + export_mime_type = self.EXPORT_FORMATS[mime_type] + request = self.service.files().export_media( + fileId=file_id, + mimeType=export_mime_type + ) + extension = self._get_extension_for_mime_type(export_mime_type) + if not full_path.endswith(extension): + full_path += extension + else: + request = self.service.files().get_media(fileId=file_id) + + with open(full_path, 'wb') as f: + downloader = MediaIoBaseDownload(f, request) + done = False + while not done: + _, done = downloader.next_chunk() + + return True + + def _download_folder_recursive(self, folder_id: str, local_dir: str, recursive: bool = True) -> int: + files_downloaded = 0 + try: + os.makedirs(local_dir, exist_ok=True) + + query = f"'{folder_id}' in parents and trashed=false" + page_token = None + + while True: + results = self.service.files().list( + q=query, + fields='nextPageToken, files(id, name, mimeType)', + pageToken=page_token, + pageSize=1000 + ).execute() + + items = results.get('files', []) + logging.info(f"Found {len(items)} items in folder {folder_id}") + + for item in items: + item_name = item['name'] + item_id = item['id'] + mime_type = item['mimeType'] + + if mime_type == 'application/vnd.google-apps.folder': + if recursive: + # Create subfolder and recurse + subfolder_path = os.path.join(local_dir, item_name) + os.makedirs(subfolder_path, exist_ok=True) + subfolder_files = self._download_folder_recursive( + item_id, + subfolder_path, + recursive + ) + files_downloaded += subfolder_files + logging.info(f"Downloaded {subfolder_files} files from subfolder {item_name}") + else: + # Download file + success = self._download_single_file(item_id, local_dir) + if success: + files_downloaded += 1 + logging.info(f"Downloaded file: {item_name}") + else: + logging.warning(f"Failed to download file: {item_name}") + + page_token = results.get('nextPageToken') + if not page_token: + break + + return files_downloaded + + except Exception as e: + logging.error(f"Error in _download_folder_recursive for folder {folder_id}: {e}", exc_info=True) + return files_downloaded + + def _get_extension_for_mime_type(self, mime_type: str) -> str: + extensions = { + 'application/pdf': '.pdf', + 'text/plain': '.txt', + 'application/vnd.openxmlformats-officedocument.wordprocessingml.document': '.docx', + 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': '.xlsx', + 'application/vnd.openxmlformats-officedocument.presentationml.presentation': '.pptx', + 'text/html': '.html', + 'text/markdown': '.md', + } + return extensions.get(mime_type, '.bin') + + def _download_folder_contents(self, folder_id: str, local_dir: str, recursive: bool = True) -> int: + try: + self._ensure_service() + return self._download_folder_recursive(folder_id, local_dir, recursive) + except Exception as e: + logging.error(f"Error downloading folder {folder_id}: {e}", exc_info=True) + return 0 + + def download_to_directory(self, local_dir: str, source_config: dict = None) -> dict: + if source_config is None: + source_config = {} + + config = source_config if source_config else getattr(self, 'config', {}) + files_downloaded = 0 + + try: + folder_ids = config.get('folder_ids', []) + file_ids = config.get('file_ids', []) + recursive = config.get('recursive', True) + + self._ensure_service() + + if file_ids: + if isinstance(file_ids, str): + file_ids = [file_ids] + + for file_id in file_ids: + if self._download_file_to_directory(file_id, local_dir): + files_downloaded += 1 + + # Process folders + if folder_ids: + if isinstance(folder_ids, str): + folder_ids = [folder_ids] + + for folder_id in folder_ids: + try: + folder_metadata = self.service.files().get( + fileId=folder_id, + fields='name' + ).execute() + folder_name = folder_metadata.get('name', '') + folder_path = os.path.join(local_dir, folder_name) + os.makedirs(folder_path, exist_ok=True) + + folder_files = self._download_folder_recursive( + folder_id, + folder_path, + recursive + ) + files_downloaded += folder_files + logging.info(f"Downloaded {folder_files} files from folder {folder_name}") + except Exception as e: + logging.error(f"Error downloading folder {folder_id}: {e}", exc_info=True) + + if not file_ids and not folder_ids: + raise ValueError("No folder_ids or file_ids provided for download") + + return { + "files_downloaded": files_downloaded, + "directory_path": local_dir, + "empty_result": files_downloaded == 0, + "source_type": "google_drive", + "config_used": config + } + + except Exception as e: + return { + "files_downloaded": files_downloaded, + "directory_path": local_dir, + "empty_result": True, + "source_type": "google_drive", + "config_used": config, + "error": str(e) + } diff --git a/application/parser/remote/remote_creator.py b/application/parser/remote/remote_creator.py index 026abd76..a47b186a 100644 --- a/application/parser/remote/remote_creator.py +++ b/application/parser/remote/remote_creator.py @@ -6,6 +6,16 @@ from application.parser.remote.github_loader import GitHubLoader class RemoteCreator: + """ + Factory class for creating remote content loaders. + + These loaders fetch content from remote web sources like URLs, + sitemaps, web crawlers, social media platforms, etc. + + For external knowledge base connectors (like Google Drive), + use ConnectorCreator instead. + """ + loaders = { "url": WebLoader, "sitemap": SitemapLoader, @@ -18,5 +28,5 @@ class RemoteCreator: def create_loader(cls, type, *args, **kwargs): loader_class = cls.loaders.get(type.lower()) if not loader_class: - raise ValueError(f"No LLM class found for type {type}") + raise ValueError(f"No loader class found for type {type}") return loader_class(*args, **kwargs) diff --git a/application/requirements.txt b/application/requirements.txt index f922a2cb..80564689 100644 --- a/application/requirements.txt +++ b/application/requirements.txt @@ -14,6 +14,9 @@ Flask==3.1.1 faiss-cpu==1.9.0.post1 flask-restx==1.3.0 google-genai==1.3.0 +google-api-python-client==2.179.0 +google-auth-httplib2==0.2.0 +google-auth-oauthlib==1.2.2 gTTS==2.5.4 gunicorn==23.0.0 javalang==0.13.0 diff --git a/application/retriever/base.py b/application/retriever/base.py index fd99dbdd..36ac2e93 100644 --- a/application/retriever/base.py +++ b/application/retriever/base.py @@ -5,10 +5,6 @@ class BaseRetriever(ABC): def __init__(self): pass - @abstractmethod - def gen(self, *args, **kwargs): - pass - @abstractmethod def search(self, *args, **kwargs): pass diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index 9416b4f7..2ce863c2 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -1,4 +1,5 @@ import logging + from application.core.settings import settings from application.llm.llm_creator import LLMCreator from application.retriever.base import BaseRetriever @@ -20,10 +21,20 @@ class ClassicRAG(BaseRetriever): api_key=settings.API_KEY, decoded_token=None, ): - self.original_question = "" + """Initialize ClassicRAG retriever with vectorstore sources and LLM configuration""" + self.original_question = source.get("question", "") self.chat_history = chat_history if chat_history is not None else [] self.prompt = prompt - self.chunks = chunks + if isinstance(chunks, str): + try: + self.chunks = int(chunks) + except ValueError: + logging.warning( + f"Invalid chunks value '{chunks}', using default value 2" + ) + self.chunks = 2 + else: + self.chunks = chunks self.gpt_model = gpt_model self.token_limit = ( token_limit @@ -44,25 +55,52 @@ class ClassicRAG(BaseRetriever): user_api_key=self.user_api_key, decoded_token=decoded_token, ) - self.vectorstore = source["active_docs"] if "active_docs" in source else None + + if "active_docs" in source and source["active_docs"] is not None: + if isinstance(source["active_docs"], list): + self.vectorstores = source["active_docs"] + else: + self.vectorstores = [source["active_docs"]] + else: + self.vectorstores = [] self.question = self._rephrase_query() self.decoded_token = decoded_token + self._validate_vectorstore_config() + + def _validate_vectorstore_config(self): + """Validate vectorstore IDs and remove any empty/invalid entries""" + if not self.vectorstores: + logging.warning("No vectorstores configured for retrieval") + return + invalid_ids = [ + vs_id for vs_id in self.vectorstores if not vs_id or not vs_id.strip() + ] + if invalid_ids: + logging.warning(f"Found invalid vectorstore IDs: {invalid_ids}") + self.vectorstores = [ + vs_id for vs_id in self.vectorstores if vs_id and vs_id.strip() + ] def _rephrase_query(self): + """Rephrase user query with chat history context for better retrieval""" if ( not self.original_question or not self.chat_history or self.chat_history == [] or self.chunks == 0 - or self.vectorstore is None + or not self.vectorstores ): return self.original_question - prompt = f"""Given the following conversation history: + {self.chat_history} + + Rephrase the following user question to be a standalone search query + that captures all relevant context from the conversation: + """ messages = [ @@ -79,44 +117,62 @@ class ClassicRAG(BaseRetriever): return self.original_question def _get_data(self): - if self.chunks == 0 or self.vectorstore is None: - docs = [] - else: - docsearch = VectorCreator.create_vectorstore( - settings.VECTOR_STORE, self.vectorstore, settings.EMBEDDINGS_KEY - ) - docs_temp = docsearch.search(self.question, k=self.chunks) - docs = [ - { - "title": i.metadata.get( - "title", i.metadata.get("post_title", i.page_content) - ).split("/")[-1], - "text": i.page_content, - "source": ( - i.metadata.get("source") - if i.metadata.get("source") - else "local" - ), - } - for i in docs_temp - ] + """Retrieve relevant documents from configured vectorstores""" + if self.chunks == 0 or not self.vectorstores: + return [] + all_docs = [] + chunks_per_source = max(1, self.chunks // len(self.vectorstores)) - return docs + for vectorstore_id in self.vectorstores: + if vectorstore_id: + try: + docsearch = VectorCreator.create_vectorstore( + settings.VECTOR_STORE, vectorstore_id, settings.EMBEDDINGS_KEY + ) + docs_temp = docsearch.search(self.question, k=chunks_per_source) - def gen(): - pass + for doc in docs_temp: + if hasattr(doc, "page_content") and hasattr(doc, "metadata"): + page_content = doc.page_content + metadata = doc.metadata + else: + page_content = doc.get("text", doc.get("page_content", "")) + metadata = doc.get("metadata", {}) + title = metadata.get( + "title", metadata.get("post_title", page_content) + ) + if isinstance(title, str): + title = title.split("/")[-1] + else: + title = str(title).split("/")[-1] + all_docs.append( + { + "title": title, + "text": page_content, + "source": metadata.get("source") or vectorstore_id, + } + ) + except Exception as e: + logging.error( + f"Error searching vectorstore {vectorstore_id}: {e}", + exc_info=True, + ) + continue + return all_docs def search(self, query: str = ""): + """Search for documents using optional query override""" if query: self.original_question = query self.question = self._rephrase_query() return self._get_data() def get_params(self): + """Return current retriever configuration parameters""" return { "question": self.original_question, "rephrased_question": self.question, - "source": self.vectorstore, + "sources": self.vectorstores, "chunks": self.chunks, "token_limit": self.token_limit, "gpt_model": self.gpt_model, diff --git a/application/vectorstore/base.py b/application/vectorstore/base.py index a6b206c9..ea4885cd 100644 --- a/application/vectorstore/base.py +++ b/application/vectorstore/base.py @@ -1,20 +1,28 @@ -from abc import ABC, abstractmethod import os -from sentence_transformers import SentenceTransformer +from abc import ABC, abstractmethod + from langchain_openai import OpenAIEmbeddings +from sentence_transformers import SentenceTransformer + from application.core.settings import settings + class EmbeddingsWrapper: def __init__(self, model_name, *args, **kwargs): - self.model = SentenceTransformer(model_name, config_kwargs={'allow_dangerous_deserialization': True}, *args, **kwargs) + self.model = SentenceTransformer( + model_name, + config_kwargs={"allow_dangerous_deserialization": True}, + *args, + **kwargs + ) self.dimension = self.model.get_sentence_embedding_dimension() def embed_query(self, query: str): return self.model.encode(query).tolist() - + def embed_documents(self, documents: list): return self.model.encode(documents).tolist() - + def __call__(self, text): if isinstance(text, str): return self.embed_query(text) @@ -24,15 +32,14 @@ class EmbeddingsWrapper: raise ValueError("Input must be a string or a list of strings") - class EmbeddingsSingleton: _instances = {} @staticmethod def get_instance(embeddings_name, *args, **kwargs): if embeddings_name not in EmbeddingsSingleton._instances: - EmbeddingsSingleton._instances[embeddings_name] = EmbeddingsSingleton._create_instance( - embeddings_name, *args, **kwargs + EmbeddingsSingleton._instances[embeddings_name] = ( + EmbeddingsSingleton._create_instance(embeddings_name, *args, **kwargs) ) return EmbeddingsSingleton._instances[embeddings_name] @@ -40,9 +47,15 @@ class EmbeddingsSingleton: def _create_instance(embeddings_name, *args, **kwargs): embeddings_factory = { "openai_text-embedding-ada-002": OpenAIEmbeddings, - "huggingface_sentence-transformers/all-mpnet-base-v2": lambda: EmbeddingsWrapper("sentence-transformers/all-mpnet-base-v2"), - "huggingface_sentence-transformers-all-mpnet-base-v2": lambda: EmbeddingsWrapper("sentence-transformers/all-mpnet-base-v2"), - "huggingface_hkunlp/instructor-large": lambda: EmbeddingsWrapper("hkunlp/instructor-large"), + "huggingface_sentence-transformers/all-mpnet-base-v2": lambda: EmbeddingsWrapper( + "sentence-transformers/all-mpnet-base-v2" + ), + "huggingface_sentence-transformers-all-mpnet-base-v2": lambda: EmbeddingsWrapper( + "sentence-transformers/all-mpnet-base-v2" + ), + "huggingface_hkunlp/instructor-large": lambda: EmbeddingsWrapper( + "hkunlp/instructor-large" + ), } if embeddings_name in embeddings_factory: @@ -50,34 +63,63 @@ class EmbeddingsSingleton: else: return EmbeddingsWrapper(embeddings_name, *args, **kwargs) + class BaseVectorStore(ABC): def __init__(self): pass @abstractmethod def search(self, *args, **kwargs): + """Search for similar documents/chunks in the vectorstore""" + pass + + @abstractmethod + def add_texts(self, texts, metadatas=None, *args, **kwargs): + """Add texts with their embeddings to the vectorstore""" + pass + + def delete_index(self, *args, **kwargs): + """Delete the entire index/collection""" + pass + + def save_local(self, *args, **kwargs): + """Save vectorstore to local storage""" + pass + + def get_chunks(self, *args, **kwargs): + """Get all chunks from the vectorstore""" + pass + + def add_chunk(self, text, metadata=None, *args, **kwargs): + """Add a single chunk to the vectorstore""" + pass + + def delete_chunk(self, chunk_id, *args, **kwargs): + """Delete a specific chunk from the vectorstore""" pass def is_azure_configured(self): - return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME + return ( + settings.OPENAI_API_BASE + and settings.OPENAI_API_VERSION + and settings.AZURE_DEPLOYMENT_NAME + ) def _get_embeddings(self, embeddings_name, embeddings_key=None): if embeddings_name == "openai_text-embedding-ada-002": if self.is_azure_configured(): os.environ["OPENAI_API_TYPE"] = "azure" embedding_instance = EmbeddingsSingleton.get_instance( - embeddings_name, - model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME + embeddings_name, model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME ) else: embedding_instance = EmbeddingsSingleton.get_instance( - embeddings_name, - openai_api_key=embeddings_key + embeddings_name, openai_api_key=embeddings_key ) elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2": if os.path.exists("./models/all-mpnet-base-v2"): embedding_instance = EmbeddingsSingleton.get_instance( - embeddings_name = "./models/all-mpnet-base-v2", + embeddings_name="./models/all-mpnet-base-v2", ) else: embedding_instance = EmbeddingsSingleton.get_instance( @@ -87,4 +129,3 @@ class BaseVectorStore(ABC): embedding_instance = EmbeddingsSingleton.get_instance(embeddings_name) return embedding_instance - diff --git a/application/worker.py b/application/worker.py index 7309806d..10fb6c2b 100755 --- a/application/worker.py +++ b/application/worker.py @@ -6,6 +6,7 @@ import os import shutil import string import tempfile +from typing import Any, Dict import zipfile from collections import Counter @@ -21,6 +22,7 @@ from application.api.answer.services.stream_processor import get_prompt from application.core.mongo_db import MongoDB from application.core.settings import settings from application.parser.chunking import Chunker +from application.parser.connectors.connector_creator import ConnectorCreator from application.parser.embedding_pipeline import embed_and_store_documents from application.parser.file.bulk import SimpleDirectoryReader from application.parser.remote.remote_creator import RemoteCreator @@ -649,8 +651,11 @@ def remote_worker( "id": str(id), "type": loader, "remote_data": source_data, - "sync_frequency": sync_frequency, + "sync_frequency": sync_frequency } + + if operation_mode == "sync": + file_data["last_sync"] = datetime.datetime.now() upload_index(full_path, file_data) except Exception as e: logging.error("Error in remote_worker task: %s", str(e), exc_info=True) @@ -707,7 +712,7 @@ def sync_worker(self, frequency): self, source_data, name, user, source_type, frequency, retriever, doc_id ) sync_counts["total_sync_count"] += 1 - sync_counts[ + sync_counts[ "sync_success" if resp["status"] == "success" else "sync_failure" ] += 1 return { @@ -744,7 +749,7 @@ def attachment_worker(self, file_info, user): input_files=[local_path], exclude_hidden=True, errors="ignore" ) .load_data()[0] - .text, + .text, ) @@ -835,3 +840,174 @@ def agent_webhook_worker(self, agent_id, payload): f"Webhook processed for agent {agent_id}", extra={"agent_id": agent_id} ) return {"status": "success", "result": result} + + +def ingest_connector( + self, + job_name: str, + user: str, + source_type: str, + session_token=None, + file_ids=None, + folder_ids=None, + recursive=True, + retriever: str = "classic", + operation_mode: str = "upload", + doc_id=None, + sync_frequency: str = "never", +) -> Dict[str, Any]: + """ + Ingestion for internal knowledge bases (GoogleDrive, etc.). + + Args: + job_name: Name of the ingestion job + user: User identifier + source_type: Type of remote source ("google_drive", "dropbox", etc.) + session_token: Authentication token for the service + file_ids: List of file IDs to download + folder_ids: List of folder IDs to download + recursive: Whether to recursively download folders + retriever: Type of retriever to use + operation_mode: "upload" for initial ingestion, "sync" for incremental sync + doc_id: Document ID for sync operations (required when operation_mode="sync") + sync_frequency: How often to sync ("never", "daily", "weekly", "monthly") + """ + logging.info(f"Starting remote ingestion from {source_type} for user: {user}, job: {job_name}") + self.update_state(state="PROGRESS", meta={"current": 1}) + + with tempfile.TemporaryDirectory() as temp_dir: + try: + # Step 1: Initialize the appropriate loader + self.update_state(state="PROGRESS", meta={"current": 10, "status": "Initializing connector"}) + + if not session_token: + raise ValueError(f"{source_type} connector requires session_token") + + if not ConnectorCreator.is_supported(source_type): + raise ValueError(f"Unsupported connector type: {source_type}. Supported types: {ConnectorCreator.get_supported_connectors()}") + + remote_loader = ConnectorCreator.create_connector(source_type, session_token) + + # Create a clean config for storage + api_source_config = { + "file_ids": file_ids or [], + "folder_ids": folder_ids or [], + "recursive": recursive + } + + # Step 2: Download files to temp directory + self.update_state(state="PROGRESS", meta={"current": 20, "status": "Downloading files"}) + download_info = remote_loader.download_to_directory( + temp_dir, + api_source_config + ) + + if download_info.get("empty_result", False) or not download_info.get("files_downloaded", 0): + logging.warning(f"No files were downloaded from {source_type}") + # Create empty result directly instead of calling a separate method + return { + "name": job_name, + "user": user, + "tokens": 0, + "type": source_type, + "source_config": api_source_config, + "directory_structure": "{}", + } + + # Step 3: Use SimpleDirectoryReader to process downloaded files + self.update_state(state="PROGRESS", meta={"current": 40, "status": "Processing files"}) + reader = SimpleDirectoryReader( + input_dir=temp_dir, + recursive=True, + required_exts=[ + ".rst", ".md", ".pdf", ".txt", ".docx", ".csv", ".epub", + ".html", ".mdx", ".json", ".xlsx", ".pptx", ".png", + ".jpg", ".jpeg", + ], + exclude_hidden=True, + file_metadata=metadata_from_filename, + ) + raw_docs = reader.load_data() + directory_structure = getattr(reader, 'directory_structure', {}) + + + + # Step 4: Process documents (chunking, embedding, etc.) + self.update_state(state="PROGRESS", meta={"current": 60, "status": "Processing documents"}) + + chunker = Chunker( + chunking_strategy="classic_chunk", + max_tokens=MAX_TOKENS, + min_tokens=MIN_TOKENS, + duplicate_headers=False, + ) + raw_docs = chunker.chunk(documents=raw_docs) + + # Preserve source information in document metadata + for doc in raw_docs: + if hasattr(doc, 'extra_info') and doc.extra_info: + source = doc.extra_info.get('source') + if source and os.path.isabs(source): + # Convert absolute path to relative path + doc.extra_info['source'] = os.path.relpath(source, start=temp_dir) + + docs = [Document.to_langchain_format(raw_doc) for raw_doc in raw_docs] + + if operation_mode == "upload": + id = ObjectId() + elif operation_mode == "sync": + if not doc_id or not ObjectId.is_valid(doc_id): + logging.error("Invalid doc_id provided for sync operation: %s", doc_id) + raise ValueError("doc_id must be provided for sync operation.") + id = ObjectId(doc_id) + else: + raise ValueError(f"Invalid operation_mode: {operation_mode}") + + vector_store_path = os.path.join(temp_dir, "vector_store") + os.makedirs(vector_store_path, exist_ok=True) + + self.update_state(state="PROGRESS", meta={"current": 80, "status": "Storing documents"}) + embed_and_store_documents(docs, vector_store_path, id, self) + + tokens = count_tokens_docs(docs) + + # Step 6: Upload index files + file_data = { + "user": user, + "name": job_name, + "tokens": tokens, + "retriever": retriever, + "id": str(id), + "type": "connector", + "remote_data": json.dumps({ + "provider": source_type, + **api_source_config + }), + "directory_structure": json.dumps(directory_structure), + "sync_frequency": sync_frequency + } + + if operation_mode == "sync": + file_data["last_sync"] = datetime.datetime.now() + else: + file_data["last_sync"] = datetime.datetime.now() + + upload_index(vector_store_path, file_data) + + # Ensure we mark the task as complete + self.update_state(state="PROGRESS", meta={"current": 100, "status": "Complete"}) + + logging.info(f"Remote ingestion completed: {job_name}") + + return { + "user": user, + "name": job_name, + "tokens": tokens, + "type": source_type, + "id": str(id), + "status": "complete" + } + + except Exception as e: + logging.error(f"Error during remote ingestion: {e}", exc_info=True) + raise diff --git a/frontend/src/Hero.tsx b/frontend/src/Hero.tsx index 01695a19..9b17c10f 100644 --- a/frontend/src/Hero.tsx +++ b/frontend/src/Hero.tsx @@ -29,7 +29,7 @@ export default function Hero({ {/* Demo Buttons Section */} -
+
{demos?.map( (demo: { header: string; query: string }, key: number) => diff --git a/frontend/src/agents/NewAgent.tsx b/frontend/src/agents/NewAgent.tsx index da8cef5d..92e8b961 100644 --- a/frontend/src/agents/NewAgent.tsx +++ b/frontend/src/agents/NewAgent.tsx @@ -45,6 +45,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { description: '', image: '', source: '', + sources: [], chunks: '', retriever: '', prompt_id: 'default', @@ -150,7 +151,41 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { const formData = new FormData(); formData.append('name', agent.name); formData.append('description', agent.description); - formData.append('source', agent.source); + + if (selectedSourceIds.size > 1) { + const sourcesArray = Array.from(selectedSourceIds) + .map((id) => { + const sourceDoc = sourceDocs?.find( + (source) => + source.id === id || source.retriever === id || source.name === id, + ); + if (sourceDoc?.name === 'Default' && !sourceDoc?.id) { + return 'default'; + } + return sourceDoc?.id || id; + }) + .filter(Boolean); + formData.append('sources', JSON.stringify(sourcesArray)); + formData.append('source', ''); + } else if (selectedSourceIds.size === 1) { + const singleSourceId = Array.from(selectedSourceIds)[0]; + const sourceDoc = sourceDocs?.find( + (source) => + source.id === singleSourceId || + source.retriever === singleSourceId || + source.name === singleSourceId, + ); + let finalSourceId; + if (sourceDoc?.name === 'Default' && !sourceDoc?.id) + finalSourceId = 'default'; + else finalSourceId = sourceDoc?.id || singleSourceId; + formData.append('source', String(finalSourceId)); + formData.append('sources', JSON.stringify([])); + } else { + formData.append('source', ''); + formData.append('sources', JSON.stringify([])); + } + formData.append('chunks', agent.chunks); formData.append('retriever', agent.retriever); formData.append('prompt_id', agent.prompt_id); @@ -196,7 +231,41 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { const formData = new FormData(); formData.append('name', agent.name); formData.append('description', agent.description); - formData.append('source', agent.source); + + if (selectedSourceIds.size > 1) { + const sourcesArray = Array.from(selectedSourceIds) + .map((id) => { + const sourceDoc = sourceDocs?.find( + (source) => + source.id === id || source.retriever === id || source.name === id, + ); + if (sourceDoc?.name === 'Default' && !sourceDoc?.id) { + return 'default'; + } + return sourceDoc?.id || id; + }) + .filter(Boolean); + formData.append('sources', JSON.stringify(sourcesArray)); + formData.append('source', ''); + } else if (selectedSourceIds.size === 1) { + const singleSourceId = Array.from(selectedSourceIds)[0]; + const sourceDoc = sourceDocs?.find( + (source) => + source.id === singleSourceId || + source.retriever === singleSourceId || + source.name === singleSourceId, + ); + let finalSourceId; + if (sourceDoc?.name === 'Default' && !sourceDoc?.id) + finalSourceId = 'default'; + else finalSourceId = sourceDoc?.id || singleSourceId; + formData.append('source', String(finalSourceId)); + formData.append('sources', JSON.stringify([])); + } else { + formData.append('source', ''); + formData.append('sources', JSON.stringify([])); + } + formData.append('chunks', agent.chunks); formData.append('retriever', agent.retriever); formData.append('prompt_id', agent.prompt_id); @@ -293,9 +362,33 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { throw new Error('Failed to fetch agent'); } const data = await response.json(); - if (data.source) setSelectedSourceIds(new Set([data.source])); - else if (data.retriever) + + if (data.sources && data.sources.length > 0) { + const mappedSources = data.sources.map((sourceId: string) => { + if (sourceId === 'default') { + const defaultSource = sourceDocs?.find( + (source) => source.name === 'Default', + ); + return defaultSource?.retriever || 'classic'; + } + return sourceId; + }); + setSelectedSourceIds(new Set(mappedSources)); + } else if (data.source) { + if (data.source === 'default') { + const defaultSource = sourceDocs?.find( + (source) => source.name === 'Default', + ); + setSelectedSourceIds( + new Set([defaultSource?.retriever || 'classic']), + ); + } else { + setSelectedSourceIds(new Set([data.source])); + } + } else if (data.retriever) { setSelectedSourceIds(new Set([data.retriever])); + } + if (data.tools) setSelectedToolIds(new Set(data.tools)); if (data.status === 'draft') setEffectiveMode('draft'); if (data.json_schema) { @@ -311,25 +404,57 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { }, [agentId, mode, token]); useEffect(() => { - const selectedSource = Array.from(selectedSourceIds).map((id) => - sourceDocs?.find( - (source) => - source.id === id || source.retriever === id || source.name === id, - ), - ); - if (selectedSource[0]?.model === embeddingsName) { - if (selectedSource[0] && 'id' in selectedSource[0]) { + const selectedSources = Array.from(selectedSourceIds) + .map((id) => + sourceDocs?.find( + (source) => + source.id === id || source.retriever === id || source.name === id, + ), + ) + .filter(Boolean); + + if (selectedSources.length > 0) { + // Handle multiple sources + if (selectedSources.length > 1) { + // Multiple sources selected - store in sources array + const sourceIds = selectedSources + .map((source) => source?.id) + .filter((id): id is string => Boolean(id)); setAgent((prev) => ({ ...prev, - source: selectedSource[0]?.id || 'default', + sources: sourceIds, + source: '', // Clear single source for multiple sources retriever: '', })); - } else - setAgent((prev) => ({ - ...prev, - source: '', - retriever: selectedSource[0]?.retriever || 'classic', - })); + } else { + // Single source selected - maintain backward compatibility + const selectedSource = selectedSources[0]; + if (selectedSource?.model === embeddingsName) { + if (selectedSource && 'id' in selectedSource) { + setAgent((prev) => ({ + ...prev, + source: selectedSource?.id || 'default', + sources: [], // Clear sources array for single source + retriever: '', + })); + } else { + setAgent((prev) => ({ + ...prev, + source: '', + sources: [], // Clear sources array + retriever: selectedSource?.retriever || 'classic', + })); + } + } + } + } else { + // No sources selected + setAgent((prev) => ({ + ...prev, + source: '', + sources: [], + retriever: '', + })); } }, [selectedSourceIds]); @@ -461,7 +586,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { onChange={(e) => setAgent({ ...agent, name: e.target.value })} />