diff --git a/application/api/answer/services/stream_processor.py b/application/api/answer/services/stream_processor.py index dfcfcdd2..6f57c2fc 100644 --- a/application/api/answer/services/stream_processor.py +++ b/application/api/answer/services/stream_processor.py @@ -69,11 +69,8 @@ class StreamProcessor: self.decoded_token.get("sub") if self.decoded_token is not None else None ) self.conversation_id = self.data.get("conversation_id") - self.source = ( - {"active_docs": self.data["active_docs"]} - if "active_docs" in self.data - else {} - ) + self.source = {} + self.all_sources = [] self.attachments = [] self.history = [] self.agent_config = {} @@ -86,6 +83,7 @@ class StreamProcessor: def initialize(self): """Initialize all required components for processing""" self._configure_agent() + self._configure_source() self._configure_retriever() self._load_conversation_history() self._process_attachments() @@ -171,12 +169,77 @@ class StreamProcessor: source = data.get("source") if isinstance(source, DBRef): source_doc = self.db.dereference(source) - data["source"] = str(source_doc["_id"]) - data["retriever"] = source_doc.get("retriever", data.get("retriever")) + if source_doc: + data["source"] = str(source_doc["_id"]) + data["retriever"] = source_doc.get("retriever", data.get("retriever")) + data["chunks"] = source_doc.get("chunks", data.get("chunks")) + else: + data["source"] = None + elif source == "default": + data["source"] = "default" else: - data["source"] = {} + data["source"] = None + # Handle multiple sources + + sources = data.get("sources", []) + if sources and isinstance(sources, list): + sources_list = [] + for i, source_ref in enumerate(sources): + if source_ref == "default": + processed_source = { + "id": "default", + "retriever": "classic", + "chunks": data.get("chunks", "2"), + } + sources_list.append(processed_source) + elif isinstance(source_ref, DBRef): + source_doc = self.db.dereference(source_ref) + if source_doc: + processed_source = { + "id": str(source_doc["_id"]), + "retriever": source_doc.get("retriever", "classic"), + "chunks": source_doc.get("chunks", data.get("chunks", "2")), + } + sources_list.append(processed_source) + data["sources"] = sources_list + else: + data["sources"] = [] return data + def _configure_source(self): + """Configure the source based on agent data""" + api_key = self.data.get("api_key") or self.agent_key + + if api_key: + agent_data = self._get_data_from_api_key(api_key) + + if agent_data.get("sources") and len(agent_data["sources"]) > 0: + source_ids = [ + source["id"] for source in agent_data["sources"] if source.get("id") + ] + if source_ids: + self.source = {"active_docs": source_ids} + else: + self.source = {} + self.all_sources = agent_data["sources"] + elif agent_data.get("source"): + self.source = {"active_docs": agent_data["source"]} + self.all_sources = [ + { + "id": agent_data["source"], + "retriever": agent_data.get("retriever", "classic"), + } + ] + else: + self.source = {} + self.all_sources = [] + return + if "active_docs" in self.data: + self.source = {"active_docs": self.data["active_docs"]} + return + self.source = {} + self.all_sources = [] + def _configure_agent(self): """Configure the agent based on request data""" agent_id = self.data.get("agent_id") @@ -230,7 +293,8 @@ class StreamProcessor: "token_limit": self.data.get("token_limit", settings.DEFAULT_MAX_HISTORY), } - if "isNoneDoc" in self.data and self.data["isNoneDoc"]: + api_key = self.data.get("api_key") or self.agent_key + if not api_key and "isNoneDoc" in self.data and self.data["isNoneDoc"]: self.retriever_config["chunks"] = 0 def create_agent(self): diff --git a/application/api/user/routes.py b/application/api/user/routes.py index 9a2febbc..2e9bae81 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -492,9 +492,9 @@ class DeleteOldIndexes(Resource): ) if not doc: return make_response(jsonify({"status": "not found"}), 404) - + storage = StorageCreator.get_storage() - + try: # Delete vector index if settings.VECTOR_STORE == "faiss": @@ -508,7 +508,7 @@ class DeleteOldIndexes(Resource): settings.VECTOR_STORE, source_id=str(doc["_id"]) ) vectorstore.delete_index() - + if "file_path" in doc and doc["file_path"]: file_path = doc["file_path"] if storage.is_directory(file_path): @@ -517,7 +517,7 @@ class DeleteOldIndexes(Resource): storage.delete_file(f) else: storage.delete_file(file_path) - + except FileNotFoundError: pass except Exception as err: @@ -525,7 +525,7 @@ class DeleteOldIndexes(Resource): f"Error deleting files and indexes: {err}", exc_info=True ) return make_response(jsonify({"success": False}), 400) - + sources_collection.delete_one({"_id": ObjectId(source_id)}) return make_response(jsonify({"success": True}), 200) @@ -573,55 +573,75 @@ class UploadFile(Resource): try: storage = StorageCreator.get_storage() - - + for file in files: original_filename = file.filename safe_file = safe_filename(original_filename) - + with tempfile.TemporaryDirectory() as temp_dir: temp_file_path = os.path.join(temp_dir, safe_file) file.save(temp_file_path) - + if zipfile.is_zipfile(temp_file_path): try: - with zipfile.ZipFile(temp_file_path, 'r') as zip_ref: + with zipfile.ZipFile(temp_file_path, "r") as zip_ref: zip_ref.extractall(path=temp_dir) - + # Walk through extracted files and upload them for root, _, files in os.walk(temp_dir): for extracted_file in files: - if os.path.join(root, extracted_file) == temp_file_path: + if ( + os.path.join(root, extracted_file) + == temp_file_path + ): continue - - rel_path = os.path.relpath(os.path.join(root, extracted_file), temp_dir) + + rel_path = os.path.relpath( + os.path.join(root, extracted_file), temp_dir + ) storage_path = f"{base_path}/{rel_path}" - - with open(os.path.join(root, extracted_file), 'rb') as f: + + with open( + os.path.join(root, extracted_file), "rb" + ) as f: storage.save_file(f, storage_path) except Exception as e: - current_app.logger.error(f"Error extracting zip: {e}", exc_info=True) + current_app.logger.error( + f"Error extracting zip: {e}", exc_info=True + ) # If zip extraction fails, save the original zip file file_path = f"{base_path}/{safe_file}" - with open(temp_file_path, 'rb') as f: + with open(temp_file_path, "rb") as f: storage.save_file(f, file_path) else: # For non-zip files, save directly file_path = f"{base_path}/{safe_file}" - with open(temp_file_path, 'rb') as f: + with open(temp_file_path, "rb") as f: storage.save_file(f, file_path) - + task = ingest.delay( settings.UPLOAD_FOLDER, [ - ".rst", ".md", ".pdf", ".txt", ".docx", ".csv", ".epub", - ".html", ".mdx", ".json", ".xlsx", ".pptx", ".png", - ".jpg", ".jpeg", + ".rst", + ".md", + ".pdf", + ".txt", + ".docx", + ".csv", + ".epub", + ".html", + ".mdx", + ".json", + ".xlsx", + ".pptx", + ".png", + ".jpg", + ".jpeg", ], job_name, user, file_path=base_path, - filename=dir_name + filename=dir_name, ) except Exception as err: current_app.logger.error(f"Error uploading file: {err}", exc_info=True) @@ -635,12 +655,29 @@ class ManageSourceFiles(Resource): api.model( "ManageSourceFilesModel", { - "source_id": fields.String(required=True, description="Source ID to modify"), - "operation": fields.String(required=True, description="Operation: 'add', 'remove', or 'remove_directory'"), - "file_paths": fields.List(fields.String, required=False, description="File paths to remove (for remove operation)"), - "directory_path": fields.String(required=False, description="Directory path to remove (for remove_directory operation)"), - "file": fields.Raw(required=False, description="Files to add (for add operation)"), - "parent_dir": fields.String(required=False, description="Parent directory path relative to source root"), + "source_id": fields.String( + required=True, description="Source ID to modify" + ), + "operation": fields.String( + required=True, + description="Operation: 'add', 'remove', or 'remove_directory'", + ), + "file_paths": fields.List( + fields.String, + required=False, + description="File paths to remove (for remove operation)", + ), + "directory_path": fields.String( + required=False, + description="Directory path to remove (for remove_directory operation)", + ), + "file": fields.Raw( + required=False, description="Files to add (for add operation)" + ), + "parent_dir": fields.String( + required=False, + description="Parent directory path relative to source root", + ), }, ) ) @@ -650,7 +687,9 @@ class ManageSourceFiles(Resource): def post(self): decoded_token = request.decoded_token if not decoded_token: - return make_response(jsonify({"success": False, "message": "Unauthorized"}), 401) + return make_response( + jsonify({"success": False, "message": "Unauthorized"}), 401 + ) user = decoded_token.get("sub") source_id = request.form.get("source_id") @@ -658,12 +697,24 @@ class ManageSourceFiles(Resource): if not source_id or not operation: return make_response( - jsonify({"success": False, "message": "source_id and operation are required"}), 400 + jsonify( + { + "success": False, + "message": "source_id and operation are required", + } + ), + 400, ) if operation not in ["add", "remove", "remove_directory"]: return make_response( - jsonify({"success": False, "message": "operation must be 'add', 'remove', or 'remove_directory'"}), 400 + jsonify( + { + "success": False, + "message": "operation must be 'add', 'remove', or 'remove_directory'", + } + ), + 400, ) try: @@ -674,34 +725,53 @@ class ManageSourceFiles(Resource): ) try: - source = sources_collection.find_one({"_id": ObjectId(source_id), "user": user}) + source = sources_collection.find_one( + {"_id": ObjectId(source_id), "user": user} + ) if not source: return make_response( - jsonify({"success": False, "message": "Source not found or access denied"}), 404 + jsonify( + { + "success": False, + "message": "Source not found or access denied", + } + ), + 404, ) except Exception as err: current_app.logger.error(f"Error finding source: {err}", exc_info=True) - return make_response(jsonify({"success": False, "message": "Database error"}), 500) + return make_response( + jsonify({"success": False, "message": "Database error"}), 500 + ) try: storage = StorageCreator.get_storage() source_file_path = source.get("file_path", "") - parent_dir = request.form.get("parent_dir", "") - + parent_dir = request.form.get("parent_dir", "") + if parent_dir and (parent_dir.startswith("/") or ".." in parent_dir): return make_response( - jsonify({"success": False, "message": "Invalid parent directory path"}), 400 + jsonify( + {"success": False, "message": "Invalid parent directory path"} + ), + 400, ) if operation == "add": files = request.files.getlist("file") if not files or all(file.filename == "" for file in files): return make_response( - jsonify({"success": False, "message": "No files provided for add operation"}), 400 + jsonify( + { + "success": False, + "message": "No files provided for add operation", + } + ), + 400, ) added_files = [] - + target_dir = source_file_path if parent_dir: target_dir = f"{source_file_path}/{parent_dir}" @@ -720,26 +790,44 @@ class ManageSourceFiles(Resource): task = reingest_source_task.delay(source_id=source_id, user=user) - return make_response(jsonify({ - "success": True, - "message": f"Added {len(added_files)} files", - "added_files": added_files, - "parent_dir": parent_dir, - "reingest_task_id": task.id - }), 200) + return make_response( + jsonify( + { + "success": True, + "message": f"Added {len(added_files)} files", + "added_files": added_files, + "parent_dir": parent_dir, + "reingest_task_id": task.id, + } + ), + 200, + ) elif operation == "remove": file_paths_str = request.form.get("file_paths") if not file_paths_str: return make_response( - jsonify({"success": False, "message": "file_paths required for remove operation"}), 400 + jsonify( + { + "success": False, + "message": "file_paths required for remove operation", + } + ), + 400, ) try: - file_paths = json.loads(file_paths_str) if isinstance(file_paths_str, str) else file_paths_str + file_paths = ( + json.loads(file_paths_str) + if isinstance(file_paths_str, str) + else file_paths_str + ) except Exception: return make_response( - jsonify({"success": False, "message": "Invalid file_paths format"}), 400 + jsonify( + {"success": False, "message": "Invalid file_paths format"} + ), + 400, ) # Remove files from storage and directory structure @@ -757,18 +845,29 @@ class ManageSourceFiles(Resource): task = reingest_source_task.delay(source_id=source_id, user=user) - return make_response(jsonify({ - "success": True, - "message": f"Removed {len(removed_files)} files", - "removed_files": removed_files, - "reingest_task_id": task.id - }), 200) + return make_response( + jsonify( + { + "success": True, + "message": f"Removed {len(removed_files)} files", + "removed_files": removed_files, + "reingest_task_id": task.id, + } + ), + 200, + ) elif operation == "remove_directory": directory_path = request.form.get("directory_path") if not directory_path: return make_response( - jsonify({"success": False, "message": "directory_path required for remove_directory operation"}), 400 + jsonify( + { + "success": False, + "message": "directory_path required for remove_directory operation", + } + ), + 400, ) # Validate directory path (prevent path traversal) @@ -778,10 +877,17 @@ class ManageSourceFiles(Resource): f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}" ) return make_response( - jsonify({"success": False, "message": "Invalid directory path"}), 400 + jsonify( + {"success": False, "message": "Invalid directory path"} + ), + 400, ) - full_directory_path = f"{source_file_path}/{directory_path}" if directory_path else source_file_path + full_directory_path = ( + f"{source_file_path}/{directory_path}" + if directory_path + else source_file_path + ) if not storage.is_directory(full_directory_path): current_app.logger.warning( @@ -790,7 +896,13 @@ class ManageSourceFiles(Resource): f"Full path: {full_directory_path}" ) return make_response( - jsonify({"success": False, "message": "Directory not found or is not a directory"}), 404 + jsonify( + { + "success": False, + "message": "Directory not found or is not a directory", + } + ), + 404, ) success = storage.remove_directory(full_directory_path) @@ -802,7 +914,10 @@ class ManageSourceFiles(Resource): f"Full path: {full_directory_path}" ) return make_response( - jsonify({"success": False, "message": "Failed to remove directory"}), 500 + jsonify( + {"success": False, "message": "Failed to remove directory"} + ), + 500, ) current_app.logger.info( @@ -816,12 +931,17 @@ class ManageSourceFiles(Resource): task = reingest_source_task.delay(source_id=source_id, user=user) - return make_response(jsonify({ - "success": True, - "message": f"Successfully removed directory: {directory_path}", - "removed_directory": directory_path, - "reingest_task_id": task.id - }), 200) + return make_response( + jsonify( + { + "success": True, + "message": f"Successfully removed directory: {directory_path}", + "removed_directory": directory_path, + "reingest_task_id": task.id, + } + ), + 200, + ) except Exception as err: error_context = f"operation={operation}, user={user}, source_id={source_id}" @@ -835,8 +955,12 @@ class ManageSourceFiles(Resource): parent_dir = request.form.get("parent_dir", "") error_context += f", parent_dir={parent_dir}" - current_app.logger.error(f"Error managing source files: {err} ({error_context})", exc_info=True) - return make_response(jsonify({"success": False, "message": "Operation failed"}), 500) + current_app.logger.error( + f"Error managing source files: {err} ({error_context})", exc_info=True + ) + return make_response( + jsonify({"success": False, "message": "Operation failed"}), 500 + ) @user_ns.route("/api/remote") @@ -984,7 +1108,7 @@ class PaginatedSources(Resource): "tokens": doc.get("tokens", ""), "retriever": doc.get("retriever", "classic"), "syncFrequency": doc.get("sync_frequency", ""), - "isNested": bool(doc.get("directory_structure")) + "isNested": bool(doc.get("directory_structure")), } paginated_docs.append(doc_data) response = { @@ -1032,7 +1156,7 @@ class CombinedJson(Resource): "tokens": index.get("tokens", ""), "retriever": index.get("retriever", "classic"), "syncFrequency": index.get("sync_frequency", ""), - "is_nested": bool(index.get("directory_structure")) + "is_nested": bool(index.get("directory_structure")), } ) except Exception as err: @@ -1272,6 +1396,16 @@ class GetAgent(Resource): and (source_doc := db.dereference(agent.get("source"))) else "" ), + "sources": [ + ( + str(db.dereference(source_ref)["_id"]) + if isinstance(source_ref, DBRef) and db.dereference(source_ref) + else source_ref + ) + for source_ref in agent.get("sources", []) + if (isinstance(source_ref, DBRef) and db.dereference(source_ref)) + or source_ref == "default" + ], "chunks": agent["chunks"], "retriever": agent.get("retriever", ""), "prompt_id": agent.get("prompt_id", ""), @@ -1325,8 +1459,24 @@ class GetAgents(Resource): str(source_doc["_id"]) if isinstance(agent.get("source"), DBRef) and (source_doc := db.dereference(agent.get("source"))) - else "" + else ( + agent.get("source", "") + if agent.get("source") == "default" + else "" + ) ), + "sources": [ + ( + source_ref + if source_ref == "default" + else str(db.dereference(source_ref)["_id"]) + ) + for source_ref in agent.get("sources", []) + if source_ref == "default" + or ( + isinstance(source_ref, DBRef) and db.dereference(source_ref) + ) + ], "chunks": agent["chunks"], "retriever": agent.get("retriever", ""), "prompt_id": agent.get("prompt_id", ""), @@ -1351,6 +1501,7 @@ class GetAgents(Resource): for agent in agents if "source" in agent or "retriever" in agent ] + except Exception as err: current_app.logger.error(f"Error retrieving agents: {err}", exc_info=True) return make_response(jsonify({"success": False}), 400) @@ -1369,7 +1520,14 @@ class CreateAgent(Resource): "image": fields.Raw( required=False, description="Image file upload", type="file" ), - "source": fields.String(required=True, description="Source ID"), + "source": fields.String( + required=False, description="Source ID (legacy single source)" + ), + "sources": fields.List( + fields.String, + required=False, + description="List of source identifiers for multiple sources", + ), "chunks": fields.Integer(required=True, description="Chunks count"), "retriever": fields.String(required=True, description="Retriever ID"), "prompt_id": fields.String(required=True, description="Prompt ID"), @@ -1381,7 +1539,8 @@ class CreateAgent(Resource): required=True, description="Status of the agent (draft or published)" ), "json_schema": fields.Raw( - required=False, description="JSON schema for enforcing structured output format" + required=False, + description="JSON schema for enforcing structured output format", ), }, ) @@ -1401,13 +1560,18 @@ class CreateAgent(Resource): data["tools"] = json.loads(data["tools"]) except json.JSONDecodeError: data["tools"] = [] + if "sources" in data: + try: + data["sources"] = json.loads(data["sources"]) + except json.JSONDecodeError: + data["sources"] = [] if "json_schema" in data: try: data["json_schema"] = json.loads(data["json_schema"]) except json.JSONDecodeError: data["json_schema"] = None print(f"Received data: {data}") - + # Validate JSON schema if provided if data.get("json_schema"): try: @@ -1415,20 +1579,32 @@ class CreateAgent(Resource): json_schema = data.get("json_schema") if not isinstance(json_schema, dict): return make_response( - jsonify({"success": False, "message": "JSON schema must be a valid JSON object"}), - 400 + jsonify( + { + "success": False, + "message": "JSON schema must be a valid JSON object", + } + ), + 400, ) - + # Validate that it has either a 'schema' property or is itself a schema if "schema" not in json_schema and "type" not in json_schema: return make_response( - jsonify({"success": False, "message": "JSON schema must contain either a 'schema' property or be a valid JSON schema with 'type' property"}), - 400 + jsonify( + { + "success": False, + "message": "JSON schema must contain either a 'schema' property or be a valid JSON schema with 'type' property", + } + ), + 400, ) except Exception as e: return make_response( - jsonify({"success": False, "message": f"Invalid JSON schema: {str(e)}"}), - 400 + jsonify( + {"success": False, "message": f"Invalid JSON schema: {str(e)}"} + ), + 400, ) if data.get("status") not in ["draft", "published"]: @@ -1446,12 +1622,22 @@ class CreateAgent(Resource): required_fields = [ "name", "description", - "source", "chunks", "retriever", "prompt_id", "agent_type", ] + # Require either source or sources (but not both) + if not data.get("source") and not data.get("sources"): + return make_response( + jsonify( + { + "success": False, + "message": "Either 'source' or 'sources' field is required for published agents", + } + ), + 400, + ) validate_fields = ["name", "description", "prompt_id", "agent_type"] else: required_fields = ["name"] @@ -1471,16 +1657,31 @@ class CreateAgent(Resource): try: key = str(uuid.uuid4()) if data.get("status") == "published" else "" + + sources_list = [] + if data.get("sources") and len(data.get("sources", [])) > 0: + for source_id in data.get("sources", []): + if source_id == "default": + sources_list.append("default") + elif ObjectId.is_valid(source_id): + sources_list.append(DBRef("sources", ObjectId(source_id))) + source_field = "" + else: + source_value = data.get("source", "") + if source_value == "default": + source_field = "default" + elif ObjectId.is_valid(source_value): + source_field = DBRef("sources", ObjectId(source_value)) + else: + source_field = "" + new_agent = { "user": user, "name": data.get("name"), "description": data.get("description", ""), "image": image_url, - "source": ( - DBRef("sources", ObjectId(data.get("source"))) - if ObjectId.is_valid(data.get("source")) - else "" - ), + "source": source_field, + "sources": sources_list, "chunks": data.get("chunks", ""), "retriever": data.get("retriever", ""), "prompt_id": data.get("prompt_id", ""), @@ -1495,7 +1696,11 @@ class CreateAgent(Resource): } if new_agent["chunks"] == "": new_agent["chunks"] = "0" - if new_agent["source"] == "" and new_agent["retriever"] == "": + if ( + new_agent["source"] == "" + and new_agent["retriever"] == "" + and not new_agent["sources"] + ): new_agent["retriever"] = "classic" resp = agents_collection.insert_one(new_agent) new_id = str(resp.inserted_id) @@ -1517,7 +1722,14 @@ class UpdateAgent(Resource): "image": fields.String( required=False, description="New image URL or identifier" ), - "source": fields.String(required=True, description="Source ID"), + "source": fields.String( + required=False, description="Source ID (legacy single source)" + ), + "sources": fields.List( + fields.String, + required=False, + description="List of source identifiers for multiple sources", + ), "chunks": fields.Integer(required=True, description="Chunks count"), "retriever": fields.String(required=True, description="Retriever ID"), "prompt_id": fields.String(required=True, description="Prompt ID"), @@ -1529,7 +1741,8 @@ class UpdateAgent(Resource): required=True, description="Status of the agent (draft or published)" ), "json_schema": fields.Raw( - required=False, description="JSON schema for enforcing structured output format" + required=False, + description="JSON schema for enforcing structured output format", ), }, ) @@ -1549,6 +1762,11 @@ class UpdateAgent(Resource): data["tools"] = json.loads(data["tools"]) except json.JSONDecodeError: data["tools"] = [] + if "sources" in data: + try: + data["sources"] = json.loads(data["sources"]) + except json.JSONDecodeError: + data["sources"] = [] if "json_schema" in data: try: data["json_schema"] = json.loads(data["json_schema"]) @@ -1593,6 +1811,7 @@ class UpdateAgent(Resource): "description", "image", "source", + "sources", "chunks", "retriever", "prompt_id", @@ -1616,7 +1835,10 @@ class UpdateAgent(Resource): update_fields[field] = new_status elif field == "source": source_id = data.get("source") - if source_id and ObjectId.is_valid(source_id): + if source_id == "default": + # Handle special "default" source + update_fields[field] = "default" + elif source_id and ObjectId.is_valid(source_id): update_fields[field] = DBRef("sources", ObjectId(source_id)) elif source_id: return make_response( @@ -1630,6 +1852,30 @@ class UpdateAgent(Resource): ) else: update_fields[field] = "" + elif field == "sources": + sources_list = data.get("sources", []) + if sources_list and isinstance(sources_list, list): + valid_sources = [] + for source_id in sources_list: + if source_id == "default": + valid_sources.append("default") + elif ObjectId.is_valid(source_id): + valid_sources.append( + DBRef("sources", ObjectId(source_id)) + ) + else: + return make_response( + jsonify( + { + "success": False, + "message": f"Invalid source ID format: {source_id}", + } + ), + 400, + ) + update_fields[field] = valid_sources + else: + update_fields[field] = [] elif field == "chunks": chunks_value = data.get("chunks") if chunks_value == "": @@ -3532,7 +3778,7 @@ class GetChunks(Resource): "page": "Page number for pagination", "per_page": "Number of chunks per page", "path": "Optional: Filter chunks by relative file path", - "search": "Optional: Search term to filter chunks by title or content" + "search": "Optional: Search term to filter chunks by title or content", }, ) def get(self): @@ -3556,7 +3802,7 @@ class GetChunks(Resource): try: store = get_vector_store(doc_id) chunks = store.get_chunks() - + filtered_chunks = [] for chunk in chunks: metadata = chunk.get("metadata", {}) @@ -3577,9 +3823,9 @@ class GetChunks(Resource): continue filtered_chunks.append(chunk) - + chunks = filtered_chunks - + total_chunks = len(chunks) start = (page - 1) * per_page end = start + per_page @@ -3593,7 +3839,7 @@ class GetChunks(Resource): "total": total_chunks, "chunks": paginated_chunks, "path": path if path else None, - "search": search_term if search_term else None + "search": search_term if search_term else None, } ), 200, @@ -3602,6 +3848,7 @@ class GetChunks(Resource): current_app.logger.error(f"Error getting chunks: {e}", exc_info=True) return make_response(jsonify({"success": False}), 500) + @user_ns.route("/api/add_chunk") class AddChunk(Resource): @api.expect( @@ -3768,7 +4015,9 @@ class UpdateChunk(Resource): deleted = store.delete_chunk(chunk_id) if not deleted: - current_app.logger.warning(f"Failed to delete old chunk {chunk_id}, but new chunk {new_chunk_id} was created") + current_app.logger.warning( + f"Failed to delete old chunk {chunk_id}, but new chunk {new_chunk_id} was created" + ) return make_response( jsonify( @@ -3900,39 +4149,38 @@ class DirectoryStructure(Resource): decoded_token = request.decoded_token if not decoded_token: return make_response(jsonify({"success": False}), 401) - + user = decoded_token.get("sub") doc_id = request.args.get("id") - + if not doc_id: - return make_response( - jsonify({"error": "Document ID is required"}), 400 - ) - + return make_response(jsonify({"error": "Document ID is required"}), 400) + if not ObjectId.is_valid(doc_id): return make_response(jsonify({"error": "Invalid document ID"}), 400) - + try: doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user}) if not doc: return make_response( jsonify({"error": "Document not found or access denied"}), 404 ) - + directory_structure = doc.get("directory_structure", {}) - + return make_response( - jsonify({ - "success": True, - "directory_structure": directory_structure, - "base_path": doc.get("file_path", "") - }), 200 + jsonify( + { + "success": True, + "directory_structure": directory_structure, + "base_path": doc.get("file_path", ""), + } + ), + 200, ) - + except Exception as e: current_app.logger.error( f"Error retrieving directory structure: {e}", exc_info=True ) - return make_response( - jsonify({"success": False, "error": str(e)}), 500 - ) + return make_response(jsonify({"success": False, "error": str(e)}), 500) diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index 82423bb5..ce1b937b 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -46,7 +46,7 @@ class ClassicRAG(BaseRetriever): user_api_key=self.user_api_key, decoded_token=decoded_token, ) - + if "active_docs" in source and source["active_docs"] is not None: if isinstance(source["active_docs"], list): self.vectorstores = source["active_docs"] @@ -54,7 +54,6 @@ class ClassicRAG(BaseRetriever): self.vectorstores = [source["active_docs"]] else: self.vectorstores = [] - self.question = self._rephrase_query() self.decoded_token = decoded_token self._validate_vectorstore_config() @@ -64,7 +63,6 @@ class ClassicRAG(BaseRetriever): if not self.vectorstores: logging.warning("No vectorstores configured for retrieval") return - invalid_ids = [ vs_id for vs_id in self.vectorstores if not vs_id or not vs_id.strip() ] @@ -84,12 +82,16 @@ class ClassicRAG(BaseRetriever): or not self.vectorstores ): return self.original_question - prompt = f"""Given the following conversation history: + {self.chat_history} + + Rephrase the following user question to be a standalone search query + that captures all relevant context from the conversation: + """ messages = [ @@ -109,7 +111,6 @@ class ClassicRAG(BaseRetriever): """Retrieve relevant documents from configured vectorstores""" if self.chunks == 0 or not self.vectorstores: return [] - all_docs = [] chunks_per_source = max(1, self.chunks // len(self.vectorstores)) @@ -128,7 +129,6 @@ class ClassicRAG(BaseRetriever): else: page_content = doc.get("text", doc.get("page_content", "")) metadata = doc.get("metadata", {}) - title = metadata.get( "title", metadata.get("post_title", page_content) ) @@ -136,7 +136,6 @@ class ClassicRAG(BaseRetriever): title = title.split("/")[-1] else: title = str(title).split("/")[-1] - all_docs.append( { "title": title, @@ -150,7 +149,6 @@ class ClassicRAG(BaseRetriever): exc_info=True, ) continue - return all_docs def search(self, query: str = ""): diff --git a/frontend/src/agents/NewAgent.tsx b/frontend/src/agents/NewAgent.tsx index da8cef5d..f1fc5e50 100644 --- a/frontend/src/agents/NewAgent.tsx +++ b/frontend/src/agents/NewAgent.tsx @@ -45,6 +45,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { description: '', image: '', source: '', + sources: [], chunks: '', retriever: '', prompt_id: 'default', @@ -150,7 +151,41 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { const formData = new FormData(); formData.append('name', agent.name); formData.append('description', agent.description); - formData.append('source', agent.source); + + if (selectedSourceIds.size > 1) { + const sourcesArray = Array.from(selectedSourceIds) + .map((id) => { + const sourceDoc = sourceDocs?.find( + (source) => + source.id === id || source.retriever === id || source.name === id, + ); + if (sourceDoc?.name === 'Default' && !sourceDoc?.id) { + return 'default'; + } + return sourceDoc?.id || id; + }) + .filter(Boolean); + formData.append('sources', JSON.stringify(sourcesArray)); + formData.append('source', ''); + } else if (selectedSourceIds.size === 1) { + const singleSourceId = Array.from(selectedSourceIds)[0]; + const sourceDoc = sourceDocs?.find( + (source) => + source.id === singleSourceId || + source.retriever === singleSourceId || + source.name === singleSourceId, + ); + let finalSourceId; + if (sourceDoc?.name === 'Default' && !sourceDoc?.id) + finalSourceId = 'default'; + else finalSourceId = sourceDoc?.id || singleSourceId; + formData.append('source', String(finalSourceId)); + formData.append('sources', JSON.stringify([])); + } else { + formData.append('source', ''); + formData.append('sources', JSON.stringify([])); + } + formData.append('chunks', agent.chunks); formData.append('retriever', agent.retriever); formData.append('prompt_id', agent.prompt_id); @@ -196,7 +231,41 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { const formData = new FormData(); formData.append('name', agent.name); formData.append('description', agent.description); - formData.append('source', agent.source); + + if (selectedSourceIds.size > 1) { + const sourcesArray = Array.from(selectedSourceIds) + .map((id) => { + const sourceDoc = sourceDocs?.find( + (source) => + source.id === id || source.retriever === id || source.name === id, + ); + if (sourceDoc?.name === 'Default' && !sourceDoc?.id) { + return 'default'; + } + return sourceDoc?.id || id; + }) + .filter(Boolean); + formData.append('sources', JSON.stringify(sourcesArray)); + formData.append('source', ''); + } else if (selectedSourceIds.size === 1) { + const singleSourceId = Array.from(selectedSourceIds)[0]; + const sourceDoc = sourceDocs?.find( + (source) => + source.id === singleSourceId || + source.retriever === singleSourceId || + source.name === singleSourceId, + ); + let finalSourceId; + if (sourceDoc?.name === 'Default' && !sourceDoc?.id) + finalSourceId = 'default'; + else finalSourceId = sourceDoc?.id || singleSourceId; + formData.append('source', String(finalSourceId)); + formData.append('sources', JSON.stringify([])); + } else { + formData.append('source', ''); + formData.append('sources', JSON.stringify([])); + } + formData.append('chunks', agent.chunks); formData.append('retriever', agent.retriever); formData.append('prompt_id', agent.prompt_id); @@ -293,9 +362,33 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { throw new Error('Failed to fetch agent'); } const data = await response.json(); - if (data.source) setSelectedSourceIds(new Set([data.source])); - else if (data.retriever) + + if (data.sources && data.sources.length > 0) { + const mappedSources = data.sources.map((sourceId: string) => { + if (sourceId === 'default') { + const defaultSource = sourceDocs?.find( + (source) => source.name === 'Default', + ); + return defaultSource?.retriever || 'classic'; + } + return sourceId; + }); + setSelectedSourceIds(new Set(mappedSources)); + } else if (data.source) { + if (data.source === 'default') { + const defaultSource = sourceDocs?.find( + (source) => source.name === 'Default', + ); + setSelectedSourceIds( + new Set([defaultSource?.retriever || 'classic']), + ); + } else { + setSelectedSourceIds(new Set([data.source])); + } + } else if (data.retriever) { setSelectedSourceIds(new Set([data.retriever])); + } + if (data.tools) setSelectedToolIds(new Set(data.tools)); if (data.status === 'draft') setEffectiveMode('draft'); if (data.json_schema) { @@ -311,25 +404,57 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { }, [agentId, mode, token]); useEffect(() => { - const selectedSource = Array.from(selectedSourceIds).map((id) => - sourceDocs?.find( - (source) => - source.id === id || source.retriever === id || source.name === id, - ), - ); - if (selectedSource[0]?.model === embeddingsName) { - if (selectedSource[0] && 'id' in selectedSource[0]) { + const selectedSources = Array.from(selectedSourceIds) + .map((id) => + sourceDocs?.find( + (source) => + source.id === id || source.retriever === id || source.name === id, + ), + ) + .filter(Boolean); + + if (selectedSources.length > 0) { + // Handle multiple sources + if (selectedSources.length > 1) { + // Multiple sources selected - store in sources array + const sourceIds = selectedSources + .map((source) => source?.id) + .filter((id): id is string => Boolean(id)); setAgent((prev) => ({ ...prev, - source: selectedSource[0]?.id || 'default', + sources: sourceIds, + source: '', // Clear single source for multiple sources retriever: '', })); - } else - setAgent((prev) => ({ - ...prev, - source: '', - retriever: selectedSource[0]?.retriever || 'classic', - })); + } else { + // Single source selected - maintain backward compatibility + const selectedSource = selectedSources[0]; + if (selectedSource?.model === embeddingsName) { + if (selectedSource && 'id' in selectedSource) { + setAgent((prev) => ({ + ...prev, + source: selectedSource?.id || 'default', + sources: [], // Clear sources array for single source + retriever: '', + })); + } else { + setAgent((prev) => ({ + ...prev, + source: '', + sources: [], // Clear sources array + retriever: selectedSource?.retriever || 'classic', + })); + } + } + } + } else { + // No sources selected + setAgent((prev) => ({ + ...prev, + source: '', + sources: [], + retriever: '', + })); } }, [selectedSourceIds]); @@ -510,7 +635,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { ) .filter(Boolean) .join(', ') - : 'Select source'} + : 'Select sources'} ) => { setSelectedSourceIds(newSelectedIds); - setIsSourcePopupOpen(false); }} - title="Select Source" + title="Select Sources" searchPlaceholder="Search sources..." - noOptionsMessage="No source available" - singleSelect={true} + noOptionsMessage="No sources available" />
diff --git a/frontend/src/agents/types/index.ts b/frontend/src/agents/types/index.ts index e841cb0a..442097a1 100644 --- a/frontend/src/agents/types/index.ts +++ b/frontend/src/agents/types/index.ts @@ -10,6 +10,7 @@ export type Agent = { description: string; image: string; source: string; + sources?: string[]; chunks: string; retriever: string; prompt_id: string;