Files
DocsGPT/application/api/answer/services/stream_processor.py

354 lines
14 KiB
Python

import datetime
import json
import logging
import os
from pathlib import Path
from typing import Any, Dict, Optional
from bson.dbref import DBRef
from bson.objectid import ObjectId
from application.agents.agent_creator import AgentCreator
from application.api.answer.services.conversation_service import ConversationService
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.retriever.retriever_creator import RetrieverCreator
from application.utils import get_gpt_model, limit_chat_history
logger = logging.getLogger(__name__)
def get_prompt(prompt_id: str, prompts_collection=None) -> str:
"""
Get a prompt by preset name or MongoDB ID
"""
current_dir = Path(__file__).resolve().parents[3]
prompts_dir = current_dir / "prompts"
preset_mapping = {
"default": "chat_combine_default.txt",
"creative": "chat_combine_creative.txt",
"strict": "chat_combine_strict.txt",
"reduce": "chat_reduce_prompt.txt",
}
if prompt_id in preset_mapping:
file_path = os.path.join(prompts_dir, preset_mapping[prompt_id])
try:
with open(file_path, "r") as f:
return f.read()
except FileNotFoundError:
raise FileNotFoundError(f"Prompt file not found: {file_path}")
try:
if prompts_collection is None:
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
prompts_collection = db["prompts"]
prompt_doc = prompts_collection.find_one({"_id": ObjectId(prompt_id)})
if not prompt_doc:
raise ValueError(f"Prompt with ID {prompt_id} not found")
return prompt_doc["content"]
except Exception as e:
raise ValueError(f"Invalid prompt ID: {prompt_id}") from e
class StreamProcessor:
def __init__(
self, request_data: Dict[str, Any], decoded_token: Optional[Dict[str, Any]]
):
mongo = MongoDB.get_client()
self.db = mongo[settings.MONGO_DB_NAME]
self.agents_collection = self.db["agents"]
self.attachments_collection = self.db["attachments"]
self.prompts_collection = self.db["prompts"]
self.data = request_data
self.decoded_token = decoded_token
self.initial_user_id = (
self.decoded_token.get("sub") if self.decoded_token is not None else None
)
self.conversation_id = self.data.get("conversation_id")
self.source = {}
self.all_sources = []
self.attachments = []
self.history = []
self.agent_config = {}
self.retriever_config = {}
self.is_shared_usage = False
self.shared_token = None
self.gpt_model = get_gpt_model()
self.conversation_service = ConversationService()
def initialize(self):
"""Initialize all required components for processing"""
self._configure_agent()
self._configure_source()
self._configure_retriever()
self._configure_agent()
self._load_conversation_history()
self._process_attachments()
def _load_conversation_history(self):
"""Load conversation history either from DB or request"""
if self.conversation_id and self.initial_user_id:
conversation = self.conversation_service.get_conversation(
self.conversation_id, self.initial_user_id
)
if not conversation:
raise ValueError("Conversation not found or unauthorized")
self.history = [
{"prompt": query["prompt"], "response": query["response"]}
for query in conversation.get("queries", [])
]
else:
self.history = limit_chat_history(
json.loads(self.data.get("history", "[]")), gpt_model=self.gpt_model
)
def _process_attachments(self):
"""Process any attachments in the request"""
attachment_ids = self.data.get("attachments", [])
self.attachments = self._get_attachments_content(
attachment_ids, self.initial_user_id
)
def _get_attachments_content(self, attachment_ids, user_id):
"""
Retrieve content from attachment documents based on their IDs.
"""
if not attachment_ids:
return []
attachments = []
for attachment_id in attachment_ids:
try:
attachment_doc = self.attachments_collection.find_one(
{"_id": ObjectId(attachment_id), "user": user_id}
)
if attachment_doc:
attachments.append(attachment_doc)
except Exception as e:
logger.error(
f"Error retrieving attachment {attachment_id}: {e}", exc_info=True
)
return attachments
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
"""Get API key for agent with access control"""
if not agent_id:
return None, False, None
try:
agent = self.agents_collection.find_one({"_id": ObjectId(agent_id)})
if agent is None:
raise Exception("Agent not found")
is_owner = agent.get("user") == user_id
is_shared_with_user = agent.get(
"shared_publicly", False
) or user_id in agent.get("shared_with", [])
if not (is_owner or is_shared_with_user):
raise Exception("Unauthorized access to the agent")
if is_owner:
self.agents_collection.update_one(
{"_id": ObjectId(agent_id)},
{
"$set": {
"lastUsedAt": datetime.datetime.now(datetime.timezone.utc)
}
},
)
return str(agent["key"]), not is_owner, agent.get("shared_token")
except Exception as e:
logger.error(f"Error in get_agent_key: {str(e)}", exc_info=True)
raise
def _get_data_from_api_key(self, api_key: str) -> Dict[str, Any]:
data = self.agents_collection.find_one({"key": api_key})
if not data:
raise Exception("Invalid API Key, please generate a new key", 401)
source = data.get("source")
if isinstance(source, DBRef):
source_doc = self.db.dereference(source)
if source_doc:
data["source"] = str(source_doc["_id"])
data["retriever"] = source_doc.get("retriever", data.get("retriever"))
data["chunks"] = source_doc.get("chunks", data.get("chunks"))
else:
data["source"] = None
elif source == "default":
data["source"] = "default"
else:
data["source"] = None
# Handle multiple sources
sources = data.get("sources", [])
if sources and isinstance(sources, list):
sources_list = []
for i, source_ref in enumerate(sources):
if source_ref == "default":
processed_source = {
"id": "default",
"retriever": "classic",
"chunks": data.get("chunks", "2"),
}
sources_list.append(processed_source)
elif isinstance(source_ref, DBRef):
source_doc = self.db.dereference(source_ref)
if source_doc:
processed_source = {
"id": str(source_doc["_id"]),
"retriever": source_doc.get("retriever", "classic"),
"chunks": source_doc.get("chunks", data.get("chunks", "2")),
}
sources_list.append(processed_source)
data["sources"] = sources_list
else:
data["sources"] = []
return data
def _configure_source(self):
"""Configure the source based on agent data"""
api_key = self.data.get("api_key") or self.agent_key
if api_key:
agent_data = self._get_data_from_api_key(api_key)
if agent_data.get("sources") and len(agent_data["sources"]) > 0:
source_ids = [
source["id"] for source in agent_data["sources"] if source.get("id")
]
if source_ids:
self.source = {"active_docs": source_ids}
else:
self.source = {}
self.all_sources = agent_data["sources"]
elif agent_data.get("source"):
self.source = {"active_docs": agent_data["source"]}
self.all_sources = [
{
"id": agent_data["source"],
"retriever": agent_data.get("retriever", "classic"),
}
]
else:
self.source = {}
self.all_sources = []
return
if "active_docs" in self.data:
self.source = {"active_docs": self.data["active_docs"]}
return
self.source = {}
self.all_sources = []
def _configure_agent(self):
"""Configure the agent based on request data"""
agent_id = self.data.get("agent_id")
self.agent_key, self.is_shared_usage, self.shared_token = self._get_agent_key(
agent_id, self.initial_user_id
)
api_key = self.data.get("api_key")
if api_key:
data_key = self._get_data_from_api_key(api_key)
self.agent_config.update(
{
"prompt_id": data_key.get("prompt_id", "default"),
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
"user_api_key": api_key,
"json_schema": data_key.get("json_schema"),
}
)
self.initial_user_id = data_key.get("user")
self.decoded_token = {"sub": data_key.get("user")}
if data_key.get("source"):
self.source = {"active_docs": data_key["source"]}
if data_key.get("retriever"):
self.retriever_config["retriever_name"] = data_key["retriever"]
if data_key.get("chunks") is not None:
try:
self.retriever_config["chunks"] = int(data_key["chunks"])
except (ValueError, TypeError):
logger.warning(
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
)
self.retriever_config["chunks"] = 2
elif self.agent_key:
data_key = self._get_data_from_api_key(self.agent_key)
self.agent_config.update(
{
"prompt_id": data_key.get("prompt_id", "default"),
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
"user_api_key": self.agent_key,
"json_schema": data_key.get("json_schema"),
}
)
self.decoded_token = (
self.decoded_token
if self.is_shared_usage
else {"sub": data_key.get("user")}
)
if data_key.get("source"):
self.source = {"active_docs": data_key["source"]}
if data_key.get("retriever"):
self.retriever_config["retriever_name"] = data_key["retriever"]
if data_key.get("chunks") is not None:
try:
self.retriever_config["chunks"] = int(data_key["chunks"])
except (ValueError, TypeError):
logger.warning(
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
)
self.retriever_config["chunks"] = 2
else:
self.agent_config.update(
{
"prompt_id": self.data.get("prompt_id", "default"),
"agent_type": settings.AGENT_NAME,
"user_api_key": None,
"json_schema": None,
}
)
def _configure_retriever(self):
"""Configure the retriever based on request data"""
self.retriever_config = {
"retriever_name": self.data.get("retriever", "classic"),
"chunks": int(self.data.get("chunks", 2)),
"token_limit": self.data.get("token_limit", settings.DEFAULT_MAX_HISTORY),
}
api_key = self.data.get("api_key") or self.agent_key
if not api_key and "isNoneDoc" in self.data and self.data["isNoneDoc"]:
self.retriever_config["chunks"] = 0
def create_agent(self):
"""Create and return the configured agent"""
return AgentCreator.create_agent(
self.agent_config["agent_type"],
endpoint="stream",
llm_name=settings.LLM_PROVIDER,
gpt_model=self.gpt_model,
api_key=settings.API_KEY,
user_api_key=self.agent_config["user_api_key"],
prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection),
chat_history=self.history,
decoded_token=self.decoded_token,
attachments=self.attachments,
json_schema=self.agent_config.get("json_schema"),
)
def create_retriever(self):
"""Create and return the configured retriever"""
return RetrieverCreator.create_retriever(
self.retriever_config["retriever_name"],
source=self.source,
chat_history=self.history,
prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection),
chunks=self.retriever_config["chunks"],
token_limit=self.retriever_config["token_limit"],
gpt_model=self.gpt_model,
user_api_key=self.agent_config["user_api_key"],
decoded_token=self.decoded_token,
)