Files
DocsGPT/application/api/answer/services/stream_processor.py
2026-05-20 10:40:15 +01:00

1252 lines
52 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import datetime
import json
import logging
import os
from pathlib import Path
from typing import Any, Dict, Optional, Set
from application.agents.agent_creator import AgentCreator
from application.agents.default_tools import synthesized_default_tools
from application.api.answer.services.compression import CompressionOrchestrator
from application.api.answer.services.compression.token_counter import TokenCounter
from application.api.answer.services.conversation_service import ConversationService
from application.api.answer.services.prompt_renderer import PromptRenderer
from application.core.model_utils import (
get_api_key_for_provider,
get_default_model_id,
get_provider_from_model_id,
validate_model_id,
)
from application.core.settings import settings
from sqlalchemy import text as sql_text
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.attachments import AttachmentsRepository
from application.storage.db.repositories.prompts import PromptsRepository
from application.storage.db.repositories.sources import SourcesRepository
from application.storage.db.repositories.user_tools import UserToolsRepository
from application.storage.db.repositories.users import UsersRepository
from application.storage.db.session import db_readonly, db_session
from application.retriever.retriever_creator import RetrieverCreator
from application.utils import (
calculate_doc_token_budget,
limit_chat_history,
)
logger = logging.getLogger(__name__)
def get_prompt(prompt_id: str, prompts_collection=None) -> str:
"""Get a prompt by preset name or Postgres ID (UUID or legacy ObjectId).
The ``prompts_collection`` parameter is retained for backwards
compatibility with call sites that still pass it positionally; it is
ignored post-cutover.
"""
del prompts_collection # unused — retained for call-site compatibility
# Callers may pass a ``uuid.UUID`` (from a PG ``prompt_id`` column) or a
# plain string ("default"/"creative"/legacy ObjectId). Normalise to str
# so both the preset lookup and the UUID-vs-legacy branching work.
# ``None`` / empty means "use the default prompt" — agents that never
# set a custom prompt land here (PG ``agents.prompt_id`` is NULL).
if prompt_id is None or prompt_id == "":
prompt_id = "default"
elif not isinstance(prompt_id, str):
prompt_id = str(prompt_id)
current_dir = Path(__file__).resolve().parents[3]
prompts_dir = current_dir / "prompts"
CLASSIC_PRESETS = {
"default": "chat_combine_default.txt",
"creative": "chat_combine_creative.txt",
"strict": "chat_combine_strict.txt",
"reduce": "chat_reduce_prompt.txt",
}
AGENTIC_PRESETS = {
"default": "agentic/default.txt",
"creative": "agentic/creative.txt",
"strict": "agentic/strict.txt",
}
preset_mapping = {
**CLASSIC_PRESETS,
**{f"agentic_{k}": v for k, v in AGENTIC_PRESETS.items()},
}
if prompt_id in preset_mapping:
file_path = os.path.join(prompts_dir, preset_mapping[prompt_id])
try:
with open(file_path, "r") as f:
return f.read()
except FileNotFoundError:
raise FileNotFoundError(f"Prompt file not found: {file_path}")
try:
with db_readonly() as conn:
repo = PromptsRepository(conn)
prompt_doc = None
if looks_like_uuid(prompt_id):
prompt_doc = repo.get_for_rendering(prompt_id)
if prompt_doc is None:
prompt_doc = repo.get_by_legacy_id(prompt_id)
if not prompt_doc:
raise ValueError(f"Prompt with ID {prompt_id} not found")
return prompt_doc["content"]
except ValueError:
raise
except Exception as e:
raise ValueError(f"Invalid prompt ID: {prompt_id}") from e
class StreamProcessor:
def __init__(
self, request_data: Dict[str, Any], decoded_token: Optional[Dict[str, Any]]
):
# Legacy attribute retained as None for any external callers that
# introspect the processor; all DB access uses per-op connections.
self.prompts_collection = None
self.data = request_data
self.decoded_token = decoded_token
self.initial_user_id = (
self.decoded_token.get("sub") if self.decoded_token is not None else None
)
self.conversation_id = self.data.get("conversation_id")
self.source = {}
self.all_sources = []
self.attachments = []
self.history = []
self.retrieved_docs = []
self.agent_config = {}
self.retriever_config = {}
self.is_shared_usage = False
self.shared_token = None
self.agent_id = self.data.get("agent_id")
self.agent_key = None
self.model_id: Optional[str] = None
# BYOM-resolution scope, set by _validate_and_set_model.
self.model_user_id: Optional[str] = None
# WAL placeholder id pulled from continuation state on resume.
self.reserved_message_id: Optional[str] = None
# Carried through resumes so multi-pause runs keep one request_id.
self.request_id: Optional[str] = None
self.conversation_service = ConversationService()
self.compression_orchestrator = CompressionOrchestrator(
self.conversation_service
)
self.prompt_renderer = PromptRenderer()
self._prompt_content: Optional[str] = None
self._required_tool_actions: Optional[Dict[str, Set[Optional[str]]]] = None
self.compressed_summary: Optional[str] = None
self.compressed_summary_tokens: int = 0
self._agent_data: Optional[Dict[str, Any]] = None
def initialize(self):
"""Initialize all required components for processing"""
self._configure_agent()
self._validate_and_set_model()
self._configure_source()
self._configure_retriever()
self._load_conversation_history()
self._process_attachments()
def build_agent(self, question: str):
"""One call to go from request data to a ready-to-run agent.
Combines initialize(), pre_fetch_docs(), pre_fetch_tools(), and
create_agent() into a single convenience method.
"""
self.initialize()
agent_type = self.agent_config.get("agent_type", "classic")
# Agentic/research agents skip pre-fetch — the LLM searches on-demand via tools
if agent_type in ("agentic", "research"):
tools_data = self.pre_fetch_tools()
return self.create_agent(tools_data=tools_data)
docs_together, docs_list = self.pre_fetch_docs(question)
tools_data = self.pre_fetch_tools()
return self.create_agent(
docs_together=docs_together,
docs=docs_list,
tools_data=tools_data,
)
def _load_conversation_history(self):
"""Load conversation history either from DB or request"""
if self.conversation_id and self.initial_user_id:
conversation = self.conversation_service.get_conversation(
self.conversation_id, self.initial_user_id
)
if not conversation:
raise ValueError("Conversation not found or unauthorized")
# Check if compression is enabled and needed
if settings.ENABLE_CONVERSATION_COMPRESSION:
self._handle_compression(conversation)
else:
# Original behavior - load all history (include metadata if present)
self.history = [
{
"prompt": query["prompt"],
"response": query["response"],
**(
{"metadata": query["metadata"]}
if "metadata" in query
else {}
),
}
for query in conversation.get("queries", [])
]
else:
# model_user_id keeps history trim aligned with the BYOM's
# actual context window instead of the default 128k.
self.history = limit_chat_history(
json.loads(self.data.get("history", "[]")),
model_id=self.model_id,
user_id=self.model_user_id,
)
def _handle_compression(self, conversation: Dict[str, Any]):
"""Handle conversation compression logic using orchestrator."""
try:
# initial_user_id for conversation access; model_user_id
# for BYOM context-window / provider lookups.
result = self.compression_orchestrator.compress_if_needed(
conversation_id=self.conversation_id,
user_id=self.initial_user_id,
model_user_id=self.model_user_id,
model_id=self.model_id,
decoded_token=self.decoded_token,
)
if not result.success:
logger.error(f"Compression failed: {result.error}, using full history")
self.history = [
{
"prompt": query["prompt"],
"response": query["response"],
**({"metadata": query["metadata"]} if "metadata" in query else {}),
}
for query in conversation.get("queries", [])
]
return
if result.compression_performed and result.compressed_summary:
self.compressed_summary = result.compressed_summary
self.compressed_summary_tokens = TokenCounter.count_message_tokens(
[{"content": result.compressed_summary}]
)
logger.info(
f"Using compressed summary ({self.compressed_summary_tokens} tokens) "
f"+ {len(result.recent_queries)} recent messages"
)
self.history = result.as_history()
# Preserve metadata from recent queries (as_history only has prompt/response)
recent = result.recent_queries if result.recent_queries else conversation.get("queries", [])
for i, entry in enumerate(self.history):
# Match by index from the end of recent queries
offset = len(recent) - len(self.history)
qi = offset + i
if 0 <= qi < len(recent) and "metadata" in recent[qi]:
entry["metadata"] = recent[qi]["metadata"]
except Exception as e:
logger.error(
f"Error handling compression, falling back to standard history: {str(e)}",
exc_info=True,
)
self.history = [
{
"prompt": query["prompt"],
"response": query["response"],
**({"metadata": query["metadata"]} if "metadata" in query else {}),
}
for query in conversation.get("queries", [])
]
def _process_attachments(self):
"""Process any attachments in the request"""
attachment_ids = self.data.get("attachments", [])
self.attachments = self._get_attachments_content(
attachment_ids, self.initial_user_id
)
def _get_attachments_content(self, attachment_ids, user_id):
if not attachment_ids:
return []
attachments = []
try:
with db_readonly() as conn:
repo = AttachmentsRepository(conn)
for attachment_id in attachment_ids:
try:
attachment_doc = repo.get_any(str(attachment_id), user_id)
if attachment_doc:
attachments.append(attachment_doc)
except Exception as e:
logger.error(
f"Error retrieving attachment {attachment_id}: {e}",
exc_info=True,
)
except Exception as e:
logger.error(f"Error opening attachments connection: {e}", exc_info=True)
return attachments
def _validate_and_set_model(self):
"""Pick model_id with agent authority on agent-bound chats."""
from application.core.model_settings import ModelRegistry
requested_model = self.data.get("model_id")
# Caller picks from their own BYOM layer; agent defaults resolve
# under the owner's layer (shared agents have caller != owner).
caller_user_id = self.initial_user_id
owner_user_id = self.agent_config.get("user_id") or caller_user_id
# Agent-bound: agent's default_model_id wins, body's model_id is dropped.
agent_bound = self._agent_data is not None
if agent_bound:
agent_default_model = self.agent_config.get("default_model_id", "")
if agent_default_model and validate_model_id(
agent_default_model, user_id=owner_user_id
):
self.model_id = agent_default_model
self.model_user_id = owner_user_id
else:
self.model_id = get_default_model_id()
self.model_user_id = None
return
if requested_model:
if not validate_model_id(requested_model, user_id=caller_user_id):
registry = ModelRegistry.get_instance()
available_models = [
m.id
for m in registry.get_enabled_models(user_id=caller_user_id)
]
raise ValueError(
f"Invalid model_id '{requested_model}'. "
f"Available models: {', '.join(available_models[:5])}"
+ (
f" and {len(available_models) - 5} more"
if len(available_models) > 5
else ""
)
)
self.model_id = requested_model
self.model_user_id = caller_user_id
else:
self.model_id = get_default_model_id()
self.model_user_id = None
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
"""Get API key for agent with access control."""
if not agent_id:
return None, False, None
try:
with db_readonly() as conn:
# Lookup without user scoping — access control is done
# against ``user_id`` / ``shared_with`` / ``shared`` flags
# right below, matching the legacy Mongo semantics.
repo = AgentsRepository(conn)
agent = None
if looks_like_uuid(str(agent_id)):
result = conn.execute(
sql_text(
"SELECT * FROM agents WHERE id = CAST(:id AS uuid)"
),
{"id": str(agent_id)},
)
row = result.fetchone()
if row is not None:
agent = row_to_dict(row)
if agent is None:
agent = repo.get_by_legacy_id(str(agent_id))
if agent is None:
raise Exception("Agent not found")
agent_owner = agent.get("user_id")
is_owner = agent_owner == user_id
is_shared_with_user = bool(agent.get("shared", False))
if not (is_owner or is_shared_with_user):
raise Exception("Unauthorized access to the agent")
if is_owner:
now = datetime.datetime.now(datetime.timezone.utc)
try:
with db_session() as conn:
AgentsRepository(conn).update(
str(agent["id"]), agent_owner,
{"last_used_at": now},
)
except Exception:
logger.warning(
"Failed to update last_used_at for agent",
exc_info=True,
)
return (
str(agent["key"]) if agent.get("key") else None,
not is_owner,
agent.get("shared_token"),
)
except Exception as e:
logger.error(f"Error in get_agent_key: {str(e)}", exc_info=True)
raise
def _get_data_from_api_key(self, api_key: str) -> Dict[str, Any]:
"""Resolve agent metadata + the unioned source set for the given key."""
with db_readonly() as conn:
agent = AgentsRepository(conn).find_by_key(api_key)
if not agent:
raise Exception("Invalid API Key, please generate a new key", 401)
sources_repo = SourcesRepository(conn)
# The repo dict uses "user_id" — the streaming path expects
# a "user" key (legacy Mongo shape) for identity propagation.
data: Dict[str, Any] = dict(agent)
data["user"] = agent.get("user_id")
# Active sources = primary extras, primary first, deduplicated.
# ``_configure_source`` ignores an empty ``data["sources"]``,
# so the primary must appear in the union too — not only in
# the legacy ``data["source"]`` slot.
sources_list: list = []
seen: set = set()
owner = agent.get("user_id")
primary_id = agent.get("source_id")
# ``sources`` row may have NULL ``retriever``/``chunks`` —
# fall back to the agent's value (``dict.get`` returns None
# even when the key exists with value None).
if primary_id:
source_doc = sources_repo.get(str(primary_id), owner)
if source_doc:
sid = str(source_doc["id"])
data["source"] = sid
src_retriever = source_doc.get("retriever")
if src_retriever:
data["retriever"] = src_retriever
src_chunks = source_doc.get("chunks")
if src_chunks is not None:
data["chunks"] = src_chunks
sources_list.append(
{
"id": sid,
"retriever": src_retriever or "classic",
"chunks": (
src_chunks if src_chunks is not None
else data.get("chunks", "2")
),
}
)
seen.add(sid)
else:
data["source"] = None
else:
data["source"] = None
for sid_raw in agent.get("extra_source_ids") or []:
if not sid_raw:
continue
source_doc = sources_repo.get(str(sid_raw), owner)
if not source_doc:
continue
sid = str(source_doc["id"])
if sid in seen:
continue
src_retriever = source_doc.get("retriever")
src_chunks = source_doc.get("chunks")
sources_list.append(
{
"id": sid,
"retriever": src_retriever or "classic",
"chunks": (
src_chunks if src_chunks is not None
else data.get("chunks", "2")
),
}
)
seen.add(sid)
data["sources"] = sources_list
data["default_model_id"] = data.get("default_model_id", "")
return data
def _configure_source(self):
"""Configure the source based on agent data.
The literal string ``"default"`` is a placeholder meaning "no
ingested source" and is normalized to an empty source so that no
retrieval is attempted.
"""
if self._agent_data:
agent_data = self._agent_data
if agent_data.get("sources") and len(agent_data["sources"]) > 0:
source_ids = [
source["id"]
for source in agent_data["sources"]
if source.get("id") and source["id"] != "default"
]
if source_ids:
self.source = {"active_docs": source_ids}
else:
self.source = {}
self.all_sources = [
s for s in agent_data["sources"] if s.get("id") != "default"
]
elif agent_data.get("source") and agent_data["source"] != "default":
self.source = {"active_docs": agent_data["source"]}
self.all_sources = [
{
"id": agent_data["source"],
"retriever": agent_data.get("retriever", "classic"),
}
]
else:
self.source = {}
self.all_sources = []
return
if "active_docs" in self.data:
active_docs = self.data["active_docs"]
if active_docs and active_docs != "default":
self.source = {"active_docs": active_docs}
else:
self.source = {}
return
self.source = {}
self.all_sources = []
def _has_active_docs(self) -> bool:
"""Return True if a real document source is configured for retrieval."""
active_docs = self.source.get("active_docs") if self.source else None
if not active_docs:
return False
if active_docs == "default":
return False
return True
def _resolve_agent_id(self) -> Optional[str]:
"""Resolve agent_id from request, then fall back to conversation context."""
request_agent_id = self.data.get("agent_id")
if request_agent_id:
return str(request_agent_id)
if not self.conversation_id or not self.initial_user_id:
return None
try:
conversation = self.conversation_service.get_conversation(
self.conversation_id, self.initial_user_id
)
except Exception:
return None
if not conversation:
return None
conversation_agent_id = conversation.get("agent_id")
if conversation_agent_id:
return str(conversation_agent_id)
return None
def _configure_agent(self):
"""Configure the agent based on request data.
Unified flow: resolve the effective API key, then extract config once.
"""
agent_id = self._resolve_agent_id()
self.agent_key, self.is_shared_usage, self.shared_token = self._get_agent_key(
agent_id, self.initial_user_id
)
self.agent_id = str(agent_id) if agent_id else None
# Determine the effective API key (explicit > agent-derived)
effective_key = self.data.get("api_key") or self.agent_key
if effective_key:
self._agent_data = self._get_data_from_api_key(effective_key)
if self._agent_data.get("_id"):
self.agent_id = str(self._agent_data.get("_id"))
self.agent_config.update(
{
"prompt_id": self._agent_data.get("prompt_id", "default"),
"agent_type": self._agent_data.get("agent_type", settings.AGENT_NAME),
"user_api_key": effective_key,
"json_schema": self._agent_data.get("json_schema"),
"default_model_id": self._agent_data.get("default_model_id", ""),
"models": self._agent_data.get("models", []),
"allow_system_prompt_override": self._agent_data.get(
"allow_system_prompt_override", False
),
# Owner identity — _validate_and_set_model reads this to
# resolve owner-stored BYOM default_model_id against the
# owner's per-user model layer rather than the caller's.
"user_id": self._agent_data.get("user"),
}
)
# Set identity context
if self.data.get("api_key"):
# External API key: use the key owner's identity
self.initial_user_id = self._agent_data.get("user")
self.decoded_token = {"sub": self._agent_data.get("user")}
elif self.is_shared_usage:
# Shared agent: keep the caller's identity
pass
else:
# Owner using their own agent
self.decoded_token = {"sub": self._agent_data.get("user")}
# PG row exposes the workflow as ``workflow_id`` (UUID column);
# legacy Mongo shape used the key ``workflow``. Accept either so
# API-key-invoked workflow agents bind correctly downstream.
wf_ref = self._agent_data.get("workflow") or self._agent_data.get(
"workflow_id"
)
if wf_ref:
self.agent_config["workflow"] = str(wf_ref)
self.agent_config["workflow_owner"] = self._agent_data.get("user")
else:
# No API key — default/workflow configuration
agent_type = settings.AGENT_NAME
if self.data.get("workflow") and isinstance(
self.data.get("workflow"), dict
):
agent_type = "workflow"
self.agent_config["workflow"] = self.data["workflow"]
if isinstance(self.decoded_token, dict):
self.agent_config["workflow_owner"] = self.decoded_token.get("sub")
self.agent_config.update(
{
"prompt_id": self.data.get("prompt_id", "default"),
"agent_type": agent_type,
"user_api_key": None,
"json_schema": None,
"default_model_id": "",
}
)
def _configure_retriever(self):
"""Assemble retriever config; agent's values are authoritative when bound."""
# BYOM scope: owner for shared-agent BYOM, caller for own BYOM,
# None for built-ins. Without ``user_id`` here, the doc budget
# falls back to settings.DEFAULT_LLM_TOKEN_LIMIT and overfills
# the upstream context window for any small (e.g. 8k/32k) BYOM.
doc_token_limit = calculate_doc_token_budget(
model_id=self.model_id, user_id=self.model_user_id
)
retriever_name = "classic"
chunks = 2
if self._agent_data is not None:
# Agent-bound: agent wins, body's retriever/chunks are dropped.
if self._agent_data.get("retriever"):
retriever_name = self._agent_data["retriever"]
if self._agent_data.get("chunks") is not None:
try:
chunks = int(self._agent_data["chunks"])
except (ValueError, TypeError):
logger.warning(
f"Invalid agent chunks value: {self._agent_data['chunks']}, "
"using default value 2"
)
else:
if "retriever" in self.data:
retriever_name = self.data["retriever"]
if "chunks" in self.data:
try:
chunks = int(self.data["chunks"])
except (ValueError, TypeError):
logger.warning(
f"Invalid request chunks value: {self.data['chunks']}, "
"using default value 2"
)
self.retriever_config = {
"retriever_name": retriever_name,
"chunks": chunks,
"doc_token_limit": doc_token_limit,
}
# isNoneDoc without an API key forces no retrieval (agentless only)
api_key = self.data.get("api_key") or self.agent_key
if not api_key and "isNoneDoc" in self.data and self.data["isNoneDoc"]:
self.retriever_config["chunks"] = 0
def create_retriever(self):
return RetrieverCreator.create_retriever(
self.retriever_config["retriever_name"],
source=self.source,
chat_history=self.history,
prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection),
chunks=self.retriever_config["chunks"],
doc_token_limit=self.retriever_config.get("doc_token_limit", 50000),
model_id=self.model_id,
model_user_id=self.model_user_id,
user_api_key=self.agent_config["user_api_key"],
agent_id=self.agent_id,
decoded_token=self.decoded_token,
)
def pre_fetch_docs(self, question: str) -> tuple[Optional[str], Optional[list]]:
"""Pre-fetch documents for template rendering before agent creation"""
if self.data.get("isNoneDoc", False) and not self.agent_id:
logger.info("Pre-fetch skipped: isNoneDoc=True")
return None, None
if not self._has_active_docs():
logger.info("Pre-fetch skipped: no active docs configured")
return None, None
try:
retriever = self.create_retriever()
logger.info(
f"Pre-fetching docs with chunks={retriever.chunks}, doc_token_limit={retriever.doc_token_limit}"
)
docs = retriever.search(question)
logger.info(f"Pre-fetch retrieved {len(docs) if docs else 0} documents")
if not docs:
logger.info("Pre-fetch: No documents returned from search")
return None, None
self.retrieved_docs = docs
docs_with_filenames = []
for doc in docs:
filename = doc.get("filename") or doc.get("title") or doc.get("source")
if filename:
chunk_header = str(filename)
docs_with_filenames.append(f"{chunk_header}\n{doc['text']}")
else:
docs_with_filenames.append(doc["text"])
docs_together = "\n\n".join(docs_with_filenames)
logger.info(f"Pre-fetch docs_together size: {len(docs_together)} chars")
return docs_together, docs
except Exception as e:
logger.error(f"Failed to pre-fetch docs: {str(e)}", exc_info=True)
return None, None
def pre_fetch_tools(self) -> Optional[Dict[str, Any]]:
"""Pre-fetch tool data for template rendering before agent creation"""
if not settings.ENABLE_TOOL_PREFETCH:
logger.info(
"Tool pre-fetching disabled globally via ENABLE_TOOL_PREFETCH setting"
)
return None
if self.data.get("disable_tool_prefetch", False):
logger.info("Tool pre-fetching disabled for this request")
return None
required_tool_actions = self._get_required_tool_actions()
filtering_enabled = required_tool_actions is not None
try:
user_id = self.initial_user_id or "local"
agentless = self.agent_id is None
with db_readonly() as conn:
user_tools = UserToolsRepository(conn).list_active_for_user(user_id)
user_doc = (
UsersRepository(conn).get(user_id) if agentless else None
)
default_docs = (
synthesized_default_tools(user_doc) if agentless else []
)
tool_docs = list(user_tools) + default_docs
if not tool_docs:
return None
tools_data = {}
for tool_doc in tool_docs:
tool_name = tool_doc.get("name")
tool_id = str(tool_doc.get("_id") or tool_doc.get("id"))
is_default = bool(tool_doc.get("default"))
if filtering_enabled:
required_actions_by_name = required_tool_actions.get(
tool_name, set()
)
required_actions_by_id = required_tool_actions.get(tool_id, set())
required_actions = required_actions_by_name | required_actions_by_id
if not required_actions:
continue
else:
# No template names a default tool, so running its
# actions blind would only inject noise.
if is_default:
continue
required_actions = None
tool_data = self._fetch_tool_data(tool_doc, required_actions)
if tool_data:
# Defaults reachable by synthetic id only — the name
# key stays bound to an explicit row of the same name.
if not is_default:
tools_data[tool_name] = tool_data
tools_data[tool_id] = tool_data
return tools_data if tools_data else None
except Exception as e:
logger.warning(f"Failed to pre-fetch tools: {type(e).__name__}")
return None
def _fetch_tool_data(
self,
tool_doc: Dict[str, Any],
required_actions: Optional[Set[Optional[str]]],
) -> Optional[Dict[str, Any]]:
"""Fetch and execute tool actions with saved parameters"""
try:
from application.agents.tools.tool_manager import ToolManager
tool_name = tool_doc.get("name")
tool_config = tool_doc.get("config", {}).copy()
tool_config["tool_id"] = str(tool_doc["_id"])
tool_manager = ToolManager(config={tool_name: tool_config})
user_id = self.initial_user_id or "local"
tool = tool_manager.load_tool(tool_name, tool_config, user_id=user_id)
if not tool:
logger.debug(f"Tool '{tool_name}' failed to load")
return None
tool_actions = tool.get_actions_metadata()
if not tool_actions:
logger.debug(f"Tool '{tool_name}' has no actions")
return None
saved_actions = tool_doc.get("actions", [])
include_all_actions = required_actions is None or (
required_actions and None in required_actions
)
allowed_actions: Set[str] = (
{action for action in required_actions if isinstance(action, str)}
if required_actions
else set()
)
action_results = {}
for action_meta in tool_actions:
action_name = action_meta.get("name")
if action_name is None:
continue
if (
not include_all_actions
and allowed_actions
and action_name not in allowed_actions
):
continue
try:
saved_action = None
for sa in saved_actions:
if sa.get("name") == action_name:
saved_action = sa
break
action_params = action_meta.get("parameters", {})
properties = action_params.get("properties", {})
kwargs = {}
for param_name, param_spec in properties.items():
if saved_action:
saved_props = saved_action.get("parameters", {}).get(
"properties", {}
)
if param_name in saved_props:
param_value = saved_props[param_name].get("value")
if param_value is not None:
kwargs[param_name] = param_value
continue
if param_name in tool_config:
kwargs[param_name] = tool_config[param_name]
elif "default" in param_spec:
kwargs[param_name] = param_spec["default"]
result = tool.execute_action(action_name, **kwargs)
action_results[action_name] = result
except Exception as e:
logger.debug(
f"Action '{action_name}' execution failed: {type(e).__name__}"
)
continue
return action_results if action_results else None
except Exception as e:
logger.debug(f"Tool pre-fetch failed for '{tool_name}': {type(e).__name__}")
return None
def _get_prompt_content(self) -> Optional[str]:
"""Retrieve and cache the raw prompt content for the current agent configuration."""
if self._prompt_content is not None:
return self._prompt_content
prompt_id = (
self.agent_config.get("prompt_id")
if isinstance(self.agent_config, dict)
else None
)
if not prompt_id:
return None
try:
self._prompt_content = get_prompt(prompt_id, self.prompts_collection)
except ValueError as e:
logger.debug(f"Invalid prompt ID '{prompt_id}': {str(e)}")
self._prompt_content = None
except Exception as e:
logger.debug(f"Failed to fetch prompt '{prompt_id}': {type(e).__name__}")
self._prompt_content = None
return self._prompt_content
def _get_required_tool_actions(self) -> Optional[Dict[str, Set[Optional[str]]]]:
"""Determine which tool actions are referenced in the prompt template"""
if self._required_tool_actions is not None:
return self._required_tool_actions
prompt_content = self._get_prompt_content()
if prompt_content is None:
return None
if "{{" not in prompt_content or "}}" not in prompt_content:
self._required_tool_actions = {}
return self._required_tool_actions
try:
from application.templates.template_engine import TemplateEngine
template_engine = TemplateEngine()
usages = template_engine.extract_tool_usages(prompt_content)
self._required_tool_actions = usages
return self._required_tool_actions
except Exception as e:
logger.debug(f"Failed to extract tool usages: {type(e).__name__}")
self._required_tool_actions = {}
return self._required_tool_actions
def _fetch_memory_tool_data(
self, tool_doc: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""Fetch memory tool data for pre-injection into prompt"""
try:
tool_config = tool_doc.get("config", {}).copy()
tool_config["tool_id"] = str(tool_doc["_id"])
from application.agents.tools.memory import MemoryTool
memory_tool = MemoryTool(tool_config, self.initial_user_id)
root_view = memory_tool.execute_action("view", path="/")
if "Error:" in root_view or not root_view.strip():
return None
return {"root": root_view, "available": True}
except Exception as e:
logger.warning(f"Failed to fetch memory tool data: {str(e)}")
return None
def resume_from_tool_actions(
self,
tool_actions: list,
conversation_id: str,
):
"""Resume a paused agent from saved continuation state.
Loads the pending state from MongoDB, recreates the agent with
the saved configuration, and returns an agent ready to call
``gen_continuation()``.
Args:
tool_actions: Client-provided actions (approvals / results).
conversation_id: The conversation being resumed.
Returns:
Tuple of (agent, messages, tools_dict, pending_tool_calls, tool_actions).
"""
from application.api.answer.services.continuation_service import (
ContinuationService,
)
from application.agents.agent_creator import AgentCreator
from application.agents.tool_executor import ToolExecutor
from application.llm.handlers.handler_creator import LLMHandlerCreator
from application.llm.llm_creator import LLMCreator
cont_service = ContinuationService()
state = cont_service.load_state(conversation_id, self.initial_user_id)
if not state:
raise ValueError("No pending tool state found for this conversation")
# Claim the resume up-front. ``mark_resuming`` only flips ``pending``
# → ``resuming``; if it returns False, another resume already
# claimed this row (status='resuming') — bail before any further
# LLM/tool work to avoid double-execution. The cleanup janitor
# reverts a stale ``resuming`` claim back to ``pending`` after the
# 10-minute grace window so the user can retry.
if not cont_service.mark_resuming(
conversation_id, self.initial_user_id,
):
raise ValueError(
"Resume already in progress for this conversation; "
"retry after the grace window if it stalls."
)
messages = state["messages"]
pending_tool_calls = state["pending_tool_calls"]
tools_dict = state["tools_dict"]
tool_schemas = state.get("tool_schemas", [])
agent_config = state["agent_config"]
model_id = agent_config.get("model_id")
# BYOM scope captured at initial dispatch. None for built-ins or
# caller-owned BYOM where decoded_token['sub'] is already the
# right scope; non-None for shared-agent owner BYOM where the
# caller's identity differs from the model owner's.
model_user_id = agent_config.get("model_user_id")
llm_name = agent_config.get("llm_name", settings.LLM_PROVIDER)
api_key = agent_config.get("api_key")
user_api_key = agent_config.get("user_api_key")
agent_id = agent_config.get("agent_id")
prompt = agent_config.get("prompt", "")
json_schema = agent_config.get("json_schema")
retriever_config = agent_config.get("retriever_config")
# Recreate dependencies
system_api_key = api_key or get_api_key_for_provider(llm_name)
llm = LLMCreator.create_llm(
llm_name,
api_key=system_api_key,
user_api_key=user_api_key,
decoded_token=self.decoded_token,
model_id=model_id,
agent_id=agent_id,
model_user_id=model_user_id,
)
llm_handler = LLMHandlerCreator.create_handler(llm_name or "default")
tool_executor = ToolExecutor(
user_api_key=user_api_key,
user=self.initial_user_id,
decoded_token=self.decoded_token,
agent_id=agent_id,
)
tool_executor.conversation_id = conversation_id
# Restore client tools so they stay available for subsequent LLM calls
saved_client_tools = state.get("client_tools")
if saved_client_tools:
tool_executor.client_tools = saved_client_tools
# Re-merge into tools_dict (they may have been stripped during serialization)
tool_executor.merge_client_tools(tools_dict, saved_client_tools)
agent_type = agent_config.get("agent_type", "ClassicAgent")
# Map class names back to agent creator keys
type_map = {
"ClassicAgent": "classic",
"AgenticAgent": "agentic",
"ResearchAgent": "research",
"WorkflowAgent": "workflow",
}
agent_key = type_map.get(agent_type, "classic")
agent_kwargs = {
"endpoint": "stream",
"llm_name": llm_name,
"model_id": model_id,
"model_user_id": model_user_id,
"api_key": system_api_key,
"agent_id": agent_id,
"user_api_key": user_api_key,
"prompt": prompt,
"chat_history": [],
"decoded_token": self.decoded_token,
"json_schema": json_schema,
"llm": llm,
"llm_handler": llm_handler,
"tool_executor": tool_executor,
}
if agent_key in ("agentic", "research") and retriever_config:
agent_kwargs["retriever_config"] = retriever_config
agent = AgentCreator.create_agent(agent_key, **agent_kwargs)
agent.conversation_id = conversation_id
agent.initial_user_id = self.initial_user_id
agent.tools = tool_schemas
# Store config for the route layer
self.model_id = model_id
# Mirror ``model_user_id`` back onto the processor so the route
# layer (StreamResource) reads the owner scope captured at
# initial dispatch. Without this, ``processor.model_user_id``
# stays at the __init__ default (None) and complete_stream
# falls back to the caller's sub: the post-resume title-LLM
# save misses the owner's BYOM layer, and any second tool
# pause persists ``model_user_id=None`` — losing owner scope
# for every subsequent resume of this conversation.
self.model_user_id = model_user_id
self.agent_id = agent_id
self.agent_config["user_api_key"] = user_api_key
self.conversation_id = conversation_id
# Reused on resume so the same WAL row gets finalised and
# request_id stays consistent across token_usage rows.
self.reserved_message_id = agent_config.get("reserved_message_id")
self.request_id = agent_config.get("request_id")
return agent, messages, tools_dict, pending_tool_calls, tool_actions
def create_agent(
self,
docs_together: Optional[str] = None,
docs: Optional[list] = None,
tools_data: Optional[Dict[str, Any]] = None,
):
"""Create and return the configured agent with rendered prompt"""
agent_type = self.agent_config["agent_type"]
# For agentic agents, swap standard presets for their agentic
# counterparts (which include search tool instructions instead of
# {summaries}). Custom / user-provided prompts pass through as-is.
raw_prompt = self._get_prompt_content()
if raw_prompt is None:
prompt_id = self.agent_config.get("prompt_id", "default")
agentic_presets = {"default", "creative", "strict"}
if agent_type in ("agentic", "research") and prompt_id in agentic_presets:
raw_prompt = get_prompt(
f"agentic_{prompt_id}", self.prompts_collection
)
else:
raw_prompt = get_prompt(prompt_id, self.prompts_collection)
self._prompt_content = raw_prompt
# Allow API callers to override the system prompt when the agent
# has opted in via allow_system_prompt_override.
if (
self.agent_config.get("allow_system_prompt_override", False)
and self.data.get("system_prompt_override")
):
rendered_prompt = self.data["system_prompt_override"]
else:
rendered_prompt = self.prompt_renderer.render_prompt(
prompt_content=raw_prompt,
user_id=self.initial_user_id,
request_id=self.data.get("request_id"),
passthrough_data=self.data.get("passthrough"),
docs=docs,
docs_together=docs_together,
tools_data=tools_data,
)
# Use the user_id that resolved the model so owner-scoped BYOM
# records dispatch correctly on shared-agent requests.
model_user_id = getattr(self, "model_user_id", self.initial_user_id)
provider = (
get_provider_from_model_id(self.model_id, user_id=model_user_id)
if self.model_id
else settings.LLM_PROVIDER
)
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
# Create LLM and handler (dependency injection)
from application.llm.llm_creator import LLMCreator
from application.llm.handlers.handler_creator import LLMHandlerCreator
from application.agents.tool_executor import ToolExecutor
# Compute backup models: agent's configured models minus the active one.
# PG agents may carry an explicit ``models: NULL`` (not absent), so
# ``.get("models", [])`` isn't enough — coerce None → [].
agent_models = self.agent_config.get("models") or []
backup_models = [m for m in agent_models if m != self.model_id]
llm = LLMCreator.create_llm(
provider or settings.LLM_PROVIDER,
api_key=system_api_key,
user_api_key=self.agent_config["user_api_key"],
decoded_token=self.decoded_token,
model_id=self.model_id,
agent_id=self.agent_id,
backup_models=backup_models,
# Owner-scope on shared-agent BYOM dispatch.
model_user_id=model_user_id,
)
llm_handler = LLMHandlerCreator.create_handler(
provider if provider else "default"
)
user = self.decoded_token.get("sub") if self.decoded_token else None
tool_executor = ToolExecutor(
user_api_key=self.agent_config["user_api_key"],
user=user,
decoded_token=self.decoded_token,
agent_id=self.agent_id,
)
tool_executor.conversation_id = self.conversation_id
# Pass client-side tools so they get merged in get_tools()
client_tools = self.data.get("client_tools")
if client_tools:
tool_executor.client_tools = client_tools
agent_kwargs = {
"endpoint": "stream",
"llm_name": provider or settings.LLM_PROVIDER,
"model_id": self.model_id,
"model_user_id": self.model_user_id,
"api_key": system_api_key,
"agent_id": self.agent_id,
"user_api_key": self.agent_config["user_api_key"],
"prompt": rendered_prompt,
"chat_history": self.history,
"retrieved_docs": self.retrieved_docs,
"decoded_token": self.decoded_token,
"attachments": self.attachments,
"json_schema": self.agent_config.get("json_schema"),
"compressed_summary": self.compressed_summary,
"llm": llm,
"llm_handler": llm_handler,
"tool_executor": tool_executor,
}
# Type-specific kwargs
if agent_type in ("agentic", "research"):
agent_kwargs["retriever_config"] = {
"source": self.source,
"retriever_name": self.retriever_config.get(
"retriever_name", "classic"
),
"chunks": self.retriever_config.get("chunks", 2),
"doc_token_limit": self.retriever_config.get(
"doc_token_limit", 50000
),
"model_id": self.model_id,
"model_user_id": self.model_user_id,
"user_api_key": self.agent_config["user_api_key"],
"agent_id": self.agent_id,
"llm_name": provider or settings.LLM_PROVIDER,
"api_key": system_api_key,
"decoded_token": self.decoded_token,
}
elif agent_type == "workflow":
workflow_config = self.agent_config.get("workflow")
if isinstance(workflow_config, str):
agent_kwargs["workflow_id"] = workflow_config
elif isinstance(workflow_config, dict):
agent_kwargs["workflow"] = workflow_config
workflow_owner = self.agent_config.get("workflow_owner")
if workflow_owner:
agent_kwargs["workflow_owner"] = workflow_owner
agent = AgentCreator.create_agent(agent_type, **agent_kwargs)
agent.conversation_id = self.conversation_id
agent.initial_user_id = self.initial_user_id
return agent