mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 16:43:16 +00:00
feat: template-based prompt rendering with dynamic namespace injection (#2091)
* feat: template-based prompt rendering with dynamic namespace injection * refactor: improve template engine initialization with clearer formatting * refactor: streamline ReActAgent methods and improve content extraction logic feat: enhance error handling in NamespaceManager and TemplateEngine fix: update NewAgent component to ensure consistent form data submission test: modify tests for ReActAgent and prompt renderer to reflect method changes and improve coverage * feat: tools namespace + three-tier token budget * refactor: remove unused variable assignment in message building tests * Enhance prompt customization and tool pre-fetching functionality * ruff lint fix * refactor: cleaner error handling and reduce code clutter --------- Co-authored-by: Alex <a@tushynski.me>
This commit is contained in:
@@ -35,7 +35,7 @@ class TestReActAgentContentExtraction:
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
|
||||
response = "Simple string response"
|
||||
content = agent._extract_content_from_llm_response(response)
|
||||
content = agent._extract_content(response)
|
||||
|
||||
assert content == "Simple string response"
|
||||
|
||||
@@ -48,7 +48,7 @@ class TestReActAgentContentExtraction:
|
||||
response.message = Mock()
|
||||
response.message.content = "Message content"
|
||||
|
||||
content = agent._extract_content_from_llm_response(response)
|
||||
content = agent._extract_content(response)
|
||||
|
||||
assert content == "Message content"
|
||||
|
||||
@@ -64,7 +64,7 @@ class TestReActAgentContentExtraction:
|
||||
response.message = None
|
||||
response.content = None
|
||||
|
||||
content = agent._extract_content_from_llm_response(response)
|
||||
content = agent._extract_content(response)
|
||||
|
||||
assert content == "OpenAI content"
|
||||
|
||||
@@ -81,7 +81,7 @@ class TestReActAgentContentExtraction:
|
||||
response.message = None
|
||||
response.choices = None
|
||||
|
||||
content = agent._extract_content_from_llm_response(response)
|
||||
content = agent._extract_content(response)
|
||||
|
||||
assert content == "Anthropic content"
|
||||
|
||||
@@ -101,7 +101,7 @@ class TestReActAgentContentExtraction:
|
||||
chunk2.choices[0].delta.content = "Part 2"
|
||||
|
||||
response = iter([chunk1, chunk2])
|
||||
content = agent._extract_content_from_llm_response(response)
|
||||
content = agent._extract_content(response)
|
||||
|
||||
assert content == "Part 1 Part 2"
|
||||
|
||||
@@ -123,7 +123,7 @@ class TestReActAgentContentExtraction:
|
||||
chunk2.choices = []
|
||||
|
||||
response = iter([chunk1, chunk2])
|
||||
content = agent._extract_content_from_llm_response(response)
|
||||
content = agent._extract_content(response)
|
||||
|
||||
assert content == "Stream 1 Stream 2"
|
||||
|
||||
@@ -133,7 +133,7 @@ class TestReActAgentContentExtraction:
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
|
||||
response = iter(["chunk1", "chunk2", "chunk3"])
|
||||
content = agent._extract_content_from_llm_response(response)
|
||||
content = agent._extract_content(response)
|
||||
|
||||
assert content == "chunk1chunk2chunk3"
|
||||
|
||||
@@ -148,7 +148,7 @@ class TestReActAgentContentExtraction:
|
||||
response.choices = None
|
||||
response.content = None
|
||||
|
||||
content = agent._extract_content_from_llm_response(response)
|
||||
content = agent._extract_content(response)
|
||||
|
||||
assert content == ""
|
||||
|
||||
@@ -161,7 +161,7 @@ class TestReActAgentPlanning:
|
||||
new_callable=mock_open,
|
||||
read_data="Test planning prompt: {query} {summaries} {prompt} {observations}",
|
||||
)
|
||||
def test_create_plan(
|
||||
def test_planning_phase(
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
@@ -171,24 +171,27 @@ class TestReActAgentPlanning:
|
||||
log_context,
|
||||
):
|
||||
def mock_gen_stream(*args, **kwargs):
|
||||
yield "Plan step 1"
|
||||
yield "Plan step 2"
|
||||
# Return simple strings - _extract_content handles strings directly
|
||||
|
||||
yield "Plan "
|
||||
yield "content"
|
||||
|
||||
mock_llm.gen_stream = Mock(return_value=mock_gen_stream())
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
agent.observations = ["Observation 1"]
|
||||
|
||||
plan_chunks = list(agent._create_plan("Test query", "Test docs", log_context))
|
||||
plan_chunks = list(agent._planning_phase("Test query", log_context))
|
||||
|
||||
assert len(plan_chunks) == 2
|
||||
assert plan_chunks[0] == "Plan step 1"
|
||||
assert plan_chunks[1] == "Plan step 2"
|
||||
# Should yield thought dicts
|
||||
|
||||
assert any("thought" in chunk for chunk in plan_chunks)
|
||||
assert agent.plan == "Plan content"
|
||||
|
||||
mock_llm.gen_stream.assert_called_once()
|
||||
|
||||
@patch("builtins.open", new_callable=mock_open, read_data="Test: {query}")
|
||||
def test_create_plan_fills_template(
|
||||
def test_planning_phase_fills_template(
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
@@ -197,10 +200,10 @@ class TestReActAgentPlanning:
|
||||
mock_llm_handler_creator,
|
||||
log_context,
|
||||
):
|
||||
mock_llm.gen_stream = Mock(return_value=iter(["Plan"]))
|
||||
mock_llm.gen_stream = Mock(return_value=iter([]))
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
list(agent._create_plan("My query", "Docs", log_context))
|
||||
list(agent._planning_phase("My query", log_context))
|
||||
|
||||
call_args = mock_llm.gen_stream.call_args[1]
|
||||
messages = call_args["messages"]
|
||||
@@ -216,7 +219,7 @@ class TestReActAgentFinalAnswer:
|
||||
new_callable=mock_open,
|
||||
read_data="Final answer for: {query} with {observations}",
|
||||
)
|
||||
def test_create_final_answer(
|
||||
def test_synthesis_phase(
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
@@ -226,24 +229,22 @@ class TestReActAgentFinalAnswer:
|
||||
log_context,
|
||||
):
|
||||
def mock_gen_stream(*args, **kwargs):
|
||||
yield "Final "
|
||||
yield "answer"
|
||||
yield Mock(choices=[Mock(delta=Mock(content="Final "))])
|
||||
yield Mock(choices=[Mock(delta=Mock(content="answer"))])
|
||||
|
||||
mock_llm.gen_stream = Mock(return_value=mock_gen_stream())
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
observations = ["Obs 1", "Obs 2"]
|
||||
agent.observations = ["Obs 1", "Obs 2"]
|
||||
|
||||
answer_chunks = list(
|
||||
agent._create_final_answer("Test query", observations, log_context)
|
||||
)
|
||||
answer_chunks = list(agent._synthesis_phase("Test query", log_context))
|
||||
|
||||
assert len(answer_chunks) == 2
|
||||
assert answer_chunks[0] == "Final "
|
||||
assert answer_chunks[1] == "answer"
|
||||
# Should yield answer dicts
|
||||
|
||||
assert any("answer" in chunk for chunk in answer_chunks)
|
||||
|
||||
@patch("builtins.open", new_callable=mock_open, read_data="Answer: {observations}")
|
||||
def test_create_final_answer_truncates_long_observations(
|
||||
def test_synthesis_phase_truncates_long_observations(
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
@@ -252,20 +253,20 @@ class TestReActAgentFinalAnswer:
|
||||
mock_llm_handler_creator,
|
||||
log_context,
|
||||
):
|
||||
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
|
||||
mock_llm.gen_stream = Mock(return_value=iter([]))
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
long_obs = ["A" * 15000]
|
||||
agent.observations = ["A" * 15000]
|
||||
|
||||
list(agent._create_final_answer("Query", long_obs, log_context))
|
||||
list(agent._synthesis_phase("Query", log_context))
|
||||
|
||||
call_args = mock_llm.gen_stream.call_args[1]
|
||||
messages = call_args["messages"]
|
||||
|
||||
assert "observations truncated" in messages[0]["content"]
|
||||
assert "truncated" in messages[0]["content"]
|
||||
|
||||
@patch("builtins.open", new_callable=mock_open, read_data="Test: {query}")
|
||||
def test_create_final_answer_no_tools(
|
||||
def test_synthesis_phase_no_tools(
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
@@ -274,10 +275,11 @@ class TestReActAgentFinalAnswer:
|
||||
mock_llm_handler_creator,
|
||||
log_context,
|
||||
):
|
||||
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
|
||||
mock_llm.gen_stream = Mock(return_value=iter([]))
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
list(agent._create_final_answer("Query", ["Obs"], log_context))
|
||||
agent.observations = ["Obs"]
|
||||
list(agent._synthesis_phase("Query", log_context))
|
||||
|
||||
call_args = mock_llm.gen_stream.call_args[1]
|
||||
|
||||
@@ -294,7 +296,6 @@ class TestReActAgentGenInner:
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -313,7 +314,7 @@ class TestReActAgentGenInner:
|
||||
agent.plan = "Old plan"
|
||||
agent.observations = ["Old obs"]
|
||||
|
||||
list(agent._gen_inner("New query", mock_retriever, log_context))
|
||||
list(agent._gen_inner("New query", log_context))
|
||||
|
||||
assert agent.plan != "Old plan"
|
||||
assert len(agent.observations) > 0
|
||||
@@ -323,7 +324,6 @@ class TestReActAgentGenInner:
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -351,7 +351,7 @@ class TestReActAgentGenInner:
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
results = list(agent._gen_inner("Test query", log_context))
|
||||
|
||||
assert any("answer" in r for r in results)
|
||||
|
||||
@@ -360,7 +360,6 @@ class TestReActAgentGenInner:
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -386,7 +385,7 @@ class TestReActAgentGenInner:
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
|
||||
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
results = list(agent._gen_inner("Test query", log_context))
|
||||
|
||||
thought_results = [r for r in results if "thought" in r]
|
||||
assert len(thought_results) > 0
|
||||
@@ -396,7 +395,6 @@ class TestReActAgentGenInner:
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -412,7 +410,7 @@ class TestReActAgentGenInner:
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
results = list(agent._gen_inner("Test query", log_context))
|
||||
|
||||
sources = [r for r in results if "sources" in r]
|
||||
assert len(sources) >= 1
|
||||
@@ -422,7 +420,6 @@ class TestReActAgentGenInner:
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -440,7 +437,7 @@ class TestReActAgentGenInner:
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
agent.tool_calls = [{"tool": "test", "result": "A" * 100}]
|
||||
|
||||
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
results = list(agent._gen_inner("Test query", log_context))
|
||||
|
||||
tool_call_results = [r for r in results if "tool_calls" in r]
|
||||
if tool_call_results:
|
||||
@@ -451,7 +448,6 @@ class TestReActAgentGenInner:
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -467,7 +463,7 @@ class TestReActAgentGenInner:
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
list(agent._gen_inner("Test query", log_context))
|
||||
|
||||
assert len(agent.observations) > 0
|
||||
|
||||
@@ -484,7 +480,6 @@ class TestReActAgentIntegration:
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -512,7 +507,7 @@ class TestReActAgentIntegration:
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
results = list(agent._gen_inner("Complex query", mock_retriever, log_context))
|
||||
results = list(agent._gen_inner("Complex query", log_context))
|
||||
|
||||
assert len(results) > 0
|
||||
assert any("thought" in r for r in results)
|
||||
|
||||
Reference in New Issue
Block a user