mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
feat: template-based prompt rendering with dynamic namespace injection (#2091)
* feat: template-based prompt rendering with dynamic namespace injection * refactor: improve template engine initialization with clearer formatting * refactor: streamline ReActAgent methods and improve content extraction logic feat: enhance error handling in NamespaceManager and TemplateEngine fix: update NewAgent component to ensure consistent form data submission test: modify tests for ReActAgent and prompt renderer to reflect method changes and improve coverage * feat: tools namespace + three-tier token budget * refactor: remove unused variable assignment in message building tests * Enhance prompt customization and tool pre-fetching functionality * ruff lint fix * refactor: cleaner error handling and reduce code clutter --------- Co-authored-by: Alex <a@tushynski.me>
This commit is contained in:
@@ -54,6 +54,10 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
default=True,
|
||||
description="Whether to save the conversation",
|
||||
),
|
||||
"passthrough": fields.Raw(
|
||||
required=False,
|
||||
description="Dynamic parameters to inject into prompt template",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -69,8 +73,17 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
processor.initialize()
|
||||
if not processor.decoded_token:
|
||||
return make_response({"error": "Unauthorized"}, 401)
|
||||
agent = processor.create_agent()
|
||||
retriever = processor.create_retriever()
|
||||
|
||||
docs_together, docs_list = processor.pre_fetch_docs(
|
||||
data.get("question", "")
|
||||
)
|
||||
tools_data = processor.pre_fetch_tools()
|
||||
|
||||
agent = processor.create_agent(
|
||||
docs_together=docs_together,
|
||||
docs=docs_list,
|
||||
tools_data=tools_data,
|
||||
)
|
||||
|
||||
if error := self.check_usage(processor.agent_config):
|
||||
return error
|
||||
@@ -78,7 +91,6 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
stream = self.complete_stream(
|
||||
question=data["question"],
|
||||
agent=agent,
|
||||
retriever=retriever,
|
||||
conversation_id=processor.conversation_id,
|
||||
user_api_key=processor.agent_config.get("user_api_key"),
|
||||
decoded_token=processor.decoded_token,
|
||||
|
||||
@@ -3,7 +3,7 @@ import json
|
||||
import logging
|
||||
from typing import Any, Dict, Generator, List, Optional
|
||||
|
||||
from flask import Response, make_response, jsonify
|
||||
from flask import jsonify, make_response, Response
|
||||
from flask_restx import Namespace
|
||||
|
||||
from application.api.answer.services.conversation_service import ConversationService
|
||||
@@ -41,9 +41,7 @@ class BaseAnswerResource:
|
||||
return missing_fields
|
||||
return None
|
||||
|
||||
def check_usage(
|
||||
self, agent_config: Dict
|
||||
) -> Optional[Response]:
|
||||
def check_usage(self, agent_config: Dict) -> Optional[Response]:
|
||||
"""Check if there is a usage limit and if it is exceeded
|
||||
|
||||
Args:
|
||||
@@ -51,30 +49,40 @@ class BaseAnswerResource:
|
||||
|
||||
Returns:
|
||||
None or Response if either of limits exceeded.
|
||||
|
||||
|
||||
"""
|
||||
api_key = agent_config.get("user_api_key")
|
||||
if not api_key:
|
||||
return None
|
||||
|
||||
|
||||
agents_collection = self.db["agents"]
|
||||
agent = agents_collection.find_one({"key": api_key})
|
||||
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Invalid API key."
|
||||
}
|
||||
),
|
||||
401
|
||||
jsonify({"success": False, "message": "Invalid API key."}), 401
|
||||
)
|
||||
|
||||
limited_token_mode = agent.get("limited_token_mode", False)
|
||||
limited_request_mode = agent.get("limited_request_mode", False)
|
||||
token_limit = int(agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]))
|
||||
request_limit = int(agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]))
|
||||
limited_token_mode_raw = agent.get("limited_token_mode", False)
|
||||
limited_request_mode_raw = agent.get("limited_request_mode", False)
|
||||
|
||||
limited_token_mode = (
|
||||
limited_token_mode_raw
|
||||
if isinstance(limited_token_mode_raw, bool)
|
||||
else limited_token_mode_raw == "True"
|
||||
)
|
||||
limited_request_mode = (
|
||||
limited_request_mode_raw
|
||||
if isinstance(limited_request_mode_raw, bool)
|
||||
else limited_request_mode_raw == "True"
|
||||
)
|
||||
|
||||
token_limit = int(
|
||||
agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"])
|
||||
)
|
||||
request_limit = int(
|
||||
agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"])
|
||||
)
|
||||
|
||||
token_usage_collection = self.db["token_usage"]
|
||||
|
||||
@@ -83,18 +91,20 @@ class BaseAnswerResource:
|
||||
|
||||
match_query = {
|
||||
"timestamp": {"$gte": start_date, "$lte": end_date},
|
||||
"api_key": api_key
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
||||
|
||||
if limited_token_mode:
|
||||
token_pipeline = [
|
||||
{"$match": match_query},
|
||||
{
|
||||
"$group": {
|
||||
"_id": None,
|
||||
"total_tokens": {"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}}
|
||||
"total_tokens": {
|
||||
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
]
|
||||
token_result = list(token_usage_collection.aggregate(token_pipeline))
|
||||
daily_token_usage = token_result[0]["total_tokens"] if token_result else 0
|
||||
@@ -108,26 +118,33 @@ class BaseAnswerResource:
|
||||
|
||||
if not limited_token_mode and not limited_request_mode:
|
||||
return None
|
||||
elif limited_token_mode and token_limit > daily_token_usage:
|
||||
return None
|
||||
elif limited_request_mode and request_limit > daily_request_usage:
|
||||
return None
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Exceeding usage limit, please try again later."
|
||||
}
|
||||
),
|
||||
429, # too many requests
|
||||
token_exceeded = (
|
||||
limited_token_mode and token_limit > 0 and daily_token_usage >= token_limit
|
||||
)
|
||||
request_exceeded = (
|
||||
limited_request_mode
|
||||
and request_limit > 0
|
||||
and daily_request_usage >= request_limit
|
||||
)
|
||||
|
||||
if token_exceeded or request_exceeded:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Exceeding usage limit, please try again later.",
|
||||
}
|
||||
),
|
||||
429,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def complete_stream(
|
||||
self,
|
||||
question: str,
|
||||
agent: Any,
|
||||
retriever: Any,
|
||||
conversation_id: Optional[str],
|
||||
user_api_key: Optional[str],
|
||||
decoded_token: Dict[str, Any],
|
||||
@@ -156,6 +173,7 @@ class BaseAnswerResource:
|
||||
agent_id: ID of agent used
|
||||
is_shared_usage: Flag for shared agent usage
|
||||
shared_token: Token for shared agent
|
||||
retrieved_docs: Pre-fetched documents for sources (optional)
|
||||
|
||||
Yields:
|
||||
Server-sent event strings
|
||||
@@ -166,7 +184,7 @@ class BaseAnswerResource:
|
||||
schema_info = None
|
||||
structured_chunks = []
|
||||
|
||||
for line in agent.gen(query=question, retriever=retriever):
|
||||
for line in agent.gen(query=question):
|
||||
if "answer" in line:
|
||||
response_full += str(line["answer"])
|
||||
if line.get("structured"):
|
||||
@@ -247,7 +265,6 @@ class BaseAnswerResource:
|
||||
data = json.dumps(id_data)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
retriever_params = retriever.get_params()
|
||||
log_data = {
|
||||
"action": "stream_answer",
|
||||
"level": "info",
|
||||
@@ -256,7 +273,6 @@ class BaseAnswerResource:
|
||||
"question": question,
|
||||
"response": response_full,
|
||||
"sources": source_log_docs,
|
||||
"retriever_params": retriever_params,
|
||||
"attachments": attachment_ids,
|
||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
@@ -264,24 +280,19 @@ class BaseAnswerResource:
|
||||
log_data["structured_output"] = True
|
||||
if schema_info:
|
||||
log_data["schema"] = schema_info
|
||||
|
||||
# clean up text fields to be no longer than 10000 characters
|
||||
|
||||
# Clean up text fields to be no longer than 10000 characters
|
||||
for key, value in log_data.items():
|
||||
if isinstance(value, str) and len(value) > 10000:
|
||||
log_data[key] = value[:10000]
|
||||
|
||||
self.user_logs_collection.insert_one(log_data)
|
||||
|
||||
# End of stream
|
||||
self.user_logs_collection.insert_one(log_data)
|
||||
|
||||
data = json.dumps({"type": "end"})
|
||||
yield f"data: {data}\n\n"
|
||||
except GeneratorExit:
|
||||
# Client aborted the connection
|
||||
logger.info(
|
||||
f"Stream aborted by client for question: {question[:50]}... "
|
||||
)
|
||||
# Save partial response to database before exiting
|
||||
logger.info(f"Stream aborted by client for question: {question[:50]}... ")
|
||||
# Save partial response
|
||||
if should_save_conversation and response_full:
|
||||
try:
|
||||
if isNoneDoc:
|
||||
@@ -311,7 +322,9 @@ class BaseAnswerResource:
|
||||
attachment_ids=attachment_ids,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving partial response: {str(e)}", exc_info=True)
|
||||
logger.error(
|
||||
f"Error saving partial response: {str(e)}", exc_info=True
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stream: {str(e)}", exc_info=True)
|
||||
|
||||
@@ -60,6 +60,10 @@ class StreamResource(Resource, BaseAnswerResource):
|
||||
"attachments": fields.List(
|
||||
fields.String, required=False, description="List of attachment IDs"
|
||||
),
|
||||
"passthrough": fields.Raw(
|
||||
required=False,
|
||||
description="Dynamic parameters to inject into prompt template",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -73,17 +77,20 @@ class StreamResource(Resource, BaseAnswerResource):
|
||||
processor = StreamProcessor(data, decoded_token)
|
||||
try:
|
||||
processor.initialize()
|
||||
agent = processor.create_agent()
|
||||
retriever = processor.create_retriever()
|
||||
|
||||
docs_together, docs_list = processor.pre_fetch_docs(data["question"])
|
||||
tools_data = processor.pre_fetch_tools()
|
||||
|
||||
agent = processor.create_agent(
|
||||
docs_together=docs_together, docs=docs_list, tools_data=tools_data
|
||||
)
|
||||
|
||||
if error := self.check_usage(processor.agent_config):
|
||||
return error
|
||||
|
||||
return Response(
|
||||
self.complete_stream(
|
||||
question=data["question"],
|
||||
agent=agent,
|
||||
retriever=retriever,
|
||||
conversation_id=processor.conversation_id,
|
||||
user_api_key=processor.agent_config.get("user_api_key"),
|
||||
decoded_token=processor.decoded_token,
|
||||
|
||||
@@ -133,10 +133,9 @@ class ConversationService:
|
||||
|
||||
messages_summary = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Summarise following conversation in no more than 3 "
|
||||
"words, respond ONLY with the summary, use the same "
|
||||
"language as the user query",
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant that creates concise conversation titles. "
|
||||
"Summarize conversations in 3 words or less using the same language as the user.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
|
||||
97
application/api/answer/services/prompt_renderer.py
Normal file
97
application/api/answer/services/prompt_renderer.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from application.templates.namespaces import NamespaceManager
|
||||
|
||||
from application.templates.template_engine import TemplateEngine, TemplateRenderError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PromptRenderer:
|
||||
"""Service for rendering prompts with dynamic context using namespaces"""
|
||||
|
||||
def __init__(self):
|
||||
self.template_engine = TemplateEngine()
|
||||
self.namespace_manager = NamespaceManager()
|
||||
|
||||
def render_prompt(
|
||||
self,
|
||||
prompt_content: str,
|
||||
user_id: Optional[str] = None,
|
||||
request_id: Optional[str] = None,
|
||||
passthrough_data: Optional[Dict[str, Any]] = None,
|
||||
docs: Optional[list] = None,
|
||||
docs_together: Optional[str] = None,
|
||||
tools_data: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Render prompt with full context from all namespaces.
|
||||
|
||||
Args:
|
||||
prompt_content: Raw prompt template string
|
||||
user_id: Current user identifier
|
||||
request_id: Unique request identifier
|
||||
passthrough_data: Parameters from web request
|
||||
docs: RAG retrieved documents
|
||||
docs_together: Concatenated document content
|
||||
tools_data: Pre-fetched tool results organized by tool name
|
||||
**kwargs: Additional parameters for namespace builders
|
||||
|
||||
Returns:
|
||||
Rendered prompt string with all variables substituted
|
||||
|
||||
Raises:
|
||||
TemplateRenderError: If template rendering fails
|
||||
"""
|
||||
if not prompt_content:
|
||||
return ""
|
||||
|
||||
uses_template = self._uses_template_syntax(prompt_content)
|
||||
|
||||
if not uses_template:
|
||||
return self._apply_legacy_substitutions(prompt_content, docs_together)
|
||||
|
||||
try:
|
||||
context = self.namespace_manager.build_context(
|
||||
user_id=user_id,
|
||||
request_id=request_id,
|
||||
passthrough_data=passthrough_data,
|
||||
docs=docs,
|
||||
docs_together=docs_together,
|
||||
tools_data=tools_data,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return self.template_engine.render(prompt_content, context)
|
||||
except TemplateRenderError:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Prompt rendering failed: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise TemplateRenderError(error_msg) from e
|
||||
|
||||
def _uses_template_syntax(self, prompt_content: str) -> bool:
|
||||
"""Check if prompt uses Jinja2 template syntax"""
|
||||
return "{{" in prompt_content and "}}" in prompt_content
|
||||
|
||||
def _apply_legacy_substitutions(
|
||||
self, prompt_content: str, docs_together: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Apply backward-compatible substitutions for old prompt format.
|
||||
|
||||
Handles legacy {summaries} and {query} placeholders during transition period.
|
||||
"""
|
||||
if docs_together:
|
||||
prompt_content = prompt_content.replace("{summaries}", docs_together)
|
||||
return prompt_content
|
||||
|
||||
def validate_template(self, prompt_content: str) -> bool:
|
||||
"""Validate prompt template syntax"""
|
||||
return self.template_engine.validate_template(prompt_content)
|
||||
|
||||
def extract_variables(self, prompt_content: str) -> set[str]:
|
||||
"""Extract all variable names from prompt template"""
|
||||
return self.template_engine.extract_variables(prompt_content)
|
||||
@@ -3,7 +3,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, Set
|
||||
|
||||
from bson.dbref import DBRef
|
||||
|
||||
@@ -11,10 +11,15 @@ from bson.objectid import ObjectId
|
||||
|
||||
from application.agents.agent_creator import AgentCreator
|
||||
from application.api.answer.services.conversation_service import ConversationService
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.retriever.retriever_creator import RetrieverCreator
|
||||
from application.utils import get_gpt_model, limit_chat_history
|
||||
from application.utils import (
|
||||
calculate_doc_token_budget,
|
||||
get_gpt_model,
|
||||
limit_chat_history,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -73,12 +78,16 @@ class StreamProcessor:
|
||||
self.all_sources = []
|
||||
self.attachments = []
|
||||
self.history = []
|
||||
self.retrieved_docs = []
|
||||
self.agent_config = {}
|
||||
self.retriever_config = {}
|
||||
self.is_shared_usage = False
|
||||
self.shared_token = None
|
||||
self.gpt_model = get_gpt_model()
|
||||
self.conversation_service = ConversationService()
|
||||
self.prompt_renderer = PromptRenderer()
|
||||
self._prompt_content: Optional[str] = None
|
||||
self._required_tool_actions: Optional[Dict[str, Set[Optional[str]]]] = None
|
||||
|
||||
def initialize(self):
|
||||
"""Initialize all required components for processing"""
|
||||
@@ -311,19 +320,312 @@ class StreamProcessor:
|
||||
)
|
||||
|
||||
def _configure_retriever(self):
|
||||
"""Configure the retriever based on request data"""
|
||||
history_token_limit = int(self.data.get("token_limit", 2000))
|
||||
doc_token_limit = calculate_doc_token_budget(
|
||||
gpt_model=self.gpt_model, history_token_limit=history_token_limit
|
||||
)
|
||||
|
||||
self.retriever_config = {
|
||||
"retriever_name": self.data.get("retriever", "classic"),
|
||||
"chunks": int(self.data.get("chunks", 2)),
|
||||
"token_limit": self.data.get("token_limit", settings.DEFAULT_MAX_HISTORY),
|
||||
"doc_token_limit": doc_token_limit,
|
||||
"history_token_limit": history_token_limit,
|
||||
}
|
||||
|
||||
api_key = self.data.get("api_key") or self.agent_key
|
||||
if not api_key and "isNoneDoc" in self.data and self.data["isNoneDoc"]:
|
||||
self.retriever_config["chunks"] = 0
|
||||
|
||||
def create_agent(self):
|
||||
"""Create and return the configured agent"""
|
||||
def create_retriever(self):
|
||||
return RetrieverCreator.create_retriever(
|
||||
self.retriever_config["retriever_name"],
|
||||
source=self.source,
|
||||
chat_history=self.history,
|
||||
prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection),
|
||||
chunks=self.retriever_config["chunks"],
|
||||
doc_token_limit=self.retriever_config.get("doc_token_limit", 50000),
|
||||
gpt_model=self.gpt_model,
|
||||
user_api_key=self.agent_config["user_api_key"],
|
||||
decoded_token=self.decoded_token,
|
||||
)
|
||||
|
||||
def pre_fetch_docs(self, question: str) -> tuple[Optional[str], Optional[list]]:
|
||||
"""Pre-fetch documents for template rendering before agent creation"""
|
||||
if self.data.get("isNoneDoc", False):
|
||||
logger.info("Pre-fetch skipped: isNoneDoc=True")
|
||||
return None, None
|
||||
try:
|
||||
retriever = self.create_retriever()
|
||||
logger.info(
|
||||
f"Pre-fetching docs with chunks={retriever.chunks}, doc_token_limit={retriever.doc_token_limit}"
|
||||
)
|
||||
docs = retriever.search(question)
|
||||
logger.info(f"Pre-fetch retrieved {len(docs) if docs else 0} documents")
|
||||
|
||||
if not docs:
|
||||
logger.info("Pre-fetch: No documents returned from search")
|
||||
return None, None
|
||||
self.retrieved_docs = docs
|
||||
|
||||
docs_with_filenames = []
|
||||
for doc in docs:
|
||||
filename = doc.get("filename") or doc.get("title") or doc.get("source")
|
||||
if filename:
|
||||
chunk_header = str(filename)
|
||||
docs_with_filenames.append(f"{chunk_header}\n{doc['text']}")
|
||||
else:
|
||||
docs_with_filenames.append(doc["text"])
|
||||
docs_together = "\n\n".join(docs_with_filenames)
|
||||
|
||||
logger.info(f"Pre-fetch docs_together size: {len(docs_together)} chars")
|
||||
|
||||
return docs_together, docs
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to pre-fetch docs: {str(e)}", exc_info=True)
|
||||
return None, None
|
||||
|
||||
def pre_fetch_tools(self) -> Optional[Dict[str, Any]]:
|
||||
"""Pre-fetch tool data for template rendering before agent creation
|
||||
|
||||
Can be controlled via:
|
||||
1. Global setting: ENABLE_TOOL_PREFETCH in .env
|
||||
2. Per-request: disable_tool_prefetch in request data
|
||||
"""
|
||||
if not settings.ENABLE_TOOL_PREFETCH:
|
||||
logger.info(
|
||||
"Tool pre-fetching disabled globally via ENABLE_TOOL_PREFETCH setting"
|
||||
)
|
||||
return None
|
||||
|
||||
if self.data.get("disable_tool_prefetch", False):
|
||||
logger.info("Tool pre-fetching disabled for this request")
|
||||
return None
|
||||
|
||||
required_tool_actions = self._get_required_tool_actions()
|
||||
filtering_enabled = required_tool_actions is not None
|
||||
|
||||
try:
|
||||
user_tools_collection = self.db["user_tools"]
|
||||
user_id = self.initial_user_id or "local"
|
||||
|
||||
user_tools = list(
|
||||
user_tools_collection.find({"user": user_id, "status": True})
|
||||
)
|
||||
|
||||
if not user_tools:
|
||||
return None
|
||||
|
||||
tools_data = {}
|
||||
|
||||
for tool_doc in user_tools:
|
||||
tool_name = tool_doc.get("name")
|
||||
tool_id = str(tool_doc.get("_id"))
|
||||
|
||||
if filtering_enabled:
|
||||
required_actions_by_name = required_tool_actions.get(
|
||||
tool_name, set()
|
||||
)
|
||||
required_actions_by_id = required_tool_actions.get(tool_id, set())
|
||||
|
||||
required_actions = required_actions_by_name | required_actions_by_id
|
||||
|
||||
if not required_actions:
|
||||
continue
|
||||
else:
|
||||
required_actions = None
|
||||
|
||||
tool_data = self._fetch_tool_data(tool_doc, required_actions)
|
||||
if tool_data:
|
||||
tools_data[tool_name] = tool_data
|
||||
tools_data[tool_id] = tool_data
|
||||
|
||||
return tools_data if tools_data else None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to pre-fetch tools: {type(e).__name__}")
|
||||
return None
|
||||
|
||||
def _fetch_tool_data(
|
||||
self,
|
||||
tool_doc: Dict[str, Any],
|
||||
required_actions: Optional[Set[Optional[str]]],
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Fetch and execute tool actions with saved parameters"""
|
||||
try:
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
|
||||
tool_name = tool_doc.get("name")
|
||||
tool_config = tool_doc.get("config", {}).copy()
|
||||
tool_config["tool_id"] = str(tool_doc["_id"])
|
||||
|
||||
tool_manager = ToolManager(config={tool_name: tool_config})
|
||||
user_id = self.initial_user_id or "local"
|
||||
tool = tool_manager.load_tool(tool_name, tool_config, user_id=user_id)
|
||||
|
||||
if not tool:
|
||||
logger.debug(f"Tool '{tool_name}' failed to load")
|
||||
return None
|
||||
|
||||
tool_actions = tool.get_actions_metadata()
|
||||
if not tool_actions:
|
||||
logger.debug(f"Tool '{tool_name}' has no actions")
|
||||
return None
|
||||
|
||||
saved_actions = tool_doc.get("actions", [])
|
||||
|
||||
include_all_actions = required_actions is None or (
|
||||
required_actions and None in required_actions
|
||||
)
|
||||
allowed_actions: Set[str] = (
|
||||
{action for action in required_actions if isinstance(action, str)}
|
||||
if required_actions
|
||||
else set()
|
||||
)
|
||||
|
||||
action_results = {}
|
||||
for action_meta in tool_actions:
|
||||
action_name = action_meta.get("name")
|
||||
if action_name is None:
|
||||
continue
|
||||
if (
|
||||
not include_all_actions
|
||||
and allowed_actions
|
||||
and action_name not in allowed_actions
|
||||
):
|
||||
continue
|
||||
|
||||
try:
|
||||
saved_action = None
|
||||
for sa in saved_actions:
|
||||
if sa.get("name") == action_name:
|
||||
saved_action = sa
|
||||
break
|
||||
|
||||
action_params = action_meta.get("parameters", {})
|
||||
properties = action_params.get("properties", {})
|
||||
|
||||
kwargs = {}
|
||||
for param_name, param_spec in properties.items():
|
||||
if saved_action:
|
||||
saved_props = saved_action.get("parameters", {}).get(
|
||||
"properties", {}
|
||||
)
|
||||
if param_name in saved_props:
|
||||
param_value = saved_props[param_name].get("value")
|
||||
if param_value is not None:
|
||||
kwargs[param_name] = param_value
|
||||
continue
|
||||
|
||||
if param_name in tool_config:
|
||||
kwargs[param_name] = tool_config[param_name]
|
||||
elif "default" in param_spec:
|
||||
kwargs[param_name] = param_spec["default"]
|
||||
|
||||
result = tool.execute_action(action_name, **kwargs)
|
||||
action_results[action_name] = result
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"Action '{action_name}' execution failed: {type(e).__name__}"
|
||||
)
|
||||
continue
|
||||
|
||||
return action_results if action_results else None
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Tool pre-fetch failed for '{tool_name}': {type(e).__name__}")
|
||||
return None
|
||||
|
||||
def _get_prompt_content(self) -> Optional[str]:
|
||||
"""Retrieve and cache the raw prompt content for the current agent configuration."""
|
||||
if self._prompt_content is not None:
|
||||
return self._prompt_content
|
||||
prompt_id = (
|
||||
self.agent_config.get("prompt_id")
|
||||
if isinstance(self.agent_config, dict)
|
||||
else None
|
||||
)
|
||||
if not prompt_id:
|
||||
return None
|
||||
try:
|
||||
self._prompt_content = get_prompt(prompt_id, self.prompts_collection)
|
||||
except ValueError as e:
|
||||
logger.debug(f"Invalid prompt ID '{prompt_id}': {str(e)}")
|
||||
self._prompt_content = None
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to fetch prompt '{prompt_id}': {type(e).__name__}")
|
||||
self._prompt_content = None
|
||||
return self._prompt_content
|
||||
|
||||
def _get_required_tool_actions(self) -> Optional[Dict[str, Set[Optional[str]]]]:
|
||||
"""Determine which tool actions are referenced in the prompt template"""
|
||||
if self._required_tool_actions is not None:
|
||||
return self._required_tool_actions
|
||||
|
||||
prompt_content = self._get_prompt_content()
|
||||
if prompt_content is None:
|
||||
return None
|
||||
|
||||
if "{{" not in prompt_content or "}}" not in prompt_content:
|
||||
self._required_tool_actions = {}
|
||||
return self._required_tool_actions
|
||||
|
||||
try:
|
||||
from application.templates.template_engine import TemplateEngine
|
||||
|
||||
template_engine = TemplateEngine()
|
||||
usages = template_engine.extract_tool_usages(prompt_content)
|
||||
self._required_tool_actions = usages
|
||||
return self._required_tool_actions
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to extract tool usages: {type(e).__name__}")
|
||||
self._required_tool_actions = {}
|
||||
return self._required_tool_actions
|
||||
|
||||
def _fetch_memory_tool_data(
|
||||
self, tool_doc: Dict[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Fetch memory tool data for pre-injection into prompt"""
|
||||
try:
|
||||
tool_config = tool_doc.get("config", {}).copy()
|
||||
tool_config["tool_id"] = str(tool_doc["_id"])
|
||||
|
||||
from application.agents.tools.memory import MemoryTool
|
||||
|
||||
memory_tool = MemoryTool(tool_config, self.initial_user_id)
|
||||
|
||||
root_view = memory_tool.execute_action("view", path="/")
|
||||
|
||||
if "Error:" in root_view or not root_view.strip():
|
||||
return None
|
||||
|
||||
return {"root": root_view, "available": True}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch memory tool data: {str(e)}")
|
||||
return None
|
||||
|
||||
def create_agent(
|
||||
self,
|
||||
docs_together: Optional[str] = None,
|
||||
docs: Optional[list] = None,
|
||||
tools_data: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""Create and return the configured agent with rendered prompt"""
|
||||
raw_prompt = self._get_prompt_content()
|
||||
if raw_prompt is None:
|
||||
raw_prompt = get_prompt(
|
||||
self.agent_config["prompt_id"], self.prompts_collection
|
||||
)
|
||||
self._prompt_content = raw_prompt
|
||||
|
||||
rendered_prompt = self.prompt_renderer.render_prompt(
|
||||
prompt_content=raw_prompt,
|
||||
user_id=self.initial_user_id,
|
||||
request_id=self.data.get("request_id"),
|
||||
passthrough_data=self.data.get("passthrough"),
|
||||
docs=docs,
|
||||
docs_together=docs_together,
|
||||
tools_data=tools_data,
|
||||
)
|
||||
|
||||
return AgentCreator.create_agent(
|
||||
self.agent_config["agent_type"],
|
||||
endpoint="stream",
|
||||
@@ -331,23 +633,10 @@ class StreamProcessor:
|
||||
gpt_model=self.gpt_model,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=self.agent_config["user_api_key"],
|
||||
prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection),
|
||||
prompt=rendered_prompt,
|
||||
chat_history=self.history,
|
||||
retrieved_docs=self.retrieved_docs,
|
||||
decoded_token=self.decoded_token,
|
||||
attachments=self.attachments,
|
||||
json_schema=self.agent_config.get("json_schema"),
|
||||
)
|
||||
|
||||
def create_retriever(self):
|
||||
"""Create and return the configured retriever"""
|
||||
return RetrieverCreator.create_retriever(
|
||||
self.retriever_config["retriever_name"],
|
||||
source=self.source,
|
||||
chat_history=self.history,
|
||||
prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection),
|
||||
chunks=self.retriever_config["chunks"],
|
||||
token_limit=self.retriever_config["token_limit"],
|
||||
gpt_model=self.gpt_model,
|
||||
user_api_key=self.agent_config["user_api_key"],
|
||||
decoded_token=self.decoded_token,
|
||||
)
|
||||
|
||||
@@ -10,7 +10,6 @@ from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.core.settings import settings
|
||||
from application.api.user.base import (
|
||||
agents_collection,
|
||||
db,
|
||||
@@ -20,6 +19,7 @@ from application.api.user.base import (
|
||||
storage,
|
||||
users_collection,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
from application.utils import (
|
||||
check_required_fields,
|
||||
generate_image_url,
|
||||
@@ -76,9 +76,13 @@ class GetAgent(Resource):
|
||||
"status": agent.get("status", ""),
|
||||
"json_schema": agent.get("json_schema"),
|
||||
"limited_token_mode": agent.get("limited_token_mode", False),
|
||||
"token_limit": agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]),
|
||||
"token_limit": agent.get(
|
||||
"token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]
|
||||
),
|
||||
"limited_request_mode": agent.get("limited_request_mode", False),
|
||||
"request_limit": agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]),
|
||||
"request_limit": agent.get(
|
||||
"request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]
|
||||
),
|
||||
"created_at": agent.get("createdAt", ""),
|
||||
"updated_at": agent.get("updatedAt", ""),
|
||||
"last_used_at": agent.get("lastUsedAt", ""),
|
||||
@@ -149,9 +153,13 @@ class GetAgents(Resource):
|
||||
"status": agent.get("status", ""),
|
||||
"json_schema": agent.get("json_schema"),
|
||||
"limited_token_mode": agent.get("limited_token_mode", False),
|
||||
"token_limit": agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]),
|
||||
"token_limit": agent.get(
|
||||
"token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]
|
||||
),
|
||||
"limited_request_mode": agent.get("limited_request_mode", False),
|
||||
"request_limit": agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]),
|
||||
"request_limit": agent.get(
|
||||
"request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]
|
||||
),
|
||||
"created_at": agent.get("createdAt", ""),
|
||||
"updated_at": agent.get("updatedAt", ""),
|
||||
"last_used_at": agent.get("lastUsedAt", ""),
|
||||
@@ -209,21 +217,19 @@ class CreateAgent(Resource):
|
||||
description="JSON schema for enforcing structured output format",
|
||||
),
|
||||
"limited_token_mode": fields.Boolean(
|
||||
required=False,
|
||||
description="Whether the agent is in limited token mode"
|
||||
required=False, description="Whether the agent is in limited token mode"
|
||||
),
|
||||
"token_limit": fields.Integer(
|
||||
required=False,
|
||||
description="Token limit for the agent in limited mode"
|
||||
required=False, description="Token limit for the agent in limited mode"
|
||||
),
|
||||
"limited_request_mode": fields.Boolean(
|
||||
required=False,
|
||||
description="Whether the agent is in limited request mode"
|
||||
description="Whether the agent is in limited request mode",
|
||||
),
|
||||
"request_limit": fields.Integer(
|
||||
required=False,
|
||||
description="Request limit for the agent in limited mode"
|
||||
)
|
||||
description="Request limit for the agent in limited mode",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -369,10 +375,26 @@ class CreateAgent(Resource):
|
||||
"agent_type": data.get("agent_type", ""),
|
||||
"status": data.get("status"),
|
||||
"json_schema": data.get("json_schema"),
|
||||
"limited_token_mode": data.get("limited_token_mode", False),
|
||||
"token_limit": data.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]),
|
||||
"limited_request_mode": data.get("limited_request_mode", False),
|
||||
"request_limit": data.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]),
|
||||
"limited_token_mode": (
|
||||
data.get("limited_token_mode") == "True"
|
||||
if isinstance(data.get("limited_token_mode"), str)
|
||||
else bool(data.get("limited_token_mode", False))
|
||||
),
|
||||
"token_limit": int(
|
||||
data.get(
|
||||
"token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]
|
||||
)
|
||||
),
|
||||
"limited_request_mode": (
|
||||
data.get("limited_request_mode") == "True"
|
||||
if isinstance(data.get("limited_request_mode"), str)
|
||||
else bool(data.get("limited_request_mode", False))
|
||||
),
|
||||
"request_limit": int(
|
||||
data.get(
|
||||
"request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]
|
||||
)
|
||||
),
|
||||
"createdAt": datetime.datetime.now(datetime.timezone.utc),
|
||||
"updatedAt": datetime.datetime.now(datetime.timezone.utc),
|
||||
"lastUsedAt": None,
|
||||
@@ -429,21 +451,19 @@ class UpdateAgent(Resource):
|
||||
description="JSON schema for enforcing structured output format",
|
||||
),
|
||||
"limited_token_mode": fields.Boolean(
|
||||
required=False,
|
||||
description="Whether the agent is in limited token mode"
|
||||
required=False, description="Whether the agent is in limited token mode"
|
||||
),
|
||||
"token_limit": fields.Integer(
|
||||
required=False,
|
||||
description="Token limit for the agent in limited mode"
|
||||
required=False, description="Token limit for the agent in limited mode"
|
||||
),
|
||||
"limited_request_mode": fields.Boolean(
|
||||
require=False,
|
||||
description="Whether the agent is in limited request mode"
|
||||
description="Whether the agent is in limited request mode",
|
||||
),
|
||||
"request_limit": fields.Integer(
|
||||
required=False,
|
||||
description="Request limit for the agent in limited mode"
|
||||
)
|
||||
description="Request limit for the agent in limited mode",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -534,7 +554,7 @@ class UpdateAgent(Resource):
|
||||
"limited_token_mode",
|
||||
"token_limit",
|
||||
"limited_request_mode",
|
||||
"request_limit"
|
||||
"request_limit",
|
||||
]
|
||||
|
||||
for field in allowed_fields:
|
||||
@@ -652,8 +672,15 @@ class UpdateAgent(Resource):
|
||||
else:
|
||||
update_fields[field] = None
|
||||
elif field == "limited_token_mode":
|
||||
is_mode_enabled = data.get("limited_token_mode", False)
|
||||
if is_mode_enabled and data.get("token_limit") is None:
|
||||
raw_value = data.get("limited_token_mode", False)
|
||||
bool_value = (
|
||||
raw_value == "True"
|
||||
if isinstance(raw_value, str)
|
||||
else bool(raw_value)
|
||||
)
|
||||
update_fields[field] = bool_value
|
||||
|
||||
if bool_value and data.get("token_limit") is None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
@@ -664,8 +691,15 @@ class UpdateAgent(Resource):
|
||||
400,
|
||||
)
|
||||
elif field == "limited_request_mode":
|
||||
is_mode_enabled = data.get("limited_request_mode", False)
|
||||
if is_mode_enabled and data.get("request_limit") is None:
|
||||
raw_value = data.get("limited_request_mode", False)
|
||||
bool_value = (
|
||||
raw_value == "True"
|
||||
if isinstance(raw_value, str)
|
||||
else bool(raw_value)
|
||||
)
|
||||
update_fields[field] = bool_value
|
||||
|
||||
if bool_value and data.get("request_limit") is None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
@@ -677,7 +711,11 @@ class UpdateAgent(Resource):
|
||||
)
|
||||
elif field == "token_limit":
|
||||
token_limit = data.get("token_limit")
|
||||
if token_limit is not None and not data.get("limited_token_mode"):
|
||||
# Convert to int and store
|
||||
update_fields[field] = int(token_limit) if token_limit else 0
|
||||
|
||||
# Validate consistency with mode
|
||||
if update_fields[field] > 0 and not data.get("limited_token_mode"):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
@@ -689,7 +727,9 @@ class UpdateAgent(Resource):
|
||||
)
|
||||
elif field == "request_limit":
|
||||
request_limit = data.get("request_limit")
|
||||
if request_limit is not None and not data.get("limited_request_mode"):
|
||||
update_fields[field] = int(request_limit) if request_limit else 0
|
||||
|
||||
if update_fields[field] > 0 and not data.get("limited_request_mode"):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user