mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-08 06:44:09 +00:00
Compare commits
46 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
349c7dcb9e | ||
|
|
1c42b867cf | ||
|
|
d4771e563e | ||
|
|
b0a5fc0693 | ||
|
|
3b96fb8776 | ||
|
|
7f93c4b978 | ||
|
|
15c3df1cba | ||
|
|
7fb8e66c01 | ||
|
|
728e1f1290 | ||
|
|
87b9ed6ecd | ||
|
|
38b4ebe8ba | ||
|
|
d098af3185 | ||
|
|
4e56130a40 | ||
|
|
2bbdc70187 | ||
|
|
b678a55f63 | ||
|
|
5491964e81 | ||
|
|
b05297a96d | ||
|
|
197293e25e | ||
|
|
ba41c4ab56 | ||
|
|
bda72b8bc0 | ||
|
|
bb6b9f4cb1 | ||
|
|
e40b5a3ea0 | ||
|
|
4cfed6e98e | ||
|
|
687e3dd5e2 | ||
|
|
e4140cd299 | ||
|
|
8e056cbdf2 | ||
|
|
9dcfb38967 | ||
|
|
47b9235d70 | ||
|
|
f3cd53a4db | ||
|
|
dbdb4ea66c | ||
|
|
00424d7ca3 | ||
|
|
4b738d6f63 | ||
|
|
8a5e2adb1e | ||
|
|
f85329e112 | ||
|
|
46efbdf1d9 | ||
|
|
8885ade003 | ||
|
|
2564928d83 | ||
|
|
56114d3071 | ||
|
|
5b9977c9af | ||
|
|
12a544164f | ||
|
|
2ca1156b7e | ||
|
|
3ad3683ca7 | ||
|
|
1599bd87a0 | ||
|
|
90623400a4 | ||
|
|
64e44fb24f | ||
|
|
156b9a133f |
@@ -21,10 +21,12 @@ RUN apt-get update && \
|
|||||||
python3 \
|
python3 \
|
||||||
python3-pip \
|
python3-pip \
|
||||||
ffmpeg \
|
ffmpeg \
|
||||||
git && \
|
git \
|
||||||
|
build-essential \
|
||||||
|
python3-dev && \
|
||||||
rm -rf /var/lib/apt/lists/*
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
|
RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
|
||||||
|
|
||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
|
|||||||
43
README.md
43
README.md
@@ -13,23 +13,16 @@
|
|||||||
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT/Dual Licensed-dark_green"></a>
|
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT/Dual Licensed-dark_green"></a>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
This project is based on [WhisperStreaming](https://github.com/ufal/whisper_streaming) and [SimulStreaming](https://github.com/ufal/SimulStreaming), allowing you to transcribe audio directly from your browser. WhisperLiveKit provides a complete backend solution for real-time speech transcription with a functional, simple and customizable frontend. Everything runs locally on your machine ✨
|
WhisperLiveKit brings real-time speech transcription directly to your browser, with a ready-to-use backend+server and a simple frontend. ✨
|
||||||
|
|
||||||
### Architecture
|
Built on [SimulStreaming](https://github.com/ufal/SimulStreaming) (SOTA 2025) and [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) for transcription, plus [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) and [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) for diarization.
|
||||||
|
|
||||||
WhisperLiveKit consists of three main components:
|
|
||||||
|
|
||||||
- **Frontend**: A basic html + JS interface that captures microphone audio and streams it to the backend via WebSockets. You can use and adapt the [provided template](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html).
|
|
||||||
- **Backend (Web Server)**: A FastAPI-based WebSocket server that receives streamed audio data, processes it in real time, and returns transcriptions to the frontend. This is where the WebSocket logic and routing live.
|
|
||||||
- **Core Backend (Library Logic)**: A server-agnostic core that handles audio processing, ASR, and diarization. It exposes reusable components that take in audio bytes and return transcriptions.
|
|
||||||
|
|
||||||
|
|
||||||
### Key Features
|
### Key Features
|
||||||
|
|
||||||
- **Real-time Transcription** - Locally (or on-prem) convert speech to text instantly as you speak
|
- **Real-time Transcription** - Locally (or on-prem) convert speech to text instantly as you speak
|
||||||
- **Speaker Diarization** - Identify different speakers in real-time using [Diart](https://github.com/juanmc2005/diart)
|
- **Speaker Diarization** - Identify different speakers in real-time. (⚠️ backend Streaming Sortformer in developement)
|
||||||
- **Multi-User Support** - Handle multiple users simultaneously with a single backend/server
|
- **Multi-User Support** - Handle multiple users simultaneously with a single backend/server
|
||||||
- **Automatic Silence Chunking** – Automatically chunks when no audio is detected to limit buffer size
|
- **Automatic Silence Chunking** – Automatically chunks when no audio is detected to limit buffer size
|
||||||
- **Confidence Validation** – Immediately validate high-confidence tokens for faster inference (WhisperStreaming only)
|
- **Confidence Validation** – Immediately validate high-confidence tokens for faster inference (WhisperStreaming only)
|
||||||
@@ -37,6 +30,11 @@ WhisperLiveKit consists of three main components:
|
|||||||
- **Punctuation-Based Speaker Splitting [BETA]** - Align speaker changes with natural sentence boundaries for more readable transcripts
|
- **Punctuation-Based Speaker Splitting [BETA]** - Align speaker changes with natural sentence boundaries for more readable transcripts
|
||||||
- **SimulStreaming Backend** - [Dual-licensed](https://github.com/ufal/SimulStreaming#-licence-and-contributions) - Ultra-low latency transcription using SOTA AlignAtt policy.
|
- **SimulStreaming Backend** - [Dual-licensed](https://github.com/ufal/SimulStreaming#-licence-and-contributions) - Ultra-low latency transcription using SOTA AlignAtt policy.
|
||||||
|
|
||||||
|
### Architecture
|
||||||
|
|
||||||
|
<img alt="Architecture" src="architecture.png" />
|
||||||
|
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -247,7 +245,7 @@ To deploy WhisperLiveKit in production:
|
|||||||
- Ensure WebSocket connection points to your server's address
|
- Ensure WebSocket connection points to your server's address
|
||||||
|
|
||||||
3. **Nginx Configuration** (recommended for production):
|
3. **Nginx Configuration** (recommended for production):
|
||||||
```nginx
|
```nginx
|
||||||
server {
|
server {
|
||||||
listen 80;
|
listen 80;
|
||||||
server_name your-domain.com;
|
server_name your-domain.com;
|
||||||
@@ -258,6 +256,7 @@ To deploy WhisperLiveKit in production:
|
|||||||
proxy_set_header Connection "upgrade";
|
proxy_set_header Connection "upgrade";
|
||||||
proxy_set_header Host $host;
|
proxy_set_header Host $host;
|
||||||
}}
|
}}
|
||||||
|
```
|
||||||
|
|
||||||
4. **HTTPS Support**: For secure deployments, use "wss://" instead of "ws://" in WebSocket URL
|
4. **HTTPS Support**: For secure deployments, use "wss://" instead of "ws://" in WebSocket URL
|
||||||
|
|
||||||
@@ -268,21 +267,21 @@ A basic Dockerfile is provided which allows re-use of Python package installatio
|
|||||||
|
|
||||||
#### All defaults
|
#### All defaults
|
||||||
- Create a reusable image with only the basics and then run as a named container:
|
- Create a reusable image with only the basics and then run as a named container:
|
||||||
```bash
|
```bash
|
||||||
docker build -t whisperlivekit-defaults .
|
docker build -t whisperlivekit-defaults .
|
||||||
docker create --gpus all --name whisperlivekit -p 8000:8000 whisperlivekit-defaults
|
docker create --gpus all --name whisperlivekit -p 8000:8000 whisperlivekit-defaults
|
||||||
docker start -i whisperlivekit
|
docker start -i whisperlivekit
|
||||||
```
|
```
|
||||||
|
|
||||||
> **Note**: If you're running on a system without NVIDIA GPU support (such as Mac with Apple Silicon or any system without CUDA capabilities), you need to **remove the `--gpus all` flag** from the `docker create` command. Without GPU acceleration, transcription will use CPU only, which may be significantly slower. Consider using small models for better performance on CPU-only systems.
|
> **Note**: If you're running on a system without NVIDIA GPU support (such as Mac with Apple Silicon or any system without CUDA capabilities), you need to **remove the `--gpus all` flag** from the `docker create` command. Without GPU acceleration, transcription will use CPU only, which may be significantly slower. Consider using small models for better performance on CPU-only systems.
|
||||||
|
|
||||||
#### Customization
|
#### Customization
|
||||||
- Customize the container options:
|
- Customize the container options:
|
||||||
```bash
|
```bash
|
||||||
docker build -t whisperlivekit-defaults .
|
docker build -t whisperlivekit-defaults .
|
||||||
docker create --gpus all --name whisperlivekit-base -p 8000:8000 whisperlivekit-defaults --model base
|
docker create --gpus all --name whisperlivekit-base -p 8000:8000 whisperlivekit-defaults --model base
|
||||||
docker start -i whisperlivekit-base
|
docker start -i whisperlivekit-base
|
||||||
```
|
```
|
||||||
|
|
||||||
- `--build-arg` Options:
|
- `--build-arg` Options:
|
||||||
- `EXTRAS="whisper-timestamped"` - Add extras to the image's installation (no spaces). Remember to set necessary container options!
|
- `EXTRAS="whisper-timestamped"` - Add extras to the image's installation (no spaces). Remember to set necessary container options!
|
||||||
|
|||||||
BIN
architecture.png
Normal file
BIN
architecture.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 382 KiB |
59
pyproject.toml
Normal file
59
pyproject.toml
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=61.0"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "whisperlivekit"
|
||||||
|
version = "0.2.5"
|
||||||
|
description = "Real-time, Fully Local Whisper's Speech-to-Text and Speaker Diarization"
|
||||||
|
readme = "README.md"
|
||||||
|
authors = [
|
||||||
|
{ name = "Quentin Fuxa" }
|
||||||
|
]
|
||||||
|
license = { file = "LICENSE" }
|
||||||
|
requires-python = ">=3.9"
|
||||||
|
classifiers = [
|
||||||
|
"Development Status :: 4 - Beta",
|
||||||
|
"Intended Audience :: Developers",
|
||||||
|
"License :: OSI Approved :: MIT License",
|
||||||
|
"Programming Language :: Python :: 3.9",
|
||||||
|
"Programming Language :: Python :: 3.10",
|
||||||
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
|
"Topic :: Multimedia :: Sound/Audio :: Speech"
|
||||||
|
]
|
||||||
|
dependencies = [
|
||||||
|
"fastapi",
|
||||||
|
"librosa",
|
||||||
|
"soundfile",
|
||||||
|
"faster-whisper",
|
||||||
|
"uvicorn",
|
||||||
|
"websockets"
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
diarization = ["diart"]
|
||||||
|
vac = ["torch"]
|
||||||
|
sentence = ["mosestokenizer", "wtpsplit"]
|
||||||
|
whisper = ["whisper"]
|
||||||
|
whisper-timestamped = ["whisper-timestamped"]
|
||||||
|
mlx-whisper = ["mlx-whisper"]
|
||||||
|
openai = ["openai"]
|
||||||
|
simulstreaming = [
|
||||||
|
"torch",
|
||||||
|
"tqdm",
|
||||||
|
"tiktoken",
|
||||||
|
'triton>=2.0.0,<3; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")'
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.urls]
|
||||||
|
Homepage = "https://github.com/QuentinFuxa/WhisperLiveKit"
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
whisperlivekit-server = "whisperlivekit.basic_server:main"
|
||||||
|
|
||||||
|
[tool.setuptools]
|
||||||
|
packages = ["whisperlivekit", "whisperlivekit.diarization", "whisperlivekit.simul_whisper", "whisperlivekit.simul_whisper.whisper", "whisperlivekit.simul_whisper.whisper.assets", "whisperlivekit.simul_whisper.whisper.normalizers", "whisperlivekit.web", "whisperlivekit.whisper_streaming_custom"]
|
||||||
|
|
||||||
|
[tool.setuptools.package-data]
|
||||||
|
whisperlivekit = ["web/*.html"]
|
||||||
|
"whisperlivekit.simul_whisper.whisper.assets" = ["*.tiktoken", "*.npz"]
|
||||||
55
setup.py
55
setup.py
@@ -1,55 +0,0 @@
|
|||||||
from setuptools import setup, find_packages
|
|
||||||
setup(
|
|
||||||
name="whisperlivekit",
|
|
||||||
version="0.2.1",
|
|
||||||
description="Real-time, Fully Local Whisper's Speech-to-Text and Speaker Diarization",
|
|
||||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
|
||||||
long_description_content_type="text/markdown",
|
|
||||||
author="Quentin Fuxa",
|
|
||||||
url="https://github.com/QuentinFuxa/WhisperLiveKit",
|
|
||||||
packages=find_packages(),
|
|
||||||
install_requires=[
|
|
||||||
"fastapi",
|
|
||||||
"librosa",
|
|
||||||
"soundfile",
|
|
||||||
"faster-whisper",
|
|
||||||
"uvicorn",
|
|
||||||
"websockets",
|
|
||||||
],
|
|
||||||
extras_require={
|
|
||||||
"diarization": ["diart"],
|
|
||||||
"vac": ["torch"],
|
|
||||||
"sentence": ["mosestokenizer", "wtpsplit"],
|
|
||||||
"whisper": ["whisper"],
|
|
||||||
"whisper-timestamped": ["whisper-timestamped"],
|
|
||||||
"mlx-whisper": ["mlx-whisper"],
|
|
||||||
"openai": ["openai"],
|
|
||||||
"simulstreaming": [
|
|
||||||
"torch",
|
|
||||||
"tqdm",
|
|
||||||
"tiktoken",
|
|
||||||
"numpy<2.0.0",
|
|
||||||
"triton>=2.0.0,<3;platform_machine==\"x86_64\" and sys_platform==\"linux\" or sys_platform==\"linux2\"",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
package_data={
|
|
||||||
'whisperlivekit': ['web/*.html'],
|
|
||||||
'whisperlivekit.simul_whisper': ['dual_license_simulstreaming.md'],
|
|
||||||
'whisperlivekit.simul_whisper.whisper.assets': ['*.tiktoken', '*.npz'],
|
|
||||||
},
|
|
||||||
entry_points={
|
|
||||||
'console_scripts': [
|
|
||||||
'whisperlivekit-server=whisperlivekit.basic_server:main',
|
|
||||||
],
|
|
||||||
},
|
|
||||||
classifiers=[
|
|
||||||
"Development Status :: 4 - Beta",
|
|
||||||
"Intended Audience :: Developers",
|
|
||||||
"License :: OSI Approved :: MIT License",
|
|
||||||
"Programming Language :: Python :: 3.9",
|
|
||||||
"Programming Language :: Python :: 3.10",
|
|
||||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
||||||
"Topic :: Multimedia :: Sound/Audio :: Speech",
|
|
||||||
],
|
|
||||||
python_requires=">=3.9",
|
|
||||||
)
|
|
||||||
@@ -1,4 +1,3 @@
|
|||||||
from .download_simulstreaming_backend import download_simulstreaming_backend
|
|
||||||
from .audio_processor import AudioProcessor
|
from .audio_processor import AudioProcessor
|
||||||
from .core import TranscriptionEngine
|
from .core import TranscriptionEngine
|
||||||
from .parse_args import parse_args
|
from .parse_args import parse_args
|
||||||
|
|||||||
@@ -6,10 +6,9 @@ import logging
|
|||||||
import traceback
|
import traceback
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from whisperlivekit.timed_objects import ASRToken
|
from whisperlivekit.timed_objects import ASRToken
|
||||||
from whisperlivekit.whisper_streaming_custom.whisper_online import online_factory
|
from whisperlivekit.core import TranscriptionEngine, online_factory
|
||||||
from whisperlivekit.core import TranscriptionEngine
|
|
||||||
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
|
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
|
||||||
|
from .remove_silences import handle_silences
|
||||||
# Set up logging once
|
# Set up logging once
|
||||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -52,7 +51,6 @@ class AudioProcessor:
|
|||||||
self.tokens = []
|
self.tokens = []
|
||||||
self.buffer_transcription = ""
|
self.buffer_transcription = ""
|
||||||
self.buffer_diarization = ""
|
self.buffer_diarization = ""
|
||||||
self.full_transcription = ""
|
|
||||||
self.end_buffer = 0
|
self.end_buffer = 0
|
||||||
self.end_attributed_speaker = 0
|
self.end_attributed_speaker = 0
|
||||||
self.lock = asyncio.Lock()
|
self.lock = asyncio.Lock()
|
||||||
@@ -96,13 +94,12 @@ class AudioProcessor:
|
|||||||
"""Convert PCM buffer in s16le format to normalized NumPy array."""
|
"""Convert PCM buffer in s16le format to normalized NumPy array."""
|
||||||
return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
|
return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
|
||||||
|
|
||||||
async def update_transcription(self, new_tokens, buffer, end_buffer, full_transcription, sep):
|
async def update_transcription(self, new_tokens, buffer, end_buffer, sep):
|
||||||
"""Thread-safe update of transcription with new data."""
|
"""Thread-safe update of transcription with new data."""
|
||||||
async with self.lock:
|
async with self.lock:
|
||||||
self.tokens.extend(new_tokens)
|
self.tokens.extend(new_tokens)
|
||||||
self.buffer_transcription = buffer
|
self.buffer_transcription = buffer
|
||||||
self.end_buffer = end_buffer
|
self.end_buffer = end_buffer
|
||||||
self.full_transcription = full_transcription
|
|
||||||
self.sep = sep
|
self.sep = sep
|
||||||
|
|
||||||
async def update_diarization(self, end_attributed_speaker, buffer_diarization=""):
|
async def update_diarization(self, end_attributed_speaker, buffer_diarization=""):
|
||||||
@@ -129,12 +126,12 @@ class AudioProcessor:
|
|||||||
# Calculate remaining times
|
# Calculate remaining times
|
||||||
remaining_transcription = 0
|
remaining_transcription = 0
|
||||||
if self.end_buffer > 0:
|
if self.end_buffer > 0:
|
||||||
remaining_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 2))
|
remaining_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 1))
|
||||||
|
|
||||||
remaining_diarization = 0
|
remaining_diarization = 0
|
||||||
if self.tokens:
|
if self.tokens:
|
||||||
latest_end = max(self.end_buffer, self.tokens[-1].end if self.tokens else 0)
|
latest_end = max(self.end_buffer, self.tokens[-1].end if self.tokens else 0)
|
||||||
remaining_diarization = max(0, round(latest_end - self.end_attributed_speaker, 2))
|
remaining_diarization = max(0, round(latest_end - self.end_attributed_speaker, 1))
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"tokens": self.tokens.copy(),
|
"tokens": self.tokens.copy(),
|
||||||
@@ -153,7 +150,6 @@ class AudioProcessor:
|
|||||||
self.tokens = []
|
self.tokens = []
|
||||||
self.buffer_transcription = self.buffer_diarization = ""
|
self.buffer_transcription = self.buffer_diarization = ""
|
||||||
self.end_buffer = self.end_attributed_speaker = 0
|
self.end_buffer = self.end_attributed_speaker = 0
|
||||||
self.full_transcription = self.last_response_content = ""
|
|
||||||
self.beg_loop = time()
|
self.beg_loop = time()
|
||||||
|
|
||||||
async def ffmpeg_stdout_reader(self):
|
async def ffmpeg_stdout_reader(self):
|
||||||
@@ -192,12 +188,6 @@ class AudioProcessor:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
self.pcm_buffer.extend(chunk)
|
self.pcm_buffer.extend(chunk)
|
||||||
|
|
||||||
# Send to diarization if enabled
|
|
||||||
if self.args.diarization and self.diarization_queue:
|
|
||||||
await self.diarization_queue.put(
|
|
||||||
self.convert_pcm_to_float(self.pcm_buffer).copy()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Process when enough data
|
# Process when enough data
|
||||||
if len(self.pcm_buffer) >= self.bytes_per_sec:
|
if len(self.pcm_buffer) >= self.bytes_per_sec:
|
||||||
@@ -214,7 +204,11 @@ class AudioProcessor:
|
|||||||
# Send to transcription if enabled
|
# Send to transcription if enabled
|
||||||
if self.args.transcription and self.transcription_queue:
|
if self.args.transcription and self.transcription_queue:
|
||||||
await self.transcription_queue.put(pcm_array.copy())
|
await self.transcription_queue.put(pcm_array.copy())
|
||||||
|
|
||||||
|
# Send to diarization if enabled
|
||||||
|
if self.args.diarization and self.diarization_queue:
|
||||||
|
await self.diarization_queue.put(pcm_array.copy())
|
||||||
|
|
||||||
# Sleep if no processing is happening
|
# Sleep if no processing is happening
|
||||||
if not self.args.transcription and not self.args.diarization:
|
if not self.args.transcription and not self.args.diarization:
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
@@ -240,7 +234,6 @@ class AudioProcessor:
|
|||||||
|
|
||||||
async def transcription_processor(self):
|
async def transcription_processor(self):
|
||||||
"""Process audio chunks for transcription."""
|
"""Process audio chunks for transcription."""
|
||||||
self.full_transcription = ""
|
|
||||||
self.sep = self.online.asr.sep
|
self.sep = self.online.asr.sep
|
||||||
cumulative_pcm_duration_stream_time = 0.0
|
cumulative_pcm_duration_stream_time = 0.0
|
||||||
|
|
||||||
@@ -252,7 +245,7 @@ class AudioProcessor:
|
|||||||
self.transcription_queue.task_done()
|
self.transcription_queue.task_done()
|
||||||
break
|
break
|
||||||
|
|
||||||
if not self.online: # Should not happen if queue is used
|
if not self.online:
|
||||||
logger.warning("Transcription processor: self.online not initialized.")
|
logger.warning("Transcription processor: self.online not initialized.")
|
||||||
self.transcription_queue.task_done()
|
self.transcription_queue.task_done()
|
||||||
continue
|
continue
|
||||||
@@ -279,8 +272,6 @@ class AudioProcessor:
|
|||||||
|
|
||||||
if new_tokens:
|
if new_tokens:
|
||||||
validated_text = self.sep.join([t.text for t in new_tokens])
|
validated_text = self.sep.join([t.text for t in new_tokens])
|
||||||
self.full_transcription += validated_text
|
|
||||||
|
|
||||||
if buffer_text.startswith(validated_text):
|
if buffer_text.startswith(validated_text):
|
||||||
buffer_text = buffer_text[len(validated_text):].lstrip()
|
buffer_text = buffer_text[len(validated_text):].lstrip()
|
||||||
|
|
||||||
@@ -297,7 +288,7 @@ class AudioProcessor:
|
|||||||
new_end_buffer = max(candidate_end_times)
|
new_end_buffer = max(candidate_end_times)
|
||||||
|
|
||||||
await self.update_transcription(
|
await self.update_transcription(
|
||||||
new_tokens, buffer_text, new_end_buffer, self.full_transcription, self.sep
|
new_tokens, buffer_text, new_end_buffer, self.sep
|
||||||
)
|
)
|
||||||
self.transcription_queue.task_done()
|
self.transcription_queue.task_done()
|
||||||
|
|
||||||
@@ -325,12 +316,12 @@ class AudioProcessor:
|
|||||||
await diarization_obj.diarize(pcm_array)
|
await diarization_obj.diarize(pcm_array)
|
||||||
|
|
||||||
async with self.lock:
|
async with self.lock:
|
||||||
new_end = diarization_obj.assign_speakers_to_tokens(
|
self.tokens = diarization_obj.assign_speakers_to_tokens(
|
||||||
self.end_attributed_speaker,
|
|
||||||
self.tokens,
|
self.tokens,
|
||||||
use_punctuation_split=self.args.punctuation_split
|
use_punctuation_split=self.args.punctuation_split
|
||||||
)
|
)
|
||||||
self.end_attributed_speaker = new_end
|
if len(self.tokens) > 0:
|
||||||
|
self.end_attributed_speaker = max(self.tokens[-1].end, self.end_attributed_speaker)
|
||||||
if buffer_diarization:
|
if buffer_diarization:
|
||||||
self.buffer_diarization = buffer_diarization
|
self.buffer_diarization = buffer_diarization
|
||||||
|
|
||||||
@@ -346,6 +337,8 @@ class AudioProcessor:
|
|||||||
|
|
||||||
async def results_formatter(self):
|
async def results_formatter(self):
|
||||||
"""Format processing results for output."""
|
"""Format processing results for output."""
|
||||||
|
last_sent_trans = None
|
||||||
|
last_sent_diar = None
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
ffmpeg_state = await self.ffmpeg_manager.get_state()
|
ffmpeg_state = await self.ffmpeg_manager.get_state()
|
||||||
@@ -383,8 +376,8 @@ class AudioProcessor:
|
|||||||
lines = []
|
lines = []
|
||||||
last_end_diarized = 0
|
last_end_diarized = 0
|
||||||
undiarized_text = []
|
undiarized_text = []
|
||||||
|
current_time = time() - self.beg_loop
|
||||||
# Process each token
|
tokens = handle_silences(tokens, current_time)
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
speaker = token.speaker
|
speaker = token.speaker
|
||||||
|
|
||||||
@@ -449,10 +442,19 @@ class AudioProcessor:
|
|||||||
' '.join([f"{line['speaker']} {line['text']}" for line in final_lines_for_response]) + \
|
' '.join([f"{line['speaker']} {line['text']}" for line in final_lines_for_response]) + \
|
||||||
f" | {buffer_transcription} | {buffer_diarization}"
|
f" | {buffer_transcription} | {buffer_diarization}"
|
||||||
|
|
||||||
if current_response_signature != self.last_response_content and \
|
trans = state["remaining_time_transcription"]
|
||||||
(final_lines_for_response or buffer_transcription or buffer_diarization or response_status == "no_audio_detected"):
|
diar = state["remaining_time_diarization"]
|
||||||
|
should_push = (
|
||||||
|
current_response_signature != self.last_response_content
|
||||||
|
or last_sent_trans is None
|
||||||
|
or round(trans, 1) != round(last_sent_trans, 1)
|
||||||
|
or round(diar, 1) != round(last_sent_diar, 1)
|
||||||
|
)
|
||||||
|
if should_push and (final_lines_for_response or buffer_transcription or buffer_diarization or response_status == "no_audio_detected" or trans > 0 or diar > 0):
|
||||||
yield response
|
yield response
|
||||||
self.last_response_content = current_response_signature
|
self.last_response_content = current_response_signature
|
||||||
|
last_sent_trans = trans
|
||||||
|
last_sent_diar = diar
|
||||||
|
|
||||||
# Check for termination condition
|
# Check for termination condition
|
||||||
if self.is_stopping:
|
if self.is_stopping:
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
try:
|
try:
|
||||||
from whisperlivekit.whisper_streaming_custom.whisper_online import backend_factory, warmup_asr
|
from whisperlivekit.whisper_streaming_custom.whisper_online import backend_factory
|
||||||
|
from whisperlivekit.whisper_streaming_custom.online_asr import VACOnlineASRProcessor, OnlineASRProcessor
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from .whisper_streaming_custom.whisper_online import backend_factory, warmup_asr
|
from .whisper_streaming_custom.whisper_online import backend_factory
|
||||||
|
from .whisper_streaming_custom.online_asr import VACOnlineASRProcessor, OnlineASRProcessor
|
||||||
|
from whisperlivekit.warmup import warmup_asr, warmup_online
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
import sys
|
||||||
|
|
||||||
class TranscriptionEngine:
|
class TranscriptionEngine:
|
||||||
_instance = None
|
_instance = None
|
||||||
@@ -22,7 +25,6 @@ class TranscriptionEngine:
|
|||||||
"host": "localhost",
|
"host": "localhost",
|
||||||
"port": 8000,
|
"port": 8000,
|
||||||
"warmup_file": None,
|
"warmup_file": None,
|
||||||
"confidence_validation": False,
|
|
||||||
"diarization": False,
|
"diarization": False,
|
||||||
"punctuation_split": False,
|
"punctuation_split": False,
|
||||||
"min_chunk_size": 0.5,
|
"min_chunk_size": 0.5,
|
||||||
@@ -34,15 +36,15 @@ class TranscriptionEngine:
|
|||||||
"backend": "faster-whisper",
|
"backend": "faster-whisper",
|
||||||
"vac": False,
|
"vac": False,
|
||||||
"vac_chunk_size": 0.04,
|
"vac_chunk_size": 0.04,
|
||||||
"buffer_trimming": "segment",
|
|
||||||
"buffer_trimming_sec": 15,
|
|
||||||
"log_level": "DEBUG",
|
"log_level": "DEBUG",
|
||||||
"ssl_certfile": None,
|
"ssl_certfile": None,
|
||||||
"ssl_keyfile": None,
|
"ssl_keyfile": None,
|
||||||
"transcription": True,
|
"transcription": True,
|
||||||
"vad": True,
|
"vad": True,
|
||||||
"segmentation_model": "pyannote/segmentation-3.0",
|
# whisperstreaming params:
|
||||||
"embedding_model": "pyannote/embedding",
|
"buffer_trimming": "segment",
|
||||||
|
"confidence_validation": False,
|
||||||
|
"buffer_trimming_sec": 15,
|
||||||
# simulstreaming params:
|
# simulstreaming params:
|
||||||
"frame_threshold": 25,
|
"frame_threshold": 25,
|
||||||
"beams": 1,
|
"beams": 1,
|
||||||
@@ -55,6 +57,10 @@ class TranscriptionEngine:
|
|||||||
"static_init_prompt": None,
|
"static_init_prompt": None,
|
||||||
"max_context_tokens": None,
|
"max_context_tokens": None,
|
||||||
"model_path": './base.pt',
|
"model_path": './base.pt',
|
||||||
|
# diart params:
|
||||||
|
"segmentation_model": "pyannote/segmentation-3.0",
|
||||||
|
"embedding_model": "pyannote/embedding",
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
config_dict = {**defaults, **kwargs}
|
config_dict = {**defaults, **kwargs}
|
||||||
@@ -78,8 +84,32 @@ class TranscriptionEngine:
|
|||||||
self.diarization = None
|
self.diarization = None
|
||||||
|
|
||||||
if self.args.transcription:
|
if self.args.transcription:
|
||||||
self.asr, self.tokenizer = backend_factory(self.args)
|
if self.args.backend == "simulstreaming":
|
||||||
warmup_asr(self.asr, self.args.warmup_file)
|
from simul_whisper import SimulStreamingASR
|
||||||
|
self.tokenizer = None
|
||||||
|
simulstreaming_kwargs = {}
|
||||||
|
for attr in ['frame_threshold', 'beams', 'decoder_type', 'audio_max_len', 'audio_min_len',
|
||||||
|
'cif_ckpt_path', 'never_fire', 'init_prompt', 'static_init_prompt',
|
||||||
|
'max_context_tokens', 'model_path']:
|
||||||
|
if hasattr(self.args, attr):
|
||||||
|
simulstreaming_kwargs[attr] = getattr(self.args, attr)
|
||||||
|
|
||||||
|
# Add segment_length from min_chunk_size
|
||||||
|
simulstreaming_kwargs['segment_length'] = getattr(self.args, 'min_chunk_size', 0.5)
|
||||||
|
simulstreaming_kwargs['task'] = self.args.task
|
||||||
|
|
||||||
|
size = self.args.model
|
||||||
|
self.asr = SimulStreamingASR(
|
||||||
|
modelsize=size,
|
||||||
|
lan=self.args.lan,
|
||||||
|
cache_dir=getattr(self.args, 'model_cache_dir', None),
|
||||||
|
model_dir=getattr(self.args, 'model_dir', None),
|
||||||
|
**simulstreaming_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.asr, self.tokenizer = backend_factory(self.args)
|
||||||
|
warmup_asr(self.asr, self.args.warmup_file) #for simulstreaming, warmup should be done in the online class not here
|
||||||
|
|
||||||
if self.args.diarization:
|
if self.args.diarization:
|
||||||
from whisperlivekit.diarization.diarization_online import DiartDiarization
|
from whisperlivekit.diarization.diarization_online import DiartDiarization
|
||||||
@@ -90,3 +120,33 @@ class TranscriptionEngine:
|
|||||||
)
|
)
|
||||||
|
|
||||||
TranscriptionEngine._initialized = True
|
TranscriptionEngine._initialized = True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def online_factory(args, asr, tokenizer, logfile=sys.stderr):
|
||||||
|
if args.backend == "simulstreaming":
|
||||||
|
from simul_whisper import SimulStreamingOnlineProcessor
|
||||||
|
online = SimulStreamingOnlineProcessor(
|
||||||
|
asr,
|
||||||
|
logfile=logfile,
|
||||||
|
)
|
||||||
|
# warmup_online(online, args.warmup_file)
|
||||||
|
elif args.vac:
|
||||||
|
online = VACOnlineASRProcessor(
|
||||||
|
args.min_chunk_size,
|
||||||
|
asr,
|
||||||
|
tokenizer,
|
||||||
|
logfile=logfile,
|
||||||
|
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
|
||||||
|
confidence_validation = args.confidence_validation
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
online = OnlineASRProcessor(
|
||||||
|
asr,
|
||||||
|
tokenizer,
|
||||||
|
logfile=logfile,
|
||||||
|
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
|
||||||
|
confidence_validation = args.confidence_validation
|
||||||
|
)
|
||||||
|
return online
|
||||||
|
|
||||||
@@ -165,7 +165,7 @@ class WebSocketAudioSource(AudioSource):
|
|||||||
|
|
||||||
|
|
||||||
class DiartDiarization:
|
class DiartDiarization:
|
||||||
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 0.5, segmentation_model_name: str = "pyannote/segmentation-3.0", embedding_model_name: str = "speechbrain/spkrec-ecapa-voxceleb"):
|
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 1.5, segmentation_model_name: str = "pyannote/segmentation-3.0", embedding_model_name: str = "pyannote/embedding"):
|
||||||
segmentation_model = m.SegmentationModel.from_pretrained(segmentation_model_name)
|
segmentation_model = m.SegmentationModel.from_pretrained(segmentation_model_name)
|
||||||
embedding_model = m.EmbeddingModel.from_pretrained(embedding_model_name)
|
embedding_model = m.EmbeddingModel.from_pretrained(embedding_model_name)
|
||||||
|
|
||||||
@@ -206,15 +206,14 @@ class DiartDiarization:
|
|||||||
"""
|
"""
|
||||||
if self.custom_source:
|
if self.custom_source:
|
||||||
self.custom_source.push_audio(pcm_array)
|
self.custom_source.push_audio(pcm_array)
|
||||||
self.observer.clear_old_segments()
|
# self.observer.clear_old_segments()
|
||||||
return self.observer.get_segments()
|
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""Close the audio source."""
|
"""Close the audio source."""
|
||||||
if self.custom_source:
|
if self.custom_source:
|
||||||
self.custom_source.close()
|
self.custom_source.close()
|
||||||
|
|
||||||
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list, use_punctuation_split: bool = False) -> float:
|
def assign_speakers_to_tokens(self, tokens: list, use_punctuation_split: bool = False) -> float:
|
||||||
"""
|
"""
|
||||||
Assign speakers to tokens based on timing overlap with speaker segments.
|
Assign speakers to tokens based on timing overlap with speaker segments.
|
||||||
Uses the segments collected by the observer.
|
Uses the segments collected by the observer.
|
||||||
@@ -231,85 +230,82 @@ class DiartDiarization:
|
|||||||
|
|
||||||
if not self.lag_diart and segments and tokens:
|
if not self.lag_diart and segments and tokens:
|
||||||
self.lag_diart = segments[0].start - tokens[0].start
|
self.lag_diart = segments[0].start - tokens[0].start
|
||||||
for token in tokens:
|
|
||||||
for segment in segments:
|
|
||||||
if not (segment.end <= token.start + self.lag_diart or segment.start >= token.end + self.lag_diart):
|
|
||||||
token.speaker = extract_number(segment.speaker) + 1
|
|
||||||
end_attributed_speaker = max(token.end, end_attributed_speaker)
|
|
||||||
|
|
||||||
if use_punctuation_split and len(tokens) > 1:
|
if not use_punctuation_split:
|
||||||
punctuation_marks = {'.', '!', '?'}
|
for token in tokens:
|
||||||
|
for segment in segments:
|
||||||
print("Here are the tokens:",
|
if not (segment.end <= token.start + self.lag_diart or segment.start >= token.end + self.lag_diart):
|
||||||
[(t.text, t.start, t.end, t.speaker) for t in tokens[:10]])
|
token.speaker = extract_number(segment.speaker) + 1
|
||||||
|
else:
|
||||||
segment_map = []
|
tokens = add_speaker_to_tokens(segments, tokens)
|
||||||
for segment in segments:
|
return tokens
|
||||||
speaker_num = extract_number(segment.speaker) + 1
|
|
||||||
segment_map.append((segment.start, segment.end, speaker_num))
|
|
||||||
segment_map.sort(key=lambda x: x[0])
|
|
||||||
|
|
||||||
i = 0
|
|
||||||
while i < len(tokens):
|
|
||||||
current_token = tokens[i]
|
|
||||||
|
|
||||||
is_sentence_end = False
|
|
||||||
if current_token.text and current_token.text.strip():
|
|
||||||
text = current_token.text.strip()
|
|
||||||
if text[-1] in punctuation_marks:
|
|
||||||
is_sentence_end = True
|
|
||||||
logger.debug(f"Token {i} ends sentence: '{current_token.text}' at {current_token.end:.2f}s")
|
|
||||||
|
|
||||||
if is_sentence_end and current_token.speaker != -1:
|
|
||||||
punctuation_time = current_token.end
|
|
||||||
current_speaker = current_token.speaker
|
|
||||||
|
|
||||||
j = i + 1
|
|
||||||
next_sentence_tokens = []
|
|
||||||
while j < len(tokens):
|
|
||||||
next_token = tokens[j]
|
|
||||||
next_sentence_tokens.append(j)
|
|
||||||
|
|
||||||
# Check if this token ends the next sentence
|
|
||||||
if next_token.text and next_token.text.strip():
|
|
||||||
if next_token.text.strip()[-1] in punctuation_marks:
|
|
||||||
break
|
|
||||||
j += 1
|
|
||||||
|
|
||||||
if next_sentence_tokens:
|
|
||||||
speaker_times = {}
|
|
||||||
|
|
||||||
for idx in next_sentence_tokens:
|
|
||||||
token = tokens[idx]
|
|
||||||
# Find which segments overlap with this token
|
|
||||||
for seg_start, seg_end, seg_speaker in segment_map:
|
|
||||||
if not (seg_end <= token.start or seg_start >= token.end):
|
|
||||||
# Calculate overlap duration
|
|
||||||
overlap_start = max(seg_start, token.start)
|
|
||||||
overlap_end = min(seg_end, token.end)
|
|
||||||
overlap_duration = overlap_end - overlap_start
|
|
||||||
|
|
||||||
if seg_speaker not in speaker_times:
|
|
||||||
speaker_times[seg_speaker] = 0
|
|
||||||
speaker_times[seg_speaker] += overlap_duration
|
|
||||||
|
|
||||||
if speaker_times:
|
|
||||||
dominant_speaker = max(speaker_times.items(), key=lambda x: x[1])[0]
|
|
||||||
|
|
||||||
if dominant_speaker != current_speaker:
|
|
||||||
logger.debug(f" Speaker change after punctuation: {current_speaker} → {dominant_speaker}")
|
|
||||||
|
|
||||||
for idx in next_sentence_tokens:
|
|
||||||
if tokens[idx].speaker != dominant_speaker:
|
|
||||||
logger.debug(f" Reassigning token {idx} ('{tokens[idx].text}') to Speaker {dominant_speaker}")
|
|
||||||
tokens[idx].speaker = dominant_speaker
|
|
||||||
end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker)
|
|
||||||
else:
|
|
||||||
for idx in next_sentence_tokens:
|
|
||||||
if tokens[idx].speaker == -1:
|
|
||||||
tokens[idx].speaker = current_speaker
|
|
||||||
end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker)
|
|
||||||
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
return end_attributed_speaker
|
def concatenate_speakers(segments):
|
||||||
|
segments_concatenated = [{"speaker": 1, "begin": 0.0, "end": 0.0}]
|
||||||
|
for segment in segments:
|
||||||
|
speaker = extract_number(segment.speaker) + 1
|
||||||
|
if segments_concatenated[-1]['speaker'] != speaker:
|
||||||
|
segments_concatenated.append({"speaker": speaker, "begin": segment.start, "end": segment.end})
|
||||||
|
else:
|
||||||
|
segments_concatenated[-1]['end'] = segment.end
|
||||||
|
# print("Segments concatenated:")
|
||||||
|
# for entry in segments_concatenated:
|
||||||
|
# print(f"Speaker {entry['speaker']}: {entry['begin']:.2f}s - {entry['end']:.2f}s")
|
||||||
|
return segments_concatenated
|
||||||
|
|
||||||
|
|
||||||
|
def add_speaker_to_tokens(segments, tokens):
|
||||||
|
"""
|
||||||
|
Assign speakers to tokens based on diarization segments, with punctuation-aware boundary adjustment.
|
||||||
|
"""
|
||||||
|
punctuation_marks = {'.', '!', '?'}
|
||||||
|
punctuation_tokens = [token for token in tokens if token.text.strip() in punctuation_marks]
|
||||||
|
segments_concatenated = concatenate_speakers(segments)
|
||||||
|
for ind, segment in enumerate(segments_concatenated):
|
||||||
|
for i, punctuation_token in enumerate(punctuation_tokens):
|
||||||
|
if punctuation_token.start > segment['end']:
|
||||||
|
after_length = punctuation_token.start - segment['end']
|
||||||
|
before_length = segment['end'] - punctuation_tokens[i - 1].end
|
||||||
|
if before_length > after_length:
|
||||||
|
segment['end'] = punctuation_token.start
|
||||||
|
if i < len(punctuation_tokens) - 1 and ind + 1 < len(segments_concatenated):
|
||||||
|
segments_concatenated[ind + 1]['begin'] = punctuation_token.start
|
||||||
|
else:
|
||||||
|
segment['end'] = punctuation_tokens[i - 1].end
|
||||||
|
if i < len(punctuation_tokens) - 1 and ind - 1 >= 0:
|
||||||
|
segments_concatenated[ind - 1]['begin'] = punctuation_tokens[i - 1].end
|
||||||
|
break
|
||||||
|
|
||||||
|
last_end = 0.0
|
||||||
|
for token in tokens:
|
||||||
|
start = max(last_end + 0.01, token.start)
|
||||||
|
token.start = start
|
||||||
|
token.end = max(start, token.end)
|
||||||
|
last_end = token.end
|
||||||
|
|
||||||
|
ind_last_speaker = 0
|
||||||
|
for segment in segments_concatenated:
|
||||||
|
for i, token in enumerate(tokens[ind_last_speaker:]):
|
||||||
|
if token.end <= segment['end']:
|
||||||
|
token.speaker = segment['speaker']
|
||||||
|
ind_last_speaker = i + 1
|
||||||
|
# print(
|
||||||
|
# f"Token '{token.text}' ('begin': {token.start:.2f}, 'end': {token.end:.2f}) "
|
||||||
|
# f"assigned to Speaker {segment['speaker']} ('segment': {segment['begin']:.2f}-{segment['end']:.2f})"
|
||||||
|
# )
|
||||||
|
elif token.start > segment['end']:
|
||||||
|
break
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
def visualize_tokens(tokens):
|
||||||
|
conversation = [{"speaker": -1, "text": ""}]
|
||||||
|
for token in tokens:
|
||||||
|
speaker = conversation[-1]['speaker']
|
||||||
|
if token.speaker != speaker:
|
||||||
|
conversation.append({"speaker": token.speaker, "text": token.text})
|
||||||
|
else:
|
||||||
|
conversation[-1]['text'] += token.text
|
||||||
|
print("Conversation:")
|
||||||
|
for entry in conversation:
|
||||||
|
print(f"Speaker {entry['speaker']}: {entry['text']}")
|
||||||
@@ -1,32 +0,0 @@
|
|||||||
import os
|
|
||||||
import requests
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
def get_module_path():
|
|
||||||
return os.path.dirname(inspect.getfile(inspect.currentframe()))
|
|
||||||
|
|
||||||
GITHUB_API_URL = "https://api.github.com/repos/ufal/SimulStreaming/contents/simul_whisper/whisper"
|
|
||||||
RAW_BASE_URL = "https://raw.githubusercontent.com/ufal/SimulStreaming/main/simul_whisper/whisper"
|
|
||||||
TARGET_DIR = os.path.join(get_module_path(), "simul_whisper", "whisper")
|
|
||||||
|
|
||||||
def download_files_from_github(api_url, local_dir):
|
|
||||||
os.makedirs(local_dir, exist_ok=True)
|
|
||||||
response = requests.get(api_url)
|
|
||||||
response.raise_for_status()
|
|
||||||
items = response.json()
|
|
||||||
for item in items:
|
|
||||||
if item['type'] == 'file':
|
|
||||||
download_url = item['download_url']
|
|
||||||
file_name = item['name']
|
|
||||||
file_response = requests.get(download_url)
|
|
||||||
file_response.raise_for_status()
|
|
||||||
with open(os.path.join(local_dir, file_name), 'wb') as f:
|
|
||||||
f.write(file_response.content)
|
|
||||||
elif item['type'] == 'dir':
|
|
||||||
# Recursive call for subdirectories
|
|
||||||
download_files_from_github(item['url'], os.path.join(local_dir, item['name']))
|
|
||||||
|
|
||||||
def download_simulstreaming_backend():
|
|
||||||
print(f"Downloading files into {TARGET_DIR} ...")
|
|
||||||
download_files_from_github(GITHUB_API_URL, TARGET_DIR)
|
|
||||||
print("✅ Download of SimulStreaming backend files completed successfully.")
|
|
||||||
103
whisperlivekit/remove_silences.py
Normal file
103
whisperlivekit/remove_silences.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
from whisperlivekit.timed_objects import ASRToken
|
||||||
|
import re
|
||||||
|
|
||||||
|
MIN_SILENCE_DURATION = 4 #in seconds
|
||||||
|
END_SILENCE_DURATION = 8 #in seconds. you should keep it important to not have false positive when the model lag is important
|
||||||
|
|
||||||
|
def blank_to_silence(tokens):
|
||||||
|
full_string = ''.join([t.text for t in tokens])
|
||||||
|
patterns = [re.compile(r'(?:\s*\[BLANK_AUDIO\]\s*)+'), re.compile(r'(?:\s*\[typing\]\s*)+')]
|
||||||
|
matches = []
|
||||||
|
for pattern in patterns:
|
||||||
|
for m in pattern.finditer(full_string):
|
||||||
|
matches.append({
|
||||||
|
'start': m.start(),
|
||||||
|
'end': m.end()
|
||||||
|
})
|
||||||
|
if matches:
|
||||||
|
# cleaned = pattern.sub(' ', full_string).strip()
|
||||||
|
# print("Cleaned:", cleaned)
|
||||||
|
cumulated_len = 0
|
||||||
|
silence_token = None
|
||||||
|
cleaned_tokens = []
|
||||||
|
for token in tokens:
|
||||||
|
if matches:
|
||||||
|
start = cumulated_len
|
||||||
|
end = cumulated_len + len(token.text)
|
||||||
|
cumulated_len = end
|
||||||
|
if start >= matches[0]['start'] and end <= matches[0]['end']:
|
||||||
|
if silence_token: #previous token was already silence
|
||||||
|
silence_token.start = min(silence_token.start, token.start)
|
||||||
|
silence_token.end = max(silence_token.end, token.end)
|
||||||
|
else: #new silence
|
||||||
|
silence_token = ASRToken(
|
||||||
|
start=token.start,
|
||||||
|
end=token.end,
|
||||||
|
speaker=-2,
|
||||||
|
probability=0.95
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if silence_token: #there was silence but no more
|
||||||
|
if silence_token.end - silence_token.start >= MIN_SILENCE_DURATION:
|
||||||
|
cleaned_tokens.append(
|
||||||
|
silence_token
|
||||||
|
)
|
||||||
|
silence_token = None
|
||||||
|
matches.pop(0)
|
||||||
|
cleaned_tokens.append(token)
|
||||||
|
# print(cleaned_tokens)
|
||||||
|
return cleaned_tokens
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def no_token_to_silence(tokens):
|
||||||
|
new_tokens = []
|
||||||
|
silence_token = None
|
||||||
|
for token in tokens:
|
||||||
|
if token.speaker == -2:
|
||||||
|
if new_tokens and new_tokens[-1].speaker == -2: #if token is silence and previous one too
|
||||||
|
new_tokens[-1].end = token.end
|
||||||
|
else:
|
||||||
|
new_tokens.append(token)
|
||||||
|
|
||||||
|
last_end = new_tokens[-1].end if new_tokens else 0.0
|
||||||
|
if token.start - last_end >= MIN_SILENCE_DURATION: #if token is not silence but important gap
|
||||||
|
if new_tokens and new_tokens[-1].speaker == -2:
|
||||||
|
new_tokens[-1].end = token.start
|
||||||
|
else:
|
||||||
|
silence_token = ASRToken(
|
||||||
|
start=last_end,
|
||||||
|
end=token.start,
|
||||||
|
speaker=-2,
|
||||||
|
probability=0.95
|
||||||
|
)
|
||||||
|
new_tokens.append(silence_token)
|
||||||
|
|
||||||
|
if token.speaker != -2:
|
||||||
|
new_tokens.append(token)
|
||||||
|
return new_tokens
|
||||||
|
|
||||||
|
def ends_with_silence(tokens, current_time):
|
||||||
|
if not tokens:
|
||||||
|
return []
|
||||||
|
last_token = tokens[-1]
|
||||||
|
if tokens and current_time - last_token.end >= END_SILENCE_DURATION:
|
||||||
|
if last_token.speaker == -2:
|
||||||
|
last_token.end = current_time
|
||||||
|
else:
|
||||||
|
tokens.append(
|
||||||
|
ASRToken(
|
||||||
|
start=tokens[-1].end,
|
||||||
|
end=current_time,
|
||||||
|
speaker=-2,
|
||||||
|
probability=0.95
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
def handle_silences(tokens, current_time):
|
||||||
|
tokens = blank_to_silence(tokens) #useful for simulstreaming backend which tends to generate [BLANK_AUDIO] text
|
||||||
|
tokens = no_token_to_silence(tokens)
|
||||||
|
tokens = ends_with_silence(tokens, current_time)
|
||||||
|
return tokens
|
||||||
|
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
from .backend import SimulStreamingASR, SimulStreamingOnlineProcessor
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SimulStreamingASR",
|
||||||
|
"SimulStreamingOnlineProcessor",
|
||||||
|
]
|
||||||
|
|||||||
223
whisperlivekit/simul_whisper/backend.py
Normal file
223
whisperlivekit/simul_whisper/backend.py
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
import logging
|
||||||
|
from typing import List, Tuple, Optional
|
||||||
|
import logging
|
||||||
|
from whisperlivekit.timed_objects import ASRToken, Transcript
|
||||||
|
from whisperlivekit.simul_whisper.license_simulstreaming import SIMULSTREAMING_LICENSE
|
||||||
|
from .whisper import load_model, tokenizer
|
||||||
|
import os
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
||||||
|
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper
|
||||||
|
from whisperlivekit.simul_whisper.whisper import tokenizer
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"""SimulStreaming dependencies are not available.
|
||||||
|
Please install WhisperLiveKit using pip install "whisperlivekit[simulstreaming]".""")
|
||||||
|
|
||||||
|
class SimulStreamingOnlineProcessor:
|
||||||
|
SAMPLING_RATE = 16000
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
asr,
|
||||||
|
logfile=sys.stderr,
|
||||||
|
warmup_file=None
|
||||||
|
):
|
||||||
|
self.asr = asr
|
||||||
|
self.logfile = logfile
|
||||||
|
self.is_last = False
|
||||||
|
self.beg = 0.0
|
||||||
|
self.end = 0.0
|
||||||
|
self.cumulative_audio_duration = 0.0
|
||||||
|
|
||||||
|
self.committed: List[ASRToken] = []
|
||||||
|
self.last_result_tokens: List[ASRToken] = []
|
||||||
|
self.model = PaddedAlignAttWhisper(
|
||||||
|
cfg=asr.cfg,
|
||||||
|
loaded_model=asr.whisper_model)
|
||||||
|
if asr.tokenizer:
|
||||||
|
self.model.tokenizer = asr.tokenizer
|
||||||
|
|
||||||
|
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: Optional[float] = None):
|
||||||
|
"""Append an audio chunk to be processed by SimulStreaming."""
|
||||||
|
|
||||||
|
# Convert numpy array to torch tensor
|
||||||
|
audio_tensor = torch.from_numpy(audio).float()
|
||||||
|
|
||||||
|
# Update timing
|
||||||
|
chunk_duration = len(audio) / self.SAMPLING_RATE
|
||||||
|
self.cumulative_audio_duration += chunk_duration
|
||||||
|
|
||||||
|
if audio_stream_end_time is not None:
|
||||||
|
self.end = audio_stream_end_time
|
||||||
|
else:
|
||||||
|
self.end = self.cumulative_audio_duration
|
||||||
|
self.model.insert_audio(audio_tensor)
|
||||||
|
|
||||||
|
def get_buffer(self):
|
||||||
|
return Transcript(
|
||||||
|
start=None,
|
||||||
|
end=None,
|
||||||
|
text='',
|
||||||
|
probability=None
|
||||||
|
)
|
||||||
|
|
||||||
|
def timestamped_text(self, tokens, generation):
|
||||||
|
# From the simulstreaming repo. self.model to self.asr.model
|
||||||
|
pr = generation["progress"]
|
||||||
|
if "result" not in generation:
|
||||||
|
split_words, split_tokens = self.model.tokenizer.split_to_word_tokens(tokens)
|
||||||
|
else:
|
||||||
|
split_words, split_tokens = generation["result"]["split_words"], generation["result"]["split_tokens"]
|
||||||
|
|
||||||
|
frames = [p["most_attended_frames"][0] for p in pr]
|
||||||
|
tokens = tokens.copy()
|
||||||
|
ret = []
|
||||||
|
for sw,st in zip(split_words,split_tokens):
|
||||||
|
b = None
|
||||||
|
for stt in st:
|
||||||
|
t,f = tokens.pop(0), frames.pop(0)
|
||||||
|
if t != stt:
|
||||||
|
raise ValueError(f"Token mismatch: {t} != {stt} at frame {f}.")
|
||||||
|
if b is None:
|
||||||
|
b = f
|
||||||
|
e = f
|
||||||
|
out = (b*0.02, e*0.02, sw)
|
||||||
|
ret.append(out)
|
||||||
|
logger.debug(f"TS-WORD:\t{' '.join(map(str, out))}")
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def process_iter(self) -> Tuple[List[ASRToken], float]:
|
||||||
|
"""
|
||||||
|
Process accumulated audio chunks using SimulStreaming.
|
||||||
|
|
||||||
|
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
tokens, generation_progress = self.model.infer(is_last=self.is_last)
|
||||||
|
ts_words = self.timestamped_text(tokens, generation_progress)
|
||||||
|
|
||||||
|
new_tokens = []
|
||||||
|
for ts_word in ts_words:
|
||||||
|
|
||||||
|
start, end, word = ts_word
|
||||||
|
token = ASRToken(
|
||||||
|
start=start,
|
||||||
|
end=end,
|
||||||
|
text=word,
|
||||||
|
probability=0.95 # fake prob. Maybe we can extract it from the model?
|
||||||
|
)
|
||||||
|
new_tokens.append(token)
|
||||||
|
self.committed.extend(new_tokens)
|
||||||
|
|
||||||
|
return new_tokens, self.end
|
||||||
|
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"SimulStreaming processing error: {e}")
|
||||||
|
return [], self.end
|
||||||
|
|
||||||
|
def warmup(self, audio, init_prompt=""):
|
||||||
|
"""Warmup the SimulStreaming model."""
|
||||||
|
try:
|
||||||
|
self.model.insert_audio(audio)
|
||||||
|
self.model.infer(True)
|
||||||
|
self.model.refresh_segment(complete=True)
|
||||||
|
logger.info("SimulStreaming model warmed up successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"SimulStreaming warmup failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class SimulStreamingASR():
|
||||||
|
"""SimulStreaming backend with AlignAtt policy."""
|
||||||
|
sep = ""
|
||||||
|
|
||||||
|
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr, **kwargs):
|
||||||
|
logger.warning(SIMULSTREAMING_LICENSE)
|
||||||
|
self.logfile = logfile
|
||||||
|
self.transcribe_kargs = {}
|
||||||
|
self.original_language = None if lan == "auto" else lan
|
||||||
|
|
||||||
|
self.model_path = kwargs.get('model_path', './large-v3.pt')
|
||||||
|
self.frame_threshold = kwargs.get('frame_threshold', 25)
|
||||||
|
self.audio_max_len = kwargs.get('audio_max_len', 30.0)
|
||||||
|
self.audio_min_len = kwargs.get('audio_min_len', 0.0)
|
||||||
|
self.segment_length = kwargs.get('segment_length', 0.5)
|
||||||
|
self.beams = kwargs.get('beams', 1)
|
||||||
|
self.decoder_type = kwargs.get('decoder_type', 'greedy' if self.beams == 1 else 'beam')
|
||||||
|
self.task = kwargs.get('task', 'transcribe')
|
||||||
|
self.cif_ckpt_path = kwargs.get('cif_ckpt_path', None)
|
||||||
|
self.never_fire = kwargs.get('never_fire', False)
|
||||||
|
self.init_prompt = kwargs.get('init_prompt', None)
|
||||||
|
self.static_init_prompt = kwargs.get('static_init_prompt', None)
|
||||||
|
self.max_context_tokens = kwargs.get('max_context_tokens', None)
|
||||||
|
|
||||||
|
if model_dir is not None:
|
||||||
|
self.model_path = model_dir
|
||||||
|
elif modelsize is not None:
|
||||||
|
model_mapping = {
|
||||||
|
'tiny': './tiny.pt',
|
||||||
|
'base': './base.pt',
|
||||||
|
'small': './small.pt',
|
||||||
|
'medium': './medium.pt',
|
||||||
|
'medium.en': './medium.en.pt',
|
||||||
|
'large-v1': './large-v1.pt',
|
||||||
|
'base.en': './base.en.pt',
|
||||||
|
'small.en': './small.en.pt',
|
||||||
|
'tiny.en': './tiny.en.pt',
|
||||||
|
'large-v2': './large-v2.pt',
|
||||||
|
'large-v3': './large-v3.pt',
|
||||||
|
'large': './large-v3.pt'
|
||||||
|
}
|
||||||
|
self.model_path = model_mapping.get(modelsize, f'./{modelsize}.pt')
|
||||||
|
|
||||||
|
self.model = self.load_model(modelsize)
|
||||||
|
|
||||||
|
# Set up tokenizer for translation if needed
|
||||||
|
if self.task == "translate":
|
||||||
|
self.tokenizer = self.set_translate_task()
|
||||||
|
else:
|
||||||
|
self.tokenizer = None
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(self, modelsize):
|
||||||
|
self.cfg = AlignAttConfig(
|
||||||
|
model_path=self.model_path,
|
||||||
|
segment_length=self.segment_length,
|
||||||
|
frame_threshold=self.frame_threshold,
|
||||||
|
language=self.original_language,
|
||||||
|
audio_max_len=self.audio_max_len,
|
||||||
|
audio_min_len=self.audio_min_len,
|
||||||
|
cif_ckpt_path=self.cif_ckpt_path,
|
||||||
|
decoder_type="beam",
|
||||||
|
beam_size=self.beams,
|
||||||
|
task=self.task,
|
||||||
|
never_fire=self.never_fire,
|
||||||
|
init_prompt=self.init_prompt,
|
||||||
|
max_context_tokens=self.max_context_tokens,
|
||||||
|
static_init_prompt=self.static_init_prompt,
|
||||||
|
)
|
||||||
|
model_name = os.path.basename(self.cfg.model_path).replace(".pt", "")
|
||||||
|
model_path = os.path.dirname(os.path.abspath(self.cfg.model_path))
|
||||||
|
self.whisper_model = load_model(name=model_name, download_root=model_path)
|
||||||
|
|
||||||
|
|
||||||
|
def set_translate_task(self):
|
||||||
|
"""Set up translation task."""
|
||||||
|
return tokenizer.get_tokenizer(
|
||||||
|
multilingual=True,
|
||||||
|
language=self.model.cfg.language,
|
||||||
|
num_languages=self.model.model.num_languages,
|
||||||
|
task="translate"
|
||||||
|
)
|
||||||
|
|
||||||
|
def transcribe(self, audio):
|
||||||
|
"""
|
||||||
|
Only used for warmup. It's a direct whisper call, not a simulstreaming call
|
||||||
|
"""
|
||||||
|
self.whisper_model.transcribe(audio, language=self.original_language)
|
||||||
@@ -8,7 +8,7 @@ class SimulWhisperConfig:
|
|||||||
'''Options that are common for all simul policies that could be implemented in SimulWhisper.'''
|
'''Options that are common for all simul policies that could be implemented in SimulWhisper.'''
|
||||||
model_path: str
|
model_path: str
|
||||||
language: str = field(default="zh")
|
language: str = field(default="zh")
|
||||||
nonspeech_prob: float = 1.0
|
nonspeech_prob: float = 0.5
|
||||||
audio_min_len: float = 1.0
|
audio_min_len: float = 1.0
|
||||||
decoder_type: Literal["greedy","beam"] = "greedy"
|
decoder_type: Literal["greedy","beam"] = "greedy"
|
||||||
beam_size: int = 5
|
beam_size: int = 5
|
||||||
|
|||||||
@@ -1,25 +0,0 @@
|
|||||||
📄 SimulStreaming (https://github.com/ufal/SimulStreaming) Licence
|
|
||||||
|
|
||||||
SimulStreaming is dual-licensed:
|
|
||||||
|
|
||||||
🔹 Non-Commercial Use
|
|
||||||
|
|
||||||
You may use SimulStreaming under the **PolyForm Noncommercial License 1.0.0** if you
|
|
||||||
obtain the code through the GitHub repository. This license is **free of charge**
|
|
||||||
and comes with **no obligations** for non-commercial users.
|
|
||||||
|
|
||||||
🔸 Commercial Use
|
|
||||||
|
|
||||||
Understanding who uses SimulStreaming commercially helps us improve and
|
|
||||||
prioritize development. Therefore, we want to **require registration** of those who acquire a commercial licence.
|
|
||||||
|
|
||||||
We plan to make the commercial licenceses **affordable** to SMEs and individuals. We
|
|
||||||
are considering to provide commercial licenses either for free or for symbolic
|
|
||||||
one-time fee, and maybe also provide additional support. You can share your preference via the [questionnaire](https://forms.cloud.microsoft/e/7tCxb4gJfB).
|
|
||||||
|
|
||||||
You can also leave your contact [there](https://forms.cloud.microsoft/e/7tCxb4gJfB) to be notified when the commercial licenses become
|
|
||||||
available.
|
|
||||||
|
|
||||||
✉️ Contact
|
|
||||||
|
|
||||||
[Dominik Macháček](https://ufal.mff.cuni.cz/dominik-machacek/), machacek@ufal.mff.cuni.cz
|
|
||||||
@@ -25,6 +25,9 @@ class BeamTokens(Tokens):
|
|||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return self.__str__()
|
return self.__str__()
|
||||||
|
|
||||||
|
def as_text(self, tokenizer):
|
||||||
|
return tokenizer.decode(self.tokens)
|
||||||
|
|
||||||
class Logits(Tokens):
|
class Logits(Tokens):
|
||||||
def __init__(self, logits):
|
def __init__(self, logits):
|
||||||
super().__init__(logits)
|
super().__init__(logits)
|
||||||
|
|||||||
5
whisperlivekit/simul_whisper/license_simulstreaming.py
Normal file
5
whisperlivekit/simul_whisper/license_simulstreaming.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
SIMULSTREAMING_LICENSE = f"""
|
||||||
|
SimulStreaming backend is dual-licensed:
|
||||||
|
• Non-Commercial Use: PolyForm Noncommercial License 1.0.0.
|
||||||
|
• Commercial Use: Check SimulStreaming README (github.com/ufal/SimulStreaming) for more details.
|
||||||
|
"""
|
||||||
@@ -10,12 +10,12 @@ from .whisper import load_model, DecodingOptions, tokenizer
|
|||||||
from .config import AlignAttConfig
|
from .config import AlignAttConfig
|
||||||
from .whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES
|
from .whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES
|
||||||
from .whisper.timing import median_filter
|
from .whisper.timing import median_filter
|
||||||
from .whisper.decoding import SuppressBlank, GreedyDecoder, BeamSearchDecoder, SuppressTokens
|
from .whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens, detect_language
|
||||||
from .beam import BeamPyTorchInference
|
from .beam import BeamPyTorchInference
|
||||||
from .eow_detection import fire_at_boundary, load_cif
|
from .eow_detection import fire_at_boundary, load_cif
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from whisperlivekit.simul_whisper.token_buffer import TokenBuffer
|
from .token_buffer import TokenBuffer
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .generation_progress import *
|
from .generation_progress import *
|
||||||
@@ -24,6 +24,7 @@ DEC_PAD = 50257
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
import wave
|
||||||
|
|
||||||
# New features added to the original version of Simul-Whisper:
|
# New features added to the original version of Simul-Whisper:
|
||||||
# - large-v3 model support
|
# - large-v3 model support
|
||||||
@@ -32,29 +33,30 @@ import sys
|
|||||||
# - prompt -- static vs. non-static
|
# - prompt -- static vs. non-static
|
||||||
# - context
|
# - context
|
||||||
class PaddedAlignAttWhisper:
|
class PaddedAlignAttWhisper:
|
||||||
def __init__(self, cfg: AlignAttConfig) -> None:
|
def __init__(self, cfg: AlignAttConfig, loaded_model=None) -> None:
|
||||||
|
self.log_segments = 0
|
||||||
model_name = os.path.basename(cfg.model_path).replace(".pt", "")
|
model_name = os.path.basename(cfg.model_path).replace(".pt", "")
|
||||||
model_path = os.path.dirname(os.path.abspath(cfg.model_path))
|
model_path = os.path.dirname(os.path.abspath(cfg.model_path))
|
||||||
self.model = load_model(name=model_name, download_root=model_path)
|
if loaded_model:
|
||||||
|
self.model = loaded_model
|
||||||
|
else:
|
||||||
|
self.model = load_model(name=model_name, download_root=model_path)
|
||||||
|
|
||||||
logger.info(f"Model dimensions: {self.model.dims}")
|
logger.info(f"Model dimensions: {self.model.dims}")
|
||||||
|
|
||||||
decode_options = DecodingOptions(
|
self.decode_options = DecodingOptions(
|
||||||
language = cfg.language,
|
language = cfg.language,
|
||||||
without_timestamps = True,
|
without_timestamps = True,
|
||||||
task=cfg.task
|
task=cfg.task
|
||||||
)
|
)
|
||||||
self.tokenizer = tokenizer.get_tokenizer(
|
self.tokenizer_is_multilingual = not model_name.endswith(".en")
|
||||||
multilingual=not model_name.endswith(".en"),
|
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
|
||||||
language=cfg.language,
|
self.detected_language = cfg.language if cfg.language != "auto" else None
|
||||||
num_languages=self.model.num_languages,
|
|
||||||
task=decode_options.task
|
|
||||||
)
|
|
||||||
self.max_text_len = self.model.dims.n_text_ctx
|
self.max_text_len = self.model.dims.n_text_ctx
|
||||||
self.num_decoder_layers = len(self.model.decoder.blocks)
|
self.num_decoder_layers = len(self.model.decoder.blocks)
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
|
|
||||||
|
|
||||||
# model to detect end-of-word boundary at the end of the segment
|
# model to detect end-of-word boundary at the end of the segment
|
||||||
self.CIFLinear, self.always_fire, self.never_fire = load_cif(cfg,
|
self.CIFLinear, self.always_fire, self.never_fire = load_cif(cfg,
|
||||||
n_audio_state=self.model.dims.n_audio_state,
|
n_audio_state=self.model.dims.n_audio_state,
|
||||||
@@ -95,14 +97,6 @@ class PaddedAlignAttWhisper:
|
|||||||
self.num_align_heads += 1
|
self.num_align_heads += 1
|
||||||
|
|
||||||
|
|
||||||
# init tokens (mandatory prompt)
|
|
||||||
self.initial_tokens = torch.tensor(
|
|
||||||
self.tokenizer.sot_sequence_including_notimestamps,
|
|
||||||
dtype=torch.long,
|
|
||||||
device=self.model.device).unsqueeze(0)
|
|
||||||
self.initial_token_length = self.initial_tokens.shape[1]
|
|
||||||
self.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
|
|
||||||
|
|
||||||
# tokens to be suppressed from decoding, to prevent hallucinations
|
# tokens to be suppressed from decoding, to prevent hallucinations
|
||||||
suppress_tokens = [
|
suppress_tokens = [
|
||||||
self.tokenizer.transcribe,
|
self.tokenizer.transcribe,
|
||||||
@@ -121,6 +115,17 @@ class PaddedAlignAttWhisper:
|
|||||||
self.suppress_tokens = lambda logits: sup_tokens.apply(logits, None)
|
self.suppress_tokens = lambda logits: sup_tokens.apply(logits, None)
|
||||||
# blank tokens are suppresed for new segments near the line 334
|
# blank tokens are suppresed for new segments near the line 334
|
||||||
|
|
||||||
|
# it's going to be regenerated after lang id
|
||||||
|
self.segments = []
|
||||||
|
self.init_tokens()
|
||||||
|
|
||||||
|
self.last_attend_frame = -self.cfg.rewind_threshold
|
||||||
|
|
||||||
|
if self.cfg.max_context_tokens is None:
|
||||||
|
self.max_context_tokens = self.max_text_len
|
||||||
|
else:
|
||||||
|
self.max_context_tokens = self.cfg.max_context_tokens
|
||||||
|
self.init_context()
|
||||||
|
|
||||||
# decoder type: greedy or beam
|
# decoder type: greedy or beam
|
||||||
if cfg.decoder_type == "greedy":
|
if cfg.decoder_type == "greedy":
|
||||||
@@ -135,16 +140,13 @@ class PaddedAlignAttWhisper:
|
|||||||
|
|
||||||
self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size)
|
self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size)
|
||||||
|
|
||||||
# init state
|
def create_tokenizer(self, language=None):
|
||||||
self.segments = []
|
self.tokenizer = tokenizer.get_tokenizer(
|
||||||
self.tokens = [self.initial_tokens]
|
multilingual=self.tokenizer_is_multilingual,
|
||||||
self.last_attend_frame = -self.cfg.rewind_threshold
|
language=language,
|
||||||
|
num_languages=self.model.num_languages,
|
||||||
if self.cfg.max_context_tokens is None:
|
task=self.decode_options.task
|
||||||
self.max_context_tokens = self.max_text_len
|
)
|
||||||
else:
|
|
||||||
self.max_context_tokens = self.cfg.max_context_tokens
|
|
||||||
self.init_context()
|
|
||||||
|
|
||||||
def init_context(self):
|
def init_context(self):
|
||||||
kw = {'tokenizer': self.tokenizer,
|
kw = {'tokenizer': self.tokenizer,
|
||||||
@@ -156,6 +158,19 @@ class PaddedAlignAttWhisper:
|
|||||||
if self.cfg.init_prompt is not None:
|
if self.cfg.init_prompt is not None:
|
||||||
self.context.text += self.cfg.init_prompt
|
self.context.text += self.cfg.init_prompt
|
||||||
|
|
||||||
|
def init_tokens(self):
|
||||||
|
logger.debug(f"init tokens, {len(self.segments)}")
|
||||||
|
# init tokens (mandatory prompt)
|
||||||
|
self.initial_tokens = torch.tensor(
|
||||||
|
self.tokenizer.sot_sequence_including_notimestamps,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.model.device).unsqueeze(0)
|
||||||
|
self.initial_token_length = self.initial_tokens.shape[1]
|
||||||
|
self.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
|
||||||
|
# self.segments = []
|
||||||
|
logger.debug(f"init tokens after, {len(self.segments)}")
|
||||||
|
self.tokens = [self.initial_tokens]
|
||||||
|
|
||||||
def trim_context(self):
|
def trim_context(self):
|
||||||
logger.info("Trimming context")
|
logger.info("Trimming context")
|
||||||
c = len(self.context.as_token_ids()) - len(self.context.prefix_token_ids)
|
c = len(self.context.as_token_ids()) - len(self.context.prefix_token_ids)
|
||||||
@@ -191,15 +206,19 @@ class PaddedAlignAttWhisper:
|
|||||||
|
|
||||||
def refresh_segment(self, complete=False):
|
def refresh_segment(self, complete=False):
|
||||||
|
|
||||||
logger.debug("Refreshing segment")
|
logger.debug("Refreshing segment:")
|
||||||
self.tokens = [self.initial_tokens]
|
self.init_tokens()
|
||||||
self.last_attend_frame = -self.cfg.rewind_threshold
|
self.last_attend_frame = -self.cfg.rewind_threshold
|
||||||
|
self.detected_language = None
|
||||||
self.init_context()
|
self.init_context()
|
||||||
logger.debug(f"Context: {self.context}")
|
logger.debug(f"Context: {self.context}")
|
||||||
if not complete and len(self.segments) > 2:
|
if not complete and len(self.segments) > 2:
|
||||||
|
logger.debug("keeping last two segments because they are and it is not complete.")
|
||||||
self.segments = self.segments[-2:]
|
self.segments = self.segments[-2:]
|
||||||
else:
|
else:
|
||||||
|
logger.debug("removing all segments.")
|
||||||
self.segments = []
|
self.segments = []
|
||||||
|
self.log_segments += 1
|
||||||
|
|
||||||
|
|
||||||
def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor):
|
def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor):
|
||||||
@@ -208,8 +227,6 @@ class PaddedAlignAttWhisper:
|
|||||||
return fire_at_boundary(chunked_encoder_feature, self.CIFLinear)
|
return fire_at_boundary(chunked_encoder_feature, self.CIFLinear)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _current_tokens(self):
|
def _current_tokens(self):
|
||||||
|
|
||||||
toks = self.tokens
|
toks = self.tokens
|
||||||
@@ -256,16 +273,59 @@ class PaddedAlignAttWhisper:
|
|||||||
removed_len = 0
|
removed_len = 0
|
||||||
# len of audio is bigger than buffer_len. Going to remove the first segment
|
# len of audio is bigger than buffer_len. Going to remove the first segment
|
||||||
segments_len = self.segments_len()
|
segments_len = self.segments_len()
|
||||||
while segments_len > self.cfg.audio_max_len:
|
while len(self.segments) > 1 and segments_len > self.cfg.audio_max_len:
|
||||||
removed_len = self.segments[0].shape[0] / 16000
|
removed_len = self.segments[0].shape[0] / 16000
|
||||||
segments_len -= removed_len
|
segments_len -= removed_len
|
||||||
self.last_attend_frame -= int(TOKENS_PER_SECOND*removed_len)
|
self.last_attend_frame -= int(TOKENS_PER_SECOND*removed_len)
|
||||||
self.segments = self.segments[1:]
|
self.segments = self.segments[1:]
|
||||||
logger.debug(f"remove segments: {len(self.segments)} {len(self.tokens)}")
|
logger.debug(f"remove segments: {len(self.segments)} {len(self.tokens)}")
|
||||||
self.context.append_token_ids(self.tokens[1][0,:])
|
if len(self.tokens) > 1:
|
||||||
self.tokens = [self.initial_tokens] + self.tokens[2:]
|
self.context.append_token_ids(self.tokens[1][0,:])
|
||||||
|
self.tokens = [self.initial_tokens] + self.tokens[2:]
|
||||||
return removed_len
|
return removed_len
|
||||||
|
|
||||||
|
def _clean_cache(self):
|
||||||
|
'''clean the cache that stores the attention matrices and kv_cache.
|
||||||
|
It must be called every time after generation with the model.'''
|
||||||
|
# cleaning cache
|
||||||
|
self.dec_attns = []
|
||||||
|
self.kv_cache = {}
|
||||||
|
if self.decoder_type == "beam":
|
||||||
|
self.inference.kv_cache = self.kv_cache
|
||||||
|
self.token_decoder.reset()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def lang_id(self, encoder_features):
|
||||||
|
"""Language detection from encoder features.
|
||||||
|
This code is trimmed and copy-pasted from whisper.decoding.detect_language .
|
||||||
|
"""
|
||||||
|
|
||||||
|
# forward pass using a single token, startoftranscript
|
||||||
|
n_audio = encoder_features.shape[0]
|
||||||
|
x = torch.tensor([[self.tokenizer.sot]] * n_audio).to(self.model.device) # [n_audio, 1]
|
||||||
|
logits = self.model.logits(x, encoder_features)[:, 0]
|
||||||
|
|
||||||
|
# collect detected languages; suppress all non-language tokens
|
||||||
|
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
|
||||||
|
mask[list(self.tokenizer.all_language_tokens)] = False
|
||||||
|
logits[:, mask] = -np.inf
|
||||||
|
language_tokens = logits.argmax(dim=-1)
|
||||||
|
language_token_probs = logits.softmax(dim=-1).cpu()
|
||||||
|
language_probs = [
|
||||||
|
{
|
||||||
|
c: language_token_probs[i, j].item()
|
||||||
|
for j, c in zip(self.tokenizer.all_language_tokens, self.tokenizer.all_language_codes)
|
||||||
|
}
|
||||||
|
for i in range(n_audio)
|
||||||
|
]
|
||||||
|
|
||||||
|
single = encoder_features.ndim == 2
|
||||||
|
if single:
|
||||||
|
language_tokens = language_tokens[0]
|
||||||
|
language_probs = language_probs[0]
|
||||||
|
|
||||||
|
self._clean_cache()
|
||||||
|
return language_tokens, language_probs
|
||||||
|
|
||||||
### transcription / translation
|
### transcription / translation
|
||||||
|
|
||||||
@@ -273,9 +333,12 @@ class PaddedAlignAttWhisper:
|
|||||||
def infer(self, is_last=False):
|
def infer(self, is_last=False):
|
||||||
new_segment = True
|
new_segment = True
|
||||||
if len(self.segments) == 0:
|
if len(self.segments) == 0:
|
||||||
return []
|
logger.debug("No segments, nothing to do")
|
||||||
|
return [], {}
|
||||||
if not self._apply_minseglen():
|
if not self._apply_minseglen():
|
||||||
return []
|
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
|
||||||
|
input_segments = torch.cat(self.segments, dim=0)
|
||||||
|
return [], {}
|
||||||
|
|
||||||
# input_segments is concatenation of audio, it's one array
|
# input_segments is concatenation of audio, it's one array
|
||||||
if len(self.segments) > 1:
|
if len(self.segments) > 1:
|
||||||
@@ -283,8 +346,7 @@ class PaddedAlignAttWhisper:
|
|||||||
else:
|
else:
|
||||||
input_segments = self.segments[0]
|
input_segments = self.segments[0]
|
||||||
|
|
||||||
self.trim_context()
|
|
||||||
current_tokens = self._current_tokens()
|
|
||||||
|
|
||||||
# mel + padding to 30s
|
# mel + padding to 30s
|
||||||
mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES,
|
mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES,
|
||||||
@@ -295,18 +357,38 @@ class PaddedAlignAttWhisper:
|
|||||||
# the len of actual audio
|
# the len of actual audio
|
||||||
content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2)
|
content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2)
|
||||||
|
|
||||||
|
# encode
|
||||||
encoder_feature = self.model.encoder(mel)
|
encoder_feature = self.model.encoder(mel)
|
||||||
sum_logprobs = torch.zeros(self.cfg.beam_size, device=mel.device)
|
|
||||||
completed = False
|
|
||||||
|
|
||||||
|
# logger.debug(f"Encoder feature shape: {encoder_feature.shape}")
|
||||||
|
# if mel.shape[-2:] != (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
|
||||||
|
# logger.debug("mel ")
|
||||||
|
if self.cfg.language == "auto" and self.detected_language is None:
|
||||||
|
language_tokens, language_probs = self.lang_id(encoder_feature)
|
||||||
|
logger.debug(f"Language tokens: {language_tokens}, probs: {language_probs}")
|
||||||
|
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
|
||||||
|
logger.info(f"Detected language: {top_lan} with p={p:.4f}")
|
||||||
|
#self.tokenizer.language = top_lan
|
||||||
|
#self.tokenizer.__post_init__()
|
||||||
|
self.create_tokenizer(top_lan)
|
||||||
|
self.detected_language = top_lan
|
||||||
|
self.init_tokens()
|
||||||
|
logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}")
|
||||||
|
|
||||||
|
self.trim_context()
|
||||||
|
current_tokens = self._current_tokens()
|
||||||
|
#
|
||||||
fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :])
|
fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :])
|
||||||
|
|
||||||
|
|
||||||
####################### Decoding loop
|
####################### Decoding loop
|
||||||
logger.info("Decoding loop starts\n")
|
logger.info("Decoding loop starts\n")
|
||||||
|
|
||||||
|
sum_logprobs = torch.zeros(self.cfg.beam_size, device=mel.device)
|
||||||
|
completed = False
|
||||||
|
|
||||||
attn_of_alignment_heads = None
|
attn_of_alignment_heads = None
|
||||||
miost_attended_frame = None
|
most_attended_frame = None
|
||||||
|
|
||||||
token_len_before_decoding = current_tokens.shape[1]
|
token_len_before_decoding = current_tokens.shape[1]
|
||||||
|
|
||||||
@@ -515,11 +597,6 @@ class PaddedAlignAttWhisper:
|
|||||||
|
|
||||||
logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}")
|
logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}")
|
||||||
|
|
||||||
# cleaning cache
|
self._clean_cache()
|
||||||
self.dec_attns = []
|
|
||||||
self.kv_cache = {}
|
|
||||||
if self.decoder_type == "beam":
|
|
||||||
self.inference.kv_cache = self.kv_cache
|
|
||||||
self.token_decoder.reset()
|
|
||||||
|
|
||||||
return new_hypothesis, generation
|
return new_hypothesis, generation
|
||||||
@@ -32,7 +32,9 @@ def detect_language(
|
|||||||
list of dictionaries containing the probability distribution over all languages.
|
list of dictionaries containing the probability distribution over all languages.
|
||||||
"""
|
"""
|
||||||
if tokenizer is None:
|
if tokenizer is None:
|
||||||
tokenizer = get_tokenizer(model.is_multilingual)
|
tokenizer = get_tokenizer(
|
||||||
|
model.is_multilingual, num_languages=model.num_languages
|
||||||
|
)
|
||||||
if (
|
if (
|
||||||
tokenizer.language is None
|
tokenizer.language is None
|
||||||
or tokenizer.language_token not in tokenizer.sot_sequence
|
or tokenizer.language_token not in tokenizer.sot_sequence
|
||||||
@@ -111,9 +113,6 @@ class DecodingOptions:
|
|||||||
# implementation details
|
# implementation details
|
||||||
fp16: bool = True # use fp16 for most of the calculation
|
fp16: bool = True # use fp16 for most of the calculation
|
||||||
|
|
||||||
# streaming
|
|
||||||
add_sot: Optional[bool] = True
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class DecodingResult:
|
class DecodingResult:
|
||||||
@@ -513,19 +512,17 @@ class DecodingTask:
|
|||||||
logit_filters: List[LogitFilter]
|
logit_filters: List[LogitFilter]
|
||||||
|
|
||||||
def __init__(self, model: "Whisper", options: DecodingOptions):
|
def __init__(self, model: "Whisper", options: DecodingOptions):
|
||||||
self.options: DecodingOptions = self._verify_options(options)
|
self.model = model
|
||||||
if self.options.fp16:
|
|
||||||
self.model = model.half()
|
|
||||||
else:
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
language = options.language or "en"
|
language = options.language or "en"
|
||||||
tokenizer = get_tokenizer(
|
tokenizer = get_tokenizer(
|
||||||
model.is_multilingual, language=language, task=options.task
|
model.is_multilingual,
|
||||||
|
num_languages=model.num_languages,
|
||||||
|
language=language,
|
||||||
|
task=options.task,
|
||||||
)
|
)
|
||||||
self.tokenizer: Tokenizer = tokenizer
|
self.tokenizer: Tokenizer = tokenizer
|
||||||
|
self.options: DecodingOptions = self._verify_options(options)
|
||||||
# print(self.options)
|
|
||||||
|
|
||||||
self.n_group: int = options.beam_size or options.best_of or 1
|
self.n_group: int = options.beam_size or options.best_of or 1
|
||||||
self.n_ctx: int = model.dims.n_text_ctx
|
self.n_ctx: int = model.dims.n_text_ctx
|
||||||
@@ -589,7 +586,7 @@ class DecodingTask:
|
|||||||
|
|
||||||
def _get_initial_tokens(self) -> Tuple[int]:
|
def _get_initial_tokens(self) -> Tuple[int]:
|
||||||
tokens = list(self.sot_sequence)
|
tokens = list(self.sot_sequence)
|
||||||
# print("prefix", prefix)
|
|
||||||
if prefix := self.options.prefix:
|
if prefix := self.options.prefix:
|
||||||
prefix_tokens = (
|
prefix_tokens = (
|
||||||
self.tokenizer.encode(" " + prefix.strip())
|
self.tokenizer.encode(" " + prefix.strip())
|
||||||
@@ -607,15 +604,12 @@ class DecodingTask:
|
|||||||
if isinstance(prompt, str)
|
if isinstance(prompt, str)
|
||||||
else prompt
|
else prompt
|
||||||
)
|
)
|
||||||
# if self.options.add_sot:
|
|
||||||
tokens = (
|
tokens = (
|
||||||
[self.tokenizer.sot_prev]
|
[self.tokenizer.sot_prev]
|
||||||
+ prompt_tokens[-(self.n_ctx // 2 - 1) :]
|
+ prompt_tokens[-(self.n_ctx // 2 - 1) :]
|
||||||
+ tokens
|
+ tokens
|
||||||
)
|
)
|
||||||
#else:
|
|
||||||
# tokens = ([self.tokenizer.sot_prev] + tokens + prompt_tokens[-(self.n_ctx // 2 - 1) :])
|
|
||||||
# print("return", tokens)
|
|
||||||
return tuple(tokens)
|
return tuple(tokens)
|
||||||
|
|
||||||
def _get_suppress_tokens(self) -> Tuple[int]:
|
def _get_suppress_tokens(self) -> Tuple[int]:
|
||||||
@@ -663,7 +657,7 @@ class DecodingTask:
|
|||||||
if audio_features.dtype != (
|
if audio_features.dtype != (
|
||||||
torch.float16 if self.options.fp16 else torch.float32
|
torch.float16 if self.options.fp16 else torch.float32
|
||||||
):
|
):
|
||||||
raise TypeError(
|
return TypeError(
|
||||||
f"audio_features has an incorrect dtype: {audio_features.dtype}"
|
f"audio_features has an incorrect dtype: {audio_features.dtype}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -689,10 +683,9 @@ class DecodingTask:
|
|||||||
no_speech_probs = [np.nan] * n_batch
|
no_speech_probs = [np.nan] * n_batch
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for i in range(self.sample_len): # 最多循环448次
|
for i in range(self.sample_len):
|
||||||
# print("in decode main loop", i , tokens[0].tolist())
|
|
||||||
logits = self.inference.logits(tokens, audio_features)
|
logits = self.inference.logits(tokens, audio_features)
|
||||||
# print(logits)
|
|
||||||
if (
|
if (
|
||||||
i == 0 and self.tokenizer.no_speech is not None
|
i == 0 and self.tokenizer.no_speech is not None
|
||||||
): # save no_speech_probs
|
): # save no_speech_probs
|
||||||
@@ -724,7 +717,7 @@ class DecodingTask:
|
|||||||
|
|
||||||
audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass
|
audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass
|
||||||
tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
|
tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
|
||||||
# print("initial_tokens", self.initial_tokens)
|
|
||||||
# detect language if requested, overwriting the language token
|
# detect language if requested, overwriting the language token
|
||||||
languages, language_probs = self._detect_language(audio_features, tokens)
|
languages, language_probs = self._detect_language(audio_features, tokens)
|
||||||
if self.options.task == "lang_id":
|
if self.options.task == "lang_id":
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from .decoding import decode as decode_function
|
|||||||
from .decoding import detect_language as detect_language_function
|
from .decoding import detect_language as detect_language_function
|
||||||
from .transcribe import transcribe as transcribe_function
|
from .transcribe import transcribe as transcribe_function
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torch.nn.functional import scaled_dot_product_attention
|
from torch.nn.functional import scaled_dot_product_attention
|
||||||
|
|
||||||
@@ -37,26 +36,27 @@ class ModelDimensions:
|
|||||||
n_text_layer: int
|
n_text_layer: int
|
||||||
|
|
||||||
|
|
||||||
# class LayerNorm(nn.LayerNorm):
|
class LayerNorm(nn.LayerNorm):
|
||||||
# def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
# return super().forward(x.float()).type(x.dtype)
|
return super().forward(x.float()).type(x.dtype)
|
||||||
|
|
||||||
# class Linear(nn.Linear):
|
|
||||||
# def forward(self, x: Tensor) -> Tensor:
|
|
||||||
# return F.linear(
|
|
||||||
# x,
|
|
||||||
# self.weight.to(x.dtype),
|
|
||||||
# None if self.bias is None else self.bias.to(x.dtype),
|
|
||||||
# )
|
|
||||||
|
|
||||||
|
|
||||||
# class Conv1d(nn.Conv1d):
|
class Linear(nn.Linear):
|
||||||
# def _conv_forward(
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
# self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
|
return F.linear(
|
||||||
# ) -> Tensor:
|
x,
|
||||||
# return super()._conv_forward(
|
self.weight.to(x.dtype),
|
||||||
# x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
|
None if self.bias is None else self.bias.to(x.dtype),
|
||||||
# )
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Conv1d(nn.Conv1d):
|
||||||
|
def _conv_forward(
|
||||||
|
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
|
||||||
|
) -> Tensor:
|
||||||
|
return super()._conv_forward(
|
||||||
|
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def sinusoids(length, channels, max_timescale=10000):
|
def sinusoids(length, channels, max_timescale=10000):
|
||||||
@@ -67,21 +67,30 @@ def sinusoids(length, channels, max_timescale=10000):
|
|||||||
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
||||||
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
||||||
|
|
||||||
import sys ## this is mine, for debugging
|
|
||||||
|
@contextmanager
|
||||||
|
def disable_sdpa():
|
||||||
|
prev_state = MultiHeadAttention.use_sdpa
|
||||||
|
try:
|
||||||
|
MultiHeadAttention.use_sdpa = False
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
MultiHeadAttention.use_sdpa = prev_state
|
||||||
|
|
||||||
|
|
||||||
class MultiHeadAttention(nn.Module):
|
class MultiHeadAttention(nn.Module):
|
||||||
|
use_sdpa = False # Disable SDPA to ensure qk is always computed for hooks
|
||||||
|
|
||||||
use_sdpa = False # disabling: https://github.com/linto-ai/whisper-timestamped/issues/212
|
def __init__(self, n_state: int, n_head: int, cache_id: str = ""):
|
||||||
|
|
||||||
def __init__(self, n_state: int, n_head: int, cache_id: str):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.n_head = n_head
|
self.n_head = n_head
|
||||||
self.query = nn.Linear(n_state, n_state)
|
self.query = Linear(n_state, n_state)
|
||||||
self.key = nn.Linear(n_state, n_state, bias=False)
|
self.key = Linear(n_state, n_state, bias=False)
|
||||||
self.key.cache_id = f"{cache_id}_key"
|
self.value = Linear(n_state, n_state)
|
||||||
self.value = nn.Linear(n_state, n_state)
|
self.out = Linear(n_state, n_state)
|
||||||
self.value.cache_id = f"{cache_id}_value"
|
|
||||||
self.out = nn.Linear(n_state, n_state)
|
|
||||||
self.cache_id = cache_id
|
self.cache_id = cache_id
|
||||||
|
self.key.cache_id = f"{cache_id}_key"
|
||||||
|
self.value.cache_id = f"{cache_id}_value"
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -90,45 +99,21 @@ class MultiHeadAttention(nn.Module):
|
|||||||
mask: Optional[Tensor] = None,
|
mask: Optional[Tensor] = None,
|
||||||
kv_cache: Optional[dict] = None,
|
kv_cache: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
#print("MultiHeadAttention forward",file=sys.stderr)
|
|
||||||
q = self.query(x)
|
q = self.query(x)
|
||||||
# print(q.shape, x is None, mask is None, list(kv_cache.keys()) if kv_cache is not None else None, file=sys.stderr)
|
|
||||||
# print(mask, kv_cache, xa, file=sys.stderr)
|
|
||||||
|
|
||||||
if kv_cache is None or xa is None or self.key.cache_id not in kv_cache:
|
if kv_cache is None or xa is None or self.key not in kv_cache:
|
||||||
|
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
|
||||||
|
# otherwise, perform key/value projections for self- or cross-attention as usual.
|
||||||
k = self.key(x if xa is None else xa)
|
k = self.key(x if xa is None else xa)
|
||||||
v = self.value(x if xa is None else xa)
|
v = self.value(x if xa is None else xa)
|
||||||
# print(self.key.cache_id, "cache miss") # , kv_cache is None, xa is None, self.key.cache_id not in kv_cache if kv_cache is not None else None, k.shape, x.shape)
|
|
||||||
# if kv_cache is not None:
|
|
||||||
# print(kv_cache.keys())
|
|
||||||
else:
|
else:
|
||||||
# print(self.key.cache_id, "cache hit") #, kv_cache is None, xa is None, self.key.cache_id not in kv_cache)
|
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
|
||||||
# if kv_cache is not None:
|
k = kv_cache[self.key]
|
||||||
# print(kv_cache.keys())
|
v = kv_cache[self.value]
|
||||||
k = kv_cache[self.key.cache_id]
|
|
||||||
v = kv_cache[self.value.cache_id]
|
|
||||||
# print(self.key.cache_id, "qkv attention", q.shape, k.shape, v.shape)
|
|
||||||
wv, qk = self.qkv_attention(q, k, v, mask)
|
wv, qk = self.qkv_attention(q, k, v, mask)
|
||||||
return self.out(wv), qk
|
return self.out(wv), qk
|
||||||
|
|
||||||
# def qkv_attention(
|
|
||||||
# self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
|
||||||
# ):
|
|
||||||
# n_batch, n_ctx, n_state = q.shape
|
|
||||||
# scale = (n_state // self.n_head) ** -0.25
|
|
||||||
# q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
|
|
||||||
# k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
|
|
||||||
# v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
|
||||||
|
|
||||||
# qk = q @ k
|
|
||||||
# if mask is not None:
|
|
||||||
# qk = qk + mask[:n_ctx, :n_ctx]
|
|
||||||
# # qk = qk.float()
|
|
||||||
|
|
||||||
# w = F.softmax(qk, dim=-1) # .to(q.dtype)
|
|
||||||
# return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
|
|
||||||
|
|
||||||
|
|
||||||
def qkv_attention(
|
def qkv_attention(
|
||||||
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
@@ -158,21 +143,22 @@ class MultiHeadAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class ResidualAttentionBlock(nn.Module):
|
class ResidualAttentionBlock(nn.Module):
|
||||||
def __init__(self, n_state: int, n_head: int, cache_id: str="", cross_attention: bool = False):
|
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, cache_id: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.attn = MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_self_attn")
|
self.attn = MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_self_attn")
|
||||||
self.attn_ln = nn.LayerNorm(n_state)
|
self.attn_ln = LayerNorm(n_state)
|
||||||
|
|
||||||
self.cross_attn = MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_cross_attn") if cross_attention else None
|
self.cross_attn = (
|
||||||
|
MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_cross_attn") if cross_attention else None
|
||||||
self.cross_attn_ln = nn.LayerNorm(n_state) if cross_attention else None
|
)
|
||||||
|
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
||||||
|
|
||||||
n_mlp = n_state * 4
|
n_mlp = n_state * 4
|
||||||
self.mlp = nn.Sequential(
|
self.mlp = nn.Sequential(
|
||||||
nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state)
|
Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
|
||||||
)
|
)
|
||||||
self.mlp_ln = nn.LayerNorm(n_state)
|
self.mlp_ln = LayerNorm(n_state)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -181,8 +167,6 @@ class ResidualAttentionBlock(nn.Module):
|
|||||||
mask: Optional[Tensor] = None,
|
mask: Optional[Tensor] = None,
|
||||||
kv_cache: Optional[dict] = None,
|
kv_cache: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
# print("ResidualAttentionBlock forward",file=sys.stderr)
|
|
||||||
# print(x.shape, file=sys.stderr)
|
|
||||||
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
|
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
|
||||||
if self.cross_attn:
|
if self.cross_attn:
|
||||||
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
|
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
|
||||||
@@ -195,44 +179,32 @@ class AudioEncoder(nn.Module):
|
|||||||
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
||||||
self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
||||||
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
|
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
|
||||||
|
|
||||||
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
||||||
[ResidualAttentionBlock(n_state, n_head, cache_id=f"enc_layer{i}") for i in range(n_layer)]
|
[ResidualAttentionBlock(n_state, n_head, cache_id=f"enc_layer{i}") for i in range(n_layer)]
|
||||||
)
|
)
|
||||||
self.ln_post = nn.LayerNorm(n_state)
|
self.ln_post = LayerNorm(n_state)
|
||||||
|
|
||||||
def forward(self, x: Tensor, return_layer_results: bool=False):
|
def forward(self, x: Tensor):
|
||||||
"""
|
"""
|
||||||
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
||||||
the mel spectrogram of the audio
|
the mel spectrogram of the audio
|
||||||
"""
|
"""
|
||||||
|
|
||||||
x = F.gelu(self.conv1(x))
|
x = F.gelu(self.conv1(x))
|
||||||
x = F.gelu(self.conv2(x))
|
x = F.gelu(self.conv2(x))
|
||||||
x = x.permute(0, 2, 1) # BDT -> BTD
|
x = x.permute(0, 2, 1)
|
||||||
|
|
||||||
# 两层卷积,2倍降采样
|
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
|
||||||
# 最终剩下1500帧
|
x = (x + self.positional_embedding).to(x.dtype)
|
||||||
|
|
||||||
x = (x + self.positional_embedding[:x.shape[1], :]) #.to(x.dtype)
|
|
||||||
|
|
||||||
layer_results = []
|
|
||||||
i = 0
|
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
# print(f"encoder layer {i}")
|
|
||||||
x = block(x)
|
x = block(x)
|
||||||
layer_results.append(x)
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
x = self.ln_post(x)
|
x = self.ln_post(x)
|
||||||
|
return x
|
||||||
if return_layer_results:
|
|
||||||
return x, layer_results
|
|
||||||
else:
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class TextDecoder(nn.Module):
|
class TextDecoder(nn.Module):
|
||||||
@@ -250,7 +222,7 @@ class TextDecoder(nn.Module):
|
|||||||
for i in range(n_layer)
|
for i in range(n_layer)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.ln = nn.LayerNorm(n_state)
|
self.ln = LayerNorm(n_state)
|
||||||
|
|
||||||
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
|
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
|
||||||
self.register_buffer("mask", mask, persistent=False)
|
self.register_buffer("mask", mask, persistent=False)
|
||||||
@@ -262,22 +234,20 @@ class TextDecoder(nn.Module):
|
|||||||
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
|
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
|
||||||
the encoded audio features to be attended on
|
the encoded audio features to be attended on
|
||||||
"""
|
"""
|
||||||
|
|
||||||
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
||||||
x = (
|
x = (
|
||||||
self.token_embedding(x)
|
self.token_embedding(x)
|
||||||
+ self.positional_embedding[offset : offset + x.shape[-1]]
|
+ self.positional_embedding[offset : offset + x.shape[-1]]
|
||||||
)
|
)
|
||||||
# x = x.to(xa.dtype)
|
x = x.to(xa.dtype)
|
||||||
|
|
||||||
i = 0
|
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
# print(f"decoder layer {i}")
|
|
||||||
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
||||||
i += 1
|
|
||||||
|
|
||||||
x = self.ln(x)
|
x = self.ln(x)
|
||||||
logits = x @ torch.transpose(self.token_embedding.weight, 0, 1)
|
logits = (
|
||||||
|
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
|
||||||
|
).float()
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
@@ -300,7 +270,8 @@ class Whisper(nn.Module):
|
|||||||
self.dims.n_text_head,
|
self.dims.n_text_head,
|
||||||
self.dims.n_text_layer,
|
self.dims.n_text_layer,
|
||||||
)
|
)
|
||||||
# use the last half layers for alignment by default; see `set_alignment_heads()` below
|
# use the last half among the decoder layers for time alignment by default;
|
||||||
|
# to use a specific set of heads, see `set_alignment_heads()` below.
|
||||||
all_heads = torch.zeros(
|
all_heads = torch.zeros(
|
||||||
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
|
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
|
||||||
)
|
)
|
||||||
@@ -320,15 +291,11 @@ class Whisper(nn.Module):
|
|||||||
return self.encoder(mel)
|
return self.encoder(mel)
|
||||||
|
|
||||||
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
|
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
|
||||||
# tokens = tokens.to(self.decoder.ln.weight.dtype)
|
|
||||||
# audio_features = audio_features.to(self.decoder.ln.weight.dtype)
|
|
||||||
return self.decoder(tokens, audio_features)
|
return self.decoder(tokens, audio_features)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, mel: torch.Tensor, tokens: torch.Tensor
|
self, mel: torch.Tensor, tokens: torch.Tensor
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> Dict[str, torch.Tensor]:
|
||||||
# mel = mel.to(self.decoder.ln.weight.dtype)
|
|
||||||
# tokens = tokens.to(self.decoder.ln.weight.dtype)
|
|
||||||
return self.decoder(tokens, self.encoder(mel))
|
return self.decoder(tokens, self.encoder(mel))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -343,7 +310,6 @@ class Whisper(nn.Module):
|
|||||||
def num_languages(self):
|
def num_languages(self):
|
||||||
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
|
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
|
||||||
|
|
||||||
# 为decoder加入缓存机制,每次推理时保存上次的k和v,下次推理无需重新计算
|
|
||||||
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
||||||
"""
|
"""
|
||||||
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
|
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
|
||||||
|
|||||||
@@ -30,15 +30,19 @@ def remove_symbols_and_diacritics(s: str, keep=""):
|
|||||||
and drop any diacritics (category 'Mn' and some manual mappings)
|
and drop any diacritics (category 'Mn' and some manual mappings)
|
||||||
"""
|
"""
|
||||||
return "".join(
|
return "".join(
|
||||||
c
|
(
|
||||||
if c in keep
|
c
|
||||||
else ADDITIONAL_DIACRITICS[c]
|
if c in keep
|
||||||
if c in ADDITIONAL_DIACRITICS
|
else (
|
||||||
else ""
|
ADDITIONAL_DIACRITICS[c]
|
||||||
if unicodedata.category(c) == "Mn"
|
if c in ADDITIONAL_DIACRITICS
|
||||||
else " "
|
else (
|
||||||
if unicodedata.category(c)[0] in "MSP"
|
""
|
||||||
else c
|
if unicodedata.category(c) == "Mn"
|
||||||
|
else " " if unicodedata.category(c)[0] in "MSP" else c
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
for c in unicodedata.normalize("NFKD", s)
|
for c in unicodedata.normalize("NFKD", s)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
1741
whisperlivekit/simul_whisper/whisper/normalizers/english.json
Normal file
1741
whisperlivekit/simul_whisper/whisper/normalizers/english.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -56,9 +56,8 @@ def median_filter(x: torch.Tensor, filter_width: int):
|
|||||||
|
|
||||||
@numba.jit(nopython=True)
|
@numba.jit(nopython=True)
|
||||||
def backtrace(trace: np.ndarray):
|
def backtrace(trace: np.ndarray):
|
||||||
i = trace.shape[0] - 1 # trace: (N+1, M+1), i=N
|
i = trace.shape[0] - 1
|
||||||
j = trace.shape[1] - 1 # j=M
|
j = trace.shape[1] - 1
|
||||||
# 边界点其实无意义?
|
|
||||||
trace[0, :] = 2
|
trace[0, :] = 2
|
||||||
trace[:, 0] = 1
|
trace[:, 0] = 1
|
||||||
|
|
||||||
@@ -83,8 +82,8 @@ def backtrace(trace: np.ndarray):
|
|||||||
@numba.jit(nopython=True, parallel=True)
|
@numba.jit(nopython=True, parallel=True)
|
||||||
def dtw_cpu(x: np.ndarray):
|
def dtw_cpu(x: np.ndarray):
|
||||||
N, M = x.shape
|
N, M = x.shape
|
||||||
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf # cost: x[0, 0]到x[i-1, j-1]的最小代价
|
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
|
||||||
trace = -np.ones((N + 1, M + 1), dtype=np.float32) # trace:
|
trace = -np.ones((N + 1, M + 1), dtype=np.float32)
|
||||||
|
|
||||||
cost[0, 0] = 0
|
cost[0, 0] = 0
|
||||||
for j in range(1, M + 1):
|
for j in range(1, M + 1):
|
||||||
@@ -118,7 +117,7 @@ def dtw_cuda(x, BLOCK_SIZE=1024):
|
|||||||
x_skew = x_skew.T.contiguous()
|
x_skew = x_skew.T.contiguous()
|
||||||
cost = torch.ones(N + M + 2, M + 2) * np.inf
|
cost = torch.ones(N + M + 2, M + 2) * np.inf
|
||||||
cost[0, 0] = 0
|
cost[0, 0] = 0
|
||||||
cost = cost.cuda()
|
cost = cost.to(x.device)
|
||||||
trace = torch.zeros_like(cost, dtype=torch.int32)
|
trace = torch.zeros_like(cost, dtype=torch.int32)
|
||||||
|
|
||||||
dtw_kernel[(1,)](
|
dtw_kernel[(1,)](
|
||||||
@@ -192,21 +191,19 @@ def find_alignment(
|
|||||||
for i, block in enumerate(model.decoder.blocks)
|
for i, block in enumerate(model.decoder.blocks)
|
||||||
]
|
]
|
||||||
|
|
||||||
# 进行前传,获得token概率
|
from .model import disable_sdpa
|
||||||
with torch.no_grad():
|
|
||||||
|
with torch.no_grad(), disable_sdpa():
|
||||||
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
|
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
|
||||||
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
|
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
|
||||||
token_probs = sampled_logits.softmax(dim=-1)
|
token_probs = sampled_logits.softmax(dim=-1)
|
||||||
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
|
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
|
||||||
text_token_probs = text_token_probs.tolist()
|
text_token_probs = text_token_probs.tolist()
|
||||||
|
|
||||||
# 移除钩子
|
|
||||||
for hook in hooks:
|
for hook in hooks:
|
||||||
hook.remove()
|
hook.remove()
|
||||||
|
|
||||||
# heads * tokens * frames
|
# heads * tokens * frames
|
||||||
# print(model.alignment_heads)
|
|
||||||
# exit(0)
|
|
||||||
weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
|
weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
|
||||||
weights = weights[:, :, : num_frames // 2]
|
weights = weights[:, :, : num_frames // 2]
|
||||||
weights = (weights * qk_scale).softmax(dim=-1)
|
weights = (weights * qk_scale).softmax(dim=-1)
|
||||||
@@ -215,18 +212,9 @@ def find_alignment(
|
|||||||
weights = median_filter(weights, medfilt_width)
|
weights = median_filter(weights, medfilt_width)
|
||||||
|
|
||||||
matrix = weights.mean(axis=0)
|
matrix = weights.mean(axis=0)
|
||||||
print("attention", matrix.shape, matrix[:5, :5])
|
|
||||||
matrix = matrix[len(tokenizer.sot_sequence) : -1]
|
matrix = matrix[len(tokenizer.sot_sequence) : -1]
|
||||||
print("attention", matrix.shape, matrix[:5, :5])
|
|
||||||
text_indices, time_indices = dtw(-matrix)
|
text_indices, time_indices = dtw(-matrix)
|
||||||
|
|
||||||
print("num_frames", num_frames)
|
|
||||||
print("attention", matrix.shape, matrix[:5, :5])
|
|
||||||
print("text_indices", text_indices)
|
|
||||||
print("time", time_indices)
|
|
||||||
print("text_tokens", text_tokens, tokenizer.decode(text_tokens), len(text_tokens))
|
|
||||||
print("eot", tokenizer.eot)
|
|
||||||
|
|
||||||
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
|
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
|
||||||
if len(word_tokens) <= 1:
|
if len(word_tokens) <= 1:
|
||||||
# return on eot only
|
# return on eot only
|
||||||
@@ -238,9 +226,7 @@ def find_alignment(
|
|||||||
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
|
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
|
||||||
|
|
||||||
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
|
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
|
||||||
# print("jumps", jumps, jumps.shape)
|
|
||||||
jump_times = time_indices[jumps] / TOKENS_PER_SECOND
|
jump_times = time_indices[jumps] / TOKENS_PER_SECOND
|
||||||
# print("jump_times", jump_times)
|
|
||||||
start_times = jump_times[word_boundaries[:-1]]
|
start_times = jump_times[word_boundaries[:-1]]
|
||||||
end_times = jump_times[word_boundaries[1:]]
|
end_times = jump_times[word_boundaries[1:]]
|
||||||
word_probabilities = [
|
word_probabilities = [
|
||||||
@@ -315,6 +301,7 @@ def add_word_timestamps(
|
|||||||
word_durations = np.array([t.end - t.start for t in alignment])
|
word_durations = np.array([t.end - t.start for t in alignment])
|
||||||
word_durations = word_durations[word_durations.nonzero()]
|
word_durations = word_durations[word_durations.nonzero()]
|
||||||
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
|
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
|
||||||
|
median_duration = min(0.7, float(median_duration))
|
||||||
max_duration = median_duration * 2
|
max_duration = median_duration * 2
|
||||||
|
|
||||||
# hack: truncate long words at sentence boundaries.
|
# hack: truncate long words at sentence boundaries.
|
||||||
|
|||||||
@@ -1,501 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import os
|
|
||||||
import warnings
|
|
||||||
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import tqdm
|
|
||||||
|
|
||||||
from whisper.audio import (
|
|
||||||
FRAMES_PER_SECOND,
|
|
||||||
HOP_LENGTH,
|
|
||||||
N_FRAMES,
|
|
||||||
N_SAMPLES,
|
|
||||||
SAMPLE_RATE,
|
|
||||||
log_mel_spectrogram,
|
|
||||||
pad_or_trim,
|
|
||||||
)
|
|
||||||
from whisper.decoding import DecodingOptions, DecodingResult
|
|
||||||
from whisper.timing import add_word_timestamps
|
|
||||||
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
|
||||||
from whisper.utils import (
|
|
||||||
exact_div,
|
|
||||||
format_timestamp,
|
|
||||||
get_writer,
|
|
||||||
make_safe,
|
|
||||||
optional_float,
|
|
||||||
optional_int,
|
|
||||||
str2bool,
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from whisper.model import Whisper
|
|
||||||
|
|
||||||
|
|
||||||
def transcribe(
|
|
||||||
model: "Whisper",
|
|
||||||
audio: Union[str, np.ndarray, torch.Tensor],
|
|
||||||
*,
|
|
||||||
verbose: Optional[bool] = None,
|
|
||||||
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
|
||||||
compression_ratio_threshold: Optional[float] = 2.4,
|
|
||||||
logprob_threshold: Optional[float] = -1.0,
|
|
||||||
no_speech_threshold: Optional[float] = 0.6,
|
|
||||||
condition_on_previous_text: bool = True,
|
|
||||||
initial_prompt: Optional[str] = None,
|
|
||||||
word_timestamps: bool = False,
|
|
||||||
prepend_punctuations: str = "\"'“¿([{-",
|
|
||||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
|
||||||
**decode_options,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Transcribe an audio file using Whisper
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
model: Whisper
|
|
||||||
The Whisper model instance
|
|
||||||
|
|
||||||
audio: Union[str, np.ndarray, torch.Tensor]
|
|
||||||
The path to the audio file to open, or the audio waveform
|
|
||||||
|
|
||||||
verbose: bool
|
|
||||||
Whether to display the text being decoded to the console. If True, displays all the details,
|
|
||||||
If False, displays minimal details. If None, does not display anything
|
|
||||||
|
|
||||||
temperature: Union[float, Tuple[float, ...]]
|
|
||||||
Temperature for sampling. It can be a tuple of temperatures, which will be successively used
|
|
||||||
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
|
|
||||||
|
|
||||||
compression_ratio_threshold: float
|
|
||||||
If the gzip compression ratio is above this value, treat as failed
|
|
||||||
|
|
||||||
logprob_threshold: float
|
|
||||||
If the average log probability over sampled tokens is below this value, treat as failed
|
|
||||||
|
|
||||||
no_speech_threshold: float
|
|
||||||
If the no_speech probability is higher than this value AND the average log probability
|
|
||||||
over sampled tokens is below `logprob_threshold`, consider the segment as silent
|
|
||||||
|
|
||||||
condition_on_previous_text: bool
|
|
||||||
if True, the previous output of the model is provided as a prompt for the next window;
|
|
||||||
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
|
||||||
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
|
||||||
|
|
||||||
word_timestamps: bool
|
|
||||||
Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
|
|
||||||
and include the timestamps for each word in each segment.
|
|
||||||
|
|
||||||
prepend_punctuations: str
|
|
||||||
If word_timestamps is True, merge these punctuation symbols with the next word
|
|
||||||
|
|
||||||
append_punctuations: str
|
|
||||||
If word_timestamps is True, merge these punctuation symbols with the previous word
|
|
||||||
|
|
||||||
initial_prompt: Optional[str]
|
|
||||||
Optional text to provide as a prompt for the first window. This can be used to provide, or
|
|
||||||
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
|
||||||
to make it more likely to predict those word correctly.
|
|
||||||
|
|
||||||
decode_options: dict
|
|
||||||
Keyword arguments to construct `DecodingOptions` instances
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
|
||||||
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
|
||||||
"""
|
|
||||||
# print("HACKED")
|
|
||||||
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
|
|
||||||
if model.device == torch.device("cpu"):
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
warnings.warn("Performing inference on CPU when CUDA is available")
|
|
||||||
if dtype == torch.float16:
|
|
||||||
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
|
|
||||||
dtype = torch.float32
|
|
||||||
|
|
||||||
if dtype == torch.float32:
|
|
||||||
decode_options["fp16"] = False
|
|
||||||
|
|
||||||
# Pad 30-seconds of silence to the input audio, for slicing
|
|
||||||
mel = log_mel_spectrogram(audio, padding=0) # log_mel_spectrogram(audio, padding=N_SAMPLES) # 添加16000*30 = 480000个点
|
|
||||||
# mel = pad_or_trim(mel, 3000)
|
|
||||||
content_frames = mel.shape[-1] # - N_FRAMES # 对应3000帧;真正有内容的是去掉尾部3000的那些数据
|
|
||||||
|
|
||||||
# 判断语种
|
|
||||||
if decode_options.get("language", None) is None:
|
|
||||||
# 如果是单语种模型,直接设成英文
|
|
||||||
if not model.is_multilingual:
|
|
||||||
decode_options["language"] = "en"
|
|
||||||
# 否则需要前传一次
|
|
||||||
else:
|
|
||||||
if verbose:
|
|
||||||
print(
|
|
||||||
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
|
|
||||||
)
|
|
||||||
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
|
||||||
# print(mel_segment.shape)
|
|
||||||
_, probs = model.detect_language(mel_segment)
|
|
||||||
decode_options["language"] = max(probs, key=probs.get)
|
|
||||||
if verbose is not None:
|
|
||||||
print(
|
|
||||||
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
language: str = decode_options["language"]
|
|
||||||
task: str = decode_options.get("task", "transcribe")
|
|
||||||
# 输出编码器
|
|
||||||
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
|
|
||||||
|
|
||||||
# 词级别时间戳
|
|
||||||
if word_timestamps and task == "translate":
|
|
||||||
warnings.warn("Word-level timestamps on translations may not be reliable.")
|
|
||||||
|
|
||||||
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
|
|
||||||
temperatures = (
|
|
||||||
[temperature] if isinstance(temperature, (int, float)) else temperature
|
|
||||||
)
|
|
||||||
decode_result = None
|
|
||||||
|
|
||||||
for t in temperatures:
|
|
||||||
kwargs = {**decode_options}
|
|
||||||
if t > 0:
|
|
||||||
# disable beam_size and patience when t > 0
|
|
||||||
kwargs.pop("beam_size", None)
|
|
||||||
kwargs.pop("patience", None)
|
|
||||||
else:
|
|
||||||
# disable best_of when t == 0
|
|
||||||
kwargs.pop("best_of", None)
|
|
||||||
|
|
||||||
options = DecodingOptions(**kwargs, temperature=t)
|
|
||||||
decode_result = model.decode(segment, options)
|
|
||||||
|
|
||||||
# 几种解码可能失败的情况。这些情况下会重复解码
|
|
||||||
# 感觉是一种KnowHow的东西 或许ChatGPT里有不少这种trick
|
|
||||||
needs_fallback = False
|
|
||||||
if (
|
|
||||||
compression_ratio_threshold is not None
|
|
||||||
and decode_result.compression_ratio > compression_ratio_threshold
|
|
||||||
):
|
|
||||||
needs_fallback = True # too repetitive
|
|
||||||
if (
|
|
||||||
logprob_threshold is not None
|
|
||||||
and decode_result.avg_logprob < logprob_threshold
|
|
||||||
):
|
|
||||||
needs_fallback = True # average log probability is too low
|
|
||||||
if (
|
|
||||||
no_speech_threshold is not None
|
|
||||||
and decode_result.no_speech_prob > no_speech_threshold
|
|
||||||
):
|
|
||||||
needs_fallback = False # silence
|
|
||||||
if not needs_fallback:
|
|
||||||
break
|
|
||||||
# print("decode with temperature {} compress rate {:.3f}/{:.3f}, log_prob {:.3f}/{:.3f}, {:.3f}/{:.3f}".format(
|
|
||||||
# t,
|
|
||||||
# decode_result.compression_ratio, compression_ratio_threshold,
|
|
||||||
# -decode_result.avg_logprob, -logprob_threshold,
|
|
||||||
# decode_result.no_speech_prob, no_speech_threshold
|
|
||||||
# ))
|
|
||||||
|
|
||||||
return decode_result
|
|
||||||
|
|
||||||
seek = 0
|
|
||||||
input_stride = exact_div(
|
|
||||||
N_FRAMES, model.dims.n_audio_ctx
|
|
||||||
) # mel frames per output token: 2
|
|
||||||
# 这里output token指的应该是CNN输出的那个东西
|
|
||||||
|
|
||||||
time_precision = (
|
|
||||||
input_stride * HOP_LENGTH / SAMPLE_RATE
|
|
||||||
) # time per output token: 0.02 (seconds)
|
|
||||||
all_tokens = []
|
|
||||||
all_segments = []
|
|
||||||
prompt_reset_since = 0
|
|
||||||
|
|
||||||
if initial_prompt is not None:
|
|
||||||
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
|
|
||||||
all_tokens.extend(initial_prompt_tokens)
|
|
||||||
else:
|
|
||||||
initial_prompt_tokens = []
|
|
||||||
|
|
||||||
def new_segment(
|
|
||||||
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
|
|
||||||
):
|
|
||||||
tokens = tokens.tolist()
|
|
||||||
text_tokens = [token for token in tokens if token < tokenizer.eot]
|
|
||||||
return {
|
|
||||||
"seek": seek,
|
|
||||||
"start": start,
|
|
||||||
"end": end,
|
|
||||||
"text": tokenizer.decode(text_tokens),
|
|
||||||
"tokens": tokens,
|
|
||||||
"temperature": result.temperature,
|
|
||||||
"avg_logprob": result.avg_logprob,
|
|
||||||
"compression_ratio": result.compression_ratio,
|
|
||||||
"no_speech_prob": result.no_speech_prob,
|
|
||||||
}
|
|
||||||
|
|
||||||
# show the progress bar when verbose is False (if True, transcribed text will be printed)
|
|
||||||
with tqdm.tqdm(
|
|
||||||
total=content_frames, unit="frames", disable=verbose is not False
|
|
||||||
) as pbar:
|
|
||||||
last_speech_timestamp = 0.0
|
|
||||||
while seek < content_frames: # seek:标记mel频谱当前帧的位置 直接跳过Padding上的部分
|
|
||||||
# print("seek segments", seek, content_frames)
|
|
||||||
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) # 本片段的开始时间
|
|
||||||
# mel_segment = mel[:, seek : seek + N_FRAMES] # 获得当前片段的数据
|
|
||||||
mel_segment = mel[:, seek:]
|
|
||||||
segment_size = min(N_FRAMES, content_frames - seek) # segment_size: 排除padding的真的长度。content_frames:有内容的段的真正长度 如果不够N_FRAMES的话就会截断
|
|
||||||
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE # 当前片段的时长
|
|
||||||
mel_segment = mel_segment.to(model.device).to(dtype) # pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype) # 补到mel_segment帧
|
|
||||||
|
|
||||||
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
|
||||||
result: DecodingResult = decode_with_fallback(mel_segment)
|
|
||||||
tokens = torch.tensor(result.tokens)
|
|
||||||
|
|
||||||
# 跳过静音部分
|
|
||||||
if no_speech_threshold is not None:
|
|
||||||
# no voice activity check
|
|
||||||
should_skip = result.no_speech_prob > no_speech_threshold
|
|
||||||
if (
|
|
||||||
logprob_threshold is not None
|
|
||||||
and result.avg_logprob > logprob_threshold
|
|
||||||
):
|
|
||||||
# don't skip if the logprob is high enough, despite the no_speech_prob
|
|
||||||
should_skip = False
|
|
||||||
|
|
||||||
if should_skip:
|
|
||||||
seek += segment_size # fast-forward to the next segment boundary
|
|
||||||
continue
|
|
||||||
|
|
||||||
previous_seek = seek
|
|
||||||
current_segments = []
|
|
||||||
|
|
||||||
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) # timestamp begin是<|0.00|>的token;bos比文字token大,eos的值比bos还大,所以是ge
|
|
||||||
timestamp_tokens[-1] = False
|
|
||||||
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] # 如果最后是[False,True]:本段里一个句子结束了
|
|
||||||
|
|
||||||
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
|
|
||||||
# torch.where(condition) is identical to torch.nonzero(condition, as_tuple=True).
|
|
||||||
# timestamp_token就是个一维向量吧 那为啥不直接nonzero
|
|
||||||
# 如果有两个连续的时间戳 这个会是一个一维tensor 是这两个连续时间戳的结尾位置
|
|
||||||
# 多个的话指向第二个 那如果有三个怎么办?
|
|
||||||
# 否则是个0维tensor
|
|
||||||
|
|
||||||
consecutive.add_(1) # 0维tensor+1还是0维 哪儿找的这些edge cases js是吧
|
|
||||||
if len(consecutive) > 0:
|
|
||||||
# if the output contains two consecutive timestamp tokens
|
|
||||||
slices = consecutive.tolist()
|
|
||||||
if single_timestamp_ending:
|
|
||||||
slices.append(len(tokens)) # 把最后一段的结尾也加进去
|
|
||||||
# print("many sentenses", consecutive)
|
|
||||||
last_slice = 0
|
|
||||||
for current_slice in slices:
|
|
||||||
sliced_tokens = tokens[last_slice:current_slice]
|
|
||||||
# 看起来语音开始帧、语音结束帧的位置会被编码到start_timestamp中
|
|
||||||
start_timestamp_pos = (
|
|
||||||
sliced_tokens[0].item() - tokenizer.timestamp_begin
|
|
||||||
)
|
|
||||||
end_timestamp_pos = (
|
|
||||||
sliced_tokens[-1].item() - tokenizer.timestamp_begin
|
|
||||||
)
|
|
||||||
# 获取一个新的语音段
|
|
||||||
current_segments.append(
|
|
||||||
new_segment(
|
|
||||||
start=time_offset + start_timestamp_pos * time_precision,
|
|
||||||
end=time_offset + end_timestamp_pos * time_precision,
|
|
||||||
tokens=sliced_tokens,
|
|
||||||
result=result,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
last_slice = current_slice
|
|
||||||
|
|
||||||
if single_timestamp_ending:
|
|
||||||
# single timestamp at the end means no speech after the last timestamp.
|
|
||||||
seek += segment_size
|
|
||||||
else:
|
|
||||||
# otherwise, ignore the unfinished segment and seek to the last timestamp
|
|
||||||
# 如果语音尚未结束,那么seek变为上一个结束的语段的位置
|
|
||||||
# 换句话说就是针对30s长的chunk的语音设计的
|
|
||||||
last_timestamp_pos = (
|
|
||||||
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
|
|
||||||
)
|
|
||||||
seek += last_timestamp_pos * input_stride
|
|
||||||
else:
|
|
||||||
duration = segment_duration
|
|
||||||
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
|
||||||
# print(timestamps)
|
|
||||||
if (
|
|
||||||
len(timestamps) > 0
|
|
||||||
and timestamps[-1].item() != tokenizer.timestamp_begin
|
|
||||||
):
|
|
||||||
# no consecutive timestamps but it has a timestamp; use the last one.
|
|
||||||
# 取最后一个;假设要么有一个结束的time stamp;要么有一对儿?
|
|
||||||
# 如果里面只有一个开始的timestamp 似乎后面的东西都会被丢掉?
|
|
||||||
last_timestamp_pos = (
|
|
||||||
timestamps[-1].item() - tokenizer.timestamp_begin
|
|
||||||
)
|
|
||||||
duration = last_timestamp_pos * time_precision
|
|
||||||
|
|
||||||
current_segments.append(
|
|
||||||
new_segment(
|
|
||||||
start=time_offset,
|
|
||||||
end=time_offset + duration,
|
|
||||||
tokens=tokens,
|
|
||||||
result=result,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
seek += segment_size
|
|
||||||
|
|
||||||
# 每个token有自己的时间戳
|
|
||||||
if word_timestamps:
|
|
||||||
add_word_timestamps(
|
|
||||||
segments=current_segments,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
mel=mel_segment,
|
|
||||||
num_frames=segment_size,
|
|
||||||
prepend_punctuations=prepend_punctuations,
|
|
||||||
append_punctuations=append_punctuations,
|
|
||||||
last_speech_timestamp=last_speech_timestamp,
|
|
||||||
)
|
|
||||||
word_end_timestamps = [
|
|
||||||
w["end"] for s in current_segments for w in s["words"]
|
|
||||||
]
|
|
||||||
if len(word_end_timestamps) > 0:
|
|
||||||
last_speech_timestamp = word_end_timestamps[-1]
|
|
||||||
if not single_timestamp_ending and len(word_end_timestamps) > 0:
|
|
||||||
seek_shift = round(
|
|
||||||
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
|
|
||||||
)
|
|
||||||
if seek_shift > 0:
|
|
||||||
seek = previous_seek + seek_shift
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
for segment in current_segments:
|
|
||||||
start, end, text = segment["start"], segment["end"], segment["text"]
|
|
||||||
line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
|
|
||||||
print(make_safe(line))
|
|
||||||
|
|
||||||
# if a segment is instantaneous or does not contain text, clear it
|
|
||||||
for i, segment in enumerate(current_segments):
|
|
||||||
if segment["start"] == segment["end"] or segment["text"].strip() == "":
|
|
||||||
segment["text"] = ""
|
|
||||||
segment["tokens"] = []
|
|
||||||
segment["words"] = []
|
|
||||||
|
|
||||||
# 更新结果
|
|
||||||
all_segments.extend(
|
|
||||||
[
|
|
||||||
{"id": i, **segment}
|
|
||||||
for i, segment in enumerate(
|
|
||||||
current_segments, start=len(all_segments)
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
all_tokens.extend(
|
|
||||||
[token for segment in current_segments for token in segment["tokens"]]
|
|
||||||
)
|
|
||||||
|
|
||||||
if not condition_on_previous_text or result.temperature > 0.5:
|
|
||||||
# do not feed the prompt tokens if a high temperature was used
|
|
||||||
prompt_reset_since = len(all_tokens)
|
|
||||||
|
|
||||||
# update progress bar
|
|
||||||
pbar.update(min(content_frames, seek) - previous_seek)
|
|
||||||
|
|
||||||
# print("太长了")
|
|
||||||
# break
|
|
||||||
|
|
||||||
return dict(
|
|
||||||
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
|
|
||||||
segments=all_segments,
|
|
||||||
language=language,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def cli():
|
|
||||||
from . import available_models
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
||||||
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
|
||||||
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
|
|
||||||
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
|
||||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
|
||||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
|
||||||
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
|
|
||||||
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
|
||||||
|
|
||||||
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
|
||||||
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
|
|
||||||
|
|
||||||
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
|
||||||
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
|
||||||
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
|
|
||||||
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
|
|
||||||
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
|
|
||||||
|
|
||||||
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
|
||||||
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
|
||||||
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
|
||||||
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
|
||||||
|
|
||||||
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
|
|
||||||
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
|
|
||||||
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
|
|
||||||
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
|
||||||
parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
|
|
||||||
parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
|
|
||||||
parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
|
|
||||||
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
|
|
||||||
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
|
|
||||||
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
|
|
||||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
args = parser.parse_args().__dict__
|
|
||||||
model_name: str = args.pop("model")
|
|
||||||
model_dir: str = args.pop("model_dir")
|
|
||||||
output_dir: str = args.pop("output_dir")
|
|
||||||
output_format: str = args.pop("output_format")
|
|
||||||
device: str = args.pop("device")
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
|
||||||
|
|
||||||
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
|
||||||
if args["language"] is not None:
|
|
||||||
warnings.warn(
|
|
||||||
f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
|
|
||||||
)
|
|
||||||
args["language"] = "en"
|
|
||||||
|
|
||||||
temperature = args.pop("temperature")
|
|
||||||
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
|
|
||||||
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
|
|
||||||
else:
|
|
||||||
temperature = [temperature]
|
|
||||||
|
|
||||||
if (threads := args.pop("threads")) > 0:
|
|
||||||
torch.set_num_threads(threads)
|
|
||||||
|
|
||||||
from . import load_model
|
|
||||||
|
|
||||||
model = load_model(model_name, device=device, download_root=model_dir)
|
|
||||||
|
|
||||||
writer = get_writer(output_format, output_dir)
|
|
||||||
word_options = ["highlight_words", "max_line_count", "max_line_width"]
|
|
||||||
if not args["word_timestamps"]:
|
|
||||||
for option in word_options:
|
|
||||||
if args[option]:
|
|
||||||
parser.error(f"--{option} requires --word_timestamps True")
|
|
||||||
if args["max_line_count"] and not args["max_line_width"]:
|
|
||||||
warnings.warn("--max_line_count has no effect without --max_line_width")
|
|
||||||
writer_args = {arg: args.pop(arg) for arg in word_options}
|
|
||||||
for audio_path in args.pop("audio"):
|
|
||||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
|
||||||
writer(result, audio_path, writer_args)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
cli()
|
|
||||||
@@ -1,7 +1,8 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -22,6 +23,7 @@ from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
|||||||
from .utils import (
|
from .utils import (
|
||||||
exact_div,
|
exact_div,
|
||||||
format_timestamp,
|
format_timestamp,
|
||||||
|
get_end,
|
||||||
get_writer,
|
get_writer,
|
||||||
make_safe,
|
make_safe,
|
||||||
optional_float,
|
optional_float,
|
||||||
@@ -44,9 +46,12 @@ def transcribe(
|
|||||||
no_speech_threshold: Optional[float] = 0.6,
|
no_speech_threshold: Optional[float] = 0.6,
|
||||||
condition_on_previous_text: bool = True,
|
condition_on_previous_text: bool = True,
|
||||||
initial_prompt: Optional[str] = None,
|
initial_prompt: Optional[str] = None,
|
||||||
|
carry_initial_prompt: bool = False,
|
||||||
word_timestamps: bool = False,
|
word_timestamps: bool = False,
|
||||||
prepend_punctuations: str = "\"'“¿([{-",
|
prepend_punctuations: str = "\"'“¿([{-",
|
||||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||||
|
clip_timestamps: Union[str, List[float]] = "0",
|
||||||
|
hallucination_silence_threshold: Optional[float] = None,
|
||||||
**decode_options,
|
**decode_options,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -98,15 +103,27 @@ def transcribe(
|
|||||||
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
||||||
to make it more likely to predict those word correctly.
|
to make it more likely to predict those word correctly.
|
||||||
|
|
||||||
|
carry_initial_prompt: bool
|
||||||
|
If carry_initial_prompt is True, `initial_prompt` is prepended to the prompt of each internal
|
||||||
|
`decode()` call. If there is not enough context space at the start of the prompt, it is
|
||||||
|
left-sliced to make space.
|
||||||
|
|
||||||
decode_options: dict
|
decode_options: dict
|
||||||
Keyword arguments to construct `DecodingOptions` instances
|
Keyword arguments to construct `DecodingOptions` instances
|
||||||
|
|
||||||
|
clip_timestamps: Union[str, List[float]]
|
||||||
|
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process.
|
||||||
|
The last end timestamp defaults to the end of the file.
|
||||||
|
|
||||||
|
hallucination_silence_threshold: Optional[float]
|
||||||
|
When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
|
||||||
|
when a possible hallucination is detected
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
||||||
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
||||||
"""
|
"""
|
||||||
# print("transcribe")
|
|
||||||
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
|
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
|
||||||
if model.device == torch.device("cpu"):
|
if model.device == torch.device("cpu"):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@@ -119,8 +136,9 @@ def transcribe(
|
|||||||
decode_options["fp16"] = False
|
decode_options["fp16"] = False
|
||||||
|
|
||||||
# Pad 30-seconds of silence to the input audio, for slicing
|
# Pad 30-seconds of silence to the input audio, for slicing
|
||||||
mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
|
mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
|
||||||
content_frames = mel.shape[-1] - N_FRAMES
|
content_frames = mel.shape[-1] - N_FRAMES
|
||||||
|
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
|
||||||
|
|
||||||
if decode_options.get("language", None) is None:
|
if decode_options.get("language", None) is None:
|
||||||
if not model.is_multilingual:
|
if not model.is_multilingual:
|
||||||
@@ -131,7 +149,6 @@ def transcribe(
|
|||||||
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
|
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
|
||||||
)
|
)
|
||||||
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
||||||
# print(mel_segment.shape)
|
|
||||||
_, probs = model.detect_language(mel_segment)
|
_, probs = model.detect_language(mel_segment)
|
||||||
decode_options["language"] = max(probs, key=probs.get)
|
decode_options["language"] = max(probs, key=probs.get)
|
||||||
if verbose is not None:
|
if verbose is not None:
|
||||||
@@ -141,7 +158,25 @@ def transcribe(
|
|||||||
|
|
||||||
language: str = decode_options["language"]
|
language: str = decode_options["language"]
|
||||||
task: str = decode_options.get("task", "transcribe")
|
task: str = decode_options.get("task", "transcribe")
|
||||||
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
|
tokenizer = get_tokenizer(
|
||||||
|
model.is_multilingual,
|
||||||
|
num_languages=model.num_languages,
|
||||||
|
language=language,
|
||||||
|
task=task,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(clip_timestamps, str):
|
||||||
|
clip_timestamps = [
|
||||||
|
float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])
|
||||||
|
]
|
||||||
|
seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
|
||||||
|
if len(seek_points) == 0:
|
||||||
|
seek_points.append(0)
|
||||||
|
if len(seek_points) % 2 == 1:
|
||||||
|
seek_points.append(content_frames)
|
||||||
|
seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))
|
||||||
|
|
||||||
|
punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"
|
||||||
|
|
||||||
if word_timestamps and task == "translate":
|
if word_timestamps and task == "translate":
|
||||||
warnings.warn("Word-level timestamps on translations may not be reliable.")
|
warnings.warn("Word-level timestamps on translations may not be reliable.")
|
||||||
@@ -179,6 +214,8 @@ def transcribe(
|
|||||||
if (
|
if (
|
||||||
no_speech_threshold is not None
|
no_speech_threshold is not None
|
||||||
and decode_result.no_speech_prob > no_speech_threshold
|
and decode_result.no_speech_prob > no_speech_threshold
|
||||||
|
and logprob_threshold is not None
|
||||||
|
and decode_result.avg_logprob < logprob_threshold
|
||||||
):
|
):
|
||||||
needs_fallback = False # silence
|
needs_fallback = False # silence
|
||||||
if not needs_fallback:
|
if not needs_fallback:
|
||||||
@@ -186,7 +223,8 @@ def transcribe(
|
|||||||
|
|
||||||
return decode_result
|
return decode_result
|
||||||
|
|
||||||
seek = 0
|
clip_idx = 0
|
||||||
|
seek = seek_clips[clip_idx][0]
|
||||||
input_stride = exact_div(
|
input_stride = exact_div(
|
||||||
N_FRAMES, model.dims.n_audio_ctx
|
N_FRAMES, model.dims.n_audio_ctx
|
||||||
) # mel frames per output token: 2
|
) # mel frames per output token: 2
|
||||||
@@ -197,9 +235,11 @@ def transcribe(
|
|||||||
all_segments = []
|
all_segments = []
|
||||||
prompt_reset_since = 0
|
prompt_reset_since = 0
|
||||||
|
|
||||||
|
remaining_prompt_length = model.dims.n_text_ctx // 2 - 1
|
||||||
if initial_prompt is not None:
|
if initial_prompt is not None:
|
||||||
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
|
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
|
||||||
all_tokens.extend(initial_prompt_tokens)
|
all_tokens.extend(initial_prompt_tokens)
|
||||||
|
remaining_prompt_length -= len(initial_prompt_tokens)
|
||||||
else:
|
else:
|
||||||
initial_prompt_tokens = []
|
initial_prompt_tokens = []
|
||||||
|
|
||||||
@@ -225,16 +265,33 @@ def transcribe(
|
|||||||
total=content_frames, unit="frames", disable=verbose is not False
|
total=content_frames, unit="frames", disable=verbose is not False
|
||||||
) as pbar:
|
) as pbar:
|
||||||
last_speech_timestamp = 0.0
|
last_speech_timestamp = 0.0
|
||||||
while seek < content_frames:
|
# NOTE: This loop is obscurely flattened to make the diff readable.
|
||||||
|
# A later commit should turn this into a simpler nested loop.
|
||||||
|
# for seek_clip_start, seek_clip_end in seek_clips:
|
||||||
|
# while seek < seek_clip_end
|
||||||
|
while clip_idx < len(seek_clips):
|
||||||
|
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
|
||||||
|
if seek < seek_clip_start:
|
||||||
|
seek = seek_clip_start
|
||||||
|
if seek >= seek_clip_end:
|
||||||
|
clip_idx += 1
|
||||||
|
if clip_idx < len(seek_clips):
|
||||||
|
seek = seek_clips[clip_idx][0]
|
||||||
|
continue
|
||||||
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
||||||
mel_segment = mel[:, seek : seek + N_FRAMES]
|
window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
|
||||||
segment_size = min(N_FRAMES, content_frames - seek)
|
segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek)
|
||||||
|
mel_segment = mel[:, seek : seek + segment_size]
|
||||||
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
||||||
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
|
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
|
||||||
|
|
||||||
# print("melshape", mel_segment.shape)
|
if carry_initial_prompt:
|
||||||
|
nignored = max(len(initial_prompt_tokens), prompt_reset_since)
|
||||||
|
remaining_prompt = all_tokens[nignored:][-remaining_prompt_length:]
|
||||||
|
decode_options["prompt"] = initial_prompt_tokens + remaining_prompt
|
||||||
|
else:
|
||||||
|
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
||||||
|
|
||||||
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
|
||||||
result: DecodingResult = decode_with_fallback(mel_segment)
|
result: DecodingResult = decode_with_fallback(mel_segment)
|
||||||
tokens = torch.tensor(result.tokens)
|
tokens = torch.tensor(result.tokens)
|
||||||
|
|
||||||
@@ -255,6 +312,30 @@ def transcribe(
|
|||||||
previous_seek = seek
|
previous_seek = seek
|
||||||
current_segments = []
|
current_segments = []
|
||||||
|
|
||||||
|
# anomalous words are very long/short/improbable
|
||||||
|
def word_anomaly_score(word: dict) -> float:
|
||||||
|
probability = word.get("probability", 0.0)
|
||||||
|
duration = word["end"] - word["start"]
|
||||||
|
score = 0.0
|
||||||
|
if probability < 0.15:
|
||||||
|
score += 1.0
|
||||||
|
if duration < 0.133:
|
||||||
|
score += (0.133 - duration) * 15
|
||||||
|
if duration > 2.0:
|
||||||
|
score += duration - 2.0
|
||||||
|
return score
|
||||||
|
|
||||||
|
def is_segment_anomaly(segment: Optional[dict]) -> bool:
|
||||||
|
if segment is None or not segment["words"]:
|
||||||
|
return False
|
||||||
|
words = [w for w in segment["words"] if w["word"] not in punctuation]
|
||||||
|
words = words[:8]
|
||||||
|
score = sum(word_anomaly_score(w) for w in words)
|
||||||
|
return score >= 3 or score + 0.01 >= len(words)
|
||||||
|
|
||||||
|
def next_words_segment(segments: List[dict]) -> Optional[dict]:
|
||||||
|
return next((s for s in segments if s["words"]), None)
|
||||||
|
|
||||||
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
||||||
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
|
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
|
||||||
|
|
||||||
@@ -317,9 +398,7 @@ def transcribe(
|
|||||||
)
|
)
|
||||||
seek += segment_size
|
seek += segment_size
|
||||||
|
|
||||||
# print("word_timestamps, ", word_timestamps)
|
|
||||||
if word_timestamps:
|
if word_timestamps:
|
||||||
# print("=========run timestamps here=========")
|
|
||||||
add_word_timestamps(
|
add_word_timestamps(
|
||||||
segments=current_segments,
|
segments=current_segments,
|
||||||
model=model,
|
model=model,
|
||||||
@@ -330,17 +409,71 @@ def transcribe(
|
|||||||
append_punctuations=append_punctuations,
|
append_punctuations=append_punctuations,
|
||||||
last_speech_timestamp=last_speech_timestamp,
|
last_speech_timestamp=last_speech_timestamp,
|
||||||
)
|
)
|
||||||
word_end_timestamps = [
|
|
||||||
w["end"] for s in current_segments for w in s["words"]
|
if not single_timestamp_ending:
|
||||||
]
|
last_word_end = get_end(current_segments)
|
||||||
if len(word_end_timestamps) > 0:
|
if last_word_end is not None and last_word_end > time_offset:
|
||||||
last_speech_timestamp = word_end_timestamps[-1]
|
seek = round(last_word_end * FRAMES_PER_SECOND)
|
||||||
if not single_timestamp_ending and len(word_end_timestamps) > 0:
|
|
||||||
seek_shift = round(
|
# skip silence before possible hallucinations
|
||||||
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
|
if hallucination_silence_threshold is not None:
|
||||||
)
|
threshold = hallucination_silence_threshold
|
||||||
if seek_shift > 0:
|
if not single_timestamp_ending:
|
||||||
seek = previous_seek + seek_shift
|
last_word_end = get_end(current_segments)
|
||||||
|
if last_word_end is not None and last_word_end > time_offset:
|
||||||
|
remaining_duration = window_end_time - last_word_end
|
||||||
|
if remaining_duration > threshold:
|
||||||
|
seek = round(last_word_end * FRAMES_PER_SECOND)
|
||||||
|
else:
|
||||||
|
seek = previous_seek + segment_size
|
||||||
|
|
||||||
|
# if first segment might be a hallucination, skip leading silence
|
||||||
|
first_segment = next_words_segment(current_segments)
|
||||||
|
if first_segment is not None and is_segment_anomaly(first_segment):
|
||||||
|
gap = first_segment["start"] - time_offset
|
||||||
|
if gap > threshold:
|
||||||
|
seek = previous_seek + round(gap * FRAMES_PER_SECOND)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# skip silence before any possible hallucination that is surrounded
|
||||||
|
# by silence or more hallucinations
|
||||||
|
hal_last_end = last_speech_timestamp
|
||||||
|
for si in range(len(current_segments)):
|
||||||
|
segment = current_segments[si]
|
||||||
|
if not segment["words"]:
|
||||||
|
continue
|
||||||
|
if is_segment_anomaly(segment):
|
||||||
|
next_segment = next_words_segment(
|
||||||
|
current_segments[si + 1 :]
|
||||||
|
)
|
||||||
|
if next_segment is not None:
|
||||||
|
hal_next_start = next_segment["words"][0]["start"]
|
||||||
|
else:
|
||||||
|
hal_next_start = time_offset + segment_duration
|
||||||
|
silence_before = (
|
||||||
|
segment["start"] - hal_last_end > threshold
|
||||||
|
or segment["start"] < threshold
|
||||||
|
or segment["start"] - time_offset < 2.0
|
||||||
|
)
|
||||||
|
silence_after = (
|
||||||
|
hal_next_start - segment["end"] > threshold
|
||||||
|
or is_segment_anomaly(next_segment)
|
||||||
|
or window_end_time - segment["end"] < 2.0
|
||||||
|
)
|
||||||
|
if silence_before and silence_after:
|
||||||
|
seek = round(
|
||||||
|
max(time_offset + 1, segment["start"])
|
||||||
|
* FRAMES_PER_SECOND
|
||||||
|
)
|
||||||
|
if content_duration - segment["end"] < threshold:
|
||||||
|
seek = content_frames
|
||||||
|
current_segments[si:] = []
|
||||||
|
break
|
||||||
|
hal_last_end = segment["end"]
|
||||||
|
|
||||||
|
last_word_end = get_end(current_segments)
|
||||||
|
if last_word_end is not None:
|
||||||
|
last_speech_timestamp = last_word_end
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
for segment in current_segments:
|
for segment in current_segments:
|
||||||
@@ -384,10 +517,17 @@ def transcribe(
|
|||||||
def cli():
|
def cli():
|
||||||
from . import available_models
|
from . import available_models
|
||||||
|
|
||||||
|
def valid_model_name(name):
|
||||||
|
if name in available_models() or os.path.exists(name):
|
||||||
|
return name
|
||||||
|
raise ValueError(
|
||||||
|
f"model should be one of {available_models()} or path to a model checkpoint"
|
||||||
|
)
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
||||||
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
|
parser.add_argument("--model", default="turbo", type=valid_model_name, help="name of the Whisper model to use")
|
||||||
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
||||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
||||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||||
@@ -405,6 +545,8 @@ def cli():
|
|||||||
|
|
||||||
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
||||||
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
||||||
|
parser.add_argument("--carry_initial_prompt", type=str2bool, default=False, help="if True, prepend initial_prompt to every internal decode() call. May reduce the effectiveness of condition_on_previous_text")
|
||||||
|
|
||||||
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
||||||
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
||||||
|
|
||||||
@@ -418,7 +560,10 @@ def cli():
|
|||||||
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
|
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
|
||||||
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
|
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
|
||||||
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
|
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
|
||||||
|
parser.add_argument("--max_words_per_line", type=optional_int, default=None, help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment")
|
||||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||||
|
parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file")
|
||||||
|
parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
args = parser.parse_args().__dict__
|
args = parser.parse_args().__dict__
|
||||||
@@ -450,17 +595,28 @@ def cli():
|
|||||||
model = load_model(model_name, device=device, download_root=model_dir)
|
model = load_model(model_name, device=device, download_root=model_dir)
|
||||||
|
|
||||||
writer = get_writer(output_format, output_dir)
|
writer = get_writer(output_format, output_dir)
|
||||||
word_options = ["highlight_words", "max_line_count", "max_line_width"]
|
word_options = [
|
||||||
|
"highlight_words",
|
||||||
|
"max_line_count",
|
||||||
|
"max_line_width",
|
||||||
|
"max_words_per_line",
|
||||||
|
]
|
||||||
if not args["word_timestamps"]:
|
if not args["word_timestamps"]:
|
||||||
for option in word_options:
|
for option in word_options:
|
||||||
if args[option]:
|
if args[option]:
|
||||||
parser.error(f"--{option} requires --word_timestamps True")
|
parser.error(f"--{option} requires --word_timestamps True")
|
||||||
if args["max_line_count"] and not args["max_line_width"]:
|
if args["max_line_count"] and not args["max_line_width"]:
|
||||||
warnings.warn("--max_line_count has no effect without --max_line_width")
|
warnings.warn("--max_line_count has no effect without --max_line_width")
|
||||||
|
if args["max_words_per_line"] and args["max_line_width"]:
|
||||||
|
warnings.warn("--max_words_per_line has no effect with --max_line_width")
|
||||||
writer_args = {arg: args.pop(arg) for arg in word_options}
|
writer_args = {arg: args.pop(arg) for arg in word_options}
|
||||||
for audio_path in args.pop("audio"):
|
for audio_path in args.pop("audio"):
|
||||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
try:
|
||||||
writer(result, audio_path, writer_args)
|
result = transcribe(model, audio_path, temperature=temperature, **args)
|
||||||
|
writer(result, audio_path, **writer_args)
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ def median_kernel(filter_width: int):
|
|||||||
tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
|
tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
|
||||||
|
|
||||||
kernel = triton.JITFunction(kernel.fn)
|
kernel = triton.JITFunction(kernel.fn)
|
||||||
kernel.src = kernel.src.replace(
|
new_kernel = kernel.src.replace(
|
||||||
" LOAD_ALL_ROWS_HERE",
|
" LOAD_ALL_ROWS_HERE",
|
||||||
"\n".join(
|
"\n".join(
|
||||||
[
|
[
|
||||||
@@ -69,7 +69,8 @@ def median_kernel(filter_width: int):
|
|||||||
]
|
]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
kernel.src = kernel.src.replace(
|
|
||||||
|
new_kernel = new_kernel.replace(
|
||||||
" BUBBLESORT_HERE",
|
" BUBBLESORT_HERE",
|
||||||
"\n\n".join(
|
"\n\n".join(
|
||||||
[
|
[
|
||||||
@@ -90,7 +91,14 @@ def median_kernel(filter_width: int):
|
|||||||
]
|
]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
|
|
||||||
|
new_kernel = new_kernel.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
|
||||||
|
|
||||||
|
if hasattr(kernel, "_unsafe_update_src") is True:
|
||||||
|
kernel._unsafe_update_src(new_kernel)
|
||||||
|
kernel.hash = None
|
||||||
|
else:
|
||||||
|
kernel.src = new_kernel
|
||||||
|
|
||||||
return kernel
|
return kernel
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import zlib
|
import zlib
|
||||||
from typing import Callable, Optional, TextIO
|
from typing import Callable, List, Optional, TextIO
|
||||||
|
|
||||||
system_encoding = sys.getdefaultencoding()
|
system_encoding = sys.getdefaultencoding()
|
||||||
|
|
||||||
@@ -68,13 +68,29 @@ def format_timestamp(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_start(segments: List[dict]) -> Optional[float]:
|
||||||
|
return next(
|
||||||
|
(w["start"] for s in segments for w in s["words"]),
|
||||||
|
segments[0]["start"] if segments else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_end(segments: List[dict]) -> Optional[float]:
|
||||||
|
return next(
|
||||||
|
(w["end"] for s in reversed(segments) for w in reversed(s["words"])),
|
||||||
|
segments[-1]["end"] if segments else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ResultWriter:
|
class ResultWriter:
|
||||||
extension: str
|
extension: str
|
||||||
|
|
||||||
def __init__(self, output_dir: str):
|
def __init__(self, output_dir: str):
|
||||||
self.output_dir = output_dir
|
self.output_dir = output_dir
|
||||||
|
|
||||||
def __call__(self, result: dict, audio_path: str, options: dict):
|
def __call__(
|
||||||
|
self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
audio_basename = os.path.basename(audio_path)
|
audio_basename = os.path.basename(audio_path)
|
||||||
audio_basename = os.path.splitext(audio_basename)[0]
|
audio_basename = os.path.splitext(audio_basename)[0]
|
||||||
output_path = os.path.join(
|
output_path = os.path.join(
|
||||||
@@ -82,16 +98,20 @@ class ResultWriter:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with open(output_path, "w", encoding="utf-8") as f:
|
with open(output_path, "w", encoding="utf-8") as f:
|
||||||
self.write_result(result, file=f, options=options)
|
self.write_result(result, file=f, options=options, **kwargs)
|
||||||
|
|
||||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
def write_result(
|
||||||
|
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class WriteTXT(ResultWriter):
|
class WriteTXT(ResultWriter):
|
||||||
extension: str = "txt"
|
extension: str = "txt"
|
||||||
|
|
||||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
def write_result(
|
||||||
|
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
for segment in result["segments"]:
|
for segment in result["segments"]:
|
||||||
print(segment["text"].strip(), file=file, flush=True)
|
print(segment["text"].strip(), file=file, flush=True)
|
||||||
|
|
||||||
@@ -100,48 +120,76 @@ class SubtitlesWriter(ResultWriter):
|
|||||||
always_include_hours: bool
|
always_include_hours: bool
|
||||||
decimal_marker: str
|
decimal_marker: str
|
||||||
|
|
||||||
def iterate_result(self, result: dict, options: dict):
|
def iterate_result(
|
||||||
raw_max_line_width: Optional[int] = options["max_line_width"]
|
self,
|
||||||
max_line_count: Optional[int] = options["max_line_count"]
|
result: dict,
|
||||||
highlight_words: bool = options["highlight_words"]
|
options: Optional[dict] = None,
|
||||||
max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width
|
*,
|
||||||
preserve_segments = max_line_count is None or raw_max_line_width is None
|
max_line_width: Optional[int] = None,
|
||||||
|
max_line_count: Optional[int] = None,
|
||||||
|
highlight_words: bool = False,
|
||||||
|
max_words_per_line: Optional[int] = None,
|
||||||
|
):
|
||||||
|
options = options or {}
|
||||||
|
max_line_width = max_line_width or options.get("max_line_width")
|
||||||
|
max_line_count = max_line_count or options.get("max_line_count")
|
||||||
|
highlight_words = highlight_words or options.get("highlight_words", False)
|
||||||
|
max_words_per_line = max_words_per_line or options.get("max_words_per_line")
|
||||||
|
preserve_segments = max_line_count is None or max_line_width is None
|
||||||
|
max_line_width = max_line_width or 1000
|
||||||
|
max_words_per_line = max_words_per_line or 1000
|
||||||
|
|
||||||
def iterate_subtitles():
|
def iterate_subtitles():
|
||||||
line_len = 0
|
line_len = 0
|
||||||
line_count = 1
|
line_count = 1
|
||||||
# the next subtitle to yield (a list of word timings with whitespace)
|
# the next subtitle to yield (a list of word timings with whitespace)
|
||||||
subtitle: list[dict] = []
|
subtitle: List[dict] = []
|
||||||
last = result["segments"][0]["words"][0]["start"]
|
last: float = get_start(result["segments"]) or 0.0
|
||||||
for segment in result["segments"]:
|
for segment in result["segments"]:
|
||||||
for i, original_timing in enumerate(segment["words"]):
|
chunk_index = 0
|
||||||
timing = original_timing.copy()
|
words_count = max_words_per_line
|
||||||
long_pause = not preserve_segments and timing["start"] - last > 3.0
|
while chunk_index < len(segment["words"]):
|
||||||
has_room = line_len + len(timing["word"]) <= max_line_width
|
remaining_words = len(segment["words"]) - chunk_index
|
||||||
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
|
if max_words_per_line > len(segment["words"]) - chunk_index:
|
||||||
if line_len > 0 and has_room and not long_pause and not seg_break:
|
words_count = remaining_words
|
||||||
# line continuation
|
for i, original_timing in enumerate(
|
||||||
line_len += len(timing["word"])
|
segment["words"][chunk_index : chunk_index + words_count]
|
||||||
else:
|
):
|
||||||
# new line
|
timing = original_timing.copy()
|
||||||
timing["word"] = timing["word"].strip()
|
long_pause = (
|
||||||
|
not preserve_segments and timing["start"] - last > 3.0
|
||||||
|
)
|
||||||
|
has_room = line_len + len(timing["word"]) <= max_line_width
|
||||||
|
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
|
||||||
if (
|
if (
|
||||||
len(subtitle) > 0
|
line_len > 0
|
||||||
and max_line_count is not None
|
and has_room
|
||||||
and (long_pause or line_count >= max_line_count)
|
and not long_pause
|
||||||
or seg_break
|
and not seg_break
|
||||||
):
|
):
|
||||||
# subtitle break
|
# line continuation
|
||||||
yield subtitle
|
line_len += len(timing["word"])
|
||||||
subtitle = []
|
else:
|
||||||
line_count = 1
|
# new line
|
||||||
elif line_len > 0:
|
timing["word"] = timing["word"].strip()
|
||||||
# line break
|
if (
|
||||||
line_count += 1
|
len(subtitle) > 0
|
||||||
timing["word"] = "\n" + timing["word"]
|
and max_line_count is not None
|
||||||
line_len = len(timing["word"].strip())
|
and (long_pause or line_count >= max_line_count)
|
||||||
subtitle.append(timing)
|
or seg_break
|
||||||
last = timing["start"]
|
):
|
||||||
|
# subtitle break
|
||||||
|
yield subtitle
|
||||||
|
subtitle = []
|
||||||
|
line_count = 1
|
||||||
|
elif line_len > 0:
|
||||||
|
# line break
|
||||||
|
line_count += 1
|
||||||
|
timing["word"] = "\n" + timing["word"]
|
||||||
|
line_len = len(timing["word"].strip())
|
||||||
|
subtitle.append(timing)
|
||||||
|
last = timing["start"]
|
||||||
|
chunk_index += max_words_per_line
|
||||||
if len(subtitle) > 0:
|
if len(subtitle) > 0:
|
||||||
yield subtitle
|
yield subtitle
|
||||||
|
|
||||||
@@ -161,9 +209,11 @@ class SubtitlesWriter(ResultWriter):
|
|||||||
|
|
||||||
yield start, end, "".join(
|
yield start, end, "".join(
|
||||||
[
|
[
|
||||||
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
|
(
|
||||||
if j == i
|
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
|
||||||
else word
|
if j == i
|
||||||
|
else word
|
||||||
|
)
|
||||||
for j, word in enumerate(all_words)
|
for j, word in enumerate(all_words)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -190,9 +240,11 @@ class WriteVTT(SubtitlesWriter):
|
|||||||
always_include_hours: bool = False
|
always_include_hours: bool = False
|
||||||
decimal_marker: str = "."
|
decimal_marker: str = "."
|
||||||
|
|
||||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
def write_result(
|
||||||
|
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
print("WEBVTT\n", file=file)
|
print("WEBVTT\n", file=file)
|
||||||
for start, end, text in self.iterate_result(result, options):
|
for start, end, text in self.iterate_result(result, options, **kwargs):
|
||||||
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
|
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
|
||||||
|
|
||||||
|
|
||||||
@@ -201,9 +253,11 @@ class WriteSRT(SubtitlesWriter):
|
|||||||
always_include_hours: bool = True
|
always_include_hours: bool = True
|
||||||
decimal_marker: str = ","
|
decimal_marker: str = ","
|
||||||
|
|
||||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
def write_result(
|
||||||
|
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
for i, (start, end, text) in enumerate(
|
for i, (start, end, text) in enumerate(
|
||||||
self.iterate_result(result, options), start=1
|
self.iterate_result(result, options, **kwargs), start=1
|
||||||
):
|
):
|
||||||
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
|
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
|
||||||
|
|
||||||
@@ -220,7 +274,9 @@ class WriteTSV(ResultWriter):
|
|||||||
|
|
||||||
extension: str = "tsv"
|
extension: str = "tsv"
|
||||||
|
|
||||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
def write_result(
|
||||||
|
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
print("start", "end", "text", sep="\t", file=file)
|
print("start", "end", "text", sep="\t", file=file)
|
||||||
for segment in result["segments"]:
|
for segment in result["segments"]:
|
||||||
print(round(1000 * segment["start"]), file=file, end="\t")
|
print(round(1000 * segment["start"]), file=file, end="\t")
|
||||||
@@ -231,7 +287,9 @@ class WriteTSV(ResultWriter):
|
|||||||
class WriteJSON(ResultWriter):
|
class WriteJSON(ResultWriter):
|
||||||
extension: str = "json"
|
extension: str = "json"
|
||||||
|
|
||||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
def write_result(
|
||||||
|
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
json.dump(result, file)
|
json.dump(result, file)
|
||||||
|
|
||||||
|
|
||||||
@@ -249,9 +307,11 @@ def get_writer(
|
|||||||
if output_format == "all":
|
if output_format == "all":
|
||||||
all_writers = [writer(output_dir) for writer in writers.values()]
|
all_writers = [writer(output_dir) for writer in writers.values()]
|
||||||
|
|
||||||
def write_all(result: dict, file: TextIO, options: dict):
|
def write_all(
|
||||||
|
result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
for writer in all_writers:
|
for writer in all_writers:
|
||||||
writer(result, file, options)
|
writer(result, file, options, **kwargs)
|
||||||
|
|
||||||
return write_all
|
return write_all
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = "20230918"
|
__version__ = "20250625"
|
||||||
|
|||||||
@@ -1,73 +0,0 @@
|
|||||||
import torch
|
|
||||||
import sys
|
|
||||||
class TokenBuffer:
|
|
||||||
|
|
||||||
def __init__(self, text="", tokenizer=None, device=None, prefix_token_ids=[]):
|
|
||||||
self.text = text
|
|
||||||
self.prefix_token_ids = prefix_token_ids
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
def as_token_ids(self, tokenizer=None):
|
|
||||||
|
|
||||||
if tokenizer is None:
|
|
||||||
tokenizer = self.tokenizer
|
|
||||||
if tokenizer is None:
|
|
||||||
raise ValueError("Tokenizer is not set.")
|
|
||||||
return self.prefix_token_ids + tokenizer.encode(self.text)
|
|
||||||
|
|
||||||
def as_tensor(self, device=None):
|
|
||||||
if device is None:
|
|
||||||
device = self.device
|
|
||||||
if device is None:
|
|
||||||
raise ValueError("Device is not set.")
|
|
||||||
tok_ids = self.as_token_ids()
|
|
||||||
return torch.tensor(tok_ids,
|
|
||||||
dtype=torch.long, device=device).unsqueeze(0)
|
|
||||||
|
|
||||||
def as_tensor_beam(self, beam, device=None):
|
|
||||||
t = self.as_tensor(device=device)
|
|
||||||
return t.repeat_interleave(beam, dim=0)
|
|
||||||
|
|
||||||
|
|
||||||
def as_text(self):
|
|
||||||
return self.text
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def empty(*a, **kw):
|
|
||||||
return TokenBuffer(*a,**kw)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_text(text, *a, **kw):
|
|
||||||
return TokenBuffer(*a, text=text, **kw)
|
|
||||||
|
|
||||||
def is_empty(self):
|
|
||||||
return self.text is None or self.text == ""
|
|
||||||
|
|
||||||
def trim_words(self, num=1, after=0):
|
|
||||||
'''
|
|
||||||
num: how many words to trim from the beginning
|
|
||||||
after: how many characters to skip (length of the static prompt)
|
|
||||||
'''
|
|
||||||
tokenizer = self.tokenizer
|
|
||||||
assert tokenizer is not None, "Tokenizer is not set."
|
|
||||||
|
|
||||||
ids = tokenizer.encode(self.text[after:])
|
|
||||||
words, wids = self.tokenizer.split_to_word_tokens(ids)
|
|
||||||
print(words, file=sys.stderr)
|
|
||||||
print(wids, file=sys.stderr)
|
|
||||||
if not words:
|
|
||||||
return 0
|
|
||||||
self.text = self.text[:after] + "".join(words[num:])
|
|
||||||
return sum(len(wi) for wi in wids[:num])
|
|
||||||
|
|
||||||
def append_token_ids(self, token_ids):
|
|
||||||
tokenizer = self.tokenizer
|
|
||||||
assert tokenizer is not None, "Tokenizer is not set."
|
|
||||||
self.text += self.tokenizer.decode(token_ids)
|
|
||||||
|
|
||||||
def as_split_word_tokens(self):
|
|
||||||
tokenizer = self.tokenizer
|
|
||||||
assert tokenizer is not None, "Tokenizer is not set."
|
|
||||||
ids = tokenizer.encode(self.text)
|
|
||||||
return tokenizer.split_to_word_tokens(ids)
|
|
||||||
62
whisperlivekit/warmup.py
Normal file
62
whisperlivekit/warmup.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def load_file(warmup_file=None, timeout=5):
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import librosa
|
||||||
|
|
||||||
|
if warmup_file is None:
|
||||||
|
# Download JFK sample if not already present
|
||||||
|
jfk_url = "https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav"
|
||||||
|
temp_dir = tempfile.gettempdir()
|
||||||
|
warmup_file = os.path.join(temp_dir, "whisper_warmup_jfk.wav")
|
||||||
|
|
||||||
|
if not os.path.exists(warmup_file):
|
||||||
|
logger.debug(f"Downloading warmup file from {jfk_url}")
|
||||||
|
print(f"Downloading warmup file from {jfk_url}")
|
||||||
|
import time
|
||||||
|
import urllib.request
|
||||||
|
import urllib.error
|
||||||
|
import socket
|
||||||
|
|
||||||
|
original_timeout = socket.getdefaulttimeout()
|
||||||
|
socket.setdefaulttimeout(timeout)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
urllib.request.urlretrieve(jfk_url, warmup_file)
|
||||||
|
logger.debug(f"Download successful in {time.time() - start_time:.2f}s")
|
||||||
|
except (urllib.error.URLError, socket.timeout) as e:
|
||||||
|
logger.warning(f"Download failed: {e}. Proceeding without warmup.")
|
||||||
|
return False
|
||||||
|
finally:
|
||||||
|
socket.setdefaulttimeout(original_timeout)
|
||||||
|
elif not warmup_file:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not warmup_file or not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0:
|
||||||
|
logger.warning(f"Warmup file {warmup_file} invalid or missing.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
audio, sr = librosa.load(warmup_file, sr=16000)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load audio file: {e}")
|
||||||
|
return False
|
||||||
|
return audio
|
||||||
|
|
||||||
|
def warmup_asr(asr, warmup_file=None, timeout=5):
|
||||||
|
"""
|
||||||
|
Warmup the ASR model by transcribing a short audio file.
|
||||||
|
"""
|
||||||
|
audio = load_file(warmup_file=None, timeout=5)
|
||||||
|
asr.transcribe(audio)
|
||||||
|
logger.info("ASR model is warmed up")
|
||||||
|
|
||||||
|
def warmup_online(online, warmup_file=None, timeout=5):
|
||||||
|
audio = load_file(warmup_file=None, timeout=5)
|
||||||
|
online.warmup(audio)
|
||||||
|
logger.warning("ASR is warmed up")
|
||||||
@@ -4,12 +4,87 @@
|
|||||||
<head>
|
<head>
|
||||||
<meta charset="UTF-8" />
|
<meta charset="UTF-8" />
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||||
<title>Audio Transcription</title>
|
<title>WhisperLiveKit</title>
|
||||||
<style>
|
<style>
|
||||||
|
:root {
|
||||||
|
--bg: #ffffff;
|
||||||
|
--text: #111111;
|
||||||
|
--muted: #666666;
|
||||||
|
--border: #e5e5e5;
|
||||||
|
--chip-bg: rgba(0, 0, 0, 0.04);
|
||||||
|
--chip-text: #000000;
|
||||||
|
--spinner-border: #8d8d8d5c;
|
||||||
|
--spinner-top: #b0b0b0;
|
||||||
|
--silence-bg: #f3f3f3;
|
||||||
|
--loading-bg: rgba(255, 77, 77, 0.06);
|
||||||
|
--button-bg: #ffffff;
|
||||||
|
--button-border: #e9e9e9;
|
||||||
|
--wave-stroke: #000000;
|
||||||
|
--label-dia-text: #868686;
|
||||||
|
--label-trans-text: #111111;
|
||||||
|
}
|
||||||
|
|
||||||
|
@media (prefers-color-scheme: dark) {
|
||||||
|
:root:not([data-theme="light"]) {
|
||||||
|
--bg: #0b0b0b;
|
||||||
|
--text: #e6e6e6;
|
||||||
|
--muted: #9aa0a6;
|
||||||
|
--border: #333333;
|
||||||
|
--chip-bg: rgba(255, 255, 255, 0.08);
|
||||||
|
--chip-text: #e6e6e6;
|
||||||
|
--spinner-border: #555555;
|
||||||
|
--spinner-top: #dddddd;
|
||||||
|
--silence-bg: #1a1a1a;
|
||||||
|
--loading-bg: rgba(255, 77, 77, 0.12);
|
||||||
|
--button-bg: #111111;
|
||||||
|
--button-border: #333333;
|
||||||
|
--wave-stroke: #e6e6e6;
|
||||||
|
--label-dia-text: #b3b3b3;
|
||||||
|
--label-trans-text: #ffffff;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
:root[data-theme="dark"] {
|
||||||
|
--bg: #0b0b0b;
|
||||||
|
--text: #e6e6e6;
|
||||||
|
--muted: #9aa0a6;
|
||||||
|
--border: #333333;
|
||||||
|
--chip-bg: rgba(255, 255, 255, 0.08);
|
||||||
|
--chip-text: #e6e6e6;
|
||||||
|
--spinner-border: #555555;
|
||||||
|
--spinner-top: #dddddd;
|
||||||
|
--silence-bg: #1a1a1a;
|
||||||
|
--loading-bg: rgba(255, 77, 77, 0.12);
|
||||||
|
--button-bg: #111111;
|
||||||
|
--button-border: #333333;
|
||||||
|
--wave-stroke: #e6e6e6;
|
||||||
|
--label-dia-text: #b3b3b3;
|
||||||
|
--label-trans-text: #ffffff;
|
||||||
|
}
|
||||||
|
|
||||||
|
:root[data-theme="light"] {
|
||||||
|
--bg: #ffffff;
|
||||||
|
--text: #111111;
|
||||||
|
--muted: #666666;
|
||||||
|
--border: #e5e5e5;
|
||||||
|
--chip-bg: rgba(0, 0, 0, 0.04);
|
||||||
|
--chip-text: #000000;
|
||||||
|
--spinner-border: #8d8d8d5c;
|
||||||
|
--spinner-top: #b0b0b0;
|
||||||
|
--silence-bg: #f3f3f3;
|
||||||
|
--loading-bg: rgba(255, 77, 77, 0.06);
|
||||||
|
--button-bg: #ffffff;
|
||||||
|
--button-border: #e9e9e9;
|
||||||
|
--wave-stroke: #000000;
|
||||||
|
--label-dia-text: #868686;
|
||||||
|
--label-trans-text: #111111;
|
||||||
|
}
|
||||||
body {
|
body {
|
||||||
font-family: ui-sans-serif, system-ui, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol', 'Noto Color Emoji';
|
font-family: ui-sans-serif, system-ui, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol', 'Noto Color Emoji';
|
||||||
margin: 20px;
|
margin: 20px;
|
||||||
text-align: center;
|
text-align: center;
|
||||||
|
background-color: var(--bg);
|
||||||
|
color: var(--text);
|
||||||
}
|
}
|
||||||
|
|
||||||
#recordButton {
|
#recordButton {
|
||||||
@@ -17,10 +92,10 @@
|
|||||||
height: 50px;
|
height: 50px;
|
||||||
border: none;
|
border: none;
|
||||||
border-radius: 50%;
|
border-radius: 50%;
|
||||||
background-color: white;
|
background-color: var(--button-bg);
|
||||||
cursor: pointer;
|
cursor: pointer;
|
||||||
transition: all 0.3s ease;
|
transition: all 0.3s ease;
|
||||||
border: 1px solid rgb(233, 233, 233);
|
border: 1px solid var(--button-border);
|
||||||
display: flex;
|
display: flex;
|
||||||
align-items: center;
|
align-items: center;
|
||||||
justify-content: center;
|
justify-content: center;
|
||||||
@@ -94,14 +169,14 @@
|
|||||||
.timer {
|
.timer {
|
||||||
font-size: 14px;
|
font-size: 14px;
|
||||||
font-weight: 500;
|
font-weight: 500;
|
||||||
color: #333;
|
color: var(--text);
|
||||||
margin-left: 10px;
|
margin-left: 10px;
|
||||||
}
|
}
|
||||||
|
|
||||||
#status {
|
#status {
|
||||||
margin-top: 20px;
|
margin-top: 20px;
|
||||||
font-size: 16px;
|
font-size: 16px;
|
||||||
color: #333;
|
color: var(--text);
|
||||||
}
|
}
|
||||||
|
|
||||||
.settings-container {
|
.settings-container {
|
||||||
@@ -120,12 +195,14 @@
|
|||||||
}
|
}
|
||||||
|
|
||||||
#chunkSelector,
|
#chunkSelector,
|
||||||
#websocketInput {
|
#websocketInput,
|
||||||
|
#themeSelector {
|
||||||
font-size: 16px;
|
font-size: 16px;
|
||||||
padding: 5px;
|
padding: 5px;
|
||||||
border-radius: 5px;
|
border-radius: 5px;
|
||||||
border: 1px solid #ddd;
|
border: 1px solid var(--border);
|
||||||
background-color: #ffffff;
|
background-color: var(--button-bg);
|
||||||
|
color: var(--text);
|
||||||
max-height: 30px;
|
max-height: 30px;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -134,7 +211,8 @@
|
|||||||
}
|
}
|
||||||
|
|
||||||
#chunkSelector:focus,
|
#chunkSelector:focus,
|
||||||
#websocketInput:focus {
|
#websocketInput:focus,
|
||||||
|
#themeSelector:focus {
|
||||||
outline: none;
|
outline: none;
|
||||||
border-color: #007bff;
|
border-color: #007bff;
|
||||||
}
|
}
|
||||||
@@ -156,18 +234,18 @@
|
|||||||
}
|
}
|
||||||
|
|
||||||
#linesTranscript strong {
|
#linesTranscript strong {
|
||||||
color: #333;
|
color: var(--text);
|
||||||
}
|
}
|
||||||
|
|
||||||
#speaker {
|
#speaker {
|
||||||
border: 1px solid rgb(229, 229, 229);
|
border: 1px solid var(--border);
|
||||||
border-radius: 100px;
|
border-radius: 100px;
|
||||||
padding: 2px 10px;
|
padding: 2px 10px;
|
||||||
font-size: 14px;
|
font-size: 14px;
|
||||||
margin-bottom: 0px;
|
margin-bottom: 0px;
|
||||||
}
|
}
|
||||||
.label_diarization {
|
.label_diarization {
|
||||||
background-color: #ffffff66;
|
background-color: var(--chip-bg);
|
||||||
border-radius: 8px 8px 8px 8px;
|
border-radius: 8px 8px 8px 8px;
|
||||||
padding: 2px 10px;
|
padding: 2px 10px;
|
||||||
margin-left: 10px;
|
margin-left: 10px;
|
||||||
@@ -175,11 +253,11 @@
|
|||||||
white-space: nowrap;
|
white-space: nowrap;
|
||||||
font-size: 14px;
|
font-size: 14px;
|
||||||
margin-bottom: 0px;
|
margin-bottom: 0px;
|
||||||
color: rgb(134, 134, 134)
|
color: var(--label-dia-text)
|
||||||
}
|
}
|
||||||
|
|
||||||
.label_transcription {
|
.label_transcription {
|
||||||
background-color: #ffffff66;
|
background-color: var(--chip-bg);
|
||||||
border-radius: 8px 8px 8px 8px;
|
border-radius: 8px 8px 8px 8px;
|
||||||
padding: 2px 10px;
|
padding: 2px 10px;
|
||||||
display: inline-block;
|
display: inline-block;
|
||||||
@@ -187,11 +265,11 @@
|
|||||||
margin-left: 10px;
|
margin-left: 10px;
|
||||||
font-size: 14px;
|
font-size: 14px;
|
||||||
margin-bottom: 0px;
|
margin-bottom: 0px;
|
||||||
color: #000000
|
color: var(--label-trans-text)
|
||||||
}
|
}
|
||||||
|
|
||||||
#timeInfo {
|
#timeInfo {
|
||||||
color: #666;
|
color: var(--muted);
|
||||||
margin-left: 10px;
|
margin-left: 10px;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -206,7 +284,7 @@
|
|||||||
}
|
}
|
||||||
|
|
||||||
.buffer_diarization {
|
.buffer_diarization {
|
||||||
color: rgb(134, 134, 134);
|
color: var(--label-dia-text);
|
||||||
margin-left: 4px;
|
margin-left: 4px;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -220,10 +298,10 @@
|
|||||||
display: inline-block;
|
display: inline-block;
|
||||||
width: 8px;
|
width: 8px;
|
||||||
height: 8px;
|
height: 8px;
|
||||||
border: 2px solid #8d8d8d5c;
|
border: 2px solid var(--spinner-border);
|
||||||
border-top: 2px solid #6c6c6ce5;
|
border-top: 2px solid var(--spinner-top);
|
||||||
border-radius: 50%;
|
border-radius: 50%;
|
||||||
animation: spin 0.6s linear infinite;
|
animation: spin 0.7s linear infinite;
|
||||||
vertical-align: middle;
|
vertical-align: middle;
|
||||||
margin-bottom: 2px;
|
margin-bottom: 2px;
|
||||||
margin-right: 5px;
|
margin-right: 5px;
|
||||||
@@ -236,16 +314,16 @@
|
|||||||
}
|
}
|
||||||
|
|
||||||
.silence {
|
.silence {
|
||||||
color: #666;
|
color: var(--muted);
|
||||||
background-color: #f3f3f3;
|
background-color: var(--silence-bg);
|
||||||
font-size: 13px;
|
font-size: 13px;
|
||||||
border-radius: 30px;
|
border-radius: 30px;
|
||||||
padding: 2px 10px;
|
padding: 2px 10px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.loading {
|
.loading {
|
||||||
color: #666;
|
color: var(--muted);
|
||||||
background-color: #ff4d4d0f;
|
background-color: var(--loading-bg);
|
||||||
border-radius: 8px 8px 8px 0px;
|
border-radius: 8px 8px 8px 0px;
|
||||||
padding: 2px 10px;
|
padding: 2px 10px;
|
||||||
font-size: 14px;
|
font-size: 14px;
|
||||||
@@ -284,6 +362,14 @@
|
|||||||
<label for="websocketInput">WebSocket URL:</label>
|
<label for="websocketInput">WebSocket URL:</label>
|
||||||
<input id="websocketInput" type="text" />
|
<input id="websocketInput" type="text" />
|
||||||
</div>
|
</div>
|
||||||
|
<div>
|
||||||
|
<label for="themeSelector">Theme:</label>
|
||||||
|
<select id="themeSelector">
|
||||||
|
<option value="system" selected>System</option>
|
||||||
|
<option value="light">Light</option>
|
||||||
|
<option value="dark">Dark</option>
|
||||||
|
</select>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -299,6 +385,7 @@
|
|||||||
let chunkDuration = 1000;
|
let chunkDuration = 1000;
|
||||||
let websocketUrl = "ws://localhost:8000/asr";
|
let websocketUrl = "ws://localhost:8000/asr";
|
||||||
let userClosing = false;
|
let userClosing = false;
|
||||||
|
let wakeLock = null;
|
||||||
let startTime = null;
|
let startTime = null;
|
||||||
let timerInterval = null;
|
let timerInterval = null;
|
||||||
let audioContext = null;
|
let audioContext = null;
|
||||||
@@ -309,6 +396,7 @@
|
|||||||
let animationFrame = null;
|
let animationFrame = null;
|
||||||
let waitingForStop = false;
|
let waitingForStop = false;
|
||||||
let lastReceivedData = null;
|
let lastReceivedData = null;
|
||||||
|
let lastSignature = null;
|
||||||
waveCanvas.width = 60 * (window.devicePixelRatio || 1);
|
waveCanvas.width = 60 * (window.devicePixelRatio || 1);
|
||||||
waveCanvas.height = 30 * (window.devicePixelRatio || 1);
|
waveCanvas.height = 30 * (window.devicePixelRatio || 1);
|
||||||
waveCtx.scale(window.devicePixelRatio || 1, window.devicePixelRatio || 1);
|
waveCtx.scale(window.devicePixelRatio || 1, window.devicePixelRatio || 1);
|
||||||
@@ -319,6 +407,57 @@
|
|||||||
const websocketInput = document.getElementById("websocketInput");
|
const websocketInput = document.getElementById("websocketInput");
|
||||||
const linesTranscriptDiv = document.getElementById("linesTranscript");
|
const linesTranscriptDiv = document.getElementById("linesTranscript");
|
||||||
const timerElement = document.querySelector(".timer");
|
const timerElement = document.querySelector(".timer");
|
||||||
|
const themeSelector = document.getElementById("themeSelector");
|
||||||
|
|
||||||
|
function getWaveStroke() {
|
||||||
|
const styles = getComputedStyle(document.documentElement);
|
||||||
|
const v = styles.getPropertyValue("--wave-stroke").trim();
|
||||||
|
return v || "#000";
|
||||||
|
}
|
||||||
|
|
||||||
|
let waveStroke = getWaveStroke();
|
||||||
|
|
||||||
|
function updateWaveStroke() {
|
||||||
|
waveStroke = getWaveStroke();
|
||||||
|
}
|
||||||
|
|
||||||
|
function applyTheme(pref) {
|
||||||
|
if (pref === "light") {
|
||||||
|
document.documentElement.setAttribute("data-theme", "light");
|
||||||
|
} else if (pref === "dark") {
|
||||||
|
document.documentElement.setAttribute("data-theme", "dark");
|
||||||
|
} else {
|
||||||
|
document.documentElement.removeAttribute("data-theme");
|
||||||
|
}
|
||||||
|
updateWaveStroke();
|
||||||
|
}
|
||||||
|
|
||||||
|
const savedThemePref = localStorage.getItem("themePreference") || "system";
|
||||||
|
applyTheme(savedThemePref);
|
||||||
|
if (themeSelector) {
|
||||||
|
themeSelector.value = savedThemePref;
|
||||||
|
themeSelector.addEventListener("change", () => {
|
||||||
|
const val = themeSelector.value;
|
||||||
|
localStorage.setItem("themePreference", val);
|
||||||
|
applyTheme(val);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
const darkMq = window.matchMedia && window.matchMedia("(prefers-color-scheme: dark)");
|
||||||
|
const handleOsThemeChange = () => {
|
||||||
|
const pref = localStorage.getItem("themePreference") || "system";
|
||||||
|
if (pref === "system") updateWaveStroke();
|
||||||
|
};
|
||||||
|
if (darkMq && darkMq.addEventListener) {
|
||||||
|
darkMq.addEventListener("change", handleOsThemeChange);
|
||||||
|
} else if (darkMq && darkMq.addListener) {
|
||||||
|
darkMq.addListener(handleOsThemeChange);
|
||||||
|
}
|
||||||
|
|
||||||
|
function fmt1(x) {
|
||||||
|
const n = Number(x);
|
||||||
|
return Number.isFinite(n) ? n.toFixed(1) : x;
|
||||||
|
}
|
||||||
|
|
||||||
const host = window.location.hostname || "localhost";
|
const host = window.location.hostname || "localhost";
|
||||||
const port = window.location.port;
|
const port = window.location.port;
|
||||||
@@ -446,10 +585,35 @@
|
|||||||
|
|
||||||
function renderLinesWithBuffer(lines, buffer_diarization, buffer_transcription, remaining_time_diarization, remaining_time_transcription, isFinalizing = false, current_status = "active_transcription") {
|
function renderLinesWithBuffer(lines, buffer_diarization, buffer_transcription, remaining_time_diarization, remaining_time_transcription, isFinalizing = false, current_status = "active_transcription") {
|
||||||
if (current_status === "no_audio_detected") {
|
if (current_status === "no_audio_detected") {
|
||||||
linesTranscriptDiv.innerHTML = "<p style='text-align: center; color: #666; margin-top: 20px;'><em>No audio detected...</em></p>";
|
linesTranscriptDiv.innerHTML = "<p style='text-align: center; color: var(--muted); margin-top: 20px;'><em>No audio detected...</em></p>";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// try to keep stable DOM despite having updates every 0.1s. only update numeric lag values if structure hasn't changed
|
||||||
|
const showLoading = (!isFinalizing) && (lines || []).some(it => it.speaker == 0);
|
||||||
|
const showTransLag = !isFinalizing && remaining_time_transcription > 0;
|
||||||
|
const showDiaLag = !isFinalizing && !!buffer_diarization && remaining_time_diarization > 0;
|
||||||
|
const signature = JSON.stringify({
|
||||||
|
lines: (lines || []).map(it => ({ speaker: it.speaker, text: it.text, beg: it.beg, end: it.end })),
|
||||||
|
buffer_transcription: buffer_transcription || "",
|
||||||
|
buffer_diarization: buffer_diarization || "",
|
||||||
|
status: current_status,
|
||||||
|
showLoading,
|
||||||
|
showTransLag,
|
||||||
|
showDiaLag,
|
||||||
|
isFinalizing: !!isFinalizing
|
||||||
|
});
|
||||||
|
if (lastSignature === signature) {
|
||||||
|
const t = document.querySelector(".lag-transcription-value");
|
||||||
|
if (t) t.textContent = fmt1(remaining_time_transcription);
|
||||||
|
const d = document.querySelector(".lag-diarization-value");
|
||||||
|
if (d) d.textContent = fmt1(remaining_time_diarization);
|
||||||
|
const ld = document.querySelector(".loading-diarization-value");
|
||||||
|
if (ld) ld.textContent = fmt1(remaining_time_diarization);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
lastSignature = signature;
|
||||||
|
|
||||||
const linesHtml = lines.map((item, idx) => {
|
const linesHtml = lines.map((item, idx) => {
|
||||||
let timeInfo = "";
|
let timeInfo = "";
|
||||||
if (item.beg !== undefined && item.end !== undefined) {
|
if (item.beg !== undefined && item.end !== undefined) {
|
||||||
@@ -460,7 +624,7 @@
|
|||||||
if (item.speaker === -2) {
|
if (item.speaker === -2) {
|
||||||
speakerLabel = `<span class="silence">Silence<span id='timeInfo'>${timeInfo}</span></span>`;
|
speakerLabel = `<span class="silence">Silence<span id='timeInfo'>${timeInfo}</span></span>`;
|
||||||
} else if (item.speaker == 0 && !isFinalizing) {
|
} else if (item.speaker == 0 && !isFinalizing) {
|
||||||
speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'>${remaining_time_diarization} second(s) of audio are undergoing diarization</span></span>`;
|
speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'><span class="loading-diarization-value">${fmt1(remaining_time_diarization)}</span> second(s) of audio are undergoing diarization</span></span>`;
|
||||||
} else if (item.speaker == -1) {
|
} else if (item.speaker == -1) {
|
||||||
speakerLabel = `<span id="speaker">Speaker 1<span id='timeInfo'>${timeInfo}</span></span>`;
|
speakerLabel = `<span id="speaker">Speaker 1<span id='timeInfo'>${timeInfo}</span></span>`;
|
||||||
} else if (item.speaker !== -1 && item.speaker !== 0) {
|
} else if (item.speaker !== -1 && item.speaker !== 0) {
|
||||||
@@ -471,12 +635,12 @@
|
|||||||
let currentLineText = item.text || "";
|
let currentLineText = item.text || "";
|
||||||
|
|
||||||
if (idx === lines.length - 1) {
|
if (idx === lines.length - 1) {
|
||||||
if (!isFinalizing) {
|
if (!isFinalizing && item.speaker !== -2) {
|
||||||
if (remaining_time_transcription > 0) {
|
if (remaining_time_transcription > 0) {
|
||||||
speakerLabel += `<span class="label_transcription"><span class="spinner"></span>Transcription lag <span id='timeInfo'>${remaining_time_transcription}s</span></span>`;
|
speakerLabel += `<span class="label_transcription"><span class="spinner"></span>Transcription lag <span id='timeInfo'><span class="lag-transcription-value">${fmt1(remaining_time_transcription)}</span>s</span></span>`;
|
||||||
}
|
}
|
||||||
if (buffer_diarization && remaining_time_diarization > 0) {
|
if (buffer_diarization && remaining_time_diarization > 0) {
|
||||||
speakerLabel += `<span class="label_diarization"><span class="spinner"></span>Diarization lag<span id='timeInfo'>${remaining_time_diarization}s</span></span>`;
|
speakerLabel += `<span class="label_diarization"><span class="spinner"></span>Diarization lag<span id='timeInfo'><span class="lag-diarization-value">${fmt1(remaining_time_diarization)}</span>s</span></span>`;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -502,6 +666,7 @@
|
|||||||
}).join("");
|
}).join("");
|
||||||
|
|
||||||
linesTranscriptDiv.innerHTML = linesHtml;
|
linesTranscriptDiv.innerHTML = linesHtml;
|
||||||
|
window.scrollTo({ top: document.body.scrollHeight, behavior: 'smooth' });
|
||||||
}
|
}
|
||||||
|
|
||||||
function updateTimer() {
|
function updateTimer() {
|
||||||
@@ -522,7 +687,7 @@
|
|||||||
|
|
||||||
waveCtx.clearRect(0, 0, waveCanvas.width / (window.devicePixelRatio || 1), waveCanvas.height / (window.devicePixelRatio || 1));
|
waveCtx.clearRect(0, 0, waveCanvas.width / (window.devicePixelRatio || 1), waveCanvas.height / (window.devicePixelRatio || 1));
|
||||||
waveCtx.lineWidth = 1;
|
waveCtx.lineWidth = 1;
|
||||||
waveCtx.strokeStyle = 'rgb(0, 0, 0)';
|
waveCtx.strokeStyle = waveStroke;
|
||||||
waveCtx.beginPath();
|
waveCtx.beginPath();
|
||||||
|
|
||||||
const sliceWidth = (waveCanvas.width / (window.devicePixelRatio || 1)) / bufferLength;
|
const sliceWidth = (waveCanvas.width / (window.devicePixelRatio || 1)) / bufferLength;
|
||||||
@@ -549,6 +714,16 @@
|
|||||||
|
|
||||||
async function startRecording() {
|
async function startRecording() {
|
||||||
try {
|
try {
|
||||||
|
|
||||||
|
// https://developer.mozilla.org/en-US/docs/Web/API/Screen_Wake_Lock_API
|
||||||
|
// create an async function to request a wake lock
|
||||||
|
try {
|
||||||
|
wakeLock = await navigator.wakeLock.request("screen");
|
||||||
|
} catch (err) {
|
||||||
|
// The Wake Lock request has failed - usually system related, such as battery.
|
||||||
|
console.log("Error acquiring wake lock.")
|
||||||
|
}
|
||||||
|
|
||||||
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
|
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
|
||||||
|
|
||||||
audioContext = new (window.AudioContext || window.webkitAudioContext)();
|
audioContext = new (window.AudioContext || window.webkitAudioContext)();
|
||||||
@@ -578,6 +753,10 @@
|
|||||||
}
|
}
|
||||||
|
|
||||||
async function stopRecording() {
|
async function stopRecording() {
|
||||||
|
wakeLock.release().then(() => {
|
||||||
|
wakeLock = null;
|
||||||
|
});
|
||||||
|
|
||||||
userClosing = true;
|
userClosing = true;
|
||||||
waitingForStop = true;
|
waitingForStop = true;
|
||||||
|
|
||||||
|
|||||||
@@ -3,43 +3,10 @@ import logging
|
|||||||
import io
|
import io
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
import math
|
import math
|
||||||
try:
|
|
||||||
import torch
|
|
||||||
except ImportError:
|
|
||||||
torch = None
|
|
||||||
from typing import List
|
from typing import List
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from whisperlivekit.timed_objects import ASRToken
|
from whisperlivekit.timed_objects import ASRToken
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
SIMULSTREAMING_ERROR_AND_INSTALLATION_INSTRUCTIONS = ImportError(
|
|
||||||
"""SimulStreaming dependencies are not available.
|
|
||||||
Please install WhisperLiveKit using pip install "whisperlivekit[simulstreaming]"
|
|
||||||
""")
|
|
||||||
|
|
||||||
try:
|
|
||||||
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
|
||||||
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper, DEC_PAD
|
|
||||||
from whisperlivekit.simul_whisper.whisper import tokenizer
|
|
||||||
SIMULSTREAMING_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
logger.warning("⚠️ SimulStreaming dependencies not available. Attempting to download them.")
|
|
||||||
try:
|
|
||||||
from whisperlivekit import download_simulstreaming_backend
|
|
||||||
download_simulstreaming_backend()
|
|
||||||
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
|
||||||
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper, DEC_PAD
|
|
||||||
from whisperlivekit.simul_whisper.whisper import tokenizer
|
|
||||||
SIMULSTREAMING_AVAILABLE = True
|
|
||||||
logger.info("SimulStreaming dependencies downloaded successfully.")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to download or import SimulStreaming dependencies: {e}")
|
|
||||||
SIMULSTREAMING_AVAILABLE = False
|
|
||||||
AlignAttConfig = None
|
|
||||||
PaddedAlignAttWhisper = None
|
|
||||||
DEC_PAD = None
|
|
||||||
tokenizer = None
|
|
||||||
|
|
||||||
class ASRBase:
|
class ASRBase:
|
||||||
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
|
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
|
||||||
# "" for faster-whisper because it emits the spaces when needed)
|
# "" for faster-whisper because it emits the spaces when needed)
|
||||||
@@ -320,182 +287,4 @@ class OpenaiApiASR(ASRBase):
|
|||||||
self.use_vad_opt = True
|
self.use_vad_opt = True
|
||||||
|
|
||||||
def set_translate_task(self):
|
def set_translate_task(self):
|
||||||
self.task = "translate"
|
self.task = "translate"
|
||||||
|
|
||||||
|
|
||||||
class SimulStreamingASR(ASRBase):
|
|
||||||
"""SimulStreaming backend with AlignAtt policy."""
|
|
||||||
sep = ""
|
|
||||||
|
|
||||||
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr, **kwargs):
|
|
||||||
if not SIMULSTREAMING_AVAILABLE:
|
|
||||||
raise SIMULSTREAMING_ERROR_AND_INSTALLATION_INSTRUCTIONS
|
|
||||||
with open("whisperlivekit/simul_whisper/dual_license_simulstreaming.md", "r") as f:
|
|
||||||
print("*"*80 + f.read() + "*"*80)
|
|
||||||
self.logfile = logfile
|
|
||||||
self.transcribe_kargs = {}
|
|
||||||
self.original_language = None if lan == "auto" else lan
|
|
||||||
|
|
||||||
self.model_path = kwargs.get('model_path', './large-v3.pt')
|
|
||||||
self.frame_threshold = kwargs.get('frame_threshold', 25)
|
|
||||||
self.audio_max_len = kwargs.get('audio_max_len', 30.0)
|
|
||||||
self.audio_min_len = kwargs.get('audio_min_len', 0.0)
|
|
||||||
self.segment_length = kwargs.get('segment_length', 0.5)
|
|
||||||
self.beams = kwargs.get('beams', 1)
|
|
||||||
self.decoder_type = kwargs.get('decoder_type', 'greedy' if self.beams == 1 else 'beam')
|
|
||||||
self.task = kwargs.get('task', 'transcribe')
|
|
||||||
self.cif_ckpt_path = kwargs.get('cif_ckpt_path', None)
|
|
||||||
self.never_fire = kwargs.get('never_fire', False)
|
|
||||||
self.init_prompt = kwargs.get('init_prompt', None)
|
|
||||||
self.static_init_prompt = kwargs.get('static_init_prompt', None)
|
|
||||||
self.max_context_tokens = kwargs.get('max_context_tokens', None)
|
|
||||||
|
|
||||||
if model_dir is not None:
|
|
||||||
self.model_path = model_dir
|
|
||||||
elif modelsize is not None: #For the moment the .en.pt models do not work!
|
|
||||||
model_mapping = {
|
|
||||||
'tiny': './tiny.pt',
|
|
||||||
'base': './base.pt',
|
|
||||||
'small': './small.pt',
|
|
||||||
'medium': './medium.pt',
|
|
||||||
'medium.en': './medium.en.pt',
|
|
||||||
'large-v1': './large-v1.pt',
|
|
||||||
'base.en': './base.en.pt',
|
|
||||||
'small.en': './small.en.pt',
|
|
||||||
'tiny.en': './tiny.en.pt',
|
|
||||||
'large-v2': './large-v2.pt',
|
|
||||||
'large-v3': './large-v3.pt',
|
|
||||||
'large': './large-v3.pt'
|
|
||||||
}
|
|
||||||
self.model_path = model_mapping.get(modelsize, f'./{modelsize}.pt')
|
|
||||||
|
|
||||||
self.model = self.load_model(modelsize, cache_dir, model_dir)
|
|
||||||
|
|
||||||
# Set up tokenizer for translation if needed
|
|
||||||
if self.task == "translate":
|
|
||||||
self.set_translate_task()
|
|
||||||
|
|
||||||
def load_model(self, modelsize, cache_dir, model_dir):
|
|
||||||
try:
|
|
||||||
cfg = AlignAttConfig(
|
|
||||||
model_path=self.model_path,
|
|
||||||
segment_length=self.segment_length,
|
|
||||||
frame_threshold=self.frame_threshold,
|
|
||||||
language=self.original_language,
|
|
||||||
audio_max_len=self.audio_max_len,
|
|
||||||
audio_min_len=self.audio_min_len,
|
|
||||||
cif_ckpt_path=self.cif_ckpt_path,
|
|
||||||
decoder_type="beam",
|
|
||||||
beam_size=self.beams,
|
|
||||||
task=self.task,
|
|
||||||
never_fire=self.never_fire,
|
|
||||||
init_prompt=self.init_prompt,
|
|
||||||
max_context_tokens=self.max_context_tokens,
|
|
||||||
static_init_prompt=self.static_init_prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Loading SimulStreaming model with language: {self.original_language}")
|
|
||||||
model = PaddedAlignAttWhisper(cfg)
|
|
||||||
return model
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to load SimulStreaming model: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def transcribe(self, audio, init_prompt=""):
|
|
||||||
"""Transcribe audio using SimulStreaming."""
|
|
||||||
try:
|
|
||||||
if isinstance(audio, np.ndarray):
|
|
||||||
audio_tensor = torch.from_numpy(audio).float()
|
|
||||||
else:
|
|
||||||
audio_tensor = audio
|
|
||||||
|
|
||||||
prompt = init_prompt if init_prompt else (self.init_prompt or "")
|
|
||||||
|
|
||||||
result = self.model.infer(audio_tensor, init_prompt=prompt)
|
|
||||||
|
|
||||||
if torch.is_tensor(result):
|
|
||||||
result = result[result < DEC_PAD]
|
|
||||||
|
|
||||||
logger.debug(f"SimulStreaming transcription result: {result}")
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"SimulStreaming transcription failed: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def ts_words(self, result) -> List[ASRToken]:
|
|
||||||
"""Convert SimulStreaming result to ASRToken list."""
|
|
||||||
tokens = []
|
|
||||||
|
|
||||||
try:
|
|
||||||
if torch.is_tensor(result):
|
|
||||||
text = self.model.tokenizer.decode(result.cpu().numpy())
|
|
||||||
else:
|
|
||||||
text = str(result)
|
|
||||||
|
|
||||||
if not text or len(text.strip()) == 0:
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
# We dont have word-level timestamps here. 1rst approach, should be improved later.
|
|
||||||
words = text.strip().split()
|
|
||||||
if not words:
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
duration_per_word = 0.1 # this will be modified based on actual audio duration
|
|
||||||
#with the SimulStreamingOnlineProcessor
|
|
||||||
|
|
||||||
for i, word in enumerate(words):
|
|
||||||
start_time = i * duration_per_word
|
|
||||||
end_time = (i + 1) * duration_per_word
|
|
||||||
|
|
||||||
token = ASRToken(
|
|
||||||
start=start_time,
|
|
||||||
end=end_time,
|
|
||||||
text=word,
|
|
||||||
probability=1.0
|
|
||||||
)
|
|
||||||
tokens.append(token)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error converting SimulStreaming result to tokens: {e}")
|
|
||||||
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
def segments_end_ts(self, result) -> List[float]:
|
|
||||||
"""Get segment end timestamps."""
|
|
||||||
if torch.is_tensor(result):
|
|
||||||
num_tokens = len(result)
|
|
||||||
return [num_tokens * 0.1] # rough estimate
|
|
||||||
return [1.0]
|
|
||||||
|
|
||||||
def use_vad(self):
|
|
||||||
"""Enable VAD - SimulStreaming has different VAD handling."""
|
|
||||||
logger.info("VAD requested for SimulStreaming - handled internally by the model")
|
|
||||||
pass
|
|
||||||
|
|
||||||
def set_translate_task(self):
|
|
||||||
"""Set up translation task."""
|
|
||||||
try:
|
|
||||||
self.model.tokenizer = tokenizer.get_tokenizer(
|
|
||||||
multilingual=True,
|
|
||||||
language=self.model.cfg.language,
|
|
||||||
num_languages=self.model.model.num_languages,
|
|
||||||
task="translate"
|
|
||||||
)
|
|
||||||
logger.info("SimulStreaming configured for translation task")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to configure SimulStreaming for translation: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def warmup(self, audio, init_prompt=""):
|
|
||||||
"""Warmup the SimulStreaming model."""
|
|
||||||
try:
|
|
||||||
if isinstance(audio, np.ndarray):
|
|
||||||
audio = torch.from_numpy(audio).float()
|
|
||||||
self.model.insert_audio(audio)
|
|
||||||
self.model.infer(True)
|
|
||||||
self.model.refresh_segment(complete=True)
|
|
||||||
logger.info("SimulStreaming model warmed up successfully")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"SimulStreaming warmup failed: {e}")
|
|
||||||
@@ -6,18 +6,6 @@ from whisperlivekit.timed_objects import ASRToken, Sentence, Transcript
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# simulStreaming imports - we check if the files are here
|
|
||||||
try:
|
|
||||||
import torch
|
|
||||||
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
|
||||||
SIMULSTREAMING_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
logger.warning("SimulStreaming dependencies not available for online processor.")
|
|
||||||
SIMULSTREAMING_AVAILABLE = False
|
|
||||||
OnlineProcessorInterface = None
|
|
||||||
torch = None
|
|
||||||
|
|
||||||
|
|
||||||
class HypothesisBuffer:
|
class HypothesisBuffer:
|
||||||
"""
|
"""
|
||||||
Buffer to store and process ASR hypothesis tokens.
|
Buffer to store and process ASR hypothesis tokens.
|
||||||
@@ -528,205 +516,3 @@ class VACOnlineASRProcessor:
|
|||||||
"""
|
"""
|
||||||
return self.online.concatenate_tokens(self.online.transcript_buffer.buffer)
|
return self.online.concatenate_tokens(self.online.transcript_buffer.buffer)
|
||||||
|
|
||||||
|
|
||||||
class SimulStreamingOnlineProcessor:
|
|
||||||
SAMPLING_RATE = 16000
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
asr,
|
|
||||||
tokenize_method: Optional[callable] = None,
|
|
||||||
buffer_trimming: Tuple[str, float] = ("segment", 15),
|
|
||||||
confidence_validation = False,
|
|
||||||
logfile=sys.stderr,
|
|
||||||
):
|
|
||||||
if not SIMULSTREAMING_AVAILABLE:
|
|
||||||
raise ImportError("SimulStreaming dependencies are not available.")
|
|
||||||
|
|
||||||
self.asr = asr
|
|
||||||
self.tokenize = tokenize_method
|
|
||||||
self.logfile = logfile
|
|
||||||
self.confidence_validation = confidence_validation
|
|
||||||
self.init()
|
|
||||||
|
|
||||||
# buffer does not work yet
|
|
||||||
self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
|
|
||||||
|
|
||||||
def init(self, offset: Optional[float] = None):
|
|
||||||
"""Initialize or reset the processing state."""
|
|
||||||
self.audio_chunks = []
|
|
||||||
self.offset = offset if offset is not None else 0.0
|
|
||||||
self.is_last = False
|
|
||||||
self.beg = self.offset
|
|
||||||
self.end = self.offset
|
|
||||||
self.cumulative_audio_duration = 0.0
|
|
||||||
self.last_audio_stream_end_time = self.offset
|
|
||||||
|
|
||||||
self.committed: List[ASRToken] = []
|
|
||||||
self.last_result_tokens: List[ASRToken] = []
|
|
||||||
self.buffer_content = ""
|
|
||||||
self.processed_audio_duration = 0.0
|
|
||||||
|
|
||||||
def get_audio_buffer_end_time(self) -> float:
|
|
||||||
"""Returns the absolute end time of the current audio buffer."""
|
|
||||||
return self.end
|
|
||||||
|
|
||||||
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: Optional[float] = None):
|
|
||||||
"""Append an audio chunk to be processed by SimulStreaming."""
|
|
||||||
if torch is None:
|
|
||||||
raise ImportError("PyTorch is required for SimulStreaming but not available")
|
|
||||||
|
|
||||||
# Convert numpy array to torch tensor
|
|
||||||
audio_tensor = torch.from_numpy(audio).float()
|
|
||||||
self.audio_chunks.append(audio_tensor)
|
|
||||||
|
|
||||||
# Update timing
|
|
||||||
chunk_duration = len(audio) / self.SAMPLING_RATE
|
|
||||||
self.cumulative_audio_duration += chunk_duration
|
|
||||||
|
|
||||||
if audio_stream_end_time is not None:
|
|
||||||
self.last_audio_stream_end_time = audio_stream_end_time
|
|
||||||
self.end = audio_stream_end_time
|
|
||||||
else:
|
|
||||||
self.end = self.offset + self.cumulative_audio_duration
|
|
||||||
|
|
||||||
def prompt(self) -> Tuple[str, str]:
|
|
||||||
"""
|
|
||||||
Returns a tuple: (prompt, context).
|
|
||||||
SimulStreaming handles prompting internally, so we return empty strings.
|
|
||||||
"""
|
|
||||||
return "", ""
|
|
||||||
|
|
||||||
def get_buffer(self):
|
|
||||||
"""
|
|
||||||
Get the unvalidated buffer content.
|
|
||||||
"""
|
|
||||||
buffer_end = self.end if hasattr(self, 'end') else None
|
|
||||||
return Transcript(
|
|
||||||
start=None,
|
|
||||||
end=buffer_end,
|
|
||||||
text=self.buffer_content,
|
|
||||||
probability=None
|
|
||||||
)
|
|
||||||
|
|
||||||
def timestamped_text(self, tokens, generation):
|
|
||||||
# From the simulstreaming repo. self.model to self.asr.model
|
|
||||||
pr = generation["progress"]
|
|
||||||
if "result" not in generation:
|
|
||||||
split_words, split_tokens = self.asr.model.tokenizer.split_to_word_tokens(tokens)
|
|
||||||
else:
|
|
||||||
split_words, split_tokens = generation["result"]["split_words"], generation["result"]["split_tokens"]
|
|
||||||
|
|
||||||
frames = [p["most_attended_frames"][0] for p in pr]
|
|
||||||
tokens = tokens.copy()
|
|
||||||
ret = []
|
|
||||||
for sw,st in zip(split_words,split_tokens):
|
|
||||||
b = None
|
|
||||||
for stt in st:
|
|
||||||
t,f = tokens.pop(0), frames.pop(0)
|
|
||||||
if t != stt:
|
|
||||||
raise ValueError(f"Token mismatch: {t} != {stt} at frame {f}.")
|
|
||||||
if b is None:
|
|
||||||
b = f
|
|
||||||
e = f
|
|
||||||
out = (b*0.02, e*0.02, sw)
|
|
||||||
ret.append(out)
|
|
||||||
logger.debug(f"TS-WORD:\t{' '.join(map(str, out))}")
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def process_iter(self) -> Tuple[List[ASRToken], float]:
|
|
||||||
"""
|
|
||||||
Process accumulated audio chunks using SimulStreaming.
|
|
||||||
|
|
||||||
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
|
|
||||||
"""
|
|
||||||
if not self.audio_chunks:
|
|
||||||
return [], self.end
|
|
||||||
|
|
||||||
try:
|
|
||||||
# concatenate all audio chunks
|
|
||||||
if len(self.audio_chunks) == 1:
|
|
||||||
audio = self.audio_chunks[0]
|
|
||||||
else:
|
|
||||||
audio = torch.cat(self.audio_chunks, dim=0)
|
|
||||||
|
|
||||||
audio_duration = audio.shape[0] / self.SAMPLING_RATE if audio.shape[0] > 0 else 0
|
|
||||||
self.processed_audio_duration += audio_duration
|
|
||||||
|
|
||||||
self.audio_chunks = []
|
|
||||||
|
|
||||||
logger.debug(f"SimulStreaming processing audio shape: {audio.shape}, duration: {audio_duration:.2f}s")
|
|
||||||
logger.debug(f"Current end time: {self.end:.2f}s, last stream time: {self.last_audio_stream_end_time:.2f}s")
|
|
||||||
|
|
||||||
self.asr.model.insert_audio(audio)
|
|
||||||
tokens, generation_progress = self.asr.model.infer(is_last=self.is_last)
|
|
||||||
ts_words = self.timestamped_text(tokens, generation_progress)
|
|
||||||
text = self.asr.model.tokenizer.decode(tokens)
|
|
||||||
|
|
||||||
new_tokens = []
|
|
||||||
for ts_word in ts_words:
|
|
||||||
|
|
||||||
start, end, word = ts_word
|
|
||||||
token = ASRToken(
|
|
||||||
start=start,
|
|
||||||
end=end,
|
|
||||||
text=word,
|
|
||||||
probability=0.95 # fake prob. Maybe we can extract it from the model?
|
|
||||||
)
|
|
||||||
new_tokens.append(token)
|
|
||||||
self.committed.extend(new_tokens)
|
|
||||||
|
|
||||||
return new_tokens, self.end
|
|
||||||
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"SimulStreaming processing error: {e}")
|
|
||||||
logger.error(f"Error details: {type(e).__name__}: {str(e)}")
|
|
||||||
return [], self.end
|
|
||||||
|
|
||||||
def finish(self) -> Tuple[List[ASRToken], float]:
|
|
||||||
logger.debug("SimulStreaming finish() called")
|
|
||||||
self.is_last = True
|
|
||||||
final_tokens, final_time = self.process_iter()
|
|
||||||
self.is_last = False
|
|
||||||
return final_tokens, final_time
|
|
||||||
|
|
||||||
def concatenate_tokens(
|
|
||||||
self,
|
|
||||||
tokens: List[ASRToken],
|
|
||||||
sep: Optional[str] = None,
|
|
||||||
offset: float = 0
|
|
||||||
) -> Transcript:
|
|
||||||
"""Concatenate tokens into a Transcript object."""
|
|
||||||
sep = sep if sep is not None else self.asr.sep
|
|
||||||
text = sep.join(token.text for token in tokens)
|
|
||||||
probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
|
|
||||||
if tokens:
|
|
||||||
start = offset + tokens[0].start
|
|
||||||
end = offset + tokens[-1].end
|
|
||||||
else:
|
|
||||||
start = None
|
|
||||||
end = None
|
|
||||||
return Transcript(start, end, text, probability=probability)
|
|
||||||
|
|
||||||
def chunk_at(self, time: float):
|
|
||||||
"""
|
|
||||||
useless but kept for compatibility
|
|
||||||
"""
|
|
||||||
logger.debug(f"SimulStreaming chunk_at({time:.2f}) - handled internally")
|
|
||||||
pass
|
|
||||||
|
|
||||||
def words_to_sentences(self, tokens: List[ASRToken]) -> List[Sentence]:
|
|
||||||
"""
|
|
||||||
Create simple sentences.
|
|
||||||
"""
|
|
||||||
if not tokens:
|
|
||||||
return []
|
|
||||||
|
|
||||||
full_text = " ".join(token.text for token in tokens)
|
|
||||||
sentence = Sentence(
|
|
||||||
start=tokens[0].start,
|
|
||||||
end=tokens[-1].end,
|
|
||||||
text=full_text
|
|
||||||
)
|
|
||||||
return [sentence]
|
|
||||||
|
|||||||
@@ -5,8 +5,7 @@ import librosa
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
from .backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR, SimulStreamingASR, SIMULSTREAMING_AVAILABLE, SIMULSTREAMING_ERROR_AND_INSTALLATION_INSTRUCTIONS
|
from .backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR
|
||||||
from .online_asr import OnlineASRProcessor, VACOnlineASRProcessor, SimulStreamingOnlineProcessor, SIMULSTREAMING_AVAILABLE as SIMULSTREAMING_ONLINE_AVAILABLE
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -68,35 +67,7 @@ def backend_factory(args):
|
|||||||
backend = args.backend
|
backend = args.backend
|
||||||
if backend == "openai-api":
|
if backend == "openai-api":
|
||||||
logger.debug("Using OpenAI API.")
|
logger.debug("Using OpenAI API.")
|
||||||
asr = OpenaiApiASR(lan=args.lan)
|
asr = OpenaiApiASR(lan=args.lan)
|
||||||
elif backend == "simulstreaming":
|
|
||||||
logger.debug("Using SimulStreaming backend.")
|
|
||||||
if not SIMULSTREAMING_AVAILABLE:
|
|
||||||
raise SIMULSTREAMING_ERROR_AND_INSTALLATION_INSTRUCTIONS
|
|
||||||
|
|
||||||
simulstreaming_kwargs = {}
|
|
||||||
for attr in ['frame_threshold', 'beams', 'decoder_type', 'audio_max_len', 'audio_min_len',
|
|
||||||
'cif_ckpt_path', 'never_fire', 'init_prompt', 'static_init_prompt',
|
|
||||||
'max_context_tokens', 'model_path']:
|
|
||||||
if hasattr(args, attr):
|
|
||||||
simulstreaming_kwargs[attr] = getattr(args, attr)
|
|
||||||
|
|
||||||
# Add segment_length from min_chunk_size
|
|
||||||
simulstreaming_kwargs['segment_length'] = getattr(args, 'min_chunk_size', 0.5)
|
|
||||||
simulstreaming_kwargs['task'] = args.task
|
|
||||||
|
|
||||||
size = args.model
|
|
||||||
t = time.time()
|
|
||||||
logger.info(f"Loading SimulStreaming {size} model for language {args.lan}...")
|
|
||||||
asr = SimulStreamingASR(
|
|
||||||
modelsize=size,
|
|
||||||
lan=args.lan,
|
|
||||||
cache_dir=getattr(args, 'model_cache_dir', None),
|
|
||||||
model_dir=getattr(args, 'model_dir', None),
|
|
||||||
**simulstreaming_kwargs
|
|
||||||
)
|
|
||||||
e = time.time()
|
|
||||||
logger.info(f"done. It took {round(e-t,2)} seconds.")
|
|
||||||
else:
|
else:
|
||||||
if backend == "faster-whisper":
|
if backend == "faster-whisper":
|
||||||
asr_cls = FasterWhisperASR
|
asr_cls = FasterWhisperASR
|
||||||
@@ -136,107 +107,4 @@ def backend_factory(args):
|
|||||||
tokenizer = create_tokenizer(tgt_language)
|
tokenizer = create_tokenizer(tgt_language)
|
||||||
else:
|
else:
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
return asr, tokenizer
|
return asr, tokenizer
|
||||||
|
|
||||||
def online_factory(args, asr, tokenizer, logfile=sys.stderr):
|
|
||||||
if args.backend == "simulstreaming":
|
|
||||||
if not SIMULSTREAMING_ONLINE_AVAILABLE:
|
|
||||||
raise SIMULSTREAMING_ERROR_AND_INSTALLATION_INSTRUCTIONS
|
|
||||||
|
|
||||||
logger.debug("Creating SimulStreaming online processor")
|
|
||||||
online = SimulStreamingOnlineProcessor(
|
|
||||||
asr,
|
|
||||||
tokenizer,
|
|
||||||
logfile=logfile,
|
|
||||||
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
|
|
||||||
confidence_validation=args.confidence_validation
|
|
||||||
)
|
|
||||||
elif args.vac:
|
|
||||||
online = VACOnlineASRProcessor(
|
|
||||||
args.min_chunk_size,
|
|
||||||
asr,
|
|
||||||
tokenizer,
|
|
||||||
logfile=logfile,
|
|
||||||
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
|
|
||||||
confidence_validation = args.confidence_validation
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
online = OnlineASRProcessor(
|
|
||||||
asr,
|
|
||||||
tokenizer,
|
|
||||||
logfile=logfile,
|
|
||||||
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
|
|
||||||
confidence_validation = args.confidence_validation
|
|
||||||
)
|
|
||||||
return online
|
|
||||||
|
|
||||||
def asr_factory(args, logfile=sys.stderr):
|
|
||||||
"""
|
|
||||||
Creates and configures an ASR and ASR Online instance based on the specified backend and arguments.
|
|
||||||
"""
|
|
||||||
asr, tokenizer = backend_factory(args)
|
|
||||||
online = online_factory(args, asr, tokenizer, logfile=logfile)
|
|
||||||
return asr, online
|
|
||||||
|
|
||||||
def warmup_asr(asr, warmup_file=None, timeout=5):
|
|
||||||
"""
|
|
||||||
Warmup the ASR model by transcribing a short audio file.
|
|
||||||
"""
|
|
||||||
import os
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
is_simulstreaming = hasattr(asr, 'warmup') and callable(getattr(asr, 'warmup'))
|
|
||||||
|
|
||||||
if warmup_file is None:
|
|
||||||
# Download JFK sample if not already present
|
|
||||||
jfk_url = "https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav"
|
|
||||||
temp_dir = tempfile.gettempdir()
|
|
||||||
warmup_file = os.path.join(temp_dir, "whisper_warmup_jfk.wav")
|
|
||||||
|
|
||||||
if not os.path.exists(warmup_file):
|
|
||||||
logger.debug(f"Downloading warmup file from {jfk_url}")
|
|
||||||
print(f"Downloading warmup file from {jfk_url}")
|
|
||||||
import time
|
|
||||||
import urllib.request
|
|
||||||
import urllib.error
|
|
||||||
import socket
|
|
||||||
|
|
||||||
original_timeout = socket.getdefaulttimeout()
|
|
||||||
socket.setdefaulttimeout(timeout)
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
try:
|
|
||||||
urllib.request.urlretrieve(jfk_url, warmup_file)
|
|
||||||
logger.debug(f"Download successful in {time.time() - start_time:.2f}s")
|
|
||||||
except (urllib.error.URLError, socket.timeout) as e:
|
|
||||||
logger.warning(f"Download failed: {e}. Proceeding without warmup.")
|
|
||||||
return False
|
|
||||||
finally:
|
|
||||||
socket.setdefaulttimeout(original_timeout)
|
|
||||||
elif not warmup_file:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if not warmup_file or not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0:
|
|
||||||
logger.warning(f"Warmup file {warmup_file} invalid or missing.")
|
|
||||||
return False
|
|
||||||
|
|
||||||
print(f"Warming up {'SimulStreaming' if is_simulstreaming else 'Whisper'} with {warmup_file}")
|
|
||||||
try:
|
|
||||||
import librosa
|
|
||||||
audio, sr = librosa.load(warmup_file, sr=16000)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to load audio file: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
if is_simulstreaming:
|
|
||||||
asr.warmup(audio)
|
|
||||||
else:
|
|
||||||
asr.transcribe(audio)
|
|
||||||
|
|
||||||
logger.info(f"{'SimulStreaming' if is_simulstreaming else 'Whisper'} is warmed up")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Warmup failed: {e}")
|
|
||||||
return False
|
|
||||||
Reference in New Issue
Block a user