From 006897f1c0571454351ddffb3da542d3f7ddfe7c Mon Sep 17 00:00:00 2001 From: Anton Larin Date: Sat, 17 Jun 2023 13:20:29 +0200 Subject: [PATCH] Azure support for streaming output. --- application/.env_sample | 3 ++- application/app.py | 19 +++++++++++++++++-- application/core/settings.py | 3 ++- 3 files changed, 21 insertions(+), 4 deletions(-) diff --git a/application/.env_sample b/application/.env_sample index 24c6495f..a9d58627 100644 --- a/application/.env_sample +++ b/application/.env_sample @@ -8,4 +8,5 @@ API_URL=http://localhost:5001 #For OPENAI on Azure OPENAI_API_BASE= OPENAI_API_VERSION= -AZURE_DEPLOYMENT_NAME= \ No newline at end of file +AZURE_DEPLOYMENT_NAME= +AZURE_EMBEDDINGS_DEPLOYMENT_NAME= \ No newline at end of file diff --git a/application/app.py b/application/app.py index 93055342..b987d4ab 100644 --- a/application/app.py +++ b/application/app.py @@ -127,7 +127,12 @@ def get_vectorstore(data): def get_docsearch(vectorstore, embeddings_key): if settings.EMBEDDINGS_NAME == "openai_text-embedding-ada-002": - docsearch = FAISS.load_local(vectorstore, OpenAIEmbeddings(openai_api_key=embeddings_key)) + if is_azure_configured(): + os.environ["OPENAI_API_TYPE"] = "azure" + openai_embeddings = OpenAIEmbeddings(model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME) + else: + openai_embeddings = OpenAIEmbeddings(openai_api_key=embeddings_key) + docsearch = FAISS.load_local(vectorstore, openai_embeddings) elif settings.EMBEDDINGS_NAME == "huggingface_sentence-transformers/all-mpnet-base-v2": docsearch = FAISS.load_local(vectorstore, HuggingFaceHubEmbeddings()) elif settings.EMBEDDINGS_NAME == "huggingface_hkunlp/instructor-large": @@ -152,7 +157,17 @@ def home(): def complete_stream(question, docsearch, chat_history, api_key): openai.api_key = api_key - llm = ChatOpenAI(openai_api_key=api_key) + if is_azure_configured(): + logger.debug("in Azure") + llm = AzureChatOpenAI( + openai_api_key=api_key, + openai_api_base=settings.OPENAI_API_BASE, + openai_api_version=settings.OPENAI_API_VERSION, + deployment_name=settings.AZURE_DEPLOYMENT_NAME, + ) + else: + logger.debug("plain OpenAI") + llm = ChatOpenAI(openai_api_key=api_key) docs = docsearch.similarity_search(question, k=2) # join all page_content together with a newline docs_together = "\n".join([doc.page_content for doc in docs]) diff --git a/application/core/settings.py b/application/core/settings.py index ed621bb8..853f1526 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -18,7 +18,8 @@ class Settings(BaseSettings): EMBEDDINGS_KEY: str = None # api key for embeddings (if using openai, just copy API_KEY OPENAI_API_BASE: str = None # azure openai api base url OPENAI_API_VERSION: str = None # azure openai api version - AZURE_DEPLOYMENT_NAME: str = None # azure deployment name + AZURE_DEPLOYMENT_NAME: str = None # azure deployment name for answering + AZURE_EMBEDDINGS_DEPLOYMENT_NAME: str = None # azure deployment name for embeddings path = Path(__file__).parent.parent.absolute()