diff --git a/seamless_integration.py b/seamless_integration.py new file mode 100644 index 0000000..f4e1a67 --- /dev/null +++ b/seamless_integration.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python3 + +import sys +import numpy as np + +# code extracted from https://github.com/facebookresearch/seamless_communication/blob/main/Seamless_Tutorial.ipynb : + +from simuleval.data.segments import SpeechSegment, EmptySegment +from simuleval.utils.arguments import cli_argument_list +from simuleval import options + + +from typing import Union, List +from simuleval.data.segments import Segment, TextSegment +from simuleval.agents.pipeline import TreeAgentPipeline +from simuleval.agents.states import AgentStates + + +SAMPLE_RATE = 16000 + +def reset_states(system, states): + if isinstance(system, TreeAgentPipeline): + states_iter = states.values() + else: + states_iter = states + for state in states_iter: + state.reset() + + +def get_states_root(system, states) -> AgentStates: + if isinstance(system, TreeAgentPipeline): + # self.states is a dict + return states[system.source_module] + else: + # self.states is a list + return system.states[0] + + +def build_streaming_system(model_configs, agent_class): + parser = options.general_parser() + parser.add_argument("-f", "--f", help="a dummy argument to fool ipython", default="1") + + agent_class.add_args(parser) + args, _ = parser.parse_known_args(cli_argument_list(model_configs)) + system = agent_class.from_args(args) + return system + +class OutputSegments: + def __init__(self, segments: Union[List[Segment], Segment]): + if isinstance(segments, Segment): + segments = [segments] + self.segments: List[Segment] = [s for s in segments] + + @property + def is_empty(self): + return all(segment.is_empty for segment in self.segments) + + @property + def finished(self): + return all(segment.finished for segment in self.segments) + + +###################### +# fixing DetokenizerAgent -- it strips output segment.content last space, but sometimes a word is split into more segments. Simple joining with spaces would be wrong. +from seamless_communication.streaming.agents.detokenizer import DetokenizerAgent +from seamless_communication.streaming.agents.offline_w2v_bert_encoder import ( + OfflineWav2VecBertEncoderAgent, +) +from seamless_communication.streaming.agents.online_feature_extractor import ( + OnlineFeatureExtractorAgent, +) +from seamless_communication.streaming.agents.online_text_decoder import ( + MMASpeechToTextDecoderAgent, +) +from seamless_communication.streaming.agents.silero_vad import SileroVADAgent +from seamless_communication.streaming.agents.unity_pipeline import UnitYAgentPipeline +class FixDetokenizerAgent(DetokenizerAgent): + def decode(self, x: str) -> str: + return x.replace(" ", "").replace("\u2581", " ") + +class FixSeamlessStreamingS2TVADAgent(UnitYAgentPipeline): + pipeline = [ + SileroVADAgent, + OnlineFeatureExtractorAgent, + OfflineWav2VecBertEncoderAgent, + MMASpeechToTextDecoderAgent, + FixDetokenizerAgent, + ] +################################## + +#class SeamlessProcessor(OnlineASRProcessorBase): # TODO: there should be a common base class +class SeamlessProcessor: + def __init__(self, tgt_lan, logfile=sys.stderr): + self.logfile = logfile + + agent_class = FixSeamlessStreamingS2TVADAgent + + model_configs = dict( + source_segment_size=320, + device="cuda:0", + dtype="fp16", + min_starting_wait_w2vbert=192, + decision_threshold=0.5, + min_unit_chunk_size=50, + no_early_stop=True, + max_len_a=0, + max_len_b=100, + task="s2tt", + tgt_lang=tgt_lan, + block_ngrams=True, + detokenize_only=True, + ) + self.tgt_lan = tgt_lan + + self.system = build_streaming_system(model_configs, agent_class) + + self.system_states = self.system.build_states() + + self.init() + + def init(self): + reset_states(self.system, self.system_states) + self.audio_buffer = np.array([],dtype=np.float32) + self.beg, self.end = 0, 0 + + def insert_audio_chunk(self, audio): + self.audio_buffer = np.append(self.audio_buffer, audio) + + def process_segment(self, input_segment): + output_segments = OutputSegments(self.system.pushpop(input_segment, self.system_states)) + out = [] + for segment in output_segments.segments: + if not segment.is_empty: + out.append(segment.content) + if output_segments.finished: + print("End of VAD segment",file=self.logfile) + reset_states(self.system, self.system_states) + if out: + b = self.beg + self.beg = self.end + o = "".join(out) + return (b, self.end, "".join(out)) + return (None, None, "") + + + def process_iter(self): + input_segment = SpeechSegment( + content=self.audio_buffer, + sample_rate=SAMPLE_RATE, + finished=False, + ) + self.audio_buffer = np.array([],dtype=np.float32) + input_segment.tgt_lang = self.tgt_lan + self.end += (len(input_segment.content)/SAMPLE_RATE) + return self.process_segment(input_segment) + + def finish(self): + segment = EmptySegment( + finished=True, + ) + return self.process_segment(segment) diff --git a/whisper_online.py b/whisper_online.py index 36bdbd6..5b2e53b 100644 --- a/whisper_online.py +++ b/whisper_online.py @@ -208,7 +208,17 @@ class HypothesisBuffer: def complete(self): return self.buffer -class OnlineASRProcessor: +class OnlineASRProcessorBase: + def init(self): + raise NotImplemented() + def insert_audio_chunk(self, audio): + raise NotImplemented() + def process_iter(self): + raise NotImplemented() + def finish(self): + raise NotImplemented() + +class OnlineASRProcessor(OnlineASRProcessorBase): SAMPLING_RATE = 16000 @@ -410,6 +420,7 @@ class OnlineASRProcessor: e = offset + sents[-1][1] return (b,e,t) + WHISPER_LANG_CODES = "af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh".split(",") def create_tokenizer(lan): @@ -453,7 +464,7 @@ def add_shared_args(parser): parser.add_argument('--model_dir', type=str, default=None, help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.") parser.add_argument('--lan', '--language', type=str, default='en', help="Language code for transcription, e.g. en,de,cs.") parser.add_argument('--task', type=str, default='transcribe', choices=["transcribe","translate"],help="Transcribe or translate.") - parser.add_argument('--backend', type=str, default="faster-whisper", choices=["faster-whisper", "whisper_timestamped"],help='Load only this backend for Whisper processing.') + parser.add_argument('--backend', type=str, default="faster-whisper", choices=["faster-whisper", "whisper_timestamped", "seamless"],help='Load only this backend for Whisper processing, or SeamlessM4T Streaming backend.') parser.add_argument('--vad', action="store_true", default=False, help='Use VAD = voice activity detection, with the default parameters.') parser.add_argument('--buffer_trimming', type=str, default="segment", choices=["sentence", "segment"],help='Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.') parser.add_argument('--buffer_trimming_sec', type=float, default=15, help='Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.') @@ -488,44 +499,49 @@ if __name__ == "__main__": size = args.model language = args.lan - t = time.time() - print(f"Loading Whisper {size} model for {language}...",file=logfile,end=" ",flush=True) - - if args.backend == "faster-whisper": - asr_cls = FasterWhisperASR - else: - asr_cls = WhisperTimestampedASR - - asr = asr_cls(modelsize=size, lan=language, cache_dir=args.model_cache_dir, model_dir=args.model_dir) - - if args.task == "translate": - asr.set_translate_task() - tgt_language = "en" # Whisper translates into English - else: - tgt_language = language # Whisper transcribes in this language - - - e = time.time() - print(f"done. It took {round(e-t,2)} seconds.",file=logfile) - - if args.vad: - print("setting VAD filter",file=logfile) - asr.use_vad() - - min_chunk = args.min_chunk_size - if args.buffer_trimming == "sentence": - tokenizer = create_tokenizer(tgt_language) + + if args.backend != "seamless": + # loading Whisper model + t = time.time() + print(f"Loading Whisper {size} model for {language}...",file=logfile,end=" ",flush=True) + + if args.backend == "faster-whisper": + asr_cls = FasterWhisperASR + elif args.backend == "whisper_timestamped": + asr_cls = WhisperTimestampedASR + + asr = asr_cls(modelsize=size, lan=language, cache_dir=args.model_cache_dir, model_dir=args.model_dir) + + e = time.time() + print(f"done. It took {round(e-t,2)} seconds.",file=logfile) + + if args.vad: + print("setting VAD filter",file=logfile) + asr.use_vad() + if args.task == "translate": + asr.set_translate_task() + tgt_language = "en" # Whisper translates into English + else: + tgt_language = language # Whisper transcribes in this language + + if args.buffer_trimming == "sentence": + tokenizer = create_tokenizer(tgt_language) + else: + tokenizer = None + + online = OnlineASRProcessor(asr,tokenizer,logfile=logfile,buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec)) + # load the audio into the LRU cache before we start the timer + a = load_audio_chunk(audio_path,0,1) + + # warm up the ASR, because the very first transcribe takes much more time than the other + asr.transcribe(a) + else: - tokenizer = None - online = OnlineASRProcessor(asr,tokenizer,logfile=logfile,buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec)) + print(f"Loading SeamlessM4T Streaming backend model",file=logfile,flush=True) - - # load the audio into the LRU cache before we start the timer - a = load_audio_chunk(audio_path,0,1) - - # warm up the ASR, because the very first transcribe takes much more time than the other - asr.transcribe(a) + from seamless_integration import SeamlessProcessor + online = SeamlessProcessor(language, logfile=logfile) beg = args.start_at start = time.time()-beg