mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 14:23:18 +00:00
backend import in child load_model method and expose logfile arg
This commit is contained in:
@@ -30,12 +30,8 @@ class ASRBase:
|
||||
self.transcribe_kargs = {}
|
||||
self.original_language = lan
|
||||
|
||||
self.import_backend()
|
||||
self.model = self.load_model(modelsize, cache_dir, model_dir)
|
||||
|
||||
def import_backend(self):
|
||||
raise NotImplemented("must be implemented in the child class")
|
||||
|
||||
def load_model(self, modelsize, cache_dir):
|
||||
raise NotImplemented("must be implemented in the child class")
|
||||
|
||||
@@ -52,15 +48,13 @@ class WhisperTimestampedASR(ASRBase):
|
||||
"""
|
||||
|
||||
sep = " "
|
||||
|
||||
def import_backend(self):
|
||||
global whisper, whisper_timestamped
|
||||
import whisper
|
||||
import whisper_timestamped
|
||||
|
||||
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
||||
global whisper_timestamped # has to be global as it is used at each `transcribe` call
|
||||
import whisper
|
||||
import whisper_timestamped
|
||||
if model_dir is not None:
|
||||
print("ignoring model_dir, not implemented",file=self.output)
|
||||
print("ignoring model_dir, not implemented",file=self.logfile)
|
||||
return whisper.load_model(modelsize, download_root=cache_dir)
|
||||
|
||||
def transcribe(self, audio, init_prompt=""):
|
||||
@@ -89,13 +83,10 @@ class FasterWhisperASR(ASRBase):
|
||||
|
||||
sep = ""
|
||||
|
||||
def import_backend(self):
|
||||
global faster_whisper
|
||||
import faster_whisper
|
||||
|
||||
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
||||
from faster_whisper import WhisperModel
|
||||
if model_dir is not None:
|
||||
print(f"Loading whisper model from model_dir {model_dir}. modelsize and cache_dir parameters are not used.",file=self.output)
|
||||
print(f"Loading whisper model from model_dir {model_dir}. modelsize and cache_dir parameters are not used.",file=self.logfile)
|
||||
model_size_or_path = model_dir
|
||||
elif modelsize is not None:
|
||||
model_size_or_path = modelsize
|
||||
@@ -143,7 +134,7 @@ class FasterWhisperASR(ASRBase):
|
||||
|
||||
class HypothesisBuffer:
|
||||
|
||||
def __init__(self, output=sys.stderr):
|
||||
def __init__(self, logfile=sys.stderr):
|
||||
"""output: where to store the log. Leave it unchanged to print to terminal."""
|
||||
self.commited_in_buffer = []
|
||||
self.buffer = []
|
||||
@@ -152,7 +143,7 @@ class HypothesisBuffer:
|
||||
self.last_commited_time = 0
|
||||
self.last_commited_word = None
|
||||
|
||||
self.output = output
|
||||
self.logfile = logfile
|
||||
|
||||
def insert(self, new, offset):
|
||||
# compare self.commited_in_buffer and new. It inserts only the words in new that extend the commited_in_buffer, it means they are roughly behind last_commited_time and new in content
|
||||
@@ -172,9 +163,9 @@ class HypothesisBuffer:
|
||||
c = " ".join([self.commited_in_buffer[-j][2] for j in range(1,i+1)][::-1])
|
||||
tail = " ".join(self.new[j-1][2] for j in range(1,i+1))
|
||||
if c == tail:
|
||||
print("removing last",i,"words:",file=self.output)
|
||||
print("removing last",i,"words:",file=self.logfile)
|
||||
for j in range(i):
|
||||
print("\t",self.new.pop(0),file=self.output)
|
||||
print("\t",self.new.pop(0),file=self.logfile)
|
||||
break
|
||||
|
||||
def flush(self):
|
||||
@@ -211,14 +202,14 @@ class OnlineASRProcessor:
|
||||
|
||||
SAMPLING_RATE = 16000
|
||||
|
||||
def __init__(self, asr, tokenizer, output=sys.stderr):
|
||||
def __init__(self, asr, tokenizer, logfile=sys.stderr):
|
||||
"""asr: WhisperASR object
|
||||
tokenizer: sentence tokenizer object for the target language. Must have a method *split* that behaves like the one of MosesTokenizer.
|
||||
output: where to store the log. Leave it unchanged to print to terminal.
|
||||
"""
|
||||
self.asr = asr
|
||||
self.tokenizer = tokenizer
|
||||
self.output = output
|
||||
self.logfile = logfile
|
||||
|
||||
self.init()
|
||||
|
||||
@@ -227,7 +218,7 @@ class OnlineASRProcessor:
|
||||
self.audio_buffer = np.array([],dtype=np.float32)
|
||||
self.buffer_time_offset = 0
|
||||
|
||||
self.transcript_buffer = HypothesisBuffer(output=self.output)
|
||||
self.transcript_buffer = HypothesisBuffer(logfile=self.logfile)
|
||||
self.commited = []
|
||||
self.last_chunked_at = 0
|
||||
|
||||
@@ -262,9 +253,9 @@ class OnlineASRProcessor:
|
||||
"""
|
||||
|
||||
prompt, non_prompt = self.prompt()
|
||||
print("PROMPT:", prompt, file=self.output)
|
||||
print("CONTEXT:", non_prompt, file=self.output)
|
||||
print(f"transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds from {self.buffer_time_offset:2.2f}",file=self.output)
|
||||
print("PROMPT:", prompt, file=self.logfile)
|
||||
print("CONTEXT:", non_prompt, file=self.logfile)
|
||||
print(f"transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds from {self.buffer_time_offset:2.2f}",file=self.logfile)
|
||||
res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt)
|
||||
|
||||
# transform to [(beg,end,"word1"), ...]
|
||||
@@ -273,8 +264,8 @@ class OnlineASRProcessor:
|
||||
self.transcript_buffer.insert(tsw, self.buffer_time_offset)
|
||||
o = self.transcript_buffer.flush()
|
||||
self.commited.extend(o)
|
||||
print(">>>>COMPLETE NOW:",self.to_flush(o),file=self.output,flush=True)
|
||||
print("INCOMPLETE:",self.to_flush(self.transcript_buffer.complete()),file=self.output,flush=True)
|
||||
print(">>>>COMPLETE NOW:",self.to_flush(o),file=self.logfile,flush=True)
|
||||
print("INCOMPLETE:",self.to_flush(self.transcript_buffer.complete()),file=self.logfile,flush=True)
|
||||
|
||||
# there is a newly confirmed text
|
||||
if o:
|
||||
@@ -293,14 +284,14 @@ class OnlineASRProcessor:
|
||||
# elif self.transcript_buffer.complete():
|
||||
# self.silence_iters = 0
|
||||
# elif not self.transcript_buffer.complete():
|
||||
# # print("NOT COMPLETE:",to_flush(self.transcript_buffer.complete()),file=self.output,flush=True)
|
||||
# # print("NOT COMPLETE:",to_flush(self.transcript_buffer.complete()),file=self.logfile,flush=True)
|
||||
# self.silence_iters += 1
|
||||
# if self.silence_iters >= 3:
|
||||
# n = self.last_chunked_at
|
||||
## self.chunk_completed_sentence()
|
||||
## if n == self.last_chunked_at:
|
||||
# self.chunk_at(self.last_chunked_at+self.chunk)
|
||||
# print(f"\tCHUNK: 3-times silence! chunk_at {n}+{self.chunk}",file=self.output)
|
||||
# print(f"\tCHUNK: 3-times silence! chunk_at {n}+{self.chunk}",file=self.logfile)
|
||||
## self.silence_iters = 0
|
||||
|
||||
|
||||
@@ -316,18 +307,18 @@ class OnlineASRProcessor:
|
||||
#while k>0 and self.commited[k][1] > l:
|
||||
# k -= 1
|
||||
#t = self.commited[k][1]
|
||||
print(f"chunking because of len",file=self.output)
|
||||
print(f"chunking because of len",file=self.logfile)
|
||||
#self.chunk_at(t)
|
||||
|
||||
print(f"len of buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}",file=self.output)
|
||||
print(f"len of buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}",file=self.logfile)
|
||||
return self.to_flush(o)
|
||||
|
||||
def chunk_completed_sentence(self):
|
||||
if self.commited == []: return
|
||||
print(self.commited,file=self.output)
|
||||
print(self.commited,file=self.logfile)
|
||||
sents = self.words_to_sentences(self.commited)
|
||||
for s in sents:
|
||||
print("\t\tSENT:",s,file=self.output)
|
||||
print("\t\tSENT:",s,file=self.logfile)
|
||||
if len(sents) < 2:
|
||||
return
|
||||
while len(sents) > 2:
|
||||
@@ -335,7 +326,7 @@ class OnlineASRProcessor:
|
||||
# we will continue with audio processing at this timestamp
|
||||
chunk_at = sents[-2][1]
|
||||
|
||||
print(f"--- sentence chunked at {chunk_at:2.2f}",file=self.output)
|
||||
print(f"--- sentence chunked at {chunk_at:2.2f}",file=self.logfile)
|
||||
self.chunk_at(chunk_at)
|
||||
|
||||
def chunk_completed_segment(self, res):
|
||||
@@ -352,12 +343,12 @@ class OnlineASRProcessor:
|
||||
ends.pop(-1)
|
||||
e = ends[-2]+self.buffer_time_offset
|
||||
if e <= t:
|
||||
print(f"--- segment chunked at {e:2.2f}",file=self.output)
|
||||
print(f"--- segment chunked at {e:2.2f}",file=self.logfile)
|
||||
self.chunk_at(e)
|
||||
else:
|
||||
print(f"--- last segment not within commited area",file=self.output)
|
||||
print(f"--- last segment not within commited area",file=self.logfile)
|
||||
else:
|
||||
print(f"--- not enough segments to chunk",file=self.output)
|
||||
print(f"--- not enough segments to chunk",file=self.logfile)
|
||||
|
||||
|
||||
|
||||
@@ -403,7 +394,7 @@ class OnlineASRProcessor:
|
||||
"""
|
||||
o = self.transcript_buffer.complete()
|
||||
f = self.to_flush(o)
|
||||
print("last, noncommited:",f,file=self.output)
|
||||
print("last, noncommited:",f,file=self.logfile)
|
||||
return f
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user