diff --git a/application/core/settings.py b/application/core/settings.py index 7030022a..d1b0e737 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -16,6 +16,7 @@ class Settings(BaseSettings): None # if LLM_PROVIDER is openai, LLM_NAME can be gpt-4 or gpt-3.5-turbo ) EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2" + EMBEDDINGS_PATH: Optional[str] = "./models/all-mpnet-base-v2" # Set None for SentenceTransformer to manage download CELERY_BROKER_URL: str = "redis://localhost:6379/0" CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1" MONGO_URI: str = "mongodb://localhost:27017/docsgpt" diff --git a/application/vectorstore/base.py b/application/vectorstore/base.py index a6b206c9..2179a629 100644 --- a/application/vectorstore/base.py +++ b/application/vectorstore/base.py @@ -74,17 +74,11 @@ class BaseVectorStore(ABC): embeddings_name, openai_api_key=embeddings_key ) - elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2": - if os.path.exists("./models/all-mpnet-base-v2"): - embedding_instance = EmbeddingsSingleton.get_instance( - embeddings_name = "./models/all-mpnet-base-v2", - ) - else: - embedding_instance = EmbeddingsSingleton.get_instance( - embeddings_name, - ) else: - embedding_instance = EmbeddingsSingleton.get_instance(embeddings_name) + model_identifier = embeddings_name + if settings.EMBEDDINGS_PATH and os.path.exists(settings.EMBEDDINGS_PATH): + model_identifier = settings.EMBEDDINGS_PATH + embedding_instance = EmbeddingsSingleton.get_instance(model_identifier) return embedding_instance