diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/whisper_streaming/online_asr.py b/src/whisper_streaming/online_asr.py index 1fd2c47..7154007 100644 --- a/src/whisper_streaming/online_asr.py +++ b/src/whisper_streaming/online_asr.py @@ -17,8 +17,10 @@ class HypothesisBuffer: self.logfile = logfile def insert(self, new, offset): - # compare self.commited_in_buffer and new. It inserts only the words in new that extend the commited_in_buffer, it means they are roughly behind last_commited_time and new in content - # the new tail is added to self.new + """ + compare self.commited_in_buffer and new. It inserts only the words in new that extend the commited_in_buffer, it means they are roughly behind last_commited_time and new in content + The new tail is added to self.new + """ new = [(a + offset, b + offset, t) for a, b, t in new] self.new = [(a, b, t) for a, b, t in new if a > self.last_commited_time - 0.1] @@ -77,6 +79,9 @@ class HypothesisBuffer: return self.buffer + + + class OnlineASRProcessor: SAMPLING_RATE = 16000 @@ -128,7 +133,9 @@ class OnlineASRProcessor: if offset is not None: self.buffer_time_offset = offset self.transcript_buffer.last_commited_time = self.buffer_time_offset - self.commited = [] + self.final_transcript = [] + self.commited_not_final = [] + def insert_audio_chunk(self, audio): self.audio_buffer = np.append(self.audio_buffer, audio) @@ -136,23 +143,42 @@ class OnlineASRProcessor: def prompt(self): """Returns a tuple: (prompt, context), where "prompt" is a 200-character suffix of commited text that is inside of the scrolled away part of audio buffer. "context" is the commited text that is inside the audio buffer. It is transcribed again and skipped. It is returned only for debugging and logging reasons. - """ - k = max(0, len(self.commited) - 1) - while k > 0 and self.commited[k - 1][1] > self.buffer_time_offset: - k -= 1 - p = self.commited[:k] - p = [t for _, _, t in p] - prompt = [] - l = 0 - while p and l < 200: # 200 characters prompt size - x = p.pop(-1) - l += len(x) + 1 - prompt.append(x) - non_prompt = self.commited[k:] - return self.asr.sep.join(prompt[::-1]), self.asr.sep.join( - t for _, _, t in non_prompt - ) + + """ + + if len(self.final_transcript) == 0: + prompt="" + + if len(self.final_transcript) == 1: + prompt = self.final_transcript[0][2][-200:] + + else: + prompt = self.concatenate_tsw(self.final_transcript)[2][-200:] + # TODO: this is not ideal as we concatenate each time the whole transcript + + # k = max(0, len(self.final_transcript) - 1) + # while k > 1 and self.final_transcript[k - 1][1] > self.buffer_time_offset: + # k -= 1 + + # p = self.final_transcript[:k] + + + # p = [t for _, _, t in p] + # prompt = [] + # l = 0 + # while p and l < 200: # 200 characters prompt size + # x = p.pop(-1) + # l += len(x) + 1 + # prompt.append(x) + + non_prompt = self.concatenate_tsw(self.commited_not_final)[2] + + logger.debug(f"PROMPT(previous): {prompt[:20]}…{prompt[-20:]} (length={len(prompt)}chars)") + logger.debug(f"CONTEXT: {non_prompt}") + + return prompt, non_prompt + def process_iter(self): """Runs on the current audio buffer. @@ -161,107 +187,137 @@ class OnlineASRProcessor: """ prompt, non_prompt = self.prompt() - logger.debug(f"PROMPT(previous): {prompt}") - logger.debug(f"CONTEXT: {non_prompt}") + logger.debug( f"transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds from {self.buffer_time_offset:2.2f}" ) + + ## Transcribe and format the result to [(beg,end,"word1"), ...] res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt) - - # transform to [(beg,end,"word1"), ...] tsw = self.asr.ts_words(res) - # insert into HypothesisBuffer + + + # insert into HypothesisBuffer, and get back the commited words self.transcript_buffer.insert(tsw, self.buffer_time_offset) - o = self.transcript_buffer.flush() - # Completed words - self.commited.extend(o) - completed = self.concatenate_tsw(o) # This will be returned at the end of the function - logger.debug(f">>>>COMPLETE NOW: {completed[2]}") - ## The rest is incomplete - the_rest = self.concatenate_tsw(self.transcript_buffer.complete()) - logger.debug(f"INCOMPLETE: {the_rest[2]}") + commited_tsw = self.transcript_buffer.flush() + + if len(commited_tsw) == 0: + return (None, None, "") - # there is a newly confirmed text + self.commited_not_final.extend(commited_tsw) + + + # Define `completed` and `the_rest` based on the buffer_trimming_way + # completed will be returned at the end of the function. + # completed is a transcribed text with (beg,end,"sentence ...") format. + + + completed = [] if self.buffer_trimming_way == "sentence": + + sentences = self.words_to_sentences(self.commited_not_final) - self.chunk_completed_sentence(self.commited) + + + if len(sentences) < 2: + logger.debug(f"[Sentence-segmentation] no full sentence segmented, do not commit anything.") + + + else: + identified_sentence= "\n - ".join([f"{s[0]*1000:.0f}-{s[1]*1000:.0f} {s[2]}" for s in sentences]) + logger.debug(f"[Sentence-segmentation] identified sentences:\n - {identified_sentence}") + + # assume last sentence is incomplete, which is not always true + + # we will continue with audio processing at this timestamp + chunk_at = sentences[-2][1] + + self.chunk_at(chunk_at) + # TODO: here paragraph breaks can be added + self.commited_not_final = sentences[-1:] + + completed= sentences[:-1] - # TODO: new words in `completed` should not be reterned unless they form a sentence - # TODO: only complete sentences should go to completed + + + + # break audio buffer anyway if it is too long if len(self.audio_buffer) / self.SAMPLING_RATE > self.buffer_trimming_sec : - + if self.buffer_trimming_way == "sentence": logger.warning(f"Chunck segment after {self.buffer_trimming_sec} seconds!" " Even if no sentence was found!" - ) - - - self.chunk_completed_segment(res) - - + ) + + - # alternative: on any word - # l = self.buffer_time_offset + len(self.audio_buffer)/self.SAMPLING_RATE - 10 - # let's find commited word that is less - # k = len(self.commited)-1 - # while k>0 and self.commited[k][1] > l: - # k -= 1 - # t = self.commited[k][1] - # self.chunk_at(t) + + completed = self.chunk_completed_segment() + + + - return completed - def chunk_completed_sentence(self): - if self.commited == []: - return - raw_text = self.asr.sep.join([s[2] for s in self.commited]) - logger.debug(f"COMPLETED SENTENCE: {raw_text}") - sents = self.words_to_sentences(self.commited) - - - if len(sents) < 2: - logger.debug(f"[Sentence-segmentation] no sentence segmented.") - return + if len(completed) == 0: + return (None, None, "") + else: + self.final_transcript.extend(completed) # add whole time stamped sentences / or words to commited list - - identified_sentence= "\n - ".join([f"{s[0]*1000:.0f}-{s[1]*1000:.0f} {s[2]}" for s in sents]) - logger.debug(f"[Sentence-segmentation] identified sentences:\n - {identified_sentence}") - - - # we will continue with audio processing at this timestamp - chunk_at = sents[-2][1] + completed_text_segment= self.concatenate_tsw(completed) + + the_rest = self.concatenate_tsw(self.transcript_buffer.complete()) + commited_but_not_final = self.concatenate_tsw(self.commited_not_final) + logger.debug(f"\n COMPLETE NOW: {completed_text_segment[2]}\n" + f" COMMITTED (but not Final): {commited_but_not_final[2]}\n" + f" INCOMPLETE: {the_rest[2]}" + ) - self.chunk_at(chunk_at) + return completed_text_segment - def chunk_completed_segment(self, res): - if self.commited == []: - return - ends = self.asr.segments_end_ts(res) + def chunk_completed_segment(self) -> list: - t = self.commited[-1][1] + + ts_words = self.commited_not_final - if len(ends) <= 1: + if len(ts_words) <= 1: logger.debug(f"--- not enough segments to chunk (<=1 words)") + return [] else: - e = ends[-2] + self.buffer_time_offset + ends = [w[1] for w in ts_words] + + t = ts_words[-1][1] # start of the last word + e = ends[-2] while len(ends) > 2 and e > t: ends.pop(-1) - e = ends[-2] + self.buffer_time_offset + e = ends[-2] + if e <= t: - logger.debug(f"--- segment chunked at {e:2.2f}") + self.chunk_at(e) + + n_commited_words = len(ends)-1 + + words_to_commit = ts_words[:n_commited_words] + self.final_transcript.extend(words_to_commit) + self.commited_not_final = ts_words[n_commited_words:] + + return words_to_commit + + + else: logger.debug(f"--- last segment not within commited area") + return [] def chunk_at(self, time): @@ -287,9 +343,11 @@ class OnlineASRProcessor: Returns: [(beg,end,"sentence 1"),...] """ + cwords = [w for w in words] t = self.asr.sep.join(o[2] for o in cwords) logger.debug(f"[Sentence-segmentation] Raw Text: {t}") + s = self.tokenize([t]) out = [] while s: @@ -302,11 +360,13 @@ class OnlineASRProcessor: w = w.strip() if beg is None and sent.startswith(w): beg = b - elif end is None and sent == w: + if end is None and sent == w: end = e + if beg is not None and end is not None: out.append((beg, end, fsent)) break sent = sent[len(w) :].strip() + return out def finish(self): @@ -315,7 +375,8 @@ class OnlineASRProcessor: """ o = self.transcript_buffer.complete() f = self.concatenate_tsw(o) - logger.debug(f"last, noncommited: {f[0]*1000:.0f}-{f[1]*1000:.0f}: {f[2][0]*1000:.0f}-{f[1]*1000:.0f}: {f[2]}") + if f[1] is not None: + logger.debug(f"last, noncommited: {f[0]*1000:.0f}-{f[1]*1000:.0f}: {f[2]}") self.buffer_time_offset += len(self.audio_buffer) / 16000 return f @@ -330,7 +391,9 @@ class OnlineASRProcessor: # return: (beg1,end-of-last-sentence,"concatenation of sentences") or (None, None, "") if empty if sep is None: sep = self.asr.sep - + + + t = sep.join(s[2] for s in tsw) if len(tsw) == 0: b = None @@ -349,6 +412,8 @@ class VACOnlineASRProcessor(OnlineASRProcessor): When it detects end of speech (non-voice for 500ms), it makes OnlineASRProcessor to end the utterance immediately. """ +# TODO: VACOnlineASRProcessor does not break after chunch length is reached, this can lead to overflow! + def __init__(self, online_chunk_size, *a, **kw): self.online_chunk_size = online_chunk_size diff --git a/src/whisper_streaming/whisper_online.py b/src/whisper_streaming/whisper_online.py index cd5d005..bb00ae8 100644 --- a/src/whisper_streaming/whisper_online.py +++ b/src/whisper_streaming/whisper_online.py @@ -5,23 +5,12 @@ import librosa from functools import lru_cache import time import logging -from src.whisper_streaming.backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR -from src.whisper_streaming.online_asr import OnlineASRProcessor, VACOnlineASRProcessor +from .backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR +from .online_asr import OnlineASRProcessor, VACOnlineASRProcessor logger = logging.getLogger(__name__) -@lru_cache(10**6) -def load_audio(fname): - a, _ = librosa.load(fname, sr=16000, dtype=np.float32) - return a - - -def load_audio_chunk(fname, beg, end): - audio = load_audio(fname) - beg_s = int(beg * 16000) - end_s = int(end * 16000) - return audio[beg_s:end_s] WHISPER_LANG_CODES = "af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh".split( "," @@ -244,163 +233,3 @@ def set_logging(args, logger, others=[]): logging.getLogger(other).setLevel(args.log_level) -# logging.getLogger("whisper_online_server").setLevel(args.log_level) - - -if __name__ == "__main__": - - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument( - "--audio_path", - type=str, - default='samples_jfk.wav', - help="Filename of 16kHz mono channel wav, on which live streaming is simulated.", - ) - add_shared_args(parser) - parser.add_argument( - "--start_at", - type=float, - default=0.0, - help="Start processing audio at this time.", - ) - parser.add_argument( - "--offline", action="store_true", default=False, help="Offline mode." - ) - parser.add_argument( - "--comp_unaware", - action="store_true", - default=False, - help="Computationally unaware simulation.", - ) - - args = parser.parse_args() - - # reset to store stderr to different file stream, e.g. open(os.devnull,"w") - logfile = None # sys.stderr - - if args.offline and args.comp_unaware: - logger.error( - "No or one option from --offline and --comp_unaware are available, not both. Exiting." - ) - sys.exit(1) - - # if args.log_level: - # logging.basicConfig(format='whisper-%(levelname)s:%(name)s: %(message)s', - # level=getattr(logging, args.log_level)) - - set_logging(args, logger,others=["src.whisper_streaming.online_asr"]) - - audio_path = args.audio_path - - SAMPLING_RATE = 16000 - duration = len(load_audio(audio_path)) / SAMPLING_RATE - logger.info("Audio duration is: %2.2f seconds" % duration) - - asr, online = asr_factory(args, logfile=logfile) - if args.vac: - min_chunk = args.vac_chunk_size - else: - min_chunk = args.min_chunk_size - - # load the audio into the LRU cache before we start the timer - a = load_audio_chunk(audio_path, 0, 1) - - # warm up the ASR because the very first transcribe takes much more time than the other - asr.transcribe(a) - - beg = args.start_at - start = time.time() - beg - - def output_transcript(o, now=None): - # output format in stdout is like: - # 4186.3606 0 1720 Takhle to je - # - the first three words are: - # - emission time from beginning of processing, in milliseconds - # - beg and end timestamp of the text segment, as estimated by Whisper model. The timestamps are not accurate, but they're useful anyway - # - the next words: segment transcript - if now is None: - now = time.time() - start - if o[0] is not None: - log_string = f"{now*1000:1.0f}, {o[0]*1000:1.0f}-{o[1]*1000:1.0f} ({(now-o[1]):+1.0f}s): {o[2]}" - - logger.debug( - log_string - ) - - if logfile is not None: - print( - log_string, - file=logfile, - flush=True, - ) - else: - # No text, so no output - pass - - if args.offline: ## offline mode processing (for testing/debugging) - a = load_audio(audio_path) - online.insert_audio_chunk(a) - try: - o = online.process_iter() - except AssertionError as e: - logger.error(f"assertion error: {repr(e)}") - else: - output_transcript(o) - now = None - elif args.comp_unaware: # computational unaware mode - end = beg + min_chunk - while True: - a = load_audio_chunk(audio_path, beg, end) - online.insert_audio_chunk(a) - try: - o = online.process_iter() - except AssertionError as e: - logger.error(f"assertion error: {repr(e)}") - pass - else: - output_transcript(o, now=end) - - logger.debug(f"## last processed {end:.2f}s") - - if end >= duration: - break - - beg = end - - if end + min_chunk > duration: - end = duration - else: - end += min_chunk - now = duration - - else: # online = simultaneous mode - end = 0 - while True: - now = time.time() - start - if now < end + min_chunk: - time.sleep(min_chunk + end - now) - end = time.time() - start - a = load_audio_chunk(audio_path, beg, end) - beg = end - online.insert_audio_chunk(a) - - try: - o = online.process_iter() - except AssertionError as e: - logger.error(f"assertion error: {e}") - pass - else: - output_transcript(o) - now = time.time() - start - logger.debug( - f"## last processed {end:.2f} s, now is {now:.2f}, the latency is {now-end:.2f}" - ) - - if end >= duration: - break - now = None - - o = online.finish() - output_transcript(o, now=now) diff --git a/whisper_fastapi_online_server.py b/whisper_fastapi_online_server.py index 33391be..4c06117 100644 --- a/whisper_fastapi_online_server.py +++ b/whisper_fastapi_online_server.py @@ -11,6 +11,65 @@ from fastapi.middleware.cors import CORSMiddleware from src.whisper_streaming.whisper_online import backend_factory, online_factory, add_shared_args + +import logging +import logging.config + +def setup_logging(): + logging_config = { + 'version': 1, + 'disable_existing_loggers': False, + 'formatters': { + 'standard': { + 'format': '%(asctime)s %(levelname)s [%(name)s]: %(message)s', + }, + }, + 'handlers': { + 'console': { + 'level': 'INFO', + 'class': 'logging.StreamHandler', + 'formatter': 'standard', + }, + }, + 'root': { + 'handlers': ['console'], + 'level': 'DEBUG', + }, + 'loggers': { + 'uvicorn': { + 'handlers': ['console'], + 'level': 'INFO', + 'propagate': False, + }, + 'uvicorn.error': { + 'level': 'INFO', + }, + 'uvicorn.access': { + 'level': 'INFO', + }, + 'src.whisper_streaming.online_asr': { # Add your specific module here + 'handlers': ['console'], + 'level': 'DEBUG', + 'propagate': False, + }, + 'src.whisper_streaming.whisper_streaming': { # Add your specific module here + 'handlers': ['console'], + 'level': 'DEBUG', + 'propagate': False, + }, + }, + } + + logging.config.dictConfig(logging_config) + +setup_logging() +logger = logging.getLogger(__name__) + + + + + + app = FastAPI() app.add_middleware( CORSMiddleware, @@ -238,5 +297,6 @@ if __name__ == "__main__": import uvicorn uvicorn.run( - "whisper_fastapi_online_server:app", host=args.host, port=args.port, reload=True + "whisper_fastapi_online_server:app", host=args.host, port=args.port, reload=True, + log_level="info" ) diff --git a/whisper_noserver_test.py b/whisper_noserver_test.py new file mode 100644 index 0000000..c36ec7b --- /dev/null +++ b/whisper_noserver_test.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python3 +import sys +import numpy as np +import librosa +from functools import lru_cache +import time +import logging + +logger = logging.getLogger(__name__) + +from src.whisper_streaming.whisper_online import * + +@lru_cache(10**6) +def load_audio(fname): + a, _ = librosa.load(fname, sr=16000, dtype=np.float32) + return a + + +def load_audio_chunk(fname, beg, end): + audio = load_audio(fname) + beg_s = int(beg * 16000) + end_s = int(end * 16000) + return audio[beg_s:end_s] + +if __name__ == "__main__": + + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--audio_path", + type=str, + default='samples_jfk.wav', + help="Filename of 16kHz mono channel wav, on which live streaming is simulated.", + ) + add_shared_args(parser) + parser.add_argument( + "--start_at", + type=float, + default=0.0, + help="Start processing audio at this time.", + ) + parser.add_argument( + "--offline", action="store_true", default=False, help="Offline mode." + ) + parser.add_argument( + "--comp_unaware", + action="store_true", + default=False, + help="Computationally unaware simulation.", + ) + + args = parser.parse_args() + + # reset to store stderr to different file stream, e.g. open(os.devnull,"w") + logfile = None # sys.stderr + + if args.offline and args.comp_unaware: + logger.error( + "No or one option from --offline and --comp_unaware are available, not both. Exiting." + ) + sys.exit(1) + + # if args.log_level: + # logging.basicConfig(format='whisper-%(levelname)s:%(name)s: %(message)s', + # level=getattr(logging, args.log_level)) + + set_logging(args, logger,others=["src.whisper_streaming.online_asr"]) + + audio_path = args.audio_path + + SAMPLING_RATE = 16000 + duration = len(load_audio(audio_path)) / SAMPLING_RATE + logger.info("Audio duration is: %2.2f seconds" % duration) + + asr, online = asr_factory(args, logfile=logfile) + if args.vac: + min_chunk = args.vac_chunk_size + else: + min_chunk = args.min_chunk_size + + # load the audio into the LRU cache before we start the timer + a = load_audio_chunk(audio_path, 0, 1) + + # warm up the ASR because the very first transcribe takes much more time than the other + asr.transcribe(a) + + beg = args.start_at + start = time.time() - beg + + def output_transcript(o, now=None): + # output format in stdout is like: + # 4186.3606 0 1720 Takhle to je + # - the first three words are: + # - emission time from beginning of processing, in milliseconds + # - beg and end timestamp of the text segment, as estimated by Whisper model. The timestamps are not accurate, but they're useful anyway + # - the next words: segment transcript + if now is None: + now = time.time() - start + if o[0] is not None: + log_string = f"{now*1000:1.0f}, {o[0]*1000:1.0f}-{o[1]*1000:1.0f} ({(now-o[1]):+1.0f}s): {o[2]}" + + logger.debug( + log_string + ) + + if logfile is not None: + print( + log_string, + file=logfile, + flush=True, + ) + else: + # No text, so no output + pass + + if args.offline: ## offline mode processing (for testing/debugging) + a = load_audio(audio_path) + online.insert_audio_chunk(a) + try: + o = online.process_iter() + except AssertionError as e: + logger.error(f"assertion error: {repr(e)}") + else: + output_transcript(o) + now = None + elif args.comp_unaware: # computational unaware mode + end = beg + min_chunk + while True: + a = load_audio_chunk(audio_path, beg, end) + online.insert_audio_chunk(a) + try: + o = online.process_iter() + except AssertionError as e: + logger.error(f"assertion error: {repr(e)}") + pass + else: + output_transcript(o, now=end) + + logger.debug(f"## last processed {end:.2f}s") + + if end >= duration: + break + + beg = end + + if end + min_chunk > duration: + end = duration + else: + end += min_chunk + now = duration + + else: # online = simultaneous mode + end = 0 + while True: + now = time.time() - start + if now < end + min_chunk: + time.sleep(min_chunk + end - now) + end = time.time() - start + a = load_audio_chunk(audio_path, beg, end) + beg = end + online.insert_audio_chunk(a) + + try: + o = online.process_iter() + except AssertionError as e: + logger.error(f"assertion error: {e}") + pass + else: + output_transcript(o) + now = time.time() - start + logger.debug( + f"## last processed {end:.2f} s, now is {now:.2f}, the latency is {now-end:.2f}" + ) + + if end >= duration: + break + now = None + + o = online.finish() + output_transcript(o, now=now)