diff --git a/application/vectorstore/faiss.py b/application/vectorstore/faiss.py index a8839cd2..ee74b971 100644 --- a/application/vectorstore/faiss.py +++ b/application/vectorstore/faiss.py @@ -3,30 +3,27 @@ from application.vectorstore.base import BaseVectorStore from application.core.settings import settings import os -def get_vectorstore(path): +def get_vectorstore(path: str) -> str: if path: - vectorstore = "indexes/"+path - vectorstore = os.path.join("application", vectorstore) + vectorstore = os.path.join("application", "indexes", path) else: vectorstore = os.path.join("application") - return vectorstore class FaissStore(BaseVectorStore): - - def __init__(self, source_id, embeddings_key, docs_init=None): + def __init__(self, source_id: str, embeddings_key: str, docs_init=None): super().__init__() self.path = get_vectorstore(source_id) embeddings = self._get_embeddings(settings.EMBEDDINGS_NAME, embeddings_key) - if docs_init: - self.docsearch = FAISS.from_documents( - docs_init, embeddings - ) - else: - self.docsearch = FAISS.load_local( - self.path, embeddings, - allow_dangerous_deserialization=True - ) + + try: + if docs_init: + self.docsearch = FAISS.from_documents(docs_init, embeddings) + else: + self.docsearch = FAISS.load_local(self.path, embeddings, allow_dangerous_deserialization=True) + except Exception as e: + raise + self.assert_embedding_dimensions(embeddings) def search(self, *args, **kwargs): @@ -42,16 +39,12 @@ class FaissStore(BaseVectorStore): return self.docsearch.delete(*args, **kwargs) def assert_embedding_dimensions(self, embeddings): - """ - Check that the word embedding dimension of the docsearch index matches - the dimension of the word embeddings used - """ + """Check that the word embedding dimension of the docsearch index matches the dimension of the word embeddings used.""" if settings.EMBEDDINGS_NAME == "huggingface_sentence-transformers/all-mpnet-base-v2": - try: - word_embedding_dimension = embeddings.dimension - except AttributeError as e: - raise AttributeError("'dimension' attribute not found in embeddings instance. Make sure the embeddings object is properly initialized.") from e + word_embedding_dimension = getattr(embeddings, 'dimension', None) + if word_embedding_dimension is None: + raise AttributeError("'dimension' attribute not found in embeddings instance.") + docsearch_index_dimension = self.docsearch.index.d if word_embedding_dimension != docsearch_index_dimension: - raise ValueError(f"Embedding dimension mismatch: embeddings.dimension ({word_embedding_dimension}) " + - f"!= docsearch index dimension ({docsearch_index_dimension})") \ No newline at end of file + raise ValueError(f"Embedding dimension mismatch: embeddings.dimension ({word_embedding_dimension}) != docsearch index dimension ({docsearch_index_dimension})")