This commit is contained in:
Vik Paruchuri
2025-10-15 16:06:57 -04:00
parent d511d5f9a6
commit a9ffa789c6
15 changed files with 724 additions and 177 deletions

80
chandra/model/vllm.py Normal file
View File

@@ -0,0 +1,80 @@
import base64
import io
from concurrent.futures import ThreadPoolExecutor
from typing import List
from PIL import Image
from openai import OpenAI
from chandra.model.schema import BatchInputItem
from chandra.model.util import scale_to_fit, detect_repeat_token
from chandra.prompts import PROMPT_MAPPING
from chandra.settings import settings
def image_to_base64(image: Image.Image) -> str:
"""Convert PIL Image to base64 string."""
buffered = io.BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
def generate_vllm(batch: List[BatchInputItem], max_retries: int = 5):
client = OpenAI(
api_key=settings.VLLM_API_KEY,
base_url=settings.VLLM_API_BASE,
)
model_name = settings.VLLM_MODEL_NAME
if model_name is None:
models = client.models.list()
model_name = models.data[0].id
def _generate(item: BatchInputItem, temperature: float = 0, top_p: float = .1):
prompt = item.prompt
if not prompt:
prompt = PROMPT_MAPPING[item.prompt_type]
content = []
image = scale_to_fit(item.image)
image_b64 = image_to_base64(image)
content.append({
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{image_b64}"
}
})
content.append({
"type": "text",
"text": prompt
})
completion = client.chat.completions.create(
model=model_name,
messages=[{
"role": "user",
"content": content
}],
max_tokens=settings.MAX_OUTPUT_TOKENS,
temperature=temperature,
top_p=top_p,
)
return completion.choices[0].message.content
def process_item(item, max_retries=3):
result = _generate(item)
retries = 0
while retries < max_retries and (detect_repeat_token(result) or
(len(result) > 50 and detect_repeat_token(result[:-50]))):
print(f"Detected repeat token, retrying generation (attempt {retries + 1})...")
result = _generate(item, temperature=0.2, top_p=0.9)
retries += 1
return result
with ThreadPoolExecutor(max_workers=len(batch)) as executor:
results = list(executor.map(process_item, batch))
return results