From bb1fd39a46d3f7776b619b1767a635d0ebf97596 Mon Sep 17 00:00:00 2001 From: Vik Paruchuri Date: Thu, 16 Oct 2025 20:19:43 -0400 Subject: [PATCH] Support returning token count --- .gitignore | 2 ++ chandra/model/__init__.py | 11 ++++++----- chandra/model/hf.py | 10 +++++++--- chandra/model/schema.py | 5 +++++ chandra/model/vllm.py | 15 +++++++++------ 5 files changed, 29 insertions(+), 14 deletions(-) diff --git a/.gitignore b/.gitignore index d3d597a..b70abfd 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +local.env + # Byte-compiled / optimized / DLL files __pycache__/ *.py[codz] diff --git a/chandra/model/__init__.py b/chandra/model/__init__.py index 63501bb..7a23ad9 100644 --- a/chandra/model/__init__.py +++ b/chandra/model/__init__.py @@ -26,11 +26,12 @@ class InferenceManager: for result, input_item in zip(results, batch): output.append( BatchOutputItem( - markdown=parse_markdown(result), - html=parse_html(result), - chunks=parse_chunks(result, input_item.image), - raw=result, - page_box=[0, 0, input_item.image.width, input_item.image.height] + markdown=parse_markdown(result.raw), + html=parse_html(result.raw), + chunks=parse_chunks(result.raw, input_item.image), + raw=result.raw, + page_box=[0, 0, input_item.image.width, input_item.image.height], + token_count=result.token_count ) ) return output diff --git a/chandra/model/hf.py b/chandra/model/hf.py index 374b689..a52ad8c 100644 --- a/chandra/model/hf.py +++ b/chandra/model/hf.py @@ -3,13 +3,13 @@ from typing import List from qwen_vl_utils import process_vision_info from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor -from chandra.model.schema import BatchInputItem +from chandra.model.schema import BatchInputItem, GenerationResult from chandra.model.util import scale_to_fit from chandra.prompts import PROMPT_MAPPING from chandra.settings import settings -def generate_hf(batch: List[BatchInputItem], model, **kwargs): +def generate_hf(batch: List[BatchInputItem], model, **kwargs) -> List[GenerationResult]: messages = [process_batch_element(item, model.processor) for item in batch] text = model.processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True @@ -33,7 +33,11 @@ def generate_hf(batch: List[BatchInputItem], model, **kwargs): output_text = model.processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) - return output_text + results = [ + GenerationResult(raw=out, token_count=len(ids)) + for out, ids in zip(output_text, generated_ids_trimmed) + ] + return results def process_batch_element(item: BatchInputItem, processor): diff --git a/chandra/model/schema.py b/chandra/model/schema.py index 536e05e..7349ac3 100644 --- a/chandra/model/schema.py +++ b/chandra/model/schema.py @@ -3,6 +3,10 @@ from typing import List from PIL import Image +@dataclass +class GenerationResult: + raw: str + token_count: int @dataclass class BatchInputItem: @@ -17,3 +21,4 @@ class BatchOutputItem: chunks: dict raw: str page_box: List[int] + token_count: int diff --git a/chandra/model/vllm.py b/chandra/model/vllm.py index cd7440d..b1fbeae 100644 --- a/chandra/model/vllm.py +++ b/chandra/model/vllm.py @@ -7,7 +7,7 @@ from typing import List from PIL import Image from openai import OpenAI -from chandra.model.schema import BatchInputItem +from chandra.model.schema import BatchInputItem, GenerationResult from chandra.model.util import scale_to_fit, detect_repeat_token from chandra.prompts import PROMPT_MAPPING from chandra.settings import settings @@ -20,7 +20,7 @@ def image_to_base64(image: Image.Image) -> str: return base64.b64encode(buffered.getvalue()).decode() -def generate_vllm(batch: List[BatchInputItem], max_retries: int = None, max_workers: int | None = None): +def generate_vllm(batch: List[BatchInputItem], max_retries: int = None, max_workers: int | None = None) -> List[GenerationResult]: client = OpenAI( api_key=settings.VLLM_API_KEY, base_url=settings.VLLM_API_BASE, @@ -37,7 +37,7 @@ def generate_vllm(batch: List[BatchInputItem], max_retries: int = None, max_work models = client.models.list() model_name = models.data[0].id - def _generate(item: BatchInputItem, temperature: float = 0, top_p: float = .1): + def _generate(item: BatchInputItem, temperature: float = 0, top_p: float = .1) -> GenerationResult: prompt = item.prompt if not prompt: prompt = PROMPT_MAPPING[item.prompt_type] @@ -67,14 +67,17 @@ def generate_vllm(batch: List[BatchInputItem], max_retries: int = None, max_work temperature=temperature, top_p=top_p, ) - return completion.choices[0].message.content + return GenerationResult( + raw=completion.choices[0].message.content, + token_count=completion.usage.completion_tokens + ) def process_item(item, max_retries): result = _generate(item) retries = 0 - while retries < max_retries and (detect_repeat_token(result) or - (len(result) > 50 and detect_repeat_token(result, cut_from_end=50))): + while retries < max_retries and (detect_repeat_token(result.raw) or + (len(result.raw) > 50 and detect_repeat_token(result.raw, cut_from_end=50))): print(f"Detected repeat token, retrying generation (attempt {retries + 1})...") result = _generate(item, temperature=0.3, top_p=0.95) retries += 1