mirror of
https://github.com/datalab-to/chandra.git
synced 2025-11-29 08:33:13 +00:00
Add precommit
This commit is contained in:
10
.pre-commit-config.yaml
Normal file
10
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,10 @@
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
rev: v0.14.1
|
||||
hooks:
|
||||
# Run the linter.
|
||||
- id: ruff-check
|
||||
args: [ --fix ]
|
||||
# Run the formatter.
|
||||
- id: ruff-format
|
||||
@@ -16,11 +16,17 @@ class InferenceManager:
|
||||
else:
|
||||
self.model = None
|
||||
|
||||
def generate(self, batch: List[BatchInputItem], **kwargs) -> List[BatchOutputItem]:
|
||||
def generate(
|
||||
self, batch: List[BatchInputItem], max_output_tokens=None, **kwargs
|
||||
) -> List[BatchOutputItem]:
|
||||
if self.method == "vllm":
|
||||
results = generate_vllm(batch, **kwargs)
|
||||
results = generate_vllm(
|
||||
batch, max_output_tokens=max_output_tokens, **kwargs
|
||||
)
|
||||
else:
|
||||
results = generate_hf(batch, self.model, **kwargs)
|
||||
results = generate_hf(
|
||||
batch, self.model, max_output_tokens=max_output_tokens, **kwargs
|
||||
)
|
||||
|
||||
output = []
|
||||
for result, input_item in zip(results, batch):
|
||||
@@ -31,7 +37,7 @@ class InferenceManager:
|
||||
chunks=parse_chunks(result.raw, input_item.image),
|
||||
raw=result.raw,
|
||||
page_box=[0, 0, input_item.image.width, input_item.image.height],
|
||||
token_count=result.token_count
|
||||
token_count=result.token_count,
|
||||
)
|
||||
)
|
||||
return output
|
||||
|
||||
@@ -9,7 +9,12 @@ from chandra.prompts import PROMPT_MAPPING
|
||||
from chandra.settings import settings
|
||||
|
||||
|
||||
def generate_hf(batch: List[BatchInputItem], model, **kwargs) -> List[GenerationResult]:
|
||||
def generate_hf(
|
||||
batch: List[BatchInputItem], model, max_output_tokens=None, **kwargs
|
||||
) -> List[GenerationResult]:
|
||||
if max_output_tokens is None:
|
||||
max_output_tokens = settings.MAX_OUTPUT_TOKENS
|
||||
|
||||
messages = [process_batch_element(item, model.processor) for item in batch]
|
||||
text = model.processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
@@ -21,17 +26,20 @@ def generate_hf(batch: List[BatchInputItem], model, **kwargs) -> List[Generation
|
||||
images=image_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
padding_side="left"
|
||||
padding_side="left",
|
||||
)
|
||||
inputs = inputs.to("cuda")
|
||||
|
||||
# Inference: Generation of the output
|
||||
generated_ids = model.generate_hf(**inputs, max_new_tokens=settings.MAX_OUTPUT_TOKENS)
|
||||
generated_ids = model.generate_hf(**inputs, max_new_tokens=max_output_tokens)
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
out_ids[len(in_ids) :]
|
||||
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
output_text = model.processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
generated_ids_trimmed,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False,
|
||||
)
|
||||
results = [
|
||||
GenerationResult(raw=out, token_count=len(ids))
|
||||
@@ -52,10 +60,7 @@ def process_batch_element(item: BatchInputItem, processor):
|
||||
content.append({"type": "image", "image": image})
|
||||
|
||||
content.append({"type": "text", "text": prompt})
|
||||
message = {
|
||||
"role": "user",
|
||||
"content": content
|
||||
}
|
||||
message = {"role": "user", "content": content}
|
||||
return message
|
||||
|
||||
|
||||
|
||||
@@ -20,7 +20,12 @@ def image_to_base64(image: Image.Image) -> str:
|
||||
return base64.b64encode(buffered.getvalue()).decode()
|
||||
|
||||
|
||||
def generate_vllm(batch: List[BatchInputItem], max_retries: int = None, max_workers: int | None = None) -> List[GenerationResult]:
|
||||
def generate_vllm(
|
||||
batch: List[BatchInputItem],
|
||||
max_output_tokens: int = None,
|
||||
max_retries: int = None,
|
||||
max_workers: int | None = None,
|
||||
) -> List[GenerationResult]:
|
||||
client = OpenAI(
|
||||
api_key=settings.VLLM_API_KEY,
|
||||
base_url=settings.VLLM_API_BASE,
|
||||
@@ -33,11 +38,16 @@ def generate_vllm(batch: List[BatchInputItem], max_retries: int = None, max_work
|
||||
if max_workers is None:
|
||||
max_workers = min(64, len(batch))
|
||||
|
||||
if max_output_tokens is None:
|
||||
max_output_tokens = settings.MAX_OUTPUT_TOKENS
|
||||
|
||||
if model_name is None:
|
||||
models = client.models.list()
|
||||
model_name = models.data[0].id
|
||||
|
||||
def _generate(item: BatchInputItem, temperature: float = 0, top_p: float = .1) -> GenerationResult:
|
||||
def _generate(
|
||||
item: BatchInputItem, temperature: float = 0, top_p: float = 0.1
|
||||
) -> GenerationResult:
|
||||
prompt = item.prompt
|
||||
if not prompt:
|
||||
prompt = PROMPT_MAPPING[item.prompt_type]
|
||||
@@ -45,40 +55,41 @@ def generate_vllm(batch: List[BatchInputItem], max_retries: int = None, max_work
|
||||
content = []
|
||||
image = scale_to_fit(item.image)
|
||||
image_b64 = image_to_base64(image)
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{image_b64}"
|
||||
content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image_b64}"},
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
content.append({
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
})
|
||||
content.append({"type": "text", "text": prompt})
|
||||
|
||||
completion = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": content
|
||||
}],
|
||||
messages=[{"role": "user", "content": content}],
|
||||
max_tokens=settings.MAX_OUTPUT_TOKENS,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
)
|
||||
return GenerationResult(
|
||||
raw=completion.choices[0].message.content,
|
||||
token_count=completion.usage.completion_tokens
|
||||
token_count=completion.usage.completion_tokens,
|
||||
)
|
||||
|
||||
def process_item(item, max_retries):
|
||||
result = _generate(item)
|
||||
retries = 0
|
||||
|
||||
while retries < max_retries and (detect_repeat_token(result.raw) or
|
||||
(len(result.raw) > 50 and detect_repeat_token(result.raw, cut_from_end=50))):
|
||||
print(f"Detected repeat token, retrying generation (attempt {retries + 1})...")
|
||||
while retries < max_retries and (
|
||||
detect_repeat_token(result.raw)
|
||||
or (
|
||||
len(result.raw) > 50
|
||||
and detect_repeat_token(result.raw, cut_from_end=50)
|
||||
)
|
||||
):
|
||||
print(
|
||||
f"Detected repeat token, retrying generation (attempt {retries + 1})..."
|
||||
)
|
||||
result = _generate(item, temperature=0.3, top_p=0.95)
|
||||
retries += 1
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ def parse_html(html: str, include_headers_footers: bool = False):
|
||||
out_html += content
|
||||
return out_html
|
||||
|
||||
|
||||
def escape_dollars(text):
|
||||
return text.replace("$", r"\$")
|
||||
|
||||
@@ -139,7 +140,11 @@ def parse_markdown(html: str, include_headers_footers: bool = False):
|
||||
inline_math_delimiters=("$", "$"),
|
||||
block_math_delimiters=("$$", "$$"),
|
||||
)
|
||||
markdown = md_cls.convert(html)
|
||||
try:
|
||||
markdown = md_cls.convert(html)
|
||||
except Exception as e:
|
||||
print(f"Error converting HTML to Markdown: {e}")
|
||||
markdown = ""
|
||||
return markdown.strip()
|
||||
|
||||
|
||||
@@ -161,8 +166,8 @@ def parse_layout(html: str, image: Image.Image):
|
||||
bbox = div.get("data-bbox")
|
||||
try:
|
||||
bbox = json.loads(bbox)
|
||||
except Exception as e:
|
||||
bbox = [0, 0, 1, 1] # Fallback to a default bbox if parsing fails
|
||||
except Exception:
|
||||
bbox = [0, 0, 1, 1] # Fallback to a default bbox if parsing fails
|
||||
|
||||
bbox = list(map(int, bbox))
|
||||
# Normalize bbox
|
||||
@@ -177,8 +182,8 @@ def parse_layout(html: str, image: Image.Image):
|
||||
layout_blocks.append(LayoutBlock(bbox=bbox, label=label, content=content))
|
||||
return layout_blocks
|
||||
|
||||
|
||||
def parse_chunks(html: str, image: Image.Image):
|
||||
layout = parse_layout(html, image)
|
||||
chunks = [asdict(block) for block in layout]
|
||||
return chunks
|
||||
|
||||
|
||||
@@ -24,3 +24,8 @@ dependencies = [
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["chandra*"]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"pre-commit>=4.3.0",
|
||||
]
|
||||
|
||||
8
uv.lock
generated
8
uv.lock
generated
@@ -182,6 +182,11 @@ dependencies = [
|
||||
{ name = "transformers" },
|
||||
]
|
||||
|
||||
[package.dev-dependencies]
|
||||
dev = [
|
||||
{ name = "pre-commit" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "beautifulsoup4", specifier = ">=4.14.2" },
|
||||
@@ -201,6 +206,9 @@ requires-dist = [
|
||||
{ name = "transformers", specifier = ">=4.57.1" },
|
||||
]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
dev = [{ name = "pre-commit", specifier = ">=4.3.0" }]
|
||||
|
||||
[[package]]
|
||||
name = "charset-normalizer"
|
||||
version = "3.4.3"
|
||||
|
||||
Reference in New Issue
Block a user