mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-22 05:15:08 +00:00
Compare commits
4 Commits
fix-glibc
...
fix-bad-sy
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2ae0ef6cc0 | ||
|
|
2399aff245 | ||
|
|
c06646519e | ||
|
|
97a362b703 |
@@ -8,7 +8,7 @@ RUN apt-get update && \
|
||||
add-apt-repository ppa:deadsnakes/ppa && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends gcc g++ wget unzip libc6-dev python3.12 python3.12-venv python3.12-dev && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Verify Python installation and setup symlink
|
||||
RUN if [ -f /usr/bin/python3.12 ]; then \
|
||||
@@ -73,7 +73,7 @@ COPY --from=builder /models /app/models
|
||||
COPY . /app/application
|
||||
|
||||
# Change the ownership of the /app directory to the appuser
|
||||
|
||||
|
||||
RUN mkdir -p /app/application/inputs/local
|
||||
RUN chown -R appuser:appuser /app
|
||||
|
||||
@@ -82,6 +82,11 @@ ENV FLASK_APP=app.py \
|
||||
FLASK_DEBUG=true \
|
||||
PATH="/venv/bin:$PATH"
|
||||
|
||||
ENV MALLOC_ARENA_MAX=2 \
|
||||
OMP_NUM_THREADS=4 \
|
||||
MKL_NUM_THREADS=4 \
|
||||
OPENBLAS_NUM_THREADS=4
|
||||
|
||||
# Expose the port the app runs on
|
||||
EXPOSE 7091
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from flask_restx import fields, Namespace, Resource
|
||||
from application.api import api
|
||||
from application.api.user.tasks import sync_source
|
||||
from application.core.settings import settings
|
||||
from application.parser.remote.remote_creator import normalize_remote_data
|
||||
from application.storage.db.repositories.sources import SourcesRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
@@ -322,7 +323,7 @@ class SyncSource(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
source_data = doc.get("remote_data")
|
||||
source_data = normalize_remote_data(source_type, doc.get("remote_data"))
|
||||
if not source_data:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Source is not syncable"}), 400
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import ctypes
|
||||
import gc
|
||||
import inspect
|
||||
import logging
|
||||
import sys
|
||||
import threading
|
||||
|
||||
from celery import Celery
|
||||
@@ -98,6 +101,34 @@ def _unbind_task_log_context(task_id, **_):
|
||||
)
|
||||
|
||||
|
||||
def _trim_native_heap() -> None:
|
||||
"""Return freed glibc heap pages to the OS (Linux only; no-op elsewhere)."""
|
||||
# docling/torch parsing makes large transient allocations; glibc keeps the
|
||||
# freed pages in per-thread malloc arenas rather than returning them, so a
|
||||
# long-lived worker child's RSS only ever climbs. malloc_trim hands them
|
||||
# back. The symbol is glibc-only — absent in macOS libc.
|
||||
if not sys.platform.startswith("linux"):
|
||||
return
|
||||
try:
|
||||
ctypes.CDLL("libc.so.6").malloc_trim(0)
|
||||
except (OSError, AttributeError):
|
||||
pass
|
||||
|
||||
|
||||
@task_postrun.connect
|
||||
def _reclaim_memory_after_task(*args, **kwargs):
|
||||
"""Drop per-task allocations so the prefork child's RSS doesn't ratchet."""
|
||||
gc.collect()
|
||||
torch = sys.modules.get("torch")
|
||||
if torch is not None:
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
_trim_native_heap()
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def _run_version_check(*args, **kwargs):
|
||||
"""Kick off the anonymous version check on worker startup.
|
||||
|
||||
@@ -31,3 +31,10 @@ worker_prefetch_multiplier = settings.CELERY_WORKER_PREFETCH_MULTIPLIER
|
||||
broker_transport_options = {"visibility_timeout": settings.CELERY_VISIBILITY_TIMEOUT}
|
||||
result_expires = 86400 * 7
|
||||
task_track_started = True
|
||||
|
||||
# Recycle the prefork worker child to bound native-heap growth from
|
||||
# docling/torch parsing. Left unset (Celery's unlimited default) when 0.
|
||||
if settings.CELERY_WORKER_MAX_MEMORY_PER_CHILD > 0:
|
||||
worker_max_memory_per_child = settings.CELERY_WORKER_MAX_MEMORY_PER_CHILD
|
||||
if settings.CELERY_WORKER_MAX_TASKS_PER_CHILD > 0:
|
||||
worker_max_tasks_per_child = settings.CELERY_WORKER_MAX_TASKS_PER_CHILD
|
||||
|
||||
@@ -36,6 +36,11 @@ class Settings(BaseSettings):
|
||||
# and Dify defaults; long ingests can override via env.
|
||||
CELERY_WORKER_PREFETCH_MULTIPLIER: int = 1
|
||||
CELERY_VISIBILITY_TIMEOUT: int = 3600
|
||||
# Recycle the prefork worker child once its resident size crosses this many
|
||||
# kilobytes — backstops native-heap growth from docling/torch parsing. 0 disables.
|
||||
CELERY_WORKER_MAX_MEMORY_PER_CHILD: int = 4194304
|
||||
# Recycle the child after this many tasks; 0 disables (memory cap is the primary knob).
|
||||
CELERY_WORKER_MAX_TASKS_PER_CHILD: int = 0
|
||||
# Only consulted when VECTOR_STORE=mongodb or when running scripts/db/backfill.py; user data lives in Postgres.
|
||||
MONGO_URI: Optional[str] = None
|
||||
# User-data Postgres DB.
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import json
|
||||
|
||||
from application.parser.remote.sitemap_loader import SitemapLoader
|
||||
from application.parser.remote.crawler_loader import CrawlerLoader
|
||||
from application.parser.remote.web_loader import WebLoader
|
||||
@@ -32,3 +34,59 @@ class RemoteCreator:
|
||||
if not loader_class:
|
||||
raise ValueError(f"No loader class found for type {type}")
|
||||
return loader_class(*args, **kwargs)
|
||||
|
||||
|
||||
# Loader types whose load_data expects a URL string, not a config dict.
|
||||
_URL_LOADER_TYPES = {"url", "crawler", "sitemap", "github"}
|
||||
|
||||
# Keys a remote_data dict may hold the URL under (``raw`` is the legacy shape).
|
||||
_URL_DATA_KEYS = ("url", "urls", "repo_url", "raw")
|
||||
|
||||
|
||||
def normalize_remote_data(source_type, remote_data):
|
||||
"""Convert a stored ``sources.remote_data`` JSONB value into the
|
||||
``source_data`` shape the matching loader expects.
|
||||
|
||||
Args:
|
||||
source_type: The ``sources.type`` value (the loader name).
|
||||
remote_data: The stored ``remote_data`` (dict, list, str, or None).
|
||||
|
||||
Returns:
|
||||
Loader input: a URL string or list for url/crawler/sitemap/github,
|
||||
a JSON string for reddit, a dict for s3; ``None`` when the row has
|
||||
nothing syncable.
|
||||
"""
|
||||
if remote_data is None:
|
||||
return None
|
||||
|
||||
# Some legacy rows stored the JSON itself as a string.
|
||||
if isinstance(remote_data, str):
|
||||
stripped = remote_data.strip()
|
||||
if stripped[:1] in ("{", "["):
|
||||
try:
|
||||
remote_data = json.loads(stripped)
|
||||
except json.JSONDecodeError:
|
||||
# Not actually JSON — leave remote_data as the original
|
||||
# string; the per-loader branches below handle a string.
|
||||
pass
|
||||
|
||||
loader = (source_type or "").lower()
|
||||
|
||||
if loader in _URL_LOADER_TYPES:
|
||||
if isinstance(remote_data, dict):
|
||||
for key in _URL_DATA_KEYS:
|
||||
value = remote_data.get(key)
|
||||
if value:
|
||||
return value
|
||||
# No URL key — None keeps the loader off the dict-crash path.
|
||||
return None
|
||||
return remote_data
|
||||
|
||||
if loader == "reddit":
|
||||
# reddit's loader runs json.loads() on its input — needs a string.
|
||||
if isinstance(remote_data, (dict, list)):
|
||||
return json.dumps(remote_data)
|
||||
return remote_data
|
||||
|
||||
# s3's loader accepts a dict or JSON string; pass it through unchanged.
|
||||
return remote_data
|
||||
|
||||
@@ -63,7 +63,8 @@ class ToolCallAttemptsRepository:
|
||||
message_id: Optional[str] = None,
|
||||
artifact_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Insert OR upgrade a row to ``executed``.
|
||||
"""Insert OR upgrade a row to ``executed`` — or ``confirmed`` when
|
||||
there is no ``message_id``, as in ``mark_executed``.
|
||||
|
||||
Used as a fallback when ``record_proposed`` failed (DB outage)
|
||||
and the tool ran anyway — preserves the journal so the
|
||||
@@ -72,6 +73,7 @@ class ToolCallAttemptsRepository:
|
||||
result_payload: dict = {"result": result}
|
||||
if artifact_id:
|
||||
result_payload["artifact_id"] = artifact_id
|
||||
status = "executed" if message_id is not None else "confirmed"
|
||||
self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
@@ -82,9 +84,9 @@ class ToolCallAttemptsRepository:
|
||||
(:call_id, CAST(:tool_id AS uuid), :tool_name,
|
||||
:action_name, CAST(:arguments AS jsonb),
|
||||
CAST(:result AS jsonb), CAST(:message_id AS uuid),
|
||||
'executed')
|
||||
:status)
|
||||
ON CONFLICT (call_id) DO UPDATE
|
||||
SET status = 'executed',
|
||||
SET status = :status,
|
||||
result = EXCLUDED.result,
|
||||
message_id = COALESCE(EXCLUDED.message_id, tool_call_attempts.message_id)
|
||||
"""
|
||||
@@ -97,6 +99,7 @@ class ToolCallAttemptsRepository:
|
||||
"arguments": json.dumps(arguments if arguments is not None else {}, cls=PGNativeJSONEncoder),
|
||||
"result": json.dumps(result_payload, cls=PGNativeJSONEncoder),
|
||||
"message_id": message_id,
|
||||
"status": status,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -108,7 +111,9 @@ class ToolCallAttemptsRepository:
|
||||
message_id: Optional[str] = None,
|
||||
artifact_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Flip ``proposed`` → ``executed`` with the tool result.
|
||||
"""Flip ``proposed`` → ``executed``, or straight to ``confirmed``
|
||||
when there is no ``message_id`` (a ``save_conversation=False``
|
||||
request reserves no message, so no finalize will confirm it).
|
||||
|
||||
``artifact_id`` (when present) is stored alongside ``result`` in
|
||||
the JSONB as audit data — the reconciler reads it for diagnostic
|
||||
@@ -117,12 +122,14 @@ class ToolCallAttemptsRepository:
|
||||
result_payload: dict = {"result": result}
|
||||
if artifact_id:
|
||||
result_payload["artifact_id"] = artifact_id
|
||||
status = "executed" if message_id is not None else "confirmed"
|
||||
sql = (
|
||||
"UPDATE tool_call_attempts SET "
|
||||
"status = 'executed', result = CAST(:result AS jsonb)"
|
||||
"status = :status, result = CAST(:result AS jsonb)"
|
||||
)
|
||||
params: dict[str, Any] = {
|
||||
"call_id": call_id,
|
||||
"status": status,
|
||||
"result": json.dumps(result_payload, cls=PGNativeJSONEncoder),
|
||||
}
|
||||
if message_id is not None:
|
||||
|
||||
@@ -29,7 +29,10 @@ from application.parser.embedding_pipeline import (
|
||||
)
|
||||
from application.parser.file.bulk import SimpleDirectoryReader, get_default_file_extractor
|
||||
from application.parser.file.constants import SUPPORTED_SOURCE_EXTENSIONS
|
||||
from application.parser.remote.remote_creator import RemoteCreator
|
||||
from application.parser.remote.remote_creator import (
|
||||
RemoteCreator,
|
||||
normalize_remote_data,
|
||||
)
|
||||
from application.parser.schema.base import Document
|
||||
from application.retriever.retriever_creator import RetrieverCreator
|
||||
|
||||
@@ -1431,19 +1434,35 @@ def sync_worker(self, frequency):
|
||||
name = doc.get("name")
|
||||
user = doc.get("user_id")
|
||||
source_type = doc.get("type")
|
||||
source_data = doc.get("remote_data")
|
||||
retriever = doc.get("retriever")
|
||||
doc_id = str(doc.get("id"))
|
||||
|
||||
sync_counts["total_sync_count"] += 1
|
||||
|
||||
# Connector sources have no RemoteCreator loader and need an OAuth
|
||||
# token to sync, which a scheduled task lacks — skip them.
|
||||
if source_type and source_type.startswith("connector"):
|
||||
sync_counts["sync_skipped"] += 1
|
||||
continue
|
||||
|
||||
source_data = normalize_remote_data(source_type, doc.get("remote_data"))
|
||||
if not source_data:
|
||||
# No syncable URL/config — skip instead of dispatching a sync
|
||||
# that can only fail (and emit a spurious failed event).
|
||||
sync_counts["sync_skipped"] += 1
|
||||
continue
|
||||
|
||||
resp = sync(
|
||||
self, source_data, name, user, source_type, frequency, retriever, doc_id
|
||||
)
|
||||
sync_counts["total_sync_count"] += 1
|
||||
sync_counts[
|
||||
"sync_success" if resp["status"] == "success" else "sync_failure"
|
||||
] += 1
|
||||
return {
|
||||
key: sync_counts[key]
|
||||
for key in ["total_sync_count", "sync_success", "sync_failure"]
|
||||
for key in [
|
||||
"total_sync_count", "sync_success", "sync_failure", "sync_skipped",
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
"""Tests for the journaled execute path on ToolExecutor.
|
||||
|
||||
Each tool call inserts a row into ``tool_call_attempts`` then flips
|
||||
through ``proposed → executed`` (or ``proposed → failed``). The flip
|
||||
to ``confirmed`` is owned by the message-finalize path and is only
|
||||
asserted indirectly here (rows stay in ``executed`` so the reconciler
|
||||
can pick them up).
|
||||
Each tool call inserts a ``tool_call_attempts`` row and flips it
|
||||
``proposed → executed`` (or ``→ failed``). With a ``message_id`` it
|
||||
stays ``executed`` for the finalize path to confirm; without one
|
||||
(``save_conversation=False``) it goes straight to ``confirmed``.
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
@@ -75,11 +74,24 @@ def _make_call(name="test_action_t1", call_id="c1"):
|
||||
return call
|
||||
|
||||
|
||||
_TOOLS_DICT = {
|
||||
"t1": {
|
||||
"id": "00000000-0000-0000-0000-000000000001",
|
||||
"name": "test_tool",
|
||||
"config": {"key": "val"},
|
||||
"actions": [
|
||||
{"name": "test_action", "description": "T", "parameters": {"properties": {}}},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestExecuteJournaling:
|
||||
def test_happy_path_proposed_then_executed(
|
||||
def test_no_message_id_proposed_then_confirmed(
|
||||
self, pg_conn, mock_tool_manager, monkeypatch
|
||||
):
|
||||
"""No reserved message (``save_conversation=False``) → row lands ``confirmed``, not ``executed``."""
|
||||
executor = ToolExecutor(user="u")
|
||||
monkeypatch.setattr(
|
||||
"application.agents.tool_executor.ToolActionParser",
|
||||
@@ -89,23 +101,12 @@ class TestExecuteJournaling:
|
||||
)
|
||||
_patch_db(monkeypatch, pg_conn)
|
||||
|
||||
tools_dict = {
|
||||
"t1": {
|
||||
"id": "00000000-0000-0000-0000-000000000001",
|
||||
"name": "test_tool",
|
||||
"config": {"key": "val"},
|
||||
"actions": [
|
||||
{"name": "test_action", "description": "T", "parameters": {"properties": {}}},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
events, result = _drain(executor.execute(tools_dict, _make_call(), "MockLLM"))
|
||||
events, result = _drain(executor.execute(_TOOLS_DICT, _make_call(), "MockLLM"))
|
||||
assert result[0] == "Tool result"
|
||||
|
||||
row = _select_attempt(pg_conn, "c1")
|
||||
assert row is not None
|
||||
assert row["status"] == "executed"
|
||||
assert row["status"] == "confirmed"
|
||||
assert row["tool_name"] == "test_tool"
|
||||
assert row["action_name"] == "test_action"
|
||||
assert row["arguments"] == {"q": "v"}
|
||||
@@ -117,10 +118,7 @@ class TestExecuteJournaling:
|
||||
def test_executor_message_id_is_persisted_on_executed_row(
|
||||
self, pg_conn, mock_tool_manager, monkeypatch
|
||||
):
|
||||
"""When the route stamps a placeholder message_id on the executor,
|
||||
the journal row carries it forward so ``confirm_executed_tool_calls``
|
||||
can later flip it to ``confirmed``.
|
||||
"""
|
||||
"""The executor's message_id is carried onto the journal row, which stays ``executed``."""
|
||||
from application.storage.db.repositories.conversations import (
|
||||
ConversationsRepository,
|
||||
)
|
||||
@@ -147,18 +145,7 @@ class TestExecuteJournaling:
|
||||
)
|
||||
_patch_db(monkeypatch, pg_conn)
|
||||
|
||||
tools_dict = {
|
||||
"t1": {
|
||||
"id": "00000000-0000-0000-0000-000000000001",
|
||||
"name": "test_tool",
|
||||
"config": {"key": "val"},
|
||||
"actions": [
|
||||
{"name": "test_action", "description": "T", "parameters": {"properties": {}}},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
_drain(executor.execute(tools_dict, _make_call(call_id="cm1"), "MockLLM"))
|
||||
_drain(executor.execute(_TOOLS_DICT, _make_call(call_id="cm1"), "MockLLM"))
|
||||
|
||||
row = _select_attempt(pg_conn, "cm1")
|
||||
assert row is not None
|
||||
@@ -180,18 +167,7 @@ class TestExecuteJournaling:
|
||||
RuntimeError("boom")
|
||||
)
|
||||
|
||||
tools_dict = {
|
||||
"t1": {
|
||||
"id": "00000000-0000-0000-0000-000000000001",
|
||||
"name": "test_tool",
|
||||
"config": {"key": "val"},
|
||||
"actions": [
|
||||
{"name": "test_action", "description": "T", "parameters": {"properties": {}}},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
gen = executor.execute(tools_dict, _make_call(call_id="c2"), "MockLLM")
|
||||
gen = executor.execute(_TOOLS_DICT, _make_call(call_id="c2"), "MockLLM")
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
_drain(gen)
|
||||
|
||||
@@ -200,42 +176,10 @@ class TestExecuteJournaling:
|
||||
assert row["status"] == "failed"
|
||||
assert row["error"] == "boom"
|
||||
|
||||
def test_executed_row_lingers_for_reconciler_when_no_confirm(
|
||||
self, pg_conn, mock_tool_manager, monkeypatch
|
||||
):
|
||||
"""No finalize_message call → row sits in ``executed``."""
|
||||
executor = ToolExecutor(user="u")
|
||||
monkeypatch.setattr(
|
||||
"application.agents.tool_executor.ToolActionParser",
|
||||
lambda _cls, **kw: Mock(
|
||||
parse_args=Mock(return_value=("t1", "test_action", {}))
|
||||
),
|
||||
)
|
||||
_patch_db(monkeypatch, pg_conn)
|
||||
|
||||
tools_dict = {
|
||||
"t1": {
|
||||
"id": "00000000-0000-0000-0000-000000000001",
|
||||
"name": "test_tool",
|
||||
"config": {"key": "val"},
|
||||
"actions": [
|
||||
{"name": "test_action", "description": "T", "parameters": {"properties": {}}},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
_drain(executor.execute(tools_dict, _make_call(call_id="c3"), "MockLLM"))
|
||||
|
||||
row = _select_attempt(pg_conn, "c3")
|
||||
assert row["status"] == "executed"
|
||||
# Partial index `tool_call_attempts_pending_ts_idx` selects rows
|
||||
# in ('proposed','executed') — the reconciler reads those.
|
||||
assert row["status"] in ("proposed", "executed")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestRepository:
|
||||
def test_proposed_then_executed_round_trip(self, pg_conn):
|
||||
def test_proposed_then_confirmed_when_no_message(self, pg_conn):
|
||||
from application.storage.db.repositories.tool_call_attempts import (
|
||||
ToolCallAttemptsRepository,
|
||||
)
|
||||
@@ -249,7 +193,50 @@ class TestRepository:
|
||||
|
||||
assert repo.mark_executed("c-x", {"out": "ok"}) is True
|
||||
row = _select_attempt(pg_conn, "c-x")
|
||||
assert row["status"] == "confirmed"
|
||||
assert row["message_id"] is None
|
||||
assert row["result"] == {"result": {"out": "ok"}}
|
||||
|
||||
def test_mark_executed_with_message_stays_executed(self, pg_conn):
|
||||
from application.storage.db.repositories.conversations import (
|
||||
ConversationsRepository,
|
||||
)
|
||||
from application.storage.db.repositories.tool_call_attempts import (
|
||||
ToolCallAttemptsRepository,
|
||||
)
|
||||
|
||||
# FK constraint: message_id must reference a real row.
|
||||
conv_repo = ConversationsRepository(pg_conn)
|
||||
conv = conv_repo.create("u-repo", "repo-msg-test")
|
||||
msg = conv_repo.reserve_message(
|
||||
str(conv["id"]),
|
||||
prompt="q?",
|
||||
placeholder_response="...",
|
||||
request_id="req-repo-1",
|
||||
status="pending",
|
||||
)
|
||||
message_uuid = str(msg["id"])
|
||||
|
||||
repo = ToolCallAttemptsRepository(pg_conn)
|
||||
repo.record_proposed("c-m", "tool", "act", {})
|
||||
assert (
|
||||
repo.mark_executed("c-m", {"out": "ok"}, message_id=message_uuid) is True
|
||||
)
|
||||
row = _select_attempt(pg_conn, "c-m")
|
||||
assert row["status"] == "executed"
|
||||
assert str(row["message_id"]) == message_uuid
|
||||
|
||||
def test_upsert_executed_without_message_confirms(self, pg_conn):
|
||||
"""``upsert_executed`` (DB-outage fallback) with no ``message_id`` lands ``confirmed``."""
|
||||
from application.storage.db.repositories.tool_call_attempts import (
|
||||
ToolCallAttemptsRepository,
|
||||
)
|
||||
|
||||
repo = ToolCallAttemptsRepository(pg_conn)
|
||||
repo.upsert_executed("c-up", "tool", "act", {"a": 1}, {"out": "ok"})
|
||||
row = _select_attempt(pg_conn, "c-up")
|
||||
assert row["status"] == "confirmed"
|
||||
assert row["message_id"] is None
|
||||
assert row["result"] == {"result": {"out": "ok"}}
|
||||
|
||||
def test_mark_failed_sets_error(self, pg_conn):
|
||||
|
||||
@@ -553,6 +553,35 @@ class TestSyncSource:
|
||||
assert response.status_code == 200
|
||||
assert response.json["task_id"] == "task-123"
|
||||
|
||||
def test_normalizes_dict_remote_data_before_dispatch(self, app, pg_conn):
|
||||
"""The route must hand the sync task the normalized URL string."""
|
||||
from application.api.user.sources.routes import SyncSource
|
||||
|
||||
user = "u-normalize"
|
||||
src = _seed_source(
|
||||
pg_conn, user, name="crawl-src", type="crawler",
|
||||
remote_data=json.dumps(
|
||||
{"url": "https://example.com", "provider": "crawler"}
|
||||
),
|
||||
)
|
||||
|
||||
fake_task = MagicMock(id="task-norm")
|
||||
with _patch_db(pg_conn), patch(
|
||||
"application.api.user.sources.routes.sync_source.delay",
|
||||
return_value=fake_task,
|
||||
) as mock_delay, app.test_request_context(
|
||||
"/api/sync_source",
|
||||
method="POST",
|
||||
json={"source_id": str(src["id"])},
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": user}
|
||||
response = SyncSource().post()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert mock_delay.call_args.kwargs["source_data"] == "https://example.com"
|
||||
assert mock_delay.call_args.kwargs["loader"] == "crawler"
|
||||
|
||||
def test_sync_task_raises_returns_400(self, app, pg_conn):
|
||||
from application.api.user.sources.routes import SyncSource
|
||||
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
"""Tests for application.parser.remote.remote_creator covering lines 31-34."""
|
||||
"""Tests for application.parser.remote.remote_creator."""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
@@ -38,3 +40,92 @@ class TestRemoteCreator:
|
||||
mock_loader_cls.assert_called_once()
|
||||
finally:
|
||||
RemoteCreator.loaders = original_loaders
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestNormalizeRemoteData:
|
||||
"""``normalize_remote_data`` maps a stored JSONB ``remote_data`` value
|
||||
back to the ``source_data`` shape each loader expects."""
|
||||
|
||||
def test_none_passes_through(self):
|
||||
from application.parser.remote.remote_creator import normalize_remote_data
|
||||
|
||||
assert normalize_remote_data("crawler", None) is None
|
||||
|
||||
def test_crawler_dict_with_url_key(self):
|
||||
from application.parser.remote.remote_creator import normalize_remote_data
|
||||
|
||||
result = normalize_remote_data(
|
||||
"crawler", {"url": "https://example.com", "provider": "crawler"}
|
||||
)
|
||||
assert result == "https://example.com"
|
||||
|
||||
def test_url_dict_with_url_key(self):
|
||||
from application.parser.remote.remote_creator import normalize_remote_data
|
||||
|
||||
result = normalize_remote_data("url", {"url": "https://example.com"})
|
||||
assert result == "https://example.com"
|
||||
|
||||
def test_url_legacy_raw_key(self):
|
||||
"""Legacy rows wrapped a bare URL string as ``{"raw": ...}``."""
|
||||
from application.parser.remote.remote_creator import normalize_remote_data
|
||||
|
||||
result = normalize_remote_data("crawler", {"raw": "https://legacy.example.com"})
|
||||
assert result == "https://legacy.example.com"
|
||||
|
||||
def test_url_dict_with_urls_list(self):
|
||||
from application.parser.remote.remote_creator import normalize_remote_data
|
||||
|
||||
result = normalize_remote_data(
|
||||
"url", {"urls": ["https://a.example.com", "https://b.example.com"]}
|
||||
)
|
||||
assert result == ["https://a.example.com", "https://b.example.com"]
|
||||
|
||||
def test_github_repo_url_key(self):
|
||||
from application.parser.remote.remote_creator import normalize_remote_data
|
||||
|
||||
result = normalize_remote_data(
|
||||
"github", {"repo_url": "https://github.com/arc53/DocsGPT"}
|
||||
)
|
||||
assert result == "https://github.com/arc53/DocsGPT"
|
||||
|
||||
def test_sitemap_dict_with_url_key(self):
|
||||
from application.parser.remote.remote_creator import normalize_remote_data
|
||||
|
||||
result = normalize_remote_data("sitemap", {"url": "https://example.com/sitemap.xml"})
|
||||
assert result == "https://example.com/sitemap.xml"
|
||||
|
||||
def test_plain_string_url_passes_through(self):
|
||||
from application.parser.remote.remote_creator import normalize_remote_data
|
||||
|
||||
assert normalize_remote_data("crawler", "https://example.com") == "https://example.com"
|
||||
|
||||
def test_url_dict_without_url_key_returns_none(self):
|
||||
"""A URL-type loader must never receive a dict, even a malformed one."""
|
||||
from application.parser.remote.remote_creator import normalize_remote_data
|
||||
|
||||
assert normalize_remote_data("crawler", {"provider": "crawler"}) is None
|
||||
|
||||
def test_reddit_dict_serialized_to_json_string(self):
|
||||
"""reddit's loader runs json.loads() — it needs a JSON string."""
|
||||
from application.parser.remote.remote_creator import normalize_remote_data
|
||||
|
||||
result = normalize_remote_data(
|
||||
"reddit", {"client_id": "x", "search_queries": ["y"]}
|
||||
)
|
||||
assert isinstance(result, str)
|
||||
assert json.loads(result) == {"client_id": "x", "search_queries": ["y"]}
|
||||
|
||||
def test_s3_dict_passes_through(self):
|
||||
"""S3Loader.load_data() accepts a dict, so it is left untouched."""
|
||||
from application.parser.remote.remote_creator import normalize_remote_data
|
||||
|
||||
data = {"bucket": "b", "prefix": "k"}
|
||||
assert normalize_remote_data("s3", data) == data
|
||||
|
||||
def test_json_string_remote_data_is_parsed(self):
|
||||
"""Legacy rows that stored the JSON itself as a string still resolve."""
|
||||
from application.parser.remote.remote_creator import normalize_remote_data
|
||||
|
||||
result = normalize_remote_data("crawler", '{"url": "https://example.com"}')
|
||||
assert result == "https://example.com"
|
||||
|
||||
@@ -148,6 +148,130 @@ class TestSyncWorker:
|
||||
assert captured[0]["loader"] == "url"
|
||||
assert captured[0]["doc_id"] == str(src["id"])
|
||||
|
||||
def test_connector_sources_are_skipped(
|
||||
self,
|
||||
pg_conn,
|
||||
patch_worker_db,
|
||||
task_self,
|
||||
monkeypatch,
|
||||
):
|
||||
"""connector:* sources have no RemoteCreator loader — sync_worker
|
||||
must skip them, not dispatch them into sync()."""
|
||||
from application import worker
|
||||
|
||||
SourcesRepository(pg_conn).create(
|
||||
"drive-folder",
|
||||
user_id="dave",
|
||||
type="connector:file",
|
||||
retriever="classic",
|
||||
sync_frequency="daily",
|
||||
remote_data={
|
||||
"provider": "google_drive",
|
||||
"file_ids": ["f1"],
|
||||
"folder_ids": [],
|
||||
"recursive": False,
|
||||
},
|
||||
)
|
||||
|
||||
def _must_not_run(*args, **kwargs):
|
||||
raise AssertionError("sync() must not run for connector sources")
|
||||
|
||||
monkeypatch.setattr(worker, "sync", _must_not_run)
|
||||
|
||||
result = worker.sync_worker(task_self, "daily")
|
||||
|
||||
assert result["total_sync_count"] == 1
|
||||
assert result["sync_skipped"] == 1
|
||||
assert result["sync_success"] == 0
|
||||
assert result["sync_failure"] == 0
|
||||
|
||||
def test_dict_remote_data_is_normalized_before_loader(
|
||||
self,
|
||||
pg_conn,
|
||||
patch_worker_db,
|
||||
task_self,
|
||||
monkeypatch,
|
||||
):
|
||||
"""Regression: remote_data reads back as a dict; sync_worker must
|
||||
hand the loader the URL string, not the raw dict."""
|
||||
from application import worker
|
||||
|
||||
SourcesRepository(pg_conn).create(
|
||||
"docs-crawl",
|
||||
user_id="erin",
|
||||
type="crawler",
|
||||
retriever="classic",
|
||||
sync_frequency="weekly",
|
||||
remote_data={"url": "https://example.com", "provider": "crawler"},
|
||||
)
|
||||
|
||||
received: list = []
|
||||
fake_loader = MagicMock(name="remote_loader")
|
||||
|
||||
def _capture(source_data):
|
||||
received.append(source_data)
|
||||
return [
|
||||
Document(
|
||||
text="page body",
|
||||
extra_info={"file_path": "index.md", "title": "home"},
|
||||
doc_id="d1",
|
||||
)
|
||||
]
|
||||
|
||||
fake_loader.load_data.side_effect = _capture
|
||||
monkeypatch.setattr(
|
||||
worker.RemoteCreator, "create_loader", lambda loader: fake_loader
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
worker,
|
||||
"embed_and_store_documents",
|
||||
lambda docs, full_path, source_id, task, **kw: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
worker, "upload_index", lambda full_path, file_data: None
|
||||
)
|
||||
|
||||
result = worker.sync_worker(task_self, "weekly")
|
||||
|
||||
assert result["total_sync_count"] == 1
|
||||
assert result["sync_success"] == 1
|
||||
assert result["sync_failure"] == 0
|
||||
assert received == ["https://example.com"], (
|
||||
"loader must receive the URL string, not the remote_data dict"
|
||||
)
|
||||
|
||||
def test_unsyncable_remote_data_is_skipped(
|
||||
self,
|
||||
pg_conn,
|
||||
patch_worker_db,
|
||||
task_self,
|
||||
monkeypatch,
|
||||
):
|
||||
"""A URL source whose remote_data dict has no URL key normalizes
|
||||
to None — sync_worker must skip it, not dispatch a doomed sync()."""
|
||||
from application import worker
|
||||
|
||||
SourcesRepository(pg_conn).create(
|
||||
"broken-feed",
|
||||
user_id="frank",
|
||||
type="url",
|
||||
retriever="classic",
|
||||
sync_frequency="monthly",
|
||||
remote_data={"provider": "url"},
|
||||
)
|
||||
|
||||
def _must_not_run(*args, **kwargs):
|
||||
raise AssertionError("sync() must not run for unsyncable sources")
|
||||
|
||||
monkeypatch.setattr(worker, "sync", _must_not_run)
|
||||
|
||||
result = worker.sync_worker(task_self, "monthly")
|
||||
|
||||
assert result["total_sync_count"] == 1
|
||||
assert result["sync_skipped"] == 1
|
||||
assert result["sync_failure"] == 0
|
||||
assert result["sync_success"] == 0
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestRemoteWorkerPathTraversal:
|
||||
|
||||
Reference in New Issue
Block a user