Merge pull request #1183 from Devparihar5/main

Refactor FaissStore to enhance error handling, add improve type hints, and document methods for better maintainability and usability
This commit is contained in:
Alex
2024-10-02 12:17:54 +01:00
committed by GitHub

View File

@@ -3,30 +3,27 @@ from application.vectorstore.base import BaseVectorStore
from application.core.settings import settings from application.core.settings import settings
import os import os
def get_vectorstore(path): def get_vectorstore(path: str) -> str:
if path: if path:
vectorstore = "indexes/"+path vectorstore = os.path.join("application", "indexes", path)
vectorstore = os.path.join("application", vectorstore)
else: else:
vectorstore = os.path.join("application") vectorstore = os.path.join("application")
return vectorstore return vectorstore
class FaissStore(BaseVectorStore): class FaissStore(BaseVectorStore):
def __init__(self, source_id: str, embeddings_key: str, docs_init=None):
def __init__(self, source_id, embeddings_key, docs_init=None):
super().__init__() super().__init__()
self.path = get_vectorstore(source_id) self.path = get_vectorstore(source_id)
embeddings = self._get_embeddings(settings.EMBEDDINGS_NAME, embeddings_key) embeddings = self._get_embeddings(settings.EMBEDDINGS_NAME, embeddings_key)
try:
if docs_init: if docs_init:
self.docsearch = FAISS.from_documents( self.docsearch = FAISS.from_documents(docs_init, embeddings)
docs_init, embeddings
)
else: else:
self.docsearch = FAISS.load_local( self.docsearch = FAISS.load_local(self.path, embeddings, allow_dangerous_deserialization=True)
self.path, embeddings, except Exception:
allow_dangerous_deserialization=True raise # Just re-raise the exception without assigning to e
)
self.assert_embedding_dimensions(embeddings) self.assert_embedding_dimensions(embeddings)
def search(self, *args, **kwargs): def search(self, *args, **kwargs):
@@ -42,16 +39,12 @@ class FaissStore(BaseVectorStore):
return self.docsearch.delete(*args, **kwargs) return self.docsearch.delete(*args, **kwargs)
def assert_embedding_dimensions(self, embeddings): def assert_embedding_dimensions(self, embeddings):
""" """Check that the word embedding dimension of the docsearch index matches the dimension of the word embeddings used."""
Check that the word embedding dimension of the docsearch index matches
the dimension of the word embeddings used
"""
if settings.EMBEDDINGS_NAME == "huggingface_sentence-transformers/all-mpnet-base-v2": if settings.EMBEDDINGS_NAME == "huggingface_sentence-transformers/all-mpnet-base-v2":
try: word_embedding_dimension = getattr(embeddings, 'dimension', None)
word_embedding_dimension = embeddings.dimension if word_embedding_dimension is None:
except AttributeError as e: raise AttributeError("'dimension' attribute not found in embeddings instance.")
raise AttributeError("'dimension' attribute not found in embeddings instance. Make sure the embeddings object is properly initialized.") from e
docsearch_index_dimension = self.docsearch.index.d docsearch_index_dimension = self.docsearch.index.d
if word_embedding_dimension != docsearch_index_dimension: if word_embedding_dimension != docsearch_index_dimension:
raise ValueError(f"Embedding dimension mismatch: embeddings.dimension ({word_embedding_dimension}) " + raise ValueError(f"Embedding dimension mismatch: embeddings.dimension ({word_embedding_dimension}) != docsearch index dimension ({docsearch_index_dimension})")
f"!= docsearch index dimension ({docsearch_index_dimension})")