Merge branch 'main' into 1059-migrating-database-to-new-model

This commit is contained in:
Alex
2024-09-09 23:55:25 +01:00
64 changed files with 3517 additions and 4971 deletions

View File

@@ -1,7 +1,7 @@
import asyncio
import os
import sys
from flask import Blueprint, request, Response
from flask import Blueprint, request, Response, current_app
import json
import datetime
import logging
@@ -126,7 +126,11 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm)
"role": "assistant",
"content": "Summarise following conversation in no more than 3 "
"words, respond ONLY with the summary, use the same "
"language as the system \n\nUser: " + question + "\n\n" + "AI: " + response,
"language as the system \n\nUser: "
+ question
+ "\n\n"
+ "AI: "
+ response,
},
{
"role": "user",
@@ -166,7 +170,10 @@ def get_prompt(prompt_id):
return prompt
def complete_stream(question, retriever, conversation_id, user_api_key):
def complete_stream(
question, retriever, conversation_id, user_api_key, isNoneDoc=False
):
try:
response_full = ""
source_log_docs = []
@@ -179,9 +186,17 @@ def complete_stream(question, retriever, conversation_id, user_api_key):
elif "source" in line:
source_log_docs.append(line["source"])
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key)
if isNoneDoc:
for doc in source_log_docs:
doc["source"] = "None"
llm = LLMCreator.create_llm(
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
)
if user_api_key is None:
conversation_id = save_conversation(conversation_id, question, response_full, source_log_docs, llm)
conversation_id = save_conversation(
conversation_id, question, response_full, source_log_docs, llm
)
# send data.type = "end" to indicate that the stream has ended as json
data = json.dumps({"type": "id", "id": str(conversation_id)})
yield f"data: {data}\n\n"
@@ -205,7 +220,6 @@ def complete_stream(question, retriever, conversation_id, user_api_key):
def stream():
try:
data = request.get_json()
# get parameter from url question
question = data["question"]
if "history" not in data:
history = []
@@ -252,10 +266,9 @@ def stream():
source = {}
user_api_key = None
""" if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local":
retriever_name = "classic"
else:
retriever_name = source["active_docs"] """
current_app.logger.info(f"/stream - request_data: {data}, source: {source}",
extra={"data": json.dumps({"request_data": data, "source": source})}
)
prompt = get_prompt(prompt_id)
@@ -277,20 +290,23 @@ def stream():
retriever=retriever,
conversation_id=conversation_id,
user_api_key=user_api_key,
isNoneDoc=data.get("isNoneDoc"),
),
mimetype="text/event-stream",
)
except ValueError as err:
except ValueError:
message = "Malformed request body"
print("\033[91merr", str(err), file=sys.stderr)
print("\033[91merr", str(message), file=sys.stderr)
return Response(
error_stream_generate(message),
status=400,
mimetype="text/event-stream",
)
except Exception as e:
print("\033[91merr", str(e), file=sys.stderr)
current_app.logger.error(f"/stream - error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()}
)
message = e.args[0]
status_code = 400
# # Custom exceptions with two arguments, index 1 as status code
@@ -357,6 +373,10 @@ def api_answer():
prompt = get_prompt(prompt_id)
current_app.logger.info(f"/api/answer - request_data: {data}, source: {source}",
extra={"data": json.dumps({"request_data": data, "source": source})}
)
retriever = RetrieverCreator.create_retriever(
retriever_name,
question=question,
@@ -376,7 +396,13 @@ def api_answer():
elif "answer" in line:
response_full += line["answer"]
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key)
if data.get("isNoneDoc"):
for doc in source_log_docs:
doc["source"] = "None"
llm = LLMCreator.create_llm(
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
)
result = {"answer": response_full, "sources": source_log_docs}
result["conversation_id"] = str(
@@ -385,16 +411,15 @@ def api_answer():
return result
except Exception as e:
# print whole traceback
traceback.print_exc()
print(str(e))
current_app.logger.error(f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()}
)
return bad_request(500, str(e))
@answer.route("/api/search", methods=["POST"])
def api_search():
data = request.get_json()
# get parameter from url question
question = data["question"]
if "chunks" in data:
chunks = int(data["chunks"])
@@ -420,6 +445,10 @@ def api_search():
token_limit = data["token_limit"]
else:
token_limit = settings.DEFAULT_MAX_HISTORY
current_app.logger.info(f"/api/answer - request_data: {data}, source: {source}",
extra={"data": json.dumps({"request_data": data, "source": source})}
)
retriever = RetrieverCreator.create_retriever(
retriever_name,
@@ -433,4 +462,9 @@ def api_search():
user_api_key=user_api_key,
)
docs = retriever.search()
if data.get("isNoneDoc"):
for doc in docs:
doc["source"] = "None"
return docs