diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index d5b31109..8b85366a 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -16,15 +16,15 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install pytest pytest-cov cd application if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + cd ../tests + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - name: Test with pytest and generate coverage report run: | - python -m pytest --cov=application --cov-report=xml + python -m pytest --cov=application --cov-report=xml --cov-report=term-missing - name: Upload coverage reports to Codecov if: github.event_name == 'pull_request' && matrix.python-version == '3.12' uses: codecov/codecov-action@v5 env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} - diff --git a/application/agents/base.py b/application/agents/base.py index f975abad..9cdaa13a 100644 --- a/application/agents/base.py +++ b/application/agents/base.py @@ -38,7 +38,7 @@ class BaseAgent(ABC): self.user_api_key = user_api_key self.prompt = prompt self.decoded_token = decoded_token or {} - self.user: str = decoded_token.get("sub") + self.user: str = self.decoded_token.get("sub") self.tool_config: Dict = {} self.tools: List[Dict] = [] self.tool_calls: List[Dict] = [] diff --git a/application/agents/tools/tool_action_parser.py b/application/agents/tools/tool_action_parser.py index ea544338..aff7d2e2 100644 --- a/application/agents/tools/tool_action_parser.py +++ b/application/agents/tools/tool_action_parser.py @@ -20,20 +20,24 @@ class ToolActionParser: try: call_args = json.loads(call.arguments) tool_parts = call.name.split("_") - + # If the tool name doesn't contain an underscore, it's likely a hallucinated tool if len(tool_parts) < 2: - logger.warning(f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id") + logger.warning( + f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id" + ) return None, None, None - + tool_id = tool_parts[-1] action_name = "_".join(tool_parts[:-1]) - + # Validate that tool_id looks like a numerical ID if not tool_id.isdigit(): - logger.warning(f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call.") - - except (AttributeError, TypeError) as e: + logger.warning( + f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call." + ) + + except (AttributeError, TypeError, json.JSONDecodeError) as e: logger.error(f"Error parsing OpenAI LLM call: {e}") return None, None, None return tool_id, action_name, call_args @@ -42,19 +46,23 @@ class ToolActionParser: try: call_args = call.arguments tool_parts = call.name.split("_") - + # If the tool name doesn't contain an underscore, it's likely a hallucinated tool if len(tool_parts) < 2: - logger.warning(f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id") + logger.warning( + f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id" + ) return None, None, None - + tool_id = tool_parts[-1] action_name = "_".join(tool_parts[:-1]) - + # Validate that tool_id looks like a numerical ID if not tool_id.isdigit(): - logger.warning(f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call.") - + logger.warning( + f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call." + ) + except (AttributeError, TypeError) as e: logger.error(f"Error parsing Google LLM call: {e}") return None, None, None diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..30b76ad3 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,20 @@ +[pytest] +testpaths = tests +python_files = test_*.py +python_classes = Test* +python_functions = test_* +addopts = + -v + --strict-markers + --tb=short + --cov=application + --cov-report=html + --cov-report=term-missing + --cov-report=xml +markers = + unit: Unit tests + integration: Integration tests + slow: Slow running tests +filterwarnings = + ignore::DeprecationWarning + ignore::PendingDeprecationWarning diff --git a/tests/agents/__init__.py b/tests/agents/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/agents/test_agent_creator.py b/tests/agents/test_agent_creator.py new file mode 100644 index 00000000..508e98a4 --- /dev/null +++ b/tests/agents/test_agent_creator.py @@ -0,0 +1,56 @@ +import pytest +from application.agents.agent_creator import AgentCreator +from application.agents.classic_agent import ClassicAgent +from application.agents.react_agent import ReActAgent + + +@pytest.mark.unit +class TestAgentCreator: + + def test_create_classic_agent(self, agent_base_params): + agent = AgentCreator.create_agent("classic", **agent_base_params) + assert isinstance(agent, ClassicAgent) + assert agent.endpoint == agent_base_params["endpoint"] + assert agent.llm_name == agent_base_params["llm_name"] + assert agent.gpt_model == agent_base_params["gpt_model"] + + def test_create_react_agent(self, agent_base_params): + agent = AgentCreator.create_agent("react", **agent_base_params) + assert isinstance(agent, ReActAgent) + assert agent.endpoint == agent_base_params["endpoint"] + assert agent.llm_name == agent_base_params["llm_name"] + + def test_create_agent_case_insensitive(self, agent_base_params): + agent_upper = AgentCreator.create_agent("CLASSIC", **agent_base_params) + agent_mixed = AgentCreator.create_agent("ClAsSiC", **agent_base_params) + + assert isinstance(agent_upper, ClassicAgent) + assert isinstance(agent_mixed, ClassicAgent) + + def test_create_agent_invalid_type(self, agent_base_params): + with pytest.raises(ValueError, match="No agent class found for type"): + AgentCreator.create_agent("invalid_agent_type", **agent_base_params) + + def test_agent_registry_contains_expected_agents(self): + assert "classic" in AgentCreator.agents + assert "react" in AgentCreator.agents + assert AgentCreator.agents["classic"] == ClassicAgent + assert AgentCreator.agents["react"] == ReActAgent + + def test_create_agent_with_optional_params(self, agent_base_params): + agent_base_params["user_api_key"] = "user_key_123" + agent_base_params["chat_history"] = [{"prompt": "test", "response": "test"}] + agent_base_params["json_schema"] = {"type": "object"} + + agent = AgentCreator.create_agent("classic", **agent_base_params) + + assert agent.user_api_key == "user_key_123" + assert len(agent.chat_history) == 1 + assert agent.json_schema == {"type": "object"} + + def test_create_agent_with_attachments(self, agent_base_params): + attachments = [{"name": "file.txt", "content": "test"}] + agent_base_params["attachments"] = attachments + + agent = AgentCreator.create_agent("classic", **agent_base_params) + assert agent.attachments == attachments diff --git a/tests/agents/test_base_agent.py b/tests/agents/test_base_agent.py new file mode 100644 index 00000000..50195da4 --- /dev/null +++ b/tests/agents/test_base_agent.py @@ -0,0 +1,641 @@ +from unittest.mock import Mock + +import pytest +from application.agents.classic_agent import ClassicAgent +from application.core.settings import settings +from tests.conftest import FakeMongoCollection + + +@pytest.mark.unit +class TestBaseAgentInitialization: + + def test_agent_initialization( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = ClassicAgent(**agent_base_params) + + assert agent.endpoint == agent_base_params["endpoint"] + assert agent.llm_name == agent_base_params["llm_name"] + assert agent.gpt_model == agent_base_params["gpt_model"] + assert agent.api_key == agent_base_params["api_key"] + assert agent.prompt == agent_base_params["prompt"] + assert agent.user == agent_base_params["decoded_token"]["sub"] + assert agent.tools == [] + assert agent.tool_calls == [] + + def test_agent_initialization_with_none_chat_history( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent_base_params["chat_history"] = None + agent = ClassicAgent(**agent_base_params) + assert agent.chat_history == [] + + def test_agent_initialization_with_chat_history( + self, + agent_base_params, + sample_chat_history, + mock_llm_creator, + mock_llm_handler_creator, + ): + agent_base_params["chat_history"] = sample_chat_history + agent = ClassicAgent(**agent_base_params) + assert len(agent.chat_history) == 2 + assert agent.chat_history[0]["prompt"] == "What is Python?" + + def test_agent_decoded_token_defaults_to_empty_dict( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent_base_params["decoded_token"] = None + agent = ClassicAgent(**agent_base_params) + assert agent.decoded_token == {} + assert agent.user is None + + def test_agent_user_extracted_from_token( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent_base_params["decoded_token"] = {"sub": "user123"} + agent = ClassicAgent(**agent_base_params) + assert agent.user == "user123" + + +@pytest.mark.unit +class TestBaseAgentBuildMessages: + + def test_build_messages_basic( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = ClassicAgent(**agent_base_params) + system_prompt = "System: {summaries}" + query = "What is Python?" + retrieved_data = [ + {"text": "Python is a programming language", "filename": "python.txt"} + ] + + messages = agent._build_messages(system_prompt, query, retrieved_data) + + assert len(messages) >= 2 + assert messages[0]["role"] == "system" + assert "Python is a programming language" in messages[0]["content"] + assert messages[-1]["role"] == "user" + assert messages[-1]["content"] == query + + def test_build_messages_with_chat_history( + self, + agent_base_params, + sample_chat_history, + mock_llm_creator, + mock_llm_handler_creator, + ): + agent_base_params["chat_history"] = sample_chat_history + agent = ClassicAgent(**agent_base_params) + + system_prompt = "System: {summaries}" + query = "New question?" + retrieved_data = [{"text": "Data", "filename": "file.txt"}] + + messages = agent._build_messages(system_prompt, query, retrieved_data) + + user_messages = [m for m in messages if m["role"] == "user"] + assistant_messages = [m for m in messages if m["role"] == "assistant"] + + assert len(user_messages) >= 3 + assert len(assistant_messages) >= 2 + + def test_build_messages_with_tool_calls_in_history( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + tool_call_history = [ + { + "tool_calls": [ + { + "call_id": "123", + "action_name": "test_action", + "arguments": {"arg": "value"}, + "result": "success", + } + ] + } + ] + agent_base_params["chat_history"] = tool_call_history + agent = ClassicAgent(**agent_base_params) + + messages = agent._build_messages( + "System: {summaries}", "query", [{"text": "data", "filename": "file.txt"}] + ) + + tool_messages = [m for m in messages if m["role"] == "tool"] + assert len(tool_messages) > 0 + + def test_build_messages_handles_missing_filename( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = ClassicAgent(**agent_base_params) + retrieved_data = [{"text": "Document without filename or title"}] + + messages = agent._build_messages("System: {summaries}", "query", retrieved_data) + + assert messages[0]["role"] == "system" + assert "Document without filename" in messages[0]["content"] + + def test_build_messages_uses_title_as_fallback( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = ClassicAgent(**agent_base_params) + retrieved_data = [{"text": "Data", "title": "Title Doc"}] + + messages = agent._build_messages("System: {summaries}", "query", retrieved_data) + + assert "Title Doc" in messages[0]["content"] + + def test_build_messages_uses_source_as_fallback( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = ClassicAgent(**agent_base_params) + retrieved_data = [{"text": "Data", "source": "source.txt"}] + + messages = agent._build_messages("System: {summaries}", "query", retrieved_data) + + assert "source.txt" in messages[0]["content"] + + +@pytest.mark.unit +class TestBaseAgentTools: + + def test_get_user_tools( + self, + agent_base_params, + mock_mongo_db, + mock_llm_creator, + mock_llm_handler_creator, + ): + mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"].docs = { + "1": {"_id": "1", "user": "test_user", "name": "tool1", "status": True}, + "2": {"_id": "2", "user": "test_user", "name": "tool2", "status": True}, + } + + agent = ClassicAgent(**agent_base_params) + tools = agent._get_user_tools("test_user") + + assert len(tools) == 2 + assert "0" in tools + assert "1" in tools + + def test_get_user_tools_filters_by_status( + self, + agent_base_params, + mock_mongo_db, + mock_llm_creator, + mock_llm_handler_creator, + ): + mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"].docs = { + "1": {"_id": "1", "user": "test_user", "name": "tool1", "status": True}, + "2": {"_id": "2", "user": "test_user", "name": "tool2", "status": False}, + } + + agent = ClassicAgent(**agent_base_params) + tools = agent._get_user_tools("test_user") + + assert len(tools) == 1 + + def test_get_tools_by_api_key( + self, + agent_base_params, + mock_mongo_db, + mock_llm_creator, + mock_llm_handler_creator, + ): + from bson.objectid import ObjectId + + tool_id = str(ObjectId()) + tool_obj_id = ObjectId(tool_id) + + fake_agent_collection = FakeMongoCollection() + fake_agent_collection.docs["api_key_123"] = { + "key": "api_key_123", + "tools": [tool_id], + } + + fake_tools_collection = FakeMongoCollection() + fake_tools_collection.docs[tool_id] = {"_id": tool_obj_id, "name": "api_tool"} + + mock_mongo_db[settings.MONGO_DB_NAME]["agents"] = fake_agent_collection + mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"] = fake_tools_collection + + agent = ClassicAgent(**agent_base_params) + tools = agent._get_tools("api_key_123") + + assert tool_id in tools + + def test_build_tool_parameters( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = ClassicAgent(**agent_base_params) + + action = { + "parameters": { + "properties": { + "param1": { + "type": "string", + "description": "Test param", + "filled_by_llm": True, + }, + "param2": {"type": "number", "filled_by_llm": False, "value": 42}, + } + } + } + + params = agent._build_tool_parameters(action) + + assert "param1" in params["properties"] + assert "param1" in params["required"] + assert "filled_by_llm" not in params["properties"]["param1"] + + def test_prepare_tools_with_api_tool( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = ClassicAgent(**agent_base_params) + + tools_dict = { + "1": { + "name": "api_tool", + "config": { + "actions": { + "get_data": { + "name": "get_data", + "description": "Get data from API", + "active": True, + "url": "https://api.example.com/data", + "method": "GET", + "parameters": {"properties": {}}, + } + } + }, + } + } + + agent._prepare_tools(tools_dict) + + assert len(agent.tools) == 1 + assert agent.tools[0]["type"] == "function" + assert agent.tools[0]["function"]["name"] == "get_data_1" + + def test_prepare_tools_with_regular_tool( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = ClassicAgent(**agent_base_params) + + tools_dict = { + "1": { + "name": "custom_tool", + "actions": [ + { + "name": "action1", + "description": "Custom action", + "active": True, + "parameters": {"properties": {}}, + } + ], + } + } + + agent._prepare_tools(tools_dict) + + assert len(agent.tools) == 1 + assert agent.tools[0]["function"]["name"] == "action1_1" + + def test_prepare_tools_filters_inactive_actions( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = ClassicAgent(**agent_base_params) + + tools_dict = { + "1": { + "name": "custom_tool", + "actions": [ + { + "name": "active_action", + "description": "Active", + "active": True, + "parameters": {"properties": {}}, + }, + { + "name": "inactive_action", + "description": "Inactive", + "active": False, + "parameters": {"properties": {}}, + }, + ], + } + } + + agent._prepare_tools(tools_dict) + + assert len(agent.tools) == 1 + assert agent.tools[0]["function"]["name"] == "active_action_1" + + +@pytest.mark.unit +class TestBaseAgentToolExecution: + + def test_execute_tool_action_success( + self, + agent_base_params, + mock_llm_creator, + mock_llm_handler_creator, + mock_tool_manager, + ): + agent = ClassicAgent(**agent_base_params) + + call = Mock() + call.id = "call_123" + call.name = "test_action_1" + call.arguments = '{"param1": "value1"}' + + tools_dict = { + "1": { + "name": "custom_tool", + "config": {}, + "actions": [ + { + "name": "test_action", + "description": "Test", + "parameters": {"properties": {}}, + } + ], + } + } + + results = list(agent._execute_tool_action(tools_dict, call)) + + assert len(results) >= 2 + assert results[0]["type"] == "tool_call" + assert results[0]["data"]["status"] == "pending" + assert results[-1]["data"]["status"] == "completed" + + def test_execute_tool_action_invalid_tool_name( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = ClassicAgent(**agent_base_params) + + call = Mock() + call.id = "call_123" + call.name = "invalid_format" + call.arguments = "{}" + + tools_dict = {} + + results = list(agent._execute_tool_action(tools_dict, call)) + + assert results[0]["type"] == "tool_call" + assert results[0]["data"]["status"] == "error" + assert ( + "Failed to parse" in results[0]["data"]["result"] + or "not found" in results[0]["data"]["result"] + ) + + def test_execute_tool_action_tool_not_found( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = ClassicAgent(**agent_base_params) + + call = Mock() + call.id = "call_123" + call.name = "action_999" + call.arguments = "{}" + + tools_dict = {"1": {"name": "tool1", "config": {}, "actions": []}} + + results = list(agent._execute_tool_action(tools_dict, call)) + + assert results[0]["type"] == "tool_call" + assert results[0]["data"]["status"] == "error" + assert "not found" in results[0]["data"]["result"] + + def test_execute_tool_action_with_parameters( + self, + agent_base_params, + mock_llm_creator, + mock_llm_handler_creator, + mock_tool_manager, + ): + agent = ClassicAgent(**agent_base_params) + + call = Mock() + call.id = "call_123" + call.name = "test_action_1" + call.arguments = '{"param1": "value1", "param2": "value2"}' + + tools_dict = { + "1": { + "name": "custom_tool", + "config": {}, + "actions": [ + { + "name": "test_action", + "description": "Test", + "parameters": { + "properties": { + "param1": {"type": "string"}, + "param2": {"type": "string"}, + } + }, + } + ], + } + } + + results = list(agent._execute_tool_action(tools_dict, call)) + + assert results[-1]["data"]["status"] == "completed" + assert results[-1]["data"]["arguments"]["param1"] == "value1" + + def test_get_truncated_tool_calls( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = ClassicAgent(**agent_base_params) + + agent.tool_calls = [ + { + "tool_name": "test_tool", + "call_id": "123", + "action_name": "action", + "arguments": {}, + "result": "a" * 100, + } + ] + + truncated = agent._get_truncated_tool_calls() + + assert len(truncated) == 1 + assert len(truncated[0]["result"]) <= 53 + assert truncated[0]["result"].endswith("...") + + +@pytest.mark.unit +class TestBaseAgentRetrieverSearch: + + def test_retriever_search( + self, + agent_base_params, + mock_retriever, + mock_llm_creator, + mock_llm_handler_creator, + log_context, + ): + agent = ClassicAgent(**agent_base_params) + + results = agent._retriever_search(mock_retriever, "test query", log_context) + + assert len(results) == 2 + mock_retriever.search.assert_called_once_with("test query") + + def test_retriever_search_logs_context( + self, + agent_base_params, + mock_retriever, + mock_llm_creator, + mock_llm_handler_creator, + log_context, + ): + agent = ClassicAgent(**agent_base_params) + + agent._retriever_search(mock_retriever, "test query", log_context) + + assert len(log_context.stacks) == 1 + assert log_context.stacks[0]["component"] == "retriever" + + +@pytest.mark.unit +class TestBaseAgentLLMGeneration: + + def test_llm_gen_basic( + self, + agent_base_params, + mock_llm, + mock_llm_creator, + mock_llm_handler_creator, + log_context, + ): + agent = ClassicAgent(**agent_base_params) + + messages = [{"role": "user", "content": "test"}] + agent._llm_gen(messages, log_context) + + mock_llm.gen_stream.assert_called_once() + call_args = mock_llm.gen_stream.call_args[1] + assert call_args["model"] == agent.gpt_model + assert call_args["messages"] == messages + + def test_llm_gen_with_tools( + self, + agent_base_params, + mock_llm, + mock_llm_creator, + mock_llm_handler_creator, + log_context, + ): + agent = ClassicAgent(**agent_base_params) + agent.tools = [{"type": "function", "function": {"name": "test"}}] + + messages = [{"role": "user", "content": "test"}] + agent._llm_gen(messages, log_context) + + call_args = mock_llm.gen_stream.call_args[1] + assert "tools" in call_args + assert call_args["tools"] == agent.tools + + def test_llm_gen_with_json_schema( + self, + agent_base_params, + mock_llm, + mock_llm_creator, + mock_llm_handler_creator, + log_context, + ): + mock_llm._supports_structured_output = Mock(return_value=True) + mock_llm.prepare_structured_output_format = Mock( + return_value={"schema": "test"} + ) + + agent_base_params["json_schema"] = {"type": "object"} + agent_base_params["llm_name"] = "openai" + agent = ClassicAgent(**agent_base_params) + + messages = [{"role": "user", "content": "test"}] + agent._llm_gen(messages, log_context) + + call_args = mock_llm.gen_stream.call_args[1] + assert "response_format" in call_args + + +@pytest.mark.unit +class TestBaseAgentHandleResponse: + + def test_handle_response_string( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator, log_context + ): + agent = ClassicAgent(**agent_base_params) + + response = "Simple string response" + results = list(agent._handle_response(response, {}, [], log_context)) + + assert len(results) == 1 + assert results[0]["answer"] == "Simple string response" + + def test_handle_response_with_message( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator, log_context + ): + agent = ClassicAgent(**agent_base_params) + + response = Mock() + response.message = Mock() + response.message.content = "Message content" + + results = list(agent._handle_response(response, {}, [], log_context)) + + assert len(results) == 1 + assert results[0]["answer"] == "Message content" + + def test_handle_response_with_structured_output( + self, + agent_base_params, + mock_llm, + mock_llm_creator, + mock_llm_handler_creator, + log_context, + ): + mock_llm._supports_structured_output = Mock(return_value=True) + agent_base_params["json_schema"] = {"type": "object"} + + agent = ClassicAgent(**agent_base_params) + + response = "Structured response" + results = list(agent._handle_response(response, {}, [], log_context)) + + assert results[0]["structured"] is True + assert results[0]["schema"] == {"type": "object"} + + def test_handle_response_with_handler( + self, + agent_base_params, + mock_llm_handler, + mock_llm_creator, + mock_llm_handler_creator, + log_context, + ): + def mock_process(*args): + yield {"type": "tool_call", "data": {}} + yield "Final answer" + + mock_llm_handler.process_message_flow = Mock(side_effect=mock_process) + + agent = ClassicAgent(**agent_base_params) + + response = Mock() + response.message = None + + results = list(agent._handle_response(response, {}, [], log_context)) + + assert len(results) == 2 + assert results[0]["type"] == "tool_call" + assert results[1]["answer"] == "Final answer" diff --git a/tests/agents/test_classic_agent.py b/tests/agents/test_classic_agent.py new file mode 100644 index 00000000..826cc19e --- /dev/null +++ b/tests/agents/test_classic_agent.py @@ -0,0 +1,242 @@ +from unittest.mock import Mock + +import pytest +from application.agents.classic_agent import ClassicAgent + + +@pytest.mark.unit +class TestClassicAgent: + + def test_classic_agent_initialization( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = ClassicAgent(**agent_base_params) + + assert isinstance(agent, ClassicAgent) + assert agent.endpoint == agent_base_params["endpoint"] + assert agent.llm_name == agent_base_params["llm_name"] + + def test_gen_inner_basic_flow( + self, + agent_base_params, + mock_retriever, + mock_llm, + mock_llm_handler, + mock_llm_creator, + mock_llm_handler_creator, + mock_mongo_db, + log_context, + ): + def mock_gen_stream(*args, **kwargs): + yield "Answer chunk 1" + yield "Answer chunk 2" + + mock_llm.gen_stream = Mock(return_value=mock_gen_stream()) + + def mock_handler(*args, **kwargs): + yield "Processed answer" + + mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler) + + agent = ClassicAgent(**agent_base_params) + + results = list(agent._gen_inner("Test query", mock_retriever, log_context)) + + assert len(results) >= 2 + sources = [r for r in results if "sources" in r] + tool_calls = [r for r in results if "tool_calls" in r] + + assert len(sources) == 1 + assert len(tool_calls) == 1 + + def test_gen_inner_retrieves_documents( + self, + agent_base_params, + mock_retriever, + mock_llm, + mock_llm_handler, + mock_llm_creator, + mock_llm_handler_creator, + mock_mongo_db, + log_context, + ): + mock_llm.gen_stream = Mock(return_value=iter(["Answer"])) + + def mock_handler(*args, **kwargs): + yield "Processed" + + mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler) + + agent = ClassicAgent(**agent_base_params) + list(agent._gen_inner("Test query", mock_retriever, log_context)) + + mock_retriever.search.assert_called_once_with("Test query") + + def test_gen_inner_uses_user_api_key_tools( + self, + agent_base_params, + mock_retriever, + mock_llm, + mock_llm_handler, + mock_llm_creator, + mock_llm_handler_creator, + mock_mongo_db, + log_context, + ): + from application.core.settings import settings + from bson.objectid import ObjectId + + tool_id = str(ObjectId()) + mock_mongo_db[settings.MONGO_DB_NAME]["agents"].docs = { + "api_key_123": {"key": "api_key_123", "tools": [tool_id]} + } + mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"].docs = { + tool_id: {"_id": ObjectId(tool_id), "name": "test_tool"} + } + + mock_llm.gen_stream = Mock(return_value=iter(["Answer"])) + + def mock_handler(*args, **kwargs): + yield "Processed" + + mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler) + + agent_base_params["user_api_key"] = "api_key_123" + agent = ClassicAgent(**agent_base_params) + + list(agent._gen_inner("Test query", mock_retriever, log_context)) + + assert len(agent.tools) >= 0 + + def test_gen_inner_uses_user_tools( + self, + agent_base_params, + mock_retriever, + mock_llm, + mock_llm_handler, + mock_llm_creator, + mock_llm_handler_creator, + mock_mongo_db, + log_context, + ): + from application.core.settings import settings + + mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"].docs = { + "1": {"_id": "1", "user": "test_user", "name": "tool1", "status": True} + } + + mock_llm.gen_stream = Mock(return_value=iter(["Answer"])) + + def mock_handler(*args, **kwargs): + yield "Processed" + + mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler) + + agent = ClassicAgent(**agent_base_params) + list(agent._gen_inner("Test query", mock_retriever, log_context)) + + assert len(agent.tools) >= 0 + + def test_gen_inner_builds_correct_messages( + self, + agent_base_params, + mock_retriever, + mock_llm, + mock_llm_handler, + mock_llm_creator, + mock_llm_handler_creator, + mock_mongo_db, + log_context, + ): + mock_llm.gen_stream = Mock(return_value=iter(["Answer"])) + + def mock_handler(*args, **kwargs): + yield "Processed" + + mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler) + + agent = ClassicAgent(**agent_base_params) + list(agent._gen_inner("Test query", mock_retriever, log_context)) + + call_kwargs = mock_llm.gen_stream.call_args[1] + messages = call_kwargs["messages"] + + assert len(messages) >= 2 + assert messages[0]["role"] == "system" + assert messages[-1]["role"] == "user" + assert messages[-1]["content"] == "Test query" + + def test_gen_inner_logs_tool_calls( + self, + agent_base_params, + mock_retriever, + mock_llm, + mock_llm_handler, + mock_llm_creator, + mock_llm_handler_creator, + mock_mongo_db, + log_context, + ): + mock_llm.gen_stream = Mock(return_value=iter(["Answer"])) + + def mock_handler(*args, **kwargs): + yield "Processed" + + mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler) + + agent = ClassicAgent(**agent_base_params) + agent.tool_calls = [{"tool": "test", "result": "success"}] + + list(agent._gen_inner("Test query", mock_retriever, log_context)) + + agent_logs = [s for s in log_context.stacks if s["component"] == "agent"] + assert len(agent_logs) == 1 + assert "tool_calls" in agent_logs[0]["data"] + + +@pytest.mark.integration +class TestClassicAgentIntegration: + + def test_gen_method_with_logging( + self, + agent_base_params, + mock_retriever, + mock_llm, + mock_llm_handler, + mock_llm_creator, + mock_llm_handler_creator, + mock_mongo_db, + ): + mock_llm.gen_stream = Mock(return_value=iter(["Answer"])) + + def mock_handler(*args, **kwargs): + yield "Processed" + + mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler) + + agent = ClassicAgent(**agent_base_params) + + results = list(agent.gen("Test query", mock_retriever)) + + assert len(results) >= 1 + + def test_gen_method_decorator_applied( + self, + agent_base_params, + mock_retriever, + mock_llm, + mock_llm_handler, + mock_llm_creator, + mock_llm_handler_creator, + mock_mongo_db, + ): + mock_llm.gen_stream = Mock(return_value=iter(["Answer"])) + + def mock_handler(*args, **kwargs): + yield "Processed" + + mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler) + + agent = ClassicAgent(**agent_base_params) + + assert hasattr(agent.gen, "__wrapped__") diff --git a/tests/agents/test_react_agent.py b/tests/agents/test_react_agent.py new file mode 100644 index 00000000..15a3dd5a --- /dev/null +++ b/tests/agents/test_react_agent.py @@ -0,0 +1,519 @@ +from unittest.mock import Mock, mock_open, patch + +import pytest +from application.agents.react_agent import ReActAgent + + +@pytest.mark.unit +class TestReActAgent: + + def test_react_agent_initialization( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = ReActAgent(**agent_base_params) + + assert isinstance(agent, ReActAgent) + assert agent.plan == "" + assert agent.observations == [] + + def test_react_agent_inherits_base_properties( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = ReActAgent(**agent_base_params) + + assert agent.endpoint == agent_base_params["endpoint"] + assert agent.llm_name == agent_base_params["llm_name"] + assert agent.gpt_model == agent_base_params["gpt_model"] + + +@pytest.mark.unit +class TestReActAgentContentExtraction: + + def test_extract_content_from_string( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = ReActAgent(**agent_base_params) + + response = "Simple string response" + content = agent._extract_content_from_llm_response(response) + + assert content == "Simple string response" + + def test_extract_content_from_message_object( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = ReActAgent(**agent_base_params) + + response = Mock() + response.message = Mock() + response.message.content = "Message content" + + content = agent._extract_content_from_llm_response(response) + + assert content == "Message content" + + def test_extract_content_from_openai_response( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = ReActAgent(**agent_base_params) + + response = Mock() + response.choices = [Mock()] + response.choices[0].message = Mock() + response.choices[0].message.content = "OpenAI content" + response.message = None + response.content = None + + content = agent._extract_content_from_llm_response(response) + + assert content == "OpenAI content" + + def test_extract_content_from_anthropic_response( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = ReActAgent(**agent_base_params) + + text_block = Mock() + text_block.text = "Anthropic content" + + response = Mock() + response.content = [text_block] + response.message = None + response.choices = None + + content = agent._extract_content_from_llm_response(response) + + assert content == "Anthropic content" + + def test_extract_content_from_openai_stream( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = ReActAgent(**agent_base_params) + + chunk1 = Mock() + chunk1.choices = [Mock()] + chunk1.choices[0].delta = Mock() + chunk1.choices[0].delta.content = "Part 1 " + + chunk2 = Mock() + chunk2.choices = [Mock()] + chunk2.choices[0].delta = Mock() + chunk2.choices[0].delta.content = "Part 2" + + response = iter([chunk1, chunk2]) + content = agent._extract_content_from_llm_response(response) + + assert content == "Part 1 Part 2" + + def test_extract_content_from_anthropic_stream( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = ReActAgent(**agent_base_params) + + chunk1 = Mock() + chunk1.type = "content_block_delta" + chunk1.delta = Mock() + chunk1.delta.text = "Stream 1 " + chunk1.choices = [] + + chunk2 = Mock() + chunk2.type = "content_block_delta" + chunk2.delta = Mock() + chunk2.delta.text = "Stream 2" + chunk2.choices = [] + + response = iter([chunk1, chunk2]) + content = agent._extract_content_from_llm_response(response) + + assert content == "Stream 1 Stream 2" + + def test_extract_content_from_string_stream( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = ReActAgent(**agent_base_params) + + response = iter(["chunk1", "chunk2", "chunk3"]) + content = agent._extract_content_from_llm_response(response) + + assert content == "chunk1chunk2chunk3" + + def test_extract_content_handles_none_content( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = ReActAgent(**agent_base_params) + + response = Mock() + response.message = Mock() + response.message.content = None + response.choices = None + response.content = None + + content = agent._extract_content_from_llm_response(response) + + assert content == "" + + +@pytest.mark.unit +class TestReActAgentPlanning: + + @patch( + "builtins.open", + new_callable=mock_open, + read_data="Test planning prompt: {query} {summaries} {prompt} {observations}", + ) + def test_create_plan( + self, + mock_file, + agent_base_params, + mock_llm, + mock_llm_creator, + mock_llm_handler_creator, + log_context, + ): + def mock_gen_stream(*args, **kwargs): + yield "Plan step 1" + yield "Plan step 2" + + mock_llm.gen_stream = Mock(return_value=mock_gen_stream()) + + agent = ReActAgent(**agent_base_params) + agent.observations = ["Observation 1"] + + plan_chunks = list(agent._create_plan("Test query", "Test docs", log_context)) + + assert len(plan_chunks) == 2 + assert plan_chunks[0] == "Plan step 1" + assert plan_chunks[1] == "Plan step 2" + + mock_llm.gen_stream.assert_called_once() + + @patch("builtins.open", new_callable=mock_open, read_data="Test: {query}") + def test_create_plan_fills_template( + self, + mock_file, + agent_base_params, + mock_llm, + mock_llm_creator, + mock_llm_handler_creator, + log_context, + ): + mock_llm.gen_stream = Mock(return_value=iter(["Plan"])) + + agent = ReActAgent(**agent_base_params) + list(agent._create_plan("My query", "Docs", log_context)) + + call_args = mock_llm.gen_stream.call_args[1] + messages = call_args["messages"] + + assert "My query" in messages[0]["content"] + + +@pytest.mark.unit +class TestReActAgentFinalAnswer: + + @patch( + "builtins.open", + new_callable=mock_open, + read_data="Final answer for: {query} with {observations}", + ) + def test_create_final_answer( + self, + mock_file, + agent_base_params, + mock_llm, + mock_llm_creator, + mock_llm_handler_creator, + log_context, + ): + def mock_gen_stream(*args, **kwargs): + yield "Final " + yield "answer" + + mock_llm.gen_stream = Mock(return_value=mock_gen_stream()) + + agent = ReActAgent(**agent_base_params) + observations = ["Obs 1", "Obs 2"] + + answer_chunks = list( + agent._create_final_answer("Test query", observations, log_context) + ) + + assert len(answer_chunks) == 2 + assert answer_chunks[0] == "Final " + assert answer_chunks[1] == "answer" + + @patch("builtins.open", new_callable=mock_open, read_data="Answer: {observations}") + def test_create_final_answer_truncates_long_observations( + self, + mock_file, + agent_base_params, + mock_llm, + mock_llm_creator, + mock_llm_handler_creator, + log_context, + ): + mock_llm.gen_stream = Mock(return_value=iter(["Answer"])) + + agent = ReActAgent(**agent_base_params) + long_obs = ["A" * 15000] + + list(agent._create_final_answer("Query", long_obs, log_context)) + + call_args = mock_llm.gen_stream.call_args[1] + messages = call_args["messages"] + + assert "observations truncated" in messages[0]["content"] + + @patch("builtins.open", new_callable=mock_open, read_data="Test: {query}") + def test_create_final_answer_no_tools( + self, + mock_file, + agent_base_params, + mock_llm, + mock_llm_creator, + mock_llm_handler_creator, + log_context, + ): + mock_llm.gen_stream = Mock(return_value=iter(["Answer"])) + + agent = ReActAgent(**agent_base_params) + list(agent._create_final_answer("Query", ["Obs"], log_context)) + + call_args = mock_llm.gen_stream.call_args[1] + + assert call_args["tools"] is None + + +@pytest.mark.unit +class TestReActAgentGenInner: + + @patch( + "builtins.open", new_callable=mock_open, read_data="Prompt template: {query}" + ) + def test_gen_inner_resets_state( + self, + mock_file, + agent_base_params, + mock_retriever, + mock_llm, + mock_llm_handler, + mock_llm_creator, + mock_llm_handler_creator, + mock_mongo_db, + log_context, + ): + mock_llm.gen_stream = Mock(return_value=iter(["SATISFIED"])) + + def mock_handler(*args, **kwargs): + yield "SATISFIED" + + mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler) + + agent = ReActAgent(**agent_base_params) + agent.plan = "Old plan" + agent.observations = ["Old obs"] + + list(agent._gen_inner("New query", mock_retriever, log_context)) + + assert agent.plan != "Old plan" + assert len(agent.observations) > 0 + + @patch("builtins.open", new_callable=mock_open, read_data="Prompt: {query}") + def test_gen_inner_stops_on_satisfied( + self, + mock_file, + agent_base_params, + mock_retriever, + mock_llm, + mock_llm_handler, + mock_llm_creator, + mock_llm_handler_creator, + mock_mongo_db, + log_context, + ): + iteration_count = 0 + + def mock_gen_stream(*args, **kwargs): + nonlocal iteration_count + iteration_count += 1 + if iteration_count == 1: + yield "Plan" + else: + yield "SATISFIED - done" + + mock_llm.gen_stream = Mock( + side_effect=lambda *args, **kwargs: mock_gen_stream(*args, **kwargs) + ) + + def mock_handler(*args, **kwargs): + yield "SATISFIED - done" + + mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler) + + agent = ReActAgent(**agent_base_params) + results = list(agent._gen_inner("Test query", mock_retriever, log_context)) + + assert any("answer" in r for r in results) + + @patch("builtins.open", new_callable=mock_open, read_data="Prompt: {query}") + def test_gen_inner_max_iterations( + self, + mock_file, + agent_base_params, + mock_retriever, + mock_llm, + mock_llm_handler, + mock_llm_creator, + mock_llm_handler_creator, + mock_mongo_db, + log_context, + ): + call_count = 0 + + def mock_gen_stream(*args, **kwargs): + nonlocal call_count + call_count += 1 + yield f"Iteration {call_count}" + + mock_llm.gen_stream = Mock( + side_effect=lambda *args, **kwargs: mock_gen_stream(*args, **kwargs) + ) + + def mock_handler(*args, **kwargs): + yield "Continue..." + + mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler) + + agent = ReActAgent(**agent_base_params) + + results = list(agent._gen_inner("Test query", mock_retriever, log_context)) + + thought_results = [r for r in results if "thought" in r] + assert len(thought_results) > 0 + + @patch("builtins.open", new_callable=mock_open, read_data="Prompt: {query}") + def test_gen_inner_yields_sources( + self, + mock_file, + agent_base_params, + mock_retriever, + mock_llm, + mock_llm_handler, + mock_llm_creator, + mock_llm_handler_creator, + mock_mongo_db, + log_context, + ): + mock_llm.gen_stream = Mock(return_value=iter(["SATISFIED"])) + + def mock_handler(*args, **kwargs): + yield "SATISFIED" + + mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler) + + agent = ReActAgent(**agent_base_params) + results = list(agent._gen_inner("Test query", mock_retriever, log_context)) + + sources = [r for r in results if "sources" in r] + assert len(sources) >= 1 + + @patch("builtins.open", new_callable=mock_open, read_data="Prompt: {query}") + def test_gen_inner_yields_tool_calls( + self, + mock_file, + agent_base_params, + mock_retriever, + mock_llm, + mock_llm_handler, + mock_llm_creator, + mock_llm_handler_creator, + mock_mongo_db, + log_context, + ): + mock_llm.gen_stream = Mock(return_value=iter(["SATISFIED"])) + + def mock_handler(*args, **kwargs): + yield "SATISFIED" + + mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler) + + agent = ReActAgent(**agent_base_params) + agent.tool_calls = [{"tool": "test", "result": "A" * 100}] + + results = list(agent._gen_inner("Test query", mock_retriever, log_context)) + + tool_call_results = [r for r in results if "tool_calls" in r] + if tool_call_results: + assert len(tool_call_results[0]["tool_calls"][0]["result"]) <= 53 + + @patch("builtins.open", new_callable=mock_open, read_data="Prompt: {query}") + def test_gen_inner_logs_observations( + self, + mock_file, + agent_base_params, + mock_retriever, + mock_llm, + mock_llm_handler, + mock_llm_creator, + mock_llm_handler_creator, + mock_mongo_db, + log_context, + ): + mock_llm.gen_stream = Mock(return_value=iter(["SATISFIED"])) + + def mock_handler(*args, **kwargs): + yield "SATISFIED" + + mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler) + + agent = ReActAgent(**agent_base_params) + list(agent._gen_inner("Test query", mock_retriever, log_context)) + + assert len(agent.observations) > 0 + + +@pytest.mark.integration +class TestReActAgentIntegration: + + @patch( + "builtins.open", + new_callable=mock_open, + read_data="Prompt: {query} {summaries} {prompt} {observations}", + ) + def test_full_react_workflow( + self, + mock_file, + agent_base_params, + mock_retriever, + mock_llm, + mock_llm_handler, + mock_llm_creator, + mock_llm_handler_creator, + mock_mongo_db, + log_context, + ): + call_sequence = [] + + def mock_gen_stream(*args, **kwargs): + call_sequence.append("gen_stream") + if len(call_sequence) <= 2: + yield "Planning..." + else: + yield "SATISFIED final answer" + + mock_llm.gen_stream = Mock( + side_effect=lambda *args, **kwargs: mock_gen_stream(*args, **kwargs) + ) + + def mock_handler(*args, **kwargs): + call_sequence.append("handler") + yield "SATISFIED final answer" + + mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler) + + agent = ReActAgent(**agent_base_params) + results = list(agent._gen_inner("Complex query", mock_retriever, log_context)) + + assert len(results) > 0 + assert any("thought" in r for r in results) + assert any("answer" in r for r in results) diff --git a/tests/agents/test_tool_action_parser.py b/tests/agents/test_tool_action_parser.py new file mode 100644 index 00000000..840d070c --- /dev/null +++ b/tests/agents/test_tool_action_parser.py @@ -0,0 +1,204 @@ +from unittest.mock import Mock + +import pytest +from application.agents.tools.tool_action_parser import ToolActionParser + + +@pytest.mark.unit +class TestToolActionParser: + + def test_parser_initialization(self): + parser = ToolActionParser("OpenAILLM") + assert parser.llm_type == "OpenAILLM" + assert "OpenAILLM" in parser.parsers + assert "GoogleLLM" in parser.parsers + + def test_parse_openai_llm_valid_call(self): + parser = ToolActionParser("OpenAILLM") + + call = Mock() + call.name = "get_data_123" + call.arguments = '{"param1": "value1", "param2": "value2"}' + + tool_id, action_name, call_args = parser.parse_args(call) + + assert tool_id == "123" + assert action_name == "get_data" + assert call_args == {"param1": "value1", "param2": "value2"} + + def test_parse_openai_llm_with_underscore_in_action(self): + parser = ToolActionParser("OpenAILLM") + + call = Mock() + call.name = "send_email_notification_456" + call.arguments = '{"to": "user@example.com"}' + + tool_id, action_name, call_args = parser.parse_args(call) + + assert tool_id == "456" + assert action_name == "send_email_notification" + assert call_args == {"to": "user@example.com"} + + def test_parse_openai_llm_invalid_format_no_underscore(self): + parser = ToolActionParser("OpenAILLM") + + call = Mock() + call.name = "invalidtoolname" + call.arguments = "{}" + + tool_id, action_name, call_args = parser.parse_args(call) + + assert tool_id is None + assert action_name is None + assert call_args is None + + def test_parse_openai_llm_non_numeric_tool_id(self): + parser = ToolActionParser("OpenAILLM") + + call = Mock() + call.name = "action_notanumber" + call.arguments = "{}" + + tool_id, action_name, call_args = parser.parse_args(call) + + assert tool_id == "notanumber" + assert action_name == "action" + + def test_parse_openai_llm_malformed_json(self): + parser = ToolActionParser("OpenAILLM") + + call = Mock() + call.name = "action_123" + call.arguments = "invalid json" + + tool_id, action_name, call_args = parser.parse_args(call) + + assert tool_id is None + assert action_name is None + assert call_args is None + + def test_parse_openai_llm_missing_attributes(self): + parser = ToolActionParser("OpenAILLM") + + call = Mock(spec=[]) + + tool_id, action_name, call_args = parser.parse_args(call) + + assert tool_id is None + assert action_name is None + assert call_args is None + + def test_parse_google_llm_valid_call(self): + parser = ToolActionParser("GoogleLLM") + + call = Mock() + call.name = "search_documents_789" + call.arguments = {"query": "test query", "limit": 10} + + tool_id, action_name, call_args = parser.parse_args(call) + + assert tool_id == "789" + assert action_name == "search_documents" + assert call_args == {"query": "test query", "limit": 10} + + def test_parse_google_llm_with_complex_action_name(self): + parser = ToolActionParser("GoogleLLM") + + call = Mock() + call.name = "create_new_user_account_999" + call.arguments = {"username": "test"} + + tool_id, action_name, call_args = parser.parse_args(call) + + assert tool_id == "999" + assert action_name == "create_new_user_account" + + def test_parse_google_llm_invalid_format(self): + parser = ToolActionParser("GoogleLLM") + + call = Mock() + call.name = "nounderscores" + call.arguments = {} + + tool_id, action_name, call_args = parser.parse_args(call) + + assert tool_id is None + assert action_name is None + assert call_args is None + + def test_parse_google_llm_missing_attributes(self): + parser = ToolActionParser("GoogleLLM") + + call = Mock(spec=[]) + + tool_id, action_name, call_args = parser.parse_args(call) + + assert tool_id is None + assert action_name is None + assert call_args is None + + def test_parse_unknown_llm_type_defaults_to_openai(self): + parser = ToolActionParser("UnknownLLM") + + call = Mock() + call.name = "action_123" + call.arguments = '{"key": "value"}' + + tool_id, action_name, call_args = parser.parse_args(call) + + assert tool_id == "123" + assert action_name == "action" + assert call_args == {"key": "value"} + + def test_parse_args_empty_arguments_openai(self): + parser = ToolActionParser("OpenAILLM") + + call = Mock() + call.name = "action_123" + call.arguments = "{}" + + tool_id, action_name, call_args = parser.parse_args(call) + + assert tool_id == "123" + assert action_name == "action" + assert call_args == {} + + def test_parse_args_empty_arguments_google(self): + parser = ToolActionParser("GoogleLLM") + + call = Mock() + call.name = "action_456" + call.arguments = {} + + tool_id, action_name, call_args = parser.parse_args(call) + + assert tool_id == "456" + assert action_name == "action" + assert call_args == {} + + def test_parse_args_with_special_characters(self): + parser = ToolActionParser("OpenAILLM") + + call = Mock() + call.name = "send_message_123" + call.arguments = '{"message": "Hello, World! 你好"}' + + tool_id, action_name, call_args = parser.parse_args(call) + + assert tool_id == "123" + assert action_name == "send_message" + assert call_args["message"] == "Hello, World! 你好" + + def test_parse_args_with_nested_objects(self): + parser = ToolActionParser("OpenAILLM") + + call = Mock() + call.name = "create_record_123" + call.arguments = '{"data": {"name": "John", "age": 30}}' + + tool_id, action_name, call_args = parser.parse_args(call) + + assert tool_id == "123" + assert action_name == "create_record" + assert call_args["data"]["name"] == "John" + assert call_args["data"]["age"] == 30 diff --git a/tests/agents/test_tool_manager.py b/tests/agents/test_tool_manager.py new file mode 100644 index 00000000..9217e3c2 --- /dev/null +++ b/tests/agents/test_tool_manager.py @@ -0,0 +1,235 @@ +from unittest.mock import Mock, patch + +import pytest +from application.agents.tools.base import Tool +from application.agents.tools.tool_manager import ToolManager + + +class MockTool(Tool): + def __init__(self, config): + self.config = config + + def execute_action(self, action_name: str, **kwargs): + return f"Executed {action_name} with {kwargs}" + + def get_actions_metadata(self): + return [{"name": "test_action", "description": "Test action"}] + + def get_config_requirements(self): + return {"required": ["api_key"]} + + +@pytest.mark.unit +class TestToolManager: + + @patch("application.agents.tools.tool_manager.pkgutil.iter_modules") + def test_tool_manager_initialization(self, mock_iter): + mock_iter.return_value = [] + + config = {"tool1": {"key": "value"}} + manager = ToolManager(config) + + assert manager.config == config + assert isinstance(manager.tools, dict) + + @patch("application.agents.tools.tool_manager.pkgutil.iter_modules") + @patch("application.agents.tools.tool_manager.importlib.import_module") + def test_load_tools_skips_base_and_private(self, mock_import, mock_iter): + mock_iter.return_value = [ + (None, "base", False), + (None, "__init__", False), + (None, "__pycache__", False), + (None, "valid_tool", False), + ] + + mock_module = Mock() + mock_module.MockTool = MockTool + mock_import.return_value = mock_module + + manager = ToolManager({}) + + assert "base" not in manager.tools + assert "__init__" not in manager.tools + + @patch("application.agents.tools.tool_manager.pkgutil.iter_modules") + def test_load_tools_creates_tool_instances(self, mock_iter): + mock_iter.return_value = [] + + manager = ToolManager({}) + + mock_tool = MockTool({"test": "config"}) + manager.tools["mock_tool"] = mock_tool + + assert "mock_tool" in manager.tools + assert isinstance(manager.tools["mock_tool"], MockTool) + assert manager.tools["mock_tool"].config == {"test": "config"} + + def test_load_tool_with_user_id(self): + with patch( + "application.agents.tools.tool_manager.pkgutil.iter_modules", + return_value=[], + ): + manager = ToolManager({}) + tool = MockTool({"key": "value"}) + assert tool.config == {"key": "value"} + + manager.config["test_tool"] = {"key": "value"} + assert "test_tool" in manager.config + + def test_load_tool_without_user_id(self): + tool = MockTool({"api_key": "test123"}) + + assert isinstance(tool, MockTool) + assert tool.config == {"api_key": "test123"} + + assert hasattr(tool, "execute_action") + assert hasattr(tool, "get_actions_metadata") + + @patch("application.agents.tools.tool_manager.pkgutil.iter_modules") + def test_load_tool_updates_config(self, mock_iter): + mock_iter.return_value = [] + + manager = ToolManager({}) + new_config = {"new_key": "new_value"} + + manager.config["test_tool"] = new_config + + assert manager.config["test_tool"] == new_config + assert "test_tool" in manager.config + + @patch("application.agents.tools.tool_manager.pkgutil.iter_modules") + @patch("application.agents.tools.tool_manager.importlib.import_module") + def test_execute_action_on_loaded_tool(self, mock_import, mock_iter): + mock_iter.return_value = [(None, "mock_tool", False)] + + mock_tool_instance = MockTool({}) + + with patch("inspect.getmembers", return_value=[("MockTool", MockTool)]): + with patch("inspect.isclass", return_value=True): + with patch.object(MockTool, "__init__", return_value=None): + manager = ToolManager({}) + manager.tools["mock_tool"] = mock_tool_instance + + result = manager.execute_action( + "mock_tool", "test_action", param="value" + ) + + assert "Executed test_action" in result + + def test_execute_action_tool_not_loaded(self): + with patch( + "application.agents.tools.tool_manager.pkgutil.iter_modules", + return_value=[], + ): + manager = ToolManager({}) + with pytest.raises(ValueError, match="Tool 'nonexistent' not loaded"): + manager.execute_action("nonexistent", "action") + + @patch("application.agents.tools.tool_manager.importlib.import_module") + def test_execute_action_with_user_id_for_mcp_tool(self, mock_import): + mock_tool = MockTool({}) + + with patch("inspect.getmembers", return_value=[("MockTool", MockTool)]): + with patch("inspect.isclass", return_value=True): + manager = ToolManager({"mcp_tool": {}}) + manager.tools["mcp_tool"] = mock_tool + + with patch.object( + manager, "load_tool", return_value=mock_tool + ) as mock_load: + manager.execute_action("mcp_tool", "action", user_id="user123") + + mock_load.assert_called_once_with("mcp_tool", {}, "user123") + + @patch("application.agents.tools.tool_manager.importlib.import_module") + def test_execute_action_with_user_id_for_memory_tool(self, mock_import): + mock_tool = MockTool({}) + + with patch("inspect.getmembers", return_value=[("MockTool", MockTool)]): + with patch("inspect.isclass", return_value=True): + manager = ToolManager({"memory": {}}) + manager.tools["memory"] = mock_tool + + with patch.object( + manager, "load_tool", return_value=mock_tool + ) as mock_load: + manager.execute_action("memory", "view", user_id="user456") + + mock_load.assert_called_once_with("memory", {}, "user456") + + @patch("application.agents.tools.tool_manager.pkgutil.iter_modules") + @patch("application.agents.tools.tool_manager.importlib.import_module") + def test_get_all_actions_metadata(self, mock_import, mock_iter): + mock_iter.return_value = [(None, "tool1", False), (None, "tool2", False)] + + mock_tool1 = Mock() + mock_tool1.get_actions_metadata.return_value = [{"name": "action1"}] + + mock_tool2 = Mock() + mock_tool2.get_actions_metadata.return_value = [{"name": "action2"}] + + manager = ToolManager({}) + manager.tools = {"tool1": mock_tool1, "tool2": mock_tool2} + + metadata = manager.get_all_actions_metadata() + + assert len(metadata) == 2 + assert {"name": "action1"} in metadata + assert {"name": "action2"} in metadata + + @patch("application.agents.tools.tool_manager.pkgutil.iter_modules") + def test_get_all_actions_metadata_empty(self, mock_iter): + mock_iter.return_value = [] + + manager = ToolManager({}) + manager.tools = {} + + metadata = manager.get_all_actions_metadata() + + assert metadata == [] + + def test_load_tool_with_notes_tool(self): + tool = MockTool({"key": "value"}) + + assert isinstance(tool, MockTool) + assert tool.config == {"key": "value"} + + result = tool.execute_action("test_action", param="value") + assert "test_action" in result + + +@pytest.mark.unit +class TestToolBase: + + def test_tool_base_is_abstract(self): + with pytest.raises(TypeError): + Tool() + + def test_mock_tool_implements_interface(self): + tool = MockTool({"test": "config"}) + + assert hasattr(tool, "execute_action") + assert hasattr(tool, "get_actions_metadata") + assert hasattr(tool, "get_config_requirements") + + def test_mock_tool_execute_action(self): + tool = MockTool({}) + result = tool.execute_action("test", param="value") + + assert "Executed test" in result + assert "param" in result + + def test_mock_tool_get_actions_metadata(self): + tool = MockTool({}) + metadata = tool.get_actions_metadata() + + assert isinstance(metadata, list) + assert len(metadata) > 0 + assert "name" in metadata[0] + + def test_mock_tool_get_config_requirements(self): + tool = MockTool({}) + requirements = tool.get_config_requirements() + + assert isinstance(requirements, dict) + assert "required" in requirements diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..3fd46f2f --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,199 @@ +from unittest.mock import Mock + +import pytest +from application.core.settings import settings + + +@pytest.fixture +def mock_llm(): + llm = Mock() + llm.gen_stream = Mock() + llm._supports_tools = True + llm._supports_structured_output = Mock(return_value=False) + llm.__class__.__name__ = "MockLLM" + return llm + + +@pytest.fixture +def mock_llm_handler(): + handler = Mock() + handler.process_message_flow = Mock() + return handler + + +@pytest.fixture +def mock_retriever(): + retriever = Mock() + retriever.search = Mock( + return_value=[ + {"text": "Test document 1", "filename": "doc1.txt", "source": "test"}, + {"text": "Test document 2", "title": "doc2.txt", "source": "test"}, + ] + ) + return retriever + + +@pytest.fixture +def mock_mongo_db(monkeypatch): + fake_collection = FakeMongoCollection() + fake_db = { + "agents": fake_collection, + "user_tools": fake_collection, + "memories": fake_collection, + } + fake_client = {settings.MONGO_DB_NAME: fake_db} + + monkeypatch.setattr( + "application.core.mongo_db.MongoDB.get_client", lambda: fake_client + ) + return fake_client + + +@pytest.fixture +def sample_chat_history(): + return [ + {"prompt": "What is Python?", "response": "Python is a programming language."}, + {"prompt": "Tell me more.", "response": "Python is known for simplicity."}, + ] + + +@pytest.fixture +def sample_tool_call(): + return { + "tool_name": "test_tool", + "call_id": "123", + "action_name": "test_action", + "arguments": {"arg1": "value1"}, + "result": "Tool executed successfully", + } + + +@pytest.fixture +def decoded_token(): + return {"sub": "test_user", "email": "test@example.com"} + + +@pytest.fixture +def log_context(): + from application.logging import LogContext + + context = LogContext( + endpoint="test_endpoint", + activity_id="test_activity", + user="test_user", + api_key="test_key", + query="test query", + ) + return context + + +class FakeMongoCollection: + def __init__(self): + self.docs = {} + + def find_one(self, query, projection=None): + if "key" in query: + return self.docs.get(query["key"]) + if "_id" in query: + return self.docs.get(str(query["_id"])) + if "user" in query: + for doc in self.docs.values(): + if doc.get("user") == query["user"]: + return doc + return None + + def find(self, query, projection=None): + results = [] + if "_id" in query and "$in" in query["_id"]: + for doc_id in query["_id"]["$in"]: + doc = self.docs.get(str(doc_id)) + if doc: + results.append(doc) + elif "user" in query: + for doc in self.docs.values(): + if doc.get("user") == query["user"]: + if "status" in query: + if doc.get("status") == query["status"]: + results.append(doc) + else: + results.append(doc) + return results + + def insert_one(self, doc): + doc_id = doc.get("_id", len(self.docs)) + self.docs[str(doc_id)] = doc + return Mock(inserted_id=doc_id) + + def update_one(self, query, update, upsert=False): + return Mock(modified_count=1) + + def delete_one(self, query): + return Mock(deleted_count=1) + + def delete_many(self, query): + return Mock(deleted_count=0) + + +@pytest.fixture +def mock_llm_creator(mock_llm, monkeypatch): + monkeypatch.setattr( + "application.llm.llm_creator.LLMCreator.create_llm", Mock(return_value=mock_llm) + ) + return mock_llm + + +@pytest.fixture +def mock_llm_handler_creator(mock_llm_handler, monkeypatch): + monkeypatch.setattr( + "application.llm.handlers.handler_creator.LLMHandlerCreator.create_handler", + Mock(return_value=mock_llm_handler), + ) + return mock_llm_handler + + +@pytest.fixture +def agent_base_params(decoded_token): + return { + "endpoint": "https://api.example.com", + "llm_name": "openai", + "gpt_model": "gpt-4", + "api_key": "test_api_key", + "user_api_key": None, + "prompt": "You are a helpful assistant.", + "chat_history": [], + "decoded_token": decoded_token, + "attachments": [], + "json_schema": None, + } + + +@pytest.fixture +def mock_tool(): + tool = Mock() + tool.execute_action = Mock(return_value="Tool result") + tool.get_actions_metadata = Mock( + return_value=[ + { + "name": "test_action", + "description": "A test action", + "parameters": { + "type": "object", + "properties": { + "param1": {"type": "string", "description": "Test parameter"} + }, + "required": ["param1"], + }, + } + ] + ) + return tool + + +@pytest.fixture +def mock_tool_manager(mock_tool, monkeypatch): + manager = Mock() + manager.load_tool = Mock(return_value=mock_tool) + monkeypatch.setattr( + "application.agents.base.ToolManager", Mock(return_value=manager) + ) + return manager diff --git a/tests/llm/test_openai_llm.py b/tests/llm/test_openai_llm.py index 82596359..77ffe14c 100644 --- a/tests/llm/test_openai_llm.py +++ b/tests/llm/test_openai_llm.py @@ -1,6 +1,6 @@ import types -import pytest +import pytest from application.llm.openai import OpenAILLM @@ -42,16 +42,16 @@ class FakeChatCompletions: def create(self, **kwargs): self.last_kwargs = kwargs - # default non-streaming: return content if not kwargs.get("stream"): - return FakeChatCompletions._Response(choices=[ - FakeChatCompletions._Choice(content="hello world") - ]) - # streaming: yield line objects each with choices[0].delta.content - return FakeChatCompletions._Response(lines=[ - FakeChatCompletions._StreamLine(["part1"]), - FakeChatCompletions._StreamLine(["part2"]), - ]) + return FakeChatCompletions._Response( + choices=[FakeChatCompletions._Choice(content="hello world")] + ) + return FakeChatCompletions._Response( + lines=[ + FakeChatCompletions._StreamLine(["part1"]), + FakeChatCompletions._StreamLine(["part2"]), + ] + ) class FakeClient: @@ -71,16 +71,29 @@ def openai_llm(monkeypatch): return llm +@pytest.mark.unit def test_clean_messages_openai_variants(openai_llm): messages = [ {"role": "system", "content": "sys"}, - {"role": "model", "content": "asst"}, - {"role": "user", "content": [ - {"text": "hello"}, - {"function_call": {"call_id": "c1", "name": "fn", "args": {"a": 1}}}, - {"function_response": {"call_id": "c1", "name": "fn", "response": {"result": 42}}}, - {"type": "image_url", "image_url": {"url": "data:image/png;base64,AAA"}}, - ]}, + {"role": "model", "content": "asst"}, + { + "role": "user", + "content": [ + {"text": "hello"}, + {"function_call": {"call_id": "c1", "name": "fn", "args": {"a": 1}}}, + { + "function_response": { + "call_id": "c1", + "name": "fn", + "response": {"result": 42}, + } + }, + { + "type": "image_url", + "image_url": {"url": "data:image/png;base64,AAA"}, + }, + ], + }, ] cleaned = openai_llm._clean_messages_openai(messages) @@ -89,17 +102,27 @@ def test_clean_messages_openai_variants(openai_llm): assert roles.count("assistant") >= 1 assert any(m["role"] == "tool" for m in cleaned) - assert any(isinstance(m["content"], list) and any( - part.get("type") == "image_url" for part in m["content"] if isinstance(part, dict) - ) for m in cleaned if m["role"] == "user") + assert any( + isinstance(m["content"], list) + and any( + part.get("type") == "image_url" + for part in m["content"] + if isinstance(part, dict) + ) + for m in cleaned + if m["role"] == "user" + ) +@pytest.mark.unit def test_raw_gen_calls_openai_client_and_returns_content(openai_llm): msgs = [ {"role": "system", "content": "sys"}, {"role": "user", "content": "hello"}, ] - content = openai_llm._raw_gen(openai_llm, model="gpt-4o", messages=msgs, stream=False) + content = openai_llm._raw_gen( + openai_llm, model="gpt-4o", messages=msgs, stream=False + ) assert content == "hello world" passed = openai_llm.client.chat.completions.last_kwargs @@ -108,16 +131,20 @@ def test_raw_gen_calls_openai_client_and_returns_content(openai_llm): assert passed["stream"] is False +@pytest.mark.unit def test_raw_gen_stream_yields_chunks(openai_llm): msgs = [ {"role": "user", "content": "hi"}, ] - gen = openai_llm._raw_gen_stream(openai_llm, model="gpt", messages=msgs, stream=True) + gen = openai_llm._raw_gen_stream( + openai_llm, model="gpt", messages=msgs, stream=True + ) chunks = list(gen) assert "part1" in "".join(chunks) assert "part2" in "".join(chunks) +@pytest.mark.unit def test_prepare_structured_output_format_enforces_required_and_strict(openai_llm): schema = { "type": "object", @@ -134,8 +161,8 @@ def test_prepare_structured_output_format_enforces_required_and_strict(openai_ll assert js["schema"]["additionalProperties"] is False +@pytest.mark.unit def test_prepare_messages_with_attachments_image_and_pdf(openai_llm, monkeypatch): - monkeypatch.setattr(openai_llm, "_get_base64_image", lambda att: "AAA=") monkeypatch.setattr(openai_llm, "_upload_file_to_openai", lambda att: "file_xyz") @@ -146,12 +173,15 @@ def test_prepare_messages_with_attachments_image_and_pdf(openai_llm, monkeypatch ] out = openai_llm.prepare_messages_with_attachments(messages, attachments) - # last user message should have list content with text and two attachments user_msg = next(m for m in out if m["role"] == "user") assert isinstance(user_msg["content"], list) - types_in_content = [p.get("type") for p in user_msg["content"] if isinstance(p, dict)] + types_in_content = [ + p.get("type") for p in user_msg["content"] if isinstance(p, dict) + ] assert "image_url" in types_in_content or any( isinstance(p, dict) and p.get("image_url") for p in user_msg["content"] ) - assert any(isinstance(p, dict) and p.get("file", {}).get("file_id") == "file_xyz" for p in user_msg["content"]) - + assert any( + isinstance(p, dict) and p.get("file", {}).get("file_id") == "file_xyz" + for p in user_msg["content"] + ) diff --git a/tests/requirements.txt b/tests/requirements.txt new file mode 100644 index 00000000..57a2c092 --- /dev/null +++ b/tests/requirements.txt @@ -0,0 +1,3 @@ +pytest>=8.0.0 +pytest-cov>=4.1.0 +coverage>=7.4.0 diff --git a/tests/security/test_encryption.py b/tests/security/test_encryption.py index 9b99eb2f..2d7e4efd 100644 --- a/tests/security/test_encryption.py +++ b/tests/security/test_encryption.py @@ -1,12 +1,24 @@ import base64 +import pytest +from application.security import encryption from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC -from application.security import encryption + +def _fake_os_urandom_factory(values): + values_iter = iter(values) + + def _fake(length): + value = next(values_iter) + assert len(value) == length + return value + + return _fake +@pytest.mark.unit def test_derive_key_uses_secret_and_user(monkeypatch): monkeypatch.setattr(encryption.settings, "ENCRYPTION_SECRET_KEY", "test-secret") salt = bytes(range(16)) @@ -25,17 +37,7 @@ def test_derive_key_uses_secret_and_user(monkeypatch): assert derived == expected_key -def _fake_os_urandom_factory(values): - values_iter = iter(values) - - def _fake(length): - value = next(values_iter) - assert len(value) == length - return value - - return _fake - - +@pytest.mark.unit def test_encrypt_and_decrypt_round_trip(monkeypatch): monkeypatch.setattr(encryption.settings, "ENCRYPTION_SECRET_KEY", "test-secret") salt = bytes(range(16)) @@ -55,6 +57,7 @@ def test_encrypt_and_decrypt_round_trip(monkeypatch): assert decrypted == credentials +@pytest.mark.unit def test_encrypt_credentials_returns_empty_for_empty_input(monkeypatch): monkeypatch.setattr(encryption.settings, "ENCRYPTION_SECRET_KEY", "test-secret") @@ -62,11 +65,12 @@ def test_encrypt_credentials_returns_empty_for_empty_input(monkeypatch): assert encryption.encrypt_credentials(None, "user-123") == "" +@pytest.mark.unit def test_encrypt_credentials_returns_empty_on_serialization_error(monkeypatch): monkeypatch.setattr(encryption.settings, "ENCRYPTION_SECRET_KEY", "test-secret") monkeypatch.setattr(encryption.os, "urandom", lambda length: b"\x00" * length) - class NonSerializable: # pragma: no cover - simple helper container + class NonSerializable: pass credentials = {"bad": NonSerializable()} @@ -74,6 +78,7 @@ def test_encrypt_credentials_returns_empty_on_serialization_error(monkeypatch): assert encryption.encrypt_credentials(credentials, "user-123") == "" +@pytest.mark.unit def test_decrypt_credentials_returns_empty_for_invalid_input(monkeypatch): monkeypatch.setattr(encryption.settings, "ENCRYPTION_SECRET_KEY", "test-secret") @@ -84,6 +89,7 @@ def test_decrypt_credentials_returns_empty_for_invalid_input(monkeypatch): assert encryption.decrypt_credentials(invalid_payload, "user-123") == {} +@pytest.mark.unit def test_pad_and_unpad_are_inverse(): original = b"secret-data" @@ -91,4 +97,3 @@ def test_pad_and_unpad_are_inverse(): assert len(padded) % 16 == 0 assert encryption._unpad_data(padded) == original - diff --git a/tests/storage/test_local_storage.py b/tests/storage/test_local_storage.py index 4277af3a..397f8e9a 100644 --- a/tests/storage/test_local_storage.py +++ b/tests/storage/test_local_storage.py @@ -1,352 +1,401 @@ -"""Tests for LocalStorage implementation -""" - import io -import pytest -from unittest.mock import patch, MagicMock, mock_open +import os +from unittest.mock import MagicMock, mock_open, patch +import pytest from application.storage.local import LocalStorage @pytest.fixture def temp_base_dir(): - """Provide a temporary base directory path for testing.""" return "/tmp/test_storage" @pytest.fixture def local_storage(temp_base_dir): - """Create LocalStorage instance with test base directory.""" return LocalStorage(base_dir=temp_base_dir) +@pytest.mark.unit class TestLocalStorageInitialization: - """Test LocalStorage initialization and configuration.""" def test_init_with_custom_base_dir(self): - """Should use provided base directory.""" storage = LocalStorage(base_dir="/custom/path") assert storage.base_dir == "/custom/path" def test_init_with_default_base_dir(self): - """Should use default base directory when none provided.""" storage = LocalStorage() - # Default is three levels up from the file location assert storage.base_dir is not None assert isinstance(storage.base_dir, str) def test_get_full_path_with_relative_path(self, local_storage): - """Should combine base_dir with relative path.""" result = local_storage._get_full_path("documents/test.txt") - assert result == "/tmp/test_storage/documents/test.txt" + expected = os.path.join("/tmp/test_storage", "documents/test.txt") + assert os.path.normpath(result) == os.path.normpath(expected) def test_get_full_path_with_absolute_path(self, local_storage): - """Should return absolute path unchanged.""" result = local_storage._get_full_path("/absolute/path/test.txt") assert result == "/absolute/path/test.txt" - -class TestLocalStorageSaveFile: - """Test file saving functionality.""" - - @patch('os.makedirs') - @patch('builtins.open', new_callable=mock_open) - @patch('shutil.copyfileobj') + @patch("os.makedirs") + @patch("builtins.open", new_callable=mock_open) + @patch("shutil.copyfileobj") def test_save_file_creates_directory_and_saves( self, mock_copyfileobj, mock_file, mock_makedirs, local_storage ): - """Should create directory and save file content.""" file_data = io.BytesIO(b"test content") path = "documents/test.txt" result = local_storage.save_file(file_data, path) - # Verify directory creation - mock_makedirs.assert_called_once_with( - "/tmp/test_storage/documents", - exist_ok=True + expected_dir = os.path.join("/tmp/test_storage", "documents") + expected_file = os.path.join("/tmp/test_storage", "documents/test.txt") + + assert mock_makedirs.call_count == 1 + assert os.path.normpath(mock_makedirs.call_args[0][0]) == os.path.normpath( + expected_dir ) + assert mock_makedirs.call_args[1] == {"exist_ok": True} + + assert mock_file.call_count == 1 + assert os.path.normpath(mock_file.call_args[0][0]) == os.path.normpath( + expected_file + ) + assert mock_file.call_args[0][1] == "wb" - # Verify file write - mock_file.assert_called_once_with("/tmp/test_storage/documents/test.txt", 'wb') mock_copyfileobj.assert_called_once_with(file_data, mock_file()) + assert result == {"storage_type": "local"} - # Verify result - assert result == {'storage_type': 'local'} - - @patch('os.makedirs') + @patch("os.makedirs") def test_save_file_with_save_method(self, mock_makedirs, local_storage): - """Should use save method if file_data has it.""" file_data = MagicMock() file_data.save = MagicMock() path = "documents/test.txt" result = local_storage.save_file(file_data, path) - # Verify save method was called - file_data.save.assert_called_once_with("/tmp/test_storage/documents/test.txt") + expected_file = os.path.join("/tmp/test_storage", "documents/test.txt") + assert file_data.save.call_count == 1 + assert os.path.normpath(file_data.save.call_args[0][0]) == os.path.normpath( + expected_file + ) + assert result == {"storage_type": "local"} - # Verify result - assert result == {'storage_type': 'local'} - - @patch('os.makedirs') - @patch('builtins.open', new_callable=mock_open) - def test_save_file_with_absolute_path(self, mock_file, mock_makedirs, local_storage): - """Should handle absolute paths correctly.""" + @patch("os.makedirs") + @patch("builtins.open", new_callable=mock_open) + def test_save_file_with_absolute_path( + self, mock_file, mock_makedirs, local_storage + ): file_data = io.BytesIO(b"test content") path = "/absolute/path/test.txt" local_storage.save_file(file_data, path) mock_makedirs.assert_called_once_with("/absolute/path", exist_ok=True) - mock_file.assert_called_once_with("/absolute/path/test.txt", 'wb') + mock_file.assert_called_once_with("/absolute/path/test.txt", "wb") +@pytest.mark.unit class TestLocalStorageGetFile: - """Test file retrieval functionality.""" - @patch('os.path.exists', return_value=True) - @patch('builtins.open', new_callable=mock_open, read_data=b"file content") + @patch("os.path.exists", return_value=True) + @patch("builtins.open", new_callable=mock_open, read_data=b"file content") def test_get_file_returns_file_handle(self, mock_file, mock_exists, local_storage): - """Should open and return file handle when file exists.""" path = "documents/test.txt" result = local_storage.get_file(path) - mock_exists.assert_called_once_with("/tmp/test_storage/documents/test.txt") - mock_file.assert_called_once_with("/tmp/test_storage/documents/test.txt", 'rb') + expected_path = os.path.join("/tmp/test_storage", "documents/test.txt") + assert mock_exists.call_count == 1 + assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath( + expected_path + ) + assert mock_file.call_count == 1 + assert os.path.normpath(mock_file.call_args[0][0]) == os.path.normpath( + expected_path + ) assert result is not None - @patch('os.path.exists', return_value=False) + @patch("os.path.exists", return_value=False) def test_get_file_raises_error_when_not_found(self, mock_exists, local_storage): - """Should raise FileNotFoundError when file doesn't exist.""" path = "documents/nonexistent.txt" with pytest.raises(FileNotFoundError, match="File not found"): local_storage.get_file(path) - - mock_exists.assert_called_once_with("/tmp/test_storage/documents/nonexistent.txt") + expected_path = os.path.join("/tmp/test_storage", "documents/nonexistent.txt") + assert mock_exists.call_count == 1 + assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath( + expected_path + ) +@pytest.mark.unit class TestLocalStorageDeleteFile: - """Test file deletion functionality.""" - @patch('os.remove') - @patch('os.path.exists', return_value=True) - def test_delete_file_removes_existing_file(self, mock_exists, mock_remove, local_storage): - """Should delete file and return True when file exists.""" + @patch("os.remove") + @patch("os.path.exists", return_value=True) + def test_delete_file_removes_existing_file( + self, mock_exists, mock_remove, local_storage + ): path = "documents/test.txt" result = local_storage.delete_file(path) + expected_path = os.path.join("/tmp/test_storage", "documents/test.txt") assert result is True - mock_exists.assert_called_once_with("/tmp/test_storage/documents/test.txt") - mock_remove.assert_called_once_with("/tmp/test_storage/documents/test.txt") + assert mock_exists.call_count == 1 + assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath( + expected_path + ) + assert mock_remove.call_count == 1 + assert os.path.normpath(mock_remove.call_args[0][0]) == os.path.normpath( + expected_path + ) - @patch('os.path.exists', return_value=False) + @patch("os.path.exists", return_value=False) def test_delete_file_returns_false_when_not_found(self, mock_exists, local_storage): - """Should return False when file doesn't exist.""" path = "documents/nonexistent.txt" result = local_storage.delete_file(path) + expected_path = os.path.join("/tmp/test_storage", "documents/nonexistent.txt") assert result is False - mock_exists.assert_called_once_with("/tmp/test_storage/documents/nonexistent.txt") + assert mock_exists.call_count == 1 + assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath( + expected_path + ) +@pytest.mark.unit class TestLocalStorageFileExists: - """Test file existence checking.""" - @patch('os.path.exists', return_value=True) + @patch("os.path.exists", return_value=True) def test_file_exists_returns_true_when_file_found(self, mock_exists, local_storage): - """Should return True when file exists.""" path = "documents/test.txt" result = local_storage.file_exists(path) + expected_path = os.path.join("/tmp/test_storage", "documents/test.txt") assert result is True - mock_exists.assert_called_once_with("/tmp/test_storage/documents/test.txt") + assert mock_exists.call_count == 1 + assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath( + expected_path + ) - @patch('os.path.exists', return_value=False) + @patch("os.path.exists", return_value=False) def test_file_exists_returns_false_when_not_found(self, mock_exists, local_storage): - """Should return False when file doesn't exist.""" path = "documents/nonexistent.txt" result = local_storage.file_exists(path) + expected_path = os.path.join("/tmp/test_storage", "documents/nonexistent.txt") assert result is False - mock_exists.assert_called_once_with("/tmp/test_storage/documents/nonexistent.txt") + assert mock_exists.call_count == 1 + assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath( + expected_path + ) +@pytest.mark.unit class TestLocalStorageListFiles: - """Test directory listing functionality.""" - @patch('os.walk') - @patch('os.path.exists', return_value=True) + @patch("os.walk") + @patch("os.path.exists", return_value=True) def test_list_files_returns_all_files_in_directory( self, mock_exists, mock_walk, local_storage ): - """Should return all files in directory and subdirectories.""" directory = "documents" + base_dir = os.path.join("/tmp/test_storage", "documents") - # Mock os.walk to return files in directory structure mock_walk.return_value = [ - ("/tmp/test_storage/documents", ["subdir"], ["file1.txt", "file2.txt"]), - ("/tmp/test_storage/documents/subdir", [], ["file3.txt"]) + (base_dir, ["subdir"], ["file1.txt", "file2.txt"]), + (os.path.join(base_dir, "subdir"), [], ["file3.txt"]), ] result = local_storage.list_files(directory) assert len(result) == 3 - assert "documents/file1.txt" in result - assert "documents/file2.txt" in result - assert "documents/subdir/file3.txt" in result + result_normalized = [os.path.normpath(f) for f in result] + assert os.path.normpath("documents/file1.txt") in result_normalized + assert os.path.normpath("documents/file2.txt") in result_normalized + assert os.path.normpath("documents/subdir/file3.txt") in result_normalized - mock_exists.assert_called_once_with("/tmp/test_storage/documents") - mock_walk.assert_called_once_with("/tmp/test_storage/documents") - - @patch('os.path.exists', return_value=False) + @patch("os.path.exists", return_value=False) def test_list_files_returns_empty_list_when_directory_not_found( self, mock_exists, local_storage ): - """Should return empty list when directory doesn't exist.""" directory = "nonexistent" result = local_storage.list_files(directory) + expected_path = os.path.join("/tmp/test_storage", "nonexistent") assert result == [] - mock_exists.assert_called_once_with("/tmp/test_storage/nonexistent") + assert mock_exists.call_count == 1 + assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath( + expected_path + ) +@pytest.mark.unit class TestLocalStorageProcessFile: - """Test file processing functionality.""" - @patch('os.path.exists', return_value=True) + @patch("os.path.exists", return_value=True) def test_process_file_calls_processor_with_full_path( self, mock_exists, local_storage ): - """Should call processor function with full file path.""" path = "documents/test.txt" processor_func = MagicMock(return_value="processed") result = local_storage.process_file(path, processor_func, extra_arg="value") + expected_path = os.path.join("/tmp/test_storage", "documents/test.txt") assert result == "processed" - processor_func.assert_called_once_with( - local_path="/tmp/test_storage/documents/test.txt", - extra_arg="value" + assert processor_func.call_count == 1 + call_kwargs = processor_func.call_args[1] + assert os.path.normpath(call_kwargs["local_path"]) == os.path.normpath( + expected_path ) - mock_exists.assert_called_once_with("/tmp/test_storage/documents/test.txt") + assert call_kwargs["extra_arg"] == "value" - @patch('os.path.exists', return_value=False) - def test_process_file_raises_error_when_file_not_found(self, mock_exists, local_storage): - """Should raise FileNotFoundError when file doesn't exist.""" + @patch("os.path.exists", return_value=False) + def test_process_file_raises_error_when_file_not_found( + self, mock_exists, local_storage + ): path = "documents/nonexistent.txt" processor_func = MagicMock() with pytest.raises(FileNotFoundError, match="File not found"): local_storage.process_file(path, processor_func) - processor_func.assert_not_called() +@pytest.mark.unit class TestLocalStorageIsDirectory: - """Test directory checking functionality.""" - @patch('os.path.isdir', return_value=True) + @patch("os.path.isdir", return_value=True) def test_is_directory_returns_true_when_directory_exists( self, mock_isdir, local_storage ): - """Should return True when path is a directory.""" path = "documents" result = local_storage.is_directory(path) + expected_path = os.path.join("/tmp/test_storage", "documents") assert result is True - mock_isdir.assert_called_once_with("/tmp/test_storage/documents") + assert mock_isdir.call_count == 1 + assert os.path.normpath(mock_isdir.call_args[0][0]) == os.path.normpath( + expected_path + ) - @patch('os.path.isdir', return_value=False) + @patch("os.path.isdir", return_value=False) def test_is_directory_returns_false_when_not_directory( self, mock_isdir, local_storage ): - """Should return False when path is not a directory or doesn't exist.""" path = "documents/test.txt" result = local_storage.is_directory(path) + expected_path = os.path.join("/tmp/test_storage", "documents/test.txt") assert result is False - mock_isdir.assert_called_once_with("/tmp/test_storage/documents/test.txt") + assert mock_isdir.call_count == 1 + assert os.path.normpath(mock_isdir.call_args[0][0]) == os.path.normpath( + expected_path + ) +@pytest.mark.unit class TestLocalStorageRemoveDirectory: - """Test directory removal functionality.""" - @patch('shutil.rmtree') - @patch('os.path.isdir', return_value=True) - @patch('os.path.exists', return_value=True) + @patch("shutil.rmtree") + @patch("os.path.isdir", return_value=True) + @patch("os.path.exists", return_value=True) def test_remove_directory_deletes_directory( self, mock_exists, mock_isdir, mock_rmtree, local_storage ): - """Should remove directory and return True when successful.""" directory = "documents" result = local_storage.remove_directory(directory) + expected_path = os.path.join("/tmp/test_storage", "documents") assert result is True - mock_exists.assert_called_once_with("/tmp/test_storage/documents") - mock_isdir.assert_called_once_with("/tmp/test_storage/documents") - mock_rmtree.assert_called_once_with("/tmp/test_storage/documents") + assert mock_exists.call_count == 1 + assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath( + expected_path + ) + assert mock_isdir.call_count == 1 + assert os.path.normpath(mock_isdir.call_args[0][0]) == os.path.normpath( + expected_path + ) + assert mock_rmtree.call_count == 1 + assert os.path.normpath(mock_rmtree.call_args[0][0]) == os.path.normpath( + expected_path + ) - @patch('os.path.exists', return_value=False) + @patch("os.path.exists", return_value=False) def test_remove_directory_returns_false_when_not_exists( self, mock_exists, local_storage ): - """Should return False when directory doesn't exist.""" directory = "nonexistent" result = local_storage.remove_directory(directory) + expected_path = os.path.join("/tmp/test_storage", "nonexistent") assert result is False - mock_exists.assert_called_once_with("/tmp/test_storage/nonexistent") + assert mock_exists.call_count == 1 + assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath( + expected_path + ) - @patch('os.path.isdir', return_value=False) - @patch('os.path.exists', return_value=True) + @patch("os.path.isdir", return_value=False) + @patch("os.path.exists", return_value=True) def test_remove_directory_returns_false_when_not_directory( self, mock_exists, mock_isdir, local_storage ): - """Should return False when path is not a directory.""" path = "documents/test.txt" result = local_storage.remove_directory(path) + expected_path = os.path.join("/tmp/test_storage", "documents/test.txt") assert result is False - mock_exists.assert_called_once_with("/tmp/test_storage/documents/test.txt") - mock_isdir.assert_called_once_with("/tmp/test_storage/documents/test.txt") + assert mock_exists.call_count == 1 + assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath( + expected_path + ) + assert mock_isdir.call_count == 1 + assert os.path.normpath(mock_isdir.call_args[0][0]) == os.path.normpath( + expected_path + ) - @patch('shutil.rmtree', side_effect=OSError("Permission denied")) - @patch('os.path.isdir', return_value=True) - @patch('os.path.exists', return_value=True) + @patch("shutil.rmtree", side_effect=OSError("Permission denied")) + @patch("os.path.isdir", return_value=True) + @patch("os.path.exists", return_value=True) def test_remove_directory_returns_false_on_os_error( self, mock_exists, mock_isdir, mock_rmtree, local_storage ): - """Should return False when OSError occurs during removal.""" directory = "documents" result = local_storage.remove_directory(directory) + expected_path = os.path.join("/tmp/test_storage", "documents") assert result is False - mock_rmtree.assert_called_once_with("/tmp/test_storage/documents") + assert mock_rmtree.call_count == 1 + assert os.path.normpath(mock_rmtree.call_args[0][0]) == os.path.normpath( + expected_path + ) - @patch('shutil.rmtree', side_effect=PermissionError("Access denied")) - @patch('os.path.isdir', return_value=True) - @patch('os.path.exists', return_value=True) + @patch("shutil.rmtree", side_effect=PermissionError("Access denied")) + @patch("os.path.isdir", return_value=True) + @patch("os.path.exists", return_value=True) def test_remove_directory_returns_false_on_permission_error( self, mock_exists, mock_isdir, mock_rmtree, local_storage ): - """Should return False when PermissionError occurs during removal.""" directory = "documents" result = local_storage.remove_directory(directory) + expected_path = os.path.join("/tmp/test_storage", "documents") assert result is False - mock_rmtree.assert_called_once_with("/tmp/test_storage/documents") + assert mock_rmtree.call_count == 1 + assert os.path.normpath(mock_rmtree.call_args[0][0]) == os.path.normpath( + expected_path + ) diff --git a/tests/storage/test_s3_storage.py b/tests/storage/test_s3_storage.py index a9cc3c5a..da2afac2 100644 --- a/tests/storage/test_s3_storage.py +++ b/tests/storage/test_s3_storage.py @@ -1,18 +1,18 @@ -"""Tests for S3 storage implementation. -""" +"""Tests for S3 storage implementation.""" import io +from unittest.mock import MagicMock, patch + import pytest -from unittest.mock import patch, MagicMock -from botocore.exceptions import ClientError from application.storage.s3 import S3Storage +from botocore.exceptions import ClientError @pytest.fixture def mock_boto3_client(): """Mock boto3.client to isolate S3 client creation.""" - with patch('boto3.client') as mock_client: + with patch("boto3.client") as mock_client: s3_mock = MagicMock() mock_client.return_value = s3_mock yield s3_mock @@ -27,22 +27,26 @@ def s3_storage(mock_boto3_client): class TestS3StorageInitialization: """Test S3Storage initialization and configuration.""" + @pytest.mark.unit def test_init_with_default_bucket(self): """Should use default bucket name when none provided.""" - with patch('boto3.client'): + with patch("boto3.client"): storage = S3Storage() assert storage.bucket_name == "docsgpt-test-bucket" + @pytest.mark.unit def test_init_with_custom_bucket(self): """Should use provided bucket name.""" - with patch('boto3.client'): + with patch("boto3.client"): storage = S3Storage(bucket_name="custom-bucket") assert storage.bucket_name == "custom-bucket" + @pytest.mark.unit def test_init_creates_boto3_client(self): """Should create boto3 S3 client with credentials from settings.""" - with patch('boto3.client') as mock_client, \ - patch('application.storage.s3.settings') as mock_settings: + with patch("boto3.client") as mock_client, patch( + "application.storage.s3.settings" + ) as mock_settings: mock_settings.SAGEMAKER_ACCESS_KEY = "test-key" mock_settings.SAGEMAKER_SECRET_KEY = "test-secret" @@ -54,52 +58,50 @@ class TestS3StorageInitialization: "s3", aws_access_key_id="test-key", aws_secret_access_key="test-secret", - region_name="us-west-2" + region_name="us-west-2", ) class TestS3StorageSaveFile: """Test file saving functionality.""" + @pytest.mark.unit def test_save_file_uploads_to_s3(self, s3_storage, mock_boto3_client): """Should upload file to S3 with correct parameters.""" file_data = io.BytesIO(b"test content") path = "documents/test.txt" - with patch('application.storage.s3.settings') as mock_settings: + with patch("application.storage.s3.settings") as mock_settings: mock_settings.SAGEMAKER_REGION = "us-east-1" result = s3_storage.save_file(file_data, path) - mock_boto3_client.upload_fileobj.assert_called_once_with( file_data, "test-bucket", path, - ExtraArgs={"StorageClass": "INTELLIGENT_TIERING"} + ExtraArgs={"StorageClass": "INTELLIGENT_TIERING"}, ) assert result == { "storage_type": "s3", "bucket_name": "test-bucket", "uri": "s3://test-bucket/documents/test.txt", - "region": "us-east-1" + "region": "us-east-1", } + @pytest.mark.unit def test_save_file_with_custom_storage_class(self, s3_storage, mock_boto3_client): """Should use custom storage class when provided.""" file_data = io.BytesIO(b"test content") path = "documents/test.txt" - with patch('application.storage.s3.settings') as mock_settings: + with patch("application.storage.s3.settings") as mock_settings: mock_settings.SAGEMAKER_REGION = "us-east-1" s3_storage.save_file(file_data, path, storage_class="STANDARD") - mock_boto3_client.upload_fileobj.assert_called_once_with( - file_data, - "test-bucket", - path, - ExtraArgs={"StorageClass": "STANDARD"} + file_data, "test-bucket", path, ExtraArgs={"StorageClass": "STANDARD"} ) + @pytest.mark.unit def test_save_file_propagates_client_error(self, s3_storage, mock_boto3_client): """Should propagate ClientError when upload fails.""" file_data = io.BytesIO(b"test content") @@ -107,7 +109,7 @@ class TestS3StorageSaveFile: mock_boto3_client.upload_fileobj.side_effect = ClientError( {"Error": {"Code": "AccessDenied", "Message": "Access denied"}}, - "upload_fileobj" + "upload_fileobj", ) with pytest.raises(ClientError): @@ -117,7 +119,10 @@ class TestS3StorageSaveFile: class TestS3StorageFileExists: """Test file existence checking.""" - def test_file_exists_returns_true_when_file_found(self, s3_storage, mock_boto3_client): + @pytest.mark.unit + def test_file_exists_returns_true_when_file_found( + self, s3_storage, mock_boto3_client + ): """Should return True when head_object succeeds.""" path = "documents/test.txt" mock_boto3_client.head_object.return_value = {"ContentLength": 100} @@ -126,16 +131,17 @@ class TestS3StorageFileExists: assert result is True mock_boto3_client.head_object.assert_called_once_with( - Bucket="test-bucket", - Key=path + Bucket="test-bucket", Key=path ) - def test_file_exists_returns_false_on_client_error(self, s3_storage, mock_boto3_client): + @pytest.mark.unit + def test_file_exists_returns_false_on_client_error( + self, s3_storage, mock_boto3_client + ): """Should return False when head_object raises ClientError.""" path = "documents/nonexistent.txt" mock_boto3_client.head_object.side_effect = ClientError( - {"Error": {"Code": "NoSuchKey", "Message": "Not found"}}, - "head_object" + {"Error": {"Code": "NoSuchKey", "Message": "Not found"}}, "head_object" ) result = s3_storage.file_exists(path) @@ -146,7 +152,10 @@ class TestS3StorageFileExists: class TestS3StorageGetFile: """Test file retrieval functionality.""" - def test_get_file_downloads_and_returns_file_object(self, s3_storage, mock_boto3_client): + @pytest.mark.unit + def test_get_file_downloads_and_returns_file_object( + self, s3_storage, mock_boto3_client + ): """Should download file from S3 and return BytesIO object.""" path = "documents/test.txt" test_content = b"file content" @@ -164,12 +173,14 @@ class TestS3StorageGetFile: assert result.read() == test_content mock_boto3_client.download_fileobj.assert_called_once() - def test_get_file_raises_error_when_file_not_found(self, s3_storage, mock_boto3_client): + @pytest.mark.unit + def test_get_file_raises_error_when_file_not_found( + self, s3_storage, mock_boto3_client + ): """Should raise FileNotFoundError when file doesn't exist.""" path = "documents/nonexistent.txt" mock_boto3_client.head_object.side_effect = ClientError( - {"Error": {"Code": "NoSuchKey", "Message": "Not found"}}, - "head_object" + {"Error": {"Code": "NoSuchKey", "Message": "Not found"}}, "head_object" ) with pytest.raises(FileNotFoundError, match="File not found"): @@ -179,6 +190,7 @@ class TestS3StorageGetFile: class TestS3StorageDeleteFile: """Test file deletion functionality.""" + @pytest.mark.unit def test_delete_file_returns_true_on_success(self, s3_storage, mock_boto3_client): """Should return True when deletion succeeds.""" path = "documents/test.txt" @@ -188,16 +200,18 @@ class TestS3StorageDeleteFile: assert result is True mock_boto3_client.delete_object.assert_called_once_with( - Bucket="test-bucket", - Key=path + Bucket="test-bucket", Key=path ) - def test_delete_file_returns_false_on_client_error(self, s3_storage, mock_boto3_client): + @pytest.mark.unit + def test_delete_file_returns_false_on_client_error( + self, s3_storage, mock_boto3_client + ): """Should return False when deletion fails with ClientError.""" path = "documents/test.txt" mock_boto3_client.delete_object.side_effect = ClientError( {"Error": {"Code": "AccessDenied", "Message": "Access denied"}}, - "delete_object" + "delete_object", ) result = s3_storage.delete_file(path) @@ -208,7 +222,10 @@ class TestS3StorageDeleteFile: class TestS3StorageListFiles: """Test directory listing functionality.""" - def test_list_files_returns_all_keys_with_prefix(self, s3_storage, mock_boto3_client): + @pytest.mark.unit + def test_list_files_returns_all_keys_with_prefix( + self, s3_storage, mock_boto3_client + ): """Should return all file keys matching the directory prefix.""" directory = "documents/" @@ -219,7 +236,7 @@ class TestS3StorageListFiles: "Contents": [ {"Key": "documents/file1.txt"}, {"Key": "documents/file2.txt"}, - {"Key": "documents/subdir/file3.txt"} + {"Key": "documents/subdir/file3.txt"}, ] } ] @@ -231,13 +248,15 @@ class TestS3StorageListFiles: assert "documents/file2.txt" in result assert "documents/subdir/file3.txt" in result - mock_boto3_client.get_paginator.assert_called_once_with('list_objects_v2') + mock_boto3_client.get_paginator.assert_called_once_with("list_objects_v2") paginator_mock.paginate.assert_called_once_with( - Bucket="test-bucket", - Prefix="documents/" + Bucket="test-bucket", Prefix="documents/" ) - def test_list_files_returns_empty_list_when_no_contents(self, s3_storage, mock_boto3_client): + @pytest.mark.unit + def test_list_files_returns_empty_list_when_no_contents( + self, s3_storage, mock_boto3_client + ): """Should return empty list when directory has no files.""" directory = "empty/" @@ -253,30 +272,36 @@ class TestS3StorageListFiles: class TestS3StorageProcessFile: """Test file processing functionality.""" - def test_process_file_downloads_and_processes_file(self, s3_storage, mock_boto3_client): + @pytest.mark.unit + def test_process_file_downloads_and_processes_file( + self, s3_storage, mock_boto3_client + ): """Should download file to temp location and call processor function.""" path = "documents/test.txt" mock_boto3_client.head_object.return_value = {} - with patch('tempfile.NamedTemporaryFile') as mock_temp: + with patch("tempfile.NamedTemporaryFile") as mock_temp: mock_file = MagicMock() mock_file.name = "/tmp/test_file" mock_temp.return_value.__enter__.return_value = mock_file processor_func = MagicMock(return_value="processed") result = s3_storage.process_file(path, processor_func, extra_arg="value") - assert result == "processed" - processor_func.assert_called_once_with(local_path="/tmp/test_file", extra_arg="value") + processor_func.assert_called_once_with( + local_path="/tmp/test_file", extra_arg="value" + ) mock_boto3_client.download_fileobj.assert_called_once() - def test_process_file_raises_error_when_file_not_found(self, s3_storage, mock_boto3_client): + @pytest.mark.unit + def test_process_file_raises_error_when_file_not_found( + self, s3_storage, mock_boto3_client + ): """Should raise FileNotFoundError when file doesn't exist.""" path = "documents/nonexistent.txt" mock_boto3_client.head_object.side_effect = ClientError( - {"Error": {"Code": "NoSuchKey", "Message": "Not found"}}, - "head_object" + {"Error": {"Code": "NoSuchKey", "Message": "Not found"}}, "head_object" ) processor_func = MagicMock() @@ -288,7 +313,10 @@ class TestS3StorageProcessFile: class TestS3StorageIsDirectory: """Test directory checking functionality.""" - def test_is_directory_returns_true_when_objects_exist(self, s3_storage, mock_boto3_client): + @pytest.mark.unit + def test_is_directory_returns_true_when_objects_exist( + self, s3_storage, mock_boto3_client + ): """Should return True when objects exist with the directory prefix.""" path = "documents/" @@ -300,12 +328,13 @@ class TestS3StorageIsDirectory: assert result is True mock_boto3_client.list_objects_v2.assert_called_once_with( - Bucket="test-bucket", - Prefix="documents/", - MaxKeys=1 + Bucket="test-bucket", Prefix="documents/", MaxKeys=1 ) - def test_is_directory_returns_false_when_no_objects_exist(self, s3_storage, mock_boto3_client): + @pytest.mark.unit + def test_is_directory_returns_false_when_no_objects_exist( + self, s3_storage, mock_boto3_client + ): """Should return False when no objects exist with the directory prefix.""" path = "nonexistent/" @@ -319,6 +348,7 @@ class TestS3StorageIsDirectory: class TestS3StorageRemoveDirectory: """Test directory removal functionality.""" + @pytest.mark.unit def test_remove_directory_deletes_all_objects(self, s3_storage, mock_boto3_client): """Should delete all objects with the directory prefix.""" directory = "documents/" @@ -329,16 +359,13 @@ class TestS3StorageRemoveDirectory: { "Contents": [ {"Key": "documents/file1.txt"}, - {"Key": "documents/file2.txt"} + {"Key": "documents/file2.txt"}, ] } ] mock_boto3_client.delete_objects.return_value = { - "Deleted": [ - {"Key": "documents/file1.txt"}, - {"Key": "documents/file2.txt"} - ] + "Deleted": [{"Key": "documents/file1.txt"}, {"Key": "documents/file2.txt"}] } result = s3_storage.remove_directory(directory) @@ -349,7 +376,10 @@ class TestS3StorageRemoveDirectory: assert call_args["Bucket"] == "test-bucket" assert len(call_args["Delete"]["Objects"]) == 2 - def test_remove_directory_returns_false_when_empty(self, s3_storage, mock_boto3_client): + @pytest.mark.unit + def test_remove_directory_returns_false_when_empty( + self, s3_storage, mock_boto3_client + ): """Should return False when directory is empty (no objects to delete).""" directory = "empty/" @@ -362,7 +392,10 @@ class TestS3StorageRemoveDirectory: assert result is False mock_boto3_client.delete_objects.assert_not_called() - def test_remove_directory_returns_false_on_client_error(self, s3_storage, mock_boto3_client): + @pytest.mark.unit + def test_remove_directory_returns_false_on_client_error( + self, s3_storage, mock_boto3_client + ): """Should return False when deletion fails with ClientError.""" directory = "documents/" @@ -374,7 +407,7 @@ class TestS3StorageRemoveDirectory: mock_boto3_client.delete_objects.side_effect = ClientError( {"Error": {"Code": "AccessDenied", "Message": "Access denied"}}, - "delete_objects" + "delete_objects", ) result = s3_storage.remove_directory(directory) diff --git a/tests/test_app.py b/tests/test_app.py index 451dd0af..3c5d7937 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -1,12 +1,12 @@ -from flask import Flask - +import pytest from application.api.answer import answer from application.api.internal.routes import internal from application.api.user.routes import user from application.core.settings import settings +from flask import Flask - +@pytest.mark.unit def test_app_config(): app = Flask(__name__) app.register_blueprint(user) diff --git a/tests/test_cache.py b/tests/test_cache.py index af2b5e00..dc802c64 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,20 +1,20 @@ -import unittest import json -from unittest.mock import patch, MagicMock -from application.cache import gen_cache_key, stream_cache, gen_cache +from unittest.mock import MagicMock, patch + +import pytest +from application.cache import gen_cache, gen_cache_key, stream_cache from application.utils import get_hash -# Test for gen_cache_key function +@pytest.mark.unit def test_make_gen_cache_key(): messages = [ - {'role': 'user', 'content': 'test_user_message'}, - {'role': 'system', 'content': 'test_system_message'}, + {"role": "user", "content": "test_user_message"}, + {"role": "system", "content": "test_system_message"}, ] model = "test_docgpt" tools = None - - # Manually calculate the expected hash + messages_str = json.dumps(messages) tools_str = json.dumps(tools) if tools else "" expected_combined = f"{model}_{messages_str}_{tools_str}" @@ -23,112 +23,100 @@ def test_make_gen_cache_key(): assert cache_key == expected_hash -def test_gen_cache_key_invalid_message_format(): - # Test when messages is not a list - with unittest.TestCase.assertRaises(unittest.TestCase, ValueError) as context: - gen_cache_key("This is not a list", model="docgpt", tools=None) - assert str(context.exception) == "All messages must be dictionaries." -# Test for gen_cache decorator -@patch('application.cache.get_redis_instance') # Mock the Redis client +@pytest.mark.unit +def test_gen_cache_key_invalid_message_format(): + with pytest.raises(ValueError, match="All messages must be dictionaries."): + gen_cache_key("This is not a list", model="docgpt", tools=None) + + +@pytest.mark.unit +@patch("application.cache.get_redis_instance") def test_gen_cache_hit(mock_make_redis): - # Arrange mock_redis_instance = MagicMock() mock_make_redis.return_value = mock_redis_instance - mock_redis_instance.get.return_value = b"cached_result" # Simulate a cache hit + mock_redis_instance.get.return_value = b"cached_result" @gen_cache def mock_function(self, model, messages, stream, tools): return "new_result" - messages = [{'role': 'user', 'content': 'test_user_message'}] + messages = [{"role": "user", "content": "test_user_message"}] model = "test_docgpt" - # Act result = mock_function(None, model, messages, stream=False, tools=None) - # Assert - assert result == "cached_result" # Should return cached result - mock_redis_instance.get.assert_called_once() # Ensure Redis get was called - mock_redis_instance.set.assert_not_called() # Ensure the function result is not cached again + assert result == "cached_result" + mock_redis_instance.get.assert_called_once() + mock_redis_instance.set.assert_not_called() -@patch('application.cache.get_redis_instance') # Mock the Redis client +@pytest.mark.unit +@patch("application.cache.get_redis_instance") def test_gen_cache_miss(mock_make_redis): - # Arrange mock_redis_instance = MagicMock() mock_make_redis.return_value = mock_redis_instance - mock_redis_instance.get.return_value = None # Simulate a cache miss + mock_redis_instance.get.return_value = None @gen_cache def mock_function(self, model, messages, steam, tools): return "new_result" messages = [ - {'role': 'user', 'content': 'test_user_message'}, - {'role': 'system', 'content': 'test_system_message'}, + {"role": "user", "content": "test_user_message"}, + {"role": "system", "content": "test_system_message"}, ] model = "test_docgpt" - # Act + result = mock_function(None, model, messages, stream=False, tools=None) - # Assert assert result == "new_result" - mock_redis_instance.get.assert_called_once() + mock_redis_instance.get.assert_called_once() -@patch('application.cache.get_redis_instance') + +@pytest.mark.unit +@patch("application.cache.get_redis_instance") def test_stream_cache_hit(mock_make_redis): - # Arrange mock_redis_instance = MagicMock() mock_make_redis.return_value = mock_redis_instance - cached_chunk = json.dumps(["chunk1", "chunk2"]).encode('utf-8') + cached_chunk = json.dumps(["chunk1", "chunk2"]).encode("utf-8") mock_redis_instance.get.return_value = cached_chunk @stream_cache def mock_function(self, model, messages, stream, tools): yield "new_chunk" - messages = [{'role': 'user', 'content': 'test_user_message'}] + messages = [{"role": "user", "content": "test_user_message"}] model = "test_docgpt" - # Act result = list(mock_function(None, model, messages, stream=True, tools=None)) - # Assert - assert result == ["chunk1", "chunk2"] # Should return cached chunks + assert result == ["chunk1", "chunk2"] mock_redis_instance.get.assert_called_once() mock_redis_instance.set.assert_not_called() -@patch('application.cache.get_redis_instance') +@pytest.mark.unit +@patch("application.cache.get_redis_instance") def test_stream_cache_miss(mock_make_redis): - # Arrange mock_redis_instance = MagicMock() mock_make_redis.return_value = mock_redis_instance - mock_redis_instance.get.return_value = None # Simulate a cache miss + mock_redis_instance.get.return_value = None @stream_cache def mock_function(self, model, messages, stream, tools): yield "new_chunk" messages = [ - {'role': 'user', 'content': 'This is the context'}, - {'role': 'system', 'content': 'Some other message'}, - {'role': 'user', 'content': 'What is the answer?'} + {"role": "user", "content": "This is the context"}, + {"role": "system", "content": "Some other message"}, + {"role": "user", "content": "What is the answer?"}, ] model = "test_docgpt" - # Act result = list(mock_function(None, model, messages, stream=True, tools=None)) - # Assert assert result == ["new_chunk"] - mock_redis_instance.get.assert_called_once() - mock_redis_instance.set.assert_called_once() - - - - - - + mock_redis_instance.get.assert_called_once() + mock_redis_instance.set.assert_called_once() diff --git a/tests/test_celery.py b/tests/test_celery.py index f4ec6a03..0efb7206 100644 --- a/tests/test_celery.py +++ b/tests/test_celery.py @@ -1,21 +1,21 @@ from unittest.mock import patch -from application.core.settings import settings + +import pytest from application.celery_init import make_celery +from application.core.settings import settings -@patch('application.celery_init.Celery') +@pytest.mark.unit +@patch("application.celery_init.Celery") def test_make_celery(mock_celery): - # Arrange - app_name = 'test_app_name' + app_name = "test_app_name" - # Act celery = make_celery(app_name) - # Assert mock_celery.assert_called_once_with( - app_name, - broker=settings.CELERY_BROKER_URL, - backend=settings.CELERY_RESULT_BACKEND + app_name, + broker=settings.CELERY_BROKER_URL, + backend=settings.CELERY_RESULT_BACKEND, ) celery.conf.update.assert_called_once_with(settings) - assert celery == mock_celery.return_value \ No newline at end of file + assert celery == mock_celery.return_value diff --git a/tests/test_error.py b/tests/test_error.py index 07f362fc..b86de00a 100644 --- a/tests/test_error.py +++ b/tests/test_error.py @@ -1,6 +1,6 @@ import pytest -from flask import Flask from application.error import bad_request, response_error +from flask import Flask @pytest.fixture @@ -9,31 +9,35 @@ def app(): return app +@pytest.mark.unit def test_bad_request_with_message(app): with app.app_context(): message = "Invalid input" response = bad_request(status_code=400, message=message) assert response.status_code == 400 - assert response.json == {'error': 'Bad Request', 'message': message} + assert response.json == {"error": "Bad Request", "message": message} +@pytest.mark.unit def test_bad_request_without_message(app): with app.app_context(): response = bad_request(status_code=400) assert response.status_code == 400 - assert response.json == {'error': 'Bad Request'} + assert response.json == {"error": "Bad Request"} +@pytest.mark.unit def test_response_error_with_message(app): with app.app_context(): message = "Something went wrong" response = response_error(code_status=500, message=message) assert response.status_code == 500 - assert response.json == {'error': 'Internal Server Error', 'message': message} + assert response.json == {"error": "Internal Server Error", "message": message} +@pytest.mark.unit def test_response_error_without_message(app): with app.app_context(): response = response_error(code_status=500) assert response.status_code == 500 - assert response.json == {'error': 'Internal Server Error'} \ No newline at end of file + assert response.json == {"error": "Internal Server Error"} diff --git a/tests/test_memory_tool.py b/tests/test_memory_tool.py index 2b041c6e..be00dbd4 100644 --- a/tests/test_memory_tool.py +++ b/tests/test_memory_tool.py @@ -17,6 +17,7 @@ def memory_tool(monkeypatch) -> MemoryTool: path = doc.get("path") key = f"{user_id}:{tool_id}:{path}" # Add _id to document if not present + if "_id" not in doc: doc["_id"] = key self.docs[key] = doc @@ -24,16 +25,17 @@ def memory_tool(monkeypatch) -> MemoryTool: def update_one(self, q, u, upsert=False): # Handle query by _id + if "_id" in q: doc_id = q["_id"] if doc_id not in self.docs: return type("res", (), {"modified_count": 0}) - if "$set" in u: old_doc = self.docs[doc_id].copy() old_doc.update(u["$set"]) # If path changed, update the dictionary key + if "path" in u["$set"]: new_path = u["$set"]["path"] user_id = old_doc.get("user_id") @@ -41,15 +43,15 @@ def memory_tool(monkeypatch) -> MemoryTool: new_key = f"{user_id}:{tool_id}:{new_path}" # Remove old key and add with new key + del self.docs[doc_id] old_doc["_id"] = new_key self.docs[new_key] = old_doc else: self.docs[doc_id] = old_doc - return type("res", (), {"modified_count": 1}) - # Handle query by user_id, tool_id, path + user_id = q.get("user_id") tool_id = q.get("tool_id") path = q.get("path") @@ -57,13 +59,16 @@ def memory_tool(monkeypatch) -> MemoryTool: if key not in self.docs and not upsert: return type("res", (), {"modified_count": 0}) - if key not in self.docs and upsert: - self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "path": path, "content": "", "_id": key} - + self.docs[key] = { + "user_id": user_id, + "tool_id": tool_id, + "path": path, + "content": "", + "_id": key, + } if "$set" in u: self.docs[key].update(u["$set"]) - return type("res", (), {"modified_count": 1}) def find_one(self, q, projection=None): @@ -74,7 +79,6 @@ def memory_tool(monkeypatch) -> MemoryTool: if path: key = f"{user_id}:{tool_id}:{path}" return self.docs.get(key) - return None def find(self, q, projection=None): @@ -83,9 +87,11 @@ def memory_tool(monkeypatch) -> MemoryTool: results = [] # Handle regex queries for directory listing + if "path" in q and isinstance(q["path"], dict) and "$regex" in q["path"]: regex_pattern = q["path"]["$regex"] # Remove regex escape characters and ^ anchor for simple matching + pattern = regex_pattern.replace("\\", "").lstrip("^") for key, doc in self.docs.items(): @@ -97,7 +103,6 @@ def memory_tool(monkeypatch) -> MemoryTool: for key, doc in self.docs.items(): if doc.get("user_id") == user_id and doc.get("tool_id") == tool_id: results.append(doc) - return results def delete_one(self, q): @@ -109,7 +114,6 @@ def memory_tool(monkeypatch) -> MemoryTool: if key in self.docs: del self.docs[key] return type("res", (), {"deleted_count": 1}) - return type("res", (), {"deleted_count": 0}) def delete_many(self, q): @@ -118,6 +122,7 @@ def memory_tool(monkeypatch) -> MemoryTool: deleted = 0 # Handle regex queries for directory deletion + if "path" in q and isinstance(q["path"], dict) and "$regex" in q["path"]: regex_pattern = q["path"]["$regex"] pattern = regex_pattern.replace("\\", "").lstrip("^") @@ -128,32 +133,36 @@ def memory_tool(monkeypatch) -> MemoryTool: doc_path = doc.get("path", "") if doc_path.startswith(pattern): keys_to_delete.append(key) - for key in keys_to_delete: del self.docs[key] deleted += 1 else: # Delete all for user and tool + keys_to_delete = [ - key for key, doc in self.docs.items() + key + for key, doc in self.docs.items() if doc.get("user_id") == user_id and doc.get("tool_id") == tool_id ] for key in keys_to_delete: del self.docs[key] deleted += 1 - return type("res", (), {"deleted_count": deleted}) fake_collection = FakeCollection() fake_db = {"memories": fake_collection} fake_client = {settings.MONGO_DB_NAME: fake_db} - monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client) + monkeypatch.setattr( + "application.core.mongo_db.MongoDB.get_client", lambda: fake_client + ) # Return tool with a fixed tool_id for consistency in tests + return MemoryTool({"tool_id": "test_tool_id"}, user_id="test_user") +@pytest.mark.unit def test_init_without_user_id(): """Should fail gracefully if no user_id is provided.""" memory_tool = MemoryTool(tool_config={}) @@ -161,90 +170,78 @@ def test_init_without_user_id(): assert "user_id" in result.lower() +@pytest.mark.unit def test_view_empty_directory(memory_tool: MemoryTool) -> None: """Should show empty directory when no files exist.""" result = memory_tool.execute_action("view", path="/") assert "empty" in result.lower() +@pytest.mark.unit def test_create_and_view_file(memory_tool: MemoryTool) -> None: """Test creating a file and viewing it.""" # Create a file + result = memory_tool.execute_action( - "create", - path="/notes.txt", - file_text="Hello world" + "create", path="/notes.txt", file_text="Hello world" ) assert "created" in result.lower() # View the file + result = memory_tool.execute_action("view", path="/notes.txt") assert "Hello world" in result +@pytest.mark.unit def test_create_overwrite_file(memory_tool: MemoryTool) -> None: """Test that create overwrites existing files.""" # Create initial file - memory_tool.execute_action( - "create", - path="/test.txt", - file_text="Original content" - ) + + memory_tool.execute_action("create", path="/test.txt", file_text="Original content") # Overwrite - memory_tool.execute_action( - "create", - path="/test.txt", - file_text="New content" - ) + + memory_tool.execute_action("create", path="/test.txt", file_text="New content") # Verify overwrite + result = memory_tool.execute_action("view", path="/test.txt") assert "New content" in result assert "Original content" not in result +@pytest.mark.unit def test_view_directory_with_files(memory_tool: MemoryTool) -> None: """Test viewing directory contents.""" # Create multiple files + + memory_tool.execute_action("create", path="/file1.txt", file_text="Content 1") + memory_tool.execute_action("create", path="/file2.txt", file_text="Content 2") memory_tool.execute_action( - "create", - path="/file1.txt", - file_text="Content 1" - ) - memory_tool.execute_action( - "create", - path="/file2.txt", - file_text="Content 2" - ) - memory_tool.execute_action( - "create", - path="/subdir/file3.txt", - file_text="Content 3" + "create", path="/subdir/file3.txt", file_text="Content 3" ) # View directory + result = memory_tool.execute_action("view", path="/") assert "file1.txt" in result assert "file2.txt" in result assert "subdir/file3.txt" in result +@pytest.mark.unit def test_view_file_with_line_range(memory_tool: MemoryTool) -> None: """Test viewing specific lines from a file.""" # Create a multiline file + content = "Line 1\nLine 2\nLine 3\nLine 4\nLine 5" - memory_tool.execute_action( - "create", - path="/multiline.txt", - file_text=content - ) + memory_tool.execute_action("create", path="/multiline.txt", file_text=content) # View lines 2-4 + result = memory_tool.execute_action( - "view", - path="/multiline.txt", - view_range=[2, 4] + "view", path="/multiline.txt", view_range=[2, 4] ) assert "Line 2" in result assert "Line 3" in result @@ -253,197 +250,177 @@ def test_view_file_with_line_range(memory_tool: MemoryTool) -> None: assert "Line 5" not in result +@pytest.mark.unit def test_str_replace(memory_tool: MemoryTool) -> None: """Test string replacement in a file.""" # Create a file + memory_tool.execute_action( - "create", - path="/replace.txt", - file_text="Hello world, hello universe" + "create", path="/replace.txt", file_text="Hello world, hello universe" ) # Replace text + result = memory_tool.execute_action( - "str_replace", - path="/replace.txt", - old_str="hello", - new_str="hi" + "str_replace", path="/replace.txt", old_str="hello", new_str="hi" ) assert "updated" in result.lower() # Verify replacement + content = memory_tool.execute_action("view", path="/replace.txt") assert "hi world, hi universe" in content +@pytest.mark.unit def test_str_replace_not_found(memory_tool: MemoryTool) -> None: """Test string replacement when string not found.""" - memory_tool.execute_action( - "create", - path="/test.txt", - file_text="Hello world" - ) + memory_tool.execute_action("create", path="/test.txt", file_text="Hello world") result = memory_tool.execute_action( - "str_replace", - path="/test.txt", - old_str="goodbye", - new_str="hi" + "str_replace", path="/test.txt", old_str="goodbye", new_str="hi" ) assert "not found" in result.lower() +@pytest.mark.unit def test_insert_line(memory_tool: MemoryTool) -> None: """Test inserting text at a line number.""" # Create a multiline file + memory_tool.execute_action( - "create", - path="/insert.txt", - file_text="Line 1\nLine 2\nLine 3" + "create", path="/insert.txt", file_text="Line 1\nLine 2\nLine 3" ) # Insert at line 2 + result = memory_tool.execute_action( - "insert", - path="/insert.txt", - insert_line=2, - insert_text="Inserted line" + "insert", path="/insert.txt", insert_line=2, insert_text="Inserted line" ) assert "inserted" in result.lower() # Verify insertion + content = memory_tool.execute_action("view", path="/insert.txt") lines = content.split("\n") assert lines[1] == "Inserted line" assert lines[2] == "Line 2" +@pytest.mark.unit def test_insert_invalid_line(memory_tool: MemoryTool) -> None: """Test inserting at an invalid line number.""" - memory_tool.execute_action( - "create", - path="/test.txt", - file_text="Line 1\nLine 2" - ) + memory_tool.execute_action("create", path="/test.txt", file_text="Line 1\nLine 2") result = memory_tool.execute_action( - "insert", - path="/test.txt", - insert_line=100, - insert_text="Text" + "insert", path="/test.txt", insert_line=100, insert_text="Text" ) assert "invalid" in result.lower() +@pytest.mark.unit def test_delete_file(memory_tool: MemoryTool) -> None: """Test deleting a file.""" # Create a file - memory_tool.execute_action( - "create", - path="/delete_me.txt", - file_text="Content" - ) + + memory_tool.execute_action("create", path="/delete_me.txt", file_text="Content") # Delete it + result = memory_tool.execute_action("delete", path="/delete_me.txt") assert "deleted" in result.lower() # Verify it's gone + result = memory_tool.execute_action("view", path="/delete_me.txt") assert "not found" in result.lower() +@pytest.mark.unit def test_delete_nonexistent_file(memory_tool: MemoryTool) -> None: """Test deleting a file that doesn't exist.""" result = memory_tool.execute_action("delete", path="/nonexistent.txt") assert "not found" in result.lower() +@pytest.mark.unit def test_delete_directory(memory_tool: MemoryTool) -> None: """Test deleting a directory with files.""" # Create files in a directory + memory_tool.execute_action( - "create", - path="/subdir/file1.txt", - file_text="Content 1" + "create", path="/subdir/file1.txt", file_text="Content 1" ) memory_tool.execute_action( - "create", - path="/subdir/file2.txt", - file_text="Content 2" + "create", path="/subdir/file2.txt", file_text="Content 2" ) # Delete the directory + result = memory_tool.execute_action("delete", path="/subdir/") assert "deleted" in result.lower() # Verify files are gone + result = memory_tool.execute_action("view", path="/subdir/file1.txt") assert "not found" in result.lower() +@pytest.mark.unit def test_rename_file(memory_tool: MemoryTool) -> None: """Test renaming a file.""" # Create a file - memory_tool.execute_action( - "create", - path="/old_name.txt", - file_text="Content" - ) + + memory_tool.execute_action("create", path="/old_name.txt", file_text="Content") # Rename it + result = memory_tool.execute_action( - "rename", - old_path="/old_name.txt", - new_path="/new_name.txt" + "rename", old_path="/old_name.txt", new_path="/new_name.txt" ) assert "renamed" in result.lower() # Verify old path doesn't exist + result = memory_tool.execute_action("view", path="/old_name.txt") assert "not found" in result.lower() # Verify new path exists + result = memory_tool.execute_action("view", path="/new_name.txt") assert "Content" in result +@pytest.mark.unit def test_rename_nonexistent_file(memory_tool: MemoryTool) -> None: """Test renaming a file that doesn't exist.""" result = memory_tool.execute_action( - "rename", - old_path="/nonexistent.txt", - new_path="/new.txt" + "rename", old_path="/nonexistent.txt", new_path="/new.txt" ) assert "not found" in result.lower() +@pytest.mark.unit def test_rename_to_existing_file(memory_tool: MemoryTool) -> None: """Test renaming to a path that already exists.""" # Create two files - memory_tool.execute_action( - "create", - path="/file1.txt", - file_text="Content 1" - ) - memory_tool.execute_action( - "create", - path="/file2.txt", - file_text="Content 2" - ) + + memory_tool.execute_action("create", path="/file1.txt", file_text="Content 1") + memory_tool.execute_action("create", path="/file2.txt", file_text="Content 2") # Try to rename file1 to file2 + result = memory_tool.execute_action( - "rename", - old_path="/file1.txt", - new_path="/file2.txt" + "rename", old_path="/file1.txt", new_path="/file2.txt" ) assert "already exists" in result.lower() +@pytest.mark.unit def test_path_traversal_protection(memory_tool: MemoryTool) -> None: """Test that directory traversal attacks are prevented.""" # Try various path traversal attempts + invalid_paths = [ "/../secrets.txt", "/../../etc/passwd", @@ -453,16 +430,16 @@ def test_path_traversal_protection(memory_tool: MemoryTool) -> None: for path in invalid_paths: result = memory_tool.execute_action( - "create", - path=path, - file_text="malicious content" + "create", path=path, file_text="malicious content" ) assert "invalid path" in result.lower() +@pytest.mark.unit def test_path_must_start_with_slash(memory_tool: MemoryTool) -> None: """Test that paths work with or without leading slash (auto-normalized).""" # These paths should all work now (auto-prepended with /) + valid_paths = [ "etc/passwd", # Auto-prepended with / "home/user/file.txt", # Auto-prepended with / @@ -470,33 +447,29 @@ def test_path_must_start_with_slash(memory_tool: MemoryTool) -> None: ] for path in valid_paths: - result = memory_tool.execute_action( - "create", - path=path, - file_text="content" - ) + result = memory_tool.execute_action("create", path=path, file_text="content") assert "created" in result.lower() # Verify the file can be accessed with or without leading slash + view_result = memory_tool.execute_action("view", path=path) assert "content" in view_result +@pytest.mark.unit def test_cannot_create_directory_as_file(memory_tool: MemoryTool) -> None: """Test that you cannot create a file at a directory path.""" - result = memory_tool.execute_action( - "create", - path="/", - file_text="content" - ) + result = memory_tool.execute_action("create", path="/", file_text="content") assert "cannot create a file at directory path" in result.lower() +@pytest.mark.unit def test_get_actions_metadata(memory_tool: MemoryTool) -> None: """Test that action metadata is properly defined.""" metadata = memory_tool.get_actions_metadata() # Check that all expected actions are defined + action_names = [action["name"] for action in metadata] assert "view" in action_names assert "create" in action_names @@ -506,15 +479,18 @@ def test_get_actions_metadata(memory_tool: MemoryTool) -> None: assert "rename" in action_names # Check that each action has required fields + for action in metadata: assert "name" in action assert "description" in action assert "parameters" in action +@pytest.mark.unit def test_memory_tool_isolation(monkeypatch) -> None: """Test that different memory tool instances have isolated memories.""" # Create fake collection + class FakeCollection: def __init__(self) -> None: self.docs = {} @@ -529,16 +505,17 @@ def test_memory_tool_isolation(monkeypatch) -> None: def update_one(self, q, u, upsert=False): # Handle query by _id + if "_id" in q: doc_id = q["_id"] if doc_id not in self.docs: return type("res", (), {"modified_count": 0}) - if "$set" in u: old_doc = self.docs[doc_id].copy() old_doc.update(u["$set"]) # If path changed, update the dictionary key + if "path" in u["$set"]: new_path = u["$set"]["path"] user_id = old_doc.get("user_id") @@ -546,15 +523,15 @@ def test_memory_tool_isolation(monkeypatch) -> None: new_key = f"{user_id}:{tool_id}:{new_path}" # Remove old key and add with new key + del self.docs[doc_id] old_doc["_id"] = new_key self.docs[new_key] = old_doc else: self.docs[doc_id] = old_doc - return type("res", (), {"modified_count": 1}) - # Handle query by user_id, tool_id, path + user_id = q.get("user_id") tool_id = q.get("tool_id") path = q.get("path") @@ -562,13 +539,16 @@ def test_memory_tool_isolation(monkeypatch) -> None: if key not in self.docs and not upsert: return type("res", (), {"modified_count": 0}) - if key not in self.docs and upsert: - self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "path": path, "content": "", "_id": key} - + self.docs[key] = { + "user_id": user_id, + "tool_id": tool_id, + "path": path, + "content": "", + "_id": key, + } if "$set" in u: self.docs[key].update(u["$set"]) - return type("res", (), {"modified_count": 1}) def find_one(self, q, projection=None): @@ -579,26 +559,31 @@ def test_memory_tool_isolation(monkeypatch) -> None: if path: key = f"{user_id}:{tool_id}:{path}" return self.docs.get(key) - return None fake_collection = FakeCollection() fake_db = {"memories": fake_collection} fake_client = {settings.MONGO_DB_NAME: fake_db} - monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client) + monkeypatch.setattr( + "application.core.mongo_db.MongoDB.get_client", lambda: fake_client + ) # Create two memory tools with different tool_ids for the same user + tool1 = MemoryTool({"tool_id": "tool_1"}, user_id="test_user") tool2 = MemoryTool({"tool_id": "tool_2"}, user_id="test_user") # Create a file in tool1 + tool1.execute_action("create", path="/file.txt", file_text="Content from tool 1") # Create a file with the same path in tool2 + tool2.execute_action("create", path="/file.txt", file_text="Content from tool 2") # Verify that each tool sees only its own content + result1 = tool1.execute_action("view", path="/file.txt") result2 = tool2.execute_action("view", path="/file.txt") @@ -609,8 +594,10 @@ def test_memory_tool_isolation(monkeypatch) -> None: assert "Content from tool 1" not in result2 +@pytest.mark.unit def test_memory_tool_auto_generates_tool_id(monkeypatch) -> None: """Test that tool_id defaults to 'default_{user_id}' for persistence.""" + class FakeCollection: def __init__(self) -> None: self.docs = {} @@ -622,78 +609,94 @@ def test_memory_tool_auto_generates_tool_id(monkeypatch) -> None: fake_db = {"memories": fake_collection} fake_client = {settings.MONGO_DB_NAME: fake_db} - monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client) + monkeypatch.setattr( + "application.core.mongo_db.MongoDB.get_client", lambda: fake_client + ) # Create two tools without providing tool_id for the same user + tool1 = MemoryTool({}, user_id="test_user") tool2 = MemoryTool({}, user_id="test_user") # Both should have the same default tool_id for persistence + assert tool1.tool_id == "default_test_user" assert tool2.tool_id == "default_test_user" assert tool1.tool_id == tool2.tool_id # Different users should have different tool_ids + tool3 = MemoryTool({}, user_id="another_user") assert tool3.tool_id == "default_another_user" assert tool3.tool_id != tool1.tool_id +@pytest.mark.unit def test_paths_without_leading_slash(memory_tool) -> None: """Test that paths without leading slash work correctly.""" # Create file without leading slash - result = memory_tool.execute_action("create", path="cat_breeds.txt", file_text="- Korat\n- Chartreux\n- British Shorthair\n- Nebelung") + + result = memory_tool.execute_action( + "create", + path="cat_breeds.txt", + file_text="- Korat\n- Chartreux\n- British Shorthair\n- Nebelung", + ) assert "created" in result.lower() # View file without leading slash + view_result = memory_tool.execute_action("view", path="cat_breeds.txt") assert "Korat" in view_result assert "Chartreux" in view_result # View same file with leading slash (should work the same) + view_result2 = memory_tool.execute_action("view", path="/cat_breeds.txt") assert "Korat" in view_result2 # Test str_replace without leading slash - replace_result = memory_tool.execute_action("str_replace", path="cat_breeds.txt", old_str="Korat", new_str="Maine Coon") + + replace_result = memory_tool.execute_action( + "str_replace", path="cat_breeds.txt", old_str="Korat", new_str="Maine Coon" + ) assert "updated" in replace_result.lower() # Test nested path without leading slash - nested_result = memory_tool.execute_action("create", path="projects/tasks.txt", file_text="Task 1\nTask 2") + + nested_result = memory_tool.execute_action( + "create", path="projects/tasks.txt", file_text="Task 1\nTask 2" + ) assert "created" in nested_result.lower() view_nested = memory_tool.execute_action("view", path="projects/tasks.txt") assert "Task 1" in view_nested +@pytest.mark.unit def test_rename_directory(memory_tool: MemoryTool) -> None: """Test renaming a directory with files.""" # Create files in a directory + + memory_tool.execute_action("create", path="/docs/file1.txt", file_text="Content 1") memory_tool.execute_action( - "create", - path="/docs/file1.txt", - file_text="Content 1" - ) - memory_tool.execute_action( - "create", - path="/docs/sub/file2.txt", - file_text="Content 2" + "create", path="/docs/sub/file2.txt", file_text="Content 2" ) # Rename directory (with trailing slash) + result = memory_tool.execute_action( - "rename", - old_path="/docs/", - new_path="/archive/" + "rename", old_path="/docs/", new_path="/archive/" ) assert "renamed" in result.lower() assert "2 files" in result.lower() # Verify old paths don't exist + result = memory_tool.execute_action("view", path="/docs/file1.txt") assert "not found" in result.lower() # Verify new paths exist + result = memory_tool.execute_action("view", path="/archive/file1.txt") assert "Content 1" in result @@ -701,29 +704,25 @@ def test_rename_directory(memory_tool: MemoryTool) -> None: assert "Content 2" in result +@pytest.mark.unit def test_rename_directory_without_trailing_slash(memory_tool: MemoryTool) -> None: """Test renaming a directory when new path is missing trailing slash.""" # Create files in a directory + + memory_tool.execute_action("create", path="/docs/file1.txt", file_text="Content 1") memory_tool.execute_action( - "create", - path="/docs/file1.txt", - file_text="Content 1" - ) - memory_tool.execute_action( - "create", - path="/docs/sub/file2.txt", - file_text="Content 2" + "create", path="/docs/sub/file2.txt", file_text="Content 2" ) # Rename directory - old path has slash, new path doesn't + result = memory_tool.execute_action( - "rename", - old_path="/docs/", - new_path="/archive" # Missing trailing slash + "rename", old_path="/docs/", new_path="/archive" # Missing trailing slash ) assert "renamed" in result.lower() # Verify paths are correct (not corrupted like "/archivesub/file2.txt") + result = memory_tool.execute_action("view", path="/archive/file1.txt") assert "Content 1" in result @@ -731,28 +730,25 @@ def test_rename_directory_without_trailing_slash(memory_tool: MemoryTool) -> Non assert "Content 2" in result # Verify corrupted path doesn't exist + result = memory_tool.execute_action("view", path="/archivesub/file2.txt") assert "not found" in result.lower() +@pytest.mark.unit def test_view_file_line_numbers(memory_tool: MemoryTool) -> None: """Test that view_range displays correct line numbers.""" # Create a multiline file + content = "Line 1\nLine 2\nLine 3\nLine 4\nLine 5" - memory_tool.execute_action( - "create", - path="/numbered.txt", - file_text=content - ) + memory_tool.execute_action("create", path="/numbered.txt", file_text=content) # View lines 2-4 - result = memory_tool.execute_action( - "view", - path="/numbered.txt", - view_range=[2, 4] - ) + + result = memory_tool.execute_action("view", path="/numbered.txt", view_range=[2, 4]) # Check that line numbers are correct (should be 2, 3, 4 not 3, 4, 5) + assert "2: Line 2" in result assert "3: Line 3" in result assert "4: Line 4" in result @@ -760,6 +756,7 @@ def test_view_file_line_numbers(memory_tool: MemoryTool) -> None: assert "5: Line 5" not in result # Verify no off-by-one error + assert "3: Line 2" not in result # Wrong line number assert "4: Line 3" not in result # Wrong line number assert "5: Line 4" not in result # Wrong line number diff --git a/tests/test_notes_tool.py b/tests/test_notes_tool.py index f57c5435..d6bd63f8 100644 --- a/tests/test_notes_tool.py +++ b/tests/test_notes_tool.py @@ -3,10 +3,10 @@ from application.agents.tools.notes import NotesTool from application.core.settings import settings - @pytest.fixture def notes_tool(monkeypatch) -> NotesTool: """Provide a NotesTool with a fake Mongo collection and fixed user_id.""" + class FakeCollection: def __init__(self) -> None: self.docs = {} # key: user_id:tool_id -> doc @@ -17,6 +17,7 @@ def notes_tool(monkeypatch) -> NotesTool: key = f"{user_id}:{tool_id}" # emulate single-note storage with optional upsert + if key not in self.docs and not upsert: return type("res", (), {"modified_count": 0}) if key not in self.docs and upsert: @@ -45,34 +46,45 @@ def notes_tool(monkeypatch) -> NotesTool: fake_client = {settings.MONGO_DB_NAME: fake_db} # Patch MongoDB client globally for the tool - monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client) + + monkeypatch.setattr( + "application.core.mongo_db.MongoDB.get_client", lambda: fake_client + ) # Return tool with a fixed tool_id for consistency in tests + return NotesTool({"tool_id": "test_tool_id"}, user_id="test_user") +@pytest.mark.unit def test_view(notes_tool: NotesTool) -> None: # Manually insert a note to test retrieval + notes_tool.collection.update_one( {"user_id": "test_user", "tool_id": "test_tool_id"}, {"$set": {"note": "hello"}}, - upsert=True + upsert=True, ) assert "hello" in notes_tool.execute_action("view") +@pytest.mark.unit def test_overwrite_and_delete(notes_tool: NotesTool) -> None: # Overwrite creates a new note + assert "saved" in notes_tool.execute_action("overwrite", text="first").lower() assert "first" in notes_tool.execute_action("view") # Overwrite replaces existing note + assert "saved" in notes_tool.execute_action("overwrite", text="second").lower() assert "second" in notes_tool.execute_action("view") assert "deleted" in notes_tool.execute_action("delete").lower() assert "no note" in notes_tool.execute_action("view").lower() + +@pytest.mark.unit def test_init_without_user_id(monkeypatch): """Should fail gracefully if no user_id is provided.""" notes_tool = NotesTool(tool_config={}) @@ -80,26 +92,32 @@ def test_init_without_user_id(monkeypatch): assert "user_id" in str(result).lower() +@pytest.mark.unit def test_view_not_found(notes_tool: NotesTool) -> None: """Should return 'No note found.' when no note exists""" result = notes_tool.execute_action("view") assert "no note found" in result.lower() +@pytest.mark.unit def test_str_replace(notes_tool: NotesTool) -> None: """Test string replacement in note""" # Create a note + notes_tool.execute_action("overwrite", text="Hello world, hello universe") # Replace text + result = notes_tool.execute_action("str_replace", old_str="hello", new_str="hi") assert "updated" in result.lower() # Verify replacement + note = notes_tool.execute_action("view") assert "hi world, hi universe" in note.lower() +@pytest.mark.unit def test_str_replace_not_found(notes_tool: NotesTool) -> None: """Test string replacement when string not found""" notes_tool.execute_action("overwrite", text="Hello world") @@ -107,22 +125,27 @@ def test_str_replace_not_found(notes_tool: NotesTool) -> None: assert "not found" in result.lower() +@pytest.mark.unit def test_insert_line(notes_tool: NotesTool) -> None: """Test inserting text at a line number""" # Create a multiline note + notes_tool.execute_action("overwrite", text="Line 1\nLine 2\nLine 3") # Insert at line 2 + result = notes_tool.execute_action("insert", line_number=2, text="Inserted line") assert "inserted" in result.lower() # Verify insertion + note = notes_tool.execute_action("view") lines = note.split("\n") assert lines[1] == "Inserted line" assert lines[2] == "Line 2" +@pytest.mark.unit def test_delete_nonexistent_note(monkeypatch): class FakeResult: deleted_count = 0 @@ -133,7 +156,7 @@ def test_delete_nonexistent_note(monkeypatch): monkeypatch.setattr( "application.core.mongo_db.MongoDB.get_client", - lambda: {"docsgpt": {"notes": FakeCollection()}} + lambda: {"docsgpt": {"notes": FakeCollection()}}, ) notes_tool = NotesTool(tool_config={}, user_id="user123") @@ -141,8 +164,10 @@ def test_delete_nonexistent_note(monkeypatch): assert "no note found" in result.lower() +@pytest.mark.unit def test_notes_tool_isolation(monkeypatch) -> None: """Test that different notes tool instances have isolated notes.""" + class FakeCollection: def __init__(self) -> None: self.docs = {} @@ -170,19 +195,25 @@ def test_notes_tool_isolation(monkeypatch) -> None: fake_db = {"notes": fake_collection} fake_client = {settings.MONGO_DB_NAME: fake_db} - monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client) + monkeypatch.setattr( + "application.core.mongo_db.MongoDB.get_client", lambda: fake_client + ) # Create two notes tools with different tool_ids for the same user + tool1 = NotesTool({"tool_id": "tool_1"}, user_id="test_user") tool2 = NotesTool({"tool_id": "tool_2"}, user_id="test_user") # Create a note in tool1 + tool1.execute_action("overwrite", text="Content from tool 1") # Create a note in tool2 + tool2.execute_action("overwrite", text="Content from tool 2") # Verify that each tool sees only its own content + result1 = tool1.execute_action("view") result2 = tool2.execute_action("view") @@ -193,8 +224,10 @@ def test_notes_tool_isolation(monkeypatch) -> None: assert "Content from tool 1" not in result2 +@pytest.mark.unit def test_notes_tool_auto_generates_tool_id(monkeypatch) -> None: """Test that tool_id defaults to 'default_{user_id}' for persistence.""" + class FakeCollection: def __init__(self) -> None: self.docs = {} @@ -206,18 +239,23 @@ def test_notes_tool_auto_generates_tool_id(monkeypatch) -> None: fake_db = {"notes": fake_collection} fake_client = {settings.MONGO_DB_NAME: fake_db} - monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client) + monkeypatch.setattr( + "application.core.mongo_db.MongoDB.get_client", lambda: fake_client + ) # Create two tools without providing tool_id for the same user + tool1 = NotesTool({}, user_id="test_user") tool2 = NotesTool({}, user_id="test_user") # Both should have the same default tool_id for persistence + assert tool1.tool_id == "default_test_user" assert tool2.tool_id == "default_test_user" assert tool1.tool_id == tool2.tool_id # Different users should have different tool_ids + tool3 = NotesTool({}, user_id="another_user") assert tool3.tool_id == "default_another_user" assert tool3.tool_id != tool1.tool_id diff --git a/tests/test_openapi3parser.py b/tests/test_openapi3parser.py index 1ce01ef5..1d23d990 100644 --- a/tests/test_openapi3parser.py +++ b/tests/test_openapi3parser.py @@ -1,6 +1,6 @@ import pytest -from openapi_parser import parse from application.parser.file.openapi3_parser import OpenAPI3Parser +from openapi_parser import parse @pytest.mark.parametrize( @@ -17,10 +17,12 @@ from application.parser.file.openapi3_parser import OpenAPI3Parser ), ], ) +@pytest.mark.unit def test_get_base_urls(urls, expected_base_urls): assert OpenAPI3Parser().get_base_urls(urls) == expected_base_urls +@pytest.mark.unit def test_get_info_from_paths(): file_path = "tests/test_openapi3.yaml" data = parse(file_path) @@ -31,6 +33,7 @@ def test_get_info_from_paths(): ) +@pytest.mark.unit def test_parse_file(): file_path = "tests/test_openapi3.yaml" results_expected = (