Files
DocsGPT/application/api/answer/routes/base.py

264 lines
9.8 KiB
Python

import datetime
import json
import logging
from typing import Any, Dict, Generator, List, Optional
from flask import Response
from flask_restx import Namespace
from application.api.answer.services.conversation_service import ConversationService
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
from application.utils import check_required_fields, get_gpt_model
logger = logging.getLogger(__name__)
answer_ns = Namespace("answer", description="Answer related operations", path="/")
class BaseAnswerResource:
"""Shared base class for answer endpoints"""
def __init__(self):
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
self.user_logs_collection = db["user_logs"]
self.gpt_model = get_gpt_model()
self.conversation_service = ConversationService()
def validate_request(
self, data: Dict[str, Any], require_conversation_id: bool = False
) -> Optional[Response]:
"""Common request validation"""
required_fields = ["question"]
if require_conversation_id:
required_fields.append("conversation_id")
if missing_fields := check_required_fields(data, required_fields):
return missing_fields
return None
def complete_stream(
self,
question: str,
agent: Any,
retriever: Any,
conversation_id: Optional[str],
user_api_key: Optional[str],
decoded_token: Dict[str, Any],
isNoneDoc: bool = False,
index: Optional[int] = None,
should_save_conversation: bool = True,
attachment_ids: Optional[List[str]] = None,
agent_id: Optional[str] = None,
is_shared_usage: bool = False,
shared_token: Optional[str] = None,
) -> Generator[str, None, None]:
"""
Generator function that streams the complete conversation response.
Args:
question: The user's question
agent: The agent instance
retriever: The retriever instance
conversation_id: Existing conversation ID
user_api_key: User's API key if any
decoded_token: Decoded JWT token
isNoneDoc: Flag for document-less responses
index: Index of message to update
should_save_conversation: Whether to persist the conversation
attachment_ids: List of attachment IDs
agent_id: ID of agent used
is_shared_usage: Flag for shared agent usage
shared_token: Token for shared agent
Yields:
Server-sent event strings
"""
try:
response_full, thought, source_log_docs, tool_calls = "", "", [], []
is_structured = False
schema_info = None
structured_chunks = []
for line in agent.gen(query=question, retriever=retriever):
if "answer" in line:
response_full += str(line["answer"])
if line.get("structured"):
is_structured = True
schema_info = line.get("schema")
structured_chunks.append(line["answer"])
else:
data = json.dumps({"type": "answer", "answer": line["answer"]})
yield f"data: {data}\n\n"
elif "sources" in line:
truncated_sources = []
source_log_docs = line["sources"]
for source in line["sources"]:
truncated_source = source.copy()
if "text" in truncated_source:
truncated_source["text"] = (
truncated_source["text"][:100].strip() + "..."
)
truncated_sources.append(truncated_source)
if truncated_sources:
data = json.dumps(
{"type": "source", "source": truncated_sources}
)
yield f"data: {data}\n\n"
elif "tool_calls" in line:
tool_calls = line["tool_calls"]
elif "thought" in line:
thought += line["thought"]
data = json.dumps({"type": "thought", "thought": line["thought"]})
yield f"data: {data}\n\n"
elif "type" in line:
data = json.dumps(line)
yield f"data: {data}\n\n"
if is_structured and structured_chunks:
structured_data = {
"type": "structured_answer",
"answer": response_full,
"structured": True,
"schema": schema_info,
}
data = json.dumps(structured_data)
yield f"data: {data}\n\n"
if isNoneDoc:
for doc in source_log_docs:
doc["source"] = "None"
llm = LLMCreator.create_llm(
settings.LLM_PROVIDER,
api_key=settings.API_KEY,
user_api_key=user_api_key,
decoded_token=decoded_token,
)
if should_save_conversation:
conversation_id = self.conversation_service.save_conversation(
conversation_id,
question,
response_full,
thought,
source_log_docs,
tool_calls,
llm,
self.gpt_model,
decoded_token,
index=index,
api_key=user_api_key,
agent_id=agent_id,
is_shared_usage=is_shared_usage,
shared_token=shared_token,
attachment_ids=attachment_ids,
)
else:
conversation_id = None
id_data = {"type": "id", "id": str(conversation_id)}
data = json.dumps(id_data)
yield f"data: {data}\n\n"
retriever_params = retriever.get_params()
log_data = {
"action": "stream_answer",
"level": "info",
"user": decoded_token.get("sub"),
"api_key": user_api_key,
"question": question,
"response": response_full,
"sources": source_log_docs,
"retriever_params": retriever_params,
"attachments": attachment_ids,
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
if is_structured:
log_data["structured_output"] = True
if schema_info:
log_data["schema"] = schema_info
# clean up text fields to be no longer than 10000 characters
for key, value in log_data.items():
if isinstance(value, str) and len(value) > 10000:
log_data[key] = value[:10000]
self.user_logs_collection.insert_one(log_data)
# End of stream
data = json.dumps({"type": "end"})
yield f"data: {data}\n\n"
except Exception as e:
logger.error(f"Error in stream: {str(e)}", exc_info=True)
data = json.dumps(
{
"type": "error",
"error": "Please try again later. We apologize for any inconvenience.",
}
)
yield f"data: {data}\n\n"
return
def process_response_stream(self, stream):
"""Process the stream response for non-streaming endpoint"""
conversation_id = ""
response_full = ""
source_log_docs = []
tool_calls = []
thought = ""
stream_ended = False
is_structured = False
schema_info = None
for line in stream:
try:
event_data = line.replace("data: ", "").strip()
event = json.loads(event_data)
if event["type"] == "id":
conversation_id = event["id"]
elif event["type"] == "answer":
response_full += event["answer"]
elif event["type"] == "structured_answer":
response_full = event["answer"]
is_structured = True
schema_info = event.get("schema")
elif event["type"] == "source":
source_log_docs = event["source"]
elif event["type"] == "tool_calls":
tool_calls = event["tool_calls"]
elif event["type"] == "thought":
thought = event["thought"]
elif event["type"] == "error":
logger.error(f"Error from stream: {event['error']}")
return None, None, None, None, event["error"]
elif event["type"] == "end":
stream_ended = True
except (json.JSONDecodeError, KeyError) as e:
logger.warning(f"Error parsing stream event: {e}, line: {line}")
continue
if not stream_ended:
logger.error("Stream ended unexpectedly without an 'end' event.")
return None, None, None, None, "Stream ended unexpectedly"
result = (
conversation_id,
response_full,
source_log_docs,
tool_calls,
thought,
None,
)
if is_structured:
result = result + ({"structured": True, "schema": schema_info},)
return result
def error_stream_generate(self, err_response):
data = json.dumps({"type": "error", "error": err_response})
yield f"data: {data}\n\n"