diff --git a/chandra/input.py b/chandra/input.py
index d552d14..56829ab 100644
--- a/chandra/input.py
+++ b/chandra/input.py
@@ -13,7 +13,23 @@ def flatten(page, flag=pdfium_c.FLAT_NORMALDISPLAY):
print(f"Failed to flatten annotations / form fields on page {page}.")
-def load_pdf_images(filepath: str, page_range: List[int]):
+def load_image(
+ filepath: str, min_image_dim: int = settings.MIN_IMAGE_DIM
+) -> Image.Image:
+ image = Image.open(filepath).convert("RGB")
+ if image.width < min_image_dim or image.height < min_image_dim:
+ scale = min_image_dim / min(image.width, image.height)
+ new_size = (int(image.width * scale), int(image.height * scale))
+ image = image.resize(new_size, Image.Resampling.LANCZOS)
+ return image
+
+
+def load_pdf_images(
+ filepath: str,
+ page_range: List[int],
+ image_dpi: int = settings.IMAGE_DPI,
+ min_pdf_image_dim: int = settings.MIN_PDF_IMAGE_DIM,
+) -> List[Image.Image]:
doc = pdfium.PdfDocument(filepath)
doc.init_forms()
@@ -22,8 +38,8 @@ def load_pdf_images(filepath: str, page_range: List[int]):
if not page_range or page in page_range:
page_obj = doc[page]
min_page_dim = min(page_obj.get_width(), page_obj.get_height())
- scale_dpi = (settings.MIN_IMAGE_DIM / min_page_dim) * 72
- scale_dpi = max(scale_dpi, settings.IMAGE_DPI)
+ scale_dpi = (min_pdf_image_dim / min_page_dim) * 72
+ scale_dpi = max(scale_dpi, image_dpi)
page_obj = doc[page]
flatten(page_obj)
page_obj = doc[page]
@@ -56,5 +72,5 @@ def load_file(filepath: str, config: dict):
if input_type and input_type.extension == "pdf":
images = load_pdf_images(filepath, page_range)
else:
- images = [Image.open(filepath).convert("RGB")]
+ images = [load_image(filepath)]
return images
diff --git a/chandra/model/__init__.py b/chandra/model/__init__.py
index 0b4b9e8..12af5e4 100644
--- a/chandra/model/__init__.py
+++ b/chandra/model/__init__.py
@@ -4,6 +4,7 @@ from chandra.model.hf import load_model, generate_hf
from chandra.model.schema import BatchInputItem, BatchOutputItem
from chandra.model.vllm import generate_vllm
from chandra.output import parse_markdown, parse_html, parse_chunks, extract_images
+from chandra.settings import settings
class InferenceManager:
@@ -26,19 +27,29 @@ class InferenceManager:
output_kwargs["include_headers_footers"] = kwargs.pop(
"include_headers_footers"
)
+ bbox_scale = kwargs.pop("bbox_scale", settings.BBOX_SCALE)
+ vllm_api_base = kwargs.pop("vllm_api_base", settings.VLLM_API_BASE)
if self.method == "vllm":
results = generate_vllm(
- batch, max_output_tokens=max_output_tokens, **kwargs
+ batch,
+ max_output_tokens=max_output_tokens,
+ bbox_scale=bbox_scale,
+ vllm_api_base=vllm_api_base,
+ **kwargs,
)
else:
results = generate_hf(
- batch, self.model, max_output_tokens=max_output_tokens, **kwargs
+ batch,
+ self.model,
+ max_output_tokens=max_output_tokens,
+ bbox_scale=bbox_scale,
+ **kwargs,
)
output = []
for result, input_item in zip(results, batch):
- chunks = parse_chunks(result.raw, input_item.image)
+ chunks = parse_chunks(result.raw, input_item.image, bbox_scale=bbox_scale)
output.append(
BatchOutputItem(
markdown=parse_markdown(result.raw, **output_kwargs),
@@ -48,6 +59,7 @@ class InferenceManager:
page_box=[0, 0, input_item.image.width, input_item.image.height],
token_count=result.token_count,
images=extract_images(result.raw, chunks, input_item.image),
+ error=result.error,
)
)
return output
diff --git a/chandra/model/hf.py b/chandra/model/hf.py
index 50aa883..6be4e47 100644
--- a/chandra/model/hf.py
+++ b/chandra/model/hf.py
@@ -1,8 +1,5 @@
from typing import List
-from qwen_vl_utils import process_vision_info
-from transformers import Qwen3VLForConditionalGeneration, Qwen3VLProcessor
-
from chandra.model.schema import BatchInputItem, GenerationResult
from chandra.model.util import scale_to_fit
from chandra.prompts import PROMPT_MAPPING
@@ -10,12 +7,20 @@ from chandra.settings import settings
def generate_hf(
- batch: List[BatchInputItem], model, max_output_tokens=None, **kwargs
+ batch: List[BatchInputItem],
+ model,
+ max_output_tokens=None,
+ bbox_scale: int = settings.BBOX_SCALE,
+ **kwargs,
) -> List[GenerationResult]:
+ from qwen_vl_utils import process_vision_info
+
if max_output_tokens is None:
max_output_tokens = settings.MAX_OUTPUT_TOKENS
- messages = [process_batch_element(item, model.processor) for item in batch]
+ messages = [
+ process_batch_element(item, model.processor, bbox_scale) for item in batch
+ ]
text = model.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
@@ -48,12 +53,12 @@ def generate_hf(
return results
-def process_batch_element(item: BatchInputItem, processor):
+def process_batch_element(item: BatchInputItem, processor, bbox_scale: int):
prompt = item.prompt
prompt_type = item.prompt_type
if not prompt:
- prompt = PROMPT_MAPPING[prompt_type]
+ prompt = PROMPT_MAPPING[prompt_type].replace("{bbox_scale}", str(bbox_scale))
content = []
image = scale_to_fit(item.image) # Guarantee max size
@@ -65,12 +70,15 @@ def process_batch_element(item: BatchInputItem, processor):
def load_model():
+ import torch
+ from transformers import Qwen3VLForConditionalGeneration, Qwen3VLProcessor
+
device_map = "auto"
if settings.TORCH_DEVICE:
device_map = {"": settings.TORCH_DEVICE}
kwargs = {
- "dtype": settings.TORCH_DTYPE,
+ "dtype": torch.bfloat16,
"device_map": device_map,
}
if settings.TORCH_ATTN:
diff --git a/chandra/model/schema.py b/chandra/model/schema.py
index 623a958..b6b75fa 100644
--- a/chandra/model/schema.py
+++ b/chandra/model/schema.py
@@ -27,3 +27,4 @@ class BatchOutputItem:
page_box: List[int]
token_count: int
images: dict
+ error: bool
diff --git a/chandra/model/util.py b/chandra/model/util.py
index d43c1c6..7cef96f 100644
--- a/chandra/model/util.py
+++ b/chandra/model/util.py
@@ -44,9 +44,10 @@ def scale_to_fit(
def detect_repeat_token(
predicted_tokens: str,
- max_repeats: int = 4,
+ base_max_repeats: int = 4,
window_size: int = 500,
cut_from_end: int = 0,
+ scaling_factor: float = 3.0,
):
try:
predicted_tokens = parse_markdown(predicted_tokens)
@@ -57,11 +58,13 @@ def detect_repeat_token(
if cut_from_end > 0:
predicted_tokens = predicted_tokens[:-cut_from_end]
- # Try different sequence lengths (1 to window_size//2)
for seq_len in range(1, window_size // 2 + 1):
# Extract the potential repeating sequence from the end
candidate_seq = predicted_tokens[-seq_len:]
+ # Inverse scaling: shorter sequences need more repeats
+ max_repeats = int(base_max_repeats * (1 + scaling_factor / seq_len))
+
# Count how many times this sequence appears consecutively at the end
repeat_count = 0
pos = len(predicted_tokens) - seq_len
@@ -75,7 +78,6 @@ def detect_repeat_token(
else:
break
- # If we found more than max_repeats consecutive occurrences
if repeat_count > max_repeats:
return True
diff --git a/chandra/model/vllm.py b/chandra/model/vllm.py
index 40f1e7f..fce2ed6 100644
--- a/chandra/model/vllm.py
+++ b/chandra/model/vllm.py
@@ -1,5 +1,6 @@
import base64
import io
+import time
from concurrent.futures import ThreadPoolExecutor
from itertools import repeat
from typing import List
@@ -25,10 +26,15 @@ def generate_vllm(
max_output_tokens: int = None,
max_retries: int = None,
max_workers: int | None = None,
+ custom_headers: dict | None = None,
+ max_failure_retries: int | None = None,
+ bbox_scale: int = settings.BBOX_SCALE,
+ vllm_api_base: str = settings.VLLM_API_BASE,
) -> List[GenerationResult]:
client = OpenAI(
api_key=settings.VLLM_API_KEY,
- base_url=settings.VLLM_API_BASE,
+ base_url=vllm_api_base,
+ default_headers=custom_headers,
)
model_name = settings.VLLM_MODEL_NAME
@@ -50,7 +56,9 @@ def generate_vllm(
) -> GenerationResult:
prompt = item.prompt
if not prompt:
- prompt = PROMPT_MAPPING[item.prompt_type]
+ prompt = PROMPT_MAPPING[item.prompt_type].replace(
+ "{bbox_scale}", str(bbox_scale)
+ )
content = []
image = scale_to_fit(item.image)
@@ -68,41 +76,68 @@ def generate_vllm(
completion = client.chat.completions.create(
model=model_name,
messages=[{"role": "user", "content": content}],
- max_tokens=settings.MAX_OUTPUT_TOKENS,
+ max_tokens=max_output_tokens,
temperature=temperature,
top_p=top_p,
)
+ raw = completion.choices[0].message.content
+ result = GenerationResult(
+ raw=raw,
+ token_count=completion.usage.completion_tokens,
+ error=False,
+ )
except Exception as e:
print(f"Error during VLLM generation: {e}")
return GenerationResult(raw="", token_count=0, error=True)
- return GenerationResult(
- raw=completion.choices[0].message.content,
- token_count=completion.usage.completion_tokens,
- error=False,
- )
+ return result
- def process_item(item, max_retries):
+ def process_item(item, max_retries, max_failure_retries=None):
result = _generate(item)
retries = 0
- while retries < max_retries and (
- detect_repeat_token(result.raw)
- or (
- len(result.raw) > 50
- and detect_repeat_token(result.raw, cut_from_end=50)
- )
- or result.error
- ):
- print(
- f"Detected repeat token or error, retrying generation (attempt {retries + 1})..."
- )
+ while _should_retry(result, retries, max_retries, max_failure_retries):
result = _generate(item, temperature=0.3, top_p=0.95)
retries += 1
return result
+ def _should_retry(result, retries, max_retries, max_failure_retries):
+ has_repeat = detect_repeat_token(result.raw) or (
+ len(result.raw) > 50 and detect_repeat_token(result.raw, cut_from_end=50)
+ )
+
+ if retries < max_retries and has_repeat:
+ print(
+ f"Detected repeat token, retrying generation (attempt {retries + 1})..."
+ )
+ return True
+
+ if retries < max_retries and result.error:
+ print(
+ f"Detected vllm error, retrying generation (attempt {retries + 1})..."
+ )
+ time.sleep(2 * (retries + 1)) # Sleeping can help under load
+ return True
+
+ if (
+ result.error
+ and max_failure_retries is not None
+ and retries < max_failure_retries
+ ):
+ print(
+ f"Detected vllm error, retrying generation (attempt {retries + 1})..."
+ )
+ time.sleep(2 * (retries + 1)) # Sleeping can help under load
+ return True
+
+ return False
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
- results = list(executor.map(process_item, batch, repeat(max_retries)))
+ results = list(
+ executor.map(
+ process_item, batch, repeat(max_retries), repeat(max_failure_retries)
+ )
+ )
return results
diff --git a/chandra/output.py b/chandra/output.py
index 173dcd8..a8a8b18 100644
--- a/chandra/output.py
+++ b/chandra/output.py
@@ -6,9 +6,11 @@ from functools import lru_cache
import six
from PIL import Image
-from bs4 import BeautifulSoup, NavigableString
+from bs4 import BeautifulSoup
from markdownify import MarkdownConverter, re_whitespace
+from chandra.settings import settings
+
@lru_cache
def _hash_html(html: str):
@@ -30,7 +32,11 @@ def extract_images(html: str, chunks: dict, image: Image.Image):
if not img:
continue
bbox = chunk["bbox"]
- block_image = image.crop(bbox)
+ try:
+ block_image = image.crop(bbox)
+ except ValueError:
+ # Happens when bbox coordinates are invalid
+ continue
img_name = get_image_name(html, div_idx)
images[img_name] = block_image
return images
@@ -67,44 +73,22 @@ def parse_html(
else:
img = BeautifulSoup(f"", "html.parser")
div.append(img)
+
+ # Wrap text content in
tags if no inner HTML tags exist + if label in ["Text"] and not re.search( + "<.+>", str(div.decode_contents()).strip() + ): + # Add inner p tags if missing for text blocks + text_content = str(div.decode_contents()).strip() + text_content = f"
{text_content}
" + div.clear() + div.append(BeautifulSoup(text_content, "html.parser")) + content = str(div.decode_contents()) out_html += content return out_html -def escape_dollars(text): - return text.replace("$", r"\$") - - -def get_formatted_table_text(element): - text = [] - for content in element.contents: - if content is None: - continue - - if isinstance(content, NavigableString): - stripped = content.strip() - if stripped: - text.append(escape_dollars(stripped)) - elif content.name == "br": - text.append("