mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-12-01 09:33:14 +00:00
feat: template-based prompt rendering with dynamic namespace injection (#2091)
* feat: template-based prompt rendering with dynamic namespace injection * refactor: improve template engine initialization with clearer formatting * refactor: streamline ReActAgent methods and improve content extraction logic feat: enhance error handling in NamespaceManager and TemplateEngine fix: update NewAgent component to ensure consistent form data submission test: modify tests for ReActAgent and prompt renderer to reflect method changes and improve coverage * feat: tools namespace + three-tier token budget * refactor: remove unused variable assignment in message building tests * Enhance prompt customization and tool pre-fetching functionality * ruff lint fix * refactor: cleaner error handling and reduce code clutter --------- Co-authored-by: Alex <a@tushynski.me>
This commit is contained in:
@@ -315,16 +315,12 @@ class TestCompleteStreamMethod:
|
||||
]
|
||||
)
|
||||
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.get_params.return_value = {}
|
||||
|
||||
decoded_token = {"sub": "user123"}
|
||||
|
||||
stream = list(
|
||||
resource.complete_stream(
|
||||
question="Test question",
|
||||
agent=mock_agent,
|
||||
retriever=mock_retriever,
|
||||
conversation_id=None,
|
||||
user_api_key=None,
|
||||
decoded_token=decoded_token,
|
||||
@@ -351,16 +347,12 @@ class TestCompleteStreamMethod:
|
||||
]
|
||||
)
|
||||
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.get_params.return_value = {}
|
||||
|
||||
decoded_token = {"sub": "user123"}
|
||||
|
||||
stream = list(
|
||||
resource.complete_stream(
|
||||
question="Test?",
|
||||
agent=mock_agent,
|
||||
retriever=mock_retriever,
|
||||
conversation_id=None,
|
||||
user_api_key=None,
|
||||
decoded_token=decoded_token,
|
||||
@@ -381,16 +373,12 @@ class TestCompleteStreamMethod:
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.gen.side_effect = Exception("Test error")
|
||||
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.get_params.return_value = {}
|
||||
|
||||
decoded_token = {"sub": "user123"}
|
||||
|
||||
stream = list(
|
||||
resource.complete_stream(
|
||||
question="Test?",
|
||||
agent=mock_agent,
|
||||
retriever=mock_retriever,
|
||||
conversation_id=None,
|
||||
user_api_key=None,
|
||||
decoded_token=decoded_token,
|
||||
@@ -413,9 +401,6 @@ class TestCompleteStreamMethod:
|
||||
]
|
||||
)
|
||||
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.get_params.return_value = {}
|
||||
|
||||
decoded_token = {"sub": "user123"}
|
||||
|
||||
with patch.object(
|
||||
@@ -427,8 +412,7 @@ class TestCompleteStreamMethod:
|
||||
resource.complete_stream(
|
||||
question="Test?",
|
||||
agent=mock_agent,
|
||||
retriever=mock_retriever,
|
||||
conversation_id=None,
|
||||
conversation_id=None,
|
||||
user_api_key=None,
|
||||
decoded_token=decoded_token,
|
||||
should_save_conversation=True,
|
||||
@@ -461,7 +445,6 @@ class TestCompleteStreamMethod:
|
||||
resource.complete_stream(
|
||||
question="Test question?",
|
||||
agent=mock_agent,
|
||||
retriever=mock_retriever,
|
||||
conversation_id=None,
|
||||
user_api_key="test_key",
|
||||
decoded_token=decoded_token,
|
||||
|
||||
850
tests/api/answer/services/test_prompt_renderer.py
Normal file
850
tests/api/answer/services/test_prompt_renderer.py
Normal file
@@ -0,0 +1,850 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestTemplateEngine:
|
||||
|
||||
def test_render_simple_template(self):
|
||||
from application.templates.template_engine import TemplateEngine
|
||||
|
||||
engine = TemplateEngine()
|
||||
result = engine.render("Hello {{ name }}", {"name": "World"})
|
||||
|
||||
assert result == "Hello World"
|
||||
|
||||
def test_render_with_namespace(self):
|
||||
from application.templates.template_engine import TemplateEngine
|
||||
|
||||
engine = TemplateEngine()
|
||||
context = {
|
||||
"user": {"name": "Alice", "role": "admin"},
|
||||
"system": {"date": "2025-10-22"},
|
||||
}
|
||||
result = engine.render(
|
||||
"{{ user.name }} is a {{ user.role }} on {{ system.date }}", context
|
||||
)
|
||||
|
||||
assert result == "Alice is a admin on 2025-10-22"
|
||||
|
||||
def test_render_empty_template(self):
|
||||
from application.templates.template_engine import TemplateEngine
|
||||
|
||||
engine = TemplateEngine()
|
||||
result = engine.render("", {"key": "value"})
|
||||
|
||||
assert result == ""
|
||||
|
||||
def test_render_template_without_variables(self):
|
||||
from application.templates.template_engine import TemplateEngine
|
||||
|
||||
engine = TemplateEngine()
|
||||
result = engine.render("Just plain text", {})
|
||||
|
||||
assert result == "Just plain text"
|
||||
|
||||
def test_render_undefined_variable_returns_empty_string(self):
|
||||
from application.templates.template_engine import TemplateEngine
|
||||
|
||||
engine = TemplateEngine()
|
||||
|
||||
result = engine.render("Hello {{ undefined_var }}", {})
|
||||
assert result == "Hello "
|
||||
|
||||
def test_render_syntax_error_raises_error(self):
|
||||
from application.templates.template_engine import (
|
||||
TemplateEngine,
|
||||
TemplateRenderError,
|
||||
)
|
||||
|
||||
engine = TemplateEngine()
|
||||
|
||||
with pytest.raises(TemplateRenderError, match="Template syntax error"):
|
||||
engine.render("Hello {{ name", {"name": "World"})
|
||||
|
||||
def test_validate_template_valid(self):
|
||||
from application.templates.template_engine import TemplateEngine
|
||||
|
||||
engine = TemplateEngine()
|
||||
assert engine.validate_template("Valid {{ variable }}") is True
|
||||
|
||||
def test_validate_template_invalid(self):
|
||||
from application.templates.template_engine import TemplateEngine
|
||||
|
||||
engine = TemplateEngine()
|
||||
assert engine.validate_template("Invalid {{ variable") is False
|
||||
|
||||
def test_validate_empty_template(self):
|
||||
from application.templates.template_engine import TemplateEngine
|
||||
|
||||
engine = TemplateEngine()
|
||||
assert engine.validate_template("") is True
|
||||
|
||||
def test_extract_variables(self):
|
||||
from application.templates.template_engine import TemplateEngine
|
||||
|
||||
engine = TemplateEngine()
|
||||
template = "{{ user.name }} and {{ user.email }}"
|
||||
|
||||
result = engine.extract_variables(template)
|
||||
|
||||
assert isinstance(result, set)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSystemNamespace:
|
||||
|
||||
def test_system_namespace_build(self):
|
||||
from application.templates.namespaces import SystemNamespace
|
||||
|
||||
builder = SystemNamespace()
|
||||
context = builder.build(
|
||||
request_id="req_123", user_id="user_456", extra_param="ignored"
|
||||
)
|
||||
|
||||
assert context["request_id"] == "req_123"
|
||||
assert context["user_id"] == "user_456"
|
||||
assert "date" in context
|
||||
assert "time" in context
|
||||
assert "timestamp" in context
|
||||
|
||||
def test_system_namespace_generates_request_id(self):
|
||||
from application.templates.namespaces import SystemNamespace
|
||||
|
||||
builder = SystemNamespace()
|
||||
context = builder.build(user_id="user_123")
|
||||
|
||||
assert context["request_id"] is not None
|
||||
assert len(context["request_id"]) > 0
|
||||
|
||||
def test_system_namespace_name(self):
|
||||
from application.templates.namespaces import SystemNamespace
|
||||
|
||||
builder = SystemNamespace()
|
||||
assert builder.namespace_name == "system"
|
||||
|
||||
def test_system_namespace_date_format(self):
|
||||
from application.templates.namespaces import SystemNamespace
|
||||
|
||||
builder = SystemNamespace()
|
||||
context = builder.build()
|
||||
|
||||
import re
|
||||
|
||||
assert re.match(r"\d{4}-\d{2}-\d{2}", context["date"])
|
||||
assert re.match(r"\d{2}:\d{2}:\d{2}", context["time"])
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPassthroughNamespace:
|
||||
|
||||
def test_passthrough_namespace_build(self):
|
||||
from application.templates.namespaces import PassthroughNamespace
|
||||
|
||||
builder = PassthroughNamespace()
|
||||
passthrough_data = {"company": "Acme", "user_name": "John", "count": 42}
|
||||
|
||||
context = builder.build(passthrough_data=passthrough_data)
|
||||
|
||||
assert context["company"] == "Acme"
|
||||
assert context["user_name"] == "John"
|
||||
assert context["count"] == 42
|
||||
|
||||
def test_passthrough_namespace_empty(self):
|
||||
from application.templates.namespaces import PassthroughNamespace
|
||||
|
||||
builder = PassthroughNamespace()
|
||||
context = builder.build(passthrough_data=None)
|
||||
|
||||
assert context == {}
|
||||
|
||||
def test_passthrough_namespace_filters_unsafe_values(self):
|
||||
from application.templates.namespaces import PassthroughNamespace
|
||||
|
||||
builder = PassthroughNamespace()
|
||||
passthrough_data = {
|
||||
"safe_string": "value",
|
||||
"unsafe_object": {"key": "value"},
|
||||
"safe_bool": True,
|
||||
"unsafe_list": [1, 2, 3],
|
||||
"safe_float": 3.14,
|
||||
}
|
||||
|
||||
context = builder.build(passthrough_data=passthrough_data)
|
||||
|
||||
assert context["safe_string"] == "value"
|
||||
assert context["safe_bool"] is True
|
||||
assert context["safe_float"] == 3.14
|
||||
assert "unsafe_object" not in context
|
||||
assert "unsafe_list" not in context
|
||||
|
||||
def test_passthrough_namespace_allows_none_values(self):
|
||||
from application.templates.namespaces import PassthroughNamespace
|
||||
|
||||
builder = PassthroughNamespace()
|
||||
passthrough_data = {"nullable_field": None}
|
||||
|
||||
context = builder.build(passthrough_data=passthrough_data)
|
||||
|
||||
assert context["nullable_field"] is None
|
||||
|
||||
def test_passthrough_namespace_name(self):
|
||||
from application.templates.namespaces import PassthroughNamespace
|
||||
|
||||
builder = PassthroughNamespace()
|
||||
assert builder.namespace_name == "passthrough"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSourceNamespace:
|
||||
|
||||
def test_source_namespace_build_with_docs(self):
|
||||
from application.templates.namespaces import SourceNamespace
|
||||
|
||||
builder = SourceNamespace()
|
||||
docs = [
|
||||
{"text": "Doc 1", "filename": "file1.txt"},
|
||||
{"text": "Doc 2", "filename": "file2.txt"},
|
||||
]
|
||||
docs_together = "Doc 1 content\n\nDoc 2 content"
|
||||
|
||||
context = builder.build(docs=docs, docs_together=docs_together)
|
||||
|
||||
assert context["documents"] == docs
|
||||
assert context["count"] == 2
|
||||
assert context["content"] == docs_together
|
||||
assert context["summaries"] == docs_together
|
||||
|
||||
def test_source_namespace_build_empty(self):
|
||||
from application.templates.namespaces import SourceNamespace
|
||||
|
||||
builder = SourceNamespace()
|
||||
context = builder.build(docs=None, docs_together=None)
|
||||
|
||||
assert context == {}
|
||||
|
||||
def test_source_namespace_build_docs_only(self):
|
||||
from application.templates.namespaces import SourceNamespace
|
||||
|
||||
builder = SourceNamespace()
|
||||
docs = [{"text": "Doc 1"}]
|
||||
|
||||
context = builder.build(docs=docs)
|
||||
|
||||
assert context["documents"] == docs
|
||||
assert context["count"] == 1
|
||||
assert "content" not in context
|
||||
|
||||
def test_source_namespace_build_docs_together_only(self):
|
||||
from application.templates.namespaces import SourceNamespace
|
||||
|
||||
builder = SourceNamespace()
|
||||
docs_together = "Content here"
|
||||
|
||||
context = builder.build(docs_together=docs_together)
|
||||
|
||||
assert context["content"] == docs_together
|
||||
assert context["summaries"] == docs_together
|
||||
assert "documents" not in context
|
||||
|
||||
def test_source_namespace_name(self):
|
||||
from application.templates.namespaces import SourceNamespace
|
||||
|
||||
builder = SourceNamespace()
|
||||
assert builder.namespace_name == "source"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestToolsNamespace:
|
||||
|
||||
def test_tools_namespace_build_with_memory_data(self):
|
||||
from application.templates.namespaces import ToolsNamespace
|
||||
|
||||
builder = ToolsNamespace()
|
||||
tools_data = {
|
||||
"memory": {"root": "Files:\n- /notes.txt\n- /tasks.txt", "available": True}
|
||||
}
|
||||
|
||||
context = builder.build(tools_data=tools_data)
|
||||
|
||||
assert context["memory"]["root"] == "Files:\n- /notes.txt\n- /tasks.txt"
|
||||
assert context["memory"]["available"] is True
|
||||
|
||||
def test_tools_namespace_build_empty(self):
|
||||
from application.templates.namespaces import ToolsNamespace
|
||||
|
||||
builder = ToolsNamespace()
|
||||
context = builder.build(tools_data=None)
|
||||
|
||||
assert context == {}
|
||||
|
||||
def test_tools_namespace_build_multiple_tools(self):
|
||||
from application.templates.namespaces import ToolsNamespace
|
||||
|
||||
builder = ToolsNamespace()
|
||||
tools_data = {
|
||||
"memory": {"root": "content", "available": True},
|
||||
"search": {"results": ["result1", "result2"]},
|
||||
"api": {"status": "success"},
|
||||
}
|
||||
|
||||
context = builder.build(tools_data=tools_data)
|
||||
|
||||
assert "memory" in context
|
||||
assert "search" in context
|
||||
assert "api" in context
|
||||
assert context["memory"]["root"] == "content"
|
||||
assert context["search"]["results"] == ["result1", "result2"]
|
||||
assert context["api"]["status"] == "success"
|
||||
|
||||
def test_tools_namespace_filters_unsafe_values(self):
|
||||
from application.templates.namespaces import ToolsNamespace
|
||||
|
||||
builder = ToolsNamespace()
|
||||
|
||||
class UnsafeObject:
|
||||
pass
|
||||
|
||||
tools_data = {"safe_tool": {"result": "success"}, "unsafe_tool": UnsafeObject()}
|
||||
|
||||
context = builder.build(tools_data=tools_data)
|
||||
|
||||
assert "safe_tool" in context
|
||||
assert "unsafe_tool" not in context
|
||||
|
||||
def test_tools_namespace_name(self):
|
||||
from application.templates.namespaces import ToolsNamespace
|
||||
|
||||
builder = ToolsNamespace()
|
||||
assert builder.namespace_name == "tools"
|
||||
|
||||
def test_tools_namespace_with_empty_dict(self):
|
||||
from application.templates.namespaces import ToolsNamespace
|
||||
|
||||
builder = ToolsNamespace()
|
||||
context = builder.build(tools_data={})
|
||||
|
||||
assert context == {}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestNamespaceManagerWithTools:
|
||||
|
||||
def test_namespace_manager_includes_tools_in_context(self):
|
||||
from application.templates.namespaces import NamespaceManager
|
||||
|
||||
manager = NamespaceManager()
|
||||
tools_data = {"memory": {"root": "content", "available": True}}
|
||||
|
||||
context = manager.build_context(tools_data=tools_data)
|
||||
|
||||
assert "tools" in context
|
||||
assert context["tools"]["memory"]["root"] == "content"
|
||||
|
||||
def test_namespace_manager_build_context_all_namespaces(self):
|
||||
from application.templates.namespaces import NamespaceManager
|
||||
|
||||
manager = NamespaceManager()
|
||||
context = manager.build_context(
|
||||
request_id="req_123",
|
||||
user_id="user_456",
|
||||
passthrough_data={"key": "value"},
|
||||
docs_together="Document content",
|
||||
tools_data={"memory": {"root": "notes"}},
|
||||
)
|
||||
|
||||
assert "system" in context
|
||||
assert "passthrough" in context
|
||||
assert "source" in context
|
||||
assert "tools" in context
|
||||
assert context["tools"]["memory"]["root"] == "notes"
|
||||
|
||||
def test_namespace_manager_build_context_partial_data(self):
|
||||
from application.templates.namespaces import NamespaceManager
|
||||
|
||||
manager = NamespaceManager()
|
||||
context = manager.build_context(request_id="req_123")
|
||||
|
||||
assert "system" in context
|
||||
assert context["system"]["request_id"] == "req_123"
|
||||
|
||||
def test_namespace_manager_get_builder(self):
|
||||
from application.templates.namespaces import NamespaceManager, SystemNamespace
|
||||
|
||||
manager = NamespaceManager()
|
||||
builder = manager.get_builder("system")
|
||||
|
||||
assert isinstance(builder, SystemNamespace)
|
||||
|
||||
def test_namespace_manager_get_builder_nonexistent(self):
|
||||
from application.templates.namespaces import NamespaceManager
|
||||
|
||||
manager = NamespaceManager()
|
||||
builder = manager.get_builder("nonexistent")
|
||||
|
||||
assert builder is None
|
||||
|
||||
def test_namespace_manager_handles_builder_exceptions(self):
|
||||
from unittest.mock import patch
|
||||
|
||||
from application.templates.namespaces import NamespaceManager
|
||||
|
||||
manager = NamespaceManager()
|
||||
|
||||
with patch.object(
|
||||
manager._builders["system"],
|
||||
"build",
|
||||
side_effect=Exception("Builder error"),
|
||||
):
|
||||
context = manager.build_context()
|
||||
# Namespace should be present but empty when builder fails
|
||||
|
||||
assert "system" in context
|
||||
assert context["system"] == {}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPromptRenderer:
|
||||
|
||||
def test_render_prompt_with_template_syntax(self):
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
prompt = "Hello {{ system.user_id }}, today is {{ system.date }}"
|
||||
|
||||
result = renderer.render_prompt(prompt, user_id="user_123")
|
||||
|
||||
assert "user_123" in result
|
||||
assert "202" in result
|
||||
|
||||
def test_render_prompt_with_passthrough_data(self):
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
prompt = "Company: {{ passthrough.company }}\nUser: {{ passthrough.user_name }}"
|
||||
passthrough_data = {"company": "Acme", "user_name": "John"}
|
||||
|
||||
result = renderer.render_prompt(prompt, passthrough_data=passthrough_data)
|
||||
|
||||
assert "Company: Acme" in result
|
||||
assert "User: John" in result
|
||||
|
||||
def test_render_prompt_with_source_docs(self):
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
prompt = "Use this information:\n{{ source.content }}"
|
||||
docs_together = "Important document content"
|
||||
|
||||
result = renderer.render_prompt(prompt, docs_together=docs_together)
|
||||
|
||||
assert "Use this information:" in result
|
||||
assert "Important document content" in result
|
||||
|
||||
def test_render_prompt_empty_content(self):
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
result = renderer.render_prompt("")
|
||||
|
||||
assert result == ""
|
||||
|
||||
def test_render_prompt_legacy_format_with_summaries(self):
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
prompt = "Context: {summaries}\nQuestion: What is this?"
|
||||
docs_together = "This is the document content"
|
||||
|
||||
result = renderer.render_prompt(prompt, docs_together=docs_together)
|
||||
|
||||
assert "Context: This is the document content" in result
|
||||
|
||||
def test_render_prompt_legacy_format_without_docs(self):
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
prompt = "Context: {summaries}\nQuestion: What is this?"
|
||||
|
||||
result = renderer.render_prompt(prompt)
|
||||
|
||||
assert "Context: {summaries}" in result
|
||||
|
||||
def test_render_prompt_combined_namespace_variables(self):
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
prompt = "User: {{ passthrough.user }}, Date: {{ system.date }}, Docs: {{ source.content }}"
|
||||
passthrough_data = {"user": "Alice"}
|
||||
docs_together = "Doc content"
|
||||
|
||||
result = renderer.render_prompt(
|
||||
prompt,
|
||||
passthrough_data=passthrough_data,
|
||||
docs_together=docs_together,
|
||||
)
|
||||
|
||||
assert "User: Alice" in result
|
||||
assert "Date: 202" in result
|
||||
assert "Doc content" in result
|
||||
|
||||
def test_render_prompt_with_tools_data(self):
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
prompt = "Memory contents:\n{{ tools.memory.root }}\n\nStatus: {{ tools.memory.available }}"
|
||||
tools_data = {
|
||||
"memory": {"root": "Files:\n- /notes.txt\n- /tasks.txt", "available": True}
|
||||
}
|
||||
|
||||
result = renderer.render_prompt(prompt, tools_data=tools_data)
|
||||
|
||||
assert "Memory contents:" in result
|
||||
assert "Files:" in result
|
||||
assert "/notes.txt" in result
|
||||
assert "/tasks.txt" in result
|
||||
assert "Status: True" in result
|
||||
|
||||
def test_render_prompt_with_all_namespaces(self):
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
prompt = """
|
||||
System: {{ system.date }}
|
||||
User: {{ passthrough.user }}
|
||||
Docs: {{ source.content }}
|
||||
Memory: {{ tools.memory.root }}
|
||||
"""
|
||||
passthrough_data = {"user": "Alice"}
|
||||
docs_together = "Important docs"
|
||||
tools_data = {"memory": {"root": "Notes content", "available": True}}
|
||||
|
||||
result = renderer.render_prompt(
|
||||
prompt,
|
||||
passthrough_data=passthrough_data,
|
||||
docs_together=docs_together,
|
||||
tools_data=tools_data,
|
||||
)
|
||||
|
||||
assert "202" in result
|
||||
assert "Alice" in result
|
||||
assert "Important docs" in result
|
||||
assert "Notes content" in result
|
||||
|
||||
def test_render_prompt_undefined_variable_returns_empty_string(self):
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
prompt = "Hello {{ undefined_var }}"
|
||||
|
||||
result = renderer.render_prompt(prompt)
|
||||
assert result == "Hello "
|
||||
|
||||
def test_render_prompt_with_undefined_variable_in_template(self):
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
prompt = "Hello {{ undefined_name }}"
|
||||
|
||||
result = renderer.render_prompt(prompt)
|
||||
assert result == "Hello "
|
||||
|
||||
def test_validate_template_valid(self):
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
assert renderer.validate_template("Valid {{ variable }}") is True
|
||||
|
||||
def test_validate_template_invalid(self):
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
assert renderer.validate_template("Invalid {{ variable") is False
|
||||
|
||||
def test_extract_variables(self):
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
template = "{{ var1 }} and {{ var2 }}"
|
||||
|
||||
result = renderer.extract_variables(template)
|
||||
|
||||
assert isinstance(result, set)
|
||||
|
||||
def test_uses_template_syntax_detection(self):
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
|
||||
assert renderer._uses_template_syntax("Text with {{ var }}") is True
|
||||
assert renderer._uses_template_syntax("Text with {var}") is False
|
||||
assert renderer._uses_template_syntax("Plain text") is False
|
||||
|
||||
def test_apply_legacy_substitutions(self):
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
prompt = "Use {summaries} to answer"
|
||||
docs_together = "Important info"
|
||||
|
||||
result = renderer._apply_legacy_substitutions(prompt, docs_together)
|
||||
|
||||
assert "Use Important info to answer" in result
|
||||
|
||||
def test_apply_legacy_substitutions_without_docs(self):
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
prompt = "Use {summaries} to answer"
|
||||
|
||||
result = renderer._apply_legacy_substitutions(prompt, None)
|
||||
|
||||
assert result == prompt
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPromptRendererIntegration:
|
||||
|
||||
def test_render_prompt_real_world_scenario(self):
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
prompt = "You are helping {{ passthrough.company }}.\n\nUser: {{ passthrough.user_name }}\n\nRequest ID: {{ system.request_id }}\n\nDate: {{ system.date }}\n\nReference Documents:\n\n{{ source.content }}\n\nPlease answer the question professionally."
|
||||
|
||||
passthrough_data = {"company": "Tech Corp", "user_name": "Alice"}
|
||||
docs_together = "Document 1: Technical specs\nDocument 2: Requirements"
|
||||
|
||||
result = renderer.render_prompt(
|
||||
prompt,
|
||||
request_id="req_123",
|
||||
user_id="user_456",
|
||||
passthrough_data=passthrough_data,
|
||||
docs_together=docs_together,
|
||||
)
|
||||
|
||||
assert "Tech Corp" in result
|
||||
assert "Alice" in result
|
||||
assert "req_123" in result
|
||||
assert "Technical specs" in result
|
||||
assert "professionally" in result
|
||||
|
||||
def test_render_prompt_multiple_doc_references(self):
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
|
||||
renderer = PromptRenderer()
|
||||
prompt = """Documents: {{ source.content }} \n\nAlso summaries: {{ source.summaries }}"""
|
||||
docs_together = "Content here"
|
||||
|
||||
result = renderer.render_prompt(prompt, docs_together=docs_together)
|
||||
|
||||
assert result.count("Content here") == 2
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestStreamProcessorPromptRendering:
|
||||
|
||||
def test_stream_processor_pre_fetch_docs_none_doc_mode(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
|
||||
request_data = {"question": "Test question", "isNoneDoc": True}
|
||||
processor = StreamProcessor(request_data, None)
|
||||
|
||||
docs_together, docs_list = processor.pre_fetch_docs("Test question")
|
||||
|
||||
assert docs_together is None
|
||||
assert docs_list is None
|
||||
|
||||
def test_pre_fetch_tools_disabled_globally(self, mock_mongo_db, monkeypatch):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.core.settings import settings
|
||||
|
||||
monkeypatch.setattr(settings, "ENABLE_TOOL_PREFETCH", False)
|
||||
|
||||
request_data = {"question": "test"}
|
||||
processor = StreamProcessor(request_data, {"sub": "user1"})
|
||||
|
||||
result = processor.pre_fetch_tools()
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_pre_fetch_tools_disabled_per_request(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
|
||||
request_data = {"question": "test", "disable_tool_prefetch": True}
|
||||
processor = StreamProcessor(request_data, {"sub": "user1"})
|
||||
|
||||
result = processor.pre_fetch_tools()
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_pre_fetch_tools_skips_tool_with_no_actions(self, mock_mongo_db):
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.core.mongo_db import MongoDB
|
||||
from bson import ObjectId
|
||||
|
||||
db = MongoDB.get_client()[list(MongoDB.get_client().keys())[0]]
|
||||
tool_doc = {
|
||||
"_id": ObjectId(),
|
||||
"name": "memory",
|
||||
"user": "user1",
|
||||
"status": True,
|
||||
"config": {},
|
||||
}
|
||||
db["user_tools"].insert_one(tool_doc)
|
||||
|
||||
request_data = {"question": "test"}
|
||||
processor = StreamProcessor(request_data, {"sub": "user1"})
|
||||
|
||||
with patch(
|
||||
"application.agents.tools.tool_manager.ToolManager"
|
||||
) as mock_manager_class:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager_class.return_value = mock_manager
|
||||
|
||||
# Mock the tool instance
|
||||
mock_tool = MagicMock()
|
||||
mock_manager.load_tool.return_value = mock_tool
|
||||
|
||||
# Tool has no actions
|
||||
mock_tool.get_actions_metadata.return_value = []
|
||||
|
||||
result = processor.pre_fetch_tools()
|
||||
|
||||
# Should return None when tool has no actions
|
||||
assert result is None
|
||||
|
||||
def test_pre_fetch_tools_enabled_by_default(self, mock_mongo_db, monkeypatch):
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.core.mongo_db import MongoDB
|
||||
from bson import ObjectId
|
||||
|
||||
db = MongoDB.get_client()[list(MongoDB.get_client().keys())[0]]
|
||||
tool_doc = {
|
||||
"_id": ObjectId(),
|
||||
"name": "memory",
|
||||
"user": "user1",
|
||||
"status": True,
|
||||
"config": {},
|
||||
}
|
||||
db["user_tools"].insert_one(tool_doc)
|
||||
|
||||
request_data = {"question": "test"}
|
||||
processor = StreamProcessor(request_data, {"sub": "user1"})
|
||||
|
||||
with patch(
|
||||
"application.agents.tools.tool_manager.ToolManager"
|
||||
) as mock_manager_class:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager_class.return_value = mock_manager
|
||||
|
||||
# Mock the tool instance returned by load_tool
|
||||
mock_tool = MagicMock()
|
||||
mock_manager.load_tool.return_value = mock_tool
|
||||
|
||||
# Mock get_actions_metadata on the tool instance
|
||||
mock_tool.get_actions_metadata.return_value = [
|
||||
{"name": "memory_ls", "description": "List files", "parameters": {"properties": {}}}
|
||||
]
|
||||
mock_tool.execute_action.return_value = "Directory: /\n- file.txt"
|
||||
|
||||
result = processor.pre_fetch_tools()
|
||||
|
||||
assert result is not None
|
||||
assert "memory" in result
|
||||
assert "memory_ls" in result["memory"]
|
||||
|
||||
def test_pre_fetch_tools_no_tools_configured(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
|
||||
request_data = {"question": "test"}
|
||||
processor = StreamProcessor(request_data, {"sub": "user1"})
|
||||
|
||||
result = processor.pre_fetch_tools()
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_pre_fetch_tools_memory_returns_error(self, mock_mongo_db):
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.core.mongo_db import MongoDB
|
||||
from bson import ObjectId
|
||||
|
||||
db = MongoDB.get_client()[list(MongoDB.get_client().keys())[0]]
|
||||
tool_doc = {
|
||||
"_id": ObjectId(),
|
||||
"name": "memory",
|
||||
"user": "user1",
|
||||
"status": True,
|
||||
"config": {},
|
||||
}
|
||||
db["user_tools"].insert_one(tool_doc)
|
||||
|
||||
request_data = {"question": "test"}
|
||||
processor = StreamProcessor(request_data, {"sub": "user1"})
|
||||
|
||||
with patch(
|
||||
"application.agents.tools.tool_manager.ToolManager"
|
||||
) as mock_manager_class:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager_class.return_value = mock_manager
|
||||
|
||||
# Mock the tool instance
|
||||
mock_tool = MagicMock()
|
||||
mock_manager.load_tool.return_value = mock_tool
|
||||
|
||||
mock_tool.get_actions_metadata.return_value = [
|
||||
{"name": "memory_ls", "description": "List files", "parameters": {"properties": {}}}
|
||||
]
|
||||
# Simulate execution error
|
||||
mock_tool.execute_action.side_effect = Exception("Tool error")
|
||||
|
||||
result = processor.pre_fetch_tools()
|
||||
|
||||
# Should return None when all actions fail
|
||||
assert result is None
|
||||
|
||||
def test_pre_fetch_tools_memory_returns_empty(self, mock_mongo_db):
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.core.mongo_db import MongoDB
|
||||
from bson import ObjectId
|
||||
|
||||
db = MongoDB.get_client()[list(MongoDB.get_client().keys())[0]]
|
||||
tool_doc = {
|
||||
"_id": ObjectId(),
|
||||
"name": "memory",
|
||||
"user": "user1",
|
||||
"status": True,
|
||||
"config": {},
|
||||
}
|
||||
db["user_tools"].insert_one(tool_doc)
|
||||
|
||||
request_data = {"question": "test"}
|
||||
processor = StreamProcessor(request_data, {"sub": "user1"})
|
||||
|
||||
with patch(
|
||||
"application.agents.tools.tool_manager.ToolManager"
|
||||
) as mock_manager_class:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager_class.return_value = mock_manager
|
||||
|
||||
# Mock the tool instance
|
||||
mock_tool = MagicMock()
|
||||
mock_manager.load_tool.return_value = mock_tool
|
||||
|
||||
mock_tool.get_actions_metadata.return_value = [
|
||||
{"name": "memory_ls", "description": "List files", "parameters": {"properties": {}}}
|
||||
]
|
||||
# Return empty string
|
||||
mock_tool.execute_action.return_value = ""
|
||||
|
||||
result = processor.pre_fetch_tools()
|
||||
|
||||
# Empty result should still be included
|
||||
assert result is not None
|
||||
assert "memory" in result
|
||||
@@ -250,3 +250,330 @@ class TestStreamProcessorAttachments:
|
||||
"attachments" not in processor.data
|
||||
or processor.data.get("attachments") is None
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestToolPreFetch:
|
||||
"""Tests for tool pre-fetching with saved parameter values from MongoDB"""
|
||||
|
||||
def test_cryptoprice_prefetch_with_saved_parameters(self, mock_mongo_db):
|
||||
"""Test that cryptoprice tool is pre-fetched with saved parameter values from MongoDB structure"""
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.core.settings import settings
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
# Setup MongoDB with cryptoprice tool configuration
|
||||
# NOTE: The collection is called "user_tools" not "tools"
|
||||
tools_collection = mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"]
|
||||
tool_id = ObjectId()
|
||||
|
||||
tools_collection.insert_one(
|
||||
{
|
||||
"_id": tool_id,
|
||||
"name": "cryptoprice",
|
||||
"user": "user_123",
|
||||
"status": True, # Must be True for tool to be included
|
||||
"actions": [
|
||||
{
|
||||
"name": "cryptoprice_get",
|
||||
"description": "Get cryptocurrency price",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"symbol": {
|
||||
"type": "string",
|
||||
"description": "Crypto symbol",
|
||||
"value": "BTC" # Saved value in MongoDB
|
||||
},
|
||||
"currency": {
|
||||
"type": "string",
|
||||
"description": "Currency for price",
|
||||
"value": "USD" # Saved value in MongoDB
|
||||
}
|
||||
},
|
||||
"required": ["symbol", "currency"]
|
||||
}
|
||||
}
|
||||
],
|
||||
"config": {
|
||||
"token": ""
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"question": "What is the price of Bitcoin?",
|
||||
"tools": [str(tool_id)]
|
||||
}
|
||||
|
||||
processor = StreamProcessor(request_data, {"sub": "user_123"})
|
||||
processor._required_tool_actions = {"cryptoprice": {"cryptoprice_get"}}
|
||||
|
||||
# Mock the ToolManager and tool instance
|
||||
with patch("application.agents.tools.tool_manager.ToolManager") as mock_manager_class:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager_class.return_value = mock_manager
|
||||
|
||||
# Mock the tool instance returned by load_tool
|
||||
mock_tool = MagicMock()
|
||||
mock_manager.load_tool.return_value = mock_tool
|
||||
|
||||
# Mock get_actions_metadata on the tool instance
|
||||
mock_tool.get_actions_metadata.return_value = [
|
||||
{
|
||||
"name": "cryptoprice_get",
|
||||
"description": "Get cryptocurrency price",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"symbol": {"type": "string", "description": "Crypto symbol"},
|
||||
"currency": {"type": "string", "description": "Currency for price"}
|
||||
},
|
||||
"required": ["symbol", "currency"]
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
# Mock execute_action on the tool instance to return price data
|
||||
mock_tool.execute_action.return_value = {
|
||||
"status_code": 200,
|
||||
"price": 45000.50,
|
||||
"message": "Price of BTC in USD retrieved successfully."
|
||||
}
|
||||
|
||||
# Execute pre-fetch
|
||||
tools_data = processor.pre_fetch_tools()
|
||||
|
||||
# Verify the tool was called
|
||||
assert mock_tool.execute_action.called
|
||||
|
||||
# Verify it was called with the saved parameters from MongoDB
|
||||
call_args = mock_tool.execute_action.call_args
|
||||
assert call_args is not None
|
||||
|
||||
# Check action name uses the full metadata name for execution
|
||||
assert call_args[0][0] == "cryptoprice_get"
|
||||
|
||||
# Check kwargs contain saved values
|
||||
kwargs = call_args[1]
|
||||
assert kwargs.get("symbol") == "BTC"
|
||||
assert kwargs.get("currency") == "USD"
|
||||
|
||||
# Verify tools_data structure
|
||||
assert "cryptoprice" in tools_data
|
||||
# Results are exposed under the full action name
|
||||
assert "cryptoprice_get" in tools_data["cryptoprice"]
|
||||
assert tools_data["cryptoprice"]["cryptoprice_get"]["price"] == 45000.50
|
||||
|
||||
def test_prefetch_with_missing_saved_values_uses_defaults(self, mock_mongo_db):
|
||||
"""Test that pre-fetch falls back to defaults when saved values are missing"""
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.core.settings import settings
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
tools_collection = mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"]
|
||||
tool_id = ObjectId()
|
||||
|
||||
# Tool configuration without saved values
|
||||
tools_collection.insert_one(
|
||||
{
|
||||
"_id": tool_id,
|
||||
"name": "cryptoprice",
|
||||
"user": "user_123",
|
||||
"status": True,
|
||||
"actions": [
|
||||
{
|
||||
"name": "cryptoprice_get",
|
||||
"description": "Get cryptocurrency price",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"symbol": {
|
||||
"type": "string",
|
||||
"description": "Crypto symbol",
|
||||
"default": "ETH" # Only default, no saved value
|
||||
},
|
||||
"currency": {
|
||||
"type": "string",
|
||||
"description": "Currency",
|
||||
"default": "EUR"
|
||||
}
|
||||
},
|
||||
"required": ["symbol", "currency"]
|
||||
}
|
||||
}
|
||||
],
|
||||
"config": {}
|
||||
}
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"question": "Crypto price?",
|
||||
"tools": [str(tool_id)]
|
||||
}
|
||||
|
||||
processor = StreamProcessor(request_data, {"sub": "user_123"})
|
||||
processor._required_tool_actions = {"cryptoprice": {"cryptoprice_get"}}
|
||||
|
||||
with patch("application.agents.tools.tool_manager.ToolManager") as mock_manager_class:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager_class.return_value = mock_manager
|
||||
|
||||
# Mock the tool instance
|
||||
mock_tool = MagicMock()
|
||||
mock_manager.load_tool.return_value = mock_tool
|
||||
|
||||
mock_tool.get_actions_metadata.return_value = [
|
||||
{
|
||||
"name": "cryptoprice_get",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"symbol": {"type": "string", "default": "ETH"},
|
||||
"currency": {"type": "string", "default": "EUR"}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
mock_tool.execute_action.return_value = {
|
||||
"status_code": 200,
|
||||
"price": 2500.00
|
||||
}
|
||||
|
||||
processor.pre_fetch_tools()
|
||||
|
||||
# Should use default values when saved values are missing
|
||||
call_args = mock_tool.execute_action.call_args
|
||||
if call_args:
|
||||
kwargs = call_args[1]
|
||||
# Either uses defaults or skips if no values available
|
||||
assert kwargs.get("symbol") in ["ETH", None]
|
||||
assert kwargs.get("currency") in ["EUR", None]
|
||||
|
||||
def test_prefetch_with_tool_id_reference(self, mock_mongo_db):
|
||||
"""Test that tools can be referenced by MongoDB ObjectId in templates"""
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.core.settings import settings
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
tools_collection = mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"]
|
||||
tool_id = ObjectId()
|
||||
|
||||
# Create a tool in the database
|
||||
tools_collection.insert_one(
|
||||
{
|
||||
"_id": tool_id,
|
||||
"name": "memory",
|
||||
"user": "user_123",
|
||||
"status": True,
|
||||
"actions": [
|
||||
{
|
||||
"name": "memory_ls",
|
||||
"description": "List files",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}
|
||||
}
|
||||
],
|
||||
"config": {},
|
||||
}
|
||||
)
|
||||
|
||||
request_data = {"question": "test"}
|
||||
processor = StreamProcessor(request_data, {"sub": "user_123"})
|
||||
|
||||
# Mock the filtering to require this specific tool by ID
|
||||
processor._required_tool_actions = {
|
||||
str(tool_id): {"memory_ls"} # Reference by ObjectId string
|
||||
}
|
||||
|
||||
with patch(
|
||||
"application.agents.tools.tool_manager.ToolManager"
|
||||
) as mock_manager_class:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager_class.return_value = mock_manager
|
||||
|
||||
# Mock the tool instance
|
||||
mock_tool = MagicMock()
|
||||
mock_manager.load_tool.return_value = mock_tool
|
||||
|
||||
mock_tool.get_actions_metadata.return_value = [
|
||||
{"name": "memory_ls", "description": "List files", "parameters": {"properties": {}}}
|
||||
]
|
||||
mock_tool.execute_action.return_value = "Directory: /\n- file.txt"
|
||||
|
||||
result = processor.pre_fetch_tools()
|
||||
|
||||
# Tool data should be available under both name and ID
|
||||
assert result is not None
|
||||
assert "memory" in result
|
||||
assert str(tool_id) in result
|
||||
# Both should point to the same data
|
||||
assert result["memory"] == result[str(tool_id)]
|
||||
assert "memory_ls" in result[str(tool_id)]
|
||||
|
||||
def test_prefetch_with_multiple_same_name_tools(self, mock_mongo_db):
|
||||
"""Test that multiple tools with the same name can be distinguished by ID"""
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.core.settings import settings
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
tools_collection = mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"]
|
||||
|
||||
# Create two memory tools with different IDs
|
||||
tool_id_1 = ObjectId()
|
||||
tool_id_2 = ObjectId()
|
||||
|
||||
tools_collection.insert_many([
|
||||
{
|
||||
"_id": tool_id_1,
|
||||
"name": "memory",
|
||||
"user": "user_123",
|
||||
"status": True,
|
||||
"actions": [{"name": "memory_ls", "parameters": {"properties": {}}}],
|
||||
"config": {"path": "/home"},
|
||||
},
|
||||
{
|
||||
"_id": tool_id_2,
|
||||
"name": "memory",
|
||||
"user": "user_123",
|
||||
"status": True,
|
||||
"actions": [{"name": "memory_ls", "parameters": {"properties": {}}}],
|
||||
"config": {"path": "/work"},
|
||||
}
|
||||
])
|
||||
|
||||
request_data = {"question": "test"}
|
||||
processor = StreamProcessor(request_data, {"sub": "user_123"})
|
||||
|
||||
# Mock the filtering to require only the second tool by ID
|
||||
processor._required_tool_actions = {
|
||||
str(tool_id_2): {"memory_ls"} # Only reference the second one
|
||||
}
|
||||
|
||||
with patch(
|
||||
"application.agents.tools.tool_manager.ToolManager"
|
||||
) as mock_manager_class:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager_class.return_value = mock_manager
|
||||
|
||||
# Mock the tool instance
|
||||
mock_tool = MagicMock()
|
||||
mock_manager.load_tool.return_value = mock_tool
|
||||
|
||||
mock_tool.get_actions_metadata.return_value = [
|
||||
{"name": "memory_ls", "parameters": {"properties": {}}}
|
||||
]
|
||||
mock_tool.execute_action.return_value = "Work directory"
|
||||
|
||||
result = processor.pre_fetch_tools()
|
||||
|
||||
# Only the second tool should be fetched (referenced by ID)
|
||||
assert result is not None
|
||||
assert str(tool_id_2) in result
|
||||
# Since filtering is enabled and only tool_id_2 is referenced,
|
||||
# only tool_id_2 should be pre-fetched
|
||||
# The "memory" key will still exist because we store under both name and ID
|
||||
assert "memory" in result
|
||||
|
||||
Reference in New Issue
Block a user