mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-22 05:15:08 +00:00
Compare commits
9 Commits
feat-notif
...
hardening-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9a5ea8fe00 | ||
|
|
1de82ca040 | ||
|
|
8f7742c937 | ||
|
|
e3bf6a5471 | ||
|
|
e167cf8247 | ||
|
|
c06646519e | ||
|
|
97a362b703 | ||
|
|
29477b40b3 | ||
|
|
e351f45d88 |
@@ -8,7 +8,7 @@ RUN apt-get update && \
|
||||
add-apt-repository ppa:deadsnakes/ppa && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends gcc g++ wget unzip libc6-dev python3.12 python3.12-venv python3.12-dev && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Verify Python installation and setup symlink
|
||||
RUN if [ -f /usr/bin/python3.12 ]; then \
|
||||
@@ -73,7 +73,7 @@ COPY --from=builder /models /app/models
|
||||
COPY . /app/application
|
||||
|
||||
# Change the ownership of the /app directory to the appuser
|
||||
|
||||
|
||||
RUN mkdir -p /app/application/inputs/local
|
||||
RUN chown -R appuser:appuser /app
|
||||
|
||||
@@ -82,6 +82,11 @@ ENV FLASK_APP=app.py \
|
||||
FLASK_DEBUG=true \
|
||||
PATH="/venv/bin:$PATH"
|
||||
|
||||
ENV MALLOC_ARENA_MAX=2 \
|
||||
OMP_NUM_THREADS=4 \
|
||||
MKL_NUM_THREADS=4 \
|
||||
OPENBLAS_NUM_THREADS=4
|
||||
|
||||
# Expose the port the app runs on
|
||||
EXPOSE 7091
|
||||
|
||||
|
||||
@@ -114,6 +114,8 @@ class BaseAgent(ABC):
|
||||
self.compressed_summary = compressed_summary
|
||||
self.current_token_count = 0
|
||||
self.context_limit_reached = False
|
||||
self.conversation_id: Optional[str] = None
|
||||
self.initial_user_id: Optional[str] = None
|
||||
|
||||
@log_activity()
|
||||
def gen(
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
from urllib.parse import urlencode
|
||||
from urllib.parse import quote, urlencode
|
||||
|
||||
import requests
|
||||
|
||||
@@ -11,7 +11,7 @@ from application.agents.tools.api_body_serializer import (
|
||||
RequestBodySerializer,
|
||||
)
|
||||
from application.agents.tools.base import Tool
|
||||
from application.core.url_validation import validate_url, SSRFError
|
||||
from application.security.safe_url import UnsafeUserUrlError, pinned_request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -70,18 +70,16 @@ class APITool(Tool):
|
||||
Returns:
|
||||
Dict with status_code, data, and message
|
||||
"""
|
||||
_VALID_METHODS = {"GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"}
|
||||
|
||||
request_url = url
|
||||
request_headers = headers.copy() if headers else {}
|
||||
response = None
|
||||
|
||||
# Validate URL to prevent SSRF attacks
|
||||
try:
|
||||
validate_url(request_url)
|
||||
except SSRFError as e:
|
||||
logger.error(f"URL validation failed: {e}")
|
||||
if method.upper() not in _VALID_METHODS:
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"URL validation error: {e}",
|
||||
"message": f"Unsupported HTTP method: {method}",
|
||||
"data": None,
|
||||
}
|
||||
|
||||
@@ -91,8 +89,9 @@ class APITool(Tool):
|
||||
for match in re.finditer(r"\{([^}]+)\}", request_url):
|
||||
param_name = match.group(1)
|
||||
if param_name in query_params:
|
||||
safe_value = quote(str(query_params[param_name]), safe="")
|
||||
request_url = request_url.replace(
|
||||
f"{{{param_name}}}", str(query_params[param_name])
|
||||
f"{{{param_name}}}", safe_value
|
||||
)
|
||||
path_params_used.add(param_name)
|
||||
remaining_params = {
|
||||
@@ -103,19 +102,6 @@ class APITool(Tool):
|
||||
separator = "&" if "?" in request_url else "?"
|
||||
request_url = f"{request_url}{separator}{query_string}"
|
||||
|
||||
# Re-validate URL after parameter substitution to prevent SSRF via path params
|
||||
try:
|
||||
validate_url(request_url)
|
||||
except SSRFError as e:
|
||||
logger.error(f"URL validation failed after parameter substitution: {e}")
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"URL validation error: {e}",
|
||||
"data": None,
|
||||
}
|
||||
|
||||
# Serialize body based on content type
|
||||
|
||||
if body and body != {}:
|
||||
try:
|
||||
serialized_body, body_headers = RequestBodySerializer.serialize(
|
||||
@@ -141,49 +127,13 @@ class APITool(Tool):
|
||||
f"API Call: {method} {request_url} | Content-Type: {request_headers.get('Content-Type', 'N/A')}"
|
||||
)
|
||||
|
||||
if method.upper() == "GET":
|
||||
response = requests.get(
|
||||
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
|
||||
)
|
||||
elif method.upper() == "POST":
|
||||
response = requests.post(
|
||||
request_url,
|
||||
data=serialized_body,
|
||||
headers=request_headers,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
)
|
||||
elif method.upper() == "PUT":
|
||||
response = requests.put(
|
||||
request_url,
|
||||
data=serialized_body,
|
||||
headers=request_headers,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
)
|
||||
elif method.upper() == "DELETE":
|
||||
response = requests.delete(
|
||||
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
|
||||
)
|
||||
elif method.upper() == "PATCH":
|
||||
response = requests.patch(
|
||||
request_url,
|
||||
data=serialized_body,
|
||||
headers=request_headers,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
)
|
||||
elif method.upper() == "HEAD":
|
||||
response = requests.head(
|
||||
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
|
||||
)
|
||||
elif method.upper() == "OPTIONS":
|
||||
response = requests.options(
|
||||
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
|
||||
)
|
||||
else:
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"Unsupported HTTP method: {method}",
|
||||
"data": None,
|
||||
}
|
||||
response = pinned_request(
|
||||
method,
|
||||
request_url,
|
||||
data=serialized_body,
|
||||
headers=request_headers,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
data = self._parse_response(response)
|
||||
@@ -193,6 +143,13 @@ class APITool(Tool):
|
||||
"data": data,
|
||||
"message": "API call successful.",
|
||||
}
|
||||
except UnsafeUserUrlError as e:
|
||||
logger.error(f"URL validation failed: {e}")
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"URL validation error: {e}",
|
||||
"data": None,
|
||||
}
|
||||
except requests.exceptions.Timeout:
|
||||
logger.error(f"Request timeout for {request_url}")
|
||||
return {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import requests
|
||||
from application.agents.tools.base import Tool
|
||||
from application.security.safe_url import UnsafeUserUrlError, pinned_request
|
||||
|
||||
class NtfyTool(Tool):
|
||||
"""
|
||||
@@ -71,7 +71,12 @@ class NtfyTool(Tool):
|
||||
if self.token:
|
||||
headers["Authorization"] = f"Basic {self.token}"
|
||||
data = message.encode("utf-8")
|
||||
response = requests.post(url, headers=headers, data=data, timeout=100)
|
||||
try:
|
||||
response = pinned_request(
|
||||
"POST", url, data=data, headers=headers, timeout=100,
|
||||
)
|
||||
except UnsafeUserUrlError as e:
|
||||
return {"status_code": None, "message": f"URL validation error: {e}"}
|
||||
return {"status_code": response.status_code, "message": "Message sent"}
|
||||
|
||||
def get_actions_metadata(self):
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import requests
|
||||
from markdownify import markdownify
|
||||
from application.agents.tools.base import Tool
|
||||
from application.core.url_validation import validate_url, SSRFError
|
||||
from application.security.safe_url import UnsafeUserUrlError, pinned_request
|
||||
|
||||
class ReadWebpageTool(Tool):
|
||||
"""
|
||||
@@ -31,28 +30,24 @@ class ReadWebpageTool(Tool):
|
||||
if not url:
|
||||
return "Error: URL parameter is missing."
|
||||
|
||||
# Validate URL to prevent SSRF attacks
|
||||
try:
|
||||
url = validate_url(url)
|
||||
except SSRFError as e:
|
||||
return f"Error: URL validation failed - {e}"
|
||||
response = pinned_request(
|
||||
"GET",
|
||||
url,
|
||||
headers={'User-Agent': 'DocsGPT-Agent/1.0'},
|
||||
timeout=10,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
try:
|
||||
response = requests.get(url, timeout=10, headers={'User-Agent': 'DocsGPT-Agent/1.0'})
|
||||
response.raise_for_status() # Raise an exception for HTTP errors (4xx or 5xx)
|
||||
|
||||
html_content = response.text
|
||||
#soup = BeautifulSoup(html_content, 'html.parser')
|
||||
|
||||
|
||||
markdown_content = markdownify(html_content, heading_style="ATX", newline_style="BACKSLASH")
|
||||
|
||||
|
||||
return markdown_content
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
return f"Error fetching URL {url}: {e}"
|
||||
except UnsafeUserUrlError as e:
|
||||
return f"Error: URL validation failed - {e}"
|
||||
except Exception as e:
|
||||
return f"Error processing URL {url}: {e}"
|
||||
return f"Error fetching URL {url}: {e}"
|
||||
|
||||
def get_actions_metadata(self):
|
||||
"""
|
||||
|
||||
44
application/alembic/versions/0008_ingest_progress_status.py
Normal file
44
application/alembic/versions/0008_ingest_progress_status.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""0008 ingest_chunk_progress.status — terminal flag for stalled ingests.
|
||||
|
||||
The reconciler's stalled-ingest sweep had no terminal write, so a dead
|
||||
ingest re-alerted every ~30 min forever. ``status`` lets it escalate a
|
||||
stalled checkpoint to ``'stalled'`` once and stop re-selecting it;
|
||||
``init_progress`` resets it to ``'active'`` on reingest.
|
||||
|
||||
Revision ID: 0008_ingest_progress_status
|
||||
Revises: 0007_message_events
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
revision: str = "0008_ingest_progress_status"
|
||||
down_revision: Union[str, None] = "0007_message_events"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Constant DEFAULT — metadata-only ADD COLUMN, no table rewrite.
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE ingest_chunk_progress
|
||||
ADD COLUMN status TEXT NOT NULL DEFAULT 'active'
|
||||
CHECK (status IN ('active', 'stalled'));
|
||||
"""
|
||||
)
|
||||
# Partial index for the reconciler's stalled-ingest sweep.
|
||||
op.execute(
|
||||
"CREATE INDEX ingest_chunk_progress_active_idx "
|
||||
"ON ingest_chunk_progress (last_updated) "
|
||||
"WHERE status = 'active';"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP INDEX IF EXISTS ingest_chunk_progress_active_idx;")
|
||||
op.execute(
|
||||
"ALTER TABLE ingest_chunk_progress DROP COLUMN IF EXISTS status;"
|
||||
)
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
@@ -26,13 +27,20 @@ LEASE_HEARTBEAT_INTERVAL = 30
|
||||
LEASE_RETRY_MAX = 10
|
||||
|
||||
|
||||
def with_idempotency(task_name: str) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
def with_idempotency(
|
||||
task_name: str,
|
||||
*,
|
||||
on_poison: Optional[Callable[[str, dict], None]] = None,
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
"""Short-circuit on completed key; gate concurrent runs via a lease.
|
||||
|
||||
The guard key is the caller's ``idempotency_key``, or one synthesized
|
||||
from ``source_id`` so a keyless dispatch is still poison-guarded.
|
||||
|
||||
Entry short-circuits:
|
||||
- completed row → return cached result
|
||||
- live lease held → retry(countdown=LEASE_TTL_SECONDS)
|
||||
- attempt_count > MAX_TASK_ATTEMPTS → poison-loop alert
|
||||
- attempt_count > MAX_TASK_ATTEMPTS → poison alert; ``on_poison`` fires
|
||||
Success writes ``completed``; exceptions leave ``pending`` for
|
||||
autoretry until the poison-loop guard trips.
|
||||
"""
|
||||
@@ -40,7 +48,14 @@ def with_idempotency(task_name: str) -> Callable[[Callable[..., Any]], Callable[
|
||||
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
|
||||
@functools.wraps(fn)
|
||||
def wrapper(self, *args: Any, idempotency_key: Any = None, **kwargs: Any) -> Any:
|
||||
key = idempotency_key if isinstance(idempotency_key, str) and idempotency_key else None
|
||||
explicit_key = (
|
||||
idempotency_key
|
||||
if isinstance(idempotency_key, str) and idempotency_key
|
||||
else None
|
||||
)
|
||||
# A keyless dispatch still gets the guard via a synthesized key;
|
||||
# None means no anchor exists — run unguarded, as before.
|
||||
key = explicit_key or _synthesize_guard_key(task_name, kwargs)
|
||||
if key is None:
|
||||
return fn(self, *args, idempotency_key=idempotency_key, **kwargs)
|
||||
|
||||
@@ -88,6 +103,9 @@ def with_idempotency(task_name: str) -> Callable[[Callable[..., Any]], Callable[
|
||||
"attempts": attempt,
|
||||
}
|
||||
_finalize(key, poisoned, status="failed")
|
||||
_run_poison_hook(
|
||||
on_poison, task_name, fn, self, args, kwargs, idempotency_key,
|
||||
)
|
||||
return poisoned
|
||||
|
||||
heartbeat_thread, heartbeat_stop = _start_lease_heartbeat(
|
||||
@@ -109,6 +127,45 @@ def with_idempotency(task_name: str) -> Callable[[Callable[..., Any]], Callable[
|
||||
return decorator
|
||||
|
||||
|
||||
def _synthesize_guard_key(task_name: str, kwargs: dict) -> Optional[str]:
|
||||
"""Derive a deterministic guard key from ``source_id`` for a keyless dispatch.
|
||||
|
||||
``source_id`` is stable across broker redeliveries and unique per
|
||||
upload, so the poison-loop counter survives an OOM SIGKILL. Returns
|
||||
``None`` when absent — the dispatch then runs unguarded as before.
|
||||
"""
|
||||
source_id = kwargs.get("source_id")
|
||||
if source_id:
|
||||
return f"auto:{task_name}:{source_id}"
|
||||
return None
|
||||
|
||||
|
||||
def _run_poison_hook(
|
||||
on_poison: Optional[Callable[[str, dict], None]],
|
||||
task_name: str,
|
||||
fn: Callable[..., Any],
|
||||
task_self: Any,
|
||||
args: tuple,
|
||||
kwargs: dict,
|
||||
idempotency_key: Any,
|
||||
) -> None:
|
||||
"""Invoke a task's poison-path hook with named call args; swallow failures.
|
||||
|
||||
A hook failure must never change the poison-guard outcome.
|
||||
"""
|
||||
if on_poison is None:
|
||||
return
|
||||
try:
|
||||
bound = inspect.signature(fn).bind_partial(
|
||||
task_self, *args, idempotency_key=idempotency_key, **kwargs,
|
||||
)
|
||||
on_poison(task_name, dict(bound.arguments))
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"idempotency: poison hook failed for task=%s", task_name,
|
||||
)
|
||||
|
||||
|
||||
def _lookup_completed(key: str) -> Any:
|
||||
"""Return cached ``result_json`` if a completed row exists for ``key``, else None."""
|
||||
with db_readonly() as conn:
|
||||
|
||||
@@ -114,11 +114,11 @@ def run_reconciliation() -> Dict[str, Any]:
|
||||
},
|
||||
)
|
||||
|
||||
# Q4: ingest checkpoints whose heartbeat has gone silent. The
|
||||
# reconciler only escalates (alerts) — it doesn't kill the worker
|
||||
# or roll back the partial embed. The next dispatch resumes from
|
||||
# ``last_index`` thanks to the per-chunk checkpoint, so this is an
|
||||
# observability sweep, not a recovery action.
|
||||
# Q4: ingest checkpoints whose heartbeat has gone silent. Each is
|
||||
# escalated to terminal ``status='stalled'`` and alerted once — no
|
||||
# worker kill, no rollback of the partial embed. The 'stalled' flag
|
||||
# ends the re-alert loop and drives the "indexing failed" badge the
|
||||
# sources list derives from this row.
|
||||
with engine.begin() as conn:
|
||||
repo = ReconciliationRepository(conn)
|
||||
for row in repo.find_and_lock_stalled_ingests():
|
||||
@@ -134,8 +134,7 @@ def run_reconciliation() -> Dict[str, Any]:
|
||||
"last_updated": str(row.get("last_updated")),
|
||||
},
|
||||
)
|
||||
# Bump the heartbeat so we don't re-alert every tick.
|
||||
repo.touch_ingest_progress(str(row["source_id"]))
|
||||
repo.mark_ingest_stalled(str(row["source_id"]))
|
||||
|
||||
# Q5: idempotency rows whose lease expired with attempts exhausted.
|
||||
# The wrapper's poison-loop guard normally finalises these, but if
|
||||
|
||||
@@ -7,8 +7,12 @@ from flask import current_app, jsonify, make_response, redirect, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.tasks import sync_source
|
||||
from application.api.user.tasks import reingest_source_task, sync_source
|
||||
from application.core.settings import settings
|
||||
from application.parser.remote.remote_creator import normalize_remote_data
|
||||
from application.storage.db.repositories.ingest_chunk_progress import (
|
||||
IngestChunkProgressRepository,
|
||||
)
|
||||
from application.storage.db.repositories.sources import SourcesRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
@@ -139,6 +143,8 @@ class PaginatedSources(Resource):
|
||||
"provider": provider,
|
||||
"isNested": bool(doc.get("directory_structure")),
|
||||
"type": doc.get("type", "file"),
|
||||
# Derived in SourcesRepository.list_for_user.
|
||||
"ingestStatus": doc.get("ingest_status"),
|
||||
}
|
||||
)
|
||||
response = {
|
||||
@@ -322,7 +328,7 @@ class SyncSource(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
source_data = doc.get("remote_data")
|
||||
source_data = normalize_remote_data(source_type, doc.get("remote_data"))
|
||||
if not source_data:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Source is not syncable"}), 400
|
||||
@@ -346,6 +352,70 @@ class SyncSource(Resource):
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
|
||||
|
||||
@sources_ns.route("/sources/reingest")
|
||||
class ReingestSource(Resource):
|
||||
reingest_source_model = api.model(
|
||||
"ReingestSourceModel",
|
||||
{"source_id": fields.String(required=True, description="Source ID")},
|
||||
)
|
||||
|
||||
@api.expect(reingest_source_model)
|
||||
@api.doc(
|
||||
description="Re-run ingestion for a source — e.g. to recover a "
|
||||
"stalled embed flagged by the reconciler."
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json() or {}
|
||||
missing_fields = check_required_fields(data, ["source_id"])
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
source_id = data["source_id"]
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
doc = SourcesRepository(conn).get_any(source_id, user)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error looking up source: {err}", exc_info=True
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid source ID"}), 400
|
||||
)
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Source not found"}), 404
|
||||
)
|
||||
resolved_source_id = str(doc["id"])
|
||||
# Drop the stale chunk-progress row so the sources list stops
|
||||
# deriving a 'failed' status; reingest never rewrites it itself.
|
||||
try:
|
||||
with db_session() as conn:
|
||||
IngestChunkProgressRepository(conn).delete(resolved_source_id)
|
||||
except Exception as err:
|
||||
current_app.logger.warning(
|
||||
f"Could not clear ingest progress for {resolved_source_id}: "
|
||||
f"{err}",
|
||||
exc_info=True,
|
||||
)
|
||||
try:
|
||||
# Scoped key so repeated clicks collapse onto one reingest.
|
||||
task = reingest_source_task.delay(
|
||||
source_id=resolved_source_id,
|
||||
user=user,
|
||||
idempotency_key=f"reingest-source:{user}:{resolved_source_id}",
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error starting reingest for source {source_id}: {err}",
|
||||
exc_info=True,
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
|
||||
|
||||
@sources_ns.route("/directory_structure")
|
||||
class DirectoryStructure(Resource):
|
||||
@api.doc(
|
||||
|
||||
@@ -27,8 +27,42 @@ DURABLE_TASK = dict(
|
||||
)
|
||||
|
||||
|
||||
# operation tag for the poison-path source.ingest.failed event, per task.
|
||||
_INGEST_POISON_OPERATION = {
|
||||
"ingest": "upload",
|
||||
"ingest_remote": "upload",
|
||||
"ingest_connector_task": "upload",
|
||||
"reingest_source_task": "reingest",
|
||||
}
|
||||
|
||||
|
||||
def _emit_ingest_poison_event(task_name, bound):
|
||||
"""Publish a terminal ``source.ingest.failed`` when the poison-guard trips.
|
||||
|
||||
The guard returns before the worker runs, so the worker's own failed
|
||||
event never fires — without this the upload toast spins on "training".
|
||||
"""
|
||||
user = bound.get("user")
|
||||
source_id = bound.get("source_id")
|
||||
if not user or not source_id:
|
||||
return
|
||||
from application.events.publisher import publish_user_event
|
||||
|
||||
publish_user_event(
|
||||
user,
|
||||
"source.ingest.failed",
|
||||
{
|
||||
"source_id": str(source_id),
|
||||
"filename": bound.get("filename") or "",
|
||||
"operation": _INGEST_POISON_OPERATION.get(task_name, "upload"),
|
||||
"error": "Ingestion stopped after repeated failures.",
|
||||
},
|
||||
scope={"kind": "source", "id": str(source_id)},
|
||||
)
|
||||
|
||||
|
||||
@celery.task(**DURABLE_TASK)
|
||||
@with_idempotency(task_name="ingest")
|
||||
@with_idempotency(task_name="ingest", on_poison=_emit_ingest_poison_event)
|
||||
def ingest(
|
||||
self,
|
||||
directory,
|
||||
@@ -57,7 +91,7 @@ def ingest(
|
||||
|
||||
|
||||
@celery.task(**DURABLE_TASK)
|
||||
@with_idempotency(task_name="ingest_remote")
|
||||
@with_idempotency(task_name="ingest_remote", on_poison=_emit_ingest_poison_event)
|
||||
def ingest_remote(
|
||||
self, source_data, job_name, user, loader,
|
||||
idempotency_key=None, source_id=None,
|
||||
@@ -71,7 +105,9 @@ def ingest_remote(
|
||||
|
||||
|
||||
@celery.task(**DURABLE_TASK)
|
||||
@with_idempotency(task_name="reingest_source_task")
|
||||
@with_idempotency(
|
||||
task_name="reingest_source_task", on_poison=_emit_ingest_poison_event,
|
||||
)
|
||||
def reingest_source_task(self, source_id, user, idempotency_key=None):
|
||||
from application.worker import reingest_source_worker
|
||||
|
||||
@@ -128,7 +164,9 @@ def process_agent_webhook(self, agent_id, payload, idempotency_key=None):
|
||||
|
||||
|
||||
@celery.task(**DURABLE_TASK)
|
||||
@with_idempotency(task_name="ingest_connector_task")
|
||||
@with_idempotency(
|
||||
task_name="ingest_connector_task", on_poison=_emit_ingest_poison_event,
|
||||
)
|
||||
def ingest_connector_task(
|
||||
self,
|
||||
job_name,
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import ctypes
|
||||
import gc
|
||||
import inspect
|
||||
import logging
|
||||
import sys
|
||||
import threading
|
||||
|
||||
from celery import Celery
|
||||
@@ -98,6 +101,34 @@ def _unbind_task_log_context(task_id, **_):
|
||||
)
|
||||
|
||||
|
||||
def _trim_native_heap() -> None:
|
||||
"""Return freed glibc heap pages to the OS (Linux only; no-op elsewhere)."""
|
||||
# docling/torch parsing makes large transient allocations; glibc keeps the
|
||||
# freed pages in per-thread malloc arenas rather than returning them, so a
|
||||
# long-lived worker child's RSS only ever climbs. malloc_trim hands them
|
||||
# back. The symbol is glibc-only — absent in macOS libc.
|
||||
if not sys.platform.startswith("linux"):
|
||||
return
|
||||
try:
|
||||
ctypes.CDLL("libc.so.6").malloc_trim(0)
|
||||
except (OSError, AttributeError):
|
||||
pass
|
||||
|
||||
|
||||
@task_postrun.connect
|
||||
def _reclaim_memory_after_task(*args, **kwargs):
|
||||
"""Drop per-task allocations so the prefork child's RSS doesn't ratchet."""
|
||||
gc.collect()
|
||||
torch = sys.modules.get("torch")
|
||||
if torch is not None:
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
_trim_native_heap()
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def _run_version_check(*args, **kwargs):
|
||||
"""Kick off the anonymous version check on worker startup.
|
||||
|
||||
@@ -31,3 +31,10 @@ worker_prefetch_multiplier = settings.CELERY_WORKER_PREFETCH_MULTIPLIER
|
||||
broker_transport_options = {"visibility_timeout": settings.CELERY_VISIBILITY_TIMEOUT}
|
||||
result_expires = 86400 * 7
|
||||
task_track_started = True
|
||||
|
||||
# Recycle the prefork worker child to bound native-heap growth from
|
||||
# docling/torch parsing. Left unset (Celery's unlimited default) when 0.
|
||||
if settings.CELERY_WORKER_MAX_MEMORY_PER_CHILD > 0:
|
||||
worker_max_memory_per_child = settings.CELERY_WORKER_MAX_MEMORY_PER_CHILD
|
||||
if settings.CELERY_WORKER_MAX_TASKS_PER_CHILD > 0:
|
||||
worker_max_tasks_per_child = settings.CELERY_WORKER_MAX_TASKS_PER_CHILD
|
||||
|
||||
@@ -36,6 +36,11 @@ class Settings(BaseSettings):
|
||||
# and Dify defaults; long ingests can override via env.
|
||||
CELERY_WORKER_PREFETCH_MULTIPLIER: int = 1
|
||||
CELERY_VISIBILITY_TIMEOUT: int = 3600
|
||||
# Recycle the prefork worker child once its resident size crosses this many
|
||||
# kilobytes — backstops native-heap growth from docling/torch parsing. 0 disables.
|
||||
CELERY_WORKER_MAX_MEMORY_PER_CHILD: int = 4194304
|
||||
# Recycle the child after this many tasks; 0 disables (memory cap is the primary knob).
|
||||
CELERY_WORKER_MAX_TASKS_PER_CHILD: int = 0
|
||||
# Only consulted when VECTOR_STORE=mongodb or when running scripts/db/backfill.py; user data lives in Postgres.
|
||||
MONGO_URI: Optional[str] = None
|
||||
# User-data Postgres DB.
|
||||
@@ -61,6 +66,9 @@ class Settings(BaseSettings):
|
||||
PARSE_IMAGE_REMOTE: bool = False
|
||||
DOCLING_OCR_ENABLED: bool = False # Enable OCR for docling parsers (PDF, images)
|
||||
DOCLING_OCR_ATTACHMENTS_ENABLED: bool = False # Enable OCR for docling when parsing attachments
|
||||
# Pages docling's threaded pipeline buffers in flight; the library
|
||||
# default (100) drives worker RSS to ~3 GB on a mid-size PDF.
|
||||
DOCLING_PIPELINE_QUEUE_MAX_SIZE: int = 2
|
||||
VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb" or "pgvector"
|
||||
RETRIEVERS_ENABLED: list = ["classic_rag"]
|
||||
AGENT_NAME: str = "classic"
|
||||
|
||||
@@ -154,6 +154,8 @@ def embed_and_store_documents(
|
||||
*,
|
||||
attempt_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
progress_start: int = 0,
|
||||
progress_end: int = 100,
|
||||
) -> None:
|
||||
"""Embeds documents and stores them in a vector store.
|
||||
|
||||
@@ -176,6 +178,11 @@ def embed_and_store_documents(
|
||||
published to ``user:{user_id}`` for the in-app upload toast.
|
||||
``None`` is the safe default — workers without a user
|
||||
context (e.g. background syncs) skip the publish.
|
||||
progress_start: Percent the reported progress maps to at chunk 0.
|
||||
Lets a caller reserve the lower band for an earlier stage
|
||||
(e.g. parsing). Defaults to ``0`` (embed owns the whole bar).
|
||||
progress_end: Percent the reported progress maps to at the final
|
||||
chunk. Defaults to ``100``.
|
||||
|
||||
Returns:
|
||||
None
|
||||
@@ -257,6 +264,7 @@ def embed_and_store_documents(
|
||||
failed_idx: int | None = None
|
||||
last_published_pct = -1
|
||||
source_id_str = str(source_id)
|
||||
progress_span = progress_end - progress_start
|
||||
for idx in tqdm(
|
||||
range(loop_start, total_docs),
|
||||
desc="Embedding 🦖",
|
||||
@@ -266,8 +274,10 @@ def embed_and_store_documents(
|
||||
):
|
||||
doc = docs[idx]
|
||||
try:
|
||||
# Update task status for progress tracking
|
||||
progress = int(((idx + 1) / total_docs) * 100)
|
||||
# Map the embed loop into [progress_start, progress_end].
|
||||
progress = progress_start + int(
|
||||
((idx + 1) / total_docs) * progress_span
|
||||
)
|
||||
task_status.update_state(state="PROGRESS", meta={"current": progress})
|
||||
|
||||
# SSE push for sub-second upload-toast updates. Throttled to one
|
||||
|
||||
@@ -211,13 +211,22 @@ class SimpleDirectoryReader(BaseReader):
|
||||
|
||||
return new_input_files
|
||||
|
||||
def load_data(self, concatenate: bool = False) -> List[Document]:
|
||||
def load_data(
|
||||
self,
|
||||
concatenate: bool = False,
|
||||
progress_callback: Optional[Callable[[int, int], None]] = None,
|
||||
) -> List[Document]:
|
||||
"""Load data from the input directory.
|
||||
|
||||
Args:
|
||||
concatenate (bool): whether to concatenate all files into one document.
|
||||
If set to True, file metadata is ignored.
|
||||
False by default.
|
||||
progress_callback (Optional[Callable[[int, int], None]]): Called
|
||||
after each file is parsed with ``(files_done, total_files)``.
|
||||
Lets callers surface parse/OCR progress before embedding
|
||||
begins. Exceptions raised by the callback are swallowed so
|
||||
progress reporting can never fail ingestion.
|
||||
|
||||
Returns:
|
||||
List[Document]: A list of documents.
|
||||
@@ -226,8 +235,9 @@ class SimpleDirectoryReader(BaseReader):
|
||||
data_list: List[str] = []
|
||||
metadata_list = []
|
||||
self.file_token_counts = {}
|
||||
|
||||
for input_file in self.input_files:
|
||||
|
||||
total_files = len(self.input_files)
|
||||
for file_index, input_file in enumerate(self.input_files):
|
||||
suffix_lower = input_file.suffix.lower()
|
||||
parser_metadata = {}
|
||||
if suffix_lower in self.file_extractor:
|
||||
@@ -277,7 +287,15 @@ class SimpleDirectoryReader(BaseReader):
|
||||
else:
|
||||
data_list.append(str(data))
|
||||
metadata_list.append(base_metadata)
|
||||
|
||||
|
||||
if progress_callback is not None:
|
||||
try:
|
||||
progress_callback(file_index + 1, total_files)
|
||||
except Exception:
|
||||
logging.warning(
|
||||
"load_data progress callback failed", exc_info=True
|
||||
)
|
||||
|
||||
# Build directory structure if input_dir is provided
|
||||
if hasattr(self, 'input_dir'):
|
||||
self.directory_structure = self.build_directory_structure(self.input_dir)
|
||||
|
||||
@@ -16,6 +16,29 @@ from application.parser.file.base_parser import BaseParser
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Per-stage batch size for docling's threaded pipeline; 1 holds the
|
||||
# concurrent working set to a single page (see _apply_pipeline_caps).
|
||||
_PIPELINE_BATCH_SIZE = 1
|
||||
|
||||
|
||||
def _apply_pipeline_caps(pipeline_options) -> None:
|
||||
"""Cap docling's threaded-pipeline queue depth and batch sizes in place.
|
||||
|
||||
hasattr-guarded so docling builds without these knobs are unaffected.
|
||||
"""
|
||||
from application.core.settings import settings
|
||||
|
||||
caps = {
|
||||
"queue_max_size": max(1, settings.DOCLING_PIPELINE_QUEUE_MAX_SIZE),
|
||||
"layout_batch_size": _PIPELINE_BATCH_SIZE,
|
||||
"table_batch_size": _PIPELINE_BATCH_SIZE,
|
||||
"ocr_batch_size": _PIPELINE_BATCH_SIZE,
|
||||
}
|
||||
for name, value in caps.items():
|
||||
if hasattr(pipeline_options, name):
|
||||
setattr(pipeline_options, name, value)
|
||||
|
||||
|
||||
class DoclingParser(BaseParser):
|
||||
"""Parser using docling for advanced document processing.
|
||||
|
||||
@@ -86,6 +109,7 @@ class DoclingParser(BaseParser):
|
||||
do_ocr=self.ocr_enabled,
|
||||
do_table_structure=self.table_structure,
|
||||
)
|
||||
_apply_pipeline_caps(pipeline_options)
|
||||
|
||||
if self.ocr_enabled:
|
||||
ocr_options = self._get_ocr_options()
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import logging
|
||||
import os
|
||||
import requests
|
||||
from urllib.parse import urlparse, urljoin
|
||||
from bs4 import BeautifulSoup
|
||||
from application.parser.remote.base import BaseRemote
|
||||
from application.parser.schema.base import Document
|
||||
from application.core.url_validation import validate_url, SSRFError
|
||||
from application.security.safe_url import UnsafeUserUrlError, pinned_request
|
||||
from langchain_community.document_loaders import WebBaseLoader
|
||||
|
||||
class CrawlerLoader(BaseRemote):
|
||||
@@ -35,14 +35,7 @@ class CrawlerLoader(BaseRemote):
|
||||
visited_urls.add(current_url)
|
||||
|
||||
try:
|
||||
# Validate each URL before making requests
|
||||
try:
|
||||
validate_url(current_url)
|
||||
except SSRFError as e:
|
||||
logging.warning(f"Skipping URL due to validation failure: {current_url} - {e}")
|
||||
continue
|
||||
|
||||
response = requests.get(current_url, timeout=30)
|
||||
response = pinned_request("GET", current_url, timeout=30)
|
||||
response.raise_for_status()
|
||||
loader = self.loader([current_url])
|
||||
docs = loader.load()
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import requests
|
||||
from urllib.parse import urlparse, urljoin
|
||||
from bs4 import BeautifulSoup
|
||||
from application.parser.remote.base import BaseRemote
|
||||
from application.core.url_validation import validate_url, SSRFError
|
||||
from application.security.safe_url import UnsafeUserUrlError, pinned_request
|
||||
import re
|
||||
from markdownify import markdownify
|
||||
from application.parser.schema.base import Document
|
||||
@@ -20,7 +20,6 @@ class CrawlerLoader(BaseRemote):
|
||||
"""
|
||||
self.limit = limit
|
||||
self.allow_subdomains = allow_subdomains
|
||||
self.session = requests.Session()
|
||||
|
||||
def load_data(self, inputs):
|
||||
url = inputs
|
||||
@@ -91,15 +90,13 @@ class CrawlerLoader(BaseRemote):
|
||||
|
||||
def _fetch_page(self, url):
|
||||
try:
|
||||
# Validate URL before fetching to prevent SSRF
|
||||
validate_url(url)
|
||||
response = self.session.get(url, timeout=10)
|
||||
response = pinned_request("GET", url, timeout=10)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except SSRFError as e:
|
||||
except UnsafeUserUrlError as e:
|
||||
print(f"URL validation failed for {url}: {e}")
|
||||
return None
|
||||
except requests.exceptions.RequestException as e:
|
||||
except Exception as e:
|
||||
print(f"Error fetching URL {url}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import json
|
||||
|
||||
from application.parser.remote.sitemap_loader import SitemapLoader
|
||||
from application.parser.remote.crawler_loader import CrawlerLoader
|
||||
from application.parser.remote.web_loader import WebLoader
|
||||
@@ -32,3 +34,59 @@ class RemoteCreator:
|
||||
if not loader_class:
|
||||
raise ValueError(f"No loader class found for type {type}")
|
||||
return loader_class(*args, **kwargs)
|
||||
|
||||
|
||||
# Loader types whose load_data expects a URL string, not a config dict.
|
||||
_URL_LOADER_TYPES = {"url", "crawler", "sitemap", "github"}
|
||||
|
||||
# Keys a remote_data dict may hold the URL under (``raw`` is the legacy shape).
|
||||
_URL_DATA_KEYS = ("url", "urls", "repo_url", "raw")
|
||||
|
||||
|
||||
def normalize_remote_data(source_type, remote_data):
|
||||
"""Convert a stored ``sources.remote_data`` JSONB value into the
|
||||
``source_data`` shape the matching loader expects.
|
||||
|
||||
Args:
|
||||
source_type: The ``sources.type`` value (the loader name).
|
||||
remote_data: The stored ``remote_data`` (dict, list, str, or None).
|
||||
|
||||
Returns:
|
||||
Loader input: a URL string or list for url/crawler/sitemap/github,
|
||||
a JSON string for reddit, a dict for s3; ``None`` when the row has
|
||||
nothing syncable.
|
||||
"""
|
||||
if remote_data is None:
|
||||
return None
|
||||
|
||||
# Some legacy rows stored the JSON itself as a string.
|
||||
if isinstance(remote_data, str):
|
||||
stripped = remote_data.strip()
|
||||
if stripped[:1] in ("{", "["):
|
||||
try:
|
||||
remote_data = json.loads(stripped)
|
||||
except json.JSONDecodeError:
|
||||
# Not actually JSON — leave remote_data as the original
|
||||
# string; the per-loader branches below handle a string.
|
||||
pass
|
||||
|
||||
loader = (source_type or "").lower()
|
||||
|
||||
if loader in _URL_LOADER_TYPES:
|
||||
if isinstance(remote_data, dict):
|
||||
for key in _URL_DATA_KEYS:
|
||||
value = remote_data.get(key)
|
||||
if value:
|
||||
return value
|
||||
# No URL key — None keeps the loader off the dict-crash path.
|
||||
return None
|
||||
return remote_data
|
||||
|
||||
if loader == "reddit":
|
||||
# reddit's loader runs json.loads() on its input — needs a string.
|
||||
if isinstance(remote_data, (dict, list)):
|
||||
return json.dumps(remote_data)
|
||||
return remote_data
|
||||
|
||||
# s3's loader accepts a dict or JSON string; pass it through unchanged.
|
||||
return remote_data
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import logging
|
||||
import requests
|
||||
import re # Import regular expression library
|
||||
import re
|
||||
import defusedxml.ElementTree as ET
|
||||
from application.parser.remote.base import BaseRemote
|
||||
from application.core.url_validation import validate_url, SSRFError
|
||||
from application.security.safe_url import UnsafeUserUrlError, pinned_request
|
||||
|
||||
class SitemapLoader(BaseRemote):
|
||||
def __init__(self, limit=20):
|
||||
@@ -53,14 +53,12 @@ class SitemapLoader(BaseRemote):
|
||||
|
||||
def _extract_urls(self, sitemap_url):
|
||||
try:
|
||||
# Validate URL before fetching to prevent SSRF
|
||||
validate_url(sitemap_url)
|
||||
response = requests.get(sitemap_url, timeout=30)
|
||||
response.raise_for_status() # Raise an exception for HTTP errors
|
||||
except SSRFError as e:
|
||||
response = pinned_request("GET", sitemap_url, timeout=30)
|
||||
response.raise_for_status()
|
||||
except UnsafeUserUrlError as e:
|
||||
print(f"URL validation failed for sitemap: {sitemap_url}. Error: {e}")
|
||||
return []
|
||||
except (requests.exceptions.HTTPError, requests.exceptions.ConnectionError) as e:
|
||||
except Exception as e:
|
||||
print(f"Failed to fetch sitemap: {sitemap_url}. Error: {e}")
|
||||
return []
|
||||
|
||||
@@ -97,13 +95,6 @@ class SitemapLoader(BaseRemote):
|
||||
nested_sitemap_url = sitemap.text
|
||||
if not nested_sitemap_url:
|
||||
continue
|
||||
try:
|
||||
nested_sitemap_url = validate_url(nested_sitemap_url)
|
||||
except SSRFError as e:
|
||||
logging.error(
|
||||
f"URL validation failed for nested sitemap {nested_sitemap_url}: {e}"
|
||||
)
|
||||
continue
|
||||
urls.extend(self._extract_urls(nested_sitemap_url))
|
||||
|
||||
return urls
|
||||
|
||||
@@ -291,6 +291,55 @@ def _ip_to_url_host(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> str:
|
||||
return str(ip)
|
||||
|
||||
|
||||
def pinned_request(
|
||||
method: str,
|
||||
url: str,
|
||||
*,
|
||||
data: Any = None,
|
||||
json: Any = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: float = 90.0,
|
||||
allow_redirects: bool = False,
|
||||
) -> requests.Response:
|
||||
"""Send an HTTP request with the connection pinned to a validated IP,
|
||||
closing the DNS-rebinding TOCTOU window left by the naive
|
||||
validate-then-``requests`` pattern.
|
||||
|
||||
Raises:
|
||||
UnsafeUserUrlError: If the URL fails the SSRF guard.
|
||||
requests.RequestException: For network-level failures.
|
||||
"""
|
||||
|
||||
host, ip, parts = _validate_and_pick_ip(url)
|
||||
|
||||
netloc = _ip_to_url_host(ip)
|
||||
if parts.port is not None:
|
||||
netloc = f"{netloc}:{parts.port}"
|
||||
pinned_url = urlunsplit(
|
||||
(parts.scheme, netloc, parts.path, parts.query, parts.fragment)
|
||||
)
|
||||
|
||||
request_headers = dict(headers or {})
|
||||
host_header = host if parts.port is None else f"{host}:{parts.port}"
|
||||
request_headers["Host"] = host_header
|
||||
|
||||
session = requests.Session()
|
||||
if parts.scheme == "https":
|
||||
session.mount("https://", _PinnedHostAdapter(host))
|
||||
try:
|
||||
return session.request(
|
||||
method=method.upper(),
|
||||
url=pinned_url,
|
||||
data=data,
|
||||
json=json,
|
||||
headers=request_headers,
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects,
|
||||
)
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
def pinned_post(
|
||||
url: str,
|
||||
*,
|
||||
@@ -328,33 +377,15 @@ def pinned_post(
|
||||
requests.RequestException: For network-level failures.
|
||||
"""
|
||||
|
||||
host, ip, parts = _validate_and_pick_ip(url)
|
||||
|
||||
netloc = _ip_to_url_host(ip)
|
||||
if parts.port is not None:
|
||||
netloc = f"{netloc}:{parts.port}"
|
||||
pinned_url = urlunsplit(
|
||||
(parts.scheme, netloc, parts.path, parts.query, parts.fragment)
|
||||
return pinned_request(
|
||||
"POST",
|
||||
url,
|
||||
json=json,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects,
|
||||
)
|
||||
|
||||
request_headers = dict(headers or {})
|
||||
host_header = host if parts.port is None else f"{host}:{parts.port}"
|
||||
request_headers["Host"] = host_header
|
||||
|
||||
session = requests.Session()
|
||||
if parts.scheme == "https":
|
||||
session.mount("https://", _PinnedHostAdapter(host))
|
||||
try:
|
||||
return session.post(
|
||||
pinned_url,
|
||||
json=json,
|
||||
headers=request_headers,
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects,
|
||||
)
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
class _PinnedHTTPSTransport(httpx.HTTPTransport):
|
||||
"""``httpx`` transport pinned to a single validated IP literal.
|
||||
|
||||
@@ -514,6 +514,9 @@ ingest_chunk_progress_table = Table(
|
||||
# same task resumes from the checkpoint, but a separate invocation
|
||||
# (manual reingest, scheduled sync) resets to a clean re-index.
|
||||
Column("attempt_id", Text),
|
||||
# Added in ``0008_ingest_progress_status``. The reconciler flips
|
||||
# this to 'stalled'; ``init_progress`` resets it to 'active'.
|
||||
Column("status", Text, nullable=False, server_default="active"),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -41,6 +41,9 @@ class IngestChunkProgressRepository:
|
||||
rows with NULL ``attempt_id`` resume against another NULL
|
||||
caller (e.g. test fixtures), but get reset the moment a real
|
||||
``attempt_id`` arrives.
|
||||
|
||||
Both branches also reset ``status`` to ``'active'``, clearing a
|
||||
prior reconciler ``'stalled'`` escalation.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
@@ -68,7 +71,8 @@ class IngestChunkProgressRepository:
|
||||
THEN ingest_chunk_progress.embedded_chunks
|
||||
ELSE 0
|
||||
END,
|
||||
attempt_id = EXCLUDED.attempt_id
|
||||
attempt_id = EXCLUDED.attempt_id,
|
||||
status = 'active'
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
@@ -113,6 +117,23 @@ class IngestChunkProgressRepository:
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def delete(self, source_id: str) -> bool:
|
||||
"""Delete the progress row for ``source_id``.
|
||||
|
||||
A manual reingest supersedes any prior ingest state — including a
|
||||
reconciler ``'stalled'`` escalation — so dropping the row clears
|
||||
the derived ``failed`` ingest status the sources list shows.
|
||||
Returns ``True`` when a row was removed.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM ingest_chunk_progress "
|
||||
"WHERE source_id = CAST(:source_id AS uuid)"
|
||||
),
|
||||
{"source_id": str(source_id)},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def bump_heartbeat(self, source_id: str) -> None:
|
||||
"""Refresh ``last_updated`` so the row looks alive to the reconciler."""
|
||||
self._conn.execute(
|
||||
|
||||
@@ -107,7 +107,11 @@ class ReconciliationRepository:
|
||||
def find_and_lock_stalled_ingests(
|
||||
self, *, age_minutes: int = 30, limit: int = 100,
|
||||
) -> list[dict]:
|
||||
"""Lock ingest checkpoints whose heartbeat hasn't ticked recently."""
|
||||
"""Lock still-active ingest checkpoints with a silent heartbeat.
|
||||
|
||||
The ``status = 'active'`` filter skips rows already escalated to
|
||||
``'stalled'``, so a dead ingest is alerted once, not every tick.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
@@ -116,6 +120,7 @@ class ReconciliationRepository:
|
||||
FROM ingest_chunk_progress
|
||||
WHERE last_updated < now() - make_interval(mins => :age)
|
||||
AND embedded_chunks < total_chunks
|
||||
AND status = 'active'
|
||||
ORDER BY last_updated ASC
|
||||
LIMIT :limit
|
||||
FOR UPDATE SKIP LOCKED
|
||||
@@ -125,11 +130,15 @@ class ReconciliationRepository:
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def touch_ingest_progress(self, source_id: str) -> bool:
|
||||
"""Bump ``last_updated`` so a once-stalled ingest re-enters the watch window."""
|
||||
def mark_ingest_stalled(self, source_id: str) -> bool:
|
||||
"""Escalate a stalled checkpoint to terminal ``status='stalled'``.
|
||||
|
||||
Drops the row out of the sweep so the reconciler alerts once;
|
||||
``init_progress`` flips it back to ``'active'`` on reingest.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE ingest_chunk_progress SET last_updated = now() "
|
||||
"UPDATE ingest_chunk_progress SET status = 'stalled' "
|
||||
"WHERE source_id = CAST(:sid AS uuid)"
|
||||
),
|
||||
{"sid": str(source_id)},
|
||||
|
||||
@@ -5,10 +5,10 @@ from __future__ import annotations
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import Connection, func, select, text
|
||||
from sqlalchemy import case, Connection, func, select, text
|
||||
|
||||
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
|
||||
from application.storage.db.models import sources_table
|
||||
from application.storage.db.models import ingest_chunk_progress_table, sources_table
|
||||
|
||||
|
||||
_SCALAR_COLUMNS = {
|
||||
@@ -61,6 +61,21 @@ def _coerce_jsonb(value: Any) -> Any:
|
||||
return value
|
||||
|
||||
|
||||
def _ingest_status_case():
|
||||
"""Derive a user-facing ingest status from the joined progress row.
|
||||
|
||||
``failed`` — reconciler-escalated stall. ``processing`` — embed in
|
||||
flight. ``None`` — no progress row, or the embed completed.
|
||||
"""
|
||||
icp = ingest_chunk_progress_table
|
||||
return case(
|
||||
(icp.c.source_id.is_(None), None),
|
||||
(icp.c.status == "stalled", "failed"),
|
||||
(icp.c.embedded_chunks < icp.c.total_chunks, "processing"),
|
||||
else_=None,
|
||||
).label("ingest_status")
|
||||
|
||||
|
||||
class SourcesRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
@@ -192,13 +207,25 @@ class SourcesRepository:
|
||||
as ``"desc"``.
|
||||
|
||||
Returns:
|
||||
A list of source rows as plain dicts (via ``row_to_dict``).
|
||||
A list of source rows as plain dicts (via ``row_to_dict``),
|
||||
each carrying a derived ``ingest_status`` (``failed`` /
|
||||
``processing`` / ``None``) from the joined progress row.
|
||||
"""
|
||||
column_name = sort_field if sort_field in _SORTABLE_COLUMNS else "date"
|
||||
sort_column = sources_table.c[column_name]
|
||||
ascending = sort_order.lower() == "asc"
|
||||
|
||||
stmt = select(sources_table).where(sources_table.c.user_id == user_id)
|
||||
stmt = (
|
||||
select(sources_table, _ingest_status_case())
|
||||
.select_from(
|
||||
sources_table.outerjoin(
|
||||
ingest_chunk_progress_table,
|
||||
ingest_chunk_progress_table.c.source_id
|
||||
== sources_table.c.id,
|
||||
)
|
||||
)
|
||||
.where(sources_table.c.user_id == user_id)
|
||||
)
|
||||
if search_term:
|
||||
stmt = stmt.where(
|
||||
sources_table.c.name.ilike(
|
||||
|
||||
@@ -63,7 +63,8 @@ class ToolCallAttemptsRepository:
|
||||
message_id: Optional[str] = None,
|
||||
artifact_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Insert OR upgrade a row to ``executed``.
|
||||
"""Insert OR upgrade a row to ``executed`` — or ``confirmed`` when
|
||||
there is no ``message_id``, as in ``mark_executed``.
|
||||
|
||||
Used as a fallback when ``record_proposed`` failed (DB outage)
|
||||
and the tool ran anyway — preserves the journal so the
|
||||
@@ -72,6 +73,7 @@ class ToolCallAttemptsRepository:
|
||||
result_payload: dict = {"result": result}
|
||||
if artifact_id:
|
||||
result_payload["artifact_id"] = artifact_id
|
||||
status = "executed" if message_id is not None else "confirmed"
|
||||
self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
@@ -82,9 +84,9 @@ class ToolCallAttemptsRepository:
|
||||
(:call_id, CAST(:tool_id AS uuid), :tool_name,
|
||||
:action_name, CAST(:arguments AS jsonb),
|
||||
CAST(:result AS jsonb), CAST(:message_id AS uuid),
|
||||
'executed')
|
||||
:status)
|
||||
ON CONFLICT (call_id) DO UPDATE
|
||||
SET status = 'executed',
|
||||
SET status = :status,
|
||||
result = EXCLUDED.result,
|
||||
message_id = COALESCE(EXCLUDED.message_id, tool_call_attempts.message_id)
|
||||
"""
|
||||
@@ -97,6 +99,7 @@ class ToolCallAttemptsRepository:
|
||||
"arguments": json.dumps(arguments if arguments is not None else {}, cls=PGNativeJSONEncoder),
|
||||
"result": json.dumps(result_payload, cls=PGNativeJSONEncoder),
|
||||
"message_id": message_id,
|
||||
"status": status,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -108,7 +111,9 @@ class ToolCallAttemptsRepository:
|
||||
message_id: Optional[str] = None,
|
||||
artifact_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Flip ``proposed`` → ``executed`` with the tool result.
|
||||
"""Flip ``proposed`` → ``executed``, or straight to ``confirmed``
|
||||
when there is no ``message_id`` (a ``save_conversation=False``
|
||||
request reserves no message, so no finalize will confirm it).
|
||||
|
||||
``artifact_id`` (when present) is stored alongside ``result`` in
|
||||
the JSONB as audit data — the reconciler reads it for diagnostic
|
||||
@@ -117,12 +122,14 @@ class ToolCallAttemptsRepository:
|
||||
result_payload: dict = {"result": result}
|
||||
if artifact_id:
|
||||
result_payload["artifact_id"] = artifact_id
|
||||
status = "executed" if message_id is not None else "confirmed"
|
||||
sql = (
|
||||
"UPDATE tool_call_attempts SET "
|
||||
"status = 'executed', result = CAST(:result AS jsonb)"
|
||||
"status = :status, result = CAST(:result AS jsonb)"
|
||||
)
|
||||
params: dict[str, Any] = {
|
||||
"call_id": call_id,
|
||||
"status": status,
|
||||
"result": json.dumps(result_payload, cls=PGNativeJSONEncoder),
|
||||
}
|
||||
if message_id is not None:
|
||||
|
||||
@@ -29,7 +29,10 @@ from application.parser.embedding_pipeline import (
|
||||
)
|
||||
from application.parser.file.bulk import SimpleDirectoryReader, get_default_file_extractor
|
||||
from application.parser.file.constants import SUPPORTED_SOURCE_EXTENSIONS
|
||||
from application.parser.remote.remote_creator import RemoteCreator
|
||||
from application.parser.remote.remote_creator import (
|
||||
RemoteCreator,
|
||||
normalize_remote_data,
|
||||
)
|
||||
from application.parser.schema.base import Document
|
||||
from application.retriever.retriever_creator import RetrieverCreator
|
||||
|
||||
@@ -97,6 +100,40 @@ def _stop_ingest_heartbeat(thread, stop_event):
|
||||
thread.join(timeout=5)
|
||||
|
||||
|
||||
def _make_parse_progress_callback(task, user, source_id, start_pct, end_pct):
|
||||
"""Build a ``load_data`` callback mapping parse progress to
|
||||
``[start_pct, end_pct]`` via ``update_state`` + a throttled
|
||||
``stage='parsing'`` SSE event.
|
||||
"""
|
||||
span = end_pct - start_pct
|
||||
source_id_str = str(source_id)
|
||||
state = {"last_pct": -1}
|
||||
|
||||
def _callback(files_done, total_files):
|
||||
if not total_files:
|
||||
return
|
||||
pct = start_pct + int((files_done / total_files) * span)
|
||||
task.update_state(
|
||||
state="PROGRESS",
|
||||
meta={"current": pct, "status": "Parsing files"},
|
||||
)
|
||||
if user and pct > state["last_pct"]:
|
||||
publish_user_event(
|
||||
user,
|
||||
"source.ingest.progress",
|
||||
{
|
||||
"current": pct,
|
||||
"total": total_files,
|
||||
"files_done": files_done,
|
||||
"stage": "parsing",
|
||||
},
|
||||
scope={"kind": "source", "id": source_id_str},
|
||||
)
|
||||
state["last_pct"] = pct
|
||||
|
||||
return _callback
|
||||
|
||||
|
||||
# Define a function to extract metadata from a given filename.
|
||||
|
||||
|
||||
@@ -637,7 +674,12 @@ def ingest_worker(
|
||||
exclude_hidden=exclude,
|
||||
file_metadata=metadata_from_filename,
|
||||
)
|
||||
raw_docs = reader.load_data()
|
||||
# Parsing/OCR owns 1-50% of the bar; embedding takes 50-100%.
|
||||
raw_docs = reader.load_data(
|
||||
progress_callback=_make_parse_progress_callback(
|
||||
self, user, source_uuid, start_pct=1, end_pct=50,
|
||||
)
|
||||
)
|
||||
|
||||
directory_structure = getattr(reader, "directory_structure", {})
|
||||
logging.info(f"Directory structure from reader: {directory_structure}")
|
||||
@@ -677,6 +719,7 @@ def ingest_worker(
|
||||
docs, vector_store_path, source_uuid, self,
|
||||
attempt_id=getattr(self.request, "id", None),
|
||||
user_id=user,
|
||||
progress_start=50, progress_end=100,
|
||||
)
|
||||
finally:
|
||||
_stop_ingest_heartbeat(heartbeat_thread, heartbeat_stop)
|
||||
@@ -807,6 +850,8 @@ def reingest_source_worker(self, source_id, user):
|
||||
{
|
||||
"source_id": source_id,
|
||||
"name": source_name,
|
||||
# ``filename`` labels the upload toast on auto-create.
|
||||
"filename": source_name,
|
||||
"operation": "reingest",
|
||||
},
|
||||
scope={"kind": "source", "id": source_id},
|
||||
@@ -914,6 +959,7 @@ def reingest_source_worker(self, source_id, user):
|
||||
{
|
||||
"source_id": source_id,
|
||||
"name": source_name,
|
||||
"filename": source_name,
|
||||
"operation": "reingest",
|
||||
"no_changes": True,
|
||||
"chunks_added": 0,
|
||||
@@ -1101,6 +1147,7 @@ def reingest_source_worker(self, source_id, user):
|
||||
completed_payload: dict = {
|
||||
"source_id": source_id,
|
||||
"name": source_name,
|
||||
"filename": source_name,
|
||||
"operation": "reingest",
|
||||
"chunks_added": added,
|
||||
"chunks_deleted": deleted,
|
||||
@@ -1140,6 +1187,7 @@ def reingest_source_worker(self, source_id, user):
|
||||
{
|
||||
"source_id": str(source_id),
|
||||
"name": source_name,
|
||||
"filename": source_name,
|
||||
"operation": "reingest",
|
||||
"error": str(e)[:1024],
|
||||
},
|
||||
@@ -1431,19 +1479,35 @@ def sync_worker(self, frequency):
|
||||
name = doc.get("name")
|
||||
user = doc.get("user_id")
|
||||
source_type = doc.get("type")
|
||||
source_data = doc.get("remote_data")
|
||||
retriever = doc.get("retriever")
|
||||
doc_id = str(doc.get("id"))
|
||||
|
||||
sync_counts["total_sync_count"] += 1
|
||||
|
||||
# Connector sources have no RemoteCreator loader and need an OAuth
|
||||
# token to sync, which a scheduled task lacks — skip them.
|
||||
if source_type and source_type.startswith("connector"):
|
||||
sync_counts["sync_skipped"] += 1
|
||||
continue
|
||||
|
||||
source_data = normalize_remote_data(source_type, doc.get("remote_data"))
|
||||
if not source_data:
|
||||
# No syncable URL/config — skip instead of dispatching a sync
|
||||
# that can only fail (and emit a spurious failed event).
|
||||
sync_counts["sync_skipped"] += 1
|
||||
continue
|
||||
|
||||
resp = sync(
|
||||
self, source_data, name, user, source_type, frequency, retriever, doc_id
|
||||
)
|
||||
sync_counts["total_sync_count"] += 1
|
||||
sync_counts[
|
||||
"sync_success" if resp["status"] == "success" else "sync_failure"
|
||||
] += 1
|
||||
return {
|
||||
key: sync_counts[key]
|
||||
for key in ["total_sync_count", "sync_success", "sync_failure"]
|
||||
for key in [
|
||||
"total_sync_count", "sync_success", "sync_failure", "sync_skipped",
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@@ -1785,14 +1849,15 @@ def ingest_connector(
|
||||
exclude_hidden=True,
|
||||
file_metadata=metadata_from_filename,
|
||||
)
|
||||
raw_docs = reader.load_data()
|
||||
# Parsing/OCR fills 40-60% of the bar; embedding takes 60-100%.
|
||||
raw_docs = reader.load_data(
|
||||
progress_callback=_make_parse_progress_callback(
|
||||
self, user, source_uuid, start_pct=40, end_pct=60,
|
||||
)
|
||||
)
|
||||
directory_structure = getattr(reader, "directory_structure", {})
|
||||
|
||||
# Step 4: Process documents (chunking, embedding, etc.)
|
||||
self.update_state(
|
||||
state="PROGRESS", meta={"current": 60, "status": "Processing documents"}
|
||||
)
|
||||
|
||||
chunker = Chunker(
|
||||
chunking_strategy="classic_chunk",
|
||||
max_tokens=MAX_TOKENS,
|
||||
@@ -1829,12 +1894,13 @@ def ingest_connector(
|
||||
os.makedirs(vector_store_path, exist_ok=True)
|
||||
|
||||
self.update_state(
|
||||
state="PROGRESS", meta={"current": 80, "status": "Storing documents"}
|
||||
state="PROGRESS", meta={"current": 60, "status": "Storing documents"}
|
||||
)
|
||||
embed_and_store_documents(
|
||||
docs, vector_store_path, source_uuid, self,
|
||||
attempt_id=getattr(self.request, "id", None),
|
||||
user_id=user,
|
||||
progress_start=60, progress_end=100,
|
||||
)
|
||||
assert_index_complete(source_uuid)
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ const endpoints = {
|
||||
LOGS: `/api/get_user_logs`,
|
||||
MANAGE_SYNC: '/api/manage_sync',
|
||||
SYNC_SOURCE: '/api/sync_source',
|
||||
REINGEST_SOURCE: '/api/sources/reingest',
|
||||
GET_AVAILABLE_TOOLS: '/api/available_tools',
|
||||
GET_USER_TOOLS: '/api/get_tools',
|
||||
CREATE_TOOL: '/api/create_tool',
|
||||
|
||||
@@ -73,6 +73,8 @@ const userService = {
|
||||
apiClient.post(endpoints.USER.MANAGE_SYNC, data, token),
|
||||
syncSource: (data: any, token: string | null): Promise<any> =>
|
||||
apiClient.post(endpoints.USER.SYNC_SOURCE, data, token),
|
||||
reingestSource: (data: any, token: string | null): Promise<any> =>
|
||||
apiClient.post(endpoints.USER.REINGEST_SOURCE, data, token),
|
||||
getAvailableTools: (token: string | null): Promise<any> =>
|
||||
apiClient.get(endpoints.USER.GET_AVAILABLE_TOOLS, token),
|
||||
getUserTools: (token: string | null): Promise<any> =>
|
||||
|
||||
@@ -165,12 +165,19 @@ function UploadRow({
|
||||
return (
|
||||
<li className="border-border/50 border-b last:border-b-0">
|
||||
<div className="flex items-center justify-between px-5 py-3">
|
||||
<p
|
||||
className="font-inter dark:text-muted-foreground max-w-[200px] truncate text-[13px] leading-[16.5px] font-normal text-black"
|
||||
title={task.fileName}
|
||||
>
|
||||
{task.fileName}
|
||||
</p>
|
||||
<div className="flex min-w-0 flex-col">
|
||||
<p
|
||||
className="font-inter dark:text-muted-foreground max-w-[200px] truncate text-[13px] leading-[16.5px] font-normal text-black"
|
||||
title={task.fileName}
|
||||
>
|
||||
{task.fileName}
|
||||
</p>
|
||||
{task.status === 'training' && task.stage && (
|
||||
<span className="font-inter text-muted-foreground mt-0.5 text-[11px] leading-[14px]">
|
||||
{t(`modals.uploadDoc.progress.${task.stage}`)}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="flex items-center gap-2">
|
||||
{showProgress && (
|
||||
|
||||
@@ -154,10 +154,10 @@ const ConversationBubble = forwardRef<
|
||||
<img
|
||||
src={DocumentationDark}
|
||||
alt="Attachment"
|
||||
className="h-[15px] w-[15px] object-fill"
|
||||
className="h-3.75 w-3.75 object-fill"
|
||||
/>
|
||||
</div>
|
||||
<span className="max-w-[150px] truncate font-normal">
|
||||
<span className="max-w-37.5 truncate font-normal">
|
||||
{file.fileName}
|
||||
</span>
|
||||
</div>
|
||||
@@ -328,7 +328,7 @@ const ConversationBubble = forwardRef<
|
||||
<div className="mb-4 flex flex-col flex-wrap items-start self-start lg:flex-nowrap">
|
||||
<div className="my-2 flex flex-row items-center justify-center gap-3">
|
||||
<Avatar
|
||||
className="h-[26px] w-[30px] text-xl"
|
||||
className="h-6.5 w-7.5 text-xl"
|
||||
avatar={
|
||||
<img
|
||||
src={Sources}
|
||||
@@ -376,7 +376,7 @@ const ConversationBubble = forwardRef<
|
||||
<img
|
||||
src={Document}
|
||||
alt="Document"
|
||||
className="h-[17px] w-[17px] object-fill"
|
||||
className="h-4.25 w-4.25 object-fill"
|
||||
/>
|
||||
<p
|
||||
className="mt-0.5 truncate text-xs"
|
||||
@@ -394,11 +394,11 @@ const ConversationBubble = forwardRef<
|
||||
</div>
|
||||
{activeTooltip === index && (
|
||||
<div
|
||||
className={`dark:bg-card dark:text-foreground absolute left-1/2 z-50 max-h-48 w-40 translate-x-[-50%] translate-y-[3px] rounded-xl bg-[#FBFBFB] p-4 text-black shadow-xl sm:w-56`}
|
||||
className={`dark:bg-card dark:text-foreground absolute left-1/2 z-50 max-h-48 w-40 translate-x-[-50%] translate-y-0.75 rounded-xl bg-[#FBFBFB] p-4 text-black shadow-xl sm:w-56`}
|
||||
onMouseOver={() => setActiveTooltip(index)}
|
||||
onMouseOut={() => setActiveTooltip(null)}
|
||||
>
|
||||
<p className="line-clamp-6 max-h-[164px] overflow-hidden rounded-md text-sm wrap-break-word text-ellipsis">
|
||||
<p className="line-clamp-6 max-h-41 overflow-hidden rounded-md text-sm wrap-break-word text-ellipsis">
|
||||
{source.text}
|
||||
</p>
|
||||
</div>
|
||||
@@ -471,7 +471,7 @@ const ConversationBubble = forwardRef<
|
||||
<div className="flex max-w-full flex-col flex-wrap items-start self-start lg:flex-nowrap">
|
||||
<div className="my-2 flex flex-row items-center justify-center gap-3">
|
||||
<Avatar
|
||||
className="h-[34px] w-[34px] text-2xl"
|
||||
className="h-8.5 w-8.5 text-2xl"
|
||||
avatar={
|
||||
<img
|
||||
src={DocsGPT3}
|
||||
@@ -1023,7 +1023,7 @@ function ToolCalls({
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="mb-4 flex w-full flex-col flex-wrap items-start self-start lg:flex-nowrap">
|
||||
<div className="mb-4 relative flex w-full flex-col flex-wrap items-start self-start lg:flex-nowrap">
|
||||
{/* Approval bars — always visible, compact inline */}
|
||||
{awaitingCalls.length > 0 && (
|
||||
<div className="fade-in mt-4 ml-3 w-[90vw] md:w-[70vw] lg:w-full">
|
||||
@@ -1042,7 +1042,7 @@ function ToolCalls({
|
||||
<>
|
||||
<div className="my-2 flex flex-row items-center justify-center gap-3">
|
||||
<Avatar
|
||||
className="h-[26px] w-[30px] text-xl"
|
||||
className="h-6.5 w-7.5 text-xl"
|
||||
avatar={
|
||||
<img
|
||||
src={Sources}
|
||||
@@ -1089,7 +1089,7 @@ function ToolCalls({
|
||||
</p>
|
||||
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
|
||||
<span
|
||||
className="dark:text-muted-foreground leading-[23px] text-black"
|
||||
className="dark:text-muted-foreground leading-5.75 text-black"
|
||||
style={{ fontFamily: 'IBMPlexMono-Medium' }}
|
||||
>
|
||||
{JSON.stringify(toolCall.arguments, null, 2)}
|
||||
@@ -1117,7 +1117,7 @@ function ToolCalls({
|
||||
{toolCall.status === 'completed' && (
|
||||
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
|
||||
<span
|
||||
className="dark:text-muted-foreground leading-[23px] text-black"
|
||||
className="dark:text-muted-foreground leading-5.75 text-black"
|
||||
style={{ fontFamily: 'IBMPlexMono-Medium' }}
|
||||
>
|
||||
{JSON.stringify(toolCall.result, null, 2)}
|
||||
@@ -1127,7 +1127,7 @@ function ToolCalls({
|
||||
{toolCall.status === 'error' && (
|
||||
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
|
||||
<span
|
||||
className="text-destructive leading-[23px]"
|
||||
className="text-destructive leading-5.75"
|
||||
style={{ fontFamily: 'IBMPlexMono-Medium' }}
|
||||
>
|
||||
{toolCall.error}
|
||||
@@ -1137,7 +1137,7 @@ function ToolCalls({
|
||||
{toolCall.status === 'denied' && (
|
||||
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
|
||||
<span
|
||||
className="text-muted-foreground leading-[23px]"
|
||||
className="text-muted-foreground leading-5.75"
|
||||
style={{ fontFamily: 'IBMPlexMono-Medium' }}
|
||||
>
|
||||
Denied by user
|
||||
@@ -1172,7 +1172,7 @@ function Thought({
|
||||
<div className="mb-4 flex w-full flex-col flex-wrap items-start self-start lg:flex-nowrap">
|
||||
<div className="my-2 flex flex-row items-center justify-center gap-3">
|
||||
<Avatar
|
||||
className="h-[26px] w-[30px] text-xl"
|
||||
className="h-6.5 w-7.5 text-xl"
|
||||
avatar={
|
||||
<img
|
||||
src={Cloud}
|
||||
@@ -1197,7 +1197,7 @@ function Thought({
|
||||
</div>
|
||||
{isThoughtOpen && (
|
||||
<div className="fade-in mr-5 ml-2 max-w-[90vw] md:max-w-[70vw] lg:max-w-[50vw]">
|
||||
<div className="bg-muted dark:bg-answer-bubble rounded-[28px] px-7 py-[18px]">
|
||||
<div className="bg-muted dark:bg-answer-bubble rounded-[28px] px-7 py-4.5">
|
||||
<ReactMarkdown
|
||||
className="fade-in leading-normal wrap-break-word whitespace-pre-wrap"
|
||||
remarkPlugins={[remarkGfm, remarkMath]}
|
||||
|
||||
@@ -1,10 +1,4 @@
|
||||
import {
|
||||
useCallback,
|
||||
useEffect,
|
||||
useRef,
|
||||
useState,
|
||||
RefObject,
|
||||
} from 'react';
|
||||
import { useCallback, useEffect, useRef, useState, RefObject } from 'react';
|
||||
|
||||
export function useOutsideAlerter<T extends HTMLElement>(
|
||||
ref: RefObject<T | null>,
|
||||
|
||||
@@ -70,6 +70,9 @@
|
||||
"sync": "Synchronisieren",
|
||||
"syncNow": "Jetzt synchronisieren",
|
||||
"syncing": "Synchronisiere...",
|
||||
"reingest": "Erneut indexieren",
|
||||
"ingestFailed": "Indexierung fehlgeschlagen",
|
||||
"ingestProcessing": "Indexierung...",
|
||||
"syncConfirmation": "Bist du sicher, dass du \"{{sourceName}}\" synchronisieren möchtest? Dies aktualisiert den Inhalt mit deinem Cloud-Speicher und kann Änderungen an einzelnen Chunks überschreiben.",
|
||||
"syncFrequency": {
|
||||
"never": "Nie",
|
||||
@@ -353,6 +356,8 @@
|
||||
"failed": "Upload fehlgeschlagen",
|
||||
"wait": "Dies kann einige Minuten dauern",
|
||||
"preparing": "Upload wird vorbereitet",
|
||||
"parsing": "Dateien werden verarbeitet...",
|
||||
"embedding": "Einbettung...",
|
||||
"tokenLimit": "Token-Limit überschritten, bitte lade ein kleineres Dokument hoch",
|
||||
"expandDetails": "Upload-Details erweitern",
|
||||
"collapseDetails": "Upload-Details einklappen",
|
||||
|
||||
@@ -70,6 +70,9 @@
|
||||
"sync": "Sync",
|
||||
"syncNow": "Sync now",
|
||||
"syncing": "Syncing...",
|
||||
"reingest": "Reingest",
|
||||
"ingestFailed": "Indexing failed",
|
||||
"ingestProcessing": "Indexing…",
|
||||
"syncConfirmation": "Are you sure you want to sync \"{{sourceName}}\"? This will update the content with your cloud storage and may override any edits you made to individual chunks.",
|
||||
"syncFrequency": {
|
||||
"never": "Never",
|
||||
@@ -365,6 +368,8 @@
|
||||
"failed": "Upload failed",
|
||||
"wait": "This may take several minutes",
|
||||
"preparing": "Preparing upload",
|
||||
"parsing": "Parsing files…",
|
||||
"embedding": "Embedding…",
|
||||
"tokenLimit": "Over the token limit, please consider uploading smaller document",
|
||||
"expandDetails": "Expand upload details",
|
||||
"collapseDetails": "Collapse upload details",
|
||||
|
||||
@@ -70,6 +70,9 @@
|
||||
"sync": "Sincronizar",
|
||||
"syncNow": "Sincronizar ahora",
|
||||
"syncing": "Sincronizando...",
|
||||
"reingest": "Reindexar",
|
||||
"ingestFailed": "Error de indexación",
|
||||
"ingestProcessing": "Indexando...",
|
||||
"syncConfirmation": "¿Estás seguro de que deseas sincronizar \"{{sourceName}}\"? Esto actualizará el contenido con tu almacenamiento en la nube y puede anular cualquier edición que hayas realizado en fragmentos individuales.",
|
||||
"syncFrequency": {
|
||||
"never": "Nunca",
|
||||
@@ -353,6 +356,8 @@
|
||||
"failed": "Error al subir",
|
||||
"wait": "Esto puede tardar varios minutos",
|
||||
"preparing": "Preparando subida",
|
||||
"parsing": "Analizando archivos...",
|
||||
"embedding": "Generando incrustaciones...",
|
||||
"tokenLimit": "Excede el límite de tokens, considere cargar un documento más pequeño",
|
||||
"expandDetails": "Expandir detalles de subida",
|
||||
"collapseDetails": "Contraer detalles de subida",
|
||||
|
||||
@@ -70,6 +70,9 @@
|
||||
"sync": "同期",
|
||||
"syncNow": "今すぐ同期",
|
||||
"syncing": "同期中...",
|
||||
"reingest": "再インデックス",
|
||||
"ingestFailed": "インデックス作成に失敗しました",
|
||||
"ingestProcessing": "インデックス作成中...",
|
||||
"syncConfirmation": "\"{{sourceName}}\"を同期してもよろしいですか?これにより、コンテンツがクラウドストレージで更新され、個々のチャンクに加えた編集が上書きされる可能性があります。",
|
||||
"syncFrequency": {
|
||||
"never": "なし",
|
||||
@@ -353,6 +356,8 @@
|
||||
"failed": "アップロード失敗",
|
||||
"wait": "数分かかる場合があります",
|
||||
"preparing": "アップロードを準備中",
|
||||
"parsing": "ファイルを解析中...",
|
||||
"embedding": "埋め込み処理中...",
|
||||
"tokenLimit": "トークン制限を超えています。より小さいドキュメントをアップロードしてください",
|
||||
"expandDetails": "アップロードの詳細を展開",
|
||||
"collapseDetails": "アップロードの詳細を折りたたむ",
|
||||
|
||||
@@ -70,6 +70,9 @@
|
||||
"sync": "Синхронизация",
|
||||
"syncNow": "Синхронизировать сейчас",
|
||||
"syncing": "Синхронизация...",
|
||||
"reingest": "Переиндексировать",
|
||||
"ingestFailed": "Ошибка индексации",
|
||||
"ingestProcessing": "Индексация...",
|
||||
"syncConfirmation": "Вы уверены, что хотите синхронизировать \"{{sourceName}}\"? Это обновит содержимое с вашим облачным хранилищем и может перезаписать любые изменения, внесенные вами в отдельные фрагменты.",
|
||||
"syncFrequency": {
|
||||
"never": "Никогда",
|
||||
@@ -353,6 +356,8 @@
|
||||
"failed": "Ошибка загрузки",
|
||||
"wait": "Это может занять несколько минут",
|
||||
"preparing": "Подготовка загрузки",
|
||||
"parsing": "Обработка файлов...",
|
||||
"embedding": "Создание эмбеддингов...",
|
||||
"tokenLimit": "Превышен лимит токенов, рассмотрите возможность загрузки документа меньшего размера",
|
||||
"expandDetails": "Развернуть детали загрузки",
|
||||
"collapseDetails": "Свернуть детали загрузки",
|
||||
|
||||
@@ -70,6 +70,9 @@
|
||||
"sync": "同步",
|
||||
"syncNow": "立即同步",
|
||||
"syncing": "同步中...",
|
||||
"reingest": "重新索引",
|
||||
"ingestFailed": "索引失敗",
|
||||
"ingestProcessing": "索引中...",
|
||||
"syncConfirmation": "您確定要同步 \"{{sourceName}}\" 嗎?這將使用您的雲端儲存更新內容,並可能覆蓋您對個別文本塊所做的任何編輯。",
|
||||
"syncFrequency": {
|
||||
"never": "從不",
|
||||
@@ -353,6 +356,8 @@
|
||||
"failed": "上傳失敗",
|
||||
"wait": "這可能需要幾分鐘",
|
||||
"preparing": "準備上傳",
|
||||
"parsing": "正在解析檔案...",
|
||||
"embedding": "正在生成嵌入...",
|
||||
"tokenLimit": "超出令牌限制,請考慮上傳較小的文檔",
|
||||
"expandDetails": "展開上傳詳情",
|
||||
"collapseDetails": "摺疊上傳詳情",
|
||||
|
||||
@@ -70,6 +70,9 @@
|
||||
"sync": "同步",
|
||||
"syncNow": "立即同步",
|
||||
"syncing": "同步中...",
|
||||
"reingest": "重新索引",
|
||||
"ingestFailed": "索引失败",
|
||||
"ingestProcessing": "索引中...",
|
||||
"syncConfirmation": "您确定要同步 \"{{sourceName}}\" 吗?这将使用您的云存储更新内容,并可能覆盖您对单个文本块所做的任何编辑。",
|
||||
"syncFrequency": {
|
||||
"never": "从不",
|
||||
@@ -353,6 +356,8 @@
|
||||
"failed": "上传失败",
|
||||
"wait": "这可能需要几分钟",
|
||||
"preparing": "准备上传",
|
||||
"parsing": "正在解析文件...",
|
||||
"embedding": "正在生成嵌入...",
|
||||
"tokenLimit": "超出令牌限制,请考虑上传较小的文档",
|
||||
"expandDetails": "展开上传详情",
|
||||
"collapseDetails": "折叠上传详情",
|
||||
|
||||
@@ -14,6 +14,8 @@ export type Doc = {
|
||||
syncFrequency?: string;
|
||||
isNested?: boolean;
|
||||
provider?: string;
|
||||
// Derived server-side from ingest_chunk_progress (sources API).
|
||||
ingestStatus?: 'processing' | 'failed';
|
||||
};
|
||||
|
||||
export type GetDocsResponse = {
|
||||
|
||||
@@ -27,6 +27,12 @@ import {
|
||||
setSourceDocs,
|
||||
} from '../preferences/preferenceSlice';
|
||||
import Upload from '../upload/Upload';
|
||||
import {
|
||||
addUploadTask,
|
||||
removeUploadTask,
|
||||
selectUploadTasks,
|
||||
updateUploadTask,
|
||||
} from '../upload/uploadSlice';
|
||||
import { formatDate } from '../utils/dateTimeUtils';
|
||||
import FileTree from '../components/FileTree';
|
||||
import ConnectorTree from '../components/ConnectorTree';
|
||||
@@ -56,6 +62,7 @@ export default function Sources({
|
||||
const [isDarkTheme] = useDarkTheme();
|
||||
const dispatch = useDispatch();
|
||||
const token = useSelector(selectToken);
|
||||
const uploadTasks = useSelector(selectUploadTasks);
|
||||
|
||||
const [searchTerm, setSearchTerm] = useState<string>('');
|
||||
const debouncedSearchTerm = useDebouncedValue(searchTerm, 500);
|
||||
@@ -249,6 +256,57 @@ export default function Sources({
|
||||
}
|
||||
};
|
||||
|
||||
const handleReingest = async (doc: Doc) => {
|
||||
if (!doc.id) {
|
||||
return;
|
||||
}
|
||||
const sourceId = doc.id;
|
||||
// Drop stale toast rows for this source (a finished/dismissed task
|
||||
// would swallow the reingest's SSE events), then open a fresh one.
|
||||
uploadTasks
|
||||
.filter((task) => task.sourceId === sourceId)
|
||||
.forEach((task) => dispatch(removeUploadTask(task.id)));
|
||||
const reingestTaskId = `reingest-${sourceId}-${Date.now()}`;
|
||||
dispatch(
|
||||
addUploadTask({
|
||||
id: reingestTaskId,
|
||||
fileName: doc.name || sourceId,
|
||||
progress: 0,
|
||||
status: 'training',
|
||||
sourceId,
|
||||
}),
|
||||
);
|
||||
try {
|
||||
const response = await userService.reingestSource(
|
||||
{ source_id: sourceId },
|
||||
token,
|
||||
);
|
||||
const data = await response.json();
|
||||
if (!data.success) {
|
||||
console.error('Reingest failed:', data.error || data.message);
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: reingestTaskId,
|
||||
updates: {
|
||||
status: 'failed',
|
||||
errorMessage: data.error || data.message,
|
||||
},
|
||||
}),
|
||||
);
|
||||
return;
|
||||
}
|
||||
refreshDocs(undefined, currentPage, rowsPerPage);
|
||||
} catch (error) {
|
||||
console.error('Error reingesting source:', error);
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: reingestTaskId,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
const [documentToDelete, setDocumentToDelete] = useState<{
|
||||
index: number;
|
||||
document: Doc;
|
||||
@@ -283,6 +341,19 @@ export default function Sources({
|
||||
},
|
||||
];
|
||||
|
||||
if (document.ingestStatus === 'failed') {
|
||||
actions.push({
|
||||
icon: SyncIcon,
|
||||
label: t('settings.sources.reingest'),
|
||||
onClick: () => {
|
||||
handleReingest(document);
|
||||
},
|
||||
iconWidth: 14,
|
||||
iconHeight: 14,
|
||||
variant: 'primary',
|
||||
});
|
||||
}
|
||||
|
||||
if (document.syncFrequency) {
|
||||
actions.push({
|
||||
icon: SyncIcon,
|
||||
@@ -483,6 +554,16 @@ export default function Sources({
|
||||
</div>
|
||||
|
||||
<div className="flex flex-col items-start justify-start gap-1">
|
||||
{document.ingestStatus === 'failed' && (
|
||||
<span className="rounded-full bg-red-100 px-2 py-0.5 text-[11px] leading-[16px] font-medium text-red-700 dark:bg-red-900/30 dark:text-red-400">
|
||||
{t('settings.sources.ingestFailed')}
|
||||
</span>
|
||||
)}
|
||||
{document.ingestStatus === 'processing' && (
|
||||
<span className="bg-muted-foreground/10 text-muted-foreground rounded-full px-2 py-0.5 text-[11px] leading-[16px] font-medium">
|
||||
{t('settings.sources.ingestProcessing')}
|
||||
</span>
|
||||
)}
|
||||
<div className="flex items-center gap-2">
|
||||
<img
|
||||
src={CalendarIcon}
|
||||
|
||||
@@ -286,6 +286,26 @@ describe('source.ingest.progress', () => {
|
||||
state = reducer(state, ingest('source.ingest.progress', { current: -10 }));
|
||||
expect(state.tasks[0].progress).toBe(100);
|
||||
});
|
||||
|
||||
it('records the ingest stage from the payload', () => {
|
||||
let state = stateWithTask(makeTask({ status: 'training' }));
|
||||
state = reducer(
|
||||
state,
|
||||
ingest('source.ingest.progress', { current: 20, stage: 'parsing' }),
|
||||
);
|
||||
expect(state.tasks[0].stage).toBe('parsing');
|
||||
state = reducer(
|
||||
state,
|
||||
ingest('source.ingest.progress', { current: 70, stage: 'embedding' }),
|
||||
);
|
||||
expect(state.tasks[0].stage).toBe('embedding');
|
||||
// An unknown/absent stage leaves the last known value intact.
|
||||
state = reducer(
|
||||
state,
|
||||
ingest('source.ingest.progress', { current: 80, stage: 'bogus' }),
|
||||
);
|
||||
expect(state.tasks[0].stage).toBe('embedding');
|
||||
});
|
||||
});
|
||||
|
||||
describe('source.ingest.completed', () => {
|
||||
|
||||
@@ -66,6 +66,12 @@ export interface UploadTask {
|
||||
sourceId?: string;
|
||||
errorMessage?: string;
|
||||
dismissed?: boolean;
|
||||
/**
|
||||
* Ingest phase from the latest ``source.ingest.progress`` event:
|
||||
* ``parsing`` (parse/OCR, lower band of the bar) or ``embedding``
|
||||
* (upper band). Drives the phase label in ``UploadToast``.
|
||||
*/
|
||||
stage?: 'parsing' | 'embedding';
|
||||
/**
|
||||
* Flipped when ``source.ingest.completed`` carries
|
||||
* ``payload.limited === true`` (the worker hit a token cap during
|
||||
@@ -334,6 +340,9 @@ export const uploadSlice = createSlice({
|
||||
if (task.status === 'completed' || task.status === 'failed') break;
|
||||
task.status = 'training';
|
||||
if (clamped > task.progress) task.progress = clamped;
|
||||
if (payload.stage === 'parsing' || payload.stage === 'embedding') {
|
||||
task.stage = payload.stage;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'source.ingest.completed':
|
||||
|
||||
@@ -4,19 +4,24 @@ Fixed 5-second generation (100 tokens × 50 ms/token). No auth. Emits SSE
|
||||
chunks in OpenAI's chat.completions streaming format, or a single response
|
||||
when stream=false. Run on 127.0.0.1:8090 — point DocsGPT at it via
|
||||
OPENAI_BASE_URL=http://127.0.0.1:8090/v1.
|
||||
|
||||
Flags:
|
||||
--tool-calls First response returns a tool call instead of text.
|
||||
Subsequent responses (after a tool_result) return text.
|
||||
Useful for triggering the tool-execution loop.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from flask import Flask, Response, request, jsonify
|
||||
|
||||
TOKEN_COUNT = 100
|
||||
TOKEN_DELAY_S = 0.05 # 100 * 0.05 = 5.0 s
|
||||
TOOL_CALL_MODE = False
|
||||
|
||||
logger = logging.getLogger("mock_llm")
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s mock: %(message)s")
|
||||
@@ -39,7 +44,7 @@ FILLER_TOKENS = [
|
||||
".",
|
||||
]
|
||||
|
||||
app = FastAPI()
|
||||
app = Flask(__name__)
|
||||
|
||||
|
||||
def _token_stream_id() -> str:
|
||||
@@ -63,11 +68,57 @@ def _sse_chunk(completion_id: str, model: str, delta: dict, finish_reason=None)
|
||||
return f"data: {json.dumps(payload)}\n\n"
|
||||
|
||||
|
||||
async def _stream_response(model: str, req_id: str):
|
||||
def _gen_tool_call_stream(model: str, req_id: str):
|
||||
"""Emit two tool_calls (search) in streaming format.
|
||||
|
||||
Two calls ensure the handler executes the first (which can return a
|
||||
huge result), then hits _check_context_limit before the second.
|
||||
"""
|
||||
completion_id = _token_stream_id()
|
||||
call_id_1 = f"call_{uuid.uuid4().hex[:12]}"
|
||||
call_id_2 = f"call_{uuid.uuid4().hex[:12]}"
|
||||
|
||||
yield _sse_chunk(completion_id, model, {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"id": call_id_1,
|
||||
"type": "function",
|
||||
"function": {"name": "search", "arguments": ""},
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"id": call_id_2,
|
||||
"type": "function",
|
||||
"function": {"name": "search", "arguments": ""},
|
||||
},
|
||||
],
|
||||
})
|
||||
args_json = json.dumps({"query": "Python programming basics"})
|
||||
for ch in args_json:
|
||||
time.sleep(TOKEN_DELAY_S)
|
||||
yield _sse_chunk(completion_id, model, {
|
||||
"tool_calls": [
|
||||
{"index": 0, "function": {"arguments": ch}},
|
||||
{"index": 1, "function": {"arguments": ch}},
|
||||
],
|
||||
})
|
||||
yield _sse_chunk(completion_id, model, {}, finish_reason="tool_calls")
|
||||
yield "data: [DONE]\n\n"
|
||||
logger.info("[%s] tool_call stream done (ids=%s, %s)", req_id, call_id_1, call_id_2)
|
||||
|
||||
|
||||
def _has_tool_result(messages: list) -> bool:
|
||||
return any(m.get("role") == "tool" for m in messages)
|
||||
|
||||
|
||||
def _gen_text_stream(model: str, req_id: str):
|
||||
completion_id = _token_stream_id()
|
||||
yield _sse_chunk(completion_id, model, {"role": "assistant", "content": ""})
|
||||
for i, tok in enumerate(FILLER_TOKENS[:TOKEN_COUNT]):
|
||||
await asyncio.sleep(TOKEN_DELAY_S)
|
||||
for tok in FILLER_TOKENS[:TOKEN_COUNT]:
|
||||
time.sleep(TOKEN_DELAY_S)
|
||||
yield _sse_chunk(completion_id, model, {"content": tok})
|
||||
yield _sse_chunk(completion_id, model, {}, finish_reason="stop")
|
||||
yield "data: [DONE]\n\n"
|
||||
@@ -75,63 +126,84 @@ async def _stream_response(model: str, req_id: str):
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(request: Request):
|
||||
body = await request.json()
|
||||
def chat_completions():
|
||||
body = request.get_json(force=True)
|
||||
model = body.get("model", "mock")
|
||||
stream = bool(body.get("stream", False))
|
||||
messages = body.get("messages", [])
|
||||
tools = body.get("tools")
|
||||
req_id = uuid.uuid4().hex[:8]
|
||||
logger.info("[%s] /chat/completions stream=%s model=%s max_tokens=%s", req_id, stream, model, body.get("max_tokens"))
|
||||
logger.info(
|
||||
"[%s] /chat/completions stream=%s model=%s tools=%s msgs=%d",
|
||||
req_id, stream, model, bool(tools), len(messages),
|
||||
)
|
||||
|
||||
use_tool_call = (
|
||||
TOOL_CALL_MODE
|
||||
and tools
|
||||
and not _has_tool_result(messages)
|
||||
)
|
||||
|
||||
if stream:
|
||||
return StreamingResponse(
|
||||
_stream_response(model, req_id),
|
||||
media_type="text/event-stream",
|
||||
gen = (
|
||||
_gen_tool_call_stream(model, req_id) if use_tool_call
|
||||
else _gen_text_stream(model, req_id)
|
||||
)
|
||||
return Response(
|
||||
gen,
|
||||
mimetype="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache, no-transform",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
await asyncio.sleep(TOKEN_COUNT * TOKEN_DELAY_S)
|
||||
time.sleep(TOKEN_COUNT * TOKEN_DELAY_S)
|
||||
logger.info("[%s] non-stream done", req_id)
|
||||
text = "".join(FILLER_TOKENS[:TOKEN_COUNT])
|
||||
completion_id = _token_stream_id()
|
||||
return JSONResponse(
|
||||
{
|
||||
"id": completion_id,
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": text},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": TOKEN_COUNT,
|
||||
"total_tokens": 10 + TOKEN_COUNT,
|
||||
},
|
||||
}
|
||||
)
|
||||
return jsonify({
|
||||
"id": completion_id,
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": text},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": TOKEN_COUNT,
|
||||
"total_tokens": 10 + TOKEN_COUNT,
|
||||
},
|
||||
})
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def list_models():
|
||||
return {
|
||||
def list_models():
|
||||
return jsonify({
|
||||
"object": "list",
|
||||
"data": [{"id": "mock", "object": "model", "owned_by": "mock"}],
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
def health():
|
||||
return jsonify({"status": "ok"})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="127.0.0.1", port=8090, log_level="info")
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--tool-calls", action="store_true",
|
||||
help="First response returns a tool_call; subsequent responses return text.",
|
||||
)
|
||||
parser.add_argument("--port", type=int, default=8090)
|
||||
args = parser.parse_args()
|
||||
TOOL_CALL_MODE = args.tool_calls
|
||||
if TOOL_CALL_MODE:
|
||||
logger.info("Tool-call mode enabled")
|
||||
app.run(host="127.0.0.1", port=args.port, debug=False, threaded=True)
|
||||
|
||||
@@ -45,15 +45,14 @@ class TestAPIToolInit:
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMakeApiCall:
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_successful_get(self, mock_get, mock_validate, tool):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_successful_get(self, mock_pinned, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {"result": "ok"}
|
||||
mock_resp.content = b'{"result":"ok"}'
|
||||
mock_get.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("any_action")
|
||||
|
||||
@@ -61,54 +60,50 @@ class TestMakeApiCall:
|
||||
assert result["data"] == {"result": "ok"}
|
||||
assert result["message"] == "API call successful."
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.post")
|
||||
def test_successful_post(self, mock_post, mock_validate, post_tool):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_successful_post(self, mock_pinned, post_tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 201
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {"id": 1}
|
||||
mock_resp.content = b'{"id":1}'
|
||||
mock_post.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = post_tool.execute_action("create", name="test")
|
||||
|
||||
assert result["status_code"] == 201
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
def test_ssrf_blocked(self, mock_validate, tool):
|
||||
from application.core.url_validation import SSRFError
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_ssrf_blocked(self, mock_pinned, tool):
|
||||
from application.security.safe_url import UnsafeUserUrlError
|
||||
|
||||
mock_validate.side_effect = SSRFError("blocked")
|
||||
mock_pinned.side_effect = UnsafeUserUrlError("blocked")
|
||||
|
||||
result = tool.execute_action("any")
|
||||
|
||||
assert result["status_code"] is None
|
||||
assert "URL validation error" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_timeout_error(self, mock_get, mock_validate, tool):
|
||||
mock_get.side_effect = requests.exceptions.Timeout()
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_timeout_error(self, mock_pinned, tool):
|
||||
mock_pinned.side_effect = requests.exceptions.Timeout()
|
||||
|
||||
result = tool.execute_action("any")
|
||||
|
||||
assert result["status_code"] is None
|
||||
assert "timeout" in result["message"].lower()
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_connection_error(self, mock_get, mock_validate, tool):
|
||||
mock_get.side_effect = requests.exceptions.ConnectionError("refused")
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_connection_error(self, mock_pinned, tool):
|
||||
mock_pinned.side_effect = requests.exceptions.ConnectionError("refused")
|
||||
|
||||
result = tool.execute_action("any")
|
||||
|
||||
assert result["status_code"] is None
|
||||
assert "Connection error" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_http_error(self, mock_get, mock_validate, tool):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_http_error(self, mock_pinned, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 404
|
||||
mock_resp.text = "Not Found"
|
||||
@@ -116,15 +111,14 @@ class TestMakeApiCall:
|
||||
mock_resp.raise_for_status.side_effect = requests.exceptions.HTTPError(
|
||||
response=mock_resp
|
||||
)
|
||||
mock_get.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("any")
|
||||
|
||||
assert result["status_code"] == 404
|
||||
assert "HTTP Error" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
def test_unsupported_method(self, mock_validate):
|
||||
def test_unsupported_method(self):
|
||||
tool = APITool(
|
||||
config={"url": "https://example.com", "method": "CUSTOM"}
|
||||
)
|
||||
@@ -132,69 +126,64 @@ class TestMakeApiCall:
|
||||
assert result["status_code"] is None
|
||||
assert "Unsupported" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.put")
|
||||
def test_put_method(self, mock_put, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_put_method(self, mock_pinned):
|
||||
tool = APITool(config={"url": "https://example.com/item/1", "method": "PUT"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {}
|
||||
mock_resp.content = b'{}'
|
||||
mock_put.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("update", name="new")
|
||||
assert result["status_code"] == 200
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.delete")
|
||||
def test_delete_method(self, mock_delete, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_delete_method(self, mock_pinned):
|
||||
tool = APITool(config={"url": "https://example.com/item/1", "method": "DELETE"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 204
|
||||
mock_resp.headers = {"Content-Type": "text/plain"}
|
||||
mock_resp.content = b''
|
||||
mock_delete.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("delete")
|
||||
assert result["status_code"] == 204
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.patch")
|
||||
def test_patch_method(self, mock_patch, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_patch_method(self, mock_pinned):
|
||||
tool = APITool(config={"url": "https://example.com/item/1", "method": "PATCH"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {"patched": True}
|
||||
mock_resp.content = b'{"patched":true}'
|
||||
mock_patch.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("patch", field="val")
|
||||
assert result["status_code"] == 200
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.head")
|
||||
def test_head_method(self, mock_head, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_head_method(self, mock_pinned):
|
||||
tool = APITool(config={"url": "https://example.com", "method": "HEAD"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "text/html"}
|
||||
mock_resp.content = b''
|
||||
mock_head.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("check")
|
||||
assert result["status_code"] == 200
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.options")
|
||||
def test_options_method(self, mock_options, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_options_method(self, mock_pinned):
|
||||
tool = APITool(config={"url": "https://example.com", "method": "OPTIONS"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "text/plain"}
|
||||
mock_resp.content = b''
|
||||
mock_options.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("options")
|
||||
assert result["status_code"] == 200
|
||||
@@ -202,9 +191,8 @@ class TestMakeApiCall:
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPathParamSubstitution:
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_path_params_substituted(self, mock_get, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_path_params_substituted(self, mock_pinned):
|
||||
tool = APITool(
|
||||
config={
|
||||
"url": "https://api.example.com/users/{user_id}/posts/{post_id}",
|
||||
@@ -217,11 +205,11 @@ class TestPathParamSubstitution:
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = []
|
||||
mock_resp.content = b'[]'
|
||||
mock_get.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
tool.execute_action("get")
|
||||
|
||||
called_url = mock_get.call_args[0][0]
|
||||
called_url = mock_pinned.call_args[0][1]
|
||||
assert "/users/42/posts/7" in called_url
|
||||
assert "{user_id}" not in called_url
|
||||
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
"""Tests for the journaled execute path on ToolExecutor.
|
||||
|
||||
Each tool call inserts a row into ``tool_call_attempts`` then flips
|
||||
through ``proposed → executed`` (or ``proposed → failed``). The flip
|
||||
to ``confirmed`` is owned by the message-finalize path and is only
|
||||
asserted indirectly here (rows stay in ``executed`` so the reconciler
|
||||
can pick them up).
|
||||
Each tool call inserts a ``tool_call_attempts`` row and flips it
|
||||
``proposed → executed`` (or ``→ failed``). With a ``message_id`` it
|
||||
stays ``executed`` for the finalize path to confirm; without one
|
||||
(``save_conversation=False``) it goes straight to ``confirmed``.
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
@@ -75,11 +74,24 @@ def _make_call(name="test_action_t1", call_id="c1"):
|
||||
return call
|
||||
|
||||
|
||||
_TOOLS_DICT = {
|
||||
"t1": {
|
||||
"id": "00000000-0000-0000-0000-000000000001",
|
||||
"name": "test_tool",
|
||||
"config": {"key": "val"},
|
||||
"actions": [
|
||||
{"name": "test_action", "description": "T", "parameters": {"properties": {}}},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestExecuteJournaling:
|
||||
def test_happy_path_proposed_then_executed(
|
||||
def test_no_message_id_proposed_then_confirmed(
|
||||
self, pg_conn, mock_tool_manager, monkeypatch
|
||||
):
|
||||
"""No reserved message (``save_conversation=False``) → row lands ``confirmed``, not ``executed``."""
|
||||
executor = ToolExecutor(user="u")
|
||||
monkeypatch.setattr(
|
||||
"application.agents.tool_executor.ToolActionParser",
|
||||
@@ -89,23 +101,12 @@ class TestExecuteJournaling:
|
||||
)
|
||||
_patch_db(monkeypatch, pg_conn)
|
||||
|
||||
tools_dict = {
|
||||
"t1": {
|
||||
"id": "00000000-0000-0000-0000-000000000001",
|
||||
"name": "test_tool",
|
||||
"config": {"key": "val"},
|
||||
"actions": [
|
||||
{"name": "test_action", "description": "T", "parameters": {"properties": {}}},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
events, result = _drain(executor.execute(tools_dict, _make_call(), "MockLLM"))
|
||||
events, result = _drain(executor.execute(_TOOLS_DICT, _make_call(), "MockLLM"))
|
||||
assert result[0] == "Tool result"
|
||||
|
||||
row = _select_attempt(pg_conn, "c1")
|
||||
assert row is not None
|
||||
assert row["status"] == "executed"
|
||||
assert row["status"] == "confirmed"
|
||||
assert row["tool_name"] == "test_tool"
|
||||
assert row["action_name"] == "test_action"
|
||||
assert row["arguments"] == {"q": "v"}
|
||||
@@ -117,10 +118,7 @@ class TestExecuteJournaling:
|
||||
def test_executor_message_id_is_persisted_on_executed_row(
|
||||
self, pg_conn, mock_tool_manager, monkeypatch
|
||||
):
|
||||
"""When the route stamps a placeholder message_id on the executor,
|
||||
the journal row carries it forward so ``confirm_executed_tool_calls``
|
||||
can later flip it to ``confirmed``.
|
||||
"""
|
||||
"""The executor's message_id is carried onto the journal row, which stays ``executed``."""
|
||||
from application.storage.db.repositories.conversations import (
|
||||
ConversationsRepository,
|
||||
)
|
||||
@@ -147,18 +145,7 @@ class TestExecuteJournaling:
|
||||
)
|
||||
_patch_db(monkeypatch, pg_conn)
|
||||
|
||||
tools_dict = {
|
||||
"t1": {
|
||||
"id": "00000000-0000-0000-0000-000000000001",
|
||||
"name": "test_tool",
|
||||
"config": {"key": "val"},
|
||||
"actions": [
|
||||
{"name": "test_action", "description": "T", "parameters": {"properties": {}}},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
_drain(executor.execute(tools_dict, _make_call(call_id="cm1"), "MockLLM"))
|
||||
_drain(executor.execute(_TOOLS_DICT, _make_call(call_id="cm1"), "MockLLM"))
|
||||
|
||||
row = _select_attempt(pg_conn, "cm1")
|
||||
assert row is not None
|
||||
@@ -180,18 +167,7 @@ class TestExecuteJournaling:
|
||||
RuntimeError("boom")
|
||||
)
|
||||
|
||||
tools_dict = {
|
||||
"t1": {
|
||||
"id": "00000000-0000-0000-0000-000000000001",
|
||||
"name": "test_tool",
|
||||
"config": {"key": "val"},
|
||||
"actions": [
|
||||
{"name": "test_action", "description": "T", "parameters": {"properties": {}}},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
gen = executor.execute(tools_dict, _make_call(call_id="c2"), "MockLLM")
|
||||
gen = executor.execute(_TOOLS_DICT, _make_call(call_id="c2"), "MockLLM")
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
_drain(gen)
|
||||
|
||||
@@ -200,42 +176,10 @@ class TestExecuteJournaling:
|
||||
assert row["status"] == "failed"
|
||||
assert row["error"] == "boom"
|
||||
|
||||
def test_executed_row_lingers_for_reconciler_when_no_confirm(
|
||||
self, pg_conn, mock_tool_manager, monkeypatch
|
||||
):
|
||||
"""No finalize_message call → row sits in ``executed``."""
|
||||
executor = ToolExecutor(user="u")
|
||||
monkeypatch.setattr(
|
||||
"application.agents.tool_executor.ToolActionParser",
|
||||
lambda _cls, **kw: Mock(
|
||||
parse_args=Mock(return_value=("t1", "test_action", {}))
|
||||
),
|
||||
)
|
||||
_patch_db(monkeypatch, pg_conn)
|
||||
|
||||
tools_dict = {
|
||||
"t1": {
|
||||
"id": "00000000-0000-0000-0000-000000000001",
|
||||
"name": "test_tool",
|
||||
"config": {"key": "val"},
|
||||
"actions": [
|
||||
{"name": "test_action", "description": "T", "parameters": {"properties": {}}},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
_drain(executor.execute(tools_dict, _make_call(call_id="c3"), "MockLLM"))
|
||||
|
||||
row = _select_attempt(pg_conn, "c3")
|
||||
assert row["status"] == "executed"
|
||||
# Partial index `tool_call_attempts_pending_ts_idx` selects rows
|
||||
# in ('proposed','executed') — the reconciler reads those.
|
||||
assert row["status"] in ("proposed", "executed")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestRepository:
|
||||
def test_proposed_then_executed_round_trip(self, pg_conn):
|
||||
def test_proposed_then_confirmed_when_no_message(self, pg_conn):
|
||||
from application.storage.db.repositories.tool_call_attempts import (
|
||||
ToolCallAttemptsRepository,
|
||||
)
|
||||
@@ -249,7 +193,50 @@ class TestRepository:
|
||||
|
||||
assert repo.mark_executed("c-x", {"out": "ok"}) is True
|
||||
row = _select_attempt(pg_conn, "c-x")
|
||||
assert row["status"] == "confirmed"
|
||||
assert row["message_id"] is None
|
||||
assert row["result"] == {"result": {"out": "ok"}}
|
||||
|
||||
def test_mark_executed_with_message_stays_executed(self, pg_conn):
|
||||
from application.storage.db.repositories.conversations import (
|
||||
ConversationsRepository,
|
||||
)
|
||||
from application.storage.db.repositories.tool_call_attempts import (
|
||||
ToolCallAttemptsRepository,
|
||||
)
|
||||
|
||||
# FK constraint: message_id must reference a real row.
|
||||
conv_repo = ConversationsRepository(pg_conn)
|
||||
conv = conv_repo.create("u-repo", "repo-msg-test")
|
||||
msg = conv_repo.reserve_message(
|
||||
str(conv["id"]),
|
||||
prompt="q?",
|
||||
placeholder_response="...",
|
||||
request_id="req-repo-1",
|
||||
status="pending",
|
||||
)
|
||||
message_uuid = str(msg["id"])
|
||||
|
||||
repo = ToolCallAttemptsRepository(pg_conn)
|
||||
repo.record_proposed("c-m", "tool", "act", {})
|
||||
assert (
|
||||
repo.mark_executed("c-m", {"out": "ok"}, message_id=message_uuid) is True
|
||||
)
|
||||
row = _select_attempt(pg_conn, "c-m")
|
||||
assert row["status"] == "executed"
|
||||
assert str(row["message_id"]) == message_uuid
|
||||
|
||||
def test_upsert_executed_without_message_confirms(self, pg_conn):
|
||||
"""``upsert_executed`` (DB-outage fallback) with no ``message_id`` lands ``confirmed``."""
|
||||
from application.storage.db.repositories.tool_call_attempts import (
|
||||
ToolCallAttemptsRepository,
|
||||
)
|
||||
|
||||
repo = ToolCallAttemptsRepository(pg_conn)
|
||||
repo.upsert_executed("c-up", "tool", "act", {"a": 1}, {"out": "ok"})
|
||||
row = _select_attempt(pg_conn, "c-up")
|
||||
assert row["status"] == "confirmed"
|
||||
assert row["message_id"] is None
|
||||
assert row["result"] == {"result": {"out": "ok"}}
|
||||
|
||||
def test_mark_failed_sets_error(self, pg_conn):
|
||||
|
||||
@@ -81,104 +81,103 @@ class TestAPIToolInit:
|
||||
@pytest.mark.unit
|
||||
class TestMakeApiCall:
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_successful_get(self, mock_get, mock_validate, get_tool):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_successful_get(self, mock_pinned, get_tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {"result": "ok"}
|
||||
mock_resp.content = b'{"result":"ok"}'
|
||||
mock_get.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = get_tool.execute_action("any_action")
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert result["data"] == {"result": "ok"}
|
||||
assert result["message"] == "API call successful."
|
||||
assert mock_pinned.call_args[0][0] == "GET"
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.post")
|
||||
def test_successful_post(self, mock_post, mock_validate, post_tool):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_successful_post(self, mock_pinned, post_tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 201
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {"id": 1}
|
||||
mock_resp.content = b'{"id":1}'
|
||||
mock_post.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = post_tool.execute_action("create", name="test")
|
||||
assert result["status_code"] == 201
|
||||
assert mock_pinned.call_args[0][0] == "POST"
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.put")
|
||||
def test_put_method(self, mock_put, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_put_method(self, mock_pinned):
|
||||
tool = APITool(config={"url": "https://example.com/item/1", "method": "PUT"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {}
|
||||
mock_resp.content = b'{}'
|
||||
mock_put.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("update", name="new")
|
||||
assert result["status_code"] == 200
|
||||
assert mock_pinned.call_args[0][0] == "PUT"
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.delete")
|
||||
def test_delete_method(self, mock_delete, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_delete_method(self, mock_pinned):
|
||||
tool = APITool(config={"url": "https://example.com/item/1", "method": "DELETE"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 204
|
||||
mock_resp.headers = {"Content-Type": "text/plain"}
|
||||
mock_resp.content = b''
|
||||
mock_delete.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("delete")
|
||||
assert result["status_code"] == 204
|
||||
assert mock_pinned.call_args[0][0] == "DELETE"
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.patch")
|
||||
def test_patch_method(self, mock_patch, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_patch_method(self, mock_pinned):
|
||||
tool = APITool(config={"url": "https://example.com/item/1", "method": "PATCH"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {"patched": True}
|
||||
mock_resp.content = b'{"patched":true}'
|
||||
mock_patch.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("patch", field="val")
|
||||
assert result["status_code"] == 200
|
||||
assert mock_pinned.call_args[0][0] == "PATCH"
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.head")
|
||||
def test_head_method(self, mock_head, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_head_method(self, mock_pinned):
|
||||
tool = APITool(config={"url": "https://example.com", "method": "HEAD"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "text/html"}
|
||||
mock_resp.content = b''
|
||||
mock_head.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("check")
|
||||
assert result["status_code"] == 200
|
||||
assert mock_pinned.call_args[0][0] == "HEAD"
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.options")
|
||||
def test_options_method(self, mock_options, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_options_method(self, mock_pinned):
|
||||
tool = APITool(config={"url": "https://example.com", "method": "OPTIONS"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "text/plain"}
|
||||
mock_resp.content = b''
|
||||
mock_options.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("options")
|
||||
assert result["status_code"] == 200
|
||||
assert mock_pinned.call_args[0][0] == "OPTIONS"
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
def test_unsupported_method(self, mock_validate):
|
||||
def test_unsupported_method(self):
|
||||
tool = APITool(config={"url": "https://example.com", "method": "CUSTOM"})
|
||||
result = tool.execute_action("any")
|
||||
assert result["status_code"] is None
|
||||
@@ -193,19 +192,18 @@ class TestMakeApiCall:
|
||||
@pytest.mark.unit
|
||||
class TestSSRFValidation:
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
def test_ssrf_blocked_initial_url(self, mock_validate, get_tool):
|
||||
from application.core.url_validation import SSRFError
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_ssrf_blocked(self, mock_pinned, get_tool):
|
||||
from application.security.safe_url import UnsafeUserUrlError
|
||||
|
||||
mock_validate.side_effect = SSRFError("blocked")
|
||||
mock_pinned.side_effect = UnsafeUserUrlError("blocked")
|
||||
result = get_tool.execute_action("any")
|
||||
assert result["status_code"] is None
|
||||
assert "URL validation error" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_ssrf_blocked_after_param_substitution(self, mock_get, mock_validate):
|
||||
from application.core.url_validation import SSRFError
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_ssrf_blocked_with_path_params(self, mock_pinned):
|
||||
from application.security.safe_url import UnsafeUserUrlError
|
||||
|
||||
tool = APITool(config={
|
||||
"url": "https://api.example.com/{host}/data",
|
||||
@@ -213,14 +211,7 @@ class TestSSRFValidation:
|
||||
"query_params": {"host": "169.254.169.254"},
|
||||
})
|
||||
|
||||
call_count = [0]
|
||||
|
||||
def side_effect(url):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 2:
|
||||
raise SSRFError("blocked after substitution")
|
||||
|
||||
mock_validate.side_effect = side_effect
|
||||
mock_pinned.side_effect = UnsafeUserUrlError("blocked")
|
||||
result = tool.execute_action("any")
|
||||
assert result["status_code"] is None
|
||||
assert "URL validation error" in result["message"]
|
||||
@@ -234,40 +225,36 @@ class TestSSRFValidation:
|
||||
@pytest.mark.unit
|
||||
class TestErrorHandling:
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_timeout_error(self, mock_get, mock_validate, get_tool):
|
||||
mock_get.side_effect = requests.exceptions.Timeout()
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_timeout_error(self, mock_pinned, get_tool):
|
||||
mock_pinned.side_effect = requests.exceptions.Timeout()
|
||||
result = get_tool.execute_action("any")
|
||||
assert result["status_code"] is None
|
||||
assert "timeout" in result["message"].lower()
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_connection_error(self, mock_get, mock_validate, get_tool):
|
||||
mock_get.side_effect = requests.exceptions.ConnectionError("refused")
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_connection_error(self, mock_pinned, get_tool):
|
||||
mock_pinned.side_effect = requests.exceptions.ConnectionError("refused")
|
||||
result = get_tool.execute_action("any")
|
||||
assert result["status_code"] is None
|
||||
assert "Connection error" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_http_error_with_json(self, mock_get, mock_validate, get_tool):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_http_error_with_json(self, mock_pinned, get_tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 422
|
||||
mock_resp.json.return_value = {"error": "invalid_field"}
|
||||
mock_resp.raise_for_status.side_effect = requests.exceptions.HTTPError(
|
||||
response=mock_resp
|
||||
)
|
||||
mock_get.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = get_tool.execute_action("any")
|
||||
assert result["status_code"] == 422
|
||||
assert result["data"] == {"error": "invalid_field"}
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_http_error_non_json_body(self, mock_get, mock_validate, get_tool):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_http_error_non_json_body(self, mock_pinned, get_tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 404
|
||||
mock_resp.text = "Not Found"
|
||||
@@ -275,29 +262,26 @@ class TestErrorHandling:
|
||||
mock_resp.raise_for_status.side_effect = requests.exceptions.HTTPError(
|
||||
response=mock_resp
|
||||
)
|
||||
mock_get.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = get_tool.execute_action("any")
|
||||
assert result["status_code"] == 404
|
||||
assert result["data"] == "Not Found"
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_request_exception(self, mock_get, mock_validate, get_tool):
|
||||
mock_get.side_effect = requests.exceptions.RequestException("something")
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_request_exception(self, mock_pinned, get_tool):
|
||||
mock_pinned.side_effect = requests.exceptions.RequestException("something")
|
||||
result = get_tool.execute_action("any")
|
||||
assert "API call failed" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_unexpected_exception(self, mock_get, mock_validate, get_tool):
|
||||
mock_get.side_effect = RuntimeError("unexpected")
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_unexpected_exception(self, mock_pinned, get_tool):
|
||||
mock_pinned.side_effect = RuntimeError("unexpected")
|
||||
result = get_tool.execute_action("any")
|
||||
assert "Unexpected error" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.post")
|
||||
def test_body_serialization_error(self, mock_post, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_body_serialization_error(self, mock_pinned):
|
||||
tool = APITool(config={
|
||||
"url": "https://example.com",
|
||||
"method": "POST",
|
||||
@@ -320,9 +304,8 @@ class TestErrorHandling:
|
||||
@pytest.mark.unit
|
||||
class TestPathParamSubstitution:
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_path_params_substituted(self, mock_get, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_path_params_substituted(self, mock_pinned):
|
||||
tool = APITool(config={
|
||||
"url": "https://api.example.com/users/{user_id}/posts/{post_id}",
|
||||
"method": "GET",
|
||||
@@ -333,17 +316,16 @@ class TestPathParamSubstitution:
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = []
|
||||
mock_resp.content = b'[]'
|
||||
mock_get.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
tool.execute_action("get")
|
||||
|
||||
called_url = mock_get.call_args[0][0]
|
||||
called_url = mock_pinned.call_args[0][1]
|
||||
assert "/users/42/posts/7" in called_url
|
||||
assert "{user_id}" not in called_url
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_remaining_query_params_appended(self, mock_get, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_remaining_query_params_appended(self, mock_pinned):
|
||||
tool = APITool(config={
|
||||
"url": "https://api.example.com/items",
|
||||
"method": "GET",
|
||||
@@ -354,19 +336,16 @@ class TestPathParamSubstitution:
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = []
|
||||
mock_resp.content = b'[]'
|
||||
mock_get.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
tool.execute_action("get")
|
||||
|
||||
called_url = mock_get.call_args[0][0]
|
||||
called_url = mock_pinned.call_args[0][1]
|
||||
assert "page=2" in called_url
|
||||
assert "limit=10" in called_url
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_query_params_append_with_existing_query_string(
|
||||
self, mock_get, mock_validate
|
||||
):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_query_params_append_with_existing_query_string(self, mock_pinned):
|
||||
tool = APITool(config={
|
||||
"url": "https://api.example.com/items?existing=true",
|
||||
"method": "GET",
|
||||
@@ -377,27 +356,65 @@ class TestPathParamSubstitution:
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = []
|
||||
mock_resp.content = b'[]'
|
||||
mock_get.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
tool.execute_action("get")
|
||||
|
||||
called_url = mock_get.call_args[0][0]
|
||||
called_url = mock_pinned.call_args[0][1]
|
||||
assert "&page=1" in called_url
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.post")
|
||||
def test_empty_body_no_serialization(self, mock_post, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_empty_body_no_serialization(self, mock_pinned):
|
||||
tool = APITool(config={"url": "https://example.com", "method": "POST"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {}
|
||||
mock_resp.content = b'{}'
|
||||
mock_post.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("create")
|
||||
assert result["status_code"] == 200
|
||||
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_path_params_are_url_encoded(self, mock_pinned):
|
||||
tool = APITool(config={
|
||||
"url": "https://api.example.com/users/{user_id}/profile",
|
||||
"method": "GET",
|
||||
"query_params": {"user_id": "../../admin"},
|
||||
})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {}
|
||||
mock_resp.content = b'{}'
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
tool.execute_action("get")
|
||||
|
||||
called_url = mock_pinned.call_args[0][1]
|
||||
assert "../../admin" not in called_url
|
||||
assert "%2F" in called_url or "%2f" in called_url
|
||||
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_path_params_query_injection_encoded(self, mock_pinned):
|
||||
tool = APITool(config={
|
||||
"url": "https://api.example.com/items/{item_id}",
|
||||
"method": "GET",
|
||||
"query_params": {"item_id": "x?admin=true"},
|
||||
})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {}
|
||||
mock_resp.content = b'{}'
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
tool.execute_action("get")
|
||||
|
||||
called_url = mock_pinned.call_args[0][1]
|
||||
assert "x?admin=true" not in called_url
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Parse Response
|
||||
@@ -494,11 +511,8 @@ class TestAPIToolMetadata:
|
||||
def test_config_requirements_empty(self, get_tool):
|
||||
assert get_tool.get_config_requirements() == {}
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.post")
|
||||
def test_content_type_set_for_post_with_no_headers(
|
||||
self, mock_post, mock_validate
|
||||
):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_content_type_set_for_post_with_no_headers(self, mock_pinned):
|
||||
tool = APITool(config={
|
||||
"url": "https://example.com",
|
||||
"method": "POST",
|
||||
@@ -509,8 +523,8 @@ class TestAPIToolMetadata:
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {}
|
||||
mock_resp.content = b'{}'
|
||||
mock_post.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
tool.execute_action("create")
|
||||
call_headers = mock_post.call_args[1]["headers"]
|
||||
call_headers = mock_pinned.call_args.kwargs["headers"]
|
||||
assert "Content-Type" in call_headers
|
||||
|
||||
@@ -6,6 +6,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from sqlalchemy import text
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -256,9 +257,39 @@ class TestPaginatedSources:
|
||||
for key in (
|
||||
"id", "name", "date", "model", "location", "tokens",
|
||||
"retriever", "syncFrequency", "provider", "isNested", "type",
|
||||
"ingestStatus",
|
||||
):
|
||||
assert key in row
|
||||
|
||||
def test_exposes_stalled_ingest_status(self, app, pg_conn):
|
||||
"""A source whose ingest the reconciler escalated to 'stalled'
|
||||
surfaces ingestStatus='failed' so the UI can badge it.
|
||||
"""
|
||||
from application.api.user.sources.routes import PaginatedSources
|
||||
|
||||
user = "u-ingest-status"
|
||||
src = _seed_source(pg_conn, user, name="stalled-doc", type="file")
|
||||
pg_conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO ingest_chunk_progress (
|
||||
source_id, total_chunks, embedded_chunks, last_index,
|
||||
status
|
||||
)
|
||||
VALUES (CAST(:sid AS uuid), 907, 9, 8, 'stalled')
|
||||
"""
|
||||
),
|
||||
{"sid": str(src["id"])},
|
||||
)
|
||||
with _patch_db(pg_conn), app.test_request_context(
|
||||
"/api/sources/paginated?page=1&rows=10"
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": user}
|
||||
response = PaginatedSources().get()
|
||||
row = response.json["paginated"][0]
|
||||
assert row["ingestStatus"] == "failed"
|
||||
|
||||
|
||||
class TestDeleteOldIndexes:
|
||||
def test_returns_401_unauthenticated(self, app):
|
||||
@@ -553,6 +584,35 @@ class TestSyncSource:
|
||||
assert response.status_code == 200
|
||||
assert response.json["task_id"] == "task-123"
|
||||
|
||||
def test_normalizes_dict_remote_data_before_dispatch(self, app, pg_conn):
|
||||
"""The route must hand the sync task the normalized URL string."""
|
||||
from application.api.user.sources.routes import SyncSource
|
||||
|
||||
user = "u-normalize"
|
||||
src = _seed_source(
|
||||
pg_conn, user, name="crawl-src", type="crawler",
|
||||
remote_data=json.dumps(
|
||||
{"url": "https://example.com", "provider": "crawler"}
|
||||
),
|
||||
)
|
||||
|
||||
fake_task = MagicMock(id="task-norm")
|
||||
with _patch_db(pg_conn), patch(
|
||||
"application.api.user.sources.routes.sync_source.delay",
|
||||
return_value=fake_task,
|
||||
) as mock_delay, app.test_request_context(
|
||||
"/api/sync_source",
|
||||
method="POST",
|
||||
json={"source_id": str(src["id"])},
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": user}
|
||||
response = SyncSource().post()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert mock_delay.call_args.kwargs["source_data"] == "https://example.com"
|
||||
assert mock_delay.call_args.kwargs["loader"] == "crawler"
|
||||
|
||||
def test_sync_task_raises_returns_400(self, app, pg_conn):
|
||||
from application.api.user.sources.routes import SyncSource
|
||||
|
||||
@@ -576,6 +636,135 @@ class TestSyncSource:
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
class TestReingestSource:
|
||||
def test_returns_401_unauthenticated(self, app):
|
||||
from application.api.user.sources.routes import ReingestSource
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/sources/reingest", method="POST", json={"source_id": "x"}
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = None
|
||||
response = ReingestSource().post()
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_returns_400_missing_id(self, app):
|
||||
from application.api.user.sources.routes import ReingestSource
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/sources/reingest", method="POST", json={}
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": "u"}
|
||||
response = ReingestSource().post()
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_returns_404_missing_source(self, app, pg_conn):
|
||||
from application.api.user.sources.routes import ReingestSource
|
||||
|
||||
with _patch_db(pg_conn), app.test_request_context(
|
||||
"/api/sources/reingest",
|
||||
method="POST",
|
||||
json={"source_id": "00000000-0000-0000-0000-000000000000"},
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": "u"}
|
||||
response = ReingestSource().post()
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_triggers_reingest_task(self, app, pg_conn):
|
||||
from application.api.user.sources.routes import ReingestSource
|
||||
|
||||
user = "u-reingest"
|
||||
src = _seed_source(pg_conn, user, name="stalled-src", type="file")
|
||||
|
||||
fake_task = MagicMock(id="reingest-task-1")
|
||||
with _patch_db(pg_conn), patch(
|
||||
"application.api.user.sources.routes.reingest_source_task.delay",
|
||||
return_value=fake_task,
|
||||
) as mock_delay, app.test_request_context(
|
||||
"/api/sources/reingest",
|
||||
method="POST",
|
||||
json={"source_id": str(src["id"])},
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": user}
|
||||
response = ReingestSource().post()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["task_id"] == "reingest-task-1"
|
||||
assert mock_delay.call_args.kwargs["source_id"] == str(src["id"])
|
||||
assert mock_delay.call_args.kwargs["user"] == user
|
||||
# Scoped idempotency key engages the task's lease so repeated
|
||||
# clicks collapse onto one reingest instead of racing.
|
||||
assert mock_delay.call_args.kwargs["idempotency_key"] == (
|
||||
f"reingest-source:{user}:{src['id']}"
|
||||
)
|
||||
|
||||
def test_clears_stalled_ingest_progress_row(self, app, pg_conn):
|
||||
"""Reingest drops the stale chunk-progress row so the sources
|
||||
list stops deriving a 'failed' ingest status for the source.
|
||||
"""
|
||||
from application.api.user.sources.routes import ReingestSource
|
||||
|
||||
user = "u-reingest-clear"
|
||||
src = _seed_source(pg_conn, user, name="stalled-doc", type="file")
|
||||
pg_conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO ingest_chunk_progress (
|
||||
source_id, total_chunks, embedded_chunks, last_index,
|
||||
status
|
||||
)
|
||||
VALUES (CAST(:sid AS uuid), 100, 9, 8, 'stalled')
|
||||
"""
|
||||
),
|
||||
{"sid": str(src["id"])},
|
||||
)
|
||||
|
||||
fake_task = MagicMock(id="reingest-task-2")
|
||||
with _patch_db(pg_conn), patch(
|
||||
"application.api.user.sources.routes.reingest_source_task.delay",
|
||||
return_value=fake_task,
|
||||
), app.test_request_context(
|
||||
"/api/sources/reingest",
|
||||
method="POST",
|
||||
json={"source_id": str(src["id"])},
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": user}
|
||||
response = ReingestSource().post()
|
||||
|
||||
assert response.status_code == 200
|
||||
remaining = pg_conn.execute(
|
||||
text(
|
||||
"SELECT count(*) FROM ingest_chunk_progress "
|
||||
"WHERE source_id = CAST(:sid AS uuid)"
|
||||
),
|
||||
{"sid": str(src["id"])},
|
||||
).scalar()
|
||||
assert remaining == 0
|
||||
|
||||
def test_reingest_task_raises_returns_400(self, app, pg_conn):
|
||||
from application.api.user.sources.routes import ReingestSource
|
||||
|
||||
user = "u-reingest-fail"
|
||||
src = _seed_source(pg_conn, user, name="fail-src", type="file")
|
||||
|
||||
with _patch_db(pg_conn), patch(
|
||||
"application.api.user.sources.routes.reingest_source_task.delay",
|
||||
side_effect=RuntimeError("boom"),
|
||||
), app.test_request_context(
|
||||
"/api/sources/reingest",
|
||||
method="POST",
|
||||
json={"source_id": str(src["id"])},
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": user}
|
||||
response = ReingestSource().post()
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
class TestDirectoryStructure:
|
||||
def test_returns_401_unauthenticated(self, app):
|
||||
from application.api.user.sources.routes import DirectoryStructure
|
||||
|
||||
@@ -417,3 +417,181 @@ class TestSuccessfulRunClearsLease:
|
||||
assert row[0] == "completed"
|
||||
assert row[1] is None
|
||||
assert row[2] is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSynthesizedKeyGuardsKeylessDispatch:
|
||||
"""A keyless dispatch carrying ``source_id`` is still poison-guarded:
|
||||
the wrapper synthesizes a deterministic key from ``source_id``.
|
||||
"""
|
||||
|
||||
def test_keyless_with_source_id_records_dedup_row(self, pg_conn):
|
||||
from application.api.user.idempotency import with_idempotency
|
||||
|
||||
@with_idempotency(task_name="ingest")
|
||||
def task(self, idempotency_key=None, source_id=None):
|
||||
return {"ran": True}
|
||||
|
||||
with _patch_decorator_db(pg_conn):
|
||||
result = task(_fake_celery_self(), source_id="src-abc")
|
||||
|
||||
assert result == {"ran": True}
|
||||
row = _row_for(pg_conn, "auto:ingest:src-abc")
|
||||
assert row is not None
|
||||
assert row[0] == "ingest"
|
||||
assert row[2] == "completed"
|
||||
|
||||
def test_synthesized_key_stable_across_redeliveries(self, pg_conn):
|
||||
"""Same ``source_id`` → same key → a redelivery short-circuits to
|
||||
the cached result instead of re-running the body.
|
||||
"""
|
||||
from application.api.user.idempotency import with_idempotency
|
||||
|
||||
runs = {"count": 0}
|
||||
|
||||
@with_idempotency(task_name="ingest")
|
||||
def task(self, idempotency_key=None, source_id=None):
|
||||
runs["count"] += 1
|
||||
return {"n": runs["count"]}
|
||||
|
||||
with _patch_decorator_db(pg_conn):
|
||||
first = task(_fake_celery_self(), source_id="src-1")
|
||||
second = task(_fake_celery_self(), source_id="src-1")
|
||||
|
||||
assert first == second == {"n": 1}
|
||||
assert runs["count"] == 1
|
||||
|
||||
def test_poison_guard_trips_for_keyless_dispatch(self, pg_conn):
|
||||
"""The core fix: a keyless OOM-looping dispatch is bounded — the
|
||||
guard trips after MAX_TASK_ATTEMPTS with no explicit key.
|
||||
"""
|
||||
from application.api.user.idempotency import (
|
||||
MAX_TASK_ATTEMPTS, with_idempotency,
|
||||
)
|
||||
|
||||
runs = {"count": 0}
|
||||
|
||||
@with_idempotency(task_name="ingest")
|
||||
def task(self, idempotency_key=None, source_id=None):
|
||||
runs["count"] += 1
|
||||
raise RuntimeError("OOM-style failure")
|
||||
|
||||
with _patch_decorator_db(pg_conn):
|
||||
for _ in range(MAX_TASK_ATTEMPTS):
|
||||
with pytest.raises(RuntimeError):
|
||||
task(_fake_celery_self(), source_id="src-poison")
|
||||
result = task(_fake_celery_self(), source_id="src-poison")
|
||||
|
||||
assert runs["count"] == MAX_TASK_ATTEMPTS
|
||||
assert result["success"] is False
|
||||
assert "poison-loop" in result["error"]
|
||||
assert _row_for(pg_conn, "auto:ingest:src-poison")[2] == "failed"
|
||||
|
||||
def test_no_source_id_no_key_runs_unguarded(self, pg_conn):
|
||||
"""No explicit key and no ``source_id`` anchor → pass through with
|
||||
no DB writes, exactly as before.
|
||||
"""
|
||||
from application.api.user.idempotency import with_idempotency
|
||||
|
||||
@with_idempotency(task_name="store_attachment")
|
||||
def task(self, idempotency_key=None):
|
||||
return {"ran": True}
|
||||
|
||||
with patch(
|
||||
"application.api.user.idempotency.db_session"
|
||||
) as mock_session, patch(
|
||||
"application.api.user.idempotency.db_readonly"
|
||||
) as mock_readonly:
|
||||
result = task(_fake_celery_self())
|
||||
|
||||
assert result == {"ran": True}
|
||||
assert mock_session.call_count == 0
|
||||
assert mock_readonly.call_count == 0
|
||||
|
||||
def test_explicit_key_takes_precedence_over_source_id(self, pg_conn):
|
||||
"""An explicit key wins; the synthesized ``auto:`` key is unused."""
|
||||
from application.api.user.idempotency import with_idempotency
|
||||
|
||||
@with_idempotency(task_name="ingest")
|
||||
def task(self, idempotency_key=None, source_id=None):
|
||||
return {"ran": True}
|
||||
|
||||
with _patch_decorator_db(pg_conn):
|
||||
task(
|
||||
_fake_celery_self(),
|
||||
idempotency_key="explicit-k",
|
||||
source_id="src-x",
|
||||
)
|
||||
|
||||
assert _row_for(pg_conn, "explicit-k") is not None
|
||||
assert _row_for(pg_conn, "auto:ingest:src-x") is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPoisonHook:
|
||||
"""``on_poison`` fires on the poison-guard branch with the task's
|
||||
bound arguments, and never on the success path.
|
||||
"""
|
||||
|
||||
def test_hook_invoked_with_bound_args_on_poison(self, pg_conn):
|
||||
from application.api.user.idempotency import (
|
||||
MAX_TASK_ATTEMPTS, with_idempotency,
|
||||
)
|
||||
|
||||
captured = []
|
||||
|
||||
def _hook(task_name, bound):
|
||||
captured.append((task_name, bound))
|
||||
|
||||
@with_idempotency(task_name="ingest", on_poison=_hook)
|
||||
def task(self, idempotency_key=None, source_id=None):
|
||||
raise RuntimeError("never converges")
|
||||
|
||||
with _patch_decorator_db(pg_conn):
|
||||
for _ in range(MAX_TASK_ATTEMPTS):
|
||||
with pytest.raises(RuntimeError):
|
||||
task(_fake_celery_self(), source_id="src-h")
|
||||
task(_fake_celery_self(), source_id="src-h")
|
||||
|
||||
assert len(captured) == 1
|
||||
task_name, bound = captured[0]
|
||||
assert task_name == "ingest"
|
||||
assert bound["source_id"] == "src-h"
|
||||
|
||||
def test_hook_not_invoked_on_success(self, pg_conn):
|
||||
from application.api.user.idempotency import with_idempotency
|
||||
|
||||
calls = []
|
||||
|
||||
@with_idempotency(
|
||||
task_name="ingest", on_poison=lambda *a: calls.append(a)
|
||||
)
|
||||
def task(self, idempotency_key=None, source_id=None):
|
||||
return {"ok": True}
|
||||
|
||||
with _patch_decorator_db(pg_conn):
|
||||
task(_fake_celery_self(), source_id="src-ok")
|
||||
|
||||
assert calls == []
|
||||
|
||||
def test_hook_failure_does_not_break_poison_return(self, pg_conn):
|
||||
"""A throwing hook must not change the poison-guard outcome."""
|
||||
from application.api.user.idempotency import (
|
||||
MAX_TASK_ATTEMPTS, with_idempotency,
|
||||
)
|
||||
|
||||
def _bad_hook(task_name, bound):
|
||||
raise ValueError("hook blew up")
|
||||
|
||||
@with_idempotency(task_name="ingest", on_poison=_bad_hook)
|
||||
def task(self, idempotency_key=None, source_id=None):
|
||||
raise RuntimeError("never converges")
|
||||
|
||||
with _patch_decorator_db(pg_conn):
|
||||
for _ in range(MAX_TASK_ATTEMPTS):
|
||||
with pytest.raises(RuntimeError):
|
||||
task(_fake_celery_self(), source_id="src-bad")
|
||||
result = task(_fake_celery_self(), source_id="src-bad")
|
||||
|
||||
assert result["success"] is False
|
||||
assert "poison-loop" in result["error"]
|
||||
|
||||
@@ -529,6 +529,142 @@ class TestStuckExecutedToolCalls:
|
||||
assert row[0] == "executed"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Q4 — stalled ingest checkpoints (escalate to terminal 'stalled' + alert)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _seed_ingest_progress(
|
||||
conn,
|
||||
*,
|
||||
source_id: str,
|
||||
embedded: int,
|
||||
total: int,
|
||||
age_minutes: int = 31,
|
||||
status: str = "active",
|
||||
) -> str:
|
||||
"""Insert an ingest_chunk_progress row with a backdated last_updated."""
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO ingest_chunk_progress (
|
||||
source_id, total_chunks, embedded_chunks, last_index,
|
||||
last_updated, status
|
||||
)
|
||||
VALUES (
|
||||
CAST(:sid AS uuid), :total, :embedded, :embedded - 1,
|
||||
clock_timestamp() - make_interval(mins => :age),
|
||||
:status
|
||||
)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"sid": source_id,
|
||||
"total": total,
|
||||
"embedded": embedded,
|
||||
"age": age_minutes,
|
||||
"status": status,
|
||||
},
|
||||
)
|
||||
return source_id
|
||||
|
||||
|
||||
def _ingest_status(conn, source_id: str) -> str | None:
|
||||
"""Return the ``status`` of an ingest_chunk_progress row, or None."""
|
||||
row = conn.execute(
|
||||
text(
|
||||
"SELECT status FROM ingest_chunk_progress "
|
||||
"WHERE source_id = CAST(:sid AS uuid)"
|
||||
),
|
||||
{"sid": source_id},
|
||||
).fetchone()
|
||||
return row[0] if row is not None else None
|
||||
|
||||
|
||||
class TestStalledIngests:
|
||||
@pytest.mark.unit
|
||||
def test_stalled_ingest_escalated_with_alert(self, pg_conn, caplog):
|
||||
from application.api.user.reconciliation import run_reconciliation
|
||||
|
||||
sid = "1a000000-0000-0000-0000-0000000000a1"
|
||||
_seed_ingest_progress(pg_conn, source_id=sid, embedded=9, total=907)
|
||||
before = _stack_logs_count(pg_conn, "reconciler_ingest_stalled")
|
||||
|
||||
with _route_engine_to(pg_conn), caplog.at_level(
|
||||
logging.ERROR, logger="application.api.user.reconciliation",
|
||||
):
|
||||
r = run_reconciliation()
|
||||
|
||||
assert r["ingests_stalled"] == 1
|
||||
# Escalated to a terminal status so the next tick skips it.
|
||||
assert _ingest_status(pg_conn, sid) == "stalled"
|
||||
# Structured alert + stack_logs row both surface the failure.
|
||||
assert any(
|
||||
getattr(rec, "alert", None) == "reconciler_ingest_stalled"
|
||||
and rec.levelname == "ERROR"
|
||||
for rec in caplog.records
|
||||
)
|
||||
assert (
|
||||
_stack_logs_count(pg_conn, "reconciler_ingest_stalled")
|
||||
== before + 1
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_stalled_ingest_alerts_once_not_every_tick(self, pg_conn):
|
||||
"""The escalate-to-'stalled' write ends the re-alert loop: a
|
||||
second tick neither re-counts nor re-logs the same dead ingest.
|
||||
"""
|
||||
from application.api.user.reconciliation import run_reconciliation
|
||||
|
||||
sid = "1a000000-0000-0000-0000-0000000000a2"
|
||||
_seed_ingest_progress(pg_conn, source_id=sid, embedded=1, total=95)
|
||||
before = _stack_logs_count(pg_conn, "reconciler_ingest_stalled")
|
||||
|
||||
with _route_engine_to(pg_conn):
|
||||
r1 = run_reconciliation()
|
||||
r2 = run_reconciliation()
|
||||
|
||||
assert r1["ingests_stalled"] == 1
|
||||
assert r2["ingests_stalled"] == 0
|
||||
# Only the first tick wrote an alert row.
|
||||
assert (
|
||||
_stack_logs_count(pg_conn, "reconciler_ingest_stalled")
|
||||
== before + 1
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_fresh_ingest_left_alone(self, pg_conn):
|
||||
from application.api.user.reconciliation import run_reconciliation
|
||||
|
||||
sid = "1a000000-0000-0000-0000-0000000000a3"
|
||||
# 2 minutes old — well under the 30-minute staleness threshold.
|
||||
_seed_ingest_progress(
|
||||
pg_conn, source_id=sid, embedded=3, total=20, age_minutes=2,
|
||||
)
|
||||
|
||||
with _route_engine_to(pg_conn):
|
||||
r = run_reconciliation()
|
||||
|
||||
assert r["ingests_stalled"] == 0
|
||||
assert _ingest_status(pg_conn, sid) == "active"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_completed_ingest_left_alone(self, pg_conn):
|
||||
"""A stale checkpoint that finished embedding (embedded == total)
|
||||
is not a stall and must not be flagged.
|
||||
"""
|
||||
from application.api.user.reconciliation import run_reconciliation
|
||||
|
||||
sid = "1a000000-0000-0000-0000-0000000000a4"
|
||||
_seed_ingest_progress(pg_conn, source_id=sid, embedded=50, total=50)
|
||||
|
||||
with _route_engine_to(pg_conn):
|
||||
r = run_reconciliation()
|
||||
|
||||
assert r["ingests_stalled"] == 0
|
||||
assert _ingest_status(pg_conn, sid) == "active"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Q5 — stuck idempotency pending rows (lease expired + attempts exhausted)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -546,3 +546,63 @@ class TestIngestIdempotency:
|
||||
assert first == second
|
||||
assert first == {"status": "ok", "directory": "dir"}
|
||||
assert len(worker_calls) == 1
|
||||
|
||||
|
||||
class TestIngestPoisonEvent:
|
||||
"""The poison hook publishes a terminal source.ingest.failed so the
|
||||
upload toast resolves instead of hanging on "training".
|
||||
"""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_publishes_failed_event(self):
|
||||
from application.api.user.tasks import _emit_ingest_poison_event
|
||||
|
||||
published = []
|
||||
|
||||
def _fake_publish(user, event_type, payload, *, scope=None):
|
||||
published.append((user, event_type, payload, scope))
|
||||
|
||||
with patch(
|
||||
"application.events.publisher.publish_user_event",
|
||||
side_effect=_fake_publish,
|
||||
):
|
||||
_emit_ingest_poison_event(
|
||||
"ingest",
|
||||
{"user": "u1", "source_id": "src-9", "filename": "doc.pdf"},
|
||||
)
|
||||
|
||||
assert len(published) == 1
|
||||
user, event_type, payload, scope = published[0]
|
||||
assert user == "u1"
|
||||
assert event_type == "source.ingest.failed"
|
||||
assert payload["source_id"] == "src-9"
|
||||
assert payload["filename"] == "doc.pdf"
|
||||
assert payload["operation"] == "upload"
|
||||
assert scope == {"kind": "source", "id": "src-9"}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_skips_when_source_id_missing(self):
|
||||
from application.api.user.tasks import _emit_ingest_poison_event
|
||||
|
||||
with patch(
|
||||
"application.events.publisher.publish_user_event",
|
||||
) as mock_publish:
|
||||
_emit_ingest_poison_event("ingest", {"user": "u1"})
|
||||
|
||||
mock_publish.assert_not_called()
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_reingest_uses_reingest_operation(self):
|
||||
from application.api.user.tasks import _emit_ingest_poison_event
|
||||
|
||||
published = []
|
||||
with patch(
|
||||
"application.events.publisher.publish_user_event",
|
||||
side_effect=lambda *a, **k: published.append((a, k)),
|
||||
):
|
||||
_emit_ingest_poison_event(
|
||||
"reingest_source_task",
|
||||
{"user": "u1", "source_id": "src-r"},
|
||||
)
|
||||
|
||||
assert published[0][0][2]["operation"] == "reingest"
|
||||
|
||||
@@ -158,6 +158,35 @@ class TestSimpleDirectoryReaderLoadData:
|
||||
for doc in docs:
|
||||
assert isinstance(doc, Document)
|
||||
|
||||
def test_load_data_progress_callback_fires_per_file(self, temp_dir):
|
||||
from application.parser.file.bulk import SimpleDirectoryReader
|
||||
|
||||
reader = SimpleDirectoryReader(
|
||||
input_dir=str(temp_dir), recursive=False, exclude_hidden=True,
|
||||
)
|
||||
calls = []
|
||||
reader.load_data(progress_callback=lambda done, total: calls.append((done, total)))
|
||||
|
||||
total_files = len(reader.input_files)
|
||||
assert total_files >= 1
|
||||
# One callback per file, monotonically increasing, ending at total.
|
||||
assert [c[0] for c in calls] == list(range(1, total_files + 1))
|
||||
assert all(c[1] == total_files for c in calls)
|
||||
|
||||
def test_load_data_progress_callback_errors_swallowed(self, temp_dir):
|
||||
from application.parser.file.bulk import SimpleDirectoryReader
|
||||
|
||||
reader = SimpleDirectoryReader(
|
||||
input_dir=str(temp_dir), recursive=False, exclude_hidden=True,
|
||||
)
|
||||
|
||||
def _boom(done, total):
|
||||
raise RuntimeError("callback blew up")
|
||||
|
||||
# A failing callback must not abort ingestion.
|
||||
docs = reader.load_data(progress_callback=_boom)
|
||||
assert len(docs) >= 1
|
||||
|
||||
def test_load_data_concatenate(self, temp_dir):
|
||||
from application.parser.file.bulk import SimpleDirectoryReader
|
||||
|
||||
|
||||
@@ -421,3 +421,85 @@ class TestDoclingParserGaps:
|
||||
parser = DoclingCSVParser()
|
||||
assert parser.export_format == "markdown"
|
||||
assert parser.ocr_enabled is True
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Pipeline memory caps
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestApplyPipelineCaps:
|
||||
"""_apply_pipeline_caps bounds docling's threaded-pipeline buffering."""
|
||||
|
||||
def test_caps_threaded_pipeline_knobs(self, monkeypatch):
|
||||
from application.core.settings import settings
|
||||
from application.parser.file.docling_parser import _apply_pipeline_caps
|
||||
|
||||
monkeypatch.setattr(
|
||||
settings, "DOCLING_PIPELINE_QUEUE_MAX_SIZE", 2, raising=False
|
||||
)
|
||||
|
||||
class Opts:
|
||||
# docling >= 2.94 threaded pipeline — all knobs present.
|
||||
queue_max_size = 100
|
||||
layout_batch_size = 4
|
||||
table_batch_size = 4
|
||||
ocr_batch_size = 4
|
||||
|
||||
opts = Opts()
|
||||
_apply_pipeline_caps(opts)
|
||||
|
||||
assert opts.queue_max_size == 2
|
||||
assert opts.layout_batch_size == 1
|
||||
assert opts.table_batch_size == 1
|
||||
assert opts.ocr_batch_size == 1
|
||||
|
||||
def test_queue_size_is_settings_driven(self, monkeypatch):
|
||||
from application.core.settings import settings
|
||||
from application.parser.file.docling_parser import _apply_pipeline_caps
|
||||
|
||||
monkeypatch.setattr(
|
||||
settings, "DOCLING_PIPELINE_QUEUE_MAX_SIZE", 6, raising=False
|
||||
)
|
||||
|
||||
class Opts:
|
||||
queue_max_size = 100
|
||||
|
||||
opts = Opts()
|
||||
_apply_pipeline_caps(opts)
|
||||
assert opts.queue_max_size == 6
|
||||
|
||||
def test_misconfigured_zero_floors_to_one(self, monkeypatch):
|
||||
"""A 0 queue depth could deadlock the threaded pipeline — floor it."""
|
||||
from application.core.settings import settings
|
||||
from application.parser.file.docling_parser import _apply_pipeline_caps
|
||||
|
||||
monkeypatch.setattr(
|
||||
settings, "DOCLING_PIPELINE_QUEUE_MAX_SIZE", 0, raising=False
|
||||
)
|
||||
|
||||
class Opts:
|
||||
queue_max_size = 100
|
||||
|
||||
opts = Opts()
|
||||
_apply_pipeline_caps(opts)
|
||||
assert opts.queue_max_size == 1
|
||||
|
||||
def test_noop_on_docling_without_threaded_pipeline(self):
|
||||
"""Builds predating the threaded pipeline lack the knobs — the cap
|
||||
must be a silent no-op, not an AttributeError."""
|
||||
from application.parser.file.docling_parser import _apply_pipeline_caps
|
||||
|
||||
class LegacyOpts:
|
||||
__slots__ = ("do_ocr", "do_table_structure")
|
||||
|
||||
def __init__(self):
|
||||
self.do_ocr = False
|
||||
self.do_table_structure = True
|
||||
|
||||
opts = LegacyOpts()
|
||||
_apply_pipeline_caps(opts) # must not raise
|
||||
|
||||
assert not hasattr(opts, "queue_max_size")
|
||||
assert not hasattr(opts, "layout_batch_size")
|
||||
|
||||
@@ -94,6 +94,35 @@ def test_embed_and_store_documents_non_faiss(tmp_path, mock_settings, mock_vecto
|
||||
assert folder_name.exists()
|
||||
|
||||
|
||||
def test_embed_and_store_documents_progress_band(
|
||||
tmp_path, mock_settings, mock_vector_creator
|
||||
):
|
||||
"""progress_start/progress_end remap the embed loop into a sub-band
|
||||
so an earlier stage (parsing) can own the lower part of the bar.
|
||||
"""
|
||||
mock_settings.VECTOR_STORE = "chromadb"
|
||||
|
||||
docs = [MagicMock(page_content=f"d{i}", metadata={}) for i in range(4)]
|
||||
task_status = MagicMock()
|
||||
mock_vector_creator.create_vectorstore.return_value = MagicMock()
|
||||
|
||||
embed_and_store_documents(
|
||||
docs, str(tmp_path / "store"), "sid", task_status,
|
||||
progress_start=50, progress_end=100,
|
||||
)
|
||||
|
||||
currents = [
|
||||
call.kwargs["meta"]["current"]
|
||||
for call in task_status.update_state.call_args_list
|
||||
if "meta" in call.kwargs and "current" in call.kwargs["meta"]
|
||||
]
|
||||
assert currents, "expected progress updates"
|
||||
# Embedding stays in the upper band and tops out at 100.
|
||||
assert min(currents) > 50
|
||||
assert max(currents) == 100
|
||||
assert currents == sorted(currents)
|
||||
|
||||
|
||||
@patch("application.parser.embedding_pipeline.add_text_to_store_with_retry")
|
||||
def test_embed_and_store_documents_partial_failure_raises(
|
||||
mock_add_retry, tmp_path, mock_settings, mock_vector_creator, caplog
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
"""Tests for application.parser.remote.remote_creator covering lines 31-34."""
|
||||
"""Tests for application.parser.remote.remote_creator."""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
@@ -38,3 +40,92 @@ class TestRemoteCreator:
|
||||
mock_loader_cls.assert_called_once()
|
||||
finally:
|
||||
RemoteCreator.loaders = original_loaders
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestNormalizeRemoteData:
|
||||
"""``normalize_remote_data`` maps a stored JSONB ``remote_data`` value
|
||||
back to the ``source_data`` shape each loader expects."""
|
||||
|
||||
def test_none_passes_through(self):
|
||||
from application.parser.remote.remote_creator import normalize_remote_data
|
||||
|
||||
assert normalize_remote_data("crawler", None) is None
|
||||
|
||||
def test_crawler_dict_with_url_key(self):
|
||||
from application.parser.remote.remote_creator import normalize_remote_data
|
||||
|
||||
result = normalize_remote_data(
|
||||
"crawler", {"url": "https://example.com", "provider": "crawler"}
|
||||
)
|
||||
assert result == "https://example.com"
|
||||
|
||||
def test_url_dict_with_url_key(self):
|
||||
from application.parser.remote.remote_creator import normalize_remote_data
|
||||
|
||||
result = normalize_remote_data("url", {"url": "https://example.com"})
|
||||
assert result == "https://example.com"
|
||||
|
||||
def test_url_legacy_raw_key(self):
|
||||
"""Legacy rows wrapped a bare URL string as ``{"raw": ...}``."""
|
||||
from application.parser.remote.remote_creator import normalize_remote_data
|
||||
|
||||
result = normalize_remote_data("crawler", {"raw": "https://legacy.example.com"})
|
||||
assert result == "https://legacy.example.com"
|
||||
|
||||
def test_url_dict_with_urls_list(self):
|
||||
from application.parser.remote.remote_creator import normalize_remote_data
|
||||
|
||||
result = normalize_remote_data(
|
||||
"url", {"urls": ["https://a.example.com", "https://b.example.com"]}
|
||||
)
|
||||
assert result == ["https://a.example.com", "https://b.example.com"]
|
||||
|
||||
def test_github_repo_url_key(self):
|
||||
from application.parser.remote.remote_creator import normalize_remote_data
|
||||
|
||||
result = normalize_remote_data(
|
||||
"github", {"repo_url": "https://github.com/arc53/DocsGPT"}
|
||||
)
|
||||
assert result == "https://github.com/arc53/DocsGPT"
|
||||
|
||||
def test_sitemap_dict_with_url_key(self):
|
||||
from application.parser.remote.remote_creator import normalize_remote_data
|
||||
|
||||
result = normalize_remote_data("sitemap", {"url": "https://example.com/sitemap.xml"})
|
||||
assert result == "https://example.com/sitemap.xml"
|
||||
|
||||
def test_plain_string_url_passes_through(self):
|
||||
from application.parser.remote.remote_creator import normalize_remote_data
|
||||
|
||||
assert normalize_remote_data("crawler", "https://example.com") == "https://example.com"
|
||||
|
||||
def test_url_dict_without_url_key_returns_none(self):
|
||||
"""A URL-type loader must never receive a dict, even a malformed one."""
|
||||
from application.parser.remote.remote_creator import normalize_remote_data
|
||||
|
||||
assert normalize_remote_data("crawler", {"provider": "crawler"}) is None
|
||||
|
||||
def test_reddit_dict_serialized_to_json_string(self):
|
||||
"""reddit's loader runs json.loads() — it needs a JSON string."""
|
||||
from application.parser.remote.remote_creator import normalize_remote_data
|
||||
|
||||
result = normalize_remote_data(
|
||||
"reddit", {"client_id": "x", "search_queries": ["y"]}
|
||||
)
|
||||
assert isinstance(result, str)
|
||||
assert json.loads(result) == {"client_id": "x", "search_queries": ["y"]}
|
||||
|
||||
def test_s3_dict_passes_through(self):
|
||||
"""S3Loader.load_data() accepts a dict, so it is left untouched."""
|
||||
from application.parser.remote.remote_creator import normalize_remote_data
|
||||
|
||||
data = {"bucket": "b", "prefix": "k"}
|
||||
assert normalize_remote_data("s3", data) == data
|
||||
|
||||
def test_json_string_remote_data_is_parsed(self):
|
||||
"""Legacy rows that stored the JSON itself as a string still resolve."""
|
||||
from application.parser.remote.remote_creator import normalize_remote_data
|
||||
|
||||
result = normalize_remote_data("crawler", '{"url": "https://example.com"}')
|
||||
assert result == "https://example.com"
|
||||
|
||||
74
tests/storage/db/repositories/test_ingest_chunk_progress.py
Normal file
74
tests/storage/db/repositories/test_ingest_chunk_progress.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""Tests for IngestChunkProgressRepository against ephemeral Postgres."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from application.storage.db.repositories.ingest_chunk_progress import (
|
||||
IngestChunkProgressRepository,
|
||||
)
|
||||
|
||||
|
||||
def _status(conn, source_id: str) -> str:
|
||||
return conn.execute(
|
||||
text(
|
||||
"SELECT status FROM ingest_chunk_progress "
|
||||
"WHERE source_id = CAST(:sid AS uuid)"
|
||||
),
|
||||
{"sid": source_id},
|
||||
).scalar()
|
||||
|
||||
|
||||
def _mark_stalled(conn, source_id: str) -> None:
|
||||
conn.execute(
|
||||
text(
|
||||
"UPDATE ingest_chunk_progress SET status = 'stalled' "
|
||||
"WHERE source_id = CAST(:sid AS uuid)"
|
||||
),
|
||||
{"sid": source_id},
|
||||
)
|
||||
|
||||
|
||||
class TestInitProgressStatus:
|
||||
def test_new_row_starts_active(self, pg_conn):
|
||||
sid = "3c000000-0000-0000-0000-0000000000c1"
|
||||
IngestChunkProgressRepository(pg_conn).init_progress(sid, 10, "att-1")
|
||||
assert _status(pg_conn, sid) == "active"
|
||||
|
||||
def test_reingest_resets_stalled_to_active(self, pg_conn):
|
||||
"""A reconciler-escalated 'stalled' row flips back to 'active'
|
||||
when the source is reingested under a fresh attempt id.
|
||||
"""
|
||||
sid = "3c000000-0000-0000-0000-0000000000c2"
|
||||
repo = IngestChunkProgressRepository(pg_conn)
|
||||
repo.init_progress(sid, 10, "att-1")
|
||||
_mark_stalled(pg_conn, sid)
|
||||
|
||||
repo.init_progress(sid, 10, "att-2")
|
||||
assert _status(pg_conn, sid) == "active"
|
||||
|
||||
def test_same_attempt_retry_also_clears_stalled(self, pg_conn):
|
||||
"""A same-attempt resume (Celery autoretry) also clears a stale
|
||||
'stalled' flag — the task is running again.
|
||||
"""
|
||||
sid = "3c000000-0000-0000-0000-0000000000c3"
|
||||
repo = IngestChunkProgressRepository(pg_conn)
|
||||
repo.init_progress(sid, 10, "att-1")
|
||||
_mark_stalled(pg_conn, sid)
|
||||
|
||||
repo.init_progress(sid, 10, "att-1")
|
||||
assert _status(pg_conn, sid) == "active"
|
||||
|
||||
|
||||
class TestDelete:
|
||||
def test_delete_removes_row(self, pg_conn):
|
||||
sid = "3c000000-0000-0000-0000-0000000000d1"
|
||||
repo = IngestChunkProgressRepository(pg_conn)
|
||||
repo.init_progress(sid, 10, "att-1")
|
||||
|
||||
assert repo.delete(sid) is True
|
||||
assert repo.get_progress(sid) is None
|
||||
|
||||
def test_delete_missing_row_returns_false(self, pg_conn):
|
||||
repo = IngestChunkProgressRepository(pg_conn)
|
||||
assert repo.delete("3c000000-0000-0000-0000-0000000000df") is False
|
||||
@@ -342,6 +342,89 @@ class TestMarkToolCallFailed:
|
||||
assert row[1] == "oops"
|
||||
|
||||
|
||||
def _seed_ingest_progress(
|
||||
conn,
|
||||
*,
|
||||
source_id: str,
|
||||
embedded: int,
|
||||
total: int,
|
||||
age_minutes: int = 31,
|
||||
status: str = "active",
|
||||
) -> None:
|
||||
"""Seed an ingest_chunk_progress row with a backdated last_updated."""
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO ingest_chunk_progress (
|
||||
source_id, total_chunks, embedded_chunks, last_index,
|
||||
last_updated, status
|
||||
)
|
||||
VALUES (
|
||||
CAST(:sid AS uuid), :total, :embedded, :embedded - 1,
|
||||
clock_timestamp() - make_interval(mins => :age), :status
|
||||
)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"sid": source_id, "total": total, "embedded": embedded,
|
||||
"age": age_minutes, "status": status,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class TestFindAndLockStalledIngests:
|
||||
def test_returns_stale_active_partial(self, pg_conn):
|
||||
sid = "2b000000-0000-0000-0000-0000000000b1"
|
||||
_seed_ingest_progress(pg_conn, source_id=sid, embedded=2, total=10)
|
||||
rows = ReconciliationRepository(pg_conn).find_and_lock_stalled_ingests()
|
||||
assert any(str(r["source_id"]) == sid for r in rows)
|
||||
|
||||
def test_excludes_already_stalled(self, pg_conn):
|
||||
sid = "2b000000-0000-0000-0000-0000000000b2"
|
||||
_seed_ingest_progress(
|
||||
pg_conn, source_id=sid, embedded=2, total=10, status="stalled",
|
||||
)
|
||||
rows = ReconciliationRepository(pg_conn).find_and_lock_stalled_ingests()
|
||||
assert all(str(r["source_id"]) != sid for r in rows)
|
||||
|
||||
def test_excludes_completed(self, pg_conn):
|
||||
sid = "2b000000-0000-0000-0000-0000000000b3"
|
||||
_seed_ingest_progress(pg_conn, source_id=sid, embedded=10, total=10)
|
||||
rows = ReconciliationRepository(pg_conn).find_and_lock_stalled_ingests()
|
||||
assert all(str(r["source_id"]) != sid for r in rows)
|
||||
|
||||
def test_excludes_under_age_threshold(self, pg_conn):
|
||||
sid = "2b000000-0000-0000-0000-0000000000b4"
|
||||
_seed_ingest_progress(
|
||||
pg_conn, source_id=sid, embedded=2, total=10, age_minutes=2,
|
||||
)
|
||||
rows = ReconciliationRepository(pg_conn).find_and_lock_stalled_ingests()
|
||||
assert all(str(r["source_id"]) != sid for r in rows)
|
||||
|
||||
|
||||
class TestMarkIngestStalled:
|
||||
def test_flips_status_to_stalled(self, pg_conn):
|
||||
sid = "2b000000-0000-0000-0000-0000000000b5"
|
||||
_seed_ingest_progress(pg_conn, source_id=sid, embedded=2, total=10)
|
||||
repo = ReconciliationRepository(pg_conn)
|
||||
assert repo.mark_ingest_stalled(sid) is True
|
||||
row = pg_conn.execute(
|
||||
text(
|
||||
"SELECT status FROM ingest_chunk_progress "
|
||||
"WHERE source_id = CAST(:sid AS uuid)"
|
||||
),
|
||||
{"sid": sid},
|
||||
).fetchone()
|
||||
assert row[0] == "stalled"
|
||||
|
||||
def test_returns_false_for_missing_source(self, pg_conn):
|
||||
repo = ReconciliationRepository(pg_conn)
|
||||
assert (
|
||||
repo.mark_ingest_stalled("2b000000-0000-0000-0000-0000000000bf")
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def _seed_stuck_idempotency(
|
||||
conn,
|
||||
*,
|
||||
|
||||
@@ -148,6 +148,130 @@ class TestSyncWorker:
|
||||
assert captured[0]["loader"] == "url"
|
||||
assert captured[0]["doc_id"] == str(src["id"])
|
||||
|
||||
def test_connector_sources_are_skipped(
|
||||
self,
|
||||
pg_conn,
|
||||
patch_worker_db,
|
||||
task_self,
|
||||
monkeypatch,
|
||||
):
|
||||
"""connector:* sources have no RemoteCreator loader — sync_worker
|
||||
must skip them, not dispatch them into sync()."""
|
||||
from application import worker
|
||||
|
||||
SourcesRepository(pg_conn).create(
|
||||
"drive-folder",
|
||||
user_id="dave",
|
||||
type="connector:file",
|
||||
retriever="classic",
|
||||
sync_frequency="daily",
|
||||
remote_data={
|
||||
"provider": "google_drive",
|
||||
"file_ids": ["f1"],
|
||||
"folder_ids": [],
|
||||
"recursive": False,
|
||||
},
|
||||
)
|
||||
|
||||
def _must_not_run(*args, **kwargs):
|
||||
raise AssertionError("sync() must not run for connector sources")
|
||||
|
||||
monkeypatch.setattr(worker, "sync", _must_not_run)
|
||||
|
||||
result = worker.sync_worker(task_self, "daily")
|
||||
|
||||
assert result["total_sync_count"] == 1
|
||||
assert result["sync_skipped"] == 1
|
||||
assert result["sync_success"] == 0
|
||||
assert result["sync_failure"] == 0
|
||||
|
||||
def test_dict_remote_data_is_normalized_before_loader(
|
||||
self,
|
||||
pg_conn,
|
||||
patch_worker_db,
|
||||
task_self,
|
||||
monkeypatch,
|
||||
):
|
||||
"""Regression: remote_data reads back as a dict; sync_worker must
|
||||
hand the loader the URL string, not the raw dict."""
|
||||
from application import worker
|
||||
|
||||
SourcesRepository(pg_conn).create(
|
||||
"docs-crawl",
|
||||
user_id="erin",
|
||||
type="crawler",
|
||||
retriever="classic",
|
||||
sync_frequency="weekly",
|
||||
remote_data={"url": "https://example.com", "provider": "crawler"},
|
||||
)
|
||||
|
||||
received: list = []
|
||||
fake_loader = MagicMock(name="remote_loader")
|
||||
|
||||
def _capture(source_data):
|
||||
received.append(source_data)
|
||||
return [
|
||||
Document(
|
||||
text="page body",
|
||||
extra_info={"file_path": "index.md", "title": "home"},
|
||||
doc_id="d1",
|
||||
)
|
||||
]
|
||||
|
||||
fake_loader.load_data.side_effect = _capture
|
||||
monkeypatch.setattr(
|
||||
worker.RemoteCreator, "create_loader", lambda loader: fake_loader
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
worker,
|
||||
"embed_and_store_documents",
|
||||
lambda docs, full_path, source_id, task, **kw: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
worker, "upload_index", lambda full_path, file_data: None
|
||||
)
|
||||
|
||||
result = worker.sync_worker(task_self, "weekly")
|
||||
|
||||
assert result["total_sync_count"] == 1
|
||||
assert result["sync_success"] == 1
|
||||
assert result["sync_failure"] == 0
|
||||
assert received == ["https://example.com"], (
|
||||
"loader must receive the URL string, not the remote_data dict"
|
||||
)
|
||||
|
||||
def test_unsyncable_remote_data_is_skipped(
|
||||
self,
|
||||
pg_conn,
|
||||
patch_worker_db,
|
||||
task_self,
|
||||
monkeypatch,
|
||||
):
|
||||
"""A URL source whose remote_data dict has no URL key normalizes
|
||||
to None — sync_worker must skip it, not dispatch a doomed sync()."""
|
||||
from application import worker
|
||||
|
||||
SourcesRepository(pg_conn).create(
|
||||
"broken-feed",
|
||||
user_id="frank",
|
||||
type="url",
|
||||
retriever="classic",
|
||||
sync_frequency="monthly",
|
||||
remote_data={"provider": "url"},
|
||||
)
|
||||
|
||||
def _must_not_run(*args, **kwargs):
|
||||
raise AssertionError("sync() must not run for unsyncable sources")
|
||||
|
||||
monkeypatch.setattr(worker, "sync", _must_not_run)
|
||||
|
||||
result = worker.sync_worker(task_self, "monthly")
|
||||
|
||||
assert result["total_sync_count"] == 1
|
||||
assert result["sync_skipped"] == 1
|
||||
assert result["sync_failure"] == 0
|
||||
assert result["sync_success"] == 0
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestRemoteWorkerPathTraversal:
|
||||
|
||||
Reference in New Issue
Block a user