mirror of
https://github.com/docling-project/docling-serve.git
synced 2025-12-01 09:33:18 +00:00
Signed-off-by: Pawel Rein <pawel.rein@prezi.com>
This commit is contained in:
@@ -30,26 +30,47 @@ class WebsocketNotifier(BaseNotifier):
|
||||
if task_id not in self.task_subscribers:
|
||||
raise RuntimeError(f"Task {task_id} does not have a subscribers list.")
|
||||
|
||||
task = await self.orchestrator.get_raw_task(task_id=task_id)
|
||||
task_queue_position = await self.orchestrator.get_queue_position(task_id)
|
||||
msg = TaskStatusResponse(
|
||||
task_id=task.task_id,
|
||||
task_type=task.task_type,
|
||||
task_status=task.task_status,
|
||||
task_position=task_queue_position,
|
||||
task_meta=task.processing_meta,
|
||||
)
|
||||
for websocket in self.task_subscribers[task_id]:
|
||||
await websocket.send_text(
|
||||
WebsocketMessage(message=MessageKind.UPDATE, task=msg).model_dump_json()
|
||||
try:
|
||||
# Get task status from Redis or RQ directly instead of in-memory registry
|
||||
task = await self.orchestrator.task_status(task_id=task_id)
|
||||
task_queue_position = await self.orchestrator.get_queue_position(task_id)
|
||||
msg = TaskStatusResponse(
|
||||
task_id=task.task_id,
|
||||
task_type=task.task_type,
|
||||
task_status=task.task_status,
|
||||
task_position=task_queue_position,
|
||||
task_meta=task.processing_meta,
|
||||
)
|
||||
if task.is_completed():
|
||||
await websocket.close()
|
||||
for websocket in self.task_subscribers[task_id]:
|
||||
await websocket.send_text(
|
||||
WebsocketMessage(
|
||||
message=MessageKind.UPDATE, task=msg
|
||||
).model_dump_json()
|
||||
)
|
||||
if task.is_completed():
|
||||
await websocket.close()
|
||||
except Exception as e:
|
||||
# Log the error but don't crash the notifier
|
||||
import logging
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
_log.error(f"Error notifying subscribers for task {task_id}: {e}")
|
||||
|
||||
async def notify_queue_positions(self):
|
||||
"""Notify all subscribers of pending tasks about queue position updates."""
|
||||
for task_id in self.task_subscribers.keys():
|
||||
# notify only pending tasks
|
||||
if self.orchestrator.tasks[task_id].task_status != TaskStatus.PENDING:
|
||||
continue
|
||||
try:
|
||||
# Check task status directly from Redis or RQ
|
||||
task = await self.orchestrator.task_status(task_id)
|
||||
|
||||
await self.notify_task_subscribers(task_id)
|
||||
# Notify only pending tasks
|
||||
if task.task_status == TaskStatus.PENDING:
|
||||
await self.notify_task_subscribers(task_id)
|
||||
except Exception as e:
|
||||
# Log the error but don't crash the notifier
|
||||
import logging
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
_log.error(
|
||||
f"Error checking task {task_id} status for queue position notification: {e}"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user