mirror of
https://github.com/docling-project/docling-serve.git
synced 2025-11-29 08:33:50 +00:00
feat: Async api (#60)
Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import importlib.metadata
|
||||
import logging
|
||||
import tempfile
|
||||
@@ -6,23 +7,46 @@ from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, Dict, List, Optional, Union
|
||||
|
||||
from docling.datamodel.base_models import DocumentStream, InputFormat
|
||||
from docling.document_converter import DocumentConverter
|
||||
from fastapi import BackgroundTasks, FastAPI, UploadFile
|
||||
from fastapi import (
|
||||
BackgroundTasks,
|
||||
Depends,
|
||||
FastAPI,
|
||||
HTTPException,
|
||||
Query,
|
||||
UploadFile,
|
||||
WebSocket,
|
||||
WebSocketDisconnect,
|
||||
)
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import RedirectResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from docling_serve.docling_conversion import (
|
||||
from docling.datamodel.base_models import DocumentStream, InputFormat
|
||||
from docling.document_converter import DocumentConverter
|
||||
|
||||
from docling_serve.datamodel.convert import ConvertDocumentsOptions
|
||||
from docling_serve.datamodel.requests import (
|
||||
ConvertDocumentFileSourcesRequest,
|
||||
ConvertDocumentsOptions,
|
||||
ConvertDocumentsRequest,
|
||||
)
|
||||
from docling_serve.datamodel.responses import (
|
||||
ConvertDocumentResponse,
|
||||
HealthCheckResponse,
|
||||
MessageKind,
|
||||
TaskStatusResponse,
|
||||
WebsocketMessage,
|
||||
)
|
||||
from docling_serve.docling_conversion import (
|
||||
convert_documents,
|
||||
converters,
|
||||
get_pdf_pipeline_opts,
|
||||
)
|
||||
from docling_serve.engines import get_orchestrator
|
||||
from docling_serve.engines.async_local.orchestrator import (
|
||||
AsyncLocalOrchestrator,
|
||||
TaskNotFoundError,
|
||||
)
|
||||
from docling_serve.helper_functions import FormDepends
|
||||
from docling_serve.response_preparation import ConvertDocumentResponse, process_results
|
||||
from docling_serve.response_preparation import process_results
|
||||
from docling_serve.settings import docling_serve_settings
|
||||
|
||||
|
||||
@@ -72,9 +96,22 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
converters[options_hash].initialize_pipeline(InputFormat.PDF)
|
||||
|
||||
orchestrator = get_orchestrator()
|
||||
|
||||
# Start the background queue processor
|
||||
queue_task = asyncio.create_task(orchestrator.process_queue())
|
||||
|
||||
yield
|
||||
|
||||
# Cancel the background queue processor on shutdown
|
||||
queue_task.cancel()
|
||||
try:
|
||||
await queue_task
|
||||
except asyncio.CancelledError:
|
||||
_log.info("Queue processor cancelled.")
|
||||
|
||||
converters.clear()
|
||||
|
||||
# if WITH_UI:
|
||||
# gradio_ui.close()
|
||||
|
||||
@@ -84,7 +121,7 @@ async def lifespan(app: FastAPI):
|
||||
##################################
|
||||
|
||||
|
||||
def create_app():
|
||||
def create_app(): # noqa: C901
|
||||
try:
|
||||
version = importlib.metadata.version("docling_serve")
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
@@ -145,10 +182,6 @@ def create_app():
|
||||
)
|
||||
return response
|
||||
|
||||
# Status
|
||||
class HealthCheckResponse(BaseModel):
|
||||
status: str = "ok"
|
||||
|
||||
@app.get("/health")
|
||||
def health() -> HealthCheckResponse:
|
||||
return HealthCheckResponse()
|
||||
@@ -233,4 +266,129 @@ def create_app():
|
||||
|
||||
return response
|
||||
|
||||
# Convert a document from URL(s) using the async api
|
||||
@app.post(
|
||||
"/v1alpha/convert/source/async",
|
||||
response_model=TaskStatusResponse,
|
||||
)
|
||||
async def process_url_async(
|
||||
orchestrator: Annotated[AsyncLocalOrchestrator, Depends(get_orchestrator)],
|
||||
conversion_request: ConvertDocumentsRequest,
|
||||
):
|
||||
task = await orchestrator.enqueue(request=conversion_request)
|
||||
task_queue_position = await orchestrator.get_queue_position(
|
||||
task_id=task.task_id
|
||||
)
|
||||
return TaskStatusResponse(
|
||||
task_id=task.task_id,
|
||||
task_status=task.task_status,
|
||||
task_position=task_queue_position,
|
||||
)
|
||||
|
||||
# Task status poll
|
||||
@app.get(
|
||||
"/v1alpha/status/poll/{task_id}",
|
||||
response_model=TaskStatusResponse,
|
||||
)
|
||||
async def task_status_poll(
|
||||
orchestrator: Annotated[AsyncLocalOrchestrator, Depends(get_orchestrator)],
|
||||
task_id: str,
|
||||
wait: Annotated[
|
||||
float, Query(help="Number of seconds to wait for a completed status.")
|
||||
] = 0.0,
|
||||
):
|
||||
try:
|
||||
task = await orchestrator.task_status(task_id=task_id, wait=wait)
|
||||
task_queue_position = await orchestrator.get_queue_position(task_id=task_id)
|
||||
except TaskNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Task not found.")
|
||||
return TaskStatusResponse(
|
||||
task_id=task.task_id,
|
||||
task_status=task.task_status,
|
||||
task_position=task_queue_position,
|
||||
)
|
||||
|
||||
# Task status websocket
|
||||
@app.websocket(
|
||||
"/v1alpha/status/ws/{task_id}",
|
||||
)
|
||||
async def task_status_ws(
|
||||
websocket: WebSocket,
|
||||
orchestrator: Annotated[AsyncLocalOrchestrator, Depends(get_orchestrator)],
|
||||
task_id: str,
|
||||
):
|
||||
await websocket.accept()
|
||||
|
||||
if task_id not in orchestrator.tasks:
|
||||
await websocket.send_text(
|
||||
WebsocketMessage(
|
||||
message=MessageKind.ERROR, error="Task not found."
|
||||
).model_dump_json()
|
||||
)
|
||||
await websocket.close()
|
||||
return
|
||||
|
||||
task = orchestrator.tasks[task_id]
|
||||
|
||||
# Track active WebSocket connections for this job
|
||||
orchestrator.task_subscribers[task_id].add(websocket)
|
||||
|
||||
try:
|
||||
task_queue_position = await orchestrator.get_queue_position(task_id=task_id)
|
||||
task_response = TaskStatusResponse(
|
||||
task_id=task.task_id,
|
||||
task_status=task.task_status,
|
||||
task_position=task_queue_position,
|
||||
)
|
||||
await websocket.send_text(
|
||||
WebsocketMessage(
|
||||
message=MessageKind.CONNECTION, task=task_response
|
||||
).model_dump_json()
|
||||
)
|
||||
while True:
|
||||
task_queue_position = await orchestrator.get_queue_position(
|
||||
task_id=task_id
|
||||
)
|
||||
task_response = TaskStatusResponse(
|
||||
task_id=task.task_id,
|
||||
task_status=task.task_status,
|
||||
task_position=task_queue_position,
|
||||
)
|
||||
await websocket.send_text(
|
||||
WebsocketMessage(
|
||||
message=MessageKind.UPDATE, task=task_response
|
||||
).model_dump_json()
|
||||
)
|
||||
# each client message will be interpreted as a request for update
|
||||
msg = await websocket.receive_text()
|
||||
_log.debug(f"Received message: {msg}")
|
||||
|
||||
except WebSocketDisconnect:
|
||||
_log.info(f"WebSocket disconnected for job {task_id}")
|
||||
|
||||
finally:
|
||||
orchestrator.task_subscribers[task_id].remove(websocket)
|
||||
|
||||
# Task result
|
||||
@app.get(
|
||||
"/v1alpha/result/{task_id}",
|
||||
response_model=ConvertDocumentResponse,
|
||||
responses={
|
||||
200: {
|
||||
"content": {"application/zip": {}},
|
||||
}
|
||||
},
|
||||
)
|
||||
async def task_result(
|
||||
orchestrator: Annotated[AsyncLocalOrchestrator, Depends(get_orchestrator)],
|
||||
task_id: str,
|
||||
):
|
||||
result = await orchestrator.task_result(task_id=task_id)
|
||||
if result is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Task result not found. Please wait for a completion status.",
|
||||
)
|
||||
return result
|
||||
|
||||
return app
|
||||
|
||||
Reference in New Issue
Block a user