diff --git a/application/api/answer/services/stream_processor.py b/application/api/answer/services/stream_processor.py index 648d24f5..f6e639ef 100644 --- a/application/api/answer/services/stream_processor.py +++ b/application/api/answer/services/stream_processor.py @@ -69,11 +69,8 @@ class StreamProcessor: self.decoded_token.get("sub") if self.decoded_token is not None else None ) self.conversation_id = self.data.get("conversation_id") - self.source = ( - {"active_docs": self.data["active_docs"]} - if "active_docs" in self.data - else {} - ) + self.source = {} + self.all_sources = [] self.attachments = [] self.history = [] self.agent_config = {} @@ -85,6 +82,8 @@ class StreamProcessor: def initialize(self): """Initialize all required components for processing""" + self._configure_agent() + self._configure_source() self._configure_retriever() self._configure_agent() self._load_conversation_history() @@ -171,13 +170,77 @@ class StreamProcessor: source = data.get("source") if isinstance(source, DBRef): source_doc = self.db.dereference(source) - data["source"] = str(source_doc["_id"]) - data["retriever"] = source_doc.get("retriever", data.get("retriever")) - data["chunks"] = source_doc.get("chunks", data.get("chunks")) + if source_doc: + data["source"] = str(source_doc["_id"]) + data["retriever"] = source_doc.get("retriever", data.get("retriever")) + data["chunks"] = source_doc.get("chunks", data.get("chunks")) + else: + data["source"] = None + elif source == "default": + data["source"] = "default" else: data["source"] = None + # Handle multiple sources + + sources = data.get("sources", []) + if sources and isinstance(sources, list): + sources_list = [] + for i, source_ref in enumerate(sources): + if source_ref == "default": + processed_source = { + "id": "default", + "retriever": "classic", + "chunks": data.get("chunks", "2"), + } + sources_list.append(processed_source) + elif isinstance(source_ref, DBRef): + source_doc = self.db.dereference(source_ref) + if source_doc: + processed_source = { + "id": str(source_doc["_id"]), + "retriever": source_doc.get("retriever", "classic"), + "chunks": source_doc.get("chunks", data.get("chunks", "2")), + } + sources_list.append(processed_source) + data["sources"] = sources_list + else: + data["sources"] = [] return data + def _configure_source(self): + """Configure the source based on agent data""" + api_key = self.data.get("api_key") or self.agent_key + + if api_key: + agent_data = self._get_data_from_api_key(api_key) + + if agent_data.get("sources") and len(agent_data["sources"]) > 0: + source_ids = [ + source["id"] for source in agent_data["sources"] if source.get("id") + ] + if source_ids: + self.source = {"active_docs": source_ids} + else: + self.source = {} + self.all_sources = agent_data["sources"] + elif agent_data.get("source"): + self.source = {"active_docs": agent_data["source"]} + self.all_sources = [ + { + "id": agent_data["source"], + "retriever": agent_data.get("retriever", "classic"), + } + ] + else: + self.source = {} + self.all_sources = [] + return + if "active_docs" in self.data: + self.source = {"active_docs": self.data["active_docs"]} + return + self.source = {} + self.all_sources = [] + def _configure_agent(self): """Configure the agent based on request data""" agent_id = self.data.get("agent_id") @@ -203,7 +266,13 @@ class StreamProcessor: if data_key.get("retriever"): self.retriever_config["retriever_name"] = data_key["retriever"] if data_key.get("chunks") is not None: - self.retriever_config["chunks"] = data_key["chunks"] + try: + self.retriever_config["chunks"] = int(data_key["chunks"]) + except (ValueError, TypeError): + logger.warning( + f"Invalid chunks value: {data_key['chunks']}, using default value 2" + ) + self.retriever_config["chunks"] = 2 elif self.agent_key: data_key = self._get_data_from_api_key(self.agent_key) self.agent_config.update( @@ -224,7 +293,13 @@ class StreamProcessor: if data_key.get("retriever"): self.retriever_config["retriever_name"] = data_key["retriever"] if data_key.get("chunks") is not None: - self.retriever_config["chunks"] = data_key["chunks"] + try: + self.retriever_config["chunks"] = int(data_key["chunks"]) + except (ValueError, TypeError): + logger.warning( + f"Invalid chunks value: {data_key['chunks']}, using default value 2" + ) + self.retriever_config["chunks"] = 2 else: self.agent_config.update( { @@ -243,7 +318,8 @@ class StreamProcessor: "token_limit": self.data.get("token_limit", settings.DEFAULT_MAX_HISTORY), } - if "isNoneDoc" in self.data and self.data["isNoneDoc"]: + api_key = self.data.get("api_key") or self.agent_key + if not api_key and "isNoneDoc" in self.data and self.data["isNoneDoc"]: self.retriever_config["chunks"] = 0 def create_agent(self): diff --git a/application/api/connector/routes.py b/application/api/connector/routes.py index f203a703..1647aa78 100644 --- a/application/api/connector/routes.py +++ b/application/api/connector/routes.py @@ -1,6 +1,5 @@ import datetime import json -import logging from bson.objectid import ObjectId diff --git a/application/api/user/routes.py b/application/api/user/routes.py index 7eae66f6..f508b7cf 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -3,11 +3,12 @@ import json import math import os import secrets +import tempfile import uuid +import zipfile from functools import wraps from typing import Optional, Tuple -import tempfile -import zipfile + from bson.binary import Binary, UuidRepresentation from bson.dbref import DBRef from bson.objectid import ObjectId @@ -25,6 +26,7 @@ from pymongo import ReturnDocument from werkzeug.utils import secure_filename from application.agents.tools.tool_manager import ToolManager +from application.api import api from application.api.user.tasks import ( ingest, @@ -35,19 +37,18 @@ from application.api.user.tasks import ( ) from application.core.mongo_db import MongoDB from application.core.settings import settings -from application.api import api +from application.parser.connectors.connector_creator import ConnectorCreator from application.storage.storage_creator import StorageCreator from application.tts.google_tts import GoogleTTS from application.utils import ( check_required_fields, generate_image_url, + num_tokens_from_string, safe_filename, validate_function_name, validate_required_fields, ) -from application.utils import num_tokens_from_string from application.vectorstore.vector_creator import VectorCreator -from application.parser.connectors.connector_creator import ConnectorCreator storage = StorageCreator.get_storage() @@ -74,7 +75,6 @@ try: users_collection.create_index("user_id", unique=True) except Exception as e: print("Error creating indexes:", e) - user = Blueprint("user", __name__) user_ns = Namespace("user", description="User related operations", path="/") api.add_namespace(user_ns) @@ -127,11 +127,9 @@ def ensure_user_doc(user_id): updates["agent_preferences.pinned"] = [] if "shared_with_me" not in prefs: updates["agent_preferences.shared_with_me"] = [] - if updates: users_collection.update_one({"user_id": user_id}, {"$set": updates}) user_doc = users_collection.find_one({"user_id": user_id}) - return user_doc @@ -183,7 +181,6 @@ def handle_image_upload( jsonify({"success": False, "message": "Image upload failed"}), 400, ) - return image_url, None @@ -297,8 +294,8 @@ class GetSingleConversation(Resource): ) if not conversation: return make_response(jsonify({"status": "not found"}), 404) - # Process queries to include attachment names + queries = conversation["queries"] for query in queries: if "attachments" in query and query["attachments"]: @@ -494,11 +491,11 @@ class DeleteOldIndexes(Resource): ) if not doc: return make_response(jsonify({"status": "not found"}), 404) - storage = StorageCreator.get_storage() try: # Delete vector index + if settings.VECTOR_STORE == "faiss": index_path = f"indexes/{str(doc['_id'])}" if storage.file_exists(f"{index_path}/index.faiss"): @@ -510,7 +507,6 @@ class DeleteOldIndexes(Resource): settings.VECTOR_STORE, source_id=str(doc["_id"]) ) vectorstore.delete_index() - if "file_path" in doc and doc["file_path"]: file_path = doc["file_path"] if storage.is_directory(file_path): @@ -519,7 +515,6 @@ class DeleteOldIndexes(Resource): storage.delete_file(f) else: storage.delete_file(file_path) - except FileNotFoundError: pass except Exception as err: @@ -527,7 +522,6 @@ class DeleteOldIndexes(Resource): f"Error deleting files and indexes: {err}", exc_info=True ) return make_response(jsonify({"success": False}), 400) - sources_collection.delete_one({"_id": ObjectId(source_id)}) return make_response(jsonify({"success": True}), 200) @@ -569,6 +563,7 @@ class UploadFile(Resource): job_name = request.form["name"] # Create safe versions for filesystem operations + safe_user = safe_filename(user) dir_name = safe_filename(job_name) base_path = f"{settings.UPLOAD_FOLDER}/{safe_user}/{dir_name}" @@ -576,7 +571,6 @@ class UploadFile(Resource): try: storage = StorageCreator.get_storage() - for file in files: original_filename = file.filename safe_file = safe_filename(original_filename) @@ -587,43 +581,65 @@ class UploadFile(Resource): if zipfile.is_zipfile(temp_file_path): try: - with zipfile.ZipFile(temp_file_path, 'r') as zip_ref: + with zipfile.ZipFile(temp_file_path, "r") as zip_ref: zip_ref.extractall(path=temp_dir) # Walk through extracted files and upload them + for root, _, files in os.walk(temp_dir): for extracted_file in files: - if os.path.join(root, extracted_file) == temp_file_path: + if ( + os.path.join(root, extracted_file) + == temp_file_path + ): continue - - rel_path = os.path.relpath(os.path.join(root, extracted_file), temp_dir) + rel_path = os.path.relpath( + os.path.join(root, extracted_file), temp_dir + ) storage_path = f"{base_path}/{rel_path}" - with open(os.path.join(root, extracted_file), 'rb') as f: + with open( + os.path.join(root, extracted_file), "rb" + ) as f: storage.save_file(f, storage_path) except Exception as e: - current_app.logger.error(f"Error extracting zip: {e}", exc_info=True) + current_app.logger.error( + f"Error extracting zip: {e}", exc_info=True + ) # If zip extraction fails, save the original zip file + file_path = f"{base_path}/{safe_file}" - with open(temp_file_path, 'rb') as f: + with open(temp_file_path, "rb") as f: storage.save_file(f, file_path) else: # For non-zip files, save directly - file_path = f"{base_path}/{safe_file}" - with open(temp_file_path, 'rb') as f: - storage.save_file(f, file_path) + file_path = f"{base_path}/{safe_file}" + with open(temp_file_path, "rb") as f: + storage.save_file(f, file_path) task = ingest.delay( settings.UPLOAD_FOLDER, [ - ".rst", ".md", ".pdf", ".txt", ".docx", ".csv", ".epub", - ".html", ".mdx", ".json", ".xlsx", ".pptx", ".png", - ".jpg", ".jpeg", + ".rst", + ".md", + ".pdf", + ".txt", + ".docx", + ".csv", + ".epub", + ".html", + ".mdx", + ".json", + ".xlsx", + ".pptx", + ".png", + ".jpg", + ".jpeg", ], job_name, user, file_path=base_path, - filename=dir_name + filename=dir_name, ) except Exception as err: current_app.logger.error(f"Error uploading file: {err}", exc_info=True) @@ -637,12 +653,29 @@ class ManageSourceFiles(Resource): api.model( "ManageSourceFilesModel", { - "source_id": fields.String(required=True, description="Source ID to modify"), - "operation": fields.String(required=True, description="Operation: 'add', 'remove', or 'remove_directory'"), - "file_paths": fields.List(fields.String, required=False, description="File paths to remove (for remove operation)"), - "directory_path": fields.String(required=False, description="Directory path to remove (for remove_directory operation)"), - "file": fields.Raw(required=False, description="Files to add (for add operation)"), - "parent_dir": fields.String(required=False, description="Parent directory path relative to source root"), + "source_id": fields.String( + required=True, description="Source ID to modify" + ), + "operation": fields.String( + required=True, + description="Operation: 'add', 'remove', or 'remove_directory'", + ), + "file_paths": fields.List( + fields.String, + required=False, + description="File paths to remove (for remove operation)", + ), + "directory_path": fields.String( + required=False, + description="Directory path to remove (for remove_directory operation)", + ), + "file": fields.Raw( + required=False, description="Files to add (for add operation)" + ), + "parent_dir": fields.String( + required=False, + description="Parent directory path relative to source root", + ), }, ) ) @@ -652,39 +685,58 @@ class ManageSourceFiles(Resource): def post(self): decoded_token = request.decoded_token if not decoded_token: - return make_response(jsonify({"success": False, "message": "Unauthorized"}), 401) - + return make_response( + jsonify({"success": False, "message": "Unauthorized"}), 401 + ) user = decoded_token.get("sub") source_id = request.form.get("source_id") operation = request.form.get("operation") if not source_id or not operation: return make_response( - jsonify({"success": False, "message": "source_id and operation are required"}), 400 + jsonify( + { + "success": False, + "message": "source_id and operation are required", + } + ), + 400, ) - if operation not in ["add", "remove", "remove_directory"]: return make_response( - jsonify({"success": False, "message": "operation must be 'add', 'remove', or 'remove_directory'"}), 400 + jsonify( + { + "success": False, + "message": "operation must be 'add', 'remove', or 'remove_directory'", + } + ), + 400, ) - try: ObjectId(source_id) except Exception: return make_response( jsonify({"success": False, "message": "Invalid source ID format"}), 400 ) - try: - source = sources_collection.find_one({"_id": ObjectId(source_id), "user": user}) + source = sources_collection.find_one( + {"_id": ObjectId(source_id), "user": user} + ) if not source: return make_response( - jsonify({"success": False, "message": "Source not found or access denied"}), 404 + jsonify( + { + "success": False, + "message": "Source not found or access denied", + } + ), + 404, ) except Exception as err: current_app.logger.error(f"Error finding source: {err}", exc_info=True) - return make_response(jsonify({"success": False, "message": "Database error"}), 500) - + return make_response( + jsonify({"success": False, "message": "Database error"}), 500 + ) try: storage = StorageCreator.get_storage() source_file_path = source.get("file_path", "") @@ -692,98 +744,138 @@ class ManageSourceFiles(Resource): if parent_dir and (parent_dir.startswith("/") or ".." in parent_dir): return make_response( - jsonify({"success": False, "message": "Invalid parent directory path"}), 400 + jsonify( + {"success": False, "message": "Invalid parent directory path"} + ), + 400, ) - if operation == "add": files = request.files.getlist("file") if not files or all(file.filename == "" for file in files): return make_response( - jsonify({"success": False, "message": "No files provided for add operation"}), 400 + jsonify( + { + "success": False, + "message": "No files provided for add operation", + } + ), + 400, ) - added_files = [] target_dir = source_file_path if parent_dir: target_dir = f"{source_file_path}/{parent_dir}" - for file in files: if file.filename: safe_filename_str = safe_filename(file.filename) file_path = f"{target_dir}/{safe_filename_str}" # Save file to storage + storage.save_file(file, file_path) added_files.append(safe_filename_str) - # Trigger re-ingestion pipeline + from application.api.user.tasks import reingest_source_task task = reingest_source_task.delay(source_id=source_id, user=user) - return make_response(jsonify({ - "success": True, - "message": f"Added {len(added_files)} files", - "added_files": added_files, - "parent_dir": parent_dir, - "reingest_task_id": task.id - }), 200) - + return make_response( + jsonify( + { + "success": True, + "message": f"Added {len(added_files)} files", + "added_files": added_files, + "parent_dir": parent_dir, + "reingest_task_id": task.id, + } + ), + 200, + ) elif operation == "remove": file_paths_str = request.form.get("file_paths") if not file_paths_str: return make_response( - jsonify({"success": False, "message": "file_paths required for remove operation"}), 400 + jsonify( + { + "success": False, + "message": "file_paths required for remove operation", + } + ), + 400, ) - try: - file_paths = json.loads(file_paths_str) if isinstance(file_paths_str, str) else file_paths_str + file_paths = ( + json.loads(file_paths_str) + if isinstance(file_paths_str, str) + else file_paths_str + ) except Exception: return make_response( - jsonify({"success": False, "message": "Invalid file_paths format"}), 400 + jsonify( + {"success": False, "message": "Invalid file_paths format"} + ), + 400, ) - # Remove files from storage and directory structure + removed_files = [] for file_path in file_paths: full_path = f"{source_file_path}/{file_path}" # Remove from storage + if storage.file_exists(full_path): storage.delete_file(full_path) removed_files.append(file_path) - # Trigger re-ingestion pipeline + from application.api.user.tasks import reingest_source_task task = reingest_source_task.delay(source_id=source_id, user=user) - return make_response(jsonify({ - "success": True, - "message": f"Removed {len(removed_files)} files", - "removed_files": removed_files, - "reingest_task_id": task.id - }), 200) - + return make_response( + jsonify( + { + "success": True, + "message": f"Removed {len(removed_files)} files", + "removed_files": removed_files, + "reingest_task_id": task.id, + } + ), + 200, + ) elif operation == "remove_directory": directory_path = request.form.get("directory_path") if not directory_path: return make_response( - jsonify({"success": False, "message": "directory_path required for remove_directory operation"}), 400 + jsonify( + { + "success": False, + "message": "directory_path required for remove_directory operation", + } + ), + 400, ) - # Validate directory path (prevent path traversal) + if directory_path.startswith("/") or ".." in directory_path: current_app.logger.warning( f"Invalid directory path attempted for removal. " f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}" ) return make_response( - jsonify({"success": False, "message": "Invalid directory path"}), 400 + jsonify( + {"success": False, "message": "Invalid directory path"} + ), + 400, ) - - full_directory_path = f"{source_file_path}/{directory_path}" if directory_path else source_file_path + full_directory_path = ( + f"{source_file_path}/{directory_path}" + if directory_path + else source_file_path + ) if not storage.is_directory(full_directory_path): current_app.logger.warning( @@ -792,9 +884,14 @@ class ManageSourceFiles(Resource): f"Full path: {full_directory_path}" ) return make_response( - jsonify({"success": False, "message": "Directory not found or is not a directory"}), 404 + jsonify( + { + "success": False, + "message": "Directory not found or is not a directory", + } + ), + 404, ) - success = storage.remove_directory(full_directory_path) if not success: @@ -804,9 +901,11 @@ class ManageSourceFiles(Resource): f"Full path: {full_directory_path}" ) return make_response( - jsonify({"success": False, "message": "Failed to remove directory"}), 500 + jsonify( + {"success": False, "message": "Failed to remove directory"} + ), + 500, ) - current_app.logger.info( f"Successfully removed directory. " f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}, " @@ -814,17 +913,22 @@ class ManageSourceFiles(Resource): ) # Trigger re-ingestion pipeline + from application.api.user.tasks import reingest_source_task task = reingest_source_task.delay(source_id=source_id, user=user) - return make_response(jsonify({ - "success": True, - "message": f"Successfully removed directory: {directory_path}", - "removed_directory": directory_path, - "reingest_task_id": task.id - }), 200) - + return make_response( + jsonify( + { + "success": True, + "message": f"Successfully removed directory: {directory_path}", + "removed_directory": directory_path, + "reingest_task_id": task.id, + } + ), + 200, + ) except Exception as err: error_context = f"operation={operation}, user={user}, source_id={source_id}" if operation == "remove_directory": @@ -836,9 +940,12 @@ class ManageSourceFiles(Resource): elif operation == "add": parent_dir = request.form.get("parent_dir", "") error_context += f", parent_dir={parent_dir}" - - current_app.logger.error(f"Error managing source files: {err} ({error_context})", exc_info=True) - return make_response(jsonify({"success": False, "message": "Operation failed"}), 500) + current_app.logger.error( + f"Error managing source files: {err} ({error_context})", exc_info=True + ) + return make_response( + jsonify({"success": False, "message": "Operation failed"}), 500 + ) @user_ns.route("/api/remote") @@ -882,25 +989,31 @@ class UploadRemote(Resource): elif data["source"] in ConnectorCreator.get_supported_connectors(): session_token = config.get("session_token") if not session_token: - return make_response(jsonify({ - "success": False, - "error": f"Missing session_token in {data['source']} configuration" - }), 400) - + return make_response( + jsonify( + { + "success": False, + "error": f"Missing session_token in {data['source']} configuration", + } + ), + 400, + ) # Process file_ids + file_ids = config.get("file_ids", []) if isinstance(file_ids, str): - file_ids = [id.strip() for id in file_ids.split(',') if id.strip()] + file_ids = [id.strip() for id in file_ids.split(",") if id.strip()] elif not isinstance(file_ids, list): file_ids = [] - # Process folder_ids + folder_ids = config.get("folder_ids", []) if isinstance(folder_ids, str): - folder_ids = [id.strip() for id in folder_ids.split(',') if id.strip()] + folder_ids = [ + id.strip() for id in folder_ids.split(",") if id.strip() + ] elif not isinstance(folder_ids, list): folder_ids = [] - config["file_ids"] = file_ids config["folder_ids"] = folder_ids @@ -912,9 +1025,11 @@ class UploadRemote(Resource): file_ids=file_ids, folder_ids=folder_ids, recursive=config.get("recursive", False), - retriever=config.get("retriever", "classic") + retriever=config.get("retriever", "classic"), + ) + return make_response( + jsonify({"success": True, "task_id": task.id}), 200 ) - return make_response(jsonify({"success": True, "task_id": task.id}), 200) task = ingest_remote.delay( source_data=source_data, job_name=data["name"], @@ -1023,7 +1138,7 @@ class PaginatedSources(Resource): "retriever": doc.get("retriever", "classic"), "syncFrequency": doc.get("sync_frequency", ""), "isNested": bool(doc.get("directory_structure")), - "type": doc.get("type", "file") + "type": doc.get("type", "file"), } paginated_docs.append(doc_data) response = { @@ -1072,7 +1187,9 @@ class CombinedJson(Resource): "retriever": index.get("retriever", "classic"), "syncFrequency": index.get("sync_frequency", ""), "is_nested": bool(index.get("directory_structure")), - "type": index.get("type", "file") # Add type field with default "file" + "type": index.get( + "type", "file" + ), # Add type field with default "file" } ) except Exception as err: @@ -1288,17 +1405,14 @@ class GetAgent(Resource): def get(self): if not (decoded_token := request.decoded_token): return {"success": False}, 401 - if not (agent_id := request.args.get("id")): return {"success": False, "message": "ID required"}, 400 - try: agent = agents_collection.find_one( {"_id": ObjectId(agent_id), "user": decoded_token["sub"]} ) if not agent: return {"status": "Not found"}, 404 - data = { "id": str(agent["_id"]), "name": agent["name"], @@ -1312,6 +1426,16 @@ class GetAgent(Resource): and (source_doc := db.dereference(agent.get("source"))) else "" ), + "sources": [ + ( + str(db.dereference(source_ref)["_id"]) + if isinstance(source_ref, DBRef) and db.dereference(source_ref) + else source_ref + ) + for source_ref in agent.get("sources", []) + if (isinstance(source_ref, DBRef) and db.dereference(source_ref)) + or source_ref == "default" + ], "chunks": agent["chunks"], "retriever": agent.get("retriever", ""), "prompt_id": agent.get("prompt_id", ""), @@ -1334,7 +1458,6 @@ class GetAgent(Resource): "shared_token": agent.get("shared_token", ""), } return make_response(jsonify(data), 200) - except Exception as e: current_app.logger.error(f"Agent fetch error: {e}", exc_info=True) return {"success": False}, 400 @@ -1346,7 +1469,6 @@ class GetAgents(Resource): def get(self): if not (decoded_token := request.decoded_token): return {"success": False}, 401 - user = decoded_token.get("sub") try: user_doc = ensure_user_doc(user) @@ -1365,8 +1487,24 @@ class GetAgents(Resource): str(source_doc["_id"]) if isinstance(agent.get("source"), DBRef) and (source_doc := db.dereference(agent.get("source"))) - else "" + else ( + agent.get("source", "") + if agent.get("source") == "default" + else "" + ) ), + "sources": [ + ( + source_ref + if source_ref == "default" + else str(db.dereference(source_ref)["_id"]) + ) + for source_ref in agent.get("sources", []) + if source_ref == "default" + or ( + isinstance(source_ref, DBRef) and db.dereference(source_ref) + ) + ], "chunks": agent["chunks"], "retriever": agent.get("retriever", ""), "prompt_id": agent.get("prompt_id", ""), @@ -1409,7 +1547,14 @@ class CreateAgent(Resource): "image": fields.Raw( required=False, description="Image file upload", type="file" ), - "source": fields.String(required=True, description="Source ID"), + "source": fields.String( + required=False, description="Source ID (legacy single source)" + ), + "sources": fields.List( + fields.String, + required=False, + description="List of source identifiers for multiple sources", + ), "chunks": fields.Integer(required=True, description="Chunks count"), "retriever": fields.String(required=True, description="Retriever ID"), "prompt_id": fields.String(required=True, description="Prompt ID"), @@ -1421,7 +1566,8 @@ class CreateAgent(Resource): required=True, description="Status of the agent (draft or published)" ), "json_schema": fields.Raw( - required=False, description="JSON schema for enforcing structured output format" + required=False, + description="JSON schema for enforcing structured output format", ), }, ) @@ -1441,6 +1587,11 @@ class CreateAgent(Resource): data["tools"] = json.loads(data["tools"]) except json.JSONDecodeError: data["tools"] = [] + if "sources" in data: + try: + data["sources"] = json.loads(data["sources"]) + except json.JSONDecodeError: + data["sources"] = [] if "json_schema" in data: try: data["json_schema"] = json.loads(data["json_schema"]) @@ -1449,28 +1600,41 @@ class CreateAgent(Resource): print(f"Received data: {data}") # Validate JSON schema if provided + if data.get("json_schema"): try: # Basic validation - ensure it's a valid JSON structure + json_schema = data.get("json_schema") if not isinstance(json_schema, dict): return make_response( - jsonify({"success": False, "message": "JSON schema must be a valid JSON object"}), - 400 + jsonify( + { + "success": False, + "message": "JSON schema must be a valid JSON object", + } + ), + 400, ) - # Validate that it has either a 'schema' property or is itself a schema + if "schema" not in json_schema and "type" not in json_schema: return make_response( - jsonify({"success": False, "message": "JSON schema must contain either a 'schema' property or be a valid JSON schema with 'type' property"}), - 400 + jsonify( + { + "success": False, + "message": "JSON schema must contain either a 'schema' property or be a valid JSON schema with 'type' property", + } + ), + 400, ) except Exception as e: return make_response( - jsonify({"success": False, "message": f"Invalid JSON schema: {str(e)}"}), - 400 + jsonify( + {"success": False, "message": f"Invalid JSON schema: {str(e)}"} + ), + 400, ) - if data.get("status") not in ["draft", "published"]: return make_response( jsonify( @@ -1481,17 +1645,27 @@ class CreateAgent(Resource): ), 400, ) - if data.get("status") == "published": required_fields = [ "name", "description", - "source", "chunks", "retriever", "prompt_id", "agent_type", ] + # Require either source or sources (but not both) + + if not data.get("source") and not data.get("sources"): + return make_response( + jsonify( + { + "success": False, + "message": "Either 'source' or 'sources' field is required for published agents", + } + ), + 400, + ) validate_fields = ["name", "description", "prompt_id", "agent_type"] else: required_fields = ["name"] @@ -1502,25 +1676,37 @@ class CreateAgent(Resource): return missing_fields if invalid_fields: return invalid_fields - image_url, error = handle_image_upload(request, "", user, storage) if error: return make_response( jsonify({"success": False, "message": "Image upload failed"}), 400 ) - try: key = str(uuid.uuid4()) if data.get("status") == "published" else "" + + sources_list = [] + if data.get("sources") and len(data.get("sources", [])) > 0: + for source_id in data.get("sources", []): + if source_id == "default": + sources_list.append("default") + elif ObjectId.is_valid(source_id): + sources_list.append(DBRef("sources", ObjectId(source_id))) + source_field = "" + else: + source_value = data.get("source", "") + if source_value == "default": + source_field = "default" + elif ObjectId.is_valid(source_value): + source_field = DBRef("sources", ObjectId(source_value)) + else: + source_field = "" new_agent = { "user": user, "name": data.get("name"), "description": data.get("description", ""), "image": image_url, - "source": ( - DBRef("sources", ObjectId(data.get("source"))) - if ObjectId.is_valid(data.get("source")) - else "" - ), + "source": source_field, + "sources": sources_list, "chunks": data.get("chunks", ""), "retriever": data.get("retriever", ""), "prompt_id": data.get("prompt_id", ""), @@ -1535,7 +1721,11 @@ class CreateAgent(Resource): } if new_agent["chunks"] == "": new_agent["chunks"] = "0" - if new_agent["source"] == "" and new_agent["retriever"] == "": + if ( + new_agent["source"] == "" + and new_agent["retriever"] == "" + and not new_agent["sources"] + ): new_agent["retriever"] = "classic" resp = agents_collection.insert_one(new_agent) new_id = str(resp.inserted_id) @@ -1557,7 +1747,14 @@ class UpdateAgent(Resource): "image": fields.String( required=False, description="New image URL or identifier" ), - "source": fields.String(required=True, description="Source ID"), + "source": fields.String( + required=False, description="Source ID (legacy single source)" + ), + "sources": fields.List( + fields.String, + required=False, + description="List of source identifiers for multiple sources", + ), "chunks": fields.Integer(required=True, description="Chunks count"), "retriever": fields.String(required=True, description="Retriever ID"), "prompt_id": fields.String(required=True, description="Prompt ID"), @@ -1569,7 +1766,8 @@ class UpdateAgent(Resource): required=True, description="Status of the agent (draft or published)" ), "json_schema": fields.Raw( - required=False, description="JSON schema for enforcing structured output format" + required=False, + description="JSON schema for enforcing structured output format", ), }, ) @@ -1589,12 +1787,16 @@ class UpdateAgent(Resource): data["tools"] = json.loads(data["tools"]) except json.JSONDecodeError: data["tools"] = [] + if "sources" in data: + try: + data["sources"] = json.loads(data["sources"]) + except json.JSONDecodeError: + data["sources"] = [] if "json_schema" in data: try: data["json_schema"] = json.loads(data["json_schema"]) except json.JSONDecodeError: data["json_schema"] = None - if not ObjectId.is_valid(agent_id): return make_response( jsonify({"success": False, "message": "Invalid agent ID format"}), 400 @@ -1618,7 +1820,6 @@ class UpdateAgent(Resource): ), 404, ) - image_url, error = handle_image_upload( request, existing_agent.get("image", ""), user, storage ) @@ -1626,13 +1827,13 @@ class UpdateAgent(Resource): return make_response( jsonify({"success": False, "message": "Image upload failed"}), 400 ) - update_fields = {} allowed_fields = [ "name", "description", "image", "source", + "sources", "chunks", "retriever", "prompt_id", @@ -1656,7 +1857,11 @@ class UpdateAgent(Resource): update_fields[field] = new_status elif field == "source": source_id = data.get("source") - if source_id and ObjectId.is_valid(source_id): + if source_id == "default": + # Handle special "default" source + + update_fields[field] = "default" + elif source_id and ObjectId.is_valid(source_id): update_fields[field] = DBRef("sources", ObjectId(source_id)) elif source_id: return make_response( @@ -1670,6 +1875,30 @@ class UpdateAgent(Resource): ) else: update_fields[field] = "" + elif field == "sources": + sources_list = data.get("sources", []) + if sources_list and isinstance(sources_list, list): + valid_sources = [] + for source_id in sources_list: + if source_id == "default": + valid_sources.append("default") + elif ObjectId.is_valid(source_id): + valid_sources.append( + DBRef("sources", ObjectId(source_id)) + ) + else: + return make_response( + jsonify( + { + "success": False, + "message": f"Invalid source ID format: {source_id}", + } + ), + 400, + ) + update_fields[field] = valid_sources + else: + update_fields[field] = [] elif field == "chunks": chunks_value = data.get("chunks") if chunks_value == "": @@ -1735,7 +1964,6 @@ class UpdateAgent(Resource): ), 400, ) - if not existing_agent.get("key"): newly_generated_key = str(uuid.uuid4()) update_fields["key"] = newly_generated_key @@ -1822,7 +2050,6 @@ class PinnedAgents(Resource): decoded_token = request.decoded_token if not decoded_token: return make_response(jsonify({"success": False}), 401) - user_id = decoded_token.get("sub") try: @@ -1831,7 +2058,6 @@ class PinnedAgents(Resource): if not pinned_ids: return make_response(jsonify([]), 200) - pinned_object_ids = [ObjectId(agent_id) for agent_id in pinned_ids] pinned_agents_cursor = agents_collection.find( @@ -1841,6 +2067,7 @@ class PinnedAgents(Resource): existing_ids = {str(agent["_id"]) for agent in pinned_agents} # Clean up any stale pinned IDs + stale_ids = [ agent_id for agent_id in pinned_ids if agent_id not in existing_ids ] @@ -1849,7 +2076,6 @@ class PinnedAgents(Resource): {"user_id": user_id}, {"$pullAll": {"agent_preferences.pinned": stale_ids}}, ) - list_pinned_agents = [ { "id": str(agent["_id"]), @@ -1886,11 +2112,9 @@ class PinnedAgents(Resource): for agent in pinned_agents if "source" in agent or "retriever" in agent ] - except Exception as err: current_app.logger.error(f"Error retrieving pinned agents: {err}") return make_response(jsonify({"success": False}), 400) - return make_response(jsonify(list_pinned_agents), 200) @@ -1954,7 +2178,6 @@ class RemoveSharedAgent(Resource): return make_response( jsonify({"success": False, "message": "ID is required"}), 400 ) - try: agent = agents_collection.find_one( {"_id": ObjectId(agent_id), "shared_publicly": True} @@ -1964,7 +2187,6 @@ class RemoveSharedAgent(Resource): jsonify({"success": False, "message": "Shared agent not found"}), 404, ) - ensure_user_doc(user_id) users_collection.update_one( {"user_id": user_id}, @@ -1977,7 +2199,6 @@ class RemoveSharedAgent(Resource): ) return make_response(jsonify({"success": True, "action": "removed"}), 200) - except Exception as err: current_app.logger.error(f"Error removing shared agent: {err}") return make_response( @@ -2000,7 +2221,6 @@ class SharedAgent(Resource): return make_response( jsonify({"success": False, "message": "Token or ID is required"}), 400 ) - try: query = { "shared_publicly": True, @@ -2012,7 +2232,6 @@ class SharedAgent(Resource): jsonify({"success": False, "message": "Shared agent not found"}), 404, ) - agent_id = str(shared_agent["_id"]) data = { "id": agent_id, @@ -2052,7 +2271,6 @@ class SharedAgent(Resource): if tool_data: enriched_tools.append(tool_data.get("name", "")) data["tools"] = enriched_tools - decoded_token = getattr(request, "decoded_token", None) if decoded_token: user_id = decoded_token.get("sub") @@ -2064,9 +2282,7 @@ class SharedAgent(Resource): {"user_id": user_id}, {"$addToSet": {"agent_preferences.shared_with_me": agent_id}}, ) - return make_response(jsonify(data), 200) - except Exception as err: current_app.logger.error(f"Error retrieving shared agent: {err}") return make_response(jsonify({"success": False}), 400) @@ -2100,7 +2316,6 @@ class SharedAgents(Resource): {"user_id": user_id}, {"$pullAll": {"agent_preferences.shared_with_me": stale_ids}}, ) - pinned_ids = set(user_doc.get("agent_preferences", {}).get("pinned", [])) list_shared_agents = [ @@ -2127,7 +2342,6 @@ class SharedAgents(Resource): ] return make_response(jsonify(list_shared_agents), 200) - except Exception as err: current_app.logger.error(f"Error retrieving shared agents: {err}") return make_response(jsonify({"success": False}), 400) @@ -3577,7 +3791,7 @@ class GetChunks(Resource): "page": "Page number for pagination", "per_page": "Number of chunks per page", "path": "Optional: Filter chunks by relative file path", - "search": "Optional: Search term to filter chunks by title or content" + "search": "Optional: Search term to filter chunks by title or content", }, ) def get(self): @@ -3607,22 +3821,22 @@ class GetChunks(Resource): metadata = chunk.get("metadata", {}) # Filter by path if provided + if path: chunk_source = metadata.get("source", "") # Check if the chunk's source matches the requested path + if not chunk_source or not chunk_source.endswith(path): continue - # Filter by search term if provided + if search_term: text_match = search_term in chunk.get("text", "").lower() title_match = search_term in metadata.get("title", "").lower() if not (text_match or title_match): continue - filtered_chunks.append(chunk) - chunks = filtered_chunks total_chunks = len(chunks) @@ -3638,7 +3852,7 @@ class GetChunks(Resource): "total": total_chunks, "chunks": paginated_chunks, "path": path if path else None, - "search": search_term if search_term else None + "search": search_term if search_term else None, } ), 200, @@ -3647,6 +3861,7 @@ class GetChunks(Resource): current_app.logger.error(f"Error getting chunks: {e}", exc_info=True) return make_response(jsonify({"success": False}), 500) + @user_ns.route("/api/add_chunk") class AddChunk(Resource): @api.expect( @@ -3781,7 +3996,6 @@ class UpdateChunk(Resource): if metadata is None: metadata = {} metadata["token_count"] = token_count - if not ObjectId.is_valid(doc_id): return make_response(jsonify({"error": "Invalid doc_id"}), 400) doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user}) @@ -3796,7 +4010,6 @@ class UpdateChunk(Resource): existing_chunk = next((c for c in chunks if c["doc_id"] == chunk_id), None) if not existing_chunk: return make_response(jsonify({"error": "Chunk not found"}), 404) - new_text = text if text is not None else existing_chunk["text"] if metadata is not None: @@ -3804,17 +4017,16 @@ class UpdateChunk(Resource): new_metadata.update(metadata) else: new_metadata = existing_chunk["metadata"].copy() - if text is not None: new_metadata["token_count"] = num_tokens_from_string(new_text) - try: new_chunk_id = store.add_chunk(new_text, new_metadata) deleted = store.delete_chunk(chunk_id) if not deleted: - current_app.logger.warning(f"Failed to delete old chunk {chunk_id}, but new chunk {new_chunk_id} was created") - + current_app.logger.warning( + f"Failed to delete old chunk {chunk_id}, but new chunk {new_chunk_id} was created" + ) return make_response( jsonify( { @@ -3861,7 +4073,6 @@ class StoreAttachment(Resource): jsonify({"status": "error", "message": "Missing file"}), 400, ) - user = None if decoded_token: user = safe_filename(decoded_token.get("sub")) @@ -3876,7 +4087,6 @@ class StoreAttachment(Resource): return make_response( jsonify({"success": False, "message": "Authentication required"}), 401 ) - try: attachment_id = ObjectId() original_filename = safe_filename(os.path.basename(file.filename)) @@ -3918,7 +4128,6 @@ class ServeImage(Resource): content_type = f"image/{extension}" if extension == "jpg": content_type = "image/jpeg" - response = make_response(file_obj.read()) response.headers.set("Content-Type", content_type) response.headers.set("Cache-Control", "max-age=86400") @@ -3945,25 +4154,19 @@ class DirectoryStructure(Resource): decoded_token = request.decoded_token if not decoded_token: return make_response(jsonify({"success": False}), 401) - user = decoded_token.get("sub") doc_id = request.args.get("id") if not doc_id: - return make_response( - jsonify({"error": "Document ID is required"}), 400 - ) - + return make_response(jsonify({"error": "Document ID is required"}), 400) if not ObjectId.is_valid(doc_id): return make_response(jsonify({"error": "Invalid document ID"}), 400) - try: doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user}) if not doc: return make_response( jsonify({"error": "Document not found or access denied"}), 404 ) - directory_structure = doc.get("directory_structure", {}) base_path = doc.get("file_path", "") @@ -3975,24 +4178,21 @@ class DirectoryStructure(Resource): provider = remote_data_obj.get("provider") except Exception as e: current_app.logger.warning( - f"Failed to parse remote_data for doc {doc_id}: {e}") - + f"Failed to parse remote_data for doc {doc_id}: {e}" + ) return make_response( - jsonify({ - "success": True, - "directory_structure": directory_structure, - "base_path": base_path, - "provider": provider, - }), 200 + jsonify( + { + "success": True, + "directory_structure": directory_structure, + "base_path": base_path, + "provider": provider, + } + ), + 200, ) - except Exception as e: current_app.logger.error( f"Error retrieving directory structure: {e}", exc_info=True ) - return make_response( - jsonify({"success": False, "error": str(e)}), 500 - ) - - - + return make_response(jsonify({"success": False, "error": str(e)}), 500) diff --git a/application/retriever/base.py b/application/retriever/base.py index fd99dbdd..36ac2e93 100644 --- a/application/retriever/base.py +++ b/application/retriever/base.py @@ -5,10 +5,6 @@ class BaseRetriever(ABC): def __init__(self): pass - @abstractmethod - def gen(self, *args, **kwargs): - pass - @abstractmethod def search(self, *args, **kwargs): pass diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index 9416b4f7..2ce863c2 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -1,4 +1,5 @@ import logging + from application.core.settings import settings from application.llm.llm_creator import LLMCreator from application.retriever.base import BaseRetriever @@ -20,10 +21,20 @@ class ClassicRAG(BaseRetriever): api_key=settings.API_KEY, decoded_token=None, ): - self.original_question = "" + """Initialize ClassicRAG retriever with vectorstore sources and LLM configuration""" + self.original_question = source.get("question", "") self.chat_history = chat_history if chat_history is not None else [] self.prompt = prompt - self.chunks = chunks + if isinstance(chunks, str): + try: + self.chunks = int(chunks) + except ValueError: + logging.warning( + f"Invalid chunks value '{chunks}', using default value 2" + ) + self.chunks = 2 + else: + self.chunks = chunks self.gpt_model = gpt_model self.token_limit = ( token_limit @@ -44,25 +55,52 @@ class ClassicRAG(BaseRetriever): user_api_key=self.user_api_key, decoded_token=decoded_token, ) - self.vectorstore = source["active_docs"] if "active_docs" in source else None + + if "active_docs" in source and source["active_docs"] is not None: + if isinstance(source["active_docs"], list): + self.vectorstores = source["active_docs"] + else: + self.vectorstores = [source["active_docs"]] + else: + self.vectorstores = [] self.question = self._rephrase_query() self.decoded_token = decoded_token + self._validate_vectorstore_config() + + def _validate_vectorstore_config(self): + """Validate vectorstore IDs and remove any empty/invalid entries""" + if not self.vectorstores: + logging.warning("No vectorstores configured for retrieval") + return + invalid_ids = [ + vs_id for vs_id in self.vectorstores if not vs_id or not vs_id.strip() + ] + if invalid_ids: + logging.warning(f"Found invalid vectorstore IDs: {invalid_ids}") + self.vectorstores = [ + vs_id for vs_id in self.vectorstores if vs_id and vs_id.strip() + ] def _rephrase_query(self): + """Rephrase user query with chat history context for better retrieval""" if ( not self.original_question or not self.chat_history or self.chat_history == [] or self.chunks == 0 - or self.vectorstore is None + or not self.vectorstores ): return self.original_question - prompt = f"""Given the following conversation history: + {self.chat_history} + + Rephrase the following user question to be a standalone search query + that captures all relevant context from the conversation: + """ messages = [ @@ -79,44 +117,62 @@ class ClassicRAG(BaseRetriever): return self.original_question def _get_data(self): - if self.chunks == 0 or self.vectorstore is None: - docs = [] - else: - docsearch = VectorCreator.create_vectorstore( - settings.VECTOR_STORE, self.vectorstore, settings.EMBEDDINGS_KEY - ) - docs_temp = docsearch.search(self.question, k=self.chunks) - docs = [ - { - "title": i.metadata.get( - "title", i.metadata.get("post_title", i.page_content) - ).split("/")[-1], - "text": i.page_content, - "source": ( - i.metadata.get("source") - if i.metadata.get("source") - else "local" - ), - } - for i in docs_temp - ] + """Retrieve relevant documents from configured vectorstores""" + if self.chunks == 0 or not self.vectorstores: + return [] + all_docs = [] + chunks_per_source = max(1, self.chunks // len(self.vectorstores)) - return docs + for vectorstore_id in self.vectorstores: + if vectorstore_id: + try: + docsearch = VectorCreator.create_vectorstore( + settings.VECTOR_STORE, vectorstore_id, settings.EMBEDDINGS_KEY + ) + docs_temp = docsearch.search(self.question, k=chunks_per_source) - def gen(): - pass + for doc in docs_temp: + if hasattr(doc, "page_content") and hasattr(doc, "metadata"): + page_content = doc.page_content + metadata = doc.metadata + else: + page_content = doc.get("text", doc.get("page_content", "")) + metadata = doc.get("metadata", {}) + title = metadata.get( + "title", metadata.get("post_title", page_content) + ) + if isinstance(title, str): + title = title.split("/")[-1] + else: + title = str(title).split("/")[-1] + all_docs.append( + { + "title": title, + "text": page_content, + "source": metadata.get("source") or vectorstore_id, + } + ) + except Exception as e: + logging.error( + f"Error searching vectorstore {vectorstore_id}: {e}", + exc_info=True, + ) + continue + return all_docs def search(self, query: str = ""): + """Search for documents using optional query override""" if query: self.original_question = query self.question = self._rephrase_query() return self._get_data() def get_params(self): + """Return current retriever configuration parameters""" return { "question": self.original_question, "rephrased_question": self.question, - "source": self.vectorstore, + "sources": self.vectorstores, "chunks": self.chunks, "token_limit": self.token_limit, "gpt_model": self.gpt_model, diff --git a/application/vectorstore/base.py b/application/vectorstore/base.py index a6b206c9..ea4885cd 100644 --- a/application/vectorstore/base.py +++ b/application/vectorstore/base.py @@ -1,20 +1,28 @@ -from abc import ABC, abstractmethod import os -from sentence_transformers import SentenceTransformer +from abc import ABC, abstractmethod + from langchain_openai import OpenAIEmbeddings +from sentence_transformers import SentenceTransformer + from application.core.settings import settings + class EmbeddingsWrapper: def __init__(self, model_name, *args, **kwargs): - self.model = SentenceTransformer(model_name, config_kwargs={'allow_dangerous_deserialization': True}, *args, **kwargs) + self.model = SentenceTransformer( + model_name, + config_kwargs={"allow_dangerous_deserialization": True}, + *args, + **kwargs + ) self.dimension = self.model.get_sentence_embedding_dimension() def embed_query(self, query: str): return self.model.encode(query).tolist() - + def embed_documents(self, documents: list): return self.model.encode(documents).tolist() - + def __call__(self, text): if isinstance(text, str): return self.embed_query(text) @@ -24,15 +32,14 @@ class EmbeddingsWrapper: raise ValueError("Input must be a string or a list of strings") - class EmbeddingsSingleton: _instances = {} @staticmethod def get_instance(embeddings_name, *args, **kwargs): if embeddings_name not in EmbeddingsSingleton._instances: - EmbeddingsSingleton._instances[embeddings_name] = EmbeddingsSingleton._create_instance( - embeddings_name, *args, **kwargs + EmbeddingsSingleton._instances[embeddings_name] = ( + EmbeddingsSingleton._create_instance(embeddings_name, *args, **kwargs) ) return EmbeddingsSingleton._instances[embeddings_name] @@ -40,9 +47,15 @@ class EmbeddingsSingleton: def _create_instance(embeddings_name, *args, **kwargs): embeddings_factory = { "openai_text-embedding-ada-002": OpenAIEmbeddings, - "huggingface_sentence-transformers/all-mpnet-base-v2": lambda: EmbeddingsWrapper("sentence-transformers/all-mpnet-base-v2"), - "huggingface_sentence-transformers-all-mpnet-base-v2": lambda: EmbeddingsWrapper("sentence-transformers/all-mpnet-base-v2"), - "huggingface_hkunlp/instructor-large": lambda: EmbeddingsWrapper("hkunlp/instructor-large"), + "huggingface_sentence-transformers/all-mpnet-base-v2": lambda: EmbeddingsWrapper( + "sentence-transformers/all-mpnet-base-v2" + ), + "huggingface_sentence-transformers-all-mpnet-base-v2": lambda: EmbeddingsWrapper( + "sentence-transformers/all-mpnet-base-v2" + ), + "huggingface_hkunlp/instructor-large": lambda: EmbeddingsWrapper( + "hkunlp/instructor-large" + ), } if embeddings_name in embeddings_factory: @@ -50,34 +63,63 @@ class EmbeddingsSingleton: else: return EmbeddingsWrapper(embeddings_name, *args, **kwargs) + class BaseVectorStore(ABC): def __init__(self): pass @abstractmethod def search(self, *args, **kwargs): + """Search for similar documents/chunks in the vectorstore""" + pass + + @abstractmethod + def add_texts(self, texts, metadatas=None, *args, **kwargs): + """Add texts with their embeddings to the vectorstore""" + pass + + def delete_index(self, *args, **kwargs): + """Delete the entire index/collection""" + pass + + def save_local(self, *args, **kwargs): + """Save vectorstore to local storage""" + pass + + def get_chunks(self, *args, **kwargs): + """Get all chunks from the vectorstore""" + pass + + def add_chunk(self, text, metadata=None, *args, **kwargs): + """Add a single chunk to the vectorstore""" + pass + + def delete_chunk(self, chunk_id, *args, **kwargs): + """Delete a specific chunk from the vectorstore""" pass def is_azure_configured(self): - return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME + return ( + settings.OPENAI_API_BASE + and settings.OPENAI_API_VERSION + and settings.AZURE_DEPLOYMENT_NAME + ) def _get_embeddings(self, embeddings_name, embeddings_key=None): if embeddings_name == "openai_text-embedding-ada-002": if self.is_azure_configured(): os.environ["OPENAI_API_TYPE"] = "azure" embedding_instance = EmbeddingsSingleton.get_instance( - embeddings_name, - model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME + embeddings_name, model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME ) else: embedding_instance = EmbeddingsSingleton.get_instance( - embeddings_name, - openai_api_key=embeddings_key + embeddings_name, openai_api_key=embeddings_key ) elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2": if os.path.exists("./models/all-mpnet-base-v2"): embedding_instance = EmbeddingsSingleton.get_instance( - embeddings_name = "./models/all-mpnet-base-v2", + embeddings_name="./models/all-mpnet-base-v2", ) else: embedding_instance = EmbeddingsSingleton.get_instance( @@ -87,4 +129,3 @@ class BaseVectorStore(ABC): embedding_instance = EmbeddingsSingleton.get_instance(embeddings_name) return embedding_instance - diff --git a/frontend/src/agents/NewAgent.tsx b/frontend/src/agents/NewAgent.tsx index af7f109e..92e8b961 100644 --- a/frontend/src/agents/NewAgent.tsx +++ b/frontend/src/agents/NewAgent.tsx @@ -45,6 +45,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { description: '', image: '', source: '', + sources: [], chunks: '', retriever: '', prompt_id: 'default', @@ -150,7 +151,41 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { const formData = new FormData(); formData.append('name', agent.name); formData.append('description', agent.description); - formData.append('source', agent.source); + + if (selectedSourceIds.size > 1) { + const sourcesArray = Array.from(selectedSourceIds) + .map((id) => { + const sourceDoc = sourceDocs?.find( + (source) => + source.id === id || source.retriever === id || source.name === id, + ); + if (sourceDoc?.name === 'Default' && !sourceDoc?.id) { + return 'default'; + } + return sourceDoc?.id || id; + }) + .filter(Boolean); + formData.append('sources', JSON.stringify(sourcesArray)); + formData.append('source', ''); + } else if (selectedSourceIds.size === 1) { + const singleSourceId = Array.from(selectedSourceIds)[0]; + const sourceDoc = sourceDocs?.find( + (source) => + source.id === singleSourceId || + source.retriever === singleSourceId || + source.name === singleSourceId, + ); + let finalSourceId; + if (sourceDoc?.name === 'Default' && !sourceDoc?.id) + finalSourceId = 'default'; + else finalSourceId = sourceDoc?.id || singleSourceId; + formData.append('source', String(finalSourceId)); + formData.append('sources', JSON.stringify([])); + } else { + formData.append('source', ''); + formData.append('sources', JSON.stringify([])); + } + formData.append('chunks', agent.chunks); formData.append('retriever', agent.retriever); formData.append('prompt_id', agent.prompt_id); @@ -196,7 +231,41 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { const formData = new FormData(); formData.append('name', agent.name); formData.append('description', agent.description); - formData.append('source', agent.source); + + if (selectedSourceIds.size > 1) { + const sourcesArray = Array.from(selectedSourceIds) + .map((id) => { + const sourceDoc = sourceDocs?.find( + (source) => + source.id === id || source.retriever === id || source.name === id, + ); + if (sourceDoc?.name === 'Default' && !sourceDoc?.id) { + return 'default'; + } + return sourceDoc?.id || id; + }) + .filter(Boolean); + formData.append('sources', JSON.stringify(sourcesArray)); + formData.append('source', ''); + } else if (selectedSourceIds.size === 1) { + const singleSourceId = Array.from(selectedSourceIds)[0]; + const sourceDoc = sourceDocs?.find( + (source) => + source.id === singleSourceId || + source.retriever === singleSourceId || + source.name === singleSourceId, + ); + let finalSourceId; + if (sourceDoc?.name === 'Default' && !sourceDoc?.id) + finalSourceId = 'default'; + else finalSourceId = sourceDoc?.id || singleSourceId; + formData.append('source', String(finalSourceId)); + formData.append('sources', JSON.stringify([])); + } else { + formData.append('source', ''); + formData.append('sources', JSON.stringify([])); + } + formData.append('chunks', agent.chunks); formData.append('retriever', agent.retriever); formData.append('prompt_id', agent.prompt_id); @@ -293,9 +362,33 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { throw new Error('Failed to fetch agent'); } const data = await response.json(); - if (data.source) setSelectedSourceIds(new Set([data.source])); - else if (data.retriever) + + if (data.sources && data.sources.length > 0) { + const mappedSources = data.sources.map((sourceId: string) => { + if (sourceId === 'default') { + const defaultSource = sourceDocs?.find( + (source) => source.name === 'Default', + ); + return defaultSource?.retriever || 'classic'; + } + return sourceId; + }); + setSelectedSourceIds(new Set(mappedSources)); + } else if (data.source) { + if (data.source === 'default') { + const defaultSource = sourceDocs?.find( + (source) => source.name === 'Default', + ); + setSelectedSourceIds( + new Set([defaultSource?.retriever || 'classic']), + ); + } else { + setSelectedSourceIds(new Set([data.source])); + } + } else if (data.retriever) { setSelectedSourceIds(new Set([data.retriever])); + } + if (data.tools) setSelectedToolIds(new Set(data.tools)); if (data.status === 'draft') setEffectiveMode('draft'); if (data.json_schema) { @@ -311,25 +404,57 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { }, [agentId, mode, token]); useEffect(() => { - const selectedSource = Array.from(selectedSourceIds).map((id) => - sourceDocs?.find( - (source) => - source.id === id || source.retriever === id || source.name === id, - ), - ); - if (selectedSource[0]?.model === embeddingsName) { - if (selectedSource[0] && 'id' in selectedSource[0]) { + const selectedSources = Array.from(selectedSourceIds) + .map((id) => + sourceDocs?.find( + (source) => + source.id === id || source.retriever === id || source.name === id, + ), + ) + .filter(Boolean); + + if (selectedSources.length > 0) { + // Handle multiple sources + if (selectedSources.length > 1) { + // Multiple sources selected - store in sources array + const sourceIds = selectedSources + .map((source) => source?.id) + .filter((id): id is string => Boolean(id)); setAgent((prev) => ({ ...prev, - source: selectedSource[0]?.id || 'default', + sources: sourceIds, + source: '', // Clear single source for multiple sources retriever: '', })); - } else - setAgent((prev) => ({ - ...prev, - source: '', - retriever: selectedSource[0]?.retriever || 'classic', - })); + } else { + // Single source selected - maintain backward compatibility + const selectedSource = selectedSources[0]; + if (selectedSource?.model === embeddingsName) { + if (selectedSource && 'id' in selectedSource) { + setAgent((prev) => ({ + ...prev, + source: selectedSource?.id || 'default', + sources: [], // Clear sources array for single source + retriever: '', + })); + } else { + setAgent((prev) => ({ + ...prev, + source: '', + sources: [], // Clear sources array + retriever: selectedSource?.retriever || 'classic', + })); + } + } + } + } else { + // No sources selected + setAgent((prev) => ({ + ...prev, + source: '', + sources: [], + retriever: '', + })); } }, [selectedSourceIds]); @@ -510,7 +635,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { ) .filter(Boolean) .join(', ') - : 'Select source'} + : 'Select sources'} ) => { setSelectedSourceIds(newSelectedIds); - setIsSourcePopupOpen(false); }} - title="Select Source" + title="Select Sources" searchPlaceholder="Search sources..." - noOptionsMessage="No source available" - singleSelect={true} + noOptionsMessage="No sources available" />
diff --git a/frontend/src/agents/types/index.ts b/frontend/src/agents/types/index.ts index e841cb0a..442097a1 100644 --- a/frontend/src/agents/types/index.ts +++ b/frontend/src/agents/types/index.ts @@ -10,6 +10,7 @@ export type Agent = { description: string; image: string; source: string; + sources?: string[]; chunks: string; retriever: string; prompt_id: string; diff --git a/frontend/src/api/services/userService.ts b/frontend/src/api/services/userService.ts index 6e375951..1cb4bbd6 100644 --- a/frontend/src/api/services/userService.ts +++ b/frontend/src/api/services/userService.ts @@ -90,7 +90,10 @@ const userService = { path?: string, search?: string, ): Promise => - apiClient.get(endpoints.USER.GET_CHUNKS(docId, page, perPage, path, search), token), + apiClient.get( + endpoints.USER.GET_CHUNKS(docId, page, perPage, path, search), + token, + ), addChunk: (data: any, token: string | null): Promise => apiClient.post(endpoints.USER.ADD_CHUNK, data, token), deleteChunk: ( @@ -105,16 +108,20 @@ const userService = { apiClient.get(endpoints.USER.DIRECTORY_STRUCTURE(docId), token), manageSourceFiles: (data: FormData, token: string | null): Promise => apiClient.postFormData(endpoints.USER.MANAGE_SOURCE_FILES, data, token), - syncConnector: (docId: string, provider: string, token: string | null): Promise => { + syncConnector: ( + docId: string, + provider: string, + token: string | null, + ): Promise => { const sessionToken = getSessionToken(provider); return apiClient.post( endpoints.USER.SYNC_CONNECTOR, { source_id: docId, session_token: sessionToken, - provider: provider + provider: provider, }, - token + token, ); }, }; diff --git a/frontend/src/components/ConnectorAuth.tsx b/frontend/src/components/ConnectorAuth.tsx index 22566521..61b6e895 100644 --- a/frontend/src/components/ConnectorAuth.tsx +++ b/frontend/src/components/ConnectorAuth.tsx @@ -16,7 +16,12 @@ const providerLabel = (provider: string) => { return map[provider] || provider.replace(/_/g, ' '); }; -const ConnectorAuth: React.FC = ({ provider, onSuccess, onError, label }) => { +const ConnectorAuth: React.FC = ({ + provider, + onSuccess, + onError, + label, +}) => { const token = useSelector(selectToken); const completedRef = useRef(false); const intervalRef = useRef(null); @@ -31,8 +36,12 @@ const ConnectorAuth: React.FC = ({ provider, onSuccess, onEr const handleAuthMessage = (event: MessageEvent) => { const successGeneric = event.data?.type === 'connector_auth_success'; - const successProvider = event.data?.type === `${provider}_auth_success` || event.data?.type === 'google_drive_auth_success'; - const errorProvider = event.data?.type === `${provider}_auth_error` || event.data?.type === 'google_drive_auth_error'; + const successProvider = + event.data?.type === `${provider}_auth_success` || + event.data?.type === 'google_drive_auth_success'; + const errorProvider = + event.data?.type === `${provider}_auth_error` || + event.data?.type === 'google_drive_auth_error'; if (successGeneric || successProvider) { completedRef.current = true; @@ -54,12 +63,17 @@ const ConnectorAuth: React.FC = ({ provider, onSuccess, onEr cleanup(); const apiHost = import.meta.env.VITE_API_HOST; - const authResponse = await fetch(`${apiHost}/api/connectors/auth?provider=${provider}`, { - headers: { Authorization: `Bearer ${token}` }, - }); + const authResponse = await fetch( + `${apiHost}/api/connectors/auth?provider=${provider}`, + { + headers: { Authorization: `Bearer ${token}` }, + }, + ); if (!authResponse.ok) { - throw new Error(`Failed to get authorization URL: ${authResponse.status}`); + throw new Error( + `Failed to get authorization URL: ${authResponse.status}`, + ); } const authData = await authResponse.json(); @@ -70,10 +84,12 @@ const ConnectorAuth: React.FC = ({ provider, onSuccess, onEr const authWindow = window.open( authData.authorization_url, `${provider}-auth`, - 'width=500,height=600,scrollbars=yes,resizable=yes' + 'width=500,height=600,scrollbars=yes,resizable=yes', ); if (!authWindow) { - throw new Error('Failed to open authentication window. Please allow popups.'); + throw new Error( + 'Failed to open authentication window. Please allow popups.', + ); } window.addEventListener('message', handleAuthMessage as any); @@ -98,10 +114,13 @@ const ConnectorAuth: React.FC = ({ provider, onSuccess, onEr return ( @@ -109,4 +128,3 @@ const ConnectorAuth: React.FC = ({ provider, onSuccess, onEr }; export default ConnectorAuth; - diff --git a/frontend/src/components/ConnectorTreeComponent.tsx b/frontend/src/components/ConnectorTreeComponent.tsx index 53900d0f..9249145c 100644 --- a/frontend/src/components/ConnectorTreeComponent.tsx +++ b/frontend/src/components/ConnectorTreeComponent.tsx @@ -227,8 +227,6 @@ const ConnectorTreeComponent: React.FC = ({ return current; }; - - const getMenuRef = (id: string) => { if (!menuRefs.current[id]) { menuRefs.current[id] = React.createRef(); diff --git a/frontend/src/components/FileTreeComponent.tsx b/frontend/src/components/FileTreeComponent.tsx index 3e97d4e9..724ca233 100644 --- a/frontend/src/components/FileTreeComponent.tsx +++ b/frontend/src/components/FileTreeComponent.tsx @@ -11,9 +11,7 @@ import FolderIcon from '../assets/folder.svg'; import ArrowLeft from '../assets/arrow-left.svg'; import ThreeDots from '../assets/three-dots.svg'; import EyeView from '../assets/eye-view.svg'; -import OutlineSource from '../assets/outline-source.svg'; import Trash from '../assets/red-trash.svg'; -import SearchIcon from '../assets/search.svg'; import { useOutsideAlerter } from '../hooks'; import ConfirmationModal from '../modals/ConfirmationModal'; @@ -129,8 +127,6 @@ const FileTreeComponent: React.FC = ({ } }, [docId, token]); - - const navigateToDirectory = (dirName: string) => { setCurrentPath((prev) => [...prev, dirName]); }; @@ -438,18 +434,18 @@ const FileTreeComponent: React.FC = ({ const renderPathNavigation = () => { return ( -
+
{/* Left side with path navigation */}
- + {sourceName} {currentPath.length > 0 && ( @@ -480,8 +476,7 @@ const FileTreeComponent: React.FC = ({
-
- +
{processingRef.current && (
{currentOpRef.current === 'add' @@ -490,13 +485,13 @@ const FileTreeComponent: React.FC = ({
)} - {renderFileSearch()} + {renderFileSearch()} {/* Add file button */} {!processingRef.current && ( ) : (
@@ -336,7 +335,7 @@ export default function Sources({
{loading ? ( -
+
) : !currentDocuments?.length ? ( @@ -351,19 +350,19 @@ export default function Sources({

) : ( -
- {currentDocuments.map((document, index) => { - const docId = document.id ? document.id.toString() : ''; +
+ {currentDocuments.map((document, index) => { + const docId = document.id ? document.id.toString() : ''; - return ( -
-
+ return ( +
+

{document.date ? formatDate(document.date) : ''} @@ -437,7 +436,7 @@ export default function Sources({ {document.tokens diff --git a/frontend/src/upload/Upload.tsx b/frontend/src/upload/Upload.tsx index 46a36f4c..610c61b6 100644 --- a/frontend/src/upload/Upload.tsx +++ b/frontend/src/upload/Upload.tsx @@ -4,7 +4,11 @@ import { useTranslation } from 'react-i18next'; import { useDispatch, useSelector } from 'react-redux'; import userService from '../api/services/userService'; -import { getSessionToken, setSessionToken, removeSessionToken } from '../utils/providerUtils'; +import { + getSessionToken, + setSessionToken, + removeSessionToken, +} from '../utils/providerUtils'; import { formatDate } from '../utils/dateTimeUtils'; import { formatBytes } from '../utils/stringUtils'; import FileUpload from '../assets/file_upload.svg'; @@ -63,7 +67,9 @@ function Upload({ const [userEmail, setUserEmail] = useState(''); const [authError, setAuthError] = useState(''); const [currentFolderId, setCurrentFolderId] = useState(null); - const [folderPath, setFolderPath] = useState>([{id: null, name: 'My Drive'}]); + const [folderPath, setFolderPath] = useState< + Array<{ id: string | null; name: string }> + >([{ id: null, name: 'My Drive' }]); const [nextPageToken, setNextPageToken] = useState(null); const [hasMoreFiles, setHasMoreFiles] = useState(false); @@ -337,7 +343,8 @@ function Upload({ data?.find( (d: Doc) => d.type?.toLowerCase() === 'local', ), - )); + ), + ); }); setProgress( (progress) => @@ -454,23 +461,31 @@ function Upload({ if (ingestor.type === 'google_drive') { const sessionToken = getSessionToken(ingestor.type); - const selectedItems = googleDriveFiles.filter(file => selectedFiles.includes(file.id)); + const selectedItems = googleDriveFiles.filter((file) => + selectedFiles.includes(file.id), + ); const selectedFolderIds = selectedItems - .filter(item => item.type === 'application/vnd.google-apps.folder' || item.isFolder) - .map(folder => folder.id); + .filter( + (item) => + item.type === 'application/vnd.google-apps.folder' || item.isFolder, + ) + .map((folder) => folder.id); const selectedFileIds = selectedItems - .filter(item => item.type !== 'application/vnd.google-apps.folder' && !item.isFolder) - .map(file => file.id); + .filter( + (item) => + item.type !== 'application/vnd.google-apps.folder' && + !item.isFolder, + ) + .map((file) => file.id); configData = { file_ids: selectedFileIds, folder_ids: selectedFolderIds, recursive: ingestor.config.recursive, - session_token: sessionToken || null + session_token: sessionToken || null, }; } else { - configData = { ...ingestor.config }; } @@ -522,14 +537,20 @@ function Upload({ try { const apiHost = import.meta.env.VITE_API_HOST; - const validateResponse = await fetch(`${apiHost}/api/connectors/validate-session`, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - 'Authorization': `Bearer ${token}` + const validateResponse = await fetch( + `${apiHost}/api/connectors/validate-session`, + { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}`, + }, + body: JSON.stringify({ + provider: 'google_drive', + session_token: sessionToken, + }), }, - body: JSON.stringify({ provider: 'google_drive', session_token: sessionToken }) - }); + ); if (!validateResponse.ok) { removeSessionToken(ingestor.type); @@ -545,15 +566,16 @@ function Upload({ // reset pagination state and files setGoogleDriveFiles([]); - - setNextPageToken(null); setHasMoreFiles(false); loadGoogleDriveFiles(sessionToken, null, null, false); } else { removeSessionToken(ingestor.type); setIsGoogleDriveConnected(false); - setAuthError(validateData.error || 'Session expired. Please reconnect your Google Drive account and make sure to grant offline access.'); + setAuthError( + validateData.error || + 'Session expired. Please reconnect your Google Drive account and make sure to grant offline access.', + ); } } catch (error) { console.error('Error validating Google Drive session:', error); @@ -566,7 +588,7 @@ function Upload({ sessionToken: string, folderId?: string | null, pageToken?: string | null, - append: boolean = false, + append = false, ) => { setIsLoadingFiles(true); @@ -587,9 +609,9 @@ function Upload({ method: 'POST', headers: { 'Content-Type': 'application/json', - 'Authorization': `Bearer ${token}` + Authorization: `Bearer ${token}`, }, - body: JSON.stringify({ ...requestBody, provider: 'google_drive' }) + body: JSON.stringify({ ...requestBody, provider: 'google_drive' }), }); if (!filesResponse.ok) { @@ -599,28 +621,31 @@ function Upload({ const filesData = await filesResponse.json(); if (filesData.success && Array.isArray(filesData.files)) { - setGoogleDriveFiles(prev => append ? [...prev, ...filesData.files] : filesData.files); + setGoogleDriveFiles((prev) => + append ? [...prev, ...filesData.files] : filesData.files, + ); setNextPageToken(filesData.next_page_token || null); setHasMoreFiles(Boolean(filesData.has_more)); } else { throw new Error(filesData.error || 'Failed to load files'); } - } catch (error) { console.error('Error loading Google Drive files:', error); - setAuthError(error instanceof Error ? error.message : 'Failed to load files. Please make sure your Google Drive account is properly connected and you granted offline access during authorization.'); + setAuthError( + error instanceof Error + ? error.message + : 'Failed to load files. Please make sure your Google Drive account is properly connected and you granted offline access during authorization.', + ); } finally { setIsLoadingFiles(false); } }; - - // Handle file selection const handleFileSelect = (fileId: string) => { - setSelectedFiles(prev => { + setSelectedFiles((prev) => { if (prev.includes(fileId)) { - return prev.filter(id => id !== fileId); + return prev.filter((id) => id !== fileId); } else { return [...prev, fileId]; } @@ -631,7 +656,7 @@ function Upload({ const sessionToken = getSessionToken(ingestor.type); if (sessionToken) { setCurrentFolderId(folderId); - setFolderPath(prev => [...prev, {id: folderId, name: folderName}]); + setFolderPath((prev) => [...prev, { id: folderId, name: folderName }]); setGoogleDriveFiles([]); setNextPageToken(null); @@ -662,7 +687,7 @@ function Upload({ if (selectedFiles.length === googleDriveFiles.length) { setSelectedFiles([]); } else { - setSelectedFiles(googleDriveFiles.map(file => file.id)); + setSelectedFiles(googleDriveFiles.map((file) => file.id)); } }; @@ -829,7 +854,7 @@ function Upload({ {files.map((file) => (

{file.name} @@ -905,10 +930,13 @@ function Upload({ ) : (

{/* Connection Status */} -
+
- + Connected as {userEmail}
@@ -927,28 +955,41 @@ function Upload({ method: 'POST', headers: { 'Content-Type': 'application/json', - 'Authorization': `Bearer ${token}` + Authorization: `Bearer ${token}`, }, - body: JSON.stringify({ provider: ingestor.type, session_token: getSessionToken(ingestor.type) }) - }).catch(err => console.error('Error disconnecting from Google Drive:', err)); + body: JSON.stringify({ + provider: ingestor.type, + session_token: getSessionToken(ingestor.type), + }), + }).catch((err) => + console.error( + 'Error disconnecting from Google Drive:', + err, + ), + ); }} - className="text-white hover:text-gray-200 text-xs underline" + className="text-xs text-white underline hover:text-gray-200" > Disconnect
{/* File Browser */} -
-
+
+
{/* Breadcrumb navigation */} -
+
{folderPath.map((path, index) => ( -
- {index > 0 && /} +
+ {index > 0 && ( + / + )} )}
{selectedFiles.length > 0 && ( -

- {selectedFiles.length} file{selectedFiles.length !== 1 ? 's' : ''} selected +

+ {selectedFiles.length} file + {selectedFiles.length !== 1 ? 's' : ''} selected

)}
-
+
{isLoadingFiles && googleDriveFiles.length === 0 ? (
@@ -996,47 +1043,76 @@ function Upload({
handleFileSelect(file.id)} - className="h-4 w-4 text-blue-600 rounded border-gray-300 focus:ring-blue-500" + checked={selectedFiles.includes( + file.id, + )} + onChange={() => + handleFileSelect(file.id) + } + className="h-4 w-4 rounded border-gray-300 text-blue-600 focus:ring-blue-500" />
- {file.type === 'application/vnd.google-apps.folder' || file.isFolder ? ( + {file.type === + 'application/vnd.google-apps.folder' || + file.isFolder ? (
handleFolderClick(file.id, file.name)} + className="cursor-pointer text-lg hover:text-blue-600" + onClick={() => + handleFolderClick(file.id, file.name) + } > - Folder + Folder
) : (
- File + File
)} -
+

{ - if (file.type === 'application/vnd.google-apps.folder' || file.isFolder) { - handleFolderClick(file.id, file.name); + if ( + file.type === + 'application/vnd.google-apps.folder' || + file.isFolder + ) { + handleFolderClick( + file.id, + file.name, + ); } }} > {file.name}

- {file.size && `${formatBytes(file.size)} • `}Modified {formatDate(file.modifiedTime)} + {file.size && + `${formatBytes(file.size)} • `} + Modified {formatDate(file.modifiedTime)}

@@ -1044,7 +1120,7 @@ function Upload({ ))}
-
+
{isLoadingFiles && (
@@ -1052,16 +1128,16 @@ function Upload({
)} {!hasMoreFiles && !isLoadingFiles && ( - All files loaded + + All files loaded + )}
)}
- - +
)} @@ -1110,8 +1186,7 @@ function Upload({ > {ingestor.type === 'google_drive' && selectedFiles.length > 0 ? `Train with ${selectedFiles.length} file${selectedFiles.length !== 1 ? 's' : ''}` - : t('modals.uploadDoc.train') - } + : t('modals.uploadDoc.train')} )}
@@ -1121,27 +1196,38 @@ function Upload({ useEffect(() => { const scrollContainer = scrollContainerRef.current; - + const handleScroll = () => { if (!scrollContainer) return; - + const { scrollTop, scrollHeight, clientHeight } = scrollContainer; const isNearBottom = scrollHeight - scrollTop - clientHeight < 50; - + if (isNearBottom && hasMoreFiles && !isLoadingFiles && nextPageToken) { const sessionToken = getSessionToken(ingestor.type); if (sessionToken) { - loadGoogleDriveFiles(sessionToken, currentFolderId, nextPageToken, true); + loadGoogleDriveFiles( + sessionToken, + currentFolderId, + nextPageToken, + true, + ); } } }; - + scrollContainer?.addEventListener('scroll', handleScroll); - + return () => { scrollContainer?.removeEventListener('scroll', handleScroll); }; - }, [hasMoreFiles, isLoadingFiles, nextPageToken, currentFolderId, ingestor.type]); + }, [ + hasMoreFiles, + isLoadingFiles, + nextPageToken, + currentFolderId, + ingestor.type, + ]); return ( { return localStorage.getItem(`${provider}_session_token`); }; @@ -14,4 +13,4 @@ export const setSessionToken = (provider: string, token: string): void => { export const removeSessionToken = (provider: string): void => { localStorage.removeItem(`${provider}_session_token`); -}; \ No newline at end of file +};