mirror of
https://github.com/datalab-to/chandra.git
synced 2025-12-01 17:43:10 +00:00
Support returning token count
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -1,3 +1,5 @@
|
||||
local.env
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[codz]
|
||||
|
||||
@@ -26,11 +26,12 @@ class InferenceManager:
|
||||
for result, input_item in zip(results, batch):
|
||||
output.append(
|
||||
BatchOutputItem(
|
||||
markdown=parse_markdown(result),
|
||||
html=parse_html(result),
|
||||
chunks=parse_chunks(result, input_item.image),
|
||||
raw=result,
|
||||
page_box=[0, 0, input_item.image.width, input_item.image.height]
|
||||
markdown=parse_markdown(result.raw),
|
||||
html=parse_html(result.raw),
|
||||
chunks=parse_chunks(result.raw, input_item.image),
|
||||
raw=result.raw,
|
||||
page_box=[0, 0, input_item.image.width, input_item.image.height],
|
||||
token_count=result.token_count
|
||||
)
|
||||
)
|
||||
return output
|
||||
|
||||
@@ -3,13 +3,13 @@ from typing import List
|
||||
from qwen_vl_utils import process_vision_info
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor
|
||||
|
||||
from chandra.model.schema import BatchInputItem
|
||||
from chandra.model.schema import BatchInputItem, GenerationResult
|
||||
from chandra.model.util import scale_to_fit
|
||||
from chandra.prompts import PROMPT_MAPPING
|
||||
from chandra.settings import settings
|
||||
|
||||
|
||||
def generate_hf(batch: List[BatchInputItem], model, **kwargs):
|
||||
def generate_hf(batch: List[BatchInputItem], model, **kwargs) -> List[GenerationResult]:
|
||||
messages = [process_batch_element(item, model.processor) for item in batch]
|
||||
text = model.processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
@@ -33,7 +33,11 @@ def generate_hf(batch: List[BatchInputItem], model, **kwargs):
|
||||
output_text = model.processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
return output_text
|
||||
results = [
|
||||
GenerationResult(raw=out, token_count=len(ids))
|
||||
for out, ids in zip(output_text, generated_ids_trimmed)
|
||||
]
|
||||
return results
|
||||
|
||||
|
||||
def process_batch_element(item: BatchInputItem, processor):
|
||||
|
||||
@@ -3,6 +3,10 @@ from typing import List
|
||||
|
||||
from PIL import Image
|
||||
|
||||
@dataclass
|
||||
class GenerationResult:
|
||||
raw: str
|
||||
token_count: int
|
||||
|
||||
@dataclass
|
||||
class BatchInputItem:
|
||||
@@ -17,3 +21,4 @@ class BatchOutputItem:
|
||||
chunks: dict
|
||||
raw: str
|
||||
page_box: List[int]
|
||||
token_count: int
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import List
|
||||
from PIL import Image
|
||||
from openai import OpenAI
|
||||
|
||||
from chandra.model.schema import BatchInputItem
|
||||
from chandra.model.schema import BatchInputItem, GenerationResult
|
||||
from chandra.model.util import scale_to_fit, detect_repeat_token
|
||||
from chandra.prompts import PROMPT_MAPPING
|
||||
from chandra.settings import settings
|
||||
@@ -20,7 +20,7 @@ def image_to_base64(image: Image.Image) -> str:
|
||||
return base64.b64encode(buffered.getvalue()).decode()
|
||||
|
||||
|
||||
def generate_vllm(batch: List[BatchInputItem], max_retries: int = None, max_workers: int | None = None):
|
||||
def generate_vllm(batch: List[BatchInputItem], max_retries: int = None, max_workers: int | None = None) -> List[GenerationResult]:
|
||||
client = OpenAI(
|
||||
api_key=settings.VLLM_API_KEY,
|
||||
base_url=settings.VLLM_API_BASE,
|
||||
@@ -37,7 +37,7 @@ def generate_vllm(batch: List[BatchInputItem], max_retries: int = None, max_work
|
||||
models = client.models.list()
|
||||
model_name = models.data[0].id
|
||||
|
||||
def _generate(item: BatchInputItem, temperature: float = 0, top_p: float = .1):
|
||||
def _generate(item: BatchInputItem, temperature: float = 0, top_p: float = .1) -> GenerationResult:
|
||||
prompt = item.prompt
|
||||
if not prompt:
|
||||
prompt = PROMPT_MAPPING[item.prompt_type]
|
||||
@@ -67,14 +67,17 @@ def generate_vllm(batch: List[BatchInputItem], max_retries: int = None, max_work
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
)
|
||||
return completion.choices[0].message.content
|
||||
return GenerationResult(
|
||||
raw=completion.choices[0].message.content,
|
||||
token_count=completion.usage.completion_tokens
|
||||
)
|
||||
|
||||
def process_item(item, max_retries):
|
||||
result = _generate(item)
|
||||
retries = 0
|
||||
|
||||
while retries < max_retries and (detect_repeat_token(result) or
|
||||
(len(result) > 50 and detect_repeat_token(result, cut_from_end=50))):
|
||||
while retries < max_retries and (detect_repeat_token(result.raw) or
|
||||
(len(result.raw) > 50 and detect_repeat_token(result.raw, cut_from_end=50))):
|
||||
print(f"Detected repeat token, retrying generation (attempt {retries + 1})...")
|
||||
result = _generate(item, temperature=0.3, top_p=0.95)
|
||||
retries += 1
|
||||
|
||||
Reference in New Issue
Block a user