diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index 4c393714..4c86cf4b 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -28,12 +28,15 @@ vectors_collection = db["vectors"] prompts_collection = db["prompts"] answer = Blueprint('answer', __name__) -if settings.LLM_NAME == "gpt4": - gpt_model = 'gpt-4' +gpt_model = "" +# to have some kind of default behaviour +if settings.LLM_NAME == "openai": + gpt_model = 'gpt-3.5-turbo' elif settings.LLM_NAME == "anthropic": gpt_model = 'claude-2' -else: - gpt_model = 'gpt-3.5-turbo' + +if settings.MODEL_NAME: # in case there is particular model name configured + gpt_model = settings.MODEL_NAME # load the prompts current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) diff --git a/application/core/settings.py b/application/core/settings.py index 84073b7d..0e1909e6 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -9,6 +9,7 @@ current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__ class Settings(BaseSettings): LLM_NAME: str = "docsgpt" + MODEL_NAME: Optional[str] = None # when LLM_NAME is openai, MODEL_NAME can be e.g. gpt-4-turbo-preview or gpt-3.5-turbo EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2" CELERY_BROKER_URL: str = "redis://localhost:6379/0" CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"