diff --git a/docling_serve/app.py b/docling_serve/app.py index 5f93b30..2774416 100644 --- a/docling_serve/app.py +++ b/docling_serve/app.py @@ -869,7 +869,10 @@ def create_app(): # noqa: C901 assert isinstance(orchestrator.notifier, WebsocketNotifier) await websocket.accept() - if task_id not in orchestrator.tasks: + try: + # Get task status from Redis or RQ directly instead of checking in-memory registry + task = await orchestrator.task_status(task_id=task_id) + except TaskNotFoundError: await websocket.send_text( WebsocketMessage( message=MessageKind.ERROR, error="Task not found." @@ -878,8 +881,6 @@ def create_app(): # noqa: C901 await websocket.close() return - task = orchestrator.tasks[task_id] - # Track active WebSocket connections for this job orchestrator.notifier.task_subscribers[task_id].add(websocket) diff --git a/docling_serve/orchestrator_factory.py b/docling_serve/orchestrator_factory.py index 7e1c29d..04f80c5 100644 --- a/docling_serve/orchestrator_factory.py +++ b/docling_serve/orchestrator_factory.py @@ -1,10 +1,267 @@ +import json +import logging from functools import lru_cache +from typing import Any, Optional -from docling_jobkit.orchestrators.base_orchestrator import BaseOrchestrator +import redis.asyncio as redis + +from docling_jobkit.datamodel.task import Task +from docling_jobkit.datamodel.task_meta import TaskStatus +from docling_jobkit.orchestrators.base_orchestrator import ( + BaseOrchestrator, + TaskNotFoundError, +) from docling_serve.settings import AsyncEngine, docling_serve_settings from docling_serve.storage import get_scratch +_log = logging.getLogger(__name__) + + +class RedisTaskStatusMixin: + tasks: dict[str, Task] + _task_result_keys: dict[str, str] + config: Any + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.redis_prefix = "docling:tasks:" + self._redis_pool = redis.ConnectionPool.from_url( + self.config.redis_url, + max_connections=10, + socket_timeout=2.0, + ) + + async def task_status(self, task_id: str, wait: float = 0.0) -> Task: + """ + Get task status by checking Redis first, then falling back to RQ verification. + + When Redis shows 'pending' but RQ shows 'success', we update Redis + and return the RQ status for cross-instance consistency. + """ + _log.info(f"Task {task_id} status check") + + # Always check RQ directly first - this is the most reliable source + rq_task = await self._get_task_from_rq_direct(task_id) + if rq_task: + _log.info(f"Task {task_id} in RQ: {rq_task.task_status}") + + # Update memory registry + self.tasks[task_id] = rq_task + + # Store/update in Redis for other instances + await self._store_task_in_redis(rq_task) + return rq_task + + # If not in RQ, check Redis (maybe it's cached from another instance) + task = await self._get_task_from_redis(task_id) + if task: + _log.info(f"Task {task_id} in Redis: {task.task_status}") + + # CRITICAL FIX: Check if Redis status might be stale + # STARTED tasks might have completed since they were cached + if task.task_status in [TaskStatus.PENDING, TaskStatus.STARTED]: + _log.debug(f"Task {task_id} verifying stale status") + + # Try to get fresh status from RQ + fresh_rq_task = await self._get_task_from_rq_direct(task_id) + if fresh_rq_task and fresh_rq_task.task_status != task.task_status: + _log.info( + f"Task {task_id} status updated: {fresh_rq_task.task_status}" + ) + + # Update memory and Redis with fresh status + self.tasks[task_id] = fresh_rq_task + await self._store_task_in_redis(fresh_rq_task) + return fresh_rq_task + else: + _log.debug(f"Task {task_id} status consistent") + + return task + + # Fall back to parent implementation + try: + parent_task = await super().task_status(task_id, wait) # type: ignore[misc] + _log.debug(f"Task {task_id} from parent: {parent_task.task_status}") + + # Store in Redis for other instances to find + await self._store_task_in_redis(parent_task) + return parent_task + except TaskNotFoundError: + _log.warning(f"Task {task_id} not found") + raise + + async def _get_task_from_redis(self, task_id: str) -> Optional[Task]: + try: + async with redis.Redis(connection_pool=self._redis_pool) as r: + task_data = await r.get(f"{self.redis_prefix}{task_id}:metadata") + if not task_data: + return None + + data: dict[str, Any] = json.loads(task_data) + meta = data.get("processing_meta") or {} + meta.setdefault("num_docs", 0) + meta.setdefault("num_processed", 0) + meta.setdefault("num_succeeded", 0) + meta.setdefault("num_failed", 0) + + return Task( + task_id=data["task_id"], + task_type=data["task_type"], + task_status=TaskStatus(data["task_status"]), + processing_meta=meta, + ) + except Exception as e: + _log.error(f"Redis get task {task_id}: {e}") + return None + + async def _get_task_from_rq_direct(self, task_id: str) -> Optional[Task]: + try: + _log.debug(f"Checking RQ for task {task_id}") + + temp_task = Task( + task_id=task_id, + task_type="convert", + task_status=TaskStatus.PENDING, + processing_meta={ + "num_docs": 0, + "num_processed": 0, + "num_succeeded": 0, + "num_failed": 0, + }, + ) + + original_task = self.tasks.get(task_id) + self.tasks[task_id] = temp_task + + try: + await super()._update_task_from_rq(task_id) # type: ignore[misc] + + updated_task = self.tasks.get(task_id) + if updated_task and updated_task.task_status != TaskStatus.PENDING: + _log.debug(f"RQ task {task_id}: {updated_task.task_status}") + + # Store result key if available + if task_id in self._task_result_keys: + try: + async with redis.Redis( + connection_pool=self._redis_pool + ) as r: + await r.set( + f"{self.redis_prefix}{task_id}:result_key", + self._task_result_keys[task_id], + ex=86400, + ) + _log.debug(f"Stored result key for {task_id}") + except Exception as e: + _log.error(f"Store result key {task_id}: {e}") + + return updated_task + return None + + finally: + # Restore original task state + if original_task: + self.tasks[task_id] = original_task + elif task_id in self.tasks and self.tasks[task_id] == temp_task: + # Only remove if it's still our temp task + del self.tasks[task_id] + + except Exception as e: + _log.error(f"RQ check {task_id}: {e}") + return None + + async def get_raw_task(self, task_id: str) -> Task: + if task_id in self.tasks: + return self.tasks[task_id] + + task = await self._get_task_from_redis(task_id) + if task: + self.tasks[task_id] = task + return task + + try: + parent_task = await super().get_raw_task(task_id) # type: ignore[misc] + await self._store_task_in_redis(parent_task) + return parent_task + except TaskNotFoundError: + raise + + async def _store_task_in_redis(self, task: Task) -> None: + try: + meta: Any = task.processing_meta + if hasattr(meta, "model_dump"): + meta = meta.model_dump() + elif not isinstance(meta, dict): + meta = { + "num_docs": 0, + "num_processed": 0, + "num_succeeded": 0, + "num_failed": 0, + } + + data: dict[str, Any] = { + "task_id": task.task_id, + "task_type": task.task_type.value + if hasattr(task.task_type, "value") + else str(task.task_type), + "task_status": task.task_status.value, + "processing_meta": meta, + } + async with redis.Redis(connection_pool=self._redis_pool) as r: + await r.set( + f"{self.redis_prefix}{task.task_id}:metadata", + json.dumps(data), + ex=86400, + ) + except Exception as e: + _log.error(f"Store task {task.task_id}: {e}") + + async def enqueue(self, **kwargs): # type: ignore[override] + task = await super().enqueue(**kwargs) # type: ignore[misc] + await self._store_task_in_redis(task) + return task + + async def task_result(self, task_id: str): # type: ignore[override] + result = await super().task_result(task_id) # type: ignore[misc] + if result is not None: + return result + + try: + async with redis.Redis(connection_pool=self._redis_pool) as r: + result_key = await r.get(f"{self.redis_prefix}{task_id}:result_key") + if result_key: + self._task_result_keys[task_id] = result_key.decode("utf-8") + return await super().task_result(task_id) # type: ignore[misc] + except Exception as e: + _log.error(f"Redis result key {task_id}: {e}") + + return None + + async def _update_task_from_rq(self, task_id: str) -> None: + original_status = ( + self.tasks[task_id].task_status if task_id in self.tasks else None + ) + + await super()._update_task_from_rq(task_id) # type: ignore[misc] + + if task_id in self.tasks: + new_status = self.tasks[task_id].task_status + if original_status != new_status: + _log.debug(f"Task {task_id} status: {original_status} -> {new_status}") + await self._store_task_in_redis(self.tasks[task_id]) + + if task_id in self._task_result_keys: + try: + async with redis.Redis(connection_pool=self._redis_pool) as r: + await r.set( + f"{self.redis_prefix}{task_id}:result_key", + self._task_result_keys[task_id], + ex=86400, + ) + except Exception as e: + _log.error(f"Store result key {task_id}: {e}") + @lru_cache def get_async_orchestrator() -> BaseOrchestrator: @@ -35,12 +292,16 @@ def get_async_orchestrator() -> BaseOrchestrator: cm = DoclingConverterManager(config=cm_config) return LocalOrchestrator(config=local_config, converter_manager=cm) + elif docling_serve_settings.eng_kind == AsyncEngine.RQ: from docling_jobkit.orchestrators.rq.orchestrator import ( RQOrchestrator, RQOrchestratorConfig, ) + class RedisAwareRQOrchestrator(RedisTaskStatusMixin, RQOrchestrator): # type: ignore[misc] + pass + rq_config = RQOrchestratorConfig( redis_url=docling_serve_settings.eng_rq_redis_url, results_prefix=docling_serve_settings.eng_rq_results_prefix, @@ -48,7 +309,8 @@ def get_async_orchestrator() -> BaseOrchestrator: scratch_dir=get_scratch(), ) - return RQOrchestrator(config=rq_config) + return RedisAwareRQOrchestrator(config=rq_config) + elif docling_serve_settings.eng_kind == AsyncEngine.KFP: from docling_jobkit.orchestrators.kfp.orchestrator import ( KfpOrchestrator, diff --git a/docling_serve/websocket_notifier.py b/docling_serve/websocket_notifier.py index 9e7644e..779c5af 100644 --- a/docling_serve/websocket_notifier.py +++ b/docling_serve/websocket_notifier.py @@ -30,26 +30,47 @@ class WebsocketNotifier(BaseNotifier): if task_id not in self.task_subscribers: raise RuntimeError(f"Task {task_id} does not have a subscribers list.") - task = await self.orchestrator.get_raw_task(task_id=task_id) - task_queue_position = await self.orchestrator.get_queue_position(task_id) - msg = TaskStatusResponse( - task_id=task.task_id, - task_type=task.task_type, - task_status=task.task_status, - task_position=task_queue_position, - task_meta=task.processing_meta, - ) - for websocket in self.task_subscribers[task_id]: - await websocket.send_text( - WebsocketMessage(message=MessageKind.UPDATE, task=msg).model_dump_json() + try: + # Get task status from Redis or RQ directly instead of in-memory registry + task = await self.orchestrator.task_status(task_id=task_id) + task_queue_position = await self.orchestrator.get_queue_position(task_id) + msg = TaskStatusResponse( + task_id=task.task_id, + task_type=task.task_type, + task_status=task.task_status, + task_position=task_queue_position, + task_meta=task.processing_meta, ) - if task.is_completed(): - await websocket.close() + for websocket in self.task_subscribers[task_id]: + await websocket.send_text( + WebsocketMessage( + message=MessageKind.UPDATE, task=msg + ).model_dump_json() + ) + if task.is_completed(): + await websocket.close() + except Exception as e: + # Log the error but don't crash the notifier + import logging + + _log = logging.getLogger(__name__) + _log.error(f"Error notifying subscribers for task {task_id}: {e}") async def notify_queue_positions(self): + """Notify all subscribers of pending tasks about queue position updates.""" for task_id in self.task_subscribers.keys(): - # notify only pending tasks - if self.orchestrator.tasks[task_id].task_status != TaskStatus.PENDING: - continue + try: + # Check task status directly from Redis or RQ + task = await self.orchestrator.task_status(task_id) - await self.notify_task_subscribers(task_id) + # Notify only pending tasks + if task.task_status == TaskStatus.PENDING: + await self.notify_task_subscribers(task_id) + except Exception as e: + # Log the error but don't crash the notifier + import logging + + _log = logging.getLogger(__name__) + _log.error( + f"Error checking task {task_id} status for queue position notification: {e}" + )