mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
feat: enhance conversation saving and response streaming with source handling
This commit is contained in:
@@ -107,6 +107,7 @@ class ClassicAgent(BaseAgent):
|
||||
if isinstance(line, str):
|
||||
yield {"answer": line}
|
||||
|
||||
yield {"sources": retrieved_data}
|
||||
yield {"tool_calls": self.tool_calls.copy()}
|
||||
|
||||
def _retriever_search(self, retriever, query, log_context):
|
||||
|
||||
@@ -116,7 +116,14 @@ def is_azure_configured():
|
||||
|
||||
|
||||
def save_conversation(
|
||||
conversation_id, question, response, source_log_docs, tool_calls, llm, index=None, api_key=None
|
||||
conversation_id,
|
||||
question,
|
||||
response,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
index=None,
|
||||
api_key=None,
|
||||
):
|
||||
current_time = datetime.datetime.now(datetime.timezone.utc)
|
||||
if conversation_id is not None and index is not None:
|
||||
@@ -128,7 +135,7 @@ def save_conversation(
|
||||
f"queries.{index}.response": response,
|
||||
f"queries.{index}.sources": source_log_docs,
|
||||
f"queries.{index}.tool_calls": tool_calls,
|
||||
f"queries.{index}.timestamp": current_time
|
||||
f"queries.{index}.timestamp": current_time,
|
||||
}
|
||||
},
|
||||
)
|
||||
@@ -147,7 +154,7 @@ def save_conversation(
|
||||
"response": response,
|
||||
"sources": source_log_docs,
|
||||
"tool_calls": tool_calls,
|
||||
"timestamp": current_time
|
||||
"timestamp": current_time,
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -182,7 +189,7 @@ def save_conversation(
|
||||
"response": response,
|
||||
"sources": source_log_docs,
|
||||
"tool_calls": tool_calls,
|
||||
"timestamp": current_time
|
||||
"timestamp": current_time,
|
||||
}
|
||||
],
|
||||
}
|
||||
@@ -190,7 +197,9 @@ def save_conversation(
|
||||
api_key_doc = api_key_collection.find_one({"key": api_key})
|
||||
if api_key_doc:
|
||||
conversation_data["api_key"] = api_key_doc["key"]
|
||||
conversation_id = conversations_collection.insert_one(conversation_data).inserted_id
|
||||
conversation_id = conversations_collection.insert_one(
|
||||
conversation_data
|
||||
).inserted_id
|
||||
return conversation_id
|
||||
|
||||
|
||||
@@ -205,36 +214,42 @@ def get_prompt(prompt_id):
|
||||
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
|
||||
return prompt
|
||||
|
||||
|
||||
def complete_stream(
|
||||
question,
|
||||
question,
|
||||
agent,
|
||||
retriever,
|
||||
conversation_id,
|
||||
user_api_key,
|
||||
isNoneDoc=False,
|
||||
retriever,
|
||||
conversation_id,
|
||||
user_api_key,
|
||||
isNoneDoc=False,
|
||||
index=None,
|
||||
should_save_conversation=True
|
||||
should_save_conversation=True,
|
||||
):
|
||||
try:
|
||||
response_full = ""
|
||||
source_log_docs = []
|
||||
tool_calls = []
|
||||
|
||||
answer = agent.gen(query=question, retriever=retriever)
|
||||
sources = retriever.search(question)
|
||||
for source in sources:
|
||||
if "text" in source:
|
||||
source["text"] = source["text"][:100].strip() + "..."
|
||||
if len(sources) > 0:
|
||||
data = json.dumps({"type": "source", "source": sources})
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
for line in answer:
|
||||
if "answer" in line:
|
||||
response_full += str(line["answer"])
|
||||
data = json.dumps(line)
|
||||
data = json.dumps({"type": "answer", "answer": line["answer"]})
|
||||
yield f"data: {data}\n\n"
|
||||
elif "source" in line:
|
||||
source_log_docs.append(line["source"])
|
||||
elif "sources" in line:
|
||||
truncated_sources = []
|
||||
source_log_docs = line["sources"]
|
||||
for source in line["sources"]:
|
||||
truncated_source = source.copy()
|
||||
if "text" in truncated_source:
|
||||
truncated_source["text"] = (
|
||||
truncated_source["text"][:100].strip() + "..."
|
||||
)
|
||||
truncated_sources.append(truncated_source)
|
||||
if len(truncated_sources) > 0:
|
||||
data = json.dumps({"type": "source", "source": truncated_sources})
|
||||
yield f"data: {data}\n\n"
|
||||
elif "tool_calls" in line:
|
||||
tool_calls = line["tool_calls"]
|
||||
data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls})
|
||||
@@ -245,11 +260,9 @@ def complete_stream(
|
||||
doc["source"] = "None"
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_NAME,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=user_api_key
|
||||
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
|
||||
)
|
||||
|
||||
|
||||
if should_save_conversation:
|
||||
conversation_id = save_conversation(
|
||||
conversation_id,
|
||||
@@ -259,7 +272,7 @@ def complete_stream(
|
||||
tool_calls,
|
||||
llm,
|
||||
index,
|
||||
api_key=user_api_key
|
||||
api_key=user_api_key,
|
||||
)
|
||||
else:
|
||||
conversation_id = None
|
||||
@@ -523,9 +536,19 @@ class Answer(Resource):
|
||||
extra={"data": json.dumps({"request_data": data, "source": source})},
|
||||
)
|
||||
|
||||
agent = AgentCreator.create_agent(
|
||||
settings.AGENT_NAME,
|
||||
endpoint="api/answer",
|
||||
llm_name=settings.LLM_NAME,
|
||||
gpt_model=gpt_model,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=user_api_key,
|
||||
prompt=prompt,
|
||||
chat_history=history,
|
||||
)
|
||||
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
retriever_name,
|
||||
question=question,
|
||||
source=source,
|
||||
chat_history=history,
|
||||
prompt=prompt,
|
||||
@@ -538,13 +561,41 @@ class Answer(Resource):
|
||||
response_full = ""
|
||||
source_log_docs = []
|
||||
tool_calls = []
|
||||
for line in retriever.gen():
|
||||
if "source" in line:
|
||||
source_log_docs.append(line["source"])
|
||||
elif "answer" in line:
|
||||
response_full += line["answer"]
|
||||
elif "tool_calls" in line:
|
||||
tool_calls.append(line["tool_calls"])
|
||||
stream_ended = False
|
||||
|
||||
for line in complete_stream(
|
||||
question=question,
|
||||
agent=agent,
|
||||
retriever=retriever,
|
||||
conversation_id=conversation_id,
|
||||
user_api_key=user_api_key,
|
||||
isNoneDoc=data.get("isNoneDoc"),
|
||||
index=None,
|
||||
should_save_conversation=False,
|
||||
):
|
||||
try:
|
||||
event_data = line.replace("data: ", "").strip()
|
||||
event = json.loads(event_data)
|
||||
|
||||
if event["type"] == "answer":
|
||||
response_full += event["answer"]
|
||||
elif event["type"] == "source":
|
||||
source_log_docs = event["source"]
|
||||
elif event["type"] == "tool_calls":
|
||||
tool_calls = event["tool_calls"]
|
||||
elif event["type"] == "error":
|
||||
logger.error(f"Error from stream: {event['error']}")
|
||||
return bad_request(500, event["error"])
|
||||
elif event["type"] == "end":
|
||||
stream_ended = True
|
||||
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
logger.warning(f"Error parsing stream event: {e}, line: {line}")
|
||||
continue
|
||||
|
||||
if not stream_ended:
|
||||
logger.error("Stream ended unexpectedly without an 'end' event.")
|
||||
return bad_request(500, "Stream ended unexpectedly.")
|
||||
|
||||
if data.get("isNoneDoc"):
|
||||
for doc in source_log_docs:
|
||||
@@ -563,8 +614,10 @@ class Answer(Resource):
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
api_key=user_api_key,
|
||||
)
|
||||
)
|
||||
|
||||
retriever_params = retriever.get_params()
|
||||
user_logs_collection.insert_one(
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user