diff --git a/whisperlivekit/results_formater.py b/whisperlivekit/results_formater.py index 1556ac9..577a1c5 100644 --- a/whisperlivekit/results_formater.py +++ b/whisperlivekit/results_formater.py @@ -6,11 +6,10 @@ from whisperlivekit.timed_objects import Line, format_time logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) -PUNCTUATION_MARKS = {'.', '!', '?', '。', '!', '?'} CHECK_AROUND = 4 def is_punctuation(token): - if token.text.strip() in PUNCTUATION_MARKS: + if token.is_punctuation(): return True return False diff --git a/whisperlivekit/timed_objects.py b/whisperlivekit/timed_objects.py index c3954ea..722673f 100644 --- a/whisperlivekit/timed_objects.py +++ b/whisperlivekit/timed_objects.py @@ -2,6 +2,8 @@ from dataclasses import dataclass, field from typing import Optional, Any from datetime import timedelta +PUNCTUATION_MARKS = {'.', '!', '?', '。', '!', '?'} + def format_time(seconds: float) -> str: """Format seconds as HH:MM:SS.""" return str(timedelta(seconds=int(seconds))) @@ -16,6 +18,9 @@ class TimedText: probability: Optional[float] = None is_dummy: Optional[bool] = False + def is_punctuation(self): + return self.text.strip() in PUNCTUATION_MARKS + def overlaps_with(self, other: 'TimedText') -> bool: return not (self.end <= other.start or other.end <= self.start) diff --git a/whisperlivekit/translation/translation.py b/whisperlivekit/translation/translation.py index c08f190..bb144c8 100644 --- a/whisperlivekit/translation/translation.py +++ b/whisperlivekit/translation/translation.py @@ -12,8 +12,6 @@ logger = logging.getLogger(__name__) #In diarization case, we may want to translate just one speaker, or at least start the sentences there -PUNCTUATION_MARKS = {'.', '!', '?', '。', '!', '?'} - MIN_SILENCE_DURATION_DEL_BUFFER = 3 #After a silence of x seconds, we consider the model should not use the buffer, even if the previous # sentence is not finished. @@ -111,7 +109,7 @@ class OnlineTranslation: if len(self.buffer) < self.len_processed_buffer + 3: #nothing new to process return self.validated + [self.translation_remaining] while i < len(self.buffer): - if self.buffer[i].text in PUNCTUATION_MARKS: + if self.buffer[i].is_punctuation(): translation_sentence = self.translate_tokens(self.buffer[:i+1]) self.validated.append(translation_sentence) self.buffer = self.buffer[i+1:]