diff --git a/docling_serve/datamodel/convert.py b/docling_serve/datamodel/convert.py index fd58f16..f48438a 100644 --- a/docling_serve/datamodel/convert.py +++ b/docling_serve/datamodel/convert.py @@ -359,14 +359,24 @@ class ConvertDocumentsOptions(BaseModel): picture_description_local: Annotated[ Optional[PictureDescriptionLocal], Field( - description="Options for running a local vision-language model in the picture description. The parameters refer to a model hosted on Hugging Face. This parameter is mutually exclusive with picture_description_api." + description="Options for running a local vision-language model in the picture description. The parameters refer to a model hosted on Hugging Face. This parameter is mutually exclusive with picture_description_api.", + examples=[ + PictureDescriptionLocal(repo_id="ibm-granite/granite-vision-3.2-2b"), + PictureDescriptionLocal(repo_id="HuggingFaceTB/SmolVLM-256M-Instruct"), + ], ), ] = None picture_description_api: Annotated[ Optional[PictureDescriptionApi], Field( - description="API details for using a vision-language model in the picture description. This parameter is mutually exclusive with picture_description_local." + description="API details for using a vision-language model in the picture description. This parameter is mutually exclusive with picture_description_local.", + examples=[ + PictureDescriptionApi( + url="http://localhost:11434/v1/chat/completions", + params={"model": "granite3.2-vision:2b"}, + ) + ], ), ] = None diff --git a/docling_serve/helper_functions.py b/docling_serve/helper_functions.py index c42a391..20982d1 100644 --- a/docling_serve/helper_functions.py +++ b/docling_serve/helper_functions.py @@ -1,9 +1,30 @@ import inspect +import json import re -from typing import Union +from typing import Union, get_args, get_origin from fastapi import Depends, Form -from pydantic import BaseModel +from pydantic import BaseModel, TypeAdapter + + +def is_pydantic_model(type_): + try: + if inspect.isclass(type_) and issubclass(type_, BaseModel): + return True + + origin = get_origin(type_) + if origin is Union: + args = get_args(type_) + return any( + inspect.isclass(arg) and issubclass(arg, BaseModel) + for arg in args + if arg is not type(None) + ) + + except Exception: + pass + + return False # Adapted from @@ -12,25 +33,62 @@ def FormDepends(cls: type[BaseModel]): new_parameters = [] for field_name, model_field in cls.model_fields.items(): + annotation = model_field.annotation + description = model_field.description + default = ( + Form(..., description=description) + if model_field.is_required() + else Form( + model_field.default, + examples=model_field.examples, + description=description, + ) + ) + + # Flatten nested Pydantic models by accepting them as JSON strings + if is_pydantic_model(annotation): + annotation = str + default = Form( + None + if model_field.default is None + else json.dumps(model_field.default.model_dump(mode="json")), + description=description, + examples=None + if not model_field.examples + else [ + json.dumps(ex.model_dump(mode="json")) + for ex in model_field.examples + ], + ) + new_parameters.append( inspect.Parameter( name=field_name, kind=inspect.Parameter.POSITIONAL_ONLY, - default=( - Form(...) - if model_field.is_required() - else Form(model_field.default) - ), - annotation=model_field.annotation, + default=default, + annotation=annotation, ) ) async def as_form_func(**data): + for field_name, model_field in cls.model_fields.items(): + value = data.get(field_name) + annotation = model_field.annotation + + # Parse nested models from JSON string + if value is not None and is_pydantic_model(annotation): + try: + validator = TypeAdapter(annotation) + data[field_name] = validator.validate_json(value) + except Exception as e: + raise ValueError(f"Invalid JSON for field '{field_name}': {e}") + return cls(**data) sig = inspect.signature(as_form_func) sig = sig.replace(parameters=new_parameters) as_form_func.__signature__ = sig # type: ignore + return Depends(as_form_func) diff --git a/tests/test_file_opts.py b/tests/test_file_opts.py new file mode 100644 index 0000000..288b314 --- /dev/null +++ b/tests/test_file_opts.py @@ -0,0 +1,77 @@ +import asyncio +import json +import os + +import pytest +import pytest_asyncio +from asgi_lifespan import LifespanManager +from httpx import ASGITransport, AsyncClient + +from docling_core.types import DoclingDocument +from docling_core.types.doc.document import PictureDescriptionData + +from docling_serve.app import create_app + + +@pytest.fixture(scope="session") +def event_loop(): + return asyncio.get_event_loop() + + +@pytest_asyncio.fixture(scope="session") +async def app(): + app = create_app() + + async with LifespanManager(app) as manager: + print("Launching lifespan of app.") + yield manager.app + + +@pytest_asyncio.fixture(scope="session") +async def client(app): + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://app.io" + ) as client: + print("Client is ready") + yield client + + +@pytest.mark.asyncio +async def test_convert_file(client: AsyncClient): + """Test convert single file to all outputs""" + + endpoint = "/v1alpha/convert/file" + options = { + "to_formats": ["md", "json"], + "image_export_mode": "placeholder", + "ocr": False, + "do_picture_description": True, + "picture_description_api": json.dumps( + { + "url": "http://localhost:11434/v1/chat/completions", # ollama + "params": {"model": "granite3.2-vision:2b"}, + "timeout": 60, + "prompt": "Describe this image in a few sentences. ", + } + ), + } + + current_dir = os.path.dirname(__file__) + file_path = os.path.join(current_dir, "2206.01062v1.pdf") + + files = { + "files": ("2206.01062v1.pdf", open(file_path, "rb"), "application/pdf"), + } + + response = await client.post(endpoint, files=files, data=options) + assert response.status_code == 200, "Response should be 200 OK" + + data = response.json() + + doc = DoclingDocument.model_validate(data["document"]["json_content"]) + + for pic in doc.pictures: + for ann in pic.annotations: + if isinstance(ann, PictureDescriptionData): + print(f"{pic.self_ref}") + print(ann.text)