Files
chandra/chandra/model/hf.py
Vik Paruchuri 2e455aeb2c Fix attn impl
2025-10-21 11:15:29 -04:00

86 lines
2.5 KiB
Python

from typing import List
from qwen_vl_utils import process_vision_info
from transformers import Qwen3VLForConditionalGeneration, Qwen3VLProcessor
from chandra.model.schema import BatchInputItem, GenerationResult
from chandra.model.util import scale_to_fit
from chandra.prompts import PROMPT_MAPPING
from chandra.settings import settings
def generate_hf(
batch: List[BatchInputItem], model, max_output_tokens=None, **kwargs
) -> List[GenerationResult]:
if max_output_tokens is None:
max_output_tokens = settings.MAX_OUTPUT_TOKENS
messages = [process_batch_element(item, model.processor) for item in batch]
text = model.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, _ = process_vision_info(messages)
inputs = model.processor(
text=text,
images=image_inputs,
padding=True,
return_tensors="pt",
padding_side="left",
)
inputs = inputs.to("cuda")
# Inference: Generation of the output
generated_ids = model.generate(**inputs, max_new_tokens=max_output_tokens)
generated_ids_trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = model.processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
results = [
GenerationResult(raw=out, token_count=len(ids), error=False)
for out, ids in zip(output_text, generated_ids_trimmed)
]
return results
def process_batch_element(item: BatchInputItem, processor):
prompt = item.prompt
prompt_type = item.prompt_type
if not prompt:
prompt = PROMPT_MAPPING[prompt_type]
content = []
image = scale_to_fit(item.image) # Guarantee max size
content.append({"type": "image", "image": image})
content.append({"type": "text", "text": prompt})
message = {"role": "user", "content": content}
return message
def load_model():
device_map = "auto"
if settings.TORCH_DEVICE:
device_map = {"": settings.TORCH_DEVICE}
kwargs = {
"dtype": settings.TORCH_DTYPE,
"device_map": device_map,
}
if settings.TORCH_ATTN:
kwargs["attn_implementation"] = settings.TORCH_ATTN
model = Qwen3VLForConditionalGeneration.from_pretrained(
settings.MODEL_CHECKPOINT, **kwargs
)
model = model.eval()
processor = Qwen3VLProcessor.from_pretrained(settings.MODEL_CHECKPOINT)
model.processor = processor
return model