mirror of
https://github.com/datalab-to/chandra.git
synced 2025-11-29 08:33:13 +00:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
22639087e7 | ||
|
|
3958707a80 | ||
|
|
fe28f26fc2 | ||
|
|
4470243560 | ||
|
|
a3889b12fb | ||
|
|
d69d18d6e8 |
@@ -5,7 +5,6 @@ from transformers import Qwen3VLForConditionalGeneration, Qwen3VLProcessor
|
||||
|
||||
from chandra.model.schema import BatchInputItem, GenerationResult
|
||||
from chandra.model.util import scale_to_fit
|
||||
from chandra.output import fix_raw
|
||||
from chandra.prompts import PROMPT_MAPPING
|
||||
from chandra.settings import settings
|
||||
|
||||
@@ -43,7 +42,7 @@ def generate_hf(
|
||||
clean_up_tokenization_spaces=False,
|
||||
)
|
||||
results = [
|
||||
GenerationResult(raw=fix_raw(out), token_count=len(ids), error=False)
|
||||
GenerationResult(raw=out, token_count=len(ids), error=False)
|
||||
for out, ids in zip(output_text, generated_ids_trimmed)
|
||||
]
|
||||
return results
|
||||
|
||||
@@ -9,7 +9,6 @@ from openai import OpenAI
|
||||
|
||||
from chandra.model.schema import BatchInputItem, GenerationResult
|
||||
from chandra.model.util import scale_to_fit, detect_repeat_token
|
||||
from chandra.output import fix_raw
|
||||
from chandra.prompts import PROMPT_MAPPING
|
||||
from chandra.settings import settings
|
||||
|
||||
@@ -27,6 +26,7 @@ def generate_vllm(
|
||||
max_retries: int = None,
|
||||
max_workers: int | None = None,
|
||||
custom_headers: dict | None = None,
|
||||
max_failure_retries: int | None = None,
|
||||
) -> List[GenerationResult]:
|
||||
client = OpenAI(
|
||||
api_key=settings.VLLM_API_KEY,
|
||||
@@ -76,7 +76,6 @@ def generate_vllm(
|
||||
top_p=top_p,
|
||||
)
|
||||
raw = completion.choices[0].message.content
|
||||
raw = fix_raw(raw)
|
||||
result = GenerationResult(
|
||||
raw=raw,
|
||||
token_count=completion.usage.completion_tokens,
|
||||
@@ -88,27 +87,50 @@ def generate_vllm(
|
||||
|
||||
return result
|
||||
|
||||
def process_item(item, max_retries):
|
||||
def process_item(item, max_retries, max_failure_retries=None):
|
||||
result = _generate(item)
|
||||
retries = 0
|
||||
|
||||
while retries < max_retries and (
|
||||
detect_repeat_token(result.raw)
|
||||
or (
|
||||
len(result.raw) > 50
|
||||
and detect_repeat_token(result.raw, cut_from_end=50)
|
||||
)
|
||||
or result.error
|
||||
):
|
||||
print(
|
||||
f"Detected repeat token or error, retrying generation (attempt {retries + 1})..."
|
||||
)
|
||||
while _should_retry(result, retries, max_retries, max_failure_retries):
|
||||
result = _generate(item, temperature=0.3, top_p=0.95)
|
||||
retries += 1
|
||||
|
||||
return result
|
||||
|
||||
def _should_retry(result, retries, max_retries, max_failure_retries):
|
||||
has_repeat = detect_repeat_token(result.raw) or (
|
||||
len(result.raw) > 50 and detect_repeat_token(result.raw, cut_from_end=50)
|
||||
)
|
||||
|
||||
if retries < max_retries and has_repeat:
|
||||
print(
|
||||
f"Detected repeat token, retrying generation (attempt {retries + 1})..."
|
||||
)
|
||||
return True
|
||||
|
||||
if retries < max_retries and result.error:
|
||||
print(
|
||||
f"Detected vllm error, retrying generation (attempt {retries + 1})..."
|
||||
)
|
||||
return True
|
||||
|
||||
if (
|
||||
result.error
|
||||
and max_failure_retries is not None
|
||||
and retries < max_failure_retries
|
||||
):
|
||||
print(
|
||||
f"Detected vllm error, retrying generation (attempt {retries + 1})..."
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
results = list(executor.map(process_item, batch, repeat(max_retries)))
|
||||
results = list(
|
||||
executor.map(
|
||||
process_item, batch, repeat(max_retries), repeat(max_failure_retries)
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@@ -9,6 +9,8 @@ from PIL import Image
|
||||
from bs4 import BeautifulSoup, NavigableString
|
||||
from markdownify import MarkdownConverter, re_whitespace
|
||||
|
||||
from chandra.settings import settings
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _hash_html(html: str):
|
||||
@@ -20,15 +22,6 @@ def get_image_name(html: str, div_idx: int):
|
||||
return f"{html_hash}_{div_idx}_img.webp"
|
||||
|
||||
|
||||
def fix_raw(html: str):
|
||||
def replace_group(match):
|
||||
numbers = re.findall(r"\d+", match.group(0))
|
||||
return "[" + ",".join(numbers) + "]"
|
||||
|
||||
result = re.sub(r"(?:<BBOX\d+>){4}", replace_group, html)
|
||||
return result
|
||||
|
||||
|
||||
def extract_images(html: str, chunks: dict, image: Image.Image):
|
||||
images = {}
|
||||
div_idx = 0
|
||||
@@ -232,16 +225,21 @@ def parse_layout(html: str, image: Image.Image):
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
top_level_divs = soup.find_all("div", recursive=False)
|
||||
width, height = image.size
|
||||
width_scaler = width / 1024
|
||||
height_scaler = height / 1024
|
||||
width_scaler = width / settings.BBOX_SCALE
|
||||
height_scaler = height / settings.BBOX_SCALE
|
||||
layout_blocks = []
|
||||
for div in top_level_divs:
|
||||
bbox = div.get("data-bbox")
|
||||
|
||||
try:
|
||||
bbox = json.loads(bbox)
|
||||
assert len(bbox) == 4, "Invalid bbox length"
|
||||
except Exception:
|
||||
bbox = [0, 0, 1, 1]
|
||||
try:
|
||||
bbox = bbox.split(" ")
|
||||
assert len(bbox) == 4, "Invalid bbox length"
|
||||
except Exception:
|
||||
bbox = [0, 0, 1, 1]
|
||||
|
||||
bbox = list(map(int, bbox))
|
||||
# Normalize bbox
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from chandra.settings import settings
|
||||
|
||||
ALLOWED_TAGS = [
|
||||
"math",
|
||||
"br",
|
||||
@@ -65,7 +67,7 @@ Guidelines:
|
||||
""".strip()
|
||||
|
||||
OCR_LAYOUT_PROMPT = f"""
|
||||
OCR this image to HTML, arranged as layout blocks. Each layout block should be a div with the data-bbox attribute representing the bounding box of the block in [x0, y0, x1, y1] format. Bboxes are normalized 0-1024. The data-label attribute is the label for the block.
|
||||
OCR this image to HTML, arranged as layout blocks. Each layout block should be a div with the data-bbox attribute representing the bounding box of the block in [x0, y0, x1, y1] format. Bboxes are normalized 0-{settings.BBOX_SCALE}. The data-label attribute is the label for the block.
|
||||
|
||||
Use the following labels:
|
||||
- Caption
|
||||
|
||||
@@ -15,6 +15,7 @@ class Settings(BaseSettings):
|
||||
TORCH_DEVICE: str | None = None
|
||||
MAX_OUTPUT_TOKENS: int = 12384
|
||||
TORCH_ATTN: str | None = None
|
||||
BBOX_SCALE: int = 1024
|
||||
|
||||
# vLLM server settings
|
||||
VLLM_API_KEY: str = "EMPTY"
|
||||
|
||||
Reference in New Issue
Block a user