mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 16:43:16 +00:00
Merge branch 'main' into feat/agent-refactor-and-logging
This commit is contained in:
@@ -116,8 +116,9 @@ def is_azure_configured():
|
||||
|
||||
|
||||
def save_conversation(
|
||||
conversation_id, question, response, source_log_docs, tool_calls, llm, index=None
|
||||
conversation_id, question, response, source_log_docs, tool_calls, llm, index=None, api_key=None
|
||||
):
|
||||
current_time = datetime.datetime.now(datetime.timezone.utc)
|
||||
if conversation_id is not None and index is not None:
|
||||
conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
|
||||
@@ -127,6 +128,7 @@ def save_conversation(
|
||||
f"queries.{index}.response": response,
|
||||
f"queries.{index}.sources": source_log_docs,
|
||||
f"queries.{index}.tool_calls": tool_calls,
|
||||
f"queries.{index}.timestamp": current_time
|
||||
}
|
||||
},
|
||||
)
|
||||
@@ -145,6 +147,7 @@ def save_conversation(
|
||||
"response": response,
|
||||
"sources": source_log_docs,
|
||||
"tool_calls": tool_calls,
|
||||
"timestamp": current_time
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -169,21 +172,25 @@ def save_conversation(
|
||||
]
|
||||
|
||||
completion = llm.gen(model=gpt_model, messages=messages_summary, max_tokens=30)
|
||||
conversation_id = conversations_collection.insert_one(
|
||||
{
|
||||
"user": "local",
|
||||
"date": datetime.datetime.utcnow(),
|
||||
"name": completion,
|
||||
"queries": [
|
||||
{
|
||||
"prompt": question,
|
||||
"response": response,
|
||||
"sources": source_log_docs,
|
||||
"tool_calls": tool_calls,
|
||||
}
|
||||
],
|
||||
}
|
||||
).inserted_id
|
||||
conversation_data = {
|
||||
"user": "local",
|
||||
"date": datetime.datetime.utcnow(),
|
||||
"name": completion,
|
||||
"queries": [
|
||||
{
|
||||
"prompt": question,
|
||||
"response": response,
|
||||
"sources": source_log_docs,
|
||||
"tool_calls": tool_calls,
|
||||
"timestamp": current_time
|
||||
}
|
||||
],
|
||||
}
|
||||
if api_key:
|
||||
api_key_doc = api_key_collection.find_one({"key": api_key})
|
||||
if api_key_doc:
|
||||
conversation_data["api_key"] = api_key_doc["key"]
|
||||
conversation_id = conversations_collection.insert_one(conversation_data).inserted_id
|
||||
return conversation_id
|
||||
|
||||
|
||||
@@ -198,7 +205,6 @@ def get_prompt(prompt_id):
|
||||
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
|
||||
return prompt
|
||||
|
||||
|
||||
def complete_stream(
|
||||
question,
|
||||
agent,
|
||||
@@ -207,8 +213,14 @@ def complete_stream(
|
||||
user_api_key,
|
||||
isNoneDoc=False,
|
||||
index=None,
|
||||
question,
|
||||
retriever,
|
||||
conversation_id,
|
||||
user_api_key,
|
||||
isNoneDoc=False,
|
||||
index=None,
|
||||
should_save_conversation=True
|
||||
):
|
||||
|
||||
try:
|
||||
response_full = ""
|
||||
source_log_docs = []
|
||||
@@ -239,9 +251,12 @@ def complete_stream(
|
||||
doc["source"] = "None"
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
|
||||
settings.LLM_NAME,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=user_api_key
|
||||
)
|
||||
if user_api_key is None:
|
||||
|
||||
if should_save_conversation:
|
||||
conversation_id = save_conversation(
|
||||
conversation_id,
|
||||
question,
|
||||
@@ -250,10 +265,14 @@ def complete_stream(
|
||||
tool_calls,
|
||||
llm,
|
||||
index,
|
||||
api_key=user_api_key
|
||||
)
|
||||
# send data.type = "end" to indicate that the stream has ended as json
|
||||
data = json.dumps({"type": "id", "id": str(conversation_id)})
|
||||
yield f"data: {data}\n\n"
|
||||
else:
|
||||
conversation_id = None
|
||||
|
||||
# send data.type = "end" to indicate that the stream has ended as json
|
||||
data = json.dumps({"type": "id", "id": str(conversation_id)})
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
retriever_params = retriever.get_params()
|
||||
user_logs_collection.insert_one(
|
||||
@@ -316,6 +335,9 @@ class Stream(Resource):
|
||||
"index": fields.Integer(
|
||||
required=False, description="The position where query is to be updated"
|
||||
),
|
||||
"save_conversation": fields.Boolean(
|
||||
required=False, default=True, description="Flag to save conversation"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -330,6 +352,8 @@ class Stream(Resource):
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
|
||||
save_conv = data.get("save_conversation", True)
|
||||
|
||||
try:
|
||||
question = data["question"]
|
||||
history = limit_chat_history(
|
||||
@@ -400,6 +424,7 @@ class Stream(Resource):
|
||||
user_api_key=user_api_key,
|
||||
isNoneDoc=data.get("isNoneDoc"),
|
||||
index=index,
|
||||
should_save_conversation=save_conv,
|
||||
),
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user