Files
docling-serve/docling_serve/websocket_notifier.py

77 lines
2.9 KiB
Python

from fastapi import WebSocket
from docling_jobkit.datamodel.task_meta import TaskStatus
from docling_jobkit.orchestrators.base_notifier import BaseNotifier
from docling_jobkit.orchestrators.base_orchestrator import BaseOrchestrator
from docling_serve.datamodel.responses import (
MessageKind,
TaskStatusResponse,
WebsocketMessage,
)
class WebsocketNotifier(BaseNotifier):
def __init__(self, orchestrator: BaseOrchestrator):
super().__init__(orchestrator)
self.task_subscribers: dict[str, set[WebSocket]] = {}
async def add_task(self, task_id: str):
self.task_subscribers[task_id] = set()
async def remove_task(self, task_id: str):
if task_id in self.task_subscribers:
for websocket in self.task_subscribers[task_id]:
await websocket.close()
del self.task_subscribers[task_id]
async def notify_task_subscribers(self, task_id: str):
if task_id not in self.task_subscribers:
raise RuntimeError(f"Task {task_id} does not have a subscribers list.")
try:
# Get task status from Redis or RQ directly instead of in-memory registry
task = await self.orchestrator.task_status(task_id=task_id)
task_queue_position = await self.orchestrator.get_queue_position(task_id)
msg = TaskStatusResponse(
task_id=task.task_id,
task_type=task.task_type,
task_status=task.task_status,
task_position=task_queue_position,
task_meta=task.processing_meta,
)
for websocket in self.task_subscribers[task_id]:
await websocket.send_text(
WebsocketMessage(
message=MessageKind.UPDATE, task=msg
).model_dump_json()
)
if task.is_completed():
await websocket.close()
except Exception as e:
# Log the error but don't crash the notifier
import logging
_log = logging.getLogger(__name__)
_log.error(f"Error notifying subscribers for task {task_id}: {e}")
async def notify_queue_positions(self):
"""Notify all subscribers of pending tasks about queue position updates."""
for task_id in self.task_subscribers.keys():
try:
# Check task status directly from Redis or RQ
task = await self.orchestrator.task_status(task_id)
# Notify only pending tasks
if task.task_status == TaskStatus.PENDING:
await self.notify_task_subscribers(task_id)
except Exception as e:
# Log the error but don't crash the notifier
import logging
_log = logging.getLogger(__name__)
_log.error(
f"Error checking task {task_id} status for queue position notification: {e}"
)