mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-08 06:44:09 +00:00
Compare commits
9 Commits
0.2.17.pos
...
0.2.18
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e7e82f7c19 | ||
|
|
8c799fa4d1 | ||
|
|
8923337380 | ||
|
|
aded1649ae | ||
|
|
3b535e857a | ||
|
|
d649250b9a | ||
|
|
7735478286 | ||
|
|
b9e72d2b9a | ||
|
|
e5b01033af |
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "whisperlivekit"
|
||||
version = "0.2.17.post1"
|
||||
version = "0.2.18"
|
||||
description = "Real-time speech-to-text with speaker diarization using Whisper"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
@@ -69,4 +69,5 @@ packages = [
|
||||
[tool.setuptools.package-data]
|
||||
whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"]
|
||||
"whisperlivekit.whisper.assets" = ["*.tiktoken", "*.npz"]
|
||||
"whisperlivekit.whisper.normalizers" = ["*.json"]
|
||||
"whisperlivekit.silero_vad_models" = ["*.jit", "*.onnx"]
|
||||
|
||||
@@ -202,14 +202,14 @@ class DiartDiarization:
|
||||
def insert_silence(self, silence_duration):
|
||||
self.observer.global_time_offset += silence_duration
|
||||
|
||||
async def diarize(self, pcm_array: np.ndarray):
|
||||
"""
|
||||
Process audio data for diarization.
|
||||
Only used when working with WebSocketAudioSource.
|
||||
"""
|
||||
def insert_audio_chunk(self, pcm_array: np.ndarray):
|
||||
"""Buffer audio for the next diarization step."""
|
||||
if self.custom_source:
|
||||
self.custom_source.push_audio(pcm_array)
|
||||
# self.observer.clear_old_segments()
|
||||
self.custom_source.push_audio(pcm_array)
|
||||
|
||||
async def diarize(self):
|
||||
"""Return the current speaker segments from the diarization pipeline."""
|
||||
return self.observer.get_segments()
|
||||
|
||||
def close(self):
|
||||
"""Close the audio source."""
|
||||
|
||||
@@ -151,7 +151,7 @@ class FasterWhisperASR(ASRBase):
|
||||
if segment.no_speech_prob > 0.9:
|
||||
continue
|
||||
for word in segment.words:
|
||||
token = ASRToken(word.start, word.end, word.word)
|
||||
token = ASRToken(word.start, word.end, word.word, probability=word.probability)
|
||||
tokens.append(token)
|
||||
return tokens
|
||||
|
||||
@@ -249,6 +249,7 @@ class OpenaiApiASR(ASRBase):
|
||||
self.load_model()
|
||||
self.use_vad_opt = False
|
||||
self.direct_english_translation = False
|
||||
self.task = "transcribe"
|
||||
|
||||
def load_model(self, *args, **kwargs):
|
||||
from openai import OpenAI
|
||||
@@ -294,7 +295,8 @@ class OpenaiApiASR(ASRBase):
|
||||
params["language"] = self.original_language
|
||||
if prompt:
|
||||
params["prompt"] = prompt
|
||||
proc = self.client.audio.translations if self.task == "translate" else self.client.audio.transcriptions
|
||||
task = self.transcribe_kargs.get("task", self.task)
|
||||
proc = self.client.audio.translations if task == "translate" else self.client.audio.transcriptions
|
||||
transcript = proc.create(**params)
|
||||
logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds")
|
||||
return transcript
|
||||
|
||||
@@ -146,6 +146,7 @@ def backend_factory(
|
||||
|
||||
if direct_english_translation:
|
||||
tgt_language = "en" # Whisper translates into English
|
||||
asr.transcribe_kargs["task"] = "translate"
|
||||
else:
|
||||
tgt_language = lan # Whisper transcribes in this language
|
||||
|
||||
@@ -154,9 +155,9 @@ def backend_factory(
|
||||
tokenizer = create_tokenizer(tgt_language)
|
||||
else:
|
||||
tokenizer = None
|
||||
|
||||
|
||||
warmup_asr(asr, warmup_file)
|
||||
|
||||
|
||||
asr.confidence_validation = confidence_validation
|
||||
asr.tokenizer = tokenizer
|
||||
asr.buffer_trimming = buffer_trimming
|
||||
|
||||
@@ -46,8 +46,6 @@ class SimulStreamingOnlineProcessor:
|
||||
self.logfile = logfile
|
||||
self.end = 0.0
|
||||
self.buffer = []
|
||||
self.committed: List[ASRToken] = []
|
||||
self.last_result_tokens: List[ASRToken] = []
|
||||
self.model = self._create_alignatt()
|
||||
|
||||
if asr.tokenizer:
|
||||
@@ -122,7 +120,6 @@ class SimulStreamingOnlineProcessor:
|
||||
self.buffer.extend(timestamped_words)
|
||||
return [], self.end
|
||||
|
||||
self.committed.extend(timestamped_words)
|
||||
self.buffer = []
|
||||
return timestamped_words, self.end
|
||||
except Exception as e:
|
||||
@@ -217,7 +214,7 @@ class SimulStreamingASR:
|
||||
cif_ckpt_path=self.cif_ckpt_path,
|
||||
decoder_type="beam",
|
||||
beam_size=self.beams,
|
||||
task=self.direct_english_translation,
|
||||
task="translate" if self.direct_english_translation else "transcribe",
|
||||
never_fire=self.never_fire,
|
||||
init_prompt=self.init_prompt,
|
||||
max_context_tokens=self.max_context_tokens,
|
||||
@@ -330,7 +327,7 @@ class SimulStreamingASR:
|
||||
lora_path = getattr(self, 'lora_path', None)
|
||||
whisper_model = load_model(
|
||||
name=model_ref,
|
||||
download_root=None,
|
||||
download_root=getattr(self, 'model_cache_dir', None),
|
||||
decoder_only=self.fast_encoder,
|
||||
custom_alignment_heads=self.custom_alignment_heads,
|
||||
lora_path=lora_path,
|
||||
|
||||
@@ -532,7 +532,9 @@ class MLXAlignAtt:
|
||||
accumulated_cross_attns = []
|
||||
|
||||
audio_duration_s = self.segments_len()
|
||||
max_tokens_per_chunk = max(50, int(audio_duration_s * TOKENS_PER_SECOND * 2.0))
|
||||
# ~15 text tokens/s is a generous upper bound for speech; TOKENS_PER_SECOND (50)
|
||||
# is the mel-frame rate and was causing 10-40x over-allocation on repetition loops.
|
||||
max_tokens_per_chunk = max(50, int(audio_duration_s * 15 * 1.5))
|
||||
tokens_produced_this_chunk = 0
|
||||
|
||||
while not completed and current_tokens.shape[1] < self.max_text_len:
|
||||
@@ -558,6 +560,8 @@ class MLXAlignAtt:
|
||||
mx.eval(logits)
|
||||
|
||||
accumulated_cross_attns.append(cross_qk)
|
||||
if len(accumulated_cross_attns) > 16:
|
||||
accumulated_cross_attns = accumulated_cross_attns[-16:]
|
||||
|
||||
if new_segment and self.tokenizer.no_speech is not None:
|
||||
probs_at_sot = mx.softmax(logits[:, self.state.sot_index, :], axis=-1)
|
||||
|
||||
@@ -390,7 +390,6 @@ class AlignAtt:
|
||||
return []
|
||||
if not self._apply_minseglen():
|
||||
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
|
||||
input_segments = torch.cat(self.state.segments, dim=0)
|
||||
return []
|
||||
|
||||
# input_segments is concatenation of audio, it's one array
|
||||
@@ -485,7 +484,9 @@ class AlignAtt:
|
||||
accumulated_cross_attns = []
|
||||
|
||||
audio_duration_s = self.segments_len()
|
||||
max_tokens_per_chunk = max(50, int(audio_duration_s * TOKENS_PER_SECOND * 2.0)) # 2x margin, min 50
|
||||
# ~15 text tokens/s is a generous upper bound for speech; TOKENS_PER_SECOND (50)
|
||||
# is the mel-frame rate and was causing 10-40x over-allocation on repetition loops.
|
||||
max_tokens_per_chunk = max(50, int(audio_duration_s * 15 * 1.5))
|
||||
tokens_produced_this_chunk = 0
|
||||
|
||||
while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens
|
||||
@@ -506,8 +507,12 @@ class AlignAtt:
|
||||
result = self.logits(tokens_for_logits, encoder_feature, return_cross_attn=True)
|
||||
logits, cross_attns = result
|
||||
|
||||
# Accumulate cross-attention from this forward pass
|
||||
# Accumulate cross-attention from this forward pass (rolling window to
|
||||
# bound VRAM — only the last entry matters for alignment, and the
|
||||
# median_filter kernel is 7, so 16 entries is more than enough).
|
||||
accumulated_cross_attns.append(cross_attns)
|
||||
if len(accumulated_cross_attns) > 16:
|
||||
accumulated_cross_attns = accumulated_cross_attns[-16:]
|
||||
|
||||
if new_segment and self.tokenizer.no_speech is not None:
|
||||
probs_at_sot = logits[:, self.state.sot_index, :].float().softmax(dim=-1)
|
||||
|
||||
@@ -39,10 +39,11 @@ class TimedText(Timed):
|
||||
|
||||
@dataclass()
|
||||
class ASRToken(TimedText):
|
||||
|
||||
probability: Optional[float] = None
|
||||
|
||||
def with_offset(self, offset: float) -> "ASRToken":
|
||||
"""Return a new token with the time offset added."""
|
||||
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, detected_language=self.detected_language)
|
||||
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, detected_language=self.detected_language, probability=self.probability)
|
||||
|
||||
def is_silence(self) -> bool:
|
||||
return False
|
||||
|
||||
@@ -53,7 +53,8 @@ class TokensAlignment:
|
||||
segment.translation = ''
|
||||
for ts in self.all_translation_segments:
|
||||
if ts.is_within(segment):
|
||||
segment.translation += ts.text + (self.sep if ts.text else '')
|
||||
if ts.text:
|
||||
segment.translation += ts.text + self.sep
|
||||
elif segment.translation:
|
||||
break
|
||||
|
||||
@@ -185,11 +186,11 @@ class TokensAlignment:
|
||||
else:
|
||||
diarization_buffer = ''
|
||||
for token in self.new_tokens:
|
||||
if token.is_silence():
|
||||
if isinstance(token, Silence):
|
||||
if self.current_line_tokens:
|
||||
self.validated_segments.append(Segment().from_tokens(self.current_line_tokens))
|
||||
self.validated_segments.append(Segment.from_tokens(self.current_line_tokens))
|
||||
self.current_line_tokens = []
|
||||
|
||||
|
||||
end_silence = token.end if token.has_ended else time() - self.beg_loop
|
||||
if self.validated_segments and self.validated_segments[-1].is_silence():
|
||||
self.validated_segments[-1].end = end_silence
|
||||
@@ -203,7 +204,7 @@ class TokensAlignment:
|
||||
|
||||
segments = list(self.validated_segments)
|
||||
if self.current_line_tokens:
|
||||
segments.append(Segment().from_tokens(self.current_line_tokens))
|
||||
segments.append(Segment.from_tokens(self.current_line_tokens))
|
||||
|
||||
if current_silence:
|
||||
end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop
|
||||
|
||||
Reference in New Issue
Block a user