mirror of
https://github.com/datalab-to/chandra.git
synced 2026-01-20 05:50:42 +00:00
@@ -5,6 +5,7 @@ from transformers import Qwen3VLForConditionalGeneration, Qwen3VLProcessor
|
||||
|
||||
from chandra.model.schema import BatchInputItem, GenerationResult
|
||||
from chandra.model.util import scale_to_fit
|
||||
from chandra.output import fix_raw
|
||||
from chandra.prompts import PROMPT_MAPPING
|
||||
from chandra.settings import settings
|
||||
|
||||
@@ -42,7 +43,7 @@ def generate_hf(
|
||||
clean_up_tokenization_spaces=False,
|
||||
)
|
||||
results = [
|
||||
GenerationResult(raw=out, token_count=len(ids), error=False)
|
||||
GenerationResult(raw=fix_raw(out), token_count=len(ids), error=False)
|
||||
for out, ids in zip(output_text, generated_ids_trimmed)
|
||||
]
|
||||
return results
|
||||
|
||||
@@ -44,9 +44,10 @@ def scale_to_fit(
|
||||
|
||||
def detect_repeat_token(
|
||||
predicted_tokens: str,
|
||||
max_repeats: int = 4,
|
||||
base_max_repeats: int = 4,
|
||||
window_size: int = 500,
|
||||
cut_from_end: int = 0,
|
||||
scaling_factor: float = 3.0,
|
||||
):
|
||||
try:
|
||||
predicted_tokens = parse_markdown(predicted_tokens)
|
||||
@@ -57,11 +58,13 @@ def detect_repeat_token(
|
||||
if cut_from_end > 0:
|
||||
predicted_tokens = predicted_tokens[:-cut_from_end]
|
||||
|
||||
# Try different sequence lengths (1 to window_size//2)
|
||||
for seq_len in range(1, window_size // 2 + 1):
|
||||
# Extract the potential repeating sequence from the end
|
||||
candidate_seq = predicted_tokens[-seq_len:]
|
||||
|
||||
# Inverse scaling: shorter sequences need more repeats
|
||||
max_repeats = int(base_max_repeats * (1 + scaling_factor / seq_len))
|
||||
|
||||
# Count how many times this sequence appears consecutively at the end
|
||||
repeat_count = 0
|
||||
pos = len(predicted_tokens) - seq_len
|
||||
@@ -75,7 +78,6 @@ def detect_repeat_token(
|
||||
else:
|
||||
break
|
||||
|
||||
# If we found more than max_repeats consecutive occurrences
|
||||
if repeat_count > max_repeats:
|
||||
return True
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from openai import OpenAI
|
||||
|
||||
from chandra.model.schema import BatchInputItem, GenerationResult
|
||||
from chandra.model.util import scale_to_fit, detect_repeat_token
|
||||
from chandra.output import fix_raw
|
||||
from chandra.prompts import PROMPT_MAPPING
|
||||
from chandra.settings import settings
|
||||
|
||||
@@ -74,8 +75,10 @@ def generate_vllm(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
)
|
||||
raw = completion.choices[0].message.content
|
||||
raw = fix_raw(raw)
|
||||
result = GenerationResult(
|
||||
raw=completion.choices[0].message.content,
|
||||
raw=raw,
|
||||
token_count=completion.usage.completion_tokens,
|
||||
error=False,
|
||||
)
|
||||
|
||||
@@ -20,6 +20,15 @@ def get_image_name(html: str, div_idx: int):
|
||||
return f"{html_hash}_{div_idx}_img.webp"
|
||||
|
||||
|
||||
def fix_raw(html: str):
|
||||
def replace_group(match):
|
||||
numbers = re.findall(r"\d+", match.group(0))
|
||||
return "[" + ",".join(numbers) + "]"
|
||||
|
||||
result = re.sub(r"(?:<BBOX\d+>){4}", replace_group, html)
|
||||
return result
|
||||
|
||||
|
||||
def extract_images(html: str, chunks: dict, image: Image.Image):
|
||||
images = {}
|
||||
div_idx = 0
|
||||
@@ -228,10 +237,11 @@ def parse_layout(html: str, image: Image.Image):
|
||||
layout_blocks = []
|
||||
for div in top_level_divs:
|
||||
bbox = div.get("data-bbox")
|
||||
|
||||
try:
|
||||
bbox = json.loads(bbox)
|
||||
except Exception:
|
||||
bbox = [0, 0, 1, 1] # Fallback to a default bbox if parsing fails
|
||||
bbox = [0, 0, 1, 1]
|
||||
|
||||
bbox = list(map(int, bbox))
|
||||
# Normalize bbox
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "chandra-ocr"
|
||||
version = "0.1.8"
|
||||
version = "0.1.9"
|
||||
description = "OCR model that converts documents to markdown, HTML, or JSON."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
Reference in New Issue
Block a user