test: add agent test coverage and standardize test suite (#2051)

- Add 104 comprehensive tests for agent system
- Integrate agent tests into CI/CD pipeline
- Standardize tests with @pytest.mark.unit markers
- Fix cross-platform path compatibility
- Clean up unused imports and dependencies
This commit is contained in:
Siddhant Rai
2025-10-13 17:13:35 +05:30
committed by GitHub
parent 1805292528
commit d6c49bdbf0
24 changed files with 2775 additions and 501 deletions

View File

@@ -16,15 +16,15 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install pytest pytest-cov
cd application cd application
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
cd ../tests
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Test with pytest and generate coverage report - name: Test with pytest and generate coverage report
run: | run: |
python -m pytest --cov=application --cov-report=xml python -m pytest --cov=application --cov-report=xml --cov-report=term-missing
- name: Upload coverage reports to Codecov - name: Upload coverage reports to Codecov
if: github.event_name == 'pull_request' && matrix.python-version == '3.12' if: github.event_name == 'pull_request' && matrix.python-version == '3.12'
uses: codecov/codecov-action@v5 uses: codecov/codecov-action@v5
env: env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}

View File

@@ -38,7 +38,7 @@ class BaseAgent(ABC):
self.user_api_key = user_api_key self.user_api_key = user_api_key
self.prompt = prompt self.prompt = prompt
self.decoded_token = decoded_token or {} self.decoded_token = decoded_token or {}
self.user: str = decoded_token.get("sub") self.user: str = self.decoded_token.get("sub")
self.tool_config: Dict = {} self.tool_config: Dict = {}
self.tools: List[Dict] = [] self.tools: List[Dict] = []
self.tool_calls: List[Dict] = [] self.tool_calls: List[Dict] = []

View File

@@ -23,7 +23,9 @@ class ToolActionParser:
# If the tool name doesn't contain an underscore, it's likely a hallucinated tool # If the tool name doesn't contain an underscore, it's likely a hallucinated tool
if len(tool_parts) < 2: if len(tool_parts) < 2:
logger.warning(f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id") logger.warning(
f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id"
)
return None, None, None return None, None, None
tool_id = tool_parts[-1] tool_id = tool_parts[-1]
@@ -31,9 +33,11 @@ class ToolActionParser:
# Validate that tool_id looks like a numerical ID # Validate that tool_id looks like a numerical ID
if not tool_id.isdigit(): if not tool_id.isdigit():
logger.warning(f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call.") logger.warning(
f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call."
)
except (AttributeError, TypeError) as e: except (AttributeError, TypeError, json.JSONDecodeError) as e:
logger.error(f"Error parsing OpenAI LLM call: {e}") logger.error(f"Error parsing OpenAI LLM call: {e}")
return None, None, None return None, None, None
return tool_id, action_name, call_args return tool_id, action_name, call_args
@@ -45,7 +49,9 @@ class ToolActionParser:
# If the tool name doesn't contain an underscore, it's likely a hallucinated tool # If the tool name doesn't contain an underscore, it's likely a hallucinated tool
if len(tool_parts) < 2: if len(tool_parts) < 2:
logger.warning(f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id") logger.warning(
f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id"
)
return None, None, None return None, None, None
tool_id = tool_parts[-1] tool_id = tool_parts[-1]
@@ -53,7 +59,9 @@ class ToolActionParser:
# Validate that tool_id looks like a numerical ID # Validate that tool_id looks like a numerical ID
if not tool_id.isdigit(): if not tool_id.isdigit():
logger.warning(f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call.") logger.warning(
f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call."
)
except (AttributeError, TypeError) as e: except (AttributeError, TypeError) as e:
logger.error(f"Error parsing Google LLM call: {e}") logger.error(f"Error parsing Google LLM call: {e}")

20
pytest.ini Normal file
View File

@@ -0,0 +1,20 @@
[pytest]
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
addopts =
-v
--strict-markers
--tb=short
--cov=application
--cov-report=html
--cov-report=term-missing
--cov-report=xml
markers =
unit: Unit tests
integration: Integration tests
slow: Slow running tests
filterwarnings =
ignore::DeprecationWarning
ignore::PendingDeprecationWarning

0
tests/agents/__init__.py Normal file
View File

View File

@@ -0,0 +1,56 @@
import pytest
from application.agents.agent_creator import AgentCreator
from application.agents.classic_agent import ClassicAgent
from application.agents.react_agent import ReActAgent
@pytest.mark.unit
class TestAgentCreator:
def test_create_classic_agent(self, agent_base_params):
agent = AgentCreator.create_agent("classic", **agent_base_params)
assert isinstance(agent, ClassicAgent)
assert agent.endpoint == agent_base_params["endpoint"]
assert agent.llm_name == agent_base_params["llm_name"]
assert agent.gpt_model == agent_base_params["gpt_model"]
def test_create_react_agent(self, agent_base_params):
agent = AgentCreator.create_agent("react", **agent_base_params)
assert isinstance(agent, ReActAgent)
assert agent.endpoint == agent_base_params["endpoint"]
assert agent.llm_name == agent_base_params["llm_name"]
def test_create_agent_case_insensitive(self, agent_base_params):
agent_upper = AgentCreator.create_agent("CLASSIC", **agent_base_params)
agent_mixed = AgentCreator.create_agent("ClAsSiC", **agent_base_params)
assert isinstance(agent_upper, ClassicAgent)
assert isinstance(agent_mixed, ClassicAgent)
def test_create_agent_invalid_type(self, agent_base_params):
with pytest.raises(ValueError, match="No agent class found for type"):
AgentCreator.create_agent("invalid_agent_type", **agent_base_params)
def test_agent_registry_contains_expected_agents(self):
assert "classic" in AgentCreator.agents
assert "react" in AgentCreator.agents
assert AgentCreator.agents["classic"] == ClassicAgent
assert AgentCreator.agents["react"] == ReActAgent
def test_create_agent_with_optional_params(self, agent_base_params):
agent_base_params["user_api_key"] = "user_key_123"
agent_base_params["chat_history"] = [{"prompt": "test", "response": "test"}]
agent_base_params["json_schema"] = {"type": "object"}
agent = AgentCreator.create_agent("classic", **agent_base_params)
assert agent.user_api_key == "user_key_123"
assert len(agent.chat_history) == 1
assert agent.json_schema == {"type": "object"}
def test_create_agent_with_attachments(self, agent_base_params):
attachments = [{"name": "file.txt", "content": "test"}]
agent_base_params["attachments"] = attachments
agent = AgentCreator.create_agent("classic", **agent_base_params)
assert agent.attachments == attachments

View File

@@ -0,0 +1,641 @@
from unittest.mock import Mock
import pytest
from application.agents.classic_agent import ClassicAgent
from application.core.settings import settings
from tests.conftest import FakeMongoCollection
@pytest.mark.unit
class TestBaseAgentInitialization:
def test_agent_initialization(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
assert agent.endpoint == agent_base_params["endpoint"]
assert agent.llm_name == agent_base_params["llm_name"]
assert agent.gpt_model == agent_base_params["gpt_model"]
assert agent.api_key == agent_base_params["api_key"]
assert agent.prompt == agent_base_params["prompt"]
assert agent.user == agent_base_params["decoded_token"]["sub"]
assert agent.tools == []
assert agent.tool_calls == []
def test_agent_initialization_with_none_chat_history(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent_base_params["chat_history"] = None
agent = ClassicAgent(**agent_base_params)
assert agent.chat_history == []
def test_agent_initialization_with_chat_history(
self,
agent_base_params,
sample_chat_history,
mock_llm_creator,
mock_llm_handler_creator,
):
agent_base_params["chat_history"] = sample_chat_history
agent = ClassicAgent(**agent_base_params)
assert len(agent.chat_history) == 2
assert agent.chat_history[0]["prompt"] == "What is Python?"
def test_agent_decoded_token_defaults_to_empty_dict(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent_base_params["decoded_token"] = None
agent = ClassicAgent(**agent_base_params)
assert agent.decoded_token == {}
assert agent.user is None
def test_agent_user_extracted_from_token(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent_base_params["decoded_token"] = {"sub": "user123"}
agent = ClassicAgent(**agent_base_params)
assert agent.user == "user123"
@pytest.mark.unit
class TestBaseAgentBuildMessages:
def test_build_messages_basic(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
system_prompt = "System: {summaries}"
query = "What is Python?"
retrieved_data = [
{"text": "Python is a programming language", "filename": "python.txt"}
]
messages = agent._build_messages(system_prompt, query, retrieved_data)
assert len(messages) >= 2
assert messages[0]["role"] == "system"
assert "Python is a programming language" in messages[0]["content"]
assert messages[-1]["role"] == "user"
assert messages[-1]["content"] == query
def test_build_messages_with_chat_history(
self,
agent_base_params,
sample_chat_history,
mock_llm_creator,
mock_llm_handler_creator,
):
agent_base_params["chat_history"] = sample_chat_history
agent = ClassicAgent(**agent_base_params)
system_prompt = "System: {summaries}"
query = "New question?"
retrieved_data = [{"text": "Data", "filename": "file.txt"}]
messages = agent._build_messages(system_prompt, query, retrieved_data)
user_messages = [m for m in messages if m["role"] == "user"]
assistant_messages = [m for m in messages if m["role"] == "assistant"]
assert len(user_messages) >= 3
assert len(assistant_messages) >= 2
def test_build_messages_with_tool_calls_in_history(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
tool_call_history = [
{
"tool_calls": [
{
"call_id": "123",
"action_name": "test_action",
"arguments": {"arg": "value"},
"result": "success",
}
]
}
]
agent_base_params["chat_history"] = tool_call_history
agent = ClassicAgent(**agent_base_params)
messages = agent._build_messages(
"System: {summaries}", "query", [{"text": "data", "filename": "file.txt"}]
)
tool_messages = [m for m in messages if m["role"] == "tool"]
assert len(tool_messages) > 0
def test_build_messages_handles_missing_filename(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
retrieved_data = [{"text": "Document without filename or title"}]
messages = agent._build_messages("System: {summaries}", "query", retrieved_data)
assert messages[0]["role"] == "system"
assert "Document without filename" in messages[0]["content"]
def test_build_messages_uses_title_as_fallback(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
retrieved_data = [{"text": "Data", "title": "Title Doc"}]
messages = agent._build_messages("System: {summaries}", "query", retrieved_data)
assert "Title Doc" in messages[0]["content"]
def test_build_messages_uses_source_as_fallback(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
retrieved_data = [{"text": "Data", "source": "source.txt"}]
messages = agent._build_messages("System: {summaries}", "query", retrieved_data)
assert "source.txt" in messages[0]["content"]
@pytest.mark.unit
class TestBaseAgentTools:
def test_get_user_tools(
self,
agent_base_params,
mock_mongo_db,
mock_llm_creator,
mock_llm_handler_creator,
):
mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"].docs = {
"1": {"_id": "1", "user": "test_user", "name": "tool1", "status": True},
"2": {"_id": "2", "user": "test_user", "name": "tool2", "status": True},
}
agent = ClassicAgent(**agent_base_params)
tools = agent._get_user_tools("test_user")
assert len(tools) == 2
assert "0" in tools
assert "1" in tools
def test_get_user_tools_filters_by_status(
self,
agent_base_params,
mock_mongo_db,
mock_llm_creator,
mock_llm_handler_creator,
):
mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"].docs = {
"1": {"_id": "1", "user": "test_user", "name": "tool1", "status": True},
"2": {"_id": "2", "user": "test_user", "name": "tool2", "status": False},
}
agent = ClassicAgent(**agent_base_params)
tools = agent._get_user_tools("test_user")
assert len(tools) == 1
def test_get_tools_by_api_key(
self,
agent_base_params,
mock_mongo_db,
mock_llm_creator,
mock_llm_handler_creator,
):
from bson.objectid import ObjectId
tool_id = str(ObjectId())
tool_obj_id = ObjectId(tool_id)
fake_agent_collection = FakeMongoCollection()
fake_agent_collection.docs["api_key_123"] = {
"key": "api_key_123",
"tools": [tool_id],
}
fake_tools_collection = FakeMongoCollection()
fake_tools_collection.docs[tool_id] = {"_id": tool_obj_id, "name": "api_tool"}
mock_mongo_db[settings.MONGO_DB_NAME]["agents"] = fake_agent_collection
mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"] = fake_tools_collection
agent = ClassicAgent(**agent_base_params)
tools = agent._get_tools("api_key_123")
assert tool_id in tools
def test_build_tool_parameters(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
action = {
"parameters": {
"properties": {
"param1": {
"type": "string",
"description": "Test param",
"filled_by_llm": True,
},
"param2": {"type": "number", "filled_by_llm": False, "value": 42},
}
}
}
params = agent._build_tool_parameters(action)
assert "param1" in params["properties"]
assert "param1" in params["required"]
assert "filled_by_llm" not in params["properties"]["param1"]
def test_prepare_tools_with_api_tool(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
tools_dict = {
"1": {
"name": "api_tool",
"config": {
"actions": {
"get_data": {
"name": "get_data",
"description": "Get data from API",
"active": True,
"url": "https://api.example.com/data",
"method": "GET",
"parameters": {"properties": {}},
}
}
},
}
}
agent._prepare_tools(tools_dict)
assert len(agent.tools) == 1
assert agent.tools[0]["type"] == "function"
assert agent.tools[0]["function"]["name"] == "get_data_1"
def test_prepare_tools_with_regular_tool(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
tools_dict = {
"1": {
"name": "custom_tool",
"actions": [
{
"name": "action1",
"description": "Custom action",
"active": True,
"parameters": {"properties": {}},
}
],
}
}
agent._prepare_tools(tools_dict)
assert len(agent.tools) == 1
assert agent.tools[0]["function"]["name"] == "action1_1"
def test_prepare_tools_filters_inactive_actions(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
tools_dict = {
"1": {
"name": "custom_tool",
"actions": [
{
"name": "active_action",
"description": "Active",
"active": True,
"parameters": {"properties": {}},
},
{
"name": "inactive_action",
"description": "Inactive",
"active": False,
"parameters": {"properties": {}},
},
],
}
}
agent._prepare_tools(tools_dict)
assert len(agent.tools) == 1
assert agent.tools[0]["function"]["name"] == "active_action_1"
@pytest.mark.unit
class TestBaseAgentToolExecution:
def test_execute_tool_action_success(
self,
agent_base_params,
mock_llm_creator,
mock_llm_handler_creator,
mock_tool_manager,
):
agent = ClassicAgent(**agent_base_params)
call = Mock()
call.id = "call_123"
call.name = "test_action_1"
call.arguments = '{"param1": "value1"}'
tools_dict = {
"1": {
"name": "custom_tool",
"config": {},
"actions": [
{
"name": "test_action",
"description": "Test",
"parameters": {"properties": {}},
}
],
}
}
results = list(agent._execute_tool_action(tools_dict, call))
assert len(results) >= 2
assert results[0]["type"] == "tool_call"
assert results[0]["data"]["status"] == "pending"
assert results[-1]["data"]["status"] == "completed"
def test_execute_tool_action_invalid_tool_name(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
call = Mock()
call.id = "call_123"
call.name = "invalid_format"
call.arguments = "{}"
tools_dict = {}
results = list(agent._execute_tool_action(tools_dict, call))
assert results[0]["type"] == "tool_call"
assert results[0]["data"]["status"] == "error"
assert (
"Failed to parse" in results[0]["data"]["result"]
or "not found" in results[0]["data"]["result"]
)
def test_execute_tool_action_tool_not_found(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
call = Mock()
call.id = "call_123"
call.name = "action_999"
call.arguments = "{}"
tools_dict = {"1": {"name": "tool1", "config": {}, "actions": []}}
results = list(agent._execute_tool_action(tools_dict, call))
assert results[0]["type"] == "tool_call"
assert results[0]["data"]["status"] == "error"
assert "not found" in results[0]["data"]["result"]
def test_execute_tool_action_with_parameters(
self,
agent_base_params,
mock_llm_creator,
mock_llm_handler_creator,
mock_tool_manager,
):
agent = ClassicAgent(**agent_base_params)
call = Mock()
call.id = "call_123"
call.name = "test_action_1"
call.arguments = '{"param1": "value1", "param2": "value2"}'
tools_dict = {
"1": {
"name": "custom_tool",
"config": {},
"actions": [
{
"name": "test_action",
"description": "Test",
"parameters": {
"properties": {
"param1": {"type": "string"},
"param2": {"type": "string"},
}
},
}
],
}
}
results = list(agent._execute_tool_action(tools_dict, call))
assert results[-1]["data"]["status"] == "completed"
assert results[-1]["data"]["arguments"]["param1"] == "value1"
def test_get_truncated_tool_calls(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
agent.tool_calls = [
{
"tool_name": "test_tool",
"call_id": "123",
"action_name": "action",
"arguments": {},
"result": "a" * 100,
}
]
truncated = agent._get_truncated_tool_calls()
assert len(truncated) == 1
assert len(truncated[0]["result"]) <= 53
assert truncated[0]["result"].endswith("...")
@pytest.mark.unit
class TestBaseAgentRetrieverSearch:
def test_retriever_search(
self,
agent_base_params,
mock_retriever,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
agent = ClassicAgent(**agent_base_params)
results = agent._retriever_search(mock_retriever, "test query", log_context)
assert len(results) == 2
mock_retriever.search.assert_called_once_with("test query")
def test_retriever_search_logs_context(
self,
agent_base_params,
mock_retriever,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
agent = ClassicAgent(**agent_base_params)
agent._retriever_search(mock_retriever, "test query", log_context)
assert len(log_context.stacks) == 1
assert log_context.stacks[0]["component"] == "retriever"
@pytest.mark.unit
class TestBaseAgentLLMGeneration:
def test_llm_gen_basic(
self,
agent_base_params,
mock_llm,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
agent = ClassicAgent(**agent_base_params)
messages = [{"role": "user", "content": "test"}]
agent._llm_gen(messages, log_context)
mock_llm.gen_stream.assert_called_once()
call_args = mock_llm.gen_stream.call_args[1]
assert call_args["model"] == agent.gpt_model
assert call_args["messages"] == messages
def test_llm_gen_with_tools(
self,
agent_base_params,
mock_llm,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
agent = ClassicAgent(**agent_base_params)
agent.tools = [{"type": "function", "function": {"name": "test"}}]
messages = [{"role": "user", "content": "test"}]
agent._llm_gen(messages, log_context)
call_args = mock_llm.gen_stream.call_args[1]
assert "tools" in call_args
assert call_args["tools"] == agent.tools
def test_llm_gen_with_json_schema(
self,
agent_base_params,
mock_llm,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
mock_llm._supports_structured_output = Mock(return_value=True)
mock_llm.prepare_structured_output_format = Mock(
return_value={"schema": "test"}
)
agent_base_params["json_schema"] = {"type": "object"}
agent_base_params["llm_name"] = "openai"
agent = ClassicAgent(**agent_base_params)
messages = [{"role": "user", "content": "test"}]
agent._llm_gen(messages, log_context)
call_args = mock_llm.gen_stream.call_args[1]
assert "response_format" in call_args
@pytest.mark.unit
class TestBaseAgentHandleResponse:
def test_handle_response_string(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator, log_context
):
agent = ClassicAgent(**agent_base_params)
response = "Simple string response"
results = list(agent._handle_response(response, {}, [], log_context))
assert len(results) == 1
assert results[0]["answer"] == "Simple string response"
def test_handle_response_with_message(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator, log_context
):
agent = ClassicAgent(**agent_base_params)
response = Mock()
response.message = Mock()
response.message.content = "Message content"
results = list(agent._handle_response(response, {}, [], log_context))
assert len(results) == 1
assert results[0]["answer"] == "Message content"
def test_handle_response_with_structured_output(
self,
agent_base_params,
mock_llm,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
mock_llm._supports_structured_output = Mock(return_value=True)
agent_base_params["json_schema"] = {"type": "object"}
agent = ClassicAgent(**agent_base_params)
response = "Structured response"
results = list(agent._handle_response(response, {}, [], log_context))
assert results[0]["structured"] is True
assert results[0]["schema"] == {"type": "object"}
def test_handle_response_with_handler(
self,
agent_base_params,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
def mock_process(*args):
yield {"type": "tool_call", "data": {}}
yield "Final answer"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_process)
agent = ClassicAgent(**agent_base_params)
response = Mock()
response.message = None
results = list(agent._handle_response(response, {}, [], log_context))
assert len(results) == 2
assert results[0]["type"] == "tool_call"
assert results[1]["answer"] == "Final answer"

View File

@@ -0,0 +1,242 @@
from unittest.mock import Mock
import pytest
from application.agents.classic_agent import ClassicAgent
@pytest.mark.unit
class TestClassicAgent:
def test_classic_agent_initialization(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
assert isinstance(agent, ClassicAgent)
assert agent.endpoint == agent_base_params["endpoint"]
assert agent.llm_name == agent_base_params["llm_name"]
def test_gen_inner_basic_flow(
self,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
log_context,
):
def mock_gen_stream(*args, **kwargs):
yield "Answer chunk 1"
yield "Answer chunk 2"
mock_llm.gen_stream = Mock(return_value=mock_gen_stream())
def mock_handler(*args, **kwargs):
yield "Processed answer"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ClassicAgent(**agent_base_params)
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
assert len(results) >= 2
sources = [r for r in results if "sources" in r]
tool_calls = [r for r in results if "tool_calls" in r]
assert len(sources) == 1
assert len(tool_calls) == 1
def test_gen_inner_retrieves_documents(
self,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
log_context,
):
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
def mock_handler(*args, **kwargs):
yield "Processed"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ClassicAgent(**agent_base_params)
list(agent._gen_inner("Test query", mock_retriever, log_context))
mock_retriever.search.assert_called_once_with("Test query")
def test_gen_inner_uses_user_api_key_tools(
self,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
log_context,
):
from application.core.settings import settings
from bson.objectid import ObjectId
tool_id = str(ObjectId())
mock_mongo_db[settings.MONGO_DB_NAME]["agents"].docs = {
"api_key_123": {"key": "api_key_123", "tools": [tool_id]}
}
mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"].docs = {
tool_id: {"_id": ObjectId(tool_id), "name": "test_tool"}
}
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
def mock_handler(*args, **kwargs):
yield "Processed"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent_base_params["user_api_key"] = "api_key_123"
agent = ClassicAgent(**agent_base_params)
list(agent._gen_inner("Test query", mock_retriever, log_context))
assert len(agent.tools) >= 0
def test_gen_inner_uses_user_tools(
self,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
log_context,
):
from application.core.settings import settings
mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"].docs = {
"1": {"_id": "1", "user": "test_user", "name": "tool1", "status": True}
}
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
def mock_handler(*args, **kwargs):
yield "Processed"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ClassicAgent(**agent_base_params)
list(agent._gen_inner("Test query", mock_retriever, log_context))
assert len(agent.tools) >= 0
def test_gen_inner_builds_correct_messages(
self,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
log_context,
):
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
def mock_handler(*args, **kwargs):
yield "Processed"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ClassicAgent(**agent_base_params)
list(agent._gen_inner("Test query", mock_retriever, log_context))
call_kwargs = mock_llm.gen_stream.call_args[1]
messages = call_kwargs["messages"]
assert len(messages) >= 2
assert messages[0]["role"] == "system"
assert messages[-1]["role"] == "user"
assert messages[-1]["content"] == "Test query"
def test_gen_inner_logs_tool_calls(
self,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
log_context,
):
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
def mock_handler(*args, **kwargs):
yield "Processed"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ClassicAgent(**agent_base_params)
agent.tool_calls = [{"tool": "test", "result": "success"}]
list(agent._gen_inner("Test query", mock_retriever, log_context))
agent_logs = [s for s in log_context.stacks if s["component"] == "agent"]
assert len(agent_logs) == 1
assert "tool_calls" in agent_logs[0]["data"]
@pytest.mark.integration
class TestClassicAgentIntegration:
def test_gen_method_with_logging(
self,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
):
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
def mock_handler(*args, **kwargs):
yield "Processed"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ClassicAgent(**agent_base_params)
results = list(agent.gen("Test query", mock_retriever))
assert len(results) >= 1
def test_gen_method_decorator_applied(
self,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
):
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
def mock_handler(*args, **kwargs):
yield "Processed"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ClassicAgent(**agent_base_params)
assert hasattr(agent.gen, "__wrapped__")

View File

@@ -0,0 +1,519 @@
from unittest.mock import Mock, mock_open, patch
import pytest
from application.agents.react_agent import ReActAgent
@pytest.mark.unit
class TestReActAgent:
def test_react_agent_initialization(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ReActAgent(**agent_base_params)
assert isinstance(agent, ReActAgent)
assert agent.plan == ""
assert agent.observations == []
def test_react_agent_inherits_base_properties(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ReActAgent(**agent_base_params)
assert agent.endpoint == agent_base_params["endpoint"]
assert agent.llm_name == agent_base_params["llm_name"]
assert agent.gpt_model == agent_base_params["gpt_model"]
@pytest.mark.unit
class TestReActAgentContentExtraction:
def test_extract_content_from_string(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ReActAgent(**agent_base_params)
response = "Simple string response"
content = agent._extract_content_from_llm_response(response)
assert content == "Simple string response"
def test_extract_content_from_message_object(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ReActAgent(**agent_base_params)
response = Mock()
response.message = Mock()
response.message.content = "Message content"
content = agent._extract_content_from_llm_response(response)
assert content == "Message content"
def test_extract_content_from_openai_response(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ReActAgent(**agent_base_params)
response = Mock()
response.choices = [Mock()]
response.choices[0].message = Mock()
response.choices[0].message.content = "OpenAI content"
response.message = None
response.content = None
content = agent._extract_content_from_llm_response(response)
assert content == "OpenAI content"
def test_extract_content_from_anthropic_response(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ReActAgent(**agent_base_params)
text_block = Mock()
text_block.text = "Anthropic content"
response = Mock()
response.content = [text_block]
response.message = None
response.choices = None
content = agent._extract_content_from_llm_response(response)
assert content == "Anthropic content"
def test_extract_content_from_openai_stream(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ReActAgent(**agent_base_params)
chunk1 = Mock()
chunk1.choices = [Mock()]
chunk1.choices[0].delta = Mock()
chunk1.choices[0].delta.content = "Part 1 "
chunk2 = Mock()
chunk2.choices = [Mock()]
chunk2.choices[0].delta = Mock()
chunk2.choices[0].delta.content = "Part 2"
response = iter([chunk1, chunk2])
content = agent._extract_content_from_llm_response(response)
assert content == "Part 1 Part 2"
def test_extract_content_from_anthropic_stream(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ReActAgent(**agent_base_params)
chunk1 = Mock()
chunk1.type = "content_block_delta"
chunk1.delta = Mock()
chunk1.delta.text = "Stream 1 "
chunk1.choices = []
chunk2 = Mock()
chunk2.type = "content_block_delta"
chunk2.delta = Mock()
chunk2.delta.text = "Stream 2"
chunk2.choices = []
response = iter([chunk1, chunk2])
content = agent._extract_content_from_llm_response(response)
assert content == "Stream 1 Stream 2"
def test_extract_content_from_string_stream(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ReActAgent(**agent_base_params)
response = iter(["chunk1", "chunk2", "chunk3"])
content = agent._extract_content_from_llm_response(response)
assert content == "chunk1chunk2chunk3"
def test_extract_content_handles_none_content(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ReActAgent(**agent_base_params)
response = Mock()
response.message = Mock()
response.message.content = None
response.choices = None
response.content = None
content = agent._extract_content_from_llm_response(response)
assert content == ""
@pytest.mark.unit
class TestReActAgentPlanning:
@patch(
"builtins.open",
new_callable=mock_open,
read_data="Test planning prompt: {query} {summaries} {prompt} {observations}",
)
def test_create_plan(
self,
mock_file,
agent_base_params,
mock_llm,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
def mock_gen_stream(*args, **kwargs):
yield "Plan step 1"
yield "Plan step 2"
mock_llm.gen_stream = Mock(return_value=mock_gen_stream())
agent = ReActAgent(**agent_base_params)
agent.observations = ["Observation 1"]
plan_chunks = list(agent._create_plan("Test query", "Test docs", log_context))
assert len(plan_chunks) == 2
assert plan_chunks[0] == "Plan step 1"
assert plan_chunks[1] == "Plan step 2"
mock_llm.gen_stream.assert_called_once()
@patch("builtins.open", new_callable=mock_open, read_data="Test: {query}")
def test_create_plan_fills_template(
self,
mock_file,
agent_base_params,
mock_llm,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
mock_llm.gen_stream = Mock(return_value=iter(["Plan"]))
agent = ReActAgent(**agent_base_params)
list(agent._create_plan("My query", "Docs", log_context))
call_args = mock_llm.gen_stream.call_args[1]
messages = call_args["messages"]
assert "My query" in messages[0]["content"]
@pytest.mark.unit
class TestReActAgentFinalAnswer:
@patch(
"builtins.open",
new_callable=mock_open,
read_data="Final answer for: {query} with {observations}",
)
def test_create_final_answer(
self,
mock_file,
agent_base_params,
mock_llm,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
def mock_gen_stream(*args, **kwargs):
yield "Final "
yield "answer"
mock_llm.gen_stream = Mock(return_value=mock_gen_stream())
agent = ReActAgent(**agent_base_params)
observations = ["Obs 1", "Obs 2"]
answer_chunks = list(
agent._create_final_answer("Test query", observations, log_context)
)
assert len(answer_chunks) == 2
assert answer_chunks[0] == "Final "
assert answer_chunks[1] == "answer"
@patch("builtins.open", new_callable=mock_open, read_data="Answer: {observations}")
def test_create_final_answer_truncates_long_observations(
self,
mock_file,
agent_base_params,
mock_llm,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
agent = ReActAgent(**agent_base_params)
long_obs = ["A" * 15000]
list(agent._create_final_answer("Query", long_obs, log_context))
call_args = mock_llm.gen_stream.call_args[1]
messages = call_args["messages"]
assert "observations truncated" in messages[0]["content"]
@patch("builtins.open", new_callable=mock_open, read_data="Test: {query}")
def test_create_final_answer_no_tools(
self,
mock_file,
agent_base_params,
mock_llm,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
agent = ReActAgent(**agent_base_params)
list(agent._create_final_answer("Query", ["Obs"], log_context))
call_args = mock_llm.gen_stream.call_args[1]
assert call_args["tools"] is None
@pytest.mark.unit
class TestReActAgentGenInner:
@patch(
"builtins.open", new_callable=mock_open, read_data="Prompt template: {query}"
)
def test_gen_inner_resets_state(
self,
mock_file,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
log_context,
):
mock_llm.gen_stream = Mock(return_value=iter(["SATISFIED"]))
def mock_handler(*args, **kwargs):
yield "SATISFIED"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ReActAgent(**agent_base_params)
agent.plan = "Old plan"
agent.observations = ["Old obs"]
list(agent._gen_inner("New query", mock_retriever, log_context))
assert agent.plan != "Old plan"
assert len(agent.observations) > 0
@patch("builtins.open", new_callable=mock_open, read_data="Prompt: {query}")
def test_gen_inner_stops_on_satisfied(
self,
mock_file,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
log_context,
):
iteration_count = 0
def mock_gen_stream(*args, **kwargs):
nonlocal iteration_count
iteration_count += 1
if iteration_count == 1:
yield "Plan"
else:
yield "SATISFIED - done"
mock_llm.gen_stream = Mock(
side_effect=lambda *args, **kwargs: mock_gen_stream(*args, **kwargs)
)
def mock_handler(*args, **kwargs):
yield "SATISFIED - done"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ReActAgent(**agent_base_params)
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
assert any("answer" in r for r in results)
@patch("builtins.open", new_callable=mock_open, read_data="Prompt: {query}")
def test_gen_inner_max_iterations(
self,
mock_file,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
log_context,
):
call_count = 0
def mock_gen_stream(*args, **kwargs):
nonlocal call_count
call_count += 1
yield f"Iteration {call_count}"
mock_llm.gen_stream = Mock(
side_effect=lambda *args, **kwargs: mock_gen_stream(*args, **kwargs)
)
def mock_handler(*args, **kwargs):
yield "Continue..."
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ReActAgent(**agent_base_params)
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
thought_results = [r for r in results if "thought" in r]
assert len(thought_results) > 0
@patch("builtins.open", new_callable=mock_open, read_data="Prompt: {query}")
def test_gen_inner_yields_sources(
self,
mock_file,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
log_context,
):
mock_llm.gen_stream = Mock(return_value=iter(["SATISFIED"]))
def mock_handler(*args, **kwargs):
yield "SATISFIED"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ReActAgent(**agent_base_params)
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
sources = [r for r in results if "sources" in r]
assert len(sources) >= 1
@patch("builtins.open", new_callable=mock_open, read_data="Prompt: {query}")
def test_gen_inner_yields_tool_calls(
self,
mock_file,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
log_context,
):
mock_llm.gen_stream = Mock(return_value=iter(["SATISFIED"]))
def mock_handler(*args, **kwargs):
yield "SATISFIED"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ReActAgent(**agent_base_params)
agent.tool_calls = [{"tool": "test", "result": "A" * 100}]
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
tool_call_results = [r for r in results if "tool_calls" in r]
if tool_call_results:
assert len(tool_call_results[0]["tool_calls"][0]["result"]) <= 53
@patch("builtins.open", new_callable=mock_open, read_data="Prompt: {query}")
def test_gen_inner_logs_observations(
self,
mock_file,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
log_context,
):
mock_llm.gen_stream = Mock(return_value=iter(["SATISFIED"]))
def mock_handler(*args, **kwargs):
yield "SATISFIED"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ReActAgent(**agent_base_params)
list(agent._gen_inner("Test query", mock_retriever, log_context))
assert len(agent.observations) > 0
@pytest.mark.integration
class TestReActAgentIntegration:
@patch(
"builtins.open",
new_callable=mock_open,
read_data="Prompt: {query} {summaries} {prompt} {observations}",
)
def test_full_react_workflow(
self,
mock_file,
agent_base_params,
mock_retriever,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
log_context,
):
call_sequence = []
def mock_gen_stream(*args, **kwargs):
call_sequence.append("gen_stream")
if len(call_sequence) <= 2:
yield "Planning..."
else:
yield "SATISFIED final answer"
mock_llm.gen_stream = Mock(
side_effect=lambda *args, **kwargs: mock_gen_stream(*args, **kwargs)
)
def mock_handler(*args, **kwargs):
call_sequence.append("handler")
yield "SATISFIED final answer"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = ReActAgent(**agent_base_params)
results = list(agent._gen_inner("Complex query", mock_retriever, log_context))
assert len(results) > 0
assert any("thought" in r for r in results)
assert any("answer" in r for r in results)

View File

@@ -0,0 +1,204 @@
from unittest.mock import Mock
import pytest
from application.agents.tools.tool_action_parser import ToolActionParser
@pytest.mark.unit
class TestToolActionParser:
def test_parser_initialization(self):
parser = ToolActionParser("OpenAILLM")
assert parser.llm_type == "OpenAILLM"
assert "OpenAILLM" in parser.parsers
assert "GoogleLLM" in parser.parsers
def test_parse_openai_llm_valid_call(self):
parser = ToolActionParser("OpenAILLM")
call = Mock()
call.name = "get_data_123"
call.arguments = '{"param1": "value1", "param2": "value2"}'
tool_id, action_name, call_args = parser.parse_args(call)
assert tool_id == "123"
assert action_name == "get_data"
assert call_args == {"param1": "value1", "param2": "value2"}
def test_parse_openai_llm_with_underscore_in_action(self):
parser = ToolActionParser("OpenAILLM")
call = Mock()
call.name = "send_email_notification_456"
call.arguments = '{"to": "user@example.com"}'
tool_id, action_name, call_args = parser.parse_args(call)
assert tool_id == "456"
assert action_name == "send_email_notification"
assert call_args == {"to": "user@example.com"}
def test_parse_openai_llm_invalid_format_no_underscore(self):
parser = ToolActionParser("OpenAILLM")
call = Mock()
call.name = "invalidtoolname"
call.arguments = "{}"
tool_id, action_name, call_args = parser.parse_args(call)
assert tool_id is None
assert action_name is None
assert call_args is None
def test_parse_openai_llm_non_numeric_tool_id(self):
parser = ToolActionParser("OpenAILLM")
call = Mock()
call.name = "action_notanumber"
call.arguments = "{}"
tool_id, action_name, call_args = parser.parse_args(call)
assert tool_id == "notanumber"
assert action_name == "action"
def test_parse_openai_llm_malformed_json(self):
parser = ToolActionParser("OpenAILLM")
call = Mock()
call.name = "action_123"
call.arguments = "invalid json"
tool_id, action_name, call_args = parser.parse_args(call)
assert tool_id is None
assert action_name is None
assert call_args is None
def test_parse_openai_llm_missing_attributes(self):
parser = ToolActionParser("OpenAILLM")
call = Mock(spec=[])
tool_id, action_name, call_args = parser.parse_args(call)
assert tool_id is None
assert action_name is None
assert call_args is None
def test_parse_google_llm_valid_call(self):
parser = ToolActionParser("GoogleLLM")
call = Mock()
call.name = "search_documents_789"
call.arguments = {"query": "test query", "limit": 10}
tool_id, action_name, call_args = parser.parse_args(call)
assert tool_id == "789"
assert action_name == "search_documents"
assert call_args == {"query": "test query", "limit": 10}
def test_parse_google_llm_with_complex_action_name(self):
parser = ToolActionParser("GoogleLLM")
call = Mock()
call.name = "create_new_user_account_999"
call.arguments = {"username": "test"}
tool_id, action_name, call_args = parser.parse_args(call)
assert tool_id == "999"
assert action_name == "create_new_user_account"
def test_parse_google_llm_invalid_format(self):
parser = ToolActionParser("GoogleLLM")
call = Mock()
call.name = "nounderscores"
call.arguments = {}
tool_id, action_name, call_args = parser.parse_args(call)
assert tool_id is None
assert action_name is None
assert call_args is None
def test_parse_google_llm_missing_attributes(self):
parser = ToolActionParser("GoogleLLM")
call = Mock(spec=[])
tool_id, action_name, call_args = parser.parse_args(call)
assert tool_id is None
assert action_name is None
assert call_args is None
def test_parse_unknown_llm_type_defaults_to_openai(self):
parser = ToolActionParser("UnknownLLM")
call = Mock()
call.name = "action_123"
call.arguments = '{"key": "value"}'
tool_id, action_name, call_args = parser.parse_args(call)
assert tool_id == "123"
assert action_name == "action"
assert call_args == {"key": "value"}
def test_parse_args_empty_arguments_openai(self):
parser = ToolActionParser("OpenAILLM")
call = Mock()
call.name = "action_123"
call.arguments = "{}"
tool_id, action_name, call_args = parser.parse_args(call)
assert tool_id == "123"
assert action_name == "action"
assert call_args == {}
def test_parse_args_empty_arguments_google(self):
parser = ToolActionParser("GoogleLLM")
call = Mock()
call.name = "action_456"
call.arguments = {}
tool_id, action_name, call_args = parser.parse_args(call)
assert tool_id == "456"
assert action_name == "action"
assert call_args == {}
def test_parse_args_with_special_characters(self):
parser = ToolActionParser("OpenAILLM")
call = Mock()
call.name = "send_message_123"
call.arguments = '{"message": "Hello, World! 你好"}'
tool_id, action_name, call_args = parser.parse_args(call)
assert tool_id == "123"
assert action_name == "send_message"
assert call_args["message"] == "Hello, World! 你好"
def test_parse_args_with_nested_objects(self):
parser = ToolActionParser("OpenAILLM")
call = Mock()
call.name = "create_record_123"
call.arguments = '{"data": {"name": "John", "age": 30}}'
tool_id, action_name, call_args = parser.parse_args(call)
assert tool_id == "123"
assert action_name == "create_record"
assert call_args["data"]["name"] == "John"
assert call_args["data"]["age"] == 30

View File

@@ -0,0 +1,235 @@
from unittest.mock import Mock, patch
import pytest
from application.agents.tools.base import Tool
from application.agents.tools.tool_manager import ToolManager
class MockTool(Tool):
def __init__(self, config):
self.config = config
def execute_action(self, action_name: str, **kwargs):
return f"Executed {action_name} with {kwargs}"
def get_actions_metadata(self):
return [{"name": "test_action", "description": "Test action"}]
def get_config_requirements(self):
return {"required": ["api_key"]}
@pytest.mark.unit
class TestToolManager:
@patch("application.agents.tools.tool_manager.pkgutil.iter_modules")
def test_tool_manager_initialization(self, mock_iter):
mock_iter.return_value = []
config = {"tool1": {"key": "value"}}
manager = ToolManager(config)
assert manager.config == config
assert isinstance(manager.tools, dict)
@patch("application.agents.tools.tool_manager.pkgutil.iter_modules")
@patch("application.agents.tools.tool_manager.importlib.import_module")
def test_load_tools_skips_base_and_private(self, mock_import, mock_iter):
mock_iter.return_value = [
(None, "base", False),
(None, "__init__", False),
(None, "__pycache__", False),
(None, "valid_tool", False),
]
mock_module = Mock()
mock_module.MockTool = MockTool
mock_import.return_value = mock_module
manager = ToolManager({})
assert "base" not in manager.tools
assert "__init__" not in manager.tools
@patch("application.agents.tools.tool_manager.pkgutil.iter_modules")
def test_load_tools_creates_tool_instances(self, mock_iter):
mock_iter.return_value = []
manager = ToolManager({})
mock_tool = MockTool({"test": "config"})
manager.tools["mock_tool"] = mock_tool
assert "mock_tool" in manager.tools
assert isinstance(manager.tools["mock_tool"], MockTool)
assert manager.tools["mock_tool"].config == {"test": "config"}
def test_load_tool_with_user_id(self):
with patch(
"application.agents.tools.tool_manager.pkgutil.iter_modules",
return_value=[],
):
manager = ToolManager({})
tool = MockTool({"key": "value"})
assert tool.config == {"key": "value"}
manager.config["test_tool"] = {"key": "value"}
assert "test_tool" in manager.config
def test_load_tool_without_user_id(self):
tool = MockTool({"api_key": "test123"})
assert isinstance(tool, MockTool)
assert tool.config == {"api_key": "test123"}
assert hasattr(tool, "execute_action")
assert hasattr(tool, "get_actions_metadata")
@patch("application.agents.tools.tool_manager.pkgutil.iter_modules")
def test_load_tool_updates_config(self, mock_iter):
mock_iter.return_value = []
manager = ToolManager({})
new_config = {"new_key": "new_value"}
manager.config["test_tool"] = new_config
assert manager.config["test_tool"] == new_config
assert "test_tool" in manager.config
@patch("application.agents.tools.tool_manager.pkgutil.iter_modules")
@patch("application.agents.tools.tool_manager.importlib.import_module")
def test_execute_action_on_loaded_tool(self, mock_import, mock_iter):
mock_iter.return_value = [(None, "mock_tool", False)]
mock_tool_instance = MockTool({})
with patch("inspect.getmembers", return_value=[("MockTool", MockTool)]):
with patch("inspect.isclass", return_value=True):
with patch.object(MockTool, "__init__", return_value=None):
manager = ToolManager({})
manager.tools["mock_tool"] = mock_tool_instance
result = manager.execute_action(
"mock_tool", "test_action", param="value"
)
assert "Executed test_action" in result
def test_execute_action_tool_not_loaded(self):
with patch(
"application.agents.tools.tool_manager.pkgutil.iter_modules",
return_value=[],
):
manager = ToolManager({})
with pytest.raises(ValueError, match="Tool 'nonexistent' not loaded"):
manager.execute_action("nonexistent", "action")
@patch("application.agents.tools.tool_manager.importlib.import_module")
def test_execute_action_with_user_id_for_mcp_tool(self, mock_import):
mock_tool = MockTool({})
with patch("inspect.getmembers", return_value=[("MockTool", MockTool)]):
with patch("inspect.isclass", return_value=True):
manager = ToolManager({"mcp_tool": {}})
manager.tools["mcp_tool"] = mock_tool
with patch.object(
manager, "load_tool", return_value=mock_tool
) as mock_load:
manager.execute_action("mcp_tool", "action", user_id="user123")
mock_load.assert_called_once_with("mcp_tool", {}, "user123")
@patch("application.agents.tools.tool_manager.importlib.import_module")
def test_execute_action_with_user_id_for_memory_tool(self, mock_import):
mock_tool = MockTool({})
with patch("inspect.getmembers", return_value=[("MockTool", MockTool)]):
with patch("inspect.isclass", return_value=True):
manager = ToolManager({"memory": {}})
manager.tools["memory"] = mock_tool
with patch.object(
manager, "load_tool", return_value=mock_tool
) as mock_load:
manager.execute_action("memory", "view", user_id="user456")
mock_load.assert_called_once_with("memory", {}, "user456")
@patch("application.agents.tools.tool_manager.pkgutil.iter_modules")
@patch("application.agents.tools.tool_manager.importlib.import_module")
def test_get_all_actions_metadata(self, mock_import, mock_iter):
mock_iter.return_value = [(None, "tool1", False), (None, "tool2", False)]
mock_tool1 = Mock()
mock_tool1.get_actions_metadata.return_value = [{"name": "action1"}]
mock_tool2 = Mock()
mock_tool2.get_actions_metadata.return_value = [{"name": "action2"}]
manager = ToolManager({})
manager.tools = {"tool1": mock_tool1, "tool2": mock_tool2}
metadata = manager.get_all_actions_metadata()
assert len(metadata) == 2
assert {"name": "action1"} in metadata
assert {"name": "action2"} in metadata
@patch("application.agents.tools.tool_manager.pkgutil.iter_modules")
def test_get_all_actions_metadata_empty(self, mock_iter):
mock_iter.return_value = []
manager = ToolManager({})
manager.tools = {}
metadata = manager.get_all_actions_metadata()
assert metadata == []
def test_load_tool_with_notes_tool(self):
tool = MockTool({"key": "value"})
assert isinstance(tool, MockTool)
assert tool.config == {"key": "value"}
result = tool.execute_action("test_action", param="value")
assert "test_action" in result
@pytest.mark.unit
class TestToolBase:
def test_tool_base_is_abstract(self):
with pytest.raises(TypeError):
Tool()
def test_mock_tool_implements_interface(self):
tool = MockTool({"test": "config"})
assert hasattr(tool, "execute_action")
assert hasattr(tool, "get_actions_metadata")
assert hasattr(tool, "get_config_requirements")
def test_mock_tool_execute_action(self):
tool = MockTool({})
result = tool.execute_action("test", param="value")
assert "Executed test" in result
assert "param" in result
def test_mock_tool_get_actions_metadata(self):
tool = MockTool({})
metadata = tool.get_actions_metadata()
assert isinstance(metadata, list)
assert len(metadata) > 0
assert "name" in metadata[0]
def test_mock_tool_get_config_requirements(self):
tool = MockTool({})
requirements = tool.get_config_requirements()
assert isinstance(requirements, dict)
assert "required" in requirements

199
tests/conftest.py Normal file
View File

@@ -0,0 +1,199 @@
from unittest.mock import Mock
import pytest
from application.core.settings import settings
@pytest.fixture
def mock_llm():
llm = Mock()
llm.gen_stream = Mock()
llm._supports_tools = True
llm._supports_structured_output = Mock(return_value=False)
llm.__class__.__name__ = "MockLLM"
return llm
@pytest.fixture
def mock_llm_handler():
handler = Mock()
handler.process_message_flow = Mock()
return handler
@pytest.fixture
def mock_retriever():
retriever = Mock()
retriever.search = Mock(
return_value=[
{"text": "Test document 1", "filename": "doc1.txt", "source": "test"},
{"text": "Test document 2", "title": "doc2.txt", "source": "test"},
]
)
return retriever
@pytest.fixture
def mock_mongo_db(monkeypatch):
fake_collection = FakeMongoCollection()
fake_db = {
"agents": fake_collection,
"user_tools": fake_collection,
"memories": fake_collection,
}
fake_client = {settings.MONGO_DB_NAME: fake_db}
monkeypatch.setattr(
"application.core.mongo_db.MongoDB.get_client", lambda: fake_client
)
return fake_client
@pytest.fixture
def sample_chat_history():
return [
{"prompt": "What is Python?", "response": "Python is a programming language."},
{"prompt": "Tell me more.", "response": "Python is known for simplicity."},
]
@pytest.fixture
def sample_tool_call():
return {
"tool_name": "test_tool",
"call_id": "123",
"action_name": "test_action",
"arguments": {"arg1": "value1"},
"result": "Tool executed successfully",
}
@pytest.fixture
def decoded_token():
return {"sub": "test_user", "email": "test@example.com"}
@pytest.fixture
def log_context():
from application.logging import LogContext
context = LogContext(
endpoint="test_endpoint",
activity_id="test_activity",
user="test_user",
api_key="test_key",
query="test query",
)
return context
class FakeMongoCollection:
def __init__(self):
self.docs = {}
def find_one(self, query, projection=None):
if "key" in query:
return self.docs.get(query["key"])
if "_id" in query:
return self.docs.get(str(query["_id"]))
if "user" in query:
for doc in self.docs.values():
if doc.get("user") == query["user"]:
return doc
return None
def find(self, query, projection=None):
results = []
if "_id" in query and "$in" in query["_id"]:
for doc_id in query["_id"]["$in"]:
doc = self.docs.get(str(doc_id))
if doc:
results.append(doc)
elif "user" in query:
for doc in self.docs.values():
if doc.get("user") == query["user"]:
if "status" in query:
if doc.get("status") == query["status"]:
results.append(doc)
else:
results.append(doc)
return results
def insert_one(self, doc):
doc_id = doc.get("_id", len(self.docs))
self.docs[str(doc_id)] = doc
return Mock(inserted_id=doc_id)
def update_one(self, query, update, upsert=False):
return Mock(modified_count=1)
def delete_one(self, query):
return Mock(deleted_count=1)
def delete_many(self, query):
return Mock(deleted_count=0)
@pytest.fixture
def mock_llm_creator(mock_llm, monkeypatch):
monkeypatch.setattr(
"application.llm.llm_creator.LLMCreator.create_llm", Mock(return_value=mock_llm)
)
return mock_llm
@pytest.fixture
def mock_llm_handler_creator(mock_llm_handler, monkeypatch):
monkeypatch.setattr(
"application.llm.handlers.handler_creator.LLMHandlerCreator.create_handler",
Mock(return_value=mock_llm_handler),
)
return mock_llm_handler
@pytest.fixture
def agent_base_params(decoded_token):
return {
"endpoint": "https://api.example.com",
"llm_name": "openai",
"gpt_model": "gpt-4",
"api_key": "test_api_key",
"user_api_key": None,
"prompt": "You are a helpful assistant.",
"chat_history": [],
"decoded_token": decoded_token,
"attachments": [],
"json_schema": None,
}
@pytest.fixture
def mock_tool():
tool = Mock()
tool.execute_action = Mock(return_value="Tool result")
tool.get_actions_metadata = Mock(
return_value=[
{
"name": "test_action",
"description": "A test action",
"parameters": {
"type": "object",
"properties": {
"param1": {"type": "string", "description": "Test parameter"}
},
"required": ["param1"],
},
}
]
)
return tool
@pytest.fixture
def mock_tool_manager(mock_tool, monkeypatch):
manager = Mock()
manager.load_tool = Mock(return_value=mock_tool)
monkeypatch.setattr(
"application.agents.base.ToolManager", Mock(return_value=manager)
)
return manager

View File

@@ -1,6 +1,6 @@
import types import types
import pytest
import pytest
from application.llm.openai import OpenAILLM from application.llm.openai import OpenAILLM
@@ -42,16 +42,16 @@ class FakeChatCompletions:
def create(self, **kwargs): def create(self, **kwargs):
self.last_kwargs = kwargs self.last_kwargs = kwargs
# default non-streaming: return content
if not kwargs.get("stream"): if not kwargs.get("stream"):
return FakeChatCompletions._Response(choices=[ return FakeChatCompletions._Response(
FakeChatCompletions._Choice(content="hello world") choices=[FakeChatCompletions._Choice(content="hello world")]
]) )
# streaming: yield line objects each with choices[0].delta.content return FakeChatCompletions._Response(
return FakeChatCompletions._Response(lines=[ lines=[
FakeChatCompletions._StreamLine(["part1"]), FakeChatCompletions._StreamLine(["part1"]),
FakeChatCompletions._StreamLine(["part2"]), FakeChatCompletions._StreamLine(["part2"]),
]) ]
)
class FakeClient: class FakeClient:
@@ -71,16 +71,29 @@ def openai_llm(monkeypatch):
return llm return llm
@pytest.mark.unit
def test_clean_messages_openai_variants(openai_llm): def test_clean_messages_openai_variants(openai_llm):
messages = [ messages = [
{"role": "system", "content": "sys"}, {"role": "system", "content": "sys"},
{"role": "model", "content": "asst"}, {"role": "model", "content": "asst"},
{"role": "user", "content": [ {
"role": "user",
"content": [
{"text": "hello"}, {"text": "hello"},
{"function_call": {"call_id": "c1", "name": "fn", "args": {"a": 1}}}, {"function_call": {"call_id": "c1", "name": "fn", "args": {"a": 1}}},
{"function_response": {"call_id": "c1", "name": "fn", "response": {"result": 42}}}, {
{"type": "image_url", "image_url": {"url": "data:image/png;base64,AAA"}}, "function_response": {
]}, "call_id": "c1",
"name": "fn",
"response": {"result": 42},
}
},
{
"type": "image_url",
"image_url": {"url": "data:image/png;base64,AAA"},
},
],
},
] ]
cleaned = openai_llm._clean_messages_openai(messages) cleaned = openai_llm._clean_messages_openai(messages)
@@ -89,17 +102,27 @@ def test_clean_messages_openai_variants(openai_llm):
assert roles.count("assistant") >= 1 assert roles.count("assistant") >= 1
assert any(m["role"] == "tool" for m in cleaned) assert any(m["role"] == "tool" for m in cleaned)
assert any(isinstance(m["content"], list) and any( assert any(
part.get("type") == "image_url" for part in m["content"] if isinstance(part, dict) isinstance(m["content"], list)
) for m in cleaned if m["role"] == "user") and any(
part.get("type") == "image_url"
for part in m["content"]
if isinstance(part, dict)
)
for m in cleaned
if m["role"] == "user"
)
@pytest.mark.unit
def test_raw_gen_calls_openai_client_and_returns_content(openai_llm): def test_raw_gen_calls_openai_client_and_returns_content(openai_llm):
msgs = [ msgs = [
{"role": "system", "content": "sys"}, {"role": "system", "content": "sys"},
{"role": "user", "content": "hello"}, {"role": "user", "content": "hello"},
] ]
content = openai_llm._raw_gen(openai_llm, model="gpt-4o", messages=msgs, stream=False) content = openai_llm._raw_gen(
openai_llm, model="gpt-4o", messages=msgs, stream=False
)
assert content == "hello world" assert content == "hello world"
passed = openai_llm.client.chat.completions.last_kwargs passed = openai_llm.client.chat.completions.last_kwargs
@@ -108,16 +131,20 @@ def test_raw_gen_calls_openai_client_and_returns_content(openai_llm):
assert passed["stream"] is False assert passed["stream"] is False
@pytest.mark.unit
def test_raw_gen_stream_yields_chunks(openai_llm): def test_raw_gen_stream_yields_chunks(openai_llm):
msgs = [ msgs = [
{"role": "user", "content": "hi"}, {"role": "user", "content": "hi"},
] ]
gen = openai_llm._raw_gen_stream(openai_llm, model="gpt", messages=msgs, stream=True) gen = openai_llm._raw_gen_stream(
openai_llm, model="gpt", messages=msgs, stream=True
)
chunks = list(gen) chunks = list(gen)
assert "part1" in "".join(chunks) assert "part1" in "".join(chunks)
assert "part2" in "".join(chunks) assert "part2" in "".join(chunks)
@pytest.mark.unit
def test_prepare_structured_output_format_enforces_required_and_strict(openai_llm): def test_prepare_structured_output_format_enforces_required_and_strict(openai_llm):
schema = { schema = {
"type": "object", "type": "object",
@@ -134,8 +161,8 @@ def test_prepare_structured_output_format_enforces_required_and_strict(openai_ll
assert js["schema"]["additionalProperties"] is False assert js["schema"]["additionalProperties"] is False
@pytest.mark.unit
def test_prepare_messages_with_attachments_image_and_pdf(openai_llm, monkeypatch): def test_prepare_messages_with_attachments_image_and_pdf(openai_llm, monkeypatch):
monkeypatch.setattr(openai_llm, "_get_base64_image", lambda att: "AAA=") monkeypatch.setattr(openai_llm, "_get_base64_image", lambda att: "AAA=")
monkeypatch.setattr(openai_llm, "_upload_file_to_openai", lambda att: "file_xyz") monkeypatch.setattr(openai_llm, "_upload_file_to_openai", lambda att: "file_xyz")
@@ -146,12 +173,15 @@ def test_prepare_messages_with_attachments_image_and_pdf(openai_llm, monkeypatch
] ]
out = openai_llm.prepare_messages_with_attachments(messages, attachments) out = openai_llm.prepare_messages_with_attachments(messages, attachments)
# last user message should have list content with text and two attachments
user_msg = next(m for m in out if m["role"] == "user") user_msg = next(m for m in out if m["role"] == "user")
assert isinstance(user_msg["content"], list) assert isinstance(user_msg["content"], list)
types_in_content = [p.get("type") for p in user_msg["content"] if isinstance(p, dict)] types_in_content = [
p.get("type") for p in user_msg["content"] if isinstance(p, dict)
]
assert "image_url" in types_in_content or any( assert "image_url" in types_in_content or any(
isinstance(p, dict) and p.get("image_url") for p in user_msg["content"] isinstance(p, dict) and p.get("image_url") for p in user_msg["content"]
) )
assert any(isinstance(p, dict) and p.get("file", {}).get("file_id") == "file_xyz" for p in user_msg["content"]) assert any(
isinstance(p, dict) and p.get("file", {}).get("file_id") == "file_xyz"
for p in user_msg["content"]
)

3
tests/requirements.txt Normal file
View File

@@ -0,0 +1,3 @@
pytest>=8.0.0
pytest-cov>=4.1.0
coverage>=7.4.0

View File

@@ -1,12 +1,24 @@
import base64 import base64
import pytest
from application.security import encryption
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from application.security import encryption
def _fake_os_urandom_factory(values):
values_iter = iter(values)
def _fake(length):
value = next(values_iter)
assert len(value) == length
return value
return _fake
@pytest.mark.unit
def test_derive_key_uses_secret_and_user(monkeypatch): def test_derive_key_uses_secret_and_user(monkeypatch):
monkeypatch.setattr(encryption.settings, "ENCRYPTION_SECRET_KEY", "test-secret") monkeypatch.setattr(encryption.settings, "ENCRYPTION_SECRET_KEY", "test-secret")
salt = bytes(range(16)) salt = bytes(range(16))
@@ -25,17 +37,7 @@ def test_derive_key_uses_secret_and_user(monkeypatch):
assert derived == expected_key assert derived == expected_key
def _fake_os_urandom_factory(values): @pytest.mark.unit
values_iter = iter(values)
def _fake(length):
value = next(values_iter)
assert len(value) == length
return value
return _fake
def test_encrypt_and_decrypt_round_trip(monkeypatch): def test_encrypt_and_decrypt_round_trip(monkeypatch):
monkeypatch.setattr(encryption.settings, "ENCRYPTION_SECRET_KEY", "test-secret") monkeypatch.setattr(encryption.settings, "ENCRYPTION_SECRET_KEY", "test-secret")
salt = bytes(range(16)) salt = bytes(range(16))
@@ -55,6 +57,7 @@ def test_encrypt_and_decrypt_round_trip(monkeypatch):
assert decrypted == credentials assert decrypted == credentials
@pytest.mark.unit
def test_encrypt_credentials_returns_empty_for_empty_input(monkeypatch): def test_encrypt_credentials_returns_empty_for_empty_input(monkeypatch):
monkeypatch.setattr(encryption.settings, "ENCRYPTION_SECRET_KEY", "test-secret") monkeypatch.setattr(encryption.settings, "ENCRYPTION_SECRET_KEY", "test-secret")
@@ -62,11 +65,12 @@ def test_encrypt_credentials_returns_empty_for_empty_input(monkeypatch):
assert encryption.encrypt_credentials(None, "user-123") == "" assert encryption.encrypt_credentials(None, "user-123") == ""
@pytest.mark.unit
def test_encrypt_credentials_returns_empty_on_serialization_error(monkeypatch): def test_encrypt_credentials_returns_empty_on_serialization_error(monkeypatch):
monkeypatch.setattr(encryption.settings, "ENCRYPTION_SECRET_KEY", "test-secret") monkeypatch.setattr(encryption.settings, "ENCRYPTION_SECRET_KEY", "test-secret")
monkeypatch.setattr(encryption.os, "urandom", lambda length: b"\x00" * length) monkeypatch.setattr(encryption.os, "urandom", lambda length: b"\x00" * length)
class NonSerializable: # pragma: no cover - simple helper container class NonSerializable:
pass pass
credentials = {"bad": NonSerializable()} credentials = {"bad": NonSerializable()}
@@ -74,6 +78,7 @@ def test_encrypt_credentials_returns_empty_on_serialization_error(monkeypatch):
assert encryption.encrypt_credentials(credentials, "user-123") == "" assert encryption.encrypt_credentials(credentials, "user-123") == ""
@pytest.mark.unit
def test_decrypt_credentials_returns_empty_for_invalid_input(monkeypatch): def test_decrypt_credentials_returns_empty_for_invalid_input(monkeypatch):
monkeypatch.setattr(encryption.settings, "ENCRYPTION_SECRET_KEY", "test-secret") monkeypatch.setattr(encryption.settings, "ENCRYPTION_SECRET_KEY", "test-secret")
@@ -84,6 +89,7 @@ def test_decrypt_credentials_returns_empty_for_invalid_input(monkeypatch):
assert encryption.decrypt_credentials(invalid_payload, "user-123") == {} assert encryption.decrypt_credentials(invalid_payload, "user-123") == {}
@pytest.mark.unit
def test_pad_and_unpad_are_inverse(): def test_pad_and_unpad_are_inverse():
original = b"secret-data" original = b"secret-data"
@@ -91,4 +97,3 @@ def test_pad_and_unpad_are_inverse():
assert len(padded) % 16 == 0 assert len(padded) % 16 == 0
assert encryption._unpad_data(padded) == original assert encryption._unpad_data(padded) == original

View File

@@ -1,352 +1,401 @@
"""Tests for LocalStorage implementation
"""
import io import io
import pytest import os
from unittest.mock import patch, MagicMock, mock_open from unittest.mock import MagicMock, mock_open, patch
import pytest
from application.storage.local import LocalStorage from application.storage.local import LocalStorage
@pytest.fixture @pytest.fixture
def temp_base_dir(): def temp_base_dir():
"""Provide a temporary base directory path for testing."""
return "/tmp/test_storage" return "/tmp/test_storage"
@pytest.fixture @pytest.fixture
def local_storage(temp_base_dir): def local_storage(temp_base_dir):
"""Create LocalStorage instance with test base directory."""
return LocalStorage(base_dir=temp_base_dir) return LocalStorage(base_dir=temp_base_dir)
@pytest.mark.unit
class TestLocalStorageInitialization: class TestLocalStorageInitialization:
"""Test LocalStorage initialization and configuration."""
def test_init_with_custom_base_dir(self): def test_init_with_custom_base_dir(self):
"""Should use provided base directory."""
storage = LocalStorage(base_dir="/custom/path") storage = LocalStorage(base_dir="/custom/path")
assert storage.base_dir == "/custom/path" assert storage.base_dir == "/custom/path"
def test_init_with_default_base_dir(self): def test_init_with_default_base_dir(self):
"""Should use default base directory when none provided."""
storage = LocalStorage() storage = LocalStorage()
# Default is three levels up from the file location
assert storage.base_dir is not None assert storage.base_dir is not None
assert isinstance(storage.base_dir, str) assert isinstance(storage.base_dir, str)
def test_get_full_path_with_relative_path(self, local_storage): def test_get_full_path_with_relative_path(self, local_storage):
"""Should combine base_dir with relative path."""
result = local_storage._get_full_path("documents/test.txt") result = local_storage._get_full_path("documents/test.txt")
assert result == "/tmp/test_storage/documents/test.txt" expected = os.path.join("/tmp/test_storage", "documents/test.txt")
assert os.path.normpath(result) == os.path.normpath(expected)
def test_get_full_path_with_absolute_path(self, local_storage): def test_get_full_path_with_absolute_path(self, local_storage):
"""Should return absolute path unchanged."""
result = local_storage._get_full_path("/absolute/path/test.txt") result = local_storage._get_full_path("/absolute/path/test.txt")
assert result == "/absolute/path/test.txt" assert result == "/absolute/path/test.txt"
@patch("os.makedirs")
class TestLocalStorageSaveFile: @patch("builtins.open", new_callable=mock_open)
"""Test file saving functionality.""" @patch("shutil.copyfileobj")
@patch('os.makedirs')
@patch('builtins.open', new_callable=mock_open)
@patch('shutil.copyfileobj')
def test_save_file_creates_directory_and_saves( def test_save_file_creates_directory_and_saves(
self, mock_copyfileobj, mock_file, mock_makedirs, local_storage self, mock_copyfileobj, mock_file, mock_makedirs, local_storage
): ):
"""Should create directory and save file content."""
file_data = io.BytesIO(b"test content") file_data = io.BytesIO(b"test content")
path = "documents/test.txt" path = "documents/test.txt"
result = local_storage.save_file(file_data, path) result = local_storage.save_file(file_data, path)
# Verify directory creation expected_dir = os.path.join("/tmp/test_storage", "documents")
mock_makedirs.assert_called_once_with( expected_file = os.path.join("/tmp/test_storage", "documents/test.txt")
"/tmp/test_storage/documents",
exist_ok=True assert mock_makedirs.call_count == 1
assert os.path.normpath(mock_makedirs.call_args[0][0]) == os.path.normpath(
expected_dir
) )
assert mock_makedirs.call_args[1] == {"exist_ok": True}
assert mock_file.call_count == 1
assert os.path.normpath(mock_file.call_args[0][0]) == os.path.normpath(
expected_file
)
assert mock_file.call_args[0][1] == "wb"
# Verify file write
mock_file.assert_called_once_with("/tmp/test_storage/documents/test.txt", 'wb')
mock_copyfileobj.assert_called_once_with(file_data, mock_file()) mock_copyfileobj.assert_called_once_with(file_data, mock_file())
assert result == {"storage_type": "local"}
# Verify result @patch("os.makedirs")
assert result == {'storage_type': 'local'}
@patch('os.makedirs')
def test_save_file_with_save_method(self, mock_makedirs, local_storage): def test_save_file_with_save_method(self, mock_makedirs, local_storage):
"""Should use save method if file_data has it."""
file_data = MagicMock() file_data = MagicMock()
file_data.save = MagicMock() file_data.save = MagicMock()
path = "documents/test.txt" path = "documents/test.txt"
result = local_storage.save_file(file_data, path) result = local_storage.save_file(file_data, path)
# Verify save method was called expected_file = os.path.join("/tmp/test_storage", "documents/test.txt")
file_data.save.assert_called_once_with("/tmp/test_storage/documents/test.txt") assert file_data.save.call_count == 1
assert os.path.normpath(file_data.save.call_args[0][0]) == os.path.normpath(
expected_file
)
assert result == {"storage_type": "local"}
# Verify result @patch("os.makedirs")
assert result == {'storage_type': 'local'} @patch("builtins.open", new_callable=mock_open)
def test_save_file_with_absolute_path(
@patch('os.makedirs') self, mock_file, mock_makedirs, local_storage
@patch('builtins.open', new_callable=mock_open) ):
def test_save_file_with_absolute_path(self, mock_file, mock_makedirs, local_storage):
"""Should handle absolute paths correctly."""
file_data = io.BytesIO(b"test content") file_data = io.BytesIO(b"test content")
path = "/absolute/path/test.txt" path = "/absolute/path/test.txt"
local_storage.save_file(file_data, path) local_storage.save_file(file_data, path)
mock_makedirs.assert_called_once_with("/absolute/path", exist_ok=True) mock_makedirs.assert_called_once_with("/absolute/path", exist_ok=True)
mock_file.assert_called_once_with("/absolute/path/test.txt", 'wb') mock_file.assert_called_once_with("/absolute/path/test.txt", "wb")
@pytest.mark.unit
class TestLocalStorageGetFile: class TestLocalStorageGetFile:
"""Test file retrieval functionality."""
@patch('os.path.exists', return_value=True) @patch("os.path.exists", return_value=True)
@patch('builtins.open', new_callable=mock_open, read_data=b"file content") @patch("builtins.open", new_callable=mock_open, read_data=b"file content")
def test_get_file_returns_file_handle(self, mock_file, mock_exists, local_storage): def test_get_file_returns_file_handle(self, mock_file, mock_exists, local_storage):
"""Should open and return file handle when file exists."""
path = "documents/test.txt" path = "documents/test.txt"
result = local_storage.get_file(path) result = local_storage.get_file(path)
mock_exists.assert_called_once_with("/tmp/test_storage/documents/test.txt") expected_path = os.path.join("/tmp/test_storage", "documents/test.txt")
mock_file.assert_called_once_with("/tmp/test_storage/documents/test.txt", 'rb') assert mock_exists.call_count == 1
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
expected_path
)
assert mock_file.call_count == 1
assert os.path.normpath(mock_file.call_args[0][0]) == os.path.normpath(
expected_path
)
assert result is not None assert result is not None
@patch('os.path.exists', return_value=False) @patch("os.path.exists", return_value=False)
def test_get_file_raises_error_when_not_found(self, mock_exists, local_storage): def test_get_file_raises_error_when_not_found(self, mock_exists, local_storage):
"""Should raise FileNotFoundError when file doesn't exist."""
path = "documents/nonexistent.txt" path = "documents/nonexistent.txt"
with pytest.raises(FileNotFoundError, match="File not found"): with pytest.raises(FileNotFoundError, match="File not found"):
local_storage.get_file(path) local_storage.get_file(path)
expected_path = os.path.join("/tmp/test_storage", "documents/nonexistent.txt")
mock_exists.assert_called_once_with("/tmp/test_storage/documents/nonexistent.txt") assert mock_exists.call_count == 1
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
expected_path
)
@pytest.mark.unit
class TestLocalStorageDeleteFile: class TestLocalStorageDeleteFile:
"""Test file deletion functionality."""
@patch('os.remove') @patch("os.remove")
@patch('os.path.exists', return_value=True) @patch("os.path.exists", return_value=True)
def test_delete_file_removes_existing_file(self, mock_exists, mock_remove, local_storage): def test_delete_file_removes_existing_file(
"""Should delete file and return True when file exists.""" self, mock_exists, mock_remove, local_storage
):
path = "documents/test.txt" path = "documents/test.txt"
result = local_storage.delete_file(path) result = local_storage.delete_file(path)
expected_path = os.path.join("/tmp/test_storage", "documents/test.txt")
assert result is True assert result is True
mock_exists.assert_called_once_with("/tmp/test_storage/documents/test.txt") assert mock_exists.call_count == 1
mock_remove.assert_called_once_with("/tmp/test_storage/documents/test.txt") assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
expected_path
)
assert mock_remove.call_count == 1
assert os.path.normpath(mock_remove.call_args[0][0]) == os.path.normpath(
expected_path
)
@patch('os.path.exists', return_value=False) @patch("os.path.exists", return_value=False)
def test_delete_file_returns_false_when_not_found(self, mock_exists, local_storage): def test_delete_file_returns_false_when_not_found(self, mock_exists, local_storage):
"""Should return False when file doesn't exist."""
path = "documents/nonexistent.txt" path = "documents/nonexistent.txt"
result = local_storage.delete_file(path) result = local_storage.delete_file(path)
expected_path = os.path.join("/tmp/test_storage", "documents/nonexistent.txt")
assert result is False assert result is False
mock_exists.assert_called_once_with("/tmp/test_storage/documents/nonexistent.txt") assert mock_exists.call_count == 1
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
expected_path
)
@pytest.mark.unit
class TestLocalStorageFileExists: class TestLocalStorageFileExists:
"""Test file existence checking."""
@patch('os.path.exists', return_value=True) @patch("os.path.exists", return_value=True)
def test_file_exists_returns_true_when_file_found(self, mock_exists, local_storage): def test_file_exists_returns_true_when_file_found(self, mock_exists, local_storage):
"""Should return True when file exists."""
path = "documents/test.txt" path = "documents/test.txt"
result = local_storage.file_exists(path) result = local_storage.file_exists(path)
expected_path = os.path.join("/tmp/test_storage", "documents/test.txt")
assert result is True assert result is True
mock_exists.assert_called_once_with("/tmp/test_storage/documents/test.txt") assert mock_exists.call_count == 1
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
expected_path
)
@patch('os.path.exists', return_value=False) @patch("os.path.exists", return_value=False)
def test_file_exists_returns_false_when_not_found(self, mock_exists, local_storage): def test_file_exists_returns_false_when_not_found(self, mock_exists, local_storage):
"""Should return False when file doesn't exist."""
path = "documents/nonexistent.txt" path = "documents/nonexistent.txt"
result = local_storage.file_exists(path) result = local_storage.file_exists(path)
expected_path = os.path.join("/tmp/test_storage", "documents/nonexistent.txt")
assert result is False assert result is False
mock_exists.assert_called_once_with("/tmp/test_storage/documents/nonexistent.txt") assert mock_exists.call_count == 1
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
expected_path
)
@pytest.mark.unit
class TestLocalStorageListFiles: class TestLocalStorageListFiles:
"""Test directory listing functionality."""
@patch('os.walk') @patch("os.walk")
@patch('os.path.exists', return_value=True) @patch("os.path.exists", return_value=True)
def test_list_files_returns_all_files_in_directory( def test_list_files_returns_all_files_in_directory(
self, mock_exists, mock_walk, local_storage self, mock_exists, mock_walk, local_storage
): ):
"""Should return all files in directory and subdirectories."""
directory = "documents" directory = "documents"
base_dir = os.path.join("/tmp/test_storage", "documents")
# Mock os.walk to return files in directory structure
mock_walk.return_value = [ mock_walk.return_value = [
("/tmp/test_storage/documents", ["subdir"], ["file1.txt", "file2.txt"]), (base_dir, ["subdir"], ["file1.txt", "file2.txt"]),
("/tmp/test_storage/documents/subdir", [], ["file3.txt"]) (os.path.join(base_dir, "subdir"), [], ["file3.txt"]),
] ]
result = local_storage.list_files(directory) result = local_storage.list_files(directory)
assert len(result) == 3 assert len(result) == 3
assert "documents/file1.txt" in result result_normalized = [os.path.normpath(f) for f in result]
assert "documents/file2.txt" in result assert os.path.normpath("documents/file1.txt") in result_normalized
assert "documents/subdir/file3.txt" in result assert os.path.normpath("documents/file2.txt") in result_normalized
assert os.path.normpath("documents/subdir/file3.txt") in result_normalized
mock_exists.assert_called_once_with("/tmp/test_storage/documents") @patch("os.path.exists", return_value=False)
mock_walk.assert_called_once_with("/tmp/test_storage/documents")
@patch('os.path.exists', return_value=False)
def test_list_files_returns_empty_list_when_directory_not_found( def test_list_files_returns_empty_list_when_directory_not_found(
self, mock_exists, local_storage self, mock_exists, local_storage
): ):
"""Should return empty list when directory doesn't exist."""
directory = "nonexistent" directory = "nonexistent"
result = local_storage.list_files(directory) result = local_storage.list_files(directory)
expected_path = os.path.join("/tmp/test_storage", "nonexistent")
assert result == [] assert result == []
mock_exists.assert_called_once_with("/tmp/test_storage/nonexistent") assert mock_exists.call_count == 1
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
expected_path
)
@pytest.mark.unit
class TestLocalStorageProcessFile: class TestLocalStorageProcessFile:
"""Test file processing functionality."""
@patch('os.path.exists', return_value=True) @patch("os.path.exists", return_value=True)
def test_process_file_calls_processor_with_full_path( def test_process_file_calls_processor_with_full_path(
self, mock_exists, local_storage self, mock_exists, local_storage
): ):
"""Should call processor function with full file path."""
path = "documents/test.txt" path = "documents/test.txt"
processor_func = MagicMock(return_value="processed") processor_func = MagicMock(return_value="processed")
result = local_storage.process_file(path, processor_func, extra_arg="value") result = local_storage.process_file(path, processor_func, extra_arg="value")
expected_path = os.path.join("/tmp/test_storage", "documents/test.txt")
assert result == "processed" assert result == "processed"
processor_func.assert_called_once_with( assert processor_func.call_count == 1
local_path="/tmp/test_storage/documents/test.txt", call_kwargs = processor_func.call_args[1]
extra_arg="value" assert os.path.normpath(call_kwargs["local_path"]) == os.path.normpath(
expected_path
) )
mock_exists.assert_called_once_with("/tmp/test_storage/documents/test.txt") assert call_kwargs["extra_arg"] == "value"
@patch('os.path.exists', return_value=False) @patch("os.path.exists", return_value=False)
def test_process_file_raises_error_when_file_not_found(self, mock_exists, local_storage): def test_process_file_raises_error_when_file_not_found(
"""Should raise FileNotFoundError when file doesn't exist.""" self, mock_exists, local_storage
):
path = "documents/nonexistent.txt" path = "documents/nonexistent.txt"
processor_func = MagicMock() processor_func = MagicMock()
with pytest.raises(FileNotFoundError, match="File not found"): with pytest.raises(FileNotFoundError, match="File not found"):
local_storage.process_file(path, processor_func) local_storage.process_file(path, processor_func)
processor_func.assert_not_called() processor_func.assert_not_called()
@pytest.mark.unit
class TestLocalStorageIsDirectory: class TestLocalStorageIsDirectory:
"""Test directory checking functionality."""
@patch('os.path.isdir', return_value=True) @patch("os.path.isdir", return_value=True)
def test_is_directory_returns_true_when_directory_exists( def test_is_directory_returns_true_when_directory_exists(
self, mock_isdir, local_storage self, mock_isdir, local_storage
): ):
"""Should return True when path is a directory."""
path = "documents" path = "documents"
result = local_storage.is_directory(path) result = local_storage.is_directory(path)
expected_path = os.path.join("/tmp/test_storage", "documents")
assert result is True assert result is True
mock_isdir.assert_called_once_with("/tmp/test_storage/documents") assert mock_isdir.call_count == 1
assert os.path.normpath(mock_isdir.call_args[0][0]) == os.path.normpath(
expected_path
)
@patch('os.path.isdir', return_value=False) @patch("os.path.isdir", return_value=False)
def test_is_directory_returns_false_when_not_directory( def test_is_directory_returns_false_when_not_directory(
self, mock_isdir, local_storage self, mock_isdir, local_storage
): ):
"""Should return False when path is not a directory or doesn't exist."""
path = "documents/test.txt" path = "documents/test.txt"
result = local_storage.is_directory(path) result = local_storage.is_directory(path)
expected_path = os.path.join("/tmp/test_storage", "documents/test.txt")
assert result is False assert result is False
mock_isdir.assert_called_once_with("/tmp/test_storage/documents/test.txt") assert mock_isdir.call_count == 1
assert os.path.normpath(mock_isdir.call_args[0][0]) == os.path.normpath(
expected_path
)
@pytest.mark.unit
class TestLocalStorageRemoveDirectory: class TestLocalStorageRemoveDirectory:
"""Test directory removal functionality."""
@patch('shutil.rmtree') @patch("shutil.rmtree")
@patch('os.path.isdir', return_value=True) @patch("os.path.isdir", return_value=True)
@patch('os.path.exists', return_value=True) @patch("os.path.exists", return_value=True)
def test_remove_directory_deletes_directory( def test_remove_directory_deletes_directory(
self, mock_exists, mock_isdir, mock_rmtree, local_storage self, mock_exists, mock_isdir, mock_rmtree, local_storage
): ):
"""Should remove directory and return True when successful."""
directory = "documents" directory = "documents"
result = local_storage.remove_directory(directory) result = local_storage.remove_directory(directory)
expected_path = os.path.join("/tmp/test_storage", "documents")
assert result is True assert result is True
mock_exists.assert_called_once_with("/tmp/test_storage/documents") assert mock_exists.call_count == 1
mock_isdir.assert_called_once_with("/tmp/test_storage/documents") assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
mock_rmtree.assert_called_once_with("/tmp/test_storage/documents") expected_path
)
assert mock_isdir.call_count == 1
assert os.path.normpath(mock_isdir.call_args[0][0]) == os.path.normpath(
expected_path
)
assert mock_rmtree.call_count == 1
assert os.path.normpath(mock_rmtree.call_args[0][0]) == os.path.normpath(
expected_path
)
@patch('os.path.exists', return_value=False) @patch("os.path.exists", return_value=False)
def test_remove_directory_returns_false_when_not_exists( def test_remove_directory_returns_false_when_not_exists(
self, mock_exists, local_storage self, mock_exists, local_storage
): ):
"""Should return False when directory doesn't exist."""
directory = "nonexistent" directory = "nonexistent"
result = local_storage.remove_directory(directory) result = local_storage.remove_directory(directory)
expected_path = os.path.join("/tmp/test_storage", "nonexistent")
assert result is False assert result is False
mock_exists.assert_called_once_with("/tmp/test_storage/nonexistent") assert mock_exists.call_count == 1
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
expected_path
)
@patch('os.path.isdir', return_value=False) @patch("os.path.isdir", return_value=False)
@patch('os.path.exists', return_value=True) @patch("os.path.exists", return_value=True)
def test_remove_directory_returns_false_when_not_directory( def test_remove_directory_returns_false_when_not_directory(
self, mock_exists, mock_isdir, local_storage self, mock_exists, mock_isdir, local_storage
): ):
"""Should return False when path is not a directory."""
path = "documents/test.txt" path = "documents/test.txt"
result = local_storage.remove_directory(path) result = local_storage.remove_directory(path)
expected_path = os.path.join("/tmp/test_storage", "documents/test.txt")
assert result is False assert result is False
mock_exists.assert_called_once_with("/tmp/test_storage/documents/test.txt") assert mock_exists.call_count == 1
mock_isdir.assert_called_once_with("/tmp/test_storage/documents/test.txt") assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
expected_path
)
assert mock_isdir.call_count == 1
assert os.path.normpath(mock_isdir.call_args[0][0]) == os.path.normpath(
expected_path
)
@patch('shutil.rmtree', side_effect=OSError("Permission denied")) @patch("shutil.rmtree", side_effect=OSError("Permission denied"))
@patch('os.path.isdir', return_value=True) @patch("os.path.isdir", return_value=True)
@patch('os.path.exists', return_value=True) @patch("os.path.exists", return_value=True)
def test_remove_directory_returns_false_on_os_error( def test_remove_directory_returns_false_on_os_error(
self, mock_exists, mock_isdir, mock_rmtree, local_storage self, mock_exists, mock_isdir, mock_rmtree, local_storage
): ):
"""Should return False when OSError occurs during removal."""
directory = "documents" directory = "documents"
result = local_storage.remove_directory(directory) result = local_storage.remove_directory(directory)
expected_path = os.path.join("/tmp/test_storage", "documents")
assert result is False assert result is False
mock_rmtree.assert_called_once_with("/tmp/test_storage/documents") assert mock_rmtree.call_count == 1
assert os.path.normpath(mock_rmtree.call_args[0][0]) == os.path.normpath(
expected_path
)
@patch('shutil.rmtree', side_effect=PermissionError("Access denied")) @patch("shutil.rmtree", side_effect=PermissionError("Access denied"))
@patch('os.path.isdir', return_value=True) @patch("os.path.isdir", return_value=True)
@patch('os.path.exists', return_value=True) @patch("os.path.exists", return_value=True)
def test_remove_directory_returns_false_on_permission_error( def test_remove_directory_returns_false_on_permission_error(
self, mock_exists, mock_isdir, mock_rmtree, local_storage self, mock_exists, mock_isdir, mock_rmtree, local_storage
): ):
"""Should return False when PermissionError occurs during removal."""
directory = "documents" directory = "documents"
result = local_storage.remove_directory(directory) result = local_storage.remove_directory(directory)
expected_path = os.path.join("/tmp/test_storage", "documents")
assert result is False assert result is False
mock_rmtree.assert_called_once_with("/tmp/test_storage/documents") assert mock_rmtree.call_count == 1
assert os.path.normpath(mock_rmtree.call_args[0][0]) == os.path.normpath(
expected_path
)

View File

@@ -1,18 +1,18 @@
"""Tests for S3 storage implementation. """Tests for S3 storage implementation."""
"""
import io import io
from unittest.mock import MagicMock, patch
import pytest import pytest
from unittest.mock import patch, MagicMock
from botocore.exceptions import ClientError
from application.storage.s3 import S3Storage from application.storage.s3 import S3Storage
from botocore.exceptions import ClientError
@pytest.fixture @pytest.fixture
def mock_boto3_client(): def mock_boto3_client():
"""Mock boto3.client to isolate S3 client creation.""" """Mock boto3.client to isolate S3 client creation."""
with patch('boto3.client') as mock_client: with patch("boto3.client") as mock_client:
s3_mock = MagicMock() s3_mock = MagicMock()
mock_client.return_value = s3_mock mock_client.return_value = s3_mock
yield s3_mock yield s3_mock
@@ -27,22 +27,26 @@ def s3_storage(mock_boto3_client):
class TestS3StorageInitialization: class TestS3StorageInitialization:
"""Test S3Storage initialization and configuration.""" """Test S3Storage initialization and configuration."""
@pytest.mark.unit
def test_init_with_default_bucket(self): def test_init_with_default_bucket(self):
"""Should use default bucket name when none provided.""" """Should use default bucket name when none provided."""
with patch('boto3.client'): with patch("boto3.client"):
storage = S3Storage() storage = S3Storage()
assert storage.bucket_name == "docsgpt-test-bucket" assert storage.bucket_name == "docsgpt-test-bucket"
@pytest.mark.unit
def test_init_with_custom_bucket(self): def test_init_with_custom_bucket(self):
"""Should use provided bucket name.""" """Should use provided bucket name."""
with patch('boto3.client'): with patch("boto3.client"):
storage = S3Storage(bucket_name="custom-bucket") storage = S3Storage(bucket_name="custom-bucket")
assert storage.bucket_name == "custom-bucket" assert storage.bucket_name == "custom-bucket"
@pytest.mark.unit
def test_init_creates_boto3_client(self): def test_init_creates_boto3_client(self):
"""Should create boto3 S3 client with credentials from settings.""" """Should create boto3 S3 client with credentials from settings."""
with patch('boto3.client') as mock_client, \ with patch("boto3.client") as mock_client, patch(
patch('application.storage.s3.settings') as mock_settings: "application.storage.s3.settings"
) as mock_settings:
mock_settings.SAGEMAKER_ACCESS_KEY = "test-key" mock_settings.SAGEMAKER_ACCESS_KEY = "test-key"
mock_settings.SAGEMAKER_SECRET_KEY = "test-secret" mock_settings.SAGEMAKER_SECRET_KEY = "test-secret"
@@ -54,52 +58,50 @@ class TestS3StorageInitialization:
"s3", "s3",
aws_access_key_id="test-key", aws_access_key_id="test-key",
aws_secret_access_key="test-secret", aws_secret_access_key="test-secret",
region_name="us-west-2" region_name="us-west-2",
) )
class TestS3StorageSaveFile: class TestS3StorageSaveFile:
"""Test file saving functionality.""" """Test file saving functionality."""
@pytest.mark.unit
def test_save_file_uploads_to_s3(self, s3_storage, mock_boto3_client): def test_save_file_uploads_to_s3(self, s3_storage, mock_boto3_client):
"""Should upload file to S3 with correct parameters.""" """Should upload file to S3 with correct parameters."""
file_data = io.BytesIO(b"test content") file_data = io.BytesIO(b"test content")
path = "documents/test.txt" path = "documents/test.txt"
with patch('application.storage.s3.settings') as mock_settings: with patch("application.storage.s3.settings") as mock_settings:
mock_settings.SAGEMAKER_REGION = "us-east-1" mock_settings.SAGEMAKER_REGION = "us-east-1"
result = s3_storage.save_file(file_data, path) result = s3_storage.save_file(file_data, path)
mock_boto3_client.upload_fileobj.assert_called_once_with( mock_boto3_client.upload_fileobj.assert_called_once_with(
file_data, file_data,
"test-bucket", "test-bucket",
path, path,
ExtraArgs={"StorageClass": "INTELLIGENT_TIERING"} ExtraArgs={"StorageClass": "INTELLIGENT_TIERING"},
) )
assert result == { assert result == {
"storage_type": "s3", "storage_type": "s3",
"bucket_name": "test-bucket", "bucket_name": "test-bucket",
"uri": "s3://test-bucket/documents/test.txt", "uri": "s3://test-bucket/documents/test.txt",
"region": "us-east-1" "region": "us-east-1",
} }
@pytest.mark.unit
def test_save_file_with_custom_storage_class(self, s3_storage, mock_boto3_client): def test_save_file_with_custom_storage_class(self, s3_storage, mock_boto3_client):
"""Should use custom storage class when provided.""" """Should use custom storage class when provided."""
file_data = io.BytesIO(b"test content") file_data = io.BytesIO(b"test content")
path = "documents/test.txt" path = "documents/test.txt"
with patch('application.storage.s3.settings') as mock_settings: with patch("application.storage.s3.settings") as mock_settings:
mock_settings.SAGEMAKER_REGION = "us-east-1" mock_settings.SAGEMAKER_REGION = "us-east-1"
s3_storage.save_file(file_data, path, storage_class="STANDARD") s3_storage.save_file(file_data, path, storage_class="STANDARD")
mock_boto3_client.upload_fileobj.assert_called_once_with( mock_boto3_client.upload_fileobj.assert_called_once_with(
file_data, file_data, "test-bucket", path, ExtraArgs={"StorageClass": "STANDARD"}
"test-bucket",
path,
ExtraArgs={"StorageClass": "STANDARD"}
) )
@pytest.mark.unit
def test_save_file_propagates_client_error(self, s3_storage, mock_boto3_client): def test_save_file_propagates_client_error(self, s3_storage, mock_boto3_client):
"""Should propagate ClientError when upload fails.""" """Should propagate ClientError when upload fails."""
file_data = io.BytesIO(b"test content") file_data = io.BytesIO(b"test content")
@@ -107,7 +109,7 @@ class TestS3StorageSaveFile:
mock_boto3_client.upload_fileobj.side_effect = ClientError( mock_boto3_client.upload_fileobj.side_effect = ClientError(
{"Error": {"Code": "AccessDenied", "Message": "Access denied"}}, {"Error": {"Code": "AccessDenied", "Message": "Access denied"}},
"upload_fileobj" "upload_fileobj",
) )
with pytest.raises(ClientError): with pytest.raises(ClientError):
@@ -117,7 +119,10 @@ class TestS3StorageSaveFile:
class TestS3StorageFileExists: class TestS3StorageFileExists:
"""Test file existence checking.""" """Test file existence checking."""
def test_file_exists_returns_true_when_file_found(self, s3_storage, mock_boto3_client): @pytest.mark.unit
def test_file_exists_returns_true_when_file_found(
self, s3_storage, mock_boto3_client
):
"""Should return True when head_object succeeds.""" """Should return True when head_object succeeds."""
path = "documents/test.txt" path = "documents/test.txt"
mock_boto3_client.head_object.return_value = {"ContentLength": 100} mock_boto3_client.head_object.return_value = {"ContentLength": 100}
@@ -126,16 +131,17 @@ class TestS3StorageFileExists:
assert result is True assert result is True
mock_boto3_client.head_object.assert_called_once_with( mock_boto3_client.head_object.assert_called_once_with(
Bucket="test-bucket", Bucket="test-bucket", Key=path
Key=path
) )
def test_file_exists_returns_false_on_client_error(self, s3_storage, mock_boto3_client): @pytest.mark.unit
def test_file_exists_returns_false_on_client_error(
self, s3_storage, mock_boto3_client
):
"""Should return False when head_object raises ClientError.""" """Should return False when head_object raises ClientError."""
path = "documents/nonexistent.txt" path = "documents/nonexistent.txt"
mock_boto3_client.head_object.side_effect = ClientError( mock_boto3_client.head_object.side_effect = ClientError(
{"Error": {"Code": "NoSuchKey", "Message": "Not found"}}, {"Error": {"Code": "NoSuchKey", "Message": "Not found"}}, "head_object"
"head_object"
) )
result = s3_storage.file_exists(path) result = s3_storage.file_exists(path)
@@ -146,7 +152,10 @@ class TestS3StorageFileExists:
class TestS3StorageGetFile: class TestS3StorageGetFile:
"""Test file retrieval functionality.""" """Test file retrieval functionality."""
def test_get_file_downloads_and_returns_file_object(self, s3_storage, mock_boto3_client): @pytest.mark.unit
def test_get_file_downloads_and_returns_file_object(
self, s3_storage, mock_boto3_client
):
"""Should download file from S3 and return BytesIO object.""" """Should download file from S3 and return BytesIO object."""
path = "documents/test.txt" path = "documents/test.txt"
test_content = b"file content" test_content = b"file content"
@@ -164,12 +173,14 @@ class TestS3StorageGetFile:
assert result.read() == test_content assert result.read() == test_content
mock_boto3_client.download_fileobj.assert_called_once() mock_boto3_client.download_fileobj.assert_called_once()
def test_get_file_raises_error_when_file_not_found(self, s3_storage, mock_boto3_client): @pytest.mark.unit
def test_get_file_raises_error_when_file_not_found(
self, s3_storage, mock_boto3_client
):
"""Should raise FileNotFoundError when file doesn't exist.""" """Should raise FileNotFoundError when file doesn't exist."""
path = "documents/nonexistent.txt" path = "documents/nonexistent.txt"
mock_boto3_client.head_object.side_effect = ClientError( mock_boto3_client.head_object.side_effect = ClientError(
{"Error": {"Code": "NoSuchKey", "Message": "Not found"}}, {"Error": {"Code": "NoSuchKey", "Message": "Not found"}}, "head_object"
"head_object"
) )
with pytest.raises(FileNotFoundError, match="File not found"): with pytest.raises(FileNotFoundError, match="File not found"):
@@ -179,6 +190,7 @@ class TestS3StorageGetFile:
class TestS3StorageDeleteFile: class TestS3StorageDeleteFile:
"""Test file deletion functionality.""" """Test file deletion functionality."""
@pytest.mark.unit
def test_delete_file_returns_true_on_success(self, s3_storage, mock_boto3_client): def test_delete_file_returns_true_on_success(self, s3_storage, mock_boto3_client):
"""Should return True when deletion succeeds.""" """Should return True when deletion succeeds."""
path = "documents/test.txt" path = "documents/test.txt"
@@ -188,16 +200,18 @@ class TestS3StorageDeleteFile:
assert result is True assert result is True
mock_boto3_client.delete_object.assert_called_once_with( mock_boto3_client.delete_object.assert_called_once_with(
Bucket="test-bucket", Bucket="test-bucket", Key=path
Key=path
) )
def test_delete_file_returns_false_on_client_error(self, s3_storage, mock_boto3_client): @pytest.mark.unit
def test_delete_file_returns_false_on_client_error(
self, s3_storage, mock_boto3_client
):
"""Should return False when deletion fails with ClientError.""" """Should return False when deletion fails with ClientError."""
path = "documents/test.txt" path = "documents/test.txt"
mock_boto3_client.delete_object.side_effect = ClientError( mock_boto3_client.delete_object.side_effect = ClientError(
{"Error": {"Code": "AccessDenied", "Message": "Access denied"}}, {"Error": {"Code": "AccessDenied", "Message": "Access denied"}},
"delete_object" "delete_object",
) )
result = s3_storage.delete_file(path) result = s3_storage.delete_file(path)
@@ -208,7 +222,10 @@ class TestS3StorageDeleteFile:
class TestS3StorageListFiles: class TestS3StorageListFiles:
"""Test directory listing functionality.""" """Test directory listing functionality."""
def test_list_files_returns_all_keys_with_prefix(self, s3_storage, mock_boto3_client): @pytest.mark.unit
def test_list_files_returns_all_keys_with_prefix(
self, s3_storage, mock_boto3_client
):
"""Should return all file keys matching the directory prefix.""" """Should return all file keys matching the directory prefix."""
directory = "documents/" directory = "documents/"
@@ -219,7 +236,7 @@ class TestS3StorageListFiles:
"Contents": [ "Contents": [
{"Key": "documents/file1.txt"}, {"Key": "documents/file1.txt"},
{"Key": "documents/file2.txt"}, {"Key": "documents/file2.txt"},
{"Key": "documents/subdir/file3.txt"} {"Key": "documents/subdir/file3.txt"},
] ]
} }
] ]
@@ -231,13 +248,15 @@ class TestS3StorageListFiles:
assert "documents/file2.txt" in result assert "documents/file2.txt" in result
assert "documents/subdir/file3.txt" in result assert "documents/subdir/file3.txt" in result
mock_boto3_client.get_paginator.assert_called_once_with('list_objects_v2') mock_boto3_client.get_paginator.assert_called_once_with("list_objects_v2")
paginator_mock.paginate.assert_called_once_with( paginator_mock.paginate.assert_called_once_with(
Bucket="test-bucket", Bucket="test-bucket", Prefix="documents/"
Prefix="documents/"
) )
def test_list_files_returns_empty_list_when_no_contents(self, s3_storage, mock_boto3_client): @pytest.mark.unit
def test_list_files_returns_empty_list_when_no_contents(
self, s3_storage, mock_boto3_client
):
"""Should return empty list when directory has no files.""" """Should return empty list when directory has no files."""
directory = "empty/" directory = "empty/"
@@ -253,30 +272,36 @@ class TestS3StorageListFiles:
class TestS3StorageProcessFile: class TestS3StorageProcessFile:
"""Test file processing functionality.""" """Test file processing functionality."""
def test_process_file_downloads_and_processes_file(self, s3_storage, mock_boto3_client): @pytest.mark.unit
def test_process_file_downloads_and_processes_file(
self, s3_storage, mock_boto3_client
):
"""Should download file to temp location and call processor function.""" """Should download file to temp location and call processor function."""
path = "documents/test.txt" path = "documents/test.txt"
mock_boto3_client.head_object.return_value = {} mock_boto3_client.head_object.return_value = {}
with patch('tempfile.NamedTemporaryFile') as mock_temp: with patch("tempfile.NamedTemporaryFile") as mock_temp:
mock_file = MagicMock() mock_file = MagicMock()
mock_file.name = "/tmp/test_file" mock_file.name = "/tmp/test_file"
mock_temp.return_value.__enter__.return_value = mock_file mock_temp.return_value.__enter__.return_value = mock_file
processor_func = MagicMock(return_value="processed") processor_func = MagicMock(return_value="processed")
result = s3_storage.process_file(path, processor_func, extra_arg="value") result = s3_storage.process_file(path, processor_func, extra_arg="value")
assert result == "processed" assert result == "processed"
processor_func.assert_called_once_with(local_path="/tmp/test_file", extra_arg="value") processor_func.assert_called_once_with(
local_path="/tmp/test_file", extra_arg="value"
)
mock_boto3_client.download_fileobj.assert_called_once() mock_boto3_client.download_fileobj.assert_called_once()
def test_process_file_raises_error_when_file_not_found(self, s3_storage, mock_boto3_client): @pytest.mark.unit
def test_process_file_raises_error_when_file_not_found(
self, s3_storage, mock_boto3_client
):
"""Should raise FileNotFoundError when file doesn't exist.""" """Should raise FileNotFoundError when file doesn't exist."""
path = "documents/nonexistent.txt" path = "documents/nonexistent.txt"
mock_boto3_client.head_object.side_effect = ClientError( mock_boto3_client.head_object.side_effect = ClientError(
{"Error": {"Code": "NoSuchKey", "Message": "Not found"}}, {"Error": {"Code": "NoSuchKey", "Message": "Not found"}}, "head_object"
"head_object"
) )
processor_func = MagicMock() processor_func = MagicMock()
@@ -288,7 +313,10 @@ class TestS3StorageProcessFile:
class TestS3StorageIsDirectory: class TestS3StorageIsDirectory:
"""Test directory checking functionality.""" """Test directory checking functionality."""
def test_is_directory_returns_true_when_objects_exist(self, s3_storage, mock_boto3_client): @pytest.mark.unit
def test_is_directory_returns_true_when_objects_exist(
self, s3_storage, mock_boto3_client
):
"""Should return True when objects exist with the directory prefix.""" """Should return True when objects exist with the directory prefix."""
path = "documents/" path = "documents/"
@@ -300,12 +328,13 @@ class TestS3StorageIsDirectory:
assert result is True assert result is True
mock_boto3_client.list_objects_v2.assert_called_once_with( mock_boto3_client.list_objects_v2.assert_called_once_with(
Bucket="test-bucket", Bucket="test-bucket", Prefix="documents/", MaxKeys=1
Prefix="documents/",
MaxKeys=1
) )
def test_is_directory_returns_false_when_no_objects_exist(self, s3_storage, mock_boto3_client): @pytest.mark.unit
def test_is_directory_returns_false_when_no_objects_exist(
self, s3_storage, mock_boto3_client
):
"""Should return False when no objects exist with the directory prefix.""" """Should return False when no objects exist with the directory prefix."""
path = "nonexistent/" path = "nonexistent/"
@@ -319,6 +348,7 @@ class TestS3StorageIsDirectory:
class TestS3StorageRemoveDirectory: class TestS3StorageRemoveDirectory:
"""Test directory removal functionality.""" """Test directory removal functionality."""
@pytest.mark.unit
def test_remove_directory_deletes_all_objects(self, s3_storage, mock_boto3_client): def test_remove_directory_deletes_all_objects(self, s3_storage, mock_boto3_client):
"""Should delete all objects with the directory prefix.""" """Should delete all objects with the directory prefix."""
directory = "documents/" directory = "documents/"
@@ -329,16 +359,13 @@ class TestS3StorageRemoveDirectory:
{ {
"Contents": [ "Contents": [
{"Key": "documents/file1.txt"}, {"Key": "documents/file1.txt"},
{"Key": "documents/file2.txt"} {"Key": "documents/file2.txt"},
] ]
} }
] ]
mock_boto3_client.delete_objects.return_value = { mock_boto3_client.delete_objects.return_value = {
"Deleted": [ "Deleted": [{"Key": "documents/file1.txt"}, {"Key": "documents/file2.txt"}]
{"Key": "documents/file1.txt"},
{"Key": "documents/file2.txt"}
]
} }
result = s3_storage.remove_directory(directory) result = s3_storage.remove_directory(directory)
@@ -349,7 +376,10 @@ class TestS3StorageRemoveDirectory:
assert call_args["Bucket"] == "test-bucket" assert call_args["Bucket"] == "test-bucket"
assert len(call_args["Delete"]["Objects"]) == 2 assert len(call_args["Delete"]["Objects"]) == 2
def test_remove_directory_returns_false_when_empty(self, s3_storage, mock_boto3_client): @pytest.mark.unit
def test_remove_directory_returns_false_when_empty(
self, s3_storage, mock_boto3_client
):
"""Should return False when directory is empty (no objects to delete).""" """Should return False when directory is empty (no objects to delete)."""
directory = "empty/" directory = "empty/"
@@ -362,7 +392,10 @@ class TestS3StorageRemoveDirectory:
assert result is False assert result is False
mock_boto3_client.delete_objects.assert_not_called() mock_boto3_client.delete_objects.assert_not_called()
def test_remove_directory_returns_false_on_client_error(self, s3_storage, mock_boto3_client): @pytest.mark.unit
def test_remove_directory_returns_false_on_client_error(
self, s3_storage, mock_boto3_client
):
"""Should return False when deletion fails with ClientError.""" """Should return False when deletion fails with ClientError."""
directory = "documents/" directory = "documents/"
@@ -374,7 +407,7 @@ class TestS3StorageRemoveDirectory:
mock_boto3_client.delete_objects.side_effect = ClientError( mock_boto3_client.delete_objects.side_effect = ClientError(
{"Error": {"Code": "AccessDenied", "Message": "Access denied"}}, {"Error": {"Code": "AccessDenied", "Message": "Access denied"}},
"delete_objects" "delete_objects",
) )
result = s3_storage.remove_directory(directory) result = s3_storage.remove_directory(directory)

View File

@@ -1,12 +1,12 @@
from flask import Flask import pytest
from application.api.answer import answer from application.api.answer import answer
from application.api.internal.routes import internal from application.api.internal.routes import internal
from application.api.user.routes import user from application.api.user.routes import user
from application.core.settings import settings from application.core.settings import settings
from flask import Flask
@pytest.mark.unit
def test_app_config(): def test_app_config():
app = Flask(__name__) app = Flask(__name__)
app.register_blueprint(user) app.register_blueprint(user)

View File

@@ -1,20 +1,20 @@
import unittest
import json import json
from unittest.mock import patch, MagicMock from unittest.mock import MagicMock, patch
from application.cache import gen_cache_key, stream_cache, gen_cache
import pytest
from application.cache import gen_cache, gen_cache_key, stream_cache
from application.utils import get_hash from application.utils import get_hash
# Test for gen_cache_key function @pytest.mark.unit
def test_make_gen_cache_key(): def test_make_gen_cache_key():
messages = [ messages = [
{'role': 'user', 'content': 'test_user_message'}, {"role": "user", "content": "test_user_message"},
{'role': 'system', 'content': 'test_system_message'}, {"role": "system", "content": "test_system_message"},
] ]
model = "test_docgpt" model = "test_docgpt"
tools = None tools = None
# Manually calculate the expected hash
messages_str = json.dumps(messages) messages_str = json.dumps(messages)
tools_str = json.dumps(tools) if tools else "" tools_str = json.dumps(tools) if tools else ""
expected_combined = f"{model}_{messages_str}_{tools_str}" expected_combined = f"{model}_{messages_str}_{tools_str}"
@@ -23,112 +23,100 @@ def test_make_gen_cache_key():
assert cache_key == expected_hash assert cache_key == expected_hash
def test_gen_cache_key_invalid_message_format():
# Test when messages is not a list
with unittest.TestCase.assertRaises(unittest.TestCase, ValueError) as context:
gen_cache_key("This is not a list", model="docgpt", tools=None)
assert str(context.exception) == "All messages must be dictionaries."
# Test for gen_cache decorator @pytest.mark.unit
@patch('application.cache.get_redis_instance') # Mock the Redis client def test_gen_cache_key_invalid_message_format():
with pytest.raises(ValueError, match="All messages must be dictionaries."):
gen_cache_key("This is not a list", model="docgpt", tools=None)
@pytest.mark.unit
@patch("application.cache.get_redis_instance")
def test_gen_cache_hit(mock_make_redis): def test_gen_cache_hit(mock_make_redis):
# Arrange
mock_redis_instance = MagicMock() mock_redis_instance = MagicMock()
mock_make_redis.return_value = mock_redis_instance mock_make_redis.return_value = mock_redis_instance
mock_redis_instance.get.return_value = b"cached_result" # Simulate a cache hit mock_redis_instance.get.return_value = b"cached_result"
@gen_cache @gen_cache
def mock_function(self, model, messages, stream, tools): def mock_function(self, model, messages, stream, tools):
return "new_result" return "new_result"
messages = [{'role': 'user', 'content': 'test_user_message'}] messages = [{"role": "user", "content": "test_user_message"}]
model = "test_docgpt" model = "test_docgpt"
# Act
result = mock_function(None, model, messages, stream=False, tools=None) result = mock_function(None, model, messages, stream=False, tools=None)
# Assert assert result == "cached_result"
assert result == "cached_result" # Should return cached result mock_redis_instance.get.assert_called_once()
mock_redis_instance.get.assert_called_once() # Ensure Redis get was called mock_redis_instance.set.assert_not_called()
mock_redis_instance.set.assert_not_called() # Ensure the function result is not cached again
@patch('application.cache.get_redis_instance') # Mock the Redis client @pytest.mark.unit
@patch("application.cache.get_redis_instance")
def test_gen_cache_miss(mock_make_redis): def test_gen_cache_miss(mock_make_redis):
# Arrange
mock_redis_instance = MagicMock() mock_redis_instance = MagicMock()
mock_make_redis.return_value = mock_redis_instance mock_make_redis.return_value = mock_redis_instance
mock_redis_instance.get.return_value = None # Simulate a cache miss mock_redis_instance.get.return_value = None
@gen_cache @gen_cache
def mock_function(self, model, messages, steam, tools): def mock_function(self, model, messages, steam, tools):
return "new_result" return "new_result"
messages = [ messages = [
{'role': 'user', 'content': 'test_user_message'}, {"role": "user", "content": "test_user_message"},
{'role': 'system', 'content': 'test_system_message'}, {"role": "system", "content": "test_system_message"},
] ]
model = "test_docgpt" model = "test_docgpt"
# Act
result = mock_function(None, model, messages, stream=False, tools=None) result = mock_function(None, model, messages, stream=False, tools=None)
# Assert
assert result == "new_result" assert result == "new_result"
mock_redis_instance.get.assert_called_once() mock_redis_instance.get.assert_called_once()
@patch('application.cache.get_redis_instance')
@pytest.mark.unit
@patch("application.cache.get_redis_instance")
def test_stream_cache_hit(mock_make_redis): def test_stream_cache_hit(mock_make_redis):
# Arrange
mock_redis_instance = MagicMock() mock_redis_instance = MagicMock()
mock_make_redis.return_value = mock_redis_instance mock_make_redis.return_value = mock_redis_instance
cached_chunk = json.dumps(["chunk1", "chunk2"]).encode('utf-8') cached_chunk = json.dumps(["chunk1", "chunk2"]).encode("utf-8")
mock_redis_instance.get.return_value = cached_chunk mock_redis_instance.get.return_value = cached_chunk
@stream_cache @stream_cache
def mock_function(self, model, messages, stream, tools): def mock_function(self, model, messages, stream, tools):
yield "new_chunk" yield "new_chunk"
messages = [{'role': 'user', 'content': 'test_user_message'}] messages = [{"role": "user", "content": "test_user_message"}]
model = "test_docgpt" model = "test_docgpt"
# Act
result = list(mock_function(None, model, messages, stream=True, tools=None)) result = list(mock_function(None, model, messages, stream=True, tools=None))
# Assert assert result == ["chunk1", "chunk2"]
assert result == ["chunk1", "chunk2"] # Should return cached chunks
mock_redis_instance.get.assert_called_once() mock_redis_instance.get.assert_called_once()
mock_redis_instance.set.assert_not_called() mock_redis_instance.set.assert_not_called()
@patch('application.cache.get_redis_instance') @pytest.mark.unit
@patch("application.cache.get_redis_instance")
def test_stream_cache_miss(mock_make_redis): def test_stream_cache_miss(mock_make_redis):
# Arrange
mock_redis_instance = MagicMock() mock_redis_instance = MagicMock()
mock_make_redis.return_value = mock_redis_instance mock_make_redis.return_value = mock_redis_instance
mock_redis_instance.get.return_value = None # Simulate a cache miss mock_redis_instance.get.return_value = None
@stream_cache @stream_cache
def mock_function(self, model, messages, stream, tools): def mock_function(self, model, messages, stream, tools):
yield "new_chunk" yield "new_chunk"
messages = [ messages = [
{'role': 'user', 'content': 'This is the context'}, {"role": "user", "content": "This is the context"},
{'role': 'system', 'content': 'Some other message'}, {"role": "system", "content": "Some other message"},
{'role': 'user', 'content': 'What is the answer?'} {"role": "user", "content": "What is the answer?"},
] ]
model = "test_docgpt" model = "test_docgpt"
# Act
result = list(mock_function(None, model, messages, stream=True, tools=None)) result = list(mock_function(None, model, messages, stream=True, tools=None))
# Assert
assert result == ["new_chunk"] assert result == ["new_chunk"]
mock_redis_instance.get.assert_called_once() mock_redis_instance.get.assert_called_once()
mock_redis_instance.set.assert_called_once() mock_redis_instance.set.assert_called_once()

View File

@@ -1,21 +1,21 @@
from unittest.mock import patch from unittest.mock import patch
from application.core.settings import settings
import pytest
from application.celery_init import make_celery from application.celery_init import make_celery
from application.core.settings import settings
@patch('application.celery_init.Celery') @pytest.mark.unit
@patch("application.celery_init.Celery")
def test_make_celery(mock_celery): def test_make_celery(mock_celery):
# Arrange app_name = "test_app_name"
app_name = 'test_app_name'
# Act
celery = make_celery(app_name) celery = make_celery(app_name)
# Assert
mock_celery.assert_called_once_with( mock_celery.assert_called_once_with(
app_name, app_name,
broker=settings.CELERY_BROKER_URL, broker=settings.CELERY_BROKER_URL,
backend=settings.CELERY_RESULT_BACKEND backend=settings.CELERY_RESULT_BACKEND,
) )
celery.conf.update.assert_called_once_with(settings) celery.conf.update.assert_called_once_with(settings)
assert celery == mock_celery.return_value assert celery == mock_celery.return_value

View File

@@ -1,6 +1,6 @@
import pytest import pytest
from flask import Flask
from application.error import bad_request, response_error from application.error import bad_request, response_error
from flask import Flask
@pytest.fixture @pytest.fixture
@@ -9,31 +9,35 @@ def app():
return app return app
@pytest.mark.unit
def test_bad_request_with_message(app): def test_bad_request_with_message(app):
with app.app_context(): with app.app_context():
message = "Invalid input" message = "Invalid input"
response = bad_request(status_code=400, message=message) response = bad_request(status_code=400, message=message)
assert response.status_code == 400 assert response.status_code == 400
assert response.json == {'error': 'Bad Request', 'message': message} assert response.json == {"error": "Bad Request", "message": message}
@pytest.mark.unit
def test_bad_request_without_message(app): def test_bad_request_without_message(app):
with app.app_context(): with app.app_context():
response = bad_request(status_code=400) response = bad_request(status_code=400)
assert response.status_code == 400 assert response.status_code == 400
assert response.json == {'error': 'Bad Request'} assert response.json == {"error": "Bad Request"}
@pytest.mark.unit
def test_response_error_with_message(app): def test_response_error_with_message(app):
with app.app_context(): with app.app_context():
message = "Something went wrong" message = "Something went wrong"
response = response_error(code_status=500, message=message) response = response_error(code_status=500, message=message)
assert response.status_code == 500 assert response.status_code == 500
assert response.json == {'error': 'Internal Server Error', 'message': message} assert response.json == {"error": "Internal Server Error", "message": message}
@pytest.mark.unit
def test_response_error_without_message(app): def test_response_error_without_message(app):
with app.app_context(): with app.app_context():
response = response_error(code_status=500) response = response_error(code_status=500)
assert response.status_code == 500 assert response.status_code == 500
assert response.json == {'error': 'Internal Server Error'} assert response.json == {"error": "Internal Server Error"}

View File

@@ -17,6 +17,7 @@ def memory_tool(monkeypatch) -> MemoryTool:
path = doc.get("path") path = doc.get("path")
key = f"{user_id}:{tool_id}:{path}" key = f"{user_id}:{tool_id}:{path}"
# Add _id to document if not present # Add _id to document if not present
if "_id" not in doc: if "_id" not in doc:
doc["_id"] = key doc["_id"] = key
self.docs[key] = doc self.docs[key] = doc
@@ -24,16 +25,17 @@ def memory_tool(monkeypatch) -> MemoryTool:
def update_one(self, q, u, upsert=False): def update_one(self, q, u, upsert=False):
# Handle query by _id # Handle query by _id
if "_id" in q: if "_id" in q:
doc_id = q["_id"] doc_id = q["_id"]
if doc_id not in self.docs: if doc_id not in self.docs:
return type("res", (), {"modified_count": 0}) return type("res", (), {"modified_count": 0})
if "$set" in u: if "$set" in u:
old_doc = self.docs[doc_id].copy() old_doc = self.docs[doc_id].copy()
old_doc.update(u["$set"]) old_doc.update(u["$set"])
# If path changed, update the dictionary key # If path changed, update the dictionary key
if "path" in u["$set"]: if "path" in u["$set"]:
new_path = u["$set"]["path"] new_path = u["$set"]["path"]
user_id = old_doc.get("user_id") user_id = old_doc.get("user_id")
@@ -41,15 +43,15 @@ def memory_tool(monkeypatch) -> MemoryTool:
new_key = f"{user_id}:{tool_id}:{new_path}" new_key = f"{user_id}:{tool_id}:{new_path}"
# Remove old key and add with new key # Remove old key and add with new key
del self.docs[doc_id] del self.docs[doc_id]
old_doc["_id"] = new_key old_doc["_id"] = new_key
self.docs[new_key] = old_doc self.docs[new_key] = old_doc
else: else:
self.docs[doc_id] = old_doc self.docs[doc_id] = old_doc
return type("res", (), {"modified_count": 1}) return type("res", (), {"modified_count": 1})
# Handle query by user_id, tool_id, path # Handle query by user_id, tool_id, path
user_id = q.get("user_id") user_id = q.get("user_id")
tool_id = q.get("tool_id") tool_id = q.get("tool_id")
path = q.get("path") path = q.get("path")
@@ -57,13 +59,16 @@ def memory_tool(monkeypatch) -> MemoryTool:
if key not in self.docs and not upsert: if key not in self.docs and not upsert:
return type("res", (), {"modified_count": 0}) return type("res", (), {"modified_count": 0})
if key not in self.docs and upsert: if key not in self.docs and upsert:
self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "path": path, "content": "", "_id": key} self.docs[key] = {
"user_id": user_id,
"tool_id": tool_id,
"path": path,
"content": "",
"_id": key,
}
if "$set" in u: if "$set" in u:
self.docs[key].update(u["$set"]) self.docs[key].update(u["$set"])
return type("res", (), {"modified_count": 1}) return type("res", (), {"modified_count": 1})
def find_one(self, q, projection=None): def find_one(self, q, projection=None):
@@ -74,7 +79,6 @@ def memory_tool(monkeypatch) -> MemoryTool:
if path: if path:
key = f"{user_id}:{tool_id}:{path}" key = f"{user_id}:{tool_id}:{path}"
return self.docs.get(key) return self.docs.get(key)
return None return None
def find(self, q, projection=None): def find(self, q, projection=None):
@@ -83,9 +87,11 @@ def memory_tool(monkeypatch) -> MemoryTool:
results = [] results = []
# Handle regex queries for directory listing # Handle regex queries for directory listing
if "path" in q and isinstance(q["path"], dict) and "$regex" in q["path"]: if "path" in q and isinstance(q["path"], dict) and "$regex" in q["path"]:
regex_pattern = q["path"]["$regex"] regex_pattern = q["path"]["$regex"]
# Remove regex escape characters and ^ anchor for simple matching # Remove regex escape characters and ^ anchor for simple matching
pattern = regex_pattern.replace("\\", "").lstrip("^") pattern = regex_pattern.replace("\\", "").lstrip("^")
for key, doc in self.docs.items(): for key, doc in self.docs.items():
@@ -97,7 +103,6 @@ def memory_tool(monkeypatch) -> MemoryTool:
for key, doc in self.docs.items(): for key, doc in self.docs.items():
if doc.get("user_id") == user_id and doc.get("tool_id") == tool_id: if doc.get("user_id") == user_id and doc.get("tool_id") == tool_id:
results.append(doc) results.append(doc)
return results return results
def delete_one(self, q): def delete_one(self, q):
@@ -109,7 +114,6 @@ def memory_tool(monkeypatch) -> MemoryTool:
if key in self.docs: if key in self.docs:
del self.docs[key] del self.docs[key]
return type("res", (), {"deleted_count": 1}) return type("res", (), {"deleted_count": 1})
return type("res", (), {"deleted_count": 0}) return type("res", (), {"deleted_count": 0})
def delete_many(self, q): def delete_many(self, q):
@@ -118,6 +122,7 @@ def memory_tool(monkeypatch) -> MemoryTool:
deleted = 0 deleted = 0
# Handle regex queries for directory deletion # Handle regex queries for directory deletion
if "path" in q and isinstance(q["path"], dict) and "$regex" in q["path"]: if "path" in q and isinstance(q["path"], dict) and "$regex" in q["path"]:
regex_pattern = q["path"]["$regex"] regex_pattern = q["path"]["$regex"]
pattern = regex_pattern.replace("\\", "").lstrip("^") pattern = regex_pattern.replace("\\", "").lstrip("^")
@@ -128,32 +133,36 @@ def memory_tool(monkeypatch) -> MemoryTool:
doc_path = doc.get("path", "") doc_path = doc.get("path", "")
if doc_path.startswith(pattern): if doc_path.startswith(pattern):
keys_to_delete.append(key) keys_to_delete.append(key)
for key in keys_to_delete: for key in keys_to_delete:
del self.docs[key] del self.docs[key]
deleted += 1 deleted += 1
else: else:
# Delete all for user and tool # Delete all for user and tool
keys_to_delete = [ keys_to_delete = [
key for key, doc in self.docs.items() key
for key, doc in self.docs.items()
if doc.get("user_id") == user_id and doc.get("tool_id") == tool_id if doc.get("user_id") == user_id and doc.get("tool_id") == tool_id
] ]
for key in keys_to_delete: for key in keys_to_delete:
del self.docs[key] del self.docs[key]
deleted += 1 deleted += 1
return type("res", (), {"deleted_count": deleted}) return type("res", (), {"deleted_count": deleted})
fake_collection = FakeCollection() fake_collection = FakeCollection()
fake_db = {"memories": fake_collection} fake_db = {"memories": fake_collection}
fake_client = {settings.MONGO_DB_NAME: fake_db} fake_client = {settings.MONGO_DB_NAME: fake_db}
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client) monkeypatch.setattr(
"application.core.mongo_db.MongoDB.get_client", lambda: fake_client
)
# Return tool with a fixed tool_id for consistency in tests # Return tool with a fixed tool_id for consistency in tests
return MemoryTool({"tool_id": "test_tool_id"}, user_id="test_user") return MemoryTool({"tool_id": "test_tool_id"}, user_id="test_user")
@pytest.mark.unit
def test_init_without_user_id(): def test_init_without_user_id():
"""Should fail gracefully if no user_id is provided.""" """Should fail gracefully if no user_id is provided."""
memory_tool = MemoryTool(tool_config={}) memory_tool = MemoryTool(tool_config={})
@@ -161,90 +170,78 @@ def test_init_without_user_id():
assert "user_id" in result.lower() assert "user_id" in result.lower()
@pytest.mark.unit
def test_view_empty_directory(memory_tool: MemoryTool) -> None: def test_view_empty_directory(memory_tool: MemoryTool) -> None:
"""Should show empty directory when no files exist.""" """Should show empty directory when no files exist."""
result = memory_tool.execute_action("view", path="/") result = memory_tool.execute_action("view", path="/")
assert "empty" in result.lower() assert "empty" in result.lower()
@pytest.mark.unit
def test_create_and_view_file(memory_tool: MemoryTool) -> None: def test_create_and_view_file(memory_tool: MemoryTool) -> None:
"""Test creating a file and viewing it.""" """Test creating a file and viewing it."""
# Create a file # Create a file
result = memory_tool.execute_action( result = memory_tool.execute_action(
"create", "create", path="/notes.txt", file_text="Hello world"
path="/notes.txt",
file_text="Hello world"
) )
assert "created" in result.lower() assert "created" in result.lower()
# View the file # View the file
result = memory_tool.execute_action("view", path="/notes.txt") result = memory_tool.execute_action("view", path="/notes.txt")
assert "Hello world" in result assert "Hello world" in result
@pytest.mark.unit
def test_create_overwrite_file(memory_tool: MemoryTool) -> None: def test_create_overwrite_file(memory_tool: MemoryTool) -> None:
"""Test that create overwrites existing files.""" """Test that create overwrites existing files."""
# Create initial file # Create initial file
memory_tool.execute_action(
"create", memory_tool.execute_action("create", path="/test.txt", file_text="Original content")
path="/test.txt",
file_text="Original content"
)
# Overwrite # Overwrite
memory_tool.execute_action(
"create", memory_tool.execute_action("create", path="/test.txt", file_text="New content")
path="/test.txt",
file_text="New content"
)
# Verify overwrite # Verify overwrite
result = memory_tool.execute_action("view", path="/test.txt") result = memory_tool.execute_action("view", path="/test.txt")
assert "New content" in result assert "New content" in result
assert "Original content" not in result assert "Original content" not in result
@pytest.mark.unit
def test_view_directory_with_files(memory_tool: MemoryTool) -> None: def test_view_directory_with_files(memory_tool: MemoryTool) -> None:
"""Test viewing directory contents.""" """Test viewing directory contents."""
# Create multiple files # Create multiple files
memory_tool.execute_action("create", path="/file1.txt", file_text="Content 1")
memory_tool.execute_action("create", path="/file2.txt", file_text="Content 2")
memory_tool.execute_action( memory_tool.execute_action(
"create", "create", path="/subdir/file3.txt", file_text="Content 3"
path="/file1.txt",
file_text="Content 1"
)
memory_tool.execute_action(
"create",
path="/file2.txt",
file_text="Content 2"
)
memory_tool.execute_action(
"create",
path="/subdir/file3.txt",
file_text="Content 3"
) )
# View directory # View directory
result = memory_tool.execute_action("view", path="/") result = memory_tool.execute_action("view", path="/")
assert "file1.txt" in result assert "file1.txt" in result
assert "file2.txt" in result assert "file2.txt" in result
assert "subdir/file3.txt" in result assert "subdir/file3.txt" in result
@pytest.mark.unit
def test_view_file_with_line_range(memory_tool: MemoryTool) -> None: def test_view_file_with_line_range(memory_tool: MemoryTool) -> None:
"""Test viewing specific lines from a file.""" """Test viewing specific lines from a file."""
# Create a multiline file # Create a multiline file
content = "Line 1\nLine 2\nLine 3\nLine 4\nLine 5" content = "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
memory_tool.execute_action( memory_tool.execute_action("create", path="/multiline.txt", file_text=content)
"create",
path="/multiline.txt",
file_text=content
)
# View lines 2-4 # View lines 2-4
result = memory_tool.execute_action( result = memory_tool.execute_action(
"view", "view", path="/multiline.txt", view_range=[2, 4]
path="/multiline.txt",
view_range=[2, 4]
) )
assert "Line 2" in result assert "Line 2" in result
assert "Line 3" in result assert "Line 3" in result
@@ -253,197 +250,177 @@ def test_view_file_with_line_range(memory_tool: MemoryTool) -> None:
assert "Line 5" not in result assert "Line 5" not in result
@pytest.mark.unit
def test_str_replace(memory_tool: MemoryTool) -> None: def test_str_replace(memory_tool: MemoryTool) -> None:
"""Test string replacement in a file.""" """Test string replacement in a file."""
# Create a file # Create a file
memory_tool.execute_action( memory_tool.execute_action(
"create", "create", path="/replace.txt", file_text="Hello world, hello universe"
path="/replace.txt",
file_text="Hello world, hello universe"
) )
# Replace text # Replace text
result = memory_tool.execute_action( result = memory_tool.execute_action(
"str_replace", "str_replace", path="/replace.txt", old_str="hello", new_str="hi"
path="/replace.txt",
old_str="hello",
new_str="hi"
) )
assert "updated" in result.lower() assert "updated" in result.lower()
# Verify replacement # Verify replacement
content = memory_tool.execute_action("view", path="/replace.txt") content = memory_tool.execute_action("view", path="/replace.txt")
assert "hi world, hi universe" in content assert "hi world, hi universe" in content
@pytest.mark.unit
def test_str_replace_not_found(memory_tool: MemoryTool) -> None: def test_str_replace_not_found(memory_tool: MemoryTool) -> None:
"""Test string replacement when string not found.""" """Test string replacement when string not found."""
memory_tool.execute_action( memory_tool.execute_action("create", path="/test.txt", file_text="Hello world")
"create",
path="/test.txt",
file_text="Hello world"
)
result = memory_tool.execute_action( result = memory_tool.execute_action(
"str_replace", "str_replace", path="/test.txt", old_str="goodbye", new_str="hi"
path="/test.txt",
old_str="goodbye",
new_str="hi"
) )
assert "not found" in result.lower() assert "not found" in result.lower()
@pytest.mark.unit
def test_insert_line(memory_tool: MemoryTool) -> None: def test_insert_line(memory_tool: MemoryTool) -> None:
"""Test inserting text at a line number.""" """Test inserting text at a line number."""
# Create a multiline file # Create a multiline file
memory_tool.execute_action( memory_tool.execute_action(
"create", "create", path="/insert.txt", file_text="Line 1\nLine 2\nLine 3"
path="/insert.txt",
file_text="Line 1\nLine 2\nLine 3"
) )
# Insert at line 2 # Insert at line 2
result = memory_tool.execute_action( result = memory_tool.execute_action(
"insert", "insert", path="/insert.txt", insert_line=2, insert_text="Inserted line"
path="/insert.txt",
insert_line=2,
insert_text="Inserted line"
) )
assert "inserted" in result.lower() assert "inserted" in result.lower()
# Verify insertion # Verify insertion
content = memory_tool.execute_action("view", path="/insert.txt") content = memory_tool.execute_action("view", path="/insert.txt")
lines = content.split("\n") lines = content.split("\n")
assert lines[1] == "Inserted line" assert lines[1] == "Inserted line"
assert lines[2] == "Line 2" assert lines[2] == "Line 2"
@pytest.mark.unit
def test_insert_invalid_line(memory_tool: MemoryTool) -> None: def test_insert_invalid_line(memory_tool: MemoryTool) -> None:
"""Test inserting at an invalid line number.""" """Test inserting at an invalid line number."""
memory_tool.execute_action( memory_tool.execute_action("create", path="/test.txt", file_text="Line 1\nLine 2")
"create",
path="/test.txt",
file_text="Line 1\nLine 2"
)
result = memory_tool.execute_action( result = memory_tool.execute_action(
"insert", "insert", path="/test.txt", insert_line=100, insert_text="Text"
path="/test.txt",
insert_line=100,
insert_text="Text"
) )
assert "invalid" in result.lower() assert "invalid" in result.lower()
@pytest.mark.unit
def test_delete_file(memory_tool: MemoryTool) -> None: def test_delete_file(memory_tool: MemoryTool) -> None:
"""Test deleting a file.""" """Test deleting a file."""
# Create a file # Create a file
memory_tool.execute_action(
"create", memory_tool.execute_action("create", path="/delete_me.txt", file_text="Content")
path="/delete_me.txt",
file_text="Content"
)
# Delete it # Delete it
result = memory_tool.execute_action("delete", path="/delete_me.txt") result = memory_tool.execute_action("delete", path="/delete_me.txt")
assert "deleted" in result.lower() assert "deleted" in result.lower()
# Verify it's gone # Verify it's gone
result = memory_tool.execute_action("view", path="/delete_me.txt") result = memory_tool.execute_action("view", path="/delete_me.txt")
assert "not found" in result.lower() assert "not found" in result.lower()
@pytest.mark.unit
def test_delete_nonexistent_file(memory_tool: MemoryTool) -> None: def test_delete_nonexistent_file(memory_tool: MemoryTool) -> None:
"""Test deleting a file that doesn't exist.""" """Test deleting a file that doesn't exist."""
result = memory_tool.execute_action("delete", path="/nonexistent.txt") result = memory_tool.execute_action("delete", path="/nonexistent.txt")
assert "not found" in result.lower() assert "not found" in result.lower()
@pytest.mark.unit
def test_delete_directory(memory_tool: MemoryTool) -> None: def test_delete_directory(memory_tool: MemoryTool) -> None:
"""Test deleting a directory with files.""" """Test deleting a directory with files."""
# Create files in a directory # Create files in a directory
memory_tool.execute_action( memory_tool.execute_action(
"create", "create", path="/subdir/file1.txt", file_text="Content 1"
path="/subdir/file1.txt",
file_text="Content 1"
) )
memory_tool.execute_action( memory_tool.execute_action(
"create", "create", path="/subdir/file2.txt", file_text="Content 2"
path="/subdir/file2.txt",
file_text="Content 2"
) )
# Delete the directory # Delete the directory
result = memory_tool.execute_action("delete", path="/subdir/") result = memory_tool.execute_action("delete", path="/subdir/")
assert "deleted" in result.lower() assert "deleted" in result.lower()
# Verify files are gone # Verify files are gone
result = memory_tool.execute_action("view", path="/subdir/file1.txt") result = memory_tool.execute_action("view", path="/subdir/file1.txt")
assert "not found" in result.lower() assert "not found" in result.lower()
@pytest.mark.unit
def test_rename_file(memory_tool: MemoryTool) -> None: def test_rename_file(memory_tool: MemoryTool) -> None:
"""Test renaming a file.""" """Test renaming a file."""
# Create a file # Create a file
memory_tool.execute_action(
"create", memory_tool.execute_action("create", path="/old_name.txt", file_text="Content")
path="/old_name.txt",
file_text="Content"
)
# Rename it # Rename it
result = memory_tool.execute_action( result = memory_tool.execute_action(
"rename", "rename", old_path="/old_name.txt", new_path="/new_name.txt"
old_path="/old_name.txt",
new_path="/new_name.txt"
) )
assert "renamed" in result.lower() assert "renamed" in result.lower()
# Verify old path doesn't exist # Verify old path doesn't exist
result = memory_tool.execute_action("view", path="/old_name.txt") result = memory_tool.execute_action("view", path="/old_name.txt")
assert "not found" in result.lower() assert "not found" in result.lower()
# Verify new path exists # Verify new path exists
result = memory_tool.execute_action("view", path="/new_name.txt") result = memory_tool.execute_action("view", path="/new_name.txt")
assert "Content" in result assert "Content" in result
@pytest.mark.unit
def test_rename_nonexistent_file(memory_tool: MemoryTool) -> None: def test_rename_nonexistent_file(memory_tool: MemoryTool) -> None:
"""Test renaming a file that doesn't exist.""" """Test renaming a file that doesn't exist."""
result = memory_tool.execute_action( result = memory_tool.execute_action(
"rename", "rename", old_path="/nonexistent.txt", new_path="/new.txt"
old_path="/nonexistent.txt",
new_path="/new.txt"
) )
assert "not found" in result.lower() assert "not found" in result.lower()
@pytest.mark.unit
def test_rename_to_existing_file(memory_tool: MemoryTool) -> None: def test_rename_to_existing_file(memory_tool: MemoryTool) -> None:
"""Test renaming to a path that already exists.""" """Test renaming to a path that already exists."""
# Create two files # Create two files
memory_tool.execute_action(
"create", memory_tool.execute_action("create", path="/file1.txt", file_text="Content 1")
path="/file1.txt", memory_tool.execute_action("create", path="/file2.txt", file_text="Content 2")
file_text="Content 1"
)
memory_tool.execute_action(
"create",
path="/file2.txt",
file_text="Content 2"
)
# Try to rename file1 to file2 # Try to rename file1 to file2
result = memory_tool.execute_action( result = memory_tool.execute_action(
"rename", "rename", old_path="/file1.txt", new_path="/file2.txt"
old_path="/file1.txt",
new_path="/file2.txt"
) )
assert "already exists" in result.lower() assert "already exists" in result.lower()
@pytest.mark.unit
def test_path_traversal_protection(memory_tool: MemoryTool) -> None: def test_path_traversal_protection(memory_tool: MemoryTool) -> None:
"""Test that directory traversal attacks are prevented.""" """Test that directory traversal attacks are prevented."""
# Try various path traversal attempts # Try various path traversal attempts
invalid_paths = [ invalid_paths = [
"/../secrets.txt", "/../secrets.txt",
"/../../etc/passwd", "/../../etc/passwd",
@@ -453,16 +430,16 @@ def test_path_traversal_protection(memory_tool: MemoryTool) -> None:
for path in invalid_paths: for path in invalid_paths:
result = memory_tool.execute_action( result = memory_tool.execute_action(
"create", "create", path=path, file_text="malicious content"
path=path,
file_text="malicious content"
) )
assert "invalid path" in result.lower() assert "invalid path" in result.lower()
@pytest.mark.unit
def test_path_must_start_with_slash(memory_tool: MemoryTool) -> None: def test_path_must_start_with_slash(memory_tool: MemoryTool) -> None:
"""Test that paths work with or without leading slash (auto-normalized).""" """Test that paths work with or without leading slash (auto-normalized)."""
# These paths should all work now (auto-prepended with /) # These paths should all work now (auto-prepended with /)
valid_paths = [ valid_paths = [
"etc/passwd", # Auto-prepended with / "etc/passwd", # Auto-prepended with /
"home/user/file.txt", # Auto-prepended with / "home/user/file.txt", # Auto-prepended with /
@@ -470,33 +447,29 @@ def test_path_must_start_with_slash(memory_tool: MemoryTool) -> None:
] ]
for path in valid_paths: for path in valid_paths:
result = memory_tool.execute_action( result = memory_tool.execute_action("create", path=path, file_text="content")
"create",
path=path,
file_text="content"
)
assert "created" in result.lower() assert "created" in result.lower()
# Verify the file can be accessed with or without leading slash # Verify the file can be accessed with or without leading slash
view_result = memory_tool.execute_action("view", path=path) view_result = memory_tool.execute_action("view", path=path)
assert "content" in view_result assert "content" in view_result
@pytest.mark.unit
def test_cannot_create_directory_as_file(memory_tool: MemoryTool) -> None: def test_cannot_create_directory_as_file(memory_tool: MemoryTool) -> None:
"""Test that you cannot create a file at a directory path.""" """Test that you cannot create a file at a directory path."""
result = memory_tool.execute_action( result = memory_tool.execute_action("create", path="/", file_text="content")
"create",
path="/",
file_text="content"
)
assert "cannot create a file at directory path" in result.lower() assert "cannot create a file at directory path" in result.lower()
@pytest.mark.unit
def test_get_actions_metadata(memory_tool: MemoryTool) -> None: def test_get_actions_metadata(memory_tool: MemoryTool) -> None:
"""Test that action metadata is properly defined.""" """Test that action metadata is properly defined."""
metadata = memory_tool.get_actions_metadata() metadata = memory_tool.get_actions_metadata()
# Check that all expected actions are defined # Check that all expected actions are defined
action_names = [action["name"] for action in metadata] action_names = [action["name"] for action in metadata]
assert "view" in action_names assert "view" in action_names
assert "create" in action_names assert "create" in action_names
@@ -506,15 +479,18 @@ def test_get_actions_metadata(memory_tool: MemoryTool) -> None:
assert "rename" in action_names assert "rename" in action_names
# Check that each action has required fields # Check that each action has required fields
for action in metadata: for action in metadata:
assert "name" in action assert "name" in action
assert "description" in action assert "description" in action
assert "parameters" in action assert "parameters" in action
@pytest.mark.unit
def test_memory_tool_isolation(monkeypatch) -> None: def test_memory_tool_isolation(monkeypatch) -> None:
"""Test that different memory tool instances have isolated memories.""" """Test that different memory tool instances have isolated memories."""
# Create fake collection # Create fake collection
class FakeCollection: class FakeCollection:
def __init__(self) -> None: def __init__(self) -> None:
self.docs = {} self.docs = {}
@@ -529,16 +505,17 @@ def test_memory_tool_isolation(monkeypatch) -> None:
def update_one(self, q, u, upsert=False): def update_one(self, q, u, upsert=False):
# Handle query by _id # Handle query by _id
if "_id" in q: if "_id" in q:
doc_id = q["_id"] doc_id = q["_id"]
if doc_id not in self.docs: if doc_id not in self.docs:
return type("res", (), {"modified_count": 0}) return type("res", (), {"modified_count": 0})
if "$set" in u: if "$set" in u:
old_doc = self.docs[doc_id].copy() old_doc = self.docs[doc_id].copy()
old_doc.update(u["$set"]) old_doc.update(u["$set"])
# If path changed, update the dictionary key # If path changed, update the dictionary key
if "path" in u["$set"]: if "path" in u["$set"]:
new_path = u["$set"]["path"] new_path = u["$set"]["path"]
user_id = old_doc.get("user_id") user_id = old_doc.get("user_id")
@@ -546,15 +523,15 @@ def test_memory_tool_isolation(monkeypatch) -> None:
new_key = f"{user_id}:{tool_id}:{new_path}" new_key = f"{user_id}:{tool_id}:{new_path}"
# Remove old key and add with new key # Remove old key and add with new key
del self.docs[doc_id] del self.docs[doc_id]
old_doc["_id"] = new_key old_doc["_id"] = new_key
self.docs[new_key] = old_doc self.docs[new_key] = old_doc
else: else:
self.docs[doc_id] = old_doc self.docs[doc_id] = old_doc
return type("res", (), {"modified_count": 1}) return type("res", (), {"modified_count": 1})
# Handle query by user_id, tool_id, path # Handle query by user_id, tool_id, path
user_id = q.get("user_id") user_id = q.get("user_id")
tool_id = q.get("tool_id") tool_id = q.get("tool_id")
path = q.get("path") path = q.get("path")
@@ -562,13 +539,16 @@ def test_memory_tool_isolation(monkeypatch) -> None:
if key not in self.docs and not upsert: if key not in self.docs and not upsert:
return type("res", (), {"modified_count": 0}) return type("res", (), {"modified_count": 0})
if key not in self.docs and upsert: if key not in self.docs and upsert:
self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "path": path, "content": "", "_id": key} self.docs[key] = {
"user_id": user_id,
"tool_id": tool_id,
"path": path,
"content": "",
"_id": key,
}
if "$set" in u: if "$set" in u:
self.docs[key].update(u["$set"]) self.docs[key].update(u["$set"])
return type("res", (), {"modified_count": 1}) return type("res", (), {"modified_count": 1})
def find_one(self, q, projection=None): def find_one(self, q, projection=None):
@@ -579,26 +559,31 @@ def test_memory_tool_isolation(monkeypatch) -> None:
if path: if path:
key = f"{user_id}:{tool_id}:{path}" key = f"{user_id}:{tool_id}:{path}"
return self.docs.get(key) return self.docs.get(key)
return None return None
fake_collection = FakeCollection() fake_collection = FakeCollection()
fake_db = {"memories": fake_collection} fake_db = {"memories": fake_collection}
fake_client = {settings.MONGO_DB_NAME: fake_db} fake_client = {settings.MONGO_DB_NAME: fake_db}
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client) monkeypatch.setattr(
"application.core.mongo_db.MongoDB.get_client", lambda: fake_client
)
# Create two memory tools with different tool_ids for the same user # Create two memory tools with different tool_ids for the same user
tool1 = MemoryTool({"tool_id": "tool_1"}, user_id="test_user") tool1 = MemoryTool({"tool_id": "tool_1"}, user_id="test_user")
tool2 = MemoryTool({"tool_id": "tool_2"}, user_id="test_user") tool2 = MemoryTool({"tool_id": "tool_2"}, user_id="test_user")
# Create a file in tool1 # Create a file in tool1
tool1.execute_action("create", path="/file.txt", file_text="Content from tool 1") tool1.execute_action("create", path="/file.txt", file_text="Content from tool 1")
# Create a file with the same path in tool2 # Create a file with the same path in tool2
tool2.execute_action("create", path="/file.txt", file_text="Content from tool 2") tool2.execute_action("create", path="/file.txt", file_text="Content from tool 2")
# Verify that each tool sees only its own content # Verify that each tool sees only its own content
result1 = tool1.execute_action("view", path="/file.txt") result1 = tool1.execute_action("view", path="/file.txt")
result2 = tool2.execute_action("view", path="/file.txt") result2 = tool2.execute_action("view", path="/file.txt")
@@ -609,8 +594,10 @@ def test_memory_tool_isolation(monkeypatch) -> None:
assert "Content from tool 1" not in result2 assert "Content from tool 1" not in result2
@pytest.mark.unit
def test_memory_tool_auto_generates_tool_id(monkeypatch) -> None: def test_memory_tool_auto_generates_tool_id(monkeypatch) -> None:
"""Test that tool_id defaults to 'default_{user_id}' for persistence.""" """Test that tool_id defaults to 'default_{user_id}' for persistence."""
class FakeCollection: class FakeCollection:
def __init__(self) -> None: def __init__(self) -> None:
self.docs = {} self.docs = {}
@@ -622,78 +609,94 @@ def test_memory_tool_auto_generates_tool_id(monkeypatch) -> None:
fake_db = {"memories": fake_collection} fake_db = {"memories": fake_collection}
fake_client = {settings.MONGO_DB_NAME: fake_db} fake_client = {settings.MONGO_DB_NAME: fake_db}
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client) monkeypatch.setattr(
"application.core.mongo_db.MongoDB.get_client", lambda: fake_client
)
# Create two tools without providing tool_id for the same user # Create two tools without providing tool_id for the same user
tool1 = MemoryTool({}, user_id="test_user") tool1 = MemoryTool({}, user_id="test_user")
tool2 = MemoryTool({}, user_id="test_user") tool2 = MemoryTool({}, user_id="test_user")
# Both should have the same default tool_id for persistence # Both should have the same default tool_id for persistence
assert tool1.tool_id == "default_test_user" assert tool1.tool_id == "default_test_user"
assert tool2.tool_id == "default_test_user" assert tool2.tool_id == "default_test_user"
assert tool1.tool_id == tool2.tool_id assert tool1.tool_id == tool2.tool_id
# Different users should have different tool_ids # Different users should have different tool_ids
tool3 = MemoryTool({}, user_id="another_user") tool3 = MemoryTool({}, user_id="another_user")
assert tool3.tool_id == "default_another_user" assert tool3.tool_id == "default_another_user"
assert tool3.tool_id != tool1.tool_id assert tool3.tool_id != tool1.tool_id
@pytest.mark.unit
def test_paths_without_leading_slash(memory_tool) -> None: def test_paths_without_leading_slash(memory_tool) -> None:
"""Test that paths without leading slash work correctly.""" """Test that paths without leading slash work correctly."""
# Create file without leading slash # Create file without leading slash
result = memory_tool.execute_action("create", path="cat_breeds.txt", file_text="- Korat\n- Chartreux\n- British Shorthair\n- Nebelung")
result = memory_tool.execute_action(
"create",
path="cat_breeds.txt",
file_text="- Korat\n- Chartreux\n- British Shorthair\n- Nebelung",
)
assert "created" in result.lower() assert "created" in result.lower()
# View file without leading slash # View file without leading slash
view_result = memory_tool.execute_action("view", path="cat_breeds.txt") view_result = memory_tool.execute_action("view", path="cat_breeds.txt")
assert "Korat" in view_result assert "Korat" in view_result
assert "Chartreux" in view_result assert "Chartreux" in view_result
# View same file with leading slash (should work the same) # View same file with leading slash (should work the same)
view_result2 = memory_tool.execute_action("view", path="/cat_breeds.txt") view_result2 = memory_tool.execute_action("view", path="/cat_breeds.txt")
assert "Korat" in view_result2 assert "Korat" in view_result2
# Test str_replace without leading slash # Test str_replace without leading slash
replace_result = memory_tool.execute_action("str_replace", path="cat_breeds.txt", old_str="Korat", new_str="Maine Coon")
replace_result = memory_tool.execute_action(
"str_replace", path="cat_breeds.txt", old_str="Korat", new_str="Maine Coon"
)
assert "updated" in replace_result.lower() assert "updated" in replace_result.lower()
# Test nested path without leading slash # Test nested path without leading slash
nested_result = memory_tool.execute_action("create", path="projects/tasks.txt", file_text="Task 1\nTask 2")
nested_result = memory_tool.execute_action(
"create", path="projects/tasks.txt", file_text="Task 1\nTask 2"
)
assert "created" in nested_result.lower() assert "created" in nested_result.lower()
view_nested = memory_tool.execute_action("view", path="projects/tasks.txt") view_nested = memory_tool.execute_action("view", path="projects/tasks.txt")
assert "Task 1" in view_nested assert "Task 1" in view_nested
@pytest.mark.unit
def test_rename_directory(memory_tool: MemoryTool) -> None: def test_rename_directory(memory_tool: MemoryTool) -> None:
"""Test renaming a directory with files.""" """Test renaming a directory with files."""
# Create files in a directory # Create files in a directory
memory_tool.execute_action("create", path="/docs/file1.txt", file_text="Content 1")
memory_tool.execute_action( memory_tool.execute_action(
"create", "create", path="/docs/sub/file2.txt", file_text="Content 2"
path="/docs/file1.txt",
file_text="Content 1"
)
memory_tool.execute_action(
"create",
path="/docs/sub/file2.txt",
file_text="Content 2"
) )
# Rename directory (with trailing slash) # Rename directory (with trailing slash)
result = memory_tool.execute_action( result = memory_tool.execute_action(
"rename", "rename", old_path="/docs/", new_path="/archive/"
old_path="/docs/",
new_path="/archive/"
) )
assert "renamed" in result.lower() assert "renamed" in result.lower()
assert "2 files" in result.lower() assert "2 files" in result.lower()
# Verify old paths don't exist # Verify old paths don't exist
result = memory_tool.execute_action("view", path="/docs/file1.txt") result = memory_tool.execute_action("view", path="/docs/file1.txt")
assert "not found" in result.lower() assert "not found" in result.lower()
# Verify new paths exist # Verify new paths exist
result = memory_tool.execute_action("view", path="/archive/file1.txt") result = memory_tool.execute_action("view", path="/archive/file1.txt")
assert "Content 1" in result assert "Content 1" in result
@@ -701,29 +704,25 @@ def test_rename_directory(memory_tool: MemoryTool) -> None:
assert "Content 2" in result assert "Content 2" in result
@pytest.mark.unit
def test_rename_directory_without_trailing_slash(memory_tool: MemoryTool) -> None: def test_rename_directory_without_trailing_slash(memory_tool: MemoryTool) -> None:
"""Test renaming a directory when new path is missing trailing slash.""" """Test renaming a directory when new path is missing trailing slash."""
# Create files in a directory # Create files in a directory
memory_tool.execute_action("create", path="/docs/file1.txt", file_text="Content 1")
memory_tool.execute_action( memory_tool.execute_action(
"create", "create", path="/docs/sub/file2.txt", file_text="Content 2"
path="/docs/file1.txt",
file_text="Content 1"
)
memory_tool.execute_action(
"create",
path="/docs/sub/file2.txt",
file_text="Content 2"
) )
# Rename directory - old path has slash, new path doesn't # Rename directory - old path has slash, new path doesn't
result = memory_tool.execute_action( result = memory_tool.execute_action(
"rename", "rename", old_path="/docs/", new_path="/archive" # Missing trailing slash
old_path="/docs/",
new_path="/archive" # Missing trailing slash
) )
assert "renamed" in result.lower() assert "renamed" in result.lower()
# Verify paths are correct (not corrupted like "/archivesub/file2.txt") # Verify paths are correct (not corrupted like "/archivesub/file2.txt")
result = memory_tool.execute_action("view", path="/archive/file1.txt") result = memory_tool.execute_action("view", path="/archive/file1.txt")
assert "Content 1" in result assert "Content 1" in result
@@ -731,28 +730,25 @@ def test_rename_directory_without_trailing_slash(memory_tool: MemoryTool) -> Non
assert "Content 2" in result assert "Content 2" in result
# Verify corrupted path doesn't exist # Verify corrupted path doesn't exist
result = memory_tool.execute_action("view", path="/archivesub/file2.txt") result = memory_tool.execute_action("view", path="/archivesub/file2.txt")
assert "not found" in result.lower() assert "not found" in result.lower()
@pytest.mark.unit
def test_view_file_line_numbers(memory_tool: MemoryTool) -> None: def test_view_file_line_numbers(memory_tool: MemoryTool) -> None:
"""Test that view_range displays correct line numbers.""" """Test that view_range displays correct line numbers."""
# Create a multiline file # Create a multiline file
content = "Line 1\nLine 2\nLine 3\nLine 4\nLine 5" content = "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
memory_tool.execute_action( memory_tool.execute_action("create", path="/numbered.txt", file_text=content)
"create",
path="/numbered.txt",
file_text=content
)
# View lines 2-4 # View lines 2-4
result = memory_tool.execute_action(
"view", result = memory_tool.execute_action("view", path="/numbered.txt", view_range=[2, 4])
path="/numbered.txt",
view_range=[2, 4]
)
# Check that line numbers are correct (should be 2, 3, 4 not 3, 4, 5) # Check that line numbers are correct (should be 2, 3, 4 not 3, 4, 5)
assert "2: Line 2" in result assert "2: Line 2" in result
assert "3: Line 3" in result assert "3: Line 3" in result
assert "4: Line 4" in result assert "4: Line 4" in result
@@ -760,6 +756,7 @@ def test_view_file_line_numbers(memory_tool: MemoryTool) -> None:
assert "5: Line 5" not in result assert "5: Line 5" not in result
# Verify no off-by-one error # Verify no off-by-one error
assert "3: Line 2" not in result # Wrong line number assert "3: Line 2" not in result # Wrong line number
assert "4: Line 3" not in result # Wrong line number assert "4: Line 3" not in result # Wrong line number
assert "5: Line 4" not in result # Wrong line number assert "5: Line 4" not in result # Wrong line number

View File

@@ -3,10 +3,10 @@ from application.agents.tools.notes import NotesTool
from application.core.settings import settings from application.core.settings import settings
@pytest.fixture @pytest.fixture
def notes_tool(monkeypatch) -> NotesTool: def notes_tool(monkeypatch) -> NotesTool:
"""Provide a NotesTool with a fake Mongo collection and fixed user_id.""" """Provide a NotesTool with a fake Mongo collection and fixed user_id."""
class FakeCollection: class FakeCollection:
def __init__(self) -> None: def __init__(self) -> None:
self.docs = {} # key: user_id:tool_id -> doc self.docs = {} # key: user_id:tool_id -> doc
@@ -17,6 +17,7 @@ def notes_tool(monkeypatch) -> NotesTool:
key = f"{user_id}:{tool_id}" key = f"{user_id}:{tool_id}"
# emulate single-note storage with optional upsert # emulate single-note storage with optional upsert
if key not in self.docs and not upsert: if key not in self.docs and not upsert:
return type("res", (), {"modified_count": 0}) return type("res", (), {"modified_count": 0})
if key not in self.docs and upsert: if key not in self.docs and upsert:
@@ -45,34 +46,45 @@ def notes_tool(monkeypatch) -> NotesTool:
fake_client = {settings.MONGO_DB_NAME: fake_db} fake_client = {settings.MONGO_DB_NAME: fake_db}
# Patch MongoDB client globally for the tool # Patch MongoDB client globally for the tool
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
monkeypatch.setattr(
"application.core.mongo_db.MongoDB.get_client", lambda: fake_client
)
# Return tool with a fixed tool_id for consistency in tests # Return tool with a fixed tool_id for consistency in tests
return NotesTool({"tool_id": "test_tool_id"}, user_id="test_user") return NotesTool({"tool_id": "test_tool_id"}, user_id="test_user")
@pytest.mark.unit
def test_view(notes_tool: NotesTool) -> None: def test_view(notes_tool: NotesTool) -> None:
# Manually insert a note to test retrieval # Manually insert a note to test retrieval
notes_tool.collection.update_one( notes_tool.collection.update_one(
{"user_id": "test_user", "tool_id": "test_tool_id"}, {"user_id": "test_user", "tool_id": "test_tool_id"},
{"$set": {"note": "hello"}}, {"$set": {"note": "hello"}},
upsert=True upsert=True,
) )
assert "hello" in notes_tool.execute_action("view") assert "hello" in notes_tool.execute_action("view")
@pytest.mark.unit
def test_overwrite_and_delete(notes_tool: NotesTool) -> None: def test_overwrite_and_delete(notes_tool: NotesTool) -> None:
# Overwrite creates a new note # Overwrite creates a new note
assert "saved" in notes_tool.execute_action("overwrite", text="first").lower() assert "saved" in notes_tool.execute_action("overwrite", text="first").lower()
assert "first" in notes_tool.execute_action("view") assert "first" in notes_tool.execute_action("view")
# Overwrite replaces existing note # Overwrite replaces existing note
assert "saved" in notes_tool.execute_action("overwrite", text="second").lower() assert "saved" in notes_tool.execute_action("overwrite", text="second").lower()
assert "second" in notes_tool.execute_action("view") assert "second" in notes_tool.execute_action("view")
assert "deleted" in notes_tool.execute_action("delete").lower() assert "deleted" in notes_tool.execute_action("delete").lower()
assert "no note" in notes_tool.execute_action("view").lower() assert "no note" in notes_tool.execute_action("view").lower()
@pytest.mark.unit
def test_init_without_user_id(monkeypatch): def test_init_without_user_id(monkeypatch):
"""Should fail gracefully if no user_id is provided.""" """Should fail gracefully if no user_id is provided."""
notes_tool = NotesTool(tool_config={}) notes_tool = NotesTool(tool_config={})
@@ -80,26 +92,32 @@ def test_init_without_user_id(monkeypatch):
assert "user_id" in str(result).lower() assert "user_id" in str(result).lower()
@pytest.mark.unit
def test_view_not_found(notes_tool: NotesTool) -> None: def test_view_not_found(notes_tool: NotesTool) -> None:
"""Should return 'No note found.' when no note exists""" """Should return 'No note found.' when no note exists"""
result = notes_tool.execute_action("view") result = notes_tool.execute_action("view")
assert "no note found" in result.lower() assert "no note found" in result.lower()
@pytest.mark.unit
def test_str_replace(notes_tool: NotesTool) -> None: def test_str_replace(notes_tool: NotesTool) -> None:
"""Test string replacement in note""" """Test string replacement in note"""
# Create a note # Create a note
notes_tool.execute_action("overwrite", text="Hello world, hello universe") notes_tool.execute_action("overwrite", text="Hello world, hello universe")
# Replace text # Replace text
result = notes_tool.execute_action("str_replace", old_str="hello", new_str="hi") result = notes_tool.execute_action("str_replace", old_str="hello", new_str="hi")
assert "updated" in result.lower() assert "updated" in result.lower()
# Verify replacement # Verify replacement
note = notes_tool.execute_action("view") note = notes_tool.execute_action("view")
assert "hi world, hi universe" in note.lower() assert "hi world, hi universe" in note.lower()
@pytest.mark.unit
def test_str_replace_not_found(notes_tool: NotesTool) -> None: def test_str_replace_not_found(notes_tool: NotesTool) -> None:
"""Test string replacement when string not found""" """Test string replacement when string not found"""
notes_tool.execute_action("overwrite", text="Hello world") notes_tool.execute_action("overwrite", text="Hello world")
@@ -107,22 +125,27 @@ def test_str_replace_not_found(notes_tool: NotesTool) -> None:
assert "not found" in result.lower() assert "not found" in result.lower()
@pytest.mark.unit
def test_insert_line(notes_tool: NotesTool) -> None: def test_insert_line(notes_tool: NotesTool) -> None:
"""Test inserting text at a line number""" """Test inserting text at a line number"""
# Create a multiline note # Create a multiline note
notes_tool.execute_action("overwrite", text="Line 1\nLine 2\nLine 3") notes_tool.execute_action("overwrite", text="Line 1\nLine 2\nLine 3")
# Insert at line 2 # Insert at line 2
result = notes_tool.execute_action("insert", line_number=2, text="Inserted line") result = notes_tool.execute_action("insert", line_number=2, text="Inserted line")
assert "inserted" in result.lower() assert "inserted" in result.lower()
# Verify insertion # Verify insertion
note = notes_tool.execute_action("view") note = notes_tool.execute_action("view")
lines = note.split("\n") lines = note.split("\n")
assert lines[1] == "Inserted line" assert lines[1] == "Inserted line"
assert lines[2] == "Line 2" assert lines[2] == "Line 2"
@pytest.mark.unit
def test_delete_nonexistent_note(monkeypatch): def test_delete_nonexistent_note(monkeypatch):
class FakeResult: class FakeResult:
deleted_count = 0 deleted_count = 0
@@ -133,7 +156,7 @@ def test_delete_nonexistent_note(monkeypatch):
monkeypatch.setattr( monkeypatch.setattr(
"application.core.mongo_db.MongoDB.get_client", "application.core.mongo_db.MongoDB.get_client",
lambda: {"docsgpt": {"notes": FakeCollection()}} lambda: {"docsgpt": {"notes": FakeCollection()}},
) )
notes_tool = NotesTool(tool_config={}, user_id="user123") notes_tool = NotesTool(tool_config={}, user_id="user123")
@@ -141,8 +164,10 @@ def test_delete_nonexistent_note(monkeypatch):
assert "no note found" in result.lower() assert "no note found" in result.lower()
@pytest.mark.unit
def test_notes_tool_isolation(monkeypatch) -> None: def test_notes_tool_isolation(monkeypatch) -> None:
"""Test that different notes tool instances have isolated notes.""" """Test that different notes tool instances have isolated notes."""
class FakeCollection: class FakeCollection:
def __init__(self) -> None: def __init__(self) -> None:
self.docs = {} self.docs = {}
@@ -170,19 +195,25 @@ def test_notes_tool_isolation(monkeypatch) -> None:
fake_db = {"notes": fake_collection} fake_db = {"notes": fake_collection}
fake_client = {settings.MONGO_DB_NAME: fake_db} fake_client = {settings.MONGO_DB_NAME: fake_db}
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client) monkeypatch.setattr(
"application.core.mongo_db.MongoDB.get_client", lambda: fake_client
)
# Create two notes tools with different tool_ids for the same user # Create two notes tools with different tool_ids for the same user
tool1 = NotesTool({"tool_id": "tool_1"}, user_id="test_user") tool1 = NotesTool({"tool_id": "tool_1"}, user_id="test_user")
tool2 = NotesTool({"tool_id": "tool_2"}, user_id="test_user") tool2 = NotesTool({"tool_id": "tool_2"}, user_id="test_user")
# Create a note in tool1 # Create a note in tool1
tool1.execute_action("overwrite", text="Content from tool 1") tool1.execute_action("overwrite", text="Content from tool 1")
# Create a note in tool2 # Create a note in tool2
tool2.execute_action("overwrite", text="Content from tool 2") tool2.execute_action("overwrite", text="Content from tool 2")
# Verify that each tool sees only its own content # Verify that each tool sees only its own content
result1 = tool1.execute_action("view") result1 = tool1.execute_action("view")
result2 = tool2.execute_action("view") result2 = tool2.execute_action("view")
@@ -193,8 +224,10 @@ def test_notes_tool_isolation(monkeypatch) -> None:
assert "Content from tool 1" not in result2 assert "Content from tool 1" not in result2
@pytest.mark.unit
def test_notes_tool_auto_generates_tool_id(monkeypatch) -> None: def test_notes_tool_auto_generates_tool_id(monkeypatch) -> None:
"""Test that tool_id defaults to 'default_{user_id}' for persistence.""" """Test that tool_id defaults to 'default_{user_id}' for persistence."""
class FakeCollection: class FakeCollection:
def __init__(self) -> None: def __init__(self) -> None:
self.docs = {} self.docs = {}
@@ -206,18 +239,23 @@ def test_notes_tool_auto_generates_tool_id(monkeypatch) -> None:
fake_db = {"notes": fake_collection} fake_db = {"notes": fake_collection}
fake_client = {settings.MONGO_DB_NAME: fake_db} fake_client = {settings.MONGO_DB_NAME: fake_db}
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client) monkeypatch.setattr(
"application.core.mongo_db.MongoDB.get_client", lambda: fake_client
)
# Create two tools without providing tool_id for the same user # Create two tools without providing tool_id for the same user
tool1 = NotesTool({}, user_id="test_user") tool1 = NotesTool({}, user_id="test_user")
tool2 = NotesTool({}, user_id="test_user") tool2 = NotesTool({}, user_id="test_user")
# Both should have the same default tool_id for persistence # Both should have the same default tool_id for persistence
assert tool1.tool_id == "default_test_user" assert tool1.tool_id == "default_test_user"
assert tool2.tool_id == "default_test_user" assert tool2.tool_id == "default_test_user"
assert tool1.tool_id == tool2.tool_id assert tool1.tool_id == tool2.tool_id
# Different users should have different tool_ids # Different users should have different tool_ids
tool3 = NotesTool({}, user_id="another_user") tool3 = NotesTool({}, user_id="another_user")
assert tool3.tool_id == "default_another_user" assert tool3.tool_id == "default_another_user"
assert tool3.tool_id != tool1.tool_id assert tool3.tool_id != tool1.tool_id

View File

@@ -1,6 +1,6 @@
import pytest import pytest
from openapi_parser import parse
from application.parser.file.openapi3_parser import OpenAPI3Parser from application.parser.file.openapi3_parser import OpenAPI3Parser
from openapi_parser import parse
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -17,10 +17,12 @@ from application.parser.file.openapi3_parser import OpenAPI3Parser
), ),
], ],
) )
@pytest.mark.unit
def test_get_base_urls(urls, expected_base_urls): def test_get_base_urls(urls, expected_base_urls):
assert OpenAPI3Parser().get_base_urls(urls) == expected_base_urls assert OpenAPI3Parser().get_base_urls(urls) == expected_base_urls
@pytest.mark.unit
def test_get_info_from_paths(): def test_get_info_from_paths():
file_path = "tests/test_openapi3.yaml" file_path = "tests/test_openapi3.yaml"
data = parse(file_path) data = parse(file_path)
@@ -31,6 +33,7 @@ def test_get_info_from_paths():
) )
@pytest.mark.unit
def test_parse_file(): def test_parse_file():
file_path = "tests/test_openapi3.yaml" file_path = "tests/test_openapi3.yaml"
results_expected = ( results_expected = (