diff --git a/application/api/internal/routes.py b/application/api/internal/routes.py index c8e32d11..f0ad042f 100755 --- a/application/api/internal/routes.py +++ b/application/api/internal/routes.py @@ -3,10 +3,13 @@ import datetime from flask import Blueprint, request, send_from_directory from werkzeug.utils import secure_filename from bson.objectid import ObjectId - +import logging from application.core.mongo_db import MongoDB from application.core.settings import settings +from application.storage.storage_creator import StorageCreator + +logger = logging.getLogger(__name__) mongo = MongoDB.get_client() db = mongo["docsgpt"] conversations_collection = db["conversations"] @@ -45,26 +48,26 @@ def upload_index_files(): remote_data = request.form["remote_data"] if "remote_data" in request.form else None sync_frequency = secure_filename(request.form["sync_frequency"]) if "sync_frequency" in request.form else None - save_dir = os.path.join(current_dir, "indexes", str(id)) + storage = StorageCreator.get_storage() + index_base_path = f"indexes/{id}" + if settings.VECTOR_STORE == "faiss": if "file_faiss" not in request.files: - print("No file part") + logger.error("No file_faiss part") return {"status": "no file"} file_faiss = request.files["file_faiss"] if file_faiss.filename == "": return {"status": "no file name"} if "file_pkl" not in request.files: - print("No file part") + logger.error("No file_pkl part") return {"status": "no file"} file_pkl = request.files["file_pkl"] if file_pkl.filename == "": return {"status": "no file name"} - # saves index files - - if not os.path.exists(save_dir): - os.makedirs(save_dir) - file_faiss.save(os.path.join(save_dir, "index.faiss")) - file_pkl.save(os.path.join(save_dir, "index.pkl")) + + # Save index files to storage + storage.save_file(file_faiss, f"{index_base_path}/index.faiss") + storage.save_file(file_pkl, f"{index_base_path}/index.pkl") existing_entry = sources_collection.find_one({"_id": ObjectId(id)}) if existing_entry: diff --git a/application/vectorstore/faiss.py b/application/vectorstore/faiss.py index 87ffcccb..ce455bd8 100644 --- a/application/vectorstore/faiss.py +++ b/application/vectorstore/faiss.py @@ -1,17 +1,19 @@ import os +import tempfile from langchain_community.vectorstores import FAISS from application.core.settings import settings from application.parser.schema.base import Document from application.vectorstore.base import BaseVectorStore +from application.storage.storage_creator import StorageCreator def get_vectorstore(path: str) -> str: if path: - vectorstore = os.path.join("application", "indexes", path) + vectorstore = f"indexes/{path}" else: - vectorstore = os.path.join("application") + vectorstore = "indexes" return vectorstore @@ -21,16 +23,36 @@ class FaissStore(BaseVectorStore): self.source_id = source_id self.path = get_vectorstore(source_id) self.embeddings = self._get_embeddings(settings.EMBEDDINGS_NAME, embeddings_key) + self.storage = StorageCreator.get_storage() try: if docs_init: self.docsearch = FAISS.from_documents(docs_init, self.embeddings) else: - self.docsearch = FAISS.load_local( - self.path, self.embeddings, allow_dangerous_deserialization=True - ) - except Exception: - raise + with tempfile.TemporaryDirectory() as temp_dir: + faiss_path = f"{self.path}/index.faiss" + pkl_path = f"{self.path}/index.pkl" + + if not self.storage.file_exists(faiss_path) or not self.storage.file_exists(pkl_path): + raise FileNotFoundError(f"Index files not found in storage at {self.path}") + + faiss_file = self.storage.get_file(faiss_path) + pkl_file = self.storage.get_file(pkl_path) + + local_faiss_path = os.path.join(temp_dir, "index.faiss") + local_pkl_path = os.path.join(temp_dir, "index.pkl") + + with open(local_faiss_path, 'wb') as f: + f.write(faiss_file.read()) + + with open(local_pkl_path, 'wb') as f: + f.write(pkl_file.read()) + + self.docsearch = FAISS.load_local( + temp_dir, self.embeddings, allow_dangerous_deserialization=True + ) + except Exception as e: + raise Exception(f"Error loading FAISS index: {str(e)}") self.assert_embedding_dimensions(self.embeddings)