mirror of
https://github.com/datalab-to/chandra.git
synced 2025-12-02 18:13:09 +00:00
Add precommit
This commit is contained in:
@@ -20,7 +20,12 @@ def image_to_base64(image: Image.Image) -> str:
|
||||
return base64.b64encode(buffered.getvalue()).decode()
|
||||
|
||||
|
||||
def generate_vllm(batch: List[BatchInputItem], max_retries: int = None, max_workers: int | None = None) -> List[GenerationResult]:
|
||||
def generate_vllm(
|
||||
batch: List[BatchInputItem],
|
||||
max_output_tokens: int = None,
|
||||
max_retries: int = None,
|
||||
max_workers: int | None = None,
|
||||
) -> List[GenerationResult]:
|
||||
client = OpenAI(
|
||||
api_key=settings.VLLM_API_KEY,
|
||||
base_url=settings.VLLM_API_BASE,
|
||||
@@ -33,11 +38,16 @@ def generate_vllm(batch: List[BatchInputItem], max_retries: int = None, max_work
|
||||
if max_workers is None:
|
||||
max_workers = min(64, len(batch))
|
||||
|
||||
if max_output_tokens is None:
|
||||
max_output_tokens = settings.MAX_OUTPUT_TOKENS
|
||||
|
||||
if model_name is None:
|
||||
models = client.models.list()
|
||||
model_name = models.data[0].id
|
||||
|
||||
def _generate(item: BatchInputItem, temperature: float = 0, top_p: float = .1) -> GenerationResult:
|
||||
def _generate(
|
||||
item: BatchInputItem, temperature: float = 0, top_p: float = 0.1
|
||||
) -> GenerationResult:
|
||||
prompt = item.prompt
|
||||
if not prompt:
|
||||
prompt = PROMPT_MAPPING[item.prompt_type]
|
||||
@@ -45,40 +55,41 @@ def generate_vllm(batch: List[BatchInputItem], max_retries: int = None, max_work
|
||||
content = []
|
||||
image = scale_to_fit(item.image)
|
||||
image_b64 = image_to_base64(image)
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{image_b64}"
|
||||
content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image_b64}"},
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
content.append({
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
})
|
||||
content.append({"type": "text", "text": prompt})
|
||||
|
||||
completion = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": content
|
||||
}],
|
||||
messages=[{"role": "user", "content": content}],
|
||||
max_tokens=settings.MAX_OUTPUT_TOKENS,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
)
|
||||
return GenerationResult(
|
||||
raw=completion.choices[0].message.content,
|
||||
token_count=completion.usage.completion_tokens
|
||||
token_count=completion.usage.completion_tokens,
|
||||
)
|
||||
|
||||
def process_item(item, max_retries):
|
||||
result = _generate(item)
|
||||
retries = 0
|
||||
|
||||
while retries < max_retries and (detect_repeat_token(result.raw) or
|
||||
(len(result.raw) > 50 and detect_repeat_token(result.raw, cut_from_end=50))):
|
||||
print(f"Detected repeat token, retrying generation (attempt {retries + 1})...")
|
||||
while retries < max_retries and (
|
||||
detect_repeat_token(result.raw)
|
||||
or (
|
||||
len(result.raw) > 50
|
||||
and detect_repeat_token(result.raw, cut_from_end=50)
|
||||
)
|
||||
):
|
||||
print(
|
||||
f"Detected repeat token, retrying generation (attempt {retries + 1})..."
|
||||
)
|
||||
result = _generate(item, temperature=0.3, top_p=0.95)
|
||||
retries += 1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user