mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 00:23:17 +00:00
test: add agent test coverage and standardize test suite (#2051)
- Add 104 comprehensive tests for agent system - Integrate agent tests into CI/CD pipeline - Standardize tests with @pytest.mark.unit markers - Fix cross-platform path compatibility - Clean up unused imports and dependencies
This commit is contained in:
6
.github/workflows/pytest.yml
vendored
6
.github/workflows/pytest.yml
vendored
@@ -16,15 +16,15 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install pytest pytest-cov
|
||||
cd application
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
||||
cd ../tests
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
||||
- name: Test with pytest and generate coverage report
|
||||
run: |
|
||||
python -m pytest --cov=application --cov-report=xml
|
||||
python -m pytest --cov=application --cov-report=xml --cov-report=term-missing
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: github.event_name == 'pull_request' && matrix.python-version == '3.12'
|
||||
uses: codecov/codecov-action@v5
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ class BaseAgent(ABC):
|
||||
self.user_api_key = user_api_key
|
||||
self.prompt = prompt
|
||||
self.decoded_token = decoded_token or {}
|
||||
self.user: str = decoded_token.get("sub")
|
||||
self.user: str = self.decoded_token.get("sub")
|
||||
self.tool_config: Dict = {}
|
||||
self.tools: List[Dict] = []
|
||||
self.tool_calls: List[Dict] = []
|
||||
|
||||
@@ -20,20 +20,24 @@ class ToolActionParser:
|
||||
try:
|
||||
call_args = json.loads(call.arguments)
|
||||
tool_parts = call.name.split("_")
|
||||
|
||||
|
||||
# If the tool name doesn't contain an underscore, it's likely a hallucinated tool
|
||||
if len(tool_parts) < 2:
|
||||
logger.warning(f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id")
|
||||
logger.warning(
|
||||
f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id"
|
||||
)
|
||||
return None, None, None
|
||||
|
||||
|
||||
tool_id = tool_parts[-1]
|
||||
action_name = "_".join(tool_parts[:-1])
|
||||
|
||||
|
||||
# Validate that tool_id looks like a numerical ID
|
||||
if not tool_id.isdigit():
|
||||
logger.warning(f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call.")
|
||||
|
||||
except (AttributeError, TypeError) as e:
|
||||
logger.warning(
|
||||
f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call."
|
||||
)
|
||||
|
||||
except (AttributeError, TypeError, json.JSONDecodeError) as e:
|
||||
logger.error(f"Error parsing OpenAI LLM call: {e}")
|
||||
return None, None, None
|
||||
return tool_id, action_name, call_args
|
||||
@@ -42,19 +46,23 @@ class ToolActionParser:
|
||||
try:
|
||||
call_args = call.arguments
|
||||
tool_parts = call.name.split("_")
|
||||
|
||||
|
||||
# If the tool name doesn't contain an underscore, it's likely a hallucinated tool
|
||||
if len(tool_parts) < 2:
|
||||
logger.warning(f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id")
|
||||
logger.warning(
|
||||
f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id"
|
||||
)
|
||||
return None, None, None
|
||||
|
||||
|
||||
tool_id = tool_parts[-1]
|
||||
action_name = "_".join(tool_parts[:-1])
|
||||
|
||||
|
||||
# Validate that tool_id looks like a numerical ID
|
||||
if not tool_id.isdigit():
|
||||
logger.warning(f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call.")
|
||||
|
||||
logger.warning(
|
||||
f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call."
|
||||
)
|
||||
|
||||
except (AttributeError, TypeError) as e:
|
||||
logger.error(f"Error parsing Google LLM call: {e}")
|
||||
return None, None, None
|
||||
|
||||
20
pytest.ini
Normal file
20
pytest.ini
Normal file
@@ -0,0 +1,20 @@
|
||||
[pytest]
|
||||
testpaths = tests
|
||||
python_files = test_*.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
addopts =
|
||||
-v
|
||||
--strict-markers
|
||||
--tb=short
|
||||
--cov=application
|
||||
--cov-report=html
|
||||
--cov-report=term-missing
|
||||
--cov-report=xml
|
||||
markers =
|
||||
unit: Unit tests
|
||||
integration: Integration tests
|
||||
slow: Slow running tests
|
||||
filterwarnings =
|
||||
ignore::DeprecationWarning
|
||||
ignore::PendingDeprecationWarning
|
||||
0
tests/agents/__init__.py
Normal file
0
tests/agents/__init__.py
Normal file
56
tests/agents/test_agent_creator.py
Normal file
56
tests/agents/test_agent_creator.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import pytest
|
||||
from application.agents.agent_creator import AgentCreator
|
||||
from application.agents.classic_agent import ClassicAgent
|
||||
from application.agents.react_agent import ReActAgent
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestAgentCreator:
|
||||
|
||||
def test_create_classic_agent(self, agent_base_params):
|
||||
agent = AgentCreator.create_agent("classic", **agent_base_params)
|
||||
assert isinstance(agent, ClassicAgent)
|
||||
assert agent.endpoint == agent_base_params["endpoint"]
|
||||
assert agent.llm_name == agent_base_params["llm_name"]
|
||||
assert agent.gpt_model == agent_base_params["gpt_model"]
|
||||
|
||||
def test_create_react_agent(self, agent_base_params):
|
||||
agent = AgentCreator.create_agent("react", **agent_base_params)
|
||||
assert isinstance(agent, ReActAgent)
|
||||
assert agent.endpoint == agent_base_params["endpoint"]
|
||||
assert agent.llm_name == agent_base_params["llm_name"]
|
||||
|
||||
def test_create_agent_case_insensitive(self, agent_base_params):
|
||||
agent_upper = AgentCreator.create_agent("CLASSIC", **agent_base_params)
|
||||
agent_mixed = AgentCreator.create_agent("ClAsSiC", **agent_base_params)
|
||||
|
||||
assert isinstance(agent_upper, ClassicAgent)
|
||||
assert isinstance(agent_mixed, ClassicAgent)
|
||||
|
||||
def test_create_agent_invalid_type(self, agent_base_params):
|
||||
with pytest.raises(ValueError, match="No agent class found for type"):
|
||||
AgentCreator.create_agent("invalid_agent_type", **agent_base_params)
|
||||
|
||||
def test_agent_registry_contains_expected_agents(self):
|
||||
assert "classic" in AgentCreator.agents
|
||||
assert "react" in AgentCreator.agents
|
||||
assert AgentCreator.agents["classic"] == ClassicAgent
|
||||
assert AgentCreator.agents["react"] == ReActAgent
|
||||
|
||||
def test_create_agent_with_optional_params(self, agent_base_params):
|
||||
agent_base_params["user_api_key"] = "user_key_123"
|
||||
agent_base_params["chat_history"] = [{"prompt": "test", "response": "test"}]
|
||||
agent_base_params["json_schema"] = {"type": "object"}
|
||||
|
||||
agent = AgentCreator.create_agent("classic", **agent_base_params)
|
||||
|
||||
assert agent.user_api_key == "user_key_123"
|
||||
assert len(agent.chat_history) == 1
|
||||
assert agent.json_schema == {"type": "object"}
|
||||
|
||||
def test_create_agent_with_attachments(self, agent_base_params):
|
||||
attachments = [{"name": "file.txt", "content": "test"}]
|
||||
agent_base_params["attachments"] = attachments
|
||||
|
||||
agent = AgentCreator.create_agent("classic", **agent_base_params)
|
||||
assert agent.attachments == attachments
|
||||
641
tests/agents/test_base_agent.py
Normal file
641
tests/agents/test_base_agent.py
Normal file
@@ -0,0 +1,641 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from application.agents.classic_agent import ClassicAgent
|
||||
from application.core.settings import settings
|
||||
from tests.conftest import FakeMongoCollection
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBaseAgentInitialization:
|
||||
|
||||
def test_agent_initialization(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
assert agent.endpoint == agent_base_params["endpoint"]
|
||||
assert agent.llm_name == agent_base_params["llm_name"]
|
||||
assert agent.gpt_model == agent_base_params["gpt_model"]
|
||||
assert agent.api_key == agent_base_params["api_key"]
|
||||
assert agent.prompt == agent_base_params["prompt"]
|
||||
assert agent.user == agent_base_params["decoded_token"]["sub"]
|
||||
assert agent.tools == []
|
||||
assert agent.tool_calls == []
|
||||
|
||||
def test_agent_initialization_with_none_chat_history(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent_base_params["chat_history"] = None
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
assert agent.chat_history == []
|
||||
|
||||
def test_agent_initialization_with_chat_history(
|
||||
self,
|
||||
agent_base_params,
|
||||
sample_chat_history,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
):
|
||||
agent_base_params["chat_history"] = sample_chat_history
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
assert len(agent.chat_history) == 2
|
||||
assert agent.chat_history[0]["prompt"] == "What is Python?"
|
||||
|
||||
def test_agent_decoded_token_defaults_to_empty_dict(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent_base_params["decoded_token"] = None
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
assert agent.decoded_token == {}
|
||||
assert agent.user is None
|
||||
|
||||
def test_agent_user_extracted_from_token(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent_base_params["decoded_token"] = {"sub": "user123"}
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
assert agent.user == "user123"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBaseAgentBuildMessages:
|
||||
|
||||
def test_build_messages_basic(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
system_prompt = "System: {summaries}"
|
||||
query = "What is Python?"
|
||||
retrieved_data = [
|
||||
{"text": "Python is a programming language", "filename": "python.txt"}
|
||||
]
|
||||
|
||||
messages = agent._build_messages(system_prompt, query, retrieved_data)
|
||||
|
||||
assert len(messages) >= 2
|
||||
assert messages[0]["role"] == "system"
|
||||
assert "Python is a programming language" in messages[0]["content"]
|
||||
assert messages[-1]["role"] == "user"
|
||||
assert messages[-1]["content"] == query
|
||||
|
||||
def test_build_messages_with_chat_history(
|
||||
self,
|
||||
agent_base_params,
|
||||
sample_chat_history,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
):
|
||||
agent_base_params["chat_history"] = sample_chat_history
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
system_prompt = "System: {summaries}"
|
||||
query = "New question?"
|
||||
retrieved_data = [{"text": "Data", "filename": "file.txt"}]
|
||||
|
||||
messages = agent._build_messages(system_prompt, query, retrieved_data)
|
||||
|
||||
user_messages = [m for m in messages if m["role"] == "user"]
|
||||
assistant_messages = [m for m in messages if m["role"] == "assistant"]
|
||||
|
||||
assert len(user_messages) >= 3
|
||||
assert len(assistant_messages) >= 2
|
||||
|
||||
def test_build_messages_with_tool_calls_in_history(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
tool_call_history = [
|
||||
{
|
||||
"tool_calls": [
|
||||
{
|
||||
"call_id": "123",
|
||||
"action_name": "test_action",
|
||||
"arguments": {"arg": "value"},
|
||||
"result": "success",
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
agent_base_params["chat_history"] = tool_call_history
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
messages = agent._build_messages(
|
||||
"System: {summaries}", "query", [{"text": "data", "filename": "file.txt"}]
|
||||
)
|
||||
|
||||
tool_messages = [m for m in messages if m["role"] == "tool"]
|
||||
assert len(tool_messages) > 0
|
||||
|
||||
def test_build_messages_handles_missing_filename(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
retrieved_data = [{"text": "Document without filename or title"}]
|
||||
|
||||
messages = agent._build_messages("System: {summaries}", "query", retrieved_data)
|
||||
|
||||
assert messages[0]["role"] == "system"
|
||||
assert "Document without filename" in messages[0]["content"]
|
||||
|
||||
def test_build_messages_uses_title_as_fallback(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
retrieved_data = [{"text": "Data", "title": "Title Doc"}]
|
||||
|
||||
messages = agent._build_messages("System: {summaries}", "query", retrieved_data)
|
||||
|
||||
assert "Title Doc" in messages[0]["content"]
|
||||
|
||||
def test_build_messages_uses_source_as_fallback(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
retrieved_data = [{"text": "Data", "source": "source.txt"}]
|
||||
|
||||
messages = agent._build_messages("System: {summaries}", "query", retrieved_data)
|
||||
|
||||
assert "source.txt" in messages[0]["content"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBaseAgentTools:
|
||||
|
||||
def test_get_user_tools(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_mongo_db,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
):
|
||||
mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"].docs = {
|
||||
"1": {"_id": "1", "user": "test_user", "name": "tool1", "status": True},
|
||||
"2": {"_id": "2", "user": "test_user", "name": "tool2", "status": True},
|
||||
}
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
tools = agent._get_user_tools("test_user")
|
||||
|
||||
assert len(tools) == 2
|
||||
assert "0" in tools
|
||||
assert "1" in tools
|
||||
|
||||
def test_get_user_tools_filters_by_status(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_mongo_db,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
):
|
||||
mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"].docs = {
|
||||
"1": {"_id": "1", "user": "test_user", "name": "tool1", "status": True},
|
||||
"2": {"_id": "2", "user": "test_user", "name": "tool2", "status": False},
|
||||
}
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
tools = agent._get_user_tools("test_user")
|
||||
|
||||
assert len(tools) == 1
|
||||
|
||||
def test_get_tools_by_api_key(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_mongo_db,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
):
|
||||
from bson.objectid import ObjectId
|
||||
|
||||
tool_id = str(ObjectId())
|
||||
tool_obj_id = ObjectId(tool_id)
|
||||
|
||||
fake_agent_collection = FakeMongoCollection()
|
||||
fake_agent_collection.docs["api_key_123"] = {
|
||||
"key": "api_key_123",
|
||||
"tools": [tool_id],
|
||||
}
|
||||
|
||||
fake_tools_collection = FakeMongoCollection()
|
||||
fake_tools_collection.docs[tool_id] = {"_id": tool_obj_id, "name": "api_tool"}
|
||||
|
||||
mock_mongo_db[settings.MONGO_DB_NAME]["agents"] = fake_agent_collection
|
||||
mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"] = fake_tools_collection
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
tools = agent._get_tools("api_key_123")
|
||||
|
||||
assert tool_id in tools
|
||||
|
||||
def test_build_tool_parameters(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
action = {
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"param1": {
|
||||
"type": "string",
|
||||
"description": "Test param",
|
||||
"filled_by_llm": True,
|
||||
},
|
||||
"param2": {"type": "number", "filled_by_llm": False, "value": 42},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
params = agent._build_tool_parameters(action)
|
||||
|
||||
assert "param1" in params["properties"]
|
||||
assert "param1" in params["required"]
|
||||
assert "filled_by_llm" not in params["properties"]["param1"]
|
||||
|
||||
def test_prepare_tools_with_api_tool(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
tools_dict = {
|
||||
"1": {
|
||||
"name": "api_tool",
|
||||
"config": {
|
||||
"actions": {
|
||||
"get_data": {
|
||||
"name": "get_data",
|
||||
"description": "Get data from API",
|
||||
"active": True,
|
||||
"url": "https://api.example.com/data",
|
||||
"method": "GET",
|
||||
"parameters": {"properties": {}},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
agent._prepare_tools(tools_dict)
|
||||
|
||||
assert len(agent.tools) == 1
|
||||
assert agent.tools[0]["type"] == "function"
|
||||
assert agent.tools[0]["function"]["name"] == "get_data_1"
|
||||
|
||||
def test_prepare_tools_with_regular_tool(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
tools_dict = {
|
||||
"1": {
|
||||
"name": "custom_tool",
|
||||
"actions": [
|
||||
{
|
||||
"name": "action1",
|
||||
"description": "Custom action",
|
||||
"active": True,
|
||||
"parameters": {"properties": {}},
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
agent._prepare_tools(tools_dict)
|
||||
|
||||
assert len(agent.tools) == 1
|
||||
assert agent.tools[0]["function"]["name"] == "action1_1"
|
||||
|
||||
def test_prepare_tools_filters_inactive_actions(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
tools_dict = {
|
||||
"1": {
|
||||
"name": "custom_tool",
|
||||
"actions": [
|
||||
{
|
||||
"name": "active_action",
|
||||
"description": "Active",
|
||||
"active": True,
|
||||
"parameters": {"properties": {}},
|
||||
},
|
||||
{
|
||||
"name": "inactive_action",
|
||||
"description": "Inactive",
|
||||
"active": False,
|
||||
"parameters": {"properties": {}},
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
agent._prepare_tools(tools_dict)
|
||||
|
||||
assert len(agent.tools) == 1
|
||||
assert agent.tools[0]["function"]["name"] == "active_action_1"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBaseAgentToolExecution:
|
||||
|
||||
def test_execute_tool_action_success(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
mock_tool_manager,
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
call = Mock()
|
||||
call.id = "call_123"
|
||||
call.name = "test_action_1"
|
||||
call.arguments = '{"param1": "value1"}'
|
||||
|
||||
tools_dict = {
|
||||
"1": {
|
||||
"name": "custom_tool",
|
||||
"config": {},
|
||||
"actions": [
|
||||
{
|
||||
"name": "test_action",
|
||||
"description": "Test",
|
||||
"parameters": {"properties": {}},
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
results = list(agent._execute_tool_action(tools_dict, call))
|
||||
|
||||
assert len(results) >= 2
|
||||
assert results[0]["type"] == "tool_call"
|
||||
assert results[0]["data"]["status"] == "pending"
|
||||
assert results[-1]["data"]["status"] == "completed"
|
||||
|
||||
def test_execute_tool_action_invalid_tool_name(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
call = Mock()
|
||||
call.id = "call_123"
|
||||
call.name = "invalid_format"
|
||||
call.arguments = "{}"
|
||||
|
||||
tools_dict = {}
|
||||
|
||||
results = list(agent._execute_tool_action(tools_dict, call))
|
||||
|
||||
assert results[0]["type"] == "tool_call"
|
||||
assert results[0]["data"]["status"] == "error"
|
||||
assert (
|
||||
"Failed to parse" in results[0]["data"]["result"]
|
||||
or "not found" in results[0]["data"]["result"]
|
||||
)
|
||||
|
||||
def test_execute_tool_action_tool_not_found(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
call = Mock()
|
||||
call.id = "call_123"
|
||||
call.name = "action_999"
|
||||
call.arguments = "{}"
|
||||
|
||||
tools_dict = {"1": {"name": "tool1", "config": {}, "actions": []}}
|
||||
|
||||
results = list(agent._execute_tool_action(tools_dict, call))
|
||||
|
||||
assert results[0]["type"] == "tool_call"
|
||||
assert results[0]["data"]["status"] == "error"
|
||||
assert "not found" in results[0]["data"]["result"]
|
||||
|
||||
def test_execute_tool_action_with_parameters(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
mock_tool_manager,
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
call = Mock()
|
||||
call.id = "call_123"
|
||||
call.name = "test_action_1"
|
||||
call.arguments = '{"param1": "value1", "param2": "value2"}'
|
||||
|
||||
tools_dict = {
|
||||
"1": {
|
||||
"name": "custom_tool",
|
||||
"config": {},
|
||||
"actions": [
|
||||
{
|
||||
"name": "test_action",
|
||||
"description": "Test",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"param1": {"type": "string"},
|
||||
"param2": {"type": "string"},
|
||||
}
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
results = list(agent._execute_tool_action(tools_dict, call))
|
||||
|
||||
assert results[-1]["data"]["status"] == "completed"
|
||||
assert results[-1]["data"]["arguments"]["param1"] == "value1"
|
||||
|
||||
def test_get_truncated_tool_calls(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
agent.tool_calls = [
|
||||
{
|
||||
"tool_name": "test_tool",
|
||||
"call_id": "123",
|
||||
"action_name": "action",
|
||||
"arguments": {},
|
||||
"result": "a" * 100,
|
||||
}
|
||||
]
|
||||
|
||||
truncated = agent._get_truncated_tool_calls()
|
||||
|
||||
assert len(truncated) == 1
|
||||
assert len(truncated[0]["result"]) <= 53
|
||||
assert truncated[0]["result"].endswith("...")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBaseAgentRetrieverSearch:
|
||||
|
||||
def test_retriever_search(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
log_context,
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
results = agent._retriever_search(mock_retriever, "test query", log_context)
|
||||
|
||||
assert len(results) == 2
|
||||
mock_retriever.search.assert_called_once_with("test query")
|
||||
|
||||
def test_retriever_search_logs_context(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
log_context,
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
agent._retriever_search(mock_retriever, "test query", log_context)
|
||||
|
||||
assert len(log_context.stacks) == 1
|
||||
assert log_context.stacks[0]["component"] == "retriever"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBaseAgentLLMGeneration:
|
||||
|
||||
def test_llm_gen_basic(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_llm,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
log_context,
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
messages = [{"role": "user", "content": "test"}]
|
||||
agent._llm_gen(messages, log_context)
|
||||
|
||||
mock_llm.gen_stream.assert_called_once()
|
||||
call_args = mock_llm.gen_stream.call_args[1]
|
||||
assert call_args["model"] == agent.gpt_model
|
||||
assert call_args["messages"] == messages
|
||||
|
||||
def test_llm_gen_with_tools(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_llm,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
log_context,
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
agent.tools = [{"type": "function", "function": {"name": "test"}}]
|
||||
|
||||
messages = [{"role": "user", "content": "test"}]
|
||||
agent._llm_gen(messages, log_context)
|
||||
|
||||
call_args = mock_llm.gen_stream.call_args[1]
|
||||
assert "tools" in call_args
|
||||
assert call_args["tools"] == agent.tools
|
||||
|
||||
def test_llm_gen_with_json_schema(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_llm,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
log_context,
|
||||
):
|
||||
mock_llm._supports_structured_output = Mock(return_value=True)
|
||||
mock_llm.prepare_structured_output_format = Mock(
|
||||
return_value={"schema": "test"}
|
||||
)
|
||||
|
||||
agent_base_params["json_schema"] = {"type": "object"}
|
||||
agent_base_params["llm_name"] = "openai"
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
messages = [{"role": "user", "content": "test"}]
|
||||
agent._llm_gen(messages, log_context)
|
||||
|
||||
call_args = mock_llm.gen_stream.call_args[1]
|
||||
assert "response_format" in call_args
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBaseAgentHandleResponse:
|
||||
|
||||
def test_handle_response_string(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator, log_context
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
response = "Simple string response"
|
||||
results = list(agent._handle_response(response, {}, [], log_context))
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["answer"] == "Simple string response"
|
||||
|
||||
def test_handle_response_with_message(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator, log_context
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
response = Mock()
|
||||
response.message = Mock()
|
||||
response.message.content = "Message content"
|
||||
|
||||
results = list(agent._handle_response(response, {}, [], log_context))
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["answer"] == "Message content"
|
||||
|
||||
def test_handle_response_with_structured_output(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_llm,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
log_context,
|
||||
):
|
||||
mock_llm._supports_structured_output = Mock(return_value=True)
|
||||
agent_base_params["json_schema"] = {"type": "object"}
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
response = "Structured response"
|
||||
results = list(agent._handle_response(response, {}, [], log_context))
|
||||
|
||||
assert results[0]["structured"] is True
|
||||
assert results[0]["schema"] == {"type": "object"}
|
||||
|
||||
def test_handle_response_with_handler(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
log_context,
|
||||
):
|
||||
def mock_process(*args):
|
||||
yield {"type": "tool_call", "data": {}}
|
||||
yield "Final answer"
|
||||
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_process)
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
response = Mock()
|
||||
response.message = None
|
||||
|
||||
results = list(agent._handle_response(response, {}, [], log_context))
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0]["type"] == "tool_call"
|
||||
assert results[1]["answer"] == "Final answer"
|
||||
242
tests/agents/test_classic_agent.py
Normal file
242
tests/agents/test_classic_agent.py
Normal file
@@ -0,0 +1,242 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from application.agents.classic_agent import ClassicAgent
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestClassicAgent:
|
||||
|
||||
def test_classic_agent_initialization(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
assert isinstance(agent, ClassicAgent)
|
||||
assert agent.endpoint == agent_base_params["endpoint"]
|
||||
assert agent.llm_name == agent_base_params["llm_name"]
|
||||
|
||||
def test_gen_inner_basic_flow(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
mock_mongo_db,
|
||||
log_context,
|
||||
):
|
||||
def mock_gen_stream(*args, **kwargs):
|
||||
yield "Answer chunk 1"
|
||||
yield "Answer chunk 2"
|
||||
|
||||
mock_llm.gen_stream = Mock(return_value=mock_gen_stream())
|
||||
|
||||
def mock_handler(*args, **kwargs):
|
||||
yield "Processed answer"
|
||||
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
|
||||
assert len(results) >= 2
|
||||
sources = [r for r in results if "sources" in r]
|
||||
tool_calls = [r for r in results if "tool_calls" in r]
|
||||
|
||||
assert len(sources) == 1
|
||||
assert len(tool_calls) == 1
|
||||
|
||||
def test_gen_inner_retrieves_documents(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
mock_mongo_db,
|
||||
log_context,
|
||||
):
|
||||
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
|
||||
|
||||
def mock_handler(*args, **kwargs):
|
||||
yield "Processed"
|
||||
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
|
||||
mock_retriever.search.assert_called_once_with("Test query")
|
||||
|
||||
def test_gen_inner_uses_user_api_key_tools(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
mock_mongo_db,
|
||||
log_context,
|
||||
):
|
||||
from application.core.settings import settings
|
||||
from bson.objectid import ObjectId
|
||||
|
||||
tool_id = str(ObjectId())
|
||||
mock_mongo_db[settings.MONGO_DB_NAME]["agents"].docs = {
|
||||
"api_key_123": {"key": "api_key_123", "tools": [tool_id]}
|
||||
}
|
||||
mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"].docs = {
|
||||
tool_id: {"_id": ObjectId(tool_id), "name": "test_tool"}
|
||||
}
|
||||
|
||||
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
|
||||
|
||||
def mock_handler(*args, **kwargs):
|
||||
yield "Processed"
|
||||
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
||||
|
||||
agent_base_params["user_api_key"] = "api_key_123"
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
|
||||
assert len(agent.tools) >= 0
|
||||
|
||||
def test_gen_inner_uses_user_tools(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
mock_mongo_db,
|
||||
log_context,
|
||||
):
|
||||
from application.core.settings import settings
|
||||
|
||||
mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"].docs = {
|
||||
"1": {"_id": "1", "user": "test_user", "name": "tool1", "status": True}
|
||||
}
|
||||
|
||||
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
|
||||
|
||||
def mock_handler(*args, **kwargs):
|
||||
yield "Processed"
|
||||
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
|
||||
assert len(agent.tools) >= 0
|
||||
|
||||
def test_gen_inner_builds_correct_messages(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
mock_mongo_db,
|
||||
log_context,
|
||||
):
|
||||
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
|
||||
|
||||
def mock_handler(*args, **kwargs):
|
||||
yield "Processed"
|
||||
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
|
||||
call_kwargs = mock_llm.gen_stream.call_args[1]
|
||||
messages = call_kwargs["messages"]
|
||||
|
||||
assert len(messages) >= 2
|
||||
assert messages[0]["role"] == "system"
|
||||
assert messages[-1]["role"] == "user"
|
||||
assert messages[-1]["content"] == "Test query"
|
||||
|
||||
def test_gen_inner_logs_tool_calls(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
mock_mongo_db,
|
||||
log_context,
|
||||
):
|
||||
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
|
||||
|
||||
def mock_handler(*args, **kwargs):
|
||||
yield "Processed"
|
||||
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
agent.tool_calls = [{"tool": "test", "result": "success"}]
|
||||
|
||||
list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
|
||||
agent_logs = [s for s in log_context.stacks if s["component"] == "agent"]
|
||||
assert len(agent_logs) == 1
|
||||
assert "tool_calls" in agent_logs[0]["data"]
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestClassicAgentIntegration:
|
||||
|
||||
def test_gen_method_with_logging(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
mock_mongo_db,
|
||||
):
|
||||
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
|
||||
|
||||
def mock_handler(*args, **kwargs):
|
||||
yield "Processed"
|
||||
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
results = list(agent.gen("Test query", mock_retriever))
|
||||
|
||||
assert len(results) >= 1
|
||||
|
||||
def test_gen_method_decorator_applied(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
mock_mongo_db,
|
||||
):
|
||||
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
|
||||
|
||||
def mock_handler(*args, **kwargs):
|
||||
yield "Processed"
|
||||
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
assert hasattr(agent.gen, "__wrapped__")
|
||||
519
tests/agents/test_react_agent.py
Normal file
519
tests/agents/test_react_agent.py
Normal file
@@ -0,0 +1,519 @@
|
||||
from unittest.mock import Mock, mock_open, patch
|
||||
|
||||
import pytest
|
||||
from application.agents.react_agent import ReActAgent
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestReActAgent:
|
||||
|
||||
def test_react_agent_initialization(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
|
||||
assert isinstance(agent, ReActAgent)
|
||||
assert agent.plan == ""
|
||||
assert agent.observations == []
|
||||
|
||||
def test_react_agent_inherits_base_properties(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
|
||||
assert agent.endpoint == agent_base_params["endpoint"]
|
||||
assert agent.llm_name == agent_base_params["llm_name"]
|
||||
assert agent.gpt_model == agent_base_params["gpt_model"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestReActAgentContentExtraction:
|
||||
|
||||
def test_extract_content_from_string(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
|
||||
response = "Simple string response"
|
||||
content = agent._extract_content_from_llm_response(response)
|
||||
|
||||
assert content == "Simple string response"
|
||||
|
||||
def test_extract_content_from_message_object(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
|
||||
response = Mock()
|
||||
response.message = Mock()
|
||||
response.message.content = "Message content"
|
||||
|
||||
content = agent._extract_content_from_llm_response(response)
|
||||
|
||||
assert content == "Message content"
|
||||
|
||||
def test_extract_content_from_openai_response(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
|
||||
response = Mock()
|
||||
response.choices = [Mock()]
|
||||
response.choices[0].message = Mock()
|
||||
response.choices[0].message.content = "OpenAI content"
|
||||
response.message = None
|
||||
response.content = None
|
||||
|
||||
content = agent._extract_content_from_llm_response(response)
|
||||
|
||||
assert content == "OpenAI content"
|
||||
|
||||
def test_extract_content_from_anthropic_response(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
|
||||
text_block = Mock()
|
||||
text_block.text = "Anthropic content"
|
||||
|
||||
response = Mock()
|
||||
response.content = [text_block]
|
||||
response.message = None
|
||||
response.choices = None
|
||||
|
||||
content = agent._extract_content_from_llm_response(response)
|
||||
|
||||
assert content == "Anthropic content"
|
||||
|
||||
def test_extract_content_from_openai_stream(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
|
||||
chunk1 = Mock()
|
||||
chunk1.choices = [Mock()]
|
||||
chunk1.choices[0].delta = Mock()
|
||||
chunk1.choices[0].delta.content = "Part 1 "
|
||||
|
||||
chunk2 = Mock()
|
||||
chunk2.choices = [Mock()]
|
||||
chunk2.choices[0].delta = Mock()
|
||||
chunk2.choices[0].delta.content = "Part 2"
|
||||
|
||||
response = iter([chunk1, chunk2])
|
||||
content = agent._extract_content_from_llm_response(response)
|
||||
|
||||
assert content == "Part 1 Part 2"
|
||||
|
||||
def test_extract_content_from_anthropic_stream(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
|
||||
chunk1 = Mock()
|
||||
chunk1.type = "content_block_delta"
|
||||
chunk1.delta = Mock()
|
||||
chunk1.delta.text = "Stream 1 "
|
||||
chunk1.choices = []
|
||||
|
||||
chunk2 = Mock()
|
||||
chunk2.type = "content_block_delta"
|
||||
chunk2.delta = Mock()
|
||||
chunk2.delta.text = "Stream 2"
|
||||
chunk2.choices = []
|
||||
|
||||
response = iter([chunk1, chunk2])
|
||||
content = agent._extract_content_from_llm_response(response)
|
||||
|
||||
assert content == "Stream 1 Stream 2"
|
||||
|
||||
def test_extract_content_from_string_stream(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
|
||||
response = iter(["chunk1", "chunk2", "chunk3"])
|
||||
content = agent._extract_content_from_llm_response(response)
|
||||
|
||||
assert content == "chunk1chunk2chunk3"
|
||||
|
||||
def test_extract_content_handles_none_content(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
|
||||
response = Mock()
|
||||
response.message = Mock()
|
||||
response.message.content = None
|
||||
response.choices = None
|
||||
response.content = None
|
||||
|
||||
content = agent._extract_content_from_llm_response(response)
|
||||
|
||||
assert content == ""
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestReActAgentPlanning:
|
||||
|
||||
@patch(
|
||||
"builtins.open",
|
||||
new_callable=mock_open,
|
||||
read_data="Test planning prompt: {query} {summaries} {prompt} {observations}",
|
||||
)
|
||||
def test_create_plan(
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
mock_llm,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
log_context,
|
||||
):
|
||||
def mock_gen_stream(*args, **kwargs):
|
||||
yield "Plan step 1"
|
||||
yield "Plan step 2"
|
||||
|
||||
mock_llm.gen_stream = Mock(return_value=mock_gen_stream())
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
agent.observations = ["Observation 1"]
|
||||
|
||||
plan_chunks = list(agent._create_plan("Test query", "Test docs", log_context))
|
||||
|
||||
assert len(plan_chunks) == 2
|
||||
assert plan_chunks[0] == "Plan step 1"
|
||||
assert plan_chunks[1] == "Plan step 2"
|
||||
|
||||
mock_llm.gen_stream.assert_called_once()
|
||||
|
||||
@patch("builtins.open", new_callable=mock_open, read_data="Test: {query}")
|
||||
def test_create_plan_fills_template(
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
mock_llm,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
log_context,
|
||||
):
|
||||
mock_llm.gen_stream = Mock(return_value=iter(["Plan"]))
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
list(agent._create_plan("My query", "Docs", log_context))
|
||||
|
||||
call_args = mock_llm.gen_stream.call_args[1]
|
||||
messages = call_args["messages"]
|
||||
|
||||
assert "My query" in messages[0]["content"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestReActAgentFinalAnswer:
|
||||
|
||||
@patch(
|
||||
"builtins.open",
|
||||
new_callable=mock_open,
|
||||
read_data="Final answer for: {query} with {observations}",
|
||||
)
|
||||
def test_create_final_answer(
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
mock_llm,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
log_context,
|
||||
):
|
||||
def mock_gen_stream(*args, **kwargs):
|
||||
yield "Final "
|
||||
yield "answer"
|
||||
|
||||
mock_llm.gen_stream = Mock(return_value=mock_gen_stream())
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
observations = ["Obs 1", "Obs 2"]
|
||||
|
||||
answer_chunks = list(
|
||||
agent._create_final_answer("Test query", observations, log_context)
|
||||
)
|
||||
|
||||
assert len(answer_chunks) == 2
|
||||
assert answer_chunks[0] == "Final "
|
||||
assert answer_chunks[1] == "answer"
|
||||
|
||||
@patch("builtins.open", new_callable=mock_open, read_data="Answer: {observations}")
|
||||
def test_create_final_answer_truncates_long_observations(
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
mock_llm,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
log_context,
|
||||
):
|
||||
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
long_obs = ["A" * 15000]
|
||||
|
||||
list(agent._create_final_answer("Query", long_obs, log_context))
|
||||
|
||||
call_args = mock_llm.gen_stream.call_args[1]
|
||||
messages = call_args["messages"]
|
||||
|
||||
assert "observations truncated" in messages[0]["content"]
|
||||
|
||||
@patch("builtins.open", new_callable=mock_open, read_data="Test: {query}")
|
||||
def test_create_final_answer_no_tools(
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
mock_llm,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
log_context,
|
||||
):
|
||||
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
list(agent._create_final_answer("Query", ["Obs"], log_context))
|
||||
|
||||
call_args = mock_llm.gen_stream.call_args[1]
|
||||
|
||||
assert call_args["tools"] is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestReActAgentGenInner:
|
||||
|
||||
@patch(
|
||||
"builtins.open", new_callable=mock_open, read_data="Prompt template: {query}"
|
||||
)
|
||||
def test_gen_inner_resets_state(
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
mock_mongo_db,
|
||||
log_context,
|
||||
):
|
||||
mock_llm.gen_stream = Mock(return_value=iter(["SATISFIED"]))
|
||||
|
||||
def mock_handler(*args, **kwargs):
|
||||
yield "SATISFIED"
|
||||
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
agent.plan = "Old plan"
|
||||
agent.observations = ["Old obs"]
|
||||
|
||||
list(agent._gen_inner("New query", mock_retriever, log_context))
|
||||
|
||||
assert agent.plan != "Old plan"
|
||||
assert len(agent.observations) > 0
|
||||
|
||||
@patch("builtins.open", new_callable=mock_open, read_data="Prompt: {query}")
|
||||
def test_gen_inner_stops_on_satisfied(
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
mock_mongo_db,
|
||||
log_context,
|
||||
):
|
||||
iteration_count = 0
|
||||
|
||||
def mock_gen_stream(*args, **kwargs):
|
||||
nonlocal iteration_count
|
||||
iteration_count += 1
|
||||
if iteration_count == 1:
|
||||
yield "Plan"
|
||||
else:
|
||||
yield "SATISFIED - done"
|
||||
|
||||
mock_llm.gen_stream = Mock(
|
||||
side_effect=lambda *args, **kwargs: mock_gen_stream(*args, **kwargs)
|
||||
)
|
||||
|
||||
def mock_handler(*args, **kwargs):
|
||||
yield "SATISFIED - done"
|
||||
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
|
||||
assert any("answer" in r for r in results)
|
||||
|
||||
@patch("builtins.open", new_callable=mock_open, read_data="Prompt: {query}")
|
||||
def test_gen_inner_max_iterations(
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
mock_mongo_db,
|
||||
log_context,
|
||||
):
|
||||
call_count = 0
|
||||
|
||||
def mock_gen_stream(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
yield f"Iteration {call_count}"
|
||||
|
||||
mock_llm.gen_stream = Mock(
|
||||
side_effect=lambda *args, **kwargs: mock_gen_stream(*args, **kwargs)
|
||||
)
|
||||
|
||||
def mock_handler(*args, **kwargs):
|
||||
yield "Continue..."
|
||||
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
|
||||
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
|
||||
thought_results = [r for r in results if "thought" in r]
|
||||
assert len(thought_results) > 0
|
||||
|
||||
@patch("builtins.open", new_callable=mock_open, read_data="Prompt: {query}")
|
||||
def test_gen_inner_yields_sources(
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
mock_mongo_db,
|
||||
log_context,
|
||||
):
|
||||
mock_llm.gen_stream = Mock(return_value=iter(["SATISFIED"]))
|
||||
|
||||
def mock_handler(*args, **kwargs):
|
||||
yield "SATISFIED"
|
||||
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
|
||||
sources = [r for r in results if "sources" in r]
|
||||
assert len(sources) >= 1
|
||||
|
||||
@patch("builtins.open", new_callable=mock_open, read_data="Prompt: {query}")
|
||||
def test_gen_inner_yields_tool_calls(
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
mock_mongo_db,
|
||||
log_context,
|
||||
):
|
||||
mock_llm.gen_stream = Mock(return_value=iter(["SATISFIED"]))
|
||||
|
||||
def mock_handler(*args, **kwargs):
|
||||
yield "SATISFIED"
|
||||
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
agent.tool_calls = [{"tool": "test", "result": "A" * 100}]
|
||||
|
||||
results = list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
|
||||
tool_call_results = [r for r in results if "tool_calls" in r]
|
||||
if tool_call_results:
|
||||
assert len(tool_call_results[0]["tool_calls"][0]["result"]) <= 53
|
||||
|
||||
@patch("builtins.open", new_callable=mock_open, read_data="Prompt: {query}")
|
||||
def test_gen_inner_logs_observations(
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
mock_mongo_db,
|
||||
log_context,
|
||||
):
|
||||
mock_llm.gen_stream = Mock(return_value=iter(["SATISFIED"]))
|
||||
|
||||
def mock_handler(*args, **kwargs):
|
||||
yield "SATISFIED"
|
||||
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
list(agent._gen_inner("Test query", mock_retriever, log_context))
|
||||
|
||||
assert len(agent.observations) > 0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestReActAgentIntegration:
|
||||
|
||||
@patch(
|
||||
"builtins.open",
|
||||
new_callable=mock_open,
|
||||
read_data="Prompt: {query} {summaries} {prompt} {observations}",
|
||||
)
|
||||
def test_full_react_workflow(
|
||||
self,
|
||||
mock_file,
|
||||
agent_base_params,
|
||||
mock_retriever,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
mock_mongo_db,
|
||||
log_context,
|
||||
):
|
||||
call_sequence = []
|
||||
|
||||
def mock_gen_stream(*args, **kwargs):
|
||||
call_sequence.append("gen_stream")
|
||||
if len(call_sequence) <= 2:
|
||||
yield "Planning..."
|
||||
else:
|
||||
yield "SATISFIED final answer"
|
||||
|
||||
mock_llm.gen_stream = Mock(
|
||||
side_effect=lambda *args, **kwargs: mock_gen_stream(*args, **kwargs)
|
||||
)
|
||||
|
||||
def mock_handler(*args, **kwargs):
|
||||
call_sequence.append("handler")
|
||||
yield "SATISFIED final answer"
|
||||
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
||||
|
||||
agent = ReActAgent(**agent_base_params)
|
||||
results = list(agent._gen_inner("Complex query", mock_retriever, log_context))
|
||||
|
||||
assert len(results) > 0
|
||||
assert any("thought" in r for r in results)
|
||||
assert any("answer" in r for r in results)
|
||||
204
tests/agents/test_tool_action_parser.py
Normal file
204
tests/agents/test_tool_action_parser.py
Normal file
@@ -0,0 +1,204 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from application.agents.tools.tool_action_parser import ToolActionParser
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestToolActionParser:
|
||||
|
||||
def test_parser_initialization(self):
|
||||
parser = ToolActionParser("OpenAILLM")
|
||||
assert parser.llm_type == "OpenAILLM"
|
||||
assert "OpenAILLM" in parser.parsers
|
||||
assert "GoogleLLM" in parser.parsers
|
||||
|
||||
def test_parse_openai_llm_valid_call(self):
|
||||
parser = ToolActionParser("OpenAILLM")
|
||||
|
||||
call = Mock()
|
||||
call.name = "get_data_123"
|
||||
call.arguments = '{"param1": "value1", "param2": "value2"}'
|
||||
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
|
||||
assert tool_id == "123"
|
||||
assert action_name == "get_data"
|
||||
assert call_args == {"param1": "value1", "param2": "value2"}
|
||||
|
||||
def test_parse_openai_llm_with_underscore_in_action(self):
|
||||
parser = ToolActionParser("OpenAILLM")
|
||||
|
||||
call = Mock()
|
||||
call.name = "send_email_notification_456"
|
||||
call.arguments = '{"to": "user@example.com"}'
|
||||
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
|
||||
assert tool_id == "456"
|
||||
assert action_name == "send_email_notification"
|
||||
assert call_args == {"to": "user@example.com"}
|
||||
|
||||
def test_parse_openai_llm_invalid_format_no_underscore(self):
|
||||
parser = ToolActionParser("OpenAILLM")
|
||||
|
||||
call = Mock()
|
||||
call.name = "invalidtoolname"
|
||||
call.arguments = "{}"
|
||||
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
|
||||
assert tool_id is None
|
||||
assert action_name is None
|
||||
assert call_args is None
|
||||
|
||||
def test_parse_openai_llm_non_numeric_tool_id(self):
|
||||
parser = ToolActionParser("OpenAILLM")
|
||||
|
||||
call = Mock()
|
||||
call.name = "action_notanumber"
|
||||
call.arguments = "{}"
|
||||
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
|
||||
assert tool_id == "notanumber"
|
||||
assert action_name == "action"
|
||||
|
||||
def test_parse_openai_llm_malformed_json(self):
|
||||
parser = ToolActionParser("OpenAILLM")
|
||||
|
||||
call = Mock()
|
||||
call.name = "action_123"
|
||||
call.arguments = "invalid json"
|
||||
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
|
||||
assert tool_id is None
|
||||
assert action_name is None
|
||||
assert call_args is None
|
||||
|
||||
def test_parse_openai_llm_missing_attributes(self):
|
||||
parser = ToolActionParser("OpenAILLM")
|
||||
|
||||
call = Mock(spec=[])
|
||||
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
|
||||
assert tool_id is None
|
||||
assert action_name is None
|
||||
assert call_args is None
|
||||
|
||||
def test_parse_google_llm_valid_call(self):
|
||||
parser = ToolActionParser("GoogleLLM")
|
||||
|
||||
call = Mock()
|
||||
call.name = "search_documents_789"
|
||||
call.arguments = {"query": "test query", "limit": 10}
|
||||
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
|
||||
assert tool_id == "789"
|
||||
assert action_name == "search_documents"
|
||||
assert call_args == {"query": "test query", "limit": 10}
|
||||
|
||||
def test_parse_google_llm_with_complex_action_name(self):
|
||||
parser = ToolActionParser("GoogleLLM")
|
||||
|
||||
call = Mock()
|
||||
call.name = "create_new_user_account_999"
|
||||
call.arguments = {"username": "test"}
|
||||
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
|
||||
assert tool_id == "999"
|
||||
assert action_name == "create_new_user_account"
|
||||
|
||||
def test_parse_google_llm_invalid_format(self):
|
||||
parser = ToolActionParser("GoogleLLM")
|
||||
|
||||
call = Mock()
|
||||
call.name = "nounderscores"
|
||||
call.arguments = {}
|
||||
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
|
||||
assert tool_id is None
|
||||
assert action_name is None
|
||||
assert call_args is None
|
||||
|
||||
def test_parse_google_llm_missing_attributes(self):
|
||||
parser = ToolActionParser("GoogleLLM")
|
||||
|
||||
call = Mock(spec=[])
|
||||
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
|
||||
assert tool_id is None
|
||||
assert action_name is None
|
||||
assert call_args is None
|
||||
|
||||
def test_parse_unknown_llm_type_defaults_to_openai(self):
|
||||
parser = ToolActionParser("UnknownLLM")
|
||||
|
||||
call = Mock()
|
||||
call.name = "action_123"
|
||||
call.arguments = '{"key": "value"}'
|
||||
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
|
||||
assert tool_id == "123"
|
||||
assert action_name == "action"
|
||||
assert call_args == {"key": "value"}
|
||||
|
||||
def test_parse_args_empty_arguments_openai(self):
|
||||
parser = ToolActionParser("OpenAILLM")
|
||||
|
||||
call = Mock()
|
||||
call.name = "action_123"
|
||||
call.arguments = "{}"
|
||||
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
|
||||
assert tool_id == "123"
|
||||
assert action_name == "action"
|
||||
assert call_args == {}
|
||||
|
||||
def test_parse_args_empty_arguments_google(self):
|
||||
parser = ToolActionParser("GoogleLLM")
|
||||
|
||||
call = Mock()
|
||||
call.name = "action_456"
|
||||
call.arguments = {}
|
||||
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
|
||||
assert tool_id == "456"
|
||||
assert action_name == "action"
|
||||
assert call_args == {}
|
||||
|
||||
def test_parse_args_with_special_characters(self):
|
||||
parser = ToolActionParser("OpenAILLM")
|
||||
|
||||
call = Mock()
|
||||
call.name = "send_message_123"
|
||||
call.arguments = '{"message": "Hello, World! 你好"}'
|
||||
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
|
||||
assert tool_id == "123"
|
||||
assert action_name == "send_message"
|
||||
assert call_args["message"] == "Hello, World! 你好"
|
||||
|
||||
def test_parse_args_with_nested_objects(self):
|
||||
parser = ToolActionParser("OpenAILLM")
|
||||
|
||||
call = Mock()
|
||||
call.name = "create_record_123"
|
||||
call.arguments = '{"data": {"name": "John", "age": 30}}'
|
||||
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
|
||||
assert tool_id == "123"
|
||||
assert action_name == "create_record"
|
||||
assert call_args["data"]["name"] == "John"
|
||||
assert call_args["data"]["age"] == 30
|
||||
235
tests/agents/test_tool_manager.py
Normal file
235
tests/agents/test_tool_manager.py
Normal file
@@ -0,0 +1,235 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from application.agents.tools.base import Tool
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
|
||||
|
||||
class MockTool(Tool):
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
def execute_action(self, action_name: str, **kwargs):
|
||||
return f"Executed {action_name} with {kwargs}"
|
||||
|
||||
def get_actions_metadata(self):
|
||||
return [{"name": "test_action", "description": "Test action"}]
|
||||
|
||||
def get_config_requirements(self):
|
||||
return {"required": ["api_key"]}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestToolManager:
|
||||
|
||||
@patch("application.agents.tools.tool_manager.pkgutil.iter_modules")
|
||||
def test_tool_manager_initialization(self, mock_iter):
|
||||
mock_iter.return_value = []
|
||||
|
||||
config = {"tool1": {"key": "value"}}
|
||||
manager = ToolManager(config)
|
||||
|
||||
assert manager.config == config
|
||||
assert isinstance(manager.tools, dict)
|
||||
|
||||
@patch("application.agents.tools.tool_manager.pkgutil.iter_modules")
|
||||
@patch("application.agents.tools.tool_manager.importlib.import_module")
|
||||
def test_load_tools_skips_base_and_private(self, mock_import, mock_iter):
|
||||
mock_iter.return_value = [
|
||||
(None, "base", False),
|
||||
(None, "__init__", False),
|
||||
(None, "__pycache__", False),
|
||||
(None, "valid_tool", False),
|
||||
]
|
||||
|
||||
mock_module = Mock()
|
||||
mock_module.MockTool = MockTool
|
||||
mock_import.return_value = mock_module
|
||||
|
||||
manager = ToolManager({})
|
||||
|
||||
assert "base" not in manager.tools
|
||||
assert "__init__" not in manager.tools
|
||||
|
||||
@patch("application.agents.tools.tool_manager.pkgutil.iter_modules")
|
||||
def test_load_tools_creates_tool_instances(self, mock_iter):
|
||||
mock_iter.return_value = []
|
||||
|
||||
manager = ToolManager({})
|
||||
|
||||
mock_tool = MockTool({"test": "config"})
|
||||
manager.tools["mock_tool"] = mock_tool
|
||||
|
||||
assert "mock_tool" in manager.tools
|
||||
assert isinstance(manager.tools["mock_tool"], MockTool)
|
||||
assert manager.tools["mock_tool"].config == {"test": "config"}
|
||||
|
||||
def test_load_tool_with_user_id(self):
|
||||
with patch(
|
||||
"application.agents.tools.tool_manager.pkgutil.iter_modules",
|
||||
return_value=[],
|
||||
):
|
||||
manager = ToolManager({})
|
||||
tool = MockTool({"key": "value"})
|
||||
assert tool.config == {"key": "value"}
|
||||
|
||||
manager.config["test_tool"] = {"key": "value"}
|
||||
assert "test_tool" in manager.config
|
||||
|
||||
def test_load_tool_without_user_id(self):
|
||||
tool = MockTool({"api_key": "test123"})
|
||||
|
||||
assert isinstance(tool, MockTool)
|
||||
assert tool.config == {"api_key": "test123"}
|
||||
|
||||
assert hasattr(tool, "execute_action")
|
||||
assert hasattr(tool, "get_actions_metadata")
|
||||
|
||||
@patch("application.agents.tools.tool_manager.pkgutil.iter_modules")
|
||||
def test_load_tool_updates_config(self, mock_iter):
|
||||
mock_iter.return_value = []
|
||||
|
||||
manager = ToolManager({})
|
||||
new_config = {"new_key": "new_value"}
|
||||
|
||||
manager.config["test_tool"] = new_config
|
||||
|
||||
assert manager.config["test_tool"] == new_config
|
||||
assert "test_tool" in manager.config
|
||||
|
||||
@patch("application.agents.tools.tool_manager.pkgutil.iter_modules")
|
||||
@patch("application.agents.tools.tool_manager.importlib.import_module")
|
||||
def test_execute_action_on_loaded_tool(self, mock_import, mock_iter):
|
||||
mock_iter.return_value = [(None, "mock_tool", False)]
|
||||
|
||||
mock_tool_instance = MockTool({})
|
||||
|
||||
with patch("inspect.getmembers", return_value=[("MockTool", MockTool)]):
|
||||
with patch("inspect.isclass", return_value=True):
|
||||
with patch.object(MockTool, "__init__", return_value=None):
|
||||
manager = ToolManager({})
|
||||
manager.tools["mock_tool"] = mock_tool_instance
|
||||
|
||||
result = manager.execute_action(
|
||||
"mock_tool", "test_action", param="value"
|
||||
)
|
||||
|
||||
assert "Executed test_action" in result
|
||||
|
||||
def test_execute_action_tool_not_loaded(self):
|
||||
with patch(
|
||||
"application.agents.tools.tool_manager.pkgutil.iter_modules",
|
||||
return_value=[],
|
||||
):
|
||||
manager = ToolManager({})
|
||||
with pytest.raises(ValueError, match="Tool 'nonexistent' not loaded"):
|
||||
manager.execute_action("nonexistent", "action")
|
||||
|
||||
@patch("application.agents.tools.tool_manager.importlib.import_module")
|
||||
def test_execute_action_with_user_id_for_mcp_tool(self, mock_import):
|
||||
mock_tool = MockTool({})
|
||||
|
||||
with patch("inspect.getmembers", return_value=[("MockTool", MockTool)]):
|
||||
with patch("inspect.isclass", return_value=True):
|
||||
manager = ToolManager({"mcp_tool": {}})
|
||||
manager.tools["mcp_tool"] = mock_tool
|
||||
|
||||
with patch.object(
|
||||
manager, "load_tool", return_value=mock_tool
|
||||
) as mock_load:
|
||||
manager.execute_action("mcp_tool", "action", user_id="user123")
|
||||
|
||||
mock_load.assert_called_once_with("mcp_tool", {}, "user123")
|
||||
|
||||
@patch("application.agents.tools.tool_manager.importlib.import_module")
|
||||
def test_execute_action_with_user_id_for_memory_tool(self, mock_import):
|
||||
mock_tool = MockTool({})
|
||||
|
||||
with patch("inspect.getmembers", return_value=[("MockTool", MockTool)]):
|
||||
with patch("inspect.isclass", return_value=True):
|
||||
manager = ToolManager({"memory": {}})
|
||||
manager.tools["memory"] = mock_tool
|
||||
|
||||
with patch.object(
|
||||
manager, "load_tool", return_value=mock_tool
|
||||
) as mock_load:
|
||||
manager.execute_action("memory", "view", user_id="user456")
|
||||
|
||||
mock_load.assert_called_once_with("memory", {}, "user456")
|
||||
|
||||
@patch("application.agents.tools.tool_manager.pkgutil.iter_modules")
|
||||
@patch("application.agents.tools.tool_manager.importlib.import_module")
|
||||
def test_get_all_actions_metadata(self, mock_import, mock_iter):
|
||||
mock_iter.return_value = [(None, "tool1", False), (None, "tool2", False)]
|
||||
|
||||
mock_tool1 = Mock()
|
||||
mock_tool1.get_actions_metadata.return_value = [{"name": "action1"}]
|
||||
|
||||
mock_tool2 = Mock()
|
||||
mock_tool2.get_actions_metadata.return_value = [{"name": "action2"}]
|
||||
|
||||
manager = ToolManager({})
|
||||
manager.tools = {"tool1": mock_tool1, "tool2": mock_tool2}
|
||||
|
||||
metadata = manager.get_all_actions_metadata()
|
||||
|
||||
assert len(metadata) == 2
|
||||
assert {"name": "action1"} in metadata
|
||||
assert {"name": "action2"} in metadata
|
||||
|
||||
@patch("application.agents.tools.tool_manager.pkgutil.iter_modules")
|
||||
def test_get_all_actions_metadata_empty(self, mock_iter):
|
||||
mock_iter.return_value = []
|
||||
|
||||
manager = ToolManager({})
|
||||
manager.tools = {}
|
||||
|
||||
metadata = manager.get_all_actions_metadata()
|
||||
|
||||
assert metadata == []
|
||||
|
||||
def test_load_tool_with_notes_tool(self):
|
||||
tool = MockTool({"key": "value"})
|
||||
|
||||
assert isinstance(tool, MockTool)
|
||||
assert tool.config == {"key": "value"}
|
||||
|
||||
result = tool.execute_action("test_action", param="value")
|
||||
assert "test_action" in result
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestToolBase:
|
||||
|
||||
def test_tool_base_is_abstract(self):
|
||||
with pytest.raises(TypeError):
|
||||
Tool()
|
||||
|
||||
def test_mock_tool_implements_interface(self):
|
||||
tool = MockTool({"test": "config"})
|
||||
|
||||
assert hasattr(tool, "execute_action")
|
||||
assert hasattr(tool, "get_actions_metadata")
|
||||
assert hasattr(tool, "get_config_requirements")
|
||||
|
||||
def test_mock_tool_execute_action(self):
|
||||
tool = MockTool({})
|
||||
result = tool.execute_action("test", param="value")
|
||||
|
||||
assert "Executed test" in result
|
||||
assert "param" in result
|
||||
|
||||
def test_mock_tool_get_actions_metadata(self):
|
||||
tool = MockTool({})
|
||||
metadata = tool.get_actions_metadata()
|
||||
|
||||
assert isinstance(metadata, list)
|
||||
assert len(metadata) > 0
|
||||
assert "name" in metadata[0]
|
||||
|
||||
def test_mock_tool_get_config_requirements(self):
|
||||
tool = MockTool({})
|
||||
requirements = tool.get_config_requirements()
|
||||
|
||||
assert isinstance(requirements, dict)
|
||||
assert "required" in requirements
|
||||
199
tests/conftest.py
Normal file
199
tests/conftest.py
Normal file
@@ -0,0 +1,199 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm():
|
||||
llm = Mock()
|
||||
llm.gen_stream = Mock()
|
||||
llm._supports_tools = True
|
||||
llm._supports_structured_output = Mock(return_value=False)
|
||||
llm.__class__.__name__ = "MockLLM"
|
||||
return llm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_handler():
|
||||
handler = Mock()
|
||||
handler.process_message_flow = Mock()
|
||||
return handler
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_retriever():
|
||||
retriever = Mock()
|
||||
retriever.search = Mock(
|
||||
return_value=[
|
||||
{"text": "Test document 1", "filename": "doc1.txt", "source": "test"},
|
||||
{"text": "Test document 2", "title": "doc2.txt", "source": "test"},
|
||||
]
|
||||
)
|
||||
return retriever
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mongo_db(monkeypatch):
|
||||
fake_collection = FakeMongoCollection()
|
||||
fake_db = {
|
||||
"agents": fake_collection,
|
||||
"user_tools": fake_collection,
|
||||
"memories": fake_collection,
|
||||
}
|
||||
fake_client = {settings.MONGO_DB_NAME: fake_db}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"application.core.mongo_db.MongoDB.get_client", lambda: fake_client
|
||||
)
|
||||
return fake_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_chat_history():
|
||||
return [
|
||||
{"prompt": "What is Python?", "response": "Python is a programming language."},
|
||||
{"prompt": "Tell me more.", "response": "Python is known for simplicity."},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tool_call():
|
||||
return {
|
||||
"tool_name": "test_tool",
|
||||
"call_id": "123",
|
||||
"action_name": "test_action",
|
||||
"arguments": {"arg1": "value1"},
|
||||
"result": "Tool executed successfully",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def decoded_token():
|
||||
return {"sub": "test_user", "email": "test@example.com"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def log_context():
|
||||
from application.logging import LogContext
|
||||
|
||||
context = LogContext(
|
||||
endpoint="test_endpoint",
|
||||
activity_id="test_activity",
|
||||
user="test_user",
|
||||
api_key="test_key",
|
||||
query="test query",
|
||||
)
|
||||
return context
|
||||
|
||||
|
||||
class FakeMongoCollection:
|
||||
def __init__(self):
|
||||
self.docs = {}
|
||||
|
||||
def find_one(self, query, projection=None):
|
||||
if "key" in query:
|
||||
return self.docs.get(query["key"])
|
||||
if "_id" in query:
|
||||
return self.docs.get(str(query["_id"]))
|
||||
if "user" in query:
|
||||
for doc in self.docs.values():
|
||||
if doc.get("user") == query["user"]:
|
||||
return doc
|
||||
return None
|
||||
|
||||
def find(self, query, projection=None):
|
||||
results = []
|
||||
if "_id" in query and "$in" in query["_id"]:
|
||||
for doc_id in query["_id"]["$in"]:
|
||||
doc = self.docs.get(str(doc_id))
|
||||
if doc:
|
||||
results.append(doc)
|
||||
elif "user" in query:
|
||||
for doc in self.docs.values():
|
||||
if doc.get("user") == query["user"]:
|
||||
if "status" in query:
|
||||
if doc.get("status") == query["status"]:
|
||||
results.append(doc)
|
||||
else:
|
||||
results.append(doc)
|
||||
return results
|
||||
|
||||
def insert_one(self, doc):
|
||||
doc_id = doc.get("_id", len(self.docs))
|
||||
self.docs[str(doc_id)] = doc
|
||||
return Mock(inserted_id=doc_id)
|
||||
|
||||
def update_one(self, query, update, upsert=False):
|
||||
return Mock(modified_count=1)
|
||||
|
||||
def delete_one(self, query):
|
||||
return Mock(deleted_count=1)
|
||||
|
||||
def delete_many(self, query):
|
||||
return Mock(deleted_count=0)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_creator(mock_llm, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"application.llm.llm_creator.LLMCreator.create_llm", Mock(return_value=mock_llm)
|
||||
)
|
||||
return mock_llm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_handler_creator(mock_llm_handler, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"application.llm.handlers.handler_creator.LLMHandlerCreator.create_handler",
|
||||
Mock(return_value=mock_llm_handler),
|
||||
)
|
||||
return mock_llm_handler
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent_base_params(decoded_token):
|
||||
return {
|
||||
"endpoint": "https://api.example.com",
|
||||
"llm_name": "openai",
|
||||
"gpt_model": "gpt-4",
|
||||
"api_key": "test_api_key",
|
||||
"user_api_key": None,
|
||||
"prompt": "You are a helpful assistant.",
|
||||
"chat_history": [],
|
||||
"decoded_token": decoded_token,
|
||||
"attachments": [],
|
||||
"json_schema": None,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool():
|
||||
tool = Mock()
|
||||
tool.execute_action = Mock(return_value="Tool result")
|
||||
tool.get_actions_metadata = Mock(
|
||||
return_value=[
|
||||
{
|
||||
"name": "test_action",
|
||||
"description": "A test action",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"param1": {"type": "string", "description": "Test parameter"}
|
||||
},
|
||||
"required": ["param1"],
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
return tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool_manager(mock_tool, monkeypatch):
|
||||
manager = Mock()
|
||||
manager.load_tool = Mock(return_value=mock_tool)
|
||||
monkeypatch.setattr(
|
||||
"application.agents.base.ToolManager", Mock(return_value=manager)
|
||||
)
|
||||
return manager
|
||||
@@ -1,6 +1,6 @@
|
||||
import types
|
||||
import pytest
|
||||
|
||||
import pytest
|
||||
from application.llm.openai import OpenAILLM
|
||||
|
||||
|
||||
@@ -42,16 +42,16 @@ class FakeChatCompletions:
|
||||
|
||||
def create(self, **kwargs):
|
||||
self.last_kwargs = kwargs
|
||||
# default non-streaming: return content
|
||||
if not kwargs.get("stream"):
|
||||
return FakeChatCompletions._Response(choices=[
|
||||
FakeChatCompletions._Choice(content="hello world")
|
||||
])
|
||||
# streaming: yield line objects each with choices[0].delta.content
|
||||
return FakeChatCompletions._Response(lines=[
|
||||
FakeChatCompletions._StreamLine(["part1"]),
|
||||
FakeChatCompletions._StreamLine(["part2"]),
|
||||
])
|
||||
return FakeChatCompletions._Response(
|
||||
choices=[FakeChatCompletions._Choice(content="hello world")]
|
||||
)
|
||||
return FakeChatCompletions._Response(
|
||||
lines=[
|
||||
FakeChatCompletions._StreamLine(["part1"]),
|
||||
FakeChatCompletions._StreamLine(["part2"]),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class FakeClient:
|
||||
@@ -71,16 +71,29 @@ def openai_llm(monkeypatch):
|
||||
return llm
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_clean_messages_openai_variants(openai_llm):
|
||||
messages = [
|
||||
{"role": "system", "content": "sys"},
|
||||
{"role": "model", "content": "asst"},
|
||||
{"role": "user", "content": [
|
||||
{"text": "hello"},
|
||||
{"function_call": {"call_id": "c1", "name": "fn", "args": {"a": 1}}},
|
||||
{"function_response": {"call_id": "c1", "name": "fn", "response": {"result": 42}}},
|
||||
{"type": "image_url", "image_url": {"url": ""}},
|
||||
]},
|
||||
{"role": "model", "content": "asst"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"text": "hello"},
|
||||
{"function_call": {"call_id": "c1", "name": "fn", "args": {"a": 1}}},
|
||||
{
|
||||
"function_response": {
|
||||
"call_id": "c1",
|
||||
"name": "fn",
|
||||
"response": {"result": 42},
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": ""},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
cleaned = openai_llm._clean_messages_openai(messages)
|
||||
@@ -89,17 +102,27 @@ def test_clean_messages_openai_variants(openai_llm):
|
||||
assert roles.count("assistant") >= 1
|
||||
assert any(m["role"] == "tool" for m in cleaned)
|
||||
|
||||
assert any(isinstance(m["content"], list) and any(
|
||||
part.get("type") == "image_url" for part in m["content"] if isinstance(part, dict)
|
||||
) for m in cleaned if m["role"] == "user")
|
||||
assert any(
|
||||
isinstance(m["content"], list)
|
||||
and any(
|
||||
part.get("type") == "image_url"
|
||||
for part in m["content"]
|
||||
if isinstance(part, dict)
|
||||
)
|
||||
for m in cleaned
|
||||
if m["role"] == "user"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_raw_gen_calls_openai_client_and_returns_content(openai_llm):
|
||||
msgs = [
|
||||
{"role": "system", "content": "sys"},
|
||||
{"role": "user", "content": "hello"},
|
||||
]
|
||||
content = openai_llm._raw_gen(openai_llm, model="gpt-4o", messages=msgs, stream=False)
|
||||
content = openai_llm._raw_gen(
|
||||
openai_llm, model="gpt-4o", messages=msgs, stream=False
|
||||
)
|
||||
assert content == "hello world"
|
||||
|
||||
passed = openai_llm.client.chat.completions.last_kwargs
|
||||
@@ -108,16 +131,20 @@ def test_raw_gen_calls_openai_client_and_returns_content(openai_llm):
|
||||
assert passed["stream"] is False
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_raw_gen_stream_yields_chunks(openai_llm):
|
||||
msgs = [
|
||||
{"role": "user", "content": "hi"},
|
||||
]
|
||||
gen = openai_llm._raw_gen_stream(openai_llm, model="gpt", messages=msgs, stream=True)
|
||||
gen = openai_llm._raw_gen_stream(
|
||||
openai_llm, model="gpt", messages=msgs, stream=True
|
||||
)
|
||||
chunks = list(gen)
|
||||
assert "part1" in "".join(chunks)
|
||||
assert "part2" in "".join(chunks)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_prepare_structured_output_format_enforces_required_and_strict(openai_llm):
|
||||
schema = {
|
||||
"type": "object",
|
||||
@@ -134,8 +161,8 @@ def test_prepare_structured_output_format_enforces_required_and_strict(openai_ll
|
||||
assert js["schema"]["additionalProperties"] is False
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_prepare_messages_with_attachments_image_and_pdf(openai_llm, monkeypatch):
|
||||
|
||||
monkeypatch.setattr(openai_llm, "_get_base64_image", lambda att: "AAA=")
|
||||
monkeypatch.setattr(openai_llm, "_upload_file_to_openai", lambda att: "file_xyz")
|
||||
|
||||
@@ -146,12 +173,15 @@ def test_prepare_messages_with_attachments_image_and_pdf(openai_llm, monkeypatch
|
||||
]
|
||||
out = openai_llm.prepare_messages_with_attachments(messages, attachments)
|
||||
|
||||
# last user message should have list content with text and two attachments
|
||||
user_msg = next(m for m in out if m["role"] == "user")
|
||||
assert isinstance(user_msg["content"], list)
|
||||
types_in_content = [p.get("type") for p in user_msg["content"] if isinstance(p, dict)]
|
||||
types_in_content = [
|
||||
p.get("type") for p in user_msg["content"] if isinstance(p, dict)
|
||||
]
|
||||
assert "image_url" in types_in_content or any(
|
||||
isinstance(p, dict) and p.get("image_url") for p in user_msg["content"]
|
||||
)
|
||||
assert any(isinstance(p, dict) and p.get("file", {}).get("file_id") == "file_xyz" for p in user_msg["content"])
|
||||
|
||||
assert any(
|
||||
isinstance(p, dict) and p.get("file", {}).get("file_id") == "file_xyz"
|
||||
for p in user_msg["content"]
|
||||
)
|
||||
|
||||
3
tests/requirements.txt
Normal file
3
tests/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
pytest>=8.0.0
|
||||
pytest-cov>=4.1.0
|
||||
coverage>=7.4.0
|
||||
@@ -1,12 +1,24 @@
|
||||
import base64
|
||||
|
||||
import pytest
|
||||
from application.security import encryption
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
|
||||
from application.security import encryption
|
||||
|
||||
def _fake_os_urandom_factory(values):
|
||||
values_iter = iter(values)
|
||||
|
||||
def _fake(length):
|
||||
value = next(values_iter)
|
||||
assert len(value) == length
|
||||
return value
|
||||
|
||||
return _fake
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_derive_key_uses_secret_and_user(monkeypatch):
|
||||
monkeypatch.setattr(encryption.settings, "ENCRYPTION_SECRET_KEY", "test-secret")
|
||||
salt = bytes(range(16))
|
||||
@@ -25,17 +37,7 @@ def test_derive_key_uses_secret_and_user(monkeypatch):
|
||||
assert derived == expected_key
|
||||
|
||||
|
||||
def _fake_os_urandom_factory(values):
|
||||
values_iter = iter(values)
|
||||
|
||||
def _fake(length):
|
||||
value = next(values_iter)
|
||||
assert len(value) == length
|
||||
return value
|
||||
|
||||
return _fake
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_encrypt_and_decrypt_round_trip(monkeypatch):
|
||||
monkeypatch.setattr(encryption.settings, "ENCRYPTION_SECRET_KEY", "test-secret")
|
||||
salt = bytes(range(16))
|
||||
@@ -55,6 +57,7 @@ def test_encrypt_and_decrypt_round_trip(monkeypatch):
|
||||
assert decrypted == credentials
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_encrypt_credentials_returns_empty_for_empty_input(monkeypatch):
|
||||
monkeypatch.setattr(encryption.settings, "ENCRYPTION_SECRET_KEY", "test-secret")
|
||||
|
||||
@@ -62,11 +65,12 @@ def test_encrypt_credentials_returns_empty_for_empty_input(monkeypatch):
|
||||
assert encryption.encrypt_credentials(None, "user-123") == ""
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_encrypt_credentials_returns_empty_on_serialization_error(monkeypatch):
|
||||
monkeypatch.setattr(encryption.settings, "ENCRYPTION_SECRET_KEY", "test-secret")
|
||||
monkeypatch.setattr(encryption.os, "urandom", lambda length: b"\x00" * length)
|
||||
|
||||
class NonSerializable: # pragma: no cover - simple helper container
|
||||
class NonSerializable:
|
||||
pass
|
||||
|
||||
credentials = {"bad": NonSerializable()}
|
||||
@@ -74,6 +78,7 @@ def test_encrypt_credentials_returns_empty_on_serialization_error(monkeypatch):
|
||||
assert encryption.encrypt_credentials(credentials, "user-123") == ""
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_decrypt_credentials_returns_empty_for_invalid_input(monkeypatch):
|
||||
monkeypatch.setattr(encryption.settings, "ENCRYPTION_SECRET_KEY", "test-secret")
|
||||
|
||||
@@ -84,6 +89,7 @@ def test_decrypt_credentials_returns_empty_for_invalid_input(monkeypatch):
|
||||
assert encryption.decrypt_credentials(invalid_payload, "user-123") == {}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_pad_and_unpad_are_inverse():
|
||||
original = b"secret-data"
|
||||
|
||||
@@ -91,4 +97,3 @@ def test_pad_and_unpad_are_inverse():
|
||||
|
||||
assert len(padded) % 16 == 0
|
||||
assert encryption._unpad_data(padded) == original
|
||||
|
||||
|
||||
@@ -1,352 +1,401 @@
|
||||
"""Tests for LocalStorage implementation
|
||||
"""
|
||||
|
||||
import io
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock, mock_open
|
||||
import os
|
||||
from unittest.mock import MagicMock, mock_open, patch
|
||||
|
||||
import pytest
|
||||
from application.storage.local import LocalStorage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_base_dir():
|
||||
"""Provide a temporary base directory path for testing."""
|
||||
return "/tmp/test_storage"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def local_storage(temp_base_dir):
|
||||
"""Create LocalStorage instance with test base directory."""
|
||||
return LocalStorage(base_dir=temp_base_dir)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestLocalStorageInitialization:
|
||||
"""Test LocalStorage initialization and configuration."""
|
||||
|
||||
def test_init_with_custom_base_dir(self):
|
||||
"""Should use provided base directory."""
|
||||
storage = LocalStorage(base_dir="/custom/path")
|
||||
assert storage.base_dir == "/custom/path"
|
||||
|
||||
def test_init_with_default_base_dir(self):
|
||||
"""Should use default base directory when none provided."""
|
||||
storage = LocalStorage()
|
||||
# Default is three levels up from the file location
|
||||
assert storage.base_dir is not None
|
||||
assert isinstance(storage.base_dir, str)
|
||||
|
||||
def test_get_full_path_with_relative_path(self, local_storage):
|
||||
"""Should combine base_dir with relative path."""
|
||||
result = local_storage._get_full_path("documents/test.txt")
|
||||
assert result == "/tmp/test_storage/documents/test.txt"
|
||||
expected = os.path.join("/tmp/test_storage", "documents/test.txt")
|
||||
assert os.path.normpath(result) == os.path.normpath(expected)
|
||||
|
||||
def test_get_full_path_with_absolute_path(self, local_storage):
|
||||
"""Should return absolute path unchanged."""
|
||||
result = local_storage._get_full_path("/absolute/path/test.txt")
|
||||
assert result == "/absolute/path/test.txt"
|
||||
|
||||
|
||||
class TestLocalStorageSaveFile:
|
||||
"""Test file saving functionality."""
|
||||
|
||||
@patch('os.makedirs')
|
||||
@patch('builtins.open', new_callable=mock_open)
|
||||
@patch('shutil.copyfileobj')
|
||||
@patch("os.makedirs")
|
||||
@patch("builtins.open", new_callable=mock_open)
|
||||
@patch("shutil.copyfileobj")
|
||||
def test_save_file_creates_directory_and_saves(
|
||||
self, mock_copyfileobj, mock_file, mock_makedirs, local_storage
|
||||
):
|
||||
"""Should create directory and save file content."""
|
||||
file_data = io.BytesIO(b"test content")
|
||||
path = "documents/test.txt"
|
||||
|
||||
result = local_storage.save_file(file_data, path)
|
||||
|
||||
# Verify directory creation
|
||||
mock_makedirs.assert_called_once_with(
|
||||
"/tmp/test_storage/documents",
|
||||
exist_ok=True
|
||||
expected_dir = os.path.join("/tmp/test_storage", "documents")
|
||||
expected_file = os.path.join("/tmp/test_storage", "documents/test.txt")
|
||||
|
||||
assert mock_makedirs.call_count == 1
|
||||
assert os.path.normpath(mock_makedirs.call_args[0][0]) == os.path.normpath(
|
||||
expected_dir
|
||||
)
|
||||
assert mock_makedirs.call_args[1] == {"exist_ok": True}
|
||||
|
||||
assert mock_file.call_count == 1
|
||||
assert os.path.normpath(mock_file.call_args[0][0]) == os.path.normpath(
|
||||
expected_file
|
||||
)
|
||||
assert mock_file.call_args[0][1] == "wb"
|
||||
|
||||
# Verify file write
|
||||
mock_file.assert_called_once_with("/tmp/test_storage/documents/test.txt", 'wb')
|
||||
mock_copyfileobj.assert_called_once_with(file_data, mock_file())
|
||||
assert result == {"storage_type": "local"}
|
||||
|
||||
# Verify result
|
||||
assert result == {'storage_type': 'local'}
|
||||
|
||||
@patch('os.makedirs')
|
||||
@patch("os.makedirs")
|
||||
def test_save_file_with_save_method(self, mock_makedirs, local_storage):
|
||||
"""Should use save method if file_data has it."""
|
||||
file_data = MagicMock()
|
||||
file_data.save = MagicMock()
|
||||
path = "documents/test.txt"
|
||||
|
||||
result = local_storage.save_file(file_data, path)
|
||||
|
||||
# Verify save method was called
|
||||
file_data.save.assert_called_once_with("/tmp/test_storage/documents/test.txt")
|
||||
expected_file = os.path.join("/tmp/test_storage", "documents/test.txt")
|
||||
assert file_data.save.call_count == 1
|
||||
assert os.path.normpath(file_data.save.call_args[0][0]) == os.path.normpath(
|
||||
expected_file
|
||||
)
|
||||
assert result == {"storage_type": "local"}
|
||||
|
||||
# Verify result
|
||||
assert result == {'storage_type': 'local'}
|
||||
|
||||
@patch('os.makedirs')
|
||||
@patch('builtins.open', new_callable=mock_open)
|
||||
def test_save_file_with_absolute_path(self, mock_file, mock_makedirs, local_storage):
|
||||
"""Should handle absolute paths correctly."""
|
||||
@patch("os.makedirs")
|
||||
@patch("builtins.open", new_callable=mock_open)
|
||||
def test_save_file_with_absolute_path(
|
||||
self, mock_file, mock_makedirs, local_storage
|
||||
):
|
||||
file_data = io.BytesIO(b"test content")
|
||||
path = "/absolute/path/test.txt"
|
||||
|
||||
local_storage.save_file(file_data, path)
|
||||
|
||||
mock_makedirs.assert_called_once_with("/absolute/path", exist_ok=True)
|
||||
mock_file.assert_called_once_with("/absolute/path/test.txt", 'wb')
|
||||
mock_file.assert_called_once_with("/absolute/path/test.txt", "wb")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestLocalStorageGetFile:
|
||||
"""Test file retrieval functionality."""
|
||||
|
||||
@patch('os.path.exists', return_value=True)
|
||||
@patch('builtins.open', new_callable=mock_open, read_data=b"file content")
|
||||
@patch("os.path.exists", return_value=True)
|
||||
@patch("builtins.open", new_callable=mock_open, read_data=b"file content")
|
||||
def test_get_file_returns_file_handle(self, mock_file, mock_exists, local_storage):
|
||||
"""Should open and return file handle when file exists."""
|
||||
path = "documents/test.txt"
|
||||
|
||||
result = local_storage.get_file(path)
|
||||
|
||||
mock_exists.assert_called_once_with("/tmp/test_storage/documents/test.txt")
|
||||
mock_file.assert_called_once_with("/tmp/test_storage/documents/test.txt", 'rb')
|
||||
expected_path = os.path.join("/tmp/test_storage", "documents/test.txt")
|
||||
assert mock_exists.call_count == 1
|
||||
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
|
||||
expected_path
|
||||
)
|
||||
assert mock_file.call_count == 1
|
||||
assert os.path.normpath(mock_file.call_args[0][0]) == os.path.normpath(
|
||||
expected_path
|
||||
)
|
||||
assert result is not None
|
||||
|
||||
@patch('os.path.exists', return_value=False)
|
||||
@patch("os.path.exists", return_value=False)
|
||||
def test_get_file_raises_error_when_not_found(self, mock_exists, local_storage):
|
||||
"""Should raise FileNotFoundError when file doesn't exist."""
|
||||
path = "documents/nonexistent.txt"
|
||||
|
||||
with pytest.raises(FileNotFoundError, match="File not found"):
|
||||
local_storage.get_file(path)
|
||||
|
||||
mock_exists.assert_called_once_with("/tmp/test_storage/documents/nonexistent.txt")
|
||||
expected_path = os.path.join("/tmp/test_storage", "documents/nonexistent.txt")
|
||||
assert mock_exists.call_count == 1
|
||||
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
|
||||
expected_path
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestLocalStorageDeleteFile:
|
||||
"""Test file deletion functionality."""
|
||||
|
||||
@patch('os.remove')
|
||||
@patch('os.path.exists', return_value=True)
|
||||
def test_delete_file_removes_existing_file(self, mock_exists, mock_remove, local_storage):
|
||||
"""Should delete file and return True when file exists."""
|
||||
@patch("os.remove")
|
||||
@patch("os.path.exists", return_value=True)
|
||||
def test_delete_file_removes_existing_file(
|
||||
self, mock_exists, mock_remove, local_storage
|
||||
):
|
||||
path = "documents/test.txt"
|
||||
|
||||
result = local_storage.delete_file(path)
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "documents/test.txt")
|
||||
assert result is True
|
||||
mock_exists.assert_called_once_with("/tmp/test_storage/documents/test.txt")
|
||||
mock_remove.assert_called_once_with("/tmp/test_storage/documents/test.txt")
|
||||
assert mock_exists.call_count == 1
|
||||
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
|
||||
expected_path
|
||||
)
|
||||
assert mock_remove.call_count == 1
|
||||
assert os.path.normpath(mock_remove.call_args[0][0]) == os.path.normpath(
|
||||
expected_path
|
||||
)
|
||||
|
||||
@patch('os.path.exists', return_value=False)
|
||||
@patch("os.path.exists", return_value=False)
|
||||
def test_delete_file_returns_false_when_not_found(self, mock_exists, local_storage):
|
||||
"""Should return False when file doesn't exist."""
|
||||
path = "documents/nonexistent.txt"
|
||||
|
||||
result = local_storage.delete_file(path)
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "documents/nonexistent.txt")
|
||||
assert result is False
|
||||
mock_exists.assert_called_once_with("/tmp/test_storage/documents/nonexistent.txt")
|
||||
assert mock_exists.call_count == 1
|
||||
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
|
||||
expected_path
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestLocalStorageFileExists:
|
||||
"""Test file existence checking."""
|
||||
|
||||
@patch('os.path.exists', return_value=True)
|
||||
@patch("os.path.exists", return_value=True)
|
||||
def test_file_exists_returns_true_when_file_found(self, mock_exists, local_storage):
|
||||
"""Should return True when file exists."""
|
||||
path = "documents/test.txt"
|
||||
|
||||
result = local_storage.file_exists(path)
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "documents/test.txt")
|
||||
assert result is True
|
||||
mock_exists.assert_called_once_with("/tmp/test_storage/documents/test.txt")
|
||||
assert mock_exists.call_count == 1
|
||||
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
|
||||
expected_path
|
||||
)
|
||||
|
||||
@patch('os.path.exists', return_value=False)
|
||||
@patch("os.path.exists", return_value=False)
|
||||
def test_file_exists_returns_false_when_not_found(self, mock_exists, local_storage):
|
||||
"""Should return False when file doesn't exist."""
|
||||
path = "documents/nonexistent.txt"
|
||||
|
||||
result = local_storage.file_exists(path)
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "documents/nonexistent.txt")
|
||||
assert result is False
|
||||
mock_exists.assert_called_once_with("/tmp/test_storage/documents/nonexistent.txt")
|
||||
assert mock_exists.call_count == 1
|
||||
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
|
||||
expected_path
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestLocalStorageListFiles:
|
||||
"""Test directory listing functionality."""
|
||||
|
||||
@patch('os.walk')
|
||||
@patch('os.path.exists', return_value=True)
|
||||
@patch("os.walk")
|
||||
@patch("os.path.exists", return_value=True)
|
||||
def test_list_files_returns_all_files_in_directory(
|
||||
self, mock_exists, mock_walk, local_storage
|
||||
):
|
||||
"""Should return all files in directory and subdirectories."""
|
||||
directory = "documents"
|
||||
base_dir = os.path.join("/tmp/test_storage", "documents")
|
||||
|
||||
# Mock os.walk to return files in directory structure
|
||||
mock_walk.return_value = [
|
||||
("/tmp/test_storage/documents", ["subdir"], ["file1.txt", "file2.txt"]),
|
||||
("/tmp/test_storage/documents/subdir", [], ["file3.txt"])
|
||||
(base_dir, ["subdir"], ["file1.txt", "file2.txt"]),
|
||||
(os.path.join(base_dir, "subdir"), [], ["file3.txt"]),
|
||||
]
|
||||
|
||||
result = local_storage.list_files(directory)
|
||||
|
||||
assert len(result) == 3
|
||||
assert "documents/file1.txt" in result
|
||||
assert "documents/file2.txt" in result
|
||||
assert "documents/subdir/file3.txt" in result
|
||||
result_normalized = [os.path.normpath(f) for f in result]
|
||||
assert os.path.normpath("documents/file1.txt") in result_normalized
|
||||
assert os.path.normpath("documents/file2.txt") in result_normalized
|
||||
assert os.path.normpath("documents/subdir/file3.txt") in result_normalized
|
||||
|
||||
mock_exists.assert_called_once_with("/tmp/test_storage/documents")
|
||||
mock_walk.assert_called_once_with("/tmp/test_storage/documents")
|
||||
|
||||
@patch('os.path.exists', return_value=False)
|
||||
@patch("os.path.exists", return_value=False)
|
||||
def test_list_files_returns_empty_list_when_directory_not_found(
|
||||
self, mock_exists, local_storage
|
||||
):
|
||||
"""Should return empty list when directory doesn't exist."""
|
||||
directory = "nonexistent"
|
||||
|
||||
result = local_storage.list_files(directory)
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "nonexistent")
|
||||
assert result == []
|
||||
mock_exists.assert_called_once_with("/tmp/test_storage/nonexistent")
|
||||
assert mock_exists.call_count == 1
|
||||
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
|
||||
expected_path
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestLocalStorageProcessFile:
|
||||
"""Test file processing functionality."""
|
||||
|
||||
@patch('os.path.exists', return_value=True)
|
||||
@patch("os.path.exists", return_value=True)
|
||||
def test_process_file_calls_processor_with_full_path(
|
||||
self, mock_exists, local_storage
|
||||
):
|
||||
"""Should call processor function with full file path."""
|
||||
path = "documents/test.txt"
|
||||
processor_func = MagicMock(return_value="processed")
|
||||
|
||||
result = local_storage.process_file(path, processor_func, extra_arg="value")
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "documents/test.txt")
|
||||
assert result == "processed"
|
||||
processor_func.assert_called_once_with(
|
||||
local_path="/tmp/test_storage/documents/test.txt",
|
||||
extra_arg="value"
|
||||
assert processor_func.call_count == 1
|
||||
call_kwargs = processor_func.call_args[1]
|
||||
assert os.path.normpath(call_kwargs["local_path"]) == os.path.normpath(
|
||||
expected_path
|
||||
)
|
||||
mock_exists.assert_called_once_with("/tmp/test_storage/documents/test.txt")
|
||||
assert call_kwargs["extra_arg"] == "value"
|
||||
|
||||
@patch('os.path.exists', return_value=False)
|
||||
def test_process_file_raises_error_when_file_not_found(self, mock_exists, local_storage):
|
||||
"""Should raise FileNotFoundError when file doesn't exist."""
|
||||
@patch("os.path.exists", return_value=False)
|
||||
def test_process_file_raises_error_when_file_not_found(
|
||||
self, mock_exists, local_storage
|
||||
):
|
||||
path = "documents/nonexistent.txt"
|
||||
processor_func = MagicMock()
|
||||
|
||||
with pytest.raises(FileNotFoundError, match="File not found"):
|
||||
local_storage.process_file(path, processor_func)
|
||||
|
||||
processor_func.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestLocalStorageIsDirectory:
|
||||
"""Test directory checking functionality."""
|
||||
|
||||
@patch('os.path.isdir', return_value=True)
|
||||
@patch("os.path.isdir", return_value=True)
|
||||
def test_is_directory_returns_true_when_directory_exists(
|
||||
self, mock_isdir, local_storage
|
||||
):
|
||||
"""Should return True when path is a directory."""
|
||||
path = "documents"
|
||||
|
||||
result = local_storage.is_directory(path)
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "documents")
|
||||
assert result is True
|
||||
mock_isdir.assert_called_once_with("/tmp/test_storage/documents")
|
||||
assert mock_isdir.call_count == 1
|
||||
assert os.path.normpath(mock_isdir.call_args[0][0]) == os.path.normpath(
|
||||
expected_path
|
||||
)
|
||||
|
||||
@patch('os.path.isdir', return_value=False)
|
||||
@patch("os.path.isdir", return_value=False)
|
||||
def test_is_directory_returns_false_when_not_directory(
|
||||
self, mock_isdir, local_storage
|
||||
):
|
||||
"""Should return False when path is not a directory or doesn't exist."""
|
||||
path = "documents/test.txt"
|
||||
|
||||
result = local_storage.is_directory(path)
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "documents/test.txt")
|
||||
assert result is False
|
||||
mock_isdir.assert_called_once_with("/tmp/test_storage/documents/test.txt")
|
||||
assert mock_isdir.call_count == 1
|
||||
assert os.path.normpath(mock_isdir.call_args[0][0]) == os.path.normpath(
|
||||
expected_path
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestLocalStorageRemoveDirectory:
|
||||
"""Test directory removal functionality."""
|
||||
|
||||
@patch('shutil.rmtree')
|
||||
@patch('os.path.isdir', return_value=True)
|
||||
@patch('os.path.exists', return_value=True)
|
||||
@patch("shutil.rmtree")
|
||||
@patch("os.path.isdir", return_value=True)
|
||||
@patch("os.path.exists", return_value=True)
|
||||
def test_remove_directory_deletes_directory(
|
||||
self, mock_exists, mock_isdir, mock_rmtree, local_storage
|
||||
):
|
||||
"""Should remove directory and return True when successful."""
|
||||
directory = "documents"
|
||||
|
||||
result = local_storage.remove_directory(directory)
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "documents")
|
||||
assert result is True
|
||||
mock_exists.assert_called_once_with("/tmp/test_storage/documents")
|
||||
mock_isdir.assert_called_once_with("/tmp/test_storage/documents")
|
||||
mock_rmtree.assert_called_once_with("/tmp/test_storage/documents")
|
||||
assert mock_exists.call_count == 1
|
||||
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
|
||||
expected_path
|
||||
)
|
||||
assert mock_isdir.call_count == 1
|
||||
assert os.path.normpath(mock_isdir.call_args[0][0]) == os.path.normpath(
|
||||
expected_path
|
||||
)
|
||||
assert mock_rmtree.call_count == 1
|
||||
assert os.path.normpath(mock_rmtree.call_args[0][0]) == os.path.normpath(
|
||||
expected_path
|
||||
)
|
||||
|
||||
@patch('os.path.exists', return_value=False)
|
||||
@patch("os.path.exists", return_value=False)
|
||||
def test_remove_directory_returns_false_when_not_exists(
|
||||
self, mock_exists, local_storage
|
||||
):
|
||||
"""Should return False when directory doesn't exist."""
|
||||
directory = "nonexistent"
|
||||
|
||||
result = local_storage.remove_directory(directory)
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "nonexistent")
|
||||
assert result is False
|
||||
mock_exists.assert_called_once_with("/tmp/test_storage/nonexistent")
|
||||
assert mock_exists.call_count == 1
|
||||
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
|
||||
expected_path
|
||||
)
|
||||
|
||||
@patch('os.path.isdir', return_value=False)
|
||||
@patch('os.path.exists', return_value=True)
|
||||
@patch("os.path.isdir", return_value=False)
|
||||
@patch("os.path.exists", return_value=True)
|
||||
def test_remove_directory_returns_false_when_not_directory(
|
||||
self, mock_exists, mock_isdir, local_storage
|
||||
):
|
||||
"""Should return False when path is not a directory."""
|
||||
path = "documents/test.txt"
|
||||
|
||||
result = local_storage.remove_directory(path)
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "documents/test.txt")
|
||||
assert result is False
|
||||
mock_exists.assert_called_once_with("/tmp/test_storage/documents/test.txt")
|
||||
mock_isdir.assert_called_once_with("/tmp/test_storage/documents/test.txt")
|
||||
assert mock_exists.call_count == 1
|
||||
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
|
||||
expected_path
|
||||
)
|
||||
assert mock_isdir.call_count == 1
|
||||
assert os.path.normpath(mock_isdir.call_args[0][0]) == os.path.normpath(
|
||||
expected_path
|
||||
)
|
||||
|
||||
@patch('shutil.rmtree', side_effect=OSError("Permission denied"))
|
||||
@patch('os.path.isdir', return_value=True)
|
||||
@patch('os.path.exists', return_value=True)
|
||||
@patch("shutil.rmtree", side_effect=OSError("Permission denied"))
|
||||
@patch("os.path.isdir", return_value=True)
|
||||
@patch("os.path.exists", return_value=True)
|
||||
def test_remove_directory_returns_false_on_os_error(
|
||||
self, mock_exists, mock_isdir, mock_rmtree, local_storage
|
||||
):
|
||||
"""Should return False when OSError occurs during removal."""
|
||||
directory = "documents"
|
||||
|
||||
result = local_storage.remove_directory(directory)
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "documents")
|
||||
assert result is False
|
||||
mock_rmtree.assert_called_once_with("/tmp/test_storage/documents")
|
||||
assert mock_rmtree.call_count == 1
|
||||
assert os.path.normpath(mock_rmtree.call_args[0][0]) == os.path.normpath(
|
||||
expected_path
|
||||
)
|
||||
|
||||
@patch('shutil.rmtree', side_effect=PermissionError("Access denied"))
|
||||
@patch('os.path.isdir', return_value=True)
|
||||
@patch('os.path.exists', return_value=True)
|
||||
@patch("shutil.rmtree", side_effect=PermissionError("Access denied"))
|
||||
@patch("os.path.isdir", return_value=True)
|
||||
@patch("os.path.exists", return_value=True)
|
||||
def test_remove_directory_returns_false_on_permission_error(
|
||||
self, mock_exists, mock_isdir, mock_rmtree, local_storage
|
||||
):
|
||||
"""Should return False when PermissionError occurs during removal."""
|
||||
directory = "documents"
|
||||
|
||||
result = local_storage.remove_directory(directory)
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "documents")
|
||||
assert result is False
|
||||
mock_rmtree.assert_called_once_with("/tmp/test_storage/documents")
|
||||
assert mock_rmtree.call_count == 1
|
||||
assert os.path.normpath(mock_rmtree.call_args[0][0]) == os.path.normpath(
|
||||
expected_path
|
||||
)
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
"""Tests for S3 storage implementation.
|
||||
"""
|
||||
"""Tests for S3 storage implementation."""
|
||||
|
||||
import io
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from application.storage.s3 import S3Storage
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_boto3_client():
|
||||
"""Mock boto3.client to isolate S3 client creation."""
|
||||
with patch('boto3.client') as mock_client:
|
||||
with patch("boto3.client") as mock_client:
|
||||
s3_mock = MagicMock()
|
||||
mock_client.return_value = s3_mock
|
||||
yield s3_mock
|
||||
@@ -27,22 +27,26 @@ def s3_storage(mock_boto3_client):
|
||||
class TestS3StorageInitialization:
|
||||
"""Test S3Storage initialization and configuration."""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_with_default_bucket(self):
|
||||
"""Should use default bucket name when none provided."""
|
||||
with patch('boto3.client'):
|
||||
with patch("boto3.client"):
|
||||
storage = S3Storage()
|
||||
assert storage.bucket_name == "docsgpt-test-bucket"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_with_custom_bucket(self):
|
||||
"""Should use provided bucket name."""
|
||||
with patch('boto3.client'):
|
||||
with patch("boto3.client"):
|
||||
storage = S3Storage(bucket_name="custom-bucket")
|
||||
assert storage.bucket_name == "custom-bucket"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_creates_boto3_client(self):
|
||||
"""Should create boto3 S3 client with credentials from settings."""
|
||||
with patch('boto3.client') as mock_client, \
|
||||
patch('application.storage.s3.settings') as mock_settings:
|
||||
with patch("boto3.client") as mock_client, patch(
|
||||
"application.storage.s3.settings"
|
||||
) as mock_settings:
|
||||
|
||||
mock_settings.SAGEMAKER_ACCESS_KEY = "test-key"
|
||||
mock_settings.SAGEMAKER_SECRET_KEY = "test-secret"
|
||||
@@ -54,52 +58,50 @@ class TestS3StorageInitialization:
|
||||
"s3",
|
||||
aws_access_key_id="test-key",
|
||||
aws_secret_access_key="test-secret",
|
||||
region_name="us-west-2"
|
||||
region_name="us-west-2",
|
||||
)
|
||||
|
||||
|
||||
class TestS3StorageSaveFile:
|
||||
"""Test file saving functionality."""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_save_file_uploads_to_s3(self, s3_storage, mock_boto3_client):
|
||||
"""Should upload file to S3 with correct parameters."""
|
||||
file_data = io.BytesIO(b"test content")
|
||||
path = "documents/test.txt"
|
||||
|
||||
with patch('application.storage.s3.settings') as mock_settings:
|
||||
with patch("application.storage.s3.settings") as mock_settings:
|
||||
mock_settings.SAGEMAKER_REGION = "us-east-1"
|
||||
result = s3_storage.save_file(file_data, path)
|
||||
|
||||
mock_boto3_client.upload_fileobj.assert_called_once_with(
|
||||
file_data,
|
||||
"test-bucket",
|
||||
path,
|
||||
ExtraArgs={"StorageClass": "INTELLIGENT_TIERING"}
|
||||
ExtraArgs={"StorageClass": "INTELLIGENT_TIERING"},
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"storage_type": "s3",
|
||||
"bucket_name": "test-bucket",
|
||||
"uri": "s3://test-bucket/documents/test.txt",
|
||||
"region": "us-east-1"
|
||||
"region": "us-east-1",
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_save_file_with_custom_storage_class(self, s3_storage, mock_boto3_client):
|
||||
"""Should use custom storage class when provided."""
|
||||
file_data = io.BytesIO(b"test content")
|
||||
path = "documents/test.txt"
|
||||
|
||||
with patch('application.storage.s3.settings') as mock_settings:
|
||||
with patch("application.storage.s3.settings") as mock_settings:
|
||||
mock_settings.SAGEMAKER_REGION = "us-east-1"
|
||||
s3_storage.save_file(file_data, path, storage_class="STANDARD")
|
||||
|
||||
mock_boto3_client.upload_fileobj.assert_called_once_with(
|
||||
file_data,
|
||||
"test-bucket",
|
||||
path,
|
||||
ExtraArgs={"StorageClass": "STANDARD"}
|
||||
file_data, "test-bucket", path, ExtraArgs={"StorageClass": "STANDARD"}
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_save_file_propagates_client_error(self, s3_storage, mock_boto3_client):
|
||||
"""Should propagate ClientError when upload fails."""
|
||||
file_data = io.BytesIO(b"test content")
|
||||
@@ -107,7 +109,7 @@ class TestS3StorageSaveFile:
|
||||
|
||||
mock_boto3_client.upload_fileobj.side_effect = ClientError(
|
||||
{"Error": {"Code": "AccessDenied", "Message": "Access denied"}},
|
||||
"upload_fileobj"
|
||||
"upload_fileobj",
|
||||
)
|
||||
|
||||
with pytest.raises(ClientError):
|
||||
@@ -117,7 +119,10 @@ class TestS3StorageSaveFile:
|
||||
class TestS3StorageFileExists:
|
||||
"""Test file existence checking."""
|
||||
|
||||
def test_file_exists_returns_true_when_file_found(self, s3_storage, mock_boto3_client):
|
||||
@pytest.mark.unit
|
||||
def test_file_exists_returns_true_when_file_found(
|
||||
self, s3_storage, mock_boto3_client
|
||||
):
|
||||
"""Should return True when head_object succeeds."""
|
||||
path = "documents/test.txt"
|
||||
mock_boto3_client.head_object.return_value = {"ContentLength": 100}
|
||||
@@ -126,16 +131,17 @@ class TestS3StorageFileExists:
|
||||
|
||||
assert result is True
|
||||
mock_boto3_client.head_object.assert_called_once_with(
|
||||
Bucket="test-bucket",
|
||||
Key=path
|
||||
Bucket="test-bucket", Key=path
|
||||
)
|
||||
|
||||
def test_file_exists_returns_false_on_client_error(self, s3_storage, mock_boto3_client):
|
||||
@pytest.mark.unit
|
||||
def test_file_exists_returns_false_on_client_error(
|
||||
self, s3_storage, mock_boto3_client
|
||||
):
|
||||
"""Should return False when head_object raises ClientError."""
|
||||
path = "documents/nonexistent.txt"
|
||||
mock_boto3_client.head_object.side_effect = ClientError(
|
||||
{"Error": {"Code": "NoSuchKey", "Message": "Not found"}},
|
||||
"head_object"
|
||||
{"Error": {"Code": "NoSuchKey", "Message": "Not found"}}, "head_object"
|
||||
)
|
||||
|
||||
result = s3_storage.file_exists(path)
|
||||
@@ -146,7 +152,10 @@ class TestS3StorageFileExists:
|
||||
class TestS3StorageGetFile:
|
||||
"""Test file retrieval functionality."""
|
||||
|
||||
def test_get_file_downloads_and_returns_file_object(self, s3_storage, mock_boto3_client):
|
||||
@pytest.mark.unit
|
||||
def test_get_file_downloads_and_returns_file_object(
|
||||
self, s3_storage, mock_boto3_client
|
||||
):
|
||||
"""Should download file from S3 and return BytesIO object."""
|
||||
path = "documents/test.txt"
|
||||
test_content = b"file content"
|
||||
@@ -164,12 +173,14 @@ class TestS3StorageGetFile:
|
||||
assert result.read() == test_content
|
||||
mock_boto3_client.download_fileobj.assert_called_once()
|
||||
|
||||
def test_get_file_raises_error_when_file_not_found(self, s3_storage, mock_boto3_client):
|
||||
@pytest.mark.unit
|
||||
def test_get_file_raises_error_when_file_not_found(
|
||||
self, s3_storage, mock_boto3_client
|
||||
):
|
||||
"""Should raise FileNotFoundError when file doesn't exist."""
|
||||
path = "documents/nonexistent.txt"
|
||||
mock_boto3_client.head_object.side_effect = ClientError(
|
||||
{"Error": {"Code": "NoSuchKey", "Message": "Not found"}},
|
||||
"head_object"
|
||||
{"Error": {"Code": "NoSuchKey", "Message": "Not found"}}, "head_object"
|
||||
)
|
||||
|
||||
with pytest.raises(FileNotFoundError, match="File not found"):
|
||||
@@ -179,6 +190,7 @@ class TestS3StorageGetFile:
|
||||
class TestS3StorageDeleteFile:
|
||||
"""Test file deletion functionality."""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_delete_file_returns_true_on_success(self, s3_storage, mock_boto3_client):
|
||||
"""Should return True when deletion succeeds."""
|
||||
path = "documents/test.txt"
|
||||
@@ -188,16 +200,18 @@ class TestS3StorageDeleteFile:
|
||||
|
||||
assert result is True
|
||||
mock_boto3_client.delete_object.assert_called_once_with(
|
||||
Bucket="test-bucket",
|
||||
Key=path
|
||||
Bucket="test-bucket", Key=path
|
||||
)
|
||||
|
||||
def test_delete_file_returns_false_on_client_error(self, s3_storage, mock_boto3_client):
|
||||
@pytest.mark.unit
|
||||
def test_delete_file_returns_false_on_client_error(
|
||||
self, s3_storage, mock_boto3_client
|
||||
):
|
||||
"""Should return False when deletion fails with ClientError."""
|
||||
path = "documents/test.txt"
|
||||
mock_boto3_client.delete_object.side_effect = ClientError(
|
||||
{"Error": {"Code": "AccessDenied", "Message": "Access denied"}},
|
||||
"delete_object"
|
||||
"delete_object",
|
||||
)
|
||||
|
||||
result = s3_storage.delete_file(path)
|
||||
@@ -208,7 +222,10 @@ class TestS3StorageDeleteFile:
|
||||
class TestS3StorageListFiles:
|
||||
"""Test directory listing functionality."""
|
||||
|
||||
def test_list_files_returns_all_keys_with_prefix(self, s3_storage, mock_boto3_client):
|
||||
@pytest.mark.unit
|
||||
def test_list_files_returns_all_keys_with_prefix(
|
||||
self, s3_storage, mock_boto3_client
|
||||
):
|
||||
"""Should return all file keys matching the directory prefix."""
|
||||
directory = "documents/"
|
||||
|
||||
@@ -219,7 +236,7 @@ class TestS3StorageListFiles:
|
||||
"Contents": [
|
||||
{"Key": "documents/file1.txt"},
|
||||
{"Key": "documents/file2.txt"},
|
||||
{"Key": "documents/subdir/file3.txt"}
|
||||
{"Key": "documents/subdir/file3.txt"},
|
||||
]
|
||||
}
|
||||
]
|
||||
@@ -231,13 +248,15 @@ class TestS3StorageListFiles:
|
||||
assert "documents/file2.txt" in result
|
||||
assert "documents/subdir/file3.txt" in result
|
||||
|
||||
mock_boto3_client.get_paginator.assert_called_once_with('list_objects_v2')
|
||||
mock_boto3_client.get_paginator.assert_called_once_with("list_objects_v2")
|
||||
paginator_mock.paginate.assert_called_once_with(
|
||||
Bucket="test-bucket",
|
||||
Prefix="documents/"
|
||||
Bucket="test-bucket", Prefix="documents/"
|
||||
)
|
||||
|
||||
def test_list_files_returns_empty_list_when_no_contents(self, s3_storage, mock_boto3_client):
|
||||
@pytest.mark.unit
|
||||
def test_list_files_returns_empty_list_when_no_contents(
|
||||
self, s3_storage, mock_boto3_client
|
||||
):
|
||||
"""Should return empty list when directory has no files."""
|
||||
directory = "empty/"
|
||||
|
||||
@@ -253,30 +272,36 @@ class TestS3StorageListFiles:
|
||||
class TestS3StorageProcessFile:
|
||||
"""Test file processing functionality."""
|
||||
|
||||
def test_process_file_downloads_and_processes_file(self, s3_storage, mock_boto3_client):
|
||||
@pytest.mark.unit
|
||||
def test_process_file_downloads_and_processes_file(
|
||||
self, s3_storage, mock_boto3_client
|
||||
):
|
||||
"""Should download file to temp location and call processor function."""
|
||||
path = "documents/test.txt"
|
||||
|
||||
mock_boto3_client.head_object.return_value = {}
|
||||
|
||||
with patch('tempfile.NamedTemporaryFile') as mock_temp:
|
||||
with patch("tempfile.NamedTemporaryFile") as mock_temp:
|
||||
mock_file = MagicMock()
|
||||
mock_file.name = "/tmp/test_file"
|
||||
mock_temp.return_value.__enter__.return_value = mock_file
|
||||
|
||||
processor_func = MagicMock(return_value="processed")
|
||||
result = s3_storage.process_file(path, processor_func, extra_arg="value")
|
||||
|
||||
assert result == "processed"
|
||||
processor_func.assert_called_once_with(local_path="/tmp/test_file", extra_arg="value")
|
||||
processor_func.assert_called_once_with(
|
||||
local_path="/tmp/test_file", extra_arg="value"
|
||||
)
|
||||
mock_boto3_client.download_fileobj.assert_called_once()
|
||||
|
||||
def test_process_file_raises_error_when_file_not_found(self, s3_storage, mock_boto3_client):
|
||||
@pytest.mark.unit
|
||||
def test_process_file_raises_error_when_file_not_found(
|
||||
self, s3_storage, mock_boto3_client
|
||||
):
|
||||
"""Should raise FileNotFoundError when file doesn't exist."""
|
||||
path = "documents/nonexistent.txt"
|
||||
mock_boto3_client.head_object.side_effect = ClientError(
|
||||
{"Error": {"Code": "NoSuchKey", "Message": "Not found"}},
|
||||
"head_object"
|
||||
{"Error": {"Code": "NoSuchKey", "Message": "Not found"}}, "head_object"
|
||||
)
|
||||
|
||||
processor_func = MagicMock()
|
||||
@@ -288,7 +313,10 @@ class TestS3StorageProcessFile:
|
||||
class TestS3StorageIsDirectory:
|
||||
"""Test directory checking functionality."""
|
||||
|
||||
def test_is_directory_returns_true_when_objects_exist(self, s3_storage, mock_boto3_client):
|
||||
@pytest.mark.unit
|
||||
def test_is_directory_returns_true_when_objects_exist(
|
||||
self, s3_storage, mock_boto3_client
|
||||
):
|
||||
"""Should return True when objects exist with the directory prefix."""
|
||||
path = "documents/"
|
||||
|
||||
@@ -300,12 +328,13 @@ class TestS3StorageIsDirectory:
|
||||
|
||||
assert result is True
|
||||
mock_boto3_client.list_objects_v2.assert_called_once_with(
|
||||
Bucket="test-bucket",
|
||||
Prefix="documents/",
|
||||
MaxKeys=1
|
||||
Bucket="test-bucket", Prefix="documents/", MaxKeys=1
|
||||
)
|
||||
|
||||
def test_is_directory_returns_false_when_no_objects_exist(self, s3_storage, mock_boto3_client):
|
||||
@pytest.mark.unit
|
||||
def test_is_directory_returns_false_when_no_objects_exist(
|
||||
self, s3_storage, mock_boto3_client
|
||||
):
|
||||
"""Should return False when no objects exist with the directory prefix."""
|
||||
path = "nonexistent/"
|
||||
|
||||
@@ -319,6 +348,7 @@ class TestS3StorageIsDirectory:
|
||||
class TestS3StorageRemoveDirectory:
|
||||
"""Test directory removal functionality."""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_remove_directory_deletes_all_objects(self, s3_storage, mock_boto3_client):
|
||||
"""Should delete all objects with the directory prefix."""
|
||||
directory = "documents/"
|
||||
@@ -329,16 +359,13 @@ class TestS3StorageRemoveDirectory:
|
||||
{
|
||||
"Contents": [
|
||||
{"Key": "documents/file1.txt"},
|
||||
{"Key": "documents/file2.txt"}
|
||||
{"Key": "documents/file2.txt"},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
mock_boto3_client.delete_objects.return_value = {
|
||||
"Deleted": [
|
||||
{"Key": "documents/file1.txt"},
|
||||
{"Key": "documents/file2.txt"}
|
||||
]
|
||||
"Deleted": [{"Key": "documents/file1.txt"}, {"Key": "documents/file2.txt"}]
|
||||
}
|
||||
|
||||
result = s3_storage.remove_directory(directory)
|
||||
@@ -349,7 +376,10 @@ class TestS3StorageRemoveDirectory:
|
||||
assert call_args["Bucket"] == "test-bucket"
|
||||
assert len(call_args["Delete"]["Objects"]) == 2
|
||||
|
||||
def test_remove_directory_returns_false_when_empty(self, s3_storage, mock_boto3_client):
|
||||
@pytest.mark.unit
|
||||
def test_remove_directory_returns_false_when_empty(
|
||||
self, s3_storage, mock_boto3_client
|
||||
):
|
||||
"""Should return False when directory is empty (no objects to delete)."""
|
||||
directory = "empty/"
|
||||
|
||||
@@ -362,7 +392,10 @@ class TestS3StorageRemoveDirectory:
|
||||
assert result is False
|
||||
mock_boto3_client.delete_objects.assert_not_called()
|
||||
|
||||
def test_remove_directory_returns_false_on_client_error(self, s3_storage, mock_boto3_client):
|
||||
@pytest.mark.unit
|
||||
def test_remove_directory_returns_false_on_client_error(
|
||||
self, s3_storage, mock_boto3_client
|
||||
):
|
||||
"""Should return False when deletion fails with ClientError."""
|
||||
directory = "documents/"
|
||||
|
||||
@@ -374,7 +407,7 @@ class TestS3StorageRemoveDirectory:
|
||||
|
||||
mock_boto3_client.delete_objects.side_effect = ClientError(
|
||||
{"Error": {"Code": "AccessDenied", "Message": "Access denied"}},
|
||||
"delete_objects"
|
||||
"delete_objects",
|
||||
)
|
||||
|
||||
result = s3_storage.remove_directory(directory)
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from flask import Flask
|
||||
|
||||
import pytest
|
||||
from application.api.answer import answer
|
||||
from application.api.internal.routes import internal
|
||||
from application.api.user.routes import user
|
||||
from application.core.settings import settings
|
||||
from flask import Flask
|
||||
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_app_config():
|
||||
app = Flask(__name__)
|
||||
app.register_blueprint(user)
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
import unittest
|
||||
import json
|
||||
from unittest.mock import patch, MagicMock
|
||||
from application.cache import gen_cache_key, stream_cache, gen_cache
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from application.cache import gen_cache, gen_cache_key, stream_cache
|
||||
from application.utils import get_hash
|
||||
|
||||
|
||||
# Test for gen_cache_key function
|
||||
@pytest.mark.unit
|
||||
def test_make_gen_cache_key():
|
||||
messages = [
|
||||
{'role': 'user', 'content': 'test_user_message'},
|
||||
{'role': 'system', 'content': 'test_system_message'},
|
||||
{"role": "user", "content": "test_user_message"},
|
||||
{"role": "system", "content": "test_system_message"},
|
||||
]
|
||||
model = "test_docgpt"
|
||||
tools = None
|
||||
|
||||
# Manually calculate the expected hash
|
||||
|
||||
messages_str = json.dumps(messages)
|
||||
tools_str = json.dumps(tools) if tools else ""
|
||||
expected_combined = f"{model}_{messages_str}_{tools_str}"
|
||||
@@ -23,112 +23,100 @@ def test_make_gen_cache_key():
|
||||
|
||||
assert cache_key == expected_hash
|
||||
|
||||
def test_gen_cache_key_invalid_message_format():
|
||||
# Test when messages is not a list
|
||||
with unittest.TestCase.assertRaises(unittest.TestCase, ValueError) as context:
|
||||
gen_cache_key("This is not a list", model="docgpt", tools=None)
|
||||
assert str(context.exception) == "All messages must be dictionaries."
|
||||
|
||||
# Test for gen_cache decorator
|
||||
@patch('application.cache.get_redis_instance') # Mock the Redis client
|
||||
@pytest.mark.unit
|
||||
def test_gen_cache_key_invalid_message_format():
|
||||
with pytest.raises(ValueError, match="All messages must be dictionaries."):
|
||||
gen_cache_key("This is not a list", model="docgpt", tools=None)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("application.cache.get_redis_instance")
|
||||
def test_gen_cache_hit(mock_make_redis):
|
||||
# Arrange
|
||||
mock_redis_instance = MagicMock()
|
||||
mock_make_redis.return_value = mock_redis_instance
|
||||
mock_redis_instance.get.return_value = b"cached_result" # Simulate a cache hit
|
||||
mock_redis_instance.get.return_value = b"cached_result"
|
||||
|
||||
@gen_cache
|
||||
def mock_function(self, model, messages, stream, tools):
|
||||
return "new_result"
|
||||
|
||||
messages = [{'role': 'user', 'content': 'test_user_message'}]
|
||||
messages = [{"role": "user", "content": "test_user_message"}]
|
||||
model = "test_docgpt"
|
||||
|
||||
# Act
|
||||
result = mock_function(None, model, messages, stream=False, tools=None)
|
||||
|
||||
# Assert
|
||||
assert result == "cached_result" # Should return cached result
|
||||
mock_redis_instance.get.assert_called_once() # Ensure Redis get was called
|
||||
mock_redis_instance.set.assert_not_called() # Ensure the function result is not cached again
|
||||
assert result == "cached_result"
|
||||
mock_redis_instance.get.assert_called_once()
|
||||
mock_redis_instance.set.assert_not_called()
|
||||
|
||||
|
||||
@patch('application.cache.get_redis_instance') # Mock the Redis client
|
||||
@pytest.mark.unit
|
||||
@patch("application.cache.get_redis_instance")
|
||||
def test_gen_cache_miss(mock_make_redis):
|
||||
# Arrange
|
||||
mock_redis_instance = MagicMock()
|
||||
mock_make_redis.return_value = mock_redis_instance
|
||||
mock_redis_instance.get.return_value = None # Simulate a cache miss
|
||||
mock_redis_instance.get.return_value = None
|
||||
|
||||
@gen_cache
|
||||
def mock_function(self, model, messages, steam, tools):
|
||||
return "new_result"
|
||||
|
||||
messages = [
|
||||
{'role': 'user', 'content': 'test_user_message'},
|
||||
{'role': 'system', 'content': 'test_system_message'},
|
||||
{"role": "user", "content": "test_user_message"},
|
||||
{"role": "system", "content": "test_system_message"},
|
||||
]
|
||||
model = "test_docgpt"
|
||||
# Act
|
||||
|
||||
result = mock_function(None, model, messages, stream=False, tools=None)
|
||||
|
||||
# Assert
|
||||
assert result == "new_result"
|
||||
mock_redis_instance.get.assert_called_once()
|
||||
mock_redis_instance.get.assert_called_once()
|
||||
|
||||
@patch('application.cache.get_redis_instance')
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("application.cache.get_redis_instance")
|
||||
def test_stream_cache_hit(mock_make_redis):
|
||||
# Arrange
|
||||
mock_redis_instance = MagicMock()
|
||||
mock_make_redis.return_value = mock_redis_instance
|
||||
|
||||
cached_chunk = json.dumps(["chunk1", "chunk2"]).encode('utf-8')
|
||||
cached_chunk = json.dumps(["chunk1", "chunk2"]).encode("utf-8")
|
||||
mock_redis_instance.get.return_value = cached_chunk
|
||||
|
||||
@stream_cache
|
||||
def mock_function(self, model, messages, stream, tools):
|
||||
yield "new_chunk"
|
||||
|
||||
messages = [{'role': 'user', 'content': 'test_user_message'}]
|
||||
messages = [{"role": "user", "content": "test_user_message"}]
|
||||
model = "test_docgpt"
|
||||
|
||||
# Act
|
||||
result = list(mock_function(None, model, messages, stream=True, tools=None))
|
||||
|
||||
# Assert
|
||||
assert result == ["chunk1", "chunk2"] # Should return cached chunks
|
||||
assert result == ["chunk1", "chunk2"]
|
||||
mock_redis_instance.get.assert_called_once()
|
||||
mock_redis_instance.set.assert_not_called()
|
||||
|
||||
|
||||
@patch('application.cache.get_redis_instance')
|
||||
@pytest.mark.unit
|
||||
@patch("application.cache.get_redis_instance")
|
||||
def test_stream_cache_miss(mock_make_redis):
|
||||
# Arrange
|
||||
mock_redis_instance = MagicMock()
|
||||
mock_make_redis.return_value = mock_redis_instance
|
||||
mock_redis_instance.get.return_value = None # Simulate a cache miss
|
||||
mock_redis_instance.get.return_value = None
|
||||
|
||||
@stream_cache
|
||||
def mock_function(self, model, messages, stream, tools):
|
||||
yield "new_chunk"
|
||||
|
||||
messages = [
|
||||
{'role': 'user', 'content': 'This is the context'},
|
||||
{'role': 'system', 'content': 'Some other message'},
|
||||
{'role': 'user', 'content': 'What is the answer?'}
|
||||
{"role": "user", "content": "This is the context"},
|
||||
{"role": "system", "content": "Some other message"},
|
||||
{"role": "user", "content": "What is the answer?"},
|
||||
]
|
||||
model = "test_docgpt"
|
||||
|
||||
# Act
|
||||
result = list(mock_function(None, model, messages, stream=True, tools=None))
|
||||
|
||||
# Assert
|
||||
assert result == ["new_chunk"]
|
||||
mock_redis_instance.get.assert_called_once()
|
||||
mock_redis_instance.set.assert_called_once()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
mock_redis_instance.get.assert_called_once()
|
||||
mock_redis_instance.set.assert_called_once()
|
||||
|
||||
@@ -1,21 +1,21 @@
|
||||
from unittest.mock import patch
|
||||
from application.core.settings import settings
|
||||
|
||||
import pytest
|
||||
from application.celery_init import make_celery
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
@patch('application.celery_init.Celery')
|
||||
@pytest.mark.unit
|
||||
@patch("application.celery_init.Celery")
|
||||
def test_make_celery(mock_celery):
|
||||
# Arrange
|
||||
app_name = 'test_app_name'
|
||||
app_name = "test_app_name"
|
||||
|
||||
# Act
|
||||
celery = make_celery(app_name)
|
||||
|
||||
# Assert
|
||||
mock_celery.assert_called_once_with(
|
||||
app_name,
|
||||
broker=settings.CELERY_BROKER_URL,
|
||||
backend=settings.CELERY_RESULT_BACKEND
|
||||
app_name,
|
||||
broker=settings.CELERY_BROKER_URL,
|
||||
backend=settings.CELERY_RESULT_BACKEND,
|
||||
)
|
||||
celery.conf.update.assert_called_once_with(settings)
|
||||
assert celery == mock_celery.return_value
|
||||
assert celery == mock_celery.return_value
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from application.error import bad_request, response_error
|
||||
from flask import Flask
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -9,31 +9,35 @@ def app():
|
||||
return app
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_bad_request_with_message(app):
|
||||
with app.app_context():
|
||||
message = "Invalid input"
|
||||
response = bad_request(status_code=400, message=message)
|
||||
assert response.status_code == 400
|
||||
assert response.json == {'error': 'Bad Request', 'message': message}
|
||||
assert response.json == {"error": "Bad Request", "message": message}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_bad_request_without_message(app):
|
||||
with app.app_context():
|
||||
response = bad_request(status_code=400)
|
||||
assert response.status_code == 400
|
||||
assert response.json == {'error': 'Bad Request'}
|
||||
assert response.json == {"error": "Bad Request"}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_response_error_with_message(app):
|
||||
with app.app_context():
|
||||
message = "Something went wrong"
|
||||
response = response_error(code_status=500, message=message)
|
||||
assert response.status_code == 500
|
||||
assert response.json == {'error': 'Internal Server Error', 'message': message}
|
||||
assert response.json == {"error": "Internal Server Error", "message": message}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_response_error_without_message(app):
|
||||
with app.app_context():
|
||||
response = response_error(code_status=500)
|
||||
assert response.status_code == 500
|
||||
assert response.json == {'error': 'Internal Server Error'}
|
||||
assert response.json == {"error": "Internal Server Error"}
|
||||
|
||||
@@ -17,6 +17,7 @@ def memory_tool(monkeypatch) -> MemoryTool:
|
||||
path = doc.get("path")
|
||||
key = f"{user_id}:{tool_id}:{path}"
|
||||
# Add _id to document if not present
|
||||
|
||||
if "_id" not in doc:
|
||||
doc["_id"] = key
|
||||
self.docs[key] = doc
|
||||
@@ -24,16 +25,17 @@ def memory_tool(monkeypatch) -> MemoryTool:
|
||||
|
||||
def update_one(self, q, u, upsert=False):
|
||||
# Handle query by _id
|
||||
|
||||
if "_id" in q:
|
||||
doc_id = q["_id"]
|
||||
if doc_id not in self.docs:
|
||||
return type("res", (), {"modified_count": 0})
|
||||
|
||||
if "$set" in u:
|
||||
old_doc = self.docs[doc_id].copy()
|
||||
old_doc.update(u["$set"])
|
||||
|
||||
# If path changed, update the dictionary key
|
||||
|
||||
if "path" in u["$set"]:
|
||||
new_path = u["$set"]["path"]
|
||||
user_id = old_doc.get("user_id")
|
||||
@@ -41,15 +43,15 @@ def memory_tool(monkeypatch) -> MemoryTool:
|
||||
new_key = f"{user_id}:{tool_id}:{new_path}"
|
||||
|
||||
# Remove old key and add with new key
|
||||
|
||||
del self.docs[doc_id]
|
||||
old_doc["_id"] = new_key
|
||||
self.docs[new_key] = old_doc
|
||||
else:
|
||||
self.docs[doc_id] = old_doc
|
||||
|
||||
return type("res", (), {"modified_count": 1})
|
||||
|
||||
# Handle query by user_id, tool_id, path
|
||||
|
||||
user_id = q.get("user_id")
|
||||
tool_id = q.get("tool_id")
|
||||
path = q.get("path")
|
||||
@@ -57,13 +59,16 @@ def memory_tool(monkeypatch) -> MemoryTool:
|
||||
|
||||
if key not in self.docs and not upsert:
|
||||
return type("res", (), {"modified_count": 0})
|
||||
|
||||
if key not in self.docs and upsert:
|
||||
self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "path": path, "content": "", "_id": key}
|
||||
|
||||
self.docs[key] = {
|
||||
"user_id": user_id,
|
||||
"tool_id": tool_id,
|
||||
"path": path,
|
||||
"content": "",
|
||||
"_id": key,
|
||||
}
|
||||
if "$set" in u:
|
||||
self.docs[key].update(u["$set"])
|
||||
|
||||
return type("res", (), {"modified_count": 1})
|
||||
|
||||
def find_one(self, q, projection=None):
|
||||
@@ -74,7 +79,6 @@ def memory_tool(monkeypatch) -> MemoryTool:
|
||||
if path:
|
||||
key = f"{user_id}:{tool_id}:{path}"
|
||||
return self.docs.get(key)
|
||||
|
||||
return None
|
||||
|
||||
def find(self, q, projection=None):
|
||||
@@ -83,9 +87,11 @@ def memory_tool(monkeypatch) -> MemoryTool:
|
||||
results = []
|
||||
|
||||
# Handle regex queries for directory listing
|
||||
|
||||
if "path" in q and isinstance(q["path"], dict) and "$regex" in q["path"]:
|
||||
regex_pattern = q["path"]["$regex"]
|
||||
# Remove regex escape characters and ^ anchor for simple matching
|
||||
|
||||
pattern = regex_pattern.replace("\\", "").lstrip("^")
|
||||
|
||||
for key, doc in self.docs.items():
|
||||
@@ -97,7 +103,6 @@ def memory_tool(monkeypatch) -> MemoryTool:
|
||||
for key, doc in self.docs.items():
|
||||
if doc.get("user_id") == user_id and doc.get("tool_id") == tool_id:
|
||||
results.append(doc)
|
||||
|
||||
return results
|
||||
|
||||
def delete_one(self, q):
|
||||
@@ -109,7 +114,6 @@ def memory_tool(monkeypatch) -> MemoryTool:
|
||||
if key in self.docs:
|
||||
del self.docs[key]
|
||||
return type("res", (), {"deleted_count": 1})
|
||||
|
||||
return type("res", (), {"deleted_count": 0})
|
||||
|
||||
def delete_many(self, q):
|
||||
@@ -118,6 +122,7 @@ def memory_tool(monkeypatch) -> MemoryTool:
|
||||
deleted = 0
|
||||
|
||||
# Handle regex queries for directory deletion
|
||||
|
||||
if "path" in q and isinstance(q["path"], dict) and "$regex" in q["path"]:
|
||||
regex_pattern = q["path"]["$regex"]
|
||||
pattern = regex_pattern.replace("\\", "").lstrip("^")
|
||||
@@ -128,32 +133,36 @@ def memory_tool(monkeypatch) -> MemoryTool:
|
||||
doc_path = doc.get("path", "")
|
||||
if doc_path.startswith(pattern):
|
||||
keys_to_delete.append(key)
|
||||
|
||||
for key in keys_to_delete:
|
||||
del self.docs[key]
|
||||
deleted += 1
|
||||
else:
|
||||
# Delete all for user and tool
|
||||
|
||||
keys_to_delete = [
|
||||
key for key, doc in self.docs.items()
|
||||
key
|
||||
for key, doc in self.docs.items()
|
||||
if doc.get("user_id") == user_id and doc.get("tool_id") == tool_id
|
||||
]
|
||||
for key in keys_to_delete:
|
||||
del self.docs[key]
|
||||
deleted += 1
|
||||
|
||||
return type("res", (), {"deleted_count": deleted})
|
||||
|
||||
fake_collection = FakeCollection()
|
||||
fake_db = {"memories": fake_collection}
|
||||
fake_client = {settings.MONGO_DB_NAME: fake_db}
|
||||
|
||||
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
|
||||
monkeypatch.setattr(
|
||||
"application.core.mongo_db.MongoDB.get_client", lambda: fake_client
|
||||
)
|
||||
|
||||
# Return tool with a fixed tool_id for consistency in tests
|
||||
|
||||
return MemoryTool({"tool_id": "test_tool_id"}, user_id="test_user")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_without_user_id():
|
||||
"""Should fail gracefully if no user_id is provided."""
|
||||
memory_tool = MemoryTool(tool_config={})
|
||||
@@ -161,90 +170,78 @@ def test_init_without_user_id():
|
||||
assert "user_id" in result.lower()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_view_empty_directory(memory_tool: MemoryTool) -> None:
|
||||
"""Should show empty directory when no files exist."""
|
||||
result = memory_tool.execute_action("view", path="/")
|
||||
assert "empty" in result.lower()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_create_and_view_file(memory_tool: MemoryTool) -> None:
|
||||
"""Test creating a file and viewing it."""
|
||||
# Create a file
|
||||
|
||||
result = memory_tool.execute_action(
|
||||
"create",
|
||||
path="/notes.txt",
|
||||
file_text="Hello world"
|
||||
"create", path="/notes.txt", file_text="Hello world"
|
||||
)
|
||||
assert "created" in result.lower()
|
||||
|
||||
# View the file
|
||||
|
||||
result = memory_tool.execute_action("view", path="/notes.txt")
|
||||
assert "Hello world" in result
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_create_overwrite_file(memory_tool: MemoryTool) -> None:
|
||||
"""Test that create overwrites existing files."""
|
||||
# Create initial file
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/test.txt",
|
||||
file_text="Original content"
|
||||
)
|
||||
|
||||
memory_tool.execute_action("create", path="/test.txt", file_text="Original content")
|
||||
|
||||
# Overwrite
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/test.txt",
|
||||
file_text="New content"
|
||||
)
|
||||
|
||||
memory_tool.execute_action("create", path="/test.txt", file_text="New content")
|
||||
|
||||
# Verify overwrite
|
||||
|
||||
result = memory_tool.execute_action("view", path="/test.txt")
|
||||
assert "New content" in result
|
||||
assert "Original content" not in result
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_view_directory_with_files(memory_tool: MemoryTool) -> None:
|
||||
"""Test viewing directory contents."""
|
||||
# Create multiple files
|
||||
|
||||
memory_tool.execute_action("create", path="/file1.txt", file_text="Content 1")
|
||||
memory_tool.execute_action("create", path="/file2.txt", file_text="Content 2")
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/file1.txt",
|
||||
file_text="Content 1"
|
||||
)
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/file2.txt",
|
||||
file_text="Content 2"
|
||||
)
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/subdir/file3.txt",
|
||||
file_text="Content 3"
|
||||
"create", path="/subdir/file3.txt", file_text="Content 3"
|
||||
)
|
||||
|
||||
# View directory
|
||||
|
||||
result = memory_tool.execute_action("view", path="/")
|
||||
assert "file1.txt" in result
|
||||
assert "file2.txt" in result
|
||||
assert "subdir/file3.txt" in result
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_view_file_with_line_range(memory_tool: MemoryTool) -> None:
|
||||
"""Test viewing specific lines from a file."""
|
||||
# Create a multiline file
|
||||
|
||||
content = "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/multiline.txt",
|
||||
file_text=content
|
||||
)
|
||||
memory_tool.execute_action("create", path="/multiline.txt", file_text=content)
|
||||
|
||||
# View lines 2-4
|
||||
|
||||
result = memory_tool.execute_action(
|
||||
"view",
|
||||
path="/multiline.txt",
|
||||
view_range=[2, 4]
|
||||
"view", path="/multiline.txt", view_range=[2, 4]
|
||||
)
|
||||
assert "Line 2" in result
|
||||
assert "Line 3" in result
|
||||
@@ -253,197 +250,177 @@ def test_view_file_with_line_range(memory_tool: MemoryTool) -> None:
|
||||
assert "Line 5" not in result
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_str_replace(memory_tool: MemoryTool) -> None:
|
||||
"""Test string replacement in a file."""
|
||||
# Create a file
|
||||
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/replace.txt",
|
||||
file_text="Hello world, hello universe"
|
||||
"create", path="/replace.txt", file_text="Hello world, hello universe"
|
||||
)
|
||||
|
||||
# Replace text
|
||||
|
||||
result = memory_tool.execute_action(
|
||||
"str_replace",
|
||||
path="/replace.txt",
|
||||
old_str="hello",
|
||||
new_str="hi"
|
||||
"str_replace", path="/replace.txt", old_str="hello", new_str="hi"
|
||||
)
|
||||
assert "updated" in result.lower()
|
||||
|
||||
# Verify replacement
|
||||
|
||||
content = memory_tool.execute_action("view", path="/replace.txt")
|
||||
assert "hi world, hi universe" in content
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_str_replace_not_found(memory_tool: MemoryTool) -> None:
|
||||
"""Test string replacement when string not found."""
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/test.txt",
|
||||
file_text="Hello world"
|
||||
)
|
||||
memory_tool.execute_action("create", path="/test.txt", file_text="Hello world")
|
||||
|
||||
result = memory_tool.execute_action(
|
||||
"str_replace",
|
||||
path="/test.txt",
|
||||
old_str="goodbye",
|
||||
new_str="hi"
|
||||
"str_replace", path="/test.txt", old_str="goodbye", new_str="hi"
|
||||
)
|
||||
assert "not found" in result.lower()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_insert_line(memory_tool: MemoryTool) -> None:
|
||||
"""Test inserting text at a line number."""
|
||||
# Create a multiline file
|
||||
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/insert.txt",
|
||||
file_text="Line 1\nLine 2\nLine 3"
|
||||
"create", path="/insert.txt", file_text="Line 1\nLine 2\nLine 3"
|
||||
)
|
||||
|
||||
# Insert at line 2
|
||||
|
||||
result = memory_tool.execute_action(
|
||||
"insert",
|
||||
path="/insert.txt",
|
||||
insert_line=2,
|
||||
insert_text="Inserted line"
|
||||
"insert", path="/insert.txt", insert_line=2, insert_text="Inserted line"
|
||||
)
|
||||
assert "inserted" in result.lower()
|
||||
|
||||
# Verify insertion
|
||||
|
||||
content = memory_tool.execute_action("view", path="/insert.txt")
|
||||
lines = content.split("\n")
|
||||
assert lines[1] == "Inserted line"
|
||||
assert lines[2] == "Line 2"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_insert_invalid_line(memory_tool: MemoryTool) -> None:
|
||||
"""Test inserting at an invalid line number."""
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/test.txt",
|
||||
file_text="Line 1\nLine 2"
|
||||
)
|
||||
memory_tool.execute_action("create", path="/test.txt", file_text="Line 1\nLine 2")
|
||||
|
||||
result = memory_tool.execute_action(
|
||||
"insert",
|
||||
path="/test.txt",
|
||||
insert_line=100,
|
||||
insert_text="Text"
|
||||
"insert", path="/test.txt", insert_line=100, insert_text="Text"
|
||||
)
|
||||
assert "invalid" in result.lower()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_delete_file(memory_tool: MemoryTool) -> None:
|
||||
"""Test deleting a file."""
|
||||
# Create a file
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/delete_me.txt",
|
||||
file_text="Content"
|
||||
)
|
||||
|
||||
memory_tool.execute_action("create", path="/delete_me.txt", file_text="Content")
|
||||
|
||||
# Delete it
|
||||
|
||||
result = memory_tool.execute_action("delete", path="/delete_me.txt")
|
||||
assert "deleted" in result.lower()
|
||||
|
||||
# Verify it's gone
|
||||
|
||||
result = memory_tool.execute_action("view", path="/delete_me.txt")
|
||||
assert "not found" in result.lower()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_delete_nonexistent_file(memory_tool: MemoryTool) -> None:
|
||||
"""Test deleting a file that doesn't exist."""
|
||||
result = memory_tool.execute_action("delete", path="/nonexistent.txt")
|
||||
assert "not found" in result.lower()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_delete_directory(memory_tool: MemoryTool) -> None:
|
||||
"""Test deleting a directory with files."""
|
||||
# Create files in a directory
|
||||
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/subdir/file1.txt",
|
||||
file_text="Content 1"
|
||||
"create", path="/subdir/file1.txt", file_text="Content 1"
|
||||
)
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/subdir/file2.txt",
|
||||
file_text="Content 2"
|
||||
"create", path="/subdir/file2.txt", file_text="Content 2"
|
||||
)
|
||||
|
||||
# Delete the directory
|
||||
|
||||
result = memory_tool.execute_action("delete", path="/subdir/")
|
||||
assert "deleted" in result.lower()
|
||||
|
||||
# Verify files are gone
|
||||
|
||||
result = memory_tool.execute_action("view", path="/subdir/file1.txt")
|
||||
assert "not found" in result.lower()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_rename_file(memory_tool: MemoryTool) -> None:
|
||||
"""Test renaming a file."""
|
||||
# Create a file
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/old_name.txt",
|
||||
file_text="Content"
|
||||
)
|
||||
|
||||
memory_tool.execute_action("create", path="/old_name.txt", file_text="Content")
|
||||
|
||||
# Rename it
|
||||
|
||||
result = memory_tool.execute_action(
|
||||
"rename",
|
||||
old_path="/old_name.txt",
|
||||
new_path="/new_name.txt"
|
||||
"rename", old_path="/old_name.txt", new_path="/new_name.txt"
|
||||
)
|
||||
assert "renamed" in result.lower()
|
||||
|
||||
# Verify old path doesn't exist
|
||||
|
||||
result = memory_tool.execute_action("view", path="/old_name.txt")
|
||||
assert "not found" in result.lower()
|
||||
|
||||
# Verify new path exists
|
||||
|
||||
result = memory_tool.execute_action("view", path="/new_name.txt")
|
||||
assert "Content" in result
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_rename_nonexistent_file(memory_tool: MemoryTool) -> None:
|
||||
"""Test renaming a file that doesn't exist."""
|
||||
result = memory_tool.execute_action(
|
||||
"rename",
|
||||
old_path="/nonexistent.txt",
|
||||
new_path="/new.txt"
|
||||
"rename", old_path="/nonexistent.txt", new_path="/new.txt"
|
||||
)
|
||||
assert "not found" in result.lower()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_rename_to_existing_file(memory_tool: MemoryTool) -> None:
|
||||
"""Test renaming to a path that already exists."""
|
||||
# Create two files
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/file1.txt",
|
||||
file_text="Content 1"
|
||||
)
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/file2.txt",
|
||||
file_text="Content 2"
|
||||
)
|
||||
|
||||
memory_tool.execute_action("create", path="/file1.txt", file_text="Content 1")
|
||||
memory_tool.execute_action("create", path="/file2.txt", file_text="Content 2")
|
||||
|
||||
# Try to rename file1 to file2
|
||||
|
||||
result = memory_tool.execute_action(
|
||||
"rename",
|
||||
old_path="/file1.txt",
|
||||
new_path="/file2.txt"
|
||||
"rename", old_path="/file1.txt", new_path="/file2.txt"
|
||||
)
|
||||
assert "already exists" in result.lower()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_path_traversal_protection(memory_tool: MemoryTool) -> None:
|
||||
"""Test that directory traversal attacks are prevented."""
|
||||
# Try various path traversal attempts
|
||||
|
||||
invalid_paths = [
|
||||
"/../secrets.txt",
|
||||
"/../../etc/passwd",
|
||||
@@ -453,16 +430,16 @@ def test_path_traversal_protection(memory_tool: MemoryTool) -> None:
|
||||
|
||||
for path in invalid_paths:
|
||||
result = memory_tool.execute_action(
|
||||
"create",
|
||||
path=path,
|
||||
file_text="malicious content"
|
||||
"create", path=path, file_text="malicious content"
|
||||
)
|
||||
assert "invalid path" in result.lower()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_path_must_start_with_slash(memory_tool: MemoryTool) -> None:
|
||||
"""Test that paths work with or without leading slash (auto-normalized)."""
|
||||
# These paths should all work now (auto-prepended with /)
|
||||
|
||||
valid_paths = [
|
||||
"etc/passwd", # Auto-prepended with /
|
||||
"home/user/file.txt", # Auto-prepended with /
|
||||
@@ -470,33 +447,29 @@ def test_path_must_start_with_slash(memory_tool: MemoryTool) -> None:
|
||||
]
|
||||
|
||||
for path in valid_paths:
|
||||
result = memory_tool.execute_action(
|
||||
"create",
|
||||
path=path,
|
||||
file_text="content"
|
||||
)
|
||||
result = memory_tool.execute_action("create", path=path, file_text="content")
|
||||
assert "created" in result.lower()
|
||||
|
||||
# Verify the file can be accessed with or without leading slash
|
||||
|
||||
view_result = memory_tool.execute_action("view", path=path)
|
||||
assert "content" in view_result
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_cannot_create_directory_as_file(memory_tool: MemoryTool) -> None:
|
||||
"""Test that you cannot create a file at a directory path."""
|
||||
result = memory_tool.execute_action(
|
||||
"create",
|
||||
path="/",
|
||||
file_text="content"
|
||||
)
|
||||
result = memory_tool.execute_action("create", path="/", file_text="content")
|
||||
assert "cannot create a file at directory path" in result.lower()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_actions_metadata(memory_tool: MemoryTool) -> None:
|
||||
"""Test that action metadata is properly defined."""
|
||||
metadata = memory_tool.get_actions_metadata()
|
||||
|
||||
# Check that all expected actions are defined
|
||||
|
||||
action_names = [action["name"] for action in metadata]
|
||||
assert "view" in action_names
|
||||
assert "create" in action_names
|
||||
@@ -506,15 +479,18 @@ def test_get_actions_metadata(memory_tool: MemoryTool) -> None:
|
||||
assert "rename" in action_names
|
||||
|
||||
# Check that each action has required fields
|
||||
|
||||
for action in metadata:
|
||||
assert "name" in action
|
||||
assert "description" in action
|
||||
assert "parameters" in action
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_memory_tool_isolation(monkeypatch) -> None:
|
||||
"""Test that different memory tool instances have isolated memories."""
|
||||
# Create fake collection
|
||||
|
||||
class FakeCollection:
|
||||
def __init__(self) -> None:
|
||||
self.docs = {}
|
||||
@@ -529,16 +505,17 @@ def test_memory_tool_isolation(monkeypatch) -> None:
|
||||
|
||||
def update_one(self, q, u, upsert=False):
|
||||
# Handle query by _id
|
||||
|
||||
if "_id" in q:
|
||||
doc_id = q["_id"]
|
||||
if doc_id not in self.docs:
|
||||
return type("res", (), {"modified_count": 0})
|
||||
|
||||
if "$set" in u:
|
||||
old_doc = self.docs[doc_id].copy()
|
||||
old_doc.update(u["$set"])
|
||||
|
||||
# If path changed, update the dictionary key
|
||||
|
||||
if "path" in u["$set"]:
|
||||
new_path = u["$set"]["path"]
|
||||
user_id = old_doc.get("user_id")
|
||||
@@ -546,15 +523,15 @@ def test_memory_tool_isolation(monkeypatch) -> None:
|
||||
new_key = f"{user_id}:{tool_id}:{new_path}"
|
||||
|
||||
# Remove old key and add with new key
|
||||
|
||||
del self.docs[doc_id]
|
||||
old_doc["_id"] = new_key
|
||||
self.docs[new_key] = old_doc
|
||||
else:
|
||||
self.docs[doc_id] = old_doc
|
||||
|
||||
return type("res", (), {"modified_count": 1})
|
||||
|
||||
# Handle query by user_id, tool_id, path
|
||||
|
||||
user_id = q.get("user_id")
|
||||
tool_id = q.get("tool_id")
|
||||
path = q.get("path")
|
||||
@@ -562,13 +539,16 @@ def test_memory_tool_isolation(monkeypatch) -> None:
|
||||
|
||||
if key not in self.docs and not upsert:
|
||||
return type("res", (), {"modified_count": 0})
|
||||
|
||||
if key not in self.docs and upsert:
|
||||
self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "path": path, "content": "", "_id": key}
|
||||
|
||||
self.docs[key] = {
|
||||
"user_id": user_id,
|
||||
"tool_id": tool_id,
|
||||
"path": path,
|
||||
"content": "",
|
||||
"_id": key,
|
||||
}
|
||||
if "$set" in u:
|
||||
self.docs[key].update(u["$set"])
|
||||
|
||||
return type("res", (), {"modified_count": 1})
|
||||
|
||||
def find_one(self, q, projection=None):
|
||||
@@ -579,26 +559,31 @@ def test_memory_tool_isolation(monkeypatch) -> None:
|
||||
if path:
|
||||
key = f"{user_id}:{tool_id}:{path}"
|
||||
return self.docs.get(key)
|
||||
|
||||
return None
|
||||
|
||||
fake_collection = FakeCollection()
|
||||
fake_db = {"memories": fake_collection}
|
||||
fake_client = {settings.MONGO_DB_NAME: fake_db}
|
||||
|
||||
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
|
||||
monkeypatch.setattr(
|
||||
"application.core.mongo_db.MongoDB.get_client", lambda: fake_client
|
||||
)
|
||||
|
||||
# Create two memory tools with different tool_ids for the same user
|
||||
|
||||
tool1 = MemoryTool({"tool_id": "tool_1"}, user_id="test_user")
|
||||
tool2 = MemoryTool({"tool_id": "tool_2"}, user_id="test_user")
|
||||
|
||||
# Create a file in tool1
|
||||
|
||||
tool1.execute_action("create", path="/file.txt", file_text="Content from tool 1")
|
||||
|
||||
# Create a file with the same path in tool2
|
||||
|
||||
tool2.execute_action("create", path="/file.txt", file_text="Content from tool 2")
|
||||
|
||||
# Verify that each tool sees only its own content
|
||||
|
||||
result1 = tool1.execute_action("view", path="/file.txt")
|
||||
result2 = tool2.execute_action("view", path="/file.txt")
|
||||
|
||||
@@ -609,8 +594,10 @@ def test_memory_tool_isolation(monkeypatch) -> None:
|
||||
assert "Content from tool 1" not in result2
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_memory_tool_auto_generates_tool_id(monkeypatch) -> None:
|
||||
"""Test that tool_id defaults to 'default_{user_id}' for persistence."""
|
||||
|
||||
class FakeCollection:
|
||||
def __init__(self) -> None:
|
||||
self.docs = {}
|
||||
@@ -622,78 +609,94 @@ def test_memory_tool_auto_generates_tool_id(monkeypatch) -> None:
|
||||
fake_db = {"memories": fake_collection}
|
||||
fake_client = {settings.MONGO_DB_NAME: fake_db}
|
||||
|
||||
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
|
||||
monkeypatch.setattr(
|
||||
"application.core.mongo_db.MongoDB.get_client", lambda: fake_client
|
||||
)
|
||||
|
||||
# Create two tools without providing tool_id for the same user
|
||||
|
||||
tool1 = MemoryTool({}, user_id="test_user")
|
||||
tool2 = MemoryTool({}, user_id="test_user")
|
||||
|
||||
# Both should have the same default tool_id for persistence
|
||||
|
||||
assert tool1.tool_id == "default_test_user"
|
||||
assert tool2.tool_id == "default_test_user"
|
||||
assert tool1.tool_id == tool2.tool_id
|
||||
|
||||
# Different users should have different tool_ids
|
||||
|
||||
tool3 = MemoryTool({}, user_id="another_user")
|
||||
assert tool3.tool_id == "default_another_user"
|
||||
assert tool3.tool_id != tool1.tool_id
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_paths_without_leading_slash(memory_tool) -> None:
|
||||
"""Test that paths without leading slash work correctly."""
|
||||
# Create file without leading slash
|
||||
result = memory_tool.execute_action("create", path="cat_breeds.txt", file_text="- Korat\n- Chartreux\n- British Shorthair\n- Nebelung")
|
||||
|
||||
result = memory_tool.execute_action(
|
||||
"create",
|
||||
path="cat_breeds.txt",
|
||||
file_text="- Korat\n- Chartreux\n- British Shorthair\n- Nebelung",
|
||||
)
|
||||
assert "created" in result.lower()
|
||||
|
||||
# View file without leading slash
|
||||
|
||||
view_result = memory_tool.execute_action("view", path="cat_breeds.txt")
|
||||
assert "Korat" in view_result
|
||||
assert "Chartreux" in view_result
|
||||
|
||||
# View same file with leading slash (should work the same)
|
||||
|
||||
view_result2 = memory_tool.execute_action("view", path="/cat_breeds.txt")
|
||||
assert "Korat" in view_result2
|
||||
|
||||
# Test str_replace without leading slash
|
||||
replace_result = memory_tool.execute_action("str_replace", path="cat_breeds.txt", old_str="Korat", new_str="Maine Coon")
|
||||
|
||||
replace_result = memory_tool.execute_action(
|
||||
"str_replace", path="cat_breeds.txt", old_str="Korat", new_str="Maine Coon"
|
||||
)
|
||||
assert "updated" in replace_result.lower()
|
||||
|
||||
# Test nested path without leading slash
|
||||
nested_result = memory_tool.execute_action("create", path="projects/tasks.txt", file_text="Task 1\nTask 2")
|
||||
|
||||
nested_result = memory_tool.execute_action(
|
||||
"create", path="projects/tasks.txt", file_text="Task 1\nTask 2"
|
||||
)
|
||||
assert "created" in nested_result.lower()
|
||||
|
||||
view_nested = memory_tool.execute_action("view", path="projects/tasks.txt")
|
||||
assert "Task 1" in view_nested
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_rename_directory(memory_tool: MemoryTool) -> None:
|
||||
"""Test renaming a directory with files."""
|
||||
# Create files in a directory
|
||||
|
||||
memory_tool.execute_action("create", path="/docs/file1.txt", file_text="Content 1")
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/docs/file1.txt",
|
||||
file_text="Content 1"
|
||||
)
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/docs/sub/file2.txt",
|
||||
file_text="Content 2"
|
||||
"create", path="/docs/sub/file2.txt", file_text="Content 2"
|
||||
)
|
||||
|
||||
# Rename directory (with trailing slash)
|
||||
|
||||
result = memory_tool.execute_action(
|
||||
"rename",
|
||||
old_path="/docs/",
|
||||
new_path="/archive/"
|
||||
"rename", old_path="/docs/", new_path="/archive/"
|
||||
)
|
||||
assert "renamed" in result.lower()
|
||||
assert "2 files" in result.lower()
|
||||
|
||||
# Verify old paths don't exist
|
||||
|
||||
result = memory_tool.execute_action("view", path="/docs/file1.txt")
|
||||
assert "not found" in result.lower()
|
||||
|
||||
# Verify new paths exist
|
||||
|
||||
result = memory_tool.execute_action("view", path="/archive/file1.txt")
|
||||
assert "Content 1" in result
|
||||
|
||||
@@ -701,29 +704,25 @@ def test_rename_directory(memory_tool: MemoryTool) -> None:
|
||||
assert "Content 2" in result
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_rename_directory_without_trailing_slash(memory_tool: MemoryTool) -> None:
|
||||
"""Test renaming a directory when new path is missing trailing slash."""
|
||||
# Create files in a directory
|
||||
|
||||
memory_tool.execute_action("create", path="/docs/file1.txt", file_text="Content 1")
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/docs/file1.txt",
|
||||
file_text="Content 1"
|
||||
)
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/docs/sub/file2.txt",
|
||||
file_text="Content 2"
|
||||
"create", path="/docs/sub/file2.txt", file_text="Content 2"
|
||||
)
|
||||
|
||||
# Rename directory - old path has slash, new path doesn't
|
||||
|
||||
result = memory_tool.execute_action(
|
||||
"rename",
|
||||
old_path="/docs/",
|
||||
new_path="/archive" # Missing trailing slash
|
||||
"rename", old_path="/docs/", new_path="/archive" # Missing trailing slash
|
||||
)
|
||||
assert "renamed" in result.lower()
|
||||
|
||||
# Verify paths are correct (not corrupted like "/archivesub/file2.txt")
|
||||
|
||||
result = memory_tool.execute_action("view", path="/archive/file1.txt")
|
||||
assert "Content 1" in result
|
||||
|
||||
@@ -731,28 +730,25 @@ def test_rename_directory_without_trailing_slash(memory_tool: MemoryTool) -> Non
|
||||
assert "Content 2" in result
|
||||
|
||||
# Verify corrupted path doesn't exist
|
||||
|
||||
result = memory_tool.execute_action("view", path="/archivesub/file2.txt")
|
||||
assert "not found" in result.lower()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_view_file_line_numbers(memory_tool: MemoryTool) -> None:
|
||||
"""Test that view_range displays correct line numbers."""
|
||||
# Create a multiline file
|
||||
|
||||
content = "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/numbered.txt",
|
||||
file_text=content
|
||||
)
|
||||
memory_tool.execute_action("create", path="/numbered.txt", file_text=content)
|
||||
|
||||
# View lines 2-4
|
||||
result = memory_tool.execute_action(
|
||||
"view",
|
||||
path="/numbered.txt",
|
||||
view_range=[2, 4]
|
||||
)
|
||||
|
||||
result = memory_tool.execute_action("view", path="/numbered.txt", view_range=[2, 4])
|
||||
|
||||
# Check that line numbers are correct (should be 2, 3, 4 not 3, 4, 5)
|
||||
|
||||
assert "2: Line 2" in result
|
||||
assert "3: Line 3" in result
|
||||
assert "4: Line 4" in result
|
||||
@@ -760,6 +756,7 @@ def test_view_file_line_numbers(memory_tool: MemoryTool) -> None:
|
||||
assert "5: Line 5" not in result
|
||||
|
||||
# Verify no off-by-one error
|
||||
|
||||
assert "3: Line 2" not in result # Wrong line number
|
||||
assert "4: Line 3" not in result # Wrong line number
|
||||
assert "5: Line 4" not in result # Wrong line number
|
||||
|
||||
@@ -3,10 +3,10 @@ from application.agents.tools.notes import NotesTool
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def notes_tool(monkeypatch) -> NotesTool:
|
||||
"""Provide a NotesTool with a fake Mongo collection and fixed user_id."""
|
||||
|
||||
class FakeCollection:
|
||||
def __init__(self) -> None:
|
||||
self.docs = {} # key: user_id:tool_id -> doc
|
||||
@@ -17,6 +17,7 @@ def notes_tool(monkeypatch) -> NotesTool:
|
||||
key = f"{user_id}:{tool_id}"
|
||||
|
||||
# emulate single-note storage with optional upsert
|
||||
|
||||
if key not in self.docs and not upsert:
|
||||
return type("res", (), {"modified_count": 0})
|
||||
if key not in self.docs and upsert:
|
||||
@@ -45,34 +46,45 @@ def notes_tool(monkeypatch) -> NotesTool:
|
||||
fake_client = {settings.MONGO_DB_NAME: fake_db}
|
||||
|
||||
# Patch MongoDB client globally for the tool
|
||||
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"application.core.mongo_db.MongoDB.get_client", lambda: fake_client
|
||||
)
|
||||
|
||||
# Return tool with a fixed tool_id for consistency in tests
|
||||
|
||||
return NotesTool({"tool_id": "test_tool_id"}, user_id="test_user")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_view(notes_tool: NotesTool) -> None:
|
||||
# Manually insert a note to test retrieval
|
||||
|
||||
notes_tool.collection.update_one(
|
||||
{"user_id": "test_user", "tool_id": "test_tool_id"},
|
||||
{"$set": {"note": "hello"}},
|
||||
upsert=True
|
||||
upsert=True,
|
||||
)
|
||||
assert "hello" in notes_tool.execute_action("view")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_overwrite_and_delete(notes_tool: NotesTool) -> None:
|
||||
# Overwrite creates a new note
|
||||
|
||||
assert "saved" in notes_tool.execute_action("overwrite", text="first").lower()
|
||||
assert "first" in notes_tool.execute_action("view")
|
||||
|
||||
# Overwrite replaces existing note
|
||||
|
||||
assert "saved" in notes_tool.execute_action("overwrite", text="second").lower()
|
||||
assert "second" in notes_tool.execute_action("view")
|
||||
|
||||
assert "deleted" in notes_tool.execute_action("delete").lower()
|
||||
assert "no note" in notes_tool.execute_action("view").lower()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_without_user_id(monkeypatch):
|
||||
"""Should fail gracefully if no user_id is provided."""
|
||||
notes_tool = NotesTool(tool_config={})
|
||||
@@ -80,26 +92,32 @@ def test_init_without_user_id(monkeypatch):
|
||||
assert "user_id" in str(result).lower()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_view_not_found(notes_tool: NotesTool) -> None:
|
||||
"""Should return 'No note found.' when no note exists"""
|
||||
result = notes_tool.execute_action("view")
|
||||
assert "no note found" in result.lower()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_str_replace(notes_tool: NotesTool) -> None:
|
||||
"""Test string replacement in note"""
|
||||
# Create a note
|
||||
|
||||
notes_tool.execute_action("overwrite", text="Hello world, hello universe")
|
||||
|
||||
# Replace text
|
||||
|
||||
result = notes_tool.execute_action("str_replace", old_str="hello", new_str="hi")
|
||||
assert "updated" in result.lower()
|
||||
|
||||
# Verify replacement
|
||||
|
||||
note = notes_tool.execute_action("view")
|
||||
assert "hi world, hi universe" in note.lower()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_str_replace_not_found(notes_tool: NotesTool) -> None:
|
||||
"""Test string replacement when string not found"""
|
||||
notes_tool.execute_action("overwrite", text="Hello world")
|
||||
@@ -107,22 +125,27 @@ def test_str_replace_not_found(notes_tool: NotesTool) -> None:
|
||||
assert "not found" in result.lower()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_insert_line(notes_tool: NotesTool) -> None:
|
||||
"""Test inserting text at a line number"""
|
||||
# Create a multiline note
|
||||
|
||||
notes_tool.execute_action("overwrite", text="Line 1\nLine 2\nLine 3")
|
||||
|
||||
# Insert at line 2
|
||||
|
||||
result = notes_tool.execute_action("insert", line_number=2, text="Inserted line")
|
||||
assert "inserted" in result.lower()
|
||||
|
||||
# Verify insertion
|
||||
|
||||
note = notes_tool.execute_action("view")
|
||||
lines = note.split("\n")
|
||||
assert lines[1] == "Inserted line"
|
||||
assert lines[2] == "Line 2"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_delete_nonexistent_note(monkeypatch):
|
||||
class FakeResult:
|
||||
deleted_count = 0
|
||||
@@ -133,7 +156,7 @@ def test_delete_nonexistent_note(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(
|
||||
"application.core.mongo_db.MongoDB.get_client",
|
||||
lambda: {"docsgpt": {"notes": FakeCollection()}}
|
||||
lambda: {"docsgpt": {"notes": FakeCollection()}},
|
||||
)
|
||||
|
||||
notes_tool = NotesTool(tool_config={}, user_id="user123")
|
||||
@@ -141,8 +164,10 @@ def test_delete_nonexistent_note(monkeypatch):
|
||||
assert "no note found" in result.lower()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_notes_tool_isolation(monkeypatch) -> None:
|
||||
"""Test that different notes tool instances have isolated notes."""
|
||||
|
||||
class FakeCollection:
|
||||
def __init__(self) -> None:
|
||||
self.docs = {}
|
||||
@@ -170,19 +195,25 @@ def test_notes_tool_isolation(monkeypatch) -> None:
|
||||
fake_db = {"notes": fake_collection}
|
||||
fake_client = {settings.MONGO_DB_NAME: fake_db}
|
||||
|
||||
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
|
||||
monkeypatch.setattr(
|
||||
"application.core.mongo_db.MongoDB.get_client", lambda: fake_client
|
||||
)
|
||||
|
||||
# Create two notes tools with different tool_ids for the same user
|
||||
|
||||
tool1 = NotesTool({"tool_id": "tool_1"}, user_id="test_user")
|
||||
tool2 = NotesTool({"tool_id": "tool_2"}, user_id="test_user")
|
||||
|
||||
# Create a note in tool1
|
||||
|
||||
tool1.execute_action("overwrite", text="Content from tool 1")
|
||||
|
||||
# Create a note in tool2
|
||||
|
||||
tool2.execute_action("overwrite", text="Content from tool 2")
|
||||
|
||||
# Verify that each tool sees only its own content
|
||||
|
||||
result1 = tool1.execute_action("view")
|
||||
result2 = tool2.execute_action("view")
|
||||
|
||||
@@ -193,8 +224,10 @@ def test_notes_tool_isolation(monkeypatch) -> None:
|
||||
assert "Content from tool 1" not in result2
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_notes_tool_auto_generates_tool_id(monkeypatch) -> None:
|
||||
"""Test that tool_id defaults to 'default_{user_id}' for persistence."""
|
||||
|
||||
class FakeCollection:
|
||||
def __init__(self) -> None:
|
||||
self.docs = {}
|
||||
@@ -206,18 +239,23 @@ def test_notes_tool_auto_generates_tool_id(monkeypatch) -> None:
|
||||
fake_db = {"notes": fake_collection}
|
||||
fake_client = {settings.MONGO_DB_NAME: fake_db}
|
||||
|
||||
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
|
||||
monkeypatch.setattr(
|
||||
"application.core.mongo_db.MongoDB.get_client", lambda: fake_client
|
||||
)
|
||||
|
||||
# Create two tools without providing tool_id for the same user
|
||||
|
||||
tool1 = NotesTool({}, user_id="test_user")
|
||||
tool2 = NotesTool({}, user_id="test_user")
|
||||
|
||||
# Both should have the same default tool_id for persistence
|
||||
|
||||
assert tool1.tool_id == "default_test_user"
|
||||
assert tool2.tool_id == "default_test_user"
|
||||
assert tool1.tool_id == tool2.tool_id
|
||||
|
||||
# Different users should have different tool_ids
|
||||
|
||||
tool3 = NotesTool({}, user_id="another_user")
|
||||
assert tool3.tool_id == "default_another_user"
|
||||
assert tool3.tool_id != tool1.tool_id
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
from openapi_parser import parse
|
||||
from application.parser.file.openapi3_parser import OpenAPI3Parser
|
||||
from openapi_parser import parse
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -17,10 +17,12 @@ from application.parser.file.openapi3_parser import OpenAPI3Parser
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.unit
|
||||
def test_get_base_urls(urls, expected_base_urls):
|
||||
assert OpenAPI3Parser().get_base_urls(urls) == expected_base_urls
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_info_from_paths():
|
||||
file_path = "tests/test_openapi3.yaml"
|
||||
data = parse(file_path)
|
||||
@@ -31,6 +33,7 @@ def test_get_info_from_paths():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_parse_file():
|
||||
file_path = "tests/test_openapi3.yaml"
|
||||
results_expected = (
|
||||
|
||||
Reference in New Issue
Block a user