This commit is contained in:
Alex
2023-11-14 01:16:06 +00:00
parent a3de360878
commit 0974085c6f
13 changed files with 345 additions and 198 deletions

View File

@@ -36,21 +36,18 @@ else:
# load the prompts
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
with open(os.path.join(current_dir, "prompts", "combine_prompt.txt"), "r") as f:
template = f.read()
with open(os.path.join(current_dir, "prompts", "combine_prompt_hist.txt"), "r") as f:
template_hist = f.read()
with open(os.path.join(current_dir, "prompts", "question_prompt.txt"), "r") as f:
template_quest = f.read()
with open(os.path.join(current_dir, "prompts", "chat_combine_prompt.txt"), "r") as f:
with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f:
chat_combine_template = f.read()
with open(os.path.join(current_dir, "prompts", "chat_reduce_prompt.txt"), "r") as f:
chat_reduce_template = f.read()
with open(os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), "r") as f:
chat_reduce_creative = f.read()
with open(os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r") as f:
chat_reduce_strict = f.read()
api_key_set = settings.API_KEY is not None
embeddings_key_set = settings.EMBEDDINGS_KEY is not None
@@ -115,8 +112,17 @@ def is_azure_configured():
return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME
def complete_stream(question, docsearch, chat_history, api_key, conversation_id):
def complete_stream(question, docsearch, chat_history, api_key, prompt_id, conversation_id):
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=api_key)
if prompt_id == 'default':
prompt = chat_reduce_template
elif prompt_id == 'creative':
prompt = chat_reduce_creative
elif prompt_id == 'strict':
prompt = chat_reduce_strict
else:
prompt = chat_reduce_template
docs = docsearch.search(question, k=2)
@@ -124,7 +130,7 @@ def complete_stream(question, docsearch, chat_history, api_key, conversation_id)
docs = [docs[0]]
# join all page_content together with a newline
docs_together = "\n".join([doc.page_content for doc in docs])
p_chat_combine = chat_combine_template.replace("{summaries}", docs_together)
p_chat_combine = prompt.replace("{summaries}", docs_together)
messages_combine = [{"role": "system", "content": p_chat_combine}]
source_log_docs = []
for doc in docs:
@@ -201,6 +207,10 @@ def stream():
# history to json object from string
history = json.loads(history)
conversation_id = data["conversation_id"]
if 'prompt_id' in data:
prompt_id = data["prompt_id"]
else:
prompt_id = 'default'
# check if active_docs is set
@@ -221,6 +231,7 @@ def stream():
return Response(
complete_stream(question, docsearch,
chat_history=history, api_key=api_key,
prompt_id=prompt_id,
conversation_id=conversation_id), mimetype="text/event-stream"
)

View File

@@ -16,6 +16,7 @@ mongo = MongoClient(settings.MONGO_URI)
db = mongo["docsgpt"]
conversations_collection = db["conversations"]
vectors_collection = db["vectors"]
prompts_collection = db["prompts"]
user = Blueprint('user', __name__)
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@@ -188,7 +189,7 @@ def combined_json():
"date": "default",
"docLink": "default",
"model": settings.EMBEDDINGS_NAME,
"location": "local",
"location": "remote",
}
]
# structure: name, language, version, description, fullName, date, docLink
@@ -245,6 +246,59 @@ def check_docs():
return {"status": "loaded"}
@user.route("/api/create_prompt", methods=["POST"])
def create_prompt():
data = request.get_json()
prompt = data["prompt"]
name = data["name"]
user = "local"
# write to mongodb
prompts_collection.insert_one(
{
"name": name,
"prompt": prompt,
"user": user,
}
)
return {"status": "ok"}
@user.route("/api/get_prompts", methods=["GET"])
def get_prompts():
user = "local"
prompts = prompts_collection.find({"user": user})
list_prompts = []
list_prompts.append({"id": "default", "name": "default", "type": "public"})
list_prompts.append({"id": "creative", "name": "creative", "type": "public"})
list_prompts.append({"id": "precise", "name": "precise", "type": "public"})
for prompt in prompts:
list_prompts.append({"id": str(prompt["_id"]), "name": prompt["name"], type: "private"})
return jsonify(list_prompts)
@user.route("/api/get_single_prompt", methods=["GET"])
def get_single_prompt():
prompt_id = request.args.get("id")
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})
return jsonify(prompt['prompt'])
@user.route("/api/delete_prompt", methods=["POST"])
def delete_prompt():
prompt_id = request.args.get("id")
prompts_collection.delete_one(
{
"_id": ObjectId(prompt_id),
}
)
return {"status": "ok"}
@user.route("/api/update_prompt_name", methods=["POST"])
def update_prompt_name():
data = request.get_json()
id = data["id"]
name = data["name"]
prompts_collection.update_one({"_id": ObjectId(id)},{"$set":{"name":name}})
return {"status": "ok"}