Add precommit

This commit is contained in:
Vik Paruchuri
2025-10-16 20:26:19 -04:00
parent bb1fd39a46
commit 6d093af119
7 changed files with 86 additions and 36 deletions

View File

@@ -20,7 +20,12 @@ def image_to_base64(image: Image.Image) -> str:
return base64.b64encode(buffered.getvalue()).decode()
def generate_vllm(batch: List[BatchInputItem], max_retries: int = None, max_workers: int | None = None) -> List[GenerationResult]:
def generate_vllm(
batch: List[BatchInputItem],
max_output_tokens: int = None,
max_retries: int = None,
max_workers: int | None = None,
) -> List[GenerationResult]:
client = OpenAI(
api_key=settings.VLLM_API_KEY,
base_url=settings.VLLM_API_BASE,
@@ -33,11 +38,16 @@ def generate_vllm(batch: List[BatchInputItem], max_retries: int = None, max_work
if max_workers is None:
max_workers = min(64, len(batch))
if max_output_tokens is None:
max_output_tokens = settings.MAX_OUTPUT_TOKENS
if model_name is None:
models = client.models.list()
model_name = models.data[0].id
def _generate(item: BatchInputItem, temperature: float = 0, top_p: float = .1) -> GenerationResult:
def _generate(
item: BatchInputItem, temperature: float = 0, top_p: float = 0.1
) -> GenerationResult:
prompt = item.prompt
if not prompt:
prompt = PROMPT_MAPPING[item.prompt_type]
@@ -45,40 +55,41 @@ def generate_vllm(batch: List[BatchInputItem], max_retries: int = None, max_work
content = []
image = scale_to_fit(item.image)
image_b64 = image_to_base64(image)
content.append({
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{image_b64}"
content.append(
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{image_b64}"},
}
})
)
content.append({
"type": "text",
"text": prompt
})
content.append({"type": "text", "text": prompt})
completion = client.chat.completions.create(
model=model_name,
messages=[{
"role": "user",
"content": content
}],
messages=[{"role": "user", "content": content}],
max_tokens=settings.MAX_OUTPUT_TOKENS,
temperature=temperature,
top_p=top_p,
)
return GenerationResult(
raw=completion.choices[0].message.content,
token_count=completion.usage.completion_tokens
token_count=completion.usage.completion_tokens,
)
def process_item(item, max_retries):
result = _generate(item)
retries = 0
while retries < max_retries and (detect_repeat_token(result.raw) or
(len(result.raw) > 50 and detect_repeat_token(result.raw, cut_from_end=50))):
print(f"Detected repeat token, retrying generation (attempt {retries + 1})...")
while retries < max_retries and (
detect_repeat_token(result.raw)
or (
len(result.raw) > 50
and detect_repeat_token(result.raw, cut_from_end=50)
)
):
print(
f"Detected repeat token, retrying generation (attempt {retries + 1})..."
)
result = _generate(item, temperature=0.3, top_p=0.95)
retries += 1