mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-04-27 00:56:20 +00:00
Compare commits
9 Commits
0.2.17.pos
...
0.2.18
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e7e82f7c19 | ||
|
|
8c799fa4d1 | ||
|
|
8923337380 | ||
|
|
aded1649ae | ||
|
|
3b535e857a | ||
|
|
d649250b9a | ||
|
|
7735478286 | ||
|
|
b9e72d2b9a | ||
|
|
e5b01033af |
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "whisperlivekit"
|
name = "whisperlivekit"
|
||||||
version = "0.2.17.post1"
|
version = "0.2.18"
|
||||||
description = "Real-time speech-to-text with speaker diarization using Whisper"
|
description = "Real-time speech-to-text with speaker diarization using Whisper"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
authors = [
|
authors = [
|
||||||
@@ -69,4 +69,5 @@ packages = [
|
|||||||
[tool.setuptools.package-data]
|
[tool.setuptools.package-data]
|
||||||
whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"]
|
whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"]
|
||||||
"whisperlivekit.whisper.assets" = ["*.tiktoken", "*.npz"]
|
"whisperlivekit.whisper.assets" = ["*.tiktoken", "*.npz"]
|
||||||
|
"whisperlivekit.whisper.normalizers" = ["*.json"]
|
||||||
"whisperlivekit.silero_vad_models" = ["*.jit", "*.onnx"]
|
"whisperlivekit.silero_vad_models" = ["*.jit", "*.onnx"]
|
||||||
|
|||||||
@@ -202,14 +202,14 @@ class DiartDiarization:
|
|||||||
def insert_silence(self, silence_duration):
|
def insert_silence(self, silence_duration):
|
||||||
self.observer.global_time_offset += silence_duration
|
self.observer.global_time_offset += silence_duration
|
||||||
|
|
||||||
async def diarize(self, pcm_array: np.ndarray):
|
def insert_audio_chunk(self, pcm_array: np.ndarray):
|
||||||
"""
|
"""Buffer audio for the next diarization step."""
|
||||||
Process audio data for diarization.
|
|
||||||
Only used when working with WebSocketAudioSource.
|
|
||||||
"""
|
|
||||||
if self.custom_source:
|
if self.custom_source:
|
||||||
self.custom_source.push_audio(pcm_array)
|
self.custom_source.push_audio(pcm_array)
|
||||||
# self.observer.clear_old_segments()
|
|
||||||
|
async def diarize(self):
|
||||||
|
"""Return the current speaker segments from the diarization pipeline."""
|
||||||
|
return self.observer.get_segments()
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""Close the audio source."""
|
"""Close the audio source."""
|
||||||
|
|||||||
@@ -151,7 +151,7 @@ class FasterWhisperASR(ASRBase):
|
|||||||
if segment.no_speech_prob > 0.9:
|
if segment.no_speech_prob > 0.9:
|
||||||
continue
|
continue
|
||||||
for word in segment.words:
|
for word in segment.words:
|
||||||
token = ASRToken(word.start, word.end, word.word)
|
token = ASRToken(word.start, word.end, word.word, probability=word.probability)
|
||||||
tokens.append(token)
|
tokens.append(token)
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
@@ -249,6 +249,7 @@ class OpenaiApiASR(ASRBase):
|
|||||||
self.load_model()
|
self.load_model()
|
||||||
self.use_vad_opt = False
|
self.use_vad_opt = False
|
||||||
self.direct_english_translation = False
|
self.direct_english_translation = False
|
||||||
|
self.task = "transcribe"
|
||||||
|
|
||||||
def load_model(self, *args, **kwargs):
|
def load_model(self, *args, **kwargs):
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
@@ -294,7 +295,8 @@ class OpenaiApiASR(ASRBase):
|
|||||||
params["language"] = self.original_language
|
params["language"] = self.original_language
|
||||||
if prompt:
|
if prompt:
|
||||||
params["prompt"] = prompt
|
params["prompt"] = prompt
|
||||||
proc = self.client.audio.translations if self.task == "translate" else self.client.audio.transcriptions
|
task = self.transcribe_kargs.get("task", self.task)
|
||||||
|
proc = self.client.audio.translations if task == "translate" else self.client.audio.transcriptions
|
||||||
transcript = proc.create(**params)
|
transcript = proc.create(**params)
|
||||||
logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds")
|
logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds")
|
||||||
return transcript
|
return transcript
|
||||||
|
|||||||
@@ -146,6 +146,7 @@ def backend_factory(
|
|||||||
|
|
||||||
if direct_english_translation:
|
if direct_english_translation:
|
||||||
tgt_language = "en" # Whisper translates into English
|
tgt_language = "en" # Whisper translates into English
|
||||||
|
asr.transcribe_kargs["task"] = "translate"
|
||||||
else:
|
else:
|
||||||
tgt_language = lan # Whisper transcribes in this language
|
tgt_language = lan # Whisper transcribes in this language
|
||||||
|
|
||||||
@@ -154,9 +155,9 @@ def backend_factory(
|
|||||||
tokenizer = create_tokenizer(tgt_language)
|
tokenizer = create_tokenizer(tgt_language)
|
||||||
else:
|
else:
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
|
|
||||||
warmup_asr(asr, warmup_file)
|
warmup_asr(asr, warmup_file)
|
||||||
|
|
||||||
asr.confidence_validation = confidence_validation
|
asr.confidence_validation = confidence_validation
|
||||||
asr.tokenizer = tokenizer
|
asr.tokenizer = tokenizer
|
||||||
asr.buffer_trimming = buffer_trimming
|
asr.buffer_trimming = buffer_trimming
|
||||||
|
|||||||
@@ -46,8 +46,6 @@ class SimulStreamingOnlineProcessor:
|
|||||||
self.logfile = logfile
|
self.logfile = logfile
|
||||||
self.end = 0.0
|
self.end = 0.0
|
||||||
self.buffer = []
|
self.buffer = []
|
||||||
self.committed: List[ASRToken] = []
|
|
||||||
self.last_result_tokens: List[ASRToken] = []
|
|
||||||
self.model = self._create_alignatt()
|
self.model = self._create_alignatt()
|
||||||
|
|
||||||
if asr.tokenizer:
|
if asr.tokenizer:
|
||||||
@@ -122,7 +120,6 @@ class SimulStreamingOnlineProcessor:
|
|||||||
self.buffer.extend(timestamped_words)
|
self.buffer.extend(timestamped_words)
|
||||||
return [], self.end
|
return [], self.end
|
||||||
|
|
||||||
self.committed.extend(timestamped_words)
|
|
||||||
self.buffer = []
|
self.buffer = []
|
||||||
return timestamped_words, self.end
|
return timestamped_words, self.end
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -217,7 +214,7 @@ class SimulStreamingASR:
|
|||||||
cif_ckpt_path=self.cif_ckpt_path,
|
cif_ckpt_path=self.cif_ckpt_path,
|
||||||
decoder_type="beam",
|
decoder_type="beam",
|
||||||
beam_size=self.beams,
|
beam_size=self.beams,
|
||||||
task=self.direct_english_translation,
|
task="translate" if self.direct_english_translation else "transcribe",
|
||||||
never_fire=self.never_fire,
|
never_fire=self.never_fire,
|
||||||
init_prompt=self.init_prompt,
|
init_prompt=self.init_prompt,
|
||||||
max_context_tokens=self.max_context_tokens,
|
max_context_tokens=self.max_context_tokens,
|
||||||
@@ -330,7 +327,7 @@ class SimulStreamingASR:
|
|||||||
lora_path = getattr(self, 'lora_path', None)
|
lora_path = getattr(self, 'lora_path', None)
|
||||||
whisper_model = load_model(
|
whisper_model = load_model(
|
||||||
name=model_ref,
|
name=model_ref,
|
||||||
download_root=None,
|
download_root=getattr(self, 'model_cache_dir', None),
|
||||||
decoder_only=self.fast_encoder,
|
decoder_only=self.fast_encoder,
|
||||||
custom_alignment_heads=self.custom_alignment_heads,
|
custom_alignment_heads=self.custom_alignment_heads,
|
||||||
lora_path=lora_path,
|
lora_path=lora_path,
|
||||||
|
|||||||
@@ -532,7 +532,9 @@ class MLXAlignAtt:
|
|||||||
accumulated_cross_attns = []
|
accumulated_cross_attns = []
|
||||||
|
|
||||||
audio_duration_s = self.segments_len()
|
audio_duration_s = self.segments_len()
|
||||||
max_tokens_per_chunk = max(50, int(audio_duration_s * TOKENS_PER_SECOND * 2.0))
|
# ~15 text tokens/s is a generous upper bound for speech; TOKENS_PER_SECOND (50)
|
||||||
|
# is the mel-frame rate and was causing 10-40x over-allocation on repetition loops.
|
||||||
|
max_tokens_per_chunk = max(50, int(audio_duration_s * 15 * 1.5))
|
||||||
tokens_produced_this_chunk = 0
|
tokens_produced_this_chunk = 0
|
||||||
|
|
||||||
while not completed and current_tokens.shape[1] < self.max_text_len:
|
while not completed and current_tokens.shape[1] < self.max_text_len:
|
||||||
@@ -558,6 +560,8 @@ class MLXAlignAtt:
|
|||||||
mx.eval(logits)
|
mx.eval(logits)
|
||||||
|
|
||||||
accumulated_cross_attns.append(cross_qk)
|
accumulated_cross_attns.append(cross_qk)
|
||||||
|
if len(accumulated_cross_attns) > 16:
|
||||||
|
accumulated_cross_attns = accumulated_cross_attns[-16:]
|
||||||
|
|
||||||
if new_segment and self.tokenizer.no_speech is not None:
|
if new_segment and self.tokenizer.no_speech is not None:
|
||||||
probs_at_sot = mx.softmax(logits[:, self.state.sot_index, :], axis=-1)
|
probs_at_sot = mx.softmax(logits[:, self.state.sot_index, :], axis=-1)
|
||||||
|
|||||||
@@ -390,7 +390,6 @@ class AlignAtt:
|
|||||||
return []
|
return []
|
||||||
if not self._apply_minseglen():
|
if not self._apply_minseglen():
|
||||||
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
|
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
|
||||||
input_segments = torch.cat(self.state.segments, dim=0)
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# input_segments is concatenation of audio, it's one array
|
# input_segments is concatenation of audio, it's one array
|
||||||
@@ -485,7 +484,9 @@ class AlignAtt:
|
|||||||
accumulated_cross_attns = []
|
accumulated_cross_attns = []
|
||||||
|
|
||||||
audio_duration_s = self.segments_len()
|
audio_duration_s = self.segments_len()
|
||||||
max_tokens_per_chunk = max(50, int(audio_duration_s * TOKENS_PER_SECOND * 2.0)) # 2x margin, min 50
|
# ~15 text tokens/s is a generous upper bound for speech; TOKENS_PER_SECOND (50)
|
||||||
|
# is the mel-frame rate and was causing 10-40x over-allocation on repetition loops.
|
||||||
|
max_tokens_per_chunk = max(50, int(audio_duration_s * 15 * 1.5))
|
||||||
tokens_produced_this_chunk = 0
|
tokens_produced_this_chunk = 0
|
||||||
|
|
||||||
while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens
|
while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens
|
||||||
@@ -506,8 +507,12 @@ class AlignAtt:
|
|||||||
result = self.logits(tokens_for_logits, encoder_feature, return_cross_attn=True)
|
result = self.logits(tokens_for_logits, encoder_feature, return_cross_attn=True)
|
||||||
logits, cross_attns = result
|
logits, cross_attns = result
|
||||||
|
|
||||||
# Accumulate cross-attention from this forward pass
|
# Accumulate cross-attention from this forward pass (rolling window to
|
||||||
|
# bound VRAM — only the last entry matters for alignment, and the
|
||||||
|
# median_filter kernel is 7, so 16 entries is more than enough).
|
||||||
accumulated_cross_attns.append(cross_attns)
|
accumulated_cross_attns.append(cross_attns)
|
||||||
|
if len(accumulated_cross_attns) > 16:
|
||||||
|
accumulated_cross_attns = accumulated_cross_attns[-16:]
|
||||||
|
|
||||||
if new_segment and self.tokenizer.no_speech is not None:
|
if new_segment and self.tokenizer.no_speech is not None:
|
||||||
probs_at_sot = logits[:, self.state.sot_index, :].float().softmax(dim=-1)
|
probs_at_sot = logits[:, self.state.sot_index, :].float().softmax(dim=-1)
|
||||||
|
|||||||
@@ -39,10 +39,11 @@ class TimedText(Timed):
|
|||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class ASRToken(TimedText):
|
class ASRToken(TimedText):
|
||||||
|
probability: Optional[float] = None
|
||||||
|
|
||||||
def with_offset(self, offset: float) -> "ASRToken":
|
def with_offset(self, offset: float) -> "ASRToken":
|
||||||
"""Return a new token with the time offset added."""
|
"""Return a new token with the time offset added."""
|
||||||
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, detected_language=self.detected_language)
|
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, detected_language=self.detected_language, probability=self.probability)
|
||||||
|
|
||||||
def is_silence(self) -> bool:
|
def is_silence(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -53,7 +53,8 @@ class TokensAlignment:
|
|||||||
segment.translation = ''
|
segment.translation = ''
|
||||||
for ts in self.all_translation_segments:
|
for ts in self.all_translation_segments:
|
||||||
if ts.is_within(segment):
|
if ts.is_within(segment):
|
||||||
segment.translation += ts.text + (self.sep if ts.text else '')
|
if ts.text:
|
||||||
|
segment.translation += ts.text + self.sep
|
||||||
elif segment.translation:
|
elif segment.translation:
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -185,11 +186,11 @@ class TokensAlignment:
|
|||||||
else:
|
else:
|
||||||
diarization_buffer = ''
|
diarization_buffer = ''
|
||||||
for token in self.new_tokens:
|
for token in self.new_tokens:
|
||||||
if token.is_silence():
|
if isinstance(token, Silence):
|
||||||
if self.current_line_tokens:
|
if self.current_line_tokens:
|
||||||
self.validated_segments.append(Segment().from_tokens(self.current_line_tokens))
|
self.validated_segments.append(Segment.from_tokens(self.current_line_tokens))
|
||||||
self.current_line_tokens = []
|
self.current_line_tokens = []
|
||||||
|
|
||||||
end_silence = token.end if token.has_ended else time() - self.beg_loop
|
end_silence = token.end if token.has_ended else time() - self.beg_loop
|
||||||
if self.validated_segments and self.validated_segments[-1].is_silence():
|
if self.validated_segments and self.validated_segments[-1].is_silence():
|
||||||
self.validated_segments[-1].end = end_silence
|
self.validated_segments[-1].end = end_silence
|
||||||
@@ -203,7 +204,7 @@ class TokensAlignment:
|
|||||||
|
|
||||||
segments = list(self.validated_segments)
|
segments = list(self.validated_segments)
|
||||||
if self.current_line_tokens:
|
if self.current_line_tokens:
|
||||||
segments.append(Segment().from_tokens(self.current_line_tokens))
|
segments.append(Segment.from_tokens(self.current_line_tokens))
|
||||||
|
|
||||||
if current_silence:
|
if current_silence:
|
||||||
end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop
|
end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop
|
||||||
|
|||||||
Reference in New Issue
Block a user