10 Commits
0.1.8 ... 0.2.1

12 changed files with 973 additions and 98 deletions

111
README.md
View File

@@ -1,45 +1,42 @@
<h1 align="center">WhisperLiveKit</h1>
<p align="center">
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/demo.png" alt="WhisperLiveKit Demo" width="730">
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/demo.png" alt="WhisperLiveKit Demo" width="730">
</p>
<p align="center"><b>Real-time, Fully Local Speech-to-Text with Speaker Diarization</b></p>
<p align="center">
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=downloads"></a>
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.13-dark_green"></a>
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT-dark_green"></a>
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=downloads"></a>
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.13-dark_green"></a>
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT-dark_green"></a>
</p>
## 🚀 Overview
This project is based on [Whisper Streaming](https://github.com/ufal/whisper_streaming) and lets you transcribe audio directly from your browser. WhisperLiveKit provides a complete backend solution for real-time speech transcription with a functional and simple frontend that you can customize for your own needs. Everything runs locally on your machine ✨
This project is based on [WhisperStreaming](https://github.com/ufal/whisper_streaming) and [SimulStreaming](https://github.com/ufal/SimulStreaming), allowing you to transcribe audio directly from your browser. WhisperLiveKit provides a complete backend solution for real-time speech transcription with a functional, simple and customizable frontend. Everything runs locally on your machine ✨
### 🔄 Architecture
WhisperLiveKit consists of three main components:
- **Frontend**: A basic HTML & JavaScript interface that captures microphone audio and streams it to the backend via WebSockets. You can use and adapt the provided template at [whisperlivekit/web/live_transcription.html](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html) for your specific use case.
- **Frontend**: A basic html + JS interface that captures microphone audio and streams it to the backend via WebSockets. You can use and adapt the provided template at [whisperlivekit/web/live_transcription.html](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html).
- **Backend (Web Server)**: A FastAPI-based WebSocket server that receives streamed audio data, processes it in real time, and returns transcriptions to the frontend. This is where the WebSocket logic and routing live.
- **Core Backend (Library Logic)**: A server-agnostic core that handles audio processing, ASR, and diarization. It exposes reusable components that take in audio bytes and return transcriptions. This makes it easy to plug into any WebSocket or audio stream pipeline.
- **Core Backend (Library Logic)**: A server-agnostic core that handles audio processing, ASR, and diarization. It exposes reusable components that take in audio bytes and return transcriptions.
### ✨ Key Features
- **🎙️ Real-time Transcription** - Convert speech to text instantly as you speak
- **🎙️ Real-time Transcription** - Locally (or on-prem) convert speech to text instantly as you speak
- **👥 Speaker Diarization** - Identify different speakers in real-time using [Diart](https://github.com/juanmc2005/diart)
- **🔒 Fully Local** - All processing happens on your machine - no data sent to external servers
- **📱 Multi-User Support** - Handle multiple users simultaneously with a single backend/server
### ⚙️ Core differences from [Whisper Streaming](https://github.com/ufal/whisper_streaming)
- **🌐 Multi-User Support** - Handle multiple users simultaneously with a single backend/server
- **🔇 Automatic Silence Chunking** Automatically chunks when no audio is detected to limit buffer size
- **✅ Confidence Validation** Immediately validate high-confidence tokens for faster inference (WhisperStreaming only)
- **👁️ Buffering Preview** Displays unvalidated transcription segments (not compatible with SimulStreaming yet)
- **✒️ Punctuation-Based Speaker Splitting [BETA]** - Align speaker changes with natural sentence boundaries for more readable transcripts
- **⚡ SimulStreaming Backend** - Ultra-low latency transcription using state-of-the-art AlignAtt policy. The code is not directly included in the repo : To use, please copy [simul_whisper](https://github.com/ufal/SimulStreaming/tree/main/simul_whisper) content into `whisperlivekit/simul_whisper` . ⚠️ You must comply with the [Polyform license](https://github.com/ufal/SimulStreaming/blob/main/LICENCE.txt)
- **Automatic Silence Chunking** Automatically chunks when no audio is detected to limit buffer size
- **Multi-User Support** Handles multiple users simultaneously by decoupling backend and online ASR
- **Confidence Validation** Immediately validate high-confidence tokens for faster inference
- **MLX Whisper Backend** Optimized for Apple Silicon for faster local processing
- **Buffering Preview** Displays unvalidated transcription segments
## 📖 Quick Start
@@ -50,15 +47,8 @@ pip install whisperlivekit
# Start the transcription server
whisperlivekit-server --model tiny.en
# Open your browser at http://localhost:8000
```
### Quick Start with SSL
```bash
# You must provide a certificate and key
whisperlivekit-server -ssl-certfile public.crt --ssl-keyfile private.key
# Open your browser at https://localhost:8000
# Open your browser at http://localhost:8000 to see the interface.
# Use -ssl-certfile public.crt --ssl-keyfile private.key parameters to use SSL
```
That's it! Start speaking and watch your words appear on screen.
@@ -112,6 +102,7 @@ pip install whisperlivekit[whisper] # Original Whisper
pip install whisperlivekit[whisper-timestamped] # Improved timestamps
pip install whisperlivekit[mlx-whisper] # Apple Silicon optimization
pip install whisperlivekit[openai] # OpenAI API
pip install whisperlivekit[simulstreaming]
```
### 🎹 Pyannote Models Setup
@@ -122,10 +113,10 @@ For diarization, you need access to pyannote.audio models:
2. [Accept user conditions](https://huggingface.co/pyannote/segmentation-3.0) for the `pyannote/segmentation-3.0` model
3. [Accept user conditions](https://huggingface.co/pyannote/embedding) for the `pyannote/embedding` model
4. Login with HuggingFace:
```bash
pip install huggingface_hub
huggingface-cli login
```
```bash
pip install huggingface_hub
huggingface-cli login
```
## 💻 Usage Examples
@@ -139,8 +130,12 @@ whisperlivekit-server --model tiny.en
# Advanced configuration with diarization
whisperlivekit-server --host 0.0.0.0 --port 8000 --model medium --diarization --language auto
# SimulStreaming backend for ultra-low latency
whisperlivekit-server --backend simulstreaming --model large-v3 --frame-threshold 20
```
### Python API Integration (Backend)
Check [basic_server.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) for a complete example.
@@ -225,11 +220,12 @@ WhisperLiveKit offers extensive configuration options:
|-----------|-------------|---------|
| `--host` | Server host address | `localhost` |
| `--port` | Server port | `8000` |
| `--model` | Whisper model size | `tiny` |
| `--model` | Whisper model size. Caution : '.en' models do not work with Simulstreaming | `tiny` |
| `--language` | Source language code or `auto` | `en` |
| `--task` | `transcribe` or `translate` | `transcribe` |
| `--backend` | Processing backend | `faster-whisper` |
| `--diarization` | Enable speaker identification | `False` |
| `--punctuation-split` | Use punctuation to improve speaker boundaries | `True` |
| `--confidence-validation` | Use confidence scores for faster validation | `False` |
| `--min-chunk-size` | Minimum audio chunk size (seconds) | `1.0` |
| `--vac` | Use Voice Activity Controller | `False` |
@@ -238,13 +234,27 @@ WhisperLiveKit offers extensive configuration options:
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
| `--ssl-certfile` | Path to the SSL certificate file (for HTTPS support) | `None` |
| `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` |
| `--segmentation-model` | Hugging Face model ID for pyannote.audio segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
| `--embedding-model` | Hugging Face model ID for pyannote.audio embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
**SimulStreaming-specific Options:**
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--frame-threshold` | AlignAtt frame threshold (lower = faster, higher = more accurate) | `25` |
| `--beams` | Number of beams for beam search (1 = greedy decoding) | `1` |
| `--decoder` | Force decoder type (`beam` or `greedy`) | `auto` |
| `--audio-max-len` | Maximum audio buffer length (seconds) | `30.0` |
| `--audio-min-len` | Minimum audio length to process (seconds) | `0.0` |
| `--cif-ckpt-path` | Path to CIF model for word boundary detection | `None` |
| `--never-fire` | Never truncate incomplete words | `False` |
| `--init-prompt` | Initial prompt for the model | `None` |
| `--static-init-prompt` | Static prompt that doesn't scroll | `None` |
| `--max-context-tokens` | Maximum context tokens | `None` |
| `--model-path` | Direct path to .pt model file. Download it if not found | `./base.pt` |
## 🔧 How It Works
<p align="center">
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/demo.png" alt="WhisperLiveKit in Action" width="500">
</p>
1. **Audio Capture**: Browser's MediaRecorder API captures audio in webm/opus format
2. **Streaming**: Audio chunks are sent to the server via WebSocket
3. **Processing**: Server decodes audio with FFmpeg and streams into Whisper for transcription
@@ -276,14 +286,14 @@ To deploy WhisperLiveKit in production:
listen 80;
server_name your-domain.com;
location / {
proxy_pass http://localhost:8000;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_set_header Host $host;
}
}
```
location / {
proxy_pass http://localhost:8000;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_set_header Host $host;
}
}
```
4. **HTTPS Support**: For secure deployments, use "wss://" instead of "ws://" in WebSocket URL
@@ -323,6 +333,12 @@ docker start -i whisperlivekit-base
- **Content Creation**: Transcribe podcasts or videos automatically
- **Customer Service**: Transcribe support calls with speaker identification
## 📄 License
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
**⚠️ Important**: When using the SimulStreaming backend, you must also comply with the **PolyForm Noncommercial License 1.0.0** that governs SimulStreaming. For commercial use of the SimulStreaming backend, obtain a commercial license from the [SimulStreaming authors](https://github.com/ufal/SimulStreaming#-licence-and-contributions).
## 🤝 Contributing
Contributions are welcome! Here's how to get started:
@@ -337,17 +353,14 @@ Contributions are welcome! Here's how to get started:
This project builds upon the foundational work of:
- [Whisper Streaming](https://github.com/ufal/whisper_streaming)
- [SimulStreaming](https://github.com/ufal/SimulStreaming) (BETA backend)
- [Diart](https://github.com/juanmc2005/diart)
- [OpenAI Whisper](https://github.com/openai/whisper)
We extend our gratitude to the original authors for their contributions.
## 📄 License
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
## 🔗 Links
- [GitHub Repository](https://github.com/QuentinFuxa/WhisperLiveKit)
- [PyPI Package](https://pypi.org/project/whisperlivekit/)
- [Issue Tracker](https://github.com/QuentinFuxa/WhisperLiveKit/issues)
- [Issue Tracker](https://github.com/QuentinFuxa/WhisperLiveKit/issues)

View File

@@ -1,7 +1,7 @@
from setuptools import setup, find_packages
setup(
name="whisperlivekit",
version="0.1.8",
version="0.2.1",
description="Real-time, Fully Local Whisper's Speech-to-Text and Speaker Diarization",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
@@ -25,9 +25,16 @@ setup(
"whisper-timestamped": ["whisper-timestamped"],
"mlx-whisper": ["mlx-whisper"],
"openai": ["openai"],
"simulstreaming": [
"torch",
"tqdm",
"tiktoken",
"triton>=2.0.0,<3;platform_machine==\"x86_64\" and sys_platform==\"linux\" or sys_platform==\"linux2\"",
],
},
package_data={
'whisperlivekit': ['web/*.html'],
'whisperlivekit.simul_whisper': ['dual_license_simulstreaming.md'],
},
entry_points={
'console_scripts': [

View File

@@ -310,7 +310,7 @@ class AudioProcessor:
self.transcription_queue.task_done()
continue
asr_internal_buffer_duration_s = len(self.online.audio_buffer) / self.online.SAMPLING_RATE
asr_internal_buffer_duration_s = len(getattr(self.online, 'audio_buffer', [])) / self.online.SAMPLING_RATE
transcription_lag_s = max(0.0, time() - self.beg_loop - self.end_buffer)
logger.info(
@@ -326,13 +326,17 @@ class AudioProcessor:
self.online.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm)
new_tokens, current_audio_processed_upto = self.online.process_iter()
if new_tokens:
self.full_transcription += self.sep.join([t.text for t in new_tokens])
# Get buffer information
_buffer_transcript_obj = self.online.get_buffer()
buffer_text = _buffer_transcript_obj.text
if new_tokens:
validated_text = self.sep.join([t.text for t in new_tokens])
self.full_transcription += validated_text
if buffer_text.startswith(validated_text):
buffer_text = buffer_text[len(validated_text):].lstrip()
candidate_end_times = [self.end_buffer]
if new_tokens:
@@ -345,10 +349,6 @@ class AudioProcessor:
new_end_buffer = max(candidate_end_times)
# Avoid duplicating content
if buffer_text in self.full_transcription:
buffer_text = ""
await self.update_transcription(
new_tokens, buffer_text, new_end_buffer, self.full_transcription, self.sep
)
@@ -377,13 +377,16 @@ class AudioProcessor:
# Process diarization
await diarization_obj.diarize(pcm_array)
# Get current state and update speakers
state = await self.get_current_state()
new_end = diarization_obj.assign_speakers_to_tokens(
state["end_attributed_speaker"], state["tokens"]
)
async with self.lock:
new_end = diarization_obj.assign_speakers_to_tokens(
self.end_attributed_speaker,
self.tokens,
use_punctuation_split=self.args.punctuation_split
)
self.end_attributed_speaker = new_end
if buffer_diarization:
self.buffer_diarization = buffer_diarization
await self.update_diarization(new_end, buffer_diarization)
self.diarization_queue.task_done()
except Exception as e:

View File

@@ -24,6 +24,7 @@ class TranscriptionEngine:
"warmup_file": None,
"confidence_validation": False,
"diarization": False,
"punctuation_split": False,
"min_chunk_size": 0.5,
"model": "tiny",
"model_cache_dir": None,
@@ -38,8 +39,22 @@ class TranscriptionEngine:
"log_level": "DEBUG",
"ssl_certfile": None,
"ssl_keyfile": None,
"transcription": True,
"transcription": True,
"vad": True,
"segmentation_model": "pyannote/segmentation-3.0",
"embedding_model": "pyannote/embedding",
# simulstreaming params:
"frame_threshold": 25,
"beams": 1,
"decoder_type": None,
"audio_max_len": 30.0,
"audio_min_len": 0.0,
"cif_ckpt_path": None,
"never_fire": False,
"init_prompt": None,
"static_init_prompt": None,
"max_context_tokens": None,
"model_path": './base.pt',
}
config_dict = {**defaults, **kwargs}
@@ -68,6 +83,10 @@ class TranscriptionEngine:
if self.args.diarization:
from whisperlivekit.diarization.diarization_online import DiartDiarization
self.diarization = DiartDiarization()
self.diarization = DiartDiarization(
block_duration=self.args.min_chunk_size,
segmentation_model_name=self.args.segmentation_model,
embedding_model_name=self.args.embedding_model
)
TranscriptionEngine._initialized = True

View File

@@ -3,7 +3,8 @@ import re
import threading
import numpy as np
import logging
import time
from queue import SimpleQueue, Empty
from diart import SpeakerDiarization, SpeakerDiarizationConfig
from diart.inference import StreamingInference
@@ -13,6 +14,7 @@ from diart.sources import MicrophoneAudioSource
from rx.core import Observer
from typing import Tuple, Any, List
from pyannote.core import Annotation
import diart.models as m
logger = logging.getLogger(__name__)
@@ -78,40 +80,114 @@ class DiarizationObserver(Observer):
class WebSocketAudioSource(AudioSource):
"""
Custom AudioSource that blocks in read() until close() is called.
Use push_audio() to inject PCM chunks.
Buffers incoming audio and releases it in fixed-size chunks at regular intervals.
"""
def __init__(self, uri: str = "websocket", sample_rate: int = 16000):
def __init__(self, uri: str = "websocket", sample_rate: int = 16000, block_duration: float = 0.5):
super().__init__(uri, sample_rate)
self.block_duration = block_duration
self.block_size = int(np.rint(block_duration * sample_rate))
self._queue = SimpleQueue()
self._buffer = np.array([], dtype=np.float32)
self._buffer_lock = threading.Lock()
self._closed = False
self._close_event = threading.Event()
self._processing_thread = None
self._last_chunk_time = time.time()
def read(self):
"""Start processing buffered audio and emit fixed-size chunks."""
self._processing_thread = threading.Thread(target=self._process_chunks)
self._processing_thread.daemon = True
self._processing_thread.start()
self._close_event.wait()
if self._processing_thread:
self._processing_thread.join(timeout=2.0)
def _process_chunks(self):
"""Process audio from queue and emit fixed-size chunks at regular intervals."""
while not self._closed:
try:
audio_chunk = self._queue.get(timeout=0.1)
with self._buffer_lock:
self._buffer = np.concatenate([self._buffer, audio_chunk])
while len(self._buffer) >= self.block_size:
chunk = self._buffer[:self.block_size]
self._buffer = self._buffer[self.block_size:]
current_time = time.time()
time_since_last = current_time - self._last_chunk_time
if time_since_last < self.block_duration:
time.sleep(self.block_duration - time_since_last)
chunk_reshaped = chunk.reshape(1, -1)
self.stream.on_next(chunk_reshaped)
self._last_chunk_time = time.time()
except Empty:
with self._buffer_lock:
if len(self._buffer) > 0 and time.time() - self._last_chunk_time > self.block_duration:
padded_chunk = np.zeros(self.block_size, dtype=np.float32)
padded_chunk[:len(self._buffer)] = self._buffer
self._buffer = np.array([], dtype=np.float32)
chunk_reshaped = padded_chunk.reshape(1, -1)
self.stream.on_next(chunk_reshaped)
self._last_chunk_time = time.time()
except Exception as e:
logger.error(f"Error in audio processing thread: {e}")
self.stream.on_error(e)
break
with self._buffer_lock:
if len(self._buffer) > 0:
padded_chunk = np.zeros(self.block_size, dtype=np.float32)
padded_chunk[:len(self._buffer)] = self._buffer
chunk_reshaped = padded_chunk.reshape(1, -1)
self.stream.on_next(chunk_reshaped)
self.stream.on_completed()
def close(self):
if not self._closed:
self._closed = True
self.stream.on_completed()
self._close_event.set()
def push_audio(self, chunk: np.ndarray):
"""Add audio chunk to the processing queue."""
if not self._closed:
new_audio = np.expand_dims(chunk, axis=0)
logger.debug('Add new chunk with shape:', new_audio.shape)
self.stream.on_next(new_audio)
if chunk.ndim > 1:
chunk = chunk.flatten()
self._queue.put(chunk)
logger.debug(f'Added chunk to queue with {len(chunk)} samples')
class DiartDiarization:
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False):
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 0.5, segmentation_model_name: str = "pyannote/segmentation-3.0", embedding_model_name: str = "speechbrain/spkrec-ecapa-voxceleb"):
segmentation_model = m.SegmentationModel.from_pretrained(segmentation_model_name)
embedding_model = m.EmbeddingModel.from_pretrained(embedding_model_name)
if config is None:
config = SpeakerDiarizationConfig(
segmentation=segmentation_model,
embedding=embedding_model,
)
self.pipeline = SpeakerDiarization(config=config)
self.observer = DiarizationObserver()
self.lag_diart = None
if use_microphone:
self.source = MicrophoneAudioSource()
self.source = MicrophoneAudioSource(block_duration=block_duration)
self.custom_source = None
else:
self.custom_source = WebSocketAudioSource(uri="websocket_source", sample_rate=sample_rate)
self.custom_source = WebSocketAudioSource(
uri="websocket_source",
sample_rate=sample_rate,
block_duration=block_duration
)
self.source = self.custom_source
self.inference = StreamingInference(
@@ -138,16 +214,102 @@ class DiartDiarization:
if self.custom_source:
self.custom_source.close()
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) -> float:
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list, use_punctuation_split: bool = False) -> float:
"""
Assign speakers to tokens based on timing overlap with speaker segments.
Uses the segments collected by the observer.
If use_punctuation_split is True, uses punctuation marks to refine speaker boundaries.
"""
segments = self.observer.get_segments()
# Debug logging
logger.debug(f"assign_speakers_to_tokens called with {len(tokens)} tokens")
logger.debug(f"Available segments: {len(segments)}")
for i, seg in enumerate(segments[:5]): # Show first 5 segments
logger.debug(f" Segment {i}: {seg.speaker} [{seg.start:.2f}-{seg.end:.2f}]")
if not self.lag_diart and segments and tokens:
self.lag_diart = segments[0].start - tokens[0].start
for token in tokens:
for segment in segments:
if not (segment.end <= token.start or segment.start >= token.end):
if not (segment.end <= token.start + self.lag_diart or segment.start >= token.end + self.lag_diart):
token.speaker = extract_number(segment.speaker) + 1
end_attributed_speaker = max(token.end, end_attributed_speaker)
return end_attributed_speaker
if use_punctuation_split and len(tokens) > 1:
punctuation_marks = {'.', '!', '?'}
print("Here are the tokens:",
[(t.text, t.start, t.end, t.speaker) for t in tokens[:10]])
segment_map = []
for segment in segments:
speaker_num = extract_number(segment.speaker) + 1
segment_map.append((segment.start, segment.end, speaker_num))
segment_map.sort(key=lambda x: x[0])
i = 0
while i < len(tokens):
current_token = tokens[i]
is_sentence_end = False
if current_token.text and current_token.text.strip():
text = current_token.text.strip()
if text[-1] in punctuation_marks:
is_sentence_end = True
logger.debug(f"Token {i} ends sentence: '{current_token.text}' at {current_token.end:.2f}s")
if is_sentence_end and current_token.speaker != -1:
punctuation_time = current_token.end
current_speaker = current_token.speaker
j = i + 1
next_sentence_tokens = []
while j < len(tokens):
next_token = tokens[j]
next_sentence_tokens.append(j)
# Check if this token ends the next sentence
if next_token.text and next_token.text.strip():
if next_token.text.strip()[-1] in punctuation_marks:
break
j += 1
if next_sentence_tokens:
speaker_times = {}
for idx in next_sentence_tokens:
token = tokens[idx]
# Find which segments overlap with this token
for seg_start, seg_end, seg_speaker in segment_map:
if not (seg_end <= token.start or seg_start >= token.end):
# Calculate overlap duration
overlap_start = max(seg_start, token.start)
overlap_end = min(seg_end, token.end)
overlap_duration = overlap_end - overlap_start
if seg_speaker not in speaker_times:
speaker_times[seg_speaker] = 0
speaker_times[seg_speaker] += overlap_duration
if speaker_times:
dominant_speaker = max(speaker_times.items(), key=lambda x: x[1])[0]
if dominant_speaker != current_speaker:
logger.debug(f" Speaker change after punctuation: {current_speaker}{dominant_speaker}")
for idx in next_sentence_tokens:
if tokens[idx].speaker != dominant_speaker:
logger.debug(f" Reassigning token {idx} ('{tokens[idx].text}') to Speaker {dominant_speaker}")
tokens[idx].speaker = dominant_speaker
end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker)
else:
for idx in next_sentence_tokens:
if tokens[idx].speaker == -1:
tokens[idx].speaker = current_speaker
end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker)
i += 1
return end_attributed_speaker

View File

@@ -37,6 +37,27 @@ def parse_args():
help="Enable speaker diarization.",
)
parser.add_argument(
"--punctuation-split",
action="store_true",
default=False,
help="Use punctuation marks from transcription to improve speaker boundary detection. Requires both transcription and diarization to be enabled.",
)
parser.add_argument(
"--segmentation-model",
type=str,
default="pyannote/segmentation-3.0",
help="Hugging Face model ID for pyannote.audio segmentation model.",
)
parser.add_argument(
"--embedding-model",
type=str,
default="pyannote/embedding",
help="Hugging Face model ID for pyannote.audio embedding model.",
)
parser.add_argument(
"--no-transcription",
action="store_true",
@@ -87,7 +108,7 @@ def parse_args():
"--backend",
type=str,
default="faster-whisper",
choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"],
choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api", "simulstreaming"],
help="Load only this backend for Whisper processing.",
)
parser.add_argument(
@@ -130,6 +151,97 @@ def parse_args():
parser.add_argument("--ssl-certfile", type=str, help="Path to the SSL certificate file.", default=None)
parser.add_argument("--ssl-keyfile", type=str, help="Path to the SSL private key file.", default=None)
# SimulStreaming-specific arguments
simulstreaming_group = parser.add_argument_group('SimulStreaming arguments (only used with --backend simulstreaming)')
simulstreaming_group.add_argument(
"--frame-threshold",
type=int,
default=25,
dest="frame_threshold",
help="Threshold for the attention-guided decoding. The AlignAtt policy will decode only until this number of frames from the end of audio. In frames: one frame is 0.02 seconds for large-v3 model.",
)
simulstreaming_group.add_argument(
"--beams",
"-b",
type=int,
default=1,
help="Number of beams for beam search decoding. If 1, GreedyDecoder is used.",
)
simulstreaming_group.add_argument(
"--decoder",
type=str,
default=None,
dest="decoder_type",
choices=["beam", "greedy"],
help="Override automatic selection of beam or greedy decoder. If beams > 1 and greedy: invalid.",
)
simulstreaming_group.add_argument(
"--audio-max-len",
type=float,
default=30.0,
dest="audio_max_len",
help="Max length of the audio buffer, in seconds.",
)
simulstreaming_group.add_argument(
"--audio-min-len",
type=float,
default=0.0,
dest="audio_min_len",
help="Skip processing if the audio buffer is shorter than this length, in seconds. Useful when the --min-chunk-size is small.",
)
simulstreaming_group.add_argument(
"--cif-ckpt-path",
type=str,
default=None,
dest="cif_ckpt_path",
help="The file path to the Simul-Whisper's CIF model checkpoint that detects whether there is end of word at the end of the chunk. If not, the last decoded space-separated word is truncated because it is often wrong -- transcribing a word in the middle. The CIF model adapted for the Whisper model version should be used. Find the models in https://github.com/backspacetg/simul_whisper/tree/main/cif_models . Note that there is no model for large-v3.",
)
simulstreaming_group.add_argument(
"--never-fire",
action="store_true",
default=False,
dest="never_fire",
help="Override the CIF model. If True, the last word is NEVER truncated, no matter what the CIF model detects. If False: if CIF model path is set, the last word is SOMETIMES truncated, depending on the CIF detection. Otherwise, if the CIF model path is not set, the last word is ALWAYS trimmed.",
)
simulstreaming_group.add_argument(
"--init-prompt",
type=str,
default=None,
dest="init_prompt",
help="Init prompt for the model. It should be in the target language.",
)
simulstreaming_group.add_argument(
"--static-init-prompt",
type=str,
default=None,
dest="static_init_prompt",
help="Do not scroll over this text. It can contain terminology that should be relevant over all document.",
)
simulstreaming_group.add_argument(
"--max-context-tokens",
type=int,
default=None,
dest="max_context_tokens",
help="Max context tokens for the model. Default is 0.",
)
simulstreaming_group.add_argument(
"--model-path",
type=str,
default=None,
dest="model_path",
help="Direct path to the SimulStreaming Whisper .pt model file. Overrides --model for SimulStreaming backend.",
)
args = parser.parse_args()
@@ -138,4 +250,4 @@ def parse_args():
delattr(args, 'no_transcription')
delattr(args, 'no_vad')
return args
return args

View File

@@ -0,0 +1,27 @@
📄 SimulStreaming (https://github.com/ufal/SimulStreaming) Licence
SimulStreaming is dual-licensed:
🔹 Non-Commercial Use
You may use SimulStreaming under the **PolyForm Noncommercial License 1.0.0** if you
obtain the code through the GitHub repository. This license is **free of charge**
and comes with **no obligations** for non-commercial users.
🔸 Commercial Use
Understanding who uses SimulStreaming commercially helps us improve and
prioritize development. Therefore, we want to **require registration** of those who acquire a commercial licence.
We plan to make the commercial licenceses **affordable** to SMEs and individuals. We
are considering to provide commercial licenses either for free or for symbolic
one-time fee, and maybe also provide additional support. You can share your preference via the [questionnaire](https://forms.cloud.microsoft/e/7tCxb4gJfB).
You can also leave your contact [there](https://forms.cloud.microsoft/e/7tCxb4gJfB) to be notified when the commercial licenses become
available.
✉️ Contact
[Dominik Macháček](https://ufal.mff.cuni.cz/dominik-machacek/), machacek@ufal.mff.cuni.cz

View File

@@ -26,4 +26,7 @@ class Transcript(TimedText):
@dataclass
class SpeakerSegment(TimedText):
"""Represents a segment of audio attributed to a specific speaker.
No text nor probability is associated with this segment.
"""
pass

View File

@@ -0,0 +1,73 @@
import torch
import sys
class TokenBuffer:
def __init__(self, text="", tokenizer=None, device=None, prefix_token_ids=[]):
self.text = text
self.prefix_token_ids = prefix_token_ids
self.tokenizer = tokenizer
self.device = device
def as_token_ids(self, tokenizer=None):
if tokenizer is None:
tokenizer = self.tokenizer
if tokenizer is None:
raise ValueError("Tokenizer is not set.")
return self.prefix_token_ids + tokenizer.encode(self.text)
def as_tensor(self, device=None):
if device is None:
device = self.device
if device is None:
raise ValueError("Device is not set.")
tok_ids = self.as_token_ids()
return torch.tensor(tok_ids,
dtype=torch.long, device=device).unsqueeze(0)
def as_tensor_beam(self, beam, device=None):
t = self.as_tensor(device=device)
return t.repeat_interleave(beam, dim=0)
def as_text(self):
return self.text
@staticmethod
def empty(*a, **kw):
return TokenBuffer(*a,**kw)
@staticmethod
def from_text(text, *a, **kw):
return TokenBuffer(*a, text=text, **kw)
def is_empty(self):
return self.text is None or self.text == ""
def trim_words(self, num=1, after=0):
'''
num: how many words to trim from the beginning
after: how many characters to skip (length of the static prompt)
'''
tokenizer = self.tokenizer
assert tokenizer is not None, "Tokenizer is not set."
ids = tokenizer.encode(self.text[after:])
words, wids = self.tokenizer.split_to_word_tokens(ids)
print(words, file=sys.stderr)
print(wids, file=sys.stderr)
if not words:
return 0
self.text = self.text[:after] + "".join(words[num:])
return sum(len(wi) for wi in wids[:num])
def append_token_ids(self, token_ids):
tokenizer = self.tokenizer
assert tokenizer is not None, "Tokenizer is not set."
self.text += self.tokenizer.decode(token_ids)
def as_split_word_tokens(self):
tokenizer = self.tokenizer
assert tokenizer is not None, "Tokenizer is not set."
ids = tokenizer.encode(self.text)
return tokenizer.split_to_word_tokens(ids)

View File

@@ -13,6 +13,19 @@ from whisperlivekit.timed_objects import ASRToken
logger = logging.getLogger(__name__)
try:
from whisperlivekit.simul_whisper.config import AlignAttConfig
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper, DEC_PAD
from whisperlivekit.simul_whisper.whisper import tokenizer
SIMULSTREAMING_AVAILABLE = True
except ImportError:
logger.warning("SimulStreaming dependencies not available. SimulStreaming backend will not be available.")
SIMULSTREAMING_AVAILABLE = False
AlignAttConfig = None
PaddedAlignAttWhisper = None
DEC_PAD = None
tokenizer = None
class ASRBase:
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
# "" for faster-whisper because it emits the spaces when needed)
@@ -293,4 +306,181 @@ class OpenaiApiASR(ASRBase):
self.use_vad_opt = True
def set_translate_task(self):
self.task = "translate"
self.task = "translate"
class SimulStreamingASR(ASRBase):
"""SimulStreaming backend with AlignAtt policy."""
sep = " "
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr, **kwargs):
if not SIMULSTREAMING_AVAILABLE:
raise ImportError("""SimulStreaming dependencies are not available. Please install WhisperLiveKit using pip install "whisperlivekit[simulstreaming]". If you are building from source, you should also copy the content of the simul_whisper directory from the SimulStreaming repository into whisperlivekit/simul_whisper.""")
with open("whisperlivekit/simul_whisper/dual_license_simulstreaming.md", "r") as f:
print("*"*80 + f.read() + "*"*80)
self.logfile = logfile
self.transcribe_kargs = {}
self.original_language = None if lan == "auto" else lan
self.model_path = kwargs.get('model_path', './large-v3.pt')
self.frame_threshold = kwargs.get('frame_threshold', 25)
self.audio_max_len = kwargs.get('audio_max_len', 30.0)
self.audio_min_len = kwargs.get('audio_min_len', 0.0)
self.segment_length = kwargs.get('segment_length', 0.5)
self.beams = kwargs.get('beams', 1)
self.decoder_type = kwargs.get('decoder_type', 'greedy' if self.beams == 1 else 'beam')
self.task = kwargs.get('task', 'transcribe')
self.cif_ckpt_path = kwargs.get('cif_ckpt_path', None)
self.never_fire = kwargs.get('never_fire', False)
self.init_prompt = kwargs.get('init_prompt', None)
self.static_init_prompt = kwargs.get('static_init_prompt', None)
self.max_context_tokens = kwargs.get('max_context_tokens', None)
if model_dir is not None:
self.model_path = model_dir
elif modelsize is not None: #For the moment the .en.pt models do not work!
model_mapping = {
'tiny': './tiny.pt',
'base': './base.pt',
'small': './small.pt',
'medium': './medium.pt',
'medium.en': './medium.en.pt',
'large-v1': './large-v1.pt',
'base.en': './base.en.pt',
'small.en': './small.en.pt',
'tiny.en': './tiny.en.pt',
'large-v2': './large-v2.pt',
'large-v3': './large-v3.pt',
'large': './large-v3.pt'
}
self.model_path = model_mapping.get(modelsize, f'./{modelsize}.pt')
self.model = self.load_model(modelsize, cache_dir, model_dir)
# Set up tokenizer for translation if needed
if self.task == "translate":
self.set_translate_task()
def load_model(self, modelsize, cache_dir, model_dir):
try:
cfg = AlignAttConfig(
model_path=self.model_path,
segment_length=self.segment_length,
frame_threshold=self.frame_threshold,
language=self.original_language,
audio_max_len=self.audio_max_len,
audio_min_len=self.audio_min_len,
cif_ckpt_path=self.cif_ckpt_path,
decoder_type="beam",
beam_size=self.beams,
task=self.task,
never_fire=self.never_fire,
init_prompt=self.init_prompt,
max_context_tokens=self.max_context_tokens,
static_init_prompt=self.static_init_prompt,
)
logger.info(f"Loading SimulStreaming model with language: {self.original_language}")
model = PaddedAlignAttWhisper(cfg)
return model
except Exception as e:
logger.error(f"Failed to load SimulStreaming model: {e}")
raise
def transcribe(self, audio, init_prompt=""):
"""Transcribe audio using SimulStreaming."""
try:
if isinstance(audio, np.ndarray):
audio_tensor = torch.from_numpy(audio).float()
else:
audio_tensor = audio
prompt = init_prompt if init_prompt else (self.init_prompt or "")
result = self.model.infer(audio_tensor, init_prompt=prompt)
if torch.is_tensor(result):
result = result[result < DEC_PAD]
logger.debug(f"SimulStreaming transcription result: {result}")
return result
except Exception as e:
logger.error(f"SimulStreaming transcription failed: {e}")
raise
def ts_words(self, result) -> List[ASRToken]:
"""Convert SimulStreaming result to ASRToken list."""
tokens = []
try:
if torch.is_tensor(result):
text = self.model.tokenizer.decode(result.cpu().numpy())
else:
text = str(result)
if not text or len(text.strip()) == 0:
return tokens
# We dont have word-level timestamps here. 1rst approach, should be improved later.
words = text.strip().split()
if not words:
return tokens
duration_per_word = 0.1 # this will be modified based on actual audio duration
#with the SimulStreamingOnlineProcessor
for i, word in enumerate(words):
start_time = i * duration_per_word
end_time = (i + 1) * duration_per_word
token = ASRToken(
start=start_time,
end=end_time,
text=word,
probability=1.0
)
tokens.append(token)
except Exception as e:
logger.error(f"Error converting SimulStreaming result to tokens: {e}")
return tokens
def segments_end_ts(self, result) -> List[float]:
"""Get segment end timestamps."""
if torch.is_tensor(result):
num_tokens = len(result)
return [num_tokens * 0.1] # rough estimate
return [1.0]
def use_vad(self):
"""Enable VAD - SimulStreaming has different VAD handling."""
logger.info("VAD requested for SimulStreaming - handled internally by the model")
pass
def set_translate_task(self):
"""Set up translation task."""
try:
self.model.tokenizer = tokenizer.get_tokenizer(
multilingual=True,
language=self.model.cfg.language,
num_languages=self.model.model.num_languages,
task="translate"
)
logger.info("SimulStreaming configured for translation task")
except Exception as e:
logger.error(f"Failed to configure SimulStreaming for translation: {e}")
raise
def warmup(self, audio, init_prompt=""):
"""Warmup the SimulStreaming model."""
try:
if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio).float()
self.model.infer(audio, True)
self.model.refresh_segment(complete=True)
logger.info("SimulStreaming model warmed up successfully")
except Exception as e:
logger.warning(f"SimulStreaming warmup failed: {e}")

View File

@@ -6,6 +6,17 @@ from whisperlivekit.timed_objects import ASRToken, Sentence, Transcript
logger = logging.getLogger(__name__)
# simulStreaming imports - we check if the files are here
try:
import torch
from simul_whisper.config import AlignAttConfig
SIMULSTREAMING_AVAILABLE = True
except ImportError:
logger.warning("SimulStreaming dependencies not available for online processor.")
SIMULSTREAMING_AVAILABLE = False
OnlineProcessorInterface = None
torch = None
class HypothesisBuffer:
"""
@@ -500,3 +511,207 @@ class VACOnlineASRProcessor:
Get the unvalidated buffer in string format.
"""
return self.online.concatenate_tokens(self.online.transcript_buffer.buffer)
class SimulStreamingOnlineProcessor:
SAMPLING_RATE = 16000
def __init__(
self,
asr,
tokenize_method: Optional[callable] = None,
buffer_trimming: Tuple[str, float] = ("segment", 15),
confidence_validation = False,
logfile=sys.stderr,
):
if not SIMULSTREAMING_AVAILABLE:
raise ImportError("SimulStreaming dependencies are not available.")
self.asr = asr
self.tokenize = tokenize_method
self.logfile = logfile
self.confidence_validation = confidence_validation
self.init()
# buffer does not work yet
self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
def init(self, offset: Optional[float] = None):
"""Initialize or reset the processing state."""
self.audio_chunks = []
self.offset = offset if offset is not None else 0.0
self.is_last = False
self.beg = self.offset
self.end = self.offset
self.cumulative_audio_duration = 0.0
self.last_audio_stream_end_time = self.offset
self.committed: List[ASRToken] = []
self.last_result_tokens: List[ASRToken] = []
self.buffer_content = ""
self.processed_audio_duration = 0.0
def get_audio_buffer_end_time(self) -> float:
"""Returns the absolute end time of the current audio buffer."""
return self.end
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: Optional[float] = None):
"""Append an audio chunk to be processed by SimulStreaming."""
if torch is None:
raise ImportError("PyTorch is required for SimulStreaming but not available")
# Convert numpy array to torch tensor
audio_tensor = torch.from_numpy(audio).float()
self.audio_chunks.append(audio_tensor)
# Update timing
chunk_duration = len(audio) / self.SAMPLING_RATE
self.cumulative_audio_duration += chunk_duration
if audio_stream_end_time is not None:
self.last_audio_stream_end_time = audio_stream_end_time
self.end = audio_stream_end_time
else:
self.end = self.offset + self.cumulative_audio_duration
def prompt(self) -> Tuple[str, str]:
"""
Returns a tuple: (prompt, context).
SimulStreaming handles prompting internally, so we return empty strings.
"""
return "", ""
def get_buffer(self):
"""
Get the unvalidated buffer content.
"""
buffer_end = self.end if hasattr(self, 'end') else None
return Transcript(
start=None,
end=buffer_end,
text=self.buffer_content,
probability=None
)
def process_iter(self) -> Tuple[List[ASRToken], float]:
"""
Process accumulated audio chunks using SimulStreaming.
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
"""
if not self.audio_chunks:
return [], self.end
try:
# concatenate all audio chunks
if len(self.audio_chunks) == 1:
audio = self.audio_chunks[0]
else:
audio = torch.cat(self.audio_chunks, dim=0)
audio_duration = audio.shape[0] / self.SAMPLING_RATE if audio.shape[0] > 0 else 0
self.processed_audio_duration += audio_duration
self.audio_chunks = []
logger.debug(f"SimulStreaming processing audio shape: {audio.shape}, duration: {audio_duration:.2f}s")
logger.debug(f"Current end time: {self.end:.2f}s, last stream time: {self.last_audio_stream_end_time:.2f}s")
result = self.asr.model.infer(audio, is_last=self.is_last)
if torch.is_tensor(result):
# we filter out padding tokens as it s done in simul whisper
from simul_whisper.simul_whisper import DEC_PAD
result = result[result < DEC_PAD]
# C/P from simul_whisper.simul_whisper.py
if len(result) > 0:
decoded_text = self.asr.model.tokenizer.decode(result.cpu().numpy())
logger.debug(f"SimulStreaming decoded: {decoded_text}")
if decoded_text.strip():
words = decoded_text.strip().split()
new_tokens = []
num_words = len(words)
if num_words > 0:
# distribute words evenly across the processed audio duration
# we NEED that for when we use diarization. Even if that s not perfect
start_time = self.end - audio_duration
time_per_word = audio_duration / num_words if num_words > 1 else audio_duration
for i, word in enumerate(words):
token_start = start_time + (i * time_per_word)
token_end = start_time + ((i + 1) * time_per_word)
token_end = min(token_end, self.end)
token = ASRToken(
start=token_start,
end=token_end,
text=word,
probability=0.95 # fake prob. Maybe we can extract it from the model?
)
new_tokens.append(token)
self.beg = self.end
self.committed.extend(new_tokens)
self.last_result_tokens = new_tokens
logger.debug(f"SimulStreaming generated {len(new_tokens)} tokens with end time: {self.end:.2f}s")
return new_tokens, self.end
return [], self.end
except Exception as e:
logger.error(f"SimulStreaming processing error: {e}")
logger.error(f"Error details: {type(e).__name__}: {str(e)}")
return [], self.end
def finish(self) -> Tuple[List[ASRToken], float]:
logger.debug("SimulStreaming finish() called")
self.is_last = True
final_tokens, final_time = self.process_iter()
self.is_last = False
return final_tokens, final_time
def concatenate_tokens(
self,
tokens: List[ASRToken],
sep: Optional[str] = None,
offset: float = 0
) -> Transcript:
"""Concatenate tokens into a Transcript object."""
sep = sep if sep is not None else self.asr.sep
text = sep.join(token.text for token in tokens)
probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
if tokens:
start = offset + tokens[0].start
end = offset + tokens[-1].end
else:
start = None
end = None
return Transcript(start, end, text, probability=probability)
def chunk_at(self, time: float):
"""
useless but kept for compatibility
"""
logger.debug(f"SimulStreaming chunk_at({time:.2f}) - handled internally")
pass
def words_to_sentences(self, tokens: List[ASRToken]) -> List[Sentence]:
"""
Create simple sentences.
"""
if not tokens:
return []
full_text = " ".join(token.text for token in tokens)
sentence = Sentence(
start=tokens[0].start,
end=tokens[-1].end,
text=full_text
)
return [sentence]

View File

@@ -5,8 +5,8 @@ import librosa
from functools import lru_cache
import time
import logging
from .backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR
from .online_asr import OnlineASRProcessor, VACOnlineASRProcessor
from .backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR, SimulStreamingASR, SIMULSTREAMING_AVAILABLE
from .online_asr import OnlineASRProcessor, VACOnlineASRProcessor, SimulStreamingOnlineProcessor, SIMULSTREAMING_AVAILABLE as SIMULSTREAMING_ONLINE_AVAILABLE
logger = logging.getLogger(__name__)
@@ -69,6 +69,37 @@ def backend_factory(args):
if backend == "openai-api":
logger.debug("Using OpenAI API.")
asr = OpenaiApiASR(lan=args.lan)
elif backend == "simulstreaming":
logger.debug("Using SimulStreaming backend.")
if not SIMULSTREAMING_AVAILABLE:
raise ImportError(
"SimulStreaming backend is not available. Please install SimulStreaming dependencies. "
"See the documentation for installation instructions."
)
simulstreaming_kwargs = {}
for attr in ['frame_threshold', 'beams', 'decoder_type', 'audio_max_len', 'audio_min_len',
'cif_ckpt_path', 'never_fire', 'init_prompt', 'static_init_prompt',
'max_context_tokens', 'model_path']:
if hasattr(args, attr):
simulstreaming_kwargs[attr] = getattr(args, attr)
# Add segment_length from min_chunk_size
simulstreaming_kwargs['segment_length'] = getattr(args, 'min_chunk_size', 0.5)
simulstreaming_kwargs['task'] = args.task
size = args.model
t = time.time()
logger.info(f"Loading SimulStreaming {size} model for language {args.lan}...")
asr = SimulStreamingASR(
modelsize=size,
lan=args.lan,
cache_dir=getattr(args, 'model_cache_dir', None),
model_dir=getattr(args, 'model_dir', None),
**simulstreaming_kwargs
)
e = time.time()
logger.info(f"done. It took {round(e-t,2)} seconds.")
else:
if backend == "faster-whisper":
asr_cls = FasterWhisperASR
@@ -84,8 +115,8 @@ def backend_factory(args):
asr = asr_cls(
modelsize=size,
lan=args.lan,
cache_dir=args.model_cache_dir,
model_dir=args.model_dir,
cache_dir=getattr(args, 'model_cache_dir', None),
model_dir=getattr(args, 'model_dir', None),
)
e = time.time()
logger.info(f"done. It took {round(e-t,2)} seconds.")
@@ -97,21 +128,33 @@ def backend_factory(args):
language = args.lan
if args.task == "translate":
asr.set_translate_task()
if backend != "simulstreaming":
asr.set_translate_task()
tgt_language = "en" # Whisper translates into English
else:
tgt_language = language # Whisper transcribes in this language
# Create the tokenizer
if args.buffer_trimming == "sentence":
tokenizer = create_tokenizer(tgt_language)
else:
tokenizer = None
return asr, tokenizer
def online_factory(args, asr, tokenizer, logfile=sys.stderr):
if args.vac:
if args.backend == "simulstreaming":
if not SIMULSTREAMING_ONLINE_AVAILABLE:
raise ImportError("SimulStreaming online processor is not available.")
logger.debug("Creating SimulStreaming online processor")
online = SimulStreamingOnlineProcessor(
asr,
tokenizer,
logfile=logfile,
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
confidence_validation=args.confidence_validation
)
elif args.vac:
online = VACOnlineASRProcessor(
args.min_chunk_size,
asr,
@@ -145,6 +188,7 @@ def warmup_asr(asr, warmup_file=None, timeout=5):
import os
import tempfile
is_simulstreaming = hasattr(asr, 'warmup') and callable(getattr(asr, 'warmup'))
if warmup_file is None:
# Download JFK sample if not already present
@@ -179,16 +223,23 @@ def warmup_asr(asr, warmup_file=None, timeout=5):
logger.warning(f"Warmup file {warmup_file} invalid or missing.")
return False
print(f"Warming up Whisper with {warmup_file}")
print(f"Warming up {'SimulStreaming' if is_simulstreaming else 'Whisper'} with {warmup_file}")
try:
import librosa
audio, sr = librosa.load(warmup_file, sr=16000)
except Exception as e:
logger.warning(f"Failed to load audio file: {e}")
return False
# Process the audio
asr.transcribe(audio)
logger.info("Whisper is warmed up")
try:
if is_simulstreaming:
asr.warmup(audio)
else:
asr.transcribe(audio)
logger.info(f"{'SimulStreaming' if is_simulstreaming else 'Whisper'} is warmed up")
return True
except Exception as e:
logger.warning(f"Warmup failed: {e}")
return False