mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
better prompts
This commit is contained in:
@@ -1,18 +1,21 @@
|
||||
import os
|
||||
import json
|
||||
import traceback
|
||||
import pprint
|
||||
|
||||
import dotenv
|
||||
import requests
|
||||
from flask import Flask, request, render_template
|
||||
from langchain import FAISS
|
||||
from langchain.llms import OpenAIChat
|
||||
from langchain import VectorDBQA, HuggingFaceHub, Cohere
|
||||
from langchain import VectorDBQA, HuggingFaceHub, Cohere, OpenAI
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceHubEmbeddings, CohereEmbeddings, HuggingFaceInstructEmbeddings
|
||||
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceHubEmbeddings, CohereEmbeddings, \
|
||||
HuggingFaceInstructEmbeddings
|
||||
from langchain.prompts import PromptTemplate
|
||||
from error import bad_request
|
||||
# os.environ["LANGCHAIN_HANDLER"] = "langchain"
|
||||
|
||||
os.environ["LANGCHAIN_HANDLER"] = "langchain"
|
||||
|
||||
if os.getenv("LLM_NAME") is not None:
|
||||
llm_choice = os.getenv("LLM_NAME")
|
||||
@@ -24,8 +27,6 @@ if os.getenv("EMBEDDINGS_NAME") is not None:
|
||||
else:
|
||||
embeddings_choice = "openai_text-embedding-ada-002"
|
||||
|
||||
|
||||
|
||||
if llm_choice == "manifest":
|
||||
from manifest import Manifest
|
||||
from langchain.llms.manifest import ManifestWrapper
|
||||
@@ -53,6 +54,9 @@ with open("combine_prompt.txt", "r") as f:
|
||||
with open("combine_prompt_hist.txt", "r") as f:
|
||||
template_hist = f.read()
|
||||
|
||||
with open("question_prompt.txt", "r") as f:
|
||||
template_quest = f.read()
|
||||
|
||||
if os.getenv("API_KEY") is not None:
|
||||
api_key_set = True
|
||||
else:
|
||||
@@ -76,7 +80,7 @@ def api_answer():
|
||||
data = request.get_json()
|
||||
question = data["question"]
|
||||
history = data["history"]
|
||||
print('-'*5)
|
||||
print('-' * 5)
|
||||
if not api_key_set:
|
||||
api_key = data["api_key"]
|
||||
else:
|
||||
@@ -95,7 +99,7 @@ def api_answer():
|
||||
vectorstore = ""
|
||||
else:
|
||||
vectorstore = ""
|
||||
|
||||
#vectorstore = "outputs/inputs/"
|
||||
# loading the index and the store and the prompt template
|
||||
# Note if you have used other embeddings than OpenAI, you need to change the embeddings
|
||||
if embeddings_choice == "openai_text-embedding-ada-002":
|
||||
@@ -110,13 +114,19 @@ def api_answer():
|
||||
# create a prompt template
|
||||
if history:
|
||||
history = json.loads(history)
|
||||
template_temp = template_hist.replace("{historyquestion}", history[0]).replace("{historyanswer}", history[1])
|
||||
c_prompt = PromptTemplate(input_variables=["summaries", "question"], template=template_temp, template_format="jinja2")
|
||||
template_temp = template_hist.replace("{historyquestion}", history[0]).replace("{historyanswer}",
|
||||
history[1])
|
||||
c_prompt = PromptTemplate(input_variables=["summaries", "question"], template=template_temp,
|
||||
template_format="jinja2")
|
||||
else:
|
||||
c_prompt = PromptTemplate(input_variables=["summaries", "question"], template=template, template_format="jinja2")
|
||||
c_prompt = PromptTemplate(input_variables=["summaries", "question"], template=template,
|
||||
template_format="jinja2")
|
||||
|
||||
q_prompt = PromptTemplate(input_variables=["context", "question"], template=template_quest,
|
||||
template_format="jinja2")
|
||||
if llm_choice == "openai":
|
||||
llm = OpenAIChat(openai_api_key=api_key, temperature=0)
|
||||
#llm = OpenAI(openai_api_key=api_key, temperature=0)
|
||||
elif llm_choice == "manifest":
|
||||
llm = ManifestWrapper(client=manifest, llm_kwargs={"temperature": 0.001, "max_tokens": 2048})
|
||||
elif llm_choice == "huggingface":
|
||||
@@ -125,14 +135,17 @@ def api_answer():
|
||||
llm = Cohere(model="command-xlarge-nightly", cohere_api_key=api_key)
|
||||
|
||||
qa_chain = load_qa_chain(llm=llm, chain_type="map_reduce",
|
||||
combine_prompt=c_prompt)
|
||||
combine_prompt=c_prompt, question_prompt=q_prompt)
|
||||
|
||||
chain = VectorDBQA(combine_documents_chain=qa_chain, vectorstore=docsearch, k=4)
|
||||
chain = VectorDBQA(combine_documents_chain=qa_chain, vectorstore=docsearch, k=25, return_source_documents=True)
|
||||
|
||||
|
||||
# fetch the answer
|
||||
result = chain({"query": question})
|
||||
print(result)
|
||||
# pprint.pprint(result)
|
||||
# docs = docsearch.similarity_search(question, k=8)
|
||||
|
||||
for i in result['source_documents']:
|
||||
print(i.page_content)
|
||||
|
||||
# some formatting for the frontend
|
||||
result['answer'] = result['result']
|
||||
@@ -141,6 +154,7 @@ def api_answer():
|
||||
result['answer'] = result['answer'].split("SOURCES:")[0]
|
||||
except:
|
||||
pass
|
||||
del result['source_documents']
|
||||
|
||||
# mock result
|
||||
# result = {
|
||||
@@ -152,7 +166,7 @@ def api_answer():
|
||||
# print whole traceback
|
||||
traceback.print_exc()
|
||||
print(str(e))
|
||||
return bad_request(500,str(e))
|
||||
return bad_request(500, str(e))
|
||||
|
||||
|
||||
@app.route("/api/docs_check", methods=["POST"])
|
||||
|
||||
Reference in New Issue
Block a user