mirror of
https://github.com/datalab-to/chandra.git
synced 2025-11-29 08:33:13 +00:00
@@ -26,6 +26,7 @@ def generate_vllm(
|
|||||||
max_retries: int = None,
|
max_retries: int = None,
|
||||||
max_workers: int | None = None,
|
max_workers: int | None = None,
|
||||||
custom_headers: dict | None = None,
|
custom_headers: dict | None = None,
|
||||||
|
max_failure_retries: int | None = None,
|
||||||
) -> List[GenerationResult]:
|
) -> List[GenerationResult]:
|
||||||
client = OpenAI(
|
client = OpenAI(
|
||||||
api_key=settings.VLLM_API_KEY,
|
api_key=settings.VLLM_API_KEY,
|
||||||
@@ -86,27 +87,50 @@ def generate_vllm(
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def process_item(item, max_retries):
|
def process_item(item, max_retries, max_failure_retries=None):
|
||||||
result = _generate(item)
|
result = _generate(item)
|
||||||
retries = 0
|
retries = 0
|
||||||
|
|
||||||
while retries < max_retries and (
|
while _should_retry(result, retries, max_retries, max_failure_retries):
|
||||||
detect_repeat_token(result.raw)
|
|
||||||
or (
|
|
||||||
len(result.raw) > 50
|
|
||||||
and detect_repeat_token(result.raw, cut_from_end=50)
|
|
||||||
)
|
|
||||||
or result.error
|
|
||||||
):
|
|
||||||
print(
|
|
||||||
f"Detected repeat token or error, retrying generation (attempt {retries + 1})..."
|
|
||||||
)
|
|
||||||
result = _generate(item, temperature=0.3, top_p=0.95)
|
result = _generate(item, temperature=0.3, top_p=0.95)
|
||||||
retries += 1
|
retries += 1
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def _should_retry(result, retries, max_retries, max_failure_retries):
|
||||||
|
has_repeat = detect_repeat_token(result.raw) or (
|
||||||
|
len(result.raw) > 50 and detect_repeat_token(result.raw, cut_from_end=50)
|
||||||
|
)
|
||||||
|
|
||||||
|
if retries < max_retries and has_repeat:
|
||||||
|
print(
|
||||||
|
f"Detected repeat token, retrying generation (attempt {retries + 1})..."
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
if retries < max_retries and result.error:
|
||||||
|
print(
|
||||||
|
f"Detected vllm error, retrying generation (attempt {retries + 1})..."
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
if (
|
||||||
|
result.error
|
||||||
|
and max_failure_retries is not None
|
||||||
|
and retries < max_failure_retries
|
||||||
|
):
|
||||||
|
print(
|
||||||
|
f"Detected vllm error, retrying generation (attempt {retries + 1})..."
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
results = list(executor.map(process_item, batch, repeat(max_retries)))
|
results = list(
|
||||||
|
executor.map(
|
||||||
|
process_item, batch, repeat(max_retries), repeat(max_failure_retries)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|||||||
Reference in New Issue
Block a user