mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
feat: template-based prompt rendering with dynamic namespace injection (#2091)
* feat: template-based prompt rendering with dynamic namespace injection * refactor: improve template engine initialization with clearer formatting * refactor: streamline ReActAgent methods and improve content extraction logic feat: enhance error handling in NamespaceManager and TemplateEngine fix: update NewAgent component to ensure consistent form data submission test: modify tests for ReActAgent and prompt renderer to reflect method changes and improve coverage * feat: tools namespace + three-tier token budget * refactor: remove unused variable assignment in message building tests * Enhance prompt customization and tool pre-fetching functionality * ruff lint fix * refactor: cleaner error handling and reduce code clutter --------- Co-authored-by: Alex <a@tushynski.me>
This commit is contained in:
@@ -19,7 +19,6 @@ class TestClassicAgent:
|
||||
def test_gen_inner_basic_flow(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -40,7 +39,7 @@ class TestClassicAgent:
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
results = list(agent._gen_inner("Test query", log_context))
|
||||
|
||||
assert len(results) >= 2
|
||||
sources = [r for r in results if "sources" in r]
|
||||
@@ -52,7 +51,6 @@ class TestClassicAgent:
|
||||
def test_gen_inner_retrieves_documents(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -68,14 +66,11 @@ class TestClassicAgent:
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
|
||||
mock_retriever.search.assert_called_once_with("Test query")
|
||||
list(agent._gen_inner("Test query", log_context))
|
||||
|
||||
def test_gen_inner_uses_user_api_key_tools(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -104,14 +99,13 @@ class TestClassicAgent:
|
||||
agent_base_params["user_api_key"] = "api_key_123"
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
list(agent._gen_inner("Test query", log_context))
|
||||
|
||||
assert len(agent.tools) >= 0
|
||||
|
||||
def test_gen_inner_uses_user_tools(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -133,14 +127,13 @@ class TestClassicAgent:
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
list(agent._gen_inner("Test query", log_context))
|
||||
|
||||
assert len(agent.tools) >= 0
|
||||
|
||||
def test_gen_inner_builds_correct_messages(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -156,7 +149,7 @@ class TestClassicAgent:
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
list(agent._gen_inner("Test query", log_context))
|
||||
|
||||
call_kwargs = mock_llm.gen_stream.call_args[1]
|
||||
messages = call_kwargs["messages"]
|
||||
@@ -169,7 +162,6 @@ class TestClassicAgent:
|
||||
def test_gen_inner_logs_tool_calls(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -187,7 +179,7 @@ class TestClassicAgent:
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
agent.tool_calls = [{"tool": "test", "result": "success"}]
|
||||
|
||||
list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
list(agent._gen_inner("Test query", log_context))
|
||||
|
||||
agent_logs = [s for s in log_context.stacks if s["component"] == "agent"]
|
||||
assert len(agent_logs) == 1
|
||||
@@ -200,7 +192,6 @@ class TestClassicAgentIntegration:
|
||||
def test_gen_method_with_logging(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
@@ -216,14 +207,13 @@ class TestClassicAgentIntegration:
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
results = list(agent.gen("Test query", mock_retriever))
|
||||
results = list(agent.gen("Test query"))
|
||||
|
||||
assert len(results) >= 1
|
||||
|
||||
def test_gen_method_decorator_applied(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
|
||||
Reference in New Issue
Block a user