diff --git a/chandra/model/hf.py b/chandra/model/hf.py index b88eb9e..50aa883 100644 --- a/chandra/model/hf.py +++ b/chandra/model/hf.py @@ -5,7 +5,6 @@ from transformers import Qwen3VLForConditionalGeneration, Qwen3VLProcessor from chandra.model.schema import BatchInputItem, GenerationResult from chandra.model.util import scale_to_fit -from chandra.output import fix_raw from chandra.prompts import PROMPT_MAPPING from chandra.settings import settings @@ -43,7 +42,7 @@ def generate_hf( clean_up_tokenization_spaces=False, ) results = [ - GenerationResult(raw=fix_raw(out), token_count=len(ids), error=False) + GenerationResult(raw=out, token_count=len(ids), error=False) for out, ids in zip(output_text, generated_ids_trimmed) ] return results diff --git a/chandra/model/vllm.py b/chandra/model/vllm.py index 5528571..1aabf69 100644 --- a/chandra/model/vllm.py +++ b/chandra/model/vllm.py @@ -9,7 +9,6 @@ from openai import OpenAI from chandra.model.schema import BatchInputItem, GenerationResult from chandra.model.util import scale_to_fit, detect_repeat_token -from chandra.output import fix_raw from chandra.prompts import PROMPT_MAPPING from chandra.settings import settings @@ -76,7 +75,6 @@ def generate_vllm( top_p=top_p, ) raw = completion.choices[0].message.content - raw = fix_raw(raw) result = GenerationResult( raw=raw, token_count=completion.usage.completion_tokens, diff --git a/chandra/output.py b/chandra/output.py index 9afe54a..47e9f98 100644 --- a/chandra/output.py +++ b/chandra/output.py @@ -22,15 +22,6 @@ def get_image_name(html: str, div_idx: int): return f"{html_hash}_{div_idx}_img.webp" -def fix_raw(html: str): - def replace_group(match): - numbers = re.findall(r"\d+", match.group(0)) - return "[" + ",".join(numbers) + "]" - - result = re.sub(r"(?:\|BBOX\d+\|){4}", replace_group, html) - return result - - def extract_images(html: str, chunks: dict, image: Image.Image): images = {} div_idx = 0 @@ -242,8 +233,13 @@ def parse_layout(html: str, image: Image.Image): try: bbox = json.loads(bbox) + assert len(bbox) == 4, "Invalid bbox length" except Exception: - bbox = [0, 0, 1, 1] + try: + bbox = bbox.split(" ") + assert len(bbox) == 4, "Invalid bbox length" + except Exception: + bbox = [0, 0, 1, 1] bbox = list(map(int, bbox)) # Normalize bbox