mirror of
https://github.com/docling-project/docling-serve.git
synced 2026-03-07 22:33:44 +00:00
fix: zombie task cleanup — reconcile stale RQ/Redis state (#516)
Signed-off-by: Pawel Rein <pawel.rein@prezi.com>
This commit is contained in:
@@ -141,14 +141,30 @@ async def lifespan(app: FastAPI):
|
||||
# Start the background queue processor
|
||||
queue_task = asyncio.create_task(orchestrator.process_queue())
|
||||
|
||||
reaper_task = None
|
||||
if hasattr(orchestrator, "_reap_zombie_tasks"):
|
||||
reaper_task = asyncio.create_task(
|
||||
orchestrator._reap_zombie_tasks(
|
||||
interval=docling_serve_settings.zombie_reaper_interval,
|
||||
max_age=docling_serve_settings.zombie_reaper_max_age,
|
||||
)
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
# Cancel the background queue processor on shutdown
|
||||
queue_task.cancel()
|
||||
if reaper_task:
|
||||
reaper_task.cancel()
|
||||
try:
|
||||
await queue_task
|
||||
except asyncio.CancelledError:
|
||||
_log.info("Queue processor cancelled.")
|
||||
if reaper_task:
|
||||
try:
|
||||
await reaper_task
|
||||
except asyncio.CancelledError:
|
||||
_log.info("Zombie reaper cancelled.")
|
||||
|
||||
# Remove scratch directory in case it was a tempfile
|
||||
if docling_serve_settings.scratch_path is not None:
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
from typing import Any, Union
|
||||
|
||||
import redis.asyncio as redis
|
||||
from rq.exceptions import NoSuchJobError
|
||||
|
||||
from docling_jobkit.datamodel.task import Task
|
||||
from docling_jobkit.datamodel.task_meta import TaskStatus
|
||||
@@ -18,6 +21,13 @@ from docling_serve.storage import get_scratch
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _RQJobGone:
|
||||
"""Sentinel: the RQ job has been deleted / TTL-expired."""
|
||||
|
||||
|
||||
_RQ_JOB_GONE = _RQJobGone()
|
||||
|
||||
|
||||
class RedisTaskStatusMixin:
|
||||
tasks: dict[str, Task]
|
||||
_task_result_keys: dict[str, str]
|
||||
@@ -64,43 +74,66 @@ class RedisTaskStatusMixin:
|
||||
|
||||
async def task_status(self, task_id: str, wait: float = 0.0) -> Task:
|
||||
"""
|
||||
Get task status by checking Redis first, then falling back to RQ verification.
|
||||
Get task status with zombie task reconciliation.
|
||||
|
||||
When Redis shows 'pending' but RQ shows 'success', we update Redis
|
||||
and return the RQ status for cross-instance consistency.
|
||||
Checks RQ first (authoritative), then Redis cache, then in-memory.
|
||||
When the RQ job is definitively gone (NoSuchJobError), reconciles:
|
||||
- Terminal status in Redis -> return it, clean up tracking
|
||||
- Non-terminal status in Redis -> mark FAILURE (orphaned task)
|
||||
- Not in Redis at all -> raise TaskNotFoundError
|
||||
"""
|
||||
_log.info(f"Task {task_id} status check")
|
||||
|
||||
# Always check RQ directly first - this is the most reliable source
|
||||
rq_task = await self._get_task_from_rq_direct(task_id)
|
||||
if rq_task:
|
||||
_log.info(f"Task {task_id} in RQ: {rq_task.task_status}")
|
||||
rq_result = await self._get_task_from_rq_direct(task_id)
|
||||
|
||||
# Update memory registry
|
||||
self.tasks[task_id] = rq_task
|
||||
if isinstance(rq_result, Task):
|
||||
_log.info(f"Task {task_id} in RQ: {rq_result.task_status}")
|
||||
self.tasks[task_id] = rq_result
|
||||
await self._store_task_in_redis(rq_result)
|
||||
return rq_result
|
||||
|
||||
# Store/update in Redis for other instances
|
||||
await self._store_task_in_redis(rq_task)
|
||||
return rq_task
|
||||
job_is_gone = isinstance(rq_result, _RQJobGone)
|
||||
|
||||
# If not in RQ, check Redis (maybe it's cached from another instance)
|
||||
task = await self._get_task_from_redis(task_id)
|
||||
if task:
|
||||
_log.info(f"Task {task_id} in Redis: {task.task_status}")
|
||||
|
||||
# CRITICAL FIX: Check if Redis status might be stale
|
||||
# STARTED tasks might have completed since they were cached
|
||||
if job_is_gone:
|
||||
if task.is_completed():
|
||||
_log.info(
|
||||
f"Task {task_id} completed ({task.task_status}) "
|
||||
f"and RQ job expired — cleaning up tracking"
|
||||
)
|
||||
self.tasks.pop(task_id, None)
|
||||
self._task_result_keys.pop(task_id, None)
|
||||
return task
|
||||
else:
|
||||
_log.warning(
|
||||
f"Task {task_id} was {task.task_status} but RQ job is gone "
|
||||
f"— marking as FAILURE (orphaned)"
|
||||
)
|
||||
task.set_status(TaskStatus.FAILURE)
|
||||
if hasattr(task, "error_message"):
|
||||
task.error_message = (
|
||||
f"Task orphaned: RQ job expired while status was "
|
||||
f"{task.task_status}. Likely caused by worker restart or "
|
||||
f"Redis eviction."
|
||||
)
|
||||
self.tasks.pop(task_id, None)
|
||||
self._task_result_keys.pop(task_id, None)
|
||||
await self._store_task_in_redis(task)
|
||||
return task
|
||||
|
||||
if task.task_status in [TaskStatus.PENDING, TaskStatus.STARTED]:
|
||||
_log.debug(f"Task {task_id} verifying stale status")
|
||||
|
||||
# Try to get fresh status from RQ
|
||||
fresh_rq_task = await self._get_task_from_rq_direct(task_id)
|
||||
if fresh_rq_task and fresh_rq_task.task_status != task.task_status:
|
||||
if (
|
||||
isinstance(fresh_rq_task, Task)
|
||||
and fresh_rq_task.task_status != task.task_status
|
||||
):
|
||||
_log.info(
|
||||
f"Task {task_id} status updated: {fresh_rq_task.task_status}"
|
||||
)
|
||||
|
||||
# Update memory and Redis with fresh status
|
||||
self.tasks[task_id] = fresh_rq_task
|
||||
await self._store_task_in_redis(fresh_rq_task)
|
||||
return fresh_rq_task
|
||||
@@ -109,12 +142,14 @@ class RedisTaskStatusMixin:
|
||||
|
||||
return task
|
||||
|
||||
# Fall back to parent implementation
|
||||
if job_is_gone:
|
||||
_log.warning(f"Task {task_id} not in RQ or Redis — truly gone")
|
||||
self.tasks.pop(task_id, None)
|
||||
raise TaskNotFoundError(task_id)
|
||||
|
||||
try:
|
||||
parent_task = await super().task_status(task_id, wait) # type: ignore[misc]
|
||||
_log.debug(f"Task {task_id} from parent: {parent_task.task_status}")
|
||||
|
||||
# Store in Redis for other instances to find
|
||||
await self._store_task_in_redis(parent_task)
|
||||
return parent_task
|
||||
except TaskNotFoundError:
|
||||
@@ -135,18 +170,23 @@ class RedisTaskStatusMixin:
|
||||
meta.setdefault("num_succeeded", 0)
|
||||
meta.setdefault("num_failed", 0)
|
||||
|
||||
return Task(
|
||||
task_id=data["task_id"],
|
||||
task_type=data["task_type"],
|
||||
task_status=TaskStatus(data["task_status"]),
|
||||
processing_meta=meta,
|
||||
error_message=data.get("error_message"),
|
||||
)
|
||||
task_kwargs: dict[str, Any] = {
|
||||
"task_id": data["task_id"],
|
||||
"task_type": data["task_type"],
|
||||
"task_status": TaskStatus(data["task_status"]),
|
||||
"processing_meta": meta,
|
||||
}
|
||||
if data.get("error_message") and "error_message" in Task.model_fields:
|
||||
task_kwargs["error_message"] = data["error_message"]
|
||||
task = Task(**task_kwargs)
|
||||
return task
|
||||
except Exception as e:
|
||||
_log.error(f"Redis get task {task_id}: {e}")
|
||||
return None
|
||||
|
||||
async def _get_task_from_rq_direct(self, task_id: str) -> Task | None:
|
||||
async def _get_task_from_rq_direct(
|
||||
self, task_id: str
|
||||
) -> Union[Task, _RQJobGone, None]:
|
||||
try:
|
||||
_log.debug(f"Checking RQ for task {task_id}")
|
||||
|
||||
@@ -172,7 +212,7 @@ class RedisTaskStatusMixin:
|
||||
if updated_task and updated_task.task_status != TaskStatus.PENDING:
|
||||
_log.debug(f"RQ task {task_id}: {updated_task.task_status}")
|
||||
|
||||
# Store result key if available
|
||||
result_ttl = docling_serve_settings.eng_rq_results_ttl
|
||||
if task_id in self._task_result_keys:
|
||||
try:
|
||||
async with redis.Redis(
|
||||
@@ -181,7 +221,7 @@ class RedisTaskStatusMixin:
|
||||
await r.set(
|
||||
f"{self.redis_prefix}{task_id}:result_key",
|
||||
self._task_result_keys[task_id],
|
||||
ex=86400,
|
||||
ex=result_ttl,
|
||||
)
|
||||
_log.debug(f"Stored result key for {task_id}")
|
||||
except Exception as e:
|
||||
@@ -191,13 +231,14 @@ class RedisTaskStatusMixin:
|
||||
return None
|
||||
|
||||
finally:
|
||||
# Restore original task state
|
||||
if original_task:
|
||||
self.tasks[task_id] = original_task
|
||||
elif task_id in self.tasks and self.tasks[task_id] == temp_task:
|
||||
# Only remove if it's still our temp task
|
||||
del self.tasks[task_id]
|
||||
|
||||
except NoSuchJobError:
|
||||
_log.info(f"RQ job {task_id} no longer exists (TTL expired or deleted)")
|
||||
return _RQ_JOB_GONE
|
||||
except Exception as e:
|
||||
_log.error(f"RQ check {task_id}: {e}")
|
||||
return None
|
||||
@@ -242,11 +283,13 @@ class RedisTaskStatusMixin:
|
||||
"processing_meta": meta,
|
||||
"error_message": getattr(task, "error_message", None),
|
||||
}
|
||||
|
||||
metadata_ttl = docling_serve_settings.eng_rq_results_ttl
|
||||
async with redis.Redis(connection_pool=self._redis_pool) as r:
|
||||
await r.set(
|
||||
f"{self.redis_prefix}{task.task_id}:metadata",
|
||||
json.dumps(data),
|
||||
ex=86400,
|
||||
ex=metadata_ttl,
|
||||
)
|
||||
except Exception as e:
|
||||
_log.error(f"Store task {task.task_id}: {e}")
|
||||
@@ -286,16 +329,53 @@ class RedisTaskStatusMixin:
|
||||
await self._store_task_in_redis(self.tasks[task_id])
|
||||
|
||||
if task_id in self._task_result_keys:
|
||||
result_ttl = docling_serve_settings.eng_rq_results_ttl
|
||||
try:
|
||||
async with redis.Redis(connection_pool=self._redis_pool) as r:
|
||||
await r.set(
|
||||
f"{self.redis_prefix}{task_id}:result_key",
|
||||
self._task_result_keys[task_id],
|
||||
ex=86400,
|
||||
ex=result_ttl,
|
||||
)
|
||||
except Exception as e:
|
||||
_log.error(f"Store result key {task_id}: {e}")
|
||||
|
||||
async def _reap_zombie_tasks(
|
||||
self, interval: float = 300.0, max_age: float = 3600.0
|
||||
) -> None:
|
||||
"""
|
||||
Periodically remove completed tasks from in-memory tracking.
|
||||
|
||||
Args:
|
||||
interval: Seconds between sweeps (default 5 min)
|
||||
max_age: Remove completed tasks older than this (default 1h)
|
||||
"""
|
||||
while True:
|
||||
await asyncio.sleep(interval)
|
||||
try:
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
cutoff = now - datetime.timedelta(seconds=max_age)
|
||||
to_remove: list[str] = []
|
||||
|
||||
for task_id, task in list(self.tasks.items()):
|
||||
if (
|
||||
task.is_completed()
|
||||
and task.finished_at
|
||||
and task.finished_at < cutoff
|
||||
):
|
||||
to_remove.append(task_id)
|
||||
|
||||
for task_id in to_remove:
|
||||
self.tasks.pop(task_id, None)
|
||||
self._task_result_keys.pop(task_id, None)
|
||||
_log.debug(f"Reaped zombie task {task_id}")
|
||||
|
||||
if to_remove:
|
||||
_log.info(f"Reaped {len(to_remove)} zombie tasks from tracking")
|
||||
|
||||
except Exception as e:
|
||||
_log.error(f"Zombie reaper error: {e}")
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_async_orchestrator() -> BaseOrchestrator:
|
||||
|
||||
@@ -94,6 +94,8 @@ class DoclingServeSettings(BaseSettings):
|
||||
eng_rq_redis_socket_connect_timeout: Optional[float] = (
|
||||
None # Socket connect timeout in seconds
|
||||
)
|
||||
zombie_reaper_interval: float = 300.0
|
||||
zombie_reaper_max_age: float = 3600.0
|
||||
# KFP engine
|
||||
eng_kfp_endpoint: Optional[AnyUrl] = None
|
||||
eng_kfp_token: Optional[str] = None
|
||||
|
||||
428
tests/test_zombie_task_cleanup.py
Normal file
428
tests/test_zombie_task_cleanup.py
Normal file
@@ -0,0 +1,428 @@
|
||||
"""Tests for zombie task cleanup in RedisTaskStatusMixin.
|
||||
|
||||
Tests cover:
|
||||
- Layer A: _RQJobGone sentinel from _get_task_from_rq_direct when NoSuchJobError
|
||||
- Layer B: task_status() reconciliation for zombie scenarios
|
||||
- Layer C: Background zombie reaper
|
||||
- Layer E: TTL alignment (metadata TTL uses results_ttl)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from rq.exceptions import NoSuchJobError
|
||||
|
||||
from docling_jobkit.datamodel.task import Task
|
||||
from docling_jobkit.datamodel.task_meta import TaskStatus
|
||||
from docling_jobkit.orchestrators.base_orchestrator import TaskNotFoundError
|
||||
|
||||
from docling_serve.orchestrator_factory import (
|
||||
_RQ_JOB_GONE,
|
||||
RedisTaskStatusMixin,
|
||||
_RQJobGone,
|
||||
)
|
||||
|
||||
|
||||
def _make_task(
|
||||
task_id: str = "test-task-1",
|
||||
status: TaskStatus = TaskStatus.SUCCESS,
|
||||
error_message: str | None = None,
|
||||
finished_at: datetime.datetime | None = None,
|
||||
) -> Task:
|
||||
task = Task(
|
||||
task_id=task_id,
|
||||
task_type="convert",
|
||||
task_status=status,
|
||||
processing_meta={
|
||||
"num_docs": 0,
|
||||
"num_processed": 0,
|
||||
"num_succeeded": 0,
|
||||
"num_failed": 0,
|
||||
},
|
||||
)
|
||||
if error_message:
|
||||
task.error_message = error_message
|
||||
if finished_at:
|
||||
task.finished_at = finished_at
|
||||
return task
|
||||
|
||||
|
||||
class FakeParentOrchestrator:
|
||||
"""Minimal fake parent to satisfy MRO without real RQ/Redis."""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
self.tasks: dict[str, Task] = {}
|
||||
self._task_result_keys: dict[str, str] = {}
|
||||
|
||||
async def _update_task_from_rq(self, task_id: str) -> None:
|
||||
raise NoSuchJobError(f"No such job: {task_id}")
|
||||
|
||||
async def task_status(self, task_id: str, wait: float = 0.0) -> Task:
|
||||
if task_id in self.tasks:
|
||||
return self.tasks[task_id]
|
||||
raise TaskNotFoundError(task_id)
|
||||
|
||||
|
||||
class FakeMixin(RedisTaskStatusMixin, FakeParentOrchestrator):
|
||||
"""Concrete class combining RedisTaskStatusMixin with a fake parent."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.tasks: dict[str, Task] = {}
|
||||
self._task_result_keys: dict[str, str] = {}
|
||||
self.redis_prefix = "docling:tasks:"
|
||||
self._redis_pool = MagicMock()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Layer A: Sentinel tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRQJobGoneSentinel:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_sentinel_on_no_such_job_error(self):
|
||||
mixin = FakeMixin()
|
||||
result = await mixin._get_task_from_rq_direct("missing-job")
|
||||
assert isinstance(result, _RQJobGone)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_on_generic_exception(self):
|
||||
mixin = FakeMixin()
|
||||
|
||||
async def raise_generic(self_inner, task_id: str) -> None:
|
||||
raise RuntimeError("Redis connection lost")
|
||||
|
||||
with patch.object(
|
||||
FakeParentOrchestrator, "_update_task_from_rq", raise_generic
|
||||
):
|
||||
result = await mixin._get_task_from_rq_direct("some-task")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_task_when_rq_has_job(self):
|
||||
mixin = FakeMixin()
|
||||
expected_task = _make_task("rq-task", TaskStatus.SUCCESS)
|
||||
|
||||
async def update_with_success(self_inner, task_id: str) -> None:
|
||||
mixin.tasks[task_id] = expected_task
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.__aenter__ = AsyncMock(return_value=mock_redis)
|
||||
mock_redis.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
FakeParentOrchestrator, "_update_task_from_rq", update_with_success
|
||||
),
|
||||
patch(
|
||||
"docling_serve.orchestrator_factory.redis.Redis",
|
||||
return_value=mock_redis,
|
||||
),
|
||||
):
|
||||
result = await mixin._get_task_from_rq_direct("rq-task")
|
||||
assert isinstance(result, Task)
|
||||
assert result.task_status == TaskStatus.SUCCESS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Layer B: task_status() reconciliation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTaskStatusReconciliation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_rq_gone_redis_success_cleans_up(self):
|
||||
"""RQ job gone + Redis has SUCCESS -> return task, clean up tracking."""
|
||||
mixin = FakeMixin()
|
||||
cached_task = _make_task("t1", TaskStatus.SUCCESS)
|
||||
mixin.tasks["t1"] = cached_task
|
||||
mixin._task_result_keys["t1"] = "some-key"
|
||||
|
||||
with (
|
||||
patch.object(mixin, "_get_task_from_rq_direct", return_value=_RQ_JOB_GONE),
|
||||
patch.object(mixin, "_get_task_from_redis", return_value=cached_task),
|
||||
):
|
||||
result = await mixin.task_status("t1")
|
||||
|
||||
assert result.task_status == TaskStatus.SUCCESS
|
||||
assert "t1" not in mixin.tasks
|
||||
assert "t1" not in mixin._task_result_keys
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rq_gone_redis_failure_cleans_up(self):
|
||||
"""RQ job gone + Redis has FAILURE -> return task, clean up tracking."""
|
||||
mixin = FakeMixin()
|
||||
cached_task = _make_task("t1", TaskStatus.FAILURE)
|
||||
mixin.tasks["t1"] = cached_task
|
||||
|
||||
with (
|
||||
patch.object(mixin, "_get_task_from_rq_direct", return_value=_RQ_JOB_GONE),
|
||||
patch.object(mixin, "_get_task_from_redis", return_value=cached_task),
|
||||
):
|
||||
result = await mixin.task_status("t1")
|
||||
|
||||
assert result.task_status == TaskStatus.FAILURE
|
||||
assert "t1" not in mixin.tasks
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rq_gone_redis_pending_marks_failure(self):
|
||||
"""RQ job gone + Redis has PENDING -> mark as FAILURE with error_message."""
|
||||
mixin = FakeMixin()
|
||||
cached_task = _make_task("t2", TaskStatus.PENDING)
|
||||
mixin.tasks["t2"] = cached_task
|
||||
|
||||
with (
|
||||
patch.object(mixin, "_get_task_from_rq_direct", return_value=_RQ_JOB_GONE),
|
||||
patch.object(mixin, "_get_task_from_redis", return_value=cached_task),
|
||||
patch.object(
|
||||
mixin, "_store_task_in_redis", new_callable=AsyncMock
|
||||
) as mock_store,
|
||||
):
|
||||
result = await mixin.task_status("t2")
|
||||
|
||||
assert result.task_status == TaskStatus.FAILURE
|
||||
assert "orphaned" in result.error_message.lower()
|
||||
assert "t2" not in mixin.tasks
|
||||
mock_store.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rq_gone_redis_started_marks_failure(self):
|
||||
"""RQ job gone + Redis has STARTED -> mark as FAILURE with error_message."""
|
||||
mixin = FakeMixin()
|
||||
cached_task = _make_task("t3", TaskStatus.STARTED)
|
||||
|
||||
with (
|
||||
patch.object(mixin, "_get_task_from_rq_direct", return_value=_RQ_JOB_GONE),
|
||||
patch.object(mixin, "_get_task_from_redis", return_value=cached_task),
|
||||
patch.object(mixin, "_store_task_in_redis", new_callable=AsyncMock),
|
||||
):
|
||||
result = await mixin.task_status("t3")
|
||||
|
||||
assert result.task_status == TaskStatus.FAILURE
|
||||
assert "orphaned" in result.error_message.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rq_gone_no_redis_raises_not_found(self):
|
||||
"""RQ job gone + not in Redis -> TaskNotFoundError."""
|
||||
mixin = FakeMixin()
|
||||
|
||||
with (
|
||||
patch.object(mixin, "_get_task_from_rq_direct", return_value=_RQ_JOB_GONE),
|
||||
patch.object(mixin, "_get_task_from_redis", return_value=None),
|
||||
):
|
||||
with pytest.raises(TaskNotFoundError):
|
||||
await mixin.task_status("ghost-task")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rq_transient_error_falls_through_to_redis(self):
|
||||
"""RQ returns None (transient) + Redis has SUCCESS -> return Redis task."""
|
||||
mixin = FakeMixin()
|
||||
cached_task = _make_task("t4", TaskStatus.SUCCESS)
|
||||
|
||||
with (
|
||||
patch.object(mixin, "_get_task_from_rq_direct", return_value=None),
|
||||
patch.object(mixin, "_get_task_from_redis", return_value=cached_task),
|
||||
):
|
||||
result = await mixin.task_status("t4")
|
||||
|
||||
assert result.task_status == TaskStatus.SUCCESS
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rq_has_task_returns_directly(self):
|
||||
"""When RQ has the task, return it directly."""
|
||||
mixin = FakeMixin()
|
||||
rq_task = _make_task("t5", TaskStatus.SUCCESS)
|
||||
|
||||
with (
|
||||
patch.object(mixin, "_get_task_from_rq_direct", return_value=rq_task),
|
||||
patch.object(mixin, "_store_task_in_redis", new_callable=AsyncMock),
|
||||
):
|
||||
result = await mixin.task_status("t5")
|
||||
|
||||
assert result is rq_task
|
||||
assert mixin.tasks["t5"] is rq_task
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Layer C: Background zombie reaper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestZombieReaper:
|
||||
@pytest.mark.asyncio
|
||||
async def test_reaps_old_completed_tasks(self):
|
||||
mixin = FakeMixin()
|
||||
old_time = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
|
||||
hours=2
|
||||
)
|
||||
old_task = _make_task("old-1", TaskStatus.SUCCESS, finished_at=old_time)
|
||||
mixin.tasks["old-1"] = old_task
|
||||
mixin._task_result_keys["old-1"] = "key-1"
|
||||
|
||||
reaper = asyncio.create_task(
|
||||
mixin._reap_zombie_tasks(interval=0.01, max_age=3600.0)
|
||||
)
|
||||
await asyncio.sleep(0.05)
|
||||
reaper.cancel()
|
||||
try:
|
||||
await reaper
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
assert "old-1" not in mixin.tasks
|
||||
assert "old-1" not in mixin._task_result_keys
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keeps_recent_completed_tasks(self):
|
||||
mixin = FakeMixin()
|
||||
recent_time = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
|
||||
minutes=5
|
||||
)
|
||||
recent_task = _make_task(
|
||||
"recent-1", TaskStatus.SUCCESS, finished_at=recent_time
|
||||
)
|
||||
mixin.tasks["recent-1"] = recent_task
|
||||
|
||||
reaper = asyncio.create_task(
|
||||
mixin._reap_zombie_tasks(interval=0.01, max_age=3600.0)
|
||||
)
|
||||
await asyncio.sleep(0.05)
|
||||
reaper.cancel()
|
||||
try:
|
||||
await reaper
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
assert "recent-1" in mixin.tasks
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keeps_in_progress_tasks(self):
|
||||
mixin = FakeMixin()
|
||||
old_time = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
|
||||
hours=2
|
||||
)
|
||||
started_task = _make_task("started-1", TaskStatus.STARTED)
|
||||
started_task.started_at = old_time
|
||||
mixin.tasks["started-1"] = started_task
|
||||
|
||||
pending_task = _make_task("pending-1", TaskStatus.PENDING)
|
||||
mixin.tasks["pending-1"] = pending_task
|
||||
|
||||
reaper = asyncio.create_task(
|
||||
mixin._reap_zombie_tasks(interval=0.01, max_age=3600.0)
|
||||
)
|
||||
await asyncio.sleep(0.05)
|
||||
reaper.cancel()
|
||||
try:
|
||||
await reaper
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
assert "started-1" in mixin.tasks
|
||||
assert "pending-1" in mixin.tasks
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reaps_failed_tasks_with_finished_at(self):
|
||||
mixin = FakeMixin()
|
||||
old_time = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
|
||||
hours=2
|
||||
)
|
||||
failed_task = _make_task("failed-1", TaskStatus.FAILURE, finished_at=old_time)
|
||||
mixin.tasks["failed-1"] = failed_task
|
||||
|
||||
reaper = asyncio.create_task(
|
||||
mixin._reap_zombie_tasks(interval=0.01, max_age=3600.0)
|
||||
)
|
||||
await asyncio.sleep(0.05)
|
||||
reaper.cancel()
|
||||
try:
|
||||
await reaper
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
assert "failed-1" not in mixin.tasks
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Layer E: TTL alignment
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTTLAlignment:
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_task_uses_results_ttl(self):
|
||||
"""Metadata TTL should match eng_rq_results_ttl, not 86400."""
|
||||
mixin = FakeMixin()
|
||||
task = _make_task("ttl-task", TaskStatus.SUCCESS)
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.__aenter__ = AsyncMock(return_value=mock_redis)
|
||||
mock_redis.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"docling_serve.orchestrator_factory.redis.Redis",
|
||||
return_value=mock_redis,
|
||||
),
|
||||
patch(
|
||||
"docling_serve.orchestrator_factory.docling_serve_settings"
|
||||
) as mock_settings,
|
||||
):
|
||||
mock_settings.eng_rq_results_ttl = 14400
|
||||
await mixin._store_task_in_redis(task)
|
||||
|
||||
mock_redis.set.assert_called_once()
|
||||
call_kwargs = mock_redis.set.call_args
|
||||
assert (
|
||||
call_kwargs.kwargs.get("ex") == 14400 or call_kwargs[1].get("ex") == 14400
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Error message propagation through Redis
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestErrorMessagePropagation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_and_retrieve_error_message(self):
|
||||
"""error_message should round-trip through Redis store/get."""
|
||||
mixin = FakeMixin()
|
||||
task = _make_task("err-task", TaskStatus.FAILURE, error_message="Out of memory")
|
||||
|
||||
stored_data = {}
|
||||
|
||||
async def fake_set(key: str, value: Any, ex: int = 0) -> None:
|
||||
stored_data[key] = value
|
||||
|
||||
async def fake_get(key: str) -> bytes | None:
|
||||
val = stored_data.get(key)
|
||||
if val is None:
|
||||
return None
|
||||
return val.encode() if isinstance(val, str) else val
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.__aenter__ = AsyncMock(return_value=mock_redis)
|
||||
mock_redis.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_redis.set = fake_set
|
||||
mock_redis.get = fake_get
|
||||
|
||||
with (
|
||||
patch(
|
||||
"docling_serve.orchestrator_factory.redis.Redis",
|
||||
return_value=mock_redis,
|
||||
),
|
||||
patch(
|
||||
"docling_serve.orchestrator_factory.docling_serve_settings"
|
||||
) as mock_settings,
|
||||
):
|
||||
mock_settings.eng_rq_results_ttl = 14400
|
||||
await mixin._store_task_in_redis(task)
|
||||
retrieved = await mixin._get_task_from_redis("err-task")
|
||||
|
||||
assert retrieved is not None
|
||||
assert retrieved.error_message == "Out of memory"
|
||||
assert retrieved.task_status == TaskStatus.FAILURE
|
||||
Reference in New Issue
Block a user