From 7140f2cd70d7b31e8416497fb86e03baadba014f Mon Sep 17 00:00:00 2001 From: Pavel Date: Sat, 21 Jun 2025 16:19:42 +0200 Subject: [PATCH] Backend I have removed the hardcoded part for the all-mpnet-base-v2 and substituted it with a separate EMBEDDINGS_PATH variable. Now the user is able to pre-download the embeddings (similarly to the recomended way with all-mpnet-base-v2) and also leave it to sentence-transformers without needing to touch vectorstore/base.py --- application/core/settings.py | 1 + application/vectorstore/base.py | 14 ++++---------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/application/core/settings.py b/application/core/settings.py index 7030022a..d1b0e737 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -16,6 +16,7 @@ class Settings(BaseSettings): None # if LLM_PROVIDER is openai, LLM_NAME can be gpt-4 or gpt-3.5-turbo ) EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2" + EMBEDDINGS_PATH: Optional[str] = "./models/all-mpnet-base-v2" # Set None for SentenceTransformer to manage download CELERY_BROKER_URL: str = "redis://localhost:6379/0" CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1" MONGO_URI: str = "mongodb://localhost:27017/docsgpt" diff --git a/application/vectorstore/base.py b/application/vectorstore/base.py index a6b206c9..2179a629 100644 --- a/application/vectorstore/base.py +++ b/application/vectorstore/base.py @@ -74,17 +74,11 @@ class BaseVectorStore(ABC): embeddings_name, openai_api_key=embeddings_key ) - elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2": - if os.path.exists("./models/all-mpnet-base-v2"): - embedding_instance = EmbeddingsSingleton.get_instance( - embeddings_name = "./models/all-mpnet-base-v2", - ) - else: - embedding_instance = EmbeddingsSingleton.get_instance( - embeddings_name, - ) else: - embedding_instance = EmbeddingsSingleton.get_instance(embeddings_name) + model_identifier = embeddings_name + if settings.EMBEDDINGS_PATH and os.path.exists(settings.EMBEDDINGS_PATH): + model_identifier = settings.EMBEDDINGS_PATH + embedding_instance = EmbeddingsSingleton.get_instance(model_identifier) return embedding_instance