From faa583864d1f1783aaf72c219c20365f486fce32 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Mon, 10 Mar 2025 14:19:43 +0530 Subject: [PATCH] feat: enhance conversation saving and response streaming with source handling --- application/agents/classic_agent.py | 1 + application/api/answer/routes.py | 121 ++++++++++++++++++++-------- 2 files changed, 88 insertions(+), 34 deletions(-) diff --git a/application/agents/classic_agent.py b/application/agents/classic_agent.py index 79d9e37f..8848c6f6 100644 --- a/application/agents/classic_agent.py +++ b/application/agents/classic_agent.py @@ -107,6 +107,7 @@ class ClassicAgent(BaseAgent): if isinstance(line, str): yield {"answer": line} + yield {"sources": retrieved_data} yield {"tool_calls": self.tool_calls.copy()} def _retriever_search(self, retriever, query, log_context): diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index c8c9708f..5a221e8d 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -116,7 +116,14 @@ def is_azure_configured(): def save_conversation( - conversation_id, question, response, source_log_docs, tool_calls, llm, index=None, api_key=None + conversation_id, + question, + response, + source_log_docs, + tool_calls, + llm, + index=None, + api_key=None, ): current_time = datetime.datetime.now(datetime.timezone.utc) if conversation_id is not None and index is not None: @@ -128,7 +135,7 @@ def save_conversation( f"queries.{index}.response": response, f"queries.{index}.sources": source_log_docs, f"queries.{index}.tool_calls": tool_calls, - f"queries.{index}.timestamp": current_time + f"queries.{index}.timestamp": current_time, } }, ) @@ -147,7 +154,7 @@ def save_conversation( "response": response, "sources": source_log_docs, "tool_calls": tool_calls, - "timestamp": current_time + "timestamp": current_time, } } }, @@ -182,7 +189,7 @@ def save_conversation( "response": response, "sources": source_log_docs, "tool_calls": tool_calls, - "timestamp": current_time + "timestamp": current_time, } ], } @@ -190,7 +197,9 @@ def save_conversation( api_key_doc = api_key_collection.find_one({"key": api_key}) if api_key_doc: conversation_data["api_key"] = api_key_doc["key"] - conversation_id = conversations_collection.insert_one(conversation_data).inserted_id + conversation_id = conversations_collection.insert_one( + conversation_data + ).inserted_id return conversation_id @@ -205,36 +214,42 @@ def get_prompt(prompt_id): prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"] return prompt + def complete_stream( - question, + question, agent, - retriever, - conversation_id, - user_api_key, - isNoneDoc=False, + retriever, + conversation_id, + user_api_key, + isNoneDoc=False, index=None, - should_save_conversation=True + should_save_conversation=True, ): try: response_full = "" source_log_docs = [] tool_calls = [] + answer = agent.gen(query=question, retriever=retriever) - sources = retriever.search(question) - for source in sources: - if "text" in source: - source["text"] = source["text"][:100].strip() + "..." - if len(sources) > 0: - data = json.dumps({"type": "source", "source": sources}) - yield f"data: {data}\n\n" for line in answer: if "answer" in line: response_full += str(line["answer"]) - data = json.dumps(line) + data = json.dumps({"type": "answer", "answer": line["answer"]}) yield f"data: {data}\n\n" - elif "source" in line: - source_log_docs.append(line["source"]) + elif "sources" in line: + truncated_sources = [] + source_log_docs = line["sources"] + for source in line["sources"]: + truncated_source = source.copy() + if "text" in truncated_source: + truncated_source["text"] = ( + truncated_source["text"][:100].strip() + "..." + ) + truncated_sources.append(truncated_source) + if len(truncated_sources) > 0: + data = json.dumps({"type": "source", "source": truncated_sources}) + yield f"data: {data}\n\n" elif "tool_calls" in line: tool_calls = line["tool_calls"] data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls}) @@ -245,11 +260,9 @@ def complete_stream( doc["source"] = "None" llm = LLMCreator.create_llm( - settings.LLM_NAME, - api_key=settings.API_KEY, - user_api_key=user_api_key + settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key ) - + if should_save_conversation: conversation_id = save_conversation( conversation_id, @@ -259,7 +272,7 @@ def complete_stream( tool_calls, llm, index, - api_key=user_api_key + api_key=user_api_key, ) else: conversation_id = None @@ -523,9 +536,19 @@ class Answer(Resource): extra={"data": json.dumps({"request_data": data, "source": source})}, ) + agent = AgentCreator.create_agent( + settings.AGENT_NAME, + endpoint="api/answer", + llm_name=settings.LLM_NAME, + gpt_model=gpt_model, + api_key=settings.API_KEY, + user_api_key=user_api_key, + prompt=prompt, + chat_history=history, + ) + retriever = RetrieverCreator.create_retriever( retriever_name, - question=question, source=source, chat_history=history, prompt=prompt, @@ -538,13 +561,41 @@ class Answer(Resource): response_full = "" source_log_docs = [] tool_calls = [] - for line in retriever.gen(): - if "source" in line: - source_log_docs.append(line["source"]) - elif "answer" in line: - response_full += line["answer"] - elif "tool_calls" in line: - tool_calls.append(line["tool_calls"]) + stream_ended = False + + for line in complete_stream( + question=question, + agent=agent, + retriever=retriever, + conversation_id=conversation_id, + user_api_key=user_api_key, + isNoneDoc=data.get("isNoneDoc"), + index=None, + should_save_conversation=False, + ): + try: + event_data = line.replace("data: ", "").strip() + event = json.loads(event_data) + + if event["type"] == "answer": + response_full += event["answer"] + elif event["type"] == "source": + source_log_docs = event["source"] + elif event["type"] == "tool_calls": + tool_calls = event["tool_calls"] + elif event["type"] == "error": + logger.error(f"Error from stream: {event['error']}") + return bad_request(500, event["error"]) + elif event["type"] == "end": + stream_ended = True + + except (json.JSONDecodeError, KeyError) as e: + logger.warning(f"Error parsing stream event: {e}, line: {line}") + continue + + if not stream_ended: + logger.error("Stream ended unexpectedly without an 'end' event.") + return bad_request(500, "Stream ended unexpectedly.") if data.get("isNoneDoc"): for doc in source_log_docs: @@ -563,8 +614,10 @@ class Answer(Resource): source_log_docs, tool_calls, llm, + api_key=user_api_key, ) ) + retriever_params = retriever.get_params() user_logs_collection.insert_one( {