mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-01-25 16:30:36 +00:00
* Add Amazon S3 support and synchronization features * refactor: remove unused variable in load_data test
358 lines
12 KiB
Python
358 lines
12 KiB
Python
"""Tests for sources routes."""
|
|
|
|
import json
|
|
import pytest
|
|
from unittest.mock import MagicMock, patch
|
|
from bson import ObjectId
|
|
|
|
|
|
class TestGetProviderFromRemoteData:
|
|
"""Test the _get_provider_from_remote_data helper function."""
|
|
|
|
def test_returns_none_for_none_input(self):
|
|
"""Should return None when remote_data is None."""
|
|
from application.api.user.sources.routes import _get_provider_from_remote_data
|
|
|
|
result = _get_provider_from_remote_data(None)
|
|
assert result is None
|
|
|
|
def test_returns_none_for_empty_string(self):
|
|
"""Should return None when remote_data is empty string."""
|
|
from application.api.user.sources.routes import _get_provider_from_remote_data
|
|
|
|
result = _get_provider_from_remote_data("")
|
|
assert result is None
|
|
|
|
def test_extracts_provider_from_dict(self):
|
|
"""Should extract provider from dict remote_data."""
|
|
from application.api.user.sources.routes import _get_provider_from_remote_data
|
|
|
|
remote_data = {"provider": "s3", "bucket": "my-bucket"}
|
|
result = _get_provider_from_remote_data(remote_data)
|
|
assert result == "s3"
|
|
|
|
def test_extracts_provider_from_json_string(self):
|
|
"""Should extract provider from JSON string remote_data."""
|
|
from application.api.user.sources.routes import _get_provider_from_remote_data
|
|
|
|
remote_data = json.dumps({"provider": "github", "repo": "test/repo"})
|
|
result = _get_provider_from_remote_data(remote_data)
|
|
assert result == "github"
|
|
|
|
def test_returns_none_for_dict_without_provider(self):
|
|
"""Should return None when dict has no provider key."""
|
|
from application.api.user.sources.routes import _get_provider_from_remote_data
|
|
|
|
remote_data = {"bucket": "my-bucket", "region": "us-east-1"}
|
|
result = _get_provider_from_remote_data(remote_data)
|
|
assert result is None
|
|
|
|
def test_returns_none_for_invalid_json(self):
|
|
"""Should return None for invalid JSON string."""
|
|
from application.api.user.sources.routes import _get_provider_from_remote_data
|
|
|
|
result = _get_provider_from_remote_data("not valid json")
|
|
assert result is None
|
|
|
|
def test_returns_none_for_json_array(self):
|
|
"""Should return None when JSON parses to non-dict."""
|
|
from application.api.user.sources.routes import _get_provider_from_remote_data
|
|
|
|
result = _get_provider_from_remote_data('["item1", "item2"]')
|
|
assert result is None
|
|
|
|
def test_returns_none_for_non_string_non_dict(self):
|
|
"""Should return None for other types like int."""
|
|
from application.api.user.sources.routes import _get_provider_from_remote_data
|
|
|
|
result = _get_provider_from_remote_data(123)
|
|
assert result is None
|
|
|
|
|
|
def _get_response_status(response):
|
|
"""Helper to get status code from response (handles both tuple and Response)."""
|
|
if isinstance(response, tuple):
|
|
return response[1]
|
|
return response.status_code
|
|
|
|
|
|
def _get_response_json(response):
|
|
"""Helper to get JSON from response (handles both tuple and Response)."""
|
|
if isinstance(response, tuple):
|
|
return response[0].json
|
|
return response.json
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestSyncSourceEndpoint:
|
|
"""Test the /sync_source endpoint."""
|
|
|
|
@pytest.fixture
|
|
def mock_sources_collection(self, mock_mongo_db):
|
|
"""Get mock sources collection."""
|
|
from application.core.settings import settings
|
|
|
|
return mock_mongo_db[settings.MONGO_DB_NAME]["sources"]
|
|
|
|
def test_sync_source_returns_401_without_token(self, flask_app):
|
|
"""Should return 401 when no decoded_token is present."""
|
|
from flask import Flask
|
|
from application.api.user.sources.routes import SyncSource
|
|
|
|
app = Flask(__name__)
|
|
|
|
with app.test_request_context(
|
|
"/api/sync_source", method="POST", json={"source_id": "123"}
|
|
):
|
|
from flask import request
|
|
|
|
request.decoded_token = None
|
|
resource = SyncSource()
|
|
response = resource.post()
|
|
|
|
assert _get_response_status(response) == 401
|
|
|
|
def test_sync_source_returns_400_for_missing_source_id(self, flask_app):
|
|
"""Should return 400 when source_id is missing."""
|
|
from flask import Flask
|
|
from application.api.user.sources.routes import SyncSource
|
|
|
|
app = Flask(__name__)
|
|
|
|
with app.test_request_context("/api/sync_source", method="POST", json={}):
|
|
from flask import request
|
|
|
|
request.decoded_token = {"sub": "test_user"}
|
|
resource = SyncSource()
|
|
response = resource.post()
|
|
|
|
# check_required_fields returns a response tuple on missing fields
|
|
assert response is not None
|
|
|
|
def test_sync_source_returns_400_for_invalid_source_id(self, flask_app):
|
|
"""Should return 400 for invalid ObjectId."""
|
|
from flask import Flask
|
|
from application.api.user.sources.routes import SyncSource
|
|
|
|
app = Flask(__name__)
|
|
|
|
with app.test_request_context(
|
|
"/api/sync_source", method="POST", json={"source_id": "invalid"}
|
|
):
|
|
from flask import request
|
|
|
|
request.decoded_token = {"sub": "test_user"}
|
|
resource = SyncSource()
|
|
response = resource.post()
|
|
|
|
assert _get_response_status(response) == 400
|
|
assert "Invalid source ID" in _get_response_json(response)["message"]
|
|
|
|
def test_sync_source_returns_404_for_nonexistent_source(
|
|
self, flask_app, mock_mongo_db
|
|
):
|
|
"""Should return 404 when source doesn't exist."""
|
|
from flask import Flask
|
|
from application.api.user.sources.routes import SyncSource
|
|
|
|
app = Flask(__name__)
|
|
source_id = str(ObjectId())
|
|
|
|
with app.test_request_context(
|
|
"/api/sync_source", method="POST", json={"source_id": source_id}
|
|
):
|
|
from flask import request
|
|
|
|
request.decoded_token = {"sub": "test_user"}
|
|
|
|
with patch(
|
|
"application.api.user.sources.routes.sources_collection",
|
|
mock_mongo_db["docsgpt"]["sources"],
|
|
):
|
|
resource = SyncSource()
|
|
response = resource.post()
|
|
|
|
assert _get_response_status(response) == 404
|
|
assert "not found" in _get_response_json(response)["message"]
|
|
|
|
def test_sync_source_returns_400_for_connector_type(
|
|
self, flask_app, mock_mongo_db, mock_sources_collection
|
|
):
|
|
"""Should return 400 for connector sources."""
|
|
from flask import Flask
|
|
from application.api.user.sources.routes import SyncSource
|
|
|
|
app = Flask(__name__)
|
|
source_id = ObjectId()
|
|
|
|
# Insert a connector source
|
|
mock_sources_collection.insert_one(
|
|
{
|
|
"_id": source_id,
|
|
"user": "test_user",
|
|
"type": "connector_slack",
|
|
"name": "Slack Source",
|
|
}
|
|
)
|
|
|
|
with app.test_request_context(
|
|
"/api/sync_source", method="POST", json={"source_id": str(source_id)}
|
|
):
|
|
from flask import request
|
|
|
|
request.decoded_token = {"sub": "test_user"}
|
|
|
|
with patch(
|
|
"application.api.user.sources.routes.sources_collection",
|
|
mock_sources_collection,
|
|
):
|
|
resource = SyncSource()
|
|
response = resource.post()
|
|
|
|
assert _get_response_status(response) == 400
|
|
assert "Connector sources" in _get_response_json(response)["message"]
|
|
|
|
def test_sync_source_returns_400_for_non_syncable_source(
|
|
self, flask_app, mock_mongo_db, mock_sources_collection
|
|
):
|
|
"""Should return 400 when source has no remote_data."""
|
|
from flask import Flask
|
|
from application.api.user.sources.routes import SyncSource
|
|
|
|
app = Flask(__name__)
|
|
source_id = ObjectId()
|
|
|
|
# Insert a source without remote_data
|
|
mock_sources_collection.insert_one(
|
|
{
|
|
"_id": source_id,
|
|
"user": "test_user",
|
|
"type": "file",
|
|
"name": "Local Source",
|
|
"remote_data": None,
|
|
}
|
|
)
|
|
|
|
with app.test_request_context(
|
|
"/api/sync_source", method="POST", json={"source_id": str(source_id)}
|
|
):
|
|
from flask import request
|
|
|
|
request.decoded_token = {"sub": "test_user"}
|
|
|
|
with patch(
|
|
"application.api.user.sources.routes.sources_collection",
|
|
mock_sources_collection,
|
|
):
|
|
resource = SyncSource()
|
|
response = resource.post()
|
|
|
|
assert _get_response_status(response) == 400
|
|
assert "not syncable" in _get_response_json(response)["message"]
|
|
|
|
def test_sync_source_triggers_sync_task(
|
|
self, flask_app, mock_mongo_db, mock_sources_collection
|
|
):
|
|
"""Should trigger sync task for valid syncable source."""
|
|
from flask import Flask
|
|
from application.api.user.sources.routes import SyncSource
|
|
|
|
app = Flask(__name__)
|
|
source_id = ObjectId()
|
|
|
|
# Insert a valid syncable source
|
|
mock_sources_collection.insert_one(
|
|
{
|
|
"_id": source_id,
|
|
"user": "test_user",
|
|
"type": "s3",
|
|
"name": "S3 Source",
|
|
"remote_data": json.dumps(
|
|
{
|
|
"provider": "s3",
|
|
"bucket": "my-bucket",
|
|
"aws_access_key_id": "key",
|
|
"aws_secret_access_key": "secret",
|
|
}
|
|
),
|
|
"sync_frequency": "daily",
|
|
"retriever": "classic",
|
|
}
|
|
)
|
|
|
|
mock_task = MagicMock()
|
|
mock_task.id = "task-123"
|
|
|
|
with app.test_request_context(
|
|
"/api/sync_source", method="POST", json={"source_id": str(source_id)}
|
|
):
|
|
from flask import request
|
|
|
|
request.decoded_token = {"sub": "test_user"}
|
|
|
|
with patch(
|
|
"application.api.user.sources.routes.sources_collection",
|
|
mock_sources_collection,
|
|
):
|
|
with patch(
|
|
"application.api.user.sources.routes.sync_source"
|
|
) as mock_sync:
|
|
mock_sync.delay.return_value = mock_task
|
|
|
|
resource = SyncSource()
|
|
response = resource.post()
|
|
|
|
assert _get_response_status(response) == 200
|
|
assert _get_response_json(response)["success"] is True
|
|
assert _get_response_json(response)["task_id"] == "task-123"
|
|
|
|
mock_sync.delay.assert_called_once()
|
|
call_kwargs = mock_sync.delay.call_args[1]
|
|
assert call_kwargs["user"] == "test_user"
|
|
assert call_kwargs["loader"] == "s3"
|
|
assert call_kwargs["doc_id"] == str(source_id)
|
|
|
|
def test_sync_source_handles_task_error(
|
|
self, flask_app, mock_mongo_db, mock_sources_collection
|
|
):
|
|
"""Should return 400 when task fails to start."""
|
|
from flask import Flask
|
|
from application.api.user.sources.routes import SyncSource
|
|
|
|
app = Flask(__name__)
|
|
source_id = ObjectId()
|
|
|
|
mock_sources_collection.insert_one(
|
|
{
|
|
"_id": source_id,
|
|
"user": "test_user",
|
|
"type": "github",
|
|
"name": "GitHub Source",
|
|
"remote_data": "https://github.com/test/repo",
|
|
"sync_frequency": "weekly",
|
|
"retriever": "classic",
|
|
}
|
|
)
|
|
|
|
with app.test_request_context(
|
|
"/api/sync_source", method="POST", json={"source_id": str(source_id)}
|
|
):
|
|
from flask import request
|
|
|
|
request.decoded_token = {"sub": "test_user"}
|
|
|
|
with patch(
|
|
"application.api.user.sources.routes.sources_collection",
|
|
mock_sources_collection,
|
|
):
|
|
with patch(
|
|
"application.api.user.sources.routes.sync_source"
|
|
) as mock_sync:
|
|
mock_sync.delay.side_effect = Exception("Celery error")
|
|
|
|
resource = SyncSource()
|
|
response = resource.post()
|
|
|
|
assert _get_response_status(response) == 400
|
|
assert _get_response_json(response)["success"] is False
|