better prompts

This commit is contained in:
Alex
2023-03-03 17:48:37 +00:00
parent 0fd39dd91c
commit 17047b6201
4 changed files with 47 additions and 25 deletions

View File

@@ -1,18 +1,21 @@
import os
import json
import traceback
import pprint
import dotenv
import requests
from flask import Flask, request, render_template
from langchain import FAISS
from langchain.llms import OpenAIChat
from langchain import VectorDBQA, HuggingFaceHub, Cohere
from langchain import VectorDBQA, HuggingFaceHub, Cohere, OpenAI
from langchain.chains.question_answering import load_qa_chain
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceHubEmbeddings, CohereEmbeddings, HuggingFaceInstructEmbeddings
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceHubEmbeddings, CohereEmbeddings, \
HuggingFaceInstructEmbeddings
from langchain.prompts import PromptTemplate
from error import bad_request
# os.environ["LANGCHAIN_HANDLER"] = "langchain"
os.environ["LANGCHAIN_HANDLER"] = "langchain"
if os.getenv("LLM_NAME") is not None:
llm_choice = os.getenv("LLM_NAME")
@@ -24,8 +27,6 @@ if os.getenv("EMBEDDINGS_NAME") is not None:
else:
embeddings_choice = "openai_text-embedding-ada-002"
if llm_choice == "manifest":
from manifest import Manifest
from langchain.llms.manifest import ManifestWrapper
@@ -53,6 +54,9 @@ with open("combine_prompt.txt", "r") as f:
with open("combine_prompt_hist.txt", "r") as f:
template_hist = f.read()
with open("question_prompt.txt", "r") as f:
template_quest = f.read()
if os.getenv("API_KEY") is not None:
api_key_set = True
else:
@@ -76,7 +80,7 @@ def api_answer():
data = request.get_json()
question = data["question"]
history = data["history"]
print('-'*5)
print('-' * 5)
if not api_key_set:
api_key = data["api_key"]
else:
@@ -95,7 +99,7 @@ def api_answer():
vectorstore = ""
else:
vectorstore = ""
#vectorstore = "outputs/inputs/"
# loading the index and the store and the prompt template
# Note if you have used other embeddings than OpenAI, you need to change the embeddings
if embeddings_choice == "openai_text-embedding-ada-002":
@@ -110,13 +114,19 @@ def api_answer():
# create a prompt template
if history:
history = json.loads(history)
template_temp = template_hist.replace("{historyquestion}", history[0]).replace("{historyanswer}", history[1])
c_prompt = PromptTemplate(input_variables=["summaries", "question"], template=template_temp, template_format="jinja2")
template_temp = template_hist.replace("{historyquestion}", history[0]).replace("{historyanswer}",
history[1])
c_prompt = PromptTemplate(input_variables=["summaries", "question"], template=template_temp,
template_format="jinja2")
else:
c_prompt = PromptTemplate(input_variables=["summaries", "question"], template=template, template_format="jinja2")
c_prompt = PromptTemplate(input_variables=["summaries", "question"], template=template,
template_format="jinja2")
q_prompt = PromptTemplate(input_variables=["context", "question"], template=template_quest,
template_format="jinja2")
if llm_choice == "openai":
llm = OpenAIChat(openai_api_key=api_key, temperature=0)
#llm = OpenAI(openai_api_key=api_key, temperature=0)
elif llm_choice == "manifest":
llm = ManifestWrapper(client=manifest, llm_kwargs={"temperature": 0.001, "max_tokens": 2048})
elif llm_choice == "huggingface":
@@ -125,14 +135,17 @@ def api_answer():
llm = Cohere(model="command-xlarge-nightly", cohere_api_key=api_key)
qa_chain = load_qa_chain(llm=llm, chain_type="map_reduce",
combine_prompt=c_prompt)
combine_prompt=c_prompt, question_prompt=q_prompt)
chain = VectorDBQA(combine_documents_chain=qa_chain, vectorstore=docsearch, k=4)
chain = VectorDBQA(combine_documents_chain=qa_chain, vectorstore=docsearch, k=25, return_source_documents=True)
# fetch the answer
result = chain({"query": question})
print(result)
# pprint.pprint(result)
# docs = docsearch.similarity_search(question, k=8)
for i in result['source_documents']:
print(i.page_content)
# some formatting for the frontend
result['answer'] = result['result']
@@ -141,6 +154,7 @@ def api_answer():
result['answer'] = result['answer'].split("SOURCES:")[0]
except:
pass
del result['source_documents']
# mock result
# result = {
@@ -152,7 +166,7 @@ def api_answer():
# print whole traceback
traceback.print_exc()
print(str(e))
return bad_request(500,str(e))
return bad_request(500, str(e))
@app.route("/api/docs_check", methods=["POST"])