mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 00:23:17 +00:00
354 lines
14 KiB
Python
354 lines
14 KiB
Python
import datetime
|
|
import json
|
|
import logging
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Any, Dict, Optional
|
|
|
|
from bson.dbref import DBRef
|
|
|
|
from bson.objectid import ObjectId
|
|
|
|
from application.agents.agent_creator import AgentCreator
|
|
from application.api.answer.services.conversation_service import ConversationService
|
|
from application.core.mongo_db import MongoDB
|
|
from application.core.settings import settings
|
|
from application.retriever.retriever_creator import RetrieverCreator
|
|
from application.utils import get_gpt_model, limit_chat_history
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def get_prompt(prompt_id: str, prompts_collection=None) -> str:
|
|
"""
|
|
Get a prompt by preset name or MongoDB ID
|
|
"""
|
|
current_dir = Path(__file__).resolve().parents[3]
|
|
prompts_dir = current_dir / "prompts"
|
|
|
|
preset_mapping = {
|
|
"default": "chat_combine_default.txt",
|
|
"creative": "chat_combine_creative.txt",
|
|
"strict": "chat_combine_strict.txt",
|
|
"reduce": "chat_reduce_prompt.txt",
|
|
}
|
|
|
|
if prompt_id in preset_mapping:
|
|
file_path = os.path.join(prompts_dir, preset_mapping[prompt_id])
|
|
try:
|
|
with open(file_path, "r") as f:
|
|
return f.read()
|
|
except FileNotFoundError:
|
|
raise FileNotFoundError(f"Prompt file not found: {file_path}")
|
|
try:
|
|
if prompts_collection is None:
|
|
mongo = MongoDB.get_client()
|
|
db = mongo[settings.MONGO_DB_NAME]
|
|
prompts_collection = db["prompts"]
|
|
prompt_doc = prompts_collection.find_one({"_id": ObjectId(prompt_id)})
|
|
if not prompt_doc:
|
|
raise ValueError(f"Prompt with ID {prompt_id} not found")
|
|
return prompt_doc["content"]
|
|
except Exception as e:
|
|
raise ValueError(f"Invalid prompt ID: {prompt_id}") from e
|
|
|
|
|
|
class StreamProcessor:
|
|
def __init__(
|
|
self, request_data: Dict[str, Any], decoded_token: Optional[Dict[str, Any]]
|
|
):
|
|
mongo = MongoDB.get_client()
|
|
self.db = mongo[settings.MONGO_DB_NAME]
|
|
self.agents_collection = self.db["agents"]
|
|
self.attachments_collection = self.db["attachments"]
|
|
self.prompts_collection = self.db["prompts"]
|
|
|
|
self.data = request_data
|
|
self.decoded_token = decoded_token
|
|
self.initial_user_id = (
|
|
self.decoded_token.get("sub") if self.decoded_token is not None else None
|
|
)
|
|
self.conversation_id = self.data.get("conversation_id")
|
|
self.source = {}
|
|
self.all_sources = []
|
|
self.attachments = []
|
|
self.history = []
|
|
self.agent_config = {}
|
|
self.retriever_config = {}
|
|
self.is_shared_usage = False
|
|
self.shared_token = None
|
|
self.gpt_model = get_gpt_model()
|
|
self.conversation_service = ConversationService()
|
|
|
|
def initialize(self):
|
|
"""Initialize all required components for processing"""
|
|
self._configure_agent()
|
|
self._configure_source()
|
|
self._configure_retriever()
|
|
self._configure_agent()
|
|
self._load_conversation_history()
|
|
self._process_attachments()
|
|
|
|
def _load_conversation_history(self):
|
|
"""Load conversation history either from DB or request"""
|
|
if self.conversation_id and self.initial_user_id:
|
|
conversation = self.conversation_service.get_conversation(
|
|
self.conversation_id, self.initial_user_id
|
|
)
|
|
if not conversation:
|
|
raise ValueError("Conversation not found or unauthorized")
|
|
self.history = [
|
|
{"prompt": query["prompt"], "response": query["response"]}
|
|
for query in conversation.get("queries", [])
|
|
]
|
|
else:
|
|
self.history = limit_chat_history(
|
|
json.loads(self.data.get("history", "[]")), gpt_model=self.gpt_model
|
|
)
|
|
|
|
def _process_attachments(self):
|
|
"""Process any attachments in the request"""
|
|
attachment_ids = self.data.get("attachments", [])
|
|
self.attachments = self._get_attachments_content(
|
|
attachment_ids, self.initial_user_id
|
|
)
|
|
|
|
def _get_attachments_content(self, attachment_ids, user_id):
|
|
"""
|
|
Retrieve content from attachment documents based on their IDs.
|
|
"""
|
|
if not attachment_ids:
|
|
return []
|
|
attachments = []
|
|
for attachment_id in attachment_ids:
|
|
try:
|
|
attachment_doc = self.attachments_collection.find_one(
|
|
{"_id": ObjectId(attachment_id), "user": user_id}
|
|
)
|
|
|
|
if attachment_doc:
|
|
attachments.append(attachment_doc)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error retrieving attachment {attachment_id}: {e}", exc_info=True
|
|
)
|
|
return attachments
|
|
|
|
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
|
|
"""Get API key for agent with access control"""
|
|
if not agent_id:
|
|
return None, False, None
|
|
try:
|
|
agent = self.agents_collection.find_one({"_id": ObjectId(agent_id)})
|
|
if agent is None:
|
|
raise Exception("Agent not found")
|
|
is_owner = agent.get("user") == user_id
|
|
is_shared_with_user = agent.get(
|
|
"shared_publicly", False
|
|
) or user_id in agent.get("shared_with", [])
|
|
|
|
if not (is_owner or is_shared_with_user):
|
|
raise Exception("Unauthorized access to the agent")
|
|
if is_owner:
|
|
self.agents_collection.update_one(
|
|
{"_id": ObjectId(agent_id)},
|
|
{
|
|
"$set": {
|
|
"lastUsedAt": datetime.datetime.now(datetime.timezone.utc)
|
|
}
|
|
},
|
|
)
|
|
return str(agent["key"]), not is_owner, agent.get("shared_token")
|
|
except Exception as e:
|
|
logger.error(f"Error in get_agent_key: {str(e)}", exc_info=True)
|
|
raise
|
|
|
|
def _get_data_from_api_key(self, api_key: str) -> Dict[str, Any]:
|
|
data = self.agents_collection.find_one({"key": api_key})
|
|
if not data:
|
|
raise Exception("Invalid API Key, please generate a new key", 401)
|
|
source = data.get("source")
|
|
if isinstance(source, DBRef):
|
|
source_doc = self.db.dereference(source)
|
|
if source_doc:
|
|
data["source"] = str(source_doc["_id"])
|
|
data["retriever"] = source_doc.get("retriever", data.get("retriever"))
|
|
data["chunks"] = source_doc.get("chunks", data.get("chunks"))
|
|
else:
|
|
data["source"] = None
|
|
elif source == "default":
|
|
data["source"] = "default"
|
|
else:
|
|
data["source"] = None
|
|
# Handle multiple sources
|
|
|
|
sources = data.get("sources", [])
|
|
if sources and isinstance(sources, list):
|
|
sources_list = []
|
|
for i, source_ref in enumerate(sources):
|
|
if source_ref == "default":
|
|
processed_source = {
|
|
"id": "default",
|
|
"retriever": "classic",
|
|
"chunks": data.get("chunks", "2"),
|
|
}
|
|
sources_list.append(processed_source)
|
|
elif isinstance(source_ref, DBRef):
|
|
source_doc = self.db.dereference(source_ref)
|
|
if source_doc:
|
|
processed_source = {
|
|
"id": str(source_doc["_id"]),
|
|
"retriever": source_doc.get("retriever", "classic"),
|
|
"chunks": source_doc.get("chunks", data.get("chunks", "2")),
|
|
}
|
|
sources_list.append(processed_source)
|
|
data["sources"] = sources_list
|
|
else:
|
|
data["sources"] = []
|
|
return data
|
|
|
|
def _configure_source(self):
|
|
"""Configure the source based on agent data"""
|
|
api_key = self.data.get("api_key") or self.agent_key
|
|
|
|
if api_key:
|
|
agent_data = self._get_data_from_api_key(api_key)
|
|
|
|
if agent_data.get("sources") and len(agent_data["sources"]) > 0:
|
|
source_ids = [
|
|
source["id"] for source in agent_data["sources"] if source.get("id")
|
|
]
|
|
if source_ids:
|
|
self.source = {"active_docs": source_ids}
|
|
else:
|
|
self.source = {}
|
|
self.all_sources = agent_data["sources"]
|
|
elif agent_data.get("source"):
|
|
self.source = {"active_docs": agent_data["source"]}
|
|
self.all_sources = [
|
|
{
|
|
"id": agent_data["source"],
|
|
"retriever": agent_data.get("retriever", "classic"),
|
|
}
|
|
]
|
|
else:
|
|
self.source = {}
|
|
self.all_sources = []
|
|
return
|
|
if "active_docs" in self.data:
|
|
self.source = {"active_docs": self.data["active_docs"]}
|
|
return
|
|
self.source = {}
|
|
self.all_sources = []
|
|
|
|
def _configure_agent(self):
|
|
"""Configure the agent based on request data"""
|
|
agent_id = self.data.get("agent_id")
|
|
self.agent_key, self.is_shared_usage, self.shared_token = self._get_agent_key(
|
|
agent_id, self.initial_user_id
|
|
)
|
|
|
|
api_key = self.data.get("api_key")
|
|
if api_key:
|
|
data_key = self._get_data_from_api_key(api_key)
|
|
self.agent_config.update(
|
|
{
|
|
"prompt_id": data_key.get("prompt_id", "default"),
|
|
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
|
|
"user_api_key": api_key,
|
|
"json_schema": data_key.get("json_schema"),
|
|
}
|
|
)
|
|
self.initial_user_id = data_key.get("user")
|
|
self.decoded_token = {"sub": data_key.get("user")}
|
|
if data_key.get("source"):
|
|
self.source = {"active_docs": data_key["source"]}
|
|
if data_key.get("retriever"):
|
|
self.retriever_config["retriever_name"] = data_key["retriever"]
|
|
if data_key.get("chunks") is not None:
|
|
try:
|
|
self.retriever_config["chunks"] = int(data_key["chunks"])
|
|
except (ValueError, TypeError):
|
|
logger.warning(
|
|
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
|
|
)
|
|
self.retriever_config["chunks"] = 2
|
|
elif self.agent_key:
|
|
data_key = self._get_data_from_api_key(self.agent_key)
|
|
self.agent_config.update(
|
|
{
|
|
"prompt_id": data_key.get("prompt_id", "default"),
|
|
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
|
|
"user_api_key": self.agent_key,
|
|
"json_schema": data_key.get("json_schema"),
|
|
}
|
|
)
|
|
self.decoded_token = (
|
|
self.decoded_token
|
|
if self.is_shared_usage
|
|
else {"sub": data_key.get("user")}
|
|
)
|
|
if data_key.get("source"):
|
|
self.source = {"active_docs": data_key["source"]}
|
|
if data_key.get("retriever"):
|
|
self.retriever_config["retriever_name"] = data_key["retriever"]
|
|
if data_key.get("chunks") is not None:
|
|
try:
|
|
self.retriever_config["chunks"] = int(data_key["chunks"])
|
|
except (ValueError, TypeError):
|
|
logger.warning(
|
|
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
|
|
)
|
|
self.retriever_config["chunks"] = 2
|
|
else:
|
|
self.agent_config.update(
|
|
{
|
|
"prompt_id": self.data.get("prompt_id", "default"),
|
|
"agent_type": settings.AGENT_NAME,
|
|
"user_api_key": None,
|
|
"json_schema": None,
|
|
}
|
|
)
|
|
|
|
def _configure_retriever(self):
|
|
"""Configure the retriever based on request data"""
|
|
self.retriever_config = {
|
|
"retriever_name": self.data.get("retriever", "classic"),
|
|
"chunks": int(self.data.get("chunks", 2)),
|
|
"token_limit": self.data.get("token_limit", settings.DEFAULT_MAX_HISTORY),
|
|
}
|
|
|
|
api_key = self.data.get("api_key") or self.agent_key
|
|
if not api_key and "isNoneDoc" in self.data and self.data["isNoneDoc"]:
|
|
self.retriever_config["chunks"] = 0
|
|
|
|
def create_agent(self):
|
|
"""Create and return the configured agent"""
|
|
return AgentCreator.create_agent(
|
|
self.agent_config["agent_type"],
|
|
endpoint="stream",
|
|
llm_name=settings.LLM_PROVIDER,
|
|
gpt_model=self.gpt_model,
|
|
api_key=settings.API_KEY,
|
|
user_api_key=self.agent_config["user_api_key"],
|
|
prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection),
|
|
chat_history=self.history,
|
|
decoded_token=self.decoded_token,
|
|
attachments=self.attachments,
|
|
json_schema=self.agent_config.get("json_schema"),
|
|
)
|
|
|
|
def create_retriever(self):
|
|
"""Create and return the configured retriever"""
|
|
return RetrieverCreator.create_retriever(
|
|
self.retriever_config["retriever_name"],
|
|
source=self.source,
|
|
chat_history=self.history,
|
|
prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection),
|
|
chunks=self.retriever_config["chunks"],
|
|
token_limit=self.retriever_config["token_limit"],
|
|
gpt_model=self.gpt_model,
|
|
user_api_key=self.agent_config["user_api_key"],
|
|
decoded_token=self.decoded_token,
|
|
)
|