feat: Async api (#60)

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
Michele Dolfi
2025-03-07 11:26:50 +01:00
committed by GitHub
parent ed851c95fe
commit 82f8900197
26 changed files with 919 additions and 367 deletions

View File

@@ -1,3 +1,4 @@
import asyncio
import importlib.metadata
import logging
import tempfile
@@ -6,23 +7,46 @@ from io import BytesIO
from pathlib import Path
from typing import Annotated, Any, Dict, List, Optional, Union
from docling.datamodel.base_models import DocumentStream, InputFormat
from docling.document_converter import DocumentConverter
from fastapi import BackgroundTasks, FastAPI, UploadFile
from fastapi import (
BackgroundTasks,
Depends,
FastAPI,
HTTPException,
Query,
UploadFile,
WebSocket,
WebSocketDisconnect,
)
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse
from pydantic import BaseModel
from docling_serve.docling_conversion import (
from docling.datamodel.base_models import DocumentStream, InputFormat
from docling.document_converter import DocumentConverter
from docling_serve.datamodel.convert import ConvertDocumentsOptions
from docling_serve.datamodel.requests import (
ConvertDocumentFileSourcesRequest,
ConvertDocumentsOptions,
ConvertDocumentsRequest,
)
from docling_serve.datamodel.responses import (
ConvertDocumentResponse,
HealthCheckResponse,
MessageKind,
TaskStatusResponse,
WebsocketMessage,
)
from docling_serve.docling_conversion import (
convert_documents,
converters,
get_pdf_pipeline_opts,
)
from docling_serve.engines import get_orchestrator
from docling_serve.engines.async_local.orchestrator import (
AsyncLocalOrchestrator,
TaskNotFoundError,
)
from docling_serve.helper_functions import FormDepends
from docling_serve.response_preparation import ConvertDocumentResponse, process_results
from docling_serve.response_preparation import process_results
from docling_serve.settings import docling_serve_settings
@@ -72,9 +96,22 @@ async def lifespan(app: FastAPI):
converters[options_hash].initialize_pipeline(InputFormat.PDF)
orchestrator = get_orchestrator()
# Start the background queue processor
queue_task = asyncio.create_task(orchestrator.process_queue())
yield
# Cancel the background queue processor on shutdown
queue_task.cancel()
try:
await queue_task
except asyncio.CancelledError:
_log.info("Queue processor cancelled.")
converters.clear()
# if WITH_UI:
# gradio_ui.close()
@@ -84,7 +121,7 @@ async def lifespan(app: FastAPI):
##################################
def create_app():
def create_app(): # noqa: C901
try:
version = importlib.metadata.version("docling_serve")
except importlib.metadata.PackageNotFoundError:
@@ -145,10 +182,6 @@ def create_app():
)
return response
# Status
class HealthCheckResponse(BaseModel):
status: str = "ok"
@app.get("/health")
def health() -> HealthCheckResponse:
return HealthCheckResponse()
@@ -233,4 +266,129 @@ def create_app():
return response
# Convert a document from URL(s) using the async api
@app.post(
"/v1alpha/convert/source/async",
response_model=TaskStatusResponse,
)
async def process_url_async(
orchestrator: Annotated[AsyncLocalOrchestrator, Depends(get_orchestrator)],
conversion_request: ConvertDocumentsRequest,
):
task = await orchestrator.enqueue(request=conversion_request)
task_queue_position = await orchestrator.get_queue_position(
task_id=task.task_id
)
return TaskStatusResponse(
task_id=task.task_id,
task_status=task.task_status,
task_position=task_queue_position,
)
# Task status poll
@app.get(
"/v1alpha/status/poll/{task_id}",
response_model=TaskStatusResponse,
)
async def task_status_poll(
orchestrator: Annotated[AsyncLocalOrchestrator, Depends(get_orchestrator)],
task_id: str,
wait: Annotated[
float, Query(help="Number of seconds to wait for a completed status.")
] = 0.0,
):
try:
task = await orchestrator.task_status(task_id=task_id, wait=wait)
task_queue_position = await orchestrator.get_queue_position(task_id=task_id)
except TaskNotFoundError:
raise HTTPException(status_code=404, detail="Task not found.")
return TaskStatusResponse(
task_id=task.task_id,
task_status=task.task_status,
task_position=task_queue_position,
)
# Task status websocket
@app.websocket(
"/v1alpha/status/ws/{task_id}",
)
async def task_status_ws(
websocket: WebSocket,
orchestrator: Annotated[AsyncLocalOrchestrator, Depends(get_orchestrator)],
task_id: str,
):
await websocket.accept()
if task_id not in orchestrator.tasks:
await websocket.send_text(
WebsocketMessage(
message=MessageKind.ERROR, error="Task not found."
).model_dump_json()
)
await websocket.close()
return
task = orchestrator.tasks[task_id]
# Track active WebSocket connections for this job
orchestrator.task_subscribers[task_id].add(websocket)
try:
task_queue_position = await orchestrator.get_queue_position(task_id=task_id)
task_response = TaskStatusResponse(
task_id=task.task_id,
task_status=task.task_status,
task_position=task_queue_position,
)
await websocket.send_text(
WebsocketMessage(
message=MessageKind.CONNECTION, task=task_response
).model_dump_json()
)
while True:
task_queue_position = await orchestrator.get_queue_position(
task_id=task_id
)
task_response = TaskStatusResponse(
task_id=task.task_id,
task_status=task.task_status,
task_position=task_queue_position,
)
await websocket.send_text(
WebsocketMessage(
message=MessageKind.UPDATE, task=task_response
).model_dump_json()
)
# each client message will be interpreted as a request for update
msg = await websocket.receive_text()
_log.debug(f"Received message: {msg}")
except WebSocketDisconnect:
_log.info(f"WebSocket disconnected for job {task_id}")
finally:
orchestrator.task_subscribers[task_id].remove(websocket)
# Task result
@app.get(
"/v1alpha/result/{task_id}",
response_model=ConvertDocumentResponse,
responses={
200: {
"content": {"application/zip": {}},
}
},
)
async def task_result(
orchestrator: Annotated[AsyncLocalOrchestrator, Depends(get_orchestrator)],
task_id: str,
):
result = await orchestrator.task_result(task_id=task_id)
if result is None:
raise HTTPException(
status_code=404,
detail="Task result not found. Please wait for a completion status.",
)
return result
return app