From 853003cf3b14a763c12fc40cffe13d5e92154b01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Rein?= Date: Tue, 24 Feb 2026 22:23:56 +0100 Subject: [PATCH] =?UTF-8?q?fix:=20zombie=20task=20cleanup=20=E2=80=94=20re?= =?UTF-8?q?concile=20stale=20RQ/Redis=20state=20(#516)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Pawel Rein --- docling_serve/app.py | 16 + docling_serve/orchestrator_factory.py | 156 +++++++--- docling_serve/settings.py | 2 + tests/test_zombie_task_cleanup.py | 428 ++++++++++++++++++++++++++ 4 files changed, 564 insertions(+), 38 deletions(-) create mode 100644 tests/test_zombie_task_cleanup.py diff --git a/docling_serve/app.py b/docling_serve/app.py index 108fb1b..244e853 100644 --- a/docling_serve/app.py +++ b/docling_serve/app.py @@ -141,14 +141,30 @@ async def lifespan(app: FastAPI): # Start the background queue processor queue_task = asyncio.create_task(orchestrator.process_queue()) + reaper_task = None + if hasattr(orchestrator, "_reap_zombie_tasks"): + reaper_task = asyncio.create_task( + orchestrator._reap_zombie_tasks( + interval=docling_serve_settings.zombie_reaper_interval, + max_age=docling_serve_settings.zombie_reaper_max_age, + ) + ) + yield # Cancel the background queue processor on shutdown queue_task.cancel() + if reaper_task: + reaper_task.cancel() try: await queue_task except asyncio.CancelledError: _log.info("Queue processor cancelled.") + if reaper_task: + try: + await reaper_task + except asyncio.CancelledError: + _log.info("Zombie reaper cancelled.") # Remove scratch directory in case it was a tempfile if docling_serve_settings.scratch_path is not None: diff --git a/docling_serve/orchestrator_factory.py b/docling_serve/orchestrator_factory.py index ade23a3..65ef12b 100644 --- a/docling_serve/orchestrator_factory.py +++ b/docling_serve/orchestrator_factory.py @@ -1,9 +1,12 @@ +import asyncio +import datetime import json import logging from functools import lru_cache -from typing import Any +from typing import Any, Union import redis.asyncio as redis +from rq.exceptions import NoSuchJobError from docling_jobkit.datamodel.task import Task from docling_jobkit.datamodel.task_meta import TaskStatus @@ -18,6 +21,13 @@ from docling_serve.storage import get_scratch _log = logging.getLogger(__name__) +class _RQJobGone: + """Sentinel: the RQ job has been deleted / TTL-expired.""" + + +_RQ_JOB_GONE = _RQJobGone() + + class RedisTaskStatusMixin: tasks: dict[str, Task] _task_result_keys: dict[str, str] @@ -64,43 +74,66 @@ class RedisTaskStatusMixin: async def task_status(self, task_id: str, wait: float = 0.0) -> Task: """ - Get task status by checking Redis first, then falling back to RQ verification. + Get task status with zombie task reconciliation. - When Redis shows 'pending' but RQ shows 'success', we update Redis - and return the RQ status for cross-instance consistency. + Checks RQ first (authoritative), then Redis cache, then in-memory. + When the RQ job is definitively gone (NoSuchJobError), reconciles: + - Terminal status in Redis -> return it, clean up tracking + - Non-terminal status in Redis -> mark FAILURE (orphaned task) + - Not in Redis at all -> raise TaskNotFoundError """ _log.info(f"Task {task_id} status check") - # Always check RQ directly first - this is the most reliable source - rq_task = await self._get_task_from_rq_direct(task_id) - if rq_task: - _log.info(f"Task {task_id} in RQ: {rq_task.task_status}") + rq_result = await self._get_task_from_rq_direct(task_id) - # Update memory registry - self.tasks[task_id] = rq_task + if isinstance(rq_result, Task): + _log.info(f"Task {task_id} in RQ: {rq_result.task_status}") + self.tasks[task_id] = rq_result + await self._store_task_in_redis(rq_result) + return rq_result - # Store/update in Redis for other instances - await self._store_task_in_redis(rq_task) - return rq_task + job_is_gone = isinstance(rq_result, _RQJobGone) - # If not in RQ, check Redis (maybe it's cached from another instance) task = await self._get_task_from_redis(task_id) if task: _log.info(f"Task {task_id} in Redis: {task.task_status}") - # CRITICAL FIX: Check if Redis status might be stale - # STARTED tasks might have completed since they were cached + if job_is_gone: + if task.is_completed(): + _log.info( + f"Task {task_id} completed ({task.task_status}) " + f"and RQ job expired — cleaning up tracking" + ) + self.tasks.pop(task_id, None) + self._task_result_keys.pop(task_id, None) + return task + else: + _log.warning( + f"Task {task_id} was {task.task_status} but RQ job is gone " + f"— marking as FAILURE (orphaned)" + ) + task.set_status(TaskStatus.FAILURE) + if hasattr(task, "error_message"): + task.error_message = ( + f"Task orphaned: RQ job expired while status was " + f"{task.task_status}. Likely caused by worker restart or " + f"Redis eviction." + ) + self.tasks.pop(task_id, None) + self._task_result_keys.pop(task_id, None) + await self._store_task_in_redis(task) + return task + if task.task_status in [TaskStatus.PENDING, TaskStatus.STARTED]: _log.debug(f"Task {task_id} verifying stale status") - - # Try to get fresh status from RQ fresh_rq_task = await self._get_task_from_rq_direct(task_id) - if fresh_rq_task and fresh_rq_task.task_status != task.task_status: + if ( + isinstance(fresh_rq_task, Task) + and fresh_rq_task.task_status != task.task_status + ): _log.info( f"Task {task_id} status updated: {fresh_rq_task.task_status}" ) - - # Update memory and Redis with fresh status self.tasks[task_id] = fresh_rq_task await self._store_task_in_redis(fresh_rq_task) return fresh_rq_task @@ -109,12 +142,14 @@ class RedisTaskStatusMixin: return task - # Fall back to parent implementation + if job_is_gone: + _log.warning(f"Task {task_id} not in RQ or Redis — truly gone") + self.tasks.pop(task_id, None) + raise TaskNotFoundError(task_id) + try: parent_task = await super().task_status(task_id, wait) # type: ignore[misc] _log.debug(f"Task {task_id} from parent: {parent_task.task_status}") - - # Store in Redis for other instances to find await self._store_task_in_redis(parent_task) return parent_task except TaskNotFoundError: @@ -135,18 +170,23 @@ class RedisTaskStatusMixin: meta.setdefault("num_succeeded", 0) meta.setdefault("num_failed", 0) - return Task( - task_id=data["task_id"], - task_type=data["task_type"], - task_status=TaskStatus(data["task_status"]), - processing_meta=meta, - error_message=data.get("error_message"), - ) + task_kwargs: dict[str, Any] = { + "task_id": data["task_id"], + "task_type": data["task_type"], + "task_status": TaskStatus(data["task_status"]), + "processing_meta": meta, + } + if data.get("error_message") and "error_message" in Task.model_fields: + task_kwargs["error_message"] = data["error_message"] + task = Task(**task_kwargs) + return task except Exception as e: _log.error(f"Redis get task {task_id}: {e}") return None - async def _get_task_from_rq_direct(self, task_id: str) -> Task | None: + async def _get_task_from_rq_direct( + self, task_id: str + ) -> Union[Task, _RQJobGone, None]: try: _log.debug(f"Checking RQ for task {task_id}") @@ -172,7 +212,7 @@ class RedisTaskStatusMixin: if updated_task and updated_task.task_status != TaskStatus.PENDING: _log.debug(f"RQ task {task_id}: {updated_task.task_status}") - # Store result key if available + result_ttl = docling_serve_settings.eng_rq_results_ttl if task_id in self._task_result_keys: try: async with redis.Redis( @@ -181,7 +221,7 @@ class RedisTaskStatusMixin: await r.set( f"{self.redis_prefix}{task_id}:result_key", self._task_result_keys[task_id], - ex=86400, + ex=result_ttl, ) _log.debug(f"Stored result key for {task_id}") except Exception as e: @@ -191,13 +231,14 @@ class RedisTaskStatusMixin: return None finally: - # Restore original task state if original_task: self.tasks[task_id] = original_task elif task_id in self.tasks and self.tasks[task_id] == temp_task: - # Only remove if it's still our temp task del self.tasks[task_id] + except NoSuchJobError: + _log.info(f"RQ job {task_id} no longer exists (TTL expired or deleted)") + return _RQ_JOB_GONE except Exception as e: _log.error(f"RQ check {task_id}: {e}") return None @@ -242,11 +283,13 @@ class RedisTaskStatusMixin: "processing_meta": meta, "error_message": getattr(task, "error_message", None), } + + metadata_ttl = docling_serve_settings.eng_rq_results_ttl async with redis.Redis(connection_pool=self._redis_pool) as r: await r.set( f"{self.redis_prefix}{task.task_id}:metadata", json.dumps(data), - ex=86400, + ex=metadata_ttl, ) except Exception as e: _log.error(f"Store task {task.task_id}: {e}") @@ -286,16 +329,53 @@ class RedisTaskStatusMixin: await self._store_task_in_redis(self.tasks[task_id]) if task_id in self._task_result_keys: + result_ttl = docling_serve_settings.eng_rq_results_ttl try: async with redis.Redis(connection_pool=self._redis_pool) as r: await r.set( f"{self.redis_prefix}{task_id}:result_key", self._task_result_keys[task_id], - ex=86400, + ex=result_ttl, ) except Exception as e: _log.error(f"Store result key {task_id}: {e}") + async def _reap_zombie_tasks( + self, interval: float = 300.0, max_age: float = 3600.0 + ) -> None: + """ + Periodically remove completed tasks from in-memory tracking. + + Args: + interval: Seconds between sweeps (default 5 min) + max_age: Remove completed tasks older than this (default 1h) + """ + while True: + await asyncio.sleep(interval) + try: + now = datetime.datetime.now(datetime.timezone.utc) + cutoff = now - datetime.timedelta(seconds=max_age) + to_remove: list[str] = [] + + for task_id, task in list(self.tasks.items()): + if ( + task.is_completed() + and task.finished_at + and task.finished_at < cutoff + ): + to_remove.append(task_id) + + for task_id in to_remove: + self.tasks.pop(task_id, None) + self._task_result_keys.pop(task_id, None) + _log.debug(f"Reaped zombie task {task_id}") + + if to_remove: + _log.info(f"Reaped {len(to_remove)} zombie tasks from tracking") + + except Exception as e: + _log.error(f"Zombie reaper error: {e}") + @lru_cache def get_async_orchestrator() -> BaseOrchestrator: diff --git a/docling_serve/settings.py b/docling_serve/settings.py index c5baeca..8bf493b 100644 --- a/docling_serve/settings.py +++ b/docling_serve/settings.py @@ -94,6 +94,8 @@ class DoclingServeSettings(BaseSettings): eng_rq_redis_socket_connect_timeout: Optional[float] = ( None # Socket connect timeout in seconds ) + zombie_reaper_interval: float = 300.0 + zombie_reaper_max_age: float = 3600.0 # KFP engine eng_kfp_endpoint: Optional[AnyUrl] = None eng_kfp_token: Optional[str] = None diff --git a/tests/test_zombie_task_cleanup.py b/tests/test_zombie_task_cleanup.py new file mode 100644 index 0000000..eb40bfb --- /dev/null +++ b/tests/test_zombie_task_cleanup.py @@ -0,0 +1,428 @@ +"""Tests for zombie task cleanup in RedisTaskStatusMixin. + +Tests cover: +- Layer A: _RQJobGone sentinel from _get_task_from_rq_direct when NoSuchJobError +- Layer B: task_status() reconciliation for zombie scenarios +- Layer C: Background zombie reaper +- Layer E: TTL alignment (metadata TTL uses results_ttl) +""" + +import asyncio +import datetime +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from rq.exceptions import NoSuchJobError + +from docling_jobkit.datamodel.task import Task +from docling_jobkit.datamodel.task_meta import TaskStatus +from docling_jobkit.orchestrators.base_orchestrator import TaskNotFoundError + +from docling_serve.orchestrator_factory import ( + _RQ_JOB_GONE, + RedisTaskStatusMixin, + _RQJobGone, +) + + +def _make_task( + task_id: str = "test-task-1", + status: TaskStatus = TaskStatus.SUCCESS, + error_message: str | None = None, + finished_at: datetime.datetime | None = None, +) -> Task: + task = Task( + task_id=task_id, + task_type="convert", + task_status=status, + processing_meta={ + "num_docs": 0, + "num_processed": 0, + "num_succeeded": 0, + "num_failed": 0, + }, + ) + if error_message: + task.error_message = error_message + if finished_at: + task.finished_at = finished_at + return task + + +class FakeParentOrchestrator: + """Minimal fake parent to satisfy MRO without real RQ/Redis.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + self.tasks: dict[str, Task] = {} + self._task_result_keys: dict[str, str] = {} + + async def _update_task_from_rq(self, task_id: str) -> None: + raise NoSuchJobError(f"No such job: {task_id}") + + async def task_status(self, task_id: str, wait: float = 0.0) -> Task: + if task_id in self.tasks: + return self.tasks[task_id] + raise TaskNotFoundError(task_id) + + +class FakeMixin(RedisTaskStatusMixin, FakeParentOrchestrator): + """Concrete class combining RedisTaskStatusMixin with a fake parent.""" + + def __init__(self) -> None: + self.tasks: dict[str, Task] = {} + self._task_result_keys: dict[str, str] = {} + self.redis_prefix = "docling:tasks:" + self._redis_pool = MagicMock() + + +# --------------------------------------------------------------------------- +# Layer A: Sentinel tests +# --------------------------------------------------------------------------- + + +class TestRQJobGoneSentinel: + @pytest.mark.asyncio + async def test_returns_sentinel_on_no_such_job_error(self): + mixin = FakeMixin() + result = await mixin._get_task_from_rq_direct("missing-job") + assert isinstance(result, _RQJobGone) + + @pytest.mark.asyncio + async def test_returns_none_on_generic_exception(self): + mixin = FakeMixin() + + async def raise_generic(self_inner, task_id: str) -> None: + raise RuntimeError("Redis connection lost") + + with patch.object( + FakeParentOrchestrator, "_update_task_from_rq", raise_generic + ): + result = await mixin._get_task_from_rq_direct("some-task") + assert result is None + + @pytest.mark.asyncio + async def test_returns_task_when_rq_has_job(self): + mixin = FakeMixin() + expected_task = _make_task("rq-task", TaskStatus.SUCCESS) + + async def update_with_success(self_inner, task_id: str) -> None: + mixin.tasks[task_id] = expected_task + + mock_redis = AsyncMock() + mock_redis.__aenter__ = AsyncMock(return_value=mock_redis) + mock_redis.__aexit__ = AsyncMock(return_value=False) + + with ( + patch.object( + FakeParentOrchestrator, "_update_task_from_rq", update_with_success + ), + patch( + "docling_serve.orchestrator_factory.redis.Redis", + return_value=mock_redis, + ), + ): + result = await mixin._get_task_from_rq_direct("rq-task") + assert isinstance(result, Task) + assert result.task_status == TaskStatus.SUCCESS + + +# --------------------------------------------------------------------------- +# Layer B: task_status() reconciliation +# --------------------------------------------------------------------------- + + +class TestTaskStatusReconciliation: + @pytest.mark.asyncio + async def test_rq_gone_redis_success_cleans_up(self): + """RQ job gone + Redis has SUCCESS -> return task, clean up tracking.""" + mixin = FakeMixin() + cached_task = _make_task("t1", TaskStatus.SUCCESS) + mixin.tasks["t1"] = cached_task + mixin._task_result_keys["t1"] = "some-key" + + with ( + patch.object(mixin, "_get_task_from_rq_direct", return_value=_RQ_JOB_GONE), + patch.object(mixin, "_get_task_from_redis", return_value=cached_task), + ): + result = await mixin.task_status("t1") + + assert result.task_status == TaskStatus.SUCCESS + assert "t1" not in mixin.tasks + assert "t1" not in mixin._task_result_keys + + @pytest.mark.asyncio + async def test_rq_gone_redis_failure_cleans_up(self): + """RQ job gone + Redis has FAILURE -> return task, clean up tracking.""" + mixin = FakeMixin() + cached_task = _make_task("t1", TaskStatus.FAILURE) + mixin.tasks["t1"] = cached_task + + with ( + patch.object(mixin, "_get_task_from_rq_direct", return_value=_RQ_JOB_GONE), + patch.object(mixin, "_get_task_from_redis", return_value=cached_task), + ): + result = await mixin.task_status("t1") + + assert result.task_status == TaskStatus.FAILURE + assert "t1" not in mixin.tasks + + @pytest.mark.asyncio + async def test_rq_gone_redis_pending_marks_failure(self): + """RQ job gone + Redis has PENDING -> mark as FAILURE with error_message.""" + mixin = FakeMixin() + cached_task = _make_task("t2", TaskStatus.PENDING) + mixin.tasks["t2"] = cached_task + + with ( + patch.object(mixin, "_get_task_from_rq_direct", return_value=_RQ_JOB_GONE), + patch.object(mixin, "_get_task_from_redis", return_value=cached_task), + patch.object( + mixin, "_store_task_in_redis", new_callable=AsyncMock + ) as mock_store, + ): + result = await mixin.task_status("t2") + + assert result.task_status == TaskStatus.FAILURE + assert "orphaned" in result.error_message.lower() + assert "t2" not in mixin.tasks + mock_store.assert_called_once() + + @pytest.mark.asyncio + async def test_rq_gone_redis_started_marks_failure(self): + """RQ job gone + Redis has STARTED -> mark as FAILURE with error_message.""" + mixin = FakeMixin() + cached_task = _make_task("t3", TaskStatus.STARTED) + + with ( + patch.object(mixin, "_get_task_from_rq_direct", return_value=_RQ_JOB_GONE), + patch.object(mixin, "_get_task_from_redis", return_value=cached_task), + patch.object(mixin, "_store_task_in_redis", new_callable=AsyncMock), + ): + result = await mixin.task_status("t3") + + assert result.task_status == TaskStatus.FAILURE + assert "orphaned" in result.error_message.lower() + + @pytest.mark.asyncio + async def test_rq_gone_no_redis_raises_not_found(self): + """RQ job gone + not in Redis -> TaskNotFoundError.""" + mixin = FakeMixin() + + with ( + patch.object(mixin, "_get_task_from_rq_direct", return_value=_RQ_JOB_GONE), + patch.object(mixin, "_get_task_from_redis", return_value=None), + ): + with pytest.raises(TaskNotFoundError): + await mixin.task_status("ghost-task") + + @pytest.mark.asyncio + async def test_rq_transient_error_falls_through_to_redis(self): + """RQ returns None (transient) + Redis has SUCCESS -> return Redis task.""" + mixin = FakeMixin() + cached_task = _make_task("t4", TaskStatus.SUCCESS) + + with ( + patch.object(mixin, "_get_task_from_rq_direct", return_value=None), + patch.object(mixin, "_get_task_from_redis", return_value=cached_task), + ): + result = await mixin.task_status("t4") + + assert result.task_status == TaskStatus.SUCCESS + + @pytest.mark.asyncio + async def test_rq_has_task_returns_directly(self): + """When RQ has the task, return it directly.""" + mixin = FakeMixin() + rq_task = _make_task("t5", TaskStatus.SUCCESS) + + with ( + patch.object(mixin, "_get_task_from_rq_direct", return_value=rq_task), + patch.object(mixin, "_store_task_in_redis", new_callable=AsyncMock), + ): + result = await mixin.task_status("t5") + + assert result is rq_task + assert mixin.tasks["t5"] is rq_task + + +# --------------------------------------------------------------------------- +# Layer C: Background zombie reaper +# --------------------------------------------------------------------------- + + +class TestZombieReaper: + @pytest.mark.asyncio + async def test_reaps_old_completed_tasks(self): + mixin = FakeMixin() + old_time = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta( + hours=2 + ) + old_task = _make_task("old-1", TaskStatus.SUCCESS, finished_at=old_time) + mixin.tasks["old-1"] = old_task + mixin._task_result_keys["old-1"] = "key-1" + + reaper = asyncio.create_task( + mixin._reap_zombie_tasks(interval=0.01, max_age=3600.0) + ) + await asyncio.sleep(0.05) + reaper.cancel() + try: + await reaper + except asyncio.CancelledError: + pass + + assert "old-1" not in mixin.tasks + assert "old-1" not in mixin._task_result_keys + + @pytest.mark.asyncio + async def test_keeps_recent_completed_tasks(self): + mixin = FakeMixin() + recent_time = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta( + minutes=5 + ) + recent_task = _make_task( + "recent-1", TaskStatus.SUCCESS, finished_at=recent_time + ) + mixin.tasks["recent-1"] = recent_task + + reaper = asyncio.create_task( + mixin._reap_zombie_tasks(interval=0.01, max_age=3600.0) + ) + await asyncio.sleep(0.05) + reaper.cancel() + try: + await reaper + except asyncio.CancelledError: + pass + + assert "recent-1" in mixin.tasks + + @pytest.mark.asyncio + async def test_keeps_in_progress_tasks(self): + mixin = FakeMixin() + old_time = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta( + hours=2 + ) + started_task = _make_task("started-1", TaskStatus.STARTED) + started_task.started_at = old_time + mixin.tasks["started-1"] = started_task + + pending_task = _make_task("pending-1", TaskStatus.PENDING) + mixin.tasks["pending-1"] = pending_task + + reaper = asyncio.create_task( + mixin._reap_zombie_tasks(interval=0.01, max_age=3600.0) + ) + await asyncio.sleep(0.05) + reaper.cancel() + try: + await reaper + except asyncio.CancelledError: + pass + + assert "started-1" in mixin.tasks + assert "pending-1" in mixin.tasks + + @pytest.mark.asyncio + async def test_reaps_failed_tasks_with_finished_at(self): + mixin = FakeMixin() + old_time = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta( + hours=2 + ) + failed_task = _make_task("failed-1", TaskStatus.FAILURE, finished_at=old_time) + mixin.tasks["failed-1"] = failed_task + + reaper = asyncio.create_task( + mixin._reap_zombie_tasks(interval=0.01, max_age=3600.0) + ) + await asyncio.sleep(0.05) + reaper.cancel() + try: + await reaper + except asyncio.CancelledError: + pass + + assert "failed-1" not in mixin.tasks + + +# --------------------------------------------------------------------------- +# Layer E: TTL alignment +# --------------------------------------------------------------------------- + + +class TestTTLAlignment: + @pytest.mark.asyncio + async def test_store_task_uses_results_ttl(self): + """Metadata TTL should match eng_rq_results_ttl, not 86400.""" + mixin = FakeMixin() + task = _make_task("ttl-task", TaskStatus.SUCCESS) + + mock_redis = AsyncMock() + mock_redis.__aenter__ = AsyncMock(return_value=mock_redis) + mock_redis.__aexit__ = AsyncMock(return_value=False) + + with ( + patch( + "docling_serve.orchestrator_factory.redis.Redis", + return_value=mock_redis, + ), + patch( + "docling_serve.orchestrator_factory.docling_serve_settings" + ) as mock_settings, + ): + mock_settings.eng_rq_results_ttl = 14400 + await mixin._store_task_in_redis(task) + + mock_redis.set.assert_called_once() + call_kwargs = mock_redis.set.call_args + assert ( + call_kwargs.kwargs.get("ex") == 14400 or call_kwargs[1].get("ex") == 14400 + ) + + +# --------------------------------------------------------------------------- +# Error message propagation through Redis +# --------------------------------------------------------------------------- + + +class TestErrorMessagePropagation: + @pytest.mark.asyncio + async def test_store_and_retrieve_error_message(self): + """error_message should round-trip through Redis store/get.""" + mixin = FakeMixin() + task = _make_task("err-task", TaskStatus.FAILURE, error_message="Out of memory") + + stored_data = {} + + async def fake_set(key: str, value: Any, ex: int = 0) -> None: + stored_data[key] = value + + async def fake_get(key: str) -> bytes | None: + val = stored_data.get(key) + if val is None: + return None + return val.encode() if isinstance(val, str) else val + + mock_redis = AsyncMock() + mock_redis.__aenter__ = AsyncMock(return_value=mock_redis) + mock_redis.__aexit__ = AsyncMock(return_value=False) + mock_redis.set = fake_set + mock_redis.get = fake_get + + with ( + patch( + "docling_serve.orchestrator_factory.redis.Redis", + return_value=mock_redis, + ), + patch( + "docling_serve.orchestrator_factory.docling_serve_settings" + ) as mock_settings, + ): + mock_settings.eng_rq_results_ttl = 14400 + await mixin._store_task_in_redis(task) + retrieved = await mixin._get_task_from_redis("err-task") + + assert retrieved is not None + assert retrieved.error_message == "Out of memory" + assert retrieved.task_status == TaskStatus.FAILURE