Merge branch 'main' into tool-use

This commit is contained in:
Alex
2024-12-20 17:29:41 +00:00
committed by GitHub
84 changed files with 4317 additions and 5522 deletions

View File

@@ -18,7 +18,7 @@ from application.error import bad_request
from application.extensions import api
from application.llm.llm_creator import LLMCreator
from application.retriever.retriever_creator import RetrieverCreator
from application.utils import check_required_fields
from application.utils import check_required_fields, limit_chat_history
logger = logging.getLogger(__name__)
@@ -118,8 +118,31 @@ def is_azure_configured():
)
def save_conversation(conversation_id, question, response, source_log_docs, llm):
if conversation_id is not None and conversation_id != "None":
def save_conversation(conversation_id, question, response, source_log_docs, llm,index=None):
if conversation_id is not None and index is not None:
conversations_collection.update_one(
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
{
"$set": {
f"queries.{index}.prompt": question,
f"queries.{index}.response": response,
f"queries.{index}.sources": source_log_docs,
}
}
)
##remove following queries from the array
conversations_collection.update_one(
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
{
"$push":{
"queries":{
"$each":[],
"$slice":index+1
}
}
}
)
elif conversation_id is not None and conversation_id != "None":
conversations_collection.update_one(
{"_id": ObjectId(conversation_id)},
{
@@ -141,17 +164,17 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm)
"role": "assistant",
"content": "Summarise following conversation in no more than 3 "
"words, respond ONLY with the summary, use the same "
"language as the system \n\nUser: "
+ question
+ "\n\n"
+ "AI: "
+ response,
"language as the system",
},
{
"role": "user",
"content": "Summarise following conversation in no more than 3 words, "
"respond ONLY with the summary, use the same language as the "
"system",
"system \n\nUser: "
+ question
+ "\n\n"
+ "AI: "
+ response,
},
]
@@ -186,7 +209,7 @@ def get_prompt(prompt_id):
def complete_stream(
question, retriever, conversation_id, user_api_key, isNoneDoc=False
question, retriever, conversation_id, user_api_key, isNoneDoc=False,index=None
):
try:
@@ -217,7 +240,7 @@ def complete_stream(
)
if user_api_key is None:
conversation_id = save_conversation(
conversation_id, question, response_full, source_log_docs, llm
conversation_id, question, response_full, source_log_docs, llm,index
)
# send data.type = "end" to indicate that the stream has ended as json
data = json.dumps({"type": "id", "id": str(conversation_id)})
@@ -282,6 +305,9 @@ class Stream(Resource):
"isNoneDoc": fields.Boolean(
required=False, description="Flag indicating if no document is used"
),
"index":fields.Integer(
required=False, description="The position where query is to be updated"
),
},
)
@@ -290,23 +316,23 @@ class Stream(Resource):
def post(self):
data = request.get_json()
required_fields = ["question"]
if "index" in data:
required_fields = ["question","conversation_id"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
question = data["question"]
history = data.get("history", [])
history = json.loads(history)
history = limit_chat_history(json.loads(data.get("history", [])), gpt_model=gpt_model)
conversation_id = data.get("conversation_id")
prompt_id = data.get("prompt_id", "default")
index=data.get("index",None)
chunks = int(data.get("chunks", 2))
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
retriever_name = data.get("retriever", "classic")
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
chunks = int(data_key.get("chunks", 2))
@@ -343,7 +369,7 @@ class Stream(Resource):
gpt_model=gpt_model,
user_api_key=user_api_key,
)
return Response(
complete_stream(
question=question,
@@ -351,6 +377,7 @@ class Stream(Resource):
conversation_id=conversation_id,
user_api_key=user_api_key,
isNoneDoc=data.get("isNoneDoc"),
index=index,
),
mimetype="text/event-stream",
)
@@ -428,7 +455,7 @@ class Answer(Resource):
try:
question = data["question"]
history = data.get("history", [])
history = limit_chat_history(json.loads(data.get("history", [])), gpt_model=gpt_model)
conversation_id = data.get("conversation_id")
prompt_id = data.get("prompt_id", "default")
chunks = int(data.get("chunks", 2))

View File

@@ -181,10 +181,12 @@ class SubmitFeedback(Resource):
"FeedbackModel",
{
"question": fields.String(
required=True, description="The user question"
required=False, description="The user question"
),
"answer": fields.String(required=True, description="The AI answer"),
"answer": fields.String(required=False, description="The AI answer"),
"feedback": fields.String(required=True, description="User feedback"),
"question_index":fields.Integer(required=True, description="The question number in that particular conversation"),
"conversation_id":fields.String(required=True, description="id of the particular conversation"),
"api_key": fields.String(description="Optional API key"),
},
)
@@ -194,23 +196,21 @@ class SubmitFeedback(Resource):
)
def post(self):
data = request.get_json()
required_fields = ["question", "answer", "feedback"]
required_fields = [ "feedback","conversation_id","question_index"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
new_doc = {
"question": data["question"],
"answer": data["answer"],
"feedback": data["feedback"],
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
if "api_key" in data:
new_doc["api_key"] = data["api_key"]
try:
feedback_collection.insert_one(new_doc)
conversations_collection.update_one(
{"_id": ObjectId(data["conversation_id"]), f"queries.{data['question_index']}": {"$exists": True}},
{
"$set": {
f"queries.{data['question_index']}.feedback": data["feedback"]
}
}
)
except Exception as err:
return make_response(jsonify({"success": False, "error": str(err)}), 400)
@@ -253,13 +253,12 @@ class DeleteOldIndexes(Resource):
jsonify({"success": False, "message": "Missing required fields"}), 400
)
try:
doc = sources_collection.find_one(
doc = sources_collection.find_one(
{"_id": ObjectId(source_id), "user": "local"}
)
if not doc:
)
if not doc:
return make_response(jsonify({"status": "not found"}), 404)
try:
if settings.VECTOR_STORE == "faiss":
shutil.rmtree(os.path.join(current_dir, "indexes", str(doc["_id"])))
else:
@@ -268,12 +267,12 @@ class DeleteOldIndexes(Resource):
)
vectorstore.delete_index()
sources_collection.delete_one({"_id": ObjectId(source_id)})
except FileNotFoundError:
pass
except Exception as err:
return make_response(jsonify({"success": False, "error": str(err)}), 400)
sources_collection.delete_one({"_id": ObjectId(source_id)})
return make_response(jsonify({"success": True}), 200)
@@ -344,6 +343,9 @@ class UploadFile(Resource):
".json",
".xlsx",
".pptx",
".png",
".jpg",
".jpeg",
],
job_name,
final_filename,
@@ -370,6 +372,9 @@ class UploadFile(Resource):
".json",
".xlsx",
".pptx",
".png",
".jpg",
".jpeg",
],
job_name,
final_filename,
@@ -478,11 +483,22 @@ class PaginatedSources(Resource):
sort_order = request.args.get("order", "desc") # Default to 'desc'
page = int(request.args.get("page", 1)) # Default to 1
rows_per_page = int(request.args.get("rows", 10)) # Default to 10
# add .strip() to remove leading and trailing whitespaces
search_term = request.args.get(
"search", ""
).strip() # add search for filter documents
# Prepare
# Prepare query for filtering
query = {"user": user}
if search_term:
query["name"] = {
"$regex": search_term,
"$options": "i", # using case-insensitive search
}
total_documents = sources_collection.count_documents(query)
total_pages = max(1, math.ceil(total_documents / rows_per_page))
page = min(max(1, page), total_pages) # add this to make sure page inbound is within the range
sort_order = 1 if sort_order == "asc" else -1
skip = (page - 1) * rows_per_page
@@ -2082,3 +2098,4 @@ class DeleteTool(Resource):
return {"success": False, "error": str(err)}, 400
return {"success": True}, 200