Merge pull request #41 from datalab-to/dev

Enable piping through params
This commit is contained in:
Vik Paruchuri
2025-11-12 18:06:41 -05:00
committed by GitHub
3 changed files with 17 additions and 7 deletions

View File

@@ -13,16 +13,23 @@ def flatten(page, flag=pdfium_c.FLAT_NORMALDISPLAY):
print(f"Failed to flatten annotations / form fields on page {page}.")
def load_image(filepath: str):
def load_image(
filepath: str, min_image_dim: int = settings.MIN_IMAGE_DIM
) -> Image.Image:
image = Image.open(filepath).convert("RGB")
if image.width < settings.MIN_IMAGE_DIM or image.height < settings.MIN_IMAGE_DIM:
scale = settings.MIN_IMAGE_DIM / min(image.width, image.height)
if image.width < min_image_dim or image.height < min_image_dim:
scale = min_image_dim / min(image.width, image.height)
new_size = (int(image.width * scale), int(image.height * scale))
image = image.resize(new_size, Image.Resampling.LANCZOS)
return image
def load_pdf_images(filepath: str, page_range: List[int]):
def load_pdf_images(
filepath: str,
page_range: List[int],
image_dpi: int = settings.IMAGE_DPI,
min_pdf_image_dim: int = settings.MIN_PDF_IMAGE_DIM,
) -> List[Image.Image]:
doc = pdfium.PdfDocument(filepath)
doc.init_forms()
@@ -31,8 +38,8 @@ def load_pdf_images(filepath: str, page_range: List[int]):
if not page_range or page in page_range:
page_obj = doc[page]
min_page_dim = min(page_obj.get_width(), page_obj.get_height())
scale_dpi = (settings.MIN_PDF_IMAGE_DIM / min_page_dim) * 72
scale_dpi = max(scale_dpi, settings.IMAGE_DPI)
scale_dpi = (min_pdf_image_dim / min_page_dim) * 72
scale_dpi = max(scale_dpi, image_dpi)
page_obj = doc[page]
flatten(page_obj)
page_obj = doc[page]

View File

@@ -28,12 +28,14 @@ class InferenceManager:
"include_headers_footers"
)
bbox_scale = kwargs.pop("bbox_scale", settings.BBOX_SCALE)
vllm_api_base = kwargs.pop("vllm_api_base", settings.VLLM_API_BASE)
if self.method == "vllm":
results = generate_vllm(
batch,
max_output_tokens=max_output_tokens,
bbox_scale=bbox_scale,
vllm_api_base=vllm_api_base,
**kwargs,
)
else:

View File

@@ -29,10 +29,11 @@ def generate_vllm(
custom_headers: dict | None = None,
max_failure_retries: int | None = None,
bbox_scale: int = settings.BBOX_SCALE,
vllm_api_base: str = settings.VLLM_API_BASE,
) -> List[GenerationResult]:
client = OpenAI(
api_key=settings.VLLM_API_KEY,
base_url=settings.VLLM_API_BASE,
base_url=vllm_api_base,
default_headers=custom_headers,
)
model_name = settings.VLLM_MODEL_NAME