feat: enhance conversation saving and response streaming with source handling

This commit is contained in:
Siddhant Rai
2025-03-10 14:19:43 +05:30
parent 46d32b4072
commit faa583864d
2 changed files with 88 additions and 34 deletions

View File

@@ -107,6 +107,7 @@ class ClassicAgent(BaseAgent):
if isinstance(line, str):
yield {"answer": line}
yield {"sources": retrieved_data}
yield {"tool_calls": self.tool_calls.copy()}
def _retriever_search(self, retriever, query, log_context):

View File

@@ -116,7 +116,14 @@ def is_azure_configured():
def save_conversation(
conversation_id, question, response, source_log_docs, tool_calls, llm, index=None, api_key=None
conversation_id,
question,
response,
source_log_docs,
tool_calls,
llm,
index=None,
api_key=None,
):
current_time = datetime.datetime.now(datetime.timezone.utc)
if conversation_id is not None and index is not None:
@@ -128,7 +135,7 @@ def save_conversation(
f"queries.{index}.response": response,
f"queries.{index}.sources": source_log_docs,
f"queries.{index}.tool_calls": tool_calls,
f"queries.{index}.timestamp": current_time
f"queries.{index}.timestamp": current_time,
}
},
)
@@ -147,7 +154,7 @@ def save_conversation(
"response": response,
"sources": source_log_docs,
"tool_calls": tool_calls,
"timestamp": current_time
"timestamp": current_time,
}
}
},
@@ -182,7 +189,7 @@ def save_conversation(
"response": response,
"sources": source_log_docs,
"tool_calls": tool_calls,
"timestamp": current_time
"timestamp": current_time,
}
],
}
@@ -190,7 +197,9 @@ def save_conversation(
api_key_doc = api_key_collection.find_one({"key": api_key})
if api_key_doc:
conversation_data["api_key"] = api_key_doc["key"]
conversation_id = conversations_collection.insert_one(conversation_data).inserted_id
conversation_id = conversations_collection.insert_one(
conversation_data
).inserted_id
return conversation_id
@@ -205,36 +214,42 @@ def get_prompt(prompt_id):
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
return prompt
def complete_stream(
question,
question,
agent,
retriever,
conversation_id,
user_api_key,
isNoneDoc=False,
retriever,
conversation_id,
user_api_key,
isNoneDoc=False,
index=None,
should_save_conversation=True
should_save_conversation=True,
):
try:
response_full = ""
source_log_docs = []
tool_calls = []
answer = agent.gen(query=question, retriever=retriever)
sources = retriever.search(question)
for source in sources:
if "text" in source:
source["text"] = source["text"][:100].strip() + "..."
if len(sources) > 0:
data = json.dumps({"type": "source", "source": sources})
yield f"data: {data}\n\n"
for line in answer:
if "answer" in line:
response_full += str(line["answer"])
data = json.dumps(line)
data = json.dumps({"type": "answer", "answer": line["answer"]})
yield f"data: {data}\n\n"
elif "source" in line:
source_log_docs.append(line["source"])
elif "sources" in line:
truncated_sources = []
source_log_docs = line["sources"]
for source in line["sources"]:
truncated_source = source.copy()
if "text" in truncated_source:
truncated_source["text"] = (
truncated_source["text"][:100].strip() + "..."
)
truncated_sources.append(truncated_source)
if len(truncated_sources) > 0:
data = json.dumps({"type": "source", "source": truncated_sources})
yield f"data: {data}\n\n"
elif "tool_calls" in line:
tool_calls = line["tool_calls"]
data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls})
@@ -245,11 +260,9 @@ def complete_stream(
doc["source"] = "None"
llm = LLMCreator.create_llm(
settings.LLM_NAME,
api_key=settings.API_KEY,
user_api_key=user_api_key
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
)
if should_save_conversation:
conversation_id = save_conversation(
conversation_id,
@@ -259,7 +272,7 @@ def complete_stream(
tool_calls,
llm,
index,
api_key=user_api_key
api_key=user_api_key,
)
else:
conversation_id = None
@@ -523,9 +536,19 @@ class Answer(Resource):
extra={"data": json.dumps({"request_data": data, "source": source})},
)
agent = AgentCreator.create_agent(
settings.AGENT_NAME,
endpoint="api/answer",
llm_name=settings.LLM_NAME,
gpt_model=gpt_model,
api_key=settings.API_KEY,
user_api_key=user_api_key,
prompt=prompt,
chat_history=history,
)
retriever = RetrieverCreator.create_retriever(
retriever_name,
question=question,
source=source,
chat_history=history,
prompt=prompt,
@@ -538,13 +561,41 @@ class Answer(Resource):
response_full = ""
source_log_docs = []
tool_calls = []
for line in retriever.gen():
if "source" in line:
source_log_docs.append(line["source"])
elif "answer" in line:
response_full += line["answer"]
elif "tool_calls" in line:
tool_calls.append(line["tool_calls"])
stream_ended = False
for line in complete_stream(
question=question,
agent=agent,
retriever=retriever,
conversation_id=conversation_id,
user_api_key=user_api_key,
isNoneDoc=data.get("isNoneDoc"),
index=None,
should_save_conversation=False,
):
try:
event_data = line.replace("data: ", "").strip()
event = json.loads(event_data)
if event["type"] == "answer":
response_full += event["answer"]
elif event["type"] == "source":
source_log_docs = event["source"]
elif event["type"] == "tool_calls":
tool_calls = event["tool_calls"]
elif event["type"] == "error":
logger.error(f"Error from stream: {event['error']}")
return bad_request(500, event["error"])
elif event["type"] == "end":
stream_ended = True
except (json.JSONDecodeError, KeyError) as e:
logger.warning(f"Error parsing stream event: {e}, line: {line}")
continue
if not stream_ended:
logger.error("Stream ended unexpectedly without an 'end' event.")
return bad_request(500, "Stream ended unexpectedly.")
if data.get("isNoneDoc"):
for doc in source_log_docs:
@@ -563,8 +614,10 @@ class Answer(Resource):
source_log_docs,
tool_calls,
llm,
api_key=user_api_key,
)
)
retriever_params = retriever.get_params()
user_logs_collection.insert_one(
{