Merge branch 'main' into feat/agent-refactor-and-logging

This commit is contained in:
Alex
2025-03-05 16:04:09 -05:00
committed by GitHub
28 changed files with 624 additions and 1072 deletions

View File

@@ -116,8 +116,9 @@ def is_azure_configured():
def save_conversation(
conversation_id, question, response, source_log_docs, tool_calls, llm, index=None
conversation_id, question, response, source_log_docs, tool_calls, llm, index=None, api_key=None
):
current_time = datetime.datetime.now(datetime.timezone.utc)
if conversation_id is not None and index is not None:
conversations_collection.update_one(
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
@@ -127,6 +128,7 @@ def save_conversation(
f"queries.{index}.response": response,
f"queries.{index}.sources": source_log_docs,
f"queries.{index}.tool_calls": tool_calls,
f"queries.{index}.timestamp": current_time
}
},
)
@@ -145,6 +147,7 @@ def save_conversation(
"response": response,
"sources": source_log_docs,
"tool_calls": tool_calls,
"timestamp": current_time
}
}
},
@@ -169,21 +172,25 @@ def save_conversation(
]
completion = llm.gen(model=gpt_model, messages=messages_summary, max_tokens=30)
conversation_id = conversations_collection.insert_one(
{
"user": "local",
"date": datetime.datetime.utcnow(),
"name": completion,
"queries": [
{
"prompt": question,
"response": response,
"sources": source_log_docs,
"tool_calls": tool_calls,
}
],
}
).inserted_id
conversation_data = {
"user": "local",
"date": datetime.datetime.utcnow(),
"name": completion,
"queries": [
{
"prompt": question,
"response": response,
"sources": source_log_docs,
"tool_calls": tool_calls,
"timestamp": current_time
}
],
}
if api_key:
api_key_doc = api_key_collection.find_one({"key": api_key})
if api_key_doc:
conversation_data["api_key"] = api_key_doc["key"]
conversation_id = conversations_collection.insert_one(conversation_data).inserted_id
return conversation_id
@@ -198,7 +205,6 @@ def get_prompt(prompt_id):
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
return prompt
def complete_stream(
question,
agent,
@@ -207,8 +213,14 @@ def complete_stream(
user_api_key,
isNoneDoc=False,
index=None,
question,
retriever,
conversation_id,
user_api_key,
isNoneDoc=False,
index=None,
should_save_conversation=True
):
try:
response_full = ""
source_log_docs = []
@@ -239,9 +251,12 @@ def complete_stream(
doc["source"] = "None"
llm = LLMCreator.create_llm(
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
settings.LLM_NAME,
api_key=settings.API_KEY,
user_api_key=user_api_key
)
if user_api_key is None:
if should_save_conversation:
conversation_id = save_conversation(
conversation_id,
question,
@@ -250,10 +265,14 @@ def complete_stream(
tool_calls,
llm,
index,
api_key=user_api_key
)
# send data.type = "end" to indicate that the stream has ended as json
data = json.dumps({"type": "id", "id": str(conversation_id)})
yield f"data: {data}\n\n"
else:
conversation_id = None
# send data.type = "end" to indicate that the stream has ended as json
data = json.dumps({"type": "id", "id": str(conversation_id)})
yield f"data: {data}\n\n"
retriever_params = retriever.get_params()
user_logs_collection.insert_one(
@@ -316,6 +335,9 @@ class Stream(Resource):
"index": fields.Integer(
required=False, description="The position where query is to be updated"
),
"save_conversation": fields.Boolean(
required=False, default=True, description="Flag to save conversation"
),
},
)
@@ -330,6 +352,8 @@ class Stream(Resource):
if missing_fields:
return missing_fields
save_conv = data.get("save_conversation", True)
try:
question = data["question"]
history = limit_chat_history(
@@ -400,6 +424,7 @@ class Stream(Resource):
user_api_key=user_api_key,
isNoneDoc=data.get("isNoneDoc"),
index=index,
should_save_conversation=save_conv,
),
mimetype="text/event-stream",
)