test: add agent test coverage and standardize test suite (#2051)

- Add 104 comprehensive tests for agent system
- Integrate agent tests into CI/CD pipeline
- Standardize tests with @pytest.mark.unit markers
- Fix cross-platform path compatibility
- Clean up unused imports and dependencies
This commit is contained in:
Siddhant Rai
2025-10-13 17:13:35 +05:30
committed by GitHub
parent 1805292528
commit d6c49bdbf0
24 changed files with 2775 additions and 501 deletions

0
tests/agents/__init__.py Normal file
View File

View File

@@ -0,0 +1,56 @@
import pytest
from application.agents.agent_creator import AgentCreator
from application.agents.classic_agent import ClassicAgent
from application.agents.react_agent import ReActAgent
@pytest.mark.unit
class TestAgentCreator:
def test_create_classic_agent(self, agent_base_params):
agent = AgentCreator.create_agent("classic", **agent_base_params)
assert isinstance(agent, ClassicAgent)
assert agent.endpoint == agent_base_params["endpoint"]
assert agent.llm_name == agent_base_params["llm_name"]
assert agent.gpt_model == agent_base_params["gpt_model"]
def test_create_react_agent(self, agent_base_params):
agent = AgentCreator.create_agent("react", **agent_base_params)
assert isinstance(agent, ReActAgent)
assert agent.endpoint == agent_base_params["endpoint"]
assert agent.llm_name == agent_base_params["llm_name"]
def test_create_agent_case_insensitive(self, agent_base_params):
agent_upper = AgentCreator.create_agent("CLASSIC", **agent_base_params)
agent_mixed = AgentCreator.create_agent("ClAsSiC", **agent_base_params)
assert isinstance(agent_upper, ClassicAgent)
assert isinstance(agent_mixed, ClassicAgent)
def test_create_agent_invalid_type(self, agent_base_params):
with pytest.raises(ValueError, match="No agent class found for type"):
AgentCreator.create_agent("invalid_agent_type", **agent_base_params)
def test_agent_registry_contains_expected_agents(self):
assert "classic" in AgentCreator.agents
assert "react" in AgentCreator.agents
assert AgentCreator.agents["classic"] == ClassicAgent
assert AgentCreator.agents["react"] == ReActAgent
def test_create_agent_with_optional_params(self, agent_base_params):
agent_base_params["user_api_key"] = "user_key_123"
agent_base_params["chat_history"] = [{"prompt": "test", "response": "test"}]
agent_base_params["json_schema"] = {"type": "object"}
agent = AgentCreator.create_agent("classic", **agent_base_params)
assert agent.user_api_key == "user_key_123"
assert len(agent.chat_history) == 1
assert agent.json_schema == {"type": "object"}
def test_create_agent_with_attachments(self, agent_base_params):
attachments = [{"name": "file.txt", "content": "test"}]
agent_base_params["attachments"] = attachments
agent = AgentCreator.create_agent("classic", **agent_base_params)
assert agent.attachments == attachments

View File

@@ -0,0 +1,641 @@
from unittest.mock import Mock
import pytest
from application.agents.classic_agent import ClassicAgent
from application.core.settings import settings
from tests.conftest import FakeMongoCollection
@pytest.mark.unit
class TestBaseAgentInitialization:
def test_agent_initialization(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
assert agent.endpoint == agent_base_params["endpoint"]
assert agent.llm_name == agent_base_params["llm_name"]
assert agent.gpt_model == agent_base_params["gpt_model"]
assert agent.api_key == agent_base_params["api_key"]
assert agent.prompt == agent_base_params["prompt"]
assert agent.user == agent_base_params["decoded_token"]["sub"]
assert agent.tools == []
assert agent.tool_calls == []
def test_agent_initialization_with_none_chat_history(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent_base_params["chat_history"] = None
agent = ClassicAgent(**agent_base_params)
assert agent.chat_history == []
def test_agent_initialization_with_chat_history(
self,
agent_base_params,
sample_chat_history,
mock_llm_creator,
mock_llm_handler_creator,
):
agent_base_params["chat_history"] = sample_chat_history
agent = ClassicAgent(**agent_base_params)
assert len(agent.chat_history) == 2
assert agent.chat_history[0]["prompt"] == "What is Python?"
def test_agent_decoded_token_defaults_to_empty_dict(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent_base_params["decoded_token"] = None
agent = ClassicAgent(**agent_base_params)
assert agent.decoded_token == {}
assert agent.user is None
def test_agent_user_extracted_from_token(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent_base_params["decoded_token"] = {"sub": "user123"}
agent = ClassicAgent(**agent_base_params)
assert agent.user == "user123"
@pytest.mark.unit
class TestBaseAgentBuildMessages:
def test_build_messages_basic(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
system_prompt = "System: {summaries}"
query = "What is Python?"
retrieved_data = [
{"text": "Python is a programming language", "filename": "python.txt"}
]
messages = agent._build_messages(system_prompt, query, retrieved_data)
assert len(messages) >= 2
assert messages[0]["role"] == "system"
assert "Python is a programming language" in messages[0]["content"]
assert messages[-1]["role"] == "user"
assert messages[-1]["content"] == query
def test_build_messages_with_chat_history(
self,
agent_base_params,
sample_chat_history,
mock_llm_creator,
mock_llm_handler_creator,
):
agent_base_params["chat_history"] = sample_chat_history
agent = ClassicAgent(**agent_base_params)
system_prompt = "System: {summaries}"
query = "New question?"
retrieved_data = [{"text": "Data", "filename": "file.txt"}]
messages = agent._build_messages(system_prompt, query, retrieved_data)
user_messages = [m for m in messages if m["role"] == "user"]
assistant_messages = [m for m in messages if m["role"] == "assistant"]
assert len(user_messages) >= 3
assert len(assistant_messages) >= 2
def test_build_messages_with_tool_calls_in_history(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
tool_call_history = [
{
"tool_calls": [
{
"call_id": "123",
"action_name": "test_action",
"arguments": {"arg": "value"},
"result": "success",
}
]
}
]
agent_base_params["chat_history"] = tool_call_history
agent = ClassicAgent(**agent_base_params)
messages = agent._build_messages(
"System: {summaries}", "query", [{"text": "data", "filename": "file.txt"}]
)
tool_messages = [m for m in messages if m["role"] == "tool"]
assert len(tool_messages) > 0
def test_build_messages_handles_missing_filename(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
retrieved_data = [{"text": "Document without filename or title"}]
messages = agent._build_messages("System: {summaries}", "query", retrieved_data)
assert messages[0]["role"] == "system"
assert "Document without filename" in messages[0]["content"]
def test_build_messages_uses_title_as_fallback(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
retrieved_data = [{"text": "Data", "title": "Title Doc"}]
messages = agent._build_messages("System: {summaries}", "query", retrieved_data)
assert "Title Doc" in messages[0]["content"]
def test_build_messages_uses_source_as_fallback(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
retrieved_data = [{"text": "Data", "source": "source.txt"}]
messages = agent._build_messages("System: {summaries}", "query", retrieved_data)
assert "source.txt" in messages[0]["content"]
@pytest.mark.unit
class TestBaseAgentTools:
def test_get_user_tools(
self,
agent_base_params,
mock_mongo_db,
mock_llm_creator,
mock_llm_handler_creator,
):
mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"].docs = {
"1": {"_id": "1", "user": "test_user", "name": "tool1", "status": True},
"2": {"_id": "2", "user": "test_user", "name": "tool2", "status": True},
}
agent = ClassicAgent(**agent_base_params)
tools = agent._get_user_tools("test_user")
assert len(tools) == 2
assert "0" in tools
assert "1" in tools
def test_get_user_tools_filters_by_status(
self,
agent_base_params,
mock_mongo_db,
mock_llm_creator,
mock_llm_handler_creator,
):
mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"].docs = {
"1": {"_id": "1", "user": "test_user", "name": "tool1", "status": True},
"2": {"_id": "2", "user": "test_user", "name": "tool2", "status": False},
}
agent = ClassicAgent(**agent_base_params)
tools = agent._get_user_tools("test_user")
assert len(tools) == 1
def test_get_tools_by_api_key(
self,
agent_base_params,
mock_mongo_db,
mock_llm_creator,
mock_llm_handler_creator,
):
from bson.objectid import ObjectId
tool_id = str(ObjectId())
tool_obj_id = ObjectId(tool_id)
fake_agent_collection = FakeMongoCollection()
fake_agent_collection.docs["api_key_123"] = {
"key": "api_key_123",
"tools": [tool_id],
}
fake_tools_collection = FakeMongoCollection()
fake_tools_collection.docs[tool_id] = {"_id": tool_obj_id, "name": "api_tool"}
mock_mongo_db[settings.MONGO_DB_NAME]["agents"] = fake_agent_collection
mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"] = fake_tools_collection
agent = ClassicAgent(**agent_base_params)
tools = agent._get_tools("api_key_123")
assert tool_id in tools
def test_build_tool_parameters(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
action = {
"parameters": {
"properties": {
"param1": {
"type": "string",
"description": "Test param",
"filled_by_llm": True,
},
"param2": {"type": "number", "filled_by_llm": False, "value": 42},
}
}
}
params = agent._build_tool_parameters(action)
assert "param1" in params["properties"]
assert "param1" in params["required"]
assert "filled_by_llm" not in params["properties"]["param1"]
def test_prepare_tools_with_api_tool(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
tools_dict = {
"1": {
"name": "api_tool",
"config": {
"actions": {
"get_data": {
"name": "get_data",
"description": "Get data from API",
"active": True,
"url": "https://api.example.com/data",
"method": "GET",
"parameters": {"properties": {}},
}
}
},
}
}
agent._prepare_tools(tools_dict)
assert len(agent.tools) == 1
assert agent.tools[0]["type"] == "function"
assert agent.tools[0]["function"]["name"] == "get_data_1"
def test_prepare_tools_with_regular_tool(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
tools_dict = {
"1": {
"name": "custom_tool",
"actions": [
{
"name": "action1",
"description": "Custom action",
"active": True,
"parameters": {"properties": {}},
}
],
}
}
agent._prepare_tools(tools_dict)
assert len(agent.tools) == 1
assert agent.tools[0]["function"]["name"] == "action1_1"
def test_prepare_tools_filters_inactive_actions(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
tools_dict = {
"1": {
"name": "custom_tool",
"actions": [
{
"name": "active_action",
"description": "Active",
"active": True,
"parameters": {"properties": {}},
},
{
"name": "inactive_action",
"description": "Inactive",
"active": False,
"parameters": {"properties": {}},
},
],
}
}
agent._prepare_tools(tools_dict)
assert len(agent.tools) == 1
assert agent.tools[0]["function"]["name"] == "active_action_1"
@pytest.mark.unit
class TestBaseAgentToolExecution:
def test_execute_tool_action_success(
self,
agent_base_params,
mock_llm_creator,
mock_llm_handler_creator,
mock_tool_manager,
):
agent = ClassicAgent(**agent_base_params)
call = Mock()
call.id = "call_123"
call.name = "test_action_1"
call.arguments = '{"param1": "value1"}'
tools_dict = {
"1": {
"name": "custom_tool",
"config": {},
"actions": [
{
"name": "test_action",
"description": "Test",
"parameters": {"properties": {}},
}
],
}
}
results = list(agent._execute_tool_action(tools_dict, call))
assert len(results) >= 2
assert results[0]["type"] == "tool_call"
assert results[0]["data"]["status"] == "pending"
assert results[-1]["data"]["status"] == "completed"
def test_execute_tool_action_invalid_tool_name(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
call = Mock()
call.id = "call_123"
call.name = "invalid_format"
call.arguments = "{}"
tools_dict = {}
results = list(agent._execute_tool_action(tools_dict, call))
assert results[0]["type"] == "tool_call"
assert results[0]["data"]["status"] == "error"
assert (
"Failed to parse" in results[0]["data"]["result"]
or "not found" in results[0]["data"]["result"]
)
def test_execute_tool_action_tool_not_found(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
call = Mock()
call.id = "call_123"
call.name = "action_999"
call.arguments = "{}"
tools_dict = {"1": {"name": "tool1", "config": {}, "actions": []}}
results = list(agent._execute_tool_action(tools_dict, call))
assert results[0]["type"] == "tool_call"
assert results[0]["data"]["status"] == "error"
assert "not found" in results[0]["data"]["result"]
def test_execute_tool_action_with_parameters(
self,
agent_base_params,
mock_llm_creator,
mock_llm_handler_creator,
mock_tool_manager,
):
agent = ClassicAgent(**agent_base_params)
call = Mock()
call.id = "call_123"
call.name = "test_action_1"
call.arguments = '{"param1": "value1", "param2": "value2"}'
tools_dict = {
"1": {
"name": "custom_tool",
"config": {},
"actions": [
{
"name": "test_action",
"description": "Test",
"parameters": {
"properties": {
"param1": {"type": "string"},
"param2": {"type": "string"},
}
},
}
],
}
}
results = list(agent._execute_tool_action(tools_dict, call))
assert results[-1]["data"]["status"] == "completed"
assert results[-1]["data"]["arguments"]["param1"] == "value1"
def test_get_truncated_tool_calls(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
agent.tool_calls = [
{
"tool_name": "test_tool",
"call_id": "123",
"action_name": "action",
"arguments": {},
"result": "a" * 100,
}
]
truncated = agent._get_truncated_tool_calls()
assert len(truncated) == 1
assert len(truncated[0]["result"]) <= 53
assert truncated[0]["result"].endswith("...")
@pytest.mark.unit
class TestBaseAgentRetrieverSearch:
def test_retriever_search(
self,
agent_base_params,
mock_retriever,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
agent = ClassicAgent(**agent_base_params)
results = agent._retriever_search(mock_retriever, "test query", log_context)
assert len(results) == 2
mock_retriever.search.assert_called_once_with("test query")
def test_retriever_search_logs_context(
self,
agent_base_params,
mock_retriever,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
agent = ClassicAgent(**agent_base_params)
agent._retriever_search(mock_retriever, "test query", log_context)
assert len(log_context.stacks) == 1
assert log_context.stacks[0]["component"] == "retriever"
@pytest.mark.unit
class TestBaseAgentLLMGeneration:
def test_llm_gen_basic(
self,
agent_base_params,
mock_llm,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
agent = ClassicAgent(**agent_base_params)
messages = [{"role": "user", "content": "test"}]
agent._llm_gen(messages, log_context)
mock_llm.gen_stream.assert_called_once()
call_args = mock_llm.gen_stream.call_args[1]
assert call_args["model"] == agent.gpt_model
assert call_args["messages"] == messages
def test_llm_gen_with_tools(
self,
agent_base_params,
mock_llm,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
agent = ClassicAgent(**agent_base_params)
agent.tools = [{"type": "function", "function": {"name": "test"}}]
messages = [{"role": "user", "content": "test"}]
agent._llm_gen(messages, log_context)
call_args = mock_llm.gen_stream.call_args[1]
assert "tools" in call_args
assert call_args["tools"] == agent.tools
def test_llm_gen_with_json_schema(
self,
agent_base_params,
mock_llm,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
mock_llm._supports_structured_output = Mock(return_value=True)
mock_llm.prepare_structured_output_format = Mock(
return_value={"schema": "test"}
)
agent_base_params["json_schema"] = {"type": "object"}
agent_base_params["llm_name"] = "openai"
agent = ClassicAgent(**agent_base_params)
messages = [{"role": "user", "content": "test"}]
agent._llm_gen(messages, log_context)
call_args = mock_llm.gen_stream.call_args[1]
assert "response_format" in call_args
@pytest.mark.unit
class TestBaseAgentHandleResponse:
def test_handle_response_string(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator, log_context
):
agent = ClassicAgent(**agent_base_params)
response = "Simple string response"
results = list(agent._handle_response(response, {}, [], log_context))
assert len(results) == 1
assert results[0]["answer"] == "Simple string response"
def test_handle_response_with_message(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator, log_context
):
agent = ClassicAgent(**agent_base_params)
response = Mock()
response.message = Mock()
response.message.content = "Message content"
results = list(agent._handle_response(response, {}, [], log_context))
assert len(results) == 1
assert results[0]["answer"] == "Message content"
def test_handle_response_with_structured_output(
self,
agent_base_params,
mock_llm,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
mock_llm._supports_structured_output = Mock(return_value=True)
agent_base_params["json_schema"] = {"type": "object"}
agent = ClassicAgent(**agent_base_params)
response = "Structured response"
results = list(agent._handle_response(response, {}, [], log_context))
assert results[0]["structured"] is True
assert results[0]["schema"] == {"type": "object"}
def test_handle_response_with_handler(
self,
agent_base_params,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
def mock_process(*args):
yield {"type": "tool_call", "data": {}}
yield "Final answer"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_process)
agent = ClassicAgent(**agent_base_params)
response = Mock()
response.message = None
results = list(agent._handle_response(response, {}, [], log_context))
assert len(results) == 2
assert results[0]["type"] == "tool_call"
assert results[1]["answer"] == "Final answer"

View File

@@ -0,0 +1,242 @@
from unittest.mock import Mock
import pytest
from application.agents.classic_agent import ClassicAgent
@pytest.mark.unit
class TestClassicAgent:
def test_classic_agent_initialization(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
assert isinstance(agent, ClassicAgent)
assert agent.endpoint == agent_base_params["endpoint"]
assert agent.llm_name == agent_base_params["llm_name"]
def test_gen_inner_basic_flow(
self,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
log_context,
):
def mock_gen_stream(*args, **kwargs):
yield "Answer chunk 1"
yield "Answer chunk 2"
mock_llm.gen_stream = Mock(return_value=mock_gen_stream())
def mock_handler(*args, **kwargs):
yield "Processed answer"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ClassicAgent(**agent_base_params)
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
assert len(results) >= 2
sources = [r for r in results if "sources" in r]
tool_calls = [r for r in results if "tool_calls" in r]
assert len(sources) == 1
assert len(tool_calls) == 1
def test_gen_inner_retrieves_documents(
self,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
log_context,
):
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
def mock_handler(*args, **kwargs):
yield "Processed"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ClassicAgent(**agent_base_params)
list(agent._gen_inner("Test query", mock_retriever, log_context))
mock_retriever.search.assert_called_once_with("Test query")
def test_gen_inner_uses_user_api_key_tools(
self,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
log_context,
):
from application.core.settings import settings
from bson.objectid import ObjectId
tool_id = str(ObjectId())
mock_mongo_db[settings.MONGO_DB_NAME]["agents"].docs = {
"api_key_123": {"key": "api_key_123", "tools": [tool_id]}
}
mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"].docs = {
tool_id: {"_id": ObjectId(tool_id), "name": "test_tool"}
}
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
def mock_handler(*args, **kwargs):
yield "Processed"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent_base_params["user_api_key"] = "api_key_123"
agent = ClassicAgent(**agent_base_params)
list(agent._gen_inner("Test query", mock_retriever, log_context))
assert len(agent.tools) >= 0
def test_gen_inner_uses_user_tools(
self,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
log_context,
):
from application.core.settings import settings
mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"].docs = {
"1": {"_id": "1", "user": "test_user", "name": "tool1", "status": True}
}
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
def mock_handler(*args, **kwargs):
yield "Processed"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ClassicAgent(**agent_base_params)
list(agent._gen_inner("Test query", mock_retriever, log_context))
assert len(agent.tools) >= 0
def test_gen_inner_builds_correct_messages(
self,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
log_context,
):
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
def mock_handler(*args, **kwargs):
yield "Processed"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ClassicAgent(**agent_base_params)
list(agent._gen_inner("Test query", mock_retriever, log_context))
call_kwargs = mock_llm.gen_stream.call_args[1]
messages = call_kwargs["messages"]
assert len(messages) >= 2
assert messages[0]["role"] == "system"
assert messages[-1]["role"] == "user"
assert messages[-1]["content"] == "Test query"
def test_gen_inner_logs_tool_calls(
self,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
log_context,
):
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
def mock_handler(*args, **kwargs):
yield "Processed"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ClassicAgent(**agent_base_params)
agent.tool_calls = [{"tool": "test", "result": "success"}]
list(agent._gen_inner("Test query", mock_retriever, log_context))
agent_logs = [s for s in log_context.stacks if s["component"] == "agent"]
assert len(agent_logs) == 1
assert "tool_calls" in agent_logs[0]["data"]
@pytest.mark.integration
class TestClassicAgentIntegration:
def test_gen_method_with_logging(
self,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
):
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
def mock_handler(*args, **kwargs):
yield "Processed"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ClassicAgent(**agent_base_params)
results = list(agent.gen("Test query", mock_retriever))
assert len(results) >= 1
def test_gen_method_decorator_applied(
self,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
):
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
def mock_handler(*args, **kwargs):
yield "Processed"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ClassicAgent(**agent_base_params)
assert hasattr(agent.gen, "__wrapped__")

View File

@@ -0,0 +1,519 @@
from unittest.mock import Mock, mock_open, patch
import pytest
from application.agents.react_agent import ReActAgent
@pytest.mark.unit
class TestReActAgent:
def test_react_agent_initialization(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ReActAgent(**agent_base_params)
assert isinstance(agent, ReActAgent)
assert agent.plan == ""
assert agent.observations == []
def test_react_agent_inherits_base_properties(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ReActAgent(**agent_base_params)
assert agent.endpoint == agent_base_params["endpoint"]
assert agent.llm_name == agent_base_params["llm_name"]
assert agent.gpt_model == agent_base_params["gpt_model"]
@pytest.mark.unit
class TestReActAgentContentExtraction:
def test_extract_content_from_string(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ReActAgent(**agent_base_params)
response = "Simple string response"
content = agent._extract_content_from_llm_response(response)
assert content == "Simple string response"
def test_extract_content_from_message_object(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ReActAgent(**agent_base_params)
response = Mock()
response.message = Mock()
response.message.content = "Message content"
content = agent._extract_content_from_llm_response(response)
assert content == "Message content"
def test_extract_content_from_openai_response(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ReActAgent(**agent_base_params)
response = Mock()
response.choices = [Mock()]
response.choices[0].message = Mock()
response.choices[0].message.content = "OpenAI content"
response.message = None
response.content = None
content = agent._extract_content_from_llm_response(response)
assert content == "OpenAI content"
def test_extract_content_from_anthropic_response(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ReActAgent(**agent_base_params)
text_block = Mock()
text_block.text = "Anthropic content"
response = Mock()
response.content = [text_block]
response.message = None
response.choices = None
content = agent._extract_content_from_llm_response(response)
assert content == "Anthropic content"
def test_extract_content_from_openai_stream(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ReActAgent(**agent_base_params)
chunk1 = Mock()
chunk1.choices = [Mock()]
chunk1.choices[0].delta = Mock()
chunk1.choices[0].delta.content = "Part 1 "
chunk2 = Mock()
chunk2.choices = [Mock()]
chunk2.choices[0].delta = Mock()
chunk2.choices[0].delta.content = "Part 2"
response = iter([chunk1, chunk2])
content = agent._extract_content_from_llm_response(response)
assert content == "Part 1 Part 2"
def test_extract_content_from_anthropic_stream(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ReActAgent(**agent_base_params)
chunk1 = Mock()
chunk1.type = "content_block_delta"
chunk1.delta = Mock()
chunk1.delta.text = "Stream 1 "
chunk1.choices = []
chunk2 = Mock()
chunk2.type = "content_block_delta"
chunk2.delta = Mock()
chunk2.delta.text = "Stream 2"
chunk2.choices = []
response = iter([chunk1, chunk2])
content = agent._extract_content_from_llm_response(response)
assert content == "Stream 1 Stream 2"
def test_extract_content_from_string_stream(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ReActAgent(**agent_base_params)
response = iter(["chunk1", "chunk2", "chunk3"])
content = agent._extract_content_from_llm_response(response)
assert content == "chunk1chunk2chunk3"
def test_extract_content_handles_none_content(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ReActAgent(**agent_base_params)
response = Mock()
response.message = Mock()
response.message.content = None
response.choices = None
response.content = None
content = agent._extract_content_from_llm_response(response)
assert content == ""
@pytest.mark.unit
class TestReActAgentPlanning:
@patch(
"builtins.open",
new_callable=mock_open,
read_data="Test planning prompt: {query} {summaries} {prompt} {observations}",
)
def test_create_plan(
self,
mock_file,
agent_base_params,
mock_llm,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
def mock_gen_stream(*args, **kwargs):
yield "Plan step 1"
yield "Plan step 2"
mock_llm.gen_stream = Mock(return_value=mock_gen_stream())
agent = ReActAgent(**agent_base_params)
agent.observations = ["Observation 1"]
plan_chunks = list(agent._create_plan("Test query", "Test docs", log_context))
assert len(plan_chunks) == 2
assert plan_chunks[0] == "Plan step 1"
assert plan_chunks[1] == "Plan step 2"
mock_llm.gen_stream.assert_called_once()
@patch("builtins.open", new_callable=mock_open, read_data="Test: {query}")
def test_create_plan_fills_template(
self,
mock_file,
agent_base_params,
mock_llm,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
mock_llm.gen_stream = Mock(return_value=iter(["Plan"]))
agent = ReActAgent(**agent_base_params)
list(agent._create_plan("My query", "Docs", log_context))
call_args = mock_llm.gen_stream.call_args[1]
messages = call_args["messages"]
assert "My query" in messages[0]["content"]
@pytest.mark.unit
class TestReActAgentFinalAnswer:
@patch(
"builtins.open",
new_callable=mock_open,
read_data="Final answer for: {query} with {observations}",
)
def test_create_final_answer(
self,
mock_file,
agent_base_params,
mock_llm,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
def mock_gen_stream(*args, **kwargs):
yield "Final "
yield "answer"
mock_llm.gen_stream = Mock(return_value=mock_gen_stream())
agent = ReActAgent(**agent_base_params)
observations = ["Obs 1", "Obs 2"]
answer_chunks = list(
agent._create_final_answer("Test query", observations, log_context)
)
assert len(answer_chunks) == 2
assert answer_chunks[0] == "Final "
assert answer_chunks[1] == "answer"
@patch("builtins.open", new_callable=mock_open, read_data="Answer: {observations}")
def test_create_final_answer_truncates_long_observations(
self,
mock_file,
agent_base_params,
mock_llm,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
agent = ReActAgent(**agent_base_params)
long_obs = ["A" * 15000]
list(agent._create_final_answer("Query", long_obs, log_context))
call_args = mock_llm.gen_stream.call_args[1]
messages = call_args["messages"]
assert "observations truncated" in messages[0]["content"]
@patch("builtins.open", new_callable=mock_open, read_data="Test: {query}")
def test_create_final_answer_no_tools(
self,
mock_file,
agent_base_params,
mock_llm,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
agent = ReActAgent(**agent_base_params)
list(agent._create_final_answer("Query", ["Obs"], log_context))
call_args = mock_llm.gen_stream.call_args[1]
assert call_args["tools"] is None
@pytest.mark.unit
class TestReActAgentGenInner:
@patch(
"builtins.open", new_callable=mock_open, read_data="Prompt template: {query}"
)
def test_gen_inner_resets_state(
self,
mock_file,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
log_context,
):
mock_llm.gen_stream = Mock(return_value=iter(["SATISFIED"]))
def mock_handler(*args, **kwargs):
yield "SATISFIED"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ReActAgent(**agent_base_params)
agent.plan = "Old plan"
agent.observations = ["Old obs"]
list(agent._gen_inner("New query", mock_retriever, log_context))
assert agent.plan != "Old plan"
assert len(agent.observations) > 0
@patch("builtins.open", new_callable=mock_open, read_data="Prompt: {query}")
def test_gen_inner_stops_on_satisfied(
self,
mock_file,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
log_context,
):
iteration_count = 0
def mock_gen_stream(*args, **kwargs):
nonlocal iteration_count
iteration_count += 1
if iteration_count == 1:
yield "Plan"
else:
yield "SATISFIED - done"
mock_llm.gen_stream = Mock(
side_effect=lambda *args, **kwargs: mock_gen_stream(*args, **kwargs)
)
def mock_handler(*args, **kwargs):
yield "SATISFIED - done"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ReActAgent(**agent_base_params)
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
assert any("answer" in r for r in results)
@patch("builtins.open", new_callable=mock_open, read_data="Prompt: {query}")
def test_gen_inner_max_iterations(
self,
mock_file,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
log_context,
):
call_count = 0
def mock_gen_stream(*args, **kwargs):
nonlocal call_count
call_count += 1
yield f"Iteration {call_count}"
mock_llm.gen_stream = Mock(
side_effect=lambda *args, **kwargs: mock_gen_stream(*args, **kwargs)
)
def mock_handler(*args, **kwargs):
yield "Continue..."
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ReActAgent(**agent_base_params)
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
thought_results = [r for r in results if "thought" in r]
assert len(thought_results) > 0
@patch("builtins.open", new_callable=mock_open, read_data="Prompt: {query}")
def test_gen_inner_yields_sources(
self,
mock_file,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
log_context,
):
mock_llm.gen_stream = Mock(return_value=iter(["SATISFIED"]))
def mock_handler(*args, **kwargs):
yield "SATISFIED"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ReActAgent(**agent_base_params)
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
sources = [r for r in results if "sources" in r]
assert len(sources) >= 1
@patch("builtins.open", new_callable=mock_open, read_data="Prompt: {query}")
def test_gen_inner_yields_tool_calls(
self,
mock_file,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
log_context,
):
mock_llm.gen_stream = Mock(return_value=iter(["SATISFIED"]))
def mock_handler(*args, **kwargs):
yield "SATISFIED"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ReActAgent(**agent_base_params)
agent.tool_calls = [{"tool": "test", "result": "A" * 100}]
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
tool_call_results = [r for r in results if "tool_calls" in r]
if tool_call_results:
assert len(tool_call_results[0]["tool_calls"][0]["result"]) <= 53
@patch("builtins.open", new_callable=mock_open, read_data="Prompt: {query}")
def test_gen_inner_logs_observations(
self,
mock_file,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
log_context,
):
mock_llm.gen_stream = Mock(return_value=iter(["SATISFIED"]))
def mock_handler(*args, **kwargs):
yield "SATISFIED"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ReActAgent(**agent_base_params)
list(agent._gen_inner("Test query", mock_retriever, log_context))
assert len(agent.observations) > 0
@pytest.mark.integration
class TestReActAgentIntegration:
@patch(
"builtins.open",
new_callable=mock_open,
read_data="Prompt: {query} {summaries} {prompt} {observations}",
)
def test_full_react_workflow(
self,
mock_file,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
log_context,
):
call_sequence = []
def mock_gen_stream(*args, **kwargs):
call_sequence.append("gen_stream")
if len(call_sequence) <= 2:
yield "Planning..."
else:
yield "SATISFIED final answer"
mock_llm.gen_stream = Mock(
side_effect=lambda *args, **kwargs: mock_gen_stream(*args, **kwargs)
)
def mock_handler(*args, **kwargs):
call_sequence.append("handler")
yield "SATISFIED final answer"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ReActAgent(**agent_base_params)
results = list(agent._gen_inner("Complex query", mock_retriever, log_context))
assert len(results) > 0
assert any("thought" in r for r in results)
assert any("answer" in r for r in results)

View File

@@ -0,0 +1,204 @@
from unittest.mock import Mock
import pytest
from application.agents.tools.tool_action_parser import ToolActionParser
@pytest.mark.unit
class TestToolActionParser:
def test_parser_initialization(self):
parser = ToolActionParser("OpenAILLM")
assert parser.llm_type == "OpenAILLM"
assert "OpenAILLM" in parser.parsers
assert "GoogleLLM" in parser.parsers
def test_parse_openai_llm_valid_call(self):
parser = ToolActionParser("OpenAILLM")
call = Mock()
call.name = "get_data_123"
call.arguments = '{"param1": "value1", "param2": "value2"}'
tool_id, action_name, call_args = parser.parse_args(call)
assert tool_id == "123"
assert action_name == "get_data"
assert call_args == {"param1": "value1", "param2": "value2"}
def test_parse_openai_llm_with_underscore_in_action(self):
parser = ToolActionParser("OpenAILLM")
call = Mock()
call.name = "send_email_notification_456"
call.arguments = '{"to": "user@example.com"}'
tool_id, action_name, call_args = parser.parse_args(call)
assert tool_id == "456"
assert action_name == "send_email_notification"
assert call_args == {"to": "user@example.com"}
def test_parse_openai_llm_invalid_format_no_underscore(self):
parser = ToolActionParser("OpenAILLM")
call = Mock()
call.name = "invalidtoolname"
call.arguments = "{}"
tool_id, action_name, call_args = parser.parse_args(call)
assert tool_id is None
assert action_name is None
assert call_args is None
def test_parse_openai_llm_non_numeric_tool_id(self):
parser = ToolActionParser("OpenAILLM")
call = Mock()
call.name = "action_notanumber"
call.arguments = "{}"
tool_id, action_name, call_args = parser.parse_args(call)
assert tool_id == "notanumber"
assert action_name == "action"
def test_parse_openai_llm_malformed_json(self):
parser = ToolActionParser("OpenAILLM")
call = Mock()
call.name = "action_123"
call.arguments = "invalid json"
tool_id, action_name, call_args = parser.parse_args(call)
assert tool_id is None
assert action_name is None
assert call_args is None
def test_parse_openai_llm_missing_attributes(self):
parser = ToolActionParser("OpenAILLM")
call = Mock(spec=[])
tool_id, action_name, call_args = parser.parse_args(call)
assert tool_id is None
assert action_name is None
assert call_args is None
def test_parse_google_llm_valid_call(self):
parser = ToolActionParser("GoogleLLM")
call = Mock()
call.name = "search_documents_789"
call.arguments = {"query": "test query", "limit": 10}
tool_id, action_name, call_args = parser.parse_args(call)
assert tool_id == "789"
assert action_name == "search_documents"
assert call_args == {"query": "test query", "limit": 10}
def test_parse_google_llm_with_complex_action_name(self):
parser = ToolActionParser("GoogleLLM")
call = Mock()
call.name = "create_new_user_account_999"
call.arguments = {"username": "test"}
tool_id, action_name, call_args = parser.parse_args(call)
assert tool_id == "999"
assert action_name == "create_new_user_account"
def test_parse_google_llm_invalid_format(self):
parser = ToolActionParser("GoogleLLM")
call = Mock()
call.name = "nounderscores"
call.arguments = {}
tool_id, action_name, call_args = parser.parse_args(call)
assert tool_id is None
assert action_name is None
assert call_args is None
def test_parse_google_llm_missing_attributes(self):
parser = ToolActionParser("GoogleLLM")
call = Mock(spec=[])
tool_id, action_name, call_args = parser.parse_args(call)
assert tool_id is None
assert action_name is None
assert call_args is None
def test_parse_unknown_llm_type_defaults_to_openai(self):
parser = ToolActionParser("UnknownLLM")
call = Mock()
call.name = "action_123"
call.arguments = '{"key": "value"}'
tool_id, action_name, call_args = parser.parse_args(call)
assert tool_id == "123"
assert action_name == "action"
assert call_args == {"key": "value"}
def test_parse_args_empty_arguments_openai(self):
parser = ToolActionParser("OpenAILLM")
call = Mock()
call.name = "action_123"
call.arguments = "{}"
tool_id, action_name, call_args = parser.parse_args(call)
assert tool_id == "123"
assert action_name == "action"
assert call_args == {}
def test_parse_args_empty_arguments_google(self):
parser = ToolActionParser("GoogleLLM")
call = Mock()
call.name = "action_456"
call.arguments = {}
tool_id, action_name, call_args = parser.parse_args(call)
assert tool_id == "456"
assert action_name == "action"
assert call_args == {}
def test_parse_args_with_special_characters(self):
parser = ToolActionParser("OpenAILLM")
call = Mock()
call.name = "send_message_123"
call.arguments = '{"message": "Hello, World! 你好"}'
tool_id, action_name, call_args = parser.parse_args(call)
assert tool_id == "123"
assert action_name == "send_message"
assert call_args["message"] == "Hello, World! 你好"
def test_parse_args_with_nested_objects(self):
parser = ToolActionParser("OpenAILLM")
call = Mock()
call.name = "create_record_123"
call.arguments = '{"data": {"name": "John", "age": 30}}'
tool_id, action_name, call_args = parser.parse_args(call)
assert tool_id == "123"
assert action_name == "create_record"
assert call_args["data"]["name"] == "John"
assert call_args["data"]["age"] == 30

View File

@@ -0,0 +1,235 @@
from unittest.mock import Mock, patch
import pytest
from application.agents.tools.base import Tool
from application.agents.tools.tool_manager import ToolManager
class MockTool(Tool):
def __init__(self, config):
self.config = config
def execute_action(self, action_name: str, **kwargs):
return f"Executed {action_name} with {kwargs}"
def get_actions_metadata(self):
return [{"name": "test_action", "description": "Test action"}]
def get_config_requirements(self):
return {"required": ["api_key"]}
@pytest.mark.unit
class TestToolManager:
@patch("application.agents.tools.tool_manager.pkgutil.iter_modules")
def test_tool_manager_initialization(self, mock_iter):
mock_iter.return_value = []
config = {"tool1": {"key": "value"}}
manager = ToolManager(config)
assert manager.config == config
assert isinstance(manager.tools, dict)
@patch("application.agents.tools.tool_manager.pkgutil.iter_modules")
@patch("application.agents.tools.tool_manager.importlib.import_module")
def test_load_tools_skips_base_and_private(self, mock_import, mock_iter):
mock_iter.return_value = [
(None, "base", False),
(None, "__init__", False),
(None, "__pycache__", False),
(None, "valid_tool", False),
]
mock_module = Mock()
mock_module.MockTool = MockTool
mock_import.return_value = mock_module
manager = ToolManager({})
assert "base" not in manager.tools
assert "__init__" not in manager.tools
@patch("application.agents.tools.tool_manager.pkgutil.iter_modules")
def test_load_tools_creates_tool_instances(self, mock_iter):
mock_iter.return_value = []
manager = ToolManager({})
mock_tool = MockTool({"test": "config"})
manager.tools["mock_tool"] = mock_tool
assert "mock_tool" in manager.tools
assert isinstance(manager.tools["mock_tool"], MockTool)
assert manager.tools["mock_tool"].config == {"test": "config"}
def test_load_tool_with_user_id(self):
with patch(
"application.agents.tools.tool_manager.pkgutil.iter_modules",
return_value=[],
):
manager = ToolManager({})
tool = MockTool({"key": "value"})
assert tool.config == {"key": "value"}
manager.config["test_tool"] = {"key": "value"}
assert "test_tool" in manager.config
def test_load_tool_without_user_id(self):
tool = MockTool({"api_key": "test123"})
assert isinstance(tool, MockTool)
assert tool.config == {"api_key": "test123"}
assert hasattr(tool, "execute_action")
assert hasattr(tool, "get_actions_metadata")
@patch("application.agents.tools.tool_manager.pkgutil.iter_modules")
def test_load_tool_updates_config(self, mock_iter):
mock_iter.return_value = []
manager = ToolManager({})
new_config = {"new_key": "new_value"}
manager.config["test_tool"] = new_config
assert manager.config["test_tool"] == new_config
assert "test_tool" in manager.config
@patch("application.agents.tools.tool_manager.pkgutil.iter_modules")
@patch("application.agents.tools.tool_manager.importlib.import_module")
def test_execute_action_on_loaded_tool(self, mock_import, mock_iter):
mock_iter.return_value = [(None, "mock_tool", False)]
mock_tool_instance = MockTool({})
with patch("inspect.getmembers", return_value=[("MockTool", MockTool)]):
with patch("inspect.isclass", return_value=True):
with patch.object(MockTool, "__init__", return_value=None):
manager = ToolManager({})
manager.tools["mock_tool"] = mock_tool_instance
result = manager.execute_action(
"mock_tool", "test_action", param="value"
)
assert "Executed test_action" in result
def test_execute_action_tool_not_loaded(self):
with patch(
"application.agents.tools.tool_manager.pkgutil.iter_modules",
return_value=[],
):
manager = ToolManager({})
with pytest.raises(ValueError, match="Tool 'nonexistent' not loaded"):
manager.execute_action("nonexistent", "action")
@patch("application.agents.tools.tool_manager.importlib.import_module")
def test_execute_action_with_user_id_for_mcp_tool(self, mock_import):
mock_tool = MockTool({})
with patch("inspect.getmembers", return_value=[("MockTool", MockTool)]):
with patch("inspect.isclass", return_value=True):
manager = ToolManager({"mcp_tool": {}})
manager.tools["mcp_tool"] = mock_tool
with patch.object(
manager, "load_tool", return_value=mock_tool
) as mock_load:
manager.execute_action("mcp_tool", "action", user_id="user123")
mock_load.assert_called_once_with("mcp_tool", {}, "user123")
@patch("application.agents.tools.tool_manager.importlib.import_module")
def test_execute_action_with_user_id_for_memory_tool(self, mock_import):
mock_tool = MockTool({})
with patch("inspect.getmembers", return_value=[("MockTool", MockTool)]):
with patch("inspect.isclass", return_value=True):
manager = ToolManager({"memory": {}})
manager.tools["memory"] = mock_tool
with patch.object(
manager, "load_tool", return_value=mock_tool
) as mock_load:
manager.execute_action("memory", "view", user_id="user456")
mock_load.assert_called_once_with("memory", {}, "user456")
@patch("application.agents.tools.tool_manager.pkgutil.iter_modules")
@patch("application.agents.tools.tool_manager.importlib.import_module")
def test_get_all_actions_metadata(self, mock_import, mock_iter):
mock_iter.return_value = [(None, "tool1", False), (None, "tool2", False)]
mock_tool1 = Mock()
mock_tool1.get_actions_metadata.return_value = [{"name": "action1"}]
mock_tool2 = Mock()
mock_tool2.get_actions_metadata.return_value = [{"name": "action2"}]
manager = ToolManager({})
manager.tools = {"tool1": mock_tool1, "tool2": mock_tool2}
metadata = manager.get_all_actions_metadata()
assert len(metadata) == 2
assert {"name": "action1"} in metadata
assert {"name": "action2"} in metadata
@patch("application.agents.tools.tool_manager.pkgutil.iter_modules")
def test_get_all_actions_metadata_empty(self, mock_iter):
mock_iter.return_value = []
manager = ToolManager({})
manager.tools = {}
metadata = manager.get_all_actions_metadata()
assert metadata == []
def test_load_tool_with_notes_tool(self):
tool = MockTool({"key": "value"})
assert isinstance(tool, MockTool)
assert tool.config == {"key": "value"}
result = tool.execute_action("test_action", param="value")
assert "test_action" in result
@pytest.mark.unit
class TestToolBase:
def test_tool_base_is_abstract(self):
with pytest.raises(TypeError):
Tool()
def test_mock_tool_implements_interface(self):
tool = MockTool({"test": "config"})
assert hasattr(tool, "execute_action")
assert hasattr(tool, "get_actions_metadata")
assert hasattr(tool, "get_config_requirements")
def test_mock_tool_execute_action(self):
tool = MockTool({})
result = tool.execute_action("test", param="value")
assert "Executed test" in result
assert "param" in result
def test_mock_tool_get_actions_metadata(self):
tool = MockTool({})
metadata = tool.get_actions_metadata()
assert isinstance(metadata, list)
assert len(metadata) > 0
assert "name" in metadata[0]
def test_mock_tool_get_config_requirements(self):
tool = MockTool({})
requirements = tool.get_config_requirements()
assert isinstance(requirements, dict)
assert "required" in requirements