mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-22 05:15:08 +00:00
Compare commits
3 Commits
fix-stuck-
...
hardening-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9a5ea8fe00 | ||
|
|
1de82ca040 | ||
|
|
8f7742c937 |
@@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
from urllib.parse import urlencode
|
||||
from urllib.parse import quote, urlencode
|
||||
|
||||
import requests
|
||||
|
||||
@@ -11,7 +11,7 @@ from application.agents.tools.api_body_serializer import (
|
||||
RequestBodySerializer,
|
||||
)
|
||||
from application.agents.tools.base import Tool
|
||||
from application.core.url_validation import validate_url, SSRFError
|
||||
from application.security.safe_url import UnsafeUserUrlError, pinned_request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -70,18 +70,16 @@ class APITool(Tool):
|
||||
Returns:
|
||||
Dict with status_code, data, and message
|
||||
"""
|
||||
_VALID_METHODS = {"GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"}
|
||||
|
||||
request_url = url
|
||||
request_headers = headers.copy() if headers else {}
|
||||
response = None
|
||||
|
||||
# Validate URL to prevent SSRF attacks
|
||||
try:
|
||||
validate_url(request_url)
|
||||
except SSRFError as e:
|
||||
logger.error(f"URL validation failed: {e}")
|
||||
if method.upper() not in _VALID_METHODS:
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"URL validation error: {e}",
|
||||
"message": f"Unsupported HTTP method: {method}",
|
||||
"data": None,
|
||||
}
|
||||
|
||||
@@ -91,8 +89,9 @@ class APITool(Tool):
|
||||
for match in re.finditer(r"\{([^}]+)\}", request_url):
|
||||
param_name = match.group(1)
|
||||
if param_name in query_params:
|
||||
safe_value = quote(str(query_params[param_name]), safe="")
|
||||
request_url = request_url.replace(
|
||||
f"{{{param_name}}}", str(query_params[param_name])
|
||||
f"{{{param_name}}}", safe_value
|
||||
)
|
||||
path_params_used.add(param_name)
|
||||
remaining_params = {
|
||||
@@ -103,19 +102,6 @@ class APITool(Tool):
|
||||
separator = "&" if "?" in request_url else "?"
|
||||
request_url = f"{request_url}{separator}{query_string}"
|
||||
|
||||
# Re-validate URL after parameter substitution to prevent SSRF via path params
|
||||
try:
|
||||
validate_url(request_url)
|
||||
except SSRFError as e:
|
||||
logger.error(f"URL validation failed after parameter substitution: {e}")
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"URL validation error: {e}",
|
||||
"data": None,
|
||||
}
|
||||
|
||||
# Serialize body based on content type
|
||||
|
||||
if body and body != {}:
|
||||
try:
|
||||
serialized_body, body_headers = RequestBodySerializer.serialize(
|
||||
@@ -141,49 +127,13 @@ class APITool(Tool):
|
||||
f"API Call: {method} {request_url} | Content-Type: {request_headers.get('Content-Type', 'N/A')}"
|
||||
)
|
||||
|
||||
if method.upper() == "GET":
|
||||
response = requests.get(
|
||||
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
|
||||
)
|
||||
elif method.upper() == "POST":
|
||||
response = requests.post(
|
||||
request_url,
|
||||
data=serialized_body,
|
||||
headers=request_headers,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
)
|
||||
elif method.upper() == "PUT":
|
||||
response = requests.put(
|
||||
request_url,
|
||||
data=serialized_body,
|
||||
headers=request_headers,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
)
|
||||
elif method.upper() == "DELETE":
|
||||
response = requests.delete(
|
||||
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
|
||||
)
|
||||
elif method.upper() == "PATCH":
|
||||
response = requests.patch(
|
||||
request_url,
|
||||
data=serialized_body,
|
||||
headers=request_headers,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
)
|
||||
elif method.upper() == "HEAD":
|
||||
response = requests.head(
|
||||
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
|
||||
)
|
||||
elif method.upper() == "OPTIONS":
|
||||
response = requests.options(
|
||||
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
|
||||
)
|
||||
else:
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"Unsupported HTTP method: {method}",
|
||||
"data": None,
|
||||
}
|
||||
response = pinned_request(
|
||||
method,
|
||||
request_url,
|
||||
data=serialized_body,
|
||||
headers=request_headers,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
data = self._parse_response(response)
|
||||
@@ -193,6 +143,13 @@ class APITool(Tool):
|
||||
"data": data,
|
||||
"message": "API call successful.",
|
||||
}
|
||||
except UnsafeUserUrlError as e:
|
||||
logger.error(f"URL validation failed: {e}")
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"URL validation error: {e}",
|
||||
"data": None,
|
||||
}
|
||||
except requests.exceptions.Timeout:
|
||||
logger.error(f"Request timeout for {request_url}")
|
||||
return {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import requests
|
||||
from application.agents.tools.base import Tool
|
||||
from application.security.safe_url import UnsafeUserUrlError, pinned_request
|
||||
|
||||
class NtfyTool(Tool):
|
||||
"""
|
||||
@@ -71,7 +71,12 @@ class NtfyTool(Tool):
|
||||
if self.token:
|
||||
headers["Authorization"] = f"Basic {self.token}"
|
||||
data = message.encode("utf-8")
|
||||
response = requests.post(url, headers=headers, data=data, timeout=100)
|
||||
try:
|
||||
response = pinned_request(
|
||||
"POST", url, data=data, headers=headers, timeout=100,
|
||||
)
|
||||
except UnsafeUserUrlError as e:
|
||||
return {"status_code": None, "message": f"URL validation error: {e}"}
|
||||
return {"status_code": response.status_code, "message": "Message sent"}
|
||||
|
||||
def get_actions_metadata(self):
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import requests
|
||||
from markdownify import markdownify
|
||||
from application.agents.tools.base import Tool
|
||||
from application.core.url_validation import validate_url, SSRFError
|
||||
from application.security.safe_url import UnsafeUserUrlError, pinned_request
|
||||
|
||||
class ReadWebpageTool(Tool):
|
||||
"""
|
||||
@@ -31,28 +30,24 @@ class ReadWebpageTool(Tool):
|
||||
if not url:
|
||||
return "Error: URL parameter is missing."
|
||||
|
||||
# Validate URL to prevent SSRF attacks
|
||||
try:
|
||||
url = validate_url(url)
|
||||
except SSRFError as e:
|
||||
return f"Error: URL validation failed - {e}"
|
||||
response = pinned_request(
|
||||
"GET",
|
||||
url,
|
||||
headers={'User-Agent': 'DocsGPT-Agent/1.0'},
|
||||
timeout=10,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
try:
|
||||
response = requests.get(url, timeout=10, headers={'User-Agent': 'DocsGPT-Agent/1.0'})
|
||||
response.raise_for_status() # Raise an exception for HTTP errors (4xx or 5xx)
|
||||
|
||||
html_content = response.text
|
||||
#soup = BeautifulSoup(html_content, 'html.parser')
|
||||
|
||||
|
||||
markdown_content = markdownify(html_content, heading_style="ATX", newline_style="BACKSLASH")
|
||||
|
||||
|
||||
return markdown_content
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
return f"Error fetching URL {url}: {e}"
|
||||
except UnsafeUserUrlError as e:
|
||||
return f"Error: URL validation failed - {e}"
|
||||
except Exception as e:
|
||||
return f"Error processing URL {url}: {e}"
|
||||
return f"Error fetching URL {url}: {e}"
|
||||
|
||||
def get_actions_metadata(self):
|
||||
"""
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
@@ -26,13 +27,20 @@ LEASE_HEARTBEAT_INTERVAL = 30
|
||||
LEASE_RETRY_MAX = 10
|
||||
|
||||
|
||||
def with_idempotency(task_name: str) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
def with_idempotency(
|
||||
task_name: str,
|
||||
*,
|
||||
on_poison: Optional[Callable[[str, dict], None]] = None,
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
"""Short-circuit on completed key; gate concurrent runs via a lease.
|
||||
|
||||
The guard key is the caller's ``idempotency_key``, or one synthesized
|
||||
from ``source_id`` so a keyless dispatch is still poison-guarded.
|
||||
|
||||
Entry short-circuits:
|
||||
- completed row → return cached result
|
||||
- live lease held → retry(countdown=LEASE_TTL_SECONDS)
|
||||
- attempt_count > MAX_TASK_ATTEMPTS → poison-loop alert
|
||||
- attempt_count > MAX_TASK_ATTEMPTS → poison alert; ``on_poison`` fires
|
||||
Success writes ``completed``; exceptions leave ``pending`` for
|
||||
autoretry until the poison-loop guard trips.
|
||||
"""
|
||||
@@ -40,7 +48,14 @@ def with_idempotency(task_name: str) -> Callable[[Callable[..., Any]], Callable[
|
||||
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
|
||||
@functools.wraps(fn)
|
||||
def wrapper(self, *args: Any, idempotency_key: Any = None, **kwargs: Any) -> Any:
|
||||
key = idempotency_key if isinstance(idempotency_key, str) and idempotency_key else None
|
||||
explicit_key = (
|
||||
idempotency_key
|
||||
if isinstance(idempotency_key, str) and idempotency_key
|
||||
else None
|
||||
)
|
||||
# A keyless dispatch still gets the guard via a synthesized key;
|
||||
# None means no anchor exists — run unguarded, as before.
|
||||
key = explicit_key or _synthesize_guard_key(task_name, kwargs)
|
||||
if key is None:
|
||||
return fn(self, *args, idempotency_key=idempotency_key, **kwargs)
|
||||
|
||||
@@ -88,6 +103,9 @@ def with_idempotency(task_name: str) -> Callable[[Callable[..., Any]], Callable[
|
||||
"attempts": attempt,
|
||||
}
|
||||
_finalize(key, poisoned, status="failed")
|
||||
_run_poison_hook(
|
||||
on_poison, task_name, fn, self, args, kwargs, idempotency_key,
|
||||
)
|
||||
return poisoned
|
||||
|
||||
heartbeat_thread, heartbeat_stop = _start_lease_heartbeat(
|
||||
@@ -109,6 +127,45 @@ def with_idempotency(task_name: str) -> Callable[[Callable[..., Any]], Callable[
|
||||
return decorator
|
||||
|
||||
|
||||
def _synthesize_guard_key(task_name: str, kwargs: dict) -> Optional[str]:
|
||||
"""Derive a deterministic guard key from ``source_id`` for a keyless dispatch.
|
||||
|
||||
``source_id`` is stable across broker redeliveries and unique per
|
||||
upload, so the poison-loop counter survives an OOM SIGKILL. Returns
|
||||
``None`` when absent — the dispatch then runs unguarded as before.
|
||||
"""
|
||||
source_id = kwargs.get("source_id")
|
||||
if source_id:
|
||||
return f"auto:{task_name}:{source_id}"
|
||||
return None
|
||||
|
||||
|
||||
def _run_poison_hook(
|
||||
on_poison: Optional[Callable[[str, dict], None]],
|
||||
task_name: str,
|
||||
fn: Callable[..., Any],
|
||||
task_self: Any,
|
||||
args: tuple,
|
||||
kwargs: dict,
|
||||
idempotency_key: Any,
|
||||
) -> None:
|
||||
"""Invoke a task's poison-path hook with named call args; swallow failures.
|
||||
|
||||
A hook failure must never change the poison-guard outcome.
|
||||
"""
|
||||
if on_poison is None:
|
||||
return
|
||||
try:
|
||||
bound = inspect.signature(fn).bind_partial(
|
||||
task_self, *args, idempotency_key=idempotency_key, **kwargs,
|
||||
)
|
||||
on_poison(task_name, dict(bound.arguments))
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"idempotency: poison hook failed for task=%s", task_name,
|
||||
)
|
||||
|
||||
|
||||
def _lookup_completed(key: str) -> Any:
|
||||
"""Return cached ``result_json`` if a completed row exists for ``key``, else None."""
|
||||
with db_readonly() as conn:
|
||||
|
||||
@@ -27,8 +27,42 @@ DURABLE_TASK = dict(
|
||||
)
|
||||
|
||||
|
||||
# operation tag for the poison-path source.ingest.failed event, per task.
|
||||
_INGEST_POISON_OPERATION = {
|
||||
"ingest": "upload",
|
||||
"ingest_remote": "upload",
|
||||
"ingest_connector_task": "upload",
|
||||
"reingest_source_task": "reingest",
|
||||
}
|
||||
|
||||
|
||||
def _emit_ingest_poison_event(task_name, bound):
|
||||
"""Publish a terminal ``source.ingest.failed`` when the poison-guard trips.
|
||||
|
||||
The guard returns before the worker runs, so the worker's own failed
|
||||
event never fires — without this the upload toast spins on "training".
|
||||
"""
|
||||
user = bound.get("user")
|
||||
source_id = bound.get("source_id")
|
||||
if not user or not source_id:
|
||||
return
|
||||
from application.events.publisher import publish_user_event
|
||||
|
||||
publish_user_event(
|
||||
user,
|
||||
"source.ingest.failed",
|
||||
{
|
||||
"source_id": str(source_id),
|
||||
"filename": bound.get("filename") or "",
|
||||
"operation": _INGEST_POISON_OPERATION.get(task_name, "upload"),
|
||||
"error": "Ingestion stopped after repeated failures.",
|
||||
},
|
||||
scope={"kind": "source", "id": str(source_id)},
|
||||
)
|
||||
|
||||
|
||||
@celery.task(**DURABLE_TASK)
|
||||
@with_idempotency(task_name="ingest")
|
||||
@with_idempotency(task_name="ingest", on_poison=_emit_ingest_poison_event)
|
||||
def ingest(
|
||||
self,
|
||||
directory,
|
||||
@@ -57,7 +91,7 @@ def ingest(
|
||||
|
||||
|
||||
@celery.task(**DURABLE_TASK)
|
||||
@with_idempotency(task_name="ingest_remote")
|
||||
@with_idempotency(task_name="ingest_remote", on_poison=_emit_ingest_poison_event)
|
||||
def ingest_remote(
|
||||
self, source_data, job_name, user, loader,
|
||||
idempotency_key=None, source_id=None,
|
||||
@@ -71,7 +105,9 @@ def ingest_remote(
|
||||
|
||||
|
||||
@celery.task(**DURABLE_TASK)
|
||||
@with_idempotency(task_name="reingest_source_task")
|
||||
@with_idempotency(
|
||||
task_name="reingest_source_task", on_poison=_emit_ingest_poison_event,
|
||||
)
|
||||
def reingest_source_task(self, source_id, user, idempotency_key=None):
|
||||
from application.worker import reingest_source_worker
|
||||
|
||||
@@ -128,7 +164,9 @@ def process_agent_webhook(self, agent_id, payload, idempotency_key=None):
|
||||
|
||||
|
||||
@celery.task(**DURABLE_TASK)
|
||||
@with_idempotency(task_name="ingest_connector_task")
|
||||
@with_idempotency(
|
||||
task_name="ingest_connector_task", on_poison=_emit_ingest_poison_event,
|
||||
)
|
||||
def ingest_connector_task(
|
||||
self,
|
||||
job_name,
|
||||
|
||||
@@ -66,6 +66,9 @@ class Settings(BaseSettings):
|
||||
PARSE_IMAGE_REMOTE: bool = False
|
||||
DOCLING_OCR_ENABLED: bool = False # Enable OCR for docling parsers (PDF, images)
|
||||
DOCLING_OCR_ATTACHMENTS_ENABLED: bool = False # Enable OCR for docling when parsing attachments
|
||||
# Pages docling's threaded pipeline buffers in flight; the library
|
||||
# default (100) drives worker RSS to ~3 GB on a mid-size PDF.
|
||||
DOCLING_PIPELINE_QUEUE_MAX_SIZE: int = 2
|
||||
VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb" or "pgvector"
|
||||
RETRIEVERS_ENABLED: list = ["classic_rag"]
|
||||
AGENT_NAME: str = "classic"
|
||||
|
||||
@@ -16,6 +16,29 @@ from application.parser.file.base_parser import BaseParser
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Per-stage batch size for docling's threaded pipeline; 1 holds the
|
||||
# concurrent working set to a single page (see _apply_pipeline_caps).
|
||||
_PIPELINE_BATCH_SIZE = 1
|
||||
|
||||
|
||||
def _apply_pipeline_caps(pipeline_options) -> None:
|
||||
"""Cap docling's threaded-pipeline queue depth and batch sizes in place.
|
||||
|
||||
hasattr-guarded so docling builds without these knobs are unaffected.
|
||||
"""
|
||||
from application.core.settings import settings
|
||||
|
||||
caps = {
|
||||
"queue_max_size": max(1, settings.DOCLING_PIPELINE_QUEUE_MAX_SIZE),
|
||||
"layout_batch_size": _PIPELINE_BATCH_SIZE,
|
||||
"table_batch_size": _PIPELINE_BATCH_SIZE,
|
||||
"ocr_batch_size": _PIPELINE_BATCH_SIZE,
|
||||
}
|
||||
for name, value in caps.items():
|
||||
if hasattr(pipeline_options, name):
|
||||
setattr(pipeline_options, name, value)
|
||||
|
||||
|
||||
class DoclingParser(BaseParser):
|
||||
"""Parser using docling for advanced document processing.
|
||||
|
||||
@@ -86,6 +109,7 @@ class DoclingParser(BaseParser):
|
||||
do_ocr=self.ocr_enabled,
|
||||
do_table_structure=self.table_structure,
|
||||
)
|
||||
_apply_pipeline_caps(pipeline_options)
|
||||
|
||||
if self.ocr_enabled:
|
||||
ocr_options = self._get_ocr_options()
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import logging
|
||||
import os
|
||||
import requests
|
||||
from urllib.parse import urlparse, urljoin
|
||||
from bs4 import BeautifulSoup
|
||||
from application.parser.remote.base import BaseRemote
|
||||
from application.parser.schema.base import Document
|
||||
from application.core.url_validation import validate_url, SSRFError
|
||||
from application.security.safe_url import UnsafeUserUrlError, pinned_request
|
||||
from langchain_community.document_loaders import WebBaseLoader
|
||||
|
||||
class CrawlerLoader(BaseRemote):
|
||||
@@ -35,14 +35,7 @@ class CrawlerLoader(BaseRemote):
|
||||
visited_urls.add(current_url)
|
||||
|
||||
try:
|
||||
# Validate each URL before making requests
|
||||
try:
|
||||
validate_url(current_url)
|
||||
except SSRFError as e:
|
||||
logging.warning(f"Skipping URL due to validation failure: {current_url} - {e}")
|
||||
continue
|
||||
|
||||
response = requests.get(current_url, timeout=30)
|
||||
response = pinned_request("GET", current_url, timeout=30)
|
||||
response.raise_for_status()
|
||||
loader = self.loader([current_url])
|
||||
docs = loader.load()
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import requests
|
||||
from urllib.parse import urlparse, urljoin
|
||||
from bs4 import BeautifulSoup
|
||||
from application.parser.remote.base import BaseRemote
|
||||
from application.core.url_validation import validate_url, SSRFError
|
||||
from application.security.safe_url import UnsafeUserUrlError, pinned_request
|
||||
import re
|
||||
from markdownify import markdownify
|
||||
from application.parser.schema.base import Document
|
||||
@@ -20,7 +20,6 @@ class CrawlerLoader(BaseRemote):
|
||||
"""
|
||||
self.limit = limit
|
||||
self.allow_subdomains = allow_subdomains
|
||||
self.session = requests.Session()
|
||||
|
||||
def load_data(self, inputs):
|
||||
url = inputs
|
||||
@@ -91,15 +90,13 @@ class CrawlerLoader(BaseRemote):
|
||||
|
||||
def _fetch_page(self, url):
|
||||
try:
|
||||
# Validate URL before fetching to prevent SSRF
|
||||
validate_url(url)
|
||||
response = self.session.get(url, timeout=10)
|
||||
response = pinned_request("GET", url, timeout=10)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except SSRFError as e:
|
||||
except UnsafeUserUrlError as e:
|
||||
print(f"URL validation failed for {url}: {e}")
|
||||
return None
|
||||
except requests.exceptions.RequestException as e:
|
||||
except Exception as e:
|
||||
print(f"Error fetching URL {url}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import logging
|
||||
import requests
|
||||
import re # Import regular expression library
|
||||
import re
|
||||
import defusedxml.ElementTree as ET
|
||||
from application.parser.remote.base import BaseRemote
|
||||
from application.core.url_validation import validate_url, SSRFError
|
||||
from application.security.safe_url import UnsafeUserUrlError, pinned_request
|
||||
|
||||
class SitemapLoader(BaseRemote):
|
||||
def __init__(self, limit=20):
|
||||
@@ -53,14 +53,12 @@ class SitemapLoader(BaseRemote):
|
||||
|
||||
def _extract_urls(self, sitemap_url):
|
||||
try:
|
||||
# Validate URL before fetching to prevent SSRF
|
||||
validate_url(sitemap_url)
|
||||
response = requests.get(sitemap_url, timeout=30)
|
||||
response.raise_for_status() # Raise an exception for HTTP errors
|
||||
except SSRFError as e:
|
||||
response = pinned_request("GET", sitemap_url, timeout=30)
|
||||
response.raise_for_status()
|
||||
except UnsafeUserUrlError as e:
|
||||
print(f"URL validation failed for sitemap: {sitemap_url}. Error: {e}")
|
||||
return []
|
||||
except (requests.exceptions.HTTPError, requests.exceptions.ConnectionError) as e:
|
||||
except Exception as e:
|
||||
print(f"Failed to fetch sitemap: {sitemap_url}. Error: {e}")
|
||||
return []
|
||||
|
||||
@@ -97,13 +95,6 @@ class SitemapLoader(BaseRemote):
|
||||
nested_sitemap_url = sitemap.text
|
||||
if not nested_sitemap_url:
|
||||
continue
|
||||
try:
|
||||
nested_sitemap_url = validate_url(nested_sitemap_url)
|
||||
except SSRFError as e:
|
||||
logging.error(
|
||||
f"URL validation failed for nested sitemap {nested_sitemap_url}: {e}"
|
||||
)
|
||||
continue
|
||||
urls.extend(self._extract_urls(nested_sitemap_url))
|
||||
|
||||
return urls
|
||||
|
||||
@@ -291,6 +291,55 @@ def _ip_to_url_host(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> str:
|
||||
return str(ip)
|
||||
|
||||
|
||||
def pinned_request(
|
||||
method: str,
|
||||
url: str,
|
||||
*,
|
||||
data: Any = None,
|
||||
json: Any = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: float = 90.0,
|
||||
allow_redirects: bool = False,
|
||||
) -> requests.Response:
|
||||
"""Send an HTTP request with the connection pinned to a validated IP,
|
||||
closing the DNS-rebinding TOCTOU window left by the naive
|
||||
validate-then-``requests`` pattern.
|
||||
|
||||
Raises:
|
||||
UnsafeUserUrlError: If the URL fails the SSRF guard.
|
||||
requests.RequestException: For network-level failures.
|
||||
"""
|
||||
|
||||
host, ip, parts = _validate_and_pick_ip(url)
|
||||
|
||||
netloc = _ip_to_url_host(ip)
|
||||
if parts.port is not None:
|
||||
netloc = f"{netloc}:{parts.port}"
|
||||
pinned_url = urlunsplit(
|
||||
(parts.scheme, netloc, parts.path, parts.query, parts.fragment)
|
||||
)
|
||||
|
||||
request_headers = dict(headers or {})
|
||||
host_header = host if parts.port is None else f"{host}:{parts.port}"
|
||||
request_headers["Host"] = host_header
|
||||
|
||||
session = requests.Session()
|
||||
if parts.scheme == "https":
|
||||
session.mount("https://", _PinnedHostAdapter(host))
|
||||
try:
|
||||
return session.request(
|
||||
method=method.upper(),
|
||||
url=pinned_url,
|
||||
data=data,
|
||||
json=json,
|
||||
headers=request_headers,
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects,
|
||||
)
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
def pinned_post(
|
||||
url: str,
|
||||
*,
|
||||
@@ -328,33 +377,15 @@ def pinned_post(
|
||||
requests.RequestException: For network-level failures.
|
||||
"""
|
||||
|
||||
host, ip, parts = _validate_and_pick_ip(url)
|
||||
|
||||
netloc = _ip_to_url_host(ip)
|
||||
if parts.port is not None:
|
||||
netloc = f"{netloc}:{parts.port}"
|
||||
pinned_url = urlunsplit(
|
||||
(parts.scheme, netloc, parts.path, parts.query, parts.fragment)
|
||||
return pinned_request(
|
||||
"POST",
|
||||
url,
|
||||
json=json,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects,
|
||||
)
|
||||
|
||||
request_headers = dict(headers or {})
|
||||
host_header = host if parts.port is None else f"{host}:{parts.port}"
|
||||
request_headers["Host"] = host_header
|
||||
|
||||
session = requests.Session()
|
||||
if parts.scheme == "https":
|
||||
session.mount("https://", _PinnedHostAdapter(host))
|
||||
try:
|
||||
return session.post(
|
||||
pinned_url,
|
||||
json=json,
|
||||
headers=request_headers,
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects,
|
||||
)
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
class _PinnedHTTPSTransport(httpx.HTTPTransport):
|
||||
"""``httpx`` transport pinned to a single validated IP literal.
|
||||
|
||||
@@ -45,15 +45,14 @@ class TestAPIToolInit:
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMakeApiCall:
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_successful_get(self, mock_get, mock_validate, tool):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_successful_get(self, mock_pinned, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {"result": "ok"}
|
||||
mock_resp.content = b'{"result":"ok"}'
|
||||
mock_get.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("any_action")
|
||||
|
||||
@@ -61,54 +60,50 @@ class TestMakeApiCall:
|
||||
assert result["data"] == {"result": "ok"}
|
||||
assert result["message"] == "API call successful."
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.post")
|
||||
def test_successful_post(self, mock_post, mock_validate, post_tool):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_successful_post(self, mock_pinned, post_tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 201
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {"id": 1}
|
||||
mock_resp.content = b'{"id":1}'
|
||||
mock_post.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = post_tool.execute_action("create", name="test")
|
||||
|
||||
assert result["status_code"] == 201
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
def test_ssrf_blocked(self, mock_validate, tool):
|
||||
from application.core.url_validation import SSRFError
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_ssrf_blocked(self, mock_pinned, tool):
|
||||
from application.security.safe_url import UnsafeUserUrlError
|
||||
|
||||
mock_validate.side_effect = SSRFError("blocked")
|
||||
mock_pinned.side_effect = UnsafeUserUrlError("blocked")
|
||||
|
||||
result = tool.execute_action("any")
|
||||
|
||||
assert result["status_code"] is None
|
||||
assert "URL validation error" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_timeout_error(self, mock_get, mock_validate, tool):
|
||||
mock_get.side_effect = requests.exceptions.Timeout()
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_timeout_error(self, mock_pinned, tool):
|
||||
mock_pinned.side_effect = requests.exceptions.Timeout()
|
||||
|
||||
result = tool.execute_action("any")
|
||||
|
||||
assert result["status_code"] is None
|
||||
assert "timeout" in result["message"].lower()
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_connection_error(self, mock_get, mock_validate, tool):
|
||||
mock_get.side_effect = requests.exceptions.ConnectionError("refused")
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_connection_error(self, mock_pinned, tool):
|
||||
mock_pinned.side_effect = requests.exceptions.ConnectionError("refused")
|
||||
|
||||
result = tool.execute_action("any")
|
||||
|
||||
assert result["status_code"] is None
|
||||
assert "Connection error" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_http_error(self, mock_get, mock_validate, tool):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_http_error(self, mock_pinned, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 404
|
||||
mock_resp.text = "Not Found"
|
||||
@@ -116,15 +111,14 @@ class TestMakeApiCall:
|
||||
mock_resp.raise_for_status.side_effect = requests.exceptions.HTTPError(
|
||||
response=mock_resp
|
||||
)
|
||||
mock_get.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("any")
|
||||
|
||||
assert result["status_code"] == 404
|
||||
assert "HTTP Error" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
def test_unsupported_method(self, mock_validate):
|
||||
def test_unsupported_method(self):
|
||||
tool = APITool(
|
||||
config={"url": "https://example.com", "method": "CUSTOM"}
|
||||
)
|
||||
@@ -132,69 +126,64 @@ class TestMakeApiCall:
|
||||
assert result["status_code"] is None
|
||||
assert "Unsupported" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.put")
|
||||
def test_put_method(self, mock_put, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_put_method(self, mock_pinned):
|
||||
tool = APITool(config={"url": "https://example.com/item/1", "method": "PUT"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {}
|
||||
mock_resp.content = b'{}'
|
||||
mock_put.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("update", name="new")
|
||||
assert result["status_code"] == 200
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.delete")
|
||||
def test_delete_method(self, mock_delete, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_delete_method(self, mock_pinned):
|
||||
tool = APITool(config={"url": "https://example.com/item/1", "method": "DELETE"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 204
|
||||
mock_resp.headers = {"Content-Type": "text/plain"}
|
||||
mock_resp.content = b''
|
||||
mock_delete.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("delete")
|
||||
assert result["status_code"] == 204
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.patch")
|
||||
def test_patch_method(self, mock_patch, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_patch_method(self, mock_pinned):
|
||||
tool = APITool(config={"url": "https://example.com/item/1", "method": "PATCH"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {"patched": True}
|
||||
mock_resp.content = b'{"patched":true}'
|
||||
mock_patch.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("patch", field="val")
|
||||
assert result["status_code"] == 200
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.head")
|
||||
def test_head_method(self, mock_head, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_head_method(self, mock_pinned):
|
||||
tool = APITool(config={"url": "https://example.com", "method": "HEAD"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "text/html"}
|
||||
mock_resp.content = b''
|
||||
mock_head.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("check")
|
||||
assert result["status_code"] == 200
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.options")
|
||||
def test_options_method(self, mock_options, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_options_method(self, mock_pinned):
|
||||
tool = APITool(config={"url": "https://example.com", "method": "OPTIONS"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "text/plain"}
|
||||
mock_resp.content = b''
|
||||
mock_options.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("options")
|
||||
assert result["status_code"] == 200
|
||||
@@ -202,9 +191,8 @@ class TestMakeApiCall:
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPathParamSubstitution:
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_path_params_substituted(self, mock_get, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_path_params_substituted(self, mock_pinned):
|
||||
tool = APITool(
|
||||
config={
|
||||
"url": "https://api.example.com/users/{user_id}/posts/{post_id}",
|
||||
@@ -217,11 +205,11 @@ class TestPathParamSubstitution:
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = []
|
||||
mock_resp.content = b'[]'
|
||||
mock_get.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
tool.execute_action("get")
|
||||
|
||||
called_url = mock_get.call_args[0][0]
|
||||
called_url = mock_pinned.call_args[0][1]
|
||||
assert "/users/42/posts/7" in called_url
|
||||
assert "{user_id}" not in called_url
|
||||
|
||||
|
||||
@@ -81,104 +81,103 @@ class TestAPIToolInit:
|
||||
@pytest.mark.unit
|
||||
class TestMakeApiCall:
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_successful_get(self, mock_get, mock_validate, get_tool):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_successful_get(self, mock_pinned, get_tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {"result": "ok"}
|
||||
mock_resp.content = b'{"result":"ok"}'
|
||||
mock_get.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = get_tool.execute_action("any_action")
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert result["data"] == {"result": "ok"}
|
||||
assert result["message"] == "API call successful."
|
||||
assert mock_pinned.call_args[0][0] == "GET"
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.post")
|
||||
def test_successful_post(self, mock_post, mock_validate, post_tool):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_successful_post(self, mock_pinned, post_tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 201
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {"id": 1}
|
||||
mock_resp.content = b'{"id":1}'
|
||||
mock_post.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = post_tool.execute_action("create", name="test")
|
||||
assert result["status_code"] == 201
|
||||
assert mock_pinned.call_args[0][0] == "POST"
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.put")
|
||||
def test_put_method(self, mock_put, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_put_method(self, mock_pinned):
|
||||
tool = APITool(config={"url": "https://example.com/item/1", "method": "PUT"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {}
|
||||
mock_resp.content = b'{}'
|
||||
mock_put.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("update", name="new")
|
||||
assert result["status_code"] == 200
|
||||
assert mock_pinned.call_args[0][0] == "PUT"
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.delete")
|
||||
def test_delete_method(self, mock_delete, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_delete_method(self, mock_pinned):
|
||||
tool = APITool(config={"url": "https://example.com/item/1", "method": "DELETE"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 204
|
||||
mock_resp.headers = {"Content-Type": "text/plain"}
|
||||
mock_resp.content = b''
|
||||
mock_delete.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("delete")
|
||||
assert result["status_code"] == 204
|
||||
assert mock_pinned.call_args[0][0] == "DELETE"
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.patch")
|
||||
def test_patch_method(self, mock_patch, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_patch_method(self, mock_pinned):
|
||||
tool = APITool(config={"url": "https://example.com/item/1", "method": "PATCH"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {"patched": True}
|
||||
mock_resp.content = b'{"patched":true}'
|
||||
mock_patch.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("patch", field="val")
|
||||
assert result["status_code"] == 200
|
||||
assert mock_pinned.call_args[0][0] == "PATCH"
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.head")
|
||||
def test_head_method(self, mock_head, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_head_method(self, mock_pinned):
|
||||
tool = APITool(config={"url": "https://example.com", "method": "HEAD"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "text/html"}
|
||||
mock_resp.content = b''
|
||||
mock_head.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("check")
|
||||
assert result["status_code"] == 200
|
||||
assert mock_pinned.call_args[0][0] == "HEAD"
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.options")
|
||||
def test_options_method(self, mock_options, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_options_method(self, mock_pinned):
|
||||
tool = APITool(config={"url": "https://example.com", "method": "OPTIONS"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "text/plain"}
|
||||
mock_resp.content = b''
|
||||
mock_options.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("options")
|
||||
assert result["status_code"] == 200
|
||||
assert mock_pinned.call_args[0][0] == "OPTIONS"
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
def test_unsupported_method(self, mock_validate):
|
||||
def test_unsupported_method(self):
|
||||
tool = APITool(config={"url": "https://example.com", "method": "CUSTOM"})
|
||||
result = tool.execute_action("any")
|
||||
assert result["status_code"] is None
|
||||
@@ -193,19 +192,18 @@ class TestMakeApiCall:
|
||||
@pytest.mark.unit
|
||||
class TestSSRFValidation:
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
def test_ssrf_blocked_initial_url(self, mock_validate, get_tool):
|
||||
from application.core.url_validation import SSRFError
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_ssrf_blocked(self, mock_pinned, get_tool):
|
||||
from application.security.safe_url import UnsafeUserUrlError
|
||||
|
||||
mock_validate.side_effect = SSRFError("blocked")
|
||||
mock_pinned.side_effect = UnsafeUserUrlError("blocked")
|
||||
result = get_tool.execute_action("any")
|
||||
assert result["status_code"] is None
|
||||
assert "URL validation error" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_ssrf_blocked_after_param_substitution(self, mock_get, mock_validate):
|
||||
from application.core.url_validation import SSRFError
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_ssrf_blocked_with_path_params(self, mock_pinned):
|
||||
from application.security.safe_url import UnsafeUserUrlError
|
||||
|
||||
tool = APITool(config={
|
||||
"url": "https://api.example.com/{host}/data",
|
||||
@@ -213,14 +211,7 @@ class TestSSRFValidation:
|
||||
"query_params": {"host": "169.254.169.254"},
|
||||
})
|
||||
|
||||
call_count = [0]
|
||||
|
||||
def side_effect(url):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 2:
|
||||
raise SSRFError("blocked after substitution")
|
||||
|
||||
mock_validate.side_effect = side_effect
|
||||
mock_pinned.side_effect = UnsafeUserUrlError("blocked")
|
||||
result = tool.execute_action("any")
|
||||
assert result["status_code"] is None
|
||||
assert "URL validation error" in result["message"]
|
||||
@@ -234,40 +225,36 @@ class TestSSRFValidation:
|
||||
@pytest.mark.unit
|
||||
class TestErrorHandling:
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_timeout_error(self, mock_get, mock_validate, get_tool):
|
||||
mock_get.side_effect = requests.exceptions.Timeout()
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_timeout_error(self, mock_pinned, get_tool):
|
||||
mock_pinned.side_effect = requests.exceptions.Timeout()
|
||||
result = get_tool.execute_action("any")
|
||||
assert result["status_code"] is None
|
||||
assert "timeout" in result["message"].lower()
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_connection_error(self, mock_get, mock_validate, get_tool):
|
||||
mock_get.side_effect = requests.exceptions.ConnectionError("refused")
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_connection_error(self, mock_pinned, get_tool):
|
||||
mock_pinned.side_effect = requests.exceptions.ConnectionError("refused")
|
||||
result = get_tool.execute_action("any")
|
||||
assert result["status_code"] is None
|
||||
assert "Connection error" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_http_error_with_json(self, mock_get, mock_validate, get_tool):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_http_error_with_json(self, mock_pinned, get_tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 422
|
||||
mock_resp.json.return_value = {"error": "invalid_field"}
|
||||
mock_resp.raise_for_status.side_effect = requests.exceptions.HTTPError(
|
||||
response=mock_resp
|
||||
)
|
||||
mock_get.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = get_tool.execute_action("any")
|
||||
assert result["status_code"] == 422
|
||||
assert result["data"] == {"error": "invalid_field"}
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_http_error_non_json_body(self, mock_get, mock_validate, get_tool):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_http_error_non_json_body(self, mock_pinned, get_tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 404
|
||||
mock_resp.text = "Not Found"
|
||||
@@ -275,29 +262,26 @@ class TestErrorHandling:
|
||||
mock_resp.raise_for_status.side_effect = requests.exceptions.HTTPError(
|
||||
response=mock_resp
|
||||
)
|
||||
mock_get.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = get_tool.execute_action("any")
|
||||
assert result["status_code"] == 404
|
||||
assert result["data"] == "Not Found"
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_request_exception(self, mock_get, mock_validate, get_tool):
|
||||
mock_get.side_effect = requests.exceptions.RequestException("something")
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_request_exception(self, mock_pinned, get_tool):
|
||||
mock_pinned.side_effect = requests.exceptions.RequestException("something")
|
||||
result = get_tool.execute_action("any")
|
||||
assert "API call failed" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_unexpected_exception(self, mock_get, mock_validate, get_tool):
|
||||
mock_get.side_effect = RuntimeError("unexpected")
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_unexpected_exception(self, mock_pinned, get_tool):
|
||||
mock_pinned.side_effect = RuntimeError("unexpected")
|
||||
result = get_tool.execute_action("any")
|
||||
assert "Unexpected error" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.post")
|
||||
def test_body_serialization_error(self, mock_post, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_body_serialization_error(self, mock_pinned):
|
||||
tool = APITool(config={
|
||||
"url": "https://example.com",
|
||||
"method": "POST",
|
||||
@@ -320,9 +304,8 @@ class TestErrorHandling:
|
||||
@pytest.mark.unit
|
||||
class TestPathParamSubstitution:
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_path_params_substituted(self, mock_get, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_path_params_substituted(self, mock_pinned):
|
||||
tool = APITool(config={
|
||||
"url": "https://api.example.com/users/{user_id}/posts/{post_id}",
|
||||
"method": "GET",
|
||||
@@ -333,17 +316,16 @@ class TestPathParamSubstitution:
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = []
|
||||
mock_resp.content = b'[]'
|
||||
mock_get.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
tool.execute_action("get")
|
||||
|
||||
called_url = mock_get.call_args[0][0]
|
||||
called_url = mock_pinned.call_args[0][1]
|
||||
assert "/users/42/posts/7" in called_url
|
||||
assert "{user_id}" not in called_url
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_remaining_query_params_appended(self, mock_get, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_remaining_query_params_appended(self, mock_pinned):
|
||||
tool = APITool(config={
|
||||
"url": "https://api.example.com/items",
|
||||
"method": "GET",
|
||||
@@ -354,19 +336,16 @@ class TestPathParamSubstitution:
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = []
|
||||
mock_resp.content = b'[]'
|
||||
mock_get.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
tool.execute_action("get")
|
||||
|
||||
called_url = mock_get.call_args[0][0]
|
||||
called_url = mock_pinned.call_args[0][1]
|
||||
assert "page=2" in called_url
|
||||
assert "limit=10" in called_url
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_query_params_append_with_existing_query_string(
|
||||
self, mock_get, mock_validate
|
||||
):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_query_params_append_with_existing_query_string(self, mock_pinned):
|
||||
tool = APITool(config={
|
||||
"url": "https://api.example.com/items?existing=true",
|
||||
"method": "GET",
|
||||
@@ -377,27 +356,65 @@ class TestPathParamSubstitution:
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = []
|
||||
mock_resp.content = b'[]'
|
||||
mock_get.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
tool.execute_action("get")
|
||||
|
||||
called_url = mock_get.call_args[0][0]
|
||||
called_url = mock_pinned.call_args[0][1]
|
||||
assert "&page=1" in called_url
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.post")
|
||||
def test_empty_body_no_serialization(self, mock_post, mock_validate):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_empty_body_no_serialization(self, mock_pinned):
|
||||
tool = APITool(config={"url": "https://example.com", "method": "POST"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {}
|
||||
mock_resp.content = b'{}'
|
||||
mock_post.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("create")
|
||||
assert result["status_code"] == 200
|
||||
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_path_params_are_url_encoded(self, mock_pinned):
|
||||
tool = APITool(config={
|
||||
"url": "https://api.example.com/users/{user_id}/profile",
|
||||
"method": "GET",
|
||||
"query_params": {"user_id": "../../admin"},
|
||||
})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {}
|
||||
mock_resp.content = b'{}'
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
tool.execute_action("get")
|
||||
|
||||
called_url = mock_pinned.call_args[0][1]
|
||||
assert "../../admin" not in called_url
|
||||
assert "%2F" in called_url or "%2f" in called_url
|
||||
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_path_params_query_injection_encoded(self, mock_pinned):
|
||||
tool = APITool(config={
|
||||
"url": "https://api.example.com/items/{item_id}",
|
||||
"method": "GET",
|
||||
"query_params": {"item_id": "x?admin=true"},
|
||||
})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {}
|
||||
mock_resp.content = b'{}'
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
tool.execute_action("get")
|
||||
|
||||
called_url = mock_pinned.call_args[0][1]
|
||||
assert "x?admin=true" not in called_url
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Parse Response
|
||||
@@ -494,11 +511,8 @@ class TestAPIToolMetadata:
|
||||
def test_config_requirements_empty(self, get_tool):
|
||||
assert get_tool.get_config_requirements() == {}
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.post")
|
||||
def test_content_type_set_for_post_with_no_headers(
|
||||
self, mock_post, mock_validate
|
||||
):
|
||||
@patch("application.agents.tools.api_tool.pinned_request")
|
||||
def test_content_type_set_for_post_with_no_headers(self, mock_pinned):
|
||||
tool = APITool(config={
|
||||
"url": "https://example.com",
|
||||
"method": "POST",
|
||||
@@ -509,8 +523,8 @@ class TestAPIToolMetadata:
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {}
|
||||
mock_resp.content = b'{}'
|
||||
mock_post.return_value = mock_resp
|
||||
mock_pinned.return_value = mock_resp
|
||||
|
||||
tool.execute_action("create")
|
||||
call_headers = mock_post.call_args[1]["headers"]
|
||||
call_headers = mock_pinned.call_args.kwargs["headers"]
|
||||
assert "Content-Type" in call_headers
|
||||
|
||||
@@ -417,3 +417,181 @@ class TestSuccessfulRunClearsLease:
|
||||
assert row[0] == "completed"
|
||||
assert row[1] is None
|
||||
assert row[2] is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSynthesizedKeyGuardsKeylessDispatch:
|
||||
"""A keyless dispatch carrying ``source_id`` is still poison-guarded:
|
||||
the wrapper synthesizes a deterministic key from ``source_id``.
|
||||
"""
|
||||
|
||||
def test_keyless_with_source_id_records_dedup_row(self, pg_conn):
|
||||
from application.api.user.idempotency import with_idempotency
|
||||
|
||||
@with_idempotency(task_name="ingest")
|
||||
def task(self, idempotency_key=None, source_id=None):
|
||||
return {"ran": True}
|
||||
|
||||
with _patch_decorator_db(pg_conn):
|
||||
result = task(_fake_celery_self(), source_id="src-abc")
|
||||
|
||||
assert result == {"ran": True}
|
||||
row = _row_for(pg_conn, "auto:ingest:src-abc")
|
||||
assert row is not None
|
||||
assert row[0] == "ingest"
|
||||
assert row[2] == "completed"
|
||||
|
||||
def test_synthesized_key_stable_across_redeliveries(self, pg_conn):
|
||||
"""Same ``source_id`` → same key → a redelivery short-circuits to
|
||||
the cached result instead of re-running the body.
|
||||
"""
|
||||
from application.api.user.idempotency import with_idempotency
|
||||
|
||||
runs = {"count": 0}
|
||||
|
||||
@with_idempotency(task_name="ingest")
|
||||
def task(self, idempotency_key=None, source_id=None):
|
||||
runs["count"] += 1
|
||||
return {"n": runs["count"]}
|
||||
|
||||
with _patch_decorator_db(pg_conn):
|
||||
first = task(_fake_celery_self(), source_id="src-1")
|
||||
second = task(_fake_celery_self(), source_id="src-1")
|
||||
|
||||
assert first == second == {"n": 1}
|
||||
assert runs["count"] == 1
|
||||
|
||||
def test_poison_guard_trips_for_keyless_dispatch(self, pg_conn):
|
||||
"""The core fix: a keyless OOM-looping dispatch is bounded — the
|
||||
guard trips after MAX_TASK_ATTEMPTS with no explicit key.
|
||||
"""
|
||||
from application.api.user.idempotency import (
|
||||
MAX_TASK_ATTEMPTS, with_idempotency,
|
||||
)
|
||||
|
||||
runs = {"count": 0}
|
||||
|
||||
@with_idempotency(task_name="ingest")
|
||||
def task(self, idempotency_key=None, source_id=None):
|
||||
runs["count"] += 1
|
||||
raise RuntimeError("OOM-style failure")
|
||||
|
||||
with _patch_decorator_db(pg_conn):
|
||||
for _ in range(MAX_TASK_ATTEMPTS):
|
||||
with pytest.raises(RuntimeError):
|
||||
task(_fake_celery_self(), source_id="src-poison")
|
||||
result = task(_fake_celery_self(), source_id="src-poison")
|
||||
|
||||
assert runs["count"] == MAX_TASK_ATTEMPTS
|
||||
assert result["success"] is False
|
||||
assert "poison-loop" in result["error"]
|
||||
assert _row_for(pg_conn, "auto:ingest:src-poison")[2] == "failed"
|
||||
|
||||
def test_no_source_id_no_key_runs_unguarded(self, pg_conn):
|
||||
"""No explicit key and no ``source_id`` anchor → pass through with
|
||||
no DB writes, exactly as before.
|
||||
"""
|
||||
from application.api.user.idempotency import with_idempotency
|
||||
|
||||
@with_idempotency(task_name="store_attachment")
|
||||
def task(self, idempotency_key=None):
|
||||
return {"ran": True}
|
||||
|
||||
with patch(
|
||||
"application.api.user.idempotency.db_session"
|
||||
) as mock_session, patch(
|
||||
"application.api.user.idempotency.db_readonly"
|
||||
) as mock_readonly:
|
||||
result = task(_fake_celery_self())
|
||||
|
||||
assert result == {"ran": True}
|
||||
assert mock_session.call_count == 0
|
||||
assert mock_readonly.call_count == 0
|
||||
|
||||
def test_explicit_key_takes_precedence_over_source_id(self, pg_conn):
|
||||
"""An explicit key wins; the synthesized ``auto:`` key is unused."""
|
||||
from application.api.user.idempotency import with_idempotency
|
||||
|
||||
@with_idempotency(task_name="ingest")
|
||||
def task(self, idempotency_key=None, source_id=None):
|
||||
return {"ran": True}
|
||||
|
||||
with _patch_decorator_db(pg_conn):
|
||||
task(
|
||||
_fake_celery_self(),
|
||||
idempotency_key="explicit-k",
|
||||
source_id="src-x",
|
||||
)
|
||||
|
||||
assert _row_for(pg_conn, "explicit-k") is not None
|
||||
assert _row_for(pg_conn, "auto:ingest:src-x") is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPoisonHook:
|
||||
"""``on_poison`` fires on the poison-guard branch with the task's
|
||||
bound arguments, and never on the success path.
|
||||
"""
|
||||
|
||||
def test_hook_invoked_with_bound_args_on_poison(self, pg_conn):
|
||||
from application.api.user.idempotency import (
|
||||
MAX_TASK_ATTEMPTS, with_idempotency,
|
||||
)
|
||||
|
||||
captured = []
|
||||
|
||||
def _hook(task_name, bound):
|
||||
captured.append((task_name, bound))
|
||||
|
||||
@with_idempotency(task_name="ingest", on_poison=_hook)
|
||||
def task(self, idempotency_key=None, source_id=None):
|
||||
raise RuntimeError("never converges")
|
||||
|
||||
with _patch_decorator_db(pg_conn):
|
||||
for _ in range(MAX_TASK_ATTEMPTS):
|
||||
with pytest.raises(RuntimeError):
|
||||
task(_fake_celery_self(), source_id="src-h")
|
||||
task(_fake_celery_self(), source_id="src-h")
|
||||
|
||||
assert len(captured) == 1
|
||||
task_name, bound = captured[0]
|
||||
assert task_name == "ingest"
|
||||
assert bound["source_id"] == "src-h"
|
||||
|
||||
def test_hook_not_invoked_on_success(self, pg_conn):
|
||||
from application.api.user.idempotency import with_idempotency
|
||||
|
||||
calls = []
|
||||
|
||||
@with_idempotency(
|
||||
task_name="ingest", on_poison=lambda *a: calls.append(a)
|
||||
)
|
||||
def task(self, idempotency_key=None, source_id=None):
|
||||
return {"ok": True}
|
||||
|
||||
with _patch_decorator_db(pg_conn):
|
||||
task(_fake_celery_self(), source_id="src-ok")
|
||||
|
||||
assert calls == []
|
||||
|
||||
def test_hook_failure_does_not_break_poison_return(self, pg_conn):
|
||||
"""A throwing hook must not change the poison-guard outcome."""
|
||||
from application.api.user.idempotency import (
|
||||
MAX_TASK_ATTEMPTS, with_idempotency,
|
||||
)
|
||||
|
||||
def _bad_hook(task_name, bound):
|
||||
raise ValueError("hook blew up")
|
||||
|
||||
@with_idempotency(task_name="ingest", on_poison=_bad_hook)
|
||||
def task(self, idempotency_key=None, source_id=None):
|
||||
raise RuntimeError("never converges")
|
||||
|
||||
with _patch_decorator_db(pg_conn):
|
||||
for _ in range(MAX_TASK_ATTEMPTS):
|
||||
with pytest.raises(RuntimeError):
|
||||
task(_fake_celery_self(), source_id="src-bad")
|
||||
result = task(_fake_celery_self(), source_id="src-bad")
|
||||
|
||||
assert result["success"] is False
|
||||
assert "poison-loop" in result["error"]
|
||||
|
||||
@@ -546,3 +546,63 @@ class TestIngestIdempotency:
|
||||
assert first == second
|
||||
assert first == {"status": "ok", "directory": "dir"}
|
||||
assert len(worker_calls) == 1
|
||||
|
||||
|
||||
class TestIngestPoisonEvent:
|
||||
"""The poison hook publishes a terminal source.ingest.failed so the
|
||||
upload toast resolves instead of hanging on "training".
|
||||
"""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_publishes_failed_event(self):
|
||||
from application.api.user.tasks import _emit_ingest_poison_event
|
||||
|
||||
published = []
|
||||
|
||||
def _fake_publish(user, event_type, payload, *, scope=None):
|
||||
published.append((user, event_type, payload, scope))
|
||||
|
||||
with patch(
|
||||
"application.events.publisher.publish_user_event",
|
||||
side_effect=_fake_publish,
|
||||
):
|
||||
_emit_ingest_poison_event(
|
||||
"ingest",
|
||||
{"user": "u1", "source_id": "src-9", "filename": "doc.pdf"},
|
||||
)
|
||||
|
||||
assert len(published) == 1
|
||||
user, event_type, payload, scope = published[0]
|
||||
assert user == "u1"
|
||||
assert event_type == "source.ingest.failed"
|
||||
assert payload["source_id"] == "src-9"
|
||||
assert payload["filename"] == "doc.pdf"
|
||||
assert payload["operation"] == "upload"
|
||||
assert scope == {"kind": "source", "id": "src-9"}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_skips_when_source_id_missing(self):
|
||||
from application.api.user.tasks import _emit_ingest_poison_event
|
||||
|
||||
with patch(
|
||||
"application.events.publisher.publish_user_event",
|
||||
) as mock_publish:
|
||||
_emit_ingest_poison_event("ingest", {"user": "u1"})
|
||||
|
||||
mock_publish.assert_not_called()
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_reingest_uses_reingest_operation(self):
|
||||
from application.api.user.tasks import _emit_ingest_poison_event
|
||||
|
||||
published = []
|
||||
with patch(
|
||||
"application.events.publisher.publish_user_event",
|
||||
side_effect=lambda *a, **k: published.append((a, k)),
|
||||
):
|
||||
_emit_ingest_poison_event(
|
||||
"reingest_source_task",
|
||||
{"user": "u1", "source_id": "src-r"},
|
||||
)
|
||||
|
||||
assert published[0][0][2]["operation"] == "reingest"
|
||||
|
||||
@@ -421,3 +421,85 @@ class TestDoclingParserGaps:
|
||||
parser = DoclingCSVParser()
|
||||
assert parser.export_format == "markdown"
|
||||
assert parser.ocr_enabled is True
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Pipeline memory caps
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestApplyPipelineCaps:
|
||||
"""_apply_pipeline_caps bounds docling's threaded-pipeline buffering."""
|
||||
|
||||
def test_caps_threaded_pipeline_knobs(self, monkeypatch):
|
||||
from application.core.settings import settings
|
||||
from application.parser.file.docling_parser import _apply_pipeline_caps
|
||||
|
||||
monkeypatch.setattr(
|
||||
settings, "DOCLING_PIPELINE_QUEUE_MAX_SIZE", 2, raising=False
|
||||
)
|
||||
|
||||
class Opts:
|
||||
# docling >= 2.94 threaded pipeline — all knobs present.
|
||||
queue_max_size = 100
|
||||
layout_batch_size = 4
|
||||
table_batch_size = 4
|
||||
ocr_batch_size = 4
|
||||
|
||||
opts = Opts()
|
||||
_apply_pipeline_caps(opts)
|
||||
|
||||
assert opts.queue_max_size == 2
|
||||
assert opts.layout_batch_size == 1
|
||||
assert opts.table_batch_size == 1
|
||||
assert opts.ocr_batch_size == 1
|
||||
|
||||
def test_queue_size_is_settings_driven(self, monkeypatch):
|
||||
from application.core.settings import settings
|
||||
from application.parser.file.docling_parser import _apply_pipeline_caps
|
||||
|
||||
monkeypatch.setattr(
|
||||
settings, "DOCLING_PIPELINE_QUEUE_MAX_SIZE", 6, raising=False
|
||||
)
|
||||
|
||||
class Opts:
|
||||
queue_max_size = 100
|
||||
|
||||
opts = Opts()
|
||||
_apply_pipeline_caps(opts)
|
||||
assert opts.queue_max_size == 6
|
||||
|
||||
def test_misconfigured_zero_floors_to_one(self, monkeypatch):
|
||||
"""A 0 queue depth could deadlock the threaded pipeline — floor it."""
|
||||
from application.core.settings import settings
|
||||
from application.parser.file.docling_parser import _apply_pipeline_caps
|
||||
|
||||
monkeypatch.setattr(
|
||||
settings, "DOCLING_PIPELINE_QUEUE_MAX_SIZE", 0, raising=False
|
||||
)
|
||||
|
||||
class Opts:
|
||||
queue_max_size = 100
|
||||
|
||||
opts = Opts()
|
||||
_apply_pipeline_caps(opts)
|
||||
assert opts.queue_max_size == 1
|
||||
|
||||
def test_noop_on_docling_without_threaded_pipeline(self):
|
||||
"""Builds predating the threaded pipeline lack the knobs — the cap
|
||||
must be a silent no-op, not an AttributeError."""
|
||||
from application.parser.file.docling_parser import _apply_pipeline_caps
|
||||
|
||||
class LegacyOpts:
|
||||
__slots__ = ("do_ocr", "do_table_structure")
|
||||
|
||||
def __init__(self):
|
||||
self.do_ocr = False
|
||||
self.do_table_structure = True
|
||||
|
||||
opts = LegacyOpts()
|
||||
_apply_pipeline_caps(opts) # must not raise
|
||||
|
||||
assert not hasattr(opts, "queue_max_size")
|
||||
assert not hasattr(opts, "layout_batch_size")
|
||||
|
||||
Reference in New Issue
Block a user