diff --git a/application/utils.py b/application/utils.py index 5ef28376..528cbac5 100644 --- a/application/utils.py +++ b/application/utils.py @@ -21,7 +21,7 @@ def get_encoding(): def get_gpt_model() -> str: - """Get the appropriate GPT model based on provider""" + """Get GPT model based on provider""" model_map = { "openai": "gpt-4o-mini", "anthropic": "claude-2", @@ -32,16 +32,7 @@ def get_gpt_model() -> str: def safe_filename(filename): - """ - Creates a safe filename that preserves the original extension. - Uses secure_filename, but ensures a proper filename is returned even with non-Latin characters. - - Args: - filename (str): The original filename - - Returns: - str: A safe filename that can be used for storage - """ + """Create safe filename, preserving extension. Handles non-Latin characters.""" if not filename: return str(uuid.uuid4()) _, extension = os.path.splitext(filename) @@ -83,8 +74,14 @@ def count_tokens_docs(docs): return tokens +def get_missing_fields(data, required_fields): + """Check for missing required fields. Returns list of missing field names.""" + return [field for field in required_fields if field not in data] + + def check_required_fields(data, required_fields): - missing_fields = [field for field in required_fields if field not in data] + """Validate required fields. Returns Flask 400 response if validation fails, None otherwise.""" + missing_fields = get_missing_fields(data, required_fields) if missing_fields: return make_response( jsonify( @@ -98,7 +95,8 @@ def check_required_fields(data, required_fields): return None -def validate_required_fields(data, required_fields): +def get_field_validation_errors(data, required_fields): + """Check for missing and empty fields. Returns dict with 'missing_fields' and 'empty_fields', or None.""" missing_fields = [] empty_fields = [] @@ -107,12 +105,24 @@ def validate_required_fields(data, required_fields): missing_fields.append(field) elif not data[field]: empty_fields.append(field) - errors = [] - if missing_fields: - errors.append(f"Missing required fields: {', '.join(missing_fields)}") - if empty_fields: - errors.append(f"Empty values in required fields: {', '.join(empty_fields)}") - if errors: + if missing_fields or empty_fields: + return {"missing_fields": missing_fields, "empty_fields": empty_fields} + return None + + +def validate_required_fields(data, required_fields): + """Validate required fields (must exist and be non-empty). Returns Flask 400 response if validation fails, None otherwise.""" + errors_dict = get_field_validation_errors(data, required_fields) + if errors_dict: + errors = [] + if errors_dict["missing_fields"]: + errors.append( + f"Missing required fields: {', '.join(errors_dict['missing_fields'])}" + ) + if errors_dict["empty_fields"]: + errors.append( + f"Empty values in required fields: {', '.join(errors_dict['empty_fields'])}" + ) return make_response( jsonify({"success": False, "message": " | ".join(errors)}), 400 ) @@ -124,10 +134,7 @@ def get_hash(data): def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"): - """ - Limits chat history based on token count. - Returns a list of messages that fit within the token limit. - """ + """Limit chat history to fit within token limit.""" from application.core.settings import settings max_token_limit = ( @@ -161,7 +168,7 @@ def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"): def validate_function_name(function_name): - """Validates if a function name matches the allowed pattern.""" + """Validate function name matches allowed pattern (alphanumeric, underscore, hyphen).""" if not re.match(r"^[a-zA-Z0-9_-]+$", function_name): return False return True diff --git a/tests/agents/test_base_agent.py b/tests/agents/test_base_agent.py index 50195da4..6f828499 100644 --- a/tests/agents/test_base_agent.py +++ b/tests/agents/test_base_agent.py @@ -3,7 +3,6 @@ from unittest.mock import Mock import pytest from application.agents.classic_agent import ClassicAgent from application.core.settings import settings -from tests.conftest import FakeMongoCollection @pytest.mark.unit @@ -168,10 +167,13 @@ class TestBaseAgentTools: mock_llm_creator, mock_llm_handler_creator, ): - mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"].docs = { - "1": {"_id": "1", "user": "test_user", "name": "tool1", "status": True}, - "2": {"_id": "2", "user": "test_user", "name": "tool2", "status": True}, - } + user_tools = mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"] + user_tools.insert_one( + {"_id": "1", "user": "test_user", "name": "tool1", "status": True} + ) + user_tools.insert_one( + {"_id": "2", "user": "test_user", "name": "tool2", "status": True} + ) agent = ClassicAgent(**agent_base_params) tools = agent._get_user_tools("test_user") @@ -187,10 +189,13 @@ class TestBaseAgentTools: mock_llm_creator, mock_llm_handler_creator, ): - mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"].docs = { - "1": {"_id": "1", "user": "test_user", "name": "tool1", "status": True}, - "2": {"_id": "2", "user": "test_user", "name": "tool2", "status": False}, - } + user_tools = mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"] + user_tools.insert_one( + {"_id": "1", "user": "test_user", "name": "tool1", "status": True} + ) + user_tools.insert_one( + {"_id": "2", "user": "test_user", "name": "tool2", "status": False} + ) agent = ClassicAgent(**agent_base_params) tools = agent._get_user_tools("test_user") @@ -209,17 +214,16 @@ class TestBaseAgentTools: tool_id = str(ObjectId()) tool_obj_id = ObjectId(tool_id) - fake_agent_collection = FakeMongoCollection() - fake_agent_collection.docs["api_key_123"] = { - "key": "api_key_123", - "tools": [tool_id], - } + agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"] + agents_collection.insert_one( + { + "key": "api_key_123", + "tools": [tool_id], + } + ) - fake_tools_collection = FakeMongoCollection() - fake_tools_collection.docs[tool_id] = {"_id": tool_obj_id, "name": "api_tool"} - - mock_mongo_db[settings.MONGO_DB_NAME]["agents"] = fake_agent_collection - mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"] = fake_tools_collection + tools_collection = mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"] + tools_collection.insert_one({"_id": tool_obj_id, "name": "api_tool"}) agent = ClassicAgent(**agent_base_params) tools = agent._get_tools("api_key_123") diff --git a/tests/api/__init__.py b/tests/api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/api/answer/__init__.py b/tests/api/answer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/api/answer/routes/__init__.py b/tests/api/answer/routes/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/api/answer/routes/test_base.py b/tests/api/answer/routes/test_base.py new file mode 100644 index 00000000..1a184108 --- /dev/null +++ b/tests/api/answer/routes/test_base.py @@ -0,0 +1,552 @@ +import datetime +from unittest.mock import MagicMock, patch + +import pytest +from bson import ObjectId + + +@pytest.mark.unit +class TestBaseAnswerValidation: + def test_validate_request_passes_with_required_fields( + self, mock_mongo_db, flask_app + ): + from application.api.answer.routes.base import BaseAnswerResource + + with flask_app.app_context(): + resource = BaseAnswerResource() + data = {"question": "What is Python?"} + + result = resource.validate_request(data) + + assert result is None + + def test_validate_request_fails_without_question(self, mock_mongo_db, flask_app): + from application.api.answer.routes.base import BaseAnswerResource + + with flask_app.app_context(): + resource = BaseAnswerResource() + data = {} + + result = resource.validate_request(data) + + assert result is not None + assert result.status_code == 400 + assert "question" in result.json["message"].lower() + + def test_validate_with_conversation_id_required(self, mock_mongo_db, flask_app): + from application.api.answer.routes.base import BaseAnswerResource + + with flask_app.app_context(): + resource = BaseAnswerResource() + data = {"question": "Test"} + + result = resource.validate_request(data, require_conversation_id=True) + + assert result is not None + assert result.status_code == 400 + assert "conversation_id" in result.json["message"].lower() + + def test_validate_passes_with_all_required_fields(self, mock_mongo_db, flask_app): + from application.api.answer.routes.base import BaseAnswerResource + + with flask_app.app_context(): + resource = BaseAnswerResource() + data = {"question": "Test", "conversation_id": str(ObjectId())} + + result = resource.validate_request(data, require_conversation_id=True) + + assert result is None + + +@pytest.mark.unit +class TestUsageChecking: + def test_returns_none_when_no_api_key(self, mock_mongo_db, flask_app): + from application.api.answer.routes.base import BaseAnswerResource + + with flask_app.app_context(): + resource = BaseAnswerResource() + agent_config = {} + + result = resource.check_usage(agent_config) + + assert result is None + + def test_returns_error_for_invalid_api_key(self, mock_mongo_db, flask_app): + from application.api.answer.routes.base import BaseAnswerResource + + with flask_app.app_context(): + resource = BaseAnswerResource() + agent_config = {"user_api_key": "invalid_key_123"} + + result = resource.check_usage(agent_config) + + assert result is not None + assert result.status_code == 401 + assert result.json["success"] is False + assert "invalid" in result.json["message"].lower() + + def test_checks_token_limit_when_enabled(self, mock_mongo_db, flask_app): + from application.api.answer.routes.base import BaseAnswerResource + from application.core.settings import settings + + with flask_app.app_context(): + agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"] + agent_id = ObjectId() + + agents_collection.insert_one( + { + "_id": agent_id, + "key": "test_key", + "limited_token_mode": True, + "token_limit": 1000, + "limited_request_mode": False, + } + ) + + resource = BaseAnswerResource() + agent_config = {"user_api_key": "test_key"} + + result = resource.check_usage(agent_config) + + assert result is None + + def test_checks_request_limit_when_enabled(self, mock_mongo_db, flask_app): + from application.api.answer.routes.base import BaseAnswerResource + from application.core.settings import settings + + with flask_app.app_context(): + agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"] + agent_id = ObjectId() + + agents_collection.insert_one( + { + "_id": agent_id, + "key": "test_key", + "limited_token_mode": False, + "limited_request_mode": True, + "request_limit": 100, + } + ) + + resource = BaseAnswerResource() + agent_config = {"user_api_key": "test_key"} + + result = resource.check_usage(agent_config) + + assert result is None + + def test_uses_default_limits_when_not_specified(self, mock_mongo_db, flask_app): + from application.api.answer.routes.base import BaseAnswerResource + from application.core.settings import settings + + with flask_app.app_context(): + agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"] + agent_id = ObjectId() + + agents_collection.insert_one( + { + "_id": agent_id, + "key": "test_key", + "limited_token_mode": True, + "limited_request_mode": True, + } + ) + + resource = BaseAnswerResource() + agent_config = {"user_api_key": "test_key"} + + result = resource.check_usage(agent_config) + + assert result is None + + def test_exceeds_token_limit(self, mock_mongo_db, flask_app): + from application.api.answer.routes.base import BaseAnswerResource + from application.core.settings import settings + + with flask_app.app_context(): + agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"] + token_usage_collection = mock_mongo_db[settings.MONGO_DB_NAME][ + "token_usage" + ] + agent_id = ObjectId() + + agents_collection.insert_one( + { + "_id": agent_id, + "key": "test_key", + "limited_token_mode": True, + "token_limit": 100, + "limited_request_mode": False, + } + ) + + token_usage_collection.insert_one( + { + "_id": ObjectId(), + "api_key": "test_key", + "prompt_tokens": 60, + "generated_tokens": 50, + "timestamp": datetime.datetime.now(), + } + ) + + resource = BaseAnswerResource() + agent_config = {"user_api_key": "test_key"} + + result = resource.check_usage(agent_config) + + assert result is not None + assert result.status_code == 429 + assert result.json["success"] is False + assert "usage limit" in result.json["message"].lower() + + def test_exceeds_request_limit(self, mock_mongo_db, flask_app): + from application.api.answer.routes.base import BaseAnswerResource + from application.core.settings import settings + + with flask_app.app_context(): + agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"] + token_usage_collection = mock_mongo_db[settings.MONGO_DB_NAME][ + "token_usage" + ] + agent_id = ObjectId() + + agents_collection.insert_one( + { + "_id": agent_id, + "key": "test_key", + "limited_token_mode": False, + "limited_request_mode": True, + "request_limit": 2, + } + ) + + now = datetime.datetime.now() + for i in range(3): + token_usage_collection.insert_one( + { + "_id": ObjectId(), + "api_key": "test_key", + "prompt_tokens": 10, + "generated_tokens": 10, + "timestamp": now, + } + ) + resource = BaseAnswerResource() + agent_config = {"user_api_key": "test_key"} + + result = resource.check_usage(agent_config) + + assert result is not None + assert result.status_code == 429 + assert result.json["success"] is False + + def test_both_limits_disabled_returns_none(self, mock_mongo_db, flask_app): + from application.api.answer.routes.base import BaseAnswerResource + from application.core.settings import settings + + with flask_app.app_context(): + agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"] + agent_id = ObjectId() + + agents_collection.insert_one( + { + "_id": agent_id, + "key": "test_key", + "limited_token_mode": False, + "limited_request_mode": False, + } + ) + + resource = BaseAnswerResource() + agent_config = {"user_api_key": "test_key"} + + result = resource.check_usage(agent_config) + + assert result is None + + +@pytest.mark.unit +class TestGPTModelRetrieval: + def test_initializes_gpt_model(self, mock_mongo_db, flask_app): + from application.api.answer.routes.base import BaseAnswerResource + + with flask_app.app_context(): + resource = BaseAnswerResource() + + assert hasattr(resource, "gpt_model") + assert resource.gpt_model is not None + + +@pytest.mark.unit +class TestConversationServiceIntegration: + def test_initializes_conversation_service(self, mock_mongo_db, flask_app): + from application.api.answer.routes.base import BaseAnswerResource + + with flask_app.app_context(): + resource = BaseAnswerResource() + + assert hasattr(resource, "conversation_service") + assert resource.conversation_service is not None + + def test_has_access_to_user_logs_collection(self, mock_mongo_db, flask_app): + from application.api.answer.routes.base import BaseAnswerResource + + with flask_app.app_context(): + resource = BaseAnswerResource() + + assert hasattr(resource, "user_logs_collection") + assert resource.user_logs_collection is not None + + +@pytest.mark.unit +class TestCompleteStreamMethod: + def test_streams_answer_chunks(self, mock_mongo_db, flask_app): + from application.api.answer.routes.base import BaseAnswerResource + + with flask_app.app_context(): + resource = BaseAnswerResource() + + mock_agent = MagicMock() + mock_agent.gen.return_value = iter( + [ + {"answer": "Hello "}, + {"answer": "world!"}, + ] + ) + + mock_retriever = MagicMock() + mock_retriever.get_params.return_value = {} + + decoded_token = {"sub": "user123"} + + stream = list( + resource.complete_stream( + question="Test question", + agent=mock_agent, + retriever=mock_retriever, + conversation_id=None, + user_api_key=None, + decoded_token=decoded_token, + should_save_conversation=False, + ) + ) + + answer_chunks = [s for s in stream if '"type": "answer"' in s] + assert len(answer_chunks) == 2 + assert '"answer": "Hello "' in answer_chunks[0] + assert '"answer": "world!"' in answer_chunks[1] + + def test_streams_sources(self, mock_mongo_db, flask_app): + from application.api.answer.routes.base import BaseAnswerResource + + with flask_app.app_context(): + resource = BaseAnswerResource() + + mock_agent = MagicMock() + mock_agent.gen.return_value = iter( + [ + {"answer": "Test answer"}, + {"sources": [{"title": "doc1.txt", "text": "x" * 200}]}, + ] + ) + + mock_retriever = MagicMock() + mock_retriever.get_params.return_value = {} + + decoded_token = {"sub": "user123"} + + stream = list( + resource.complete_stream( + question="Test?", + agent=mock_agent, + retriever=mock_retriever, + conversation_id=None, + user_api_key=None, + decoded_token=decoded_token, + should_save_conversation=False, + ) + ) + + source_chunks = [s for s in stream if '"type": "source"' in s] + assert len(source_chunks) == 1 + assert '"title": "doc1.txt"' in source_chunks[0] + + def test_handles_error_during_streaming(self, mock_mongo_db, flask_app): + from application.api.answer.routes.base import BaseAnswerResource + + with flask_app.app_context(): + resource = BaseAnswerResource() + + mock_agent = MagicMock() + mock_agent.gen.side_effect = Exception("Test error") + + mock_retriever = MagicMock() + mock_retriever.get_params.return_value = {} + + decoded_token = {"sub": "user123"} + + stream = list( + resource.complete_stream( + question="Test?", + agent=mock_agent, + retriever=mock_retriever, + conversation_id=None, + user_api_key=None, + decoded_token=decoded_token, + should_save_conversation=False, + ) + ) + + assert any('"type": "error"' in s for s in stream) + + def test_saves_conversation_when_enabled(self, mock_mongo_db, flask_app): + from application.api.answer.routes.base import BaseAnswerResource + + with flask_app.app_context(): + resource = BaseAnswerResource() + + mock_agent = MagicMock() + mock_agent.gen.return_value = iter( + [ + {"answer": "Test answer"}, + ] + ) + + mock_retriever = MagicMock() + mock_retriever.get_params.return_value = {} + + decoded_token = {"sub": "user123"} + + with patch.object( + resource.conversation_service, "save_conversation" + ) as mock_save: + mock_save.return_value = str(ObjectId()) + + list( + resource.complete_stream( + question="Test?", + agent=mock_agent, + retriever=mock_retriever, + conversation_id=None, + user_api_key=None, + decoded_token=decoded_token, + should_save_conversation=True, + ) + ) + + mock_save.assert_called_once() + + def test_logs_to_user_logs_collection(self, mock_mongo_db, flask_app): + from application.api.answer.routes.base import BaseAnswerResource + from application.core.settings import settings + + with flask_app.app_context(): + resource = BaseAnswerResource() + user_logs = mock_mongo_db[settings.MONGO_DB_NAME]["user_logs"] + + mock_agent = MagicMock() + mock_agent.gen.return_value = iter( + [ + {"answer": "Test answer"}, + ] + ) + + mock_retriever = MagicMock() + mock_retriever.get_params.return_value = {"retriever": "test"} + + decoded_token = {"sub": "user123"} + + list( + resource.complete_stream( + question="Test question?", + agent=mock_agent, + retriever=mock_retriever, + conversation_id=None, + user_api_key="test_key", + decoded_token=decoded_token, + should_save_conversation=False, + ) + ) + + assert user_logs.count_documents({}) == 1 + log_entry = user_logs.find_one({}) + assert log_entry["action"] == "stream_answer" + assert log_entry["user"] == "user123" + assert log_entry["api_key"] == "test_key" + assert log_entry["question"] == "Test question?" + + +@pytest.mark.unit +class TestProcessResponseStream: + def test_processes_complete_stream(self, mock_mongo_db, flask_app): + import json + + from application.api.answer.routes.base import BaseAnswerResource + + with flask_app.app_context(): + resource = BaseAnswerResource() + + conv_id = str(ObjectId()) + stream = [ + f'data: {json.dumps({"type": "answer", "answer": "Hello "})}\n\n', + f'data: {json.dumps({"type": "answer", "answer": "world"})}\n\n', + f'data: {json.dumps({"type": "source", "source": [{"title": "doc1"}]})}\n\n', + f'data: {json.dumps({"type": "id", "id": conv_id})}\n\n', + f'data: {json.dumps({"type": "end"})}\n\n', + ] + + result = resource.process_response_stream(iter(stream)) + + assert result[0] == conv_id + assert result[1] == "Hello world" + assert result[2] == [{"title": "doc1"}] + assert result[5] is None + + def test_handles_stream_error(self, mock_mongo_db, flask_app): + import json + + from application.api.answer.routes.base import BaseAnswerResource + + with flask_app.app_context(): + resource = BaseAnswerResource() + + stream = [ + f'data: {json.dumps({"type": "error", "error": "Test error"})}\n\n', + ] + + result = resource.process_response_stream(iter(stream)) + + assert len(result) == 5 + assert result[0] is None + assert result[4] == "Test error" + + def test_handles_malformed_stream_data(self, mock_mongo_db, flask_app): + from application.api.answer.routes.base import BaseAnswerResource + + with flask_app.app_context(): + resource = BaseAnswerResource() + + stream = [ + "data: invalid json\n\n", + 'data: {"type": "end"}\n\n', + ] + + result = resource.process_response_stream(iter(stream)) + + assert result is not None + + +@pytest.mark.unit +class TestErrorStreamGenerate: + def test_generates_error_stream(self, mock_mongo_db, flask_app): + from application.api.answer.routes.base import BaseAnswerResource + + with flask_app.app_context(): + resource = BaseAnswerResource() + + error_stream = list(resource.error_stream_generate("Test error message")) + + assert len(error_stream) == 1 + assert '"type": "error"' in error_stream[0] + assert '"error": "Test error message"' in error_stream[0] diff --git a/tests/api/answer/services/__init__.py b/tests/api/answer/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/api/answer/services/test_conversation_service.py b/tests/api/answer/services/test_conversation_service.py new file mode 100644 index 00000000..80a18ebe --- /dev/null +++ b/tests/api/answer/services/test_conversation_service.py @@ -0,0 +1,242 @@ +from unittest.mock import Mock + +import pytest +from bson import ObjectId + + +@pytest.mark.unit +class TestConversationServiceGet: + + def test_returns_none_when_no_conversation_id(self, mock_mongo_db): + from application.api.answer.services.conversation_service import ( + ConversationService, + ) + + service = ConversationService() + result = service.get_conversation("", "user_123") + + assert result is None + + def test_returns_none_when_no_user_id(self, mock_mongo_db): + from application.api.answer.services.conversation_service import ( + ConversationService, + ) + + service = ConversationService() + result = service.get_conversation(str(ObjectId()), "") + + assert result is None + + def test_returns_conversation_for_owner(self, mock_mongo_db): + from application.api.answer.services.conversation_service import ( + ConversationService, + ) + from application.core.settings import settings + + service = ConversationService() + collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"] + + conv_id = ObjectId() + conversation = { + "_id": conv_id, + "user": "user_123", + "name": "Test Conv", + "queries": [], + } + collection.insert_one(conversation) + + result = service.get_conversation(str(conv_id), "user_123") + + assert result is not None + assert result["name"] == "Test Conv" + assert result["_id"] == str(conv_id) + + def test_returns_none_for_unauthorized_user(self, mock_mongo_db): + from application.api.answer.services.conversation_service import ( + ConversationService, + ) + from application.core.settings import settings + + service = ConversationService() + collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"] + + conv_id = ObjectId() + collection.insert_one( + {"_id": conv_id, "user": "owner_123", "name": "Private Conv"} + ) + + result = service.get_conversation(str(conv_id), "hacker_456") + + assert result is None + + def test_converts_objectid_to_string(self, mock_mongo_db): + from application.api.answer.services.conversation_service import ( + ConversationService, + ) + from application.core.settings import settings + + service = ConversationService() + collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"] + + conv_id = ObjectId() + collection.insert_one({"_id": conv_id, "user": "user_123", "name": "Test"}) + + result = service.get_conversation(str(conv_id), "user_123") + + assert isinstance(result["_id"], str) + assert result["_id"] == str(conv_id) + + +@pytest.mark.unit +class TestConversationServiceSave: + + def test_raises_error_when_no_user_in_token(self, mock_mongo_db): + """Test validation: user ID required""" + from application.api.answer.services.conversation_service import ( + ConversationService, + ) + + service = ConversationService() + mock_llm = Mock() + + with pytest.raises(ValueError, match="User ID not found"): + service.save_conversation( + conversation_id=None, + question="Test?", + response="Answer", + thought="", + sources=[], + tool_calls=[], + llm=mock_llm, + gpt_model="gpt-4", + decoded_token={}, # No 'sub' key + ) + + def test_truncates_long_source_text(self, mock_mongo_db): + from application.api.answer.services.conversation_service import ( + ConversationService, + ) + from application.core.settings import settings + from bson import ObjectId + + service = ConversationService() + collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"] + + mock_llm = Mock() + mock_llm.gen.return_value = "Test Summary" + + long_text = "x" * 2000 + sources = [{"text": long_text, "title": "Doc"}] + + conv_id = service.save_conversation( + conversation_id=None, + question="Question", + response="Response", + thought="", + sources=sources, + tool_calls=[], + llm=mock_llm, + gpt_model="gpt-4", + decoded_token={"sub": "user_123"}, + ) + + saved_conv = collection.find_one({"_id": ObjectId(conv_id)}) + saved_source_text = saved_conv["queries"][0]["sources"][0]["text"] + + assert len(saved_source_text) == 1000 + assert saved_source_text == "x" * 1000 + + def test_creates_new_conversation_with_summary(self, mock_mongo_db): + from application.api.answer.services.conversation_service import ( + ConversationService, + ) + from application.core.settings import settings + from bson import ObjectId + + service = ConversationService() + collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"] + + mock_llm = Mock() + mock_llm.gen.return_value = "Python Basics" + + conv_id = service.save_conversation( + conversation_id=None, + question="What is Python?", + response="Python is a programming language", + thought="", + sources=[], + tool_calls=[], + llm=mock_llm, + gpt_model="gpt-4", + decoded_token={"sub": "user_123"}, + ) + + assert conv_id is not None + saved_conv = collection.find_one({"_id": ObjectId(conv_id)}) + assert saved_conv["name"] == "Python Basics" + assert saved_conv["user"] == "user_123" + assert len(saved_conv["queries"]) == 1 + assert saved_conv["queries"][0]["prompt"] == "What is Python?" + + def test_appends_to_existing_conversation(self, mock_mongo_db): + from application.api.answer.services.conversation_service import ( + ConversationService, + ) + from application.core.settings import settings + from bson import ObjectId + + service = ConversationService() + collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"] + + existing_conv_id = ObjectId() + collection.insert_one( + { + "_id": existing_conv_id, + "user": "user_123", + "name": "Old Conv", + "queries": [{"prompt": "Q1", "response": "A1"}], + } + ) + + mock_llm = Mock() + + result = service.save_conversation( + conversation_id=str(existing_conv_id), + question="Q2", + response="A2", + thought="", + sources=[], + tool_calls=[], + llm=mock_llm, + gpt_model="gpt-4", + decoded_token={"sub": "user_123"}, + ) + + assert result == str(existing_conv_id) + + def test_prevents_unauthorized_conversation_update(self, mock_mongo_db): + from application.api.answer.services.conversation_service import ( + ConversationService, + ) + from application.core.settings import settings + + service = ConversationService() + collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"] + + conv_id = ObjectId() + collection.insert_one({"_id": conv_id, "user": "owner_123", "queries": []}) + + mock_llm = Mock() + + with pytest.raises(ValueError, match="not found or unauthorized"): + service.save_conversation( + conversation_id=str(conv_id), + question="Hack", + response="Attempt", + thought="", + sources=[], + tool_calls=[], + llm=mock_llm, + gpt_model="gpt-4", + decoded_token={"sub": "hacker_456"}, + ) diff --git a/tests/api/answer/services/test_stream_processor.py b/tests/api/answer/services/test_stream_processor.py new file mode 100644 index 00000000..6fd1856e --- /dev/null +++ b/tests/api/answer/services/test_stream_processor.py @@ -0,0 +1,252 @@ +import pytest +from bson import ObjectId + + +@pytest.mark.unit +class TestGetPromptFunction: + + def test_loads_custom_prompt_from_database(self, mock_mongo_db): + from application.api.answer.services.stream_processor import get_prompt + from application.core.settings import settings + + prompts_collection = mock_mongo_db[settings.MONGO_DB_NAME]["prompts"] + prompt_id = ObjectId() + + prompts_collection.insert_one( + { + "_id": prompt_id, + "content": "Custom prompt from database", + "user": "user_123", + } + ) + + result = get_prompt(str(prompt_id), prompts_collection) + assert result == "Custom prompt from database" + + def test_raises_error_for_invalid_prompt_id(self, mock_mongo_db): + from application.api.answer.services.stream_processor import get_prompt + from application.core.settings import settings + + prompts_collection = mock_mongo_db[settings.MONGO_DB_NAME]["prompts"] + + with pytest.raises(ValueError, match="Invalid prompt ID"): + get_prompt(str(ObjectId()), prompts_collection) + + def test_raises_error_for_malformed_id(self, mock_mongo_db): + from application.api.answer.services.stream_processor import get_prompt + from application.core.settings import settings + + prompts_collection = mock_mongo_db[settings.MONGO_DB_NAME]["prompts"] + + with pytest.raises(ValueError, match="Invalid prompt ID"): + get_prompt("not_a_valid_id", prompts_collection) + + +@pytest.mark.unit +class TestStreamProcessorInitialization: + + def test_initializes_with_decoded_token(self, mock_mongo_db): + from application.api.answer.services.stream_processor import StreamProcessor + + request_data = { + "question": "What is Python?", + "conversation_id": str(ObjectId()), + } + decoded_token = {"sub": "user_123", "email": "test@example.com"} + + processor = StreamProcessor(request_data, decoded_token) + + assert processor.data == request_data + assert processor.decoded_token == decoded_token + assert processor.initial_user_id == "user_123" + assert processor.conversation_id == request_data["conversation_id"] + + def test_initializes_without_token(self, mock_mongo_db): + from application.api.answer.services.stream_processor import StreamProcessor + + request_data = {"question": "Test question"} + + processor = StreamProcessor(request_data, None) + + assert processor.decoded_token is None + assert processor.initial_user_id is None + assert processor.data == request_data + + def test_initializes_default_attributes(self, mock_mongo_db): + from application.api.answer.services.stream_processor import StreamProcessor + + processor = StreamProcessor({"question": "Test"}, {"sub": "user_123"}) + + assert processor.source == {} + assert processor.all_sources == [] + assert processor.attachments == [] + assert processor.history == [] + assert processor.agent_config == {} + assert processor.retriever_config == {} + assert processor.is_shared_usage is False + assert processor.shared_token is None + + def test_extracts_conversation_id_from_request(self, mock_mongo_db): + from application.api.answer.services.stream_processor import StreamProcessor + + conv_id = str(ObjectId()) + request_data = {"question": "Test", "conversation_id": conv_id} + + processor = StreamProcessor(request_data, {"sub": "user_123"}) + + assert processor.conversation_id == conv_id + + +@pytest.mark.unit +class TestStreamProcessorHistoryLoading: + + def test_loads_history_from_existing_conversation(self, mock_mongo_db): + from application.api.answer.services.stream_processor import StreamProcessor + from application.core.settings import settings + + conversations_collection = mock_mongo_db[settings.MONGO_DB_NAME][ + "conversations" + ] + conv_id = ObjectId() + + conversations_collection.insert_one( + { + "_id": conv_id, + "user": "user_123", + "name": "Test Conv", + "queries": [ + {"prompt": "What is Python?", "response": "Python is a language"}, + {"prompt": "Tell me more", "response": "Python is versatile"}, + ], + } + ) + + request_data = { + "question": "How to install it?", + "conversation_id": str(conv_id), + } + + processor = StreamProcessor(request_data, {"sub": "user_123"}) + processor._load_conversation_history() + + assert len(processor.history) == 2 + assert processor.history[0]["prompt"] == "What is Python?" + assert processor.history[1]["response"] == "Python is versatile" + + def test_raises_error_for_unauthorized_conversation(self, mock_mongo_db): + from application.api.answer.services.stream_processor import StreamProcessor + from application.core.settings import settings + + conversations_collection = mock_mongo_db[settings.MONGO_DB_NAME][ + "conversations" + ] + conv_id = ObjectId() + + conversations_collection.insert_one( + { + "_id": conv_id, + "user": "owner_123", + "name": "Private Conv", + "queries": [], + } + ) + + request_data = {"question": "Hack attempt", "conversation_id": str(conv_id)} + + processor = StreamProcessor(request_data, {"sub": "hacker_456"}) + + with pytest.raises(ValueError, match="Conversation not found or unauthorized"): + processor._load_conversation_history() + + def test_uses_request_history_when_no_conversation_id(self, mock_mongo_db): + from application.api.answer.services.stream_processor import StreamProcessor + + request_data = { + "question": "What is Python?", + "history": [{"prompt": "Hello", "response": "Hi there!"}], + } + + processor = StreamProcessor(request_data, {"sub": "user_123"}) + + assert processor.conversation_id is None + + +@pytest.mark.unit +class TestStreamProcessorAgentConfiguration: + + def test_configures_agent_from_valid_api_key(self, mock_mongo_db): + from application.api.answer.services.stream_processor import StreamProcessor + from application.core.settings import settings + + agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"] + agent_id = ObjectId() + + agents_collection.insert_one( + { + "_id": agent_id, + "key": "test_api_key_123", + "endpoint": "openai", + "model": "gpt-4", + "prompt_id": "default", + "user": "user_123", + } + ) + + request_data = {"question": "Test", "api_key": "test_api_key_123"} + + processor = StreamProcessor(request_data, None) + + try: + processor._configure_agent() + assert processor.agent_config is not None + except Exception as e: + assert "Invalid API Key" in str(e) + + def test_uses_default_config_without_api_key(self, mock_mongo_db): + from application.api.answer.services.stream_processor import StreamProcessor + + request_data = {"question": "Test"} + + processor = StreamProcessor(request_data, {"sub": "user_123"}) + processor._configure_agent() + + assert isinstance(processor.agent_config, dict) + + +@pytest.mark.unit +class TestStreamProcessorAttachments: + + def test_processes_attachments_from_request(self, mock_mongo_db): + from application.api.answer.services.stream_processor import StreamProcessor + from application.core.settings import settings + + attachments_collection = mock_mongo_db[settings.MONGO_DB_NAME]["attachments"] + att_id = ObjectId() + + attachments_collection.insert_one( + { + "_id": att_id, + "filename": "document.pdf", + "content": "Document content", + "user": "user_123", + } + ) + + request_data = {"question": "Analyze this", "attachments": [str(att_id)]} + + processor = StreamProcessor(request_data, {"sub": "user_123"}) + + assert processor.data.get("attachments") == [str(att_id)] + + def test_handles_empty_attachments(self, mock_mongo_db): + from application.api.answer.services.stream_processor import StreamProcessor + + request_data = {"question": "Simple question"} + + processor = StreamProcessor(request_data, {"sub": "user_123"}) + + assert processor.attachments == [] + assert ( + "attachments" not in processor.data + or processor.data.get("attachments") is None + ) diff --git a/tests/api/conftest.py b/tests/api/conftest.py new file mode 100644 index 00000000..2c98b14f --- /dev/null +++ b/tests/api/conftest.py @@ -0,0 +1,89 @@ +"""API-specific test fixtures.""" + +import pytest +from bson import ObjectId + + +@pytest.fixture +def auth_headers(): + return {"Authorization": "Bearer test_token"} + + +@pytest.fixture +def mock_request_token(monkeypatch, decoded_token): + def mock_decorator(f): + def wrapper(*args, **kwargs): + from flask import request + + request.decoded_token = decoded_token + return f(*args, **kwargs) + + return wrapper + + monkeypatch.setattr("application.auth.api_key_required", lambda: mock_decorator) + return decoded_token + + +@pytest.fixture +def sample_conversation(): + return { + "_id": ObjectId(), + "user": "test_user", + "name": "Test Conversation", + "queries": [ + { + "prompt": "What is Python?", + "response": "Python is a programming language", + } + ], + "date": "2025-01-01T00:00:00", + } + + +@pytest.fixture +def sample_prompt(): + return { + "_id": ObjectId(), + "user": "test_user", + "name": "Helpful Assistant", + "content": "You are a helpful assistant that provides clear and concise answers.", + "type": "custom", + } + + +@pytest.fixture +def sample_agent(): + return { + "_id": ObjectId(), + "user": "test_user", + "name": "Test Agent", + "type": "classic", + "endpoint": "openai", + "model": "gpt-4", + "prompt_id": "default", + "status": "active", + } + + +@pytest.fixture +def sample_answer_request(): + return { + "question": "What is Python?", + "history": [], + "conversation_id": None, + "prompt_id": "default", + "chunks": 2, + "token_limit": 1000, + "retriever": "classic_rag", + "active_docs": "local/test/", + "isNoneDoc": False, + "save_conversation": True, + } + + +@pytest.fixture +def flask_app(): + from flask import Flask + + app = Flask(__name__) + return app diff --git a/tests/api/user/test_base.py b/tests/api/user/test_base.py new file mode 100644 index 00000000..0190ac46 --- /dev/null +++ b/tests/api/user/test_base.py @@ -0,0 +1,311 @@ +import datetime +import io +from unittest.mock import Mock, patch + +import pytest +from bson import ObjectId +from werkzeug.datastructures import FileStorage + + +@pytest.mark.unit +class TestTimeRangeGenerators: + + def test_generate_minute_range(self): + from application.api.user.base import generate_minute_range + + start = datetime.datetime(2024, 1, 1, 10, 0, 0) + end = datetime.datetime(2024, 1, 1, 10, 5, 0) + + result = generate_minute_range(start, end) + + assert len(result) == 6 + assert "2024-01-01 10:00:00" in result + assert "2024-01-01 10:05:00" in result + assert all(val == 0 for val in result.values()) + + def test_generate_hourly_range(self): + from application.api.user.base import generate_hourly_range + + start = datetime.datetime(2024, 1, 1, 10, 0, 0) + end = datetime.datetime(2024, 1, 1, 15, 0, 0) + + result = generate_hourly_range(start, end) + + assert len(result) == 6 + assert "2024-01-01 10:00" in result + assert "2024-01-01 15:00" in result + assert all(val == 0 for val in result.values()) + + def test_generate_date_range(self): + from application.api.user.base import generate_date_range + + start = datetime.date(2024, 1, 1) + end = datetime.date(2024, 1, 5) + + result = generate_date_range(start, end) + + assert len(result) == 5 + assert "2024-01-01" in result + assert "2024-01-05" in result + assert all(val == 0 for val in result.values()) + + def test_single_minute_range(self): + from application.api.user.base import generate_minute_range + + time = datetime.datetime(2024, 1, 1, 10, 30, 0) + result = generate_minute_range(time, time) + + assert len(result) == 1 + assert "2024-01-01 10:30:00" in result + + +@pytest.mark.unit +class TestEnsureUserDoc: + + def test_creates_new_user_with_defaults(self, mock_mongo_db): + from application.api.user.base import ensure_user_doc + + user_id = "test_user_123" + + result = ensure_user_doc(user_id) + + assert result is not None + assert result["user_id"] == user_id + assert "agent_preferences" in result + assert result["agent_preferences"]["pinned"] == [] + assert result["agent_preferences"]["shared_with_me"] == [] + + def test_returns_existing_user(self, mock_mongo_db): + from application.api.user.base import ensure_user_doc + from application.core.settings import settings + + users_collection = mock_mongo_db[settings.MONGO_DB_NAME]["users"] + user_id = "existing_user" + + existing_doc = { + "user_id": user_id, + "agent_preferences": {"pinned": ["agent1"], "shared_with_me": ["agent2"]}, + } + users_collection.insert_one(existing_doc) + + result = ensure_user_doc(user_id) + + assert result["user_id"] == user_id + assert result["agent_preferences"]["pinned"] == ["agent1"] + assert result["agent_preferences"]["shared_with_me"] == ["agent2"] + + def test_adds_missing_preferences_fields(self, mock_mongo_db): + from application.api.user.base import ensure_user_doc + from application.core.settings import settings + + users_collection = mock_mongo_db[settings.MONGO_DB_NAME]["users"] + user_id = "incomplete_user" + + users_collection.insert_one( + {"user_id": user_id, "agent_preferences": {"pinned": ["agent1"]}} + ) + + result = ensure_user_doc(user_id) + + assert "shared_with_me" in result["agent_preferences"] + assert result["agent_preferences"]["shared_with_me"] == [] + + +@pytest.mark.unit +class TestResolveToolDetails: + + def test_resolves_tool_ids_to_details(self, mock_mongo_db): + from application.api.user.base import resolve_tool_details + from application.core.settings import settings + + user_tools = mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"] + tool_id1 = ObjectId() + tool_id2 = ObjectId() + + user_tools.insert_one( + {"_id": tool_id1, "name": "calculator", "displayName": "Calculator Tool"} + ) + user_tools.insert_one( + {"_id": tool_id2, "name": "weather", "displayName": "Weather API"} + ) + + result = resolve_tool_details([str(tool_id1), str(tool_id2)]) + + assert len(result) == 2 + assert result[0]["id"] == str(tool_id1) + assert result[0]["name"] == "calculator" + assert result[0]["display_name"] == "Calculator Tool" + assert result[1]["name"] == "weather" + + def test_handles_missing_display_name(self, mock_mongo_db): + from application.api.user.base import resolve_tool_details + from application.core.settings import settings + + user_tools = mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"] + tool_id = ObjectId() + + user_tools.insert_one({"_id": tool_id, "name": "test_tool"}) + + result = resolve_tool_details([str(tool_id)]) + + assert result[0]["display_name"] == "test_tool" + + def test_empty_tool_ids_list(self, mock_mongo_db): + from application.api.user.base import resolve_tool_details + + result = resolve_tool_details([]) + + assert result == [] + + +@pytest.mark.unit +class TestGetVectorStore: + + @patch("application.api.user.base.VectorCreator.create_vectorstore") + def test_creates_vector_store(self, mock_create, mock_mongo_db): + from application.api.user.base import get_vector_store + + mock_store = Mock() + mock_create.return_value = mock_store + source_id = "test_source_123" + + result = get_vector_store(source_id) + + assert result == mock_store + mock_create.assert_called_once() + args, kwargs = mock_create.call_args + assert kwargs.get("source_id") == source_id + + +@pytest.mark.unit +class TestHandleImageUpload: + + def test_returns_existing_url_when_no_file(self, flask_app): + from application.api.user.base import handle_image_upload + + with flask_app.test_request_context(): + mock_request = Mock() + mock_request.files = {} + mock_storage = Mock() + existing_url = "existing/path/image.jpg" + + url, error = handle_image_upload( + mock_request, existing_url, "user123", mock_storage + ) + + assert url == existing_url + assert error is None + + def test_uploads_new_image(self, flask_app): + from application.api.user.base import handle_image_upload + + with flask_app.test_request_context(): + mock_file = FileStorage( + stream=io.BytesIO(b"fake image data"), filename="test_image.png" + ) + mock_request = Mock() + mock_request.files = {"image": mock_file} + mock_storage = Mock() + mock_storage.save_file.return_value = {"success": True} + + url, error = handle_image_upload( + mock_request, "old_url", "user123", mock_storage + ) + + assert error is None + assert url is not None + assert "test_image.png" in url + assert "user123" in url + mock_storage.save_file.assert_called_once() + + def test_ignores_empty_filename(self, flask_app): + from application.api.user.base import handle_image_upload + + with flask_app.test_request_context(): + mock_file = Mock() + mock_file.filename = "" + mock_request = Mock() + mock_request.files = {"image": mock_file} + mock_storage = Mock() + existing_url = "existing.jpg" + + url, error = handle_image_upload( + mock_request, existing_url, "user123", mock_storage + ) + + assert url == existing_url + assert error is None + mock_storage.save_file.assert_not_called() + + def test_handles_upload_error(self, flask_app): + from application.api.user.base import handle_image_upload + + with flask_app.app_context(): + mock_file = FileStorage(stream=io.BytesIO(b"data"), filename="test.png") + mock_request = Mock() + mock_request.files = {"image": mock_file} + mock_storage = Mock() + mock_storage.save_file.side_effect = Exception("Storage error") + + url, error = handle_image_upload( + mock_request, "old.jpg", "user123", mock_storage + ) + + assert url is None + assert error is not None + assert error.status_code == 400 + + +@pytest.mark.unit +class TestRequireAgentDecorator: + + def test_validates_webhook_token(self, mock_mongo_db, flask_app): + from application.api.user.base import require_agent + from application.core.settings import settings + + with flask_app.app_context(): + agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"] + agent_id = ObjectId() + webhook_token = "valid_webhook_token_123" + + agents_collection.insert_one( + {"_id": agent_id, "incoming_webhook_token": webhook_token} + ) + + @require_agent + def test_func(webhook_token=None, agent=None, agent_id_str=None): + return {"agent_id": agent_id_str} + + result = test_func(webhook_token=webhook_token) + + assert result["agent_id"] == str(agent_id) + + def test_returns_400_for_missing_token(self, mock_mongo_db, flask_app): + from application.api.user.base import require_agent + + with flask_app.app_context(): + + @require_agent + def test_func(webhook_token=None, agent=None, agent_id_str=None): + return {"success": True} + + result = test_func() + + assert result.status_code == 400 + assert result.json["success"] is False + assert "missing" in result.json["message"].lower() + + def test_returns_404_for_invalid_token(self, mock_mongo_db, flask_app): + from application.api.user.base import require_agent + + with flask_app.app_context(): + + @require_agent + def test_func(webhook_token=None, agent=None, agent_id_str=None): + return {"success": True} + + result = test_func(webhook_token="invalid_token_999") + + assert result.status_code == 404 + assert result.json["success"] is False + assert "not found" in result.json["message"].lower() diff --git a/tests/conftest.py b/tests/conftest.py index 3fd46f2f..9218378f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,15 @@ from unittest.mock import Mock +import mongomock + import pytest -from application.core.settings import settings + + +def get_settings(): + """Lazy load settings to avoid import-time errors.""" + from application.core.settings import settings + + return settings @pytest.fixture @@ -35,18 +43,51 @@ def mock_retriever(): @pytest.fixture def mock_mongo_db(monkeypatch): - fake_collection = FakeMongoCollection() - fake_db = { - "agents": fake_collection, - "user_tools": fake_collection, - "memories": fake_collection, - } - fake_client = {settings.MONGO_DB_NAME: fake_db} + """Mock MongoDB using mongomock - industry standard MongoDB mocking library.""" + settings = get_settings() + mock_client = mongomock.MongoClient() + mock_db = mock_client[settings.MONGO_DB_NAME] + + def get_mock_client(): + return {settings.MONGO_DB_NAME: mock_db} + + monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", get_mock_client) + + monkeypatch.setattr("application.api.user.base.users_collection", mock_db["users"]) monkeypatch.setattr( - "application.core.mongo_db.MongoDB.get_client", lambda: fake_client + "application.api.user.base.user_tools_collection", mock_db["user_tools"] ) - return fake_client + monkeypatch.setattr( + "application.api.user.base.agents_collection", mock_db["agents"] + ) + monkeypatch.setattr( + "application.api.user.base.conversations_collection", mock_db["conversations"] + ) + monkeypatch.setattr( + "application.api.user.base.sources_collection", mock_db["sources"] + ) + monkeypatch.setattr( + "application.api.user.base.prompts_collection", mock_db["prompts"] + ) + monkeypatch.setattr( + "application.api.user.base.feedback_collection", mock_db["feedback"] + ) + monkeypatch.setattr( + "application.api.user.base.token_usage_collection", mock_db["token_usage"] + ) + monkeypatch.setattr( + "application.api.user.base.attachments_collection", mock_db["attachments"] + ) + monkeypatch.setattr( + "application.api.user.base.user_logs_collection", mock_db["user_logs"] + ) + monkeypatch.setattr( + "application.api.user.base.shared_conversations_collections", + mock_db["shared_conversations"], + ) + + return get_mock_client() @pytest.fixture @@ -87,53 +128,6 @@ def log_context(): return context -class FakeMongoCollection: - def __init__(self): - self.docs = {} - - def find_one(self, query, projection=None): - if "key" in query: - return self.docs.get(query["key"]) - if "_id" in query: - return self.docs.get(str(query["_id"])) - if "user" in query: - for doc in self.docs.values(): - if doc.get("user") == query["user"]: - return doc - return None - - def find(self, query, projection=None): - results = [] - if "_id" in query and "$in" in query["_id"]: - for doc_id in query["_id"]["$in"]: - doc = self.docs.get(str(doc_id)) - if doc: - results.append(doc) - elif "user" in query: - for doc in self.docs.values(): - if doc.get("user") == query["user"]: - if "status" in query: - if doc.get("status") == query["status"]: - results.append(doc) - else: - results.append(doc) - return results - - def insert_one(self, doc): - doc_id = doc.get("_id", len(self.docs)) - self.docs[str(doc_id)] = doc - return Mock(inserted_id=doc_id) - - def update_one(self, query, update, upsert=False): - return Mock(modified_count=1) - - def delete_one(self, query): - return Mock(deleted_count=1) - - def delete_many(self, query): - return Mock(deleted_count=0) - - @pytest.fixture def mock_llm_creator(mock_llm, monkeypatch): monkeypatch.setattr( diff --git a/tests/llm/handlers/test_base.py b/tests/llm/handlers/test_llm_handlers.py similarity index 69% rename from tests/llm/handlers/test_base.py rename to tests/llm/handlers/test_llm_handlers.py index 8e70793c..a44b2de5 100644 --- a/tests/llm/handlers/test_base.py +++ b/tests/llm/handlers/test_llm_handlers.py @@ -1,19 +1,13 @@ -from unittest.mock import Mock, patch from typing import Any, Dict, Generator +from unittest.mock import Mock, patch from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall class TestToolCall: - """Test ToolCall dataclass.""" - def test_tool_call_creation(self): - """Test basic ToolCall creation.""" tool_call = ToolCall( - id="test_id", - name="test_function", - arguments={"arg1": "value1"}, - index=0 + id="test_id", name="test_function", arguments={"arg1": "value1"}, index=0 ) assert tool_call.id == "test_id" assert tool_call.name == "test_function" @@ -21,12 +15,11 @@ class TestToolCall: assert tool_call.index == 0 def test_tool_call_from_dict(self): - """Test ToolCall creation from dictionary.""" data = { "id": "call_123", "name": "get_weather", "arguments": {"location": "New York"}, - "index": 1 + "index": 1, } tool_call = ToolCall.from_dict(data) assert tool_call.id == "call_123" @@ -35,7 +28,6 @@ class TestToolCall: assert tool_call.index == 1 def test_tool_call_from_dict_missing_fields(self): - """Test ToolCall creation with missing fields.""" data = {"name": "test_func"} tool_call = ToolCall.from_dict(data) assert tool_call.id == "" @@ -45,16 +37,13 @@ class TestToolCall: class TestLLMResponse: - """Test LLMResponse dataclass.""" - def test_llm_response_creation(self): - """Test basic LLMResponse creation.""" tool_calls = [ToolCall(id="1", name="func", arguments={})] response = LLMResponse( content="Hello", tool_calls=tool_calls, finish_reason="tool_calls", - raw_response={"test": "data"} + raw_response={"test": "data"}, ) assert response.content == "Hello" assert len(response.tool_calls) == 1 @@ -62,55 +51,43 @@ class TestLLMResponse: assert response.raw_response == {"test": "data"} def test_requires_tool_call_true(self): - """Test requires_tool_call property when tool calls are needed.""" tool_calls = [ToolCall(id="1", name="func", arguments={})] response = LLMResponse( content="", tool_calls=tool_calls, finish_reason="tool_calls", - raw_response={} + raw_response={}, ) assert response.requires_tool_call is True def test_requires_tool_call_false_no_tools(self): - """Test requires_tool_call property when no tool calls.""" response = LLMResponse( - content="Hello", - tool_calls=[], - finish_reason="stop", - raw_response={} + content="Hello", tool_calls=[], finish_reason="stop", raw_response={} ) assert response.requires_tool_call is False def test_requires_tool_call_false_wrong_finish_reason(self): - """Test requires_tool_call property with tools but wrong finish reason.""" tool_calls = [ToolCall(id="1", name="func", arguments={})] response = LLMResponse( content="Hello", tool_calls=tool_calls, finish_reason="stop", - raw_response={} + raw_response={}, ) assert response.requires_tool_call is False class ConcreteHandler(LLMHandler): - """Concrete implementation for testing abstract base class.""" - def parse_response(self, response: Any) -> LLMResponse: return LLMResponse( content=str(response), tool_calls=[], finish_reason="stop", - raw_response=response + raw_response=response, ) def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict: - return { - "role": "tool", - "content": str(result), - "tool_call_id": tool_call.id - } + return {"role": "tool", "content": str(result), "tool_call_id": tool_call.id} def _iterate_stream(self, response: Any) -> Generator: for chunk in response: @@ -118,114 +95,119 @@ class ConcreteHandler(LLMHandler): class TestLLMHandler: - """Test LLMHandler base class.""" - def test_handler_initialization(self): - """Test handler initialization.""" handler = ConcreteHandler() assert handler.llm_calls == [] assert handler.tool_calls == [] def test_prepare_messages_no_attachments(self): - """Test prepare_messages with no attachments.""" handler = ConcreteHandler() messages = [{"role": "user", "content": "Hello"}] - + mock_agent = Mock() result = handler.prepare_messages(mock_agent, messages, None) assert result == messages def test_prepare_messages_with_supported_attachments(self): - """Test prepare_messages with supported attachments.""" handler = ConcreteHandler() messages = [{"role": "user", "content": "Hello"}] attachments = [{"mime_type": "image/png", "path": "/test.png"}] - + mock_agent = Mock() mock_agent.llm.get_supported_attachment_types.return_value = ["image/png"] mock_agent.llm.prepare_messages_with_attachments.return_value = messages - + result = handler.prepare_messages(mock_agent, messages, attachments) mock_agent.llm.prepare_messages_with_attachments.assert_called_once_with( messages, attachments ) assert result == messages - @patch('application.llm.handlers.base.logger') + @patch("application.llm.handlers.base.logger") def test_prepare_messages_with_unsupported_attachments(self, mock_logger): - """Test prepare_messages with unsupported attachments.""" handler = ConcreteHandler() messages = [{"role": "user", "content": "Hello"}] attachments = [{"mime_type": "text/plain", "path": "/test.txt"}] - + mock_agent = Mock() mock_agent.llm.get_supported_attachment_types.return_value = ["image/png"] - - with patch.object(handler, '_append_unsupported_attachments', return_value=messages) as mock_append: + + with patch.object( + handler, "_append_unsupported_attachments", return_value=messages + ) as mock_append: result = handler.prepare_messages(mock_agent, messages, attachments) mock_append.assert_called_once_with(messages, attachments) assert result == messages def test_prepare_messages_mixed_attachments(self): - """Test prepare_messages with both supported and unsupported attachments.""" handler = ConcreteHandler() messages = [{"role": "user", "content": "Hello"}] attachments = [ {"mime_type": "image/png", "path": "/test.png"}, - {"mime_type": "text/plain", "path": "/test.txt"} + {"mime_type": "text/plain", "path": "/test.txt"}, ] - + mock_agent = Mock() mock_agent.llm.get_supported_attachment_types.return_value = ["image/png"] mock_agent.llm.prepare_messages_with_attachments.return_value = messages - - with patch.object(handler, '_append_unsupported_attachments', return_value=messages) as mock_append: + + with patch.object( + handler, "_append_unsupported_attachments", return_value=messages + ) as mock_append: result = handler.prepare_messages(mock_agent, messages, attachments) - - # Should call both methods + mock_agent.llm.prepare_messages_with_attachments.assert_called_once() mock_append.assert_called_once() assert result == messages def test_process_message_flow_non_streaming(self): - """Test process_message_flow for non-streaming.""" handler = ConcreteHandler() mock_agent = Mock() initial_response = "test response" tools_dict = {} messages = [{"role": "user", "content": "Hello"}] - - with patch.object(handler, 'prepare_messages', return_value=messages) as mock_prepare: - with patch.object(handler, 'handle_non_streaming', return_value="final") as mock_handle: + + with patch.object( + handler, "prepare_messages", return_value=messages + ) as mock_prepare: + with patch.object( + handler, "handle_non_streaming", return_value="final" + ) as mock_handle: result = handler.process_message_flow( mock_agent, initial_response, tools_dict, messages, stream=False ) - + mock_prepare.assert_called_once_with(mock_agent, messages, None) - mock_handle.assert_called_once_with(mock_agent, initial_response, tools_dict, messages) + mock_handle.assert_called_once_with( + mock_agent, initial_response, tools_dict, messages + ) assert result == "final" def test_process_message_flow_streaming(self): - """Test process_message_flow for streaming.""" handler = ConcreteHandler() mock_agent = Mock() initial_response = "test response" tools_dict = {} messages = [{"role": "user", "content": "Hello"}] - + def mock_generator(): yield "chunk1" yield "chunk2" - - with patch.object(handler, 'prepare_messages', return_value=messages) as mock_prepare: - with patch.object(handler, 'handle_streaming', return_value=mock_generator()) as mock_handle: + + with patch.object( + handler, "prepare_messages", return_value=messages + ) as mock_prepare: + with patch.object( + handler, "handle_streaming", return_value=mock_generator() + ) as mock_handle: result = handler.process_message_flow( mock_agent, initial_response, tools_dict, messages, stream=True ) - + mock_prepare.assert_called_once_with(mock_agent, messages, None) - mock_handle.assert_called_once_with(mock_agent, initial_response, tools_dict, messages) - - # Verify it's a generator + mock_handle.assert_called_once_with( + mock_agent, initial_response, tools_dict, messages + ) + chunks = list(result) assert chunks == ["chunk1", "chunk2"] diff --git a/tests/requirements.txt b/tests/requirements.txt index 57a2c092..777cd498 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,3 +1,4 @@ pytest>=8.0.0 pytest-cov>=4.1.0 coverage>=7.4.0 +mongomock>=4.3.0