each SimulStreamingOnlineProcessor now contains PaddedAlignAttWhisper instance. SimulStreamingASR only contains loaded whisper model

This commit is contained in:
Quentin Fuxa
2025-08-11 08:24:14 +02:00
parent 4e56130a40
commit d098af3185
6 changed files with 55 additions and 65 deletions

View File

@@ -13,13 +13,16 @@
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT/Dual Licensed-dark_green"></a>
</p>
Built on [WhisperStreaming](https://github.com/ufal/whisper_streaming) and [SimulStreaming](https://github.com/ufal/SimulStreaming), WhisperLiveKit provides real-time speech transcription in your browser, with a ready-to-use backend and a simple, customizable frontend. ✨
WhisperLiveKit brings real-time speech transcription directly to your browser, with a ready-to-use backend+server and a simple frontend. ✨w
Built on [SimulStreaming](https://github.com/ufal/SimulStreaming) (SOTA 2025) and [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) for transcription, plus [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) and [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) for diarization.
### Key Features
- **Real-time Transcription** - Locally (or on-prem) convert speech to text instantly as you speak
- **Speaker Diarization** - Identify different speakers in real-time using [Diart](https://github.com/juanmc2005/diart)
- **Speaker Diarization** - Identify different speakers in real-time. (⚠️ backend Streaming Sortformer in developement)
- **Multi-User Support** - Handle multiple users simultaneously with a single backend/server
- **Automatic Silence Chunking** Automatically chunks when no audio is detected to limit buffer size
- **Confidence Validation** Immediately validate high-confidence tokens for faster inference (WhisperStreaming only)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 348 KiB

After

Width:  |  Height:  |  Size: 382 KiB

View File

@@ -109,7 +109,7 @@ class TranscriptionEngine:
else:
self.asr, self.tokenizer = backend_factory(self.args)
warmup_asr(self.asr, self.args.warmup_file)
warmup_asr(self.asr, self.args.warmup_file) #for simulstreaming, warmup should be done in the online class not here
if self.args.diarization:
from whisperlivekit.diarization.diarization_online import DiartDiarization

View File

@@ -5,6 +5,8 @@ from typing import List, Tuple, Optional
import logging
from whisperlivekit.timed_objects import ASRToken, Transcript
from whisperlivekit.simul_whisper.license_simulstreaming import SIMULSTREAMING_LICENSE
from .whisper import load_model, tokenizer
import os
logger = logging.getLogger(__name__)
try:
@@ -34,7 +36,11 @@ class SimulStreamingOnlineProcessor:
self.committed: List[ASRToken] = []
self.last_result_tokens: List[ASRToken] = []
self.buffer_content = ""
self.model = PaddedAlignAttWhisper(
cfg=asr.cfg,
loaded_model=asr.whisper_model)
if asr.tokenizer:
self.model.tokenizer = asr.tokenizer
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: Optional[float] = None):
"""Append an audio chunk to be processed by SimulStreaming."""
@@ -50,17 +56,13 @@ class SimulStreamingOnlineProcessor:
self.end = audio_stream_end_time
else:
self.end = self.cumulative_audio_duration
self.asr.model.insert_audio(audio_tensor)
self.model.insert_audio(audio_tensor)
def get_buffer(self):
"""
Get the unvalidated buffer content.
"""
buffer_end = self.end if hasattr(self, 'end') else None
return Transcript(
start=None,
end=buffer_end,
text=self.buffer_content,
end=None,
text='',
probability=None
)
@@ -68,7 +70,7 @@ class SimulStreamingOnlineProcessor:
# From the simulstreaming repo. self.model to self.asr.model
pr = generation["progress"]
if "result" not in generation:
split_words, split_tokens = self.asr.model.tokenizer.split_to_word_tokens(tokens)
split_words, split_tokens = self.model.tokenizer.split_to_word_tokens(tokens)
else:
split_words, split_tokens = generation["result"]["split_words"], generation["result"]["split_tokens"]
@@ -96,7 +98,7 @@ class SimulStreamingOnlineProcessor:
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
"""
try:
tokens, generation_progress = self.asr.model.infer(is_last=self.is_last)
tokens, generation_progress = self.model.infer(is_last=self.is_last)
ts_words = self.timestamped_text(tokens, generation_progress)
new_tokens = []
@@ -162,15 +164,17 @@ class SimulStreamingASR():
}
self.model_path = model_mapping.get(modelsize, f'./{modelsize}.pt')
self.model = self.load_model(modelsize, cache_dir, model_dir)
self.model = self.load_model(modelsize)
# Set up tokenizer for translation if needed
if self.task == "translate":
self.set_translate_task()
self.tokenizer = self.set_translate_task()
else:
self.tokenizer = None
def load_model(self, modelsize, cache_dir, model_dir):
try:
cfg = AlignAttConfig(
def load_model(self, modelsize):
self.cfg = AlignAttConfig(
model_path=self.model_path,
segment_length=self.segment_length,
frame_threshold=self.frame_threshold,
@@ -185,36 +189,29 @@ class SimulStreamingASR():
init_prompt=self.init_prompt,
max_context_tokens=self.max_context_tokens,
static_init_prompt=self.static_init_prompt,
)
model = PaddedAlignAttWhisper(cfg)
return model
except Exception as e:
logger.error(f"Failed to load SimulStreaming model: {e}")
raise
)
model_name = os.path.basename(self.cfg.model_path).replace(".pt", "")
model_path = os.path.dirname(os.path.abspath(self.cfg.model_path))
self.whisper_model = load_model(name=model_name, download_root=model_path)
def set_translate_task(self):
"""Set up translation task."""
try:
self.model.tokenizer = tokenizer.get_tokenizer(
multilingual=True,
language=self.model.cfg.language,
num_languages=self.model.model.num_languages,
task="translate"
)
logger.info("SimulStreaming configured for translation task")
except Exception as e:
logger.error(f"Failed to configure SimulStreaming for translation: {e}")
raise
return tokenizer.get_tokenizer(
multilingual=True,
language=self.model.cfg.language,
num_languages=self.model.model.num_languages,
task="translate"
)
def warmup(self, audio, init_prompt=""):
"""Warmup the SimulStreaming model."""
try:
if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio).float()
self.model.insert_audio(audio)
self.model.infer(True)
self.model.refresh_segment(complete=True)
logger.info("SimulStreaming model warmed up successfully")
except Exception as e:
logger.exception(f"SimulStreaming warmup failed: {e}")
# def warmup(self, audio, init_prompt=""):
# """Warmup the SimulStreaming model."""
# try:
# if isinstance(audio, np.ndarray):
# audio = torch.from_numpy(audio).float()
# self.model.insert_audio(audio)
# self.model.infer(True)
# self.model.refresh_segment(complete=True)
# logger.info("SimulStreaming model warmed up successfully")
# except Exception as e:
# logger.exception(f"SimulStreaming warmup failed: {e}")

View File

@@ -1,18 +1,5 @@
SIMULSTREAMING_LICENSE = f"""
{"*"*80}
SimulStreaming (https://github.com/ufal/SimulStreaming) is dual-licensed:
🔹 Non-Commercial Use
You may use SimulStreaming under the PolyForm Noncommercial License 1.0.0 if you obtain the code through the GitHub repository. This license is free of charge and comes with no obligations for non-commercial users.
🔸 Commercial Use
Understanding who uses SimulStreaming commercially helps us improve and
prioritize development. Therefore, we want to require registration of those who acquire a commercial licence.
We plan to make the commercial licenceses affordable to SMEs and individuals. We are considering to provide commercial licenses either for free or for symbolic one-time fee, and maybe also provide additional support. You can share your preference via the questionnaire https://forms.cloud.microsoft/e/7tCxb4gJfB.
You can also leave your contact there: https://forms.cloud.microsoft/e/7tCxb4gJfB to be notified when the commercial licenses become
available.
✉️ Contact
Dominik Macháček (https://ufal.mff.cuni.cz/dominik-machacek/), machacek@ufal.mff.cuni.cz
{"*"*80}
SimulStreaming backend is dual-licensed:
• Non-Commercial Use: PolyForm Noncommercial License 1.0.0.
• Commercial Use: Check SimulStreaming README (github.com/ufal/SimulStreaming) for more details.
"""

View File

@@ -33,11 +33,14 @@ import wave
# - prompt -- static vs. non-static
# - context
class PaddedAlignAttWhisper:
def __init__(self, cfg: AlignAttConfig) -> None:
def __init__(self, cfg: AlignAttConfig, loaded_model=None) -> None:
self.log_segments = 0
model_name = os.path.basename(cfg.model_path).replace(".pt", "")
model_path = os.path.dirname(os.path.abspath(cfg.model_path))
self.model = load_model(name=model_name, download_root=model_path)
if loaded_model:
self.model = loaded_model
else:
self.model = load_model(name=model_name, download_root=model_path)
logger.info(f"Model dimensions: {self.model.dims}")