Files
docling-serve/docling_serve/ui/app.py

279 lines
9.8 KiB
Python

import io
import logging
from pathlib import Path
from typing import Annotated
from fastapi import (
BackgroundTasks,
Depends,
FastAPI,
Form,
HTTPException,
Request,
UploadFile,
status,
)
from fastapi.responses import HTMLResponse, RedirectResponse, Response
from fastapi.staticfiles import StaticFiles
from pydantic import AnyHttpUrl
from pyjsx import auto_setup # type: ignore
from starlette.exceptions import HTTPException as StarletteHTTPException
from docling.datamodel.base_models import OutputFormat
from docling_core.types.doc.document import (
FloatingItem,
PageItem,
RefItem,
)
from docling_jobkit.orchestrators.base_orchestrator import (
BaseOrchestrator,
)
from docling_serve.auth import APIKeyCookieAuth, AuthenticationResult
from docling_serve.datamodel.convert import ConvertDocumentsRequestOptions
from docling_serve.datamodel.requests import ConvertDocumentsRequest, HttpSourceRequest
from docling_serve.helper_functions import FormDepends
from docling_serve.orchestrator_factory import get_async_orchestrator
from docling_serve.settings import docling_serve_settings
from .convert import ConvertPage # type: ignore
from .pages import AuthPage, StatusPage, TaskPage, TasksPage # type: ignore
# Initialize JSX.
auto_setup
_log = logging.getLogger(__name__)
# TODO: Isolate passed functions into a controller?
def create_ui_app(process_file, process_url, task_result, task_status_poll) -> FastAPI: # noqa: C901
ui_app = FastAPI()
require_auth = APIKeyCookieAuth(docling_serve_settings.api_key)
# Static files.
ui_app.mount(
"/static",
StaticFiles(directory=Path(__file__).parent.absolute() / "static"),
name="static",
)
# Convert page.
@ui_app.get("/")
async def get_root():
return RedirectResponse(url="convert")
@ui_app.get("/convert", response_class=HTMLResponse)
async def get_convert(
auth: Annotated[AuthenticationResult, Depends(require_auth)],
):
return str(ConvertPage())
@ui_app.post("/convert", response_class=HTMLResponse)
async def post_convert(
auth: Annotated[AuthenticationResult, Depends(require_auth)],
orchestrator: Annotated[BaseOrchestrator, Depends(get_async_orchestrator)],
background_tasks: BackgroundTasks,
options: Annotated[
ConvertDocumentsRequestOptions, FormDepends(ConvertDocumentsRequestOptions)
],
files: Annotated[list[UploadFile], Form()],
url: Annotated[str, Form()],
page_min: Annotated[str, Form()],
page_max: Annotated[str, Form()],
):
# Refined model options and behavior.
if len(page_min) > 0:
options.page_range = (int(page_min), options.page_range[1])
if len(page_max) > 0:
options.page_range = (options.page_range[0], int(page_max))
options.ocr_lang = [
sub_lang.strip()
for lang in options.ocr_lang or []
for sub_lang in lang.split(",")
if len(sub_lang.strip()) > 0
]
files = [f for f in files if f.size]
if len(files) > 0:
# Directly uploaded documents.
response = await process_file(
auth=auth,
orchestrator=orchestrator,
background_tasks=background_tasks,
files=files,
options=options,
)
elif len(url.strip()) > 0:
# URLs of documents.
source = HttpSourceRequest(url=AnyHttpUrl(url))
request = ConvertDocumentsRequest(options=options, sources=[source])
response = await process_url(
auth=auth,
orchestrator=orchestrator,
conversion_request=request,
)
else:
validation = {
"files": "Upload files or enter a URL",
"url": "Enter a URL or upload files",
}
return str(ConvertPage(options=options, validation=validation))
return RedirectResponse(f"tasks/{response.task_id}/", status.HTTP_303_SEE_OTHER)
# Task overview page.
@ui_app.get("/tasks/", response_class=HTMLResponse)
async def get_tasks(
auth: Annotated[AuthenticationResult, Depends(require_auth)],
orchestrator: Annotated[BaseOrchestrator, Depends(get_async_orchestrator)],
):
tasks = sorted(orchestrator.tasks.values(), key=lambda t: t.created_at)
return str(TasksPage(tasks))
# Task specific page.
@ui_app.get("/tasks/{task_id}/", response_class=HTMLResponse)
async def get_task(
auth: Annotated[AuthenticationResult, Depends(require_auth)],
orchestrator: Annotated[BaseOrchestrator, Depends(get_async_orchestrator)],
background_tasks: BackgroundTasks,
task_id: str,
):
poll = await task_status_poll(auth, orchestrator, task_id)
result = None
if poll.task_status in ["success", "failure"]:
try:
result = await task_result(
auth, orchestrator, background_tasks, task_id
)
except Exception as ex:
logging.error(ex)
return str(TaskPage(poll, result))
# Poll task via HTTP status.
@ui_app.get("/tasks/{task_id}/poll", response_class=Response)
async def poll_task(
auth: Annotated[AuthenticationResult, Depends(require_auth)],
orchestrator: Annotated[BaseOrchestrator, Depends(get_async_orchestrator)],
task_id: str,
):
poll = await task_status_poll(auth, orchestrator, task_id)
return Response(
status_code=status.HTTP_202_ACCEPTED
if poll.task_status == "started"
else status.HTTP_200_OK
)
# Download the contents of zipped documents.
@ui_app.get("/tasks/{task_id}/documents.zip")
async def get_task_zip(
auth: Annotated[AuthenticationResult, Depends(require_auth)],
orchestrator: Annotated[BaseOrchestrator, Depends(get_async_orchestrator)],
background_tasks: BackgroundTasks,
task_id: str,
):
return await task_result(auth, orchestrator, background_tasks, task_id)
# Get the output of a task, as a converted document in a specific format.
@ui_app.get("/tasks/{task_id}/document.{format}")
async def get_task_document_format(
auth: Annotated[AuthenticationResult, Depends(require_auth)],
orchestrator: Annotated[BaseOrchestrator, Depends(get_async_orchestrator)],
background_tasks: BackgroundTasks,
task_id: str,
format: str,
):
if format not in [f.value for f in OutputFormat]:
raise HTTPException(status.HTTP_404_NOT_FOUND, "Output format not found.")
else:
response = await task_result(auth, orchestrator, background_tasks, task_id)
# TODO: Make this compatible with base_models FormatToMimeType?
mimes = {
"html": "text/html",
"md": "text/markdown",
"json": "application/json",
}
content = (
response.document.json_content.export_to_dict()
if format == OutputFormat.JSON
else response.document.dict()[f"{format}_content"]
)
return Response(
content=str(content),
media_type=mimes.get(format, "text/plain"),
)
@ui_app.get("/tasks/{task_id}/document/{cref:path}")
async def get_task_document_item(
request: Request,
auth: Annotated[AuthenticationResult, Depends(require_auth)],
orchestrator: Annotated[BaseOrchestrator, Depends(get_async_orchestrator)],
background_tasks: BackgroundTasks,
task_id: str,
cref: str,
):
response = await task_result(auth, orchestrator, background_tasks, task_id)
doc = response.document.json_content
item = RefItem(cref=f"#/{cref}").resolve(doc) # type: ignore
if "image/*" in (request.headers.get("Accept") or "") and isinstance(
item, FloatingItem | PageItem
):
content = io.BytesIO()
if (
isinstance(item, PageItem)
and (img_ref := item.image)
and img_ref.pil_image
):
img_ref.pil_image.save(content, format="PNG")
elif isinstance(item, FloatingItem) and (img := item.get_image(doc)):
img.save(content, format="PNG")
return Response(content=content.getvalue(), media_type="image/png")
else:
return item
# Page not found; catch all.
@ui_app.api_route("/{path_name:path}")
def no_page(
auth: Annotated[AuthenticationResult, Depends(require_auth)],
):
raise HTTPException(status.HTTP_404_NOT_FOUND, "Page not found.")
# Exception and auth pages.
@ui_app.exception_handler(StarletteHTTPException)
@ui_app.exception_handler(Exception)
async def exception_page(request: Request, ex: Exception):
if not isinstance(ex, StarletteHTTPException):
# Internal error.
ex = HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR)
if request.method == "POST":
# Authorization required -> API key dialog.
form = await request.form()
form_api_key = form.get("api_key")
if isinstance(form_api_key, str):
response = RedirectResponse(request.url, status.HTTP_303_SEE_OTHER)
require_auth._set_api_key(response, form_api_key)
return response
if ex.status_code == status.HTTP_401_UNAUTHORIZED:
return HTMLResponse(str(AuthPage()), status.HTTP_401_UNAUTHORIZED)
# HTTP exception page; avoid referer loop.
referer = request.headers.get("Referer")
if referer == request.url:
referer = None
return HTMLResponse(str(StatusPage(ex, referer)), ex.status_code)
return ui_app