mirror of
https://github.com/datalab-to/chandra.git
synced 2025-11-29 16:43:11 +00:00
Compare commits
14 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c049e7524f | ||
|
|
b96eb84094 | ||
|
|
0f5f3d485c | ||
|
|
1bab4bf73a | ||
|
|
34f825351c | ||
|
|
068db0311e | ||
|
|
aafbb70ce8 | ||
|
|
22639087e7 | ||
|
|
910bcf100f | ||
|
|
3958707a80 | ||
|
|
fe28f26fc2 | ||
|
|
4470243560 | ||
|
|
a3889b12fb | ||
|
|
d69d18d6e8 |
@@ -13,16 +13,23 @@ def flatten(page, flag=pdfium_c.FLAT_NORMALDISPLAY):
|
||||
print(f"Failed to flatten annotations / form fields on page {page}.")
|
||||
|
||||
|
||||
def load_image(filepath: str):
|
||||
def load_image(
|
||||
filepath: str, min_image_dim: int = settings.MIN_IMAGE_DIM
|
||||
) -> Image.Image:
|
||||
image = Image.open(filepath).convert("RGB")
|
||||
if image.width < settings.MIN_IMAGE_DIM or image.height < settings.MIN_IMAGE_DIM:
|
||||
scale = settings.MIN_IMAGE_DIM / min(image.width, image.height)
|
||||
if image.width < min_image_dim or image.height < min_image_dim:
|
||||
scale = min_image_dim / min(image.width, image.height)
|
||||
new_size = (int(image.width * scale), int(image.height * scale))
|
||||
image = image.resize(new_size, Image.Resampling.LANCZOS)
|
||||
return image
|
||||
|
||||
|
||||
def load_pdf_images(filepath: str, page_range: List[int]):
|
||||
def load_pdf_images(
|
||||
filepath: str,
|
||||
page_range: List[int],
|
||||
image_dpi: int = settings.IMAGE_DPI,
|
||||
min_pdf_image_dim: int = settings.MIN_PDF_IMAGE_DIM,
|
||||
) -> List[Image.Image]:
|
||||
doc = pdfium.PdfDocument(filepath)
|
||||
doc.init_forms()
|
||||
|
||||
@@ -31,8 +38,8 @@ def load_pdf_images(filepath: str, page_range: List[int]):
|
||||
if not page_range or page in page_range:
|
||||
page_obj = doc[page]
|
||||
min_page_dim = min(page_obj.get_width(), page_obj.get_height())
|
||||
scale_dpi = (settings.MIN_PDF_IMAGE_DIM / min_page_dim) * 72
|
||||
scale_dpi = max(scale_dpi, settings.IMAGE_DPI)
|
||||
scale_dpi = (min_pdf_image_dim / min_page_dim) * 72
|
||||
scale_dpi = max(scale_dpi, image_dpi)
|
||||
page_obj = doc[page]
|
||||
flatten(page_obj)
|
||||
page_obj = doc[page]
|
||||
|
||||
@@ -4,6 +4,7 @@ from chandra.model.hf import load_model, generate_hf
|
||||
from chandra.model.schema import BatchInputItem, BatchOutputItem
|
||||
from chandra.model.vllm import generate_vllm
|
||||
from chandra.output import parse_markdown, parse_html, parse_chunks, extract_images
|
||||
from chandra.settings import settings
|
||||
|
||||
|
||||
class InferenceManager:
|
||||
@@ -26,19 +27,29 @@ class InferenceManager:
|
||||
output_kwargs["include_headers_footers"] = kwargs.pop(
|
||||
"include_headers_footers"
|
||||
)
|
||||
bbox_scale = kwargs.pop("bbox_scale", settings.BBOX_SCALE)
|
||||
vllm_api_base = kwargs.pop("vllm_api_base", settings.VLLM_API_BASE)
|
||||
|
||||
if self.method == "vllm":
|
||||
results = generate_vllm(
|
||||
batch, max_output_tokens=max_output_tokens, **kwargs
|
||||
batch,
|
||||
max_output_tokens=max_output_tokens,
|
||||
bbox_scale=bbox_scale,
|
||||
vllm_api_base=vllm_api_base,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
results = generate_hf(
|
||||
batch, self.model, max_output_tokens=max_output_tokens, **kwargs
|
||||
batch,
|
||||
self.model,
|
||||
max_output_tokens=max_output_tokens,
|
||||
bbox_scale=bbox_scale,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
output = []
|
||||
for result, input_item in zip(results, batch):
|
||||
chunks = parse_chunks(result.raw, input_item.image)
|
||||
chunks = parse_chunks(result.raw, input_item.image, bbox_scale=bbox_scale)
|
||||
output.append(
|
||||
BatchOutputItem(
|
||||
markdown=parse_markdown(result.raw, **output_kwargs),
|
||||
|
||||
@@ -1,22 +1,26 @@
|
||||
from typing import List
|
||||
|
||||
from qwen_vl_utils import process_vision_info
|
||||
from transformers import Qwen3VLForConditionalGeneration, Qwen3VLProcessor
|
||||
|
||||
from chandra.model.schema import BatchInputItem, GenerationResult
|
||||
from chandra.model.util import scale_to_fit
|
||||
from chandra.output import fix_raw
|
||||
from chandra.prompts import PROMPT_MAPPING
|
||||
from chandra.settings import settings
|
||||
|
||||
|
||||
def generate_hf(
|
||||
batch: List[BatchInputItem], model, max_output_tokens=None, **kwargs
|
||||
batch: List[BatchInputItem],
|
||||
model,
|
||||
max_output_tokens=None,
|
||||
bbox_scale: int = settings.BBOX_SCALE,
|
||||
**kwargs,
|
||||
) -> List[GenerationResult]:
|
||||
from qwen_vl_utils import process_vision_info
|
||||
|
||||
if max_output_tokens is None:
|
||||
max_output_tokens = settings.MAX_OUTPUT_TOKENS
|
||||
|
||||
messages = [process_batch_element(item, model.processor) for item in batch]
|
||||
messages = [
|
||||
process_batch_element(item, model.processor, bbox_scale) for item in batch
|
||||
]
|
||||
text = model.processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
@@ -43,18 +47,18 @@ def generate_hf(
|
||||
clean_up_tokenization_spaces=False,
|
||||
)
|
||||
results = [
|
||||
GenerationResult(raw=fix_raw(out), token_count=len(ids), error=False)
|
||||
GenerationResult(raw=out, token_count=len(ids), error=False)
|
||||
for out, ids in zip(output_text, generated_ids_trimmed)
|
||||
]
|
||||
return results
|
||||
|
||||
|
||||
def process_batch_element(item: BatchInputItem, processor):
|
||||
def process_batch_element(item: BatchInputItem, processor, bbox_scale: int):
|
||||
prompt = item.prompt
|
||||
prompt_type = item.prompt_type
|
||||
|
||||
if not prompt:
|
||||
prompt = PROMPT_MAPPING[prompt_type]
|
||||
prompt = PROMPT_MAPPING[prompt_type].replace("{bbox_scale}", str(bbox_scale))
|
||||
|
||||
content = []
|
||||
image = scale_to_fit(item.image) # Guarantee max size
|
||||
@@ -66,12 +70,15 @@ def process_batch_element(item: BatchInputItem, processor):
|
||||
|
||||
|
||||
def load_model():
|
||||
import torch
|
||||
from transformers import Qwen3VLForConditionalGeneration, Qwen3VLProcessor
|
||||
|
||||
device_map = "auto"
|
||||
if settings.TORCH_DEVICE:
|
||||
device_map = {"": settings.TORCH_DEVICE}
|
||||
|
||||
kwargs = {
|
||||
"dtype": settings.TORCH_DTYPE,
|
||||
"dtype": torch.bfloat16,
|
||||
"device_map": device_map,
|
||||
}
|
||||
if settings.TORCH_ATTN:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import base64
|
||||
import io
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from itertools import repeat
|
||||
from typing import List
|
||||
@@ -9,7 +10,6 @@ from openai import OpenAI
|
||||
|
||||
from chandra.model.schema import BatchInputItem, GenerationResult
|
||||
from chandra.model.util import scale_to_fit, detect_repeat_token
|
||||
from chandra.output import fix_raw
|
||||
from chandra.prompts import PROMPT_MAPPING
|
||||
from chandra.settings import settings
|
||||
|
||||
@@ -27,10 +27,13 @@ def generate_vllm(
|
||||
max_retries: int = None,
|
||||
max_workers: int | None = None,
|
||||
custom_headers: dict | None = None,
|
||||
max_failure_retries: int | None = None,
|
||||
bbox_scale: int = settings.BBOX_SCALE,
|
||||
vllm_api_base: str = settings.VLLM_API_BASE,
|
||||
) -> List[GenerationResult]:
|
||||
client = OpenAI(
|
||||
api_key=settings.VLLM_API_KEY,
|
||||
base_url=settings.VLLM_API_BASE,
|
||||
base_url=vllm_api_base,
|
||||
default_headers=custom_headers,
|
||||
)
|
||||
model_name = settings.VLLM_MODEL_NAME
|
||||
@@ -53,7 +56,9 @@ def generate_vllm(
|
||||
) -> GenerationResult:
|
||||
prompt = item.prompt
|
||||
if not prompt:
|
||||
prompt = PROMPT_MAPPING[item.prompt_type]
|
||||
prompt = PROMPT_MAPPING[item.prompt_type].replace(
|
||||
"{bbox_scale}", str(bbox_scale)
|
||||
)
|
||||
|
||||
content = []
|
||||
image = scale_to_fit(item.image)
|
||||
@@ -76,7 +81,6 @@ def generate_vllm(
|
||||
top_p=top_p,
|
||||
)
|
||||
raw = completion.choices[0].message.content
|
||||
raw = fix_raw(raw)
|
||||
result = GenerationResult(
|
||||
raw=raw,
|
||||
token_count=completion.usage.completion_tokens,
|
||||
@@ -88,27 +92,52 @@ def generate_vllm(
|
||||
|
||||
return result
|
||||
|
||||
def process_item(item, max_retries):
|
||||
def process_item(item, max_retries, max_failure_retries=None):
|
||||
result = _generate(item)
|
||||
retries = 0
|
||||
|
||||
while retries < max_retries and (
|
||||
detect_repeat_token(result.raw)
|
||||
or (
|
||||
len(result.raw) > 50
|
||||
and detect_repeat_token(result.raw, cut_from_end=50)
|
||||
)
|
||||
or result.error
|
||||
):
|
||||
print(
|
||||
f"Detected repeat token or error, retrying generation (attempt {retries + 1})..."
|
||||
)
|
||||
while _should_retry(result, retries, max_retries, max_failure_retries):
|
||||
result = _generate(item, temperature=0.3, top_p=0.95)
|
||||
retries += 1
|
||||
|
||||
return result
|
||||
|
||||
def _should_retry(result, retries, max_retries, max_failure_retries):
|
||||
has_repeat = detect_repeat_token(result.raw) or (
|
||||
len(result.raw) > 50 and detect_repeat_token(result.raw, cut_from_end=50)
|
||||
)
|
||||
|
||||
if retries < max_retries and has_repeat:
|
||||
print(
|
||||
f"Detected repeat token, retrying generation (attempt {retries + 1})..."
|
||||
)
|
||||
return True
|
||||
|
||||
if retries < max_retries and result.error:
|
||||
print(
|
||||
f"Detected vllm error, retrying generation (attempt {retries + 1})..."
|
||||
)
|
||||
time.sleep(2 * (retries + 1)) # Sleeping can help under load
|
||||
return True
|
||||
|
||||
if (
|
||||
result.error
|
||||
and max_failure_retries is not None
|
||||
and retries < max_failure_retries
|
||||
):
|
||||
print(
|
||||
f"Detected vllm error, retrying generation (attempt {retries + 1})..."
|
||||
)
|
||||
time.sleep(2 * (retries + 1)) # Sleeping can help under load
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
results = list(executor.map(process_item, batch, repeat(max_retries)))
|
||||
results = list(
|
||||
executor.map(
|
||||
process_item, batch, repeat(max_retries), repeat(max_failure_retries)
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@@ -6,9 +6,11 @@ from functools import lru_cache
|
||||
|
||||
import six
|
||||
from PIL import Image
|
||||
from bs4 import BeautifulSoup, NavigableString
|
||||
from bs4 import BeautifulSoup
|
||||
from markdownify import MarkdownConverter, re_whitespace
|
||||
|
||||
from chandra.settings import settings
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _hash_html(html: str):
|
||||
@@ -20,15 +22,6 @@ def get_image_name(html: str, div_idx: int):
|
||||
return f"{html_hash}_{div_idx}_img.webp"
|
||||
|
||||
|
||||
def fix_raw(html: str):
|
||||
def replace_group(match):
|
||||
numbers = re.findall(r"\d+", match.group(0))
|
||||
return "[" + ",".join(numbers) + "]"
|
||||
|
||||
result = re.sub(r"(?:<BBOX\d+>){4}", replace_group, html)
|
||||
return result
|
||||
|
||||
|
||||
def extract_images(html: str, chunks: dict, image: Image.Image):
|
||||
images = {}
|
||||
div_idx = 0
|
||||
@@ -96,39 +89,6 @@ def parse_html(
|
||||
return out_html
|
||||
|
||||
|
||||
def escape_dollars(text):
|
||||
return text.replace("$", r"\$")
|
||||
|
||||
|
||||
def get_formatted_table_text(element):
|
||||
text = []
|
||||
for content in element.contents:
|
||||
if content is None:
|
||||
continue
|
||||
|
||||
if isinstance(content, NavigableString):
|
||||
stripped = content.strip()
|
||||
if stripped:
|
||||
text.append(escape_dollars(stripped))
|
||||
elif content.name == "br":
|
||||
text.append("<br>")
|
||||
elif content.name == "math":
|
||||
text.append("$" + content.text + "$")
|
||||
else:
|
||||
content_str = escape_dollars(str(content))
|
||||
text.append(content_str)
|
||||
|
||||
full_text = ""
|
||||
for i, t in enumerate(text):
|
||||
if t == "<br>":
|
||||
full_text += t
|
||||
elif i > 0 and text[i - 1] != "<br>":
|
||||
full_text += " " + t
|
||||
else:
|
||||
full_text += t
|
||||
return full_text
|
||||
|
||||
|
||||
class Markdownify(MarkdownConverter):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -228,20 +188,25 @@ class LayoutBlock:
|
||||
content: str
|
||||
|
||||
|
||||
def parse_layout(html: str, image: Image.Image):
|
||||
def parse_layout(html: str, image: Image.Image, bbox_scale=settings.BBOX_SCALE):
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
top_level_divs = soup.find_all("div", recursive=False)
|
||||
width, height = image.size
|
||||
width_scaler = width / 1024
|
||||
height_scaler = height / 1024
|
||||
width_scaler = width / bbox_scale
|
||||
height_scaler = height / bbox_scale
|
||||
layout_blocks = []
|
||||
for div in top_level_divs:
|
||||
bbox = div.get("data-bbox")
|
||||
|
||||
try:
|
||||
bbox = json.loads(bbox)
|
||||
assert len(bbox) == 4, "Invalid bbox length"
|
||||
except Exception:
|
||||
bbox = [0, 0, 1, 1]
|
||||
try:
|
||||
bbox = bbox.split(" ")
|
||||
assert len(bbox) == 4, "Invalid bbox length"
|
||||
except Exception:
|
||||
bbox = [0, 0, 1, 1]
|
||||
|
||||
bbox = list(map(int, bbox))
|
||||
# Normalize bbox
|
||||
@@ -257,7 +222,7 @@ def parse_layout(html: str, image: Image.Image):
|
||||
return layout_blocks
|
||||
|
||||
|
||||
def parse_chunks(html: str, image: Image.Image):
|
||||
layout = parse_layout(html, image)
|
||||
def parse_chunks(html: str, image: Image.Image, bbox_scale=settings.BBOX_SCALE):
|
||||
layout = parse_layout(html, image, bbox_scale=bbox_scale)
|
||||
chunks = [asdict(block) for block in layout]
|
||||
return chunks
|
||||
|
||||
@@ -65,7 +65,7 @@ Guidelines:
|
||||
""".strip()
|
||||
|
||||
OCR_LAYOUT_PROMPT = f"""
|
||||
OCR this image to HTML, arranged as layout blocks. Each layout block should be a div with the data-bbox attribute representing the bounding box of the block in [x0, y0, x1, y1] format. Bboxes are normalized 0-1024. The data-label attribute is the label for the block.
|
||||
OCR this image to HTML, arranged as layout blocks. Each layout block should be a div with the data-bbox attribute representing the bounding box of the block in [x0, y0, x1, y1] format. Bboxes are normalized 0-{{bbox_scale}}. The data-label attribute is the label for the block.
|
||||
|
||||
Use the following labels:
|
||||
- Caption
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from dotenv import find_dotenv
|
||||
from pydantic import computed_field
|
||||
from pydantic_settings import BaseSettings
|
||||
import torch
|
||||
import os
|
||||
|
||||
|
||||
@@ -15,6 +13,7 @@ class Settings(BaseSettings):
|
||||
TORCH_DEVICE: str | None = None
|
||||
MAX_OUTPUT_TOKENS: int = 12384
|
||||
TORCH_ATTN: str | None = None
|
||||
BBOX_SCALE: int = 1024
|
||||
|
||||
# vLLM server settings
|
||||
VLLM_API_KEY: str = "EMPTY"
|
||||
@@ -23,11 +22,6 @@ class Settings(BaseSettings):
|
||||
VLLM_GPUS: str = "0"
|
||||
MAX_VLLM_RETRIES: int = 6
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def TORCH_DTYPE(self) -> torch.dtype:
|
||||
return torch.bfloat16
|
||||
|
||||
class Config:
|
||||
env_file = find_dotenv("local.env")
|
||||
extra = "ignore"
|
||||
|
||||
Reference in New Issue
Block a user