diff --git a/chandra/model/hf.py b/chandra/model/hf.py index 50aa883..b88eb9e 100644 --- a/chandra/model/hf.py +++ b/chandra/model/hf.py @@ -5,6 +5,7 @@ from transformers import Qwen3VLForConditionalGeneration, Qwen3VLProcessor from chandra.model.schema import BatchInputItem, GenerationResult from chandra.model.util import scale_to_fit +from chandra.output import fix_raw from chandra.prompts import PROMPT_MAPPING from chandra.settings import settings @@ -42,7 +43,7 @@ def generate_hf( clean_up_tokenization_spaces=False, ) results = [ - GenerationResult(raw=out, token_count=len(ids), error=False) + GenerationResult(raw=fix_raw(out), token_count=len(ids), error=False) for out, ids in zip(output_text, generated_ids_trimmed) ] return results diff --git a/chandra/model/util.py b/chandra/model/util.py index d43c1c6..7cef96f 100644 --- a/chandra/model/util.py +++ b/chandra/model/util.py @@ -44,9 +44,10 @@ def scale_to_fit( def detect_repeat_token( predicted_tokens: str, - max_repeats: int = 4, + base_max_repeats: int = 4, window_size: int = 500, cut_from_end: int = 0, + scaling_factor: float = 3.0, ): try: predicted_tokens = parse_markdown(predicted_tokens) @@ -57,11 +58,13 @@ def detect_repeat_token( if cut_from_end > 0: predicted_tokens = predicted_tokens[:-cut_from_end] - # Try different sequence lengths (1 to window_size//2) for seq_len in range(1, window_size // 2 + 1): # Extract the potential repeating sequence from the end candidate_seq = predicted_tokens[-seq_len:] + # Inverse scaling: shorter sequences need more repeats + max_repeats = int(base_max_repeats * (1 + scaling_factor / seq_len)) + # Count how many times this sequence appears consecutively at the end repeat_count = 0 pos = len(predicted_tokens) - seq_len @@ -75,7 +78,6 @@ def detect_repeat_token( else: break - # If we found more than max_repeats consecutive occurrences if repeat_count > max_repeats: return True diff --git a/chandra/model/vllm.py b/chandra/model/vllm.py index 044d9d5..4e36f0e 100644 --- a/chandra/model/vllm.py +++ b/chandra/model/vllm.py @@ -9,6 +9,7 @@ from openai import OpenAI from chandra.model.schema import BatchInputItem, GenerationResult from chandra.model.util import scale_to_fit, detect_repeat_token +from chandra.output import fix_raw from chandra.prompts import PROMPT_MAPPING from chandra.settings import settings @@ -74,8 +75,10 @@ def generate_vllm( temperature=temperature, top_p=top_p, ) + raw = completion.choices[0].message.content + raw = fix_raw(raw) result = GenerationResult( - raw=completion.choices[0].message.content, + raw=raw, token_count=completion.usage.completion_tokens, error=False, ) diff --git a/chandra/output.py b/chandra/output.py index 7d4d1c6..b174ef5 100644 --- a/chandra/output.py +++ b/chandra/output.py @@ -20,6 +20,15 @@ def get_image_name(html: str, div_idx: int): return f"{html_hash}_{div_idx}_img.webp" +def fix_raw(html: str): + def replace_group(match): + numbers = re.findall(r"\d+", match.group(0)) + return "[" + ",".join(numbers) + "]" + + result = re.sub(r"(?:){4}", replace_group, html) + return result + + def extract_images(html: str, chunks: dict, image: Image.Image): images = {} div_idx = 0 @@ -228,10 +237,11 @@ def parse_layout(html: str, image: Image.Image): layout_blocks = [] for div in top_level_divs: bbox = div.get("data-bbox") + try: bbox = json.loads(bbox) except Exception: - bbox = [0, 0, 1, 1] # Fallback to a default bbox if parsing fails + bbox = [0, 0, 1, 1] bbox = list(map(int, bbox)) # Normalize bbox diff --git a/pyproject.toml b/pyproject.toml index 699d710..6be56bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "chandra-ocr" -version = "0.1.8" +version = "0.1.9" description = "OCR model that converts documents to markdown, HTML, or JSON." readme = "README.md" requires-python = ">=3.10"