diff --git a/chandra/model/__init__.py b/chandra/model/__init__.py index b703a18..0f0f761 100644 --- a/chandra/model/__init__.py +++ b/chandra/model/__init__.py @@ -4,6 +4,7 @@ from chandra.model.hf import load_model, generate_hf from chandra.model.schema import BatchInputItem, BatchOutputItem from chandra.model.vllm import generate_vllm from chandra.output import parse_markdown, parse_html, parse_chunks, extract_images +from chandra.settings import settings class InferenceManager: @@ -26,19 +27,27 @@ class InferenceManager: output_kwargs["include_headers_footers"] = kwargs.pop( "include_headers_footers" ) + bbox_scale = kwargs.get("bbox_scale", settings.BBOX_SCALE) if self.method == "vllm": results = generate_vllm( - batch, max_output_tokens=max_output_tokens, **kwargs + batch, + max_output_tokens=max_output_tokens, + bbox_scale=bbox_scale, + **kwargs, ) else: results = generate_hf( - batch, self.model, max_output_tokens=max_output_tokens, **kwargs + batch, + self.model, + max_output_tokens=max_output_tokens, + bbox_scale=bbox_scale, + **kwargs, ) output = [] for result, input_item in zip(results, batch): - chunks = parse_chunks(result.raw, input_item.image) + chunks = parse_chunks(result.raw, input_item.image, bbox_scale=bbox_scale) output.append( BatchOutputItem( markdown=parse_markdown(result.raw, **output_kwargs), diff --git a/chandra/model/hf.py b/chandra/model/hf.py index 50aa883..2ef603f 100644 --- a/chandra/model/hf.py +++ b/chandra/model/hf.py @@ -10,12 +10,18 @@ from chandra.settings import settings def generate_hf( - batch: List[BatchInputItem], model, max_output_tokens=None, **kwargs + batch: List[BatchInputItem], + model, + max_output_tokens=None, + bbox_scale: int = settings.BBOX_SCALE, + **kwargs, ) -> List[GenerationResult]: if max_output_tokens is None: max_output_tokens = settings.MAX_OUTPUT_TOKENS - messages = [process_batch_element(item, model.processor) for item in batch] + messages = [ + process_batch_element(item, model.processor, bbox_scale) for item in batch + ] text = model.processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) @@ -48,12 +54,12 @@ def generate_hf( return results -def process_batch_element(item: BatchInputItem, processor): +def process_batch_element(item: BatchInputItem, processor, bbox_scale: int): prompt = item.prompt prompt_type = item.prompt_type if not prompt: - prompt = PROMPT_MAPPING[prompt_type] + prompt = PROMPT_MAPPING[prompt_type].replace("{bbox_scale}", str(bbox_scale)) content = [] image = scale_to_fit(item.image) # Guarantee max size diff --git a/chandra/model/vllm.py b/chandra/model/vllm.py index a2e5816..bf081d0 100644 --- a/chandra/model/vllm.py +++ b/chandra/model/vllm.py @@ -28,6 +28,7 @@ def generate_vllm( max_workers: int | None = None, custom_headers: dict | None = None, max_failure_retries: int | None = None, + bbox_scale: int = settings.BBOX_SCALE, ) -> List[GenerationResult]: client = OpenAI( api_key=settings.VLLM_API_KEY, @@ -54,7 +55,9 @@ def generate_vllm( ) -> GenerationResult: prompt = item.prompt if not prompt: - prompt = PROMPT_MAPPING[item.prompt_type] + prompt = PROMPT_MAPPING[item.prompt_type].replace( + "{bbox_scale}", str(bbox_scale) + ) content = [] image = scale_to_fit(item.image) diff --git a/chandra/output.py b/chandra/output.py index 47e9f98..5cadb89 100644 --- a/chandra/output.py +++ b/chandra/output.py @@ -221,12 +221,12 @@ class LayoutBlock: content: str -def parse_layout(html: str, image: Image.Image): +def parse_layout(html: str, image: Image.Image, bbox_scale=settings.BBOX_SCALE): soup = BeautifulSoup(html, "html.parser") top_level_divs = soup.find_all("div", recursive=False) width, height = image.size - width_scaler = width / settings.BBOX_SCALE - height_scaler = height / settings.BBOX_SCALE + width_scaler = width / bbox_scale + height_scaler = height / bbox_scale layout_blocks = [] for div in top_level_divs: bbox = div.get("data-bbox") @@ -255,7 +255,7 @@ def parse_layout(html: str, image: Image.Image): return layout_blocks -def parse_chunks(html: str, image: Image.Image): - layout = parse_layout(html, image) +def parse_chunks(html: str, image: Image.Image, bbox_scale=settings.BBOX_SCALE): + layout = parse_layout(html, image, bbox_scale=bbox_scale) chunks = [asdict(block) for block in layout] return chunks diff --git a/chandra/prompts.py b/chandra/prompts.py index f5a17bb..954897e 100644 --- a/chandra/prompts.py +++ b/chandra/prompts.py @@ -1,5 +1,3 @@ -from chandra.settings import settings - ALLOWED_TAGS = [ "math", "br", @@ -67,7 +65,7 @@ Guidelines: """.strip() OCR_LAYOUT_PROMPT = f""" -OCR this image to HTML, arranged as layout blocks. Each layout block should be a div with the data-bbox attribute representing the bounding box of the block in [x0, y0, x1, y1] format. Bboxes are normalized 0-{settings.BBOX_SCALE}. The data-label attribute is the label for the block. +OCR this image to HTML, arranged as layout blocks. Each layout block should be a div with the data-bbox attribute representing the bounding box of the block in [x0, y0, x1, y1] format. Bboxes are normalized 0-{{bbox_scale}}. The data-label attribute is the label for the block. Use the following labels: - Caption