Merge remote-tracking branch 'contrib/fix-sentencesegmenter'

This commit is contained in:
Quentin Fuxa
2025-01-26 15:34:41 +01:00
3 changed files with 54 additions and 30 deletions

3
.gitignore vendored
View File

@@ -127,3 +127,6 @@ dmypy.json
# Pyre type checker
.pyre/
*.wav
run_*.sh

View File

@@ -87,11 +87,20 @@ class OnlineASRProcessor:
buffer_trimming=("segment", 15),
logfile=sys.stderr,
):
"""asr: WhisperASR object
tokenize_method: sentence tokenizer function for the target language. Must be a callable and behaves like the one of MosesTokenizer. It can be None, if "segment" buffer trimming option is used, then tokenizer is not used at all.
("segment", 15)
buffer_trimming: a pair of (option, seconds), where option is either "sentence" or "segment", and seconds is a number. Buffer is trimmed if it is longer than "seconds" threshold. Default is the most recommended option.
logfile: where to store the log.
"""
Initialize OnlineASRProcessor.
Args:
asr: WhisperASR object
tokenize_method: Sentence tokenizer function for the target language.
Must be a function that takes a list of text as input like MosesSentenceSplitter.
Can be None if using "segment" buffer trimming option.
buffer_trimming: Tuple of (option, seconds) where:
- option: Either "sentence" or "segment"
- seconds: Number of seconds threshold for buffer trimming
Default is ("segment", 15)
logfile: File to store logs
"""
self.asr = asr
self.tokenize = tokenize_method
@@ -142,7 +151,7 @@ class OnlineASRProcessor:
"""
prompt, non_prompt = self.prompt()
logger.debug(f"PROMPT: {prompt}")
logger.debug(f"PROMPT(previous): {prompt}")
logger.debug(f"CONTEXT: {non_prompt}")
logger.debug(
f"transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds from {self.buffer_time_offset:2.2f}"
@@ -199,22 +208,27 @@ class OnlineASRProcessor:
def chunk_completed_sentence(self):
if self.commited == []:
return
import pdb; pdb.set_trace()
return
raw_text = self.asr.sep.join([s[2] for s in self.commited])
logger.debug(f"COMPLETED SENTENCE: {raw_text}")
sents = self.words_to_sentences(self.commited)
for s in sents:
logger.debug(f"\t\tSENT: {s}")
if len(sents) < 2:
logger.debug(f"[Sentence-segmentation] no sentence segmented.")
return
while len(sents) > 2:
sents.pop(0)
identified_sentence= "\n - ".join([f"{s[0]*1000:.0f}-{s[1]*1000:.0f} {s[2]}" for s in sents])
logger.debug(f"[Sentence-segmentation] identified sentences:\n - {identified_sentence}")
# we will continue with audio processing at this timestamp
chunk_at = sents[-2][1]
logger.debug(f"--- sentence chunked at {chunk_at:2.2f}")
logger.debug(f"[Sentence-segmentation]: sentence will be chunked at {chunk_at:2.2f}")
self.chunk_at(chunk_at)
def chunk_completed_segment(self, res):
@@ -253,7 +267,8 @@ class OnlineASRProcessor:
cwords = [w for w in words]
t = self.asr.sep.join(o[2] for o in cwords)
s = self.tokenize(t)
logger.debug(f"[Sentence-segmentation] Raw Text: {t}")
s = self.tokenize([t])
out = []
while s:
beg = None
@@ -278,7 +293,7 @@ class OnlineASRProcessor:
"""
o = self.transcript_buffer.complete()
f = self.to_flush(o)
logger.debug(f"last, noncommited: {f[0]*1000:.0f}-{f[1]*1000:.0f}: {f[2]}")
logger.debug(f"last, noncommited: {f[0]*1000:.0f}-{f[1]*1000:.0f}: {f[2][0]*1000:.0f}-{f[1]*1000:.0f}: {f[2]}")
self.buffer_time_offset += len(self.audio_buffer) / 16000
return f

View File

@@ -49,9 +49,9 @@ def create_tokenizer(lan):
lan
in "as bn ca cs de el en es et fi fr ga gu hi hu is it kn lt lv ml mni mr nl or pa pl pt ro ru sk sl sv ta te yue zh".split()
):
from mosestokenizer import MosesTokenizer
from mosestokenizer import MosesSentenceSplitter
return MosesTokenizer(lan)
return MosesSentenceSplitter(lan)
# the following languages are in Whisper, but not in wtpsplit:
if (
@@ -204,6 +204,7 @@ def backend_factory(args):
# Create the tokenizer
if args.buffer_trimming == "sentence":
tokenizer = create_tokenizer(tgt_language)
else:
tokenizer = None
@@ -235,10 +236,12 @@ def asr_factory(args, logfile=sys.stderr):
online = online_factory(args, asr, tokenizer, logfile=logfile)
return asr, online
def set_logging(args, logger, other="_server"):
def set_logging(args, logger, others=[]):
logging.basicConfig(format="%(levelname)s\t%(message)s") # format='%(name)s
logger.setLevel(args.log_level)
logging.getLogger("whisper_online" + other).setLevel(args.log_level)
for other in others:
logging.getLogger(other).setLevel(args.log_level)
# logging.getLogger("whisper_online_server").setLevel(args.log_level)
@@ -275,7 +278,7 @@ if __name__ == "__main__":
args = parser.parse_args()
# reset to store stderr to different file stream, e.g. open(os.devnull,"w")
logfile = sys.stderr
logfile = None # sys.stderr
if args.offline and args.comp_unaware:
logger.error(
@@ -287,7 +290,7 @@ if __name__ == "__main__":
# logging.basicConfig(format='whisper-%(levelname)s:%(name)s: %(message)s',
# level=getattr(logging, args.log_level))
set_logging(args, logger)
set_logging(args, logger,others=["src.whisper_streaming.online_asr"])
audio_path = args.audio_path
@@ -320,15 +323,18 @@ if __name__ == "__main__":
if now is None:
now = time.time() - start
if o[0] is not None:
print(
"%1.4f %1.0f %1.0f %s" % (now * 1000, o[0] * 1000, o[1] * 1000, o[2]),
file=logfile,
flush=True,
)
print(
"%1.4f %1.0f %1.0f %s" % (now * 1000, o[0] * 1000, o[1] * 1000, o[2]),
flush=True,
log_string = f"{now*1000:1.0f}, {o[0]*1000:1.0f}-{o[1]*1000:1.0f} ({(now-o[1]):+1.0f}s): {o[2]}"
logger.debug(
log_string
)
if logfile is not None:
print(
log_string,
file=logfile,
flush=True,
)
else:
# No text, so no output
pass