Files
DocsGPT/tests/agents/test_classic_agent.py
Siddhant Rai 21e5c261ef feat: template-based prompt rendering with dynamic namespace injection (#2091)
* feat: template-based prompt rendering with dynamic namespace injection

* refactor: improve template engine initialization with clearer formatting

* refactor: streamline ReActAgent methods and improve content extraction logic

feat: enhance error handling in NamespaceManager and TemplateEngine

fix: update NewAgent component to ensure consistent form data submission

test: modify tests for ReActAgent and prompt renderer to reflect method changes and improve coverage

* feat: tools namespace + three-tier token budget

* refactor: remove unused variable assignment in message building tests

* Enhance prompt customization and tool pre-fetching functionality

* ruff lint fix

* refactor: cleaner error handling and reduce code clutter

---------

Co-authored-by: Alex <a@tushynski.me>
2025-10-31 12:47:44 +00:00

233 lines
6.6 KiB
Python

from unittest.mock import Mock
import pytest
from application.agents.classic_agent import ClassicAgent
@pytest.mark.unit
class TestClassicAgent:
def test_classic_agent_initialization(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
assert isinstance(agent, ClassicAgent)
assert agent.endpoint == agent_base_params["endpoint"]
assert agent.llm_name == agent_base_params["llm_name"]
def test_gen_inner_basic_flow(
self,
agent_base_params,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
log_context,
):
def mock_gen_stream(*args, **kwargs):
yield "Answer chunk 1"
yield "Answer chunk 2"
mock_llm.gen_stream = Mock(return_value=mock_gen_stream())
def mock_handler(*args, **kwargs):
yield "Processed answer"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ClassicAgent(**agent_base_params)
results = list(agent._gen_inner("Test query", log_context))
assert len(results) >= 2
sources = [r for r in results if "sources" in r]
tool_calls = [r for r in results if "tool_calls" in r]
assert len(sources) == 1
assert len(tool_calls) == 1
def test_gen_inner_retrieves_documents(
self,
agent_base_params,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
log_context,
):
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
def mock_handler(*args, **kwargs):
yield "Processed"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ClassicAgent(**agent_base_params)
list(agent._gen_inner("Test query", log_context))
def test_gen_inner_uses_user_api_key_tools(
self,
agent_base_params,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
log_context,
):
from application.core.settings import settings
from bson.objectid import ObjectId
tool_id = str(ObjectId())
mock_mongo_db[settings.MONGO_DB_NAME]["agents"].docs = {
"api_key_123": {"key": "api_key_123", "tools": [tool_id]}
}
mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"].docs = {
tool_id: {"_id": ObjectId(tool_id), "name": "test_tool"}
}
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
def mock_handler(*args, **kwargs):
yield "Processed"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent_base_params["user_api_key"] = "api_key_123"
agent = ClassicAgent(**agent_base_params)
list(agent._gen_inner("Test query", log_context))
assert len(agent.tools) >= 0
def test_gen_inner_uses_user_tools(
self,
agent_base_params,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
log_context,
):
from application.core.settings import settings
mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"].docs = {
"1": {"_id": "1", "user": "test_user", "name": "tool1", "status": True}
}
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
def mock_handler(*args, **kwargs):
yield "Processed"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ClassicAgent(**agent_base_params)
list(agent._gen_inner("Test query", log_context))
assert len(agent.tools) >= 0
def test_gen_inner_builds_correct_messages(
self,
agent_base_params,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
log_context,
):
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
def mock_handler(*args, **kwargs):
yield "Processed"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ClassicAgent(**agent_base_params)
list(agent._gen_inner("Test query", log_context))
call_kwargs = mock_llm.gen_stream.call_args[1]
messages = call_kwargs["messages"]
assert len(messages) >= 2
assert messages[0]["role"] == "system"
assert messages[-1]["role"] == "user"
assert messages[-1]["content"] == "Test query"
def test_gen_inner_logs_tool_calls(
self,
agent_base_params,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
log_context,
):
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
def mock_handler(*args, **kwargs):
yield "Processed"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ClassicAgent(**agent_base_params)
agent.tool_calls = [{"tool": "test", "result": "success"}]
list(agent._gen_inner("Test query", log_context))
agent_logs = [s for s in log_context.stacks if s["component"] == "agent"]
assert len(agent_logs) == 1
assert "tool_calls" in agent_logs[0]["data"]
@pytest.mark.integration
class TestClassicAgentIntegration:
def test_gen_method_with_logging(
self,
agent_base_params,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
):
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
def mock_handler(*args, **kwargs):
yield "Processed"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ClassicAgent(**agent_base_params)
results = list(agent.gen("Test query"))
assert len(results) >= 1
def test_gen_method_decorator_applied(
self,
agent_base_params,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
):
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
def mock_handler(*args, **kwargs):
yield "Processed"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ClassicAgent(**agent_base_params)
assert hasattr(agent.gen, "__wrapped__")