fix: zombie task cleanup — reconcile stale RQ/Redis state (#516)

Signed-off-by: Pawel Rein <pawel.rein@prezi.com>
This commit is contained in:
Paweł Rein
2026-02-24 22:23:56 +01:00
committed by GitHub
parent e1d8ea9278
commit 853003cf3b
4 changed files with 564 additions and 38 deletions

View File

@@ -141,14 +141,30 @@ async def lifespan(app: FastAPI):
# Start the background queue processor
queue_task = asyncio.create_task(orchestrator.process_queue())
reaper_task = None
if hasattr(orchestrator, "_reap_zombie_tasks"):
reaper_task = asyncio.create_task(
orchestrator._reap_zombie_tasks(
interval=docling_serve_settings.zombie_reaper_interval,
max_age=docling_serve_settings.zombie_reaper_max_age,
)
)
yield
# Cancel the background queue processor on shutdown
queue_task.cancel()
if reaper_task:
reaper_task.cancel()
try:
await queue_task
except asyncio.CancelledError:
_log.info("Queue processor cancelled.")
if reaper_task:
try:
await reaper_task
except asyncio.CancelledError:
_log.info("Zombie reaper cancelled.")
# Remove scratch directory in case it was a tempfile
if docling_serve_settings.scratch_path is not None:

View File

@@ -1,9 +1,12 @@
import asyncio
import datetime
import json
import logging
from functools import lru_cache
from typing import Any
from typing import Any, Union
import redis.asyncio as redis
from rq.exceptions import NoSuchJobError
from docling_jobkit.datamodel.task import Task
from docling_jobkit.datamodel.task_meta import TaskStatus
@@ -18,6 +21,13 @@ from docling_serve.storage import get_scratch
_log = logging.getLogger(__name__)
class _RQJobGone:
"""Sentinel: the RQ job has been deleted / TTL-expired."""
_RQ_JOB_GONE = _RQJobGone()
class RedisTaskStatusMixin:
tasks: dict[str, Task]
_task_result_keys: dict[str, str]
@@ -64,43 +74,66 @@ class RedisTaskStatusMixin:
async def task_status(self, task_id: str, wait: float = 0.0) -> Task:
"""
Get task status by checking Redis first, then falling back to RQ verification.
Get task status with zombie task reconciliation.
When Redis shows 'pending' but RQ shows 'success', we update Redis
and return the RQ status for cross-instance consistency.
Checks RQ first (authoritative), then Redis cache, then in-memory.
When the RQ job is definitively gone (NoSuchJobError), reconciles:
- Terminal status in Redis -> return it, clean up tracking
- Non-terminal status in Redis -> mark FAILURE (orphaned task)
- Not in Redis at all -> raise TaskNotFoundError
"""
_log.info(f"Task {task_id} status check")
# Always check RQ directly first - this is the most reliable source
rq_task = await self._get_task_from_rq_direct(task_id)
if rq_task:
_log.info(f"Task {task_id} in RQ: {rq_task.task_status}")
rq_result = await self._get_task_from_rq_direct(task_id)
# Update memory registry
self.tasks[task_id] = rq_task
if isinstance(rq_result, Task):
_log.info(f"Task {task_id} in RQ: {rq_result.task_status}")
self.tasks[task_id] = rq_result
await self._store_task_in_redis(rq_result)
return rq_result
# Store/update in Redis for other instances
await self._store_task_in_redis(rq_task)
return rq_task
job_is_gone = isinstance(rq_result, _RQJobGone)
# If not in RQ, check Redis (maybe it's cached from another instance)
task = await self._get_task_from_redis(task_id)
if task:
_log.info(f"Task {task_id} in Redis: {task.task_status}")
# CRITICAL FIX: Check if Redis status might be stale
# STARTED tasks might have completed since they were cached
if job_is_gone:
if task.is_completed():
_log.info(
f"Task {task_id} completed ({task.task_status}) "
f"and RQ job expired — cleaning up tracking"
)
self.tasks.pop(task_id, None)
self._task_result_keys.pop(task_id, None)
return task
else:
_log.warning(
f"Task {task_id} was {task.task_status} but RQ job is gone "
f"— marking as FAILURE (orphaned)"
)
task.set_status(TaskStatus.FAILURE)
if hasattr(task, "error_message"):
task.error_message = (
f"Task orphaned: RQ job expired while status was "
f"{task.task_status}. Likely caused by worker restart or "
f"Redis eviction."
)
self.tasks.pop(task_id, None)
self._task_result_keys.pop(task_id, None)
await self._store_task_in_redis(task)
return task
if task.task_status in [TaskStatus.PENDING, TaskStatus.STARTED]:
_log.debug(f"Task {task_id} verifying stale status")
# Try to get fresh status from RQ
fresh_rq_task = await self._get_task_from_rq_direct(task_id)
if fresh_rq_task and fresh_rq_task.task_status != task.task_status:
if (
isinstance(fresh_rq_task, Task)
and fresh_rq_task.task_status != task.task_status
):
_log.info(
f"Task {task_id} status updated: {fresh_rq_task.task_status}"
)
# Update memory and Redis with fresh status
self.tasks[task_id] = fresh_rq_task
await self._store_task_in_redis(fresh_rq_task)
return fresh_rq_task
@@ -109,12 +142,14 @@ class RedisTaskStatusMixin:
return task
# Fall back to parent implementation
if job_is_gone:
_log.warning(f"Task {task_id} not in RQ or Redis — truly gone")
self.tasks.pop(task_id, None)
raise TaskNotFoundError(task_id)
try:
parent_task = await super().task_status(task_id, wait) # type: ignore[misc]
_log.debug(f"Task {task_id} from parent: {parent_task.task_status}")
# Store in Redis for other instances to find
await self._store_task_in_redis(parent_task)
return parent_task
except TaskNotFoundError:
@@ -135,18 +170,23 @@ class RedisTaskStatusMixin:
meta.setdefault("num_succeeded", 0)
meta.setdefault("num_failed", 0)
return Task(
task_id=data["task_id"],
task_type=data["task_type"],
task_status=TaskStatus(data["task_status"]),
processing_meta=meta,
error_message=data.get("error_message"),
)
task_kwargs: dict[str, Any] = {
"task_id": data["task_id"],
"task_type": data["task_type"],
"task_status": TaskStatus(data["task_status"]),
"processing_meta": meta,
}
if data.get("error_message") and "error_message" in Task.model_fields:
task_kwargs["error_message"] = data["error_message"]
task = Task(**task_kwargs)
return task
except Exception as e:
_log.error(f"Redis get task {task_id}: {e}")
return None
async def _get_task_from_rq_direct(self, task_id: str) -> Task | None:
async def _get_task_from_rq_direct(
self, task_id: str
) -> Union[Task, _RQJobGone, None]:
try:
_log.debug(f"Checking RQ for task {task_id}")
@@ -172,7 +212,7 @@ class RedisTaskStatusMixin:
if updated_task and updated_task.task_status != TaskStatus.PENDING:
_log.debug(f"RQ task {task_id}: {updated_task.task_status}")
# Store result key if available
result_ttl = docling_serve_settings.eng_rq_results_ttl
if task_id in self._task_result_keys:
try:
async with redis.Redis(
@@ -181,7 +221,7 @@ class RedisTaskStatusMixin:
await r.set(
f"{self.redis_prefix}{task_id}:result_key",
self._task_result_keys[task_id],
ex=86400,
ex=result_ttl,
)
_log.debug(f"Stored result key for {task_id}")
except Exception as e:
@@ -191,13 +231,14 @@ class RedisTaskStatusMixin:
return None
finally:
# Restore original task state
if original_task:
self.tasks[task_id] = original_task
elif task_id in self.tasks and self.tasks[task_id] == temp_task:
# Only remove if it's still our temp task
del self.tasks[task_id]
except NoSuchJobError:
_log.info(f"RQ job {task_id} no longer exists (TTL expired or deleted)")
return _RQ_JOB_GONE
except Exception as e:
_log.error(f"RQ check {task_id}: {e}")
return None
@@ -242,11 +283,13 @@ class RedisTaskStatusMixin:
"processing_meta": meta,
"error_message": getattr(task, "error_message", None),
}
metadata_ttl = docling_serve_settings.eng_rq_results_ttl
async with redis.Redis(connection_pool=self._redis_pool) as r:
await r.set(
f"{self.redis_prefix}{task.task_id}:metadata",
json.dumps(data),
ex=86400,
ex=metadata_ttl,
)
except Exception as e:
_log.error(f"Store task {task.task_id}: {e}")
@@ -286,16 +329,53 @@ class RedisTaskStatusMixin:
await self._store_task_in_redis(self.tasks[task_id])
if task_id in self._task_result_keys:
result_ttl = docling_serve_settings.eng_rq_results_ttl
try:
async with redis.Redis(connection_pool=self._redis_pool) as r:
await r.set(
f"{self.redis_prefix}{task_id}:result_key",
self._task_result_keys[task_id],
ex=86400,
ex=result_ttl,
)
except Exception as e:
_log.error(f"Store result key {task_id}: {e}")
async def _reap_zombie_tasks(
self, interval: float = 300.0, max_age: float = 3600.0
) -> None:
"""
Periodically remove completed tasks from in-memory tracking.
Args:
interval: Seconds between sweeps (default 5 min)
max_age: Remove completed tasks older than this (default 1h)
"""
while True:
await asyncio.sleep(interval)
try:
now = datetime.datetime.now(datetime.timezone.utc)
cutoff = now - datetime.timedelta(seconds=max_age)
to_remove: list[str] = []
for task_id, task in list(self.tasks.items()):
if (
task.is_completed()
and task.finished_at
and task.finished_at < cutoff
):
to_remove.append(task_id)
for task_id in to_remove:
self.tasks.pop(task_id, None)
self._task_result_keys.pop(task_id, None)
_log.debug(f"Reaped zombie task {task_id}")
if to_remove:
_log.info(f"Reaped {len(to_remove)} zombie tasks from tracking")
except Exception as e:
_log.error(f"Zombie reaper error: {e}")
@lru_cache
def get_async_orchestrator() -> BaseOrchestrator:

View File

@@ -94,6 +94,8 @@ class DoclingServeSettings(BaseSettings):
eng_rq_redis_socket_connect_timeout: Optional[float] = (
None # Socket connect timeout in seconds
)
zombie_reaper_interval: float = 300.0
zombie_reaper_max_age: float = 3600.0
# KFP engine
eng_kfp_endpoint: Optional[AnyUrl] = None
eng_kfp_token: Optional[str] = None

View File

@@ -0,0 +1,428 @@
"""Tests for zombie task cleanup in RedisTaskStatusMixin.
Tests cover:
- Layer A: _RQJobGone sentinel from _get_task_from_rq_direct when NoSuchJobError
- Layer B: task_status() reconciliation for zombie scenarios
- Layer C: Background zombie reaper
- Layer E: TTL alignment (metadata TTL uses results_ttl)
"""
import asyncio
import datetime
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from rq.exceptions import NoSuchJobError
from docling_jobkit.datamodel.task import Task
from docling_jobkit.datamodel.task_meta import TaskStatus
from docling_jobkit.orchestrators.base_orchestrator import TaskNotFoundError
from docling_serve.orchestrator_factory import (
_RQ_JOB_GONE,
RedisTaskStatusMixin,
_RQJobGone,
)
def _make_task(
task_id: str = "test-task-1",
status: TaskStatus = TaskStatus.SUCCESS,
error_message: str | None = None,
finished_at: datetime.datetime | None = None,
) -> Task:
task = Task(
task_id=task_id,
task_type="convert",
task_status=status,
processing_meta={
"num_docs": 0,
"num_processed": 0,
"num_succeeded": 0,
"num_failed": 0,
},
)
if error_message:
task.error_message = error_message
if finished_at:
task.finished_at = finished_at
return task
class FakeParentOrchestrator:
"""Minimal fake parent to satisfy MRO without real RQ/Redis."""
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.tasks: dict[str, Task] = {}
self._task_result_keys: dict[str, str] = {}
async def _update_task_from_rq(self, task_id: str) -> None:
raise NoSuchJobError(f"No such job: {task_id}")
async def task_status(self, task_id: str, wait: float = 0.0) -> Task:
if task_id in self.tasks:
return self.tasks[task_id]
raise TaskNotFoundError(task_id)
class FakeMixin(RedisTaskStatusMixin, FakeParentOrchestrator):
"""Concrete class combining RedisTaskStatusMixin with a fake parent."""
def __init__(self) -> None:
self.tasks: dict[str, Task] = {}
self._task_result_keys: dict[str, str] = {}
self.redis_prefix = "docling:tasks:"
self._redis_pool = MagicMock()
# ---------------------------------------------------------------------------
# Layer A: Sentinel tests
# ---------------------------------------------------------------------------
class TestRQJobGoneSentinel:
@pytest.mark.asyncio
async def test_returns_sentinel_on_no_such_job_error(self):
mixin = FakeMixin()
result = await mixin._get_task_from_rq_direct("missing-job")
assert isinstance(result, _RQJobGone)
@pytest.mark.asyncio
async def test_returns_none_on_generic_exception(self):
mixin = FakeMixin()
async def raise_generic(self_inner, task_id: str) -> None:
raise RuntimeError("Redis connection lost")
with patch.object(
FakeParentOrchestrator, "_update_task_from_rq", raise_generic
):
result = await mixin._get_task_from_rq_direct("some-task")
assert result is None
@pytest.mark.asyncio
async def test_returns_task_when_rq_has_job(self):
mixin = FakeMixin()
expected_task = _make_task("rq-task", TaskStatus.SUCCESS)
async def update_with_success(self_inner, task_id: str) -> None:
mixin.tasks[task_id] = expected_task
mock_redis = AsyncMock()
mock_redis.__aenter__ = AsyncMock(return_value=mock_redis)
mock_redis.__aexit__ = AsyncMock(return_value=False)
with (
patch.object(
FakeParentOrchestrator, "_update_task_from_rq", update_with_success
),
patch(
"docling_serve.orchestrator_factory.redis.Redis",
return_value=mock_redis,
),
):
result = await mixin._get_task_from_rq_direct("rq-task")
assert isinstance(result, Task)
assert result.task_status == TaskStatus.SUCCESS
# ---------------------------------------------------------------------------
# Layer B: task_status() reconciliation
# ---------------------------------------------------------------------------
class TestTaskStatusReconciliation:
@pytest.mark.asyncio
async def test_rq_gone_redis_success_cleans_up(self):
"""RQ job gone + Redis has SUCCESS -> return task, clean up tracking."""
mixin = FakeMixin()
cached_task = _make_task("t1", TaskStatus.SUCCESS)
mixin.tasks["t1"] = cached_task
mixin._task_result_keys["t1"] = "some-key"
with (
patch.object(mixin, "_get_task_from_rq_direct", return_value=_RQ_JOB_GONE),
patch.object(mixin, "_get_task_from_redis", return_value=cached_task),
):
result = await mixin.task_status("t1")
assert result.task_status == TaskStatus.SUCCESS
assert "t1" not in mixin.tasks
assert "t1" not in mixin._task_result_keys
@pytest.mark.asyncio
async def test_rq_gone_redis_failure_cleans_up(self):
"""RQ job gone + Redis has FAILURE -> return task, clean up tracking."""
mixin = FakeMixin()
cached_task = _make_task("t1", TaskStatus.FAILURE)
mixin.tasks["t1"] = cached_task
with (
patch.object(mixin, "_get_task_from_rq_direct", return_value=_RQ_JOB_GONE),
patch.object(mixin, "_get_task_from_redis", return_value=cached_task),
):
result = await mixin.task_status("t1")
assert result.task_status == TaskStatus.FAILURE
assert "t1" not in mixin.tasks
@pytest.mark.asyncio
async def test_rq_gone_redis_pending_marks_failure(self):
"""RQ job gone + Redis has PENDING -> mark as FAILURE with error_message."""
mixin = FakeMixin()
cached_task = _make_task("t2", TaskStatus.PENDING)
mixin.tasks["t2"] = cached_task
with (
patch.object(mixin, "_get_task_from_rq_direct", return_value=_RQ_JOB_GONE),
patch.object(mixin, "_get_task_from_redis", return_value=cached_task),
patch.object(
mixin, "_store_task_in_redis", new_callable=AsyncMock
) as mock_store,
):
result = await mixin.task_status("t2")
assert result.task_status == TaskStatus.FAILURE
assert "orphaned" in result.error_message.lower()
assert "t2" not in mixin.tasks
mock_store.assert_called_once()
@pytest.mark.asyncio
async def test_rq_gone_redis_started_marks_failure(self):
"""RQ job gone + Redis has STARTED -> mark as FAILURE with error_message."""
mixin = FakeMixin()
cached_task = _make_task("t3", TaskStatus.STARTED)
with (
patch.object(mixin, "_get_task_from_rq_direct", return_value=_RQ_JOB_GONE),
patch.object(mixin, "_get_task_from_redis", return_value=cached_task),
patch.object(mixin, "_store_task_in_redis", new_callable=AsyncMock),
):
result = await mixin.task_status("t3")
assert result.task_status == TaskStatus.FAILURE
assert "orphaned" in result.error_message.lower()
@pytest.mark.asyncio
async def test_rq_gone_no_redis_raises_not_found(self):
"""RQ job gone + not in Redis -> TaskNotFoundError."""
mixin = FakeMixin()
with (
patch.object(mixin, "_get_task_from_rq_direct", return_value=_RQ_JOB_GONE),
patch.object(mixin, "_get_task_from_redis", return_value=None),
):
with pytest.raises(TaskNotFoundError):
await mixin.task_status("ghost-task")
@pytest.mark.asyncio
async def test_rq_transient_error_falls_through_to_redis(self):
"""RQ returns None (transient) + Redis has SUCCESS -> return Redis task."""
mixin = FakeMixin()
cached_task = _make_task("t4", TaskStatus.SUCCESS)
with (
patch.object(mixin, "_get_task_from_rq_direct", return_value=None),
patch.object(mixin, "_get_task_from_redis", return_value=cached_task),
):
result = await mixin.task_status("t4")
assert result.task_status == TaskStatus.SUCCESS
@pytest.mark.asyncio
async def test_rq_has_task_returns_directly(self):
"""When RQ has the task, return it directly."""
mixin = FakeMixin()
rq_task = _make_task("t5", TaskStatus.SUCCESS)
with (
patch.object(mixin, "_get_task_from_rq_direct", return_value=rq_task),
patch.object(mixin, "_store_task_in_redis", new_callable=AsyncMock),
):
result = await mixin.task_status("t5")
assert result is rq_task
assert mixin.tasks["t5"] is rq_task
# ---------------------------------------------------------------------------
# Layer C: Background zombie reaper
# ---------------------------------------------------------------------------
class TestZombieReaper:
@pytest.mark.asyncio
async def test_reaps_old_completed_tasks(self):
mixin = FakeMixin()
old_time = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
hours=2
)
old_task = _make_task("old-1", TaskStatus.SUCCESS, finished_at=old_time)
mixin.tasks["old-1"] = old_task
mixin._task_result_keys["old-1"] = "key-1"
reaper = asyncio.create_task(
mixin._reap_zombie_tasks(interval=0.01, max_age=3600.0)
)
await asyncio.sleep(0.05)
reaper.cancel()
try:
await reaper
except asyncio.CancelledError:
pass
assert "old-1" not in mixin.tasks
assert "old-1" not in mixin._task_result_keys
@pytest.mark.asyncio
async def test_keeps_recent_completed_tasks(self):
mixin = FakeMixin()
recent_time = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
minutes=5
)
recent_task = _make_task(
"recent-1", TaskStatus.SUCCESS, finished_at=recent_time
)
mixin.tasks["recent-1"] = recent_task
reaper = asyncio.create_task(
mixin._reap_zombie_tasks(interval=0.01, max_age=3600.0)
)
await asyncio.sleep(0.05)
reaper.cancel()
try:
await reaper
except asyncio.CancelledError:
pass
assert "recent-1" in mixin.tasks
@pytest.mark.asyncio
async def test_keeps_in_progress_tasks(self):
mixin = FakeMixin()
old_time = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
hours=2
)
started_task = _make_task("started-1", TaskStatus.STARTED)
started_task.started_at = old_time
mixin.tasks["started-1"] = started_task
pending_task = _make_task("pending-1", TaskStatus.PENDING)
mixin.tasks["pending-1"] = pending_task
reaper = asyncio.create_task(
mixin._reap_zombie_tasks(interval=0.01, max_age=3600.0)
)
await asyncio.sleep(0.05)
reaper.cancel()
try:
await reaper
except asyncio.CancelledError:
pass
assert "started-1" in mixin.tasks
assert "pending-1" in mixin.tasks
@pytest.mark.asyncio
async def test_reaps_failed_tasks_with_finished_at(self):
mixin = FakeMixin()
old_time = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
hours=2
)
failed_task = _make_task("failed-1", TaskStatus.FAILURE, finished_at=old_time)
mixin.tasks["failed-1"] = failed_task
reaper = asyncio.create_task(
mixin._reap_zombie_tasks(interval=0.01, max_age=3600.0)
)
await asyncio.sleep(0.05)
reaper.cancel()
try:
await reaper
except asyncio.CancelledError:
pass
assert "failed-1" not in mixin.tasks
# ---------------------------------------------------------------------------
# Layer E: TTL alignment
# ---------------------------------------------------------------------------
class TestTTLAlignment:
@pytest.mark.asyncio
async def test_store_task_uses_results_ttl(self):
"""Metadata TTL should match eng_rq_results_ttl, not 86400."""
mixin = FakeMixin()
task = _make_task("ttl-task", TaskStatus.SUCCESS)
mock_redis = AsyncMock()
mock_redis.__aenter__ = AsyncMock(return_value=mock_redis)
mock_redis.__aexit__ = AsyncMock(return_value=False)
with (
patch(
"docling_serve.orchestrator_factory.redis.Redis",
return_value=mock_redis,
),
patch(
"docling_serve.orchestrator_factory.docling_serve_settings"
) as mock_settings,
):
mock_settings.eng_rq_results_ttl = 14400
await mixin._store_task_in_redis(task)
mock_redis.set.assert_called_once()
call_kwargs = mock_redis.set.call_args
assert (
call_kwargs.kwargs.get("ex") == 14400 or call_kwargs[1].get("ex") == 14400
)
# ---------------------------------------------------------------------------
# Error message propagation through Redis
# ---------------------------------------------------------------------------
class TestErrorMessagePropagation:
@pytest.mark.asyncio
async def test_store_and_retrieve_error_message(self):
"""error_message should round-trip through Redis store/get."""
mixin = FakeMixin()
task = _make_task("err-task", TaskStatus.FAILURE, error_message="Out of memory")
stored_data = {}
async def fake_set(key: str, value: Any, ex: int = 0) -> None:
stored_data[key] = value
async def fake_get(key: str) -> bytes | None:
val = stored_data.get(key)
if val is None:
return None
return val.encode() if isinstance(val, str) else val
mock_redis = AsyncMock()
mock_redis.__aenter__ = AsyncMock(return_value=mock_redis)
mock_redis.__aexit__ = AsyncMock(return_value=False)
mock_redis.set = fake_set
mock_redis.get = fake_get
with (
patch(
"docling_serve.orchestrator_factory.redis.Redis",
return_value=mock_redis,
),
patch(
"docling_serve.orchestrator_factory.docling_serve_settings"
) as mock_settings,
):
mock_settings.eng_rq_results_ttl = 14400
await mixin._store_task_in_redis(task)
retrieved = await mixin._get_task_from_redis("err-task")
assert retrieved is not None
assert retrieved.error_message == "Out of memory"
assert retrieved.task_status == TaskStatus.FAILURE