mirror of
https://github.com/datalab-to/chandra.git
synced 2025-12-03 18:43:09 +00:00
Enable robustness
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,4 +1,5 @@
|
||||
local.env
|
||||
experiments
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
|
||||
@@ -42,7 +42,7 @@ def generate_hf(
|
||||
clean_up_tokenization_spaces=False,
|
||||
)
|
||||
results = [
|
||||
GenerationResult(raw=out, token_count=len(ids))
|
||||
GenerationResult(raw=out, token_count=len(ids), error=False)
|
||||
for out, ids in zip(output_text, generated_ids_trimmed)
|
||||
]
|
||||
return results
|
||||
|
||||
@@ -3,10 +3,13 @@ from typing import List
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerationResult:
|
||||
raw: str
|
||||
token_count: int
|
||||
error: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchInputItem:
|
||||
@@ -14,6 +17,7 @@ class BatchInputItem:
|
||||
prompt: str | None = None
|
||||
prompt_type: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchOutputItem:
|
||||
markdown: str
|
||||
|
||||
@@ -64,16 +64,22 @@ def generate_vllm(
|
||||
|
||||
content.append({"type": "text", "text": prompt})
|
||||
|
||||
completion = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=[{"role": "user", "content": content}],
|
||||
max_tokens=settings.MAX_OUTPUT_TOKENS,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
)
|
||||
try:
|
||||
completion = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=[{"role": "user", "content": content}],
|
||||
max_tokens=settings.MAX_OUTPUT_TOKENS,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error during VLLM generation: {e}")
|
||||
return GenerationResult(raw="", token_count=0, error=True)
|
||||
|
||||
return GenerationResult(
|
||||
raw=completion.choices[0].message.content,
|
||||
token_count=completion.usage.completion_tokens,
|
||||
error=False,
|
||||
)
|
||||
|
||||
def process_item(item, max_retries):
|
||||
@@ -86,9 +92,10 @@ def generate_vllm(
|
||||
len(result.raw) > 50
|
||||
and detect_repeat_token(result.raw, cut_from_end=50)
|
||||
)
|
||||
or result.error
|
||||
):
|
||||
print(
|
||||
f"Detected repeat token, retrying generation (attempt {retries + 1})..."
|
||||
f"Detected repeat token or error, retrying generation (attempt {retries + 1})..."
|
||||
)
|
||||
result = _generate(item, temperature=0.3, top_p=0.95)
|
||||
retries += 1
|
||||
|
||||
@@ -71,7 +71,7 @@ Use the following labels:
|
||||
- Caption
|
||||
- Footnote
|
||||
- Equation-Block
|
||||
- List-Item
|
||||
- List-Group
|
||||
- Page-Header
|
||||
- Page-Footer
|
||||
- Image
|
||||
@@ -96,4 +96,4 @@ OCR this image to HTML.
|
||||
PROMPT_MAPPING = {
|
||||
"ocr_layout": OCR_LAYOUT_PROMPT,
|
||||
"ocr": OCR_PROMPT,
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user