mirror of
https://github.com/docling-project/docling-serve.git
synced 2026-03-07 22:33:44 +00:00
429 lines
15 KiB
Python
429 lines
15 KiB
Python
"""Tests for zombie task cleanup in RedisTaskStatusMixin.
|
|
|
|
Tests cover:
|
|
- Layer A: _RQJobGone sentinel from _get_task_from_rq_direct when NoSuchJobError
|
|
- Layer B: task_status() reconciliation for zombie scenarios
|
|
- Layer C: Background zombie reaper
|
|
- Layer E: TTL alignment (metadata TTL uses results_ttl)
|
|
"""
|
|
|
|
import asyncio
|
|
import datetime
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
from rq.exceptions import NoSuchJobError
|
|
|
|
from docling_jobkit.datamodel.task import Task
|
|
from docling_jobkit.datamodel.task_meta import TaskStatus
|
|
from docling_jobkit.orchestrators.base_orchestrator import TaskNotFoundError
|
|
|
|
from docling_serve.orchestrator_factory import (
|
|
_RQ_JOB_GONE,
|
|
RedisTaskStatusMixin,
|
|
_RQJobGone,
|
|
)
|
|
|
|
|
|
def _make_task(
|
|
task_id: str = "test-task-1",
|
|
status: TaskStatus = TaskStatus.SUCCESS,
|
|
error_message: str | None = None,
|
|
finished_at: datetime.datetime | None = None,
|
|
) -> Task:
|
|
task = Task(
|
|
task_id=task_id,
|
|
task_type="convert",
|
|
task_status=status,
|
|
processing_meta={
|
|
"num_docs": 0,
|
|
"num_processed": 0,
|
|
"num_succeeded": 0,
|
|
"num_failed": 0,
|
|
},
|
|
)
|
|
if error_message:
|
|
task.error_message = error_message
|
|
if finished_at:
|
|
task.finished_at = finished_at
|
|
return task
|
|
|
|
|
|
class FakeParentOrchestrator:
|
|
"""Minimal fake parent to satisfy MRO without real RQ/Redis."""
|
|
|
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
self.tasks: dict[str, Task] = {}
|
|
self._task_result_keys: dict[str, str] = {}
|
|
|
|
async def _update_task_from_rq(self, task_id: str) -> None:
|
|
raise NoSuchJobError(f"No such job: {task_id}")
|
|
|
|
async def task_status(self, task_id: str, wait: float = 0.0) -> Task:
|
|
if task_id in self.tasks:
|
|
return self.tasks[task_id]
|
|
raise TaskNotFoundError(task_id)
|
|
|
|
|
|
class FakeMixin(RedisTaskStatusMixin, FakeParentOrchestrator):
|
|
"""Concrete class combining RedisTaskStatusMixin with a fake parent."""
|
|
|
|
def __init__(self) -> None:
|
|
self.tasks: dict[str, Task] = {}
|
|
self._task_result_keys: dict[str, str] = {}
|
|
self.redis_prefix = "docling:tasks:"
|
|
self._redis_pool = MagicMock()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Layer A: Sentinel tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestRQJobGoneSentinel:
|
|
@pytest.mark.asyncio
|
|
async def test_returns_sentinel_on_no_such_job_error(self):
|
|
mixin = FakeMixin()
|
|
result = await mixin._get_task_from_rq_direct("missing-job")
|
|
assert isinstance(result, _RQJobGone)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_returns_none_on_generic_exception(self):
|
|
mixin = FakeMixin()
|
|
|
|
async def raise_generic(self_inner, task_id: str) -> None:
|
|
raise RuntimeError("Redis connection lost")
|
|
|
|
with patch.object(
|
|
FakeParentOrchestrator, "_update_task_from_rq", raise_generic
|
|
):
|
|
result = await mixin._get_task_from_rq_direct("some-task")
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_returns_task_when_rq_has_job(self):
|
|
mixin = FakeMixin()
|
|
expected_task = _make_task("rq-task", TaskStatus.SUCCESS)
|
|
|
|
async def update_with_success(self_inner, task_id: str) -> None:
|
|
mixin.tasks[task_id] = expected_task
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.__aenter__ = AsyncMock(return_value=mock_redis)
|
|
mock_redis.__aexit__ = AsyncMock(return_value=False)
|
|
|
|
with (
|
|
patch.object(
|
|
FakeParentOrchestrator, "_update_task_from_rq", update_with_success
|
|
),
|
|
patch(
|
|
"docling_serve.orchestrator_factory.redis.Redis",
|
|
return_value=mock_redis,
|
|
),
|
|
):
|
|
result = await mixin._get_task_from_rq_direct("rq-task")
|
|
assert isinstance(result, Task)
|
|
assert result.task_status == TaskStatus.SUCCESS
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Layer B: task_status() reconciliation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestTaskStatusReconciliation:
|
|
@pytest.mark.asyncio
|
|
async def test_rq_gone_redis_success_cleans_up(self):
|
|
"""RQ job gone + Redis has SUCCESS -> return task, clean up tracking."""
|
|
mixin = FakeMixin()
|
|
cached_task = _make_task("t1", TaskStatus.SUCCESS)
|
|
mixin.tasks["t1"] = cached_task
|
|
mixin._task_result_keys["t1"] = "some-key"
|
|
|
|
with (
|
|
patch.object(mixin, "_get_task_from_rq_direct", return_value=_RQ_JOB_GONE),
|
|
patch.object(mixin, "_get_task_from_redis", return_value=cached_task),
|
|
):
|
|
result = await mixin.task_status("t1")
|
|
|
|
assert result.task_status == TaskStatus.SUCCESS
|
|
assert "t1" not in mixin.tasks
|
|
assert "t1" not in mixin._task_result_keys
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rq_gone_redis_failure_cleans_up(self):
|
|
"""RQ job gone + Redis has FAILURE -> return task, clean up tracking."""
|
|
mixin = FakeMixin()
|
|
cached_task = _make_task("t1", TaskStatus.FAILURE)
|
|
mixin.tasks["t1"] = cached_task
|
|
|
|
with (
|
|
patch.object(mixin, "_get_task_from_rq_direct", return_value=_RQ_JOB_GONE),
|
|
patch.object(mixin, "_get_task_from_redis", return_value=cached_task),
|
|
):
|
|
result = await mixin.task_status("t1")
|
|
|
|
assert result.task_status == TaskStatus.FAILURE
|
|
assert "t1" not in mixin.tasks
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rq_gone_redis_pending_marks_failure(self):
|
|
"""RQ job gone + Redis has PENDING -> mark as FAILURE with error_message."""
|
|
mixin = FakeMixin()
|
|
cached_task = _make_task("t2", TaskStatus.PENDING)
|
|
mixin.tasks["t2"] = cached_task
|
|
|
|
with (
|
|
patch.object(mixin, "_get_task_from_rq_direct", return_value=_RQ_JOB_GONE),
|
|
patch.object(mixin, "_get_task_from_redis", return_value=cached_task),
|
|
patch.object(
|
|
mixin, "_store_task_in_redis", new_callable=AsyncMock
|
|
) as mock_store,
|
|
):
|
|
result = await mixin.task_status("t2")
|
|
|
|
assert result.task_status == TaskStatus.FAILURE
|
|
assert "orphaned" in result.error_message.lower()
|
|
assert "t2" not in mixin.tasks
|
|
mock_store.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rq_gone_redis_started_marks_failure(self):
|
|
"""RQ job gone + Redis has STARTED -> mark as FAILURE with error_message."""
|
|
mixin = FakeMixin()
|
|
cached_task = _make_task("t3", TaskStatus.STARTED)
|
|
|
|
with (
|
|
patch.object(mixin, "_get_task_from_rq_direct", return_value=_RQ_JOB_GONE),
|
|
patch.object(mixin, "_get_task_from_redis", return_value=cached_task),
|
|
patch.object(mixin, "_store_task_in_redis", new_callable=AsyncMock),
|
|
):
|
|
result = await mixin.task_status("t3")
|
|
|
|
assert result.task_status == TaskStatus.FAILURE
|
|
assert "orphaned" in result.error_message.lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rq_gone_no_redis_raises_not_found(self):
|
|
"""RQ job gone + not in Redis -> TaskNotFoundError."""
|
|
mixin = FakeMixin()
|
|
|
|
with (
|
|
patch.object(mixin, "_get_task_from_rq_direct", return_value=_RQ_JOB_GONE),
|
|
patch.object(mixin, "_get_task_from_redis", return_value=None),
|
|
):
|
|
with pytest.raises(TaskNotFoundError):
|
|
await mixin.task_status("ghost-task")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rq_transient_error_falls_through_to_redis(self):
|
|
"""RQ returns None (transient) + Redis has SUCCESS -> return Redis task."""
|
|
mixin = FakeMixin()
|
|
cached_task = _make_task("t4", TaskStatus.SUCCESS)
|
|
|
|
with (
|
|
patch.object(mixin, "_get_task_from_rq_direct", return_value=None),
|
|
patch.object(mixin, "_get_task_from_redis", return_value=cached_task),
|
|
):
|
|
result = await mixin.task_status("t4")
|
|
|
|
assert result.task_status == TaskStatus.SUCCESS
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rq_has_task_returns_directly(self):
|
|
"""When RQ has the task, return it directly."""
|
|
mixin = FakeMixin()
|
|
rq_task = _make_task("t5", TaskStatus.SUCCESS)
|
|
|
|
with (
|
|
patch.object(mixin, "_get_task_from_rq_direct", return_value=rq_task),
|
|
patch.object(mixin, "_store_task_in_redis", new_callable=AsyncMock),
|
|
):
|
|
result = await mixin.task_status("t5")
|
|
|
|
assert result is rq_task
|
|
assert mixin.tasks["t5"] is rq_task
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Layer C: Background zombie reaper
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestZombieReaper:
|
|
@pytest.mark.asyncio
|
|
async def test_reaps_old_completed_tasks(self):
|
|
mixin = FakeMixin()
|
|
old_time = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
|
|
hours=2
|
|
)
|
|
old_task = _make_task("old-1", TaskStatus.SUCCESS, finished_at=old_time)
|
|
mixin.tasks["old-1"] = old_task
|
|
mixin._task_result_keys["old-1"] = "key-1"
|
|
|
|
reaper = asyncio.create_task(
|
|
mixin._reap_zombie_tasks(interval=0.01, max_age=3600.0)
|
|
)
|
|
await asyncio.sleep(0.05)
|
|
reaper.cancel()
|
|
try:
|
|
await reaper
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
assert "old-1" not in mixin.tasks
|
|
assert "old-1" not in mixin._task_result_keys
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_keeps_recent_completed_tasks(self):
|
|
mixin = FakeMixin()
|
|
recent_time = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
|
|
minutes=5
|
|
)
|
|
recent_task = _make_task(
|
|
"recent-1", TaskStatus.SUCCESS, finished_at=recent_time
|
|
)
|
|
mixin.tasks["recent-1"] = recent_task
|
|
|
|
reaper = asyncio.create_task(
|
|
mixin._reap_zombie_tasks(interval=0.01, max_age=3600.0)
|
|
)
|
|
await asyncio.sleep(0.05)
|
|
reaper.cancel()
|
|
try:
|
|
await reaper
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
assert "recent-1" in mixin.tasks
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_keeps_in_progress_tasks(self):
|
|
mixin = FakeMixin()
|
|
old_time = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
|
|
hours=2
|
|
)
|
|
started_task = _make_task("started-1", TaskStatus.STARTED)
|
|
started_task.started_at = old_time
|
|
mixin.tasks["started-1"] = started_task
|
|
|
|
pending_task = _make_task("pending-1", TaskStatus.PENDING)
|
|
mixin.tasks["pending-1"] = pending_task
|
|
|
|
reaper = asyncio.create_task(
|
|
mixin._reap_zombie_tasks(interval=0.01, max_age=3600.0)
|
|
)
|
|
await asyncio.sleep(0.05)
|
|
reaper.cancel()
|
|
try:
|
|
await reaper
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
assert "started-1" in mixin.tasks
|
|
assert "pending-1" in mixin.tasks
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_reaps_failed_tasks_with_finished_at(self):
|
|
mixin = FakeMixin()
|
|
old_time = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
|
|
hours=2
|
|
)
|
|
failed_task = _make_task("failed-1", TaskStatus.FAILURE, finished_at=old_time)
|
|
mixin.tasks["failed-1"] = failed_task
|
|
|
|
reaper = asyncio.create_task(
|
|
mixin._reap_zombie_tasks(interval=0.01, max_age=3600.0)
|
|
)
|
|
await asyncio.sleep(0.05)
|
|
reaper.cancel()
|
|
try:
|
|
await reaper
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
assert "failed-1" not in mixin.tasks
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Layer E: TTL alignment
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestTTLAlignment:
|
|
@pytest.mark.asyncio
|
|
async def test_store_task_uses_results_ttl(self):
|
|
"""Metadata TTL should match eng_rq_results_ttl, not 86400."""
|
|
mixin = FakeMixin()
|
|
task = _make_task("ttl-task", TaskStatus.SUCCESS)
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.__aenter__ = AsyncMock(return_value=mock_redis)
|
|
mock_redis.__aexit__ = AsyncMock(return_value=False)
|
|
|
|
with (
|
|
patch(
|
|
"docling_serve.orchestrator_factory.redis.Redis",
|
|
return_value=mock_redis,
|
|
),
|
|
patch(
|
|
"docling_serve.orchestrator_factory.docling_serve_settings"
|
|
) as mock_settings,
|
|
):
|
|
mock_settings.eng_rq_results_ttl = 14400
|
|
await mixin._store_task_in_redis(task)
|
|
|
|
mock_redis.set.assert_called_once()
|
|
call_kwargs = mock_redis.set.call_args
|
|
assert (
|
|
call_kwargs.kwargs.get("ex") == 14400 or call_kwargs[1].get("ex") == 14400
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Error message propagation through Redis
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestErrorMessagePropagation:
|
|
@pytest.mark.asyncio
|
|
async def test_store_and_retrieve_error_message(self):
|
|
"""error_message should round-trip through Redis store/get."""
|
|
mixin = FakeMixin()
|
|
task = _make_task("err-task", TaskStatus.FAILURE, error_message="Out of memory")
|
|
|
|
stored_data = {}
|
|
|
|
async def fake_set(key: str, value: Any, ex: int = 0) -> None:
|
|
stored_data[key] = value
|
|
|
|
async def fake_get(key: str) -> bytes | None:
|
|
val = stored_data.get(key)
|
|
if val is None:
|
|
return None
|
|
return val.encode() if isinstance(val, str) else val
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.__aenter__ = AsyncMock(return_value=mock_redis)
|
|
mock_redis.__aexit__ = AsyncMock(return_value=False)
|
|
mock_redis.set = fake_set
|
|
mock_redis.get = fake_get
|
|
|
|
with (
|
|
patch(
|
|
"docling_serve.orchestrator_factory.redis.Redis",
|
|
return_value=mock_redis,
|
|
),
|
|
patch(
|
|
"docling_serve.orchestrator_factory.docling_serve_settings"
|
|
) as mock_settings,
|
|
):
|
|
mock_settings.eng_rq_results_ttl = 14400
|
|
await mixin._store_task_in_redis(task)
|
|
retrieved = await mixin._get_task_from_redis("err-task")
|
|
|
|
assert retrieved is not None
|
|
assert retrieved.error_message == "Out of memory"
|
|
assert retrieved.task_status == TaskStatus.FAILURE
|