mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-02-03 04:44:10 +00:00
Merge branch 'main' into tool-use
This commit is contained in:
@@ -18,7 +18,7 @@ from application.error import bad_request
|
||||
from application.extensions import api
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.retriever.retriever_creator import RetrieverCreator
|
||||
from application.utils import check_required_fields
|
||||
from application.utils import check_required_fields, limit_chat_history
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -118,8 +118,31 @@ def is_azure_configured():
|
||||
)
|
||||
|
||||
|
||||
def save_conversation(conversation_id, question, response, source_log_docs, llm):
|
||||
if conversation_id is not None and conversation_id != "None":
|
||||
def save_conversation(conversation_id, question, response, source_log_docs, llm,index=None):
|
||||
if conversation_id is not None and index is not None:
|
||||
conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
|
||||
{
|
||||
"$set": {
|
||||
f"queries.{index}.prompt": question,
|
||||
f"queries.{index}.response": response,
|
||||
f"queries.{index}.sources": source_log_docs,
|
||||
}
|
||||
}
|
||||
)
|
||||
##remove following queries from the array
|
||||
conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
|
||||
{
|
||||
"$push":{
|
||||
"queries":{
|
||||
"$each":[],
|
||||
"$slice":index+1
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
elif conversation_id is not None and conversation_id != "None":
|
||||
conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id)},
|
||||
{
|
||||
@@ -141,17 +164,17 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm)
|
||||
"role": "assistant",
|
||||
"content": "Summarise following conversation in no more than 3 "
|
||||
"words, respond ONLY with the summary, use the same "
|
||||
"language as the system \n\nUser: "
|
||||
+ question
|
||||
+ "\n\n"
|
||||
+ "AI: "
|
||||
+ response,
|
||||
"language as the system",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Summarise following conversation in no more than 3 words, "
|
||||
"respond ONLY with the summary, use the same language as the "
|
||||
"system",
|
||||
"system \n\nUser: "
|
||||
+ question
|
||||
+ "\n\n"
|
||||
+ "AI: "
|
||||
+ response,
|
||||
},
|
||||
]
|
||||
|
||||
@@ -186,7 +209,7 @@ def get_prompt(prompt_id):
|
||||
|
||||
|
||||
def complete_stream(
|
||||
question, retriever, conversation_id, user_api_key, isNoneDoc=False
|
||||
question, retriever, conversation_id, user_api_key, isNoneDoc=False,index=None
|
||||
):
|
||||
|
||||
try:
|
||||
@@ -217,7 +240,7 @@ def complete_stream(
|
||||
)
|
||||
if user_api_key is None:
|
||||
conversation_id = save_conversation(
|
||||
conversation_id, question, response_full, source_log_docs, llm
|
||||
conversation_id, question, response_full, source_log_docs, llm,index
|
||||
)
|
||||
# send data.type = "end" to indicate that the stream has ended as json
|
||||
data = json.dumps({"type": "id", "id": str(conversation_id)})
|
||||
@@ -282,6 +305,9 @@ class Stream(Resource):
|
||||
"isNoneDoc": fields.Boolean(
|
||||
required=False, description="Flag indicating if no document is used"
|
||||
),
|
||||
"index":fields.Integer(
|
||||
required=False, description="The position where query is to be updated"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -290,23 +316,23 @@ class Stream(Resource):
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
required_fields = ["question"]
|
||||
|
||||
if "index" in data:
|
||||
required_fields = ["question","conversation_id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
|
||||
try:
|
||||
question = data["question"]
|
||||
history = data.get("history", [])
|
||||
history = json.loads(history)
|
||||
history = limit_chat_history(json.loads(data.get("history", [])), gpt_model=gpt_model)
|
||||
conversation_id = data.get("conversation_id")
|
||||
prompt_id = data.get("prompt_id", "default")
|
||||
|
||||
|
||||
index=data.get("index",None)
|
||||
chunks = int(data.get("chunks", 2))
|
||||
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
|
||||
retriever_name = data.get("retriever", "classic")
|
||||
|
||||
|
||||
if "api_key" in data:
|
||||
data_key = get_data_from_api_key(data["api_key"])
|
||||
chunks = int(data_key.get("chunks", 2))
|
||||
@@ -343,7 +369,7 @@ class Stream(Resource):
|
||||
gpt_model=gpt_model,
|
||||
user_api_key=user_api_key,
|
||||
)
|
||||
|
||||
|
||||
return Response(
|
||||
complete_stream(
|
||||
question=question,
|
||||
@@ -351,6 +377,7 @@ class Stream(Resource):
|
||||
conversation_id=conversation_id,
|
||||
user_api_key=user_api_key,
|
||||
isNoneDoc=data.get("isNoneDoc"),
|
||||
index=index,
|
||||
),
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
@@ -428,7 +455,7 @@ class Answer(Resource):
|
||||
|
||||
try:
|
||||
question = data["question"]
|
||||
history = data.get("history", [])
|
||||
history = limit_chat_history(json.loads(data.get("history", [])), gpt_model=gpt_model)
|
||||
conversation_id = data.get("conversation_id")
|
||||
prompt_id = data.get("prompt_id", "default")
|
||||
chunks = int(data.get("chunks", 2))
|
||||
|
||||
@@ -181,10 +181,12 @@ class SubmitFeedback(Resource):
|
||||
"FeedbackModel",
|
||||
{
|
||||
"question": fields.String(
|
||||
required=True, description="The user question"
|
||||
required=False, description="The user question"
|
||||
),
|
||||
"answer": fields.String(required=True, description="The AI answer"),
|
||||
"answer": fields.String(required=False, description="The AI answer"),
|
||||
"feedback": fields.String(required=True, description="User feedback"),
|
||||
"question_index":fields.Integer(required=True, description="The question number in that particular conversation"),
|
||||
"conversation_id":fields.String(required=True, description="id of the particular conversation"),
|
||||
"api_key": fields.String(description="Optional API key"),
|
||||
},
|
||||
)
|
||||
@@ -194,23 +196,21 @@ class SubmitFeedback(Resource):
|
||||
)
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
required_fields = ["question", "answer", "feedback"]
|
||||
required_fields = [ "feedback","conversation_id","question_index"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
|
||||
new_doc = {
|
||||
"question": data["question"],
|
||||
"answer": data["answer"],
|
||||
"feedback": data["feedback"],
|
||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
|
||||
if "api_key" in data:
|
||||
new_doc["api_key"] = data["api_key"]
|
||||
|
||||
try:
|
||||
feedback_collection.insert_one(new_doc)
|
||||
conversations_collection.update_one(
|
||||
{"_id": ObjectId(data["conversation_id"]), f"queries.{data['question_index']}": {"$exists": True}},
|
||||
{
|
||||
"$set": {
|
||||
f"queries.{data['question_index']}.feedback": data["feedback"]
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as err:
|
||||
return make_response(jsonify({"success": False, "error": str(err)}), 400)
|
||||
|
||||
@@ -253,13 +253,12 @@ class DeleteOldIndexes(Resource):
|
||||
jsonify({"success": False, "message": "Missing required fields"}), 400
|
||||
)
|
||||
|
||||
try:
|
||||
doc = sources_collection.find_one(
|
||||
doc = sources_collection.find_one(
|
||||
{"_id": ObjectId(source_id), "user": "local"}
|
||||
)
|
||||
if not doc:
|
||||
)
|
||||
if not doc:
|
||||
return make_response(jsonify({"status": "not found"}), 404)
|
||||
|
||||
try:
|
||||
if settings.VECTOR_STORE == "faiss":
|
||||
shutil.rmtree(os.path.join(current_dir, "indexes", str(doc["_id"])))
|
||||
else:
|
||||
@@ -268,12 +267,12 @@ class DeleteOldIndexes(Resource):
|
||||
)
|
||||
vectorstore.delete_index()
|
||||
|
||||
sources_collection.delete_one({"_id": ObjectId(source_id)})
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception as err:
|
||||
return make_response(jsonify({"success": False, "error": str(err)}), 400)
|
||||
|
||||
|
||||
sources_collection.delete_one({"_id": ObjectId(source_id)})
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@@ -344,6 +343,9 @@ class UploadFile(Resource):
|
||||
".json",
|
||||
".xlsx",
|
||||
".pptx",
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
],
|
||||
job_name,
|
||||
final_filename,
|
||||
@@ -370,6 +372,9 @@ class UploadFile(Resource):
|
||||
".json",
|
||||
".xlsx",
|
||||
".pptx",
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
],
|
||||
job_name,
|
||||
final_filename,
|
||||
@@ -478,11 +483,22 @@ class PaginatedSources(Resource):
|
||||
sort_order = request.args.get("order", "desc") # Default to 'desc'
|
||||
page = int(request.args.get("page", 1)) # Default to 1
|
||||
rows_per_page = int(request.args.get("rows", 10)) # Default to 10
|
||||
# add .strip() to remove leading and trailing whitespaces
|
||||
search_term = request.args.get(
|
||||
"search", ""
|
||||
).strip() # add search for filter documents
|
||||
|
||||
# Prepare
|
||||
# Prepare query for filtering
|
||||
query = {"user": user}
|
||||
if search_term:
|
||||
query["name"] = {
|
||||
"$regex": search_term,
|
||||
"$options": "i", # using case-insensitive search
|
||||
}
|
||||
|
||||
total_documents = sources_collection.count_documents(query)
|
||||
total_pages = max(1, math.ceil(total_documents / rows_per_page))
|
||||
page = min(max(1, page), total_pages) # add this to make sure page inbound is within the range
|
||||
sort_order = 1 if sort_order == "asc" else -1
|
||||
skip = (page - 1) * rows_per_page
|
||||
|
||||
@@ -2082,3 +2098,4 @@ class DeleteTool(Resource):
|
||||
return {"success": False, "error": str(err)}, 400
|
||||
|
||||
return {"success": True}, 200
|
||||
|
||||
Reference in New Issue
Block a user