diff --git a/application/vectorstore/faiss.py b/application/vectorstore/faiss.py index 2c1fcb93..b9c63cc8 100644 --- a/application/vectorstore/faiss.py +++ b/application/vectorstore/faiss.py @@ -1,5 +1,6 @@ import os import tempfile +import io from langchain_community.vectorstores import FAISS @@ -66,8 +67,26 @@ class FaissStore(BaseVectorStore): def add_texts(self, *args, **kwargs): return self.docsearch.add_texts(*args, **kwargs) - def save_local(self, *args, **kwargs): - return self.docsearch.save_local(*args, **kwargs) + def save_local(self, path): + """ + Save the FAISS index to disk and upload to storage. + + Args: + path: Path where the index should be stored + """ + with tempfile.TemporaryDirectory() as temp_dir: + self.docsearch.save_local(temp_dir) + + with open(os.path.join(temp_dir, "index.faiss"), "rb") as f_faiss: + faiss_data = f_faiss.read() + + with open(os.path.join(temp_dir, "index.pkl"), "rb") as f_pkl: + pkl_data = f_pkl.read() + + self.storage.save_file(io.BytesIO(faiss_data), f"{path}/index.faiss") + self.storage.save_file(io.BytesIO(pkl_data), f"{path}/index.pkl") + + return True def delete_index(self, *args, **kwargs): return self.docsearch.delete(*args, **kwargs)