Enable robustness

This commit is contained in:
Vik Paruchuri
2025-10-18 14:30:18 -04:00
parent 6d093af119
commit 313f9c71b8
5 changed files with 23 additions and 11 deletions

1
.gitignore vendored
View File

@@ -1,4 +1,5 @@
local.env
experiments
# Byte-compiled / optimized / DLL files
__pycache__/

View File

@@ -42,7 +42,7 @@ def generate_hf(
clean_up_tokenization_spaces=False,
)
results = [
GenerationResult(raw=out, token_count=len(ids))
GenerationResult(raw=out, token_count=len(ids), error=False)
for out, ids in zip(output_text, generated_ids_trimmed)
]
return results

View File

@@ -3,10 +3,13 @@ from typing import List
from PIL import Image
@dataclass
class GenerationResult:
raw: str
token_count: int
error: bool = False
@dataclass
class BatchInputItem:
@@ -14,6 +17,7 @@ class BatchInputItem:
prompt: str | None = None
prompt_type: str | None = None
@dataclass
class BatchOutputItem:
markdown: str

View File

@@ -64,16 +64,22 @@ def generate_vllm(
content.append({"type": "text", "text": prompt})
completion = client.chat.completions.create(
model=model_name,
messages=[{"role": "user", "content": content}],
max_tokens=settings.MAX_OUTPUT_TOKENS,
temperature=temperature,
top_p=top_p,
)
try:
completion = client.chat.completions.create(
model=model_name,
messages=[{"role": "user", "content": content}],
max_tokens=settings.MAX_OUTPUT_TOKENS,
temperature=temperature,
top_p=top_p,
)
except Exception as e:
print(f"Error during VLLM generation: {e}")
return GenerationResult(raw="", token_count=0, error=True)
return GenerationResult(
raw=completion.choices[0].message.content,
token_count=completion.usage.completion_tokens,
error=False,
)
def process_item(item, max_retries):
@@ -86,9 +92,10 @@ def generate_vllm(
len(result.raw) > 50
and detect_repeat_token(result.raw, cut_from_end=50)
)
or result.error
):
print(
f"Detected repeat token, retrying generation (attempt {retries + 1})..."
f"Detected repeat token or error, retrying generation (attempt {retries + 1})..."
)
result = _generate(item, temperature=0.3, top_p=0.95)
retries += 1

View File

@@ -71,7 +71,7 @@ Use the following labels:
- Caption
- Footnote
- Equation-Block
- List-Item
- List-Group
- Page-Header
- Page-Footer
- Image
@@ -96,4 +96,4 @@ OCR this image to HTML.
PROMPT_MAPPING = {
"ocr_layout": OCR_LAYOUT_PROMPT,
"ocr": OCR_PROMPT,
}
}