Compare commits
142 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
babe93b99a | ||
|
|
a4e9f3cab7 | ||
|
|
b06866877a | ||
|
|
967cdfebc8 | ||
|
|
3c11c60126 | ||
|
|
2963e8a757 | ||
|
|
cb2d4ea88a | ||
|
|
add7ea07ee | ||
|
|
da8726b2cb | ||
|
|
3358877054 | ||
|
|
1f7798c7c1 | ||
|
|
c7b3bb5e58 | ||
|
|
f661f21675 | ||
|
|
b6164aa59b | ||
|
|
4209d7f7c0 | ||
|
|
334b338ab0 | ||
|
|
72f33be6f2 | ||
|
|
84890b8e61 | ||
|
|
c6668adcf3 | ||
|
|
a178ed5c22 | ||
|
|
7601c74c9c | ||
|
|
fad9ee4d21 | ||
|
|
d1a9913c47 | ||
|
|
e4ca2623cb | ||
|
|
9c1bf37960 | ||
|
|
f46528471b | ||
|
|
191680940b | ||
|
|
ee02afec56 | ||
|
|
a458028de2 | ||
|
|
abd8f2c269 | ||
|
|
f3ad4e39e4 | ||
|
|
e0a5cbf0e7 | ||
|
|
953697cd86 | ||
|
|
3bd2122eb4 | ||
|
|
50b0527858 | ||
|
|
b044fcdec2 | ||
|
|
b0508fcf2c | ||
|
|
ce89b0aebc | ||
|
|
d5008ed828 | ||
|
|
d467716e26 | ||
|
|
199e21b3ef | ||
|
|
1d926f2e67 | ||
|
|
4a71a391b8 | ||
|
|
d3ed4e46e2 | ||
|
|
057a1026d7 | ||
|
|
1ba171a58d | ||
|
|
1adac67155 | ||
|
|
42be1a3773 | ||
|
|
0a49fafa0d | ||
|
|
4a5d5e1f3b | ||
|
|
583a2ec2e4 | ||
|
|
19765e89e9 | ||
|
|
9895bc83bf | ||
|
|
ab98c31f16 | ||
|
|
f9c9c4188a | ||
|
|
c21d2302e7 | ||
|
|
4ed62e181d | ||
|
|
52a755a08c | ||
|
|
9a8d3cbd90 | ||
|
|
b101ce06bd | ||
|
|
c83fd179a8 | ||
|
|
5258305745 | ||
|
|
ce781831ee | ||
|
|
58297daf6d | ||
|
|
3393a08f7e | ||
|
|
5b2ddeccdb | ||
|
|
26cc1072dd | ||
|
|
12973711f6 | ||
|
|
909ac9dd41 | ||
|
|
d94a07d417 | ||
|
|
b32dd8bfc4 | ||
|
|
9feb0e597b | ||
|
|
9dab84a573 | ||
|
|
d089c7fce0 | ||
|
|
253a080df5 | ||
|
|
0c6e4b2aee | ||
|
|
e14bbde77d | ||
|
|
7496163467 | ||
|
|
696a94d1ce | ||
|
|
2699b0974c | ||
|
|
90c0250ba4 | ||
|
|
eb96153ffd | ||
|
|
47e3eb9b5b | ||
|
|
b8b07adeef | ||
|
|
d0e9e37ef6 | ||
|
|
820f92d8cb | ||
|
|
e42523af84 | ||
|
|
e2184d5e06 | ||
|
|
7fe0353260 | ||
|
|
0f2eba507e | ||
|
|
55e08474f3 | ||
|
|
28bdc52e1d | ||
|
|
e4221fa6c3 | ||
|
|
1652db9a2d | ||
|
|
601f17653a | ||
|
|
7718190fcd | ||
|
|
349c7dcb9e | ||
|
|
1c42b867cf | ||
|
|
d4771e563e | ||
|
|
b0a5fc0693 | ||
|
|
3b96fb8776 | ||
|
|
7f93c4b978 | ||
|
|
15c3df1cba | ||
|
|
7fb8e66c01 | ||
|
|
728e1f1290 | ||
|
|
87b9ed6ecd | ||
|
|
38b4ebe8ba | ||
|
|
d098af3185 | ||
|
|
4e56130a40 | ||
|
|
2bbdc70187 | ||
|
|
b678a55f63 | ||
|
|
5491964e81 | ||
|
|
b05297a96d | ||
|
|
197293e25e | ||
|
|
ba41c4ab56 | ||
|
|
bda72b8bc0 | ||
|
|
bb6b9f4cb1 | ||
|
|
e40b5a3ea0 | ||
|
|
4cfed6e98e | ||
|
|
687e3dd5e2 | ||
|
|
e4140cd299 | ||
|
|
8e056cbdf2 | ||
|
|
9dcfb38967 | ||
|
|
47b9235d70 | ||
|
|
f3cd53a4db | ||
|
|
dbdb4ea66c | ||
|
|
00424d7ca3 | ||
|
|
4b738d6f63 | ||
|
|
8a5e2adb1e | ||
|
|
f85329e112 | ||
|
|
46efbdf1d9 | ||
|
|
8885ade003 | ||
|
|
2564928d83 | ||
|
|
56114d3071 | ||
|
|
5b9977c9af | ||
|
|
12a544164f | ||
|
|
2ca1156b7e | ||
|
|
3ad3683ca7 | ||
|
|
1599bd87a0 | ||
|
|
90623400a4 | ||
|
|
64e44fb24f | ||
|
|
156b9a133f |
3
.gitignore
vendored
@@ -137,4 +137,5 @@ run_*.sh
|
||||
test_*.py
|
||||
launch.json
|
||||
.DS_Store
|
||||
test/*
|
||||
test/*
|
||||
nllb-200-distilled-600M-ctranslate2/*
|
||||
@@ -15,7 +15,7 @@ Thank you for considering contributing ! We appreciate your time and effort to h
|
||||
|
||||
## Opening Issues
|
||||
|
||||
If you encounter a problem with diart or want to suggest an improvement, please follow these guidelines when opening an issue:
|
||||
If you encounter a problem with WhisperLiveKit or want to suggest an improvement, please follow these guidelines when opening an issue:
|
||||
|
||||
- **Bug Reports:**
|
||||
- Clearly describe the error. **Please indicate the parameters you use, especially the model(s)**
|
||||
@@ -43,4 +43,4 @@ We welcome and appreciate contributions! To ensure a smooth review process, plea
|
||||
|
||||
## Thank You
|
||||
|
||||
Your contributions make diart better for everyone. Thank you for your time and dedication!
|
||||
Your contributions make WhisperLiveKit better for everyone. Thank you for your time and dedication!
|
||||
|
||||
70
DEV_NOTES.md
Normal file
@@ -0,0 +1,70 @@
|
||||
# 1. Simulstreaming: Decouple the encoder for faster inference
|
||||
|
||||
Simulstreaming encoder time (whisperlivekit/simul_whisper/simul_whisper.py l. 397) experimentations :
|
||||
|
||||
On macOS Apple Silicon M4 :
|
||||
|
||||
| Encoder | base.en | small |
|
||||
|--------|---------|-------|
|
||||
| WHISPER (no modification) | 0.35s | 1.09s |
|
||||
| FASTER_WHISPER | 0.4s | 1.20s |
|
||||
| MLX_WHISPER | 0.07s | 0.20s |
|
||||
|
||||
Memory saved by only loading encoder for optimized framework:
|
||||
|
||||
For tiny.en, mlx whisper:
|
||||
Sizes MLX whisper:
|
||||
Decoder weights: 59110771 bytes
|
||||
Encoder weights: 15268874 bytes
|
||||
|
||||
|
||||
|
||||
# 2. SortFormer Diarization: 4-to-2 Speaker Constraint Algorithm
|
||||
|
||||
Transform a diarization model that predicts up to 4 speakers into one that predicts up to 2 speakers by mapping the output predictions.
|
||||
|
||||
## Problem Statement
|
||||
- Input: `self.total_preds` with shape `(x, x, 4)` - predictions for 4 speakers
|
||||
- Output: Constrained predictions with shape `(x, x, 2)` - predictions for 2 speakers
|
||||
|
||||
#
|
||||
### Initial Setup
|
||||
For each time step `i`, we have a ranking of 4 speaker predictions (1-4). When only 2 speakers are present, the model will have close predictions for the 2 active speaker positions.
|
||||
|
||||
Instead of `np.argmax(preds_np, axis=1)`, we take the top 2 predictions and build a dynamic 4→2 mapping that can evolve over time.
|
||||
|
||||
### Algorithm
|
||||
|
||||
```python
|
||||
top_2_speakers = np.argsort(preds_np, axis=1)[:, -2:]
|
||||
```
|
||||
|
||||
- `DS_a_{i}`: Top detected speaker for prediction i
|
||||
- `DS_b_{i}`: Second detected speaker for prediction i
|
||||
- `AS_{i}`: Attributed speaker for prediction i
|
||||
- `GTS_A`: Ground truth speaker A
|
||||
- `GTS_B`: Ground truth speaker B
|
||||
- `DIST(a, b)`: Distance between detected speakers a and b
|
||||
|
||||
3. **Attribution Logic**
|
||||
|
||||
```
|
||||
AS_0 ← A
|
||||
|
||||
AS_1 ← B
|
||||
|
||||
IF DIST(DS_a_0, DS_a_1) < DIST(DS_a_0, DS_a_2) AND
|
||||
DIST(DS_a_0, DS_a_1) < DIST(DS_a_1, DS_a_2):
|
||||
# Likely that DS_a_0 = DS_a_1 (same speaker)
|
||||
AS_1 ← A
|
||||
AS_2 ← B
|
||||
|
||||
ELIF DIST(DS_a_0, DS_a_2) < DIST(DS_a_0, DS_a_1) AND
|
||||
DIST(DS_a_0, DS_a_2) < DIST(DS_a_1, DS_a_2):
|
||||
AS_2 ← A
|
||||
|
||||
ELSE:
|
||||
AS_2 ← B
|
||||
|
||||
to finish
|
||||
```
|
||||
49
Dockerfile
@@ -1,4 +1,4 @@
|
||||
FROM nvidia/cuda:12.8.1-cudnn-runtime-ubuntu22.04
|
||||
FROM nvidia/cuda:12.9.1-cudnn-devel-ubuntu24.04
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
@@ -9,46 +9,50 @@ ARG EXTRAS
|
||||
ARG HF_PRECACHE_DIR
|
||||
ARG HF_TKN_FILE
|
||||
|
||||
# Install system dependencies
|
||||
#RUN apt-get update && \
|
||||
# apt-get install -y ffmpeg git && \
|
||||
# apt-get clean && \
|
||||
# rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# 2) Install system dependencies + Python + pip
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
python3 \
|
||||
python3-pip \
|
||||
python3-venv \
|
||||
ffmpeg \
|
||||
git && \
|
||||
git \
|
||||
build-essential \
|
||||
python3-dev \
|
||||
ca-certificates && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
|
||||
RUN python3 -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# timeout/retries for large torch wheels
|
||||
RUN pip3 install --upgrade pip setuptools wheel && \
|
||||
pip3 --disable-pip-version-check install --timeout=120 --retries=5 \
|
||||
--index-url https://download.pytorch.org/whl/cu129 \
|
||||
torch torchaudio \
|
||||
|| (echo "Initial install failed — retrying with extended timeout..." && \
|
||||
pip3 --disable-pip-version-check install --timeout=300 --retries=3 \
|
||||
--index-url https://download.pytorch.org/whl/cu129 \
|
||||
torch torchvision torchaudio)
|
||||
|
||||
COPY . .
|
||||
|
||||
# Install WhisperLiveKit directly, allowing for optional dependencies
|
||||
# Note: For gates models, need to add your HF toke. See README.md
|
||||
# for more details.
|
||||
RUN if [ -n "$EXTRAS" ]; then \
|
||||
echo "Installing with extras: [$EXTRAS]"; \
|
||||
pip install --no-cache-dir .[$EXTRAS]; \
|
||||
pip install --no-cache-dir whisperlivekit[$EXTRAS]; \
|
||||
else \
|
||||
echo "Installing base package only"; \
|
||||
pip install --no-cache-dir .; \
|
||||
pip install --no-cache-dir whisperlivekit; \
|
||||
fi
|
||||
|
||||
# Enable in-container caching for Hugging Face models by:
|
||||
# Note: If running multiple containers, better to map a shared
|
||||
# bucket.
|
||||
#
|
||||
# In-container caching for Hugging Face models by:
|
||||
# A) Make the cache directory persistent via an anonymous volume.
|
||||
# Note: This only persists for a single, named container. This is
|
||||
# only for convenience at de/test stage.
|
||||
# For prod, it is better to use a named volume via host mount/k8s.
|
||||
VOLUME ["/root/.cache/huggingface/hub"]
|
||||
|
||||
|
||||
# or
|
||||
# B) Conditionally copy a local pre-cache from the build context to the
|
||||
# container's cache via the HF_PRECACHE_DIR build-arg.
|
||||
@@ -63,8 +67,7 @@ RUN if [ -n "$HF_PRECACHE_DIR" ]; then \
|
||||
echo "No local Hugging Face cache specified, skipping copy"; \
|
||||
fi
|
||||
|
||||
# Conditionally copy a Hugging Face token if provided
|
||||
|
||||
# Conditionally copy a Hugging Face token if provided. Useful for Diart backend (pyannote audio models)
|
||||
RUN if [ -n "$HF_TKN_FILE" ]; then \
|
||||
echo "Copying Hugging Face token from $HF_TKN_FILE"; \
|
||||
mkdir -p /root/.cache/huggingface && \
|
||||
@@ -72,11 +75,9 @@ RUN if [ -n "$HF_TKN_FILE" ]; then \
|
||||
else \
|
||||
echo "No Hugging Face token file specified, skipping token setup"; \
|
||||
fi
|
||||
|
||||
# Expose port for the transcription server
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"]
|
||||
|
||||
# Default args
|
||||
CMD ["--model", "tiny.en"]
|
||||
CMD ["--model", "medium"]
|
||||
|
||||
61
Dockerfile.cpu
Normal file
@@ -0,0 +1,61 @@
|
||||
FROM python:3.13-slim
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
ARG EXTRAS
|
||||
ARG HF_PRECACHE_DIR
|
||||
ARG HF_TKN_FILE
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
ffmpeg \
|
||||
git \
|
||||
build-essential \
|
||||
python3-dev && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install CPU-only PyTorch
|
||||
RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
COPY . .
|
||||
|
||||
# Install WhisperLiveKit directly, allowing for optional dependencies
|
||||
RUN if [ -n "$EXTRAS" ]; then \
|
||||
echo "Installing with extras: [$EXTRAS]"; \
|
||||
pip install --no-cache-dir whisperlivekit[$EXTRAS]; \
|
||||
else \
|
||||
echo "Installing base package only"; \
|
||||
pip install --no-cache-dir whisperlivekit; \
|
||||
fi
|
||||
|
||||
# Enable in-container caching for Hugging Face models
|
||||
VOLUME ["/root/.cache/huggingface/hub"]
|
||||
|
||||
# Conditionally copy a local pre-cache from the build context
|
||||
RUN if [ -n "$HF_PRECACHE_DIR" ]; then \
|
||||
echo "Copying Hugging Face cache from $HF_PRECACHE_DIR"; \
|
||||
mkdir -p /root/.cache/huggingface/hub && \
|
||||
cp -r $HF_PRECACHE_DIR/* /root/.cache/huggingface/hub; \
|
||||
else \
|
||||
echo "No local Hugging Face cache specified, skipping copy"; \
|
||||
fi
|
||||
|
||||
# Conditionally copy a Hugging Face token if provided
|
||||
RUN if [ -n "$HF_TKN_FILE" ]; then \
|
||||
echo "Copying Hugging Face token from $HF_TKN_FILE"; \
|
||||
mkdir -p /root/.cache/huggingface && \
|
||||
cp $HF_TKN_FILE /root/.cache/huggingface/token; \
|
||||
else \
|
||||
echo "No Hugging Face token file specified, skipping token setup"; \
|
||||
fi
|
||||
|
||||
# Expose port for the transcription server
|
||||
EXPOSE 8000
|
||||
|
||||
ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"]
|
||||
|
||||
# Default args - you might want to use a smaller model for CPU
|
||||
CMD ["--model", "tiny"]
|
||||
322
README.md
@@ -4,133 +4,97 @@
|
||||
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/demo.png" alt="WhisperLiveKit Demo" width="730">
|
||||
</p>
|
||||
|
||||
<p align="center"><b>Real-time, Fully Local Speech-to-Text with Speaker Diarization</b></p>
|
||||
<p align="center"><b>Real-time, Fully Local Speech-to-Text with Speaker Identification</b></p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
|
||||
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=downloads"></a>
|
||||
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.13-dark_green"></a>
|
||||
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a>
|
||||
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.15-dark_green"></a>
|
||||
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT/Dual Licensed-dark_green"></a>
|
||||
</p>
|
||||
|
||||
## Overview
|
||||
|
||||
This project is based on [WhisperStreaming](https://github.com/ufal/whisper_streaming) and [SimulStreaming](https://github.com/ufal/SimulStreaming), allowing you to transcribe audio directly from your browser. WhisperLiveKit provides a complete backend solution for real-time speech transcription with a functional, simple and customizable frontend. Everything runs locally on your machine ✨
|
||||
Real-time speech transcription directly to your browser, with a ready-to-use backend+server and a simple frontend. ✨
|
||||
|
||||
#### Powered by Leading Research:
|
||||
|
||||
- [SimulStreaming](https://github.com/ufal/SimulStreaming) (SOTA 2025) - Ultra-low latency transcription with AlignAtt policy
|
||||
- [NLLB](https://arxiv.org/abs/2207.04672), ([distilled](https://huggingface.co/entai2965/nllb-200-distilled-600M-ctranslate2)) (2024) - Translation to more than 100 languages.
|
||||
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - Low latency transcription with LocalAgreement policy
|
||||
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - Advanced real-time speaker diarization
|
||||
- [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) - Real-time speaker diarization
|
||||
- [Silero VAD](https://github.com/snakers4/silero-vad) (2024) - Enterprise-grade Voice Activity Detection
|
||||
|
||||
|
||||
> **Why not just run a simple Whisper model on every audio batch?** Whisper is designed for complete utterances, not real-time chunks. Processing small segments loses context, cuts off words mid-syllable, and produces poor transcription. WhisperLiveKit uses state-of-the-art simultaneous speech research for intelligent buffering and incremental processing.
|
||||
|
||||
|
||||
### Architecture
|
||||
|
||||
WhisperLiveKit consists of three main components:
|
||||
<img alt="Architecture" src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/architecture.png" />
|
||||
|
||||
- **Frontend**: A basic html + JS interface that captures microphone audio and streams it to the backend via WebSockets. You can use and adapt the [provided template](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html).
|
||||
- **Backend (Web Server)**: A FastAPI-based WebSocket server that receives streamed audio data, processes it in real time, and returns transcriptions to the frontend. This is where the WebSocket logic and routing live.
|
||||
- **Core Backend (Library Logic)**: A server-agnostic core that handles audio processing, ASR, and diarization. It exposes reusable components that take in audio bytes and return transcriptions.
|
||||
*The backend supports multiple concurrent users. Voice Activity Detection reduces overhead when no voice is detected.*
|
||||
|
||||
|
||||
### Key Features
|
||||
|
||||
- **Real-time Transcription** - Locally (or on-prem) convert speech to text instantly as you speak
|
||||
- **Speaker Diarization** - Identify different speakers in real-time using [Diart](https://github.com/juanmc2005/diart)
|
||||
- **Multi-User Support** - Handle multiple users simultaneously with a single backend/server
|
||||
- **Automatic Silence Chunking** – Automatically chunks when no audio is detected to limit buffer size
|
||||
- **Confidence Validation** – Immediately validate high-confidence tokens for faster inference (WhisperStreaming only)
|
||||
- **Buffering Preview** – Displays unvalidated transcription segments (not compatible with SimulStreaming yet)
|
||||
- **Punctuation-Based Speaker Splitting [BETA]** - Align speaker changes with natural sentence boundaries for more readable transcripts
|
||||
- **SimulStreaming Backend** - [Dual-licensed](https://github.com/ufal/SimulStreaming#-licence-and-contributions) - Ultra-low latency transcription using SOTA AlignAtt policy.
|
||||
|
||||
## Quick Start
|
||||
### Installation & Quick Start
|
||||
|
||||
```bash
|
||||
# Install the package
|
||||
pip install whisperlivekit
|
||||
|
||||
# Start the transcription server
|
||||
whisperlivekit-server --model tiny.en
|
||||
|
||||
# Open your browser at http://localhost:8000 to see the interface.
|
||||
# Use -ssl-certfile public.crt --ssl-keyfile private.key parameters to use SSL
|
||||
```
|
||||
> You can also clone the repo and `pip install -e .` for the latest version.
|
||||
|
||||
That's it! Start speaking and watch your words appear on screen.
|
||||
|
||||
## Installation
|
||||
> **FFmpeg is required** and must be installed before using WhisperLiveKit
|
||||
>
|
||||
> | OS | How to install |
|
||||
> |-----------|-------------|
|
||||
> | Ubuntu/Debian | `sudo apt install ffmpeg` |
|
||||
> | MacOS | `brew install ffmpeg` |
|
||||
> | Windows | Download .exe from https://ffmpeg.org/download.html and add to PATH |
|
||||
|
||||
#### Quick Start
|
||||
1. **Start the transcription server:**
|
||||
```bash
|
||||
whisperlivekit-server --model base --language en
|
||||
```
|
||||
|
||||
2. **Open your browser** and navigate to `http://localhost:8000`. Start speaking and watch your words appear in real-time!
|
||||
|
||||
|
||||
> - See [tokenizer.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) for the list of all available languages.
|
||||
> - For HTTPS requirements, see the **Parameters** section for SSL configuration options.
|
||||
|
||||
|
||||
|
||||
#### Optional Dependencies
|
||||
|
||||
| Optional | `pip install` |
|
||||
|-----------|-------------|
|
||||
| **Speaker diarization with Sortformer** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
|
||||
| **Apple Silicon optimized backend** | `mlx-whisper` |
|
||||
| **NLLB Translation** | `huggingface_hub` & `transformers` |
|
||||
| *[Not recommanded]* Speaker diarization with Diart | `diart` |
|
||||
| *[Not recommanded]* Original Whisper backend | `whisper` |
|
||||
| *[Not recommanded]* Improved timestamps backend | `whisper-timestamped` |
|
||||
| OpenAI API backend | `openai` |
|
||||
|
||||
See **Parameters & Configuration** below on how to use them.
|
||||
|
||||
|
||||
|
||||
### Usage Examples
|
||||
|
||||
**Command-line Interface**: Start the transcription server with various options:
|
||||
|
||||
```bash
|
||||
#Install from PyPI (Recommended)
|
||||
pip install whisperlivekit
|
||||
# Use better model than default (small)
|
||||
whisperlivekit-server --model large-v3
|
||||
|
||||
#Install from Source
|
||||
git clone https://github.com/QuentinFuxa/WhisperLiveKit
|
||||
cd WhisperLiveKit
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### FFmpeg Dependency
|
||||
|
||||
```bash
|
||||
# Ubuntu/Debian
|
||||
sudo apt install ffmpeg
|
||||
|
||||
# macOS
|
||||
brew install ffmpeg
|
||||
|
||||
# Windows
|
||||
# Download from https://ffmpeg.org/download.html and add to PATH
|
||||
```
|
||||
|
||||
### Optional Dependencies
|
||||
|
||||
```bash
|
||||
# Voice Activity Controller (prevents hallucinations)
|
||||
pip install torch
|
||||
|
||||
# Sentence-based buffer trimming
|
||||
pip install mosestokenizer wtpsplit
|
||||
pip install tokenize_uk # If you work with Ukrainian text
|
||||
|
||||
# Speaker diarization
|
||||
pip install diart
|
||||
|
||||
# Alternative Whisper backends (default is faster-whisper)
|
||||
pip install whisperlivekit[whisper] # Original Whisper
|
||||
pip install whisperlivekit[whisper-timestamped] # Improved timestamps
|
||||
pip install whisperlivekit[mlx-whisper] # Apple Silicon optimization
|
||||
pip install whisperlivekit[openai] # OpenAI API
|
||||
pip install whisperlivekit[simulstreaming]
|
||||
```
|
||||
|
||||
### 🎹 Pyannote Models Setup
|
||||
|
||||
For diarization, you need access to pyannote.audio models:
|
||||
|
||||
1. [Accept user conditions](https://huggingface.co/pyannote/segmentation) for the `pyannote/segmentation` model
|
||||
2. [Accept user conditions](https://huggingface.co/pyannote/segmentation-3.0) for the `pyannote/segmentation-3.0` model
|
||||
3. [Accept user conditions](https://huggingface.co/pyannote/embedding) for the `pyannote/embedding` model
|
||||
4. Login with HuggingFace:
|
||||
```bash
|
||||
pip install huggingface_hub
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
## 💻 Usage Examples
|
||||
|
||||
### Command-line Interface
|
||||
|
||||
Start the transcription server with various options:
|
||||
|
||||
```bash
|
||||
# Basic server with English model
|
||||
whisperlivekit-server --model tiny.en
|
||||
|
||||
# Advanced configuration with diarization
|
||||
whisperlivekit-server --host 0.0.0.0 --port 8000 --model medium --diarization --language auto
|
||||
|
||||
# SimulStreaming backend for ultra-low latency
|
||||
whisperlivekit-server --backend simulstreaming --model large-v3 --frame-threshold 20
|
||||
# Advanced configuration with diarization and language
|
||||
whisperlivekit-server --host 0.0.0.0 --port 8000 --model medium --diarization --language fr
|
||||
```
|
||||
|
||||
|
||||
### Python API Integration (Backend)
|
||||
Check [basic_server.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) for a complete example.
|
||||
**Python API Integration**: Check [basic_server](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) for a more complete example of how to use the functions and classes.
|
||||
|
||||
```python
|
||||
from whisperlivekit import TranscriptionEngine, AudioProcessor, parse_args
|
||||
@@ -145,14 +109,10 @@ transcription_engine = None
|
||||
async def lifespan(app: FastAPI):
|
||||
global transcription_engine
|
||||
transcription_engine = TranscriptionEngine(model="medium", diarization=True, lan="en")
|
||||
# You can also load from command-line arguments using parse_args()
|
||||
# args = parse_args()
|
||||
# transcription_engine = TranscriptionEngine(**vars(args))
|
||||
yield
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
# Process WebSocket connections
|
||||
async def handle_websocket_results(websocket: WebSocket, results_generator):
|
||||
async for response in results_generator:
|
||||
await websocket.send_json(response)
|
||||
@@ -172,44 +132,44 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
await audio_processor.process_audio(message)
|
||||
```
|
||||
|
||||
### Frontend Implementation
|
||||
**Frontend Implementation**: The package includes an HTML/JavaScript implementation [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html). You can also import it using `from whisperlivekit import get_inline_ui_html` & `page = get_inline_ui_html()`
|
||||
|
||||
The package includes a simple HTML/JavaScript implementation that you can adapt for your project. You can find it [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html), or load its content using `get_web_interface_html()` :
|
||||
|
||||
```python
|
||||
from whisperlivekit import get_web_interface_html
|
||||
html_content = get_web_interface_html()
|
||||
```
|
||||
## Parameters & Configuration
|
||||
|
||||
## ⚙️ Configuration Reference
|
||||
An important list of parameters can be changed. But what *should* you change?
|
||||
- the `--model` size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/available_models.md)
|
||||
- the `--language`. List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English.
|
||||
- the `--backend` ? you can switch to `--backend faster-whisper` if `simulstreaming` does not work correctly or if you prefer to avoid the dual-license requirements.
|
||||
- `--warmup-file`, if you have one
|
||||
- `--task translate`, to translate in english
|
||||
- `--host`, `--port`, `--ssl-certfile`, `--ssl-keyfile`, if you set up a server
|
||||
- `--diarization`, if you want to use it.
|
||||
- [BETA] `--target-language`, to translate using NLLB. [118 languages available](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/translation/mapping_languages.py). If you want to translate to english, you should rather use `--task translate`, since Whisper can do it directly.
|
||||
|
||||
WhisperLiveKit offers extensive configuration options:
|
||||
### Full list of parameters :
|
||||
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--model` | Whisper model size. | `small` |
|
||||
| `--language` | Source language code or `auto` | `auto` |
|
||||
| `--task` | Set to `translate` to translate to english | `transcribe` |
|
||||
| `--target-language` | [BETA] Translation language target. Ex: `fr` | `None` |
|
||||
| `--backend` | Processing backend | `simulstreaming` |
|
||||
| `--min-chunk-size` | Minimum audio chunk size (seconds) | `1.0` |
|
||||
| `--no-vac` | Disable Voice Activity Controller | `False` |
|
||||
| `--no-vad` | Disable Voice Activity Detection | `False` |
|
||||
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
|
||||
| `--host` | Server host address | `localhost` |
|
||||
| `--port` | Server port | `8000` |
|
||||
| `--model` | Whisper model size. Caution : '.en' models do not work with Simulstreaming | `tiny` |
|
||||
| `--language` | Source language code or `auto` | `en` |
|
||||
| `--task` | `transcribe` or `translate` | `transcribe` |
|
||||
| `--backend` | Processing backend | `faster-whisper` |
|
||||
| `--diarization` | Enable speaker identification | `False` |
|
||||
| `--punctuation-split` | Use punctuation to improve speaker boundaries | `True` |
|
||||
| `--confidence-validation` | Use confidence scores for faster validation | `False` |
|
||||
| `--min-chunk-size` | Minimum audio chunk size (seconds) | `1.0` |
|
||||
| `--vac` | Use Voice Activity Controller | `False` |
|
||||
| `--no-vad` | Disable Voice Activity Detection | `False` |
|
||||
| `--buffer_trimming` | Buffer trimming strategy (`sentence` or `segment`) | `segment` |
|
||||
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
|
||||
| `--ssl-certfile` | Path to the SSL certificate file (for HTTPS support) | `None` |
|
||||
| `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` |
|
||||
| `--segmentation-model` | Hugging Face model ID for pyannote.audio segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
|
||||
| `--embedding-model` | Hugging Face model ID for pyannote.audio embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
|
||||
| `--pcm-input` | raw PCM (s16le) data is expected as input and FFmpeg will be bypassed. | `False` |
|
||||
|
||||
**SimulStreaming-specific Options:**
|
||||
|
||||
| Parameter | Description | Default |
|
||||
| SimulStreaming backend options | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--disable-fast-encoder` | Disable Faster Whisper or MLX Whisper backends for the encoder (if installed). Inference can be slower but helpful when GPU memory is limited | `False` |
|
||||
| `--frame-threshold` | AlignAtt frame threshold (lower = faster, higher = more accurate) | `25` |
|
||||
| `--beams` | Number of beams for beam search (1 = greedy decoding) | `1` |
|
||||
| `--decoder` | Force decoder type (`beam` or `greedy`) | `auto` |
|
||||
@@ -221,68 +181,91 @@ WhisperLiveKit offers extensive configuration options:
|
||||
| `--static-init-prompt` | Static prompt that doesn't scroll | `None` |
|
||||
| `--max-context-tokens` | Maximum context tokens | `None` |
|
||||
| `--model-path` | Direct path to .pt model file. Download it if not found | `./base.pt` |
|
||||
| `--preload-model-count` | Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent users) | `1` |
|
||||
|
||||
## 🔧 How It Works
|
||||
|
||||
1. **Audio Capture**: Browser's MediaRecorder API captures audio in webm/opus format
|
||||
2. **Streaming**: Audio chunks are sent to the server via WebSocket
|
||||
3. **Processing**: Server decodes audio with FFmpeg and streams into the model for transcription
|
||||
4. **Real-time Output**: Partial transcriptions appear immediately in light gray (the 'aperçu') and finalized text appears in normal color
|
||||
| WhisperStreaming backend options | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--confidence-validation` | Use confidence scores for faster validation | `False` |
|
||||
| `--buffer_trimming` | Buffer trimming strategy (`sentence` or `segment`) | `segment` |
|
||||
|
||||
## 🚀 Deployment Guide
|
||||
| Diarization options | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--diarization` | Enable speaker identification | `False` |
|
||||
| `--diarization-backend` | `diart` or `sortformer` | `sortformer` |
|
||||
| `--disable-punctuation-split` | Disable punctuation based splits. See #214 | `False` |
|
||||
| `--segmentation-model` | Hugging Face model ID for Diart segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
|
||||
| `--embedding-model` | Hugging Face model ID for Diart embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
|
||||
|
||||
|
||||
> For diarization using Diart, you need access to pyannote.audio models:
|
||||
> 1. [Accept user conditions](https://huggingface.co/pyannote/segmentation) for the `pyannote/segmentation` model
|
||||
> 2. [Accept user conditions](https://huggingface.co/pyannote/segmentation-3.0) for the `pyannote/segmentation-3.0` model
|
||||
> 3. [Accept user conditions](https://huggingface.co/pyannote/embedding) for the `pyannote/embedding` model
|
||||
>4. Login with HuggingFace: `huggingface-cli login`
|
||||
|
||||
### 🚀 Deployment Guide
|
||||
|
||||
To deploy WhisperLiveKit in production:
|
||||
|
||||
1. **Server Setup** (Backend):
|
||||
|
||||
1. **Server Setup**: Install production ASGI server & launch with multiple workers
|
||||
```bash
|
||||
# Install production ASGI server
|
||||
pip install uvicorn gunicorn
|
||||
|
||||
# Launch with multiple workers
|
||||
gunicorn -k uvicorn.workers.UvicornWorker -w 4 your_app:app
|
||||
```
|
||||
|
||||
2. **Frontend Integration**:
|
||||
- Host your customized version of the example HTML/JS in your web application
|
||||
- Ensure WebSocket connection points to your server's address
|
||||
2. **Frontend**: Host your customized version of the `html` example & ensure WebSocket connection points correctly
|
||||
|
||||
3. **Nginx Configuration** (recommended for production):
|
||||
```nginx
|
||||
```nginx
|
||||
server {
|
||||
listen 80;
|
||||
server_name your-domain.com;
|
||||
|
||||
location / {
|
||||
proxy_pass http://localhost:8000;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection "upgrade";
|
||||
proxy_set_header Host $host;
|
||||
location / {
|
||||
proxy_pass http://localhost:8000;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection "upgrade";
|
||||
proxy_set_header Host $host;
|
||||
}}
|
||||
```
|
||||
|
||||
4. **HTTPS Support**: For secure deployments, use "wss://" instead of "ws://" in WebSocket URL
|
||||
|
||||
### 🐋 Docker
|
||||
## 🐋 Docker
|
||||
|
||||
A basic Dockerfile is provided which allows re-use of Python package installation options. ⚠️ For **large** models, ensure that your **docker runtime** has enough **memory** available. See below usage examples:
|
||||
Deploy the application easily using Docker with GPU or CPU support.
|
||||
|
||||
### Prerequisites
|
||||
- Docker installed on your system
|
||||
- For GPU support: NVIDIA Docker runtime installed
|
||||
|
||||
#### All defaults
|
||||
- Create a reusable image with only the basics and then run as a named container:
|
||||
### Quick Start
|
||||
|
||||
**With GPU acceleration (recommended):**
|
||||
```bash
|
||||
docker build -t whisperlivekit-defaults .
|
||||
docker create --gpus all --name whisperlivekit -p 8000:8000 whisperlivekit-defaults
|
||||
docker start -i whisperlivekit
|
||||
docker build -t wlk .
|
||||
docker run --gpus all -p 8000:8000 --name wlk wlk
|
||||
```
|
||||
|
||||
> **Note**: If you're running on a system without NVIDIA GPU support (such as Mac with Apple Silicon or any system without CUDA capabilities), you need to **remove the `--gpus all` flag** from the `docker create` command. Without GPU acceleration, transcription will use CPU only, which may be significantly slower. Consider using small models for better performance on CPU-only systems.
|
||||
**CPU only:**
|
||||
```bash
|
||||
docker build -f Dockerfile.cpu -t wlk .
|
||||
docker run -p 8000:8000 --name wlk wlk
|
||||
```
|
||||
|
||||
### Advanced Usage
|
||||
|
||||
**Custom configuration:**
|
||||
```bash
|
||||
# Example with custom model and language
|
||||
docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
|
||||
```
|
||||
|
||||
### Memory Requirements
|
||||
- **Large models**: Ensure your Docker runtime has sufficient memory allocated
|
||||
|
||||
|
||||
#### Customization
|
||||
- Customize the container options:
|
||||
```bash
|
||||
docker build -t whisperlivekit-defaults .
|
||||
docker create --gpus all --name whisperlivekit-base -p 8000:8000 whisperlivekit-defaults --model base
|
||||
docker start -i whisperlivekit-base
|
||||
```
|
||||
|
||||
- `--build-arg` Options:
|
||||
- `EXTRAS="whisper-timestamped"` - Add extras to the image's installation (no spaces). Remember to set necessary container options!
|
||||
@@ -291,10 +274,3 @@ docker start -i whisperlivekit-base
|
||||
|
||||
## 🔮 Use Cases
|
||||
Capture discussions in real-time for meeting transcription, help hearing-impaired users follow conversations through accessibility tools, transcribe podcasts or videos automatically for content creation, transcribe support calls with speaker identification for customer service...
|
||||
|
||||
## 🙏 Acknowledgments
|
||||
|
||||
We extend our gratitude to the original authors of:
|
||||
|
||||
| [Whisper Streaming](https://github.com/ufal/whisper_streaming) | [SimulStreaming](https://github.com/ufal/SimulStreaming) | [Diart](https://github.com/juanmc2005/diart) | [OpenAI Whisper](https://github.com/openai/whisper) |
|
||||
| -------- | ------- | -------- | ------- |
|
||||
|
||||
258
ReadmeJP.md
Normal file
@@ -0,0 +1,258 @@
|
||||
<h1 align="center">WhisperLiveKit</h1>
|
||||
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/demo.png" alt="WhisperLiveKit Demo" width="730">
|
||||
</p>
|
||||
|
||||
<p align="center"><b>話者識別機能付き、リアルタイム、完全ローカルな音声テキスト変換</b></p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
|
||||
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a>
|
||||
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.13-dark_green"></a>
|
||||
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT/Dual Licensed-dark_green"></a>
|
||||
</p>
|
||||
|
||||
すぐに使えるバックエンド+サーバーとシンプルなフロントエンドで、リアルタイムの音声文字起こしをブラウザに直接提供します。✨
|
||||
|
||||
#### 主要な研究による技術:
|
||||
|
||||
- [SimulStreaming](https://github.com/ufal/SimulStreaming) (SOTA 2025) - AlignAttポリシーによる超低遅延文字起こし
|
||||
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - LocalAgreementポリシーによる低遅延文字起こし
|
||||
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - 高度なリアルタイム話者ダイアライゼーション
|
||||
- [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) - リアルタイム話者ダイアライゼーション
|
||||
- [Silero VAD](https://github.com/snakers4/silero-vad) (2024) - エンタープライズグレードの音声区間検出
|
||||
|
||||
> **なぜ各音声バッチで単純なWhisperモデルを実行しないのか?** Whisperは完全な発話向けに設計されており、リアルタイムのチャンク向けではありません。小さなセグメントを処理するとコンテキストが失われ、単語が音節の途中で途切れ、質の悪い文字起こしになります。WhisperLiveKitは、インテリジェントなバッファリングとインクリメンタルな処理のために、最先端の同時音声研究を利用しています。
|
||||
|
||||
### アーキテクチャ
|
||||
|
||||
<img alt="Architecture" src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/architecture.png" />
|
||||
|
||||
*バックエンドは複数の同時ユーザーをサポートします。音声が検出されない場合、音声区間検出がオーバーヘッドを削減します。*
|
||||
|
||||
### インストールとクイックスタート
|
||||
|
||||
```bash
|
||||
pip install whisperlivekit
|
||||
```
|
||||
|
||||
> **FFmpegが必要です** WhisperLiveKitを使用する前にインストールする必要があります。
|
||||
>
|
||||
> | OS | インストール方法 |
|
||||
> |-----------|-------------|
|
||||
> | Ubuntu/Debian | `sudo apt install ffmpeg` |
|
||||
> | MacOS | `brew install ffmpeg` |
|
||||
> | Windows | https://ffmpeg.org/download.html から.exeをダウンロードし、PATHに追加 |
|
||||
|
||||
#### クイックスタート
|
||||
1. **文字起こしサーバーを起動します:**
|
||||
```bash
|
||||
whisperlivekit-server --model base --language en
|
||||
```
|
||||
|
||||
2. **ブラウザを開き** `http://localhost:8000` にアクセスします。話し始めると、あなたの言葉がリアルタイムで表示されます!
|
||||
|
||||
|
||||
> - 利用可能なすべての言語のリストについては、[tokenizer.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) を参照してください。
|
||||
> - HTTPSの要件については、**パラメータ**セクションのSSL設定オプションを参照してください。
|
||||
|
||||
#### オプションの依存関係
|
||||
|
||||
| オプション | `pip install` |
|
||||
|-----------|-------------|
|
||||
| **Sortformerによる話者ダイアライゼーション** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
|
||||
| Diartによる話者ダイアライゼーション | `diart` |
|
||||
| オリジナルのWhisperバックエンド | `whisper` |
|
||||
| タイムスタンプ改善バックエンド | `whisper-timestamped` |
|
||||
| Apple Silicon最適化バックエンド | `mlx-whisper` |
|
||||
| OpenAI APIバックエンド | `openai` |
|
||||
|
||||
それらの使用方法については、以下の**パラメータと設定**を参照してください。
|
||||
|
||||
### 使用例
|
||||
|
||||
**コマンドラインインターフェース**: 様々なオプションで文字起こしサーバーを起動します:
|
||||
|
||||
```bash
|
||||
# デフォルト(small)より良いモデルを使用
|
||||
whisperlivekit-server --model large-v3
|
||||
|
||||
# ダイアライゼーションと言語を指定した高度な設定
|
||||
whisperlivekit-server --host 0.0.0.0 --port 8000 --model medium --diarization --language fr
|
||||
```
|
||||
|
||||
**Python API連携**: 関数やクラスの使用方法のより完全な例については、[basic_server](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) を確認してください。
|
||||
|
||||
```python
|
||||
from whisperlivekit import TranscriptionEngine, AudioProcessor, parse_args
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import HTMLResponse
|
||||
from contextlib import asynccontextmanager
|
||||
import asyncio
|
||||
|
||||
transcription_engine = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global transcription_engine
|
||||
transcription_engine = TranscriptionEngine(model="medium", diarization=True, lan="en")
|
||||
yield
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
async def handle_websocket_results(websocket: WebSocket, results_generator):
|
||||
async for response in results_generator:
|
||||
await websocket.send_json(response)
|
||||
await websocket.send_json({"type": "ready_to_stop"})
|
||||
|
||||
@app.websocket("/asr")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
global transcription_engine
|
||||
|
||||
# 接続ごとに新しいAudioProcessorを作成し、共有エンジンを渡す
|
||||
audio_processor = AudioProcessor(transcription_engine=transcription_engine)
|
||||
results_generator = await audio_processor.create_tasks()
|
||||
results_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
|
||||
await websocket.accept()
|
||||
while True:
|
||||
message = await websocket.receive_bytes()
|
||||
await audio_processor.process_audio(message)
|
||||
```
|
||||
|
||||
**フロントエンド実装**: パッケージにはHTML/JavaScript実装が[ここ](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html)に含まれています。`from whisperlivekit import get_web_interface_html` & `page = get_web_interface_html()` を使ってインポートすることもできます。
|
||||
|
||||
|
||||
## パラメータと設定
|
||||
|
||||
重要なパラメータのリストを変更できます。しかし、何を*変更すべき*でしょうか?
|
||||
- `--model` サイズ。リストと推奨事項は[こちら](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/available_models.md)
|
||||
- `--language`。リストは[こちら](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py)。`auto`を使用すると、モデルは自動的に言語を検出しようとしますが、英語に偏る傾向があります。
|
||||
- `--backend`? `simulstreaming`が正しく動作しない場合や、デュアルライセンス要件を避けたい場合は`--backend faster-whisper`に切り替えることができます。
|
||||
- `--warmup-file`、もしあれば
|
||||
- `--host`, `--port`, `--ssl-certfile`, `--ssl-keyfile`、サーバーをセットアップする場合
|
||||
- `--diarization`、使用したい場合。
|
||||
|
||||
残りは推奨しません。しかし、以下があなたのオプションです。
|
||||
|
||||
| パラメータ | 説明 | デフォルト |
|
||||
|-----------|-------------|---------|
|
||||
| `--model` | Whisperモデルのサイズ。 | `small` |
|
||||
| `--language` | ソース言語コードまたは`auto` | `auto` |
|
||||
| `--task` | `transcribe`または`translate` | `transcribe` |
|
||||
| `--backend` | 処理バックエンド | `simulstreaming` |
|
||||
| `--min-chunk-size` | 最小音声チャンクサイズ(秒) | `1.0` |
|
||||
| `--no-vac` | 音声アクティビティコントローラーを無効化 | `False` |
|
||||
| `--no-vad` | 音声区間検出を無効化 | `False` |
|
||||
| `--warmup-file` | モデルのウォームアップ用音声ファイルパス | `jfk.wav` |
|
||||
| `--host` | サーバーホストアドレス | `localhost` |
|
||||
| `--port` | サーバーポート | `8000` |
|
||||
| `--ssl-certfile` | SSL証明書ファイルへのパス(HTTPSサポート用) | `None` |
|
||||
| `--ssl-keyfile` | SSL秘密鍵ファイルへのパス(HTTPSサポート用) | `None` |
|
||||
|
||||
|
||||
| WhisperStreamingバックエンドオプション | 説明 | デフォルト |
|
||||
|-----------|-------------|---------|
|
||||
| `--confidence-validation` | 高速な検証のために信頼スコアを使用 | `False` |
|
||||
| `--buffer_trimming` | バッファトリミング戦略(`sentence`または`segment`) | `segment` |
|
||||
|
||||
|
||||
| SimulStreamingバックエンドオプション | 説明 | デフォルト |
|
||||
|-----------|-------------|---------|
|
||||
| `--frame-threshold` | AlignAttフレームしきい値(低いほど速く、高いほど正確) | `25` |
|
||||
| `--beams` | ビームサーチのビーム数(1 = 貪欲デコーディング) | `1` |
|
||||
| `--decoder` | デコーダタイプを強制(`beam`または`greedy`) | `auto` |
|
||||
| `--audio-max-len` | 最大音声バッファ長(秒) | `30.0` |
|
||||
| `--audio-min-len` | 処理する最小音声長(秒) | `0.0` |
|
||||
| `--cif-ckpt-path` | 単語境界検出用CIFモデルへのパス | `None` |
|
||||
| `--never-fire` | 未完了の単語を決して切り捨てない | `False` |
|
||||
| `--init-prompt` | モデルの初期プロンプト | `None` |
|
||||
| `--static-init-prompt` | スクロールしない静的プロンプト | `None` |
|
||||
| `--max-context-tokens` | 最大コンテキストトークン数 | `None` |
|
||||
| `--model-path` | .ptモデルファイルへの直接パス。見つからない場合はダウンロード | `./base.pt` |
|
||||
| `--preloaded-model-count` | オプション。メモリにプリロードするモデルの数(予想される同時ユーザー数まで設定) | `1` |
|
||||
|
||||
| ダイアライゼーションオプション | 説明 | デフォルト |
|
||||
|-----------|-------------|---------|
|
||||
| `--diarization` | 話者識別を有効化 | `False` |
|
||||
| `--diarization-backend` | `diart`または`sortformer` | `sortformer` |
|
||||
| `--segmentation-model` | DiartセグメンテーションモデルのHugging FaceモデルID。[利用可能なモデル](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
|
||||
| `--embedding-model` | Diart埋め込みモデルのHugging FaceモデルID。[利用可能なモデル](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
|
||||
|
||||
|
||||
> Diartを使用したダイアライゼーションには、pyannote.audioモデルへのアクセスが必要です:
|
||||
> 1. `pyannote/segmentation`モデルの[ユーザー条件に同意](https://huggingface.co/pyannote/segmentation)
|
||||
> 2. `pyannote/segmentation-3.0`モデルの[ユーザー条件に同意](https://huggingface.co/pyannote/segmentation-3.0)
|
||||
> 3. `pyannote/embedding`モデルの[ユーザー条件に同意](https://huggingface.co/pyannote/embedding)
|
||||
>4. HuggingFaceでログイン: `huggingface-cli login`
|
||||
|
||||
### 🚀 デプロイガイド
|
||||
|
||||
WhisperLiveKitを本番環境にデプロイするには:
|
||||
|
||||
1. **サーバーセットアップ**: 本番用ASGIサーバーをインストールし、複数のワーカーで起動します
|
||||
```bash
|
||||
pip install uvicorn gunicorn
|
||||
gunicorn -k uvicorn.workers.UvicornWorker -w 4 your_app:app
|
||||
```
|
||||
|
||||
2. **フロントエンド**: カスタマイズした`html`のバージョンをホストし、WebSocket接続が正しくポイントするようにします
|
||||
|
||||
3. **Nginx設定** (本番環境で推奨):
|
||||
```nginx
|
||||
server {
|
||||
listen 80;
|
||||
server_name your-domain.com;
|
||||
location / {
|
||||
proxy_pass http://localhost:8000;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection "upgrade";
|
||||
proxy_set_header Host $host;
|
||||
}}
|
||||
```
|
||||
|
||||
4. **HTTPSサポート**: 安全なデプロイメントのために、WebSocket URLで "ws://" の代わりに "wss://" を使用します
|
||||
|
||||
## 🐋 Docker
|
||||
|
||||
GPUまたはCPUサポート付きでDockerを使用してアプリケーションを簡単にデプロイします。
|
||||
|
||||
### 前提条件
|
||||
- Dockerがシステムにインストールされていること
|
||||
- GPUサポートの場合: NVIDIA Dockerランタイムがインストールされていること
|
||||
|
||||
### クイックスタート
|
||||
|
||||
**GPUアクセラレーション付き (推奨):**
|
||||
```bash
|
||||
docker build -t wlk .
|
||||
docker run --gpus all -p 8000:8000 --name wlk wlk
|
||||
```
|
||||
|
||||
**CPUのみ:**
|
||||
```bash
|
||||
docker build -f Dockerfile.cpu -t wlk .
|
||||
docker run -p 8000:8000 --name wlk wlk
|
||||
```
|
||||
|
||||
### 高度な使用法
|
||||
|
||||
**カスタム設定:**
|
||||
```bash
|
||||
# カスタムモデルと言語の例
|
||||
docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
|
||||
```
|
||||
|
||||
### メモリ要件
|
||||
- **大規模モデル**: Dockerランタイムに十分なメモリが割り当てられていることを確認してください
|
||||
|
||||
|
||||
#### カスタマイズ
|
||||
|
||||
- `--build-arg` オプション:
|
||||
- `EXTRAS="whisper-timestamped"` - イメージのインストールにエクストラを追加します(スペースなし)。必要なコンテナオプションを設定することを忘れないでください!
|
||||
- `HF_PRECACHE_DIR="./.cache/"` - 初回起動を高速化するためにモデルキャッシュをプリロードします
|
||||
- `HF_TKN_FILE="./token"` - ゲート付きモデルをダウンロードするためにHugging Face Hubアクセストークンを追加します
|
||||
|
||||
## 🔮 ユースケース
|
||||
会議の文字起こしのためにリアルタイムで議論をキャプチャする、聴覚障害のあるユーザーがアクセシビリティツールを通じて会話を追うのを助ける、コンテンツ作成のためにポッドキャストやビデオを自動的に文字起こしする、カスタマーサービスのために話者識別付きでサポートコールを文字起こしする...
|
||||
BIN
architecture.png
Normal file
|
After Width: | Height: | Size: 368 KiB |
73
available_models.md
Normal file
@@ -0,0 +1,73 @@
|
||||
# Available model sizes:
|
||||
|
||||
- tiny.en (english only)
|
||||
- tiny
|
||||
- base.en (english only)
|
||||
- base
|
||||
- small.en (english only)
|
||||
- small
|
||||
- medium.en (english only)
|
||||
- medium
|
||||
- large-v1
|
||||
- large-v2
|
||||
- large-v3
|
||||
- large-v3-turbo
|
||||
|
||||
## How to choose?
|
||||
|
||||
### Language Support
|
||||
- **English only**: Use `.en` models for better accuracy and faster processing when you only need English transcription
|
||||
- **Multilingual**: Do not use `.en` models.
|
||||
|
||||
### Resource Constraints
|
||||
- **Limited GPU/CPU or need for very low latency**: Choose `small` or smaller models
|
||||
- `tiny`: Fastest, lowest resource usage, acceptable quality for simple audio
|
||||
- `base`: Good balance of speed and accuracy for basic use cases
|
||||
- `small`: Better accuracy while still being resource-efficient
|
||||
- **Good resources available**: Use `large` models for best accuracy
|
||||
- `large-v2`: Excellent accuracy, good multilingual support
|
||||
- `large-v3`: Best overall accuracy and language support
|
||||
|
||||
### Special Cases
|
||||
- **No translation needed**: Use `large-v3-turbo`
|
||||
- Same transcription quality as `large-v2` but significantly faster
|
||||
- **Important**: Does not translate correctly, only transcribes
|
||||
|
||||
### Model Comparison Table
|
||||
|
||||
| Model | Speed | Accuracy | Multilingual | Translation | Best Use Case |
|
||||
|-------|--------|----------|--------------|-------------|---------------|
|
||||
| tiny(.en) | Fastest | Basic | Yes/No | Yes/No | Real-time, low resources |
|
||||
| base(.en) | Fast | Good | Yes/No | Yes/No | Balanced performance |
|
||||
| small(.en) | Medium | Better | Yes/No | Yes/No | Quality on limited hardware |
|
||||
| medium(.en) | Slow | High | Yes/No | Yes/No | High quality, moderate resources |
|
||||
| large-v2 | Slowest | Excellent | Yes | Yes | Best overall quality |
|
||||
| large-v3 | Slowest | Excellent | Yes | Yes | Maximum accuracy |
|
||||
| large-v3-turbo | Fast | Excellent | Yes | No | Fast, high-quality transcription |
|
||||
|
||||
### Additional Considerations
|
||||
|
||||
**Model Performance**:
|
||||
- Accuracy improves significantly from tiny to large models
|
||||
- English-only models are ~10-15% more accurate for English audio
|
||||
- Newer versions (v2, v3) have better punctuation and formatting
|
||||
|
||||
**Hardware Requirements**:
|
||||
- `tiny`: ~1GB VRAM
|
||||
- `base`: ~1GB VRAM
|
||||
- `small`: ~2GB VRAM
|
||||
- `medium`: ~5GB VRAM
|
||||
- `large`: ~10GB VRAM
|
||||
- `large‑v3‑turbo`: ~6GB VRAM
|
||||
|
||||
**Audio Quality Impact**:
|
||||
- Clean, clear audio: smaller models may suffice
|
||||
- Noisy, accented, or technical audio: larger models recommended
|
||||
- Phone/low-quality audio: use at least `small` model
|
||||
|
||||
### Quick Decision Tree
|
||||
1. English only? → Add `.en` to your choice
|
||||
2. Limited resources or need speed? → `small` or smaller
|
||||
3. Good hardware and want best quality? → `large-v3`
|
||||
4. Need fast, high-quality transcription without translation? → `large-v3-turbo`
|
||||
5. Need translation capabilities? → `large-v2` or `large-v3` (avoid turbo)
|
||||
17
chrome-extension/README.md
Normal file
@@ -0,0 +1,17 @@
|
||||
## WhisperLiveKit Chrome Extension v0.1.0
|
||||
Capture the audio of your current tab, transcribe or translate it using WhisperliveKit. **Still unstable**
|
||||
|
||||
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/chrome-extension/demo-extension.png" alt="WhisperLiveKit Demo" width="730">
|
||||
|
||||
## Running this extension
|
||||
1. Clone this repository.
|
||||
2. Load this directory in Chrome as an unpacked extension.
|
||||
|
||||
|
||||
## Devs:
|
||||
- Impossible to capture audio from tabs if extension is a pannel, unfortunately:
|
||||
- https://issues.chromium.org/issues/40926394
|
||||
- https://groups.google.com/a/chromium.org/g/chromium-extensions/c/DET2SXCFnDg
|
||||
- https://issues.chromium.org/issues/40916430
|
||||
|
||||
- To capture microphone in an extension, there are tricks: https://github.com/justinmann/sidepanel-audio-issue , https://medium.com/@lynchee.owo/how-to-enable-microphone-access-in-chrome-extensions-by-code-924295170080 (comments)
|
||||
9
chrome-extension/background.js
Normal file
@@ -0,0 +1,9 @@
|
||||
chrome.runtime.onInstalled.addListener((details) => {
|
||||
if (details.reason.search(/install/g) === -1) {
|
||||
return
|
||||
}
|
||||
chrome.tabs.create({
|
||||
url: chrome.runtime.getURL("welcome.html"),
|
||||
active: true
|
||||
})
|
||||
})
|
||||
BIN
chrome-extension/demo-extension.png
Normal file
|
After Width: | Height: | Size: 1.2 MiB |
BIN
chrome-extension/icons/icon128.png
Normal file
|
After Width: | Height: | Size: 5.8 KiB |
BIN
chrome-extension/icons/icon16.png
Normal file
|
After Width: | Height: | Size: 376 B |
BIN
chrome-extension/icons/icon32.png
Normal file
|
After Width: | Height: | Size: 823 B |
BIN
chrome-extension/icons/icon48.png
Normal file
|
After Width: | Height: | Size: 1.4 KiB |
669
chrome-extension/live_transcription.js
Normal file
@@ -0,0 +1,669 @@
|
||||
/* Theme, WebSocket, recording, rendering logic extracted from inline script and adapted for segmented theme control and WS caption */
|
||||
let isRecording = false;
|
||||
let websocket = null;
|
||||
let recorder = null;
|
||||
let chunkDuration = 100;
|
||||
let websocketUrl = "ws://localhost:8000/asr";
|
||||
let userClosing = false;
|
||||
let wakeLock = null;
|
||||
let startTime = null;
|
||||
let timerInterval = null;
|
||||
let audioContext = null;
|
||||
let analyser = null;
|
||||
let microphone = null;
|
||||
let waveCanvas = document.getElementById("waveCanvas");
|
||||
let waveCtx = waveCanvas.getContext("2d");
|
||||
let animationFrame = null;
|
||||
let waitingForStop = false;
|
||||
let lastReceivedData = null;
|
||||
let lastSignature = null;
|
||||
let availableMicrophones = [];
|
||||
let selectedMicrophoneId = null;
|
||||
|
||||
waveCanvas.width = 60 * (window.devicePixelRatio || 1);
|
||||
waveCanvas.height = 30 * (window.devicePixelRatio || 1);
|
||||
waveCtx.scale(window.devicePixelRatio || 1, window.devicePixelRatio || 1);
|
||||
|
||||
const statusText = document.getElementById("status");
|
||||
const recordButton = document.getElementById("recordButton");
|
||||
const chunkSelector = document.getElementById("chunkSelector");
|
||||
const websocketInput = document.getElementById("websocketInput");
|
||||
const websocketDefaultSpan = document.getElementById("wsDefaultUrl");
|
||||
const linesTranscriptDiv = document.getElementById("linesTranscript");
|
||||
const timerElement = document.querySelector(".timer");
|
||||
const themeRadios = document.querySelectorAll('input[name="theme"]');
|
||||
const microphoneSelect = document.getElementById("microphoneSelect");
|
||||
const settingsToggle = document.getElementById("settingsToggle");
|
||||
const settingsDiv = document.querySelector(".settings");
|
||||
|
||||
|
||||
|
||||
chrome.runtime.onInstalled.addListener((details) => {
|
||||
if (details.reason.search(/install/g) === -1) {
|
||||
return
|
||||
}
|
||||
chrome.tabs.create({
|
||||
url: chrome.runtime.getURL("welcome.html"),
|
||||
active: true
|
||||
})
|
||||
})
|
||||
|
||||
function getWaveStroke() {
|
||||
const styles = getComputedStyle(document.documentElement);
|
||||
const v = styles.getPropertyValue("--wave-stroke").trim();
|
||||
return v || "#000";
|
||||
}
|
||||
|
||||
let waveStroke = getWaveStroke();
|
||||
function updateWaveStroke() {
|
||||
waveStroke = getWaveStroke();
|
||||
}
|
||||
|
||||
function applyTheme(pref) {
|
||||
if (pref === "light") {
|
||||
document.documentElement.setAttribute("data-theme", "light");
|
||||
} else if (pref === "dark") {
|
||||
document.documentElement.setAttribute("data-theme", "dark");
|
||||
} else {
|
||||
document.documentElement.removeAttribute("data-theme");
|
||||
}
|
||||
updateWaveStroke();
|
||||
}
|
||||
|
||||
// Persisted theme preference
|
||||
const savedThemePref = localStorage.getItem("themePreference") || "system";
|
||||
applyTheme(savedThemePref);
|
||||
if (themeRadios.length) {
|
||||
themeRadios.forEach((r) => {
|
||||
r.checked = r.value === savedThemePref;
|
||||
r.addEventListener("change", () => {
|
||||
if (r.checked) {
|
||||
localStorage.setItem("themePreference", r.value);
|
||||
applyTheme(r.value);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// React to OS theme changes when in "system" mode
|
||||
const darkMq = window.matchMedia && window.matchMedia("(prefers-color-scheme: dark)");
|
||||
const handleOsThemeChange = () => {
|
||||
const pref = localStorage.getItem("themePreference") || "system";
|
||||
if (pref === "system") updateWaveStroke();
|
||||
};
|
||||
if (darkMq && darkMq.addEventListener) {
|
||||
darkMq.addEventListener("change", handleOsThemeChange);
|
||||
} else if (darkMq && darkMq.addListener) {
|
||||
// deprecated, but included for Safari compatibility
|
||||
darkMq.addListener(handleOsThemeChange);
|
||||
}
|
||||
|
||||
async function enumerateMicrophones() {
|
||||
try {
|
||||
const micPermission = await navigator.permissions.query({
|
||||
name: "microphone",
|
||||
});
|
||||
|
||||
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
|
||||
stream.getTracks().forEach(track => track.stop());
|
||||
|
||||
const devices = await navigator.mediaDevices.enumerateDevices();
|
||||
availableMicrophones = devices.filter(device => device.kind === 'audioinput');
|
||||
|
||||
populateMicrophoneSelect();
|
||||
console.log(`Found ${availableMicrophones.length} microphone(s)`);
|
||||
} catch (error) {
|
||||
console.error('Error enumerating microphones:', error);
|
||||
statusText.textContent = "Error accessing microphones. Please grant permission.";
|
||||
}
|
||||
}
|
||||
|
||||
function populateMicrophoneSelect() {
|
||||
if (!microphoneSelect) return;
|
||||
|
||||
microphoneSelect.innerHTML = '<option value="">Default Microphone</option>';
|
||||
|
||||
availableMicrophones.forEach((device, index) => {
|
||||
const option = document.createElement('option');
|
||||
option.value = device.deviceId;
|
||||
option.textContent = device.label || `Microphone ${index + 1}`;
|
||||
microphoneSelect.appendChild(option);
|
||||
});
|
||||
|
||||
const savedMicId = localStorage.getItem('selectedMicrophone');
|
||||
if (savedMicId && availableMicrophones.some(mic => mic.deviceId === savedMicId)) {
|
||||
microphoneSelect.value = savedMicId;
|
||||
selectedMicrophoneId = savedMicId;
|
||||
}
|
||||
}
|
||||
|
||||
function handleMicrophoneChange() {
|
||||
selectedMicrophoneId = microphoneSelect.value || null;
|
||||
localStorage.setItem('selectedMicrophone', selectedMicrophoneId || '');
|
||||
|
||||
const selectedDevice = availableMicrophones.find(mic => mic.deviceId === selectedMicrophoneId);
|
||||
const deviceName = selectedDevice ? selectedDevice.label : 'Default Microphone';
|
||||
|
||||
console.log(`Selected microphone: ${deviceName}`);
|
||||
statusText.textContent = `Microphone changed to: ${deviceName}`;
|
||||
|
||||
if (isRecording) {
|
||||
statusText.textContent = "Switching microphone... Please wait.";
|
||||
stopRecording().then(() => {
|
||||
setTimeout(() => {
|
||||
toggleRecording();
|
||||
}, 1000);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Helpers
|
||||
function fmt1(x) {
|
||||
const n = Number(x);
|
||||
return Number.isFinite(n) ? n.toFixed(1) : x;
|
||||
}
|
||||
|
||||
// Default WebSocket URL computation
|
||||
const host = window.location.hostname || "localhost";
|
||||
const port = window.location.port;
|
||||
const protocol = window.location.protocol === "https:" ? "wss" : "ws";
|
||||
const defaultWebSocketUrl = websocketUrl;
|
||||
|
||||
// Populate default caption and input
|
||||
if (websocketDefaultSpan) websocketDefaultSpan.textContent = defaultWebSocketUrl;
|
||||
websocketInput.value = defaultWebSocketUrl;
|
||||
websocketUrl = defaultWebSocketUrl;
|
||||
|
||||
// Optional chunk selector (guard for presence)
|
||||
if (chunkSelector) {
|
||||
chunkSelector.addEventListener("change", () => {
|
||||
chunkDuration = parseInt(chunkSelector.value);
|
||||
});
|
||||
}
|
||||
|
||||
// WebSocket input change handling
|
||||
websocketInput.addEventListener("change", () => {
|
||||
const urlValue = websocketInput.value.trim();
|
||||
if (!urlValue.startsWith("ws://") && !urlValue.startsWith("wss://")) {
|
||||
statusText.textContent = "Invalid WebSocket URL (must start with ws:// or wss://)";
|
||||
return;
|
||||
}
|
||||
websocketUrl = urlValue;
|
||||
statusText.textContent = "WebSocket URL updated. Ready to connect.";
|
||||
});
|
||||
|
||||
function setupWebSocket() {
|
||||
return new Promise((resolve, reject) => {
|
||||
try {
|
||||
websocket = new WebSocket(websocketUrl);
|
||||
} catch (error) {
|
||||
statusText.textContent = "Invalid WebSocket URL. Please check and try again.";
|
||||
reject(error);
|
||||
return;
|
||||
}
|
||||
|
||||
websocket.onopen = () => {
|
||||
statusText.textContent = "Connected to server.";
|
||||
resolve();
|
||||
};
|
||||
|
||||
websocket.onclose = () => {
|
||||
if (userClosing) {
|
||||
if (waitingForStop) {
|
||||
statusText.textContent = "Processing finalized or connection closed.";
|
||||
if (lastReceivedData) {
|
||||
renderLinesWithBuffer(
|
||||
lastReceivedData.lines || [],
|
||||
lastReceivedData.buffer_diarization || "",
|
||||
lastReceivedData.buffer_transcription || "",
|
||||
0,
|
||||
0,
|
||||
true
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
statusText.textContent = "Disconnected from the WebSocket server. (Check logs if model is loading.)";
|
||||
if (isRecording) {
|
||||
stopRecording();
|
||||
}
|
||||
}
|
||||
isRecording = false;
|
||||
waitingForStop = false;
|
||||
userClosing = false;
|
||||
lastReceivedData = null;
|
||||
websocket = null;
|
||||
updateUI();
|
||||
};
|
||||
|
||||
websocket.onerror = () => {
|
||||
statusText.textContent = "Error connecting to WebSocket.";
|
||||
reject(new Error("Error connecting to WebSocket"));
|
||||
};
|
||||
|
||||
websocket.onmessage = (event) => {
|
||||
const data = JSON.parse(event.data);
|
||||
|
||||
if (data.type === "ready_to_stop") {
|
||||
console.log("Ready to stop received, finalizing display and closing WebSocket.");
|
||||
waitingForStop = false;
|
||||
|
||||
if (lastReceivedData) {
|
||||
renderLinesWithBuffer(
|
||||
lastReceivedData.lines || [],
|
||||
lastReceivedData.buffer_diarization || "",
|
||||
lastReceivedData.buffer_transcription || "",
|
||||
0,
|
||||
0,
|
||||
true
|
||||
);
|
||||
}
|
||||
statusText.textContent = "Finished processing audio! Ready to record again.";
|
||||
recordButton.disabled = false;
|
||||
|
||||
if (websocket) {
|
||||
websocket.close();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
lastReceivedData = data;
|
||||
|
||||
const {
|
||||
lines = [],
|
||||
buffer_transcription = "",
|
||||
buffer_diarization = "",
|
||||
remaining_time_transcription = 0,
|
||||
remaining_time_diarization = 0,
|
||||
status = "active_transcription",
|
||||
} = data;
|
||||
|
||||
renderLinesWithBuffer(
|
||||
lines,
|
||||
buffer_diarization,
|
||||
buffer_transcription,
|
||||
remaining_time_diarization,
|
||||
remaining_time_transcription,
|
||||
false,
|
||||
status
|
||||
);
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
function renderLinesWithBuffer(
|
||||
lines,
|
||||
buffer_diarization,
|
||||
buffer_transcription,
|
||||
remaining_time_diarization,
|
||||
remaining_time_transcription,
|
||||
isFinalizing = false,
|
||||
current_status = "active_transcription"
|
||||
) {
|
||||
if (current_status === "no_audio_detected") {
|
||||
linesTranscriptDiv.innerHTML =
|
||||
"<p style='text-align: center; color: var(--muted); margin-top: 20px;'><em>No audio detected...</em></p>";
|
||||
return;
|
||||
}
|
||||
|
||||
const showLoading = !isFinalizing && (lines || []).some((it) => it.speaker == 0);
|
||||
const showTransLag = !isFinalizing && remaining_time_transcription > 0;
|
||||
const showDiaLag = !isFinalizing && !!buffer_diarization && remaining_time_diarization > 0;
|
||||
const signature = JSON.stringify({
|
||||
lines: (lines || []).map((it) => ({ speaker: it.speaker, text: it.text, start: it.start, end: it.end })),
|
||||
buffer_transcription: buffer_transcription || "",
|
||||
buffer_diarization: buffer_diarization || "",
|
||||
status: current_status,
|
||||
showLoading,
|
||||
showTransLag,
|
||||
showDiaLag,
|
||||
isFinalizing: !!isFinalizing,
|
||||
});
|
||||
if (lastSignature === signature) {
|
||||
const t = document.querySelector(".lag-transcription-value");
|
||||
if (t) t.textContent = fmt1(remaining_time_transcription);
|
||||
const d = document.querySelector(".lag-diarization-value");
|
||||
if (d) d.textContent = fmt1(remaining_time_diarization);
|
||||
const ld = document.querySelector(".loading-diarization-value");
|
||||
if (ld) ld.textContent = fmt1(remaining_time_diarization);
|
||||
return;
|
||||
}
|
||||
lastSignature = signature;
|
||||
|
||||
const linesHtml = (lines || [])
|
||||
.map((item, idx) => {
|
||||
let timeInfo = "";
|
||||
if (item.start !== undefined && item.end !== undefined) {
|
||||
timeInfo = ` ${item.start} - ${item.end}`;
|
||||
}
|
||||
|
||||
let speakerLabel = "";
|
||||
if (item.speaker === -2) {
|
||||
speakerLabel = `<span class="silence">Silence<span id='timeInfo'>${timeInfo}</span></span>`;
|
||||
} else if (item.speaker == 0 && !isFinalizing) {
|
||||
speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'><span class="loading-diarization-value">${fmt1(
|
||||
remaining_time_diarization
|
||||
)}</span> second(s) of audio are undergoing diarization</span></span>`;
|
||||
} else if (item.speaker !== 0) {
|
||||
speakerLabel = `<span id="speaker">Speaker ${item.speaker}<span id='timeInfo'>${timeInfo}</span></span>`;
|
||||
}
|
||||
|
||||
let currentLineText = item.text || "";
|
||||
|
||||
if (idx === lines.length - 1) {
|
||||
if (!isFinalizing && item.speaker !== -2) {
|
||||
if (remaining_time_transcription > 0) {
|
||||
speakerLabel += `<span class="label_transcription"><span class="spinner"></span>Lag <span id='timeInfo'><span class="lag-transcription-value">${fmt1(
|
||||
remaining_time_transcription
|
||||
)}</span>s</span></span>`;
|
||||
}
|
||||
if (buffer_diarization && remaining_time_diarization > 0) {
|
||||
speakerLabel += `<span class="label_diarization"><span class="spinner"></span>Lag<span id='timeInfo'><span class="lag-diarization-value">${fmt1(
|
||||
remaining_time_diarization
|
||||
)}</span>s</span></span>`;
|
||||
}
|
||||
}
|
||||
|
||||
if (buffer_diarization) {
|
||||
if (isFinalizing) {
|
||||
currentLineText +=
|
||||
(currentLineText.length > 0 && buffer_diarization.trim().length > 0 ? " " : "") + buffer_diarization.trim();
|
||||
} else {
|
||||
currentLineText += `<span class="buffer_diarization">${buffer_diarization}</span>`;
|
||||
}
|
||||
}
|
||||
if (buffer_transcription) {
|
||||
if (isFinalizing) {
|
||||
currentLineText +=
|
||||
(currentLineText.length > 0 && buffer_transcription.trim().length > 0 ? " " : "") +
|
||||
buffer_transcription.trim();
|
||||
} else {
|
||||
currentLineText += `<span class="buffer_transcription">${buffer_transcription}</span>`;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return currentLineText.trim().length > 0 || speakerLabel.length > 0
|
||||
? `<p>${speakerLabel}<br/><div class='textcontent'>${currentLineText}</div></p>`
|
||||
: `<p>${speakerLabel}<br/></p>`;
|
||||
})
|
||||
.join("");
|
||||
|
||||
linesTranscriptDiv.innerHTML = linesHtml;
|
||||
window.scrollTo({ top: document.body.scrollHeight, behavior: "smooth" });
|
||||
}
|
||||
|
||||
function updateTimer() {
|
||||
if (!startTime) return;
|
||||
|
||||
const elapsed = Math.floor((Date.now() - startTime) / 1000);
|
||||
const minutes = Math.floor(elapsed / 60).toString().padStart(2, "0");
|
||||
const seconds = (elapsed % 60).toString().padStart(2, "0");
|
||||
timerElement.textContent = `${minutes}:${seconds}`;
|
||||
}
|
||||
|
||||
function drawWaveform() {
|
||||
if (!analyser) return;
|
||||
|
||||
const bufferLength = analyser.frequencyBinCount;
|
||||
const dataArray = new Uint8Array(bufferLength);
|
||||
analyser.getByteTimeDomainData(dataArray);
|
||||
|
||||
waveCtx.clearRect(
|
||||
0,
|
||||
0,
|
||||
waveCanvas.width / (window.devicePixelRatio || 1),
|
||||
waveCanvas.height / (window.devicePixelRatio || 1)
|
||||
);
|
||||
waveCtx.lineWidth = 1;
|
||||
waveCtx.strokeStyle = waveStroke;
|
||||
waveCtx.beginPath();
|
||||
|
||||
const sliceWidth = (waveCanvas.width / (window.devicePixelRatio || 1)) / bufferLength;
|
||||
let x = 0;
|
||||
|
||||
for (let i = 0; i < bufferLength; i++) {
|
||||
const v = dataArray[i] / 128.0;
|
||||
const y = (v * (waveCanvas.height / (window.devicePixelRatio || 1))) / 2;
|
||||
|
||||
if (i === 0) {
|
||||
waveCtx.moveTo(x, y);
|
||||
} else {
|
||||
waveCtx.lineTo(x, y);
|
||||
}
|
||||
|
||||
x += sliceWidth;
|
||||
}
|
||||
|
||||
waveCtx.lineTo(
|
||||
waveCanvas.width / (window.devicePixelRatio || 1),
|
||||
(waveCanvas.height / (window.devicePixelRatio || 1)) / 2
|
||||
);
|
||||
waveCtx.stroke();
|
||||
|
||||
animationFrame = requestAnimationFrame(drawWaveform);
|
||||
}
|
||||
|
||||
async function startRecording() {
|
||||
try {
|
||||
try {
|
||||
wakeLock = await navigator.wakeLock.request("screen");
|
||||
} catch (err) {
|
||||
console.log("Error acquiring wake lock.");
|
||||
}
|
||||
|
||||
let stream;
|
||||
try {
|
||||
// Try tab capture first
|
||||
stream = await new Promise((resolve, reject) => {
|
||||
chrome.tabCapture.capture({audio: true}, (s) => {
|
||||
if (s) {
|
||||
resolve(s);
|
||||
} else {
|
||||
reject(new Error('Tab capture failed or not available'));
|
||||
}
|
||||
});
|
||||
});
|
||||
statusText.textContent = "Using tab audio capture.";
|
||||
} catch (tabError) {
|
||||
console.log('Tab capture not available, falling back to microphone', tabError);
|
||||
// Fallback to microphone
|
||||
const audioConstraints = selectedMicrophoneId
|
||||
? { audio: { deviceId: { exact: selectedMicrophoneId } } }
|
||||
: { audio: true };
|
||||
stream = await navigator.mediaDevices.getUserMedia(audioConstraints);
|
||||
statusText.textContent = "Using microphone audio.";
|
||||
}
|
||||
|
||||
audioContext = new (window.AudioContext || window.webkitAudioContext)();
|
||||
analyser = audioContext.createAnalyser();
|
||||
analyser.fftSize = 256;
|
||||
microphone = audioContext.createMediaStreamSource(stream);
|
||||
microphone.connect(analyser);
|
||||
|
||||
recorder = new MediaRecorder(stream, { mimeType: "audio/webm" });
|
||||
recorder.ondataavailable = (e) => {
|
||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||
websocket.send(e.data);
|
||||
}
|
||||
};
|
||||
recorder.start(chunkDuration);
|
||||
|
||||
startTime = Date.now();
|
||||
timerInterval = setInterval(updateTimer, 1000);
|
||||
drawWaveform();
|
||||
|
||||
isRecording = true;
|
||||
updateUI();
|
||||
} catch (err) {
|
||||
if (window.location.hostname === "0.0.0.0") {
|
||||
statusText.textContent =
|
||||
"Error accessing audio input. Browsers may block audio access on 0.0.0.0. Try using localhost:8000 instead.";
|
||||
} else {
|
||||
statusText.textContent = "Error accessing audio input. Please check permissions.";
|
||||
}
|
||||
console.error(err);
|
||||
}
|
||||
}
|
||||
|
||||
async function stopRecording() {
|
||||
if (wakeLock) {
|
||||
try {
|
||||
await wakeLock.release();
|
||||
} catch (e) {
|
||||
// ignore
|
||||
}
|
||||
wakeLock = null;
|
||||
}
|
||||
|
||||
userClosing = true;
|
||||
waitingForStop = true;
|
||||
|
||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||
const emptyBlob = new Blob([], { type: "audio/webm" });
|
||||
websocket.send(emptyBlob);
|
||||
statusText.textContent = "Recording stopped. Processing final audio...";
|
||||
}
|
||||
|
||||
if (recorder) {
|
||||
recorder.stop();
|
||||
recorder = null;
|
||||
}
|
||||
|
||||
if (microphone) {
|
||||
microphone.disconnect();
|
||||
microphone = null;
|
||||
}
|
||||
|
||||
if (analyser) {
|
||||
analyser = null;
|
||||
}
|
||||
|
||||
if (audioContext && audioContext.state !== "closed") {
|
||||
try {
|
||||
await audioContext.close();
|
||||
} catch (e) {
|
||||
console.warn("Could not close audio context:", e);
|
||||
}
|
||||
audioContext = null;
|
||||
}
|
||||
|
||||
if (animationFrame) {
|
||||
cancelAnimationFrame(animationFrame);
|
||||
animationFrame = null;
|
||||
}
|
||||
|
||||
if (timerInterval) {
|
||||
clearInterval(timerInterval);
|
||||
timerInterval = null;
|
||||
}
|
||||
timerElement.textContent = "00:00";
|
||||
startTime = null;
|
||||
|
||||
isRecording = false;
|
||||
updateUI();
|
||||
}
|
||||
|
||||
async function toggleRecording() {
|
||||
if (!isRecording) {
|
||||
if (waitingForStop) {
|
||||
console.log("Waiting for stop, early return");
|
||||
return;
|
||||
}
|
||||
console.log("Connecting to WebSocket");
|
||||
try {
|
||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||
await startRecording();
|
||||
} else {
|
||||
await setupWebSocket();
|
||||
await startRecording();
|
||||
}
|
||||
} catch (err) {
|
||||
statusText.textContent = "Could not connect to WebSocket or access mic. Aborted.";
|
||||
console.error(err);
|
||||
}
|
||||
} else {
|
||||
console.log("Stopping recording");
|
||||
stopRecording();
|
||||
}
|
||||
}
|
||||
|
||||
function updateUI() {
|
||||
recordButton.classList.toggle("recording", isRecording);
|
||||
recordButton.disabled = waitingForStop;
|
||||
|
||||
if (waitingForStop) {
|
||||
if (statusText.textContent !== "Recording stopped. Processing final audio...") {
|
||||
statusText.textContent = "Please wait for processing to complete...";
|
||||
}
|
||||
} else if (isRecording) {
|
||||
statusText.textContent = "Recording...";
|
||||
} else {
|
||||
if (
|
||||
statusText.textContent !== "Finished processing audio! Ready to record again." &&
|
||||
statusText.textContent !== "Processing finalized or connection closed."
|
||||
) {
|
||||
statusText.textContent = "Click to start transcription";
|
||||
}
|
||||
}
|
||||
if (!waitingForStop) {
|
||||
recordButton.disabled = false;
|
||||
}
|
||||
}
|
||||
|
||||
recordButton.addEventListener("click", toggleRecording);
|
||||
|
||||
if (microphoneSelect) {
|
||||
microphoneSelect.addEventListener("change", handleMicrophoneChange);
|
||||
}
|
||||
|
||||
// Settings toggle functionality
|
||||
settingsToggle.addEventListener("click", () => {
|
||||
settingsDiv.classList.toggle("visible");
|
||||
settingsToggle.classList.toggle("active");
|
||||
});
|
||||
|
||||
document.addEventListener('DOMContentLoaded', async () => {
|
||||
try {
|
||||
await enumerateMicrophones();
|
||||
} catch (error) {
|
||||
console.log("Could not enumerate microphones on load:", error);
|
||||
}
|
||||
});
|
||||
navigator.mediaDevices.addEventListener('devicechange', async () => {
|
||||
console.log('Device change detected, re-enumerating microphones');
|
||||
try {
|
||||
await enumerateMicrophones();
|
||||
} catch (error) {
|
||||
console.log("Error re-enumerating microphones:", error);
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
async function run() {
|
||||
const micPermission = await navigator.permissions.query({
|
||||
name: "microphone",
|
||||
});
|
||||
|
||||
document.getElementById(
|
||||
"audioPermission"
|
||||
).innerText = `MICROPHONE: ${micPermission.state}`;
|
||||
|
||||
if (micPermission.state !== "granted") {
|
||||
chrome.tabs.create({ url: "welcome.html" });
|
||||
}
|
||||
|
||||
const intervalId = setInterval(async () => {
|
||||
const micPermission = await navigator.permissions.query({
|
||||
name: "microphone",
|
||||
});
|
||||
if (micPermission.state === "granted") {
|
||||
document.getElementById(
|
||||
"audioPermission"
|
||||
).innerText = `MICROPHONE: ${micPermission.state}`;
|
||||
clearInterval(intervalId);
|
||||
}
|
||||
}, 100);
|
||||
}
|
||||
|
||||
void run();
|
||||
37
chrome-extension/manifest.json
Normal file
@@ -0,0 +1,37 @@
|
||||
{
|
||||
"manifest_version": 3,
|
||||
"name": "WhisperLiveKit Tab Capture",
|
||||
"version": "1.0",
|
||||
"description": "Capture and transcribe audio from browser tabs using WhisperLiveKit.",
|
||||
"background": {
|
||||
"service_worker": "background.js"
|
||||
},
|
||||
"icons": {
|
||||
"16": "icons/icon16.png",
|
||||
"32": "icons/icon32.png",
|
||||
"48": "icons/icon48.png",
|
||||
"128": "icons/icon128.png"
|
||||
},
|
||||
"action": {
|
||||
"default_title": "WhisperLiveKit Tab Capture",
|
||||
"default_popup": "popup.html"
|
||||
},
|
||||
"permissions": [
|
||||
"scripting",
|
||||
"tabCapture",
|
||||
"offscreen",
|
||||
"activeTab",
|
||||
"storage"
|
||||
],
|
||||
"web_accessible_resources": [
|
||||
{
|
||||
"resources": [
|
||||
"requestPermissions.html",
|
||||
"requestPermissions.js"
|
||||
],
|
||||
"matches": [
|
||||
"<all_urls>"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
78
chrome-extension/popup.html
Normal file
@@ -0,0 +1,78 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>WhisperLiveKit</title>
|
||||
<link rel="stylesheet" href="/web/live_transcription.css" />
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<div class="settings-container">
|
||||
<button id="recordButton">
|
||||
<div class="shape-container">
|
||||
<div class="shape"></div>
|
||||
</div>
|
||||
<div class="recording-info">
|
||||
<div class="wave-container">
|
||||
<canvas id="waveCanvas"></canvas>
|
||||
</div>
|
||||
<div class="timer">00:00</div>
|
||||
</div>
|
||||
</button>
|
||||
|
||||
<button id="settingsToggle" class="settings-toggle" title="Show/hide settings">
|
||||
<img src="/web/src/settings.svg" alt="Settings" />
|
||||
</button>
|
||||
|
||||
<div class="settings">
|
||||
<div class="field">
|
||||
<label for="websocketInput">Websocket URL</label>
|
||||
<input id="websocketInput" type="text" placeholder="ws://host:port/asr" />
|
||||
</div>
|
||||
|
||||
<div class="field">
|
||||
<label id="microphoneSelectLabel" for="microphoneSelect">Select Microphone</label>
|
||||
<select id="microphoneSelect">
|
||||
<option value="">Default Microphone</option>
|
||||
</select>
|
||||
<div id="audioPermission"></div>
|
||||
|
||||
</div>
|
||||
|
||||
<div class="theme-selector-container">
|
||||
<div class="segmented" role="radiogroup" aria-label="Theme selector">
|
||||
<input type="radio" id="theme-system" name="theme" value="system" />
|
||||
<label for="theme-system" title="System">
|
||||
<img src="/web/src/system_mode.svg" alt="" />
|
||||
<!-- <span>System</span> -->
|
||||
</label>
|
||||
|
||||
<input type="radio" id="theme-light" name="theme" value="light" />
|
||||
<label for="theme-light" title="Light">
|
||||
<img src="/web/src/light_mode.svg" alt="" />
|
||||
<!-- <span>Light</span> -->
|
||||
</label>
|
||||
|
||||
<input type="radio" id="theme-dark" name="theme" value="dark" />
|
||||
<label for="theme-dark" title="Dark">
|
||||
<img src="/web/src/dark_mode.svg" alt="" />
|
||||
<!-- <span>Dark</span> -->
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
<p id="status"></p>
|
||||
|
||||
<div id="linesTranscript"></div>
|
||||
|
||||
<script src="live_transcription.js"></script>
|
||||
</body>
|
||||
|
||||
</html>
|
||||
12
chrome-extension/requestPermissions.html
Normal file
@@ -0,0 +1,12 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Request Permissions</title>
|
||||
<script src="requestPermissions.js"></script>
|
||||
</head>
|
||||
<body>
|
||||
This page exists to workaround an issue with Chrome that blocks permission
|
||||
requests from chrome extensions
|
||||
<button id="requestMicrophone">Request Microphone</button>
|
||||
</body>
|
||||
</html>
|
||||
17
chrome-extension/requestPermissions.js
Normal file
@@ -0,0 +1,17 @@
|
||||
/**
|
||||
* Requests user permission for microphone access.
|
||||
* @returns {Promise<void>} A Promise that resolves when permission is granted or rejects with an error.
|
||||
*/
|
||||
async function getUserPermission() {
|
||||
console.log("Getting user permission for microphone access...");
|
||||
await navigator.mediaDevices.getUserMedia({ audio: true });
|
||||
const micPermission = await navigator.permissions.query({
|
||||
name: "microphone",
|
||||
});
|
||||
if (micPermission.state == "granted") {
|
||||
window.close();
|
||||
}
|
||||
}
|
||||
|
||||
// Call the function to request microphone permission
|
||||
getUserPermission();
|
||||
29
chrome-extension/sidepanel.js
Normal file
@@ -0,0 +1,29 @@
|
||||
console.log("sidepanel.js");
|
||||
|
||||
async function run() {
|
||||
const micPermission = await navigator.permissions.query({
|
||||
name: "microphone",
|
||||
});
|
||||
|
||||
document.getElementById(
|
||||
"audioPermission"
|
||||
).innerText = `MICROPHONE: ${micPermission.state}`;
|
||||
|
||||
if (micPermission.state !== "granted") {
|
||||
chrome.tabs.create({ url: "requestPermissions.html" });
|
||||
}
|
||||
|
||||
const intervalId = setInterval(async () => {
|
||||
const micPermission = await navigator.permissions.query({
|
||||
name: "microphone",
|
||||
});
|
||||
if (micPermission.state === "granted") {
|
||||
document.getElementById(
|
||||
"audioPermission"
|
||||
).innerText = `MICROPHONE: ${micPermission.state}`;
|
||||
clearInterval(intervalId);
|
||||
}
|
||||
}, 100);
|
||||
}
|
||||
|
||||
void run();
|
||||
539
chrome-extension/web/live_transcription.css
Normal file
@@ -0,0 +1,539 @@
|
||||
:root {
|
||||
--bg: #ffffff;
|
||||
--text: #111111;
|
||||
--muted: #666666;
|
||||
--border: #e5e5e5;
|
||||
--chip-bg: rgba(0, 0, 0, 0.04);
|
||||
--chip-text: #000000;
|
||||
--spinner-border: #8d8d8d5c;
|
||||
--spinner-top: #b0b0b0;
|
||||
--silence-bg: #f3f3f3;
|
||||
--loading-bg: rgba(255, 77, 77, 0.06);
|
||||
--button-bg: #ffffff;
|
||||
--button-border: #e9e9e9;
|
||||
--wave-stroke: #000000;
|
||||
--label-dia-text: #868686;
|
||||
--label-trans-text: #111111;
|
||||
}
|
||||
|
||||
@media (prefers-color-scheme: dark) {
|
||||
:root:not([data-theme="light"]) {
|
||||
--bg: #0b0b0b;
|
||||
--text: #e6e6e6;
|
||||
--muted: #9aa0a6;
|
||||
--border: #333333;
|
||||
--chip-bg: rgba(255, 255, 255, 0.08);
|
||||
--chip-text: #e6e6e6;
|
||||
--spinner-border: #555555;
|
||||
--spinner-top: #dddddd;
|
||||
--silence-bg: #1a1a1a;
|
||||
--loading-bg: rgba(255, 77, 77, 0.12);
|
||||
--button-bg: #111111;
|
||||
--button-border: #333333;
|
||||
--wave-stroke: #e6e6e6;
|
||||
--label-dia-text: #b3b3b3;
|
||||
--label-trans-text: #ffffff;
|
||||
}
|
||||
}
|
||||
|
||||
:root[data-theme="dark"] {
|
||||
--bg: #0b0b0b;
|
||||
--text: #e6e6e6;
|
||||
--muted: #9aa0a6;
|
||||
--border: #333333;
|
||||
--chip-bg: rgba(255, 255, 255, 0.08);
|
||||
--chip-text: #e6e6e6;
|
||||
--spinner-border: #555555;
|
||||
--spinner-top: #dddddd;
|
||||
--silence-bg: #1a1a1a;
|
||||
--loading-bg: rgba(255, 77, 77, 0.12);
|
||||
--button-bg: #111111;
|
||||
--button-border: #333333;
|
||||
--wave-stroke: #e6e6e6;
|
||||
--label-dia-text: #b3b3b3;
|
||||
--label-trans-text: #ffffff;
|
||||
}
|
||||
|
||||
:root[data-theme="light"] {
|
||||
--bg: #ffffff;
|
||||
--text: #111111;
|
||||
--muted: #666666;
|
||||
--border: #e5e5e5;
|
||||
--chip-bg: rgba(0, 0, 0, 0.04);
|
||||
--chip-text: #000000;
|
||||
--spinner-border: #8d8d8d5c;
|
||||
--spinner-top: #b0b0b0;
|
||||
--silence-bg: #f3f3f3;
|
||||
--loading-bg: rgba(255, 77, 77, 0.06);
|
||||
--button-bg: #ffffff;
|
||||
--button-border: #e9e9e9;
|
||||
--wave-stroke: #000000;
|
||||
--label-dia-text: #868686;
|
||||
--label-trans-text: #111111;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: ui-sans-serif, system-ui, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol', 'Noto Color Emoji';
|
||||
margin: 20px;
|
||||
text-align: center;
|
||||
background-color: var(--bg);
|
||||
color: var(--text);
|
||||
}
|
||||
|
||||
.settings-toggle {
|
||||
margin-top: 4px;
|
||||
width: 40px;
|
||||
height: 40px;
|
||||
border: none;
|
||||
border-radius: 50%;
|
||||
background-color: var(--button-bg);
|
||||
cursor: pointer;
|
||||
transition: all 0.3s ease;
|
||||
/* border: 1px solid var(--button-border); */
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.settings-toggle:hover {
|
||||
background-color: var(--chip-bg);
|
||||
}
|
||||
|
||||
.settings-toggle img {
|
||||
width: 24px;
|
||||
height: 24px;
|
||||
opacity: 0.7;
|
||||
transition: opacity 0.2s ease, transform 0.3s ease;
|
||||
}
|
||||
|
||||
.settings-toggle:hover img {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.settings-toggle.active img {
|
||||
transform: rotate(80deg);
|
||||
}
|
||||
|
||||
/* Record button */
|
||||
#recordButton {
|
||||
width: 50px;
|
||||
height: 50px;
|
||||
border: none;
|
||||
border-radius: 50%;
|
||||
background-color: var(--button-bg);
|
||||
cursor: pointer;
|
||||
transition: all 0.3s ease;
|
||||
border: 1px solid var(--button-border);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
#recordButton.recording {
|
||||
width: 180px;
|
||||
border-radius: 40px;
|
||||
justify-content: flex-start;
|
||||
padding-left: 20px;
|
||||
}
|
||||
|
||||
#recordButton:active {
|
||||
transform: scale(0.95);
|
||||
}
|
||||
|
||||
.shape-container {
|
||||
width: 25px;
|
||||
height: 25px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.shape {
|
||||
width: 25px;
|
||||
height: 25px;
|
||||
background-color: rgb(209, 61, 53);
|
||||
border-radius: 50%;
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
|
||||
#recordButton:disabled .shape {
|
||||
background-color: #6e6d6d;
|
||||
}
|
||||
|
||||
#recordButton.recording .shape {
|
||||
border-radius: 5px;
|
||||
width: 25px;
|
||||
height: 25px;
|
||||
}
|
||||
|
||||
/* Recording elements */
|
||||
.recording-info {
|
||||
display: none;
|
||||
align-items: center;
|
||||
margin-left: 15px;
|
||||
flex-grow: 1;
|
||||
}
|
||||
|
||||
#recordButton.recording .recording-info {
|
||||
display: flex;
|
||||
}
|
||||
|
||||
.wave-container {
|
||||
width: 60px;
|
||||
height: 30px;
|
||||
position: relative;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
#waveCanvas {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
.timer {
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
color: var(--text);
|
||||
margin-left: 10px;
|
||||
}
|
||||
|
||||
#status {
|
||||
margin-top: 20px;
|
||||
font-size: 16px;
|
||||
color: var(--text);
|
||||
}
|
||||
|
||||
/* Settings */
|
||||
.settings-container {
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: flex-start;
|
||||
gap: 15px;
|
||||
margin-top: 20px;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.settings {
|
||||
display: none;
|
||||
flex-wrap: wrap;
|
||||
align-items: flex-start;
|
||||
gap: 12px;
|
||||
transition: opacity 0.3s ease;
|
||||
}
|
||||
|
||||
.settings.visible {
|
||||
display: flex;
|
||||
}
|
||||
|
||||
.field {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: flex-start;
|
||||
gap: 3px;
|
||||
}
|
||||
|
||||
#chunkSelector,
|
||||
#websocketInput,
|
||||
#themeSelector,
|
||||
#microphoneSelect {
|
||||
font-size: 16px;
|
||||
padding: 5px 8px;
|
||||
border-radius: 8px;
|
||||
border: 1px solid var(--border);
|
||||
background-color: var(--button-bg);
|
||||
color: var(--text);
|
||||
max-height: 30px;
|
||||
}
|
||||
|
||||
#microphoneSelect {
|
||||
width: 100%;
|
||||
max-width: 190px;
|
||||
min-width: 120px;
|
||||
}
|
||||
|
||||
#chunkSelector:focus,
|
||||
#websocketInput:focus,
|
||||
#themeSelector:focus,
|
||||
#microphoneSelect:focus {
|
||||
outline: none;
|
||||
border-color: #007bff;
|
||||
box-shadow: 0 0 0 3px rgba(0, 123, 255, 0.15);
|
||||
}
|
||||
|
||||
label {
|
||||
font-size: 13px;
|
||||
color: var(--muted);
|
||||
}
|
||||
|
||||
.ws-default {
|
||||
font-size: 12px;
|
||||
color: var(--muted);
|
||||
}
|
||||
|
||||
/* Segmented pill control for Theme */
|
||||
.segmented {
|
||||
display: inline-flex;
|
||||
align-items: stretch;
|
||||
border: 1px solid var(--button-border);
|
||||
background-color: var(--button-bg);
|
||||
border-radius: 999px;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.segmented input[type="radio"] {
|
||||
position: absolute;
|
||||
opacity: 0;
|
||||
pointer-events: none;
|
||||
}
|
||||
|
||||
.theme-selector-container {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin-top: 17px;
|
||||
}
|
||||
|
||||
.segmented label {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
padding: 6px 12px;
|
||||
font-size: 14px;
|
||||
color: var(--muted);
|
||||
cursor: pointer;
|
||||
user-select: none;
|
||||
transition: background-color 0.2s ease, color 0.2s ease;
|
||||
}
|
||||
|
||||
.segmented label span {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.segmented label:hover span {
|
||||
display: inline;
|
||||
}
|
||||
|
||||
.segmented label:hover {
|
||||
background-color: var(--chip-bg);
|
||||
}
|
||||
|
||||
.segmented img {
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
}
|
||||
|
||||
.segmented input[type="radio"]:checked + label {
|
||||
background-color: var(--chip-bg);
|
||||
color: var(--text);
|
||||
}
|
||||
|
||||
.segmented input[type="radio"]:focus-visible + label,
|
||||
.segmented input[type="radio"]:focus + label {
|
||||
outline: 2px solid #007bff;
|
||||
outline-offset: 2px;
|
||||
border-radius: 999px;
|
||||
}
|
||||
|
||||
/* Transcript area */
|
||||
#linesTranscript {
|
||||
margin: 20px auto;
|
||||
max-width: 700px;
|
||||
text-align: left;
|
||||
font-size: 16px;
|
||||
}
|
||||
|
||||
#linesTranscript p {
|
||||
margin: 0px 0;
|
||||
}
|
||||
|
||||
#linesTranscript strong {
|
||||
color: var(--text);
|
||||
}
|
||||
|
||||
#speaker {
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 100px;
|
||||
padding: 2px 10px;
|
||||
font-size: 14px;
|
||||
margin-bottom: 0px;
|
||||
}
|
||||
|
||||
.label_diarization {
|
||||
background-color: var(--chip-bg);
|
||||
border-radius: 8px 8px 8px 8px;
|
||||
padding: 2px 10px;
|
||||
margin-left: 10px;
|
||||
display: inline-block;
|
||||
white-space: nowrap;
|
||||
font-size: 14px;
|
||||
margin-bottom: 0px;
|
||||
color: var(--label-dia-text);
|
||||
}
|
||||
|
||||
.label_transcription {
|
||||
background-color: var(--chip-bg);
|
||||
border-radius: 8px 8px 8px 8px;
|
||||
padding: 2px 10px;
|
||||
display: inline-block;
|
||||
white-space: nowrap;
|
||||
margin-left: 10px;
|
||||
font-size: 14px;
|
||||
margin-bottom: 0px;
|
||||
color: var(--label-trans-text);
|
||||
}
|
||||
|
||||
#timeInfo {
|
||||
color: var(--muted);
|
||||
margin-left: 10px;
|
||||
}
|
||||
|
||||
.textcontent {
|
||||
font-size: 16px;
|
||||
padding-left: 10px;
|
||||
margin-bottom: 10px;
|
||||
margin-top: 1px;
|
||||
padding-top: 5px;
|
||||
border-radius: 0px 0px 0px 10px;
|
||||
}
|
||||
|
||||
.buffer_diarization {
|
||||
color: var(--label-dia-text);
|
||||
margin-left: 4px;
|
||||
}
|
||||
|
||||
.buffer_transcription {
|
||||
color: #7474748c;
|
||||
margin-left: 4px;
|
||||
}
|
||||
|
||||
.spinner {
|
||||
display: inline-block;
|
||||
width: 8px;
|
||||
height: 8px;
|
||||
border: 2px solid var(--spinner-border);
|
||||
border-top: 2px solid var(--spinner-top);
|
||||
border-radius: 50%;
|
||||
animation: spin 0.7s linear infinite;
|
||||
vertical-align: middle;
|
||||
margin-bottom: 2px;
|
||||
margin-right: 5px;
|
||||
}
|
||||
|
||||
@keyframes spin {
|
||||
to {
|
||||
transform: rotate(360deg);
|
||||
}
|
||||
}
|
||||
|
||||
.silence {
|
||||
color: var(--muted);
|
||||
background-color: var(--silence-bg);
|
||||
font-size: 13px;
|
||||
border-radius: 30px;
|
||||
padding: 2px 10px;
|
||||
}
|
||||
|
||||
.loading {
|
||||
color: var(--muted);
|
||||
background-color: var(--loading-bg);
|
||||
border-radius: 8px 8px 8px 0px;
|
||||
padding: 2px 10px;
|
||||
font-size: 14px;
|
||||
margin-bottom: 0px;
|
||||
}
|
||||
|
||||
/* for smaller screens */
|
||||
/* @media (max-width: 450px) {
|
||||
.settings-container {
|
||||
flex-direction: column;
|
||||
gap: 10px;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.settings {
|
||||
justify-content: center;
|
||||
gap: 8px;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.field {
|
||||
align-items: center;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
#websocketInput,
|
||||
#microphoneSelect {
|
||||
min-width: 200px;
|
||||
max-width: 100%;
|
||||
}
|
||||
|
||||
.theme-selector-container {
|
||||
margin-top: 10px;
|
||||
}
|
||||
} */
|
||||
|
||||
/* @media (max-width: 768px) and (min-width: 451px) {
|
||||
.settings-container {
|
||||
gap: 10px;
|
||||
}
|
||||
|
||||
.settings {
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
#websocketInput,
|
||||
#microphoneSelect {
|
||||
min-width: 150px;
|
||||
max-width: 300px;
|
||||
}
|
||||
} */
|
||||
|
||||
/* @media (max-width: 480px) {
|
||||
body {
|
||||
margin: 10px;
|
||||
}
|
||||
|
||||
.settings-toggle {
|
||||
width: 35px;
|
||||
height: 35px;
|
||||
}
|
||||
|
||||
.settings-toggle img {
|
||||
width: 20px;
|
||||
height: 20px;
|
||||
}
|
||||
|
||||
.settings {
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
}
|
||||
|
||||
#websocketInput,
|
||||
#microphoneSelect {
|
||||
max-width: 400px;
|
||||
}
|
||||
|
||||
.segmented label {
|
||||
padding: 4px 8px;
|
||||
font-size: 12px;
|
||||
}
|
||||
|
||||
.segmented img {
|
||||
width: 14px;
|
||||
height: 14px;
|
||||
}
|
||||
} */
|
||||
|
||||
|
||||
html
|
||||
{
|
||||
width: 400px; /* max: 800px */
|
||||
height: 600px; /* max: 600px */
|
||||
border-radius: 10px;
|
||||
|
||||
}
|
||||
1
chrome-extension/web/src/dark_mode.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M480-120q-151 0-255.5-104.5T120-480q0-138 90-239.5T440-838q13-2 23 3.5t16 14.5q6 9 6.5 21t-7.5 23q-17 26-25.5 55t-8.5 61q0 90 63 153t153 63q31 0 61.5-9t54.5-25q11-7 22.5-6.5T819-479q10 5 15.5 15t3.5 24q-14 138-117.5 229T480-120Zm0-80q88 0 158-48.5T740-375q-20 5-40 8t-40 3q-123 0-209.5-86.5T364-660q0-20 3-40t8-40q-78 32-126.5 102T200-480q0 116 82 198t198 82Zm-10-270Z"/></svg>
|
||||
|
After Width: | Height: | Size: 493 B |
1
chrome-extension/web/src/light_mode.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M480-360q50 0 85-35t35-85q0-50-35-85t-85-35q-50 0-85 35t-35 85q0 50 35 85t85 35Zm0 80q-83 0-141.5-58.5T280-480q0-83 58.5-141.5T480-680q83 0 141.5 58.5T680-480q0 83-58.5 141.5T480-280ZM80-440q-17 0-28.5-11.5T40-480q0-17 11.5-28.5T80-520h80q17 0 28.5 11.5T200-480q0 17-11.5 28.5T160-440H80Zm720 0q-17 0-28.5-11.5T760-480q0-17 11.5-28.5T800-520h80q17 0 28.5 11.5T920-480q0 17-11.5 28.5T880-440h-80ZM480-760q-17 0-28.5-11.5T440-800v-80q0-17 11.5-28.5T480-920q17 0 28.5 11.5T520-880v80q0 17-11.5 28.5T480-760Zm0 720q-17 0-28.5-11.5T440-80v-80q0-17 11.5-28.5T480-200q17 0 28.5 11.5T520-160v80q0 17-11.5 28.5T480-40ZM226-678l-43-42q-12-11-11.5-28t11.5-29q12-12 29-12t28 12l42 43q11 12 11 28t-11 28q-11 12-27.5 11.5T226-678Zm494 495-42-43q-11-12-11-28.5t11-27.5q11-12 27.5-11.5T734-282l43 42q12 11 11.5 28T777-183q-12 12-29 12t-28-12Zm-42-495q-12-11-11.5-27.5T678-734l42-43q11-12 28-11.5t29 11.5q12 12 12 29t-12 28l-43 42q-12 11-28 11t-28-11ZM183-183q-12-12-12-29t12-28l43-42q12-11 28.5-11t27.5 11q12 11 11.5 27.5T282-226l-42 43q-11 12-28 11.5T183-183Zm297-297Z"/></svg>
|
||||
|
After Width: | Height: | Size: 1.2 KiB |
1
chrome-extension/web/src/settings.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M433-80q-27 0-46.5-18T363-142l-9-66q-13-5-24.5-12T307-235l-62 26q-25 11-50 2t-39-32l-47-82q-14-23-8-49t27-43l53-40q-1-7-1-13.5v-27q0-6.5 1-13.5l-53-40q-21-17-27-43t8-49l47-82q14-23 39-32t50 2l62 26q11-8 23-15t24-12l9-66q4-26 23.5-44t46.5-18h94q27 0 46.5 18t23.5 44l9 66q13 5 24.5 12t22.5 15l62-26q25-11 50-2t39 32l47 82q14 23 8 49t-27 43l-53 40q1 7 1 13.5v27q0 6.5-2 13.5l53 40q21 17 27 43t-8 49l-48 82q-14 23-39 32t-50-2l-60-26q-11 8-23 15t-24 12l-9 66q-4 26-23.5 44T527-80h-94Zm7-80h79l14-106q31-8 57.5-23.5T639-327l99 41 39-68-86-65q5-14 7-29.5t2-31.5q0-16-2-31.5t-7-29.5l86-65-39-68-99 42q-22-23-48.5-38.5T533-694l-13-106h-79l-14 106q-31 8-57.5 23.5T321-633l-99-41-39 68 86 64q-5 15-7 30t-2 32q0 16 2 31t7 30l-86 65 39 68 99-42q22 23 48.5 38.5T427-266l13 106Zm42-180q58 0 99-41t41-99q0-58-41-99t-99-41q-59 0-99.5 41T342-480q0 58 40.5 99t99.5 41Zm-2-140Z"/></svg>
|
||||
|
After Width: | Height: | Size: 982 B |
1
chrome-extension/web/src/system_mode.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M396-396q-32-32-58.5-67T289-537q-5 14-6.5 28.5T281-480q0 83 58 141t141 58q14 0 28.5-2t28.5-6q-39-22-74-48.5T396-396Zm85 196q-56 0-107-21t-91-61q-40-40-61-91t-21-107q0-51 17-97.5t50-84.5q13-14 32-9.5t27 24.5q21 55 52.5 104t73.5 91q42 42 91 73.5T648-326q20 8 24.5 27t-9.5 32q-38 33-84.5 50T481-200Zm223-192q-16-5-23-20.5t-4-32.5q9-48-6-94.5T621-621q-35-35-80.5-49.5T448-677q-17 3-32-4t-21-23q-6-16 1.5-31t23.5-19q69-15 138 4.5T679-678q51 51 71 120t5 138q-4 17-19 25t-32 3ZM480-840q-17 0-28.5-11.5T440-880v-40q0-17 11.5-28.5T480-960q17 0 28.5 11.5T520-920v40q0 17-11.5 28.5T480-840Zm0 840q-17 0-28.5-11.5T440-40v-40q0-17 11.5-28.5T480-120q17 0 28.5 11.5T520-80v40q0 17-11.5 28.5T480 0Zm255-734q-12-12-12-28.5t12-28.5l28-28q11-11 27.5-11t28.5 11q12 12 12 28.5T819-762l-28 28q-12 12-28 12t-28-12ZM141-141q-12-12-12-28.5t12-28.5l28-28q12-12 28-12t28 12q12 12 12 28.5T225-169l-28 28q-11 11-27.5 11T141-141Zm739-299q-17 0-28.5-11.5T840-480q0-17 11.5-28.5T880-520h40q17 0 28.5 11.5T960-480q0 17-11.5 28.5T920-440h-40Zm-840 0q-17 0-28.5-11.5T0-480q0-17 11.5-28.5T40-520h40q17 0 28.5 11.5T120-480q0 17-11.5 28.5T80-440H40Zm779 299q-12 12-28.5 12T762-141l-28-28q-12-12-12-28t12-28q12-12 28.5-12t28.5 12l28 28q11 11 11 27.5T819-141ZM226-735q-12 12-28.5 12T169-735l-28-28q-11-11-11-27.5t11-28.5q12-12 28.5-12t28.5 12l28 28q12 12 12 28t-12 28Zm170 339Z"/></svg>
|
||||
|
After Width: | Height: | Size: 1.4 KiB |
12
chrome-extension/welcome.html
Normal file
@@ -0,0 +1,12 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Welcome</title>
|
||||
<script src="welcome.js"></script>
|
||||
</head>
|
||||
<body>
|
||||
This page exists to workaround an issue with Chrome that blocks permission
|
||||
requests from chrome extensions
|
||||
<!-- <button id="requestMicrophone">Request Microphone</button> -->
|
||||
</body>
|
||||
</html>
|
||||
BIN
demo.png
|
Before Width: | Height: | Size: 438 KiB After Width: | Height: | Size: 449 KiB |
57
pyproject.toml
Normal file
@@ -0,0 +1,57 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "whisperlivekit"
|
||||
version = "0.2.9"
|
||||
description = "Real-time speech-to-text with speaker diarization using Whisper"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Quentin Fuxa" }
|
||||
]
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.9"
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Developers",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Programming Language :: Python :: 3.14",
|
||||
"Programming Language :: Python :: 3.15",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"Topic :: Multimedia :: Sound/Audio :: Speech"
|
||||
]
|
||||
dependencies = [
|
||||
"fastapi",
|
||||
"librosa",
|
||||
"soundfile",
|
||||
"faster-whisper",
|
||||
"uvicorn",
|
||||
"websockets",
|
||||
"torchaudio>=2.0.0",
|
||||
"torch>=2.0.0",
|
||||
"tqdm",
|
||||
"tiktoken",
|
||||
'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")'
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
sentence = ["mosestokenizer", "wtpsplit"]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/QuentinFuxa/WhisperLiveKit"
|
||||
|
||||
[project.scripts]
|
||||
whisperlivekit-server = "whisperlivekit.basic_server:main"
|
||||
|
||||
[tool.setuptools]
|
||||
packages = ["whisperlivekit", "whisperlivekit.diarization", "whisperlivekit.simul_whisper", "whisperlivekit.simul_whisper.whisper", "whisperlivekit.simul_whisper.whisper.assets", "whisperlivekit.simul_whisper.whisper.normalizers", "whisperlivekit.web", "whisperlivekit.whisper_streaming_custom"]
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"]
|
||||
"whisperlivekit.simul_whisper.whisper.assets" = ["*.tiktoken", "*.npz"]
|
||||
55
setup.py
@@ -1,55 +0,0 @@
|
||||
from setuptools import setup, find_packages
|
||||
setup(
|
||||
name="whisperlivekit",
|
||||
version="0.2.1",
|
||||
description="Real-time, Fully Local Whisper's Speech-to-Text and Speaker Diarization",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
author="Quentin Fuxa",
|
||||
url="https://github.com/QuentinFuxa/WhisperLiveKit",
|
||||
packages=find_packages(),
|
||||
install_requires=[
|
||||
"fastapi",
|
||||
"librosa",
|
||||
"soundfile",
|
||||
"faster-whisper",
|
||||
"uvicorn",
|
||||
"websockets",
|
||||
],
|
||||
extras_require={
|
||||
"diarization": ["diart"],
|
||||
"vac": ["torch"],
|
||||
"sentence": ["mosestokenizer", "wtpsplit"],
|
||||
"whisper": ["whisper"],
|
||||
"whisper-timestamped": ["whisper-timestamped"],
|
||||
"mlx-whisper": ["mlx-whisper"],
|
||||
"openai": ["openai"],
|
||||
"simulstreaming": [
|
||||
"torch",
|
||||
"tqdm",
|
||||
"tiktoken",
|
||||
"numpy<2.0.0",
|
||||
"triton>=2.0.0,<3;platform_machine==\"x86_64\" and sys_platform==\"linux\" or sys_platform==\"linux2\"",
|
||||
],
|
||||
},
|
||||
package_data={
|
||||
'whisperlivekit': ['web/*.html'],
|
||||
'whisperlivekit.simul_whisper': ['dual_license_simulstreaming.md'],
|
||||
'whisperlivekit.simul_whisper.whisper.assets': ['*.tiktoken', '*.npz'],
|
||||
},
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
'whisperlivekit-server=whisperlivekit.basic_server:main',
|
||||
],
|
||||
},
|
||||
classifiers=[
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Developers",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"Topic :: Multimedia :: Sound/Audio :: Speech",
|
||||
],
|
||||
python_requires=">=3.9",
|
||||
)
|
||||
@@ -1,13 +1,13 @@
|
||||
from .download_simulstreaming_backend import download_simulstreaming_backend
|
||||
from .audio_processor import AudioProcessor
|
||||
from .core import TranscriptionEngine
|
||||
from .parse_args import parse_args
|
||||
from .web.web_interface import get_web_interface_html
|
||||
from .web.web_interface import get_web_interface_html, get_inline_ui_html
|
||||
|
||||
__all__ = [
|
||||
"TranscriptionEngine",
|
||||
"AudioProcessor",
|
||||
"parse_args",
|
||||
"get_web_interface_html",
|
||||
"get_inline_ui_html",
|
||||
"download_simulstreaming_backend",
|
||||
]
|
||||
|
||||
@@ -4,12 +4,11 @@ from time import time, sleep
|
||||
import math
|
||||
import logging
|
||||
import traceback
|
||||
from datetime import timedelta
|
||||
from whisperlivekit.timed_objects import ASRToken
|
||||
from whisperlivekit.whisper_streaming_custom.whisper_online import online_factory
|
||||
from whisperlivekit.core import TranscriptionEngine
|
||||
from whisperlivekit.timed_objects import ASRToken, Silence, Line
|
||||
from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory, online_translation_factory
|
||||
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
|
||||
|
||||
from whisperlivekit.silero_vad_iterator import FixedVADIterator
|
||||
from whisperlivekit.results_formater import format_output
|
||||
# Set up logging once
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -17,9 +16,16 @@ logger.setLevel(logging.DEBUG)
|
||||
|
||||
SENTINEL = object() # unique sentinel object for end of stream marker
|
||||
|
||||
def format_time(seconds: float) -> str:
|
||||
"""Format seconds as HH:MM:SS."""
|
||||
return str(timedelta(seconds=int(seconds)))
|
||||
|
||||
async def get_all_from_queue(queue):
|
||||
items = []
|
||||
try:
|
||||
while True:
|
||||
item = queue.get_nowait()
|
||||
items.append(item)
|
||||
except asyncio.QueueEmpty:
|
||||
pass
|
||||
return items
|
||||
|
||||
class AudioProcessor:
|
||||
"""
|
||||
@@ -46,25 +52,33 @@ class AudioProcessor:
|
||||
self.last_ffmpeg_activity = time()
|
||||
self.ffmpeg_health_check_interval = 5
|
||||
self.ffmpeg_max_idle_time = 10
|
||||
self.is_pcm_input = self.args.pcm_input
|
||||
self.debug = False
|
||||
|
||||
# State management
|
||||
self.is_stopping = False
|
||||
self.silence = False
|
||||
self.silence_duration = 0.0
|
||||
self.tokens = []
|
||||
self.translated_segments = []
|
||||
self.buffer_transcription = ""
|
||||
self.buffer_diarization = ""
|
||||
self.full_transcription = ""
|
||||
self.end_buffer = 0
|
||||
self.end_attributed_speaker = 0
|
||||
self.lock = asyncio.Lock()
|
||||
self.beg_loop = time()
|
||||
self.beg_loop = None #to deal with a potential little lag at the websocket initialization, this is now set in process_audio
|
||||
self.sep = " " # Default separator
|
||||
self.last_response_content = ""
|
||||
|
||||
# Models and processing
|
||||
self.asr = models.asr
|
||||
self.tokenizer = models.tokenizer
|
||||
self.diarization = models.diarization
|
||||
|
||||
self.vac_model = models.vac_model
|
||||
if self.args.vac:
|
||||
self.vac = FixedVADIterator(models.vac_model)
|
||||
else:
|
||||
self.vac = None
|
||||
|
||||
self.ffmpeg_manager = FFmpegManager(
|
||||
sample_rate=self.sample_rate,
|
||||
channels=self.channels
|
||||
@@ -79,30 +93,32 @@ class AudioProcessor:
|
||||
|
||||
self.transcription_queue = asyncio.Queue() if self.args.transcription else None
|
||||
self.diarization_queue = asyncio.Queue() if self.args.diarization else None
|
||||
self.translation_queue = asyncio.Queue() if self.args.target_language else None
|
||||
self.pcm_buffer = bytearray()
|
||||
|
||||
# Task references
|
||||
self.transcription_task = None
|
||||
self.diarization_task = None
|
||||
self.ffmpeg_reader_task = None
|
||||
self.watchdog_task = None
|
||||
self.all_tasks_for_cleanup = []
|
||||
|
||||
# Initialize transcription engine if enabled
|
||||
if self.args.transcription:
|
||||
self.online = online_factory(self.args, models.asr, models.tokenizer)
|
||||
self.online = online_factory(self.args, models.asr, models.tokenizer)
|
||||
if self.args.diarization:
|
||||
self.diarization = online_diarization_factory(self.args, models.diarization_model)
|
||||
if self.args.target_language:
|
||||
self.online_translation = online_translation_factory(self.args, models.translation_model)
|
||||
|
||||
def convert_pcm_to_float(self, pcm_buffer):
|
||||
"""Convert PCM buffer in s16le format to normalized NumPy array."""
|
||||
return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
|
||||
|
||||
async def update_transcription(self, new_tokens, buffer, end_buffer, full_transcription, sep):
|
||||
async def update_transcription(self, new_tokens, buffer, end_buffer, sep):
|
||||
"""Thread-safe update of transcription with new data."""
|
||||
async with self.lock:
|
||||
self.tokens.extend(new_tokens)
|
||||
self.buffer_transcription = buffer
|
||||
self.end_buffer = end_buffer
|
||||
self.full_transcription = full_transcription
|
||||
self.sep = sep
|
||||
|
||||
async def update_diarization(self, end_attributed_speaker, buffer_diarization=""):
|
||||
@@ -115,7 +131,7 @@ class AudioProcessor:
|
||||
async def add_dummy_token(self):
|
||||
"""Placeholder token when no transcription is available."""
|
||||
async with self.lock:
|
||||
current_time = time() - self.beg_loop
|
||||
current_time = time() - self.beg_loop if self.beg_loop else 0
|
||||
self.tokens.append(ASRToken(
|
||||
start=current_time, end=current_time + 1,
|
||||
text=".", speaker=-1, is_dummy=True
|
||||
@@ -129,15 +145,16 @@ class AudioProcessor:
|
||||
# Calculate remaining times
|
||||
remaining_transcription = 0
|
||||
if self.end_buffer > 0:
|
||||
remaining_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 2))
|
||||
remaining_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 1))
|
||||
|
||||
remaining_diarization = 0
|
||||
if self.tokens:
|
||||
latest_end = max(self.end_buffer, self.tokens[-1].end if self.tokens else 0)
|
||||
remaining_diarization = max(0, round(latest_end - self.end_attributed_speaker, 2))
|
||||
remaining_diarization = max(0, round(latest_end - self.end_attributed_speaker, 1))
|
||||
|
||||
return {
|
||||
"tokens": self.tokens.copy(),
|
||||
"translated_segments": self.translated_segments.copy(),
|
||||
"buffer_transcription": self.buffer_transcription,
|
||||
"buffer_diarization": self.buffer_diarization,
|
||||
"end_buffer": self.end_buffer,
|
||||
@@ -151,9 +168,9 @@ class AudioProcessor:
|
||||
"""Reset all state variables to initial values."""
|
||||
async with self.lock:
|
||||
self.tokens = []
|
||||
self.translated_segments = []
|
||||
self.buffer_transcription = self.buffer_diarization = ""
|
||||
self.end_buffer = self.end_attributed_speaker = 0
|
||||
self.full_transcription = self.last_response_content = ""
|
||||
self.beg_loop = time()
|
||||
|
||||
async def ffmpeg_stdout_reader(self):
|
||||
@@ -192,32 +209,9 @@ class AudioProcessor:
|
||||
continue
|
||||
|
||||
self.pcm_buffer.extend(chunk)
|
||||
|
||||
# Send to diarization if enabled
|
||||
if self.args.diarization and self.diarization_queue:
|
||||
await self.diarization_queue.put(
|
||||
self.convert_pcm_to_float(self.pcm_buffer).copy()
|
||||
)
|
||||
|
||||
# Process when enough data
|
||||
if len(self.pcm_buffer) >= self.bytes_per_sec:
|
||||
if len(self.pcm_buffer) > self.max_bytes_per_sec:
|
||||
logger.warning(
|
||||
f"Audio buffer too large: {len(self.pcm_buffer) / self.bytes_per_sec:.2f}s. "
|
||||
f"Consider using a smaller model."
|
||||
)
|
||||
|
||||
# Process audio chunk
|
||||
pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:self.max_bytes_per_sec])
|
||||
self.pcm_buffer = self.pcm_buffer[self.max_bytes_per_sec:]
|
||||
await self.handle_pcm_data()
|
||||
|
||||
# Send to transcription if enabled
|
||||
if self.args.transcription and self.transcription_queue:
|
||||
await self.transcription_queue.put(pcm_array.copy())
|
||||
|
||||
# Sleep if no processing is happening
|
||||
if not self.args.transcription and not self.args.diarization:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in ffmpeg_stdout_reader: {e}")
|
||||
@@ -236,37 +230,48 @@ class AudioProcessor:
|
||||
if self.args.diarization and self.diarization_queue:
|
||||
await self.diarization_queue.put(SENTINEL)
|
||||
logger.debug("Sentinel put into diarization_queue.")
|
||||
if self.args.target_language and self.translation_queue:
|
||||
await self.translation_queue.put(SENTINEL)
|
||||
|
||||
|
||||
async def transcription_processor(self):
|
||||
"""Process audio chunks for transcription."""
|
||||
self.full_transcription = ""
|
||||
self.sep = self.online.asr.sep
|
||||
cumulative_pcm_duration_stream_time = 0.0
|
||||
|
||||
while True:
|
||||
try:
|
||||
pcm_array = await self.transcription_queue.get()
|
||||
if pcm_array is SENTINEL:
|
||||
item = await self.transcription_queue.get()
|
||||
if item is SENTINEL:
|
||||
logger.debug("Transcription processor received sentinel. Finishing.")
|
||||
self.transcription_queue.task_done()
|
||||
break
|
||||
|
||||
if not self.online: # Should not happen if queue is used
|
||||
if not self.online:
|
||||
logger.warning("Transcription processor: self.online not initialized.")
|
||||
self.transcription_queue.task_done()
|
||||
continue
|
||||
|
||||
asr_internal_buffer_duration_s = len(getattr(self.online, 'audio_buffer', [])) / self.online.SAMPLING_RATE
|
||||
transcription_lag_s = max(0.0, time() - self.beg_loop - self.end_buffer)
|
||||
|
||||
logger.info(
|
||||
f"ASR processing: internal_buffer={asr_internal_buffer_duration_s:.2f}s, "
|
||||
f"lag={transcription_lag_s:.2f}s."
|
||||
)
|
||||
asr_processing_logs = f"internal_buffer={asr_internal_buffer_duration_s:.2f}s | lag={transcription_lag_s:.2f}s |"
|
||||
if type(item) is Silence:
|
||||
asr_processing_logs += f" + Silence of = {item.duration:.2f}s"
|
||||
if self.tokens:
|
||||
asr_processing_logs += f" | last_end = {self.tokens[-1].end} |"
|
||||
logger.info(asr_processing_logs)
|
||||
|
||||
# Process transcription
|
||||
duration_this_chunk = len(pcm_array) / self.sample_rate if isinstance(pcm_array, np.ndarray) else 0
|
||||
if type(item) is Silence:
|
||||
cumulative_pcm_duration_stream_time += item.duration
|
||||
self.online.insert_silence(item.duration, self.tokens[-1].end if self.tokens else 0)
|
||||
continue
|
||||
|
||||
if isinstance(item, np.ndarray):
|
||||
pcm_array = item
|
||||
else:
|
||||
raise Exception('item should be pcm_array')
|
||||
|
||||
duration_this_chunk = len(pcm_array) / self.sample_rate
|
||||
cumulative_pcm_duration_stream_time += duration_this_chunk
|
||||
stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time
|
||||
|
||||
@@ -279,8 +284,6 @@ class AudioProcessor:
|
||||
|
||||
if new_tokens:
|
||||
validated_text = self.sep.join([t.text for t in new_tokens])
|
||||
self.full_transcription += validated_text
|
||||
|
||||
if buffer_text.startswith(validated_text):
|
||||
buffer_text = buffer_text[len(validated_text):].lstrip()
|
||||
|
||||
@@ -297,8 +300,13 @@ class AudioProcessor:
|
||||
new_end_buffer = max(candidate_end_times)
|
||||
|
||||
await self.update_transcription(
|
||||
new_tokens, buffer_text, new_end_buffer, self.full_transcription, self.sep
|
||||
new_tokens, buffer_text, new_end_buffer, self.sep
|
||||
)
|
||||
|
||||
if new_tokens and self.args.target_language and self.translation_queue:
|
||||
for token in new_tokens:
|
||||
await self.translation_queue.put(token)
|
||||
|
||||
self.transcription_queue.task_done()
|
||||
|
||||
except Exception as e:
|
||||
@@ -312,25 +320,35 @@ class AudioProcessor:
|
||||
async def diarization_processor(self, diarization_obj):
|
||||
"""Process audio chunks for speaker diarization."""
|
||||
buffer_diarization = ""
|
||||
|
||||
cumulative_pcm_duration_stream_time = 0.0
|
||||
while True:
|
||||
try:
|
||||
pcm_array = await self.diarization_queue.get()
|
||||
if pcm_array is SENTINEL:
|
||||
item = await self.diarization_queue.get()
|
||||
if item is SENTINEL:
|
||||
logger.debug("Diarization processor received sentinel. Finishing.")
|
||||
self.diarization_queue.task_done()
|
||||
break
|
||||
|
||||
if type(item) is Silence:
|
||||
cumulative_pcm_duration_stream_time += item.duration
|
||||
diarization_obj.insert_silence(item.duration)
|
||||
continue
|
||||
|
||||
if isinstance(item, np.ndarray):
|
||||
pcm_array = item
|
||||
else:
|
||||
raise Exception('item should be pcm_array')
|
||||
|
||||
# Process diarization
|
||||
await diarization_obj.diarize(pcm_array)
|
||||
|
||||
async with self.lock:
|
||||
new_end = diarization_obj.assign_speakers_to_tokens(
|
||||
self.end_attributed_speaker,
|
||||
self.tokens = diarization_obj.assign_speakers_to_tokens(
|
||||
self.tokens,
|
||||
use_punctuation_split=self.args.punctuation_split
|
||||
)
|
||||
self.end_attributed_speaker = new_end
|
||||
if len(self.tokens) > 0:
|
||||
self.end_attributed_speaker = max(self.tokens[-1].end, self.end_attributed_speaker)
|
||||
if buffer_diarization:
|
||||
self.buffer_diarization = buffer_diarization
|
||||
|
||||
@@ -343,9 +361,54 @@ class AudioProcessor:
|
||||
self.diarization_queue.task_done()
|
||||
logger.info("Diarization processor task finished.")
|
||||
|
||||
async def translation_processor(self, online_translation):
|
||||
# the idea is to ignore diarization for the moment. We use only transcription tokens.
|
||||
# And the speaker is attributed given the segments used for the translation
|
||||
# in the future we want to have different languages for each speaker etc, so it will be more complex.
|
||||
while True:
|
||||
try:
|
||||
token = await self.translation_queue.get() #block until at least 1 token
|
||||
if token is SENTINEL:
|
||||
logger.debug("Translation processor received sentinel. Finishing.")
|
||||
self.translation_queue.task_done()
|
||||
break
|
||||
|
||||
# get all the available tokens for translation. The more words, the more precise
|
||||
tokens_to_process = [token]
|
||||
additional_tokens = await get_all_from_queue(self.translation_queue)
|
||||
|
||||
sentinel_found = False
|
||||
for additional_token in additional_tokens:
|
||||
if additional_token is SENTINEL:
|
||||
sentinel_found = True
|
||||
break
|
||||
tokens_to_process.append(additional_token)
|
||||
if tokens_to_process:
|
||||
online_translation.insert_tokens(tokens_to_process)
|
||||
self.translated_segments = online_translation.process()
|
||||
|
||||
self.translation_queue.task_done()
|
||||
for _ in additional_tokens:
|
||||
self.translation_queue.task_done()
|
||||
|
||||
if sentinel_found:
|
||||
logger.debug("Translation processor received sentinel in batch. Finishing.")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in translation_processor: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
if 'token' in locals() and token is not SENTINEL:
|
||||
self.translation_queue.task_done()
|
||||
if 'additional_tokens' in locals():
|
||||
for _ in additional_tokens:
|
||||
self.translation_queue.task_done()
|
||||
logger.info("Translation processor task finished.")
|
||||
|
||||
async def results_formatter(self):
|
||||
"""Format processing results for output."""
|
||||
last_sent_trans = None
|
||||
last_sent_diar = None
|
||||
while True:
|
||||
try:
|
||||
ffmpeg_state = await self.ffmpeg_manager.get_state()
|
||||
@@ -370,7 +433,7 @@ class AudioProcessor:
|
||||
buffer_diarization = state["buffer_diarization"]
|
||||
end_attributed_speaker = state["end_attributed_speaker"]
|
||||
sep = state["sep"]
|
||||
|
||||
|
||||
# Add dummy tokens if needed
|
||||
if (not tokens or tokens[-1].is_dummy) and not self.args.transcription and self.args.diarization:
|
||||
await self.add_dummy_token()
|
||||
@@ -379,40 +442,13 @@ class AudioProcessor:
|
||||
tokens = state["tokens"]
|
||||
|
||||
# Format output
|
||||
previous_speaker = -1
|
||||
lines = []
|
||||
last_end_diarized = 0
|
||||
undiarized_text = []
|
||||
|
||||
# Process each token
|
||||
for token in tokens:
|
||||
speaker = token.speaker
|
||||
|
||||
# Handle diarization
|
||||
if self.args.diarization:
|
||||
if (speaker in [-1, 0]) and token.end >= end_attributed_speaker:
|
||||
undiarized_text.append(token.text)
|
||||
continue
|
||||
elif (speaker in [-1, 0]) and token.end < end_attributed_speaker:
|
||||
speaker = previous_speaker
|
||||
if speaker not in [-1, 0]:
|
||||
last_end_diarized = max(token.end, last_end_diarized)
|
||||
|
||||
# Group by speaker
|
||||
if speaker != previous_speaker or not lines:
|
||||
lines.append({
|
||||
"speaker": speaker,
|
||||
"text": token.text,
|
||||
"beg": format_time(token.start),
|
||||
"end": format_time(token.end),
|
||||
"diff": round(token.end - last_end_diarized, 2)
|
||||
})
|
||||
previous_speaker = speaker
|
||||
elif token.text: # Only append if text isn't empty
|
||||
lines[-1]["text"] += sep + token.text
|
||||
lines[-1]["end"] = format_time(token.end)
|
||||
lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
|
||||
|
||||
lines, undiarized_text, buffer_transcription, buffer_diarization = format_output(
|
||||
state,
|
||||
self.silence,
|
||||
current_time = time() - self.beg_loop if self.beg_loop else None,
|
||||
args = self.args,
|
||||
debug = self.debug
|
||||
)
|
||||
# Handle undiarized text
|
||||
if undiarized_text:
|
||||
combined = sep.join(undiarized_text)
|
||||
@@ -422,37 +458,42 @@ class AudioProcessor:
|
||||
buffer_diarization = combined
|
||||
|
||||
response_status = "active_transcription"
|
||||
final_lines_for_response = lines.copy()
|
||||
|
||||
if not tokens and not buffer_transcription and not buffer_diarization:
|
||||
response_status = "no_audio_detected"
|
||||
final_lines_for_response = []
|
||||
elif response_status == "active_transcription" and not final_lines_for_response:
|
||||
final_lines_for_response = [{
|
||||
"speaker": 1,
|
||||
"text": "",
|
||||
"beg": format_time(state.get("end_buffer", 0)),
|
||||
"end": format_time(state.get("end_buffer", 0)),
|
||||
"diff": 0
|
||||
}]
|
||||
lines = []
|
||||
elif response_status == "active_transcription" and not lines:
|
||||
lines = [Line(
|
||||
speaker=1,
|
||||
start=state.get("end_buffer", 0),
|
||||
end=state.get("end_buffer", 0)
|
||||
)]
|
||||
|
||||
response = {
|
||||
"status": response_status,
|
||||
"lines": final_lines_for_response,
|
||||
"lines": [line.to_dict() for line in lines],
|
||||
"buffer_transcription": buffer_transcription,
|
||||
"buffer_diarization": buffer_diarization,
|
||||
"remaining_time_transcription": state["remaining_time_transcription"],
|
||||
"remaining_time_diarization": state["remaining_time_diarization"]
|
||||
"remaining_time_diarization": state["remaining_time_diarization"] if self.args.diarization else 0
|
||||
}
|
||||
|
||||
current_response_signature = f"{response_status} | " + \
|
||||
' '.join([f"{line['speaker']} {line['text']}" for line in final_lines_for_response]) + \
|
||||
' '.join([f"{line.speaker} {line.text}" for line in lines]) + \
|
||||
f" | {buffer_transcription} | {buffer_diarization}"
|
||||
|
||||
if current_response_signature != self.last_response_content and \
|
||||
(final_lines_for_response or buffer_transcription or buffer_diarization or response_status == "no_audio_detected"):
|
||||
trans = state["remaining_time_transcription"]
|
||||
diar = state["remaining_time_diarization"]
|
||||
should_push = (
|
||||
current_response_signature != self.last_response_content
|
||||
or last_sent_trans is None
|
||||
or round(trans, 1) != round(last_sent_trans, 1)
|
||||
or round(diar, 1) != round(last_sent_diar, 1)
|
||||
)
|
||||
if should_push and (lines or buffer_transcription or buffer_diarization or response_status == "no_audio_detected" or trans > 0 or diar > 0):
|
||||
yield response
|
||||
self.last_response_content = current_response_signature
|
||||
last_sent_trans = trans
|
||||
last_sent_diar = diar
|
||||
|
||||
# Check for termination condition
|
||||
if self.is_stopping:
|
||||
@@ -464,7 +505,6 @@ class AudioProcessor:
|
||||
|
||||
if all_processors_done:
|
||||
logger.info("Results formatter: All upstream processors are done and in stopping state. Terminating.")
|
||||
final_state = await self.get_current_state()
|
||||
return
|
||||
|
||||
await asyncio.sleep(0.1) # Avoid overwhelming the client
|
||||
@@ -504,6 +544,11 @@ class AudioProcessor:
|
||||
self.all_tasks_for_cleanup.append(self.diarization_task)
|
||||
processing_tasks_for_watchdog.append(self.diarization_task)
|
||||
|
||||
if self.args.target_language and self.args.lan != 'auto':
|
||||
self.translation_task = asyncio.create_task(self.translation_processor(self.online_translation))
|
||||
self.all_tasks_for_cleanup.append(self.translation_task)
|
||||
processing_tasks_for_watchdog.append(self.translation_task)
|
||||
|
||||
self.ffmpeg_reader_task = asyncio.create_task(self.ffmpeg_stdout_reader())
|
||||
self.all_tasks_for_cleanup.append(self.ffmpeg_reader_task)
|
||||
processing_tasks_for_watchdog.append(self.ffmpeg_reader_task)
|
||||
@@ -546,24 +591,29 @@ class AudioProcessor:
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up resources when processing is complete."""
|
||||
logger.info("Starting cleanup of AudioProcessor resources.")
|
||||
logger.info("Starting cleanup of AudioProcessor resources.")
|
||||
self.is_stopping = True
|
||||
for task in self.all_tasks_for_cleanup:
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
|
||||
created_tasks = [t for t in self.all_tasks_for_cleanup if t]
|
||||
if created_tasks:
|
||||
await asyncio.gather(*created_tasks, return_exceptions=True)
|
||||
logger.info("All processing tasks cancelled or finished.")
|
||||
await self.ffmpeg_manager.stop()
|
||||
logger.info("FFmpeg manager stopped.")
|
||||
if self.args.diarization and hasattr(self, 'diarization') and hasattr(self.diarization, 'close'):
|
||||
self.diarization.close()
|
||||
logger.info("AudioProcessor cleanup complete.")
|
||||
|
||||
created_tasks = [t for t in self.all_tasks_for_cleanup if t]
|
||||
if created_tasks:
|
||||
await asyncio.gather(*created_tasks, return_exceptions=True)
|
||||
logger.info("All processing tasks cancelled or finished.")
|
||||
await self.ffmpeg_manager.stop()
|
||||
logger.info("FFmpeg manager stopped.")
|
||||
if self.args.diarization and hasattr(self, 'diarization') and hasattr(self.diarization, 'close'):
|
||||
self.diarization.close()
|
||||
logger.info("AudioProcessor cleanup complete.")
|
||||
|
||||
|
||||
async def process_audio(self, message):
|
||||
"""Process incoming audio data."""
|
||||
|
||||
if not self.beg_loop:
|
||||
self.beg_loop = time()
|
||||
|
||||
if not message:
|
||||
logger.info("Empty audio message received, initiating stop sequence.")
|
||||
self.is_stopping = True
|
||||
@@ -575,10 +625,65 @@ class AudioProcessor:
|
||||
logger.warning("AudioProcessor is stopping. Ignoring incoming audio.")
|
||||
return
|
||||
|
||||
success = await self.ffmpeg_manager.write_data(message)
|
||||
if not success:
|
||||
ffmpeg_state = await self.ffmpeg_manager.get_state()
|
||||
if ffmpeg_state == FFmpegState.FAILED:
|
||||
logger.error("FFmpeg is in FAILED state, cannot process audio")
|
||||
else:
|
||||
logger.warning("Failed to write audio data to FFmpeg")
|
||||
if self.is_pcm_input:
|
||||
self.pcm_buffer.extend(message)
|
||||
await self.handle_pcm_data()
|
||||
else:
|
||||
success = await self.ffmpeg_manager.write_data(message)
|
||||
if not success:
|
||||
ffmpeg_state = await self.ffmpeg_manager.get_state()
|
||||
if ffmpeg_state == FFmpegState.FAILED:
|
||||
logger.error("FFmpeg is in FAILED state, cannot process audio")
|
||||
else:
|
||||
logger.warning("Failed to write audio data to FFmpeg")
|
||||
|
||||
async def handle_pcm_data(self):
|
||||
# Process when enough data
|
||||
if len(self.pcm_buffer) < self.bytes_per_sec:
|
||||
return
|
||||
|
||||
if len(self.pcm_buffer) > self.max_bytes_per_sec:
|
||||
logger.warning(
|
||||
f"Audio buffer too large: {len(self.pcm_buffer) / self.bytes_per_sec:.2f}s. "
|
||||
f"Consider using a smaller model."
|
||||
)
|
||||
|
||||
# Process audio chunk
|
||||
pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:self.max_bytes_per_sec])
|
||||
self.pcm_buffer = self.pcm_buffer[self.max_bytes_per_sec:]
|
||||
|
||||
res = None
|
||||
end_of_audio = False
|
||||
silence_buffer = None
|
||||
|
||||
if self.args.vac:
|
||||
res = self.vac(pcm_array)
|
||||
|
||||
if res is not None:
|
||||
if res.get("end", 0) > res.get("start", 0):
|
||||
end_of_audio = True
|
||||
elif self.silence: #end of silence
|
||||
self.silence = False
|
||||
silence_buffer = Silence(duration=time() - self.start_silence)
|
||||
|
||||
if silence_buffer:
|
||||
if self.args.transcription and self.transcription_queue:
|
||||
await self.transcription_queue.put(silence_buffer)
|
||||
if self.args.diarization and self.diarization_queue:
|
||||
await self.diarization_queue.put(silence_buffer)
|
||||
|
||||
if not self.silence:
|
||||
if self.args.transcription and self.transcription_queue:
|
||||
await self.transcription_queue.put(pcm_array.copy())
|
||||
|
||||
if self.args.diarization and self.diarization_queue:
|
||||
await self.diarization_queue.put(pcm_array.copy())
|
||||
|
||||
self.silence_duration = 0.0
|
||||
|
||||
if end_of_audio:
|
||||
self.silence = True
|
||||
self.start_silence = time()
|
||||
|
||||
if not self.args.transcription and not self.args.diarization:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
@@ -2,9 +2,12 @@ from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from whisperlivekit import TranscriptionEngine, AudioProcessor, get_web_interface_html, parse_args
|
||||
from whisperlivekit import TranscriptionEngine, AudioProcessor, get_inline_ui_html, parse_args
|
||||
import asyncio
|
||||
import logging
|
||||
from starlette.staticfiles import StaticFiles
|
||||
import pathlib
|
||||
import whisperlivekit.web as webpkg
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
logging.getLogger().setLevel(logging.WARNING)
|
||||
@@ -16,6 +19,15 @@ transcription_engine = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
|
||||
#to remove after 0.2.8
|
||||
if args.backend == "simulstreaming" and not args.disable_fast_encoder:
|
||||
logger.warning(f"""
|
||||
{'='*50}
|
||||
WhisperLiveKit 0.2.8 has introduced a new fast encoder feature using MLX Whisper or Faster Whisper for improved speed. Use --disable-fast-encoder to disable if you encounter issues.
|
||||
{'='*50}
|
||||
""")
|
||||
|
||||
global transcription_engine
|
||||
transcription_engine = TranscriptionEngine(
|
||||
**vars(args),
|
||||
@@ -30,10 +42,12 @@ app.add_middleware(
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
web_dir = pathlib.Path(webpkg.__file__).parent
|
||||
app.mount("/web", StaticFiles(directory=str(web_dir)), name="web")
|
||||
|
||||
@app.get("/")
|
||||
async def get():
|
||||
return HTMLResponse(get_web_interface_html())
|
||||
return HTMLResponse(get_inline_ui_html())
|
||||
|
||||
|
||||
async def handle_websocket_results(websocket, results_generator):
|
||||
@@ -47,7 +61,7 @@ async def handle_websocket_results(websocket, results_generator):
|
||||
except WebSocketDisconnect:
|
||||
logger.info("WebSocket disconnected while handling results (client likely closed connection).")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in WebSocket results handler: {e}")
|
||||
logger.exception(f"Error in WebSocket results handler: {e}")
|
||||
|
||||
|
||||
@app.websocket("/asr")
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
try:
|
||||
from whisperlivekit.whisper_streaming_custom.whisper_online import backend_factory, warmup_asr
|
||||
from whisperlivekit.whisper_streaming_custom.whisper_online import backend_factory
|
||||
from whisperlivekit.whisper_streaming_custom.online_asr import OnlineASRProcessor
|
||||
except ImportError:
|
||||
from .whisper_streaming_custom.whisper_online import backend_factory, warmup_asr
|
||||
from .whisper_streaming_custom.whisper_online import backend_factory
|
||||
from .whisper_streaming_custom.online_asr import OnlineASRProcessor
|
||||
from whisperlivekit.warmup import warmup_asr, warmup_online
|
||||
from argparse import Namespace
|
||||
|
||||
import sys
|
||||
|
||||
class TranscriptionEngine:
|
||||
_instance = None
|
||||
@@ -22,7 +25,6 @@ class TranscriptionEngine:
|
||||
"host": "localhost",
|
||||
"port": 8000,
|
||||
"warmup_file": None,
|
||||
"confidence_validation": False,
|
||||
"diarization": False,
|
||||
"punctuation_split": False,
|
||||
"min_chunk_size": 0.5,
|
||||
@@ -31,23 +33,26 @@ class TranscriptionEngine:
|
||||
"model_dir": None,
|
||||
"lan": "auto",
|
||||
"task": "transcribe",
|
||||
"target_language": "",
|
||||
"backend": "faster-whisper",
|
||||
"vac": False,
|
||||
"vac": True,
|
||||
"vac_chunk_size": 0.04,
|
||||
"buffer_trimming": "segment",
|
||||
"buffer_trimming_sec": 15,
|
||||
"log_level": "DEBUG",
|
||||
"ssl_certfile": None,
|
||||
"ssl_keyfile": None,
|
||||
"transcription": True,
|
||||
"vad": True,
|
||||
"segmentation_model": "pyannote/segmentation-3.0",
|
||||
"embedding_model": "pyannote/embedding",
|
||||
"pcm_input": False,
|
||||
# whisperstreaming params:
|
||||
"buffer_trimming": "segment",
|
||||
"confidence_validation": False,
|
||||
"buffer_trimming_sec": 15,
|
||||
# simulstreaming params:
|
||||
"disable_fast_encoder": False,
|
||||
"frame_threshold": 25,
|
||||
"beams": 1,
|
||||
"decoder_type": None,
|
||||
"audio_max_len": 30.0,
|
||||
"audio_max_len": 20.0,
|
||||
"audio_min_len": 0.0,
|
||||
"cif_ckpt_path": None,
|
||||
"never_fire": False,
|
||||
@@ -55,6 +60,11 @@ class TranscriptionEngine:
|
||||
"static_init_prompt": None,
|
||||
"max_context_tokens": None,
|
||||
"model_path": './base.pt',
|
||||
"diarization_backend": "sortformer",
|
||||
# diarization params:
|
||||
"disable_punctuation_split" : False,
|
||||
"segmentation_model": "pyannote/segmentation-3.0",
|
||||
"embedding_model": "pyannote/embedding",
|
||||
}
|
||||
|
||||
config_dict = {**defaults, **kwargs}
|
||||
@@ -63,6 +73,8 @@ class TranscriptionEngine:
|
||||
config_dict['transcription'] = not kwargs['no_transcription']
|
||||
if 'no_vad' in kwargs:
|
||||
config_dict['vad'] = not kwargs['no_vad']
|
||||
if 'no_vac' in kwargs:
|
||||
config_dict['vac'] = not kwargs['no_vac']
|
||||
|
||||
config_dict.pop('no_transcription', None)
|
||||
config_dict.pop('no_vad', None)
|
||||
@@ -76,17 +88,99 @@ class TranscriptionEngine:
|
||||
self.asr = None
|
||||
self.tokenizer = None
|
||||
self.diarization = None
|
||||
self.vac_model = None
|
||||
|
||||
if self.args.vac:
|
||||
import torch
|
||||
self.vac_model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
|
||||
|
||||
if self.args.transcription:
|
||||
self.asr, self.tokenizer = backend_factory(self.args)
|
||||
warmup_asr(self.asr, self.args.warmup_file)
|
||||
if self.args.backend == "simulstreaming":
|
||||
from whisperlivekit.simul_whisper import SimulStreamingASR
|
||||
self.tokenizer = None
|
||||
simulstreaming_kwargs = {}
|
||||
for attr in ['frame_threshold', 'beams', 'decoder_type', 'audio_max_len', 'audio_min_len',
|
||||
'cif_ckpt_path', 'never_fire', 'init_prompt', 'static_init_prompt',
|
||||
'max_context_tokens', 'model_path', 'warmup_file', 'preload_model_count', 'disable_fast_encoder']:
|
||||
if hasattr(self.args, attr):
|
||||
simulstreaming_kwargs[attr] = getattr(self.args, attr)
|
||||
|
||||
# Add segment_length from min_chunk_size
|
||||
simulstreaming_kwargs['segment_length'] = getattr(self.args, 'min_chunk_size', 0.5)
|
||||
simulstreaming_kwargs['task'] = self.args.task
|
||||
|
||||
size = self.args.model
|
||||
self.asr = SimulStreamingASR(
|
||||
modelsize=size,
|
||||
lan=self.args.lan,
|
||||
cache_dir=getattr(self.args, 'model_cache_dir', None),
|
||||
model_dir=getattr(self.args, 'model_dir', None),
|
||||
**simulstreaming_kwargs
|
||||
)
|
||||
|
||||
else:
|
||||
self.asr, self.tokenizer = backend_factory(self.args)
|
||||
warmup_asr(self.asr, self.args.warmup_file) #for simulstreaming, warmup should be done in the online class not here
|
||||
|
||||
if self.args.diarization:
|
||||
from whisperlivekit.diarization.diarization_online import DiartDiarization
|
||||
self.diarization = DiartDiarization(
|
||||
block_duration=self.args.min_chunk_size,
|
||||
segmentation_model_name=self.args.segmentation_model,
|
||||
embedding_model_name=self.args.embedding_model
|
||||
)
|
||||
if self.args.diarization_backend == "diart":
|
||||
from whisperlivekit.diarization.diart_backend import DiartDiarization
|
||||
self.diarization_model = DiartDiarization(
|
||||
block_duration=self.args.min_chunk_size,
|
||||
segmentation_model_name=self.args.segmentation_model,
|
||||
embedding_model_name=self.args.embedding_model
|
||||
)
|
||||
elif self.args.diarization_backend == "sortformer":
|
||||
from whisperlivekit.diarization.sortformer_backend import SortformerDiarization
|
||||
self.diarization_model = SortformerDiarization()
|
||||
else:
|
||||
raise ValueError(f"Unknown diarization backend: {self.args.diarization_backend}")
|
||||
|
||||
self.translation_model = None
|
||||
if self.args.target_language:
|
||||
if self.args.lan == 'auto':
|
||||
raise Exception('Translation cannot be set with language auto')
|
||||
else:
|
||||
from whisperlivekit.translation.translation import load_model
|
||||
self.translation_model = load_model([self.args.lan]) #in the future we want to handle different languages for different speakers
|
||||
|
||||
TranscriptionEngine._initialized = True
|
||||
|
||||
|
||||
|
||||
def online_factory(args, asr, tokenizer, logfile=sys.stderr):
|
||||
if args.backend == "simulstreaming":
|
||||
from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor
|
||||
online = SimulStreamingOnlineProcessor(
|
||||
asr,
|
||||
logfile=logfile,
|
||||
)
|
||||
# warmup_online(online, args.warmup_file)
|
||||
else:
|
||||
online = OnlineASRProcessor(
|
||||
asr,
|
||||
tokenizer,
|
||||
logfile=logfile,
|
||||
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
|
||||
confidence_validation = args.confidence_validation
|
||||
)
|
||||
return online
|
||||
|
||||
|
||||
def online_diarization_factory(args, diarization_backend):
|
||||
if args.diarization_backend == "diart":
|
||||
online = diarization_backend
|
||||
# Not the best here, since several user/instances will share the same backend, but diart is not SOTA anymore and sortformer is recommended
|
||||
|
||||
if args.diarization_backend == "sortformer":
|
||||
from whisperlivekit.diarization.sortformer_backend import SortformerDiarizationOnline
|
||||
online = SortformerDiarizationOnline(shared_model=diarization_backend)
|
||||
return online
|
||||
|
||||
|
||||
def online_translation_factory(args, translation_model):
|
||||
#should be at speaker level in the future:
|
||||
#one shared nllb model for all speaker
|
||||
#one tokenizer per speaker/language
|
||||
from whisperlivekit.translation.translation import OnlineTranslation
|
||||
return OnlineTranslation(translation_model, [args.lan], [args.target_language])
|
||||
@@ -29,6 +29,7 @@ class DiarizationObserver(Observer):
|
||||
self.speaker_segments = []
|
||||
self.processed_time = 0
|
||||
self.segment_lock = threading.Lock()
|
||||
self.global_time_offset = 0.0
|
||||
|
||||
def on_next(self, value: Tuple[Annotation, Any]):
|
||||
annotation, audio = value
|
||||
@@ -49,8 +50,8 @@ class DiarizationObserver(Observer):
|
||||
print(f" {speaker}: {start:.2f}s-{end:.2f}s")
|
||||
self.speaker_segments.append(SpeakerSegment(
|
||||
speaker=speaker,
|
||||
start=start,
|
||||
end=end
|
||||
start=start + self.global_time_offset,
|
||||
end=end + self.global_time_offset
|
||||
))
|
||||
else:
|
||||
logger.debug("\nNo speakers detected in this segment")
|
||||
@@ -165,7 +166,7 @@ class WebSocketAudioSource(AudioSource):
|
||||
|
||||
|
||||
class DiartDiarization:
|
||||
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 0.5, segmentation_model_name: str = "pyannote/segmentation-3.0", embedding_model_name: str = "speechbrain/spkrec-ecapa-voxceleb"):
|
||||
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 1.5, segmentation_model_name: str = "pyannote/segmentation-3.0", embedding_model_name: str = "pyannote/embedding"):
|
||||
segmentation_model = m.SegmentationModel.from_pretrained(segmentation_model_name)
|
||||
embedding_model = m.EmbeddingModel.from_pretrained(embedding_model_name)
|
||||
|
||||
@@ -199,6 +200,9 @@ class DiartDiarization:
|
||||
self.inference.attach_observers(self.observer)
|
||||
asyncio.get_event_loop().run_in_executor(None, self.inference)
|
||||
|
||||
def insert_silence(self, silence_duration):
|
||||
self.observer.global_time_offset += silence_duration
|
||||
|
||||
async def diarize(self, pcm_array: np.ndarray):
|
||||
"""
|
||||
Process audio data for diarization.
|
||||
@@ -206,15 +210,14 @@ class DiartDiarization:
|
||||
"""
|
||||
if self.custom_source:
|
||||
self.custom_source.push_audio(pcm_array)
|
||||
self.observer.clear_old_segments()
|
||||
return self.observer.get_segments()
|
||||
# self.observer.clear_old_segments()
|
||||
|
||||
def close(self):
|
||||
"""Close the audio source."""
|
||||
if self.custom_source:
|
||||
self.custom_source.close()
|
||||
|
||||
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list, use_punctuation_split: bool = False) -> float:
|
||||
def assign_speakers_to_tokens(self, tokens: list, use_punctuation_split: bool = False) -> float:
|
||||
"""
|
||||
Assign speakers to tokens based on timing overlap with speaker segments.
|
||||
Uses the segments collected by the observer.
|
||||
@@ -231,85 +234,82 @@ class DiartDiarization:
|
||||
|
||||
if not self.lag_diart and segments and tokens:
|
||||
self.lag_diart = segments[0].start - tokens[0].start
|
||||
for token in tokens:
|
||||
for segment in segments:
|
||||
if not (segment.end <= token.start + self.lag_diart or segment.start >= token.end + self.lag_diart):
|
||||
token.speaker = extract_number(segment.speaker) + 1
|
||||
end_attributed_speaker = max(token.end, end_attributed_speaker)
|
||||
|
||||
if use_punctuation_split and len(tokens) > 1:
|
||||
punctuation_marks = {'.', '!', '?'}
|
||||
|
||||
print("Here are the tokens:",
|
||||
[(t.text, t.start, t.end, t.speaker) for t in tokens[:10]])
|
||||
|
||||
segment_map = []
|
||||
for segment in segments:
|
||||
speaker_num = extract_number(segment.speaker) + 1
|
||||
segment_map.append((segment.start, segment.end, speaker_num))
|
||||
segment_map.sort(key=lambda x: x[0])
|
||||
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
current_token = tokens[i]
|
||||
|
||||
is_sentence_end = False
|
||||
if current_token.text and current_token.text.strip():
|
||||
text = current_token.text.strip()
|
||||
if text[-1] in punctuation_marks:
|
||||
is_sentence_end = True
|
||||
logger.debug(f"Token {i} ends sentence: '{current_token.text}' at {current_token.end:.2f}s")
|
||||
|
||||
if is_sentence_end and current_token.speaker != -1:
|
||||
punctuation_time = current_token.end
|
||||
current_speaker = current_token.speaker
|
||||
|
||||
j = i + 1
|
||||
next_sentence_tokens = []
|
||||
while j < len(tokens):
|
||||
next_token = tokens[j]
|
||||
next_sentence_tokens.append(j)
|
||||
|
||||
# Check if this token ends the next sentence
|
||||
if next_token.text and next_token.text.strip():
|
||||
if next_token.text.strip()[-1] in punctuation_marks:
|
||||
break
|
||||
j += 1
|
||||
|
||||
if next_sentence_tokens:
|
||||
speaker_times = {}
|
||||
|
||||
for idx in next_sentence_tokens:
|
||||
token = tokens[idx]
|
||||
# Find which segments overlap with this token
|
||||
for seg_start, seg_end, seg_speaker in segment_map:
|
||||
if not (seg_end <= token.start or seg_start >= token.end):
|
||||
# Calculate overlap duration
|
||||
overlap_start = max(seg_start, token.start)
|
||||
overlap_end = min(seg_end, token.end)
|
||||
overlap_duration = overlap_end - overlap_start
|
||||
|
||||
if seg_speaker not in speaker_times:
|
||||
speaker_times[seg_speaker] = 0
|
||||
speaker_times[seg_speaker] += overlap_duration
|
||||
|
||||
if speaker_times:
|
||||
dominant_speaker = max(speaker_times.items(), key=lambda x: x[1])[0]
|
||||
|
||||
if dominant_speaker != current_speaker:
|
||||
logger.debug(f" Speaker change after punctuation: {current_speaker} → {dominant_speaker}")
|
||||
|
||||
for idx in next_sentence_tokens:
|
||||
if tokens[idx].speaker != dominant_speaker:
|
||||
logger.debug(f" Reassigning token {idx} ('{tokens[idx].text}') to Speaker {dominant_speaker}")
|
||||
tokens[idx].speaker = dominant_speaker
|
||||
end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker)
|
||||
else:
|
||||
for idx in next_sentence_tokens:
|
||||
if tokens[idx].speaker == -1:
|
||||
tokens[idx].speaker = current_speaker
|
||||
end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker)
|
||||
|
||||
i += 1
|
||||
if not use_punctuation_split:
|
||||
for token in tokens:
|
||||
for segment in segments:
|
||||
if not (segment.end <= token.start + self.lag_diart or segment.start >= token.end + self.lag_diart):
|
||||
token.speaker = extract_number(segment.speaker) + 1
|
||||
else:
|
||||
tokens = add_speaker_to_tokens(segments, tokens)
|
||||
return tokens
|
||||
|
||||
return end_attributed_speaker
|
||||
def concatenate_speakers(segments):
|
||||
segments_concatenated = [{"speaker": 1, "begin": 0.0, "end": 0.0}]
|
||||
for segment in segments:
|
||||
speaker = extract_number(segment.speaker) + 1
|
||||
if segments_concatenated[-1]['speaker'] != speaker:
|
||||
segments_concatenated.append({"speaker": speaker, "begin": segment.start, "end": segment.end})
|
||||
else:
|
||||
segments_concatenated[-1]['end'] = segment.end
|
||||
# print("Segments concatenated:")
|
||||
# for entry in segments_concatenated:
|
||||
# print(f"Speaker {entry['speaker']}: {entry['begin']:.2f}s - {entry['end']:.2f}s")
|
||||
return segments_concatenated
|
||||
|
||||
|
||||
def add_speaker_to_tokens(segments, tokens):
|
||||
"""
|
||||
Assign speakers to tokens based on diarization segments, with punctuation-aware boundary adjustment.
|
||||
"""
|
||||
punctuation_marks = {'.', '!', '?'}
|
||||
punctuation_tokens = [token for token in tokens if token.text.strip() in punctuation_marks]
|
||||
segments_concatenated = concatenate_speakers(segments)
|
||||
for ind, segment in enumerate(segments_concatenated):
|
||||
for i, punctuation_token in enumerate(punctuation_tokens):
|
||||
if punctuation_token.start > segment['end']:
|
||||
after_length = punctuation_token.start - segment['end']
|
||||
before_length = segment['end'] - punctuation_tokens[i - 1].end
|
||||
if before_length > after_length:
|
||||
segment['end'] = punctuation_token.start
|
||||
if i < len(punctuation_tokens) - 1 and ind + 1 < len(segments_concatenated):
|
||||
segments_concatenated[ind + 1]['begin'] = punctuation_token.start
|
||||
else:
|
||||
segment['end'] = punctuation_tokens[i - 1].end
|
||||
if i < len(punctuation_tokens) - 1 and ind - 1 >= 0:
|
||||
segments_concatenated[ind - 1]['begin'] = punctuation_tokens[i - 1].end
|
||||
break
|
||||
|
||||
last_end = 0.0
|
||||
for token in tokens:
|
||||
start = max(last_end + 0.01, token.start)
|
||||
token.start = start
|
||||
token.end = max(start, token.end)
|
||||
last_end = token.end
|
||||
|
||||
ind_last_speaker = 0
|
||||
for segment in segments_concatenated:
|
||||
for i, token in enumerate(tokens[ind_last_speaker:]):
|
||||
if token.end <= segment['end']:
|
||||
token.speaker = segment['speaker']
|
||||
ind_last_speaker = i + 1
|
||||
# print(
|
||||
# f"Token '{token.text}' ('begin': {token.start:.2f}, 'end': {token.end:.2f}) "
|
||||
# f"assigned to Speaker {segment['speaker']} ('segment': {segment['begin']:.2f}-{segment['end']:.2f})"
|
||||
# )
|
||||
elif token.start > segment['end']:
|
||||
break
|
||||
return tokens
|
||||
|
||||
|
||||
def visualize_tokens(tokens):
|
||||
conversation = [{"speaker": -1, "text": ""}]
|
||||
for token in tokens:
|
||||
speaker = conversation[-1]['speaker']
|
||||
if token.speaker != speaker:
|
||||
conversation.append({"speaker": token.speaker, "text": token.text})
|
||||
else:
|
||||
conversation[-1]['text'] += token.text
|
||||
print("Conversation:")
|
||||
for entry in conversation:
|
||||
print(f"Speaker {entry['speaker']}: {entry['text']}")
|
||||
465
whisperlivekit/diarization/sortformer_backend.py
Normal file
@@ -0,0 +1,465 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import wave
|
||||
from typing import List, Optional
|
||||
from queue import SimpleQueue, Empty
|
||||
|
||||
from whisperlivekit.timed_objects import SpeakerSegment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from nemo.collections.asr.models import SortformerEncLabelModel
|
||||
from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor
|
||||
except ImportError:
|
||||
raise SystemExit("""Please use `pip install "git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]"` to use the Sortformer diarization""")
|
||||
|
||||
|
||||
class StreamingSortformerState:
|
||||
"""
|
||||
This class creates a class instance that will be used to store the state of the
|
||||
streaming Sortformer model.
|
||||
|
||||
Attributes:
|
||||
spkcache (torch.Tensor): Speaker cache to store embeddings from start
|
||||
spkcache_lengths (torch.Tensor): Lengths of the speaker cache
|
||||
spkcache_preds (torch.Tensor): The speaker predictions for the speaker cache parts
|
||||
fifo (torch.Tensor): FIFO queue to save the embedding from the latest chunks
|
||||
fifo_lengths (torch.Tensor): Lengths of the FIFO queue
|
||||
fifo_preds (torch.Tensor): The speaker predictions for the FIFO queue parts
|
||||
spk_perm (torch.Tensor): Speaker permutation information for the speaker cache
|
||||
mean_sil_emb (torch.Tensor): Mean silence embedding
|
||||
n_sil_frames (torch.Tensor): Number of silence frames
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.spkcache = None # Speaker cache to store embeddings from start
|
||||
self.spkcache_lengths = None
|
||||
self.spkcache_preds = None # speaker cache predictions
|
||||
self.fifo = None # to save the embedding from the latest chunks
|
||||
self.fifo_lengths = None
|
||||
self.fifo_preds = None
|
||||
self.spk_perm = None
|
||||
self.mean_sil_emb = None
|
||||
self.n_sil_frames = None
|
||||
|
||||
|
||||
class SortformerDiarization:
|
||||
def __init__(self, model_name: str = "nvidia/diar_streaming_sortformer_4spk-v2"):
|
||||
"""
|
||||
Stores the shared streaming Sortformer diarization model. Used when a new online_diarization is initialized.
|
||||
"""
|
||||
self._load_model(model_name)
|
||||
|
||||
def _load_model(self, model_name: str):
|
||||
"""Load and configure the Sortformer model for streaming."""
|
||||
try:
|
||||
self.diar_model = SortformerEncLabelModel.from_pretrained(model_name)
|
||||
self.diar_model.eval()
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.diar_model.to(device)
|
||||
|
||||
## to test
|
||||
# for name, param in self.diar_model.named_parameters():
|
||||
# if param.device != device:
|
||||
# raise RuntimeError(f"Parameter {name} is on {param.device} but should be on {device}")
|
||||
|
||||
logger.info(f"Using {device.type.upper()} for Sortformer model")
|
||||
|
||||
self.diar_model.sortformer_modules.chunk_len = 10
|
||||
self.diar_model.sortformer_modules.subsampling_factor = 10
|
||||
self.diar_model.sortformer_modules.chunk_right_context = 0
|
||||
self.diar_model.sortformer_modules.chunk_left_context = 10
|
||||
self.diar_model.sortformer_modules.spkcache_len = 188
|
||||
self.diar_model.sortformer_modules.fifo_len = 188
|
||||
self.diar_model.sortformer_modules.spkcache_update_period = 144
|
||||
self.diar_model.sortformer_modules.log = False
|
||||
self.diar_model.sortformer_modules._check_streaming_parameters()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load Sortformer model: {e}")
|
||||
raise
|
||||
|
||||
class SortformerDiarizationOnline:
|
||||
def __init__(self, shared_model, sample_rate: int = 16000):
|
||||
"""
|
||||
Initialize the streaming Sortformer diarization system.
|
||||
|
||||
Args:
|
||||
sample_rate: Audio sample rate (default: 16000)
|
||||
model_name: Pre-trained model name (default: "nvidia/diar_streaming_sortformer_4spk-v2")
|
||||
"""
|
||||
self.sample_rate = sample_rate
|
||||
self.speaker_segments = []
|
||||
self.buffer_audio = np.array([], dtype=np.float32)
|
||||
self.segment_lock = threading.Lock()
|
||||
self.global_time_offset = 0.0
|
||||
self.processed_time = 0.0
|
||||
self.debug = False
|
||||
|
||||
self.diar_model = shared_model.diar_model
|
||||
|
||||
self.audio2mel = AudioToMelSpectrogramPreprocessor(
|
||||
window_size=0.025,
|
||||
normalize="NA",
|
||||
n_fft=512,
|
||||
features=128,
|
||||
pad_to=0
|
||||
)
|
||||
self.audio2mel.to(self.diar_model.device)
|
||||
|
||||
self.chunk_duration_seconds = (
|
||||
self.diar_model.sortformer_modules.chunk_len *
|
||||
self.diar_model.sortformer_modules.subsampling_factor *
|
||||
self.diar_model.preprocessor._cfg.window_stride
|
||||
)
|
||||
|
||||
self._init_streaming_state()
|
||||
|
||||
self._previous_chunk_features = None
|
||||
self._chunk_index = 0
|
||||
self._len_prediction = None
|
||||
|
||||
# Audio buffer to store PCM chunks for debugging
|
||||
self.audio_buffer = []
|
||||
|
||||
# Buffer for accumulating audio chunks until reaching chunk_duration_seconds
|
||||
self.audio_chunk_buffer = []
|
||||
self.accumulated_duration = 0.0
|
||||
|
||||
logger.info("SortformerDiarization initialized successfully")
|
||||
|
||||
|
||||
def _init_streaming_state(self):
|
||||
"""Initialize the streaming state for the model."""
|
||||
batch_size = 1
|
||||
device = self.diar_model.device
|
||||
|
||||
self.streaming_state = StreamingSortformerState()
|
||||
self.streaming_state.spkcache = torch.zeros(
|
||||
(batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.fc_d_model),
|
||||
device=device
|
||||
)
|
||||
self.streaming_state.spkcache_preds = torch.zeros(
|
||||
(batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.n_spk),
|
||||
device=device
|
||||
)
|
||||
self.streaming_state.spkcache_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||
self.streaming_state.fifo = torch.zeros(
|
||||
(batch_size, self.diar_model.sortformer_modules.fifo_len, self.diar_model.sortformer_modules.fc_d_model),
|
||||
device=device
|
||||
)
|
||||
self.streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||
self.streaming_state.mean_sil_emb = torch.zeros((batch_size, self.diar_model.sortformer_modules.fc_d_model), device=device)
|
||||
self.streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||
|
||||
# Initialize total predictions tensor
|
||||
self.total_preds = torch.zeros((batch_size, 0, self.diar_model.sortformer_modules.n_spk), device=device)
|
||||
|
||||
def insert_silence(self, silence_duration: float):
|
||||
"""
|
||||
Insert silence period by adjusting the global time offset.
|
||||
|
||||
Args:
|
||||
silence_duration: Duration of silence in seconds
|
||||
"""
|
||||
with self.segment_lock:
|
||||
self.global_time_offset += silence_duration
|
||||
logger.debug(f"Inserted silence of {silence_duration:.2f}s, new offset: {self.global_time_offset:.2f}s")
|
||||
|
||||
async def diarize(self, pcm_array: np.ndarray):
|
||||
"""
|
||||
Process audio data for diarization in streaming fashion.
|
||||
|
||||
Args:
|
||||
pcm_array: Audio data as numpy array
|
||||
"""
|
||||
try:
|
||||
if self.debug:
|
||||
self.audio_buffer.append(pcm_array.copy())
|
||||
|
||||
threshold = int(self.chunk_duration_seconds * self.sample_rate)
|
||||
|
||||
self.buffer_audio = np.concatenate([self.buffer_audio, pcm_array.copy()])
|
||||
if not len(self.buffer_audio) >= threshold:
|
||||
return
|
||||
|
||||
audio = self.buffer_audio[:threshold]
|
||||
self.buffer_audio = self.buffer_audio[threshold:]
|
||||
|
||||
device = self.diar_model.device
|
||||
audio_signal_chunk = torch.tensor(audio, device=device).unsqueeze(0)
|
||||
audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]], device=device)
|
||||
|
||||
processed_signal_chunk, processed_signal_length_chunk = self.audio2mel.get_features(
|
||||
audio_signal_chunk, audio_signal_length_chunk
|
||||
)
|
||||
processed_signal_chunk = processed_signal_chunk.to(device)
|
||||
processed_signal_length_chunk = processed_signal_length_chunk.to(device)
|
||||
|
||||
if self._previous_chunk_features is not None:
|
||||
to_add = self._previous_chunk_features[:, :, -99:].to(device)
|
||||
total_features = torch.concat([to_add, processed_signal_chunk], dim=2).to(device)
|
||||
else:
|
||||
total_features = processed_signal_chunk.to(device)
|
||||
|
||||
self._previous_chunk_features = processed_signal_chunk.to(device)
|
||||
|
||||
chunk_feat_seq_t = torch.transpose(total_features, 1, 2).to(device)
|
||||
|
||||
with torch.inference_mode():
|
||||
left_offset = 8 if self._chunk_index > 0 else 0
|
||||
right_offset = 8
|
||||
|
||||
self.streaming_state, self.total_preds = self.diar_model.forward_streaming_step(
|
||||
processed_signal=chunk_feat_seq_t,
|
||||
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]).to(device),
|
||||
streaming_state=self.streaming_state,
|
||||
total_preds=self.total_preds,
|
||||
left_offset=left_offset,
|
||||
right_offset=right_offset,
|
||||
)
|
||||
|
||||
# Convert predictions to speaker segments
|
||||
self._process_predictions()
|
||||
|
||||
self._chunk_index += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in diarize: {e}")
|
||||
raise
|
||||
|
||||
# TODO: Handle case when stream ends with partial buffer (accumulated_duration > 0 but < chunk_duration_seconds)
|
||||
|
||||
def _process_predictions(self):
|
||||
"""Process model predictions and convert to speaker segments."""
|
||||
try:
|
||||
preds_np = self.total_preds[0].cpu().numpy()
|
||||
active_speakers = np.argmax(preds_np, axis=1)
|
||||
|
||||
if self._len_prediction is None:
|
||||
self._len_prediction = len(active_speakers)
|
||||
|
||||
# Get predictions for current chunk
|
||||
frame_duration = self.chunk_duration_seconds / self._len_prediction
|
||||
current_chunk_preds = active_speakers[-self._len_prediction:]
|
||||
|
||||
with self.segment_lock:
|
||||
# Process predictions into segments
|
||||
base_time = self._chunk_index * self.chunk_duration_seconds + self.global_time_offset
|
||||
|
||||
for idx, spk in enumerate(current_chunk_preds):
|
||||
start_time = base_time + idx * frame_duration
|
||||
end_time = base_time + (idx + 1) * frame_duration
|
||||
|
||||
# Check if this continues the last segment or starts a new one
|
||||
if (self.speaker_segments and
|
||||
self.speaker_segments[-1].speaker == spk and
|
||||
abs(self.speaker_segments[-1].end - start_time) < frame_duration * 0.5):
|
||||
# Continue existing segment
|
||||
self.speaker_segments[-1].end = end_time
|
||||
else:
|
||||
|
||||
# Create new segment
|
||||
self.speaker_segments.append(SpeakerSegment(
|
||||
speaker=spk,
|
||||
start=start_time,
|
||||
end=end_time
|
||||
))
|
||||
|
||||
# Update processed time
|
||||
self.processed_time = max(self.processed_time, base_time + self.chunk_duration_seconds)
|
||||
|
||||
logger.debug(f"Processed chunk {self._chunk_index}, total segments: {len(self.speaker_segments)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing predictions: {e}")
|
||||
|
||||
def assign_speakers_to_tokens(self, tokens: list, use_punctuation_split: bool = False) -> list:
|
||||
"""
|
||||
Assign speakers to tokens based on timing overlap with speaker segments.
|
||||
|
||||
Args:
|
||||
tokens: List of tokens with timing information
|
||||
use_punctuation_split: Whether to use punctuation for boundary refinement
|
||||
|
||||
Returns:
|
||||
List of tokens with speaker assignments
|
||||
"""
|
||||
with self.segment_lock:
|
||||
segments = self.speaker_segments.copy()
|
||||
|
||||
if not segments or not tokens:
|
||||
logger.debug("No segments or tokens available for speaker assignment")
|
||||
return tokens
|
||||
|
||||
logger.debug(f"Assigning speakers to {len(tokens)} tokens using {len(segments)} segments")
|
||||
use_punctuation_split = False
|
||||
if not use_punctuation_split:
|
||||
# Simple overlap-based assignment
|
||||
for token in tokens:
|
||||
token.speaker = -1 # Default to no speaker
|
||||
for segment in segments:
|
||||
# Check for timing overlap
|
||||
if not (segment.end <= token.start or segment.start >= token.end):
|
||||
token.speaker = segment.speaker + 1 # Convert to 1-based indexing
|
||||
break
|
||||
else:
|
||||
# Use punctuation-aware assignment (similar to diart_backend)
|
||||
tokens = self._add_speaker_to_tokens_with_punctuation(segments, tokens)
|
||||
|
||||
return tokens
|
||||
|
||||
def _add_speaker_to_tokens_with_punctuation(self, segments: List[SpeakerSegment], tokens: list) -> list:
|
||||
"""
|
||||
Assign speakers to tokens with punctuation-aware boundary adjustment.
|
||||
|
||||
Args:
|
||||
segments: List of speaker segments
|
||||
tokens: List of tokens to assign speakers to
|
||||
|
||||
Returns:
|
||||
List of tokens with speaker assignments
|
||||
"""
|
||||
punctuation_marks = {'.', '!', '?'}
|
||||
punctuation_tokens = [token for token in tokens if token.text.strip() in punctuation_marks]
|
||||
|
||||
# Convert segments to concatenated format
|
||||
segments_concatenated = self._concatenate_speakers(segments)
|
||||
|
||||
# Adjust segment boundaries based on punctuation
|
||||
for ind, segment in enumerate(segments_concatenated):
|
||||
for i, punctuation_token in enumerate(punctuation_tokens):
|
||||
if punctuation_token.start > segment['end']:
|
||||
after_length = punctuation_token.start - segment['end']
|
||||
before_length = segment['end'] - punctuation_tokens[i - 1].end if i > 0 else float('inf')
|
||||
|
||||
if before_length > after_length:
|
||||
segment['end'] = punctuation_token.start
|
||||
if i < len(punctuation_tokens) - 1 and ind + 1 < len(segments_concatenated):
|
||||
segments_concatenated[ind + 1]['begin'] = punctuation_token.start
|
||||
else:
|
||||
segment['end'] = punctuation_tokens[i - 1].end if i > 0 else segment['end']
|
||||
if i < len(punctuation_tokens) - 1 and ind - 1 >= 0:
|
||||
segments_concatenated[ind - 1]['begin'] = punctuation_tokens[i - 1].end
|
||||
break
|
||||
|
||||
# Ensure non-overlapping tokens
|
||||
last_end = 0.0
|
||||
for token in tokens:
|
||||
start = max(last_end + 0.01, token.start)
|
||||
token.start = start
|
||||
token.end = max(start, token.end)
|
||||
last_end = token.end
|
||||
|
||||
# Assign speakers based on adjusted segments
|
||||
ind_last_speaker = 0
|
||||
for segment in segments_concatenated:
|
||||
for i, token in enumerate(tokens[ind_last_speaker:]):
|
||||
if token.end <= segment['end']:
|
||||
token.speaker = segment['speaker']
|
||||
ind_last_speaker = i + 1
|
||||
elif token.start > segment['end']:
|
||||
break
|
||||
|
||||
return tokens
|
||||
|
||||
def _concatenate_speakers(self, segments: List[SpeakerSegment]) -> List[dict]:
|
||||
"""
|
||||
Concatenate consecutive segments from the same speaker.
|
||||
|
||||
Args:
|
||||
segments: List of speaker segments
|
||||
|
||||
Returns:
|
||||
List of concatenated speaker segments
|
||||
"""
|
||||
if not segments:
|
||||
return []
|
||||
|
||||
segments_concatenated = [{"speaker": segments[0].speaker + 1, "begin": segments[0].start, "end": segments[0].end}]
|
||||
|
||||
for segment in segments[1:]:
|
||||
speaker = segment.speaker + 1
|
||||
if segments_concatenated[-1]['speaker'] != speaker:
|
||||
segments_concatenated.append({"speaker": speaker, "begin": segment.start, "end": segment.end})
|
||||
else:
|
||||
segments_concatenated[-1]['end'] = segment.end
|
||||
|
||||
return segments_concatenated
|
||||
|
||||
def get_segments(self) -> List[SpeakerSegment]:
|
||||
"""Get a copy of the current speaker segments."""
|
||||
with self.segment_lock:
|
||||
return self.speaker_segments.copy()
|
||||
|
||||
def clear_old_segments(self, older_than: float = 30.0):
|
||||
"""Clear segments older than the specified time."""
|
||||
with self.segment_lock:
|
||||
current_time = self.processed_time
|
||||
self.speaker_segments = [
|
||||
segment for segment in self.speaker_segments
|
||||
if current_time - segment.end < older_than
|
||||
]
|
||||
logger.debug(f"Cleared old segments, remaining: {len(self.speaker_segments)}")
|
||||
|
||||
def close(self):
|
||||
"""Close the diarization system and clean up resources."""
|
||||
logger.info("Closing SortformerDiarization")
|
||||
with self.segment_lock:
|
||||
self.speaker_segments.clear()
|
||||
|
||||
if self.debug:
|
||||
concatenated_audio = np.concatenate(self.audio_buffer)
|
||||
audio_data_int16 = (concatenated_audio * 32767).astype(np.int16)
|
||||
with wave.open("diarization_audio.wav", "wb") as wav_file:
|
||||
wav_file.setnchannels(1) # mono audio
|
||||
wav_file.setsampwidth(2) # 2 bytes per sample (int16)
|
||||
wav_file.setframerate(self.sample_rate)
|
||||
wav_file.writeframes(audio_data_int16.tobytes())
|
||||
logger.info(f"Saved {len(concatenated_audio)} samples to diarization_audio.wav")
|
||||
|
||||
|
||||
def extract_number(s: str) -> int:
|
||||
"""Extract number from speaker string (compatibility function)."""
|
||||
import re
|
||||
m = re.search(r'\d+', s)
|
||||
return int(m.group()) if m else 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import asyncio
|
||||
import librosa
|
||||
|
||||
async def main():
|
||||
"""TEST ONLY."""
|
||||
an4_audio = 'audio_test.mp3'
|
||||
signal, sr = librosa.load(an4_audio, sr=16000)
|
||||
signal = signal[:16000*30]
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("ground truth:")
|
||||
print("Speaker 0: 0:00 - 0:09")
|
||||
print("Speaker 1: 0:09 - 0:19")
|
||||
print("Speaker 2: 0:19 - 0:25")
|
||||
print("Speaker 0: 0:25 - 0:30")
|
||||
print("=" * 50)
|
||||
|
||||
diarization = SortformerDiarization(sample_rate=16000)
|
||||
chunk_size = 1600
|
||||
|
||||
for i in range(0, len(signal), chunk_size):
|
||||
chunk = signal[i:i+chunk_size]
|
||||
await diarization.diarize(chunk)
|
||||
print(f"Processed chunk {i // chunk_size + 1}")
|
||||
|
||||
segments = diarization.get_segments()
|
||||
print("\nDiarization results:")
|
||||
for segment in segments:
|
||||
print(f"Speaker {segment.speaker}: {segment.start:.2f}s - {segment.end:.2f}s")
|
||||
|
||||
asyncio.run(main())
|
||||
205
whisperlivekit/diarization/sortformer_backend_offline.py
Normal file
@@ -0,0 +1,205 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import logging
|
||||
|
||||
from nemo.collections.asr.models import SortformerEncLabelModel
|
||||
from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor
|
||||
import librosa
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def load_model():
|
||||
|
||||
diar_model = SortformerEncLabelModel.from_pretrained("nvidia/diar_streaming_sortformer_4spk-v2")
|
||||
diar_model.eval()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
diar_model.to(torch.device("cuda"))
|
||||
|
||||
#we target 1 second lag for the moment. chunk_len could be reduced.
|
||||
diar_model.sortformer_modules.chunk_len = 10
|
||||
diar_model.sortformer_modules.subsampling_factor = 10 #8 would be better ideally
|
||||
|
||||
diar_model.sortformer_modules.chunk_right_context = 0 #no.
|
||||
diar_model.sortformer_modules.chunk_left_context = 10 #big so it compensiate the problem with no padding later.
|
||||
|
||||
diar_model.sortformer_modules.spkcache_len = 188
|
||||
diar_model.sortformer_modules.fifo_len = 188
|
||||
diar_model.sortformer_modules.spkcache_update_period = 144
|
||||
diar_model.sortformer_modules.log = False
|
||||
diar_model.sortformer_modules._check_streaming_parameters()
|
||||
|
||||
|
||||
audio2mel = AudioToMelSpectrogramPreprocessor(
|
||||
window_size= 0.025,
|
||||
normalize="NA",
|
||||
n_fft=512,
|
||||
features=128,
|
||||
pad_to=0) #pad_to 16 works better than 0. On test audio, we detect a third speaker for 1 second with pad_to=0. To solve that : increase left context to 10.
|
||||
|
||||
return diar_model, audio2mel
|
||||
|
||||
diar_model, audio2mel = load_model()
|
||||
|
||||
class StreamingSortformerState:
|
||||
"""
|
||||
This class creates a class instance that will be used to store the state of the
|
||||
streaming Sortformer model.
|
||||
|
||||
Attributes:
|
||||
spkcache (torch.Tensor): Speaker cache to store embeddings from start
|
||||
spkcache_lengths (torch.Tensor): Lengths of the speaker cache
|
||||
spkcache_preds (torch.Tensor): The speaker predictions for the speaker cache parts
|
||||
fifo (torch.Tensor): FIFO queue to save the embedding from the latest chunks
|
||||
fifo_lengths (torch.Tensor): Lengths of the FIFO queue
|
||||
fifo_preds (torch.Tensor): The speaker predictions for the FIFO queue parts
|
||||
spk_perm (torch.Tensor): Speaker permutation information for the speaker cache
|
||||
mean_sil_emb (torch.Tensor): Mean silence embedding
|
||||
n_sil_frames (torch.Tensor): Number of silence frames
|
||||
"""
|
||||
|
||||
spkcache = None # Speaker cache to store embeddings from start
|
||||
spkcache_lengths = None #
|
||||
spkcache_preds = None # speaker cache predictions
|
||||
fifo = None # to save the embedding from the latest chunks
|
||||
fifo_lengths = None
|
||||
fifo_preds = None
|
||||
spk_perm = None
|
||||
mean_sil_emb = None
|
||||
n_sil_frames = None
|
||||
|
||||
|
||||
def init_streaming_state(self, batch_size: int = 1, async_streaming: bool = False, device: torch.device = None):
|
||||
"""
|
||||
Initializes StreamingSortformerState with empty tensors or zero-valued tensors.
|
||||
|
||||
Args:
|
||||
batch_size (int): Batch size for tensors in streaming state
|
||||
async_streaming (bool): True for asynchronous update, False for synchronous update
|
||||
device (torch.device): Device for tensors in streaming state
|
||||
|
||||
Returns:
|
||||
streaming_state (SortformerStreamingState): initialized streaming state
|
||||
"""
|
||||
streaming_state = StreamingSortformerState()
|
||||
if async_streaming:
|
||||
streaming_state.spkcache = torch.zeros((batch_size, self.spkcache_len, self.fc_d_model), device=device)
|
||||
streaming_state.spkcache_preds = torch.zeros((batch_size, self.spkcache_len, self.n_spk), device=device)
|
||||
streaming_state.spkcache_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||
streaming_state.fifo = torch.zeros((batch_size, self.fifo_len, self.fc_d_model), device=device)
|
||||
streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||
else:
|
||||
streaming_state.spkcache = torch.zeros((batch_size, 0, self.fc_d_model), device=device)
|
||||
streaming_state.fifo = torch.zeros((batch_size, 0, self.fc_d_model), device=device)
|
||||
streaming_state.mean_sil_emb = torch.zeros((batch_size, self.fc_d_model), device=device)
|
||||
streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||
return streaming_state
|
||||
|
||||
|
||||
def process_diarization(chunks):
|
||||
"""
|
||||
what it does:
|
||||
1. Preprocessing: Applies dithering and pre-emphasis (high-pass filter) if enabled
|
||||
2. STFT: Computes the Short-Time Fourier Transform using:
|
||||
- the window of window_size=0.025 --> size of a window : 400 samples
|
||||
- the hop parameter : n_window_stride = 0.01 -> every 160 samples, a new window
|
||||
3. Magnitude Calculation: Converts complex STFT output to magnitude spectrogram
|
||||
4. Mel Conversion: Applies Mel filterbanks (128 filters in this case) to get Mel spectrogram
|
||||
5. Logarithm: Takes the log of the Mel spectrogram (if `log=True`)
|
||||
6. Normalization: Skips normalization since `normalize="NA"`
|
||||
7. Padding: Pads the time dimension to a multiple of `pad_to` (default 16)
|
||||
"""
|
||||
previous_chunk = None
|
||||
l_chunk_feat_seq_t = []
|
||||
for chunk in chunks:
|
||||
audio_signal_chunk = torch.tensor(chunk).unsqueeze(0).to(diar_model.device)
|
||||
audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]]).to(diar_model.device)
|
||||
processed_signal_chunk, processed_signal_length_chunk = audio2mel.get_features(audio_signal_chunk, audio_signal_length_chunk)
|
||||
if previous_chunk is not None:
|
||||
to_add = previous_chunk[:, :, -99:]
|
||||
total = torch.concat([to_add, processed_signal_chunk], dim=2)
|
||||
else:
|
||||
total = processed_signal_chunk
|
||||
previous_chunk = processed_signal_chunk
|
||||
l_chunk_feat_seq_t.append(torch.transpose(total, 1, 2))
|
||||
|
||||
batch_size = 1
|
||||
streaming_state = init_streaming_state(diar_model.sortformer_modules,
|
||||
batch_size = batch_size,
|
||||
async_streaming = True,
|
||||
device = diar_model.device
|
||||
)
|
||||
total_preds = torch.zeros((batch_size, 0, diar_model.sortformer_modules.n_spk), device=diar_model.device)
|
||||
|
||||
chunk_duration_seconds = diar_model.sortformer_modules.chunk_len * diar_model.sortformer_modules.subsampling_factor * diar_model.preprocessor._cfg.window_stride
|
||||
|
||||
l_speakers = [
|
||||
{'start_time': 0,
|
||||
'end_time': 0,
|
||||
'speaker': 0
|
||||
}
|
||||
]
|
||||
len_prediction = None
|
||||
left_offset = 0
|
||||
right_offset = 8
|
||||
for i, chunk_feat_seq_t in enumerate(l_chunk_feat_seq_t):
|
||||
with torch.inference_mode():
|
||||
streaming_state, total_preds = diar_model.forward_streaming_step(
|
||||
processed_signal=chunk_feat_seq_t,
|
||||
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]),
|
||||
streaming_state=streaming_state,
|
||||
total_preds=total_preds,
|
||||
left_offset=left_offset,
|
||||
right_offset=right_offset,
|
||||
)
|
||||
left_offset = 8
|
||||
preds_np = total_preds[0].cpu().numpy()
|
||||
active_speakers = np.argmax(preds_np, axis=1)
|
||||
if len_prediction is None:
|
||||
len_prediction = len(active_speakers) # we want to get the len of 1 prediction
|
||||
frame_duration = chunk_duration_seconds / len_prediction
|
||||
active_speakers = active_speakers[-len_prediction:]
|
||||
for idx, spk in enumerate(active_speakers):
|
||||
if spk != l_speakers[-1]['speaker']:
|
||||
l_speakers.append(
|
||||
{'start_time': (i * chunk_duration_seconds + idx * frame_duration),
|
||||
'end_time': (i * chunk_duration_seconds + (idx + 1) * frame_duration),
|
||||
'speaker': spk
|
||||
})
|
||||
else:
|
||||
l_speakers[-1]['end_time'] = i * chunk_duration_seconds + (idx + 1) * frame_duration
|
||||
|
||||
|
||||
"""
|
||||
Should print
|
||||
[{'start_time': 0, 'end_time': 8.72, 'speaker': 0},
|
||||
{'start_time': 8.72, 'end_time': 18.88, 'speaker': 1},
|
||||
{'start_time': 18.88, 'end_time': 24.96, 'speaker': 2},
|
||||
{'start_time': 24.96, 'end_time': 31.68, 'speaker': 0}]
|
||||
"""
|
||||
for speaker in l_speakers:
|
||||
print(f"Speaker {speaker['speaker']}: {speaker['start_time']:.2f}s - {speaker['end_time']:.2f}s")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
an4_audio = 'audio_test.mp3'
|
||||
signal, sr = librosa.load(an4_audio, sr=16000)
|
||||
signal = signal[:16000*30]
|
||||
# signal = signal[:-(len(signal)%16000)]
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("Expected ground truth:")
|
||||
print("Speaker 0: 0:00 - 0:09")
|
||||
print("Speaker 1: 0:09 - 0:19")
|
||||
print("Speaker 2: 0:19 - 0:25")
|
||||
print("Speaker 0: 0:25 - 0:30")
|
||||
print("=" * 50)
|
||||
|
||||
chunk_size = 16000 # 1 second
|
||||
chunks = []
|
||||
for i in range(0, len(signal), chunk_size):
|
||||
chunk = signal[i:i+chunk_size]
|
||||
chunks.append(chunk)
|
||||
|
||||
process_diarization(chunks)
|
||||
@@ -1,32 +0,0 @@
|
||||
import os
|
||||
import requests
|
||||
import inspect
|
||||
|
||||
def get_module_path():
|
||||
return os.path.dirname(inspect.getfile(inspect.currentframe()))
|
||||
|
||||
GITHUB_API_URL = "https://api.github.com/repos/ufal/SimulStreaming/contents/simul_whisper/whisper"
|
||||
RAW_BASE_URL = "https://raw.githubusercontent.com/ufal/SimulStreaming/main/simul_whisper/whisper"
|
||||
TARGET_DIR = os.path.join(get_module_path(), "simul_whisper", "whisper")
|
||||
|
||||
def download_files_from_github(api_url, local_dir):
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
response = requests.get(api_url)
|
||||
response.raise_for_status()
|
||||
items = response.json()
|
||||
for item in items:
|
||||
if item['type'] == 'file':
|
||||
download_url = item['download_url']
|
||||
file_name = item['name']
|
||||
file_response = requests.get(download_url)
|
||||
file_response.raise_for_status()
|
||||
with open(os.path.join(local_dir, file_name), 'wb') as f:
|
||||
f.write(file_response.content)
|
||||
elif item['type'] == 'dir':
|
||||
# Recursive call for subdirectories
|
||||
download_files_from_github(item['url'], os.path.join(local_dir, item['name']))
|
||||
|
||||
def download_simulstreaming_backend():
|
||||
print(f"Downloading files into {TARGET_DIR} ...")
|
||||
download_files_from_github(GITHUB_API_URL, TARGET_DIR)
|
||||
print("✅ Download of SimulStreaming backend files completed successfully.")
|
||||
@@ -143,7 +143,7 @@ class FFmpegManager:
|
||||
try:
|
||||
data = await asyncio.wait_for(
|
||||
self.process.stdout.read(size),
|
||||
timeout=5.0
|
||||
timeout=20.0
|
||||
)
|
||||
return data
|
||||
except asyncio.TimeoutError:
|
||||
|
||||
@@ -58,12 +58,26 @@ def parse_args():
|
||||
help="Hugging Face model ID for pyannote.audio embedding model.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--diarization-backend",
|
||||
type=str,
|
||||
default="sortformer",
|
||||
choices=["sortformer", "diart"],
|
||||
help="The diarization backend to use.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-transcription",
|
||||
action="store_true",
|
||||
help="Disable transcription to only see live diarization results.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--disable-punctuation-split",
|
||||
action="store_true",
|
||||
help="Disable the split parameter.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--min-chunk-size",
|
||||
type=float,
|
||||
@@ -74,7 +88,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="tiny",
|
||||
default="small",
|
||||
help="Name size of the Whisper model to use (default: tiny). Suggested values: tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo. The model is automatically downloaded from the model hub if not present in model cache dir.",
|
||||
)
|
||||
|
||||
@@ -104,18 +118,27 @@ def parse_args():
|
||||
choices=["transcribe", "translate"],
|
||||
help="Transcribe or translate.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--target-language",
|
||||
type=str,
|
||||
default="",
|
||||
dest="target_language",
|
||||
help="Target language for translation. Not functional yet.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
type=str,
|
||||
default="faster-whisper",
|
||||
default="simulstreaming",
|
||||
choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api", "simulstreaming"],
|
||||
help="Load only this backend for Whisper processing.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vac",
|
||||
"--no-vac",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use VAC = voice activity controller. Recommended. Requires torch.",
|
||||
help="Disable VAC = voice activity controller.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vac-chunk-size", type=float, default=0.04, help="VAC sample size in seconds."
|
||||
@@ -150,9 +173,22 @@ def parse_args():
|
||||
)
|
||||
parser.add_argument("--ssl-certfile", type=str, help="Path to the SSL certificate file.", default=None)
|
||||
parser.add_argument("--ssl-keyfile", type=str, help="Path to the SSL private key file.", default=None)
|
||||
|
||||
parser.add_argument(
|
||||
"--pcm-input",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="If set, raw PCM (s16le) data is expected as input and FFmpeg will be bypassed."
|
||||
)
|
||||
# SimulStreaming-specific arguments
|
||||
simulstreaming_group = parser.add_argument_group('SimulStreaming arguments (only used with --backend simulstreaming)')
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--disable-fast-encoder",
|
||||
action="store_true",
|
||||
default=False,
|
||||
dest="disable_fast_encoder",
|
||||
help="Disable Faster Whisper or MLX Whisper backends for encoding (if installed). Slower but helpful when GPU memory is limited",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--frame-threshold",
|
||||
@@ -242,6 +278,14 @@ def parse_args():
|
||||
dest="model_path",
|
||||
help="Direct path to the SimulStreaming Whisper .pt model file. Overrides --model for SimulStreaming backend.",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--preload-model-count",
|
||||
type=int,
|
||||
default=1,
|
||||
dest="preload_model_count",
|
||||
help="Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent instances).",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
110
whisperlivekit/remove_silences.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from whisperlivekit.timed_objects import ASRToken
|
||||
import re
|
||||
|
||||
MIN_SILENCE_DURATION = 4 #in seconds
|
||||
END_SILENCE_DURATION = 8 #in seconds. you should keep it important to not have false positive when the model lag is important
|
||||
END_SILENCE_DURATION_VAC = 3 #VAC is good at detecting silences, but we want to skip the smallest silences
|
||||
|
||||
def blank_to_silence(tokens):
|
||||
full_string = ''.join([t.text for t in tokens])
|
||||
patterns = [re.compile(r'(?:\s*\[BLANK_AUDIO\]\s*)+'), re.compile(r'(?:\s*\[typing\]\s*)+')]
|
||||
matches = []
|
||||
for pattern in patterns:
|
||||
for m in pattern.finditer(full_string):
|
||||
matches.append({
|
||||
'start': m.start(),
|
||||
'end': m.end()
|
||||
})
|
||||
if matches:
|
||||
# cleaned = pattern.sub(' ', full_string).strip()
|
||||
# print("Cleaned:", cleaned)
|
||||
cumulated_len = 0
|
||||
silence_token = None
|
||||
cleaned_tokens = []
|
||||
for token in tokens:
|
||||
if matches:
|
||||
start = cumulated_len
|
||||
end = cumulated_len + len(token.text)
|
||||
cumulated_len = end
|
||||
if start >= matches[0]['start'] and end <= matches[0]['end']:
|
||||
if silence_token: #previous token was already silence
|
||||
silence_token.start = min(silence_token.start, token.start)
|
||||
silence_token.end = max(silence_token.end, token.end)
|
||||
else: #new silence
|
||||
silence_token = ASRToken(
|
||||
start=token.start,
|
||||
end=token.end,
|
||||
speaker=-2,
|
||||
probability=0.95
|
||||
)
|
||||
else:
|
||||
if silence_token: #there was silence but no more
|
||||
if silence_token.end - silence_token.start >= MIN_SILENCE_DURATION:
|
||||
cleaned_tokens.append(
|
||||
silence_token
|
||||
)
|
||||
silence_token = None
|
||||
matches.pop(0)
|
||||
cleaned_tokens.append(token)
|
||||
# print(cleaned_tokens)
|
||||
return cleaned_tokens
|
||||
return tokens
|
||||
|
||||
def no_token_to_silence(tokens):
|
||||
new_tokens = []
|
||||
silence_token = None
|
||||
for token in tokens:
|
||||
if token.speaker == -2:
|
||||
if new_tokens and new_tokens[-1].speaker == -2: #if token is silence and previous one too
|
||||
new_tokens[-1].end = token.end
|
||||
else:
|
||||
new_tokens.append(token)
|
||||
|
||||
last_end = new_tokens[-1].end if new_tokens else 0.0
|
||||
if token.start - last_end >= MIN_SILENCE_DURATION: #if token is not silence but important gap
|
||||
if new_tokens and new_tokens[-1].speaker == -2:
|
||||
new_tokens[-1].end = token.start
|
||||
else:
|
||||
silence_token = ASRToken(
|
||||
start=last_end,
|
||||
end=token.start,
|
||||
speaker=-2,
|
||||
probability=0.95
|
||||
)
|
||||
new_tokens.append(silence_token)
|
||||
|
||||
if token.speaker != -2:
|
||||
new_tokens.append(token)
|
||||
return new_tokens
|
||||
|
||||
def ends_with_silence(tokens, buffer_transcription, buffer_diarization, current_time, vac_detected_silence):
|
||||
if not tokens:
|
||||
return [], buffer_transcription, buffer_diarization
|
||||
last_token = tokens[-1]
|
||||
if tokens and current_time and (
|
||||
current_time - last_token.end >= END_SILENCE_DURATION
|
||||
or
|
||||
(current_time - last_token.end >= 3 and vac_detected_silence)
|
||||
):
|
||||
if last_token.speaker == -2:
|
||||
last_token.end = current_time
|
||||
else:
|
||||
tokens.append(
|
||||
ASRToken(
|
||||
start=tokens[-1].end,
|
||||
end=current_time,
|
||||
speaker=-2,
|
||||
probability=0.95
|
||||
)
|
||||
)
|
||||
buffer_transcription = "" # for whisperstreaming backend, we should probably validate the buffer has because of the silence
|
||||
buffer_diarization = ""
|
||||
return tokens, buffer_transcription, buffer_diarization
|
||||
|
||||
|
||||
def handle_silences(tokens, buffer_transcription, buffer_diarization, current_time, vac_detected_silence):
|
||||
tokens = blank_to_silence(tokens) #useful for simulstreaming backend which tends to generate [BLANK_AUDIO] text
|
||||
tokens = no_token_to_silence(tokens)
|
||||
tokens, buffer_transcription, buffer_diarization = ends_with_silence(tokens, buffer_transcription, buffer_diarization, current_time, vac_detected_silence)
|
||||
return tokens, buffer_transcription, buffer_diarization
|
||||
|
||||
137
whisperlivekit/results_formater.py
Normal file
@@ -0,0 +1,137 @@
|
||||
|
||||
import logging
|
||||
from whisperlivekit.remove_silences import handle_silences
|
||||
from whisperlivekit.timed_objects import Line, format_time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
PUNCTUATION_MARKS = {'.', '!', '?', '。', '!', '?'}
|
||||
CHECK_AROUND = 4
|
||||
|
||||
def is_punctuation(token):
|
||||
if token.text.strip() in PUNCTUATION_MARKS:
|
||||
return True
|
||||
return False
|
||||
|
||||
def next_punctuation_change(i, tokens):
|
||||
for ind in range(i+1, min(len(tokens), i+CHECK_AROUND+1)):
|
||||
if is_punctuation(tokens[ind]):
|
||||
return ind
|
||||
return None
|
||||
|
||||
def next_speaker_change(i, tokens, speaker):
|
||||
for ind in range(i-1, max(0, i-CHECK_AROUND)-1, -1):
|
||||
token = tokens[ind]
|
||||
if is_punctuation(token):
|
||||
break
|
||||
if token.speaker != speaker:
|
||||
return ind, token.speaker
|
||||
return None, speaker
|
||||
|
||||
def new_line(
|
||||
token,
|
||||
speaker,
|
||||
debug_info = ""
|
||||
):
|
||||
return Line(
|
||||
speaker = speaker,
|
||||
text = token.text + debug_info,
|
||||
start = token.start,
|
||||
end = token.end,
|
||||
)
|
||||
|
||||
def append_token_to_last_line(lines, sep, token, debug_info):
|
||||
if token.text:
|
||||
lines[-1].text += sep + token.text + debug_info
|
||||
lines[-1].end = token.end
|
||||
|
||||
def format_output(state, silence, current_time, args, debug):
|
||||
diarization = args.diarization
|
||||
disable_punctuation_split = args.disable_punctuation_split
|
||||
tokens = state["tokens"]
|
||||
translated_segments = state["translated_segments"] # Here we will attribute the speakers only based on the timestamps of the segments
|
||||
buffer_transcription = state["buffer_transcription"]
|
||||
buffer_diarization = state["buffer_diarization"]
|
||||
end_attributed_speaker = state["end_attributed_speaker"]
|
||||
sep = state["sep"]
|
||||
|
||||
previous_speaker = -1
|
||||
lines = []
|
||||
undiarized_text = []
|
||||
tokens, buffer_transcription, buffer_diarization = handle_silences(tokens, buffer_transcription, buffer_diarization, current_time, silence)
|
||||
last_punctuation = None
|
||||
for i, token in enumerate(tokens):
|
||||
speaker = token.speaker
|
||||
if not diarization and speaker == -1: #Speaker -1 means no attributed by diarization. In the frontend, it should appear under 'Speaker 1'
|
||||
speaker = 1
|
||||
if diarization and not tokens[-1].speaker == -2:
|
||||
if (speaker in [-1, 0]) and token.end >= end_attributed_speaker:
|
||||
undiarized_text.append(token.text)
|
||||
continue
|
||||
elif (speaker in [-1, 0]) and token.end < end_attributed_speaker:
|
||||
speaker = previous_speaker
|
||||
debug_info = ""
|
||||
if debug:
|
||||
debug_info = f"[{format_time(token.start)} : {format_time(token.end)}]"
|
||||
|
||||
if not lines:
|
||||
lines.append(new_line(token, speaker, debug_info = ""))
|
||||
continue
|
||||
else:
|
||||
previous_speaker = lines[-1].speaker
|
||||
|
||||
if is_punctuation(token):
|
||||
last_punctuation = i
|
||||
|
||||
|
||||
if last_punctuation == i-1:
|
||||
if speaker != previous_speaker:
|
||||
# perfect, diarization perfectly aligned
|
||||
lines.append(new_line(token, speaker, debug_info = ""))
|
||||
last_punctuation, next_punctuation = None, None
|
||||
continue
|
||||
|
||||
speaker_change_pos, new_speaker = next_speaker_change(i, tokens, speaker)
|
||||
if speaker_change_pos:
|
||||
# Corrects delay:
|
||||
# That was the idea. Okay haha |SPLIT SPEAKER| that's a good one
|
||||
# should become:
|
||||
# That was the idea. |SPLIT SPEAKER| Okay haha that's a good one
|
||||
lines.append(new_line(token, new_speaker, debug_info = ""))
|
||||
else:
|
||||
# No speaker change to come
|
||||
append_token_to_last_line(lines, sep, token, debug_info)
|
||||
continue
|
||||
|
||||
|
||||
if speaker != previous_speaker:
|
||||
if speaker == -2 or previous_speaker == -2: #silences can happen anytime
|
||||
lines.append(new_line(token, speaker, debug_info = ""))
|
||||
continue
|
||||
elif next_punctuation_change(i, tokens):
|
||||
# Corrects advance:
|
||||
# Are you |SPLIT SPEAKER| okay? yeah, sure. Absolutely
|
||||
# should become:
|
||||
# Are you okay? |SPLIT SPEAKER| yeah, sure. Absolutely
|
||||
append_token_to_last_line(lines, sep, token, debug_info)
|
||||
continue
|
||||
else: #we create a new speaker, but that's no ideal. We are not sure about the split. We prefer to append to previous line
|
||||
if disable_punctuation_split:
|
||||
lines.append(new_line(token, speaker, debug_info = ""))
|
||||
continue
|
||||
pass
|
||||
|
||||
append_token_to_last_line(lines, sep, token, debug_info)
|
||||
if lines and translated_segments:
|
||||
cts_idx = 0 # current_translated_segment_idx
|
||||
for line in lines:
|
||||
while cts_idx < len(translated_segments):
|
||||
ts = translated_segments[cts_idx]
|
||||
if ts.start and ts.start >= line.start and ts.end <= line.end:
|
||||
line.translation += ts.text + ' '
|
||||
cts_idx += 1
|
||||
else:
|
||||
break
|
||||
return lines, undiarized_text, buffer_transcription, ''
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
from .backend import SimulStreamingASR, SimulStreamingOnlineProcessor
|
||||
|
||||
__all__ = [
|
||||
"SimulStreamingASR",
|
||||
"SimulStreamingOnlineProcessor",
|
||||
]
|
||||
|
||||
365
whisperlivekit/simul_whisper/backend.py
Normal file
@@ -0,0 +1,365 @@
|
||||
import sys
|
||||
import numpy as np
|
||||
import logging
|
||||
from typing import List, Tuple, Optional
|
||||
import logging
|
||||
import platform
|
||||
from whisperlivekit.timed_objects import ASRToken, Transcript
|
||||
from whisperlivekit.warmup import load_file
|
||||
from whisperlivekit.simul_whisper.license_simulstreaming import SIMULSTREAMING_LICENSE
|
||||
from .whisper import load_model, tokenizer
|
||||
from .whisper.audio import TOKENS_PER_SECOND
|
||||
import os
|
||||
import gc
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import torch
|
||||
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
||||
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper
|
||||
from whisperlivekit.simul_whisper.whisper import tokenizer
|
||||
|
||||
try:
|
||||
from .mlx_encoder import mlx_model_mapping, load_mlx_encoder
|
||||
HAS_MLX_WHISPER = True
|
||||
except ImportError:
|
||||
if platform.system() == "Darwin" and platform.machine() == "arm64":
|
||||
print('MLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: pip install mlx-whisper')
|
||||
HAS_MLX_WHISPER = False
|
||||
if HAS_MLX_WHISPER:
|
||||
HAS_FASTER_WHISPER = False
|
||||
else:
|
||||
try:
|
||||
from faster_whisper import WhisperModel
|
||||
HAS_FASTER_WHISPER = True
|
||||
except ImportError:
|
||||
HAS_FASTER_WHISPER = False
|
||||
|
||||
|
||||
# TOO_MANY_REPETITIONS = 3
|
||||
|
||||
class SimulStreamingOnlineProcessor:
|
||||
SAMPLING_RATE = 16000
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
asr,
|
||||
logfile=sys.stderr,
|
||||
warmup_file=None
|
||||
):
|
||||
self.asr = asr
|
||||
self.logfile = logfile
|
||||
self.end = 0.0
|
||||
self.global_time_offset = 0.0
|
||||
|
||||
self.committed: List[ASRToken] = []
|
||||
self.last_result_tokens: List[ASRToken] = []
|
||||
self.load_new_backend()
|
||||
|
||||
#can be moved
|
||||
if asr.tokenizer:
|
||||
self.model.tokenizer = asr.tokenizer
|
||||
|
||||
def load_new_backend(self):
|
||||
model = self.asr.get_new_model_instance()
|
||||
self.model = PaddedAlignAttWhisper(
|
||||
cfg=self.asr.cfg,
|
||||
loaded_model=model,
|
||||
mlx_encoder=self.asr.mlx_encoder,
|
||||
fw_encoder=self.asr.fw_encoder,
|
||||
)
|
||||
|
||||
def insert_silence(self, silence_duration, offset):
|
||||
"""
|
||||
If silences are > 5s, we do a complete context clear. Otherwise, we just insert a small silence and shift the last_attend_frame
|
||||
"""
|
||||
if silence_duration < 5:
|
||||
gap_silence = torch.zeros(int(16000*silence_duration))
|
||||
self.model.insert_audio(gap_silence)
|
||||
# self.global_time_offset += silence_duration
|
||||
else:
|
||||
self.process_iter(is_last=True) #we want to totally process what remains in the buffer.
|
||||
self.model.refresh_segment(complete=True)
|
||||
self.global_time_offset = silence_duration + offset
|
||||
|
||||
|
||||
|
||||
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time):
|
||||
"""Append an audio chunk to be processed by SimulStreaming."""
|
||||
|
||||
# Convert numpy array to torch tensor
|
||||
audio_tensor = torch.from_numpy(audio).float()
|
||||
self.end = audio_stream_end_time #Only to be aligned with what happens in whisperstreaming backend.
|
||||
self.model.insert_audio(audio_tensor)
|
||||
|
||||
def get_buffer(self):
|
||||
return Transcript(
|
||||
start=None,
|
||||
end=None,
|
||||
text='',
|
||||
probability=None
|
||||
)
|
||||
|
||||
def timestamped_text(self, tokens, generation):
|
||||
"""
|
||||
generate timestamped text from tokens and generation data.
|
||||
|
||||
args:
|
||||
tokens: List of tokens to process
|
||||
generation: Dictionary containing generation progress and optionally results
|
||||
|
||||
returns:
|
||||
List of tuples containing (start_time, end_time, word) for each word
|
||||
"""
|
||||
FRAME_DURATION = 0.02
|
||||
if "result" in generation:
|
||||
split_words = generation["result"]["split_words"]
|
||||
split_tokens = generation["result"]["split_tokens"]
|
||||
else:
|
||||
split_words, split_tokens = self.model.tokenizer.split_to_word_tokens(tokens)
|
||||
progress = generation["progress"]
|
||||
frames = [p["most_attended_frames"][0] for p in progress]
|
||||
absolute_timestamps = [p["absolute_timestamps"][0] for p in progress]
|
||||
tokens_queue = tokens.copy()
|
||||
timestamped_words = []
|
||||
|
||||
for word, word_tokens in zip(split_words, split_tokens):
|
||||
# start_frame = None
|
||||
# end_frame = None
|
||||
for expected_token in word_tokens:
|
||||
if not tokens_queue or not frames:
|
||||
raise ValueError(f"Insufficient tokens or frames for word '{word}'")
|
||||
|
||||
actual_token = tokens_queue.pop(0)
|
||||
current_frame = frames.pop(0)
|
||||
current_timestamp = absolute_timestamps.pop(0)
|
||||
if actual_token != expected_token:
|
||||
raise ValueError(
|
||||
f"Token mismatch: expected '{expected_token}', "
|
||||
f"got '{actual_token}' at frame {current_frame}"
|
||||
)
|
||||
# if start_frame is None:
|
||||
# start_frame = current_frame
|
||||
# end_frame = current_frame
|
||||
# start_time = start_frame * FRAME_DURATION
|
||||
# end_time = end_frame * FRAME_DURATION
|
||||
start_time = current_timestamp
|
||||
end_time = current_timestamp + 0.1
|
||||
timestamp_entry = (start_time, end_time, word)
|
||||
timestamped_words.append(timestamp_entry)
|
||||
logger.debug(f"TS-WORD:\t{start_time:.2f}\t{end_time:.2f}\t{word}")
|
||||
return timestamped_words
|
||||
|
||||
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
|
||||
"""
|
||||
Process accumulated audio chunks using SimulStreaming.
|
||||
|
||||
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
|
||||
"""
|
||||
try:
|
||||
tokens, generation_progress = self.model.infer(is_last=is_last)
|
||||
ts_words = self.timestamped_text(tokens, generation_progress)
|
||||
|
||||
new_tokens = []
|
||||
for ts_word in ts_words:
|
||||
|
||||
start, end, word = ts_word
|
||||
token = ASRToken(
|
||||
start=start,
|
||||
end=end,
|
||||
text=word,
|
||||
probability=0.95 # fake prob. Maybe we can extract it from the model?
|
||||
).with_offset(
|
||||
self.global_time_offset
|
||||
)
|
||||
new_tokens.append(token)
|
||||
|
||||
# identical_tokens = 0
|
||||
# n_new_tokens = len(new_tokens)
|
||||
# if n_new_tokens:
|
||||
|
||||
self.committed.extend(new_tokens)
|
||||
|
||||
# if token in self.committed:
|
||||
# pos = len(self.committed) - 1 - self.committed[::-1].index(token)
|
||||
# if pos:
|
||||
# for i in range(len(self.committed) - n_new_tokens, -1, -n_new_tokens):
|
||||
# commited_segment = self.committed[i:i+n_new_tokens]
|
||||
# if commited_segment == new_tokens:
|
||||
# identical_segments +=1
|
||||
# if identical_tokens >= TOO_MANY_REPETITIONS:
|
||||
# logger.warning('Too many repetition, model is stuck. Load a new one')
|
||||
# self.committed = self.committed[:i]
|
||||
# self.load_new_backend()
|
||||
# return [], self.end
|
||||
|
||||
# pos = self.committed.rindex(token)
|
||||
|
||||
|
||||
|
||||
return new_tokens, self.end
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"SimulStreaming processing error: {e}")
|
||||
return [], self.end
|
||||
|
||||
def warmup(self, audio, init_prompt=""):
|
||||
"""Warmup the SimulStreaming model."""
|
||||
try:
|
||||
self.model.insert_audio(audio)
|
||||
self.model.infer(True)
|
||||
self.model.refresh_segment(complete=True)
|
||||
logger.info("SimulStreaming model warmed up successfully")
|
||||
except Exception as e:
|
||||
logger.exception(f"SimulStreaming warmup failed: {e}")
|
||||
|
||||
def __del__(self):
|
||||
# free the model and add a new model to stack.
|
||||
# del self.model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
# self.asr.new_model_to_stack()
|
||||
self.model.remove_hooks()
|
||||
|
||||
class SimulStreamingASR():
|
||||
"""SimulStreaming backend with AlignAtt policy."""
|
||||
sep = ""
|
||||
|
||||
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr, **kwargs):
|
||||
logger.warning(SIMULSTREAMING_LICENSE)
|
||||
self.logfile = logfile
|
||||
self.transcribe_kargs = {}
|
||||
self.original_language = lan
|
||||
|
||||
self.model_path = kwargs.get('model_path', './large-v3.pt')
|
||||
self.frame_threshold = kwargs.get('frame_threshold', 25)
|
||||
self.audio_max_len = kwargs.get('audio_max_len', 20.0)
|
||||
self.audio_min_len = kwargs.get('audio_min_len', 0.0)
|
||||
self.segment_length = kwargs.get('segment_length', 0.5)
|
||||
self.beams = kwargs.get('beams', 1)
|
||||
self.decoder_type = kwargs.get('decoder_type', 'greedy' if self.beams == 1 else 'beam')
|
||||
self.task = kwargs.get('task', 'transcribe')
|
||||
self.cif_ckpt_path = kwargs.get('cif_ckpt_path', None)
|
||||
self.never_fire = kwargs.get('never_fire', False)
|
||||
self.init_prompt = kwargs.get('init_prompt', None)
|
||||
self.static_init_prompt = kwargs.get('static_init_prompt', None)
|
||||
self.max_context_tokens = kwargs.get('max_context_tokens', None)
|
||||
self.warmup_file = kwargs.get('warmup_file', None)
|
||||
self.preload_model_count = kwargs.get('preload_model_count', 1)
|
||||
self.disable_fast_encoder = kwargs.get('disable_fast_encoder', False)
|
||||
self.fast_encoder = False
|
||||
if model_dir is not None:
|
||||
self.model_path = model_dir
|
||||
elif modelsize is not None:
|
||||
model_mapping = {
|
||||
'tiny': './tiny.pt',
|
||||
'base': './base.pt',
|
||||
'small': './small.pt',
|
||||
'medium': './medium.pt',
|
||||
'medium.en': './medium.en.pt',
|
||||
'large-v1': './large-v1.pt',
|
||||
'base.en': './base.en.pt',
|
||||
'small.en': './small.en.pt',
|
||||
'tiny.en': './tiny.en.pt',
|
||||
'large-v2': './large-v2.pt',
|
||||
'large-v3': './large-v3.pt',
|
||||
'large': './large-v3.pt'
|
||||
}
|
||||
self.model_path = model_mapping.get(modelsize, f'./{modelsize}.pt')
|
||||
|
||||
self.cfg = AlignAttConfig(
|
||||
model_path=self.model_path,
|
||||
segment_length=self.segment_length,
|
||||
frame_threshold=self.frame_threshold,
|
||||
language=self.original_language,
|
||||
audio_max_len=self.audio_max_len,
|
||||
audio_min_len=self.audio_min_len,
|
||||
cif_ckpt_path=self.cif_ckpt_path,
|
||||
decoder_type="beam",
|
||||
beam_size=self.beams,
|
||||
task=self.task,
|
||||
never_fire=self.never_fire,
|
||||
init_prompt=self.init_prompt,
|
||||
max_context_tokens=self.max_context_tokens,
|
||||
static_init_prompt=self.static_init_prompt,
|
||||
)
|
||||
|
||||
# Set up tokenizer for translation if needed
|
||||
if self.task == "translate":
|
||||
self.tokenizer = self.set_translate_task()
|
||||
else:
|
||||
self.tokenizer = None
|
||||
|
||||
self.model_name = os.path.basename(self.cfg.model_path).replace(".pt", "")
|
||||
self.model_path = os.path.dirname(os.path.abspath(self.cfg.model_path))
|
||||
|
||||
self.mlx_encoder, self.fw_encoder = None, None
|
||||
if not self.disable_fast_encoder:
|
||||
if HAS_MLX_WHISPER:
|
||||
print('Simulstreaming will use MLX whisper for a faster encoder.')
|
||||
mlx_model_name = mlx_model_mapping[self.model_name]
|
||||
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_name)
|
||||
self.fast_encoder = True
|
||||
elif HAS_FASTER_WHISPER:
|
||||
print('Simulstreaming will use Faster Whisper for the encoder.')
|
||||
self.fw_encoder = WhisperModel(
|
||||
self.model_name,
|
||||
device='auto',
|
||||
compute_type='auto',
|
||||
)
|
||||
self.fast_encoder = True
|
||||
|
||||
self.models = [self.load_model() for i in range(self.preload_model_count)]
|
||||
|
||||
|
||||
def load_model(self):
|
||||
whisper_model = load_model(name=self.model_name, download_root=self.model_path, decoder_only=self.fast_encoder)
|
||||
warmup_audio = load_file(self.warmup_file)
|
||||
if warmup_audio is not None:
|
||||
warmup_audio = torch.from_numpy(warmup_audio).float()
|
||||
if self.fast_encoder:
|
||||
temp_model = PaddedAlignAttWhisper(
|
||||
cfg=self.cfg,
|
||||
loaded_model=whisper_model,
|
||||
mlx_encoder=self.mlx_encoder,
|
||||
fw_encoder=self.fw_encoder,
|
||||
)
|
||||
temp_model.warmup(warmup_audio)
|
||||
temp_model.remove_hooks()
|
||||
else:
|
||||
# For standard encoder, use the original transcribe warmup
|
||||
warmup_audio = load_file(self.warmup_file)
|
||||
whisper_model.transcribe(warmup_audio, language=self.original_language if self.original_language != 'auto' else None)
|
||||
return whisper_model
|
||||
|
||||
def get_new_model_instance(self):
|
||||
"""
|
||||
SimulStreaming cannot share the same backend because it uses global forward hooks on the attention layers.
|
||||
Therefore, each user requires a separate model instance, which can be memory-intensive. To maintain speed, we preload the models into memory.
|
||||
"""
|
||||
if len(self.models) == 0:
|
||||
self.models.append(self.load_model())
|
||||
new_model = self.models.pop()
|
||||
return new_model
|
||||
# self.models[0]
|
||||
|
||||
def new_model_to_stack(self):
|
||||
self.models.append(self.load_model())
|
||||
|
||||
|
||||
def set_translate_task(self):
|
||||
"""Set up translation task."""
|
||||
if self.cfg.language == 'auto':
|
||||
raise Exception('Translation cannot be done with language = auto')
|
||||
return tokenizer.get_tokenizer(
|
||||
multilingual=True,
|
||||
language=self.cfg.language,
|
||||
num_languages=99,
|
||||
task="translate"
|
||||
)
|
||||
|
||||
def transcribe(self, audio):
|
||||
"""
|
||||
Warmup is done directly in load_model
|
||||
"""
|
||||
pass
|
||||
@@ -8,7 +8,7 @@ class SimulWhisperConfig:
|
||||
'''Options that are common for all simul policies that could be implemented in SimulWhisper.'''
|
||||
model_path: str
|
||||
language: str = field(default="zh")
|
||||
nonspeech_prob: float = 1.0
|
||||
nonspeech_prob: float = 0.5
|
||||
audio_min_len: float = 1.0
|
||||
decoder_type: Literal["greedy","beam"] = "greedy"
|
||||
beam_size: int = 5
|
||||
@@ -24,6 +24,6 @@ class AlignAttConfig(SimulWhisperConfig):
|
||||
segment_length: float = field(default=1.0, metadata = {"help": "in second"})
|
||||
frame_threshold: int = 4
|
||||
rewind_threshold: int = 200
|
||||
audio_max_len: float = 30.0
|
||||
audio_max_len: float = 20.0
|
||||
cif_ckpt_path: str = ""
|
||||
never_fire: bool = False
|
||||
@@ -1,25 +0,0 @@
|
||||
📄 SimulStreaming (https://github.com/ufal/SimulStreaming) Licence
|
||||
|
||||
SimulStreaming is dual-licensed:
|
||||
|
||||
🔹 Non-Commercial Use
|
||||
|
||||
You may use SimulStreaming under the **PolyForm Noncommercial License 1.0.0** if you
|
||||
obtain the code through the GitHub repository. This license is **free of charge**
|
||||
and comes with **no obligations** for non-commercial users.
|
||||
|
||||
🔸 Commercial Use
|
||||
|
||||
Understanding who uses SimulStreaming commercially helps us improve and
|
||||
prioritize development. Therefore, we want to **require registration** of those who acquire a commercial licence.
|
||||
|
||||
We plan to make the commercial licenceses **affordable** to SMEs and individuals. We
|
||||
are considering to provide commercial licenses either for free or for symbolic
|
||||
one-time fee, and maybe also provide additional support. You can share your preference via the [questionnaire](https://forms.cloud.microsoft/e/7tCxb4gJfB).
|
||||
|
||||
You can also leave your contact [there](https://forms.cloud.microsoft/e/7tCxb4gJfB) to be notified when the commercial licenses become
|
||||
available.
|
||||
|
||||
✉️ Contact
|
||||
|
||||
[Dominik Macháček](https://ufal.mff.cuni.cz/dominik-machacek/), machacek@ufal.mff.cuni.cz
|
||||
@@ -25,6 +25,9 @@ class BeamTokens(Tokens):
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
def as_text(self, tokenizer):
|
||||
return tokenizer.decode(self.tokens)
|
||||
|
||||
class Logits(Tokens):
|
||||
def __init__(self, logits):
|
||||
super().__init__(logits)
|
||||
|
||||
5
whisperlivekit/simul_whisper/license_simulstreaming.py
Normal file
@@ -0,0 +1,5 @@
|
||||
SIMULSTREAMING_LICENSE = f"""
|
||||
SimulStreaming backend is dual-licensed:
|
||||
• Non-Commercial Use: PolyForm Noncommercial License 1.0.0.
|
||||
• Commercial Use: Check SimulStreaming README (github.com/ufal/SimulStreaming) for more details.
|
||||
"""
|
||||
72
whisperlivekit/simul_whisper/mlx_encoder.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
from mlx.utils import tree_unflatten
|
||||
|
||||
from mlx_whisper import whisper
|
||||
|
||||
mlx_model_mapping = {
|
||||
"tiny.en": "mlx-community/whisper-tiny.en-mlx",
|
||||
"tiny": "mlx-community/whisper-tiny-mlx",
|
||||
"base.en": "mlx-community/whisper-base.en-mlx",
|
||||
"base": "mlx-community/whisper-base-mlx",
|
||||
"small.en": "mlx-community/whisper-small.en-mlx",
|
||||
"small": "mlx-community/whisper-small-mlx",
|
||||
"medium.en": "mlx-community/whisper-medium.en-mlx",
|
||||
"medium": "mlx-community/whisper-medium-mlx",
|
||||
"large-v1": "mlx-community/whisper-large-v1-mlx",
|
||||
"large-v2": "mlx-community/whisper-large-v2-mlx",
|
||||
"large-v3": "mlx-community/whisper-large-v3-mlx",
|
||||
"large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
|
||||
"large": "mlx-community/whisper-large-mlx",
|
||||
}
|
||||
|
||||
def load_mlx_encoder(
|
||||
path_or_hf_repo: str,
|
||||
dtype: mx.Dtype = mx.float32,
|
||||
) -> whisper.Whisper:
|
||||
model_path = Path(path_or_hf_repo)
|
||||
if not model_path.exists():
|
||||
model_path = Path(snapshot_download(repo_id=path_or_hf_repo))
|
||||
|
||||
with open(str(model_path / "config.json"), "r") as f:
|
||||
config = json.loads(f.read())
|
||||
config.pop("model_type", None)
|
||||
quantization = config.pop("quantization", None)
|
||||
|
||||
model_args = whisper.ModelDimensions(**config)
|
||||
|
||||
wf = model_path / "weights.safetensors"
|
||||
if not wf.exists():
|
||||
wf = model_path / "weights.npz"
|
||||
weights = mx.load(str(wf))
|
||||
|
||||
model = whisper.Whisper(model_args, dtype)
|
||||
|
||||
if quantization is not None:
|
||||
class_predicate = (
|
||||
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
|
||||
and f"{p}.scales" in weights
|
||||
)
|
||||
nn.quantize(model, **quantization, class_predicate=class_predicate)
|
||||
|
||||
weights = tree_unflatten(list(weights.items()))
|
||||
|
||||
# we only want to load the encoder weights here.
|
||||
# Size examples: for tiny.en,
|
||||
# Decoder weights: 59110771 bytes
|
||||
# Encoder weights: 15268874 bytes
|
||||
|
||||
|
||||
encoder_weights = {}
|
||||
encoder_weights['encoder'] = weights['encoder']
|
||||
del(weights)
|
||||
|
||||
|
||||
|
||||
model.update(encoder_weights)
|
||||
mx.eval(model.parameters())
|
||||
return model
|
||||
@@ -10,12 +10,12 @@ from .whisper import load_model, DecodingOptions, tokenizer
|
||||
from .config import AlignAttConfig
|
||||
from .whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES
|
||||
from .whisper.timing import median_filter
|
||||
from .whisper.decoding import SuppressBlank, GreedyDecoder, BeamSearchDecoder, SuppressTokens
|
||||
from .whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens, detect_language
|
||||
from .beam import BeamPyTorchInference
|
||||
from .eow_detection import fire_at_boundary, load_cif
|
||||
import os
|
||||
|
||||
from whisperlivekit.simul_whisper.token_buffer import TokenBuffer
|
||||
from time import time
|
||||
from .token_buffer import TokenBuffer
|
||||
|
||||
import numpy as np
|
||||
from .generation_progress import *
|
||||
@@ -23,7 +23,22 @@ from .generation_progress import *
|
||||
DEC_PAD = 50257
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import sys
|
||||
|
||||
try:
|
||||
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
|
||||
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
|
||||
HAS_MLX_WHISPER = True
|
||||
except ImportError:
|
||||
HAS_MLX_WHISPER = False
|
||||
if HAS_MLX_WHISPER:
|
||||
HAS_FASTER_WHISPER = False
|
||||
else:
|
||||
try:
|
||||
from faster_whisper.audio import pad_or_trim as fw_pad_or_trim
|
||||
from faster_whisper.feature_extractor import FeatureExtractor
|
||||
HAS_FASTER_WHISPER = True
|
||||
except ImportError:
|
||||
HAS_FASTER_WHISPER = False
|
||||
|
||||
# New features added to the original version of Simul-Whisper:
|
||||
# - large-v3 model support
|
||||
@@ -32,28 +47,43 @@ import sys
|
||||
# - prompt -- static vs. non-static
|
||||
# - context
|
||||
class PaddedAlignAttWhisper:
|
||||
def __init__(self, cfg: AlignAttConfig) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
cfg: AlignAttConfig,
|
||||
loaded_model=None,
|
||||
mlx_encoder=None,
|
||||
fw_encoder=None,
|
||||
) -> None:
|
||||
self.log_segments = 0
|
||||
model_name = os.path.basename(cfg.model_path).replace(".pt", "")
|
||||
model_path = os.path.dirname(os.path.abspath(cfg.model_path))
|
||||
self.model = load_model(name=model_name, download_root=model_path)
|
||||
if loaded_model:
|
||||
self.model = loaded_model
|
||||
else:
|
||||
self.model = load_model(name=model_name, download_root=model_path)
|
||||
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
self.mlx_encoder = mlx_encoder
|
||||
self.fw_encoder = fw_encoder
|
||||
if fw_encoder:
|
||||
self.fw_feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels)
|
||||
|
||||
logger.info(f"Model dimensions: {self.model.dims}")
|
||||
|
||||
decode_options = DecodingOptions(
|
||||
self.decode_options = DecodingOptions(
|
||||
language = cfg.language,
|
||||
without_timestamps = True,
|
||||
task=cfg.task
|
||||
)
|
||||
self.tokenizer = tokenizer.get_tokenizer(
|
||||
multilingual=not model_name.endswith(".en"),
|
||||
language=cfg.language,
|
||||
num_languages=self.model.num_languages,
|
||||
task=decode_options.task
|
||||
)
|
||||
self.tokenizer_is_multilingual = not model_name.endswith(".en")
|
||||
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
|
||||
self.detected_language = cfg.language if cfg.language != "auto" else None
|
||||
|
||||
self.max_text_len = self.model.dims.n_text_ctx
|
||||
self.num_decoder_layers = len(self.model.decoder.blocks)
|
||||
self.cfg = cfg
|
||||
|
||||
self.l_hooks = []
|
||||
|
||||
# model to detect end-of-word boundary at the end of the segment
|
||||
self.CIFLinear, self.always_fire, self.never_fire = load_cif(cfg,
|
||||
@@ -67,7 +97,8 @@ class PaddedAlignAttWhisper:
|
||||
t = F.softmax(net_output[1], dim=-1)
|
||||
self.dec_attns.append(t.squeeze(0))
|
||||
for b in self.model.decoder.blocks:
|
||||
b.cross_attn.register_forward_hook(layer_hook)
|
||||
hook = b.cross_attn.register_forward_hook(layer_hook)
|
||||
self.l_hooks.append(hook)
|
||||
|
||||
self.kv_cache = {}
|
||||
def kv_hook(module: torch.nn.Linear, _, net_output: torch.Tensor):
|
||||
@@ -80,10 +111,13 @@ class PaddedAlignAttWhisper:
|
||||
return self.kv_cache[module.cache_id]
|
||||
|
||||
for i,b in enumerate(self.model.decoder.blocks):
|
||||
b.attn.key.register_forward_hook(kv_hook)
|
||||
b.attn.value.register_forward_hook(kv_hook)
|
||||
b.cross_attn.key.register_forward_hook(kv_hook)
|
||||
b.cross_attn.value.register_forward_hook(kv_hook)
|
||||
hooks = [
|
||||
b.attn.key.register_forward_hook(kv_hook),
|
||||
b.attn.value.register_forward_hook(kv_hook),
|
||||
b.cross_attn.key.register_forward_hook(kv_hook),
|
||||
b.cross_attn.value.register_forward_hook(kv_hook),
|
||||
]
|
||||
self.l_hooks.extend(hooks)
|
||||
|
||||
self.align_source = {}
|
||||
self.num_align_heads = 0
|
||||
@@ -95,14 +129,6 @@ class PaddedAlignAttWhisper:
|
||||
self.num_align_heads += 1
|
||||
|
||||
|
||||
# init tokens (mandatory prompt)
|
||||
self.initial_tokens = torch.tensor(
|
||||
self.tokenizer.sot_sequence_including_notimestamps,
|
||||
dtype=torch.long,
|
||||
device=self.model.device).unsqueeze(0)
|
||||
self.initial_token_length = self.initial_tokens.shape[1]
|
||||
self.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
|
||||
|
||||
# tokens to be suppressed from decoding, to prevent hallucinations
|
||||
suppress_tokens = [
|
||||
self.tokenizer.transcribe,
|
||||
@@ -121,6 +147,18 @@ class PaddedAlignAttWhisper:
|
||||
self.suppress_tokens = lambda logits: sup_tokens.apply(logits, None)
|
||||
# blank tokens are suppresed for new segments near the line 334
|
||||
|
||||
# it's going to be regenerated after lang id
|
||||
self.segments = []
|
||||
self.init_tokens()
|
||||
|
||||
self.last_attend_frame = -self.cfg.rewind_threshold
|
||||
self.cumulative_time_offset = 0.0
|
||||
|
||||
if self.cfg.max_context_tokens is None:
|
||||
self.max_context_tokens = self.max_text_len
|
||||
else:
|
||||
self.max_context_tokens = self.cfg.max_context_tokens
|
||||
self.init_context()
|
||||
|
||||
# decoder type: greedy or beam
|
||||
if cfg.decoder_type == "greedy":
|
||||
@@ -134,17 +172,27 @@ class PaddedAlignAttWhisper:
|
||||
self.inference.kv_cache = self.kv_cache
|
||||
|
||||
self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size)
|
||||
|
||||
def remove_hooks(self):
|
||||
for hook in self.l_hooks:
|
||||
hook.remove()
|
||||
|
||||
# init state
|
||||
self.segments = []
|
||||
self.tokens = [self.initial_tokens]
|
||||
self.last_attend_frame = -self.cfg.rewind_threshold
|
||||
def warmup(self, audio):
|
||||
try:
|
||||
self.insert_audio(audio)
|
||||
self.infer(is_last=True)
|
||||
self.refresh_segment(complete=True)
|
||||
logger.info("Model warmed up successfully")
|
||||
except Exception as e:
|
||||
logger.exception(f"Model warmup failed: {e}")
|
||||
|
||||
if self.cfg.max_context_tokens is None:
|
||||
self.max_context_tokens = self.max_text_len
|
||||
else:
|
||||
self.max_context_tokens = self.cfg.max_context_tokens
|
||||
self.init_context()
|
||||
def create_tokenizer(self, language=None):
|
||||
self.tokenizer = tokenizer.get_tokenizer(
|
||||
multilingual=self.tokenizer_is_multilingual,
|
||||
language=language,
|
||||
num_languages=self.model.num_languages,
|
||||
task=self.decode_options.task
|
||||
)
|
||||
|
||||
def init_context(self):
|
||||
kw = {'tokenizer': self.tokenizer,
|
||||
@@ -156,6 +204,19 @@ class PaddedAlignAttWhisper:
|
||||
if self.cfg.init_prompt is not None:
|
||||
self.context.text += self.cfg.init_prompt
|
||||
|
||||
def init_tokens(self):
|
||||
logger.debug(f"init tokens, {len(self.segments)}")
|
||||
# init tokens (mandatory prompt)
|
||||
self.initial_tokens = torch.tensor(
|
||||
self.tokenizer.sot_sequence_including_notimestamps,
|
||||
dtype=torch.long,
|
||||
device=self.model.device).unsqueeze(0)
|
||||
self.initial_token_length = self.initial_tokens.shape[1]
|
||||
self.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
|
||||
# self.segments = []
|
||||
logger.debug(f"init tokens after, {len(self.segments)}")
|
||||
self.tokens = [self.initial_tokens]
|
||||
|
||||
def trim_context(self):
|
||||
logger.info("Trimming context")
|
||||
c = len(self.context.as_token_ids()) - len(self.context.prefix_token_ids)
|
||||
@@ -191,15 +252,20 @@ class PaddedAlignAttWhisper:
|
||||
|
||||
def refresh_segment(self, complete=False):
|
||||
|
||||
logger.debug("Refreshing segment")
|
||||
self.tokens = [self.initial_tokens]
|
||||
logger.debug("Refreshing segment:")
|
||||
self.init_tokens()
|
||||
self.last_attend_frame = -self.cfg.rewind_threshold
|
||||
self.detected_language = None
|
||||
self.cumulative_time_offset = 0.0
|
||||
self.init_context()
|
||||
logger.debug(f"Context: {self.context}")
|
||||
if not complete and len(self.segments) > 2:
|
||||
logger.debug("keeping last two segments because they are and it is not complete.")
|
||||
self.segments = self.segments[-2:]
|
||||
else:
|
||||
logger.debug("removing all segments.")
|
||||
self.segments = []
|
||||
self.log_segments += 1
|
||||
|
||||
|
||||
def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor):
|
||||
@@ -208,8 +274,6 @@ class PaddedAlignAttWhisper:
|
||||
return fire_at_boundary(chunked_encoder_feature, self.CIFLinear)
|
||||
|
||||
|
||||
|
||||
|
||||
def _current_tokens(self):
|
||||
|
||||
toks = self.tokens
|
||||
@@ -256,16 +320,60 @@ class PaddedAlignAttWhisper:
|
||||
removed_len = 0
|
||||
# len of audio is bigger than buffer_len. Going to remove the first segment
|
||||
segments_len = self.segments_len()
|
||||
while segments_len > self.cfg.audio_max_len:
|
||||
while len(self.segments) > 1 and segments_len > self.cfg.audio_max_len:
|
||||
removed_len = self.segments[0].shape[0] / 16000
|
||||
segments_len -= removed_len
|
||||
self.last_attend_frame -= int(TOKENS_PER_SECOND*removed_len)
|
||||
self.cumulative_time_offset += removed_len # Track cumulative time removed
|
||||
self.segments = self.segments[1:]
|
||||
logger.debug(f"remove segments: {len(self.segments)} {len(self.tokens)}")
|
||||
self.context.append_token_ids(self.tokens[1][0,:])
|
||||
self.tokens = [self.initial_tokens] + self.tokens[2:]
|
||||
logger.debug(f"remove segments: {len(self.segments)} {len(self.tokens)}, cumulative offset: {self.cumulative_time_offset:.2f}s")
|
||||
if len(self.tokens) > 1:
|
||||
self.context.append_token_ids(self.tokens[1][0,:])
|
||||
self.tokens = [self.initial_tokens] + self.tokens[2:]
|
||||
return removed_len
|
||||
|
||||
def _clean_cache(self):
|
||||
'''clean the cache that stores the attention matrices and kv_cache.
|
||||
It must be called every time after generation with the model.'''
|
||||
# cleaning cache
|
||||
self.dec_attns = []
|
||||
self.kv_cache = {}
|
||||
if self.decoder_type == "beam":
|
||||
self.inference.kv_cache = self.kv_cache
|
||||
self.token_decoder.reset()
|
||||
|
||||
@torch.no_grad()
|
||||
def lang_id(self, encoder_features):
|
||||
"""Language detection from encoder features.
|
||||
This code is trimmed and copy-pasted from whisper.decoding.detect_language .
|
||||
"""
|
||||
|
||||
# forward pass using a single token, startoftranscript
|
||||
n_audio = encoder_features.shape[0]
|
||||
x = torch.tensor([[self.tokenizer.sot]] * n_audio).to(self.model.device) # [n_audio, 1]
|
||||
logits = self.model.logits(x, encoder_features)[:, 0]
|
||||
|
||||
# collect detected languages; suppress all non-language tokens
|
||||
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
|
||||
mask[list(self.tokenizer.all_language_tokens)] = False
|
||||
logits[:, mask] = -np.inf
|
||||
language_tokens = logits.argmax(dim=-1)
|
||||
language_token_probs = logits.softmax(dim=-1).cpu()
|
||||
language_probs = [
|
||||
{
|
||||
c: language_token_probs[i, j].item()
|
||||
for j, c in zip(self.tokenizer.all_language_tokens, self.tokenizer.all_language_codes)
|
||||
}
|
||||
for i in range(n_audio)
|
||||
]
|
||||
|
||||
single = encoder_features.ndim == 2
|
||||
if single:
|
||||
language_tokens = language_tokens[0]
|
||||
language_probs = language_probs[0]
|
||||
|
||||
self._clean_cache()
|
||||
return language_tokens, language_probs
|
||||
|
||||
### transcription / translation
|
||||
|
||||
@@ -273,9 +381,12 @@ class PaddedAlignAttWhisper:
|
||||
def infer(self, is_last=False):
|
||||
new_segment = True
|
||||
if len(self.segments) == 0:
|
||||
return []
|
||||
logger.debug("No segments, nothing to do")
|
||||
return [], {}
|
||||
if not self._apply_minseglen():
|
||||
return []
|
||||
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
|
||||
input_segments = torch.cat(self.segments, dim=0)
|
||||
return [], {}
|
||||
|
||||
# input_segments is concatenation of audio, it's one array
|
||||
if len(self.segments) > 1:
|
||||
@@ -283,30 +394,67 @@ class PaddedAlignAttWhisper:
|
||||
else:
|
||||
input_segments = self.segments[0]
|
||||
|
||||
# NEW : we can use a different encoder, before using standart whisper for cross attention with the hooks on the decoder
|
||||
beg_encode = time()
|
||||
if self.mlx_encoder:
|
||||
mlx_mel_padded = mlx_log_mel_spectrogram(audio=input_segments.detach(), n_mels=self.model.dims.n_mels, padding=N_SAMPLES)
|
||||
mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2)
|
||||
mlx_encoder_feature = self.mlx_encoder.encoder(mlx_mel[None])
|
||||
encoder_feature = torch.as_tensor(mlx_encoder_feature)
|
||||
content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0])/2)
|
||||
elif self.fw_encoder:
|
||||
audio_length_seconds = len(input_segments) / 16000
|
||||
content_mel_len = int(audio_length_seconds * 100)//2
|
||||
mel_padded_2 = self.fw_feature_extractor(waveform=input_segments.numpy(), padding=N_SAMPLES)[None, :]
|
||||
mel = fw_pad_or_trim(mel_padded_2, N_FRAMES, axis=-1)
|
||||
encoder_feature_ctranslate = self.fw_encoder.encode(mel)
|
||||
if self.device == 'cpu': #it seems that on gpu, passing StorageView to torch.as_tensor fails and wrapping in the array works
|
||||
encoder_feature_ctranslate = np.array(encoder_feature_ctranslate)
|
||||
try:
|
||||
encoder_feature = torch.as_tensor(encoder_feature_ctranslate, device=self.device)
|
||||
except TypeError: # Normally the cpu condition should prevent having exceptions, but just in case:
|
||||
encoder_feature = torch.as_tensor(np.array(encoder_feature_ctranslate), device=self.device)
|
||||
else:
|
||||
# mel + padding to 30s
|
||||
mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES,
|
||||
device=self.device).unsqueeze(0)
|
||||
# trim to 3000
|
||||
mel = pad_or_trim(mel_padded, N_FRAMES)
|
||||
# the len of actual audio
|
||||
content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2)
|
||||
encoder_feature = self.model.encoder(mel)
|
||||
end_encode = time()
|
||||
# print('Encoder duration:', end_encode-beg_encode)
|
||||
|
||||
# logger.debug(f"Encoder feature shape: {encoder_feature.shape}")
|
||||
# if mel.shape[-2:] != (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
|
||||
# logger.debug("mel ")
|
||||
if self.cfg.language == "auto" and self.detected_language is None:
|
||||
language_tokens, language_probs = self.lang_id(encoder_feature)
|
||||
logger.debug(f"Language tokens: {language_tokens}, probs: {language_probs}")
|
||||
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
|
||||
logger.info(f"Detected language: {top_lan} with p={p:.4f}")
|
||||
#self.tokenizer.language = top_lan
|
||||
#self.tokenizer.__post_init__()
|
||||
self.create_tokenizer(top_lan)
|
||||
self.detected_language = top_lan
|
||||
self.init_tokens()
|
||||
logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}")
|
||||
|
||||
self.trim_context()
|
||||
current_tokens = self._current_tokens()
|
||||
|
||||
# mel + padding to 30s
|
||||
mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES,
|
||||
device=self.model.device).unsqueeze(0)
|
||||
# trim to 3000
|
||||
mel = pad_or_trim(mel_padded, N_FRAMES)
|
||||
|
||||
# the len of actual audio
|
||||
content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2)
|
||||
|
||||
encoder_feature = self.model.encoder(mel)
|
||||
sum_logprobs = torch.zeros(self.cfg.beam_size, device=mel.device)
|
||||
completed = False
|
||||
|
||||
#
|
||||
fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :])
|
||||
|
||||
|
||||
####################### Decoding loop
|
||||
logger.info("Decoding loop starts\n")
|
||||
|
||||
sum_logprobs = torch.zeros(self.cfg.beam_size, device=self.device)
|
||||
completed = False
|
||||
|
||||
attn_of_alignment_heads = None
|
||||
miost_attended_frame = None
|
||||
most_attended_frame = None
|
||||
|
||||
token_len_before_decoding = current_tokens.shape[1]
|
||||
|
||||
@@ -412,7 +560,13 @@ class PaddedAlignAttWhisper:
|
||||
# for each beam, the most attended frame is:
|
||||
most_attended_frames = torch.argmax(attn_of_alignment_heads[:,-1,:], dim=-1)
|
||||
generation_progress_loop.append(("most_attended_frames",most_attended_frames.clone().tolist()))
|
||||
|
||||
# Calculate absolute timestamps accounting for cumulative offset
|
||||
absolute_timestamps = [(frame * 0.02 + self.cumulative_time_offset) for frame in most_attended_frames.tolist()]
|
||||
generation_progress_loop.append(("absolute_timestamps", absolute_timestamps))
|
||||
|
||||
logger.debug(str(most_attended_frames.tolist()) + " most att frames")
|
||||
logger.debug(f"Absolute timestamps: {absolute_timestamps} (offset: {self.cumulative_time_offset:.2f}s)")
|
||||
|
||||
most_attended_frame = most_attended_frames[0].item()
|
||||
|
||||
@@ -507,7 +661,7 @@ class PaddedAlignAttWhisper:
|
||||
### new hypothesis
|
||||
logger.debug(f"new_hypothesis: {new_hypothesis}")
|
||||
new_tokens = torch.tensor([new_hypothesis], dtype=torch.long).repeat_interleave(self.cfg.beam_size, dim=0).to(
|
||||
device=self.model.device,
|
||||
device=self.device,
|
||||
)
|
||||
self.tokens.append(new_tokens)
|
||||
# TODO: test if this is redundant or not
|
||||
@@ -515,11 +669,6 @@ class PaddedAlignAttWhisper:
|
||||
|
||||
logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}")
|
||||
|
||||
# cleaning cache
|
||||
self.dec_attns = []
|
||||
self.kv_cache = {}
|
||||
if self.decoder_type == "beam":
|
||||
self.inference.kv_cache = self.kv_cache
|
||||
self.token_decoder.reset()
|
||||
self._clean_cache()
|
||||
|
||||
return new_hypothesis, generation
|
||||
return new_hypothesis, generation
|
||||
@@ -105,6 +105,7 @@ def load_model(
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
download_root: str = None,
|
||||
in_memory: bool = False,
|
||||
decoder_only=False
|
||||
) -> Whisper:
|
||||
"""
|
||||
Load a Whisper ASR model
|
||||
@@ -151,7 +152,14 @@ def load_model(
|
||||
del checkpoint_file
|
||||
|
||||
dims = ModelDimensions(**checkpoint["dims"])
|
||||
model = Whisper(dims)
|
||||
model = Whisper(dims, decoder_only=decoder_only)
|
||||
|
||||
if decoder_only:
|
||||
checkpoint["model_state_dict"] = {
|
||||
k: v for k, v in checkpoint["model_state_dict"].items()
|
||||
if 'encoder' not in k
|
||||
}
|
||||
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
|
||||
if alignment_heads is not None:
|
||||
|
||||
@@ -32,7 +32,9 @@ def detect_language(
|
||||
list of dictionaries containing the probability distribution over all languages.
|
||||
"""
|
||||
if tokenizer is None:
|
||||
tokenizer = get_tokenizer(model.is_multilingual)
|
||||
tokenizer = get_tokenizer(
|
||||
model.is_multilingual, num_languages=model.num_languages
|
||||
)
|
||||
if (
|
||||
tokenizer.language is None
|
||||
or tokenizer.language_token not in tokenizer.sot_sequence
|
||||
@@ -111,9 +113,6 @@ class DecodingOptions:
|
||||
# implementation details
|
||||
fp16: bool = True # use fp16 for most of the calculation
|
||||
|
||||
# streaming
|
||||
add_sot: Optional[bool] = True
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DecodingResult:
|
||||
@@ -513,19 +512,17 @@ class DecodingTask:
|
||||
logit_filters: List[LogitFilter]
|
||||
|
||||
def __init__(self, model: "Whisper", options: DecodingOptions):
|
||||
self.options: DecodingOptions = self._verify_options(options)
|
||||
if self.options.fp16:
|
||||
self.model = model.half()
|
||||
else:
|
||||
self.model = model
|
||||
self.model = model
|
||||
|
||||
language = options.language or "en"
|
||||
tokenizer = get_tokenizer(
|
||||
model.is_multilingual, language=language, task=options.task
|
||||
model.is_multilingual,
|
||||
num_languages=model.num_languages,
|
||||
language=language,
|
||||
task=options.task,
|
||||
)
|
||||
self.tokenizer: Tokenizer = tokenizer
|
||||
|
||||
# print(self.options)
|
||||
self.options: DecodingOptions = self._verify_options(options)
|
||||
|
||||
self.n_group: int = options.beam_size or options.best_of or 1
|
||||
self.n_ctx: int = model.dims.n_text_ctx
|
||||
@@ -589,7 +586,7 @@ class DecodingTask:
|
||||
|
||||
def _get_initial_tokens(self) -> Tuple[int]:
|
||||
tokens = list(self.sot_sequence)
|
||||
# print("prefix", prefix)
|
||||
|
||||
if prefix := self.options.prefix:
|
||||
prefix_tokens = (
|
||||
self.tokenizer.encode(" " + prefix.strip())
|
||||
@@ -607,15 +604,12 @@ class DecodingTask:
|
||||
if isinstance(prompt, str)
|
||||
else prompt
|
||||
)
|
||||
# if self.options.add_sot:
|
||||
tokens = (
|
||||
[self.tokenizer.sot_prev]
|
||||
+ prompt_tokens[-(self.n_ctx // 2 - 1) :]
|
||||
+ tokens
|
||||
)
|
||||
#else:
|
||||
# tokens = ([self.tokenizer.sot_prev] + tokens + prompt_tokens[-(self.n_ctx // 2 - 1) :])
|
||||
# print("return", tokens)
|
||||
|
||||
return tuple(tokens)
|
||||
|
||||
def _get_suppress_tokens(self) -> Tuple[int]:
|
||||
@@ -663,7 +657,7 @@ class DecodingTask:
|
||||
if audio_features.dtype != (
|
||||
torch.float16 if self.options.fp16 else torch.float32
|
||||
):
|
||||
raise TypeError(
|
||||
return TypeError(
|
||||
f"audio_features has an incorrect dtype: {audio_features.dtype}"
|
||||
)
|
||||
|
||||
@@ -689,10 +683,9 @@ class DecodingTask:
|
||||
no_speech_probs = [np.nan] * n_batch
|
||||
|
||||
try:
|
||||
for i in range(self.sample_len): # 最多循环448次
|
||||
# print("in decode main loop", i , tokens[0].tolist())
|
||||
for i in range(self.sample_len):
|
||||
logits = self.inference.logits(tokens, audio_features)
|
||||
# print(logits)
|
||||
|
||||
if (
|
||||
i == 0 and self.tokenizer.no_speech is not None
|
||||
): # save no_speech_probs
|
||||
@@ -724,7 +717,7 @@ class DecodingTask:
|
||||
|
||||
audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass
|
||||
tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
|
||||
# print("initial_tokens", self.initial_tokens)
|
||||
|
||||
# detect language if requested, overwriting the language token
|
||||
languages, language_probs = self._detect_language(audio_features, tokens)
|
||||
if self.options.task == "lang_id":
|
||||
|
||||
@@ -13,7 +13,6 @@ from .decoding import decode as decode_function
|
||||
from .decoding import detect_language as detect_language_function
|
||||
from .transcribe import transcribe as transcribe_function
|
||||
|
||||
|
||||
try:
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
@@ -37,26 +36,27 @@ class ModelDimensions:
|
||||
n_text_layer: int
|
||||
|
||||
|
||||
# class LayerNorm(nn.LayerNorm):
|
||||
# def forward(self, x: Tensor) -> Tensor:
|
||||
# return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
# class Linear(nn.Linear):
|
||||
# def forward(self, x: Tensor) -> Tensor:
|
||||
# return F.linear(
|
||||
# x,
|
||||
# self.weight.to(x.dtype),
|
||||
# None if self.bias is None else self.bias.to(x.dtype),
|
||||
# )
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
|
||||
# class Conv1d(nn.Conv1d):
|
||||
# def _conv_forward(
|
||||
# self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
|
||||
# ) -> Tensor:
|
||||
# return super()._conv_forward(
|
||||
# x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
|
||||
# )
|
||||
class Linear(nn.Linear):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return F.linear(
|
||||
x,
|
||||
self.weight.to(x.dtype),
|
||||
None if self.bias is None else self.bias.to(x.dtype),
|
||||
)
|
||||
|
||||
|
||||
class Conv1d(nn.Conv1d):
|
||||
def _conv_forward(
|
||||
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
|
||||
) -> Tensor:
|
||||
return super()._conv_forward(
|
||||
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
|
||||
)
|
||||
|
||||
|
||||
def sinusoids(length, channels, max_timescale=10000):
|
||||
@@ -67,21 +67,30 @@ def sinusoids(length, channels, max_timescale=10000):
|
||||
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
||||
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
||||
|
||||
import sys ## this is mine, for debugging
|
||||
|
||||
@contextmanager
|
||||
def disable_sdpa():
|
||||
prev_state = MultiHeadAttention.use_sdpa
|
||||
try:
|
||||
MultiHeadAttention.use_sdpa = False
|
||||
yield
|
||||
finally:
|
||||
MultiHeadAttention.use_sdpa = prev_state
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
use_sdpa = False # Disable SDPA to ensure qk is always computed for hooks
|
||||
|
||||
use_sdpa = False # disabling: https://github.com/linto-ai/whisper-timestamped/issues/212
|
||||
|
||||
def __init__(self, n_state: int, n_head: int, cache_id: str):
|
||||
def __init__(self, n_state: int, n_head: int, cache_id: str = ""):
|
||||
super().__init__()
|
||||
self.n_head = n_head
|
||||
self.query = nn.Linear(n_state, n_state)
|
||||
self.key = nn.Linear(n_state, n_state, bias=False)
|
||||
self.key.cache_id = f"{cache_id}_key"
|
||||
self.value = nn.Linear(n_state, n_state)
|
||||
self.value.cache_id = f"{cache_id}_value"
|
||||
self.out = nn.Linear(n_state, n_state)
|
||||
self.query = Linear(n_state, n_state)
|
||||
self.key = Linear(n_state, n_state, bias=False)
|
||||
self.value = Linear(n_state, n_state)
|
||||
self.out = Linear(n_state, n_state)
|
||||
self.cache_id = cache_id
|
||||
self.key.cache_id = f"{cache_id}_key"
|
||||
self.value.cache_id = f"{cache_id}_value"
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -90,45 +99,21 @@ class MultiHeadAttention(nn.Module):
|
||||
mask: Optional[Tensor] = None,
|
||||
kv_cache: Optional[dict] = None,
|
||||
):
|
||||
#print("MultiHeadAttention forward",file=sys.stderr)
|
||||
q = self.query(x)
|
||||
# print(q.shape, x is None, mask is None, list(kv_cache.keys()) if kv_cache is not None else None, file=sys.stderr)
|
||||
# print(mask, kv_cache, xa, file=sys.stderr)
|
||||
|
||||
if kv_cache is None or xa is None or self.key.cache_id not in kv_cache:
|
||||
if kv_cache is None or xa is None or self.key not in kv_cache:
|
||||
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
|
||||
# otherwise, perform key/value projections for self- or cross-attention as usual.
|
||||
k = self.key(x if xa is None else xa)
|
||||
v = self.value(x if xa is None else xa)
|
||||
# print(self.key.cache_id, "cache miss") # , kv_cache is None, xa is None, self.key.cache_id not in kv_cache if kv_cache is not None else None, k.shape, x.shape)
|
||||
# if kv_cache is not None:
|
||||
# print(kv_cache.keys())
|
||||
else:
|
||||
# print(self.key.cache_id, "cache hit") #, kv_cache is None, xa is None, self.key.cache_id not in kv_cache)
|
||||
# if kv_cache is not None:
|
||||
# print(kv_cache.keys())
|
||||
k = kv_cache[self.key.cache_id]
|
||||
v = kv_cache[self.value.cache_id]
|
||||
# print(self.key.cache_id, "qkv attention", q.shape, k.shape, v.shape)
|
||||
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
|
||||
k = kv_cache[self.key]
|
||||
v = kv_cache[self.value]
|
||||
|
||||
wv, qk = self.qkv_attention(q, k, v, mask)
|
||||
return self.out(wv), qk
|
||||
|
||||
# def qkv_attention(
|
||||
# self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
||||
# ):
|
||||
# n_batch, n_ctx, n_state = q.shape
|
||||
# scale = (n_state // self.n_head) ** -0.25
|
||||
# q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
|
||||
# k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
|
||||
# v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||
|
||||
# qk = q @ k
|
||||
# if mask is not None:
|
||||
# qk = qk + mask[:n_ctx, :n_ctx]
|
||||
# # qk = qk.float()
|
||||
|
||||
# w = F.softmax(qk, dim=-1) # .to(q.dtype)
|
||||
# return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
|
||||
|
||||
|
||||
def qkv_attention(
|
||||
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
@@ -158,21 +143,22 @@ class MultiHeadAttention(nn.Module):
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
def __init__(self, n_state: int, n_head: int, cache_id: str="", cross_attention: bool = False):
|
||||
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, cache_id: str = ""):
|
||||
super().__init__()
|
||||
|
||||
self.attn = MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_self_attn")
|
||||
self.attn_ln = nn.LayerNorm(n_state)
|
||||
self.attn_ln = LayerNorm(n_state)
|
||||
|
||||
self.cross_attn = MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_cross_attn") if cross_attention else None
|
||||
|
||||
self.cross_attn_ln = nn.LayerNorm(n_state) if cross_attention else None
|
||||
self.cross_attn = (
|
||||
MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_cross_attn") if cross_attention else None
|
||||
)
|
||||
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
||||
|
||||
n_mlp = n_state * 4
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state)
|
||||
Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
|
||||
)
|
||||
self.mlp_ln = nn.LayerNorm(n_state)
|
||||
self.mlp_ln = LayerNorm(n_state)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -181,8 +167,6 @@ class ResidualAttentionBlock(nn.Module):
|
||||
mask: Optional[Tensor] = None,
|
||||
kv_cache: Optional[dict] = None,
|
||||
):
|
||||
# print("ResidualAttentionBlock forward",file=sys.stderr)
|
||||
# print(x.shape, file=sys.stderr)
|
||||
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
|
||||
if self.cross_attn:
|
||||
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
|
||||
@@ -195,44 +179,32 @@ class AudioEncoder(nn.Module):
|
||||
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
||||
):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
||||
self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
||||
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
||||
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
||||
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
|
||||
|
||||
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
||||
[ResidualAttentionBlock(n_state, n_head, cache_id=f"enc_layer{i}") for i in range(n_layer)]
|
||||
)
|
||||
self.ln_post = nn.LayerNorm(n_state)
|
||||
self.ln_post = LayerNorm(n_state)
|
||||
|
||||
def forward(self, x: Tensor, return_layer_results: bool=False):
|
||||
def forward(self, x: Tensor):
|
||||
"""
|
||||
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
||||
the mel spectrogram of the audio
|
||||
"""
|
||||
|
||||
x = F.gelu(self.conv1(x))
|
||||
x = F.gelu(self.conv2(x))
|
||||
x = x.permute(0, 2, 1) # BDT -> BTD
|
||||
x = x.permute(0, 2, 1)
|
||||
|
||||
# 两层卷积,2倍降采样
|
||||
# 最终剩下1500帧
|
||||
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
|
||||
x = (x + self.positional_embedding).to(x.dtype)
|
||||
|
||||
x = (x + self.positional_embedding[:x.shape[1], :]) #.to(x.dtype)
|
||||
|
||||
layer_results = []
|
||||
i = 0
|
||||
for block in self.blocks:
|
||||
# print(f"encoder layer {i}")
|
||||
x = block(x)
|
||||
layer_results.append(x)
|
||||
i += 1
|
||||
|
||||
x = self.ln_post(x)
|
||||
|
||||
if return_layer_results:
|
||||
return x, layer_results
|
||||
else:
|
||||
return x
|
||||
return x
|
||||
|
||||
|
||||
class TextDecoder(nn.Module):
|
||||
@@ -250,7 +222,7 @@ class TextDecoder(nn.Module):
|
||||
for i in range(n_layer)
|
||||
]
|
||||
)
|
||||
self.ln = nn.LayerNorm(n_state)
|
||||
self.ln = LayerNorm(n_state)
|
||||
|
||||
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
|
||||
self.register_buffer("mask", mask, persistent=False)
|
||||
@@ -262,37 +234,37 @@ class TextDecoder(nn.Module):
|
||||
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
|
||||
the encoded audio features to be attended on
|
||||
"""
|
||||
|
||||
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
||||
x = (
|
||||
self.token_embedding(x)
|
||||
+ self.positional_embedding[offset : offset + x.shape[-1]]
|
||||
)
|
||||
# x = x.to(xa.dtype)
|
||||
x = x.to(xa.dtype)
|
||||
|
||||
i = 0
|
||||
for block in self.blocks:
|
||||
# print(f"decoder layer {i}")
|
||||
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
||||
i += 1
|
||||
|
||||
x = self.ln(x)
|
||||
logits = x @ torch.transpose(self.token_embedding.weight, 0, 1)
|
||||
logits = (
|
||||
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
|
||||
).float()
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
class Whisper(nn.Module):
|
||||
def __init__(self, dims: ModelDimensions):
|
||||
def __init__(self, dims: ModelDimensions, decoder_only: bool = False):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
self.encoder = AudioEncoder(
|
||||
self.dims.n_mels,
|
||||
self.dims.n_audio_ctx,
|
||||
self.dims.n_audio_state,
|
||||
self.dims.n_audio_head,
|
||||
self.dims.n_audio_layer,
|
||||
)
|
||||
|
||||
if not decoder_only:
|
||||
self.encoder = AudioEncoder(
|
||||
self.dims.n_mels,
|
||||
self.dims.n_audio_ctx,
|
||||
self.dims.n_audio_state,
|
||||
self.dims.n_audio_head,
|
||||
self.dims.n_audio_layer,
|
||||
)
|
||||
self.decoder = TextDecoder(
|
||||
self.dims.n_vocab,
|
||||
self.dims.n_text_ctx,
|
||||
@@ -300,7 +272,8 @@ class Whisper(nn.Module):
|
||||
self.dims.n_text_head,
|
||||
self.dims.n_text_layer,
|
||||
)
|
||||
# use the last half layers for alignment by default; see `set_alignment_heads()` below
|
||||
# use the last half among the decoder layers for time alignment by default;
|
||||
# to use a specific set of heads, see `set_alignment_heads()` below.
|
||||
all_heads = torch.zeros(
|
||||
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
|
||||
)
|
||||
@@ -320,15 +293,11 @@ class Whisper(nn.Module):
|
||||
return self.encoder(mel)
|
||||
|
||||
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
|
||||
# tokens = tokens.to(self.decoder.ln.weight.dtype)
|
||||
# audio_features = audio_features.to(self.decoder.ln.weight.dtype)
|
||||
return self.decoder(tokens, audio_features)
|
||||
|
||||
def forward(
|
||||
self, mel: torch.Tensor, tokens: torch.Tensor
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
# mel = mel.to(self.decoder.ln.weight.dtype)
|
||||
# tokens = tokens.to(self.decoder.ln.weight.dtype)
|
||||
return self.decoder(tokens, self.encoder(mel))
|
||||
|
||||
@property
|
||||
@@ -343,7 +312,6 @@ class Whisper(nn.Module):
|
||||
def num_languages(self):
|
||||
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
|
||||
|
||||
# 为decoder加入缓存机制,每次推理时保存上次的k和v,下次推理无需重新计算
|
||||
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
||||
"""
|
||||
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
|
||||
|
||||
@@ -30,15 +30,19 @@ def remove_symbols_and_diacritics(s: str, keep=""):
|
||||
and drop any diacritics (category 'Mn' and some manual mappings)
|
||||
"""
|
||||
return "".join(
|
||||
c
|
||||
if c in keep
|
||||
else ADDITIONAL_DIACRITICS[c]
|
||||
if c in ADDITIONAL_DIACRITICS
|
||||
else ""
|
||||
if unicodedata.category(c) == "Mn"
|
||||
else " "
|
||||
if unicodedata.category(c)[0] in "MSP"
|
||||
else c
|
||||
(
|
||||
c
|
||||
if c in keep
|
||||
else (
|
||||
ADDITIONAL_DIACRITICS[c]
|
||||
if c in ADDITIONAL_DIACRITICS
|
||||
else (
|
||||
""
|
||||
if unicodedata.category(c) == "Mn"
|
||||
else " " if unicodedata.category(c)[0] in "MSP" else c
|
||||
)
|
||||
)
|
||||
)
|
||||
for c in unicodedata.normalize("NFKD", s)
|
||||
)
|
||||
|
||||
|
||||
1741
whisperlivekit/simul_whisper/whisper/normalizers/english.json
Normal file
@@ -56,9 +56,8 @@ def median_filter(x: torch.Tensor, filter_width: int):
|
||||
|
||||
@numba.jit(nopython=True)
|
||||
def backtrace(trace: np.ndarray):
|
||||
i = trace.shape[0] - 1 # trace: (N+1, M+1), i=N
|
||||
j = trace.shape[1] - 1 # j=M
|
||||
# 边界点其实无意义?
|
||||
i = trace.shape[0] - 1
|
||||
j = trace.shape[1] - 1
|
||||
trace[0, :] = 2
|
||||
trace[:, 0] = 1
|
||||
|
||||
@@ -83,8 +82,8 @@ def backtrace(trace: np.ndarray):
|
||||
@numba.jit(nopython=True, parallel=True)
|
||||
def dtw_cpu(x: np.ndarray):
|
||||
N, M = x.shape
|
||||
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf # cost: x[0, 0]到x[i-1, j-1]的最小代价
|
||||
trace = -np.ones((N + 1, M + 1), dtype=np.float32) # trace:
|
||||
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
|
||||
trace = -np.ones((N + 1, M + 1), dtype=np.float32)
|
||||
|
||||
cost[0, 0] = 0
|
||||
for j in range(1, M + 1):
|
||||
@@ -118,7 +117,7 @@ def dtw_cuda(x, BLOCK_SIZE=1024):
|
||||
x_skew = x_skew.T.contiguous()
|
||||
cost = torch.ones(N + M + 2, M + 2) * np.inf
|
||||
cost[0, 0] = 0
|
||||
cost = cost.cuda()
|
||||
cost = cost.to(x.device)
|
||||
trace = torch.zeros_like(cost, dtype=torch.int32)
|
||||
|
||||
dtw_kernel[(1,)](
|
||||
@@ -192,21 +191,19 @@ def find_alignment(
|
||||
for i, block in enumerate(model.decoder.blocks)
|
||||
]
|
||||
|
||||
# 进行前传,获得token概率
|
||||
with torch.no_grad():
|
||||
from .model import disable_sdpa
|
||||
|
||||
with torch.no_grad(), disable_sdpa():
|
||||
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
|
||||
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
|
||||
token_probs = sampled_logits.softmax(dim=-1)
|
||||
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
|
||||
text_token_probs = text_token_probs.tolist()
|
||||
|
||||
# 移除钩子
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
|
||||
# heads * tokens * frames
|
||||
# print(model.alignment_heads)
|
||||
# exit(0)
|
||||
weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
|
||||
weights = weights[:, :, : num_frames // 2]
|
||||
weights = (weights * qk_scale).softmax(dim=-1)
|
||||
@@ -215,18 +212,9 @@ def find_alignment(
|
||||
weights = median_filter(weights, medfilt_width)
|
||||
|
||||
matrix = weights.mean(axis=0)
|
||||
print("attention", matrix.shape, matrix[:5, :5])
|
||||
matrix = matrix[len(tokenizer.sot_sequence) : -1]
|
||||
print("attention", matrix.shape, matrix[:5, :5])
|
||||
text_indices, time_indices = dtw(-matrix)
|
||||
|
||||
print("num_frames", num_frames)
|
||||
print("attention", matrix.shape, matrix[:5, :5])
|
||||
print("text_indices", text_indices)
|
||||
print("time", time_indices)
|
||||
print("text_tokens", text_tokens, tokenizer.decode(text_tokens), len(text_tokens))
|
||||
print("eot", tokenizer.eot)
|
||||
|
||||
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
|
||||
if len(word_tokens) <= 1:
|
||||
# return on eot only
|
||||
@@ -238,9 +226,7 @@ def find_alignment(
|
||||
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
|
||||
|
||||
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
|
||||
# print("jumps", jumps, jumps.shape)
|
||||
jump_times = time_indices[jumps] / TOKENS_PER_SECOND
|
||||
# print("jump_times", jump_times)
|
||||
start_times = jump_times[word_boundaries[:-1]]
|
||||
end_times = jump_times[word_boundaries[1:]]
|
||||
word_probabilities = [
|
||||
@@ -315,6 +301,7 @@ def add_word_timestamps(
|
||||
word_durations = np.array([t.end - t.start for t in alignment])
|
||||
word_durations = word_durations[word_durations.nonzero()]
|
||||
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
|
||||
median_duration = min(0.7, float(median_duration))
|
||||
max_duration = median_duration * 2
|
||||
|
||||
# hack: truncate long words at sentence boundaries.
|
||||
|
||||
@@ -1,501 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from whisper.audio import (
|
||||
FRAMES_PER_SECOND,
|
||||
HOP_LENGTH,
|
||||
N_FRAMES,
|
||||
N_SAMPLES,
|
||||
SAMPLE_RATE,
|
||||
log_mel_spectrogram,
|
||||
pad_or_trim,
|
||||
)
|
||||
from whisper.decoding import DecodingOptions, DecodingResult
|
||||
from whisper.timing import add_word_timestamps
|
||||
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
||||
from whisper.utils import (
|
||||
exact_div,
|
||||
format_timestamp,
|
||||
get_writer,
|
||||
make_safe,
|
||||
optional_float,
|
||||
optional_int,
|
||||
str2bool,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from whisper.model import Whisper
|
||||
|
||||
|
||||
def transcribe(
|
||||
model: "Whisper",
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
*,
|
||||
verbose: Optional[bool] = None,
|
||||
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
||||
compression_ratio_threshold: Optional[float] = 2.4,
|
||||
logprob_threshold: Optional[float] = -1.0,
|
||||
no_speech_threshold: Optional[float] = 0.6,
|
||||
condition_on_previous_text: bool = True,
|
||||
initial_prompt: Optional[str] = None,
|
||||
word_timestamps: bool = False,
|
||||
prepend_punctuations: str = "\"'“¿([{-",
|
||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||
**decode_options,
|
||||
):
|
||||
"""
|
||||
Transcribe an audio file using Whisper
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model: Whisper
|
||||
The Whisper model instance
|
||||
|
||||
audio: Union[str, np.ndarray, torch.Tensor]
|
||||
The path to the audio file to open, or the audio waveform
|
||||
|
||||
verbose: bool
|
||||
Whether to display the text being decoded to the console. If True, displays all the details,
|
||||
If False, displays minimal details. If None, does not display anything
|
||||
|
||||
temperature: Union[float, Tuple[float, ...]]
|
||||
Temperature for sampling. It can be a tuple of temperatures, which will be successively used
|
||||
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
|
||||
|
||||
compression_ratio_threshold: float
|
||||
If the gzip compression ratio is above this value, treat as failed
|
||||
|
||||
logprob_threshold: float
|
||||
If the average log probability over sampled tokens is below this value, treat as failed
|
||||
|
||||
no_speech_threshold: float
|
||||
If the no_speech probability is higher than this value AND the average log probability
|
||||
over sampled tokens is below `logprob_threshold`, consider the segment as silent
|
||||
|
||||
condition_on_previous_text: bool
|
||||
if True, the previous output of the model is provided as a prompt for the next window;
|
||||
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
||||
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
||||
|
||||
word_timestamps: bool
|
||||
Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
|
||||
and include the timestamps for each word in each segment.
|
||||
|
||||
prepend_punctuations: str
|
||||
If word_timestamps is True, merge these punctuation symbols with the next word
|
||||
|
||||
append_punctuations: str
|
||||
If word_timestamps is True, merge these punctuation symbols with the previous word
|
||||
|
||||
initial_prompt: Optional[str]
|
||||
Optional text to provide as a prompt for the first window. This can be used to provide, or
|
||||
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
||||
to make it more likely to predict those word correctly.
|
||||
|
||||
decode_options: dict
|
||||
Keyword arguments to construct `DecodingOptions` instances
|
||||
|
||||
Returns
|
||||
-------
|
||||
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
||||
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
||||
"""
|
||||
# print("HACKED")
|
||||
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
|
||||
if model.device == torch.device("cpu"):
|
||||
if torch.cuda.is_available():
|
||||
warnings.warn("Performing inference on CPU when CUDA is available")
|
||||
if dtype == torch.float16:
|
||||
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
|
||||
dtype = torch.float32
|
||||
|
||||
if dtype == torch.float32:
|
||||
decode_options["fp16"] = False
|
||||
|
||||
# Pad 30-seconds of silence to the input audio, for slicing
|
||||
mel = log_mel_spectrogram(audio, padding=0) # log_mel_spectrogram(audio, padding=N_SAMPLES) # 添加16000*30 = 480000个点
|
||||
# mel = pad_or_trim(mel, 3000)
|
||||
content_frames = mel.shape[-1] # - N_FRAMES # 对应3000帧;真正有内容的是去掉尾部3000的那些数据
|
||||
|
||||
# 判断语种
|
||||
if decode_options.get("language", None) is None:
|
||||
# 如果是单语种模型,直接设成英文
|
||||
if not model.is_multilingual:
|
||||
decode_options["language"] = "en"
|
||||
# 否则需要前传一次
|
||||
else:
|
||||
if verbose:
|
||||
print(
|
||||
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
|
||||
)
|
||||
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
||||
# print(mel_segment.shape)
|
||||
_, probs = model.detect_language(mel_segment)
|
||||
decode_options["language"] = max(probs, key=probs.get)
|
||||
if verbose is not None:
|
||||
print(
|
||||
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
|
||||
)
|
||||
|
||||
language: str = decode_options["language"]
|
||||
task: str = decode_options.get("task", "transcribe")
|
||||
# 输出编码器
|
||||
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
|
||||
|
||||
# 词级别时间戳
|
||||
if word_timestamps and task == "translate":
|
||||
warnings.warn("Word-level timestamps on translations may not be reliable.")
|
||||
|
||||
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
|
||||
temperatures = (
|
||||
[temperature] if isinstance(temperature, (int, float)) else temperature
|
||||
)
|
||||
decode_result = None
|
||||
|
||||
for t in temperatures:
|
||||
kwargs = {**decode_options}
|
||||
if t > 0:
|
||||
# disable beam_size and patience when t > 0
|
||||
kwargs.pop("beam_size", None)
|
||||
kwargs.pop("patience", None)
|
||||
else:
|
||||
# disable best_of when t == 0
|
||||
kwargs.pop("best_of", None)
|
||||
|
||||
options = DecodingOptions(**kwargs, temperature=t)
|
||||
decode_result = model.decode(segment, options)
|
||||
|
||||
# 几种解码可能失败的情况。这些情况下会重复解码
|
||||
# 感觉是一种KnowHow的东西 或许ChatGPT里有不少这种trick
|
||||
needs_fallback = False
|
||||
if (
|
||||
compression_ratio_threshold is not None
|
||||
and decode_result.compression_ratio > compression_ratio_threshold
|
||||
):
|
||||
needs_fallback = True # too repetitive
|
||||
if (
|
||||
logprob_threshold is not None
|
||||
and decode_result.avg_logprob < logprob_threshold
|
||||
):
|
||||
needs_fallback = True # average log probability is too low
|
||||
if (
|
||||
no_speech_threshold is not None
|
||||
and decode_result.no_speech_prob > no_speech_threshold
|
||||
):
|
||||
needs_fallback = False # silence
|
||||
if not needs_fallback:
|
||||
break
|
||||
# print("decode with temperature {} compress rate {:.3f}/{:.3f}, log_prob {:.3f}/{:.3f}, {:.3f}/{:.3f}".format(
|
||||
# t,
|
||||
# decode_result.compression_ratio, compression_ratio_threshold,
|
||||
# -decode_result.avg_logprob, -logprob_threshold,
|
||||
# decode_result.no_speech_prob, no_speech_threshold
|
||||
# ))
|
||||
|
||||
return decode_result
|
||||
|
||||
seek = 0
|
||||
input_stride = exact_div(
|
||||
N_FRAMES, model.dims.n_audio_ctx
|
||||
) # mel frames per output token: 2
|
||||
# 这里output token指的应该是CNN输出的那个东西
|
||||
|
||||
time_precision = (
|
||||
input_stride * HOP_LENGTH / SAMPLE_RATE
|
||||
) # time per output token: 0.02 (seconds)
|
||||
all_tokens = []
|
||||
all_segments = []
|
||||
prompt_reset_since = 0
|
||||
|
||||
if initial_prompt is not None:
|
||||
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
|
||||
all_tokens.extend(initial_prompt_tokens)
|
||||
else:
|
||||
initial_prompt_tokens = []
|
||||
|
||||
def new_segment(
|
||||
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
|
||||
):
|
||||
tokens = tokens.tolist()
|
||||
text_tokens = [token for token in tokens if token < tokenizer.eot]
|
||||
return {
|
||||
"seek": seek,
|
||||
"start": start,
|
||||
"end": end,
|
||||
"text": tokenizer.decode(text_tokens),
|
||||
"tokens": tokens,
|
||||
"temperature": result.temperature,
|
||||
"avg_logprob": result.avg_logprob,
|
||||
"compression_ratio": result.compression_ratio,
|
||||
"no_speech_prob": result.no_speech_prob,
|
||||
}
|
||||
|
||||
# show the progress bar when verbose is False (if True, transcribed text will be printed)
|
||||
with tqdm.tqdm(
|
||||
total=content_frames, unit="frames", disable=verbose is not False
|
||||
) as pbar:
|
||||
last_speech_timestamp = 0.0
|
||||
while seek < content_frames: # seek:标记mel频谱当前帧的位置 直接跳过Padding上的部分
|
||||
# print("seek segments", seek, content_frames)
|
||||
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) # 本片段的开始时间
|
||||
# mel_segment = mel[:, seek : seek + N_FRAMES] # 获得当前片段的数据
|
||||
mel_segment = mel[:, seek:]
|
||||
segment_size = min(N_FRAMES, content_frames - seek) # segment_size: 排除padding的真的长度。content_frames:有内容的段的真正长度 如果不够N_FRAMES的话就会截断
|
||||
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE # 当前片段的时长
|
||||
mel_segment = mel_segment.to(model.device).to(dtype) # pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype) # 补到mel_segment帧
|
||||
|
||||
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
||||
result: DecodingResult = decode_with_fallback(mel_segment)
|
||||
tokens = torch.tensor(result.tokens)
|
||||
|
||||
# 跳过静音部分
|
||||
if no_speech_threshold is not None:
|
||||
# no voice activity check
|
||||
should_skip = result.no_speech_prob > no_speech_threshold
|
||||
if (
|
||||
logprob_threshold is not None
|
||||
and result.avg_logprob > logprob_threshold
|
||||
):
|
||||
# don't skip if the logprob is high enough, despite the no_speech_prob
|
||||
should_skip = False
|
||||
|
||||
if should_skip:
|
||||
seek += segment_size # fast-forward to the next segment boundary
|
||||
continue
|
||||
|
||||
previous_seek = seek
|
||||
current_segments = []
|
||||
|
||||
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) # timestamp begin是<|0.00|>的token;bos比文字token大,eos的值比bos还大,所以是ge
|
||||
timestamp_tokens[-1] = False
|
||||
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] # 如果最后是[False,True]:本段里一个句子结束了
|
||||
|
||||
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
|
||||
# torch.where(condition) is identical to torch.nonzero(condition, as_tuple=True).
|
||||
# timestamp_token就是个一维向量吧 那为啥不直接nonzero
|
||||
# 如果有两个连续的时间戳 这个会是一个一维tensor 是这两个连续时间戳的结尾位置
|
||||
# 多个的话指向第二个 那如果有三个怎么办?
|
||||
# 否则是个0维tensor
|
||||
|
||||
consecutive.add_(1) # 0维tensor+1还是0维 哪儿找的这些edge cases js是吧
|
||||
if len(consecutive) > 0:
|
||||
# if the output contains two consecutive timestamp tokens
|
||||
slices = consecutive.tolist()
|
||||
if single_timestamp_ending:
|
||||
slices.append(len(tokens)) # 把最后一段的结尾也加进去
|
||||
# print("many sentenses", consecutive)
|
||||
last_slice = 0
|
||||
for current_slice in slices:
|
||||
sliced_tokens = tokens[last_slice:current_slice]
|
||||
# 看起来语音开始帧、语音结束帧的位置会被编码到start_timestamp中
|
||||
start_timestamp_pos = (
|
||||
sliced_tokens[0].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
end_timestamp_pos = (
|
||||
sliced_tokens[-1].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
# 获取一个新的语音段
|
||||
current_segments.append(
|
||||
new_segment(
|
||||
start=time_offset + start_timestamp_pos * time_precision,
|
||||
end=time_offset + end_timestamp_pos * time_precision,
|
||||
tokens=sliced_tokens,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
last_slice = current_slice
|
||||
|
||||
if single_timestamp_ending:
|
||||
# single timestamp at the end means no speech after the last timestamp.
|
||||
seek += segment_size
|
||||
else:
|
||||
# otherwise, ignore the unfinished segment and seek to the last timestamp
|
||||
# 如果语音尚未结束,那么seek变为上一个结束的语段的位置
|
||||
# 换句话说就是针对30s长的chunk的语音设计的
|
||||
last_timestamp_pos = (
|
||||
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
seek += last_timestamp_pos * input_stride
|
||||
else:
|
||||
duration = segment_duration
|
||||
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
||||
# print(timestamps)
|
||||
if (
|
||||
len(timestamps) > 0
|
||||
and timestamps[-1].item() != tokenizer.timestamp_begin
|
||||
):
|
||||
# no consecutive timestamps but it has a timestamp; use the last one.
|
||||
# 取最后一个;假设要么有一个结束的time stamp;要么有一对儿?
|
||||
# 如果里面只有一个开始的timestamp 似乎后面的东西都会被丢掉?
|
||||
last_timestamp_pos = (
|
||||
timestamps[-1].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
duration = last_timestamp_pos * time_precision
|
||||
|
||||
current_segments.append(
|
||||
new_segment(
|
||||
start=time_offset,
|
||||
end=time_offset + duration,
|
||||
tokens=tokens,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
seek += segment_size
|
||||
|
||||
# 每个token有自己的时间戳
|
||||
if word_timestamps:
|
||||
add_word_timestamps(
|
||||
segments=current_segments,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
mel=mel_segment,
|
||||
num_frames=segment_size,
|
||||
prepend_punctuations=prepend_punctuations,
|
||||
append_punctuations=append_punctuations,
|
||||
last_speech_timestamp=last_speech_timestamp,
|
||||
)
|
||||
word_end_timestamps = [
|
||||
w["end"] for s in current_segments for w in s["words"]
|
||||
]
|
||||
if len(word_end_timestamps) > 0:
|
||||
last_speech_timestamp = word_end_timestamps[-1]
|
||||
if not single_timestamp_ending and len(word_end_timestamps) > 0:
|
||||
seek_shift = round(
|
||||
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
|
||||
)
|
||||
if seek_shift > 0:
|
||||
seek = previous_seek + seek_shift
|
||||
|
||||
if verbose:
|
||||
for segment in current_segments:
|
||||
start, end, text = segment["start"], segment["end"], segment["text"]
|
||||
line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
|
||||
print(make_safe(line))
|
||||
|
||||
# if a segment is instantaneous or does not contain text, clear it
|
||||
for i, segment in enumerate(current_segments):
|
||||
if segment["start"] == segment["end"] or segment["text"].strip() == "":
|
||||
segment["text"] = ""
|
||||
segment["tokens"] = []
|
||||
segment["words"] = []
|
||||
|
||||
# 更新结果
|
||||
all_segments.extend(
|
||||
[
|
||||
{"id": i, **segment}
|
||||
for i, segment in enumerate(
|
||||
current_segments, start=len(all_segments)
|
||||
)
|
||||
]
|
||||
)
|
||||
all_tokens.extend(
|
||||
[token for segment in current_segments for token in segment["tokens"]]
|
||||
)
|
||||
|
||||
if not condition_on_previous_text or result.temperature > 0.5:
|
||||
# do not feed the prompt tokens if a high temperature was used
|
||||
prompt_reset_since = len(all_tokens)
|
||||
|
||||
# update progress bar
|
||||
pbar.update(min(content_frames, seek) - previous_seek)
|
||||
|
||||
# print("太长了")
|
||||
# break
|
||||
|
||||
return dict(
|
||||
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
|
||||
segments=all_segments,
|
||||
language=language,
|
||||
)
|
||||
|
||||
|
||||
def cli():
|
||||
from . import available_models
|
||||
|
||||
# fmt: off
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
||||
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
|
||||
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
|
||||
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
||||
|
||||
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
||||
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
|
||||
|
||||
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
||||
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
||||
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
|
||||
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
|
||||
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
|
||||
|
||||
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
||||
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
||||
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
||||
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
||||
|
||||
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
|
||||
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
|
||||
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
|
||||
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
||||
parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
|
||||
parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
|
||||
parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
|
||||
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
|
||||
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
|
||||
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
|
||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||
# fmt: on
|
||||
|
||||
args = parser.parse_args().__dict__
|
||||
model_name: str = args.pop("model")
|
||||
model_dir: str = args.pop("model_dir")
|
||||
output_dir: str = args.pop("output_dir")
|
||||
output_format: str = args.pop("output_format")
|
||||
device: str = args.pop("device")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
||||
if args["language"] is not None:
|
||||
warnings.warn(
|
||||
f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
|
||||
)
|
||||
args["language"] = "en"
|
||||
|
||||
temperature = args.pop("temperature")
|
||||
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
|
||||
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
|
||||
else:
|
||||
temperature = [temperature]
|
||||
|
||||
if (threads := args.pop("threads")) > 0:
|
||||
torch.set_num_threads(threads)
|
||||
|
||||
from . import load_model
|
||||
|
||||
model = load_model(model_name, device=device, download_root=model_dir)
|
||||
|
||||
writer = get_writer(output_format, output_dir)
|
||||
word_options = ["highlight_words", "max_line_count", "max_line_width"]
|
||||
if not args["word_timestamps"]:
|
||||
for option in word_options:
|
||||
if args[option]:
|
||||
parser.error(f"--{option} requires --word_timestamps True")
|
||||
if args["max_line_count"] and not args["max_line_width"]:
|
||||
warnings.warn("--max_line_count has no effect without --max_line_width")
|
||||
writer_args = {arg: args.pop(arg) for arg in word_options}
|
||||
for audio_path in args.pop("audio"):
|
||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
||||
writer(result, audio_path, writer_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
@@ -1,7 +1,8 @@
|
||||
import argparse
|
||||
import os
|
||||
import traceback
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -22,6 +23,7 @@ from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
||||
from .utils import (
|
||||
exact_div,
|
||||
format_timestamp,
|
||||
get_end,
|
||||
get_writer,
|
||||
make_safe,
|
||||
optional_float,
|
||||
@@ -44,9 +46,12 @@ def transcribe(
|
||||
no_speech_threshold: Optional[float] = 0.6,
|
||||
condition_on_previous_text: bool = True,
|
||||
initial_prompt: Optional[str] = None,
|
||||
carry_initial_prompt: bool = False,
|
||||
word_timestamps: bool = False,
|
||||
prepend_punctuations: str = "\"'“¿([{-",
|
||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||
clip_timestamps: Union[str, List[float]] = "0",
|
||||
hallucination_silence_threshold: Optional[float] = None,
|
||||
**decode_options,
|
||||
):
|
||||
"""
|
||||
@@ -98,15 +103,27 @@ def transcribe(
|
||||
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
||||
to make it more likely to predict those word correctly.
|
||||
|
||||
carry_initial_prompt: bool
|
||||
If carry_initial_prompt is True, `initial_prompt` is prepended to the prompt of each internal
|
||||
`decode()` call. If there is not enough context space at the start of the prompt, it is
|
||||
left-sliced to make space.
|
||||
|
||||
decode_options: dict
|
||||
Keyword arguments to construct `DecodingOptions` instances
|
||||
|
||||
clip_timestamps: Union[str, List[float]]
|
||||
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process.
|
||||
The last end timestamp defaults to the end of the file.
|
||||
|
||||
hallucination_silence_threshold: Optional[float]
|
||||
When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
|
||||
when a possible hallucination is detected
|
||||
|
||||
Returns
|
||||
-------
|
||||
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
||||
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
||||
"""
|
||||
# print("transcribe")
|
||||
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
|
||||
if model.device == torch.device("cpu"):
|
||||
if torch.cuda.is_available():
|
||||
@@ -119,8 +136,9 @@ def transcribe(
|
||||
decode_options["fp16"] = False
|
||||
|
||||
# Pad 30-seconds of silence to the input audio, for slicing
|
||||
mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
|
||||
mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
|
||||
content_frames = mel.shape[-1] - N_FRAMES
|
||||
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
|
||||
|
||||
if decode_options.get("language", None) is None:
|
||||
if not model.is_multilingual:
|
||||
@@ -131,7 +149,6 @@ def transcribe(
|
||||
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
|
||||
)
|
||||
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
||||
# print(mel_segment.shape)
|
||||
_, probs = model.detect_language(mel_segment)
|
||||
decode_options["language"] = max(probs, key=probs.get)
|
||||
if verbose is not None:
|
||||
@@ -141,7 +158,25 @@ def transcribe(
|
||||
|
||||
language: str = decode_options["language"]
|
||||
task: str = decode_options.get("task", "transcribe")
|
||||
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
|
||||
tokenizer = get_tokenizer(
|
||||
model.is_multilingual,
|
||||
num_languages=model.num_languages,
|
||||
language=language,
|
||||
task=task,
|
||||
)
|
||||
|
||||
if isinstance(clip_timestamps, str):
|
||||
clip_timestamps = [
|
||||
float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])
|
||||
]
|
||||
seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
|
||||
if len(seek_points) == 0:
|
||||
seek_points.append(0)
|
||||
if len(seek_points) % 2 == 1:
|
||||
seek_points.append(content_frames)
|
||||
seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))
|
||||
|
||||
punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"
|
||||
|
||||
if word_timestamps and task == "translate":
|
||||
warnings.warn("Word-level timestamps on translations may not be reliable.")
|
||||
@@ -179,6 +214,8 @@ def transcribe(
|
||||
if (
|
||||
no_speech_threshold is not None
|
||||
and decode_result.no_speech_prob > no_speech_threshold
|
||||
and logprob_threshold is not None
|
||||
and decode_result.avg_logprob < logprob_threshold
|
||||
):
|
||||
needs_fallback = False # silence
|
||||
if not needs_fallback:
|
||||
@@ -186,7 +223,8 @@ def transcribe(
|
||||
|
||||
return decode_result
|
||||
|
||||
seek = 0
|
||||
clip_idx = 0
|
||||
seek = seek_clips[clip_idx][0]
|
||||
input_stride = exact_div(
|
||||
N_FRAMES, model.dims.n_audio_ctx
|
||||
) # mel frames per output token: 2
|
||||
@@ -197,9 +235,11 @@ def transcribe(
|
||||
all_segments = []
|
||||
prompt_reset_since = 0
|
||||
|
||||
remaining_prompt_length = model.dims.n_text_ctx // 2 - 1
|
||||
if initial_prompt is not None:
|
||||
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
|
||||
all_tokens.extend(initial_prompt_tokens)
|
||||
remaining_prompt_length -= len(initial_prompt_tokens)
|
||||
else:
|
||||
initial_prompt_tokens = []
|
||||
|
||||
@@ -225,16 +265,33 @@ def transcribe(
|
||||
total=content_frames, unit="frames", disable=verbose is not False
|
||||
) as pbar:
|
||||
last_speech_timestamp = 0.0
|
||||
while seek < content_frames:
|
||||
# NOTE: This loop is obscurely flattened to make the diff readable.
|
||||
# A later commit should turn this into a simpler nested loop.
|
||||
# for seek_clip_start, seek_clip_end in seek_clips:
|
||||
# while seek < seek_clip_end
|
||||
while clip_idx < len(seek_clips):
|
||||
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
|
||||
if seek < seek_clip_start:
|
||||
seek = seek_clip_start
|
||||
if seek >= seek_clip_end:
|
||||
clip_idx += 1
|
||||
if clip_idx < len(seek_clips):
|
||||
seek = seek_clips[clip_idx][0]
|
||||
continue
|
||||
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
||||
mel_segment = mel[:, seek : seek + N_FRAMES]
|
||||
segment_size = min(N_FRAMES, content_frames - seek)
|
||||
window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
|
||||
segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek)
|
||||
mel_segment = mel[:, seek : seek + segment_size]
|
||||
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
||||
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
|
||||
|
||||
# print("melshape", mel_segment.shape)
|
||||
if carry_initial_prompt:
|
||||
nignored = max(len(initial_prompt_tokens), prompt_reset_since)
|
||||
remaining_prompt = all_tokens[nignored:][-remaining_prompt_length:]
|
||||
decode_options["prompt"] = initial_prompt_tokens + remaining_prompt
|
||||
else:
|
||||
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
||||
|
||||
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
||||
result: DecodingResult = decode_with_fallback(mel_segment)
|
||||
tokens = torch.tensor(result.tokens)
|
||||
|
||||
@@ -255,6 +312,30 @@ def transcribe(
|
||||
previous_seek = seek
|
||||
current_segments = []
|
||||
|
||||
# anomalous words are very long/short/improbable
|
||||
def word_anomaly_score(word: dict) -> float:
|
||||
probability = word.get("probability", 0.0)
|
||||
duration = word["end"] - word["start"]
|
||||
score = 0.0
|
||||
if probability < 0.15:
|
||||
score += 1.0
|
||||
if duration < 0.133:
|
||||
score += (0.133 - duration) * 15
|
||||
if duration > 2.0:
|
||||
score += duration - 2.0
|
||||
return score
|
||||
|
||||
def is_segment_anomaly(segment: Optional[dict]) -> bool:
|
||||
if segment is None or not segment["words"]:
|
||||
return False
|
||||
words = [w for w in segment["words"] if w["word"] not in punctuation]
|
||||
words = words[:8]
|
||||
score = sum(word_anomaly_score(w) for w in words)
|
||||
return score >= 3 or score + 0.01 >= len(words)
|
||||
|
||||
def next_words_segment(segments: List[dict]) -> Optional[dict]:
|
||||
return next((s for s in segments if s["words"]), None)
|
||||
|
||||
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
||||
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
|
||||
|
||||
@@ -317,9 +398,7 @@ def transcribe(
|
||||
)
|
||||
seek += segment_size
|
||||
|
||||
# print("word_timestamps, ", word_timestamps)
|
||||
if word_timestamps:
|
||||
# print("=========run timestamps here=========")
|
||||
add_word_timestamps(
|
||||
segments=current_segments,
|
||||
model=model,
|
||||
@@ -330,17 +409,71 @@ def transcribe(
|
||||
append_punctuations=append_punctuations,
|
||||
last_speech_timestamp=last_speech_timestamp,
|
||||
)
|
||||
word_end_timestamps = [
|
||||
w["end"] for s in current_segments for w in s["words"]
|
||||
]
|
||||
if len(word_end_timestamps) > 0:
|
||||
last_speech_timestamp = word_end_timestamps[-1]
|
||||
if not single_timestamp_ending and len(word_end_timestamps) > 0:
|
||||
seek_shift = round(
|
||||
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
|
||||
)
|
||||
if seek_shift > 0:
|
||||
seek = previous_seek + seek_shift
|
||||
|
||||
if not single_timestamp_ending:
|
||||
last_word_end = get_end(current_segments)
|
||||
if last_word_end is not None and last_word_end > time_offset:
|
||||
seek = round(last_word_end * FRAMES_PER_SECOND)
|
||||
|
||||
# skip silence before possible hallucinations
|
||||
if hallucination_silence_threshold is not None:
|
||||
threshold = hallucination_silence_threshold
|
||||
if not single_timestamp_ending:
|
||||
last_word_end = get_end(current_segments)
|
||||
if last_word_end is not None and last_word_end > time_offset:
|
||||
remaining_duration = window_end_time - last_word_end
|
||||
if remaining_duration > threshold:
|
||||
seek = round(last_word_end * FRAMES_PER_SECOND)
|
||||
else:
|
||||
seek = previous_seek + segment_size
|
||||
|
||||
# if first segment might be a hallucination, skip leading silence
|
||||
first_segment = next_words_segment(current_segments)
|
||||
if first_segment is not None and is_segment_anomaly(first_segment):
|
||||
gap = first_segment["start"] - time_offset
|
||||
if gap > threshold:
|
||||
seek = previous_seek + round(gap * FRAMES_PER_SECOND)
|
||||
continue
|
||||
|
||||
# skip silence before any possible hallucination that is surrounded
|
||||
# by silence or more hallucinations
|
||||
hal_last_end = last_speech_timestamp
|
||||
for si in range(len(current_segments)):
|
||||
segment = current_segments[si]
|
||||
if not segment["words"]:
|
||||
continue
|
||||
if is_segment_anomaly(segment):
|
||||
next_segment = next_words_segment(
|
||||
current_segments[si + 1 :]
|
||||
)
|
||||
if next_segment is not None:
|
||||
hal_next_start = next_segment["words"][0]["start"]
|
||||
else:
|
||||
hal_next_start = time_offset + segment_duration
|
||||
silence_before = (
|
||||
segment["start"] - hal_last_end > threshold
|
||||
or segment["start"] < threshold
|
||||
or segment["start"] - time_offset < 2.0
|
||||
)
|
||||
silence_after = (
|
||||
hal_next_start - segment["end"] > threshold
|
||||
or is_segment_anomaly(next_segment)
|
||||
or window_end_time - segment["end"] < 2.0
|
||||
)
|
||||
if silence_before and silence_after:
|
||||
seek = round(
|
||||
max(time_offset + 1, segment["start"])
|
||||
* FRAMES_PER_SECOND
|
||||
)
|
||||
if content_duration - segment["end"] < threshold:
|
||||
seek = content_frames
|
||||
current_segments[si:] = []
|
||||
break
|
||||
hal_last_end = segment["end"]
|
||||
|
||||
last_word_end = get_end(current_segments)
|
||||
if last_word_end is not None:
|
||||
last_speech_timestamp = last_word_end
|
||||
|
||||
if verbose:
|
||||
for segment in current_segments:
|
||||
@@ -384,10 +517,17 @@ def transcribe(
|
||||
def cli():
|
||||
from . import available_models
|
||||
|
||||
def valid_model_name(name):
|
||||
if name in available_models() or os.path.exists(name):
|
||||
return name
|
||||
raise ValueError(
|
||||
f"model should be one of {available_models()} or path to a model checkpoint"
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
||||
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
|
||||
parser.add_argument("--model", default="turbo", type=valid_model_name, help="name of the Whisper model to use")
|
||||
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||
@@ -405,6 +545,8 @@ def cli():
|
||||
|
||||
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
||||
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
||||
parser.add_argument("--carry_initial_prompt", type=str2bool, default=False, help="if True, prepend initial_prompt to every internal decode() call. May reduce the effectiveness of condition_on_previous_text")
|
||||
|
||||
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
||||
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
||||
|
||||
@@ -418,7 +560,10 @@ def cli():
|
||||
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
|
||||
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
|
||||
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
|
||||
parser.add_argument("--max_words_per_line", type=optional_int, default=None, help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment")
|
||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||
parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file")
|
||||
parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected")
|
||||
# fmt: on
|
||||
|
||||
args = parser.parse_args().__dict__
|
||||
@@ -450,17 +595,28 @@ def cli():
|
||||
model = load_model(model_name, device=device, download_root=model_dir)
|
||||
|
||||
writer = get_writer(output_format, output_dir)
|
||||
word_options = ["highlight_words", "max_line_count", "max_line_width"]
|
||||
word_options = [
|
||||
"highlight_words",
|
||||
"max_line_count",
|
||||
"max_line_width",
|
||||
"max_words_per_line",
|
||||
]
|
||||
if not args["word_timestamps"]:
|
||||
for option in word_options:
|
||||
if args[option]:
|
||||
parser.error(f"--{option} requires --word_timestamps True")
|
||||
if args["max_line_count"] and not args["max_line_width"]:
|
||||
warnings.warn("--max_line_count has no effect without --max_line_width")
|
||||
if args["max_words_per_line"] and args["max_line_width"]:
|
||||
warnings.warn("--max_words_per_line has no effect with --max_line_width")
|
||||
writer_args = {arg: args.pop(arg) for arg in word_options}
|
||||
for audio_path in args.pop("audio"):
|
||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
||||
writer(result, audio_path, writer_args)
|
||||
try:
|
||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
||||
writer(result, audio_path, **writer_args)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -60,7 +60,7 @@ def median_kernel(filter_width: int):
|
||||
tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
|
||||
|
||||
kernel = triton.JITFunction(kernel.fn)
|
||||
kernel.src = kernel.src.replace(
|
||||
new_kernel = kernel.src.replace(
|
||||
" LOAD_ALL_ROWS_HERE",
|
||||
"\n".join(
|
||||
[
|
||||
@@ -69,7 +69,8 @@ def median_kernel(filter_width: int):
|
||||
]
|
||||
),
|
||||
)
|
||||
kernel.src = kernel.src.replace(
|
||||
|
||||
new_kernel = new_kernel.replace(
|
||||
" BUBBLESORT_HERE",
|
||||
"\n\n".join(
|
||||
[
|
||||
@@ -90,7 +91,14 @@ def median_kernel(filter_width: int):
|
||||
]
|
||||
),
|
||||
)
|
||||
kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
|
||||
|
||||
new_kernel = new_kernel.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
|
||||
|
||||
if hasattr(kernel, "_unsafe_update_src") is True:
|
||||
kernel._unsafe_update_src(new_kernel)
|
||||
kernel.hash = None
|
||||
else:
|
||||
kernel.src = new_kernel
|
||||
|
||||
return kernel
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import re
|
||||
import sys
|
||||
import zlib
|
||||
from typing import Callable, Optional, TextIO
|
||||
from typing import Callable, List, Optional, TextIO
|
||||
|
||||
system_encoding = sys.getdefaultencoding()
|
||||
|
||||
@@ -68,13 +68,29 @@ def format_timestamp(
|
||||
)
|
||||
|
||||
|
||||
def get_start(segments: List[dict]) -> Optional[float]:
|
||||
return next(
|
||||
(w["start"] for s in segments for w in s["words"]),
|
||||
segments[0]["start"] if segments else None,
|
||||
)
|
||||
|
||||
|
||||
def get_end(segments: List[dict]) -> Optional[float]:
|
||||
return next(
|
||||
(w["end"] for s in reversed(segments) for w in reversed(s["words"])),
|
||||
segments[-1]["end"] if segments else None,
|
||||
)
|
||||
|
||||
|
||||
class ResultWriter:
|
||||
extension: str
|
||||
|
||||
def __init__(self, output_dir: str):
|
||||
self.output_dir = output_dir
|
||||
|
||||
def __call__(self, result: dict, audio_path: str, options: dict):
|
||||
def __call__(
|
||||
self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
audio_basename = os.path.basename(audio_path)
|
||||
audio_basename = os.path.splitext(audio_basename)[0]
|
||||
output_path = os.path.join(
|
||||
@@ -82,16 +98,20 @@ class ResultWriter:
|
||||
)
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
self.write_result(result, file=f, options=options)
|
||||
self.write_result(result, file=f, options=options, **kwargs)
|
||||
|
||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
||||
def write_result(
|
||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class WriteTXT(ResultWriter):
|
||||
extension: str = "txt"
|
||||
|
||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
||||
def write_result(
|
||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
for segment in result["segments"]:
|
||||
print(segment["text"].strip(), file=file, flush=True)
|
||||
|
||||
@@ -100,48 +120,76 @@ class SubtitlesWriter(ResultWriter):
|
||||
always_include_hours: bool
|
||||
decimal_marker: str
|
||||
|
||||
def iterate_result(self, result: dict, options: dict):
|
||||
raw_max_line_width: Optional[int] = options["max_line_width"]
|
||||
max_line_count: Optional[int] = options["max_line_count"]
|
||||
highlight_words: bool = options["highlight_words"]
|
||||
max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width
|
||||
preserve_segments = max_line_count is None or raw_max_line_width is None
|
||||
def iterate_result(
|
||||
self,
|
||||
result: dict,
|
||||
options: Optional[dict] = None,
|
||||
*,
|
||||
max_line_width: Optional[int] = None,
|
||||
max_line_count: Optional[int] = None,
|
||||
highlight_words: bool = False,
|
||||
max_words_per_line: Optional[int] = None,
|
||||
):
|
||||
options = options or {}
|
||||
max_line_width = max_line_width or options.get("max_line_width")
|
||||
max_line_count = max_line_count or options.get("max_line_count")
|
||||
highlight_words = highlight_words or options.get("highlight_words", False)
|
||||
max_words_per_line = max_words_per_line or options.get("max_words_per_line")
|
||||
preserve_segments = max_line_count is None or max_line_width is None
|
||||
max_line_width = max_line_width or 1000
|
||||
max_words_per_line = max_words_per_line or 1000
|
||||
|
||||
def iterate_subtitles():
|
||||
line_len = 0
|
||||
line_count = 1
|
||||
# the next subtitle to yield (a list of word timings with whitespace)
|
||||
subtitle: list[dict] = []
|
||||
last = result["segments"][0]["words"][0]["start"]
|
||||
subtitle: List[dict] = []
|
||||
last: float = get_start(result["segments"]) or 0.0
|
||||
for segment in result["segments"]:
|
||||
for i, original_timing in enumerate(segment["words"]):
|
||||
timing = original_timing.copy()
|
||||
long_pause = not preserve_segments and timing["start"] - last > 3.0
|
||||
has_room = line_len + len(timing["word"]) <= max_line_width
|
||||
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
|
||||
if line_len > 0 and has_room and not long_pause and not seg_break:
|
||||
# line continuation
|
||||
line_len += len(timing["word"])
|
||||
else:
|
||||
# new line
|
||||
timing["word"] = timing["word"].strip()
|
||||
chunk_index = 0
|
||||
words_count = max_words_per_line
|
||||
while chunk_index < len(segment["words"]):
|
||||
remaining_words = len(segment["words"]) - chunk_index
|
||||
if max_words_per_line > len(segment["words"]) - chunk_index:
|
||||
words_count = remaining_words
|
||||
for i, original_timing in enumerate(
|
||||
segment["words"][chunk_index : chunk_index + words_count]
|
||||
):
|
||||
timing = original_timing.copy()
|
||||
long_pause = (
|
||||
not preserve_segments and timing["start"] - last > 3.0
|
||||
)
|
||||
has_room = line_len + len(timing["word"]) <= max_line_width
|
||||
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
|
||||
if (
|
||||
len(subtitle) > 0
|
||||
and max_line_count is not None
|
||||
and (long_pause or line_count >= max_line_count)
|
||||
or seg_break
|
||||
line_len > 0
|
||||
and has_room
|
||||
and not long_pause
|
||||
and not seg_break
|
||||
):
|
||||
# subtitle break
|
||||
yield subtitle
|
||||
subtitle = []
|
||||
line_count = 1
|
||||
elif line_len > 0:
|
||||
# line break
|
||||
line_count += 1
|
||||
timing["word"] = "\n" + timing["word"]
|
||||
line_len = len(timing["word"].strip())
|
||||
subtitle.append(timing)
|
||||
last = timing["start"]
|
||||
# line continuation
|
||||
line_len += len(timing["word"])
|
||||
else:
|
||||
# new line
|
||||
timing["word"] = timing["word"].strip()
|
||||
if (
|
||||
len(subtitle) > 0
|
||||
and max_line_count is not None
|
||||
and (long_pause or line_count >= max_line_count)
|
||||
or seg_break
|
||||
):
|
||||
# subtitle break
|
||||
yield subtitle
|
||||
subtitle = []
|
||||
line_count = 1
|
||||
elif line_len > 0:
|
||||
# line break
|
||||
line_count += 1
|
||||
timing["word"] = "\n" + timing["word"]
|
||||
line_len = len(timing["word"].strip())
|
||||
subtitle.append(timing)
|
||||
last = timing["start"]
|
||||
chunk_index += max_words_per_line
|
||||
if len(subtitle) > 0:
|
||||
yield subtitle
|
||||
|
||||
@@ -161,9 +209,11 @@ class SubtitlesWriter(ResultWriter):
|
||||
|
||||
yield start, end, "".join(
|
||||
[
|
||||
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
|
||||
if j == i
|
||||
else word
|
||||
(
|
||||
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
|
||||
if j == i
|
||||
else word
|
||||
)
|
||||
for j, word in enumerate(all_words)
|
||||
]
|
||||
)
|
||||
@@ -190,9 +240,11 @@ class WriteVTT(SubtitlesWriter):
|
||||
always_include_hours: bool = False
|
||||
decimal_marker: str = "."
|
||||
|
||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
||||
def write_result(
|
||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
print("WEBVTT\n", file=file)
|
||||
for start, end, text in self.iterate_result(result, options):
|
||||
for start, end, text in self.iterate_result(result, options, **kwargs):
|
||||
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
|
||||
|
||||
|
||||
@@ -201,9 +253,11 @@ class WriteSRT(SubtitlesWriter):
|
||||
always_include_hours: bool = True
|
||||
decimal_marker: str = ","
|
||||
|
||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
||||
def write_result(
|
||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
for i, (start, end, text) in enumerate(
|
||||
self.iterate_result(result, options), start=1
|
||||
self.iterate_result(result, options, **kwargs), start=1
|
||||
):
|
||||
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
|
||||
|
||||
@@ -220,7 +274,9 @@ class WriteTSV(ResultWriter):
|
||||
|
||||
extension: str = "tsv"
|
||||
|
||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
||||
def write_result(
|
||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
print("start", "end", "text", sep="\t", file=file)
|
||||
for segment in result["segments"]:
|
||||
print(round(1000 * segment["start"]), file=file, end="\t")
|
||||
@@ -231,7 +287,9 @@ class WriteTSV(ResultWriter):
|
||||
class WriteJSON(ResultWriter):
|
||||
extension: str = "json"
|
||||
|
||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
||||
def write_result(
|
||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
json.dump(result, file)
|
||||
|
||||
|
||||
@@ -249,9 +307,11 @@ def get_writer(
|
||||
if output_format == "all":
|
||||
all_writers = [writer(output_dir) for writer in writers.values()]
|
||||
|
||||
def write_all(result: dict, file: TextIO, options: dict):
|
||||
def write_all(
|
||||
result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
for writer in all_writers:
|
||||
writer(result, file, options)
|
||||
writer(result, file, options, **kwargs)
|
||||
|
||||
return write_all
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "20230918"
|
||||
__version__ = "20250625"
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from datetime import timedelta
|
||||
|
||||
def format_time(seconds: float) -> str:
|
||||
"""Format seconds as HH:MM:SS."""
|
||||
return str(timedelta(seconds=int(seconds)))
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimedText:
|
||||
start: Optional[float]
|
||||
end: Optional[float]
|
||||
start: Optional[float] = 0
|
||||
end: Optional[float] = 0
|
||||
text: Optional[str] = ''
|
||||
speaker: Optional[int] = -1
|
||||
probability: Optional[float] = None
|
||||
@@ -29,4 +35,26 @@ class SpeakerSegment(TimedText):
|
||||
"""Represents a segment of audio attributed to a specific speaker.
|
||||
No text nor probability is associated with this segment.
|
||||
"""
|
||||
pass
|
||||
pass
|
||||
|
||||
@dataclass
|
||||
class Translation(TimedText):
|
||||
pass
|
||||
|
||||
@dataclass
|
||||
class Silence():
|
||||
duration: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class Line(TimedText):
|
||||
translation: str = ''
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'speaker': int(self.speaker),
|
||||
'text': self.text,
|
||||
'translation': self.translation,
|
||||
'start': format_time(self.start),
|
||||
'end': format_time(self.end),
|
||||
}
|
||||
@@ -1,73 +0,0 @@
|
||||
import torch
|
||||
import sys
|
||||
class TokenBuffer:
|
||||
|
||||
def __init__(self, text="", tokenizer=None, device=None, prefix_token_ids=[]):
|
||||
self.text = text
|
||||
self.prefix_token_ids = prefix_token_ids
|
||||
self.tokenizer = tokenizer
|
||||
self.device = device
|
||||
|
||||
def as_token_ids(self, tokenizer=None):
|
||||
|
||||
if tokenizer is None:
|
||||
tokenizer = self.tokenizer
|
||||
if tokenizer is None:
|
||||
raise ValueError("Tokenizer is not set.")
|
||||
return self.prefix_token_ids + tokenizer.encode(self.text)
|
||||
|
||||
def as_tensor(self, device=None):
|
||||
if device is None:
|
||||
device = self.device
|
||||
if device is None:
|
||||
raise ValueError("Device is not set.")
|
||||
tok_ids = self.as_token_ids()
|
||||
return torch.tensor(tok_ids,
|
||||
dtype=torch.long, device=device).unsqueeze(0)
|
||||
|
||||
def as_tensor_beam(self, beam, device=None):
|
||||
t = self.as_tensor(device=device)
|
||||
return t.repeat_interleave(beam, dim=0)
|
||||
|
||||
|
||||
def as_text(self):
|
||||
return self.text
|
||||
|
||||
@staticmethod
|
||||
def empty(*a, **kw):
|
||||
return TokenBuffer(*a,**kw)
|
||||
|
||||
@staticmethod
|
||||
def from_text(text, *a, **kw):
|
||||
return TokenBuffer(*a, text=text, **kw)
|
||||
|
||||
def is_empty(self):
|
||||
return self.text is None or self.text == ""
|
||||
|
||||
def trim_words(self, num=1, after=0):
|
||||
'''
|
||||
num: how many words to trim from the beginning
|
||||
after: how many characters to skip (length of the static prompt)
|
||||
'''
|
||||
tokenizer = self.tokenizer
|
||||
assert tokenizer is not None, "Tokenizer is not set."
|
||||
|
||||
ids = tokenizer.encode(self.text[after:])
|
||||
words, wids = self.tokenizer.split_to_word_tokens(ids)
|
||||
print(words, file=sys.stderr)
|
||||
print(wids, file=sys.stderr)
|
||||
if not words:
|
||||
return 0
|
||||
self.text = self.text[:after] + "".join(words[num:])
|
||||
return sum(len(wi) for wi in wids[:num])
|
||||
|
||||
def append_token_ids(self, token_ids):
|
||||
tokenizer = self.tokenizer
|
||||
assert tokenizer is not None, "Tokenizer is not set."
|
||||
self.text += self.tokenizer.decode(token_ids)
|
||||
|
||||
def as_split_word_tokens(self):
|
||||
tokenizer = self.tokenizer
|
||||
assert tokenizer is not None, "Tokenizer is not set."
|
||||
ids = tokenizer.encode(self.text)
|
||||
return tokenizer.split_to_word_tokens(ids)
|
||||
60
whisperlivekit/trail_repetition.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from typing import Sequence, Callable, Any, Optional, Dict
|
||||
|
||||
def _detect_tail_repetition(
|
||||
seq: Sequence[Any],
|
||||
key: Callable[[Any], Any] = lambda x: x, # extract comparable value
|
||||
min_block: int = 1, # set to 2 to ignore 1-token loops like "."
|
||||
max_tail: int = 300, # search window from the end for speed
|
||||
prefer: str = "longest", # "longest" coverage or "smallest" block
|
||||
) -> Optional[Dict]:
|
||||
vals = [key(x) for x in seq][-max_tail:]
|
||||
n = len(vals)
|
||||
best = None
|
||||
|
||||
# try every possible block length
|
||||
for b in range(min_block, n // 2 + 1):
|
||||
block = vals[-b:]
|
||||
# count how many times this block repeats contiguously at the very end
|
||||
count, i = 0, n
|
||||
while i - b >= 0 and vals[i - b:i] == block:
|
||||
count += 1
|
||||
i -= b
|
||||
|
||||
if count >= 2:
|
||||
cand = {
|
||||
"block_size": b,
|
||||
"count": count,
|
||||
"start_index": len(seq) - count * b, # in original seq
|
||||
"end_index": len(seq),
|
||||
}
|
||||
if (best is None or
|
||||
(prefer == "longest" and count * b > best["count"] * best["block_size"]) or
|
||||
(prefer == "smallest" and b < best["block_size"])):
|
||||
best = cand
|
||||
return best
|
||||
|
||||
def trim_tail_repetition(
|
||||
seq: Sequence[Any],
|
||||
key: Callable[[Any], Any] = lambda x: x,
|
||||
min_block: int = 1,
|
||||
max_tail: int = 300,
|
||||
prefer: str = "longest",
|
||||
keep: int = 1, # how many copies of the repeating block to keep at the end (0 or 1 are common)
|
||||
):
|
||||
"""
|
||||
Returns a new sequence with repeated tail trimmed.
|
||||
keep=1 -> keep a single copy of the repeated block.
|
||||
keep=0 -> remove all copies of the repeated block.
|
||||
"""
|
||||
rep = _detect_tail_repetition(seq, key, min_block, max_tail, prefer)
|
||||
if not rep:
|
||||
return seq, False # nothing to trim
|
||||
|
||||
b, c = rep["block_size"], rep["count"]
|
||||
if keep < 0:
|
||||
keep = 0
|
||||
if keep >= c:
|
||||
return seq, False # nothing to trim (already <= keep copies)
|
||||
# new length = total - (copies_to_remove * block_size)
|
||||
new_len = len(seq) - (c - keep) * b
|
||||
return seq[:new_len], True
|
||||
0
whisperlivekit/translation/__init__.py
Normal file
182
whisperlivekit/translation/mapping_languages.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""
|
||||
adapted from https://store.crowdin.com/custom-mt
|
||||
"""
|
||||
|
||||
LANGUAGES = [
|
||||
{"name": "Afrikaans", "nllb": "afr_Latn", "crowdin": "af"},
|
||||
{"name": "Akan", "nllb": "aka_Latn", "crowdin": "ak"},
|
||||
{"name": "Amharic", "nllb": "amh_Ethi", "crowdin": "am"},
|
||||
{"name": "Assamese", "nllb": "asm_Beng", "crowdin": "as"},
|
||||
{"name": "Asturian", "nllb": "ast_Latn", "crowdin": "ast"},
|
||||
{"name": "Bashkir", "nllb": "bak_Cyrl", "crowdin": "ba"},
|
||||
{"name": "Bambara", "nllb": "bam_Latn", "crowdin": "bm"},
|
||||
{"name": "Balinese", "nllb": "ban_Latn", "crowdin": "ban"},
|
||||
{"name": "Belarusian", "nllb": "bel_Cyrl", "crowdin": "be"},
|
||||
{"name": "Bengali", "nllb": "ben_Beng", "crowdin": "bn"},
|
||||
{"name": "Bosnian", "nllb": "bos_Latn", "crowdin": "bs"},
|
||||
{"name": "Bulgarian", "nllb": "bul_Cyrl", "crowdin": "bg"},
|
||||
{"name": "Catalan", "nllb": "cat_Latn", "crowdin": "ca"},
|
||||
{"name": "Cebuano", "nllb": "ceb_Latn", "crowdin": "ceb"},
|
||||
{"name": "Czech", "nllb": "ces_Latn", "crowdin": "cs"},
|
||||
{"name": "Welsh", "nllb": "cym_Latn", "crowdin": "cy"},
|
||||
{"name": "Danish", "nllb": "dan_Latn", "crowdin": "da"},
|
||||
{"name": "German", "nllb": "deu_Latn", "crowdin": "de"},
|
||||
{"name": "Dzongkha", "nllb": "dzo_Tibt", "crowdin": "dz"},
|
||||
{"name": "Greek", "nllb": "ell_Grek", "crowdin": "el"},
|
||||
{"name": "English", "nllb": "eng_Latn", "crowdin": "en"},
|
||||
{"name": "Esperanto", "nllb": "epo_Latn", "crowdin": "eo"},
|
||||
{"name": "Estonian", "nllb": "est_Latn", "crowdin": "et"},
|
||||
{"name": "Basque", "nllb": "eus_Latn", "crowdin": "eu"},
|
||||
{"name": "Ewe", "nllb": "ewe_Latn", "crowdin": "ee"},
|
||||
{"name": "Faroese", "nllb": "fao_Latn", "crowdin": "fo"},
|
||||
{"name": "Fijian", "nllb": "fij_Latn", "crowdin": "fj"},
|
||||
{"name": "Finnish", "nllb": "fin_Latn", "crowdin": "fi"},
|
||||
{"name": "French", "nllb": "fra_Latn", "crowdin": "fr"},
|
||||
{"name": "Friulian", "nllb": "fur_Latn", "crowdin": "fur-IT"},
|
||||
{"name": "Scottish Gaelic", "nllb": "gla_Latn", "crowdin": "gd"},
|
||||
{"name": "Irish", "nllb": "gle_Latn", "crowdin": "ga-IE"},
|
||||
{"name": "Galician", "nllb": "glg_Latn", "crowdin": "gl"},
|
||||
{"name": "Guarani", "nllb": "grn_Latn", "crowdin": "gn"},
|
||||
{"name": "Gujarati", "nllb": "guj_Gujr", "crowdin": "gu-IN"},
|
||||
{"name": "Haitian Creole", "nllb": "hat_Latn", "crowdin": "ht"},
|
||||
{"name": "Hausa", "nllb": "hau_Latn", "crowdin": "ha"},
|
||||
{"name": "Hebrew", "nllb": "heb_Hebr", "crowdin": "he"},
|
||||
{"name": "Hindi", "nllb": "hin_Deva", "crowdin": "hi"},
|
||||
{"name": "Croatian", "nllb": "hrv_Latn", "crowdin": "hr"},
|
||||
{"name": "Hungarian", "nllb": "hun_Latn", "crowdin": "hu"},
|
||||
{"name": "Armenian", "nllb": "hye_Armn", "crowdin": "hy-AM"},
|
||||
{"name": "Igbo", "nllb": "ibo_Latn", "crowdin": "ig"},
|
||||
{"name": "Indonesian", "nllb": "ind_Latn", "crowdin": "id"},
|
||||
{"name": "Icelandic", "nllb": "isl_Latn", "crowdin": "is"},
|
||||
{"name": "Italian", "nllb": "ita_Latn", "crowdin": "it"},
|
||||
{"name": "Javanese", "nllb": "jav_Latn", "crowdin": "jv"},
|
||||
{"name": "Japanese", "nllb": "jpn_Jpan", "crowdin": "ja"},
|
||||
{"name": "Kabyle", "nllb": "kab_Latn", "crowdin": "kab"},
|
||||
{"name": "Kannada", "nllb": "kan_Knda", "crowdin": "kn"},
|
||||
{"name": "Georgian", "nllb": "kat_Geor", "crowdin": "ka"},
|
||||
{"name": "Kazakh", "nllb": "kaz_Cyrl", "crowdin": "kk"},
|
||||
{"name": "Khmer", "nllb": "khm_Khmr", "crowdin": "km"},
|
||||
{"name": "Kinyarwanda", "nllb": "kin_Latn", "crowdin": "rw"},
|
||||
{"name": "Kyrgyz", "nllb": "kir_Cyrl", "crowdin": "ky"},
|
||||
{"name": "Korean", "nllb": "kor_Hang", "crowdin": "ko"},
|
||||
{"name": "Lao", "nllb": "lao_Laoo", "crowdin": "lo"},
|
||||
{"name": "Ligurian", "nllb": "lij_Latn", "crowdin": "lij"},
|
||||
{"name": "Limburgish", "nllb": "lim_Latn", "crowdin": "li"},
|
||||
{"name": "Lingala", "nllb": "lin_Latn", "crowdin": "ln"},
|
||||
{"name": "Lithuanian", "nllb": "lit_Latn", "crowdin": "lt"},
|
||||
{"name": "Luxembourgish", "nllb": "ltz_Latn", "crowdin": "lb"},
|
||||
{"name": "Maithili", "nllb": "mai_Deva", "crowdin": "mai"},
|
||||
{"name": "Malayalam", "nllb": "mal_Mlym", "crowdin": "ml-IN"},
|
||||
{"name": "Marathi", "nllb": "mar_Deva", "crowdin": "mr"},
|
||||
{"name": "Macedonian", "nllb": "mkd_Cyrl", "crowdin": "mk"},
|
||||
{"name": "Maltese", "nllb": "mlt_Latn", "crowdin": "mt"},
|
||||
{"name": "Mossi", "nllb": "mos_Latn", "crowdin": "mos"},
|
||||
{"name": "Maori", "nllb": "mri_Latn", "crowdin": "mi"},
|
||||
{"name": "Burmese", "nllb": "mya_Mymr", "crowdin": "my"},
|
||||
{"name": "Dutch", "nllb": "nld_Latn", "crowdin": "nl"},
|
||||
{"name": "Norwegian Nynorsk", "nllb": "nno_Latn", "crowdin": "nn-NO"},
|
||||
{"name": "Nepali", "nllb": "npi_Deva", "crowdin": "ne-NP"},
|
||||
{"name": "Northern Sotho", "nllb": "nso_Latn", "crowdin": "nso"},
|
||||
{"name": "Occitan", "nllb": "oci_Latn", "crowdin": "oc"},
|
||||
{"name": "Odia", "nllb": "ory_Orya", "crowdin": "or"},
|
||||
{"name": "Papiamento", "nllb": "pap_Latn", "crowdin": "pap"},
|
||||
{"name": "Polish", "nllb": "pol_Latn", "crowdin": "pl"},
|
||||
{"name": "Portuguese", "nllb": "por_Latn", "crowdin": "pt-PT"},
|
||||
{"name": "Dari", "nllb": "prs_Arab", "crowdin": "fa-AF"},
|
||||
{"name": "Romanian", "nllb": "ron_Latn", "crowdin": "ro"},
|
||||
{"name": "Rundi", "nllb": "run_Latn", "crowdin": "rn"},
|
||||
{"name": "Russian", "nllb": "rus_Cyrl", "crowdin": "ru"},
|
||||
{"name": "Sango", "nllb": "sag_Latn", "crowdin": "sg"},
|
||||
{"name": "Sanskrit", "nllb": "san_Deva", "crowdin": "sa"},
|
||||
{"name": "Santali", "nllb": "sat_Olck", "crowdin": "sat"},
|
||||
{"name": "Sinhala", "nllb": "sin_Sinh", "crowdin": "si-LK"},
|
||||
{"name": "Slovak", "nllb": "slk_Latn", "crowdin": "sk"},
|
||||
{"name": "Slovenian", "nllb": "slv_Latn", "crowdin": "sl"},
|
||||
{"name": "Shona", "nllb": "sna_Latn", "crowdin": "sn"},
|
||||
{"name": "Sindhi", "nllb": "snd_Arab", "crowdin": "sd"},
|
||||
{"name": "Somali", "nllb": "som_Latn", "crowdin": "so"},
|
||||
{"name": "Southern Sotho", "nllb": "sot_Latn", "crowdin": "st"},
|
||||
{"name": "Spanish", "nllb": "spa_Latn", "crowdin": "es-ES"},
|
||||
{"name": "Sardinian", "nllb": "srd_Latn", "crowdin": "sc"},
|
||||
{"name": "Swati", "nllb": "ssw_Latn", "crowdin": "ss"},
|
||||
{"name": "Sundanese", "nllb": "sun_Latn", "crowdin": "su"},
|
||||
{"name": "Swedish", "nllb": "swe_Latn", "crowdin": "sv-SE"},
|
||||
{"name": "Swahili", "nllb": "swh_Latn", "crowdin": "sw"},
|
||||
{"name": "Tamil", "nllb": "tam_Taml", "crowdin": "ta"},
|
||||
{"name": "Tatar", "nllb": "tat_Cyrl", "crowdin": "tt-RU"},
|
||||
{"name": "Telugu", "nllb": "tel_Telu", "crowdin": "te"},
|
||||
{"name": "Tajik", "nllb": "tgk_Cyrl", "crowdin": "tg"},
|
||||
{"name": "Tagalog", "nllb": "tgl_Latn", "crowdin": "tl"},
|
||||
{"name": "Thai", "nllb": "tha_Thai", "crowdin": "th"},
|
||||
{"name": "Tigrinya", "nllb": "tir_Ethi", "crowdin": "ti"},
|
||||
{"name": "Tswana", "nllb": "tsn_Latn", "crowdin": "tn"},
|
||||
{"name": "Tsonga", "nllb": "tso_Latn", "crowdin": "ts"},
|
||||
{"name": "Turkmen", "nllb": "tuk_Latn", "crowdin": "tk"},
|
||||
{"name": "Turkish", "nllb": "tur_Latn", "crowdin": "tr"},
|
||||
{"name": "Uyghur", "nllb": "uig_Arab", "crowdin": "ug"},
|
||||
{"name": "Ukrainian", "nllb": "ukr_Cyrl", "crowdin": "uk"},
|
||||
{"name": "Venetian", "nllb": "vec_Latn", "crowdin": "vec"},
|
||||
{"name": "Vietnamese", "nllb": "vie_Latn", "crowdin": "vi"},
|
||||
{"name": "Wolof", "nllb": "wol_Latn", "crowdin": "wo"},
|
||||
{"name": "Xhosa", "nllb": "xho_Latn", "crowdin": "xh"},
|
||||
{"name": "Yoruba", "nllb": "yor_Latn", "crowdin": "yo"},
|
||||
{"name": "Zulu", "nllb": "zul_Latn", "crowdin": "zu"},
|
||||
]
|
||||
|
||||
NAME_TO_NLLB = {lang["name"]: lang["nllb"] for lang in LANGUAGES}
|
||||
NAME_TO_CROWDIN = {lang["name"]: lang["crowdin"] for lang in LANGUAGES}
|
||||
CROWDIN_TO_NLLB = {lang["crowdin"]: lang["nllb"] for lang in LANGUAGES}
|
||||
NLLB_TO_CROWDIN = {lang["nllb"]: lang["crowdin"] for lang in LANGUAGES}
|
||||
CROWDIN_TO_NAME = {lang["crowdin"]: lang["name"] for lang in LANGUAGES}
|
||||
NLLB_TO_NAME = {lang["nllb"]: lang["name"] for lang in LANGUAGES}
|
||||
|
||||
|
||||
def get_nllb_code(crowdin_code):
|
||||
return CROWDIN_TO_NLLB.get(crowdin_code, None)
|
||||
|
||||
|
||||
def get_crowdin_code(nllb_code):
|
||||
return NLLB_TO_CROWDIN.get(nllb_code)
|
||||
|
||||
|
||||
def get_language_name_by_crowdin(crowdin_code):
|
||||
return CROWDIN_TO_NAME.get(crowdin_code)
|
||||
|
||||
|
||||
def get_language_name_by_nllb(nllb_code):
|
||||
return NLLB_TO_NAME.get(nllb_code)
|
||||
|
||||
|
||||
def get_language_info(identifier, identifier_type="auto"):
|
||||
if identifier_type == "auto":
|
||||
for lang in LANGUAGES:
|
||||
if (lang["name"].lower() == identifier.lower() or
|
||||
lang["nllb"] == identifier or
|
||||
lang["crowdin"] == identifier):
|
||||
return lang
|
||||
elif identifier_type == "name":
|
||||
for lang in LANGUAGES:
|
||||
if lang["name"].lower() == identifier.lower():
|
||||
return lang
|
||||
elif identifier_type == "nllb":
|
||||
for lang in LANGUAGES:
|
||||
if lang["nllb"] == identifier:
|
||||
return lang
|
||||
elif identifier_type == "crowdin":
|
||||
for lang in LANGUAGES:
|
||||
if lang["crowdin"] == identifier:
|
||||
return lang
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def list_all_languages():
|
||||
return [lang["name"] for lang in LANGUAGES]
|
||||
|
||||
|
||||
def list_all_nllb_codes():
|
||||
return [lang["nllb"] for lang in LANGUAGES]
|
||||
|
||||
|
||||
def list_all_crowdin_codes():
|
||||
return [lang["crowdin"] for lang in LANGUAGES]
|
||||
137
whisperlivekit/translation/translation.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import ctranslate2
|
||||
import torch
|
||||
import transformers
|
||||
from dataclasses import dataclass
|
||||
import huggingface_hub
|
||||
from whisperlivekit.translation.mapping_languages import get_nllb_code
|
||||
from whisperlivekit.timed_objects import Translation
|
||||
|
||||
|
||||
#In diarization case, we may want to translate just one speaker, or at least start the sentences there
|
||||
|
||||
PUNCTUATION_MARKS = {'.', '!', '?', '。', '!', '?'}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranslationModel():
|
||||
translator: ctranslate2.Translator
|
||||
tokenizer: dict
|
||||
|
||||
def load_model(src_langs):
|
||||
MODEL = 'nllb-200-distilled-600M-ctranslate2'
|
||||
MODEL_GUY = 'entai2965'
|
||||
huggingface_hub.snapshot_download(MODEL_GUY + '/' + MODEL,local_dir=MODEL)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
translator = ctranslate2.Translator(MODEL,device=device)
|
||||
tokenizer = dict()
|
||||
for src_lang in src_langs:
|
||||
tokenizer[src_lang] = transformers.AutoTokenizer.from_pretrained(MODEL, src_lang=src_lang, clean_up_tokenization_spaces=True)
|
||||
return TranslationModel(
|
||||
translator=translator,
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
|
||||
def translate(input, translation_model, tgt_lang):
|
||||
source = translation_model.tokenizer.convert_ids_to_tokens(translation_model.tokenizer.encode(input))
|
||||
target_prefix = [tgt_lang]
|
||||
results = translation_model.translator.translate_batch([source], target_prefix=[target_prefix])
|
||||
target = results[0].hypotheses[0][1:]
|
||||
return translation_model.tokenizer.decode(translation_model.tokenizer.convert_tokens_to_ids(target))
|
||||
|
||||
class OnlineTranslation:
|
||||
def __init__(self, translation_model: TranslationModel, input_languages: list, output_languages: list):
|
||||
self.buffer = []
|
||||
self.len_processed_buffer = 0
|
||||
self.translation_remaining = Translation()
|
||||
self.validated = []
|
||||
self.translation_pending_validation = ''
|
||||
self.translation_model = translation_model
|
||||
self.input_languages = input_languages
|
||||
self.output_languages = output_languages
|
||||
|
||||
def compute_common_prefix(self, results):
|
||||
#we dont want want to prune the result for the moment.
|
||||
if not self.buffer:
|
||||
self.buffer = results
|
||||
else:
|
||||
for i in range(min(len(self.buffer), len(results))):
|
||||
if self.buffer[i] != results[i]:
|
||||
self.commited.extend(self.buffer[:i])
|
||||
self.buffer = results[i:]
|
||||
|
||||
def translate(self, input, input_lang=None, output_lang=None):
|
||||
if not input:
|
||||
return ""
|
||||
if input_lang is None:
|
||||
input_lang = self.input_languages[0]
|
||||
if output_lang is None:
|
||||
output_lang = self.output_languages[0]
|
||||
nllb_output_lang = get_nllb_code(output_lang)
|
||||
|
||||
source = self.translation_model.tokenizer[input_lang].convert_ids_to_tokens(self.translation_model.tokenizer[input_lang].encode(input))
|
||||
results = self.translation_model.translator.translate_batch([source], target_prefix=[[nllb_output_lang]]) #we can use return_attention=True to try to optimize the stuff.
|
||||
target = results[0].hypotheses[0][1:]
|
||||
results = self.translation_model.tokenizer[input_lang].decode(self.translation_model.tokenizer[input_lang].convert_tokens_to_ids(target))
|
||||
return results
|
||||
|
||||
def translate_tokens(self, tokens):
|
||||
if tokens:
|
||||
text = ' '.join([token.text for token in tokens])
|
||||
start = tokens[0].start
|
||||
end = tokens[-1].end
|
||||
translated_text = self.translate(text)
|
||||
translation = Translation(
|
||||
text=translated_text,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
return translation
|
||||
return None
|
||||
|
||||
|
||||
|
||||
def insert_tokens(self, tokens):
|
||||
self.buffer.extend(tokens)
|
||||
pass
|
||||
|
||||
def process(self):
|
||||
i = 0
|
||||
if len(self.buffer) < self.len_processed_buffer + 3: #nothing new to process
|
||||
return self.validated + [self.translation_remaining]
|
||||
while i < len(self.buffer):
|
||||
if self.buffer[i].text in PUNCTUATION_MARKS:
|
||||
translation_sentence = self.translate_tokens(self.buffer[:i+1])
|
||||
self.validated.append(translation_sentence)
|
||||
self.buffer = self.buffer[i+1:]
|
||||
i = 0
|
||||
else:
|
||||
i+=1
|
||||
self.translation_remaining = self.translate_tokens(self.buffer)
|
||||
self.len_processed_buffer = len(self.buffer)
|
||||
return self.validated + [self.translation_remaining]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
output_lang = 'fr'
|
||||
input_lang = "en"
|
||||
|
||||
|
||||
test_string = """
|
||||
Transcription technology has improved so much in the past few years. Have you noticed how accurate real-time speech-to-text is now?
|
||||
"""
|
||||
test = test_string.split(' ')
|
||||
step = len(test) // 3
|
||||
|
||||
shared_model = load_model([input_lang])
|
||||
online_translation = OnlineTranslation(shared_model, input_languages=[input_lang], output_languages=[output_lang])
|
||||
|
||||
for id in range(5):
|
||||
val = test[id*step : (id+1)*step]
|
||||
val_str = ' '.join(val)
|
||||
result = online_translation.translate(val_str)
|
||||
print(result)
|
||||
|
||||
|
||||
|
||||
|
||||
# print(result)
|
||||
62
whisperlivekit/warmup.py
Normal file
@@ -0,0 +1,62 @@
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def load_file(warmup_file=None, timeout=5):
|
||||
import os
|
||||
import tempfile
|
||||
import librosa
|
||||
|
||||
if warmup_file is None:
|
||||
# Download JFK sample if not already present
|
||||
jfk_url = "https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav"
|
||||
temp_dir = tempfile.gettempdir()
|
||||
warmup_file = os.path.join(temp_dir, "whisper_warmup_jfk.wav")
|
||||
|
||||
if not os.path.exists(warmup_file):
|
||||
logger.debug(f"Downloading warmup file from {jfk_url}")
|
||||
print(f"Downloading warmup file from {jfk_url}")
|
||||
import time
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
import socket
|
||||
|
||||
original_timeout = socket.getdefaulttimeout()
|
||||
socket.setdefaulttimeout(timeout)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
urllib.request.urlretrieve(jfk_url, warmup_file)
|
||||
logger.debug(f"Download successful in {time.time() - start_time:.2f}s")
|
||||
except (urllib.error.URLError, socket.timeout) as e:
|
||||
logger.warning(f"Download failed: {e}. Proceeding without warmup.")
|
||||
return None
|
||||
finally:
|
||||
socket.setdefaulttimeout(original_timeout)
|
||||
elif not warmup_file:
|
||||
return None
|
||||
|
||||
if not warmup_file or not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0:
|
||||
logger.warning(f"Warmup file {warmup_file} invalid or missing.")
|
||||
return None
|
||||
|
||||
try:
|
||||
audio, sr = librosa.load(warmup_file, sr=16000)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load audio file: {e}")
|
||||
return None
|
||||
return audio
|
||||
|
||||
def warmup_asr(asr, warmup_file=None, timeout=5):
|
||||
"""
|
||||
Warmup the ASR model by transcribing a short audio file.
|
||||
"""
|
||||
audio = load_file(warmup_file=None, timeout=5)
|
||||
asr.transcribe(audio)
|
||||
logger.info("ASR model is warmed up")
|
||||
|
||||
def warmup_online(online, warmup_file=None, timeout=5):
|
||||
audio = load_file(warmup_file=None, timeout=5)
|
||||
online.warmup(audio)
|
||||
logger.warning("ASR is warmed up")
|
||||
517
whisperlivekit/web/live_transcription.css
Normal file
@@ -0,0 +1,517 @@
|
||||
:root {
|
||||
--bg: #ffffff;
|
||||
--text: #111111;
|
||||
--muted: #666666;
|
||||
--border: #e5e5e5;
|
||||
--chip-bg: rgba(0, 0, 0, 0.04);
|
||||
--chip-text: #000000;
|
||||
--spinner-border: #8d8d8d5c;
|
||||
--spinner-top: #b0b0b0;
|
||||
--silence-bg: #f3f3f3;
|
||||
--loading-bg: rgba(255, 77, 77, 0.06);
|
||||
--button-bg: #ffffff;
|
||||
--button-border: #e9e9e9;
|
||||
--wave-stroke: #000000;
|
||||
--label-dia-text: #868686;
|
||||
--label-trans-text: #111111;
|
||||
}
|
||||
|
||||
@media (prefers-color-scheme: dark) {
|
||||
:root:not([data-theme="light"]) {
|
||||
--bg: #0b0b0b;
|
||||
--text: #e6e6e6;
|
||||
--muted: #9aa0a6;
|
||||
--border: #333333;
|
||||
--chip-bg: rgba(255, 255, 255, 0.08);
|
||||
--chip-text: #e6e6e6;
|
||||
--spinner-border: #555555;
|
||||
--spinner-top: #dddddd;
|
||||
--silence-bg: #1a1a1a;
|
||||
--loading-bg: rgba(255, 77, 77, 0.12);
|
||||
--button-bg: #111111;
|
||||
--button-border: #333333;
|
||||
--wave-stroke: #e6e6e6;
|
||||
--label-dia-text: #b3b3b3;
|
||||
--label-trans-text: #ffffff;
|
||||
}
|
||||
}
|
||||
|
||||
:root[data-theme="dark"] {
|
||||
--bg: #0b0b0b;
|
||||
--text: #e6e6e6;
|
||||
--muted: #9aa0a6;
|
||||
--border: #333333;
|
||||
--chip-bg: rgba(255, 255, 255, 0.08);
|
||||
--chip-text: #e6e6e6;
|
||||
--spinner-border: #555555;
|
||||
--spinner-top: #dddddd;
|
||||
--silence-bg: #1a1a1a;
|
||||
--loading-bg: rgba(255, 77, 77, 0.12);
|
||||
--button-bg: #111111;
|
||||
--button-border: #333333;
|
||||
--wave-stroke: #e6e6e6;
|
||||
--label-dia-text: #b3b3b3;
|
||||
--label-trans-text: #ffffff;
|
||||
}
|
||||
|
||||
:root[data-theme="light"] {
|
||||
--bg: #ffffff;
|
||||
--text: #111111;
|
||||
--muted: #666666;
|
||||
--border: #e5e5e5;
|
||||
--chip-bg: rgba(0, 0, 0, 0.04);
|
||||
--chip-text: #000000;
|
||||
--spinner-border: #8d8d8d5c;
|
||||
--spinner-top: #b0b0b0;
|
||||
--silence-bg: #f3f3f3;
|
||||
--loading-bg: rgba(255, 77, 77, 0.06);
|
||||
--button-bg: #ffffff;
|
||||
--button-border: #e9e9e9;
|
||||
--wave-stroke: #000000;
|
||||
--label-dia-text: #868686;
|
||||
--label-trans-text: #111111;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: ui-sans-serif, system-ui, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol', 'Noto Color Emoji';
|
||||
margin: 0;
|
||||
text-align: center;
|
||||
background-color: var(--bg);
|
||||
color: var(--text);
|
||||
height: 100vh;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
/* Record button */
|
||||
#recordButton {
|
||||
width: 50px;
|
||||
height: 50px;
|
||||
border: none;
|
||||
border-radius: 50%;
|
||||
background-color: var(--button-bg);
|
||||
cursor: pointer;
|
||||
transition: all 0.3s ease;
|
||||
border: 1px solid var(--button-border);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
#recordButton.recording {
|
||||
width: 180px;
|
||||
border-radius: 40px;
|
||||
justify-content: flex-start;
|
||||
padding-left: 20px;
|
||||
}
|
||||
|
||||
#recordButton:active {
|
||||
transform: scale(0.95);
|
||||
}
|
||||
|
||||
.shape-container {
|
||||
width: 25px;
|
||||
height: 25px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.shape {
|
||||
width: 25px;
|
||||
height: 25px;
|
||||
background-color: rgb(209, 61, 53);
|
||||
border-radius: 50%;
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
|
||||
#recordButton:disabled .shape {
|
||||
background-color: #6e6d6d;
|
||||
}
|
||||
|
||||
#recordButton.recording .shape {
|
||||
border-radius: 5px;
|
||||
width: 25px;
|
||||
height: 25px;
|
||||
}
|
||||
|
||||
/* Recording elements */
|
||||
.recording-info {
|
||||
display: none;
|
||||
align-items: center;
|
||||
margin-left: 15px;
|
||||
flex-grow: 1;
|
||||
}
|
||||
|
||||
#recordButton.recording .recording-info {
|
||||
display: flex;
|
||||
}
|
||||
|
||||
.wave-container {
|
||||
width: 60px;
|
||||
height: 30px;
|
||||
position: relative;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
#waveCanvas {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
.timer {
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
color: var(--text);
|
||||
margin-left: 10px;
|
||||
}
|
||||
|
||||
#status {
|
||||
margin-top: 15px;
|
||||
font-size: 16px;
|
||||
color: var(--text);
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
.header-container {
|
||||
position: sticky;
|
||||
top: 0;
|
||||
background-color: var(--bg);
|
||||
z-index: 100;
|
||||
padding: 20px;
|
||||
}
|
||||
|
||||
/* Settings */
|
||||
.settings-container {
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
gap: 15px;
|
||||
}
|
||||
|
||||
.settings {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
align-items: flex-start;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.field {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: flex-start;
|
||||
gap: 3px;
|
||||
}
|
||||
|
||||
#chunkSelector,
|
||||
#websocketInput,
|
||||
#themeSelector,
|
||||
#microphoneSelect {
|
||||
font-size: 16px;
|
||||
padding: 5px 8px;
|
||||
border-radius: 8px;
|
||||
border: 1px solid var(--border);
|
||||
background-color: var(--button-bg);
|
||||
color: var(--text);
|
||||
max-height: 30px;
|
||||
}
|
||||
|
||||
#microphoneSelect {
|
||||
width: 100%;
|
||||
max-width: 190px;
|
||||
min-width: 120px;
|
||||
}
|
||||
|
||||
#chunkSelector:focus,
|
||||
#websocketInput:focus,
|
||||
#themeSelector:focus,
|
||||
#microphoneSelect:focus {
|
||||
outline: none;
|
||||
border-color: #007bff;
|
||||
box-shadow: 0 0 0 3px rgba(0, 123, 255, 0.15);
|
||||
}
|
||||
|
||||
label {
|
||||
font-size: 13px;
|
||||
color: var(--muted);
|
||||
}
|
||||
|
||||
.ws-default {
|
||||
font-size: 12px;
|
||||
color: var(--muted);
|
||||
}
|
||||
|
||||
/* Segmented pill control for Theme */
|
||||
.segmented {
|
||||
display: inline-flex;
|
||||
align-items: stretch;
|
||||
border: 1px solid var(--button-border);
|
||||
background-color: var(--button-bg);
|
||||
border-radius: 999px;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.segmented input[type="radio"] {
|
||||
position: absolute;
|
||||
opacity: 0;
|
||||
pointer-events: none;
|
||||
}
|
||||
|
||||
.theme-selector-container {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin-top: 17px;
|
||||
}
|
||||
|
||||
.segmented label {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
padding: 6px 12px;
|
||||
font-size: 14px;
|
||||
color: var(--muted);
|
||||
cursor: pointer;
|
||||
user-select: none;
|
||||
transition: background-color 0.2s ease, color 0.2s ease;
|
||||
}
|
||||
|
||||
.segmented label span {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.segmented label:hover span {
|
||||
display: inline;
|
||||
}
|
||||
|
||||
.segmented label:hover {
|
||||
background-color: var(--chip-bg);
|
||||
}
|
||||
|
||||
.segmented img {
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
}
|
||||
|
||||
.segmented input[type="radio"]:checked + label {
|
||||
background-color: var(--chip-bg);
|
||||
color: var(--text);
|
||||
}
|
||||
|
||||
.segmented input[type="radio"]:focus-visible + label,
|
||||
.segmented input[type="radio"]:focus + label {
|
||||
outline: 2px solid #007bff;
|
||||
outline-offset: 2px;
|
||||
border-radius: 999px;
|
||||
}
|
||||
|
||||
.transcript-container {
|
||||
flex: 1;
|
||||
overflow-y: auto;
|
||||
padding: 20px;
|
||||
scrollbar-width: none;
|
||||
-ms-overflow-style: none;
|
||||
}
|
||||
|
||||
.transcript-container::-webkit-scrollbar {
|
||||
display: none;
|
||||
}
|
||||
|
||||
/* Transcript area */
|
||||
#linesTranscript {
|
||||
margin: 0 auto;
|
||||
max-width: 700px;
|
||||
text-align: left;
|
||||
font-size: 16px;
|
||||
}
|
||||
|
||||
#linesTranscript p {
|
||||
margin: 0px 0;
|
||||
}
|
||||
|
||||
#linesTranscript strong {
|
||||
color: var(--text);
|
||||
}
|
||||
|
||||
#speaker {
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 100px;
|
||||
padding: 2px 10px;
|
||||
font-size: 14px;
|
||||
margin-bottom: 0px;
|
||||
}
|
||||
|
||||
.label_diarization {
|
||||
background-color: var(--chip-bg);
|
||||
border-radius: 8px 8px 8px 8px;
|
||||
padding: 2px 10px;
|
||||
margin-left: 10px;
|
||||
display: inline-block;
|
||||
white-space: nowrap;
|
||||
font-size: 14px;
|
||||
margin-bottom: 0px;
|
||||
color: var(--label-dia-text);
|
||||
}
|
||||
|
||||
.label_transcription {
|
||||
background-color: var(--chip-bg);
|
||||
border-radius: 8px 8px 8px 8px;
|
||||
padding: 2px 10px;
|
||||
display: inline-block;
|
||||
white-space: nowrap;
|
||||
margin-left: 10px;
|
||||
font-size: 14px;
|
||||
margin-bottom: 0px;
|
||||
color: var(--label-trans-text);
|
||||
}
|
||||
|
||||
.label_translation {
|
||||
background-color: var(--chip-bg);
|
||||
border-radius: 10px;
|
||||
padding: 4px 8px;
|
||||
margin-top: 4px;
|
||||
font-size: 14px;
|
||||
color: var(--text);
|
||||
display: flex;
|
||||
align-items: flex-start;
|
||||
gap: 4px;
|
||||
}
|
||||
|
||||
.label_translation img {
|
||||
margin-top: 2px;
|
||||
}
|
||||
|
||||
.label_translation img {
|
||||
width: 12px;
|
||||
height: 12px;
|
||||
}
|
||||
|
||||
#timeInfo {
|
||||
color: var(--muted);
|
||||
margin-left: 10px;
|
||||
}
|
||||
|
||||
.textcontent {
|
||||
font-size: 16px;
|
||||
padding-left: 10px;
|
||||
margin-bottom: 10px;
|
||||
margin-top: 1px;
|
||||
padding-top: 5px;
|
||||
border-radius: 0px 0px 0px 10px;
|
||||
}
|
||||
|
||||
.buffer_diarization {
|
||||
color: var(--label-dia-text);
|
||||
margin-left: 4px;
|
||||
}
|
||||
|
||||
.buffer_transcription {
|
||||
color: #7474748c;
|
||||
margin-left: 4px;
|
||||
}
|
||||
|
||||
.spinner {
|
||||
display: inline-block;
|
||||
width: 8px;
|
||||
height: 8px;
|
||||
border: 2px solid var(--spinner-border);
|
||||
border-top: 2px solid var(--spinner-top);
|
||||
border-radius: 50%;
|
||||
animation: spin 0.7s linear infinite;
|
||||
vertical-align: middle;
|
||||
margin-bottom: 2px;
|
||||
margin-right: 5px;
|
||||
}
|
||||
|
||||
@keyframes spin {
|
||||
to {
|
||||
transform: rotate(360deg);
|
||||
}
|
||||
}
|
||||
|
||||
.silence {
|
||||
color: var(--muted);
|
||||
background-color: var(--silence-bg);
|
||||
font-size: 13px;
|
||||
border-radius: 30px;
|
||||
padding: 2px 10px;
|
||||
display: none;
|
||||
}
|
||||
|
||||
.loading {
|
||||
color: var(--muted);
|
||||
background-color: var(--loading-bg);
|
||||
border-radius: 8px 8px 8px 0px;
|
||||
padding: 2px 10px;
|
||||
font-size: 14px;
|
||||
margin-bottom: 0px;
|
||||
}
|
||||
|
||||
/* for smaller screens */
|
||||
@media (max-width: 768px) {
|
||||
.header-container {
|
||||
padding: 15px;
|
||||
}
|
||||
|
||||
.settings-container {
|
||||
flex-direction: column;
|
||||
gap: 10px;
|
||||
}
|
||||
|
||||
.settings {
|
||||
justify-content: center;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.field {
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
#websocketInput,
|
||||
#microphoneSelect {
|
||||
min-width: 100px;
|
||||
max-width: 160px;
|
||||
}
|
||||
|
||||
.theme-selector-container {
|
||||
margin-top: 10px;
|
||||
}
|
||||
|
||||
.transcript-container {
|
||||
padding: 15px;
|
||||
}
|
||||
}
|
||||
|
||||
@media (max-width: 480px) {
|
||||
.header-container {
|
||||
padding: 10px;
|
||||
}
|
||||
|
||||
.settings {
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
}
|
||||
|
||||
#websocketInput,
|
||||
#microphoneSelect {
|
||||
max-width: 140px;
|
||||
}
|
||||
|
||||
.segmented label {
|
||||
padding: 4px 8px;
|
||||
font-size: 12px;
|
||||
}
|
||||
|
||||
.segmented img {
|
||||
width: 14px;
|
||||
height: 14px;
|
||||
}
|
||||
|
||||
.transcript-container {
|
||||
padding: 10px;
|
||||
}
|
||||
}
|
||||
@@ -4,679 +4,70 @@
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Audio Transcription</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: ui-sans-serif, system-ui, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol', 'Noto Color Emoji';
|
||||
margin: 20px;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
#recordButton {
|
||||
width: 50px;
|
||||
height: 50px;
|
||||
border: none;
|
||||
border-radius: 50%;
|
||||
background-color: white;
|
||||
cursor: pointer;
|
||||
transition: all 0.3s ease;
|
||||
border: 1px solid rgb(233, 233, 233);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
#recordButton.recording {
|
||||
width: 180px;
|
||||
border-radius: 40px;
|
||||
justify-content: flex-start;
|
||||
padding-left: 20px;
|
||||
}
|
||||
|
||||
#recordButton:active {
|
||||
transform: scale(0.95);
|
||||
}
|
||||
|
||||
.shape-container {
|
||||
width: 25px;
|
||||
height: 25px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.shape {
|
||||
width: 25px;
|
||||
height: 25px;
|
||||
background-color: rgb(209, 61, 53);
|
||||
border-radius: 50%;
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
|
||||
#recordButton:disabled .shape {
|
||||
background-color: #6e6d6d;
|
||||
}
|
||||
|
||||
#recordButton.recording .shape {
|
||||
border-radius: 5px;
|
||||
width: 25px;
|
||||
height: 25px;
|
||||
}
|
||||
|
||||
/* Recording elements */
|
||||
.recording-info {
|
||||
display: none;
|
||||
align-items: center;
|
||||
margin-left: 15px;
|
||||
flex-grow: 1;
|
||||
}
|
||||
|
||||
#recordButton.recording .recording-info {
|
||||
display: flex;
|
||||
}
|
||||
|
||||
.wave-container {
|
||||
width: 60px;
|
||||
height: 30px;
|
||||
position: relative;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
#waveCanvas {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
.timer {
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
color: #333;
|
||||
margin-left: 10px;
|
||||
}
|
||||
|
||||
#status {
|
||||
margin-top: 20px;
|
||||
font-size: 16px;
|
||||
color: #333;
|
||||
}
|
||||
|
||||
.settings-container {
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
gap: 15px;
|
||||
margin-top: 20px;
|
||||
}
|
||||
|
||||
.settings {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: flex-start;
|
||||
gap: 5px;
|
||||
}
|
||||
|
||||
#chunkSelector,
|
||||
#websocketInput {
|
||||
font-size: 16px;
|
||||
padding: 5px;
|
||||
border-radius: 5px;
|
||||
border: 1px solid #ddd;
|
||||
background-color: #ffffff;
|
||||
max-height: 30px;
|
||||
}
|
||||
|
||||
#websocketInput {
|
||||
width: 200px;
|
||||
}
|
||||
|
||||
#chunkSelector:focus,
|
||||
#websocketInput:focus {
|
||||
outline: none;
|
||||
border-color: #007bff;
|
||||
}
|
||||
|
||||
label {
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
/* Speaker-labeled transcript area */
|
||||
#linesTranscript {
|
||||
margin: 20px auto;
|
||||
max-width: 700px;
|
||||
text-align: left;
|
||||
font-size: 16px;
|
||||
}
|
||||
|
||||
#linesTranscript p {
|
||||
margin: 0px 0;
|
||||
}
|
||||
|
||||
#linesTranscript strong {
|
||||
color: #333;
|
||||
}
|
||||
|
||||
#speaker {
|
||||
border: 1px solid rgb(229, 229, 229);
|
||||
border-radius: 100px;
|
||||
padding: 2px 10px;
|
||||
font-size: 14px;
|
||||
margin-bottom: 0px;
|
||||
}
|
||||
.label_diarization {
|
||||
background-color: #ffffff66;
|
||||
border-radius: 8px 8px 8px 8px;
|
||||
padding: 2px 10px;
|
||||
margin-left: 10px;
|
||||
display: inline-block;
|
||||
white-space: nowrap;
|
||||
font-size: 14px;
|
||||
margin-bottom: 0px;
|
||||
color: rgb(134, 134, 134)
|
||||
}
|
||||
|
||||
.label_transcription {
|
||||
background-color: #ffffff66;
|
||||
border-radius: 8px 8px 8px 8px;
|
||||
padding: 2px 10px;
|
||||
display: inline-block;
|
||||
white-space: nowrap;
|
||||
margin-left: 10px;
|
||||
font-size: 14px;
|
||||
margin-bottom: 0px;
|
||||
color: #000000
|
||||
}
|
||||
|
||||
#timeInfo {
|
||||
color: #666;
|
||||
margin-left: 10px;
|
||||
}
|
||||
|
||||
.textcontent {
|
||||
font-size: 16px;
|
||||
/* margin-left: 10px; */
|
||||
padding-left: 10px;
|
||||
margin-bottom: 10px;
|
||||
margin-top: 1px;
|
||||
padding-top: 5px;
|
||||
border-radius: 0px 0px 0px 10px;
|
||||
}
|
||||
|
||||
.buffer_diarization {
|
||||
color: rgb(134, 134, 134);
|
||||
margin-left: 4px;
|
||||
}
|
||||
|
||||
.buffer_transcription {
|
||||
color: #7474748c;
|
||||
margin-left: 4px;
|
||||
}
|
||||
|
||||
|
||||
.spinner {
|
||||
display: inline-block;
|
||||
width: 8px;
|
||||
height: 8px;
|
||||
border: 2px solid #8d8d8d5c;
|
||||
border-top: 2px solid #6c6c6ce5;
|
||||
border-radius: 50%;
|
||||
animation: spin 0.6s linear infinite;
|
||||
vertical-align: middle;
|
||||
margin-bottom: 2px;
|
||||
margin-right: 5px;
|
||||
}
|
||||
|
||||
@keyframes spin {
|
||||
to {
|
||||
transform: rotate(360deg);
|
||||
}
|
||||
}
|
||||
|
||||
.silence {
|
||||
color: #666;
|
||||
background-color: #f3f3f3;
|
||||
font-size: 13px;
|
||||
border-radius: 30px;
|
||||
padding: 2px 10px;
|
||||
}
|
||||
|
||||
.loading {
|
||||
color: #666;
|
||||
background-color: #ff4d4d0f;
|
||||
border-radius: 8px 8px 8px 0px;
|
||||
padding: 2px 10px;
|
||||
font-size: 14px;
|
||||
margin-bottom: 0px;
|
||||
}
|
||||
</style>
|
||||
<title>WhisperLiveKit</title>
|
||||
<link rel="stylesheet" href="/web/live_transcription.css" />
|
||||
</head>
|
||||
|
||||
<body>
|
||||
|
||||
<div class="settings-container">
|
||||
<button id="recordButton">
|
||||
<div class="shape-container">
|
||||
<div class="shape"></div>
|
||||
</div>
|
||||
<div class="recording-info">
|
||||
<div class="wave-container">
|
||||
<canvas id="waveCanvas"></canvas>
|
||||
<div class="header-container">
|
||||
<div class="settings-container">
|
||||
<button id="recordButton">
|
||||
<div class="shape-container">
|
||||
<div class="shape"></div>
|
||||
</div>
|
||||
<div class="recording-info">
|
||||
<div class="wave-container">
|
||||
<canvas id="waveCanvas"></canvas>
|
||||
</div>
|
||||
<div class="timer">00:00</div>
|
||||
</div>
|
||||
</button>
|
||||
|
||||
<div class="settings">
|
||||
<div class="field">
|
||||
<label for="websocketInput">Websocket URL</label>
|
||||
<input id="websocketInput" type="text" placeholder="ws://host:port/asr" />
|
||||
</div>
|
||||
|
||||
<div class="field">
|
||||
<label id="microphoneSelectLabel" for="microphoneSelect">Select Microphone</label>
|
||||
<select id="microphoneSelect">
|
||||
<option value="">Default Microphone</option>
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<div class="theme-selector-container">
|
||||
<div class="segmented" role="radiogroup" aria-label="Theme selector">
|
||||
<input type="radio" id="theme-system" name="theme" value="system" />
|
||||
<label for="theme-system" title="System">
|
||||
<img src="/web/src/system_mode.svg" alt="" />
|
||||
<span>System</span>
|
||||
</label>
|
||||
|
||||
<input type="radio" id="theme-light" name="theme" value="light" />
|
||||
<label for="theme-light" title="Light">
|
||||
<img src="/web/src/light_mode.svg" alt="" />
|
||||
<span>Light</span>
|
||||
</label>
|
||||
|
||||
<input type="radio" id="theme-dark" name="theme" value="dark" />
|
||||
<label for="theme-dark" title="Dark">
|
||||
<img src="/web/src/dark_mode.svg" alt="" />
|
||||
<span>Dark</span>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
<div class="timer">00:00</div>
|
||||
</div>
|
||||
</button>
|
||||
<div class="settings">
|
||||
<div>
|
||||
<label for="chunkSelector">Chunk size (ms):</label>
|
||||
<select id="chunkSelector">
|
||||
<option value="500">500 ms</option>
|
||||
<option value="1000" selected>1000 ms</option>
|
||||
<option value="2000">2000 ms</option>
|
||||
<option value="3000">3000 ms</option>
|
||||
<option value="4000">4000 ms</option>
|
||||
<option value="5000">5000 ms</option>
|
||||
</select>
|
||||
</div>
|
||||
<div>
|
||||
<label for="websocketInput">WebSocket URL:</label>
|
||||
<input id="websocketInput" type="text" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<p id="status"></p>
|
||||
</div>
|
||||
|
||||
<p id="status"></p>
|
||||
<div class="transcript-container">
|
||||
<div id="linesTranscript"></div>
|
||||
</div>
|
||||
|
||||
<!-- Speaker-labeled transcript -->
|
||||
<div id="linesTranscript"></div>
|
||||
|
||||
<script>
|
||||
let isRecording = false;
|
||||
let websocket = null;
|
||||
let recorder = null;
|
||||
let chunkDuration = 1000;
|
||||
let websocketUrl = "ws://localhost:8000/asr";
|
||||
let userClosing = false;
|
||||
let startTime = null;
|
||||
let timerInterval = null;
|
||||
let audioContext = null;
|
||||
let analyser = null;
|
||||
let microphone = null;
|
||||
let waveCanvas = document.getElementById("waveCanvas");
|
||||
let waveCtx = waveCanvas.getContext("2d");
|
||||
let animationFrame = null;
|
||||
let waitingForStop = false;
|
||||
let lastReceivedData = null;
|
||||
waveCanvas.width = 60 * (window.devicePixelRatio || 1);
|
||||
waveCanvas.height = 30 * (window.devicePixelRatio || 1);
|
||||
waveCtx.scale(window.devicePixelRatio || 1, window.devicePixelRatio || 1);
|
||||
|
||||
const statusText = document.getElementById("status");
|
||||
const recordButton = document.getElementById("recordButton");
|
||||
const chunkSelector = document.getElementById("chunkSelector");
|
||||
const websocketInput = document.getElementById("websocketInput");
|
||||
const linesTranscriptDiv = document.getElementById("linesTranscript");
|
||||
const timerElement = document.querySelector(".timer");
|
||||
|
||||
const host = window.location.hostname || "localhost";
|
||||
const port = window.location.port;
|
||||
const protocol = window.location.protocol === "https:" ? "wss" : "ws";
|
||||
const defaultWebSocketUrl = `${protocol}://${host}:${port}/asr`;
|
||||
websocketInput.value = defaultWebSocketUrl;
|
||||
websocketUrl = defaultWebSocketUrl;
|
||||
|
||||
chunkSelector.addEventListener("change", () => {
|
||||
chunkDuration = parseInt(chunkSelector.value);
|
||||
});
|
||||
|
||||
websocketInput.addEventListener("change", () => {
|
||||
const urlValue = websocketInput.value.trim();
|
||||
if (!urlValue.startsWith("ws://") && !urlValue.startsWith("wss://")) {
|
||||
statusText.textContent = "Invalid WebSocket URL (must start with ws:// or wss://)";
|
||||
return;
|
||||
}
|
||||
websocketUrl = urlValue;
|
||||
statusText.textContent = "WebSocket URL updated. Ready to connect.";
|
||||
});
|
||||
|
||||
function setupWebSocket() {
|
||||
return new Promise((resolve, reject) => {
|
||||
try {
|
||||
websocket = new WebSocket(websocketUrl);
|
||||
} catch (error) {
|
||||
statusText.textContent = "Invalid WebSocket URL. Please check and try again.";
|
||||
reject(error);
|
||||
return;
|
||||
}
|
||||
|
||||
websocket.onopen = () => {
|
||||
statusText.textContent = "Connected to server.";
|
||||
resolve();
|
||||
};
|
||||
|
||||
websocket.onclose = () => {
|
||||
if (userClosing) {
|
||||
if (waitingForStop) {
|
||||
statusText.textContent = "Processing finalized or connection closed.";
|
||||
if (lastReceivedData) {
|
||||
renderLinesWithBuffer(
|
||||
lastReceivedData.lines || [],
|
||||
lastReceivedData.buffer_diarization || "",
|
||||
lastReceivedData.buffer_transcription || "",
|
||||
0, 0, true // isFinalizing = true
|
||||
);
|
||||
}
|
||||
}
|
||||
// If ready_to_stop was received, statusText is already "Finished processing..."
|
||||
// and waitingForStop is false.
|
||||
} else {
|
||||
statusText.textContent = "Disconnected from the WebSocket server. (Check logs if model is loading.)";
|
||||
if (isRecording) {
|
||||
stopRecording();
|
||||
}
|
||||
}
|
||||
isRecording = false;
|
||||
waitingForStop = false;
|
||||
userClosing = false;
|
||||
lastReceivedData = null;
|
||||
websocket = null;
|
||||
updateUI();
|
||||
};
|
||||
|
||||
websocket.onerror = () => {
|
||||
statusText.textContent = "Error connecting to WebSocket.";
|
||||
reject(new Error("Error connecting to WebSocket"));
|
||||
};
|
||||
|
||||
// Handle messages from server
|
||||
websocket.onmessage = (event) => {
|
||||
const data = JSON.parse(event.data);
|
||||
|
||||
// Check for status messages
|
||||
if (data.type === "ready_to_stop") {
|
||||
console.log("Ready to stop received, finalizing display and closing WebSocket.");
|
||||
waitingForStop = false;
|
||||
|
||||
if (lastReceivedData) {
|
||||
renderLinesWithBuffer(
|
||||
lastReceivedData.lines || [],
|
||||
lastReceivedData.buffer_diarization || "",
|
||||
lastReceivedData.buffer_transcription || "",
|
||||
0, // No more lag
|
||||
0, // No more lag
|
||||
true // isFinalizing = true
|
||||
);
|
||||
}
|
||||
statusText.textContent = "Finished processing audio! Ready to record again.";
|
||||
recordButton.disabled = false;
|
||||
|
||||
if (websocket) {
|
||||
websocket.close(); // will trigger onclose
|
||||
// websocket = null; // onclose handle setting websocket to null
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
lastReceivedData = data;
|
||||
|
||||
// Handle normal transcription updates
|
||||
const {
|
||||
lines = [],
|
||||
buffer_transcription = "",
|
||||
buffer_diarization = "",
|
||||
remaining_time_transcription = 0,
|
||||
remaining_time_diarization = 0,
|
||||
status = "active_transcription"
|
||||
} = data;
|
||||
|
||||
renderLinesWithBuffer(
|
||||
lines,
|
||||
buffer_diarization,
|
||||
buffer_transcription,
|
||||
remaining_time_diarization,
|
||||
remaining_time_transcription,
|
||||
false,
|
||||
status
|
||||
);
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
function renderLinesWithBuffer(lines, buffer_diarization, buffer_transcription, remaining_time_diarization, remaining_time_transcription, isFinalizing = false, current_status = "active_transcription") {
|
||||
if (current_status === "no_audio_detected") {
|
||||
linesTranscriptDiv.innerHTML = "<p style='text-align: center; color: #666; margin-top: 20px;'><em>No audio detected...</em></p>";
|
||||
return;
|
||||
}
|
||||
|
||||
const linesHtml = lines.map((item, idx) => {
|
||||
let timeInfo = "";
|
||||
if (item.beg !== undefined && item.end !== undefined) {
|
||||
timeInfo = ` ${item.beg} - ${item.end}`;
|
||||
}
|
||||
|
||||
let speakerLabel = "";
|
||||
if (item.speaker === -2) {
|
||||
speakerLabel = `<span class="silence">Silence<span id='timeInfo'>${timeInfo}</span></span>`;
|
||||
} else if (item.speaker == 0 && !isFinalizing) {
|
||||
speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'>${remaining_time_diarization} second(s) of audio are undergoing diarization</span></span>`;
|
||||
} else if (item.speaker == -1) {
|
||||
speakerLabel = `<span id="speaker">Speaker 1<span id='timeInfo'>${timeInfo}</span></span>`;
|
||||
} else if (item.speaker !== -1 && item.speaker !== 0) {
|
||||
speakerLabel = `<span id="speaker">Speaker ${item.speaker}<span id='timeInfo'>${timeInfo}</span></span>`;
|
||||
}
|
||||
|
||||
|
||||
let currentLineText = item.text || "";
|
||||
|
||||
if (idx === lines.length - 1) {
|
||||
if (!isFinalizing) {
|
||||
if (remaining_time_transcription > 0) {
|
||||
speakerLabel += `<span class="label_transcription"><span class="spinner"></span>Transcription lag <span id='timeInfo'>${remaining_time_transcription}s</span></span>`;
|
||||
}
|
||||
if (buffer_diarization && remaining_time_diarization > 0) {
|
||||
speakerLabel += `<span class="label_diarization"><span class="spinner"></span>Diarization lag<span id='timeInfo'>${remaining_time_diarization}s</span></span>`;
|
||||
}
|
||||
}
|
||||
|
||||
if (buffer_diarization) {
|
||||
if (isFinalizing) {
|
||||
currentLineText += (currentLineText.length > 0 && buffer_diarization.trim().length > 0 ? " " : "") + buffer_diarization.trim();
|
||||
} else {
|
||||
currentLineText += `<span class="buffer_diarization">${buffer_diarization}</span>`;
|
||||
}
|
||||
}
|
||||
if (buffer_transcription) {
|
||||
if (isFinalizing) {
|
||||
currentLineText += (currentLineText.length > 0 && buffer_transcription.trim().length > 0 ? " " : "") + buffer_transcription.trim();
|
||||
} else {
|
||||
currentLineText += `<span class="buffer_transcription">${buffer_transcription}</span>`;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return currentLineText.trim().length > 0 || speakerLabel.length > 0
|
||||
? `<p>${speakerLabel}<br/><div class='textcontent'>${currentLineText}</div></p>`
|
||||
: `<p>${speakerLabel}<br/></p>`;
|
||||
}).join("");
|
||||
|
||||
linesTranscriptDiv.innerHTML = linesHtml;
|
||||
}
|
||||
|
||||
function updateTimer() {
|
||||
if (!startTime) return;
|
||||
|
||||
const elapsed = Math.floor((Date.now() - startTime) / 1000);
|
||||
const minutes = Math.floor(elapsed / 60).toString().padStart(2, "0");
|
||||
const seconds = (elapsed % 60).toString().padStart(2, "0");
|
||||
timerElement.textContent = `${minutes}:${seconds}`;
|
||||
}
|
||||
|
||||
function drawWaveform() {
|
||||
if (!analyser) return;
|
||||
|
||||
const bufferLength = analyser.frequencyBinCount;
|
||||
const dataArray = new Uint8Array(bufferLength);
|
||||
analyser.getByteTimeDomainData(dataArray);
|
||||
|
||||
waveCtx.clearRect(0, 0, waveCanvas.width / (window.devicePixelRatio || 1), waveCanvas.height / (window.devicePixelRatio || 1));
|
||||
waveCtx.lineWidth = 1;
|
||||
waveCtx.strokeStyle = 'rgb(0, 0, 0)';
|
||||
waveCtx.beginPath();
|
||||
|
||||
const sliceWidth = (waveCanvas.width / (window.devicePixelRatio || 1)) / bufferLength;
|
||||
let x = 0;
|
||||
|
||||
for (let i = 0; i < bufferLength; i++) {
|
||||
const v = dataArray[i] / 128.0;
|
||||
const y = v * (waveCanvas.height / (window.devicePixelRatio || 1)) / 2;
|
||||
|
||||
if (i === 0) {
|
||||
waveCtx.moveTo(x, y);
|
||||
} else {
|
||||
waveCtx.lineTo(x, y);
|
||||
}
|
||||
|
||||
x += sliceWidth;
|
||||
}
|
||||
|
||||
waveCtx.lineTo(waveCanvas.width / (window.devicePixelRatio || 1), waveCanvas.height / (window.devicePixelRatio || 1) / 2);
|
||||
waveCtx.stroke();
|
||||
|
||||
animationFrame = requestAnimationFrame(drawWaveform);
|
||||
}
|
||||
|
||||
async function startRecording() {
|
||||
try {
|
||||
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
|
||||
|
||||
audioContext = new (window.AudioContext || window.webkitAudioContext)();
|
||||
analyser = audioContext.createAnalyser();
|
||||
analyser.fftSize = 256;
|
||||
microphone = audioContext.createMediaStreamSource(stream);
|
||||
microphone.connect(analyser);
|
||||
|
||||
recorder = new MediaRecorder(stream, { mimeType: "audio/webm" });
|
||||
recorder.ondataavailable = (e) => {
|
||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||
websocket.send(e.data);
|
||||
}
|
||||
};
|
||||
recorder.start(chunkDuration);
|
||||
|
||||
startTime = Date.now();
|
||||
timerInterval = setInterval(updateTimer, 1000);
|
||||
drawWaveform();
|
||||
|
||||
isRecording = true;
|
||||
updateUI();
|
||||
} catch (err) {
|
||||
statusText.textContent = "Error accessing microphone. Please allow microphone access.";
|
||||
console.error(err);
|
||||
}
|
||||
}
|
||||
|
||||
async function stopRecording() {
|
||||
userClosing = true;
|
||||
waitingForStop = true;
|
||||
|
||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||
// Send empty audio buffer as stop signal
|
||||
const emptyBlob = new Blob([], { type: 'audio/webm' });
|
||||
websocket.send(emptyBlob);
|
||||
statusText.textContent = "Recording stopped. Processing final audio...";
|
||||
}
|
||||
|
||||
if (recorder) {
|
||||
recorder.stop();
|
||||
recorder = null;
|
||||
}
|
||||
|
||||
if (microphone) {
|
||||
microphone.disconnect();
|
||||
microphone = null;
|
||||
}
|
||||
|
||||
if (analyser) {
|
||||
analyser = null;
|
||||
}
|
||||
|
||||
if (audioContext && audioContext.state !== 'closed') {
|
||||
try {
|
||||
audioContext.close();
|
||||
} catch (e) {
|
||||
console.warn("Could not close audio context:", e);
|
||||
}
|
||||
audioContext = null;
|
||||
}
|
||||
|
||||
if (animationFrame) {
|
||||
cancelAnimationFrame(animationFrame);
|
||||
animationFrame = null;
|
||||
}
|
||||
|
||||
if (timerInterval) {
|
||||
clearInterval(timerInterval);
|
||||
timerInterval = null;
|
||||
}
|
||||
timerElement.textContent = "00:00";
|
||||
startTime = null;
|
||||
|
||||
|
||||
isRecording = false;
|
||||
updateUI();
|
||||
}
|
||||
|
||||
async function toggleRecording() {
|
||||
if (!isRecording) {
|
||||
if (waitingForStop) {
|
||||
console.log("Waiting for stop, early return");
|
||||
return; // Early return, UI is already updated
|
||||
}
|
||||
console.log("Connecting to WebSocket");
|
||||
try {
|
||||
// If we have an active WebSocket that's still processing, just restart audio capture
|
||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||
await startRecording();
|
||||
} else {
|
||||
// If no active WebSocket or it's closed, create new one
|
||||
await setupWebSocket();
|
||||
await startRecording();
|
||||
}
|
||||
} catch (err) {
|
||||
statusText.textContent = "Could not connect to WebSocket or access mic. Aborted.";
|
||||
console.error(err);
|
||||
}
|
||||
} else {
|
||||
console.log("Stopping recording");
|
||||
stopRecording();
|
||||
}
|
||||
}
|
||||
|
||||
function updateUI() {
|
||||
recordButton.classList.toggle("recording", isRecording);
|
||||
recordButton.disabled = waitingForStop;
|
||||
|
||||
if (waitingForStop) {
|
||||
if (statusText.textContent !== "Recording stopped. Processing final audio...") {
|
||||
statusText.textContent = "Please wait for processing to complete...";
|
||||
}
|
||||
} else if (isRecording) {
|
||||
statusText.textContent = "Recording...";
|
||||
} else {
|
||||
if (statusText.textContent !== "Finished processing audio! Ready to record again." &&
|
||||
statusText.textContent !== "Processing finalized or connection closed.") {
|
||||
statusText.textContent = "Click to start transcription";
|
||||
}
|
||||
}
|
||||
if (!waitingForStop) {
|
||||
recordButton.disabled = false;
|
||||
}
|
||||
}
|
||||
|
||||
recordButton.addEventListener("click", toggleRecording);
|
||||
</script>
|
||||
<script src="/web/live_transcription.js"></script>
|
||||
</body>
|
||||
|
||||
</html>
|
||||
</html>
|
||||
609
whisperlivekit/web/live_transcription.js
Normal file
@@ -0,0 +1,609 @@
|
||||
/* Theme, WebSocket, recording, rendering logic extracted from inline script and adapted for segmented theme control and WS caption */
|
||||
|
||||
let isRecording = false;
|
||||
let websocket = null;
|
||||
let recorder = null;
|
||||
let chunkDuration = 100;
|
||||
let websocketUrl = "ws://localhost:8000/asr";
|
||||
let userClosing = false;
|
||||
let wakeLock = null;
|
||||
let startTime = null;
|
||||
let timerInterval = null;
|
||||
let audioContext = null;
|
||||
let analyser = null;
|
||||
let microphone = null;
|
||||
let waveCanvas = document.getElementById("waveCanvas");
|
||||
let waveCtx = waveCanvas.getContext("2d");
|
||||
let animationFrame = null;
|
||||
let waitingForStop = false;
|
||||
let lastReceivedData = null;
|
||||
let lastSignature = null;
|
||||
let availableMicrophones = [];
|
||||
let selectedMicrophoneId = null;
|
||||
|
||||
waveCanvas.width = 60 * (window.devicePixelRatio || 1);
|
||||
waveCanvas.height = 30 * (window.devicePixelRatio || 1);
|
||||
waveCtx.scale(window.devicePixelRatio || 1, window.devicePixelRatio || 1);
|
||||
|
||||
const statusText = document.getElementById("status");
|
||||
const recordButton = document.getElementById("recordButton");
|
||||
const chunkSelector = document.getElementById("chunkSelector");
|
||||
const websocketInput = document.getElementById("websocketInput");
|
||||
const websocketDefaultSpan = document.getElementById("wsDefaultUrl");
|
||||
const linesTranscriptDiv = document.getElementById("linesTranscript");
|
||||
const timerElement = document.querySelector(".timer");
|
||||
const themeRadios = document.querySelectorAll('input[name="theme"]');
|
||||
const microphoneSelect = document.getElementById("microphoneSelect");
|
||||
|
||||
function getWaveStroke() {
|
||||
const styles = getComputedStyle(document.documentElement);
|
||||
const v = styles.getPropertyValue("--wave-stroke").trim();
|
||||
return v || "#000";
|
||||
}
|
||||
|
||||
let waveStroke = getWaveStroke();
|
||||
function updateWaveStroke() {
|
||||
waveStroke = getWaveStroke();
|
||||
}
|
||||
|
||||
function applyTheme(pref) {
|
||||
if (pref === "light") {
|
||||
document.documentElement.setAttribute("data-theme", "light");
|
||||
} else if (pref === "dark") {
|
||||
document.documentElement.setAttribute("data-theme", "dark");
|
||||
} else {
|
||||
document.documentElement.removeAttribute("data-theme");
|
||||
}
|
||||
updateWaveStroke();
|
||||
}
|
||||
|
||||
// Persisted theme preference
|
||||
const savedThemePref = localStorage.getItem("themePreference") || "system";
|
||||
applyTheme(savedThemePref);
|
||||
if (themeRadios.length) {
|
||||
themeRadios.forEach((r) => {
|
||||
r.checked = r.value === savedThemePref;
|
||||
r.addEventListener("change", () => {
|
||||
if (r.checked) {
|
||||
localStorage.setItem("themePreference", r.value);
|
||||
applyTheme(r.value);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// React to OS theme changes when in "system" mode
|
||||
const darkMq = window.matchMedia && window.matchMedia("(prefers-color-scheme: dark)");
|
||||
const handleOsThemeChange = () => {
|
||||
const pref = localStorage.getItem("themePreference") || "system";
|
||||
if (pref === "system") updateWaveStroke();
|
||||
};
|
||||
if (darkMq && darkMq.addEventListener) {
|
||||
darkMq.addEventListener("change", handleOsThemeChange);
|
||||
} else if (darkMq && darkMq.addListener) {
|
||||
// deprecated, but included for Safari compatibility
|
||||
darkMq.addListener(handleOsThemeChange);
|
||||
}
|
||||
|
||||
async function enumerateMicrophones() {
|
||||
try {
|
||||
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
|
||||
stream.getTracks().forEach(track => track.stop());
|
||||
|
||||
const devices = await navigator.mediaDevices.enumerateDevices();
|
||||
availableMicrophones = devices.filter(device => device.kind === 'audioinput');
|
||||
|
||||
populateMicrophoneSelect();
|
||||
console.log(`Found ${availableMicrophones.length} microphone(s)`);
|
||||
} catch (error) {
|
||||
console.error('Error enumerating microphones:', error);
|
||||
statusText.textContent = "Error accessing microphones. Please grant permission.";
|
||||
}
|
||||
}
|
||||
|
||||
function populateMicrophoneSelect() {
|
||||
if (!microphoneSelect) return;
|
||||
|
||||
microphoneSelect.innerHTML = '<option value="">Default Microphone</option>';
|
||||
|
||||
availableMicrophones.forEach((device, index) => {
|
||||
const option = document.createElement('option');
|
||||
option.value = device.deviceId;
|
||||
option.textContent = device.label || `Microphone ${index + 1}`;
|
||||
microphoneSelect.appendChild(option);
|
||||
});
|
||||
|
||||
const savedMicId = localStorage.getItem('selectedMicrophone');
|
||||
if (savedMicId && availableMicrophones.some(mic => mic.deviceId === savedMicId)) {
|
||||
microphoneSelect.value = savedMicId;
|
||||
selectedMicrophoneId = savedMicId;
|
||||
}
|
||||
}
|
||||
|
||||
function handleMicrophoneChange() {
|
||||
selectedMicrophoneId = microphoneSelect.value || null;
|
||||
localStorage.setItem('selectedMicrophone', selectedMicrophoneId || '');
|
||||
|
||||
const selectedDevice = availableMicrophones.find(mic => mic.deviceId === selectedMicrophoneId);
|
||||
const deviceName = selectedDevice ? selectedDevice.label : 'Default Microphone';
|
||||
|
||||
console.log(`Selected microphone: ${deviceName}`);
|
||||
statusText.textContent = `Microphone changed to: ${deviceName}`;
|
||||
|
||||
if (isRecording) {
|
||||
statusText.textContent = "Switching microphone... Please wait.";
|
||||
stopRecording().then(() => {
|
||||
setTimeout(() => {
|
||||
toggleRecording();
|
||||
}, 1000);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Helpers
|
||||
function fmt1(x) {
|
||||
const n = Number(x);
|
||||
return Number.isFinite(n) ? n.toFixed(1) : x;
|
||||
}
|
||||
|
||||
// Default WebSocket URL computation
|
||||
const host = window.location.hostname || "localhost";
|
||||
const port = window.location.port;
|
||||
const protocol = window.location.protocol === "https:" ? "wss" : "ws";
|
||||
const defaultWebSocketUrl = `${protocol}://${host}${port ? ":" + port : ""}/asr`;
|
||||
|
||||
// Populate default caption and input
|
||||
if (websocketDefaultSpan) websocketDefaultSpan.textContent = defaultWebSocketUrl;
|
||||
websocketInput.value = defaultWebSocketUrl;
|
||||
websocketUrl = defaultWebSocketUrl;
|
||||
|
||||
// Optional chunk selector (guard for presence)
|
||||
if (chunkSelector) {
|
||||
chunkSelector.addEventListener("change", () => {
|
||||
chunkDuration = parseInt(chunkSelector.value);
|
||||
});
|
||||
}
|
||||
|
||||
// WebSocket input change handling
|
||||
websocketInput.addEventListener("change", () => {
|
||||
const urlValue = websocketInput.value.trim();
|
||||
if (!urlValue.startsWith("ws://") && !urlValue.startsWith("wss://")) {
|
||||
statusText.textContent = "Invalid WebSocket URL (must start with ws:// or wss://)";
|
||||
return;
|
||||
}
|
||||
websocketUrl = urlValue;
|
||||
statusText.textContent = "WebSocket URL updated. Ready to connect.";
|
||||
});
|
||||
|
||||
function setupWebSocket() {
|
||||
return new Promise((resolve, reject) => {
|
||||
try {
|
||||
websocket = new WebSocket(websocketUrl);
|
||||
} catch (error) {
|
||||
statusText.textContent = "Invalid WebSocket URL. Please check and try again.";
|
||||
reject(error);
|
||||
return;
|
||||
}
|
||||
|
||||
websocket.onopen = () => {
|
||||
statusText.textContent = "Connected to server.";
|
||||
resolve();
|
||||
};
|
||||
|
||||
websocket.onclose = () => {
|
||||
if (userClosing) {
|
||||
if (waitingForStop) {
|
||||
statusText.textContent = "Processing finalized or connection closed.";
|
||||
if (lastReceivedData) {
|
||||
renderLinesWithBuffer(
|
||||
lastReceivedData.lines || [],
|
||||
lastReceivedData.buffer_diarization || "",
|
||||
lastReceivedData.buffer_transcription || "",
|
||||
0,
|
||||
0,
|
||||
true
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
statusText.textContent = "Disconnected from the WebSocket server. (Check logs if model is loading.)";
|
||||
if (isRecording) {
|
||||
stopRecording();
|
||||
}
|
||||
}
|
||||
isRecording = false;
|
||||
waitingForStop = false;
|
||||
userClosing = false;
|
||||
lastReceivedData = null;
|
||||
websocket = null;
|
||||
updateUI();
|
||||
};
|
||||
|
||||
websocket.onerror = () => {
|
||||
statusText.textContent = "Error connecting to WebSocket.";
|
||||
reject(new Error("Error connecting to WebSocket"));
|
||||
};
|
||||
|
||||
websocket.onmessage = (event) => {
|
||||
const data = JSON.parse(event.data);
|
||||
|
||||
if (data.type === "ready_to_stop") {
|
||||
console.log("Ready to stop received, finalizing display and closing WebSocket.");
|
||||
waitingForStop = false;
|
||||
|
||||
if (lastReceivedData) {
|
||||
renderLinesWithBuffer(
|
||||
lastReceivedData.lines || [],
|
||||
lastReceivedData.buffer_diarization || "",
|
||||
lastReceivedData.buffer_transcription || "",
|
||||
0,
|
||||
0,
|
||||
true
|
||||
);
|
||||
}
|
||||
statusText.textContent = "Finished processing audio! Ready to record again.";
|
||||
recordButton.disabled = false;
|
||||
|
||||
if (websocket) {
|
||||
websocket.close();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
lastReceivedData = data;
|
||||
|
||||
const {
|
||||
lines = [],
|
||||
buffer_transcription = "",
|
||||
buffer_diarization = "",
|
||||
remaining_time_transcription = 0,
|
||||
remaining_time_diarization = 0,
|
||||
status = "active_transcription",
|
||||
} = data;
|
||||
|
||||
renderLinesWithBuffer(
|
||||
lines,
|
||||
buffer_diarization,
|
||||
buffer_transcription,
|
||||
remaining_time_diarization,
|
||||
remaining_time_transcription,
|
||||
false,
|
||||
status
|
||||
);
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
function renderLinesWithBuffer(
|
||||
lines,
|
||||
buffer_diarization,
|
||||
buffer_transcription,
|
||||
remaining_time_diarization,
|
||||
remaining_time_transcription,
|
||||
isFinalizing = false,
|
||||
current_status = "active_transcription"
|
||||
) {
|
||||
if (current_status === "no_audio_detected") {
|
||||
linesTranscriptDiv.innerHTML =
|
||||
"<p style='text-align: center; color: var(--muted); margin-top: 20px;'><em>No audio detected...</em></p>";
|
||||
return;
|
||||
}
|
||||
|
||||
const showLoading = !isFinalizing && (lines || []).some((it) => it.speaker == 0);
|
||||
const showTransLag = !isFinalizing && remaining_time_transcription > 0;
|
||||
const showDiaLag = !isFinalizing && !!buffer_diarization && remaining_time_diarization > 0;
|
||||
const signature = JSON.stringify({
|
||||
lines: (lines || []).map((it) => ({ speaker: it.speaker, text: it.text, start: it.start, end: it.end })),
|
||||
buffer_transcription: buffer_transcription || "",
|
||||
buffer_diarization: buffer_diarization || "",
|
||||
status: current_status,
|
||||
showLoading,
|
||||
showTransLag,
|
||||
showDiaLag,
|
||||
isFinalizing: !!isFinalizing,
|
||||
});
|
||||
if (lastSignature === signature) {
|
||||
const t = document.querySelector(".lag-transcription-value");
|
||||
if (t) t.textContent = fmt1(remaining_time_transcription);
|
||||
const d = document.querySelector(".lag-diarization-value");
|
||||
if (d) d.textContent = fmt1(remaining_time_diarization);
|
||||
const ld = document.querySelector(".loading-diarization-value");
|
||||
if (ld) ld.textContent = fmt1(remaining_time_diarization);
|
||||
return;
|
||||
}
|
||||
lastSignature = signature;
|
||||
|
||||
const linesHtml = (lines || [])
|
||||
.map((item, idx) => {
|
||||
let timeInfo = "";
|
||||
if (item.start !== undefined && item.end !== undefined) {
|
||||
timeInfo = ` ${item.start} - ${item.end}`;
|
||||
}
|
||||
|
||||
let speakerLabel = "";
|
||||
if (item.speaker === -2) {
|
||||
speakerLabel = `<span class="silence">Silence<span id='timeInfo'>${timeInfo}</span></span>`;
|
||||
} else if (item.speaker == 0 && !isFinalizing) {
|
||||
speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'><span class="loading-diarization-value">${fmt1(
|
||||
remaining_time_diarization
|
||||
)}</span> second(s) of audio are undergoing diarization</span></span>`;
|
||||
} else if (item.speaker !== 0) {
|
||||
speakerLabel = `<span id="speaker">Speaker ${item.speaker}<span id='timeInfo'>${timeInfo}</span></span>`;
|
||||
}
|
||||
|
||||
let currentLineText = item.text || "";
|
||||
|
||||
if (item.translation) {
|
||||
currentLineText += `<div class="label_translation">
|
||||
<img src="/web/src/translate.svg" alt="Translation" width="12" height="12" />
|
||||
<span>${item.translation}</span>
|
||||
</div>`;
|
||||
}
|
||||
|
||||
if (idx === lines.length - 1) {
|
||||
if (!isFinalizing && item.speaker !== -2) {
|
||||
if (remaining_time_transcription > 0) {
|
||||
speakerLabel += `<span class="label_transcription"><span class="spinner"></span>Transcription lag <span id='timeInfo'><span class="lag-transcription-value">${fmt1(
|
||||
remaining_time_transcription
|
||||
)}</span>s</span></span>`;
|
||||
}
|
||||
if (buffer_diarization && remaining_time_diarization > 0) {
|
||||
speakerLabel += `<span class="label_diarization"><span class="spinner"></span>Diarization lag<span id='timeInfo'><span class="lag-diarization-value">${fmt1(
|
||||
remaining_time_diarization
|
||||
)}</span>s</span></span>`;
|
||||
}
|
||||
}
|
||||
|
||||
if (buffer_diarization) {
|
||||
if (isFinalizing) {
|
||||
currentLineText +=
|
||||
(currentLineText.length > 0 && buffer_diarization.trim().length > 0 ? " " : "") + buffer_diarization.trim();
|
||||
} else {
|
||||
currentLineText += `<span class="buffer_diarization">${buffer_diarization}</span>`;
|
||||
}
|
||||
}
|
||||
if (buffer_transcription) {
|
||||
if (isFinalizing) {
|
||||
currentLineText +=
|
||||
(currentLineText.length > 0 && buffer_transcription.trim().length > 0 ? " " : "") +
|
||||
buffer_transcription.trim();
|
||||
} else {
|
||||
currentLineText += `<span class="buffer_transcription">${buffer_transcription}</span>`;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return currentLineText.trim().length > 0 || speakerLabel.length > 0
|
||||
? `<p>${speakerLabel}<br/><div class='textcontent'>${currentLineText}</div></p>`
|
||||
: `<p>${speakerLabel}<br/></p>`;
|
||||
})
|
||||
.join("");
|
||||
|
||||
linesTranscriptDiv.innerHTML = linesHtml;
|
||||
const transcriptContainer = document.querySelector('.transcript-container');
|
||||
if (transcriptContainer) {
|
||||
transcriptContainer.scrollTo({ top: transcriptContainer.scrollHeight, behavior: "smooth" });
|
||||
}
|
||||
}
|
||||
|
||||
function updateTimer() {
|
||||
if (!startTime) return;
|
||||
|
||||
const elapsed = Math.floor((Date.now() - startTime) / 1000);
|
||||
const minutes = Math.floor(elapsed / 60).toString().padStart(2, "0");
|
||||
const seconds = (elapsed % 60).toString().padStart(2, "0");
|
||||
timerElement.textContent = `${minutes}:${seconds}`;
|
||||
}
|
||||
|
||||
function drawWaveform() {
|
||||
if (!analyser) return;
|
||||
|
||||
const bufferLength = analyser.frequencyBinCount;
|
||||
const dataArray = new Uint8Array(bufferLength);
|
||||
analyser.getByteTimeDomainData(dataArray);
|
||||
|
||||
waveCtx.clearRect(
|
||||
0,
|
||||
0,
|
||||
waveCanvas.width / (window.devicePixelRatio || 1),
|
||||
waveCanvas.height / (window.devicePixelRatio || 1)
|
||||
);
|
||||
waveCtx.lineWidth = 1;
|
||||
waveCtx.strokeStyle = waveStroke;
|
||||
waveCtx.beginPath();
|
||||
|
||||
const sliceWidth = (waveCanvas.width / (window.devicePixelRatio || 1)) / bufferLength;
|
||||
let x = 0;
|
||||
|
||||
for (let i = 0; i < bufferLength; i++) {
|
||||
const v = dataArray[i] / 128.0;
|
||||
const y = (v * (waveCanvas.height / (window.devicePixelRatio || 1))) / 2;
|
||||
|
||||
if (i === 0) {
|
||||
waveCtx.moveTo(x, y);
|
||||
} else {
|
||||
waveCtx.lineTo(x, y);
|
||||
}
|
||||
|
||||
x += sliceWidth;
|
||||
}
|
||||
|
||||
waveCtx.lineTo(
|
||||
waveCanvas.width / (window.devicePixelRatio || 1),
|
||||
(waveCanvas.height / (window.devicePixelRatio || 1)) / 2
|
||||
);
|
||||
waveCtx.stroke();
|
||||
|
||||
animationFrame = requestAnimationFrame(drawWaveform);
|
||||
}
|
||||
|
||||
async function startRecording() {
|
||||
try {
|
||||
try {
|
||||
wakeLock = await navigator.wakeLock.request("screen");
|
||||
} catch (err) {
|
||||
console.log("Error acquiring wake lock.");
|
||||
}
|
||||
|
||||
const audioConstraints = selectedMicrophoneId
|
||||
? { audio: { deviceId: { exact: selectedMicrophoneId } } }
|
||||
: { audio: true };
|
||||
|
||||
const stream = await navigator.mediaDevices.getUserMedia(audioConstraints);
|
||||
|
||||
audioContext = new (window.AudioContext || window.webkitAudioContext)();
|
||||
analyser = audioContext.createAnalyser();
|
||||
analyser.fftSize = 256;
|
||||
microphone = audioContext.createMediaStreamSource(stream);
|
||||
microphone.connect(analyser);
|
||||
|
||||
recorder = new MediaRecorder(stream, { mimeType: "audio/webm" });
|
||||
recorder.ondataavailable = (e) => {
|
||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||
websocket.send(e.data);
|
||||
}
|
||||
};
|
||||
recorder.start(chunkDuration);
|
||||
|
||||
startTime = Date.now();
|
||||
timerInterval = setInterval(updateTimer, 1000);
|
||||
drawWaveform();
|
||||
|
||||
isRecording = true;
|
||||
updateUI();
|
||||
} catch (err) {
|
||||
if (window.location.hostname === "0.0.0.0") {
|
||||
statusText.textContent =
|
||||
"Error accessing microphone. Browsers may block microphone access on 0.0.0.0. Try using localhost:8000 instead.";
|
||||
} else {
|
||||
statusText.textContent = "Error accessing microphone. Please allow microphone access.";
|
||||
}
|
||||
console.error(err);
|
||||
}
|
||||
}
|
||||
|
||||
async function stopRecording() {
|
||||
if (wakeLock) {
|
||||
try {
|
||||
await wakeLock.release();
|
||||
} catch (e) {
|
||||
// ignore
|
||||
}
|
||||
wakeLock = null;
|
||||
}
|
||||
|
||||
userClosing = true;
|
||||
waitingForStop = true;
|
||||
|
||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||
const emptyBlob = new Blob([], { type: "audio/webm" });
|
||||
websocket.send(emptyBlob);
|
||||
statusText.textContent = "Recording stopped. Processing final audio...";
|
||||
}
|
||||
|
||||
if (recorder) {
|
||||
recorder.stop();
|
||||
recorder = null;
|
||||
}
|
||||
|
||||
if (microphone) {
|
||||
microphone.disconnect();
|
||||
microphone = null;
|
||||
}
|
||||
|
||||
if (analyser) {
|
||||
analyser = null;
|
||||
}
|
||||
|
||||
if (audioContext && audioContext.state !== "closed") {
|
||||
try {
|
||||
await audioContext.close();
|
||||
} catch (e) {
|
||||
console.warn("Could not close audio context:", e);
|
||||
}
|
||||
audioContext = null;
|
||||
}
|
||||
|
||||
if (animationFrame) {
|
||||
cancelAnimationFrame(animationFrame);
|
||||
animationFrame = null;
|
||||
}
|
||||
|
||||
if (timerInterval) {
|
||||
clearInterval(timerInterval);
|
||||
timerInterval = null;
|
||||
}
|
||||
timerElement.textContent = "00:00";
|
||||
startTime = null;
|
||||
|
||||
isRecording = false;
|
||||
updateUI();
|
||||
}
|
||||
|
||||
async function toggleRecording() {
|
||||
if (!isRecording) {
|
||||
if (waitingForStop) {
|
||||
console.log("Waiting for stop, early return");
|
||||
return;
|
||||
}
|
||||
console.log("Connecting to WebSocket");
|
||||
try {
|
||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||
await startRecording();
|
||||
} else {
|
||||
await setupWebSocket();
|
||||
await startRecording();
|
||||
}
|
||||
} catch (err) {
|
||||
statusText.textContent = "Could not connect to WebSocket or access mic. Aborted.";
|
||||
console.error(err);
|
||||
}
|
||||
} else {
|
||||
console.log("Stopping recording");
|
||||
stopRecording();
|
||||
}
|
||||
}
|
||||
|
||||
function updateUI() {
|
||||
recordButton.classList.toggle("recording", isRecording);
|
||||
recordButton.disabled = waitingForStop;
|
||||
|
||||
if (waitingForStop) {
|
||||
if (statusText.textContent !== "Recording stopped. Processing final audio...") {
|
||||
statusText.textContent = "Please wait for processing to complete...";
|
||||
}
|
||||
} else if (isRecording) {
|
||||
statusText.textContent = "Recording...";
|
||||
} else {
|
||||
if (
|
||||
statusText.textContent !== "Finished processing audio! Ready to record again." &&
|
||||
statusText.textContent !== "Processing finalized or connection closed."
|
||||
) {
|
||||
statusText.textContent = "Click to start transcription";
|
||||
}
|
||||
}
|
||||
if (!waitingForStop) {
|
||||
recordButton.disabled = false;
|
||||
}
|
||||
}
|
||||
|
||||
recordButton.addEventListener("click", toggleRecording);
|
||||
|
||||
if (microphoneSelect) {
|
||||
microphoneSelect.addEventListener("change", handleMicrophoneChange);
|
||||
}
|
||||
document.addEventListener('DOMContentLoaded', async () => {
|
||||
try {
|
||||
await enumerateMicrophones();
|
||||
} catch (error) {
|
||||
console.log("Could not enumerate microphones on load:", error);
|
||||
}
|
||||
});
|
||||
navigator.mediaDevices.addEventListener('devicechange', async () => {
|
||||
console.log('Device change detected, re-enumerating microphones');
|
||||
try {
|
||||
await enumerateMicrophones();
|
||||
} catch (error) {
|
||||
console.log("Error re-enumerating microphones:", error);
|
||||
}
|
||||
});
|
||||
1
whisperlivekit/web/src/dark_mode.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M480-120q-151 0-255.5-104.5T120-480q0-138 90-239.5T440-838q13-2 23 3.5t16 14.5q6 9 6.5 21t-7.5 23q-17 26-25.5 55t-8.5 61q0 90 63 153t153 63q31 0 61.5-9t54.5-25q11-7 22.5-6.5T819-479q10 5 15.5 15t3.5 24q-14 138-117.5 229T480-120Zm0-80q88 0 158-48.5T740-375q-20 5-40 8t-40 3q-123 0-209.5-86.5T364-660q0-20 3-40t8-40q-78 32-126.5 102T200-480q0 116 82 198t198 82Zm-10-270Z"/></svg>
|
||||
|
After Width: | Height: | Size: 493 B |
1
whisperlivekit/web/src/light_mode.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M480-360q50 0 85-35t35-85q0-50-35-85t-85-35q-50 0-85 35t-35 85q0 50 35 85t85 35Zm0 80q-83 0-141.5-58.5T280-480q0-83 58.5-141.5T480-680q83 0 141.5 58.5T680-480q0 83-58.5 141.5T480-280ZM80-440q-17 0-28.5-11.5T40-480q0-17 11.5-28.5T80-520h80q17 0 28.5 11.5T200-480q0 17-11.5 28.5T160-440H80Zm720 0q-17 0-28.5-11.5T760-480q0-17 11.5-28.5T800-520h80q17 0 28.5 11.5T920-480q0 17-11.5 28.5T880-440h-80ZM480-760q-17 0-28.5-11.5T440-800v-80q0-17 11.5-28.5T480-920q17 0 28.5 11.5T520-880v80q0 17-11.5 28.5T480-760Zm0 720q-17 0-28.5-11.5T440-80v-80q0-17 11.5-28.5T480-200q17 0 28.5 11.5T520-160v80q0 17-11.5 28.5T480-40ZM226-678l-43-42q-12-11-11.5-28t11.5-29q12-12 29-12t28 12l42 43q11 12 11 28t-11 28q-11 12-27.5 11.5T226-678Zm494 495-42-43q-11-12-11-28.5t11-27.5q11-12 27.5-11.5T734-282l43 42q12 11 11.5 28T777-183q-12 12-29 12t-28-12Zm-42-495q-12-11-11.5-27.5T678-734l42-43q11-12 28-11.5t29 11.5q12 12 12 29t-12 28l-43 42q-12 11-28 11t-28-11ZM183-183q-12-12-12-29t12-28l43-42q12-11 28.5-11t27.5 11q12 11 11.5 27.5T282-226l-42 43q-11 12-28 11.5T183-183Zm297-297Z"/></svg>
|
||||
|
After Width: | Height: | Size: 1.2 KiB |
1
whisperlivekit/web/src/settings.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M433-80q-27 0-46.5-18T363-142l-9-66q-13-5-24.5-12T307-235l-62 26q-25 11-50 2t-39-32l-47-82q-14-23-8-49t27-43l53-40q-1-7-1-13.5v-27q0-6.5 1-13.5l-53-40q-21-17-27-43t8-49l47-82q14-23 39-32t50 2l62 26q11-8 23-15t24-12l9-66q4-26 23.5-44t46.5-18h94q27 0 46.5 18t23.5 44l9 66q13 5 24.5 12t22.5 15l62-26q25-11 50-2t39 32l47 82q14 23 8 49t-27 43l-53 40q1 7 1 13.5v27q0 6.5-2 13.5l53 40q21 17 27 43t-8 49l-48 82q-14 23-39 32t-50-2l-60-26q-11 8-23 15t-24 12l-9 66q-4 26-23.5 44T527-80h-94Zm7-80h79l14-106q31-8 57.5-23.5T639-327l99 41 39-68-86-65q5-14 7-29.5t2-31.5q0-16-2-31.5t-7-29.5l86-65-39-68-99 42q-22-23-48.5-38.5T533-694l-13-106h-79l-14 106q-31 8-57.5 23.5T321-633l-99-41-39 68 86 64q-5 15-7 30t-2 32q0 16 2 31t7 30l-86 65 39 68 99-42q22 23 48.5 38.5T427-266l13 106Zm42-180q58 0 99-41t41-99q0-58-41-99t-99-41q-59 0-99.5 41T342-480q0 58 40.5 99t99.5 41Zm-2-140Z"/></svg>
|
||||
|
After Width: | Height: | Size: 982 B |
1
whisperlivekit/web/src/system_mode.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M396-396q-32-32-58.5-67T289-537q-5 14-6.5 28.5T281-480q0 83 58 141t141 58q14 0 28.5-2t28.5-6q-39-22-74-48.5T396-396Zm85 196q-56 0-107-21t-91-61q-40-40-61-91t-21-107q0-51 17-97.5t50-84.5q13-14 32-9.5t27 24.5q21 55 52.5 104t73.5 91q42 42 91 73.5T648-326q20 8 24.5 27t-9.5 32q-38 33-84.5 50T481-200Zm223-192q-16-5-23-20.5t-4-32.5q9-48-6-94.5T621-621q-35-35-80.5-49.5T448-677q-17 3-32-4t-21-23q-6-16 1.5-31t23.5-19q69-15 138 4.5T679-678q51 51 71 120t5 138q-4 17-19 25t-32 3ZM480-840q-17 0-28.5-11.5T440-880v-40q0-17 11.5-28.5T480-960q17 0 28.5 11.5T520-920v40q0 17-11.5 28.5T480-840Zm0 840q-17 0-28.5-11.5T440-40v-40q0-17 11.5-28.5T480-120q17 0 28.5 11.5T520-80v40q0 17-11.5 28.5T480 0Zm255-734q-12-12-12-28.5t12-28.5l28-28q11-11 27.5-11t28.5 11q12 12 12 28.5T819-762l-28 28q-12 12-28 12t-28-12ZM141-141q-12-12-12-28.5t12-28.5l28-28q12-12 28-12t28 12q12 12 12 28.5T225-169l-28 28q-11 11-27.5 11T141-141Zm739-299q-17 0-28.5-11.5T840-480q0-17 11.5-28.5T880-520h40q17 0 28.5 11.5T960-480q0 17-11.5 28.5T920-440h-40Zm-840 0q-17 0-28.5-11.5T0-480q0-17 11.5-28.5T40-520h40q17 0 28.5 11.5T120-480q0 17-11.5 28.5T80-440H40Zm779 299q-12 12-28.5 12T762-141l-28-28q-12-12-12-28t12-28q12-12 28.5-12t28.5 12l28 28q11 11 11 27.5T819-141ZM226-735q-12 12-28.5 12T169-735l-28-28q-11-11-11-27.5t11-28.5q12-12 28.5-12t28.5 12l28 28q12 12 12 28t-12 28Zm170 339Z"/></svg>
|
||||
|
After Width: | Height: | Size: 1.4 KiB |
1
whisperlivekit/web/src/translate.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="m603-202-34 97q-4 11-14 18t-22 7q-20 0-32.5-16.5T496-133l152-402q5-11 15-18t22-7h30q12 0 22 7t15 18l152 403q8 19-4 35.5T868-80q-13 0-22.5-7T831-106l-34-96H603ZM362-401 188-228q-11 11-27.5 11.5T132-228q-11-11-11-28t11-28l174-174q-35-35-63.5-80T190-640h84q20 39 40 68t48 58q33-33 68.5-92.5T484-720H80q-17 0-28.5-11.5T40-760q0-17 11.5-28.5T80-800h240v-40q0-17 11.5-28.5T360-880q17 0 28.5 11.5T400-840v40h240q17 0 28.5 11.5T680-760q0 17-11.5 28.5T640-720h-76q-21 72-63 148t-83 116l96 98-30 82-122-125Zm266 129h144l-72-204-72 204Z"/></svg>
|
||||
|
After Width: | Height: | Size: 650 B |
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import importlib.resources as resources
|
||||
import base64
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -10,4 +11,78 @@ def get_web_interface_html():
|
||||
return f.read()
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading web interface HTML: {e}")
|
||||
return "<html><body><h1>Error loading interface</h1></body></html>"
|
||||
return "<html><body><h1>Error loading interface</h1></body></html>"
|
||||
|
||||
def get_inline_ui_html():
|
||||
"""Returns the complete web interface HTML with all assets embedded in a single call."""
|
||||
try:
|
||||
with resources.files('whisperlivekit.web').joinpath('live_transcription.html').open('r', encoding='utf-8') as f:
|
||||
html_content = f.read()
|
||||
with resources.files('whisperlivekit.web').joinpath('live_transcription.css').open('r', encoding='utf-8') as f:
|
||||
css_content = f.read()
|
||||
with resources.files('whisperlivekit.web').joinpath('live_transcription.js').open('r', encoding='utf-8') as f:
|
||||
js_content = f.read()
|
||||
|
||||
# SVG files
|
||||
with resources.files('whisperlivekit.web').joinpath('src', 'system_mode.svg').open('r', encoding='utf-8') as f:
|
||||
system_svg = f.read()
|
||||
system_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(system_svg.encode('utf-8')).decode('utf-8')}"
|
||||
with resources.files('whisperlivekit.web').joinpath('src', 'light_mode.svg').open('r', encoding='utf-8') as f:
|
||||
light_svg = f.read()
|
||||
light_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(light_svg.encode('utf-8')).decode('utf-8')}"
|
||||
with resources.files('whisperlivekit.web').joinpath('src', 'dark_mode.svg').open('r', encoding='utf-8') as f:
|
||||
dark_svg = f.read()
|
||||
dark_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(dark_svg.encode('utf-8')).decode('utf-8')}"
|
||||
|
||||
# Replace external references
|
||||
html_content = html_content.replace(
|
||||
'<link rel="stylesheet" href="/web/live_transcription.css" />',
|
||||
f'<style>\n{css_content}\n</style>'
|
||||
)
|
||||
|
||||
html_content = html_content.replace(
|
||||
'<script src="/web/live_transcription.js"></script>',
|
||||
f'<script>\n{js_content}\n</script>'
|
||||
)
|
||||
|
||||
# Replace SVG references
|
||||
html_content = html_content.replace(
|
||||
'<img src="/web/src/system_mode.svg" alt="" />',
|
||||
f'<img src="{system_data_uri}" alt="" />'
|
||||
)
|
||||
|
||||
html_content = html_content.replace(
|
||||
'<img src="/web/src/light_mode.svg" alt="" />',
|
||||
f'<img src="{light_data_uri}" alt="" />'
|
||||
)
|
||||
|
||||
html_content = html_content.replace(
|
||||
'<img src="/web/src/dark_mode.svg" alt="" />',
|
||||
f'<img src="{dark_data_uri}" alt="" />'
|
||||
)
|
||||
|
||||
return html_content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating embedded web interface: {e}")
|
||||
return "<html><body><h1>Error loading embedded interface</h1></body></html>"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import HTMLResponse
|
||||
import uvicorn
|
||||
from starlette.staticfiles import StaticFiles
|
||||
import pathlib
|
||||
import whisperlivekit.web as webpkg
|
||||
|
||||
app = FastAPI()
|
||||
web_dir = pathlib.Path(webpkg.__file__).parent
|
||||
app.mount("/web", StaticFiles(directory=str(web_dir)), name="web")
|
||||
|
||||
@app.get("/")
|
||||
async def get():
|
||||
return HTMLResponse(get_inline_ui_html())
|
||||
|
||||
uvicorn.run(app=app)
|
||||
|
||||
@@ -3,43 +3,10 @@ import logging
|
||||
import io
|
||||
import soundfile as sf
|
||||
import math
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
torch = None
|
||||
from typing import List
|
||||
import numpy as np
|
||||
from whisperlivekit.timed_objects import ASRToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
SIMULSTREAMING_ERROR_AND_INSTALLATION_INSTRUCTIONS = ImportError(
|
||||
"""SimulStreaming dependencies are not available.
|
||||
Please install WhisperLiveKit using pip install "whisperlivekit[simulstreaming]"
|
||||
""")
|
||||
|
||||
try:
|
||||
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
||||
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper, DEC_PAD
|
||||
from whisperlivekit.simul_whisper.whisper import tokenizer
|
||||
SIMULSTREAMING_AVAILABLE = True
|
||||
except ImportError:
|
||||
logger.warning("⚠️ SimulStreaming dependencies not available. Attempting to download them.")
|
||||
try:
|
||||
from whisperlivekit import download_simulstreaming_backend
|
||||
download_simulstreaming_backend()
|
||||
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
||||
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper, DEC_PAD
|
||||
from whisperlivekit.simul_whisper.whisper import tokenizer
|
||||
SIMULSTREAMING_AVAILABLE = True
|
||||
logger.info("SimulStreaming dependencies downloaded successfully.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download or import SimulStreaming dependencies: {e}")
|
||||
SIMULSTREAMING_AVAILABLE = False
|
||||
AlignAttConfig = None
|
||||
PaddedAlignAttWhisper = None
|
||||
DEC_PAD = None
|
||||
tokenizer = None
|
||||
|
||||
class ASRBase:
|
||||
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
|
||||
# "" for faster-whisper because it emits the spaces when needed)
|
||||
@@ -320,182 +287,4 @@ class OpenaiApiASR(ASRBase):
|
||||
self.use_vad_opt = True
|
||||
|
||||
def set_translate_task(self):
|
||||
self.task = "translate"
|
||||
|
||||
|
||||
class SimulStreamingASR(ASRBase):
|
||||
"""SimulStreaming backend with AlignAtt policy."""
|
||||
sep = ""
|
||||
|
||||
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr, **kwargs):
|
||||
if not SIMULSTREAMING_AVAILABLE:
|
||||
raise SIMULSTREAMING_ERROR_AND_INSTALLATION_INSTRUCTIONS
|
||||
with open("whisperlivekit/simul_whisper/dual_license_simulstreaming.md", "r") as f:
|
||||
print("*"*80 + f.read() + "*"*80)
|
||||
self.logfile = logfile
|
||||
self.transcribe_kargs = {}
|
||||
self.original_language = None if lan == "auto" else lan
|
||||
|
||||
self.model_path = kwargs.get('model_path', './large-v3.pt')
|
||||
self.frame_threshold = kwargs.get('frame_threshold', 25)
|
||||
self.audio_max_len = kwargs.get('audio_max_len', 30.0)
|
||||
self.audio_min_len = kwargs.get('audio_min_len', 0.0)
|
||||
self.segment_length = kwargs.get('segment_length', 0.5)
|
||||
self.beams = kwargs.get('beams', 1)
|
||||
self.decoder_type = kwargs.get('decoder_type', 'greedy' if self.beams == 1 else 'beam')
|
||||
self.task = kwargs.get('task', 'transcribe')
|
||||
self.cif_ckpt_path = kwargs.get('cif_ckpt_path', None)
|
||||
self.never_fire = kwargs.get('never_fire', False)
|
||||
self.init_prompt = kwargs.get('init_prompt', None)
|
||||
self.static_init_prompt = kwargs.get('static_init_prompt', None)
|
||||
self.max_context_tokens = kwargs.get('max_context_tokens', None)
|
||||
|
||||
if model_dir is not None:
|
||||
self.model_path = model_dir
|
||||
elif modelsize is not None: #For the moment the .en.pt models do not work!
|
||||
model_mapping = {
|
||||
'tiny': './tiny.pt',
|
||||
'base': './base.pt',
|
||||
'small': './small.pt',
|
||||
'medium': './medium.pt',
|
||||
'medium.en': './medium.en.pt',
|
||||
'large-v1': './large-v1.pt',
|
||||
'base.en': './base.en.pt',
|
||||
'small.en': './small.en.pt',
|
||||
'tiny.en': './tiny.en.pt',
|
||||
'large-v2': './large-v2.pt',
|
||||
'large-v3': './large-v3.pt',
|
||||
'large': './large-v3.pt'
|
||||
}
|
||||
self.model_path = model_mapping.get(modelsize, f'./{modelsize}.pt')
|
||||
|
||||
self.model = self.load_model(modelsize, cache_dir, model_dir)
|
||||
|
||||
# Set up tokenizer for translation if needed
|
||||
if self.task == "translate":
|
||||
self.set_translate_task()
|
||||
|
||||
def load_model(self, modelsize, cache_dir, model_dir):
|
||||
try:
|
||||
cfg = AlignAttConfig(
|
||||
model_path=self.model_path,
|
||||
segment_length=self.segment_length,
|
||||
frame_threshold=self.frame_threshold,
|
||||
language=self.original_language,
|
||||
audio_max_len=self.audio_max_len,
|
||||
audio_min_len=self.audio_min_len,
|
||||
cif_ckpt_path=self.cif_ckpt_path,
|
||||
decoder_type="beam",
|
||||
beam_size=self.beams,
|
||||
task=self.task,
|
||||
never_fire=self.never_fire,
|
||||
init_prompt=self.init_prompt,
|
||||
max_context_tokens=self.max_context_tokens,
|
||||
static_init_prompt=self.static_init_prompt,
|
||||
)
|
||||
|
||||
logger.info(f"Loading SimulStreaming model with language: {self.original_language}")
|
||||
model = PaddedAlignAttWhisper(cfg)
|
||||
return model
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load SimulStreaming model: {e}")
|
||||
raise
|
||||
|
||||
def transcribe(self, audio, init_prompt=""):
|
||||
"""Transcribe audio using SimulStreaming."""
|
||||
try:
|
||||
if isinstance(audio, np.ndarray):
|
||||
audio_tensor = torch.from_numpy(audio).float()
|
||||
else:
|
||||
audio_tensor = audio
|
||||
|
||||
prompt = init_prompt if init_prompt else (self.init_prompt or "")
|
||||
|
||||
result = self.model.infer(audio_tensor, init_prompt=prompt)
|
||||
|
||||
if torch.is_tensor(result):
|
||||
result = result[result < DEC_PAD]
|
||||
|
||||
logger.debug(f"SimulStreaming transcription result: {result}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"SimulStreaming transcription failed: {e}")
|
||||
raise
|
||||
|
||||
def ts_words(self, result) -> List[ASRToken]:
|
||||
"""Convert SimulStreaming result to ASRToken list."""
|
||||
tokens = []
|
||||
|
||||
try:
|
||||
if torch.is_tensor(result):
|
||||
text = self.model.tokenizer.decode(result.cpu().numpy())
|
||||
else:
|
||||
text = str(result)
|
||||
|
||||
if not text or len(text.strip()) == 0:
|
||||
return tokens
|
||||
|
||||
# We dont have word-level timestamps here. 1rst approach, should be improved later.
|
||||
words = text.strip().split()
|
||||
if not words:
|
||||
return tokens
|
||||
|
||||
duration_per_word = 0.1 # this will be modified based on actual audio duration
|
||||
#with the SimulStreamingOnlineProcessor
|
||||
|
||||
for i, word in enumerate(words):
|
||||
start_time = i * duration_per_word
|
||||
end_time = (i + 1) * duration_per_word
|
||||
|
||||
token = ASRToken(
|
||||
start=start_time,
|
||||
end=end_time,
|
||||
text=word,
|
||||
probability=1.0
|
||||
)
|
||||
tokens.append(token)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting SimulStreaming result to tokens: {e}")
|
||||
|
||||
return tokens
|
||||
|
||||
def segments_end_ts(self, result) -> List[float]:
|
||||
"""Get segment end timestamps."""
|
||||
if torch.is_tensor(result):
|
||||
num_tokens = len(result)
|
||||
return [num_tokens * 0.1] # rough estimate
|
||||
return [1.0]
|
||||
|
||||
def use_vad(self):
|
||||
"""Enable VAD - SimulStreaming has different VAD handling."""
|
||||
logger.info("VAD requested for SimulStreaming - handled internally by the model")
|
||||
pass
|
||||
|
||||
def set_translate_task(self):
|
||||
"""Set up translation task."""
|
||||
try:
|
||||
self.model.tokenizer = tokenizer.get_tokenizer(
|
||||
multilingual=True,
|
||||
language=self.model.cfg.language,
|
||||
num_languages=self.model.model.num_languages,
|
||||
task="translate"
|
||||
)
|
||||
logger.info("SimulStreaming configured for translation task")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure SimulStreaming for translation: {e}")
|
||||
raise
|
||||
|
||||
def warmup(self, audio, init_prompt=""):
|
||||
"""Warmup the SimulStreaming model."""
|
||||
try:
|
||||
if isinstance(audio, np.ndarray):
|
||||
audio = torch.from_numpy(audio).float()
|
||||
self.model.insert_audio(audio)
|
||||
self.model.infer(True)
|
||||
self.model.refresh_segment(complete=True)
|
||||
logger.info("SimulStreaming model warmed up successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"SimulStreaming warmup failed: {e}")
|
||||
self.task = "translate"
|
||||
@@ -6,18 +6,6 @@ from whisperlivekit.timed_objects import ASRToken, Sentence, Transcript
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# simulStreaming imports - we check if the files are here
|
||||
try:
|
||||
import torch
|
||||
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
||||
SIMULSTREAMING_AVAILABLE = True
|
||||
except ImportError:
|
||||
logger.warning("SimulStreaming dependencies not available for online processor.")
|
||||
SIMULSTREAMING_AVAILABLE = False
|
||||
OnlineProcessorInterface = None
|
||||
torch = None
|
||||
|
||||
|
||||
class HypothesisBuffer:
|
||||
"""
|
||||
Buffer to store and process ASR hypothesis tokens.
|
||||
@@ -134,6 +122,7 @@ class OnlineASRProcessor:
|
||||
self.tokenize = tokenize_method
|
||||
self.logfile = logfile
|
||||
self.confidence_validation = confidence_validation
|
||||
self.global_time_offset = 0.0
|
||||
self.init()
|
||||
|
||||
self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
|
||||
@@ -164,6 +153,21 @@ class OnlineASRProcessor:
|
||||
"""Append an audio chunk (a numpy array) to the current audio buffer."""
|
||||
self.audio_buffer = np.append(self.audio_buffer, audio)
|
||||
|
||||
def insert_silence(self, silence_duration, offset):
|
||||
"""
|
||||
If silences are > 5s, we do a complete context clear. Otherwise, we just insert a small silence and shift the last_attend_frame
|
||||
"""
|
||||
# if self.transcript_buffer.buffer:
|
||||
# self.committed.extend(self.transcript_buffer.buffer)
|
||||
# self.transcript_buffer.buffer = []
|
||||
|
||||
if True: #silence_duration < 3: #we want the last audio to be treated to not have a gap. could also be handled in the future in ends_with_silence.
|
||||
gap_silence = np.zeros(int(16000 * silence_duration), dtype=np.int16)
|
||||
self.insert_audio_chunk(gap_silence)
|
||||
else:
|
||||
self.init(offset=silence_duration + offset)
|
||||
self.global_time_offset += silence_duration
|
||||
|
||||
def prompt(self) -> Tuple[str, str]:
|
||||
"""
|
||||
Returns a tuple: (prompt, context), where:
|
||||
@@ -242,6 +246,9 @@ class OnlineASRProcessor:
|
||||
logger.debug(
|
||||
f"Length of audio buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds"
|
||||
)
|
||||
if self.global_time_offset:
|
||||
for token in committed_tokens:
|
||||
token = token.with_offset(self.global_time_offset)
|
||||
return committed_tokens, current_audio_processed_upto
|
||||
|
||||
def chunk_completed_sentence(self):
|
||||
@@ -403,330 +410,3 @@ class OnlineASRProcessor:
|
||||
start = None
|
||||
end = None
|
||||
return Transcript(start, end, text, probability=probability)
|
||||
|
||||
|
||||
class VACOnlineASRProcessor:
|
||||
"""
|
||||
Wraps an OnlineASRProcessor with a Voice Activity Controller (VAC).
|
||||
|
||||
It receives small chunks of audio, applies VAD (e.g. with Silero),
|
||||
and when the system detects a pause in speech (or end of an utterance)
|
||||
it finalizes the utterance immediately.
|
||||
"""
|
||||
SAMPLING_RATE = 16000
|
||||
|
||||
def __init__(self, online_chunk_size: float, *args, **kwargs):
|
||||
self.online_chunk_size = online_chunk_size
|
||||
self.online = OnlineASRProcessor(*args, **kwargs)
|
||||
self.asr = self.online.asr
|
||||
|
||||
# Load a VAD model (e.g. Silero VAD)
|
||||
import torch
|
||||
model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
|
||||
from .silero_vad_iterator import FixedVADIterator
|
||||
|
||||
self.vac = FixedVADIterator(model)
|
||||
self.logfile = self.online.logfile
|
||||
self.last_input_audio_stream_end_time: float = 0.0
|
||||
self.init()
|
||||
|
||||
def init(self):
|
||||
self.online.init()
|
||||
self.vac.reset_states()
|
||||
self.current_online_chunk_buffer_size = 0
|
||||
self.last_input_audio_stream_end_time = self.online.buffer_time_offset
|
||||
self.is_currently_final = False
|
||||
self.status: Optional[str] = None # "voice" or "nonvoice"
|
||||
self.audio_buffer = np.array([], dtype=np.float32)
|
||||
self.buffer_offset = 0 # in frames
|
||||
|
||||
def get_audio_buffer_end_time(self) -> float:
|
||||
"""Returns the absolute end time of the audio processed by the underlying OnlineASRProcessor."""
|
||||
return self.online.get_audio_buffer_end_time()
|
||||
|
||||
def clear_buffer(self):
|
||||
self.buffer_offset += len(self.audio_buffer)
|
||||
self.audio_buffer = np.array([], dtype=np.float32)
|
||||
|
||||
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
|
||||
"""
|
||||
Process an incoming small audio chunk:
|
||||
- run VAD on the chunk,
|
||||
- decide whether to send the audio to the online ASR processor immediately,
|
||||
- and/or to mark the current utterance as finished.
|
||||
"""
|
||||
self.last_input_audio_stream_end_time = audio_stream_end_time
|
||||
res = self.vac(audio)
|
||||
self.audio_buffer = np.append(self.audio_buffer, audio)
|
||||
|
||||
if res is not None:
|
||||
# VAD returned a result; adjust the frame number
|
||||
frame = list(res.values())[0] - self.buffer_offset
|
||||
if "start" in res and "end" not in res:
|
||||
self.status = "voice"
|
||||
send_audio = self.audio_buffer[frame:]
|
||||
self.online.init(offset=(frame + self.buffer_offset) / self.SAMPLING_RATE)
|
||||
self.online.insert_audio_chunk(send_audio)
|
||||
self.current_online_chunk_buffer_size += len(send_audio)
|
||||
self.clear_buffer()
|
||||
elif "end" in res and "start" not in res:
|
||||
self.status = "nonvoice"
|
||||
send_audio = self.audio_buffer[:frame]
|
||||
self.online.insert_audio_chunk(send_audio)
|
||||
self.current_online_chunk_buffer_size += len(send_audio)
|
||||
self.is_currently_final = True
|
||||
self.clear_buffer()
|
||||
else:
|
||||
beg = res["start"] - self.buffer_offset
|
||||
end = res["end"] - self.buffer_offset
|
||||
self.status = "nonvoice"
|
||||
send_audio = self.audio_buffer[beg:end]
|
||||
self.online.init(offset=(beg + self.buffer_offset) / self.SAMPLING_RATE)
|
||||
self.online.insert_audio_chunk(send_audio)
|
||||
self.current_online_chunk_buffer_size += len(send_audio)
|
||||
self.is_currently_final = True
|
||||
self.clear_buffer()
|
||||
else:
|
||||
if self.status == "voice":
|
||||
self.online.insert_audio_chunk(self.audio_buffer)
|
||||
self.current_online_chunk_buffer_size += len(self.audio_buffer)
|
||||
self.clear_buffer()
|
||||
else:
|
||||
# Keep 1 second worth of audio in case VAD later detects voice,
|
||||
# but trim to avoid unbounded memory usage.
|
||||
self.buffer_offset += max(0, len(self.audio_buffer) - self.SAMPLING_RATE)
|
||||
self.audio_buffer = self.audio_buffer[-self.SAMPLING_RATE:]
|
||||
|
||||
def process_iter(self) -> Tuple[List[ASRToken], float]:
|
||||
"""
|
||||
Depending on the VAD status and the amount of accumulated audio,
|
||||
process the current audio chunk.
|
||||
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
|
||||
"""
|
||||
if self.is_currently_final:
|
||||
return self.finish()
|
||||
elif self.current_online_chunk_buffer_size > self.SAMPLING_RATE * self.online_chunk_size:
|
||||
self.current_online_chunk_buffer_size = 0
|
||||
return self.online.process_iter()
|
||||
else:
|
||||
logger.debug("No online update, only VAD")
|
||||
return [], self.last_input_audio_stream_end_time
|
||||
|
||||
def finish(self) -> Tuple[List[ASRToken], float]:
|
||||
"""
|
||||
Finish processing by flushing any remaining text.
|
||||
Returns a tuple: (list of remaining ASRToken objects, float representing the final audio processed up to time).
|
||||
"""
|
||||
result_tokens, processed_upto = self.online.finish()
|
||||
self.current_online_chunk_buffer_size = 0
|
||||
self.is_currently_final = False
|
||||
return result_tokens, processed_upto
|
||||
|
||||
def get_buffer(self):
|
||||
"""
|
||||
Get the unvalidated buffer in string format.
|
||||
"""
|
||||
return self.online.concatenate_tokens(self.online.transcript_buffer.buffer)
|
||||
|
||||
|
||||
class SimulStreamingOnlineProcessor:
|
||||
SAMPLING_RATE = 16000
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
asr,
|
||||
tokenize_method: Optional[callable] = None,
|
||||
buffer_trimming: Tuple[str, float] = ("segment", 15),
|
||||
confidence_validation = False,
|
||||
logfile=sys.stderr,
|
||||
):
|
||||
if not SIMULSTREAMING_AVAILABLE:
|
||||
raise ImportError("SimulStreaming dependencies are not available.")
|
||||
|
||||
self.asr = asr
|
||||
self.tokenize = tokenize_method
|
||||
self.logfile = logfile
|
||||
self.confidence_validation = confidence_validation
|
||||
self.init()
|
||||
|
||||
# buffer does not work yet
|
||||
self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
|
||||
|
||||
def init(self, offset: Optional[float] = None):
|
||||
"""Initialize or reset the processing state."""
|
||||
self.audio_chunks = []
|
||||
self.offset = offset if offset is not None else 0.0
|
||||
self.is_last = False
|
||||
self.beg = self.offset
|
||||
self.end = self.offset
|
||||
self.cumulative_audio_duration = 0.0
|
||||
self.last_audio_stream_end_time = self.offset
|
||||
|
||||
self.committed: List[ASRToken] = []
|
||||
self.last_result_tokens: List[ASRToken] = []
|
||||
self.buffer_content = ""
|
||||
self.processed_audio_duration = 0.0
|
||||
|
||||
def get_audio_buffer_end_time(self) -> float:
|
||||
"""Returns the absolute end time of the current audio buffer."""
|
||||
return self.end
|
||||
|
||||
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: Optional[float] = None):
|
||||
"""Append an audio chunk to be processed by SimulStreaming."""
|
||||
if torch is None:
|
||||
raise ImportError("PyTorch is required for SimulStreaming but not available")
|
||||
|
||||
# Convert numpy array to torch tensor
|
||||
audio_tensor = torch.from_numpy(audio).float()
|
||||
self.audio_chunks.append(audio_tensor)
|
||||
|
||||
# Update timing
|
||||
chunk_duration = len(audio) / self.SAMPLING_RATE
|
||||
self.cumulative_audio_duration += chunk_duration
|
||||
|
||||
if audio_stream_end_time is not None:
|
||||
self.last_audio_stream_end_time = audio_stream_end_time
|
||||
self.end = audio_stream_end_time
|
||||
else:
|
||||
self.end = self.offset + self.cumulative_audio_duration
|
||||
|
||||
def prompt(self) -> Tuple[str, str]:
|
||||
"""
|
||||
Returns a tuple: (prompt, context).
|
||||
SimulStreaming handles prompting internally, so we return empty strings.
|
||||
"""
|
||||
return "", ""
|
||||
|
||||
def get_buffer(self):
|
||||
"""
|
||||
Get the unvalidated buffer content.
|
||||
"""
|
||||
buffer_end = self.end if hasattr(self, 'end') else None
|
||||
return Transcript(
|
||||
start=None,
|
||||
end=buffer_end,
|
||||
text=self.buffer_content,
|
||||
probability=None
|
||||
)
|
||||
|
||||
def timestamped_text(self, tokens, generation):
|
||||
# From the simulstreaming repo. self.model to self.asr.model
|
||||
pr = generation["progress"]
|
||||
if "result" not in generation:
|
||||
split_words, split_tokens = self.asr.model.tokenizer.split_to_word_tokens(tokens)
|
||||
else:
|
||||
split_words, split_tokens = generation["result"]["split_words"], generation["result"]["split_tokens"]
|
||||
|
||||
frames = [p["most_attended_frames"][0] for p in pr]
|
||||
tokens = tokens.copy()
|
||||
ret = []
|
||||
for sw,st in zip(split_words,split_tokens):
|
||||
b = None
|
||||
for stt in st:
|
||||
t,f = tokens.pop(0), frames.pop(0)
|
||||
if t != stt:
|
||||
raise ValueError(f"Token mismatch: {t} != {stt} at frame {f}.")
|
||||
if b is None:
|
||||
b = f
|
||||
e = f
|
||||
out = (b*0.02, e*0.02, sw)
|
||||
ret.append(out)
|
||||
logger.debug(f"TS-WORD:\t{' '.join(map(str, out))}")
|
||||
return ret
|
||||
|
||||
def process_iter(self) -> Tuple[List[ASRToken], float]:
|
||||
"""
|
||||
Process accumulated audio chunks using SimulStreaming.
|
||||
|
||||
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
|
||||
"""
|
||||
if not self.audio_chunks:
|
||||
return [], self.end
|
||||
|
||||
try:
|
||||
# concatenate all audio chunks
|
||||
if len(self.audio_chunks) == 1:
|
||||
audio = self.audio_chunks[0]
|
||||
else:
|
||||
audio = torch.cat(self.audio_chunks, dim=0)
|
||||
|
||||
audio_duration = audio.shape[0] / self.SAMPLING_RATE if audio.shape[0] > 0 else 0
|
||||
self.processed_audio_duration += audio_duration
|
||||
|
||||
self.audio_chunks = []
|
||||
|
||||
logger.debug(f"SimulStreaming processing audio shape: {audio.shape}, duration: {audio_duration:.2f}s")
|
||||
logger.debug(f"Current end time: {self.end:.2f}s, last stream time: {self.last_audio_stream_end_time:.2f}s")
|
||||
|
||||
self.asr.model.insert_audio(audio)
|
||||
tokens, generation_progress = self.asr.model.infer(is_last=self.is_last)
|
||||
ts_words = self.timestamped_text(tokens, generation_progress)
|
||||
text = self.asr.model.tokenizer.decode(tokens)
|
||||
|
||||
new_tokens = []
|
||||
for ts_word in ts_words:
|
||||
|
||||
start, end, word = ts_word
|
||||
token = ASRToken(
|
||||
start=start,
|
||||
end=end,
|
||||
text=word,
|
||||
probability=0.95 # fake prob. Maybe we can extract it from the model?
|
||||
)
|
||||
new_tokens.append(token)
|
||||
self.committed.extend(new_tokens)
|
||||
|
||||
return new_tokens, self.end
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"SimulStreaming processing error: {e}")
|
||||
logger.error(f"Error details: {type(e).__name__}: {str(e)}")
|
||||
return [], self.end
|
||||
|
||||
def finish(self) -> Tuple[List[ASRToken], float]:
|
||||
logger.debug("SimulStreaming finish() called")
|
||||
self.is_last = True
|
||||
final_tokens, final_time = self.process_iter()
|
||||
self.is_last = False
|
||||
return final_tokens, final_time
|
||||
|
||||
def concatenate_tokens(
|
||||
self,
|
||||
tokens: List[ASRToken],
|
||||
sep: Optional[str] = None,
|
||||
offset: float = 0
|
||||
) -> Transcript:
|
||||
"""Concatenate tokens into a Transcript object."""
|
||||
sep = sep if sep is not None else self.asr.sep
|
||||
text = sep.join(token.text for token in tokens)
|
||||
probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
|
||||
if tokens:
|
||||
start = offset + tokens[0].start
|
||||
end = offset + tokens[-1].end
|
||||
else:
|
||||
start = None
|
||||
end = None
|
||||
return Transcript(start, end, text, probability=probability)
|
||||
|
||||
def chunk_at(self, time: float):
|
||||
"""
|
||||
useless but kept for compatibility
|
||||
"""
|
||||
logger.debug(f"SimulStreaming chunk_at({time:.2f}) - handled internally")
|
||||
pass
|
||||
|
||||
def words_to_sentences(self, tokens: List[ASRToken]) -> List[Sentence]:
|
||||
"""
|
||||
Create simple sentences.
|
||||
"""
|
||||
if not tokens:
|
||||
return []
|
||||
|
||||
full_text = " ".join(token.text for token in tokens)
|
||||
sentence = Sentence(
|
||||
start=tokens[0].start,
|
||||
end=tokens[-1].end,
|
||||
text=full_text
|
||||
)
|
||||
return [sentence]
|
||||
|
||||
@@ -5,8 +5,7 @@ import librosa
|
||||
from functools import lru_cache
|
||||
import time
|
||||
import logging
|
||||
from .backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR, SimulStreamingASR, SIMULSTREAMING_AVAILABLE, SIMULSTREAMING_ERROR_AND_INSTALLATION_INSTRUCTIONS
|
||||
from .online_asr import OnlineASRProcessor, VACOnlineASRProcessor, SimulStreamingOnlineProcessor, SIMULSTREAMING_AVAILABLE as SIMULSTREAMING_ONLINE_AVAILABLE
|
||||
from .backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -68,35 +67,7 @@ def backend_factory(args):
|
||||
backend = args.backend
|
||||
if backend == "openai-api":
|
||||
logger.debug("Using OpenAI API.")
|
||||
asr = OpenaiApiASR(lan=args.lan)
|
||||
elif backend == "simulstreaming":
|
||||
logger.debug("Using SimulStreaming backend.")
|
||||
if not SIMULSTREAMING_AVAILABLE:
|
||||
raise SIMULSTREAMING_ERROR_AND_INSTALLATION_INSTRUCTIONS
|
||||
|
||||
simulstreaming_kwargs = {}
|
||||
for attr in ['frame_threshold', 'beams', 'decoder_type', 'audio_max_len', 'audio_min_len',
|
||||
'cif_ckpt_path', 'never_fire', 'init_prompt', 'static_init_prompt',
|
||||
'max_context_tokens', 'model_path']:
|
||||
if hasattr(args, attr):
|
||||
simulstreaming_kwargs[attr] = getattr(args, attr)
|
||||
|
||||
# Add segment_length from min_chunk_size
|
||||
simulstreaming_kwargs['segment_length'] = getattr(args, 'min_chunk_size', 0.5)
|
||||
simulstreaming_kwargs['task'] = args.task
|
||||
|
||||
size = args.model
|
||||
t = time.time()
|
||||
logger.info(f"Loading SimulStreaming {size} model for language {args.lan}...")
|
||||
asr = SimulStreamingASR(
|
||||
modelsize=size,
|
||||
lan=args.lan,
|
||||
cache_dir=getattr(args, 'model_cache_dir', None),
|
||||
model_dir=getattr(args, 'model_dir', None),
|
||||
**simulstreaming_kwargs
|
||||
)
|
||||
e = time.time()
|
||||
logger.info(f"done. It took {round(e-t,2)} seconds.")
|
||||
asr = OpenaiApiASR(lan=args.lan)
|
||||
else:
|
||||
if backend == "faster-whisper":
|
||||
asr_cls = FasterWhisperASR
|
||||
@@ -136,107 +107,4 @@ def backend_factory(args):
|
||||
tokenizer = create_tokenizer(tgt_language)
|
||||
else:
|
||||
tokenizer = None
|
||||
return asr, tokenizer
|
||||
|
||||
def online_factory(args, asr, tokenizer, logfile=sys.stderr):
|
||||
if args.backend == "simulstreaming":
|
||||
if not SIMULSTREAMING_ONLINE_AVAILABLE:
|
||||
raise SIMULSTREAMING_ERROR_AND_INSTALLATION_INSTRUCTIONS
|
||||
|
||||
logger.debug("Creating SimulStreaming online processor")
|
||||
online = SimulStreamingOnlineProcessor(
|
||||
asr,
|
||||
tokenizer,
|
||||
logfile=logfile,
|
||||
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
|
||||
confidence_validation=args.confidence_validation
|
||||
)
|
||||
elif args.vac:
|
||||
online = VACOnlineASRProcessor(
|
||||
args.min_chunk_size,
|
||||
asr,
|
||||
tokenizer,
|
||||
logfile=logfile,
|
||||
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
|
||||
confidence_validation = args.confidence_validation
|
||||
)
|
||||
else:
|
||||
online = OnlineASRProcessor(
|
||||
asr,
|
||||
tokenizer,
|
||||
logfile=logfile,
|
||||
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
|
||||
confidence_validation = args.confidence_validation
|
||||
)
|
||||
return online
|
||||
|
||||
def asr_factory(args, logfile=sys.stderr):
|
||||
"""
|
||||
Creates and configures an ASR and ASR Online instance based on the specified backend and arguments.
|
||||
"""
|
||||
asr, tokenizer = backend_factory(args)
|
||||
online = online_factory(args, asr, tokenizer, logfile=logfile)
|
||||
return asr, online
|
||||
|
||||
def warmup_asr(asr, warmup_file=None, timeout=5):
|
||||
"""
|
||||
Warmup the ASR model by transcribing a short audio file.
|
||||
"""
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
is_simulstreaming = hasattr(asr, 'warmup') and callable(getattr(asr, 'warmup'))
|
||||
|
||||
if warmup_file is None:
|
||||
# Download JFK sample if not already present
|
||||
jfk_url = "https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav"
|
||||
temp_dir = tempfile.gettempdir()
|
||||
warmup_file = os.path.join(temp_dir, "whisper_warmup_jfk.wav")
|
||||
|
||||
if not os.path.exists(warmup_file):
|
||||
logger.debug(f"Downloading warmup file from {jfk_url}")
|
||||
print(f"Downloading warmup file from {jfk_url}")
|
||||
import time
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
import socket
|
||||
|
||||
original_timeout = socket.getdefaulttimeout()
|
||||
socket.setdefaulttimeout(timeout)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
urllib.request.urlretrieve(jfk_url, warmup_file)
|
||||
logger.debug(f"Download successful in {time.time() - start_time:.2f}s")
|
||||
except (urllib.error.URLError, socket.timeout) as e:
|
||||
logger.warning(f"Download failed: {e}. Proceeding without warmup.")
|
||||
return False
|
||||
finally:
|
||||
socket.setdefaulttimeout(original_timeout)
|
||||
elif not warmup_file:
|
||||
return False
|
||||
|
||||
if not warmup_file or not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0:
|
||||
logger.warning(f"Warmup file {warmup_file} invalid or missing.")
|
||||
return False
|
||||
|
||||
print(f"Warming up {'SimulStreaming' if is_simulstreaming else 'Whisper'} with {warmup_file}")
|
||||
try:
|
||||
import librosa
|
||||
audio, sr = librosa.load(warmup_file, sr=16000)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load audio file: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
if is_simulstreaming:
|
||||
asr.warmup(audio)
|
||||
else:
|
||||
asr.transcribe(audio)
|
||||
|
||||
logger.info(f"{'SimulStreaming' if is_simulstreaming else 'Whisper'} is warmed up")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Warmup failed: {e}")
|
||||
return False
|
||||
return asr, tokenizer
|
||||