Merge remote-tracking branch 'origin/master'

This commit is contained in:
Vik Paruchuri
2026-01-12 17:59:38 -05:00
12 changed files with 282 additions and 135 deletions

View File

@@ -13,7 +13,23 @@ def flatten(page, flag=pdfium_c.FLAT_NORMALDISPLAY):
print(f"Failed to flatten annotations / form fields on page {page}.") print(f"Failed to flatten annotations / form fields on page {page}.")
def load_pdf_images(filepath: str, page_range: List[int]): def load_image(
filepath: str, min_image_dim: int = settings.MIN_IMAGE_DIM
) -> Image.Image:
image = Image.open(filepath).convert("RGB")
if image.width < min_image_dim or image.height < min_image_dim:
scale = min_image_dim / min(image.width, image.height)
new_size = (int(image.width * scale), int(image.height * scale))
image = image.resize(new_size, Image.Resampling.LANCZOS)
return image
def load_pdf_images(
filepath: str,
page_range: List[int],
image_dpi: int = settings.IMAGE_DPI,
min_pdf_image_dim: int = settings.MIN_PDF_IMAGE_DIM,
) -> List[Image.Image]:
doc = pdfium.PdfDocument(filepath) doc = pdfium.PdfDocument(filepath)
doc.init_forms() doc.init_forms()
@@ -22,8 +38,8 @@ def load_pdf_images(filepath: str, page_range: List[int]):
if not page_range or page in page_range: if not page_range or page in page_range:
page_obj = doc[page] page_obj = doc[page]
min_page_dim = min(page_obj.get_width(), page_obj.get_height()) min_page_dim = min(page_obj.get_width(), page_obj.get_height())
scale_dpi = (settings.MIN_IMAGE_DIM / min_page_dim) * 72 scale_dpi = (min_pdf_image_dim / min_page_dim) * 72
scale_dpi = max(scale_dpi, settings.IMAGE_DPI) scale_dpi = max(scale_dpi, image_dpi)
page_obj = doc[page] page_obj = doc[page]
flatten(page_obj) flatten(page_obj)
page_obj = doc[page] page_obj = doc[page]
@@ -56,5 +72,5 @@ def load_file(filepath: str, config: dict):
if input_type and input_type.extension == "pdf": if input_type and input_type.extension == "pdf":
images = load_pdf_images(filepath, page_range) images = load_pdf_images(filepath, page_range)
else: else:
images = [Image.open(filepath).convert("RGB")] images = [load_image(filepath)]
return images return images

View File

@@ -4,6 +4,7 @@ from chandra.model.hf import load_model, generate_hf
from chandra.model.schema import BatchInputItem, BatchOutputItem from chandra.model.schema import BatchInputItem, BatchOutputItem
from chandra.model.vllm import generate_vllm from chandra.model.vllm import generate_vllm
from chandra.output import parse_markdown, parse_html, parse_chunks, extract_images from chandra.output import parse_markdown, parse_html, parse_chunks, extract_images
from chandra.settings import settings
class InferenceManager: class InferenceManager:
@@ -26,19 +27,29 @@ class InferenceManager:
output_kwargs["include_headers_footers"] = kwargs.pop( output_kwargs["include_headers_footers"] = kwargs.pop(
"include_headers_footers" "include_headers_footers"
) )
bbox_scale = kwargs.pop("bbox_scale", settings.BBOX_SCALE)
vllm_api_base = kwargs.pop("vllm_api_base", settings.VLLM_API_BASE)
if self.method == "vllm": if self.method == "vllm":
results = generate_vllm( results = generate_vllm(
batch, max_output_tokens=max_output_tokens, **kwargs batch,
max_output_tokens=max_output_tokens,
bbox_scale=bbox_scale,
vllm_api_base=vllm_api_base,
**kwargs,
) )
else: else:
results = generate_hf( results = generate_hf(
batch, self.model, max_output_tokens=max_output_tokens, **kwargs batch,
self.model,
max_output_tokens=max_output_tokens,
bbox_scale=bbox_scale,
**kwargs,
) )
output = [] output = []
for result, input_item in zip(results, batch): for result, input_item in zip(results, batch):
chunks = parse_chunks(result.raw, input_item.image) chunks = parse_chunks(result.raw, input_item.image, bbox_scale=bbox_scale)
output.append( output.append(
BatchOutputItem( BatchOutputItem(
markdown=parse_markdown(result.raw, **output_kwargs), markdown=parse_markdown(result.raw, **output_kwargs),
@@ -48,6 +59,7 @@ class InferenceManager:
page_box=[0, 0, input_item.image.width, input_item.image.height], page_box=[0, 0, input_item.image.width, input_item.image.height],
token_count=result.token_count, token_count=result.token_count,
images=extract_images(result.raw, chunks, input_item.image), images=extract_images(result.raw, chunks, input_item.image),
error=result.error,
) )
) )
return output return output

View File

@@ -1,8 +1,5 @@
from typing import List from typing import List
from qwen_vl_utils import process_vision_info
from transformers import Qwen3VLForConditionalGeneration, Qwen3VLProcessor
from chandra.model.schema import BatchInputItem, GenerationResult from chandra.model.schema import BatchInputItem, GenerationResult
from chandra.model.util import scale_to_fit from chandra.model.util import scale_to_fit
from chandra.prompts import PROMPT_MAPPING from chandra.prompts import PROMPT_MAPPING
@@ -10,12 +7,20 @@ from chandra.settings import settings
def generate_hf( def generate_hf(
batch: List[BatchInputItem], model, max_output_tokens=None, **kwargs batch: List[BatchInputItem],
model,
max_output_tokens=None,
bbox_scale: int = settings.BBOX_SCALE,
**kwargs,
) -> List[GenerationResult]: ) -> List[GenerationResult]:
from qwen_vl_utils import process_vision_info
if max_output_tokens is None: if max_output_tokens is None:
max_output_tokens = settings.MAX_OUTPUT_TOKENS max_output_tokens = settings.MAX_OUTPUT_TOKENS
messages = [process_batch_element(item, model.processor) for item in batch] messages = [
process_batch_element(item, model.processor, bbox_scale) for item in batch
]
text = model.processor.apply_chat_template( text = model.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, tokenize=False, add_generation_prompt=True
) )
@@ -48,12 +53,12 @@ def generate_hf(
return results return results
def process_batch_element(item: BatchInputItem, processor): def process_batch_element(item: BatchInputItem, processor, bbox_scale: int):
prompt = item.prompt prompt = item.prompt
prompt_type = item.prompt_type prompt_type = item.prompt_type
if not prompt: if not prompt:
prompt = PROMPT_MAPPING[prompt_type] prompt = PROMPT_MAPPING[prompt_type].replace("{bbox_scale}", str(bbox_scale))
content = [] content = []
image = scale_to_fit(item.image) # Guarantee max size image = scale_to_fit(item.image) # Guarantee max size
@@ -65,12 +70,15 @@ def process_batch_element(item: BatchInputItem, processor):
def load_model(): def load_model():
import torch
from transformers import Qwen3VLForConditionalGeneration, Qwen3VLProcessor
device_map = "auto" device_map = "auto"
if settings.TORCH_DEVICE: if settings.TORCH_DEVICE:
device_map = {"": settings.TORCH_DEVICE} device_map = {"": settings.TORCH_DEVICE}
kwargs = { kwargs = {
"dtype": settings.TORCH_DTYPE, "dtype": torch.bfloat16,
"device_map": device_map, "device_map": device_map,
} }
if settings.TORCH_ATTN: if settings.TORCH_ATTN:

View File

@@ -27,3 +27,4 @@ class BatchOutputItem:
page_box: List[int] page_box: List[int]
token_count: int token_count: int
images: dict images: dict
error: bool

View File

@@ -44,9 +44,10 @@ def scale_to_fit(
def detect_repeat_token( def detect_repeat_token(
predicted_tokens: str, predicted_tokens: str,
max_repeats: int = 4, base_max_repeats: int = 4,
window_size: int = 500, window_size: int = 500,
cut_from_end: int = 0, cut_from_end: int = 0,
scaling_factor: float = 3.0,
): ):
try: try:
predicted_tokens = parse_markdown(predicted_tokens) predicted_tokens = parse_markdown(predicted_tokens)
@@ -57,11 +58,13 @@ def detect_repeat_token(
if cut_from_end > 0: if cut_from_end > 0:
predicted_tokens = predicted_tokens[:-cut_from_end] predicted_tokens = predicted_tokens[:-cut_from_end]
# Try different sequence lengths (1 to window_size//2)
for seq_len in range(1, window_size // 2 + 1): for seq_len in range(1, window_size // 2 + 1):
# Extract the potential repeating sequence from the end # Extract the potential repeating sequence from the end
candidate_seq = predicted_tokens[-seq_len:] candidate_seq = predicted_tokens[-seq_len:]
# Inverse scaling: shorter sequences need more repeats
max_repeats = int(base_max_repeats * (1 + scaling_factor / seq_len))
# Count how many times this sequence appears consecutively at the end # Count how many times this sequence appears consecutively at the end
repeat_count = 0 repeat_count = 0
pos = len(predicted_tokens) - seq_len pos = len(predicted_tokens) - seq_len
@@ -75,7 +78,6 @@ def detect_repeat_token(
else: else:
break break
# If we found more than max_repeats consecutive occurrences
if repeat_count > max_repeats: if repeat_count > max_repeats:
return True return True

View File

@@ -1,5 +1,6 @@
import base64 import base64
import io import io
import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from itertools import repeat from itertools import repeat
from typing import List from typing import List
@@ -25,10 +26,15 @@ def generate_vllm(
max_output_tokens: int = None, max_output_tokens: int = None,
max_retries: int = None, max_retries: int = None,
max_workers: int | None = None, max_workers: int | None = None,
custom_headers: dict | None = None,
max_failure_retries: int | None = None,
bbox_scale: int = settings.BBOX_SCALE,
vllm_api_base: str = settings.VLLM_API_BASE,
) -> List[GenerationResult]: ) -> List[GenerationResult]:
client = OpenAI( client = OpenAI(
api_key=settings.VLLM_API_KEY, api_key=settings.VLLM_API_KEY,
base_url=settings.VLLM_API_BASE, base_url=vllm_api_base,
default_headers=custom_headers,
) )
model_name = settings.VLLM_MODEL_NAME model_name = settings.VLLM_MODEL_NAME
@@ -50,7 +56,9 @@ def generate_vllm(
) -> GenerationResult: ) -> GenerationResult:
prompt = item.prompt prompt = item.prompt
if not prompt: if not prompt:
prompt = PROMPT_MAPPING[item.prompt_type] prompt = PROMPT_MAPPING[item.prompt_type].replace(
"{bbox_scale}", str(bbox_scale)
)
content = [] content = []
image = scale_to_fit(item.image) image = scale_to_fit(item.image)
@@ -68,41 +76,68 @@ def generate_vllm(
completion = client.chat.completions.create( completion = client.chat.completions.create(
model=model_name, model=model_name,
messages=[{"role": "user", "content": content}], messages=[{"role": "user", "content": content}],
max_tokens=settings.MAX_OUTPUT_TOKENS, max_tokens=max_output_tokens,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
) )
raw = completion.choices[0].message.content
result = GenerationResult(
raw=raw,
token_count=completion.usage.completion_tokens,
error=False,
)
except Exception as e: except Exception as e:
print(f"Error during VLLM generation: {e}") print(f"Error during VLLM generation: {e}")
return GenerationResult(raw="", token_count=0, error=True) return GenerationResult(raw="", token_count=0, error=True)
return GenerationResult( return result
raw=completion.choices[0].message.content,
token_count=completion.usage.completion_tokens,
error=False,
)
def process_item(item, max_retries): def process_item(item, max_retries, max_failure_retries=None):
result = _generate(item) result = _generate(item)
retries = 0 retries = 0
while retries < max_retries and ( while _should_retry(result, retries, max_retries, max_failure_retries):
detect_repeat_token(result.raw)
or (
len(result.raw) > 50
and detect_repeat_token(result.raw, cut_from_end=50)
)
or result.error
):
print(
f"Detected repeat token or error, retrying generation (attempt {retries + 1})..."
)
result = _generate(item, temperature=0.3, top_p=0.95) result = _generate(item, temperature=0.3, top_p=0.95)
retries += 1 retries += 1
return result return result
def _should_retry(result, retries, max_retries, max_failure_retries):
has_repeat = detect_repeat_token(result.raw) or (
len(result.raw) > 50 and detect_repeat_token(result.raw, cut_from_end=50)
)
if retries < max_retries and has_repeat:
print(
f"Detected repeat token, retrying generation (attempt {retries + 1})..."
)
return True
if retries < max_retries and result.error:
print(
f"Detected vllm error, retrying generation (attempt {retries + 1})..."
)
time.sleep(2 * (retries + 1)) # Sleeping can help under load
return True
if (
result.error
and max_failure_retries is not None
and retries < max_failure_retries
):
print(
f"Detected vllm error, retrying generation (attempt {retries + 1})..."
)
time.sleep(2 * (retries + 1)) # Sleeping can help under load
return True
return False
with ThreadPoolExecutor(max_workers=max_workers) as executor: with ThreadPoolExecutor(max_workers=max_workers) as executor:
results = list(executor.map(process_item, batch, repeat(max_retries))) results = list(
executor.map(
process_item, batch, repeat(max_retries), repeat(max_failure_retries)
)
)
return results return results

View File

@@ -6,9 +6,11 @@ from functools import lru_cache
import six import six
from PIL import Image from PIL import Image
from bs4 import BeautifulSoup, NavigableString from bs4 import BeautifulSoup
from markdownify import MarkdownConverter, re_whitespace from markdownify import MarkdownConverter, re_whitespace
from chandra.settings import settings
@lru_cache @lru_cache
def _hash_html(html: str): def _hash_html(html: str):
@@ -30,7 +32,11 @@ def extract_images(html: str, chunks: dict, image: Image.Image):
if not img: if not img:
continue continue
bbox = chunk["bbox"] bbox = chunk["bbox"]
block_image = image.crop(bbox) try:
block_image = image.crop(bbox)
except ValueError:
# Happens when bbox coordinates are invalid
continue
img_name = get_image_name(html, div_idx) img_name = get_image_name(html, div_idx)
images[img_name] = block_image images[img_name] = block_image
return images return images
@@ -67,44 +73,22 @@ def parse_html(
else: else:
img = BeautifulSoup(f"<img src='{img_src}'/>", "html.parser") img = BeautifulSoup(f"<img src='{img_src}'/>", "html.parser")
div.append(img) div.append(img)
# Wrap text content in <p> tags if no inner HTML tags exist
if label in ["Text"] and not re.search(
"<.+>", str(div.decode_contents()).strip()
):
# Add inner p tags if missing for text blocks
text_content = str(div.decode_contents()).strip()
text_content = f"<p>{text_content}</p>"
div.clear()
div.append(BeautifulSoup(text_content, "html.parser"))
content = str(div.decode_contents()) content = str(div.decode_contents())
out_html += content out_html += content
return out_html return out_html
def escape_dollars(text):
return text.replace("$", r"\$")
def get_formatted_table_text(element):
text = []
for content in element.contents:
if content is None:
continue
if isinstance(content, NavigableString):
stripped = content.strip()
if stripped:
text.append(escape_dollars(stripped))
elif content.name == "br":
text.append("<br>")
elif content.name == "math":
text.append("$" + content.text + "$")
else:
content_str = escape_dollars(str(content))
text.append(content_str)
full_text = ""
for i, t in enumerate(text):
if t == "<br>":
full_text += t
elif i > 0 and text[i - 1] != "<br>":
full_text += " " + t
else:
full_text += t
return full_text
class Markdownify(MarkdownConverter): class Markdownify(MarkdownConverter):
def __init__( def __init__(
self, self,
@@ -204,19 +188,25 @@ class LayoutBlock:
content: str content: str
def parse_layout(html: str, image: Image.Image): def parse_layout(html: str, image: Image.Image, bbox_scale=settings.BBOX_SCALE):
soup = BeautifulSoup(html, "html.parser") soup = BeautifulSoup(html, "html.parser")
top_level_divs = soup.find_all("div", recursive=False) top_level_divs = soup.find_all("div", recursive=False)
width, height = image.size width, height = image.size
width_scaler = width / 1024 width_scaler = width / bbox_scale
height_scaler = height / 1024 height_scaler = height / bbox_scale
layout_blocks = [] layout_blocks = []
for div in top_level_divs: for div in top_level_divs:
bbox = div.get("data-bbox") bbox = div.get("data-bbox")
try: try:
bbox = json.loads(bbox) bbox = json.loads(bbox)
assert len(bbox) == 4, "Invalid bbox length"
except Exception: except Exception:
bbox = [0, 0, 1, 1] # Fallback to a default bbox if parsing fails try:
bbox = bbox.split(" ")
assert len(bbox) == 4, "Invalid bbox length"
except Exception:
bbox = [0, 0, 1, 1]
bbox = list(map(int, bbox)) bbox = list(map(int, bbox))
# Normalize bbox # Normalize bbox
@@ -232,7 +222,7 @@ def parse_layout(html: str, image: Image.Image):
return layout_blocks return layout_blocks
def parse_chunks(html: str, image: Image.Image): def parse_chunks(html: str, image: Image.Image, bbox_scale=settings.BBOX_SCALE):
layout = parse_layout(html, image) layout = parse_layout(html, image, bbox_scale=bbox_scale)
chunks = [asdict(block) for block in layout] chunks = [asdict(block) for block in layout]
return chunks return chunks

View File

@@ -65,7 +65,7 @@ Guidelines:
""".strip() """.strip()
OCR_LAYOUT_PROMPT = f""" OCR_LAYOUT_PROMPT = f"""
OCR this image to HTML, arranged as layout blocks. Each layout block should be a div with the data-bbox attribute representing the bounding box of the block in [x0, y0, x1, y1] format. Bboxes are normalized 0-1024. The data-label attribute is the label for the block. OCR this image to HTML, arranged as layout blocks. Each layout block should be a div with the data-bbox attribute representing the bounding box of the block in [x0, y0, x1, y1] format. Bboxes are normalized 0-{{bbox_scale}}. The data-label attribute is the label for the block.
Use the following labels: Use the following labels:
- Caption - Caption

View File

@@ -143,6 +143,7 @@ def process():
"image_height": img_height, "image_height": img_height,
"blocks": blocks_data, "blocks": blocks_data,
"html": html_with_images, "html": html_with_images,
"markdown": result.markdown,
} }
) )

View File

@@ -64,6 +64,20 @@
cursor: not-allowed; cursor: not-allowed;
} }
.controls label {
display: flex;
align-items: center;
gap: 8px;
color: white;
font-size: 14px;
cursor: pointer;
user-select: none;
}
.controls input[type="checkbox"] {
cursor: pointer;
}
.loading { .loading {
display: none; display: none;
color: #f39c12; color: #f39c12;
@@ -75,6 +89,11 @@
font-weight: bold; font-weight: bold;
} }
.success {
color: #27ae60;
font-weight: bold;
}
.screenshot-container { .screenshot-container {
display: none; display: none;
margin-top: 60px; margin-top: 60px;
@@ -88,8 +107,18 @@
display: flex; display: flex;
} }
.left-panel, .right-panel { .left-panel {
flex: 1; flex: 0 0 40%;
display: flex;
flex-direction: column;
background: white;
border-radius: 8px;
overflow: hidden;
box-shadow: 0 4px 12px rgba(0,0,0,0.3);
}
.right-panel {
flex: 0 0 60%;
display: flex; display: flex;
flex-direction: column; flex-direction: column;
background: white; background: white;
@@ -137,6 +166,7 @@
padding: 30px; padding: 30px;
line-height: 1.6; line-height: 1.6;
color: #333; color: #333;
font-size: 24px;
} }
.markdown-content h1, .markdown-content h2, .markdown-content h3 { .markdown-content h1, .markdown-content h2, .markdown-content h3 {
@@ -215,8 +245,14 @@
<input type="text" id="filePath" placeholder="Enter file path (e.g., /path/to/document.pdf)"> <input type="text" id="filePath" placeholder="Enter file path (e.g., /path/to/document.pdf)">
<input type="number" id="pageNumber" placeholder="Page" value="0" min="0"> <input type="number" id="pageNumber" placeholder="Page" value="0" min="0">
<button id="processBtn" onclick="processFile()">Process</button> <button id="processBtn" onclick="processFile()">Process</button>
<label>
<input type="checkbox" id="showLayoutBoxes" checked onchange="toggleLayoutBoxes()">
Show Layout Boxes
</label>
<button id="copyMarkdownBtn" onclick="copyMarkdown()" style="display: none;">Copy Markdown</button>
<span class="loading" id="loading">Processing...</span> <span class="loading" id="loading">Processing...</span>
<span class="error" id="error"></span> <span class="error" id="error"></span>
<span class="success" id="success"></span>
</div> </div>
<div class="screenshot-container" id="container"> <div class="screenshot-container" id="container">
@@ -242,6 +278,11 @@
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script> <script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/github-markdown-css/5.8.1/github-markdown.min.css" integrity="sha512-BrOPA520KmDMqieeM7XFe6a3u3Sb3F1JBaQnrIAmWg3EYrciJ+Qqe6ZcKCdfPv26rGcgTrJnZ/IdQEct8h3Zhw==" crossorigin="anonymous" referrerpolicy="no-referrer" /> <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/github-markdown-css/5.8.1/github-markdown.min.css" integrity="sha512-BrOPA520KmDMqieeM7XFe6a3u3Sb3F1JBaQnrIAmWg3EYrciJ+Qqe6ZcKCdfPv26rGcgTrJnZ/IdQEct8h3Zhw==" crossorigin="anonymous" referrerpolicy="no-referrer" />
<script> <script>
// Global state to store markdown and canvas data
let currentMarkdown = null;
let currentData = null;
let currentImageSrc = null;
async function processFile() { async function processFile() {
const filePath = document.getElementById('filePath').value; const filePath = document.getElementById('filePath').value;
const pageNumber = parseInt(document.getElementById('pageNumber').value) || 0; const pageNumber = parseInt(document.getElementById('pageNumber').value) || 0;
@@ -285,6 +326,10 @@
} }
function renderResults(data) { function renderResults(data) {
// Store data for toggle functionality
currentData = data;
currentImageSrc = data.image_base64;
const canvas = document.getElementById('layoutCanvas'); const canvas = document.getElementById('layoutCanvas');
const ctx = canvas.getContext('2d'); const ctx = canvas.getContext('2d');
const markdownContent = document.getElementById('markdownContent'); const markdownContent = document.getElementById('markdownContent');
@@ -292,51 +337,14 @@
// Draw image with layout overlays // Draw image with layout overlays
const img = new Image(); const img = new Image();
img.onload = function() { img.onload = function() {
canvas.width = data.image_width; drawCanvas(img, data, ctx);
canvas.height = data.image_height;
// Draw image
ctx.drawImage(img, 0, 0, data.image_width, data.image_height);
// Draw layout blocks
ctx.lineWidth = 3;
ctx.font = 'bold 14px -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif';
const labelCounts = {};
data.blocks.forEach((block) => {
const [x1, y1, x2, y2] = block.bbox;
const width = x2 - x1;
const height = y2 - y1;
// Draw rectangle with semi-transparent fill
ctx.strokeStyle = block.color;
ctx.fillStyle = block.color + '33';
ctx.fillRect(x1, y1, width, height);
ctx.strokeRect(x1, y1, width, height);
// Count labels for unique identification
labelCounts[block.label] = (labelCounts[block.label] || 0) + 1;
const labelWithCount = `${block.label} #${labelCounts[block.label]}`;
// Draw label with background
const textMetrics = ctx.measureText(labelWithCount);
const textWidth = textMetrics.width;
const textHeight = 16;
const padding = 6;
const labelX = x1;
const labelY = Math.max(y1 - textHeight - padding, textHeight);
ctx.fillStyle = block.color;
ctx.fillRect(labelX, labelY - textHeight, textWidth + padding * 2, textHeight + padding);
ctx.fillStyle = 'white';
ctx.textBaseline = 'top';
ctx.fillText(labelWithCount, labelX + padding, labelY - textHeight + padding/2);
});
}; };
img.src = data.image_base64; img.src = data.image_base64;
// Store markdown and show copy button
currentMarkdown = data.markdown;
document.getElementById('copyMarkdownBtn').style.display = 'inline-block';
// Render HTML directly (with images embedded) // Render HTML directly (with images embedded)
markdownContent.innerHTML = data.html; markdownContent.innerHTML = data.html;
@@ -362,6 +370,85 @@
}); });
} }
function drawCanvas(img, data, ctx) {
const canvas = document.getElementById('layoutCanvas');
canvas.width = data.image_width;
canvas.height = data.image_height;
// Draw image
ctx.drawImage(img, 0, 0, data.image_width, data.image_height);
// Check if layout boxes should be shown
const showBoxes = document.getElementById('showLayoutBoxes').checked;
if (!showBoxes) return;
// Draw layout blocks
ctx.lineWidth = 3;
ctx.font = 'bold 14px -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif';
const labelCounts = {};
data.blocks.forEach((block) => {
const [x1, y1, x2, y2] = block.bbox;
const width = x2 - x1;
const height = y2 - y1;
// Draw rectangle with semi-transparent fill
ctx.strokeStyle = block.color;
ctx.fillStyle = block.color + '33';
ctx.fillRect(x1, y1, width, height);
ctx.strokeRect(x1, y1, width, height);
// Count labels for unique identification
labelCounts[block.label] = (labelCounts[block.label] || 0) + 1;
const labelWithCount = `${block.label} #${labelCounts[block.label]}`;
// Draw label with background
const textMetrics = ctx.measureText(labelWithCount);
const textWidth = textMetrics.width;
const textHeight = 16;
const padding = 6;
const labelX = x1;
const labelY = Math.max(y1 - textHeight - padding, textHeight);
ctx.fillStyle = block.color;
ctx.fillRect(labelX, labelY - textHeight, textWidth + padding * 2, textHeight + padding);
ctx.fillStyle = 'white';
ctx.textBaseline = 'top';
ctx.fillText(labelWithCount, labelX + padding, labelY - textHeight + padding/2);
});
}
function toggleLayoutBoxes() {
if (!currentData || !currentImageSrc) return;
const canvas = document.getElementById('layoutCanvas');
const ctx = canvas.getContext('2d');
const img = new Image();
img.onload = function() {
drawCanvas(img, currentData, ctx);
};
img.src = currentImageSrc;
}
function copyMarkdown() {
if (!currentMarkdown) {
document.getElementById('error').textContent = 'No markdown to copy';
return;
}
navigator.clipboard.writeText(currentMarkdown).then(() => {
const success = document.getElementById('success');
success.textContent = 'Markdown copied!';
setTimeout(() => {
success.textContent = '';
}, 2000);
}).catch((err) => {
document.getElementById('error').textContent = 'Failed to copy: ' + err.message;
});
}
// Allow Enter key to trigger processing // Allow Enter key to trigger processing
document.getElementById('filePath').addEventListener('keypress', function(e) { document.getElementById('filePath').addEventListener('keypress', function(e) {
if (e.key === 'Enter') processFile(); if (e.key === 'Enter') processFile();

View File

@@ -1,7 +1,5 @@
from dotenv import find_dotenv from dotenv import find_dotenv
from pydantic import computed_field
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
import torch
import os import os
@@ -9,11 +7,13 @@ class Settings(BaseSettings):
# Paths # Paths
BASE_DIR: str = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) BASE_DIR: str = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
IMAGE_DPI: int = 192 IMAGE_DPI: int = 192
MIN_IMAGE_DIM: int = 1024 MIN_PDF_IMAGE_DIM: int = 1024
MIN_IMAGE_DIM: int = 1536
MODEL_CHECKPOINT: str = "datalab-to/chandra" MODEL_CHECKPOINT: str = "datalab-to/chandra"
TORCH_DEVICE: str | None = None TORCH_DEVICE: str | None = None
MAX_OUTPUT_TOKENS: int = 8192 MAX_OUTPUT_TOKENS: int = 12384
TORCH_ATTN: str | None = None TORCH_ATTN: str | None = None
BBOX_SCALE: int = 1024
# vLLM server settings # vLLM server settings
VLLM_API_KEY: str = "EMPTY" VLLM_API_KEY: str = "EMPTY"
@@ -22,11 +22,6 @@ class Settings(BaseSettings):
VLLM_GPUS: str = "0" VLLM_GPUS: str = "0"
MAX_VLLM_RETRIES: int = 6 MAX_VLLM_RETRIES: int = 6
@computed_field
@property
def TORCH_DTYPE(self) -> torch.dtype:
return torch.bfloat16
class Config: class Config:
env_file = find_dotenv("local.env") env_file = find_dotenv("local.env")
extra = "ignore" extra = "ignore"

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "chandra-ocr" name = "chandra-ocr"
version = "0.1.7" version = "0.1.9"
description = "OCR model that converts documents to markdown, HTML, or JSON." description = "OCR model that converts documents to markdown, HTML, or JSON."
readme = "README.md" readme = "README.md"
requires-python = ">=3.10" requires-python = ">=3.10"