mirror of
https://github.com/datalab-to/chandra.git
synced 2025-11-29 08:33:13 +00:00
@@ -4,6 +4,7 @@ from chandra.model.hf import load_model, generate_hf
|
||||
from chandra.model.schema import BatchInputItem, BatchOutputItem
|
||||
from chandra.model.vllm import generate_vllm
|
||||
from chandra.output import parse_markdown, parse_html, parse_chunks, extract_images
|
||||
from chandra.settings import settings
|
||||
|
||||
|
||||
class InferenceManager:
|
||||
@@ -26,19 +27,27 @@ class InferenceManager:
|
||||
output_kwargs["include_headers_footers"] = kwargs.pop(
|
||||
"include_headers_footers"
|
||||
)
|
||||
bbox_scale = kwargs.get("bbox_scale", settings.BBOX_SCALE)
|
||||
|
||||
if self.method == "vllm":
|
||||
results = generate_vllm(
|
||||
batch, max_output_tokens=max_output_tokens, **kwargs
|
||||
batch,
|
||||
max_output_tokens=max_output_tokens,
|
||||
bbox_scale=bbox_scale,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
results = generate_hf(
|
||||
batch, self.model, max_output_tokens=max_output_tokens, **kwargs
|
||||
batch,
|
||||
self.model,
|
||||
max_output_tokens=max_output_tokens,
|
||||
bbox_scale=bbox_scale,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
output = []
|
||||
for result, input_item in zip(results, batch):
|
||||
chunks = parse_chunks(result.raw, input_item.image)
|
||||
chunks = parse_chunks(result.raw, input_item.image, bbox_scale=bbox_scale)
|
||||
output.append(
|
||||
BatchOutputItem(
|
||||
markdown=parse_markdown(result.raw, **output_kwargs),
|
||||
|
||||
@@ -10,12 +10,18 @@ from chandra.settings import settings
|
||||
|
||||
|
||||
def generate_hf(
|
||||
batch: List[BatchInputItem], model, max_output_tokens=None, **kwargs
|
||||
batch: List[BatchInputItem],
|
||||
model,
|
||||
max_output_tokens=None,
|
||||
bbox_scale: int = settings.BBOX_SCALE,
|
||||
**kwargs,
|
||||
) -> List[GenerationResult]:
|
||||
if max_output_tokens is None:
|
||||
max_output_tokens = settings.MAX_OUTPUT_TOKENS
|
||||
|
||||
messages = [process_batch_element(item, model.processor) for item in batch]
|
||||
messages = [
|
||||
process_batch_element(item, model.processor, bbox_scale) for item in batch
|
||||
]
|
||||
text = model.processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
@@ -48,12 +54,12 @@ def generate_hf(
|
||||
return results
|
||||
|
||||
|
||||
def process_batch_element(item: BatchInputItem, processor):
|
||||
def process_batch_element(item: BatchInputItem, processor, bbox_scale: int):
|
||||
prompt = item.prompt
|
||||
prompt_type = item.prompt_type
|
||||
|
||||
if not prompt:
|
||||
prompt = PROMPT_MAPPING[prompt_type]
|
||||
prompt = PROMPT_MAPPING[prompt_type].replace("{bbox_scale}", str(bbox_scale))
|
||||
|
||||
content = []
|
||||
image = scale_to_fit(item.image) # Guarantee max size
|
||||
|
||||
@@ -28,6 +28,7 @@ def generate_vllm(
|
||||
max_workers: int | None = None,
|
||||
custom_headers: dict | None = None,
|
||||
max_failure_retries: int | None = None,
|
||||
bbox_scale: int = settings.BBOX_SCALE,
|
||||
) -> List[GenerationResult]:
|
||||
client = OpenAI(
|
||||
api_key=settings.VLLM_API_KEY,
|
||||
@@ -54,7 +55,9 @@ def generate_vllm(
|
||||
) -> GenerationResult:
|
||||
prompt = item.prompt
|
||||
if not prompt:
|
||||
prompt = PROMPT_MAPPING[item.prompt_type]
|
||||
prompt = PROMPT_MAPPING[item.prompt_type].replace(
|
||||
"{bbox_scale}", str(bbox_scale)
|
||||
)
|
||||
|
||||
content = []
|
||||
image = scale_to_fit(item.image)
|
||||
|
||||
@@ -221,12 +221,12 @@ class LayoutBlock:
|
||||
content: str
|
||||
|
||||
|
||||
def parse_layout(html: str, image: Image.Image):
|
||||
def parse_layout(html: str, image: Image.Image, bbox_scale=settings.BBOX_SCALE):
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
top_level_divs = soup.find_all("div", recursive=False)
|
||||
width, height = image.size
|
||||
width_scaler = width / settings.BBOX_SCALE
|
||||
height_scaler = height / settings.BBOX_SCALE
|
||||
width_scaler = width / bbox_scale
|
||||
height_scaler = height / bbox_scale
|
||||
layout_blocks = []
|
||||
for div in top_level_divs:
|
||||
bbox = div.get("data-bbox")
|
||||
@@ -255,7 +255,7 @@ def parse_layout(html: str, image: Image.Image):
|
||||
return layout_blocks
|
||||
|
||||
|
||||
def parse_chunks(html: str, image: Image.Image):
|
||||
layout = parse_layout(html, image)
|
||||
def parse_chunks(html: str, image: Image.Image, bbox_scale=settings.BBOX_SCALE):
|
||||
layout = parse_layout(html, image, bbox_scale=bbox_scale)
|
||||
chunks = [asdict(block) for block in layout]
|
||||
return chunks
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from chandra.settings import settings
|
||||
|
||||
ALLOWED_TAGS = [
|
||||
"math",
|
||||
"br",
|
||||
@@ -67,7 +65,7 @@ Guidelines:
|
||||
""".strip()
|
||||
|
||||
OCR_LAYOUT_PROMPT = f"""
|
||||
OCR this image to HTML, arranged as layout blocks. Each layout block should be a div with the data-bbox attribute representing the bounding box of the block in [x0, y0, x1, y1] format. Bboxes are normalized 0-{settings.BBOX_SCALE}. The data-label attribute is the label for the block.
|
||||
OCR this image to HTML, arranged as layout blocks. Each layout block should be a div with the data-bbox attribute representing the bounding box of the block in [x0, y0, x1, y1] format. Bboxes are normalized 0-{{bbox_scale}}. The data-label attribute is the label for the block.
|
||||
|
||||
Use the following labels:
|
||||
- Caption
|
||||
|
||||
Reference in New Issue
Block a user