46 Commits
0.2.2 ... 0.2.5

Author SHA1 Message Date
Quentin Fuxa
349c7dcb9e bump version ro 0.2.5 2025-08-13 10:04:31 +02:00
Quentin Fuxa
1c42b867cf Merge branch 'main' of https://github.com/QuentinFuxa/whisper_streaming_web 2025-08-13 10:04:04 +02:00
Quentin Fuxa
d4771e563e Increase END_SILENCE_DURATION to reduce false positives 2025-08-13 10:04:00 +02:00
Quentin Fuxa
b0a5fc0693 Merge pull request #155 from davidgumberg/keepawakescrolldown
frontend: Keep screen awake and scroll down when transcribing.
2025-08-13 10:02:52 +02:00
David Gumberg
3b96fb8776 frontend: Scroll down when appending transcription 2025-08-12 17:31:32 -07:00
David Gumberg
7f93c4b978 frontend: Don't let screen sleep when transcribing. 2025-08-12 17:30:57 -07:00
Quentin Fuxa
15c3df1cba warmup base whisper when using simulstreaming 2025-08-12 18:52:52 +02:00
Quentin Fuxa
7fb8e66c01 typo 2025-08-12 18:36:32 +02:00
Quentin Fuxa
728e1f1290 simulstreaming warmup is done for each instance of online, not for the backend 2025-08-12 18:35:04 +02:00
Quentin Fuxa
87b9ed6ecd nonspeech_prob from 1 to 0.5 2025-08-12 18:34:37 +02:00
Quentin Fuxa
38b4ebe8ba Handle 3 types of silences: Indicated by whisper, between tokens, and at the end of the input. Display them in the frontend 2025-08-11 17:56:57 +02:00
Quentin Fuxa
d098af3185 each SimulStreamingOnlineProcessor now contains PaddedAlignAttWhisper instance. SimulStreamingASR only contains loaded whisper model 2025-08-11 08:24:14 +02:00
Quentin Fuxa
4e56130a40 frontend supports dark theme 2025-08-11 08:22:23 +02:00
Quentin Fuxa
2bbdc70187 lags are now updated every 0.1s 2025-08-09 23:11:05 +02:00
Quentin Fuxa
b678a55f63 remove duplicate file 2025-08-09 23:10:34 +02:00
Quentin Fuxa
5491964e81 clean SimulStreamingOnlineProcessor initialization + audio processing 2025-08-09 20:16:27 +02:00
Quentin Fuxa
b05297a96d clean simulwhisper backend and online 2025-08-09 18:02:15 +02:00
Quentin Fuxa
197293e25e refactor(simulstreaming): extract backend + online module into separate files from whisper streaming 2025-08-08 18:07:51 +02:00
Quentin Fuxa
ba41c4ab56 Remove download_simulstreaming_backend 2025-08-08 18:06:40 +02:00
Quentin Fuxa
bda72b8bc0 setup.py to pyproject.toml. Remove <2.0.0 condition on numpy dep 2025-08-03 16:32:31 +02:00
Quentin Fuxa
bb6b9f4cb1 architecture diagram : available backends for whisper streaming & diarization 2025-08-03 12:25:36 +02:00
Quentin Fuxa
e40b5a3ea0 Update architecture diagram 2025-08-02 13:51:15 +02:00
Quentin Fuxa
4cfed6e98e in MultiHeadAttention and ResidualAttentionBlock include cache_id for compatibility with simulstreaming code 2025-08-02 13:16:58 +02:00
Quentin Fuxa
687e3dd5e2 update simulstreaming model.py to match the latest version of whisper sources 2025-08-02 13:16:10 +02:00
Quentin Fuxa
e4140cd299 Update Dockerfile to install build-essential and update PyTorch version 2025-08-02 13:08:43 +02:00
Quentin Fuxa
8e056cbdf2 Upgrade SimulStreaming Whisper core from version 20230918 to 20250625 2025-08-02 13:06:36 +02:00
Quentin Fuxa
9dcfb38967 Update README.md 2025-08-01 18:02:11 +02:00
Quentin Fuxa
47b9235d70 Update README.md 2025-08-01 17:55:40 +02:00
Quentin Fuxa
f3cd53a4db Update README.md 2025-08-01 16:53:22 +02:00
Quentin Fuxa
dbdb4ea66c Update README.md 2025-08-01 16:33:26 +02:00
Quentin Fuxa
00424d7ca3 latest version of simulstreaming 2025-07-31 16:44:23 +02:00
Quentin Fuxa
4b738d6f63 fix duplicate line 2025-07-31 16:29:35 +02:00
Quentin Fuxa
8a5e2adb1e simulstreaming: fixes token handling during warm-up phase 2025-07-31 16:25:34 +02:00
Quentin Fuxa
f85329e112 Merge branch 'main' of https://github.com/QuentinFuxa/whisper_streaming_web 2025-07-31 11:42:16 +02:00
Quentin Fuxa
46efbdf1d9 solves https://github.com/QuentinFuxa/WhisperLiveKit/issues/151 2025-07-31 11:42:06 +02:00
Quentin Fuxa
8885ade003 Merge pull request #153 from luisla-rivas/main
Fix README.md to view correctly Deployment Guide info
2025-07-31 07:10:35 +02:00
luisla-rivas
2564928d83 Fix README.md to view correctly Deployment Guide info 2025-07-30 14:11:19 +02:00
Quentin Fuxa
56114d3071 Remove end_attributed_speaker in diarization_online. handled in audio processor 2025-07-16 12:09:43 +02:00
Quentin Fuxa
5b9977c9af Enhanced use_punctuation_split for diarization. further improvements still needed 2025-07-16 12:06:17 +02:00
Quentin Fuxa
12a544164f Merge branch 'main' of https://github.com/QuentinFuxa/whisper_streaming_web 2025-07-16 12:05:01 +02:00
Quentin Fuxa
2ca1156b7e Merge pull request #147 from choomegan/diar_queue
Ensure diarization_queue receives only latest PCM chunk
2025-07-16 12:04:53 +02:00
Quentin Fuxa
3ad3683ca7 Refactor speaker assignment in DiartDiarization for clarity and punctuation awareness 2025-07-15 14:38:53 +02:00
Quentin Fuxa
1599bd87a0 work on punctuation_split 2025-07-15 12:04:54 +02:00
Quentin Fuxa
90623400a4 Remove automatic downloading of SimulStreaming dependencies on import failure 2025-07-15 12:04:17 +02:00
choomegan
64e44fb24f fix: logic of adding of pcm_array to diarization_queue 2025-07-15 15:33:41 +08:00
Quentin Fuxa
156b9a133f 0.2.2 2025-07-04 17:11:35 +02:00
34 changed files with 3160 additions and 1713 deletions

View File

@@ -21,10 +21,12 @@ RUN apt-get update && \
python3 \
python3-pip \
ffmpeg \
git && \
git \
build-essential \
python3-dev && \
rm -rf /var/lib/apt/lists/*
RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
COPY . .

View File

@@ -13,23 +13,16 @@
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT/Dual Licensed-dark_green"></a>
</p>
## Overview
This project is based on [WhisperStreaming](https://github.com/ufal/whisper_streaming) and [SimulStreaming](https://github.com/ufal/SimulStreaming), allowing you to transcribe audio directly from your browser. WhisperLiveKit provides a complete backend solution for real-time speech transcription with a functional, simple and customizable frontend. Everything runs locally on your machine
WhisperLiveKit brings real-time speech transcription directly to your browser, with a ready-to-use backend+server and a simple frontend.
### Architecture
WhisperLiveKit consists of three main components:
- **Frontend**: A basic html + JS interface that captures microphone audio and streams it to the backend via WebSockets. You can use and adapt the [provided template](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html).
- **Backend (Web Server)**: A FastAPI-based WebSocket server that receives streamed audio data, processes it in real time, and returns transcriptions to the frontend. This is where the WebSocket logic and routing live.
- **Core Backend (Library Logic)**: A server-agnostic core that handles audio processing, ASR, and diarization. It exposes reusable components that take in audio bytes and return transcriptions.
Built on [SimulStreaming](https://github.com/ufal/SimulStreaming) (SOTA 2025) and [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) for transcription, plus [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) and [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) for diarization.
### Key Features
- **Real-time Transcription** - Locally (or on-prem) convert speech to text instantly as you speak
- **Speaker Diarization** - Identify different speakers in real-time using [Diart](https://github.com/juanmc2005/diart)
- **Speaker Diarization** - Identify different speakers in real-time. (⚠️ backend Streaming Sortformer in developement)
- **Multi-User Support** - Handle multiple users simultaneously with a single backend/server
- **Automatic Silence Chunking** Automatically chunks when no audio is detected to limit buffer size
- **Confidence Validation** Immediately validate high-confidence tokens for faster inference (WhisperStreaming only)
@@ -37,6 +30,11 @@ WhisperLiveKit consists of three main components:
- **Punctuation-Based Speaker Splitting [BETA]** - Align speaker changes with natural sentence boundaries for more readable transcripts
- **SimulStreaming Backend** - [Dual-licensed](https://github.com/ufal/SimulStreaming#-licence-and-contributions) - Ultra-low latency transcription using SOTA AlignAtt policy.
### Architecture
<img alt="Architecture" src="architecture.png" />
## Quick Start
```bash
@@ -247,7 +245,7 @@ To deploy WhisperLiveKit in production:
- Ensure WebSocket connection points to your server's address
3. **Nginx Configuration** (recommended for production):
```nginx
```nginx
server {
listen 80;
server_name your-domain.com;
@@ -258,6 +256,7 @@ To deploy WhisperLiveKit in production:
proxy_set_header Connection "upgrade";
proxy_set_header Host $host;
}}
```
4. **HTTPS Support**: For secure deployments, use "wss://" instead of "ws://" in WebSocket URL
@@ -268,21 +267,21 @@ A basic Dockerfile is provided which allows re-use of Python package installatio
#### All defaults
- Create a reusable image with only the basics and then run as a named container:
```bash
docker build -t whisperlivekit-defaults .
docker create --gpus all --name whisperlivekit -p 8000:8000 whisperlivekit-defaults
docker start -i whisperlivekit
```
```bash
docker build -t whisperlivekit-defaults .
docker create --gpus all --name whisperlivekit -p 8000:8000 whisperlivekit-defaults
docker start -i whisperlivekit
```
> **Note**: If you're running on a system without NVIDIA GPU support (such as Mac with Apple Silicon or any system without CUDA capabilities), you need to **remove the `--gpus all` flag** from the `docker create` command. Without GPU acceleration, transcription will use CPU only, which may be significantly slower. Consider using small models for better performance on CPU-only systems.
> **Note**: If you're running on a system without NVIDIA GPU support (such as Mac with Apple Silicon or any system without CUDA capabilities), you need to **remove the `--gpus all` flag** from the `docker create` command. Without GPU acceleration, transcription will use CPU only, which may be significantly slower. Consider using small models for better performance on CPU-only systems.
#### Customization
- Customize the container options:
```bash
docker build -t whisperlivekit-defaults .
docker create --gpus all --name whisperlivekit-base -p 8000:8000 whisperlivekit-defaults --model base
docker start -i whisperlivekit-base
```
```bash
docker build -t whisperlivekit-defaults .
docker create --gpus all --name whisperlivekit-base -p 8000:8000 whisperlivekit-defaults --model base
docker start -i whisperlivekit-base
```
- `--build-arg` Options:
- `EXTRAS="whisper-timestamped"` - Add extras to the image's installation (no spaces). Remember to set necessary container options!

BIN
architecture.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 382 KiB

59
pyproject.toml Normal file
View File

@@ -0,0 +1,59 @@
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[project]
name = "whisperlivekit"
version = "0.2.5"
description = "Real-time, Fully Local Whisper's Speech-to-Text and Speaker Diarization"
readme = "README.md"
authors = [
{ name = "Quentin Fuxa" }
]
license = { file = "LICENSE" }
requires-python = ">=3.9"
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Multimedia :: Sound/Audio :: Speech"
]
dependencies = [
"fastapi",
"librosa",
"soundfile",
"faster-whisper",
"uvicorn",
"websockets"
]
[project.optional-dependencies]
diarization = ["diart"]
vac = ["torch"]
sentence = ["mosestokenizer", "wtpsplit"]
whisper = ["whisper"]
whisper-timestamped = ["whisper-timestamped"]
mlx-whisper = ["mlx-whisper"]
openai = ["openai"]
simulstreaming = [
"torch",
"tqdm",
"tiktoken",
'triton>=2.0.0,<3; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")'
]
[project.urls]
Homepage = "https://github.com/QuentinFuxa/WhisperLiveKit"
[project.scripts]
whisperlivekit-server = "whisperlivekit.basic_server:main"
[tool.setuptools]
packages = ["whisperlivekit", "whisperlivekit.diarization", "whisperlivekit.simul_whisper", "whisperlivekit.simul_whisper.whisper", "whisperlivekit.simul_whisper.whisper.assets", "whisperlivekit.simul_whisper.whisper.normalizers", "whisperlivekit.web", "whisperlivekit.whisper_streaming_custom"]
[tool.setuptools.package-data]
whisperlivekit = ["web/*.html"]
"whisperlivekit.simul_whisper.whisper.assets" = ["*.tiktoken", "*.npz"]

View File

@@ -1,55 +0,0 @@
from setuptools import setup, find_packages
setup(
name="whisperlivekit",
version="0.2.1",
description="Real-time, Fully Local Whisper's Speech-to-Text and Speaker Diarization",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
author="Quentin Fuxa",
url="https://github.com/QuentinFuxa/WhisperLiveKit",
packages=find_packages(),
install_requires=[
"fastapi",
"librosa",
"soundfile",
"faster-whisper",
"uvicorn",
"websockets",
],
extras_require={
"diarization": ["diart"],
"vac": ["torch"],
"sentence": ["mosestokenizer", "wtpsplit"],
"whisper": ["whisper"],
"whisper-timestamped": ["whisper-timestamped"],
"mlx-whisper": ["mlx-whisper"],
"openai": ["openai"],
"simulstreaming": [
"torch",
"tqdm",
"tiktoken",
"numpy<2.0.0",
"triton>=2.0.0,<3;platform_machine==\"x86_64\" and sys_platform==\"linux\" or sys_platform==\"linux2\"",
],
},
package_data={
'whisperlivekit': ['web/*.html'],
'whisperlivekit.simul_whisper': ['dual_license_simulstreaming.md'],
'whisperlivekit.simul_whisper.whisper.assets': ['*.tiktoken', '*.npz'],
},
entry_points={
'console_scripts': [
'whisperlivekit-server=whisperlivekit.basic_server:main',
],
},
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Multimedia :: Sound/Audio :: Speech",
],
python_requires=">=3.9",
)

View File

@@ -1,4 +1,3 @@
from .download_simulstreaming_backend import download_simulstreaming_backend
from .audio_processor import AudioProcessor
from .core import TranscriptionEngine
from .parse_args import parse_args

View File

@@ -6,10 +6,9 @@ import logging
import traceback
from datetime import timedelta
from whisperlivekit.timed_objects import ASRToken
from whisperlivekit.whisper_streaming_custom.whisper_online import online_factory
from whisperlivekit.core import TranscriptionEngine
from whisperlivekit.core import TranscriptionEngine, online_factory
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
from .remove_silences import handle_silences
# Set up logging once
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
@@ -52,7 +51,6 @@ class AudioProcessor:
self.tokens = []
self.buffer_transcription = ""
self.buffer_diarization = ""
self.full_transcription = ""
self.end_buffer = 0
self.end_attributed_speaker = 0
self.lock = asyncio.Lock()
@@ -96,13 +94,12 @@ class AudioProcessor:
"""Convert PCM buffer in s16le format to normalized NumPy array."""
return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
async def update_transcription(self, new_tokens, buffer, end_buffer, full_transcription, sep):
async def update_transcription(self, new_tokens, buffer, end_buffer, sep):
"""Thread-safe update of transcription with new data."""
async with self.lock:
self.tokens.extend(new_tokens)
self.buffer_transcription = buffer
self.end_buffer = end_buffer
self.full_transcription = full_transcription
self.sep = sep
async def update_diarization(self, end_attributed_speaker, buffer_diarization=""):
@@ -129,12 +126,12 @@ class AudioProcessor:
# Calculate remaining times
remaining_transcription = 0
if self.end_buffer > 0:
remaining_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 2))
remaining_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 1))
remaining_diarization = 0
if self.tokens:
latest_end = max(self.end_buffer, self.tokens[-1].end if self.tokens else 0)
remaining_diarization = max(0, round(latest_end - self.end_attributed_speaker, 2))
remaining_diarization = max(0, round(latest_end - self.end_attributed_speaker, 1))
return {
"tokens": self.tokens.copy(),
@@ -153,7 +150,6 @@ class AudioProcessor:
self.tokens = []
self.buffer_transcription = self.buffer_diarization = ""
self.end_buffer = self.end_attributed_speaker = 0
self.full_transcription = self.last_response_content = ""
self.beg_loop = time()
async def ffmpeg_stdout_reader(self):
@@ -192,12 +188,6 @@ class AudioProcessor:
continue
self.pcm_buffer.extend(chunk)
# Send to diarization if enabled
if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(
self.convert_pcm_to_float(self.pcm_buffer).copy()
)
# Process when enough data
if len(self.pcm_buffer) >= self.bytes_per_sec:
@@ -214,7 +204,11 @@ class AudioProcessor:
# Send to transcription if enabled
if self.args.transcription and self.transcription_queue:
await self.transcription_queue.put(pcm_array.copy())
# Send to diarization if enabled
if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(pcm_array.copy())
# Sleep if no processing is happening
if not self.args.transcription and not self.args.diarization:
await asyncio.sleep(0.1)
@@ -240,7 +234,6 @@ class AudioProcessor:
async def transcription_processor(self):
"""Process audio chunks for transcription."""
self.full_transcription = ""
self.sep = self.online.asr.sep
cumulative_pcm_duration_stream_time = 0.0
@@ -252,7 +245,7 @@ class AudioProcessor:
self.transcription_queue.task_done()
break
if not self.online: # Should not happen if queue is used
if not self.online:
logger.warning("Transcription processor: self.online not initialized.")
self.transcription_queue.task_done()
continue
@@ -279,8 +272,6 @@ class AudioProcessor:
if new_tokens:
validated_text = self.sep.join([t.text for t in new_tokens])
self.full_transcription += validated_text
if buffer_text.startswith(validated_text):
buffer_text = buffer_text[len(validated_text):].lstrip()
@@ -297,7 +288,7 @@ class AudioProcessor:
new_end_buffer = max(candidate_end_times)
await self.update_transcription(
new_tokens, buffer_text, new_end_buffer, self.full_transcription, self.sep
new_tokens, buffer_text, new_end_buffer, self.sep
)
self.transcription_queue.task_done()
@@ -325,12 +316,12 @@ class AudioProcessor:
await diarization_obj.diarize(pcm_array)
async with self.lock:
new_end = diarization_obj.assign_speakers_to_tokens(
self.end_attributed_speaker,
self.tokens = diarization_obj.assign_speakers_to_tokens(
self.tokens,
use_punctuation_split=self.args.punctuation_split
)
self.end_attributed_speaker = new_end
if len(self.tokens) > 0:
self.end_attributed_speaker = max(self.tokens[-1].end, self.end_attributed_speaker)
if buffer_diarization:
self.buffer_diarization = buffer_diarization
@@ -346,6 +337,8 @@ class AudioProcessor:
async def results_formatter(self):
"""Format processing results for output."""
last_sent_trans = None
last_sent_diar = None
while True:
try:
ffmpeg_state = await self.ffmpeg_manager.get_state()
@@ -383,8 +376,8 @@ class AudioProcessor:
lines = []
last_end_diarized = 0
undiarized_text = []
# Process each token
current_time = time() - self.beg_loop
tokens = handle_silences(tokens, current_time)
for token in tokens:
speaker = token.speaker
@@ -449,10 +442,19 @@ class AudioProcessor:
' '.join([f"{line['speaker']} {line['text']}" for line in final_lines_for_response]) + \
f" | {buffer_transcription} | {buffer_diarization}"
if current_response_signature != self.last_response_content and \
(final_lines_for_response or buffer_transcription or buffer_diarization or response_status == "no_audio_detected"):
trans = state["remaining_time_transcription"]
diar = state["remaining_time_diarization"]
should_push = (
current_response_signature != self.last_response_content
or last_sent_trans is None
or round(trans, 1) != round(last_sent_trans, 1)
or round(diar, 1) != round(last_sent_diar, 1)
)
if should_push and (final_lines_for_response or buffer_transcription or buffer_diarization or response_status == "no_audio_detected" or trans > 0 or diar > 0):
yield response
self.last_response_content = current_response_signature
last_sent_trans = trans
last_sent_diar = diar
# Check for termination condition
if self.is_stopping:

View File

@@ -1,9 +1,12 @@
try:
from whisperlivekit.whisper_streaming_custom.whisper_online import backend_factory, warmup_asr
from whisperlivekit.whisper_streaming_custom.whisper_online import backend_factory
from whisperlivekit.whisper_streaming_custom.online_asr import VACOnlineASRProcessor, OnlineASRProcessor
except ImportError:
from .whisper_streaming_custom.whisper_online import backend_factory, warmup_asr
from .whisper_streaming_custom.whisper_online import backend_factory
from .whisper_streaming_custom.online_asr import VACOnlineASRProcessor, OnlineASRProcessor
from whisperlivekit.warmup import warmup_asr, warmup_online
from argparse import Namespace
import sys
class TranscriptionEngine:
_instance = None
@@ -22,7 +25,6 @@ class TranscriptionEngine:
"host": "localhost",
"port": 8000,
"warmup_file": None,
"confidence_validation": False,
"diarization": False,
"punctuation_split": False,
"min_chunk_size": 0.5,
@@ -34,15 +36,15 @@ class TranscriptionEngine:
"backend": "faster-whisper",
"vac": False,
"vac_chunk_size": 0.04,
"buffer_trimming": "segment",
"buffer_trimming_sec": 15,
"log_level": "DEBUG",
"ssl_certfile": None,
"ssl_keyfile": None,
"transcription": True,
"vad": True,
"segmentation_model": "pyannote/segmentation-3.0",
"embedding_model": "pyannote/embedding",
# whisperstreaming params:
"buffer_trimming": "segment",
"confidence_validation": False,
"buffer_trimming_sec": 15,
# simulstreaming params:
"frame_threshold": 25,
"beams": 1,
@@ -55,6 +57,10 @@ class TranscriptionEngine:
"static_init_prompt": None,
"max_context_tokens": None,
"model_path": './base.pt',
# diart params:
"segmentation_model": "pyannote/segmentation-3.0",
"embedding_model": "pyannote/embedding",
}
config_dict = {**defaults, **kwargs}
@@ -78,8 +84,32 @@ class TranscriptionEngine:
self.diarization = None
if self.args.transcription:
self.asr, self.tokenizer = backend_factory(self.args)
warmup_asr(self.asr, self.args.warmup_file)
if self.args.backend == "simulstreaming":
from simul_whisper import SimulStreamingASR
self.tokenizer = None
simulstreaming_kwargs = {}
for attr in ['frame_threshold', 'beams', 'decoder_type', 'audio_max_len', 'audio_min_len',
'cif_ckpt_path', 'never_fire', 'init_prompt', 'static_init_prompt',
'max_context_tokens', 'model_path']:
if hasattr(self.args, attr):
simulstreaming_kwargs[attr] = getattr(self.args, attr)
# Add segment_length from min_chunk_size
simulstreaming_kwargs['segment_length'] = getattr(self.args, 'min_chunk_size', 0.5)
simulstreaming_kwargs['task'] = self.args.task
size = self.args.model
self.asr = SimulStreamingASR(
modelsize=size,
lan=self.args.lan,
cache_dir=getattr(self.args, 'model_cache_dir', None),
model_dir=getattr(self.args, 'model_dir', None),
**simulstreaming_kwargs
)
else:
self.asr, self.tokenizer = backend_factory(self.args)
warmup_asr(self.asr, self.args.warmup_file) #for simulstreaming, warmup should be done in the online class not here
if self.args.diarization:
from whisperlivekit.diarization.diarization_online import DiartDiarization
@@ -90,3 +120,33 @@ class TranscriptionEngine:
)
TranscriptionEngine._initialized = True
def online_factory(args, asr, tokenizer, logfile=sys.stderr):
if args.backend == "simulstreaming":
from simul_whisper import SimulStreamingOnlineProcessor
online = SimulStreamingOnlineProcessor(
asr,
logfile=logfile,
)
# warmup_online(online, args.warmup_file)
elif args.vac:
online = VACOnlineASRProcessor(
args.min_chunk_size,
asr,
tokenizer,
logfile=logfile,
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
confidence_validation = args.confidence_validation
)
else:
online = OnlineASRProcessor(
asr,
tokenizer,
logfile=logfile,
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
confidence_validation = args.confidence_validation
)
return online

View File

@@ -165,7 +165,7 @@ class WebSocketAudioSource(AudioSource):
class DiartDiarization:
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 0.5, segmentation_model_name: str = "pyannote/segmentation-3.0", embedding_model_name: str = "speechbrain/spkrec-ecapa-voxceleb"):
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 1.5, segmentation_model_name: str = "pyannote/segmentation-3.0", embedding_model_name: str = "pyannote/embedding"):
segmentation_model = m.SegmentationModel.from_pretrained(segmentation_model_name)
embedding_model = m.EmbeddingModel.from_pretrained(embedding_model_name)
@@ -206,15 +206,14 @@ class DiartDiarization:
"""
if self.custom_source:
self.custom_source.push_audio(pcm_array)
self.observer.clear_old_segments()
return self.observer.get_segments()
# self.observer.clear_old_segments()
def close(self):
"""Close the audio source."""
if self.custom_source:
self.custom_source.close()
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list, use_punctuation_split: bool = False) -> float:
def assign_speakers_to_tokens(self, tokens: list, use_punctuation_split: bool = False) -> float:
"""
Assign speakers to tokens based on timing overlap with speaker segments.
Uses the segments collected by the observer.
@@ -231,85 +230,82 @@ class DiartDiarization:
if not self.lag_diart and segments and tokens:
self.lag_diart = segments[0].start - tokens[0].start
for token in tokens:
for segment in segments:
if not (segment.end <= token.start + self.lag_diart or segment.start >= token.end + self.lag_diart):
token.speaker = extract_number(segment.speaker) + 1
end_attributed_speaker = max(token.end, end_attributed_speaker)
if use_punctuation_split and len(tokens) > 1:
punctuation_marks = {'.', '!', '?'}
print("Here are the tokens:",
[(t.text, t.start, t.end, t.speaker) for t in tokens[:10]])
segment_map = []
for segment in segments:
speaker_num = extract_number(segment.speaker) + 1
segment_map.append((segment.start, segment.end, speaker_num))
segment_map.sort(key=lambda x: x[0])
i = 0
while i < len(tokens):
current_token = tokens[i]
is_sentence_end = False
if current_token.text and current_token.text.strip():
text = current_token.text.strip()
if text[-1] in punctuation_marks:
is_sentence_end = True
logger.debug(f"Token {i} ends sentence: '{current_token.text}' at {current_token.end:.2f}s")
if is_sentence_end and current_token.speaker != -1:
punctuation_time = current_token.end
current_speaker = current_token.speaker
j = i + 1
next_sentence_tokens = []
while j < len(tokens):
next_token = tokens[j]
next_sentence_tokens.append(j)
# Check if this token ends the next sentence
if next_token.text and next_token.text.strip():
if next_token.text.strip()[-1] in punctuation_marks:
break
j += 1
if next_sentence_tokens:
speaker_times = {}
for idx in next_sentence_tokens:
token = tokens[idx]
# Find which segments overlap with this token
for seg_start, seg_end, seg_speaker in segment_map:
if not (seg_end <= token.start or seg_start >= token.end):
# Calculate overlap duration
overlap_start = max(seg_start, token.start)
overlap_end = min(seg_end, token.end)
overlap_duration = overlap_end - overlap_start
if seg_speaker not in speaker_times:
speaker_times[seg_speaker] = 0
speaker_times[seg_speaker] += overlap_duration
if speaker_times:
dominant_speaker = max(speaker_times.items(), key=lambda x: x[1])[0]
if dominant_speaker != current_speaker:
logger.debug(f" Speaker change after punctuation: {current_speaker}{dominant_speaker}")
for idx in next_sentence_tokens:
if tokens[idx].speaker != dominant_speaker:
logger.debug(f" Reassigning token {idx} ('{tokens[idx].text}') to Speaker {dominant_speaker}")
tokens[idx].speaker = dominant_speaker
end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker)
else:
for idx in next_sentence_tokens:
if tokens[idx].speaker == -1:
tokens[idx].speaker = current_speaker
end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker)
i += 1
if not use_punctuation_split:
for token in tokens:
for segment in segments:
if not (segment.end <= token.start + self.lag_diart or segment.start >= token.end + self.lag_diart):
token.speaker = extract_number(segment.speaker) + 1
else:
tokens = add_speaker_to_tokens(segments, tokens)
return tokens
return end_attributed_speaker
def concatenate_speakers(segments):
segments_concatenated = [{"speaker": 1, "begin": 0.0, "end": 0.0}]
for segment in segments:
speaker = extract_number(segment.speaker) + 1
if segments_concatenated[-1]['speaker'] != speaker:
segments_concatenated.append({"speaker": speaker, "begin": segment.start, "end": segment.end})
else:
segments_concatenated[-1]['end'] = segment.end
# print("Segments concatenated:")
# for entry in segments_concatenated:
# print(f"Speaker {entry['speaker']}: {entry['begin']:.2f}s - {entry['end']:.2f}s")
return segments_concatenated
def add_speaker_to_tokens(segments, tokens):
"""
Assign speakers to tokens based on diarization segments, with punctuation-aware boundary adjustment.
"""
punctuation_marks = {'.', '!', '?'}
punctuation_tokens = [token for token in tokens if token.text.strip() in punctuation_marks]
segments_concatenated = concatenate_speakers(segments)
for ind, segment in enumerate(segments_concatenated):
for i, punctuation_token in enumerate(punctuation_tokens):
if punctuation_token.start > segment['end']:
after_length = punctuation_token.start - segment['end']
before_length = segment['end'] - punctuation_tokens[i - 1].end
if before_length > after_length:
segment['end'] = punctuation_token.start
if i < len(punctuation_tokens) - 1 and ind + 1 < len(segments_concatenated):
segments_concatenated[ind + 1]['begin'] = punctuation_token.start
else:
segment['end'] = punctuation_tokens[i - 1].end
if i < len(punctuation_tokens) - 1 and ind - 1 >= 0:
segments_concatenated[ind - 1]['begin'] = punctuation_tokens[i - 1].end
break
last_end = 0.0
for token in tokens:
start = max(last_end + 0.01, token.start)
token.start = start
token.end = max(start, token.end)
last_end = token.end
ind_last_speaker = 0
for segment in segments_concatenated:
for i, token in enumerate(tokens[ind_last_speaker:]):
if token.end <= segment['end']:
token.speaker = segment['speaker']
ind_last_speaker = i + 1
# print(
# f"Token '{token.text}' ('begin': {token.start:.2f}, 'end': {token.end:.2f}) "
# f"assigned to Speaker {segment['speaker']} ('segment': {segment['begin']:.2f}-{segment['end']:.2f})"
# )
elif token.start > segment['end']:
break
return tokens
def visualize_tokens(tokens):
conversation = [{"speaker": -1, "text": ""}]
for token in tokens:
speaker = conversation[-1]['speaker']
if token.speaker != speaker:
conversation.append({"speaker": token.speaker, "text": token.text})
else:
conversation[-1]['text'] += token.text
print("Conversation:")
for entry in conversation:
print(f"Speaker {entry['speaker']}: {entry['text']}")

View File

@@ -1,32 +0,0 @@
import os
import requests
import inspect
def get_module_path():
return os.path.dirname(inspect.getfile(inspect.currentframe()))
GITHUB_API_URL = "https://api.github.com/repos/ufal/SimulStreaming/contents/simul_whisper/whisper"
RAW_BASE_URL = "https://raw.githubusercontent.com/ufal/SimulStreaming/main/simul_whisper/whisper"
TARGET_DIR = os.path.join(get_module_path(), "simul_whisper", "whisper")
def download_files_from_github(api_url, local_dir):
os.makedirs(local_dir, exist_ok=True)
response = requests.get(api_url)
response.raise_for_status()
items = response.json()
for item in items:
if item['type'] == 'file':
download_url = item['download_url']
file_name = item['name']
file_response = requests.get(download_url)
file_response.raise_for_status()
with open(os.path.join(local_dir, file_name), 'wb') as f:
f.write(file_response.content)
elif item['type'] == 'dir':
# Recursive call for subdirectories
download_files_from_github(item['url'], os.path.join(local_dir, item['name']))
def download_simulstreaming_backend():
print(f"Downloading files into {TARGET_DIR} ...")
download_files_from_github(GITHUB_API_URL, TARGET_DIR)
print("✅ Download of SimulStreaming backend files completed successfully.")

View File

@@ -0,0 +1,103 @@
from whisperlivekit.timed_objects import ASRToken
import re
MIN_SILENCE_DURATION = 4 #in seconds
END_SILENCE_DURATION = 8 #in seconds. you should keep it important to not have false positive when the model lag is important
def blank_to_silence(tokens):
full_string = ''.join([t.text for t in tokens])
patterns = [re.compile(r'(?:\s*\[BLANK_AUDIO\]\s*)+'), re.compile(r'(?:\s*\[typing\]\s*)+')]
matches = []
for pattern in patterns:
for m in pattern.finditer(full_string):
matches.append({
'start': m.start(),
'end': m.end()
})
if matches:
# cleaned = pattern.sub(' ', full_string).strip()
# print("Cleaned:", cleaned)
cumulated_len = 0
silence_token = None
cleaned_tokens = []
for token in tokens:
if matches:
start = cumulated_len
end = cumulated_len + len(token.text)
cumulated_len = end
if start >= matches[0]['start'] and end <= matches[0]['end']:
if silence_token: #previous token was already silence
silence_token.start = min(silence_token.start, token.start)
silence_token.end = max(silence_token.end, token.end)
else: #new silence
silence_token = ASRToken(
start=token.start,
end=token.end,
speaker=-2,
probability=0.95
)
else:
if silence_token: #there was silence but no more
if silence_token.end - silence_token.start >= MIN_SILENCE_DURATION:
cleaned_tokens.append(
silence_token
)
silence_token = None
matches.pop(0)
cleaned_tokens.append(token)
# print(cleaned_tokens)
return cleaned_tokens
return tokens
def no_token_to_silence(tokens):
new_tokens = []
silence_token = None
for token in tokens:
if token.speaker == -2:
if new_tokens and new_tokens[-1].speaker == -2: #if token is silence and previous one too
new_tokens[-1].end = token.end
else:
new_tokens.append(token)
last_end = new_tokens[-1].end if new_tokens else 0.0
if token.start - last_end >= MIN_SILENCE_DURATION: #if token is not silence but important gap
if new_tokens and new_tokens[-1].speaker == -2:
new_tokens[-1].end = token.start
else:
silence_token = ASRToken(
start=last_end,
end=token.start,
speaker=-2,
probability=0.95
)
new_tokens.append(silence_token)
if token.speaker != -2:
new_tokens.append(token)
return new_tokens
def ends_with_silence(tokens, current_time):
if not tokens:
return []
last_token = tokens[-1]
if tokens and current_time - last_token.end >= END_SILENCE_DURATION:
if last_token.speaker == -2:
last_token.end = current_time
else:
tokens.append(
ASRToken(
start=tokens[-1].end,
end=current_time,
speaker=-2,
probability=0.95
)
)
return tokens
def handle_silences(tokens, current_time):
tokens = blank_to_silence(tokens) #useful for simulstreaming backend which tends to generate [BLANK_AUDIO] text
tokens = no_token_to_silence(tokens)
tokens = ends_with_silence(tokens, current_time)
return tokens

View File

@@ -0,0 +1,6 @@
from .backend import SimulStreamingASR, SimulStreamingOnlineProcessor
__all__ = [
"SimulStreamingASR",
"SimulStreamingOnlineProcessor",
]

View File

@@ -0,0 +1,223 @@
import sys
import numpy as np
import logging
from typing import List, Tuple, Optional
import logging
from whisperlivekit.timed_objects import ASRToken, Transcript
from whisperlivekit.simul_whisper.license_simulstreaming import SIMULSTREAMING_LICENSE
from .whisper import load_model, tokenizer
import os
logger = logging.getLogger(__name__)
try:
import torch
from whisperlivekit.simul_whisper.config import AlignAttConfig
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper
from whisperlivekit.simul_whisper.whisper import tokenizer
except ImportError as e:
raise ImportError(
"""SimulStreaming dependencies are not available.
Please install WhisperLiveKit using pip install "whisperlivekit[simulstreaming]".""")
class SimulStreamingOnlineProcessor:
SAMPLING_RATE = 16000
def __init__(
self,
asr,
logfile=sys.stderr,
warmup_file=None
):
self.asr = asr
self.logfile = logfile
self.is_last = False
self.beg = 0.0
self.end = 0.0
self.cumulative_audio_duration = 0.0
self.committed: List[ASRToken] = []
self.last_result_tokens: List[ASRToken] = []
self.model = PaddedAlignAttWhisper(
cfg=asr.cfg,
loaded_model=asr.whisper_model)
if asr.tokenizer:
self.model.tokenizer = asr.tokenizer
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: Optional[float] = None):
"""Append an audio chunk to be processed by SimulStreaming."""
# Convert numpy array to torch tensor
audio_tensor = torch.from_numpy(audio).float()
# Update timing
chunk_duration = len(audio) / self.SAMPLING_RATE
self.cumulative_audio_duration += chunk_duration
if audio_stream_end_time is not None:
self.end = audio_stream_end_time
else:
self.end = self.cumulative_audio_duration
self.model.insert_audio(audio_tensor)
def get_buffer(self):
return Transcript(
start=None,
end=None,
text='',
probability=None
)
def timestamped_text(self, tokens, generation):
# From the simulstreaming repo. self.model to self.asr.model
pr = generation["progress"]
if "result" not in generation:
split_words, split_tokens = self.model.tokenizer.split_to_word_tokens(tokens)
else:
split_words, split_tokens = generation["result"]["split_words"], generation["result"]["split_tokens"]
frames = [p["most_attended_frames"][0] for p in pr]
tokens = tokens.copy()
ret = []
for sw,st in zip(split_words,split_tokens):
b = None
for stt in st:
t,f = tokens.pop(0), frames.pop(0)
if t != stt:
raise ValueError(f"Token mismatch: {t} != {stt} at frame {f}.")
if b is None:
b = f
e = f
out = (b*0.02, e*0.02, sw)
ret.append(out)
logger.debug(f"TS-WORD:\t{' '.join(map(str, out))}")
return ret
def process_iter(self) -> Tuple[List[ASRToken], float]:
"""
Process accumulated audio chunks using SimulStreaming.
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
"""
try:
tokens, generation_progress = self.model.infer(is_last=self.is_last)
ts_words = self.timestamped_text(tokens, generation_progress)
new_tokens = []
for ts_word in ts_words:
start, end, word = ts_word
token = ASRToken(
start=start,
end=end,
text=word,
probability=0.95 # fake prob. Maybe we can extract it from the model?
)
new_tokens.append(token)
self.committed.extend(new_tokens)
return new_tokens, self.end
except Exception as e:
logger.exception(f"SimulStreaming processing error: {e}")
return [], self.end
def warmup(self, audio, init_prompt=""):
"""Warmup the SimulStreaming model."""
try:
self.model.insert_audio(audio)
self.model.infer(True)
self.model.refresh_segment(complete=True)
logger.info("SimulStreaming model warmed up successfully")
except Exception as e:
logger.exception(f"SimulStreaming warmup failed: {e}")
class SimulStreamingASR():
"""SimulStreaming backend with AlignAtt policy."""
sep = ""
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr, **kwargs):
logger.warning(SIMULSTREAMING_LICENSE)
self.logfile = logfile
self.transcribe_kargs = {}
self.original_language = None if lan == "auto" else lan
self.model_path = kwargs.get('model_path', './large-v3.pt')
self.frame_threshold = kwargs.get('frame_threshold', 25)
self.audio_max_len = kwargs.get('audio_max_len', 30.0)
self.audio_min_len = kwargs.get('audio_min_len', 0.0)
self.segment_length = kwargs.get('segment_length', 0.5)
self.beams = kwargs.get('beams', 1)
self.decoder_type = kwargs.get('decoder_type', 'greedy' if self.beams == 1 else 'beam')
self.task = kwargs.get('task', 'transcribe')
self.cif_ckpt_path = kwargs.get('cif_ckpt_path', None)
self.never_fire = kwargs.get('never_fire', False)
self.init_prompt = kwargs.get('init_prompt', None)
self.static_init_prompt = kwargs.get('static_init_prompt', None)
self.max_context_tokens = kwargs.get('max_context_tokens', None)
if model_dir is not None:
self.model_path = model_dir
elif modelsize is not None:
model_mapping = {
'tiny': './tiny.pt',
'base': './base.pt',
'small': './small.pt',
'medium': './medium.pt',
'medium.en': './medium.en.pt',
'large-v1': './large-v1.pt',
'base.en': './base.en.pt',
'small.en': './small.en.pt',
'tiny.en': './tiny.en.pt',
'large-v2': './large-v2.pt',
'large-v3': './large-v3.pt',
'large': './large-v3.pt'
}
self.model_path = model_mapping.get(modelsize, f'./{modelsize}.pt')
self.model = self.load_model(modelsize)
# Set up tokenizer for translation if needed
if self.task == "translate":
self.tokenizer = self.set_translate_task()
else:
self.tokenizer = None
def load_model(self, modelsize):
self.cfg = AlignAttConfig(
model_path=self.model_path,
segment_length=self.segment_length,
frame_threshold=self.frame_threshold,
language=self.original_language,
audio_max_len=self.audio_max_len,
audio_min_len=self.audio_min_len,
cif_ckpt_path=self.cif_ckpt_path,
decoder_type="beam",
beam_size=self.beams,
task=self.task,
never_fire=self.never_fire,
init_prompt=self.init_prompt,
max_context_tokens=self.max_context_tokens,
static_init_prompt=self.static_init_prompt,
)
model_name = os.path.basename(self.cfg.model_path).replace(".pt", "")
model_path = os.path.dirname(os.path.abspath(self.cfg.model_path))
self.whisper_model = load_model(name=model_name, download_root=model_path)
def set_translate_task(self):
"""Set up translation task."""
return tokenizer.get_tokenizer(
multilingual=True,
language=self.model.cfg.language,
num_languages=self.model.model.num_languages,
task="translate"
)
def transcribe(self, audio):
"""
Only used for warmup. It's a direct whisper call, not a simulstreaming call
"""
self.whisper_model.transcribe(audio, language=self.original_language)

View File

@@ -8,7 +8,7 @@ class SimulWhisperConfig:
'''Options that are common for all simul policies that could be implemented in SimulWhisper.'''
model_path: str
language: str = field(default="zh")
nonspeech_prob: float = 1.0
nonspeech_prob: float = 0.5
audio_min_len: float = 1.0
decoder_type: Literal["greedy","beam"] = "greedy"
beam_size: int = 5

View File

@@ -1,25 +0,0 @@
📄 SimulStreaming (https://github.com/ufal/SimulStreaming) Licence
SimulStreaming is dual-licensed:
🔹 Non-Commercial Use
You may use SimulStreaming under the **PolyForm Noncommercial License 1.0.0** if you
obtain the code through the GitHub repository. This license is **free of charge**
and comes with **no obligations** for non-commercial users.
🔸 Commercial Use
Understanding who uses SimulStreaming commercially helps us improve and
prioritize development. Therefore, we want to **require registration** of those who acquire a commercial licence.
We plan to make the commercial licenceses **affordable** to SMEs and individuals. We
are considering to provide commercial licenses either for free or for symbolic
one-time fee, and maybe also provide additional support. You can share your preference via the [questionnaire](https://forms.cloud.microsoft/e/7tCxb4gJfB).
You can also leave your contact [there](https://forms.cloud.microsoft/e/7tCxb4gJfB) to be notified when the commercial licenses become
available.
✉️ Contact
[Dominik Macháček](https://ufal.mff.cuni.cz/dominik-machacek/), machacek@ufal.mff.cuni.cz

View File

@@ -25,6 +25,9 @@ class BeamTokens(Tokens):
def __repr__(self):
return self.__str__()
def as_text(self, tokenizer):
return tokenizer.decode(self.tokens)
class Logits(Tokens):
def __init__(self, logits):
super().__init__(logits)

View File

@@ -0,0 +1,5 @@
SIMULSTREAMING_LICENSE = f"""
SimulStreaming backend is dual-licensed:
• Non-Commercial Use: PolyForm Noncommercial License 1.0.0.
• Commercial Use: Check SimulStreaming README (github.com/ufal/SimulStreaming) for more details.
"""

View File

@@ -10,12 +10,12 @@ from .whisper import load_model, DecodingOptions, tokenizer
from .config import AlignAttConfig
from .whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES
from .whisper.timing import median_filter
from .whisper.decoding import SuppressBlank, GreedyDecoder, BeamSearchDecoder, SuppressTokens
from .whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens, detect_language
from .beam import BeamPyTorchInference
from .eow_detection import fire_at_boundary, load_cif
import os
from whisperlivekit.simul_whisper.token_buffer import TokenBuffer
from .token_buffer import TokenBuffer
import numpy as np
from .generation_progress import *
@@ -24,6 +24,7 @@ DEC_PAD = 50257
logger = logging.getLogger(__name__)
import sys
import wave
# New features added to the original version of Simul-Whisper:
# - large-v3 model support
@@ -32,29 +33,30 @@ import sys
# - prompt -- static vs. non-static
# - context
class PaddedAlignAttWhisper:
def __init__(self, cfg: AlignAttConfig) -> None:
def __init__(self, cfg: AlignAttConfig, loaded_model=None) -> None:
self.log_segments = 0
model_name = os.path.basename(cfg.model_path).replace(".pt", "")
model_path = os.path.dirname(os.path.abspath(cfg.model_path))
self.model = load_model(name=model_name, download_root=model_path)
if loaded_model:
self.model = loaded_model
else:
self.model = load_model(name=model_name, download_root=model_path)
logger.info(f"Model dimensions: {self.model.dims}")
decode_options = DecodingOptions(
self.decode_options = DecodingOptions(
language = cfg.language,
without_timestamps = True,
task=cfg.task
)
self.tokenizer = tokenizer.get_tokenizer(
multilingual=not model_name.endswith(".en"),
language=cfg.language,
num_languages=self.model.num_languages,
task=decode_options.task
)
self.tokenizer_is_multilingual = not model_name.endswith(".en")
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
self.detected_language = cfg.language if cfg.language != "auto" else None
self.max_text_len = self.model.dims.n_text_ctx
self.num_decoder_layers = len(self.model.decoder.blocks)
self.cfg = cfg
# model to detect end-of-word boundary at the end of the segment
self.CIFLinear, self.always_fire, self.never_fire = load_cif(cfg,
n_audio_state=self.model.dims.n_audio_state,
@@ -95,14 +97,6 @@ class PaddedAlignAttWhisper:
self.num_align_heads += 1
# init tokens (mandatory prompt)
self.initial_tokens = torch.tensor(
self.tokenizer.sot_sequence_including_notimestamps,
dtype=torch.long,
device=self.model.device).unsqueeze(0)
self.initial_token_length = self.initial_tokens.shape[1]
self.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
# tokens to be suppressed from decoding, to prevent hallucinations
suppress_tokens = [
self.tokenizer.transcribe,
@@ -121,6 +115,17 @@ class PaddedAlignAttWhisper:
self.suppress_tokens = lambda logits: sup_tokens.apply(logits, None)
# blank tokens are suppresed for new segments near the line 334
# it's going to be regenerated after lang id
self.segments = []
self.init_tokens()
self.last_attend_frame = -self.cfg.rewind_threshold
if self.cfg.max_context_tokens is None:
self.max_context_tokens = self.max_text_len
else:
self.max_context_tokens = self.cfg.max_context_tokens
self.init_context()
# decoder type: greedy or beam
if cfg.decoder_type == "greedy":
@@ -135,16 +140,13 @@ class PaddedAlignAttWhisper:
self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size)
# init state
self.segments = []
self.tokens = [self.initial_tokens]
self.last_attend_frame = -self.cfg.rewind_threshold
if self.cfg.max_context_tokens is None:
self.max_context_tokens = self.max_text_len
else:
self.max_context_tokens = self.cfg.max_context_tokens
self.init_context()
def create_tokenizer(self, language=None):
self.tokenizer = tokenizer.get_tokenizer(
multilingual=self.tokenizer_is_multilingual,
language=language,
num_languages=self.model.num_languages,
task=self.decode_options.task
)
def init_context(self):
kw = {'tokenizer': self.tokenizer,
@@ -156,6 +158,19 @@ class PaddedAlignAttWhisper:
if self.cfg.init_prompt is not None:
self.context.text += self.cfg.init_prompt
def init_tokens(self):
logger.debug(f"init tokens, {len(self.segments)}")
# init tokens (mandatory prompt)
self.initial_tokens = torch.tensor(
self.tokenizer.sot_sequence_including_notimestamps,
dtype=torch.long,
device=self.model.device).unsqueeze(0)
self.initial_token_length = self.initial_tokens.shape[1]
self.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
# self.segments = []
logger.debug(f"init tokens after, {len(self.segments)}")
self.tokens = [self.initial_tokens]
def trim_context(self):
logger.info("Trimming context")
c = len(self.context.as_token_ids()) - len(self.context.prefix_token_ids)
@@ -191,15 +206,19 @@ class PaddedAlignAttWhisper:
def refresh_segment(self, complete=False):
logger.debug("Refreshing segment")
self.tokens = [self.initial_tokens]
logger.debug("Refreshing segment:")
self.init_tokens()
self.last_attend_frame = -self.cfg.rewind_threshold
self.detected_language = None
self.init_context()
logger.debug(f"Context: {self.context}")
if not complete and len(self.segments) > 2:
logger.debug("keeping last two segments because they are and it is not complete.")
self.segments = self.segments[-2:]
else:
logger.debug("removing all segments.")
self.segments = []
self.log_segments += 1
def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor):
@@ -208,8 +227,6 @@ class PaddedAlignAttWhisper:
return fire_at_boundary(chunked_encoder_feature, self.CIFLinear)
def _current_tokens(self):
toks = self.tokens
@@ -256,16 +273,59 @@ class PaddedAlignAttWhisper:
removed_len = 0
# len of audio is bigger than buffer_len. Going to remove the first segment
segments_len = self.segments_len()
while segments_len > self.cfg.audio_max_len:
while len(self.segments) > 1 and segments_len > self.cfg.audio_max_len:
removed_len = self.segments[0].shape[0] / 16000
segments_len -= removed_len
self.last_attend_frame -= int(TOKENS_PER_SECOND*removed_len)
self.segments = self.segments[1:]
logger.debug(f"remove segments: {len(self.segments)} {len(self.tokens)}")
self.context.append_token_ids(self.tokens[1][0,:])
self.tokens = [self.initial_tokens] + self.tokens[2:]
if len(self.tokens) > 1:
self.context.append_token_ids(self.tokens[1][0,:])
self.tokens = [self.initial_tokens] + self.tokens[2:]
return removed_len
def _clean_cache(self):
'''clean the cache that stores the attention matrices and kv_cache.
It must be called every time after generation with the model.'''
# cleaning cache
self.dec_attns = []
self.kv_cache = {}
if self.decoder_type == "beam":
self.inference.kv_cache = self.kv_cache
self.token_decoder.reset()
@torch.no_grad()
def lang_id(self, encoder_features):
"""Language detection from encoder features.
This code is trimmed and copy-pasted from whisper.decoding.detect_language .
"""
# forward pass using a single token, startoftranscript
n_audio = encoder_features.shape[0]
x = torch.tensor([[self.tokenizer.sot]] * n_audio).to(self.model.device) # [n_audio, 1]
logits = self.model.logits(x, encoder_features)[:, 0]
# collect detected languages; suppress all non-language tokens
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
mask[list(self.tokenizer.all_language_tokens)] = False
logits[:, mask] = -np.inf
language_tokens = logits.argmax(dim=-1)
language_token_probs = logits.softmax(dim=-1).cpu()
language_probs = [
{
c: language_token_probs[i, j].item()
for j, c in zip(self.tokenizer.all_language_tokens, self.tokenizer.all_language_codes)
}
for i in range(n_audio)
]
single = encoder_features.ndim == 2
if single:
language_tokens = language_tokens[0]
language_probs = language_probs[0]
self._clean_cache()
return language_tokens, language_probs
### transcription / translation
@@ -273,9 +333,12 @@ class PaddedAlignAttWhisper:
def infer(self, is_last=False):
new_segment = True
if len(self.segments) == 0:
return []
logger.debug("No segments, nothing to do")
return [], {}
if not self._apply_minseglen():
return []
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
input_segments = torch.cat(self.segments, dim=0)
return [], {}
# input_segments is concatenation of audio, it's one array
if len(self.segments) > 1:
@@ -283,8 +346,7 @@ class PaddedAlignAttWhisper:
else:
input_segments = self.segments[0]
self.trim_context()
current_tokens = self._current_tokens()
# mel + padding to 30s
mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES,
@@ -295,18 +357,38 @@ class PaddedAlignAttWhisper:
# the len of actual audio
content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2)
# encode
encoder_feature = self.model.encoder(mel)
sum_logprobs = torch.zeros(self.cfg.beam_size, device=mel.device)
completed = False
# logger.debug(f"Encoder feature shape: {encoder_feature.shape}")
# if mel.shape[-2:] != (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
# logger.debug("mel ")
if self.cfg.language == "auto" and self.detected_language is None:
language_tokens, language_probs = self.lang_id(encoder_feature)
logger.debug(f"Language tokens: {language_tokens}, probs: {language_probs}")
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
logger.info(f"Detected language: {top_lan} with p={p:.4f}")
#self.tokenizer.language = top_lan
#self.tokenizer.__post_init__()
self.create_tokenizer(top_lan)
self.detected_language = top_lan
self.init_tokens()
logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}")
self.trim_context()
current_tokens = self._current_tokens()
#
fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :])
####################### Decoding loop
logger.info("Decoding loop starts\n")
sum_logprobs = torch.zeros(self.cfg.beam_size, device=mel.device)
completed = False
attn_of_alignment_heads = None
miost_attended_frame = None
most_attended_frame = None
token_len_before_decoding = current_tokens.shape[1]
@@ -515,11 +597,6 @@ class PaddedAlignAttWhisper:
logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}")
# cleaning cache
self.dec_attns = []
self.kv_cache = {}
if self.decoder_type == "beam":
self.inference.kv_cache = self.kv_cache
self.token_decoder.reset()
self._clean_cache()
return new_hypothesis, generation
return new_hypothesis, generation

View File

@@ -32,7 +32,9 @@ def detect_language(
list of dictionaries containing the probability distribution over all languages.
"""
if tokenizer is None:
tokenizer = get_tokenizer(model.is_multilingual)
tokenizer = get_tokenizer(
model.is_multilingual, num_languages=model.num_languages
)
if (
tokenizer.language is None
or tokenizer.language_token not in tokenizer.sot_sequence
@@ -111,9 +113,6 @@ class DecodingOptions:
# implementation details
fp16: bool = True # use fp16 for most of the calculation
# streaming
add_sot: Optional[bool] = True
@dataclass(frozen=True)
class DecodingResult:
@@ -513,19 +512,17 @@ class DecodingTask:
logit_filters: List[LogitFilter]
def __init__(self, model: "Whisper", options: DecodingOptions):
self.options: DecodingOptions = self._verify_options(options)
if self.options.fp16:
self.model = model.half()
else:
self.model = model
self.model = model
language = options.language or "en"
tokenizer = get_tokenizer(
model.is_multilingual, language=language, task=options.task
model.is_multilingual,
num_languages=model.num_languages,
language=language,
task=options.task,
)
self.tokenizer: Tokenizer = tokenizer
# print(self.options)
self.options: DecodingOptions = self._verify_options(options)
self.n_group: int = options.beam_size or options.best_of or 1
self.n_ctx: int = model.dims.n_text_ctx
@@ -589,7 +586,7 @@ class DecodingTask:
def _get_initial_tokens(self) -> Tuple[int]:
tokens = list(self.sot_sequence)
# print("prefix", prefix)
if prefix := self.options.prefix:
prefix_tokens = (
self.tokenizer.encode(" " + prefix.strip())
@@ -607,15 +604,12 @@ class DecodingTask:
if isinstance(prompt, str)
else prompt
)
# if self.options.add_sot:
tokens = (
[self.tokenizer.sot_prev]
+ prompt_tokens[-(self.n_ctx // 2 - 1) :]
+ tokens
)
#else:
# tokens = ([self.tokenizer.sot_prev] + tokens + prompt_tokens[-(self.n_ctx // 2 - 1) :])
# print("return", tokens)
return tuple(tokens)
def _get_suppress_tokens(self) -> Tuple[int]:
@@ -663,7 +657,7 @@ class DecodingTask:
if audio_features.dtype != (
torch.float16 if self.options.fp16 else torch.float32
):
raise TypeError(
return TypeError(
f"audio_features has an incorrect dtype: {audio_features.dtype}"
)
@@ -689,10 +683,9 @@ class DecodingTask:
no_speech_probs = [np.nan] * n_batch
try:
for i in range(self.sample_len): # 最多循环448次
# print("in decode main loop", i , tokens[0].tolist())
for i in range(self.sample_len):
logits = self.inference.logits(tokens, audio_features)
# print(logits)
if (
i == 0 and self.tokenizer.no_speech is not None
): # save no_speech_probs
@@ -724,7 +717,7 @@ class DecodingTask:
audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass
tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
# print("initial_tokens", self.initial_tokens)
# detect language if requested, overwriting the language token
languages, language_probs = self._detect_language(audio_features, tokens)
if self.options.task == "lang_id":

View File

@@ -13,7 +13,6 @@ from .decoding import decode as decode_function
from .decoding import detect_language as detect_language_function
from .transcribe import transcribe as transcribe_function
try:
from torch.nn.functional import scaled_dot_product_attention
@@ -37,26 +36,27 @@ class ModelDimensions:
n_text_layer: int
# class LayerNorm(nn.LayerNorm):
# def forward(self, x: Tensor) -> Tensor:
# return super().forward(x.float()).type(x.dtype)
# class Linear(nn.Linear):
# def forward(self, x: Tensor) -> Tensor:
# return F.linear(
# x,
# self.weight.to(x.dtype),
# None if self.bias is None else self.bias.to(x.dtype),
# )
class LayerNorm(nn.LayerNorm):
def forward(self, x: Tensor) -> Tensor:
return super().forward(x.float()).type(x.dtype)
# class Conv1d(nn.Conv1d):
# def _conv_forward(
# self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
# ) -> Tensor:
# return super()._conv_forward(
# x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
# )
class Linear(nn.Linear):
def forward(self, x: Tensor) -> Tensor:
return F.linear(
x,
self.weight.to(x.dtype),
None if self.bias is None else self.bias.to(x.dtype),
)
class Conv1d(nn.Conv1d):
def _conv_forward(
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
) -> Tensor:
return super()._conv_forward(
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
)
def sinusoids(length, channels, max_timescale=10000):
@@ -67,21 +67,30 @@ def sinusoids(length, channels, max_timescale=10000):
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
import sys ## this is mine, for debugging
@contextmanager
def disable_sdpa():
prev_state = MultiHeadAttention.use_sdpa
try:
MultiHeadAttention.use_sdpa = False
yield
finally:
MultiHeadAttention.use_sdpa = prev_state
class MultiHeadAttention(nn.Module):
use_sdpa = False # Disable SDPA to ensure qk is always computed for hooks
use_sdpa = False # disabling: https://github.com/linto-ai/whisper-timestamped/issues/212
def __init__(self, n_state: int, n_head: int, cache_id: str):
def __init__(self, n_state: int, n_head: int, cache_id: str = ""):
super().__init__()
self.n_head = n_head
self.query = nn.Linear(n_state, n_state)
self.key = nn.Linear(n_state, n_state, bias=False)
self.key.cache_id = f"{cache_id}_key"
self.value = nn.Linear(n_state, n_state)
self.value.cache_id = f"{cache_id}_value"
self.out = nn.Linear(n_state, n_state)
self.query = Linear(n_state, n_state)
self.key = Linear(n_state, n_state, bias=False)
self.value = Linear(n_state, n_state)
self.out = Linear(n_state, n_state)
self.cache_id = cache_id
self.key.cache_id = f"{cache_id}_key"
self.value.cache_id = f"{cache_id}_value"
def forward(
self,
@@ -90,45 +99,21 @@ class MultiHeadAttention(nn.Module):
mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None,
):
#print("MultiHeadAttention forward",file=sys.stderr)
q = self.query(x)
# print(q.shape, x is None, mask is None, list(kv_cache.keys()) if kv_cache is not None else None, file=sys.stderr)
# print(mask, kv_cache, xa, file=sys.stderr)
if kv_cache is None or xa is None or self.key.cache_id not in kv_cache:
if kv_cache is None or xa is None or self.key not in kv_cache:
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
# otherwise, perform key/value projections for self- or cross-attention as usual.
k = self.key(x if xa is None else xa)
v = self.value(x if xa is None else xa)
# print(self.key.cache_id, "cache miss") # , kv_cache is None, xa is None, self.key.cache_id not in kv_cache if kv_cache is not None else None, k.shape, x.shape)
# if kv_cache is not None:
# print(kv_cache.keys())
else:
# print(self.key.cache_id, "cache hit") #, kv_cache is None, xa is None, self.key.cache_id not in kv_cache)
# if kv_cache is not None:
# print(kv_cache.keys())
k = kv_cache[self.key.cache_id]
v = kv_cache[self.value.cache_id]
# print(self.key.cache_id, "qkv attention", q.shape, k.shape, v.shape)
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
k = kv_cache[self.key]
v = kv_cache[self.value]
wv, qk = self.qkv_attention(q, k, v, mask)
return self.out(wv), qk
# def qkv_attention(
# self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
# ):
# n_batch, n_ctx, n_state = q.shape
# scale = (n_state // self.n_head) ** -0.25
# q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
# k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
# v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
# qk = q @ k
# if mask is not None:
# qk = qk + mask[:n_ctx, :n_ctx]
# # qk = qk.float()
# w = F.softmax(qk, dim=-1) # .to(q.dtype)
# return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
def qkv_attention(
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@@ -158,21 +143,22 @@ class MultiHeadAttention(nn.Module):
class ResidualAttentionBlock(nn.Module):
def __init__(self, n_state: int, n_head: int, cache_id: str="", cross_attention: bool = False):
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, cache_id: str = ""):
super().__init__()
self.attn = MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_self_attn")
self.attn_ln = nn.LayerNorm(n_state)
self.attn_ln = LayerNorm(n_state)
self.cross_attn = MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_cross_attn") if cross_attention else None
self.cross_attn_ln = nn.LayerNorm(n_state) if cross_attention else None
self.cross_attn = (
MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_cross_attn") if cross_attention else None
)
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
n_mlp = n_state * 4
self.mlp = nn.Sequential(
nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state)
Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
)
self.mlp_ln = nn.LayerNorm(n_state)
self.mlp_ln = LayerNorm(n_state)
def forward(
self,
@@ -181,8 +167,6 @@ class ResidualAttentionBlock(nn.Module):
mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None,
):
# print("ResidualAttentionBlock forward",file=sys.stderr)
# print(x.shape, file=sys.stderr)
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
if self.cross_attn:
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
@@ -195,44 +179,32 @@ class AudioEncoder(nn.Module):
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
):
super().__init__()
self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)
self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[ResidualAttentionBlock(n_state, n_head, cache_id=f"enc_layer{i}") for i in range(n_layer)]
)
self.ln_post = nn.LayerNorm(n_state)
self.ln_post = LayerNorm(n_state)
def forward(self, x: Tensor, return_layer_results: bool=False):
def forward(self, x: Tensor):
"""
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
the mel spectrogram of the audio
"""
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
x = x.permute(0, 2, 1) # BDT -> BTD
x = x.permute(0, 2, 1)
# 两层卷积2倍降采样
# 最终剩下1500帧
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
x = (x + self.positional_embedding).to(x.dtype)
x = (x + self.positional_embedding[:x.shape[1], :]) #.to(x.dtype)
layer_results = []
i = 0
for block in self.blocks:
# print(f"encoder layer {i}")
x = block(x)
layer_results.append(x)
i += 1
x = self.ln_post(x)
if return_layer_results:
return x, layer_results
else:
return x
return x
class TextDecoder(nn.Module):
@@ -250,7 +222,7 @@ class TextDecoder(nn.Module):
for i in range(n_layer)
]
)
self.ln = nn.LayerNorm(n_state)
self.ln = LayerNorm(n_state)
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
self.register_buffer("mask", mask, persistent=False)
@@ -262,22 +234,20 @@ class TextDecoder(nn.Module):
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
the encoded audio features to be attended on
"""
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
x = (
self.token_embedding(x)
+ self.positional_embedding[offset : offset + x.shape[-1]]
)
# x = x.to(xa.dtype)
x = x.to(xa.dtype)
i = 0
for block in self.blocks:
# print(f"decoder layer {i}")
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
i += 1
x = self.ln(x)
logits = x @ torch.transpose(self.token_embedding.weight, 0, 1)
logits = (
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
).float()
return logits
@@ -300,7 +270,8 @@ class Whisper(nn.Module):
self.dims.n_text_head,
self.dims.n_text_layer,
)
# use the last half layers for alignment by default; see `set_alignment_heads()` below
# use the last half among the decoder layers for time alignment by default;
# to use a specific set of heads, see `set_alignment_heads()` below.
all_heads = torch.zeros(
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
)
@@ -320,15 +291,11 @@ class Whisper(nn.Module):
return self.encoder(mel)
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
# tokens = tokens.to(self.decoder.ln.weight.dtype)
# audio_features = audio_features.to(self.decoder.ln.weight.dtype)
return self.decoder(tokens, audio_features)
def forward(
self, mel: torch.Tensor, tokens: torch.Tensor
) -> Dict[str, torch.Tensor]:
# mel = mel.to(self.decoder.ln.weight.dtype)
# tokens = tokens.to(self.decoder.ln.weight.dtype)
return self.decoder(tokens, self.encoder(mel))
@property
@@ -343,7 +310,6 @@ class Whisper(nn.Module):
def num_languages(self):
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
# 为decoder加入缓存机制每次推理时保存上次的k和v下次推理无需重新计算
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
"""
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value

View File

@@ -30,15 +30,19 @@ def remove_symbols_and_diacritics(s: str, keep=""):
and drop any diacritics (category 'Mn' and some manual mappings)
"""
return "".join(
c
if c in keep
else ADDITIONAL_DIACRITICS[c]
if c in ADDITIONAL_DIACRITICS
else ""
if unicodedata.category(c) == "Mn"
else " "
if unicodedata.category(c)[0] in "MSP"
else c
(
c
if c in keep
else (
ADDITIONAL_DIACRITICS[c]
if c in ADDITIONAL_DIACRITICS
else (
""
if unicodedata.category(c) == "Mn"
else " " if unicodedata.category(c)[0] in "MSP" else c
)
)
)
for c in unicodedata.normalize("NFKD", s)
)

File diff suppressed because it is too large Load Diff

View File

@@ -56,9 +56,8 @@ def median_filter(x: torch.Tensor, filter_width: int):
@numba.jit(nopython=True)
def backtrace(trace: np.ndarray):
i = trace.shape[0] - 1 # trace: (N+1, M+1), i=N
j = trace.shape[1] - 1 # j=M
# 边界点其实无意义?
i = trace.shape[0] - 1
j = trace.shape[1] - 1
trace[0, :] = 2
trace[:, 0] = 1
@@ -83,8 +82,8 @@ def backtrace(trace: np.ndarray):
@numba.jit(nopython=True, parallel=True)
def dtw_cpu(x: np.ndarray):
N, M = x.shape
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf # cost: x[0, 0]到x[i-1, j-1]的最小代价
trace = -np.ones((N + 1, M + 1), dtype=np.float32) # trace:
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
trace = -np.ones((N + 1, M + 1), dtype=np.float32)
cost[0, 0] = 0
for j in range(1, M + 1):
@@ -118,7 +117,7 @@ def dtw_cuda(x, BLOCK_SIZE=1024):
x_skew = x_skew.T.contiguous()
cost = torch.ones(N + M + 2, M + 2) * np.inf
cost[0, 0] = 0
cost = cost.cuda()
cost = cost.to(x.device)
trace = torch.zeros_like(cost, dtype=torch.int32)
dtw_kernel[(1,)](
@@ -192,21 +191,19 @@ def find_alignment(
for i, block in enumerate(model.decoder.blocks)
]
# 进行前传获得token概率
with torch.no_grad():
from .model import disable_sdpa
with torch.no_grad(), disable_sdpa():
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
token_probs = sampled_logits.softmax(dim=-1)
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
text_token_probs = text_token_probs.tolist()
# 移除钩子
for hook in hooks:
hook.remove()
# heads * tokens * frames
# print(model.alignment_heads)
# exit(0)
weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
weights = weights[:, :, : num_frames // 2]
weights = (weights * qk_scale).softmax(dim=-1)
@@ -215,18 +212,9 @@ def find_alignment(
weights = median_filter(weights, medfilt_width)
matrix = weights.mean(axis=0)
print("attention", matrix.shape, matrix[:5, :5])
matrix = matrix[len(tokenizer.sot_sequence) : -1]
print("attention", matrix.shape, matrix[:5, :5])
text_indices, time_indices = dtw(-matrix)
print("num_frames", num_frames)
print("attention", matrix.shape, matrix[:5, :5])
print("text_indices", text_indices)
print("time", time_indices)
print("text_tokens", text_tokens, tokenizer.decode(text_tokens), len(text_tokens))
print("eot", tokenizer.eot)
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
if len(word_tokens) <= 1:
# return on eot only
@@ -238,9 +226,7 @@ def find_alignment(
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
# print("jumps", jumps, jumps.shape)
jump_times = time_indices[jumps] / TOKENS_PER_SECOND
# print("jump_times", jump_times)
start_times = jump_times[word_boundaries[:-1]]
end_times = jump_times[word_boundaries[1:]]
word_probabilities = [
@@ -315,6 +301,7 @@ def add_word_timestamps(
word_durations = np.array([t.end - t.start for t in alignment])
word_durations = word_durations[word_durations.nonzero()]
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
median_duration = min(0.7, float(median_duration))
max_duration = median_duration * 2
# hack: truncate long words at sentence boundaries.

View File

@@ -1,501 +0,0 @@
import argparse
import os
import warnings
from typing import TYPE_CHECKING, Optional, Tuple, Union
import numpy as np
import torch
import tqdm
from whisper.audio import (
FRAMES_PER_SECOND,
HOP_LENGTH,
N_FRAMES,
N_SAMPLES,
SAMPLE_RATE,
log_mel_spectrogram,
pad_or_trim,
)
from whisper.decoding import DecodingOptions, DecodingResult
from whisper.timing import add_word_timestamps
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from whisper.utils import (
exact_div,
format_timestamp,
get_writer,
make_safe,
optional_float,
optional_int,
str2bool,
)
if TYPE_CHECKING:
from whisper.model import Whisper
def transcribe(
model: "Whisper",
audio: Union[str, np.ndarray, torch.Tensor],
*,
verbose: Optional[bool] = None,
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
compression_ratio_threshold: Optional[float] = 2.4,
logprob_threshold: Optional[float] = -1.0,
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True,
initial_prompt: Optional[str] = None,
word_timestamps: bool = False,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、",
**decode_options,
):
"""
Transcribe an audio file using Whisper
Parameters
----------
model: Whisper
The Whisper model instance
audio: Union[str, np.ndarray, torch.Tensor]
The path to the audio file to open, or the audio waveform
verbose: bool
Whether to display the text being decoded to the console. If True, displays all the details,
If False, displays minimal details. If None, does not display anything
temperature: Union[float, Tuple[float, ...]]
Temperature for sampling. It can be a tuple of temperatures, which will be successively used
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
compression_ratio_threshold: float
If the gzip compression ratio is above this value, treat as failed
logprob_threshold: float
If the average log probability over sampled tokens is below this value, treat as failed
no_speech_threshold: float
If the no_speech probability is higher than this value AND the average log probability
over sampled tokens is below `logprob_threshold`, consider the segment as silent
condition_on_previous_text: bool
if True, the previous output of the model is provided as a prompt for the next window;
disabling may make the text inconsistent across windows, but the model becomes less prone to
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
word_timestamps: bool
Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
and include the timestamps for each word in each segment.
prepend_punctuations: str
If word_timestamps is True, merge these punctuation symbols with the next word
append_punctuations: str
If word_timestamps is True, merge these punctuation symbols with the previous word
initial_prompt: Optional[str]
Optional text to provide as a prompt for the first window. This can be used to provide, or
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
to make it more likely to predict those word correctly.
decode_options: dict
Keyword arguments to construct `DecodingOptions` instances
Returns
-------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
"""
# print("HACKED")
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
if model.device == torch.device("cpu"):
if torch.cuda.is_available():
warnings.warn("Performing inference on CPU when CUDA is available")
if dtype == torch.float16:
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
dtype = torch.float32
if dtype == torch.float32:
decode_options["fp16"] = False
# Pad 30-seconds of silence to the input audio, for slicing
mel = log_mel_spectrogram(audio, padding=0) # log_mel_spectrogram(audio, padding=N_SAMPLES) # 添加16000*30 = 480000个点
# mel = pad_or_trim(mel, 3000)
content_frames = mel.shape[-1] # - N_FRAMES # 对应3000帧真正有内容的是去掉尾部3000的那些数据
# 判断语种
if decode_options.get("language", None) is None:
# 如果是单语种模型,直接设成英文
if not model.is_multilingual:
decode_options["language"] = "en"
# 否则需要前传一次
else:
if verbose:
print(
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
)
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
# print(mel_segment.shape)
_, probs = model.detect_language(mel_segment)
decode_options["language"] = max(probs, key=probs.get)
if verbose is not None:
print(
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
)
language: str = decode_options["language"]
task: str = decode_options.get("task", "transcribe")
# 输出编码器
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
# 词级别时间戳
if word_timestamps and task == "translate":
warnings.warn("Word-level timestamps on translations may not be reliable.")
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
temperatures = (
[temperature] if isinstance(temperature, (int, float)) else temperature
)
decode_result = None
for t in temperatures:
kwargs = {**decode_options}
if t > 0:
# disable beam_size and patience when t > 0
kwargs.pop("beam_size", None)
kwargs.pop("patience", None)
else:
# disable best_of when t == 0
kwargs.pop("best_of", None)
options = DecodingOptions(**kwargs, temperature=t)
decode_result = model.decode(segment, options)
# 几种解码可能失败的情况。这些情况下会重复解码
# 感觉是一种KnowHow的东西 或许ChatGPT里有不少这种trick
needs_fallback = False
if (
compression_ratio_threshold is not None
and decode_result.compression_ratio > compression_ratio_threshold
):
needs_fallback = True # too repetitive
if (
logprob_threshold is not None
and decode_result.avg_logprob < logprob_threshold
):
needs_fallback = True # average log probability is too low
if (
no_speech_threshold is not None
and decode_result.no_speech_prob > no_speech_threshold
):
needs_fallback = False # silence
if not needs_fallback:
break
# print("decode with temperature {} compress rate {:.3f}/{:.3f}, log_prob {:.3f}/{:.3f}, {:.3f}/{:.3f}".format(
# t,
# decode_result.compression_ratio, compression_ratio_threshold,
# -decode_result.avg_logprob, -logprob_threshold,
# decode_result.no_speech_prob, no_speech_threshold
# ))
return decode_result
seek = 0
input_stride = exact_div(
N_FRAMES, model.dims.n_audio_ctx
) # mel frames per output token: 2
# 这里output token指的应该是CNN输出的那个东西
time_precision = (
input_stride * HOP_LENGTH / SAMPLE_RATE
) # time per output token: 0.02 (seconds)
all_tokens = []
all_segments = []
prompt_reset_since = 0
if initial_prompt is not None:
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
all_tokens.extend(initial_prompt_tokens)
else:
initial_prompt_tokens = []
def new_segment(
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
):
tokens = tokens.tolist()
text_tokens = [token for token in tokens if token < tokenizer.eot]
return {
"seek": seek,
"start": start,
"end": end,
"text": tokenizer.decode(text_tokens),
"tokens": tokens,
"temperature": result.temperature,
"avg_logprob": result.avg_logprob,
"compression_ratio": result.compression_ratio,
"no_speech_prob": result.no_speech_prob,
}
# show the progress bar when verbose is False (if True, transcribed text will be printed)
with tqdm.tqdm(
total=content_frames, unit="frames", disable=verbose is not False
) as pbar:
last_speech_timestamp = 0.0
while seek < content_frames: # seek标记mel频谱当前帧的位置 直接跳过Padding上的部分
# print("seek segments", seek, content_frames)
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) # 本片段的开始时间
# mel_segment = mel[:, seek : seek + N_FRAMES] # 获得当前片段的数据
mel_segment = mel[:, seek:]
segment_size = min(N_FRAMES, content_frames - seek) # segment_size: 排除padding的真的长度。content_frames有内容的段的真正长度 如果不够N_FRAMES的话就会截断
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE # 当前片段的时长
mel_segment = mel_segment.to(model.device).to(dtype) # pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype) # 补到mel_segment帧
decode_options["prompt"] = all_tokens[prompt_reset_since:]
result: DecodingResult = decode_with_fallback(mel_segment)
tokens = torch.tensor(result.tokens)
# 跳过静音部分
if no_speech_threshold is not None:
# no voice activity check
should_skip = result.no_speech_prob > no_speech_threshold
if (
logprob_threshold is not None
and result.avg_logprob > logprob_threshold
):
# don't skip if the logprob is high enough, despite the no_speech_prob
should_skip = False
if should_skip:
seek += segment_size # fast-forward to the next segment boundary
continue
previous_seek = seek
current_segments = []
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) # timestamp begin是<|0.00|>的tokenbos比文字token大eos的值比bos还大所以是ge
timestamp_tokens[-1] = False
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] # 如果最后是[False,True]:本段里一个句子结束了
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
# torch.where(condition) is identical to torch.nonzero(condition, as_tuple=True).
# timestamp_token就是个一维向量吧 那为啥不直接nonzero
# 如果有两个连续的时间戳 这个会是一个一维tensor 是这两个连续时间戳的结尾位置
# 多个的话指向第二个 那如果有三个怎么办?
# 否则是个0维tensor
consecutive.add_(1) # 0维tensor+1还是0维 哪儿找的这些edge cases js是吧
if len(consecutive) > 0:
# if the output contains two consecutive timestamp tokens
slices = consecutive.tolist()
if single_timestamp_ending:
slices.append(len(tokens)) # 把最后一段的结尾也加进去
# print("many sentenses", consecutive)
last_slice = 0
for current_slice in slices:
sliced_tokens = tokens[last_slice:current_slice]
# 看起来语音开始帧、语音结束帧的位置会被编码到start_timestamp中
start_timestamp_pos = (
sliced_tokens[0].item() - tokenizer.timestamp_begin
)
end_timestamp_pos = (
sliced_tokens[-1].item() - tokenizer.timestamp_begin
)
# 获取一个新的语音段
current_segments.append(
new_segment(
start=time_offset + start_timestamp_pos * time_precision,
end=time_offset + end_timestamp_pos * time_precision,
tokens=sliced_tokens,
result=result,
)
)
last_slice = current_slice
if single_timestamp_ending:
# single timestamp at the end means no speech after the last timestamp.
seek += segment_size
else:
# otherwise, ignore the unfinished segment and seek to the last timestamp
# 如果语音尚未结束那么seek变为上一个结束的语段的位置
# 换句话说就是针对30s长的chunk的语音设计的
last_timestamp_pos = (
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
)
seek += last_timestamp_pos * input_stride
else:
duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
# print(timestamps)
if (
len(timestamps) > 0
and timestamps[-1].item() != tokenizer.timestamp_begin
):
# no consecutive timestamps but it has a timestamp; use the last one.
# 取最后一个假设要么有一个结束的time stamp要么有一对儿
# 如果里面只有一个开始的timestamp 似乎后面的东西都会被丢掉?
last_timestamp_pos = (
timestamps[-1].item() - tokenizer.timestamp_begin
)
duration = last_timestamp_pos * time_precision
current_segments.append(
new_segment(
start=time_offset,
end=time_offset + duration,
tokens=tokens,
result=result,
)
)
seek += segment_size
# 每个token有自己的时间戳
if word_timestamps:
add_word_timestamps(
segments=current_segments,
model=model,
tokenizer=tokenizer,
mel=mel_segment,
num_frames=segment_size,
prepend_punctuations=prepend_punctuations,
append_punctuations=append_punctuations,
last_speech_timestamp=last_speech_timestamp,
)
word_end_timestamps = [
w["end"] for s in current_segments for w in s["words"]
]
if len(word_end_timestamps) > 0:
last_speech_timestamp = word_end_timestamps[-1]
if not single_timestamp_ending and len(word_end_timestamps) > 0:
seek_shift = round(
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
)
if seek_shift > 0:
seek = previous_seek + seek_shift
if verbose:
for segment in current_segments:
start, end, text = segment["start"], segment["end"], segment["text"]
line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
print(make_safe(line))
# if a segment is instantaneous or does not contain text, clear it
for i, segment in enumerate(current_segments):
if segment["start"] == segment["end"] or segment["text"].strip() == "":
segment["text"] = ""
segment["tokens"] = []
segment["words"] = []
# 更新结果
all_segments.extend(
[
{"id": i, **segment}
for i, segment in enumerate(
current_segments, start=len(all_segments)
)
]
)
all_tokens.extend(
[token for segment in current_segments for token in segment["tokens"]]
)
if not condition_on_previous_text or result.temperature > 0.5:
# do not feed the prompt tokens if a high temperature was used
prompt_reset_since = len(all_tokens)
# update progress bar
pbar.update(min(content_frames, seek) - previous_seek)
# print("太长了")
# break
return dict(
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
segments=all_segments,
language=language,
)
def cli():
from . import available_models
# fmt: off
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
parser.add_argument("--append_punctuations", type=str, default="\"\'.。,!?::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
# fmt: on
args = parser.parse_args().__dict__
model_name: str = args.pop("model")
model_dir: str = args.pop("model_dir")
output_dir: str = args.pop("output_dir")
output_format: str = args.pop("output_format")
device: str = args.pop("device")
os.makedirs(output_dir, exist_ok=True)
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
if args["language"] is not None:
warnings.warn(
f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
)
args["language"] = "en"
temperature = args.pop("temperature")
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
else:
temperature = [temperature]
if (threads := args.pop("threads")) > 0:
torch.set_num_threads(threads)
from . import load_model
model = load_model(model_name, device=device, download_root=model_dir)
writer = get_writer(output_format, output_dir)
word_options = ["highlight_words", "max_line_count", "max_line_width"]
if not args["word_timestamps"]:
for option in word_options:
if args[option]:
parser.error(f"--{option} requires --word_timestamps True")
if args["max_line_count"] and not args["max_line_width"]:
warnings.warn("--max_line_count has no effect without --max_line_width")
writer_args = {arg: args.pop(arg) for arg in word_options}
for audio_path in args.pop("audio"):
result = transcribe(model, audio_path, temperature=temperature, **args)
writer(result, audio_path, writer_args)
if __name__ == "__main__":
cli()

View File

@@ -1,7 +1,8 @@
import argparse
import os
import traceback
import warnings
from typing import TYPE_CHECKING, Optional, Tuple, Union
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import numpy as np
import torch
@@ -22,6 +23,7 @@ from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .utils import (
exact_div,
format_timestamp,
get_end,
get_writer,
make_safe,
optional_float,
@@ -44,9 +46,12 @@ def transcribe(
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True,
initial_prompt: Optional[str] = None,
carry_initial_prompt: bool = False,
word_timestamps: bool = False,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、",
clip_timestamps: Union[str, List[float]] = "0",
hallucination_silence_threshold: Optional[float] = None,
**decode_options,
):
"""
@@ -98,15 +103,27 @@ def transcribe(
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
to make it more likely to predict those word correctly.
carry_initial_prompt: bool
If carry_initial_prompt is True, `initial_prompt` is prepended to the prompt of each internal
`decode()` call. If there is not enough context space at the start of the prompt, it is
left-sliced to make space.
decode_options: dict
Keyword arguments to construct `DecodingOptions` instances
clip_timestamps: Union[str, List[float]]
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process.
The last end timestamp defaults to the end of the file.
hallucination_silence_threshold: Optional[float]
When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
when a possible hallucination is detected
Returns
-------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
"""
# print("transcribe")
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
if model.device == torch.device("cpu"):
if torch.cuda.is_available():
@@ -119,8 +136,9 @@ def transcribe(
decode_options["fp16"] = False
# Pad 30-seconds of silence to the input audio, for slicing
mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
content_frames = mel.shape[-1] - N_FRAMES
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
if decode_options.get("language", None) is None:
if not model.is_multilingual:
@@ -131,7 +149,6 @@ def transcribe(
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
)
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
# print(mel_segment.shape)
_, probs = model.detect_language(mel_segment)
decode_options["language"] = max(probs, key=probs.get)
if verbose is not None:
@@ -141,7 +158,25 @@ def transcribe(
language: str = decode_options["language"]
task: str = decode_options.get("task", "transcribe")
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
tokenizer = get_tokenizer(
model.is_multilingual,
num_languages=model.num_languages,
language=language,
task=task,
)
if isinstance(clip_timestamps, str):
clip_timestamps = [
float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])
]
seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
if len(seek_points) == 0:
seek_points.append(0)
if len(seek_points) % 2 == 1:
seek_points.append(content_frames)
seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))
punctuation = "\"'“¿([{-\"'.。,!?::”)]}、"
if word_timestamps and task == "translate":
warnings.warn("Word-level timestamps on translations may not be reliable.")
@@ -179,6 +214,8 @@ def transcribe(
if (
no_speech_threshold is not None
and decode_result.no_speech_prob > no_speech_threshold
and logprob_threshold is not None
and decode_result.avg_logprob < logprob_threshold
):
needs_fallback = False # silence
if not needs_fallback:
@@ -186,7 +223,8 @@ def transcribe(
return decode_result
seek = 0
clip_idx = 0
seek = seek_clips[clip_idx][0]
input_stride = exact_div(
N_FRAMES, model.dims.n_audio_ctx
) # mel frames per output token: 2
@@ -197,9 +235,11 @@ def transcribe(
all_segments = []
prompt_reset_since = 0
remaining_prompt_length = model.dims.n_text_ctx // 2 - 1
if initial_prompt is not None:
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
all_tokens.extend(initial_prompt_tokens)
remaining_prompt_length -= len(initial_prompt_tokens)
else:
initial_prompt_tokens = []
@@ -225,16 +265,33 @@ def transcribe(
total=content_frames, unit="frames", disable=verbose is not False
) as pbar:
last_speech_timestamp = 0.0
while seek < content_frames:
# NOTE: This loop is obscurely flattened to make the diff readable.
# A later commit should turn this into a simpler nested loop.
# for seek_clip_start, seek_clip_end in seek_clips:
# while seek < seek_clip_end
while clip_idx < len(seek_clips):
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
if seek < seek_clip_start:
seek = seek_clip_start
if seek >= seek_clip_end:
clip_idx += 1
if clip_idx < len(seek_clips):
seek = seek_clips[clip_idx][0]
continue
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
mel_segment = mel[:, seek : seek + N_FRAMES]
segment_size = min(N_FRAMES, content_frames - seek)
window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek)
mel_segment = mel[:, seek : seek + segment_size]
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
# print("melshape", mel_segment.shape)
if carry_initial_prompt:
nignored = max(len(initial_prompt_tokens), prompt_reset_since)
remaining_prompt = all_tokens[nignored:][-remaining_prompt_length:]
decode_options["prompt"] = initial_prompt_tokens + remaining_prompt
else:
decode_options["prompt"] = all_tokens[prompt_reset_since:]
decode_options["prompt"] = all_tokens[prompt_reset_since:]
result: DecodingResult = decode_with_fallback(mel_segment)
tokens = torch.tensor(result.tokens)
@@ -255,6 +312,30 @@ def transcribe(
previous_seek = seek
current_segments = []
# anomalous words are very long/short/improbable
def word_anomaly_score(word: dict) -> float:
probability = word.get("probability", 0.0)
duration = word["end"] - word["start"]
score = 0.0
if probability < 0.15:
score += 1.0
if duration < 0.133:
score += (0.133 - duration) * 15
if duration > 2.0:
score += duration - 2.0
return score
def is_segment_anomaly(segment: Optional[dict]) -> bool:
if segment is None or not segment["words"]:
return False
words = [w for w in segment["words"] if w["word"] not in punctuation]
words = words[:8]
score = sum(word_anomaly_score(w) for w in words)
return score >= 3 or score + 0.01 >= len(words)
def next_words_segment(segments: List[dict]) -> Optional[dict]:
return next((s for s in segments if s["words"]), None)
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
@@ -317,9 +398,7 @@ def transcribe(
)
seek += segment_size
# print("word_timestamps, ", word_timestamps)
if word_timestamps:
# print("=========run timestamps here=========")
add_word_timestamps(
segments=current_segments,
model=model,
@@ -330,17 +409,71 @@ def transcribe(
append_punctuations=append_punctuations,
last_speech_timestamp=last_speech_timestamp,
)
word_end_timestamps = [
w["end"] for s in current_segments for w in s["words"]
]
if len(word_end_timestamps) > 0:
last_speech_timestamp = word_end_timestamps[-1]
if not single_timestamp_ending and len(word_end_timestamps) > 0:
seek_shift = round(
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
)
if seek_shift > 0:
seek = previous_seek + seek_shift
if not single_timestamp_ending:
last_word_end = get_end(current_segments)
if last_word_end is not None and last_word_end > time_offset:
seek = round(last_word_end * FRAMES_PER_SECOND)
# skip silence before possible hallucinations
if hallucination_silence_threshold is not None:
threshold = hallucination_silence_threshold
if not single_timestamp_ending:
last_word_end = get_end(current_segments)
if last_word_end is not None and last_word_end > time_offset:
remaining_duration = window_end_time - last_word_end
if remaining_duration > threshold:
seek = round(last_word_end * FRAMES_PER_SECOND)
else:
seek = previous_seek + segment_size
# if first segment might be a hallucination, skip leading silence
first_segment = next_words_segment(current_segments)
if first_segment is not None and is_segment_anomaly(first_segment):
gap = first_segment["start"] - time_offset
if gap > threshold:
seek = previous_seek + round(gap * FRAMES_PER_SECOND)
continue
# skip silence before any possible hallucination that is surrounded
# by silence or more hallucinations
hal_last_end = last_speech_timestamp
for si in range(len(current_segments)):
segment = current_segments[si]
if not segment["words"]:
continue
if is_segment_anomaly(segment):
next_segment = next_words_segment(
current_segments[si + 1 :]
)
if next_segment is not None:
hal_next_start = next_segment["words"][0]["start"]
else:
hal_next_start = time_offset + segment_duration
silence_before = (
segment["start"] - hal_last_end > threshold
or segment["start"] < threshold
or segment["start"] - time_offset < 2.0
)
silence_after = (
hal_next_start - segment["end"] > threshold
or is_segment_anomaly(next_segment)
or window_end_time - segment["end"] < 2.0
)
if silence_before and silence_after:
seek = round(
max(time_offset + 1, segment["start"])
* FRAMES_PER_SECOND
)
if content_duration - segment["end"] < threshold:
seek = content_frames
current_segments[si:] = []
break
hal_last_end = segment["end"]
last_word_end = get_end(current_segments)
if last_word_end is not None:
last_speech_timestamp = last_word_end
if verbose:
for segment in current_segments:
@@ -384,10 +517,17 @@ def transcribe(
def cli():
from . import available_models
def valid_model_name(name):
if name in available_models() or os.path.exists(name):
return name
raise ValueError(
f"model should be one of {available_models()} or path to a model checkpoint"
)
# fmt: off
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
parser.add_argument("--model", default="turbo", type=valid_model_name, help="name of the Whisper model to use")
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
@@ -405,6 +545,8 @@ def cli():
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
parser.add_argument("--carry_initial_prompt", type=str2bool, default=False, help="if True, prepend initial_prompt to every internal decode() call. May reduce the effectiveness of condition_on_previous_text")
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
@@ -418,7 +560,10 @@ def cli():
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
parser.add_argument("--max_words_per_line", type=optional_int, default=None, help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file")
parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected")
# fmt: on
args = parser.parse_args().__dict__
@@ -450,17 +595,28 @@ def cli():
model = load_model(model_name, device=device, download_root=model_dir)
writer = get_writer(output_format, output_dir)
word_options = ["highlight_words", "max_line_count", "max_line_width"]
word_options = [
"highlight_words",
"max_line_count",
"max_line_width",
"max_words_per_line",
]
if not args["word_timestamps"]:
for option in word_options:
if args[option]:
parser.error(f"--{option} requires --word_timestamps True")
if args["max_line_count"] and not args["max_line_width"]:
warnings.warn("--max_line_count has no effect without --max_line_width")
if args["max_words_per_line"] and args["max_line_width"]:
warnings.warn("--max_words_per_line has no effect with --max_line_width")
writer_args = {arg: args.pop(arg) for arg in word_options}
for audio_path in args.pop("audio"):
result = transcribe(model, audio_path, temperature=temperature, **args)
writer(result, audio_path, writer_args)
try:
result = transcribe(model, audio_path, temperature=temperature, **args)
writer(result, audio_path, **writer_args)
except Exception as e:
traceback.print_exc()
print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}")
if __name__ == "__main__":

View File

@@ -60,7 +60,7 @@ def median_kernel(filter_width: int):
tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
kernel = triton.JITFunction(kernel.fn)
kernel.src = kernel.src.replace(
new_kernel = kernel.src.replace(
" LOAD_ALL_ROWS_HERE",
"\n".join(
[
@@ -69,7 +69,8 @@ def median_kernel(filter_width: int):
]
),
)
kernel.src = kernel.src.replace(
new_kernel = new_kernel.replace(
" BUBBLESORT_HERE",
"\n\n".join(
[
@@ -90,7 +91,14 @@ def median_kernel(filter_width: int):
]
),
)
kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
new_kernel = new_kernel.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
if hasattr(kernel, "_unsafe_update_src") is True:
kernel._unsafe_update_src(new_kernel)
kernel.hash = None
else:
kernel.src = new_kernel
return kernel

View File

@@ -3,7 +3,7 @@ import os
import re
import sys
import zlib
from typing import Callable, Optional, TextIO
from typing import Callable, List, Optional, TextIO
system_encoding = sys.getdefaultencoding()
@@ -68,13 +68,29 @@ def format_timestamp(
)
def get_start(segments: List[dict]) -> Optional[float]:
return next(
(w["start"] for s in segments for w in s["words"]),
segments[0]["start"] if segments else None,
)
def get_end(segments: List[dict]) -> Optional[float]:
return next(
(w["end"] for s in reversed(segments) for w in reversed(s["words"])),
segments[-1]["end"] if segments else None,
)
class ResultWriter:
extension: str
def __init__(self, output_dir: str):
self.output_dir = output_dir
def __call__(self, result: dict, audio_path: str, options: dict):
def __call__(
self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
):
audio_basename = os.path.basename(audio_path)
audio_basename = os.path.splitext(audio_basename)[0]
output_path = os.path.join(
@@ -82,16 +98,20 @@ class ResultWriter:
)
with open(output_path, "w", encoding="utf-8") as f:
self.write_result(result, file=f, options=options)
self.write_result(result, file=f, options=options, **kwargs)
def write_result(self, result: dict, file: TextIO, options: dict):
def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
raise NotImplementedError
class WriteTXT(ResultWriter):
extension: str = "txt"
def write_result(self, result: dict, file: TextIO, options: dict):
def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
for segment in result["segments"]:
print(segment["text"].strip(), file=file, flush=True)
@@ -100,48 +120,76 @@ class SubtitlesWriter(ResultWriter):
always_include_hours: bool
decimal_marker: str
def iterate_result(self, result: dict, options: dict):
raw_max_line_width: Optional[int] = options["max_line_width"]
max_line_count: Optional[int] = options["max_line_count"]
highlight_words: bool = options["highlight_words"]
max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width
preserve_segments = max_line_count is None or raw_max_line_width is None
def iterate_result(
self,
result: dict,
options: Optional[dict] = None,
*,
max_line_width: Optional[int] = None,
max_line_count: Optional[int] = None,
highlight_words: bool = False,
max_words_per_line: Optional[int] = None,
):
options = options or {}
max_line_width = max_line_width or options.get("max_line_width")
max_line_count = max_line_count or options.get("max_line_count")
highlight_words = highlight_words or options.get("highlight_words", False)
max_words_per_line = max_words_per_line or options.get("max_words_per_line")
preserve_segments = max_line_count is None or max_line_width is None
max_line_width = max_line_width or 1000
max_words_per_line = max_words_per_line or 1000
def iterate_subtitles():
line_len = 0
line_count = 1
# the next subtitle to yield (a list of word timings with whitespace)
subtitle: list[dict] = []
last = result["segments"][0]["words"][0]["start"]
subtitle: List[dict] = []
last: float = get_start(result["segments"]) or 0.0
for segment in result["segments"]:
for i, original_timing in enumerate(segment["words"]):
timing = original_timing.copy()
long_pause = not preserve_segments and timing["start"] - last > 3.0
has_room = line_len + len(timing["word"]) <= max_line_width
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
if line_len > 0 and has_room and not long_pause and not seg_break:
# line continuation
line_len += len(timing["word"])
else:
# new line
timing["word"] = timing["word"].strip()
chunk_index = 0
words_count = max_words_per_line
while chunk_index < len(segment["words"]):
remaining_words = len(segment["words"]) - chunk_index
if max_words_per_line > len(segment["words"]) - chunk_index:
words_count = remaining_words
for i, original_timing in enumerate(
segment["words"][chunk_index : chunk_index + words_count]
):
timing = original_timing.copy()
long_pause = (
not preserve_segments and timing["start"] - last > 3.0
)
has_room = line_len + len(timing["word"]) <= max_line_width
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
if (
len(subtitle) > 0
and max_line_count is not None
and (long_pause or line_count >= max_line_count)
or seg_break
line_len > 0
and has_room
and not long_pause
and not seg_break
):
# subtitle break
yield subtitle
subtitle = []
line_count = 1
elif line_len > 0:
# line break
line_count += 1
timing["word"] = "\n" + timing["word"]
line_len = len(timing["word"].strip())
subtitle.append(timing)
last = timing["start"]
# line continuation
line_len += len(timing["word"])
else:
# new line
timing["word"] = timing["word"].strip()
if (
len(subtitle) > 0
and max_line_count is not None
and (long_pause or line_count >= max_line_count)
or seg_break
):
# subtitle break
yield subtitle
subtitle = []
line_count = 1
elif line_len > 0:
# line break
line_count += 1
timing["word"] = "\n" + timing["word"]
line_len = len(timing["word"].strip())
subtitle.append(timing)
last = timing["start"]
chunk_index += max_words_per_line
if len(subtitle) > 0:
yield subtitle
@@ -161,9 +209,11 @@ class SubtitlesWriter(ResultWriter):
yield start, end, "".join(
[
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
if j == i
else word
(
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
if j == i
else word
)
for j, word in enumerate(all_words)
]
)
@@ -190,9 +240,11 @@ class WriteVTT(SubtitlesWriter):
always_include_hours: bool = False
decimal_marker: str = "."
def write_result(self, result: dict, file: TextIO, options: dict):
def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
print("WEBVTT\n", file=file)
for start, end, text in self.iterate_result(result, options):
for start, end, text in self.iterate_result(result, options, **kwargs):
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
@@ -201,9 +253,11 @@ class WriteSRT(SubtitlesWriter):
always_include_hours: bool = True
decimal_marker: str = ","
def write_result(self, result: dict, file: TextIO, options: dict):
def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
for i, (start, end, text) in enumerate(
self.iterate_result(result, options), start=1
self.iterate_result(result, options, **kwargs), start=1
):
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
@@ -220,7 +274,9 @@ class WriteTSV(ResultWriter):
extension: str = "tsv"
def write_result(self, result: dict, file: TextIO, options: dict):
def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
print("start", "end", "text", sep="\t", file=file)
for segment in result["segments"]:
print(round(1000 * segment["start"]), file=file, end="\t")
@@ -231,7 +287,9 @@ class WriteTSV(ResultWriter):
class WriteJSON(ResultWriter):
extension: str = "json"
def write_result(self, result: dict, file: TextIO, options: dict):
def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
json.dump(result, file)
@@ -249,9 +307,11 @@ def get_writer(
if output_format == "all":
all_writers = [writer(output_dir) for writer in writers.values()]
def write_all(result: dict, file: TextIO, options: dict):
def write_all(
result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
for writer in all_writers:
writer(result, file, options)
writer(result, file, options, **kwargs)
return write_all

View File

@@ -1 +1 @@
__version__ = "20230918"
__version__ = "20250625"

View File

@@ -1,73 +0,0 @@
import torch
import sys
class TokenBuffer:
def __init__(self, text="", tokenizer=None, device=None, prefix_token_ids=[]):
self.text = text
self.prefix_token_ids = prefix_token_ids
self.tokenizer = tokenizer
self.device = device
def as_token_ids(self, tokenizer=None):
if tokenizer is None:
tokenizer = self.tokenizer
if tokenizer is None:
raise ValueError("Tokenizer is not set.")
return self.prefix_token_ids + tokenizer.encode(self.text)
def as_tensor(self, device=None):
if device is None:
device = self.device
if device is None:
raise ValueError("Device is not set.")
tok_ids = self.as_token_ids()
return torch.tensor(tok_ids,
dtype=torch.long, device=device).unsqueeze(0)
def as_tensor_beam(self, beam, device=None):
t = self.as_tensor(device=device)
return t.repeat_interleave(beam, dim=0)
def as_text(self):
return self.text
@staticmethod
def empty(*a, **kw):
return TokenBuffer(*a,**kw)
@staticmethod
def from_text(text, *a, **kw):
return TokenBuffer(*a, text=text, **kw)
def is_empty(self):
return self.text is None or self.text == ""
def trim_words(self, num=1, after=0):
'''
num: how many words to trim from the beginning
after: how many characters to skip (length of the static prompt)
'''
tokenizer = self.tokenizer
assert tokenizer is not None, "Tokenizer is not set."
ids = tokenizer.encode(self.text[after:])
words, wids = self.tokenizer.split_to_word_tokens(ids)
print(words, file=sys.stderr)
print(wids, file=sys.stderr)
if not words:
return 0
self.text = self.text[:after] + "".join(words[num:])
return sum(len(wi) for wi in wids[:num])
def append_token_ids(self, token_ids):
tokenizer = self.tokenizer
assert tokenizer is not None, "Tokenizer is not set."
self.text += self.tokenizer.decode(token_ids)
def as_split_word_tokens(self):
tokenizer = self.tokenizer
assert tokenizer is not None, "Tokenizer is not set."
ids = tokenizer.encode(self.text)
return tokenizer.split_to_word_tokens(ids)

62
whisperlivekit/warmup.py Normal file
View File

@@ -0,0 +1,62 @@
import logging
logger = logging.getLogger(__name__)
def load_file(warmup_file=None, timeout=5):
import os
import tempfile
import librosa
if warmup_file is None:
# Download JFK sample if not already present
jfk_url = "https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav"
temp_dir = tempfile.gettempdir()
warmup_file = os.path.join(temp_dir, "whisper_warmup_jfk.wav")
if not os.path.exists(warmup_file):
logger.debug(f"Downloading warmup file from {jfk_url}")
print(f"Downloading warmup file from {jfk_url}")
import time
import urllib.request
import urllib.error
import socket
original_timeout = socket.getdefaulttimeout()
socket.setdefaulttimeout(timeout)
start_time = time.time()
try:
urllib.request.urlretrieve(jfk_url, warmup_file)
logger.debug(f"Download successful in {time.time() - start_time:.2f}s")
except (urllib.error.URLError, socket.timeout) as e:
logger.warning(f"Download failed: {e}. Proceeding without warmup.")
return False
finally:
socket.setdefaulttimeout(original_timeout)
elif not warmup_file:
return False
if not warmup_file or not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0:
logger.warning(f"Warmup file {warmup_file} invalid or missing.")
return False
try:
audio, sr = librosa.load(warmup_file, sr=16000)
except Exception as e:
logger.warning(f"Failed to load audio file: {e}")
return False
return audio
def warmup_asr(asr, warmup_file=None, timeout=5):
"""
Warmup the ASR model by transcribing a short audio file.
"""
audio = load_file(warmup_file=None, timeout=5)
asr.transcribe(audio)
logger.info("ASR model is warmed up")
def warmup_online(online, warmup_file=None, timeout=5):
audio = load_file(warmup_file=None, timeout=5)
online.warmup(audio)
logger.warning("ASR is warmed up")

View File

@@ -4,12 +4,87 @@
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Audio Transcription</title>
<title>WhisperLiveKit</title>
<style>
:root {
--bg: #ffffff;
--text: #111111;
--muted: #666666;
--border: #e5e5e5;
--chip-bg: rgba(0, 0, 0, 0.04);
--chip-text: #000000;
--spinner-border: #8d8d8d5c;
--spinner-top: #b0b0b0;
--silence-bg: #f3f3f3;
--loading-bg: rgba(255, 77, 77, 0.06);
--button-bg: #ffffff;
--button-border: #e9e9e9;
--wave-stroke: #000000;
--label-dia-text: #868686;
--label-trans-text: #111111;
}
@media (prefers-color-scheme: dark) {
:root:not([data-theme="light"]) {
--bg: #0b0b0b;
--text: #e6e6e6;
--muted: #9aa0a6;
--border: #333333;
--chip-bg: rgba(255, 255, 255, 0.08);
--chip-text: #e6e6e6;
--spinner-border: #555555;
--spinner-top: #dddddd;
--silence-bg: #1a1a1a;
--loading-bg: rgba(255, 77, 77, 0.12);
--button-bg: #111111;
--button-border: #333333;
--wave-stroke: #e6e6e6;
--label-dia-text: #b3b3b3;
--label-trans-text: #ffffff;
}
}
:root[data-theme="dark"] {
--bg: #0b0b0b;
--text: #e6e6e6;
--muted: #9aa0a6;
--border: #333333;
--chip-bg: rgba(255, 255, 255, 0.08);
--chip-text: #e6e6e6;
--spinner-border: #555555;
--spinner-top: #dddddd;
--silence-bg: #1a1a1a;
--loading-bg: rgba(255, 77, 77, 0.12);
--button-bg: #111111;
--button-border: #333333;
--wave-stroke: #e6e6e6;
--label-dia-text: #b3b3b3;
--label-trans-text: #ffffff;
}
:root[data-theme="light"] {
--bg: #ffffff;
--text: #111111;
--muted: #666666;
--border: #e5e5e5;
--chip-bg: rgba(0, 0, 0, 0.04);
--chip-text: #000000;
--spinner-border: #8d8d8d5c;
--spinner-top: #b0b0b0;
--silence-bg: #f3f3f3;
--loading-bg: rgba(255, 77, 77, 0.06);
--button-bg: #ffffff;
--button-border: #e9e9e9;
--wave-stroke: #000000;
--label-dia-text: #868686;
--label-trans-text: #111111;
}
body {
font-family: ui-sans-serif, system-ui, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol', 'Noto Color Emoji';
margin: 20px;
text-align: center;
background-color: var(--bg);
color: var(--text);
}
#recordButton {
@@ -17,10 +92,10 @@
height: 50px;
border: none;
border-radius: 50%;
background-color: white;
background-color: var(--button-bg);
cursor: pointer;
transition: all 0.3s ease;
border: 1px solid rgb(233, 233, 233);
border: 1px solid var(--button-border);
display: flex;
align-items: center;
justify-content: center;
@@ -94,14 +169,14 @@
.timer {
font-size: 14px;
font-weight: 500;
color: #333;
color: var(--text);
margin-left: 10px;
}
#status {
margin-top: 20px;
font-size: 16px;
color: #333;
color: var(--text);
}
.settings-container {
@@ -120,12 +195,14 @@
}
#chunkSelector,
#websocketInput {
#websocketInput,
#themeSelector {
font-size: 16px;
padding: 5px;
border-radius: 5px;
border: 1px solid #ddd;
background-color: #ffffff;
border: 1px solid var(--border);
background-color: var(--button-bg);
color: var(--text);
max-height: 30px;
}
@@ -134,7 +211,8 @@
}
#chunkSelector:focus,
#websocketInput:focus {
#websocketInput:focus,
#themeSelector:focus {
outline: none;
border-color: #007bff;
}
@@ -156,18 +234,18 @@
}
#linesTranscript strong {
color: #333;
color: var(--text);
}
#speaker {
border: 1px solid rgb(229, 229, 229);
border: 1px solid var(--border);
border-radius: 100px;
padding: 2px 10px;
font-size: 14px;
margin-bottom: 0px;
}
.label_diarization {
background-color: #ffffff66;
background-color: var(--chip-bg);
border-radius: 8px 8px 8px 8px;
padding: 2px 10px;
margin-left: 10px;
@@ -175,11 +253,11 @@
white-space: nowrap;
font-size: 14px;
margin-bottom: 0px;
color: rgb(134, 134, 134)
color: var(--label-dia-text)
}
.label_transcription {
background-color: #ffffff66;
background-color: var(--chip-bg);
border-radius: 8px 8px 8px 8px;
padding: 2px 10px;
display: inline-block;
@@ -187,11 +265,11 @@
margin-left: 10px;
font-size: 14px;
margin-bottom: 0px;
color: #000000
color: var(--label-trans-text)
}
#timeInfo {
color: #666;
color: var(--muted);
margin-left: 10px;
}
@@ -206,7 +284,7 @@
}
.buffer_diarization {
color: rgb(134, 134, 134);
color: var(--label-dia-text);
margin-left: 4px;
}
@@ -220,10 +298,10 @@
display: inline-block;
width: 8px;
height: 8px;
border: 2px solid #8d8d8d5c;
border-top: 2px solid #6c6c6ce5;
border: 2px solid var(--spinner-border);
border-top: 2px solid var(--spinner-top);
border-radius: 50%;
animation: spin 0.6s linear infinite;
animation: spin 0.7s linear infinite;
vertical-align: middle;
margin-bottom: 2px;
margin-right: 5px;
@@ -236,16 +314,16 @@
}
.silence {
color: #666;
background-color: #f3f3f3;
color: var(--muted);
background-color: var(--silence-bg);
font-size: 13px;
border-radius: 30px;
padding: 2px 10px;
}
.loading {
color: #666;
background-color: #ff4d4d0f;
color: var(--muted);
background-color: var(--loading-bg);
border-radius: 8px 8px 8px 0px;
padding: 2px 10px;
font-size: 14px;
@@ -284,6 +362,14 @@
<label for="websocketInput">WebSocket URL:</label>
<input id="websocketInput" type="text" />
</div>
<div>
<label for="themeSelector">Theme:</label>
<select id="themeSelector">
<option value="system" selected>System</option>
<option value="light">Light</option>
<option value="dark">Dark</option>
</select>
</div>
</div>
</div>
@@ -299,6 +385,7 @@
let chunkDuration = 1000;
let websocketUrl = "ws://localhost:8000/asr";
let userClosing = false;
let wakeLock = null;
let startTime = null;
let timerInterval = null;
let audioContext = null;
@@ -309,6 +396,7 @@
let animationFrame = null;
let waitingForStop = false;
let lastReceivedData = null;
let lastSignature = null;
waveCanvas.width = 60 * (window.devicePixelRatio || 1);
waveCanvas.height = 30 * (window.devicePixelRatio || 1);
waveCtx.scale(window.devicePixelRatio || 1, window.devicePixelRatio || 1);
@@ -319,6 +407,57 @@
const websocketInput = document.getElementById("websocketInput");
const linesTranscriptDiv = document.getElementById("linesTranscript");
const timerElement = document.querySelector(".timer");
const themeSelector = document.getElementById("themeSelector");
function getWaveStroke() {
const styles = getComputedStyle(document.documentElement);
const v = styles.getPropertyValue("--wave-stroke").trim();
return v || "#000";
}
let waveStroke = getWaveStroke();
function updateWaveStroke() {
waveStroke = getWaveStroke();
}
function applyTheme(pref) {
if (pref === "light") {
document.documentElement.setAttribute("data-theme", "light");
} else if (pref === "dark") {
document.documentElement.setAttribute("data-theme", "dark");
} else {
document.documentElement.removeAttribute("data-theme");
}
updateWaveStroke();
}
const savedThemePref = localStorage.getItem("themePreference") || "system";
applyTheme(savedThemePref);
if (themeSelector) {
themeSelector.value = savedThemePref;
themeSelector.addEventListener("change", () => {
const val = themeSelector.value;
localStorage.setItem("themePreference", val);
applyTheme(val);
});
}
const darkMq = window.matchMedia && window.matchMedia("(prefers-color-scheme: dark)");
const handleOsThemeChange = () => {
const pref = localStorage.getItem("themePreference") || "system";
if (pref === "system") updateWaveStroke();
};
if (darkMq && darkMq.addEventListener) {
darkMq.addEventListener("change", handleOsThemeChange);
} else if (darkMq && darkMq.addListener) {
darkMq.addListener(handleOsThemeChange);
}
function fmt1(x) {
const n = Number(x);
return Number.isFinite(n) ? n.toFixed(1) : x;
}
const host = window.location.hostname || "localhost";
const port = window.location.port;
@@ -446,10 +585,35 @@
function renderLinesWithBuffer(lines, buffer_diarization, buffer_transcription, remaining_time_diarization, remaining_time_transcription, isFinalizing = false, current_status = "active_transcription") {
if (current_status === "no_audio_detected") {
linesTranscriptDiv.innerHTML = "<p style='text-align: center; color: #666; margin-top: 20px;'><em>No audio detected...</em></p>";
linesTranscriptDiv.innerHTML = "<p style='text-align: center; color: var(--muted); margin-top: 20px;'><em>No audio detected...</em></p>";
return;
}
// try to keep stable DOM despite having updates every 0.1s. only update numeric lag values if structure hasn't changed
const showLoading = (!isFinalizing) && (lines || []).some(it => it.speaker == 0);
const showTransLag = !isFinalizing && remaining_time_transcription > 0;
const showDiaLag = !isFinalizing && !!buffer_diarization && remaining_time_diarization > 0;
const signature = JSON.stringify({
lines: (lines || []).map(it => ({ speaker: it.speaker, text: it.text, beg: it.beg, end: it.end })),
buffer_transcription: buffer_transcription || "",
buffer_diarization: buffer_diarization || "",
status: current_status,
showLoading,
showTransLag,
showDiaLag,
isFinalizing: !!isFinalizing
});
if (lastSignature === signature) {
const t = document.querySelector(".lag-transcription-value");
if (t) t.textContent = fmt1(remaining_time_transcription);
const d = document.querySelector(".lag-diarization-value");
if (d) d.textContent = fmt1(remaining_time_diarization);
const ld = document.querySelector(".loading-diarization-value");
if (ld) ld.textContent = fmt1(remaining_time_diarization);
return;
}
lastSignature = signature;
const linesHtml = lines.map((item, idx) => {
let timeInfo = "";
if (item.beg !== undefined && item.end !== undefined) {
@@ -460,7 +624,7 @@
if (item.speaker === -2) {
speakerLabel = `<span class="silence">Silence<span id='timeInfo'>${timeInfo}</span></span>`;
} else if (item.speaker == 0 && !isFinalizing) {
speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'>${remaining_time_diarization} second(s) of audio are undergoing diarization</span></span>`;
speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'><span class="loading-diarization-value">${fmt1(remaining_time_diarization)}</span> second(s) of audio are undergoing diarization</span></span>`;
} else if (item.speaker == -1) {
speakerLabel = `<span id="speaker">Speaker 1<span id='timeInfo'>${timeInfo}</span></span>`;
} else if (item.speaker !== -1 && item.speaker !== 0) {
@@ -471,12 +635,12 @@
let currentLineText = item.text || "";
if (idx === lines.length - 1) {
if (!isFinalizing) {
if (!isFinalizing && item.speaker !== -2) {
if (remaining_time_transcription > 0) {
speakerLabel += `<span class="label_transcription"><span class="spinner"></span>Transcription lag <span id='timeInfo'>${remaining_time_transcription}s</span></span>`;
speakerLabel += `<span class="label_transcription"><span class="spinner"></span>Transcription lag <span id='timeInfo'><span class="lag-transcription-value">${fmt1(remaining_time_transcription)}</span>s</span></span>`;
}
if (buffer_diarization && remaining_time_diarization > 0) {
speakerLabel += `<span class="label_diarization"><span class="spinner"></span>Diarization lag<span id='timeInfo'>${remaining_time_diarization}s</span></span>`;
speakerLabel += `<span class="label_diarization"><span class="spinner"></span>Diarization lag<span id='timeInfo'><span class="lag-diarization-value">${fmt1(remaining_time_diarization)}</span>s</span></span>`;
}
}
@@ -502,6 +666,7 @@
}).join("");
linesTranscriptDiv.innerHTML = linesHtml;
window.scrollTo({ top: document.body.scrollHeight, behavior: 'smooth' });
}
function updateTimer() {
@@ -522,7 +687,7 @@
waveCtx.clearRect(0, 0, waveCanvas.width / (window.devicePixelRatio || 1), waveCanvas.height / (window.devicePixelRatio || 1));
waveCtx.lineWidth = 1;
waveCtx.strokeStyle = 'rgb(0, 0, 0)';
waveCtx.strokeStyle = waveStroke;
waveCtx.beginPath();
const sliceWidth = (waveCanvas.width / (window.devicePixelRatio || 1)) / bufferLength;
@@ -549,6 +714,16 @@
async function startRecording() {
try {
// https://developer.mozilla.org/en-US/docs/Web/API/Screen_Wake_Lock_API
// create an async function to request a wake lock
try {
wakeLock = await navigator.wakeLock.request("screen");
} catch (err) {
// The Wake Lock request has failed - usually system related, such as battery.
console.log("Error acquiring wake lock.")
}
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
audioContext = new (window.AudioContext || window.webkitAudioContext)();
@@ -578,6 +753,10 @@
}
async function stopRecording() {
wakeLock.release().then(() => {
wakeLock = null;
});
userClosing = true;
waitingForStop = true;

View File

@@ -3,43 +3,10 @@ import logging
import io
import soundfile as sf
import math
try:
import torch
except ImportError:
torch = None
from typing import List
import numpy as np
from whisperlivekit.timed_objects import ASRToken
logger = logging.getLogger(__name__)
SIMULSTREAMING_ERROR_AND_INSTALLATION_INSTRUCTIONS = ImportError(
"""SimulStreaming dependencies are not available.
Please install WhisperLiveKit using pip install "whisperlivekit[simulstreaming]"
""")
try:
from whisperlivekit.simul_whisper.config import AlignAttConfig
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper, DEC_PAD
from whisperlivekit.simul_whisper.whisper import tokenizer
SIMULSTREAMING_AVAILABLE = True
except ImportError:
logger.warning("⚠️ SimulStreaming dependencies not available. Attempting to download them.")
try:
from whisperlivekit import download_simulstreaming_backend
download_simulstreaming_backend()
from whisperlivekit.simul_whisper.config import AlignAttConfig
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper, DEC_PAD
from whisperlivekit.simul_whisper.whisper import tokenizer
SIMULSTREAMING_AVAILABLE = True
logger.info("SimulStreaming dependencies downloaded successfully.")
except Exception as e:
logger.error(f"Failed to download or import SimulStreaming dependencies: {e}")
SIMULSTREAMING_AVAILABLE = False
AlignAttConfig = None
PaddedAlignAttWhisper = None
DEC_PAD = None
tokenizer = None
class ASRBase:
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
# "" for faster-whisper because it emits the spaces when needed)
@@ -320,182 +287,4 @@ class OpenaiApiASR(ASRBase):
self.use_vad_opt = True
def set_translate_task(self):
self.task = "translate"
class SimulStreamingASR(ASRBase):
"""SimulStreaming backend with AlignAtt policy."""
sep = ""
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr, **kwargs):
if not SIMULSTREAMING_AVAILABLE:
raise SIMULSTREAMING_ERROR_AND_INSTALLATION_INSTRUCTIONS
with open("whisperlivekit/simul_whisper/dual_license_simulstreaming.md", "r") as f:
print("*"*80 + f.read() + "*"*80)
self.logfile = logfile
self.transcribe_kargs = {}
self.original_language = None if lan == "auto" else lan
self.model_path = kwargs.get('model_path', './large-v3.pt')
self.frame_threshold = kwargs.get('frame_threshold', 25)
self.audio_max_len = kwargs.get('audio_max_len', 30.0)
self.audio_min_len = kwargs.get('audio_min_len', 0.0)
self.segment_length = kwargs.get('segment_length', 0.5)
self.beams = kwargs.get('beams', 1)
self.decoder_type = kwargs.get('decoder_type', 'greedy' if self.beams == 1 else 'beam')
self.task = kwargs.get('task', 'transcribe')
self.cif_ckpt_path = kwargs.get('cif_ckpt_path', None)
self.never_fire = kwargs.get('never_fire', False)
self.init_prompt = kwargs.get('init_prompt', None)
self.static_init_prompt = kwargs.get('static_init_prompt', None)
self.max_context_tokens = kwargs.get('max_context_tokens', None)
if model_dir is not None:
self.model_path = model_dir
elif modelsize is not None: #For the moment the .en.pt models do not work!
model_mapping = {
'tiny': './tiny.pt',
'base': './base.pt',
'small': './small.pt',
'medium': './medium.pt',
'medium.en': './medium.en.pt',
'large-v1': './large-v1.pt',
'base.en': './base.en.pt',
'small.en': './small.en.pt',
'tiny.en': './tiny.en.pt',
'large-v2': './large-v2.pt',
'large-v3': './large-v3.pt',
'large': './large-v3.pt'
}
self.model_path = model_mapping.get(modelsize, f'./{modelsize}.pt')
self.model = self.load_model(modelsize, cache_dir, model_dir)
# Set up tokenizer for translation if needed
if self.task == "translate":
self.set_translate_task()
def load_model(self, modelsize, cache_dir, model_dir):
try:
cfg = AlignAttConfig(
model_path=self.model_path,
segment_length=self.segment_length,
frame_threshold=self.frame_threshold,
language=self.original_language,
audio_max_len=self.audio_max_len,
audio_min_len=self.audio_min_len,
cif_ckpt_path=self.cif_ckpt_path,
decoder_type="beam",
beam_size=self.beams,
task=self.task,
never_fire=self.never_fire,
init_prompt=self.init_prompt,
max_context_tokens=self.max_context_tokens,
static_init_prompt=self.static_init_prompt,
)
logger.info(f"Loading SimulStreaming model with language: {self.original_language}")
model = PaddedAlignAttWhisper(cfg)
return model
except Exception as e:
logger.error(f"Failed to load SimulStreaming model: {e}")
raise
def transcribe(self, audio, init_prompt=""):
"""Transcribe audio using SimulStreaming."""
try:
if isinstance(audio, np.ndarray):
audio_tensor = torch.from_numpy(audio).float()
else:
audio_tensor = audio
prompt = init_prompt if init_prompt else (self.init_prompt or "")
result = self.model.infer(audio_tensor, init_prompt=prompt)
if torch.is_tensor(result):
result = result[result < DEC_PAD]
logger.debug(f"SimulStreaming transcription result: {result}")
return result
except Exception as e:
logger.error(f"SimulStreaming transcription failed: {e}")
raise
def ts_words(self, result) -> List[ASRToken]:
"""Convert SimulStreaming result to ASRToken list."""
tokens = []
try:
if torch.is_tensor(result):
text = self.model.tokenizer.decode(result.cpu().numpy())
else:
text = str(result)
if not text or len(text.strip()) == 0:
return tokens
# We dont have word-level timestamps here. 1rst approach, should be improved later.
words = text.strip().split()
if not words:
return tokens
duration_per_word = 0.1 # this will be modified based on actual audio duration
#with the SimulStreamingOnlineProcessor
for i, word in enumerate(words):
start_time = i * duration_per_word
end_time = (i + 1) * duration_per_word
token = ASRToken(
start=start_time,
end=end_time,
text=word,
probability=1.0
)
tokens.append(token)
except Exception as e:
logger.error(f"Error converting SimulStreaming result to tokens: {e}")
return tokens
def segments_end_ts(self, result) -> List[float]:
"""Get segment end timestamps."""
if torch.is_tensor(result):
num_tokens = len(result)
return [num_tokens * 0.1] # rough estimate
return [1.0]
def use_vad(self):
"""Enable VAD - SimulStreaming has different VAD handling."""
logger.info("VAD requested for SimulStreaming - handled internally by the model")
pass
def set_translate_task(self):
"""Set up translation task."""
try:
self.model.tokenizer = tokenizer.get_tokenizer(
multilingual=True,
language=self.model.cfg.language,
num_languages=self.model.model.num_languages,
task="translate"
)
logger.info("SimulStreaming configured for translation task")
except Exception as e:
logger.error(f"Failed to configure SimulStreaming for translation: {e}")
raise
def warmup(self, audio, init_prompt=""):
"""Warmup the SimulStreaming model."""
try:
if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio).float()
self.model.insert_audio(audio)
self.model.infer(True)
self.model.refresh_segment(complete=True)
logger.info("SimulStreaming model warmed up successfully")
except Exception as e:
logger.warning(f"SimulStreaming warmup failed: {e}")
self.task = "translate"

View File

@@ -6,18 +6,6 @@ from whisperlivekit.timed_objects import ASRToken, Sentence, Transcript
logger = logging.getLogger(__name__)
# simulStreaming imports - we check if the files are here
try:
import torch
from whisperlivekit.simul_whisper.config import AlignAttConfig
SIMULSTREAMING_AVAILABLE = True
except ImportError:
logger.warning("SimulStreaming dependencies not available for online processor.")
SIMULSTREAMING_AVAILABLE = False
OnlineProcessorInterface = None
torch = None
class HypothesisBuffer:
"""
Buffer to store and process ASR hypothesis tokens.
@@ -528,205 +516,3 @@ class VACOnlineASRProcessor:
"""
return self.online.concatenate_tokens(self.online.transcript_buffer.buffer)
class SimulStreamingOnlineProcessor:
SAMPLING_RATE = 16000
def __init__(
self,
asr,
tokenize_method: Optional[callable] = None,
buffer_trimming: Tuple[str, float] = ("segment", 15),
confidence_validation = False,
logfile=sys.stderr,
):
if not SIMULSTREAMING_AVAILABLE:
raise ImportError("SimulStreaming dependencies are not available.")
self.asr = asr
self.tokenize = tokenize_method
self.logfile = logfile
self.confidence_validation = confidence_validation
self.init()
# buffer does not work yet
self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
def init(self, offset: Optional[float] = None):
"""Initialize or reset the processing state."""
self.audio_chunks = []
self.offset = offset if offset is not None else 0.0
self.is_last = False
self.beg = self.offset
self.end = self.offset
self.cumulative_audio_duration = 0.0
self.last_audio_stream_end_time = self.offset
self.committed: List[ASRToken] = []
self.last_result_tokens: List[ASRToken] = []
self.buffer_content = ""
self.processed_audio_duration = 0.0
def get_audio_buffer_end_time(self) -> float:
"""Returns the absolute end time of the current audio buffer."""
return self.end
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: Optional[float] = None):
"""Append an audio chunk to be processed by SimulStreaming."""
if torch is None:
raise ImportError("PyTorch is required for SimulStreaming but not available")
# Convert numpy array to torch tensor
audio_tensor = torch.from_numpy(audio).float()
self.audio_chunks.append(audio_tensor)
# Update timing
chunk_duration = len(audio) / self.SAMPLING_RATE
self.cumulative_audio_duration += chunk_duration
if audio_stream_end_time is not None:
self.last_audio_stream_end_time = audio_stream_end_time
self.end = audio_stream_end_time
else:
self.end = self.offset + self.cumulative_audio_duration
def prompt(self) -> Tuple[str, str]:
"""
Returns a tuple: (prompt, context).
SimulStreaming handles prompting internally, so we return empty strings.
"""
return "", ""
def get_buffer(self):
"""
Get the unvalidated buffer content.
"""
buffer_end = self.end if hasattr(self, 'end') else None
return Transcript(
start=None,
end=buffer_end,
text=self.buffer_content,
probability=None
)
def timestamped_text(self, tokens, generation):
# From the simulstreaming repo. self.model to self.asr.model
pr = generation["progress"]
if "result" not in generation:
split_words, split_tokens = self.asr.model.tokenizer.split_to_word_tokens(tokens)
else:
split_words, split_tokens = generation["result"]["split_words"], generation["result"]["split_tokens"]
frames = [p["most_attended_frames"][0] for p in pr]
tokens = tokens.copy()
ret = []
for sw,st in zip(split_words,split_tokens):
b = None
for stt in st:
t,f = tokens.pop(0), frames.pop(0)
if t != stt:
raise ValueError(f"Token mismatch: {t} != {stt} at frame {f}.")
if b is None:
b = f
e = f
out = (b*0.02, e*0.02, sw)
ret.append(out)
logger.debug(f"TS-WORD:\t{' '.join(map(str, out))}")
return ret
def process_iter(self) -> Tuple[List[ASRToken], float]:
"""
Process accumulated audio chunks using SimulStreaming.
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
"""
if not self.audio_chunks:
return [], self.end
try:
# concatenate all audio chunks
if len(self.audio_chunks) == 1:
audio = self.audio_chunks[0]
else:
audio = torch.cat(self.audio_chunks, dim=0)
audio_duration = audio.shape[0] / self.SAMPLING_RATE if audio.shape[0] > 0 else 0
self.processed_audio_duration += audio_duration
self.audio_chunks = []
logger.debug(f"SimulStreaming processing audio shape: {audio.shape}, duration: {audio_duration:.2f}s")
logger.debug(f"Current end time: {self.end:.2f}s, last stream time: {self.last_audio_stream_end_time:.2f}s")
self.asr.model.insert_audio(audio)
tokens, generation_progress = self.asr.model.infer(is_last=self.is_last)
ts_words = self.timestamped_text(tokens, generation_progress)
text = self.asr.model.tokenizer.decode(tokens)
new_tokens = []
for ts_word in ts_words:
start, end, word = ts_word
token = ASRToken(
start=start,
end=end,
text=word,
probability=0.95 # fake prob. Maybe we can extract it from the model?
)
new_tokens.append(token)
self.committed.extend(new_tokens)
return new_tokens, self.end
except Exception as e:
logger.error(f"SimulStreaming processing error: {e}")
logger.error(f"Error details: {type(e).__name__}: {str(e)}")
return [], self.end
def finish(self) -> Tuple[List[ASRToken], float]:
logger.debug("SimulStreaming finish() called")
self.is_last = True
final_tokens, final_time = self.process_iter()
self.is_last = False
return final_tokens, final_time
def concatenate_tokens(
self,
tokens: List[ASRToken],
sep: Optional[str] = None,
offset: float = 0
) -> Transcript:
"""Concatenate tokens into a Transcript object."""
sep = sep if sep is not None else self.asr.sep
text = sep.join(token.text for token in tokens)
probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
if tokens:
start = offset + tokens[0].start
end = offset + tokens[-1].end
else:
start = None
end = None
return Transcript(start, end, text, probability=probability)
def chunk_at(self, time: float):
"""
useless but kept for compatibility
"""
logger.debug(f"SimulStreaming chunk_at({time:.2f}) - handled internally")
pass
def words_to_sentences(self, tokens: List[ASRToken]) -> List[Sentence]:
"""
Create simple sentences.
"""
if not tokens:
return []
full_text = " ".join(token.text for token in tokens)
sentence = Sentence(
start=tokens[0].start,
end=tokens[-1].end,
text=full_text
)
return [sentence]

View File

@@ -5,8 +5,7 @@ import librosa
from functools import lru_cache
import time
import logging
from .backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR, SimulStreamingASR, SIMULSTREAMING_AVAILABLE, SIMULSTREAMING_ERROR_AND_INSTALLATION_INSTRUCTIONS
from .online_asr import OnlineASRProcessor, VACOnlineASRProcessor, SimulStreamingOnlineProcessor, SIMULSTREAMING_AVAILABLE as SIMULSTREAMING_ONLINE_AVAILABLE
from .backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR
logger = logging.getLogger(__name__)
@@ -68,35 +67,7 @@ def backend_factory(args):
backend = args.backend
if backend == "openai-api":
logger.debug("Using OpenAI API.")
asr = OpenaiApiASR(lan=args.lan)
elif backend == "simulstreaming":
logger.debug("Using SimulStreaming backend.")
if not SIMULSTREAMING_AVAILABLE:
raise SIMULSTREAMING_ERROR_AND_INSTALLATION_INSTRUCTIONS
simulstreaming_kwargs = {}
for attr in ['frame_threshold', 'beams', 'decoder_type', 'audio_max_len', 'audio_min_len',
'cif_ckpt_path', 'never_fire', 'init_prompt', 'static_init_prompt',
'max_context_tokens', 'model_path']:
if hasattr(args, attr):
simulstreaming_kwargs[attr] = getattr(args, attr)
# Add segment_length from min_chunk_size
simulstreaming_kwargs['segment_length'] = getattr(args, 'min_chunk_size', 0.5)
simulstreaming_kwargs['task'] = args.task
size = args.model
t = time.time()
logger.info(f"Loading SimulStreaming {size} model for language {args.lan}...")
asr = SimulStreamingASR(
modelsize=size,
lan=args.lan,
cache_dir=getattr(args, 'model_cache_dir', None),
model_dir=getattr(args, 'model_dir', None),
**simulstreaming_kwargs
)
e = time.time()
logger.info(f"done. It took {round(e-t,2)} seconds.")
asr = OpenaiApiASR(lan=args.lan)
else:
if backend == "faster-whisper":
asr_cls = FasterWhisperASR
@@ -136,107 +107,4 @@ def backend_factory(args):
tokenizer = create_tokenizer(tgt_language)
else:
tokenizer = None
return asr, tokenizer
def online_factory(args, asr, tokenizer, logfile=sys.stderr):
if args.backend == "simulstreaming":
if not SIMULSTREAMING_ONLINE_AVAILABLE:
raise SIMULSTREAMING_ERROR_AND_INSTALLATION_INSTRUCTIONS
logger.debug("Creating SimulStreaming online processor")
online = SimulStreamingOnlineProcessor(
asr,
tokenizer,
logfile=logfile,
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
confidence_validation=args.confidence_validation
)
elif args.vac:
online = VACOnlineASRProcessor(
args.min_chunk_size,
asr,
tokenizer,
logfile=logfile,
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
confidence_validation = args.confidence_validation
)
else:
online = OnlineASRProcessor(
asr,
tokenizer,
logfile=logfile,
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
confidence_validation = args.confidence_validation
)
return online
def asr_factory(args, logfile=sys.stderr):
"""
Creates and configures an ASR and ASR Online instance based on the specified backend and arguments.
"""
asr, tokenizer = backend_factory(args)
online = online_factory(args, asr, tokenizer, logfile=logfile)
return asr, online
def warmup_asr(asr, warmup_file=None, timeout=5):
"""
Warmup the ASR model by transcribing a short audio file.
"""
import os
import tempfile
is_simulstreaming = hasattr(asr, 'warmup') and callable(getattr(asr, 'warmup'))
if warmup_file is None:
# Download JFK sample if not already present
jfk_url = "https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav"
temp_dir = tempfile.gettempdir()
warmup_file = os.path.join(temp_dir, "whisper_warmup_jfk.wav")
if not os.path.exists(warmup_file):
logger.debug(f"Downloading warmup file from {jfk_url}")
print(f"Downloading warmup file from {jfk_url}")
import time
import urllib.request
import urllib.error
import socket
original_timeout = socket.getdefaulttimeout()
socket.setdefaulttimeout(timeout)
start_time = time.time()
try:
urllib.request.urlretrieve(jfk_url, warmup_file)
logger.debug(f"Download successful in {time.time() - start_time:.2f}s")
except (urllib.error.URLError, socket.timeout) as e:
logger.warning(f"Download failed: {e}. Proceeding without warmup.")
return False
finally:
socket.setdefaulttimeout(original_timeout)
elif not warmup_file:
return False
if not warmup_file or not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0:
logger.warning(f"Warmup file {warmup_file} invalid or missing.")
return False
print(f"Warming up {'SimulStreaming' if is_simulstreaming else 'Whisper'} with {warmup_file}")
try:
import librosa
audio, sr = librosa.load(warmup_file, sr=16000)
except Exception as e:
logger.warning(f"Failed to load audio file: {e}")
return False
try:
if is_simulstreaming:
asr.warmup(audio)
else:
asr.transcribe(audio)
logger.info(f"{'SimulStreaming' if is_simulstreaming else 'Whisper'} is warmed up")
return True
except Exception as e:
logger.warning(f"Warmup failed: {e}")
return False
return asr, tokenizer