mirror of
https://github.com/datalab-to/chandra.git
synced 2025-12-01 17:43:10 +00:00
Refactor
This commit is contained in:
80
chandra/model/vllm.py
Normal file
80
chandra/model/vllm.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import base64
|
||||
import io
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import List
|
||||
|
||||
from PIL import Image
|
||||
from openai import OpenAI
|
||||
|
||||
from chandra.model.schema import BatchInputItem
|
||||
from chandra.model.util import scale_to_fit, detect_repeat_token
|
||||
from chandra.prompts import PROMPT_MAPPING
|
||||
from chandra.settings import settings
|
||||
|
||||
|
||||
def image_to_base64(image: Image.Image) -> str:
|
||||
"""Convert PIL Image to base64 string."""
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format="PNG")
|
||||
return base64.b64encode(buffered.getvalue()).decode()
|
||||
|
||||
|
||||
def generate_vllm(batch: List[BatchInputItem], max_retries: int = 5):
|
||||
client = OpenAI(
|
||||
api_key=settings.VLLM_API_KEY,
|
||||
base_url=settings.VLLM_API_BASE,
|
||||
)
|
||||
model_name = settings.VLLM_MODEL_NAME
|
||||
|
||||
if model_name is None:
|
||||
models = client.models.list()
|
||||
model_name = models.data[0].id
|
||||
|
||||
def _generate(item: BatchInputItem, temperature: float = 0, top_p: float = .1):
|
||||
prompt = item.prompt
|
||||
if not prompt:
|
||||
prompt = PROMPT_MAPPING[item.prompt_type]
|
||||
|
||||
content = []
|
||||
image = scale_to_fit(item.image)
|
||||
image_b64 = image_to_base64(image)
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{image_b64}"
|
||||
}
|
||||
})
|
||||
|
||||
content.append({
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
})
|
||||
|
||||
completion = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": content
|
||||
}],
|
||||
max_tokens=settings.MAX_OUTPUT_TOKENS,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
)
|
||||
return completion.choices[0].message.content
|
||||
|
||||
def process_item(item, max_retries=3):
|
||||
result = _generate(item)
|
||||
retries = 0
|
||||
|
||||
while retries < max_retries and (detect_repeat_token(result) or
|
||||
(len(result) > 50 and detect_repeat_token(result[:-50]))):
|
||||
print(f"Detected repeat token, retrying generation (attempt {retries + 1})...")
|
||||
result = _generate(item, temperature=0.2, top_p=0.9)
|
||||
retries += 1
|
||||
|
||||
return result
|
||||
|
||||
with ThreadPoolExecutor(max_workers=len(batch)) as executor:
|
||||
results = list(executor.map(process_item, batch))
|
||||
|
||||
return results
|
||||
Reference in New Issue
Block a user