mirror of
https://github.com/docling-project/docling-serve.git
synced 2025-11-29 16:43:24 +00:00
138 lines
4.1 KiB
Python
138 lines
4.1 KiB
Python
# ruff: noqa: E402, UP006, UP035
|
|
|
|
from typing import Any, Dict, List
|
|
|
|
from kfp import dsl
|
|
|
|
PYTHON_BASE_IMAGE = "python:3.12"
|
|
|
|
|
|
@dsl.component(
|
|
base_image=PYTHON_BASE_IMAGE,
|
|
packages_to_install=[
|
|
"pydantic",
|
|
"docling-serve @ git+https://github.com/docling-project/docling-serve@feat-kfp-engine",
|
|
],
|
|
pip_index_urls=["https://download.pytorch.org/whl/cpu", "https://pypi.org/simple"],
|
|
)
|
|
def generate_chunks(
|
|
run_name: str,
|
|
request: Dict[str, Any],
|
|
batch_size: int,
|
|
callbacks: List[Dict[str, Any]],
|
|
) -> List[List[Dict[str, Any]]]:
|
|
from pydantic import TypeAdapter
|
|
|
|
from docling_serve.datamodel.callback import (
|
|
ProgressCallbackRequest,
|
|
ProgressSetNumDocs,
|
|
)
|
|
from docling_serve.datamodel.kfp import CallbackSpec
|
|
from docling_serve.engines.async_kfp.notify import notify_callbacks
|
|
|
|
CallbacksListType = TypeAdapter(list[CallbackSpec])
|
|
|
|
sources = request["http_sources"]
|
|
splits = [sources[i : i + batch_size] for i in range(0, len(sources), batch_size)]
|
|
|
|
total = sum(len(chunk) for chunk in splits)
|
|
payload = ProgressCallbackRequest(
|
|
task_id=run_name, progress=ProgressSetNumDocs(num_docs=total)
|
|
)
|
|
notify_callbacks(
|
|
payload=payload,
|
|
callbacks=CallbacksListType.validate_python(callbacks),
|
|
)
|
|
|
|
return splits
|
|
|
|
|
|
@dsl.component(
|
|
base_image=PYTHON_BASE_IMAGE,
|
|
packages_to_install=[
|
|
"pydantic",
|
|
"docling-serve @ git+https://github.com/docling-project/docling-serve@feat-kfp-engine",
|
|
],
|
|
pip_index_urls=["https://download.pytorch.org/whl/cpu", "https://pypi.org/simple"],
|
|
)
|
|
def convert_batch(
|
|
run_name: str,
|
|
data_splits: List[Dict[str, Any]],
|
|
request: Dict[str, Any],
|
|
callbacks: List[Dict[str, Any]],
|
|
output_path: dsl.OutputPath("Directory"), # type: ignore
|
|
):
|
|
from pathlib import Path
|
|
|
|
from pydantic import AnyUrl, TypeAdapter
|
|
|
|
from docling_serve.datamodel.callback import (
|
|
FailedDocsItem,
|
|
ProgressCallbackRequest,
|
|
ProgressUpdateProcessed,
|
|
SucceededDocsItem,
|
|
)
|
|
from docling_serve.datamodel.convert import ConvertDocumentsOptions
|
|
from docling_serve.datamodel.kfp import CallbackSpec
|
|
from docling_serve.datamodel.requests import HttpSource
|
|
from docling_serve.engines.async_kfp.notify import notify_callbacks
|
|
|
|
CallbacksListType = TypeAdapter(list[CallbackSpec])
|
|
|
|
convert_options = ConvertDocumentsOptions.model_validate(request["options"])
|
|
print(convert_options)
|
|
|
|
output_dir = Path(output_path)
|
|
output_dir.mkdir(exist_ok=True, parents=True)
|
|
docs_succeeded: list[SucceededDocsItem] = []
|
|
docs_failed: list[FailedDocsItem] = []
|
|
for source_dict in data_splits:
|
|
source = HttpSource.model_validate(source_dict)
|
|
filename = Path(str(AnyUrl(source.url).path)).name
|
|
output_filename = output_dir / filename
|
|
print(f"Writing {output_filename}")
|
|
with output_filename.open("w") as f:
|
|
f.write(source.model_dump_json())
|
|
docs_succeeded.append(SucceededDocsItem(source=source.url))
|
|
|
|
payload = ProgressCallbackRequest(
|
|
task_id=run_name,
|
|
progress=ProgressUpdateProcessed(
|
|
num_failed=len(docs_failed),
|
|
num_processed=len(docs_succeeded) + len(docs_failed),
|
|
num_succeeded=len(docs_succeeded),
|
|
docs_succeeded=docs_succeeded,
|
|
docs_failed=docs_failed,
|
|
),
|
|
)
|
|
|
|
print(payload)
|
|
notify_callbacks(
|
|
payload=payload,
|
|
callbacks=CallbacksListType.validate_python(callbacks),
|
|
)
|
|
|
|
|
|
@dsl.pipeline()
|
|
def process(
|
|
batch_size: int,
|
|
request: Dict[str, Any],
|
|
callbacks: List[Dict[str, Any]] = [],
|
|
run_name: str = "",
|
|
):
|
|
chunks_task = generate_chunks(
|
|
run_name=run_name,
|
|
request=request,
|
|
batch_size=batch_size,
|
|
callbacks=callbacks,
|
|
)
|
|
chunks_task.set_caching_options(False)
|
|
|
|
with dsl.ParallelFor(chunks_task.output, parallelism=4) as data_splits:
|
|
convert_batch(
|
|
run_name=run_name,
|
|
data_splits=data_splits,
|
|
request=request,
|
|
callbacks=callbacks,
|
|
)
|