46 Commits
0.2.2 ... 0.2.5

Author SHA1 Message Date
Quentin Fuxa
349c7dcb9e bump version ro 0.2.5 2025-08-13 10:04:31 +02:00
Quentin Fuxa
1c42b867cf Merge branch 'main' of https://github.com/QuentinFuxa/whisper_streaming_web 2025-08-13 10:04:04 +02:00
Quentin Fuxa
d4771e563e Increase END_SILENCE_DURATION to reduce false positives 2025-08-13 10:04:00 +02:00
Quentin Fuxa
b0a5fc0693 Merge pull request #155 from davidgumberg/keepawakescrolldown
frontend: Keep screen awake and scroll down when transcribing.
2025-08-13 10:02:52 +02:00
David Gumberg
3b96fb8776 frontend: Scroll down when appending transcription 2025-08-12 17:31:32 -07:00
David Gumberg
7f93c4b978 frontend: Don't let screen sleep when transcribing. 2025-08-12 17:30:57 -07:00
Quentin Fuxa
15c3df1cba warmup base whisper when using simulstreaming 2025-08-12 18:52:52 +02:00
Quentin Fuxa
7fb8e66c01 typo 2025-08-12 18:36:32 +02:00
Quentin Fuxa
728e1f1290 simulstreaming warmup is done for each instance of online, not for the backend 2025-08-12 18:35:04 +02:00
Quentin Fuxa
87b9ed6ecd nonspeech_prob from 1 to 0.5 2025-08-12 18:34:37 +02:00
Quentin Fuxa
38b4ebe8ba Handle 3 types of silences: Indicated by whisper, between tokens, and at the end of the input. Display them in the frontend 2025-08-11 17:56:57 +02:00
Quentin Fuxa
d098af3185 each SimulStreamingOnlineProcessor now contains PaddedAlignAttWhisper instance. SimulStreamingASR only contains loaded whisper model 2025-08-11 08:24:14 +02:00
Quentin Fuxa
4e56130a40 frontend supports dark theme 2025-08-11 08:22:23 +02:00
Quentin Fuxa
2bbdc70187 lags are now updated every 0.1s 2025-08-09 23:11:05 +02:00
Quentin Fuxa
b678a55f63 remove duplicate file 2025-08-09 23:10:34 +02:00
Quentin Fuxa
5491964e81 clean SimulStreamingOnlineProcessor initialization + audio processing 2025-08-09 20:16:27 +02:00
Quentin Fuxa
b05297a96d clean simulwhisper backend and online 2025-08-09 18:02:15 +02:00
Quentin Fuxa
197293e25e refactor(simulstreaming): extract backend + online module into separate files from whisper streaming 2025-08-08 18:07:51 +02:00
Quentin Fuxa
ba41c4ab56 Remove download_simulstreaming_backend 2025-08-08 18:06:40 +02:00
Quentin Fuxa
bda72b8bc0 setup.py to pyproject.toml. Remove <2.0.0 condition on numpy dep 2025-08-03 16:32:31 +02:00
Quentin Fuxa
bb6b9f4cb1 architecture diagram : available backends for whisper streaming & diarization 2025-08-03 12:25:36 +02:00
Quentin Fuxa
e40b5a3ea0 Update architecture diagram 2025-08-02 13:51:15 +02:00
Quentin Fuxa
4cfed6e98e in MultiHeadAttention and ResidualAttentionBlock include cache_id for compatibility with simulstreaming code 2025-08-02 13:16:58 +02:00
Quentin Fuxa
687e3dd5e2 update simulstreaming model.py to match the latest version of whisper sources 2025-08-02 13:16:10 +02:00
Quentin Fuxa
e4140cd299 Update Dockerfile to install build-essential and update PyTorch version 2025-08-02 13:08:43 +02:00
Quentin Fuxa
8e056cbdf2 Upgrade SimulStreaming Whisper core from version 20230918 to 20250625 2025-08-02 13:06:36 +02:00
Quentin Fuxa
9dcfb38967 Update README.md 2025-08-01 18:02:11 +02:00
Quentin Fuxa
47b9235d70 Update README.md 2025-08-01 17:55:40 +02:00
Quentin Fuxa
f3cd53a4db Update README.md 2025-08-01 16:53:22 +02:00
Quentin Fuxa
dbdb4ea66c Update README.md 2025-08-01 16:33:26 +02:00
Quentin Fuxa
00424d7ca3 latest version of simulstreaming 2025-07-31 16:44:23 +02:00
Quentin Fuxa
4b738d6f63 fix duplicate line 2025-07-31 16:29:35 +02:00
Quentin Fuxa
8a5e2adb1e simulstreaming: fixes token handling during warm-up phase 2025-07-31 16:25:34 +02:00
Quentin Fuxa
f85329e112 Merge branch 'main' of https://github.com/QuentinFuxa/whisper_streaming_web 2025-07-31 11:42:16 +02:00
Quentin Fuxa
46efbdf1d9 solves https://github.com/QuentinFuxa/WhisperLiveKit/issues/151 2025-07-31 11:42:06 +02:00
Quentin Fuxa
8885ade003 Merge pull request #153 from luisla-rivas/main
Fix README.md to view correctly Deployment Guide info
2025-07-31 07:10:35 +02:00
luisla-rivas
2564928d83 Fix README.md to view correctly Deployment Guide info 2025-07-30 14:11:19 +02:00
Quentin Fuxa
56114d3071 Remove end_attributed_speaker in diarization_online. handled in audio processor 2025-07-16 12:09:43 +02:00
Quentin Fuxa
5b9977c9af Enhanced use_punctuation_split for diarization. further improvements still needed 2025-07-16 12:06:17 +02:00
Quentin Fuxa
12a544164f Merge branch 'main' of https://github.com/QuentinFuxa/whisper_streaming_web 2025-07-16 12:05:01 +02:00
Quentin Fuxa
2ca1156b7e Merge pull request #147 from choomegan/diar_queue
Ensure diarization_queue receives only latest PCM chunk
2025-07-16 12:04:53 +02:00
Quentin Fuxa
3ad3683ca7 Refactor speaker assignment in DiartDiarization for clarity and punctuation awareness 2025-07-15 14:38:53 +02:00
Quentin Fuxa
1599bd87a0 work on punctuation_split 2025-07-15 12:04:54 +02:00
Quentin Fuxa
90623400a4 Remove automatic downloading of SimulStreaming dependencies on import failure 2025-07-15 12:04:17 +02:00
choomegan
64e44fb24f fix: logic of adding of pcm_array to diarization_queue 2025-07-15 15:33:41 +08:00
Quentin Fuxa
156b9a133f 0.2.2 2025-07-04 17:11:35 +02:00
34 changed files with 3160 additions and 1713 deletions

View File

@@ -21,10 +21,12 @@ RUN apt-get update && \
python3 \ python3 \
python3-pip \ python3-pip \
ffmpeg \ ffmpeg \
git && \ git \
build-essential \
python3-dev && \
rm -rf /var/lib/apt/lists/* rm -rf /var/lib/apt/lists/*
RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
COPY . . COPY . .

View File

@@ -13,23 +13,16 @@
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT/Dual Licensed-dark_green"></a> <a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT/Dual Licensed-dark_green"></a>
</p> </p>
## Overview
This project is based on [WhisperStreaming](https://github.com/ufal/whisper_streaming) and [SimulStreaming](https://github.com/ufal/SimulStreaming), allowing you to transcribe audio directly from your browser. WhisperLiveKit provides a complete backend solution for real-time speech transcription with a functional, simple and customizable frontend. Everything runs locally on your machine WhisperLiveKit brings real-time speech transcription directly to your browser, with a ready-to-use backend+server and a simple frontend.
### Architecture Built on [SimulStreaming](https://github.com/ufal/SimulStreaming) (SOTA 2025) and [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) for transcription, plus [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) and [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) for diarization.
WhisperLiveKit consists of three main components:
- **Frontend**: A basic html + JS interface that captures microphone audio and streams it to the backend via WebSockets. You can use and adapt the [provided template](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html).
- **Backend (Web Server)**: A FastAPI-based WebSocket server that receives streamed audio data, processes it in real time, and returns transcriptions to the frontend. This is where the WebSocket logic and routing live.
- **Core Backend (Library Logic)**: A server-agnostic core that handles audio processing, ASR, and diarization. It exposes reusable components that take in audio bytes and return transcriptions.
### Key Features ### Key Features
- **Real-time Transcription** - Locally (or on-prem) convert speech to text instantly as you speak - **Real-time Transcription** - Locally (or on-prem) convert speech to text instantly as you speak
- **Speaker Diarization** - Identify different speakers in real-time using [Diart](https://github.com/juanmc2005/diart) - **Speaker Diarization** - Identify different speakers in real-time. (⚠️ backend Streaming Sortformer in developement)
- **Multi-User Support** - Handle multiple users simultaneously with a single backend/server - **Multi-User Support** - Handle multiple users simultaneously with a single backend/server
- **Automatic Silence Chunking** Automatically chunks when no audio is detected to limit buffer size - **Automatic Silence Chunking** Automatically chunks when no audio is detected to limit buffer size
- **Confidence Validation** Immediately validate high-confidence tokens for faster inference (WhisperStreaming only) - **Confidence Validation** Immediately validate high-confidence tokens for faster inference (WhisperStreaming only)
@@ -37,6 +30,11 @@ WhisperLiveKit consists of three main components:
- **Punctuation-Based Speaker Splitting [BETA]** - Align speaker changes with natural sentence boundaries for more readable transcripts - **Punctuation-Based Speaker Splitting [BETA]** - Align speaker changes with natural sentence boundaries for more readable transcripts
- **SimulStreaming Backend** - [Dual-licensed](https://github.com/ufal/SimulStreaming#-licence-and-contributions) - Ultra-low latency transcription using SOTA AlignAtt policy. - **SimulStreaming Backend** - [Dual-licensed](https://github.com/ufal/SimulStreaming#-licence-and-contributions) - Ultra-low latency transcription using SOTA AlignAtt policy.
### Architecture
<img alt="Architecture" src="architecture.png" />
## Quick Start ## Quick Start
```bash ```bash
@@ -247,7 +245,7 @@ To deploy WhisperLiveKit in production:
- Ensure WebSocket connection points to your server's address - Ensure WebSocket connection points to your server's address
3. **Nginx Configuration** (recommended for production): 3. **Nginx Configuration** (recommended for production):
```nginx ```nginx
server { server {
listen 80; listen 80;
server_name your-domain.com; server_name your-domain.com;
@@ -258,6 +256,7 @@ To deploy WhisperLiveKit in production:
proxy_set_header Connection "upgrade"; proxy_set_header Connection "upgrade";
proxy_set_header Host $host; proxy_set_header Host $host;
}} }}
```
4. **HTTPS Support**: For secure deployments, use "wss://" instead of "ws://" in WebSocket URL 4. **HTTPS Support**: For secure deployments, use "wss://" instead of "ws://" in WebSocket URL
@@ -268,21 +267,21 @@ A basic Dockerfile is provided which allows re-use of Python package installatio
#### All defaults #### All defaults
- Create a reusable image with only the basics and then run as a named container: - Create a reusable image with only the basics and then run as a named container:
```bash ```bash
docker build -t whisperlivekit-defaults . docker build -t whisperlivekit-defaults .
docker create --gpus all --name whisperlivekit -p 8000:8000 whisperlivekit-defaults docker create --gpus all --name whisperlivekit -p 8000:8000 whisperlivekit-defaults
docker start -i whisperlivekit docker start -i whisperlivekit
``` ```
> **Note**: If you're running on a system without NVIDIA GPU support (such as Mac with Apple Silicon or any system without CUDA capabilities), you need to **remove the `--gpus all` flag** from the `docker create` command. Without GPU acceleration, transcription will use CPU only, which may be significantly slower. Consider using small models for better performance on CPU-only systems. > **Note**: If you're running on a system without NVIDIA GPU support (such as Mac with Apple Silicon or any system without CUDA capabilities), you need to **remove the `--gpus all` flag** from the `docker create` command. Without GPU acceleration, transcription will use CPU only, which may be significantly slower. Consider using small models for better performance on CPU-only systems.
#### Customization #### Customization
- Customize the container options: - Customize the container options:
```bash ```bash
docker build -t whisperlivekit-defaults . docker build -t whisperlivekit-defaults .
docker create --gpus all --name whisperlivekit-base -p 8000:8000 whisperlivekit-defaults --model base docker create --gpus all --name whisperlivekit-base -p 8000:8000 whisperlivekit-defaults --model base
docker start -i whisperlivekit-base docker start -i whisperlivekit-base
``` ```
- `--build-arg` Options: - `--build-arg` Options:
- `EXTRAS="whisper-timestamped"` - Add extras to the image's installation (no spaces). Remember to set necessary container options! - `EXTRAS="whisper-timestamped"` - Add extras to the image's installation (no spaces). Remember to set necessary container options!

BIN
architecture.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 382 KiB

59
pyproject.toml Normal file
View File

@@ -0,0 +1,59 @@
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[project]
name = "whisperlivekit"
version = "0.2.5"
description = "Real-time, Fully Local Whisper's Speech-to-Text and Speaker Diarization"
readme = "README.md"
authors = [
{ name = "Quentin Fuxa" }
]
license = { file = "LICENSE" }
requires-python = ">=3.9"
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Multimedia :: Sound/Audio :: Speech"
]
dependencies = [
"fastapi",
"librosa",
"soundfile",
"faster-whisper",
"uvicorn",
"websockets"
]
[project.optional-dependencies]
diarization = ["diart"]
vac = ["torch"]
sentence = ["mosestokenizer", "wtpsplit"]
whisper = ["whisper"]
whisper-timestamped = ["whisper-timestamped"]
mlx-whisper = ["mlx-whisper"]
openai = ["openai"]
simulstreaming = [
"torch",
"tqdm",
"tiktoken",
'triton>=2.0.0,<3; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")'
]
[project.urls]
Homepage = "https://github.com/QuentinFuxa/WhisperLiveKit"
[project.scripts]
whisperlivekit-server = "whisperlivekit.basic_server:main"
[tool.setuptools]
packages = ["whisperlivekit", "whisperlivekit.diarization", "whisperlivekit.simul_whisper", "whisperlivekit.simul_whisper.whisper", "whisperlivekit.simul_whisper.whisper.assets", "whisperlivekit.simul_whisper.whisper.normalizers", "whisperlivekit.web", "whisperlivekit.whisper_streaming_custom"]
[tool.setuptools.package-data]
whisperlivekit = ["web/*.html"]
"whisperlivekit.simul_whisper.whisper.assets" = ["*.tiktoken", "*.npz"]

View File

@@ -1,55 +0,0 @@
from setuptools import setup, find_packages
setup(
name="whisperlivekit",
version="0.2.1",
description="Real-time, Fully Local Whisper's Speech-to-Text and Speaker Diarization",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
author="Quentin Fuxa",
url="https://github.com/QuentinFuxa/WhisperLiveKit",
packages=find_packages(),
install_requires=[
"fastapi",
"librosa",
"soundfile",
"faster-whisper",
"uvicorn",
"websockets",
],
extras_require={
"diarization": ["diart"],
"vac": ["torch"],
"sentence": ["mosestokenizer", "wtpsplit"],
"whisper": ["whisper"],
"whisper-timestamped": ["whisper-timestamped"],
"mlx-whisper": ["mlx-whisper"],
"openai": ["openai"],
"simulstreaming": [
"torch",
"tqdm",
"tiktoken",
"numpy<2.0.0",
"triton>=2.0.0,<3;platform_machine==\"x86_64\" and sys_platform==\"linux\" or sys_platform==\"linux2\"",
],
},
package_data={
'whisperlivekit': ['web/*.html'],
'whisperlivekit.simul_whisper': ['dual_license_simulstreaming.md'],
'whisperlivekit.simul_whisper.whisper.assets': ['*.tiktoken', '*.npz'],
},
entry_points={
'console_scripts': [
'whisperlivekit-server=whisperlivekit.basic_server:main',
],
},
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Multimedia :: Sound/Audio :: Speech",
],
python_requires=">=3.9",
)

View File

@@ -1,4 +1,3 @@
from .download_simulstreaming_backend import download_simulstreaming_backend
from .audio_processor import AudioProcessor from .audio_processor import AudioProcessor
from .core import TranscriptionEngine from .core import TranscriptionEngine
from .parse_args import parse_args from .parse_args import parse_args

View File

@@ -6,10 +6,9 @@ import logging
import traceback import traceback
from datetime import timedelta from datetime import timedelta
from whisperlivekit.timed_objects import ASRToken from whisperlivekit.timed_objects import ASRToken
from whisperlivekit.whisper_streaming_custom.whisper_online import online_factory from whisperlivekit.core import TranscriptionEngine, online_factory
from whisperlivekit.core import TranscriptionEngine
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
from .remove_silences import handle_silences
# Set up logging once # Set up logging once
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -52,7 +51,6 @@ class AudioProcessor:
self.tokens = [] self.tokens = []
self.buffer_transcription = "" self.buffer_transcription = ""
self.buffer_diarization = "" self.buffer_diarization = ""
self.full_transcription = ""
self.end_buffer = 0 self.end_buffer = 0
self.end_attributed_speaker = 0 self.end_attributed_speaker = 0
self.lock = asyncio.Lock() self.lock = asyncio.Lock()
@@ -96,13 +94,12 @@ class AudioProcessor:
"""Convert PCM buffer in s16le format to normalized NumPy array.""" """Convert PCM buffer in s16le format to normalized NumPy array."""
return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0 return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
async def update_transcription(self, new_tokens, buffer, end_buffer, full_transcription, sep): async def update_transcription(self, new_tokens, buffer, end_buffer, sep):
"""Thread-safe update of transcription with new data.""" """Thread-safe update of transcription with new data."""
async with self.lock: async with self.lock:
self.tokens.extend(new_tokens) self.tokens.extend(new_tokens)
self.buffer_transcription = buffer self.buffer_transcription = buffer
self.end_buffer = end_buffer self.end_buffer = end_buffer
self.full_transcription = full_transcription
self.sep = sep self.sep = sep
async def update_diarization(self, end_attributed_speaker, buffer_diarization=""): async def update_diarization(self, end_attributed_speaker, buffer_diarization=""):
@@ -129,12 +126,12 @@ class AudioProcessor:
# Calculate remaining times # Calculate remaining times
remaining_transcription = 0 remaining_transcription = 0
if self.end_buffer > 0: if self.end_buffer > 0:
remaining_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 2)) remaining_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 1))
remaining_diarization = 0 remaining_diarization = 0
if self.tokens: if self.tokens:
latest_end = max(self.end_buffer, self.tokens[-1].end if self.tokens else 0) latest_end = max(self.end_buffer, self.tokens[-1].end if self.tokens else 0)
remaining_diarization = max(0, round(latest_end - self.end_attributed_speaker, 2)) remaining_diarization = max(0, round(latest_end - self.end_attributed_speaker, 1))
return { return {
"tokens": self.tokens.copy(), "tokens": self.tokens.copy(),
@@ -153,7 +150,6 @@ class AudioProcessor:
self.tokens = [] self.tokens = []
self.buffer_transcription = self.buffer_diarization = "" self.buffer_transcription = self.buffer_diarization = ""
self.end_buffer = self.end_attributed_speaker = 0 self.end_buffer = self.end_attributed_speaker = 0
self.full_transcription = self.last_response_content = ""
self.beg_loop = time() self.beg_loop = time()
async def ffmpeg_stdout_reader(self): async def ffmpeg_stdout_reader(self):
@@ -192,12 +188,6 @@ class AudioProcessor:
continue continue
self.pcm_buffer.extend(chunk) self.pcm_buffer.extend(chunk)
# Send to diarization if enabled
if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(
self.convert_pcm_to_float(self.pcm_buffer).copy()
)
# Process when enough data # Process when enough data
if len(self.pcm_buffer) >= self.bytes_per_sec: if len(self.pcm_buffer) >= self.bytes_per_sec:
@@ -214,7 +204,11 @@ class AudioProcessor:
# Send to transcription if enabled # Send to transcription if enabled
if self.args.transcription and self.transcription_queue: if self.args.transcription and self.transcription_queue:
await self.transcription_queue.put(pcm_array.copy()) await self.transcription_queue.put(pcm_array.copy())
# Send to diarization if enabled
if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(pcm_array.copy())
# Sleep if no processing is happening # Sleep if no processing is happening
if not self.args.transcription and not self.args.diarization: if not self.args.transcription and not self.args.diarization:
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
@@ -240,7 +234,6 @@ class AudioProcessor:
async def transcription_processor(self): async def transcription_processor(self):
"""Process audio chunks for transcription.""" """Process audio chunks for transcription."""
self.full_transcription = ""
self.sep = self.online.asr.sep self.sep = self.online.asr.sep
cumulative_pcm_duration_stream_time = 0.0 cumulative_pcm_duration_stream_time = 0.0
@@ -252,7 +245,7 @@ class AudioProcessor:
self.transcription_queue.task_done() self.transcription_queue.task_done()
break break
if not self.online: # Should not happen if queue is used if not self.online:
logger.warning("Transcription processor: self.online not initialized.") logger.warning("Transcription processor: self.online not initialized.")
self.transcription_queue.task_done() self.transcription_queue.task_done()
continue continue
@@ -279,8 +272,6 @@ class AudioProcessor:
if new_tokens: if new_tokens:
validated_text = self.sep.join([t.text for t in new_tokens]) validated_text = self.sep.join([t.text for t in new_tokens])
self.full_transcription += validated_text
if buffer_text.startswith(validated_text): if buffer_text.startswith(validated_text):
buffer_text = buffer_text[len(validated_text):].lstrip() buffer_text = buffer_text[len(validated_text):].lstrip()
@@ -297,7 +288,7 @@ class AudioProcessor:
new_end_buffer = max(candidate_end_times) new_end_buffer = max(candidate_end_times)
await self.update_transcription( await self.update_transcription(
new_tokens, buffer_text, new_end_buffer, self.full_transcription, self.sep new_tokens, buffer_text, new_end_buffer, self.sep
) )
self.transcription_queue.task_done() self.transcription_queue.task_done()
@@ -325,12 +316,12 @@ class AudioProcessor:
await diarization_obj.diarize(pcm_array) await diarization_obj.diarize(pcm_array)
async with self.lock: async with self.lock:
new_end = diarization_obj.assign_speakers_to_tokens( self.tokens = diarization_obj.assign_speakers_to_tokens(
self.end_attributed_speaker,
self.tokens, self.tokens,
use_punctuation_split=self.args.punctuation_split use_punctuation_split=self.args.punctuation_split
) )
self.end_attributed_speaker = new_end if len(self.tokens) > 0:
self.end_attributed_speaker = max(self.tokens[-1].end, self.end_attributed_speaker)
if buffer_diarization: if buffer_diarization:
self.buffer_diarization = buffer_diarization self.buffer_diarization = buffer_diarization
@@ -346,6 +337,8 @@ class AudioProcessor:
async def results_formatter(self): async def results_formatter(self):
"""Format processing results for output.""" """Format processing results for output."""
last_sent_trans = None
last_sent_diar = None
while True: while True:
try: try:
ffmpeg_state = await self.ffmpeg_manager.get_state() ffmpeg_state = await self.ffmpeg_manager.get_state()
@@ -383,8 +376,8 @@ class AudioProcessor:
lines = [] lines = []
last_end_diarized = 0 last_end_diarized = 0
undiarized_text = [] undiarized_text = []
current_time = time() - self.beg_loop
# Process each token tokens = handle_silences(tokens, current_time)
for token in tokens: for token in tokens:
speaker = token.speaker speaker = token.speaker
@@ -449,10 +442,19 @@ class AudioProcessor:
' '.join([f"{line['speaker']} {line['text']}" for line in final_lines_for_response]) + \ ' '.join([f"{line['speaker']} {line['text']}" for line in final_lines_for_response]) + \
f" | {buffer_transcription} | {buffer_diarization}" f" | {buffer_transcription} | {buffer_diarization}"
if current_response_signature != self.last_response_content and \ trans = state["remaining_time_transcription"]
(final_lines_for_response or buffer_transcription or buffer_diarization or response_status == "no_audio_detected"): diar = state["remaining_time_diarization"]
should_push = (
current_response_signature != self.last_response_content
or last_sent_trans is None
or round(trans, 1) != round(last_sent_trans, 1)
or round(diar, 1) != round(last_sent_diar, 1)
)
if should_push and (final_lines_for_response or buffer_transcription or buffer_diarization or response_status == "no_audio_detected" or trans > 0 or diar > 0):
yield response yield response
self.last_response_content = current_response_signature self.last_response_content = current_response_signature
last_sent_trans = trans
last_sent_diar = diar
# Check for termination condition # Check for termination condition
if self.is_stopping: if self.is_stopping:

View File

@@ -1,9 +1,12 @@
try: try:
from whisperlivekit.whisper_streaming_custom.whisper_online import backend_factory, warmup_asr from whisperlivekit.whisper_streaming_custom.whisper_online import backend_factory
from whisperlivekit.whisper_streaming_custom.online_asr import VACOnlineASRProcessor, OnlineASRProcessor
except ImportError: except ImportError:
from .whisper_streaming_custom.whisper_online import backend_factory, warmup_asr from .whisper_streaming_custom.whisper_online import backend_factory
from .whisper_streaming_custom.online_asr import VACOnlineASRProcessor, OnlineASRProcessor
from whisperlivekit.warmup import warmup_asr, warmup_online
from argparse import Namespace from argparse import Namespace
import sys
class TranscriptionEngine: class TranscriptionEngine:
_instance = None _instance = None
@@ -22,7 +25,6 @@ class TranscriptionEngine:
"host": "localhost", "host": "localhost",
"port": 8000, "port": 8000,
"warmup_file": None, "warmup_file": None,
"confidence_validation": False,
"diarization": False, "diarization": False,
"punctuation_split": False, "punctuation_split": False,
"min_chunk_size": 0.5, "min_chunk_size": 0.5,
@@ -34,15 +36,15 @@ class TranscriptionEngine:
"backend": "faster-whisper", "backend": "faster-whisper",
"vac": False, "vac": False,
"vac_chunk_size": 0.04, "vac_chunk_size": 0.04,
"buffer_trimming": "segment",
"buffer_trimming_sec": 15,
"log_level": "DEBUG", "log_level": "DEBUG",
"ssl_certfile": None, "ssl_certfile": None,
"ssl_keyfile": None, "ssl_keyfile": None,
"transcription": True, "transcription": True,
"vad": True, "vad": True,
"segmentation_model": "pyannote/segmentation-3.0", # whisperstreaming params:
"embedding_model": "pyannote/embedding", "buffer_trimming": "segment",
"confidence_validation": False,
"buffer_trimming_sec": 15,
# simulstreaming params: # simulstreaming params:
"frame_threshold": 25, "frame_threshold": 25,
"beams": 1, "beams": 1,
@@ -55,6 +57,10 @@ class TranscriptionEngine:
"static_init_prompt": None, "static_init_prompt": None,
"max_context_tokens": None, "max_context_tokens": None,
"model_path": './base.pt', "model_path": './base.pt',
# diart params:
"segmentation_model": "pyannote/segmentation-3.0",
"embedding_model": "pyannote/embedding",
} }
config_dict = {**defaults, **kwargs} config_dict = {**defaults, **kwargs}
@@ -78,8 +84,32 @@ class TranscriptionEngine:
self.diarization = None self.diarization = None
if self.args.transcription: if self.args.transcription:
self.asr, self.tokenizer = backend_factory(self.args) if self.args.backend == "simulstreaming":
warmup_asr(self.asr, self.args.warmup_file) from simul_whisper import SimulStreamingASR
self.tokenizer = None
simulstreaming_kwargs = {}
for attr in ['frame_threshold', 'beams', 'decoder_type', 'audio_max_len', 'audio_min_len',
'cif_ckpt_path', 'never_fire', 'init_prompt', 'static_init_prompt',
'max_context_tokens', 'model_path']:
if hasattr(self.args, attr):
simulstreaming_kwargs[attr] = getattr(self.args, attr)
# Add segment_length from min_chunk_size
simulstreaming_kwargs['segment_length'] = getattr(self.args, 'min_chunk_size', 0.5)
simulstreaming_kwargs['task'] = self.args.task
size = self.args.model
self.asr = SimulStreamingASR(
modelsize=size,
lan=self.args.lan,
cache_dir=getattr(self.args, 'model_cache_dir', None),
model_dir=getattr(self.args, 'model_dir', None),
**simulstreaming_kwargs
)
else:
self.asr, self.tokenizer = backend_factory(self.args)
warmup_asr(self.asr, self.args.warmup_file) #for simulstreaming, warmup should be done in the online class not here
if self.args.diarization: if self.args.diarization:
from whisperlivekit.diarization.diarization_online import DiartDiarization from whisperlivekit.diarization.diarization_online import DiartDiarization
@@ -90,3 +120,33 @@ class TranscriptionEngine:
) )
TranscriptionEngine._initialized = True TranscriptionEngine._initialized = True
def online_factory(args, asr, tokenizer, logfile=sys.stderr):
if args.backend == "simulstreaming":
from simul_whisper import SimulStreamingOnlineProcessor
online = SimulStreamingOnlineProcessor(
asr,
logfile=logfile,
)
# warmup_online(online, args.warmup_file)
elif args.vac:
online = VACOnlineASRProcessor(
args.min_chunk_size,
asr,
tokenizer,
logfile=logfile,
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
confidence_validation = args.confidence_validation
)
else:
online = OnlineASRProcessor(
asr,
tokenizer,
logfile=logfile,
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
confidence_validation = args.confidence_validation
)
return online

View File

@@ -165,7 +165,7 @@ class WebSocketAudioSource(AudioSource):
class DiartDiarization: class DiartDiarization:
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 0.5, segmentation_model_name: str = "pyannote/segmentation-3.0", embedding_model_name: str = "speechbrain/spkrec-ecapa-voxceleb"): def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 1.5, segmentation_model_name: str = "pyannote/segmentation-3.0", embedding_model_name: str = "pyannote/embedding"):
segmentation_model = m.SegmentationModel.from_pretrained(segmentation_model_name) segmentation_model = m.SegmentationModel.from_pretrained(segmentation_model_name)
embedding_model = m.EmbeddingModel.from_pretrained(embedding_model_name) embedding_model = m.EmbeddingModel.from_pretrained(embedding_model_name)
@@ -206,15 +206,14 @@ class DiartDiarization:
""" """
if self.custom_source: if self.custom_source:
self.custom_source.push_audio(pcm_array) self.custom_source.push_audio(pcm_array)
self.observer.clear_old_segments() # self.observer.clear_old_segments()
return self.observer.get_segments()
def close(self): def close(self):
"""Close the audio source.""" """Close the audio source."""
if self.custom_source: if self.custom_source:
self.custom_source.close() self.custom_source.close()
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list, use_punctuation_split: bool = False) -> float: def assign_speakers_to_tokens(self, tokens: list, use_punctuation_split: bool = False) -> float:
""" """
Assign speakers to tokens based on timing overlap with speaker segments. Assign speakers to tokens based on timing overlap with speaker segments.
Uses the segments collected by the observer. Uses the segments collected by the observer.
@@ -231,85 +230,82 @@ class DiartDiarization:
if not self.lag_diart and segments and tokens: if not self.lag_diart and segments and tokens:
self.lag_diart = segments[0].start - tokens[0].start self.lag_diart = segments[0].start - tokens[0].start
for token in tokens:
for segment in segments:
if not (segment.end <= token.start + self.lag_diart or segment.start >= token.end + self.lag_diart):
token.speaker = extract_number(segment.speaker) + 1
end_attributed_speaker = max(token.end, end_attributed_speaker)
if use_punctuation_split and len(tokens) > 1: if not use_punctuation_split:
punctuation_marks = {'.', '!', '?'} for token in tokens:
for segment in segments:
print("Here are the tokens:", if not (segment.end <= token.start + self.lag_diart or segment.start >= token.end + self.lag_diart):
[(t.text, t.start, t.end, t.speaker) for t in tokens[:10]]) token.speaker = extract_number(segment.speaker) + 1
else:
segment_map = [] tokens = add_speaker_to_tokens(segments, tokens)
for segment in segments: return tokens
speaker_num = extract_number(segment.speaker) + 1
segment_map.append((segment.start, segment.end, speaker_num))
segment_map.sort(key=lambda x: x[0])
i = 0
while i < len(tokens):
current_token = tokens[i]
is_sentence_end = False
if current_token.text and current_token.text.strip():
text = current_token.text.strip()
if text[-1] in punctuation_marks:
is_sentence_end = True
logger.debug(f"Token {i} ends sentence: '{current_token.text}' at {current_token.end:.2f}s")
if is_sentence_end and current_token.speaker != -1:
punctuation_time = current_token.end
current_speaker = current_token.speaker
j = i + 1
next_sentence_tokens = []
while j < len(tokens):
next_token = tokens[j]
next_sentence_tokens.append(j)
# Check if this token ends the next sentence
if next_token.text and next_token.text.strip():
if next_token.text.strip()[-1] in punctuation_marks:
break
j += 1
if next_sentence_tokens:
speaker_times = {}
for idx in next_sentence_tokens:
token = tokens[idx]
# Find which segments overlap with this token
for seg_start, seg_end, seg_speaker in segment_map:
if not (seg_end <= token.start or seg_start >= token.end):
# Calculate overlap duration
overlap_start = max(seg_start, token.start)
overlap_end = min(seg_end, token.end)
overlap_duration = overlap_end - overlap_start
if seg_speaker not in speaker_times:
speaker_times[seg_speaker] = 0
speaker_times[seg_speaker] += overlap_duration
if speaker_times:
dominant_speaker = max(speaker_times.items(), key=lambda x: x[1])[0]
if dominant_speaker != current_speaker:
logger.debug(f" Speaker change after punctuation: {current_speaker}{dominant_speaker}")
for idx in next_sentence_tokens:
if tokens[idx].speaker != dominant_speaker:
logger.debug(f" Reassigning token {idx} ('{tokens[idx].text}') to Speaker {dominant_speaker}")
tokens[idx].speaker = dominant_speaker
end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker)
else:
for idx in next_sentence_tokens:
if tokens[idx].speaker == -1:
tokens[idx].speaker = current_speaker
end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker)
i += 1
return end_attributed_speaker def concatenate_speakers(segments):
segments_concatenated = [{"speaker": 1, "begin": 0.0, "end": 0.0}]
for segment in segments:
speaker = extract_number(segment.speaker) + 1
if segments_concatenated[-1]['speaker'] != speaker:
segments_concatenated.append({"speaker": speaker, "begin": segment.start, "end": segment.end})
else:
segments_concatenated[-1]['end'] = segment.end
# print("Segments concatenated:")
# for entry in segments_concatenated:
# print(f"Speaker {entry['speaker']}: {entry['begin']:.2f}s - {entry['end']:.2f}s")
return segments_concatenated
def add_speaker_to_tokens(segments, tokens):
"""
Assign speakers to tokens based on diarization segments, with punctuation-aware boundary adjustment.
"""
punctuation_marks = {'.', '!', '?'}
punctuation_tokens = [token for token in tokens if token.text.strip() in punctuation_marks]
segments_concatenated = concatenate_speakers(segments)
for ind, segment in enumerate(segments_concatenated):
for i, punctuation_token in enumerate(punctuation_tokens):
if punctuation_token.start > segment['end']:
after_length = punctuation_token.start - segment['end']
before_length = segment['end'] - punctuation_tokens[i - 1].end
if before_length > after_length:
segment['end'] = punctuation_token.start
if i < len(punctuation_tokens) - 1 and ind + 1 < len(segments_concatenated):
segments_concatenated[ind + 1]['begin'] = punctuation_token.start
else:
segment['end'] = punctuation_tokens[i - 1].end
if i < len(punctuation_tokens) - 1 and ind - 1 >= 0:
segments_concatenated[ind - 1]['begin'] = punctuation_tokens[i - 1].end
break
last_end = 0.0
for token in tokens:
start = max(last_end + 0.01, token.start)
token.start = start
token.end = max(start, token.end)
last_end = token.end
ind_last_speaker = 0
for segment in segments_concatenated:
for i, token in enumerate(tokens[ind_last_speaker:]):
if token.end <= segment['end']:
token.speaker = segment['speaker']
ind_last_speaker = i + 1
# print(
# f"Token '{token.text}' ('begin': {token.start:.2f}, 'end': {token.end:.2f}) "
# f"assigned to Speaker {segment['speaker']} ('segment': {segment['begin']:.2f}-{segment['end']:.2f})"
# )
elif token.start > segment['end']:
break
return tokens
def visualize_tokens(tokens):
conversation = [{"speaker": -1, "text": ""}]
for token in tokens:
speaker = conversation[-1]['speaker']
if token.speaker != speaker:
conversation.append({"speaker": token.speaker, "text": token.text})
else:
conversation[-1]['text'] += token.text
print("Conversation:")
for entry in conversation:
print(f"Speaker {entry['speaker']}: {entry['text']}")

View File

@@ -1,32 +0,0 @@
import os
import requests
import inspect
def get_module_path():
return os.path.dirname(inspect.getfile(inspect.currentframe()))
GITHUB_API_URL = "https://api.github.com/repos/ufal/SimulStreaming/contents/simul_whisper/whisper"
RAW_BASE_URL = "https://raw.githubusercontent.com/ufal/SimulStreaming/main/simul_whisper/whisper"
TARGET_DIR = os.path.join(get_module_path(), "simul_whisper", "whisper")
def download_files_from_github(api_url, local_dir):
os.makedirs(local_dir, exist_ok=True)
response = requests.get(api_url)
response.raise_for_status()
items = response.json()
for item in items:
if item['type'] == 'file':
download_url = item['download_url']
file_name = item['name']
file_response = requests.get(download_url)
file_response.raise_for_status()
with open(os.path.join(local_dir, file_name), 'wb') as f:
f.write(file_response.content)
elif item['type'] == 'dir':
# Recursive call for subdirectories
download_files_from_github(item['url'], os.path.join(local_dir, item['name']))
def download_simulstreaming_backend():
print(f"Downloading files into {TARGET_DIR} ...")
download_files_from_github(GITHUB_API_URL, TARGET_DIR)
print("✅ Download of SimulStreaming backend files completed successfully.")

View File

@@ -0,0 +1,103 @@
from whisperlivekit.timed_objects import ASRToken
import re
MIN_SILENCE_DURATION = 4 #in seconds
END_SILENCE_DURATION = 8 #in seconds. you should keep it important to not have false positive when the model lag is important
def blank_to_silence(tokens):
full_string = ''.join([t.text for t in tokens])
patterns = [re.compile(r'(?:\s*\[BLANK_AUDIO\]\s*)+'), re.compile(r'(?:\s*\[typing\]\s*)+')]
matches = []
for pattern in patterns:
for m in pattern.finditer(full_string):
matches.append({
'start': m.start(),
'end': m.end()
})
if matches:
# cleaned = pattern.sub(' ', full_string).strip()
# print("Cleaned:", cleaned)
cumulated_len = 0
silence_token = None
cleaned_tokens = []
for token in tokens:
if matches:
start = cumulated_len
end = cumulated_len + len(token.text)
cumulated_len = end
if start >= matches[0]['start'] and end <= matches[0]['end']:
if silence_token: #previous token was already silence
silence_token.start = min(silence_token.start, token.start)
silence_token.end = max(silence_token.end, token.end)
else: #new silence
silence_token = ASRToken(
start=token.start,
end=token.end,
speaker=-2,
probability=0.95
)
else:
if silence_token: #there was silence but no more
if silence_token.end - silence_token.start >= MIN_SILENCE_DURATION:
cleaned_tokens.append(
silence_token
)
silence_token = None
matches.pop(0)
cleaned_tokens.append(token)
# print(cleaned_tokens)
return cleaned_tokens
return tokens
def no_token_to_silence(tokens):
new_tokens = []
silence_token = None
for token in tokens:
if token.speaker == -2:
if new_tokens and new_tokens[-1].speaker == -2: #if token is silence and previous one too
new_tokens[-1].end = token.end
else:
new_tokens.append(token)
last_end = new_tokens[-1].end if new_tokens else 0.0
if token.start - last_end >= MIN_SILENCE_DURATION: #if token is not silence but important gap
if new_tokens and new_tokens[-1].speaker == -2:
new_tokens[-1].end = token.start
else:
silence_token = ASRToken(
start=last_end,
end=token.start,
speaker=-2,
probability=0.95
)
new_tokens.append(silence_token)
if token.speaker != -2:
new_tokens.append(token)
return new_tokens
def ends_with_silence(tokens, current_time):
if not tokens:
return []
last_token = tokens[-1]
if tokens and current_time - last_token.end >= END_SILENCE_DURATION:
if last_token.speaker == -2:
last_token.end = current_time
else:
tokens.append(
ASRToken(
start=tokens[-1].end,
end=current_time,
speaker=-2,
probability=0.95
)
)
return tokens
def handle_silences(tokens, current_time):
tokens = blank_to_silence(tokens) #useful for simulstreaming backend which tends to generate [BLANK_AUDIO] text
tokens = no_token_to_silence(tokens)
tokens = ends_with_silence(tokens, current_time)
return tokens

View File

@@ -0,0 +1,6 @@
from .backend import SimulStreamingASR, SimulStreamingOnlineProcessor
__all__ = [
"SimulStreamingASR",
"SimulStreamingOnlineProcessor",
]

View File

@@ -0,0 +1,223 @@
import sys
import numpy as np
import logging
from typing import List, Tuple, Optional
import logging
from whisperlivekit.timed_objects import ASRToken, Transcript
from whisperlivekit.simul_whisper.license_simulstreaming import SIMULSTREAMING_LICENSE
from .whisper import load_model, tokenizer
import os
logger = logging.getLogger(__name__)
try:
import torch
from whisperlivekit.simul_whisper.config import AlignAttConfig
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper
from whisperlivekit.simul_whisper.whisper import tokenizer
except ImportError as e:
raise ImportError(
"""SimulStreaming dependencies are not available.
Please install WhisperLiveKit using pip install "whisperlivekit[simulstreaming]".""")
class SimulStreamingOnlineProcessor:
SAMPLING_RATE = 16000
def __init__(
self,
asr,
logfile=sys.stderr,
warmup_file=None
):
self.asr = asr
self.logfile = logfile
self.is_last = False
self.beg = 0.0
self.end = 0.0
self.cumulative_audio_duration = 0.0
self.committed: List[ASRToken] = []
self.last_result_tokens: List[ASRToken] = []
self.model = PaddedAlignAttWhisper(
cfg=asr.cfg,
loaded_model=asr.whisper_model)
if asr.tokenizer:
self.model.tokenizer = asr.tokenizer
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: Optional[float] = None):
"""Append an audio chunk to be processed by SimulStreaming."""
# Convert numpy array to torch tensor
audio_tensor = torch.from_numpy(audio).float()
# Update timing
chunk_duration = len(audio) / self.SAMPLING_RATE
self.cumulative_audio_duration += chunk_duration
if audio_stream_end_time is not None:
self.end = audio_stream_end_time
else:
self.end = self.cumulative_audio_duration
self.model.insert_audio(audio_tensor)
def get_buffer(self):
return Transcript(
start=None,
end=None,
text='',
probability=None
)
def timestamped_text(self, tokens, generation):
# From the simulstreaming repo. self.model to self.asr.model
pr = generation["progress"]
if "result" not in generation:
split_words, split_tokens = self.model.tokenizer.split_to_word_tokens(tokens)
else:
split_words, split_tokens = generation["result"]["split_words"], generation["result"]["split_tokens"]
frames = [p["most_attended_frames"][0] for p in pr]
tokens = tokens.copy()
ret = []
for sw,st in zip(split_words,split_tokens):
b = None
for stt in st:
t,f = tokens.pop(0), frames.pop(0)
if t != stt:
raise ValueError(f"Token mismatch: {t} != {stt} at frame {f}.")
if b is None:
b = f
e = f
out = (b*0.02, e*0.02, sw)
ret.append(out)
logger.debug(f"TS-WORD:\t{' '.join(map(str, out))}")
return ret
def process_iter(self) -> Tuple[List[ASRToken], float]:
"""
Process accumulated audio chunks using SimulStreaming.
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
"""
try:
tokens, generation_progress = self.model.infer(is_last=self.is_last)
ts_words = self.timestamped_text(tokens, generation_progress)
new_tokens = []
for ts_word in ts_words:
start, end, word = ts_word
token = ASRToken(
start=start,
end=end,
text=word,
probability=0.95 # fake prob. Maybe we can extract it from the model?
)
new_tokens.append(token)
self.committed.extend(new_tokens)
return new_tokens, self.end
except Exception as e:
logger.exception(f"SimulStreaming processing error: {e}")
return [], self.end
def warmup(self, audio, init_prompt=""):
"""Warmup the SimulStreaming model."""
try:
self.model.insert_audio(audio)
self.model.infer(True)
self.model.refresh_segment(complete=True)
logger.info("SimulStreaming model warmed up successfully")
except Exception as e:
logger.exception(f"SimulStreaming warmup failed: {e}")
class SimulStreamingASR():
"""SimulStreaming backend with AlignAtt policy."""
sep = ""
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr, **kwargs):
logger.warning(SIMULSTREAMING_LICENSE)
self.logfile = logfile
self.transcribe_kargs = {}
self.original_language = None if lan == "auto" else lan
self.model_path = kwargs.get('model_path', './large-v3.pt')
self.frame_threshold = kwargs.get('frame_threshold', 25)
self.audio_max_len = kwargs.get('audio_max_len', 30.0)
self.audio_min_len = kwargs.get('audio_min_len', 0.0)
self.segment_length = kwargs.get('segment_length', 0.5)
self.beams = kwargs.get('beams', 1)
self.decoder_type = kwargs.get('decoder_type', 'greedy' if self.beams == 1 else 'beam')
self.task = kwargs.get('task', 'transcribe')
self.cif_ckpt_path = kwargs.get('cif_ckpt_path', None)
self.never_fire = kwargs.get('never_fire', False)
self.init_prompt = kwargs.get('init_prompt', None)
self.static_init_prompt = kwargs.get('static_init_prompt', None)
self.max_context_tokens = kwargs.get('max_context_tokens', None)
if model_dir is not None:
self.model_path = model_dir
elif modelsize is not None:
model_mapping = {
'tiny': './tiny.pt',
'base': './base.pt',
'small': './small.pt',
'medium': './medium.pt',
'medium.en': './medium.en.pt',
'large-v1': './large-v1.pt',
'base.en': './base.en.pt',
'small.en': './small.en.pt',
'tiny.en': './tiny.en.pt',
'large-v2': './large-v2.pt',
'large-v3': './large-v3.pt',
'large': './large-v3.pt'
}
self.model_path = model_mapping.get(modelsize, f'./{modelsize}.pt')
self.model = self.load_model(modelsize)
# Set up tokenizer for translation if needed
if self.task == "translate":
self.tokenizer = self.set_translate_task()
else:
self.tokenizer = None
def load_model(self, modelsize):
self.cfg = AlignAttConfig(
model_path=self.model_path,
segment_length=self.segment_length,
frame_threshold=self.frame_threshold,
language=self.original_language,
audio_max_len=self.audio_max_len,
audio_min_len=self.audio_min_len,
cif_ckpt_path=self.cif_ckpt_path,
decoder_type="beam",
beam_size=self.beams,
task=self.task,
never_fire=self.never_fire,
init_prompt=self.init_prompt,
max_context_tokens=self.max_context_tokens,
static_init_prompt=self.static_init_prompt,
)
model_name = os.path.basename(self.cfg.model_path).replace(".pt", "")
model_path = os.path.dirname(os.path.abspath(self.cfg.model_path))
self.whisper_model = load_model(name=model_name, download_root=model_path)
def set_translate_task(self):
"""Set up translation task."""
return tokenizer.get_tokenizer(
multilingual=True,
language=self.model.cfg.language,
num_languages=self.model.model.num_languages,
task="translate"
)
def transcribe(self, audio):
"""
Only used for warmup. It's a direct whisper call, not a simulstreaming call
"""
self.whisper_model.transcribe(audio, language=self.original_language)

View File

@@ -8,7 +8,7 @@ class SimulWhisperConfig:
'''Options that are common for all simul policies that could be implemented in SimulWhisper.''' '''Options that are common for all simul policies that could be implemented in SimulWhisper.'''
model_path: str model_path: str
language: str = field(default="zh") language: str = field(default="zh")
nonspeech_prob: float = 1.0 nonspeech_prob: float = 0.5
audio_min_len: float = 1.0 audio_min_len: float = 1.0
decoder_type: Literal["greedy","beam"] = "greedy" decoder_type: Literal["greedy","beam"] = "greedy"
beam_size: int = 5 beam_size: int = 5

View File

@@ -1,25 +0,0 @@
📄 SimulStreaming (https://github.com/ufal/SimulStreaming) Licence
SimulStreaming is dual-licensed:
🔹 Non-Commercial Use
You may use SimulStreaming under the **PolyForm Noncommercial License 1.0.0** if you
obtain the code through the GitHub repository. This license is **free of charge**
and comes with **no obligations** for non-commercial users.
🔸 Commercial Use
Understanding who uses SimulStreaming commercially helps us improve and
prioritize development. Therefore, we want to **require registration** of those who acquire a commercial licence.
We plan to make the commercial licenceses **affordable** to SMEs and individuals. We
are considering to provide commercial licenses either for free or for symbolic
one-time fee, and maybe also provide additional support. You can share your preference via the [questionnaire](https://forms.cloud.microsoft/e/7tCxb4gJfB).
You can also leave your contact [there](https://forms.cloud.microsoft/e/7tCxb4gJfB) to be notified when the commercial licenses become
available.
✉️ Contact
[Dominik Macháček](https://ufal.mff.cuni.cz/dominik-machacek/), machacek@ufal.mff.cuni.cz

View File

@@ -25,6 +25,9 @@ class BeamTokens(Tokens):
def __repr__(self): def __repr__(self):
return self.__str__() return self.__str__()
def as_text(self, tokenizer):
return tokenizer.decode(self.tokens)
class Logits(Tokens): class Logits(Tokens):
def __init__(self, logits): def __init__(self, logits):
super().__init__(logits) super().__init__(logits)

View File

@@ -0,0 +1,5 @@
SIMULSTREAMING_LICENSE = f"""
SimulStreaming backend is dual-licensed:
• Non-Commercial Use: PolyForm Noncommercial License 1.0.0.
• Commercial Use: Check SimulStreaming README (github.com/ufal/SimulStreaming) for more details.
"""

View File

@@ -10,12 +10,12 @@ from .whisper import load_model, DecodingOptions, tokenizer
from .config import AlignAttConfig from .config import AlignAttConfig
from .whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES from .whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES
from .whisper.timing import median_filter from .whisper.timing import median_filter
from .whisper.decoding import SuppressBlank, GreedyDecoder, BeamSearchDecoder, SuppressTokens from .whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens, detect_language
from .beam import BeamPyTorchInference from .beam import BeamPyTorchInference
from .eow_detection import fire_at_boundary, load_cif from .eow_detection import fire_at_boundary, load_cif
import os import os
from whisperlivekit.simul_whisper.token_buffer import TokenBuffer from .token_buffer import TokenBuffer
import numpy as np import numpy as np
from .generation_progress import * from .generation_progress import *
@@ -24,6 +24,7 @@ DEC_PAD = 50257
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
import sys import sys
import wave
# New features added to the original version of Simul-Whisper: # New features added to the original version of Simul-Whisper:
# - large-v3 model support # - large-v3 model support
@@ -32,29 +33,30 @@ import sys
# - prompt -- static vs. non-static # - prompt -- static vs. non-static
# - context # - context
class PaddedAlignAttWhisper: class PaddedAlignAttWhisper:
def __init__(self, cfg: AlignAttConfig) -> None: def __init__(self, cfg: AlignAttConfig, loaded_model=None) -> None:
self.log_segments = 0
model_name = os.path.basename(cfg.model_path).replace(".pt", "") model_name = os.path.basename(cfg.model_path).replace(".pt", "")
model_path = os.path.dirname(os.path.abspath(cfg.model_path)) model_path = os.path.dirname(os.path.abspath(cfg.model_path))
self.model = load_model(name=model_name, download_root=model_path) if loaded_model:
self.model = loaded_model
else:
self.model = load_model(name=model_name, download_root=model_path)
logger.info(f"Model dimensions: {self.model.dims}") logger.info(f"Model dimensions: {self.model.dims}")
decode_options = DecodingOptions( self.decode_options = DecodingOptions(
language = cfg.language, language = cfg.language,
without_timestamps = True, without_timestamps = True,
task=cfg.task task=cfg.task
) )
self.tokenizer = tokenizer.get_tokenizer( self.tokenizer_is_multilingual = not model_name.endswith(".en")
multilingual=not model_name.endswith(".en"), self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
language=cfg.language, self.detected_language = cfg.language if cfg.language != "auto" else None
num_languages=self.model.num_languages,
task=decode_options.task
)
self.max_text_len = self.model.dims.n_text_ctx self.max_text_len = self.model.dims.n_text_ctx
self.num_decoder_layers = len(self.model.decoder.blocks) self.num_decoder_layers = len(self.model.decoder.blocks)
self.cfg = cfg self.cfg = cfg
# model to detect end-of-word boundary at the end of the segment # model to detect end-of-word boundary at the end of the segment
self.CIFLinear, self.always_fire, self.never_fire = load_cif(cfg, self.CIFLinear, self.always_fire, self.never_fire = load_cif(cfg,
n_audio_state=self.model.dims.n_audio_state, n_audio_state=self.model.dims.n_audio_state,
@@ -95,14 +97,6 @@ class PaddedAlignAttWhisper:
self.num_align_heads += 1 self.num_align_heads += 1
# init tokens (mandatory prompt)
self.initial_tokens = torch.tensor(
self.tokenizer.sot_sequence_including_notimestamps,
dtype=torch.long,
device=self.model.device).unsqueeze(0)
self.initial_token_length = self.initial_tokens.shape[1]
self.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
# tokens to be suppressed from decoding, to prevent hallucinations # tokens to be suppressed from decoding, to prevent hallucinations
suppress_tokens = [ suppress_tokens = [
self.tokenizer.transcribe, self.tokenizer.transcribe,
@@ -121,6 +115,17 @@ class PaddedAlignAttWhisper:
self.suppress_tokens = lambda logits: sup_tokens.apply(logits, None) self.suppress_tokens = lambda logits: sup_tokens.apply(logits, None)
# blank tokens are suppresed for new segments near the line 334 # blank tokens are suppresed for new segments near the line 334
# it's going to be regenerated after lang id
self.segments = []
self.init_tokens()
self.last_attend_frame = -self.cfg.rewind_threshold
if self.cfg.max_context_tokens is None:
self.max_context_tokens = self.max_text_len
else:
self.max_context_tokens = self.cfg.max_context_tokens
self.init_context()
# decoder type: greedy or beam # decoder type: greedy or beam
if cfg.decoder_type == "greedy": if cfg.decoder_type == "greedy":
@@ -135,16 +140,13 @@ class PaddedAlignAttWhisper:
self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size) self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size)
# init state def create_tokenizer(self, language=None):
self.segments = [] self.tokenizer = tokenizer.get_tokenizer(
self.tokens = [self.initial_tokens] multilingual=self.tokenizer_is_multilingual,
self.last_attend_frame = -self.cfg.rewind_threshold language=language,
num_languages=self.model.num_languages,
if self.cfg.max_context_tokens is None: task=self.decode_options.task
self.max_context_tokens = self.max_text_len )
else:
self.max_context_tokens = self.cfg.max_context_tokens
self.init_context()
def init_context(self): def init_context(self):
kw = {'tokenizer': self.tokenizer, kw = {'tokenizer': self.tokenizer,
@@ -156,6 +158,19 @@ class PaddedAlignAttWhisper:
if self.cfg.init_prompt is not None: if self.cfg.init_prompt is not None:
self.context.text += self.cfg.init_prompt self.context.text += self.cfg.init_prompt
def init_tokens(self):
logger.debug(f"init tokens, {len(self.segments)}")
# init tokens (mandatory prompt)
self.initial_tokens = torch.tensor(
self.tokenizer.sot_sequence_including_notimestamps,
dtype=torch.long,
device=self.model.device).unsqueeze(0)
self.initial_token_length = self.initial_tokens.shape[1]
self.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
# self.segments = []
logger.debug(f"init tokens after, {len(self.segments)}")
self.tokens = [self.initial_tokens]
def trim_context(self): def trim_context(self):
logger.info("Trimming context") logger.info("Trimming context")
c = len(self.context.as_token_ids()) - len(self.context.prefix_token_ids) c = len(self.context.as_token_ids()) - len(self.context.prefix_token_ids)
@@ -191,15 +206,19 @@ class PaddedAlignAttWhisper:
def refresh_segment(self, complete=False): def refresh_segment(self, complete=False):
logger.debug("Refreshing segment") logger.debug("Refreshing segment:")
self.tokens = [self.initial_tokens] self.init_tokens()
self.last_attend_frame = -self.cfg.rewind_threshold self.last_attend_frame = -self.cfg.rewind_threshold
self.detected_language = None
self.init_context() self.init_context()
logger.debug(f"Context: {self.context}") logger.debug(f"Context: {self.context}")
if not complete and len(self.segments) > 2: if not complete and len(self.segments) > 2:
logger.debug("keeping last two segments because they are and it is not complete.")
self.segments = self.segments[-2:] self.segments = self.segments[-2:]
else: else:
logger.debug("removing all segments.")
self.segments = [] self.segments = []
self.log_segments += 1
def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor): def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor):
@@ -208,8 +227,6 @@ class PaddedAlignAttWhisper:
return fire_at_boundary(chunked_encoder_feature, self.CIFLinear) return fire_at_boundary(chunked_encoder_feature, self.CIFLinear)
def _current_tokens(self): def _current_tokens(self):
toks = self.tokens toks = self.tokens
@@ -256,16 +273,59 @@ class PaddedAlignAttWhisper:
removed_len = 0 removed_len = 0
# len of audio is bigger than buffer_len. Going to remove the first segment # len of audio is bigger than buffer_len. Going to remove the first segment
segments_len = self.segments_len() segments_len = self.segments_len()
while segments_len > self.cfg.audio_max_len: while len(self.segments) > 1 and segments_len > self.cfg.audio_max_len:
removed_len = self.segments[0].shape[0] / 16000 removed_len = self.segments[0].shape[0] / 16000
segments_len -= removed_len segments_len -= removed_len
self.last_attend_frame -= int(TOKENS_PER_SECOND*removed_len) self.last_attend_frame -= int(TOKENS_PER_SECOND*removed_len)
self.segments = self.segments[1:] self.segments = self.segments[1:]
logger.debug(f"remove segments: {len(self.segments)} {len(self.tokens)}") logger.debug(f"remove segments: {len(self.segments)} {len(self.tokens)}")
self.context.append_token_ids(self.tokens[1][0,:]) if len(self.tokens) > 1:
self.tokens = [self.initial_tokens] + self.tokens[2:] self.context.append_token_ids(self.tokens[1][0,:])
self.tokens = [self.initial_tokens] + self.tokens[2:]
return removed_len return removed_len
def _clean_cache(self):
'''clean the cache that stores the attention matrices and kv_cache.
It must be called every time after generation with the model.'''
# cleaning cache
self.dec_attns = []
self.kv_cache = {}
if self.decoder_type == "beam":
self.inference.kv_cache = self.kv_cache
self.token_decoder.reset()
@torch.no_grad()
def lang_id(self, encoder_features):
"""Language detection from encoder features.
This code is trimmed and copy-pasted from whisper.decoding.detect_language .
"""
# forward pass using a single token, startoftranscript
n_audio = encoder_features.shape[0]
x = torch.tensor([[self.tokenizer.sot]] * n_audio).to(self.model.device) # [n_audio, 1]
logits = self.model.logits(x, encoder_features)[:, 0]
# collect detected languages; suppress all non-language tokens
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
mask[list(self.tokenizer.all_language_tokens)] = False
logits[:, mask] = -np.inf
language_tokens = logits.argmax(dim=-1)
language_token_probs = logits.softmax(dim=-1).cpu()
language_probs = [
{
c: language_token_probs[i, j].item()
for j, c in zip(self.tokenizer.all_language_tokens, self.tokenizer.all_language_codes)
}
for i in range(n_audio)
]
single = encoder_features.ndim == 2
if single:
language_tokens = language_tokens[0]
language_probs = language_probs[0]
self._clean_cache()
return language_tokens, language_probs
### transcription / translation ### transcription / translation
@@ -273,9 +333,12 @@ class PaddedAlignAttWhisper:
def infer(self, is_last=False): def infer(self, is_last=False):
new_segment = True new_segment = True
if len(self.segments) == 0: if len(self.segments) == 0:
return [] logger.debug("No segments, nothing to do")
return [], {}
if not self._apply_minseglen(): if not self._apply_minseglen():
return [] logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
input_segments = torch.cat(self.segments, dim=0)
return [], {}
# input_segments is concatenation of audio, it's one array # input_segments is concatenation of audio, it's one array
if len(self.segments) > 1: if len(self.segments) > 1:
@@ -283,8 +346,7 @@ class PaddedAlignAttWhisper:
else: else:
input_segments = self.segments[0] input_segments = self.segments[0]
self.trim_context()
current_tokens = self._current_tokens()
# mel + padding to 30s # mel + padding to 30s
mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES, mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES,
@@ -295,18 +357,38 @@ class PaddedAlignAttWhisper:
# the len of actual audio # the len of actual audio
content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2) content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2)
# encode
encoder_feature = self.model.encoder(mel) encoder_feature = self.model.encoder(mel)
sum_logprobs = torch.zeros(self.cfg.beam_size, device=mel.device)
completed = False
# logger.debug(f"Encoder feature shape: {encoder_feature.shape}")
# if mel.shape[-2:] != (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
# logger.debug("mel ")
if self.cfg.language == "auto" and self.detected_language is None:
language_tokens, language_probs = self.lang_id(encoder_feature)
logger.debug(f"Language tokens: {language_tokens}, probs: {language_probs}")
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
logger.info(f"Detected language: {top_lan} with p={p:.4f}")
#self.tokenizer.language = top_lan
#self.tokenizer.__post_init__()
self.create_tokenizer(top_lan)
self.detected_language = top_lan
self.init_tokens()
logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}")
self.trim_context()
current_tokens = self._current_tokens()
#
fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :]) fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :])
####################### Decoding loop ####################### Decoding loop
logger.info("Decoding loop starts\n") logger.info("Decoding loop starts\n")
sum_logprobs = torch.zeros(self.cfg.beam_size, device=mel.device)
completed = False
attn_of_alignment_heads = None attn_of_alignment_heads = None
miost_attended_frame = None most_attended_frame = None
token_len_before_decoding = current_tokens.shape[1] token_len_before_decoding = current_tokens.shape[1]
@@ -515,11 +597,6 @@ class PaddedAlignAttWhisper:
logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}") logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}")
# cleaning cache self._clean_cache()
self.dec_attns = []
self.kv_cache = {}
if self.decoder_type == "beam":
self.inference.kv_cache = self.kv_cache
self.token_decoder.reset()
return new_hypothesis, generation return new_hypothesis, generation

View File

@@ -32,7 +32,9 @@ def detect_language(
list of dictionaries containing the probability distribution over all languages. list of dictionaries containing the probability distribution over all languages.
""" """
if tokenizer is None: if tokenizer is None:
tokenizer = get_tokenizer(model.is_multilingual) tokenizer = get_tokenizer(
model.is_multilingual, num_languages=model.num_languages
)
if ( if (
tokenizer.language is None tokenizer.language is None
or tokenizer.language_token not in tokenizer.sot_sequence or tokenizer.language_token not in tokenizer.sot_sequence
@@ -111,9 +113,6 @@ class DecodingOptions:
# implementation details # implementation details
fp16: bool = True # use fp16 for most of the calculation fp16: bool = True # use fp16 for most of the calculation
# streaming
add_sot: Optional[bool] = True
@dataclass(frozen=True) @dataclass(frozen=True)
class DecodingResult: class DecodingResult:
@@ -513,19 +512,17 @@ class DecodingTask:
logit_filters: List[LogitFilter] logit_filters: List[LogitFilter]
def __init__(self, model: "Whisper", options: DecodingOptions): def __init__(self, model: "Whisper", options: DecodingOptions):
self.options: DecodingOptions = self._verify_options(options) self.model = model
if self.options.fp16:
self.model = model.half()
else:
self.model = model
language = options.language or "en" language = options.language or "en"
tokenizer = get_tokenizer( tokenizer = get_tokenizer(
model.is_multilingual, language=language, task=options.task model.is_multilingual,
num_languages=model.num_languages,
language=language,
task=options.task,
) )
self.tokenizer: Tokenizer = tokenizer self.tokenizer: Tokenizer = tokenizer
self.options: DecodingOptions = self._verify_options(options)
# print(self.options)
self.n_group: int = options.beam_size or options.best_of or 1 self.n_group: int = options.beam_size or options.best_of or 1
self.n_ctx: int = model.dims.n_text_ctx self.n_ctx: int = model.dims.n_text_ctx
@@ -589,7 +586,7 @@ class DecodingTask:
def _get_initial_tokens(self) -> Tuple[int]: def _get_initial_tokens(self) -> Tuple[int]:
tokens = list(self.sot_sequence) tokens = list(self.sot_sequence)
# print("prefix", prefix)
if prefix := self.options.prefix: if prefix := self.options.prefix:
prefix_tokens = ( prefix_tokens = (
self.tokenizer.encode(" " + prefix.strip()) self.tokenizer.encode(" " + prefix.strip())
@@ -607,15 +604,12 @@ class DecodingTask:
if isinstance(prompt, str) if isinstance(prompt, str)
else prompt else prompt
) )
# if self.options.add_sot:
tokens = ( tokens = (
[self.tokenizer.sot_prev] [self.tokenizer.sot_prev]
+ prompt_tokens[-(self.n_ctx // 2 - 1) :] + prompt_tokens[-(self.n_ctx // 2 - 1) :]
+ tokens + tokens
) )
#else:
# tokens = ([self.tokenizer.sot_prev] + tokens + prompt_tokens[-(self.n_ctx // 2 - 1) :])
# print("return", tokens)
return tuple(tokens) return tuple(tokens)
def _get_suppress_tokens(self) -> Tuple[int]: def _get_suppress_tokens(self) -> Tuple[int]:
@@ -663,7 +657,7 @@ class DecodingTask:
if audio_features.dtype != ( if audio_features.dtype != (
torch.float16 if self.options.fp16 else torch.float32 torch.float16 if self.options.fp16 else torch.float32
): ):
raise TypeError( return TypeError(
f"audio_features has an incorrect dtype: {audio_features.dtype}" f"audio_features has an incorrect dtype: {audio_features.dtype}"
) )
@@ -689,10 +683,9 @@ class DecodingTask:
no_speech_probs = [np.nan] * n_batch no_speech_probs = [np.nan] * n_batch
try: try:
for i in range(self.sample_len): # 最多循环448次 for i in range(self.sample_len):
# print("in decode main loop", i , tokens[0].tolist())
logits = self.inference.logits(tokens, audio_features) logits = self.inference.logits(tokens, audio_features)
# print(logits)
if ( if (
i == 0 and self.tokenizer.no_speech is not None i == 0 and self.tokenizer.no_speech is not None
): # save no_speech_probs ): # save no_speech_probs
@@ -724,7 +717,7 @@ class DecodingTask:
audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass
tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1) tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
# print("initial_tokens", self.initial_tokens)
# detect language if requested, overwriting the language token # detect language if requested, overwriting the language token
languages, language_probs = self._detect_language(audio_features, tokens) languages, language_probs = self._detect_language(audio_features, tokens)
if self.options.task == "lang_id": if self.options.task == "lang_id":

View File

@@ -13,7 +13,6 @@ from .decoding import decode as decode_function
from .decoding import detect_language as detect_language_function from .decoding import detect_language as detect_language_function
from .transcribe import transcribe as transcribe_function from .transcribe import transcribe as transcribe_function
try: try:
from torch.nn.functional import scaled_dot_product_attention from torch.nn.functional import scaled_dot_product_attention
@@ -37,26 +36,27 @@ class ModelDimensions:
n_text_layer: int n_text_layer: int
# class LayerNorm(nn.LayerNorm): class LayerNorm(nn.LayerNorm):
# def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
# return super().forward(x.float()).type(x.dtype) return super().forward(x.float()).type(x.dtype)
# class Linear(nn.Linear):
# def forward(self, x: Tensor) -> Tensor:
# return F.linear(
# x,
# self.weight.to(x.dtype),
# None if self.bias is None else self.bias.to(x.dtype),
# )
# class Conv1d(nn.Conv1d): class Linear(nn.Linear):
# def _conv_forward( def forward(self, x: Tensor) -> Tensor:
# self, x: Tensor, weight: Tensor, bias: Optional[Tensor] return F.linear(
# ) -> Tensor: x,
# return super()._conv_forward( self.weight.to(x.dtype),
# x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype) None if self.bias is None else self.bias.to(x.dtype),
# ) )
class Conv1d(nn.Conv1d):
def _conv_forward(
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
) -> Tensor:
return super()._conv_forward(
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
)
def sinusoids(length, channels, max_timescale=10000): def sinusoids(length, channels, max_timescale=10000):
@@ -67,21 +67,30 @@ def sinusoids(length, channels, max_timescale=10000):
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
import sys ## this is mine, for debugging
@contextmanager
def disable_sdpa():
prev_state = MultiHeadAttention.use_sdpa
try:
MultiHeadAttention.use_sdpa = False
yield
finally:
MultiHeadAttention.use_sdpa = prev_state
class MultiHeadAttention(nn.Module): class MultiHeadAttention(nn.Module):
use_sdpa = False # Disable SDPA to ensure qk is always computed for hooks
use_sdpa = False # disabling: https://github.com/linto-ai/whisper-timestamped/issues/212 def __init__(self, n_state: int, n_head: int, cache_id: str = ""):
def __init__(self, n_state: int, n_head: int, cache_id: str):
super().__init__() super().__init__()
self.n_head = n_head self.n_head = n_head
self.query = nn.Linear(n_state, n_state) self.query = Linear(n_state, n_state)
self.key = nn.Linear(n_state, n_state, bias=False) self.key = Linear(n_state, n_state, bias=False)
self.key.cache_id = f"{cache_id}_key" self.value = Linear(n_state, n_state)
self.value = nn.Linear(n_state, n_state) self.out = Linear(n_state, n_state)
self.value.cache_id = f"{cache_id}_value"
self.out = nn.Linear(n_state, n_state)
self.cache_id = cache_id self.cache_id = cache_id
self.key.cache_id = f"{cache_id}_key"
self.value.cache_id = f"{cache_id}_value"
def forward( def forward(
self, self,
@@ -90,45 +99,21 @@ class MultiHeadAttention(nn.Module):
mask: Optional[Tensor] = None, mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None, kv_cache: Optional[dict] = None,
): ):
#print("MultiHeadAttention forward",file=sys.stderr)
q = self.query(x) q = self.query(x)
# print(q.shape, x is None, mask is None, list(kv_cache.keys()) if kv_cache is not None else None, file=sys.stderr)
# print(mask, kv_cache, xa, file=sys.stderr)
if kv_cache is None or xa is None or self.key.cache_id not in kv_cache: if kv_cache is None or xa is None or self.key not in kv_cache:
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
# otherwise, perform key/value projections for self- or cross-attention as usual.
k = self.key(x if xa is None else xa) k = self.key(x if xa is None else xa)
v = self.value(x if xa is None else xa) v = self.value(x if xa is None else xa)
# print(self.key.cache_id, "cache miss") # , kv_cache is None, xa is None, self.key.cache_id not in kv_cache if kv_cache is not None else None, k.shape, x.shape)
# if kv_cache is not None:
# print(kv_cache.keys())
else: else:
# print(self.key.cache_id, "cache hit") #, kv_cache is None, xa is None, self.key.cache_id not in kv_cache) # for cross-attention, calculate keys and values once and reuse in subsequent calls.
# if kv_cache is not None: k = kv_cache[self.key]
# print(kv_cache.keys()) v = kv_cache[self.value]
k = kv_cache[self.key.cache_id]
v = kv_cache[self.value.cache_id]
# print(self.key.cache_id, "qkv attention", q.shape, k.shape, v.shape)
wv, qk = self.qkv_attention(q, k, v, mask) wv, qk = self.qkv_attention(q, k, v, mask)
return self.out(wv), qk return self.out(wv), qk
# def qkv_attention(
# self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
# ):
# n_batch, n_ctx, n_state = q.shape
# scale = (n_state // self.n_head) ** -0.25
# q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
# k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
# v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
# qk = q @ k
# if mask is not None:
# qk = qk + mask[:n_ctx, :n_ctx]
# # qk = qk.float()
# w = F.softmax(qk, dim=-1) # .to(q.dtype)
# return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
def qkv_attention( def qkv_attention(
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@@ -158,21 +143,22 @@ class MultiHeadAttention(nn.Module):
class ResidualAttentionBlock(nn.Module): class ResidualAttentionBlock(nn.Module):
def __init__(self, n_state: int, n_head: int, cache_id: str="", cross_attention: bool = False): def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, cache_id: str = ""):
super().__init__() super().__init__()
self.attn = MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_self_attn") self.attn = MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_self_attn")
self.attn_ln = nn.LayerNorm(n_state) self.attn_ln = LayerNorm(n_state)
self.cross_attn = MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_cross_attn") if cross_attention else None self.cross_attn = (
MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_cross_attn") if cross_attention else None
self.cross_attn_ln = nn.LayerNorm(n_state) if cross_attention else None )
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
n_mlp = n_state * 4 n_mlp = n_state * 4
self.mlp = nn.Sequential( self.mlp = nn.Sequential(
nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state) Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
) )
self.mlp_ln = nn.LayerNorm(n_state) self.mlp_ln = LayerNorm(n_state)
def forward( def forward(
self, self,
@@ -181,8 +167,6 @@ class ResidualAttentionBlock(nn.Module):
mask: Optional[Tensor] = None, mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None, kv_cache: Optional[dict] = None,
): ):
# print("ResidualAttentionBlock forward",file=sys.stderr)
# print(x.shape, file=sys.stderr)
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0] x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
if self.cross_attn: if self.cross_attn:
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0] x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
@@ -195,44 +179,32 @@ class AudioEncoder(nn.Module):
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
): ):
super().__init__() super().__init__()
self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1) self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)) self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[ResidualAttentionBlock(n_state, n_head, cache_id=f"enc_layer{i}") for i in range(n_layer)] [ResidualAttentionBlock(n_state, n_head, cache_id=f"enc_layer{i}") for i in range(n_layer)]
) )
self.ln_post = nn.LayerNorm(n_state) self.ln_post = LayerNorm(n_state)
def forward(self, x: Tensor, return_layer_results: bool=False): def forward(self, x: Tensor):
""" """
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
the mel spectrogram of the audio the mel spectrogram of the audio
""" """
x = F.gelu(self.conv1(x)) x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x)) x = F.gelu(self.conv2(x))
x = x.permute(0, 2, 1) # BDT -> BTD x = x.permute(0, 2, 1)
# 两层卷积2倍降采样 assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
# 最终剩下1500帧 x = (x + self.positional_embedding).to(x.dtype)
x = (x + self.positional_embedding[:x.shape[1], :]) #.to(x.dtype)
layer_results = []
i = 0
for block in self.blocks: for block in self.blocks:
# print(f"encoder layer {i}")
x = block(x) x = block(x)
layer_results.append(x)
i += 1
x = self.ln_post(x) x = self.ln_post(x)
return x
if return_layer_results:
return x, layer_results
else:
return x
class TextDecoder(nn.Module): class TextDecoder(nn.Module):
@@ -250,7 +222,7 @@ class TextDecoder(nn.Module):
for i in range(n_layer) for i in range(n_layer)
] ]
) )
self.ln = nn.LayerNorm(n_state) self.ln = LayerNorm(n_state)
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
self.register_buffer("mask", mask, persistent=False) self.register_buffer("mask", mask, persistent=False)
@@ -262,22 +234,20 @@ class TextDecoder(nn.Module):
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state) xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
the encoded audio features to be attended on the encoded audio features to be attended on
""" """
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
x = ( x = (
self.token_embedding(x) self.token_embedding(x)
+ self.positional_embedding[offset : offset + x.shape[-1]] + self.positional_embedding[offset : offset + x.shape[-1]]
) )
# x = x.to(xa.dtype) x = x.to(xa.dtype)
i = 0
for block in self.blocks: for block in self.blocks:
# print(f"decoder layer {i}")
x = block(x, xa, mask=self.mask, kv_cache=kv_cache) x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
i += 1
x = self.ln(x) x = self.ln(x)
logits = x @ torch.transpose(self.token_embedding.weight, 0, 1) logits = (
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
).float()
return logits return logits
@@ -300,7 +270,8 @@ class Whisper(nn.Module):
self.dims.n_text_head, self.dims.n_text_head,
self.dims.n_text_layer, self.dims.n_text_layer,
) )
# use the last half layers for alignment by default; see `set_alignment_heads()` below # use the last half among the decoder layers for time alignment by default;
# to use a specific set of heads, see `set_alignment_heads()` below.
all_heads = torch.zeros( all_heads = torch.zeros(
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
) )
@@ -320,15 +291,11 @@ class Whisper(nn.Module):
return self.encoder(mel) return self.encoder(mel)
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor): def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
# tokens = tokens.to(self.decoder.ln.weight.dtype)
# audio_features = audio_features.to(self.decoder.ln.weight.dtype)
return self.decoder(tokens, audio_features) return self.decoder(tokens, audio_features)
def forward( def forward(
self, mel: torch.Tensor, tokens: torch.Tensor self, mel: torch.Tensor, tokens: torch.Tensor
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
# mel = mel.to(self.decoder.ln.weight.dtype)
# tokens = tokens.to(self.decoder.ln.weight.dtype)
return self.decoder(tokens, self.encoder(mel)) return self.decoder(tokens, self.encoder(mel))
@property @property
@@ -343,7 +310,6 @@ class Whisper(nn.Module):
def num_languages(self): def num_languages(self):
return self.dims.n_vocab - 51765 - int(self.is_multilingual) return self.dims.n_vocab - 51765 - int(self.is_multilingual)
# 为decoder加入缓存机制每次推理时保存上次的k和v下次推理无需重新计算
def install_kv_cache_hooks(self, cache: Optional[dict] = None): def install_kv_cache_hooks(self, cache: Optional[dict] = None):
""" """
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value

View File

@@ -30,15 +30,19 @@ def remove_symbols_and_diacritics(s: str, keep=""):
and drop any diacritics (category 'Mn' and some manual mappings) and drop any diacritics (category 'Mn' and some manual mappings)
""" """
return "".join( return "".join(
c (
if c in keep c
else ADDITIONAL_DIACRITICS[c] if c in keep
if c in ADDITIONAL_DIACRITICS else (
else "" ADDITIONAL_DIACRITICS[c]
if unicodedata.category(c) == "Mn" if c in ADDITIONAL_DIACRITICS
else " " else (
if unicodedata.category(c)[0] in "MSP" ""
else c if unicodedata.category(c) == "Mn"
else " " if unicodedata.category(c)[0] in "MSP" else c
)
)
)
for c in unicodedata.normalize("NFKD", s) for c in unicodedata.normalize("NFKD", s)
) )

File diff suppressed because it is too large Load Diff

View File

@@ -56,9 +56,8 @@ def median_filter(x: torch.Tensor, filter_width: int):
@numba.jit(nopython=True) @numba.jit(nopython=True)
def backtrace(trace: np.ndarray): def backtrace(trace: np.ndarray):
i = trace.shape[0] - 1 # trace: (N+1, M+1), i=N i = trace.shape[0] - 1
j = trace.shape[1] - 1 # j=M j = trace.shape[1] - 1
# 边界点其实无意义?
trace[0, :] = 2 trace[0, :] = 2
trace[:, 0] = 1 trace[:, 0] = 1
@@ -83,8 +82,8 @@ def backtrace(trace: np.ndarray):
@numba.jit(nopython=True, parallel=True) @numba.jit(nopython=True, parallel=True)
def dtw_cpu(x: np.ndarray): def dtw_cpu(x: np.ndarray):
N, M = x.shape N, M = x.shape
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf # cost: x[0, 0]到x[i-1, j-1]的最小代价 cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
trace = -np.ones((N + 1, M + 1), dtype=np.float32) # trace: trace = -np.ones((N + 1, M + 1), dtype=np.float32)
cost[0, 0] = 0 cost[0, 0] = 0
for j in range(1, M + 1): for j in range(1, M + 1):
@@ -118,7 +117,7 @@ def dtw_cuda(x, BLOCK_SIZE=1024):
x_skew = x_skew.T.contiguous() x_skew = x_skew.T.contiguous()
cost = torch.ones(N + M + 2, M + 2) * np.inf cost = torch.ones(N + M + 2, M + 2) * np.inf
cost[0, 0] = 0 cost[0, 0] = 0
cost = cost.cuda() cost = cost.to(x.device)
trace = torch.zeros_like(cost, dtype=torch.int32) trace = torch.zeros_like(cost, dtype=torch.int32)
dtw_kernel[(1,)]( dtw_kernel[(1,)](
@@ -192,21 +191,19 @@ def find_alignment(
for i, block in enumerate(model.decoder.blocks) for i, block in enumerate(model.decoder.blocks)
] ]
# 进行前传获得token概率 from .model import disable_sdpa
with torch.no_grad():
with torch.no_grad(), disable_sdpa():
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0] logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot] sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
token_probs = sampled_logits.softmax(dim=-1) token_probs = sampled_logits.softmax(dim=-1)
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens] text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
text_token_probs = text_token_probs.tolist() text_token_probs = text_token_probs.tolist()
# 移除钩子
for hook in hooks: for hook in hooks:
hook.remove() hook.remove()
# heads * tokens * frames # heads * tokens * frames
# print(model.alignment_heads)
# exit(0)
weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T]) weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
weights = weights[:, :, : num_frames // 2] weights = weights[:, :, : num_frames // 2]
weights = (weights * qk_scale).softmax(dim=-1) weights = (weights * qk_scale).softmax(dim=-1)
@@ -215,18 +212,9 @@ def find_alignment(
weights = median_filter(weights, medfilt_width) weights = median_filter(weights, medfilt_width)
matrix = weights.mean(axis=0) matrix = weights.mean(axis=0)
print("attention", matrix.shape, matrix[:5, :5])
matrix = matrix[len(tokenizer.sot_sequence) : -1] matrix = matrix[len(tokenizer.sot_sequence) : -1]
print("attention", matrix.shape, matrix[:5, :5])
text_indices, time_indices = dtw(-matrix) text_indices, time_indices = dtw(-matrix)
print("num_frames", num_frames)
print("attention", matrix.shape, matrix[:5, :5])
print("text_indices", text_indices)
print("time", time_indices)
print("text_tokens", text_tokens, tokenizer.decode(text_tokens), len(text_tokens))
print("eot", tokenizer.eot)
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot]) words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
if len(word_tokens) <= 1: if len(word_tokens) <= 1:
# return on eot only # return on eot only
@@ -238,9 +226,7 @@ def find_alignment(
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0)) word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool) jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
# print("jumps", jumps, jumps.shape)
jump_times = time_indices[jumps] / TOKENS_PER_SECOND jump_times = time_indices[jumps] / TOKENS_PER_SECOND
# print("jump_times", jump_times)
start_times = jump_times[word_boundaries[:-1]] start_times = jump_times[word_boundaries[:-1]]
end_times = jump_times[word_boundaries[1:]] end_times = jump_times[word_boundaries[1:]]
word_probabilities = [ word_probabilities = [
@@ -315,6 +301,7 @@ def add_word_timestamps(
word_durations = np.array([t.end - t.start for t in alignment]) word_durations = np.array([t.end - t.start for t in alignment])
word_durations = word_durations[word_durations.nonzero()] word_durations = word_durations[word_durations.nonzero()]
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0 median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
median_duration = min(0.7, float(median_duration))
max_duration = median_duration * 2 max_duration = median_duration * 2
# hack: truncate long words at sentence boundaries. # hack: truncate long words at sentence boundaries.

View File

@@ -1,501 +0,0 @@
import argparse
import os
import warnings
from typing import TYPE_CHECKING, Optional, Tuple, Union
import numpy as np
import torch
import tqdm
from whisper.audio import (
FRAMES_PER_SECOND,
HOP_LENGTH,
N_FRAMES,
N_SAMPLES,
SAMPLE_RATE,
log_mel_spectrogram,
pad_or_trim,
)
from whisper.decoding import DecodingOptions, DecodingResult
from whisper.timing import add_word_timestamps
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from whisper.utils import (
exact_div,
format_timestamp,
get_writer,
make_safe,
optional_float,
optional_int,
str2bool,
)
if TYPE_CHECKING:
from whisper.model import Whisper
def transcribe(
model: "Whisper",
audio: Union[str, np.ndarray, torch.Tensor],
*,
verbose: Optional[bool] = None,
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
compression_ratio_threshold: Optional[float] = 2.4,
logprob_threshold: Optional[float] = -1.0,
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True,
initial_prompt: Optional[str] = None,
word_timestamps: bool = False,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、",
**decode_options,
):
"""
Transcribe an audio file using Whisper
Parameters
----------
model: Whisper
The Whisper model instance
audio: Union[str, np.ndarray, torch.Tensor]
The path to the audio file to open, or the audio waveform
verbose: bool
Whether to display the text being decoded to the console. If True, displays all the details,
If False, displays minimal details. If None, does not display anything
temperature: Union[float, Tuple[float, ...]]
Temperature for sampling. It can be a tuple of temperatures, which will be successively used
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
compression_ratio_threshold: float
If the gzip compression ratio is above this value, treat as failed
logprob_threshold: float
If the average log probability over sampled tokens is below this value, treat as failed
no_speech_threshold: float
If the no_speech probability is higher than this value AND the average log probability
over sampled tokens is below `logprob_threshold`, consider the segment as silent
condition_on_previous_text: bool
if True, the previous output of the model is provided as a prompt for the next window;
disabling may make the text inconsistent across windows, but the model becomes less prone to
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
word_timestamps: bool
Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
and include the timestamps for each word in each segment.
prepend_punctuations: str
If word_timestamps is True, merge these punctuation symbols with the next word
append_punctuations: str
If word_timestamps is True, merge these punctuation symbols with the previous word
initial_prompt: Optional[str]
Optional text to provide as a prompt for the first window. This can be used to provide, or
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
to make it more likely to predict those word correctly.
decode_options: dict
Keyword arguments to construct `DecodingOptions` instances
Returns
-------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
"""
# print("HACKED")
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
if model.device == torch.device("cpu"):
if torch.cuda.is_available():
warnings.warn("Performing inference on CPU when CUDA is available")
if dtype == torch.float16:
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
dtype = torch.float32
if dtype == torch.float32:
decode_options["fp16"] = False
# Pad 30-seconds of silence to the input audio, for slicing
mel = log_mel_spectrogram(audio, padding=0) # log_mel_spectrogram(audio, padding=N_SAMPLES) # 添加16000*30 = 480000个点
# mel = pad_or_trim(mel, 3000)
content_frames = mel.shape[-1] # - N_FRAMES # 对应3000帧真正有内容的是去掉尾部3000的那些数据
# 判断语种
if decode_options.get("language", None) is None:
# 如果是单语种模型,直接设成英文
if not model.is_multilingual:
decode_options["language"] = "en"
# 否则需要前传一次
else:
if verbose:
print(
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
)
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
# print(mel_segment.shape)
_, probs = model.detect_language(mel_segment)
decode_options["language"] = max(probs, key=probs.get)
if verbose is not None:
print(
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
)
language: str = decode_options["language"]
task: str = decode_options.get("task", "transcribe")
# 输出编码器
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
# 词级别时间戳
if word_timestamps and task == "translate":
warnings.warn("Word-level timestamps on translations may not be reliable.")
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
temperatures = (
[temperature] if isinstance(temperature, (int, float)) else temperature
)
decode_result = None
for t in temperatures:
kwargs = {**decode_options}
if t > 0:
# disable beam_size and patience when t > 0
kwargs.pop("beam_size", None)
kwargs.pop("patience", None)
else:
# disable best_of when t == 0
kwargs.pop("best_of", None)
options = DecodingOptions(**kwargs, temperature=t)
decode_result = model.decode(segment, options)
# 几种解码可能失败的情况。这些情况下会重复解码
# 感觉是一种KnowHow的东西 或许ChatGPT里有不少这种trick
needs_fallback = False
if (
compression_ratio_threshold is not None
and decode_result.compression_ratio > compression_ratio_threshold
):
needs_fallback = True # too repetitive
if (
logprob_threshold is not None
and decode_result.avg_logprob < logprob_threshold
):
needs_fallback = True # average log probability is too low
if (
no_speech_threshold is not None
and decode_result.no_speech_prob > no_speech_threshold
):
needs_fallback = False # silence
if not needs_fallback:
break
# print("decode with temperature {} compress rate {:.3f}/{:.3f}, log_prob {:.3f}/{:.3f}, {:.3f}/{:.3f}".format(
# t,
# decode_result.compression_ratio, compression_ratio_threshold,
# -decode_result.avg_logprob, -logprob_threshold,
# decode_result.no_speech_prob, no_speech_threshold
# ))
return decode_result
seek = 0
input_stride = exact_div(
N_FRAMES, model.dims.n_audio_ctx
) # mel frames per output token: 2
# 这里output token指的应该是CNN输出的那个东西
time_precision = (
input_stride * HOP_LENGTH / SAMPLE_RATE
) # time per output token: 0.02 (seconds)
all_tokens = []
all_segments = []
prompt_reset_since = 0
if initial_prompt is not None:
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
all_tokens.extend(initial_prompt_tokens)
else:
initial_prompt_tokens = []
def new_segment(
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
):
tokens = tokens.tolist()
text_tokens = [token for token in tokens if token < tokenizer.eot]
return {
"seek": seek,
"start": start,
"end": end,
"text": tokenizer.decode(text_tokens),
"tokens": tokens,
"temperature": result.temperature,
"avg_logprob": result.avg_logprob,
"compression_ratio": result.compression_ratio,
"no_speech_prob": result.no_speech_prob,
}
# show the progress bar when verbose is False (if True, transcribed text will be printed)
with tqdm.tqdm(
total=content_frames, unit="frames", disable=verbose is not False
) as pbar:
last_speech_timestamp = 0.0
while seek < content_frames: # seek标记mel频谱当前帧的位置 直接跳过Padding上的部分
# print("seek segments", seek, content_frames)
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) # 本片段的开始时间
# mel_segment = mel[:, seek : seek + N_FRAMES] # 获得当前片段的数据
mel_segment = mel[:, seek:]
segment_size = min(N_FRAMES, content_frames - seek) # segment_size: 排除padding的真的长度。content_frames有内容的段的真正长度 如果不够N_FRAMES的话就会截断
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE # 当前片段的时长
mel_segment = mel_segment.to(model.device).to(dtype) # pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype) # 补到mel_segment帧
decode_options["prompt"] = all_tokens[prompt_reset_since:]
result: DecodingResult = decode_with_fallback(mel_segment)
tokens = torch.tensor(result.tokens)
# 跳过静音部分
if no_speech_threshold is not None:
# no voice activity check
should_skip = result.no_speech_prob > no_speech_threshold
if (
logprob_threshold is not None
and result.avg_logprob > logprob_threshold
):
# don't skip if the logprob is high enough, despite the no_speech_prob
should_skip = False
if should_skip:
seek += segment_size # fast-forward to the next segment boundary
continue
previous_seek = seek
current_segments = []
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) # timestamp begin是<|0.00|>的tokenbos比文字token大eos的值比bos还大所以是ge
timestamp_tokens[-1] = False
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] # 如果最后是[False,True]:本段里一个句子结束了
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
# torch.where(condition) is identical to torch.nonzero(condition, as_tuple=True).
# timestamp_token就是个一维向量吧 那为啥不直接nonzero
# 如果有两个连续的时间戳 这个会是一个一维tensor 是这两个连续时间戳的结尾位置
# 多个的话指向第二个 那如果有三个怎么办?
# 否则是个0维tensor
consecutive.add_(1) # 0维tensor+1还是0维 哪儿找的这些edge cases js是吧
if len(consecutive) > 0:
# if the output contains two consecutive timestamp tokens
slices = consecutive.tolist()
if single_timestamp_ending:
slices.append(len(tokens)) # 把最后一段的结尾也加进去
# print("many sentenses", consecutive)
last_slice = 0
for current_slice in slices:
sliced_tokens = tokens[last_slice:current_slice]
# 看起来语音开始帧、语音结束帧的位置会被编码到start_timestamp中
start_timestamp_pos = (
sliced_tokens[0].item() - tokenizer.timestamp_begin
)
end_timestamp_pos = (
sliced_tokens[-1].item() - tokenizer.timestamp_begin
)
# 获取一个新的语音段
current_segments.append(
new_segment(
start=time_offset + start_timestamp_pos * time_precision,
end=time_offset + end_timestamp_pos * time_precision,
tokens=sliced_tokens,
result=result,
)
)
last_slice = current_slice
if single_timestamp_ending:
# single timestamp at the end means no speech after the last timestamp.
seek += segment_size
else:
# otherwise, ignore the unfinished segment and seek to the last timestamp
# 如果语音尚未结束那么seek变为上一个结束的语段的位置
# 换句话说就是针对30s长的chunk的语音设计的
last_timestamp_pos = (
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
)
seek += last_timestamp_pos * input_stride
else:
duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
# print(timestamps)
if (
len(timestamps) > 0
and timestamps[-1].item() != tokenizer.timestamp_begin
):
# no consecutive timestamps but it has a timestamp; use the last one.
# 取最后一个假设要么有一个结束的time stamp要么有一对儿
# 如果里面只有一个开始的timestamp 似乎后面的东西都会被丢掉?
last_timestamp_pos = (
timestamps[-1].item() - tokenizer.timestamp_begin
)
duration = last_timestamp_pos * time_precision
current_segments.append(
new_segment(
start=time_offset,
end=time_offset + duration,
tokens=tokens,
result=result,
)
)
seek += segment_size
# 每个token有自己的时间戳
if word_timestamps:
add_word_timestamps(
segments=current_segments,
model=model,
tokenizer=tokenizer,
mel=mel_segment,
num_frames=segment_size,
prepend_punctuations=prepend_punctuations,
append_punctuations=append_punctuations,
last_speech_timestamp=last_speech_timestamp,
)
word_end_timestamps = [
w["end"] for s in current_segments for w in s["words"]
]
if len(word_end_timestamps) > 0:
last_speech_timestamp = word_end_timestamps[-1]
if not single_timestamp_ending and len(word_end_timestamps) > 0:
seek_shift = round(
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
)
if seek_shift > 0:
seek = previous_seek + seek_shift
if verbose:
for segment in current_segments:
start, end, text = segment["start"], segment["end"], segment["text"]
line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
print(make_safe(line))
# if a segment is instantaneous or does not contain text, clear it
for i, segment in enumerate(current_segments):
if segment["start"] == segment["end"] or segment["text"].strip() == "":
segment["text"] = ""
segment["tokens"] = []
segment["words"] = []
# 更新结果
all_segments.extend(
[
{"id": i, **segment}
for i, segment in enumerate(
current_segments, start=len(all_segments)
)
]
)
all_tokens.extend(
[token for segment in current_segments for token in segment["tokens"]]
)
if not condition_on_previous_text or result.temperature > 0.5:
# do not feed the prompt tokens if a high temperature was used
prompt_reset_since = len(all_tokens)
# update progress bar
pbar.update(min(content_frames, seek) - previous_seek)
# print("太长了")
# break
return dict(
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
segments=all_segments,
language=language,
)
def cli():
from . import available_models
# fmt: off
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
parser.add_argument("--append_punctuations", type=str, default="\"\'.。,!?::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
# fmt: on
args = parser.parse_args().__dict__
model_name: str = args.pop("model")
model_dir: str = args.pop("model_dir")
output_dir: str = args.pop("output_dir")
output_format: str = args.pop("output_format")
device: str = args.pop("device")
os.makedirs(output_dir, exist_ok=True)
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
if args["language"] is not None:
warnings.warn(
f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
)
args["language"] = "en"
temperature = args.pop("temperature")
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
else:
temperature = [temperature]
if (threads := args.pop("threads")) > 0:
torch.set_num_threads(threads)
from . import load_model
model = load_model(model_name, device=device, download_root=model_dir)
writer = get_writer(output_format, output_dir)
word_options = ["highlight_words", "max_line_count", "max_line_width"]
if not args["word_timestamps"]:
for option in word_options:
if args[option]:
parser.error(f"--{option} requires --word_timestamps True")
if args["max_line_count"] and not args["max_line_width"]:
warnings.warn("--max_line_count has no effect without --max_line_width")
writer_args = {arg: args.pop(arg) for arg in word_options}
for audio_path in args.pop("audio"):
result = transcribe(model, audio_path, temperature=temperature, **args)
writer(result, audio_path, writer_args)
if __name__ == "__main__":
cli()

View File

@@ -1,7 +1,8 @@
import argparse import argparse
import os import os
import traceback
import warnings import warnings
from typing import TYPE_CHECKING, Optional, Tuple, Union from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
@@ -22,6 +23,7 @@ from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .utils import ( from .utils import (
exact_div, exact_div,
format_timestamp, format_timestamp,
get_end,
get_writer, get_writer,
make_safe, make_safe,
optional_float, optional_float,
@@ -44,9 +46,12 @@ def transcribe(
no_speech_threshold: Optional[float] = 0.6, no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True, condition_on_previous_text: bool = True,
initial_prompt: Optional[str] = None, initial_prompt: Optional[str] = None,
carry_initial_prompt: bool = False,
word_timestamps: bool = False, word_timestamps: bool = False,
prepend_punctuations: str = "\"'“¿([{-", prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、", append_punctuations: str = "\"'.。,!?::”)]}、",
clip_timestamps: Union[str, List[float]] = "0",
hallucination_silence_threshold: Optional[float] = None,
**decode_options, **decode_options,
): ):
""" """
@@ -98,15 +103,27 @@ def transcribe(
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
to make it more likely to predict those word correctly. to make it more likely to predict those word correctly.
carry_initial_prompt: bool
If carry_initial_prompt is True, `initial_prompt` is prepended to the prompt of each internal
`decode()` call. If there is not enough context space at the start of the prompt, it is
left-sliced to make space.
decode_options: dict decode_options: dict
Keyword arguments to construct `DecodingOptions` instances Keyword arguments to construct `DecodingOptions` instances
clip_timestamps: Union[str, List[float]]
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process.
The last end timestamp defaults to the end of the file.
hallucination_silence_threshold: Optional[float]
When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
when a possible hallucination is detected
Returns Returns
------- -------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
the spoken language ("language"), which is detected when `decode_options["language"]` is None. the spoken language ("language"), which is detected when `decode_options["language"]` is None.
""" """
# print("transcribe")
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32 dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
if model.device == torch.device("cpu"): if model.device == torch.device("cpu"):
if torch.cuda.is_available(): if torch.cuda.is_available():
@@ -119,8 +136,9 @@ def transcribe(
decode_options["fp16"] = False decode_options["fp16"] = False
# Pad 30-seconds of silence to the input audio, for slicing # Pad 30-seconds of silence to the input audio, for slicing
mel = log_mel_spectrogram(audio, padding=N_SAMPLES) mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
content_frames = mel.shape[-1] - N_FRAMES content_frames = mel.shape[-1] - N_FRAMES
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
if decode_options.get("language", None) is None: if decode_options.get("language", None) is None:
if not model.is_multilingual: if not model.is_multilingual:
@@ -131,7 +149,6 @@ def transcribe(
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language" "Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
) )
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype) mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
# print(mel_segment.shape)
_, probs = model.detect_language(mel_segment) _, probs = model.detect_language(mel_segment)
decode_options["language"] = max(probs, key=probs.get) decode_options["language"] = max(probs, key=probs.get)
if verbose is not None: if verbose is not None:
@@ -141,7 +158,25 @@ def transcribe(
language: str = decode_options["language"] language: str = decode_options["language"]
task: str = decode_options.get("task", "transcribe") task: str = decode_options.get("task", "transcribe")
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task) tokenizer = get_tokenizer(
model.is_multilingual,
num_languages=model.num_languages,
language=language,
task=task,
)
if isinstance(clip_timestamps, str):
clip_timestamps = [
float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])
]
seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
if len(seek_points) == 0:
seek_points.append(0)
if len(seek_points) % 2 == 1:
seek_points.append(content_frames)
seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))
punctuation = "\"'“¿([{-\"'.。,!?::”)]}、"
if word_timestamps and task == "translate": if word_timestamps and task == "translate":
warnings.warn("Word-level timestamps on translations may not be reliable.") warnings.warn("Word-level timestamps on translations may not be reliable.")
@@ -179,6 +214,8 @@ def transcribe(
if ( if (
no_speech_threshold is not None no_speech_threshold is not None
and decode_result.no_speech_prob > no_speech_threshold and decode_result.no_speech_prob > no_speech_threshold
and logprob_threshold is not None
and decode_result.avg_logprob < logprob_threshold
): ):
needs_fallback = False # silence needs_fallback = False # silence
if not needs_fallback: if not needs_fallback:
@@ -186,7 +223,8 @@ def transcribe(
return decode_result return decode_result
seek = 0 clip_idx = 0
seek = seek_clips[clip_idx][0]
input_stride = exact_div( input_stride = exact_div(
N_FRAMES, model.dims.n_audio_ctx N_FRAMES, model.dims.n_audio_ctx
) # mel frames per output token: 2 ) # mel frames per output token: 2
@@ -197,9 +235,11 @@ def transcribe(
all_segments = [] all_segments = []
prompt_reset_since = 0 prompt_reset_since = 0
remaining_prompt_length = model.dims.n_text_ctx // 2 - 1
if initial_prompt is not None: if initial_prompt is not None:
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip()) initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
all_tokens.extend(initial_prompt_tokens) all_tokens.extend(initial_prompt_tokens)
remaining_prompt_length -= len(initial_prompt_tokens)
else: else:
initial_prompt_tokens = [] initial_prompt_tokens = []
@@ -225,16 +265,33 @@ def transcribe(
total=content_frames, unit="frames", disable=verbose is not False total=content_frames, unit="frames", disable=verbose is not False
) as pbar: ) as pbar:
last_speech_timestamp = 0.0 last_speech_timestamp = 0.0
while seek < content_frames: # NOTE: This loop is obscurely flattened to make the diff readable.
# A later commit should turn this into a simpler nested loop.
# for seek_clip_start, seek_clip_end in seek_clips:
# while seek < seek_clip_end
while clip_idx < len(seek_clips):
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
if seek < seek_clip_start:
seek = seek_clip_start
if seek >= seek_clip_end:
clip_idx += 1
if clip_idx < len(seek_clips):
seek = seek_clips[clip_idx][0]
continue
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
mel_segment = mel[:, seek : seek + N_FRAMES] window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
segment_size = min(N_FRAMES, content_frames - seek) segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek)
mel_segment = mel[:, seek : seek + segment_size]
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype) mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
# print("melshape", mel_segment.shape) if carry_initial_prompt:
nignored = max(len(initial_prompt_tokens), prompt_reset_since)
remaining_prompt = all_tokens[nignored:][-remaining_prompt_length:]
decode_options["prompt"] = initial_prompt_tokens + remaining_prompt
else:
decode_options["prompt"] = all_tokens[prompt_reset_since:]
decode_options["prompt"] = all_tokens[prompt_reset_since:]
result: DecodingResult = decode_with_fallback(mel_segment) result: DecodingResult = decode_with_fallback(mel_segment)
tokens = torch.tensor(result.tokens) tokens = torch.tensor(result.tokens)
@@ -255,6 +312,30 @@ def transcribe(
previous_seek = seek previous_seek = seek
current_segments = [] current_segments = []
# anomalous words are very long/short/improbable
def word_anomaly_score(word: dict) -> float:
probability = word.get("probability", 0.0)
duration = word["end"] - word["start"]
score = 0.0
if probability < 0.15:
score += 1.0
if duration < 0.133:
score += (0.133 - duration) * 15
if duration > 2.0:
score += duration - 2.0
return score
def is_segment_anomaly(segment: Optional[dict]) -> bool:
if segment is None or not segment["words"]:
return False
words = [w for w in segment["words"] if w["word"] not in punctuation]
words = words[:8]
score = sum(word_anomaly_score(w) for w in words)
return score >= 3 or score + 0.01 >= len(words)
def next_words_segment(segments: List[dict]) -> Optional[dict]:
return next((s for s in segments if s["words"]), None)
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
@@ -317,9 +398,7 @@ def transcribe(
) )
seek += segment_size seek += segment_size
# print("word_timestamps, ", word_timestamps)
if word_timestamps: if word_timestamps:
# print("=========run timestamps here=========")
add_word_timestamps( add_word_timestamps(
segments=current_segments, segments=current_segments,
model=model, model=model,
@@ -330,17 +409,71 @@ def transcribe(
append_punctuations=append_punctuations, append_punctuations=append_punctuations,
last_speech_timestamp=last_speech_timestamp, last_speech_timestamp=last_speech_timestamp,
) )
word_end_timestamps = [
w["end"] for s in current_segments for w in s["words"] if not single_timestamp_ending:
] last_word_end = get_end(current_segments)
if len(word_end_timestamps) > 0: if last_word_end is not None and last_word_end > time_offset:
last_speech_timestamp = word_end_timestamps[-1] seek = round(last_word_end * FRAMES_PER_SECOND)
if not single_timestamp_ending and len(word_end_timestamps) > 0:
seek_shift = round( # skip silence before possible hallucinations
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND if hallucination_silence_threshold is not None:
) threshold = hallucination_silence_threshold
if seek_shift > 0: if not single_timestamp_ending:
seek = previous_seek + seek_shift last_word_end = get_end(current_segments)
if last_word_end is not None and last_word_end > time_offset:
remaining_duration = window_end_time - last_word_end
if remaining_duration > threshold:
seek = round(last_word_end * FRAMES_PER_SECOND)
else:
seek = previous_seek + segment_size
# if first segment might be a hallucination, skip leading silence
first_segment = next_words_segment(current_segments)
if first_segment is not None and is_segment_anomaly(first_segment):
gap = first_segment["start"] - time_offset
if gap > threshold:
seek = previous_seek + round(gap * FRAMES_PER_SECOND)
continue
# skip silence before any possible hallucination that is surrounded
# by silence or more hallucinations
hal_last_end = last_speech_timestamp
for si in range(len(current_segments)):
segment = current_segments[si]
if not segment["words"]:
continue
if is_segment_anomaly(segment):
next_segment = next_words_segment(
current_segments[si + 1 :]
)
if next_segment is not None:
hal_next_start = next_segment["words"][0]["start"]
else:
hal_next_start = time_offset + segment_duration
silence_before = (
segment["start"] - hal_last_end > threshold
or segment["start"] < threshold
or segment["start"] - time_offset < 2.0
)
silence_after = (
hal_next_start - segment["end"] > threshold
or is_segment_anomaly(next_segment)
or window_end_time - segment["end"] < 2.0
)
if silence_before and silence_after:
seek = round(
max(time_offset + 1, segment["start"])
* FRAMES_PER_SECOND
)
if content_duration - segment["end"] < threshold:
seek = content_frames
current_segments[si:] = []
break
hal_last_end = segment["end"]
last_word_end = get_end(current_segments)
if last_word_end is not None:
last_speech_timestamp = last_word_end
if verbose: if verbose:
for segment in current_segments: for segment in current_segments:
@@ -384,10 +517,17 @@ def transcribe(
def cli(): def cli():
from . import available_models from . import available_models
def valid_model_name(name):
if name in available_models() or os.path.exists(name):
return name
raise ValueError(
f"model should be one of {available_models()} or path to a model checkpoint"
)
# fmt: off # fmt: off
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe") parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use") parser.add_argument("--model", default="turbo", type=valid_model_name, help="name of the Whisper model to use")
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
@@ -405,6 +545,8 @@ def cli():
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations") parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.") parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
parser.add_argument("--carry_initial_prompt", type=str2bool, default=False, help="if True, prepend initial_prompt to every internal decode() call. May reduce the effectiveness of condition_on_previous_text")
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop") parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default") parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
@@ -418,7 +560,10 @@ def cli():
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt") parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line") parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment") parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
parser.add_argument("--max_words_per_line", type=optional_int, default=None, help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file")
parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected")
# fmt: on # fmt: on
args = parser.parse_args().__dict__ args = parser.parse_args().__dict__
@@ -450,17 +595,28 @@ def cli():
model = load_model(model_name, device=device, download_root=model_dir) model = load_model(model_name, device=device, download_root=model_dir)
writer = get_writer(output_format, output_dir) writer = get_writer(output_format, output_dir)
word_options = ["highlight_words", "max_line_count", "max_line_width"] word_options = [
"highlight_words",
"max_line_count",
"max_line_width",
"max_words_per_line",
]
if not args["word_timestamps"]: if not args["word_timestamps"]:
for option in word_options: for option in word_options:
if args[option]: if args[option]:
parser.error(f"--{option} requires --word_timestamps True") parser.error(f"--{option} requires --word_timestamps True")
if args["max_line_count"] and not args["max_line_width"]: if args["max_line_count"] and not args["max_line_width"]:
warnings.warn("--max_line_count has no effect without --max_line_width") warnings.warn("--max_line_count has no effect without --max_line_width")
if args["max_words_per_line"] and args["max_line_width"]:
warnings.warn("--max_words_per_line has no effect with --max_line_width")
writer_args = {arg: args.pop(arg) for arg in word_options} writer_args = {arg: args.pop(arg) for arg in word_options}
for audio_path in args.pop("audio"): for audio_path in args.pop("audio"):
result = transcribe(model, audio_path, temperature=temperature, **args) try:
writer(result, audio_path, writer_args) result = transcribe(model, audio_path, temperature=temperature, **args)
writer(result, audio_path, **writer_args)
except Exception as e:
traceback.print_exc()
print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}")
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -60,7 +60,7 @@ def median_kernel(filter_width: int):
tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821 tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
kernel = triton.JITFunction(kernel.fn) kernel = triton.JITFunction(kernel.fn)
kernel.src = kernel.src.replace( new_kernel = kernel.src.replace(
" LOAD_ALL_ROWS_HERE", " LOAD_ALL_ROWS_HERE",
"\n".join( "\n".join(
[ [
@@ -69,7 +69,8 @@ def median_kernel(filter_width: int):
] ]
), ),
) )
kernel.src = kernel.src.replace(
new_kernel = new_kernel.replace(
" BUBBLESORT_HERE", " BUBBLESORT_HERE",
"\n\n".join( "\n\n".join(
[ [
@@ -90,7 +91,14 @@ def median_kernel(filter_width: int):
] ]
), ),
) )
kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
new_kernel = new_kernel.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
if hasattr(kernel, "_unsafe_update_src") is True:
kernel._unsafe_update_src(new_kernel)
kernel.hash = None
else:
kernel.src = new_kernel
return kernel return kernel

View File

@@ -3,7 +3,7 @@ import os
import re import re
import sys import sys
import zlib import zlib
from typing import Callable, Optional, TextIO from typing import Callable, List, Optional, TextIO
system_encoding = sys.getdefaultencoding() system_encoding = sys.getdefaultencoding()
@@ -68,13 +68,29 @@ def format_timestamp(
) )
def get_start(segments: List[dict]) -> Optional[float]:
return next(
(w["start"] for s in segments for w in s["words"]),
segments[0]["start"] if segments else None,
)
def get_end(segments: List[dict]) -> Optional[float]:
return next(
(w["end"] for s in reversed(segments) for w in reversed(s["words"])),
segments[-1]["end"] if segments else None,
)
class ResultWriter: class ResultWriter:
extension: str extension: str
def __init__(self, output_dir: str): def __init__(self, output_dir: str):
self.output_dir = output_dir self.output_dir = output_dir
def __call__(self, result: dict, audio_path: str, options: dict): def __call__(
self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
):
audio_basename = os.path.basename(audio_path) audio_basename = os.path.basename(audio_path)
audio_basename = os.path.splitext(audio_basename)[0] audio_basename = os.path.splitext(audio_basename)[0]
output_path = os.path.join( output_path = os.path.join(
@@ -82,16 +98,20 @@ class ResultWriter:
) )
with open(output_path, "w", encoding="utf-8") as f: with open(output_path, "w", encoding="utf-8") as f:
self.write_result(result, file=f, options=options) self.write_result(result, file=f, options=options, **kwargs)
def write_result(self, result: dict, file: TextIO, options: dict): def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
raise NotImplementedError raise NotImplementedError
class WriteTXT(ResultWriter): class WriteTXT(ResultWriter):
extension: str = "txt" extension: str = "txt"
def write_result(self, result: dict, file: TextIO, options: dict): def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
for segment in result["segments"]: for segment in result["segments"]:
print(segment["text"].strip(), file=file, flush=True) print(segment["text"].strip(), file=file, flush=True)
@@ -100,48 +120,76 @@ class SubtitlesWriter(ResultWriter):
always_include_hours: bool always_include_hours: bool
decimal_marker: str decimal_marker: str
def iterate_result(self, result: dict, options: dict): def iterate_result(
raw_max_line_width: Optional[int] = options["max_line_width"] self,
max_line_count: Optional[int] = options["max_line_count"] result: dict,
highlight_words: bool = options["highlight_words"] options: Optional[dict] = None,
max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width *,
preserve_segments = max_line_count is None or raw_max_line_width is None max_line_width: Optional[int] = None,
max_line_count: Optional[int] = None,
highlight_words: bool = False,
max_words_per_line: Optional[int] = None,
):
options = options or {}
max_line_width = max_line_width or options.get("max_line_width")
max_line_count = max_line_count or options.get("max_line_count")
highlight_words = highlight_words or options.get("highlight_words", False)
max_words_per_line = max_words_per_line or options.get("max_words_per_line")
preserve_segments = max_line_count is None or max_line_width is None
max_line_width = max_line_width or 1000
max_words_per_line = max_words_per_line or 1000
def iterate_subtitles(): def iterate_subtitles():
line_len = 0 line_len = 0
line_count = 1 line_count = 1
# the next subtitle to yield (a list of word timings with whitespace) # the next subtitle to yield (a list of word timings with whitespace)
subtitle: list[dict] = [] subtitle: List[dict] = []
last = result["segments"][0]["words"][0]["start"] last: float = get_start(result["segments"]) or 0.0
for segment in result["segments"]: for segment in result["segments"]:
for i, original_timing in enumerate(segment["words"]): chunk_index = 0
timing = original_timing.copy() words_count = max_words_per_line
long_pause = not preserve_segments and timing["start"] - last > 3.0 while chunk_index < len(segment["words"]):
has_room = line_len + len(timing["word"]) <= max_line_width remaining_words = len(segment["words"]) - chunk_index
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments if max_words_per_line > len(segment["words"]) - chunk_index:
if line_len > 0 and has_room and not long_pause and not seg_break: words_count = remaining_words
# line continuation for i, original_timing in enumerate(
line_len += len(timing["word"]) segment["words"][chunk_index : chunk_index + words_count]
else: ):
# new line timing = original_timing.copy()
timing["word"] = timing["word"].strip() long_pause = (
not preserve_segments and timing["start"] - last > 3.0
)
has_room = line_len + len(timing["word"]) <= max_line_width
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
if ( if (
len(subtitle) > 0 line_len > 0
and max_line_count is not None and has_room
and (long_pause or line_count >= max_line_count) and not long_pause
or seg_break and not seg_break
): ):
# subtitle break # line continuation
yield subtitle line_len += len(timing["word"])
subtitle = [] else:
line_count = 1 # new line
elif line_len > 0: timing["word"] = timing["word"].strip()
# line break if (
line_count += 1 len(subtitle) > 0
timing["word"] = "\n" + timing["word"] and max_line_count is not None
line_len = len(timing["word"].strip()) and (long_pause or line_count >= max_line_count)
subtitle.append(timing) or seg_break
last = timing["start"] ):
# subtitle break
yield subtitle
subtitle = []
line_count = 1
elif line_len > 0:
# line break
line_count += 1
timing["word"] = "\n" + timing["word"]
line_len = len(timing["word"].strip())
subtitle.append(timing)
last = timing["start"]
chunk_index += max_words_per_line
if len(subtitle) > 0: if len(subtitle) > 0:
yield subtitle yield subtitle
@@ -161,9 +209,11 @@ class SubtitlesWriter(ResultWriter):
yield start, end, "".join( yield start, end, "".join(
[ [
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word) (
if j == i re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
else word if j == i
else word
)
for j, word in enumerate(all_words) for j, word in enumerate(all_words)
] ]
) )
@@ -190,9 +240,11 @@ class WriteVTT(SubtitlesWriter):
always_include_hours: bool = False always_include_hours: bool = False
decimal_marker: str = "." decimal_marker: str = "."
def write_result(self, result: dict, file: TextIO, options: dict): def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
print("WEBVTT\n", file=file) print("WEBVTT\n", file=file)
for start, end, text in self.iterate_result(result, options): for start, end, text in self.iterate_result(result, options, **kwargs):
print(f"{start} --> {end}\n{text}\n", file=file, flush=True) print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
@@ -201,9 +253,11 @@ class WriteSRT(SubtitlesWriter):
always_include_hours: bool = True always_include_hours: bool = True
decimal_marker: str = "," decimal_marker: str = ","
def write_result(self, result: dict, file: TextIO, options: dict): def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
for i, (start, end, text) in enumerate( for i, (start, end, text) in enumerate(
self.iterate_result(result, options), start=1 self.iterate_result(result, options, **kwargs), start=1
): ):
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True) print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
@@ -220,7 +274,9 @@ class WriteTSV(ResultWriter):
extension: str = "tsv" extension: str = "tsv"
def write_result(self, result: dict, file: TextIO, options: dict): def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
print("start", "end", "text", sep="\t", file=file) print("start", "end", "text", sep="\t", file=file)
for segment in result["segments"]: for segment in result["segments"]:
print(round(1000 * segment["start"]), file=file, end="\t") print(round(1000 * segment["start"]), file=file, end="\t")
@@ -231,7 +287,9 @@ class WriteTSV(ResultWriter):
class WriteJSON(ResultWriter): class WriteJSON(ResultWriter):
extension: str = "json" extension: str = "json"
def write_result(self, result: dict, file: TextIO, options: dict): def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
json.dump(result, file) json.dump(result, file)
@@ -249,9 +307,11 @@ def get_writer(
if output_format == "all": if output_format == "all":
all_writers = [writer(output_dir) for writer in writers.values()] all_writers = [writer(output_dir) for writer in writers.values()]
def write_all(result: dict, file: TextIO, options: dict): def write_all(
result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
for writer in all_writers: for writer in all_writers:
writer(result, file, options) writer(result, file, options, **kwargs)
return write_all return write_all

View File

@@ -1 +1 @@
__version__ = "20230918" __version__ = "20250625"

View File

@@ -1,73 +0,0 @@
import torch
import sys
class TokenBuffer:
def __init__(self, text="", tokenizer=None, device=None, prefix_token_ids=[]):
self.text = text
self.prefix_token_ids = prefix_token_ids
self.tokenizer = tokenizer
self.device = device
def as_token_ids(self, tokenizer=None):
if tokenizer is None:
tokenizer = self.tokenizer
if tokenizer is None:
raise ValueError("Tokenizer is not set.")
return self.prefix_token_ids + tokenizer.encode(self.text)
def as_tensor(self, device=None):
if device is None:
device = self.device
if device is None:
raise ValueError("Device is not set.")
tok_ids = self.as_token_ids()
return torch.tensor(tok_ids,
dtype=torch.long, device=device).unsqueeze(0)
def as_tensor_beam(self, beam, device=None):
t = self.as_tensor(device=device)
return t.repeat_interleave(beam, dim=0)
def as_text(self):
return self.text
@staticmethod
def empty(*a, **kw):
return TokenBuffer(*a,**kw)
@staticmethod
def from_text(text, *a, **kw):
return TokenBuffer(*a, text=text, **kw)
def is_empty(self):
return self.text is None or self.text == ""
def trim_words(self, num=1, after=0):
'''
num: how many words to trim from the beginning
after: how many characters to skip (length of the static prompt)
'''
tokenizer = self.tokenizer
assert tokenizer is not None, "Tokenizer is not set."
ids = tokenizer.encode(self.text[after:])
words, wids = self.tokenizer.split_to_word_tokens(ids)
print(words, file=sys.stderr)
print(wids, file=sys.stderr)
if not words:
return 0
self.text = self.text[:after] + "".join(words[num:])
return sum(len(wi) for wi in wids[:num])
def append_token_ids(self, token_ids):
tokenizer = self.tokenizer
assert tokenizer is not None, "Tokenizer is not set."
self.text += self.tokenizer.decode(token_ids)
def as_split_word_tokens(self):
tokenizer = self.tokenizer
assert tokenizer is not None, "Tokenizer is not set."
ids = tokenizer.encode(self.text)
return tokenizer.split_to_word_tokens(ids)

62
whisperlivekit/warmup.py Normal file
View File

@@ -0,0 +1,62 @@
import logging
logger = logging.getLogger(__name__)
def load_file(warmup_file=None, timeout=5):
import os
import tempfile
import librosa
if warmup_file is None:
# Download JFK sample if not already present
jfk_url = "https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav"
temp_dir = tempfile.gettempdir()
warmup_file = os.path.join(temp_dir, "whisper_warmup_jfk.wav")
if not os.path.exists(warmup_file):
logger.debug(f"Downloading warmup file from {jfk_url}")
print(f"Downloading warmup file from {jfk_url}")
import time
import urllib.request
import urllib.error
import socket
original_timeout = socket.getdefaulttimeout()
socket.setdefaulttimeout(timeout)
start_time = time.time()
try:
urllib.request.urlretrieve(jfk_url, warmup_file)
logger.debug(f"Download successful in {time.time() - start_time:.2f}s")
except (urllib.error.URLError, socket.timeout) as e:
logger.warning(f"Download failed: {e}. Proceeding without warmup.")
return False
finally:
socket.setdefaulttimeout(original_timeout)
elif not warmup_file:
return False
if not warmup_file or not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0:
logger.warning(f"Warmup file {warmup_file} invalid or missing.")
return False
try:
audio, sr = librosa.load(warmup_file, sr=16000)
except Exception as e:
logger.warning(f"Failed to load audio file: {e}")
return False
return audio
def warmup_asr(asr, warmup_file=None, timeout=5):
"""
Warmup the ASR model by transcribing a short audio file.
"""
audio = load_file(warmup_file=None, timeout=5)
asr.transcribe(audio)
logger.info("ASR model is warmed up")
def warmup_online(online, warmup_file=None, timeout=5):
audio = load_file(warmup_file=None, timeout=5)
online.warmup(audio)
logger.warning("ASR is warmed up")

View File

@@ -4,12 +4,87 @@
<head> <head>
<meta charset="UTF-8" /> <meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /> <meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Audio Transcription</title> <title>WhisperLiveKit</title>
<style> <style>
:root {
--bg: #ffffff;
--text: #111111;
--muted: #666666;
--border: #e5e5e5;
--chip-bg: rgba(0, 0, 0, 0.04);
--chip-text: #000000;
--spinner-border: #8d8d8d5c;
--spinner-top: #b0b0b0;
--silence-bg: #f3f3f3;
--loading-bg: rgba(255, 77, 77, 0.06);
--button-bg: #ffffff;
--button-border: #e9e9e9;
--wave-stroke: #000000;
--label-dia-text: #868686;
--label-trans-text: #111111;
}
@media (prefers-color-scheme: dark) {
:root:not([data-theme="light"]) {
--bg: #0b0b0b;
--text: #e6e6e6;
--muted: #9aa0a6;
--border: #333333;
--chip-bg: rgba(255, 255, 255, 0.08);
--chip-text: #e6e6e6;
--spinner-border: #555555;
--spinner-top: #dddddd;
--silence-bg: #1a1a1a;
--loading-bg: rgba(255, 77, 77, 0.12);
--button-bg: #111111;
--button-border: #333333;
--wave-stroke: #e6e6e6;
--label-dia-text: #b3b3b3;
--label-trans-text: #ffffff;
}
}
:root[data-theme="dark"] {
--bg: #0b0b0b;
--text: #e6e6e6;
--muted: #9aa0a6;
--border: #333333;
--chip-bg: rgba(255, 255, 255, 0.08);
--chip-text: #e6e6e6;
--spinner-border: #555555;
--spinner-top: #dddddd;
--silence-bg: #1a1a1a;
--loading-bg: rgba(255, 77, 77, 0.12);
--button-bg: #111111;
--button-border: #333333;
--wave-stroke: #e6e6e6;
--label-dia-text: #b3b3b3;
--label-trans-text: #ffffff;
}
:root[data-theme="light"] {
--bg: #ffffff;
--text: #111111;
--muted: #666666;
--border: #e5e5e5;
--chip-bg: rgba(0, 0, 0, 0.04);
--chip-text: #000000;
--spinner-border: #8d8d8d5c;
--spinner-top: #b0b0b0;
--silence-bg: #f3f3f3;
--loading-bg: rgba(255, 77, 77, 0.06);
--button-bg: #ffffff;
--button-border: #e9e9e9;
--wave-stroke: #000000;
--label-dia-text: #868686;
--label-trans-text: #111111;
}
body { body {
font-family: ui-sans-serif, system-ui, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol', 'Noto Color Emoji'; font-family: ui-sans-serif, system-ui, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol', 'Noto Color Emoji';
margin: 20px; margin: 20px;
text-align: center; text-align: center;
background-color: var(--bg);
color: var(--text);
} }
#recordButton { #recordButton {
@@ -17,10 +92,10 @@
height: 50px; height: 50px;
border: none; border: none;
border-radius: 50%; border-radius: 50%;
background-color: white; background-color: var(--button-bg);
cursor: pointer; cursor: pointer;
transition: all 0.3s ease; transition: all 0.3s ease;
border: 1px solid rgb(233, 233, 233); border: 1px solid var(--button-border);
display: flex; display: flex;
align-items: center; align-items: center;
justify-content: center; justify-content: center;
@@ -94,14 +169,14 @@
.timer { .timer {
font-size: 14px; font-size: 14px;
font-weight: 500; font-weight: 500;
color: #333; color: var(--text);
margin-left: 10px; margin-left: 10px;
} }
#status { #status {
margin-top: 20px; margin-top: 20px;
font-size: 16px; font-size: 16px;
color: #333; color: var(--text);
} }
.settings-container { .settings-container {
@@ -120,12 +195,14 @@
} }
#chunkSelector, #chunkSelector,
#websocketInput { #websocketInput,
#themeSelector {
font-size: 16px; font-size: 16px;
padding: 5px; padding: 5px;
border-radius: 5px; border-radius: 5px;
border: 1px solid #ddd; border: 1px solid var(--border);
background-color: #ffffff; background-color: var(--button-bg);
color: var(--text);
max-height: 30px; max-height: 30px;
} }
@@ -134,7 +211,8 @@
} }
#chunkSelector:focus, #chunkSelector:focus,
#websocketInput:focus { #websocketInput:focus,
#themeSelector:focus {
outline: none; outline: none;
border-color: #007bff; border-color: #007bff;
} }
@@ -156,18 +234,18 @@
} }
#linesTranscript strong { #linesTranscript strong {
color: #333; color: var(--text);
} }
#speaker { #speaker {
border: 1px solid rgb(229, 229, 229); border: 1px solid var(--border);
border-radius: 100px; border-radius: 100px;
padding: 2px 10px; padding: 2px 10px;
font-size: 14px; font-size: 14px;
margin-bottom: 0px; margin-bottom: 0px;
} }
.label_diarization { .label_diarization {
background-color: #ffffff66; background-color: var(--chip-bg);
border-radius: 8px 8px 8px 8px; border-radius: 8px 8px 8px 8px;
padding: 2px 10px; padding: 2px 10px;
margin-left: 10px; margin-left: 10px;
@@ -175,11 +253,11 @@
white-space: nowrap; white-space: nowrap;
font-size: 14px; font-size: 14px;
margin-bottom: 0px; margin-bottom: 0px;
color: rgb(134, 134, 134) color: var(--label-dia-text)
} }
.label_transcription { .label_transcription {
background-color: #ffffff66; background-color: var(--chip-bg);
border-radius: 8px 8px 8px 8px; border-radius: 8px 8px 8px 8px;
padding: 2px 10px; padding: 2px 10px;
display: inline-block; display: inline-block;
@@ -187,11 +265,11 @@
margin-left: 10px; margin-left: 10px;
font-size: 14px; font-size: 14px;
margin-bottom: 0px; margin-bottom: 0px;
color: #000000 color: var(--label-trans-text)
} }
#timeInfo { #timeInfo {
color: #666; color: var(--muted);
margin-left: 10px; margin-left: 10px;
} }
@@ -206,7 +284,7 @@
} }
.buffer_diarization { .buffer_diarization {
color: rgb(134, 134, 134); color: var(--label-dia-text);
margin-left: 4px; margin-left: 4px;
} }
@@ -220,10 +298,10 @@
display: inline-block; display: inline-block;
width: 8px; width: 8px;
height: 8px; height: 8px;
border: 2px solid #8d8d8d5c; border: 2px solid var(--spinner-border);
border-top: 2px solid #6c6c6ce5; border-top: 2px solid var(--spinner-top);
border-radius: 50%; border-radius: 50%;
animation: spin 0.6s linear infinite; animation: spin 0.7s linear infinite;
vertical-align: middle; vertical-align: middle;
margin-bottom: 2px; margin-bottom: 2px;
margin-right: 5px; margin-right: 5px;
@@ -236,16 +314,16 @@
} }
.silence { .silence {
color: #666; color: var(--muted);
background-color: #f3f3f3; background-color: var(--silence-bg);
font-size: 13px; font-size: 13px;
border-radius: 30px; border-radius: 30px;
padding: 2px 10px; padding: 2px 10px;
} }
.loading { .loading {
color: #666; color: var(--muted);
background-color: #ff4d4d0f; background-color: var(--loading-bg);
border-radius: 8px 8px 8px 0px; border-radius: 8px 8px 8px 0px;
padding: 2px 10px; padding: 2px 10px;
font-size: 14px; font-size: 14px;
@@ -284,6 +362,14 @@
<label for="websocketInput">WebSocket URL:</label> <label for="websocketInput">WebSocket URL:</label>
<input id="websocketInput" type="text" /> <input id="websocketInput" type="text" />
</div> </div>
<div>
<label for="themeSelector">Theme:</label>
<select id="themeSelector">
<option value="system" selected>System</option>
<option value="light">Light</option>
<option value="dark">Dark</option>
</select>
</div>
</div> </div>
</div> </div>
@@ -299,6 +385,7 @@
let chunkDuration = 1000; let chunkDuration = 1000;
let websocketUrl = "ws://localhost:8000/asr"; let websocketUrl = "ws://localhost:8000/asr";
let userClosing = false; let userClosing = false;
let wakeLock = null;
let startTime = null; let startTime = null;
let timerInterval = null; let timerInterval = null;
let audioContext = null; let audioContext = null;
@@ -309,6 +396,7 @@
let animationFrame = null; let animationFrame = null;
let waitingForStop = false; let waitingForStop = false;
let lastReceivedData = null; let lastReceivedData = null;
let lastSignature = null;
waveCanvas.width = 60 * (window.devicePixelRatio || 1); waveCanvas.width = 60 * (window.devicePixelRatio || 1);
waveCanvas.height = 30 * (window.devicePixelRatio || 1); waveCanvas.height = 30 * (window.devicePixelRatio || 1);
waveCtx.scale(window.devicePixelRatio || 1, window.devicePixelRatio || 1); waveCtx.scale(window.devicePixelRatio || 1, window.devicePixelRatio || 1);
@@ -319,6 +407,57 @@
const websocketInput = document.getElementById("websocketInput"); const websocketInput = document.getElementById("websocketInput");
const linesTranscriptDiv = document.getElementById("linesTranscript"); const linesTranscriptDiv = document.getElementById("linesTranscript");
const timerElement = document.querySelector(".timer"); const timerElement = document.querySelector(".timer");
const themeSelector = document.getElementById("themeSelector");
function getWaveStroke() {
const styles = getComputedStyle(document.documentElement);
const v = styles.getPropertyValue("--wave-stroke").trim();
return v || "#000";
}
let waveStroke = getWaveStroke();
function updateWaveStroke() {
waveStroke = getWaveStroke();
}
function applyTheme(pref) {
if (pref === "light") {
document.documentElement.setAttribute("data-theme", "light");
} else if (pref === "dark") {
document.documentElement.setAttribute("data-theme", "dark");
} else {
document.documentElement.removeAttribute("data-theme");
}
updateWaveStroke();
}
const savedThemePref = localStorage.getItem("themePreference") || "system";
applyTheme(savedThemePref);
if (themeSelector) {
themeSelector.value = savedThemePref;
themeSelector.addEventListener("change", () => {
const val = themeSelector.value;
localStorage.setItem("themePreference", val);
applyTheme(val);
});
}
const darkMq = window.matchMedia && window.matchMedia("(prefers-color-scheme: dark)");
const handleOsThemeChange = () => {
const pref = localStorage.getItem("themePreference") || "system";
if (pref === "system") updateWaveStroke();
};
if (darkMq && darkMq.addEventListener) {
darkMq.addEventListener("change", handleOsThemeChange);
} else if (darkMq && darkMq.addListener) {
darkMq.addListener(handleOsThemeChange);
}
function fmt1(x) {
const n = Number(x);
return Number.isFinite(n) ? n.toFixed(1) : x;
}
const host = window.location.hostname || "localhost"; const host = window.location.hostname || "localhost";
const port = window.location.port; const port = window.location.port;
@@ -446,10 +585,35 @@
function renderLinesWithBuffer(lines, buffer_diarization, buffer_transcription, remaining_time_diarization, remaining_time_transcription, isFinalizing = false, current_status = "active_transcription") { function renderLinesWithBuffer(lines, buffer_diarization, buffer_transcription, remaining_time_diarization, remaining_time_transcription, isFinalizing = false, current_status = "active_transcription") {
if (current_status === "no_audio_detected") { if (current_status === "no_audio_detected") {
linesTranscriptDiv.innerHTML = "<p style='text-align: center; color: #666; margin-top: 20px;'><em>No audio detected...</em></p>"; linesTranscriptDiv.innerHTML = "<p style='text-align: center; color: var(--muted); margin-top: 20px;'><em>No audio detected...</em></p>";
return; return;
} }
// try to keep stable DOM despite having updates every 0.1s. only update numeric lag values if structure hasn't changed
const showLoading = (!isFinalizing) && (lines || []).some(it => it.speaker == 0);
const showTransLag = !isFinalizing && remaining_time_transcription > 0;
const showDiaLag = !isFinalizing && !!buffer_diarization && remaining_time_diarization > 0;
const signature = JSON.stringify({
lines: (lines || []).map(it => ({ speaker: it.speaker, text: it.text, beg: it.beg, end: it.end })),
buffer_transcription: buffer_transcription || "",
buffer_diarization: buffer_diarization || "",
status: current_status,
showLoading,
showTransLag,
showDiaLag,
isFinalizing: !!isFinalizing
});
if (lastSignature === signature) {
const t = document.querySelector(".lag-transcription-value");
if (t) t.textContent = fmt1(remaining_time_transcription);
const d = document.querySelector(".lag-diarization-value");
if (d) d.textContent = fmt1(remaining_time_diarization);
const ld = document.querySelector(".loading-diarization-value");
if (ld) ld.textContent = fmt1(remaining_time_diarization);
return;
}
lastSignature = signature;
const linesHtml = lines.map((item, idx) => { const linesHtml = lines.map((item, idx) => {
let timeInfo = ""; let timeInfo = "";
if (item.beg !== undefined && item.end !== undefined) { if (item.beg !== undefined && item.end !== undefined) {
@@ -460,7 +624,7 @@
if (item.speaker === -2) { if (item.speaker === -2) {
speakerLabel = `<span class="silence">Silence<span id='timeInfo'>${timeInfo}</span></span>`; speakerLabel = `<span class="silence">Silence<span id='timeInfo'>${timeInfo}</span></span>`;
} else if (item.speaker == 0 && !isFinalizing) { } else if (item.speaker == 0 && !isFinalizing) {
speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'>${remaining_time_diarization} second(s) of audio are undergoing diarization</span></span>`; speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'><span class="loading-diarization-value">${fmt1(remaining_time_diarization)}</span> second(s) of audio are undergoing diarization</span></span>`;
} else if (item.speaker == -1) { } else if (item.speaker == -1) {
speakerLabel = `<span id="speaker">Speaker 1<span id='timeInfo'>${timeInfo}</span></span>`; speakerLabel = `<span id="speaker">Speaker 1<span id='timeInfo'>${timeInfo}</span></span>`;
} else if (item.speaker !== -1 && item.speaker !== 0) { } else if (item.speaker !== -1 && item.speaker !== 0) {
@@ -471,12 +635,12 @@
let currentLineText = item.text || ""; let currentLineText = item.text || "";
if (idx === lines.length - 1) { if (idx === lines.length - 1) {
if (!isFinalizing) { if (!isFinalizing && item.speaker !== -2) {
if (remaining_time_transcription > 0) { if (remaining_time_transcription > 0) {
speakerLabel += `<span class="label_transcription"><span class="spinner"></span>Transcription lag <span id='timeInfo'>${remaining_time_transcription}s</span></span>`; speakerLabel += `<span class="label_transcription"><span class="spinner"></span>Transcription lag <span id='timeInfo'><span class="lag-transcription-value">${fmt1(remaining_time_transcription)}</span>s</span></span>`;
} }
if (buffer_diarization && remaining_time_diarization > 0) { if (buffer_diarization && remaining_time_diarization > 0) {
speakerLabel += `<span class="label_diarization"><span class="spinner"></span>Diarization lag<span id='timeInfo'>${remaining_time_diarization}s</span></span>`; speakerLabel += `<span class="label_diarization"><span class="spinner"></span>Diarization lag<span id='timeInfo'><span class="lag-diarization-value">${fmt1(remaining_time_diarization)}</span>s</span></span>`;
} }
} }
@@ -502,6 +666,7 @@
}).join(""); }).join("");
linesTranscriptDiv.innerHTML = linesHtml; linesTranscriptDiv.innerHTML = linesHtml;
window.scrollTo({ top: document.body.scrollHeight, behavior: 'smooth' });
} }
function updateTimer() { function updateTimer() {
@@ -522,7 +687,7 @@
waveCtx.clearRect(0, 0, waveCanvas.width / (window.devicePixelRatio || 1), waveCanvas.height / (window.devicePixelRatio || 1)); waveCtx.clearRect(0, 0, waveCanvas.width / (window.devicePixelRatio || 1), waveCanvas.height / (window.devicePixelRatio || 1));
waveCtx.lineWidth = 1; waveCtx.lineWidth = 1;
waveCtx.strokeStyle = 'rgb(0, 0, 0)'; waveCtx.strokeStyle = waveStroke;
waveCtx.beginPath(); waveCtx.beginPath();
const sliceWidth = (waveCanvas.width / (window.devicePixelRatio || 1)) / bufferLength; const sliceWidth = (waveCanvas.width / (window.devicePixelRatio || 1)) / bufferLength;
@@ -549,6 +714,16 @@
async function startRecording() { async function startRecording() {
try { try {
// https://developer.mozilla.org/en-US/docs/Web/API/Screen_Wake_Lock_API
// create an async function to request a wake lock
try {
wakeLock = await navigator.wakeLock.request("screen");
} catch (err) {
// The Wake Lock request has failed - usually system related, such as battery.
console.log("Error acquiring wake lock.")
}
const stream = await navigator.mediaDevices.getUserMedia({ audio: true }); const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
audioContext = new (window.AudioContext || window.webkitAudioContext)(); audioContext = new (window.AudioContext || window.webkitAudioContext)();
@@ -578,6 +753,10 @@
} }
async function stopRecording() { async function stopRecording() {
wakeLock.release().then(() => {
wakeLock = null;
});
userClosing = true; userClosing = true;
waitingForStop = true; waitingForStop = true;

View File

@@ -3,43 +3,10 @@ import logging
import io import io
import soundfile as sf import soundfile as sf
import math import math
try:
import torch
except ImportError:
torch = None
from typing import List from typing import List
import numpy as np import numpy as np
from whisperlivekit.timed_objects import ASRToken from whisperlivekit.timed_objects import ASRToken
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SIMULSTREAMING_ERROR_AND_INSTALLATION_INSTRUCTIONS = ImportError(
"""SimulStreaming dependencies are not available.
Please install WhisperLiveKit using pip install "whisperlivekit[simulstreaming]"
""")
try:
from whisperlivekit.simul_whisper.config import AlignAttConfig
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper, DEC_PAD
from whisperlivekit.simul_whisper.whisper import tokenizer
SIMULSTREAMING_AVAILABLE = True
except ImportError:
logger.warning("⚠️ SimulStreaming dependencies not available. Attempting to download them.")
try:
from whisperlivekit import download_simulstreaming_backend
download_simulstreaming_backend()
from whisperlivekit.simul_whisper.config import AlignAttConfig
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper, DEC_PAD
from whisperlivekit.simul_whisper.whisper import tokenizer
SIMULSTREAMING_AVAILABLE = True
logger.info("SimulStreaming dependencies downloaded successfully.")
except Exception as e:
logger.error(f"Failed to download or import SimulStreaming dependencies: {e}")
SIMULSTREAMING_AVAILABLE = False
AlignAttConfig = None
PaddedAlignAttWhisper = None
DEC_PAD = None
tokenizer = None
class ASRBase: class ASRBase:
sep = " " # join transcribe words with this character (" " for whisper_timestamped, sep = " " # join transcribe words with this character (" " for whisper_timestamped,
# "" for faster-whisper because it emits the spaces when needed) # "" for faster-whisper because it emits the spaces when needed)
@@ -320,182 +287,4 @@ class OpenaiApiASR(ASRBase):
self.use_vad_opt = True self.use_vad_opt = True
def set_translate_task(self): def set_translate_task(self):
self.task = "translate" self.task = "translate"
class SimulStreamingASR(ASRBase):
"""SimulStreaming backend with AlignAtt policy."""
sep = ""
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr, **kwargs):
if not SIMULSTREAMING_AVAILABLE:
raise SIMULSTREAMING_ERROR_AND_INSTALLATION_INSTRUCTIONS
with open("whisperlivekit/simul_whisper/dual_license_simulstreaming.md", "r") as f:
print("*"*80 + f.read() + "*"*80)
self.logfile = logfile
self.transcribe_kargs = {}
self.original_language = None if lan == "auto" else lan
self.model_path = kwargs.get('model_path', './large-v3.pt')
self.frame_threshold = kwargs.get('frame_threshold', 25)
self.audio_max_len = kwargs.get('audio_max_len', 30.0)
self.audio_min_len = kwargs.get('audio_min_len', 0.0)
self.segment_length = kwargs.get('segment_length', 0.5)
self.beams = kwargs.get('beams', 1)
self.decoder_type = kwargs.get('decoder_type', 'greedy' if self.beams == 1 else 'beam')
self.task = kwargs.get('task', 'transcribe')
self.cif_ckpt_path = kwargs.get('cif_ckpt_path', None)
self.never_fire = kwargs.get('never_fire', False)
self.init_prompt = kwargs.get('init_prompt', None)
self.static_init_prompt = kwargs.get('static_init_prompt', None)
self.max_context_tokens = kwargs.get('max_context_tokens', None)
if model_dir is not None:
self.model_path = model_dir
elif modelsize is not None: #For the moment the .en.pt models do not work!
model_mapping = {
'tiny': './tiny.pt',
'base': './base.pt',
'small': './small.pt',
'medium': './medium.pt',
'medium.en': './medium.en.pt',
'large-v1': './large-v1.pt',
'base.en': './base.en.pt',
'small.en': './small.en.pt',
'tiny.en': './tiny.en.pt',
'large-v2': './large-v2.pt',
'large-v3': './large-v3.pt',
'large': './large-v3.pt'
}
self.model_path = model_mapping.get(modelsize, f'./{modelsize}.pt')
self.model = self.load_model(modelsize, cache_dir, model_dir)
# Set up tokenizer for translation if needed
if self.task == "translate":
self.set_translate_task()
def load_model(self, modelsize, cache_dir, model_dir):
try:
cfg = AlignAttConfig(
model_path=self.model_path,
segment_length=self.segment_length,
frame_threshold=self.frame_threshold,
language=self.original_language,
audio_max_len=self.audio_max_len,
audio_min_len=self.audio_min_len,
cif_ckpt_path=self.cif_ckpt_path,
decoder_type="beam",
beam_size=self.beams,
task=self.task,
never_fire=self.never_fire,
init_prompt=self.init_prompt,
max_context_tokens=self.max_context_tokens,
static_init_prompt=self.static_init_prompt,
)
logger.info(f"Loading SimulStreaming model with language: {self.original_language}")
model = PaddedAlignAttWhisper(cfg)
return model
except Exception as e:
logger.error(f"Failed to load SimulStreaming model: {e}")
raise
def transcribe(self, audio, init_prompt=""):
"""Transcribe audio using SimulStreaming."""
try:
if isinstance(audio, np.ndarray):
audio_tensor = torch.from_numpy(audio).float()
else:
audio_tensor = audio
prompt = init_prompt if init_prompt else (self.init_prompt or "")
result = self.model.infer(audio_tensor, init_prompt=prompt)
if torch.is_tensor(result):
result = result[result < DEC_PAD]
logger.debug(f"SimulStreaming transcription result: {result}")
return result
except Exception as e:
logger.error(f"SimulStreaming transcription failed: {e}")
raise
def ts_words(self, result) -> List[ASRToken]:
"""Convert SimulStreaming result to ASRToken list."""
tokens = []
try:
if torch.is_tensor(result):
text = self.model.tokenizer.decode(result.cpu().numpy())
else:
text = str(result)
if not text or len(text.strip()) == 0:
return tokens
# We dont have word-level timestamps here. 1rst approach, should be improved later.
words = text.strip().split()
if not words:
return tokens
duration_per_word = 0.1 # this will be modified based on actual audio duration
#with the SimulStreamingOnlineProcessor
for i, word in enumerate(words):
start_time = i * duration_per_word
end_time = (i + 1) * duration_per_word
token = ASRToken(
start=start_time,
end=end_time,
text=word,
probability=1.0
)
tokens.append(token)
except Exception as e:
logger.error(f"Error converting SimulStreaming result to tokens: {e}")
return tokens
def segments_end_ts(self, result) -> List[float]:
"""Get segment end timestamps."""
if torch.is_tensor(result):
num_tokens = len(result)
return [num_tokens * 0.1] # rough estimate
return [1.0]
def use_vad(self):
"""Enable VAD - SimulStreaming has different VAD handling."""
logger.info("VAD requested for SimulStreaming - handled internally by the model")
pass
def set_translate_task(self):
"""Set up translation task."""
try:
self.model.tokenizer = tokenizer.get_tokenizer(
multilingual=True,
language=self.model.cfg.language,
num_languages=self.model.model.num_languages,
task="translate"
)
logger.info("SimulStreaming configured for translation task")
except Exception as e:
logger.error(f"Failed to configure SimulStreaming for translation: {e}")
raise
def warmup(self, audio, init_prompt=""):
"""Warmup the SimulStreaming model."""
try:
if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio).float()
self.model.insert_audio(audio)
self.model.infer(True)
self.model.refresh_segment(complete=True)
logger.info("SimulStreaming model warmed up successfully")
except Exception as e:
logger.warning(f"SimulStreaming warmup failed: {e}")

View File

@@ -6,18 +6,6 @@ from whisperlivekit.timed_objects import ASRToken, Sentence, Transcript
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# simulStreaming imports - we check if the files are here
try:
import torch
from whisperlivekit.simul_whisper.config import AlignAttConfig
SIMULSTREAMING_AVAILABLE = True
except ImportError:
logger.warning("SimulStreaming dependencies not available for online processor.")
SIMULSTREAMING_AVAILABLE = False
OnlineProcessorInterface = None
torch = None
class HypothesisBuffer: class HypothesisBuffer:
""" """
Buffer to store and process ASR hypothesis tokens. Buffer to store and process ASR hypothesis tokens.
@@ -528,205 +516,3 @@ class VACOnlineASRProcessor:
""" """
return self.online.concatenate_tokens(self.online.transcript_buffer.buffer) return self.online.concatenate_tokens(self.online.transcript_buffer.buffer)
class SimulStreamingOnlineProcessor:
SAMPLING_RATE = 16000
def __init__(
self,
asr,
tokenize_method: Optional[callable] = None,
buffer_trimming: Tuple[str, float] = ("segment", 15),
confidence_validation = False,
logfile=sys.stderr,
):
if not SIMULSTREAMING_AVAILABLE:
raise ImportError("SimulStreaming dependencies are not available.")
self.asr = asr
self.tokenize = tokenize_method
self.logfile = logfile
self.confidence_validation = confidence_validation
self.init()
# buffer does not work yet
self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
def init(self, offset: Optional[float] = None):
"""Initialize or reset the processing state."""
self.audio_chunks = []
self.offset = offset if offset is not None else 0.0
self.is_last = False
self.beg = self.offset
self.end = self.offset
self.cumulative_audio_duration = 0.0
self.last_audio_stream_end_time = self.offset
self.committed: List[ASRToken] = []
self.last_result_tokens: List[ASRToken] = []
self.buffer_content = ""
self.processed_audio_duration = 0.0
def get_audio_buffer_end_time(self) -> float:
"""Returns the absolute end time of the current audio buffer."""
return self.end
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: Optional[float] = None):
"""Append an audio chunk to be processed by SimulStreaming."""
if torch is None:
raise ImportError("PyTorch is required for SimulStreaming but not available")
# Convert numpy array to torch tensor
audio_tensor = torch.from_numpy(audio).float()
self.audio_chunks.append(audio_tensor)
# Update timing
chunk_duration = len(audio) / self.SAMPLING_RATE
self.cumulative_audio_duration += chunk_duration
if audio_stream_end_time is not None:
self.last_audio_stream_end_time = audio_stream_end_time
self.end = audio_stream_end_time
else:
self.end = self.offset + self.cumulative_audio_duration
def prompt(self) -> Tuple[str, str]:
"""
Returns a tuple: (prompt, context).
SimulStreaming handles prompting internally, so we return empty strings.
"""
return "", ""
def get_buffer(self):
"""
Get the unvalidated buffer content.
"""
buffer_end = self.end if hasattr(self, 'end') else None
return Transcript(
start=None,
end=buffer_end,
text=self.buffer_content,
probability=None
)
def timestamped_text(self, tokens, generation):
# From the simulstreaming repo. self.model to self.asr.model
pr = generation["progress"]
if "result" not in generation:
split_words, split_tokens = self.asr.model.tokenizer.split_to_word_tokens(tokens)
else:
split_words, split_tokens = generation["result"]["split_words"], generation["result"]["split_tokens"]
frames = [p["most_attended_frames"][0] for p in pr]
tokens = tokens.copy()
ret = []
for sw,st in zip(split_words,split_tokens):
b = None
for stt in st:
t,f = tokens.pop(0), frames.pop(0)
if t != stt:
raise ValueError(f"Token mismatch: {t} != {stt} at frame {f}.")
if b is None:
b = f
e = f
out = (b*0.02, e*0.02, sw)
ret.append(out)
logger.debug(f"TS-WORD:\t{' '.join(map(str, out))}")
return ret
def process_iter(self) -> Tuple[List[ASRToken], float]:
"""
Process accumulated audio chunks using SimulStreaming.
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
"""
if not self.audio_chunks:
return [], self.end
try:
# concatenate all audio chunks
if len(self.audio_chunks) == 1:
audio = self.audio_chunks[0]
else:
audio = torch.cat(self.audio_chunks, dim=0)
audio_duration = audio.shape[0] / self.SAMPLING_RATE if audio.shape[0] > 0 else 0
self.processed_audio_duration += audio_duration
self.audio_chunks = []
logger.debug(f"SimulStreaming processing audio shape: {audio.shape}, duration: {audio_duration:.2f}s")
logger.debug(f"Current end time: {self.end:.2f}s, last stream time: {self.last_audio_stream_end_time:.2f}s")
self.asr.model.insert_audio(audio)
tokens, generation_progress = self.asr.model.infer(is_last=self.is_last)
ts_words = self.timestamped_text(tokens, generation_progress)
text = self.asr.model.tokenizer.decode(tokens)
new_tokens = []
for ts_word in ts_words:
start, end, word = ts_word
token = ASRToken(
start=start,
end=end,
text=word,
probability=0.95 # fake prob. Maybe we can extract it from the model?
)
new_tokens.append(token)
self.committed.extend(new_tokens)
return new_tokens, self.end
except Exception as e:
logger.error(f"SimulStreaming processing error: {e}")
logger.error(f"Error details: {type(e).__name__}: {str(e)}")
return [], self.end
def finish(self) -> Tuple[List[ASRToken], float]:
logger.debug("SimulStreaming finish() called")
self.is_last = True
final_tokens, final_time = self.process_iter()
self.is_last = False
return final_tokens, final_time
def concatenate_tokens(
self,
tokens: List[ASRToken],
sep: Optional[str] = None,
offset: float = 0
) -> Transcript:
"""Concatenate tokens into a Transcript object."""
sep = sep if sep is not None else self.asr.sep
text = sep.join(token.text for token in tokens)
probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
if tokens:
start = offset + tokens[0].start
end = offset + tokens[-1].end
else:
start = None
end = None
return Transcript(start, end, text, probability=probability)
def chunk_at(self, time: float):
"""
useless but kept for compatibility
"""
logger.debug(f"SimulStreaming chunk_at({time:.2f}) - handled internally")
pass
def words_to_sentences(self, tokens: List[ASRToken]) -> List[Sentence]:
"""
Create simple sentences.
"""
if not tokens:
return []
full_text = " ".join(token.text for token in tokens)
sentence = Sentence(
start=tokens[0].start,
end=tokens[-1].end,
text=full_text
)
return [sentence]

View File

@@ -5,8 +5,7 @@ import librosa
from functools import lru_cache from functools import lru_cache
import time import time
import logging import logging
from .backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR, SimulStreamingASR, SIMULSTREAMING_AVAILABLE, SIMULSTREAMING_ERROR_AND_INSTALLATION_INSTRUCTIONS from .backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR
from .online_asr import OnlineASRProcessor, VACOnlineASRProcessor, SimulStreamingOnlineProcessor, SIMULSTREAMING_AVAILABLE as SIMULSTREAMING_ONLINE_AVAILABLE
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -68,35 +67,7 @@ def backend_factory(args):
backend = args.backend backend = args.backend
if backend == "openai-api": if backend == "openai-api":
logger.debug("Using OpenAI API.") logger.debug("Using OpenAI API.")
asr = OpenaiApiASR(lan=args.lan) asr = OpenaiApiASR(lan=args.lan)
elif backend == "simulstreaming":
logger.debug("Using SimulStreaming backend.")
if not SIMULSTREAMING_AVAILABLE:
raise SIMULSTREAMING_ERROR_AND_INSTALLATION_INSTRUCTIONS
simulstreaming_kwargs = {}
for attr in ['frame_threshold', 'beams', 'decoder_type', 'audio_max_len', 'audio_min_len',
'cif_ckpt_path', 'never_fire', 'init_prompt', 'static_init_prompt',
'max_context_tokens', 'model_path']:
if hasattr(args, attr):
simulstreaming_kwargs[attr] = getattr(args, attr)
# Add segment_length from min_chunk_size
simulstreaming_kwargs['segment_length'] = getattr(args, 'min_chunk_size', 0.5)
simulstreaming_kwargs['task'] = args.task
size = args.model
t = time.time()
logger.info(f"Loading SimulStreaming {size} model for language {args.lan}...")
asr = SimulStreamingASR(
modelsize=size,
lan=args.lan,
cache_dir=getattr(args, 'model_cache_dir', None),
model_dir=getattr(args, 'model_dir', None),
**simulstreaming_kwargs
)
e = time.time()
logger.info(f"done. It took {round(e-t,2)} seconds.")
else: else:
if backend == "faster-whisper": if backend == "faster-whisper":
asr_cls = FasterWhisperASR asr_cls = FasterWhisperASR
@@ -136,107 +107,4 @@ def backend_factory(args):
tokenizer = create_tokenizer(tgt_language) tokenizer = create_tokenizer(tgt_language)
else: else:
tokenizer = None tokenizer = None
return asr, tokenizer return asr, tokenizer
def online_factory(args, asr, tokenizer, logfile=sys.stderr):
if args.backend == "simulstreaming":
if not SIMULSTREAMING_ONLINE_AVAILABLE:
raise SIMULSTREAMING_ERROR_AND_INSTALLATION_INSTRUCTIONS
logger.debug("Creating SimulStreaming online processor")
online = SimulStreamingOnlineProcessor(
asr,
tokenizer,
logfile=logfile,
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
confidence_validation=args.confidence_validation
)
elif args.vac:
online = VACOnlineASRProcessor(
args.min_chunk_size,
asr,
tokenizer,
logfile=logfile,
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
confidence_validation = args.confidence_validation
)
else:
online = OnlineASRProcessor(
asr,
tokenizer,
logfile=logfile,
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
confidence_validation = args.confidence_validation
)
return online
def asr_factory(args, logfile=sys.stderr):
"""
Creates and configures an ASR and ASR Online instance based on the specified backend and arguments.
"""
asr, tokenizer = backend_factory(args)
online = online_factory(args, asr, tokenizer, logfile=logfile)
return asr, online
def warmup_asr(asr, warmup_file=None, timeout=5):
"""
Warmup the ASR model by transcribing a short audio file.
"""
import os
import tempfile
is_simulstreaming = hasattr(asr, 'warmup') and callable(getattr(asr, 'warmup'))
if warmup_file is None:
# Download JFK sample if not already present
jfk_url = "https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav"
temp_dir = tempfile.gettempdir()
warmup_file = os.path.join(temp_dir, "whisper_warmup_jfk.wav")
if not os.path.exists(warmup_file):
logger.debug(f"Downloading warmup file from {jfk_url}")
print(f"Downloading warmup file from {jfk_url}")
import time
import urllib.request
import urllib.error
import socket
original_timeout = socket.getdefaulttimeout()
socket.setdefaulttimeout(timeout)
start_time = time.time()
try:
urllib.request.urlretrieve(jfk_url, warmup_file)
logger.debug(f"Download successful in {time.time() - start_time:.2f}s")
except (urllib.error.URLError, socket.timeout) as e:
logger.warning(f"Download failed: {e}. Proceeding without warmup.")
return False
finally:
socket.setdefaulttimeout(original_timeout)
elif not warmup_file:
return False
if not warmup_file or not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0:
logger.warning(f"Warmup file {warmup_file} invalid or missing.")
return False
print(f"Warming up {'SimulStreaming' if is_simulstreaming else 'Whisper'} with {warmup_file}")
try:
import librosa
audio, sr = librosa.load(warmup_file, sr=16000)
except Exception as e:
logger.warning(f"Failed to load audio file: {e}")
return False
try:
if is_simulstreaming:
asr.warmup(audio)
else:
asr.transcribe(audio)
logger.info(f"{'SimulStreaming' if is_simulstreaming else 'Whisper'} is warmed up")
return True
except Exception as e:
logger.warning(f"Warmup failed: {e}")
return False