mirror of
https://github.com/datalab-to/chandra.git
synced 2025-11-29 08:33:13 +00:00
Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d1cde9b608 | ||
|
|
aabfed2ed3 | ||
|
|
4b01146865 | ||
|
|
7cf96f3911 | ||
|
|
607205211a | ||
|
|
358358134e | ||
|
|
2d2d7ab331 | ||
|
|
528b58c16f | ||
|
|
5acfd8dc6a |
@@ -13,6 +13,15 @@ def flatten(page, flag=pdfium_c.FLAT_NORMALDISPLAY):
|
||||
print(f"Failed to flatten annotations / form fields on page {page}.")
|
||||
|
||||
|
||||
def load_image(filepath: str):
|
||||
image = Image.open(filepath).convert("RGB")
|
||||
if image.width < settings.MIN_IMAGE_DIM or image.height < settings.MIN_IMAGE_DIM:
|
||||
scale = settings.MIN_IMAGE_DIM / min(image.width, image.height)
|
||||
new_size = (int(image.width * scale), int(image.height * scale))
|
||||
image = image.resize(new_size, Image.Resampling.LANCZOS)
|
||||
return image
|
||||
|
||||
|
||||
def load_pdf_images(filepath: str, page_range: List[int]):
|
||||
doc = pdfium.PdfDocument(filepath)
|
||||
doc.init_forms()
|
||||
@@ -22,7 +31,7 @@ def load_pdf_images(filepath: str, page_range: List[int]):
|
||||
if not page_range or page in page_range:
|
||||
page_obj = doc[page]
|
||||
min_page_dim = min(page_obj.get_width(), page_obj.get_height())
|
||||
scale_dpi = (settings.MIN_IMAGE_DIM / min_page_dim) * 72
|
||||
scale_dpi = (settings.MIN_PDF_IMAGE_DIM / min_page_dim) * 72
|
||||
scale_dpi = max(scale_dpi, settings.IMAGE_DPI)
|
||||
page_obj = doc[page]
|
||||
flatten(page_obj)
|
||||
@@ -56,5 +65,5 @@ def load_file(filepath: str, config: dict):
|
||||
if input_type and input_type.extension == "pdf":
|
||||
images = load_pdf_images(filepath, page_range)
|
||||
else:
|
||||
images = [Image.open(filepath).convert("RGB")]
|
||||
images = [load_image(filepath)]
|
||||
return images
|
||||
|
||||
@@ -48,6 +48,7 @@ class InferenceManager:
|
||||
page_box=[0, 0, input_item.image.width, input_item.image.height],
|
||||
token_count=result.token_count,
|
||||
images=extract_images(result.raw, chunks, input_item.image),
|
||||
error=result.error,
|
||||
)
|
||||
)
|
||||
return output
|
||||
|
||||
@@ -5,6 +5,7 @@ from transformers import Qwen3VLForConditionalGeneration, Qwen3VLProcessor
|
||||
|
||||
from chandra.model.schema import BatchInputItem, GenerationResult
|
||||
from chandra.model.util import scale_to_fit
|
||||
from chandra.output import fix_raw
|
||||
from chandra.prompts import PROMPT_MAPPING
|
||||
from chandra.settings import settings
|
||||
|
||||
@@ -42,7 +43,7 @@ def generate_hf(
|
||||
clean_up_tokenization_spaces=False,
|
||||
)
|
||||
results = [
|
||||
GenerationResult(raw=out, token_count=len(ids), error=False)
|
||||
GenerationResult(raw=fix_raw(out), token_count=len(ids), error=False)
|
||||
for out, ids in zip(output_text, generated_ids_trimmed)
|
||||
]
|
||||
return results
|
||||
|
||||
@@ -27,3 +27,4 @@ class BatchOutputItem:
|
||||
page_box: List[int]
|
||||
token_count: int
|
||||
images: dict
|
||||
error: bool
|
||||
|
||||
@@ -44,9 +44,10 @@ def scale_to_fit(
|
||||
|
||||
def detect_repeat_token(
|
||||
predicted_tokens: str,
|
||||
max_repeats: int = 4,
|
||||
base_max_repeats: int = 4,
|
||||
window_size: int = 500,
|
||||
cut_from_end: int = 0,
|
||||
scaling_factor: float = 3.0,
|
||||
):
|
||||
try:
|
||||
predicted_tokens = parse_markdown(predicted_tokens)
|
||||
@@ -57,11 +58,13 @@ def detect_repeat_token(
|
||||
if cut_from_end > 0:
|
||||
predicted_tokens = predicted_tokens[:-cut_from_end]
|
||||
|
||||
# Try different sequence lengths (1 to window_size//2)
|
||||
for seq_len in range(1, window_size // 2 + 1):
|
||||
# Extract the potential repeating sequence from the end
|
||||
candidate_seq = predicted_tokens[-seq_len:]
|
||||
|
||||
# Inverse scaling: shorter sequences need more repeats
|
||||
max_repeats = int(base_max_repeats * (1 + scaling_factor / seq_len))
|
||||
|
||||
# Count how many times this sequence appears consecutively at the end
|
||||
repeat_count = 0
|
||||
pos = len(predicted_tokens) - seq_len
|
||||
@@ -75,7 +78,6 @@ def detect_repeat_token(
|
||||
else:
|
||||
break
|
||||
|
||||
# If we found more than max_repeats consecutive occurrences
|
||||
if repeat_count > max_repeats:
|
||||
return True
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from openai import OpenAI
|
||||
|
||||
from chandra.model.schema import BatchInputItem, GenerationResult
|
||||
from chandra.model.util import scale_to_fit, detect_repeat_token
|
||||
from chandra.output import fix_raw
|
||||
from chandra.prompts import PROMPT_MAPPING
|
||||
from chandra.settings import settings
|
||||
|
||||
@@ -25,10 +26,12 @@ def generate_vllm(
|
||||
max_output_tokens: int = None,
|
||||
max_retries: int = None,
|
||||
max_workers: int | None = None,
|
||||
custom_headers: dict | None = None,
|
||||
) -> List[GenerationResult]:
|
||||
client = OpenAI(
|
||||
api_key=settings.VLLM_API_KEY,
|
||||
base_url=settings.VLLM_API_BASE,
|
||||
default_headers=custom_headers,
|
||||
)
|
||||
model_name = settings.VLLM_MODEL_NAME
|
||||
|
||||
@@ -68,19 +71,22 @@ def generate_vllm(
|
||||
completion = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=[{"role": "user", "content": content}],
|
||||
max_tokens=settings.MAX_OUTPUT_TOKENS,
|
||||
max_tokens=max_output_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
)
|
||||
raw = completion.choices[0].message.content
|
||||
raw = fix_raw(raw)
|
||||
result = GenerationResult(
|
||||
raw=raw,
|
||||
token_count=completion.usage.completion_tokens,
|
||||
error=False,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error during VLLM generation: {e}")
|
||||
return GenerationResult(raw="", token_count=0, error=True)
|
||||
|
||||
return GenerationResult(
|
||||
raw=completion.choices[0].message.content,
|
||||
token_count=completion.usage.completion_tokens,
|
||||
error=False,
|
||||
)
|
||||
return result
|
||||
|
||||
def process_item(item, max_retries):
|
||||
result = _generate(item)
|
||||
|
||||
@@ -20,6 +20,15 @@ def get_image_name(html: str, div_idx: int):
|
||||
return f"{html_hash}_{div_idx}_img.webp"
|
||||
|
||||
|
||||
def fix_raw(html: str):
|
||||
def replace_group(match):
|
||||
numbers = re.findall(r"\d+", match.group(0))
|
||||
return "[" + ",".join(numbers) + "]"
|
||||
|
||||
result = re.sub(r"(?:<BBOX\d+>){4}", replace_group, html)
|
||||
return result
|
||||
|
||||
|
||||
def extract_images(html: str, chunks: dict, image: Image.Image):
|
||||
images = {}
|
||||
div_idx = 0
|
||||
@@ -30,7 +39,11 @@ def extract_images(html: str, chunks: dict, image: Image.Image):
|
||||
if not img:
|
||||
continue
|
||||
bbox = chunk["bbox"]
|
||||
block_image = image.crop(bbox)
|
||||
try:
|
||||
block_image = image.crop(bbox)
|
||||
except ValueError:
|
||||
# Happens when bbox coordinates are invalid
|
||||
continue
|
||||
img_name = get_image_name(html, div_idx)
|
||||
images[img_name] = block_image
|
||||
return images
|
||||
@@ -67,6 +80,17 @@ def parse_html(
|
||||
else:
|
||||
img = BeautifulSoup(f"<img src='{img_src}'/>", "html.parser")
|
||||
div.append(img)
|
||||
|
||||
# Wrap text content in <p> tags if no inner HTML tags exist
|
||||
if label in ["Text"] and not re.search(
|
||||
"<.+>", str(div.decode_contents()).strip()
|
||||
):
|
||||
# Add inner p tags if missing for text blocks
|
||||
text_content = str(div.decode_contents()).strip()
|
||||
text_content = f"<p>{text_content}</p>"
|
||||
div.clear()
|
||||
div.append(BeautifulSoup(text_content, "html.parser"))
|
||||
|
||||
content = str(div.decode_contents())
|
||||
out_html += content
|
||||
return out_html
|
||||
@@ -213,10 +237,11 @@ def parse_layout(html: str, image: Image.Image):
|
||||
layout_blocks = []
|
||||
for div in top_level_divs:
|
||||
bbox = div.get("data-bbox")
|
||||
|
||||
try:
|
||||
bbox = json.loads(bbox)
|
||||
except Exception:
|
||||
bbox = [0, 0, 1, 1] # Fallback to a default bbox if parsing fails
|
||||
bbox = [0, 0, 1, 1]
|
||||
|
||||
bbox = list(map(int, bbox))
|
||||
# Normalize bbox
|
||||
|
||||
@@ -143,6 +143,7 @@ def process():
|
||||
"image_height": img_height,
|
||||
"blocks": blocks_data,
|
||||
"html": html_with_images,
|
||||
"markdown": result.markdown,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -64,6 +64,20 @@
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.controls label {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
color: white;
|
||||
font-size: 14px;
|
||||
cursor: pointer;
|
||||
user-select: none;
|
||||
}
|
||||
|
||||
.controls input[type="checkbox"] {
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.loading {
|
||||
display: none;
|
||||
color: #f39c12;
|
||||
@@ -75,6 +89,11 @@
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.success {
|
||||
color: #27ae60;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.screenshot-container {
|
||||
display: none;
|
||||
margin-top: 60px;
|
||||
@@ -88,8 +107,18 @@
|
||||
display: flex;
|
||||
}
|
||||
|
||||
.left-panel, .right-panel {
|
||||
flex: 1;
|
||||
.left-panel {
|
||||
flex: 0 0 40%;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
background: white;
|
||||
border-radius: 8px;
|
||||
overflow: hidden;
|
||||
box-shadow: 0 4px 12px rgba(0,0,0,0.3);
|
||||
}
|
||||
|
||||
.right-panel {
|
||||
flex: 0 0 60%;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
background: white;
|
||||
@@ -137,6 +166,7 @@
|
||||
padding: 30px;
|
||||
line-height: 1.6;
|
||||
color: #333;
|
||||
font-size: 24px;
|
||||
}
|
||||
|
||||
.markdown-content h1, .markdown-content h2, .markdown-content h3 {
|
||||
@@ -215,8 +245,14 @@
|
||||
<input type="text" id="filePath" placeholder="Enter file path (e.g., /path/to/document.pdf)">
|
||||
<input type="number" id="pageNumber" placeholder="Page" value="0" min="0">
|
||||
<button id="processBtn" onclick="processFile()">Process</button>
|
||||
<label>
|
||||
<input type="checkbox" id="showLayoutBoxes" checked onchange="toggleLayoutBoxes()">
|
||||
Show Layout Boxes
|
||||
</label>
|
||||
<button id="copyMarkdownBtn" onclick="copyMarkdown()" style="display: none;">Copy Markdown</button>
|
||||
<span class="loading" id="loading">Processing...</span>
|
||||
<span class="error" id="error"></span>
|
||||
<span class="success" id="success"></span>
|
||||
</div>
|
||||
|
||||
<div class="screenshot-container" id="container">
|
||||
@@ -242,6 +278,11 @@
|
||||
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
|
||||
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/github-markdown-css/5.8.1/github-markdown.min.css" integrity="sha512-BrOPA520KmDMqieeM7XFe6a3u3Sb3F1JBaQnrIAmWg3EYrciJ+Qqe6ZcKCdfPv26rGcgTrJnZ/IdQEct8h3Zhw==" crossorigin="anonymous" referrerpolicy="no-referrer" />
|
||||
<script>
|
||||
// Global state to store markdown and canvas data
|
||||
let currentMarkdown = null;
|
||||
let currentData = null;
|
||||
let currentImageSrc = null;
|
||||
|
||||
async function processFile() {
|
||||
const filePath = document.getElementById('filePath').value;
|
||||
const pageNumber = parseInt(document.getElementById('pageNumber').value) || 0;
|
||||
@@ -285,6 +326,10 @@
|
||||
}
|
||||
|
||||
function renderResults(data) {
|
||||
// Store data for toggle functionality
|
||||
currentData = data;
|
||||
currentImageSrc = data.image_base64;
|
||||
|
||||
const canvas = document.getElementById('layoutCanvas');
|
||||
const ctx = canvas.getContext('2d');
|
||||
const markdownContent = document.getElementById('markdownContent');
|
||||
@@ -292,51 +337,14 @@
|
||||
// Draw image with layout overlays
|
||||
const img = new Image();
|
||||
img.onload = function() {
|
||||
canvas.width = data.image_width;
|
||||
canvas.height = data.image_height;
|
||||
|
||||
// Draw image
|
||||
ctx.drawImage(img, 0, 0, data.image_width, data.image_height);
|
||||
|
||||
// Draw layout blocks
|
||||
ctx.lineWidth = 3;
|
||||
ctx.font = 'bold 14px -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif';
|
||||
|
||||
const labelCounts = {};
|
||||
data.blocks.forEach((block) => {
|
||||
const [x1, y1, x2, y2] = block.bbox;
|
||||
const width = x2 - x1;
|
||||
const height = y2 - y1;
|
||||
|
||||
// Draw rectangle with semi-transparent fill
|
||||
ctx.strokeStyle = block.color;
|
||||
ctx.fillStyle = block.color + '33';
|
||||
ctx.fillRect(x1, y1, width, height);
|
||||
ctx.strokeRect(x1, y1, width, height);
|
||||
|
||||
// Count labels for unique identification
|
||||
labelCounts[block.label] = (labelCounts[block.label] || 0) + 1;
|
||||
const labelWithCount = `${block.label} #${labelCounts[block.label]}`;
|
||||
|
||||
// Draw label with background
|
||||
const textMetrics = ctx.measureText(labelWithCount);
|
||||
const textWidth = textMetrics.width;
|
||||
const textHeight = 16;
|
||||
const padding = 6;
|
||||
|
||||
const labelX = x1;
|
||||
const labelY = Math.max(y1 - textHeight - padding, textHeight);
|
||||
|
||||
ctx.fillStyle = block.color;
|
||||
ctx.fillRect(labelX, labelY - textHeight, textWidth + padding * 2, textHeight + padding);
|
||||
|
||||
ctx.fillStyle = 'white';
|
||||
ctx.textBaseline = 'top';
|
||||
ctx.fillText(labelWithCount, labelX + padding, labelY - textHeight + padding/2);
|
||||
});
|
||||
drawCanvas(img, data, ctx);
|
||||
};
|
||||
img.src = data.image_base64;
|
||||
|
||||
// Store markdown and show copy button
|
||||
currentMarkdown = data.markdown;
|
||||
document.getElementById('copyMarkdownBtn').style.display = 'inline-block';
|
||||
|
||||
// Render HTML directly (with images embedded)
|
||||
markdownContent.innerHTML = data.html;
|
||||
|
||||
@@ -362,6 +370,85 @@
|
||||
});
|
||||
}
|
||||
|
||||
function drawCanvas(img, data, ctx) {
|
||||
const canvas = document.getElementById('layoutCanvas');
|
||||
canvas.width = data.image_width;
|
||||
canvas.height = data.image_height;
|
||||
|
||||
// Draw image
|
||||
ctx.drawImage(img, 0, 0, data.image_width, data.image_height);
|
||||
|
||||
// Check if layout boxes should be shown
|
||||
const showBoxes = document.getElementById('showLayoutBoxes').checked;
|
||||
if (!showBoxes) return;
|
||||
|
||||
// Draw layout blocks
|
||||
ctx.lineWidth = 3;
|
||||
ctx.font = 'bold 14px -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif';
|
||||
|
||||
const labelCounts = {};
|
||||
data.blocks.forEach((block) => {
|
||||
const [x1, y1, x2, y2] = block.bbox;
|
||||
const width = x2 - x1;
|
||||
const height = y2 - y1;
|
||||
|
||||
// Draw rectangle with semi-transparent fill
|
||||
ctx.strokeStyle = block.color;
|
||||
ctx.fillStyle = block.color + '33';
|
||||
ctx.fillRect(x1, y1, width, height);
|
||||
ctx.strokeRect(x1, y1, width, height);
|
||||
|
||||
// Count labels for unique identification
|
||||
labelCounts[block.label] = (labelCounts[block.label] || 0) + 1;
|
||||
const labelWithCount = `${block.label} #${labelCounts[block.label]}`;
|
||||
|
||||
// Draw label with background
|
||||
const textMetrics = ctx.measureText(labelWithCount);
|
||||
const textWidth = textMetrics.width;
|
||||
const textHeight = 16;
|
||||
const padding = 6;
|
||||
|
||||
const labelX = x1;
|
||||
const labelY = Math.max(y1 - textHeight - padding, textHeight);
|
||||
|
||||
ctx.fillStyle = block.color;
|
||||
ctx.fillRect(labelX, labelY - textHeight, textWidth + padding * 2, textHeight + padding);
|
||||
|
||||
ctx.fillStyle = 'white';
|
||||
ctx.textBaseline = 'top';
|
||||
ctx.fillText(labelWithCount, labelX + padding, labelY - textHeight + padding/2);
|
||||
});
|
||||
}
|
||||
|
||||
function toggleLayoutBoxes() {
|
||||
if (!currentData || !currentImageSrc) return;
|
||||
|
||||
const canvas = document.getElementById('layoutCanvas');
|
||||
const ctx = canvas.getContext('2d');
|
||||
const img = new Image();
|
||||
img.onload = function() {
|
||||
drawCanvas(img, currentData, ctx);
|
||||
};
|
||||
img.src = currentImageSrc;
|
||||
}
|
||||
|
||||
function copyMarkdown() {
|
||||
if (!currentMarkdown) {
|
||||
document.getElementById('error').textContent = 'No markdown to copy';
|
||||
return;
|
||||
}
|
||||
|
||||
navigator.clipboard.writeText(currentMarkdown).then(() => {
|
||||
const success = document.getElementById('success');
|
||||
success.textContent = 'Markdown copied!';
|
||||
setTimeout(() => {
|
||||
success.textContent = '';
|
||||
}, 2000);
|
||||
}).catch((err) => {
|
||||
document.getElementById('error').textContent = 'Failed to copy: ' + err.message;
|
||||
});
|
||||
}
|
||||
|
||||
// Allow Enter key to trigger processing
|
||||
document.getElementById('filePath').addEventListener('keypress', function(e) {
|
||||
if (e.key === 'Enter') processFile();
|
||||
|
||||
@@ -9,10 +9,11 @@ class Settings(BaseSettings):
|
||||
# Paths
|
||||
BASE_DIR: str = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
IMAGE_DPI: int = 192
|
||||
MIN_IMAGE_DIM: int = 1024
|
||||
MIN_PDF_IMAGE_DIM: int = 1024
|
||||
MIN_IMAGE_DIM: int = 1536
|
||||
MODEL_CHECKPOINT: str = "datalab-to/chandra"
|
||||
TORCH_DEVICE: str | None = None
|
||||
MAX_OUTPUT_TOKENS: int = 8192
|
||||
MAX_OUTPUT_TOKENS: int = 12384
|
||||
TORCH_ATTN: str | None = None
|
||||
|
||||
# vLLM server settings
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "chandra-ocr"
|
||||
version = "0.1.7"
|
||||
version = "0.1.9"
|
||||
description = "OCR model that converts documents to markdown, HTML, or JSON."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
Reference in New Issue
Block a user