mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-03-03 20:33:45 +00:00
test: implement full API test suite with mongomock and centralized fixtures (#2068)
This commit is contained in:
0
tests/api/__init__.py
Normal file
0
tests/api/__init__.py
Normal file
0
tests/api/answer/__init__.py
Normal file
0
tests/api/answer/__init__.py
Normal file
0
tests/api/answer/routes/__init__.py
Normal file
0
tests/api/answer/routes/__init__.py
Normal file
552
tests/api/answer/routes/test_base.py
Normal file
552
tests/api/answer/routes/test_base.py
Normal file
@@ -0,0 +1,552 @@
|
||||
import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from bson import ObjectId
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBaseAnswerValidation:
|
||||
def test_validate_request_passes_with_required_fields(
|
||||
self, mock_mongo_db, flask_app
|
||||
):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
data = {"question": "What is Python?"}
|
||||
|
||||
result = resource.validate_request(data)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_validate_request_fails_without_question(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
data = {}
|
||||
|
||||
result = resource.validate_request(data)
|
||||
|
||||
assert result is not None
|
||||
assert result.status_code == 400
|
||||
assert "question" in result.json["message"].lower()
|
||||
|
||||
def test_validate_with_conversation_id_required(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
data = {"question": "Test"}
|
||||
|
||||
result = resource.validate_request(data, require_conversation_id=True)
|
||||
|
||||
assert result is not None
|
||||
assert result.status_code == 400
|
||||
assert "conversation_id" in result.json["message"].lower()
|
||||
|
||||
def test_validate_passes_with_all_required_fields(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
data = {"question": "Test", "conversation_id": str(ObjectId())}
|
||||
|
||||
result = resource.validate_request(data, require_conversation_id=True)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestUsageChecking:
|
||||
def test_returns_none_when_no_api_key(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
agent_config = {}
|
||||
|
||||
result = resource.check_usage(agent_config)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_returns_error_for_invalid_api_key(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
agent_config = {"user_api_key": "invalid_key_123"}
|
||||
|
||||
result = resource.check_usage(agent_config)
|
||||
|
||||
assert result is not None
|
||||
assert result.status_code == 401
|
||||
assert result.json["success"] is False
|
||||
assert "invalid" in result.json["message"].lower()
|
||||
|
||||
def test_checks_token_limit_when_enabled(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
from application.core.settings import settings
|
||||
|
||||
with flask_app.app_context():
|
||||
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
|
||||
agent_id = ObjectId()
|
||||
|
||||
agents_collection.insert_one(
|
||||
{
|
||||
"_id": agent_id,
|
||||
"key": "test_key",
|
||||
"limited_token_mode": True,
|
||||
"token_limit": 1000,
|
||||
"limited_request_mode": False,
|
||||
}
|
||||
)
|
||||
|
||||
resource = BaseAnswerResource()
|
||||
agent_config = {"user_api_key": "test_key"}
|
||||
|
||||
result = resource.check_usage(agent_config)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_checks_request_limit_when_enabled(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
from application.core.settings import settings
|
||||
|
||||
with flask_app.app_context():
|
||||
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
|
||||
agent_id = ObjectId()
|
||||
|
||||
agents_collection.insert_one(
|
||||
{
|
||||
"_id": agent_id,
|
||||
"key": "test_key",
|
||||
"limited_token_mode": False,
|
||||
"limited_request_mode": True,
|
||||
"request_limit": 100,
|
||||
}
|
||||
)
|
||||
|
||||
resource = BaseAnswerResource()
|
||||
agent_config = {"user_api_key": "test_key"}
|
||||
|
||||
result = resource.check_usage(agent_config)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_uses_default_limits_when_not_specified(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
from application.core.settings import settings
|
||||
|
||||
with flask_app.app_context():
|
||||
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
|
||||
agent_id = ObjectId()
|
||||
|
||||
agents_collection.insert_one(
|
||||
{
|
||||
"_id": agent_id,
|
||||
"key": "test_key",
|
||||
"limited_token_mode": True,
|
||||
"limited_request_mode": True,
|
||||
}
|
||||
)
|
||||
|
||||
resource = BaseAnswerResource()
|
||||
agent_config = {"user_api_key": "test_key"}
|
||||
|
||||
result = resource.check_usage(agent_config)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_exceeds_token_limit(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
from application.core.settings import settings
|
||||
|
||||
with flask_app.app_context():
|
||||
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
|
||||
token_usage_collection = mock_mongo_db[settings.MONGO_DB_NAME][
|
||||
"token_usage"
|
||||
]
|
||||
agent_id = ObjectId()
|
||||
|
||||
agents_collection.insert_one(
|
||||
{
|
||||
"_id": agent_id,
|
||||
"key": "test_key",
|
||||
"limited_token_mode": True,
|
||||
"token_limit": 100,
|
||||
"limited_request_mode": False,
|
||||
}
|
||||
)
|
||||
|
||||
token_usage_collection.insert_one(
|
||||
{
|
||||
"_id": ObjectId(),
|
||||
"api_key": "test_key",
|
||||
"prompt_tokens": 60,
|
||||
"generated_tokens": 50,
|
||||
"timestamp": datetime.datetime.now(),
|
||||
}
|
||||
)
|
||||
|
||||
resource = BaseAnswerResource()
|
||||
agent_config = {"user_api_key": "test_key"}
|
||||
|
||||
result = resource.check_usage(agent_config)
|
||||
|
||||
assert result is not None
|
||||
assert result.status_code == 429
|
||||
assert result.json["success"] is False
|
||||
assert "usage limit" in result.json["message"].lower()
|
||||
|
||||
def test_exceeds_request_limit(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
from application.core.settings import settings
|
||||
|
||||
with flask_app.app_context():
|
||||
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
|
||||
token_usage_collection = mock_mongo_db[settings.MONGO_DB_NAME][
|
||||
"token_usage"
|
||||
]
|
||||
agent_id = ObjectId()
|
||||
|
||||
agents_collection.insert_one(
|
||||
{
|
||||
"_id": agent_id,
|
||||
"key": "test_key",
|
||||
"limited_token_mode": False,
|
||||
"limited_request_mode": True,
|
||||
"request_limit": 2,
|
||||
}
|
||||
)
|
||||
|
||||
now = datetime.datetime.now()
|
||||
for i in range(3):
|
||||
token_usage_collection.insert_one(
|
||||
{
|
||||
"_id": ObjectId(),
|
||||
"api_key": "test_key",
|
||||
"prompt_tokens": 10,
|
||||
"generated_tokens": 10,
|
||||
"timestamp": now,
|
||||
}
|
||||
)
|
||||
resource = BaseAnswerResource()
|
||||
agent_config = {"user_api_key": "test_key"}
|
||||
|
||||
result = resource.check_usage(agent_config)
|
||||
|
||||
assert result is not None
|
||||
assert result.status_code == 429
|
||||
assert result.json["success"] is False
|
||||
|
||||
def test_both_limits_disabled_returns_none(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
from application.core.settings import settings
|
||||
|
||||
with flask_app.app_context():
|
||||
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
|
||||
agent_id = ObjectId()
|
||||
|
||||
agents_collection.insert_one(
|
||||
{
|
||||
"_id": agent_id,
|
||||
"key": "test_key",
|
||||
"limited_token_mode": False,
|
||||
"limited_request_mode": False,
|
||||
}
|
||||
)
|
||||
|
||||
resource = BaseAnswerResource()
|
||||
agent_config = {"user_api_key": "test_key"}
|
||||
|
||||
result = resource.check_usage(agent_config)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGPTModelRetrieval:
|
||||
def test_initializes_gpt_model(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
|
||||
assert hasattr(resource, "gpt_model")
|
||||
assert resource.gpt_model is not None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestConversationServiceIntegration:
|
||||
def test_initializes_conversation_service(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
|
||||
assert hasattr(resource, "conversation_service")
|
||||
assert resource.conversation_service is not None
|
||||
|
||||
def test_has_access_to_user_logs_collection(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
|
||||
assert hasattr(resource, "user_logs_collection")
|
||||
assert resource.user_logs_collection is not None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCompleteStreamMethod:
|
||||
def test_streams_answer_chunks(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.gen.return_value = iter(
|
||||
[
|
||||
{"answer": "Hello "},
|
||||
{"answer": "world!"},
|
||||
]
|
||||
)
|
||||
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.get_params.return_value = {}
|
||||
|
||||
decoded_token = {"sub": "user123"}
|
||||
|
||||
stream = list(
|
||||
resource.complete_stream(
|
||||
question="Test question",
|
||||
agent=mock_agent,
|
||||
retriever=mock_retriever,
|
||||
conversation_id=None,
|
||||
user_api_key=None,
|
||||
decoded_token=decoded_token,
|
||||
should_save_conversation=False,
|
||||
)
|
||||
)
|
||||
|
||||
answer_chunks = [s for s in stream if '"type": "answer"' in s]
|
||||
assert len(answer_chunks) == 2
|
||||
assert '"answer": "Hello "' in answer_chunks[0]
|
||||
assert '"answer": "world!"' in answer_chunks[1]
|
||||
|
||||
def test_streams_sources(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.gen.return_value = iter(
|
||||
[
|
||||
{"answer": "Test answer"},
|
||||
{"sources": [{"title": "doc1.txt", "text": "x" * 200}]},
|
||||
]
|
||||
)
|
||||
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.get_params.return_value = {}
|
||||
|
||||
decoded_token = {"sub": "user123"}
|
||||
|
||||
stream = list(
|
||||
resource.complete_stream(
|
||||
question="Test?",
|
||||
agent=mock_agent,
|
||||
retriever=mock_retriever,
|
||||
conversation_id=None,
|
||||
user_api_key=None,
|
||||
decoded_token=decoded_token,
|
||||
should_save_conversation=False,
|
||||
)
|
||||
)
|
||||
|
||||
source_chunks = [s for s in stream if '"type": "source"' in s]
|
||||
assert len(source_chunks) == 1
|
||||
assert '"title": "doc1.txt"' in source_chunks[0]
|
||||
|
||||
def test_handles_error_during_streaming(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.gen.side_effect = Exception("Test error")
|
||||
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.get_params.return_value = {}
|
||||
|
||||
decoded_token = {"sub": "user123"}
|
||||
|
||||
stream = list(
|
||||
resource.complete_stream(
|
||||
question="Test?",
|
||||
agent=mock_agent,
|
||||
retriever=mock_retriever,
|
||||
conversation_id=None,
|
||||
user_api_key=None,
|
||||
decoded_token=decoded_token,
|
||||
should_save_conversation=False,
|
||||
)
|
||||
)
|
||||
|
||||
assert any('"type": "error"' in s for s in stream)
|
||||
|
||||
def test_saves_conversation_when_enabled(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.gen.return_value = iter(
|
||||
[
|
||||
{"answer": "Test answer"},
|
||||
]
|
||||
)
|
||||
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.get_params.return_value = {}
|
||||
|
||||
decoded_token = {"sub": "user123"}
|
||||
|
||||
with patch.object(
|
||||
resource.conversation_service, "save_conversation"
|
||||
) as mock_save:
|
||||
mock_save.return_value = str(ObjectId())
|
||||
|
||||
list(
|
||||
resource.complete_stream(
|
||||
question="Test?",
|
||||
agent=mock_agent,
|
||||
retriever=mock_retriever,
|
||||
conversation_id=None,
|
||||
user_api_key=None,
|
||||
decoded_token=decoded_token,
|
||||
should_save_conversation=True,
|
||||
)
|
||||
)
|
||||
|
||||
mock_save.assert_called_once()
|
||||
|
||||
def test_logs_to_user_logs_collection(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
from application.core.settings import settings
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
user_logs = mock_mongo_db[settings.MONGO_DB_NAME]["user_logs"]
|
||||
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.gen.return_value = iter(
|
||||
[
|
||||
{"answer": "Test answer"},
|
||||
]
|
||||
)
|
||||
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.get_params.return_value = {"retriever": "test"}
|
||||
|
||||
decoded_token = {"sub": "user123"}
|
||||
|
||||
list(
|
||||
resource.complete_stream(
|
||||
question="Test question?",
|
||||
agent=mock_agent,
|
||||
retriever=mock_retriever,
|
||||
conversation_id=None,
|
||||
user_api_key="test_key",
|
||||
decoded_token=decoded_token,
|
||||
should_save_conversation=False,
|
||||
)
|
||||
)
|
||||
|
||||
assert user_logs.count_documents({}) == 1
|
||||
log_entry = user_logs.find_one({})
|
||||
assert log_entry["action"] == "stream_answer"
|
||||
assert log_entry["user"] == "user123"
|
||||
assert log_entry["api_key"] == "test_key"
|
||||
assert log_entry["question"] == "Test question?"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestProcessResponseStream:
|
||||
def test_processes_complete_stream(self, mock_mongo_db, flask_app):
|
||||
import json
|
||||
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
|
||||
conv_id = str(ObjectId())
|
||||
stream = [
|
||||
f'data: {json.dumps({"type": "answer", "answer": "Hello "})}\n\n',
|
||||
f'data: {json.dumps({"type": "answer", "answer": "world"})}\n\n',
|
||||
f'data: {json.dumps({"type": "source", "source": [{"title": "doc1"}]})}\n\n',
|
||||
f'data: {json.dumps({"type": "id", "id": conv_id})}\n\n',
|
||||
f'data: {json.dumps({"type": "end"})}\n\n',
|
||||
]
|
||||
|
||||
result = resource.process_response_stream(iter(stream))
|
||||
|
||||
assert result[0] == conv_id
|
||||
assert result[1] == "Hello world"
|
||||
assert result[2] == [{"title": "doc1"}]
|
||||
assert result[5] is None
|
||||
|
||||
def test_handles_stream_error(self, mock_mongo_db, flask_app):
|
||||
import json
|
||||
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
|
||||
stream = [
|
||||
f'data: {json.dumps({"type": "error", "error": "Test error"})}\n\n',
|
||||
]
|
||||
|
||||
result = resource.process_response_stream(iter(stream))
|
||||
|
||||
assert len(result) == 5
|
||||
assert result[0] is None
|
||||
assert result[4] == "Test error"
|
||||
|
||||
def test_handles_malformed_stream_data(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
|
||||
stream = [
|
||||
"data: invalid json\n\n",
|
||||
'data: {"type": "end"}\n\n',
|
||||
]
|
||||
|
||||
result = resource.process_response_stream(iter(stream))
|
||||
|
||||
assert result is not None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestErrorStreamGenerate:
|
||||
def test_generates_error_stream(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
|
||||
error_stream = list(resource.error_stream_generate("Test error message"))
|
||||
|
||||
assert len(error_stream) == 1
|
||||
assert '"type": "error"' in error_stream[0]
|
||||
assert '"error": "Test error message"' in error_stream[0]
|
||||
0
tests/api/answer/services/__init__.py
Normal file
0
tests/api/answer/services/__init__.py
Normal file
242
tests/api/answer/services/test_conversation_service.py
Normal file
242
tests/api/answer/services/test_conversation_service.py
Normal file
@@ -0,0 +1,242 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from bson import ObjectId
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestConversationServiceGet:
|
||||
|
||||
def test_returns_none_when_no_conversation_id(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
|
||||
service = ConversationService()
|
||||
result = service.get_conversation("", "user_123")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_when_no_user_id(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
|
||||
service = ConversationService()
|
||||
result = service.get_conversation(str(ObjectId()), "")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_returns_conversation_for_owner(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
|
||||
service = ConversationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
|
||||
|
||||
conv_id = ObjectId()
|
||||
conversation = {
|
||||
"_id": conv_id,
|
||||
"user": "user_123",
|
||||
"name": "Test Conv",
|
||||
"queries": [],
|
||||
}
|
||||
collection.insert_one(conversation)
|
||||
|
||||
result = service.get_conversation(str(conv_id), "user_123")
|
||||
|
||||
assert result is not None
|
||||
assert result["name"] == "Test Conv"
|
||||
assert result["_id"] == str(conv_id)
|
||||
|
||||
def test_returns_none_for_unauthorized_user(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
|
||||
service = ConversationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
|
||||
|
||||
conv_id = ObjectId()
|
||||
collection.insert_one(
|
||||
{"_id": conv_id, "user": "owner_123", "name": "Private Conv"}
|
||||
)
|
||||
|
||||
result = service.get_conversation(str(conv_id), "hacker_456")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_converts_objectid_to_string(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
|
||||
service = ConversationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
|
||||
|
||||
conv_id = ObjectId()
|
||||
collection.insert_one({"_id": conv_id, "user": "user_123", "name": "Test"})
|
||||
|
||||
result = service.get_conversation(str(conv_id), "user_123")
|
||||
|
||||
assert isinstance(result["_id"], str)
|
||||
assert result["_id"] == str(conv_id)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestConversationServiceSave:
|
||||
|
||||
def test_raises_error_when_no_user_in_token(self, mock_mongo_db):
|
||||
"""Test validation: user ID required"""
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
|
||||
service = ConversationService()
|
||||
mock_llm = Mock()
|
||||
|
||||
with pytest.raises(ValueError, match="User ID not found"):
|
||||
service.save_conversation(
|
||||
conversation_id=None,
|
||||
question="Test?",
|
||||
response="Answer",
|
||||
thought="",
|
||||
sources=[],
|
||||
tool_calls=[],
|
||||
llm=mock_llm,
|
||||
gpt_model="gpt-4",
|
||||
decoded_token={}, # No 'sub' key
|
||||
)
|
||||
|
||||
def test_truncates_long_source_text(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
from bson import ObjectId
|
||||
|
||||
service = ConversationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
|
||||
|
||||
mock_llm = Mock()
|
||||
mock_llm.gen.return_value = "Test Summary"
|
||||
|
||||
long_text = "x" * 2000
|
||||
sources = [{"text": long_text, "title": "Doc"}]
|
||||
|
||||
conv_id = service.save_conversation(
|
||||
conversation_id=None,
|
||||
question="Question",
|
||||
response="Response",
|
||||
thought="",
|
||||
sources=sources,
|
||||
tool_calls=[],
|
||||
llm=mock_llm,
|
||||
gpt_model="gpt-4",
|
||||
decoded_token={"sub": "user_123"},
|
||||
)
|
||||
|
||||
saved_conv = collection.find_one({"_id": ObjectId(conv_id)})
|
||||
saved_source_text = saved_conv["queries"][0]["sources"][0]["text"]
|
||||
|
||||
assert len(saved_source_text) == 1000
|
||||
assert saved_source_text == "x" * 1000
|
||||
|
||||
def test_creates_new_conversation_with_summary(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
from bson import ObjectId
|
||||
|
||||
service = ConversationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
|
||||
|
||||
mock_llm = Mock()
|
||||
mock_llm.gen.return_value = "Python Basics"
|
||||
|
||||
conv_id = service.save_conversation(
|
||||
conversation_id=None,
|
||||
question="What is Python?",
|
||||
response="Python is a programming language",
|
||||
thought="",
|
||||
sources=[],
|
||||
tool_calls=[],
|
||||
llm=mock_llm,
|
||||
gpt_model="gpt-4",
|
||||
decoded_token={"sub": "user_123"},
|
||||
)
|
||||
|
||||
assert conv_id is not None
|
||||
saved_conv = collection.find_one({"_id": ObjectId(conv_id)})
|
||||
assert saved_conv["name"] == "Python Basics"
|
||||
assert saved_conv["user"] == "user_123"
|
||||
assert len(saved_conv["queries"]) == 1
|
||||
assert saved_conv["queries"][0]["prompt"] == "What is Python?"
|
||||
|
||||
def test_appends_to_existing_conversation(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
from bson import ObjectId
|
||||
|
||||
service = ConversationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
|
||||
|
||||
existing_conv_id = ObjectId()
|
||||
collection.insert_one(
|
||||
{
|
||||
"_id": existing_conv_id,
|
||||
"user": "user_123",
|
||||
"name": "Old Conv",
|
||||
"queries": [{"prompt": "Q1", "response": "A1"}],
|
||||
}
|
||||
)
|
||||
|
||||
mock_llm = Mock()
|
||||
|
||||
result = service.save_conversation(
|
||||
conversation_id=str(existing_conv_id),
|
||||
question="Q2",
|
||||
response="A2",
|
||||
thought="",
|
||||
sources=[],
|
||||
tool_calls=[],
|
||||
llm=mock_llm,
|
||||
gpt_model="gpt-4",
|
||||
decoded_token={"sub": "user_123"},
|
||||
)
|
||||
|
||||
assert result == str(existing_conv_id)
|
||||
|
||||
def test_prevents_unauthorized_conversation_update(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
|
||||
service = ConversationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
|
||||
|
||||
conv_id = ObjectId()
|
||||
collection.insert_one({"_id": conv_id, "user": "owner_123", "queries": []})
|
||||
|
||||
mock_llm = Mock()
|
||||
|
||||
with pytest.raises(ValueError, match="not found or unauthorized"):
|
||||
service.save_conversation(
|
||||
conversation_id=str(conv_id),
|
||||
question="Hack",
|
||||
response="Attempt",
|
||||
thought="",
|
||||
sources=[],
|
||||
tool_calls=[],
|
||||
llm=mock_llm,
|
||||
gpt_model="gpt-4",
|
||||
decoded_token={"sub": "hacker_456"},
|
||||
)
|
||||
252
tests/api/answer/services/test_stream_processor.py
Normal file
252
tests/api/answer/services/test_stream_processor.py
Normal file
@@ -0,0 +1,252 @@
|
||||
import pytest
|
||||
from bson import ObjectId
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetPromptFunction:
|
||||
|
||||
def test_loads_custom_prompt_from_database(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import get_prompt
|
||||
from application.core.settings import settings
|
||||
|
||||
prompts_collection = mock_mongo_db[settings.MONGO_DB_NAME]["prompts"]
|
||||
prompt_id = ObjectId()
|
||||
|
||||
prompts_collection.insert_one(
|
||||
{
|
||||
"_id": prompt_id,
|
||||
"content": "Custom prompt from database",
|
||||
"user": "user_123",
|
||||
}
|
||||
)
|
||||
|
||||
result = get_prompt(str(prompt_id), prompts_collection)
|
||||
assert result == "Custom prompt from database"
|
||||
|
||||
def test_raises_error_for_invalid_prompt_id(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import get_prompt
|
||||
from application.core.settings import settings
|
||||
|
||||
prompts_collection = mock_mongo_db[settings.MONGO_DB_NAME]["prompts"]
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid prompt ID"):
|
||||
get_prompt(str(ObjectId()), prompts_collection)
|
||||
|
||||
def test_raises_error_for_malformed_id(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import get_prompt
|
||||
from application.core.settings import settings
|
||||
|
||||
prompts_collection = mock_mongo_db[settings.MONGO_DB_NAME]["prompts"]
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid prompt ID"):
|
||||
get_prompt("not_a_valid_id", prompts_collection)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestStreamProcessorInitialization:
|
||||
|
||||
def test_initializes_with_decoded_token(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
|
||||
request_data = {
|
||||
"question": "What is Python?",
|
||||
"conversation_id": str(ObjectId()),
|
||||
}
|
||||
decoded_token = {"sub": "user_123", "email": "test@example.com"}
|
||||
|
||||
processor = StreamProcessor(request_data, decoded_token)
|
||||
|
||||
assert processor.data == request_data
|
||||
assert processor.decoded_token == decoded_token
|
||||
assert processor.initial_user_id == "user_123"
|
||||
assert processor.conversation_id == request_data["conversation_id"]
|
||||
|
||||
def test_initializes_without_token(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
|
||||
request_data = {"question": "Test question"}
|
||||
|
||||
processor = StreamProcessor(request_data, None)
|
||||
|
||||
assert processor.decoded_token is None
|
||||
assert processor.initial_user_id is None
|
||||
assert processor.data == request_data
|
||||
|
||||
def test_initializes_default_attributes(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
|
||||
processor = StreamProcessor({"question": "Test"}, {"sub": "user_123"})
|
||||
|
||||
assert processor.source == {}
|
||||
assert processor.all_sources == []
|
||||
assert processor.attachments == []
|
||||
assert processor.history == []
|
||||
assert processor.agent_config == {}
|
||||
assert processor.retriever_config == {}
|
||||
assert processor.is_shared_usage is False
|
||||
assert processor.shared_token is None
|
||||
|
||||
def test_extracts_conversation_id_from_request(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
|
||||
conv_id = str(ObjectId())
|
||||
request_data = {"question": "Test", "conversation_id": conv_id}
|
||||
|
||||
processor = StreamProcessor(request_data, {"sub": "user_123"})
|
||||
|
||||
assert processor.conversation_id == conv_id
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestStreamProcessorHistoryLoading:
|
||||
|
||||
def test_loads_history_from_existing_conversation(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.core.settings import settings
|
||||
|
||||
conversations_collection = mock_mongo_db[settings.MONGO_DB_NAME][
|
||||
"conversations"
|
||||
]
|
||||
conv_id = ObjectId()
|
||||
|
||||
conversations_collection.insert_one(
|
||||
{
|
||||
"_id": conv_id,
|
||||
"user": "user_123",
|
||||
"name": "Test Conv",
|
||||
"queries": [
|
||||
{"prompt": "What is Python?", "response": "Python is a language"},
|
||||
{"prompt": "Tell me more", "response": "Python is versatile"},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"question": "How to install it?",
|
||||
"conversation_id": str(conv_id),
|
||||
}
|
||||
|
||||
processor = StreamProcessor(request_data, {"sub": "user_123"})
|
||||
processor._load_conversation_history()
|
||||
|
||||
assert len(processor.history) == 2
|
||||
assert processor.history[0]["prompt"] == "What is Python?"
|
||||
assert processor.history[1]["response"] == "Python is versatile"
|
||||
|
||||
def test_raises_error_for_unauthorized_conversation(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.core.settings import settings
|
||||
|
||||
conversations_collection = mock_mongo_db[settings.MONGO_DB_NAME][
|
||||
"conversations"
|
||||
]
|
||||
conv_id = ObjectId()
|
||||
|
||||
conversations_collection.insert_one(
|
||||
{
|
||||
"_id": conv_id,
|
||||
"user": "owner_123",
|
||||
"name": "Private Conv",
|
||||
"queries": [],
|
||||
}
|
||||
)
|
||||
|
||||
request_data = {"question": "Hack attempt", "conversation_id": str(conv_id)}
|
||||
|
||||
processor = StreamProcessor(request_data, {"sub": "hacker_456"})
|
||||
|
||||
with pytest.raises(ValueError, match="Conversation not found or unauthorized"):
|
||||
processor._load_conversation_history()
|
||||
|
||||
def test_uses_request_history_when_no_conversation_id(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
|
||||
request_data = {
|
||||
"question": "What is Python?",
|
||||
"history": [{"prompt": "Hello", "response": "Hi there!"}],
|
||||
}
|
||||
|
||||
processor = StreamProcessor(request_data, {"sub": "user_123"})
|
||||
|
||||
assert processor.conversation_id is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestStreamProcessorAgentConfiguration:
|
||||
|
||||
def test_configures_agent_from_valid_api_key(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.core.settings import settings
|
||||
|
||||
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
|
||||
agent_id = ObjectId()
|
||||
|
||||
agents_collection.insert_one(
|
||||
{
|
||||
"_id": agent_id,
|
||||
"key": "test_api_key_123",
|
||||
"endpoint": "openai",
|
||||
"model": "gpt-4",
|
||||
"prompt_id": "default",
|
||||
"user": "user_123",
|
||||
}
|
||||
)
|
||||
|
||||
request_data = {"question": "Test", "api_key": "test_api_key_123"}
|
||||
|
||||
processor = StreamProcessor(request_data, None)
|
||||
|
||||
try:
|
||||
processor._configure_agent()
|
||||
assert processor.agent_config is not None
|
||||
except Exception as e:
|
||||
assert "Invalid API Key" in str(e)
|
||||
|
||||
def test_uses_default_config_without_api_key(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
|
||||
request_data = {"question": "Test"}
|
||||
|
||||
processor = StreamProcessor(request_data, {"sub": "user_123"})
|
||||
processor._configure_agent()
|
||||
|
||||
assert isinstance(processor.agent_config, dict)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestStreamProcessorAttachments:
|
||||
|
||||
def test_processes_attachments_from_request(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.core.settings import settings
|
||||
|
||||
attachments_collection = mock_mongo_db[settings.MONGO_DB_NAME]["attachments"]
|
||||
att_id = ObjectId()
|
||||
|
||||
attachments_collection.insert_one(
|
||||
{
|
||||
"_id": att_id,
|
||||
"filename": "document.pdf",
|
||||
"content": "Document content",
|
||||
"user": "user_123",
|
||||
}
|
||||
)
|
||||
|
||||
request_data = {"question": "Analyze this", "attachments": [str(att_id)]}
|
||||
|
||||
processor = StreamProcessor(request_data, {"sub": "user_123"})
|
||||
|
||||
assert processor.data.get("attachments") == [str(att_id)]
|
||||
|
||||
def test_handles_empty_attachments(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
|
||||
request_data = {"question": "Simple question"}
|
||||
|
||||
processor = StreamProcessor(request_data, {"sub": "user_123"})
|
||||
|
||||
assert processor.attachments == []
|
||||
assert (
|
||||
"attachments" not in processor.data
|
||||
or processor.data.get("attachments") is None
|
||||
)
|
||||
89
tests/api/conftest.py
Normal file
89
tests/api/conftest.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""API-specific test fixtures."""
|
||||
|
||||
import pytest
|
||||
from bson import ObjectId
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_headers():
|
||||
return {"Authorization": "Bearer test_token"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request_token(monkeypatch, decoded_token):
|
||||
def mock_decorator(f):
|
||||
def wrapper(*args, **kwargs):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = decoded_token
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
monkeypatch.setattr("application.auth.api_key_required", lambda: mock_decorator)
|
||||
return decoded_token
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_conversation():
|
||||
return {
|
||||
"_id": ObjectId(),
|
||||
"user": "test_user",
|
||||
"name": "Test Conversation",
|
||||
"queries": [
|
||||
{
|
||||
"prompt": "What is Python?",
|
||||
"response": "Python is a programming language",
|
||||
}
|
||||
],
|
||||
"date": "2025-01-01T00:00:00",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_prompt():
|
||||
return {
|
||||
"_id": ObjectId(),
|
||||
"user": "test_user",
|
||||
"name": "Helpful Assistant",
|
||||
"content": "You are a helpful assistant that provides clear and concise answers.",
|
||||
"type": "custom",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_agent():
|
||||
return {
|
||||
"_id": ObjectId(),
|
||||
"user": "test_user",
|
||||
"name": "Test Agent",
|
||||
"type": "classic",
|
||||
"endpoint": "openai",
|
||||
"model": "gpt-4",
|
||||
"prompt_id": "default",
|
||||
"status": "active",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_answer_request():
|
||||
return {
|
||||
"question": "What is Python?",
|
||||
"history": [],
|
||||
"conversation_id": None,
|
||||
"prompt_id": "default",
|
||||
"chunks": 2,
|
||||
"token_limit": 1000,
|
||||
"retriever": "classic_rag",
|
||||
"active_docs": "local/test/",
|
||||
"isNoneDoc": False,
|
||||
"save_conversation": True,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flask_app():
|
||||
from flask import Flask
|
||||
|
||||
app = Flask(__name__)
|
||||
return app
|
||||
311
tests/api/user/test_base.py
Normal file
311
tests/api/user/test_base.py
Normal file
@@ -0,0 +1,311 @@
|
||||
import datetime
|
||||
import io
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from bson import ObjectId
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestTimeRangeGenerators:
|
||||
|
||||
def test_generate_minute_range(self):
|
||||
from application.api.user.base import generate_minute_range
|
||||
|
||||
start = datetime.datetime(2024, 1, 1, 10, 0, 0)
|
||||
end = datetime.datetime(2024, 1, 1, 10, 5, 0)
|
||||
|
||||
result = generate_minute_range(start, end)
|
||||
|
||||
assert len(result) == 6
|
||||
assert "2024-01-01 10:00:00" in result
|
||||
assert "2024-01-01 10:05:00" in result
|
||||
assert all(val == 0 for val in result.values())
|
||||
|
||||
def test_generate_hourly_range(self):
|
||||
from application.api.user.base import generate_hourly_range
|
||||
|
||||
start = datetime.datetime(2024, 1, 1, 10, 0, 0)
|
||||
end = datetime.datetime(2024, 1, 1, 15, 0, 0)
|
||||
|
||||
result = generate_hourly_range(start, end)
|
||||
|
||||
assert len(result) == 6
|
||||
assert "2024-01-01 10:00" in result
|
||||
assert "2024-01-01 15:00" in result
|
||||
assert all(val == 0 for val in result.values())
|
||||
|
||||
def test_generate_date_range(self):
|
||||
from application.api.user.base import generate_date_range
|
||||
|
||||
start = datetime.date(2024, 1, 1)
|
||||
end = datetime.date(2024, 1, 5)
|
||||
|
||||
result = generate_date_range(start, end)
|
||||
|
||||
assert len(result) == 5
|
||||
assert "2024-01-01" in result
|
||||
assert "2024-01-05" in result
|
||||
assert all(val == 0 for val in result.values())
|
||||
|
||||
def test_single_minute_range(self):
|
||||
from application.api.user.base import generate_minute_range
|
||||
|
||||
time = datetime.datetime(2024, 1, 1, 10, 30, 0)
|
||||
result = generate_minute_range(time, time)
|
||||
|
||||
assert len(result) == 1
|
||||
assert "2024-01-01 10:30:00" in result
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestEnsureUserDoc:
|
||||
|
||||
def test_creates_new_user_with_defaults(self, mock_mongo_db):
|
||||
from application.api.user.base import ensure_user_doc
|
||||
|
||||
user_id = "test_user_123"
|
||||
|
||||
result = ensure_user_doc(user_id)
|
||||
|
||||
assert result is not None
|
||||
assert result["user_id"] == user_id
|
||||
assert "agent_preferences" in result
|
||||
assert result["agent_preferences"]["pinned"] == []
|
||||
assert result["agent_preferences"]["shared_with_me"] == []
|
||||
|
||||
def test_returns_existing_user(self, mock_mongo_db):
|
||||
from application.api.user.base import ensure_user_doc
|
||||
from application.core.settings import settings
|
||||
|
||||
users_collection = mock_mongo_db[settings.MONGO_DB_NAME]["users"]
|
||||
user_id = "existing_user"
|
||||
|
||||
existing_doc = {
|
||||
"user_id": user_id,
|
||||
"agent_preferences": {"pinned": ["agent1"], "shared_with_me": ["agent2"]},
|
||||
}
|
||||
users_collection.insert_one(existing_doc)
|
||||
|
||||
result = ensure_user_doc(user_id)
|
||||
|
||||
assert result["user_id"] == user_id
|
||||
assert result["agent_preferences"]["pinned"] == ["agent1"]
|
||||
assert result["agent_preferences"]["shared_with_me"] == ["agent2"]
|
||||
|
||||
def test_adds_missing_preferences_fields(self, mock_mongo_db):
|
||||
from application.api.user.base import ensure_user_doc
|
||||
from application.core.settings import settings
|
||||
|
||||
users_collection = mock_mongo_db[settings.MONGO_DB_NAME]["users"]
|
||||
user_id = "incomplete_user"
|
||||
|
||||
users_collection.insert_one(
|
||||
{"user_id": user_id, "agent_preferences": {"pinned": ["agent1"]}}
|
||||
)
|
||||
|
||||
result = ensure_user_doc(user_id)
|
||||
|
||||
assert "shared_with_me" in result["agent_preferences"]
|
||||
assert result["agent_preferences"]["shared_with_me"] == []
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestResolveToolDetails:
|
||||
|
||||
def test_resolves_tool_ids_to_details(self, mock_mongo_db):
|
||||
from application.api.user.base import resolve_tool_details
|
||||
from application.core.settings import settings
|
||||
|
||||
user_tools = mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"]
|
||||
tool_id1 = ObjectId()
|
||||
tool_id2 = ObjectId()
|
||||
|
||||
user_tools.insert_one(
|
||||
{"_id": tool_id1, "name": "calculator", "displayName": "Calculator Tool"}
|
||||
)
|
||||
user_tools.insert_one(
|
||||
{"_id": tool_id2, "name": "weather", "displayName": "Weather API"}
|
||||
)
|
||||
|
||||
result = resolve_tool_details([str(tool_id1), str(tool_id2)])
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["id"] == str(tool_id1)
|
||||
assert result[0]["name"] == "calculator"
|
||||
assert result[0]["display_name"] == "Calculator Tool"
|
||||
assert result[1]["name"] == "weather"
|
||||
|
||||
def test_handles_missing_display_name(self, mock_mongo_db):
|
||||
from application.api.user.base import resolve_tool_details
|
||||
from application.core.settings import settings
|
||||
|
||||
user_tools = mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"]
|
||||
tool_id = ObjectId()
|
||||
|
||||
user_tools.insert_one({"_id": tool_id, "name": "test_tool"})
|
||||
|
||||
result = resolve_tool_details([str(tool_id)])
|
||||
|
||||
assert result[0]["display_name"] == "test_tool"
|
||||
|
||||
def test_empty_tool_ids_list(self, mock_mongo_db):
|
||||
from application.api.user.base import resolve_tool_details
|
||||
|
||||
result = resolve_tool_details([])
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetVectorStore:
|
||||
|
||||
@patch("application.api.user.base.VectorCreator.create_vectorstore")
|
||||
def test_creates_vector_store(self, mock_create, mock_mongo_db):
|
||||
from application.api.user.base import get_vector_store
|
||||
|
||||
mock_store = Mock()
|
||||
mock_create.return_value = mock_store
|
||||
source_id = "test_source_123"
|
||||
|
||||
result = get_vector_store(source_id)
|
||||
|
||||
assert result == mock_store
|
||||
mock_create.assert_called_once()
|
||||
args, kwargs = mock_create.call_args
|
||||
assert kwargs.get("source_id") == source_id
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestHandleImageUpload:
|
||||
|
||||
def test_returns_existing_url_when_no_file(self, flask_app):
|
||||
from application.api.user.base import handle_image_upload
|
||||
|
||||
with flask_app.test_request_context():
|
||||
mock_request = Mock()
|
||||
mock_request.files = {}
|
||||
mock_storage = Mock()
|
||||
existing_url = "existing/path/image.jpg"
|
||||
|
||||
url, error = handle_image_upload(
|
||||
mock_request, existing_url, "user123", mock_storage
|
||||
)
|
||||
|
||||
assert url == existing_url
|
||||
assert error is None
|
||||
|
||||
def test_uploads_new_image(self, flask_app):
|
||||
from application.api.user.base import handle_image_upload
|
||||
|
||||
with flask_app.test_request_context():
|
||||
mock_file = FileStorage(
|
||||
stream=io.BytesIO(b"fake image data"), filename="test_image.png"
|
||||
)
|
||||
mock_request = Mock()
|
||||
mock_request.files = {"image": mock_file}
|
||||
mock_storage = Mock()
|
||||
mock_storage.save_file.return_value = {"success": True}
|
||||
|
||||
url, error = handle_image_upload(
|
||||
mock_request, "old_url", "user123", mock_storage
|
||||
)
|
||||
|
||||
assert error is None
|
||||
assert url is not None
|
||||
assert "test_image.png" in url
|
||||
assert "user123" in url
|
||||
mock_storage.save_file.assert_called_once()
|
||||
|
||||
def test_ignores_empty_filename(self, flask_app):
|
||||
from application.api.user.base import handle_image_upload
|
||||
|
||||
with flask_app.test_request_context():
|
||||
mock_file = Mock()
|
||||
mock_file.filename = ""
|
||||
mock_request = Mock()
|
||||
mock_request.files = {"image": mock_file}
|
||||
mock_storage = Mock()
|
||||
existing_url = "existing.jpg"
|
||||
|
||||
url, error = handle_image_upload(
|
||||
mock_request, existing_url, "user123", mock_storage
|
||||
)
|
||||
|
||||
assert url == existing_url
|
||||
assert error is None
|
||||
mock_storage.save_file.assert_not_called()
|
||||
|
||||
def test_handles_upload_error(self, flask_app):
|
||||
from application.api.user.base import handle_image_upload
|
||||
|
||||
with flask_app.app_context():
|
||||
mock_file = FileStorage(stream=io.BytesIO(b"data"), filename="test.png")
|
||||
mock_request = Mock()
|
||||
mock_request.files = {"image": mock_file}
|
||||
mock_storage = Mock()
|
||||
mock_storage.save_file.side_effect = Exception("Storage error")
|
||||
|
||||
url, error = handle_image_upload(
|
||||
mock_request, "old.jpg", "user123", mock_storage
|
||||
)
|
||||
|
||||
assert url is None
|
||||
assert error is not None
|
||||
assert error.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestRequireAgentDecorator:
|
||||
|
||||
def test_validates_webhook_token(self, mock_mongo_db, flask_app):
|
||||
from application.api.user.base import require_agent
|
||||
from application.core.settings import settings
|
||||
|
||||
with flask_app.app_context():
|
||||
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
|
||||
agent_id = ObjectId()
|
||||
webhook_token = "valid_webhook_token_123"
|
||||
|
||||
agents_collection.insert_one(
|
||||
{"_id": agent_id, "incoming_webhook_token": webhook_token}
|
||||
)
|
||||
|
||||
@require_agent
|
||||
def test_func(webhook_token=None, agent=None, agent_id_str=None):
|
||||
return {"agent_id": agent_id_str}
|
||||
|
||||
result = test_func(webhook_token=webhook_token)
|
||||
|
||||
assert result["agent_id"] == str(agent_id)
|
||||
|
||||
def test_returns_400_for_missing_token(self, mock_mongo_db, flask_app):
|
||||
from application.api.user.base import require_agent
|
||||
|
||||
with flask_app.app_context():
|
||||
|
||||
@require_agent
|
||||
def test_func(webhook_token=None, agent=None, agent_id_str=None):
|
||||
return {"success": True}
|
||||
|
||||
result = test_func()
|
||||
|
||||
assert result.status_code == 400
|
||||
assert result.json["success"] is False
|
||||
assert "missing" in result.json["message"].lower()
|
||||
|
||||
def test_returns_404_for_invalid_token(self, mock_mongo_db, flask_app):
|
||||
from application.api.user.base import require_agent
|
||||
|
||||
with flask_app.app_context():
|
||||
|
||||
@require_agent
|
||||
def test_func(webhook_token=None, agent=None, agent_id_str=None):
|
||||
return {"success": True}
|
||||
|
||||
result = test_func(webhook_token="invalid_token_999")
|
||||
|
||||
assert result.status_code == 404
|
||||
assert result.json["success"] is False
|
||||
assert "not found" in result.json["message"].lower()
|
||||
Reference in New Issue
Block a user