Compare commits

..

28 Commits

Author SHA1 Message Date
Alex
827a0bb382 Merge remote-tracking branch 'origin/main' into feat-notification-system
# Conflicts:
#	frontend/src/api/services/userService.ts
#	frontend/src/utils/providerUtils.ts
2026-05-13 22:57:58 +01:00
Alex
b04cb44ab5 fix: e2e tests 2026-05-13 22:49:37 +01:00
Alex
42384a0e92 fix: better docs 2026-05-13 17:03:14 +01:00
Alex
0bce35ad29 feat: events cleanup 2026-05-13 08:36:52 +01:00
Alex
9de8bb4499 chore(events): rename attachment.processing.progress to attachment.progress
The event-type taxonomy was inconsistent: source ingest emits
source.ingest.progress (three segments) while attachments emitted
attachment.processing.progress (four segments). Drops the
.processing. infix for parity. Worker publish sites, the slice
reducer's match, and the worker tests all flip together.

No external consumers — the event type is purely internal between
the publisher and the in-tab slice; safe to rename in one commit.
2026-05-12 19:38:18 +01:00
Alex
cdbd3f061d fix(cache): enable Redis health_check_interval to surface half-open TCP
Without health_check_interval, a half-open TCP socket (NAT silently
dropped state, ELB idle-close) can leave pubsub.get_message hanging
past the SSE generator's keepalive cadence — the kernel never
surfaces the dead socket because no payload is in flight. Setting
health_check_interval=10 makes redis-py ping every 10s when
otherwise idle, so the next get_message after the dead window
raises and the SSE loop falls into its reconnect path instead of
silently freezing on the user.
2026-05-12 19:38:11 +01:00
Alex
2ac46fd858 refactor(sources): move source-id derivation out of worker module
application/api/user/sources/upload.py imported _derive_source_id
from application.worker — pulling the entire Celery worker module
into the API process at import time just for a two-line helper.

Move DOCSGPT_INGEST_NAMESPACE and the derivation function to a
new application/storage/db/source_ids.py module that both layers
can import without that dependency edge. worker.py re-exports the
old names (_derive_source_id, DOCSGPT_INGEST_NAMESPACE) for
backward-compatible imports from tests and any other in-tree
callers; new code should import from the new module directly.
2026-05-12 19:38:04 +01:00
Alex
daa4320da2 docs(events): enumerate publish_user_event None-return paths
The function returns Optional[str] today, with None conflating five
distinct outcomes (missing args / push disabled / unserialisable /
Redis down / XADD failed). Every current call site is fire-and-
forget and ignores the return, so the right move is to document the
five cases rather than promote to an enum return — keeps the API
small while making the diagnostic surface (logs) obvious. If a
future caller needs to react differently per reason, promote then.
2026-05-12 19:37:56 +01:00
Alex
e70a7a5115 fix(notifications): treat /c/new as no current conversation
useMatch('/c/:conversationId') treats the literal URL /c/new as a
real conversation id, so the toast suppression check confused
'user is on /c/new' with 'user is on the conversation needing
approval'. Explicit guard: when the matched id is 'new', fall
through to the no-match case so approval toasts still surface.
2026-05-12 19:16:08 +01:00
Alex
150d9f4e37 test(tasks): cover cleanup_message_events task body
Adds skipped-when-no-POSTGRES_URI and happy-path coverage for the
Celery janitor. The skipped path returns the documented short-circuit
shape without touching the repo. The happy path seeds a backdated
row, runs the task against the pg_conn fixture, and asserts the
retention window's row is deleted while in-window rows survive.
Mirrors the TestCleanupPendingToolState pattern.
2026-05-12 19:16:08 +01:00
Alex
746bcbc5f9 refactor(events): raise on malformed stream id instead of lex fallback
stream_id_compare's lex-fallback branch was a footgun: a malformed id
that sorts lex-greater than a real one would pin live-tail dedup
forever, dropping every subsequent legitimate event silently. Both
current callers in application/api/events/routes.py pre-validate
inputs against _STREAM_ID_RE before calling, so changing the function
to raise ValueError is a no-op on the happy path and turns the future-
caller footgun into a loud failure.
2026-05-12 19:16:08 +01:00
Alex
aa91117fbf docs(message-events): clarify repo vs wrapper payload contract
MessageEventsRepository.record accepts any JSONB-compatible value; the
streaming wrapper record_event tightens this to dicts only because the
live and replay paths reconstruct non-dict payloads differently. Spell
the split out so the next reader of the repo method doesn't assume the
wrapper's contract applies here.
2026-05-12 19:16:08 +01:00
Alex
abbd56cb66 docs(repo): remove stale planning docs from repo root
notification-channel-design.md, plan.md, and reminder-tool-design.md
were leftover Claude planning artifacts from the SSE substrate work
that landed accidentally. CLAUDE.md prohibits creating planning docs
unless asked — delete them.
2026-05-12 19:16:08 +01:00
Alex
85d8375e6c chore(frontend): drop orphaned getTaskStatus client
After the polling-removal sweep no caller in frontend/src/ references
userService.getTaskStatus or endpoints.USER.TASK_STATUS. The backend
route /api/task_status itself stays — agents, webhooks, e2e specs,
and the public docs still depend on it.
2026-05-12 19:16:07 +01:00
Alex
7e98d21b61 chore(upload): drop dead UploadTask.lastEventAt field
The lastEventAt field on UploadTask had no remaining consumers — the
matching Attachment.lastEventAt was cleaned up earlier. Remove the
field declaration and the slice write site.
2026-05-12 19:16:07 +01:00
Alex
249f9f9fe0 perf(streaming): batch message_events INSERTs per stream
complete_stream previously opened a fresh db_session() per yielded
event, doing one Postgres INSERT + commit per chunk on the WSGI
thread. Streaming answers emit ~100s of answer chunks per response,
so the route was paying ~100 PG roundtrips per stream serialized on
commit latency.

New BatchedJournalWriter in application/streaming/message_journal.py
accumulates rows per stream and flushes on three triggers:
- size: buffer reaches 16 entries
- time: 100ms elapsed since the last flush
- lifecycle: close() at end-of-stream

Live pubsub publishes still fire synchronously per record(), so
subscribers see events in real time — only the durable journal write
is amortized. On bulk INSERT IntegrityError the writer falls back to
per-row record() with the existing seq+1 retry so a single colliding
seq doesn't drop the rest of the batch.

complete_stream wires journal_writer.close() into every exit path
(happy end, tool-approval-paused end, GeneratorExit, error handler)
so the terminal event is committed before the generator returns —
otherwise a reconnecting client could snapshot up to the last flush
boundary and live-tail waiting for an end that's still in memory.

Repository gets bulk_record() — one SQLAlchemy executemany INSERT
for the bulk path. All-or-nothing on collision (Postgres aborts the
whole batch); the writer's per-row fallback handles recovery.
2026-05-12 18:20:19 +01:00
Alex
6c4346eb84 fix(streaming): tighten journal contract + recover from seq collisions
Two related fixes to application/streaming/message_journal.py.

1. record_event now rejects non-dict payloads at the gate. The
   live path (base.py::_emit) wrapped non-dicts as
   {"value": payload}; the replay path in event_replay synthesized
   {"type": event_type}. A reconnecting client would receive a
   different envelope than the one originally streamed. Now both
   paths see byte-identical envelopes because non-dicts can't be
   journaled at all. The corresponding event_replay fallback is
   replaced with a warn-and-skip for any legacy rows.

2. record_event handles IntegrityError on (message_id, sequence_no)
   collisions by reading latest_sequence_no and retrying once with
   latest+1. The most likely cause is a stale seq seed on a
   continuation retry where the route read MAX(seq) from a
   separate connection before another writer committed past it.
   Previously the error was swallowed and the event silently
   dropped from the journal; now it lands at the next available
   seq. The live pubsub publish uses the materialised seq so the
   journal row and the live frame agree.
2026-05-12 17:55:16 +01:00
Alex
cb3ca8a36b fix(events): skip replay budget INCR when no snapshot work possible
_allow_replay incremented the per-user counter on every
/api/events GET, including no-op connects from a fresh client
with no cursor against an empty backlog. React StrictMode dev
double-mounts plus a few tabs trivially tripped the default
30-per-60s budget on idle reconnects.

XLEN pre-check: when last_event_id is None and the user stream
is empty, the connect can't do snapshot work — return True
without INCR. Cursor-bearing connects still INCR unconditionally
(probing the cursor's relationship to stream contents would
require a redundant XRANGE).
2026-05-12 17:55:08 +01:00
Alex
4c8230fb6c fix(notifications): dedupe sseEventReceived against immediate dupes
Snapshot replay + live tail can both deliver the same id when the
live pubsub frame and the replay XRANGE overlap. The route's own
dedup floor catches the common case, but consumers walking
``recentEvents`` (FileTree, ConnectorTree, MCPServerModal,
ToolApprovalToast) would otherwise act on the same envelope
twice when a duplicate slipped through.

Belt-and-suspenders: short-circuit when the most recent id in
the ring matches the incoming one.
2026-05-12 17:54:59 +01:00
Alex
649557798d fix(events): drop live publish when journal write fails
application/events/publisher.py returned an envelope to live
pubsub subscribers even when the XADD to the durable journal
failed. The envelope had no ``id`` field, which bypassed the SSE
route's dedup floor and broke ``Last-Event-ID`` semantics for any
reconnecting client.

Best-effort delivery means dropping consistently, not delivering
inconsistent state. Now: if the journal write fails the publisher
returns None and skips the live publish entirely.
2026-05-12 17:54:52 +01:00
Alex
afe8354ca5 chore(mcp-oauth): delete orphaned getMCPOAuthStatus client
The /api/mcp_server/oauth_status/<task_id> endpoint was removed in
the prior commit; the corresponding userService method and the
MCP_OAUTH_STATUS endpoint constant had no remaining callers in the
frontend, so they're deleted along with it.
2026-05-12 16:02:41 +01:00
Alex
5483eb0e27 refactor(mcp-oauth): read status from SSE journal, drop polling endpoint
MCPOAuthManager.get_oauth_status now walks the per-user SSE Streams
journal (user:{user_id}:stream) for the latest mcp.oauth.* envelope
matching the task id, returning the status string derived from the
event type suffix and the payload fields. The worker is the single
source of truth — its publish_user_event calls write the same
record the SSE client receives live.

Removed:
- /api/mcp_server/oauth_status/<task_id> route in
  application/api/user/tools/mcp.py
- mcp_oauth_status worker function and mcp_oauth_status_task Celery
  wrapper
- All mcp_oauth_status:{task_id} Redis setex writes (4 in mcp_oauth,
  2 in DocsGPTOAuth.redirect_handler / callback_handler)
- The update_status closure in mcp_oauth that wrote the polling
  payload

Tests updated:
- get_oauth_status now takes (task_id, user_id); new coverage walks
  a fake xrevrange response for the completed envelope, the no-match
  case, and a Redis-down case
- Removed TestMCPOAuthStatus route tests and TestMcpOauthStatusTask
  celery-wrapper test
- Removed the two oauth_status methods from the integration runner

mcp_oauth:auth_url/state/code/error Redis keys remain — they are
the OAuth flow's own state (not the dropped polling payload).
2026-05-12 16:01:31 +01:00
Alex
bd2985db47 feat(source-ingest): plumb limited flag through SSE for token-cap UX
application/worker.py::ingest_worker and remote_worker now publish
``limited: bool`` on the source.ingest.completed envelope.
uploadSlice routes ``payload.limited === true`` to a failed status
with a ``tokenLimitReached`` flag, and UploadToast surfaces the
translated tokenLimit i18n string. No worker code path sets
limited=true today; this is a forward-looking contract so when
token-cap detection lands, the UX is already wired.
2026-05-12 15:49:15 +01:00
Alex
b99147ba83 refactor(mcp-oauth): carry authorization_url in SSE, remove polling
application/worker.py::mcp_oauth now publishes
authorization_url on the mcp.oauth.awaiting_redirect envelope.
frontend/src/modals/MCPServerModal.tsx consumes it from SSE
instead of polling /oauth_status/<task_id> every 1s.

The URL is generated inside DocsGPTOAuth.redirect_handler when
the FastMCP client triggers OAuth. The worker now plumbs a
publish callback through tool_config -> MCPTool -> DocsGPTOAuth
so the awaiting_redirect publish fires from inside the handler
at the exact point the URL becomes known. The legacy Redis
mcp_oauth_status setex writes and the GET
/api/mcp_server/oauth_status/<task_id> endpoint are kept as
belt-and-suspenders; nothing in the frontend reads them now.
2026-05-12 14:44:42 +01:00
Alex
c3023f8b71 refactor(source-ingest): remove polling, SSE-only
frontend/src/upload/Upload.tsx and
frontend/src/components/FileTree.tsx no longer run getTaskStatus
polling fallbacks. The source.ingest.* SSE reducers in
uploadSlice.ts and FileTree's slice walk are now the sole
drivers of upload/reingest state transitions.
2026-05-12 14:44:33 +01:00
Alex
c168a530f5 feat(connector): consume source.ingest.* SSE, remove polling
frontend/src/components/ConnectorTree.tsx now mirrors FileTree's
slice-walking pattern: it watches notifications.recentEvents
for source.ingest.{completed,failed} envelopes matching the
sync's source id, and no longer polls /task_status every 2s.
2026-05-12 14:44:27 +01:00
Alex
2d539f3199 refactor(attachments): remove polling, SSE-only
frontend/src/components/MessageInput.tsx no longer runs a 2s
setInterval against getTaskStatus for every processing
attachment. The attachment.* SSE reducers in uploadSlice.ts are
now the sole driver of attachment state transitions.
2026-05-12 14:44:21 +01:00
Alex
ed9444cf3d feat: SSE notification system
Adds a per-user SSE pipe (GET /api/events) plus a per-message
chat-stream reconnect endpoint (GET /api/messages/<id>/events).

Backend substrate:
- application/events/ — durable journal (Redis Streams) + live
  pub/sub for user-scoped events, with publish_user_event() as
  the worker-side entrypoint.
- application/streaming/ — broadcast_channel for pub/sub fanout
  and event_replay for the per-message snapshot+tail path.
- application/storage/db/repositories/message_events.py +
  alembic 0007 — Postgres journal for chat-stream events.
- application/worker.py — ingest/reingest/remote/connector/
  attachment/mcp_oauth tasks publish queued/progress/completed/
  failed envelopes alongside their existing status updates.

Frontend client:
- frontend/src/events/ — connect/reconnect, Last-Event-ID cursor,
  backoff with jitter. Each tab runs its own connection; no
  cross-tab dedup (future work).
- frontend/src/notifications/ — recentEvents ring, cursor
  tracking, tool-approval toast.
- frontend/src/upload/uploadSlice.ts — extraReducers for
  source.ingest.* and attachment.* events.

Coverage: 132 SSE tests across events substrate, replay, journal,
routes, and worker publishes.
2026-05-12 14:29:45 +01:00
58 changed files with 492 additions and 2252 deletions

View File

@@ -8,7 +8,7 @@ RUN apt-get update && \
add-apt-repository ppa:deadsnakes/ppa && \
apt-get update && \
apt-get install -y --no-install-recommends gcc g++ wget unzip libc6-dev python3.12 python3.12-venv python3.12-dev && \
rm -rf /var/lib/apt/lists/*
rm -rf /var/lib/apt/lists/*
# Verify Python installation and setup symlink
RUN if [ -f /usr/bin/python3.12 ]; then \
@@ -73,7 +73,7 @@ COPY --from=builder /models /app/models
COPY . /app/application
# Change the ownership of the /app directory to the appuser
RUN mkdir -p /app/application/inputs/local
RUN chown -R appuser:appuser /app
@@ -82,11 +82,6 @@ ENV FLASK_APP=app.py \
FLASK_DEBUG=true \
PATH="/venv/bin:$PATH"
ENV MALLOC_ARENA_MAX=2 \
OMP_NUM_THREADS=4 \
MKL_NUM_THREADS=4 \
OPENBLAS_NUM_THREADS=4
# Expose the port the app runs on
EXPOSE 7091

View File

@@ -114,8 +114,6 @@ class BaseAgent(ABC):
self.compressed_summary = compressed_summary
self.current_token_count = 0
self.context_limit_reached = False
self.conversation_id: Optional[str] = None
self.initial_user_id: Optional[str] = None
@log_activity()
def gen(

View File

@@ -2,7 +2,7 @@ import json
import logging
import re
from typing import Any, Dict, Optional
from urllib.parse import quote, urlencode
from urllib.parse import urlencode
import requests
@@ -11,7 +11,7 @@ from application.agents.tools.api_body_serializer import (
RequestBodySerializer,
)
from application.agents.tools.base import Tool
from application.security.safe_url import UnsafeUserUrlError, pinned_request
from application.core.url_validation import validate_url, SSRFError
logger = logging.getLogger(__name__)
@@ -70,16 +70,18 @@ class APITool(Tool):
Returns:
Dict with status_code, data, and message
"""
_VALID_METHODS = {"GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"}
request_url = url
request_headers = headers.copy() if headers else {}
response = None
if method.upper() not in _VALID_METHODS:
# Validate URL to prevent SSRF attacks
try:
validate_url(request_url)
except SSRFError as e:
logger.error(f"URL validation failed: {e}")
return {
"status_code": None,
"message": f"Unsupported HTTP method: {method}",
"message": f"URL validation error: {e}",
"data": None,
}
@@ -89,9 +91,8 @@ class APITool(Tool):
for match in re.finditer(r"\{([^}]+)\}", request_url):
param_name = match.group(1)
if param_name in query_params:
safe_value = quote(str(query_params[param_name]), safe="")
request_url = request_url.replace(
f"{{{param_name}}}", safe_value
f"{{{param_name}}}", str(query_params[param_name])
)
path_params_used.add(param_name)
remaining_params = {
@@ -102,6 +103,19 @@ class APITool(Tool):
separator = "&" if "?" in request_url else "?"
request_url = f"{request_url}{separator}{query_string}"
# Re-validate URL after parameter substitution to prevent SSRF via path params
try:
validate_url(request_url)
except SSRFError as e:
logger.error(f"URL validation failed after parameter substitution: {e}")
return {
"status_code": None,
"message": f"URL validation error: {e}",
"data": None,
}
# Serialize body based on content type
if body and body != {}:
try:
serialized_body, body_headers = RequestBodySerializer.serialize(
@@ -127,13 +141,49 @@ class APITool(Tool):
f"API Call: {method} {request_url} | Content-Type: {request_headers.get('Content-Type', 'N/A')}"
)
response = pinned_request(
method,
request_url,
data=serialized_body,
headers=request_headers,
timeout=DEFAULT_TIMEOUT,
)
if method.upper() == "GET":
response = requests.get(
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
)
elif method.upper() == "POST":
response = requests.post(
request_url,
data=serialized_body,
headers=request_headers,
timeout=DEFAULT_TIMEOUT,
)
elif method.upper() == "PUT":
response = requests.put(
request_url,
data=serialized_body,
headers=request_headers,
timeout=DEFAULT_TIMEOUT,
)
elif method.upper() == "DELETE":
response = requests.delete(
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
)
elif method.upper() == "PATCH":
response = requests.patch(
request_url,
data=serialized_body,
headers=request_headers,
timeout=DEFAULT_TIMEOUT,
)
elif method.upper() == "HEAD":
response = requests.head(
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
)
elif method.upper() == "OPTIONS":
response = requests.options(
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
)
else:
return {
"status_code": None,
"message": f"Unsupported HTTP method: {method}",
"data": None,
}
response.raise_for_status()
data = self._parse_response(response)
@@ -143,13 +193,6 @@ class APITool(Tool):
"data": data,
"message": "API call successful.",
}
except UnsafeUserUrlError as e:
logger.error(f"URL validation failed: {e}")
return {
"status_code": None,
"message": f"URL validation error: {e}",
"data": None,
}
except requests.exceptions.Timeout:
logger.error(f"Request timeout for {request_url}")
return {

View File

@@ -1,5 +1,5 @@
import requests
from application.agents.tools.base import Tool
from application.security.safe_url import UnsafeUserUrlError, pinned_request
class NtfyTool(Tool):
"""
@@ -71,12 +71,7 @@ class NtfyTool(Tool):
if self.token:
headers["Authorization"] = f"Basic {self.token}"
data = message.encode("utf-8")
try:
response = pinned_request(
"POST", url, data=data, headers=headers, timeout=100,
)
except UnsafeUserUrlError as e:
return {"status_code": None, "message": f"URL validation error: {e}"}
response = requests.post(url, headers=headers, data=data, timeout=100)
return {"status_code": response.status_code, "message": "Message sent"}
def get_actions_metadata(self):

View File

@@ -1,6 +1,7 @@
import requests
from markdownify import markdownify
from application.agents.tools.base import Tool
from application.security.safe_url import UnsafeUserUrlError, pinned_request
from application.core.url_validation import validate_url, SSRFError
class ReadWebpageTool(Tool):
"""
@@ -30,24 +31,28 @@ class ReadWebpageTool(Tool):
if not url:
return "Error: URL parameter is missing."
# Validate URL to prevent SSRF attacks
try:
response = pinned_request(
"GET",
url,
headers={'User-Agent': 'DocsGPT-Agent/1.0'},
timeout=10,
)
response.raise_for_status()
url = validate_url(url)
except SSRFError as e:
return f"Error: URL validation failed - {e}"
try:
response = requests.get(url, timeout=10, headers={'User-Agent': 'DocsGPT-Agent/1.0'})
response.raise_for_status() # Raise an exception for HTTP errors (4xx or 5xx)
html_content = response.text
#soup = BeautifulSoup(html_content, 'html.parser')
markdown_content = markdownify(html_content, heading_style="ATX", newline_style="BACKSLASH")
return markdown_content
except UnsafeUserUrlError as e:
return f"Error: URL validation failed - {e}"
except Exception as e:
except requests.exceptions.RequestException as e:
return f"Error fetching URL {url}: {e}"
except Exception as e:
return f"Error processing URL {url}: {e}"
def get_actions_metadata(self):
"""

View File

@@ -1,44 +0,0 @@
"""0008 ingest_chunk_progress.status — terminal flag for stalled ingests.
The reconciler's stalled-ingest sweep had no terminal write, so a dead
ingest re-alerted every ~30 min forever. ``status`` lets it escalate a
stalled checkpoint to ``'stalled'`` once and stop re-selecting it;
``init_progress`` resets it to ``'active'`` on reingest.
Revision ID: 0008_ingest_progress_status
Revises: 0007_message_events
"""
from typing import Sequence, Union
from alembic import op
revision: str = "0008_ingest_progress_status"
down_revision: Union[str, None] = "0007_message_events"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Constant DEFAULT — metadata-only ADD COLUMN, no table rewrite.
op.execute(
"""
ALTER TABLE ingest_chunk_progress
ADD COLUMN status TEXT NOT NULL DEFAULT 'active'
CHECK (status IN ('active', 'stalled'));
"""
)
# Partial index for the reconciler's stalled-ingest sweep.
op.execute(
"CREATE INDEX ingest_chunk_progress_active_idx "
"ON ingest_chunk_progress (last_updated) "
"WHERE status = 'active';"
)
def downgrade() -> None:
op.execute("DROP INDEX IF EXISTS ingest_chunk_progress_active_idx;")
op.execute(
"ALTER TABLE ingest_chunk_progress DROP COLUMN IF EXISTS status;"
)

View File

@@ -3,7 +3,6 @@
from __future__ import annotations
import functools
import inspect
import logging
import threading
import uuid
@@ -27,20 +26,13 @@ LEASE_HEARTBEAT_INTERVAL = 30
LEASE_RETRY_MAX = 10
def with_idempotency(
task_name: str,
*,
on_poison: Optional[Callable[[str, dict], None]] = None,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def with_idempotency(task_name: str) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""Short-circuit on completed key; gate concurrent runs via a lease.
The guard key is the caller's ``idempotency_key``, or one synthesized
from ``source_id`` so a keyless dispatch is still poison-guarded.
Entry short-circuits:
- completed row → return cached result
- live lease held → retry(countdown=LEASE_TTL_SECONDS)
- attempt_count > MAX_TASK_ATTEMPTS → poison alert; ``on_poison`` fires
- attempt_count > MAX_TASK_ATTEMPTS → poison-loop alert
Success writes ``completed``; exceptions leave ``pending`` for
autoretry until the poison-loop guard trips.
"""
@@ -48,14 +40,7 @@ def with_idempotency(
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
@functools.wraps(fn)
def wrapper(self, *args: Any, idempotency_key: Any = None, **kwargs: Any) -> Any:
explicit_key = (
idempotency_key
if isinstance(idempotency_key, str) and idempotency_key
else None
)
# A keyless dispatch still gets the guard via a synthesized key;
# None means no anchor exists — run unguarded, as before.
key = explicit_key or _synthesize_guard_key(task_name, kwargs)
key = idempotency_key if isinstance(idempotency_key, str) and idempotency_key else None
if key is None:
return fn(self, *args, idempotency_key=idempotency_key, **kwargs)
@@ -103,9 +88,6 @@ def with_idempotency(
"attempts": attempt,
}
_finalize(key, poisoned, status="failed")
_run_poison_hook(
on_poison, task_name, fn, self, args, kwargs, idempotency_key,
)
return poisoned
heartbeat_thread, heartbeat_stop = _start_lease_heartbeat(
@@ -127,45 +109,6 @@ def with_idempotency(
return decorator
def _synthesize_guard_key(task_name: str, kwargs: dict) -> Optional[str]:
"""Derive a deterministic guard key from ``source_id`` for a keyless dispatch.
``source_id`` is stable across broker redeliveries and unique per
upload, so the poison-loop counter survives an OOM SIGKILL. Returns
``None`` when absent — the dispatch then runs unguarded as before.
"""
source_id = kwargs.get("source_id")
if source_id:
return f"auto:{task_name}:{source_id}"
return None
def _run_poison_hook(
on_poison: Optional[Callable[[str, dict], None]],
task_name: str,
fn: Callable[..., Any],
task_self: Any,
args: tuple,
kwargs: dict,
idempotency_key: Any,
) -> None:
"""Invoke a task's poison-path hook with named call args; swallow failures.
A hook failure must never change the poison-guard outcome.
"""
if on_poison is None:
return
try:
bound = inspect.signature(fn).bind_partial(
task_self, *args, idempotency_key=idempotency_key, **kwargs,
)
on_poison(task_name, dict(bound.arguments))
except Exception:
logger.exception(
"idempotency: poison hook failed for task=%s", task_name,
)
def _lookup_completed(key: str) -> Any:
"""Return cached ``result_json`` if a completed row exists for ``key``, else None."""
with db_readonly() as conn:

View File

@@ -114,11 +114,11 @@ def run_reconciliation() -> Dict[str, Any]:
},
)
# Q4: ingest checkpoints whose heartbeat has gone silent. Each is
# escalated to terminal ``status='stalled'`` and alerted once — no
# worker kill, no rollback of the partial embed. The 'stalled' flag
# ends the re-alert loop and drives the "indexing failed" badge the
# sources list derives from this row.
# Q4: ingest checkpoints whose heartbeat has gone silent. The
# reconciler only escalates (alerts) — it doesn't kill the worker
# or roll back the partial embed. The next dispatch resumes from
# ``last_index`` thanks to the per-chunk checkpoint, so this is an
# observability sweep, not a recovery action.
with engine.begin() as conn:
repo = ReconciliationRepository(conn)
for row in repo.find_and_lock_stalled_ingests():
@@ -134,7 +134,8 @@ def run_reconciliation() -> Dict[str, Any]:
"last_updated": str(row.get("last_updated")),
},
)
repo.mark_ingest_stalled(str(row["source_id"]))
# Bump the heartbeat so we don't re-alert every tick.
repo.touch_ingest_progress(str(row["source_id"]))
# Q5: idempotency rows whose lease expired with attempts exhausted.
# The wrapper's poison-loop guard normally finalises these, but if

View File

@@ -7,12 +7,8 @@ from flask import current_app, jsonify, make_response, redirect, request
from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.tasks import reingest_source_task, sync_source
from application.api.user.tasks import sync_source
from application.core.settings import settings
from application.parser.remote.remote_creator import normalize_remote_data
from application.storage.db.repositories.ingest_chunk_progress import (
IngestChunkProgressRepository,
)
from application.storage.db.repositories.sources import SourcesRepository
from application.storage.db.session import db_readonly, db_session
from application.storage.storage_creator import StorageCreator
@@ -143,8 +139,6 @@ class PaginatedSources(Resource):
"provider": provider,
"isNested": bool(doc.get("directory_structure")),
"type": doc.get("type", "file"),
# Derived in SourcesRepository.list_for_user.
"ingestStatus": doc.get("ingest_status"),
}
)
response = {
@@ -328,7 +322,7 @@ class SyncSource(Resource):
),
400,
)
source_data = normalize_remote_data(source_type, doc.get("remote_data"))
source_data = doc.get("remote_data")
if not source_data:
return make_response(
jsonify({"success": False, "message": "Source is not syncable"}), 400
@@ -352,70 +346,6 @@ class SyncSource(Resource):
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
@sources_ns.route("/sources/reingest")
class ReingestSource(Resource):
reingest_source_model = api.model(
"ReingestSourceModel",
{"source_id": fields.String(required=True, description="Source ID")},
)
@api.expect(reingest_source_model)
@api.doc(
description="Re-run ingestion for a source — e.g. to recover a "
"stalled embed flagged by the reconciler."
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json() or {}
missing_fields = check_required_fields(data, ["source_id"])
if missing_fields:
return missing_fields
source_id = data["source_id"]
try:
with db_readonly() as conn:
doc = SourcesRepository(conn).get_any(source_id, user)
except Exception as err:
current_app.logger.error(
f"Error looking up source: {err}", exc_info=True
)
return make_response(
jsonify({"success": False, "message": "Invalid source ID"}), 400
)
if not doc:
return make_response(
jsonify({"success": False, "message": "Source not found"}), 404
)
resolved_source_id = str(doc["id"])
# Drop the stale chunk-progress row so the sources list stops
# deriving a 'failed' status; reingest never rewrites it itself.
try:
with db_session() as conn:
IngestChunkProgressRepository(conn).delete(resolved_source_id)
except Exception as err:
current_app.logger.warning(
f"Could not clear ingest progress for {resolved_source_id}: "
f"{err}",
exc_info=True,
)
try:
# Scoped key so repeated clicks collapse onto one reingest.
task = reingest_source_task.delay(
source_id=resolved_source_id,
user=user,
idempotency_key=f"reingest-source:{user}:{resolved_source_id}",
)
except Exception as err:
current_app.logger.error(
f"Error starting reingest for source {source_id}: {err}",
exc_info=True,
)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
@sources_ns.route("/directory_structure")
class DirectoryStructure(Resource):
@api.doc(

View File

@@ -27,42 +27,8 @@ DURABLE_TASK = dict(
)
# operation tag for the poison-path source.ingest.failed event, per task.
_INGEST_POISON_OPERATION = {
"ingest": "upload",
"ingest_remote": "upload",
"ingest_connector_task": "upload",
"reingest_source_task": "reingest",
}
def _emit_ingest_poison_event(task_name, bound):
"""Publish a terminal ``source.ingest.failed`` when the poison-guard trips.
The guard returns before the worker runs, so the worker's own failed
event never fires — without this the upload toast spins on "training".
"""
user = bound.get("user")
source_id = bound.get("source_id")
if not user or not source_id:
return
from application.events.publisher import publish_user_event
publish_user_event(
user,
"source.ingest.failed",
{
"source_id": str(source_id),
"filename": bound.get("filename") or "",
"operation": _INGEST_POISON_OPERATION.get(task_name, "upload"),
"error": "Ingestion stopped after repeated failures.",
},
scope={"kind": "source", "id": str(source_id)},
)
@celery.task(**DURABLE_TASK)
@with_idempotency(task_name="ingest", on_poison=_emit_ingest_poison_event)
@with_idempotency(task_name="ingest")
def ingest(
self,
directory,
@@ -91,7 +57,7 @@ def ingest(
@celery.task(**DURABLE_TASK)
@with_idempotency(task_name="ingest_remote", on_poison=_emit_ingest_poison_event)
@with_idempotency(task_name="ingest_remote")
def ingest_remote(
self, source_data, job_name, user, loader,
idempotency_key=None, source_id=None,
@@ -105,9 +71,7 @@ def ingest_remote(
@celery.task(**DURABLE_TASK)
@with_idempotency(
task_name="reingest_source_task", on_poison=_emit_ingest_poison_event,
)
@with_idempotency(task_name="reingest_source_task")
def reingest_source_task(self, source_id, user, idempotency_key=None):
from application.worker import reingest_source_worker
@@ -164,9 +128,7 @@ def process_agent_webhook(self, agent_id, payload, idempotency_key=None):
@celery.task(**DURABLE_TASK)
@with_idempotency(
task_name="ingest_connector_task", on_poison=_emit_ingest_poison_event,
)
@with_idempotency(task_name="ingest_connector_task")
def ingest_connector_task(
self,
job_name,

View File

@@ -1,8 +1,5 @@
import ctypes
import gc
import inspect
import logging
import sys
import threading
from celery import Celery
@@ -101,34 +98,6 @@ def _unbind_task_log_context(task_id, **_):
)
def _trim_native_heap() -> None:
"""Return freed glibc heap pages to the OS (Linux only; no-op elsewhere)."""
# docling/torch parsing makes large transient allocations; glibc keeps the
# freed pages in per-thread malloc arenas rather than returning them, so a
# long-lived worker child's RSS only ever climbs. malloc_trim hands them
# back. The symbol is glibc-only — absent in macOS libc.
if not sys.platform.startswith("linux"):
return
try:
ctypes.CDLL("libc.so.6").malloc_trim(0)
except (OSError, AttributeError):
pass
@task_postrun.connect
def _reclaim_memory_after_task(*args, **kwargs):
"""Drop per-task allocations so the prefork child's RSS doesn't ratchet."""
gc.collect()
torch = sys.modules.get("torch")
if torch is not None:
try:
if torch.cuda.is_available():
torch.cuda.empty_cache()
except Exception:
pass
_trim_native_heap()
@worker_ready.connect
def _run_version_check(*args, **kwargs):
"""Kick off the anonymous version check on worker startup.

View File

@@ -31,10 +31,3 @@ worker_prefetch_multiplier = settings.CELERY_WORKER_PREFETCH_MULTIPLIER
broker_transport_options = {"visibility_timeout": settings.CELERY_VISIBILITY_TIMEOUT}
result_expires = 86400 * 7
task_track_started = True
# Recycle the prefork worker child to bound native-heap growth from
# docling/torch parsing. Left unset (Celery's unlimited default) when 0.
if settings.CELERY_WORKER_MAX_MEMORY_PER_CHILD > 0:
worker_max_memory_per_child = settings.CELERY_WORKER_MAX_MEMORY_PER_CHILD
if settings.CELERY_WORKER_MAX_TASKS_PER_CHILD > 0:
worker_max_tasks_per_child = settings.CELERY_WORKER_MAX_TASKS_PER_CHILD

View File

@@ -36,11 +36,6 @@ class Settings(BaseSettings):
# and Dify defaults; long ingests can override via env.
CELERY_WORKER_PREFETCH_MULTIPLIER: int = 1
CELERY_VISIBILITY_TIMEOUT: int = 3600
# Recycle the prefork worker child once its resident size crosses this many
# kilobytes — backstops native-heap growth from docling/torch parsing. 0 disables.
CELERY_WORKER_MAX_MEMORY_PER_CHILD: int = 4194304
# Recycle the child after this many tasks; 0 disables (memory cap is the primary knob).
CELERY_WORKER_MAX_TASKS_PER_CHILD: int = 0
# Only consulted when VECTOR_STORE=mongodb or when running scripts/db/backfill.py; user data lives in Postgres.
MONGO_URI: Optional[str] = None
# User-data Postgres DB.
@@ -66,9 +61,6 @@ class Settings(BaseSettings):
PARSE_IMAGE_REMOTE: bool = False
DOCLING_OCR_ENABLED: bool = False # Enable OCR for docling parsers (PDF, images)
DOCLING_OCR_ATTACHMENTS_ENABLED: bool = False # Enable OCR for docling when parsing attachments
# Pages docling's threaded pipeline buffers in flight; the library
# default (100) drives worker RSS to ~3 GB on a mid-size PDF.
DOCLING_PIPELINE_QUEUE_MAX_SIZE: int = 2
VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb" or "pgvector"
RETRIEVERS_ENABLED: list = ["classic_rag"]
AGENT_NAME: str = "classic"

View File

@@ -154,8 +154,6 @@ def embed_and_store_documents(
*,
attempt_id: Optional[str] = None,
user_id: Optional[str] = None,
progress_start: int = 0,
progress_end: int = 100,
) -> None:
"""Embeds documents and stores them in a vector store.
@@ -178,11 +176,6 @@ def embed_and_store_documents(
published to ``user:{user_id}`` for the in-app upload toast.
``None`` is the safe default — workers without a user
context (e.g. background syncs) skip the publish.
progress_start: Percent the reported progress maps to at chunk 0.
Lets a caller reserve the lower band for an earlier stage
(e.g. parsing). Defaults to ``0`` (embed owns the whole bar).
progress_end: Percent the reported progress maps to at the final
chunk. Defaults to ``100``.
Returns:
None
@@ -264,7 +257,6 @@ def embed_and_store_documents(
failed_idx: int | None = None
last_published_pct = -1
source_id_str = str(source_id)
progress_span = progress_end - progress_start
for idx in tqdm(
range(loop_start, total_docs),
desc="Embedding 🦖",
@@ -274,10 +266,8 @@ def embed_and_store_documents(
):
doc = docs[idx]
try:
# Map the embed loop into [progress_start, progress_end].
progress = progress_start + int(
((idx + 1) / total_docs) * progress_span
)
# Update task status for progress tracking
progress = int(((idx + 1) / total_docs) * 100)
task_status.update_state(state="PROGRESS", meta={"current": progress})
# SSE push for sub-second upload-toast updates. Throttled to one

View File

@@ -211,22 +211,13 @@ class SimpleDirectoryReader(BaseReader):
return new_input_files
def load_data(
self,
concatenate: bool = False,
progress_callback: Optional[Callable[[int, int], None]] = None,
) -> List[Document]:
def load_data(self, concatenate: bool = False) -> List[Document]:
"""Load data from the input directory.
Args:
concatenate (bool): whether to concatenate all files into one document.
If set to True, file metadata is ignored.
False by default.
progress_callback (Optional[Callable[[int, int], None]]): Called
after each file is parsed with ``(files_done, total_files)``.
Lets callers surface parse/OCR progress before embedding
begins. Exceptions raised by the callback are swallowed so
progress reporting can never fail ingestion.
Returns:
List[Document]: A list of documents.
@@ -235,9 +226,8 @@ class SimpleDirectoryReader(BaseReader):
data_list: List[str] = []
metadata_list = []
self.file_token_counts = {}
total_files = len(self.input_files)
for file_index, input_file in enumerate(self.input_files):
for input_file in self.input_files:
suffix_lower = input_file.suffix.lower()
parser_metadata = {}
if suffix_lower in self.file_extractor:
@@ -287,15 +277,7 @@ class SimpleDirectoryReader(BaseReader):
else:
data_list.append(str(data))
metadata_list.append(base_metadata)
if progress_callback is not None:
try:
progress_callback(file_index + 1, total_files)
except Exception:
logging.warning(
"load_data progress callback failed", exc_info=True
)
# Build directory structure if input_dir is provided
if hasattr(self, 'input_dir'):
self.directory_structure = self.build_directory_structure(self.input_dir)

View File

@@ -16,29 +16,6 @@ from application.parser.file.base_parser import BaseParser
logger = logging.getLogger(__name__)
# Per-stage batch size for docling's threaded pipeline; 1 holds the
# concurrent working set to a single page (see _apply_pipeline_caps).
_PIPELINE_BATCH_SIZE = 1
def _apply_pipeline_caps(pipeline_options) -> None:
"""Cap docling's threaded-pipeline queue depth and batch sizes in place.
hasattr-guarded so docling builds without these knobs are unaffected.
"""
from application.core.settings import settings
caps = {
"queue_max_size": max(1, settings.DOCLING_PIPELINE_QUEUE_MAX_SIZE),
"layout_batch_size": _PIPELINE_BATCH_SIZE,
"table_batch_size": _PIPELINE_BATCH_SIZE,
"ocr_batch_size": _PIPELINE_BATCH_SIZE,
}
for name, value in caps.items():
if hasattr(pipeline_options, name):
setattr(pipeline_options, name, value)
class DoclingParser(BaseParser):
"""Parser using docling for advanced document processing.
@@ -109,7 +86,6 @@ class DoclingParser(BaseParser):
do_ocr=self.ocr_enabled,
do_table_structure=self.table_structure,
)
_apply_pipeline_caps(pipeline_options)
if self.ocr_enabled:
ocr_options = self._get_ocr_options()

View File

@@ -1,11 +1,11 @@
import logging
import os
import requests
from urllib.parse import urlparse, urljoin
from bs4 import BeautifulSoup
from application.parser.remote.base import BaseRemote
from application.parser.schema.base import Document
from application.core.url_validation import validate_url, SSRFError
from application.security.safe_url import UnsafeUserUrlError, pinned_request
from langchain_community.document_loaders import WebBaseLoader
class CrawlerLoader(BaseRemote):
@@ -35,7 +35,14 @@ class CrawlerLoader(BaseRemote):
visited_urls.add(current_url)
try:
response = pinned_request("GET", current_url, timeout=30)
# Validate each URL before making requests
try:
validate_url(current_url)
except SSRFError as e:
logging.warning(f"Skipping URL due to validation failure: {current_url} - {e}")
continue
response = requests.get(current_url, timeout=30)
response.raise_for_status()
loader = self.loader([current_url])
docs = loader.load()

View File

@@ -1,8 +1,8 @@
import requests
from urllib.parse import urlparse, urljoin
from bs4 import BeautifulSoup
from application.parser.remote.base import BaseRemote
from application.core.url_validation import validate_url, SSRFError
from application.security.safe_url import UnsafeUserUrlError, pinned_request
import re
from markdownify import markdownify
from application.parser.schema.base import Document
@@ -20,6 +20,7 @@ class CrawlerLoader(BaseRemote):
"""
self.limit = limit
self.allow_subdomains = allow_subdomains
self.session = requests.Session()
def load_data(self, inputs):
url = inputs
@@ -90,13 +91,15 @@ class CrawlerLoader(BaseRemote):
def _fetch_page(self, url):
try:
response = pinned_request("GET", url, timeout=10)
# Validate URL before fetching to prevent SSRF
validate_url(url)
response = self.session.get(url, timeout=10)
response.raise_for_status()
return response.text
except UnsafeUserUrlError as e:
except SSRFError as e:
print(f"URL validation failed for {url}: {e}")
return None
except Exception as e:
except requests.exceptions.RequestException as e:
print(f"Error fetching URL {url}: {e}")
return None

View File

@@ -1,5 +1,3 @@
import json
from application.parser.remote.sitemap_loader import SitemapLoader
from application.parser.remote.crawler_loader import CrawlerLoader
from application.parser.remote.web_loader import WebLoader
@@ -34,59 +32,3 @@ class RemoteCreator:
if not loader_class:
raise ValueError(f"No loader class found for type {type}")
return loader_class(*args, **kwargs)
# Loader types whose load_data expects a URL string, not a config dict.
_URL_LOADER_TYPES = {"url", "crawler", "sitemap", "github"}
# Keys a remote_data dict may hold the URL under (``raw`` is the legacy shape).
_URL_DATA_KEYS = ("url", "urls", "repo_url", "raw")
def normalize_remote_data(source_type, remote_data):
"""Convert a stored ``sources.remote_data`` JSONB value into the
``source_data`` shape the matching loader expects.
Args:
source_type: The ``sources.type`` value (the loader name).
remote_data: The stored ``remote_data`` (dict, list, str, or None).
Returns:
Loader input: a URL string or list for url/crawler/sitemap/github,
a JSON string for reddit, a dict for s3; ``None`` when the row has
nothing syncable.
"""
if remote_data is None:
return None
# Some legacy rows stored the JSON itself as a string.
if isinstance(remote_data, str):
stripped = remote_data.strip()
if stripped[:1] in ("{", "["):
try:
remote_data = json.loads(stripped)
except json.JSONDecodeError:
# Not actually JSON — leave remote_data as the original
# string; the per-loader branches below handle a string.
pass
loader = (source_type or "").lower()
if loader in _URL_LOADER_TYPES:
if isinstance(remote_data, dict):
for key in _URL_DATA_KEYS:
value = remote_data.get(key)
if value:
return value
# No URL key — None keeps the loader off the dict-crash path.
return None
return remote_data
if loader == "reddit":
# reddit's loader runs json.loads() on its input — needs a string.
if isinstance(remote_data, (dict, list)):
return json.dumps(remote_data)
return remote_data
# s3's loader accepts a dict or JSON string; pass it through unchanged.
return remote_data

View File

@@ -1,9 +1,9 @@
import logging
import re
import requests
import re # Import regular expression library
import defusedxml.ElementTree as ET
from application.parser.remote.base import BaseRemote
from application.core.url_validation import validate_url, SSRFError
from application.security.safe_url import UnsafeUserUrlError, pinned_request
class SitemapLoader(BaseRemote):
def __init__(self, limit=20):
@@ -53,12 +53,14 @@ class SitemapLoader(BaseRemote):
def _extract_urls(self, sitemap_url):
try:
response = pinned_request("GET", sitemap_url, timeout=30)
response.raise_for_status()
except UnsafeUserUrlError as e:
# Validate URL before fetching to prevent SSRF
validate_url(sitemap_url)
response = requests.get(sitemap_url, timeout=30)
response.raise_for_status() # Raise an exception for HTTP errors
except SSRFError as e:
print(f"URL validation failed for sitemap: {sitemap_url}. Error: {e}")
return []
except Exception as e:
except (requests.exceptions.HTTPError, requests.exceptions.ConnectionError) as e:
print(f"Failed to fetch sitemap: {sitemap_url}. Error: {e}")
return []
@@ -95,6 +97,13 @@ class SitemapLoader(BaseRemote):
nested_sitemap_url = sitemap.text
if not nested_sitemap_url:
continue
try:
nested_sitemap_url = validate_url(nested_sitemap_url)
except SSRFError as e:
logging.error(
f"URL validation failed for nested sitemap {nested_sitemap_url}: {e}"
)
continue
urls.extend(self._extract_urls(nested_sitemap_url))
return urls

View File

@@ -291,55 +291,6 @@ def _ip_to_url_host(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> str:
return str(ip)
def pinned_request(
method: str,
url: str,
*,
data: Any = None,
json: Any = None,
headers: dict[str, str] | None = None,
timeout: float = 90.0,
allow_redirects: bool = False,
) -> requests.Response:
"""Send an HTTP request with the connection pinned to a validated IP,
closing the DNS-rebinding TOCTOU window left by the naive
validate-then-``requests`` pattern.
Raises:
UnsafeUserUrlError: If the URL fails the SSRF guard.
requests.RequestException: For network-level failures.
"""
host, ip, parts = _validate_and_pick_ip(url)
netloc = _ip_to_url_host(ip)
if parts.port is not None:
netloc = f"{netloc}:{parts.port}"
pinned_url = urlunsplit(
(parts.scheme, netloc, parts.path, parts.query, parts.fragment)
)
request_headers = dict(headers or {})
host_header = host if parts.port is None else f"{host}:{parts.port}"
request_headers["Host"] = host_header
session = requests.Session()
if parts.scheme == "https":
session.mount("https://", _PinnedHostAdapter(host))
try:
return session.request(
method=method.upper(),
url=pinned_url,
data=data,
json=json,
headers=request_headers,
timeout=timeout,
allow_redirects=allow_redirects,
)
finally:
session.close()
def pinned_post(
url: str,
*,
@@ -377,15 +328,33 @@ def pinned_post(
requests.RequestException: For network-level failures.
"""
return pinned_request(
"POST",
url,
json=json,
headers=headers,
timeout=timeout,
allow_redirects=allow_redirects,
host, ip, parts = _validate_and_pick_ip(url)
netloc = _ip_to_url_host(ip)
if parts.port is not None:
netloc = f"{netloc}:{parts.port}"
pinned_url = urlunsplit(
(parts.scheme, netloc, parts.path, parts.query, parts.fragment)
)
request_headers = dict(headers or {})
host_header = host if parts.port is None else f"{host}:{parts.port}"
request_headers["Host"] = host_header
session = requests.Session()
if parts.scheme == "https":
session.mount("https://", _PinnedHostAdapter(host))
try:
return session.post(
pinned_url,
json=json,
headers=request_headers,
timeout=timeout,
allow_redirects=allow_redirects,
)
finally:
session.close()
class _PinnedHTTPSTransport(httpx.HTTPTransport):
"""``httpx`` transport pinned to a single validated IP literal.

View File

@@ -514,9 +514,6 @@ ingest_chunk_progress_table = Table(
# same task resumes from the checkpoint, but a separate invocation
# (manual reingest, scheduled sync) resets to a clean re-index.
Column("attempt_id", Text),
# Added in ``0008_ingest_progress_status``. The reconciler flips
# this to 'stalled'; ``init_progress`` resets it to 'active'.
Column("status", Text, nullable=False, server_default="active"),
)

View File

@@ -41,9 +41,6 @@ class IngestChunkProgressRepository:
rows with NULL ``attempt_id`` resume against another NULL
caller (e.g. test fixtures), but get reset the moment a real
``attempt_id`` arrives.
Both branches also reset ``status`` to ``'active'``, clearing a
prior reconciler ``'stalled'`` escalation.
"""
result = self._conn.execute(
text(
@@ -71,8 +68,7 @@ class IngestChunkProgressRepository:
THEN ingest_chunk_progress.embedded_chunks
ELSE 0
END,
attempt_id = EXCLUDED.attempt_id,
status = 'active'
attempt_id = EXCLUDED.attempt_id
RETURNING *
"""
),
@@ -117,23 +113,6 @@ class IngestChunkProgressRepository:
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def delete(self, source_id: str) -> bool:
"""Delete the progress row for ``source_id``.
A manual reingest supersedes any prior ingest state — including a
reconciler ``'stalled'`` escalation — so dropping the row clears
the derived ``failed`` ingest status the sources list shows.
Returns ``True`` when a row was removed.
"""
result = self._conn.execute(
text(
"DELETE FROM ingest_chunk_progress "
"WHERE source_id = CAST(:source_id AS uuid)"
),
{"source_id": str(source_id)},
)
return result.rowcount > 0
def bump_heartbeat(self, source_id: str) -> None:
"""Refresh ``last_updated`` so the row looks alive to the reconciler."""
self._conn.execute(

View File

@@ -107,11 +107,7 @@ class ReconciliationRepository:
def find_and_lock_stalled_ingests(
self, *, age_minutes: int = 30, limit: int = 100,
) -> list[dict]:
"""Lock still-active ingest checkpoints with a silent heartbeat.
The ``status = 'active'`` filter skips rows already escalated to
``'stalled'``, so a dead ingest is alerted once, not every tick.
"""
"""Lock ingest checkpoints whose heartbeat hasn't ticked recently."""
result = self._conn.execute(
text(
"""
@@ -120,7 +116,6 @@ class ReconciliationRepository:
FROM ingest_chunk_progress
WHERE last_updated < now() - make_interval(mins => :age)
AND embedded_chunks < total_chunks
AND status = 'active'
ORDER BY last_updated ASC
LIMIT :limit
FOR UPDATE SKIP LOCKED
@@ -130,15 +125,11 @@ class ReconciliationRepository:
)
return [row_to_dict(r) for r in result.fetchall()]
def mark_ingest_stalled(self, source_id: str) -> bool:
"""Escalate a stalled checkpoint to terminal ``status='stalled'``.
Drops the row out of the sweep so the reconciler alerts once;
``init_progress`` flips it back to ``'active'`` on reingest.
"""
def touch_ingest_progress(self, source_id: str) -> bool:
"""Bump ``last_updated`` so a once-stalled ingest re-enters the watch window."""
result = self._conn.execute(
text(
"UPDATE ingest_chunk_progress SET status = 'stalled' "
"UPDATE ingest_chunk_progress SET last_updated = now() "
"WHERE source_id = CAST(:sid AS uuid)"
),
{"sid": str(source_id)},

View File

@@ -5,10 +5,10 @@ from __future__ import annotations
import json
from typing import Any, Optional
from sqlalchemy import case, Connection, func, select, text
from sqlalchemy import Connection, func, select, text
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
from application.storage.db.models import ingest_chunk_progress_table, sources_table
from application.storage.db.models import sources_table
_SCALAR_COLUMNS = {
@@ -61,21 +61,6 @@ def _coerce_jsonb(value: Any) -> Any:
return value
def _ingest_status_case():
"""Derive a user-facing ingest status from the joined progress row.
``failed`` — reconciler-escalated stall. ``processing`` — embed in
flight. ``None`` — no progress row, or the embed completed.
"""
icp = ingest_chunk_progress_table
return case(
(icp.c.source_id.is_(None), None),
(icp.c.status == "stalled", "failed"),
(icp.c.embedded_chunks < icp.c.total_chunks, "processing"),
else_=None,
).label("ingest_status")
class SourcesRepository:
def __init__(self, conn: Connection) -> None:
self._conn = conn
@@ -207,25 +192,13 @@ class SourcesRepository:
as ``"desc"``.
Returns:
A list of source rows as plain dicts (via ``row_to_dict``),
each carrying a derived ``ingest_status`` (``failed`` /
``processing`` / ``None``) from the joined progress row.
A list of source rows as plain dicts (via ``row_to_dict``).
"""
column_name = sort_field if sort_field in _SORTABLE_COLUMNS else "date"
sort_column = sources_table.c[column_name]
ascending = sort_order.lower() == "asc"
stmt = (
select(sources_table, _ingest_status_case())
.select_from(
sources_table.outerjoin(
ingest_chunk_progress_table,
ingest_chunk_progress_table.c.source_id
== sources_table.c.id,
)
)
.where(sources_table.c.user_id == user_id)
)
stmt = select(sources_table).where(sources_table.c.user_id == user_id)
if search_term:
stmt = stmt.where(
sources_table.c.name.ilike(

View File

@@ -63,8 +63,7 @@ class ToolCallAttemptsRepository:
message_id: Optional[str] = None,
artifact_id: Optional[str] = None,
) -> None:
"""Insert OR upgrade a row to ``executed`` — or ``confirmed`` when
there is no ``message_id``, as in ``mark_executed``.
"""Insert OR upgrade a row to ``executed``.
Used as a fallback when ``record_proposed`` failed (DB outage)
and the tool ran anyway — preserves the journal so the
@@ -73,7 +72,6 @@ class ToolCallAttemptsRepository:
result_payload: dict = {"result": result}
if artifact_id:
result_payload["artifact_id"] = artifact_id
status = "executed" if message_id is not None else "confirmed"
self._conn.execute(
text(
"""
@@ -84,9 +82,9 @@ class ToolCallAttemptsRepository:
(:call_id, CAST(:tool_id AS uuid), :tool_name,
:action_name, CAST(:arguments AS jsonb),
CAST(:result AS jsonb), CAST(:message_id AS uuid),
:status)
'executed')
ON CONFLICT (call_id) DO UPDATE
SET status = :status,
SET status = 'executed',
result = EXCLUDED.result,
message_id = COALESCE(EXCLUDED.message_id, tool_call_attempts.message_id)
"""
@@ -99,7 +97,6 @@ class ToolCallAttemptsRepository:
"arguments": json.dumps(arguments if arguments is not None else {}, cls=PGNativeJSONEncoder),
"result": json.dumps(result_payload, cls=PGNativeJSONEncoder),
"message_id": message_id,
"status": status,
},
)
@@ -111,9 +108,7 @@ class ToolCallAttemptsRepository:
message_id: Optional[str] = None,
artifact_id: Optional[str] = None,
) -> bool:
"""Flip ``proposed`` → ``executed``, or straight to ``confirmed``
when there is no ``message_id`` (a ``save_conversation=False``
request reserves no message, so no finalize will confirm it).
"""Flip ``proposed`` → ``executed`` with the tool result.
``artifact_id`` (when present) is stored alongside ``result`` in
the JSONB as audit data — the reconciler reads it for diagnostic
@@ -122,14 +117,12 @@ class ToolCallAttemptsRepository:
result_payload: dict = {"result": result}
if artifact_id:
result_payload["artifact_id"] = artifact_id
status = "executed" if message_id is not None else "confirmed"
sql = (
"UPDATE tool_call_attempts SET "
"status = :status, result = CAST(:result AS jsonb)"
"status = 'executed', result = CAST(:result AS jsonb)"
)
params: dict[str, Any] = {
"call_id": call_id,
"status": status,
"result": json.dumps(result_payload, cls=PGNativeJSONEncoder),
}
if message_id is not None:

View File

@@ -29,10 +29,7 @@ from application.parser.embedding_pipeline import (
)
from application.parser.file.bulk import SimpleDirectoryReader, get_default_file_extractor
from application.parser.file.constants import SUPPORTED_SOURCE_EXTENSIONS
from application.parser.remote.remote_creator import (
RemoteCreator,
normalize_remote_data,
)
from application.parser.remote.remote_creator import RemoteCreator
from application.parser.schema.base import Document
from application.retriever.retriever_creator import RetrieverCreator
@@ -100,40 +97,6 @@ def _stop_ingest_heartbeat(thread, stop_event):
thread.join(timeout=5)
def _make_parse_progress_callback(task, user, source_id, start_pct, end_pct):
"""Build a ``load_data`` callback mapping parse progress to
``[start_pct, end_pct]`` via ``update_state`` + a throttled
``stage='parsing'`` SSE event.
"""
span = end_pct - start_pct
source_id_str = str(source_id)
state = {"last_pct": -1}
def _callback(files_done, total_files):
if not total_files:
return
pct = start_pct + int((files_done / total_files) * span)
task.update_state(
state="PROGRESS",
meta={"current": pct, "status": "Parsing files"},
)
if user and pct > state["last_pct"]:
publish_user_event(
user,
"source.ingest.progress",
{
"current": pct,
"total": total_files,
"files_done": files_done,
"stage": "parsing",
},
scope={"kind": "source", "id": source_id_str},
)
state["last_pct"] = pct
return _callback
# Define a function to extract metadata from a given filename.
@@ -674,12 +637,7 @@ def ingest_worker(
exclude_hidden=exclude,
file_metadata=metadata_from_filename,
)
# Parsing/OCR owns 1-50% of the bar; embedding takes 50-100%.
raw_docs = reader.load_data(
progress_callback=_make_parse_progress_callback(
self, user, source_uuid, start_pct=1, end_pct=50,
)
)
raw_docs = reader.load_data()
directory_structure = getattr(reader, "directory_structure", {})
logging.info(f"Directory structure from reader: {directory_structure}")
@@ -719,7 +677,6 @@ def ingest_worker(
docs, vector_store_path, source_uuid, self,
attempt_id=getattr(self.request, "id", None),
user_id=user,
progress_start=50, progress_end=100,
)
finally:
_stop_ingest_heartbeat(heartbeat_thread, heartbeat_stop)
@@ -850,8 +807,6 @@ def reingest_source_worker(self, source_id, user):
{
"source_id": source_id,
"name": source_name,
# ``filename`` labels the upload toast on auto-create.
"filename": source_name,
"operation": "reingest",
},
scope={"kind": "source", "id": source_id},
@@ -959,7 +914,6 @@ def reingest_source_worker(self, source_id, user):
{
"source_id": source_id,
"name": source_name,
"filename": source_name,
"operation": "reingest",
"no_changes": True,
"chunks_added": 0,
@@ -1147,7 +1101,6 @@ def reingest_source_worker(self, source_id, user):
completed_payload: dict = {
"source_id": source_id,
"name": source_name,
"filename": source_name,
"operation": "reingest",
"chunks_added": added,
"chunks_deleted": deleted,
@@ -1187,7 +1140,6 @@ def reingest_source_worker(self, source_id, user):
{
"source_id": str(source_id),
"name": source_name,
"filename": source_name,
"operation": "reingest",
"error": str(e)[:1024],
},
@@ -1479,35 +1431,19 @@ def sync_worker(self, frequency):
name = doc.get("name")
user = doc.get("user_id")
source_type = doc.get("type")
source_data = doc.get("remote_data")
retriever = doc.get("retriever")
doc_id = str(doc.get("id"))
sync_counts["total_sync_count"] += 1
# Connector sources have no RemoteCreator loader and need an OAuth
# token to sync, which a scheduled task lacks — skip them.
if source_type and source_type.startswith("connector"):
sync_counts["sync_skipped"] += 1
continue
source_data = normalize_remote_data(source_type, doc.get("remote_data"))
if not source_data:
# No syncable URL/config — skip instead of dispatching a sync
# that can only fail (and emit a spurious failed event).
sync_counts["sync_skipped"] += 1
continue
resp = sync(
self, source_data, name, user, source_type, frequency, retriever, doc_id
)
sync_counts["total_sync_count"] += 1
sync_counts[
"sync_success" if resp["status"] == "success" else "sync_failure"
] += 1
return {
key: sync_counts[key]
for key in [
"total_sync_count", "sync_success", "sync_failure", "sync_skipped",
]
for key in ["total_sync_count", "sync_success", "sync_failure"]
}
@@ -1849,15 +1785,14 @@ def ingest_connector(
exclude_hidden=True,
file_metadata=metadata_from_filename,
)
# Parsing/OCR fills 40-60% of the bar; embedding takes 60-100%.
raw_docs = reader.load_data(
progress_callback=_make_parse_progress_callback(
self, user, source_uuid, start_pct=40, end_pct=60,
)
)
raw_docs = reader.load_data()
directory_structure = getattr(reader, "directory_structure", {})
# Step 4: Process documents (chunking, embedding, etc.)
self.update_state(
state="PROGRESS", meta={"current": 60, "status": "Processing documents"}
)
chunker = Chunker(
chunking_strategy="classic_chunk",
max_tokens=MAX_TOKENS,
@@ -1894,13 +1829,12 @@ def ingest_connector(
os.makedirs(vector_store_path, exist_ok=True)
self.update_state(
state="PROGRESS", meta={"current": 60, "status": "Storing documents"}
state="PROGRESS", meta={"current": 80, "status": "Storing documents"}
)
embed_and_store_documents(
docs, vector_store_path, source_uuid, self,
attempt_id=getattr(self.request, "id", None),
user_id=user,
progress_start=60, progress_end=100,
)
assert_index_complete(source_uuid)

View File

@@ -34,7 +34,6 @@ const endpoints = {
LOGS: `/api/get_user_logs`,
MANAGE_SYNC: '/api/manage_sync',
SYNC_SOURCE: '/api/sync_source',
REINGEST_SOURCE: '/api/sources/reingest',
GET_AVAILABLE_TOOLS: '/api/available_tools',
GET_USER_TOOLS: '/api/get_tools',
CREATE_TOOL: '/api/create_tool',

View File

@@ -73,8 +73,6 @@ const userService = {
apiClient.post(endpoints.USER.MANAGE_SYNC, data, token),
syncSource: (data: any, token: string | null): Promise<any> =>
apiClient.post(endpoints.USER.SYNC_SOURCE, data, token),
reingestSource: (data: any, token: string | null): Promise<any> =>
apiClient.post(endpoints.USER.REINGEST_SOURCE, data, token),
getAvailableTools: (token: string | null): Promise<any> =>
apiClient.get(endpoints.USER.GET_AVAILABLE_TOOLS, token),
getUserTools: (token: string | null): Promise<any> =>

View File

@@ -165,19 +165,12 @@ function UploadRow({
return (
<li className="border-border/50 border-b last:border-b-0">
<div className="flex items-center justify-between px-5 py-3">
<div className="flex min-w-0 flex-col">
<p
className="font-inter dark:text-muted-foreground max-w-[200px] truncate text-[13px] leading-[16.5px] font-normal text-black"
title={task.fileName}
>
{task.fileName}
</p>
{task.status === 'training' && task.stage && (
<span className="font-inter text-muted-foreground mt-0.5 text-[11px] leading-[14px]">
{t(`modals.uploadDoc.progress.${task.stage}`)}
</span>
)}
</div>
<p
className="font-inter dark:text-muted-foreground max-w-[200px] truncate text-[13px] leading-[16.5px] font-normal text-black"
title={task.fileName}
>
{task.fileName}
</p>
<div className="flex items-center gap-2">
{showProgress && (

View File

@@ -154,10 +154,10 @@ const ConversationBubble = forwardRef<
<img
src={DocumentationDark}
alt="Attachment"
className="h-3.75 w-3.75 object-fill"
className="h-[15px] w-[15px] object-fill"
/>
</div>
<span className="max-w-37.5 truncate font-normal">
<span className="max-w-[150px] truncate font-normal">
{file.fileName}
</span>
</div>
@@ -328,7 +328,7 @@ const ConversationBubble = forwardRef<
<div className="mb-4 flex flex-col flex-wrap items-start self-start lg:flex-nowrap">
<div className="my-2 flex flex-row items-center justify-center gap-3">
<Avatar
className="h-6.5 w-7.5 text-xl"
className="h-[26px] w-[30px] text-xl"
avatar={
<img
src={Sources}
@@ -376,7 +376,7 @@ const ConversationBubble = forwardRef<
<img
src={Document}
alt="Document"
className="h-4.25 w-4.25 object-fill"
className="h-[17px] w-[17px] object-fill"
/>
<p
className="mt-0.5 truncate text-xs"
@@ -394,11 +394,11 @@ const ConversationBubble = forwardRef<
</div>
{activeTooltip === index && (
<div
className={`dark:bg-card dark:text-foreground absolute left-1/2 z-50 max-h-48 w-40 translate-x-[-50%] translate-y-0.75 rounded-xl bg-[#FBFBFB] p-4 text-black shadow-xl sm:w-56`}
className={`dark:bg-card dark:text-foreground absolute left-1/2 z-50 max-h-48 w-40 translate-x-[-50%] translate-y-[3px] rounded-xl bg-[#FBFBFB] p-4 text-black shadow-xl sm:w-56`}
onMouseOver={() => setActiveTooltip(index)}
onMouseOut={() => setActiveTooltip(null)}
>
<p className="line-clamp-6 max-h-41 overflow-hidden rounded-md text-sm wrap-break-word text-ellipsis">
<p className="line-clamp-6 max-h-[164px] overflow-hidden rounded-md text-sm wrap-break-word text-ellipsis">
{source.text}
</p>
</div>
@@ -471,7 +471,7 @@ const ConversationBubble = forwardRef<
<div className="flex max-w-full flex-col flex-wrap items-start self-start lg:flex-nowrap">
<div className="my-2 flex flex-row items-center justify-center gap-3">
<Avatar
className="h-8.5 w-8.5 text-2xl"
className="h-[34px] w-[34px] text-2xl"
avatar={
<img
src={DocsGPT3}
@@ -1023,7 +1023,7 @@ function ToolCalls({
);
return (
<div className="mb-4 relative flex w-full flex-col flex-wrap items-start self-start lg:flex-nowrap">
<div className="mb-4 flex w-full flex-col flex-wrap items-start self-start lg:flex-nowrap">
{/* Approval bars — always visible, compact inline */}
{awaitingCalls.length > 0 && (
<div className="fade-in mt-4 ml-3 w-[90vw] md:w-[70vw] lg:w-full">
@@ -1042,7 +1042,7 @@ function ToolCalls({
<>
<div className="my-2 flex flex-row items-center justify-center gap-3">
<Avatar
className="h-6.5 w-7.5 text-xl"
className="h-[26px] w-[30px] text-xl"
avatar={
<img
src={Sources}
@@ -1089,7 +1089,7 @@ function ToolCalls({
</p>
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
<span
className="dark:text-muted-foreground leading-5.75 text-black"
className="dark:text-muted-foreground leading-[23px] text-black"
style={{ fontFamily: 'IBMPlexMono-Medium' }}
>
{JSON.stringify(toolCall.arguments, null, 2)}
@@ -1117,7 +1117,7 @@ function ToolCalls({
{toolCall.status === 'completed' && (
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
<span
className="dark:text-muted-foreground leading-5.75 text-black"
className="dark:text-muted-foreground leading-[23px] text-black"
style={{ fontFamily: 'IBMPlexMono-Medium' }}
>
{JSON.stringify(toolCall.result, null, 2)}
@@ -1127,7 +1127,7 @@ function ToolCalls({
{toolCall.status === 'error' && (
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
<span
className="text-destructive leading-5.75"
className="text-destructive leading-[23px]"
style={{ fontFamily: 'IBMPlexMono-Medium' }}
>
{toolCall.error}
@@ -1137,7 +1137,7 @@ function ToolCalls({
{toolCall.status === 'denied' && (
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
<span
className="text-muted-foreground leading-5.75"
className="text-muted-foreground leading-[23px]"
style={{ fontFamily: 'IBMPlexMono-Medium' }}
>
Denied by user
@@ -1172,7 +1172,7 @@ function Thought({
<div className="mb-4 flex w-full flex-col flex-wrap items-start self-start lg:flex-nowrap">
<div className="my-2 flex flex-row items-center justify-center gap-3">
<Avatar
className="h-6.5 w-7.5 text-xl"
className="h-[26px] w-[30px] text-xl"
avatar={
<img
src={Cloud}
@@ -1197,7 +1197,7 @@ function Thought({
</div>
{isThoughtOpen && (
<div className="fade-in mr-5 ml-2 max-w-[90vw] md:max-w-[70vw] lg:max-w-[50vw]">
<div className="bg-muted dark:bg-answer-bubble rounded-[28px] px-7 py-4.5">
<div className="bg-muted dark:bg-answer-bubble rounded-[28px] px-7 py-[18px]">
<ReactMarkdown
className="fade-in leading-normal wrap-break-word whitespace-pre-wrap"
remarkPlugins={[remarkGfm, remarkMath]}

View File

@@ -1,4 +1,10 @@
import { useCallback, useEffect, useRef, useState, RefObject } from 'react';
import {
useCallback,
useEffect,
useRef,
useState,
RefObject,
} from 'react';
export function useOutsideAlerter<T extends HTMLElement>(
ref: RefObject<T | null>,

View File

@@ -70,9 +70,6 @@
"sync": "Synchronisieren",
"syncNow": "Jetzt synchronisieren",
"syncing": "Synchronisiere...",
"reingest": "Erneut indexieren",
"ingestFailed": "Indexierung fehlgeschlagen",
"ingestProcessing": "Indexierung...",
"syncConfirmation": "Bist du sicher, dass du \"{{sourceName}}\" synchronisieren möchtest? Dies aktualisiert den Inhalt mit deinem Cloud-Speicher und kann Änderungen an einzelnen Chunks überschreiben.",
"syncFrequency": {
"never": "Nie",
@@ -356,8 +353,6 @@
"failed": "Upload fehlgeschlagen",
"wait": "Dies kann einige Minuten dauern",
"preparing": "Upload wird vorbereitet",
"parsing": "Dateien werden verarbeitet...",
"embedding": "Einbettung...",
"tokenLimit": "Token-Limit überschritten, bitte lade ein kleineres Dokument hoch",
"expandDetails": "Upload-Details erweitern",
"collapseDetails": "Upload-Details einklappen",

View File

@@ -70,9 +70,6 @@
"sync": "Sync",
"syncNow": "Sync now",
"syncing": "Syncing...",
"reingest": "Reingest",
"ingestFailed": "Indexing failed",
"ingestProcessing": "Indexing…",
"syncConfirmation": "Are you sure you want to sync \"{{sourceName}}\"? This will update the content with your cloud storage and may override any edits you made to individual chunks.",
"syncFrequency": {
"never": "Never",
@@ -368,8 +365,6 @@
"failed": "Upload failed",
"wait": "This may take several minutes",
"preparing": "Preparing upload",
"parsing": "Parsing files…",
"embedding": "Embedding…",
"tokenLimit": "Over the token limit, please consider uploading smaller document",
"expandDetails": "Expand upload details",
"collapseDetails": "Collapse upload details",

View File

@@ -70,9 +70,6 @@
"sync": "Sincronizar",
"syncNow": "Sincronizar ahora",
"syncing": "Sincronizando...",
"reingest": "Reindexar",
"ingestFailed": "Error de indexación",
"ingestProcessing": "Indexando...",
"syncConfirmation": "¿Estás seguro de que deseas sincronizar \"{{sourceName}}\"? Esto actualizará el contenido con tu almacenamiento en la nube y puede anular cualquier edición que hayas realizado en fragmentos individuales.",
"syncFrequency": {
"never": "Nunca",
@@ -356,8 +353,6 @@
"failed": "Error al subir",
"wait": "Esto puede tardar varios minutos",
"preparing": "Preparando subida",
"parsing": "Analizando archivos...",
"embedding": "Generando incrustaciones...",
"tokenLimit": "Excede el límite de tokens, considere cargar un documento más pequeño",
"expandDetails": "Expandir detalles de subida",
"collapseDetails": "Contraer detalles de subida",

View File

@@ -70,9 +70,6 @@
"sync": "同期",
"syncNow": "今すぐ同期",
"syncing": "同期中...",
"reingest": "再インデックス",
"ingestFailed": "インデックス作成に失敗しました",
"ingestProcessing": "インデックス作成中...",
"syncConfirmation": "\"{{sourceName}}\"を同期してもよろしいですか?これにより、コンテンツがクラウドストレージで更新され、個々のチャンクに加えた編集が上書きされる可能性があります。",
"syncFrequency": {
"never": "なし",
@@ -356,8 +353,6 @@
"failed": "アップロード失敗",
"wait": "数分かかる場合があります",
"preparing": "アップロードを準備中",
"parsing": "ファイルを解析中...",
"embedding": "埋め込み処理中...",
"tokenLimit": "トークン制限を超えています。より小さいドキュメントをアップロードしてください",
"expandDetails": "アップロードの詳細を展開",
"collapseDetails": "アップロードの詳細を折りたたむ",

View File

@@ -70,9 +70,6 @@
"sync": "Синхронизация",
"syncNow": "Синхронизировать сейчас",
"syncing": "Синхронизация...",
"reingest": "Переиндексировать",
"ingestFailed": "Ошибка индексации",
"ingestProcessing": "Индексация...",
"syncConfirmation": "Вы уверены, что хотите синхронизировать \"{{sourceName}}\"? Это обновит содержимое с вашим облачным хранилищем и может перезаписать любые изменения, внесенные вами в отдельные фрагменты.",
"syncFrequency": {
"never": "Никогда",
@@ -356,8 +353,6 @@
"failed": "Ошибка загрузки",
"wait": "Это может занять несколько минут",
"preparing": "Подготовка загрузки",
"parsing": "Обработка файлов...",
"embedding": "Создание эмбеддингов...",
"tokenLimit": "Превышен лимит токенов, рассмотрите возможность загрузки документа меньшего размера",
"expandDetails": "Развернуть детали загрузки",
"collapseDetails": "Свернуть детали загрузки",

View File

@@ -70,9 +70,6 @@
"sync": "同步",
"syncNow": "立即同步",
"syncing": "同步中...",
"reingest": "重新索引",
"ingestFailed": "索引失敗",
"ingestProcessing": "索引中...",
"syncConfirmation": "您確定要同步 \"{{sourceName}}\" 嗎?這將使用您的雲端儲存更新內容,並可能覆蓋您對個別文本塊所做的任何編輯。",
"syncFrequency": {
"never": "從不",
@@ -356,8 +353,6 @@
"failed": "上傳失敗",
"wait": "這可能需要幾分鐘",
"preparing": "準備上傳",
"parsing": "正在解析檔案...",
"embedding": "正在生成嵌入...",
"tokenLimit": "超出令牌限制,請考慮上傳較小的文檔",
"expandDetails": "展開上傳詳情",
"collapseDetails": "摺疊上傳詳情",

View File

@@ -70,9 +70,6 @@
"sync": "同步",
"syncNow": "立即同步",
"syncing": "同步中...",
"reingest": "重新索引",
"ingestFailed": "索引失败",
"ingestProcessing": "索引中...",
"syncConfirmation": "您确定要同步 \"{{sourceName}}\" 吗?这将使用您的云存储更新内容,并可能覆盖您对单个文本块所做的任何编辑。",
"syncFrequency": {
"never": "从不",
@@ -356,8 +353,6 @@
"failed": "上传失败",
"wait": "这可能需要几分钟",
"preparing": "准备上传",
"parsing": "正在解析文件...",
"embedding": "正在生成嵌入...",
"tokenLimit": "超出令牌限制,请考虑上传较小的文档",
"expandDetails": "展开上传详情",
"collapseDetails": "折叠上传详情",

View File

@@ -14,8 +14,6 @@ export type Doc = {
syncFrequency?: string;
isNested?: boolean;
provider?: string;
// Derived server-side from ingest_chunk_progress (sources API).
ingestStatus?: 'processing' | 'failed';
};
export type GetDocsResponse = {

View File

@@ -27,12 +27,6 @@ import {
setSourceDocs,
} from '../preferences/preferenceSlice';
import Upload from '../upload/Upload';
import {
addUploadTask,
removeUploadTask,
selectUploadTasks,
updateUploadTask,
} from '../upload/uploadSlice';
import { formatDate } from '../utils/dateTimeUtils';
import FileTree from '../components/FileTree';
import ConnectorTree from '../components/ConnectorTree';
@@ -62,7 +56,6 @@ export default function Sources({
const [isDarkTheme] = useDarkTheme();
const dispatch = useDispatch();
const token = useSelector(selectToken);
const uploadTasks = useSelector(selectUploadTasks);
const [searchTerm, setSearchTerm] = useState<string>('');
const debouncedSearchTerm = useDebouncedValue(searchTerm, 500);
@@ -256,57 +249,6 @@ export default function Sources({
}
};
const handleReingest = async (doc: Doc) => {
if (!doc.id) {
return;
}
const sourceId = doc.id;
// Drop stale toast rows for this source (a finished/dismissed task
// would swallow the reingest's SSE events), then open a fresh one.
uploadTasks
.filter((task) => task.sourceId === sourceId)
.forEach((task) => dispatch(removeUploadTask(task.id)));
const reingestTaskId = `reingest-${sourceId}-${Date.now()}`;
dispatch(
addUploadTask({
id: reingestTaskId,
fileName: doc.name || sourceId,
progress: 0,
status: 'training',
sourceId,
}),
);
try {
const response = await userService.reingestSource(
{ source_id: sourceId },
token,
);
const data = await response.json();
if (!data.success) {
console.error('Reingest failed:', data.error || data.message);
dispatch(
updateUploadTask({
id: reingestTaskId,
updates: {
status: 'failed',
errorMessage: data.error || data.message,
},
}),
);
return;
}
refreshDocs(undefined, currentPage, rowsPerPage);
} catch (error) {
console.error('Error reingesting source:', error);
dispatch(
updateUploadTask({
id: reingestTaskId,
updates: { status: 'failed' },
}),
);
}
};
const [documentToDelete, setDocumentToDelete] = useState<{
index: number;
document: Doc;
@@ -341,19 +283,6 @@ export default function Sources({
},
];
if (document.ingestStatus === 'failed') {
actions.push({
icon: SyncIcon,
label: t('settings.sources.reingest'),
onClick: () => {
handleReingest(document);
},
iconWidth: 14,
iconHeight: 14,
variant: 'primary',
});
}
if (document.syncFrequency) {
actions.push({
icon: SyncIcon,
@@ -554,16 +483,6 @@ export default function Sources({
</div>
<div className="flex flex-col items-start justify-start gap-1">
{document.ingestStatus === 'failed' && (
<span className="rounded-full bg-red-100 px-2 py-0.5 text-[11px] leading-[16px] font-medium text-red-700 dark:bg-red-900/30 dark:text-red-400">
{t('settings.sources.ingestFailed')}
</span>
)}
{document.ingestStatus === 'processing' && (
<span className="bg-muted-foreground/10 text-muted-foreground rounded-full px-2 py-0.5 text-[11px] leading-[16px] font-medium">
{t('settings.sources.ingestProcessing')}
</span>
)}
<div className="flex items-center gap-2">
<img
src={CalendarIcon}

View File

@@ -286,26 +286,6 @@ describe('source.ingest.progress', () => {
state = reducer(state, ingest('source.ingest.progress', { current: -10 }));
expect(state.tasks[0].progress).toBe(100);
});
it('records the ingest stage from the payload', () => {
let state = stateWithTask(makeTask({ status: 'training' }));
state = reducer(
state,
ingest('source.ingest.progress', { current: 20, stage: 'parsing' }),
);
expect(state.tasks[0].stage).toBe('parsing');
state = reducer(
state,
ingest('source.ingest.progress', { current: 70, stage: 'embedding' }),
);
expect(state.tasks[0].stage).toBe('embedding');
// An unknown/absent stage leaves the last known value intact.
state = reducer(
state,
ingest('source.ingest.progress', { current: 80, stage: 'bogus' }),
);
expect(state.tasks[0].stage).toBe('embedding');
});
});
describe('source.ingest.completed', () => {

View File

@@ -66,12 +66,6 @@ export interface UploadTask {
sourceId?: string;
errorMessage?: string;
dismissed?: boolean;
/**
* Ingest phase from the latest ``source.ingest.progress`` event:
* ``parsing`` (parse/OCR, lower band of the bar) or ``embedding``
* (upper band). Drives the phase label in ``UploadToast``.
*/
stage?: 'parsing' | 'embedding';
/**
* Flipped when ``source.ingest.completed`` carries
* ``payload.limited === true`` (the worker hit a token cap during
@@ -340,9 +334,6 @@ export const uploadSlice = createSlice({
if (task.status === 'completed' || task.status === 'failed') break;
task.status = 'training';
if (clamped > task.progress) task.progress = clamped;
if (payload.stage === 'parsing' || payload.stage === 'embedding') {
task.stage = payload.stage;
}
break;
}
case 'source.ingest.completed':

View File

@@ -4,24 +4,19 @@ Fixed 5-second generation (100 tokens × 50 ms/token). No auth. Emits SSE
chunks in OpenAI's chat.completions streaming format, or a single response
when stream=false. Run on 127.0.0.1:8090 — point DocsGPT at it via
OPENAI_BASE_URL=http://127.0.0.1:8090/v1.
Flags:
--tool-calls First response returns a tool call instead of text.
Subsequent responses (after a tool_result) return text.
Useful for triggering the tool-execution loop.
"""
import argparse
import asyncio
import json
import logging
import time
import uuid
from flask import Flask, Response, request, jsonify
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
TOKEN_COUNT = 100
TOKEN_DELAY_S = 0.05 # 100 * 0.05 = 5.0 s
TOOL_CALL_MODE = False
logger = logging.getLogger("mock_llm")
logging.basicConfig(level=logging.INFO, format="%(asctime)s mock: %(message)s")
@@ -44,7 +39,7 @@ FILLER_TOKENS = [
".",
]
app = Flask(__name__)
app = FastAPI()
def _token_stream_id() -> str:
@@ -68,57 +63,11 @@ def _sse_chunk(completion_id: str, model: str, delta: dict, finish_reason=None)
return f"data: {json.dumps(payload)}\n\n"
def _gen_tool_call_stream(model: str, req_id: str):
"""Emit two tool_calls (search) in streaming format.
Two calls ensure the handler executes the first (which can return a
huge result), then hits _check_context_limit before the second.
"""
completion_id = _token_stream_id()
call_id_1 = f"call_{uuid.uuid4().hex[:12]}"
call_id_2 = f"call_{uuid.uuid4().hex[:12]}"
yield _sse_chunk(completion_id, model, {
"role": "assistant",
"content": None,
"tool_calls": [
{
"index": 0,
"id": call_id_1,
"type": "function",
"function": {"name": "search", "arguments": ""},
},
{
"index": 1,
"id": call_id_2,
"type": "function",
"function": {"name": "search", "arguments": ""},
},
],
})
args_json = json.dumps({"query": "Python programming basics"})
for ch in args_json:
time.sleep(TOKEN_DELAY_S)
yield _sse_chunk(completion_id, model, {
"tool_calls": [
{"index": 0, "function": {"arguments": ch}},
{"index": 1, "function": {"arguments": ch}},
],
})
yield _sse_chunk(completion_id, model, {}, finish_reason="tool_calls")
yield "data: [DONE]\n\n"
logger.info("[%s] tool_call stream done (ids=%s, %s)", req_id, call_id_1, call_id_2)
def _has_tool_result(messages: list) -> bool:
return any(m.get("role") == "tool" for m in messages)
def _gen_text_stream(model: str, req_id: str):
async def _stream_response(model: str, req_id: str):
completion_id = _token_stream_id()
yield _sse_chunk(completion_id, model, {"role": "assistant", "content": ""})
for tok in FILLER_TOKENS[:TOKEN_COUNT]:
time.sleep(TOKEN_DELAY_S)
for i, tok in enumerate(FILLER_TOKENS[:TOKEN_COUNT]):
await asyncio.sleep(TOKEN_DELAY_S)
yield _sse_chunk(completion_id, model, {"content": tok})
yield _sse_chunk(completion_id, model, {}, finish_reason="stop")
yield "data: [DONE]\n\n"
@@ -126,84 +75,63 @@ def _gen_text_stream(model: str, req_id: str):
@app.post("/v1/chat/completions")
def chat_completions():
body = request.get_json(force=True)
async def chat_completions(request: Request):
body = await request.json()
model = body.get("model", "mock")
stream = bool(body.get("stream", False))
messages = body.get("messages", [])
tools = body.get("tools")
req_id = uuid.uuid4().hex[:8]
logger.info(
"[%s] /chat/completions stream=%s model=%s tools=%s msgs=%d",
req_id, stream, model, bool(tools), len(messages),
)
use_tool_call = (
TOOL_CALL_MODE
and tools
and not _has_tool_result(messages)
)
logger.info("[%s] /chat/completions stream=%s model=%s max_tokens=%s", req_id, stream, model, body.get("max_tokens"))
if stream:
gen = (
_gen_tool_call_stream(model, req_id) if use_tool_call
else _gen_text_stream(model, req_id)
)
return Response(
gen,
mimetype="text/event-stream",
return StreamingResponse(
_stream_response(model, req_id),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache, no-transform",
"X-Accel-Buffering": "no",
},
)
time.sleep(TOKEN_COUNT * TOKEN_DELAY_S)
await asyncio.sleep(TOKEN_COUNT * TOKEN_DELAY_S)
logger.info("[%s] non-stream done", req_id)
text = "".join(FILLER_TOKENS[:TOKEN_COUNT])
completion_id = _token_stream_id()
return jsonify({
"id": completion_id,
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": text},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 10,
"completion_tokens": TOKEN_COUNT,
"total_tokens": 10 + TOKEN_COUNT,
},
})
return JSONResponse(
{
"id": completion_id,
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": text},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 10,
"completion_tokens": TOKEN_COUNT,
"total_tokens": 10 + TOKEN_COUNT,
},
}
)
@app.get("/v1/models")
def list_models():
return jsonify({
async def list_models():
return {
"object": "list",
"data": [{"id": "mock", "object": "model", "owned_by": "mock"}],
})
}
@app.get("/health")
def health():
return jsonify({"status": "ok"})
async def health():
return {"status": "ok"}
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--tool-calls", action="store_true",
help="First response returns a tool_call; subsequent responses return text.",
)
parser.add_argument("--port", type=int, default=8090)
args = parser.parse_args()
TOOL_CALL_MODE = args.tool_calls
if TOOL_CALL_MODE:
logger.info("Tool-call mode enabled")
app.run(host="127.0.0.1", port=args.port, debug=False, threaded=True)
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=8090, log_level="info")

View File

@@ -45,14 +45,15 @@ class TestAPIToolInit:
@pytest.mark.unit
class TestMakeApiCall:
@patch("application.agents.tools.api_tool.pinned_request")
def test_successful_get(self, mock_pinned, tool):
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_successful_get(self, mock_get, mock_validate, tool):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {"result": "ok"}
mock_resp.content = b'{"result":"ok"}'
mock_pinned.return_value = mock_resp
mock_get.return_value = mock_resp
result = tool.execute_action("any_action")
@@ -60,50 +61,54 @@ class TestMakeApiCall:
assert result["data"] == {"result": "ok"}
assert result["message"] == "API call successful."
@patch("application.agents.tools.api_tool.pinned_request")
def test_successful_post(self, mock_pinned, post_tool):
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.post")
def test_successful_post(self, mock_post, mock_validate, post_tool):
mock_resp = MagicMock()
mock_resp.status_code = 201
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {"id": 1}
mock_resp.content = b'{"id":1}'
mock_pinned.return_value = mock_resp
mock_post.return_value = mock_resp
result = post_tool.execute_action("create", name="test")
assert result["status_code"] == 201
@patch("application.agents.tools.api_tool.pinned_request")
def test_ssrf_blocked(self, mock_pinned, tool):
from application.security.safe_url import UnsafeUserUrlError
@patch("application.agents.tools.api_tool.validate_url")
def test_ssrf_blocked(self, mock_validate, tool):
from application.core.url_validation import SSRFError
mock_pinned.side_effect = UnsafeUserUrlError("blocked")
mock_validate.side_effect = SSRFError("blocked")
result = tool.execute_action("any")
assert result["status_code"] is None
assert "URL validation error" in result["message"]
@patch("application.agents.tools.api_tool.pinned_request")
def test_timeout_error(self, mock_pinned, tool):
mock_pinned.side_effect = requests.exceptions.Timeout()
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_timeout_error(self, mock_get, mock_validate, tool):
mock_get.side_effect = requests.exceptions.Timeout()
result = tool.execute_action("any")
assert result["status_code"] is None
assert "timeout" in result["message"].lower()
@patch("application.agents.tools.api_tool.pinned_request")
def test_connection_error(self, mock_pinned, tool):
mock_pinned.side_effect = requests.exceptions.ConnectionError("refused")
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_connection_error(self, mock_get, mock_validate, tool):
mock_get.side_effect = requests.exceptions.ConnectionError("refused")
result = tool.execute_action("any")
assert result["status_code"] is None
assert "Connection error" in result["message"]
@patch("application.agents.tools.api_tool.pinned_request")
def test_http_error(self, mock_pinned, tool):
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_http_error(self, mock_get, mock_validate, tool):
mock_resp = MagicMock()
mock_resp.status_code = 404
mock_resp.text = "Not Found"
@@ -111,14 +116,15 @@ class TestMakeApiCall:
mock_resp.raise_for_status.side_effect = requests.exceptions.HTTPError(
response=mock_resp
)
mock_pinned.return_value = mock_resp
mock_get.return_value = mock_resp
result = tool.execute_action("any")
assert result["status_code"] == 404
assert "HTTP Error" in result["message"]
def test_unsupported_method(self):
@patch("application.agents.tools.api_tool.validate_url")
def test_unsupported_method(self, mock_validate):
tool = APITool(
config={"url": "https://example.com", "method": "CUSTOM"}
)
@@ -126,64 +132,69 @@ class TestMakeApiCall:
assert result["status_code"] is None
assert "Unsupported" in result["message"]
@patch("application.agents.tools.api_tool.pinned_request")
def test_put_method(self, mock_pinned):
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.put")
def test_put_method(self, mock_put, mock_validate):
tool = APITool(config={"url": "https://example.com/item/1", "method": "PUT"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {}
mock_resp.content = b'{}'
mock_pinned.return_value = mock_resp
mock_put.return_value = mock_resp
result = tool.execute_action("update", name="new")
assert result["status_code"] == 200
@patch("application.agents.tools.api_tool.pinned_request")
def test_delete_method(self, mock_pinned):
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.delete")
def test_delete_method(self, mock_delete, mock_validate):
tool = APITool(config={"url": "https://example.com/item/1", "method": "DELETE"})
mock_resp = MagicMock()
mock_resp.status_code = 204
mock_resp.headers = {"Content-Type": "text/plain"}
mock_resp.content = b''
mock_pinned.return_value = mock_resp
mock_delete.return_value = mock_resp
result = tool.execute_action("delete")
assert result["status_code"] == 204
@patch("application.agents.tools.api_tool.pinned_request")
def test_patch_method(self, mock_pinned):
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.patch")
def test_patch_method(self, mock_patch, mock_validate):
tool = APITool(config={"url": "https://example.com/item/1", "method": "PATCH"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {"patched": True}
mock_resp.content = b'{"patched":true}'
mock_pinned.return_value = mock_resp
mock_patch.return_value = mock_resp
result = tool.execute_action("patch", field="val")
assert result["status_code"] == 200
@patch("application.agents.tools.api_tool.pinned_request")
def test_head_method(self, mock_pinned):
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.head")
def test_head_method(self, mock_head, mock_validate):
tool = APITool(config={"url": "https://example.com", "method": "HEAD"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "text/html"}
mock_resp.content = b''
mock_pinned.return_value = mock_resp
mock_head.return_value = mock_resp
result = tool.execute_action("check")
assert result["status_code"] == 200
@patch("application.agents.tools.api_tool.pinned_request")
def test_options_method(self, mock_pinned):
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.options")
def test_options_method(self, mock_options, mock_validate):
tool = APITool(config={"url": "https://example.com", "method": "OPTIONS"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "text/plain"}
mock_resp.content = b''
mock_pinned.return_value = mock_resp
mock_options.return_value = mock_resp
result = tool.execute_action("options")
assert result["status_code"] == 200
@@ -191,8 +202,9 @@ class TestMakeApiCall:
@pytest.mark.unit
class TestPathParamSubstitution:
@patch("application.agents.tools.api_tool.pinned_request")
def test_path_params_substituted(self, mock_pinned):
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_path_params_substituted(self, mock_get, mock_validate):
tool = APITool(
config={
"url": "https://api.example.com/users/{user_id}/posts/{post_id}",
@@ -205,11 +217,11 @@ class TestPathParamSubstitution:
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = []
mock_resp.content = b'[]'
mock_pinned.return_value = mock_resp
mock_get.return_value = mock_resp
tool.execute_action("get")
called_url = mock_pinned.call_args[0][1]
called_url = mock_get.call_args[0][0]
assert "/users/42/posts/7" in called_url
assert "{user_id}" not in called_url

View File

@@ -1,9 +1,10 @@
"""Tests for the journaled execute path on ToolExecutor.
Each tool call inserts a ``tool_call_attempts`` row and flips it
``proposed → executed`` (or ``→ failed``). With a ``message_id`` it
stays ``executed`` for the finalize path to confirm; without one
(``save_conversation=False``) it goes straight to ``confirmed``.
Each tool call inserts a row into ``tool_call_attempts`` then flips
through ``proposed → executed`` (or ``proposed → failed``). The flip
to ``confirmed`` is owned by the message-finalize path and is only
asserted indirectly here (rows stay in ``executed`` so the reconciler
can pick them up).
"""
from contextlib import contextmanager
@@ -74,24 +75,11 @@ def _make_call(name="test_action_t1", call_id="c1"):
return call
_TOOLS_DICT = {
"t1": {
"id": "00000000-0000-0000-0000-000000000001",
"name": "test_tool",
"config": {"key": "val"},
"actions": [
{"name": "test_action", "description": "T", "parameters": {"properties": {}}},
],
}
}
@pytest.mark.unit
class TestExecuteJournaling:
def test_no_message_id_proposed_then_confirmed(
def test_happy_path_proposed_then_executed(
self, pg_conn, mock_tool_manager, monkeypatch
):
"""No reserved message (``save_conversation=False``) → row lands ``confirmed``, not ``executed``."""
executor = ToolExecutor(user="u")
monkeypatch.setattr(
"application.agents.tool_executor.ToolActionParser",
@@ -101,12 +89,23 @@ class TestExecuteJournaling:
)
_patch_db(monkeypatch, pg_conn)
events, result = _drain(executor.execute(_TOOLS_DICT, _make_call(), "MockLLM"))
tools_dict = {
"t1": {
"id": "00000000-0000-0000-0000-000000000001",
"name": "test_tool",
"config": {"key": "val"},
"actions": [
{"name": "test_action", "description": "T", "parameters": {"properties": {}}},
],
}
}
events, result = _drain(executor.execute(tools_dict, _make_call(), "MockLLM"))
assert result[0] == "Tool result"
row = _select_attempt(pg_conn, "c1")
assert row is not None
assert row["status"] == "confirmed"
assert row["status"] == "executed"
assert row["tool_name"] == "test_tool"
assert row["action_name"] == "test_action"
assert row["arguments"] == {"q": "v"}
@@ -118,7 +117,10 @@ class TestExecuteJournaling:
def test_executor_message_id_is_persisted_on_executed_row(
self, pg_conn, mock_tool_manager, monkeypatch
):
"""The executor's message_id is carried onto the journal row, which stays ``executed``."""
"""When the route stamps a placeholder message_id on the executor,
the journal row carries it forward so ``confirm_executed_tool_calls``
can later flip it to ``confirmed``.
"""
from application.storage.db.repositories.conversations import (
ConversationsRepository,
)
@@ -145,7 +147,18 @@ class TestExecuteJournaling:
)
_patch_db(monkeypatch, pg_conn)
_drain(executor.execute(_TOOLS_DICT, _make_call(call_id="cm1"), "MockLLM"))
tools_dict = {
"t1": {
"id": "00000000-0000-0000-0000-000000000001",
"name": "test_tool",
"config": {"key": "val"},
"actions": [
{"name": "test_action", "description": "T", "parameters": {"properties": {}}},
],
}
}
_drain(executor.execute(tools_dict, _make_call(call_id="cm1"), "MockLLM"))
row = _select_attempt(pg_conn, "cm1")
assert row is not None
@@ -167,7 +180,18 @@ class TestExecuteJournaling:
RuntimeError("boom")
)
gen = executor.execute(_TOOLS_DICT, _make_call(call_id="c2"), "MockLLM")
tools_dict = {
"t1": {
"id": "00000000-0000-0000-0000-000000000001",
"name": "test_tool",
"config": {"key": "val"},
"actions": [
{"name": "test_action", "description": "T", "parameters": {"properties": {}}},
],
}
}
gen = executor.execute(tools_dict, _make_call(call_id="c2"), "MockLLM")
with pytest.raises(RuntimeError, match="boom"):
_drain(gen)
@@ -176,10 +200,42 @@ class TestExecuteJournaling:
assert row["status"] == "failed"
assert row["error"] == "boom"
def test_executed_row_lingers_for_reconciler_when_no_confirm(
self, pg_conn, mock_tool_manager, monkeypatch
):
"""No finalize_message call → row sits in ``executed``."""
executor = ToolExecutor(user="u")
monkeypatch.setattr(
"application.agents.tool_executor.ToolActionParser",
lambda _cls, **kw: Mock(
parse_args=Mock(return_value=("t1", "test_action", {}))
),
)
_patch_db(monkeypatch, pg_conn)
tools_dict = {
"t1": {
"id": "00000000-0000-0000-0000-000000000001",
"name": "test_tool",
"config": {"key": "val"},
"actions": [
{"name": "test_action", "description": "T", "parameters": {"properties": {}}},
],
}
}
_drain(executor.execute(tools_dict, _make_call(call_id="c3"), "MockLLM"))
row = _select_attempt(pg_conn, "c3")
assert row["status"] == "executed"
# Partial index `tool_call_attempts_pending_ts_idx` selects rows
# in ('proposed','executed') — the reconciler reads those.
assert row["status"] in ("proposed", "executed")
@pytest.mark.unit
class TestRepository:
def test_proposed_then_confirmed_when_no_message(self, pg_conn):
def test_proposed_then_executed_round_trip(self, pg_conn):
from application.storage.db.repositories.tool_call_attempts import (
ToolCallAttemptsRepository,
)
@@ -193,50 +249,7 @@ class TestRepository:
assert repo.mark_executed("c-x", {"out": "ok"}) is True
row = _select_attempt(pg_conn, "c-x")
assert row["status"] == "confirmed"
assert row["message_id"] is None
assert row["result"] == {"result": {"out": "ok"}}
def test_mark_executed_with_message_stays_executed(self, pg_conn):
from application.storage.db.repositories.conversations import (
ConversationsRepository,
)
from application.storage.db.repositories.tool_call_attempts import (
ToolCallAttemptsRepository,
)
# FK constraint: message_id must reference a real row.
conv_repo = ConversationsRepository(pg_conn)
conv = conv_repo.create("u-repo", "repo-msg-test")
msg = conv_repo.reserve_message(
str(conv["id"]),
prompt="q?",
placeholder_response="...",
request_id="req-repo-1",
status="pending",
)
message_uuid = str(msg["id"])
repo = ToolCallAttemptsRepository(pg_conn)
repo.record_proposed("c-m", "tool", "act", {})
assert (
repo.mark_executed("c-m", {"out": "ok"}, message_id=message_uuid) is True
)
row = _select_attempt(pg_conn, "c-m")
assert row["status"] == "executed"
assert str(row["message_id"]) == message_uuid
def test_upsert_executed_without_message_confirms(self, pg_conn):
"""``upsert_executed`` (DB-outage fallback) with no ``message_id`` lands ``confirmed``."""
from application.storage.db.repositories.tool_call_attempts import (
ToolCallAttemptsRepository,
)
repo = ToolCallAttemptsRepository(pg_conn)
repo.upsert_executed("c-up", "tool", "act", {"a": 1}, {"out": "ok"})
row = _select_attempt(pg_conn, "c-up")
assert row["status"] == "confirmed"
assert row["message_id"] is None
assert row["result"] == {"result": {"out": "ok"}}
def test_mark_failed_sets_error(self, pg_conn):

View File

@@ -81,103 +81,104 @@ class TestAPIToolInit:
@pytest.mark.unit
class TestMakeApiCall:
@patch("application.agents.tools.api_tool.pinned_request")
def test_successful_get(self, mock_pinned, get_tool):
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_successful_get(self, mock_get, mock_validate, get_tool):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {"result": "ok"}
mock_resp.content = b'{"result":"ok"}'
mock_pinned.return_value = mock_resp
mock_get.return_value = mock_resp
result = get_tool.execute_action("any_action")
assert result["status_code"] == 200
assert result["data"] == {"result": "ok"}
assert result["message"] == "API call successful."
assert mock_pinned.call_args[0][0] == "GET"
@patch("application.agents.tools.api_tool.pinned_request")
def test_successful_post(self, mock_pinned, post_tool):
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.post")
def test_successful_post(self, mock_post, mock_validate, post_tool):
mock_resp = MagicMock()
mock_resp.status_code = 201
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {"id": 1}
mock_resp.content = b'{"id":1}'
mock_pinned.return_value = mock_resp
mock_post.return_value = mock_resp
result = post_tool.execute_action("create", name="test")
assert result["status_code"] == 201
assert mock_pinned.call_args[0][0] == "POST"
@patch("application.agents.tools.api_tool.pinned_request")
def test_put_method(self, mock_pinned):
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.put")
def test_put_method(self, mock_put, mock_validate):
tool = APITool(config={"url": "https://example.com/item/1", "method": "PUT"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {}
mock_resp.content = b'{}'
mock_pinned.return_value = mock_resp
mock_put.return_value = mock_resp
result = tool.execute_action("update", name="new")
assert result["status_code"] == 200
assert mock_pinned.call_args[0][0] == "PUT"
@patch("application.agents.tools.api_tool.pinned_request")
def test_delete_method(self, mock_pinned):
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.delete")
def test_delete_method(self, mock_delete, mock_validate):
tool = APITool(config={"url": "https://example.com/item/1", "method": "DELETE"})
mock_resp = MagicMock()
mock_resp.status_code = 204
mock_resp.headers = {"Content-Type": "text/plain"}
mock_resp.content = b''
mock_pinned.return_value = mock_resp
mock_delete.return_value = mock_resp
result = tool.execute_action("delete")
assert result["status_code"] == 204
assert mock_pinned.call_args[0][0] == "DELETE"
@patch("application.agents.tools.api_tool.pinned_request")
def test_patch_method(self, mock_pinned):
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.patch")
def test_patch_method(self, mock_patch, mock_validate):
tool = APITool(config={"url": "https://example.com/item/1", "method": "PATCH"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {"patched": True}
mock_resp.content = b'{"patched":true}'
mock_pinned.return_value = mock_resp
mock_patch.return_value = mock_resp
result = tool.execute_action("patch", field="val")
assert result["status_code"] == 200
assert mock_pinned.call_args[0][0] == "PATCH"
@patch("application.agents.tools.api_tool.pinned_request")
def test_head_method(self, mock_pinned):
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.head")
def test_head_method(self, mock_head, mock_validate):
tool = APITool(config={"url": "https://example.com", "method": "HEAD"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "text/html"}
mock_resp.content = b''
mock_pinned.return_value = mock_resp
mock_head.return_value = mock_resp
result = tool.execute_action("check")
assert result["status_code"] == 200
assert mock_pinned.call_args[0][0] == "HEAD"
@patch("application.agents.tools.api_tool.pinned_request")
def test_options_method(self, mock_pinned):
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.options")
def test_options_method(self, mock_options, mock_validate):
tool = APITool(config={"url": "https://example.com", "method": "OPTIONS"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "text/plain"}
mock_resp.content = b''
mock_pinned.return_value = mock_resp
mock_options.return_value = mock_resp
result = tool.execute_action("options")
assert result["status_code"] == 200
assert mock_pinned.call_args[0][0] == "OPTIONS"
def test_unsupported_method(self):
@patch("application.agents.tools.api_tool.validate_url")
def test_unsupported_method(self, mock_validate):
tool = APITool(config={"url": "https://example.com", "method": "CUSTOM"})
result = tool.execute_action("any")
assert result["status_code"] is None
@@ -192,18 +193,19 @@ class TestMakeApiCall:
@pytest.mark.unit
class TestSSRFValidation:
@patch("application.agents.tools.api_tool.pinned_request")
def test_ssrf_blocked(self, mock_pinned, get_tool):
from application.security.safe_url import UnsafeUserUrlError
@patch("application.agents.tools.api_tool.validate_url")
def test_ssrf_blocked_initial_url(self, mock_validate, get_tool):
from application.core.url_validation import SSRFError
mock_pinned.side_effect = UnsafeUserUrlError("blocked")
mock_validate.side_effect = SSRFError("blocked")
result = get_tool.execute_action("any")
assert result["status_code"] is None
assert "URL validation error" in result["message"]
@patch("application.agents.tools.api_tool.pinned_request")
def test_ssrf_blocked_with_path_params(self, mock_pinned):
from application.security.safe_url import UnsafeUserUrlError
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_ssrf_blocked_after_param_substitution(self, mock_get, mock_validate):
from application.core.url_validation import SSRFError
tool = APITool(config={
"url": "https://api.example.com/{host}/data",
@@ -211,7 +213,14 @@ class TestSSRFValidation:
"query_params": {"host": "169.254.169.254"},
})
mock_pinned.side_effect = UnsafeUserUrlError("blocked")
call_count = [0]
def side_effect(url):
call_count[0] += 1
if call_count[0] == 2:
raise SSRFError("blocked after substitution")
mock_validate.side_effect = side_effect
result = tool.execute_action("any")
assert result["status_code"] is None
assert "URL validation error" in result["message"]
@@ -225,36 +234,40 @@ class TestSSRFValidation:
@pytest.mark.unit
class TestErrorHandling:
@patch("application.agents.tools.api_tool.pinned_request")
def test_timeout_error(self, mock_pinned, get_tool):
mock_pinned.side_effect = requests.exceptions.Timeout()
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_timeout_error(self, mock_get, mock_validate, get_tool):
mock_get.side_effect = requests.exceptions.Timeout()
result = get_tool.execute_action("any")
assert result["status_code"] is None
assert "timeout" in result["message"].lower()
@patch("application.agents.tools.api_tool.pinned_request")
def test_connection_error(self, mock_pinned, get_tool):
mock_pinned.side_effect = requests.exceptions.ConnectionError("refused")
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_connection_error(self, mock_get, mock_validate, get_tool):
mock_get.side_effect = requests.exceptions.ConnectionError("refused")
result = get_tool.execute_action("any")
assert result["status_code"] is None
assert "Connection error" in result["message"]
@patch("application.agents.tools.api_tool.pinned_request")
def test_http_error_with_json(self, mock_pinned, get_tool):
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_http_error_with_json(self, mock_get, mock_validate, get_tool):
mock_resp = MagicMock()
mock_resp.status_code = 422
mock_resp.json.return_value = {"error": "invalid_field"}
mock_resp.raise_for_status.side_effect = requests.exceptions.HTTPError(
response=mock_resp
)
mock_pinned.return_value = mock_resp
mock_get.return_value = mock_resp
result = get_tool.execute_action("any")
assert result["status_code"] == 422
assert result["data"] == {"error": "invalid_field"}
@patch("application.agents.tools.api_tool.pinned_request")
def test_http_error_non_json_body(self, mock_pinned, get_tool):
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_http_error_non_json_body(self, mock_get, mock_validate, get_tool):
mock_resp = MagicMock()
mock_resp.status_code = 404
mock_resp.text = "Not Found"
@@ -262,26 +275,29 @@ class TestErrorHandling:
mock_resp.raise_for_status.side_effect = requests.exceptions.HTTPError(
response=mock_resp
)
mock_pinned.return_value = mock_resp
mock_get.return_value = mock_resp
result = get_tool.execute_action("any")
assert result["status_code"] == 404
assert result["data"] == "Not Found"
@patch("application.agents.tools.api_tool.pinned_request")
def test_request_exception(self, mock_pinned, get_tool):
mock_pinned.side_effect = requests.exceptions.RequestException("something")
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_request_exception(self, mock_get, mock_validate, get_tool):
mock_get.side_effect = requests.exceptions.RequestException("something")
result = get_tool.execute_action("any")
assert "API call failed" in result["message"]
@patch("application.agents.tools.api_tool.pinned_request")
def test_unexpected_exception(self, mock_pinned, get_tool):
mock_pinned.side_effect = RuntimeError("unexpected")
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_unexpected_exception(self, mock_get, mock_validate, get_tool):
mock_get.side_effect = RuntimeError("unexpected")
result = get_tool.execute_action("any")
assert "Unexpected error" in result["message"]
@patch("application.agents.tools.api_tool.pinned_request")
def test_body_serialization_error(self, mock_pinned):
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.post")
def test_body_serialization_error(self, mock_post, mock_validate):
tool = APITool(config={
"url": "https://example.com",
"method": "POST",
@@ -304,8 +320,9 @@ class TestErrorHandling:
@pytest.mark.unit
class TestPathParamSubstitution:
@patch("application.agents.tools.api_tool.pinned_request")
def test_path_params_substituted(self, mock_pinned):
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_path_params_substituted(self, mock_get, mock_validate):
tool = APITool(config={
"url": "https://api.example.com/users/{user_id}/posts/{post_id}",
"method": "GET",
@@ -316,16 +333,17 @@ class TestPathParamSubstitution:
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = []
mock_resp.content = b'[]'
mock_pinned.return_value = mock_resp
mock_get.return_value = mock_resp
tool.execute_action("get")
called_url = mock_pinned.call_args[0][1]
called_url = mock_get.call_args[0][0]
assert "/users/42/posts/7" in called_url
assert "{user_id}" not in called_url
@patch("application.agents.tools.api_tool.pinned_request")
def test_remaining_query_params_appended(self, mock_pinned):
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_remaining_query_params_appended(self, mock_get, mock_validate):
tool = APITool(config={
"url": "https://api.example.com/items",
"method": "GET",
@@ -336,16 +354,19 @@ class TestPathParamSubstitution:
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = []
mock_resp.content = b'[]'
mock_pinned.return_value = mock_resp
mock_get.return_value = mock_resp
tool.execute_action("get")
called_url = mock_pinned.call_args[0][1]
called_url = mock_get.call_args[0][0]
assert "page=2" in called_url
assert "limit=10" in called_url
@patch("application.agents.tools.api_tool.pinned_request")
def test_query_params_append_with_existing_query_string(self, mock_pinned):
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_query_params_append_with_existing_query_string(
self, mock_get, mock_validate
):
tool = APITool(config={
"url": "https://api.example.com/items?existing=true",
"method": "GET",
@@ -356,65 +377,27 @@ class TestPathParamSubstitution:
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = []
mock_resp.content = b'[]'
mock_pinned.return_value = mock_resp
mock_get.return_value = mock_resp
tool.execute_action("get")
called_url = mock_pinned.call_args[0][1]
called_url = mock_get.call_args[0][0]
assert "&page=1" in called_url
@patch("application.agents.tools.api_tool.pinned_request")
def test_empty_body_no_serialization(self, mock_pinned):
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.post")
def test_empty_body_no_serialization(self, mock_post, mock_validate):
tool = APITool(config={"url": "https://example.com", "method": "POST"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {}
mock_resp.content = b'{}'
mock_pinned.return_value = mock_resp
mock_post.return_value = mock_resp
result = tool.execute_action("create")
assert result["status_code"] == 200
@patch("application.agents.tools.api_tool.pinned_request")
def test_path_params_are_url_encoded(self, mock_pinned):
tool = APITool(config={
"url": "https://api.example.com/users/{user_id}/profile",
"method": "GET",
"query_params": {"user_id": "../../admin"},
})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {}
mock_resp.content = b'{}'
mock_pinned.return_value = mock_resp
tool.execute_action("get")
called_url = mock_pinned.call_args[0][1]
assert "../../admin" not in called_url
assert "%2F" in called_url or "%2f" in called_url
@patch("application.agents.tools.api_tool.pinned_request")
def test_path_params_query_injection_encoded(self, mock_pinned):
tool = APITool(config={
"url": "https://api.example.com/items/{item_id}",
"method": "GET",
"query_params": {"item_id": "x?admin=true"},
})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {}
mock_resp.content = b'{}'
mock_pinned.return_value = mock_resp
tool.execute_action("get")
called_url = mock_pinned.call_args[0][1]
assert "x?admin=true" not in called_url
# =====================================================================
# Parse Response
@@ -511,8 +494,11 @@ class TestAPIToolMetadata:
def test_config_requirements_empty(self, get_tool):
assert get_tool.get_config_requirements() == {}
@patch("application.agents.tools.api_tool.pinned_request")
def test_content_type_set_for_post_with_no_headers(self, mock_pinned):
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.post")
def test_content_type_set_for_post_with_no_headers(
self, mock_post, mock_validate
):
tool = APITool(config={
"url": "https://example.com",
"method": "POST",
@@ -523,8 +509,8 @@ class TestAPIToolMetadata:
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {}
mock_resp.content = b'{}'
mock_pinned.return_value = mock_resp
mock_post.return_value = mock_resp
tool.execute_action("create")
call_headers = mock_pinned.call_args.kwargs["headers"]
call_headers = mock_post.call_args[1]["headers"]
assert "Content-Type" in call_headers

View File

@@ -6,7 +6,6 @@ from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from sqlalchemy import text
@pytest.fixture
@@ -257,39 +256,9 @@ class TestPaginatedSources:
for key in (
"id", "name", "date", "model", "location", "tokens",
"retriever", "syncFrequency", "provider", "isNested", "type",
"ingestStatus",
):
assert key in row
def test_exposes_stalled_ingest_status(self, app, pg_conn):
"""A source whose ingest the reconciler escalated to 'stalled'
surfaces ingestStatus='failed' so the UI can badge it.
"""
from application.api.user.sources.routes import PaginatedSources
user = "u-ingest-status"
src = _seed_source(pg_conn, user, name="stalled-doc", type="file")
pg_conn.execute(
text(
"""
INSERT INTO ingest_chunk_progress (
source_id, total_chunks, embedded_chunks, last_index,
status
)
VALUES (CAST(:sid AS uuid), 907, 9, 8, 'stalled')
"""
),
{"sid": str(src["id"])},
)
with _patch_db(pg_conn), app.test_request_context(
"/api/sources/paginated?page=1&rows=10"
):
from flask import request
request.decoded_token = {"sub": user}
response = PaginatedSources().get()
row = response.json["paginated"][0]
assert row["ingestStatus"] == "failed"
class TestDeleteOldIndexes:
def test_returns_401_unauthenticated(self, app):
@@ -584,35 +553,6 @@ class TestSyncSource:
assert response.status_code == 200
assert response.json["task_id"] == "task-123"
def test_normalizes_dict_remote_data_before_dispatch(self, app, pg_conn):
"""The route must hand the sync task the normalized URL string."""
from application.api.user.sources.routes import SyncSource
user = "u-normalize"
src = _seed_source(
pg_conn, user, name="crawl-src", type="crawler",
remote_data=json.dumps(
{"url": "https://example.com", "provider": "crawler"}
),
)
fake_task = MagicMock(id="task-norm")
with _patch_db(pg_conn), patch(
"application.api.user.sources.routes.sync_source.delay",
return_value=fake_task,
) as mock_delay, app.test_request_context(
"/api/sync_source",
method="POST",
json={"source_id": str(src["id"])},
):
from flask import request
request.decoded_token = {"sub": user}
response = SyncSource().post()
assert response.status_code == 200
assert mock_delay.call_args.kwargs["source_data"] == "https://example.com"
assert mock_delay.call_args.kwargs["loader"] == "crawler"
def test_sync_task_raises_returns_400(self, app, pg_conn):
from application.api.user.sources.routes import SyncSource
@@ -636,135 +576,6 @@ class TestSyncSource:
assert response.status_code == 400
class TestReingestSource:
def test_returns_401_unauthenticated(self, app):
from application.api.user.sources.routes import ReingestSource
with app.test_request_context(
"/api/sources/reingest", method="POST", json={"source_id": "x"}
):
from flask import request
request.decoded_token = None
response = ReingestSource().post()
assert response.status_code == 401
def test_returns_400_missing_id(self, app):
from application.api.user.sources.routes import ReingestSource
with app.test_request_context(
"/api/sources/reingest", method="POST", json={}
):
from flask import request
request.decoded_token = {"sub": "u"}
response = ReingestSource().post()
assert response.status_code == 400
def test_returns_404_missing_source(self, app, pg_conn):
from application.api.user.sources.routes import ReingestSource
with _patch_db(pg_conn), app.test_request_context(
"/api/sources/reingest",
method="POST",
json={"source_id": "00000000-0000-0000-0000-000000000000"},
):
from flask import request
request.decoded_token = {"sub": "u"}
response = ReingestSource().post()
assert response.status_code == 404
def test_triggers_reingest_task(self, app, pg_conn):
from application.api.user.sources.routes import ReingestSource
user = "u-reingest"
src = _seed_source(pg_conn, user, name="stalled-src", type="file")
fake_task = MagicMock(id="reingest-task-1")
with _patch_db(pg_conn), patch(
"application.api.user.sources.routes.reingest_source_task.delay",
return_value=fake_task,
) as mock_delay, app.test_request_context(
"/api/sources/reingest",
method="POST",
json={"source_id": str(src["id"])},
):
from flask import request
request.decoded_token = {"sub": user}
response = ReingestSource().post()
assert response.status_code == 200
assert response.json["task_id"] == "reingest-task-1"
assert mock_delay.call_args.kwargs["source_id"] == str(src["id"])
assert mock_delay.call_args.kwargs["user"] == user
# Scoped idempotency key engages the task's lease so repeated
# clicks collapse onto one reingest instead of racing.
assert mock_delay.call_args.kwargs["idempotency_key"] == (
f"reingest-source:{user}:{src['id']}"
)
def test_clears_stalled_ingest_progress_row(self, app, pg_conn):
"""Reingest drops the stale chunk-progress row so the sources
list stops deriving a 'failed' ingest status for the source.
"""
from application.api.user.sources.routes import ReingestSource
user = "u-reingest-clear"
src = _seed_source(pg_conn, user, name="stalled-doc", type="file")
pg_conn.execute(
text(
"""
INSERT INTO ingest_chunk_progress (
source_id, total_chunks, embedded_chunks, last_index,
status
)
VALUES (CAST(:sid AS uuid), 100, 9, 8, 'stalled')
"""
),
{"sid": str(src["id"])},
)
fake_task = MagicMock(id="reingest-task-2")
with _patch_db(pg_conn), patch(
"application.api.user.sources.routes.reingest_source_task.delay",
return_value=fake_task,
), app.test_request_context(
"/api/sources/reingest",
method="POST",
json={"source_id": str(src["id"])},
):
from flask import request
request.decoded_token = {"sub": user}
response = ReingestSource().post()
assert response.status_code == 200
remaining = pg_conn.execute(
text(
"SELECT count(*) FROM ingest_chunk_progress "
"WHERE source_id = CAST(:sid AS uuid)"
),
{"sid": str(src["id"])},
).scalar()
assert remaining == 0
def test_reingest_task_raises_returns_400(self, app, pg_conn):
from application.api.user.sources.routes import ReingestSource
user = "u-reingest-fail"
src = _seed_source(pg_conn, user, name="fail-src", type="file")
with _patch_db(pg_conn), patch(
"application.api.user.sources.routes.reingest_source_task.delay",
side_effect=RuntimeError("boom"),
), app.test_request_context(
"/api/sources/reingest",
method="POST",
json={"source_id": str(src["id"])},
):
from flask import request
request.decoded_token = {"sub": user}
response = ReingestSource().post()
assert response.status_code == 400
class TestDirectoryStructure:
def test_returns_401_unauthenticated(self, app):
from application.api.user.sources.routes import DirectoryStructure

View File

@@ -417,181 +417,3 @@ class TestSuccessfulRunClearsLease:
assert row[0] == "completed"
assert row[1] is None
assert row[2] is None
@pytest.mark.unit
class TestSynthesizedKeyGuardsKeylessDispatch:
"""A keyless dispatch carrying ``source_id`` is still poison-guarded:
the wrapper synthesizes a deterministic key from ``source_id``.
"""
def test_keyless_with_source_id_records_dedup_row(self, pg_conn):
from application.api.user.idempotency import with_idempotency
@with_idempotency(task_name="ingest")
def task(self, idempotency_key=None, source_id=None):
return {"ran": True}
with _patch_decorator_db(pg_conn):
result = task(_fake_celery_self(), source_id="src-abc")
assert result == {"ran": True}
row = _row_for(pg_conn, "auto:ingest:src-abc")
assert row is not None
assert row[0] == "ingest"
assert row[2] == "completed"
def test_synthesized_key_stable_across_redeliveries(self, pg_conn):
"""Same ``source_id`` → same key → a redelivery short-circuits to
the cached result instead of re-running the body.
"""
from application.api.user.idempotency import with_idempotency
runs = {"count": 0}
@with_idempotency(task_name="ingest")
def task(self, idempotency_key=None, source_id=None):
runs["count"] += 1
return {"n": runs["count"]}
with _patch_decorator_db(pg_conn):
first = task(_fake_celery_self(), source_id="src-1")
second = task(_fake_celery_self(), source_id="src-1")
assert first == second == {"n": 1}
assert runs["count"] == 1
def test_poison_guard_trips_for_keyless_dispatch(self, pg_conn):
"""The core fix: a keyless OOM-looping dispatch is bounded — the
guard trips after MAX_TASK_ATTEMPTS with no explicit key.
"""
from application.api.user.idempotency import (
MAX_TASK_ATTEMPTS, with_idempotency,
)
runs = {"count": 0}
@with_idempotency(task_name="ingest")
def task(self, idempotency_key=None, source_id=None):
runs["count"] += 1
raise RuntimeError("OOM-style failure")
with _patch_decorator_db(pg_conn):
for _ in range(MAX_TASK_ATTEMPTS):
with pytest.raises(RuntimeError):
task(_fake_celery_self(), source_id="src-poison")
result = task(_fake_celery_self(), source_id="src-poison")
assert runs["count"] == MAX_TASK_ATTEMPTS
assert result["success"] is False
assert "poison-loop" in result["error"]
assert _row_for(pg_conn, "auto:ingest:src-poison")[2] == "failed"
def test_no_source_id_no_key_runs_unguarded(self, pg_conn):
"""No explicit key and no ``source_id`` anchor → pass through with
no DB writes, exactly as before.
"""
from application.api.user.idempotency import with_idempotency
@with_idempotency(task_name="store_attachment")
def task(self, idempotency_key=None):
return {"ran": True}
with patch(
"application.api.user.idempotency.db_session"
) as mock_session, patch(
"application.api.user.idempotency.db_readonly"
) as mock_readonly:
result = task(_fake_celery_self())
assert result == {"ran": True}
assert mock_session.call_count == 0
assert mock_readonly.call_count == 0
def test_explicit_key_takes_precedence_over_source_id(self, pg_conn):
"""An explicit key wins; the synthesized ``auto:`` key is unused."""
from application.api.user.idempotency import with_idempotency
@with_idempotency(task_name="ingest")
def task(self, idempotency_key=None, source_id=None):
return {"ran": True}
with _patch_decorator_db(pg_conn):
task(
_fake_celery_self(),
idempotency_key="explicit-k",
source_id="src-x",
)
assert _row_for(pg_conn, "explicit-k") is not None
assert _row_for(pg_conn, "auto:ingest:src-x") is None
@pytest.mark.unit
class TestPoisonHook:
"""``on_poison`` fires on the poison-guard branch with the task's
bound arguments, and never on the success path.
"""
def test_hook_invoked_with_bound_args_on_poison(self, pg_conn):
from application.api.user.idempotency import (
MAX_TASK_ATTEMPTS, with_idempotency,
)
captured = []
def _hook(task_name, bound):
captured.append((task_name, bound))
@with_idempotency(task_name="ingest", on_poison=_hook)
def task(self, idempotency_key=None, source_id=None):
raise RuntimeError("never converges")
with _patch_decorator_db(pg_conn):
for _ in range(MAX_TASK_ATTEMPTS):
with pytest.raises(RuntimeError):
task(_fake_celery_self(), source_id="src-h")
task(_fake_celery_self(), source_id="src-h")
assert len(captured) == 1
task_name, bound = captured[0]
assert task_name == "ingest"
assert bound["source_id"] == "src-h"
def test_hook_not_invoked_on_success(self, pg_conn):
from application.api.user.idempotency import with_idempotency
calls = []
@with_idempotency(
task_name="ingest", on_poison=lambda *a: calls.append(a)
)
def task(self, idempotency_key=None, source_id=None):
return {"ok": True}
with _patch_decorator_db(pg_conn):
task(_fake_celery_self(), source_id="src-ok")
assert calls == []
def test_hook_failure_does_not_break_poison_return(self, pg_conn):
"""A throwing hook must not change the poison-guard outcome."""
from application.api.user.idempotency import (
MAX_TASK_ATTEMPTS, with_idempotency,
)
def _bad_hook(task_name, bound):
raise ValueError("hook blew up")
@with_idempotency(task_name="ingest", on_poison=_bad_hook)
def task(self, idempotency_key=None, source_id=None):
raise RuntimeError("never converges")
with _patch_decorator_db(pg_conn):
for _ in range(MAX_TASK_ATTEMPTS):
with pytest.raises(RuntimeError):
task(_fake_celery_self(), source_id="src-bad")
result = task(_fake_celery_self(), source_id="src-bad")
assert result["success"] is False
assert "poison-loop" in result["error"]

View File

@@ -529,142 +529,6 @@ class TestStuckExecutedToolCalls:
assert row[0] == "executed"
# ---------------------------------------------------------------------------
# Q4 — stalled ingest checkpoints (escalate to terminal 'stalled' + alert)
# ---------------------------------------------------------------------------
def _seed_ingest_progress(
conn,
*,
source_id: str,
embedded: int,
total: int,
age_minutes: int = 31,
status: str = "active",
) -> str:
"""Insert an ingest_chunk_progress row with a backdated last_updated."""
conn.execute(
text(
"""
INSERT INTO ingest_chunk_progress (
source_id, total_chunks, embedded_chunks, last_index,
last_updated, status
)
VALUES (
CAST(:sid AS uuid), :total, :embedded, :embedded - 1,
clock_timestamp() - make_interval(mins => :age),
:status
)
"""
),
{
"sid": source_id,
"total": total,
"embedded": embedded,
"age": age_minutes,
"status": status,
},
)
return source_id
def _ingest_status(conn, source_id: str) -> str | None:
"""Return the ``status`` of an ingest_chunk_progress row, or None."""
row = conn.execute(
text(
"SELECT status FROM ingest_chunk_progress "
"WHERE source_id = CAST(:sid AS uuid)"
),
{"sid": source_id},
).fetchone()
return row[0] if row is not None else None
class TestStalledIngests:
@pytest.mark.unit
def test_stalled_ingest_escalated_with_alert(self, pg_conn, caplog):
from application.api.user.reconciliation import run_reconciliation
sid = "1a000000-0000-0000-0000-0000000000a1"
_seed_ingest_progress(pg_conn, source_id=sid, embedded=9, total=907)
before = _stack_logs_count(pg_conn, "reconciler_ingest_stalled")
with _route_engine_to(pg_conn), caplog.at_level(
logging.ERROR, logger="application.api.user.reconciliation",
):
r = run_reconciliation()
assert r["ingests_stalled"] == 1
# Escalated to a terminal status so the next tick skips it.
assert _ingest_status(pg_conn, sid) == "stalled"
# Structured alert + stack_logs row both surface the failure.
assert any(
getattr(rec, "alert", None) == "reconciler_ingest_stalled"
and rec.levelname == "ERROR"
for rec in caplog.records
)
assert (
_stack_logs_count(pg_conn, "reconciler_ingest_stalled")
== before + 1
)
@pytest.mark.unit
def test_stalled_ingest_alerts_once_not_every_tick(self, pg_conn):
"""The escalate-to-'stalled' write ends the re-alert loop: a
second tick neither re-counts nor re-logs the same dead ingest.
"""
from application.api.user.reconciliation import run_reconciliation
sid = "1a000000-0000-0000-0000-0000000000a2"
_seed_ingest_progress(pg_conn, source_id=sid, embedded=1, total=95)
before = _stack_logs_count(pg_conn, "reconciler_ingest_stalled")
with _route_engine_to(pg_conn):
r1 = run_reconciliation()
r2 = run_reconciliation()
assert r1["ingests_stalled"] == 1
assert r2["ingests_stalled"] == 0
# Only the first tick wrote an alert row.
assert (
_stack_logs_count(pg_conn, "reconciler_ingest_stalled")
== before + 1
)
@pytest.mark.unit
def test_fresh_ingest_left_alone(self, pg_conn):
from application.api.user.reconciliation import run_reconciliation
sid = "1a000000-0000-0000-0000-0000000000a3"
# 2 minutes old — well under the 30-minute staleness threshold.
_seed_ingest_progress(
pg_conn, source_id=sid, embedded=3, total=20, age_minutes=2,
)
with _route_engine_to(pg_conn):
r = run_reconciliation()
assert r["ingests_stalled"] == 0
assert _ingest_status(pg_conn, sid) == "active"
@pytest.mark.unit
def test_completed_ingest_left_alone(self, pg_conn):
"""A stale checkpoint that finished embedding (embedded == total)
is not a stall and must not be flagged.
"""
from application.api.user.reconciliation import run_reconciliation
sid = "1a000000-0000-0000-0000-0000000000a4"
_seed_ingest_progress(pg_conn, source_id=sid, embedded=50, total=50)
with _route_engine_to(pg_conn):
r = run_reconciliation()
assert r["ingests_stalled"] == 0
assert _ingest_status(pg_conn, sid) == "active"
# ---------------------------------------------------------------------------
# Q5 — stuck idempotency pending rows (lease expired + attempts exhausted)
# ---------------------------------------------------------------------------

View File

@@ -546,63 +546,3 @@ class TestIngestIdempotency:
assert first == second
assert first == {"status": "ok", "directory": "dir"}
assert len(worker_calls) == 1
class TestIngestPoisonEvent:
"""The poison hook publishes a terminal source.ingest.failed so the
upload toast resolves instead of hanging on "training".
"""
@pytest.mark.unit
def test_publishes_failed_event(self):
from application.api.user.tasks import _emit_ingest_poison_event
published = []
def _fake_publish(user, event_type, payload, *, scope=None):
published.append((user, event_type, payload, scope))
with patch(
"application.events.publisher.publish_user_event",
side_effect=_fake_publish,
):
_emit_ingest_poison_event(
"ingest",
{"user": "u1", "source_id": "src-9", "filename": "doc.pdf"},
)
assert len(published) == 1
user, event_type, payload, scope = published[0]
assert user == "u1"
assert event_type == "source.ingest.failed"
assert payload["source_id"] == "src-9"
assert payload["filename"] == "doc.pdf"
assert payload["operation"] == "upload"
assert scope == {"kind": "source", "id": "src-9"}
@pytest.mark.unit
def test_skips_when_source_id_missing(self):
from application.api.user.tasks import _emit_ingest_poison_event
with patch(
"application.events.publisher.publish_user_event",
) as mock_publish:
_emit_ingest_poison_event("ingest", {"user": "u1"})
mock_publish.assert_not_called()
@pytest.mark.unit
def test_reingest_uses_reingest_operation(self):
from application.api.user.tasks import _emit_ingest_poison_event
published = []
with patch(
"application.events.publisher.publish_user_event",
side_effect=lambda *a, **k: published.append((a, k)),
):
_emit_ingest_poison_event(
"reingest_source_task",
{"user": "u1", "source_id": "src-r"},
)
assert published[0][0][2]["operation"] == "reingest"

View File

@@ -158,35 +158,6 @@ class TestSimpleDirectoryReaderLoadData:
for doc in docs:
assert isinstance(doc, Document)
def test_load_data_progress_callback_fires_per_file(self, temp_dir):
from application.parser.file.bulk import SimpleDirectoryReader
reader = SimpleDirectoryReader(
input_dir=str(temp_dir), recursive=False, exclude_hidden=True,
)
calls = []
reader.load_data(progress_callback=lambda done, total: calls.append((done, total)))
total_files = len(reader.input_files)
assert total_files >= 1
# One callback per file, monotonically increasing, ending at total.
assert [c[0] for c in calls] == list(range(1, total_files + 1))
assert all(c[1] == total_files for c in calls)
def test_load_data_progress_callback_errors_swallowed(self, temp_dir):
from application.parser.file.bulk import SimpleDirectoryReader
reader = SimpleDirectoryReader(
input_dir=str(temp_dir), recursive=False, exclude_hidden=True,
)
def _boom(done, total):
raise RuntimeError("callback blew up")
# A failing callback must not abort ingestion.
docs = reader.load_data(progress_callback=_boom)
assert len(docs) >= 1
def test_load_data_concatenate(self, temp_dir):
from application.parser.file.bulk import SimpleDirectoryReader

View File

@@ -421,85 +421,3 @@ class TestDoclingParserGaps:
parser = DoclingCSVParser()
assert parser.export_format == "markdown"
assert parser.ocr_enabled is True
# =====================================================================
# Pipeline memory caps
# =====================================================================
@pytest.mark.unit
class TestApplyPipelineCaps:
"""_apply_pipeline_caps bounds docling's threaded-pipeline buffering."""
def test_caps_threaded_pipeline_knobs(self, monkeypatch):
from application.core.settings import settings
from application.parser.file.docling_parser import _apply_pipeline_caps
monkeypatch.setattr(
settings, "DOCLING_PIPELINE_QUEUE_MAX_SIZE", 2, raising=False
)
class Opts:
# docling >= 2.94 threaded pipeline — all knobs present.
queue_max_size = 100
layout_batch_size = 4
table_batch_size = 4
ocr_batch_size = 4
opts = Opts()
_apply_pipeline_caps(opts)
assert opts.queue_max_size == 2
assert opts.layout_batch_size == 1
assert opts.table_batch_size == 1
assert opts.ocr_batch_size == 1
def test_queue_size_is_settings_driven(self, monkeypatch):
from application.core.settings import settings
from application.parser.file.docling_parser import _apply_pipeline_caps
monkeypatch.setattr(
settings, "DOCLING_PIPELINE_QUEUE_MAX_SIZE", 6, raising=False
)
class Opts:
queue_max_size = 100
opts = Opts()
_apply_pipeline_caps(opts)
assert opts.queue_max_size == 6
def test_misconfigured_zero_floors_to_one(self, monkeypatch):
"""A 0 queue depth could deadlock the threaded pipeline — floor it."""
from application.core.settings import settings
from application.parser.file.docling_parser import _apply_pipeline_caps
monkeypatch.setattr(
settings, "DOCLING_PIPELINE_QUEUE_MAX_SIZE", 0, raising=False
)
class Opts:
queue_max_size = 100
opts = Opts()
_apply_pipeline_caps(opts)
assert opts.queue_max_size == 1
def test_noop_on_docling_without_threaded_pipeline(self):
"""Builds predating the threaded pipeline lack the knobs — the cap
must be a silent no-op, not an AttributeError."""
from application.parser.file.docling_parser import _apply_pipeline_caps
class LegacyOpts:
__slots__ = ("do_ocr", "do_table_structure")
def __init__(self):
self.do_ocr = False
self.do_table_structure = True
opts = LegacyOpts()
_apply_pipeline_caps(opts) # must not raise
assert not hasattr(opts, "queue_max_size")
assert not hasattr(opts, "layout_batch_size")

View File

@@ -94,35 +94,6 @@ def test_embed_and_store_documents_non_faiss(tmp_path, mock_settings, mock_vecto
assert folder_name.exists()
def test_embed_and_store_documents_progress_band(
tmp_path, mock_settings, mock_vector_creator
):
"""progress_start/progress_end remap the embed loop into a sub-band
so an earlier stage (parsing) can own the lower part of the bar.
"""
mock_settings.VECTOR_STORE = "chromadb"
docs = [MagicMock(page_content=f"d{i}", metadata={}) for i in range(4)]
task_status = MagicMock()
mock_vector_creator.create_vectorstore.return_value = MagicMock()
embed_and_store_documents(
docs, str(tmp_path / "store"), "sid", task_status,
progress_start=50, progress_end=100,
)
currents = [
call.kwargs["meta"]["current"]
for call in task_status.update_state.call_args_list
if "meta" in call.kwargs and "current" in call.kwargs["meta"]
]
assert currents, "expected progress updates"
# Embedding stays in the upper band and tops out at 100.
assert min(currents) > 50
assert max(currents) == 100
assert currents == sorted(currents)
@patch("application.parser.embedding_pipeline.add_text_to_store_with_retry")
def test_embed_and_store_documents_partial_failure_raises(
mock_add_retry, tmp_path, mock_settings, mock_vector_creator, caplog

View File

@@ -1,6 +1,4 @@
"""Tests for application.parser.remote.remote_creator."""
import json
"""Tests for application.parser.remote.remote_creator covering lines 31-34."""
import pytest
from unittest.mock import MagicMock
@@ -40,92 +38,3 @@ class TestRemoteCreator:
mock_loader_cls.assert_called_once()
finally:
RemoteCreator.loaders = original_loaders
@pytest.mark.unit
class TestNormalizeRemoteData:
"""``normalize_remote_data`` maps a stored JSONB ``remote_data`` value
back to the ``source_data`` shape each loader expects."""
def test_none_passes_through(self):
from application.parser.remote.remote_creator import normalize_remote_data
assert normalize_remote_data("crawler", None) is None
def test_crawler_dict_with_url_key(self):
from application.parser.remote.remote_creator import normalize_remote_data
result = normalize_remote_data(
"crawler", {"url": "https://example.com", "provider": "crawler"}
)
assert result == "https://example.com"
def test_url_dict_with_url_key(self):
from application.parser.remote.remote_creator import normalize_remote_data
result = normalize_remote_data("url", {"url": "https://example.com"})
assert result == "https://example.com"
def test_url_legacy_raw_key(self):
"""Legacy rows wrapped a bare URL string as ``{"raw": ...}``."""
from application.parser.remote.remote_creator import normalize_remote_data
result = normalize_remote_data("crawler", {"raw": "https://legacy.example.com"})
assert result == "https://legacy.example.com"
def test_url_dict_with_urls_list(self):
from application.parser.remote.remote_creator import normalize_remote_data
result = normalize_remote_data(
"url", {"urls": ["https://a.example.com", "https://b.example.com"]}
)
assert result == ["https://a.example.com", "https://b.example.com"]
def test_github_repo_url_key(self):
from application.parser.remote.remote_creator import normalize_remote_data
result = normalize_remote_data(
"github", {"repo_url": "https://github.com/arc53/DocsGPT"}
)
assert result == "https://github.com/arc53/DocsGPT"
def test_sitemap_dict_with_url_key(self):
from application.parser.remote.remote_creator import normalize_remote_data
result = normalize_remote_data("sitemap", {"url": "https://example.com/sitemap.xml"})
assert result == "https://example.com/sitemap.xml"
def test_plain_string_url_passes_through(self):
from application.parser.remote.remote_creator import normalize_remote_data
assert normalize_remote_data("crawler", "https://example.com") == "https://example.com"
def test_url_dict_without_url_key_returns_none(self):
"""A URL-type loader must never receive a dict, even a malformed one."""
from application.parser.remote.remote_creator import normalize_remote_data
assert normalize_remote_data("crawler", {"provider": "crawler"}) is None
def test_reddit_dict_serialized_to_json_string(self):
"""reddit's loader runs json.loads() — it needs a JSON string."""
from application.parser.remote.remote_creator import normalize_remote_data
result = normalize_remote_data(
"reddit", {"client_id": "x", "search_queries": ["y"]}
)
assert isinstance(result, str)
assert json.loads(result) == {"client_id": "x", "search_queries": ["y"]}
def test_s3_dict_passes_through(self):
"""S3Loader.load_data() accepts a dict, so it is left untouched."""
from application.parser.remote.remote_creator import normalize_remote_data
data = {"bucket": "b", "prefix": "k"}
assert normalize_remote_data("s3", data) == data
def test_json_string_remote_data_is_parsed(self):
"""Legacy rows that stored the JSON itself as a string still resolve."""
from application.parser.remote.remote_creator import normalize_remote_data
result = normalize_remote_data("crawler", '{"url": "https://example.com"}')
assert result == "https://example.com"

View File

@@ -1,74 +0,0 @@
"""Tests for IngestChunkProgressRepository against ephemeral Postgres."""
from __future__ import annotations
from sqlalchemy import text
from application.storage.db.repositories.ingest_chunk_progress import (
IngestChunkProgressRepository,
)
def _status(conn, source_id: str) -> str:
return conn.execute(
text(
"SELECT status FROM ingest_chunk_progress "
"WHERE source_id = CAST(:sid AS uuid)"
),
{"sid": source_id},
).scalar()
def _mark_stalled(conn, source_id: str) -> None:
conn.execute(
text(
"UPDATE ingest_chunk_progress SET status = 'stalled' "
"WHERE source_id = CAST(:sid AS uuid)"
),
{"sid": source_id},
)
class TestInitProgressStatus:
def test_new_row_starts_active(self, pg_conn):
sid = "3c000000-0000-0000-0000-0000000000c1"
IngestChunkProgressRepository(pg_conn).init_progress(sid, 10, "att-1")
assert _status(pg_conn, sid) == "active"
def test_reingest_resets_stalled_to_active(self, pg_conn):
"""A reconciler-escalated 'stalled' row flips back to 'active'
when the source is reingested under a fresh attempt id.
"""
sid = "3c000000-0000-0000-0000-0000000000c2"
repo = IngestChunkProgressRepository(pg_conn)
repo.init_progress(sid, 10, "att-1")
_mark_stalled(pg_conn, sid)
repo.init_progress(sid, 10, "att-2")
assert _status(pg_conn, sid) == "active"
def test_same_attempt_retry_also_clears_stalled(self, pg_conn):
"""A same-attempt resume (Celery autoretry) also clears a stale
'stalled' flag — the task is running again.
"""
sid = "3c000000-0000-0000-0000-0000000000c3"
repo = IngestChunkProgressRepository(pg_conn)
repo.init_progress(sid, 10, "att-1")
_mark_stalled(pg_conn, sid)
repo.init_progress(sid, 10, "att-1")
assert _status(pg_conn, sid) == "active"
class TestDelete:
def test_delete_removes_row(self, pg_conn):
sid = "3c000000-0000-0000-0000-0000000000d1"
repo = IngestChunkProgressRepository(pg_conn)
repo.init_progress(sid, 10, "att-1")
assert repo.delete(sid) is True
assert repo.get_progress(sid) is None
def test_delete_missing_row_returns_false(self, pg_conn):
repo = IngestChunkProgressRepository(pg_conn)
assert repo.delete("3c000000-0000-0000-0000-0000000000df") is False

View File

@@ -342,89 +342,6 @@ class TestMarkToolCallFailed:
assert row[1] == "oops"
def _seed_ingest_progress(
conn,
*,
source_id: str,
embedded: int,
total: int,
age_minutes: int = 31,
status: str = "active",
) -> None:
"""Seed an ingest_chunk_progress row with a backdated last_updated."""
conn.execute(
text(
"""
INSERT INTO ingest_chunk_progress (
source_id, total_chunks, embedded_chunks, last_index,
last_updated, status
)
VALUES (
CAST(:sid AS uuid), :total, :embedded, :embedded - 1,
clock_timestamp() - make_interval(mins => :age), :status
)
"""
),
{
"sid": source_id, "total": total, "embedded": embedded,
"age": age_minutes, "status": status,
},
)
class TestFindAndLockStalledIngests:
def test_returns_stale_active_partial(self, pg_conn):
sid = "2b000000-0000-0000-0000-0000000000b1"
_seed_ingest_progress(pg_conn, source_id=sid, embedded=2, total=10)
rows = ReconciliationRepository(pg_conn).find_and_lock_stalled_ingests()
assert any(str(r["source_id"]) == sid for r in rows)
def test_excludes_already_stalled(self, pg_conn):
sid = "2b000000-0000-0000-0000-0000000000b2"
_seed_ingest_progress(
pg_conn, source_id=sid, embedded=2, total=10, status="stalled",
)
rows = ReconciliationRepository(pg_conn).find_and_lock_stalled_ingests()
assert all(str(r["source_id"]) != sid for r in rows)
def test_excludes_completed(self, pg_conn):
sid = "2b000000-0000-0000-0000-0000000000b3"
_seed_ingest_progress(pg_conn, source_id=sid, embedded=10, total=10)
rows = ReconciliationRepository(pg_conn).find_and_lock_stalled_ingests()
assert all(str(r["source_id"]) != sid for r in rows)
def test_excludes_under_age_threshold(self, pg_conn):
sid = "2b000000-0000-0000-0000-0000000000b4"
_seed_ingest_progress(
pg_conn, source_id=sid, embedded=2, total=10, age_minutes=2,
)
rows = ReconciliationRepository(pg_conn).find_and_lock_stalled_ingests()
assert all(str(r["source_id"]) != sid for r in rows)
class TestMarkIngestStalled:
def test_flips_status_to_stalled(self, pg_conn):
sid = "2b000000-0000-0000-0000-0000000000b5"
_seed_ingest_progress(pg_conn, source_id=sid, embedded=2, total=10)
repo = ReconciliationRepository(pg_conn)
assert repo.mark_ingest_stalled(sid) is True
row = pg_conn.execute(
text(
"SELECT status FROM ingest_chunk_progress "
"WHERE source_id = CAST(:sid AS uuid)"
),
{"sid": sid},
).fetchone()
assert row[0] == "stalled"
def test_returns_false_for_missing_source(self, pg_conn):
repo = ReconciliationRepository(pg_conn)
assert (
repo.mark_ingest_stalled("2b000000-0000-0000-0000-0000000000bf")
is False
)
def _seed_stuck_idempotency(
conn,
*,

View File

@@ -148,130 +148,6 @@ class TestSyncWorker:
assert captured[0]["loader"] == "url"
assert captured[0]["doc_id"] == str(src["id"])
def test_connector_sources_are_skipped(
self,
pg_conn,
patch_worker_db,
task_self,
monkeypatch,
):
"""connector:* sources have no RemoteCreator loader — sync_worker
must skip them, not dispatch them into sync()."""
from application import worker
SourcesRepository(pg_conn).create(
"drive-folder",
user_id="dave",
type="connector:file",
retriever="classic",
sync_frequency="daily",
remote_data={
"provider": "google_drive",
"file_ids": ["f1"],
"folder_ids": [],
"recursive": False,
},
)
def _must_not_run(*args, **kwargs):
raise AssertionError("sync() must not run for connector sources")
monkeypatch.setattr(worker, "sync", _must_not_run)
result = worker.sync_worker(task_self, "daily")
assert result["total_sync_count"] == 1
assert result["sync_skipped"] == 1
assert result["sync_success"] == 0
assert result["sync_failure"] == 0
def test_dict_remote_data_is_normalized_before_loader(
self,
pg_conn,
patch_worker_db,
task_self,
monkeypatch,
):
"""Regression: remote_data reads back as a dict; sync_worker must
hand the loader the URL string, not the raw dict."""
from application import worker
SourcesRepository(pg_conn).create(
"docs-crawl",
user_id="erin",
type="crawler",
retriever="classic",
sync_frequency="weekly",
remote_data={"url": "https://example.com", "provider": "crawler"},
)
received: list = []
fake_loader = MagicMock(name="remote_loader")
def _capture(source_data):
received.append(source_data)
return [
Document(
text="page body",
extra_info={"file_path": "index.md", "title": "home"},
doc_id="d1",
)
]
fake_loader.load_data.side_effect = _capture
monkeypatch.setattr(
worker.RemoteCreator, "create_loader", lambda loader: fake_loader
)
monkeypatch.setattr(
worker,
"embed_and_store_documents",
lambda docs, full_path, source_id, task, **kw: None,
)
monkeypatch.setattr(
worker, "upload_index", lambda full_path, file_data: None
)
result = worker.sync_worker(task_self, "weekly")
assert result["total_sync_count"] == 1
assert result["sync_success"] == 1
assert result["sync_failure"] == 0
assert received == ["https://example.com"], (
"loader must receive the URL string, not the remote_data dict"
)
def test_unsyncable_remote_data_is_skipped(
self,
pg_conn,
patch_worker_db,
task_self,
monkeypatch,
):
"""A URL source whose remote_data dict has no URL key normalizes
to None — sync_worker must skip it, not dispatch a doomed sync()."""
from application import worker
SourcesRepository(pg_conn).create(
"broken-feed",
user_id="frank",
type="url",
retriever="classic",
sync_frequency="monthly",
remote_data={"provider": "url"},
)
def _must_not_run(*args, **kwargs):
raise AssertionError("sync() must not run for unsyncable sources")
monkeypatch.setattr(worker, "sync", _must_not_run)
result = worker.sync_worker(task_self, "monthly")
assert result["total_sync_count"] == 1
assert result["sync_skipped"] == 1
assert result["sync_failure"] == 0
assert result["sync_success"] == 0
@pytest.mark.unit
class TestRemoteWorkerPathTraversal: