fix: prevent stale RQ STARTED from overwriting watchdog FAILURE (#523)

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
Co-authored-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
Christoph Auer
2026-03-03 15:39:11 +01:00
committed by GitHub
parent 0de3b4fa1c
commit f4c42f4a82
2 changed files with 86 additions and 63 deletions

View File

@@ -7,9 +7,10 @@ from typing import Any, Union
import redis.asyncio as redis
from rq.exceptions import NoSuchJobError
from rq.job import Job
from docling_jobkit.datamodel.task import Task
from docling_jobkit.datamodel.task_meta import TaskStatus
from docling_jobkit.datamodel.task_meta import TaskProcessingMeta, TaskStatus
from docling_jobkit.orchestrators.base_orchestrator import (
BaseOrchestrator,
TaskNotFoundError,
@@ -32,6 +33,7 @@ class RedisTaskStatusMixin:
tasks: dict[str, Task]
_task_result_keys: dict[str, str]
config: Any
_redis_conn: Any
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
@@ -76,7 +78,16 @@ class RedisTaskStatusMixin:
"""
Get task status with zombie task reconciliation.
Checks RQ first (authoritative), then Redis cache, then in-memory.
Resolution order:
1. Redis (terminal-state gate): if Redis already holds a completed state,
return it immediately without consulting RQ. Prevents stale STARTED
in RQ from overwriting a watchdog-published FAILURE.
2. RQ: authoritative source for non-terminal states. Returns a Task only
when the job is non-PENDING; returns None for PENDING, _RQJobGone on
NoSuchJobError.
3. Redis (fallback): reached only when RQ had no useful answer (PENDING
or job expired). Handles job-gone reconciliation and stale-status
cross-checks. Same Redis key as step 1, different role.
When the RQ job is definitively gone (NoSuchJobError), reconciles:
- Terminal status in Redis -> return it, clean up tracking
- Non-terminal status in Redis -> mark FAILURE (orphaned task)
@@ -84,6 +95,39 @@ class RedisTaskStatusMixin:
"""
_log.info(f"Task {task_id} status check")
# Before consulting RQ (which can report stale STARTED for up to 4 hours
# after a worker kill), check Redis for a terminal state written by
# _on_task_status_changed() or a previous poll. A terminal state in Redis
# is authoritative: written either by the watchdog (after heartbeat expiry
# + grace period) or by the normal success/failure path, neither of which
# can be a false positive for a still-running job.
task_from_redis = await self._get_task_from_redis(task_id)
if task_from_redis is not None and task_from_redis.is_completed():
_log.info(
f"Task {task_id} terminal in Redis ({task_from_redis.task_status}), "
f"skipping RQ check"
)
try:
job_exists = await asyncio.to_thread(
Job.exists, task_id, self._redis_conn
)
except Exception as e:
_log.warning(
f"Task {task_id} terminal in Redis, but RQ existence check "
f"failed: {e}"
)
job_exists = True
if job_exists:
self.tasks[task_id] = task_from_redis
else:
_log.info(
f"Task {task_id} terminal in Redis and RQ job is gone — "
f"cleaning up tracking"
)
self.tasks.pop(task_id, None)
self._task_result_keys.pop(task_id, None)
return task_from_redis
rq_result = await self._get_task_from_rq_direct(task_id)
if isinstance(rq_result, Task):
@@ -113,12 +157,11 @@ class RedisTaskStatusMixin:
f"— marking as FAILURE (orphaned)"
)
task.set_status(TaskStatus.FAILURE)
if hasattr(task, "error_message"):
task.error_message = (
f"Task orphaned: RQ job expired while status was "
f"{task.task_status}. Likely caused by worker restart or "
f"Redis eviction."
)
task.error_message = (
f"Task orphaned: RQ job expired while status was "
f"{task.task_status}. Likely caused by worker restart or "
f"Redis eviction."
)
self.tasks.pop(task_id, None)
self._task_result_keys.pop(task_id, None)
await self._store_task_in_redis(task)
@@ -164,7 +207,7 @@ class RedisTaskStatusMixin:
return None
data: dict[str, Any] = json.loads(task_data)
meta = data.get("processing_meta") or {}
meta = data["processing_meta"]
meta.setdefault("num_docs", 0)
meta.setdefault("num_processed", 0)
meta.setdefault("num_succeeded", 0)
@@ -175,9 +218,12 @@ class RedisTaskStatusMixin:
"task_type": data["task_type"],
"task_status": TaskStatus(data["task_status"]),
"processing_meta": meta,
"error_message": data["error_message"],
"created_at": data["created_at"],
"started_at": data["started_at"],
"finished_at": data["finished_at"],
"last_update_at": data["last_update_at"],
}
if data.get("error_message") and "error_message" in Task.model_fields:
task_kwargs["error_message"] = data["error_message"]
task = Task(**task_kwargs)
return task
except Exception as e:
@@ -190,6 +236,18 @@ class RedisTaskStatusMixin:
try:
_log.debug(f"Checking RQ for task {task_id}")
# Do not consult RQ for tasks already in a terminal state. The temp-task
# swap below would replace self.tasks[task_id] with a PENDING task, making
# the base class's is_completed() guard ineffective and allowing a stale
# RQ STARTED status to overwrite a watchdog-published FAILURE.
original_task = self.tasks.get(task_id)
if original_task is not None and original_task.is_completed():
_log.debug(
f"Task {task_id} already terminal ({original_task.task_status}), "
f"skipping RQ direct check"
)
return original_task
temp_task = Task(
task_id=task_id,
task_type="convert",
@@ -211,22 +269,6 @@ class RedisTaskStatusMixin:
updated_task = self.tasks.get(task_id)
if updated_task and updated_task.task_status != TaskStatus.PENDING:
_log.debug(f"RQ task {task_id}: {updated_task.task_status}")
result_ttl = docling_serve_settings.eng_rq_results_ttl
if task_id in self._task_result_keys:
try:
async with redis.Redis(
connection_pool=self._redis_pool
) as r:
await r.set(
f"{self.redis_prefix}{task_id}:result_key",
self._task_result_keys[task_id],
ex=result_ttl,
)
_log.debug(f"Stored result key for {task_id}")
except Exception as e:
_log.error(f"Store result key {task_id}: {e}")
return updated_task
return None
@@ -262,7 +304,7 @@ class RedisTaskStatusMixin:
async def _store_task_in_redis(self, task: Task) -> None:
try:
meta: Any = task.processing_meta
if hasattr(meta, "model_dump"):
if isinstance(meta, TaskProcessingMeta):
meta = meta.model_dump()
elif not isinstance(meta, dict):
meta = {
@@ -274,14 +316,20 @@ class RedisTaskStatusMixin:
data: dict[str, Any] = {
"task_id": task.task_id,
"task_type": (
task.task_type.value
if hasattr(task.task_type, "value")
else str(task.task_type)
),
"task_type": task.task_type.value,
"task_status": task.task_status.value,
"processing_meta": meta,
"error_message": getattr(task, "error_message", None),
"error_message": task.error_message,
"created_at": task.created_at.isoformat(),
"started_at": (
task.started_at.isoformat() if task.started_at is not None else None
),
"finished_at": (
task.finished_at.isoformat()
if task.finished_at is not None
else None
),
"last_update_at": task.last_update_at.isoformat(),
}
metadata_ttl = docling_serve_settings.eng_rq_results_ttl
@@ -294,27 +342,14 @@ class RedisTaskStatusMixin:
except Exception as e:
_log.error(f"Store task {task.task_id}: {e}")
async def _on_task_status_changed(self, task: Task) -> None:
await self._store_task_in_redis(task)
async def enqueue(self, **kwargs): # type: ignore[override]
task = await super().enqueue(**kwargs) # type: ignore[misc]
await self._store_task_in_redis(task)
return task
async def task_result(self, task_id: str): # type: ignore[override]
result = await super().task_result(task_id) # type: ignore[misc]
if result is not None:
return result
try:
async with redis.Redis(connection_pool=self._redis_pool) as r:
result_key = await r.get(f"{self.redis_prefix}{task_id}:result_key")
if result_key:
self._task_result_keys[task_id] = result_key.decode("utf-8")
return await super().task_result(task_id) # type: ignore[misc]
except Exception as e:
_log.error(f"Redis result key {task_id}: {e}")
return None
async def _update_task_from_rq(self, task_id: str) -> None:
original_status = (
self.tasks[task_id].task_status if task_id in self.tasks else None
@@ -328,18 +363,6 @@ class RedisTaskStatusMixin:
_log.debug(f"Task {task_id} status: {original_status} -> {new_status}")
await self._store_task_in_redis(self.tasks[task_id])
if task_id in self._task_result_keys:
result_ttl = docling_serve_settings.eng_rq_results_ttl
try:
async with redis.Redis(connection_pool=self._redis_pool) as r:
await r.set(
f"{self.redis_prefix}{task_id}:result_key",
self._task_result_keys[task_id],
ex=result_ttl,
)
except Exception as e:
_log.error(f"Store result key {task_id}: {e}")
async def _reap_zombie_tasks(
self, interval: float = 300.0, max_age: float = 3600.0
) -> None:

View File

@@ -337,4 +337,4 @@ branch = "main"
# (note that they must be a subset of the configured allowed types):
parser_angular_allowed_types = "build,chore,ci,docs,feat,fix,perf,style,refactor,test"
parser_angular_minor_types = "feat"
parser_angular_patch_types = "fix,perf"
parser_angular_patch_types = "fix,perf"