mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-04-26 16:45:46 +00:00
Merge branch 'main' into ScriptProcessorNode-to-AudioWorklet
This commit is contained in:
25
DEV_NOTES.md
25
DEV_NOTES.md
@@ -18,8 +18,29 @@ Decoder weights: 59110771 bytes
|
||||
Encoder weights: 15268874 bytes
|
||||
|
||||
|
||||
# 2. Translation: Faster model for each system
|
||||
|
||||
# 2. SortFormer Diarization: 4-to-2 Speaker Constraint Algorithm
|
||||
## Benchmark Results
|
||||
|
||||
Testing on MacBook M3 with NLLB-200-distilled-600M model:
|
||||
|
||||
### Standard Transformers vs CTranslate2
|
||||
|
||||
| Test Text | Standard Inference Time | CTranslate2 Inference Time | Speedup |
|
||||
|-----------|-------------------------|---------------------------|---------|
|
||||
| UN Chief says there is no military solution in Syria | 0.9395s | 2.0472s | 0.5x |
|
||||
| The rapid advancement of AI technology is transforming various industries | 0.7171s | 1.7516s | 0.4x |
|
||||
| Climate change poses a significant threat to global ecosystems | 0.8533s | 1.8323s | 0.5x |
|
||||
| International cooperation is essential for addressing global challenges | 0.7209s | 1.3575s | 0.5x |
|
||||
| The development of renewable energy sources is crucial for a sustainable future | 0.8760s | 1.5589s | 0.6x |
|
||||
|
||||
**Results:**
|
||||
- Total Standard time: 4.1068s
|
||||
- Total CTranslate2 time: 8.5476s
|
||||
- CTranslate2 is slower on this system --> Use Transformers, and ideally we would have an mlx implementation.
|
||||
|
||||
|
||||
# 3. SortFormer Diarization: 4-to-2 Speaker Constraint Algorithm
|
||||
|
||||
Transform a diarization model that predicts up to 4 speakers into one that predicts up to 2 speakers by mapping the output predictions.
|
||||
|
||||
@@ -67,4 +88,4 @@ ELSE:
|
||||
AS_2 ← B
|
||||
|
||||
to finish
|
||||
```
|
||||
```
|
||||
|
||||
@@ -198,6 +198,11 @@ An important list of parameters can be changed. But what *should* you change?
|
||||
| `--embedding-model` | Hugging Face model ID for Diart embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
|
||||
|
||||
|
||||
| Translation options | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--nllb-backend` | `transformers` or `ctranslate2` | `ctranslate2` |
|
||||
| `--nllb-size` | `600M` or `1.3B` | `600M` |
|
||||
|
||||
> For diarization using Diart, you need access to pyannote.audio models:
|
||||
> 1. [Accept user conditions](https://huggingface.co/pyannote/segmentation) for the `pyannote/segmentation` model
|
||||
> 2. [Accept user conditions](https://huggingface.co/pyannote/segmentation-3.0) for the `pyannote/segmentation-3.0` model
|
||||
|
||||
@@ -179,12 +179,11 @@ class AudioProcessor:
|
||||
asr_processing_logs += f" + Silence of = {item.duration:.2f}s"
|
||||
if self.tokens:
|
||||
asr_processing_logs += f" | last_end = {self.tokens[-1].end} |"
|
||||
logger.info(asr_processing_logs)
|
||||
|
||||
if type(item) is Silence:
|
||||
logger.info(asr_processing_logs)
|
||||
cumulative_pcm_duration_stream_time += item.duration
|
||||
self.online.insert_silence(item.duration, self.tokens[-1].end if self.tokens else 0)
|
||||
continue
|
||||
logger.info(asr_processing_logs)
|
||||
|
||||
if isinstance(item, np.ndarray):
|
||||
pcm_array = item
|
||||
@@ -223,7 +222,7 @@ class AudioProcessor:
|
||||
new_tokens, buffer_text, new_end_buffer
|
||||
)
|
||||
|
||||
if new_tokens and self.args.target_language and self.translation_queue:
|
||||
if self.translation_queue:
|
||||
for token in new_tokens:
|
||||
await self.translation_queue.put(token)
|
||||
|
||||
@@ -256,13 +255,11 @@ class AudioProcessor:
|
||||
logger.debug("Diarization processor received sentinel. Finishing.")
|
||||
self.diarization_queue.task_done()
|
||||
break
|
||||
|
||||
if type(item) is Silence:
|
||||
elif type(item) is Silence:
|
||||
cumulative_pcm_duration_stream_time += item.duration
|
||||
diarization_obj.insert_silence(item.duration)
|
||||
continue
|
||||
|
||||
if isinstance(item, np.ndarray):
|
||||
elif isinstance(item, np.ndarray):
|
||||
pcm_array = item
|
||||
else:
|
||||
raise Exception('item should be pcm_array')
|
||||
@@ -295,14 +292,17 @@ class AudioProcessor:
|
||||
# in the future we want to have different languages for each speaker etc, so it will be more complex.
|
||||
while True:
|
||||
try:
|
||||
token = await self.translation_queue.get() #block until at least 1 token
|
||||
if token is SENTINEL:
|
||||
item = await self.translation_queue.get() #block until at least 1 token
|
||||
if item is SENTINEL:
|
||||
logger.debug("Translation processor received sentinel. Finishing.")
|
||||
self.translation_queue.task_done()
|
||||
break
|
||||
elif type(item) is Silence:
|
||||
online_translation.insert_silence(item.duration)
|
||||
continue
|
||||
|
||||
# get all the available tokens for translation. The more words, the more precise
|
||||
tokens_to_process = [token]
|
||||
tokens_to_process = [item]
|
||||
additional_tokens = await get_all_from_queue(self.translation_queue)
|
||||
|
||||
sentinel_found = False
|
||||
@@ -326,7 +326,7 @@ class AudioProcessor:
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in translation_processor: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
if 'token' in locals() and token is not SENTINEL:
|
||||
if 'token' in locals() and item is not SENTINEL:
|
||||
self.translation_queue.task_done()
|
||||
if 'additional_tokens' in locals():
|
||||
for _ in additional_tokens:
|
||||
@@ -367,7 +367,7 @@ class AudioProcessor:
|
||||
if not state.tokens and not buffer_transcription and not buffer_diarization:
|
||||
response_status = "no_audio_detected"
|
||||
lines = []
|
||||
elif response_status == "active_transcription" and not lines:
|
||||
elif not lines:
|
||||
lines = [Line(
|
||||
speaker=1,
|
||||
start=state.get("end_buffer", 0),
|
||||
@@ -528,6 +528,8 @@ class AudioProcessor:
|
||||
await self.transcription_queue.put(silence_buffer)
|
||||
if self.args.diarization and self.diarization_queue:
|
||||
await self.diarization_queue.put(silence_buffer)
|
||||
if self.translation_queue:
|
||||
await self.translation_queue.put(silence_buffer)
|
||||
|
||||
if not self.silence:
|
||||
if self.args.transcription and self.transcription_queue:
|
||||
|
||||
@@ -43,10 +43,12 @@ class TranscriptionEngine:
|
||||
"transcription": True,
|
||||
"vad": True,
|
||||
"pcm_input": False,
|
||||
|
||||
# whisperstreaming params:
|
||||
"buffer_trimming": "segment",
|
||||
"confidence_validation": False,
|
||||
"buffer_trimming_sec": 15,
|
||||
|
||||
# simulstreaming params:
|
||||
"disable_fast_encoder": False,
|
||||
"frame_threshold": 25,
|
||||
@@ -61,10 +63,15 @@ class TranscriptionEngine:
|
||||
"max_context_tokens": None,
|
||||
"model_path": './base.pt',
|
||||
"diarization_backend": "sortformer",
|
||||
|
||||
# diarization params:
|
||||
"disable_punctuation_split" : False,
|
||||
"segmentation_model": "pyannote/segmentation-3.0",
|
||||
"embedding_model": "pyannote/embedding",
|
||||
"embedding_model": "pyannote/embedding",
|
||||
|
||||
# translation params:
|
||||
"nllb_backend": "ctranslate2",
|
||||
"nllb_size": "600M"
|
||||
}
|
||||
|
||||
config_dict = {**defaults, **kwargs}
|
||||
@@ -142,8 +149,7 @@ class TranscriptionEngine:
|
||||
raise Exception('Translation cannot be set with language auto')
|
||||
else:
|
||||
from whisperlivekit.translation.translation import load_model
|
||||
self.translation_model = load_model([self.args.lan]) #in the future we want to handle different languages for different speakers
|
||||
|
||||
self.translation_model = load_model([self.args.lan], backend=self.args.nllb_backend, model_size=self.args.nllb_size) #in the future we want to handle different languages for different speakers
|
||||
TranscriptionEngine._initialized = True
|
||||
|
||||
|
||||
|
||||
@@ -287,6 +287,20 @@ def parse_args():
|
||||
help="Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent instances).",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--nllb-backend",
|
||||
type=str,
|
||||
default="ctranslate2",
|
||||
help="transformers or ctranslate2",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--nllb-size",
|
||||
type=str,
|
||||
default="600M",
|
||||
help="600M or 1.3B",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
args.transcription = not args.no_transcription
|
||||
|
||||
@@ -39,7 +39,7 @@ def blank_to_silence(tokens):
|
||||
)
|
||||
else:
|
||||
if silence_token: #there was silence but no more
|
||||
if silence_token.end - silence_token.start >= MIN_SILENCE_DURATION:
|
||||
if silence_token.duration() >= MIN_SILENCE_DURATION:
|
||||
cleaned_tokens.append(
|
||||
silence_token
|
||||
)
|
||||
|
||||
@@ -123,14 +123,33 @@ def format_output(state, silence, current_time, args, debug, sep):
|
||||
|
||||
append_token_to_last_line(lines, sep, token, debug_info)
|
||||
if lines and translated_segments:
|
||||
cts_idx = 0 # current_translated_segment_idx
|
||||
for line in lines:
|
||||
while cts_idx < len(translated_segments):
|
||||
ts = translated_segments[cts_idx]
|
||||
if ts and ts.start and ts.start >= line.start and ts.end <= line.end:
|
||||
line.translation += ts.text + ' '
|
||||
cts_idx += 1
|
||||
else:
|
||||
break
|
||||
return lines, undiarized_text, buffer_transcription, ''
|
||||
|
||||
unassigned_translated_segments = []
|
||||
for ts in translated_segments:
|
||||
assigned = False
|
||||
for line in lines:
|
||||
if ts and ts.overlaps_with(line):
|
||||
if ts.is_within(line):
|
||||
line.translation += ts.text + ' '
|
||||
assigned = True
|
||||
break
|
||||
else:
|
||||
ts0, ts1 = ts.approximate_cut_at(line.end)
|
||||
if ts0 and line.overlaps_with(ts0):
|
||||
line.translation += ts0.text + ' '
|
||||
if ts1:
|
||||
unassigned_translated_segments.append(ts1)
|
||||
assigned = True
|
||||
break
|
||||
if not assigned:
|
||||
unassigned_translated_segments.append(ts)
|
||||
|
||||
if unassigned_translated_segments:
|
||||
for line in lines:
|
||||
remaining_segments = []
|
||||
for ts in unassigned_translated_segments:
|
||||
if ts and ts.overlaps_with(line):
|
||||
line.translation += ts.text + ' '
|
||||
else:
|
||||
remaining_segments.append(ts)
|
||||
unassigned_translated_segments = remaining_segments #maybe do smth in the future about that
|
||||
return lines, undiarized_text, buffer_transcription, ''
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from typing import Optional, Any
|
||||
from datetime import timedelta
|
||||
|
||||
def format_time(seconds: float) -> str:
|
||||
@@ -15,6 +15,21 @@ class TimedText:
|
||||
speaker: Optional[int] = -1
|
||||
probability: Optional[float] = None
|
||||
is_dummy: Optional[bool] = False
|
||||
|
||||
def overlaps_with(self, other: 'TimedText') -> bool:
|
||||
return not (self.end <= other.start or other.end <= self.start)
|
||||
|
||||
def is_within(self, other: 'TimedText') -> bool:
|
||||
return other.contains_timespan(self)
|
||||
|
||||
def duration(self) -> float:
|
||||
return self.end - self.start
|
||||
|
||||
def contains_time(self, time: float) -> bool:
|
||||
return self.start <= time <= self.end
|
||||
|
||||
def contains_timespan(self, other: 'TimedText') -> bool:
|
||||
return self.start <= other.start and self.end >= other.end
|
||||
|
||||
@dataclass
|
||||
class ASRToken(TimedText):
|
||||
@@ -41,6 +56,34 @@ class SpeakerSegment(TimedText):
|
||||
class Translation(TimedText):
|
||||
pass
|
||||
|
||||
def approximate_cut_at(self, cut_time):
|
||||
"""
|
||||
Each word in text is considered to be of duration (end-start)/len(words in text)
|
||||
"""
|
||||
if not self.text or not self.contains_time(cut_time):
|
||||
return self, None
|
||||
|
||||
words = self.text.split()
|
||||
num_words = len(words)
|
||||
if num_words == 0:
|
||||
return self, None
|
||||
|
||||
duration_per_word = self.duration() / num_words
|
||||
|
||||
cut_word_index = int((cut_time - self.start) / duration_per_word)
|
||||
|
||||
if cut_word_index >= num_words:
|
||||
cut_word_index = num_words -1
|
||||
|
||||
text0 = " ".join(words[:cut_word_index])
|
||||
text1 = " ".join(words[cut_word_index:])
|
||||
|
||||
segment0 = Translation(start=self.start, end=cut_time, text=text0)
|
||||
segment1 = Translation(start=cut_time, end=self.end, text=text1)
|
||||
|
||||
return segment0, segment1
|
||||
|
||||
|
||||
@dataclass
|
||||
class Silence():
|
||||
duration: float
|
||||
@@ -91,4 +134,4 @@ class State():
|
||||
end_buffer: float
|
||||
end_attributed_speaker: float
|
||||
remaining_time_transcription: float
|
||||
remaining_time_diarization: float
|
||||
remaining_time_diarization: float
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import logging
|
||||
import time
|
||||
import ctranslate2
|
||||
import torch
|
||||
import transformers
|
||||
@@ -6,38 +8,42 @@ import huggingface_hub
|
||||
from whisperlivekit.translation.mapping_languages import get_nllb_code
|
||||
from whisperlivekit.timed_objects import Translation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
#In diarization case, we may want to translate just one speaker, or at least start the sentences there
|
||||
|
||||
PUNCTUATION_MARKS = {'.', '!', '?', '。', '!', '?'}
|
||||
|
||||
MIN_SILENCE_DURATION_DEL_BUFFER = 3 #After a silence of x seconds, we consider the model should not use the buffer, even if the previous
|
||||
# sentence is not finished.
|
||||
|
||||
@dataclass
|
||||
class TranslationModel():
|
||||
translator: ctranslate2.Translator
|
||||
tokenizer: dict
|
||||
device: str
|
||||
backend_type: str = 'ctranslate2'
|
||||
|
||||
def load_model(src_langs):
|
||||
MODEL = 'nllb-200-distilled-600M-ctranslate2'
|
||||
MODEL_GUY = 'entai2965'
|
||||
huggingface_hub.snapshot_download(MODEL_GUY + '/' + MODEL,local_dir=MODEL)
|
||||
def load_model(src_langs, backend='ctranslate2', model_size='600M'):
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
translator = ctranslate2.Translator(MODEL,device=device)
|
||||
MODEL = f'nllb-200-distilled-{model_size}-ctranslate2'
|
||||
if backend=='ctranslate2':
|
||||
MODEL_GUY = 'entai2965'
|
||||
huggingface_hub.snapshot_download(MODEL_GUY + '/' + MODEL,local_dir=MODEL)
|
||||
translator = ctranslate2.Translator(MODEL,device=device)
|
||||
elif backend=='transformers':
|
||||
translator = transformers.AutoModelForSeq2SeqLM.from_pretrained(f"facebook/nllb-200-distilled-{model_size}")
|
||||
tokenizer = dict()
|
||||
for src_lang in src_langs:
|
||||
tokenizer[src_lang] = transformers.AutoTokenizer.from_pretrained(MODEL, src_lang=src_lang, clean_up_tokenization_spaces=True)
|
||||
|
||||
return TranslationModel(
|
||||
translator=translator,
|
||||
tokenizer=tokenizer
|
||||
tokenizer=tokenizer,
|
||||
backend_type=backend,
|
||||
device = device
|
||||
)
|
||||
|
||||
def translate(input, translation_model, tgt_lang):
|
||||
source = translation_model.tokenizer.convert_ids_to_tokens(translation_model.tokenizer.encode(input))
|
||||
target_prefix = [tgt_lang]
|
||||
results = translation_model.translator.translate_batch([source], target_prefix=[target_prefix])
|
||||
target = results[0].hypotheses[0][1:]
|
||||
return translation_model.tokenizer.decode(translation_model.tokenizer.convert_tokens_to_ids(target))
|
||||
|
||||
class OnlineTranslation:
|
||||
def __init__(self, translation_model: TranslationModel, input_languages: list, output_languages: list):
|
||||
self.buffer = []
|
||||
@@ -68,12 +74,19 @@ class OnlineTranslation:
|
||||
output_lang = self.output_languages[0]
|
||||
nllb_output_lang = get_nllb_code(output_lang)
|
||||
|
||||
source = self.translation_model.tokenizer[input_lang].convert_ids_to_tokens(self.translation_model.tokenizer[input_lang].encode(input))
|
||||
results = self.translation_model.translator.translate_batch([source], target_prefix=[[nllb_output_lang]]) #we can use return_attention=True to try to optimize the stuff.
|
||||
target = results[0].hypotheses[0][1:]
|
||||
results = self.translation_model.tokenizer[input_lang].decode(self.translation_model.tokenizer[input_lang].convert_tokens_to_ids(target))
|
||||
return results
|
||||
|
||||
tokenizer = self.translation_model.tokenizer[input_lang]
|
||||
tokenizer_output = tokenizer(input, return_tensors="pt").to(self.translation_model.device)
|
||||
|
||||
if self.translation_model.backend_type == 'ctranslate2':
|
||||
source = tokenizer.convert_ids_to_tokens(tokenizer_output['input_ids'][0])
|
||||
results = self.translation_model.translator.translate_batch([source], target_prefix=[[nllb_output_lang]])
|
||||
target = results[0].hypotheses[0][1:]
|
||||
result = tokenizer.decode(tokenizer.convert_tokens_to_ids(target))
|
||||
else:
|
||||
translated_tokens = self.translation_model.translator.generate(**tokenizer_output, forced_bos_token_id=tokenizer.convert_tokens_to_ids(nllb_output_lang))
|
||||
result = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
|
||||
return result
|
||||
|
||||
def translate_tokens(self, tokens):
|
||||
if tokens:
|
||||
text = ' '.join([token.text for token in tokens])
|
||||
@@ -88,7 +101,6 @@ class OnlineTranslation:
|
||||
return translation
|
||||
return None
|
||||
|
||||
|
||||
|
||||
def insert_tokens(self, tokens):
|
||||
self.buffer.extend(tokens)
|
||||
@@ -109,7 +121,11 @@ class OnlineTranslation:
|
||||
self.translation_remaining = self.translate_tokens(self.buffer)
|
||||
self.len_processed_buffer = len(self.buffer)
|
||||
return self.validated + [self.translation_remaining]
|
||||
|
||||
|
||||
def insert_silence(self, silence_duration: float):
|
||||
if silence_duration >= MIN_SILENCE_DURATION_DEL_BUFFER:
|
||||
self.buffer = []
|
||||
self.validated += [self.translation_remaining]
|
||||
|
||||
if __name__ == '__main__':
|
||||
output_lang = 'fr'
|
||||
@@ -122,16 +138,13 @@ if __name__ == '__main__':
|
||||
test = test_string.split(' ')
|
||||
step = len(test) // 3
|
||||
|
||||
shared_model = load_model([input_lang])
|
||||
shared_model = load_model([input_lang], backend='ctranslate2')
|
||||
online_translation = OnlineTranslation(shared_model, input_languages=[input_lang], output_languages=[output_lang])
|
||||
|
||||
|
||||
beg_inference = time.time()
|
||||
for id in range(5):
|
||||
val = test[id*step : (id+1)*step]
|
||||
val_str = ' '.join(val)
|
||||
result = online_translation.translate(val_str)
|
||||
print(result)
|
||||
|
||||
|
||||
|
||||
|
||||
# print(result)
|
||||
print('inference time:', time.time() - beg_inference)
|
||||
@@ -438,7 +438,6 @@ label {
|
||||
font-size: 13px;
|
||||
border-radius: 30px;
|
||||
padding: 2px 10px;
|
||||
display: none;
|
||||
}
|
||||
|
||||
.loading {
|
||||
|
||||
Reference in New Issue
Block a user