code readability, formatting, minor version bump

This commit is contained in:
Anton Larin
2023-06-17 12:40:28 +02:00
parent 8bee47dc50
commit 968849e52b
2 changed files with 13 additions and 2 deletions

View File

@@ -4,6 +4,7 @@ import http.client
import json
import os
import traceback
import logging
import dotenv
import openai
@@ -40,6 +41,8 @@ from worker import ingest_worker
# os.environ["LANGCHAIN_HANDLER"] = "langchain"
logger = logging.getLogger(__name__)
if settings.LLM_NAME == "manifest":
from manifest import Manifest
from langchain.llms.manifest import ManifestWrapper
@@ -176,7 +179,7 @@ def complete_stream(question, docsearch, chat_history, api_key):
messages_combine.append({"role": "user", "content": question})
completion = openai.ChatCompletion.create(model="gpt-3.5-turbo",
messages=messages_combine, stream=True, max_tokens=500, temperature=0)
for line in completion:
if "content" in line["choices"][0]["delta"]:
# check if the delta contains content
@@ -217,6 +220,10 @@ def stream():
)
def is_azure_configured():
return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME
@app.route("/api/answer", methods=["POST"])
def api_answer():
data = request.get_json()
@@ -244,7 +251,8 @@ def api_answer():
input_variables=["context", "question"], template=template_quest, template_format="jinja2"
)
if settings.LLM_NAME == "openai_chat":
if settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME: # azure
if is_azure_configured():
logger.debug("in Azure")
llm = AzureChatOpenAI(
openai_api_key=api_key,
openai_api_base=settings.OPENAI_API_BASE,
@@ -252,6 +260,7 @@ def api_answer():
deployment_name=settings.AZURE_DEPLOYMENT_NAME,
)
else:
logger.debug("plain OpenAI")
llm = ChatOpenAI(openai_api_key=api_key) # optional parameter: model_name="gpt-4"
messages_combine = [SystemMessagePromptTemplate.from_template(chat_combine_template)]
if history: