Add precommit

This commit is contained in:
Vik Paruchuri
2025-10-16 20:26:19 -04:00
parent bb1fd39a46
commit 6d093af119
7 changed files with 86 additions and 36 deletions

10
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,10 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.14.1
hooks:
# Run the linter.
- id: ruff-check
args: [ --fix ]
# Run the formatter.
- id: ruff-format

View File

@@ -16,11 +16,17 @@ class InferenceManager:
else:
self.model = None
def generate(self, batch: List[BatchInputItem], **kwargs) -> List[BatchOutputItem]:
def generate(
self, batch: List[BatchInputItem], max_output_tokens=None, **kwargs
) -> List[BatchOutputItem]:
if self.method == "vllm":
results = generate_vllm(batch, **kwargs)
results = generate_vllm(
batch, max_output_tokens=max_output_tokens, **kwargs
)
else:
results = generate_hf(batch, self.model, **kwargs)
results = generate_hf(
batch, self.model, max_output_tokens=max_output_tokens, **kwargs
)
output = []
for result, input_item in zip(results, batch):
@@ -31,7 +37,7 @@ class InferenceManager:
chunks=parse_chunks(result.raw, input_item.image),
raw=result.raw,
page_box=[0, 0, input_item.image.width, input_item.image.height],
token_count=result.token_count
token_count=result.token_count,
)
)
return output

View File

@@ -9,7 +9,12 @@ from chandra.prompts import PROMPT_MAPPING
from chandra.settings import settings
def generate_hf(batch: List[BatchInputItem], model, **kwargs) -> List[GenerationResult]:
def generate_hf(
batch: List[BatchInputItem], model, max_output_tokens=None, **kwargs
) -> List[GenerationResult]:
if max_output_tokens is None:
max_output_tokens = settings.MAX_OUTPUT_TOKENS
messages = [process_batch_element(item, model.processor) for item in batch]
text = model.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
@@ -21,17 +26,20 @@ def generate_hf(batch: List[BatchInputItem], model, **kwargs) -> List[Generation
images=image_inputs,
padding=True,
return_tensors="pt",
padding_side="left"
padding_side="left",
)
inputs = inputs.to("cuda")
# Inference: Generation of the output
generated_ids = model.generate_hf(**inputs, max_new_tokens=settings.MAX_OUTPUT_TOKENS)
generated_ids = model.generate_hf(**inputs, max_new_tokens=max_output_tokens)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = model.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
results = [
GenerationResult(raw=out, token_count=len(ids))
@@ -52,10 +60,7 @@ def process_batch_element(item: BatchInputItem, processor):
content.append({"type": "image", "image": image})
content.append({"type": "text", "text": prompt})
message = {
"role": "user",
"content": content
}
message = {"role": "user", "content": content}
return message

View File

@@ -20,7 +20,12 @@ def image_to_base64(image: Image.Image) -> str:
return base64.b64encode(buffered.getvalue()).decode()
def generate_vllm(batch: List[BatchInputItem], max_retries: int = None, max_workers: int | None = None) -> List[GenerationResult]:
def generate_vllm(
batch: List[BatchInputItem],
max_output_tokens: int = None,
max_retries: int = None,
max_workers: int | None = None,
) -> List[GenerationResult]:
client = OpenAI(
api_key=settings.VLLM_API_KEY,
base_url=settings.VLLM_API_BASE,
@@ -33,11 +38,16 @@ def generate_vllm(batch: List[BatchInputItem], max_retries: int = None, max_work
if max_workers is None:
max_workers = min(64, len(batch))
if max_output_tokens is None:
max_output_tokens = settings.MAX_OUTPUT_TOKENS
if model_name is None:
models = client.models.list()
model_name = models.data[0].id
def _generate(item: BatchInputItem, temperature: float = 0, top_p: float = .1) -> GenerationResult:
def _generate(
item: BatchInputItem, temperature: float = 0, top_p: float = 0.1
) -> GenerationResult:
prompt = item.prompt
if not prompt:
prompt = PROMPT_MAPPING[item.prompt_type]
@@ -45,40 +55,41 @@ def generate_vllm(batch: List[BatchInputItem], max_retries: int = None, max_work
content = []
image = scale_to_fit(item.image)
image_b64 = image_to_base64(image)
content.append({
content.append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{image_b64}"
"image_url": {"url": f"data:image/png;base64,{image_b64}"},
}
})
)
content.append({
"type": "text",
"text": prompt
})
content.append({"type": "text", "text": prompt})
completion = client.chat.completions.create(
model=model_name,
messages=[{
"role": "user",
"content": content
}],
messages=[{"role": "user", "content": content}],
max_tokens=settings.MAX_OUTPUT_TOKENS,
temperature=temperature,
top_p=top_p,
)
return GenerationResult(
raw=completion.choices[0].message.content,
token_count=completion.usage.completion_tokens
token_count=completion.usage.completion_tokens,
)
def process_item(item, max_retries):
result = _generate(item)
retries = 0
while retries < max_retries and (detect_repeat_token(result.raw) or
(len(result.raw) > 50 and detect_repeat_token(result.raw, cut_from_end=50))):
print(f"Detected repeat token, retrying generation (attempt {retries + 1})...")
while retries < max_retries and (
detect_repeat_token(result.raw)
or (
len(result.raw) > 50
and detect_repeat_token(result.raw, cut_from_end=50)
)
):
print(
f"Detected repeat token, retrying generation (attempt {retries + 1})..."
)
result = _generate(item, temperature=0.3, top_p=0.95)
retries += 1

View File

@@ -24,6 +24,7 @@ def parse_html(html: str, include_headers_footers: bool = False):
out_html += content
return out_html
def escape_dollars(text):
return text.replace("$", r"\$")
@@ -139,7 +140,11 @@ def parse_markdown(html: str, include_headers_footers: bool = False):
inline_math_delimiters=("$", "$"),
block_math_delimiters=("$$", "$$"),
)
try:
markdown = md_cls.convert(html)
except Exception as e:
print(f"Error converting HTML to Markdown: {e}")
markdown = ""
return markdown.strip()
@@ -161,7 +166,7 @@ def parse_layout(html: str, image: Image.Image):
bbox = div.get("data-bbox")
try:
bbox = json.loads(bbox)
except Exception as e:
except Exception:
bbox = [0, 0, 1, 1] # Fallback to a default bbox if parsing fails
bbox = list(map(int, bbox))
@@ -177,8 +182,8 @@ def parse_layout(html: str, image: Image.Image):
layout_blocks.append(LayoutBlock(bbox=bbox, label=label, content=content))
return layout_blocks
def parse_chunks(html: str, image: Image.Image):
layout = parse_layout(html, image)
chunks = [asdict(block) for block in layout]
return chunks

View File

@@ -24,3 +24,8 @@ dependencies = [
[tool.setuptools.packages.find]
include = ["chandra*"]
[dependency-groups]
dev = [
"pre-commit>=4.3.0",
]

8
uv.lock generated
View File

@@ -182,6 +182,11 @@ dependencies = [
{ name = "transformers" },
]
[package.dev-dependencies]
dev = [
{ name = "pre-commit" },
]
[package.metadata]
requires-dist = [
{ name = "beautifulsoup4", specifier = ">=4.14.2" },
@@ -201,6 +206,9 @@ requires-dist = [
{ name = "transformers", specifier = ">=4.57.1" },
]
[package.metadata.requires-dev]
dev = [{ name = "pre-commit", specifier = ">=4.3.0" }]
[[package]]
name = "charset-normalizer"
version = "3.4.3"