mirror of
https://github.com/datalab-to/chandra.git
synced 2025-11-29 00:23:12 +00:00
@@ -26,6 +26,7 @@ def generate_vllm(
|
||||
max_retries: int = None,
|
||||
max_workers: int | None = None,
|
||||
custom_headers: dict | None = None,
|
||||
max_failure_retries: int | None = None,
|
||||
) -> List[GenerationResult]:
|
||||
client = OpenAI(
|
||||
api_key=settings.VLLM_API_KEY,
|
||||
@@ -86,27 +87,50 @@ def generate_vllm(
|
||||
|
||||
return result
|
||||
|
||||
def process_item(item, max_retries):
|
||||
def process_item(item, max_retries, max_failure_retries=None):
|
||||
result = _generate(item)
|
||||
retries = 0
|
||||
|
||||
while retries < max_retries and (
|
||||
detect_repeat_token(result.raw)
|
||||
or (
|
||||
len(result.raw) > 50
|
||||
and detect_repeat_token(result.raw, cut_from_end=50)
|
||||
)
|
||||
or result.error
|
||||
):
|
||||
print(
|
||||
f"Detected repeat token or error, retrying generation (attempt {retries + 1})..."
|
||||
)
|
||||
while _should_retry(result, retries, max_retries, max_failure_retries):
|
||||
result = _generate(item, temperature=0.3, top_p=0.95)
|
||||
retries += 1
|
||||
|
||||
return result
|
||||
|
||||
def _should_retry(result, retries, max_retries, max_failure_retries):
|
||||
has_repeat = detect_repeat_token(result.raw) or (
|
||||
len(result.raw) > 50 and detect_repeat_token(result.raw, cut_from_end=50)
|
||||
)
|
||||
|
||||
if retries < max_retries and has_repeat:
|
||||
print(
|
||||
f"Detected repeat token, retrying generation (attempt {retries + 1})..."
|
||||
)
|
||||
return True
|
||||
|
||||
if retries < max_retries and result.error:
|
||||
print(
|
||||
f"Detected vllm error, retrying generation (attempt {retries + 1})..."
|
||||
)
|
||||
return True
|
||||
|
||||
if (
|
||||
result.error
|
||||
and max_failure_retries is not None
|
||||
and retries < max_failure_retries
|
||||
):
|
||||
print(
|
||||
f"Detected vllm error, retrying generation (attempt {retries + 1})..."
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
results = list(executor.map(process_item, batch, repeat(max_retries)))
|
||||
results = list(
|
||||
executor.map(
|
||||
process_item, batch, repeat(max_retries), repeat(max_failure_retries)
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
Reference in New Issue
Block a user