feat: template-based prompt rendering with dynamic namespace injection (#2091)

* feat: template-based prompt rendering with dynamic namespace injection

* refactor: improve template engine initialization with clearer formatting

* refactor: streamline ReActAgent methods and improve content extraction logic

feat: enhance error handling in NamespaceManager and TemplateEngine

fix: update NewAgent component to ensure consistent form data submission

test: modify tests for ReActAgent and prompt renderer to reflect method changes and improve coverage

* feat: tools namespace + three-tier token budget

* refactor: remove unused variable assignment in message building tests

* Enhance prompt customization and tool pre-fetching functionality

* ruff lint fix

* refactor: cleaner error handling and reduce code clutter

---------

Co-authored-by: Alex <a@tushynski.me>
This commit is contained in:
Siddhant Rai
2025-10-31 18:17:44 +05:30
committed by GitHub
parent a7d61b9d59
commit 21e5c261ef
33 changed files with 2917 additions and 646 deletions

View File

@@ -35,7 +35,7 @@ class TestReActAgentContentExtraction:
agent = ReActAgent(**agent_base_params)
response = "Simple string response"
content = agent._extract_content_from_llm_response(response)
content = agent._extract_content(response)
assert content == "Simple string response"
@@ -48,7 +48,7 @@ class TestReActAgentContentExtraction:
response.message = Mock()
response.message.content = "Message content"
content = agent._extract_content_from_llm_response(response)
content = agent._extract_content(response)
assert content == "Message content"
@@ -64,7 +64,7 @@ class TestReActAgentContentExtraction:
response.message = None
response.content = None
content = agent._extract_content_from_llm_response(response)
content = agent._extract_content(response)
assert content == "OpenAI content"
@@ -81,7 +81,7 @@ class TestReActAgentContentExtraction:
response.message = None
response.choices = None
content = agent._extract_content_from_llm_response(response)
content = agent._extract_content(response)
assert content == "Anthropic content"
@@ -101,7 +101,7 @@ class TestReActAgentContentExtraction:
chunk2.choices[0].delta.content = "Part 2"
response = iter([chunk1, chunk2])
content = agent._extract_content_from_llm_response(response)
content = agent._extract_content(response)
assert content == "Part 1 Part 2"
@@ -123,7 +123,7 @@ class TestReActAgentContentExtraction:
chunk2.choices = []
response = iter([chunk1, chunk2])
content = agent._extract_content_from_llm_response(response)
content = agent._extract_content(response)
assert content == "Stream 1 Stream 2"
@@ -133,7 +133,7 @@ class TestReActAgentContentExtraction:
agent = ReActAgent(**agent_base_params)
response = iter(["chunk1", "chunk2", "chunk3"])
content = agent._extract_content_from_llm_response(response)
content = agent._extract_content(response)
assert content == "chunk1chunk2chunk3"
@@ -148,7 +148,7 @@ class TestReActAgentContentExtraction:
response.choices = None
response.content = None
content = agent._extract_content_from_llm_response(response)
content = agent._extract_content(response)
assert content == ""
@@ -161,7 +161,7 @@ class TestReActAgentPlanning:
new_callable=mock_open,
read_data="Test planning prompt: {query} {summaries} {prompt} {observations}",
)
def test_create_plan(
def test_planning_phase(
self,
mock_file,
agent_base_params,
@@ -171,24 +171,27 @@ class TestReActAgentPlanning:
log_context,
):
def mock_gen_stream(*args, **kwargs):
yield "Plan step 1"
yield "Plan step 2"
# Return simple strings - _extract_content handles strings directly
yield "Plan "
yield "content"
mock_llm.gen_stream = Mock(return_value=mock_gen_stream())
agent = ReActAgent(**agent_base_params)
agent.observations = ["Observation 1"]
plan_chunks = list(agent._create_plan("Test query", "Test docs", log_context))
plan_chunks = list(agent._planning_phase("Test query", log_context))
assert len(plan_chunks) == 2
assert plan_chunks[0] == "Plan step 1"
assert plan_chunks[1] == "Plan step 2"
# Should yield thought dicts
assert any("thought" in chunk for chunk in plan_chunks)
assert agent.plan == "Plan content"
mock_llm.gen_stream.assert_called_once()
@patch("builtins.open", new_callable=mock_open, read_data="Test: {query}")
def test_create_plan_fills_template(
def test_planning_phase_fills_template(
self,
mock_file,
agent_base_params,
@@ -197,10 +200,10 @@ class TestReActAgentPlanning:
mock_llm_handler_creator,
log_context,
):
mock_llm.gen_stream = Mock(return_value=iter(["Plan"]))
mock_llm.gen_stream = Mock(return_value=iter([]))
agent = ReActAgent(**agent_base_params)
list(agent._create_plan("My query", "Docs", log_context))
list(agent._planning_phase("My query", log_context))
call_args = mock_llm.gen_stream.call_args[1]
messages = call_args["messages"]
@@ -216,7 +219,7 @@ class TestReActAgentFinalAnswer:
new_callable=mock_open,
read_data="Final answer for: {query} with {observations}",
)
def test_create_final_answer(
def test_synthesis_phase(
self,
mock_file,
agent_base_params,
@@ -226,24 +229,22 @@ class TestReActAgentFinalAnswer:
log_context,
):
def mock_gen_stream(*args, **kwargs):
yield "Final "
yield "answer"
yield Mock(choices=[Mock(delta=Mock(content="Final "))])
yield Mock(choices=[Mock(delta=Mock(content="answer"))])
mock_llm.gen_stream = Mock(return_value=mock_gen_stream())
agent = ReActAgent(**agent_base_params)
observations = ["Obs 1", "Obs 2"]
agent.observations = ["Obs 1", "Obs 2"]
answer_chunks = list(
agent._create_final_answer("Test query", observations, log_context)
)
answer_chunks = list(agent._synthesis_phase("Test query", log_context))
assert len(answer_chunks) == 2
assert answer_chunks[0] == "Final "
assert answer_chunks[1] == "answer"
# Should yield answer dicts
assert any("answer" in chunk for chunk in answer_chunks)
@patch("builtins.open", new_callable=mock_open, read_data="Answer: {observations}")
def test_create_final_answer_truncates_long_observations(
def test_synthesis_phase_truncates_long_observations(
self,
mock_file,
agent_base_params,
@@ -252,20 +253,20 @@ class TestReActAgentFinalAnswer:
mock_llm_handler_creator,
log_context,
):
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
mock_llm.gen_stream = Mock(return_value=iter([]))
agent = ReActAgent(**agent_base_params)
long_obs = ["A" * 15000]
agent.observations = ["A" * 15000]
list(agent._create_final_answer("Query", long_obs, log_context))
list(agent._synthesis_phase("Query", log_context))
call_args = mock_llm.gen_stream.call_args[1]
messages = call_args["messages"]
assert "observations truncated" in messages[0]["content"]
assert "truncated" in messages[0]["content"]
@patch("builtins.open", new_callable=mock_open, read_data="Test: {query}")
def test_create_final_answer_no_tools(
def test_synthesis_phase_no_tools(
self,
mock_file,
agent_base_params,
@@ -274,10 +275,11 @@ class TestReActAgentFinalAnswer:
mock_llm_handler_creator,
log_context,
):
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
mock_llm.gen_stream = Mock(return_value=iter([]))
agent = ReActAgent(**agent_base_params)
list(agent._create_final_answer("Query", ["Obs"], log_context))
agent.observations = ["Obs"]
list(agent._synthesis_phase("Query", log_context))
call_args = mock_llm.gen_stream.call_args[1]
@@ -294,7 +296,6 @@ class TestReActAgentGenInner:
self,
mock_file,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
@@ -313,7 +314,7 @@ class TestReActAgentGenInner:
agent.plan = "Old plan"
agent.observations = ["Old obs"]
list(agent._gen_inner("New query", mock_retriever, log_context))
list(agent._gen_inner("New query", log_context))
assert agent.plan != "Old plan"
assert len(agent.observations) > 0
@@ -323,7 +324,6 @@ class TestReActAgentGenInner:
self,
mock_file,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
@@ -351,7 +351,7 @@ class TestReActAgentGenInner:
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ReActAgent(**agent_base_params)
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
results = list(agent._gen_inner("Test query", log_context))
assert any("answer" in r for r in results)
@@ -360,7 +360,6 @@ class TestReActAgentGenInner:
self,
mock_file,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
@@ -386,7 +385,7 @@ class TestReActAgentGenInner:
agent = ReActAgent(**agent_base_params)
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
results = list(agent._gen_inner("Test query", log_context))
thought_results = [r for r in results if "thought" in r]
assert len(thought_results) > 0
@@ -396,7 +395,6 @@ class TestReActAgentGenInner:
self,
mock_file,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
@@ -412,7 +410,7 @@ class TestReActAgentGenInner:
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ReActAgent(**agent_base_params)
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
results = list(agent._gen_inner("Test query", log_context))
sources = [r for r in results if "sources" in r]
assert len(sources) >= 1
@@ -422,7 +420,6 @@ class TestReActAgentGenInner:
self,
mock_file,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
@@ -440,7 +437,7 @@ class TestReActAgentGenInner:
agent = ReActAgent(**agent_base_params)
agent.tool_calls = [{"tool": "test", "result": "A" * 100}]
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
results = list(agent._gen_inner("Test query", log_context))
tool_call_results = [r for r in results if "tool_calls" in r]
if tool_call_results:
@@ -451,7 +448,6 @@ class TestReActAgentGenInner:
self,
mock_file,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
@@ -467,7 +463,7 @@ class TestReActAgentGenInner:
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ReActAgent(**agent_base_params)
list(agent._gen_inner("Test query", mock_retriever, log_context))
list(agent._gen_inner("Test query", log_context))
assert len(agent.observations) > 0
@@ -484,7 +480,6 @@ class TestReActAgentIntegration:
self,
mock_file,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
@@ -512,7 +507,7 @@ class TestReActAgentIntegration:
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ReActAgent(**agent_base_params)
results = list(agent._gen_inner("Complex query", mock_retriever, log_context))
results = list(agent._gen_inner("Complex query", log_context))
assert len(results) > 0
assert any("thought" in r for r in results)