backend import in child load_model method and expose logfile arg

This commit is contained in:
Luca
2023-11-03 11:33:03 +01:00
parent f97a253273
commit 18c1434f77

View File

@@ -30,12 +30,8 @@ class ASRBase:
self.transcribe_kargs = {}
self.original_language = lan
self.import_backend()
self.model = self.load_model(modelsize, cache_dir, model_dir)
def import_backend(self):
raise NotImplemented("must be implemented in the child class")
def load_model(self, modelsize, cache_dir):
raise NotImplemented("must be implemented in the child class")
@@ -52,15 +48,13 @@ class WhisperTimestampedASR(ASRBase):
"""
sep = " "
def import_backend(self):
global whisper, whisper_timestamped
import whisper
import whisper_timestamped
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
global whisper_timestamped # has to be global as it is used at each `transcribe` call
import whisper
import whisper_timestamped
if model_dir is not None:
print("ignoring model_dir, not implemented",file=self.output)
print("ignoring model_dir, not implemented",file=self.logfile)
return whisper.load_model(modelsize, download_root=cache_dir)
def transcribe(self, audio, init_prompt=""):
@@ -89,13 +83,10 @@ class FasterWhisperASR(ASRBase):
sep = ""
def import_backend(self):
global faster_whisper
import faster_whisper
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
from faster_whisper import WhisperModel
if model_dir is not None:
print(f"Loading whisper model from model_dir {model_dir}. modelsize and cache_dir parameters are not used.",file=self.output)
print(f"Loading whisper model from model_dir {model_dir}. modelsize and cache_dir parameters are not used.",file=self.logfile)
model_size_or_path = model_dir
elif modelsize is not None:
model_size_or_path = modelsize
@@ -143,7 +134,7 @@ class FasterWhisperASR(ASRBase):
class HypothesisBuffer:
def __init__(self, output=sys.stderr):
def __init__(self, logfile=sys.stderr):
"""output: where to store the log. Leave it unchanged to print to terminal."""
self.commited_in_buffer = []
self.buffer = []
@@ -152,7 +143,7 @@ class HypothesisBuffer:
self.last_commited_time = 0
self.last_commited_word = None
self.output = output
self.logfile = logfile
def insert(self, new, offset):
# compare self.commited_in_buffer and new. It inserts only the words in new that extend the commited_in_buffer, it means they are roughly behind last_commited_time and new in content
@@ -172,9 +163,9 @@ class HypothesisBuffer:
c = " ".join([self.commited_in_buffer[-j][2] for j in range(1,i+1)][::-1])
tail = " ".join(self.new[j-1][2] for j in range(1,i+1))
if c == tail:
print("removing last",i,"words:",file=self.output)
print("removing last",i,"words:",file=self.logfile)
for j in range(i):
print("\t",self.new.pop(0),file=self.output)
print("\t",self.new.pop(0),file=self.logfile)
break
def flush(self):
@@ -211,14 +202,14 @@ class OnlineASRProcessor:
SAMPLING_RATE = 16000
def __init__(self, asr, tokenizer, output=sys.stderr):
def __init__(self, asr, tokenizer, logfile=sys.stderr):
"""asr: WhisperASR object
tokenizer: sentence tokenizer object for the target language. Must have a method *split* that behaves like the one of MosesTokenizer.
output: where to store the log. Leave it unchanged to print to terminal.
"""
self.asr = asr
self.tokenizer = tokenizer
self.output = output
self.logfile = logfile
self.init()
@@ -227,7 +218,7 @@ class OnlineASRProcessor:
self.audio_buffer = np.array([],dtype=np.float32)
self.buffer_time_offset = 0
self.transcript_buffer = HypothesisBuffer(output=self.output)
self.transcript_buffer = HypothesisBuffer(logfile=self.logfile)
self.commited = []
self.last_chunked_at = 0
@@ -262,9 +253,9 @@ class OnlineASRProcessor:
"""
prompt, non_prompt = self.prompt()
print("PROMPT:", prompt, file=self.output)
print("CONTEXT:", non_prompt, file=self.output)
print(f"transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds from {self.buffer_time_offset:2.2f}",file=self.output)
print("PROMPT:", prompt, file=self.logfile)
print("CONTEXT:", non_prompt, file=self.logfile)
print(f"transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds from {self.buffer_time_offset:2.2f}",file=self.logfile)
res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt)
# transform to [(beg,end,"word1"), ...]
@@ -273,8 +264,8 @@ class OnlineASRProcessor:
self.transcript_buffer.insert(tsw, self.buffer_time_offset)
o = self.transcript_buffer.flush()
self.commited.extend(o)
print(">>>>COMPLETE NOW:",self.to_flush(o),file=self.output,flush=True)
print("INCOMPLETE:",self.to_flush(self.transcript_buffer.complete()),file=self.output,flush=True)
print(">>>>COMPLETE NOW:",self.to_flush(o),file=self.logfile,flush=True)
print("INCOMPLETE:",self.to_flush(self.transcript_buffer.complete()),file=self.logfile,flush=True)
# there is a newly confirmed text
if o:
@@ -293,14 +284,14 @@ class OnlineASRProcessor:
# elif self.transcript_buffer.complete():
# self.silence_iters = 0
# elif not self.transcript_buffer.complete():
# # print("NOT COMPLETE:",to_flush(self.transcript_buffer.complete()),file=self.output,flush=True)
# # print("NOT COMPLETE:",to_flush(self.transcript_buffer.complete()),file=self.logfile,flush=True)
# self.silence_iters += 1
# if self.silence_iters >= 3:
# n = self.last_chunked_at
## self.chunk_completed_sentence()
## if n == self.last_chunked_at:
# self.chunk_at(self.last_chunked_at+self.chunk)
# print(f"\tCHUNK: 3-times silence! chunk_at {n}+{self.chunk}",file=self.output)
# print(f"\tCHUNK: 3-times silence! chunk_at {n}+{self.chunk}",file=self.logfile)
## self.silence_iters = 0
@@ -316,18 +307,18 @@ class OnlineASRProcessor:
#while k>0 and self.commited[k][1] > l:
# k -= 1
#t = self.commited[k][1]
print(f"chunking because of len",file=self.output)
print(f"chunking because of len",file=self.logfile)
#self.chunk_at(t)
print(f"len of buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}",file=self.output)
print(f"len of buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}",file=self.logfile)
return self.to_flush(o)
def chunk_completed_sentence(self):
if self.commited == []: return
print(self.commited,file=self.output)
print(self.commited,file=self.logfile)
sents = self.words_to_sentences(self.commited)
for s in sents:
print("\t\tSENT:",s,file=self.output)
print("\t\tSENT:",s,file=self.logfile)
if len(sents) < 2:
return
while len(sents) > 2:
@@ -335,7 +326,7 @@ class OnlineASRProcessor:
# we will continue with audio processing at this timestamp
chunk_at = sents[-2][1]
print(f"--- sentence chunked at {chunk_at:2.2f}",file=self.output)
print(f"--- sentence chunked at {chunk_at:2.2f}",file=self.logfile)
self.chunk_at(chunk_at)
def chunk_completed_segment(self, res):
@@ -352,12 +343,12 @@ class OnlineASRProcessor:
ends.pop(-1)
e = ends[-2]+self.buffer_time_offset
if e <= t:
print(f"--- segment chunked at {e:2.2f}",file=self.output)
print(f"--- segment chunked at {e:2.2f}",file=self.logfile)
self.chunk_at(e)
else:
print(f"--- last segment not within commited area",file=self.output)
print(f"--- last segment not within commited area",file=self.logfile)
else:
print(f"--- not enough segments to chunk",file=self.output)
print(f"--- not enough segments to chunk",file=self.logfile)
@@ -403,7 +394,7 @@ class OnlineASRProcessor:
"""
o = self.transcript_buffer.complete()
f = self.to_flush(o)
print("last, noncommited:",f,file=self.output)
print("last, noncommited:",f,file=self.logfile)
return f