mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
feat: template-based prompt rendering with dynamic namespace injection (#2091)
* feat: template-based prompt rendering with dynamic namespace injection * refactor: improve template engine initialization with clearer formatting * refactor: streamline ReActAgent methods and improve content extraction logic feat: enhance error handling in NamespaceManager and TemplateEngine fix: update NewAgent component to ensure consistent form data submission test: modify tests for ReActAgent and prompt renderer to reflect method changes and improve coverage * feat: tools namespace + three-tier token budget * refactor: remove unused variable assignment in message building tests * Enhance prompt customization and tool pre-fetching functionality * ruff lint fix * refactor: cleaner error handling and reduce code clutter --------- Co-authored-by: Alex <a@tushynski.me>
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -3,6 +3,7 @@ __pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
experiments
|
||||
# C extensions
|
||||
*.so
|
||||
*.next
|
||||
|
||||
@@ -12,7 +12,6 @@ from application.core.settings import settings
|
||||
from application.llm.handlers.handler_creator import LLMHandlerCreator
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.logging import build_stack_data, log_activity, LogContext
|
||||
from application.retriever.base import BaseRetriever
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -27,6 +26,7 @@ class BaseAgent(ABC):
|
||||
user_api_key: Optional[str] = None,
|
||||
prompt: str = "",
|
||||
chat_history: Optional[List[Dict]] = None,
|
||||
retrieved_docs: Optional[List[Dict]] = None,
|
||||
decoded_token: Optional[Dict] = None,
|
||||
attachments: Optional[List[Dict]] = None,
|
||||
json_schema: Optional[Dict] = None,
|
||||
@@ -53,6 +53,7 @@ class BaseAgent(ABC):
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
self.retrieved_docs = retrieved_docs or []
|
||||
self.llm_handler = LLMHandlerCreator.create_handler(
|
||||
llm_name if llm_name else "default"
|
||||
)
|
||||
@@ -65,13 +66,13 @@ class BaseAgent(ABC):
|
||||
|
||||
@log_activity()
|
||||
def gen(
|
||||
self, query: str, retriever: BaseRetriever, log_context: LogContext = None
|
||||
self, query: str, log_context: LogContext = None
|
||||
) -> Generator[Dict, None, None]:
|
||||
yield from self._gen_inner(query, retriever, log_context)
|
||||
yield from self._gen_inner(query, log_context)
|
||||
|
||||
@abstractmethod
|
||||
def _gen_inner(
|
||||
self, query: str, retriever: BaseRetriever, log_context: LogContext
|
||||
self, query: str, log_context: LogContext
|
||||
) -> Generator[Dict, None, None]:
|
||||
pass
|
||||
|
||||
@@ -150,6 +151,7 @@ class BaseAgent(ABC):
|
||||
call_id = getattr(call, "id", None) or str(uuid.uuid4())
|
||||
|
||||
# Check if parsing failed
|
||||
|
||||
if tool_id is None or action_name is None:
|
||||
error_message = f"Error: Failed to parse LLM tool call. Tool name: {getattr(call, 'name', 'unknown')}"
|
||||
logger.error(error_message)
|
||||
@@ -164,13 +166,14 @@ class BaseAgent(ABC):
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
return "Failed to parse tool call.", call_id
|
||||
|
||||
# Check if tool_id exists in available tools
|
||||
|
||||
if tool_id not in tools_dict:
|
||||
error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}"
|
||||
logger.error(error_message)
|
||||
|
||||
# Return error result
|
||||
|
||||
tool_call_data = {
|
||||
"tool_name": "unknown",
|
||||
"call_id": call_id,
|
||||
@@ -181,7 +184,6 @@ class BaseAgent(ABC):
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
return f"Tool with ID {tool_id} not found.", call_id
|
||||
|
||||
tool_call_data = {
|
||||
"tool_name": tools_dict[tool_id]["name"],
|
||||
"call_id": call_id,
|
||||
@@ -223,6 +225,7 @@ class BaseAgent(ABC):
|
||||
tm = ToolManager(config={})
|
||||
|
||||
# Prepare tool_config and add tool_id for memory tools
|
||||
|
||||
if tool_data["name"] == "api_tool":
|
||||
tool_config = {
|
||||
"url": tool_data["config"]["actions"][action_name]["url"],
|
||||
@@ -234,8 +237,8 @@ class BaseAgent(ABC):
|
||||
tool_config = tool_data["config"].copy() if tool_data["config"] else {}
|
||||
# Add tool_id from MongoDB _id for tools that need instance isolation (like memory tool)
|
||||
# Use MongoDB _id if available, otherwise fall back to enumerated tool_id
|
||||
tool_config["tool_id"] = str(tool_data.get("_id", tool_id))
|
||||
|
||||
tool_config["tool_id"] = str(tool_data.get("_id", tool_id))
|
||||
tool = tm.load_tool(
|
||||
tool_data["name"],
|
||||
tool_config=tool_config,
|
||||
@@ -276,24 +279,14 @@ class BaseAgent(ABC):
|
||||
self,
|
||||
system_prompt: str,
|
||||
query: str,
|
||||
retrieved_data: List[Dict],
|
||||
) -> List[Dict]:
|
||||
docs_with_filenames = []
|
||||
for doc in retrieved_data:
|
||||
filename = doc.get("filename") or doc.get("title") or doc.get("source")
|
||||
if filename:
|
||||
chunk_header = str(filename)
|
||||
docs_with_filenames.append(f"{chunk_header}\n{doc['text']}")
|
||||
else:
|
||||
docs_with_filenames.append(doc["text"])
|
||||
docs_together = "\n\n".join(docs_with_filenames)
|
||||
p_chat_combine = system_prompt.replace("{summaries}", docs_together)
|
||||
messages_combine = [{"role": "system", "content": p_chat_combine}]
|
||||
"""Build messages using pre-rendered system prompt"""
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
|
||||
for i in self.chat_history:
|
||||
if "prompt" in i and "response" in i:
|
||||
messages_combine.append({"role": "user", "content": i["prompt"]})
|
||||
messages_combine.append({"role": "assistant", "content": i["response"]})
|
||||
messages.append({"role": "user", "content": i["prompt"]})
|
||||
messages.append({"role": "assistant", "content": i["response"]})
|
||||
if "tool_calls" in i:
|
||||
for tool_call in i["tool_calls"]:
|
||||
call_id = tool_call.get("call_id") or str(uuid.uuid4())
|
||||
@@ -313,26 +306,14 @@ class BaseAgent(ABC):
|
||||
}
|
||||
}
|
||||
|
||||
messages_combine.append(
|
||||
messages.append(
|
||||
{"role": "assistant", "content": [function_call_dict]}
|
||||
)
|
||||
messages_combine.append(
|
||||
messages.append(
|
||||
{"role": "tool", "content": [function_response_dict]}
|
||||
)
|
||||
messages_combine.append({"role": "user", "content": query})
|
||||
return messages_combine
|
||||
|
||||
def _retriever_search(
|
||||
self,
|
||||
retriever: BaseRetriever,
|
||||
query: str,
|
||||
log_context: Optional[LogContext] = None,
|
||||
) -> List[Dict]:
|
||||
retrieved_data = retriever.search(query)
|
||||
if log_context:
|
||||
data = build_stack_data(retriever, exclude_attributes=["llm"])
|
||||
log_context.stacks.append({"component": "retriever", "data": data})
|
||||
return retrieved_data
|
||||
messages.append({"role": "user", "content": query})
|
||||
return messages
|
||||
|
||||
def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None):
|
||||
gen_kwargs = {"model": self.gpt_model, "messages": messages}
|
||||
@@ -343,7 +324,6 @@ class BaseAgent(ABC):
|
||||
and self.tools
|
||||
):
|
||||
gen_kwargs["tools"] = self.tools
|
||||
|
||||
if (
|
||||
self.json_schema
|
||||
and hasattr(self.llm, "_supports_structured_output")
|
||||
@@ -357,7 +337,6 @@ class BaseAgent(ABC):
|
||||
gen_kwargs["response_format"] = structured_format
|
||||
elif self.llm_name == "google":
|
||||
gen_kwargs["response_schema"] = structured_format
|
||||
|
||||
resp = self.llm.gen_stream(**gen_kwargs)
|
||||
|
||||
if log_context:
|
||||
|
||||
@@ -1,32 +1,20 @@
|
||||
import logging
|
||||
from typing import Dict, Generator
|
||||
|
||||
from application.agents.base import BaseAgent
|
||||
from application.logging import LogContext
|
||||
from application.retriever.base import BaseRetriever
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClassicAgent(BaseAgent):
|
||||
"""A simplified agent with clear execution flow.
|
||||
|
||||
Usage:
|
||||
1. Processes a query through retrieval
|
||||
2. Sets up available tools
|
||||
3. Generates responses using LLM
|
||||
4. Handles tool interactions if needed
|
||||
5. Returns standardized outputs
|
||||
|
||||
Easy to extend by overriding specific steps.
|
||||
"""
|
||||
"""A simplified agent with clear execution flow"""
|
||||
|
||||
def _gen_inner(
|
||||
self, query: str, retriever: BaseRetriever, log_context: LogContext
|
||||
self, query: str, log_context: LogContext
|
||||
) -> Generator[Dict, None, None]:
|
||||
# Step 1: Retrieve relevant data
|
||||
retrieved_data = self._retriever_search(retriever, query, log_context)
|
||||
"""Core generator function for ClassicAgent execution flow"""
|
||||
|
||||
# Step 2: Prepare tools
|
||||
tools_dict = (
|
||||
self._get_user_tools(self.user)
|
||||
if not self.user_api_key
|
||||
@@ -34,20 +22,16 @@ class ClassicAgent(BaseAgent):
|
||||
)
|
||||
self._prepare_tools(tools_dict)
|
||||
|
||||
# Step 3: Build and process messages
|
||||
messages = self._build_messages(self.prompt, query, retrieved_data)
|
||||
messages = self._build_messages(self.prompt, query)
|
||||
llm_response = self._llm_gen(messages, log_context)
|
||||
|
||||
# Step 4: Handle the response
|
||||
yield from self._handle_response(
|
||||
llm_response, tools_dict, messages, log_context
|
||||
)
|
||||
|
||||
# Step 5: Return metadata
|
||||
yield {"sources": retrieved_data}
|
||||
yield {"sources": self.retrieved_docs}
|
||||
yield {"tool_calls": self._get_truncated_tool_calls()}
|
||||
|
||||
# Log tool calls for debugging
|
||||
log_context.stacks.append(
|
||||
{"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}}
|
||||
)
|
||||
|
||||
@@ -1,284 +1,238 @@
|
||||
import os
|
||||
from typing import Dict, Generator, List, Any
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, Generator, List
|
||||
|
||||
from application.agents.base import BaseAgent
|
||||
from application.logging import build_stack_data, LogContext
|
||||
from application.retriever.base import BaseRetriever
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_ITERATIONS_REASONING = 10
|
||||
|
||||
current_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
with open(
|
||||
os.path.join(current_dir, "application/prompts", "react_planning_prompt.txt"), "r"
|
||||
) as f:
|
||||
planning_prompt_template = f.read()
|
||||
PLANNING_PROMPT_TEMPLATE = f.read()
|
||||
with open(
|
||||
os.path.join(current_dir, "application/prompts", "react_final_prompt.txt"),
|
||||
"r",
|
||||
os.path.join(current_dir, "application/prompts", "react_final_prompt.txt"), "r"
|
||||
) as f:
|
||||
final_prompt_template = f.read()
|
||||
|
||||
MAX_ITERATIONS_REASONING = 10
|
||||
FINAL_PROMPT_TEMPLATE = f.read()
|
||||
|
||||
|
||||
class ReActAgent(BaseAgent):
|
||||
"""
|
||||
Research and Action (ReAct) Agent - Advanced reasoning agent with iterative planning.
|
||||
|
||||
Implements a think-act-observe loop for complex problem-solving:
|
||||
1. Creates a strategic plan based on the query
|
||||
2. Executes tools and gathers observations
|
||||
3. Iteratively refines approach until satisfied
|
||||
4. Synthesizes final answer from all observations
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.plan: str = ""
|
||||
self.observations: List[str] = []
|
||||
|
||||
def _extract_content_from_llm_response(self, resp: Any) -> str:
|
||||
"""
|
||||
Helper to extract string content from various LLM response types.
|
||||
Handles strings, message objects (OpenAI-like), and streams.
|
||||
Adapt stream handling for your specific LLM client if not OpenAI.
|
||||
"""
|
||||
collected_content = []
|
||||
if isinstance(resp, str):
|
||||
collected_content.append(resp)
|
||||
elif ( # OpenAI non-streaming or Anthropic non-streaming (older SDK style)
|
||||
hasattr(resp, "message")
|
||||
and hasattr(resp.message, "content")
|
||||
and resp.message.content is not None
|
||||
):
|
||||
collected_content.append(resp.message.content)
|
||||
elif ( # OpenAI non-streaming (Pydantic model), Anthropic new SDK non-streaming
|
||||
hasattr(resp, "choices")
|
||||
and resp.choices
|
||||
and hasattr(resp.choices[0], "message")
|
||||
and hasattr(resp.choices[0].message, "content")
|
||||
and resp.choices[0].message.content is not None
|
||||
):
|
||||
collected_content.append(resp.choices[0].message.content) # OpenAI
|
||||
elif ( # Anthropic new SDK non-streaming content block
|
||||
hasattr(resp, "content")
|
||||
and isinstance(resp.content, list)
|
||||
and resp.content
|
||||
and hasattr(resp.content[0], "text")
|
||||
):
|
||||
collected_content.append(resp.content[0].text) # Anthropic
|
||||
else:
|
||||
# Assume resp is a stream if not a recognized object
|
||||
chunk = None
|
||||
try:
|
||||
for (
|
||||
chunk
|
||||
) in (
|
||||
resp
|
||||
): # This will fail if resp is not iterable (e.g. a non-streaming response object)
|
||||
content_piece = ""
|
||||
# OpenAI-like stream
|
||||
if (
|
||||
hasattr(chunk, "choices")
|
||||
and len(chunk.choices) > 0
|
||||
and hasattr(chunk.choices[0], "delta")
|
||||
and hasattr(chunk.choices[0].delta, "content")
|
||||
and chunk.choices[0].delta.content is not None
|
||||
):
|
||||
content_piece = chunk.choices[0].delta.content
|
||||
# Anthropic-like stream (ContentBlockDelta)
|
||||
elif (
|
||||
hasattr(chunk, "type")
|
||||
and chunk.type == "content_block_delta"
|
||||
and hasattr(chunk, "delta")
|
||||
and hasattr(chunk.delta, "text")
|
||||
):
|
||||
content_piece = chunk.delta.text
|
||||
elif isinstance(chunk, str): # Simplest case: stream of strings
|
||||
content_piece = chunk
|
||||
|
||||
if content_piece:
|
||||
collected_content.append(content_piece)
|
||||
except (
|
||||
TypeError
|
||||
): # If resp is not iterable (e.g. a final response object that wasn't caught above)
|
||||
logger.debug(
|
||||
f"Response type {type(resp)} could not be iterated as a stream. It might be a non-streaming object not handled by specific checks."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing potential stream chunk: {e}, chunk was: {getattr(chunk, '__dict__', chunk) if chunk is not None else 'N/A'}"
|
||||
)
|
||||
|
||||
return "".join(collected_content)
|
||||
|
||||
def _gen_inner(
|
||||
self, query: str, retriever: BaseRetriever, log_context: LogContext
|
||||
self, query: str, log_context: LogContext
|
||||
) -> Generator[Dict, None, None]:
|
||||
# Reset state for this generation call
|
||||
self.plan = ""
|
||||
self.observations = []
|
||||
retrieved_data = self._retriever_search(retriever, query, log_context)
|
||||
"""Execute ReAct reasoning loop with planning, action, and observation cycles"""
|
||||
|
||||
if self.user_api_key:
|
||||
tools_dict = self._get_tools(self.user_api_key)
|
||||
else:
|
||||
tools_dict = self._get_user_tools(self.user)
|
||||
self._reset_state()
|
||||
|
||||
tools_dict = (
|
||||
self._get_tools(self.user_api_key)
|
||||
if self.user_api_key
|
||||
else self._get_user_tools(self.user)
|
||||
)
|
||||
self._prepare_tools(tools_dict)
|
||||
|
||||
docs_together = "\n".join([doc["text"] for doc in retrieved_data])
|
||||
iterating_reasoning = 0
|
||||
while iterating_reasoning < MAX_ITERATIONS_REASONING:
|
||||
iterating_reasoning += 1
|
||||
# 1. Create Plan
|
||||
logger.info("ReActAgent: Creating plan...")
|
||||
plan_stream = self._create_plan(query, docs_together, log_context)
|
||||
current_plan_parts = []
|
||||
yield {"thought": f"Reasoning... (iteration {iterating_reasoning})\n\n"}
|
||||
for line_chunk in plan_stream:
|
||||
current_plan_parts.append(line_chunk)
|
||||
yield {"thought": line_chunk}
|
||||
self.plan = "".join(current_plan_parts)
|
||||
if self.plan:
|
||||
self.observations.append(
|
||||
f"Plan: {self.plan} Iteration: {iterating_reasoning}"
|
||||
)
|
||||
for iteration in range(1, MAX_ITERATIONS_REASONING + 1):
|
||||
yield {"thought": f"Reasoning... (iteration {iteration})\n\n"}
|
||||
|
||||
max_obs_len = 20000
|
||||
obs_str = "\n".join(self.observations)
|
||||
if len(obs_str) > max_obs_len:
|
||||
obs_str = obs_str[:max_obs_len] + "\n...[observations truncated]"
|
||||
execution_prompt_str = (
|
||||
(self.prompt or "")
|
||||
+ f"\n\nFollow this plan:\n{self.plan}"
|
||||
+ f"\n\nObservations:\n{obs_str}"
|
||||
+ f"\n\nIf there is enough data to complete user query '{query}', Respond with 'SATISFIED' only. Otherwise, continue. Dont Menstion 'SATISFIED' in your response if you are not ready. "
|
||||
)
|
||||
yield from self._planning_phase(query, log_context)
|
||||
|
||||
messages = self._build_messages(execution_prompt_str, query, retrieved_data)
|
||||
|
||||
resp_from_llm_gen = self._llm_gen(messages, log_context)
|
||||
|
||||
initial_llm_thought_content = self._extract_content_from_llm_response(
|
||||
resp_from_llm_gen
|
||||
)
|
||||
if initial_llm_thought_content:
|
||||
self.observations.append(
|
||||
f"Initial thought/response: {initial_llm_thought_content}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"ReActAgent: Initial LLM response (before handler) had no textual content (might be only tool calls)."
|
||||
)
|
||||
resp_after_handler = self._llm_handler(
|
||||
resp_from_llm_gen, tools_dict, messages, log_context
|
||||
)
|
||||
|
||||
for (
|
||||
tool_call_info
|
||||
) in (
|
||||
self.tool_calls
|
||||
): # Iterate over self.tool_calls populated by _llm_handler
|
||||
observation_string = (
|
||||
f"Executed Action: Tool '{tool_call_info.get('tool_name', 'N/A')}' "
|
||||
f"with arguments '{tool_call_info.get('arguments', '{}')}'. Result: '{str(tool_call_info.get('result', ''))[:200]}...'"
|
||||
)
|
||||
self.observations.append(observation_string)
|
||||
|
||||
content_after_handler = self._extract_content_from_llm_response(
|
||||
resp_after_handler
|
||||
)
|
||||
if content_after_handler:
|
||||
self.observations.append(
|
||||
f"Response after tool execution: {content_after_handler}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"ReActAgent: LLM response after handler had no textual content."
|
||||
)
|
||||
|
||||
if log_context:
|
||||
log_context.stacks.append(
|
||||
{
|
||||
"component": "agent_tool_calls",
|
||||
"data": {"tool_calls": self.tool_calls.copy()},
|
||||
}
|
||||
)
|
||||
|
||||
yield {"sources": retrieved_data}
|
||||
|
||||
display_tool_calls = []
|
||||
for tc in self.tool_calls:
|
||||
cleaned_tc = tc.copy()
|
||||
if len(str(cleaned_tc.get("result", ""))) > 50:
|
||||
cleaned_tc["result"] = str(cleaned_tc["result"])[:50] + "..."
|
||||
display_tool_calls.append(cleaned_tc)
|
||||
if display_tool_calls:
|
||||
yield {"tool_calls": display_tool_calls}
|
||||
|
||||
if "SATISFIED" in content_after_handler:
|
||||
logger.info(
|
||||
"ReActAgent: LLM satisfied with the plan and data. Stopping reasoning."
|
||||
if not self.plan:
|
||||
logger.warning(
|
||||
f"ReActAgent: No plan generated in iteration {iteration}"
|
||||
)
|
||||
break
|
||||
self.observations.append(f"Plan (iteration {iteration}): {self.plan}")
|
||||
|
||||
# 3. Create Final Answer based on all observations
|
||||
final_answer_stream = self._create_final_answer(
|
||||
query, self.observations, log_context
|
||||
)
|
||||
for answer_chunk in final_answer_stream:
|
||||
yield {"answer": answer_chunk}
|
||||
logger.info("ReActAgent: Finished generating final answer.")
|
||||
satisfied = yield from self._execution_phase(query, tools_dict, log_context)
|
||||
|
||||
def _create_plan(
|
||||
self, query: str, docs_data: str, log_context: LogContext = None
|
||||
) -> Generator[str, None, None]:
|
||||
plan_prompt_filled = planning_prompt_template.replace("{query}", query)
|
||||
if "{summaries}" in plan_prompt_filled:
|
||||
summaries = docs_data if docs_data else "No documents retrieved."
|
||||
plan_prompt_filled = plan_prompt_filled.replace("{summaries}", summaries)
|
||||
plan_prompt_filled = plan_prompt_filled.replace("{prompt}", self.prompt or "")
|
||||
plan_prompt_filled = plan_prompt_filled.replace(
|
||||
"{observations}", "\n".join(self.observations)
|
||||
)
|
||||
if satisfied:
|
||||
logger.info("ReActAgent: Goal satisfied, stopping reasoning loop")
|
||||
break
|
||||
yield from self._synthesis_phase(query, log_context)
|
||||
|
||||
messages = [{"role": "user", "content": plan_prompt_filled}]
|
||||
def _reset_state(self):
|
||||
"""Reset agent state for new query"""
|
||||
self.plan = ""
|
||||
self.observations = []
|
||||
|
||||
plan_stream_from_llm = self.llm.gen_stream(
|
||||
def _planning_phase(
|
||||
self, query: str, log_context: LogContext
|
||||
) -> Generator[Dict, None, None]:
|
||||
"""Generate strategic plan for query"""
|
||||
logger.info("ReActAgent: Creating plan...")
|
||||
|
||||
plan_prompt = self._build_planning_prompt(query)
|
||||
messages = [{"role": "user", "content": plan_prompt}]
|
||||
|
||||
plan_stream = self.llm.gen_stream(
|
||||
model=self.gpt_model,
|
||||
messages=messages,
|
||||
tools=getattr(self, "tools", None), # Use self.tools
|
||||
tools=self.tools if self.tools else None,
|
||||
)
|
||||
|
||||
if log_context:
|
||||
data = build_stack_data(self.llm)
|
||||
log_context.stacks.append({"component": "planning_llm", "data": data})
|
||||
|
||||
for chunk in plan_stream_from_llm:
|
||||
content_piece = self._extract_content_from_llm_response(chunk)
|
||||
if content_piece:
|
||||
yield content_piece
|
||||
|
||||
def _create_final_answer(
|
||||
self, query: str, observations: List[str], log_context: LogContext = None
|
||||
) -> Generator[str, None, None]:
|
||||
observation_string = "\n".join(observations)
|
||||
max_obs_len = 10000
|
||||
if len(observation_string) > max_obs_len:
|
||||
observation_string = (
|
||||
observation_string[:max_obs_len] + "\n...[observations truncated]"
|
||||
)
|
||||
logger.warning(
|
||||
"ReActAgent: Truncated observations for final answer prompt due to length."
|
||||
log_context.stacks.append(
|
||||
{"component": "planning_llm", "data": build_stack_data(self.llm)}
|
||||
)
|
||||
plan_parts = []
|
||||
for chunk in plan_stream:
|
||||
content = self._extract_content(chunk)
|
||||
if content:
|
||||
plan_parts.append(content)
|
||||
yield {"thought": content}
|
||||
self.plan = "".join(plan_parts)
|
||||
|
||||
final_answer_prompt_filled = final_prompt_template.format(
|
||||
query=query, observations=observation_string
|
||||
def _execution_phase(
|
||||
self, query: str, tools_dict: Dict, log_context: LogContext
|
||||
) -> Generator[bool, None, None]:
|
||||
"""Execute plan with tool calls and observations"""
|
||||
execution_prompt = self._build_execution_prompt(query)
|
||||
messages = self._build_messages(execution_prompt, query)
|
||||
|
||||
llm_response = self._llm_gen(messages, log_context)
|
||||
initial_content = self._extract_content(llm_response)
|
||||
|
||||
if initial_content:
|
||||
self.observations.append(f"Initial response: {initial_content}")
|
||||
processed_response = self._llm_handler(
|
||||
llm_response, tools_dict, messages, log_context
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": final_answer_prompt_filled}]
|
||||
for tool_call in self.tool_calls:
|
||||
observation = (
|
||||
f"Executed: {tool_call.get('tool_name', 'Unknown')} "
|
||||
f"with args {tool_call.get('arguments', {})}. "
|
||||
f"Result: {str(tool_call.get('result', ''))[:200]}"
|
||||
)
|
||||
self.observations.append(observation)
|
||||
final_content = self._extract_content(processed_response)
|
||||
if final_content:
|
||||
self.observations.append(f"Response after tools: {final_content}")
|
||||
if log_context:
|
||||
log_context.stacks.append(
|
||||
{
|
||||
"component": "agent_tool_calls",
|
||||
"data": {"tool_calls": self.tool_calls.copy()},
|
||||
}
|
||||
)
|
||||
yield {"sources": self.retrieved_docs}
|
||||
yield {"tool_calls": self._get_truncated_tool_calls()}
|
||||
|
||||
# Final answer should synthesize, not call tools.
|
||||
final_answer_stream_from_llm = self.llm.gen_stream(
|
||||
return "SATISFIED" in (final_content or "")
|
||||
|
||||
def _synthesis_phase(
|
||||
self, query: str, log_context: LogContext
|
||||
) -> Generator[Dict, None, None]:
|
||||
"""Synthesize final answer from all observations"""
|
||||
logger.info("ReActAgent: Generating final answer...")
|
||||
|
||||
final_prompt = self._build_final_answer_prompt(query)
|
||||
messages = [{"role": "user", "content": final_prompt}]
|
||||
|
||||
final_stream = self.llm.gen_stream(
|
||||
model=self.gpt_model, messages=messages, tools=None
|
||||
)
|
||||
if log_context:
|
||||
data = build_stack_data(self.llm)
|
||||
log_context.stacks.append({"component": "final_answer_llm", "data": data})
|
||||
|
||||
for chunk in final_answer_stream_from_llm:
|
||||
content_piece = self._extract_content_from_llm_response(chunk)
|
||||
if content_piece:
|
||||
yield content_piece
|
||||
if log_context:
|
||||
log_context.stacks.append(
|
||||
{"component": "final_answer_llm", "data": build_stack_data(self.llm)}
|
||||
)
|
||||
for chunk in final_stream:
|
||||
content = self._extract_content(chunk)
|
||||
if content:
|
||||
yield {"answer": content}
|
||||
|
||||
def _build_planning_prompt(self, query: str) -> str:
|
||||
"""Build planning phase prompt"""
|
||||
prompt = PLANNING_PROMPT_TEMPLATE.replace("{query}", query)
|
||||
prompt = prompt.replace("{prompt}", self.prompt or "")
|
||||
prompt = prompt.replace("{summaries}", "")
|
||||
prompt = prompt.replace("{observations}", "\n".join(self.observations))
|
||||
return prompt
|
||||
|
||||
def _build_execution_prompt(self, query: str) -> str:
|
||||
"""Build execution phase prompt with plan and observations"""
|
||||
observations_str = "\n".join(self.observations)
|
||||
|
||||
if len(observations_str) > 20000:
|
||||
observations_str = observations_str[:20000] + "\n...[truncated]"
|
||||
return (
|
||||
f"{self.prompt or ''}\n\n"
|
||||
f"Follow this plan:\n{self.plan}\n\n"
|
||||
f"Observations:\n{observations_str}\n\n"
|
||||
f"If sufficient data exists to answer '{query}', respond with 'SATISFIED'. "
|
||||
f"Otherwise, continue executing the plan."
|
||||
)
|
||||
|
||||
def _build_final_answer_prompt(self, query: str) -> str:
|
||||
"""Build final synthesis prompt"""
|
||||
observations_str = "\n".join(self.observations)
|
||||
|
||||
if len(observations_str) > 10000:
|
||||
observations_str = observations_str[:10000] + "\n...[truncated]"
|
||||
logger.warning("ReActAgent: Observations truncated for final answer")
|
||||
return FINAL_PROMPT_TEMPLATE.format(query=query, observations=observations_str)
|
||||
|
||||
def _extract_content(self, response: Any) -> str:
|
||||
"""Extract text content from various LLM response formats"""
|
||||
if not response:
|
||||
return ""
|
||||
collected = []
|
||||
|
||||
if isinstance(response, str):
|
||||
return response
|
||||
if hasattr(response, "message") and hasattr(response.message, "content"):
|
||||
if response.message.content:
|
||||
return response.message.content
|
||||
if hasattr(response, "choices") and response.choices:
|
||||
if hasattr(response.choices[0], "message"):
|
||||
content = response.choices[0].message.content
|
||||
if content:
|
||||
return content
|
||||
if hasattr(response, "content") and isinstance(response.content, list):
|
||||
if response.content and hasattr(response.content[0], "text"):
|
||||
return response.content[0].text
|
||||
try:
|
||||
for chunk in response:
|
||||
content_piece = ""
|
||||
|
||||
if hasattr(chunk, "choices") and chunk.choices:
|
||||
if hasattr(chunk.choices[0], "delta"):
|
||||
delta_content = chunk.choices[0].delta.content
|
||||
if delta_content:
|
||||
content_piece = delta_content
|
||||
elif hasattr(chunk, "type") and chunk.type == "content_block_delta":
|
||||
if hasattr(chunk, "delta") and hasattr(chunk.delta, "text"):
|
||||
content_piece = chunk.delta.text
|
||||
elif isinstance(chunk, str):
|
||||
content_piece = chunk
|
||||
if content_piece:
|
||||
collected.append(content_piece)
|
||||
except (TypeError, AttributeError):
|
||||
logger.debug(
|
||||
f"Response not iterable or unexpected format: {type(response)}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting content: {e}")
|
||||
return "".join(collected)
|
||||
|
||||
@@ -54,6 +54,10 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
default=True,
|
||||
description="Whether to save the conversation",
|
||||
),
|
||||
"passthrough": fields.Raw(
|
||||
required=False,
|
||||
description="Dynamic parameters to inject into prompt template",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -69,8 +73,17 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
processor.initialize()
|
||||
if not processor.decoded_token:
|
||||
return make_response({"error": "Unauthorized"}, 401)
|
||||
agent = processor.create_agent()
|
||||
retriever = processor.create_retriever()
|
||||
|
||||
docs_together, docs_list = processor.pre_fetch_docs(
|
||||
data.get("question", "")
|
||||
)
|
||||
tools_data = processor.pre_fetch_tools()
|
||||
|
||||
agent = processor.create_agent(
|
||||
docs_together=docs_together,
|
||||
docs=docs_list,
|
||||
tools_data=tools_data,
|
||||
)
|
||||
|
||||
if error := self.check_usage(processor.agent_config):
|
||||
return error
|
||||
@@ -78,7 +91,6 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
stream = self.complete_stream(
|
||||
question=data["question"],
|
||||
agent=agent,
|
||||
retriever=retriever,
|
||||
conversation_id=processor.conversation_id,
|
||||
user_api_key=processor.agent_config.get("user_api_key"),
|
||||
decoded_token=processor.decoded_token,
|
||||
|
||||
@@ -3,7 +3,7 @@ import json
|
||||
import logging
|
||||
from typing import Any, Dict, Generator, List, Optional
|
||||
|
||||
from flask import Response, make_response, jsonify
|
||||
from flask import jsonify, make_response, Response
|
||||
from flask_restx import Namespace
|
||||
|
||||
from application.api.answer.services.conversation_service import ConversationService
|
||||
@@ -41,9 +41,7 @@ class BaseAnswerResource:
|
||||
return missing_fields
|
||||
return None
|
||||
|
||||
def check_usage(
|
||||
self, agent_config: Dict
|
||||
) -> Optional[Response]:
|
||||
def check_usage(self, agent_config: Dict) -> Optional[Response]:
|
||||
"""Check if there is a usage limit and if it is exceeded
|
||||
|
||||
Args:
|
||||
@@ -51,30 +49,40 @@ class BaseAnswerResource:
|
||||
|
||||
Returns:
|
||||
None or Response if either of limits exceeded.
|
||||
|
||||
|
||||
"""
|
||||
api_key = agent_config.get("user_api_key")
|
||||
if not api_key:
|
||||
return None
|
||||
|
||||
|
||||
agents_collection = self.db["agents"]
|
||||
agent = agents_collection.find_one({"key": api_key})
|
||||
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Invalid API key."
|
||||
}
|
||||
),
|
||||
401
|
||||
jsonify({"success": False, "message": "Invalid API key."}), 401
|
||||
)
|
||||
|
||||
limited_token_mode = agent.get("limited_token_mode", False)
|
||||
limited_request_mode = agent.get("limited_request_mode", False)
|
||||
token_limit = int(agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]))
|
||||
request_limit = int(agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]))
|
||||
limited_token_mode_raw = agent.get("limited_token_mode", False)
|
||||
limited_request_mode_raw = agent.get("limited_request_mode", False)
|
||||
|
||||
limited_token_mode = (
|
||||
limited_token_mode_raw
|
||||
if isinstance(limited_token_mode_raw, bool)
|
||||
else limited_token_mode_raw == "True"
|
||||
)
|
||||
limited_request_mode = (
|
||||
limited_request_mode_raw
|
||||
if isinstance(limited_request_mode_raw, bool)
|
||||
else limited_request_mode_raw == "True"
|
||||
)
|
||||
|
||||
token_limit = int(
|
||||
agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"])
|
||||
)
|
||||
request_limit = int(
|
||||
agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"])
|
||||
)
|
||||
|
||||
token_usage_collection = self.db["token_usage"]
|
||||
|
||||
@@ -83,18 +91,20 @@ class BaseAnswerResource:
|
||||
|
||||
match_query = {
|
||||
"timestamp": {"$gte": start_date, "$lte": end_date},
|
||||
"api_key": api_key
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
||||
|
||||
if limited_token_mode:
|
||||
token_pipeline = [
|
||||
{"$match": match_query},
|
||||
{
|
||||
"$group": {
|
||||
"_id": None,
|
||||
"total_tokens": {"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}}
|
||||
"total_tokens": {
|
||||
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
]
|
||||
token_result = list(token_usage_collection.aggregate(token_pipeline))
|
||||
daily_token_usage = token_result[0]["total_tokens"] if token_result else 0
|
||||
@@ -108,26 +118,33 @@ class BaseAnswerResource:
|
||||
|
||||
if not limited_token_mode and not limited_request_mode:
|
||||
return None
|
||||
elif limited_token_mode and token_limit > daily_token_usage:
|
||||
return None
|
||||
elif limited_request_mode and request_limit > daily_request_usage:
|
||||
return None
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Exceeding usage limit, please try again later."
|
||||
}
|
||||
),
|
||||
429, # too many requests
|
||||
token_exceeded = (
|
||||
limited_token_mode and token_limit > 0 and daily_token_usage >= token_limit
|
||||
)
|
||||
request_exceeded = (
|
||||
limited_request_mode
|
||||
and request_limit > 0
|
||||
and daily_request_usage >= request_limit
|
||||
)
|
||||
|
||||
if token_exceeded or request_exceeded:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Exceeding usage limit, please try again later.",
|
||||
}
|
||||
),
|
||||
429,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def complete_stream(
|
||||
self,
|
||||
question: str,
|
||||
agent: Any,
|
||||
retriever: Any,
|
||||
conversation_id: Optional[str],
|
||||
user_api_key: Optional[str],
|
||||
decoded_token: Dict[str, Any],
|
||||
@@ -156,6 +173,7 @@ class BaseAnswerResource:
|
||||
agent_id: ID of agent used
|
||||
is_shared_usage: Flag for shared agent usage
|
||||
shared_token: Token for shared agent
|
||||
retrieved_docs: Pre-fetched documents for sources (optional)
|
||||
|
||||
Yields:
|
||||
Server-sent event strings
|
||||
@@ -166,7 +184,7 @@ class BaseAnswerResource:
|
||||
schema_info = None
|
||||
structured_chunks = []
|
||||
|
||||
for line in agent.gen(query=question, retriever=retriever):
|
||||
for line in agent.gen(query=question):
|
||||
if "answer" in line:
|
||||
response_full += str(line["answer"])
|
||||
if line.get("structured"):
|
||||
@@ -247,7 +265,6 @@ class BaseAnswerResource:
|
||||
data = json.dumps(id_data)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
retriever_params = retriever.get_params()
|
||||
log_data = {
|
||||
"action": "stream_answer",
|
||||
"level": "info",
|
||||
@@ -256,7 +273,6 @@ class BaseAnswerResource:
|
||||
"question": question,
|
||||
"response": response_full,
|
||||
"sources": source_log_docs,
|
||||
"retriever_params": retriever_params,
|
||||
"attachments": attachment_ids,
|
||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
@@ -264,24 +280,19 @@ class BaseAnswerResource:
|
||||
log_data["structured_output"] = True
|
||||
if schema_info:
|
||||
log_data["schema"] = schema_info
|
||||
|
||||
# clean up text fields to be no longer than 10000 characters
|
||||
|
||||
# Clean up text fields to be no longer than 10000 characters
|
||||
for key, value in log_data.items():
|
||||
if isinstance(value, str) and len(value) > 10000:
|
||||
log_data[key] = value[:10000]
|
||||
|
||||
self.user_logs_collection.insert_one(log_data)
|
||||
|
||||
# End of stream
|
||||
self.user_logs_collection.insert_one(log_data)
|
||||
|
||||
data = json.dumps({"type": "end"})
|
||||
yield f"data: {data}\n\n"
|
||||
except GeneratorExit:
|
||||
# Client aborted the connection
|
||||
logger.info(
|
||||
f"Stream aborted by client for question: {question[:50]}... "
|
||||
)
|
||||
# Save partial response to database before exiting
|
||||
logger.info(f"Stream aborted by client for question: {question[:50]}... ")
|
||||
# Save partial response
|
||||
if should_save_conversation and response_full:
|
||||
try:
|
||||
if isNoneDoc:
|
||||
@@ -311,7 +322,9 @@ class BaseAnswerResource:
|
||||
attachment_ids=attachment_ids,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving partial response: {str(e)}", exc_info=True)
|
||||
logger.error(
|
||||
f"Error saving partial response: {str(e)}", exc_info=True
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stream: {str(e)}", exc_info=True)
|
||||
|
||||
@@ -60,6 +60,10 @@ class StreamResource(Resource, BaseAnswerResource):
|
||||
"attachments": fields.List(
|
||||
fields.String, required=False, description="List of attachment IDs"
|
||||
),
|
||||
"passthrough": fields.Raw(
|
||||
required=False,
|
||||
description="Dynamic parameters to inject into prompt template",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -73,17 +77,20 @@ class StreamResource(Resource, BaseAnswerResource):
|
||||
processor = StreamProcessor(data, decoded_token)
|
||||
try:
|
||||
processor.initialize()
|
||||
agent = processor.create_agent()
|
||||
retriever = processor.create_retriever()
|
||||
|
||||
docs_together, docs_list = processor.pre_fetch_docs(data["question"])
|
||||
tools_data = processor.pre_fetch_tools()
|
||||
|
||||
agent = processor.create_agent(
|
||||
docs_together=docs_together, docs=docs_list, tools_data=tools_data
|
||||
)
|
||||
|
||||
if error := self.check_usage(processor.agent_config):
|
||||
return error
|
||||
|
||||
return Response(
|
||||
self.complete_stream(
|
||||
question=data["question"],
|
||||
agent=agent,
|
||||
retriever=retriever,
|
||||
conversation_id=processor.conversation_id,
|
||||
user_api_key=processor.agent_config.get("user_api_key"),
|
||||
decoded_token=processor.decoded_token,
|
||||
|
||||
@@ -133,10 +133,9 @@ class ConversationService:
|
||||
|
||||
messages_summary = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Summarise following conversation in no more than 3 "
|
||||
"words, respond ONLY with the summary, use the same "
|
||||
"language as the user query",
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant that creates concise conversation titles. "
|
||||
"Summarize conversations in 3 words or less using the same language as the user.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
|
||||
97
application/api/answer/services/prompt_renderer.py
Normal file
97
application/api/answer/services/prompt_renderer.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from application.templates.namespaces import NamespaceManager
|
||||
|
||||
from application.templates.template_engine import TemplateEngine, TemplateRenderError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PromptRenderer:
|
||||
"""Service for rendering prompts with dynamic context using namespaces"""
|
||||
|
||||
def __init__(self):
|
||||
self.template_engine = TemplateEngine()
|
||||
self.namespace_manager = NamespaceManager()
|
||||
|
||||
def render_prompt(
|
||||
self,
|
||||
prompt_content: str,
|
||||
user_id: Optional[str] = None,
|
||||
request_id: Optional[str] = None,
|
||||
passthrough_data: Optional[Dict[str, Any]] = None,
|
||||
docs: Optional[list] = None,
|
||||
docs_together: Optional[str] = None,
|
||||
tools_data: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Render prompt with full context from all namespaces.
|
||||
|
||||
Args:
|
||||
prompt_content: Raw prompt template string
|
||||
user_id: Current user identifier
|
||||
request_id: Unique request identifier
|
||||
passthrough_data: Parameters from web request
|
||||
docs: RAG retrieved documents
|
||||
docs_together: Concatenated document content
|
||||
tools_data: Pre-fetched tool results organized by tool name
|
||||
**kwargs: Additional parameters for namespace builders
|
||||
|
||||
Returns:
|
||||
Rendered prompt string with all variables substituted
|
||||
|
||||
Raises:
|
||||
TemplateRenderError: If template rendering fails
|
||||
"""
|
||||
if not prompt_content:
|
||||
return ""
|
||||
|
||||
uses_template = self._uses_template_syntax(prompt_content)
|
||||
|
||||
if not uses_template:
|
||||
return self._apply_legacy_substitutions(prompt_content, docs_together)
|
||||
|
||||
try:
|
||||
context = self.namespace_manager.build_context(
|
||||
user_id=user_id,
|
||||
request_id=request_id,
|
||||
passthrough_data=passthrough_data,
|
||||
docs=docs,
|
||||
docs_together=docs_together,
|
||||
tools_data=tools_data,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return self.template_engine.render(prompt_content, context)
|
||||
except TemplateRenderError:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Prompt rendering failed: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise TemplateRenderError(error_msg) from e
|
||||
|
||||
def _uses_template_syntax(self, prompt_content: str) -> bool:
|
||||
"""Check if prompt uses Jinja2 template syntax"""
|
||||
return "{{" in prompt_content and "}}" in prompt_content
|
||||
|
||||
def _apply_legacy_substitutions(
|
||||
self, prompt_content: str, docs_together: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Apply backward-compatible substitutions for old prompt format.
|
||||
|
||||
Handles legacy {summaries} and {query} placeholders during transition period.
|
||||
"""
|
||||
if docs_together:
|
||||
prompt_content = prompt_content.replace("{summaries}", docs_together)
|
||||
return prompt_content
|
||||
|
||||
def validate_template(self, prompt_content: str) -> bool:
|
||||
"""Validate prompt template syntax"""
|
||||
return self.template_engine.validate_template(prompt_content)
|
||||
|
||||
def extract_variables(self, prompt_content: str) -> set[str]:
|
||||
"""Extract all variable names from prompt template"""
|
||||
return self.template_engine.extract_variables(prompt_content)
|
||||
@@ -3,7 +3,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, Set
|
||||
|
||||
from bson.dbref import DBRef
|
||||
|
||||
@@ -11,10 +11,15 @@ from bson.objectid import ObjectId
|
||||
|
||||
from application.agents.agent_creator import AgentCreator
|
||||
from application.api.answer.services.conversation_service import ConversationService
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.retriever.retriever_creator import RetrieverCreator
|
||||
from application.utils import get_gpt_model, limit_chat_history
|
||||
from application.utils import (
|
||||
calculate_doc_token_budget,
|
||||
get_gpt_model,
|
||||
limit_chat_history,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -73,12 +78,16 @@ class StreamProcessor:
|
||||
self.all_sources = []
|
||||
self.attachments = []
|
||||
self.history = []
|
||||
self.retrieved_docs = []
|
||||
self.agent_config = {}
|
||||
self.retriever_config = {}
|
||||
self.is_shared_usage = False
|
||||
self.shared_token = None
|
||||
self.gpt_model = get_gpt_model()
|
||||
self.conversation_service = ConversationService()
|
||||
self.prompt_renderer = PromptRenderer()
|
||||
self._prompt_content: Optional[str] = None
|
||||
self._required_tool_actions: Optional[Dict[str, Set[Optional[str]]]] = None
|
||||
|
||||
def initialize(self):
|
||||
"""Initialize all required components for processing"""
|
||||
@@ -311,19 +320,312 @@ class StreamProcessor:
|
||||
)
|
||||
|
||||
def _configure_retriever(self):
|
||||
"""Configure the retriever based on request data"""
|
||||
history_token_limit = int(self.data.get("token_limit", 2000))
|
||||
doc_token_limit = calculate_doc_token_budget(
|
||||
gpt_model=self.gpt_model, history_token_limit=history_token_limit
|
||||
)
|
||||
|
||||
self.retriever_config = {
|
||||
"retriever_name": self.data.get("retriever", "classic"),
|
||||
"chunks": int(self.data.get("chunks", 2)),
|
||||
"token_limit": self.data.get("token_limit", settings.DEFAULT_MAX_HISTORY),
|
||||
"doc_token_limit": doc_token_limit,
|
||||
"history_token_limit": history_token_limit,
|
||||
}
|
||||
|
||||
api_key = self.data.get("api_key") or self.agent_key
|
||||
if not api_key and "isNoneDoc" in self.data and self.data["isNoneDoc"]:
|
||||
self.retriever_config["chunks"] = 0
|
||||
|
||||
def create_agent(self):
|
||||
"""Create and return the configured agent"""
|
||||
def create_retriever(self):
|
||||
return RetrieverCreator.create_retriever(
|
||||
self.retriever_config["retriever_name"],
|
||||
source=self.source,
|
||||
chat_history=self.history,
|
||||
prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection),
|
||||
chunks=self.retriever_config["chunks"],
|
||||
doc_token_limit=self.retriever_config.get("doc_token_limit", 50000),
|
||||
gpt_model=self.gpt_model,
|
||||
user_api_key=self.agent_config["user_api_key"],
|
||||
decoded_token=self.decoded_token,
|
||||
)
|
||||
|
||||
def pre_fetch_docs(self, question: str) -> tuple[Optional[str], Optional[list]]:
|
||||
"""Pre-fetch documents for template rendering before agent creation"""
|
||||
if self.data.get("isNoneDoc", False):
|
||||
logger.info("Pre-fetch skipped: isNoneDoc=True")
|
||||
return None, None
|
||||
try:
|
||||
retriever = self.create_retriever()
|
||||
logger.info(
|
||||
f"Pre-fetching docs with chunks={retriever.chunks}, doc_token_limit={retriever.doc_token_limit}"
|
||||
)
|
||||
docs = retriever.search(question)
|
||||
logger.info(f"Pre-fetch retrieved {len(docs) if docs else 0} documents")
|
||||
|
||||
if not docs:
|
||||
logger.info("Pre-fetch: No documents returned from search")
|
||||
return None, None
|
||||
self.retrieved_docs = docs
|
||||
|
||||
docs_with_filenames = []
|
||||
for doc in docs:
|
||||
filename = doc.get("filename") or doc.get("title") or doc.get("source")
|
||||
if filename:
|
||||
chunk_header = str(filename)
|
||||
docs_with_filenames.append(f"{chunk_header}\n{doc['text']}")
|
||||
else:
|
||||
docs_with_filenames.append(doc["text"])
|
||||
docs_together = "\n\n".join(docs_with_filenames)
|
||||
|
||||
logger.info(f"Pre-fetch docs_together size: {len(docs_together)} chars")
|
||||
|
||||
return docs_together, docs
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to pre-fetch docs: {str(e)}", exc_info=True)
|
||||
return None, None
|
||||
|
||||
def pre_fetch_tools(self) -> Optional[Dict[str, Any]]:
|
||||
"""Pre-fetch tool data for template rendering before agent creation
|
||||
|
||||
Can be controlled via:
|
||||
1. Global setting: ENABLE_TOOL_PREFETCH in .env
|
||||
2. Per-request: disable_tool_prefetch in request data
|
||||
"""
|
||||
if not settings.ENABLE_TOOL_PREFETCH:
|
||||
logger.info(
|
||||
"Tool pre-fetching disabled globally via ENABLE_TOOL_PREFETCH setting"
|
||||
)
|
||||
return None
|
||||
|
||||
if self.data.get("disable_tool_prefetch", False):
|
||||
logger.info("Tool pre-fetching disabled for this request")
|
||||
return None
|
||||
|
||||
required_tool_actions = self._get_required_tool_actions()
|
||||
filtering_enabled = required_tool_actions is not None
|
||||
|
||||
try:
|
||||
user_tools_collection = self.db["user_tools"]
|
||||
user_id = self.initial_user_id or "local"
|
||||
|
||||
user_tools = list(
|
||||
user_tools_collection.find({"user": user_id, "status": True})
|
||||
)
|
||||
|
||||
if not user_tools:
|
||||
return None
|
||||
|
||||
tools_data = {}
|
||||
|
||||
for tool_doc in user_tools:
|
||||
tool_name = tool_doc.get("name")
|
||||
tool_id = str(tool_doc.get("_id"))
|
||||
|
||||
if filtering_enabled:
|
||||
required_actions_by_name = required_tool_actions.get(
|
||||
tool_name, set()
|
||||
)
|
||||
required_actions_by_id = required_tool_actions.get(tool_id, set())
|
||||
|
||||
required_actions = required_actions_by_name | required_actions_by_id
|
||||
|
||||
if not required_actions:
|
||||
continue
|
||||
else:
|
||||
required_actions = None
|
||||
|
||||
tool_data = self._fetch_tool_data(tool_doc, required_actions)
|
||||
if tool_data:
|
||||
tools_data[tool_name] = tool_data
|
||||
tools_data[tool_id] = tool_data
|
||||
|
||||
return tools_data if tools_data else None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to pre-fetch tools: {type(e).__name__}")
|
||||
return None
|
||||
|
||||
def _fetch_tool_data(
|
||||
self,
|
||||
tool_doc: Dict[str, Any],
|
||||
required_actions: Optional[Set[Optional[str]]],
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Fetch and execute tool actions with saved parameters"""
|
||||
try:
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
|
||||
tool_name = tool_doc.get("name")
|
||||
tool_config = tool_doc.get("config", {}).copy()
|
||||
tool_config["tool_id"] = str(tool_doc["_id"])
|
||||
|
||||
tool_manager = ToolManager(config={tool_name: tool_config})
|
||||
user_id = self.initial_user_id or "local"
|
||||
tool = tool_manager.load_tool(tool_name, tool_config, user_id=user_id)
|
||||
|
||||
if not tool:
|
||||
logger.debug(f"Tool '{tool_name}' failed to load")
|
||||
return None
|
||||
|
||||
tool_actions = tool.get_actions_metadata()
|
||||
if not tool_actions:
|
||||
logger.debug(f"Tool '{tool_name}' has no actions")
|
||||
return None
|
||||
|
||||
saved_actions = tool_doc.get("actions", [])
|
||||
|
||||
include_all_actions = required_actions is None or (
|
||||
required_actions and None in required_actions
|
||||
)
|
||||
allowed_actions: Set[str] = (
|
||||
{action for action in required_actions if isinstance(action, str)}
|
||||
if required_actions
|
||||
else set()
|
||||
)
|
||||
|
||||
action_results = {}
|
||||
for action_meta in tool_actions:
|
||||
action_name = action_meta.get("name")
|
||||
if action_name is None:
|
||||
continue
|
||||
if (
|
||||
not include_all_actions
|
||||
and allowed_actions
|
||||
and action_name not in allowed_actions
|
||||
):
|
||||
continue
|
||||
|
||||
try:
|
||||
saved_action = None
|
||||
for sa in saved_actions:
|
||||
if sa.get("name") == action_name:
|
||||
saved_action = sa
|
||||
break
|
||||
|
||||
action_params = action_meta.get("parameters", {})
|
||||
properties = action_params.get("properties", {})
|
||||
|
||||
kwargs = {}
|
||||
for param_name, param_spec in properties.items():
|
||||
if saved_action:
|
||||
saved_props = saved_action.get("parameters", {}).get(
|
||||
"properties", {}
|
||||
)
|
||||
if param_name in saved_props:
|
||||
param_value = saved_props[param_name].get("value")
|
||||
if param_value is not None:
|
||||
kwargs[param_name] = param_value
|
||||
continue
|
||||
|
||||
if param_name in tool_config:
|
||||
kwargs[param_name] = tool_config[param_name]
|
||||
elif "default" in param_spec:
|
||||
kwargs[param_name] = param_spec["default"]
|
||||
|
||||
result = tool.execute_action(action_name, **kwargs)
|
||||
action_results[action_name] = result
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"Action '{action_name}' execution failed: {type(e).__name__}"
|
||||
)
|
||||
continue
|
||||
|
||||
return action_results if action_results else None
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Tool pre-fetch failed for '{tool_name}': {type(e).__name__}")
|
||||
return None
|
||||
|
||||
def _get_prompt_content(self) -> Optional[str]:
|
||||
"""Retrieve and cache the raw prompt content for the current agent configuration."""
|
||||
if self._prompt_content is not None:
|
||||
return self._prompt_content
|
||||
prompt_id = (
|
||||
self.agent_config.get("prompt_id")
|
||||
if isinstance(self.agent_config, dict)
|
||||
else None
|
||||
)
|
||||
if not prompt_id:
|
||||
return None
|
||||
try:
|
||||
self._prompt_content = get_prompt(prompt_id, self.prompts_collection)
|
||||
except ValueError as e:
|
||||
logger.debug(f"Invalid prompt ID '{prompt_id}': {str(e)}")
|
||||
self._prompt_content = None
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to fetch prompt '{prompt_id}': {type(e).__name__}")
|
||||
self._prompt_content = None
|
||||
return self._prompt_content
|
||||
|
||||
def _get_required_tool_actions(self) -> Optional[Dict[str, Set[Optional[str]]]]:
|
||||
"""Determine which tool actions are referenced in the prompt template"""
|
||||
if self._required_tool_actions is not None:
|
||||
return self._required_tool_actions
|
||||
|
||||
prompt_content = self._get_prompt_content()
|
||||
if prompt_content is None:
|
||||
return None
|
||||
|
||||
if "{{" not in prompt_content or "}}" not in prompt_content:
|
||||
self._required_tool_actions = {}
|
||||
return self._required_tool_actions
|
||||
|
||||
try:
|
||||
from application.templates.template_engine import TemplateEngine
|
||||
|
||||
template_engine = TemplateEngine()
|
||||
usages = template_engine.extract_tool_usages(prompt_content)
|
||||
self._required_tool_actions = usages
|
||||
return self._required_tool_actions
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to extract tool usages: {type(e).__name__}")
|
||||
self._required_tool_actions = {}
|
||||
return self._required_tool_actions
|
||||
|
||||
def _fetch_memory_tool_data(
|
||||
self, tool_doc: Dict[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Fetch memory tool data for pre-injection into prompt"""
|
||||
try:
|
||||
tool_config = tool_doc.get("config", {}).copy()
|
||||
tool_config["tool_id"] = str(tool_doc["_id"])
|
||||
|
||||
from application.agents.tools.memory import MemoryTool
|
||||
|
||||
memory_tool = MemoryTool(tool_config, self.initial_user_id)
|
||||
|
||||
root_view = memory_tool.execute_action("view", path="/")
|
||||
|
||||
if "Error:" in root_view or not root_view.strip():
|
||||
return None
|
||||
|
||||
return {"root": root_view, "available": True}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch memory tool data: {str(e)}")
|
||||
return None
|
||||
|
||||
def create_agent(
|
||||
self,
|
||||
docs_together: Optional[str] = None,
|
||||
docs: Optional[list] = None,
|
||||
tools_data: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""Create and return the configured agent with rendered prompt"""
|
||||
raw_prompt = self._get_prompt_content()
|
||||
if raw_prompt is None:
|
||||
raw_prompt = get_prompt(
|
||||
self.agent_config["prompt_id"], self.prompts_collection
|
||||
)
|
||||
self._prompt_content = raw_prompt
|
||||
|
||||
rendered_prompt = self.prompt_renderer.render_prompt(
|
||||
prompt_content=raw_prompt,
|
||||
user_id=self.initial_user_id,
|
||||
request_id=self.data.get("request_id"),
|
||||
passthrough_data=self.data.get("passthrough"),
|
||||
docs=docs,
|
||||
docs_together=docs_together,
|
||||
tools_data=tools_data,
|
||||
)
|
||||
|
||||
return AgentCreator.create_agent(
|
||||
self.agent_config["agent_type"],
|
||||
endpoint="stream",
|
||||
@@ -331,23 +633,10 @@ class StreamProcessor:
|
||||
gpt_model=self.gpt_model,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=self.agent_config["user_api_key"],
|
||||
prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection),
|
||||
prompt=rendered_prompt,
|
||||
chat_history=self.history,
|
||||
retrieved_docs=self.retrieved_docs,
|
||||
decoded_token=self.decoded_token,
|
||||
attachments=self.attachments,
|
||||
json_schema=self.agent_config.get("json_schema"),
|
||||
)
|
||||
|
||||
def create_retriever(self):
|
||||
"""Create and return the configured retriever"""
|
||||
return RetrieverCreator.create_retriever(
|
||||
self.retriever_config["retriever_name"],
|
||||
source=self.source,
|
||||
chat_history=self.history,
|
||||
prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection),
|
||||
chunks=self.retriever_config["chunks"],
|
||||
token_limit=self.retriever_config["token_limit"],
|
||||
gpt_model=self.gpt_model,
|
||||
user_api_key=self.agent_config["user_api_key"],
|
||||
decoded_token=self.decoded_token,
|
||||
)
|
||||
|
||||
@@ -10,7 +10,6 @@ from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.core.settings import settings
|
||||
from application.api.user.base import (
|
||||
agents_collection,
|
||||
db,
|
||||
@@ -20,6 +19,7 @@ from application.api.user.base import (
|
||||
storage,
|
||||
users_collection,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
from application.utils import (
|
||||
check_required_fields,
|
||||
generate_image_url,
|
||||
@@ -76,9 +76,13 @@ class GetAgent(Resource):
|
||||
"status": agent.get("status", ""),
|
||||
"json_schema": agent.get("json_schema"),
|
||||
"limited_token_mode": agent.get("limited_token_mode", False),
|
||||
"token_limit": agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]),
|
||||
"token_limit": agent.get(
|
||||
"token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]
|
||||
),
|
||||
"limited_request_mode": agent.get("limited_request_mode", False),
|
||||
"request_limit": agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]),
|
||||
"request_limit": agent.get(
|
||||
"request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]
|
||||
),
|
||||
"created_at": agent.get("createdAt", ""),
|
||||
"updated_at": agent.get("updatedAt", ""),
|
||||
"last_used_at": agent.get("lastUsedAt", ""),
|
||||
@@ -149,9 +153,13 @@ class GetAgents(Resource):
|
||||
"status": agent.get("status", ""),
|
||||
"json_schema": agent.get("json_schema"),
|
||||
"limited_token_mode": agent.get("limited_token_mode", False),
|
||||
"token_limit": agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]),
|
||||
"token_limit": agent.get(
|
||||
"token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]
|
||||
),
|
||||
"limited_request_mode": agent.get("limited_request_mode", False),
|
||||
"request_limit": agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]),
|
||||
"request_limit": agent.get(
|
||||
"request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]
|
||||
),
|
||||
"created_at": agent.get("createdAt", ""),
|
||||
"updated_at": agent.get("updatedAt", ""),
|
||||
"last_used_at": agent.get("lastUsedAt", ""),
|
||||
@@ -209,21 +217,19 @@ class CreateAgent(Resource):
|
||||
description="JSON schema for enforcing structured output format",
|
||||
),
|
||||
"limited_token_mode": fields.Boolean(
|
||||
required=False,
|
||||
description="Whether the agent is in limited token mode"
|
||||
required=False, description="Whether the agent is in limited token mode"
|
||||
),
|
||||
"token_limit": fields.Integer(
|
||||
required=False,
|
||||
description="Token limit for the agent in limited mode"
|
||||
required=False, description="Token limit for the agent in limited mode"
|
||||
),
|
||||
"limited_request_mode": fields.Boolean(
|
||||
required=False,
|
||||
description="Whether the agent is in limited request mode"
|
||||
description="Whether the agent is in limited request mode",
|
||||
),
|
||||
"request_limit": fields.Integer(
|
||||
required=False,
|
||||
description="Request limit for the agent in limited mode"
|
||||
)
|
||||
description="Request limit for the agent in limited mode",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -369,10 +375,26 @@ class CreateAgent(Resource):
|
||||
"agent_type": data.get("agent_type", ""),
|
||||
"status": data.get("status"),
|
||||
"json_schema": data.get("json_schema"),
|
||||
"limited_token_mode": data.get("limited_token_mode", False),
|
||||
"token_limit": data.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]),
|
||||
"limited_request_mode": data.get("limited_request_mode", False),
|
||||
"request_limit": data.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]),
|
||||
"limited_token_mode": (
|
||||
data.get("limited_token_mode") == "True"
|
||||
if isinstance(data.get("limited_token_mode"), str)
|
||||
else bool(data.get("limited_token_mode", False))
|
||||
),
|
||||
"token_limit": int(
|
||||
data.get(
|
||||
"token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]
|
||||
)
|
||||
),
|
||||
"limited_request_mode": (
|
||||
data.get("limited_request_mode") == "True"
|
||||
if isinstance(data.get("limited_request_mode"), str)
|
||||
else bool(data.get("limited_request_mode", False))
|
||||
),
|
||||
"request_limit": int(
|
||||
data.get(
|
||||
"request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]
|
||||
)
|
||||
),
|
||||
"createdAt": datetime.datetime.now(datetime.timezone.utc),
|
||||
"updatedAt": datetime.datetime.now(datetime.timezone.utc),
|
||||
"lastUsedAt": None,
|
||||
@@ -429,21 +451,19 @@ class UpdateAgent(Resource):
|
||||
description="JSON schema for enforcing structured output format",
|
||||
),
|
||||
"limited_token_mode": fields.Boolean(
|
||||
required=False,
|
||||
description="Whether the agent is in limited token mode"
|
||||
required=False, description="Whether the agent is in limited token mode"
|
||||
),
|
||||
"token_limit": fields.Integer(
|
||||
required=False,
|
||||
description="Token limit for the agent in limited mode"
|
||||
required=False, description="Token limit for the agent in limited mode"
|
||||
),
|
||||
"limited_request_mode": fields.Boolean(
|
||||
require=False,
|
||||
description="Whether the agent is in limited request mode"
|
||||
description="Whether the agent is in limited request mode",
|
||||
),
|
||||
"request_limit": fields.Integer(
|
||||
required=False,
|
||||
description="Request limit for the agent in limited mode"
|
||||
)
|
||||
description="Request limit for the agent in limited mode",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -534,7 +554,7 @@ class UpdateAgent(Resource):
|
||||
"limited_token_mode",
|
||||
"token_limit",
|
||||
"limited_request_mode",
|
||||
"request_limit"
|
||||
"request_limit",
|
||||
]
|
||||
|
||||
for field in allowed_fields:
|
||||
@@ -652,8 +672,15 @@ class UpdateAgent(Resource):
|
||||
else:
|
||||
update_fields[field] = None
|
||||
elif field == "limited_token_mode":
|
||||
is_mode_enabled = data.get("limited_token_mode", False)
|
||||
if is_mode_enabled and data.get("token_limit") is None:
|
||||
raw_value = data.get("limited_token_mode", False)
|
||||
bool_value = (
|
||||
raw_value == "True"
|
||||
if isinstance(raw_value, str)
|
||||
else bool(raw_value)
|
||||
)
|
||||
update_fields[field] = bool_value
|
||||
|
||||
if bool_value and data.get("token_limit") is None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
@@ -664,8 +691,15 @@ class UpdateAgent(Resource):
|
||||
400,
|
||||
)
|
||||
elif field == "limited_request_mode":
|
||||
is_mode_enabled = data.get("limited_request_mode", False)
|
||||
if is_mode_enabled and data.get("request_limit") is None:
|
||||
raw_value = data.get("limited_request_mode", False)
|
||||
bool_value = (
|
||||
raw_value == "True"
|
||||
if isinstance(raw_value, str)
|
||||
else bool(raw_value)
|
||||
)
|
||||
update_fields[field] = bool_value
|
||||
|
||||
if bool_value and data.get("request_limit") is None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
@@ -677,7 +711,11 @@ class UpdateAgent(Resource):
|
||||
)
|
||||
elif field == "token_limit":
|
||||
token_limit = data.get("token_limit")
|
||||
if token_limit is not None and not data.get("limited_token_mode"):
|
||||
# Convert to int and store
|
||||
update_fields[field] = int(token_limit) if token_limit else 0
|
||||
|
||||
# Validate consistency with mode
|
||||
if update_fields[field] > 0 and not data.get("limited_token_mode"):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
@@ -689,7 +727,9 @@ class UpdateAgent(Resource):
|
||||
)
|
||||
elif field == "request_limit":
|
||||
request_limit = data.get("request_limit")
|
||||
if request_limit is not None and not data.get("limited_request_mode"):
|
||||
update_fields[field] = int(request_limit) if request_limit else 0
|
||||
|
||||
if update_fields[field] > 0 and not data.get("limited_request_mode"):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
|
||||
@@ -23,10 +23,18 @@ class Settings(BaseSettings):
|
||||
LLM_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
|
||||
DEFAULT_MAX_HISTORY: int = 150
|
||||
LLM_TOKEN_LIMITS: dict = {
|
||||
"gpt-4o": 128000,
|
||||
"gpt-4o-mini": 128000,
|
||||
"gpt-4": 8192,
|
||||
"gpt-3.5-turbo": 4096,
|
||||
"claude-2": 1e5,
|
||||
"gemini-2.5-flash": 1e6,
|
||||
"claude-2": int(1e5),
|
||||
"gemini-2.5-flash": int(1e6),
|
||||
}
|
||||
DEFAULT_LLM_TOKEN_LIMIT: int = 128000
|
||||
RESERVED_TOKENS: dict = {
|
||||
"system_prompt": 500,
|
||||
"current_query": 500,
|
||||
"safety_buffer": 1000,
|
||||
}
|
||||
DEFAULT_AGENT_LIMITS: dict = {
|
||||
"token_limit": 50000,
|
||||
@@ -133,5 +141,8 @@ class Settings(BaseSettings):
|
||||
TTS_PROVIDER: str = "google_tts" # google_tts or elevenlabs
|
||||
ELEVENLABS_API_KEY: Optional[str] = None
|
||||
|
||||
# Tool pre-fetch settings
|
||||
ENABLE_TOOL_PREFETCH: bool = True
|
||||
|
||||
path = Path(__file__).parent.parent.absolute()
|
||||
settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8")
|
||||
|
||||
@@ -44,6 +44,12 @@ class BaseLLM(ABC):
|
||||
)
|
||||
return self._fallback_llm
|
||||
|
||||
@staticmethod
|
||||
def _remove_null_values(args_dict):
|
||||
if not isinstance(args_dict, dict):
|
||||
return args_dict
|
||||
return {k: v for k, v in args_dict.items() if v is not None}
|
||||
|
||||
def _execute_with_fallback(
|
||||
self, method_name: str, decorators: list, *args, **kwargs
|
||||
):
|
||||
|
||||
@@ -33,14 +33,15 @@ class DocsGPTAPILLM(BaseLLM):
|
||||
{"role": role, "content": item["text"]}
|
||||
)
|
||||
elif "function_call" in item:
|
||||
cleaned_args = self._remove_null_values(
|
||||
item["function_call"]["args"]
|
||||
)
|
||||
tool_call = {
|
||||
"id": item["function_call"]["call_id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": item["function_call"]["name"],
|
||||
"arguments": json.dumps(
|
||||
item["function_call"]["args"]
|
||||
),
|
||||
"arguments": json.dumps(cleaned_args),
|
||||
},
|
||||
}
|
||||
cleaned_messages.append(
|
||||
|
||||
@@ -163,10 +163,14 @@ class GoogleLLM(BaseLLM):
|
||||
if "text" in item:
|
||||
parts.append(types.Part.from_text(text=item["text"]))
|
||||
elif "function_call" in item:
|
||||
# Remove null values from args to avoid API errors
|
||||
cleaned_args = self._remove_null_values(
|
||||
item["function_call"]["args"]
|
||||
)
|
||||
parts.append(
|
||||
types.Part.from_function_call(
|
||||
name=item["function_call"]["name"],
|
||||
args=item["function_call"]["args"],
|
||||
args=cleaned_args,
|
||||
)
|
||||
)
|
||||
elif "function_response" in item:
|
||||
@@ -386,7 +390,7 @@ class GoogleLLM(BaseLLM):
|
||||
elif hasattr(chunk, "text"):
|
||||
yield chunk.text
|
||||
finally:
|
||||
if hasattr(response, 'close'):
|
||||
if hasattr(response, "close"):
|
||||
response.close()
|
||||
|
||||
def _supports_tools(self):
|
||||
|
||||
@@ -44,14 +44,15 @@ class OpenAILLM(BaseLLM):
|
||||
{"role": role, "content": item["text"]}
|
||||
)
|
||||
elif "function_call" in item:
|
||||
cleaned_args = self._remove_null_values(
|
||||
item["function_call"]["args"]
|
||||
)
|
||||
tool_call = {
|
||||
"id": item["function_call"]["call_id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": item["function_call"]["name"],
|
||||
"arguments": json.dumps(
|
||||
item["function_call"]["args"]
|
||||
),
|
||||
"arguments": json.dumps(cleaned_args),
|
||||
},
|
||||
}
|
||||
cleaned_messages.append(
|
||||
@@ -181,7 +182,7 @@ class OpenAILLM(BaseLLM):
|
||||
elif len(line.choices) > 0:
|
||||
yield line.choices[0]
|
||||
finally:
|
||||
if hasattr(response, 'close'):
|
||||
if hasattr(response, "close"):
|
||||
response.close()
|
||||
|
||||
def _supports_tools(self):
|
||||
|
||||
@@ -8,7 +8,3 @@ class BaseRetriever(ABC):
|
||||
@abstractmethod
|
||||
def search(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_params(self):
|
||||
pass
|
||||
|
||||
@@ -4,7 +4,7 @@ import os
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.retriever.base import BaseRetriever
|
||||
|
||||
from application.utils import num_tokens_from_string
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
|
||||
|
||||
@@ -15,14 +15,13 @@ class ClassicRAG(BaseRetriever):
|
||||
chat_history=None,
|
||||
prompt="",
|
||||
chunks=2,
|
||||
token_limit=150,
|
||||
doc_token_limit=50000,
|
||||
gpt_model="docsgpt",
|
||||
user_api_key=None,
|
||||
llm_name=settings.LLM_PROVIDER,
|
||||
api_key=settings.API_KEY,
|
||||
decoded_token=None,
|
||||
):
|
||||
"""Initialize ClassicRAG retriever with vectorstore sources and LLM configuration"""
|
||||
self.original_question = source.get("question", "")
|
||||
self.chat_history = chat_history if chat_history is not None else []
|
||||
self.prompt = prompt
|
||||
@@ -42,16 +41,7 @@ class ClassicRAG(BaseRetriever):
|
||||
f"sources={'active_docs' in source and source['active_docs'] is not None}"
|
||||
)
|
||||
self.gpt_model = gpt_model
|
||||
self.token_limit = (
|
||||
token_limit
|
||||
if token_limit
|
||||
< settings.LLM_TOKEN_LIMITS.get(
|
||||
self.gpt_model, settings.DEFAULT_MAX_HISTORY
|
||||
)
|
||||
else settings.LLM_TOKEN_LIMITS.get(
|
||||
self.gpt_model, settings.DEFAULT_MAX_HISTORY
|
||||
)
|
||||
)
|
||||
self.doc_token_limit = doc_token_limit
|
||||
self.user_api_key = user_api_key
|
||||
self.llm_name = llm_name
|
||||
self.api_key = api_key
|
||||
@@ -118,21 +108,17 @@ class ClassicRAG(BaseRetriever):
|
||||
return self.original_question
|
||||
|
||||
def _get_data(self):
|
||||
"""Retrieve relevant documents from configured vectorstores"""
|
||||
if self.chunks == 0 or not self.vectorstores:
|
||||
logging.info(
|
||||
f"ClassicRAG._get_data: Skipping retrieval - chunks={self.chunks}, "
|
||||
f"vectorstores_count={len(self.vectorstores) if self.vectorstores else 0}"
|
||||
)
|
||||
return []
|
||||
|
||||
all_docs = []
|
||||
chunks_per_source = max(1, self.chunks // len(self.vectorstores))
|
||||
|
||||
logging.info(
|
||||
f"ClassicRAG._get_data: Starting retrieval with chunks={self.chunks}, "
|
||||
f"vectorstores={self.vectorstores}, chunks_per_source={chunks_per_source}, "
|
||||
f"query='{self.question[:50]}...'"
|
||||
)
|
||||
token_budget = max(int(self.doc_token_limit * 0.9), 100)
|
||||
cumulative_tokens = 0
|
||||
|
||||
for vectorstore_id in self.vectorstores:
|
||||
if vectorstore_id:
|
||||
@@ -140,15 +126,21 @@ class ClassicRAG(BaseRetriever):
|
||||
docsearch = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE, vectorstore_id, settings.EMBEDDINGS_KEY
|
||||
)
|
||||
docs_temp = docsearch.search(self.question, k=chunks_per_source)
|
||||
docs_temp = docsearch.search(
|
||||
self.question, k=max(chunks_per_source * 2, 20)
|
||||
)
|
||||
|
||||
for doc in docs_temp:
|
||||
if cumulative_tokens >= token_budget:
|
||||
break
|
||||
|
||||
if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
|
||||
page_content = doc.page_content
|
||||
metadata = doc.metadata
|
||||
else:
|
||||
page_content = doc.get("text", doc.get("page_content", ""))
|
||||
metadata = doc.get("metadata", {})
|
||||
|
||||
title = metadata.get(
|
||||
"title", metadata.get("post_title", page_content)
|
||||
)
|
||||
@@ -168,23 +160,35 @@ class ClassicRAG(BaseRetriever):
|
||||
if not filename:
|
||||
filename = title
|
||||
source_path = metadata.get("source") or vectorstore_id
|
||||
all_docs.append(
|
||||
{
|
||||
"title": title,
|
||||
"text": page_content,
|
||||
"source": source_path,
|
||||
"filename": filename,
|
||||
}
|
||||
)
|
||||
|
||||
doc_text_with_header = f"{filename}\n{page_content}"
|
||||
doc_tokens = num_tokens_from_string(doc_text_with_header)
|
||||
|
||||
if cumulative_tokens + doc_tokens < token_budget:
|
||||
all_docs.append(
|
||||
{
|
||||
"title": title,
|
||||
"text": page_content,
|
||||
"source": source_path,
|
||||
"filename": filename,
|
||||
}
|
||||
)
|
||||
cumulative_tokens += doc_tokens
|
||||
|
||||
if cumulative_tokens >= token_budget:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error searching vectorstore {vectorstore_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
continue
|
||||
|
||||
logging.info(
|
||||
f"ClassicRAG._get_data: Retrieval complete - retrieved {len(all_docs)} documents "
|
||||
f"(requested chunks={self.chunks}, chunks_per_source={chunks_per_source})"
|
||||
f"(requested chunks={self.chunks}, chunks_per_source={chunks_per_source}, "
|
||||
f"cumulative_tokens={cumulative_tokens}/{token_budget})"
|
||||
)
|
||||
return all_docs
|
||||
|
||||
@@ -194,15 +198,3 @@ class ClassicRAG(BaseRetriever):
|
||||
self.original_question = query
|
||||
self.question = self._rephrase_query()
|
||||
return self._get_data()
|
||||
|
||||
def get_params(self):
|
||||
"""Return current retriever configuration parameters"""
|
||||
return {
|
||||
"question": self.original_question,
|
||||
"rephrased_question": self.question,
|
||||
"sources": self.vectorstores,
|
||||
"chunks": self.chunks,
|
||||
"token_limit": self.token_limit,
|
||||
"gpt_model": self.gpt_model,
|
||||
"user_api_key": self.user_api_key,
|
||||
}
|
||||
|
||||
0
application/templates/__init__.py
Normal file
0
application/templates/__init__.py
Normal file
190
application/templates/namespaces.py
Normal file
190
application/templates/namespaces.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NamespaceBuilder(ABC):
|
||||
"""Base class for building template context namespaces"""
|
||||
|
||||
@abstractmethod
|
||||
def build(self, **kwargs) -> Dict[str, Any]:
|
||||
"""Build namespace context dictionary"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def namespace_name(self) -> str:
|
||||
"""Name of this namespace for template access"""
|
||||
pass
|
||||
|
||||
|
||||
class SystemNamespace(NamespaceBuilder):
|
||||
"""System metadata namespace: {{ system.* }}"""
|
||||
|
||||
@property
|
||||
def namespace_name(self) -> str:
|
||||
return "system"
|
||||
|
||||
def build(
|
||||
self, request_id: Optional[str] = None, user_id: Optional[str] = None, **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build system context with metadata.
|
||||
|
||||
Args:
|
||||
request_id: Unique request identifier
|
||||
user_id: Current user identifier
|
||||
|
||||
Returns:
|
||||
Dictionary with system variables
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
return {
|
||||
"date": now.strftime("%Y-%m-%d"),
|
||||
"time": now.strftime("%H:%M:%S"),
|
||||
"timestamp": now.isoformat(),
|
||||
"request_id": request_id or str(uuid.uuid4()),
|
||||
"user_id": user_id,
|
||||
}
|
||||
|
||||
|
||||
class PassthroughNamespace(NamespaceBuilder):
|
||||
"""Request parameters namespace: {{ passthrough.* }}"""
|
||||
|
||||
@property
|
||||
def namespace_name(self) -> str:
|
||||
return "passthrough"
|
||||
|
||||
def build(
|
||||
self, passthrough_data: Optional[Dict[str, Any]] = None, **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build passthrough context from request parameters.
|
||||
|
||||
Args:
|
||||
passthrough_data: Dictionary of parameters from web request
|
||||
|
||||
Returns:
|
||||
Dictionary with passthrough variables
|
||||
"""
|
||||
if not passthrough_data:
|
||||
return {}
|
||||
safe_data = {}
|
||||
for key, value in passthrough_data.items():
|
||||
if isinstance(value, (str, int, float, bool, type(None))):
|
||||
safe_data[key] = value
|
||||
else:
|
||||
logger.warning(
|
||||
f"Skipping non-serializable passthrough value for key '{key}': {type(value)}"
|
||||
)
|
||||
return safe_data
|
||||
|
||||
|
||||
class SourceNamespace(NamespaceBuilder):
|
||||
"""RAG source documents namespace: {{ source.* }}"""
|
||||
|
||||
@property
|
||||
def namespace_name(self) -> str:
|
||||
return "source"
|
||||
|
||||
def build(
|
||||
self, docs: Optional[list] = None, docs_together: Optional[str] = None, **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build source context from RAG retrieval results.
|
||||
|
||||
Args:
|
||||
docs: List of retrieved documents
|
||||
docs_together: Concatenated document content (for backward compatibility)
|
||||
|
||||
Returns:
|
||||
Dictionary with source variables
|
||||
"""
|
||||
context = {}
|
||||
|
||||
if docs:
|
||||
context["documents"] = docs
|
||||
context["count"] = len(docs)
|
||||
if docs_together:
|
||||
context["docs_together"] = docs_together # Add docs_together for custom templates
|
||||
context["content"] = docs_together
|
||||
context["summaries"] = docs_together
|
||||
return context
|
||||
|
||||
|
||||
class ToolsNamespace(NamespaceBuilder):
|
||||
"""Pre-executed tools namespace: {{ tools.* }}"""
|
||||
|
||||
@property
|
||||
def namespace_name(self) -> str:
|
||||
return "tools"
|
||||
|
||||
def build(
|
||||
self, tools_data: Optional[Dict[str, Any]] = None, **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build tools context with pre-executed tool results.
|
||||
|
||||
Args:
|
||||
tools_data: Dictionary of pre-fetched tool results organized by tool name
|
||||
e.g., {"memory": {"notes": "content", "tasks": "list"}}
|
||||
|
||||
Returns:
|
||||
Dictionary with tool results organized by tool name
|
||||
"""
|
||||
if not tools_data:
|
||||
return {}
|
||||
|
||||
safe_data = {}
|
||||
for tool_name, tool_result in tools_data.items():
|
||||
if isinstance(tool_result, (str, dict, list, int, float, bool, type(None))):
|
||||
safe_data[tool_name] = tool_result
|
||||
else:
|
||||
logger.warning(
|
||||
f"Skipping non-serializable tool result for '{tool_name}': {type(tool_result)}"
|
||||
)
|
||||
return safe_data
|
||||
|
||||
|
||||
class NamespaceManager:
|
||||
"""Manages all namespace builders and context assembly"""
|
||||
|
||||
def __init__(self):
|
||||
self._builders = {
|
||||
"system": SystemNamespace(),
|
||||
"passthrough": PassthroughNamespace(),
|
||||
"source": SourceNamespace(),
|
||||
"tools": ToolsNamespace(),
|
||||
}
|
||||
|
||||
def build_context(self, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Build complete template context from all namespaces.
|
||||
|
||||
Args:
|
||||
**kwargs: Parameters to pass to namespace builders
|
||||
|
||||
Returns:
|
||||
Complete context dictionary for template rendering
|
||||
"""
|
||||
context = {}
|
||||
|
||||
for namespace_name, builder in self._builders.items():
|
||||
try:
|
||||
namespace_context = builder.build(**kwargs)
|
||||
# Always include namespace, even if empty, to prevent undefined errors
|
||||
context[namespace_name] = namespace_context if namespace_context else {}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to build {namespace_name} namespace: {str(e)}")
|
||||
# Include empty namespace on error to prevent template failures
|
||||
context[namespace_name] = {}
|
||||
return context
|
||||
|
||||
def get_builder(self, namespace_name: str) -> Optional[NamespaceBuilder]:
|
||||
"""Get specific namespace builder"""
|
||||
return self._builders.get(namespace_name)
|
||||
161
application/templates/template_engine.py
Normal file
161
application/templates/template_engine.py
Normal file
@@ -0,0 +1,161 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from jinja2 import (
|
||||
ChainableUndefined,
|
||||
Environment,
|
||||
nodes,
|
||||
select_autoescape,
|
||||
TemplateSyntaxError,
|
||||
)
|
||||
from jinja2.exceptions import UndefinedError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TemplateRenderError(Exception):
|
||||
"""Raised when template rendering fails"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TemplateEngine:
|
||||
"""Jinja2-based template engine for dynamic prompt rendering"""
|
||||
|
||||
def __init__(self):
|
||||
self._env = Environment(
|
||||
undefined=ChainableUndefined,
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
autoescape=select_autoescape(default_for_string=True, default=True),
|
||||
)
|
||||
|
||||
def render(self, template_content: str, context: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Render template with provided context.
|
||||
|
||||
Args:
|
||||
template_content: Raw template string with Jinja2 syntax
|
||||
context: Dictionary of variables to inject into template
|
||||
|
||||
Returns:
|
||||
Rendered template string
|
||||
|
||||
Raises:
|
||||
TemplateRenderError: If template syntax is invalid or variables undefined
|
||||
"""
|
||||
if not template_content:
|
||||
return ""
|
||||
try:
|
||||
template = self._env.from_string(template_content)
|
||||
return template.render(**context)
|
||||
except TemplateSyntaxError as e:
|
||||
error_msg = f"Template syntax error at line {e.lineno}: {e.message}"
|
||||
logger.error(error_msg)
|
||||
raise TemplateRenderError(error_msg) from e
|
||||
except UndefinedError as e:
|
||||
error_msg = f"Undefined variable in template: {e.message}"
|
||||
logger.error(error_msg)
|
||||
raise TemplateRenderError(error_msg) from e
|
||||
except Exception as e:
|
||||
error_msg = f"Template rendering failed: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise TemplateRenderError(error_msg) from e
|
||||
|
||||
def validate_template(self, template_content: str) -> bool:
|
||||
"""
|
||||
Validate template syntax without rendering.
|
||||
|
||||
Args:
|
||||
template_content: Template string to validate
|
||||
|
||||
Returns:
|
||||
True if template is syntactically valid
|
||||
"""
|
||||
if not template_content:
|
||||
return True
|
||||
try:
|
||||
self._env.from_string(template_content)
|
||||
return True
|
||||
except TemplateSyntaxError as e:
|
||||
logger.debug(f"Template syntax invalid at line {e.lineno}: {e.message}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.debug(f"Template validation error: {type(e).__name__}: {str(e)}")
|
||||
return False
|
||||
|
||||
def extract_variables(self, template_content: str) -> Set[str]:
|
||||
"""
|
||||
Extract all variable names from template.
|
||||
|
||||
Args:
|
||||
template_content: Template string to analyze
|
||||
|
||||
Returns:
|
||||
Set of variable names found in template
|
||||
"""
|
||||
if not template_content:
|
||||
return set()
|
||||
try:
|
||||
ast = self._env.parse(template_content)
|
||||
return set(self._env.get_template_module(ast).make_module().keys())
|
||||
except TemplateSyntaxError as e:
|
||||
logger.debug(f"Cannot extract variables - syntax error at line {e.lineno}")
|
||||
return set()
|
||||
except Exception as e:
|
||||
logger.debug(f"Cannot extract variables: {type(e).__name__}")
|
||||
return set()
|
||||
|
||||
def extract_tool_usages(
|
||||
self, template_content: str
|
||||
) -> Dict[str, Set[Optional[str]]]:
|
||||
"""Extract tool and action references from a template"""
|
||||
if not template_content:
|
||||
return {}
|
||||
try:
|
||||
ast = self._env.parse(template_content)
|
||||
except TemplateSyntaxError as e:
|
||||
logger.debug(f"extract_tool_usages - syntax error at line {e.lineno}")
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.debug(f"extract_tool_usages - parse error: {type(e).__name__}")
|
||||
return {}
|
||||
|
||||
usages: Dict[str, Set[Optional[str]]] = {}
|
||||
|
||||
def record(path: List[str]) -> None:
|
||||
if not path:
|
||||
return
|
||||
tool_name = path[0]
|
||||
action_name = path[1] if len(path) > 1 else None
|
||||
if not tool_name:
|
||||
return
|
||||
tool_entry = usages.setdefault(tool_name, set())
|
||||
tool_entry.add(action_name)
|
||||
|
||||
for node in ast.find_all(nodes.Getattr):
|
||||
path = []
|
||||
current = node
|
||||
while isinstance(current, nodes.Getattr):
|
||||
path.append(current.attr)
|
||||
current = current.node
|
||||
if isinstance(current, nodes.Name) and current.name == "tools":
|
||||
path.reverse()
|
||||
record(path)
|
||||
|
||||
for node in ast.find_all(nodes.Getitem):
|
||||
path = []
|
||||
current = node
|
||||
while isinstance(current, nodes.Getitem):
|
||||
key = current.arg
|
||||
if isinstance(key, nodes.Const) and isinstance(key.value, str):
|
||||
path.append(key.value)
|
||||
else:
|
||||
path = []
|
||||
break
|
||||
current = current.node
|
||||
if path and isinstance(current, nodes.Name) and current.name == "tools":
|
||||
path.reverse()
|
||||
record(path)
|
||||
|
||||
return usages
|
||||
@@ -74,6 +74,17 @@ def count_tokens_docs(docs):
|
||||
return tokens
|
||||
|
||||
|
||||
def calculate_doc_token_budget(
|
||||
gpt_model: str = "gpt-4o", history_token_limit: int = 2000
|
||||
) -> int:
|
||||
total_context = settings.LLM_TOKEN_LIMITS.get(
|
||||
gpt_model, settings.DEFAULT_LLM_TOKEN_LIMIT
|
||||
)
|
||||
reserved = sum(settings.RESERVED_TOKENS.values())
|
||||
doc_budget = total_context - history_token_limit - reserved
|
||||
return max(doc_budget, 1000)
|
||||
|
||||
|
||||
def get_missing_fields(data, required_fields):
|
||||
"""Check for missing required fields. Returns list of missing field names."""
|
||||
return [field for field in required_fields if field not in data]
|
||||
@@ -141,8 +152,8 @@ def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"):
|
||||
max_token_limit
|
||||
if max_token_limit
|
||||
and max_token_limit
|
||||
< settings.LLM_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY)
|
||||
else settings.LLM_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY)
|
||||
< settings.LLM_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_LLM_TOKEN_LIMIT)
|
||||
else settings.LLM_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_LLM_TOKEN_LIMIT)
|
||||
)
|
||||
|
||||
if not history:
|
||||
|
||||
@@ -1,49 +1,453 @@
|
||||
---
|
||||
title: Customizing Prompts
|
||||
description: This guide will explain how to change prompts in DocsGPT and why it might be benefitial. Additionaly this article expains additional variables that can be used in prompts.
|
||||
title: Customizing Prompts
|
||||
description: This guide explains how to customize prompts in DocsGPT using the new template-based system with dynamic variable injection.
|
||||
---
|
||||
|
||||
import Image from 'next/image'
|
||||
|
||||
# Customizing the Main Prompt
|
||||
# Customizing Prompts in DocsGPT
|
||||
|
||||
Customizing the main prompt for DocsGPT gives you the ability to tailor the AI's responses to your specific requirements. By modifying the prompt text, you can achieve more accurate and relevant answers. Here's how you can do it:
|
||||
Customizing prompts for DocsGPT gives you powerful control over the AI's behavior and responses. With the new template-based system, you can inject dynamic context through organized namespaces, making prompts flexible and maintainable without hardcoding values.
|
||||
|
||||
## Quick Start
|
||||
|
||||
1. Navigate to `SideBar -> Settings`.
|
||||
|
||||
|
||||
|
||||
|
||||
2.In Settings select the `Active Prompt` now you will be able to see various prompts style.x
|
||||
|
||||
|
||||
|
||||
|
||||
3.Click on the `edit icon` on the prompt of your choice and you will be able to see the current prompt for it,you can now customise the prompt as per your choice.
|
||||
2. In Settings, select the `Active Prompt` to see various prompt styles.
|
||||
3. Click on the `edit icon` on your chosen prompt to customize it.
|
||||
|
||||
### Video Demo
|
||||
<Image src="/prompts.gif" alt="prompts" width={800} height={500} />
|
||||
|
||||
---
|
||||
|
||||
## Template-Based Prompt System
|
||||
|
||||
## Example Prompt Modification
|
||||
DocsGPT now uses **Jinja2 templating** with four organized namespaces for dynamic variable injection:
|
||||
|
||||
### Available Namespaces
|
||||
|
||||
#### 1. **`system`** - System Metadata
|
||||
Access system-level information:
|
||||
|
||||
```jinja
|
||||
{{ system.date }} # Current date (YYYY-MM-DD)
|
||||
{{ system.time }} # Current time (HH:MM:SS)
|
||||
{{ system.timestamp }} # ISO 8601 timestamp
|
||||
{{ system.request_id }} # Unique request identifier
|
||||
{{ system.user_id }} # Current user ID
|
||||
```
|
||||
|
||||
#### 2. **`source`** - Retrieved Documents
|
||||
Access RAG (Retrieval-Augmented Generation) document context:
|
||||
|
||||
```jinja
|
||||
{{ source.content }} # Concatenated document content
|
||||
{{ source.summaries }} # Alias for content (backward compatible)
|
||||
{{ source.documents }} # List of document objects
|
||||
{{ source.count }} # Number of retrieved documents
|
||||
```
|
||||
|
||||
#### 3. **`passthrough`** - Request Parameters
|
||||
Access custom parameters passed in the API request:
|
||||
|
||||
```jinja
|
||||
{{ passthrough.company }} # Custom field from request
|
||||
{{ passthrough.user_name }} # User-provided data
|
||||
{{ passthrough.context }} # Any custom parameter
|
||||
```
|
||||
|
||||
To use passthrough data, send it in your API request:
|
||||
```json
|
||||
{
|
||||
"question": "What is the pricing?",
|
||||
"passthrough": {
|
||||
"company": "Acme Corp",
|
||||
"user_name": "Alice",
|
||||
"plan_type": "enterprise"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### 4. **`tools`** - Pre-fetched Tool Data
|
||||
Access results from tools that run before the agent (like memory tool):
|
||||
|
||||
```jinja
|
||||
{{ tools.memory.root }} # Memory tool directory listing
|
||||
{{ tools.memory.available }} # Boolean: is memory available
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Example Prompts
|
||||
|
||||
### Basic Prompt with Documents
|
||||
```jinja
|
||||
You are a helpful AI assistant for DocsGPT.
|
||||
|
||||
Current date: {{ system.date }}
|
||||
|
||||
Use the following documents to answer the question:
|
||||
|
||||
{{ source.content }}
|
||||
|
||||
Provide accurate, helpful answers with code examples when relevant.
|
||||
```
|
||||
|
||||
### Advanced Prompt with All Namespaces
|
||||
```jinja
|
||||
You are an AI assistant for {{ passthrough.company }}.
|
||||
|
||||
**System Info:**
|
||||
- Date: {{ system.date }}
|
||||
- Request ID: {{ system.request_id }}
|
||||
|
||||
**User Context:**
|
||||
- User: {{ passthrough.user_name }}
|
||||
- Role: {{ passthrough.role }}
|
||||
|
||||
**Available Documents ({{ source.count }}):**
|
||||
{{ source.content }}
|
||||
|
||||
**Memory Context:**
|
||||
{% if tools.memory.available %}
|
||||
{{ tools.memory.root }}
|
||||
{% else %}
|
||||
No saved context available.
|
||||
{% endif %}
|
||||
|
||||
Please provide detailed, accurate answers based on the documents above.
|
||||
```
|
||||
|
||||
### Conditional Logic Example
|
||||
```jinja
|
||||
You are a DocsGPT assistant.
|
||||
|
||||
{% if source.count > 0 %}
|
||||
I found {{ source.count }} relevant document(s):
|
||||
|
||||
{{ source.content }}
|
||||
|
||||
Base your answer on these documents.
|
||||
{% else %}
|
||||
No documents were found. Please answer based on your general knowledge.
|
||||
{% endif %}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Migration Guide
|
||||
|
||||
### Legacy Format (Still Supported)
|
||||
The old `{summaries}` format continues to work for backward compatibility:
|
||||
|
||||
**Original Prompt:**
|
||||
```markdown
|
||||
You are a DocsGPT, friendly and helpful AI assistant by Arc53 that provides help with documents. You give thorough answers with code examples if possible.
|
||||
Use the following pieces of context to help answer the users question. If it's not relevant to the question, provide friendly responses.
|
||||
You have access to chat history, and can use it to help answer the question.
|
||||
When using code examples, use the following format:
|
||||
You are a helpful assistant.
|
||||
|
||||
(code)
|
||||
Documents:
|
||||
{summaries}
|
||||
```
|
||||
|
||||
Note that `{summaries}` allows model to see and respond to your upploaded documents. If you don't want this functionality you can safely remove it from the customized prompt.
|
||||
This will automatically substitute `{summaries}` with document content.
|
||||
|
||||
Feel free to customize the prompt to align it with your specific use case or the kind of responses you want from the AI. For example, you can focus on specific document types, industries, or topics to get more targeted results.
|
||||
### New Template Format (Recommended)
|
||||
Migrate to the new template syntax for more flexibility:
|
||||
|
||||
```jinja
|
||||
You are a helpful assistant.
|
||||
|
||||
Documents:
|
||||
{{ source.content }}
|
||||
```
|
||||
|
||||
**Migration mapping:**
|
||||
- `{summaries}` → `{{ source.content }}` or `{{ source.summaries }}`
|
||||
|
||||
---
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. **Use Descriptive Context**
|
||||
```jinja
|
||||
**Retrieved Documents:**
|
||||
{{ source.content }}
|
||||
|
||||
**User Query Context:**
|
||||
- Company: {{ passthrough.company }}
|
||||
- Department: {{ passthrough.department }}
|
||||
```
|
||||
|
||||
### 2. **Handle Missing Data Gracefully**
|
||||
```jinja
|
||||
{% if passthrough.user_name %}
|
||||
Hello {{ passthrough.user_name }}!
|
||||
{% endif %}
|
||||
```
|
||||
|
||||
### 3. **Leverage Memory for Continuity**
|
||||
```jinja
|
||||
{% if tools.memory.available %}
|
||||
**Previous Context:**
|
||||
{{ tools.memory.root }}
|
||||
{% endif %}
|
||||
|
||||
**Current Question:**
|
||||
Please consider the above context when answering.
|
||||
```
|
||||
|
||||
### 4. **Add Clear Instructions**
|
||||
```jinja
|
||||
You are a technical support assistant.
|
||||
|
||||
**Guidelines:**
|
||||
1. Always reference the documents below
|
||||
2. Provide step-by-step instructions
|
||||
3. Include code examples when relevant
|
||||
|
||||
**Reference Documents:**
|
||||
{{ source.content }}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Advanced Features
|
||||
|
||||
### Looping Over Documents
|
||||
```jinja
|
||||
{% for doc in source.documents %}
|
||||
**Source {{ loop.index }}:** {{ doc.filename }}
|
||||
{{ doc.text }}
|
||||
|
||||
{% endfor %}
|
||||
```
|
||||
|
||||
### Date-Based Behavior
|
||||
```jinja
|
||||
{% if system.date > "2025-01-01" %}
|
||||
Note: This is information from 2025 or later.
|
||||
{% endif %}
|
||||
```
|
||||
|
||||
### Custom Formatting
|
||||
```jinja
|
||||
**Request Information**
|
||||
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
||||
• Request ID: {{ system.request_id }}
|
||||
• User: {{ passthrough.user_name | default("Guest") }}
|
||||
• Time: {{ system.time }}
|
||||
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tool Pre-Fetching
|
||||
|
||||
### Memory Tool Configuration
|
||||
Enable memory tool pre-fetching to inject saved context into prompts:
|
||||
|
||||
```python
|
||||
# In your tool configuration
|
||||
{
|
||||
"name": "memory",
|
||||
"config": {
|
||||
"pre_fetch_enabled": true # Default: true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Control pre-fetching globally:
|
||||
```bash
|
||||
# .env file
|
||||
ENABLE_TOOL_PREFETCH=true
|
||||
```
|
||||
|
||||
Or per-request:
|
||||
```json
|
||||
{
|
||||
"question": "What are the requirements?",
|
||||
"disable_tool_prefetch": false
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Debugging Prompts
|
||||
|
||||
### View Rendered Prompts in Logs
|
||||
Set log level to `INFO` to see the final rendered prompt sent to the LLM:
|
||||
|
||||
```bash
|
||||
export LOG_LEVEL=INFO
|
||||
```
|
||||
|
||||
You'll see output like:
|
||||
```
|
||||
INFO - Rendered system prompt for agent (length: 1234 chars):
|
||||
================================================================================
|
||||
You are a helpful assistant for Acme Corp.
|
||||
|
||||
Current date: 2025-10-30
|
||||
Request ID: req_abc123
|
||||
|
||||
Documents:
|
||||
Technical documentation about...
|
||||
================================================================================
|
||||
```
|
||||
|
||||
### Template Validation
|
||||
Test your template syntax before saving:
|
||||
```python
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
is_valid = renderer.validate_template("Your prompt with {{ variables }}")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Common Use Cases
|
||||
|
||||
### 1. Customer Support Bot
|
||||
```jinja
|
||||
You are a customer support assistant for {{ passthrough.company }}.
|
||||
|
||||
**Customer:** {{ passthrough.customer_name }}
|
||||
**Ticket ID:** {{ system.request_id }}
|
||||
**Date:** {{ system.date }}
|
||||
|
||||
**Knowledge Base:**
|
||||
{{ source.content }}
|
||||
|
||||
**Previous Interactions:**
|
||||
{{ tools.memory.root }}
|
||||
|
||||
Please provide helpful, friendly support based on the knowledge base above.
|
||||
```
|
||||
|
||||
### 2. Technical Documentation Assistant
|
||||
```jinja
|
||||
You are a technical documentation expert.
|
||||
|
||||
**Available Documentation ({{ source.count }} documents):**
|
||||
{{ source.content }}
|
||||
|
||||
**Requirements:**
|
||||
- Provide code examples in {{ passthrough.language }}
|
||||
- Focus on {{ passthrough.framework }} best practices
|
||||
- Include relevant links when possible
|
||||
```
|
||||
|
||||
### 3. Internal Knowledge Base
|
||||
```jinja
|
||||
You are an internal AI assistant for {{ passthrough.department }}.
|
||||
|
||||
**Employee:** {{ passthrough.employee_name }}
|
||||
**Access Level:** {{ passthrough.access_level }}
|
||||
|
||||
**Relevant Documents:**
|
||||
{{ source.content }}
|
||||
|
||||
Provide detailed answers appropriate for {{ passthrough.access_level }} access level.
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Template Syntax Reference
|
||||
|
||||
### Variables
|
||||
```jinja
|
||||
{{ variable_name }} # Output variable
|
||||
{{ namespace.field }} # Access nested field
|
||||
{{ variable | default("N/A") }} # Default value
|
||||
```
|
||||
|
||||
### Conditionals
|
||||
```jinja
|
||||
{% if condition %}
|
||||
Content
|
||||
{% elif other_condition %}
|
||||
Other content
|
||||
{% else %}
|
||||
Default content
|
||||
{% endif %}
|
||||
```
|
||||
|
||||
### Loops
|
||||
```jinja
|
||||
{% for item in list %}
|
||||
{{ item.field }}
|
||||
{% endfor %}
|
||||
```
|
||||
|
||||
### Comments
|
||||
```jinja
|
||||
{# This is a comment and won't appear in output #}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Security Considerations
|
||||
|
||||
1. **Input Sanitization**: Passthrough data is automatically sanitized to prevent injection attacks
|
||||
2. **Type Filtering**: Only primitive types (string, int, float, bool, None) are allowed in passthrough
|
||||
3. **Autoescaping**: Jinja2 autoescaping is enabled by default
|
||||
4. **Size Limits**: Consider the token budget when including large documents
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Problem: Variables Not Rendering
|
||||
**Solution:** Ensure you're using the correct namespace:
|
||||
```jinja
|
||||
❌ {{ company }}
|
||||
✅ {{ passthrough.company }}
|
||||
```
|
||||
|
||||
### Problem: Empty Output for Tool Data
|
||||
**Solution:** Check that tool pre-fetching is enabled and the tool is configured correctly.
|
||||
|
||||
### Problem: Syntax Errors
|
||||
**Solution:** Validate template syntax. Common issues:
|
||||
```jinja
|
||||
❌ {{ variable } # Missing closing brace
|
||||
❌ {% if x % # Missing closing %}
|
||||
✅ {{ variable }}
|
||||
✅ {% if x %}...{% endif %}
|
||||
```
|
||||
|
||||
### Problem: Legacy Prompts Not Working
|
||||
**Solution:** The system auto-detects template syntax. If your prompt uses `{summaries}`, it will work in legacy mode. To use new features, add `{{ }}` syntax.
|
||||
|
||||
---
|
||||
|
||||
## API Reference
|
||||
|
||||
### Render Prompt via API
|
||||
```python
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
rendered = renderer.render_prompt(
|
||||
prompt_content="Your template with {{ passthrough.name }}",
|
||||
user_id="user_123",
|
||||
request_id="req_456",
|
||||
passthrough_data={"name": "Alice"},
|
||||
docs_together="Document content here",
|
||||
tools_data={"memory": {"root": "Files: notes.txt"}}
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Conclusion
|
||||
|
||||
Customizing the main prompt for DocsGPT allows you to tailor the AI's responses to your unique requirements. Whether you need in-depth explanations, code examples, or specific insights, you can achieve it by modifying the main prompt. Remember to experiment and fine-tune your prompts to get the best results.
|
||||
The new template-based prompt system provides powerful flexibility while maintaining backward compatibility. By leveraging namespaces, you can create dynamic, context-aware prompts that adapt to your specific use case.
|
||||
|
||||
**Key Benefits:**
|
||||
- ✅ Dynamic variable injection
|
||||
- ✅ Organized namespaces
|
||||
- ✅ Backward compatible
|
||||
- ✅ Security built-in
|
||||
- ✅ Easy to debug
|
||||
|
||||
Start with simple templates and gradually add complexity as needed. Happy prompting! 🚀
|
||||
|
||||
@@ -200,13 +200,19 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
|
||||
if (agent.limited_token_mode && agent.token_limit) {
|
||||
formData.append('limited_token_mode', 'True');
|
||||
formData.append('token_limit', JSON.stringify(agent.token_limit));
|
||||
} else formData.append('token_limit', '0');
|
||||
formData.append('token_limit', agent.token_limit.toString());
|
||||
} else {
|
||||
formData.append('limited_token_mode', 'False');
|
||||
formData.append('token_limit', '0');
|
||||
}
|
||||
|
||||
if (agent.limited_request_mode && agent.request_limit) {
|
||||
formData.append('limited_request_mode', 'True');
|
||||
formData.append('request_limit', JSON.stringify(agent.request_limit));
|
||||
} else formData.append('request_limit', '0');
|
||||
formData.append('request_limit', agent.request_limit.toString());
|
||||
} else {
|
||||
formData.append('limited_request_mode', 'False');
|
||||
formData.append('request_limit', '0');
|
||||
}
|
||||
|
||||
if (imageFile) formData.append('image', imageFile);
|
||||
|
||||
@@ -297,15 +303,22 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
formData.append('json_schema', JSON.stringify(agent.json_schema));
|
||||
}
|
||||
|
||||
// Always send the limited mode fields
|
||||
if (agent.limited_token_mode && agent.token_limit) {
|
||||
formData.append('limited_token_mode', 'True');
|
||||
formData.append('token_limit', JSON.stringify(agent.token_limit));
|
||||
} else formData.append('token_limit', '0');
|
||||
formData.append('token_limit', agent.token_limit.toString());
|
||||
} else {
|
||||
formData.append('limited_token_mode', 'False');
|
||||
formData.append('token_limit', '0');
|
||||
}
|
||||
|
||||
if (agent.limited_request_mode && agent.request_limit) {
|
||||
formData.append('limited_request_mode', 'True');
|
||||
formData.append('request_limit', JSON.stringify(agent.request_limit));
|
||||
} else formData.append('request_limit', '0');
|
||||
formData.append('request_limit', agent.request_limit.toString());
|
||||
} else {
|
||||
formData.append('limited_request_mode', 'False');
|
||||
formData.append('request_limit', '0');
|
||||
}
|
||||
|
||||
try {
|
||||
setPublishLoading(true);
|
||||
|
||||
@@ -130,7 +130,7 @@ export default function Conversation() {
|
||||
}),
|
||||
);
|
||||
handleQuestion({
|
||||
question: queries[queries.length - 1].prompt,
|
||||
question: question,
|
||||
isRetry: true,
|
||||
});
|
||||
} else {
|
||||
|
||||
@@ -90,15 +90,20 @@ export default function ConversationMessages({
|
||||
setHasScrolledToLast(isAtBottom);
|
||||
}, [setHasScrolledToLast]);
|
||||
|
||||
const lastQuery = queries[queries.length - 1];
|
||||
const lastQueryResponse = lastQuery?.response;
|
||||
const lastQueryError = lastQuery?.error;
|
||||
const lastQueryThought = lastQuery?.thought;
|
||||
|
||||
useEffect(() => {
|
||||
if (!userInterruptedScroll) {
|
||||
scrollConversationToBottom();
|
||||
}
|
||||
}, [
|
||||
queries.length,
|
||||
queries[queries.length - 1]?.response,
|
||||
queries[queries.length - 1]?.error,
|
||||
queries[queries.length - 1]?.thought,
|
||||
lastQueryResponse,
|
||||
lastQueryError,
|
||||
lastQueryThought,
|
||||
userInterruptedScroll,
|
||||
scrollConversationToBottom,
|
||||
]);
|
||||
|
||||
@@ -370,7 +370,10 @@ export const conversationSlice = createSlice({
|
||||
return state;
|
||||
}
|
||||
state.status = 'failed';
|
||||
state.queries[state.queries.length - 1].error = 'Something went wrong';
|
||||
if (state.queries.length > 0) {
|
||||
state.queries[state.queries.length - 1].error =
|
||||
'Something went wrong';
|
||||
}
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
@@ -64,17 +64,14 @@ class TestBaseAgentBuildMessages:
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
system_prompt = "System: {summaries}"
|
||||
system_prompt = "System prompt content"
|
||||
query = "What is Python?"
|
||||
retrieved_data = [
|
||||
{"text": "Python is a programming language", "filename": "python.txt"}
|
||||
]
|
||||
|
||||
messages = agent._build_messages(system_prompt, query, retrieved_data)
|
||||
messages = agent._build_messages(system_prompt, query)
|
||||
|
||||
assert len(messages) >= 2
|
||||
assert messages[0]["role"] == "system"
|
||||
assert "Python is a programming language" in messages[0]["content"]
|
||||
assert messages[0]["content"] == system_prompt
|
||||
assert messages[-1]["role"] == "user"
|
||||
assert messages[-1]["content"] == query
|
||||
|
||||
@@ -88,11 +85,10 @@ class TestBaseAgentBuildMessages:
|
||||
agent_base_params["chat_history"] = sample_chat_history
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
system_prompt = "System: {summaries}"
|
||||
system_prompt = "System prompt"
|
||||
query = "New question?"
|
||||
retrieved_data = [{"text": "Data", "filename": "file.txt"}]
|
||||
|
||||
messages = agent._build_messages(system_prompt, query, retrieved_data)
|
||||
messages = agent._build_messages(system_prompt, query)
|
||||
|
||||
user_messages = [m for m in messages if m["role"] == "user"]
|
||||
assistant_messages = [m for m in messages if m["role"] == "assistant"]
|
||||
@@ -118,9 +114,7 @@ class TestBaseAgentBuildMessages:
|
||||
agent_base_params["chat_history"] = tool_call_history
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
messages = agent._build_messages(
|
||||
"System: {summaries}", "query", [{"text": "data", "filename": "file.txt"}]
|
||||
)
|
||||
messages = agent._build_messages("System prompt", "query")
|
||||
|
||||
tool_messages = [m for m in messages if m["role"] == "tool"]
|
||||
assert len(tool_messages) > 0
|
||||
@@ -129,32 +123,25 @@ class TestBaseAgentBuildMessages:
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
retrieved_data = [{"text": "Document without filename or title"}]
|
||||
|
||||
messages = agent._build_messages("System: {summaries}", "query", retrieved_data)
|
||||
messages = agent._build_messages("System prompt", "query")
|
||||
|
||||
assert messages[0]["role"] == "system"
|
||||
assert "Document without filename" in messages[0]["content"]
|
||||
assert messages[0]["content"] == "System prompt"
|
||||
|
||||
def test_build_messages_uses_title_as_fallback(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
retrieved_data = [{"text": "Data", "title": "Title Doc"}]
|
||||
|
||||
messages = agent._build_messages("System: {summaries}", "query", retrieved_data)
|
||||
|
||||
assert "Title Doc" in messages[0]["content"]
|
||||
agent._build_messages("System prompt", "query")
|
||||
|
||||
def test_build_messages_uses_source_as_fallback(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
retrieved_data = [{"text": "Data", "source": "source.txt"}]
|
||||
|
||||
messages = agent._build_messages("System: {summaries}", "query", retrieved_data)
|
||||
|
||||
assert "source.txt" in messages[0]["content"]
|
||||
agent._build_messages("System prompt", "query")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@@ -475,40 +462,6 @@ class TestBaseAgentToolExecution:
|
||||
assert truncated[0]["result"].endswith("...")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBaseAgentRetrieverSearch:
|
||||
|
||||
def test_retriever_search(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
log_context,
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
results = agent._retriever_search(mock_retriever, "test query", log_context)
|
||||
|
||||
assert len(results) == 2
|
||||
mock_retriever.search.assert_called_once_with("test query")
|
||||
|
||||
def test_retriever_search_logs_context(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
log_context,
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
agent._retriever_search(mock_retriever, "test query", log_context)
|
||||
|
||||
assert len(log_context.stacks) == 1
|
||||
assert log_context.stacks[0]["component"] == "retriever"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBaseAgentLLMGeneration:
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ class TestClassicAgent:
|
||||
def test_gen_inner_basic_flow(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -40,7 +39,7 @@ class TestClassicAgent:
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
results = list(agent._gen_inner("Test query", log_context))
|
||||
|
||||
assert len(results) >= 2
|
||||
sources = [r for r in results if "sources" in r]
|
||||
@@ -52,7 +51,6 @@ class TestClassicAgent:
|
||||
def test_gen_inner_retrieves_documents(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -68,14 +66,11 @@ class TestClassicAgent:
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
|
||||
mock_retriever.search.assert_called_once_with("Test query")
|
||||
list(agent._gen_inner("Test query", log_context))
|
||||
|
||||
def test_gen_inner_uses_user_api_key_tools(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -104,14 +99,13 @@ class TestClassicAgent:
|
||||
agent_base_params["user_api_key"] = "api_key_123"
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
list(agent._gen_inner("Test query", log_context))
|
||||
|
||||
assert len(agent.tools) >= 0
|
||||
|
||||
def test_gen_inner_uses_user_tools(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -133,14 +127,13 @@ class TestClassicAgent:
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
list(agent._gen_inner("Test query", log_context))
|
||||
|
||||
assert len(agent.tools) >= 0
|
||||
|
||||
def test_gen_inner_builds_correct_messages(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -156,7 +149,7 @@ class TestClassicAgent:
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
list(agent._gen_inner("Test query", log_context))
|
||||
|
||||
call_kwargs = mock_llm.gen_stream.call_args[1]
|
||||
messages = call_kwargs["messages"]
|
||||
@@ -169,7 +162,6 @@ class TestClassicAgent:
|
||||
def test_gen_inner_logs_tool_calls(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -187,7 +179,7 @@ class TestClassicAgent:
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
agent.tool_calls = [{"tool": "test", "result": "success"}]
|
||||
|
||||
list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
list(agent._gen_inner("Test query", log_context))
|
||||
|
||||
agent_logs = [s for s in log_context.stacks if s["component"] == "agent"]
|
||||
assert len(agent_logs) == 1
|
||||
@@ -200,7 +192,6 @@ class TestClassicAgentIntegration:
|
||||
def test_gen_method_with_logging(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -216,14 +207,13 @@ class TestClassicAgentIntegration:
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
results = list(agent.gen("Test query", mock_retriever))
|
||||
results = list(agent.gen("Test query"))
|
||||
|
||||
assert len(results) >= 1
|
||||
|
||||
def test_gen_method_decorator_applied(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
|
||||
@@ -35,7 +35,7 @@ class TestReActAgentContentExtraction:
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
|
||||
response = "Simple string response"
|
||||
content = agent._extract_content_from_llm_response(response)
|
||||
content = agent._extract_content(response)
|
||||
|
||||
assert content == "Simple string response"
|
||||
|
||||
@@ -48,7 +48,7 @@ class TestReActAgentContentExtraction:
|
||||
response.message = Mock()
|
||||
response.message.content = "Message content"
|
||||
|
||||
content = agent._extract_content_from_llm_response(response)
|
||||
content = agent._extract_content(response)
|
||||
|
||||
assert content == "Message content"
|
||||
|
||||
@@ -64,7 +64,7 @@ class TestReActAgentContentExtraction:
|
||||
response.message = None
|
||||
response.content = None
|
||||
|
||||
content = agent._extract_content_from_llm_response(response)
|
||||
content = agent._extract_content(response)
|
||||
|
||||
assert content == "OpenAI content"
|
||||
|
||||
@@ -81,7 +81,7 @@ class TestReActAgentContentExtraction:
|
||||
response.message = None
|
||||
response.choices = None
|
||||
|
||||
content = agent._extract_content_from_llm_response(response)
|
||||
content = agent._extract_content(response)
|
||||
|
||||
assert content == "Anthropic content"
|
||||
|
||||
@@ -101,7 +101,7 @@ class TestReActAgentContentExtraction:
|
||||
chunk2.choices[0].delta.content = "Part 2"
|
||||
|
||||
response = iter([chunk1, chunk2])
|
||||
content = agent._extract_content_from_llm_response(response)
|
||||
content = agent._extract_content(response)
|
||||
|
||||
assert content == "Part 1 Part 2"
|
||||
|
||||
@@ -123,7 +123,7 @@ class TestReActAgentContentExtraction:
|
||||
chunk2.choices = []
|
||||
|
||||
response = iter([chunk1, chunk2])
|
||||
content = agent._extract_content_from_llm_response(response)
|
||||
content = agent._extract_content(response)
|
||||
|
||||
assert content == "Stream 1 Stream 2"
|
||||
|
||||
@@ -133,7 +133,7 @@ class TestReActAgentContentExtraction:
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
|
||||
response = iter(["chunk1", "chunk2", "chunk3"])
|
||||
content = agent._extract_content_from_llm_response(response)
|
||||
content = agent._extract_content(response)
|
||||
|
||||
assert content == "chunk1chunk2chunk3"
|
||||
|
||||
@@ -148,7 +148,7 @@ class TestReActAgentContentExtraction:
|
||||
response.choices = None
|
||||
response.content = None
|
||||
|
||||
content = agent._extract_content_from_llm_response(response)
|
||||
content = agent._extract_content(response)
|
||||
|
||||
assert content == ""
|
||||
|
||||
@@ -161,7 +161,7 @@ class TestReActAgentPlanning:
|
||||
new_callable=mock_open,
|
||||
read_data="Test planning prompt: {query} {summaries} {prompt} {observations}",
|
||||
)
|
||||
def test_create_plan(
|
||||
def test_planning_phase(
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
@@ -171,24 +171,27 @@ class TestReActAgentPlanning:
|
||||
log_context,
|
||||
):
|
||||
def mock_gen_stream(*args, **kwargs):
|
||||
yield "Plan step 1"
|
||||
yield "Plan step 2"
|
||||
# Return simple strings - _extract_content handles strings directly
|
||||
|
||||
yield "Plan "
|
||||
yield "content"
|
||||
|
||||
mock_llm.gen_stream = Mock(return_value=mock_gen_stream())
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
agent.observations = ["Observation 1"]
|
||||
|
||||
plan_chunks = list(agent._create_plan("Test query", "Test docs", log_context))
|
||||
plan_chunks = list(agent._planning_phase("Test query", log_context))
|
||||
|
||||
assert len(plan_chunks) == 2
|
||||
assert plan_chunks[0] == "Plan step 1"
|
||||
assert plan_chunks[1] == "Plan step 2"
|
||||
# Should yield thought dicts
|
||||
|
||||
assert any("thought" in chunk for chunk in plan_chunks)
|
||||
assert agent.plan == "Plan content"
|
||||
|
||||
mock_llm.gen_stream.assert_called_once()
|
||||
|
||||
@patch("builtins.open", new_callable=mock_open, read_data="Test: {query}")
|
||||
def test_create_plan_fills_template(
|
||||
def test_planning_phase_fills_template(
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
@@ -197,10 +200,10 @@ class TestReActAgentPlanning:
|
||||
mock_llm_handler_creator,
|
||||
log_context,
|
||||
):
|
||||
mock_llm.gen_stream = Mock(return_value=iter(["Plan"]))
|
||||
mock_llm.gen_stream = Mock(return_value=iter([]))
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
list(agent._create_plan("My query", "Docs", log_context))
|
||||
list(agent._planning_phase("My query", log_context))
|
||||
|
||||
call_args = mock_llm.gen_stream.call_args[1]
|
||||
messages = call_args["messages"]
|
||||
@@ -216,7 +219,7 @@ class TestReActAgentFinalAnswer:
|
||||
new_callable=mock_open,
|
||||
read_data="Final answer for: {query} with {observations}",
|
||||
)
|
||||
def test_create_final_answer(
|
||||
def test_synthesis_phase(
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
@@ -226,24 +229,22 @@ class TestReActAgentFinalAnswer:
|
||||
log_context,
|
||||
):
|
||||
def mock_gen_stream(*args, **kwargs):
|
||||
yield "Final "
|
||||
yield "answer"
|
||||
yield Mock(choices=[Mock(delta=Mock(content="Final "))])
|
||||
yield Mock(choices=[Mock(delta=Mock(content="answer"))])
|
||||
|
||||
mock_llm.gen_stream = Mock(return_value=mock_gen_stream())
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
observations = ["Obs 1", "Obs 2"]
|
||||
agent.observations = ["Obs 1", "Obs 2"]
|
||||
|
||||
answer_chunks = list(
|
||||
agent._create_final_answer("Test query", observations, log_context)
|
||||
)
|
||||
answer_chunks = list(agent._synthesis_phase("Test query", log_context))
|
||||
|
||||
assert len(answer_chunks) == 2
|
||||
assert answer_chunks[0] == "Final "
|
||||
assert answer_chunks[1] == "answer"
|
||||
# Should yield answer dicts
|
||||
|
||||
assert any("answer" in chunk for chunk in answer_chunks)
|
||||
|
||||
@patch("builtins.open", new_callable=mock_open, read_data="Answer: {observations}")
|
||||
def test_create_final_answer_truncates_long_observations(
|
||||
def test_synthesis_phase_truncates_long_observations(
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
@@ -252,20 +253,20 @@ class TestReActAgentFinalAnswer:
|
||||
mock_llm_handler_creator,
|
||||
log_context,
|
||||
):
|
||||
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
|
||||
mock_llm.gen_stream = Mock(return_value=iter([]))
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
long_obs = ["A" * 15000]
|
||||
agent.observations = ["A" * 15000]
|
||||
|
||||
list(agent._create_final_answer("Query", long_obs, log_context))
|
||||
list(agent._synthesis_phase("Query", log_context))
|
||||
|
||||
call_args = mock_llm.gen_stream.call_args[1]
|
||||
messages = call_args["messages"]
|
||||
|
||||
assert "observations truncated" in messages[0]["content"]
|
||||
assert "truncated" in messages[0]["content"]
|
||||
|
||||
@patch("builtins.open", new_callable=mock_open, read_data="Test: {query}")
|
||||
def test_create_final_answer_no_tools(
|
||||
def test_synthesis_phase_no_tools(
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
@@ -274,10 +275,11 @@ class TestReActAgentFinalAnswer:
|
||||
mock_llm_handler_creator,
|
||||
log_context,
|
||||
):
|
||||
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
|
||||
mock_llm.gen_stream = Mock(return_value=iter([]))
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
list(agent._create_final_answer("Query", ["Obs"], log_context))
|
||||
agent.observations = ["Obs"]
|
||||
list(agent._synthesis_phase("Query", log_context))
|
||||
|
||||
call_args = mock_llm.gen_stream.call_args[1]
|
||||
|
||||
@@ -294,7 +296,6 @@ class TestReActAgentGenInner:
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -313,7 +314,7 @@ class TestReActAgentGenInner:
|
||||
agent.plan = "Old plan"
|
||||
agent.observations = ["Old obs"]
|
||||
|
||||
list(agent._gen_inner("New query", mock_retriever, log_context))
|
||||
list(agent._gen_inner("New query", log_context))
|
||||
|
||||
assert agent.plan != "Old plan"
|
||||
assert len(agent.observations) > 0
|
||||
@@ -323,7 +324,6 @@ class TestReActAgentGenInner:
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -351,7 +351,7 @@ class TestReActAgentGenInner:
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
results = list(agent._gen_inner("Test query", log_context))
|
||||
|
||||
assert any("answer" in r for r in results)
|
||||
|
||||
@@ -360,7 +360,6 @@ class TestReActAgentGenInner:
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -386,7 +385,7 @@ class TestReActAgentGenInner:
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
|
||||
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
results = list(agent._gen_inner("Test query", log_context))
|
||||
|
||||
thought_results = [r for r in results if "thought" in r]
|
||||
assert len(thought_results) > 0
|
||||
@@ -396,7 +395,6 @@ class TestReActAgentGenInner:
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -412,7 +410,7 @@ class TestReActAgentGenInner:
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
results = list(agent._gen_inner("Test query", log_context))
|
||||
|
||||
sources = [r for r in results if "sources" in r]
|
||||
assert len(sources) >= 1
|
||||
@@ -422,7 +420,6 @@ class TestReActAgentGenInner:
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -440,7 +437,7 @@ class TestReActAgentGenInner:
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
agent.tool_calls = [{"tool": "test", "result": "A" * 100}]
|
||||
|
||||
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
results = list(agent._gen_inner("Test query", log_context))
|
||||
|
||||
tool_call_results = [r for r in results if "tool_calls" in r]
|
||||
if tool_call_results:
|
||||
@@ -451,7 +448,6 @@ class TestReActAgentGenInner:
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -467,7 +463,7 @@ class TestReActAgentGenInner:
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
list(agent._gen_inner("Test query", log_context))
|
||||
|
||||
assert len(agent.observations) > 0
|
||||
|
||||
@@ -484,7 +480,6 @@ class TestReActAgentIntegration:
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -512,7 +507,7 @@ class TestReActAgentIntegration:
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
results = list(agent._gen_inner("Complex query", mock_retriever, log_context))
|
||||
results = list(agent._gen_inner("Complex query", log_context))
|
||||
|
||||
assert len(results) > 0
|
||||
assert any("thought" in r for r in results)
|
||||
|
||||
@@ -315,16 +315,12 @@ class TestCompleteStreamMethod:
|
||||
]
|
||||
)
|
||||
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.get_params.return_value = {}
|
||||
|
||||
decoded_token = {"sub": "user123"}
|
||||
|
||||
stream = list(
|
||||
resource.complete_stream(
|
||||
question="Test question",
|
||||
agent=mock_agent,
|
||||
retriever=mock_retriever,
|
||||
conversation_id=None,
|
||||
user_api_key=None,
|
||||
decoded_token=decoded_token,
|
||||
@@ -351,16 +347,12 @@ class TestCompleteStreamMethod:
|
||||
]
|
||||
)
|
||||
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.get_params.return_value = {}
|
||||
|
||||
decoded_token = {"sub": "user123"}
|
||||
|
||||
stream = list(
|
||||
resource.complete_stream(
|
||||
question="Test?",
|
||||
agent=mock_agent,
|
||||
retriever=mock_retriever,
|
||||
conversation_id=None,
|
||||
user_api_key=None,
|
||||
decoded_token=decoded_token,
|
||||
@@ -381,16 +373,12 @@ class TestCompleteStreamMethod:
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.gen.side_effect = Exception("Test error")
|
||||
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.get_params.return_value = {}
|
||||
|
||||
decoded_token = {"sub": "user123"}
|
||||
|
||||
stream = list(
|
||||
resource.complete_stream(
|
||||
question="Test?",
|
||||
agent=mock_agent,
|
||||
retriever=mock_retriever,
|
||||
conversation_id=None,
|
||||
user_api_key=None,
|
||||
decoded_token=decoded_token,
|
||||
@@ -413,9 +401,6 @@ class TestCompleteStreamMethod:
|
||||
]
|
||||
)
|
||||
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.get_params.return_value = {}
|
||||
|
||||
decoded_token = {"sub": "user123"}
|
||||
|
||||
with patch.object(
|
||||
@@ -427,8 +412,7 @@ class TestCompleteStreamMethod:
|
||||
resource.complete_stream(
|
||||
question="Test?",
|
||||
agent=mock_agent,
|
||||
retriever=mock_retriever,
|
||||
conversation_id=None,
|
||||
conversation_id=None,
|
||||
user_api_key=None,
|
||||
decoded_token=decoded_token,
|
||||
should_save_conversation=True,
|
||||
@@ -461,7 +445,6 @@ class TestCompleteStreamMethod:
|
||||
resource.complete_stream(
|
||||
question="Test question?",
|
||||
agent=mock_agent,
|
||||
retriever=mock_retriever,
|
||||
conversation_id=None,
|
||||
user_api_key="test_key",
|
||||
decoded_token=decoded_token,
|
||||
|
||||
850
tests/api/answer/services/test_prompt_renderer.py
Normal file
850
tests/api/answer/services/test_prompt_renderer.py
Normal file
@@ -0,0 +1,850 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestTemplateEngine:
|
||||
|
||||
def test_render_simple_template(self):
|
||||
from application.templates.template_engine import TemplateEngine
|
||||
|
||||
engine = TemplateEngine()
|
||||
result = engine.render("Hello {{ name }}", {"name": "World"})
|
||||
|
||||
assert result == "Hello World"
|
||||
|
||||
def test_render_with_namespace(self):
|
||||
from application.templates.template_engine import TemplateEngine
|
||||
|
||||
engine = TemplateEngine()
|
||||
context = {
|
||||
"user": {"name": "Alice", "role": "admin"},
|
||||
"system": {"date": "2025-10-22"},
|
||||
}
|
||||
result = engine.render(
|
||||
"{{ user.name }} is a {{ user.role }} on {{ system.date }}", context
|
||||
)
|
||||
|
||||
assert result == "Alice is a admin on 2025-10-22"
|
||||
|
||||
def test_render_empty_template(self):
|
||||
from application.templates.template_engine import TemplateEngine
|
||||
|
||||
engine = TemplateEngine()
|
||||
result = engine.render("", {"key": "value"})
|
||||
|
||||
assert result == ""
|
||||
|
||||
def test_render_template_without_variables(self):
|
||||
from application.templates.template_engine import TemplateEngine
|
||||
|
||||
engine = TemplateEngine()
|
||||
result = engine.render("Just plain text", {})
|
||||
|
||||
assert result == "Just plain text"
|
||||
|
||||
def test_render_undefined_variable_returns_empty_string(self):
|
||||
from application.templates.template_engine import TemplateEngine
|
||||
|
||||
engine = TemplateEngine()
|
||||
|
||||
result = engine.render("Hello {{ undefined_var }}", {})
|
||||
assert result == "Hello "
|
||||
|
||||
def test_render_syntax_error_raises_error(self):
|
||||
from application.templates.template_engine import (
|
||||
TemplateEngine,
|
||||
TemplateRenderError,
|
||||
)
|
||||
|
||||
engine = TemplateEngine()
|
||||
|
||||
with pytest.raises(TemplateRenderError, match="Template syntax error"):
|
||||
engine.render("Hello {{ name", {"name": "World"})
|
||||
|
||||
def test_validate_template_valid(self):
|
||||
from application.templates.template_engine import TemplateEngine
|
||||
|
||||
engine = TemplateEngine()
|
||||
assert engine.validate_template("Valid {{ variable }}") is True
|
||||
|
||||
def test_validate_template_invalid(self):
|
||||
from application.templates.template_engine import TemplateEngine
|
||||
|
||||
engine = TemplateEngine()
|
||||
assert engine.validate_template("Invalid {{ variable") is False
|
||||
|
||||
def test_validate_empty_template(self):
|
||||
from application.templates.template_engine import TemplateEngine
|
||||
|
||||
engine = TemplateEngine()
|
||||
assert engine.validate_template("") is True
|
||||
|
||||
def test_extract_variables(self):
|
||||
from application.templates.template_engine import TemplateEngine
|
||||
|
||||
engine = TemplateEngine()
|
||||
template = "{{ user.name }} and {{ user.email }}"
|
||||
|
||||
result = engine.extract_variables(template)
|
||||
|
||||
assert isinstance(result, set)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSystemNamespace:
|
||||
|
||||
def test_system_namespace_build(self):
|
||||
from application.templates.namespaces import SystemNamespace
|
||||
|
||||
builder = SystemNamespace()
|
||||
context = builder.build(
|
||||
request_id="req_123", user_id="user_456", extra_param="ignored"
|
||||
)
|
||||
|
||||
assert context["request_id"] == "req_123"
|
||||
assert context["user_id"] == "user_456"
|
||||
assert "date" in context
|
||||
assert "time" in context
|
||||
assert "timestamp" in context
|
||||
|
||||
def test_system_namespace_generates_request_id(self):
|
||||
from application.templates.namespaces import SystemNamespace
|
||||
|
||||
builder = SystemNamespace()
|
||||
context = builder.build(user_id="user_123")
|
||||
|
||||
assert context["request_id"] is not None
|
||||
assert len(context["request_id"]) > 0
|
||||
|
||||
def test_system_namespace_name(self):
|
||||
from application.templates.namespaces import SystemNamespace
|
||||
|
||||
builder = SystemNamespace()
|
||||
assert builder.namespace_name == "system"
|
||||
|
||||
def test_system_namespace_date_format(self):
|
||||
from application.templates.namespaces import SystemNamespace
|
||||
|
||||
builder = SystemNamespace()
|
||||
context = builder.build()
|
||||
|
||||
import re
|
||||
|
||||
assert re.match(r"\d{4}-\d{2}-\d{2}", context["date"])
|
||||
assert re.match(r"\d{2}:\d{2}:\d{2}", context["time"])
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPassthroughNamespace:
|
||||
|
||||
def test_passthrough_namespace_build(self):
|
||||
from application.templates.namespaces import PassthroughNamespace
|
||||
|
||||
builder = PassthroughNamespace()
|
||||
passthrough_data = {"company": "Acme", "user_name": "John", "count": 42}
|
||||
|
||||
context = builder.build(passthrough_data=passthrough_data)
|
||||
|
||||
assert context["company"] == "Acme"
|
||||
assert context["user_name"] == "John"
|
||||
assert context["count"] == 42
|
||||
|
||||
def test_passthrough_namespace_empty(self):
|
||||
from application.templates.namespaces import PassthroughNamespace
|
||||
|
||||
builder = PassthroughNamespace()
|
||||
context = builder.build(passthrough_data=None)
|
||||
|
||||
assert context == {}
|
||||
|
||||
def test_passthrough_namespace_filters_unsafe_values(self):
|
||||
from application.templates.namespaces import PassthroughNamespace
|
||||
|
||||
builder = PassthroughNamespace()
|
||||
passthrough_data = {
|
||||
"safe_string": "value",
|
||||
"unsafe_object": {"key": "value"},
|
||||
"safe_bool": True,
|
||||
"unsafe_list": [1, 2, 3],
|
||||
"safe_float": 3.14,
|
||||
}
|
||||
|
||||
context = builder.build(passthrough_data=passthrough_data)
|
||||
|
||||
assert context["safe_string"] == "value"
|
||||
assert context["safe_bool"] is True
|
||||
assert context["safe_float"] == 3.14
|
||||
assert "unsafe_object" not in context
|
||||
assert "unsafe_list" not in context
|
||||
|
||||
def test_passthrough_namespace_allows_none_values(self):
|
||||
from application.templates.namespaces import PassthroughNamespace
|
||||
|
||||
builder = PassthroughNamespace()
|
||||
passthrough_data = {"nullable_field": None}
|
||||
|
||||
context = builder.build(passthrough_data=passthrough_data)
|
||||
|
||||
assert context["nullable_field"] is None
|
||||
|
||||
def test_passthrough_namespace_name(self):
|
||||
from application.templates.namespaces import PassthroughNamespace
|
||||
|
||||
builder = PassthroughNamespace()
|
||||
assert builder.namespace_name == "passthrough"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSourceNamespace:
|
||||
|
||||
def test_source_namespace_build_with_docs(self):
|
||||
from application.templates.namespaces import SourceNamespace
|
||||
|
||||
builder = SourceNamespace()
|
||||
docs = [
|
||||
{"text": "Doc 1", "filename": "file1.txt"},
|
||||
{"text": "Doc 2", "filename": "file2.txt"},
|
||||
]
|
||||
docs_together = "Doc 1 content\n\nDoc 2 content"
|
||||
|
||||
context = builder.build(docs=docs, docs_together=docs_together)
|
||||
|
||||
assert context["documents"] == docs
|
||||
assert context["count"] == 2
|
||||
assert context["content"] == docs_together
|
||||
assert context["summaries"] == docs_together
|
||||
|
||||
def test_source_namespace_build_empty(self):
|
||||
from application.templates.namespaces import SourceNamespace
|
||||
|
||||
builder = SourceNamespace()
|
||||
context = builder.build(docs=None, docs_together=None)
|
||||
|
||||
assert context == {}
|
||||
|
||||
def test_source_namespace_build_docs_only(self):
|
||||
from application.templates.namespaces import SourceNamespace
|
||||
|
||||
builder = SourceNamespace()
|
||||
docs = [{"text": "Doc 1"}]
|
||||
|
||||
context = builder.build(docs=docs)
|
||||
|
||||
assert context["documents"] == docs
|
||||
assert context["count"] == 1
|
||||
assert "content" not in context
|
||||
|
||||
def test_source_namespace_build_docs_together_only(self):
|
||||
from application.templates.namespaces import SourceNamespace
|
||||
|
||||
builder = SourceNamespace()
|
||||
docs_together = "Content here"
|
||||
|
||||
context = builder.build(docs_together=docs_together)
|
||||
|
||||
assert context["content"] == docs_together
|
||||
assert context["summaries"] == docs_together
|
||||
assert "documents" not in context
|
||||
|
||||
def test_source_namespace_name(self):
|
||||
from application.templates.namespaces import SourceNamespace
|
||||
|
||||
builder = SourceNamespace()
|
||||
assert builder.namespace_name == "source"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestToolsNamespace:
|
||||
|
||||
def test_tools_namespace_build_with_memory_data(self):
|
||||
from application.templates.namespaces import ToolsNamespace
|
||||
|
||||
builder = ToolsNamespace()
|
||||
tools_data = {
|
||||
"memory": {"root": "Files:\n- /notes.txt\n- /tasks.txt", "available": True}
|
||||
}
|
||||
|
||||
context = builder.build(tools_data=tools_data)
|
||||
|
||||
assert context["memory"]["root"] == "Files:\n- /notes.txt\n- /tasks.txt"
|
||||
assert context["memory"]["available"] is True
|
||||
|
||||
def test_tools_namespace_build_empty(self):
|
||||
from application.templates.namespaces import ToolsNamespace
|
||||
|
||||
builder = ToolsNamespace()
|
||||
context = builder.build(tools_data=None)
|
||||
|
||||
assert context == {}
|
||||
|
||||
def test_tools_namespace_build_multiple_tools(self):
|
||||
from application.templates.namespaces import ToolsNamespace
|
||||
|
||||
builder = ToolsNamespace()
|
||||
tools_data = {
|
||||
"memory": {"root": "content", "available": True},
|
||||
"search": {"results": ["result1", "result2"]},
|
||||
"api": {"status": "success"},
|
||||
}
|
||||
|
||||
context = builder.build(tools_data=tools_data)
|
||||
|
||||
assert "memory" in context
|
||||
assert "search" in context
|
||||
assert "api" in context
|
||||
assert context["memory"]["root"] == "content"
|
||||
assert context["search"]["results"] == ["result1", "result2"]
|
||||
assert context["api"]["status"] == "success"
|
||||
|
||||
def test_tools_namespace_filters_unsafe_values(self):
|
||||
from application.templates.namespaces import ToolsNamespace
|
||||
|
||||
builder = ToolsNamespace()
|
||||
|
||||
class UnsafeObject:
|
||||
pass
|
||||
|
||||
tools_data = {"safe_tool": {"result": "success"}, "unsafe_tool": UnsafeObject()}
|
||||
|
||||
context = builder.build(tools_data=tools_data)
|
||||
|
||||
assert "safe_tool" in context
|
||||
assert "unsafe_tool" not in context
|
||||
|
||||
def test_tools_namespace_name(self):
|
||||
from application.templates.namespaces import ToolsNamespace
|
||||
|
||||
builder = ToolsNamespace()
|
||||
assert builder.namespace_name == "tools"
|
||||
|
||||
def test_tools_namespace_with_empty_dict(self):
|
||||
from application.templates.namespaces import ToolsNamespace
|
||||
|
||||
builder = ToolsNamespace()
|
||||
context = builder.build(tools_data={})
|
||||
|
||||
assert context == {}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestNamespaceManagerWithTools:
|
||||
|
||||
def test_namespace_manager_includes_tools_in_context(self):
|
||||
from application.templates.namespaces import NamespaceManager
|
||||
|
||||
manager = NamespaceManager()
|
||||
tools_data = {"memory": {"root": "content", "available": True}}
|
||||
|
||||
context = manager.build_context(tools_data=tools_data)
|
||||
|
||||
assert "tools" in context
|
||||
assert context["tools"]["memory"]["root"] == "content"
|
||||
|
||||
def test_namespace_manager_build_context_all_namespaces(self):
|
||||
from application.templates.namespaces import NamespaceManager
|
||||
|
||||
manager = NamespaceManager()
|
||||
context = manager.build_context(
|
||||
request_id="req_123",
|
||||
user_id="user_456",
|
||||
passthrough_data={"key": "value"},
|
||||
docs_together="Document content",
|
||||
tools_data={"memory": {"root": "notes"}},
|
||||
)
|
||||
|
||||
assert "system" in context
|
||||
assert "passthrough" in context
|
||||
assert "source" in context
|
||||
assert "tools" in context
|
||||
assert context["tools"]["memory"]["root"] == "notes"
|
||||
|
||||
def test_namespace_manager_build_context_partial_data(self):
|
||||
from application.templates.namespaces import NamespaceManager
|
||||
|
||||
manager = NamespaceManager()
|
||||
context = manager.build_context(request_id="req_123")
|
||||
|
||||
assert "system" in context
|
||||
assert context["system"]["request_id"] == "req_123"
|
||||
|
||||
def test_namespace_manager_get_builder(self):
|
||||
from application.templates.namespaces import NamespaceManager, SystemNamespace
|
||||
|
||||
manager = NamespaceManager()
|
||||
builder = manager.get_builder("system")
|
||||
|
||||
assert isinstance(builder, SystemNamespace)
|
||||
|
||||
def test_namespace_manager_get_builder_nonexistent(self):
|
||||
from application.templates.namespaces import NamespaceManager
|
||||
|
||||
manager = NamespaceManager()
|
||||
builder = manager.get_builder("nonexistent")
|
||||
|
||||
assert builder is None
|
||||
|
||||
def test_namespace_manager_handles_builder_exceptions(self):
|
||||
from unittest.mock import patch
|
||||
|
||||
from application.templates.namespaces import NamespaceManager
|
||||
|
||||
manager = NamespaceManager()
|
||||
|
||||
with patch.object(
|
||||
manager._builders["system"],
|
||||
"build",
|
||||
side_effect=Exception("Builder error"),
|
||||
):
|
||||
context = manager.build_context()
|
||||
# Namespace should be present but empty when builder fails
|
||||
|
||||
assert "system" in context
|
||||
assert context["system"] == {}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPromptRenderer:
|
||||
|
||||
def test_render_prompt_with_template_syntax(self):
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
prompt = "Hello {{ system.user_id }}, today is {{ system.date }}"
|
||||
|
||||
result = renderer.render_prompt(prompt, user_id="user_123")
|
||||
|
||||
assert "user_123" in result
|
||||
assert "202" in result
|
||||
|
||||
def test_render_prompt_with_passthrough_data(self):
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
prompt = "Company: {{ passthrough.company }}\nUser: {{ passthrough.user_name }}"
|
||||
passthrough_data = {"company": "Acme", "user_name": "John"}
|
||||
|
||||
result = renderer.render_prompt(prompt, passthrough_data=passthrough_data)
|
||||
|
||||
assert "Company: Acme" in result
|
||||
assert "User: John" in result
|
||||
|
||||
def test_render_prompt_with_source_docs(self):
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
prompt = "Use this information:\n{{ source.content }}"
|
||||
docs_together = "Important document content"
|
||||
|
||||
result = renderer.render_prompt(prompt, docs_together=docs_together)
|
||||
|
||||
assert "Use this information:" in result
|
||||
assert "Important document content" in result
|
||||
|
||||
def test_render_prompt_empty_content(self):
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
result = renderer.render_prompt("")
|
||||
|
||||
assert result == ""
|
||||
|
||||
def test_render_prompt_legacy_format_with_summaries(self):
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
prompt = "Context: {summaries}\nQuestion: What is this?"
|
||||
docs_together = "This is the document content"
|
||||
|
||||
result = renderer.render_prompt(prompt, docs_together=docs_together)
|
||||
|
||||
assert "Context: This is the document content" in result
|
||||
|
||||
def test_render_prompt_legacy_format_without_docs(self):
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
prompt = "Context: {summaries}\nQuestion: What is this?"
|
||||
|
||||
result = renderer.render_prompt(prompt)
|
||||
|
||||
assert "Context: {summaries}" in result
|
||||
|
||||
def test_render_prompt_combined_namespace_variables(self):
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
prompt = "User: {{ passthrough.user }}, Date: {{ system.date }}, Docs: {{ source.content }}"
|
||||
passthrough_data = {"user": "Alice"}
|
||||
docs_together = "Doc content"
|
||||
|
||||
result = renderer.render_prompt(
|
||||
prompt,
|
||||
passthrough_data=passthrough_data,
|
||||
docs_together=docs_together,
|
||||
)
|
||||
|
||||
assert "User: Alice" in result
|
||||
assert "Date: 202" in result
|
||||
assert "Doc content" in result
|
||||
|
||||
def test_render_prompt_with_tools_data(self):
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
prompt = "Memory contents:\n{{ tools.memory.root }}\n\nStatus: {{ tools.memory.available }}"
|
||||
tools_data = {
|
||||
"memory": {"root": "Files:\n- /notes.txt\n- /tasks.txt", "available": True}
|
||||
}
|
||||
|
||||
result = renderer.render_prompt(prompt, tools_data=tools_data)
|
||||
|
||||
assert "Memory contents:" in result
|
||||
assert "Files:" in result
|
||||
assert "/notes.txt" in result
|
||||
assert "/tasks.txt" in result
|
||||
assert "Status: True" in result
|
||||
|
||||
def test_render_prompt_with_all_namespaces(self):
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
prompt = """
|
||||
System: {{ system.date }}
|
||||
User: {{ passthrough.user }}
|
||||
Docs: {{ source.content }}
|
||||
Memory: {{ tools.memory.root }}
|
||||
"""
|
||||
passthrough_data = {"user": "Alice"}
|
||||
docs_together = "Important docs"
|
||||
tools_data = {"memory": {"root": "Notes content", "available": True}}
|
||||
|
||||
result = renderer.render_prompt(
|
||||
prompt,
|
||||
passthrough_data=passthrough_data,
|
||||
docs_together=docs_together,
|
||||
tools_data=tools_data,
|
||||
)
|
||||
|
||||
assert "202" in result
|
||||
assert "Alice" in result
|
||||
assert "Important docs" in result
|
||||
assert "Notes content" in result
|
||||
|
||||
def test_render_prompt_undefined_variable_returns_empty_string(self):
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
prompt = "Hello {{ undefined_var }}"
|
||||
|
||||
result = renderer.render_prompt(prompt)
|
||||
assert result == "Hello "
|
||||
|
||||
def test_render_prompt_with_undefined_variable_in_template(self):
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
prompt = "Hello {{ undefined_name }}"
|
||||
|
||||
result = renderer.render_prompt(prompt)
|
||||
assert result == "Hello "
|
||||
|
||||
def test_validate_template_valid(self):
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
assert renderer.validate_template("Valid {{ variable }}") is True
|
||||
|
||||
def test_validate_template_invalid(self):
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
assert renderer.validate_template("Invalid {{ variable") is False
|
||||
|
||||
def test_extract_variables(self):
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
template = "{{ var1 }} and {{ var2 }}"
|
||||
|
||||
result = renderer.extract_variables(template)
|
||||
|
||||
assert isinstance(result, set)
|
||||
|
||||
def test_uses_template_syntax_detection(self):
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
|
||||
assert renderer._uses_template_syntax("Text with {{ var }}") is True
|
||||
assert renderer._uses_template_syntax("Text with {var}") is False
|
||||
assert renderer._uses_template_syntax("Plain text") is False
|
||||
|
||||
def test_apply_legacy_substitutions(self):
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
prompt = "Use {summaries} to answer"
|
||||
docs_together = "Important info"
|
||||
|
||||
result = renderer._apply_legacy_substitutions(prompt, docs_together)
|
||||
|
||||
assert "Use Important info to answer" in result
|
||||
|
||||
def test_apply_legacy_substitutions_without_docs(self):
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
prompt = "Use {summaries} to answer"
|
||||
|
||||
result = renderer._apply_legacy_substitutions(prompt, None)
|
||||
|
||||
assert result == prompt
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPromptRendererIntegration:
|
||||
|
||||
def test_render_prompt_real_world_scenario(self):
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
prompt = "You are helping {{ passthrough.company }}.\n\nUser: {{ passthrough.user_name }}\n\nRequest ID: {{ system.request_id }}\n\nDate: {{ system.date }}\n\nReference Documents:\n\n{{ source.content }}\n\nPlease answer the question professionally."
|
||||
|
||||
passthrough_data = {"company": "Tech Corp", "user_name": "Alice"}
|
||||
docs_together = "Document 1: Technical specs\nDocument 2: Requirements"
|
||||
|
||||
result = renderer.render_prompt(
|
||||
prompt,
|
||||
request_id="req_123",
|
||||
user_id="user_456",
|
||||
passthrough_data=passthrough_data,
|
||||
docs_together=docs_together,
|
||||
)
|
||||
|
||||
assert "Tech Corp" in result
|
||||
assert "Alice" in result
|
||||
assert "req_123" in result
|
||||
assert "Technical specs" in result
|
||||
assert "professionally" in result
|
||||
|
||||
def test_render_prompt_multiple_doc_references(self):
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
prompt = """Documents: {{ source.content }} \n\nAlso summaries: {{ source.summaries }}"""
|
||||
docs_together = "Content here"
|
||||
|
||||
result = renderer.render_prompt(prompt, docs_together=docs_together)
|
||||
|
||||
assert result.count("Content here") == 2
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestStreamProcessorPromptRendering:
|
||||
|
||||
def test_stream_processor_pre_fetch_docs_none_doc_mode(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
|
||||
request_data = {"question": "Test question", "isNoneDoc": True}
|
||||
processor = StreamProcessor(request_data, None)
|
||||
|
||||
docs_together, docs_list = processor.pre_fetch_docs("Test question")
|
||||
|
||||
assert docs_together is None
|
||||
assert docs_list is None
|
||||
|
||||
def test_pre_fetch_tools_disabled_globally(self, mock_mongo_db, monkeypatch):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.core.settings import settings
|
||||
|
||||
monkeypatch.setattr(settings, "ENABLE_TOOL_PREFETCH", False)
|
||||
|
||||
request_data = {"question": "test"}
|
||||
processor = StreamProcessor(request_data, {"sub": "user1"})
|
||||
|
||||
result = processor.pre_fetch_tools()
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_pre_fetch_tools_disabled_per_request(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
|
||||
request_data = {"question": "test", "disable_tool_prefetch": True}
|
||||
processor = StreamProcessor(request_data, {"sub": "user1"})
|
||||
|
||||
result = processor.pre_fetch_tools()
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_pre_fetch_tools_skips_tool_with_no_actions(self, mock_mongo_db):
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.core.mongo_db import MongoDB
|
||||
from bson import ObjectId
|
||||
|
||||
db = MongoDB.get_client()[list(MongoDB.get_client().keys())[0]]
|
||||
tool_doc = {
|
||||
"_id": ObjectId(),
|
||||
"name": "memory",
|
||||
"user": "user1",
|
||||
"status": True,
|
||||
"config": {},
|
||||
}
|
||||
db["user_tools"].insert_one(tool_doc)
|
||||
|
||||
request_data = {"question": "test"}
|
||||
processor = StreamProcessor(request_data, {"sub": "user1"})
|
||||
|
||||
with patch(
|
||||
"application.agents.tools.tool_manager.ToolManager"
|
||||
) as mock_manager_class:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager_class.return_value = mock_manager
|
||||
|
||||
# Mock the tool instance
|
||||
mock_tool = MagicMock()
|
||||
mock_manager.load_tool.return_value = mock_tool
|
||||
|
||||
# Tool has no actions
|
||||
mock_tool.get_actions_metadata.return_value = []
|
||||
|
||||
result = processor.pre_fetch_tools()
|
||||
|
||||
# Should return None when tool has no actions
|
||||
assert result is None
|
||||
|
||||
def test_pre_fetch_tools_enabled_by_default(self, mock_mongo_db, monkeypatch):
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.core.mongo_db import MongoDB
|
||||
from bson import ObjectId
|
||||
|
||||
db = MongoDB.get_client()[list(MongoDB.get_client().keys())[0]]
|
||||
tool_doc = {
|
||||
"_id": ObjectId(),
|
||||
"name": "memory",
|
||||
"user": "user1",
|
||||
"status": True,
|
||||
"config": {},
|
||||
}
|
||||
db["user_tools"].insert_one(tool_doc)
|
||||
|
||||
request_data = {"question": "test"}
|
||||
processor = StreamProcessor(request_data, {"sub": "user1"})
|
||||
|
||||
with patch(
|
||||
"application.agents.tools.tool_manager.ToolManager"
|
||||
) as mock_manager_class:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager_class.return_value = mock_manager
|
||||
|
||||
# Mock the tool instance returned by load_tool
|
||||
mock_tool = MagicMock()
|
||||
mock_manager.load_tool.return_value = mock_tool
|
||||
|
||||
# Mock get_actions_metadata on the tool instance
|
||||
mock_tool.get_actions_metadata.return_value = [
|
||||
{"name": "memory_ls", "description": "List files", "parameters": {"properties": {}}}
|
||||
]
|
||||
mock_tool.execute_action.return_value = "Directory: /\n- file.txt"
|
||||
|
||||
result = processor.pre_fetch_tools()
|
||||
|
||||
assert result is not None
|
||||
assert "memory" in result
|
||||
assert "memory_ls" in result["memory"]
|
||||
|
||||
def test_pre_fetch_tools_no_tools_configured(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
|
||||
request_data = {"question": "test"}
|
||||
processor = StreamProcessor(request_data, {"sub": "user1"})
|
||||
|
||||
result = processor.pre_fetch_tools()
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_pre_fetch_tools_memory_returns_error(self, mock_mongo_db):
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.core.mongo_db import MongoDB
|
||||
from bson import ObjectId
|
||||
|
||||
db = MongoDB.get_client()[list(MongoDB.get_client().keys())[0]]
|
||||
tool_doc = {
|
||||
"_id": ObjectId(),
|
||||
"name": "memory",
|
||||
"user": "user1",
|
||||
"status": True,
|
||||
"config": {},
|
||||
}
|
||||
db["user_tools"].insert_one(tool_doc)
|
||||
|
||||
request_data = {"question": "test"}
|
||||
processor = StreamProcessor(request_data, {"sub": "user1"})
|
||||
|
||||
with patch(
|
||||
"application.agents.tools.tool_manager.ToolManager"
|
||||
) as mock_manager_class:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager_class.return_value = mock_manager
|
||||
|
||||
# Mock the tool instance
|
||||
mock_tool = MagicMock()
|
||||
mock_manager.load_tool.return_value = mock_tool
|
||||
|
||||
mock_tool.get_actions_metadata.return_value = [
|
||||
{"name": "memory_ls", "description": "List files", "parameters": {"properties": {}}}
|
||||
]
|
||||
# Simulate execution error
|
||||
mock_tool.execute_action.side_effect = Exception("Tool error")
|
||||
|
||||
result = processor.pre_fetch_tools()
|
||||
|
||||
# Should return None when all actions fail
|
||||
assert result is None
|
||||
|
||||
def test_pre_fetch_tools_memory_returns_empty(self, mock_mongo_db):
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.core.mongo_db import MongoDB
|
||||
from bson import ObjectId
|
||||
|
||||
db = MongoDB.get_client()[list(MongoDB.get_client().keys())[0]]
|
||||
tool_doc = {
|
||||
"_id": ObjectId(),
|
||||
"name": "memory",
|
||||
"user": "user1",
|
||||
"status": True,
|
||||
"config": {},
|
||||
}
|
||||
db["user_tools"].insert_one(tool_doc)
|
||||
|
||||
request_data = {"question": "test"}
|
||||
processor = StreamProcessor(request_data, {"sub": "user1"})
|
||||
|
||||
with patch(
|
||||
"application.agents.tools.tool_manager.ToolManager"
|
||||
) as mock_manager_class:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager_class.return_value = mock_manager
|
||||
|
||||
# Mock the tool instance
|
||||
mock_tool = MagicMock()
|
||||
mock_manager.load_tool.return_value = mock_tool
|
||||
|
||||
mock_tool.get_actions_metadata.return_value = [
|
||||
{"name": "memory_ls", "description": "List files", "parameters": {"properties": {}}}
|
||||
]
|
||||
# Return empty string
|
||||
mock_tool.execute_action.return_value = ""
|
||||
|
||||
result = processor.pre_fetch_tools()
|
||||
|
||||
# Empty result should still be included
|
||||
assert result is not None
|
||||
assert "memory" in result
|
||||
@@ -250,3 +250,330 @@ class TestStreamProcessorAttachments:
|
||||
"attachments" not in processor.data
|
||||
or processor.data.get("attachments") is None
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestToolPreFetch:
|
||||
"""Tests for tool pre-fetching with saved parameter values from MongoDB"""
|
||||
|
||||
def test_cryptoprice_prefetch_with_saved_parameters(self, mock_mongo_db):
|
||||
"""Test that cryptoprice tool is pre-fetched with saved parameter values from MongoDB structure"""
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.core.settings import settings
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
# Setup MongoDB with cryptoprice tool configuration
|
||||
# NOTE: The collection is called "user_tools" not "tools"
|
||||
tools_collection = mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"]
|
||||
tool_id = ObjectId()
|
||||
|
||||
tools_collection.insert_one(
|
||||
{
|
||||
"_id": tool_id,
|
||||
"name": "cryptoprice",
|
||||
"user": "user_123",
|
||||
"status": True, # Must be True for tool to be included
|
||||
"actions": [
|
||||
{
|
||||
"name": "cryptoprice_get",
|
||||
"description": "Get cryptocurrency price",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"symbol": {
|
||||
"type": "string",
|
||||
"description": "Crypto symbol",
|
||||
"value": "BTC" # Saved value in MongoDB
|
||||
},
|
||||
"currency": {
|
||||
"type": "string",
|
||||
"description": "Currency for price",
|
||||
"value": "USD" # Saved value in MongoDB
|
||||
}
|
||||
},
|
||||
"required": ["symbol", "currency"]
|
||||
}
|
||||
}
|
||||
],
|
||||
"config": {
|
||||
"token": ""
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"question": "What is the price of Bitcoin?",
|
||||
"tools": [str(tool_id)]
|
||||
}
|
||||
|
||||
processor = StreamProcessor(request_data, {"sub": "user_123"})
|
||||
processor._required_tool_actions = {"cryptoprice": {"cryptoprice_get"}}
|
||||
|
||||
# Mock the ToolManager and tool instance
|
||||
with patch("application.agents.tools.tool_manager.ToolManager") as mock_manager_class:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager_class.return_value = mock_manager
|
||||
|
||||
# Mock the tool instance returned by load_tool
|
||||
mock_tool = MagicMock()
|
||||
mock_manager.load_tool.return_value = mock_tool
|
||||
|
||||
# Mock get_actions_metadata on the tool instance
|
||||
mock_tool.get_actions_metadata.return_value = [
|
||||
{
|
||||
"name": "cryptoprice_get",
|
||||
"description": "Get cryptocurrency price",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"symbol": {"type": "string", "description": "Crypto symbol"},
|
||||
"currency": {"type": "string", "description": "Currency for price"}
|
||||
},
|
||||
"required": ["symbol", "currency"]
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
# Mock execute_action on the tool instance to return price data
|
||||
mock_tool.execute_action.return_value = {
|
||||
"status_code": 200,
|
||||
"price": 45000.50,
|
||||
"message": "Price of BTC in USD retrieved successfully."
|
||||
}
|
||||
|
||||
# Execute pre-fetch
|
||||
tools_data = processor.pre_fetch_tools()
|
||||
|
||||
# Verify the tool was called
|
||||
assert mock_tool.execute_action.called
|
||||
|
||||
# Verify it was called with the saved parameters from MongoDB
|
||||
call_args = mock_tool.execute_action.call_args
|
||||
assert call_args is not None
|
||||
|
||||
# Check action name uses the full metadata name for execution
|
||||
assert call_args[0][0] == "cryptoprice_get"
|
||||
|
||||
# Check kwargs contain saved values
|
||||
kwargs = call_args[1]
|
||||
assert kwargs.get("symbol") == "BTC"
|
||||
assert kwargs.get("currency") == "USD"
|
||||
|
||||
# Verify tools_data structure
|
||||
assert "cryptoprice" in tools_data
|
||||
# Results are exposed under the full action name
|
||||
assert "cryptoprice_get" in tools_data["cryptoprice"]
|
||||
assert tools_data["cryptoprice"]["cryptoprice_get"]["price"] == 45000.50
|
||||
|
||||
def test_prefetch_with_missing_saved_values_uses_defaults(self, mock_mongo_db):
|
||||
"""Test that pre-fetch falls back to defaults when saved values are missing"""
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.core.settings import settings
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
tools_collection = mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"]
|
||||
tool_id = ObjectId()
|
||||
|
||||
# Tool configuration without saved values
|
||||
tools_collection.insert_one(
|
||||
{
|
||||
"_id": tool_id,
|
||||
"name": "cryptoprice",
|
||||
"user": "user_123",
|
||||
"status": True,
|
||||
"actions": [
|
||||
{
|
||||
"name": "cryptoprice_get",
|
||||
"description": "Get cryptocurrency price",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"symbol": {
|
||||
"type": "string",
|
||||
"description": "Crypto symbol",
|
||||
"default": "ETH" # Only default, no saved value
|
||||
},
|
||||
"currency": {
|
||||
"type": "string",
|
||||
"description": "Currency",
|
||||
"default": "EUR"
|
||||
}
|
||||
},
|
||||
"required": ["symbol", "currency"]
|
||||
}
|
||||
}
|
||||
],
|
||||
"config": {}
|
||||
}
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"question": "Crypto price?",
|
||||
"tools": [str(tool_id)]
|
||||
}
|
||||
|
||||
processor = StreamProcessor(request_data, {"sub": "user_123"})
|
||||
processor._required_tool_actions = {"cryptoprice": {"cryptoprice_get"}}
|
||||
|
||||
with patch("application.agents.tools.tool_manager.ToolManager") as mock_manager_class:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager_class.return_value = mock_manager
|
||||
|
||||
# Mock the tool instance
|
||||
mock_tool = MagicMock()
|
||||
mock_manager.load_tool.return_value = mock_tool
|
||||
|
||||
mock_tool.get_actions_metadata.return_value = [
|
||||
{
|
||||
"name": "cryptoprice_get",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"symbol": {"type": "string", "default": "ETH"},
|
||||
"currency": {"type": "string", "default": "EUR"}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
mock_tool.execute_action.return_value = {
|
||||
"status_code": 200,
|
||||
"price": 2500.00
|
||||
}
|
||||
|
||||
processor.pre_fetch_tools()
|
||||
|
||||
# Should use default values when saved values are missing
|
||||
call_args = mock_tool.execute_action.call_args
|
||||
if call_args:
|
||||
kwargs = call_args[1]
|
||||
# Either uses defaults or skips if no values available
|
||||
assert kwargs.get("symbol") in ["ETH", None]
|
||||
assert kwargs.get("currency") in ["EUR", None]
|
||||
|
||||
def test_prefetch_with_tool_id_reference(self, mock_mongo_db):
|
||||
"""Test that tools can be referenced by MongoDB ObjectId in templates"""
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.core.settings import settings
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
tools_collection = mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"]
|
||||
tool_id = ObjectId()
|
||||
|
||||
# Create a tool in the database
|
||||
tools_collection.insert_one(
|
||||
{
|
||||
"_id": tool_id,
|
||||
"name": "memory",
|
||||
"user": "user_123",
|
||||
"status": True,
|
||||
"actions": [
|
||||
{
|
||||
"name": "memory_ls",
|
||||
"description": "List files",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}
|
||||
}
|
||||
],
|
||||
"config": {},
|
||||
}
|
||||
)
|
||||
|
||||
request_data = {"question": "test"}
|
||||
processor = StreamProcessor(request_data, {"sub": "user_123"})
|
||||
|
||||
# Mock the filtering to require this specific tool by ID
|
||||
processor._required_tool_actions = {
|
||||
str(tool_id): {"memory_ls"} # Reference by ObjectId string
|
||||
}
|
||||
|
||||
with patch(
|
||||
"application.agents.tools.tool_manager.ToolManager"
|
||||
) as mock_manager_class:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager_class.return_value = mock_manager
|
||||
|
||||
# Mock the tool instance
|
||||
mock_tool = MagicMock()
|
||||
mock_manager.load_tool.return_value = mock_tool
|
||||
|
||||
mock_tool.get_actions_metadata.return_value = [
|
||||
{"name": "memory_ls", "description": "List files", "parameters": {"properties": {}}}
|
||||
]
|
||||
mock_tool.execute_action.return_value = "Directory: /\n- file.txt"
|
||||
|
||||
result = processor.pre_fetch_tools()
|
||||
|
||||
# Tool data should be available under both name and ID
|
||||
assert result is not None
|
||||
assert "memory" in result
|
||||
assert str(tool_id) in result
|
||||
# Both should point to the same data
|
||||
assert result["memory"] == result[str(tool_id)]
|
||||
assert "memory_ls" in result[str(tool_id)]
|
||||
|
||||
def test_prefetch_with_multiple_same_name_tools(self, mock_mongo_db):
|
||||
"""Test that multiple tools with the same name can be distinguished by ID"""
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.core.settings import settings
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
tools_collection = mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"]
|
||||
|
||||
# Create two memory tools with different IDs
|
||||
tool_id_1 = ObjectId()
|
||||
tool_id_2 = ObjectId()
|
||||
|
||||
tools_collection.insert_many([
|
||||
{
|
||||
"_id": tool_id_1,
|
||||
"name": "memory",
|
||||
"user": "user_123",
|
||||
"status": True,
|
||||
"actions": [{"name": "memory_ls", "parameters": {"properties": {}}}],
|
||||
"config": {"path": "/home"},
|
||||
},
|
||||
{
|
||||
"_id": tool_id_2,
|
||||
"name": "memory",
|
||||
"user": "user_123",
|
||||
"status": True,
|
||||
"actions": [{"name": "memory_ls", "parameters": {"properties": {}}}],
|
||||
"config": {"path": "/work"},
|
||||
}
|
||||
])
|
||||
|
||||
request_data = {"question": "test"}
|
||||
processor = StreamProcessor(request_data, {"sub": "user_123"})
|
||||
|
||||
# Mock the filtering to require only the second tool by ID
|
||||
processor._required_tool_actions = {
|
||||
str(tool_id_2): {"memory_ls"} # Only reference the second one
|
||||
}
|
||||
|
||||
with patch(
|
||||
"application.agents.tools.tool_manager.ToolManager"
|
||||
) as mock_manager_class:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager_class.return_value = mock_manager
|
||||
|
||||
# Mock the tool instance
|
||||
mock_tool = MagicMock()
|
||||
mock_manager.load_tool.return_value = mock_tool
|
||||
|
||||
mock_tool.get_actions_metadata.return_value = [
|
||||
{"name": "memory_ls", "parameters": {"properties": {}}}
|
||||
]
|
||||
mock_tool.execute_action.return_value = "Work directory"
|
||||
|
||||
result = processor.pre_fetch_tools()
|
||||
|
||||
# Only the second tool should be fetched (referenced by ID)
|
||||
assert result is not None
|
||||
assert str(tool_id_2) in result
|
||||
# Since filtering is enabled and only tool_id_2 is referenced,
|
||||
# only tool_id_2 should be pre-fetched
|
||||
# The "memory" key will still exist because we store under both name and ID
|
||||
assert "memory" in result
|
||||
|
||||
Reference in New Issue
Block a user