chat prompts

This commit is contained in:
Alex
2023-03-08 17:50:07 +00:00
parent 0799728000
commit 6d959051e2
6 changed files with 53 additions and 18 deletions

View File

@@ -1,25 +1,31 @@
import os
import json
import os
import traceback
import dotenv
import requests
from flask import Flask, request, render_template
from langchain import FAISS
from langchain.llms import OpenAIChat
from langchain import VectorDBQA, HuggingFaceHub, Cohere, OpenAI
from langchain.chains.question_answering import load_qa_chain
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceHubEmbeddings, CohereEmbeddings, \
HuggingFaceInstructEmbeddings
from langchain.prompts import PromptTemplate
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from error import bad_request
os.environ["LANGCHAIN_HANDLER"] = "langchain"
# os.environ["LANGCHAIN_HANDLER"] = "langchain"
if os.getenv("LLM_NAME") is not None:
llm_choice = os.getenv("LLM_NAME")
else:
llm_choice = "openai"
llm_choice = "openai_chat"
if os.getenv("EMBEDDINGS_NAME") is not None:
embeddings_choice = os.getenv("EMBEDDINGS_NAME")
@@ -47,15 +53,21 @@ if platform.system() == "Windows":
# loading the .env file
dotenv.load_dotenv()
with open("combine_prompt.txt", "r") as f:
with open("prompts/combine_prompt.txt", "r") as f:
template = f.read()
with open("combine_prompt_hist.txt", "r") as f:
with open("prompts/combine_prompt_hist.txt", "r") as f:
template_hist = f.read()
with open("question_prompt.txt", "r") as f:
with open("prompts/question_prompt.txt", "r") as f:
template_quest = f.read()
with open("prompts/chat_combine_prompt.txt", "r") as f:
chat_combine_template = f.read()
with open("prompts/chat_reduce_prompt.txt", "r") as f:
chat_reduce_template = f.read()
if os.getenv("API_KEY") is not None:
api_key_set = True
else:
@@ -98,7 +110,7 @@ def api_answer():
vectorstore = ""
else:
vectorstore = ""
#vectorstore = "outputs/inputs/"
# vectorstore = "outputs/inputs/"
# loading the index and the store and the prompt template
# Note if you have used other embeddings than OpenAI, you need to change the embeddings
if embeddings_choice == "openai_text-embedding-ada-002":
@@ -123,9 +135,20 @@ def api_answer():
q_prompt = PromptTemplate(input_variables=["context", "question"], template=template_quest,
template_format="jinja2")
if llm_choice == "openai":
llm = OpenAIChat(openai_api_key=api_key, temperature=0)
#llm = OpenAI(openai_api_key=api_key, temperature=0)
if llm_choice == "openai_chat":
llm = ChatOpenAI(openai_api_key=api_key)
messages_combine = [
SystemMessagePromptTemplate.from_template(chat_combine_template),
HumanMessagePromptTemplate.from_template("{question}")
]
p_chat_combine = ChatPromptTemplate.from_messages(messages_combine)
messages_reduce = [
SystemMessagePromptTemplate.from_template(chat_reduce_template),
HumanMessagePromptTemplate.from_template("{question}")
]
p_chat_reduce = ChatPromptTemplate.from_messages(messages_reduce)
elif llm_choice == "openai":
llm = OpenAI(openai_api_key=api_key, temperature=0)
elif llm_choice == "manifest":
llm = ManifestWrapper(client=manifest, llm_kwargs={"temperature": 0.001, "max_tokens": 2048})
elif llm_choice == "huggingface":
@@ -133,13 +156,19 @@ def api_answer():
elif llm_choice == "cohere":
llm = Cohere(model="command-xlarge-nightly", cohere_api_key=api_key)
qa_chain = load_qa_chain(llm=llm, chain_type="map_reduce",
combine_prompt=c_prompt, question_prompt=q_prompt)
if llm_choice == "openai_chat":
chain = VectorDBQA.from_chain_type(llm=llm, chain_type="map_reduce", vectorstore=docsearch,
k=4,
chain_type_kwargs={"question_prompt": p_chat_reduce,
"combine_prompt": p_chat_combine})
result = chain({"query": question})
else:
qa_chain = load_qa_chain(llm=llm, chain_type="map_reduce",
combine_prompt=c_prompt, question_prompt=q_prompt)
chain = VectorDBQA(combine_documents_chain=qa_chain, vectorstore=docsearch, k=4)
result = chain({"query": question})
chain = VectorDBQA(combine_documents_chain=qa_chain, vectorstore=docsearch, k=4)
# fetch the answer
result = chain({"query": question})
print(result)
# some formatting for the frontend
result['answer'] = result['result']
@@ -215,7 +244,6 @@ def api_feedback():
return {"status": 'ok'}
# handling CORS
@app.after_request
def after_request(response):