mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 16:43:16 +00:00
- Add 104 comprehensive tests for agent system - Integrate agent tests into CI/CD pipeline - Standardize tests with @pytest.mark.unit markers - Fix cross-platform path compatibility - Clean up unused imports and dependencies
243 lines
7.0 KiB
Python
243 lines
7.0 KiB
Python
from unittest.mock import Mock
|
|
|
|
import pytest
|
|
from application.agents.classic_agent import ClassicAgent
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestClassicAgent:
|
|
|
|
def test_classic_agent_initialization(
|
|
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
|
):
|
|
agent = ClassicAgent(**agent_base_params)
|
|
|
|
assert isinstance(agent, ClassicAgent)
|
|
assert agent.endpoint == agent_base_params["endpoint"]
|
|
assert agent.llm_name == agent_base_params["llm_name"]
|
|
|
|
def test_gen_inner_basic_flow(
|
|
self,
|
|
agent_base_params,
|
|
mock_retriever,
|
|
mock_llm,
|
|
mock_llm_handler,
|
|
mock_llm_creator,
|
|
mock_llm_handler_creator,
|
|
mock_mongo_db,
|
|
log_context,
|
|
):
|
|
def mock_gen_stream(*args, **kwargs):
|
|
yield "Answer chunk 1"
|
|
yield "Answer chunk 2"
|
|
|
|
mock_llm.gen_stream = Mock(return_value=mock_gen_stream())
|
|
|
|
def mock_handler(*args, **kwargs):
|
|
yield "Processed answer"
|
|
|
|
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
|
|
|
agent = ClassicAgent(**agent_base_params)
|
|
|
|
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
|
|
|
|
assert len(results) >= 2
|
|
sources = [r for r in results if "sources" in r]
|
|
tool_calls = [r for r in results if "tool_calls" in r]
|
|
|
|
assert len(sources) == 1
|
|
assert len(tool_calls) == 1
|
|
|
|
def test_gen_inner_retrieves_documents(
|
|
self,
|
|
agent_base_params,
|
|
mock_retriever,
|
|
mock_llm,
|
|
mock_llm_handler,
|
|
mock_llm_creator,
|
|
mock_llm_handler_creator,
|
|
mock_mongo_db,
|
|
log_context,
|
|
):
|
|
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
|
|
|
|
def mock_handler(*args, **kwargs):
|
|
yield "Processed"
|
|
|
|
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
|
|
|
agent = ClassicAgent(**agent_base_params)
|
|
list(agent._gen_inner("Test query", mock_retriever, log_context))
|
|
|
|
mock_retriever.search.assert_called_once_with("Test query")
|
|
|
|
def test_gen_inner_uses_user_api_key_tools(
|
|
self,
|
|
agent_base_params,
|
|
mock_retriever,
|
|
mock_llm,
|
|
mock_llm_handler,
|
|
mock_llm_creator,
|
|
mock_llm_handler_creator,
|
|
mock_mongo_db,
|
|
log_context,
|
|
):
|
|
from application.core.settings import settings
|
|
from bson.objectid import ObjectId
|
|
|
|
tool_id = str(ObjectId())
|
|
mock_mongo_db[settings.MONGO_DB_NAME]["agents"].docs = {
|
|
"api_key_123": {"key": "api_key_123", "tools": [tool_id]}
|
|
}
|
|
mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"].docs = {
|
|
tool_id: {"_id": ObjectId(tool_id), "name": "test_tool"}
|
|
}
|
|
|
|
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
|
|
|
|
def mock_handler(*args, **kwargs):
|
|
yield "Processed"
|
|
|
|
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
|
|
|
agent_base_params["user_api_key"] = "api_key_123"
|
|
agent = ClassicAgent(**agent_base_params)
|
|
|
|
list(agent._gen_inner("Test query", mock_retriever, log_context))
|
|
|
|
assert len(agent.tools) >= 0
|
|
|
|
def test_gen_inner_uses_user_tools(
|
|
self,
|
|
agent_base_params,
|
|
mock_retriever,
|
|
mock_llm,
|
|
mock_llm_handler,
|
|
mock_llm_creator,
|
|
mock_llm_handler_creator,
|
|
mock_mongo_db,
|
|
log_context,
|
|
):
|
|
from application.core.settings import settings
|
|
|
|
mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"].docs = {
|
|
"1": {"_id": "1", "user": "test_user", "name": "tool1", "status": True}
|
|
}
|
|
|
|
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
|
|
|
|
def mock_handler(*args, **kwargs):
|
|
yield "Processed"
|
|
|
|
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
|
|
|
agent = ClassicAgent(**agent_base_params)
|
|
list(agent._gen_inner("Test query", mock_retriever, log_context))
|
|
|
|
assert len(agent.tools) >= 0
|
|
|
|
def test_gen_inner_builds_correct_messages(
|
|
self,
|
|
agent_base_params,
|
|
mock_retriever,
|
|
mock_llm,
|
|
mock_llm_handler,
|
|
mock_llm_creator,
|
|
mock_llm_handler_creator,
|
|
mock_mongo_db,
|
|
log_context,
|
|
):
|
|
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
|
|
|
|
def mock_handler(*args, **kwargs):
|
|
yield "Processed"
|
|
|
|
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
|
|
|
agent = ClassicAgent(**agent_base_params)
|
|
list(agent._gen_inner("Test query", mock_retriever, log_context))
|
|
|
|
call_kwargs = mock_llm.gen_stream.call_args[1]
|
|
messages = call_kwargs["messages"]
|
|
|
|
assert len(messages) >= 2
|
|
assert messages[0]["role"] == "system"
|
|
assert messages[-1]["role"] == "user"
|
|
assert messages[-1]["content"] == "Test query"
|
|
|
|
def test_gen_inner_logs_tool_calls(
|
|
self,
|
|
agent_base_params,
|
|
mock_retriever,
|
|
mock_llm,
|
|
mock_llm_handler,
|
|
mock_llm_creator,
|
|
mock_llm_handler_creator,
|
|
mock_mongo_db,
|
|
log_context,
|
|
):
|
|
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
|
|
|
|
def mock_handler(*args, **kwargs):
|
|
yield "Processed"
|
|
|
|
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
|
|
|
agent = ClassicAgent(**agent_base_params)
|
|
agent.tool_calls = [{"tool": "test", "result": "success"}]
|
|
|
|
list(agent._gen_inner("Test query", mock_retriever, log_context))
|
|
|
|
agent_logs = [s for s in log_context.stacks if s["component"] == "agent"]
|
|
assert len(agent_logs) == 1
|
|
assert "tool_calls" in agent_logs[0]["data"]
|
|
|
|
|
|
@pytest.mark.integration
|
|
class TestClassicAgentIntegration:
|
|
|
|
def test_gen_method_with_logging(
|
|
self,
|
|
agent_base_params,
|
|
mock_retriever,
|
|
mock_llm,
|
|
mock_llm_handler,
|
|
mock_llm_creator,
|
|
mock_llm_handler_creator,
|
|
mock_mongo_db,
|
|
):
|
|
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
|
|
|
|
def mock_handler(*args, **kwargs):
|
|
yield "Processed"
|
|
|
|
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
|
|
|
agent = ClassicAgent(**agent_base_params)
|
|
|
|
results = list(agent.gen("Test query", mock_retriever))
|
|
|
|
assert len(results) >= 1
|
|
|
|
def test_gen_method_decorator_applied(
|
|
self,
|
|
agent_base_params,
|
|
mock_retriever,
|
|
mock_llm,
|
|
mock_llm_handler,
|
|
mock_llm_creator,
|
|
mock_llm_handler_creator,
|
|
mock_mongo_db,
|
|
):
|
|
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
|
|
|
|
def mock_handler(*args, **kwargs):
|
|
yield "Processed"
|
|
|
|
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
|
|
|
agent = ClassicAgent(**agent_base_params)
|
|
|
|
assert hasattr(agent.gen, "__wrapped__")
|