feat: template-based prompt rendering with dynamic namespace injection (#2091)

* feat: template-based prompt rendering with dynamic namespace injection

* refactor: improve template engine initialization with clearer formatting

* refactor: streamline ReActAgent methods and improve content extraction logic

feat: enhance error handling in NamespaceManager and TemplateEngine

fix: update NewAgent component to ensure consistent form data submission

test: modify tests for ReActAgent and prompt renderer to reflect method changes and improve coverage

* feat: tools namespace + three-tier token budget

* refactor: remove unused variable assignment in message building tests

* Enhance prompt customization and tool pre-fetching functionality

* ruff lint fix

* refactor: cleaner error handling and reduce code clutter

---------

Co-authored-by: Alex <a@tushynski.me>
This commit is contained in:
Siddhant Rai
2025-10-31 18:17:44 +05:30
committed by GitHub
parent a7d61b9d59
commit 21e5c261ef
33 changed files with 2917 additions and 646 deletions

View File

@@ -64,17 +64,14 @@ class TestBaseAgentBuildMessages:
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
system_prompt = "System: {summaries}"
system_prompt = "System prompt content"
query = "What is Python?"
retrieved_data = [
{"text": "Python is a programming language", "filename": "python.txt"}
]
messages = agent._build_messages(system_prompt, query, retrieved_data)
messages = agent._build_messages(system_prompt, query)
assert len(messages) >= 2
assert messages[0]["role"] == "system"
assert "Python is a programming language" in messages[0]["content"]
assert messages[0]["content"] == system_prompt
assert messages[-1]["role"] == "user"
assert messages[-1]["content"] == query
@@ -88,11 +85,10 @@ class TestBaseAgentBuildMessages:
agent_base_params["chat_history"] = sample_chat_history
agent = ClassicAgent(**agent_base_params)
system_prompt = "System: {summaries}"
system_prompt = "System prompt"
query = "New question?"
retrieved_data = [{"text": "Data", "filename": "file.txt"}]
messages = agent._build_messages(system_prompt, query, retrieved_data)
messages = agent._build_messages(system_prompt, query)
user_messages = [m for m in messages if m["role"] == "user"]
assistant_messages = [m for m in messages if m["role"] == "assistant"]
@@ -118,9 +114,7 @@ class TestBaseAgentBuildMessages:
agent_base_params["chat_history"] = tool_call_history
agent = ClassicAgent(**agent_base_params)
messages = agent._build_messages(
"System: {summaries}", "query", [{"text": "data", "filename": "file.txt"}]
)
messages = agent._build_messages("System prompt", "query")
tool_messages = [m for m in messages if m["role"] == "tool"]
assert len(tool_messages) > 0
@@ -129,32 +123,25 @@ class TestBaseAgentBuildMessages:
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
retrieved_data = [{"text": "Document without filename or title"}]
messages = agent._build_messages("System: {summaries}", "query", retrieved_data)
messages = agent._build_messages("System prompt", "query")
assert messages[0]["role"] == "system"
assert "Document without filename" in messages[0]["content"]
assert messages[0]["content"] == "System prompt"
def test_build_messages_uses_title_as_fallback(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
retrieved_data = [{"text": "Data", "title": "Title Doc"}]
messages = agent._build_messages("System: {summaries}", "query", retrieved_data)
assert "Title Doc" in messages[0]["content"]
agent._build_messages("System prompt", "query")
def test_build_messages_uses_source_as_fallback(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
retrieved_data = [{"text": "Data", "source": "source.txt"}]
messages = agent._build_messages("System: {summaries}", "query", retrieved_data)
assert "source.txt" in messages[0]["content"]
agent._build_messages("System prompt", "query")
@pytest.mark.unit
@@ -475,40 +462,6 @@ class TestBaseAgentToolExecution:
assert truncated[0]["result"].endswith("...")
@pytest.mark.unit
class TestBaseAgentRetrieverSearch:
def test_retriever_search(
self,
agent_base_params,
mock_retriever,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
agent = ClassicAgent(**agent_base_params)
results = agent._retriever_search(mock_retriever, "test query", log_context)
assert len(results) == 2
mock_retriever.search.assert_called_once_with("test query")
def test_retriever_search_logs_context(
self,
agent_base_params,
mock_retriever,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
agent = ClassicAgent(**agent_base_params)
agent._retriever_search(mock_retriever, "test query", log_context)
assert len(log_context.stacks) == 1
assert log_context.stacks[0]["component"] == "retriever"
@pytest.mark.unit
class TestBaseAgentLLMGeneration:

View File

@@ -19,7 +19,6 @@ class TestClassicAgent:
def test_gen_inner_basic_flow(
self,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
@@ -40,7 +39,7 @@ class TestClassicAgent:
agent = ClassicAgent(**agent_base_params)
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
results = list(agent._gen_inner("Test query", log_context))
assert len(results) >= 2
sources = [r for r in results if "sources" in r]
@@ -52,7 +51,6 @@ class TestClassicAgent:
def test_gen_inner_retrieves_documents(
self,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
@@ -68,14 +66,11 @@ class TestClassicAgent:
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ClassicAgent(**agent_base_params)
list(agent._gen_inner("Test query", mock_retriever, log_context))
mock_retriever.search.assert_called_once_with("Test query")
list(agent._gen_inner("Test query", log_context))
def test_gen_inner_uses_user_api_key_tools(
self,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
@@ -104,14 +99,13 @@ class TestClassicAgent:
agent_base_params["user_api_key"] = "api_key_123"
agent = ClassicAgent(**agent_base_params)
list(agent._gen_inner("Test query", mock_retriever, log_context))
list(agent._gen_inner("Test query", log_context))
assert len(agent.tools) >= 0
def test_gen_inner_uses_user_tools(
self,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
@@ -133,14 +127,13 @@ class TestClassicAgent:
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ClassicAgent(**agent_base_params)
list(agent._gen_inner("Test query", mock_retriever, log_context))
list(agent._gen_inner("Test query", log_context))
assert len(agent.tools) >= 0
def test_gen_inner_builds_correct_messages(
self,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
@@ -156,7 +149,7 @@ class TestClassicAgent:
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ClassicAgent(**agent_base_params)
list(agent._gen_inner("Test query", mock_retriever, log_context))
list(agent._gen_inner("Test query", log_context))
call_kwargs = mock_llm.gen_stream.call_args[1]
messages = call_kwargs["messages"]
@@ -169,7 +162,6 @@ class TestClassicAgent:
def test_gen_inner_logs_tool_calls(
self,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
@@ -187,7 +179,7 @@ class TestClassicAgent:
agent = ClassicAgent(**agent_base_params)
agent.tool_calls = [{"tool": "test", "result": "success"}]
list(agent._gen_inner("Test query", mock_retriever, log_context))
list(agent._gen_inner("Test query", log_context))
agent_logs = [s for s in log_context.stacks if s["component"] == "agent"]
assert len(agent_logs) == 1
@@ -200,7 +192,6 @@ class TestClassicAgentIntegration:
def test_gen_method_with_logging(
self,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
@@ -216,14 +207,13 @@ class TestClassicAgentIntegration:
agent = ClassicAgent(**agent_base_params)
results = list(agent.gen("Test query", mock_retriever))
results = list(agent.gen("Test query"))
assert len(results) >= 1
def test_gen_method_decorator_applied(
self,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,

View File

@@ -35,7 +35,7 @@ class TestReActAgentContentExtraction:
agent = ReActAgent(**agent_base_params)
response = "Simple string response"
content = agent._extract_content_from_llm_response(response)
content = agent._extract_content(response)
assert content == "Simple string response"
@@ -48,7 +48,7 @@ class TestReActAgentContentExtraction:
response.message = Mock()
response.message.content = "Message content"
content = agent._extract_content_from_llm_response(response)
content = agent._extract_content(response)
assert content == "Message content"
@@ -64,7 +64,7 @@ class TestReActAgentContentExtraction:
response.message = None
response.content = None
content = agent._extract_content_from_llm_response(response)
content = agent._extract_content(response)
assert content == "OpenAI content"
@@ -81,7 +81,7 @@ class TestReActAgentContentExtraction:
response.message = None
response.choices = None
content = agent._extract_content_from_llm_response(response)
content = agent._extract_content(response)
assert content == "Anthropic content"
@@ -101,7 +101,7 @@ class TestReActAgentContentExtraction:
chunk2.choices[0].delta.content = "Part 2"
response = iter([chunk1, chunk2])
content = agent._extract_content_from_llm_response(response)
content = agent._extract_content(response)
assert content == "Part 1 Part 2"
@@ -123,7 +123,7 @@ class TestReActAgentContentExtraction:
chunk2.choices = []
response = iter([chunk1, chunk2])
content = agent._extract_content_from_llm_response(response)
content = agent._extract_content(response)
assert content == "Stream 1 Stream 2"
@@ -133,7 +133,7 @@ class TestReActAgentContentExtraction:
agent = ReActAgent(**agent_base_params)
response = iter(["chunk1", "chunk2", "chunk3"])
content = agent._extract_content_from_llm_response(response)
content = agent._extract_content(response)
assert content == "chunk1chunk2chunk3"
@@ -148,7 +148,7 @@ class TestReActAgentContentExtraction:
response.choices = None
response.content = None
content = agent._extract_content_from_llm_response(response)
content = agent._extract_content(response)
assert content == ""
@@ -161,7 +161,7 @@ class TestReActAgentPlanning:
new_callable=mock_open,
read_data="Test planning prompt: {query} {summaries} {prompt} {observations}",
)
def test_create_plan(
def test_planning_phase(
self,
mock_file,
agent_base_params,
@@ -171,24 +171,27 @@ class TestReActAgentPlanning:
log_context,
):
def mock_gen_stream(*args, **kwargs):
yield "Plan step 1"
yield "Plan step 2"
# Return simple strings - _extract_content handles strings directly
yield "Plan "
yield "content"
mock_llm.gen_stream = Mock(return_value=mock_gen_stream())
agent = ReActAgent(**agent_base_params)
agent.observations = ["Observation 1"]
plan_chunks = list(agent._create_plan("Test query", "Test docs", log_context))
plan_chunks = list(agent._planning_phase("Test query", log_context))
assert len(plan_chunks) == 2
assert plan_chunks[0] == "Plan step 1"
assert plan_chunks[1] == "Plan step 2"
# Should yield thought dicts
assert any("thought" in chunk for chunk in plan_chunks)
assert agent.plan == "Plan content"
mock_llm.gen_stream.assert_called_once()
@patch("builtins.open", new_callable=mock_open, read_data="Test: {query}")
def test_create_plan_fills_template(
def test_planning_phase_fills_template(
self,
mock_file,
agent_base_params,
@@ -197,10 +200,10 @@ class TestReActAgentPlanning:
mock_llm_handler_creator,
log_context,
):
mock_llm.gen_stream = Mock(return_value=iter(["Plan"]))
mock_llm.gen_stream = Mock(return_value=iter([]))
agent = ReActAgent(**agent_base_params)
list(agent._create_plan("My query", "Docs", log_context))
list(agent._planning_phase("My query", log_context))
call_args = mock_llm.gen_stream.call_args[1]
messages = call_args["messages"]
@@ -216,7 +219,7 @@ class TestReActAgentFinalAnswer:
new_callable=mock_open,
read_data="Final answer for: {query} with {observations}",
)
def test_create_final_answer(
def test_synthesis_phase(
self,
mock_file,
agent_base_params,
@@ -226,24 +229,22 @@ class TestReActAgentFinalAnswer:
log_context,
):
def mock_gen_stream(*args, **kwargs):
yield "Final "
yield "answer"
yield Mock(choices=[Mock(delta=Mock(content="Final "))])
yield Mock(choices=[Mock(delta=Mock(content="answer"))])
mock_llm.gen_stream = Mock(return_value=mock_gen_stream())
agent = ReActAgent(**agent_base_params)
observations = ["Obs 1", "Obs 2"]
agent.observations = ["Obs 1", "Obs 2"]
answer_chunks = list(
agent._create_final_answer("Test query", observations, log_context)
)
answer_chunks = list(agent._synthesis_phase("Test query", log_context))
assert len(answer_chunks) == 2
assert answer_chunks[0] == "Final "
assert answer_chunks[1] == "answer"
# Should yield answer dicts
assert any("answer" in chunk for chunk in answer_chunks)
@patch("builtins.open", new_callable=mock_open, read_data="Answer: {observations}")
def test_create_final_answer_truncates_long_observations(
def test_synthesis_phase_truncates_long_observations(
self,
mock_file,
agent_base_params,
@@ -252,20 +253,20 @@ class TestReActAgentFinalAnswer:
mock_llm_handler_creator,
log_context,
):
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
mock_llm.gen_stream = Mock(return_value=iter([]))
agent = ReActAgent(**agent_base_params)
long_obs = ["A" * 15000]
agent.observations = ["A" * 15000]
list(agent._create_final_answer("Query", long_obs, log_context))
list(agent._synthesis_phase("Query", log_context))
call_args = mock_llm.gen_stream.call_args[1]
messages = call_args["messages"]
assert "observations truncated" in messages[0]["content"]
assert "truncated" in messages[0]["content"]
@patch("builtins.open", new_callable=mock_open, read_data="Test: {query}")
def test_create_final_answer_no_tools(
def test_synthesis_phase_no_tools(
self,
mock_file,
agent_base_params,
@@ -274,10 +275,11 @@ class TestReActAgentFinalAnswer:
mock_llm_handler_creator,
log_context,
):
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
mock_llm.gen_stream = Mock(return_value=iter([]))
agent = ReActAgent(**agent_base_params)
list(agent._create_final_answer("Query", ["Obs"], log_context))
agent.observations = ["Obs"]
list(agent._synthesis_phase("Query", log_context))
call_args = mock_llm.gen_stream.call_args[1]
@@ -294,7 +296,6 @@ class TestReActAgentGenInner:
self,
mock_file,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
@@ -313,7 +314,7 @@ class TestReActAgentGenInner:
agent.plan = "Old plan"
agent.observations = ["Old obs"]
list(agent._gen_inner("New query", mock_retriever, log_context))
list(agent._gen_inner("New query", log_context))
assert agent.plan != "Old plan"
assert len(agent.observations) > 0
@@ -323,7 +324,6 @@ class TestReActAgentGenInner:
self,
mock_file,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
@@ -351,7 +351,7 @@ class TestReActAgentGenInner:
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ReActAgent(**agent_base_params)
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
results = list(agent._gen_inner("Test query", log_context))
assert any("answer" in r for r in results)
@@ -360,7 +360,6 @@ class TestReActAgentGenInner:
self,
mock_file,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
@@ -386,7 +385,7 @@ class TestReActAgentGenInner:
agent = ReActAgent(**agent_base_params)
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
results = list(agent._gen_inner("Test query", log_context))
thought_results = [r for r in results if "thought" in r]
assert len(thought_results) > 0
@@ -396,7 +395,6 @@ class TestReActAgentGenInner:
self,
mock_file,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
@@ -412,7 +410,7 @@ class TestReActAgentGenInner:
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ReActAgent(**agent_base_params)
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
results = list(agent._gen_inner("Test query", log_context))
sources = [r for r in results if "sources" in r]
assert len(sources) >= 1
@@ -422,7 +420,6 @@ class TestReActAgentGenInner:
self,
mock_file,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
@@ -440,7 +437,7 @@ class TestReActAgentGenInner:
agent = ReActAgent(**agent_base_params)
agent.tool_calls = [{"tool": "test", "result": "A" * 100}]
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
results = list(agent._gen_inner("Test query", log_context))
tool_call_results = [r for r in results if "tool_calls" in r]
if tool_call_results:
@@ -451,7 +448,6 @@ class TestReActAgentGenInner:
self,
mock_file,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
@@ -467,7 +463,7 @@ class TestReActAgentGenInner:
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ReActAgent(**agent_base_params)
list(agent._gen_inner("Test query", mock_retriever, log_context))
list(agent._gen_inner("Test query", log_context))
assert len(agent.observations) > 0
@@ -484,7 +480,6 @@ class TestReActAgentIntegration:
self,
mock_file,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
@@ -512,7 +507,7 @@ class TestReActAgentIntegration:
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ReActAgent(**agent_base_params)
results = list(agent._gen_inner("Complex query", mock_retriever, log_context))
results = list(agent._gen_inner("Complex query", log_context))
assert len(results) > 0
assert any("thought" in r for r in results)