Compare commits

...

2 Commits

Author SHA1 Message Date
Alex
2ae0ef6cc0 fix: mini fixes 2026-05-17 23:30:16 +01:00
Alex
2399aff245 fix: broken syncs 2026-05-17 23:15:36 +01:00
6 changed files with 328 additions and 6 deletions

View File

@@ -9,6 +9,7 @@ from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.tasks import sync_source
from application.core.settings import settings
from application.parser.remote.remote_creator import normalize_remote_data
from application.storage.db.repositories.sources import SourcesRepository
from application.storage.db.session import db_readonly, db_session
from application.storage.storage_creator import StorageCreator
@@ -322,7 +323,7 @@ class SyncSource(Resource):
),
400,
)
source_data = doc.get("remote_data")
source_data = normalize_remote_data(source_type, doc.get("remote_data"))
if not source_data:
return make_response(
jsonify({"success": False, "message": "Source is not syncable"}), 400

View File

@@ -1,3 +1,5 @@
import json
from application.parser.remote.sitemap_loader import SitemapLoader
from application.parser.remote.crawler_loader import CrawlerLoader
from application.parser.remote.web_loader import WebLoader
@@ -32,3 +34,59 @@ class RemoteCreator:
if not loader_class:
raise ValueError(f"No loader class found for type {type}")
return loader_class(*args, **kwargs)
# Loader types whose load_data expects a URL string, not a config dict.
_URL_LOADER_TYPES = {"url", "crawler", "sitemap", "github"}
# Keys a remote_data dict may hold the URL under (``raw`` is the legacy shape).
_URL_DATA_KEYS = ("url", "urls", "repo_url", "raw")
def normalize_remote_data(source_type, remote_data):
"""Convert a stored ``sources.remote_data`` JSONB value into the
``source_data`` shape the matching loader expects.
Args:
source_type: The ``sources.type`` value (the loader name).
remote_data: The stored ``remote_data`` (dict, list, str, or None).
Returns:
Loader input: a URL string or list for url/crawler/sitemap/github,
a JSON string for reddit, a dict for s3; ``None`` when the row has
nothing syncable.
"""
if remote_data is None:
return None
# Some legacy rows stored the JSON itself as a string.
if isinstance(remote_data, str):
stripped = remote_data.strip()
if stripped[:1] in ("{", "["):
try:
remote_data = json.loads(stripped)
except json.JSONDecodeError:
# Not actually JSON — leave remote_data as the original
# string; the per-loader branches below handle a string.
pass
loader = (source_type or "").lower()
if loader in _URL_LOADER_TYPES:
if isinstance(remote_data, dict):
for key in _URL_DATA_KEYS:
value = remote_data.get(key)
if value:
return value
# No URL key — None keeps the loader off the dict-crash path.
return None
return remote_data
if loader == "reddit":
# reddit's loader runs json.loads() on its input — needs a string.
if isinstance(remote_data, (dict, list)):
return json.dumps(remote_data)
return remote_data
# s3's loader accepts a dict or JSON string; pass it through unchanged.
return remote_data

View File

@@ -29,7 +29,10 @@ from application.parser.embedding_pipeline import (
)
from application.parser.file.bulk import SimpleDirectoryReader, get_default_file_extractor
from application.parser.file.constants import SUPPORTED_SOURCE_EXTENSIONS
from application.parser.remote.remote_creator import RemoteCreator
from application.parser.remote.remote_creator import (
RemoteCreator,
normalize_remote_data,
)
from application.parser.schema.base import Document
from application.retriever.retriever_creator import RetrieverCreator
@@ -1431,19 +1434,35 @@ def sync_worker(self, frequency):
name = doc.get("name")
user = doc.get("user_id")
source_type = doc.get("type")
source_data = doc.get("remote_data")
retriever = doc.get("retriever")
doc_id = str(doc.get("id"))
sync_counts["total_sync_count"] += 1
# Connector sources have no RemoteCreator loader and need an OAuth
# token to sync, which a scheduled task lacks — skip them.
if source_type and source_type.startswith("connector"):
sync_counts["sync_skipped"] += 1
continue
source_data = normalize_remote_data(source_type, doc.get("remote_data"))
if not source_data:
# No syncable URL/config — skip instead of dispatching a sync
# that can only fail (and emit a spurious failed event).
sync_counts["sync_skipped"] += 1
continue
resp = sync(
self, source_data, name, user, source_type, frequency, retriever, doc_id
)
sync_counts["total_sync_count"] += 1
sync_counts[
"sync_success" if resp["status"] == "success" else "sync_failure"
] += 1
return {
key: sync_counts[key]
for key in ["total_sync_count", "sync_success", "sync_failure"]
for key in [
"total_sync_count", "sync_success", "sync_failure", "sync_skipped",
]
}

View File

@@ -553,6 +553,35 @@ class TestSyncSource:
assert response.status_code == 200
assert response.json["task_id"] == "task-123"
def test_normalizes_dict_remote_data_before_dispatch(self, app, pg_conn):
"""The route must hand the sync task the normalized URL string."""
from application.api.user.sources.routes import SyncSource
user = "u-normalize"
src = _seed_source(
pg_conn, user, name="crawl-src", type="crawler",
remote_data=json.dumps(
{"url": "https://example.com", "provider": "crawler"}
),
)
fake_task = MagicMock(id="task-norm")
with _patch_db(pg_conn), patch(
"application.api.user.sources.routes.sync_source.delay",
return_value=fake_task,
) as mock_delay, app.test_request_context(
"/api/sync_source",
method="POST",
json={"source_id": str(src["id"])},
):
from flask import request
request.decoded_token = {"sub": user}
response = SyncSource().post()
assert response.status_code == 200
assert mock_delay.call_args.kwargs["source_data"] == "https://example.com"
assert mock_delay.call_args.kwargs["loader"] == "crawler"
def test_sync_task_raises_returns_400(self, app, pg_conn):
from application.api.user.sources.routes import SyncSource

View File

@@ -1,4 +1,6 @@
"""Tests for application.parser.remote.remote_creator covering lines 31-34."""
"""Tests for application.parser.remote.remote_creator."""
import json
import pytest
from unittest.mock import MagicMock
@@ -38,3 +40,92 @@ class TestRemoteCreator:
mock_loader_cls.assert_called_once()
finally:
RemoteCreator.loaders = original_loaders
@pytest.mark.unit
class TestNormalizeRemoteData:
"""``normalize_remote_data`` maps a stored JSONB ``remote_data`` value
back to the ``source_data`` shape each loader expects."""
def test_none_passes_through(self):
from application.parser.remote.remote_creator import normalize_remote_data
assert normalize_remote_data("crawler", None) is None
def test_crawler_dict_with_url_key(self):
from application.parser.remote.remote_creator import normalize_remote_data
result = normalize_remote_data(
"crawler", {"url": "https://example.com", "provider": "crawler"}
)
assert result == "https://example.com"
def test_url_dict_with_url_key(self):
from application.parser.remote.remote_creator import normalize_remote_data
result = normalize_remote_data("url", {"url": "https://example.com"})
assert result == "https://example.com"
def test_url_legacy_raw_key(self):
"""Legacy rows wrapped a bare URL string as ``{"raw": ...}``."""
from application.parser.remote.remote_creator import normalize_remote_data
result = normalize_remote_data("crawler", {"raw": "https://legacy.example.com"})
assert result == "https://legacy.example.com"
def test_url_dict_with_urls_list(self):
from application.parser.remote.remote_creator import normalize_remote_data
result = normalize_remote_data(
"url", {"urls": ["https://a.example.com", "https://b.example.com"]}
)
assert result == ["https://a.example.com", "https://b.example.com"]
def test_github_repo_url_key(self):
from application.parser.remote.remote_creator import normalize_remote_data
result = normalize_remote_data(
"github", {"repo_url": "https://github.com/arc53/DocsGPT"}
)
assert result == "https://github.com/arc53/DocsGPT"
def test_sitemap_dict_with_url_key(self):
from application.parser.remote.remote_creator import normalize_remote_data
result = normalize_remote_data("sitemap", {"url": "https://example.com/sitemap.xml"})
assert result == "https://example.com/sitemap.xml"
def test_plain_string_url_passes_through(self):
from application.parser.remote.remote_creator import normalize_remote_data
assert normalize_remote_data("crawler", "https://example.com") == "https://example.com"
def test_url_dict_without_url_key_returns_none(self):
"""A URL-type loader must never receive a dict, even a malformed one."""
from application.parser.remote.remote_creator import normalize_remote_data
assert normalize_remote_data("crawler", {"provider": "crawler"}) is None
def test_reddit_dict_serialized_to_json_string(self):
"""reddit's loader runs json.loads() — it needs a JSON string."""
from application.parser.remote.remote_creator import normalize_remote_data
result = normalize_remote_data(
"reddit", {"client_id": "x", "search_queries": ["y"]}
)
assert isinstance(result, str)
assert json.loads(result) == {"client_id": "x", "search_queries": ["y"]}
def test_s3_dict_passes_through(self):
"""S3Loader.load_data() accepts a dict, so it is left untouched."""
from application.parser.remote.remote_creator import normalize_remote_data
data = {"bucket": "b", "prefix": "k"}
assert normalize_remote_data("s3", data) == data
def test_json_string_remote_data_is_parsed(self):
"""Legacy rows that stored the JSON itself as a string still resolve."""
from application.parser.remote.remote_creator import normalize_remote_data
result = normalize_remote_data("crawler", '{"url": "https://example.com"}')
assert result == "https://example.com"

View File

@@ -148,6 +148,130 @@ class TestSyncWorker:
assert captured[0]["loader"] == "url"
assert captured[0]["doc_id"] == str(src["id"])
def test_connector_sources_are_skipped(
self,
pg_conn,
patch_worker_db,
task_self,
monkeypatch,
):
"""connector:* sources have no RemoteCreator loader — sync_worker
must skip them, not dispatch them into sync()."""
from application import worker
SourcesRepository(pg_conn).create(
"drive-folder",
user_id="dave",
type="connector:file",
retriever="classic",
sync_frequency="daily",
remote_data={
"provider": "google_drive",
"file_ids": ["f1"],
"folder_ids": [],
"recursive": False,
},
)
def _must_not_run(*args, **kwargs):
raise AssertionError("sync() must not run for connector sources")
monkeypatch.setattr(worker, "sync", _must_not_run)
result = worker.sync_worker(task_self, "daily")
assert result["total_sync_count"] == 1
assert result["sync_skipped"] == 1
assert result["sync_success"] == 0
assert result["sync_failure"] == 0
def test_dict_remote_data_is_normalized_before_loader(
self,
pg_conn,
patch_worker_db,
task_self,
monkeypatch,
):
"""Regression: remote_data reads back as a dict; sync_worker must
hand the loader the URL string, not the raw dict."""
from application import worker
SourcesRepository(pg_conn).create(
"docs-crawl",
user_id="erin",
type="crawler",
retriever="classic",
sync_frequency="weekly",
remote_data={"url": "https://example.com", "provider": "crawler"},
)
received: list = []
fake_loader = MagicMock(name="remote_loader")
def _capture(source_data):
received.append(source_data)
return [
Document(
text="page body",
extra_info={"file_path": "index.md", "title": "home"},
doc_id="d1",
)
]
fake_loader.load_data.side_effect = _capture
monkeypatch.setattr(
worker.RemoteCreator, "create_loader", lambda loader: fake_loader
)
monkeypatch.setattr(
worker,
"embed_and_store_documents",
lambda docs, full_path, source_id, task, **kw: None,
)
monkeypatch.setattr(
worker, "upload_index", lambda full_path, file_data: None
)
result = worker.sync_worker(task_self, "weekly")
assert result["total_sync_count"] == 1
assert result["sync_success"] == 1
assert result["sync_failure"] == 0
assert received == ["https://example.com"], (
"loader must receive the URL string, not the remote_data dict"
)
def test_unsyncable_remote_data_is_skipped(
self,
pg_conn,
patch_worker_db,
task_self,
monkeypatch,
):
"""A URL source whose remote_data dict has no URL key normalizes
to None — sync_worker must skip it, not dispatch a doomed sync()."""
from application import worker
SourcesRepository(pg_conn).create(
"broken-feed",
user_id="frank",
type="url",
retriever="classic",
sync_frequency="monthly",
remote_data={"provider": "url"},
)
def _must_not_run(*args, **kwargs):
raise AssertionError("sync() must not run for unsyncable sources")
monkeypatch.setattr(worker, "sync", _must_not_run)
result = worker.sync_worker(task_self, "monthly")
assert result["total_sync_count"] == 1
assert result["sync_skipped"] == 1
assert result["sync_failure"] == 0
assert result["sync_success"] == 0
@pytest.mark.unit
class TestRemoteWorkerPathTraversal: