From c9d24b8f42053ceebf9d42a9da7d8233a60a53cf Mon Sep 17 00:00:00 2001 From: Serj Date: Sat, 29 Apr 2023 15:44:47 +0100 Subject: [PATCH] Added llm model variable --- application/app.py | 22 +++++++++------------- application/core/settings.py | 1 + 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/application/app.py b/application/app.py index 5defcc7d..8419cf52 100644 --- a/application/app.py +++ b/application/app.py @@ -28,21 +28,17 @@ from werkzeug.utils import secure_filename from error import bad_request from worker import ingest_worker +from core.settings import settings import celeryconfig # os.environ["LANGCHAIN_HANDLER"] = "langchain" -if os.getenv("LLM_NAME") is not None: - llm_choice = os.getenv("LLM_NAME") -else: - llm_choice = "openai_chat" - if os.getenv("EMBEDDINGS_NAME") is not None: embeddings_choice = os.getenv("EMBEDDINGS_NAME") else: embeddings_choice = "openai_text-embedding-ada-002" -if llm_choice == "manifest": +if settings.LLM_NAME == "manifest": from manifest import Manifest from langchain.llms.manifest import ManifestWrapper @@ -122,7 +118,7 @@ def ingest(self, directory, formats, name_job, filename, user): @app.route("/") def home(): - return render_template("index.html", api_key_set=api_key_set, llm_choice=llm_choice, + return render_template("index.html", api_key_set=api_key_set, llm_choice=settings.LLM_NAME, embeddings_choice=embeddings_choice) @@ -182,7 +178,7 @@ def api_answer(): q_prompt = PromptTemplate(input_variables=["context", "question"], template=template_quest, template_format="jinja2") - if llm_choice == "openai_chat": + if settings.LLM_NAME == "openai_chat": # llm = ChatOpenAI(openai_api_key=api_key, model_name="gpt-4") llm = ChatOpenAI(openai_api_key=api_key) messages_combine = [ @@ -195,16 +191,16 @@ def api_answer(): HumanMessagePromptTemplate.from_template("{question}") ] p_chat_reduce = ChatPromptTemplate.from_messages(messages_reduce) - elif llm_choice == "openai": + elif settings.LLM_NAME == "openai": llm = OpenAI(openai_api_key=api_key, temperature=0) - elif llm_choice == "manifest": + elif settings.LLM_NAME == "manifest": llm = ManifestWrapper(client=manifest, llm_kwargs={"temperature": 0.001, "max_tokens": 2048}) - elif llm_choice == "huggingface": + elif settings.LLM_NAME == "huggingface": llm = HuggingFaceHub(repo_id="bigscience/bloom", huggingfacehub_api_token=api_key) - elif llm_choice == "cohere": + elif settings.LLM_NAME == "cohere": llm = Cohere(model="command-xlarge-nightly", cohere_api_key=api_key) - if llm_choice == "openai_chat": + if settings.LLM_NAME == "openai_chat": question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT) doc_chain = load_qa_chain(llm, chain_type="map_reduce", combine_prompt=p_chat_combine) chain = ConversationalRetrievalChain( diff --git a/application/core/settings.py b/application/core/settings.py index 416b903d..bb5063f5 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -3,6 +3,7 @@ from pathlib import Path class Settings(BaseSettings): + LLM_NAME: str = "openai_chat" openai_token: str