mirror of
https://github.com/docling-project/docling-serve.git
synced 2025-11-29 08:33:50 +00:00
86 lines
3.0 KiB
Python
86 lines
3.0 KiB
Python
import shutil
|
|
from typing import Union
|
|
|
|
from fastapi import BackgroundTasks, WebSocket
|
|
from fastapi.responses import FileResponse
|
|
|
|
from docling_serve.datamodel.callback import ProgressCallbackRequest
|
|
from docling_serve.datamodel.engines import TaskStatus
|
|
from docling_serve.datamodel.responses import (
|
|
ConvertDocumentResponse,
|
|
MessageKind,
|
|
TaskStatusResponse,
|
|
WebsocketMessage,
|
|
)
|
|
from docling_serve.datamodel.task import Task
|
|
from docling_serve.engines.base_orchestrator import (
|
|
BaseOrchestrator,
|
|
OrchestratorError,
|
|
TaskNotFoundError,
|
|
)
|
|
from docling_serve.settings import docling_serve_settings
|
|
|
|
|
|
class ProgressInvalid(OrchestratorError):
|
|
pass
|
|
|
|
|
|
class BaseAsyncOrchestrator(BaseOrchestrator):
|
|
def __init__(self):
|
|
self.tasks: dict[str, Task] = {}
|
|
self.task_subscribers: dict[str, set[WebSocket]] = {}
|
|
|
|
async def init_task_tracking(self, task: Task):
|
|
task_id = task.task_id
|
|
self.tasks[task.task_id] = task
|
|
self.task_subscribers[task_id] = set()
|
|
|
|
async def get_raw_task(self, task_id: str) -> Task:
|
|
if task_id not in self.tasks:
|
|
raise TaskNotFoundError()
|
|
return self.tasks[task_id]
|
|
|
|
async def task_status(self, task_id: str, wait: float = 0.0) -> Task:
|
|
return await self.get_raw_task(task_id=task_id)
|
|
|
|
async def task_result(
|
|
self, task_id: str, background_tasks: BackgroundTasks
|
|
) -> Union[ConvertDocumentResponse, FileResponse, None]:
|
|
task = await self.get_raw_task(task_id=task_id)
|
|
if task.is_completed() and task.scratch_dir is not None:
|
|
if docling_serve_settings.single_use_results:
|
|
background_tasks.add_task(
|
|
shutil.rmtree, task.scratch_dir, ignore_errors=True
|
|
)
|
|
return task.result
|
|
|
|
async def notify_task_subscribers(self, task_id: str):
|
|
if task_id not in self.task_subscribers:
|
|
raise RuntimeError(f"Task {task_id} does not have a subscribers list.")
|
|
|
|
task = await self.get_raw_task(task_id=task_id)
|
|
task_queue_position = await self.get_queue_position(task_id)
|
|
msg = TaskStatusResponse(
|
|
task_id=task.task_id,
|
|
task_status=task.task_status,
|
|
task_position=task_queue_position,
|
|
task_meta=task.processing_meta,
|
|
)
|
|
for websocket in self.task_subscribers[task_id]:
|
|
await websocket.send_text(
|
|
WebsocketMessage(message=MessageKind.UPDATE, task=msg).model_dump_json()
|
|
)
|
|
if task.is_completed():
|
|
await websocket.close()
|
|
|
|
async def notify_queue_positions(self):
|
|
for task_id in self.task_subscribers.keys():
|
|
# notify only pending tasks
|
|
if self.tasks[task_id].task_status != TaskStatus.PENDING:
|
|
continue
|
|
|
|
await self.notify_task_subscribers(task_id)
|
|
|
|
async def receive_task_progress(self, request: ProgressCallbackRequest):
|
|
raise NotImplementedError()
|