mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-12-03 18:43:14 +00:00
feat: template-based prompt rendering with dynamic namespace injection (#2091)
* feat: template-based prompt rendering with dynamic namespace injection * refactor: improve template engine initialization with clearer formatting * refactor: streamline ReActAgent methods and improve content extraction logic feat: enhance error handling in NamespaceManager and TemplateEngine fix: update NewAgent component to ensure consistent form data submission test: modify tests for ReActAgent and prompt renderer to reflect method changes and improve coverage * feat: tools namespace + three-tier token budget * refactor: remove unused variable assignment in message building tests * Enhance prompt customization and tool pre-fetching functionality * ruff lint fix * refactor: cleaner error handling and reduce code clutter --------- Co-authored-by: Alex <a@tushynski.me>
This commit is contained in:
@@ -64,17 +64,14 @@ class TestBaseAgentBuildMessages:
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
system_prompt = "System: {summaries}"
|
||||
system_prompt = "System prompt content"
|
||||
query = "What is Python?"
|
||||
retrieved_data = [
|
||||
{"text": "Python is a programming language", "filename": "python.txt"}
|
||||
]
|
||||
|
||||
messages = agent._build_messages(system_prompt, query, retrieved_data)
|
||||
messages = agent._build_messages(system_prompt, query)
|
||||
|
||||
assert len(messages) >= 2
|
||||
assert messages[0]["role"] == "system"
|
||||
assert "Python is a programming language" in messages[0]["content"]
|
||||
assert messages[0]["content"] == system_prompt
|
||||
assert messages[-1]["role"] == "user"
|
||||
assert messages[-1]["content"] == query
|
||||
|
||||
@@ -88,11 +85,10 @@ class TestBaseAgentBuildMessages:
|
||||
agent_base_params["chat_history"] = sample_chat_history
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
system_prompt = "System: {summaries}"
|
||||
system_prompt = "System prompt"
|
||||
query = "New question?"
|
||||
retrieved_data = [{"text": "Data", "filename": "file.txt"}]
|
||||
|
||||
messages = agent._build_messages(system_prompt, query, retrieved_data)
|
||||
messages = agent._build_messages(system_prompt, query)
|
||||
|
||||
user_messages = [m for m in messages if m["role"] == "user"]
|
||||
assistant_messages = [m for m in messages if m["role"] == "assistant"]
|
||||
@@ -118,9 +114,7 @@ class TestBaseAgentBuildMessages:
|
||||
agent_base_params["chat_history"] = tool_call_history
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
messages = agent._build_messages(
|
||||
"System: {summaries}", "query", [{"text": "data", "filename": "file.txt"}]
|
||||
)
|
||||
messages = agent._build_messages("System prompt", "query")
|
||||
|
||||
tool_messages = [m for m in messages if m["role"] == "tool"]
|
||||
assert len(tool_messages) > 0
|
||||
@@ -129,32 +123,25 @@ class TestBaseAgentBuildMessages:
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
retrieved_data = [{"text": "Document without filename or title"}]
|
||||
|
||||
messages = agent._build_messages("System: {summaries}", "query", retrieved_data)
|
||||
messages = agent._build_messages("System prompt", "query")
|
||||
|
||||
assert messages[0]["role"] == "system"
|
||||
assert "Document without filename" in messages[0]["content"]
|
||||
assert messages[0]["content"] == "System prompt"
|
||||
|
||||
def test_build_messages_uses_title_as_fallback(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
retrieved_data = [{"text": "Data", "title": "Title Doc"}]
|
||||
|
||||
messages = agent._build_messages("System: {summaries}", "query", retrieved_data)
|
||||
|
||||
assert "Title Doc" in messages[0]["content"]
|
||||
agent._build_messages("System prompt", "query")
|
||||
|
||||
def test_build_messages_uses_source_as_fallback(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
retrieved_data = [{"text": "Data", "source": "source.txt"}]
|
||||
|
||||
messages = agent._build_messages("System: {summaries}", "query", retrieved_data)
|
||||
|
||||
assert "source.txt" in messages[0]["content"]
|
||||
agent._build_messages("System prompt", "query")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@@ -475,40 +462,6 @@ class TestBaseAgentToolExecution:
|
||||
assert truncated[0]["result"].endswith("...")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBaseAgentRetrieverSearch:
|
||||
|
||||
def test_retriever_search(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
log_context,
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
results = agent._retriever_search(mock_retriever, "test query", log_context)
|
||||
|
||||
assert len(results) == 2
|
||||
mock_retriever.search.assert_called_once_with("test query")
|
||||
|
||||
def test_retriever_search_logs_context(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
log_context,
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
agent._retriever_search(mock_retriever, "test query", log_context)
|
||||
|
||||
assert len(log_context.stacks) == 1
|
||||
assert log_context.stacks[0]["component"] == "retriever"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBaseAgentLLMGeneration:
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ class TestClassicAgent:
|
||||
def test_gen_inner_basic_flow(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -40,7 +39,7 @@ class TestClassicAgent:
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
results = list(agent._gen_inner("Test query", log_context))
|
||||
|
||||
assert len(results) >= 2
|
||||
sources = [r for r in results if "sources" in r]
|
||||
@@ -52,7 +51,6 @@ class TestClassicAgent:
|
||||
def test_gen_inner_retrieves_documents(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -68,14 +66,11 @@ class TestClassicAgent:
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
|
||||
mock_retriever.search.assert_called_once_with("Test query")
|
||||
list(agent._gen_inner("Test query", log_context))
|
||||
|
||||
def test_gen_inner_uses_user_api_key_tools(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -104,14 +99,13 @@ class TestClassicAgent:
|
||||
agent_base_params["user_api_key"] = "api_key_123"
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
list(agent._gen_inner("Test query", log_context))
|
||||
|
||||
assert len(agent.tools) >= 0
|
||||
|
||||
def test_gen_inner_uses_user_tools(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -133,14 +127,13 @@ class TestClassicAgent:
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
list(agent._gen_inner("Test query", log_context))
|
||||
|
||||
assert len(agent.tools) >= 0
|
||||
|
||||
def test_gen_inner_builds_correct_messages(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -156,7 +149,7 @@ class TestClassicAgent:
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
list(agent._gen_inner("Test query", log_context))
|
||||
|
||||
call_kwargs = mock_llm.gen_stream.call_args[1]
|
||||
messages = call_kwargs["messages"]
|
||||
@@ -169,7 +162,6 @@ class TestClassicAgent:
|
||||
def test_gen_inner_logs_tool_calls(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -187,7 +179,7 @@ class TestClassicAgent:
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
agent.tool_calls = [{"tool": "test", "result": "success"}]
|
||||
|
||||
list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
list(agent._gen_inner("Test query", log_context))
|
||||
|
||||
agent_logs = [s for s in log_context.stacks if s["component"] == "agent"]
|
||||
assert len(agent_logs) == 1
|
||||
@@ -200,7 +192,6 @@ class TestClassicAgentIntegration:
|
||||
def test_gen_method_with_logging(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -216,14 +207,13 @@ class TestClassicAgentIntegration:
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
results = list(agent.gen("Test query", mock_retriever))
|
||||
results = list(agent.gen("Test query"))
|
||||
|
||||
assert len(results) >= 1
|
||||
|
||||
def test_gen_method_decorator_applied(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
|
||||
@@ -35,7 +35,7 @@ class TestReActAgentContentExtraction:
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
|
||||
response = "Simple string response"
|
||||
content = agent._extract_content_from_llm_response(response)
|
||||
content = agent._extract_content(response)
|
||||
|
||||
assert content == "Simple string response"
|
||||
|
||||
@@ -48,7 +48,7 @@ class TestReActAgentContentExtraction:
|
||||
response.message = Mock()
|
||||
response.message.content = "Message content"
|
||||
|
||||
content = agent._extract_content_from_llm_response(response)
|
||||
content = agent._extract_content(response)
|
||||
|
||||
assert content == "Message content"
|
||||
|
||||
@@ -64,7 +64,7 @@ class TestReActAgentContentExtraction:
|
||||
response.message = None
|
||||
response.content = None
|
||||
|
||||
content = agent._extract_content_from_llm_response(response)
|
||||
content = agent._extract_content(response)
|
||||
|
||||
assert content == "OpenAI content"
|
||||
|
||||
@@ -81,7 +81,7 @@ class TestReActAgentContentExtraction:
|
||||
response.message = None
|
||||
response.choices = None
|
||||
|
||||
content = agent._extract_content_from_llm_response(response)
|
||||
content = agent._extract_content(response)
|
||||
|
||||
assert content == "Anthropic content"
|
||||
|
||||
@@ -101,7 +101,7 @@ class TestReActAgentContentExtraction:
|
||||
chunk2.choices[0].delta.content = "Part 2"
|
||||
|
||||
response = iter([chunk1, chunk2])
|
||||
content = agent._extract_content_from_llm_response(response)
|
||||
content = agent._extract_content(response)
|
||||
|
||||
assert content == "Part 1 Part 2"
|
||||
|
||||
@@ -123,7 +123,7 @@ class TestReActAgentContentExtraction:
|
||||
chunk2.choices = []
|
||||
|
||||
response = iter([chunk1, chunk2])
|
||||
content = agent._extract_content_from_llm_response(response)
|
||||
content = agent._extract_content(response)
|
||||
|
||||
assert content == "Stream 1 Stream 2"
|
||||
|
||||
@@ -133,7 +133,7 @@ class TestReActAgentContentExtraction:
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
|
||||
response = iter(["chunk1", "chunk2", "chunk3"])
|
||||
content = agent._extract_content_from_llm_response(response)
|
||||
content = agent._extract_content(response)
|
||||
|
||||
assert content == "chunk1chunk2chunk3"
|
||||
|
||||
@@ -148,7 +148,7 @@ class TestReActAgentContentExtraction:
|
||||
response.choices = None
|
||||
response.content = None
|
||||
|
||||
content = agent._extract_content_from_llm_response(response)
|
||||
content = agent._extract_content(response)
|
||||
|
||||
assert content == ""
|
||||
|
||||
@@ -161,7 +161,7 @@ class TestReActAgentPlanning:
|
||||
new_callable=mock_open,
|
||||
read_data="Test planning prompt: {query} {summaries} {prompt} {observations}",
|
||||
)
|
||||
def test_create_plan(
|
||||
def test_planning_phase(
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
@@ -171,24 +171,27 @@ class TestReActAgentPlanning:
|
||||
log_context,
|
||||
):
|
||||
def mock_gen_stream(*args, **kwargs):
|
||||
yield "Plan step 1"
|
||||
yield "Plan step 2"
|
||||
# Return simple strings - _extract_content handles strings directly
|
||||
|
||||
yield "Plan "
|
||||
yield "content"
|
||||
|
||||
mock_llm.gen_stream = Mock(return_value=mock_gen_stream())
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
agent.observations = ["Observation 1"]
|
||||
|
||||
plan_chunks = list(agent._create_plan("Test query", "Test docs", log_context))
|
||||
plan_chunks = list(agent._planning_phase("Test query", log_context))
|
||||
|
||||
assert len(plan_chunks) == 2
|
||||
assert plan_chunks[0] == "Plan step 1"
|
||||
assert plan_chunks[1] == "Plan step 2"
|
||||
# Should yield thought dicts
|
||||
|
||||
assert any("thought" in chunk for chunk in plan_chunks)
|
||||
assert agent.plan == "Plan content"
|
||||
|
||||
mock_llm.gen_stream.assert_called_once()
|
||||
|
||||
@patch("builtins.open", new_callable=mock_open, read_data="Test: {query}")
|
||||
def test_create_plan_fills_template(
|
||||
def test_planning_phase_fills_template(
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
@@ -197,10 +200,10 @@ class TestReActAgentPlanning:
|
||||
mock_llm_handler_creator,
|
||||
log_context,
|
||||
):
|
||||
mock_llm.gen_stream = Mock(return_value=iter(["Plan"]))
|
||||
mock_llm.gen_stream = Mock(return_value=iter([]))
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
list(agent._create_plan("My query", "Docs", log_context))
|
||||
list(agent._planning_phase("My query", log_context))
|
||||
|
||||
call_args = mock_llm.gen_stream.call_args[1]
|
||||
messages = call_args["messages"]
|
||||
@@ -216,7 +219,7 @@ class TestReActAgentFinalAnswer:
|
||||
new_callable=mock_open,
|
||||
read_data="Final answer for: {query} with {observations}",
|
||||
)
|
||||
def test_create_final_answer(
|
||||
def test_synthesis_phase(
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
@@ -226,24 +229,22 @@ class TestReActAgentFinalAnswer:
|
||||
log_context,
|
||||
):
|
||||
def mock_gen_stream(*args, **kwargs):
|
||||
yield "Final "
|
||||
yield "answer"
|
||||
yield Mock(choices=[Mock(delta=Mock(content="Final "))])
|
||||
yield Mock(choices=[Mock(delta=Mock(content="answer"))])
|
||||
|
||||
mock_llm.gen_stream = Mock(return_value=mock_gen_stream())
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
observations = ["Obs 1", "Obs 2"]
|
||||
agent.observations = ["Obs 1", "Obs 2"]
|
||||
|
||||
answer_chunks = list(
|
||||
agent._create_final_answer("Test query", observations, log_context)
|
||||
)
|
||||
answer_chunks = list(agent._synthesis_phase("Test query", log_context))
|
||||
|
||||
assert len(answer_chunks) == 2
|
||||
assert answer_chunks[0] == "Final "
|
||||
assert answer_chunks[1] == "answer"
|
||||
# Should yield answer dicts
|
||||
|
||||
assert any("answer" in chunk for chunk in answer_chunks)
|
||||
|
||||
@patch("builtins.open", new_callable=mock_open, read_data="Answer: {observations}")
|
||||
def test_create_final_answer_truncates_long_observations(
|
||||
def test_synthesis_phase_truncates_long_observations(
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
@@ -252,20 +253,20 @@ class TestReActAgentFinalAnswer:
|
||||
mock_llm_handler_creator,
|
||||
log_context,
|
||||
):
|
||||
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
|
||||
mock_llm.gen_stream = Mock(return_value=iter([]))
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
long_obs = ["A" * 15000]
|
||||
agent.observations = ["A" * 15000]
|
||||
|
||||
list(agent._create_final_answer("Query", long_obs, log_context))
|
||||
list(agent._synthesis_phase("Query", log_context))
|
||||
|
||||
call_args = mock_llm.gen_stream.call_args[1]
|
||||
messages = call_args["messages"]
|
||||
|
||||
assert "observations truncated" in messages[0]["content"]
|
||||
assert "truncated" in messages[0]["content"]
|
||||
|
||||
@patch("builtins.open", new_callable=mock_open, read_data="Test: {query}")
|
||||
def test_create_final_answer_no_tools(
|
||||
def test_synthesis_phase_no_tools(
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
@@ -274,10 +275,11 @@ class TestReActAgentFinalAnswer:
|
||||
mock_llm_handler_creator,
|
||||
log_context,
|
||||
):
|
||||
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
|
||||
mock_llm.gen_stream = Mock(return_value=iter([]))
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
list(agent._create_final_answer("Query", ["Obs"], log_context))
|
||||
agent.observations = ["Obs"]
|
||||
list(agent._synthesis_phase("Query", log_context))
|
||||
|
||||
call_args = mock_llm.gen_stream.call_args[1]
|
||||
|
||||
@@ -294,7 +296,6 @@ class TestReActAgentGenInner:
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -313,7 +314,7 @@ class TestReActAgentGenInner:
|
||||
agent.plan = "Old plan"
|
||||
agent.observations = ["Old obs"]
|
||||
|
||||
list(agent._gen_inner("New query", mock_retriever, log_context))
|
||||
list(agent._gen_inner("New query", log_context))
|
||||
|
||||
assert agent.plan != "Old plan"
|
||||
assert len(agent.observations) > 0
|
||||
@@ -323,7 +324,6 @@ class TestReActAgentGenInner:
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -351,7 +351,7 @@ class TestReActAgentGenInner:
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
results = list(agent._gen_inner("Test query", log_context))
|
||||
|
||||
assert any("answer" in r for r in results)
|
||||
|
||||
@@ -360,7 +360,6 @@ class TestReActAgentGenInner:
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -386,7 +385,7 @@ class TestReActAgentGenInner:
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
|
||||
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
results = list(agent._gen_inner("Test query", log_context))
|
||||
|
||||
thought_results = [r for r in results if "thought" in r]
|
||||
assert len(thought_results) > 0
|
||||
@@ -396,7 +395,6 @@ class TestReActAgentGenInner:
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -412,7 +410,7 @@ class TestReActAgentGenInner:
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
results = list(agent._gen_inner("Test query", log_context))
|
||||
|
||||
sources = [r for r in results if "sources" in r]
|
||||
assert len(sources) >= 1
|
||||
@@ -422,7 +420,6 @@ class TestReActAgentGenInner:
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -440,7 +437,7 @@ class TestReActAgentGenInner:
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
agent.tool_calls = [{"tool": "test", "result": "A" * 100}]
|
||||
|
||||
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
results = list(agent._gen_inner("Test query", log_context))
|
||||
|
||||
tool_call_results = [r for r in results if "tool_calls" in r]
|
||||
if tool_call_results:
|
||||
@@ -451,7 +448,6 @@ class TestReActAgentGenInner:
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -467,7 +463,7 @@ class TestReActAgentGenInner:
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
list(agent._gen_inner("Test query", log_context))
|
||||
|
||||
assert len(agent.observations) > 0
|
||||
|
||||
@@ -484,7 +480,6 @@ class TestReActAgentIntegration:
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -512,7 +507,7 @@ class TestReActAgentIntegration:
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
results = list(agent._gen_inner("Complex query", mock_retriever, log_context))
|
||||
results = list(agent._gen_inner("Complex query", log_context))
|
||||
|
||||
assert len(results) > 0
|
||||
assert any("thought" in r for r in results)
|
||||
|
||||
Reference in New Issue
Block a user