diff --git a/chandra/model/vllm.py b/chandra/model/vllm.py index 1aabf69..a74b852 100644 --- a/chandra/model/vllm.py +++ b/chandra/model/vllm.py @@ -26,6 +26,7 @@ def generate_vllm( max_retries: int = None, max_workers: int | None = None, custom_headers: dict | None = None, + max_failure_retries: int | None = None, ) -> List[GenerationResult]: client = OpenAI( api_key=settings.VLLM_API_KEY, @@ -86,27 +87,50 @@ def generate_vllm( return result - def process_item(item, max_retries): + def process_item(item, max_retries, max_failure_retries=None): result = _generate(item) retries = 0 - while retries < max_retries and ( - detect_repeat_token(result.raw) - or ( - len(result.raw) > 50 - and detect_repeat_token(result.raw, cut_from_end=50) - ) - or result.error - ): - print( - f"Detected repeat token or error, retrying generation (attempt {retries + 1})..." - ) + while _should_retry(result, retries, max_retries, max_failure_retries): result = _generate(item, temperature=0.3, top_p=0.95) retries += 1 return result + def _should_retry(result, retries, max_retries, max_failure_retries): + has_repeat = detect_repeat_token(result.raw) or ( + len(result.raw) > 50 and detect_repeat_token(result.raw, cut_from_end=50) + ) + + if retries < max_retries and has_repeat: + print( + f"Detected repeat token, retrying generation (attempt {retries + 1})..." + ) + return True + + if retries < max_retries and result.error: + print( + f"Detected vllm error, retrying generation (attempt {retries + 1})..." + ) + return True + + if ( + result.error + and max_failure_retries is not None + and retries < max_failure_retries + ): + print( + f"Detected vllm error, retrying generation (attempt {retries + 1})..." + ) + return True + + return False + with ThreadPoolExecutor(max_workers=max_workers) as executor: - results = list(executor.map(process_item, batch, repeat(max_retries))) + results = list( + executor.map( + process_item, batch, repeat(max_retries), repeat(max_failure_retries) + ) + ) return results