test: implement full API test suite with mongomock and centralized fixtures (#2068)

This commit is contained in:
Siddhant Rai
2025-10-17 14:31:14 +05:30
committed by GitHub
parent ababc9ae04
commit 125ce0aad3
14 changed files with 1601 additions and 167 deletions

View File

@@ -21,7 +21,7 @@ def get_encoding():
def get_gpt_model() -> str:
"""Get the appropriate GPT model based on provider"""
"""Get GPT model based on provider"""
model_map = {
"openai": "gpt-4o-mini",
"anthropic": "claude-2",
@@ -32,16 +32,7 @@ def get_gpt_model() -> str:
def safe_filename(filename):
"""
Creates a safe filename that preserves the original extension.
Uses secure_filename, but ensures a proper filename is returned even with non-Latin characters.
Args:
filename (str): The original filename
Returns:
str: A safe filename that can be used for storage
"""
"""Create safe filename, preserving extension. Handles non-Latin characters."""
if not filename:
return str(uuid.uuid4())
_, extension = os.path.splitext(filename)
@@ -83,8 +74,14 @@ def count_tokens_docs(docs):
return tokens
def get_missing_fields(data, required_fields):
"""Check for missing required fields. Returns list of missing field names."""
return [field for field in required_fields if field not in data]
def check_required_fields(data, required_fields):
missing_fields = [field for field in required_fields if field not in data]
"""Validate required fields. Returns Flask 400 response if validation fails, None otherwise."""
missing_fields = get_missing_fields(data, required_fields)
if missing_fields:
return make_response(
jsonify(
@@ -98,7 +95,8 @@ def check_required_fields(data, required_fields):
return None
def validate_required_fields(data, required_fields):
def get_field_validation_errors(data, required_fields):
"""Check for missing and empty fields. Returns dict with 'missing_fields' and 'empty_fields', or None."""
missing_fields = []
empty_fields = []
@@ -107,12 +105,24 @@ def validate_required_fields(data, required_fields):
missing_fields.append(field)
elif not data[field]:
empty_fields.append(field)
errors = []
if missing_fields:
errors.append(f"Missing required fields: {', '.join(missing_fields)}")
if empty_fields:
errors.append(f"Empty values in required fields: {', '.join(empty_fields)}")
if errors:
if missing_fields or empty_fields:
return {"missing_fields": missing_fields, "empty_fields": empty_fields}
return None
def validate_required_fields(data, required_fields):
"""Validate required fields (must exist and be non-empty). Returns Flask 400 response if validation fails, None otherwise."""
errors_dict = get_field_validation_errors(data, required_fields)
if errors_dict:
errors = []
if errors_dict["missing_fields"]:
errors.append(
f"Missing required fields: {', '.join(errors_dict['missing_fields'])}"
)
if errors_dict["empty_fields"]:
errors.append(
f"Empty values in required fields: {', '.join(errors_dict['empty_fields'])}"
)
return make_response(
jsonify({"success": False, "message": " | ".join(errors)}), 400
)
@@ -124,10 +134,7 @@ def get_hash(data):
def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"):
"""
Limits chat history based on token count.
Returns a list of messages that fit within the token limit.
"""
"""Limit chat history to fit within token limit."""
from application.core.settings import settings
max_token_limit = (
@@ -161,7 +168,7 @@ def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"):
def validate_function_name(function_name):
"""Validates if a function name matches the allowed pattern."""
"""Validate function name matches allowed pattern (alphanumeric, underscore, hyphen)."""
if not re.match(r"^[a-zA-Z0-9_-]+$", function_name):
return False
return True

View File

@@ -3,7 +3,6 @@ from unittest.mock import Mock
import pytest
from application.agents.classic_agent import ClassicAgent
from application.core.settings import settings
from tests.conftest import FakeMongoCollection
@pytest.mark.unit
@@ -168,10 +167,13 @@ class TestBaseAgentTools:
mock_llm_creator,
mock_llm_handler_creator,
):
mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"].docs = {
"1": {"_id": "1", "user": "test_user", "name": "tool1", "status": True},
"2": {"_id": "2", "user": "test_user", "name": "tool2", "status": True},
}
user_tools = mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"]
user_tools.insert_one(
{"_id": "1", "user": "test_user", "name": "tool1", "status": True}
)
user_tools.insert_one(
{"_id": "2", "user": "test_user", "name": "tool2", "status": True}
)
agent = ClassicAgent(**agent_base_params)
tools = agent._get_user_tools("test_user")
@@ -187,10 +189,13 @@ class TestBaseAgentTools:
mock_llm_creator,
mock_llm_handler_creator,
):
mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"].docs = {
"1": {"_id": "1", "user": "test_user", "name": "tool1", "status": True},
"2": {"_id": "2", "user": "test_user", "name": "tool2", "status": False},
}
user_tools = mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"]
user_tools.insert_one(
{"_id": "1", "user": "test_user", "name": "tool1", "status": True}
)
user_tools.insert_one(
{"_id": "2", "user": "test_user", "name": "tool2", "status": False}
)
agent = ClassicAgent(**agent_base_params)
tools = agent._get_user_tools("test_user")
@@ -209,17 +214,16 @@ class TestBaseAgentTools:
tool_id = str(ObjectId())
tool_obj_id = ObjectId(tool_id)
fake_agent_collection = FakeMongoCollection()
fake_agent_collection.docs["api_key_123"] = {
"key": "api_key_123",
"tools": [tool_id],
}
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
agents_collection.insert_one(
{
"key": "api_key_123",
"tools": [tool_id],
}
)
fake_tools_collection = FakeMongoCollection()
fake_tools_collection.docs[tool_id] = {"_id": tool_obj_id, "name": "api_tool"}
mock_mongo_db[settings.MONGO_DB_NAME]["agents"] = fake_agent_collection
mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"] = fake_tools_collection
tools_collection = mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"]
tools_collection.insert_one({"_id": tool_obj_id, "name": "api_tool"})
agent = ClassicAgent(**agent_base_params)
tools = agent._get_tools("api_key_123")

0
tests/api/__init__.py Normal file
View File

View File

View File

View File

@@ -0,0 +1,552 @@
import datetime
from unittest.mock import MagicMock, patch
import pytest
from bson import ObjectId
@pytest.mark.unit
class TestBaseAnswerValidation:
def test_validate_request_passes_with_required_fields(
self, mock_mongo_db, flask_app
):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
data = {"question": "What is Python?"}
result = resource.validate_request(data)
assert result is None
def test_validate_request_fails_without_question(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
data = {}
result = resource.validate_request(data)
assert result is not None
assert result.status_code == 400
assert "question" in result.json["message"].lower()
def test_validate_with_conversation_id_required(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
data = {"question": "Test"}
result = resource.validate_request(data, require_conversation_id=True)
assert result is not None
assert result.status_code == 400
assert "conversation_id" in result.json["message"].lower()
def test_validate_passes_with_all_required_fields(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
data = {"question": "Test", "conversation_id": str(ObjectId())}
result = resource.validate_request(data, require_conversation_id=True)
assert result is None
@pytest.mark.unit
class TestUsageChecking:
def test_returns_none_when_no_api_key(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
agent_config = {}
result = resource.check_usage(agent_config)
assert result is None
def test_returns_error_for_invalid_api_key(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
agent_config = {"user_api_key": "invalid_key_123"}
result = resource.check_usage(agent_config)
assert result is not None
assert result.status_code == 401
assert result.json["success"] is False
assert "invalid" in result.json["message"].lower()
def test_checks_token_limit_when_enabled(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
from application.core.settings import settings
with flask_app.app_context():
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
agent_id = ObjectId()
agents_collection.insert_one(
{
"_id": agent_id,
"key": "test_key",
"limited_token_mode": True,
"token_limit": 1000,
"limited_request_mode": False,
}
)
resource = BaseAnswerResource()
agent_config = {"user_api_key": "test_key"}
result = resource.check_usage(agent_config)
assert result is None
def test_checks_request_limit_when_enabled(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
from application.core.settings import settings
with flask_app.app_context():
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
agent_id = ObjectId()
agents_collection.insert_one(
{
"_id": agent_id,
"key": "test_key",
"limited_token_mode": False,
"limited_request_mode": True,
"request_limit": 100,
}
)
resource = BaseAnswerResource()
agent_config = {"user_api_key": "test_key"}
result = resource.check_usage(agent_config)
assert result is None
def test_uses_default_limits_when_not_specified(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
from application.core.settings import settings
with flask_app.app_context():
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
agent_id = ObjectId()
agents_collection.insert_one(
{
"_id": agent_id,
"key": "test_key",
"limited_token_mode": True,
"limited_request_mode": True,
}
)
resource = BaseAnswerResource()
agent_config = {"user_api_key": "test_key"}
result = resource.check_usage(agent_config)
assert result is None
def test_exceeds_token_limit(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
from application.core.settings import settings
with flask_app.app_context():
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
token_usage_collection = mock_mongo_db[settings.MONGO_DB_NAME][
"token_usage"
]
agent_id = ObjectId()
agents_collection.insert_one(
{
"_id": agent_id,
"key": "test_key",
"limited_token_mode": True,
"token_limit": 100,
"limited_request_mode": False,
}
)
token_usage_collection.insert_one(
{
"_id": ObjectId(),
"api_key": "test_key",
"prompt_tokens": 60,
"generated_tokens": 50,
"timestamp": datetime.datetime.now(),
}
)
resource = BaseAnswerResource()
agent_config = {"user_api_key": "test_key"}
result = resource.check_usage(agent_config)
assert result is not None
assert result.status_code == 429
assert result.json["success"] is False
assert "usage limit" in result.json["message"].lower()
def test_exceeds_request_limit(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
from application.core.settings import settings
with flask_app.app_context():
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
token_usage_collection = mock_mongo_db[settings.MONGO_DB_NAME][
"token_usage"
]
agent_id = ObjectId()
agents_collection.insert_one(
{
"_id": agent_id,
"key": "test_key",
"limited_token_mode": False,
"limited_request_mode": True,
"request_limit": 2,
}
)
now = datetime.datetime.now()
for i in range(3):
token_usage_collection.insert_one(
{
"_id": ObjectId(),
"api_key": "test_key",
"prompt_tokens": 10,
"generated_tokens": 10,
"timestamp": now,
}
)
resource = BaseAnswerResource()
agent_config = {"user_api_key": "test_key"}
result = resource.check_usage(agent_config)
assert result is not None
assert result.status_code == 429
assert result.json["success"] is False
def test_both_limits_disabled_returns_none(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
from application.core.settings import settings
with flask_app.app_context():
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
agent_id = ObjectId()
agents_collection.insert_one(
{
"_id": agent_id,
"key": "test_key",
"limited_token_mode": False,
"limited_request_mode": False,
}
)
resource = BaseAnswerResource()
agent_config = {"user_api_key": "test_key"}
result = resource.check_usage(agent_config)
assert result is None
@pytest.mark.unit
class TestGPTModelRetrieval:
def test_initializes_gpt_model(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
assert hasattr(resource, "gpt_model")
assert resource.gpt_model is not None
@pytest.mark.unit
class TestConversationServiceIntegration:
def test_initializes_conversation_service(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
assert hasattr(resource, "conversation_service")
assert resource.conversation_service is not None
def test_has_access_to_user_logs_collection(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
assert hasattr(resource, "user_logs_collection")
assert resource.user_logs_collection is not None
@pytest.mark.unit
class TestCompleteStreamMethod:
def test_streams_answer_chunks(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
mock_agent = MagicMock()
mock_agent.gen.return_value = iter(
[
{"answer": "Hello "},
{"answer": "world!"},
]
)
mock_retriever = MagicMock()
mock_retriever.get_params.return_value = {}
decoded_token = {"sub": "user123"}
stream = list(
resource.complete_stream(
question="Test question",
agent=mock_agent,
retriever=mock_retriever,
conversation_id=None,
user_api_key=None,
decoded_token=decoded_token,
should_save_conversation=False,
)
)
answer_chunks = [s for s in stream if '"type": "answer"' in s]
assert len(answer_chunks) == 2
assert '"answer": "Hello "' in answer_chunks[0]
assert '"answer": "world!"' in answer_chunks[1]
def test_streams_sources(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
mock_agent = MagicMock()
mock_agent.gen.return_value = iter(
[
{"answer": "Test answer"},
{"sources": [{"title": "doc1.txt", "text": "x" * 200}]},
]
)
mock_retriever = MagicMock()
mock_retriever.get_params.return_value = {}
decoded_token = {"sub": "user123"}
stream = list(
resource.complete_stream(
question="Test?",
agent=mock_agent,
retriever=mock_retriever,
conversation_id=None,
user_api_key=None,
decoded_token=decoded_token,
should_save_conversation=False,
)
)
source_chunks = [s for s in stream if '"type": "source"' in s]
assert len(source_chunks) == 1
assert '"title": "doc1.txt"' in source_chunks[0]
def test_handles_error_during_streaming(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
mock_agent = MagicMock()
mock_agent.gen.side_effect = Exception("Test error")
mock_retriever = MagicMock()
mock_retriever.get_params.return_value = {}
decoded_token = {"sub": "user123"}
stream = list(
resource.complete_stream(
question="Test?",
agent=mock_agent,
retriever=mock_retriever,
conversation_id=None,
user_api_key=None,
decoded_token=decoded_token,
should_save_conversation=False,
)
)
assert any('"type": "error"' in s for s in stream)
def test_saves_conversation_when_enabled(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
mock_agent = MagicMock()
mock_agent.gen.return_value = iter(
[
{"answer": "Test answer"},
]
)
mock_retriever = MagicMock()
mock_retriever.get_params.return_value = {}
decoded_token = {"sub": "user123"}
with patch.object(
resource.conversation_service, "save_conversation"
) as mock_save:
mock_save.return_value = str(ObjectId())
list(
resource.complete_stream(
question="Test?",
agent=mock_agent,
retriever=mock_retriever,
conversation_id=None,
user_api_key=None,
decoded_token=decoded_token,
should_save_conversation=True,
)
)
mock_save.assert_called_once()
def test_logs_to_user_logs_collection(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
from application.core.settings import settings
with flask_app.app_context():
resource = BaseAnswerResource()
user_logs = mock_mongo_db[settings.MONGO_DB_NAME]["user_logs"]
mock_agent = MagicMock()
mock_agent.gen.return_value = iter(
[
{"answer": "Test answer"},
]
)
mock_retriever = MagicMock()
mock_retriever.get_params.return_value = {"retriever": "test"}
decoded_token = {"sub": "user123"}
list(
resource.complete_stream(
question="Test question?",
agent=mock_agent,
retriever=mock_retriever,
conversation_id=None,
user_api_key="test_key",
decoded_token=decoded_token,
should_save_conversation=False,
)
)
assert user_logs.count_documents({}) == 1
log_entry = user_logs.find_one({})
assert log_entry["action"] == "stream_answer"
assert log_entry["user"] == "user123"
assert log_entry["api_key"] == "test_key"
assert log_entry["question"] == "Test question?"
@pytest.mark.unit
class TestProcessResponseStream:
def test_processes_complete_stream(self, mock_mongo_db, flask_app):
import json
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
conv_id = str(ObjectId())
stream = [
f'data: {json.dumps({"type": "answer", "answer": "Hello "})}\n\n',
f'data: {json.dumps({"type": "answer", "answer": "world"})}\n\n',
f'data: {json.dumps({"type": "source", "source": [{"title": "doc1"}]})}\n\n',
f'data: {json.dumps({"type": "id", "id": conv_id})}\n\n',
f'data: {json.dumps({"type": "end"})}\n\n',
]
result = resource.process_response_stream(iter(stream))
assert result[0] == conv_id
assert result[1] == "Hello world"
assert result[2] == [{"title": "doc1"}]
assert result[5] is None
def test_handles_stream_error(self, mock_mongo_db, flask_app):
import json
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
stream = [
f'data: {json.dumps({"type": "error", "error": "Test error"})}\n\n',
]
result = resource.process_response_stream(iter(stream))
assert len(result) == 5
assert result[0] is None
assert result[4] == "Test error"
def test_handles_malformed_stream_data(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
stream = [
"data: invalid json\n\n",
'data: {"type": "end"}\n\n',
]
result = resource.process_response_stream(iter(stream))
assert result is not None
@pytest.mark.unit
class TestErrorStreamGenerate:
def test_generates_error_stream(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
error_stream = list(resource.error_stream_generate("Test error message"))
assert len(error_stream) == 1
assert '"type": "error"' in error_stream[0]
assert '"error": "Test error message"' in error_stream[0]

View File

View File

@@ -0,0 +1,242 @@
from unittest.mock import Mock
import pytest
from bson import ObjectId
@pytest.mark.unit
class TestConversationServiceGet:
def test_returns_none_when_no_conversation_id(self, mock_mongo_db):
from application.api.answer.services.conversation_service import (
ConversationService,
)
service = ConversationService()
result = service.get_conversation("", "user_123")
assert result is None
def test_returns_none_when_no_user_id(self, mock_mongo_db):
from application.api.answer.services.conversation_service import (
ConversationService,
)
service = ConversationService()
result = service.get_conversation(str(ObjectId()), "")
assert result is None
def test_returns_conversation_for_owner(self, mock_mongo_db):
from application.api.answer.services.conversation_service import (
ConversationService,
)
from application.core.settings import settings
service = ConversationService()
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
conv_id = ObjectId()
conversation = {
"_id": conv_id,
"user": "user_123",
"name": "Test Conv",
"queries": [],
}
collection.insert_one(conversation)
result = service.get_conversation(str(conv_id), "user_123")
assert result is not None
assert result["name"] == "Test Conv"
assert result["_id"] == str(conv_id)
def test_returns_none_for_unauthorized_user(self, mock_mongo_db):
from application.api.answer.services.conversation_service import (
ConversationService,
)
from application.core.settings import settings
service = ConversationService()
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
conv_id = ObjectId()
collection.insert_one(
{"_id": conv_id, "user": "owner_123", "name": "Private Conv"}
)
result = service.get_conversation(str(conv_id), "hacker_456")
assert result is None
def test_converts_objectid_to_string(self, mock_mongo_db):
from application.api.answer.services.conversation_service import (
ConversationService,
)
from application.core.settings import settings
service = ConversationService()
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
conv_id = ObjectId()
collection.insert_one({"_id": conv_id, "user": "user_123", "name": "Test"})
result = service.get_conversation(str(conv_id), "user_123")
assert isinstance(result["_id"], str)
assert result["_id"] == str(conv_id)
@pytest.mark.unit
class TestConversationServiceSave:
def test_raises_error_when_no_user_in_token(self, mock_mongo_db):
"""Test validation: user ID required"""
from application.api.answer.services.conversation_service import (
ConversationService,
)
service = ConversationService()
mock_llm = Mock()
with pytest.raises(ValueError, match="User ID not found"):
service.save_conversation(
conversation_id=None,
question="Test?",
response="Answer",
thought="",
sources=[],
tool_calls=[],
llm=mock_llm,
gpt_model="gpt-4",
decoded_token={}, # No 'sub' key
)
def test_truncates_long_source_text(self, mock_mongo_db):
from application.api.answer.services.conversation_service import (
ConversationService,
)
from application.core.settings import settings
from bson import ObjectId
service = ConversationService()
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
mock_llm = Mock()
mock_llm.gen.return_value = "Test Summary"
long_text = "x" * 2000
sources = [{"text": long_text, "title": "Doc"}]
conv_id = service.save_conversation(
conversation_id=None,
question="Question",
response="Response",
thought="",
sources=sources,
tool_calls=[],
llm=mock_llm,
gpt_model="gpt-4",
decoded_token={"sub": "user_123"},
)
saved_conv = collection.find_one({"_id": ObjectId(conv_id)})
saved_source_text = saved_conv["queries"][0]["sources"][0]["text"]
assert len(saved_source_text) == 1000
assert saved_source_text == "x" * 1000
def test_creates_new_conversation_with_summary(self, mock_mongo_db):
from application.api.answer.services.conversation_service import (
ConversationService,
)
from application.core.settings import settings
from bson import ObjectId
service = ConversationService()
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
mock_llm = Mock()
mock_llm.gen.return_value = "Python Basics"
conv_id = service.save_conversation(
conversation_id=None,
question="What is Python?",
response="Python is a programming language",
thought="",
sources=[],
tool_calls=[],
llm=mock_llm,
gpt_model="gpt-4",
decoded_token={"sub": "user_123"},
)
assert conv_id is not None
saved_conv = collection.find_one({"_id": ObjectId(conv_id)})
assert saved_conv["name"] == "Python Basics"
assert saved_conv["user"] == "user_123"
assert len(saved_conv["queries"]) == 1
assert saved_conv["queries"][0]["prompt"] == "What is Python?"
def test_appends_to_existing_conversation(self, mock_mongo_db):
from application.api.answer.services.conversation_service import (
ConversationService,
)
from application.core.settings import settings
from bson import ObjectId
service = ConversationService()
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
existing_conv_id = ObjectId()
collection.insert_one(
{
"_id": existing_conv_id,
"user": "user_123",
"name": "Old Conv",
"queries": [{"prompt": "Q1", "response": "A1"}],
}
)
mock_llm = Mock()
result = service.save_conversation(
conversation_id=str(existing_conv_id),
question="Q2",
response="A2",
thought="",
sources=[],
tool_calls=[],
llm=mock_llm,
gpt_model="gpt-4",
decoded_token={"sub": "user_123"},
)
assert result == str(existing_conv_id)
def test_prevents_unauthorized_conversation_update(self, mock_mongo_db):
from application.api.answer.services.conversation_service import (
ConversationService,
)
from application.core.settings import settings
service = ConversationService()
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
conv_id = ObjectId()
collection.insert_one({"_id": conv_id, "user": "owner_123", "queries": []})
mock_llm = Mock()
with pytest.raises(ValueError, match="not found or unauthorized"):
service.save_conversation(
conversation_id=str(conv_id),
question="Hack",
response="Attempt",
thought="",
sources=[],
tool_calls=[],
llm=mock_llm,
gpt_model="gpt-4",
decoded_token={"sub": "hacker_456"},
)

View File

@@ -0,0 +1,252 @@
import pytest
from bson import ObjectId
@pytest.mark.unit
class TestGetPromptFunction:
def test_loads_custom_prompt_from_database(self, mock_mongo_db):
from application.api.answer.services.stream_processor import get_prompt
from application.core.settings import settings
prompts_collection = mock_mongo_db[settings.MONGO_DB_NAME]["prompts"]
prompt_id = ObjectId()
prompts_collection.insert_one(
{
"_id": prompt_id,
"content": "Custom prompt from database",
"user": "user_123",
}
)
result = get_prompt(str(prompt_id), prompts_collection)
assert result == "Custom prompt from database"
def test_raises_error_for_invalid_prompt_id(self, mock_mongo_db):
from application.api.answer.services.stream_processor import get_prompt
from application.core.settings import settings
prompts_collection = mock_mongo_db[settings.MONGO_DB_NAME]["prompts"]
with pytest.raises(ValueError, match="Invalid prompt ID"):
get_prompt(str(ObjectId()), prompts_collection)
def test_raises_error_for_malformed_id(self, mock_mongo_db):
from application.api.answer.services.stream_processor import get_prompt
from application.core.settings import settings
prompts_collection = mock_mongo_db[settings.MONGO_DB_NAME]["prompts"]
with pytest.raises(ValueError, match="Invalid prompt ID"):
get_prompt("not_a_valid_id", prompts_collection)
@pytest.mark.unit
class TestStreamProcessorInitialization:
def test_initializes_with_decoded_token(self, mock_mongo_db):
from application.api.answer.services.stream_processor import StreamProcessor
request_data = {
"question": "What is Python?",
"conversation_id": str(ObjectId()),
}
decoded_token = {"sub": "user_123", "email": "test@example.com"}
processor = StreamProcessor(request_data, decoded_token)
assert processor.data == request_data
assert processor.decoded_token == decoded_token
assert processor.initial_user_id == "user_123"
assert processor.conversation_id == request_data["conversation_id"]
def test_initializes_without_token(self, mock_mongo_db):
from application.api.answer.services.stream_processor import StreamProcessor
request_data = {"question": "Test question"}
processor = StreamProcessor(request_data, None)
assert processor.decoded_token is None
assert processor.initial_user_id is None
assert processor.data == request_data
def test_initializes_default_attributes(self, mock_mongo_db):
from application.api.answer.services.stream_processor import StreamProcessor
processor = StreamProcessor({"question": "Test"}, {"sub": "user_123"})
assert processor.source == {}
assert processor.all_sources == []
assert processor.attachments == []
assert processor.history == []
assert processor.agent_config == {}
assert processor.retriever_config == {}
assert processor.is_shared_usage is False
assert processor.shared_token is None
def test_extracts_conversation_id_from_request(self, mock_mongo_db):
from application.api.answer.services.stream_processor import StreamProcessor
conv_id = str(ObjectId())
request_data = {"question": "Test", "conversation_id": conv_id}
processor = StreamProcessor(request_data, {"sub": "user_123"})
assert processor.conversation_id == conv_id
@pytest.mark.unit
class TestStreamProcessorHistoryLoading:
def test_loads_history_from_existing_conversation(self, mock_mongo_db):
from application.api.answer.services.stream_processor import StreamProcessor
from application.core.settings import settings
conversations_collection = mock_mongo_db[settings.MONGO_DB_NAME][
"conversations"
]
conv_id = ObjectId()
conversations_collection.insert_one(
{
"_id": conv_id,
"user": "user_123",
"name": "Test Conv",
"queries": [
{"prompt": "What is Python?", "response": "Python is a language"},
{"prompt": "Tell me more", "response": "Python is versatile"},
],
}
)
request_data = {
"question": "How to install it?",
"conversation_id": str(conv_id),
}
processor = StreamProcessor(request_data, {"sub": "user_123"})
processor._load_conversation_history()
assert len(processor.history) == 2
assert processor.history[0]["prompt"] == "What is Python?"
assert processor.history[1]["response"] == "Python is versatile"
def test_raises_error_for_unauthorized_conversation(self, mock_mongo_db):
from application.api.answer.services.stream_processor import StreamProcessor
from application.core.settings import settings
conversations_collection = mock_mongo_db[settings.MONGO_DB_NAME][
"conversations"
]
conv_id = ObjectId()
conversations_collection.insert_one(
{
"_id": conv_id,
"user": "owner_123",
"name": "Private Conv",
"queries": [],
}
)
request_data = {"question": "Hack attempt", "conversation_id": str(conv_id)}
processor = StreamProcessor(request_data, {"sub": "hacker_456"})
with pytest.raises(ValueError, match="Conversation not found or unauthorized"):
processor._load_conversation_history()
def test_uses_request_history_when_no_conversation_id(self, mock_mongo_db):
from application.api.answer.services.stream_processor import StreamProcessor
request_data = {
"question": "What is Python?",
"history": [{"prompt": "Hello", "response": "Hi there!"}],
}
processor = StreamProcessor(request_data, {"sub": "user_123"})
assert processor.conversation_id is None
@pytest.mark.unit
class TestStreamProcessorAgentConfiguration:
def test_configures_agent_from_valid_api_key(self, mock_mongo_db):
from application.api.answer.services.stream_processor import StreamProcessor
from application.core.settings import settings
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
agent_id = ObjectId()
agents_collection.insert_one(
{
"_id": agent_id,
"key": "test_api_key_123",
"endpoint": "openai",
"model": "gpt-4",
"prompt_id": "default",
"user": "user_123",
}
)
request_data = {"question": "Test", "api_key": "test_api_key_123"}
processor = StreamProcessor(request_data, None)
try:
processor._configure_agent()
assert processor.agent_config is not None
except Exception as e:
assert "Invalid API Key" in str(e)
def test_uses_default_config_without_api_key(self, mock_mongo_db):
from application.api.answer.services.stream_processor import StreamProcessor
request_data = {"question": "Test"}
processor = StreamProcessor(request_data, {"sub": "user_123"})
processor._configure_agent()
assert isinstance(processor.agent_config, dict)
@pytest.mark.unit
class TestStreamProcessorAttachments:
def test_processes_attachments_from_request(self, mock_mongo_db):
from application.api.answer.services.stream_processor import StreamProcessor
from application.core.settings import settings
attachments_collection = mock_mongo_db[settings.MONGO_DB_NAME]["attachments"]
att_id = ObjectId()
attachments_collection.insert_one(
{
"_id": att_id,
"filename": "document.pdf",
"content": "Document content",
"user": "user_123",
}
)
request_data = {"question": "Analyze this", "attachments": [str(att_id)]}
processor = StreamProcessor(request_data, {"sub": "user_123"})
assert processor.data.get("attachments") == [str(att_id)]
def test_handles_empty_attachments(self, mock_mongo_db):
from application.api.answer.services.stream_processor import StreamProcessor
request_data = {"question": "Simple question"}
processor = StreamProcessor(request_data, {"sub": "user_123"})
assert processor.attachments == []
assert (
"attachments" not in processor.data
or processor.data.get("attachments") is None
)

89
tests/api/conftest.py Normal file
View File

@@ -0,0 +1,89 @@
"""API-specific test fixtures."""
import pytest
from bson import ObjectId
@pytest.fixture
def auth_headers():
return {"Authorization": "Bearer test_token"}
@pytest.fixture
def mock_request_token(monkeypatch, decoded_token):
def mock_decorator(f):
def wrapper(*args, **kwargs):
from flask import request
request.decoded_token = decoded_token
return f(*args, **kwargs)
return wrapper
monkeypatch.setattr("application.auth.api_key_required", lambda: mock_decorator)
return decoded_token
@pytest.fixture
def sample_conversation():
return {
"_id": ObjectId(),
"user": "test_user",
"name": "Test Conversation",
"queries": [
{
"prompt": "What is Python?",
"response": "Python is a programming language",
}
],
"date": "2025-01-01T00:00:00",
}
@pytest.fixture
def sample_prompt():
return {
"_id": ObjectId(),
"user": "test_user",
"name": "Helpful Assistant",
"content": "You are a helpful assistant that provides clear and concise answers.",
"type": "custom",
}
@pytest.fixture
def sample_agent():
return {
"_id": ObjectId(),
"user": "test_user",
"name": "Test Agent",
"type": "classic",
"endpoint": "openai",
"model": "gpt-4",
"prompt_id": "default",
"status": "active",
}
@pytest.fixture
def sample_answer_request():
return {
"question": "What is Python?",
"history": [],
"conversation_id": None,
"prompt_id": "default",
"chunks": 2,
"token_limit": 1000,
"retriever": "classic_rag",
"active_docs": "local/test/",
"isNoneDoc": False,
"save_conversation": True,
}
@pytest.fixture
def flask_app():
from flask import Flask
app = Flask(__name__)
return app

311
tests/api/user/test_base.py Normal file
View File

@@ -0,0 +1,311 @@
import datetime
import io
from unittest.mock import Mock, patch
import pytest
from bson import ObjectId
from werkzeug.datastructures import FileStorage
@pytest.mark.unit
class TestTimeRangeGenerators:
def test_generate_minute_range(self):
from application.api.user.base import generate_minute_range
start = datetime.datetime(2024, 1, 1, 10, 0, 0)
end = datetime.datetime(2024, 1, 1, 10, 5, 0)
result = generate_minute_range(start, end)
assert len(result) == 6
assert "2024-01-01 10:00:00" in result
assert "2024-01-01 10:05:00" in result
assert all(val == 0 for val in result.values())
def test_generate_hourly_range(self):
from application.api.user.base import generate_hourly_range
start = datetime.datetime(2024, 1, 1, 10, 0, 0)
end = datetime.datetime(2024, 1, 1, 15, 0, 0)
result = generate_hourly_range(start, end)
assert len(result) == 6
assert "2024-01-01 10:00" in result
assert "2024-01-01 15:00" in result
assert all(val == 0 for val in result.values())
def test_generate_date_range(self):
from application.api.user.base import generate_date_range
start = datetime.date(2024, 1, 1)
end = datetime.date(2024, 1, 5)
result = generate_date_range(start, end)
assert len(result) == 5
assert "2024-01-01" in result
assert "2024-01-05" in result
assert all(val == 0 for val in result.values())
def test_single_minute_range(self):
from application.api.user.base import generate_minute_range
time = datetime.datetime(2024, 1, 1, 10, 30, 0)
result = generate_minute_range(time, time)
assert len(result) == 1
assert "2024-01-01 10:30:00" in result
@pytest.mark.unit
class TestEnsureUserDoc:
def test_creates_new_user_with_defaults(self, mock_mongo_db):
from application.api.user.base import ensure_user_doc
user_id = "test_user_123"
result = ensure_user_doc(user_id)
assert result is not None
assert result["user_id"] == user_id
assert "agent_preferences" in result
assert result["agent_preferences"]["pinned"] == []
assert result["agent_preferences"]["shared_with_me"] == []
def test_returns_existing_user(self, mock_mongo_db):
from application.api.user.base import ensure_user_doc
from application.core.settings import settings
users_collection = mock_mongo_db[settings.MONGO_DB_NAME]["users"]
user_id = "existing_user"
existing_doc = {
"user_id": user_id,
"agent_preferences": {"pinned": ["agent1"], "shared_with_me": ["agent2"]},
}
users_collection.insert_one(existing_doc)
result = ensure_user_doc(user_id)
assert result["user_id"] == user_id
assert result["agent_preferences"]["pinned"] == ["agent1"]
assert result["agent_preferences"]["shared_with_me"] == ["agent2"]
def test_adds_missing_preferences_fields(self, mock_mongo_db):
from application.api.user.base import ensure_user_doc
from application.core.settings import settings
users_collection = mock_mongo_db[settings.MONGO_DB_NAME]["users"]
user_id = "incomplete_user"
users_collection.insert_one(
{"user_id": user_id, "agent_preferences": {"pinned": ["agent1"]}}
)
result = ensure_user_doc(user_id)
assert "shared_with_me" in result["agent_preferences"]
assert result["agent_preferences"]["shared_with_me"] == []
@pytest.mark.unit
class TestResolveToolDetails:
def test_resolves_tool_ids_to_details(self, mock_mongo_db):
from application.api.user.base import resolve_tool_details
from application.core.settings import settings
user_tools = mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"]
tool_id1 = ObjectId()
tool_id2 = ObjectId()
user_tools.insert_one(
{"_id": tool_id1, "name": "calculator", "displayName": "Calculator Tool"}
)
user_tools.insert_one(
{"_id": tool_id2, "name": "weather", "displayName": "Weather API"}
)
result = resolve_tool_details([str(tool_id1), str(tool_id2)])
assert len(result) == 2
assert result[0]["id"] == str(tool_id1)
assert result[0]["name"] == "calculator"
assert result[0]["display_name"] == "Calculator Tool"
assert result[1]["name"] == "weather"
def test_handles_missing_display_name(self, mock_mongo_db):
from application.api.user.base import resolve_tool_details
from application.core.settings import settings
user_tools = mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"]
tool_id = ObjectId()
user_tools.insert_one({"_id": tool_id, "name": "test_tool"})
result = resolve_tool_details([str(tool_id)])
assert result[0]["display_name"] == "test_tool"
def test_empty_tool_ids_list(self, mock_mongo_db):
from application.api.user.base import resolve_tool_details
result = resolve_tool_details([])
assert result == []
@pytest.mark.unit
class TestGetVectorStore:
@patch("application.api.user.base.VectorCreator.create_vectorstore")
def test_creates_vector_store(self, mock_create, mock_mongo_db):
from application.api.user.base import get_vector_store
mock_store = Mock()
mock_create.return_value = mock_store
source_id = "test_source_123"
result = get_vector_store(source_id)
assert result == mock_store
mock_create.assert_called_once()
args, kwargs = mock_create.call_args
assert kwargs.get("source_id") == source_id
@pytest.mark.unit
class TestHandleImageUpload:
def test_returns_existing_url_when_no_file(self, flask_app):
from application.api.user.base import handle_image_upload
with flask_app.test_request_context():
mock_request = Mock()
mock_request.files = {}
mock_storage = Mock()
existing_url = "existing/path/image.jpg"
url, error = handle_image_upload(
mock_request, existing_url, "user123", mock_storage
)
assert url == existing_url
assert error is None
def test_uploads_new_image(self, flask_app):
from application.api.user.base import handle_image_upload
with flask_app.test_request_context():
mock_file = FileStorage(
stream=io.BytesIO(b"fake image data"), filename="test_image.png"
)
mock_request = Mock()
mock_request.files = {"image": mock_file}
mock_storage = Mock()
mock_storage.save_file.return_value = {"success": True}
url, error = handle_image_upload(
mock_request, "old_url", "user123", mock_storage
)
assert error is None
assert url is not None
assert "test_image.png" in url
assert "user123" in url
mock_storage.save_file.assert_called_once()
def test_ignores_empty_filename(self, flask_app):
from application.api.user.base import handle_image_upload
with flask_app.test_request_context():
mock_file = Mock()
mock_file.filename = ""
mock_request = Mock()
mock_request.files = {"image": mock_file}
mock_storage = Mock()
existing_url = "existing.jpg"
url, error = handle_image_upload(
mock_request, existing_url, "user123", mock_storage
)
assert url == existing_url
assert error is None
mock_storage.save_file.assert_not_called()
def test_handles_upload_error(self, flask_app):
from application.api.user.base import handle_image_upload
with flask_app.app_context():
mock_file = FileStorage(stream=io.BytesIO(b"data"), filename="test.png")
mock_request = Mock()
mock_request.files = {"image": mock_file}
mock_storage = Mock()
mock_storage.save_file.side_effect = Exception("Storage error")
url, error = handle_image_upload(
mock_request, "old.jpg", "user123", mock_storage
)
assert url is None
assert error is not None
assert error.status_code == 400
@pytest.mark.unit
class TestRequireAgentDecorator:
def test_validates_webhook_token(self, mock_mongo_db, flask_app):
from application.api.user.base import require_agent
from application.core.settings import settings
with flask_app.app_context():
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
agent_id = ObjectId()
webhook_token = "valid_webhook_token_123"
agents_collection.insert_one(
{"_id": agent_id, "incoming_webhook_token": webhook_token}
)
@require_agent
def test_func(webhook_token=None, agent=None, agent_id_str=None):
return {"agent_id": agent_id_str}
result = test_func(webhook_token=webhook_token)
assert result["agent_id"] == str(agent_id)
def test_returns_400_for_missing_token(self, mock_mongo_db, flask_app):
from application.api.user.base import require_agent
with flask_app.app_context():
@require_agent
def test_func(webhook_token=None, agent=None, agent_id_str=None):
return {"success": True}
result = test_func()
assert result.status_code == 400
assert result.json["success"] is False
assert "missing" in result.json["message"].lower()
def test_returns_404_for_invalid_token(self, mock_mongo_db, flask_app):
from application.api.user.base import require_agent
with flask_app.app_context():
@require_agent
def test_func(webhook_token=None, agent=None, agent_id_str=None):
return {"success": True}
result = test_func(webhook_token="invalid_token_999")
assert result.status_code == 404
assert result.json["success"] is False
assert "not found" in result.json["message"].lower()

View File

@@ -1,7 +1,15 @@
from unittest.mock import Mock
import mongomock
import pytest
from application.core.settings import settings
def get_settings():
"""Lazy load settings to avoid import-time errors."""
from application.core.settings import settings
return settings
@pytest.fixture
@@ -35,18 +43,51 @@ def mock_retriever():
@pytest.fixture
def mock_mongo_db(monkeypatch):
fake_collection = FakeMongoCollection()
fake_db = {
"agents": fake_collection,
"user_tools": fake_collection,
"memories": fake_collection,
}
fake_client = {settings.MONGO_DB_NAME: fake_db}
"""Mock MongoDB using mongomock - industry standard MongoDB mocking library."""
settings = get_settings()
mock_client = mongomock.MongoClient()
mock_db = mock_client[settings.MONGO_DB_NAME]
def get_mock_client():
return {settings.MONGO_DB_NAME: mock_db}
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", get_mock_client)
monkeypatch.setattr("application.api.user.base.users_collection", mock_db["users"])
monkeypatch.setattr(
"application.core.mongo_db.MongoDB.get_client", lambda: fake_client
"application.api.user.base.user_tools_collection", mock_db["user_tools"]
)
return fake_client
monkeypatch.setattr(
"application.api.user.base.agents_collection", mock_db["agents"]
)
monkeypatch.setattr(
"application.api.user.base.conversations_collection", mock_db["conversations"]
)
monkeypatch.setattr(
"application.api.user.base.sources_collection", mock_db["sources"]
)
monkeypatch.setattr(
"application.api.user.base.prompts_collection", mock_db["prompts"]
)
monkeypatch.setattr(
"application.api.user.base.feedback_collection", mock_db["feedback"]
)
monkeypatch.setattr(
"application.api.user.base.token_usage_collection", mock_db["token_usage"]
)
monkeypatch.setattr(
"application.api.user.base.attachments_collection", mock_db["attachments"]
)
monkeypatch.setattr(
"application.api.user.base.user_logs_collection", mock_db["user_logs"]
)
monkeypatch.setattr(
"application.api.user.base.shared_conversations_collections",
mock_db["shared_conversations"],
)
return get_mock_client()
@pytest.fixture
@@ -87,53 +128,6 @@ def log_context():
return context
class FakeMongoCollection:
def __init__(self):
self.docs = {}
def find_one(self, query, projection=None):
if "key" in query:
return self.docs.get(query["key"])
if "_id" in query:
return self.docs.get(str(query["_id"]))
if "user" in query:
for doc in self.docs.values():
if doc.get("user") == query["user"]:
return doc
return None
def find(self, query, projection=None):
results = []
if "_id" in query and "$in" in query["_id"]:
for doc_id in query["_id"]["$in"]:
doc = self.docs.get(str(doc_id))
if doc:
results.append(doc)
elif "user" in query:
for doc in self.docs.values():
if doc.get("user") == query["user"]:
if "status" in query:
if doc.get("status") == query["status"]:
results.append(doc)
else:
results.append(doc)
return results
def insert_one(self, doc):
doc_id = doc.get("_id", len(self.docs))
self.docs[str(doc_id)] = doc
return Mock(inserted_id=doc_id)
def update_one(self, query, update, upsert=False):
return Mock(modified_count=1)
def delete_one(self, query):
return Mock(deleted_count=1)
def delete_many(self, query):
return Mock(deleted_count=0)
@pytest.fixture
def mock_llm_creator(mock_llm, monkeypatch):
monkeypatch.setattr(

View File

@@ -1,19 +1,13 @@
from unittest.mock import Mock, patch
from typing import Any, Dict, Generator
from unittest.mock import Mock, patch
from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall
class TestToolCall:
"""Test ToolCall dataclass."""
def test_tool_call_creation(self):
"""Test basic ToolCall creation."""
tool_call = ToolCall(
id="test_id",
name="test_function",
arguments={"arg1": "value1"},
index=0
id="test_id", name="test_function", arguments={"arg1": "value1"}, index=0
)
assert tool_call.id == "test_id"
assert tool_call.name == "test_function"
@@ -21,12 +15,11 @@ class TestToolCall:
assert tool_call.index == 0
def test_tool_call_from_dict(self):
"""Test ToolCall creation from dictionary."""
data = {
"id": "call_123",
"name": "get_weather",
"arguments": {"location": "New York"},
"index": 1
"index": 1,
}
tool_call = ToolCall.from_dict(data)
assert tool_call.id == "call_123"
@@ -35,7 +28,6 @@ class TestToolCall:
assert tool_call.index == 1
def test_tool_call_from_dict_missing_fields(self):
"""Test ToolCall creation with missing fields."""
data = {"name": "test_func"}
tool_call = ToolCall.from_dict(data)
assert tool_call.id == ""
@@ -45,16 +37,13 @@ class TestToolCall:
class TestLLMResponse:
"""Test LLMResponse dataclass."""
def test_llm_response_creation(self):
"""Test basic LLMResponse creation."""
tool_calls = [ToolCall(id="1", name="func", arguments={})]
response = LLMResponse(
content="Hello",
tool_calls=tool_calls,
finish_reason="tool_calls",
raw_response={"test": "data"}
raw_response={"test": "data"},
)
assert response.content == "Hello"
assert len(response.tool_calls) == 1
@@ -62,55 +51,43 @@ class TestLLMResponse:
assert response.raw_response == {"test": "data"}
def test_requires_tool_call_true(self):
"""Test requires_tool_call property when tool calls are needed."""
tool_calls = [ToolCall(id="1", name="func", arguments={})]
response = LLMResponse(
content="",
tool_calls=tool_calls,
finish_reason="tool_calls",
raw_response={}
raw_response={},
)
assert response.requires_tool_call is True
def test_requires_tool_call_false_no_tools(self):
"""Test requires_tool_call property when no tool calls."""
response = LLMResponse(
content="Hello",
tool_calls=[],
finish_reason="stop",
raw_response={}
content="Hello", tool_calls=[], finish_reason="stop", raw_response={}
)
assert response.requires_tool_call is False
def test_requires_tool_call_false_wrong_finish_reason(self):
"""Test requires_tool_call property with tools but wrong finish reason."""
tool_calls = [ToolCall(id="1", name="func", arguments={})]
response = LLMResponse(
content="Hello",
tool_calls=tool_calls,
finish_reason="stop",
raw_response={}
raw_response={},
)
assert response.requires_tool_call is False
class ConcreteHandler(LLMHandler):
"""Concrete implementation for testing abstract base class."""
def parse_response(self, response: Any) -> LLMResponse:
return LLMResponse(
content=str(response),
tool_calls=[],
finish_reason="stop",
raw_response=response
raw_response=response,
)
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
return {
"role": "tool",
"content": str(result),
"tool_call_id": tool_call.id
}
return {"role": "tool", "content": str(result), "tool_call_id": tool_call.id}
def _iterate_stream(self, response: Any) -> Generator:
for chunk in response:
@@ -118,114 +95,119 @@ class ConcreteHandler(LLMHandler):
class TestLLMHandler:
"""Test LLMHandler base class."""
def test_handler_initialization(self):
"""Test handler initialization."""
handler = ConcreteHandler()
assert handler.llm_calls == []
assert handler.tool_calls == []
def test_prepare_messages_no_attachments(self):
"""Test prepare_messages with no attachments."""
handler = ConcreteHandler()
messages = [{"role": "user", "content": "Hello"}]
mock_agent = Mock()
result = handler.prepare_messages(mock_agent, messages, None)
assert result == messages
def test_prepare_messages_with_supported_attachments(self):
"""Test prepare_messages with supported attachments."""
handler = ConcreteHandler()
messages = [{"role": "user", "content": "Hello"}]
attachments = [{"mime_type": "image/png", "path": "/test.png"}]
mock_agent = Mock()
mock_agent.llm.get_supported_attachment_types.return_value = ["image/png"]
mock_agent.llm.prepare_messages_with_attachments.return_value = messages
result = handler.prepare_messages(mock_agent, messages, attachments)
mock_agent.llm.prepare_messages_with_attachments.assert_called_once_with(
messages, attachments
)
assert result == messages
@patch('application.llm.handlers.base.logger')
@patch("application.llm.handlers.base.logger")
def test_prepare_messages_with_unsupported_attachments(self, mock_logger):
"""Test prepare_messages with unsupported attachments."""
handler = ConcreteHandler()
messages = [{"role": "user", "content": "Hello"}]
attachments = [{"mime_type": "text/plain", "path": "/test.txt"}]
mock_agent = Mock()
mock_agent.llm.get_supported_attachment_types.return_value = ["image/png"]
with patch.object(handler, '_append_unsupported_attachments', return_value=messages) as mock_append:
with patch.object(
handler, "_append_unsupported_attachments", return_value=messages
) as mock_append:
result = handler.prepare_messages(mock_agent, messages, attachments)
mock_append.assert_called_once_with(messages, attachments)
assert result == messages
def test_prepare_messages_mixed_attachments(self):
"""Test prepare_messages with both supported and unsupported attachments."""
handler = ConcreteHandler()
messages = [{"role": "user", "content": "Hello"}]
attachments = [
{"mime_type": "image/png", "path": "/test.png"},
{"mime_type": "text/plain", "path": "/test.txt"}
{"mime_type": "text/plain", "path": "/test.txt"},
]
mock_agent = Mock()
mock_agent.llm.get_supported_attachment_types.return_value = ["image/png"]
mock_agent.llm.prepare_messages_with_attachments.return_value = messages
with patch.object(handler, '_append_unsupported_attachments', return_value=messages) as mock_append:
with patch.object(
handler, "_append_unsupported_attachments", return_value=messages
) as mock_append:
result = handler.prepare_messages(mock_agent, messages, attachments)
# Should call both methods
mock_agent.llm.prepare_messages_with_attachments.assert_called_once()
mock_append.assert_called_once()
assert result == messages
def test_process_message_flow_non_streaming(self):
"""Test process_message_flow for non-streaming."""
handler = ConcreteHandler()
mock_agent = Mock()
initial_response = "test response"
tools_dict = {}
messages = [{"role": "user", "content": "Hello"}]
with patch.object(handler, 'prepare_messages', return_value=messages) as mock_prepare:
with patch.object(handler, 'handle_non_streaming', return_value="final") as mock_handle:
with patch.object(
handler, "prepare_messages", return_value=messages
) as mock_prepare:
with patch.object(
handler, "handle_non_streaming", return_value="final"
) as mock_handle:
result = handler.process_message_flow(
mock_agent, initial_response, tools_dict, messages, stream=False
)
mock_prepare.assert_called_once_with(mock_agent, messages, None)
mock_handle.assert_called_once_with(mock_agent, initial_response, tools_dict, messages)
mock_handle.assert_called_once_with(
mock_agent, initial_response, tools_dict, messages
)
assert result == "final"
def test_process_message_flow_streaming(self):
"""Test process_message_flow for streaming."""
handler = ConcreteHandler()
mock_agent = Mock()
initial_response = "test response"
tools_dict = {}
messages = [{"role": "user", "content": "Hello"}]
def mock_generator():
yield "chunk1"
yield "chunk2"
with patch.object(handler, 'prepare_messages', return_value=messages) as mock_prepare:
with patch.object(handler, 'handle_streaming', return_value=mock_generator()) as mock_handle:
with patch.object(
handler, "prepare_messages", return_value=messages
) as mock_prepare:
with patch.object(
handler, "handle_streaming", return_value=mock_generator()
) as mock_handle:
result = handler.process_message_flow(
mock_agent, initial_response, tools_dict, messages, stream=True
)
mock_prepare.assert_called_once_with(mock_agent, messages, None)
mock_handle.assert_called_once_with(mock_agent, initial_response, tools_dict, messages)
# Verify it's a generator
mock_handle.assert_called_once_with(
mock_agent, initial_response, tools_dict, messages
)
chunks = list(result)
assert chunks == ["chunk1", "chunk2"]

View File

@@ -1,3 +1,4 @@
pytest>=8.0.0
pytest-cov>=4.1.0
coverage>=7.4.0
mongomock>=4.3.0