mirror of
https://github.com/datalab-to/chandra.git
synced 2026-05-01 22:46:16 +00:00
Merge remote-tracking branch 'origin/master'
This commit is contained in:
@@ -13,7 +13,23 @@ def flatten(page, flag=pdfium_c.FLAT_NORMALDISPLAY):
|
|||||||
print(f"Failed to flatten annotations / form fields on page {page}.")
|
print(f"Failed to flatten annotations / form fields on page {page}.")
|
||||||
|
|
||||||
|
|
||||||
def load_pdf_images(filepath: str, page_range: List[int]):
|
def load_image(
|
||||||
|
filepath: str, min_image_dim: int = settings.MIN_IMAGE_DIM
|
||||||
|
) -> Image.Image:
|
||||||
|
image = Image.open(filepath).convert("RGB")
|
||||||
|
if image.width < min_image_dim or image.height < min_image_dim:
|
||||||
|
scale = min_image_dim / min(image.width, image.height)
|
||||||
|
new_size = (int(image.width * scale), int(image.height * scale))
|
||||||
|
image = image.resize(new_size, Image.Resampling.LANCZOS)
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def load_pdf_images(
|
||||||
|
filepath: str,
|
||||||
|
page_range: List[int],
|
||||||
|
image_dpi: int = settings.IMAGE_DPI,
|
||||||
|
min_pdf_image_dim: int = settings.MIN_PDF_IMAGE_DIM,
|
||||||
|
) -> List[Image.Image]:
|
||||||
doc = pdfium.PdfDocument(filepath)
|
doc = pdfium.PdfDocument(filepath)
|
||||||
doc.init_forms()
|
doc.init_forms()
|
||||||
|
|
||||||
@@ -22,8 +38,8 @@ def load_pdf_images(filepath: str, page_range: List[int]):
|
|||||||
if not page_range or page in page_range:
|
if not page_range or page in page_range:
|
||||||
page_obj = doc[page]
|
page_obj = doc[page]
|
||||||
min_page_dim = min(page_obj.get_width(), page_obj.get_height())
|
min_page_dim = min(page_obj.get_width(), page_obj.get_height())
|
||||||
scale_dpi = (settings.MIN_IMAGE_DIM / min_page_dim) * 72
|
scale_dpi = (min_pdf_image_dim / min_page_dim) * 72
|
||||||
scale_dpi = max(scale_dpi, settings.IMAGE_DPI)
|
scale_dpi = max(scale_dpi, image_dpi)
|
||||||
page_obj = doc[page]
|
page_obj = doc[page]
|
||||||
flatten(page_obj)
|
flatten(page_obj)
|
||||||
page_obj = doc[page]
|
page_obj = doc[page]
|
||||||
@@ -56,5 +72,5 @@ def load_file(filepath: str, config: dict):
|
|||||||
if input_type and input_type.extension == "pdf":
|
if input_type and input_type.extension == "pdf":
|
||||||
images = load_pdf_images(filepath, page_range)
|
images = load_pdf_images(filepath, page_range)
|
||||||
else:
|
else:
|
||||||
images = [Image.open(filepath).convert("RGB")]
|
images = [load_image(filepath)]
|
||||||
return images
|
return images
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from chandra.model.hf import load_model, generate_hf
|
|||||||
from chandra.model.schema import BatchInputItem, BatchOutputItem
|
from chandra.model.schema import BatchInputItem, BatchOutputItem
|
||||||
from chandra.model.vllm import generate_vllm
|
from chandra.model.vllm import generate_vllm
|
||||||
from chandra.output import parse_markdown, parse_html, parse_chunks, extract_images
|
from chandra.output import parse_markdown, parse_html, parse_chunks, extract_images
|
||||||
|
from chandra.settings import settings
|
||||||
|
|
||||||
|
|
||||||
class InferenceManager:
|
class InferenceManager:
|
||||||
@@ -26,19 +27,29 @@ class InferenceManager:
|
|||||||
output_kwargs["include_headers_footers"] = kwargs.pop(
|
output_kwargs["include_headers_footers"] = kwargs.pop(
|
||||||
"include_headers_footers"
|
"include_headers_footers"
|
||||||
)
|
)
|
||||||
|
bbox_scale = kwargs.pop("bbox_scale", settings.BBOX_SCALE)
|
||||||
|
vllm_api_base = kwargs.pop("vllm_api_base", settings.VLLM_API_BASE)
|
||||||
|
|
||||||
if self.method == "vllm":
|
if self.method == "vllm":
|
||||||
results = generate_vllm(
|
results = generate_vllm(
|
||||||
batch, max_output_tokens=max_output_tokens, **kwargs
|
batch,
|
||||||
|
max_output_tokens=max_output_tokens,
|
||||||
|
bbox_scale=bbox_scale,
|
||||||
|
vllm_api_base=vllm_api_base,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
results = generate_hf(
|
results = generate_hf(
|
||||||
batch, self.model, max_output_tokens=max_output_tokens, **kwargs
|
batch,
|
||||||
|
self.model,
|
||||||
|
max_output_tokens=max_output_tokens,
|
||||||
|
bbox_scale=bbox_scale,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
output = []
|
output = []
|
||||||
for result, input_item in zip(results, batch):
|
for result, input_item in zip(results, batch):
|
||||||
chunks = parse_chunks(result.raw, input_item.image)
|
chunks = parse_chunks(result.raw, input_item.image, bbox_scale=bbox_scale)
|
||||||
output.append(
|
output.append(
|
||||||
BatchOutputItem(
|
BatchOutputItem(
|
||||||
markdown=parse_markdown(result.raw, **output_kwargs),
|
markdown=parse_markdown(result.raw, **output_kwargs),
|
||||||
@@ -48,6 +59,7 @@ class InferenceManager:
|
|||||||
page_box=[0, 0, input_item.image.width, input_item.image.height],
|
page_box=[0, 0, input_item.image.width, input_item.image.height],
|
||||||
token_count=result.token_count,
|
token_count=result.token_count,
|
||||||
images=extract_images(result.raw, chunks, input_item.image),
|
images=extract_images(result.raw, chunks, input_item.image),
|
||||||
|
error=result.error,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -1,8 +1,5 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from qwen_vl_utils import process_vision_info
|
|
||||||
from transformers import Qwen3VLForConditionalGeneration, Qwen3VLProcessor
|
|
||||||
|
|
||||||
from chandra.model.schema import BatchInputItem, GenerationResult
|
from chandra.model.schema import BatchInputItem, GenerationResult
|
||||||
from chandra.model.util import scale_to_fit
|
from chandra.model.util import scale_to_fit
|
||||||
from chandra.prompts import PROMPT_MAPPING
|
from chandra.prompts import PROMPT_MAPPING
|
||||||
@@ -10,12 +7,20 @@ from chandra.settings import settings
|
|||||||
|
|
||||||
|
|
||||||
def generate_hf(
|
def generate_hf(
|
||||||
batch: List[BatchInputItem], model, max_output_tokens=None, **kwargs
|
batch: List[BatchInputItem],
|
||||||
|
model,
|
||||||
|
max_output_tokens=None,
|
||||||
|
bbox_scale: int = settings.BBOX_SCALE,
|
||||||
|
**kwargs,
|
||||||
) -> List[GenerationResult]:
|
) -> List[GenerationResult]:
|
||||||
|
from qwen_vl_utils import process_vision_info
|
||||||
|
|
||||||
if max_output_tokens is None:
|
if max_output_tokens is None:
|
||||||
max_output_tokens = settings.MAX_OUTPUT_TOKENS
|
max_output_tokens = settings.MAX_OUTPUT_TOKENS
|
||||||
|
|
||||||
messages = [process_batch_element(item, model.processor) for item in batch]
|
messages = [
|
||||||
|
process_batch_element(item, model.processor, bbox_scale) for item in batch
|
||||||
|
]
|
||||||
text = model.processor.apply_chat_template(
|
text = model.processor.apply_chat_template(
|
||||||
messages, tokenize=False, add_generation_prompt=True
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
)
|
)
|
||||||
@@ -48,12 +53,12 @@ def generate_hf(
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def process_batch_element(item: BatchInputItem, processor):
|
def process_batch_element(item: BatchInputItem, processor, bbox_scale: int):
|
||||||
prompt = item.prompt
|
prompt = item.prompt
|
||||||
prompt_type = item.prompt_type
|
prompt_type = item.prompt_type
|
||||||
|
|
||||||
if not prompt:
|
if not prompt:
|
||||||
prompt = PROMPT_MAPPING[prompt_type]
|
prompt = PROMPT_MAPPING[prompt_type].replace("{bbox_scale}", str(bbox_scale))
|
||||||
|
|
||||||
content = []
|
content = []
|
||||||
image = scale_to_fit(item.image) # Guarantee max size
|
image = scale_to_fit(item.image) # Guarantee max size
|
||||||
@@ -65,12 +70,15 @@ def process_batch_element(item: BatchInputItem, processor):
|
|||||||
|
|
||||||
|
|
||||||
def load_model():
|
def load_model():
|
||||||
|
import torch
|
||||||
|
from transformers import Qwen3VLForConditionalGeneration, Qwen3VLProcessor
|
||||||
|
|
||||||
device_map = "auto"
|
device_map = "auto"
|
||||||
if settings.TORCH_DEVICE:
|
if settings.TORCH_DEVICE:
|
||||||
device_map = {"": settings.TORCH_DEVICE}
|
device_map = {"": settings.TORCH_DEVICE}
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"dtype": settings.TORCH_DTYPE,
|
"dtype": torch.bfloat16,
|
||||||
"device_map": device_map,
|
"device_map": device_map,
|
||||||
}
|
}
|
||||||
if settings.TORCH_ATTN:
|
if settings.TORCH_ATTN:
|
||||||
|
|||||||
@@ -27,3 +27,4 @@ class BatchOutputItem:
|
|||||||
page_box: List[int]
|
page_box: List[int]
|
||||||
token_count: int
|
token_count: int
|
||||||
images: dict
|
images: dict
|
||||||
|
error: bool
|
||||||
|
|||||||
@@ -44,9 +44,10 @@ def scale_to_fit(
|
|||||||
|
|
||||||
def detect_repeat_token(
|
def detect_repeat_token(
|
||||||
predicted_tokens: str,
|
predicted_tokens: str,
|
||||||
max_repeats: int = 4,
|
base_max_repeats: int = 4,
|
||||||
window_size: int = 500,
|
window_size: int = 500,
|
||||||
cut_from_end: int = 0,
|
cut_from_end: int = 0,
|
||||||
|
scaling_factor: float = 3.0,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
predicted_tokens = parse_markdown(predicted_tokens)
|
predicted_tokens = parse_markdown(predicted_tokens)
|
||||||
@@ -57,11 +58,13 @@ def detect_repeat_token(
|
|||||||
if cut_from_end > 0:
|
if cut_from_end > 0:
|
||||||
predicted_tokens = predicted_tokens[:-cut_from_end]
|
predicted_tokens = predicted_tokens[:-cut_from_end]
|
||||||
|
|
||||||
# Try different sequence lengths (1 to window_size//2)
|
|
||||||
for seq_len in range(1, window_size // 2 + 1):
|
for seq_len in range(1, window_size // 2 + 1):
|
||||||
# Extract the potential repeating sequence from the end
|
# Extract the potential repeating sequence from the end
|
||||||
candidate_seq = predicted_tokens[-seq_len:]
|
candidate_seq = predicted_tokens[-seq_len:]
|
||||||
|
|
||||||
|
# Inverse scaling: shorter sequences need more repeats
|
||||||
|
max_repeats = int(base_max_repeats * (1 + scaling_factor / seq_len))
|
||||||
|
|
||||||
# Count how many times this sequence appears consecutively at the end
|
# Count how many times this sequence appears consecutively at the end
|
||||||
repeat_count = 0
|
repeat_count = 0
|
||||||
pos = len(predicted_tokens) - seq_len
|
pos = len(predicted_tokens) - seq_len
|
||||||
@@ -75,7 +78,6 @@ def detect_repeat_token(
|
|||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
# If we found more than max_repeats consecutive occurrences
|
|
||||||
if repeat_count > max_repeats:
|
if repeat_count > max_repeats:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
|
import time
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from itertools import repeat
|
from itertools import repeat
|
||||||
from typing import List
|
from typing import List
|
||||||
@@ -25,10 +26,15 @@ def generate_vllm(
|
|||||||
max_output_tokens: int = None,
|
max_output_tokens: int = None,
|
||||||
max_retries: int = None,
|
max_retries: int = None,
|
||||||
max_workers: int | None = None,
|
max_workers: int | None = None,
|
||||||
|
custom_headers: dict | None = None,
|
||||||
|
max_failure_retries: int | None = None,
|
||||||
|
bbox_scale: int = settings.BBOX_SCALE,
|
||||||
|
vllm_api_base: str = settings.VLLM_API_BASE,
|
||||||
) -> List[GenerationResult]:
|
) -> List[GenerationResult]:
|
||||||
client = OpenAI(
|
client = OpenAI(
|
||||||
api_key=settings.VLLM_API_KEY,
|
api_key=settings.VLLM_API_KEY,
|
||||||
base_url=settings.VLLM_API_BASE,
|
base_url=vllm_api_base,
|
||||||
|
default_headers=custom_headers,
|
||||||
)
|
)
|
||||||
model_name = settings.VLLM_MODEL_NAME
|
model_name = settings.VLLM_MODEL_NAME
|
||||||
|
|
||||||
@@ -50,7 +56,9 @@ def generate_vllm(
|
|||||||
) -> GenerationResult:
|
) -> GenerationResult:
|
||||||
prompt = item.prompt
|
prompt = item.prompt
|
||||||
if not prompt:
|
if not prompt:
|
||||||
prompt = PROMPT_MAPPING[item.prompt_type]
|
prompt = PROMPT_MAPPING[item.prompt_type].replace(
|
||||||
|
"{bbox_scale}", str(bbox_scale)
|
||||||
|
)
|
||||||
|
|
||||||
content = []
|
content = []
|
||||||
image = scale_to_fit(item.image)
|
image = scale_to_fit(item.image)
|
||||||
@@ -68,41 +76,68 @@ def generate_vllm(
|
|||||||
completion = client.chat.completions.create(
|
completion = client.chat.completions.create(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
messages=[{"role": "user", "content": content}],
|
messages=[{"role": "user", "content": content}],
|
||||||
max_tokens=settings.MAX_OUTPUT_TOKENS,
|
max_tokens=max_output_tokens,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
)
|
)
|
||||||
|
raw = completion.choices[0].message.content
|
||||||
|
result = GenerationResult(
|
||||||
|
raw=raw,
|
||||||
|
token_count=completion.usage.completion_tokens,
|
||||||
|
error=False,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error during VLLM generation: {e}")
|
print(f"Error during VLLM generation: {e}")
|
||||||
return GenerationResult(raw="", token_count=0, error=True)
|
return GenerationResult(raw="", token_count=0, error=True)
|
||||||
|
|
||||||
return GenerationResult(
|
return result
|
||||||
raw=completion.choices[0].message.content,
|
|
||||||
token_count=completion.usage.completion_tokens,
|
|
||||||
error=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def process_item(item, max_retries):
|
def process_item(item, max_retries, max_failure_retries=None):
|
||||||
result = _generate(item)
|
result = _generate(item)
|
||||||
retries = 0
|
retries = 0
|
||||||
|
|
||||||
while retries < max_retries and (
|
while _should_retry(result, retries, max_retries, max_failure_retries):
|
||||||
detect_repeat_token(result.raw)
|
|
||||||
or (
|
|
||||||
len(result.raw) > 50
|
|
||||||
and detect_repeat_token(result.raw, cut_from_end=50)
|
|
||||||
)
|
|
||||||
or result.error
|
|
||||||
):
|
|
||||||
print(
|
|
||||||
f"Detected repeat token or error, retrying generation (attempt {retries + 1})..."
|
|
||||||
)
|
|
||||||
result = _generate(item, temperature=0.3, top_p=0.95)
|
result = _generate(item, temperature=0.3, top_p=0.95)
|
||||||
retries += 1
|
retries += 1
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def _should_retry(result, retries, max_retries, max_failure_retries):
|
||||||
|
has_repeat = detect_repeat_token(result.raw) or (
|
||||||
|
len(result.raw) > 50 and detect_repeat_token(result.raw, cut_from_end=50)
|
||||||
|
)
|
||||||
|
|
||||||
|
if retries < max_retries and has_repeat:
|
||||||
|
print(
|
||||||
|
f"Detected repeat token, retrying generation (attempt {retries + 1})..."
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
if retries < max_retries and result.error:
|
||||||
|
print(
|
||||||
|
f"Detected vllm error, retrying generation (attempt {retries + 1})..."
|
||||||
|
)
|
||||||
|
time.sleep(2 * (retries + 1)) # Sleeping can help under load
|
||||||
|
return True
|
||||||
|
|
||||||
|
if (
|
||||||
|
result.error
|
||||||
|
and max_failure_retries is not None
|
||||||
|
and retries < max_failure_retries
|
||||||
|
):
|
||||||
|
print(
|
||||||
|
f"Detected vllm error, retrying generation (attempt {retries + 1})..."
|
||||||
|
)
|
||||||
|
time.sleep(2 * (retries + 1)) # Sleeping can help under load
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
results = list(executor.map(process_item, batch, repeat(max_retries)))
|
results = list(
|
||||||
|
executor.map(
|
||||||
|
process_item, batch, repeat(max_retries), repeat(max_failure_retries)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|||||||
@@ -6,9 +6,11 @@ from functools import lru_cache
|
|||||||
|
|
||||||
import six
|
import six
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from bs4 import BeautifulSoup, NavigableString
|
from bs4 import BeautifulSoup
|
||||||
from markdownify import MarkdownConverter, re_whitespace
|
from markdownify import MarkdownConverter, re_whitespace
|
||||||
|
|
||||||
|
from chandra.settings import settings
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
def _hash_html(html: str):
|
def _hash_html(html: str):
|
||||||
@@ -30,7 +32,11 @@ def extract_images(html: str, chunks: dict, image: Image.Image):
|
|||||||
if not img:
|
if not img:
|
||||||
continue
|
continue
|
||||||
bbox = chunk["bbox"]
|
bbox = chunk["bbox"]
|
||||||
block_image = image.crop(bbox)
|
try:
|
||||||
|
block_image = image.crop(bbox)
|
||||||
|
except ValueError:
|
||||||
|
# Happens when bbox coordinates are invalid
|
||||||
|
continue
|
||||||
img_name = get_image_name(html, div_idx)
|
img_name = get_image_name(html, div_idx)
|
||||||
images[img_name] = block_image
|
images[img_name] = block_image
|
||||||
return images
|
return images
|
||||||
@@ -67,44 +73,22 @@ def parse_html(
|
|||||||
else:
|
else:
|
||||||
img = BeautifulSoup(f"<img src='{img_src}'/>", "html.parser")
|
img = BeautifulSoup(f"<img src='{img_src}'/>", "html.parser")
|
||||||
div.append(img)
|
div.append(img)
|
||||||
|
|
||||||
|
# Wrap text content in <p> tags if no inner HTML tags exist
|
||||||
|
if label in ["Text"] and not re.search(
|
||||||
|
"<.+>", str(div.decode_contents()).strip()
|
||||||
|
):
|
||||||
|
# Add inner p tags if missing for text blocks
|
||||||
|
text_content = str(div.decode_contents()).strip()
|
||||||
|
text_content = f"<p>{text_content}</p>"
|
||||||
|
div.clear()
|
||||||
|
div.append(BeautifulSoup(text_content, "html.parser"))
|
||||||
|
|
||||||
content = str(div.decode_contents())
|
content = str(div.decode_contents())
|
||||||
out_html += content
|
out_html += content
|
||||||
return out_html
|
return out_html
|
||||||
|
|
||||||
|
|
||||||
def escape_dollars(text):
|
|
||||||
return text.replace("$", r"\$")
|
|
||||||
|
|
||||||
|
|
||||||
def get_formatted_table_text(element):
|
|
||||||
text = []
|
|
||||||
for content in element.contents:
|
|
||||||
if content is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if isinstance(content, NavigableString):
|
|
||||||
stripped = content.strip()
|
|
||||||
if stripped:
|
|
||||||
text.append(escape_dollars(stripped))
|
|
||||||
elif content.name == "br":
|
|
||||||
text.append("<br>")
|
|
||||||
elif content.name == "math":
|
|
||||||
text.append("$" + content.text + "$")
|
|
||||||
else:
|
|
||||||
content_str = escape_dollars(str(content))
|
|
||||||
text.append(content_str)
|
|
||||||
|
|
||||||
full_text = ""
|
|
||||||
for i, t in enumerate(text):
|
|
||||||
if t == "<br>":
|
|
||||||
full_text += t
|
|
||||||
elif i > 0 and text[i - 1] != "<br>":
|
|
||||||
full_text += " " + t
|
|
||||||
else:
|
|
||||||
full_text += t
|
|
||||||
return full_text
|
|
||||||
|
|
||||||
|
|
||||||
class Markdownify(MarkdownConverter):
|
class Markdownify(MarkdownConverter):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -204,19 +188,25 @@ class LayoutBlock:
|
|||||||
content: str
|
content: str
|
||||||
|
|
||||||
|
|
||||||
def parse_layout(html: str, image: Image.Image):
|
def parse_layout(html: str, image: Image.Image, bbox_scale=settings.BBOX_SCALE):
|
||||||
soup = BeautifulSoup(html, "html.parser")
|
soup = BeautifulSoup(html, "html.parser")
|
||||||
top_level_divs = soup.find_all("div", recursive=False)
|
top_level_divs = soup.find_all("div", recursive=False)
|
||||||
width, height = image.size
|
width, height = image.size
|
||||||
width_scaler = width / 1024
|
width_scaler = width / bbox_scale
|
||||||
height_scaler = height / 1024
|
height_scaler = height / bbox_scale
|
||||||
layout_blocks = []
|
layout_blocks = []
|
||||||
for div in top_level_divs:
|
for div in top_level_divs:
|
||||||
bbox = div.get("data-bbox")
|
bbox = div.get("data-bbox")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
bbox = json.loads(bbox)
|
bbox = json.loads(bbox)
|
||||||
|
assert len(bbox) == 4, "Invalid bbox length"
|
||||||
except Exception:
|
except Exception:
|
||||||
bbox = [0, 0, 1, 1] # Fallback to a default bbox if parsing fails
|
try:
|
||||||
|
bbox = bbox.split(" ")
|
||||||
|
assert len(bbox) == 4, "Invalid bbox length"
|
||||||
|
except Exception:
|
||||||
|
bbox = [0, 0, 1, 1]
|
||||||
|
|
||||||
bbox = list(map(int, bbox))
|
bbox = list(map(int, bbox))
|
||||||
# Normalize bbox
|
# Normalize bbox
|
||||||
@@ -232,7 +222,7 @@ def parse_layout(html: str, image: Image.Image):
|
|||||||
return layout_blocks
|
return layout_blocks
|
||||||
|
|
||||||
|
|
||||||
def parse_chunks(html: str, image: Image.Image):
|
def parse_chunks(html: str, image: Image.Image, bbox_scale=settings.BBOX_SCALE):
|
||||||
layout = parse_layout(html, image)
|
layout = parse_layout(html, image, bbox_scale=bbox_scale)
|
||||||
chunks = [asdict(block) for block in layout]
|
chunks = [asdict(block) for block in layout]
|
||||||
return chunks
|
return chunks
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ Guidelines:
|
|||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
OCR_LAYOUT_PROMPT = f"""
|
OCR_LAYOUT_PROMPT = f"""
|
||||||
OCR this image to HTML, arranged as layout blocks. Each layout block should be a div with the data-bbox attribute representing the bounding box of the block in [x0, y0, x1, y1] format. Bboxes are normalized 0-1024. The data-label attribute is the label for the block.
|
OCR this image to HTML, arranged as layout blocks. Each layout block should be a div with the data-bbox attribute representing the bounding box of the block in [x0, y0, x1, y1] format. Bboxes are normalized 0-{{bbox_scale}}. The data-label attribute is the label for the block.
|
||||||
|
|
||||||
Use the following labels:
|
Use the following labels:
|
||||||
- Caption
|
- Caption
|
||||||
|
|||||||
@@ -143,6 +143,7 @@ def process():
|
|||||||
"image_height": img_height,
|
"image_height": img_height,
|
||||||
"blocks": blocks_data,
|
"blocks": blocks_data,
|
||||||
"html": html_with_images,
|
"html": html_with_images,
|
||||||
|
"markdown": result.markdown,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -64,6 +64,20 @@
|
|||||||
cursor: not-allowed;
|
cursor: not-allowed;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.controls label {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 8px;
|
||||||
|
color: white;
|
||||||
|
font-size: 14px;
|
||||||
|
cursor: pointer;
|
||||||
|
user-select: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.controls input[type="checkbox"] {
|
||||||
|
cursor: pointer;
|
||||||
|
}
|
||||||
|
|
||||||
.loading {
|
.loading {
|
||||||
display: none;
|
display: none;
|
||||||
color: #f39c12;
|
color: #f39c12;
|
||||||
@@ -75,6 +89,11 @@
|
|||||||
font-weight: bold;
|
font-weight: bold;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.success {
|
||||||
|
color: #27ae60;
|
||||||
|
font-weight: bold;
|
||||||
|
}
|
||||||
|
|
||||||
.screenshot-container {
|
.screenshot-container {
|
||||||
display: none;
|
display: none;
|
||||||
margin-top: 60px;
|
margin-top: 60px;
|
||||||
@@ -88,8 +107,18 @@
|
|||||||
display: flex;
|
display: flex;
|
||||||
}
|
}
|
||||||
|
|
||||||
.left-panel, .right-panel {
|
.left-panel {
|
||||||
flex: 1;
|
flex: 0 0 40%;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
background: white;
|
||||||
|
border-radius: 8px;
|
||||||
|
overflow: hidden;
|
||||||
|
box-shadow: 0 4px 12px rgba(0,0,0,0.3);
|
||||||
|
}
|
||||||
|
|
||||||
|
.right-panel {
|
||||||
|
flex: 0 0 60%;
|
||||||
display: flex;
|
display: flex;
|
||||||
flex-direction: column;
|
flex-direction: column;
|
||||||
background: white;
|
background: white;
|
||||||
@@ -137,6 +166,7 @@
|
|||||||
padding: 30px;
|
padding: 30px;
|
||||||
line-height: 1.6;
|
line-height: 1.6;
|
||||||
color: #333;
|
color: #333;
|
||||||
|
font-size: 24px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.markdown-content h1, .markdown-content h2, .markdown-content h3 {
|
.markdown-content h1, .markdown-content h2, .markdown-content h3 {
|
||||||
@@ -215,8 +245,14 @@
|
|||||||
<input type="text" id="filePath" placeholder="Enter file path (e.g., /path/to/document.pdf)">
|
<input type="text" id="filePath" placeholder="Enter file path (e.g., /path/to/document.pdf)">
|
||||||
<input type="number" id="pageNumber" placeholder="Page" value="0" min="0">
|
<input type="number" id="pageNumber" placeholder="Page" value="0" min="0">
|
||||||
<button id="processBtn" onclick="processFile()">Process</button>
|
<button id="processBtn" onclick="processFile()">Process</button>
|
||||||
|
<label>
|
||||||
|
<input type="checkbox" id="showLayoutBoxes" checked onchange="toggleLayoutBoxes()">
|
||||||
|
Show Layout Boxes
|
||||||
|
</label>
|
||||||
|
<button id="copyMarkdownBtn" onclick="copyMarkdown()" style="display: none;">Copy Markdown</button>
|
||||||
<span class="loading" id="loading">Processing...</span>
|
<span class="loading" id="loading">Processing...</span>
|
||||||
<span class="error" id="error"></span>
|
<span class="error" id="error"></span>
|
||||||
|
<span class="success" id="success"></span>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class="screenshot-container" id="container">
|
<div class="screenshot-container" id="container">
|
||||||
@@ -242,6 +278,11 @@
|
|||||||
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
|
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
|
||||||
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/github-markdown-css/5.8.1/github-markdown.min.css" integrity="sha512-BrOPA520KmDMqieeM7XFe6a3u3Sb3F1JBaQnrIAmWg3EYrciJ+Qqe6ZcKCdfPv26rGcgTrJnZ/IdQEct8h3Zhw==" crossorigin="anonymous" referrerpolicy="no-referrer" />
|
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/github-markdown-css/5.8.1/github-markdown.min.css" integrity="sha512-BrOPA520KmDMqieeM7XFe6a3u3Sb3F1JBaQnrIAmWg3EYrciJ+Qqe6ZcKCdfPv26rGcgTrJnZ/IdQEct8h3Zhw==" crossorigin="anonymous" referrerpolicy="no-referrer" />
|
||||||
<script>
|
<script>
|
||||||
|
// Global state to store markdown and canvas data
|
||||||
|
let currentMarkdown = null;
|
||||||
|
let currentData = null;
|
||||||
|
let currentImageSrc = null;
|
||||||
|
|
||||||
async function processFile() {
|
async function processFile() {
|
||||||
const filePath = document.getElementById('filePath').value;
|
const filePath = document.getElementById('filePath').value;
|
||||||
const pageNumber = parseInt(document.getElementById('pageNumber').value) || 0;
|
const pageNumber = parseInt(document.getElementById('pageNumber').value) || 0;
|
||||||
@@ -285,6 +326,10 @@
|
|||||||
}
|
}
|
||||||
|
|
||||||
function renderResults(data) {
|
function renderResults(data) {
|
||||||
|
// Store data for toggle functionality
|
||||||
|
currentData = data;
|
||||||
|
currentImageSrc = data.image_base64;
|
||||||
|
|
||||||
const canvas = document.getElementById('layoutCanvas');
|
const canvas = document.getElementById('layoutCanvas');
|
||||||
const ctx = canvas.getContext('2d');
|
const ctx = canvas.getContext('2d');
|
||||||
const markdownContent = document.getElementById('markdownContent');
|
const markdownContent = document.getElementById('markdownContent');
|
||||||
@@ -292,51 +337,14 @@
|
|||||||
// Draw image with layout overlays
|
// Draw image with layout overlays
|
||||||
const img = new Image();
|
const img = new Image();
|
||||||
img.onload = function() {
|
img.onload = function() {
|
||||||
canvas.width = data.image_width;
|
drawCanvas(img, data, ctx);
|
||||||
canvas.height = data.image_height;
|
|
||||||
|
|
||||||
// Draw image
|
|
||||||
ctx.drawImage(img, 0, 0, data.image_width, data.image_height);
|
|
||||||
|
|
||||||
// Draw layout blocks
|
|
||||||
ctx.lineWidth = 3;
|
|
||||||
ctx.font = 'bold 14px -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif';
|
|
||||||
|
|
||||||
const labelCounts = {};
|
|
||||||
data.blocks.forEach((block) => {
|
|
||||||
const [x1, y1, x2, y2] = block.bbox;
|
|
||||||
const width = x2 - x1;
|
|
||||||
const height = y2 - y1;
|
|
||||||
|
|
||||||
// Draw rectangle with semi-transparent fill
|
|
||||||
ctx.strokeStyle = block.color;
|
|
||||||
ctx.fillStyle = block.color + '33';
|
|
||||||
ctx.fillRect(x1, y1, width, height);
|
|
||||||
ctx.strokeRect(x1, y1, width, height);
|
|
||||||
|
|
||||||
// Count labels for unique identification
|
|
||||||
labelCounts[block.label] = (labelCounts[block.label] || 0) + 1;
|
|
||||||
const labelWithCount = `${block.label} #${labelCounts[block.label]}`;
|
|
||||||
|
|
||||||
// Draw label with background
|
|
||||||
const textMetrics = ctx.measureText(labelWithCount);
|
|
||||||
const textWidth = textMetrics.width;
|
|
||||||
const textHeight = 16;
|
|
||||||
const padding = 6;
|
|
||||||
|
|
||||||
const labelX = x1;
|
|
||||||
const labelY = Math.max(y1 - textHeight - padding, textHeight);
|
|
||||||
|
|
||||||
ctx.fillStyle = block.color;
|
|
||||||
ctx.fillRect(labelX, labelY - textHeight, textWidth + padding * 2, textHeight + padding);
|
|
||||||
|
|
||||||
ctx.fillStyle = 'white';
|
|
||||||
ctx.textBaseline = 'top';
|
|
||||||
ctx.fillText(labelWithCount, labelX + padding, labelY - textHeight + padding/2);
|
|
||||||
});
|
|
||||||
};
|
};
|
||||||
img.src = data.image_base64;
|
img.src = data.image_base64;
|
||||||
|
|
||||||
|
// Store markdown and show copy button
|
||||||
|
currentMarkdown = data.markdown;
|
||||||
|
document.getElementById('copyMarkdownBtn').style.display = 'inline-block';
|
||||||
|
|
||||||
// Render HTML directly (with images embedded)
|
// Render HTML directly (with images embedded)
|
||||||
markdownContent.innerHTML = data.html;
|
markdownContent.innerHTML = data.html;
|
||||||
|
|
||||||
@@ -362,6 +370,85 @@
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function drawCanvas(img, data, ctx) {
|
||||||
|
const canvas = document.getElementById('layoutCanvas');
|
||||||
|
canvas.width = data.image_width;
|
||||||
|
canvas.height = data.image_height;
|
||||||
|
|
||||||
|
// Draw image
|
||||||
|
ctx.drawImage(img, 0, 0, data.image_width, data.image_height);
|
||||||
|
|
||||||
|
// Check if layout boxes should be shown
|
||||||
|
const showBoxes = document.getElementById('showLayoutBoxes').checked;
|
||||||
|
if (!showBoxes) return;
|
||||||
|
|
||||||
|
// Draw layout blocks
|
||||||
|
ctx.lineWidth = 3;
|
||||||
|
ctx.font = 'bold 14px -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif';
|
||||||
|
|
||||||
|
const labelCounts = {};
|
||||||
|
data.blocks.forEach((block) => {
|
||||||
|
const [x1, y1, x2, y2] = block.bbox;
|
||||||
|
const width = x2 - x1;
|
||||||
|
const height = y2 - y1;
|
||||||
|
|
||||||
|
// Draw rectangle with semi-transparent fill
|
||||||
|
ctx.strokeStyle = block.color;
|
||||||
|
ctx.fillStyle = block.color + '33';
|
||||||
|
ctx.fillRect(x1, y1, width, height);
|
||||||
|
ctx.strokeRect(x1, y1, width, height);
|
||||||
|
|
||||||
|
// Count labels for unique identification
|
||||||
|
labelCounts[block.label] = (labelCounts[block.label] || 0) + 1;
|
||||||
|
const labelWithCount = `${block.label} #${labelCounts[block.label]}`;
|
||||||
|
|
||||||
|
// Draw label with background
|
||||||
|
const textMetrics = ctx.measureText(labelWithCount);
|
||||||
|
const textWidth = textMetrics.width;
|
||||||
|
const textHeight = 16;
|
||||||
|
const padding = 6;
|
||||||
|
|
||||||
|
const labelX = x1;
|
||||||
|
const labelY = Math.max(y1 - textHeight - padding, textHeight);
|
||||||
|
|
||||||
|
ctx.fillStyle = block.color;
|
||||||
|
ctx.fillRect(labelX, labelY - textHeight, textWidth + padding * 2, textHeight + padding);
|
||||||
|
|
||||||
|
ctx.fillStyle = 'white';
|
||||||
|
ctx.textBaseline = 'top';
|
||||||
|
ctx.fillText(labelWithCount, labelX + padding, labelY - textHeight + padding/2);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function toggleLayoutBoxes() {
|
||||||
|
if (!currentData || !currentImageSrc) return;
|
||||||
|
|
||||||
|
const canvas = document.getElementById('layoutCanvas');
|
||||||
|
const ctx = canvas.getContext('2d');
|
||||||
|
const img = new Image();
|
||||||
|
img.onload = function() {
|
||||||
|
drawCanvas(img, currentData, ctx);
|
||||||
|
};
|
||||||
|
img.src = currentImageSrc;
|
||||||
|
}
|
||||||
|
|
||||||
|
function copyMarkdown() {
|
||||||
|
if (!currentMarkdown) {
|
||||||
|
document.getElementById('error').textContent = 'No markdown to copy';
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
navigator.clipboard.writeText(currentMarkdown).then(() => {
|
||||||
|
const success = document.getElementById('success');
|
||||||
|
success.textContent = 'Markdown copied!';
|
||||||
|
setTimeout(() => {
|
||||||
|
success.textContent = '';
|
||||||
|
}, 2000);
|
||||||
|
}).catch((err) => {
|
||||||
|
document.getElementById('error').textContent = 'Failed to copy: ' + err.message;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
// Allow Enter key to trigger processing
|
// Allow Enter key to trigger processing
|
||||||
document.getElementById('filePath').addEventListener('keypress', function(e) {
|
document.getElementById('filePath').addEventListener('keypress', function(e) {
|
||||||
if (e.key === 'Enter') processFile();
|
if (e.key === 'Enter') processFile();
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
from dotenv import find_dotenv
|
from dotenv import find_dotenv
|
||||||
from pydantic import computed_field
|
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
import torch
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
@@ -9,11 +7,13 @@ class Settings(BaseSettings):
|
|||||||
# Paths
|
# Paths
|
||||||
BASE_DIR: str = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
BASE_DIR: str = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
IMAGE_DPI: int = 192
|
IMAGE_DPI: int = 192
|
||||||
MIN_IMAGE_DIM: int = 1024
|
MIN_PDF_IMAGE_DIM: int = 1024
|
||||||
|
MIN_IMAGE_DIM: int = 1536
|
||||||
MODEL_CHECKPOINT: str = "datalab-to/chandra"
|
MODEL_CHECKPOINT: str = "datalab-to/chandra"
|
||||||
TORCH_DEVICE: str | None = None
|
TORCH_DEVICE: str | None = None
|
||||||
MAX_OUTPUT_TOKENS: int = 8192
|
MAX_OUTPUT_TOKENS: int = 12384
|
||||||
TORCH_ATTN: str | None = None
|
TORCH_ATTN: str | None = None
|
||||||
|
BBOX_SCALE: int = 1024
|
||||||
|
|
||||||
# vLLM server settings
|
# vLLM server settings
|
||||||
VLLM_API_KEY: str = "EMPTY"
|
VLLM_API_KEY: str = "EMPTY"
|
||||||
@@ -22,11 +22,6 @@ class Settings(BaseSettings):
|
|||||||
VLLM_GPUS: str = "0"
|
VLLM_GPUS: str = "0"
|
||||||
MAX_VLLM_RETRIES: int = 6
|
MAX_VLLM_RETRIES: int = 6
|
||||||
|
|
||||||
@computed_field
|
|
||||||
@property
|
|
||||||
def TORCH_DTYPE(self) -> torch.dtype:
|
|
||||||
return torch.bfloat16
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
env_file = find_dotenv("local.env")
|
env_file = find_dotenv("local.env")
|
||||||
extra = "ignore"
|
extra = "ignore"
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "chandra-ocr"
|
name = "chandra-ocr"
|
||||||
version = "0.1.7"
|
version = "0.1.9"
|
||||||
description = "OCR model that converts documents to markdown, HTML, or JSON."
|
description = "OCR model that converts documents to markdown, HTML, or JSON."
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
|
|||||||
Reference in New Issue
Block a user