feat!: use orchestrators from jobkit (#248)

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
Michele Dolfi
2025-07-10 15:47:22 +02:00
committed by GitHub
parent e63197e89e
commit daa924a77e
30 changed files with 813 additions and 1997 deletions

View File

@@ -28,12 +28,18 @@ from fastapi.staticfiles import StaticFiles
from scalar_fastapi import get_scalar_api_reference
from docling.datamodel.base_models import DocumentStream
from docling_serve.datamodel.callback import (
from docling_jobkit.datamodel.callback import (
ProgressCallbackRequest,
ProgressCallbackResponse,
)
from docling_serve.datamodel.convert import ConvertDocumentsOptions
from docling_jobkit.datamodel.task import Task, TaskSource
from docling_jobkit.orchestrators.base_orchestrator import (
BaseOrchestrator,
ProgressInvalid,
TaskNotFoundError,
)
from docling_serve.datamodel.convert import ConvertDocumentsRequestOptions
from docling_serve.datamodel.requests import (
ConvertDocumentFileSourcesRequest,
ConvertDocumentHttpSourcesRequest,
@@ -47,17 +53,12 @@ from docling_serve.datamodel.responses import (
TaskStatusResponse,
WebsocketMessage,
)
from docling_serve.datamodel.task import Task, TaskSource
from docling_serve.docling_conversion import _get_converter_from_hash
from docling_serve.engines.async_orchestrator import (
BaseAsyncOrchestrator,
ProgressInvalid,
)
from docling_serve.engines.async_orchestrator_factory import get_async_orchestrator
from docling_serve.engines.base_orchestrator import TaskNotFoundError
from docling_serve.helper_functions import FormDepends
from docling_serve.orchestrator_factory import get_async_orchestrator
from docling_serve.response_preparation import prepare_response
from docling_serve.settings import docling_serve_settings
from docling_serve.storage import get_scratch
from docling_serve.websocker_notifier import WebsocketNotifier
# Set up custom logging as we'll be intermixes with FastAPI/Uvicorn's logging
@@ -95,9 +96,12 @@ _log = logging.getLogger(__name__)
# Context manager to initialize and clean up the lifespan of the FastAPI app
@asynccontextmanager
async def lifespan(app: FastAPI):
orchestrator = get_async_orchestrator()
scratch_dir = get_scratch()
orchestrator = get_async_orchestrator()
notifier = WebsocketNotifier(orchestrator)
orchestrator.bind_notifier(notifier)
# Warm up processing cache
if docling_serve_settings.load_models_at_boot:
await orchestrator.warm_up_caches()
@@ -230,7 +234,7 @@ def create_app(): # noqa: C901
########################
async def _enque_source(
orchestrator: BaseAsyncOrchestrator, conversion_request: ConvertDocumentsRequest
orchestrator: BaseOrchestrator, conversion_request: ConvertDocumentsRequest
) -> Task:
sources: list[TaskSource] = []
if isinstance(conversion_request, ConvertDocumentFileSourcesRequest):
@@ -244,9 +248,9 @@ def create_app(): # noqa: C901
return task
async def _enque_file(
orchestrator: BaseAsyncOrchestrator,
orchestrator: BaseOrchestrator,
files: list[UploadFile],
options: ConvertDocumentsOptions,
options: ConvertDocumentsRequestOptions,
) -> Task:
_log.info(f"Received {len(files)} files for processing.")
@@ -261,9 +265,7 @@ def create_app(): # noqa: C901
task = await orchestrator.enqueue(sources=file_sources, options=options)
return task
async def _wait_task_complete(
orchestrator: BaseAsyncOrchestrator, task_id: str
) -> bool:
async def _wait_task_complete(orchestrator: BaseOrchestrator, task_id: str) -> bool:
start_time = time.monotonic()
while True:
task = await orchestrator.task_status(task_id=task_id)
@@ -309,32 +311,28 @@ def create_app(): # noqa: C901
)
async def process_url(
background_tasks: BackgroundTasks,
orchestrator: Annotated[BaseAsyncOrchestrator, Depends(get_async_orchestrator)],
orchestrator: Annotated[BaseOrchestrator, Depends(get_async_orchestrator)],
conversion_request: ConvertDocumentsRequest,
):
task = await _enque_source(
orchestrator=orchestrator, conversion_request=conversion_request
)
success = await _wait_task_complete(
completed = await _wait_task_complete(
orchestrator=orchestrator, task_id=task.task_id
)
if not success:
if not completed:
# TODO: abort task!
return HTTPException(
status_code=504,
detail=f"Conversion is taking too long. The maximum wait time is configure as DOCLING_SERVE_MAX_SYNC_WAIT={docling_serve_settings.max_sync_wait}.",
)
result = await orchestrator.task_result(
task_id=task.task_id, background_tasks=background_tasks
task = await orchestrator.get_raw_task(task_id=task.task_id)
response = await prepare_response(
task=task, orchestrator=orchestrator, background_tasks=background_tasks
)
if result is None:
raise HTTPException(
status_code=404,
detail="Task result not found. Please wait for a completion status.",
)
return result
return response
# Convert a document from file(s)
@app.post(
@@ -348,35 +346,31 @@ def create_app(): # noqa: C901
)
async def process_file(
background_tasks: BackgroundTasks,
orchestrator: Annotated[BaseAsyncOrchestrator, Depends(get_async_orchestrator)],
orchestrator: Annotated[BaseOrchestrator, Depends(get_async_orchestrator)],
files: list[UploadFile],
options: Annotated[
ConvertDocumentsOptions, FormDepends(ConvertDocumentsOptions)
ConvertDocumentsRequestOptions, FormDepends(ConvertDocumentsRequestOptions)
],
):
task = await _enque_file(
orchestrator=orchestrator, files=files, options=options
)
success = await _wait_task_complete(
completed = await _wait_task_complete(
orchestrator=orchestrator, task_id=task.task_id
)
if not success:
if not completed:
# TODO: abort task!
return HTTPException(
status_code=504,
detail=f"Conversion is taking too long. The maximum wait time is configure as DOCLING_SERVE_MAX_SYNC_WAIT={docling_serve_settings.max_sync_wait}.",
)
result = await orchestrator.task_result(
task_id=task.task_id, background_tasks=background_tasks
task = await orchestrator.get_raw_task(task_id=task.task_id)
response = await prepare_response(
task=task, orchestrator=orchestrator, background_tasks=background_tasks
)
if result is None:
raise HTTPException(
status_code=404,
detail="Task result not found. Please wait for a completion status.",
)
return result
return response
# Convert a document from URL(s) using the async api
@app.post(
@@ -384,7 +378,7 @@ def create_app(): # noqa: C901
response_model=TaskStatusResponse,
)
async def process_url_async(
orchestrator: Annotated[BaseAsyncOrchestrator, Depends(get_async_orchestrator)],
orchestrator: Annotated[BaseOrchestrator, Depends(get_async_orchestrator)],
conversion_request: ConvertDocumentsRequest,
):
task = await _enque_source(
@@ -406,11 +400,11 @@ def create_app(): # noqa: C901
response_model=TaskStatusResponse,
)
async def process_file_async(
orchestrator: Annotated[BaseAsyncOrchestrator, Depends(get_async_orchestrator)],
orchestrator: Annotated[BaseOrchestrator, Depends(get_async_orchestrator)],
background_tasks: BackgroundTasks,
files: list[UploadFile],
options: Annotated[
ConvertDocumentsOptions, FormDepends(ConvertDocumentsOptions)
ConvertDocumentsRequestOptions, FormDepends(ConvertDocumentsRequestOptions)
],
):
task = await _enque_file(
@@ -432,7 +426,7 @@ def create_app(): # noqa: C901
response_model=TaskStatusResponse,
)
async def task_status_poll(
orchestrator: Annotated[BaseAsyncOrchestrator, Depends(get_async_orchestrator)],
orchestrator: Annotated[BaseOrchestrator, Depends(get_async_orchestrator)],
task_id: str,
wait: Annotated[
float, Query(help="Number of seconds to wait for a completed status.")
@@ -456,9 +450,10 @@ def create_app(): # noqa: C901
)
async def task_status_ws(
websocket: WebSocket,
orchestrator: Annotated[BaseAsyncOrchestrator, Depends(get_async_orchestrator)],
orchestrator: Annotated[BaseOrchestrator, Depends(get_async_orchestrator)],
task_id: str,
):
assert isinstance(orchestrator.notifier, WebsocketNotifier)
await websocket.accept()
if task_id not in orchestrator.tasks:
@@ -473,7 +468,7 @@ def create_app(): # noqa: C901
task = orchestrator.tasks[task_id]
# Track active WebSocket connections for this job
orchestrator.task_subscribers[task_id].add(websocket)
orchestrator.notifier.task_subscribers[task_id].add(websocket)
try:
task_queue_position = await orchestrator.get_queue_position(task_id=task_id)
@@ -511,7 +506,7 @@ def create_app(): # noqa: C901
_log.info(f"WebSocket disconnected for job {task_id}")
finally:
orchestrator.task_subscribers[task_id].remove(websocket)
orchestrator.notifier.task_subscribers[task_id].remove(websocket)
# Task result
@app.get(
@@ -524,19 +519,18 @@ def create_app(): # noqa: C901
},
)
async def task_result(
orchestrator: Annotated[BaseAsyncOrchestrator, Depends(get_async_orchestrator)],
orchestrator: Annotated[BaseOrchestrator, Depends(get_async_orchestrator)],
background_tasks: BackgroundTasks,
task_id: str,
):
result = await orchestrator.task_result(
task_id=task_id, background_tasks=background_tasks
)
if result is None:
raise HTTPException(
status_code=404,
detail="Task result not found. Please wait for a completion status.",
try:
task = await orchestrator.get_raw_task(task_id=task_id)
response = await prepare_response(
task=task, orchestrator=orchestrator, background_tasks=background_tasks
)
return result
return response
except TaskNotFoundError:
raise HTTPException(status_code=404, detail="Task not found.")
# Update task progress
@app.post(
@@ -544,7 +538,7 @@ def create_app(): # noqa: C901
response_model=ProgressCallbackResponse,
)
async def callback_task_progress(
orchestrator: Annotated[BaseAsyncOrchestrator, Depends(get_async_orchestrator)],
orchestrator: Annotated[BaseOrchestrator, Depends(get_async_orchestrator)],
request: ProgressCallbackRequest,
):
try:
@@ -564,8 +558,10 @@ def create_app(): # noqa: C901
"/v1alpha/clear/converters",
response_model=ClearResponse,
)
async def clear_converters():
_get_converter_from_hash.cache_clear()
async def clear_converters(
orchestrator: Annotated[BaseOrchestrator, Depends(get_async_orchestrator)],
):
await orchestrator.clear_converters()
return ClearResponse()
# Clean results
@@ -574,7 +570,7 @@ def create_app(): # noqa: C901
response_model=ClearResponse,
)
async def clear_results(
orchestrator: Annotated[BaseAsyncOrchestrator, Depends(get_async_orchestrator)],
orchestrator: Annotated[BaseOrchestrator, Depends(get_async_orchestrator)],
older_then: float = 3600,
):
await orchestrator.clear_results(older_than=older_then)