diff --git a/chandra/input.py b/chandra/input.py index d552d14..56829ab 100644 --- a/chandra/input.py +++ b/chandra/input.py @@ -13,7 +13,23 @@ def flatten(page, flag=pdfium_c.FLAT_NORMALDISPLAY): print(f"Failed to flatten annotations / form fields on page {page}.") -def load_pdf_images(filepath: str, page_range: List[int]): +def load_image( + filepath: str, min_image_dim: int = settings.MIN_IMAGE_DIM +) -> Image.Image: + image = Image.open(filepath).convert("RGB") + if image.width < min_image_dim or image.height < min_image_dim: + scale = min_image_dim / min(image.width, image.height) + new_size = (int(image.width * scale), int(image.height * scale)) + image = image.resize(new_size, Image.Resampling.LANCZOS) + return image + + +def load_pdf_images( + filepath: str, + page_range: List[int], + image_dpi: int = settings.IMAGE_DPI, + min_pdf_image_dim: int = settings.MIN_PDF_IMAGE_DIM, +) -> List[Image.Image]: doc = pdfium.PdfDocument(filepath) doc.init_forms() @@ -22,8 +38,8 @@ def load_pdf_images(filepath: str, page_range: List[int]): if not page_range or page in page_range: page_obj = doc[page] min_page_dim = min(page_obj.get_width(), page_obj.get_height()) - scale_dpi = (settings.MIN_IMAGE_DIM / min_page_dim) * 72 - scale_dpi = max(scale_dpi, settings.IMAGE_DPI) + scale_dpi = (min_pdf_image_dim / min_page_dim) * 72 + scale_dpi = max(scale_dpi, image_dpi) page_obj = doc[page] flatten(page_obj) page_obj = doc[page] @@ -56,5 +72,5 @@ def load_file(filepath: str, config: dict): if input_type and input_type.extension == "pdf": images = load_pdf_images(filepath, page_range) else: - images = [Image.open(filepath).convert("RGB")] + images = [load_image(filepath)] return images diff --git a/chandra/model/__init__.py b/chandra/model/__init__.py index 0b4b9e8..12af5e4 100644 --- a/chandra/model/__init__.py +++ b/chandra/model/__init__.py @@ -4,6 +4,7 @@ from chandra.model.hf import load_model, generate_hf from chandra.model.schema import BatchInputItem, BatchOutputItem from chandra.model.vllm import generate_vllm from chandra.output import parse_markdown, parse_html, parse_chunks, extract_images +from chandra.settings import settings class InferenceManager: @@ -26,19 +27,29 @@ class InferenceManager: output_kwargs["include_headers_footers"] = kwargs.pop( "include_headers_footers" ) + bbox_scale = kwargs.pop("bbox_scale", settings.BBOX_SCALE) + vllm_api_base = kwargs.pop("vllm_api_base", settings.VLLM_API_BASE) if self.method == "vllm": results = generate_vllm( - batch, max_output_tokens=max_output_tokens, **kwargs + batch, + max_output_tokens=max_output_tokens, + bbox_scale=bbox_scale, + vllm_api_base=vllm_api_base, + **kwargs, ) else: results = generate_hf( - batch, self.model, max_output_tokens=max_output_tokens, **kwargs + batch, + self.model, + max_output_tokens=max_output_tokens, + bbox_scale=bbox_scale, + **kwargs, ) output = [] for result, input_item in zip(results, batch): - chunks = parse_chunks(result.raw, input_item.image) + chunks = parse_chunks(result.raw, input_item.image, bbox_scale=bbox_scale) output.append( BatchOutputItem( markdown=parse_markdown(result.raw, **output_kwargs), @@ -48,6 +59,7 @@ class InferenceManager: page_box=[0, 0, input_item.image.width, input_item.image.height], token_count=result.token_count, images=extract_images(result.raw, chunks, input_item.image), + error=result.error, ) ) return output diff --git a/chandra/model/hf.py b/chandra/model/hf.py index 50aa883..6be4e47 100644 --- a/chandra/model/hf.py +++ b/chandra/model/hf.py @@ -1,8 +1,5 @@ from typing import List -from qwen_vl_utils import process_vision_info -from transformers import Qwen3VLForConditionalGeneration, Qwen3VLProcessor - from chandra.model.schema import BatchInputItem, GenerationResult from chandra.model.util import scale_to_fit from chandra.prompts import PROMPT_MAPPING @@ -10,12 +7,20 @@ from chandra.settings import settings def generate_hf( - batch: List[BatchInputItem], model, max_output_tokens=None, **kwargs + batch: List[BatchInputItem], + model, + max_output_tokens=None, + bbox_scale: int = settings.BBOX_SCALE, + **kwargs, ) -> List[GenerationResult]: + from qwen_vl_utils import process_vision_info + if max_output_tokens is None: max_output_tokens = settings.MAX_OUTPUT_TOKENS - messages = [process_batch_element(item, model.processor) for item in batch] + messages = [ + process_batch_element(item, model.processor, bbox_scale) for item in batch + ] text = model.processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) @@ -48,12 +53,12 @@ def generate_hf( return results -def process_batch_element(item: BatchInputItem, processor): +def process_batch_element(item: BatchInputItem, processor, bbox_scale: int): prompt = item.prompt prompt_type = item.prompt_type if not prompt: - prompt = PROMPT_MAPPING[prompt_type] + prompt = PROMPT_MAPPING[prompt_type].replace("{bbox_scale}", str(bbox_scale)) content = [] image = scale_to_fit(item.image) # Guarantee max size @@ -65,12 +70,15 @@ def process_batch_element(item: BatchInputItem, processor): def load_model(): + import torch + from transformers import Qwen3VLForConditionalGeneration, Qwen3VLProcessor + device_map = "auto" if settings.TORCH_DEVICE: device_map = {"": settings.TORCH_DEVICE} kwargs = { - "dtype": settings.TORCH_DTYPE, + "dtype": torch.bfloat16, "device_map": device_map, } if settings.TORCH_ATTN: diff --git a/chandra/model/schema.py b/chandra/model/schema.py index 623a958..b6b75fa 100644 --- a/chandra/model/schema.py +++ b/chandra/model/schema.py @@ -27,3 +27,4 @@ class BatchOutputItem: page_box: List[int] token_count: int images: dict + error: bool diff --git a/chandra/model/util.py b/chandra/model/util.py index d43c1c6..7cef96f 100644 --- a/chandra/model/util.py +++ b/chandra/model/util.py @@ -44,9 +44,10 @@ def scale_to_fit( def detect_repeat_token( predicted_tokens: str, - max_repeats: int = 4, + base_max_repeats: int = 4, window_size: int = 500, cut_from_end: int = 0, + scaling_factor: float = 3.0, ): try: predicted_tokens = parse_markdown(predicted_tokens) @@ -57,11 +58,13 @@ def detect_repeat_token( if cut_from_end > 0: predicted_tokens = predicted_tokens[:-cut_from_end] - # Try different sequence lengths (1 to window_size//2) for seq_len in range(1, window_size // 2 + 1): # Extract the potential repeating sequence from the end candidate_seq = predicted_tokens[-seq_len:] + # Inverse scaling: shorter sequences need more repeats + max_repeats = int(base_max_repeats * (1 + scaling_factor / seq_len)) + # Count how many times this sequence appears consecutively at the end repeat_count = 0 pos = len(predicted_tokens) - seq_len @@ -75,7 +78,6 @@ def detect_repeat_token( else: break - # If we found more than max_repeats consecutive occurrences if repeat_count > max_repeats: return True diff --git a/chandra/model/vllm.py b/chandra/model/vllm.py index 40f1e7f..fce2ed6 100644 --- a/chandra/model/vllm.py +++ b/chandra/model/vllm.py @@ -1,5 +1,6 @@ import base64 import io +import time from concurrent.futures import ThreadPoolExecutor from itertools import repeat from typing import List @@ -25,10 +26,15 @@ def generate_vllm( max_output_tokens: int = None, max_retries: int = None, max_workers: int | None = None, + custom_headers: dict | None = None, + max_failure_retries: int | None = None, + bbox_scale: int = settings.BBOX_SCALE, + vllm_api_base: str = settings.VLLM_API_BASE, ) -> List[GenerationResult]: client = OpenAI( api_key=settings.VLLM_API_KEY, - base_url=settings.VLLM_API_BASE, + base_url=vllm_api_base, + default_headers=custom_headers, ) model_name = settings.VLLM_MODEL_NAME @@ -50,7 +56,9 @@ def generate_vllm( ) -> GenerationResult: prompt = item.prompt if not prompt: - prompt = PROMPT_MAPPING[item.prompt_type] + prompt = PROMPT_MAPPING[item.prompt_type].replace( + "{bbox_scale}", str(bbox_scale) + ) content = [] image = scale_to_fit(item.image) @@ -68,41 +76,68 @@ def generate_vllm( completion = client.chat.completions.create( model=model_name, messages=[{"role": "user", "content": content}], - max_tokens=settings.MAX_OUTPUT_TOKENS, + max_tokens=max_output_tokens, temperature=temperature, top_p=top_p, ) + raw = completion.choices[0].message.content + result = GenerationResult( + raw=raw, + token_count=completion.usage.completion_tokens, + error=False, + ) except Exception as e: print(f"Error during VLLM generation: {e}") return GenerationResult(raw="", token_count=0, error=True) - return GenerationResult( - raw=completion.choices[0].message.content, - token_count=completion.usage.completion_tokens, - error=False, - ) + return result - def process_item(item, max_retries): + def process_item(item, max_retries, max_failure_retries=None): result = _generate(item) retries = 0 - while retries < max_retries and ( - detect_repeat_token(result.raw) - or ( - len(result.raw) > 50 - and detect_repeat_token(result.raw, cut_from_end=50) - ) - or result.error - ): - print( - f"Detected repeat token or error, retrying generation (attempt {retries + 1})..." - ) + while _should_retry(result, retries, max_retries, max_failure_retries): result = _generate(item, temperature=0.3, top_p=0.95) retries += 1 return result + def _should_retry(result, retries, max_retries, max_failure_retries): + has_repeat = detect_repeat_token(result.raw) or ( + len(result.raw) > 50 and detect_repeat_token(result.raw, cut_from_end=50) + ) + + if retries < max_retries and has_repeat: + print( + f"Detected repeat token, retrying generation (attempt {retries + 1})..." + ) + return True + + if retries < max_retries and result.error: + print( + f"Detected vllm error, retrying generation (attempt {retries + 1})..." + ) + time.sleep(2 * (retries + 1)) # Sleeping can help under load + return True + + if ( + result.error + and max_failure_retries is not None + and retries < max_failure_retries + ): + print( + f"Detected vllm error, retrying generation (attempt {retries + 1})..." + ) + time.sleep(2 * (retries + 1)) # Sleeping can help under load + return True + + return False + with ThreadPoolExecutor(max_workers=max_workers) as executor: - results = list(executor.map(process_item, batch, repeat(max_retries))) + results = list( + executor.map( + process_item, batch, repeat(max_retries), repeat(max_failure_retries) + ) + ) return results diff --git a/chandra/output.py b/chandra/output.py index 173dcd8..a8a8b18 100644 --- a/chandra/output.py +++ b/chandra/output.py @@ -6,9 +6,11 @@ from functools import lru_cache import six from PIL import Image -from bs4 import BeautifulSoup, NavigableString +from bs4 import BeautifulSoup from markdownify import MarkdownConverter, re_whitespace +from chandra.settings import settings + @lru_cache def _hash_html(html: str): @@ -30,7 +32,11 @@ def extract_images(html: str, chunks: dict, image: Image.Image): if not img: continue bbox = chunk["bbox"] - block_image = image.crop(bbox) + try: + block_image = image.crop(bbox) + except ValueError: + # Happens when bbox coordinates are invalid + continue img_name = get_image_name(html, div_idx) images[img_name] = block_image return images @@ -67,44 +73,22 @@ def parse_html( else: img = BeautifulSoup(f"", "html.parser") div.append(img) + + # Wrap text content in

tags if no inner HTML tags exist + if label in ["Text"] and not re.search( + "<.+>", str(div.decode_contents()).strip() + ): + # Add inner p tags if missing for text blocks + text_content = str(div.decode_contents()).strip() + text_content = f"

{text_content}

" + div.clear() + div.append(BeautifulSoup(text_content, "html.parser")) + content = str(div.decode_contents()) out_html += content return out_html -def escape_dollars(text): - return text.replace("$", r"\$") - - -def get_formatted_table_text(element): - text = [] - for content in element.contents: - if content is None: - continue - - if isinstance(content, NavigableString): - stripped = content.strip() - if stripped: - text.append(escape_dollars(stripped)) - elif content.name == "br": - text.append("
") - elif content.name == "math": - text.append("$" + content.text + "$") - else: - content_str = escape_dollars(str(content)) - text.append(content_str) - - full_text = "" - for i, t in enumerate(text): - if t == "
": - full_text += t - elif i > 0 and text[i - 1] != "
": - full_text += " " + t - else: - full_text += t - return full_text - - class Markdownify(MarkdownConverter): def __init__( self, @@ -204,19 +188,25 @@ class LayoutBlock: content: str -def parse_layout(html: str, image: Image.Image): +def parse_layout(html: str, image: Image.Image, bbox_scale=settings.BBOX_SCALE): soup = BeautifulSoup(html, "html.parser") top_level_divs = soup.find_all("div", recursive=False) width, height = image.size - width_scaler = width / 1024 - height_scaler = height / 1024 + width_scaler = width / bbox_scale + height_scaler = height / bbox_scale layout_blocks = [] for div in top_level_divs: bbox = div.get("data-bbox") + try: bbox = json.loads(bbox) + assert len(bbox) == 4, "Invalid bbox length" except Exception: - bbox = [0, 0, 1, 1] # Fallback to a default bbox if parsing fails + try: + bbox = bbox.split(" ") + assert len(bbox) == 4, "Invalid bbox length" + except Exception: + bbox = [0, 0, 1, 1] bbox = list(map(int, bbox)) # Normalize bbox @@ -232,7 +222,7 @@ def parse_layout(html: str, image: Image.Image): return layout_blocks -def parse_chunks(html: str, image: Image.Image): - layout = parse_layout(html, image) +def parse_chunks(html: str, image: Image.Image, bbox_scale=settings.BBOX_SCALE): + layout = parse_layout(html, image, bbox_scale=bbox_scale) chunks = [asdict(block) for block in layout] return chunks diff --git a/chandra/prompts.py b/chandra/prompts.py index 49d6b15..954897e 100644 --- a/chandra/prompts.py +++ b/chandra/prompts.py @@ -65,7 +65,7 @@ Guidelines: """.strip() OCR_LAYOUT_PROMPT = f""" -OCR this image to HTML, arranged as layout blocks. Each layout block should be a div with the data-bbox attribute representing the bounding box of the block in [x0, y0, x1, y1] format. Bboxes are normalized 0-1024. The data-label attribute is the label for the block. +OCR this image to HTML, arranged as layout blocks. Each layout block should be a div with the data-bbox attribute representing the bounding box of the block in [x0, y0, x1, y1] format. Bboxes are normalized 0-{{bbox_scale}}. The data-label attribute is the label for the block. Use the following labels: - Caption diff --git a/chandra/scripts/screenshot_app.py b/chandra/scripts/screenshot_app.py index e7ab0f2..4083feb 100644 --- a/chandra/scripts/screenshot_app.py +++ b/chandra/scripts/screenshot_app.py @@ -143,6 +143,7 @@ def process(): "image_height": img_height, "blocks": blocks_data, "html": html_with_images, + "markdown": result.markdown, } ) diff --git a/chandra/scripts/templates/screenshot.html b/chandra/scripts/templates/screenshot.html index 642fe94..92f6d4d 100644 --- a/chandra/scripts/templates/screenshot.html +++ b/chandra/scripts/templates/screenshot.html @@ -64,6 +64,20 @@ cursor: not-allowed; } + .controls label { + display: flex; + align-items: center; + gap: 8px; + color: white; + font-size: 14px; + cursor: pointer; + user-select: none; + } + + .controls input[type="checkbox"] { + cursor: pointer; + } + .loading { display: none; color: #f39c12; @@ -75,6 +89,11 @@ font-weight: bold; } + .success { + color: #27ae60; + font-weight: bold; + } + .screenshot-container { display: none; margin-top: 60px; @@ -88,8 +107,18 @@ display: flex; } - .left-panel, .right-panel { - flex: 1; + .left-panel { + flex: 0 0 40%; + display: flex; + flex-direction: column; + background: white; + border-radius: 8px; + overflow: hidden; + box-shadow: 0 4px 12px rgba(0,0,0,0.3); + } + + .right-panel { + flex: 0 0 60%; display: flex; flex-direction: column; background: white; @@ -137,6 +166,7 @@ padding: 30px; line-height: 1.6; color: #333; + font-size: 24px; } .markdown-content h1, .markdown-content h2, .markdown-content h3 { @@ -215,8 +245,14 @@ + + Processing... +
@@ -242,6 +278,11 @@