mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-22 05:15:08 +00:00
1252 lines
52 KiB
Python
1252 lines
52 KiB
Python
import datetime
|
||
import json
|
||
import logging
|
||
import os
|
||
from pathlib import Path
|
||
from typing import Any, Dict, Optional, Set
|
||
|
||
from application.agents.agent_creator import AgentCreator
|
||
from application.agents.default_tools import synthesized_default_tools
|
||
from application.api.answer.services.compression import CompressionOrchestrator
|
||
from application.api.answer.services.compression.token_counter import TokenCounter
|
||
from application.api.answer.services.conversation_service import ConversationService
|
||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||
from application.core.model_utils import (
|
||
get_api_key_for_provider,
|
||
get_default_model_id,
|
||
get_provider_from_model_id,
|
||
validate_model_id,
|
||
)
|
||
from application.core.settings import settings
|
||
from sqlalchemy import text as sql_text
|
||
|
||
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
|
||
from application.storage.db.repositories.agents import AgentsRepository
|
||
from application.storage.db.repositories.attachments import AttachmentsRepository
|
||
from application.storage.db.repositories.prompts import PromptsRepository
|
||
from application.storage.db.repositories.sources import SourcesRepository
|
||
from application.storage.db.repositories.user_tools import UserToolsRepository
|
||
from application.storage.db.repositories.users import UsersRepository
|
||
from application.storage.db.session import db_readonly, db_session
|
||
from application.retriever.retriever_creator import RetrieverCreator
|
||
from application.utils import (
|
||
calculate_doc_token_budget,
|
||
limit_chat_history,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def get_prompt(prompt_id: str, prompts_collection=None) -> str:
|
||
"""Get a prompt by preset name or Postgres ID (UUID or legacy ObjectId).
|
||
|
||
The ``prompts_collection`` parameter is retained for backwards
|
||
compatibility with call sites that still pass it positionally; it is
|
||
ignored post-cutover.
|
||
"""
|
||
del prompts_collection # unused — retained for call-site compatibility
|
||
# Callers may pass a ``uuid.UUID`` (from a PG ``prompt_id`` column) or a
|
||
# plain string ("default"/"creative"/legacy ObjectId). Normalise to str
|
||
# so both the preset lookup and the UUID-vs-legacy branching work.
|
||
# ``None`` / empty means "use the default prompt" — agents that never
|
||
# set a custom prompt land here (PG ``agents.prompt_id`` is NULL).
|
||
if prompt_id is None or prompt_id == "":
|
||
prompt_id = "default"
|
||
elif not isinstance(prompt_id, str):
|
||
prompt_id = str(prompt_id)
|
||
current_dir = Path(__file__).resolve().parents[3]
|
||
prompts_dir = current_dir / "prompts"
|
||
|
||
CLASSIC_PRESETS = {
|
||
"default": "chat_combine_default.txt",
|
||
"creative": "chat_combine_creative.txt",
|
||
"strict": "chat_combine_strict.txt",
|
||
"reduce": "chat_reduce_prompt.txt",
|
||
}
|
||
AGENTIC_PRESETS = {
|
||
"default": "agentic/default.txt",
|
||
"creative": "agentic/creative.txt",
|
||
"strict": "agentic/strict.txt",
|
||
}
|
||
|
||
preset_mapping = {
|
||
**CLASSIC_PRESETS,
|
||
**{f"agentic_{k}": v for k, v in AGENTIC_PRESETS.items()},
|
||
}
|
||
|
||
if prompt_id in preset_mapping:
|
||
file_path = os.path.join(prompts_dir, preset_mapping[prompt_id])
|
||
try:
|
||
with open(file_path, "r") as f:
|
||
return f.read()
|
||
except FileNotFoundError:
|
||
raise FileNotFoundError(f"Prompt file not found: {file_path}")
|
||
try:
|
||
with db_readonly() as conn:
|
||
repo = PromptsRepository(conn)
|
||
prompt_doc = None
|
||
if looks_like_uuid(prompt_id):
|
||
prompt_doc = repo.get_for_rendering(prompt_id)
|
||
if prompt_doc is None:
|
||
prompt_doc = repo.get_by_legacy_id(prompt_id)
|
||
if not prompt_doc:
|
||
raise ValueError(f"Prompt with ID {prompt_id} not found")
|
||
return prompt_doc["content"]
|
||
except ValueError:
|
||
raise
|
||
except Exception as e:
|
||
raise ValueError(f"Invalid prompt ID: {prompt_id}") from e
|
||
|
||
|
||
class StreamProcessor:
|
||
def __init__(
|
||
self, request_data: Dict[str, Any], decoded_token: Optional[Dict[str, Any]]
|
||
):
|
||
# Legacy attribute retained as None for any external callers that
|
||
# introspect the processor; all DB access uses per-op connections.
|
||
self.prompts_collection = None
|
||
self.data = request_data
|
||
self.decoded_token = decoded_token
|
||
self.initial_user_id = (
|
||
self.decoded_token.get("sub") if self.decoded_token is not None else None
|
||
)
|
||
self.conversation_id = self.data.get("conversation_id")
|
||
self.source = {}
|
||
self.all_sources = []
|
||
self.attachments = []
|
||
self.history = []
|
||
self.retrieved_docs = []
|
||
self.agent_config = {}
|
||
self.retriever_config = {}
|
||
self.is_shared_usage = False
|
||
self.shared_token = None
|
||
self.agent_id = self.data.get("agent_id")
|
||
self.agent_key = None
|
||
self.model_id: Optional[str] = None
|
||
# BYOM-resolution scope, set by _validate_and_set_model.
|
||
self.model_user_id: Optional[str] = None
|
||
# WAL placeholder id pulled from continuation state on resume.
|
||
self.reserved_message_id: Optional[str] = None
|
||
# Carried through resumes so multi-pause runs keep one request_id.
|
||
self.request_id: Optional[str] = None
|
||
self.conversation_service = ConversationService()
|
||
self.compression_orchestrator = CompressionOrchestrator(
|
||
self.conversation_service
|
||
)
|
||
self.prompt_renderer = PromptRenderer()
|
||
self._prompt_content: Optional[str] = None
|
||
self._required_tool_actions: Optional[Dict[str, Set[Optional[str]]]] = None
|
||
self.compressed_summary: Optional[str] = None
|
||
self.compressed_summary_tokens: int = 0
|
||
self._agent_data: Optional[Dict[str, Any]] = None
|
||
|
||
def initialize(self):
|
||
"""Initialize all required components for processing"""
|
||
self._configure_agent()
|
||
self._validate_and_set_model()
|
||
self._configure_source()
|
||
self._configure_retriever()
|
||
self._load_conversation_history()
|
||
self._process_attachments()
|
||
|
||
def build_agent(self, question: str):
|
||
"""One call to go from request data to a ready-to-run agent.
|
||
|
||
Combines initialize(), pre_fetch_docs(), pre_fetch_tools(), and
|
||
create_agent() into a single convenience method.
|
||
"""
|
||
self.initialize()
|
||
|
||
agent_type = self.agent_config.get("agent_type", "classic")
|
||
|
||
# Agentic/research agents skip pre-fetch — the LLM searches on-demand via tools
|
||
if agent_type in ("agentic", "research"):
|
||
tools_data = self.pre_fetch_tools()
|
||
return self.create_agent(tools_data=tools_data)
|
||
|
||
docs_together, docs_list = self.pre_fetch_docs(question)
|
||
tools_data = self.pre_fetch_tools()
|
||
return self.create_agent(
|
||
docs_together=docs_together,
|
||
docs=docs_list,
|
||
tools_data=tools_data,
|
||
)
|
||
|
||
def _load_conversation_history(self):
|
||
"""Load conversation history either from DB or request"""
|
||
if self.conversation_id and self.initial_user_id:
|
||
conversation = self.conversation_service.get_conversation(
|
||
self.conversation_id, self.initial_user_id
|
||
)
|
||
if not conversation:
|
||
raise ValueError("Conversation not found or unauthorized")
|
||
|
||
# Check if compression is enabled and needed
|
||
if settings.ENABLE_CONVERSATION_COMPRESSION:
|
||
self._handle_compression(conversation)
|
||
else:
|
||
# Original behavior - load all history (include metadata if present)
|
||
self.history = [
|
||
{
|
||
"prompt": query["prompt"],
|
||
"response": query["response"],
|
||
**(
|
||
{"metadata": query["metadata"]}
|
||
if "metadata" in query
|
||
else {}
|
||
),
|
||
}
|
||
for query in conversation.get("queries", [])
|
||
]
|
||
else:
|
||
# model_user_id keeps history trim aligned with the BYOM's
|
||
# actual context window instead of the default 128k.
|
||
self.history = limit_chat_history(
|
||
json.loads(self.data.get("history", "[]")),
|
||
model_id=self.model_id,
|
||
user_id=self.model_user_id,
|
||
)
|
||
|
||
def _handle_compression(self, conversation: Dict[str, Any]):
|
||
"""Handle conversation compression logic using orchestrator."""
|
||
try:
|
||
# initial_user_id for conversation access; model_user_id
|
||
# for BYOM context-window / provider lookups.
|
||
result = self.compression_orchestrator.compress_if_needed(
|
||
conversation_id=self.conversation_id,
|
||
user_id=self.initial_user_id,
|
||
model_user_id=self.model_user_id,
|
||
model_id=self.model_id,
|
||
decoded_token=self.decoded_token,
|
||
)
|
||
|
||
if not result.success:
|
||
logger.error(f"Compression failed: {result.error}, using full history")
|
||
self.history = [
|
||
{
|
||
"prompt": query["prompt"],
|
||
"response": query["response"],
|
||
**({"metadata": query["metadata"]} if "metadata" in query else {}),
|
||
}
|
||
for query in conversation.get("queries", [])
|
||
]
|
||
return
|
||
|
||
if result.compression_performed and result.compressed_summary:
|
||
self.compressed_summary = result.compressed_summary
|
||
self.compressed_summary_tokens = TokenCounter.count_message_tokens(
|
||
[{"content": result.compressed_summary}]
|
||
)
|
||
logger.info(
|
||
f"Using compressed summary ({self.compressed_summary_tokens} tokens) "
|
||
f"+ {len(result.recent_queries)} recent messages"
|
||
)
|
||
|
||
self.history = result.as_history()
|
||
# Preserve metadata from recent queries (as_history only has prompt/response)
|
||
recent = result.recent_queries if result.recent_queries else conversation.get("queries", [])
|
||
for i, entry in enumerate(self.history):
|
||
# Match by index from the end of recent queries
|
||
offset = len(recent) - len(self.history)
|
||
qi = offset + i
|
||
if 0 <= qi < len(recent) and "metadata" in recent[qi]:
|
||
entry["metadata"] = recent[qi]["metadata"]
|
||
|
||
except Exception as e:
|
||
logger.error(
|
||
f"Error handling compression, falling back to standard history: {str(e)}",
|
||
exc_info=True,
|
||
)
|
||
self.history = [
|
||
{
|
||
"prompt": query["prompt"],
|
||
"response": query["response"],
|
||
**({"metadata": query["metadata"]} if "metadata" in query else {}),
|
||
}
|
||
for query in conversation.get("queries", [])
|
||
]
|
||
|
||
def _process_attachments(self):
|
||
"""Process any attachments in the request"""
|
||
attachment_ids = self.data.get("attachments", [])
|
||
self.attachments = self._get_attachments_content(
|
||
attachment_ids, self.initial_user_id
|
||
)
|
||
|
||
def _get_attachments_content(self, attachment_ids, user_id):
|
||
if not attachment_ids:
|
||
return []
|
||
attachments = []
|
||
try:
|
||
with db_readonly() as conn:
|
||
repo = AttachmentsRepository(conn)
|
||
for attachment_id in attachment_ids:
|
||
try:
|
||
attachment_doc = repo.get_any(str(attachment_id), user_id)
|
||
if attachment_doc:
|
||
attachments.append(attachment_doc)
|
||
except Exception as e:
|
||
logger.error(
|
||
f"Error retrieving attachment {attachment_id}: {e}",
|
||
exc_info=True,
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"Error opening attachments connection: {e}", exc_info=True)
|
||
return attachments
|
||
|
||
def _validate_and_set_model(self):
|
||
"""Pick model_id with agent authority on agent-bound chats."""
|
||
from application.core.model_settings import ModelRegistry
|
||
|
||
requested_model = self.data.get("model_id")
|
||
# Caller picks from their own BYOM layer; agent defaults resolve
|
||
# under the owner's layer (shared agents have caller != owner).
|
||
caller_user_id = self.initial_user_id
|
||
owner_user_id = self.agent_config.get("user_id") or caller_user_id
|
||
|
||
# Agent-bound: agent's default_model_id wins, body's model_id is dropped.
|
||
agent_bound = self._agent_data is not None
|
||
if agent_bound:
|
||
agent_default_model = self.agent_config.get("default_model_id", "")
|
||
if agent_default_model and validate_model_id(
|
||
agent_default_model, user_id=owner_user_id
|
||
):
|
||
self.model_id = agent_default_model
|
||
self.model_user_id = owner_user_id
|
||
else:
|
||
self.model_id = get_default_model_id()
|
||
self.model_user_id = None
|
||
return
|
||
|
||
if requested_model:
|
||
if not validate_model_id(requested_model, user_id=caller_user_id):
|
||
registry = ModelRegistry.get_instance()
|
||
available_models = [
|
||
m.id
|
||
for m in registry.get_enabled_models(user_id=caller_user_id)
|
||
]
|
||
raise ValueError(
|
||
f"Invalid model_id '{requested_model}'. "
|
||
f"Available models: {', '.join(available_models[:5])}"
|
||
+ (
|
||
f" and {len(available_models) - 5} more"
|
||
if len(available_models) > 5
|
||
else ""
|
||
)
|
||
)
|
||
self.model_id = requested_model
|
||
self.model_user_id = caller_user_id
|
||
else:
|
||
self.model_id = get_default_model_id()
|
||
self.model_user_id = None
|
||
|
||
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
|
||
"""Get API key for agent with access control."""
|
||
if not agent_id:
|
||
return None, False, None
|
||
try:
|
||
with db_readonly() as conn:
|
||
# Lookup without user scoping — access control is done
|
||
# against ``user_id`` / ``shared_with`` / ``shared`` flags
|
||
# right below, matching the legacy Mongo semantics.
|
||
repo = AgentsRepository(conn)
|
||
agent = None
|
||
if looks_like_uuid(str(agent_id)):
|
||
result = conn.execute(
|
||
sql_text(
|
||
"SELECT * FROM agents WHERE id = CAST(:id AS uuid)"
|
||
),
|
||
{"id": str(agent_id)},
|
||
)
|
||
row = result.fetchone()
|
||
if row is not None:
|
||
agent = row_to_dict(row)
|
||
if agent is None:
|
||
agent = repo.get_by_legacy_id(str(agent_id))
|
||
if agent is None:
|
||
raise Exception("Agent not found")
|
||
agent_owner = agent.get("user_id")
|
||
is_owner = agent_owner == user_id
|
||
is_shared_with_user = bool(agent.get("shared", False))
|
||
|
||
if not (is_owner or is_shared_with_user):
|
||
raise Exception("Unauthorized access to the agent")
|
||
if is_owner:
|
||
now = datetime.datetime.now(datetime.timezone.utc)
|
||
try:
|
||
with db_session() as conn:
|
||
AgentsRepository(conn).update(
|
||
str(agent["id"]), agent_owner,
|
||
{"last_used_at": now},
|
||
)
|
||
except Exception:
|
||
logger.warning(
|
||
"Failed to update last_used_at for agent",
|
||
exc_info=True,
|
||
)
|
||
return (
|
||
str(agent["key"]) if agent.get("key") else None,
|
||
not is_owner,
|
||
agent.get("shared_token"),
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"Error in get_agent_key: {str(e)}", exc_info=True)
|
||
raise
|
||
|
||
def _get_data_from_api_key(self, api_key: str) -> Dict[str, Any]:
|
||
"""Resolve agent metadata + the unioned source set for the given key."""
|
||
with db_readonly() as conn:
|
||
agent = AgentsRepository(conn).find_by_key(api_key)
|
||
if not agent:
|
||
raise Exception("Invalid API Key, please generate a new key", 401)
|
||
sources_repo = SourcesRepository(conn)
|
||
# The repo dict uses "user_id" — the streaming path expects
|
||
# a "user" key (legacy Mongo shape) for identity propagation.
|
||
data: Dict[str, Any] = dict(agent)
|
||
data["user"] = agent.get("user_id")
|
||
|
||
# Active sources = primary ∪ extras, primary first, deduplicated.
|
||
# ``_configure_source`` ignores an empty ``data["sources"]``,
|
||
# so the primary must appear in the union too — not only in
|
||
# the legacy ``data["source"]`` slot.
|
||
sources_list: list = []
|
||
seen: set = set()
|
||
owner = agent.get("user_id")
|
||
primary_id = agent.get("source_id")
|
||
# ``sources`` row may have NULL ``retriever``/``chunks`` —
|
||
# fall back to the agent's value (``dict.get`` returns None
|
||
# even when the key exists with value None).
|
||
if primary_id:
|
||
source_doc = sources_repo.get(str(primary_id), owner)
|
||
if source_doc:
|
||
sid = str(source_doc["id"])
|
||
data["source"] = sid
|
||
src_retriever = source_doc.get("retriever")
|
||
if src_retriever:
|
||
data["retriever"] = src_retriever
|
||
src_chunks = source_doc.get("chunks")
|
||
if src_chunks is not None:
|
||
data["chunks"] = src_chunks
|
||
sources_list.append(
|
||
{
|
||
"id": sid,
|
||
"retriever": src_retriever or "classic",
|
||
"chunks": (
|
||
src_chunks if src_chunks is not None
|
||
else data.get("chunks", "2")
|
||
),
|
||
}
|
||
)
|
||
seen.add(sid)
|
||
else:
|
||
data["source"] = None
|
||
else:
|
||
data["source"] = None
|
||
|
||
for sid_raw in agent.get("extra_source_ids") or []:
|
||
if not sid_raw:
|
||
continue
|
||
source_doc = sources_repo.get(str(sid_raw), owner)
|
||
if not source_doc:
|
||
continue
|
||
sid = str(source_doc["id"])
|
||
if sid in seen:
|
||
continue
|
||
src_retriever = source_doc.get("retriever")
|
||
src_chunks = source_doc.get("chunks")
|
||
sources_list.append(
|
||
{
|
||
"id": sid,
|
||
"retriever": src_retriever or "classic",
|
||
"chunks": (
|
||
src_chunks if src_chunks is not None
|
||
else data.get("chunks", "2")
|
||
),
|
||
}
|
||
)
|
||
seen.add(sid)
|
||
data["sources"] = sources_list
|
||
data["default_model_id"] = data.get("default_model_id", "")
|
||
return data
|
||
|
||
def _configure_source(self):
|
||
"""Configure the source based on agent data.
|
||
|
||
The literal string ``"default"`` is a placeholder meaning "no
|
||
ingested source" and is normalized to an empty source so that no
|
||
retrieval is attempted.
|
||
"""
|
||
if self._agent_data:
|
||
agent_data = self._agent_data
|
||
|
||
if agent_data.get("sources") and len(agent_data["sources"]) > 0:
|
||
source_ids = [
|
||
source["id"]
|
||
for source in agent_data["sources"]
|
||
if source.get("id") and source["id"] != "default"
|
||
]
|
||
if source_ids:
|
||
self.source = {"active_docs": source_ids}
|
||
else:
|
||
self.source = {}
|
||
self.all_sources = [
|
||
s for s in agent_data["sources"] if s.get("id") != "default"
|
||
]
|
||
elif agent_data.get("source") and agent_data["source"] != "default":
|
||
self.source = {"active_docs": agent_data["source"]}
|
||
self.all_sources = [
|
||
{
|
||
"id": agent_data["source"],
|
||
"retriever": agent_data.get("retriever", "classic"),
|
||
}
|
||
]
|
||
else:
|
||
self.source = {}
|
||
self.all_sources = []
|
||
return
|
||
if "active_docs" in self.data:
|
||
active_docs = self.data["active_docs"]
|
||
if active_docs and active_docs != "default":
|
||
self.source = {"active_docs": active_docs}
|
||
else:
|
||
self.source = {}
|
||
return
|
||
self.source = {}
|
||
self.all_sources = []
|
||
|
||
def _has_active_docs(self) -> bool:
|
||
"""Return True if a real document source is configured for retrieval."""
|
||
active_docs = self.source.get("active_docs") if self.source else None
|
||
if not active_docs:
|
||
return False
|
||
if active_docs == "default":
|
||
return False
|
||
return True
|
||
|
||
def _resolve_agent_id(self) -> Optional[str]:
|
||
"""Resolve agent_id from request, then fall back to conversation context."""
|
||
request_agent_id = self.data.get("agent_id")
|
||
if request_agent_id:
|
||
return str(request_agent_id)
|
||
|
||
if not self.conversation_id or not self.initial_user_id:
|
||
return None
|
||
|
||
try:
|
||
conversation = self.conversation_service.get_conversation(
|
||
self.conversation_id, self.initial_user_id
|
||
)
|
||
except Exception:
|
||
return None
|
||
|
||
if not conversation:
|
||
return None
|
||
|
||
conversation_agent_id = conversation.get("agent_id")
|
||
if conversation_agent_id:
|
||
return str(conversation_agent_id)
|
||
|
||
return None
|
||
|
||
def _configure_agent(self):
|
||
"""Configure the agent based on request data.
|
||
|
||
Unified flow: resolve the effective API key, then extract config once.
|
||
"""
|
||
agent_id = self._resolve_agent_id()
|
||
|
||
self.agent_key, self.is_shared_usage, self.shared_token = self._get_agent_key(
|
||
agent_id, self.initial_user_id
|
||
)
|
||
self.agent_id = str(agent_id) if agent_id else None
|
||
|
||
# Determine the effective API key (explicit > agent-derived)
|
||
effective_key = self.data.get("api_key") or self.agent_key
|
||
|
||
if effective_key:
|
||
self._agent_data = self._get_data_from_api_key(effective_key)
|
||
if self._agent_data.get("_id"):
|
||
self.agent_id = str(self._agent_data.get("_id"))
|
||
|
||
self.agent_config.update(
|
||
{
|
||
"prompt_id": self._agent_data.get("prompt_id", "default"),
|
||
"agent_type": self._agent_data.get("agent_type", settings.AGENT_NAME),
|
||
"user_api_key": effective_key,
|
||
"json_schema": self._agent_data.get("json_schema"),
|
||
"default_model_id": self._agent_data.get("default_model_id", ""),
|
||
"models": self._agent_data.get("models", []),
|
||
"allow_system_prompt_override": self._agent_data.get(
|
||
"allow_system_prompt_override", False
|
||
),
|
||
# Owner identity — _validate_and_set_model reads this to
|
||
# resolve owner-stored BYOM default_model_id against the
|
||
# owner's per-user model layer rather than the caller's.
|
||
"user_id": self._agent_data.get("user"),
|
||
}
|
||
)
|
||
|
||
# Set identity context
|
||
if self.data.get("api_key"):
|
||
# External API key: use the key owner's identity
|
||
self.initial_user_id = self._agent_data.get("user")
|
||
self.decoded_token = {"sub": self._agent_data.get("user")}
|
||
elif self.is_shared_usage:
|
||
# Shared agent: keep the caller's identity
|
||
pass
|
||
else:
|
||
# Owner using their own agent
|
||
self.decoded_token = {"sub": self._agent_data.get("user")}
|
||
|
||
# PG row exposes the workflow as ``workflow_id`` (UUID column);
|
||
# legacy Mongo shape used the key ``workflow``. Accept either so
|
||
# API-key-invoked workflow agents bind correctly downstream.
|
||
wf_ref = self._agent_data.get("workflow") or self._agent_data.get(
|
||
"workflow_id"
|
||
)
|
||
if wf_ref:
|
||
self.agent_config["workflow"] = str(wf_ref)
|
||
self.agent_config["workflow_owner"] = self._agent_data.get("user")
|
||
else:
|
||
# No API key — default/workflow configuration
|
||
agent_type = settings.AGENT_NAME
|
||
if self.data.get("workflow") and isinstance(
|
||
self.data.get("workflow"), dict
|
||
):
|
||
agent_type = "workflow"
|
||
self.agent_config["workflow"] = self.data["workflow"]
|
||
if isinstance(self.decoded_token, dict):
|
||
self.agent_config["workflow_owner"] = self.decoded_token.get("sub")
|
||
|
||
self.agent_config.update(
|
||
{
|
||
"prompt_id": self.data.get("prompt_id", "default"),
|
||
"agent_type": agent_type,
|
||
"user_api_key": None,
|
||
"json_schema": None,
|
||
"default_model_id": "",
|
||
}
|
||
)
|
||
|
||
def _configure_retriever(self):
|
||
"""Assemble retriever config; agent's values are authoritative when bound."""
|
||
# BYOM scope: owner for shared-agent BYOM, caller for own BYOM,
|
||
# None for built-ins. Without ``user_id`` here, the doc budget
|
||
# falls back to settings.DEFAULT_LLM_TOKEN_LIMIT and overfills
|
||
# the upstream context window for any small (e.g. 8k/32k) BYOM.
|
||
doc_token_limit = calculate_doc_token_budget(
|
||
model_id=self.model_id, user_id=self.model_user_id
|
||
)
|
||
|
||
retriever_name = "classic"
|
||
chunks = 2
|
||
|
||
if self._agent_data is not None:
|
||
# Agent-bound: agent wins, body's retriever/chunks are dropped.
|
||
if self._agent_data.get("retriever"):
|
||
retriever_name = self._agent_data["retriever"]
|
||
if self._agent_data.get("chunks") is not None:
|
||
try:
|
||
chunks = int(self._agent_data["chunks"])
|
||
except (ValueError, TypeError):
|
||
logger.warning(
|
||
f"Invalid agent chunks value: {self._agent_data['chunks']}, "
|
||
"using default value 2"
|
||
)
|
||
else:
|
||
if "retriever" in self.data:
|
||
retriever_name = self.data["retriever"]
|
||
if "chunks" in self.data:
|
||
try:
|
||
chunks = int(self.data["chunks"])
|
||
except (ValueError, TypeError):
|
||
logger.warning(
|
||
f"Invalid request chunks value: {self.data['chunks']}, "
|
||
"using default value 2"
|
||
)
|
||
|
||
self.retriever_config = {
|
||
"retriever_name": retriever_name,
|
||
"chunks": chunks,
|
||
"doc_token_limit": doc_token_limit,
|
||
}
|
||
|
||
# isNoneDoc without an API key forces no retrieval (agentless only)
|
||
api_key = self.data.get("api_key") or self.agent_key
|
||
if not api_key and "isNoneDoc" in self.data and self.data["isNoneDoc"]:
|
||
self.retriever_config["chunks"] = 0
|
||
|
||
def create_retriever(self):
|
||
return RetrieverCreator.create_retriever(
|
||
self.retriever_config["retriever_name"],
|
||
source=self.source,
|
||
chat_history=self.history,
|
||
prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection),
|
||
chunks=self.retriever_config["chunks"],
|
||
doc_token_limit=self.retriever_config.get("doc_token_limit", 50000),
|
||
model_id=self.model_id,
|
||
model_user_id=self.model_user_id,
|
||
user_api_key=self.agent_config["user_api_key"],
|
||
agent_id=self.agent_id,
|
||
decoded_token=self.decoded_token,
|
||
)
|
||
|
||
def pre_fetch_docs(self, question: str) -> tuple[Optional[str], Optional[list]]:
|
||
"""Pre-fetch documents for template rendering before agent creation"""
|
||
if self.data.get("isNoneDoc", False) and not self.agent_id:
|
||
logger.info("Pre-fetch skipped: isNoneDoc=True")
|
||
return None, None
|
||
if not self._has_active_docs():
|
||
logger.info("Pre-fetch skipped: no active docs configured")
|
||
return None, None
|
||
try:
|
||
retriever = self.create_retriever()
|
||
logger.info(
|
||
f"Pre-fetching docs with chunks={retriever.chunks}, doc_token_limit={retriever.doc_token_limit}"
|
||
)
|
||
docs = retriever.search(question)
|
||
logger.info(f"Pre-fetch retrieved {len(docs) if docs else 0} documents")
|
||
|
||
if not docs:
|
||
logger.info("Pre-fetch: No documents returned from search")
|
||
return None, None
|
||
self.retrieved_docs = docs
|
||
|
||
docs_with_filenames = []
|
||
for doc in docs:
|
||
filename = doc.get("filename") or doc.get("title") or doc.get("source")
|
||
if filename:
|
||
chunk_header = str(filename)
|
||
docs_with_filenames.append(f"{chunk_header}\n{doc['text']}")
|
||
else:
|
||
docs_with_filenames.append(doc["text"])
|
||
docs_together = "\n\n".join(docs_with_filenames)
|
||
|
||
logger.info(f"Pre-fetch docs_together size: {len(docs_together)} chars")
|
||
|
||
return docs_together, docs
|
||
except Exception as e:
|
||
logger.error(f"Failed to pre-fetch docs: {str(e)}", exc_info=True)
|
||
return None, None
|
||
|
||
def pre_fetch_tools(self) -> Optional[Dict[str, Any]]:
|
||
"""Pre-fetch tool data for template rendering before agent creation"""
|
||
if not settings.ENABLE_TOOL_PREFETCH:
|
||
logger.info(
|
||
"Tool pre-fetching disabled globally via ENABLE_TOOL_PREFETCH setting"
|
||
)
|
||
return None
|
||
|
||
if self.data.get("disable_tool_prefetch", False):
|
||
logger.info("Tool pre-fetching disabled for this request")
|
||
return None
|
||
|
||
required_tool_actions = self._get_required_tool_actions()
|
||
filtering_enabled = required_tool_actions is not None
|
||
|
||
try:
|
||
user_id = self.initial_user_id or "local"
|
||
agentless = self.agent_id is None
|
||
with db_readonly() as conn:
|
||
user_tools = UserToolsRepository(conn).list_active_for_user(user_id)
|
||
user_doc = (
|
||
UsersRepository(conn).get(user_id) if agentless else None
|
||
)
|
||
|
||
default_docs = (
|
||
synthesized_default_tools(user_doc) if agentless else []
|
||
)
|
||
tool_docs = list(user_tools) + default_docs
|
||
if not tool_docs:
|
||
return None
|
||
|
||
tools_data = {}
|
||
|
||
for tool_doc in tool_docs:
|
||
tool_name = tool_doc.get("name")
|
||
tool_id = str(tool_doc.get("_id") or tool_doc.get("id"))
|
||
is_default = bool(tool_doc.get("default"))
|
||
|
||
if filtering_enabled:
|
||
required_actions_by_name = required_tool_actions.get(
|
||
tool_name, set()
|
||
)
|
||
required_actions_by_id = required_tool_actions.get(tool_id, set())
|
||
|
||
required_actions = required_actions_by_name | required_actions_by_id
|
||
|
||
if not required_actions:
|
||
continue
|
||
else:
|
||
# No template names a default tool, so running its
|
||
# actions blind would only inject noise.
|
||
if is_default:
|
||
continue
|
||
required_actions = None
|
||
|
||
tool_data = self._fetch_tool_data(tool_doc, required_actions)
|
||
if tool_data:
|
||
# Defaults reachable by synthetic id only — the name
|
||
# key stays bound to an explicit row of the same name.
|
||
if not is_default:
|
||
tools_data[tool_name] = tool_data
|
||
tools_data[tool_id] = tool_data
|
||
|
||
return tools_data if tools_data else None
|
||
except Exception as e:
|
||
logger.warning(f"Failed to pre-fetch tools: {type(e).__name__}")
|
||
return None
|
||
|
||
def _fetch_tool_data(
|
||
self,
|
||
tool_doc: Dict[str, Any],
|
||
required_actions: Optional[Set[Optional[str]]],
|
||
) -> Optional[Dict[str, Any]]:
|
||
"""Fetch and execute tool actions with saved parameters"""
|
||
try:
|
||
from application.agents.tools.tool_manager import ToolManager
|
||
|
||
tool_name = tool_doc.get("name")
|
||
tool_config = tool_doc.get("config", {}).copy()
|
||
tool_config["tool_id"] = str(tool_doc["_id"])
|
||
|
||
tool_manager = ToolManager(config={tool_name: tool_config})
|
||
user_id = self.initial_user_id or "local"
|
||
tool = tool_manager.load_tool(tool_name, tool_config, user_id=user_id)
|
||
|
||
if not tool:
|
||
logger.debug(f"Tool '{tool_name}' failed to load")
|
||
return None
|
||
|
||
tool_actions = tool.get_actions_metadata()
|
||
if not tool_actions:
|
||
logger.debug(f"Tool '{tool_name}' has no actions")
|
||
return None
|
||
|
||
saved_actions = tool_doc.get("actions", [])
|
||
|
||
include_all_actions = required_actions is None or (
|
||
required_actions and None in required_actions
|
||
)
|
||
allowed_actions: Set[str] = (
|
||
{action for action in required_actions if isinstance(action, str)}
|
||
if required_actions
|
||
else set()
|
||
)
|
||
|
||
action_results = {}
|
||
for action_meta in tool_actions:
|
||
action_name = action_meta.get("name")
|
||
if action_name is None:
|
||
continue
|
||
if (
|
||
not include_all_actions
|
||
and allowed_actions
|
||
and action_name not in allowed_actions
|
||
):
|
||
continue
|
||
|
||
try:
|
||
saved_action = None
|
||
for sa in saved_actions:
|
||
if sa.get("name") == action_name:
|
||
saved_action = sa
|
||
break
|
||
|
||
action_params = action_meta.get("parameters", {})
|
||
properties = action_params.get("properties", {})
|
||
|
||
kwargs = {}
|
||
for param_name, param_spec in properties.items():
|
||
if saved_action:
|
||
saved_props = saved_action.get("parameters", {}).get(
|
||
"properties", {}
|
||
)
|
||
if param_name in saved_props:
|
||
param_value = saved_props[param_name].get("value")
|
||
if param_value is not None:
|
||
kwargs[param_name] = param_value
|
||
continue
|
||
|
||
if param_name in tool_config:
|
||
kwargs[param_name] = tool_config[param_name]
|
||
elif "default" in param_spec:
|
||
kwargs[param_name] = param_spec["default"]
|
||
|
||
result = tool.execute_action(action_name, **kwargs)
|
||
action_results[action_name] = result
|
||
except Exception as e:
|
||
logger.debug(
|
||
f"Action '{action_name}' execution failed: {type(e).__name__}"
|
||
)
|
||
continue
|
||
|
||
return action_results if action_results else None
|
||
|
||
except Exception as e:
|
||
logger.debug(f"Tool pre-fetch failed for '{tool_name}': {type(e).__name__}")
|
||
return None
|
||
|
||
def _get_prompt_content(self) -> Optional[str]:
|
||
"""Retrieve and cache the raw prompt content for the current agent configuration."""
|
||
if self._prompt_content is not None:
|
||
return self._prompt_content
|
||
prompt_id = (
|
||
self.agent_config.get("prompt_id")
|
||
if isinstance(self.agent_config, dict)
|
||
else None
|
||
)
|
||
if not prompt_id:
|
||
return None
|
||
try:
|
||
self._prompt_content = get_prompt(prompt_id, self.prompts_collection)
|
||
except ValueError as e:
|
||
logger.debug(f"Invalid prompt ID '{prompt_id}': {str(e)}")
|
||
self._prompt_content = None
|
||
except Exception as e:
|
||
logger.debug(f"Failed to fetch prompt '{prompt_id}': {type(e).__name__}")
|
||
self._prompt_content = None
|
||
return self._prompt_content
|
||
|
||
def _get_required_tool_actions(self) -> Optional[Dict[str, Set[Optional[str]]]]:
|
||
"""Determine which tool actions are referenced in the prompt template"""
|
||
if self._required_tool_actions is not None:
|
||
return self._required_tool_actions
|
||
|
||
prompt_content = self._get_prompt_content()
|
||
if prompt_content is None:
|
||
return None
|
||
|
||
if "{{" not in prompt_content or "}}" not in prompt_content:
|
||
self._required_tool_actions = {}
|
||
return self._required_tool_actions
|
||
|
||
try:
|
||
from application.templates.template_engine import TemplateEngine
|
||
|
||
template_engine = TemplateEngine()
|
||
usages = template_engine.extract_tool_usages(prompt_content)
|
||
self._required_tool_actions = usages
|
||
return self._required_tool_actions
|
||
except Exception as e:
|
||
logger.debug(f"Failed to extract tool usages: {type(e).__name__}")
|
||
self._required_tool_actions = {}
|
||
return self._required_tool_actions
|
||
|
||
def _fetch_memory_tool_data(
|
||
self, tool_doc: Dict[str, Any]
|
||
) -> Optional[Dict[str, Any]]:
|
||
"""Fetch memory tool data for pre-injection into prompt"""
|
||
try:
|
||
tool_config = tool_doc.get("config", {}).copy()
|
||
tool_config["tool_id"] = str(tool_doc["_id"])
|
||
|
||
from application.agents.tools.memory import MemoryTool
|
||
|
||
memory_tool = MemoryTool(tool_config, self.initial_user_id)
|
||
|
||
root_view = memory_tool.execute_action("view", path="/")
|
||
|
||
if "Error:" in root_view or not root_view.strip():
|
||
return None
|
||
|
||
return {"root": root_view, "available": True}
|
||
except Exception as e:
|
||
logger.warning(f"Failed to fetch memory tool data: {str(e)}")
|
||
return None
|
||
|
||
def resume_from_tool_actions(
|
||
self,
|
||
tool_actions: list,
|
||
conversation_id: str,
|
||
):
|
||
"""Resume a paused agent from saved continuation state.
|
||
|
||
Loads the pending state from MongoDB, recreates the agent with
|
||
the saved configuration, and returns an agent ready to call
|
||
``gen_continuation()``.
|
||
|
||
Args:
|
||
tool_actions: Client-provided actions (approvals / results).
|
||
conversation_id: The conversation being resumed.
|
||
|
||
Returns:
|
||
Tuple of (agent, messages, tools_dict, pending_tool_calls, tool_actions).
|
||
"""
|
||
from application.api.answer.services.continuation_service import (
|
||
ContinuationService,
|
||
)
|
||
from application.agents.agent_creator import AgentCreator
|
||
from application.agents.tool_executor import ToolExecutor
|
||
from application.llm.handlers.handler_creator import LLMHandlerCreator
|
||
from application.llm.llm_creator import LLMCreator
|
||
|
||
cont_service = ContinuationService()
|
||
state = cont_service.load_state(conversation_id, self.initial_user_id)
|
||
if not state:
|
||
raise ValueError("No pending tool state found for this conversation")
|
||
|
||
# Claim the resume up-front. ``mark_resuming`` only flips ``pending``
|
||
# → ``resuming``; if it returns False, another resume already
|
||
# claimed this row (status='resuming') — bail before any further
|
||
# LLM/tool work to avoid double-execution. The cleanup janitor
|
||
# reverts a stale ``resuming`` claim back to ``pending`` after the
|
||
# 10-minute grace window so the user can retry.
|
||
if not cont_service.mark_resuming(
|
||
conversation_id, self.initial_user_id,
|
||
):
|
||
raise ValueError(
|
||
"Resume already in progress for this conversation; "
|
||
"retry after the grace window if it stalls."
|
||
)
|
||
|
||
messages = state["messages"]
|
||
pending_tool_calls = state["pending_tool_calls"]
|
||
tools_dict = state["tools_dict"]
|
||
tool_schemas = state.get("tool_schemas", [])
|
||
agent_config = state["agent_config"]
|
||
|
||
model_id = agent_config.get("model_id")
|
||
# BYOM scope captured at initial dispatch. None for built-ins or
|
||
# caller-owned BYOM where decoded_token['sub'] is already the
|
||
# right scope; non-None for shared-agent owner BYOM where the
|
||
# caller's identity differs from the model owner's.
|
||
model_user_id = agent_config.get("model_user_id")
|
||
llm_name = agent_config.get("llm_name", settings.LLM_PROVIDER)
|
||
api_key = agent_config.get("api_key")
|
||
user_api_key = agent_config.get("user_api_key")
|
||
agent_id = agent_config.get("agent_id")
|
||
prompt = agent_config.get("prompt", "")
|
||
json_schema = agent_config.get("json_schema")
|
||
retriever_config = agent_config.get("retriever_config")
|
||
|
||
# Recreate dependencies
|
||
system_api_key = api_key or get_api_key_for_provider(llm_name)
|
||
llm = LLMCreator.create_llm(
|
||
llm_name,
|
||
api_key=system_api_key,
|
||
user_api_key=user_api_key,
|
||
decoded_token=self.decoded_token,
|
||
model_id=model_id,
|
||
agent_id=agent_id,
|
||
model_user_id=model_user_id,
|
||
)
|
||
llm_handler = LLMHandlerCreator.create_handler(llm_name or "default")
|
||
tool_executor = ToolExecutor(
|
||
user_api_key=user_api_key,
|
||
user=self.initial_user_id,
|
||
decoded_token=self.decoded_token,
|
||
agent_id=agent_id,
|
||
)
|
||
tool_executor.conversation_id = conversation_id
|
||
# Restore client tools so they stay available for subsequent LLM calls
|
||
saved_client_tools = state.get("client_tools")
|
||
if saved_client_tools:
|
||
tool_executor.client_tools = saved_client_tools
|
||
# Re-merge into tools_dict (they may have been stripped during serialization)
|
||
tool_executor.merge_client_tools(tools_dict, saved_client_tools)
|
||
|
||
agent_type = agent_config.get("agent_type", "ClassicAgent")
|
||
# Map class names back to agent creator keys
|
||
type_map = {
|
||
"ClassicAgent": "classic",
|
||
"AgenticAgent": "agentic",
|
||
"ResearchAgent": "research",
|
||
"WorkflowAgent": "workflow",
|
||
}
|
||
agent_key = type_map.get(agent_type, "classic")
|
||
|
||
agent_kwargs = {
|
||
"endpoint": "stream",
|
||
"llm_name": llm_name,
|
||
"model_id": model_id,
|
||
"model_user_id": model_user_id,
|
||
"api_key": system_api_key,
|
||
"agent_id": agent_id,
|
||
"user_api_key": user_api_key,
|
||
"prompt": prompt,
|
||
"chat_history": [],
|
||
"decoded_token": self.decoded_token,
|
||
"json_schema": json_schema,
|
||
"llm": llm,
|
||
"llm_handler": llm_handler,
|
||
"tool_executor": tool_executor,
|
||
}
|
||
|
||
if agent_key in ("agentic", "research") and retriever_config:
|
||
agent_kwargs["retriever_config"] = retriever_config
|
||
|
||
agent = AgentCreator.create_agent(agent_key, **agent_kwargs)
|
||
agent.conversation_id = conversation_id
|
||
agent.initial_user_id = self.initial_user_id
|
||
agent.tools = tool_schemas
|
||
|
||
# Store config for the route layer
|
||
self.model_id = model_id
|
||
# Mirror ``model_user_id`` back onto the processor so the route
|
||
# layer (StreamResource) reads the owner scope captured at
|
||
# initial dispatch. Without this, ``processor.model_user_id``
|
||
# stays at the __init__ default (None) and complete_stream
|
||
# falls back to the caller's sub: the post-resume title-LLM
|
||
# save misses the owner's BYOM layer, and any second tool
|
||
# pause persists ``model_user_id=None`` — losing owner scope
|
||
# for every subsequent resume of this conversation.
|
||
self.model_user_id = model_user_id
|
||
self.agent_id = agent_id
|
||
self.agent_config["user_api_key"] = user_api_key
|
||
self.conversation_id = conversation_id
|
||
# Reused on resume so the same WAL row gets finalised and
|
||
# request_id stays consistent across token_usage rows.
|
||
self.reserved_message_id = agent_config.get("reserved_message_id")
|
||
self.request_id = agent_config.get("request_id")
|
||
|
||
return agent, messages, tools_dict, pending_tool_calls, tool_actions
|
||
|
||
def create_agent(
|
||
self,
|
||
docs_together: Optional[str] = None,
|
||
docs: Optional[list] = None,
|
||
tools_data: Optional[Dict[str, Any]] = None,
|
||
):
|
||
"""Create and return the configured agent with rendered prompt"""
|
||
agent_type = self.agent_config["agent_type"]
|
||
|
||
# For agentic agents, swap standard presets for their agentic
|
||
# counterparts (which include search tool instructions instead of
|
||
# {summaries}). Custom / user-provided prompts pass through as-is.
|
||
raw_prompt = self._get_prompt_content()
|
||
if raw_prompt is None:
|
||
prompt_id = self.agent_config.get("prompt_id", "default")
|
||
agentic_presets = {"default", "creative", "strict"}
|
||
if agent_type in ("agentic", "research") and prompt_id in agentic_presets:
|
||
raw_prompt = get_prompt(
|
||
f"agentic_{prompt_id}", self.prompts_collection
|
||
)
|
||
else:
|
||
raw_prompt = get_prompt(prompt_id, self.prompts_collection)
|
||
self._prompt_content = raw_prompt
|
||
|
||
# Allow API callers to override the system prompt when the agent
|
||
# has opted in via allow_system_prompt_override.
|
||
if (
|
||
self.agent_config.get("allow_system_prompt_override", False)
|
||
and self.data.get("system_prompt_override")
|
||
):
|
||
rendered_prompt = self.data["system_prompt_override"]
|
||
else:
|
||
rendered_prompt = self.prompt_renderer.render_prompt(
|
||
prompt_content=raw_prompt,
|
||
user_id=self.initial_user_id,
|
||
request_id=self.data.get("request_id"),
|
||
passthrough_data=self.data.get("passthrough"),
|
||
docs=docs,
|
||
docs_together=docs_together,
|
||
tools_data=tools_data,
|
||
)
|
||
|
||
# Use the user_id that resolved the model so owner-scoped BYOM
|
||
# records dispatch correctly on shared-agent requests.
|
||
model_user_id = getattr(self, "model_user_id", self.initial_user_id)
|
||
provider = (
|
||
get_provider_from_model_id(self.model_id, user_id=model_user_id)
|
||
if self.model_id
|
||
else settings.LLM_PROVIDER
|
||
)
|
||
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
|
||
|
||
# Create LLM and handler (dependency injection)
|
||
from application.llm.llm_creator import LLMCreator
|
||
from application.llm.handlers.handler_creator import LLMHandlerCreator
|
||
from application.agents.tool_executor import ToolExecutor
|
||
|
||
# Compute backup models: agent's configured models minus the active one.
|
||
# PG agents may carry an explicit ``models: NULL`` (not absent), so
|
||
# ``.get("models", [])`` isn't enough — coerce None → [].
|
||
agent_models = self.agent_config.get("models") or []
|
||
backup_models = [m for m in agent_models if m != self.model_id]
|
||
|
||
llm = LLMCreator.create_llm(
|
||
provider or settings.LLM_PROVIDER,
|
||
api_key=system_api_key,
|
||
user_api_key=self.agent_config["user_api_key"],
|
||
decoded_token=self.decoded_token,
|
||
model_id=self.model_id,
|
||
agent_id=self.agent_id,
|
||
backup_models=backup_models,
|
||
# Owner-scope on shared-agent BYOM dispatch.
|
||
model_user_id=model_user_id,
|
||
)
|
||
llm_handler = LLMHandlerCreator.create_handler(
|
||
provider if provider else "default"
|
||
)
|
||
|
||
user = self.decoded_token.get("sub") if self.decoded_token else None
|
||
tool_executor = ToolExecutor(
|
||
user_api_key=self.agent_config["user_api_key"],
|
||
user=user,
|
||
decoded_token=self.decoded_token,
|
||
agent_id=self.agent_id,
|
||
)
|
||
tool_executor.conversation_id = self.conversation_id
|
||
# Pass client-side tools so they get merged in get_tools()
|
||
client_tools = self.data.get("client_tools")
|
||
if client_tools:
|
||
tool_executor.client_tools = client_tools
|
||
|
||
agent_kwargs = {
|
||
"endpoint": "stream",
|
||
"llm_name": provider or settings.LLM_PROVIDER,
|
||
"model_id": self.model_id,
|
||
"model_user_id": self.model_user_id,
|
||
"api_key": system_api_key,
|
||
"agent_id": self.agent_id,
|
||
"user_api_key": self.agent_config["user_api_key"],
|
||
"prompt": rendered_prompt,
|
||
"chat_history": self.history,
|
||
"retrieved_docs": self.retrieved_docs,
|
||
"decoded_token": self.decoded_token,
|
||
"attachments": self.attachments,
|
||
"json_schema": self.agent_config.get("json_schema"),
|
||
"compressed_summary": self.compressed_summary,
|
||
"llm": llm,
|
||
"llm_handler": llm_handler,
|
||
"tool_executor": tool_executor,
|
||
}
|
||
|
||
# Type-specific kwargs
|
||
if agent_type in ("agentic", "research"):
|
||
agent_kwargs["retriever_config"] = {
|
||
"source": self.source,
|
||
"retriever_name": self.retriever_config.get(
|
||
"retriever_name", "classic"
|
||
),
|
||
"chunks": self.retriever_config.get("chunks", 2),
|
||
"doc_token_limit": self.retriever_config.get(
|
||
"doc_token_limit", 50000
|
||
),
|
||
"model_id": self.model_id,
|
||
"model_user_id": self.model_user_id,
|
||
"user_api_key": self.agent_config["user_api_key"],
|
||
"agent_id": self.agent_id,
|
||
"llm_name": provider or settings.LLM_PROVIDER,
|
||
"api_key": system_api_key,
|
||
"decoded_token": self.decoded_token,
|
||
}
|
||
|
||
elif agent_type == "workflow":
|
||
workflow_config = self.agent_config.get("workflow")
|
||
if isinstance(workflow_config, str):
|
||
agent_kwargs["workflow_id"] = workflow_config
|
||
elif isinstance(workflow_config, dict):
|
||
agent_kwargs["workflow"] = workflow_config
|
||
workflow_owner = self.agent_config.get("workflow_owner")
|
||
if workflow_owner:
|
||
agent_kwargs["workflow_owner"] = workflow_owner
|
||
|
||
agent = AgentCreator.create_agent(agent_type, **agent_kwargs)
|
||
|
||
agent.conversation_id = self.conversation_id
|
||
agent.initial_user_id = self.initial_user_id
|
||
|
||
return agent
|