Compare commits

..

3 Commits

Author SHA1 Message Date
Pavel
9a5ea8fe00 Harden protection with pinned requests and path-param encoding 2026-05-21 00:31:52 +04:00
Alex
1de82ca040 fix: batch limits and failed task reque limit (#2484) 2026-05-18 22:22:43 +01:00
Alex
8f7742c937 fix: better source upload status and fix reconciliation issue (#2482)
* fix: better source upload status and fix reconciliation issue

* fix: mini issues

* chore: locale coverage
2026-05-18 14:22:03 +01:00
16 changed files with 712 additions and 299 deletions

View File

@@ -2,7 +2,7 @@ import json
import logging
import re
from typing import Any, Dict, Optional
from urllib.parse import urlencode
from urllib.parse import quote, urlencode
import requests
@@ -11,7 +11,7 @@ from application.agents.tools.api_body_serializer import (
RequestBodySerializer,
)
from application.agents.tools.base import Tool
from application.core.url_validation import validate_url, SSRFError
from application.security.safe_url import UnsafeUserUrlError, pinned_request
logger = logging.getLogger(__name__)
@@ -70,18 +70,16 @@ class APITool(Tool):
Returns:
Dict with status_code, data, and message
"""
_VALID_METHODS = {"GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"}
request_url = url
request_headers = headers.copy() if headers else {}
response = None
# Validate URL to prevent SSRF attacks
try:
validate_url(request_url)
except SSRFError as e:
logger.error(f"URL validation failed: {e}")
if method.upper() not in _VALID_METHODS:
return {
"status_code": None,
"message": f"URL validation error: {e}",
"message": f"Unsupported HTTP method: {method}",
"data": None,
}
@@ -91,8 +89,9 @@ class APITool(Tool):
for match in re.finditer(r"\{([^}]+)\}", request_url):
param_name = match.group(1)
if param_name in query_params:
safe_value = quote(str(query_params[param_name]), safe="")
request_url = request_url.replace(
f"{{{param_name}}}", str(query_params[param_name])
f"{{{param_name}}}", safe_value
)
path_params_used.add(param_name)
remaining_params = {
@@ -103,19 +102,6 @@ class APITool(Tool):
separator = "&" if "?" in request_url else "?"
request_url = f"{request_url}{separator}{query_string}"
# Re-validate URL after parameter substitution to prevent SSRF via path params
try:
validate_url(request_url)
except SSRFError as e:
logger.error(f"URL validation failed after parameter substitution: {e}")
return {
"status_code": None,
"message": f"URL validation error: {e}",
"data": None,
}
# Serialize body based on content type
if body and body != {}:
try:
serialized_body, body_headers = RequestBodySerializer.serialize(
@@ -141,49 +127,13 @@ class APITool(Tool):
f"API Call: {method} {request_url} | Content-Type: {request_headers.get('Content-Type', 'N/A')}"
)
if method.upper() == "GET":
response = requests.get(
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
)
elif method.upper() == "POST":
response = requests.post(
request_url,
data=serialized_body,
headers=request_headers,
timeout=DEFAULT_TIMEOUT,
)
elif method.upper() == "PUT":
response = requests.put(
request_url,
data=serialized_body,
headers=request_headers,
timeout=DEFAULT_TIMEOUT,
)
elif method.upper() == "DELETE":
response = requests.delete(
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
)
elif method.upper() == "PATCH":
response = requests.patch(
request_url,
data=serialized_body,
headers=request_headers,
timeout=DEFAULT_TIMEOUT,
)
elif method.upper() == "HEAD":
response = requests.head(
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
)
elif method.upper() == "OPTIONS":
response = requests.options(
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
)
else:
return {
"status_code": None,
"message": f"Unsupported HTTP method: {method}",
"data": None,
}
response = pinned_request(
method,
request_url,
data=serialized_body,
headers=request_headers,
timeout=DEFAULT_TIMEOUT,
)
response.raise_for_status()
data = self._parse_response(response)
@@ -193,6 +143,13 @@ class APITool(Tool):
"data": data,
"message": "API call successful.",
}
except UnsafeUserUrlError as e:
logger.error(f"URL validation failed: {e}")
return {
"status_code": None,
"message": f"URL validation error: {e}",
"data": None,
}
except requests.exceptions.Timeout:
logger.error(f"Request timeout for {request_url}")
return {

View File

@@ -1,5 +1,5 @@
import requests
from application.agents.tools.base import Tool
from application.security.safe_url import UnsafeUserUrlError, pinned_request
class NtfyTool(Tool):
"""
@@ -71,7 +71,12 @@ class NtfyTool(Tool):
if self.token:
headers["Authorization"] = f"Basic {self.token}"
data = message.encode("utf-8")
response = requests.post(url, headers=headers, data=data, timeout=100)
try:
response = pinned_request(
"POST", url, data=data, headers=headers, timeout=100,
)
except UnsafeUserUrlError as e:
return {"status_code": None, "message": f"URL validation error: {e}"}
return {"status_code": response.status_code, "message": "Message sent"}
def get_actions_metadata(self):

View File

@@ -1,7 +1,6 @@
import requests
from markdownify import markdownify
from application.agents.tools.base import Tool
from application.core.url_validation import validate_url, SSRFError
from application.security.safe_url import UnsafeUserUrlError, pinned_request
class ReadWebpageTool(Tool):
"""
@@ -31,28 +30,24 @@ class ReadWebpageTool(Tool):
if not url:
return "Error: URL parameter is missing."
# Validate URL to prevent SSRF attacks
try:
url = validate_url(url)
except SSRFError as e:
return f"Error: URL validation failed - {e}"
response = pinned_request(
"GET",
url,
headers={'User-Agent': 'DocsGPT-Agent/1.0'},
timeout=10,
)
response.raise_for_status()
try:
response = requests.get(url, timeout=10, headers={'User-Agent': 'DocsGPT-Agent/1.0'})
response.raise_for_status() # Raise an exception for HTTP errors (4xx or 5xx)
html_content = response.text
#soup = BeautifulSoup(html_content, 'html.parser')
markdown_content = markdownify(html_content, heading_style="ATX", newline_style="BACKSLASH")
return markdown_content
except requests.exceptions.RequestException as e:
return f"Error fetching URL {url}: {e}"
except UnsafeUserUrlError as e:
return f"Error: URL validation failed - {e}"
except Exception as e:
return f"Error processing URL {url}: {e}"
return f"Error fetching URL {url}: {e}"
def get_actions_metadata(self):
"""

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
import functools
import inspect
import logging
import threading
import uuid
@@ -26,13 +27,20 @@ LEASE_HEARTBEAT_INTERVAL = 30
LEASE_RETRY_MAX = 10
def with_idempotency(task_name: str) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def with_idempotency(
task_name: str,
*,
on_poison: Optional[Callable[[str, dict], None]] = None,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""Short-circuit on completed key; gate concurrent runs via a lease.
The guard key is the caller's ``idempotency_key``, or one synthesized
from ``source_id`` so a keyless dispatch is still poison-guarded.
Entry short-circuits:
- completed row → return cached result
- live lease held → retry(countdown=LEASE_TTL_SECONDS)
- attempt_count > MAX_TASK_ATTEMPTS → poison-loop alert
- attempt_count > MAX_TASK_ATTEMPTS → poison alert; ``on_poison`` fires
Success writes ``completed``; exceptions leave ``pending`` for
autoretry until the poison-loop guard trips.
"""
@@ -40,7 +48,14 @@ def with_idempotency(task_name: str) -> Callable[[Callable[..., Any]], Callable[
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
@functools.wraps(fn)
def wrapper(self, *args: Any, idempotency_key: Any = None, **kwargs: Any) -> Any:
key = idempotency_key if isinstance(idempotency_key, str) and idempotency_key else None
explicit_key = (
idempotency_key
if isinstance(idempotency_key, str) and idempotency_key
else None
)
# A keyless dispatch still gets the guard via a synthesized key;
# None means no anchor exists — run unguarded, as before.
key = explicit_key or _synthesize_guard_key(task_name, kwargs)
if key is None:
return fn(self, *args, idempotency_key=idempotency_key, **kwargs)
@@ -88,6 +103,9 @@ def with_idempotency(task_name: str) -> Callable[[Callable[..., Any]], Callable[
"attempts": attempt,
}
_finalize(key, poisoned, status="failed")
_run_poison_hook(
on_poison, task_name, fn, self, args, kwargs, idempotency_key,
)
return poisoned
heartbeat_thread, heartbeat_stop = _start_lease_heartbeat(
@@ -109,6 +127,45 @@ def with_idempotency(task_name: str) -> Callable[[Callable[..., Any]], Callable[
return decorator
def _synthesize_guard_key(task_name: str, kwargs: dict) -> Optional[str]:
"""Derive a deterministic guard key from ``source_id`` for a keyless dispatch.
``source_id`` is stable across broker redeliveries and unique per
upload, so the poison-loop counter survives an OOM SIGKILL. Returns
``None`` when absent — the dispatch then runs unguarded as before.
"""
source_id = kwargs.get("source_id")
if source_id:
return f"auto:{task_name}:{source_id}"
return None
def _run_poison_hook(
on_poison: Optional[Callable[[str, dict], None]],
task_name: str,
fn: Callable[..., Any],
task_self: Any,
args: tuple,
kwargs: dict,
idempotency_key: Any,
) -> None:
"""Invoke a task's poison-path hook with named call args; swallow failures.
A hook failure must never change the poison-guard outcome.
"""
if on_poison is None:
return
try:
bound = inspect.signature(fn).bind_partial(
task_self, *args, idempotency_key=idempotency_key, **kwargs,
)
on_poison(task_name, dict(bound.arguments))
except Exception:
logger.exception(
"idempotency: poison hook failed for task=%s", task_name,
)
def _lookup_completed(key: str) -> Any:
"""Return cached ``result_json`` if a completed row exists for ``key``, else None."""
with db_readonly() as conn:

View File

@@ -27,8 +27,42 @@ DURABLE_TASK = dict(
)
# operation tag for the poison-path source.ingest.failed event, per task.
_INGEST_POISON_OPERATION = {
"ingest": "upload",
"ingest_remote": "upload",
"ingest_connector_task": "upload",
"reingest_source_task": "reingest",
}
def _emit_ingest_poison_event(task_name, bound):
"""Publish a terminal ``source.ingest.failed`` when the poison-guard trips.
The guard returns before the worker runs, so the worker's own failed
event never fires — without this the upload toast spins on "training".
"""
user = bound.get("user")
source_id = bound.get("source_id")
if not user or not source_id:
return
from application.events.publisher import publish_user_event
publish_user_event(
user,
"source.ingest.failed",
{
"source_id": str(source_id),
"filename": bound.get("filename") or "",
"operation": _INGEST_POISON_OPERATION.get(task_name, "upload"),
"error": "Ingestion stopped after repeated failures.",
},
scope={"kind": "source", "id": str(source_id)},
)
@celery.task(**DURABLE_TASK)
@with_idempotency(task_name="ingest")
@with_idempotency(task_name="ingest", on_poison=_emit_ingest_poison_event)
def ingest(
self,
directory,
@@ -57,7 +91,7 @@ def ingest(
@celery.task(**DURABLE_TASK)
@with_idempotency(task_name="ingest_remote")
@with_idempotency(task_name="ingest_remote", on_poison=_emit_ingest_poison_event)
def ingest_remote(
self, source_data, job_name, user, loader,
idempotency_key=None, source_id=None,
@@ -71,7 +105,9 @@ def ingest_remote(
@celery.task(**DURABLE_TASK)
@with_idempotency(task_name="reingest_source_task")
@with_idempotency(
task_name="reingest_source_task", on_poison=_emit_ingest_poison_event,
)
def reingest_source_task(self, source_id, user, idempotency_key=None):
from application.worker import reingest_source_worker
@@ -128,7 +164,9 @@ def process_agent_webhook(self, agent_id, payload, idempotency_key=None):
@celery.task(**DURABLE_TASK)
@with_idempotency(task_name="ingest_connector_task")
@with_idempotency(
task_name="ingest_connector_task", on_poison=_emit_ingest_poison_event,
)
def ingest_connector_task(
self,
job_name,

View File

@@ -66,6 +66,9 @@ class Settings(BaseSettings):
PARSE_IMAGE_REMOTE: bool = False
DOCLING_OCR_ENABLED: bool = False # Enable OCR for docling parsers (PDF, images)
DOCLING_OCR_ATTACHMENTS_ENABLED: bool = False # Enable OCR for docling when parsing attachments
# Pages docling's threaded pipeline buffers in flight; the library
# default (100) drives worker RSS to ~3 GB on a mid-size PDF.
DOCLING_PIPELINE_QUEUE_MAX_SIZE: int = 2
VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb" or "pgvector"
RETRIEVERS_ENABLED: list = ["classic_rag"]
AGENT_NAME: str = "classic"

View File

@@ -16,6 +16,29 @@ from application.parser.file.base_parser import BaseParser
logger = logging.getLogger(__name__)
# Per-stage batch size for docling's threaded pipeline; 1 holds the
# concurrent working set to a single page (see _apply_pipeline_caps).
_PIPELINE_BATCH_SIZE = 1
def _apply_pipeline_caps(pipeline_options) -> None:
"""Cap docling's threaded-pipeline queue depth and batch sizes in place.
hasattr-guarded so docling builds without these knobs are unaffected.
"""
from application.core.settings import settings
caps = {
"queue_max_size": max(1, settings.DOCLING_PIPELINE_QUEUE_MAX_SIZE),
"layout_batch_size": _PIPELINE_BATCH_SIZE,
"table_batch_size": _PIPELINE_BATCH_SIZE,
"ocr_batch_size": _PIPELINE_BATCH_SIZE,
}
for name, value in caps.items():
if hasattr(pipeline_options, name):
setattr(pipeline_options, name, value)
class DoclingParser(BaseParser):
"""Parser using docling for advanced document processing.
@@ -86,6 +109,7 @@ class DoclingParser(BaseParser):
do_ocr=self.ocr_enabled,
do_table_structure=self.table_structure,
)
_apply_pipeline_caps(pipeline_options)
if self.ocr_enabled:
ocr_options = self._get_ocr_options()

View File

@@ -1,11 +1,11 @@
import logging
import os
import requests
from urllib.parse import urlparse, urljoin
from bs4 import BeautifulSoup
from application.parser.remote.base import BaseRemote
from application.parser.schema.base import Document
from application.core.url_validation import validate_url, SSRFError
from application.security.safe_url import UnsafeUserUrlError, pinned_request
from langchain_community.document_loaders import WebBaseLoader
class CrawlerLoader(BaseRemote):
@@ -35,14 +35,7 @@ class CrawlerLoader(BaseRemote):
visited_urls.add(current_url)
try:
# Validate each URL before making requests
try:
validate_url(current_url)
except SSRFError as e:
logging.warning(f"Skipping URL due to validation failure: {current_url} - {e}")
continue
response = requests.get(current_url, timeout=30)
response = pinned_request("GET", current_url, timeout=30)
response.raise_for_status()
loader = self.loader([current_url])
docs = loader.load()

View File

@@ -1,8 +1,8 @@
import requests
from urllib.parse import urlparse, urljoin
from bs4 import BeautifulSoup
from application.parser.remote.base import BaseRemote
from application.core.url_validation import validate_url, SSRFError
from application.security.safe_url import UnsafeUserUrlError, pinned_request
import re
from markdownify import markdownify
from application.parser.schema.base import Document
@@ -20,7 +20,6 @@ class CrawlerLoader(BaseRemote):
"""
self.limit = limit
self.allow_subdomains = allow_subdomains
self.session = requests.Session()
def load_data(self, inputs):
url = inputs
@@ -91,15 +90,13 @@ class CrawlerLoader(BaseRemote):
def _fetch_page(self, url):
try:
# Validate URL before fetching to prevent SSRF
validate_url(url)
response = self.session.get(url, timeout=10)
response = pinned_request("GET", url, timeout=10)
response.raise_for_status()
return response.text
except SSRFError as e:
except UnsafeUserUrlError as e:
print(f"URL validation failed for {url}: {e}")
return None
except requests.exceptions.RequestException as e:
except Exception as e:
print(f"Error fetching URL {url}: {e}")
return None

View File

@@ -1,9 +1,9 @@
import logging
import requests
import re # Import regular expression library
import re
import defusedxml.ElementTree as ET
from application.parser.remote.base import BaseRemote
from application.core.url_validation import validate_url, SSRFError
from application.security.safe_url import UnsafeUserUrlError, pinned_request
class SitemapLoader(BaseRemote):
def __init__(self, limit=20):
@@ -53,14 +53,12 @@ class SitemapLoader(BaseRemote):
def _extract_urls(self, sitemap_url):
try:
# Validate URL before fetching to prevent SSRF
validate_url(sitemap_url)
response = requests.get(sitemap_url, timeout=30)
response.raise_for_status() # Raise an exception for HTTP errors
except SSRFError as e:
response = pinned_request("GET", sitemap_url, timeout=30)
response.raise_for_status()
except UnsafeUserUrlError as e:
print(f"URL validation failed for sitemap: {sitemap_url}. Error: {e}")
return []
except (requests.exceptions.HTTPError, requests.exceptions.ConnectionError) as e:
except Exception as e:
print(f"Failed to fetch sitemap: {sitemap_url}. Error: {e}")
return []
@@ -97,13 +95,6 @@ class SitemapLoader(BaseRemote):
nested_sitemap_url = sitemap.text
if not nested_sitemap_url:
continue
try:
nested_sitemap_url = validate_url(nested_sitemap_url)
except SSRFError as e:
logging.error(
f"URL validation failed for nested sitemap {nested_sitemap_url}: {e}"
)
continue
urls.extend(self._extract_urls(nested_sitemap_url))
return urls

View File

@@ -291,6 +291,55 @@ def _ip_to_url_host(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> str:
return str(ip)
def pinned_request(
method: str,
url: str,
*,
data: Any = None,
json: Any = None,
headers: dict[str, str] | None = None,
timeout: float = 90.0,
allow_redirects: bool = False,
) -> requests.Response:
"""Send an HTTP request with the connection pinned to a validated IP,
closing the DNS-rebinding TOCTOU window left by the naive
validate-then-``requests`` pattern.
Raises:
UnsafeUserUrlError: If the URL fails the SSRF guard.
requests.RequestException: For network-level failures.
"""
host, ip, parts = _validate_and_pick_ip(url)
netloc = _ip_to_url_host(ip)
if parts.port is not None:
netloc = f"{netloc}:{parts.port}"
pinned_url = urlunsplit(
(parts.scheme, netloc, parts.path, parts.query, parts.fragment)
)
request_headers = dict(headers or {})
host_header = host if parts.port is None else f"{host}:{parts.port}"
request_headers["Host"] = host_header
session = requests.Session()
if parts.scheme == "https":
session.mount("https://", _PinnedHostAdapter(host))
try:
return session.request(
method=method.upper(),
url=pinned_url,
data=data,
json=json,
headers=request_headers,
timeout=timeout,
allow_redirects=allow_redirects,
)
finally:
session.close()
def pinned_post(
url: str,
*,
@@ -328,33 +377,15 @@ def pinned_post(
requests.RequestException: For network-level failures.
"""
host, ip, parts = _validate_and_pick_ip(url)
netloc = _ip_to_url_host(ip)
if parts.port is not None:
netloc = f"{netloc}:{parts.port}"
pinned_url = urlunsplit(
(parts.scheme, netloc, parts.path, parts.query, parts.fragment)
return pinned_request(
"POST",
url,
json=json,
headers=headers,
timeout=timeout,
allow_redirects=allow_redirects,
)
request_headers = dict(headers or {})
host_header = host if parts.port is None else f"{host}:{parts.port}"
request_headers["Host"] = host_header
session = requests.Session()
if parts.scheme == "https":
session.mount("https://", _PinnedHostAdapter(host))
try:
return session.post(
pinned_url,
json=json,
headers=request_headers,
timeout=timeout,
allow_redirects=allow_redirects,
)
finally:
session.close()
class _PinnedHTTPSTransport(httpx.HTTPTransport):
"""``httpx`` transport pinned to a single validated IP literal.

View File

@@ -45,15 +45,14 @@ class TestAPIToolInit:
@pytest.mark.unit
class TestMakeApiCall:
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_successful_get(self, mock_get, mock_validate, tool):
@patch("application.agents.tools.api_tool.pinned_request")
def test_successful_get(self, mock_pinned, tool):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {"result": "ok"}
mock_resp.content = b'{"result":"ok"}'
mock_get.return_value = mock_resp
mock_pinned.return_value = mock_resp
result = tool.execute_action("any_action")
@@ -61,54 +60,50 @@ class TestMakeApiCall:
assert result["data"] == {"result": "ok"}
assert result["message"] == "API call successful."
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.post")
def test_successful_post(self, mock_post, mock_validate, post_tool):
@patch("application.agents.tools.api_tool.pinned_request")
def test_successful_post(self, mock_pinned, post_tool):
mock_resp = MagicMock()
mock_resp.status_code = 201
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {"id": 1}
mock_resp.content = b'{"id":1}'
mock_post.return_value = mock_resp
mock_pinned.return_value = mock_resp
result = post_tool.execute_action("create", name="test")
assert result["status_code"] == 201
@patch("application.agents.tools.api_tool.validate_url")
def test_ssrf_blocked(self, mock_validate, tool):
from application.core.url_validation import SSRFError
@patch("application.agents.tools.api_tool.pinned_request")
def test_ssrf_blocked(self, mock_pinned, tool):
from application.security.safe_url import UnsafeUserUrlError
mock_validate.side_effect = SSRFError("blocked")
mock_pinned.side_effect = UnsafeUserUrlError("blocked")
result = tool.execute_action("any")
assert result["status_code"] is None
assert "URL validation error" in result["message"]
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_timeout_error(self, mock_get, mock_validate, tool):
mock_get.side_effect = requests.exceptions.Timeout()
@patch("application.agents.tools.api_tool.pinned_request")
def test_timeout_error(self, mock_pinned, tool):
mock_pinned.side_effect = requests.exceptions.Timeout()
result = tool.execute_action("any")
assert result["status_code"] is None
assert "timeout" in result["message"].lower()
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_connection_error(self, mock_get, mock_validate, tool):
mock_get.side_effect = requests.exceptions.ConnectionError("refused")
@patch("application.agents.tools.api_tool.pinned_request")
def test_connection_error(self, mock_pinned, tool):
mock_pinned.side_effect = requests.exceptions.ConnectionError("refused")
result = tool.execute_action("any")
assert result["status_code"] is None
assert "Connection error" in result["message"]
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_http_error(self, mock_get, mock_validate, tool):
@patch("application.agents.tools.api_tool.pinned_request")
def test_http_error(self, mock_pinned, tool):
mock_resp = MagicMock()
mock_resp.status_code = 404
mock_resp.text = "Not Found"
@@ -116,15 +111,14 @@ class TestMakeApiCall:
mock_resp.raise_for_status.side_effect = requests.exceptions.HTTPError(
response=mock_resp
)
mock_get.return_value = mock_resp
mock_pinned.return_value = mock_resp
result = tool.execute_action("any")
assert result["status_code"] == 404
assert "HTTP Error" in result["message"]
@patch("application.agents.tools.api_tool.validate_url")
def test_unsupported_method(self, mock_validate):
def test_unsupported_method(self):
tool = APITool(
config={"url": "https://example.com", "method": "CUSTOM"}
)
@@ -132,69 +126,64 @@ class TestMakeApiCall:
assert result["status_code"] is None
assert "Unsupported" in result["message"]
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.put")
def test_put_method(self, mock_put, mock_validate):
@patch("application.agents.tools.api_tool.pinned_request")
def test_put_method(self, mock_pinned):
tool = APITool(config={"url": "https://example.com/item/1", "method": "PUT"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {}
mock_resp.content = b'{}'
mock_put.return_value = mock_resp
mock_pinned.return_value = mock_resp
result = tool.execute_action("update", name="new")
assert result["status_code"] == 200
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.delete")
def test_delete_method(self, mock_delete, mock_validate):
@patch("application.agents.tools.api_tool.pinned_request")
def test_delete_method(self, mock_pinned):
tool = APITool(config={"url": "https://example.com/item/1", "method": "DELETE"})
mock_resp = MagicMock()
mock_resp.status_code = 204
mock_resp.headers = {"Content-Type": "text/plain"}
mock_resp.content = b''
mock_delete.return_value = mock_resp
mock_pinned.return_value = mock_resp
result = tool.execute_action("delete")
assert result["status_code"] == 204
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.patch")
def test_patch_method(self, mock_patch, mock_validate):
@patch("application.agents.tools.api_tool.pinned_request")
def test_patch_method(self, mock_pinned):
tool = APITool(config={"url": "https://example.com/item/1", "method": "PATCH"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {"patched": True}
mock_resp.content = b'{"patched":true}'
mock_patch.return_value = mock_resp
mock_pinned.return_value = mock_resp
result = tool.execute_action("patch", field="val")
assert result["status_code"] == 200
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.head")
def test_head_method(self, mock_head, mock_validate):
@patch("application.agents.tools.api_tool.pinned_request")
def test_head_method(self, mock_pinned):
tool = APITool(config={"url": "https://example.com", "method": "HEAD"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "text/html"}
mock_resp.content = b''
mock_head.return_value = mock_resp
mock_pinned.return_value = mock_resp
result = tool.execute_action("check")
assert result["status_code"] == 200
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.options")
def test_options_method(self, mock_options, mock_validate):
@patch("application.agents.tools.api_tool.pinned_request")
def test_options_method(self, mock_pinned):
tool = APITool(config={"url": "https://example.com", "method": "OPTIONS"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "text/plain"}
mock_resp.content = b''
mock_options.return_value = mock_resp
mock_pinned.return_value = mock_resp
result = tool.execute_action("options")
assert result["status_code"] == 200
@@ -202,9 +191,8 @@ class TestMakeApiCall:
@pytest.mark.unit
class TestPathParamSubstitution:
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_path_params_substituted(self, mock_get, mock_validate):
@patch("application.agents.tools.api_tool.pinned_request")
def test_path_params_substituted(self, mock_pinned):
tool = APITool(
config={
"url": "https://api.example.com/users/{user_id}/posts/{post_id}",
@@ -217,11 +205,11 @@ class TestPathParamSubstitution:
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = []
mock_resp.content = b'[]'
mock_get.return_value = mock_resp
mock_pinned.return_value = mock_resp
tool.execute_action("get")
called_url = mock_get.call_args[0][0]
called_url = mock_pinned.call_args[0][1]
assert "/users/42/posts/7" in called_url
assert "{user_id}" not in called_url

View File

@@ -81,104 +81,103 @@ class TestAPIToolInit:
@pytest.mark.unit
class TestMakeApiCall:
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_successful_get(self, mock_get, mock_validate, get_tool):
@patch("application.agents.tools.api_tool.pinned_request")
def test_successful_get(self, mock_pinned, get_tool):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {"result": "ok"}
mock_resp.content = b'{"result":"ok"}'
mock_get.return_value = mock_resp
mock_pinned.return_value = mock_resp
result = get_tool.execute_action("any_action")
assert result["status_code"] == 200
assert result["data"] == {"result": "ok"}
assert result["message"] == "API call successful."
assert mock_pinned.call_args[0][0] == "GET"
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.post")
def test_successful_post(self, mock_post, mock_validate, post_tool):
@patch("application.agents.tools.api_tool.pinned_request")
def test_successful_post(self, mock_pinned, post_tool):
mock_resp = MagicMock()
mock_resp.status_code = 201
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {"id": 1}
mock_resp.content = b'{"id":1}'
mock_post.return_value = mock_resp
mock_pinned.return_value = mock_resp
result = post_tool.execute_action("create", name="test")
assert result["status_code"] == 201
assert mock_pinned.call_args[0][0] == "POST"
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.put")
def test_put_method(self, mock_put, mock_validate):
@patch("application.agents.tools.api_tool.pinned_request")
def test_put_method(self, mock_pinned):
tool = APITool(config={"url": "https://example.com/item/1", "method": "PUT"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {}
mock_resp.content = b'{}'
mock_put.return_value = mock_resp
mock_pinned.return_value = mock_resp
result = tool.execute_action("update", name="new")
assert result["status_code"] == 200
assert mock_pinned.call_args[0][0] == "PUT"
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.delete")
def test_delete_method(self, mock_delete, mock_validate):
@patch("application.agents.tools.api_tool.pinned_request")
def test_delete_method(self, mock_pinned):
tool = APITool(config={"url": "https://example.com/item/1", "method": "DELETE"})
mock_resp = MagicMock()
mock_resp.status_code = 204
mock_resp.headers = {"Content-Type": "text/plain"}
mock_resp.content = b''
mock_delete.return_value = mock_resp
mock_pinned.return_value = mock_resp
result = tool.execute_action("delete")
assert result["status_code"] == 204
assert mock_pinned.call_args[0][0] == "DELETE"
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.patch")
def test_patch_method(self, mock_patch, mock_validate):
@patch("application.agents.tools.api_tool.pinned_request")
def test_patch_method(self, mock_pinned):
tool = APITool(config={"url": "https://example.com/item/1", "method": "PATCH"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {"patched": True}
mock_resp.content = b'{"patched":true}'
mock_patch.return_value = mock_resp
mock_pinned.return_value = mock_resp
result = tool.execute_action("patch", field="val")
assert result["status_code"] == 200
assert mock_pinned.call_args[0][0] == "PATCH"
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.head")
def test_head_method(self, mock_head, mock_validate):
@patch("application.agents.tools.api_tool.pinned_request")
def test_head_method(self, mock_pinned):
tool = APITool(config={"url": "https://example.com", "method": "HEAD"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "text/html"}
mock_resp.content = b''
mock_head.return_value = mock_resp
mock_pinned.return_value = mock_resp
result = tool.execute_action("check")
assert result["status_code"] == 200
assert mock_pinned.call_args[0][0] == "HEAD"
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.options")
def test_options_method(self, mock_options, mock_validate):
@patch("application.agents.tools.api_tool.pinned_request")
def test_options_method(self, mock_pinned):
tool = APITool(config={"url": "https://example.com", "method": "OPTIONS"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "text/plain"}
mock_resp.content = b''
mock_options.return_value = mock_resp
mock_pinned.return_value = mock_resp
result = tool.execute_action("options")
assert result["status_code"] == 200
assert mock_pinned.call_args[0][0] == "OPTIONS"
@patch("application.agents.tools.api_tool.validate_url")
def test_unsupported_method(self, mock_validate):
def test_unsupported_method(self):
tool = APITool(config={"url": "https://example.com", "method": "CUSTOM"})
result = tool.execute_action("any")
assert result["status_code"] is None
@@ -193,19 +192,18 @@ class TestMakeApiCall:
@pytest.mark.unit
class TestSSRFValidation:
@patch("application.agents.tools.api_tool.validate_url")
def test_ssrf_blocked_initial_url(self, mock_validate, get_tool):
from application.core.url_validation import SSRFError
@patch("application.agents.tools.api_tool.pinned_request")
def test_ssrf_blocked(self, mock_pinned, get_tool):
from application.security.safe_url import UnsafeUserUrlError
mock_validate.side_effect = SSRFError("blocked")
mock_pinned.side_effect = UnsafeUserUrlError("blocked")
result = get_tool.execute_action("any")
assert result["status_code"] is None
assert "URL validation error" in result["message"]
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_ssrf_blocked_after_param_substitution(self, mock_get, mock_validate):
from application.core.url_validation import SSRFError
@patch("application.agents.tools.api_tool.pinned_request")
def test_ssrf_blocked_with_path_params(self, mock_pinned):
from application.security.safe_url import UnsafeUserUrlError
tool = APITool(config={
"url": "https://api.example.com/{host}/data",
@@ -213,14 +211,7 @@ class TestSSRFValidation:
"query_params": {"host": "169.254.169.254"},
})
call_count = [0]
def side_effect(url):
call_count[0] += 1
if call_count[0] == 2:
raise SSRFError("blocked after substitution")
mock_validate.side_effect = side_effect
mock_pinned.side_effect = UnsafeUserUrlError("blocked")
result = tool.execute_action("any")
assert result["status_code"] is None
assert "URL validation error" in result["message"]
@@ -234,40 +225,36 @@ class TestSSRFValidation:
@pytest.mark.unit
class TestErrorHandling:
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_timeout_error(self, mock_get, mock_validate, get_tool):
mock_get.side_effect = requests.exceptions.Timeout()
@patch("application.agents.tools.api_tool.pinned_request")
def test_timeout_error(self, mock_pinned, get_tool):
mock_pinned.side_effect = requests.exceptions.Timeout()
result = get_tool.execute_action("any")
assert result["status_code"] is None
assert "timeout" in result["message"].lower()
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_connection_error(self, mock_get, mock_validate, get_tool):
mock_get.side_effect = requests.exceptions.ConnectionError("refused")
@patch("application.agents.tools.api_tool.pinned_request")
def test_connection_error(self, mock_pinned, get_tool):
mock_pinned.side_effect = requests.exceptions.ConnectionError("refused")
result = get_tool.execute_action("any")
assert result["status_code"] is None
assert "Connection error" in result["message"]
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_http_error_with_json(self, mock_get, mock_validate, get_tool):
@patch("application.agents.tools.api_tool.pinned_request")
def test_http_error_with_json(self, mock_pinned, get_tool):
mock_resp = MagicMock()
mock_resp.status_code = 422
mock_resp.json.return_value = {"error": "invalid_field"}
mock_resp.raise_for_status.side_effect = requests.exceptions.HTTPError(
response=mock_resp
)
mock_get.return_value = mock_resp
mock_pinned.return_value = mock_resp
result = get_tool.execute_action("any")
assert result["status_code"] == 422
assert result["data"] == {"error": "invalid_field"}
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_http_error_non_json_body(self, mock_get, mock_validate, get_tool):
@patch("application.agents.tools.api_tool.pinned_request")
def test_http_error_non_json_body(self, mock_pinned, get_tool):
mock_resp = MagicMock()
mock_resp.status_code = 404
mock_resp.text = "Not Found"
@@ -275,29 +262,26 @@ class TestErrorHandling:
mock_resp.raise_for_status.side_effect = requests.exceptions.HTTPError(
response=mock_resp
)
mock_get.return_value = mock_resp
mock_pinned.return_value = mock_resp
result = get_tool.execute_action("any")
assert result["status_code"] == 404
assert result["data"] == "Not Found"
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_request_exception(self, mock_get, mock_validate, get_tool):
mock_get.side_effect = requests.exceptions.RequestException("something")
@patch("application.agents.tools.api_tool.pinned_request")
def test_request_exception(self, mock_pinned, get_tool):
mock_pinned.side_effect = requests.exceptions.RequestException("something")
result = get_tool.execute_action("any")
assert "API call failed" in result["message"]
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_unexpected_exception(self, mock_get, mock_validate, get_tool):
mock_get.side_effect = RuntimeError("unexpected")
@patch("application.agents.tools.api_tool.pinned_request")
def test_unexpected_exception(self, mock_pinned, get_tool):
mock_pinned.side_effect = RuntimeError("unexpected")
result = get_tool.execute_action("any")
assert "Unexpected error" in result["message"]
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.post")
def test_body_serialization_error(self, mock_post, mock_validate):
@patch("application.agents.tools.api_tool.pinned_request")
def test_body_serialization_error(self, mock_pinned):
tool = APITool(config={
"url": "https://example.com",
"method": "POST",
@@ -320,9 +304,8 @@ class TestErrorHandling:
@pytest.mark.unit
class TestPathParamSubstitution:
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_path_params_substituted(self, mock_get, mock_validate):
@patch("application.agents.tools.api_tool.pinned_request")
def test_path_params_substituted(self, mock_pinned):
tool = APITool(config={
"url": "https://api.example.com/users/{user_id}/posts/{post_id}",
"method": "GET",
@@ -333,17 +316,16 @@ class TestPathParamSubstitution:
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = []
mock_resp.content = b'[]'
mock_get.return_value = mock_resp
mock_pinned.return_value = mock_resp
tool.execute_action("get")
called_url = mock_get.call_args[0][0]
called_url = mock_pinned.call_args[0][1]
assert "/users/42/posts/7" in called_url
assert "{user_id}" not in called_url
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_remaining_query_params_appended(self, mock_get, mock_validate):
@patch("application.agents.tools.api_tool.pinned_request")
def test_remaining_query_params_appended(self, mock_pinned):
tool = APITool(config={
"url": "https://api.example.com/items",
"method": "GET",
@@ -354,19 +336,16 @@ class TestPathParamSubstitution:
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = []
mock_resp.content = b'[]'
mock_get.return_value = mock_resp
mock_pinned.return_value = mock_resp
tool.execute_action("get")
called_url = mock_get.call_args[0][0]
called_url = mock_pinned.call_args[0][1]
assert "page=2" in called_url
assert "limit=10" in called_url
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_query_params_append_with_existing_query_string(
self, mock_get, mock_validate
):
@patch("application.agents.tools.api_tool.pinned_request")
def test_query_params_append_with_existing_query_string(self, mock_pinned):
tool = APITool(config={
"url": "https://api.example.com/items?existing=true",
"method": "GET",
@@ -377,27 +356,65 @@ class TestPathParamSubstitution:
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = []
mock_resp.content = b'[]'
mock_get.return_value = mock_resp
mock_pinned.return_value = mock_resp
tool.execute_action("get")
called_url = mock_get.call_args[0][0]
called_url = mock_pinned.call_args[0][1]
assert "&page=1" in called_url
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.post")
def test_empty_body_no_serialization(self, mock_post, mock_validate):
@patch("application.agents.tools.api_tool.pinned_request")
def test_empty_body_no_serialization(self, mock_pinned):
tool = APITool(config={"url": "https://example.com", "method": "POST"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {}
mock_resp.content = b'{}'
mock_post.return_value = mock_resp
mock_pinned.return_value = mock_resp
result = tool.execute_action("create")
assert result["status_code"] == 200
@patch("application.agents.tools.api_tool.pinned_request")
def test_path_params_are_url_encoded(self, mock_pinned):
tool = APITool(config={
"url": "https://api.example.com/users/{user_id}/profile",
"method": "GET",
"query_params": {"user_id": "../../admin"},
})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {}
mock_resp.content = b'{}'
mock_pinned.return_value = mock_resp
tool.execute_action("get")
called_url = mock_pinned.call_args[0][1]
assert "../../admin" not in called_url
assert "%2F" in called_url or "%2f" in called_url
@patch("application.agents.tools.api_tool.pinned_request")
def test_path_params_query_injection_encoded(self, mock_pinned):
tool = APITool(config={
"url": "https://api.example.com/items/{item_id}",
"method": "GET",
"query_params": {"item_id": "x?admin=true"},
})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {}
mock_resp.content = b'{}'
mock_pinned.return_value = mock_resp
tool.execute_action("get")
called_url = mock_pinned.call_args[0][1]
assert "x?admin=true" not in called_url
# =====================================================================
# Parse Response
@@ -494,11 +511,8 @@ class TestAPIToolMetadata:
def test_config_requirements_empty(self, get_tool):
assert get_tool.get_config_requirements() == {}
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.post")
def test_content_type_set_for_post_with_no_headers(
self, mock_post, mock_validate
):
@patch("application.agents.tools.api_tool.pinned_request")
def test_content_type_set_for_post_with_no_headers(self, mock_pinned):
tool = APITool(config={
"url": "https://example.com",
"method": "POST",
@@ -509,8 +523,8 @@ class TestAPIToolMetadata:
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {}
mock_resp.content = b'{}'
mock_post.return_value = mock_resp
mock_pinned.return_value = mock_resp
tool.execute_action("create")
call_headers = mock_post.call_args[1]["headers"]
call_headers = mock_pinned.call_args.kwargs["headers"]
assert "Content-Type" in call_headers

View File

@@ -417,3 +417,181 @@ class TestSuccessfulRunClearsLease:
assert row[0] == "completed"
assert row[1] is None
assert row[2] is None
@pytest.mark.unit
class TestSynthesizedKeyGuardsKeylessDispatch:
"""A keyless dispatch carrying ``source_id`` is still poison-guarded:
the wrapper synthesizes a deterministic key from ``source_id``.
"""
def test_keyless_with_source_id_records_dedup_row(self, pg_conn):
from application.api.user.idempotency import with_idempotency
@with_idempotency(task_name="ingest")
def task(self, idempotency_key=None, source_id=None):
return {"ran": True}
with _patch_decorator_db(pg_conn):
result = task(_fake_celery_self(), source_id="src-abc")
assert result == {"ran": True}
row = _row_for(pg_conn, "auto:ingest:src-abc")
assert row is not None
assert row[0] == "ingest"
assert row[2] == "completed"
def test_synthesized_key_stable_across_redeliveries(self, pg_conn):
"""Same ``source_id`` → same key → a redelivery short-circuits to
the cached result instead of re-running the body.
"""
from application.api.user.idempotency import with_idempotency
runs = {"count": 0}
@with_idempotency(task_name="ingest")
def task(self, idempotency_key=None, source_id=None):
runs["count"] += 1
return {"n": runs["count"]}
with _patch_decorator_db(pg_conn):
first = task(_fake_celery_self(), source_id="src-1")
second = task(_fake_celery_self(), source_id="src-1")
assert first == second == {"n": 1}
assert runs["count"] == 1
def test_poison_guard_trips_for_keyless_dispatch(self, pg_conn):
"""The core fix: a keyless OOM-looping dispatch is bounded — the
guard trips after MAX_TASK_ATTEMPTS with no explicit key.
"""
from application.api.user.idempotency import (
MAX_TASK_ATTEMPTS, with_idempotency,
)
runs = {"count": 0}
@with_idempotency(task_name="ingest")
def task(self, idempotency_key=None, source_id=None):
runs["count"] += 1
raise RuntimeError("OOM-style failure")
with _patch_decorator_db(pg_conn):
for _ in range(MAX_TASK_ATTEMPTS):
with pytest.raises(RuntimeError):
task(_fake_celery_self(), source_id="src-poison")
result = task(_fake_celery_self(), source_id="src-poison")
assert runs["count"] == MAX_TASK_ATTEMPTS
assert result["success"] is False
assert "poison-loop" in result["error"]
assert _row_for(pg_conn, "auto:ingest:src-poison")[2] == "failed"
def test_no_source_id_no_key_runs_unguarded(self, pg_conn):
"""No explicit key and no ``source_id`` anchor → pass through with
no DB writes, exactly as before.
"""
from application.api.user.idempotency import with_idempotency
@with_idempotency(task_name="store_attachment")
def task(self, idempotency_key=None):
return {"ran": True}
with patch(
"application.api.user.idempotency.db_session"
) as mock_session, patch(
"application.api.user.idempotency.db_readonly"
) as mock_readonly:
result = task(_fake_celery_self())
assert result == {"ran": True}
assert mock_session.call_count == 0
assert mock_readonly.call_count == 0
def test_explicit_key_takes_precedence_over_source_id(self, pg_conn):
"""An explicit key wins; the synthesized ``auto:`` key is unused."""
from application.api.user.idempotency import with_idempotency
@with_idempotency(task_name="ingest")
def task(self, idempotency_key=None, source_id=None):
return {"ran": True}
with _patch_decorator_db(pg_conn):
task(
_fake_celery_self(),
idempotency_key="explicit-k",
source_id="src-x",
)
assert _row_for(pg_conn, "explicit-k") is not None
assert _row_for(pg_conn, "auto:ingest:src-x") is None
@pytest.mark.unit
class TestPoisonHook:
"""``on_poison`` fires on the poison-guard branch with the task's
bound arguments, and never on the success path.
"""
def test_hook_invoked_with_bound_args_on_poison(self, pg_conn):
from application.api.user.idempotency import (
MAX_TASK_ATTEMPTS, with_idempotency,
)
captured = []
def _hook(task_name, bound):
captured.append((task_name, bound))
@with_idempotency(task_name="ingest", on_poison=_hook)
def task(self, idempotency_key=None, source_id=None):
raise RuntimeError("never converges")
with _patch_decorator_db(pg_conn):
for _ in range(MAX_TASK_ATTEMPTS):
with pytest.raises(RuntimeError):
task(_fake_celery_self(), source_id="src-h")
task(_fake_celery_self(), source_id="src-h")
assert len(captured) == 1
task_name, bound = captured[0]
assert task_name == "ingest"
assert bound["source_id"] == "src-h"
def test_hook_not_invoked_on_success(self, pg_conn):
from application.api.user.idempotency import with_idempotency
calls = []
@with_idempotency(
task_name="ingest", on_poison=lambda *a: calls.append(a)
)
def task(self, idempotency_key=None, source_id=None):
return {"ok": True}
with _patch_decorator_db(pg_conn):
task(_fake_celery_self(), source_id="src-ok")
assert calls == []
def test_hook_failure_does_not_break_poison_return(self, pg_conn):
"""A throwing hook must not change the poison-guard outcome."""
from application.api.user.idempotency import (
MAX_TASK_ATTEMPTS, with_idempotency,
)
def _bad_hook(task_name, bound):
raise ValueError("hook blew up")
@with_idempotency(task_name="ingest", on_poison=_bad_hook)
def task(self, idempotency_key=None, source_id=None):
raise RuntimeError("never converges")
with _patch_decorator_db(pg_conn):
for _ in range(MAX_TASK_ATTEMPTS):
with pytest.raises(RuntimeError):
task(_fake_celery_self(), source_id="src-bad")
result = task(_fake_celery_self(), source_id="src-bad")
assert result["success"] is False
assert "poison-loop" in result["error"]

View File

@@ -546,3 +546,63 @@ class TestIngestIdempotency:
assert first == second
assert first == {"status": "ok", "directory": "dir"}
assert len(worker_calls) == 1
class TestIngestPoisonEvent:
"""The poison hook publishes a terminal source.ingest.failed so the
upload toast resolves instead of hanging on "training".
"""
@pytest.mark.unit
def test_publishes_failed_event(self):
from application.api.user.tasks import _emit_ingest_poison_event
published = []
def _fake_publish(user, event_type, payload, *, scope=None):
published.append((user, event_type, payload, scope))
with patch(
"application.events.publisher.publish_user_event",
side_effect=_fake_publish,
):
_emit_ingest_poison_event(
"ingest",
{"user": "u1", "source_id": "src-9", "filename": "doc.pdf"},
)
assert len(published) == 1
user, event_type, payload, scope = published[0]
assert user == "u1"
assert event_type == "source.ingest.failed"
assert payload["source_id"] == "src-9"
assert payload["filename"] == "doc.pdf"
assert payload["operation"] == "upload"
assert scope == {"kind": "source", "id": "src-9"}
@pytest.mark.unit
def test_skips_when_source_id_missing(self):
from application.api.user.tasks import _emit_ingest_poison_event
with patch(
"application.events.publisher.publish_user_event",
) as mock_publish:
_emit_ingest_poison_event("ingest", {"user": "u1"})
mock_publish.assert_not_called()
@pytest.mark.unit
def test_reingest_uses_reingest_operation(self):
from application.api.user.tasks import _emit_ingest_poison_event
published = []
with patch(
"application.events.publisher.publish_user_event",
side_effect=lambda *a, **k: published.append((a, k)),
):
_emit_ingest_poison_event(
"reingest_source_task",
{"user": "u1", "source_id": "src-r"},
)
assert published[0][0][2]["operation"] == "reingest"

View File

@@ -421,3 +421,85 @@ class TestDoclingParserGaps:
parser = DoclingCSVParser()
assert parser.export_format == "markdown"
assert parser.ocr_enabled is True
# =====================================================================
# Pipeline memory caps
# =====================================================================
@pytest.mark.unit
class TestApplyPipelineCaps:
"""_apply_pipeline_caps bounds docling's threaded-pipeline buffering."""
def test_caps_threaded_pipeline_knobs(self, monkeypatch):
from application.core.settings import settings
from application.parser.file.docling_parser import _apply_pipeline_caps
monkeypatch.setattr(
settings, "DOCLING_PIPELINE_QUEUE_MAX_SIZE", 2, raising=False
)
class Opts:
# docling >= 2.94 threaded pipeline — all knobs present.
queue_max_size = 100
layout_batch_size = 4
table_batch_size = 4
ocr_batch_size = 4
opts = Opts()
_apply_pipeline_caps(opts)
assert opts.queue_max_size == 2
assert opts.layout_batch_size == 1
assert opts.table_batch_size == 1
assert opts.ocr_batch_size == 1
def test_queue_size_is_settings_driven(self, monkeypatch):
from application.core.settings import settings
from application.parser.file.docling_parser import _apply_pipeline_caps
monkeypatch.setattr(
settings, "DOCLING_PIPELINE_QUEUE_MAX_SIZE", 6, raising=False
)
class Opts:
queue_max_size = 100
opts = Opts()
_apply_pipeline_caps(opts)
assert opts.queue_max_size == 6
def test_misconfigured_zero_floors_to_one(self, monkeypatch):
"""A 0 queue depth could deadlock the threaded pipeline — floor it."""
from application.core.settings import settings
from application.parser.file.docling_parser import _apply_pipeline_caps
monkeypatch.setattr(
settings, "DOCLING_PIPELINE_QUEUE_MAX_SIZE", 0, raising=False
)
class Opts:
queue_max_size = 100
opts = Opts()
_apply_pipeline_caps(opts)
assert opts.queue_max_size == 1
def test_noop_on_docling_without_threaded_pipeline(self):
"""Builds predating the threaded pipeline lack the knobs — the cap
must be a silent no-op, not an AttributeError."""
from application.parser.file.docling_parser import _apply_pipeline_caps
class LegacyOpts:
__slots__ = ("do_ocr", "do_table_structure")
def __init__(self):
self.do_ocr = False
self.do_table_structure = True
opts = LegacyOpts()
_apply_pipeline_caps(opts) # must not raise
assert not hasattr(opts, "queue_max_size")
assert not hasattr(opts, "layout_batch_size")