mirror of
https://github.com/docling-project/docling-serve.git
synced 2025-11-29 00:23:36 +00:00
337 lines
13 KiB
Python
337 lines
13 KiB
Python
import json
|
|
import logging
|
|
from functools import lru_cache
|
|
from typing import Any, Optional
|
|
|
|
import redis.asyncio as redis
|
|
|
|
from docling_jobkit.datamodel.task import Task
|
|
from docling_jobkit.datamodel.task_meta import TaskStatus
|
|
from docling_jobkit.orchestrators.base_orchestrator import (
|
|
BaseOrchestrator,
|
|
TaskNotFoundError,
|
|
)
|
|
|
|
from docling_serve.settings import AsyncEngine, docling_serve_settings
|
|
from docling_serve.storage import get_scratch
|
|
|
|
_log = logging.getLogger(__name__)
|
|
|
|
|
|
class RedisTaskStatusMixin:
|
|
tasks: dict[str, Task]
|
|
_task_result_keys: dict[str, str]
|
|
config: Any
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.redis_prefix = "docling:tasks:"
|
|
self._redis_pool = redis.ConnectionPool.from_url(
|
|
self.config.redis_url,
|
|
max_connections=10,
|
|
socket_timeout=2.0,
|
|
)
|
|
|
|
async def task_status(self, task_id: str, wait: float = 0.0) -> Task:
|
|
"""
|
|
Get task status by checking Redis first, then falling back to RQ verification.
|
|
|
|
When Redis shows 'pending' but RQ shows 'success', we update Redis
|
|
and return the RQ status for cross-instance consistency.
|
|
"""
|
|
_log.info(f"Task {task_id} status check")
|
|
|
|
# Always check RQ directly first - this is the most reliable source
|
|
rq_task = await self._get_task_from_rq_direct(task_id)
|
|
if rq_task:
|
|
_log.info(f"Task {task_id} in RQ: {rq_task.task_status}")
|
|
|
|
# Update memory registry
|
|
self.tasks[task_id] = rq_task
|
|
|
|
# Store/update in Redis for other instances
|
|
await self._store_task_in_redis(rq_task)
|
|
return rq_task
|
|
|
|
# If not in RQ, check Redis (maybe it's cached from another instance)
|
|
task = await self._get_task_from_redis(task_id)
|
|
if task:
|
|
_log.info(f"Task {task_id} in Redis: {task.task_status}")
|
|
|
|
# CRITICAL FIX: Check if Redis status might be stale
|
|
# STARTED tasks might have completed since they were cached
|
|
if task.task_status in [TaskStatus.PENDING, TaskStatus.STARTED]:
|
|
_log.debug(f"Task {task_id} verifying stale status")
|
|
|
|
# Try to get fresh status from RQ
|
|
fresh_rq_task = await self._get_task_from_rq_direct(task_id)
|
|
if fresh_rq_task and fresh_rq_task.task_status != task.task_status:
|
|
_log.info(
|
|
f"Task {task_id} status updated: {fresh_rq_task.task_status}"
|
|
)
|
|
|
|
# Update memory and Redis with fresh status
|
|
self.tasks[task_id] = fresh_rq_task
|
|
await self._store_task_in_redis(fresh_rq_task)
|
|
return fresh_rq_task
|
|
else:
|
|
_log.debug(f"Task {task_id} status consistent")
|
|
|
|
return task
|
|
|
|
# Fall back to parent implementation
|
|
try:
|
|
parent_task = await super().task_status(task_id, wait) # type: ignore[misc]
|
|
_log.debug(f"Task {task_id} from parent: {parent_task.task_status}")
|
|
|
|
# Store in Redis for other instances to find
|
|
await self._store_task_in_redis(parent_task)
|
|
return parent_task
|
|
except TaskNotFoundError:
|
|
_log.warning(f"Task {task_id} not found")
|
|
raise
|
|
|
|
async def _get_task_from_redis(self, task_id: str) -> Optional[Task]:
|
|
try:
|
|
async with redis.Redis(connection_pool=self._redis_pool) as r:
|
|
task_data = await r.get(f"{self.redis_prefix}{task_id}:metadata")
|
|
if not task_data:
|
|
return None
|
|
|
|
data: dict[str, Any] = json.loads(task_data)
|
|
meta = data.get("processing_meta") or {}
|
|
meta.setdefault("num_docs", 0)
|
|
meta.setdefault("num_processed", 0)
|
|
meta.setdefault("num_succeeded", 0)
|
|
meta.setdefault("num_failed", 0)
|
|
|
|
return Task(
|
|
task_id=data["task_id"],
|
|
task_type=data["task_type"],
|
|
task_status=TaskStatus(data["task_status"]),
|
|
processing_meta=meta,
|
|
)
|
|
except Exception as e:
|
|
_log.error(f"Redis get task {task_id}: {e}")
|
|
return None
|
|
|
|
async def _get_task_from_rq_direct(self, task_id: str) -> Optional[Task]:
|
|
try:
|
|
_log.debug(f"Checking RQ for task {task_id}")
|
|
|
|
temp_task = Task(
|
|
task_id=task_id,
|
|
task_type="convert",
|
|
task_status=TaskStatus.PENDING,
|
|
processing_meta={
|
|
"num_docs": 0,
|
|
"num_processed": 0,
|
|
"num_succeeded": 0,
|
|
"num_failed": 0,
|
|
},
|
|
)
|
|
|
|
original_task = self.tasks.get(task_id)
|
|
self.tasks[task_id] = temp_task
|
|
|
|
try:
|
|
await super()._update_task_from_rq(task_id) # type: ignore[misc]
|
|
|
|
updated_task = self.tasks.get(task_id)
|
|
if updated_task and updated_task.task_status != TaskStatus.PENDING:
|
|
_log.debug(f"RQ task {task_id}: {updated_task.task_status}")
|
|
|
|
# Store result key if available
|
|
if task_id in self._task_result_keys:
|
|
try:
|
|
async with redis.Redis(
|
|
connection_pool=self._redis_pool
|
|
) as r:
|
|
await r.set(
|
|
f"{self.redis_prefix}{task_id}:result_key",
|
|
self._task_result_keys[task_id],
|
|
ex=86400,
|
|
)
|
|
_log.debug(f"Stored result key for {task_id}")
|
|
except Exception as e:
|
|
_log.error(f"Store result key {task_id}: {e}")
|
|
|
|
return updated_task
|
|
return None
|
|
|
|
finally:
|
|
# Restore original task state
|
|
if original_task:
|
|
self.tasks[task_id] = original_task
|
|
elif task_id in self.tasks and self.tasks[task_id] == temp_task:
|
|
# Only remove if it's still our temp task
|
|
del self.tasks[task_id]
|
|
|
|
except Exception as e:
|
|
_log.error(f"RQ check {task_id}: {e}")
|
|
return None
|
|
|
|
async def get_raw_task(self, task_id: str) -> Task:
|
|
if task_id in self.tasks:
|
|
return self.tasks[task_id]
|
|
|
|
task = await self._get_task_from_redis(task_id)
|
|
if task:
|
|
self.tasks[task_id] = task
|
|
return task
|
|
|
|
try:
|
|
parent_task = await super().get_raw_task(task_id) # type: ignore[misc]
|
|
await self._store_task_in_redis(parent_task)
|
|
return parent_task
|
|
except TaskNotFoundError:
|
|
raise
|
|
|
|
async def _store_task_in_redis(self, task: Task) -> None:
|
|
try:
|
|
meta: Any = task.processing_meta
|
|
if hasattr(meta, "model_dump"):
|
|
meta = meta.model_dump()
|
|
elif not isinstance(meta, dict):
|
|
meta = {
|
|
"num_docs": 0,
|
|
"num_processed": 0,
|
|
"num_succeeded": 0,
|
|
"num_failed": 0,
|
|
}
|
|
|
|
data: dict[str, Any] = {
|
|
"task_id": task.task_id,
|
|
"task_type": task.task_type.value
|
|
if hasattr(task.task_type, "value")
|
|
else str(task.task_type),
|
|
"task_status": task.task_status.value,
|
|
"processing_meta": meta,
|
|
}
|
|
async with redis.Redis(connection_pool=self._redis_pool) as r:
|
|
await r.set(
|
|
f"{self.redis_prefix}{task.task_id}:metadata",
|
|
json.dumps(data),
|
|
ex=86400,
|
|
)
|
|
except Exception as e:
|
|
_log.error(f"Store task {task.task_id}: {e}")
|
|
|
|
async def enqueue(self, **kwargs): # type: ignore[override]
|
|
task = await super().enqueue(**kwargs) # type: ignore[misc]
|
|
await self._store_task_in_redis(task)
|
|
return task
|
|
|
|
async def task_result(self, task_id: str): # type: ignore[override]
|
|
result = await super().task_result(task_id) # type: ignore[misc]
|
|
if result is not None:
|
|
return result
|
|
|
|
try:
|
|
async with redis.Redis(connection_pool=self._redis_pool) as r:
|
|
result_key = await r.get(f"{self.redis_prefix}{task_id}:result_key")
|
|
if result_key:
|
|
self._task_result_keys[task_id] = result_key.decode("utf-8")
|
|
return await super().task_result(task_id) # type: ignore[misc]
|
|
except Exception as e:
|
|
_log.error(f"Redis result key {task_id}: {e}")
|
|
|
|
return None
|
|
|
|
async def _update_task_from_rq(self, task_id: str) -> None:
|
|
original_status = (
|
|
self.tasks[task_id].task_status if task_id in self.tasks else None
|
|
)
|
|
|
|
await super()._update_task_from_rq(task_id) # type: ignore[misc]
|
|
|
|
if task_id in self.tasks:
|
|
new_status = self.tasks[task_id].task_status
|
|
if original_status != new_status:
|
|
_log.debug(f"Task {task_id} status: {original_status} -> {new_status}")
|
|
await self._store_task_in_redis(self.tasks[task_id])
|
|
|
|
if task_id in self._task_result_keys:
|
|
try:
|
|
async with redis.Redis(connection_pool=self._redis_pool) as r:
|
|
await r.set(
|
|
f"{self.redis_prefix}{task_id}:result_key",
|
|
self._task_result_keys[task_id],
|
|
ex=86400,
|
|
)
|
|
except Exception as e:
|
|
_log.error(f"Store result key {task_id}: {e}")
|
|
|
|
|
|
@lru_cache
|
|
def get_async_orchestrator() -> BaseOrchestrator:
|
|
if docling_serve_settings.eng_kind == AsyncEngine.LOCAL:
|
|
from docling_jobkit.convert.manager import (
|
|
DoclingConverterManager,
|
|
DoclingConverterManagerConfig,
|
|
)
|
|
from docling_jobkit.orchestrators.local.orchestrator import (
|
|
LocalOrchestrator,
|
|
LocalOrchestratorConfig,
|
|
)
|
|
|
|
local_config = LocalOrchestratorConfig(
|
|
num_workers=docling_serve_settings.eng_loc_num_workers,
|
|
shared_models=docling_serve_settings.eng_loc_share_models,
|
|
scratch_dir=get_scratch(),
|
|
)
|
|
|
|
cm_config = DoclingConverterManagerConfig(
|
|
artifacts_path=docling_serve_settings.artifacts_path,
|
|
options_cache_size=docling_serve_settings.options_cache_size,
|
|
enable_remote_services=docling_serve_settings.enable_remote_services,
|
|
allow_external_plugins=docling_serve_settings.allow_external_plugins,
|
|
max_num_pages=docling_serve_settings.max_num_pages,
|
|
max_file_size=docling_serve_settings.max_file_size,
|
|
queue_max_size=docling_serve_settings.queue_max_size,
|
|
ocr_batch_size=docling_serve_settings.ocr_batch_size,
|
|
layout_batch_size=docling_serve_settings.layout_batch_size,
|
|
table_batch_size=docling_serve_settings.table_batch_size,
|
|
batch_polling_interval_seconds=docling_serve_settings.batch_polling_interval_seconds,
|
|
)
|
|
cm = DoclingConverterManager(config=cm_config)
|
|
|
|
return LocalOrchestrator(config=local_config, converter_manager=cm)
|
|
|
|
elif docling_serve_settings.eng_kind == AsyncEngine.RQ:
|
|
from docling_jobkit.orchestrators.rq.orchestrator import (
|
|
RQOrchestrator,
|
|
RQOrchestratorConfig,
|
|
)
|
|
|
|
class RedisAwareRQOrchestrator(RedisTaskStatusMixin, RQOrchestrator): # type: ignore[misc]
|
|
pass
|
|
|
|
rq_config = RQOrchestratorConfig(
|
|
redis_url=docling_serve_settings.eng_rq_redis_url,
|
|
results_prefix=docling_serve_settings.eng_rq_results_prefix,
|
|
sub_channel=docling_serve_settings.eng_rq_sub_channel,
|
|
scratch_dir=get_scratch(),
|
|
)
|
|
|
|
return RedisAwareRQOrchestrator(config=rq_config)
|
|
|
|
elif docling_serve_settings.eng_kind == AsyncEngine.KFP:
|
|
from docling_jobkit.orchestrators.kfp.orchestrator import (
|
|
KfpOrchestrator,
|
|
KfpOrchestratorConfig,
|
|
)
|
|
|
|
kfp_config = KfpOrchestratorConfig(
|
|
endpoint=docling_serve_settings.eng_kfp_endpoint,
|
|
token=docling_serve_settings.eng_kfp_token,
|
|
ca_cert_path=docling_serve_settings.eng_kfp_ca_cert_path,
|
|
self_callback_endpoint=docling_serve_settings.eng_kfp_self_callback_endpoint,
|
|
self_callback_token_path=docling_serve_settings.eng_kfp_self_callback_token_path,
|
|
self_callback_ca_cert_path=docling_serve_settings.eng_kfp_self_callback_ca_cert_path,
|
|
)
|
|
|
|
return KfpOrchestrator(config=kfp_config)
|
|
|
|
raise RuntimeError(f"Engine {docling_serve_settings.eng_kind} not recognized.")
|