Files
DocsGPT/tests/api/user/sources/test_routes.py
Alex 9e7f1ad1c0 Add Amazon S3 support and synchronization features (#2244)
* Add Amazon S3 support and synchronization features

* refactor: remove unused variable in load_data test
2025-12-30 20:26:51 +02:00

358 lines
12 KiB
Python

"""Tests for sources routes."""
import json
import pytest
from unittest.mock import MagicMock, patch
from bson import ObjectId
class TestGetProviderFromRemoteData:
"""Test the _get_provider_from_remote_data helper function."""
def test_returns_none_for_none_input(self):
"""Should return None when remote_data is None."""
from application.api.user.sources.routes import _get_provider_from_remote_data
result = _get_provider_from_remote_data(None)
assert result is None
def test_returns_none_for_empty_string(self):
"""Should return None when remote_data is empty string."""
from application.api.user.sources.routes import _get_provider_from_remote_data
result = _get_provider_from_remote_data("")
assert result is None
def test_extracts_provider_from_dict(self):
"""Should extract provider from dict remote_data."""
from application.api.user.sources.routes import _get_provider_from_remote_data
remote_data = {"provider": "s3", "bucket": "my-bucket"}
result = _get_provider_from_remote_data(remote_data)
assert result == "s3"
def test_extracts_provider_from_json_string(self):
"""Should extract provider from JSON string remote_data."""
from application.api.user.sources.routes import _get_provider_from_remote_data
remote_data = json.dumps({"provider": "github", "repo": "test/repo"})
result = _get_provider_from_remote_data(remote_data)
assert result == "github"
def test_returns_none_for_dict_without_provider(self):
"""Should return None when dict has no provider key."""
from application.api.user.sources.routes import _get_provider_from_remote_data
remote_data = {"bucket": "my-bucket", "region": "us-east-1"}
result = _get_provider_from_remote_data(remote_data)
assert result is None
def test_returns_none_for_invalid_json(self):
"""Should return None for invalid JSON string."""
from application.api.user.sources.routes import _get_provider_from_remote_data
result = _get_provider_from_remote_data("not valid json")
assert result is None
def test_returns_none_for_json_array(self):
"""Should return None when JSON parses to non-dict."""
from application.api.user.sources.routes import _get_provider_from_remote_data
result = _get_provider_from_remote_data('["item1", "item2"]')
assert result is None
def test_returns_none_for_non_string_non_dict(self):
"""Should return None for other types like int."""
from application.api.user.sources.routes import _get_provider_from_remote_data
result = _get_provider_from_remote_data(123)
assert result is None
def _get_response_status(response):
"""Helper to get status code from response (handles both tuple and Response)."""
if isinstance(response, tuple):
return response[1]
return response.status_code
def _get_response_json(response):
"""Helper to get JSON from response (handles both tuple and Response)."""
if isinstance(response, tuple):
return response[0].json
return response.json
@pytest.mark.unit
class TestSyncSourceEndpoint:
"""Test the /sync_source endpoint."""
@pytest.fixture
def mock_sources_collection(self, mock_mongo_db):
"""Get mock sources collection."""
from application.core.settings import settings
return mock_mongo_db[settings.MONGO_DB_NAME]["sources"]
def test_sync_source_returns_401_without_token(self, flask_app):
"""Should return 401 when no decoded_token is present."""
from flask import Flask
from application.api.user.sources.routes import SyncSource
app = Flask(__name__)
with app.test_request_context(
"/api/sync_source", method="POST", json={"source_id": "123"}
):
from flask import request
request.decoded_token = None
resource = SyncSource()
response = resource.post()
assert _get_response_status(response) == 401
def test_sync_source_returns_400_for_missing_source_id(self, flask_app):
"""Should return 400 when source_id is missing."""
from flask import Flask
from application.api.user.sources.routes import SyncSource
app = Flask(__name__)
with app.test_request_context("/api/sync_source", method="POST", json={}):
from flask import request
request.decoded_token = {"sub": "test_user"}
resource = SyncSource()
response = resource.post()
# check_required_fields returns a response tuple on missing fields
assert response is not None
def test_sync_source_returns_400_for_invalid_source_id(self, flask_app):
"""Should return 400 for invalid ObjectId."""
from flask import Flask
from application.api.user.sources.routes import SyncSource
app = Flask(__name__)
with app.test_request_context(
"/api/sync_source", method="POST", json={"source_id": "invalid"}
):
from flask import request
request.decoded_token = {"sub": "test_user"}
resource = SyncSource()
response = resource.post()
assert _get_response_status(response) == 400
assert "Invalid source ID" in _get_response_json(response)["message"]
def test_sync_source_returns_404_for_nonexistent_source(
self, flask_app, mock_mongo_db
):
"""Should return 404 when source doesn't exist."""
from flask import Flask
from application.api.user.sources.routes import SyncSource
app = Flask(__name__)
source_id = str(ObjectId())
with app.test_request_context(
"/api/sync_source", method="POST", json={"source_id": source_id}
):
from flask import request
request.decoded_token = {"sub": "test_user"}
with patch(
"application.api.user.sources.routes.sources_collection",
mock_mongo_db["docsgpt"]["sources"],
):
resource = SyncSource()
response = resource.post()
assert _get_response_status(response) == 404
assert "not found" in _get_response_json(response)["message"]
def test_sync_source_returns_400_for_connector_type(
self, flask_app, mock_mongo_db, mock_sources_collection
):
"""Should return 400 for connector sources."""
from flask import Flask
from application.api.user.sources.routes import SyncSource
app = Flask(__name__)
source_id = ObjectId()
# Insert a connector source
mock_sources_collection.insert_one(
{
"_id": source_id,
"user": "test_user",
"type": "connector_slack",
"name": "Slack Source",
}
)
with app.test_request_context(
"/api/sync_source", method="POST", json={"source_id": str(source_id)}
):
from flask import request
request.decoded_token = {"sub": "test_user"}
with patch(
"application.api.user.sources.routes.sources_collection",
mock_sources_collection,
):
resource = SyncSource()
response = resource.post()
assert _get_response_status(response) == 400
assert "Connector sources" in _get_response_json(response)["message"]
def test_sync_source_returns_400_for_non_syncable_source(
self, flask_app, mock_mongo_db, mock_sources_collection
):
"""Should return 400 when source has no remote_data."""
from flask import Flask
from application.api.user.sources.routes import SyncSource
app = Flask(__name__)
source_id = ObjectId()
# Insert a source without remote_data
mock_sources_collection.insert_one(
{
"_id": source_id,
"user": "test_user",
"type": "file",
"name": "Local Source",
"remote_data": None,
}
)
with app.test_request_context(
"/api/sync_source", method="POST", json={"source_id": str(source_id)}
):
from flask import request
request.decoded_token = {"sub": "test_user"}
with patch(
"application.api.user.sources.routes.sources_collection",
mock_sources_collection,
):
resource = SyncSource()
response = resource.post()
assert _get_response_status(response) == 400
assert "not syncable" in _get_response_json(response)["message"]
def test_sync_source_triggers_sync_task(
self, flask_app, mock_mongo_db, mock_sources_collection
):
"""Should trigger sync task for valid syncable source."""
from flask import Flask
from application.api.user.sources.routes import SyncSource
app = Flask(__name__)
source_id = ObjectId()
# Insert a valid syncable source
mock_sources_collection.insert_one(
{
"_id": source_id,
"user": "test_user",
"type": "s3",
"name": "S3 Source",
"remote_data": json.dumps(
{
"provider": "s3",
"bucket": "my-bucket",
"aws_access_key_id": "key",
"aws_secret_access_key": "secret",
}
),
"sync_frequency": "daily",
"retriever": "classic",
}
)
mock_task = MagicMock()
mock_task.id = "task-123"
with app.test_request_context(
"/api/sync_source", method="POST", json={"source_id": str(source_id)}
):
from flask import request
request.decoded_token = {"sub": "test_user"}
with patch(
"application.api.user.sources.routes.sources_collection",
mock_sources_collection,
):
with patch(
"application.api.user.sources.routes.sync_source"
) as mock_sync:
mock_sync.delay.return_value = mock_task
resource = SyncSource()
response = resource.post()
assert _get_response_status(response) == 200
assert _get_response_json(response)["success"] is True
assert _get_response_json(response)["task_id"] == "task-123"
mock_sync.delay.assert_called_once()
call_kwargs = mock_sync.delay.call_args[1]
assert call_kwargs["user"] == "test_user"
assert call_kwargs["loader"] == "s3"
assert call_kwargs["doc_id"] == str(source_id)
def test_sync_source_handles_task_error(
self, flask_app, mock_mongo_db, mock_sources_collection
):
"""Should return 400 when task fails to start."""
from flask import Flask
from application.api.user.sources.routes import SyncSource
app = Flask(__name__)
source_id = ObjectId()
mock_sources_collection.insert_one(
{
"_id": source_id,
"user": "test_user",
"type": "github",
"name": "GitHub Source",
"remote_data": "https://github.com/test/repo",
"sync_frequency": "weekly",
"retriever": "classic",
}
)
with app.test_request_context(
"/api/sync_source", method="POST", json={"source_id": str(source_id)}
):
from flask import request
request.decoded_token = {"sub": "test_user"}
with patch(
"application.api.user.sources.routes.sources_collection",
mock_sources_collection,
):
with patch(
"application.api.user.sources.routes.sync_source"
) as mock_sync:
mock_sync.delay.side_effect = Exception("Celery error")
resource = SyncSource()
response = resource.post()
assert _get_response_status(response) == 400
assert _get_response_json(response)["success"] is False