Merge pull request #27 from SilasK/fix-sentencesegmenter

Fix sentence segmenter
This commit is contained in:
Quentin Fuxa
2025-01-31 22:54:33 +01:00
committed by GitHub
5 changed files with 392 additions and 257 deletions

0
src/__init__.py Normal file
View File

View File

@@ -17,8 +17,10 @@ class HypothesisBuffer:
self.logfile = logfile
def insert(self, new, offset):
# compare self.commited_in_buffer and new. It inserts only the words in new that extend the commited_in_buffer, it means they are roughly behind last_commited_time and new in content
# the new tail is added to self.new
"""
compare self.commited_in_buffer and new. It inserts only the words in new that extend the commited_in_buffer, it means they are roughly behind last_commited_time and new in content
The new tail is added to self.new
"""
new = [(a + offset, b + offset, t) for a, b, t in new]
self.new = [(a, b, t) for a, b, t in new if a > self.last_commited_time - 0.1]
@@ -77,6 +79,9 @@ class HypothesisBuffer:
return self.buffer
class OnlineASRProcessor:
SAMPLING_RATE = 16000
@@ -128,7 +133,9 @@ class OnlineASRProcessor:
if offset is not None:
self.buffer_time_offset = offset
self.transcript_buffer.last_commited_time = self.buffer_time_offset
self.commited = []
self.final_transcript = []
self.commited_not_final = []
def insert_audio_chunk(self, audio):
self.audio_buffer = np.append(self.audio_buffer, audio)
@@ -136,23 +143,42 @@ class OnlineASRProcessor:
def prompt(self):
"""Returns a tuple: (prompt, context), where "prompt" is a 200-character suffix of commited text that is inside of the scrolled away part of audio buffer.
"context" is the commited text that is inside the audio buffer. It is transcribed again and skipped. It is returned only for debugging and logging reasons.
"""
k = max(0, len(self.commited) - 1)
while k > 0 and self.commited[k - 1][1] > self.buffer_time_offset:
k -= 1
p = self.commited[:k]
p = [t for _, _, t in p]
prompt = []
l = 0
while p and l < 200: # 200 characters prompt size
x = p.pop(-1)
l += len(x) + 1
prompt.append(x)
non_prompt = self.commited[k:]
return self.asr.sep.join(prompt[::-1]), self.asr.sep.join(
t for _, _, t in non_prompt
)
"""
if len(self.final_transcript) == 0:
prompt=""
if len(self.final_transcript) == 1:
prompt = self.final_transcript[0][2][-200:]
else:
prompt = self.concatenate_tsw(self.final_transcript)[2][-200:]
# TODO: this is not ideal as we concatenate each time the whole transcript
# k = max(0, len(self.final_transcript) - 1)
# while k > 1 and self.final_transcript[k - 1][1] > self.buffer_time_offset:
# k -= 1
# p = self.final_transcript[:k]
# p = [t for _, _, t in p]
# prompt = []
# l = 0
# while p and l < 200: # 200 characters prompt size
# x = p.pop(-1)
# l += len(x) + 1
# prompt.append(x)
non_prompt = self.concatenate_tsw(self.commited_not_final)[2]
logger.debug(f"PROMPT(previous): {prompt[:20]}{prompt[-20:]} (length={len(prompt)}chars)")
logger.debug(f"CONTEXT: {non_prompt}")
return prompt, non_prompt
def process_iter(self):
"""Runs on the current audio buffer.
@@ -161,107 +187,137 @@ class OnlineASRProcessor:
"""
prompt, non_prompt = self.prompt()
logger.debug(f"PROMPT(previous): {prompt}")
logger.debug(f"CONTEXT: {non_prompt}")
logger.debug(
f"transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds from {self.buffer_time_offset:2.2f}"
)
## Transcribe and format the result to [(beg,end,"word1"), ...]
res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt)
# transform to [(beg,end,"word1"), ...]
tsw = self.asr.ts_words(res)
# insert into HypothesisBuffer
# insert into HypothesisBuffer, and get back the commited words
self.transcript_buffer.insert(tsw, self.buffer_time_offset)
o = self.transcript_buffer.flush()
# Completed words
self.commited.extend(o)
completed = self.concatenate_tsw(o) # This will be returned at the end of the function
logger.debug(f">>>>COMPLETE NOW: {completed[2]}")
## The rest is incomplete
the_rest = self.concatenate_tsw(self.transcript_buffer.complete())
logger.debug(f"INCOMPLETE: {the_rest[2]}")
commited_tsw = self.transcript_buffer.flush()
if len(commited_tsw) == 0:
return (None, None, "")
# there is a newly confirmed text
self.commited_not_final.extend(commited_tsw)
# Define `completed` and `the_rest` based on the buffer_trimming_way
# completed will be returned at the end of the function.
# completed is a transcribed text with (beg,end,"sentence ...") format.
completed = []
if self.buffer_trimming_way == "sentence":
sentences = self.words_to_sentences(self.commited_not_final)
self.chunk_completed_sentence(self.commited)
if len(sentences) < 2:
logger.debug(f"[Sentence-segmentation] no full sentence segmented, do not commit anything.")
else:
identified_sentence= "\n - ".join([f"{s[0]*1000:.0f}-{s[1]*1000:.0f} {s[2]}" for s in sentences])
logger.debug(f"[Sentence-segmentation] identified sentences:\n - {identified_sentence}")
# assume last sentence is incomplete, which is not always true
# we will continue with audio processing at this timestamp
chunk_at = sentences[-2][1]
self.chunk_at(chunk_at)
# TODO: here paragraph breaks can be added
self.commited_not_final = sentences[-1:]
completed= sentences[:-1]
# TODO: new words in `completed` should not be reterned unless they form a sentence
# TODO: only complete sentences should go to completed
# break audio buffer anyway if it is too long
if len(self.audio_buffer) / self.SAMPLING_RATE > self.buffer_trimming_sec :
if self.buffer_trimming_way == "sentence":
logger.warning(f"Chunck segment after {self.buffer_trimming_sec} seconds!"
" Even if no sentence was found!"
)
self.chunk_completed_segment(res)
)
# alternative: on any word
# l = self.buffer_time_offset + len(self.audio_buffer)/self.SAMPLING_RATE - 10
# let's find commited word that is less
# k = len(self.commited)-1
# while k>0 and self.commited[k][1] > l:
# k -= 1
# t = self.commited[k][1]
# self.chunk_at(t)
completed = self.chunk_completed_segment()
return completed
def chunk_completed_sentence(self):
if self.commited == []:
return
raw_text = self.asr.sep.join([s[2] for s in self.commited])
logger.debug(f"COMPLETED SENTENCE: {raw_text}")
sents = self.words_to_sentences(self.commited)
if len(sents) < 2:
logger.debug(f"[Sentence-segmentation] no sentence segmented.")
return
if len(completed) == 0:
return (None, None, "")
else:
self.final_transcript.extend(completed) # add whole time stamped sentences / or words to commited list
identified_sentence= "\n - ".join([f"{s[0]*1000:.0f}-{s[1]*1000:.0f} {s[2]}" for s in sents])
logger.debug(f"[Sentence-segmentation] identified sentences:\n - {identified_sentence}")
# we will continue with audio processing at this timestamp
chunk_at = sents[-2][1]
completed_text_segment= self.concatenate_tsw(completed)
the_rest = self.concatenate_tsw(self.transcript_buffer.complete())
commited_but_not_final = self.concatenate_tsw(self.commited_not_final)
logger.debug(f"\n COMPLETE NOW: {completed_text_segment[2]}\n"
f" COMMITTED (but not Final): {commited_but_not_final[2]}\n"
f" INCOMPLETE: {the_rest[2]}"
)
self.chunk_at(chunk_at)
return completed_text_segment
def chunk_completed_segment(self, res):
if self.commited == []:
return
ends = self.asr.segments_end_ts(res)
def chunk_completed_segment(self) -> list:
t = self.commited[-1][1]
ts_words = self.commited_not_final
if len(ends) <= 1:
if len(ts_words) <= 1:
logger.debug(f"--- not enough segments to chunk (<=1 words)")
return []
else:
e = ends[-2] + self.buffer_time_offset
ends = [w[1] for w in ts_words]
t = ts_words[-1][1] # start of the last word
e = ends[-2]
while len(ends) > 2 and e > t:
ends.pop(-1)
e = ends[-2] + self.buffer_time_offset
e = ends[-2]
if e <= t:
logger.debug(f"--- segment chunked at {e:2.2f}")
self.chunk_at(e)
n_commited_words = len(ends)-1
words_to_commit = ts_words[:n_commited_words]
self.final_transcript.extend(words_to_commit)
self.commited_not_final = ts_words[n_commited_words:]
return words_to_commit
else:
logger.debug(f"--- last segment not within commited area")
return []
def chunk_at(self, time):
@@ -287,9 +343,11 @@ class OnlineASRProcessor:
Returns: [(beg,end,"sentence 1"),...]
"""
cwords = [w for w in words]
t = self.asr.sep.join(o[2] for o in cwords)
logger.debug(f"[Sentence-segmentation] Raw Text: {t}")
s = self.tokenize([t])
out = []
while s:
@@ -302,11 +360,13 @@ class OnlineASRProcessor:
w = w.strip()
if beg is None and sent.startswith(w):
beg = b
elif end is None and sent == w:
if end is None and sent == w:
end = e
if beg is not None and end is not None:
out.append((beg, end, fsent))
break
sent = sent[len(w) :].strip()
return out
def finish(self):
@@ -315,7 +375,8 @@ class OnlineASRProcessor:
"""
o = self.transcript_buffer.complete()
f = self.concatenate_tsw(o)
logger.debug(f"last, noncommited: {f[0]*1000:.0f}-{f[1]*1000:.0f}: {f[2][0]*1000:.0f}-{f[1]*1000:.0f}: {f[2]}")
if f[1] is not None:
logger.debug(f"last, noncommited: {f[0]*1000:.0f}-{f[1]*1000:.0f}: {f[2]}")
self.buffer_time_offset += len(self.audio_buffer) / 16000
return f
@@ -330,7 +391,9 @@ class OnlineASRProcessor:
# return: (beg1,end-of-last-sentence,"concatenation of sentences") or (None, None, "") if empty
if sep is None:
sep = self.asr.sep
t = sep.join(s[2] for s in tsw)
if len(tsw) == 0:
b = None
@@ -349,6 +412,8 @@ class VACOnlineASRProcessor(OnlineASRProcessor):
When it detects end of speech (non-voice for 500ms), it makes OnlineASRProcessor to end the utterance immediately.
"""
# TODO: VACOnlineASRProcessor does not break after chunch length is reached, this can lead to overflow!
def __init__(self, online_chunk_size, *a, **kw):
self.online_chunk_size = online_chunk_size

View File

@@ -5,23 +5,12 @@ import librosa
from functools import lru_cache
import time
import logging
from src.whisper_streaming.backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR
from src.whisper_streaming.online_asr import OnlineASRProcessor, VACOnlineASRProcessor
from .backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR
from .online_asr import OnlineASRProcessor, VACOnlineASRProcessor
logger = logging.getLogger(__name__)
@lru_cache(10**6)
def load_audio(fname):
a, _ = librosa.load(fname, sr=16000, dtype=np.float32)
return a
def load_audio_chunk(fname, beg, end):
audio = load_audio(fname)
beg_s = int(beg * 16000)
end_s = int(end * 16000)
return audio[beg_s:end_s]
WHISPER_LANG_CODES = "af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh".split(
","
@@ -244,163 +233,3 @@ def set_logging(args, logger, others=[]):
logging.getLogger(other).setLevel(args.log_level)
# logging.getLogger("whisper_online_server").setLevel(args.log_level)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--audio_path",
type=str,
default='samples_jfk.wav',
help="Filename of 16kHz mono channel wav, on which live streaming is simulated.",
)
add_shared_args(parser)
parser.add_argument(
"--start_at",
type=float,
default=0.0,
help="Start processing audio at this time.",
)
parser.add_argument(
"--offline", action="store_true", default=False, help="Offline mode."
)
parser.add_argument(
"--comp_unaware",
action="store_true",
default=False,
help="Computationally unaware simulation.",
)
args = parser.parse_args()
# reset to store stderr to different file stream, e.g. open(os.devnull,"w")
logfile = None # sys.stderr
if args.offline and args.comp_unaware:
logger.error(
"No or one option from --offline and --comp_unaware are available, not both. Exiting."
)
sys.exit(1)
# if args.log_level:
# logging.basicConfig(format='whisper-%(levelname)s:%(name)s: %(message)s',
# level=getattr(logging, args.log_level))
set_logging(args, logger,others=["src.whisper_streaming.online_asr"])
audio_path = args.audio_path
SAMPLING_RATE = 16000
duration = len(load_audio(audio_path)) / SAMPLING_RATE
logger.info("Audio duration is: %2.2f seconds" % duration)
asr, online = asr_factory(args, logfile=logfile)
if args.vac:
min_chunk = args.vac_chunk_size
else:
min_chunk = args.min_chunk_size
# load the audio into the LRU cache before we start the timer
a = load_audio_chunk(audio_path, 0, 1)
# warm up the ASR because the very first transcribe takes much more time than the other
asr.transcribe(a)
beg = args.start_at
start = time.time() - beg
def output_transcript(o, now=None):
# output format in stdout is like:
# 4186.3606 0 1720 Takhle to je
# - the first three words are:
# - emission time from beginning of processing, in milliseconds
# - beg and end timestamp of the text segment, as estimated by Whisper model. The timestamps are not accurate, but they're useful anyway
# - the next words: segment transcript
if now is None:
now = time.time() - start
if o[0] is not None:
log_string = f"{now*1000:1.0f}, {o[0]*1000:1.0f}-{o[1]*1000:1.0f} ({(now-o[1]):+1.0f}s): {o[2]}"
logger.debug(
log_string
)
if logfile is not None:
print(
log_string,
file=logfile,
flush=True,
)
else:
# No text, so no output
pass
if args.offline: ## offline mode processing (for testing/debugging)
a = load_audio(audio_path)
online.insert_audio_chunk(a)
try:
o = online.process_iter()
except AssertionError as e:
logger.error(f"assertion error: {repr(e)}")
else:
output_transcript(o)
now = None
elif args.comp_unaware: # computational unaware mode
end = beg + min_chunk
while True:
a = load_audio_chunk(audio_path, beg, end)
online.insert_audio_chunk(a)
try:
o = online.process_iter()
except AssertionError as e:
logger.error(f"assertion error: {repr(e)}")
pass
else:
output_transcript(o, now=end)
logger.debug(f"## last processed {end:.2f}s")
if end >= duration:
break
beg = end
if end + min_chunk > duration:
end = duration
else:
end += min_chunk
now = duration
else: # online = simultaneous mode
end = 0
while True:
now = time.time() - start
if now < end + min_chunk:
time.sleep(min_chunk + end - now)
end = time.time() - start
a = load_audio_chunk(audio_path, beg, end)
beg = end
online.insert_audio_chunk(a)
try:
o = online.process_iter()
except AssertionError as e:
logger.error(f"assertion error: {e}")
pass
else:
output_transcript(o)
now = time.time() - start
logger.debug(
f"## last processed {end:.2f} s, now is {now:.2f}, the latency is {now-end:.2f}"
)
if end >= duration:
break
now = None
o = online.finish()
output_transcript(o, now=now)

View File

@@ -11,6 +11,65 @@ from fastapi.middleware.cors import CORSMiddleware
from src.whisper_streaming.whisper_online import backend_factory, online_factory, add_shared_args
import logging
import logging.config
def setup_logging():
logging_config = {
'version': 1,
'disable_existing_loggers': False,
'formatters': {
'standard': {
'format': '%(asctime)s %(levelname)s [%(name)s]: %(message)s',
},
},
'handlers': {
'console': {
'level': 'INFO',
'class': 'logging.StreamHandler',
'formatter': 'standard',
},
},
'root': {
'handlers': ['console'],
'level': 'DEBUG',
},
'loggers': {
'uvicorn': {
'handlers': ['console'],
'level': 'INFO',
'propagate': False,
},
'uvicorn.error': {
'level': 'INFO',
},
'uvicorn.access': {
'level': 'INFO',
},
'src.whisper_streaming.online_asr': { # Add your specific module here
'handlers': ['console'],
'level': 'DEBUG',
'propagate': False,
},
'src.whisper_streaming.whisper_streaming': { # Add your specific module here
'handlers': ['console'],
'level': 'DEBUG',
'propagate': False,
},
},
}
logging.config.dictConfig(logging_config)
setup_logging()
logger = logging.getLogger(__name__)
app = FastAPI()
app.add_middleware(
CORSMiddleware,
@@ -238,5 +297,6 @@ if __name__ == "__main__":
import uvicorn
uvicorn.run(
"whisper_fastapi_online_server:app", host=args.host, port=args.port, reload=True
"whisper_fastapi_online_server:app", host=args.host, port=args.port, reload=True,
log_level="info"
)

181
whisper_noserver_test.py Normal file
View File

@@ -0,0 +1,181 @@
#!/usr/bin/env python3
import sys
import numpy as np
import librosa
from functools import lru_cache
import time
import logging
logger = logging.getLogger(__name__)
from src.whisper_streaming.whisper_online import *
@lru_cache(10**6)
def load_audio(fname):
a, _ = librosa.load(fname, sr=16000, dtype=np.float32)
return a
def load_audio_chunk(fname, beg, end):
audio = load_audio(fname)
beg_s = int(beg * 16000)
end_s = int(end * 16000)
return audio[beg_s:end_s]
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--audio_path",
type=str,
default='samples_jfk.wav',
help="Filename of 16kHz mono channel wav, on which live streaming is simulated.",
)
add_shared_args(parser)
parser.add_argument(
"--start_at",
type=float,
default=0.0,
help="Start processing audio at this time.",
)
parser.add_argument(
"--offline", action="store_true", default=False, help="Offline mode."
)
parser.add_argument(
"--comp_unaware",
action="store_true",
default=False,
help="Computationally unaware simulation.",
)
args = parser.parse_args()
# reset to store stderr to different file stream, e.g. open(os.devnull,"w")
logfile = None # sys.stderr
if args.offline and args.comp_unaware:
logger.error(
"No or one option from --offline and --comp_unaware are available, not both. Exiting."
)
sys.exit(1)
# if args.log_level:
# logging.basicConfig(format='whisper-%(levelname)s:%(name)s: %(message)s',
# level=getattr(logging, args.log_level))
set_logging(args, logger,others=["src.whisper_streaming.online_asr"])
audio_path = args.audio_path
SAMPLING_RATE = 16000
duration = len(load_audio(audio_path)) / SAMPLING_RATE
logger.info("Audio duration is: %2.2f seconds" % duration)
asr, online = asr_factory(args, logfile=logfile)
if args.vac:
min_chunk = args.vac_chunk_size
else:
min_chunk = args.min_chunk_size
# load the audio into the LRU cache before we start the timer
a = load_audio_chunk(audio_path, 0, 1)
# warm up the ASR because the very first transcribe takes much more time than the other
asr.transcribe(a)
beg = args.start_at
start = time.time() - beg
def output_transcript(o, now=None):
# output format in stdout is like:
# 4186.3606 0 1720 Takhle to je
# - the first three words are:
# - emission time from beginning of processing, in milliseconds
# - beg and end timestamp of the text segment, as estimated by Whisper model. The timestamps are not accurate, but they're useful anyway
# - the next words: segment transcript
if now is None:
now = time.time() - start
if o[0] is not None:
log_string = f"{now*1000:1.0f}, {o[0]*1000:1.0f}-{o[1]*1000:1.0f} ({(now-o[1]):+1.0f}s): {o[2]}"
logger.debug(
log_string
)
if logfile is not None:
print(
log_string,
file=logfile,
flush=True,
)
else:
# No text, so no output
pass
if args.offline: ## offline mode processing (for testing/debugging)
a = load_audio(audio_path)
online.insert_audio_chunk(a)
try:
o = online.process_iter()
except AssertionError as e:
logger.error(f"assertion error: {repr(e)}")
else:
output_transcript(o)
now = None
elif args.comp_unaware: # computational unaware mode
end = beg + min_chunk
while True:
a = load_audio_chunk(audio_path, beg, end)
online.insert_audio_chunk(a)
try:
o = online.process_iter()
except AssertionError as e:
logger.error(f"assertion error: {repr(e)}")
pass
else:
output_transcript(o, now=end)
logger.debug(f"## last processed {end:.2f}s")
if end >= duration:
break
beg = end
if end + min_chunk > duration:
end = duration
else:
end += min_chunk
now = duration
else: # online = simultaneous mode
end = 0
while True:
now = time.time() - start
if now < end + min_chunk:
time.sleep(min_chunk + end - now)
end = time.time() - start
a = load_audio_chunk(audio_path, beg, end)
beg = end
online.insert_audio_chunk(a)
try:
o = online.process_iter()
except AssertionError as e:
logger.error(f"assertion error: {e}")
pass
else:
output_transcript(o)
now = time.time() - start
logger.debug(
f"## last processed {end:.2f} s, now is {now:.2f}, the latency is {now-end:.2f}"
)
if end >= duration:
break
now = None
o = online.finish()
output_transcript(o, now=now)