diff --git a/whisper_online.py b/whisper_online.py index 58dbe8b..51a9ab6 100644 --- a/whisper_online.py +++ b/whisper_online.py @@ -30,12 +30,8 @@ class ASRBase: self.transcribe_kargs = {} self.original_language = lan - self.import_backend() self.model = self.load_model(modelsize, cache_dir, model_dir) - def import_backend(self): - raise NotImplemented("must be implemented in the child class") - def load_model(self, modelsize, cache_dir): raise NotImplemented("must be implemented in the child class") @@ -52,15 +48,13 @@ class WhisperTimestampedASR(ASRBase): """ sep = " " - - def import_backend(self): - global whisper, whisper_timestamped - import whisper - import whisper_timestamped def load_model(self, modelsize=None, cache_dir=None, model_dir=None): + global whisper_timestamped # has to be global as it is used at each `transcribe` call + import whisper + import whisper_timestamped if model_dir is not None: - print("ignoring model_dir, not implemented",file=self.output) + print("ignoring model_dir, not implemented",file=self.logfile) return whisper.load_model(modelsize, download_root=cache_dir) def transcribe(self, audio, init_prompt=""): @@ -89,13 +83,10 @@ class FasterWhisperASR(ASRBase): sep = "" - def import_backend(self): - global faster_whisper - import faster_whisper - def load_model(self, modelsize=None, cache_dir=None, model_dir=None): + from faster_whisper import WhisperModel if model_dir is not None: - print(f"Loading whisper model from model_dir {model_dir}. modelsize and cache_dir parameters are not used.",file=self.output) + print(f"Loading whisper model from model_dir {model_dir}. modelsize and cache_dir parameters are not used.",file=self.logfile) model_size_or_path = model_dir elif modelsize is not None: model_size_or_path = modelsize @@ -143,7 +134,7 @@ class FasterWhisperASR(ASRBase): class HypothesisBuffer: - def __init__(self, output=sys.stderr): + def __init__(self, logfile=sys.stderr): """output: where to store the log. Leave it unchanged to print to terminal.""" self.commited_in_buffer = [] self.buffer = [] @@ -152,7 +143,7 @@ class HypothesisBuffer: self.last_commited_time = 0 self.last_commited_word = None - self.output = output + self.logfile = logfile def insert(self, new, offset): # compare self.commited_in_buffer and new. It inserts only the words in new that extend the commited_in_buffer, it means they are roughly behind last_commited_time and new in content @@ -172,9 +163,9 @@ class HypothesisBuffer: c = " ".join([self.commited_in_buffer[-j][2] for j in range(1,i+1)][::-1]) tail = " ".join(self.new[j-1][2] for j in range(1,i+1)) if c == tail: - print("removing last",i,"words:",file=self.output) + print("removing last",i,"words:",file=self.logfile) for j in range(i): - print("\t",self.new.pop(0),file=self.output) + print("\t",self.new.pop(0),file=self.logfile) break def flush(self): @@ -211,14 +202,14 @@ class OnlineASRProcessor: SAMPLING_RATE = 16000 - def __init__(self, asr, tokenizer, output=sys.stderr): + def __init__(self, asr, tokenizer, logfile=sys.stderr): """asr: WhisperASR object tokenizer: sentence tokenizer object for the target language. Must have a method *split* that behaves like the one of MosesTokenizer. output: where to store the log. Leave it unchanged to print to terminal. """ self.asr = asr self.tokenizer = tokenizer - self.output = output + self.logfile = logfile self.init() @@ -227,7 +218,7 @@ class OnlineASRProcessor: self.audio_buffer = np.array([],dtype=np.float32) self.buffer_time_offset = 0 - self.transcript_buffer = HypothesisBuffer(output=self.output) + self.transcript_buffer = HypothesisBuffer(logfile=self.logfile) self.commited = [] self.last_chunked_at = 0 @@ -262,9 +253,9 @@ class OnlineASRProcessor: """ prompt, non_prompt = self.prompt() - print("PROMPT:", prompt, file=self.output) - print("CONTEXT:", non_prompt, file=self.output) - print(f"transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds from {self.buffer_time_offset:2.2f}",file=self.output) + print("PROMPT:", prompt, file=self.logfile) + print("CONTEXT:", non_prompt, file=self.logfile) + print(f"transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds from {self.buffer_time_offset:2.2f}",file=self.logfile) res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt) # transform to [(beg,end,"word1"), ...] @@ -273,8 +264,8 @@ class OnlineASRProcessor: self.transcript_buffer.insert(tsw, self.buffer_time_offset) o = self.transcript_buffer.flush() self.commited.extend(o) - print(">>>>COMPLETE NOW:",self.to_flush(o),file=self.output,flush=True) - print("INCOMPLETE:",self.to_flush(self.transcript_buffer.complete()),file=self.output,flush=True) + print(">>>>COMPLETE NOW:",self.to_flush(o),file=self.logfile,flush=True) + print("INCOMPLETE:",self.to_flush(self.transcript_buffer.complete()),file=self.logfile,flush=True) # there is a newly confirmed text if o: @@ -293,14 +284,14 @@ class OnlineASRProcessor: # elif self.transcript_buffer.complete(): # self.silence_iters = 0 # elif not self.transcript_buffer.complete(): -# # print("NOT COMPLETE:",to_flush(self.transcript_buffer.complete()),file=self.output,flush=True) +# # print("NOT COMPLETE:",to_flush(self.transcript_buffer.complete()),file=self.logfile,flush=True) # self.silence_iters += 1 # if self.silence_iters >= 3: # n = self.last_chunked_at ## self.chunk_completed_sentence() ## if n == self.last_chunked_at: # self.chunk_at(self.last_chunked_at+self.chunk) -# print(f"\tCHUNK: 3-times silence! chunk_at {n}+{self.chunk}",file=self.output) +# print(f"\tCHUNK: 3-times silence! chunk_at {n}+{self.chunk}",file=self.logfile) ## self.silence_iters = 0 @@ -316,18 +307,18 @@ class OnlineASRProcessor: #while k>0 and self.commited[k][1] > l: # k -= 1 #t = self.commited[k][1] - print(f"chunking because of len",file=self.output) + print(f"chunking because of len",file=self.logfile) #self.chunk_at(t) - print(f"len of buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}",file=self.output) + print(f"len of buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}",file=self.logfile) return self.to_flush(o) def chunk_completed_sentence(self): if self.commited == []: return - print(self.commited,file=self.output) + print(self.commited,file=self.logfile) sents = self.words_to_sentences(self.commited) for s in sents: - print("\t\tSENT:",s,file=self.output) + print("\t\tSENT:",s,file=self.logfile) if len(sents) < 2: return while len(sents) > 2: @@ -335,7 +326,7 @@ class OnlineASRProcessor: # we will continue with audio processing at this timestamp chunk_at = sents[-2][1] - print(f"--- sentence chunked at {chunk_at:2.2f}",file=self.output) + print(f"--- sentence chunked at {chunk_at:2.2f}",file=self.logfile) self.chunk_at(chunk_at) def chunk_completed_segment(self, res): @@ -352,12 +343,12 @@ class OnlineASRProcessor: ends.pop(-1) e = ends[-2]+self.buffer_time_offset if e <= t: - print(f"--- segment chunked at {e:2.2f}",file=self.output) + print(f"--- segment chunked at {e:2.2f}",file=self.logfile) self.chunk_at(e) else: - print(f"--- last segment not within commited area",file=self.output) + print(f"--- last segment not within commited area",file=self.logfile) else: - print(f"--- not enough segments to chunk",file=self.output) + print(f"--- not enough segments to chunk",file=self.logfile) @@ -403,7 +394,7 @@ class OnlineASRProcessor: """ o = self.transcript_buffer.complete() f = self.to_flush(o) - print("last, noncommited:",f,file=self.output) + print("last, noncommited:",f,file=self.logfile) return f