Add precommit

This commit is contained in:
Vik Paruchuri
2025-10-16 20:26:19 -04:00
parent bb1fd39a46
commit 6d093af119
7 changed files with 86 additions and 36 deletions

10
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,10 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.14.1
hooks:
# Run the linter.
- id: ruff-check
args: [ --fix ]
# Run the formatter.
- id: ruff-format

View File

@@ -16,11 +16,17 @@ class InferenceManager:
else: else:
self.model = None self.model = None
def generate(self, batch: List[BatchInputItem], **kwargs) -> List[BatchOutputItem]: def generate(
self, batch: List[BatchInputItem], max_output_tokens=None, **kwargs
) -> List[BatchOutputItem]:
if self.method == "vllm": if self.method == "vllm":
results = generate_vllm(batch, **kwargs) results = generate_vllm(
batch, max_output_tokens=max_output_tokens, **kwargs
)
else: else:
results = generate_hf(batch, self.model, **kwargs) results = generate_hf(
batch, self.model, max_output_tokens=max_output_tokens, **kwargs
)
output = [] output = []
for result, input_item in zip(results, batch): for result, input_item in zip(results, batch):
@@ -31,7 +37,7 @@ class InferenceManager:
chunks=parse_chunks(result.raw, input_item.image), chunks=parse_chunks(result.raw, input_item.image),
raw=result.raw, raw=result.raw,
page_box=[0, 0, input_item.image.width, input_item.image.height], page_box=[0, 0, input_item.image.width, input_item.image.height],
token_count=result.token_count token_count=result.token_count,
) )
) )
return output return output

View File

@@ -9,7 +9,12 @@ from chandra.prompts import PROMPT_MAPPING
from chandra.settings import settings from chandra.settings import settings
def generate_hf(batch: List[BatchInputItem], model, **kwargs) -> List[GenerationResult]: def generate_hf(
batch: List[BatchInputItem], model, max_output_tokens=None, **kwargs
) -> List[GenerationResult]:
if max_output_tokens is None:
max_output_tokens = settings.MAX_OUTPUT_TOKENS
messages = [process_batch_element(item, model.processor) for item in batch] messages = [process_batch_element(item, model.processor) for item in batch]
text = model.processor.apply_chat_template( text = model.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, tokenize=False, add_generation_prompt=True
@@ -21,17 +26,20 @@ def generate_hf(batch: List[BatchInputItem], model, **kwargs) -> List[Generation
images=image_inputs, images=image_inputs,
padding=True, padding=True,
return_tensors="pt", return_tensors="pt",
padding_side="left" padding_side="left",
) )
inputs = inputs.to("cuda") inputs = inputs.to("cuda")
# Inference: Generation of the output # Inference: Generation of the output
generated_ids = model.generate_hf(**inputs, max_new_tokens=settings.MAX_OUTPUT_TOKENS) generated_ids = model.generate_hf(**inputs, max_new_tokens=max_output_tokens)
generated_ids_trimmed = [ generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
] ]
output_text = model.processor.batch_decode( output_text = model.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
) )
results = [ results = [
GenerationResult(raw=out, token_count=len(ids)) GenerationResult(raw=out, token_count=len(ids))
@@ -52,10 +60,7 @@ def process_batch_element(item: BatchInputItem, processor):
content.append({"type": "image", "image": image}) content.append({"type": "image", "image": image})
content.append({"type": "text", "text": prompt}) content.append({"type": "text", "text": prompt})
message = { message = {"role": "user", "content": content}
"role": "user",
"content": content
}
return message return message

View File

@@ -20,7 +20,12 @@ def image_to_base64(image: Image.Image) -> str:
return base64.b64encode(buffered.getvalue()).decode() return base64.b64encode(buffered.getvalue()).decode()
def generate_vllm(batch: List[BatchInputItem], max_retries: int = None, max_workers: int | None = None) -> List[GenerationResult]: def generate_vllm(
batch: List[BatchInputItem],
max_output_tokens: int = None,
max_retries: int = None,
max_workers: int | None = None,
) -> List[GenerationResult]:
client = OpenAI( client = OpenAI(
api_key=settings.VLLM_API_KEY, api_key=settings.VLLM_API_KEY,
base_url=settings.VLLM_API_BASE, base_url=settings.VLLM_API_BASE,
@@ -33,11 +38,16 @@ def generate_vllm(batch: List[BatchInputItem], max_retries: int = None, max_work
if max_workers is None: if max_workers is None:
max_workers = min(64, len(batch)) max_workers = min(64, len(batch))
if max_output_tokens is None:
max_output_tokens = settings.MAX_OUTPUT_TOKENS
if model_name is None: if model_name is None:
models = client.models.list() models = client.models.list()
model_name = models.data[0].id model_name = models.data[0].id
def _generate(item: BatchInputItem, temperature: float = 0, top_p: float = .1) -> GenerationResult: def _generate(
item: BatchInputItem, temperature: float = 0, top_p: float = 0.1
) -> GenerationResult:
prompt = item.prompt prompt = item.prompt
if not prompt: if not prompt:
prompt = PROMPT_MAPPING[item.prompt_type] prompt = PROMPT_MAPPING[item.prompt_type]
@@ -45,40 +55,41 @@ def generate_vllm(batch: List[BatchInputItem], max_retries: int = None, max_work
content = [] content = []
image = scale_to_fit(item.image) image = scale_to_fit(item.image)
image_b64 = image_to_base64(image) image_b64 = image_to_base64(image)
content.append({ content.append(
"type": "image_url", {
"image_url": { "type": "image_url",
"url": f"data:image/png;base64,{image_b64}" "image_url": {"url": f"data:image/png;base64,{image_b64}"},
} }
}) )
content.append({ content.append({"type": "text", "text": prompt})
"type": "text",
"text": prompt
})
completion = client.chat.completions.create( completion = client.chat.completions.create(
model=model_name, model=model_name,
messages=[{ messages=[{"role": "user", "content": content}],
"role": "user",
"content": content
}],
max_tokens=settings.MAX_OUTPUT_TOKENS, max_tokens=settings.MAX_OUTPUT_TOKENS,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
) )
return GenerationResult( return GenerationResult(
raw=completion.choices[0].message.content, raw=completion.choices[0].message.content,
token_count=completion.usage.completion_tokens token_count=completion.usage.completion_tokens,
) )
def process_item(item, max_retries): def process_item(item, max_retries):
result = _generate(item) result = _generate(item)
retries = 0 retries = 0
while retries < max_retries and (detect_repeat_token(result.raw) or while retries < max_retries and (
(len(result.raw) > 50 and detect_repeat_token(result.raw, cut_from_end=50))): detect_repeat_token(result.raw)
print(f"Detected repeat token, retrying generation (attempt {retries + 1})...") or (
len(result.raw) > 50
and detect_repeat_token(result.raw, cut_from_end=50)
)
):
print(
f"Detected repeat token, retrying generation (attempt {retries + 1})..."
)
result = _generate(item, temperature=0.3, top_p=0.95) result = _generate(item, temperature=0.3, top_p=0.95)
retries += 1 retries += 1

View File

@@ -24,6 +24,7 @@ def parse_html(html: str, include_headers_footers: bool = False):
out_html += content out_html += content
return out_html return out_html
def escape_dollars(text): def escape_dollars(text):
return text.replace("$", r"\$") return text.replace("$", r"\$")
@@ -139,7 +140,11 @@ def parse_markdown(html: str, include_headers_footers: bool = False):
inline_math_delimiters=("$", "$"), inline_math_delimiters=("$", "$"),
block_math_delimiters=("$$", "$$"), block_math_delimiters=("$$", "$$"),
) )
markdown = md_cls.convert(html) try:
markdown = md_cls.convert(html)
except Exception as e:
print(f"Error converting HTML to Markdown: {e}")
markdown = ""
return markdown.strip() return markdown.strip()
@@ -161,8 +166,8 @@ def parse_layout(html: str, image: Image.Image):
bbox = div.get("data-bbox") bbox = div.get("data-bbox")
try: try:
bbox = json.loads(bbox) bbox = json.loads(bbox)
except Exception as e: except Exception:
bbox = [0, 0, 1, 1] # Fallback to a default bbox if parsing fails bbox = [0, 0, 1, 1] # Fallback to a default bbox if parsing fails
bbox = list(map(int, bbox)) bbox = list(map(int, bbox))
# Normalize bbox # Normalize bbox
@@ -177,8 +182,8 @@ def parse_layout(html: str, image: Image.Image):
layout_blocks.append(LayoutBlock(bbox=bbox, label=label, content=content)) layout_blocks.append(LayoutBlock(bbox=bbox, label=label, content=content))
return layout_blocks return layout_blocks
def parse_chunks(html: str, image: Image.Image): def parse_chunks(html: str, image: Image.Image):
layout = parse_layout(html, image) layout = parse_layout(html, image)
chunks = [asdict(block) for block in layout] chunks = [asdict(block) for block in layout]
return chunks return chunks

View File

@@ -24,3 +24,8 @@ dependencies = [
[tool.setuptools.packages.find] [tool.setuptools.packages.find]
include = ["chandra*"] include = ["chandra*"]
[dependency-groups]
dev = [
"pre-commit>=4.3.0",
]

8
uv.lock generated
View File

@@ -182,6 +182,11 @@ dependencies = [
{ name = "transformers" }, { name = "transformers" },
] ]
[package.dev-dependencies]
dev = [
{ name = "pre-commit" },
]
[package.metadata] [package.metadata]
requires-dist = [ requires-dist = [
{ name = "beautifulsoup4", specifier = ">=4.14.2" }, { name = "beautifulsoup4", specifier = ">=4.14.2" },
@@ -201,6 +206,9 @@ requires-dist = [
{ name = "transformers", specifier = ">=4.57.1" }, { name = "transformers", specifier = ">=4.57.1" },
] ]
[package.metadata.requires-dev]
dev = [{ name = "pre-commit", specifier = ">=4.3.0" }]
[[package]] [[package]]
name = "charset-normalizer" name = "charset-normalizer"
version = "3.4.3" version = "3.4.3"