mirror of
https://github.com/datalab-to/chandra.git
synced 2025-11-29 08:33:13 +00:00
Add precommit
This commit is contained in:
10
.pre-commit-config.yaml
Normal file
10
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
repos:
|
||||||
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
|
# Ruff version.
|
||||||
|
rev: v0.14.1
|
||||||
|
hooks:
|
||||||
|
# Run the linter.
|
||||||
|
- id: ruff-check
|
||||||
|
args: [ --fix ]
|
||||||
|
# Run the formatter.
|
||||||
|
- id: ruff-format
|
||||||
@@ -16,11 +16,17 @@ class InferenceManager:
|
|||||||
else:
|
else:
|
||||||
self.model = None
|
self.model = None
|
||||||
|
|
||||||
def generate(self, batch: List[BatchInputItem], **kwargs) -> List[BatchOutputItem]:
|
def generate(
|
||||||
|
self, batch: List[BatchInputItem], max_output_tokens=None, **kwargs
|
||||||
|
) -> List[BatchOutputItem]:
|
||||||
if self.method == "vllm":
|
if self.method == "vllm":
|
||||||
results = generate_vllm(batch, **kwargs)
|
results = generate_vllm(
|
||||||
|
batch, max_output_tokens=max_output_tokens, **kwargs
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
results = generate_hf(batch, self.model, **kwargs)
|
results = generate_hf(
|
||||||
|
batch, self.model, max_output_tokens=max_output_tokens, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
output = []
|
output = []
|
||||||
for result, input_item in zip(results, batch):
|
for result, input_item in zip(results, batch):
|
||||||
@@ -31,7 +37,7 @@ class InferenceManager:
|
|||||||
chunks=parse_chunks(result.raw, input_item.image),
|
chunks=parse_chunks(result.raw, input_item.image),
|
||||||
raw=result.raw,
|
raw=result.raw,
|
||||||
page_box=[0, 0, input_item.image.width, input_item.image.height],
|
page_box=[0, 0, input_item.image.width, input_item.image.height],
|
||||||
token_count=result.token_count
|
token_count=result.token_count,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -9,7 +9,12 @@ from chandra.prompts import PROMPT_MAPPING
|
|||||||
from chandra.settings import settings
|
from chandra.settings import settings
|
||||||
|
|
||||||
|
|
||||||
def generate_hf(batch: List[BatchInputItem], model, **kwargs) -> List[GenerationResult]:
|
def generate_hf(
|
||||||
|
batch: List[BatchInputItem], model, max_output_tokens=None, **kwargs
|
||||||
|
) -> List[GenerationResult]:
|
||||||
|
if max_output_tokens is None:
|
||||||
|
max_output_tokens = settings.MAX_OUTPUT_TOKENS
|
||||||
|
|
||||||
messages = [process_batch_element(item, model.processor) for item in batch]
|
messages = [process_batch_element(item, model.processor) for item in batch]
|
||||||
text = model.processor.apply_chat_template(
|
text = model.processor.apply_chat_template(
|
||||||
messages, tokenize=False, add_generation_prompt=True
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
@@ -21,17 +26,20 @@ def generate_hf(batch: List[BatchInputItem], model, **kwargs) -> List[Generation
|
|||||||
images=image_inputs,
|
images=image_inputs,
|
||||||
padding=True,
|
padding=True,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
padding_side="left"
|
padding_side="left",
|
||||||
)
|
)
|
||||||
inputs = inputs.to("cuda")
|
inputs = inputs.to("cuda")
|
||||||
|
|
||||||
# Inference: Generation of the output
|
# Inference: Generation of the output
|
||||||
generated_ids = model.generate_hf(**inputs, max_new_tokens=settings.MAX_OUTPUT_TOKENS)
|
generated_ids = model.generate_hf(**inputs, max_new_tokens=max_output_tokens)
|
||||||
generated_ids_trimmed = [
|
generated_ids_trimmed = [
|
||||||
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
out_ids[len(in_ids) :]
|
||||||
|
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||||
]
|
]
|
||||||
output_text = model.processor.batch_decode(
|
output_text = model.processor.batch_decode(
|
||||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
generated_ids_trimmed,
|
||||||
|
skip_special_tokens=True,
|
||||||
|
clean_up_tokenization_spaces=False,
|
||||||
)
|
)
|
||||||
results = [
|
results = [
|
||||||
GenerationResult(raw=out, token_count=len(ids))
|
GenerationResult(raw=out, token_count=len(ids))
|
||||||
@@ -52,10 +60,7 @@ def process_batch_element(item: BatchInputItem, processor):
|
|||||||
content.append({"type": "image", "image": image})
|
content.append({"type": "image", "image": image})
|
||||||
|
|
||||||
content.append({"type": "text", "text": prompt})
|
content.append({"type": "text", "text": prompt})
|
||||||
message = {
|
message = {"role": "user", "content": content}
|
||||||
"role": "user",
|
|
||||||
"content": content
|
|
||||||
}
|
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,12 @@ def image_to_base64(image: Image.Image) -> str:
|
|||||||
return base64.b64encode(buffered.getvalue()).decode()
|
return base64.b64encode(buffered.getvalue()).decode()
|
||||||
|
|
||||||
|
|
||||||
def generate_vllm(batch: List[BatchInputItem], max_retries: int = None, max_workers: int | None = None) -> List[GenerationResult]:
|
def generate_vllm(
|
||||||
|
batch: List[BatchInputItem],
|
||||||
|
max_output_tokens: int = None,
|
||||||
|
max_retries: int = None,
|
||||||
|
max_workers: int | None = None,
|
||||||
|
) -> List[GenerationResult]:
|
||||||
client = OpenAI(
|
client = OpenAI(
|
||||||
api_key=settings.VLLM_API_KEY,
|
api_key=settings.VLLM_API_KEY,
|
||||||
base_url=settings.VLLM_API_BASE,
|
base_url=settings.VLLM_API_BASE,
|
||||||
@@ -33,11 +38,16 @@ def generate_vllm(batch: List[BatchInputItem], max_retries: int = None, max_work
|
|||||||
if max_workers is None:
|
if max_workers is None:
|
||||||
max_workers = min(64, len(batch))
|
max_workers = min(64, len(batch))
|
||||||
|
|
||||||
|
if max_output_tokens is None:
|
||||||
|
max_output_tokens = settings.MAX_OUTPUT_TOKENS
|
||||||
|
|
||||||
if model_name is None:
|
if model_name is None:
|
||||||
models = client.models.list()
|
models = client.models.list()
|
||||||
model_name = models.data[0].id
|
model_name = models.data[0].id
|
||||||
|
|
||||||
def _generate(item: BatchInputItem, temperature: float = 0, top_p: float = .1) -> GenerationResult:
|
def _generate(
|
||||||
|
item: BatchInputItem, temperature: float = 0, top_p: float = 0.1
|
||||||
|
) -> GenerationResult:
|
||||||
prompt = item.prompt
|
prompt = item.prompt
|
||||||
if not prompt:
|
if not prompt:
|
||||||
prompt = PROMPT_MAPPING[item.prompt_type]
|
prompt = PROMPT_MAPPING[item.prompt_type]
|
||||||
@@ -45,40 +55,41 @@ def generate_vllm(batch: List[BatchInputItem], max_retries: int = None, max_work
|
|||||||
content = []
|
content = []
|
||||||
image = scale_to_fit(item.image)
|
image = scale_to_fit(item.image)
|
||||||
image_b64 = image_to_base64(image)
|
image_b64 = image_to_base64(image)
|
||||||
content.append({
|
content.append(
|
||||||
|
{
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {
|
"image_url": {"url": f"data:image/png;base64,{image_b64}"},
|
||||||
"url": f"data:image/png;base64,{image_b64}"
|
|
||||||
}
|
}
|
||||||
})
|
)
|
||||||
|
|
||||||
content.append({
|
content.append({"type": "text", "text": prompt})
|
||||||
"type": "text",
|
|
||||||
"text": prompt
|
|
||||||
})
|
|
||||||
|
|
||||||
completion = client.chat.completions.create(
|
completion = client.chat.completions.create(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
messages=[{
|
messages=[{"role": "user", "content": content}],
|
||||||
"role": "user",
|
|
||||||
"content": content
|
|
||||||
}],
|
|
||||||
max_tokens=settings.MAX_OUTPUT_TOKENS,
|
max_tokens=settings.MAX_OUTPUT_TOKENS,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
)
|
)
|
||||||
return GenerationResult(
|
return GenerationResult(
|
||||||
raw=completion.choices[0].message.content,
|
raw=completion.choices[0].message.content,
|
||||||
token_count=completion.usage.completion_tokens
|
token_count=completion.usage.completion_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
def process_item(item, max_retries):
|
def process_item(item, max_retries):
|
||||||
result = _generate(item)
|
result = _generate(item)
|
||||||
retries = 0
|
retries = 0
|
||||||
|
|
||||||
while retries < max_retries and (detect_repeat_token(result.raw) or
|
while retries < max_retries and (
|
||||||
(len(result.raw) > 50 and detect_repeat_token(result.raw, cut_from_end=50))):
|
detect_repeat_token(result.raw)
|
||||||
print(f"Detected repeat token, retrying generation (attempt {retries + 1})...")
|
or (
|
||||||
|
len(result.raw) > 50
|
||||||
|
and detect_repeat_token(result.raw, cut_from_end=50)
|
||||||
|
)
|
||||||
|
):
|
||||||
|
print(
|
||||||
|
f"Detected repeat token, retrying generation (attempt {retries + 1})..."
|
||||||
|
)
|
||||||
result = _generate(item, temperature=0.3, top_p=0.95)
|
result = _generate(item, temperature=0.3, top_p=0.95)
|
||||||
retries += 1
|
retries += 1
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ def parse_html(html: str, include_headers_footers: bool = False):
|
|||||||
out_html += content
|
out_html += content
|
||||||
return out_html
|
return out_html
|
||||||
|
|
||||||
|
|
||||||
def escape_dollars(text):
|
def escape_dollars(text):
|
||||||
return text.replace("$", r"\$")
|
return text.replace("$", r"\$")
|
||||||
|
|
||||||
@@ -139,7 +140,11 @@ def parse_markdown(html: str, include_headers_footers: bool = False):
|
|||||||
inline_math_delimiters=("$", "$"),
|
inline_math_delimiters=("$", "$"),
|
||||||
block_math_delimiters=("$$", "$$"),
|
block_math_delimiters=("$$", "$$"),
|
||||||
)
|
)
|
||||||
|
try:
|
||||||
markdown = md_cls.convert(html)
|
markdown = md_cls.convert(html)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error converting HTML to Markdown: {e}")
|
||||||
|
markdown = ""
|
||||||
return markdown.strip()
|
return markdown.strip()
|
||||||
|
|
||||||
|
|
||||||
@@ -161,7 +166,7 @@ def parse_layout(html: str, image: Image.Image):
|
|||||||
bbox = div.get("data-bbox")
|
bbox = div.get("data-bbox")
|
||||||
try:
|
try:
|
||||||
bbox = json.loads(bbox)
|
bbox = json.loads(bbox)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
bbox = [0, 0, 1, 1] # Fallback to a default bbox if parsing fails
|
bbox = [0, 0, 1, 1] # Fallback to a default bbox if parsing fails
|
||||||
|
|
||||||
bbox = list(map(int, bbox))
|
bbox = list(map(int, bbox))
|
||||||
@@ -177,8 +182,8 @@ def parse_layout(html: str, image: Image.Image):
|
|||||||
layout_blocks.append(LayoutBlock(bbox=bbox, label=label, content=content))
|
layout_blocks.append(LayoutBlock(bbox=bbox, label=label, content=content))
|
||||||
return layout_blocks
|
return layout_blocks
|
||||||
|
|
||||||
|
|
||||||
def parse_chunks(html: str, image: Image.Image):
|
def parse_chunks(html: str, image: Image.Image):
|
||||||
layout = parse_layout(html, image)
|
layout = parse_layout(html, image)
|
||||||
chunks = [asdict(block) for block in layout]
|
chunks = [asdict(block) for block in layout]
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
|
|||||||
@@ -24,3 +24,8 @@ dependencies = [
|
|||||||
|
|
||||||
[tool.setuptools.packages.find]
|
[tool.setuptools.packages.find]
|
||||||
include = ["chandra*"]
|
include = ["chandra*"]
|
||||||
|
|
||||||
|
[dependency-groups]
|
||||||
|
dev = [
|
||||||
|
"pre-commit>=4.3.0",
|
||||||
|
]
|
||||||
|
|||||||
8
uv.lock
generated
8
uv.lock
generated
@@ -182,6 +182,11 @@ dependencies = [
|
|||||||
{ name = "transformers" },
|
{ name = "transformers" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[package.dev-dependencies]
|
||||||
|
dev = [
|
||||||
|
{ name = "pre-commit" },
|
||||||
|
]
|
||||||
|
|
||||||
[package.metadata]
|
[package.metadata]
|
||||||
requires-dist = [
|
requires-dist = [
|
||||||
{ name = "beautifulsoup4", specifier = ">=4.14.2" },
|
{ name = "beautifulsoup4", specifier = ">=4.14.2" },
|
||||||
@@ -201,6 +206,9 @@ requires-dist = [
|
|||||||
{ name = "transformers", specifier = ">=4.57.1" },
|
{ name = "transformers", specifier = ">=4.57.1" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[package.metadata.requires-dev]
|
||||||
|
dev = [{ name = "pre-commit", specifier = ">=4.3.0" }]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "charset-normalizer"
|
name = "charset-normalizer"
|
||||||
version = "3.4.3"
|
version = "3.4.3"
|
||||||
|
|||||||
Reference in New Issue
Block a user