diff --git a/whisper_online.py b/whisper_online.py index 0f743e0..26fe6db 100644 --- a/whisper_online.py +++ b/whisper_online.py @@ -4,7 +4,7 @@ import numpy as np import librosa from functools import lru_cache import time -from mosestokenizer import MosesTokenizer + @lru_cache @@ -207,14 +207,12 @@ class OnlineASRProcessor: SAMPLING_RATE = 16000 - def __init__(self, language, asr): - """language: lang. code that MosesTokenizer uses for sentence segmentation - asr: WhisperASR object - chunk: number of seconds for intended size of audio interval that is inserted and looped + def __init__(self, asr, tokenizer): + """asr: WhisperASR object + tokenizer: sentence tokenizer object for the target language. Must have a method *split* that behaves like the one of MosesTokenizer. """ - self.language = language self.asr = asr - self.tokenizer = MosesTokenizer(self.language) + self.tokenizer = tokenizer self.init() @@ -369,7 +367,7 @@ class OnlineASRProcessor: self.last_chunked_at = time def words_to_sentences(self, words): - """Uses mosestokenizer for sentence segmentation of words. + """Uses self.tokenizer for sentence segmentation of words. Returns: [(beg,end,"sentence 1"),...] """ @@ -419,6 +417,15 @@ class OnlineASRProcessor: return (b,e,t) +def create_tokenizer(lan): + if lan == "uk": + import tokenize_uk + class UkrainianTokenizer: + def split(self, text): + return tokenize_uk.tokenize_sents(text) + return UkrainianTokenizer() + from mosestokenizer import MosesTokenizer + return MosesTokenizer(lan) ## main: @@ -482,8 +489,9 @@ if __name__ == "__main__": print("setting VAD filter",file=sys.stderr) asr.use_vad() + min_chunk = args.min_chunk_size - online = OnlineASRProcessor(tgt_language,asr) + online = OnlineASRProcessor(asr,create_tokenizer(tgt_language)) # load the audio into the LRU cache before we start the timer diff --git a/whisper_online_server.py b/whisper_online_server.py index 3e2a569..2df13a2 100644 --- a/whisper_online_server.py +++ b/whisper_online_server.py @@ -48,6 +48,9 @@ asr = asr_cls(modelsize=size, lan=language, cache_dir=args.model_cache_dir, mode if args.task == "translate": asr.set_translate_task() + tgt_language = "en" +else: + tgt_language = language e = time.time() print(f"done. It took {round(e-t,2)} seconds.",file=sys.stderr) @@ -58,7 +61,7 @@ if args.vad: min_chunk = args.min_chunk_size -online = OnlineASRProcessor(language,asr) +online = OnlineASRProcessor(asr,create_tokenizer(tgt_language))