Merge pull request #36 from datalab-to/vik/bbox

Fix retry settings
This commit is contained in:
Vik Paruchuri
2025-11-12 16:03:10 -05:00
committed by GitHub

View File

@@ -26,6 +26,7 @@ def generate_vllm(
max_retries: int = None,
max_workers: int | None = None,
custom_headers: dict | None = None,
max_failure_retries: int | None = None,
) -> List[GenerationResult]:
client = OpenAI(
api_key=settings.VLLM_API_KEY,
@@ -86,27 +87,50 @@ def generate_vllm(
return result
def process_item(item, max_retries):
def process_item(item, max_retries, max_failure_retries=None):
result = _generate(item)
retries = 0
while retries < max_retries and (
detect_repeat_token(result.raw)
or (
len(result.raw) > 50
and detect_repeat_token(result.raw, cut_from_end=50)
)
or result.error
):
print(
f"Detected repeat token or error, retrying generation (attempt {retries + 1})..."
)
while _should_retry(result, retries, max_retries, max_failure_retries):
result = _generate(item, temperature=0.3, top_p=0.95)
retries += 1
return result
def _should_retry(result, retries, max_retries, max_failure_retries):
has_repeat = detect_repeat_token(result.raw) or (
len(result.raw) > 50 and detect_repeat_token(result.raw, cut_from_end=50)
)
if retries < max_retries and has_repeat:
print(
f"Detected repeat token, retrying generation (attempt {retries + 1})..."
)
return True
if retries < max_retries and result.error:
print(
f"Detected vllm error, retrying generation (attempt {retries + 1})..."
)
return True
if (
result.error
and max_failure_retries is not None
and retries < max_failure_retries
):
print(
f"Detected vllm error, retrying generation (attempt {retries + 1})..."
)
return True
return False
with ThreadPoolExecutor(max_workers=max_workers) as executor:
results = list(executor.map(process_item, batch, repeat(max_retries)))
results = list(
executor.map(
process_item, batch, repeat(max_retries), repeat(max_failure_retries)
)
)
return results