diff --git a/chandra/input.py b/chandra/input.py index 0d793a7..56829ab 100644 --- a/chandra/input.py +++ b/chandra/input.py @@ -13,16 +13,23 @@ def flatten(page, flag=pdfium_c.FLAT_NORMALDISPLAY): print(f"Failed to flatten annotations / form fields on page {page}.") -def load_image(filepath: str): +def load_image( + filepath: str, min_image_dim: int = settings.MIN_IMAGE_DIM +) -> Image.Image: image = Image.open(filepath).convert("RGB") - if image.width < settings.MIN_IMAGE_DIM or image.height < settings.MIN_IMAGE_DIM: - scale = settings.MIN_IMAGE_DIM / min(image.width, image.height) + if image.width < min_image_dim or image.height < min_image_dim: + scale = min_image_dim / min(image.width, image.height) new_size = (int(image.width * scale), int(image.height * scale)) image = image.resize(new_size, Image.Resampling.LANCZOS) return image -def load_pdf_images(filepath: str, page_range: List[int]): +def load_pdf_images( + filepath: str, + page_range: List[int], + image_dpi: int = settings.IMAGE_DPI, + min_pdf_image_dim: int = settings.MIN_PDF_IMAGE_DIM, +) -> List[Image.Image]: doc = pdfium.PdfDocument(filepath) doc.init_forms() @@ -31,8 +38,8 @@ def load_pdf_images(filepath: str, page_range: List[int]): if not page_range or page in page_range: page_obj = doc[page] min_page_dim = min(page_obj.get_width(), page_obj.get_height()) - scale_dpi = (settings.MIN_PDF_IMAGE_DIM / min_page_dim) * 72 - scale_dpi = max(scale_dpi, settings.IMAGE_DPI) + scale_dpi = (min_pdf_image_dim / min_page_dim) * 72 + scale_dpi = max(scale_dpi, image_dpi) page_obj = doc[page] flatten(page_obj) page_obj = doc[page] diff --git a/chandra/model/__init__.py b/chandra/model/__init__.py index 63d6df5..12af5e4 100644 --- a/chandra/model/__init__.py +++ b/chandra/model/__init__.py @@ -28,12 +28,14 @@ class InferenceManager: "include_headers_footers" ) bbox_scale = kwargs.pop("bbox_scale", settings.BBOX_SCALE) + vllm_api_base = kwargs.pop("vllm_api_base", settings.VLLM_API_BASE) if self.method == "vllm": results = generate_vllm( batch, max_output_tokens=max_output_tokens, bbox_scale=bbox_scale, + vllm_api_base=vllm_api_base, **kwargs, ) else: diff --git a/chandra/model/vllm.py b/chandra/model/vllm.py index bf081d0..fce2ed6 100644 --- a/chandra/model/vllm.py +++ b/chandra/model/vllm.py @@ -29,10 +29,11 @@ def generate_vllm( custom_headers: dict | None = None, max_failure_retries: int | None = None, bbox_scale: int = settings.BBOX_SCALE, + vllm_api_base: str = settings.VLLM_API_BASE, ) -> List[GenerationResult]: client = OpenAI( api_key=settings.VLLM_API_KEY, - base_url=settings.VLLM_API_BASE, + base_url=vllm_api_base, default_headers=custom_headers, ) model_name = settings.VLLM_MODEL_NAME