Files
docling-serve/docling_serve/engines/async_kfp/orchestrator.py
2025-04-28 11:18:19 +02:00

236 lines
8.6 KiB
Python

import datetime
import json
import logging
import uuid
from pathlib import Path
from typing import Optional
from kfp_server_api.models import V2beta1RuntimeState
from pydantic import BaseModel, TypeAdapter
from pydantic_settings import SettingsConfigDict
from docling_serve.datamodel.callback import (
ProgressCallbackRequest,
ProgressSetNumDocs,
ProgressUpdateProcessed,
)
from docling_serve.datamodel.convert import ConvertDocumentsOptions
from docling_serve.datamodel.engines import TaskStatus
from docling_serve.datamodel.kfp import CallbackSpec
from docling_serve.datamodel.requests import HttpSource
from docling_serve.datamodel.task import Task, TaskSource
from docling_serve.datamodel.task_meta import TaskProcessingMeta
from docling_serve.engines.async_kfp.kfp_pipeline import process
from docling_serve.engines.async_orchestrator import (
BaseAsyncOrchestrator,
ProgressInvalid,
)
from docling_serve.settings import docling_serve_settings
_log = logging.getLogger(__name__)
class _RunItem(BaseModel):
model_config = SettingsConfigDict(arbitrary_types_allowed=True)
run_id: str
state: str
created_at: datetime.datetime
scheduled_at: datetime.datetime
finished_at: datetime.datetime
class AsyncKfpOrchestrator(BaseAsyncOrchestrator):
def __init__(self):
super().__init__()
import kfp
kfp_endpoint = docling_serve_settings.eng_kfp_endpoint
if kfp_endpoint is None:
raise ValueError("KFP endpoint is required when using the KFP engine.")
kube_sa_token_path = Path("/run/secrets/kubernetes.io/serviceaccount/token")
kube_sa_ca_cert_path = Path(
"/run/secrets/kubernetes.io/serviceaccount/service-ca.crt"
)
ssl_ca_cert = docling_serve_settings.eng_kfp_ca_cert_path
token = docling_serve_settings.eng_kfp_token
if (
ssl_ca_cert is None
and ".svc" in kfp_endpoint.host
and kube_sa_ca_cert_path.exists()
):
ssl_ca_cert = str(kube_sa_ca_cert_path)
if token is None and kube_sa_token_path.exists():
token = kube_sa_token_path.read_text()
self._client = kfp.Client(
host=str(kfp_endpoint),
existing_token=token,
ssl_ca_cert=ssl_ca_cert,
# verify_ssl=False,
)
async def enqueue(
self, sources: list[TaskSource], options: ConvertDocumentsOptions
) -> Task:
callbacks = []
if docling_serve_settings.eng_kfp_self_callback_endpoint is not None:
headers = {}
if docling_serve_settings.eng_kfp_self_callback_token_path is not None:
token = (
docling_serve_settings.eng_kfp_self_callback_token_path.read_text()
)
headers["Authorization"] = f"Bearer {token}"
ca_cert = ""
if docling_serve_settings.eng_kfp_self_callback_ca_cert_path is not None:
ca_cert = docling_serve_settings.eng_kfp_self_callback_ca_cert_path.read_text()
callbacks.append(
CallbackSpec(
url=docling_serve_settings.eng_kfp_self_callback_endpoint,
headers=headers,
ca_cert=ca_cert,
)
)
CallbacksType = TypeAdapter(list[CallbackSpec])
SourcesListType = TypeAdapter(list[HttpSource])
http_sources = [s for s in sources if isinstance(s, HttpSource)]
# hack: since the current kfp backend is not resolving the job_id placeholder,
# we set the run_name and pass it as argument to the job itself.
run_name = f"docling-job-{uuid.uuid4()}"
kfp_run = self._client.create_run_from_pipeline_func(
process,
arguments={
"batch_size": 10,
"sources": SourcesListType.dump_python(http_sources, mode="json"),
"options": options.model_dump(mode="json"),
"callbacks": CallbacksType.dump_python(callbacks, mode="json"),
"run_name": run_name,
},
run_name=run_name,
)
task_id = kfp_run.run_id
task = Task(task_id=task_id, sources=sources, options=options)
await self.init_task_tracking(task)
return task
async def _update_task_from_run(self, task_id: str, wait: float = 0.0):
run_info = self._client.get_run(run_id=task_id)
task = await self.get_raw_task(task_id=task_id)
# RUNTIME_STATE_UNSPECIFIED = "RUNTIME_STATE_UNSPECIFIED"
# PENDING = "PENDING"
# RUNNING = "RUNNING"
# SUCCEEDED = "SUCCEEDED"
# SKIPPED = "SKIPPED"
# FAILED = "FAILED"
# CANCELING = "CANCELING"
# CANCELED = "CANCELED"
# PAUSED = "PAUSED"
if run_info.state == V2beta1RuntimeState.SUCCEEDED:
task.task_status = TaskStatus.SUCCESS
elif run_info.state == V2beta1RuntimeState.PENDING:
task.task_status = TaskStatus.PENDING
elif run_info.state == V2beta1RuntimeState.RUNNING:
task.task_status = TaskStatus.STARTED
else:
task.task_status = TaskStatus.FAILURE
async def task_status(self, task_id: str, wait: float = 0.0) -> Task:
await self._update_task_from_run(task_id=task_id, wait=wait)
return await self.get_raw_task(task_id=task_id)
async def _get_pending(self) -> list[_RunItem]:
runs: list[_RunItem] = []
next_page: Optional[str] = None
while True:
res = self._client.list_runs(
page_token=next_page,
page_size=20,
filter=json.dumps(
{
"predicates": [
{
"operation": "EQUALS",
"key": "state",
"stringValue": "PENDING",
}
]
}
),
)
if res.runs is not None:
for run in res.runs:
runs.append(
_RunItem(
run_id=run.run_id,
state=run.state,
created_at=run.created_at,
scheduled_at=run.scheduled_at,
finished_at=run.finished_at,
)
)
if res.next_page_token is None:
break
next_page = res.next_page_token
return runs
async def queue_size(self) -> int:
runs = await self._get_pending()
return len(runs)
async def get_queue_position(self, task_id: str) -> Optional[int]:
runs = await self._get_pending()
for pos, run in enumerate(runs, start=1):
if run.run_id == task_id:
return pos
return None
async def process_queue(self):
return
async def warm_up_caches(self):
return
async def _get_run_id(self, run_name: str) -> str:
res = self._client.list_runs(
filter=json.dumps(
{
"predicates": [
{
"operation": "EQUALS",
"key": "name",
"stringValue": run_name,
}
]
}
),
)
if res.runs is not None and len(res.runs) > 0:
return res.runs[0].run_id
raise RuntimeError(f"Run with {run_name=} not found.")
async def receive_task_progress(self, request: ProgressCallbackRequest):
task_id = await self._get_run_id(run_name=request.task_id)
progress = request.progress
task = await self.get_raw_task(task_id=task_id)
if isinstance(progress, ProgressSetNumDocs):
task.processing_meta = TaskProcessingMeta(num_docs=progress.num_docs)
task.task_status = TaskStatus.STARTED
elif isinstance(progress, ProgressUpdateProcessed):
if task.processing_meta is None:
raise ProgressInvalid(
"UpdateProcessed was called before setting the expected number of documents."
)
task.processing_meta.num_processed += progress.num_processed
task.processing_meta.num_succeeded += progress.num_succeeded
task.processing_meta.num_failed += progress.num_failed
task.task_status = TaskStatus.STARTED
# TODO: could be moved to BackgroundTask
await self.notify_task_subscribers(task_id=task_id)