Support returning token count

This commit is contained in:
Vik Paruchuri
2025-10-16 20:19:43 -04:00
parent aa59df2996
commit bb1fd39a46
5 changed files with 29 additions and 14 deletions

View File

@@ -7,7 +7,7 @@ from typing import List
from PIL import Image
from openai import OpenAI
from chandra.model.schema import BatchInputItem
from chandra.model.schema import BatchInputItem, GenerationResult
from chandra.model.util import scale_to_fit, detect_repeat_token
from chandra.prompts import PROMPT_MAPPING
from chandra.settings import settings
@@ -20,7 +20,7 @@ def image_to_base64(image: Image.Image) -> str:
return base64.b64encode(buffered.getvalue()).decode()
def generate_vllm(batch: List[BatchInputItem], max_retries: int = None, max_workers: int | None = None):
def generate_vllm(batch: List[BatchInputItem], max_retries: int = None, max_workers: int | None = None) -> List[GenerationResult]:
client = OpenAI(
api_key=settings.VLLM_API_KEY,
base_url=settings.VLLM_API_BASE,
@@ -37,7 +37,7 @@ def generate_vllm(batch: List[BatchInputItem], max_retries: int = None, max_work
models = client.models.list()
model_name = models.data[0].id
def _generate(item: BatchInputItem, temperature: float = 0, top_p: float = .1):
def _generate(item: BatchInputItem, temperature: float = 0, top_p: float = .1) -> GenerationResult:
prompt = item.prompt
if not prompt:
prompt = PROMPT_MAPPING[item.prompt_type]
@@ -67,14 +67,17 @@ def generate_vllm(batch: List[BatchInputItem], max_retries: int = None, max_work
temperature=temperature,
top_p=top_p,
)
return completion.choices[0].message.content
return GenerationResult(
raw=completion.choices[0].message.content,
token_count=completion.usage.completion_tokens
)
def process_item(item, max_retries):
result = _generate(item)
retries = 0
while retries < max_retries and (detect_repeat_token(result) or
(len(result) > 50 and detect_repeat_token(result, cut_from_end=50))):
while retries < max_retries and (detect_repeat_token(result.raw) or
(len(result.raw) > 50 and detect_repeat_token(result.raw, cut_from_end=50))):
print(f"Detected repeat token, retrying generation (attempt {retries + 1})...")
result = _generate(item, temperature=0.3, top_p=0.95)
retries += 1