mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-12-02 18:13:13 +00:00
custom prompts
This commit is contained in:
@@ -25,6 +25,7 @@ mongo = MongoClient(settings.MONGO_URI)
|
||||
db = mongo["docsgpt"]
|
||||
conversations_collection = db["conversations"]
|
||||
vectors_collection = db["vectors"]
|
||||
prompts_collection = db["prompts"]
|
||||
answer = Blueprint('answer', __name__)
|
||||
|
||||
if settings.LLM_NAME == "gpt4":
|
||||
@@ -43,10 +44,10 @@ with open(os.path.join(current_dir, "prompts", "chat_reduce_prompt.txt"), "r") a
|
||||
chat_reduce_template = f.read()
|
||||
|
||||
with open(os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), "r") as f:
|
||||
chat_reduce_creative = f.read()
|
||||
chat_combine_creative = f.read()
|
||||
|
||||
with open(os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r") as f:
|
||||
chat_reduce_strict = f.read()
|
||||
chat_combine_strict = f.read()
|
||||
|
||||
api_key_set = settings.API_KEY is not None
|
||||
embeddings_key_set = settings.EMBEDDINGS_KEY is not None
|
||||
@@ -90,23 +91,6 @@ def get_vectorstore(data):
|
||||
return vectorstore
|
||||
|
||||
|
||||
# def get_docsearch(vectorstore, embeddings_key):
|
||||
# if settings.EMBEDDINGS_NAME == "openai_text-embedding-ada-002":
|
||||
# if is_azure_configured():
|
||||
# os.environ["OPENAI_API_TYPE"] = "azure"
|
||||
# openai_embeddings = OpenAIEmbeddings(model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME)
|
||||
# else:
|
||||
# openai_embeddings = OpenAIEmbeddings(openai_api_key=embeddings_key)
|
||||
# docsearch = FAISS.load_local(vectorstore, openai_embeddings)
|
||||
# elif settings.EMBEDDINGS_NAME == "huggingface_sentence-transformers/all-mpnet-base-v2":
|
||||
# docsearch = FAISS.load_local(vectorstore, HuggingFaceHubEmbeddings())
|
||||
# elif settings.EMBEDDINGS_NAME == "huggingface_hkunlp/instructor-large":
|
||||
# docsearch = FAISS.load_local(vectorstore, HuggingFaceInstructEmbeddings())
|
||||
# elif settings.EMBEDDINGS_NAME == "cohere_medium":
|
||||
# docsearch = FAISS.load_local(vectorstore, CohereEmbeddings(cohere_api_key=embeddings_key))
|
||||
# return docsearch
|
||||
|
||||
|
||||
def is_azure_configured():
|
||||
return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME
|
||||
|
||||
@@ -115,13 +99,16 @@ def complete_stream(question, docsearch, chat_history, api_key, prompt_id, conve
|
||||
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=api_key)
|
||||
|
||||
if prompt_id == 'default':
|
||||
prompt = chat_reduce_template
|
||||
prompt = chat_combine_template
|
||||
elif prompt_id == 'creative':
|
||||
prompt = chat_reduce_creative
|
||||
prompt = chat_combine_creative
|
||||
elif prompt_id == 'strict':
|
||||
prompt = chat_reduce_strict
|
||||
prompt = chat_combine_strict
|
||||
else:
|
||||
prompt = chat_reduce_template
|
||||
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
|
||||
import sys
|
||||
print(prompt_id, file=sys.stderr)
|
||||
print(prompt, file=sys.stderr)
|
||||
|
||||
|
||||
docs = docsearch.search(question, k=2)
|
||||
@@ -253,6 +240,19 @@ def api_answer():
|
||||
embeddings_key = data["embeddings_key"]
|
||||
else:
|
||||
embeddings_key = settings.EMBEDDINGS_KEY
|
||||
if 'prompt_id' in data:
|
||||
prompt_id = data["prompt_id"]
|
||||
else:
|
||||
prompt_id = 'default'
|
||||
|
||||
if prompt_id == 'default':
|
||||
prompt = chat_combine_template
|
||||
elif prompt_id == 'creative':
|
||||
prompt = chat_combine_creative
|
||||
elif prompt_id == 'strict':
|
||||
prompt = chat_combine_strict
|
||||
else:
|
||||
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
|
||||
|
||||
# use try and except to check for exception
|
||||
try:
|
||||
@@ -270,7 +270,7 @@ def api_answer():
|
||||
docs = docsearch.search(question, k=2)
|
||||
# join all page_content together with a newline
|
||||
docs_together = "\n".join([doc.page_content for doc in docs])
|
||||
p_chat_combine = chat_combine_template.replace("{summaries}", docs_together)
|
||||
p_chat_combine = prompt.replace("{summaries}", docs_together)
|
||||
messages_combine = [{"role": "system", "content": p_chat_combine}]
|
||||
source_log_docs = []
|
||||
for doc in docs:
|
||||
|
||||
@@ -246,18 +246,21 @@ def check_docs():
|
||||
@user.route("/api/create_prompt", methods=["POST"])
|
||||
def create_prompt():
|
||||
data = request.get_json()
|
||||
prompt = data["prompt"]
|
||||
content = data["content"]
|
||||
name = data["name"]
|
||||
if name == "":
|
||||
return {"status": "error"}
|
||||
user = "local"
|
||||
# write to mongodb
|
||||
prompts_collection.insert_one(
|
||||
resp = prompts_collection.insert_one(
|
||||
{
|
||||
"name": name,
|
||||
"prompt": prompt,
|
||||
"content": content,
|
||||
"user": user,
|
||||
}
|
||||
)
|
||||
return {"status": "ok"}
|
||||
new_id = str(resp.inserted_id)
|
||||
return {"id": new_id, "name": name, "content": content}
|
||||
|
||||
@user.route("/api/get_prompts", methods=["GET"])
|
||||
def get_prompts():
|
||||
@@ -268,32 +271,51 @@ def get_prompts():
|
||||
list_prompts.append({"id": "creative", "name": "creative", "type": "public"})
|
||||
list_prompts.append({"id": "precise", "name": "precise", "type": "public"})
|
||||
for prompt in prompts:
|
||||
list_prompts.append({"id": str(prompt["_id"]), "name": prompt["name"], type: "private"})
|
||||
list_prompts.append({"id": str(prompt["_id"]), "name": prompt["name"], "type": "private"})
|
||||
|
||||
return jsonify(list_prompts)
|
||||
|
||||
@user.route("/api/get_single_prompt", methods=["GET"])
|
||||
def get_single_prompt():
|
||||
prompt_id = request.args.get("id")
|
||||
if prompt_id == 'default':
|
||||
with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f:
|
||||
chat_combine_template = f.read()
|
||||
return jsonify({"content": chat_combine_template})
|
||||
elif prompt_id == 'creative':
|
||||
with open(os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), "r") as f:
|
||||
chat_reduce_creative = f.read()
|
||||
return jsonify({"content": chat_reduce_creative})
|
||||
elif prompt_id == 'strict':
|
||||
with open(os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r") as f:
|
||||
chat_reduce_strict = f.read()
|
||||
return jsonify({"content": chat_reduce_strict})
|
||||
|
||||
|
||||
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})
|
||||
return jsonify(prompt['prompt'])
|
||||
return jsonify({"content": prompt["content"]})
|
||||
|
||||
@user.route("/api/delete_prompt", methods=["POST"])
|
||||
def delete_prompt():
|
||||
prompt_id = request.args.get("id")
|
||||
data = request.get_json()
|
||||
id = data["id"]
|
||||
prompts_collection.delete_one(
|
||||
{
|
||||
"_id": ObjectId(prompt_id),
|
||||
"_id": ObjectId(id),
|
||||
}
|
||||
)
|
||||
return {"status": "ok"}
|
||||
|
||||
@user.route("/api/update_prompt_name", methods=["POST"])
|
||||
@user.route("/api/update_prompt", methods=["POST"])
|
||||
def update_prompt_name():
|
||||
data = request.get_json()
|
||||
id = data["id"]
|
||||
name = data["name"]
|
||||
prompts_collection.update_one({"_id": ObjectId(id)},{"$set":{"name":name}})
|
||||
content = data["content"]
|
||||
# check if name is null
|
||||
if name == "":
|
||||
return {"status": "error"}
|
||||
prompts_collection.update_one({"_id": ObjectId(id)},{"$set":{"name":name, "content": content}})
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user