mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 14:23:18 +00:00
202 lines
7.8 KiB
Python
202 lines
7.8 KiB
Python
import logging
|
|
import sys
|
|
from argparse import Namespace
|
|
|
|
from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor
|
|
from whisperlivekit.local_agreement.whisper_online import backend_factory
|
|
from whisperlivekit.simul_whisper import SimulStreamingASR
|
|
|
|
|
|
def update_with_kwargs(_dict, kwargs):
|
|
_dict.update({
|
|
k: v for k, v in kwargs.items() if k in _dict
|
|
})
|
|
return _dict
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class TranscriptionEngine:
|
|
_instance = None
|
|
_initialized = False
|
|
|
|
def __new__(cls, *args, **kwargs):
|
|
if cls._instance is None:
|
|
cls._instance = super().__new__(cls)
|
|
return cls._instance
|
|
|
|
def __init__(self, **kwargs):
|
|
if TranscriptionEngine._initialized:
|
|
return
|
|
|
|
global_params = {
|
|
"host": "localhost",
|
|
"port": 8000,
|
|
"diarization": False,
|
|
"punctuation_split": False,
|
|
"target_language": "",
|
|
"vac": True,
|
|
"vac_onnx": False,
|
|
"vac_chunk_size": 0.04,
|
|
"log_level": "DEBUG",
|
|
"ssl_certfile": None,
|
|
"ssl_keyfile": None,
|
|
"forwarded_allow_ips": None,
|
|
"transcription": True,
|
|
"vad": True,
|
|
"pcm_input": False,
|
|
"disable_punctuation_split" : False,
|
|
"diarization_backend": "sortformer",
|
|
"backend_policy": "simulstreaming",
|
|
"backend": "auto",
|
|
}
|
|
global_params = update_with_kwargs(global_params, kwargs)
|
|
|
|
transcription_common_params = {
|
|
"warmup_file": None,
|
|
"min_chunk_size": 0.1,
|
|
"model_size": "base",
|
|
"model_cache_dir": None,
|
|
"model_dir": None,
|
|
"model_path": None,
|
|
"lan": "auto",
|
|
"direct_english_translation": False,
|
|
}
|
|
transcription_common_params = update_with_kwargs(transcription_common_params, kwargs)
|
|
|
|
if transcription_common_params['model_size'].endswith(".en"):
|
|
transcription_common_params["lan"] = "en"
|
|
if 'no_transcription' in kwargs:
|
|
global_params['transcription'] = not global_params['no_transcription']
|
|
if 'no_vad' in kwargs:
|
|
global_params['vad'] = not kwargs['no_vad']
|
|
if 'no_vac' in kwargs:
|
|
global_params['vac'] = not kwargs['no_vac']
|
|
|
|
self.args = Namespace(**{**global_params, **transcription_common_params})
|
|
|
|
self.asr = None
|
|
self.tokenizer = None
|
|
self.diarization = None
|
|
self.vac_model = None
|
|
|
|
if self.args.vac:
|
|
from whisperlivekit.silero_vad_iterator import load_silero_vad
|
|
|
|
# Use ONNX if specified, otherwise use JIT (default)
|
|
use_onnx = kwargs.get('vac_onnx', False)
|
|
self.vac_model = load_silero_vad(onnx=use_onnx)
|
|
|
|
backend_policy = self.args.backend_policy
|
|
if self.args.transcription:
|
|
if backend_policy == "simulstreaming":
|
|
simulstreaming_params = {
|
|
"disable_fast_encoder": False,
|
|
"custom_alignment_heads": None,
|
|
"frame_threshold": 25,
|
|
"beams": 1,
|
|
"decoder_type": None,
|
|
"audio_max_len": 20.0,
|
|
"audio_min_len": 0.0,
|
|
"cif_ckpt_path": None,
|
|
"never_fire": False,
|
|
"init_prompt": None,
|
|
"static_init_prompt": None,
|
|
"max_context_tokens": None,
|
|
"preload_model_count": 1,
|
|
}
|
|
simulstreaming_params = update_with_kwargs(simulstreaming_params, kwargs)
|
|
|
|
self.tokenizer = None
|
|
self.asr = SimulStreamingASR(
|
|
**transcription_common_params,
|
|
**simulstreaming_params,
|
|
backend=self.args.backend,
|
|
)
|
|
logger.info(
|
|
"Using SimulStreaming policy with %s backend",
|
|
getattr(self.asr, "encoder_backend", "whisper"),
|
|
)
|
|
else:
|
|
|
|
whisperstreaming_params = {
|
|
"buffer_trimming": "segment",
|
|
"confidence_validation": False,
|
|
"buffer_trimming_sec": 15,
|
|
}
|
|
whisperstreaming_params = update_with_kwargs(whisperstreaming_params, kwargs)
|
|
|
|
self.asr = backend_factory(
|
|
backend=self.args.backend,
|
|
**transcription_common_params,
|
|
**whisperstreaming_params,
|
|
)
|
|
logger.info(
|
|
"Using LocalAgreement policy with %s backend",
|
|
getattr(self.asr, "backend_choice", self.asr.__class__.__name__),
|
|
)
|
|
|
|
if self.args.diarization:
|
|
if self.args.diarization_backend == "diart":
|
|
from whisperlivekit.diarization.diart_backend import \
|
|
DiartDiarization
|
|
diart_params = {
|
|
"segmentation_model": "pyannote/segmentation-3.0",
|
|
"embedding_model": "pyannote/embedding",
|
|
}
|
|
diart_params = update_with_kwargs(diart_params, kwargs)
|
|
self.diarization_model = DiartDiarization(
|
|
block_duration=self.args.min_chunk_size,
|
|
**diart_params
|
|
)
|
|
elif self.args.diarization_backend == "sortformer":
|
|
from whisperlivekit.diarization.sortformer_backend import \
|
|
SortformerDiarization
|
|
self.diarization_model = SortformerDiarization()
|
|
|
|
self.translation_model = None
|
|
if self.args.target_language:
|
|
if self.args.lan == 'auto' and backend_policy != "simulstreaming":
|
|
raise Exception('Translation cannot be set with language auto when transcription backend is not simulstreaming')
|
|
else:
|
|
try:
|
|
from nllw import load_model
|
|
except:
|
|
raise Exception('To use translation, you must install nllw: `pip install nllw`')
|
|
translation_params = {
|
|
"nllb_backend": "transformers",
|
|
"nllb_size": "600M"
|
|
}
|
|
translation_params = update_with_kwargs(translation_params, kwargs)
|
|
self.translation_model = load_model([self.args.lan], **translation_params) #in the future we want to handle different languages for different speakers
|
|
TranscriptionEngine._initialized = True
|
|
|
|
|
|
def online_factory(args, asr):
|
|
if args.backend_policy == "simulstreaming":
|
|
from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor
|
|
online = SimulStreamingOnlineProcessor(asr)
|
|
else:
|
|
online = OnlineASRProcessor(asr)
|
|
return online
|
|
|
|
|
|
def online_diarization_factory(args, diarization_backend):
|
|
if args.diarization_backend == "diart":
|
|
online = diarization_backend
|
|
# Not the best here, since several user/instances will share the same backend, but diart is not SOTA anymore and sortformer is recommended
|
|
|
|
if args.diarization_backend == "sortformer":
|
|
from whisperlivekit.diarization.sortformer_backend import \
|
|
SortformerDiarizationOnline
|
|
online = SortformerDiarizationOnline(shared_model=diarization_backend)
|
|
return online
|
|
|
|
|
|
def online_translation_factory(args, translation_model):
|
|
#should be at speaker level in the future:
|
|
#one shared nllb model for all speaker
|
|
#one tokenizer per speaker/language
|
|
from nllw import OnlineTranslation
|
|
return OnlineTranslation(translation_model, [args.lan], [args.target_language])
|