mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
seamless streaming integrated
This commit is contained in:
161
seamless_integration.py
Normal file
161
seamless_integration.py
Normal file
@@ -0,0 +1,161 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
# code extracted from https://github.com/facebookresearch/seamless_communication/blob/main/Seamless_Tutorial.ipynb :
|
||||
|
||||
from simuleval.data.segments import SpeechSegment, EmptySegment
|
||||
from simuleval.utils.arguments import cli_argument_list
|
||||
from simuleval import options
|
||||
|
||||
|
||||
from typing import Union, List
|
||||
from simuleval.data.segments import Segment, TextSegment
|
||||
from simuleval.agents.pipeline import TreeAgentPipeline
|
||||
from simuleval.agents.states import AgentStates
|
||||
|
||||
|
||||
SAMPLE_RATE = 16000
|
||||
|
||||
def reset_states(system, states):
|
||||
if isinstance(system, TreeAgentPipeline):
|
||||
states_iter = states.values()
|
||||
else:
|
||||
states_iter = states
|
||||
for state in states_iter:
|
||||
state.reset()
|
||||
|
||||
|
||||
def get_states_root(system, states) -> AgentStates:
|
||||
if isinstance(system, TreeAgentPipeline):
|
||||
# self.states is a dict
|
||||
return states[system.source_module]
|
||||
else:
|
||||
# self.states is a list
|
||||
return system.states[0]
|
||||
|
||||
|
||||
def build_streaming_system(model_configs, agent_class):
|
||||
parser = options.general_parser()
|
||||
parser.add_argument("-f", "--f", help="a dummy argument to fool ipython", default="1")
|
||||
|
||||
agent_class.add_args(parser)
|
||||
args, _ = parser.parse_known_args(cli_argument_list(model_configs))
|
||||
system = agent_class.from_args(args)
|
||||
return system
|
||||
|
||||
class OutputSegments:
|
||||
def __init__(self, segments: Union[List[Segment], Segment]):
|
||||
if isinstance(segments, Segment):
|
||||
segments = [segments]
|
||||
self.segments: List[Segment] = [s for s in segments]
|
||||
|
||||
@property
|
||||
def is_empty(self):
|
||||
return all(segment.is_empty for segment in self.segments)
|
||||
|
||||
@property
|
||||
def finished(self):
|
||||
return all(segment.finished for segment in self.segments)
|
||||
|
||||
|
||||
######################
|
||||
# fixing DetokenizerAgent -- it strips output segment.content last space, but sometimes a word is split into more segments. Simple joining with spaces would be wrong.
|
||||
from seamless_communication.streaming.agents.detokenizer import DetokenizerAgent
|
||||
from seamless_communication.streaming.agents.offline_w2v_bert_encoder import (
|
||||
OfflineWav2VecBertEncoderAgent,
|
||||
)
|
||||
from seamless_communication.streaming.agents.online_feature_extractor import (
|
||||
OnlineFeatureExtractorAgent,
|
||||
)
|
||||
from seamless_communication.streaming.agents.online_text_decoder import (
|
||||
MMASpeechToTextDecoderAgent,
|
||||
)
|
||||
from seamless_communication.streaming.agents.silero_vad import SileroVADAgent
|
||||
from seamless_communication.streaming.agents.unity_pipeline import UnitYAgentPipeline
|
||||
class FixDetokenizerAgent(DetokenizerAgent):
|
||||
def decode(self, x: str) -> str:
|
||||
return x.replace(" ", "").replace("\u2581", " ")
|
||||
|
||||
class FixSeamlessStreamingS2TVADAgent(UnitYAgentPipeline):
|
||||
pipeline = [
|
||||
SileroVADAgent,
|
||||
OnlineFeatureExtractorAgent,
|
||||
OfflineWav2VecBertEncoderAgent,
|
||||
MMASpeechToTextDecoderAgent,
|
||||
FixDetokenizerAgent,
|
||||
]
|
||||
##################################
|
||||
|
||||
#class SeamlessProcessor(OnlineASRProcessorBase): # TODO: there should be a common base class
|
||||
class SeamlessProcessor:
|
||||
def __init__(self, tgt_lan, logfile=sys.stderr):
|
||||
self.logfile = logfile
|
||||
|
||||
agent_class = FixSeamlessStreamingS2TVADAgent
|
||||
|
||||
model_configs = dict(
|
||||
source_segment_size=320,
|
||||
device="cuda:0",
|
||||
dtype="fp16",
|
||||
min_starting_wait_w2vbert=192,
|
||||
decision_threshold=0.5,
|
||||
min_unit_chunk_size=50,
|
||||
no_early_stop=True,
|
||||
max_len_a=0,
|
||||
max_len_b=100,
|
||||
task="s2tt",
|
||||
tgt_lang=tgt_lan,
|
||||
block_ngrams=True,
|
||||
detokenize_only=True,
|
||||
)
|
||||
self.tgt_lan = tgt_lan
|
||||
|
||||
self.system = build_streaming_system(model_configs, agent_class)
|
||||
|
||||
self.system_states = self.system.build_states()
|
||||
|
||||
self.init()
|
||||
|
||||
def init(self):
|
||||
reset_states(self.system, self.system_states)
|
||||
self.audio_buffer = np.array([],dtype=np.float32)
|
||||
self.beg, self.end = 0, 0
|
||||
|
||||
def insert_audio_chunk(self, audio):
|
||||
self.audio_buffer = np.append(self.audio_buffer, audio)
|
||||
|
||||
def process_segment(self, input_segment):
|
||||
output_segments = OutputSegments(self.system.pushpop(input_segment, self.system_states))
|
||||
out = []
|
||||
for segment in output_segments.segments:
|
||||
if not segment.is_empty:
|
||||
out.append(segment.content)
|
||||
if output_segments.finished:
|
||||
print("End of VAD segment",file=self.logfile)
|
||||
reset_states(self.system, self.system_states)
|
||||
if out:
|
||||
b = self.beg
|
||||
self.beg = self.end
|
||||
o = "".join(out)
|
||||
return (b, self.end, "".join(out))
|
||||
return (None, None, "")
|
||||
|
||||
|
||||
def process_iter(self):
|
||||
input_segment = SpeechSegment(
|
||||
content=self.audio_buffer,
|
||||
sample_rate=SAMPLE_RATE,
|
||||
finished=False,
|
||||
)
|
||||
self.audio_buffer = np.array([],dtype=np.float32)
|
||||
input_segment.tgt_lang = self.tgt_lan
|
||||
self.end += (len(input_segment.content)/SAMPLE_RATE)
|
||||
return self.process_segment(input_segment)
|
||||
|
||||
def finish(self):
|
||||
segment = EmptySegment(
|
||||
finished=True,
|
||||
)
|
||||
return self.process_segment(segment)
|
||||
@@ -208,7 +208,17 @@ class HypothesisBuffer:
|
||||
def complete(self):
|
||||
return self.buffer
|
||||
|
||||
class OnlineASRProcessor:
|
||||
class OnlineASRProcessorBase:
|
||||
def init(self):
|
||||
raise NotImplemented()
|
||||
def insert_audio_chunk(self, audio):
|
||||
raise NotImplemented()
|
||||
def process_iter(self):
|
||||
raise NotImplemented()
|
||||
def finish(self):
|
||||
raise NotImplemented()
|
||||
|
||||
class OnlineASRProcessor(OnlineASRProcessorBase):
|
||||
|
||||
SAMPLING_RATE = 16000
|
||||
|
||||
@@ -410,6 +420,7 @@ class OnlineASRProcessor:
|
||||
e = offset + sents[-1][1]
|
||||
return (b,e,t)
|
||||
|
||||
|
||||
WHISPER_LANG_CODES = "af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh".split(",")
|
||||
|
||||
def create_tokenizer(lan):
|
||||
@@ -453,7 +464,7 @@ def add_shared_args(parser):
|
||||
parser.add_argument('--model_dir', type=str, default=None, help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.")
|
||||
parser.add_argument('--lan', '--language', type=str, default='en', help="Language code for transcription, e.g. en,de,cs.")
|
||||
parser.add_argument('--task', type=str, default='transcribe', choices=["transcribe","translate"],help="Transcribe or translate.")
|
||||
parser.add_argument('--backend', type=str, default="faster-whisper", choices=["faster-whisper", "whisper_timestamped"],help='Load only this backend for Whisper processing.')
|
||||
parser.add_argument('--backend', type=str, default="faster-whisper", choices=["faster-whisper", "whisper_timestamped", "seamless"],help='Load only this backend for Whisper processing, or SeamlessM4T Streaming backend.')
|
||||
parser.add_argument('--vad', action="store_true", default=False, help='Use VAD = voice activity detection, with the default parameters.')
|
||||
parser.add_argument('--buffer_trimming', type=str, default="segment", choices=["sentence", "segment"],help='Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.')
|
||||
parser.add_argument('--buffer_trimming_sec', type=float, default=15, help='Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.')
|
||||
@@ -488,44 +499,49 @@ if __name__ == "__main__":
|
||||
size = args.model
|
||||
language = args.lan
|
||||
|
||||
t = time.time()
|
||||
print(f"Loading Whisper {size} model for {language}...",file=logfile,end=" ",flush=True)
|
||||
|
||||
if args.backend == "faster-whisper":
|
||||
asr_cls = FasterWhisperASR
|
||||
else:
|
||||
asr_cls = WhisperTimestampedASR
|
||||
|
||||
asr = asr_cls(modelsize=size, lan=language, cache_dir=args.model_cache_dir, model_dir=args.model_dir)
|
||||
|
||||
if args.task == "translate":
|
||||
asr.set_translate_task()
|
||||
tgt_language = "en" # Whisper translates into English
|
||||
else:
|
||||
tgt_language = language # Whisper transcribes in this language
|
||||
|
||||
|
||||
e = time.time()
|
||||
print(f"done. It took {round(e-t,2)} seconds.",file=logfile)
|
||||
|
||||
if args.vad:
|
||||
print("setting VAD filter",file=logfile)
|
||||
asr.use_vad()
|
||||
|
||||
|
||||
min_chunk = args.min_chunk_size
|
||||
if args.buffer_trimming == "sentence":
|
||||
tokenizer = create_tokenizer(tgt_language)
|
||||
|
||||
if args.backend != "seamless":
|
||||
# loading Whisper model
|
||||
t = time.time()
|
||||
print(f"Loading Whisper {size} model for {language}...",file=logfile,end=" ",flush=True)
|
||||
|
||||
if args.backend == "faster-whisper":
|
||||
asr_cls = FasterWhisperASR
|
||||
elif args.backend == "whisper_timestamped":
|
||||
asr_cls = WhisperTimestampedASR
|
||||
|
||||
asr = asr_cls(modelsize=size, lan=language, cache_dir=args.model_cache_dir, model_dir=args.model_dir)
|
||||
|
||||
e = time.time()
|
||||
print(f"done. It took {round(e-t,2)} seconds.",file=logfile)
|
||||
|
||||
if args.vad:
|
||||
print("setting VAD filter",file=logfile)
|
||||
asr.use_vad()
|
||||
if args.task == "translate":
|
||||
asr.set_translate_task()
|
||||
tgt_language = "en" # Whisper translates into English
|
||||
else:
|
||||
tgt_language = language # Whisper transcribes in this language
|
||||
|
||||
if args.buffer_trimming == "sentence":
|
||||
tokenizer = create_tokenizer(tgt_language)
|
||||
else:
|
||||
tokenizer = None
|
||||
|
||||
online = OnlineASRProcessor(asr,tokenizer,logfile=logfile,buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec))
|
||||
# load the audio into the LRU cache before we start the timer
|
||||
a = load_audio_chunk(audio_path,0,1)
|
||||
|
||||
# warm up the ASR, because the very first transcribe takes much more time than the other
|
||||
asr.transcribe(a)
|
||||
|
||||
else:
|
||||
tokenizer = None
|
||||
online = OnlineASRProcessor(asr,tokenizer,logfile=logfile,buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec))
|
||||
print(f"Loading SeamlessM4T Streaming backend model",file=logfile,flush=True)
|
||||
|
||||
|
||||
# load the audio into the LRU cache before we start the timer
|
||||
a = load_audio_chunk(audio_path,0,1)
|
||||
|
||||
# warm up the ASR, because the very first transcribe takes much more time than the other
|
||||
asr.transcribe(a)
|
||||
from seamless_integration import SeamlessProcessor
|
||||
online = SeamlessProcessor(language, logfile=logfile)
|
||||
|
||||
beg = args.start_at
|
||||
start = time.time()-beg
|
||||
|
||||
Reference in New Issue
Block a user