mirror of
https://github.com/docling-project/docling-serve.git
synced 2025-12-01 17:43:18 +00:00
feat!: use orchestrators from jobkit (#248)
Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
@@ -28,12 +28,18 @@ from fastapi.staticfiles import StaticFiles
|
||||
from scalar_fastapi import get_scalar_api_reference
|
||||
|
||||
from docling.datamodel.base_models import DocumentStream
|
||||
|
||||
from docling_serve.datamodel.callback import (
|
||||
from docling_jobkit.datamodel.callback import (
|
||||
ProgressCallbackRequest,
|
||||
ProgressCallbackResponse,
|
||||
)
|
||||
from docling_serve.datamodel.convert import ConvertDocumentsOptions
|
||||
from docling_jobkit.datamodel.task import Task, TaskSource
|
||||
from docling_jobkit.orchestrators.base_orchestrator import (
|
||||
BaseOrchestrator,
|
||||
ProgressInvalid,
|
||||
TaskNotFoundError,
|
||||
)
|
||||
|
||||
from docling_serve.datamodel.convert import ConvertDocumentsRequestOptions
|
||||
from docling_serve.datamodel.requests import (
|
||||
ConvertDocumentFileSourcesRequest,
|
||||
ConvertDocumentHttpSourcesRequest,
|
||||
@@ -47,17 +53,12 @@ from docling_serve.datamodel.responses import (
|
||||
TaskStatusResponse,
|
||||
WebsocketMessage,
|
||||
)
|
||||
from docling_serve.datamodel.task import Task, TaskSource
|
||||
from docling_serve.docling_conversion import _get_converter_from_hash
|
||||
from docling_serve.engines.async_orchestrator import (
|
||||
BaseAsyncOrchestrator,
|
||||
ProgressInvalid,
|
||||
)
|
||||
from docling_serve.engines.async_orchestrator_factory import get_async_orchestrator
|
||||
from docling_serve.engines.base_orchestrator import TaskNotFoundError
|
||||
from docling_serve.helper_functions import FormDepends
|
||||
from docling_serve.orchestrator_factory import get_async_orchestrator
|
||||
from docling_serve.response_preparation import prepare_response
|
||||
from docling_serve.settings import docling_serve_settings
|
||||
from docling_serve.storage import get_scratch
|
||||
from docling_serve.websocker_notifier import WebsocketNotifier
|
||||
|
||||
|
||||
# Set up custom logging as we'll be intermixes with FastAPI/Uvicorn's logging
|
||||
@@ -95,9 +96,12 @@ _log = logging.getLogger(__name__)
|
||||
# Context manager to initialize and clean up the lifespan of the FastAPI app
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
orchestrator = get_async_orchestrator()
|
||||
scratch_dir = get_scratch()
|
||||
|
||||
orchestrator = get_async_orchestrator()
|
||||
notifier = WebsocketNotifier(orchestrator)
|
||||
orchestrator.bind_notifier(notifier)
|
||||
|
||||
# Warm up processing cache
|
||||
if docling_serve_settings.load_models_at_boot:
|
||||
await orchestrator.warm_up_caches()
|
||||
@@ -230,7 +234,7 @@ def create_app(): # noqa: C901
|
||||
########################
|
||||
|
||||
async def _enque_source(
|
||||
orchestrator: BaseAsyncOrchestrator, conversion_request: ConvertDocumentsRequest
|
||||
orchestrator: BaseOrchestrator, conversion_request: ConvertDocumentsRequest
|
||||
) -> Task:
|
||||
sources: list[TaskSource] = []
|
||||
if isinstance(conversion_request, ConvertDocumentFileSourcesRequest):
|
||||
@@ -244,9 +248,9 @@ def create_app(): # noqa: C901
|
||||
return task
|
||||
|
||||
async def _enque_file(
|
||||
orchestrator: BaseAsyncOrchestrator,
|
||||
orchestrator: BaseOrchestrator,
|
||||
files: list[UploadFile],
|
||||
options: ConvertDocumentsOptions,
|
||||
options: ConvertDocumentsRequestOptions,
|
||||
) -> Task:
|
||||
_log.info(f"Received {len(files)} files for processing.")
|
||||
|
||||
@@ -261,9 +265,7 @@ def create_app(): # noqa: C901
|
||||
task = await orchestrator.enqueue(sources=file_sources, options=options)
|
||||
return task
|
||||
|
||||
async def _wait_task_complete(
|
||||
orchestrator: BaseAsyncOrchestrator, task_id: str
|
||||
) -> bool:
|
||||
async def _wait_task_complete(orchestrator: BaseOrchestrator, task_id: str) -> bool:
|
||||
start_time = time.monotonic()
|
||||
while True:
|
||||
task = await orchestrator.task_status(task_id=task_id)
|
||||
@@ -309,32 +311,28 @@ def create_app(): # noqa: C901
|
||||
)
|
||||
async def process_url(
|
||||
background_tasks: BackgroundTasks,
|
||||
orchestrator: Annotated[BaseAsyncOrchestrator, Depends(get_async_orchestrator)],
|
||||
orchestrator: Annotated[BaseOrchestrator, Depends(get_async_orchestrator)],
|
||||
conversion_request: ConvertDocumentsRequest,
|
||||
):
|
||||
task = await _enque_source(
|
||||
orchestrator=orchestrator, conversion_request=conversion_request
|
||||
)
|
||||
success = await _wait_task_complete(
|
||||
completed = await _wait_task_complete(
|
||||
orchestrator=orchestrator, task_id=task.task_id
|
||||
)
|
||||
|
||||
if not success:
|
||||
if not completed:
|
||||
# TODO: abort task!
|
||||
return HTTPException(
|
||||
status_code=504,
|
||||
detail=f"Conversion is taking too long. The maximum wait time is configure as DOCLING_SERVE_MAX_SYNC_WAIT={docling_serve_settings.max_sync_wait}.",
|
||||
)
|
||||
|
||||
result = await orchestrator.task_result(
|
||||
task_id=task.task_id, background_tasks=background_tasks
|
||||
task = await orchestrator.get_raw_task(task_id=task.task_id)
|
||||
response = await prepare_response(
|
||||
task=task, orchestrator=orchestrator, background_tasks=background_tasks
|
||||
)
|
||||
if result is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Task result not found. Please wait for a completion status.",
|
||||
)
|
||||
return result
|
||||
return response
|
||||
|
||||
# Convert a document from file(s)
|
||||
@app.post(
|
||||
@@ -348,35 +346,31 @@ def create_app(): # noqa: C901
|
||||
)
|
||||
async def process_file(
|
||||
background_tasks: BackgroundTasks,
|
||||
orchestrator: Annotated[BaseAsyncOrchestrator, Depends(get_async_orchestrator)],
|
||||
orchestrator: Annotated[BaseOrchestrator, Depends(get_async_orchestrator)],
|
||||
files: list[UploadFile],
|
||||
options: Annotated[
|
||||
ConvertDocumentsOptions, FormDepends(ConvertDocumentsOptions)
|
||||
ConvertDocumentsRequestOptions, FormDepends(ConvertDocumentsRequestOptions)
|
||||
],
|
||||
):
|
||||
task = await _enque_file(
|
||||
orchestrator=orchestrator, files=files, options=options
|
||||
)
|
||||
success = await _wait_task_complete(
|
||||
completed = await _wait_task_complete(
|
||||
orchestrator=orchestrator, task_id=task.task_id
|
||||
)
|
||||
|
||||
if not success:
|
||||
if not completed:
|
||||
# TODO: abort task!
|
||||
return HTTPException(
|
||||
status_code=504,
|
||||
detail=f"Conversion is taking too long. The maximum wait time is configure as DOCLING_SERVE_MAX_SYNC_WAIT={docling_serve_settings.max_sync_wait}.",
|
||||
)
|
||||
|
||||
result = await orchestrator.task_result(
|
||||
task_id=task.task_id, background_tasks=background_tasks
|
||||
task = await orchestrator.get_raw_task(task_id=task.task_id)
|
||||
response = await prepare_response(
|
||||
task=task, orchestrator=orchestrator, background_tasks=background_tasks
|
||||
)
|
||||
if result is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Task result not found. Please wait for a completion status.",
|
||||
)
|
||||
return result
|
||||
return response
|
||||
|
||||
# Convert a document from URL(s) using the async api
|
||||
@app.post(
|
||||
@@ -384,7 +378,7 @@ def create_app(): # noqa: C901
|
||||
response_model=TaskStatusResponse,
|
||||
)
|
||||
async def process_url_async(
|
||||
orchestrator: Annotated[BaseAsyncOrchestrator, Depends(get_async_orchestrator)],
|
||||
orchestrator: Annotated[BaseOrchestrator, Depends(get_async_orchestrator)],
|
||||
conversion_request: ConvertDocumentsRequest,
|
||||
):
|
||||
task = await _enque_source(
|
||||
@@ -406,11 +400,11 @@ def create_app(): # noqa: C901
|
||||
response_model=TaskStatusResponse,
|
||||
)
|
||||
async def process_file_async(
|
||||
orchestrator: Annotated[BaseAsyncOrchestrator, Depends(get_async_orchestrator)],
|
||||
orchestrator: Annotated[BaseOrchestrator, Depends(get_async_orchestrator)],
|
||||
background_tasks: BackgroundTasks,
|
||||
files: list[UploadFile],
|
||||
options: Annotated[
|
||||
ConvertDocumentsOptions, FormDepends(ConvertDocumentsOptions)
|
||||
ConvertDocumentsRequestOptions, FormDepends(ConvertDocumentsRequestOptions)
|
||||
],
|
||||
):
|
||||
task = await _enque_file(
|
||||
@@ -432,7 +426,7 @@ def create_app(): # noqa: C901
|
||||
response_model=TaskStatusResponse,
|
||||
)
|
||||
async def task_status_poll(
|
||||
orchestrator: Annotated[BaseAsyncOrchestrator, Depends(get_async_orchestrator)],
|
||||
orchestrator: Annotated[BaseOrchestrator, Depends(get_async_orchestrator)],
|
||||
task_id: str,
|
||||
wait: Annotated[
|
||||
float, Query(help="Number of seconds to wait for a completed status.")
|
||||
@@ -456,9 +450,10 @@ def create_app(): # noqa: C901
|
||||
)
|
||||
async def task_status_ws(
|
||||
websocket: WebSocket,
|
||||
orchestrator: Annotated[BaseAsyncOrchestrator, Depends(get_async_orchestrator)],
|
||||
orchestrator: Annotated[BaseOrchestrator, Depends(get_async_orchestrator)],
|
||||
task_id: str,
|
||||
):
|
||||
assert isinstance(orchestrator.notifier, WebsocketNotifier)
|
||||
await websocket.accept()
|
||||
|
||||
if task_id not in orchestrator.tasks:
|
||||
@@ -473,7 +468,7 @@ def create_app(): # noqa: C901
|
||||
task = orchestrator.tasks[task_id]
|
||||
|
||||
# Track active WebSocket connections for this job
|
||||
orchestrator.task_subscribers[task_id].add(websocket)
|
||||
orchestrator.notifier.task_subscribers[task_id].add(websocket)
|
||||
|
||||
try:
|
||||
task_queue_position = await orchestrator.get_queue_position(task_id=task_id)
|
||||
@@ -511,7 +506,7 @@ def create_app(): # noqa: C901
|
||||
_log.info(f"WebSocket disconnected for job {task_id}")
|
||||
|
||||
finally:
|
||||
orchestrator.task_subscribers[task_id].remove(websocket)
|
||||
orchestrator.notifier.task_subscribers[task_id].remove(websocket)
|
||||
|
||||
# Task result
|
||||
@app.get(
|
||||
@@ -524,19 +519,18 @@ def create_app(): # noqa: C901
|
||||
},
|
||||
)
|
||||
async def task_result(
|
||||
orchestrator: Annotated[BaseAsyncOrchestrator, Depends(get_async_orchestrator)],
|
||||
orchestrator: Annotated[BaseOrchestrator, Depends(get_async_orchestrator)],
|
||||
background_tasks: BackgroundTasks,
|
||||
task_id: str,
|
||||
):
|
||||
result = await orchestrator.task_result(
|
||||
task_id=task_id, background_tasks=background_tasks
|
||||
)
|
||||
if result is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Task result not found. Please wait for a completion status.",
|
||||
try:
|
||||
task = await orchestrator.get_raw_task(task_id=task_id)
|
||||
response = await prepare_response(
|
||||
task=task, orchestrator=orchestrator, background_tasks=background_tasks
|
||||
)
|
||||
return result
|
||||
return response
|
||||
except TaskNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Task not found.")
|
||||
|
||||
# Update task progress
|
||||
@app.post(
|
||||
@@ -544,7 +538,7 @@ def create_app(): # noqa: C901
|
||||
response_model=ProgressCallbackResponse,
|
||||
)
|
||||
async def callback_task_progress(
|
||||
orchestrator: Annotated[BaseAsyncOrchestrator, Depends(get_async_orchestrator)],
|
||||
orchestrator: Annotated[BaseOrchestrator, Depends(get_async_orchestrator)],
|
||||
request: ProgressCallbackRequest,
|
||||
):
|
||||
try:
|
||||
@@ -564,8 +558,10 @@ def create_app(): # noqa: C901
|
||||
"/v1alpha/clear/converters",
|
||||
response_model=ClearResponse,
|
||||
)
|
||||
async def clear_converters():
|
||||
_get_converter_from_hash.cache_clear()
|
||||
async def clear_converters(
|
||||
orchestrator: Annotated[BaseOrchestrator, Depends(get_async_orchestrator)],
|
||||
):
|
||||
await orchestrator.clear_converters()
|
||||
return ClearResponse()
|
||||
|
||||
# Clean results
|
||||
@@ -574,7 +570,7 @@ def create_app(): # noqa: C901
|
||||
response_model=ClearResponse,
|
||||
)
|
||||
async def clear_results(
|
||||
orchestrator: Annotated[BaseAsyncOrchestrator, Depends(get_async_orchestrator)],
|
||||
orchestrator: Annotated[BaseOrchestrator, Depends(get_async_orchestrator)],
|
||||
older_then: float = 3600,
|
||||
):
|
||||
await orchestrator.clear_results(older_than=older_then)
|
||||
|
||||
Reference in New Issue
Block a user