142 Commits
0.2.2 ... 0.2.9

Author SHA1 Message Date
Quentin Fuxa
babe93b99a to 0.2.9 2025-09-11 21:36:32 +02:00
Quentin Fuxa
a4e9f3cab7 support for raw PCM input option by @YeonjunNotFR 2025-09-11 21:32:11 +02:00
Quentin Fuxa
b06866877a add --disable-punctuation-split option 2025-09-11 21:03:00 +02:00
Quentin Fuxa
967cdfebc8 fix Translation imports 2025-09-11 21:03:00 +02:00
Quentin Fuxa
3c11c60126 fix by @treeaaa 2025-09-11 21:03:00 +02:00
Quentin Fuxa
2963e8a757 translate when at least 3 new tokens 2025-09-09 21:45:00 +02:00
Quentin Fuxa
cb2d4ea88a audio processor lines use now Lines objects instead of dict 2025-09-09 21:45:00 +02:00
Quentin Fuxa
add7ea07ee translator takes all the tokens from the queue 2025-09-09 19:55:39 +02:00
Quentin Fuxa
da8726b2cb Merge pull request #211 from Alexander-ARTV/main
Fix type error when setting encoder_feature in simul_whisper->infer for faster whisper encoder
2025-09-09 15:46:59 +02:00
Quentin Fuxa
3358877054 Fix StorageView conversion for CPU/GPU compatibility 2025-09-09 15:44:16 +02:00
Quentin Fuxa
1f7798c7c1 condition on encoder_feature_ctranslate type 2025-09-09 12:16:52 +02:00
Alexander Lindberg
c7b3bb5e58 Fix regression with faster-whisper encoder_feature 2025-09-09 11:18:55 +03:00
Quentin Fuxa
f661f21675 translation asyncio task 2025-09-08 18:34:31 +02:00
Quentin Fuxa
b6164aa59b translation device determined with torch.device 2025-09-08 11:34:40 +02:00
Quentin Fuxa
4209d7f7c0 Place all tensors on the same device in sortformer diarization 2025-09-08 10:20:57 +02:00
Quentin Fuxa
334b338ab0 use platform to determine system and recommand mlx whisper 2025-09-07 15:49:11 +02:00
Quentin Fuxa
72f33be6f2 translation: use of get_nllb_code 2025-09-07 15:25:14 +02:00
Quentin Fuxa
84890b8e61 Merge pull request #201 from notV3NOM/main
Fix: simulstreaming preload model count argument in cli
2025-09-07 15:18:54 +02:00
Quentin Fuxa
c6668adcf3 Merge pull request #200 from notV3NOM/misc
docs: add vram usage for large-v3-turbo
2025-09-07 15:17:42 +02:00
notV3NOM
a178ed5c22 fix simulstreaming preload model count argument in cli 2025-09-06 18:18:09 +05:30
notV3NOM
7601c74c9c add vram usage for large-v3-turbo 2025-09-06 17:56:39 +05:30
Quentin Fuxa
fad9ee4d21 Merge pull request #198 from notV3NOM/main
Fix scrolling UX with sticky header controls
2025-09-05 20:46:36 +02:00
Quentin Fuxa
d1a9913c47 nllb v0 2025-09-05 18:02:42 +02:00
notV3NOM
e4ca2623cb Fix scrolling UX with sticky header controls 2025-09-05 21:25:13 +05:30
Quentin Fuxa
9c1bf37960 fixes #197 2025-09-05 16:34:13 +02:00
Quentin Fuxa
f46528471b revamp chromium extension settings 2025-09-05 16:19:48 +02:00
Quentin Fuxa
191680940b Merge branch 'main' of https://github.com/QuentinFuxa/whisper_streaming_web 2025-09-04 23:58:51 +02:00
Quentin Fuxa
ee02afec56 workaround to get the list of microphones in the extension 2025-09-04 23:58:48 +02:00
Quentin Fuxa
a458028de2 Merge pull request #196 from notV3NOM/main
Fix: Exponentially growing simulstreaming silence timer
2025-09-04 23:05:59 +02:00
notV3NOM
abd8f2c269 Fix exponentially growing simulstreaming silence timer 2025-09-04 21:49:07 +05:30
Quentin Fuxa
f3ad4e39e4 torch.Tensor to torch.as_tensor 2025-09-04 16:39:11 +02:00
Quentin Fuxa
e0a5cbf0e7 v0.1.0 chrome extension 2025-09-04 16:36:28 +02:00
Quentin Fuxa
953697cd86 torch.Tensor to torch.as_tensor 2025-09-04 15:25:39 +02:00
Quentin Fuxa
3bd2122eb4 0.2.8 : only the decoder of whisper is loaded in memory when a different encoder is used 2025-09-02 21:12:25 +02:00
Quentin Fuxa
50b0527858 update architecture 2025-09-01 21:24:12 +02:00
Quentin Fuxa
b044fcdec2 Merge branch 'main' of https://github.com/QuentinFuxa/whisper_streaming_web 2025-09-01 14:55:19 +02:00
Quentin Fuxa
b0508fcf2c mlx/fasterWhisper encoders are loaded once and shared in simulstreaming 2025-09-01 14:55:11 +02:00
Quentin Fuxa
ce89b0aebc Merge pull request #177 from komiyamma/translate-readme-to-japanese
Translate README.md to Japanese
2025-09-01 13:54:50 +02:00
Quentin Fuxa
d5008ed828 mlx/fasterWhisper encoders are loaded once and shared in simulstreaming 2025-09-01 12:33:19 +02:00
Quentin Fuxa
d467716e26 add microphone picker 2025-08-31 10:12:52 +02:00
Quentin Fuxa
199e21b3ef faster-whisper as an optional encoder alternative for simulstreaming 2025-08-30 23:50:16 +02:00
Quentin Fuxa
1d926f2e67 mlx-whisper used as simulstreaming encoder: improve speed for macos systems 2025-08-30 22:19:11 +02:00
Quentin Fuxa
4a71a391b8 get_web_interface_html to get_inline_ui_html for embedded web interface HTML 2025-08-30 13:44:06 +02:00
google-labs-jules[bot]
d3ed4e46e2 Translate README.md to Japanese
Create a Japanese version of the README.md file named ReadmeJP.md.
This makes the project more accessible to Japanese-speaking users.
2025-08-30 04:16:18 +00:00
Quentin Fuxa
057a1026d7 Merge branch 'main' of https://github.com/QuentinFuxa/whisper_streaming_web 2025-08-29 22:01:04 +02:00
Quentin Fuxa
1ba171a58d add embedded web interface HTML (single-file version with inline CSS/JS/SVG)
### Added
- `get_inline_ui_html()`: generates a self-contained version of the web interface, with CSS, JS, and SVG assets inlined directly into the HTML. useful for environments where serving static files is inconvenient or when a single-call UI delivery is preferred.

(cherry picked from commit aa44a92a67)
2025-08-29 22:00:59 +02:00
Quentin Fuxa
1adac67155 explanations about model persistency in containers 2025-08-29 21:27:08 +02:00
Quentin Fuxa
42be1a3773 Merge pull request #173 from CoderRahul9904/chore/docker/pytorch-timeout-retries
fix: increase pip timeout & retries for torch wheel install
2025-08-29 21:22:30 +02:00
Rahul Mourya
0a49fafa0d Update Dockerfile
fix(docker): increase pip timeout/retries for PyTorch wheel installs
2025-08-30 00:23:59 +05:30
Quentin Fuxa
4a5d5e1f3b raise Exception when language == auto and task == translation 2025-08-29 17:44:46 +02:00
Quentin Fuxa
583a2ec2e4 highlight Sortformer optional installation 2025-08-27 21:02:25 +02:00
Quentin Fuxa
19765e89e9 remove triton <3 condition 2025-08-27 20:44:39 +02:00
Quentin Fuxa
9895bc83bf auto detection of language for warmup if not indicated 2025-08-27 20:37:48 +02:00
Quentin Fuxa
ab98c31f16 trim will happen before audio processor 2025-08-27 18:17:11 +02:00
Quentin Fuxa
f9c9c4188a optional dependencies removed, ask to direct alternative package installations 2025-08-27 18:15:32 +02:00
Quentin Fuxa
c21d2302e7 to 0.2.7 2024-08-24 19:28:00 +02:00
Quentin Fuxa
4ed62e181d when silences are detected, speaker correction is no more applied 2024-08-24 19:24:00 +02:00
Quentin Fuxa
52a755a08c indications on how to choose a model 2024-08-24 19:22:00 +02:00
Quentin Fuxa
9a8d3cbd90 improve diarization + silence handling 2024-08-24 19:20:00 +02:00
Quentin Fuxa
b101ce06bd several users share the same sortformer model instance 2024-08-24 19:18:00 +02:00
Quentin Fuxa
c83fd179a8 improves phase shift correction between transcription and diarization 2024-08-24 19:15:00 +02:00
Quentin Fuxa
5258305745 default diarization backend in now sortformer 2025-08-24 18:32:01 +02:00
Quentin Fuxa
ce781831ee punctuation is checked in audio-processor's result formatter 2025-08-24 18:32:01 +02:00
Quentin Fuxa
58297daf6d sortformer diar implementation v0.3 2025-08-24 18:32:01 +02:00
Quentin Fuxa
3393a08f7e sortformer diar implementation v0.2 2025-08-24 18:32:01 +02:00
Quentin Fuxa
5b2ddeccdb correct pip installation error in image build 2025-08-22 15:37:46 +02:00
Quentin Fuxa
26cc1072dd new dockerfile for cpu only. update dockerfile from cuda 12.8 to 12.9 2025-08-22 11:04:35 +02:00
Quentin Fuxa
12973711f6 0.2.6 2025-08-21 14:34:46 +02:00
Quentin Fuxa
909ac9dd41 speaker -1 are no more sent in websocket - no buffer when their is a silence 2025-08-21 14:09:02 +02:00
Quentin Fuxa
d94a07d417 default model is now base. default backend simulstreaming 2025-08-21 11:55:36 +02:00
Quentin Fuxa
b32dd8bfc4 Align backend and frontend time handling 2025-08-21 10:33:15 +02:00
Quentin Fuxa
9feb0e597b remove VACOnlineASRProcessor backend possibility 2025-08-20 20:57:43 +02:00
Quentin Fuxa
9dab84a573 update front 2025-08-20 20:15:38 +02:00
Quentin Fuxa
d089c7fce0 .html to .html + .css + .js 2025-08-20 20:00:31 +02:00
Quentin Fuxa
253a080df5 diart diarization handles pauses/silences thanks to offset 2025-08-19 21:12:55 +02:00
Quentin Fuxa
0c6e4b2aee sortformer diar implementation v0.1 2025-08-19 19:48:51 +02:00
Quentin Fuxa
e14bbde77d sortformer diar implementation v0 2025-08-19 17:02:55 +02:00
Quentin Fuxa
7496163467 rename diart backend 2025-08-19 15:02:27 +02:00
Quentin Fuxa
696a94d1ce 1rst sortformer backend implementation 2025-08-19 15:02:17 +02:00
Quentin Fuxa
2699b0974c Fix simulstreaming imports 2025-08-19 14:43:54 +02:00
Quentin Fuxa
90c0250ba4 update optional dependencies 2025-08-19 09:36:59 +02:00
Quentin Fuxa
eb96153ffd new vac parameters 2025-08-17 22:26:28 +02:00
Quentin Fuxa
47e3eb9b5b Update README.md 2025-08-17 09:55:03 +02:00
Quentin Fuxa
b8b07adeef --vac to --no-vac 2025-08-17 09:44:26 +02:00
Quentin Fuxa
d0e9e37ef6 simulstreaming: cumulative_time_offset to keep timestamps correct when audio > 30s 2025-08-17 09:33:47 +02:00
Quentin Fuxa
820f92d8cb audio_max_len to 30 -> 20, ffmpeg timeout 5 -> 20 2025-08-17 09:32:08 +02:00
Quentin Fuxa
e42523af84 VAC activated by default 2025-08-17 01:29:34 +02:00
Quentin Fuxa
e2184d5e06 better handle silences when VAC + correct offset issue with whisperstreaming backend 2025-08-17 01:27:07 +02:00
Quentin Fuxa
7fe0353260 vac model is loaded in TranscriptionEngine, and by default 2025-08-17 00:34:25 +02:00
Quentin Fuxa
0f2eba507e use with_offset to add no audio offset to tokens 2025-08-17 00:33:24 +02:00
Quentin Fuxa
55e08474f3 recycle backend in simulstreaming thanks to new remove hooks function 2025-08-16 23:06:16 +02:00
Quentin Fuxa
28bdc52e1d VAC before doing transcription and diarization. V0 2025-08-16 23:04:21 +02:00
Quentin Fuxa
e4221fa6c3 Merge branch 'main' of https://github.com/QuentinFuxa/whisper_streaming_web 2025-08-15 23:04:05 +02:00
Quentin Fuxa
1652db9a2d Use distinct backend models for simulstreaming and add --preloaded_model_count to preload them 2025-08-15 23:03:55 +02:00
Quentin Fuxa
601f17653a Update CONTRIBUTING.md 2025-08-13 21:59:32 +02:00
Quentin Fuxa
7718190fcd Update CONTRIBUTING.md 2025-08-13 21:59:00 +02:00
Quentin Fuxa
349c7dcb9e bump version ro 0.2.5 2025-08-13 10:04:31 +02:00
Quentin Fuxa
1c42b867cf Merge branch 'main' of https://github.com/QuentinFuxa/whisper_streaming_web 2025-08-13 10:04:04 +02:00
Quentin Fuxa
d4771e563e Increase END_SILENCE_DURATION to reduce false positives 2025-08-13 10:04:00 +02:00
Quentin Fuxa
b0a5fc0693 Merge pull request #155 from davidgumberg/keepawakescrolldown
frontend: Keep screen awake and scroll down when transcribing.
2025-08-13 10:02:52 +02:00
David Gumberg
3b96fb8776 frontend: Scroll down when appending transcription 2025-08-12 17:31:32 -07:00
David Gumberg
7f93c4b978 frontend: Don't let screen sleep when transcribing. 2025-08-12 17:30:57 -07:00
Quentin Fuxa
15c3df1cba warmup base whisper when using simulstreaming 2025-08-12 18:52:52 +02:00
Quentin Fuxa
7fb8e66c01 typo 2025-08-12 18:36:32 +02:00
Quentin Fuxa
728e1f1290 simulstreaming warmup is done for each instance of online, not for the backend 2025-08-12 18:35:04 +02:00
Quentin Fuxa
87b9ed6ecd nonspeech_prob from 1 to 0.5 2025-08-12 18:34:37 +02:00
Quentin Fuxa
38b4ebe8ba Handle 3 types of silences: Indicated by whisper, between tokens, and at the end of the input. Display them in the frontend 2025-08-11 17:56:57 +02:00
Quentin Fuxa
d098af3185 each SimulStreamingOnlineProcessor now contains PaddedAlignAttWhisper instance. SimulStreamingASR only contains loaded whisper model 2025-08-11 08:24:14 +02:00
Quentin Fuxa
4e56130a40 frontend supports dark theme 2025-08-11 08:22:23 +02:00
Quentin Fuxa
2bbdc70187 lags are now updated every 0.1s 2025-08-09 23:11:05 +02:00
Quentin Fuxa
b678a55f63 remove duplicate file 2025-08-09 23:10:34 +02:00
Quentin Fuxa
5491964e81 clean SimulStreamingOnlineProcessor initialization + audio processing 2025-08-09 20:16:27 +02:00
Quentin Fuxa
b05297a96d clean simulwhisper backend and online 2025-08-09 18:02:15 +02:00
Quentin Fuxa
197293e25e refactor(simulstreaming): extract backend + online module into separate files from whisper streaming 2025-08-08 18:07:51 +02:00
Quentin Fuxa
ba41c4ab56 Remove download_simulstreaming_backend 2025-08-08 18:06:40 +02:00
Quentin Fuxa
bda72b8bc0 setup.py to pyproject.toml. Remove <2.0.0 condition on numpy dep 2025-08-03 16:32:31 +02:00
Quentin Fuxa
bb6b9f4cb1 architecture diagram : available backends for whisper streaming & diarization 2025-08-03 12:25:36 +02:00
Quentin Fuxa
e40b5a3ea0 Update architecture diagram 2025-08-02 13:51:15 +02:00
Quentin Fuxa
4cfed6e98e in MultiHeadAttention and ResidualAttentionBlock include cache_id for compatibility with simulstreaming code 2025-08-02 13:16:58 +02:00
Quentin Fuxa
687e3dd5e2 update simulstreaming model.py to match the latest version of whisper sources 2025-08-02 13:16:10 +02:00
Quentin Fuxa
e4140cd299 Update Dockerfile to install build-essential and update PyTorch version 2025-08-02 13:08:43 +02:00
Quentin Fuxa
8e056cbdf2 Upgrade SimulStreaming Whisper core from version 20230918 to 20250625 2025-08-02 13:06:36 +02:00
Quentin Fuxa
9dcfb38967 Update README.md 2025-08-01 18:02:11 +02:00
Quentin Fuxa
47b9235d70 Update README.md 2025-08-01 17:55:40 +02:00
Quentin Fuxa
f3cd53a4db Update README.md 2025-08-01 16:53:22 +02:00
Quentin Fuxa
dbdb4ea66c Update README.md 2025-08-01 16:33:26 +02:00
Quentin Fuxa
00424d7ca3 latest version of simulstreaming 2025-07-31 16:44:23 +02:00
Quentin Fuxa
4b738d6f63 fix duplicate line 2025-07-31 16:29:35 +02:00
Quentin Fuxa
8a5e2adb1e simulstreaming: fixes token handling during warm-up phase 2025-07-31 16:25:34 +02:00
Quentin Fuxa
f85329e112 Merge branch 'main' of https://github.com/QuentinFuxa/whisper_streaming_web 2025-07-31 11:42:16 +02:00
Quentin Fuxa
46efbdf1d9 solves https://github.com/QuentinFuxa/WhisperLiveKit/issues/151 2025-07-31 11:42:06 +02:00
Quentin Fuxa
8885ade003 Merge pull request #153 from luisla-rivas/main
Fix README.md to view correctly Deployment Guide info
2025-07-31 07:10:35 +02:00
luisla-rivas
2564928d83 Fix README.md to view correctly Deployment Guide info 2025-07-30 14:11:19 +02:00
Quentin Fuxa
56114d3071 Remove end_attributed_speaker in diarization_online. handled in audio processor 2025-07-16 12:09:43 +02:00
Quentin Fuxa
5b9977c9af Enhanced use_punctuation_split for diarization. further improvements still needed 2025-07-16 12:06:17 +02:00
Quentin Fuxa
12a544164f Merge branch 'main' of https://github.com/QuentinFuxa/whisper_streaming_web 2025-07-16 12:05:01 +02:00
Quentin Fuxa
2ca1156b7e Merge pull request #147 from choomegan/diar_queue
Ensure diarization_queue receives only latest PCM chunk
2025-07-16 12:04:53 +02:00
Quentin Fuxa
3ad3683ca7 Refactor speaker assignment in DiartDiarization for clarity and punctuation awareness 2025-07-15 14:38:53 +02:00
Quentin Fuxa
1599bd87a0 work on punctuation_split 2025-07-15 12:04:54 +02:00
Quentin Fuxa
90623400a4 Remove automatic downloading of SimulStreaming dependencies on import failure 2025-07-15 12:04:17 +02:00
choomegan
64e44fb24f fix: logic of adding of pcm_array to diarization_queue 2025-07-15 15:33:41 +08:00
Quentin Fuxa
156b9a133f 0.2.2 2025-07-04 17:11:35 +02:00
82 changed files with 8144 additions and 2808 deletions

3
.gitignore vendored
View File

@@ -137,4 +137,5 @@ run_*.sh
test_*.py test_*.py
launch.json launch.json
.DS_Store .DS_Store
test/* test/*
nllb-200-distilled-600M-ctranslate2/*

View File

@@ -15,7 +15,7 @@ Thank you for considering contributing ! We appreciate your time and effort to h
## Opening Issues ## Opening Issues
If you encounter a problem with diart or want to suggest an improvement, please follow these guidelines when opening an issue: If you encounter a problem with WhisperLiveKit or want to suggest an improvement, please follow these guidelines when opening an issue:
- **Bug Reports:** - **Bug Reports:**
- Clearly describe the error. **Please indicate the parameters you use, especially the model(s)** - Clearly describe the error. **Please indicate the parameters you use, especially the model(s)**
@@ -43,4 +43,4 @@ We welcome and appreciate contributions! To ensure a smooth review process, plea
## Thank You ## Thank You
Your contributions make diart better for everyone. Thank you for your time and dedication! Your contributions make WhisperLiveKit better for everyone. Thank you for your time and dedication!

70
DEV_NOTES.md Normal file
View File

@@ -0,0 +1,70 @@
# 1. Simulstreaming: Decouple the encoder for faster inference
Simulstreaming encoder time (whisperlivekit/simul_whisper/simul_whisper.py l. 397) experimentations :
On macOS Apple Silicon M4 :
| Encoder | base.en | small |
|--------|---------|-------|
| WHISPER (no modification) | 0.35s | 1.09s |
| FASTER_WHISPER | 0.4s | 1.20s |
| MLX_WHISPER | 0.07s | 0.20s |
Memory saved by only loading encoder for optimized framework:
For tiny.en, mlx whisper:
Sizes MLX whisper:
Decoder weights: 59110771 bytes
Encoder weights: 15268874 bytes
# 2. SortFormer Diarization: 4-to-2 Speaker Constraint Algorithm
Transform a diarization model that predicts up to 4 speakers into one that predicts up to 2 speakers by mapping the output predictions.
## Problem Statement
- Input: `self.total_preds` with shape `(x, x, 4)` - predictions for 4 speakers
- Output: Constrained predictions with shape `(x, x, 2)` - predictions for 2 speakers
#
### Initial Setup
For each time step `i`, we have a ranking of 4 speaker predictions (1-4). When only 2 speakers are present, the model will have close predictions for the 2 active speaker positions.
Instead of `np.argmax(preds_np, axis=1)`, we take the top 2 predictions and build a dynamic 4→2 mapping that can evolve over time.
### Algorithm
```python
top_2_speakers = np.argsort(preds_np, axis=1)[:, -2:]
```
- `DS_a_{i}`: Top detected speaker for prediction i
- `DS_b_{i}`: Second detected speaker for prediction i
- `AS_{i}`: Attributed speaker for prediction i
- `GTS_A`: Ground truth speaker A
- `GTS_B`: Ground truth speaker B
- `DIST(a, b)`: Distance between detected speakers a and b
3. **Attribution Logic**
```
AS_0 ← A
AS_1 ← B
IF DIST(DS_a_0, DS_a_1) < DIST(DS_a_0, DS_a_2) AND
DIST(DS_a_0, DS_a_1) < DIST(DS_a_1, DS_a_2):
# Likely that DS_a_0 = DS_a_1 (same speaker)
AS_1 ← A
AS_2 ← B
ELIF DIST(DS_a_0, DS_a_2) < DIST(DS_a_0, DS_a_1) AND
DIST(DS_a_0, DS_a_2) < DIST(DS_a_1, DS_a_2):
AS_2 ← A
ELSE:
AS_2 ← B
to finish
```

View File

@@ -1,4 +1,4 @@
FROM nvidia/cuda:12.8.1-cudnn-runtime-ubuntu22.04 FROM nvidia/cuda:12.9.1-cudnn-devel-ubuntu24.04
ENV DEBIAN_FRONTEND=noninteractive ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1 ENV PYTHONUNBUFFERED=1
@@ -9,46 +9,50 @@ ARG EXTRAS
ARG HF_PRECACHE_DIR ARG HF_PRECACHE_DIR
ARG HF_TKN_FILE ARG HF_TKN_FILE
# Install system dependencies
#RUN apt-get update && \
# apt-get install -y ffmpeg git && \
# apt-get clean && \
# rm -rf /var/lib/apt/lists/*
# 2) Install system dependencies + Python + pip
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y --no-install-recommends \ apt-get install -y --no-install-recommends \
python3 \ python3 \
python3-pip \ python3-pip \
python3-venv \
ffmpeg \ ffmpeg \
git && \ git \
build-essential \
python3-dev \
ca-certificates && \
rm -rf /var/lib/apt/lists/* rm -rf /var/lib/apt/lists/*
RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 RUN python3 -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
# timeout/retries for large torch wheels
RUN pip3 install --upgrade pip setuptools wheel && \
pip3 --disable-pip-version-check install --timeout=120 --retries=5 \
--index-url https://download.pytorch.org/whl/cu129 \
torch torchaudio \
|| (echo "Initial install failed — retrying with extended timeout..." && \
pip3 --disable-pip-version-check install --timeout=300 --retries=3 \
--index-url https://download.pytorch.org/whl/cu129 \
torch torchvision torchaudio)
COPY . . COPY . .
# Install WhisperLiveKit directly, allowing for optional dependencies # Install WhisperLiveKit directly, allowing for optional dependencies
# Note: For gates models, need to add your HF toke. See README.md
# for more details.
RUN if [ -n "$EXTRAS" ]; then \ RUN if [ -n "$EXTRAS" ]; then \
echo "Installing with extras: [$EXTRAS]"; \ echo "Installing with extras: [$EXTRAS]"; \
pip install --no-cache-dir .[$EXTRAS]; \ pip install --no-cache-dir whisperlivekit[$EXTRAS]; \
else \ else \
echo "Installing base package only"; \ echo "Installing base package only"; \
pip install --no-cache-dir .; \ pip install --no-cache-dir whisperlivekit; \
fi fi
# Enable in-container caching for Hugging Face models by: # In-container caching for Hugging Face models by:
# Note: If running multiple containers, better to map a shared
# bucket.
#
# A) Make the cache directory persistent via an anonymous volume. # A) Make the cache directory persistent via an anonymous volume.
# Note: This only persists for a single, named container. This is # Note: This only persists for a single, named container. This is
# only for convenience at de/test stage. # only for convenience at de/test stage.
# For prod, it is better to use a named volume via host mount/k8s. # For prod, it is better to use a named volume via host mount/k8s.
VOLUME ["/root/.cache/huggingface/hub"] VOLUME ["/root/.cache/huggingface/hub"]
# or # or
# B) Conditionally copy a local pre-cache from the build context to the # B) Conditionally copy a local pre-cache from the build context to the
# container's cache via the HF_PRECACHE_DIR build-arg. # container's cache via the HF_PRECACHE_DIR build-arg.
@@ -63,8 +67,7 @@ RUN if [ -n "$HF_PRECACHE_DIR" ]; then \
echo "No local Hugging Face cache specified, skipping copy"; \ echo "No local Hugging Face cache specified, skipping copy"; \
fi fi
# Conditionally copy a Hugging Face token if provided # Conditionally copy a Hugging Face token if provided. Useful for Diart backend (pyannote audio models)
RUN if [ -n "$HF_TKN_FILE" ]; then \ RUN if [ -n "$HF_TKN_FILE" ]; then \
echo "Copying Hugging Face token from $HF_TKN_FILE"; \ echo "Copying Hugging Face token from $HF_TKN_FILE"; \
mkdir -p /root/.cache/huggingface && \ mkdir -p /root/.cache/huggingface && \
@@ -72,11 +75,9 @@ RUN if [ -n "$HF_TKN_FILE" ]; then \
else \ else \
echo "No Hugging Face token file specified, skipping token setup"; \ echo "No Hugging Face token file specified, skipping token setup"; \
fi fi
# Expose port for the transcription server
EXPOSE 8000 EXPOSE 8000
ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"] ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"]
# Default args CMD ["--model", "medium"]
CMD ["--model", "tiny.en"]

61
Dockerfile.cpu Normal file
View File

@@ -0,0 +1,61 @@
FROM python:3.13-slim
ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1
WORKDIR /app
ARG EXTRAS
ARG HF_PRECACHE_DIR
ARG HF_TKN_FILE
RUN apt-get update && \
apt-get install -y --no-install-recommends \
ffmpeg \
git \
build-essential \
python3-dev && \
rm -rf /var/lib/apt/lists/*
# Install CPU-only PyTorch
RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
COPY . .
# Install WhisperLiveKit directly, allowing for optional dependencies
RUN if [ -n "$EXTRAS" ]; then \
echo "Installing with extras: [$EXTRAS]"; \
pip install --no-cache-dir whisperlivekit[$EXTRAS]; \
else \
echo "Installing base package only"; \
pip install --no-cache-dir whisperlivekit; \
fi
# Enable in-container caching for Hugging Face models
VOLUME ["/root/.cache/huggingface/hub"]
# Conditionally copy a local pre-cache from the build context
RUN if [ -n "$HF_PRECACHE_DIR" ]; then \
echo "Copying Hugging Face cache from $HF_PRECACHE_DIR"; \
mkdir -p /root/.cache/huggingface/hub && \
cp -r $HF_PRECACHE_DIR/* /root/.cache/huggingface/hub; \
else \
echo "No local Hugging Face cache specified, skipping copy"; \
fi
# Conditionally copy a Hugging Face token if provided
RUN if [ -n "$HF_TKN_FILE" ]; then \
echo "Copying Hugging Face token from $HF_TKN_FILE"; \
mkdir -p /root/.cache/huggingface && \
cp $HF_TKN_FILE /root/.cache/huggingface/token; \
else \
echo "No Hugging Face token file specified, skipping token setup"; \
fi
# Expose port for the transcription server
EXPOSE 8000
ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"]
# Default args - you might want to use a smaller model for CPU
CMD ["--model", "tiny"]

322
README.md
View File

@@ -4,133 +4,97 @@
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/demo.png" alt="WhisperLiveKit Demo" width="730"> <img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/demo.png" alt="WhisperLiveKit Demo" width="730">
</p> </p>
<p align="center"><b>Real-time, Fully Local Speech-to-Text with Speaker Diarization</b></p> <p align="center"><b>Real-time, Fully Local Speech-to-Text with Speaker Identification</b></p>
<p align="center"> <p align="center">
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a> <a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=downloads"></a> <a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a>
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.13-dark_green"></a> <a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.15-dark_green"></a>
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT/Dual Licensed-dark_green"></a> <a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT/Dual Licensed-dark_green"></a>
</p> </p>
## Overview
This project is based on [WhisperStreaming](https://github.com/ufal/whisper_streaming) and [SimulStreaming](https://github.com/ufal/SimulStreaming), allowing you to transcribe audio directly from your browser. WhisperLiveKit provides a complete backend solution for real-time speech transcription with a functional, simple and customizable frontend. Everything runs locally on your machine Real-time speech transcription directly to your browser, with a ready-to-use backend+server and a simple frontend.
#### Powered by Leading Research:
- [SimulStreaming](https://github.com/ufal/SimulStreaming) (SOTA 2025) - Ultra-low latency transcription with AlignAtt policy
- [NLLB](https://arxiv.org/abs/2207.04672), ([distilled](https://huggingface.co/entai2965/nllb-200-distilled-600M-ctranslate2)) (2024) - Translation to more than 100 languages.
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - Low latency transcription with LocalAgreement policy
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - Advanced real-time speaker diarization
- [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) - Real-time speaker diarization
- [Silero VAD](https://github.com/snakers4/silero-vad) (2024) - Enterprise-grade Voice Activity Detection
> **Why not just run a simple Whisper model on every audio batch?** Whisper is designed for complete utterances, not real-time chunks. Processing small segments loses context, cuts off words mid-syllable, and produces poor transcription. WhisperLiveKit uses state-of-the-art simultaneous speech research for intelligent buffering and incremental processing.
### Architecture ### Architecture
WhisperLiveKit consists of three main components: <img alt="Architecture" src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/architecture.png" />
- **Frontend**: A basic html + JS interface that captures microphone audio and streams it to the backend via WebSockets. You can use and adapt the [provided template](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html). *The backend supports multiple concurrent users. Voice Activity Detection reduces overhead when no voice is detected.*
- **Backend (Web Server)**: A FastAPI-based WebSocket server that receives streamed audio data, processes it in real time, and returns transcriptions to the frontend. This is where the WebSocket logic and routing live.
- **Core Backend (Library Logic)**: A server-agnostic core that handles audio processing, ASR, and diarization. It exposes reusable components that take in audio bytes and return transcriptions.
### Installation & Quick Start
### Key Features
- **Real-time Transcription** - Locally (or on-prem) convert speech to text instantly as you speak
- **Speaker Diarization** - Identify different speakers in real-time using [Diart](https://github.com/juanmc2005/diart)
- **Multi-User Support** - Handle multiple users simultaneously with a single backend/server
- **Automatic Silence Chunking** Automatically chunks when no audio is detected to limit buffer size
- **Confidence Validation** Immediately validate high-confidence tokens for faster inference (WhisperStreaming only)
- **Buffering Preview** Displays unvalidated transcription segments (not compatible with SimulStreaming yet)
- **Punctuation-Based Speaker Splitting [BETA]** - Align speaker changes with natural sentence boundaries for more readable transcripts
- **SimulStreaming Backend** - [Dual-licensed](https://github.com/ufal/SimulStreaming#-licence-and-contributions) - Ultra-low latency transcription using SOTA AlignAtt policy.
## Quick Start
```bash ```bash
# Install the package
pip install whisperlivekit pip install whisperlivekit
# Start the transcription server
whisperlivekit-server --model tiny.en
# Open your browser at http://localhost:8000 to see the interface.
# Use -ssl-certfile public.crt --ssl-keyfile private.key parameters to use SSL
``` ```
> You can also clone the repo and `pip install -e .` for the latest version.
That's it! Start speaking and watch your words appear on screen.
## Installation > **FFmpeg is required** and must be installed before using WhisperLiveKit
>
> | OS | How to install |
> |-----------|-------------|
> | Ubuntu/Debian | `sudo apt install ffmpeg` |
> | MacOS | `brew install ffmpeg` |
> | Windows | Download .exe from https://ffmpeg.org/download.html and add to PATH |
#### Quick Start
1. **Start the transcription server:**
```bash
whisperlivekit-server --model base --language en
```
2. **Open your browser** and navigate to `http://localhost:8000`. Start speaking and watch your words appear in real-time!
> - See [tokenizer.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) for the list of all available languages.
> - For HTTPS requirements, see the **Parameters** section for SSL configuration options.
#### Optional Dependencies
| Optional | `pip install` |
|-----------|-------------|
| **Speaker diarization with Sortformer** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
| **Apple Silicon optimized backend** | `mlx-whisper` |
| **NLLB Translation** | `huggingface_hub` & `transformers` |
| *[Not recommanded]* Speaker diarization with Diart | `diart` |
| *[Not recommanded]* Original Whisper backend | `whisper` |
| *[Not recommanded]* Improved timestamps backend | `whisper-timestamped` |
| OpenAI API backend | `openai` |
See **Parameters & Configuration** below on how to use them.
### Usage Examples
**Command-line Interface**: Start the transcription server with various options:
```bash ```bash
#Install from PyPI (Recommended) # Use better model than default (small)
pip install whisperlivekit whisperlivekit-server --model large-v3
#Install from Source # Advanced configuration with diarization and language
git clone https://github.com/QuentinFuxa/WhisperLiveKit whisperlivekit-server --host 0.0.0.0 --port 8000 --model medium --diarization --language fr
cd WhisperLiveKit
pip install -e .
```
### FFmpeg Dependency
```bash
# Ubuntu/Debian
sudo apt install ffmpeg
# macOS
brew install ffmpeg
# Windows
# Download from https://ffmpeg.org/download.html and add to PATH
```
### Optional Dependencies
```bash
# Voice Activity Controller (prevents hallucinations)
pip install torch
# Sentence-based buffer trimming
pip install mosestokenizer wtpsplit
pip install tokenize_uk # If you work with Ukrainian text
# Speaker diarization
pip install diart
# Alternative Whisper backends (default is faster-whisper)
pip install whisperlivekit[whisper] # Original Whisper
pip install whisperlivekit[whisper-timestamped] # Improved timestamps
pip install whisperlivekit[mlx-whisper] # Apple Silicon optimization
pip install whisperlivekit[openai] # OpenAI API
pip install whisperlivekit[simulstreaming]
```
### 🎹 Pyannote Models Setup
For diarization, you need access to pyannote.audio models:
1. [Accept user conditions](https://huggingface.co/pyannote/segmentation) for the `pyannote/segmentation` model
2. [Accept user conditions](https://huggingface.co/pyannote/segmentation-3.0) for the `pyannote/segmentation-3.0` model
3. [Accept user conditions](https://huggingface.co/pyannote/embedding) for the `pyannote/embedding` model
4. Login with HuggingFace:
```bash
pip install huggingface_hub
huggingface-cli login
```
## 💻 Usage Examples
### Command-line Interface
Start the transcription server with various options:
```bash
# Basic server with English model
whisperlivekit-server --model tiny.en
# Advanced configuration with diarization
whisperlivekit-server --host 0.0.0.0 --port 8000 --model medium --diarization --language auto
# SimulStreaming backend for ultra-low latency
whisperlivekit-server --backend simulstreaming --model large-v3 --frame-threshold 20
``` ```
### Python API Integration (Backend) **Python API Integration**: Check [basic_server](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) for a more complete example of how to use the functions and classes.
Check [basic_server.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) for a complete example.
```python ```python
from whisperlivekit import TranscriptionEngine, AudioProcessor, parse_args from whisperlivekit import TranscriptionEngine, AudioProcessor, parse_args
@@ -145,14 +109,10 @@ transcription_engine = None
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
global transcription_engine global transcription_engine
transcription_engine = TranscriptionEngine(model="medium", diarization=True, lan="en") transcription_engine = TranscriptionEngine(model="medium", diarization=True, lan="en")
# You can also load from command-line arguments using parse_args()
# args = parse_args()
# transcription_engine = TranscriptionEngine(**vars(args))
yield yield
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
# Process WebSocket connections
async def handle_websocket_results(websocket: WebSocket, results_generator): async def handle_websocket_results(websocket: WebSocket, results_generator):
async for response in results_generator: async for response in results_generator:
await websocket.send_json(response) await websocket.send_json(response)
@@ -172,44 +132,44 @@ async def websocket_endpoint(websocket: WebSocket):
await audio_processor.process_audio(message) await audio_processor.process_audio(message)
``` ```
### Frontend Implementation **Frontend Implementation**: The package includes an HTML/JavaScript implementation [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html). You can also import it using `from whisperlivekit import get_inline_ui_html` & `page = get_inline_ui_html()`
The package includes a simple HTML/JavaScript implementation that you can adapt for your project. You can find it [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html), or load its content using `get_web_interface_html()` :
```python ## Parameters & Configuration
from whisperlivekit import get_web_interface_html
html_content = get_web_interface_html()
```
## ⚙️ Configuration Reference An important list of parameters can be changed. But what *should* you change?
- the `--model` size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/available_models.md)
- the `--language`. List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English.
- the `--backend` ? you can switch to `--backend faster-whisper` if `simulstreaming` does not work correctly or if you prefer to avoid the dual-license requirements.
- `--warmup-file`, if you have one
- `--task translate`, to translate in english
- `--host`, `--port`, `--ssl-certfile`, `--ssl-keyfile`, if you set up a server
- `--diarization`, if you want to use it.
- [BETA] `--target-language`, to translate using NLLB. [118 languages available](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/translation/mapping_languages.py). If you want to translate to english, you should rather use `--task translate`, since Whisper can do it directly.
WhisperLiveKit offers extensive configuration options: ### Full list of parameters :
| Parameter | Description | Default | | Parameter | Description | Default |
|-----------|-------------|---------| |-----------|-------------|---------|
| `--model` | Whisper model size. | `small` |
| `--language` | Source language code or `auto` | `auto` |
| `--task` | Set to `translate` to translate to english | `transcribe` |
| `--target-language` | [BETA] Translation language target. Ex: `fr` | `None` |
| `--backend` | Processing backend | `simulstreaming` |
| `--min-chunk-size` | Minimum audio chunk size (seconds) | `1.0` |
| `--no-vac` | Disable Voice Activity Controller | `False` |
| `--no-vad` | Disable Voice Activity Detection | `False` |
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
| `--host` | Server host address | `localhost` | | `--host` | Server host address | `localhost` |
| `--port` | Server port | `8000` | | `--port` | Server port | `8000` |
| `--model` | Whisper model size. Caution : '.en' models do not work with Simulstreaming | `tiny` |
| `--language` | Source language code or `auto` | `en` |
| `--task` | `transcribe` or `translate` | `transcribe` |
| `--backend` | Processing backend | `faster-whisper` |
| `--diarization` | Enable speaker identification | `False` |
| `--punctuation-split` | Use punctuation to improve speaker boundaries | `True` |
| `--confidence-validation` | Use confidence scores for faster validation | `False` |
| `--min-chunk-size` | Minimum audio chunk size (seconds) | `1.0` |
| `--vac` | Use Voice Activity Controller | `False` |
| `--no-vad` | Disable Voice Activity Detection | `False` |
| `--buffer_trimming` | Buffer trimming strategy (`sentence` or `segment`) | `segment` |
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
| `--ssl-certfile` | Path to the SSL certificate file (for HTTPS support) | `None` | | `--ssl-certfile` | Path to the SSL certificate file (for HTTPS support) | `None` |
| `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` | | `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` |
| `--segmentation-model` | Hugging Face model ID for pyannote.audio segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` | | `--pcm-input` | raw PCM (s16le) data is expected as input and FFmpeg will be bypassed. | `False` |
| `--embedding-model` | Hugging Face model ID for pyannote.audio embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
**SimulStreaming-specific Options:**
| Parameter | Description | Default | | SimulStreaming backend options | Description | Default |
|-----------|-------------|---------| |-----------|-------------|---------|
| `--disable-fast-encoder` | Disable Faster Whisper or MLX Whisper backends for the encoder (if installed). Inference can be slower but helpful when GPU memory is limited | `False` |
| `--frame-threshold` | AlignAtt frame threshold (lower = faster, higher = more accurate) | `25` | | `--frame-threshold` | AlignAtt frame threshold (lower = faster, higher = more accurate) | `25` |
| `--beams` | Number of beams for beam search (1 = greedy decoding) | `1` | | `--beams` | Number of beams for beam search (1 = greedy decoding) | `1` |
| `--decoder` | Force decoder type (`beam` or `greedy`) | `auto` | | `--decoder` | Force decoder type (`beam` or `greedy`) | `auto` |
@@ -221,68 +181,91 @@ WhisperLiveKit offers extensive configuration options:
| `--static-init-prompt` | Static prompt that doesn't scroll | `None` | | `--static-init-prompt` | Static prompt that doesn't scroll | `None` |
| `--max-context-tokens` | Maximum context tokens | `None` | | `--max-context-tokens` | Maximum context tokens | `None` |
| `--model-path` | Direct path to .pt model file. Download it if not found | `./base.pt` | | `--model-path` | Direct path to .pt model file. Download it if not found | `./base.pt` |
| `--preload-model-count` | Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent users) | `1` |
## 🔧 How It Works
1. **Audio Capture**: Browser's MediaRecorder API captures audio in webm/opus format | WhisperStreaming backend options | Description | Default |
2. **Streaming**: Audio chunks are sent to the server via WebSocket |-----------|-------------|---------|
3. **Processing**: Server decodes audio with FFmpeg and streams into the model for transcription | `--confidence-validation` | Use confidence scores for faster validation | `False` |
4. **Real-time Output**: Partial transcriptions appear immediately in light gray (the 'aperçu') and finalized text appears in normal color | `--buffer_trimming` | Buffer trimming strategy (`sentence` or `segment`) | `segment` |
## 🚀 Deployment Guide | Diarization options | Description | Default |
|-----------|-------------|---------|
| `--diarization` | Enable speaker identification | `False` |
| `--diarization-backend` | `diart` or `sortformer` | `sortformer` |
| `--disable-punctuation-split` | Disable punctuation based splits. See #214 | `False` |
| `--segmentation-model` | Hugging Face model ID for Diart segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
| `--embedding-model` | Hugging Face model ID for Diart embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
> For diarization using Diart, you need access to pyannote.audio models:
> 1. [Accept user conditions](https://huggingface.co/pyannote/segmentation) for the `pyannote/segmentation` model
> 2. [Accept user conditions](https://huggingface.co/pyannote/segmentation-3.0) for the `pyannote/segmentation-3.0` model
> 3. [Accept user conditions](https://huggingface.co/pyannote/embedding) for the `pyannote/embedding` model
>4. Login with HuggingFace: `huggingface-cli login`
### 🚀 Deployment Guide
To deploy WhisperLiveKit in production: To deploy WhisperLiveKit in production:
1. **Server Setup** (Backend): 1. **Server Setup**: Install production ASGI server & launch with multiple workers
```bash ```bash
# Install production ASGI server
pip install uvicorn gunicorn pip install uvicorn gunicorn
# Launch with multiple workers
gunicorn -k uvicorn.workers.UvicornWorker -w 4 your_app:app gunicorn -k uvicorn.workers.UvicornWorker -w 4 your_app:app
``` ```
2. **Frontend Integration**: 2. **Frontend**: Host your customized version of the `html` example & ensure WebSocket connection points correctly
- Host your customized version of the example HTML/JS in your web application
- Ensure WebSocket connection points to your server's address
3. **Nginx Configuration** (recommended for production): 3. **Nginx Configuration** (recommended for production):
```nginx ```nginx
server { server {
listen 80; listen 80;
server_name your-domain.com; server_name your-domain.com;
location / {
location / { proxy_pass http://localhost:8000;
proxy_pass http://localhost:8000; proxy_set_header Upgrade $http_upgrade;
proxy_set_header Upgrade $http_upgrade; proxy_set_header Connection "upgrade";
proxy_set_header Connection "upgrade"; proxy_set_header Host $host;
proxy_set_header Host $host;
}} }}
```
4. **HTTPS Support**: For secure deployments, use "wss://" instead of "ws://" in WebSocket URL 4. **HTTPS Support**: For secure deployments, use "wss://" instead of "ws://" in WebSocket URL
### 🐋 Docker ## 🐋 Docker
A basic Dockerfile is provided which allows re-use of Python package installation options. ⚠️ For **large** models, ensure that your **docker runtime** has enough **memory** available. See below usage examples: Deploy the application easily using Docker with GPU or CPU support.
### Prerequisites
- Docker installed on your system
- For GPU support: NVIDIA Docker runtime installed
#### All defaults ### Quick Start
- Create a reusable image with only the basics and then run as a named container:
**With GPU acceleration (recommended):**
```bash ```bash
docker build -t whisperlivekit-defaults . docker build -t wlk .
docker create --gpus all --name whisperlivekit -p 8000:8000 whisperlivekit-defaults docker run --gpus all -p 8000:8000 --name wlk wlk
docker start -i whisperlivekit
``` ```
> **Note**: If you're running on a system without NVIDIA GPU support (such as Mac with Apple Silicon or any system without CUDA capabilities), you need to **remove the `--gpus all` flag** from the `docker create` command. Without GPU acceleration, transcription will use CPU only, which may be significantly slower. Consider using small models for better performance on CPU-only systems. **CPU only:**
```bash
docker build -f Dockerfile.cpu -t wlk .
docker run -p 8000:8000 --name wlk wlk
```
### Advanced Usage
**Custom configuration:**
```bash
# Example with custom model and language
docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
```
### Memory Requirements
- **Large models**: Ensure your Docker runtime has sufficient memory allocated
#### Customization #### Customization
- Customize the container options:
```bash
docker build -t whisperlivekit-defaults .
docker create --gpus all --name whisperlivekit-base -p 8000:8000 whisperlivekit-defaults --model base
docker start -i whisperlivekit-base
```
- `--build-arg` Options: - `--build-arg` Options:
- `EXTRAS="whisper-timestamped"` - Add extras to the image's installation (no spaces). Remember to set necessary container options! - `EXTRAS="whisper-timestamped"` - Add extras to the image's installation (no spaces). Remember to set necessary container options!
@@ -291,10 +274,3 @@ docker start -i whisperlivekit-base
## 🔮 Use Cases ## 🔮 Use Cases
Capture discussions in real-time for meeting transcription, help hearing-impaired users follow conversations through accessibility tools, transcribe podcasts or videos automatically for content creation, transcribe support calls with speaker identification for customer service... Capture discussions in real-time for meeting transcription, help hearing-impaired users follow conversations through accessibility tools, transcribe podcasts or videos automatically for content creation, transcribe support calls with speaker identification for customer service...
## 🙏 Acknowledgments
We extend our gratitude to the original authors of:
| [Whisper Streaming](https://github.com/ufal/whisper_streaming) | [SimulStreaming](https://github.com/ufal/SimulStreaming) | [Diart](https://github.com/juanmc2005/diart) | [OpenAI Whisper](https://github.com/openai/whisper) |
| -------- | ------- | -------- | ------- |

258
ReadmeJP.md Normal file
View File

@@ -0,0 +1,258 @@
<h1 align="center">WhisperLiveKit</h1>
<p align="center">
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/demo.png" alt="WhisperLiveKit Demo" width="730">
</p>
<p align="center"><b>話者識別機能付き、リアルタイム、完全ローカルな音声テキスト変換</b></p>
<p align="center">
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a>
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.13-dark_green"></a>
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT/Dual Licensed-dark_green"></a>
</p>
すぐに使えるバックエンド+サーバーとシンプルなフロントエンドで、リアルタイムの音声文字起こしをブラウザに直接提供します。✨
#### 主要な研究による技術:
- [SimulStreaming](https://github.com/ufal/SimulStreaming) (SOTA 2025) - AlignAttポリシーによる超低遅延文字起こし
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - LocalAgreementポリシーによる低遅延文字起こし
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - 高度なリアルタイム話者ダイアライゼーション
- [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) - リアルタイム話者ダイアライゼーション
- [Silero VAD](https://github.com/snakers4/silero-vad) (2024) - エンタープライズグレードの音声区間検出
> **なぜ各音声バッチで単純なWhisperモデルを実行しないのか** Whisperは完全な発話向けに設計されており、リアルタイムのチャンク向けではありません。小さなセグメントを処理するとコンテキストが失われ、単語が音節の途中で途切れ、質の悪い文字起こしになります。WhisperLiveKitは、インテリジェントなバッファリングとインクリメンタルな処理のために、最先端の同時音声研究を利用しています。
### アーキテクチャ
<img alt="Architecture" src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/architecture.png" />
*バックエンドは複数の同時ユーザーをサポートします。音声が検出されない場合、音声区間検出がオーバーヘッドを削減します。*
### インストールとクイックスタート
```bash
pip install whisperlivekit
```
> **FFmpegが必要です** WhisperLiveKitを使用する前にインストールする必要があります。
>
> | OS | インストール方法 |
> |-----------|-------------|
> | Ubuntu/Debian | `sudo apt install ffmpeg` |
> | MacOS | `brew install ffmpeg` |
> | Windows | https://ffmpeg.org/download.html から.exeをダウンロードし、PATHに追加 |
#### クイックスタート
1. **文字起こしサーバーを起動します:**
```bash
whisperlivekit-server --model base --language en
```
2. **ブラウザを開き** `http://localhost:8000` にアクセスします。話し始めると、あなたの言葉がリアルタイムで表示されます!
> - 利用可能なすべての言語のリストについては、[tokenizer.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) を参照してください。
> - HTTPSの要件については、**パラメータ**セクションのSSL設定オプションを参照してください。
#### オプションの依存関係
| オプション | `pip install` |
|-----------|-------------|
| **Sortformerによる話者ダイアライゼーション** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
| Diartによる話者ダイアライゼーション | `diart` |
| オリジナルのWhisperバックエンド | `whisper` |
| タイムスタンプ改善バックエンド | `whisper-timestamped` |
| Apple Silicon最適化バックエンド | `mlx-whisper` |
| OpenAI APIバックエンド | `openai` |
それらの使用方法については、以下の**パラメータと設定**を参照してください。
### 使用例
**コマンドラインインターフェース**: 様々なオプションで文字起こしサーバーを起動します:
```bash
# デフォルト(small)より良いモデルを使用
whisperlivekit-server --model large-v3
# ダイアライゼーションと言語を指定した高度な設定
whisperlivekit-server --host 0.0.0.0 --port 8000 --model medium --diarization --language fr
```
**Python API連携**: 関数やクラスの使用方法のより完全な例については、[basic_server](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) を確認してください。
```python
from whisperlivekit import TranscriptionEngine, AudioProcessor, parse_args
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from contextlib import asynccontextmanager
import asyncio
transcription_engine = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global transcription_engine
transcription_engine = TranscriptionEngine(model="medium", diarization=True, lan="en")
yield
app = FastAPI(lifespan=lifespan)
async def handle_websocket_results(websocket: WebSocket, results_generator):
async for response in results_generator:
await websocket.send_json(response)
await websocket.send_json({"type": "ready_to_stop"})
@app.websocket("/asr")
async def websocket_endpoint(websocket: WebSocket):
global transcription_engine
# 接続ごとに新しいAudioProcessorを作成し、共有エンジンを渡す
audio_processor = AudioProcessor(transcription_engine=transcription_engine)
results_generator = await audio_processor.create_tasks()
results_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
await websocket.accept()
while True:
message = await websocket.receive_bytes()
await audio_processor.process_audio(message)
```
**フロントエンド実装**: パッケージにはHTML/JavaScript実装が[ここ](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html)に含まれています。`from whisperlivekit import get_web_interface_html` & `page = get_web_interface_html()` を使ってインポートすることもできます。
## パラメータと設定
重要なパラメータのリストを変更できます。しかし、何を*変更すべき*でしょうか?
- `--model` サイズ。リストと推奨事項は[こちら](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/available_models.md)
- `--language`。リストは[こちら](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py)。`auto`を使用すると、モデルは自動的に言語を検出しようとしますが、英語に偏る傾向があります。
- `--backend` `simulstreaming`が正しく動作しない場合や、デュアルライセンス要件を避けたい場合は`--backend faster-whisper`に切り替えることができます。
- `--warmup-file`、もしあれば
- `--host`, `--port`, `--ssl-certfile`, `--ssl-keyfile`、サーバーをセットアップする場合
- `--diarization`、使用したい場合。
残りは推奨しません。しかし、以下があなたのオプションです。
| パラメータ | 説明 | デフォルト |
|-----------|-------------|---------|
| `--model` | Whisperモデルのサイズ。 | `small` |
| `--language` | ソース言語コードまたは`auto` | `auto` |
| `--task` | `transcribe`または`translate` | `transcribe` |
| `--backend` | 処理バックエンド | `simulstreaming` |
| `--min-chunk-size` | 最小音声チャンクサイズ(秒) | `1.0` |
| `--no-vac` | 音声アクティビティコントローラーを無効化 | `False` |
| `--no-vad` | 音声区間検出を無効化 | `False` |
| `--warmup-file` | モデルのウォームアップ用音声ファイルパス | `jfk.wav` |
| `--host` | サーバーホストアドレス | `localhost` |
| `--port` | サーバーポート | `8000` |
| `--ssl-certfile` | SSL証明書ファイルへのパスHTTPSサポート用 | `None` |
| `--ssl-keyfile` | SSL秘密鍵ファイルへのパスHTTPSサポート用 | `None` |
| WhisperStreamingバックエンドオプション | 説明 | デフォルト |
|-----------|-------------|---------|
| `--confidence-validation` | 高速な検証のために信頼スコアを使用 | `False` |
| `--buffer_trimming` | バッファトリミング戦略(`sentence`または`segment` | `segment` |
| SimulStreamingバックエンドオプション | 説明 | デフォルト |
|-----------|-------------|---------|
| `--frame-threshold` | AlignAttフレームしきい値低いほど速く、高いほど正確 | `25` |
| `--beams` | ビームサーチのビーム数1 = 貪欲デコーディング) | `1` |
| `--decoder` | デコーダタイプを強制(`beam`または`greedy` | `auto` |
| `--audio-max-len` | 最大音声バッファ長(秒) | `30.0` |
| `--audio-min-len` | 処理する最小音声長(秒) | `0.0` |
| `--cif-ckpt-path` | 単語境界検出用CIFモデルへのパス | `None` |
| `--never-fire` | 未完了の単語を決して切り捨てない | `False` |
| `--init-prompt` | モデルの初期プロンプト | `None` |
| `--static-init-prompt` | スクロールしない静的プロンプト | `None` |
| `--max-context-tokens` | 最大コンテキストトークン数 | `None` |
| `--model-path` | .ptモデルファイルへの直接パス。見つからない場合はダウンロード | `./base.pt` |
| `--preloaded-model-count` | オプション。メモリにプリロードするモデルの数(予想される同時ユーザー数まで設定) | `1` |
| ダイアライゼーションオプション | 説明 | デフォルト |
|-----------|-------------|---------|
| `--diarization` | 話者識別を有効化 | `False` |
| `--diarization-backend` | `diart`または`sortformer` | `sortformer` |
| `--segmentation-model` | DiartセグメンテーションモデルのHugging FaceモデルID。[利用可能なモデル](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
| `--embedding-model` | Diart埋め込みモデルのHugging FaceモデルID。[利用可能なモデル](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
> Diartを使用したダイアライゼーションには、pyannote.audioモデルへのアクセスが必要です
> 1. `pyannote/segmentation`モデルの[ユーザー条件に同意](https://huggingface.co/pyannote/segmentation)
> 2. `pyannote/segmentation-3.0`モデルの[ユーザー条件に同意](https://huggingface.co/pyannote/segmentation-3.0)
> 3. `pyannote/embedding`モデルの[ユーザー条件に同意](https://huggingface.co/pyannote/embedding)
>4. HuggingFaceでログイン: `huggingface-cli login`
### 🚀 デプロイガイド
WhisperLiveKitを本番環境にデプロイするには
1. **サーバーセットアップ**: 本番用ASGIサーバーをインストールし、複数のワーカーで起動します
```bash
pip install uvicorn gunicorn
gunicorn -k uvicorn.workers.UvicornWorker -w 4 your_app:app
```
2. **フロントエンド**: カスタマイズした`html`のバージョンをホストし、WebSocket接続が正しくポイントするようにします
3. **Nginx設定** (本番環境で推奨):
```nginx
server {
listen 80;
server_name your-domain.com;
location / {
proxy_pass http://localhost:8000;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_set_header Host $host;
}}
```
4. **HTTPSサポート**: 安全なデプロイメントのために、WebSocket URLで "ws://" の代わりに "wss://" を使用します
## 🐋 Docker
GPUまたはCPUサポート付きでDockerを使用してアプリケーションを簡単にデプロイします。
### 前提条件
- Dockerがシステムにインストールされていること
- GPUサポートの場合: NVIDIA Dockerランタイムがインストールされていること
### クイックスタート
**GPUアクセラレーション付き (推奨):**
```bash
docker build -t wlk .
docker run --gpus all -p 8000:8000 --name wlk wlk
```
**CPUのみ:**
```bash
docker build -f Dockerfile.cpu -t wlk .
docker run -p 8000:8000 --name wlk wlk
```
### 高度な使用法
**カスタム設定:**
```bash
# カスタムモデルと言語の例
docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
```
### メモリ要件
- **大規模モデル**: Dockerランタイムに十分なメモリが割り当てられていることを確認してください
#### カスタマイズ
- `--build-arg` オプション:
- `EXTRAS="whisper-timestamped"` - イメージのインストールにエクストラを追加します(スペースなし)。必要なコンテナオプションを設定することを忘れないでください!
- `HF_PRECACHE_DIR="./.cache/"` - 初回起動を高速化するためにモデルキャッシュをプリロードします
- `HF_TKN_FILE="./token"` - ゲート付きモデルをダウンロードするためにHugging Face Hubアクセストークンを追加します
## 🔮 ユースケース
会議の文字起こしのためにリアルタイムで議論をキャプチャする、聴覚障害のあるユーザーがアクセシビリティツールを通じて会話を追うのを助ける、コンテンツ作成のためにポッドキャストやビデオを自動的に文字起こしする、カスタマーサービスのために話者識別付きでサポートコールを文字起こしする...

BIN
architecture.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 368 KiB

73
available_models.md Normal file
View File

@@ -0,0 +1,73 @@
# Available model sizes:
- tiny.en (english only)
- tiny
- base.en (english only)
- base
- small.en (english only)
- small
- medium.en (english only)
- medium
- large-v1
- large-v2
- large-v3
- large-v3-turbo
## How to choose?
### Language Support
- **English only**: Use `.en` models for better accuracy and faster processing when you only need English transcription
- **Multilingual**: Do not use `.en` models.
### Resource Constraints
- **Limited GPU/CPU or need for very low latency**: Choose `small` or smaller models
- `tiny`: Fastest, lowest resource usage, acceptable quality for simple audio
- `base`: Good balance of speed and accuracy for basic use cases
- `small`: Better accuracy while still being resource-efficient
- **Good resources available**: Use `large` models for best accuracy
- `large-v2`: Excellent accuracy, good multilingual support
- `large-v3`: Best overall accuracy and language support
### Special Cases
- **No translation needed**: Use `large-v3-turbo`
- Same transcription quality as `large-v2` but significantly faster
- **Important**: Does not translate correctly, only transcribes
### Model Comparison Table
| Model | Speed | Accuracy | Multilingual | Translation | Best Use Case |
|-------|--------|----------|--------------|-------------|---------------|
| tiny(.en) | Fastest | Basic | Yes/No | Yes/No | Real-time, low resources |
| base(.en) | Fast | Good | Yes/No | Yes/No | Balanced performance |
| small(.en) | Medium | Better | Yes/No | Yes/No | Quality on limited hardware |
| medium(.en) | Slow | High | Yes/No | Yes/No | High quality, moderate resources |
| large-v2 | Slowest | Excellent | Yes | Yes | Best overall quality |
| large-v3 | Slowest | Excellent | Yes | Yes | Maximum accuracy |
| large-v3-turbo | Fast | Excellent | Yes | No | Fast, high-quality transcription |
### Additional Considerations
**Model Performance**:
- Accuracy improves significantly from tiny to large models
- English-only models are ~10-15% more accurate for English audio
- Newer versions (v2, v3) have better punctuation and formatting
**Hardware Requirements**:
- `tiny`: ~1GB VRAM
- `base`: ~1GB VRAM
- `small`: ~2GB VRAM
- `medium`: ~5GB VRAM
- `large`: ~10GB VRAM
- `largev3turbo`: ~6GB VRAM
**Audio Quality Impact**:
- Clean, clear audio: smaller models may suffice
- Noisy, accented, or technical audio: larger models recommended
- Phone/low-quality audio: use at least `small` model
### Quick Decision Tree
1. English only? → Add `.en` to your choice
2. Limited resources or need speed? → `small` or smaller
3. Good hardware and want best quality? → `large-v3`
4. Need fast, high-quality transcription without translation? → `large-v3-turbo`
5. Need translation capabilities? → `large-v2` or `large-v3` (avoid turbo)

View File

@@ -0,0 +1,17 @@
## WhisperLiveKit Chrome Extension v0.1.0
Capture the audio of your current tab, transcribe or translate it using WhisperliveKit. **Still unstable**
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/chrome-extension/demo-extension.png" alt="WhisperLiveKit Demo" width="730">
## Running this extension
1. Clone this repository.
2. Load this directory in Chrome as an unpacked extension.
## Devs:
- Impossible to capture audio from tabs if extension is a pannel, unfortunately:
- https://issues.chromium.org/issues/40926394
- https://groups.google.com/a/chromium.org/g/chromium-extensions/c/DET2SXCFnDg
- https://issues.chromium.org/issues/40916430
- To capture microphone in an extension, there are tricks: https://github.com/justinmann/sidepanel-audio-issue , https://medium.com/@lynchee.owo/how-to-enable-microphone-access-in-chrome-extensions-by-code-924295170080 (comments)

View File

@@ -0,0 +1,9 @@
chrome.runtime.onInstalled.addListener((details) => {
if (details.reason.search(/install/g) === -1) {
return
}
chrome.tabs.create({
url: chrome.runtime.getURL("welcome.html"),
active: true
})
})

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 376 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 823 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 KiB

View File

@@ -0,0 +1,669 @@
/* Theme, WebSocket, recording, rendering logic extracted from inline script and adapted for segmented theme control and WS caption */
let isRecording = false;
let websocket = null;
let recorder = null;
let chunkDuration = 100;
let websocketUrl = "ws://localhost:8000/asr";
let userClosing = false;
let wakeLock = null;
let startTime = null;
let timerInterval = null;
let audioContext = null;
let analyser = null;
let microphone = null;
let waveCanvas = document.getElementById("waveCanvas");
let waveCtx = waveCanvas.getContext("2d");
let animationFrame = null;
let waitingForStop = false;
let lastReceivedData = null;
let lastSignature = null;
let availableMicrophones = [];
let selectedMicrophoneId = null;
waveCanvas.width = 60 * (window.devicePixelRatio || 1);
waveCanvas.height = 30 * (window.devicePixelRatio || 1);
waveCtx.scale(window.devicePixelRatio || 1, window.devicePixelRatio || 1);
const statusText = document.getElementById("status");
const recordButton = document.getElementById("recordButton");
const chunkSelector = document.getElementById("chunkSelector");
const websocketInput = document.getElementById("websocketInput");
const websocketDefaultSpan = document.getElementById("wsDefaultUrl");
const linesTranscriptDiv = document.getElementById("linesTranscript");
const timerElement = document.querySelector(".timer");
const themeRadios = document.querySelectorAll('input[name="theme"]');
const microphoneSelect = document.getElementById("microphoneSelect");
const settingsToggle = document.getElementById("settingsToggle");
const settingsDiv = document.querySelector(".settings");
chrome.runtime.onInstalled.addListener((details) => {
if (details.reason.search(/install/g) === -1) {
return
}
chrome.tabs.create({
url: chrome.runtime.getURL("welcome.html"),
active: true
})
})
function getWaveStroke() {
const styles = getComputedStyle(document.documentElement);
const v = styles.getPropertyValue("--wave-stroke").trim();
return v || "#000";
}
let waveStroke = getWaveStroke();
function updateWaveStroke() {
waveStroke = getWaveStroke();
}
function applyTheme(pref) {
if (pref === "light") {
document.documentElement.setAttribute("data-theme", "light");
} else if (pref === "dark") {
document.documentElement.setAttribute("data-theme", "dark");
} else {
document.documentElement.removeAttribute("data-theme");
}
updateWaveStroke();
}
// Persisted theme preference
const savedThemePref = localStorage.getItem("themePreference") || "system";
applyTheme(savedThemePref);
if (themeRadios.length) {
themeRadios.forEach((r) => {
r.checked = r.value === savedThemePref;
r.addEventListener("change", () => {
if (r.checked) {
localStorage.setItem("themePreference", r.value);
applyTheme(r.value);
}
});
});
}
// React to OS theme changes when in "system" mode
const darkMq = window.matchMedia && window.matchMedia("(prefers-color-scheme: dark)");
const handleOsThemeChange = () => {
const pref = localStorage.getItem("themePreference") || "system";
if (pref === "system") updateWaveStroke();
};
if (darkMq && darkMq.addEventListener) {
darkMq.addEventListener("change", handleOsThemeChange);
} else if (darkMq && darkMq.addListener) {
// deprecated, but included for Safari compatibility
darkMq.addListener(handleOsThemeChange);
}
async function enumerateMicrophones() {
try {
const micPermission = await navigator.permissions.query({
name: "microphone",
});
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
stream.getTracks().forEach(track => track.stop());
const devices = await navigator.mediaDevices.enumerateDevices();
availableMicrophones = devices.filter(device => device.kind === 'audioinput');
populateMicrophoneSelect();
console.log(`Found ${availableMicrophones.length} microphone(s)`);
} catch (error) {
console.error('Error enumerating microphones:', error);
statusText.textContent = "Error accessing microphones. Please grant permission.";
}
}
function populateMicrophoneSelect() {
if (!microphoneSelect) return;
microphoneSelect.innerHTML = '<option value="">Default Microphone</option>';
availableMicrophones.forEach((device, index) => {
const option = document.createElement('option');
option.value = device.deviceId;
option.textContent = device.label || `Microphone ${index + 1}`;
microphoneSelect.appendChild(option);
});
const savedMicId = localStorage.getItem('selectedMicrophone');
if (savedMicId && availableMicrophones.some(mic => mic.deviceId === savedMicId)) {
microphoneSelect.value = savedMicId;
selectedMicrophoneId = savedMicId;
}
}
function handleMicrophoneChange() {
selectedMicrophoneId = microphoneSelect.value || null;
localStorage.setItem('selectedMicrophone', selectedMicrophoneId || '');
const selectedDevice = availableMicrophones.find(mic => mic.deviceId === selectedMicrophoneId);
const deviceName = selectedDevice ? selectedDevice.label : 'Default Microphone';
console.log(`Selected microphone: ${deviceName}`);
statusText.textContent = `Microphone changed to: ${deviceName}`;
if (isRecording) {
statusText.textContent = "Switching microphone... Please wait.";
stopRecording().then(() => {
setTimeout(() => {
toggleRecording();
}, 1000);
});
}
}
// Helpers
function fmt1(x) {
const n = Number(x);
return Number.isFinite(n) ? n.toFixed(1) : x;
}
// Default WebSocket URL computation
const host = window.location.hostname || "localhost";
const port = window.location.port;
const protocol = window.location.protocol === "https:" ? "wss" : "ws";
const defaultWebSocketUrl = websocketUrl;
// Populate default caption and input
if (websocketDefaultSpan) websocketDefaultSpan.textContent = defaultWebSocketUrl;
websocketInput.value = defaultWebSocketUrl;
websocketUrl = defaultWebSocketUrl;
// Optional chunk selector (guard for presence)
if (chunkSelector) {
chunkSelector.addEventListener("change", () => {
chunkDuration = parseInt(chunkSelector.value);
});
}
// WebSocket input change handling
websocketInput.addEventListener("change", () => {
const urlValue = websocketInput.value.trim();
if (!urlValue.startsWith("ws://") && !urlValue.startsWith("wss://")) {
statusText.textContent = "Invalid WebSocket URL (must start with ws:// or wss://)";
return;
}
websocketUrl = urlValue;
statusText.textContent = "WebSocket URL updated. Ready to connect.";
});
function setupWebSocket() {
return new Promise((resolve, reject) => {
try {
websocket = new WebSocket(websocketUrl);
} catch (error) {
statusText.textContent = "Invalid WebSocket URL. Please check and try again.";
reject(error);
return;
}
websocket.onopen = () => {
statusText.textContent = "Connected to server.";
resolve();
};
websocket.onclose = () => {
if (userClosing) {
if (waitingForStop) {
statusText.textContent = "Processing finalized or connection closed.";
if (lastReceivedData) {
renderLinesWithBuffer(
lastReceivedData.lines || [],
lastReceivedData.buffer_diarization || "",
lastReceivedData.buffer_transcription || "",
0,
0,
true
);
}
}
} else {
statusText.textContent = "Disconnected from the WebSocket server. (Check logs if model is loading.)";
if (isRecording) {
stopRecording();
}
}
isRecording = false;
waitingForStop = false;
userClosing = false;
lastReceivedData = null;
websocket = null;
updateUI();
};
websocket.onerror = () => {
statusText.textContent = "Error connecting to WebSocket.";
reject(new Error("Error connecting to WebSocket"));
};
websocket.onmessage = (event) => {
const data = JSON.parse(event.data);
if (data.type === "ready_to_stop") {
console.log("Ready to stop received, finalizing display and closing WebSocket.");
waitingForStop = false;
if (lastReceivedData) {
renderLinesWithBuffer(
lastReceivedData.lines || [],
lastReceivedData.buffer_diarization || "",
lastReceivedData.buffer_transcription || "",
0,
0,
true
);
}
statusText.textContent = "Finished processing audio! Ready to record again.";
recordButton.disabled = false;
if (websocket) {
websocket.close();
}
return;
}
lastReceivedData = data;
const {
lines = [],
buffer_transcription = "",
buffer_diarization = "",
remaining_time_transcription = 0,
remaining_time_diarization = 0,
status = "active_transcription",
} = data;
renderLinesWithBuffer(
lines,
buffer_diarization,
buffer_transcription,
remaining_time_diarization,
remaining_time_transcription,
false,
status
);
};
});
}
function renderLinesWithBuffer(
lines,
buffer_diarization,
buffer_transcription,
remaining_time_diarization,
remaining_time_transcription,
isFinalizing = false,
current_status = "active_transcription"
) {
if (current_status === "no_audio_detected") {
linesTranscriptDiv.innerHTML =
"<p style='text-align: center; color: var(--muted); margin-top: 20px;'><em>No audio detected...</em></p>";
return;
}
const showLoading = !isFinalizing && (lines || []).some((it) => it.speaker == 0);
const showTransLag = !isFinalizing && remaining_time_transcription > 0;
const showDiaLag = !isFinalizing && !!buffer_diarization && remaining_time_diarization > 0;
const signature = JSON.stringify({
lines: (lines || []).map((it) => ({ speaker: it.speaker, text: it.text, start: it.start, end: it.end })),
buffer_transcription: buffer_transcription || "",
buffer_diarization: buffer_diarization || "",
status: current_status,
showLoading,
showTransLag,
showDiaLag,
isFinalizing: !!isFinalizing,
});
if (lastSignature === signature) {
const t = document.querySelector(".lag-transcription-value");
if (t) t.textContent = fmt1(remaining_time_transcription);
const d = document.querySelector(".lag-diarization-value");
if (d) d.textContent = fmt1(remaining_time_diarization);
const ld = document.querySelector(".loading-diarization-value");
if (ld) ld.textContent = fmt1(remaining_time_diarization);
return;
}
lastSignature = signature;
const linesHtml = (lines || [])
.map((item, idx) => {
let timeInfo = "";
if (item.start !== undefined && item.end !== undefined) {
timeInfo = ` ${item.start} - ${item.end}`;
}
let speakerLabel = "";
if (item.speaker === -2) {
speakerLabel = `<span class="silence">Silence<span id='timeInfo'>${timeInfo}</span></span>`;
} else if (item.speaker == 0 && !isFinalizing) {
speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'><span class="loading-diarization-value">${fmt1(
remaining_time_diarization
)}</span> second(s) of audio are undergoing diarization</span></span>`;
} else if (item.speaker !== 0) {
speakerLabel = `<span id="speaker">Speaker ${item.speaker}<span id='timeInfo'>${timeInfo}</span></span>`;
}
let currentLineText = item.text || "";
if (idx === lines.length - 1) {
if (!isFinalizing && item.speaker !== -2) {
if (remaining_time_transcription > 0) {
speakerLabel += `<span class="label_transcription"><span class="spinner"></span>Lag <span id='timeInfo'><span class="lag-transcription-value">${fmt1(
remaining_time_transcription
)}</span>s</span></span>`;
}
if (buffer_diarization && remaining_time_diarization > 0) {
speakerLabel += `<span class="label_diarization"><span class="spinner"></span>Lag<span id='timeInfo'><span class="lag-diarization-value">${fmt1(
remaining_time_diarization
)}</span>s</span></span>`;
}
}
if (buffer_diarization) {
if (isFinalizing) {
currentLineText +=
(currentLineText.length > 0 && buffer_diarization.trim().length > 0 ? " " : "") + buffer_diarization.trim();
} else {
currentLineText += `<span class="buffer_diarization">${buffer_diarization}</span>`;
}
}
if (buffer_transcription) {
if (isFinalizing) {
currentLineText +=
(currentLineText.length > 0 && buffer_transcription.trim().length > 0 ? " " : "") +
buffer_transcription.trim();
} else {
currentLineText += `<span class="buffer_transcription">${buffer_transcription}</span>`;
}
}
}
return currentLineText.trim().length > 0 || speakerLabel.length > 0
? `<p>${speakerLabel}<br/><div class='textcontent'>${currentLineText}</div></p>`
: `<p>${speakerLabel}<br/></p>`;
})
.join("");
linesTranscriptDiv.innerHTML = linesHtml;
window.scrollTo({ top: document.body.scrollHeight, behavior: "smooth" });
}
function updateTimer() {
if (!startTime) return;
const elapsed = Math.floor((Date.now() - startTime) / 1000);
const minutes = Math.floor(elapsed / 60).toString().padStart(2, "0");
const seconds = (elapsed % 60).toString().padStart(2, "0");
timerElement.textContent = `${minutes}:${seconds}`;
}
function drawWaveform() {
if (!analyser) return;
const bufferLength = analyser.frequencyBinCount;
const dataArray = new Uint8Array(bufferLength);
analyser.getByteTimeDomainData(dataArray);
waveCtx.clearRect(
0,
0,
waveCanvas.width / (window.devicePixelRatio || 1),
waveCanvas.height / (window.devicePixelRatio || 1)
);
waveCtx.lineWidth = 1;
waveCtx.strokeStyle = waveStroke;
waveCtx.beginPath();
const sliceWidth = (waveCanvas.width / (window.devicePixelRatio || 1)) / bufferLength;
let x = 0;
for (let i = 0; i < bufferLength; i++) {
const v = dataArray[i] / 128.0;
const y = (v * (waveCanvas.height / (window.devicePixelRatio || 1))) / 2;
if (i === 0) {
waveCtx.moveTo(x, y);
} else {
waveCtx.lineTo(x, y);
}
x += sliceWidth;
}
waveCtx.lineTo(
waveCanvas.width / (window.devicePixelRatio || 1),
(waveCanvas.height / (window.devicePixelRatio || 1)) / 2
);
waveCtx.stroke();
animationFrame = requestAnimationFrame(drawWaveform);
}
async function startRecording() {
try {
try {
wakeLock = await navigator.wakeLock.request("screen");
} catch (err) {
console.log("Error acquiring wake lock.");
}
let stream;
try {
// Try tab capture first
stream = await new Promise((resolve, reject) => {
chrome.tabCapture.capture({audio: true}, (s) => {
if (s) {
resolve(s);
} else {
reject(new Error('Tab capture failed or not available'));
}
});
});
statusText.textContent = "Using tab audio capture.";
} catch (tabError) {
console.log('Tab capture not available, falling back to microphone', tabError);
// Fallback to microphone
const audioConstraints = selectedMicrophoneId
? { audio: { deviceId: { exact: selectedMicrophoneId } } }
: { audio: true };
stream = await navigator.mediaDevices.getUserMedia(audioConstraints);
statusText.textContent = "Using microphone audio.";
}
audioContext = new (window.AudioContext || window.webkitAudioContext)();
analyser = audioContext.createAnalyser();
analyser.fftSize = 256;
microphone = audioContext.createMediaStreamSource(stream);
microphone.connect(analyser);
recorder = new MediaRecorder(stream, { mimeType: "audio/webm" });
recorder.ondataavailable = (e) => {
if (websocket && websocket.readyState === WebSocket.OPEN) {
websocket.send(e.data);
}
};
recorder.start(chunkDuration);
startTime = Date.now();
timerInterval = setInterval(updateTimer, 1000);
drawWaveform();
isRecording = true;
updateUI();
} catch (err) {
if (window.location.hostname === "0.0.0.0") {
statusText.textContent =
"Error accessing audio input. Browsers may block audio access on 0.0.0.0. Try using localhost:8000 instead.";
} else {
statusText.textContent = "Error accessing audio input. Please check permissions.";
}
console.error(err);
}
}
async function stopRecording() {
if (wakeLock) {
try {
await wakeLock.release();
} catch (e) {
// ignore
}
wakeLock = null;
}
userClosing = true;
waitingForStop = true;
if (websocket && websocket.readyState === WebSocket.OPEN) {
const emptyBlob = new Blob([], { type: "audio/webm" });
websocket.send(emptyBlob);
statusText.textContent = "Recording stopped. Processing final audio...";
}
if (recorder) {
recorder.stop();
recorder = null;
}
if (microphone) {
microphone.disconnect();
microphone = null;
}
if (analyser) {
analyser = null;
}
if (audioContext && audioContext.state !== "closed") {
try {
await audioContext.close();
} catch (e) {
console.warn("Could not close audio context:", e);
}
audioContext = null;
}
if (animationFrame) {
cancelAnimationFrame(animationFrame);
animationFrame = null;
}
if (timerInterval) {
clearInterval(timerInterval);
timerInterval = null;
}
timerElement.textContent = "00:00";
startTime = null;
isRecording = false;
updateUI();
}
async function toggleRecording() {
if (!isRecording) {
if (waitingForStop) {
console.log("Waiting for stop, early return");
return;
}
console.log("Connecting to WebSocket");
try {
if (websocket && websocket.readyState === WebSocket.OPEN) {
await startRecording();
} else {
await setupWebSocket();
await startRecording();
}
} catch (err) {
statusText.textContent = "Could not connect to WebSocket or access mic. Aborted.";
console.error(err);
}
} else {
console.log("Stopping recording");
stopRecording();
}
}
function updateUI() {
recordButton.classList.toggle("recording", isRecording);
recordButton.disabled = waitingForStop;
if (waitingForStop) {
if (statusText.textContent !== "Recording stopped. Processing final audio...") {
statusText.textContent = "Please wait for processing to complete...";
}
} else if (isRecording) {
statusText.textContent = "Recording...";
} else {
if (
statusText.textContent !== "Finished processing audio! Ready to record again." &&
statusText.textContent !== "Processing finalized or connection closed."
) {
statusText.textContent = "Click to start transcription";
}
}
if (!waitingForStop) {
recordButton.disabled = false;
}
}
recordButton.addEventListener("click", toggleRecording);
if (microphoneSelect) {
microphoneSelect.addEventListener("change", handleMicrophoneChange);
}
// Settings toggle functionality
settingsToggle.addEventListener("click", () => {
settingsDiv.classList.toggle("visible");
settingsToggle.classList.toggle("active");
});
document.addEventListener('DOMContentLoaded', async () => {
try {
await enumerateMicrophones();
} catch (error) {
console.log("Could not enumerate microphones on load:", error);
}
});
navigator.mediaDevices.addEventListener('devicechange', async () => {
console.log('Device change detected, re-enumerating microphones');
try {
await enumerateMicrophones();
} catch (error) {
console.log("Error re-enumerating microphones:", error);
}
});
async function run() {
const micPermission = await navigator.permissions.query({
name: "microphone",
});
document.getElementById(
"audioPermission"
).innerText = `MICROPHONE: ${micPermission.state}`;
if (micPermission.state !== "granted") {
chrome.tabs.create({ url: "welcome.html" });
}
const intervalId = setInterval(async () => {
const micPermission = await navigator.permissions.query({
name: "microphone",
});
if (micPermission.state === "granted") {
document.getElementById(
"audioPermission"
).innerText = `MICROPHONE: ${micPermission.state}`;
clearInterval(intervalId);
}
}, 100);
}
void run();

View File

@@ -0,0 +1,37 @@
{
"manifest_version": 3,
"name": "WhisperLiveKit Tab Capture",
"version": "1.0",
"description": "Capture and transcribe audio from browser tabs using WhisperLiveKit.",
"background": {
"service_worker": "background.js"
},
"icons": {
"16": "icons/icon16.png",
"32": "icons/icon32.png",
"48": "icons/icon48.png",
"128": "icons/icon128.png"
},
"action": {
"default_title": "WhisperLiveKit Tab Capture",
"default_popup": "popup.html"
},
"permissions": [
"scripting",
"tabCapture",
"offscreen",
"activeTab",
"storage"
],
"web_accessible_resources": [
{
"resources": [
"requestPermissions.html",
"requestPermissions.js"
],
"matches": [
"<all_urls>"
]
}
]
}

View File

@@ -0,0 +1,78 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>WhisperLiveKit</title>
<link rel="stylesheet" href="/web/live_transcription.css" />
</head>
<body>
<div class="settings-container">
<button id="recordButton">
<div class="shape-container">
<div class="shape"></div>
</div>
<div class="recording-info">
<div class="wave-container">
<canvas id="waveCanvas"></canvas>
</div>
<div class="timer">00:00</div>
</div>
</button>
<button id="settingsToggle" class="settings-toggle" title="Show/hide settings">
<img src="/web/src/settings.svg" alt="Settings" />
</button>
<div class="settings">
<div class="field">
<label for="websocketInput">Websocket URL</label>
<input id="websocketInput" type="text" placeholder="ws://host:port/asr" />
</div>
<div class="field">
<label id="microphoneSelectLabel" for="microphoneSelect">Select Microphone</label>
<select id="microphoneSelect">
<option value="">Default Microphone</option>
</select>
<div id="audioPermission"></div>
</div>
<div class="theme-selector-container">
<div class="segmented" role="radiogroup" aria-label="Theme selector">
<input type="radio" id="theme-system" name="theme" value="system" />
<label for="theme-system" title="System">
<img src="/web/src/system_mode.svg" alt="" />
<!-- <span>System</span> -->
</label>
<input type="radio" id="theme-light" name="theme" value="light" />
<label for="theme-light" title="Light">
<img src="/web/src/light_mode.svg" alt="" />
<!-- <span>Light</span> -->
</label>
<input type="radio" id="theme-dark" name="theme" value="dark" />
<label for="theme-dark" title="Dark">
<img src="/web/src/dark_mode.svg" alt="" />
<!-- <span>Dark</span> -->
</label>
</div>
</div>
</div>
</div>
<p id="status"></p>
<div id="linesTranscript"></div>
<script src="live_transcription.js"></script>
</body>
</html>

View File

@@ -0,0 +1,12 @@
<!DOCTYPE html>
<html>
<head>
<title>Request Permissions</title>
<script src="requestPermissions.js"></script>
</head>
<body>
This page exists to workaround an issue with Chrome that blocks permission
requests from chrome extensions
<button id="requestMicrophone">Request Microphone</button>
</body>
</html>

View File

@@ -0,0 +1,17 @@
/**
* Requests user permission for microphone access.
* @returns {Promise<void>} A Promise that resolves when permission is granted or rejects with an error.
*/
async function getUserPermission() {
console.log("Getting user permission for microphone access...");
await navigator.mediaDevices.getUserMedia({ audio: true });
const micPermission = await navigator.permissions.query({
name: "microphone",
});
if (micPermission.state == "granted") {
window.close();
}
}
// Call the function to request microphone permission
getUserPermission();

View File

@@ -0,0 +1,29 @@
console.log("sidepanel.js");
async function run() {
const micPermission = await navigator.permissions.query({
name: "microphone",
});
document.getElementById(
"audioPermission"
).innerText = `MICROPHONE: ${micPermission.state}`;
if (micPermission.state !== "granted") {
chrome.tabs.create({ url: "requestPermissions.html" });
}
const intervalId = setInterval(async () => {
const micPermission = await navigator.permissions.query({
name: "microphone",
});
if (micPermission.state === "granted") {
document.getElementById(
"audioPermission"
).innerText = `MICROPHONE: ${micPermission.state}`;
clearInterval(intervalId);
}
}, 100);
}
void run();

View File

@@ -0,0 +1,539 @@
:root {
--bg: #ffffff;
--text: #111111;
--muted: #666666;
--border: #e5e5e5;
--chip-bg: rgba(0, 0, 0, 0.04);
--chip-text: #000000;
--spinner-border: #8d8d8d5c;
--spinner-top: #b0b0b0;
--silence-bg: #f3f3f3;
--loading-bg: rgba(255, 77, 77, 0.06);
--button-bg: #ffffff;
--button-border: #e9e9e9;
--wave-stroke: #000000;
--label-dia-text: #868686;
--label-trans-text: #111111;
}
@media (prefers-color-scheme: dark) {
:root:not([data-theme="light"]) {
--bg: #0b0b0b;
--text: #e6e6e6;
--muted: #9aa0a6;
--border: #333333;
--chip-bg: rgba(255, 255, 255, 0.08);
--chip-text: #e6e6e6;
--spinner-border: #555555;
--spinner-top: #dddddd;
--silence-bg: #1a1a1a;
--loading-bg: rgba(255, 77, 77, 0.12);
--button-bg: #111111;
--button-border: #333333;
--wave-stroke: #e6e6e6;
--label-dia-text: #b3b3b3;
--label-trans-text: #ffffff;
}
}
:root[data-theme="dark"] {
--bg: #0b0b0b;
--text: #e6e6e6;
--muted: #9aa0a6;
--border: #333333;
--chip-bg: rgba(255, 255, 255, 0.08);
--chip-text: #e6e6e6;
--spinner-border: #555555;
--spinner-top: #dddddd;
--silence-bg: #1a1a1a;
--loading-bg: rgba(255, 77, 77, 0.12);
--button-bg: #111111;
--button-border: #333333;
--wave-stroke: #e6e6e6;
--label-dia-text: #b3b3b3;
--label-trans-text: #ffffff;
}
:root[data-theme="light"] {
--bg: #ffffff;
--text: #111111;
--muted: #666666;
--border: #e5e5e5;
--chip-bg: rgba(0, 0, 0, 0.04);
--chip-text: #000000;
--spinner-border: #8d8d8d5c;
--spinner-top: #b0b0b0;
--silence-bg: #f3f3f3;
--loading-bg: rgba(255, 77, 77, 0.06);
--button-bg: #ffffff;
--button-border: #e9e9e9;
--wave-stroke: #000000;
--label-dia-text: #868686;
--label-trans-text: #111111;
}
body {
font-family: ui-sans-serif, system-ui, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol', 'Noto Color Emoji';
margin: 20px;
text-align: center;
background-color: var(--bg);
color: var(--text);
}
.settings-toggle {
margin-top: 4px;
width: 40px;
height: 40px;
border: none;
border-radius: 50%;
background-color: var(--button-bg);
cursor: pointer;
transition: all 0.3s ease;
/* border: 1px solid var(--button-border); */
display: flex;
align-items: center;
justify-content: center;
position: relative;
}
.settings-toggle:hover {
background-color: var(--chip-bg);
}
.settings-toggle img {
width: 24px;
height: 24px;
opacity: 0.7;
transition: opacity 0.2s ease, transform 0.3s ease;
}
.settings-toggle:hover img {
opacity: 1;
}
.settings-toggle.active img {
transform: rotate(80deg);
}
/* Record button */
#recordButton {
width: 50px;
height: 50px;
border: none;
border-radius: 50%;
background-color: var(--button-bg);
cursor: pointer;
transition: all 0.3s ease;
border: 1px solid var(--button-border);
display: flex;
align-items: center;
justify-content: center;
position: relative;
}
#recordButton.recording {
width: 180px;
border-radius: 40px;
justify-content: flex-start;
padding-left: 20px;
}
#recordButton:active {
transform: scale(0.95);
}
.shape-container {
width: 25px;
height: 25px;
display: flex;
align-items: center;
justify-content: center;
flex-shrink: 0;
}
.shape {
width: 25px;
height: 25px;
background-color: rgb(209, 61, 53);
border-radius: 50%;
transition: all 0.3s ease;
}
#recordButton:disabled .shape {
background-color: #6e6d6d;
}
#recordButton.recording .shape {
border-radius: 5px;
width: 25px;
height: 25px;
}
/* Recording elements */
.recording-info {
display: none;
align-items: center;
margin-left: 15px;
flex-grow: 1;
}
#recordButton.recording .recording-info {
display: flex;
}
.wave-container {
width: 60px;
height: 30px;
position: relative;
display: flex;
align-items: center;
justify-content: center;
}
#waveCanvas {
width: 100%;
height: 100%;
}
.timer {
font-size: 14px;
font-weight: 500;
color: var(--text);
margin-left: 10px;
}
#status {
margin-top: 20px;
font-size: 16px;
color: var(--text);
}
/* Settings */
.settings-container {
display: flex;
justify-content: center;
align-items: flex-start;
gap: 15px;
margin-top: 20px;
flex-wrap: wrap;
}
.settings {
display: none;
flex-wrap: wrap;
align-items: flex-start;
gap: 12px;
transition: opacity 0.3s ease;
}
.settings.visible {
display: flex;
}
.field {
display: flex;
flex-direction: column;
align-items: flex-start;
gap: 3px;
}
#chunkSelector,
#websocketInput,
#themeSelector,
#microphoneSelect {
font-size: 16px;
padding: 5px 8px;
border-radius: 8px;
border: 1px solid var(--border);
background-color: var(--button-bg);
color: var(--text);
max-height: 30px;
}
#microphoneSelect {
width: 100%;
max-width: 190px;
min-width: 120px;
}
#chunkSelector:focus,
#websocketInput:focus,
#themeSelector:focus,
#microphoneSelect:focus {
outline: none;
border-color: #007bff;
box-shadow: 0 0 0 3px rgba(0, 123, 255, 0.15);
}
label {
font-size: 13px;
color: var(--muted);
}
.ws-default {
font-size: 12px;
color: var(--muted);
}
/* Segmented pill control for Theme */
.segmented {
display: inline-flex;
align-items: stretch;
border: 1px solid var(--button-border);
background-color: var(--button-bg);
border-radius: 999px;
overflow: hidden;
}
.segmented input[type="radio"] {
position: absolute;
opacity: 0;
pointer-events: none;
}
.theme-selector-container {
display: flex;
align-items: center;
margin-top: 17px;
}
.segmented label {
display: inline-flex;
align-items: center;
gap: 6px;
padding: 6px 12px;
font-size: 14px;
color: var(--muted);
cursor: pointer;
user-select: none;
transition: background-color 0.2s ease, color 0.2s ease;
}
.segmented label span {
display: none;
}
.segmented label:hover span {
display: inline;
}
.segmented label:hover {
background-color: var(--chip-bg);
}
.segmented img {
width: 16px;
height: 16px;
}
.segmented input[type="radio"]:checked + label {
background-color: var(--chip-bg);
color: var(--text);
}
.segmented input[type="radio"]:focus-visible + label,
.segmented input[type="radio"]:focus + label {
outline: 2px solid #007bff;
outline-offset: 2px;
border-radius: 999px;
}
/* Transcript area */
#linesTranscript {
margin: 20px auto;
max-width: 700px;
text-align: left;
font-size: 16px;
}
#linesTranscript p {
margin: 0px 0;
}
#linesTranscript strong {
color: var(--text);
}
#speaker {
border: 1px solid var(--border);
border-radius: 100px;
padding: 2px 10px;
font-size: 14px;
margin-bottom: 0px;
}
.label_diarization {
background-color: var(--chip-bg);
border-radius: 8px 8px 8px 8px;
padding: 2px 10px;
margin-left: 10px;
display: inline-block;
white-space: nowrap;
font-size: 14px;
margin-bottom: 0px;
color: var(--label-dia-text);
}
.label_transcription {
background-color: var(--chip-bg);
border-radius: 8px 8px 8px 8px;
padding: 2px 10px;
display: inline-block;
white-space: nowrap;
margin-left: 10px;
font-size: 14px;
margin-bottom: 0px;
color: var(--label-trans-text);
}
#timeInfo {
color: var(--muted);
margin-left: 10px;
}
.textcontent {
font-size: 16px;
padding-left: 10px;
margin-bottom: 10px;
margin-top: 1px;
padding-top: 5px;
border-radius: 0px 0px 0px 10px;
}
.buffer_diarization {
color: var(--label-dia-text);
margin-left: 4px;
}
.buffer_transcription {
color: #7474748c;
margin-left: 4px;
}
.spinner {
display: inline-block;
width: 8px;
height: 8px;
border: 2px solid var(--spinner-border);
border-top: 2px solid var(--spinner-top);
border-radius: 50%;
animation: spin 0.7s linear infinite;
vertical-align: middle;
margin-bottom: 2px;
margin-right: 5px;
}
@keyframes spin {
to {
transform: rotate(360deg);
}
}
.silence {
color: var(--muted);
background-color: var(--silence-bg);
font-size: 13px;
border-radius: 30px;
padding: 2px 10px;
}
.loading {
color: var(--muted);
background-color: var(--loading-bg);
border-radius: 8px 8px 8px 0px;
padding: 2px 10px;
font-size: 14px;
margin-bottom: 0px;
}
/* for smaller screens */
/* @media (max-width: 450px) {
.settings-container {
flex-direction: column;
gap: 10px;
align-items: center;
}
.settings {
justify-content: center;
gap: 8px;
width: 100%;
}
.field {
align-items: center;
width: 100%;
}
#websocketInput,
#microphoneSelect {
min-width: 200px;
max-width: 100%;
}
.theme-selector-container {
margin-top: 10px;
}
} */
/* @media (max-width: 768px) and (min-width: 451px) {
.settings-container {
gap: 10px;
}
.settings {
gap: 8px;
}
#websocketInput,
#microphoneSelect {
min-width: 150px;
max-width: 300px;
}
} */
/* @media (max-width: 480px) {
body {
margin: 10px;
}
.settings-toggle {
width: 35px;
height: 35px;
}
.settings-toggle img {
width: 20px;
height: 20px;
}
.settings {
flex-direction: column;
align-items: center;
gap: 6px;
}
#websocketInput,
#microphoneSelect {
max-width: 400px;
}
.segmented label {
padding: 4px 8px;
font-size: 12px;
}
.segmented img {
width: 14px;
height: 14px;
}
} */
html
{
width: 400px; /* max: 800px */
height: 600px; /* max: 600px */
border-radius: 10px;
}

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M480-120q-151 0-255.5-104.5T120-480q0-138 90-239.5T440-838q13-2 23 3.5t16 14.5q6 9 6.5 21t-7.5 23q-17 26-25.5 55t-8.5 61q0 90 63 153t153 63q31 0 61.5-9t54.5-25q11-7 22.5-6.5T819-479q10 5 15.5 15t3.5 24q-14 138-117.5 229T480-120Zm0-80q88 0 158-48.5T740-375q-20 5-40 8t-40 3q-123 0-209.5-86.5T364-660q0-20 3-40t8-40q-78 32-126.5 102T200-480q0 116 82 198t198 82Zm-10-270Z"/></svg>

After

Width:  |  Height:  |  Size: 493 B

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M480-360q50 0 85-35t35-85q0-50-35-85t-85-35q-50 0-85 35t-35 85q0 50 35 85t85 35Zm0 80q-83 0-141.5-58.5T280-480q0-83 58.5-141.5T480-680q83 0 141.5 58.5T680-480q0 83-58.5 141.5T480-280ZM80-440q-17 0-28.5-11.5T40-480q0-17 11.5-28.5T80-520h80q17 0 28.5 11.5T200-480q0 17-11.5 28.5T160-440H80Zm720 0q-17 0-28.5-11.5T760-480q0-17 11.5-28.5T800-520h80q17 0 28.5 11.5T920-480q0 17-11.5 28.5T880-440h-80ZM480-760q-17 0-28.5-11.5T440-800v-80q0-17 11.5-28.5T480-920q17 0 28.5 11.5T520-880v80q0 17-11.5 28.5T480-760Zm0 720q-17 0-28.5-11.5T440-80v-80q0-17 11.5-28.5T480-200q17 0 28.5 11.5T520-160v80q0 17-11.5 28.5T480-40ZM226-678l-43-42q-12-11-11.5-28t11.5-29q12-12 29-12t28 12l42 43q11 12 11 28t-11 28q-11 12-27.5 11.5T226-678Zm494 495-42-43q-11-12-11-28.5t11-27.5q11-12 27.5-11.5T734-282l43 42q12 11 11.5 28T777-183q-12 12-29 12t-28-12Zm-42-495q-12-11-11.5-27.5T678-734l42-43q11-12 28-11.5t29 11.5q12 12 12 29t-12 28l-43 42q-12 11-28 11t-28-11ZM183-183q-12-12-12-29t12-28l43-42q12-11 28.5-11t27.5 11q12 11 11.5 27.5T282-226l-42 43q-11 12-28 11.5T183-183Zm297-297Z"/></svg>

After

Width:  |  Height:  |  Size: 1.2 KiB

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M433-80q-27 0-46.5-18T363-142l-9-66q-13-5-24.5-12T307-235l-62 26q-25 11-50 2t-39-32l-47-82q-14-23-8-49t27-43l53-40q-1-7-1-13.5v-27q0-6.5 1-13.5l-53-40q-21-17-27-43t8-49l47-82q14-23 39-32t50 2l62 26q11-8 23-15t24-12l9-66q4-26 23.5-44t46.5-18h94q27 0 46.5 18t23.5 44l9 66q13 5 24.5 12t22.5 15l62-26q25-11 50-2t39 32l47 82q14 23 8 49t-27 43l-53 40q1 7 1 13.5v27q0 6.5-2 13.5l53 40q21 17 27 43t-8 49l-48 82q-14 23-39 32t-50-2l-60-26q-11 8-23 15t-24 12l-9 66q-4 26-23.5 44T527-80h-94Zm7-80h79l14-106q31-8 57.5-23.5T639-327l99 41 39-68-86-65q5-14 7-29.5t2-31.5q0-16-2-31.5t-7-29.5l86-65-39-68-99 42q-22-23-48.5-38.5T533-694l-13-106h-79l-14 106q-31 8-57.5 23.5T321-633l-99-41-39 68 86 64q-5 15-7 30t-2 32q0 16 2 31t7 30l-86 65 39 68 99-42q22 23 48.5 38.5T427-266l13 106Zm42-180q58 0 99-41t41-99q0-58-41-99t-99-41q-59 0-99.5 41T342-480q0 58 40.5 99t99.5 41Zm-2-140Z"/></svg>

After

Width:  |  Height:  |  Size: 982 B

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M396-396q-32-32-58.5-67T289-537q-5 14-6.5 28.5T281-480q0 83 58 141t141 58q14 0 28.5-2t28.5-6q-39-22-74-48.5T396-396Zm85 196q-56 0-107-21t-91-61q-40-40-61-91t-21-107q0-51 17-97.5t50-84.5q13-14 32-9.5t27 24.5q21 55 52.5 104t73.5 91q42 42 91 73.5T648-326q20 8 24.5 27t-9.5 32q-38 33-84.5 50T481-200Zm223-192q-16-5-23-20.5t-4-32.5q9-48-6-94.5T621-621q-35-35-80.5-49.5T448-677q-17 3-32-4t-21-23q-6-16 1.5-31t23.5-19q69-15 138 4.5T679-678q51 51 71 120t5 138q-4 17-19 25t-32 3ZM480-840q-17 0-28.5-11.5T440-880v-40q0-17 11.5-28.5T480-960q17 0 28.5 11.5T520-920v40q0 17-11.5 28.5T480-840Zm0 840q-17 0-28.5-11.5T440-40v-40q0-17 11.5-28.5T480-120q17 0 28.5 11.5T520-80v40q0 17-11.5 28.5T480 0Zm255-734q-12-12-12-28.5t12-28.5l28-28q11-11 27.5-11t28.5 11q12 12 12 28.5T819-762l-28 28q-12 12-28 12t-28-12ZM141-141q-12-12-12-28.5t12-28.5l28-28q12-12 28-12t28 12q12 12 12 28.5T225-169l-28 28q-11 11-27.5 11T141-141Zm739-299q-17 0-28.5-11.5T840-480q0-17 11.5-28.5T880-520h40q17 0 28.5 11.5T960-480q0 17-11.5 28.5T920-440h-40Zm-840 0q-17 0-28.5-11.5T0-480q0-17 11.5-28.5T40-520h40q17 0 28.5 11.5T120-480q0 17-11.5 28.5T80-440H40Zm779 299q-12 12-28.5 12T762-141l-28-28q-12-12-12-28t12-28q12-12 28.5-12t28.5 12l28 28q11 11 11 27.5T819-141ZM226-735q-12 12-28.5 12T169-735l-28-28q-11-11-11-27.5t11-28.5q12-12 28.5-12t28.5 12l28 28q12 12 12 28t-12 28Zm170 339Z"/></svg>

After

Width:  |  Height:  |  Size: 1.4 KiB

View File

@@ -0,0 +1,12 @@
<!DOCTYPE html>
<html>
<head>
<title>Welcome</title>
<script src="welcome.js"></script>
</head>
<body>
This page exists to workaround an issue with Chrome that blocks permission
requests from chrome extensions
<!-- <button id="requestMicrophone">Request Microphone</button> -->
</body>
</html>

BIN
demo.png

Binary file not shown.

Before

Width:  |  Height:  |  Size: 438 KiB

After

Width:  |  Height:  |  Size: 449 KiB

57
pyproject.toml Normal file
View File

@@ -0,0 +1,57 @@
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[project]
name = "whisperlivekit"
version = "0.2.9"
description = "Real-time speech-to-text with speaker diarization using Whisper"
readme = "README.md"
authors = [
{ name = "Quentin Fuxa" }
]
license = { file = "LICENSE" }
requires-python = ">=3.9"
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
"Programming Language :: Python :: 3.15",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Multimedia :: Sound/Audio :: Speech"
]
dependencies = [
"fastapi",
"librosa",
"soundfile",
"faster-whisper",
"uvicorn",
"websockets",
"torchaudio>=2.0.0",
"torch>=2.0.0",
"tqdm",
"tiktoken",
'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")'
]
[project.optional-dependencies]
sentence = ["mosestokenizer", "wtpsplit"]
[project.urls]
Homepage = "https://github.com/QuentinFuxa/WhisperLiveKit"
[project.scripts]
whisperlivekit-server = "whisperlivekit.basic_server:main"
[tool.setuptools]
packages = ["whisperlivekit", "whisperlivekit.diarization", "whisperlivekit.simul_whisper", "whisperlivekit.simul_whisper.whisper", "whisperlivekit.simul_whisper.whisper.assets", "whisperlivekit.simul_whisper.whisper.normalizers", "whisperlivekit.web", "whisperlivekit.whisper_streaming_custom"]
[tool.setuptools.package-data]
whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"]
"whisperlivekit.simul_whisper.whisper.assets" = ["*.tiktoken", "*.npz"]

View File

@@ -1,55 +0,0 @@
from setuptools import setup, find_packages
setup(
name="whisperlivekit",
version="0.2.1",
description="Real-time, Fully Local Whisper's Speech-to-Text and Speaker Diarization",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
author="Quentin Fuxa",
url="https://github.com/QuentinFuxa/WhisperLiveKit",
packages=find_packages(),
install_requires=[
"fastapi",
"librosa",
"soundfile",
"faster-whisper",
"uvicorn",
"websockets",
],
extras_require={
"diarization": ["diart"],
"vac": ["torch"],
"sentence": ["mosestokenizer", "wtpsplit"],
"whisper": ["whisper"],
"whisper-timestamped": ["whisper-timestamped"],
"mlx-whisper": ["mlx-whisper"],
"openai": ["openai"],
"simulstreaming": [
"torch",
"tqdm",
"tiktoken",
"numpy<2.0.0",
"triton>=2.0.0,<3;platform_machine==\"x86_64\" and sys_platform==\"linux\" or sys_platform==\"linux2\"",
],
},
package_data={
'whisperlivekit': ['web/*.html'],
'whisperlivekit.simul_whisper': ['dual_license_simulstreaming.md'],
'whisperlivekit.simul_whisper.whisper.assets': ['*.tiktoken', '*.npz'],
},
entry_points={
'console_scripts': [
'whisperlivekit-server=whisperlivekit.basic_server:main',
],
},
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Multimedia :: Sound/Audio :: Speech",
],
python_requires=">=3.9",
)

View File

@@ -1,13 +1,13 @@
from .download_simulstreaming_backend import download_simulstreaming_backend
from .audio_processor import AudioProcessor from .audio_processor import AudioProcessor
from .core import TranscriptionEngine from .core import TranscriptionEngine
from .parse_args import parse_args from .parse_args import parse_args
from .web.web_interface import get_web_interface_html from .web.web_interface import get_web_interface_html, get_inline_ui_html
__all__ = [ __all__ = [
"TranscriptionEngine", "TranscriptionEngine",
"AudioProcessor", "AudioProcessor",
"parse_args", "parse_args",
"get_web_interface_html", "get_web_interface_html",
"get_inline_ui_html",
"download_simulstreaming_backend", "download_simulstreaming_backend",
] ]

View File

@@ -4,12 +4,11 @@ from time import time, sleep
import math import math
import logging import logging
import traceback import traceback
from datetime import timedelta from whisperlivekit.timed_objects import ASRToken, Silence, Line
from whisperlivekit.timed_objects import ASRToken from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory, online_translation_factory
from whisperlivekit.whisper_streaming_custom.whisper_online import online_factory
from whisperlivekit.core import TranscriptionEngine
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
from whisperlivekit.silero_vad_iterator import FixedVADIterator
from whisperlivekit.results_formater import format_output
# Set up logging once # Set up logging once
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -17,9 +16,16 @@ logger.setLevel(logging.DEBUG)
SENTINEL = object() # unique sentinel object for end of stream marker SENTINEL = object() # unique sentinel object for end of stream marker
def format_time(seconds: float) -> str:
"""Format seconds as HH:MM:SS.""" async def get_all_from_queue(queue):
return str(timedelta(seconds=int(seconds))) items = []
try:
while True:
item = queue.get_nowait()
items.append(item)
except asyncio.QueueEmpty:
pass
return items
class AudioProcessor: class AudioProcessor:
""" """
@@ -46,25 +52,33 @@ class AudioProcessor:
self.last_ffmpeg_activity = time() self.last_ffmpeg_activity = time()
self.ffmpeg_health_check_interval = 5 self.ffmpeg_health_check_interval = 5
self.ffmpeg_max_idle_time = 10 self.ffmpeg_max_idle_time = 10
self.is_pcm_input = self.args.pcm_input
self.debug = False
# State management # State management
self.is_stopping = False self.is_stopping = False
self.silence = False
self.silence_duration = 0.0
self.tokens = [] self.tokens = []
self.translated_segments = []
self.buffer_transcription = "" self.buffer_transcription = ""
self.buffer_diarization = "" self.buffer_diarization = ""
self.full_transcription = ""
self.end_buffer = 0 self.end_buffer = 0
self.end_attributed_speaker = 0 self.end_attributed_speaker = 0
self.lock = asyncio.Lock() self.lock = asyncio.Lock()
self.beg_loop = time() self.beg_loop = None #to deal with a potential little lag at the websocket initialization, this is now set in process_audio
self.sep = " " # Default separator self.sep = " " # Default separator
self.last_response_content = "" self.last_response_content = ""
# Models and processing # Models and processing
self.asr = models.asr self.asr = models.asr
self.tokenizer = models.tokenizer self.tokenizer = models.tokenizer
self.diarization = models.diarization self.vac_model = models.vac_model
if self.args.vac:
self.vac = FixedVADIterator(models.vac_model)
else:
self.vac = None
self.ffmpeg_manager = FFmpegManager( self.ffmpeg_manager = FFmpegManager(
sample_rate=self.sample_rate, sample_rate=self.sample_rate,
channels=self.channels channels=self.channels
@@ -79,30 +93,32 @@ class AudioProcessor:
self.transcription_queue = asyncio.Queue() if self.args.transcription else None self.transcription_queue = asyncio.Queue() if self.args.transcription else None
self.diarization_queue = asyncio.Queue() if self.args.diarization else None self.diarization_queue = asyncio.Queue() if self.args.diarization else None
self.translation_queue = asyncio.Queue() if self.args.target_language else None
self.pcm_buffer = bytearray() self.pcm_buffer = bytearray()
# Task references
self.transcription_task = None self.transcription_task = None
self.diarization_task = None self.diarization_task = None
self.ffmpeg_reader_task = None self.ffmpeg_reader_task = None
self.watchdog_task = None self.watchdog_task = None
self.all_tasks_for_cleanup = [] self.all_tasks_for_cleanup = []
# Initialize transcription engine if enabled
if self.args.transcription: if self.args.transcription:
self.online = online_factory(self.args, models.asr, models.tokenizer) self.online = online_factory(self.args, models.asr, models.tokenizer)
if self.args.diarization:
self.diarization = online_diarization_factory(self.args, models.diarization_model)
if self.args.target_language:
self.online_translation = online_translation_factory(self.args, models.translation_model)
def convert_pcm_to_float(self, pcm_buffer): def convert_pcm_to_float(self, pcm_buffer):
"""Convert PCM buffer in s16le format to normalized NumPy array.""" """Convert PCM buffer in s16le format to normalized NumPy array."""
return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0 return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
async def update_transcription(self, new_tokens, buffer, end_buffer, full_transcription, sep): async def update_transcription(self, new_tokens, buffer, end_buffer, sep):
"""Thread-safe update of transcription with new data.""" """Thread-safe update of transcription with new data."""
async with self.lock: async with self.lock:
self.tokens.extend(new_tokens) self.tokens.extend(new_tokens)
self.buffer_transcription = buffer self.buffer_transcription = buffer
self.end_buffer = end_buffer self.end_buffer = end_buffer
self.full_transcription = full_transcription
self.sep = sep self.sep = sep
async def update_diarization(self, end_attributed_speaker, buffer_diarization=""): async def update_diarization(self, end_attributed_speaker, buffer_diarization=""):
@@ -115,7 +131,7 @@ class AudioProcessor:
async def add_dummy_token(self): async def add_dummy_token(self):
"""Placeholder token when no transcription is available.""" """Placeholder token when no transcription is available."""
async with self.lock: async with self.lock:
current_time = time() - self.beg_loop current_time = time() - self.beg_loop if self.beg_loop else 0
self.tokens.append(ASRToken( self.tokens.append(ASRToken(
start=current_time, end=current_time + 1, start=current_time, end=current_time + 1,
text=".", speaker=-1, is_dummy=True text=".", speaker=-1, is_dummy=True
@@ -129,15 +145,16 @@ class AudioProcessor:
# Calculate remaining times # Calculate remaining times
remaining_transcription = 0 remaining_transcription = 0
if self.end_buffer > 0: if self.end_buffer > 0:
remaining_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 2)) remaining_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 1))
remaining_diarization = 0 remaining_diarization = 0
if self.tokens: if self.tokens:
latest_end = max(self.end_buffer, self.tokens[-1].end if self.tokens else 0) latest_end = max(self.end_buffer, self.tokens[-1].end if self.tokens else 0)
remaining_diarization = max(0, round(latest_end - self.end_attributed_speaker, 2)) remaining_diarization = max(0, round(latest_end - self.end_attributed_speaker, 1))
return { return {
"tokens": self.tokens.copy(), "tokens": self.tokens.copy(),
"translated_segments": self.translated_segments.copy(),
"buffer_transcription": self.buffer_transcription, "buffer_transcription": self.buffer_transcription,
"buffer_diarization": self.buffer_diarization, "buffer_diarization": self.buffer_diarization,
"end_buffer": self.end_buffer, "end_buffer": self.end_buffer,
@@ -151,9 +168,9 @@ class AudioProcessor:
"""Reset all state variables to initial values.""" """Reset all state variables to initial values."""
async with self.lock: async with self.lock:
self.tokens = [] self.tokens = []
self.translated_segments = []
self.buffer_transcription = self.buffer_diarization = "" self.buffer_transcription = self.buffer_diarization = ""
self.end_buffer = self.end_attributed_speaker = 0 self.end_buffer = self.end_attributed_speaker = 0
self.full_transcription = self.last_response_content = ""
self.beg_loop = time() self.beg_loop = time()
async def ffmpeg_stdout_reader(self): async def ffmpeg_stdout_reader(self):
@@ -192,32 +209,9 @@ class AudioProcessor:
continue continue
self.pcm_buffer.extend(chunk) self.pcm_buffer.extend(chunk)
await self.handle_pcm_data()
# Send to diarization if enabled
if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(
self.convert_pcm_to_float(self.pcm_buffer).copy()
)
# Process when enough data
if len(self.pcm_buffer) >= self.bytes_per_sec:
if len(self.pcm_buffer) > self.max_bytes_per_sec:
logger.warning(
f"Audio buffer too large: {len(self.pcm_buffer) / self.bytes_per_sec:.2f}s. "
f"Consider using a smaller model."
)
# Process audio chunk
pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:self.max_bytes_per_sec])
self.pcm_buffer = self.pcm_buffer[self.max_bytes_per_sec:]
# Send to transcription if enabled
if self.args.transcription and self.transcription_queue:
await self.transcription_queue.put(pcm_array.copy())
# Sleep if no processing is happening
if not self.args.transcription and not self.args.diarization:
await asyncio.sleep(0.1)
except Exception as e: except Exception as e:
logger.warning(f"Exception in ffmpeg_stdout_reader: {e}") logger.warning(f"Exception in ffmpeg_stdout_reader: {e}")
@@ -236,37 +230,48 @@ class AudioProcessor:
if self.args.diarization and self.diarization_queue: if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(SENTINEL) await self.diarization_queue.put(SENTINEL)
logger.debug("Sentinel put into diarization_queue.") logger.debug("Sentinel put into diarization_queue.")
if self.args.target_language and self.translation_queue:
await self.translation_queue.put(SENTINEL)
async def transcription_processor(self): async def transcription_processor(self):
"""Process audio chunks for transcription.""" """Process audio chunks for transcription."""
self.full_transcription = ""
self.sep = self.online.asr.sep self.sep = self.online.asr.sep
cumulative_pcm_duration_stream_time = 0.0 cumulative_pcm_duration_stream_time = 0.0
while True: while True:
try: try:
pcm_array = await self.transcription_queue.get() item = await self.transcription_queue.get()
if pcm_array is SENTINEL: if item is SENTINEL:
logger.debug("Transcription processor received sentinel. Finishing.") logger.debug("Transcription processor received sentinel. Finishing.")
self.transcription_queue.task_done() self.transcription_queue.task_done()
break break
if not self.online: # Should not happen if queue is used if not self.online:
logger.warning("Transcription processor: self.online not initialized.") logger.warning("Transcription processor: self.online not initialized.")
self.transcription_queue.task_done() self.transcription_queue.task_done()
continue continue
asr_internal_buffer_duration_s = len(getattr(self.online, 'audio_buffer', [])) / self.online.SAMPLING_RATE asr_internal_buffer_duration_s = len(getattr(self.online, 'audio_buffer', [])) / self.online.SAMPLING_RATE
transcription_lag_s = max(0.0, time() - self.beg_loop - self.end_buffer) transcription_lag_s = max(0.0, time() - self.beg_loop - self.end_buffer)
asr_processing_logs = f"internal_buffer={asr_internal_buffer_duration_s:.2f}s | lag={transcription_lag_s:.2f}s |"
logger.info( if type(item) is Silence:
f"ASR processing: internal_buffer={asr_internal_buffer_duration_s:.2f}s, " asr_processing_logs += f" + Silence of = {item.duration:.2f}s"
f"lag={transcription_lag_s:.2f}s." if self.tokens:
) asr_processing_logs += f" | last_end = {self.tokens[-1].end} |"
logger.info(asr_processing_logs)
# Process transcription if type(item) is Silence:
duration_this_chunk = len(pcm_array) / self.sample_rate if isinstance(pcm_array, np.ndarray) else 0 cumulative_pcm_duration_stream_time += item.duration
self.online.insert_silence(item.duration, self.tokens[-1].end if self.tokens else 0)
continue
if isinstance(item, np.ndarray):
pcm_array = item
else:
raise Exception('item should be pcm_array')
duration_this_chunk = len(pcm_array) / self.sample_rate
cumulative_pcm_duration_stream_time += duration_this_chunk cumulative_pcm_duration_stream_time += duration_this_chunk
stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time
@@ -279,8 +284,6 @@ class AudioProcessor:
if new_tokens: if new_tokens:
validated_text = self.sep.join([t.text for t in new_tokens]) validated_text = self.sep.join([t.text for t in new_tokens])
self.full_transcription += validated_text
if buffer_text.startswith(validated_text): if buffer_text.startswith(validated_text):
buffer_text = buffer_text[len(validated_text):].lstrip() buffer_text = buffer_text[len(validated_text):].lstrip()
@@ -297,8 +300,13 @@ class AudioProcessor:
new_end_buffer = max(candidate_end_times) new_end_buffer = max(candidate_end_times)
await self.update_transcription( await self.update_transcription(
new_tokens, buffer_text, new_end_buffer, self.full_transcription, self.sep new_tokens, buffer_text, new_end_buffer, self.sep
) )
if new_tokens and self.args.target_language and self.translation_queue:
for token in new_tokens:
await self.translation_queue.put(token)
self.transcription_queue.task_done() self.transcription_queue.task_done()
except Exception as e: except Exception as e:
@@ -312,25 +320,35 @@ class AudioProcessor:
async def diarization_processor(self, diarization_obj): async def diarization_processor(self, diarization_obj):
"""Process audio chunks for speaker diarization.""" """Process audio chunks for speaker diarization."""
buffer_diarization = "" buffer_diarization = ""
cumulative_pcm_duration_stream_time = 0.0
while True: while True:
try: try:
pcm_array = await self.diarization_queue.get() item = await self.diarization_queue.get()
if pcm_array is SENTINEL: if item is SENTINEL:
logger.debug("Diarization processor received sentinel. Finishing.") logger.debug("Diarization processor received sentinel. Finishing.")
self.diarization_queue.task_done() self.diarization_queue.task_done()
break break
if type(item) is Silence:
cumulative_pcm_duration_stream_time += item.duration
diarization_obj.insert_silence(item.duration)
continue
if isinstance(item, np.ndarray):
pcm_array = item
else:
raise Exception('item should be pcm_array')
# Process diarization # Process diarization
await diarization_obj.diarize(pcm_array) await diarization_obj.diarize(pcm_array)
async with self.lock: async with self.lock:
new_end = diarization_obj.assign_speakers_to_tokens( self.tokens = diarization_obj.assign_speakers_to_tokens(
self.end_attributed_speaker,
self.tokens, self.tokens,
use_punctuation_split=self.args.punctuation_split use_punctuation_split=self.args.punctuation_split
) )
self.end_attributed_speaker = new_end if len(self.tokens) > 0:
self.end_attributed_speaker = max(self.tokens[-1].end, self.end_attributed_speaker)
if buffer_diarization: if buffer_diarization:
self.buffer_diarization = buffer_diarization self.buffer_diarization = buffer_diarization
@@ -343,9 +361,54 @@ class AudioProcessor:
self.diarization_queue.task_done() self.diarization_queue.task_done()
logger.info("Diarization processor task finished.") logger.info("Diarization processor task finished.")
async def translation_processor(self, online_translation):
# the idea is to ignore diarization for the moment. We use only transcription tokens.
# And the speaker is attributed given the segments used for the translation
# in the future we want to have different languages for each speaker etc, so it will be more complex.
while True:
try:
token = await self.translation_queue.get() #block until at least 1 token
if token is SENTINEL:
logger.debug("Translation processor received sentinel. Finishing.")
self.translation_queue.task_done()
break
# get all the available tokens for translation. The more words, the more precise
tokens_to_process = [token]
additional_tokens = await get_all_from_queue(self.translation_queue)
sentinel_found = False
for additional_token in additional_tokens:
if additional_token is SENTINEL:
sentinel_found = True
break
tokens_to_process.append(additional_token)
if tokens_to_process:
online_translation.insert_tokens(tokens_to_process)
self.translated_segments = online_translation.process()
self.translation_queue.task_done()
for _ in additional_tokens:
self.translation_queue.task_done()
if sentinel_found:
logger.debug("Translation processor received sentinel in batch. Finishing.")
break
except Exception as e:
logger.warning(f"Exception in translation_processor: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}")
if 'token' in locals() and token is not SENTINEL:
self.translation_queue.task_done()
if 'additional_tokens' in locals():
for _ in additional_tokens:
self.translation_queue.task_done()
logger.info("Translation processor task finished.")
async def results_formatter(self): async def results_formatter(self):
"""Format processing results for output.""" """Format processing results for output."""
last_sent_trans = None
last_sent_diar = None
while True: while True:
try: try:
ffmpeg_state = await self.ffmpeg_manager.get_state() ffmpeg_state = await self.ffmpeg_manager.get_state()
@@ -370,7 +433,7 @@ class AudioProcessor:
buffer_diarization = state["buffer_diarization"] buffer_diarization = state["buffer_diarization"]
end_attributed_speaker = state["end_attributed_speaker"] end_attributed_speaker = state["end_attributed_speaker"]
sep = state["sep"] sep = state["sep"]
# Add dummy tokens if needed # Add dummy tokens if needed
if (not tokens or tokens[-1].is_dummy) and not self.args.transcription and self.args.diarization: if (not tokens or tokens[-1].is_dummy) and not self.args.transcription and self.args.diarization:
await self.add_dummy_token() await self.add_dummy_token()
@@ -379,40 +442,13 @@ class AudioProcessor:
tokens = state["tokens"] tokens = state["tokens"]
# Format output # Format output
previous_speaker = -1 lines, undiarized_text, buffer_transcription, buffer_diarization = format_output(
lines = [] state,
last_end_diarized = 0 self.silence,
undiarized_text = [] current_time = time() - self.beg_loop if self.beg_loop else None,
args = self.args,
# Process each token debug = self.debug
for token in tokens: )
speaker = token.speaker
# Handle diarization
if self.args.diarization:
if (speaker in [-1, 0]) and token.end >= end_attributed_speaker:
undiarized_text.append(token.text)
continue
elif (speaker in [-1, 0]) and token.end < end_attributed_speaker:
speaker = previous_speaker
if speaker not in [-1, 0]:
last_end_diarized = max(token.end, last_end_diarized)
# Group by speaker
if speaker != previous_speaker or not lines:
lines.append({
"speaker": speaker,
"text": token.text,
"beg": format_time(token.start),
"end": format_time(token.end),
"diff": round(token.end - last_end_diarized, 2)
})
previous_speaker = speaker
elif token.text: # Only append if text isn't empty
lines[-1]["text"] += sep + token.text
lines[-1]["end"] = format_time(token.end)
lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
# Handle undiarized text # Handle undiarized text
if undiarized_text: if undiarized_text:
combined = sep.join(undiarized_text) combined = sep.join(undiarized_text)
@@ -422,37 +458,42 @@ class AudioProcessor:
buffer_diarization = combined buffer_diarization = combined
response_status = "active_transcription" response_status = "active_transcription"
final_lines_for_response = lines.copy()
if not tokens and not buffer_transcription and not buffer_diarization: if not tokens and not buffer_transcription and not buffer_diarization:
response_status = "no_audio_detected" response_status = "no_audio_detected"
final_lines_for_response = [] lines = []
elif response_status == "active_transcription" and not final_lines_for_response: elif response_status == "active_transcription" and not lines:
final_lines_for_response = [{ lines = [Line(
"speaker": 1, speaker=1,
"text": "", start=state.get("end_buffer", 0),
"beg": format_time(state.get("end_buffer", 0)), end=state.get("end_buffer", 0)
"end": format_time(state.get("end_buffer", 0)), )]
"diff": 0
}]
response = { response = {
"status": response_status, "status": response_status,
"lines": final_lines_for_response, "lines": [line.to_dict() for line in lines],
"buffer_transcription": buffer_transcription, "buffer_transcription": buffer_transcription,
"buffer_diarization": buffer_diarization, "buffer_diarization": buffer_diarization,
"remaining_time_transcription": state["remaining_time_transcription"], "remaining_time_transcription": state["remaining_time_transcription"],
"remaining_time_diarization": state["remaining_time_diarization"] "remaining_time_diarization": state["remaining_time_diarization"] if self.args.diarization else 0
} }
current_response_signature = f"{response_status} | " + \ current_response_signature = f"{response_status} | " + \
' '.join([f"{line['speaker']} {line['text']}" for line in final_lines_for_response]) + \ ' '.join([f"{line.speaker} {line.text}" for line in lines]) + \
f" | {buffer_transcription} | {buffer_diarization}" f" | {buffer_transcription} | {buffer_diarization}"
if current_response_signature != self.last_response_content and \ trans = state["remaining_time_transcription"]
(final_lines_for_response or buffer_transcription or buffer_diarization or response_status == "no_audio_detected"): diar = state["remaining_time_diarization"]
should_push = (
current_response_signature != self.last_response_content
or last_sent_trans is None
or round(trans, 1) != round(last_sent_trans, 1)
or round(diar, 1) != round(last_sent_diar, 1)
)
if should_push and (lines or buffer_transcription or buffer_diarization or response_status == "no_audio_detected" or trans > 0 or diar > 0):
yield response yield response
self.last_response_content = current_response_signature self.last_response_content = current_response_signature
last_sent_trans = trans
last_sent_diar = diar
# Check for termination condition # Check for termination condition
if self.is_stopping: if self.is_stopping:
@@ -464,7 +505,6 @@ class AudioProcessor:
if all_processors_done: if all_processors_done:
logger.info("Results formatter: All upstream processors are done and in stopping state. Terminating.") logger.info("Results formatter: All upstream processors are done and in stopping state. Terminating.")
final_state = await self.get_current_state()
return return
await asyncio.sleep(0.1) # Avoid overwhelming the client await asyncio.sleep(0.1) # Avoid overwhelming the client
@@ -504,6 +544,11 @@ class AudioProcessor:
self.all_tasks_for_cleanup.append(self.diarization_task) self.all_tasks_for_cleanup.append(self.diarization_task)
processing_tasks_for_watchdog.append(self.diarization_task) processing_tasks_for_watchdog.append(self.diarization_task)
if self.args.target_language and self.args.lan != 'auto':
self.translation_task = asyncio.create_task(self.translation_processor(self.online_translation))
self.all_tasks_for_cleanup.append(self.translation_task)
processing_tasks_for_watchdog.append(self.translation_task)
self.ffmpeg_reader_task = asyncio.create_task(self.ffmpeg_stdout_reader()) self.ffmpeg_reader_task = asyncio.create_task(self.ffmpeg_stdout_reader())
self.all_tasks_for_cleanup.append(self.ffmpeg_reader_task) self.all_tasks_for_cleanup.append(self.ffmpeg_reader_task)
processing_tasks_for_watchdog.append(self.ffmpeg_reader_task) processing_tasks_for_watchdog.append(self.ffmpeg_reader_task)
@@ -546,24 +591,29 @@ class AudioProcessor:
async def cleanup(self): async def cleanup(self):
"""Clean up resources when processing is complete.""" """Clean up resources when processing is complete."""
logger.info("Starting cleanup of AudioProcessor resources.") logger.info("Starting cleanup of AudioProcessor resources.")
self.is_stopping = True
for task in self.all_tasks_for_cleanup: for task in self.all_tasks_for_cleanup:
if task and not task.done(): if task and not task.done():
task.cancel() task.cancel()
created_tasks = [t for t in self.all_tasks_for_cleanup if t] created_tasks = [t for t in self.all_tasks_for_cleanup if t]
if created_tasks: if created_tasks:
await asyncio.gather(*created_tasks, return_exceptions=True) await asyncio.gather(*created_tasks, return_exceptions=True)
logger.info("All processing tasks cancelled or finished.") logger.info("All processing tasks cancelled or finished.")
await self.ffmpeg_manager.stop() await self.ffmpeg_manager.stop()
logger.info("FFmpeg manager stopped.") logger.info("FFmpeg manager stopped.")
if self.args.diarization and hasattr(self, 'diarization') and hasattr(self.diarization, 'close'): if self.args.diarization and hasattr(self, 'diarization') and hasattr(self.diarization, 'close'):
self.diarization.close() self.diarization.close()
logger.info("AudioProcessor cleanup complete.") logger.info("AudioProcessor cleanup complete.")
async def process_audio(self, message): async def process_audio(self, message):
"""Process incoming audio data.""" """Process incoming audio data."""
if not self.beg_loop:
self.beg_loop = time()
if not message: if not message:
logger.info("Empty audio message received, initiating stop sequence.") logger.info("Empty audio message received, initiating stop sequence.")
self.is_stopping = True self.is_stopping = True
@@ -575,10 +625,65 @@ class AudioProcessor:
logger.warning("AudioProcessor is stopping. Ignoring incoming audio.") logger.warning("AudioProcessor is stopping. Ignoring incoming audio.")
return return
success = await self.ffmpeg_manager.write_data(message) if self.is_pcm_input:
if not success: self.pcm_buffer.extend(message)
ffmpeg_state = await self.ffmpeg_manager.get_state() await self.handle_pcm_data()
if ffmpeg_state == FFmpegState.FAILED: else:
logger.error("FFmpeg is in FAILED state, cannot process audio") success = await self.ffmpeg_manager.write_data(message)
else: if not success:
logger.warning("Failed to write audio data to FFmpeg") ffmpeg_state = await self.ffmpeg_manager.get_state()
if ffmpeg_state == FFmpegState.FAILED:
logger.error("FFmpeg is in FAILED state, cannot process audio")
else:
logger.warning("Failed to write audio data to FFmpeg")
async def handle_pcm_data(self):
# Process when enough data
if len(self.pcm_buffer) < self.bytes_per_sec:
return
if len(self.pcm_buffer) > self.max_bytes_per_sec:
logger.warning(
f"Audio buffer too large: {len(self.pcm_buffer) / self.bytes_per_sec:.2f}s. "
f"Consider using a smaller model."
)
# Process audio chunk
pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:self.max_bytes_per_sec])
self.pcm_buffer = self.pcm_buffer[self.max_bytes_per_sec:]
res = None
end_of_audio = False
silence_buffer = None
if self.args.vac:
res = self.vac(pcm_array)
if res is not None:
if res.get("end", 0) > res.get("start", 0):
end_of_audio = True
elif self.silence: #end of silence
self.silence = False
silence_buffer = Silence(duration=time() - self.start_silence)
if silence_buffer:
if self.args.transcription and self.transcription_queue:
await self.transcription_queue.put(silence_buffer)
if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(silence_buffer)
if not self.silence:
if self.args.transcription and self.transcription_queue:
await self.transcription_queue.put(pcm_array.copy())
if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(pcm_array.copy())
self.silence_duration = 0.0
if end_of_audio:
self.silence = True
self.start_silence = time()
if not self.args.transcription and not self.args.diarization:
await asyncio.sleep(0.1)

View File

@@ -2,9 +2,12 @@ from contextlib import asynccontextmanager
from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from whisperlivekit import TranscriptionEngine, AudioProcessor, get_web_interface_html, parse_args from whisperlivekit import TranscriptionEngine, AudioProcessor, get_inline_ui_html, parse_args
import asyncio import asyncio
import logging import logging
from starlette.staticfiles import StaticFiles
import pathlib
import whisperlivekit.web as webpkg
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logging.getLogger().setLevel(logging.WARNING) logging.getLogger().setLevel(logging.WARNING)
@@ -16,6 +19,15 @@ transcription_engine = None
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
#to remove after 0.2.8
if args.backend == "simulstreaming" and not args.disable_fast_encoder:
logger.warning(f"""
{'='*50}
WhisperLiveKit 0.2.8 has introduced a new fast encoder feature using MLX Whisper or Faster Whisper for improved speed. Use --disable-fast-encoder to disable if you encounter issues.
{'='*50}
""")
global transcription_engine global transcription_engine
transcription_engine = TranscriptionEngine( transcription_engine = TranscriptionEngine(
**vars(args), **vars(args),
@@ -30,10 +42,12 @@ app.add_middleware(
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],
) )
web_dir = pathlib.Path(webpkg.__file__).parent
app.mount("/web", StaticFiles(directory=str(web_dir)), name="web")
@app.get("/") @app.get("/")
async def get(): async def get():
return HTMLResponse(get_web_interface_html()) return HTMLResponse(get_inline_ui_html())
async def handle_websocket_results(websocket, results_generator): async def handle_websocket_results(websocket, results_generator):
@@ -47,7 +61,7 @@ async def handle_websocket_results(websocket, results_generator):
except WebSocketDisconnect: except WebSocketDisconnect:
logger.info("WebSocket disconnected while handling results (client likely closed connection).") logger.info("WebSocket disconnected while handling results (client likely closed connection).")
except Exception as e: except Exception as e:
logger.warning(f"Error in WebSocket results handler: {e}") logger.exception(f"Error in WebSocket results handler: {e}")
@app.websocket("/asr") @app.websocket("/asr")

View File

@@ -1,9 +1,12 @@
try: try:
from whisperlivekit.whisper_streaming_custom.whisper_online import backend_factory, warmup_asr from whisperlivekit.whisper_streaming_custom.whisper_online import backend_factory
from whisperlivekit.whisper_streaming_custom.online_asr import OnlineASRProcessor
except ImportError: except ImportError:
from .whisper_streaming_custom.whisper_online import backend_factory, warmup_asr from .whisper_streaming_custom.whisper_online import backend_factory
from .whisper_streaming_custom.online_asr import OnlineASRProcessor
from whisperlivekit.warmup import warmup_asr, warmup_online
from argparse import Namespace from argparse import Namespace
import sys
class TranscriptionEngine: class TranscriptionEngine:
_instance = None _instance = None
@@ -22,7 +25,6 @@ class TranscriptionEngine:
"host": "localhost", "host": "localhost",
"port": 8000, "port": 8000,
"warmup_file": None, "warmup_file": None,
"confidence_validation": False,
"diarization": False, "diarization": False,
"punctuation_split": False, "punctuation_split": False,
"min_chunk_size": 0.5, "min_chunk_size": 0.5,
@@ -31,23 +33,26 @@ class TranscriptionEngine:
"model_dir": None, "model_dir": None,
"lan": "auto", "lan": "auto",
"task": "transcribe", "task": "transcribe",
"target_language": "",
"backend": "faster-whisper", "backend": "faster-whisper",
"vac": False, "vac": True,
"vac_chunk_size": 0.04, "vac_chunk_size": 0.04,
"buffer_trimming": "segment",
"buffer_trimming_sec": 15,
"log_level": "DEBUG", "log_level": "DEBUG",
"ssl_certfile": None, "ssl_certfile": None,
"ssl_keyfile": None, "ssl_keyfile": None,
"transcription": True, "transcription": True,
"vad": True, "vad": True,
"segmentation_model": "pyannote/segmentation-3.0", "pcm_input": False,
"embedding_model": "pyannote/embedding", # whisperstreaming params:
"buffer_trimming": "segment",
"confidence_validation": False,
"buffer_trimming_sec": 15,
# simulstreaming params: # simulstreaming params:
"disable_fast_encoder": False,
"frame_threshold": 25, "frame_threshold": 25,
"beams": 1, "beams": 1,
"decoder_type": None, "decoder_type": None,
"audio_max_len": 30.0, "audio_max_len": 20.0,
"audio_min_len": 0.0, "audio_min_len": 0.0,
"cif_ckpt_path": None, "cif_ckpt_path": None,
"never_fire": False, "never_fire": False,
@@ -55,6 +60,11 @@ class TranscriptionEngine:
"static_init_prompt": None, "static_init_prompt": None,
"max_context_tokens": None, "max_context_tokens": None,
"model_path": './base.pt', "model_path": './base.pt',
"diarization_backend": "sortformer",
# diarization params:
"disable_punctuation_split" : False,
"segmentation_model": "pyannote/segmentation-3.0",
"embedding_model": "pyannote/embedding",
} }
config_dict = {**defaults, **kwargs} config_dict = {**defaults, **kwargs}
@@ -63,6 +73,8 @@ class TranscriptionEngine:
config_dict['transcription'] = not kwargs['no_transcription'] config_dict['transcription'] = not kwargs['no_transcription']
if 'no_vad' in kwargs: if 'no_vad' in kwargs:
config_dict['vad'] = not kwargs['no_vad'] config_dict['vad'] = not kwargs['no_vad']
if 'no_vac' in kwargs:
config_dict['vac'] = not kwargs['no_vac']
config_dict.pop('no_transcription', None) config_dict.pop('no_transcription', None)
config_dict.pop('no_vad', None) config_dict.pop('no_vad', None)
@@ -76,17 +88,99 @@ class TranscriptionEngine:
self.asr = None self.asr = None
self.tokenizer = None self.tokenizer = None
self.diarization = None self.diarization = None
self.vac_model = None
if self.args.vac:
import torch
self.vac_model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
if self.args.transcription: if self.args.transcription:
self.asr, self.tokenizer = backend_factory(self.args) if self.args.backend == "simulstreaming":
warmup_asr(self.asr, self.args.warmup_file) from whisperlivekit.simul_whisper import SimulStreamingASR
self.tokenizer = None
simulstreaming_kwargs = {}
for attr in ['frame_threshold', 'beams', 'decoder_type', 'audio_max_len', 'audio_min_len',
'cif_ckpt_path', 'never_fire', 'init_prompt', 'static_init_prompt',
'max_context_tokens', 'model_path', 'warmup_file', 'preload_model_count', 'disable_fast_encoder']:
if hasattr(self.args, attr):
simulstreaming_kwargs[attr] = getattr(self.args, attr)
# Add segment_length from min_chunk_size
simulstreaming_kwargs['segment_length'] = getattr(self.args, 'min_chunk_size', 0.5)
simulstreaming_kwargs['task'] = self.args.task
size = self.args.model
self.asr = SimulStreamingASR(
modelsize=size,
lan=self.args.lan,
cache_dir=getattr(self.args, 'model_cache_dir', None),
model_dir=getattr(self.args, 'model_dir', None),
**simulstreaming_kwargs
)
else:
self.asr, self.tokenizer = backend_factory(self.args)
warmup_asr(self.asr, self.args.warmup_file) #for simulstreaming, warmup should be done in the online class not here
if self.args.diarization: if self.args.diarization:
from whisperlivekit.diarization.diarization_online import DiartDiarization if self.args.diarization_backend == "diart":
self.diarization = DiartDiarization( from whisperlivekit.diarization.diart_backend import DiartDiarization
block_duration=self.args.min_chunk_size, self.diarization_model = DiartDiarization(
segmentation_model_name=self.args.segmentation_model, block_duration=self.args.min_chunk_size,
embedding_model_name=self.args.embedding_model segmentation_model_name=self.args.segmentation_model,
) embedding_model_name=self.args.embedding_model
)
elif self.args.diarization_backend == "sortformer":
from whisperlivekit.diarization.sortformer_backend import SortformerDiarization
self.diarization_model = SortformerDiarization()
else:
raise ValueError(f"Unknown diarization backend: {self.args.diarization_backend}")
self.translation_model = None
if self.args.target_language:
if self.args.lan == 'auto':
raise Exception('Translation cannot be set with language auto')
else:
from whisperlivekit.translation.translation import load_model
self.translation_model = load_model([self.args.lan]) #in the future we want to handle different languages for different speakers
TranscriptionEngine._initialized = True TranscriptionEngine._initialized = True
def online_factory(args, asr, tokenizer, logfile=sys.stderr):
if args.backend == "simulstreaming":
from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor
online = SimulStreamingOnlineProcessor(
asr,
logfile=logfile,
)
# warmup_online(online, args.warmup_file)
else:
online = OnlineASRProcessor(
asr,
tokenizer,
logfile=logfile,
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
confidence_validation = args.confidence_validation
)
return online
def online_diarization_factory(args, diarization_backend):
if args.diarization_backend == "diart":
online = diarization_backend
# Not the best here, since several user/instances will share the same backend, but diart is not SOTA anymore and sortformer is recommended
if args.diarization_backend == "sortformer":
from whisperlivekit.diarization.sortformer_backend import SortformerDiarizationOnline
online = SortformerDiarizationOnline(shared_model=diarization_backend)
return online
def online_translation_factory(args, translation_model):
#should be at speaker level in the future:
#one shared nllb model for all speaker
#one tokenizer per speaker/language
from whisperlivekit.translation.translation import OnlineTranslation
return OnlineTranslation(translation_model, [args.lan], [args.target_language])

View File

@@ -29,6 +29,7 @@ class DiarizationObserver(Observer):
self.speaker_segments = [] self.speaker_segments = []
self.processed_time = 0 self.processed_time = 0
self.segment_lock = threading.Lock() self.segment_lock = threading.Lock()
self.global_time_offset = 0.0
def on_next(self, value: Tuple[Annotation, Any]): def on_next(self, value: Tuple[Annotation, Any]):
annotation, audio = value annotation, audio = value
@@ -49,8 +50,8 @@ class DiarizationObserver(Observer):
print(f" {speaker}: {start:.2f}s-{end:.2f}s") print(f" {speaker}: {start:.2f}s-{end:.2f}s")
self.speaker_segments.append(SpeakerSegment( self.speaker_segments.append(SpeakerSegment(
speaker=speaker, speaker=speaker,
start=start, start=start + self.global_time_offset,
end=end end=end + self.global_time_offset
)) ))
else: else:
logger.debug("\nNo speakers detected in this segment") logger.debug("\nNo speakers detected in this segment")
@@ -165,7 +166,7 @@ class WebSocketAudioSource(AudioSource):
class DiartDiarization: class DiartDiarization:
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 0.5, segmentation_model_name: str = "pyannote/segmentation-3.0", embedding_model_name: str = "speechbrain/spkrec-ecapa-voxceleb"): def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 1.5, segmentation_model_name: str = "pyannote/segmentation-3.0", embedding_model_name: str = "pyannote/embedding"):
segmentation_model = m.SegmentationModel.from_pretrained(segmentation_model_name) segmentation_model = m.SegmentationModel.from_pretrained(segmentation_model_name)
embedding_model = m.EmbeddingModel.from_pretrained(embedding_model_name) embedding_model = m.EmbeddingModel.from_pretrained(embedding_model_name)
@@ -199,6 +200,9 @@ class DiartDiarization:
self.inference.attach_observers(self.observer) self.inference.attach_observers(self.observer)
asyncio.get_event_loop().run_in_executor(None, self.inference) asyncio.get_event_loop().run_in_executor(None, self.inference)
def insert_silence(self, silence_duration):
self.observer.global_time_offset += silence_duration
async def diarize(self, pcm_array: np.ndarray): async def diarize(self, pcm_array: np.ndarray):
""" """
Process audio data for diarization. Process audio data for diarization.
@@ -206,15 +210,14 @@ class DiartDiarization:
""" """
if self.custom_source: if self.custom_source:
self.custom_source.push_audio(pcm_array) self.custom_source.push_audio(pcm_array)
self.observer.clear_old_segments() # self.observer.clear_old_segments()
return self.observer.get_segments()
def close(self): def close(self):
"""Close the audio source.""" """Close the audio source."""
if self.custom_source: if self.custom_source:
self.custom_source.close() self.custom_source.close()
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list, use_punctuation_split: bool = False) -> float: def assign_speakers_to_tokens(self, tokens: list, use_punctuation_split: bool = False) -> float:
""" """
Assign speakers to tokens based on timing overlap with speaker segments. Assign speakers to tokens based on timing overlap with speaker segments.
Uses the segments collected by the observer. Uses the segments collected by the observer.
@@ -231,85 +234,82 @@ class DiartDiarization:
if not self.lag_diart and segments and tokens: if not self.lag_diart and segments and tokens:
self.lag_diart = segments[0].start - tokens[0].start self.lag_diart = segments[0].start - tokens[0].start
for token in tokens:
for segment in segments:
if not (segment.end <= token.start + self.lag_diart or segment.start >= token.end + self.lag_diart):
token.speaker = extract_number(segment.speaker) + 1
end_attributed_speaker = max(token.end, end_attributed_speaker)
if use_punctuation_split and len(tokens) > 1: if not use_punctuation_split:
punctuation_marks = {'.', '!', '?'} for token in tokens:
for segment in segments:
print("Here are the tokens:", if not (segment.end <= token.start + self.lag_diart or segment.start >= token.end + self.lag_diart):
[(t.text, t.start, t.end, t.speaker) for t in tokens[:10]]) token.speaker = extract_number(segment.speaker) + 1
else:
segment_map = [] tokens = add_speaker_to_tokens(segments, tokens)
for segment in segments: return tokens
speaker_num = extract_number(segment.speaker) + 1
segment_map.append((segment.start, segment.end, speaker_num))
segment_map.sort(key=lambda x: x[0])
i = 0
while i < len(tokens):
current_token = tokens[i]
is_sentence_end = False
if current_token.text and current_token.text.strip():
text = current_token.text.strip()
if text[-1] in punctuation_marks:
is_sentence_end = True
logger.debug(f"Token {i} ends sentence: '{current_token.text}' at {current_token.end:.2f}s")
if is_sentence_end and current_token.speaker != -1:
punctuation_time = current_token.end
current_speaker = current_token.speaker
j = i + 1
next_sentence_tokens = []
while j < len(tokens):
next_token = tokens[j]
next_sentence_tokens.append(j)
# Check if this token ends the next sentence
if next_token.text and next_token.text.strip():
if next_token.text.strip()[-1] in punctuation_marks:
break
j += 1
if next_sentence_tokens:
speaker_times = {}
for idx in next_sentence_tokens:
token = tokens[idx]
# Find which segments overlap with this token
for seg_start, seg_end, seg_speaker in segment_map:
if not (seg_end <= token.start or seg_start >= token.end):
# Calculate overlap duration
overlap_start = max(seg_start, token.start)
overlap_end = min(seg_end, token.end)
overlap_duration = overlap_end - overlap_start
if seg_speaker not in speaker_times:
speaker_times[seg_speaker] = 0
speaker_times[seg_speaker] += overlap_duration
if speaker_times:
dominant_speaker = max(speaker_times.items(), key=lambda x: x[1])[0]
if dominant_speaker != current_speaker:
logger.debug(f" Speaker change after punctuation: {current_speaker}{dominant_speaker}")
for idx in next_sentence_tokens:
if tokens[idx].speaker != dominant_speaker:
logger.debug(f" Reassigning token {idx} ('{tokens[idx].text}') to Speaker {dominant_speaker}")
tokens[idx].speaker = dominant_speaker
end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker)
else:
for idx in next_sentence_tokens:
if tokens[idx].speaker == -1:
tokens[idx].speaker = current_speaker
end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker)
i += 1
return end_attributed_speaker def concatenate_speakers(segments):
segments_concatenated = [{"speaker": 1, "begin": 0.0, "end": 0.0}]
for segment in segments:
speaker = extract_number(segment.speaker) + 1
if segments_concatenated[-1]['speaker'] != speaker:
segments_concatenated.append({"speaker": speaker, "begin": segment.start, "end": segment.end})
else:
segments_concatenated[-1]['end'] = segment.end
# print("Segments concatenated:")
# for entry in segments_concatenated:
# print(f"Speaker {entry['speaker']}: {entry['begin']:.2f}s - {entry['end']:.2f}s")
return segments_concatenated
def add_speaker_to_tokens(segments, tokens):
"""
Assign speakers to tokens based on diarization segments, with punctuation-aware boundary adjustment.
"""
punctuation_marks = {'.', '!', '?'}
punctuation_tokens = [token for token in tokens if token.text.strip() in punctuation_marks]
segments_concatenated = concatenate_speakers(segments)
for ind, segment in enumerate(segments_concatenated):
for i, punctuation_token in enumerate(punctuation_tokens):
if punctuation_token.start > segment['end']:
after_length = punctuation_token.start - segment['end']
before_length = segment['end'] - punctuation_tokens[i - 1].end
if before_length > after_length:
segment['end'] = punctuation_token.start
if i < len(punctuation_tokens) - 1 and ind + 1 < len(segments_concatenated):
segments_concatenated[ind + 1]['begin'] = punctuation_token.start
else:
segment['end'] = punctuation_tokens[i - 1].end
if i < len(punctuation_tokens) - 1 and ind - 1 >= 0:
segments_concatenated[ind - 1]['begin'] = punctuation_tokens[i - 1].end
break
last_end = 0.0
for token in tokens:
start = max(last_end + 0.01, token.start)
token.start = start
token.end = max(start, token.end)
last_end = token.end
ind_last_speaker = 0
for segment in segments_concatenated:
for i, token in enumerate(tokens[ind_last_speaker:]):
if token.end <= segment['end']:
token.speaker = segment['speaker']
ind_last_speaker = i + 1
# print(
# f"Token '{token.text}' ('begin': {token.start:.2f}, 'end': {token.end:.2f}) "
# f"assigned to Speaker {segment['speaker']} ('segment': {segment['begin']:.2f}-{segment['end']:.2f})"
# )
elif token.start > segment['end']:
break
return tokens
def visualize_tokens(tokens):
conversation = [{"speaker": -1, "text": ""}]
for token in tokens:
speaker = conversation[-1]['speaker']
if token.speaker != speaker:
conversation.append({"speaker": token.speaker, "text": token.text})
else:
conversation[-1]['text'] += token.text
print("Conversation:")
for entry in conversation:
print(f"Speaker {entry['speaker']}: {entry['text']}")

View File

@@ -0,0 +1,465 @@
import numpy as np
import torch
import logging
import threading
import time
import wave
from typing import List, Optional
from queue import SimpleQueue, Empty
from whisperlivekit.timed_objects import SpeakerSegment
logger = logging.getLogger(__name__)
try:
from nemo.collections.asr.models import SortformerEncLabelModel
from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor
except ImportError:
raise SystemExit("""Please use `pip install "git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]"` to use the Sortformer diarization""")
class StreamingSortformerState:
"""
This class creates a class instance that will be used to store the state of the
streaming Sortformer model.
Attributes:
spkcache (torch.Tensor): Speaker cache to store embeddings from start
spkcache_lengths (torch.Tensor): Lengths of the speaker cache
spkcache_preds (torch.Tensor): The speaker predictions for the speaker cache parts
fifo (torch.Tensor): FIFO queue to save the embedding from the latest chunks
fifo_lengths (torch.Tensor): Lengths of the FIFO queue
fifo_preds (torch.Tensor): The speaker predictions for the FIFO queue parts
spk_perm (torch.Tensor): Speaker permutation information for the speaker cache
mean_sil_emb (torch.Tensor): Mean silence embedding
n_sil_frames (torch.Tensor): Number of silence frames
"""
def __init__(self):
self.spkcache = None # Speaker cache to store embeddings from start
self.spkcache_lengths = None
self.spkcache_preds = None # speaker cache predictions
self.fifo = None # to save the embedding from the latest chunks
self.fifo_lengths = None
self.fifo_preds = None
self.spk_perm = None
self.mean_sil_emb = None
self.n_sil_frames = None
class SortformerDiarization:
def __init__(self, model_name: str = "nvidia/diar_streaming_sortformer_4spk-v2"):
"""
Stores the shared streaming Sortformer diarization model. Used when a new online_diarization is initialized.
"""
self._load_model(model_name)
def _load_model(self, model_name: str):
"""Load and configure the Sortformer model for streaming."""
try:
self.diar_model = SortformerEncLabelModel.from_pretrained(model_name)
self.diar_model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.diar_model.to(device)
## to test
# for name, param in self.diar_model.named_parameters():
# if param.device != device:
# raise RuntimeError(f"Parameter {name} is on {param.device} but should be on {device}")
logger.info(f"Using {device.type.upper()} for Sortformer model")
self.diar_model.sortformer_modules.chunk_len = 10
self.diar_model.sortformer_modules.subsampling_factor = 10
self.diar_model.sortformer_modules.chunk_right_context = 0
self.diar_model.sortformer_modules.chunk_left_context = 10
self.diar_model.sortformer_modules.spkcache_len = 188
self.diar_model.sortformer_modules.fifo_len = 188
self.diar_model.sortformer_modules.spkcache_update_period = 144
self.diar_model.sortformer_modules.log = False
self.diar_model.sortformer_modules._check_streaming_parameters()
except Exception as e:
logger.error(f"Failed to load Sortformer model: {e}")
raise
class SortformerDiarizationOnline:
def __init__(self, shared_model, sample_rate: int = 16000):
"""
Initialize the streaming Sortformer diarization system.
Args:
sample_rate: Audio sample rate (default: 16000)
model_name: Pre-trained model name (default: "nvidia/diar_streaming_sortformer_4spk-v2")
"""
self.sample_rate = sample_rate
self.speaker_segments = []
self.buffer_audio = np.array([], dtype=np.float32)
self.segment_lock = threading.Lock()
self.global_time_offset = 0.0
self.processed_time = 0.0
self.debug = False
self.diar_model = shared_model.diar_model
self.audio2mel = AudioToMelSpectrogramPreprocessor(
window_size=0.025,
normalize="NA",
n_fft=512,
features=128,
pad_to=0
)
self.audio2mel.to(self.diar_model.device)
self.chunk_duration_seconds = (
self.diar_model.sortformer_modules.chunk_len *
self.diar_model.sortformer_modules.subsampling_factor *
self.diar_model.preprocessor._cfg.window_stride
)
self._init_streaming_state()
self._previous_chunk_features = None
self._chunk_index = 0
self._len_prediction = None
# Audio buffer to store PCM chunks for debugging
self.audio_buffer = []
# Buffer for accumulating audio chunks until reaching chunk_duration_seconds
self.audio_chunk_buffer = []
self.accumulated_duration = 0.0
logger.info("SortformerDiarization initialized successfully")
def _init_streaming_state(self):
"""Initialize the streaming state for the model."""
batch_size = 1
device = self.diar_model.device
self.streaming_state = StreamingSortformerState()
self.streaming_state.spkcache = torch.zeros(
(batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.fc_d_model),
device=device
)
self.streaming_state.spkcache_preds = torch.zeros(
(batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.n_spk),
device=device
)
self.streaming_state.spkcache_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
self.streaming_state.fifo = torch.zeros(
(batch_size, self.diar_model.sortformer_modules.fifo_len, self.diar_model.sortformer_modules.fc_d_model),
device=device
)
self.streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
self.streaming_state.mean_sil_emb = torch.zeros((batch_size, self.diar_model.sortformer_modules.fc_d_model), device=device)
self.streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device)
# Initialize total predictions tensor
self.total_preds = torch.zeros((batch_size, 0, self.diar_model.sortformer_modules.n_spk), device=device)
def insert_silence(self, silence_duration: float):
"""
Insert silence period by adjusting the global time offset.
Args:
silence_duration: Duration of silence in seconds
"""
with self.segment_lock:
self.global_time_offset += silence_duration
logger.debug(f"Inserted silence of {silence_duration:.2f}s, new offset: {self.global_time_offset:.2f}s")
async def diarize(self, pcm_array: np.ndarray):
"""
Process audio data for diarization in streaming fashion.
Args:
pcm_array: Audio data as numpy array
"""
try:
if self.debug:
self.audio_buffer.append(pcm_array.copy())
threshold = int(self.chunk_duration_seconds * self.sample_rate)
self.buffer_audio = np.concatenate([self.buffer_audio, pcm_array.copy()])
if not len(self.buffer_audio) >= threshold:
return
audio = self.buffer_audio[:threshold]
self.buffer_audio = self.buffer_audio[threshold:]
device = self.diar_model.device
audio_signal_chunk = torch.tensor(audio, device=device).unsqueeze(0)
audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]], device=device)
processed_signal_chunk, processed_signal_length_chunk = self.audio2mel.get_features(
audio_signal_chunk, audio_signal_length_chunk
)
processed_signal_chunk = processed_signal_chunk.to(device)
processed_signal_length_chunk = processed_signal_length_chunk.to(device)
if self._previous_chunk_features is not None:
to_add = self._previous_chunk_features[:, :, -99:].to(device)
total_features = torch.concat([to_add, processed_signal_chunk], dim=2).to(device)
else:
total_features = processed_signal_chunk.to(device)
self._previous_chunk_features = processed_signal_chunk.to(device)
chunk_feat_seq_t = torch.transpose(total_features, 1, 2).to(device)
with torch.inference_mode():
left_offset = 8 if self._chunk_index > 0 else 0
right_offset = 8
self.streaming_state, self.total_preds = self.diar_model.forward_streaming_step(
processed_signal=chunk_feat_seq_t,
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]).to(device),
streaming_state=self.streaming_state,
total_preds=self.total_preds,
left_offset=left_offset,
right_offset=right_offset,
)
# Convert predictions to speaker segments
self._process_predictions()
self._chunk_index += 1
except Exception as e:
logger.error(f"Error in diarize: {e}")
raise
# TODO: Handle case when stream ends with partial buffer (accumulated_duration > 0 but < chunk_duration_seconds)
def _process_predictions(self):
"""Process model predictions and convert to speaker segments."""
try:
preds_np = self.total_preds[0].cpu().numpy()
active_speakers = np.argmax(preds_np, axis=1)
if self._len_prediction is None:
self._len_prediction = len(active_speakers)
# Get predictions for current chunk
frame_duration = self.chunk_duration_seconds / self._len_prediction
current_chunk_preds = active_speakers[-self._len_prediction:]
with self.segment_lock:
# Process predictions into segments
base_time = self._chunk_index * self.chunk_duration_seconds + self.global_time_offset
for idx, spk in enumerate(current_chunk_preds):
start_time = base_time + idx * frame_duration
end_time = base_time + (idx + 1) * frame_duration
# Check if this continues the last segment or starts a new one
if (self.speaker_segments and
self.speaker_segments[-1].speaker == spk and
abs(self.speaker_segments[-1].end - start_time) < frame_duration * 0.5):
# Continue existing segment
self.speaker_segments[-1].end = end_time
else:
# Create new segment
self.speaker_segments.append(SpeakerSegment(
speaker=spk,
start=start_time,
end=end_time
))
# Update processed time
self.processed_time = max(self.processed_time, base_time + self.chunk_duration_seconds)
logger.debug(f"Processed chunk {self._chunk_index}, total segments: {len(self.speaker_segments)}")
except Exception as e:
logger.error(f"Error processing predictions: {e}")
def assign_speakers_to_tokens(self, tokens: list, use_punctuation_split: bool = False) -> list:
"""
Assign speakers to tokens based on timing overlap with speaker segments.
Args:
tokens: List of tokens with timing information
use_punctuation_split: Whether to use punctuation for boundary refinement
Returns:
List of tokens with speaker assignments
"""
with self.segment_lock:
segments = self.speaker_segments.copy()
if not segments or not tokens:
logger.debug("No segments or tokens available for speaker assignment")
return tokens
logger.debug(f"Assigning speakers to {len(tokens)} tokens using {len(segments)} segments")
use_punctuation_split = False
if not use_punctuation_split:
# Simple overlap-based assignment
for token in tokens:
token.speaker = -1 # Default to no speaker
for segment in segments:
# Check for timing overlap
if not (segment.end <= token.start or segment.start >= token.end):
token.speaker = segment.speaker + 1 # Convert to 1-based indexing
break
else:
# Use punctuation-aware assignment (similar to diart_backend)
tokens = self._add_speaker_to_tokens_with_punctuation(segments, tokens)
return tokens
def _add_speaker_to_tokens_with_punctuation(self, segments: List[SpeakerSegment], tokens: list) -> list:
"""
Assign speakers to tokens with punctuation-aware boundary adjustment.
Args:
segments: List of speaker segments
tokens: List of tokens to assign speakers to
Returns:
List of tokens with speaker assignments
"""
punctuation_marks = {'.', '!', '?'}
punctuation_tokens = [token for token in tokens if token.text.strip() in punctuation_marks]
# Convert segments to concatenated format
segments_concatenated = self._concatenate_speakers(segments)
# Adjust segment boundaries based on punctuation
for ind, segment in enumerate(segments_concatenated):
for i, punctuation_token in enumerate(punctuation_tokens):
if punctuation_token.start > segment['end']:
after_length = punctuation_token.start - segment['end']
before_length = segment['end'] - punctuation_tokens[i - 1].end if i > 0 else float('inf')
if before_length > after_length:
segment['end'] = punctuation_token.start
if i < len(punctuation_tokens) - 1 and ind + 1 < len(segments_concatenated):
segments_concatenated[ind + 1]['begin'] = punctuation_token.start
else:
segment['end'] = punctuation_tokens[i - 1].end if i > 0 else segment['end']
if i < len(punctuation_tokens) - 1 and ind - 1 >= 0:
segments_concatenated[ind - 1]['begin'] = punctuation_tokens[i - 1].end
break
# Ensure non-overlapping tokens
last_end = 0.0
for token in tokens:
start = max(last_end + 0.01, token.start)
token.start = start
token.end = max(start, token.end)
last_end = token.end
# Assign speakers based on adjusted segments
ind_last_speaker = 0
for segment in segments_concatenated:
for i, token in enumerate(tokens[ind_last_speaker:]):
if token.end <= segment['end']:
token.speaker = segment['speaker']
ind_last_speaker = i + 1
elif token.start > segment['end']:
break
return tokens
def _concatenate_speakers(self, segments: List[SpeakerSegment]) -> List[dict]:
"""
Concatenate consecutive segments from the same speaker.
Args:
segments: List of speaker segments
Returns:
List of concatenated speaker segments
"""
if not segments:
return []
segments_concatenated = [{"speaker": segments[0].speaker + 1, "begin": segments[0].start, "end": segments[0].end}]
for segment in segments[1:]:
speaker = segment.speaker + 1
if segments_concatenated[-1]['speaker'] != speaker:
segments_concatenated.append({"speaker": speaker, "begin": segment.start, "end": segment.end})
else:
segments_concatenated[-1]['end'] = segment.end
return segments_concatenated
def get_segments(self) -> List[SpeakerSegment]:
"""Get a copy of the current speaker segments."""
with self.segment_lock:
return self.speaker_segments.copy()
def clear_old_segments(self, older_than: float = 30.0):
"""Clear segments older than the specified time."""
with self.segment_lock:
current_time = self.processed_time
self.speaker_segments = [
segment for segment in self.speaker_segments
if current_time - segment.end < older_than
]
logger.debug(f"Cleared old segments, remaining: {len(self.speaker_segments)}")
def close(self):
"""Close the diarization system and clean up resources."""
logger.info("Closing SortformerDiarization")
with self.segment_lock:
self.speaker_segments.clear()
if self.debug:
concatenated_audio = np.concatenate(self.audio_buffer)
audio_data_int16 = (concatenated_audio * 32767).astype(np.int16)
with wave.open("diarization_audio.wav", "wb") as wav_file:
wav_file.setnchannels(1) # mono audio
wav_file.setsampwidth(2) # 2 bytes per sample (int16)
wav_file.setframerate(self.sample_rate)
wav_file.writeframes(audio_data_int16.tobytes())
logger.info(f"Saved {len(concatenated_audio)} samples to diarization_audio.wav")
def extract_number(s: str) -> int:
"""Extract number from speaker string (compatibility function)."""
import re
m = re.search(r'\d+', s)
return int(m.group()) if m else 0
if __name__ == '__main__':
import asyncio
import librosa
async def main():
"""TEST ONLY."""
an4_audio = 'audio_test.mp3'
signal, sr = librosa.load(an4_audio, sr=16000)
signal = signal[:16000*30]
print("\n" + "=" * 50)
print("ground truth:")
print("Speaker 0: 0:00 - 0:09")
print("Speaker 1: 0:09 - 0:19")
print("Speaker 2: 0:19 - 0:25")
print("Speaker 0: 0:25 - 0:30")
print("=" * 50)
diarization = SortformerDiarization(sample_rate=16000)
chunk_size = 1600
for i in range(0, len(signal), chunk_size):
chunk = signal[i:i+chunk_size]
await diarization.diarize(chunk)
print(f"Processed chunk {i // chunk_size + 1}")
segments = diarization.get_segments()
print("\nDiarization results:")
for segment in segments:
print(f"Speaker {segment.speaker}: {segment.start:.2f}s - {segment.end:.2f}s")
asyncio.run(main())

View File

@@ -0,0 +1,205 @@
import numpy as np
import torch
import logging
from nemo.collections.asr.models import SortformerEncLabelModel
from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor
import librosa
logger = logging.getLogger(__name__)
def load_model():
diar_model = SortformerEncLabelModel.from_pretrained("nvidia/diar_streaming_sortformer_4spk-v2")
diar_model.eval()
if torch.cuda.is_available():
diar_model.to(torch.device("cuda"))
#we target 1 second lag for the moment. chunk_len could be reduced.
diar_model.sortformer_modules.chunk_len = 10
diar_model.sortformer_modules.subsampling_factor = 10 #8 would be better ideally
diar_model.sortformer_modules.chunk_right_context = 0 #no.
diar_model.sortformer_modules.chunk_left_context = 10 #big so it compensiate the problem with no padding later.
diar_model.sortformer_modules.spkcache_len = 188
diar_model.sortformer_modules.fifo_len = 188
diar_model.sortformer_modules.spkcache_update_period = 144
diar_model.sortformer_modules.log = False
diar_model.sortformer_modules._check_streaming_parameters()
audio2mel = AudioToMelSpectrogramPreprocessor(
window_size= 0.025,
normalize="NA",
n_fft=512,
features=128,
pad_to=0) #pad_to 16 works better than 0. On test audio, we detect a third speaker for 1 second with pad_to=0. To solve that : increase left context to 10.
return diar_model, audio2mel
diar_model, audio2mel = load_model()
class StreamingSortformerState:
"""
This class creates a class instance that will be used to store the state of the
streaming Sortformer model.
Attributes:
spkcache (torch.Tensor): Speaker cache to store embeddings from start
spkcache_lengths (torch.Tensor): Lengths of the speaker cache
spkcache_preds (torch.Tensor): The speaker predictions for the speaker cache parts
fifo (torch.Tensor): FIFO queue to save the embedding from the latest chunks
fifo_lengths (torch.Tensor): Lengths of the FIFO queue
fifo_preds (torch.Tensor): The speaker predictions for the FIFO queue parts
spk_perm (torch.Tensor): Speaker permutation information for the speaker cache
mean_sil_emb (torch.Tensor): Mean silence embedding
n_sil_frames (torch.Tensor): Number of silence frames
"""
spkcache = None # Speaker cache to store embeddings from start
spkcache_lengths = None #
spkcache_preds = None # speaker cache predictions
fifo = None # to save the embedding from the latest chunks
fifo_lengths = None
fifo_preds = None
spk_perm = None
mean_sil_emb = None
n_sil_frames = None
def init_streaming_state(self, batch_size: int = 1, async_streaming: bool = False, device: torch.device = None):
"""
Initializes StreamingSortformerState with empty tensors or zero-valued tensors.
Args:
batch_size (int): Batch size for tensors in streaming state
async_streaming (bool): True for asynchronous update, False for synchronous update
device (torch.device): Device for tensors in streaming state
Returns:
streaming_state (SortformerStreamingState): initialized streaming state
"""
streaming_state = StreamingSortformerState()
if async_streaming:
streaming_state.spkcache = torch.zeros((batch_size, self.spkcache_len, self.fc_d_model), device=device)
streaming_state.spkcache_preds = torch.zeros((batch_size, self.spkcache_len, self.n_spk), device=device)
streaming_state.spkcache_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
streaming_state.fifo = torch.zeros((batch_size, self.fifo_len, self.fc_d_model), device=device)
streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
else:
streaming_state.spkcache = torch.zeros((batch_size, 0, self.fc_d_model), device=device)
streaming_state.fifo = torch.zeros((batch_size, 0, self.fc_d_model), device=device)
streaming_state.mean_sil_emb = torch.zeros((batch_size, self.fc_d_model), device=device)
streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device)
return streaming_state
def process_diarization(chunks):
"""
what it does:
1. Preprocessing: Applies dithering and pre-emphasis (high-pass filter) if enabled
2. STFT: Computes the Short-Time Fourier Transform using:
- the window of window_size=0.025 --> size of a window : 400 samples
- the hop parameter : n_window_stride = 0.01 -> every 160 samples, a new window
3. Magnitude Calculation: Converts complex STFT output to magnitude spectrogram
4. Mel Conversion: Applies Mel filterbanks (128 filters in this case) to get Mel spectrogram
5. Logarithm: Takes the log of the Mel spectrogram (if `log=True`)
6. Normalization: Skips normalization since `normalize="NA"`
7. Padding: Pads the time dimension to a multiple of `pad_to` (default 16)
"""
previous_chunk = None
l_chunk_feat_seq_t = []
for chunk in chunks:
audio_signal_chunk = torch.tensor(chunk).unsqueeze(0).to(diar_model.device)
audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]]).to(diar_model.device)
processed_signal_chunk, processed_signal_length_chunk = audio2mel.get_features(audio_signal_chunk, audio_signal_length_chunk)
if previous_chunk is not None:
to_add = previous_chunk[:, :, -99:]
total = torch.concat([to_add, processed_signal_chunk], dim=2)
else:
total = processed_signal_chunk
previous_chunk = processed_signal_chunk
l_chunk_feat_seq_t.append(torch.transpose(total, 1, 2))
batch_size = 1
streaming_state = init_streaming_state(diar_model.sortformer_modules,
batch_size = batch_size,
async_streaming = True,
device = diar_model.device
)
total_preds = torch.zeros((batch_size, 0, diar_model.sortformer_modules.n_spk), device=diar_model.device)
chunk_duration_seconds = diar_model.sortformer_modules.chunk_len * diar_model.sortformer_modules.subsampling_factor * diar_model.preprocessor._cfg.window_stride
l_speakers = [
{'start_time': 0,
'end_time': 0,
'speaker': 0
}
]
len_prediction = None
left_offset = 0
right_offset = 8
for i, chunk_feat_seq_t in enumerate(l_chunk_feat_seq_t):
with torch.inference_mode():
streaming_state, total_preds = diar_model.forward_streaming_step(
processed_signal=chunk_feat_seq_t,
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]),
streaming_state=streaming_state,
total_preds=total_preds,
left_offset=left_offset,
right_offset=right_offset,
)
left_offset = 8
preds_np = total_preds[0].cpu().numpy()
active_speakers = np.argmax(preds_np, axis=1)
if len_prediction is None:
len_prediction = len(active_speakers) # we want to get the len of 1 prediction
frame_duration = chunk_duration_seconds / len_prediction
active_speakers = active_speakers[-len_prediction:]
for idx, spk in enumerate(active_speakers):
if spk != l_speakers[-1]['speaker']:
l_speakers.append(
{'start_time': (i * chunk_duration_seconds + idx * frame_duration),
'end_time': (i * chunk_duration_seconds + (idx + 1) * frame_duration),
'speaker': spk
})
else:
l_speakers[-1]['end_time'] = i * chunk_duration_seconds + (idx + 1) * frame_duration
"""
Should print
[{'start_time': 0, 'end_time': 8.72, 'speaker': 0},
{'start_time': 8.72, 'end_time': 18.88, 'speaker': 1},
{'start_time': 18.88, 'end_time': 24.96, 'speaker': 2},
{'start_time': 24.96, 'end_time': 31.68, 'speaker': 0}]
"""
for speaker in l_speakers:
print(f"Speaker {speaker['speaker']}: {speaker['start_time']:.2f}s - {speaker['end_time']:.2f}s")
if __name__ == '__main__':
an4_audio = 'audio_test.mp3'
signal, sr = librosa.load(an4_audio, sr=16000)
signal = signal[:16000*30]
# signal = signal[:-(len(signal)%16000)]
print("\n" + "=" * 50)
print("Expected ground truth:")
print("Speaker 0: 0:00 - 0:09")
print("Speaker 1: 0:09 - 0:19")
print("Speaker 2: 0:19 - 0:25")
print("Speaker 0: 0:25 - 0:30")
print("=" * 50)
chunk_size = 16000 # 1 second
chunks = []
for i in range(0, len(signal), chunk_size):
chunk = signal[i:i+chunk_size]
chunks.append(chunk)
process_diarization(chunks)

View File

@@ -1,32 +0,0 @@
import os
import requests
import inspect
def get_module_path():
return os.path.dirname(inspect.getfile(inspect.currentframe()))
GITHUB_API_URL = "https://api.github.com/repos/ufal/SimulStreaming/contents/simul_whisper/whisper"
RAW_BASE_URL = "https://raw.githubusercontent.com/ufal/SimulStreaming/main/simul_whisper/whisper"
TARGET_DIR = os.path.join(get_module_path(), "simul_whisper", "whisper")
def download_files_from_github(api_url, local_dir):
os.makedirs(local_dir, exist_ok=True)
response = requests.get(api_url)
response.raise_for_status()
items = response.json()
for item in items:
if item['type'] == 'file':
download_url = item['download_url']
file_name = item['name']
file_response = requests.get(download_url)
file_response.raise_for_status()
with open(os.path.join(local_dir, file_name), 'wb') as f:
f.write(file_response.content)
elif item['type'] == 'dir':
# Recursive call for subdirectories
download_files_from_github(item['url'], os.path.join(local_dir, item['name']))
def download_simulstreaming_backend():
print(f"Downloading files into {TARGET_DIR} ...")
download_files_from_github(GITHUB_API_URL, TARGET_DIR)
print("✅ Download of SimulStreaming backend files completed successfully.")

View File

@@ -143,7 +143,7 @@ class FFmpegManager:
try: try:
data = await asyncio.wait_for( data = await asyncio.wait_for(
self.process.stdout.read(size), self.process.stdout.read(size),
timeout=5.0 timeout=20.0
) )
return data return data
except asyncio.TimeoutError: except asyncio.TimeoutError:

View File

@@ -58,12 +58,26 @@ def parse_args():
help="Hugging Face model ID for pyannote.audio embedding model.", help="Hugging Face model ID for pyannote.audio embedding model.",
) )
parser.add_argument(
"--diarization-backend",
type=str,
default="sortformer",
choices=["sortformer", "diart"],
help="The diarization backend to use.",
)
parser.add_argument( parser.add_argument(
"--no-transcription", "--no-transcription",
action="store_true", action="store_true",
help="Disable transcription to only see live diarization results.", help="Disable transcription to only see live diarization results.",
) )
parser.add_argument(
"--disable-punctuation-split",
action="store_true",
help="Disable the split parameter.",
)
parser.add_argument( parser.add_argument(
"--min-chunk-size", "--min-chunk-size",
type=float, type=float,
@@ -74,7 +88,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--model", "--model",
type=str, type=str,
default="tiny", default="small",
help="Name size of the Whisper model to use (default: tiny). Suggested values: tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo. The model is automatically downloaded from the model hub if not present in model cache dir.", help="Name size of the Whisper model to use (default: tiny). Suggested values: tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo. The model is automatically downloaded from the model hub if not present in model cache dir.",
) )
@@ -104,18 +118,27 @@ def parse_args():
choices=["transcribe", "translate"], choices=["transcribe", "translate"],
help="Transcribe or translate.", help="Transcribe or translate.",
) )
parser.add_argument(
"--target-language",
type=str,
default="",
dest="target_language",
help="Target language for translation. Not functional yet.",
)
parser.add_argument( parser.add_argument(
"--backend", "--backend",
type=str, type=str,
default="faster-whisper", default="simulstreaming",
choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api", "simulstreaming"], choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api", "simulstreaming"],
help="Load only this backend for Whisper processing.", help="Load only this backend for Whisper processing.",
) )
parser.add_argument( parser.add_argument(
"--vac", "--no-vac",
action="store_true", action="store_true",
default=False, default=False,
help="Use VAC = voice activity controller. Recommended. Requires torch.", help="Disable VAC = voice activity controller.",
) )
parser.add_argument( parser.add_argument(
"--vac-chunk-size", type=float, default=0.04, help="VAC sample size in seconds." "--vac-chunk-size", type=float, default=0.04, help="VAC sample size in seconds."
@@ -150,9 +173,22 @@ def parse_args():
) )
parser.add_argument("--ssl-certfile", type=str, help="Path to the SSL certificate file.", default=None) parser.add_argument("--ssl-certfile", type=str, help="Path to the SSL certificate file.", default=None)
parser.add_argument("--ssl-keyfile", type=str, help="Path to the SSL private key file.", default=None) parser.add_argument("--ssl-keyfile", type=str, help="Path to the SSL private key file.", default=None)
parser.add_argument(
"--pcm-input",
action="store_true",
default=False,
help="If set, raw PCM (s16le) data is expected as input and FFmpeg will be bypassed."
)
# SimulStreaming-specific arguments # SimulStreaming-specific arguments
simulstreaming_group = parser.add_argument_group('SimulStreaming arguments (only used with --backend simulstreaming)') simulstreaming_group = parser.add_argument_group('SimulStreaming arguments (only used with --backend simulstreaming)')
simulstreaming_group.add_argument(
"--disable-fast-encoder",
action="store_true",
default=False,
dest="disable_fast_encoder",
help="Disable Faster Whisper or MLX Whisper backends for encoding (if installed). Slower but helpful when GPU memory is limited",
)
simulstreaming_group.add_argument( simulstreaming_group.add_argument(
"--frame-threshold", "--frame-threshold",
@@ -242,6 +278,14 @@ def parse_args():
dest="model_path", dest="model_path",
help="Direct path to the SimulStreaming Whisper .pt model file. Overrides --model for SimulStreaming backend.", help="Direct path to the SimulStreaming Whisper .pt model file. Overrides --model for SimulStreaming backend.",
) )
simulstreaming_group.add_argument(
"--preload-model-count",
type=int,
default=1,
dest="preload_model_count",
help="Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent instances).",
)
args = parser.parse_args() args = parser.parse_args()

View File

@@ -0,0 +1,110 @@
from whisperlivekit.timed_objects import ASRToken
import re
MIN_SILENCE_DURATION = 4 #in seconds
END_SILENCE_DURATION = 8 #in seconds. you should keep it important to not have false positive when the model lag is important
END_SILENCE_DURATION_VAC = 3 #VAC is good at detecting silences, but we want to skip the smallest silences
def blank_to_silence(tokens):
full_string = ''.join([t.text for t in tokens])
patterns = [re.compile(r'(?:\s*\[BLANK_AUDIO\]\s*)+'), re.compile(r'(?:\s*\[typing\]\s*)+')]
matches = []
for pattern in patterns:
for m in pattern.finditer(full_string):
matches.append({
'start': m.start(),
'end': m.end()
})
if matches:
# cleaned = pattern.sub(' ', full_string).strip()
# print("Cleaned:", cleaned)
cumulated_len = 0
silence_token = None
cleaned_tokens = []
for token in tokens:
if matches:
start = cumulated_len
end = cumulated_len + len(token.text)
cumulated_len = end
if start >= matches[0]['start'] and end <= matches[0]['end']:
if silence_token: #previous token was already silence
silence_token.start = min(silence_token.start, token.start)
silence_token.end = max(silence_token.end, token.end)
else: #new silence
silence_token = ASRToken(
start=token.start,
end=token.end,
speaker=-2,
probability=0.95
)
else:
if silence_token: #there was silence but no more
if silence_token.end - silence_token.start >= MIN_SILENCE_DURATION:
cleaned_tokens.append(
silence_token
)
silence_token = None
matches.pop(0)
cleaned_tokens.append(token)
# print(cleaned_tokens)
return cleaned_tokens
return tokens
def no_token_to_silence(tokens):
new_tokens = []
silence_token = None
for token in tokens:
if token.speaker == -2:
if new_tokens and new_tokens[-1].speaker == -2: #if token is silence and previous one too
new_tokens[-1].end = token.end
else:
new_tokens.append(token)
last_end = new_tokens[-1].end if new_tokens else 0.0
if token.start - last_end >= MIN_SILENCE_DURATION: #if token is not silence but important gap
if new_tokens and new_tokens[-1].speaker == -2:
new_tokens[-1].end = token.start
else:
silence_token = ASRToken(
start=last_end,
end=token.start,
speaker=-2,
probability=0.95
)
new_tokens.append(silence_token)
if token.speaker != -2:
new_tokens.append(token)
return new_tokens
def ends_with_silence(tokens, buffer_transcription, buffer_diarization, current_time, vac_detected_silence):
if not tokens:
return [], buffer_transcription, buffer_diarization
last_token = tokens[-1]
if tokens and current_time and (
current_time - last_token.end >= END_SILENCE_DURATION
or
(current_time - last_token.end >= 3 and vac_detected_silence)
):
if last_token.speaker == -2:
last_token.end = current_time
else:
tokens.append(
ASRToken(
start=tokens[-1].end,
end=current_time,
speaker=-2,
probability=0.95
)
)
buffer_transcription = "" # for whisperstreaming backend, we should probably validate the buffer has because of the silence
buffer_diarization = ""
return tokens, buffer_transcription, buffer_diarization
def handle_silences(tokens, buffer_transcription, buffer_diarization, current_time, vac_detected_silence):
tokens = blank_to_silence(tokens) #useful for simulstreaming backend which tends to generate [BLANK_AUDIO] text
tokens = no_token_to_silence(tokens)
tokens, buffer_transcription, buffer_diarization = ends_with_silence(tokens, buffer_transcription, buffer_diarization, current_time, vac_detected_silence)
return tokens, buffer_transcription, buffer_diarization

View File

@@ -0,0 +1,137 @@
import logging
from whisperlivekit.remove_silences import handle_silences
from whisperlivekit.timed_objects import Line, format_time
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
PUNCTUATION_MARKS = {'.', '!', '?', '', '', ''}
CHECK_AROUND = 4
def is_punctuation(token):
if token.text.strip() in PUNCTUATION_MARKS:
return True
return False
def next_punctuation_change(i, tokens):
for ind in range(i+1, min(len(tokens), i+CHECK_AROUND+1)):
if is_punctuation(tokens[ind]):
return ind
return None
def next_speaker_change(i, tokens, speaker):
for ind in range(i-1, max(0, i-CHECK_AROUND)-1, -1):
token = tokens[ind]
if is_punctuation(token):
break
if token.speaker != speaker:
return ind, token.speaker
return None, speaker
def new_line(
token,
speaker,
debug_info = ""
):
return Line(
speaker = speaker,
text = token.text + debug_info,
start = token.start,
end = token.end,
)
def append_token_to_last_line(lines, sep, token, debug_info):
if token.text:
lines[-1].text += sep + token.text + debug_info
lines[-1].end = token.end
def format_output(state, silence, current_time, args, debug):
diarization = args.diarization
disable_punctuation_split = args.disable_punctuation_split
tokens = state["tokens"]
translated_segments = state["translated_segments"] # Here we will attribute the speakers only based on the timestamps of the segments
buffer_transcription = state["buffer_transcription"]
buffer_diarization = state["buffer_diarization"]
end_attributed_speaker = state["end_attributed_speaker"]
sep = state["sep"]
previous_speaker = -1
lines = []
undiarized_text = []
tokens, buffer_transcription, buffer_diarization = handle_silences(tokens, buffer_transcription, buffer_diarization, current_time, silence)
last_punctuation = None
for i, token in enumerate(tokens):
speaker = token.speaker
if not diarization and speaker == -1: #Speaker -1 means no attributed by diarization. In the frontend, it should appear under 'Speaker 1'
speaker = 1
if diarization and not tokens[-1].speaker == -2:
if (speaker in [-1, 0]) and token.end >= end_attributed_speaker:
undiarized_text.append(token.text)
continue
elif (speaker in [-1, 0]) and token.end < end_attributed_speaker:
speaker = previous_speaker
debug_info = ""
if debug:
debug_info = f"[{format_time(token.start)} : {format_time(token.end)}]"
if not lines:
lines.append(new_line(token, speaker, debug_info = ""))
continue
else:
previous_speaker = lines[-1].speaker
if is_punctuation(token):
last_punctuation = i
if last_punctuation == i-1:
if speaker != previous_speaker:
# perfect, diarization perfectly aligned
lines.append(new_line(token, speaker, debug_info = ""))
last_punctuation, next_punctuation = None, None
continue
speaker_change_pos, new_speaker = next_speaker_change(i, tokens, speaker)
if speaker_change_pos:
# Corrects delay:
# That was the idea. Okay haha |SPLIT SPEAKER| that's a good one
# should become:
# That was the idea. |SPLIT SPEAKER| Okay haha that's a good one
lines.append(new_line(token, new_speaker, debug_info = ""))
else:
# No speaker change to come
append_token_to_last_line(lines, sep, token, debug_info)
continue
if speaker != previous_speaker:
if speaker == -2 or previous_speaker == -2: #silences can happen anytime
lines.append(new_line(token, speaker, debug_info = ""))
continue
elif next_punctuation_change(i, tokens):
# Corrects advance:
# Are you |SPLIT SPEAKER| okay? yeah, sure. Absolutely
# should become:
# Are you okay? |SPLIT SPEAKER| yeah, sure. Absolutely
append_token_to_last_line(lines, sep, token, debug_info)
continue
else: #we create a new speaker, but that's no ideal. We are not sure about the split. We prefer to append to previous line
if disable_punctuation_split:
lines.append(new_line(token, speaker, debug_info = ""))
continue
pass
append_token_to_last_line(lines, sep, token, debug_info)
if lines and translated_segments:
cts_idx = 0 # current_translated_segment_idx
for line in lines:
while cts_idx < len(translated_segments):
ts = translated_segments[cts_idx]
if ts.start and ts.start >= line.start and ts.end <= line.end:
line.translation += ts.text + ' '
cts_idx += 1
else:
break
return lines, undiarized_text, buffer_transcription, ''

View File

@@ -0,0 +1,6 @@
from .backend import SimulStreamingASR, SimulStreamingOnlineProcessor
__all__ = [
"SimulStreamingASR",
"SimulStreamingOnlineProcessor",
]

View File

@@ -0,0 +1,365 @@
import sys
import numpy as np
import logging
from typing import List, Tuple, Optional
import logging
import platform
from whisperlivekit.timed_objects import ASRToken, Transcript
from whisperlivekit.warmup import load_file
from whisperlivekit.simul_whisper.license_simulstreaming import SIMULSTREAMING_LICENSE
from .whisper import load_model, tokenizer
from .whisper.audio import TOKENS_PER_SECOND
import os
import gc
logger = logging.getLogger(__name__)
import torch
from whisperlivekit.simul_whisper.config import AlignAttConfig
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper
from whisperlivekit.simul_whisper.whisper import tokenizer
try:
from .mlx_encoder import mlx_model_mapping, load_mlx_encoder
HAS_MLX_WHISPER = True
except ImportError:
if platform.system() == "Darwin" and platform.machine() == "arm64":
print('MLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: pip install mlx-whisper')
HAS_MLX_WHISPER = False
if HAS_MLX_WHISPER:
HAS_FASTER_WHISPER = False
else:
try:
from faster_whisper import WhisperModel
HAS_FASTER_WHISPER = True
except ImportError:
HAS_FASTER_WHISPER = False
# TOO_MANY_REPETITIONS = 3
class SimulStreamingOnlineProcessor:
SAMPLING_RATE = 16000
def __init__(
self,
asr,
logfile=sys.stderr,
warmup_file=None
):
self.asr = asr
self.logfile = logfile
self.end = 0.0
self.global_time_offset = 0.0
self.committed: List[ASRToken] = []
self.last_result_tokens: List[ASRToken] = []
self.load_new_backend()
#can be moved
if asr.tokenizer:
self.model.tokenizer = asr.tokenizer
def load_new_backend(self):
model = self.asr.get_new_model_instance()
self.model = PaddedAlignAttWhisper(
cfg=self.asr.cfg,
loaded_model=model,
mlx_encoder=self.asr.mlx_encoder,
fw_encoder=self.asr.fw_encoder,
)
def insert_silence(self, silence_duration, offset):
"""
If silences are > 5s, we do a complete context clear. Otherwise, we just insert a small silence and shift the last_attend_frame
"""
if silence_duration < 5:
gap_silence = torch.zeros(int(16000*silence_duration))
self.model.insert_audio(gap_silence)
# self.global_time_offset += silence_duration
else:
self.process_iter(is_last=True) #we want to totally process what remains in the buffer.
self.model.refresh_segment(complete=True)
self.global_time_offset = silence_duration + offset
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time):
"""Append an audio chunk to be processed by SimulStreaming."""
# Convert numpy array to torch tensor
audio_tensor = torch.from_numpy(audio).float()
self.end = audio_stream_end_time #Only to be aligned with what happens in whisperstreaming backend.
self.model.insert_audio(audio_tensor)
def get_buffer(self):
return Transcript(
start=None,
end=None,
text='',
probability=None
)
def timestamped_text(self, tokens, generation):
"""
generate timestamped text from tokens and generation data.
args:
tokens: List of tokens to process
generation: Dictionary containing generation progress and optionally results
returns:
List of tuples containing (start_time, end_time, word) for each word
"""
FRAME_DURATION = 0.02
if "result" in generation:
split_words = generation["result"]["split_words"]
split_tokens = generation["result"]["split_tokens"]
else:
split_words, split_tokens = self.model.tokenizer.split_to_word_tokens(tokens)
progress = generation["progress"]
frames = [p["most_attended_frames"][0] for p in progress]
absolute_timestamps = [p["absolute_timestamps"][0] for p in progress]
tokens_queue = tokens.copy()
timestamped_words = []
for word, word_tokens in zip(split_words, split_tokens):
# start_frame = None
# end_frame = None
for expected_token in word_tokens:
if not tokens_queue or not frames:
raise ValueError(f"Insufficient tokens or frames for word '{word}'")
actual_token = tokens_queue.pop(0)
current_frame = frames.pop(0)
current_timestamp = absolute_timestamps.pop(0)
if actual_token != expected_token:
raise ValueError(
f"Token mismatch: expected '{expected_token}', "
f"got '{actual_token}' at frame {current_frame}"
)
# if start_frame is None:
# start_frame = current_frame
# end_frame = current_frame
# start_time = start_frame * FRAME_DURATION
# end_time = end_frame * FRAME_DURATION
start_time = current_timestamp
end_time = current_timestamp + 0.1
timestamp_entry = (start_time, end_time, word)
timestamped_words.append(timestamp_entry)
logger.debug(f"TS-WORD:\t{start_time:.2f}\t{end_time:.2f}\t{word}")
return timestamped_words
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
"""
Process accumulated audio chunks using SimulStreaming.
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
"""
try:
tokens, generation_progress = self.model.infer(is_last=is_last)
ts_words = self.timestamped_text(tokens, generation_progress)
new_tokens = []
for ts_word in ts_words:
start, end, word = ts_word
token = ASRToken(
start=start,
end=end,
text=word,
probability=0.95 # fake prob. Maybe we can extract it from the model?
).with_offset(
self.global_time_offset
)
new_tokens.append(token)
# identical_tokens = 0
# n_new_tokens = len(new_tokens)
# if n_new_tokens:
self.committed.extend(new_tokens)
# if token in self.committed:
# pos = len(self.committed) - 1 - self.committed[::-1].index(token)
# if pos:
# for i in range(len(self.committed) - n_new_tokens, -1, -n_new_tokens):
# commited_segment = self.committed[i:i+n_new_tokens]
# if commited_segment == new_tokens:
# identical_segments +=1
# if identical_tokens >= TOO_MANY_REPETITIONS:
# logger.warning('Too many repetition, model is stuck. Load a new one')
# self.committed = self.committed[:i]
# self.load_new_backend()
# return [], self.end
# pos = self.committed.rindex(token)
return new_tokens, self.end
except Exception as e:
logger.exception(f"SimulStreaming processing error: {e}")
return [], self.end
def warmup(self, audio, init_prompt=""):
"""Warmup the SimulStreaming model."""
try:
self.model.insert_audio(audio)
self.model.infer(True)
self.model.refresh_segment(complete=True)
logger.info("SimulStreaming model warmed up successfully")
except Exception as e:
logger.exception(f"SimulStreaming warmup failed: {e}")
def __del__(self):
# free the model and add a new model to stack.
# del self.model
gc.collect()
torch.cuda.empty_cache()
# self.asr.new_model_to_stack()
self.model.remove_hooks()
class SimulStreamingASR():
"""SimulStreaming backend with AlignAtt policy."""
sep = ""
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr, **kwargs):
logger.warning(SIMULSTREAMING_LICENSE)
self.logfile = logfile
self.transcribe_kargs = {}
self.original_language = lan
self.model_path = kwargs.get('model_path', './large-v3.pt')
self.frame_threshold = kwargs.get('frame_threshold', 25)
self.audio_max_len = kwargs.get('audio_max_len', 20.0)
self.audio_min_len = kwargs.get('audio_min_len', 0.0)
self.segment_length = kwargs.get('segment_length', 0.5)
self.beams = kwargs.get('beams', 1)
self.decoder_type = kwargs.get('decoder_type', 'greedy' if self.beams == 1 else 'beam')
self.task = kwargs.get('task', 'transcribe')
self.cif_ckpt_path = kwargs.get('cif_ckpt_path', None)
self.never_fire = kwargs.get('never_fire', False)
self.init_prompt = kwargs.get('init_prompt', None)
self.static_init_prompt = kwargs.get('static_init_prompt', None)
self.max_context_tokens = kwargs.get('max_context_tokens', None)
self.warmup_file = kwargs.get('warmup_file', None)
self.preload_model_count = kwargs.get('preload_model_count', 1)
self.disable_fast_encoder = kwargs.get('disable_fast_encoder', False)
self.fast_encoder = False
if model_dir is not None:
self.model_path = model_dir
elif modelsize is not None:
model_mapping = {
'tiny': './tiny.pt',
'base': './base.pt',
'small': './small.pt',
'medium': './medium.pt',
'medium.en': './medium.en.pt',
'large-v1': './large-v1.pt',
'base.en': './base.en.pt',
'small.en': './small.en.pt',
'tiny.en': './tiny.en.pt',
'large-v2': './large-v2.pt',
'large-v3': './large-v3.pt',
'large': './large-v3.pt'
}
self.model_path = model_mapping.get(modelsize, f'./{modelsize}.pt')
self.cfg = AlignAttConfig(
model_path=self.model_path,
segment_length=self.segment_length,
frame_threshold=self.frame_threshold,
language=self.original_language,
audio_max_len=self.audio_max_len,
audio_min_len=self.audio_min_len,
cif_ckpt_path=self.cif_ckpt_path,
decoder_type="beam",
beam_size=self.beams,
task=self.task,
never_fire=self.never_fire,
init_prompt=self.init_prompt,
max_context_tokens=self.max_context_tokens,
static_init_prompt=self.static_init_prompt,
)
# Set up tokenizer for translation if needed
if self.task == "translate":
self.tokenizer = self.set_translate_task()
else:
self.tokenizer = None
self.model_name = os.path.basename(self.cfg.model_path).replace(".pt", "")
self.model_path = os.path.dirname(os.path.abspath(self.cfg.model_path))
self.mlx_encoder, self.fw_encoder = None, None
if not self.disable_fast_encoder:
if HAS_MLX_WHISPER:
print('Simulstreaming will use MLX whisper for a faster encoder.')
mlx_model_name = mlx_model_mapping[self.model_name]
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_name)
self.fast_encoder = True
elif HAS_FASTER_WHISPER:
print('Simulstreaming will use Faster Whisper for the encoder.')
self.fw_encoder = WhisperModel(
self.model_name,
device='auto',
compute_type='auto',
)
self.fast_encoder = True
self.models = [self.load_model() for i in range(self.preload_model_count)]
def load_model(self):
whisper_model = load_model(name=self.model_name, download_root=self.model_path, decoder_only=self.fast_encoder)
warmup_audio = load_file(self.warmup_file)
if warmup_audio is not None:
warmup_audio = torch.from_numpy(warmup_audio).float()
if self.fast_encoder:
temp_model = PaddedAlignAttWhisper(
cfg=self.cfg,
loaded_model=whisper_model,
mlx_encoder=self.mlx_encoder,
fw_encoder=self.fw_encoder,
)
temp_model.warmup(warmup_audio)
temp_model.remove_hooks()
else:
# For standard encoder, use the original transcribe warmup
warmup_audio = load_file(self.warmup_file)
whisper_model.transcribe(warmup_audio, language=self.original_language if self.original_language != 'auto' else None)
return whisper_model
def get_new_model_instance(self):
"""
SimulStreaming cannot share the same backend because it uses global forward hooks on the attention layers.
Therefore, each user requires a separate model instance, which can be memory-intensive. To maintain speed, we preload the models into memory.
"""
if len(self.models) == 0:
self.models.append(self.load_model())
new_model = self.models.pop()
return new_model
# self.models[0]
def new_model_to_stack(self):
self.models.append(self.load_model())
def set_translate_task(self):
"""Set up translation task."""
if self.cfg.language == 'auto':
raise Exception('Translation cannot be done with language = auto')
return tokenizer.get_tokenizer(
multilingual=True,
language=self.cfg.language,
num_languages=99,
task="translate"
)
def transcribe(self, audio):
"""
Warmup is done directly in load_model
"""
pass

View File

@@ -8,7 +8,7 @@ class SimulWhisperConfig:
'''Options that are common for all simul policies that could be implemented in SimulWhisper.''' '''Options that are common for all simul policies that could be implemented in SimulWhisper.'''
model_path: str model_path: str
language: str = field(default="zh") language: str = field(default="zh")
nonspeech_prob: float = 1.0 nonspeech_prob: float = 0.5
audio_min_len: float = 1.0 audio_min_len: float = 1.0
decoder_type: Literal["greedy","beam"] = "greedy" decoder_type: Literal["greedy","beam"] = "greedy"
beam_size: int = 5 beam_size: int = 5
@@ -24,6 +24,6 @@ class AlignAttConfig(SimulWhisperConfig):
segment_length: float = field(default=1.0, metadata = {"help": "in second"}) segment_length: float = field(default=1.0, metadata = {"help": "in second"})
frame_threshold: int = 4 frame_threshold: int = 4
rewind_threshold: int = 200 rewind_threshold: int = 200
audio_max_len: float = 30.0 audio_max_len: float = 20.0
cif_ckpt_path: str = "" cif_ckpt_path: str = ""
never_fire: bool = False never_fire: bool = False

View File

@@ -1,25 +0,0 @@
📄 SimulStreaming (https://github.com/ufal/SimulStreaming) Licence
SimulStreaming is dual-licensed:
🔹 Non-Commercial Use
You may use SimulStreaming under the **PolyForm Noncommercial License 1.0.0** if you
obtain the code through the GitHub repository. This license is **free of charge**
and comes with **no obligations** for non-commercial users.
🔸 Commercial Use
Understanding who uses SimulStreaming commercially helps us improve and
prioritize development. Therefore, we want to **require registration** of those who acquire a commercial licence.
We plan to make the commercial licenceses **affordable** to SMEs and individuals. We
are considering to provide commercial licenses either for free or for symbolic
one-time fee, and maybe also provide additional support. You can share your preference via the [questionnaire](https://forms.cloud.microsoft/e/7tCxb4gJfB).
You can also leave your contact [there](https://forms.cloud.microsoft/e/7tCxb4gJfB) to be notified when the commercial licenses become
available.
✉️ Contact
[Dominik Macháček](https://ufal.mff.cuni.cz/dominik-machacek/), machacek@ufal.mff.cuni.cz

View File

@@ -25,6 +25,9 @@ class BeamTokens(Tokens):
def __repr__(self): def __repr__(self):
return self.__str__() return self.__str__()
def as_text(self, tokenizer):
return tokenizer.decode(self.tokens)
class Logits(Tokens): class Logits(Tokens):
def __init__(self, logits): def __init__(self, logits):
super().__init__(logits) super().__init__(logits)

View File

@@ -0,0 +1,5 @@
SIMULSTREAMING_LICENSE = f"""
SimulStreaming backend is dual-licensed:
• Non-Commercial Use: PolyForm Noncommercial License 1.0.0.
• Commercial Use: Check SimulStreaming README (github.com/ufal/SimulStreaming) for more details.
"""

View File

@@ -0,0 +1,72 @@
import json
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from mlx.utils import tree_unflatten
from mlx_whisper import whisper
mlx_model_mapping = {
"tiny.en": "mlx-community/whisper-tiny.en-mlx",
"tiny": "mlx-community/whisper-tiny-mlx",
"base.en": "mlx-community/whisper-base.en-mlx",
"base": "mlx-community/whisper-base-mlx",
"small.en": "mlx-community/whisper-small.en-mlx",
"small": "mlx-community/whisper-small-mlx",
"medium.en": "mlx-community/whisper-medium.en-mlx",
"medium": "mlx-community/whisper-medium-mlx",
"large-v1": "mlx-community/whisper-large-v1-mlx",
"large-v2": "mlx-community/whisper-large-v2-mlx",
"large-v3": "mlx-community/whisper-large-v3-mlx",
"large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
"large": "mlx-community/whisper-large-mlx",
}
def load_mlx_encoder(
path_or_hf_repo: str,
dtype: mx.Dtype = mx.float32,
) -> whisper.Whisper:
model_path = Path(path_or_hf_repo)
if not model_path.exists():
model_path = Path(snapshot_download(repo_id=path_or_hf_repo))
with open(str(model_path / "config.json"), "r") as f:
config = json.loads(f.read())
config.pop("model_type", None)
quantization = config.pop("quantization", None)
model_args = whisper.ModelDimensions(**config)
wf = model_path / "weights.safetensors"
if not wf.exists():
wf = model_path / "weights.npz"
weights = mx.load(str(wf))
model = whisper.Whisper(model_args, dtype)
if quantization is not None:
class_predicate = (
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
and f"{p}.scales" in weights
)
nn.quantize(model, **quantization, class_predicate=class_predicate)
weights = tree_unflatten(list(weights.items()))
# we only want to load the encoder weights here.
# Size examples: for tiny.en,
# Decoder weights: 59110771 bytes
# Encoder weights: 15268874 bytes
encoder_weights = {}
encoder_weights['encoder'] = weights['encoder']
del(weights)
model.update(encoder_weights)
mx.eval(model.parameters())
return model

View File

@@ -10,12 +10,12 @@ from .whisper import load_model, DecodingOptions, tokenizer
from .config import AlignAttConfig from .config import AlignAttConfig
from .whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES from .whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES
from .whisper.timing import median_filter from .whisper.timing import median_filter
from .whisper.decoding import SuppressBlank, GreedyDecoder, BeamSearchDecoder, SuppressTokens from .whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens, detect_language
from .beam import BeamPyTorchInference from .beam import BeamPyTorchInference
from .eow_detection import fire_at_boundary, load_cif from .eow_detection import fire_at_boundary, load_cif
import os import os
from time import time
from whisperlivekit.simul_whisper.token_buffer import TokenBuffer from .token_buffer import TokenBuffer
import numpy as np import numpy as np
from .generation_progress import * from .generation_progress import *
@@ -23,7 +23,22 @@ from .generation_progress import *
DEC_PAD = 50257 DEC_PAD = 50257
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
import sys
try:
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
HAS_MLX_WHISPER = True
except ImportError:
HAS_MLX_WHISPER = False
if HAS_MLX_WHISPER:
HAS_FASTER_WHISPER = False
else:
try:
from faster_whisper.audio import pad_or_trim as fw_pad_or_trim
from faster_whisper.feature_extractor import FeatureExtractor
HAS_FASTER_WHISPER = True
except ImportError:
HAS_FASTER_WHISPER = False
# New features added to the original version of Simul-Whisper: # New features added to the original version of Simul-Whisper:
# - large-v3 model support # - large-v3 model support
@@ -32,28 +47,43 @@ import sys
# - prompt -- static vs. non-static # - prompt -- static vs. non-static
# - context # - context
class PaddedAlignAttWhisper: class PaddedAlignAttWhisper:
def __init__(self, cfg: AlignAttConfig) -> None: def __init__(
self,
cfg: AlignAttConfig,
loaded_model=None,
mlx_encoder=None,
fw_encoder=None,
) -> None:
self.log_segments = 0
model_name = os.path.basename(cfg.model_path).replace(".pt", "") model_name = os.path.basename(cfg.model_path).replace(".pt", "")
model_path = os.path.dirname(os.path.abspath(cfg.model_path)) model_path = os.path.dirname(os.path.abspath(cfg.model_path))
self.model = load_model(name=model_name, download_root=model_path) if loaded_model:
self.model = loaded_model
else:
self.model = load_model(name=model_name, download_root=model_path)
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.mlx_encoder = mlx_encoder
self.fw_encoder = fw_encoder
if fw_encoder:
self.fw_feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels)
logger.info(f"Model dimensions: {self.model.dims}") logger.info(f"Model dimensions: {self.model.dims}")
decode_options = DecodingOptions( self.decode_options = DecodingOptions(
language = cfg.language, language = cfg.language,
without_timestamps = True, without_timestamps = True,
task=cfg.task task=cfg.task
) )
self.tokenizer = tokenizer.get_tokenizer( self.tokenizer_is_multilingual = not model_name.endswith(".en")
multilingual=not model_name.endswith(".en"), self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
language=cfg.language, self.detected_language = cfg.language if cfg.language != "auto" else None
num_languages=self.model.num_languages,
task=decode_options.task
)
self.max_text_len = self.model.dims.n_text_ctx self.max_text_len = self.model.dims.n_text_ctx
self.num_decoder_layers = len(self.model.decoder.blocks) self.num_decoder_layers = len(self.model.decoder.blocks)
self.cfg = cfg self.cfg = cfg
self.l_hooks = []
# model to detect end-of-word boundary at the end of the segment # model to detect end-of-word boundary at the end of the segment
self.CIFLinear, self.always_fire, self.never_fire = load_cif(cfg, self.CIFLinear, self.always_fire, self.never_fire = load_cif(cfg,
@@ -67,7 +97,8 @@ class PaddedAlignAttWhisper:
t = F.softmax(net_output[1], dim=-1) t = F.softmax(net_output[1], dim=-1)
self.dec_attns.append(t.squeeze(0)) self.dec_attns.append(t.squeeze(0))
for b in self.model.decoder.blocks: for b in self.model.decoder.blocks:
b.cross_attn.register_forward_hook(layer_hook) hook = b.cross_attn.register_forward_hook(layer_hook)
self.l_hooks.append(hook)
self.kv_cache = {} self.kv_cache = {}
def kv_hook(module: torch.nn.Linear, _, net_output: torch.Tensor): def kv_hook(module: torch.nn.Linear, _, net_output: torch.Tensor):
@@ -80,10 +111,13 @@ class PaddedAlignAttWhisper:
return self.kv_cache[module.cache_id] return self.kv_cache[module.cache_id]
for i,b in enumerate(self.model.decoder.blocks): for i,b in enumerate(self.model.decoder.blocks):
b.attn.key.register_forward_hook(kv_hook) hooks = [
b.attn.value.register_forward_hook(kv_hook) b.attn.key.register_forward_hook(kv_hook),
b.cross_attn.key.register_forward_hook(kv_hook) b.attn.value.register_forward_hook(kv_hook),
b.cross_attn.value.register_forward_hook(kv_hook) b.cross_attn.key.register_forward_hook(kv_hook),
b.cross_attn.value.register_forward_hook(kv_hook),
]
self.l_hooks.extend(hooks)
self.align_source = {} self.align_source = {}
self.num_align_heads = 0 self.num_align_heads = 0
@@ -95,14 +129,6 @@ class PaddedAlignAttWhisper:
self.num_align_heads += 1 self.num_align_heads += 1
# init tokens (mandatory prompt)
self.initial_tokens = torch.tensor(
self.tokenizer.sot_sequence_including_notimestamps,
dtype=torch.long,
device=self.model.device).unsqueeze(0)
self.initial_token_length = self.initial_tokens.shape[1]
self.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
# tokens to be suppressed from decoding, to prevent hallucinations # tokens to be suppressed from decoding, to prevent hallucinations
suppress_tokens = [ suppress_tokens = [
self.tokenizer.transcribe, self.tokenizer.transcribe,
@@ -121,6 +147,18 @@ class PaddedAlignAttWhisper:
self.suppress_tokens = lambda logits: sup_tokens.apply(logits, None) self.suppress_tokens = lambda logits: sup_tokens.apply(logits, None)
# blank tokens are suppresed for new segments near the line 334 # blank tokens are suppresed for new segments near the line 334
# it's going to be regenerated after lang id
self.segments = []
self.init_tokens()
self.last_attend_frame = -self.cfg.rewind_threshold
self.cumulative_time_offset = 0.0
if self.cfg.max_context_tokens is None:
self.max_context_tokens = self.max_text_len
else:
self.max_context_tokens = self.cfg.max_context_tokens
self.init_context()
# decoder type: greedy or beam # decoder type: greedy or beam
if cfg.decoder_type == "greedy": if cfg.decoder_type == "greedy":
@@ -134,17 +172,27 @@ class PaddedAlignAttWhisper:
self.inference.kv_cache = self.kv_cache self.inference.kv_cache = self.kv_cache
self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size) self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size)
def remove_hooks(self):
for hook in self.l_hooks:
hook.remove()
# init state def warmup(self, audio):
self.segments = [] try:
self.tokens = [self.initial_tokens] self.insert_audio(audio)
self.last_attend_frame = -self.cfg.rewind_threshold self.infer(is_last=True)
self.refresh_segment(complete=True)
logger.info("Model warmed up successfully")
except Exception as e:
logger.exception(f"Model warmup failed: {e}")
if self.cfg.max_context_tokens is None: def create_tokenizer(self, language=None):
self.max_context_tokens = self.max_text_len self.tokenizer = tokenizer.get_tokenizer(
else: multilingual=self.tokenizer_is_multilingual,
self.max_context_tokens = self.cfg.max_context_tokens language=language,
self.init_context() num_languages=self.model.num_languages,
task=self.decode_options.task
)
def init_context(self): def init_context(self):
kw = {'tokenizer': self.tokenizer, kw = {'tokenizer': self.tokenizer,
@@ -156,6 +204,19 @@ class PaddedAlignAttWhisper:
if self.cfg.init_prompt is not None: if self.cfg.init_prompt is not None:
self.context.text += self.cfg.init_prompt self.context.text += self.cfg.init_prompt
def init_tokens(self):
logger.debug(f"init tokens, {len(self.segments)}")
# init tokens (mandatory prompt)
self.initial_tokens = torch.tensor(
self.tokenizer.sot_sequence_including_notimestamps,
dtype=torch.long,
device=self.model.device).unsqueeze(0)
self.initial_token_length = self.initial_tokens.shape[1]
self.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
# self.segments = []
logger.debug(f"init tokens after, {len(self.segments)}")
self.tokens = [self.initial_tokens]
def trim_context(self): def trim_context(self):
logger.info("Trimming context") logger.info("Trimming context")
c = len(self.context.as_token_ids()) - len(self.context.prefix_token_ids) c = len(self.context.as_token_ids()) - len(self.context.prefix_token_ids)
@@ -191,15 +252,20 @@ class PaddedAlignAttWhisper:
def refresh_segment(self, complete=False): def refresh_segment(self, complete=False):
logger.debug("Refreshing segment") logger.debug("Refreshing segment:")
self.tokens = [self.initial_tokens] self.init_tokens()
self.last_attend_frame = -self.cfg.rewind_threshold self.last_attend_frame = -self.cfg.rewind_threshold
self.detected_language = None
self.cumulative_time_offset = 0.0
self.init_context() self.init_context()
logger.debug(f"Context: {self.context}") logger.debug(f"Context: {self.context}")
if not complete and len(self.segments) > 2: if not complete and len(self.segments) > 2:
logger.debug("keeping last two segments because they are and it is not complete.")
self.segments = self.segments[-2:] self.segments = self.segments[-2:]
else: else:
logger.debug("removing all segments.")
self.segments = [] self.segments = []
self.log_segments += 1
def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor): def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor):
@@ -208,8 +274,6 @@ class PaddedAlignAttWhisper:
return fire_at_boundary(chunked_encoder_feature, self.CIFLinear) return fire_at_boundary(chunked_encoder_feature, self.CIFLinear)
def _current_tokens(self): def _current_tokens(self):
toks = self.tokens toks = self.tokens
@@ -256,16 +320,60 @@ class PaddedAlignAttWhisper:
removed_len = 0 removed_len = 0
# len of audio is bigger than buffer_len. Going to remove the first segment # len of audio is bigger than buffer_len. Going to remove the first segment
segments_len = self.segments_len() segments_len = self.segments_len()
while segments_len > self.cfg.audio_max_len: while len(self.segments) > 1 and segments_len > self.cfg.audio_max_len:
removed_len = self.segments[0].shape[0] / 16000 removed_len = self.segments[0].shape[0] / 16000
segments_len -= removed_len segments_len -= removed_len
self.last_attend_frame -= int(TOKENS_PER_SECOND*removed_len) self.last_attend_frame -= int(TOKENS_PER_SECOND*removed_len)
self.cumulative_time_offset += removed_len # Track cumulative time removed
self.segments = self.segments[1:] self.segments = self.segments[1:]
logger.debug(f"remove segments: {len(self.segments)} {len(self.tokens)}") logger.debug(f"remove segments: {len(self.segments)} {len(self.tokens)}, cumulative offset: {self.cumulative_time_offset:.2f}s")
self.context.append_token_ids(self.tokens[1][0,:]) if len(self.tokens) > 1:
self.tokens = [self.initial_tokens] + self.tokens[2:] self.context.append_token_ids(self.tokens[1][0,:])
self.tokens = [self.initial_tokens] + self.tokens[2:]
return removed_len return removed_len
def _clean_cache(self):
'''clean the cache that stores the attention matrices and kv_cache.
It must be called every time after generation with the model.'''
# cleaning cache
self.dec_attns = []
self.kv_cache = {}
if self.decoder_type == "beam":
self.inference.kv_cache = self.kv_cache
self.token_decoder.reset()
@torch.no_grad()
def lang_id(self, encoder_features):
"""Language detection from encoder features.
This code is trimmed and copy-pasted from whisper.decoding.detect_language .
"""
# forward pass using a single token, startoftranscript
n_audio = encoder_features.shape[0]
x = torch.tensor([[self.tokenizer.sot]] * n_audio).to(self.model.device) # [n_audio, 1]
logits = self.model.logits(x, encoder_features)[:, 0]
# collect detected languages; suppress all non-language tokens
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
mask[list(self.tokenizer.all_language_tokens)] = False
logits[:, mask] = -np.inf
language_tokens = logits.argmax(dim=-1)
language_token_probs = logits.softmax(dim=-1).cpu()
language_probs = [
{
c: language_token_probs[i, j].item()
for j, c in zip(self.tokenizer.all_language_tokens, self.tokenizer.all_language_codes)
}
for i in range(n_audio)
]
single = encoder_features.ndim == 2
if single:
language_tokens = language_tokens[0]
language_probs = language_probs[0]
self._clean_cache()
return language_tokens, language_probs
### transcription / translation ### transcription / translation
@@ -273,9 +381,12 @@ class PaddedAlignAttWhisper:
def infer(self, is_last=False): def infer(self, is_last=False):
new_segment = True new_segment = True
if len(self.segments) == 0: if len(self.segments) == 0:
return [] logger.debug("No segments, nothing to do")
return [], {}
if not self._apply_minseglen(): if not self._apply_minseglen():
return [] logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
input_segments = torch.cat(self.segments, dim=0)
return [], {}
# input_segments is concatenation of audio, it's one array # input_segments is concatenation of audio, it's one array
if len(self.segments) > 1: if len(self.segments) > 1:
@@ -283,30 +394,67 @@ class PaddedAlignAttWhisper:
else: else:
input_segments = self.segments[0] input_segments = self.segments[0]
# NEW : we can use a different encoder, before using standart whisper for cross attention with the hooks on the decoder
beg_encode = time()
if self.mlx_encoder:
mlx_mel_padded = mlx_log_mel_spectrogram(audio=input_segments.detach(), n_mels=self.model.dims.n_mels, padding=N_SAMPLES)
mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2)
mlx_encoder_feature = self.mlx_encoder.encoder(mlx_mel[None])
encoder_feature = torch.as_tensor(mlx_encoder_feature)
content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0])/2)
elif self.fw_encoder:
audio_length_seconds = len(input_segments) / 16000
content_mel_len = int(audio_length_seconds * 100)//2
mel_padded_2 = self.fw_feature_extractor(waveform=input_segments.numpy(), padding=N_SAMPLES)[None, :]
mel = fw_pad_or_trim(mel_padded_2, N_FRAMES, axis=-1)
encoder_feature_ctranslate = self.fw_encoder.encode(mel)
if self.device == 'cpu': #it seems that on gpu, passing StorageView to torch.as_tensor fails and wrapping in the array works
encoder_feature_ctranslate = np.array(encoder_feature_ctranslate)
try:
encoder_feature = torch.as_tensor(encoder_feature_ctranslate, device=self.device)
except TypeError: # Normally the cpu condition should prevent having exceptions, but just in case:
encoder_feature = torch.as_tensor(np.array(encoder_feature_ctranslate), device=self.device)
else:
# mel + padding to 30s
mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES,
device=self.device).unsqueeze(0)
# trim to 3000
mel = pad_or_trim(mel_padded, N_FRAMES)
# the len of actual audio
content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2)
encoder_feature = self.model.encoder(mel)
end_encode = time()
# print('Encoder duration:', end_encode-beg_encode)
# logger.debug(f"Encoder feature shape: {encoder_feature.shape}")
# if mel.shape[-2:] != (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
# logger.debug("mel ")
if self.cfg.language == "auto" and self.detected_language is None:
language_tokens, language_probs = self.lang_id(encoder_feature)
logger.debug(f"Language tokens: {language_tokens}, probs: {language_probs}")
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
logger.info(f"Detected language: {top_lan} with p={p:.4f}")
#self.tokenizer.language = top_lan
#self.tokenizer.__post_init__()
self.create_tokenizer(top_lan)
self.detected_language = top_lan
self.init_tokens()
logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}")
self.trim_context() self.trim_context()
current_tokens = self._current_tokens() current_tokens = self._current_tokens()
#
# mel + padding to 30s
mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES,
device=self.model.device).unsqueeze(0)
# trim to 3000
mel = pad_or_trim(mel_padded, N_FRAMES)
# the len of actual audio
content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2)
encoder_feature = self.model.encoder(mel)
sum_logprobs = torch.zeros(self.cfg.beam_size, device=mel.device)
completed = False
fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :]) fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :])
####################### Decoding loop ####################### Decoding loop
logger.info("Decoding loop starts\n") logger.info("Decoding loop starts\n")
sum_logprobs = torch.zeros(self.cfg.beam_size, device=self.device)
completed = False
attn_of_alignment_heads = None attn_of_alignment_heads = None
miost_attended_frame = None most_attended_frame = None
token_len_before_decoding = current_tokens.shape[1] token_len_before_decoding = current_tokens.shape[1]
@@ -412,7 +560,13 @@ class PaddedAlignAttWhisper:
# for each beam, the most attended frame is: # for each beam, the most attended frame is:
most_attended_frames = torch.argmax(attn_of_alignment_heads[:,-1,:], dim=-1) most_attended_frames = torch.argmax(attn_of_alignment_heads[:,-1,:], dim=-1)
generation_progress_loop.append(("most_attended_frames",most_attended_frames.clone().tolist())) generation_progress_loop.append(("most_attended_frames",most_attended_frames.clone().tolist()))
# Calculate absolute timestamps accounting for cumulative offset
absolute_timestamps = [(frame * 0.02 + self.cumulative_time_offset) for frame in most_attended_frames.tolist()]
generation_progress_loop.append(("absolute_timestamps", absolute_timestamps))
logger.debug(str(most_attended_frames.tolist()) + " most att frames") logger.debug(str(most_attended_frames.tolist()) + " most att frames")
logger.debug(f"Absolute timestamps: {absolute_timestamps} (offset: {self.cumulative_time_offset:.2f}s)")
most_attended_frame = most_attended_frames[0].item() most_attended_frame = most_attended_frames[0].item()
@@ -507,7 +661,7 @@ class PaddedAlignAttWhisper:
### new hypothesis ### new hypothesis
logger.debug(f"new_hypothesis: {new_hypothesis}") logger.debug(f"new_hypothesis: {new_hypothesis}")
new_tokens = torch.tensor([new_hypothesis], dtype=torch.long).repeat_interleave(self.cfg.beam_size, dim=0).to( new_tokens = torch.tensor([new_hypothesis], dtype=torch.long).repeat_interleave(self.cfg.beam_size, dim=0).to(
device=self.model.device, device=self.device,
) )
self.tokens.append(new_tokens) self.tokens.append(new_tokens)
# TODO: test if this is redundant or not # TODO: test if this is redundant or not
@@ -515,11 +669,6 @@ class PaddedAlignAttWhisper:
logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}") logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}")
# cleaning cache self._clean_cache()
self.dec_attns = []
self.kv_cache = {}
if self.decoder_type == "beam":
self.inference.kv_cache = self.kv_cache
self.token_decoder.reset()
return new_hypothesis, generation return new_hypothesis, generation

View File

@@ -105,6 +105,7 @@ def load_model(
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, torch.device]] = None,
download_root: str = None, download_root: str = None,
in_memory: bool = False, in_memory: bool = False,
decoder_only=False
) -> Whisper: ) -> Whisper:
""" """
Load a Whisper ASR model Load a Whisper ASR model
@@ -151,7 +152,14 @@ def load_model(
del checkpoint_file del checkpoint_file
dims = ModelDimensions(**checkpoint["dims"]) dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims) model = Whisper(dims, decoder_only=decoder_only)
if decoder_only:
checkpoint["model_state_dict"] = {
k: v for k, v in checkpoint["model_state_dict"].items()
if 'encoder' not in k
}
model.load_state_dict(checkpoint["model_state_dict"]) model.load_state_dict(checkpoint["model_state_dict"])
if alignment_heads is not None: if alignment_heads is not None:

View File

@@ -32,7 +32,9 @@ def detect_language(
list of dictionaries containing the probability distribution over all languages. list of dictionaries containing the probability distribution over all languages.
""" """
if tokenizer is None: if tokenizer is None:
tokenizer = get_tokenizer(model.is_multilingual) tokenizer = get_tokenizer(
model.is_multilingual, num_languages=model.num_languages
)
if ( if (
tokenizer.language is None tokenizer.language is None
or tokenizer.language_token not in tokenizer.sot_sequence or tokenizer.language_token not in tokenizer.sot_sequence
@@ -111,9 +113,6 @@ class DecodingOptions:
# implementation details # implementation details
fp16: bool = True # use fp16 for most of the calculation fp16: bool = True # use fp16 for most of the calculation
# streaming
add_sot: Optional[bool] = True
@dataclass(frozen=True) @dataclass(frozen=True)
class DecodingResult: class DecodingResult:
@@ -513,19 +512,17 @@ class DecodingTask:
logit_filters: List[LogitFilter] logit_filters: List[LogitFilter]
def __init__(self, model: "Whisper", options: DecodingOptions): def __init__(self, model: "Whisper", options: DecodingOptions):
self.options: DecodingOptions = self._verify_options(options) self.model = model
if self.options.fp16:
self.model = model.half()
else:
self.model = model
language = options.language or "en" language = options.language or "en"
tokenizer = get_tokenizer( tokenizer = get_tokenizer(
model.is_multilingual, language=language, task=options.task model.is_multilingual,
num_languages=model.num_languages,
language=language,
task=options.task,
) )
self.tokenizer: Tokenizer = tokenizer self.tokenizer: Tokenizer = tokenizer
self.options: DecodingOptions = self._verify_options(options)
# print(self.options)
self.n_group: int = options.beam_size or options.best_of or 1 self.n_group: int = options.beam_size or options.best_of or 1
self.n_ctx: int = model.dims.n_text_ctx self.n_ctx: int = model.dims.n_text_ctx
@@ -589,7 +586,7 @@ class DecodingTask:
def _get_initial_tokens(self) -> Tuple[int]: def _get_initial_tokens(self) -> Tuple[int]:
tokens = list(self.sot_sequence) tokens = list(self.sot_sequence)
# print("prefix", prefix)
if prefix := self.options.prefix: if prefix := self.options.prefix:
prefix_tokens = ( prefix_tokens = (
self.tokenizer.encode(" " + prefix.strip()) self.tokenizer.encode(" " + prefix.strip())
@@ -607,15 +604,12 @@ class DecodingTask:
if isinstance(prompt, str) if isinstance(prompt, str)
else prompt else prompt
) )
# if self.options.add_sot:
tokens = ( tokens = (
[self.tokenizer.sot_prev] [self.tokenizer.sot_prev]
+ prompt_tokens[-(self.n_ctx // 2 - 1) :] + prompt_tokens[-(self.n_ctx // 2 - 1) :]
+ tokens + tokens
) )
#else:
# tokens = ([self.tokenizer.sot_prev] + tokens + prompt_tokens[-(self.n_ctx // 2 - 1) :])
# print("return", tokens)
return tuple(tokens) return tuple(tokens)
def _get_suppress_tokens(self) -> Tuple[int]: def _get_suppress_tokens(self) -> Tuple[int]:
@@ -663,7 +657,7 @@ class DecodingTask:
if audio_features.dtype != ( if audio_features.dtype != (
torch.float16 if self.options.fp16 else torch.float32 torch.float16 if self.options.fp16 else torch.float32
): ):
raise TypeError( return TypeError(
f"audio_features has an incorrect dtype: {audio_features.dtype}" f"audio_features has an incorrect dtype: {audio_features.dtype}"
) )
@@ -689,10 +683,9 @@ class DecodingTask:
no_speech_probs = [np.nan] * n_batch no_speech_probs = [np.nan] * n_batch
try: try:
for i in range(self.sample_len): # 最多循环448次 for i in range(self.sample_len):
# print("in decode main loop", i , tokens[0].tolist())
logits = self.inference.logits(tokens, audio_features) logits = self.inference.logits(tokens, audio_features)
# print(logits)
if ( if (
i == 0 and self.tokenizer.no_speech is not None i == 0 and self.tokenizer.no_speech is not None
): # save no_speech_probs ): # save no_speech_probs
@@ -724,7 +717,7 @@ class DecodingTask:
audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass
tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1) tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
# print("initial_tokens", self.initial_tokens)
# detect language if requested, overwriting the language token # detect language if requested, overwriting the language token
languages, language_probs = self._detect_language(audio_features, tokens) languages, language_probs = self._detect_language(audio_features, tokens)
if self.options.task == "lang_id": if self.options.task == "lang_id":

View File

@@ -13,7 +13,6 @@ from .decoding import decode as decode_function
from .decoding import detect_language as detect_language_function from .decoding import detect_language as detect_language_function
from .transcribe import transcribe as transcribe_function from .transcribe import transcribe as transcribe_function
try: try:
from torch.nn.functional import scaled_dot_product_attention from torch.nn.functional import scaled_dot_product_attention
@@ -37,26 +36,27 @@ class ModelDimensions:
n_text_layer: int n_text_layer: int
# class LayerNorm(nn.LayerNorm): class LayerNorm(nn.LayerNorm):
# def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
# return super().forward(x.float()).type(x.dtype) return super().forward(x.float()).type(x.dtype)
# class Linear(nn.Linear):
# def forward(self, x: Tensor) -> Tensor:
# return F.linear(
# x,
# self.weight.to(x.dtype),
# None if self.bias is None else self.bias.to(x.dtype),
# )
# class Conv1d(nn.Conv1d): class Linear(nn.Linear):
# def _conv_forward( def forward(self, x: Tensor) -> Tensor:
# self, x: Tensor, weight: Tensor, bias: Optional[Tensor] return F.linear(
# ) -> Tensor: x,
# return super()._conv_forward( self.weight.to(x.dtype),
# x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype) None if self.bias is None else self.bias.to(x.dtype),
# ) )
class Conv1d(nn.Conv1d):
def _conv_forward(
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
) -> Tensor:
return super()._conv_forward(
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
)
def sinusoids(length, channels, max_timescale=10000): def sinusoids(length, channels, max_timescale=10000):
@@ -67,21 +67,30 @@ def sinusoids(length, channels, max_timescale=10000):
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
import sys ## this is mine, for debugging
@contextmanager
def disable_sdpa():
prev_state = MultiHeadAttention.use_sdpa
try:
MultiHeadAttention.use_sdpa = False
yield
finally:
MultiHeadAttention.use_sdpa = prev_state
class MultiHeadAttention(nn.Module): class MultiHeadAttention(nn.Module):
use_sdpa = False # Disable SDPA to ensure qk is always computed for hooks
use_sdpa = False # disabling: https://github.com/linto-ai/whisper-timestamped/issues/212 def __init__(self, n_state: int, n_head: int, cache_id: str = ""):
def __init__(self, n_state: int, n_head: int, cache_id: str):
super().__init__() super().__init__()
self.n_head = n_head self.n_head = n_head
self.query = nn.Linear(n_state, n_state) self.query = Linear(n_state, n_state)
self.key = nn.Linear(n_state, n_state, bias=False) self.key = Linear(n_state, n_state, bias=False)
self.key.cache_id = f"{cache_id}_key" self.value = Linear(n_state, n_state)
self.value = nn.Linear(n_state, n_state) self.out = Linear(n_state, n_state)
self.value.cache_id = f"{cache_id}_value"
self.out = nn.Linear(n_state, n_state)
self.cache_id = cache_id self.cache_id = cache_id
self.key.cache_id = f"{cache_id}_key"
self.value.cache_id = f"{cache_id}_value"
def forward( def forward(
self, self,
@@ -90,45 +99,21 @@ class MultiHeadAttention(nn.Module):
mask: Optional[Tensor] = None, mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None, kv_cache: Optional[dict] = None,
): ):
#print("MultiHeadAttention forward",file=sys.stderr)
q = self.query(x) q = self.query(x)
# print(q.shape, x is None, mask is None, list(kv_cache.keys()) if kv_cache is not None else None, file=sys.stderr)
# print(mask, kv_cache, xa, file=sys.stderr)
if kv_cache is None or xa is None or self.key.cache_id not in kv_cache: if kv_cache is None or xa is None or self.key not in kv_cache:
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
# otherwise, perform key/value projections for self- or cross-attention as usual.
k = self.key(x if xa is None else xa) k = self.key(x if xa is None else xa)
v = self.value(x if xa is None else xa) v = self.value(x if xa is None else xa)
# print(self.key.cache_id, "cache miss") # , kv_cache is None, xa is None, self.key.cache_id not in kv_cache if kv_cache is not None else None, k.shape, x.shape)
# if kv_cache is not None:
# print(kv_cache.keys())
else: else:
# print(self.key.cache_id, "cache hit") #, kv_cache is None, xa is None, self.key.cache_id not in kv_cache) # for cross-attention, calculate keys and values once and reuse in subsequent calls.
# if kv_cache is not None: k = kv_cache[self.key]
# print(kv_cache.keys()) v = kv_cache[self.value]
k = kv_cache[self.key.cache_id]
v = kv_cache[self.value.cache_id]
# print(self.key.cache_id, "qkv attention", q.shape, k.shape, v.shape)
wv, qk = self.qkv_attention(q, k, v, mask) wv, qk = self.qkv_attention(q, k, v, mask)
return self.out(wv), qk return self.out(wv), qk
# def qkv_attention(
# self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
# ):
# n_batch, n_ctx, n_state = q.shape
# scale = (n_state // self.n_head) ** -0.25
# q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
# k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
# v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
# qk = q @ k
# if mask is not None:
# qk = qk + mask[:n_ctx, :n_ctx]
# # qk = qk.float()
# w = F.softmax(qk, dim=-1) # .to(q.dtype)
# return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
def qkv_attention( def qkv_attention(
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@@ -158,21 +143,22 @@ class MultiHeadAttention(nn.Module):
class ResidualAttentionBlock(nn.Module): class ResidualAttentionBlock(nn.Module):
def __init__(self, n_state: int, n_head: int, cache_id: str="", cross_attention: bool = False): def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, cache_id: str = ""):
super().__init__() super().__init__()
self.attn = MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_self_attn") self.attn = MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_self_attn")
self.attn_ln = nn.LayerNorm(n_state) self.attn_ln = LayerNorm(n_state)
self.cross_attn = MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_cross_attn") if cross_attention else None self.cross_attn = (
MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_cross_attn") if cross_attention else None
self.cross_attn_ln = nn.LayerNorm(n_state) if cross_attention else None )
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
n_mlp = n_state * 4 n_mlp = n_state * 4
self.mlp = nn.Sequential( self.mlp = nn.Sequential(
nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state) Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
) )
self.mlp_ln = nn.LayerNorm(n_state) self.mlp_ln = LayerNorm(n_state)
def forward( def forward(
self, self,
@@ -181,8 +167,6 @@ class ResidualAttentionBlock(nn.Module):
mask: Optional[Tensor] = None, mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None, kv_cache: Optional[dict] = None,
): ):
# print("ResidualAttentionBlock forward",file=sys.stderr)
# print(x.shape, file=sys.stderr)
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0] x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
if self.cross_attn: if self.cross_attn:
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0] x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
@@ -195,44 +179,32 @@ class AudioEncoder(nn.Module):
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
): ):
super().__init__() super().__init__()
self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1) self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)) self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[ResidualAttentionBlock(n_state, n_head, cache_id=f"enc_layer{i}") for i in range(n_layer)] [ResidualAttentionBlock(n_state, n_head, cache_id=f"enc_layer{i}") for i in range(n_layer)]
) )
self.ln_post = nn.LayerNorm(n_state) self.ln_post = LayerNorm(n_state)
def forward(self, x: Tensor, return_layer_results: bool=False): def forward(self, x: Tensor):
""" """
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
the mel spectrogram of the audio the mel spectrogram of the audio
""" """
x = F.gelu(self.conv1(x)) x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x)) x = F.gelu(self.conv2(x))
x = x.permute(0, 2, 1) # BDT -> BTD x = x.permute(0, 2, 1)
# 两层卷积2倍降采样 assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
# 最终剩下1500帧 x = (x + self.positional_embedding).to(x.dtype)
x = (x + self.positional_embedding[:x.shape[1], :]) #.to(x.dtype)
layer_results = []
i = 0
for block in self.blocks: for block in self.blocks:
# print(f"encoder layer {i}")
x = block(x) x = block(x)
layer_results.append(x)
i += 1
x = self.ln_post(x) x = self.ln_post(x)
return x
if return_layer_results:
return x, layer_results
else:
return x
class TextDecoder(nn.Module): class TextDecoder(nn.Module):
@@ -250,7 +222,7 @@ class TextDecoder(nn.Module):
for i in range(n_layer) for i in range(n_layer)
] ]
) )
self.ln = nn.LayerNorm(n_state) self.ln = LayerNorm(n_state)
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
self.register_buffer("mask", mask, persistent=False) self.register_buffer("mask", mask, persistent=False)
@@ -262,37 +234,37 @@ class TextDecoder(nn.Module):
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state) xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
the encoded audio features to be attended on the encoded audio features to be attended on
""" """
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
x = ( x = (
self.token_embedding(x) self.token_embedding(x)
+ self.positional_embedding[offset : offset + x.shape[-1]] + self.positional_embedding[offset : offset + x.shape[-1]]
) )
# x = x.to(xa.dtype) x = x.to(xa.dtype)
i = 0
for block in self.blocks: for block in self.blocks:
# print(f"decoder layer {i}")
x = block(x, xa, mask=self.mask, kv_cache=kv_cache) x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
i += 1
x = self.ln(x) x = self.ln(x)
logits = x @ torch.transpose(self.token_embedding.weight, 0, 1) logits = (
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
).float()
return logits return logits
class Whisper(nn.Module): class Whisper(nn.Module):
def __init__(self, dims: ModelDimensions): def __init__(self, dims: ModelDimensions, decoder_only: bool = False):
super().__init__() super().__init__()
self.dims = dims self.dims = dims
self.encoder = AudioEncoder(
self.dims.n_mels, if not decoder_only:
self.dims.n_audio_ctx, self.encoder = AudioEncoder(
self.dims.n_audio_state, self.dims.n_mels,
self.dims.n_audio_head, self.dims.n_audio_ctx,
self.dims.n_audio_layer, self.dims.n_audio_state,
) self.dims.n_audio_head,
self.dims.n_audio_layer,
)
self.decoder = TextDecoder( self.decoder = TextDecoder(
self.dims.n_vocab, self.dims.n_vocab,
self.dims.n_text_ctx, self.dims.n_text_ctx,
@@ -300,7 +272,8 @@ class Whisper(nn.Module):
self.dims.n_text_head, self.dims.n_text_head,
self.dims.n_text_layer, self.dims.n_text_layer,
) )
# use the last half layers for alignment by default; see `set_alignment_heads()` below # use the last half among the decoder layers for time alignment by default;
# to use a specific set of heads, see `set_alignment_heads()` below.
all_heads = torch.zeros( all_heads = torch.zeros(
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
) )
@@ -320,15 +293,11 @@ class Whisper(nn.Module):
return self.encoder(mel) return self.encoder(mel)
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor): def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
# tokens = tokens.to(self.decoder.ln.weight.dtype)
# audio_features = audio_features.to(self.decoder.ln.weight.dtype)
return self.decoder(tokens, audio_features) return self.decoder(tokens, audio_features)
def forward( def forward(
self, mel: torch.Tensor, tokens: torch.Tensor self, mel: torch.Tensor, tokens: torch.Tensor
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
# mel = mel.to(self.decoder.ln.weight.dtype)
# tokens = tokens.to(self.decoder.ln.weight.dtype)
return self.decoder(tokens, self.encoder(mel)) return self.decoder(tokens, self.encoder(mel))
@property @property
@@ -343,7 +312,6 @@ class Whisper(nn.Module):
def num_languages(self): def num_languages(self):
return self.dims.n_vocab - 51765 - int(self.is_multilingual) return self.dims.n_vocab - 51765 - int(self.is_multilingual)
# 为decoder加入缓存机制每次推理时保存上次的k和v下次推理无需重新计算
def install_kv_cache_hooks(self, cache: Optional[dict] = None): def install_kv_cache_hooks(self, cache: Optional[dict] = None):
""" """
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value

View File

@@ -30,15 +30,19 @@ def remove_symbols_and_diacritics(s: str, keep=""):
and drop any diacritics (category 'Mn' and some manual mappings) and drop any diacritics (category 'Mn' and some manual mappings)
""" """
return "".join( return "".join(
c (
if c in keep c
else ADDITIONAL_DIACRITICS[c] if c in keep
if c in ADDITIONAL_DIACRITICS else (
else "" ADDITIONAL_DIACRITICS[c]
if unicodedata.category(c) == "Mn" if c in ADDITIONAL_DIACRITICS
else " " else (
if unicodedata.category(c)[0] in "MSP" ""
else c if unicodedata.category(c) == "Mn"
else " " if unicodedata.category(c)[0] in "MSP" else c
)
)
)
for c in unicodedata.normalize("NFKD", s) for c in unicodedata.normalize("NFKD", s)
) )

File diff suppressed because it is too large Load Diff

View File

@@ -56,9 +56,8 @@ def median_filter(x: torch.Tensor, filter_width: int):
@numba.jit(nopython=True) @numba.jit(nopython=True)
def backtrace(trace: np.ndarray): def backtrace(trace: np.ndarray):
i = trace.shape[0] - 1 # trace: (N+1, M+1), i=N i = trace.shape[0] - 1
j = trace.shape[1] - 1 # j=M j = trace.shape[1] - 1
# 边界点其实无意义?
trace[0, :] = 2 trace[0, :] = 2
trace[:, 0] = 1 trace[:, 0] = 1
@@ -83,8 +82,8 @@ def backtrace(trace: np.ndarray):
@numba.jit(nopython=True, parallel=True) @numba.jit(nopython=True, parallel=True)
def dtw_cpu(x: np.ndarray): def dtw_cpu(x: np.ndarray):
N, M = x.shape N, M = x.shape
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf # cost: x[0, 0]到x[i-1, j-1]的最小代价 cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
trace = -np.ones((N + 1, M + 1), dtype=np.float32) # trace: trace = -np.ones((N + 1, M + 1), dtype=np.float32)
cost[0, 0] = 0 cost[0, 0] = 0
for j in range(1, M + 1): for j in range(1, M + 1):
@@ -118,7 +117,7 @@ def dtw_cuda(x, BLOCK_SIZE=1024):
x_skew = x_skew.T.contiguous() x_skew = x_skew.T.contiguous()
cost = torch.ones(N + M + 2, M + 2) * np.inf cost = torch.ones(N + M + 2, M + 2) * np.inf
cost[0, 0] = 0 cost[0, 0] = 0
cost = cost.cuda() cost = cost.to(x.device)
trace = torch.zeros_like(cost, dtype=torch.int32) trace = torch.zeros_like(cost, dtype=torch.int32)
dtw_kernel[(1,)]( dtw_kernel[(1,)](
@@ -192,21 +191,19 @@ def find_alignment(
for i, block in enumerate(model.decoder.blocks) for i, block in enumerate(model.decoder.blocks)
] ]
# 进行前传获得token概率 from .model import disable_sdpa
with torch.no_grad():
with torch.no_grad(), disable_sdpa():
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0] logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot] sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
token_probs = sampled_logits.softmax(dim=-1) token_probs = sampled_logits.softmax(dim=-1)
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens] text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
text_token_probs = text_token_probs.tolist() text_token_probs = text_token_probs.tolist()
# 移除钩子
for hook in hooks: for hook in hooks:
hook.remove() hook.remove()
# heads * tokens * frames # heads * tokens * frames
# print(model.alignment_heads)
# exit(0)
weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T]) weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
weights = weights[:, :, : num_frames // 2] weights = weights[:, :, : num_frames // 2]
weights = (weights * qk_scale).softmax(dim=-1) weights = (weights * qk_scale).softmax(dim=-1)
@@ -215,18 +212,9 @@ def find_alignment(
weights = median_filter(weights, medfilt_width) weights = median_filter(weights, medfilt_width)
matrix = weights.mean(axis=0) matrix = weights.mean(axis=0)
print("attention", matrix.shape, matrix[:5, :5])
matrix = matrix[len(tokenizer.sot_sequence) : -1] matrix = matrix[len(tokenizer.sot_sequence) : -1]
print("attention", matrix.shape, matrix[:5, :5])
text_indices, time_indices = dtw(-matrix) text_indices, time_indices = dtw(-matrix)
print("num_frames", num_frames)
print("attention", matrix.shape, matrix[:5, :5])
print("text_indices", text_indices)
print("time", time_indices)
print("text_tokens", text_tokens, tokenizer.decode(text_tokens), len(text_tokens))
print("eot", tokenizer.eot)
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot]) words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
if len(word_tokens) <= 1: if len(word_tokens) <= 1:
# return on eot only # return on eot only
@@ -238,9 +226,7 @@ def find_alignment(
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0)) word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool) jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
# print("jumps", jumps, jumps.shape)
jump_times = time_indices[jumps] / TOKENS_PER_SECOND jump_times = time_indices[jumps] / TOKENS_PER_SECOND
# print("jump_times", jump_times)
start_times = jump_times[word_boundaries[:-1]] start_times = jump_times[word_boundaries[:-1]]
end_times = jump_times[word_boundaries[1:]] end_times = jump_times[word_boundaries[1:]]
word_probabilities = [ word_probabilities = [
@@ -315,6 +301,7 @@ def add_word_timestamps(
word_durations = np.array([t.end - t.start for t in alignment]) word_durations = np.array([t.end - t.start for t in alignment])
word_durations = word_durations[word_durations.nonzero()] word_durations = word_durations[word_durations.nonzero()]
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0 median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
median_duration = min(0.7, float(median_duration))
max_duration = median_duration * 2 max_duration = median_duration * 2
# hack: truncate long words at sentence boundaries. # hack: truncate long words at sentence boundaries.

View File

@@ -1,501 +0,0 @@
import argparse
import os
import warnings
from typing import TYPE_CHECKING, Optional, Tuple, Union
import numpy as np
import torch
import tqdm
from whisper.audio import (
FRAMES_PER_SECOND,
HOP_LENGTH,
N_FRAMES,
N_SAMPLES,
SAMPLE_RATE,
log_mel_spectrogram,
pad_or_trim,
)
from whisper.decoding import DecodingOptions, DecodingResult
from whisper.timing import add_word_timestamps
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from whisper.utils import (
exact_div,
format_timestamp,
get_writer,
make_safe,
optional_float,
optional_int,
str2bool,
)
if TYPE_CHECKING:
from whisper.model import Whisper
def transcribe(
model: "Whisper",
audio: Union[str, np.ndarray, torch.Tensor],
*,
verbose: Optional[bool] = None,
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
compression_ratio_threshold: Optional[float] = 2.4,
logprob_threshold: Optional[float] = -1.0,
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True,
initial_prompt: Optional[str] = None,
word_timestamps: bool = False,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、",
**decode_options,
):
"""
Transcribe an audio file using Whisper
Parameters
----------
model: Whisper
The Whisper model instance
audio: Union[str, np.ndarray, torch.Tensor]
The path to the audio file to open, or the audio waveform
verbose: bool
Whether to display the text being decoded to the console. If True, displays all the details,
If False, displays minimal details. If None, does not display anything
temperature: Union[float, Tuple[float, ...]]
Temperature for sampling. It can be a tuple of temperatures, which will be successively used
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
compression_ratio_threshold: float
If the gzip compression ratio is above this value, treat as failed
logprob_threshold: float
If the average log probability over sampled tokens is below this value, treat as failed
no_speech_threshold: float
If the no_speech probability is higher than this value AND the average log probability
over sampled tokens is below `logprob_threshold`, consider the segment as silent
condition_on_previous_text: bool
if True, the previous output of the model is provided as a prompt for the next window;
disabling may make the text inconsistent across windows, but the model becomes less prone to
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
word_timestamps: bool
Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
and include the timestamps for each word in each segment.
prepend_punctuations: str
If word_timestamps is True, merge these punctuation symbols with the next word
append_punctuations: str
If word_timestamps is True, merge these punctuation symbols with the previous word
initial_prompt: Optional[str]
Optional text to provide as a prompt for the first window. This can be used to provide, or
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
to make it more likely to predict those word correctly.
decode_options: dict
Keyword arguments to construct `DecodingOptions` instances
Returns
-------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
"""
# print("HACKED")
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
if model.device == torch.device("cpu"):
if torch.cuda.is_available():
warnings.warn("Performing inference on CPU when CUDA is available")
if dtype == torch.float16:
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
dtype = torch.float32
if dtype == torch.float32:
decode_options["fp16"] = False
# Pad 30-seconds of silence to the input audio, for slicing
mel = log_mel_spectrogram(audio, padding=0) # log_mel_spectrogram(audio, padding=N_SAMPLES) # 添加16000*30 = 480000个点
# mel = pad_or_trim(mel, 3000)
content_frames = mel.shape[-1] # - N_FRAMES # 对应3000帧真正有内容的是去掉尾部3000的那些数据
# 判断语种
if decode_options.get("language", None) is None:
# 如果是单语种模型,直接设成英文
if not model.is_multilingual:
decode_options["language"] = "en"
# 否则需要前传一次
else:
if verbose:
print(
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
)
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
# print(mel_segment.shape)
_, probs = model.detect_language(mel_segment)
decode_options["language"] = max(probs, key=probs.get)
if verbose is not None:
print(
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
)
language: str = decode_options["language"]
task: str = decode_options.get("task", "transcribe")
# 输出编码器
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
# 词级别时间戳
if word_timestamps and task == "translate":
warnings.warn("Word-level timestamps on translations may not be reliable.")
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
temperatures = (
[temperature] if isinstance(temperature, (int, float)) else temperature
)
decode_result = None
for t in temperatures:
kwargs = {**decode_options}
if t > 0:
# disable beam_size and patience when t > 0
kwargs.pop("beam_size", None)
kwargs.pop("patience", None)
else:
# disable best_of when t == 0
kwargs.pop("best_of", None)
options = DecodingOptions(**kwargs, temperature=t)
decode_result = model.decode(segment, options)
# 几种解码可能失败的情况。这些情况下会重复解码
# 感觉是一种KnowHow的东西 或许ChatGPT里有不少这种trick
needs_fallback = False
if (
compression_ratio_threshold is not None
and decode_result.compression_ratio > compression_ratio_threshold
):
needs_fallback = True # too repetitive
if (
logprob_threshold is not None
and decode_result.avg_logprob < logprob_threshold
):
needs_fallback = True # average log probability is too low
if (
no_speech_threshold is not None
and decode_result.no_speech_prob > no_speech_threshold
):
needs_fallback = False # silence
if not needs_fallback:
break
# print("decode with temperature {} compress rate {:.3f}/{:.3f}, log_prob {:.3f}/{:.3f}, {:.3f}/{:.3f}".format(
# t,
# decode_result.compression_ratio, compression_ratio_threshold,
# -decode_result.avg_logprob, -logprob_threshold,
# decode_result.no_speech_prob, no_speech_threshold
# ))
return decode_result
seek = 0
input_stride = exact_div(
N_FRAMES, model.dims.n_audio_ctx
) # mel frames per output token: 2
# 这里output token指的应该是CNN输出的那个东西
time_precision = (
input_stride * HOP_LENGTH / SAMPLE_RATE
) # time per output token: 0.02 (seconds)
all_tokens = []
all_segments = []
prompt_reset_since = 0
if initial_prompt is not None:
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
all_tokens.extend(initial_prompt_tokens)
else:
initial_prompt_tokens = []
def new_segment(
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
):
tokens = tokens.tolist()
text_tokens = [token for token in tokens if token < tokenizer.eot]
return {
"seek": seek,
"start": start,
"end": end,
"text": tokenizer.decode(text_tokens),
"tokens": tokens,
"temperature": result.temperature,
"avg_logprob": result.avg_logprob,
"compression_ratio": result.compression_ratio,
"no_speech_prob": result.no_speech_prob,
}
# show the progress bar when verbose is False (if True, transcribed text will be printed)
with tqdm.tqdm(
total=content_frames, unit="frames", disable=verbose is not False
) as pbar:
last_speech_timestamp = 0.0
while seek < content_frames: # seek标记mel频谱当前帧的位置 直接跳过Padding上的部分
# print("seek segments", seek, content_frames)
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) # 本片段的开始时间
# mel_segment = mel[:, seek : seek + N_FRAMES] # 获得当前片段的数据
mel_segment = mel[:, seek:]
segment_size = min(N_FRAMES, content_frames - seek) # segment_size: 排除padding的真的长度。content_frames有内容的段的真正长度 如果不够N_FRAMES的话就会截断
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE # 当前片段的时长
mel_segment = mel_segment.to(model.device).to(dtype) # pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype) # 补到mel_segment帧
decode_options["prompt"] = all_tokens[prompt_reset_since:]
result: DecodingResult = decode_with_fallback(mel_segment)
tokens = torch.tensor(result.tokens)
# 跳过静音部分
if no_speech_threshold is not None:
# no voice activity check
should_skip = result.no_speech_prob > no_speech_threshold
if (
logprob_threshold is not None
and result.avg_logprob > logprob_threshold
):
# don't skip if the logprob is high enough, despite the no_speech_prob
should_skip = False
if should_skip:
seek += segment_size # fast-forward to the next segment boundary
continue
previous_seek = seek
current_segments = []
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) # timestamp begin是<|0.00|>的tokenbos比文字token大eos的值比bos还大所以是ge
timestamp_tokens[-1] = False
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] # 如果最后是[False,True]:本段里一个句子结束了
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
# torch.where(condition) is identical to torch.nonzero(condition, as_tuple=True).
# timestamp_token就是个一维向量吧 那为啥不直接nonzero
# 如果有两个连续的时间戳 这个会是一个一维tensor 是这两个连续时间戳的结尾位置
# 多个的话指向第二个 那如果有三个怎么办?
# 否则是个0维tensor
consecutive.add_(1) # 0维tensor+1还是0维 哪儿找的这些edge cases js是吧
if len(consecutive) > 0:
# if the output contains two consecutive timestamp tokens
slices = consecutive.tolist()
if single_timestamp_ending:
slices.append(len(tokens)) # 把最后一段的结尾也加进去
# print("many sentenses", consecutive)
last_slice = 0
for current_slice in slices:
sliced_tokens = tokens[last_slice:current_slice]
# 看起来语音开始帧、语音结束帧的位置会被编码到start_timestamp中
start_timestamp_pos = (
sliced_tokens[0].item() - tokenizer.timestamp_begin
)
end_timestamp_pos = (
sliced_tokens[-1].item() - tokenizer.timestamp_begin
)
# 获取一个新的语音段
current_segments.append(
new_segment(
start=time_offset + start_timestamp_pos * time_precision,
end=time_offset + end_timestamp_pos * time_precision,
tokens=sliced_tokens,
result=result,
)
)
last_slice = current_slice
if single_timestamp_ending:
# single timestamp at the end means no speech after the last timestamp.
seek += segment_size
else:
# otherwise, ignore the unfinished segment and seek to the last timestamp
# 如果语音尚未结束那么seek变为上一个结束的语段的位置
# 换句话说就是针对30s长的chunk的语音设计的
last_timestamp_pos = (
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
)
seek += last_timestamp_pos * input_stride
else:
duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
# print(timestamps)
if (
len(timestamps) > 0
and timestamps[-1].item() != tokenizer.timestamp_begin
):
# no consecutive timestamps but it has a timestamp; use the last one.
# 取最后一个假设要么有一个结束的time stamp要么有一对儿
# 如果里面只有一个开始的timestamp 似乎后面的东西都会被丢掉?
last_timestamp_pos = (
timestamps[-1].item() - tokenizer.timestamp_begin
)
duration = last_timestamp_pos * time_precision
current_segments.append(
new_segment(
start=time_offset,
end=time_offset + duration,
tokens=tokens,
result=result,
)
)
seek += segment_size
# 每个token有自己的时间戳
if word_timestamps:
add_word_timestamps(
segments=current_segments,
model=model,
tokenizer=tokenizer,
mel=mel_segment,
num_frames=segment_size,
prepend_punctuations=prepend_punctuations,
append_punctuations=append_punctuations,
last_speech_timestamp=last_speech_timestamp,
)
word_end_timestamps = [
w["end"] for s in current_segments for w in s["words"]
]
if len(word_end_timestamps) > 0:
last_speech_timestamp = word_end_timestamps[-1]
if not single_timestamp_ending and len(word_end_timestamps) > 0:
seek_shift = round(
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
)
if seek_shift > 0:
seek = previous_seek + seek_shift
if verbose:
for segment in current_segments:
start, end, text = segment["start"], segment["end"], segment["text"]
line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
print(make_safe(line))
# if a segment is instantaneous or does not contain text, clear it
for i, segment in enumerate(current_segments):
if segment["start"] == segment["end"] or segment["text"].strip() == "":
segment["text"] = ""
segment["tokens"] = []
segment["words"] = []
# 更新结果
all_segments.extend(
[
{"id": i, **segment}
for i, segment in enumerate(
current_segments, start=len(all_segments)
)
]
)
all_tokens.extend(
[token for segment in current_segments for token in segment["tokens"]]
)
if not condition_on_previous_text or result.temperature > 0.5:
# do not feed the prompt tokens if a high temperature was used
prompt_reset_since = len(all_tokens)
# update progress bar
pbar.update(min(content_frames, seek) - previous_seek)
# print("太长了")
# break
return dict(
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
segments=all_segments,
language=language,
)
def cli():
from . import available_models
# fmt: off
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
parser.add_argument("--append_punctuations", type=str, default="\"\'.。,!?::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
# fmt: on
args = parser.parse_args().__dict__
model_name: str = args.pop("model")
model_dir: str = args.pop("model_dir")
output_dir: str = args.pop("output_dir")
output_format: str = args.pop("output_format")
device: str = args.pop("device")
os.makedirs(output_dir, exist_ok=True)
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
if args["language"] is not None:
warnings.warn(
f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
)
args["language"] = "en"
temperature = args.pop("temperature")
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
else:
temperature = [temperature]
if (threads := args.pop("threads")) > 0:
torch.set_num_threads(threads)
from . import load_model
model = load_model(model_name, device=device, download_root=model_dir)
writer = get_writer(output_format, output_dir)
word_options = ["highlight_words", "max_line_count", "max_line_width"]
if not args["word_timestamps"]:
for option in word_options:
if args[option]:
parser.error(f"--{option} requires --word_timestamps True")
if args["max_line_count"] and not args["max_line_width"]:
warnings.warn("--max_line_count has no effect without --max_line_width")
writer_args = {arg: args.pop(arg) for arg in word_options}
for audio_path in args.pop("audio"):
result = transcribe(model, audio_path, temperature=temperature, **args)
writer(result, audio_path, writer_args)
if __name__ == "__main__":
cli()

View File

@@ -1,7 +1,8 @@
import argparse import argparse
import os import os
import traceback
import warnings import warnings
from typing import TYPE_CHECKING, Optional, Tuple, Union from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
@@ -22,6 +23,7 @@ from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .utils import ( from .utils import (
exact_div, exact_div,
format_timestamp, format_timestamp,
get_end,
get_writer, get_writer,
make_safe, make_safe,
optional_float, optional_float,
@@ -44,9 +46,12 @@ def transcribe(
no_speech_threshold: Optional[float] = 0.6, no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True, condition_on_previous_text: bool = True,
initial_prompt: Optional[str] = None, initial_prompt: Optional[str] = None,
carry_initial_prompt: bool = False,
word_timestamps: bool = False, word_timestamps: bool = False,
prepend_punctuations: str = "\"'“¿([{-", prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、", append_punctuations: str = "\"'.。,!?::”)]}、",
clip_timestamps: Union[str, List[float]] = "0",
hallucination_silence_threshold: Optional[float] = None,
**decode_options, **decode_options,
): ):
""" """
@@ -98,15 +103,27 @@ def transcribe(
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
to make it more likely to predict those word correctly. to make it more likely to predict those word correctly.
carry_initial_prompt: bool
If carry_initial_prompt is True, `initial_prompt` is prepended to the prompt of each internal
`decode()` call. If there is not enough context space at the start of the prompt, it is
left-sliced to make space.
decode_options: dict decode_options: dict
Keyword arguments to construct `DecodingOptions` instances Keyword arguments to construct `DecodingOptions` instances
clip_timestamps: Union[str, List[float]]
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process.
The last end timestamp defaults to the end of the file.
hallucination_silence_threshold: Optional[float]
When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
when a possible hallucination is detected
Returns Returns
------- -------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
the spoken language ("language"), which is detected when `decode_options["language"]` is None. the spoken language ("language"), which is detected when `decode_options["language"]` is None.
""" """
# print("transcribe")
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32 dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
if model.device == torch.device("cpu"): if model.device == torch.device("cpu"):
if torch.cuda.is_available(): if torch.cuda.is_available():
@@ -119,8 +136,9 @@ def transcribe(
decode_options["fp16"] = False decode_options["fp16"] = False
# Pad 30-seconds of silence to the input audio, for slicing # Pad 30-seconds of silence to the input audio, for slicing
mel = log_mel_spectrogram(audio, padding=N_SAMPLES) mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
content_frames = mel.shape[-1] - N_FRAMES content_frames = mel.shape[-1] - N_FRAMES
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
if decode_options.get("language", None) is None: if decode_options.get("language", None) is None:
if not model.is_multilingual: if not model.is_multilingual:
@@ -131,7 +149,6 @@ def transcribe(
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language" "Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
) )
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype) mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
# print(mel_segment.shape)
_, probs = model.detect_language(mel_segment) _, probs = model.detect_language(mel_segment)
decode_options["language"] = max(probs, key=probs.get) decode_options["language"] = max(probs, key=probs.get)
if verbose is not None: if verbose is not None:
@@ -141,7 +158,25 @@ def transcribe(
language: str = decode_options["language"] language: str = decode_options["language"]
task: str = decode_options.get("task", "transcribe") task: str = decode_options.get("task", "transcribe")
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task) tokenizer = get_tokenizer(
model.is_multilingual,
num_languages=model.num_languages,
language=language,
task=task,
)
if isinstance(clip_timestamps, str):
clip_timestamps = [
float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])
]
seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
if len(seek_points) == 0:
seek_points.append(0)
if len(seek_points) % 2 == 1:
seek_points.append(content_frames)
seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))
punctuation = "\"'“¿([{-\"'.。,!?::”)]}、"
if word_timestamps and task == "translate": if word_timestamps and task == "translate":
warnings.warn("Word-level timestamps on translations may not be reliable.") warnings.warn("Word-level timestamps on translations may not be reliable.")
@@ -179,6 +214,8 @@ def transcribe(
if ( if (
no_speech_threshold is not None no_speech_threshold is not None
and decode_result.no_speech_prob > no_speech_threshold and decode_result.no_speech_prob > no_speech_threshold
and logprob_threshold is not None
and decode_result.avg_logprob < logprob_threshold
): ):
needs_fallback = False # silence needs_fallback = False # silence
if not needs_fallback: if not needs_fallback:
@@ -186,7 +223,8 @@ def transcribe(
return decode_result return decode_result
seek = 0 clip_idx = 0
seek = seek_clips[clip_idx][0]
input_stride = exact_div( input_stride = exact_div(
N_FRAMES, model.dims.n_audio_ctx N_FRAMES, model.dims.n_audio_ctx
) # mel frames per output token: 2 ) # mel frames per output token: 2
@@ -197,9 +235,11 @@ def transcribe(
all_segments = [] all_segments = []
prompt_reset_since = 0 prompt_reset_since = 0
remaining_prompt_length = model.dims.n_text_ctx // 2 - 1
if initial_prompt is not None: if initial_prompt is not None:
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip()) initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
all_tokens.extend(initial_prompt_tokens) all_tokens.extend(initial_prompt_tokens)
remaining_prompt_length -= len(initial_prompt_tokens)
else: else:
initial_prompt_tokens = [] initial_prompt_tokens = []
@@ -225,16 +265,33 @@ def transcribe(
total=content_frames, unit="frames", disable=verbose is not False total=content_frames, unit="frames", disable=verbose is not False
) as pbar: ) as pbar:
last_speech_timestamp = 0.0 last_speech_timestamp = 0.0
while seek < content_frames: # NOTE: This loop is obscurely flattened to make the diff readable.
# A later commit should turn this into a simpler nested loop.
# for seek_clip_start, seek_clip_end in seek_clips:
# while seek < seek_clip_end
while clip_idx < len(seek_clips):
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
if seek < seek_clip_start:
seek = seek_clip_start
if seek >= seek_clip_end:
clip_idx += 1
if clip_idx < len(seek_clips):
seek = seek_clips[clip_idx][0]
continue
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
mel_segment = mel[:, seek : seek + N_FRAMES] window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
segment_size = min(N_FRAMES, content_frames - seek) segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek)
mel_segment = mel[:, seek : seek + segment_size]
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype) mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
# print("melshape", mel_segment.shape) if carry_initial_prompt:
nignored = max(len(initial_prompt_tokens), prompt_reset_since)
remaining_prompt = all_tokens[nignored:][-remaining_prompt_length:]
decode_options["prompt"] = initial_prompt_tokens + remaining_prompt
else:
decode_options["prompt"] = all_tokens[prompt_reset_since:]
decode_options["prompt"] = all_tokens[prompt_reset_since:]
result: DecodingResult = decode_with_fallback(mel_segment) result: DecodingResult = decode_with_fallback(mel_segment)
tokens = torch.tensor(result.tokens) tokens = torch.tensor(result.tokens)
@@ -255,6 +312,30 @@ def transcribe(
previous_seek = seek previous_seek = seek
current_segments = [] current_segments = []
# anomalous words are very long/short/improbable
def word_anomaly_score(word: dict) -> float:
probability = word.get("probability", 0.0)
duration = word["end"] - word["start"]
score = 0.0
if probability < 0.15:
score += 1.0
if duration < 0.133:
score += (0.133 - duration) * 15
if duration > 2.0:
score += duration - 2.0
return score
def is_segment_anomaly(segment: Optional[dict]) -> bool:
if segment is None or not segment["words"]:
return False
words = [w for w in segment["words"] if w["word"] not in punctuation]
words = words[:8]
score = sum(word_anomaly_score(w) for w in words)
return score >= 3 or score + 0.01 >= len(words)
def next_words_segment(segments: List[dict]) -> Optional[dict]:
return next((s for s in segments if s["words"]), None)
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
@@ -317,9 +398,7 @@ def transcribe(
) )
seek += segment_size seek += segment_size
# print("word_timestamps, ", word_timestamps)
if word_timestamps: if word_timestamps:
# print("=========run timestamps here=========")
add_word_timestamps( add_word_timestamps(
segments=current_segments, segments=current_segments,
model=model, model=model,
@@ -330,17 +409,71 @@ def transcribe(
append_punctuations=append_punctuations, append_punctuations=append_punctuations,
last_speech_timestamp=last_speech_timestamp, last_speech_timestamp=last_speech_timestamp,
) )
word_end_timestamps = [
w["end"] for s in current_segments for w in s["words"] if not single_timestamp_ending:
] last_word_end = get_end(current_segments)
if len(word_end_timestamps) > 0: if last_word_end is not None and last_word_end > time_offset:
last_speech_timestamp = word_end_timestamps[-1] seek = round(last_word_end * FRAMES_PER_SECOND)
if not single_timestamp_ending and len(word_end_timestamps) > 0:
seek_shift = round( # skip silence before possible hallucinations
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND if hallucination_silence_threshold is not None:
) threshold = hallucination_silence_threshold
if seek_shift > 0: if not single_timestamp_ending:
seek = previous_seek + seek_shift last_word_end = get_end(current_segments)
if last_word_end is not None and last_word_end > time_offset:
remaining_duration = window_end_time - last_word_end
if remaining_duration > threshold:
seek = round(last_word_end * FRAMES_PER_SECOND)
else:
seek = previous_seek + segment_size
# if first segment might be a hallucination, skip leading silence
first_segment = next_words_segment(current_segments)
if first_segment is not None and is_segment_anomaly(first_segment):
gap = first_segment["start"] - time_offset
if gap > threshold:
seek = previous_seek + round(gap * FRAMES_PER_SECOND)
continue
# skip silence before any possible hallucination that is surrounded
# by silence or more hallucinations
hal_last_end = last_speech_timestamp
for si in range(len(current_segments)):
segment = current_segments[si]
if not segment["words"]:
continue
if is_segment_anomaly(segment):
next_segment = next_words_segment(
current_segments[si + 1 :]
)
if next_segment is not None:
hal_next_start = next_segment["words"][0]["start"]
else:
hal_next_start = time_offset + segment_duration
silence_before = (
segment["start"] - hal_last_end > threshold
or segment["start"] < threshold
or segment["start"] - time_offset < 2.0
)
silence_after = (
hal_next_start - segment["end"] > threshold
or is_segment_anomaly(next_segment)
or window_end_time - segment["end"] < 2.0
)
if silence_before and silence_after:
seek = round(
max(time_offset + 1, segment["start"])
* FRAMES_PER_SECOND
)
if content_duration - segment["end"] < threshold:
seek = content_frames
current_segments[si:] = []
break
hal_last_end = segment["end"]
last_word_end = get_end(current_segments)
if last_word_end is not None:
last_speech_timestamp = last_word_end
if verbose: if verbose:
for segment in current_segments: for segment in current_segments:
@@ -384,10 +517,17 @@ def transcribe(
def cli(): def cli():
from . import available_models from . import available_models
def valid_model_name(name):
if name in available_models() or os.path.exists(name):
return name
raise ValueError(
f"model should be one of {available_models()} or path to a model checkpoint"
)
# fmt: off # fmt: off
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe") parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use") parser.add_argument("--model", default="turbo", type=valid_model_name, help="name of the Whisper model to use")
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
@@ -405,6 +545,8 @@ def cli():
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations") parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.") parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
parser.add_argument("--carry_initial_prompt", type=str2bool, default=False, help="if True, prepend initial_prompt to every internal decode() call. May reduce the effectiveness of condition_on_previous_text")
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop") parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default") parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
@@ -418,7 +560,10 @@ def cli():
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt") parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line") parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment") parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
parser.add_argument("--max_words_per_line", type=optional_int, default=None, help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file")
parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected")
# fmt: on # fmt: on
args = parser.parse_args().__dict__ args = parser.parse_args().__dict__
@@ -450,17 +595,28 @@ def cli():
model = load_model(model_name, device=device, download_root=model_dir) model = load_model(model_name, device=device, download_root=model_dir)
writer = get_writer(output_format, output_dir) writer = get_writer(output_format, output_dir)
word_options = ["highlight_words", "max_line_count", "max_line_width"] word_options = [
"highlight_words",
"max_line_count",
"max_line_width",
"max_words_per_line",
]
if not args["word_timestamps"]: if not args["word_timestamps"]:
for option in word_options: for option in word_options:
if args[option]: if args[option]:
parser.error(f"--{option} requires --word_timestamps True") parser.error(f"--{option} requires --word_timestamps True")
if args["max_line_count"] and not args["max_line_width"]: if args["max_line_count"] and not args["max_line_width"]:
warnings.warn("--max_line_count has no effect without --max_line_width") warnings.warn("--max_line_count has no effect without --max_line_width")
if args["max_words_per_line"] and args["max_line_width"]:
warnings.warn("--max_words_per_line has no effect with --max_line_width")
writer_args = {arg: args.pop(arg) for arg in word_options} writer_args = {arg: args.pop(arg) for arg in word_options}
for audio_path in args.pop("audio"): for audio_path in args.pop("audio"):
result = transcribe(model, audio_path, temperature=temperature, **args) try:
writer(result, audio_path, writer_args) result = transcribe(model, audio_path, temperature=temperature, **args)
writer(result, audio_path, **writer_args)
except Exception as e:
traceback.print_exc()
print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}")
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -60,7 +60,7 @@ def median_kernel(filter_width: int):
tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821 tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
kernel = triton.JITFunction(kernel.fn) kernel = triton.JITFunction(kernel.fn)
kernel.src = kernel.src.replace( new_kernel = kernel.src.replace(
" LOAD_ALL_ROWS_HERE", " LOAD_ALL_ROWS_HERE",
"\n".join( "\n".join(
[ [
@@ -69,7 +69,8 @@ def median_kernel(filter_width: int):
] ]
), ),
) )
kernel.src = kernel.src.replace(
new_kernel = new_kernel.replace(
" BUBBLESORT_HERE", " BUBBLESORT_HERE",
"\n\n".join( "\n\n".join(
[ [
@@ -90,7 +91,14 @@ def median_kernel(filter_width: int):
] ]
), ),
) )
kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
new_kernel = new_kernel.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
if hasattr(kernel, "_unsafe_update_src") is True:
kernel._unsafe_update_src(new_kernel)
kernel.hash = None
else:
kernel.src = new_kernel
return kernel return kernel

View File

@@ -3,7 +3,7 @@ import os
import re import re
import sys import sys
import zlib import zlib
from typing import Callable, Optional, TextIO from typing import Callable, List, Optional, TextIO
system_encoding = sys.getdefaultencoding() system_encoding = sys.getdefaultencoding()
@@ -68,13 +68,29 @@ def format_timestamp(
) )
def get_start(segments: List[dict]) -> Optional[float]:
return next(
(w["start"] for s in segments for w in s["words"]),
segments[0]["start"] if segments else None,
)
def get_end(segments: List[dict]) -> Optional[float]:
return next(
(w["end"] for s in reversed(segments) for w in reversed(s["words"])),
segments[-1]["end"] if segments else None,
)
class ResultWriter: class ResultWriter:
extension: str extension: str
def __init__(self, output_dir: str): def __init__(self, output_dir: str):
self.output_dir = output_dir self.output_dir = output_dir
def __call__(self, result: dict, audio_path: str, options: dict): def __call__(
self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
):
audio_basename = os.path.basename(audio_path) audio_basename = os.path.basename(audio_path)
audio_basename = os.path.splitext(audio_basename)[0] audio_basename = os.path.splitext(audio_basename)[0]
output_path = os.path.join( output_path = os.path.join(
@@ -82,16 +98,20 @@ class ResultWriter:
) )
with open(output_path, "w", encoding="utf-8") as f: with open(output_path, "w", encoding="utf-8") as f:
self.write_result(result, file=f, options=options) self.write_result(result, file=f, options=options, **kwargs)
def write_result(self, result: dict, file: TextIO, options: dict): def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
raise NotImplementedError raise NotImplementedError
class WriteTXT(ResultWriter): class WriteTXT(ResultWriter):
extension: str = "txt" extension: str = "txt"
def write_result(self, result: dict, file: TextIO, options: dict): def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
for segment in result["segments"]: for segment in result["segments"]:
print(segment["text"].strip(), file=file, flush=True) print(segment["text"].strip(), file=file, flush=True)
@@ -100,48 +120,76 @@ class SubtitlesWriter(ResultWriter):
always_include_hours: bool always_include_hours: bool
decimal_marker: str decimal_marker: str
def iterate_result(self, result: dict, options: dict): def iterate_result(
raw_max_line_width: Optional[int] = options["max_line_width"] self,
max_line_count: Optional[int] = options["max_line_count"] result: dict,
highlight_words: bool = options["highlight_words"] options: Optional[dict] = None,
max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width *,
preserve_segments = max_line_count is None or raw_max_line_width is None max_line_width: Optional[int] = None,
max_line_count: Optional[int] = None,
highlight_words: bool = False,
max_words_per_line: Optional[int] = None,
):
options = options or {}
max_line_width = max_line_width or options.get("max_line_width")
max_line_count = max_line_count or options.get("max_line_count")
highlight_words = highlight_words or options.get("highlight_words", False)
max_words_per_line = max_words_per_line or options.get("max_words_per_line")
preserve_segments = max_line_count is None or max_line_width is None
max_line_width = max_line_width or 1000
max_words_per_line = max_words_per_line or 1000
def iterate_subtitles(): def iterate_subtitles():
line_len = 0 line_len = 0
line_count = 1 line_count = 1
# the next subtitle to yield (a list of word timings with whitespace) # the next subtitle to yield (a list of word timings with whitespace)
subtitle: list[dict] = [] subtitle: List[dict] = []
last = result["segments"][0]["words"][0]["start"] last: float = get_start(result["segments"]) or 0.0
for segment in result["segments"]: for segment in result["segments"]:
for i, original_timing in enumerate(segment["words"]): chunk_index = 0
timing = original_timing.copy() words_count = max_words_per_line
long_pause = not preserve_segments and timing["start"] - last > 3.0 while chunk_index < len(segment["words"]):
has_room = line_len + len(timing["word"]) <= max_line_width remaining_words = len(segment["words"]) - chunk_index
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments if max_words_per_line > len(segment["words"]) - chunk_index:
if line_len > 0 and has_room and not long_pause and not seg_break: words_count = remaining_words
# line continuation for i, original_timing in enumerate(
line_len += len(timing["word"]) segment["words"][chunk_index : chunk_index + words_count]
else: ):
# new line timing = original_timing.copy()
timing["word"] = timing["word"].strip() long_pause = (
not preserve_segments and timing["start"] - last > 3.0
)
has_room = line_len + len(timing["word"]) <= max_line_width
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
if ( if (
len(subtitle) > 0 line_len > 0
and max_line_count is not None and has_room
and (long_pause or line_count >= max_line_count) and not long_pause
or seg_break and not seg_break
): ):
# subtitle break # line continuation
yield subtitle line_len += len(timing["word"])
subtitle = [] else:
line_count = 1 # new line
elif line_len > 0: timing["word"] = timing["word"].strip()
# line break if (
line_count += 1 len(subtitle) > 0
timing["word"] = "\n" + timing["word"] and max_line_count is not None
line_len = len(timing["word"].strip()) and (long_pause or line_count >= max_line_count)
subtitle.append(timing) or seg_break
last = timing["start"] ):
# subtitle break
yield subtitle
subtitle = []
line_count = 1
elif line_len > 0:
# line break
line_count += 1
timing["word"] = "\n" + timing["word"]
line_len = len(timing["word"].strip())
subtitle.append(timing)
last = timing["start"]
chunk_index += max_words_per_line
if len(subtitle) > 0: if len(subtitle) > 0:
yield subtitle yield subtitle
@@ -161,9 +209,11 @@ class SubtitlesWriter(ResultWriter):
yield start, end, "".join( yield start, end, "".join(
[ [
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word) (
if j == i re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
else word if j == i
else word
)
for j, word in enumerate(all_words) for j, word in enumerate(all_words)
] ]
) )
@@ -190,9 +240,11 @@ class WriteVTT(SubtitlesWriter):
always_include_hours: bool = False always_include_hours: bool = False
decimal_marker: str = "." decimal_marker: str = "."
def write_result(self, result: dict, file: TextIO, options: dict): def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
print("WEBVTT\n", file=file) print("WEBVTT\n", file=file)
for start, end, text in self.iterate_result(result, options): for start, end, text in self.iterate_result(result, options, **kwargs):
print(f"{start} --> {end}\n{text}\n", file=file, flush=True) print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
@@ -201,9 +253,11 @@ class WriteSRT(SubtitlesWriter):
always_include_hours: bool = True always_include_hours: bool = True
decimal_marker: str = "," decimal_marker: str = ","
def write_result(self, result: dict, file: TextIO, options: dict): def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
for i, (start, end, text) in enumerate( for i, (start, end, text) in enumerate(
self.iterate_result(result, options), start=1 self.iterate_result(result, options, **kwargs), start=1
): ):
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True) print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
@@ -220,7 +274,9 @@ class WriteTSV(ResultWriter):
extension: str = "tsv" extension: str = "tsv"
def write_result(self, result: dict, file: TextIO, options: dict): def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
print("start", "end", "text", sep="\t", file=file) print("start", "end", "text", sep="\t", file=file)
for segment in result["segments"]: for segment in result["segments"]:
print(round(1000 * segment["start"]), file=file, end="\t") print(round(1000 * segment["start"]), file=file, end="\t")
@@ -231,7 +287,9 @@ class WriteTSV(ResultWriter):
class WriteJSON(ResultWriter): class WriteJSON(ResultWriter):
extension: str = "json" extension: str = "json"
def write_result(self, result: dict, file: TextIO, options: dict): def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
json.dump(result, file) json.dump(result, file)
@@ -249,9 +307,11 @@ def get_writer(
if output_format == "all": if output_format == "all":
all_writers = [writer(output_dir) for writer in writers.values()] all_writers = [writer(output_dir) for writer in writers.values()]
def write_all(result: dict, file: TextIO, options: dict): def write_all(
result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
for writer in all_writers: for writer in all_writers:
writer(result, file, options) writer(result, file, options, **kwargs)
return write_all return write_all

View File

@@ -1 +1 @@
__version__ = "20230918" __version__ = "20250625"

View File

@@ -1,10 +1,16 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
from datetime import timedelta
def format_time(seconds: float) -> str:
"""Format seconds as HH:MM:SS."""
return str(timedelta(seconds=int(seconds)))
@dataclass @dataclass
class TimedText: class TimedText:
start: Optional[float] start: Optional[float] = 0
end: Optional[float] end: Optional[float] = 0
text: Optional[str] = '' text: Optional[str] = ''
speaker: Optional[int] = -1 speaker: Optional[int] = -1
probability: Optional[float] = None probability: Optional[float] = None
@@ -29,4 +35,26 @@ class SpeakerSegment(TimedText):
"""Represents a segment of audio attributed to a specific speaker. """Represents a segment of audio attributed to a specific speaker.
No text nor probability is associated with this segment. No text nor probability is associated with this segment.
""" """
pass pass
@dataclass
class Translation(TimedText):
pass
@dataclass
class Silence():
duration: float
@dataclass
class Line(TimedText):
translation: str = ''
def to_dict(self):
return {
'speaker': int(self.speaker),
'text': self.text,
'translation': self.translation,
'start': format_time(self.start),
'end': format_time(self.end),
}

View File

@@ -1,73 +0,0 @@
import torch
import sys
class TokenBuffer:
def __init__(self, text="", tokenizer=None, device=None, prefix_token_ids=[]):
self.text = text
self.prefix_token_ids = prefix_token_ids
self.tokenizer = tokenizer
self.device = device
def as_token_ids(self, tokenizer=None):
if tokenizer is None:
tokenizer = self.tokenizer
if tokenizer is None:
raise ValueError("Tokenizer is not set.")
return self.prefix_token_ids + tokenizer.encode(self.text)
def as_tensor(self, device=None):
if device is None:
device = self.device
if device is None:
raise ValueError("Device is not set.")
tok_ids = self.as_token_ids()
return torch.tensor(tok_ids,
dtype=torch.long, device=device).unsqueeze(0)
def as_tensor_beam(self, beam, device=None):
t = self.as_tensor(device=device)
return t.repeat_interleave(beam, dim=0)
def as_text(self):
return self.text
@staticmethod
def empty(*a, **kw):
return TokenBuffer(*a,**kw)
@staticmethod
def from_text(text, *a, **kw):
return TokenBuffer(*a, text=text, **kw)
def is_empty(self):
return self.text is None or self.text == ""
def trim_words(self, num=1, after=0):
'''
num: how many words to trim from the beginning
after: how many characters to skip (length of the static prompt)
'''
tokenizer = self.tokenizer
assert tokenizer is not None, "Tokenizer is not set."
ids = tokenizer.encode(self.text[after:])
words, wids = self.tokenizer.split_to_word_tokens(ids)
print(words, file=sys.stderr)
print(wids, file=sys.stderr)
if not words:
return 0
self.text = self.text[:after] + "".join(words[num:])
return sum(len(wi) for wi in wids[:num])
def append_token_ids(self, token_ids):
tokenizer = self.tokenizer
assert tokenizer is not None, "Tokenizer is not set."
self.text += self.tokenizer.decode(token_ids)
def as_split_word_tokens(self):
tokenizer = self.tokenizer
assert tokenizer is not None, "Tokenizer is not set."
ids = tokenizer.encode(self.text)
return tokenizer.split_to_word_tokens(ids)

View File

@@ -0,0 +1,60 @@
from typing import Sequence, Callable, Any, Optional, Dict
def _detect_tail_repetition(
seq: Sequence[Any],
key: Callable[[Any], Any] = lambda x: x, # extract comparable value
min_block: int = 1, # set to 2 to ignore 1-token loops like "."
max_tail: int = 300, # search window from the end for speed
prefer: str = "longest", # "longest" coverage or "smallest" block
) -> Optional[Dict]:
vals = [key(x) for x in seq][-max_tail:]
n = len(vals)
best = None
# try every possible block length
for b in range(min_block, n // 2 + 1):
block = vals[-b:]
# count how many times this block repeats contiguously at the very end
count, i = 0, n
while i - b >= 0 and vals[i - b:i] == block:
count += 1
i -= b
if count >= 2:
cand = {
"block_size": b,
"count": count,
"start_index": len(seq) - count * b, # in original seq
"end_index": len(seq),
}
if (best is None or
(prefer == "longest" and count * b > best["count"] * best["block_size"]) or
(prefer == "smallest" and b < best["block_size"])):
best = cand
return best
def trim_tail_repetition(
seq: Sequence[Any],
key: Callable[[Any], Any] = lambda x: x,
min_block: int = 1,
max_tail: int = 300,
prefer: str = "longest",
keep: int = 1, # how many copies of the repeating block to keep at the end (0 or 1 are common)
):
"""
Returns a new sequence with repeated tail trimmed.
keep=1 -> keep a single copy of the repeated block.
keep=0 -> remove all copies of the repeated block.
"""
rep = _detect_tail_repetition(seq, key, min_block, max_tail, prefer)
if not rep:
return seq, False # nothing to trim
b, c = rep["block_size"], rep["count"]
if keep < 0:
keep = 0
if keep >= c:
return seq, False # nothing to trim (already <= keep copies)
# new length = total - (copies_to_remove * block_size)
new_len = len(seq) - (c - keep) * b
return seq[:new_len], True

View File

View File

@@ -0,0 +1,182 @@
"""
adapted from https://store.crowdin.com/custom-mt
"""
LANGUAGES = [
{"name": "Afrikaans", "nllb": "afr_Latn", "crowdin": "af"},
{"name": "Akan", "nllb": "aka_Latn", "crowdin": "ak"},
{"name": "Amharic", "nllb": "amh_Ethi", "crowdin": "am"},
{"name": "Assamese", "nllb": "asm_Beng", "crowdin": "as"},
{"name": "Asturian", "nllb": "ast_Latn", "crowdin": "ast"},
{"name": "Bashkir", "nllb": "bak_Cyrl", "crowdin": "ba"},
{"name": "Bambara", "nllb": "bam_Latn", "crowdin": "bm"},
{"name": "Balinese", "nllb": "ban_Latn", "crowdin": "ban"},
{"name": "Belarusian", "nllb": "bel_Cyrl", "crowdin": "be"},
{"name": "Bengali", "nllb": "ben_Beng", "crowdin": "bn"},
{"name": "Bosnian", "nllb": "bos_Latn", "crowdin": "bs"},
{"name": "Bulgarian", "nllb": "bul_Cyrl", "crowdin": "bg"},
{"name": "Catalan", "nllb": "cat_Latn", "crowdin": "ca"},
{"name": "Cebuano", "nllb": "ceb_Latn", "crowdin": "ceb"},
{"name": "Czech", "nllb": "ces_Latn", "crowdin": "cs"},
{"name": "Welsh", "nllb": "cym_Latn", "crowdin": "cy"},
{"name": "Danish", "nllb": "dan_Latn", "crowdin": "da"},
{"name": "German", "nllb": "deu_Latn", "crowdin": "de"},
{"name": "Dzongkha", "nllb": "dzo_Tibt", "crowdin": "dz"},
{"name": "Greek", "nllb": "ell_Grek", "crowdin": "el"},
{"name": "English", "nllb": "eng_Latn", "crowdin": "en"},
{"name": "Esperanto", "nllb": "epo_Latn", "crowdin": "eo"},
{"name": "Estonian", "nllb": "est_Latn", "crowdin": "et"},
{"name": "Basque", "nllb": "eus_Latn", "crowdin": "eu"},
{"name": "Ewe", "nllb": "ewe_Latn", "crowdin": "ee"},
{"name": "Faroese", "nllb": "fao_Latn", "crowdin": "fo"},
{"name": "Fijian", "nllb": "fij_Latn", "crowdin": "fj"},
{"name": "Finnish", "nllb": "fin_Latn", "crowdin": "fi"},
{"name": "French", "nllb": "fra_Latn", "crowdin": "fr"},
{"name": "Friulian", "nllb": "fur_Latn", "crowdin": "fur-IT"},
{"name": "Scottish Gaelic", "nllb": "gla_Latn", "crowdin": "gd"},
{"name": "Irish", "nllb": "gle_Latn", "crowdin": "ga-IE"},
{"name": "Galician", "nllb": "glg_Latn", "crowdin": "gl"},
{"name": "Guarani", "nllb": "grn_Latn", "crowdin": "gn"},
{"name": "Gujarati", "nllb": "guj_Gujr", "crowdin": "gu-IN"},
{"name": "Haitian Creole", "nllb": "hat_Latn", "crowdin": "ht"},
{"name": "Hausa", "nllb": "hau_Latn", "crowdin": "ha"},
{"name": "Hebrew", "nllb": "heb_Hebr", "crowdin": "he"},
{"name": "Hindi", "nllb": "hin_Deva", "crowdin": "hi"},
{"name": "Croatian", "nllb": "hrv_Latn", "crowdin": "hr"},
{"name": "Hungarian", "nllb": "hun_Latn", "crowdin": "hu"},
{"name": "Armenian", "nllb": "hye_Armn", "crowdin": "hy-AM"},
{"name": "Igbo", "nllb": "ibo_Latn", "crowdin": "ig"},
{"name": "Indonesian", "nllb": "ind_Latn", "crowdin": "id"},
{"name": "Icelandic", "nllb": "isl_Latn", "crowdin": "is"},
{"name": "Italian", "nllb": "ita_Latn", "crowdin": "it"},
{"name": "Javanese", "nllb": "jav_Latn", "crowdin": "jv"},
{"name": "Japanese", "nllb": "jpn_Jpan", "crowdin": "ja"},
{"name": "Kabyle", "nllb": "kab_Latn", "crowdin": "kab"},
{"name": "Kannada", "nllb": "kan_Knda", "crowdin": "kn"},
{"name": "Georgian", "nllb": "kat_Geor", "crowdin": "ka"},
{"name": "Kazakh", "nllb": "kaz_Cyrl", "crowdin": "kk"},
{"name": "Khmer", "nllb": "khm_Khmr", "crowdin": "km"},
{"name": "Kinyarwanda", "nllb": "kin_Latn", "crowdin": "rw"},
{"name": "Kyrgyz", "nllb": "kir_Cyrl", "crowdin": "ky"},
{"name": "Korean", "nllb": "kor_Hang", "crowdin": "ko"},
{"name": "Lao", "nllb": "lao_Laoo", "crowdin": "lo"},
{"name": "Ligurian", "nllb": "lij_Latn", "crowdin": "lij"},
{"name": "Limburgish", "nllb": "lim_Latn", "crowdin": "li"},
{"name": "Lingala", "nllb": "lin_Latn", "crowdin": "ln"},
{"name": "Lithuanian", "nllb": "lit_Latn", "crowdin": "lt"},
{"name": "Luxembourgish", "nllb": "ltz_Latn", "crowdin": "lb"},
{"name": "Maithili", "nllb": "mai_Deva", "crowdin": "mai"},
{"name": "Malayalam", "nllb": "mal_Mlym", "crowdin": "ml-IN"},
{"name": "Marathi", "nllb": "mar_Deva", "crowdin": "mr"},
{"name": "Macedonian", "nllb": "mkd_Cyrl", "crowdin": "mk"},
{"name": "Maltese", "nllb": "mlt_Latn", "crowdin": "mt"},
{"name": "Mossi", "nllb": "mos_Latn", "crowdin": "mos"},
{"name": "Maori", "nllb": "mri_Latn", "crowdin": "mi"},
{"name": "Burmese", "nllb": "mya_Mymr", "crowdin": "my"},
{"name": "Dutch", "nllb": "nld_Latn", "crowdin": "nl"},
{"name": "Norwegian Nynorsk", "nllb": "nno_Latn", "crowdin": "nn-NO"},
{"name": "Nepali", "nllb": "npi_Deva", "crowdin": "ne-NP"},
{"name": "Northern Sotho", "nllb": "nso_Latn", "crowdin": "nso"},
{"name": "Occitan", "nllb": "oci_Latn", "crowdin": "oc"},
{"name": "Odia", "nllb": "ory_Orya", "crowdin": "or"},
{"name": "Papiamento", "nllb": "pap_Latn", "crowdin": "pap"},
{"name": "Polish", "nllb": "pol_Latn", "crowdin": "pl"},
{"name": "Portuguese", "nllb": "por_Latn", "crowdin": "pt-PT"},
{"name": "Dari", "nllb": "prs_Arab", "crowdin": "fa-AF"},
{"name": "Romanian", "nllb": "ron_Latn", "crowdin": "ro"},
{"name": "Rundi", "nllb": "run_Latn", "crowdin": "rn"},
{"name": "Russian", "nllb": "rus_Cyrl", "crowdin": "ru"},
{"name": "Sango", "nllb": "sag_Latn", "crowdin": "sg"},
{"name": "Sanskrit", "nllb": "san_Deva", "crowdin": "sa"},
{"name": "Santali", "nllb": "sat_Olck", "crowdin": "sat"},
{"name": "Sinhala", "nllb": "sin_Sinh", "crowdin": "si-LK"},
{"name": "Slovak", "nllb": "slk_Latn", "crowdin": "sk"},
{"name": "Slovenian", "nllb": "slv_Latn", "crowdin": "sl"},
{"name": "Shona", "nllb": "sna_Latn", "crowdin": "sn"},
{"name": "Sindhi", "nllb": "snd_Arab", "crowdin": "sd"},
{"name": "Somali", "nllb": "som_Latn", "crowdin": "so"},
{"name": "Southern Sotho", "nllb": "sot_Latn", "crowdin": "st"},
{"name": "Spanish", "nllb": "spa_Latn", "crowdin": "es-ES"},
{"name": "Sardinian", "nllb": "srd_Latn", "crowdin": "sc"},
{"name": "Swati", "nllb": "ssw_Latn", "crowdin": "ss"},
{"name": "Sundanese", "nllb": "sun_Latn", "crowdin": "su"},
{"name": "Swedish", "nllb": "swe_Latn", "crowdin": "sv-SE"},
{"name": "Swahili", "nllb": "swh_Latn", "crowdin": "sw"},
{"name": "Tamil", "nllb": "tam_Taml", "crowdin": "ta"},
{"name": "Tatar", "nllb": "tat_Cyrl", "crowdin": "tt-RU"},
{"name": "Telugu", "nllb": "tel_Telu", "crowdin": "te"},
{"name": "Tajik", "nllb": "tgk_Cyrl", "crowdin": "tg"},
{"name": "Tagalog", "nllb": "tgl_Latn", "crowdin": "tl"},
{"name": "Thai", "nllb": "tha_Thai", "crowdin": "th"},
{"name": "Tigrinya", "nllb": "tir_Ethi", "crowdin": "ti"},
{"name": "Tswana", "nllb": "tsn_Latn", "crowdin": "tn"},
{"name": "Tsonga", "nllb": "tso_Latn", "crowdin": "ts"},
{"name": "Turkmen", "nllb": "tuk_Latn", "crowdin": "tk"},
{"name": "Turkish", "nllb": "tur_Latn", "crowdin": "tr"},
{"name": "Uyghur", "nllb": "uig_Arab", "crowdin": "ug"},
{"name": "Ukrainian", "nllb": "ukr_Cyrl", "crowdin": "uk"},
{"name": "Venetian", "nllb": "vec_Latn", "crowdin": "vec"},
{"name": "Vietnamese", "nllb": "vie_Latn", "crowdin": "vi"},
{"name": "Wolof", "nllb": "wol_Latn", "crowdin": "wo"},
{"name": "Xhosa", "nllb": "xho_Latn", "crowdin": "xh"},
{"name": "Yoruba", "nllb": "yor_Latn", "crowdin": "yo"},
{"name": "Zulu", "nllb": "zul_Latn", "crowdin": "zu"},
]
NAME_TO_NLLB = {lang["name"]: lang["nllb"] for lang in LANGUAGES}
NAME_TO_CROWDIN = {lang["name"]: lang["crowdin"] for lang in LANGUAGES}
CROWDIN_TO_NLLB = {lang["crowdin"]: lang["nllb"] for lang in LANGUAGES}
NLLB_TO_CROWDIN = {lang["nllb"]: lang["crowdin"] for lang in LANGUAGES}
CROWDIN_TO_NAME = {lang["crowdin"]: lang["name"] for lang in LANGUAGES}
NLLB_TO_NAME = {lang["nllb"]: lang["name"] for lang in LANGUAGES}
def get_nllb_code(crowdin_code):
return CROWDIN_TO_NLLB.get(crowdin_code, None)
def get_crowdin_code(nllb_code):
return NLLB_TO_CROWDIN.get(nllb_code)
def get_language_name_by_crowdin(crowdin_code):
return CROWDIN_TO_NAME.get(crowdin_code)
def get_language_name_by_nllb(nllb_code):
return NLLB_TO_NAME.get(nllb_code)
def get_language_info(identifier, identifier_type="auto"):
if identifier_type == "auto":
for lang in LANGUAGES:
if (lang["name"].lower() == identifier.lower() or
lang["nllb"] == identifier or
lang["crowdin"] == identifier):
return lang
elif identifier_type == "name":
for lang in LANGUAGES:
if lang["name"].lower() == identifier.lower():
return lang
elif identifier_type == "nllb":
for lang in LANGUAGES:
if lang["nllb"] == identifier:
return lang
elif identifier_type == "crowdin":
for lang in LANGUAGES:
if lang["crowdin"] == identifier:
return lang
return None
def list_all_languages():
return [lang["name"] for lang in LANGUAGES]
def list_all_nllb_codes():
return [lang["nllb"] for lang in LANGUAGES]
def list_all_crowdin_codes():
return [lang["crowdin"] for lang in LANGUAGES]

View File

@@ -0,0 +1,137 @@
import ctranslate2
import torch
import transformers
from dataclasses import dataclass
import huggingface_hub
from whisperlivekit.translation.mapping_languages import get_nllb_code
from whisperlivekit.timed_objects import Translation
#In diarization case, we may want to translate just one speaker, or at least start the sentences there
PUNCTUATION_MARKS = {'.', '!', '?', '', '', ''}
@dataclass
class TranslationModel():
translator: ctranslate2.Translator
tokenizer: dict
def load_model(src_langs):
MODEL = 'nllb-200-distilled-600M-ctranslate2'
MODEL_GUY = 'entai2965'
huggingface_hub.snapshot_download(MODEL_GUY + '/' + MODEL,local_dir=MODEL)
device = "cuda" if torch.cuda.is_available() else "cpu"
translator = ctranslate2.Translator(MODEL,device=device)
tokenizer = dict()
for src_lang in src_langs:
tokenizer[src_lang] = transformers.AutoTokenizer.from_pretrained(MODEL, src_lang=src_lang, clean_up_tokenization_spaces=True)
return TranslationModel(
translator=translator,
tokenizer=tokenizer
)
def translate(input, translation_model, tgt_lang):
source = translation_model.tokenizer.convert_ids_to_tokens(translation_model.tokenizer.encode(input))
target_prefix = [tgt_lang]
results = translation_model.translator.translate_batch([source], target_prefix=[target_prefix])
target = results[0].hypotheses[0][1:]
return translation_model.tokenizer.decode(translation_model.tokenizer.convert_tokens_to_ids(target))
class OnlineTranslation:
def __init__(self, translation_model: TranslationModel, input_languages: list, output_languages: list):
self.buffer = []
self.len_processed_buffer = 0
self.translation_remaining = Translation()
self.validated = []
self.translation_pending_validation = ''
self.translation_model = translation_model
self.input_languages = input_languages
self.output_languages = output_languages
def compute_common_prefix(self, results):
#we dont want want to prune the result for the moment.
if not self.buffer:
self.buffer = results
else:
for i in range(min(len(self.buffer), len(results))):
if self.buffer[i] != results[i]:
self.commited.extend(self.buffer[:i])
self.buffer = results[i:]
def translate(self, input, input_lang=None, output_lang=None):
if not input:
return ""
if input_lang is None:
input_lang = self.input_languages[0]
if output_lang is None:
output_lang = self.output_languages[0]
nllb_output_lang = get_nllb_code(output_lang)
source = self.translation_model.tokenizer[input_lang].convert_ids_to_tokens(self.translation_model.tokenizer[input_lang].encode(input))
results = self.translation_model.translator.translate_batch([source], target_prefix=[[nllb_output_lang]]) #we can use return_attention=True to try to optimize the stuff.
target = results[0].hypotheses[0][1:]
results = self.translation_model.tokenizer[input_lang].decode(self.translation_model.tokenizer[input_lang].convert_tokens_to_ids(target))
return results
def translate_tokens(self, tokens):
if tokens:
text = ' '.join([token.text for token in tokens])
start = tokens[0].start
end = tokens[-1].end
translated_text = self.translate(text)
translation = Translation(
text=translated_text,
start=start,
end=end,
)
return translation
return None
def insert_tokens(self, tokens):
self.buffer.extend(tokens)
pass
def process(self):
i = 0
if len(self.buffer) < self.len_processed_buffer + 3: #nothing new to process
return self.validated + [self.translation_remaining]
while i < len(self.buffer):
if self.buffer[i].text in PUNCTUATION_MARKS:
translation_sentence = self.translate_tokens(self.buffer[:i+1])
self.validated.append(translation_sentence)
self.buffer = self.buffer[i+1:]
i = 0
else:
i+=1
self.translation_remaining = self.translate_tokens(self.buffer)
self.len_processed_buffer = len(self.buffer)
return self.validated + [self.translation_remaining]
if __name__ == '__main__':
output_lang = 'fr'
input_lang = "en"
test_string = """
Transcription technology has improved so much in the past few years. Have you noticed how accurate real-time speech-to-text is now?
"""
test = test_string.split(' ')
step = len(test) // 3
shared_model = load_model([input_lang])
online_translation = OnlineTranslation(shared_model, input_languages=[input_lang], output_languages=[output_lang])
for id in range(5):
val = test[id*step : (id+1)*step]
val_str = ' '.join(val)
result = online_translation.translate(val_str)
print(result)
# print(result)

62
whisperlivekit/warmup.py Normal file
View File

@@ -0,0 +1,62 @@
import logging
logger = logging.getLogger(__name__)
def load_file(warmup_file=None, timeout=5):
import os
import tempfile
import librosa
if warmup_file is None:
# Download JFK sample if not already present
jfk_url = "https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav"
temp_dir = tempfile.gettempdir()
warmup_file = os.path.join(temp_dir, "whisper_warmup_jfk.wav")
if not os.path.exists(warmup_file):
logger.debug(f"Downloading warmup file from {jfk_url}")
print(f"Downloading warmup file from {jfk_url}")
import time
import urllib.request
import urllib.error
import socket
original_timeout = socket.getdefaulttimeout()
socket.setdefaulttimeout(timeout)
start_time = time.time()
try:
urllib.request.urlretrieve(jfk_url, warmup_file)
logger.debug(f"Download successful in {time.time() - start_time:.2f}s")
except (urllib.error.URLError, socket.timeout) as e:
logger.warning(f"Download failed: {e}. Proceeding without warmup.")
return None
finally:
socket.setdefaulttimeout(original_timeout)
elif not warmup_file:
return None
if not warmup_file or not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0:
logger.warning(f"Warmup file {warmup_file} invalid or missing.")
return None
try:
audio, sr = librosa.load(warmup_file, sr=16000)
except Exception as e:
logger.warning(f"Failed to load audio file: {e}")
return None
return audio
def warmup_asr(asr, warmup_file=None, timeout=5):
"""
Warmup the ASR model by transcribing a short audio file.
"""
audio = load_file(warmup_file=None, timeout=5)
asr.transcribe(audio)
logger.info("ASR model is warmed up")
def warmup_online(online, warmup_file=None, timeout=5):
audio = load_file(warmup_file=None, timeout=5)
online.warmup(audio)
logger.warning("ASR is warmed up")

View File

@@ -0,0 +1,517 @@
:root {
--bg: #ffffff;
--text: #111111;
--muted: #666666;
--border: #e5e5e5;
--chip-bg: rgba(0, 0, 0, 0.04);
--chip-text: #000000;
--spinner-border: #8d8d8d5c;
--spinner-top: #b0b0b0;
--silence-bg: #f3f3f3;
--loading-bg: rgba(255, 77, 77, 0.06);
--button-bg: #ffffff;
--button-border: #e9e9e9;
--wave-stroke: #000000;
--label-dia-text: #868686;
--label-trans-text: #111111;
}
@media (prefers-color-scheme: dark) {
:root:not([data-theme="light"]) {
--bg: #0b0b0b;
--text: #e6e6e6;
--muted: #9aa0a6;
--border: #333333;
--chip-bg: rgba(255, 255, 255, 0.08);
--chip-text: #e6e6e6;
--spinner-border: #555555;
--spinner-top: #dddddd;
--silence-bg: #1a1a1a;
--loading-bg: rgba(255, 77, 77, 0.12);
--button-bg: #111111;
--button-border: #333333;
--wave-stroke: #e6e6e6;
--label-dia-text: #b3b3b3;
--label-trans-text: #ffffff;
}
}
:root[data-theme="dark"] {
--bg: #0b0b0b;
--text: #e6e6e6;
--muted: #9aa0a6;
--border: #333333;
--chip-bg: rgba(255, 255, 255, 0.08);
--chip-text: #e6e6e6;
--spinner-border: #555555;
--spinner-top: #dddddd;
--silence-bg: #1a1a1a;
--loading-bg: rgba(255, 77, 77, 0.12);
--button-bg: #111111;
--button-border: #333333;
--wave-stroke: #e6e6e6;
--label-dia-text: #b3b3b3;
--label-trans-text: #ffffff;
}
:root[data-theme="light"] {
--bg: #ffffff;
--text: #111111;
--muted: #666666;
--border: #e5e5e5;
--chip-bg: rgba(0, 0, 0, 0.04);
--chip-text: #000000;
--spinner-border: #8d8d8d5c;
--spinner-top: #b0b0b0;
--silence-bg: #f3f3f3;
--loading-bg: rgba(255, 77, 77, 0.06);
--button-bg: #ffffff;
--button-border: #e9e9e9;
--wave-stroke: #000000;
--label-dia-text: #868686;
--label-trans-text: #111111;
}
body {
font-family: ui-sans-serif, system-ui, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol', 'Noto Color Emoji';
margin: 0;
text-align: center;
background-color: var(--bg);
color: var(--text);
height: 100vh;
display: flex;
flex-direction: column;
}
/* Record button */
#recordButton {
width: 50px;
height: 50px;
border: none;
border-radius: 50%;
background-color: var(--button-bg);
cursor: pointer;
transition: all 0.3s ease;
border: 1px solid var(--button-border);
display: flex;
align-items: center;
justify-content: center;
position: relative;
}
#recordButton.recording {
width: 180px;
border-radius: 40px;
justify-content: flex-start;
padding-left: 20px;
}
#recordButton:active {
transform: scale(0.95);
}
.shape-container {
width: 25px;
height: 25px;
display: flex;
align-items: center;
justify-content: center;
flex-shrink: 0;
}
.shape {
width: 25px;
height: 25px;
background-color: rgb(209, 61, 53);
border-radius: 50%;
transition: all 0.3s ease;
}
#recordButton:disabled .shape {
background-color: #6e6d6d;
}
#recordButton.recording .shape {
border-radius: 5px;
width: 25px;
height: 25px;
}
/* Recording elements */
.recording-info {
display: none;
align-items: center;
margin-left: 15px;
flex-grow: 1;
}
#recordButton.recording .recording-info {
display: flex;
}
.wave-container {
width: 60px;
height: 30px;
position: relative;
display: flex;
align-items: center;
justify-content: center;
}
#waveCanvas {
width: 100%;
height: 100%;
}
.timer {
font-size: 14px;
font-weight: 500;
color: var(--text);
margin-left: 10px;
}
#status {
margin-top: 15px;
font-size: 16px;
color: var(--text);
margin-bottom: 0;
}
.header-container {
position: sticky;
top: 0;
background-color: var(--bg);
z-index: 100;
padding: 20px;
}
/* Settings */
.settings-container {
display: flex;
justify-content: center;
align-items: center;
gap: 15px;
}
.settings {
display: flex;
flex-wrap: wrap;
align-items: flex-start;
gap: 12px;
}
.field {
display: flex;
flex-direction: column;
align-items: flex-start;
gap: 3px;
}
#chunkSelector,
#websocketInput,
#themeSelector,
#microphoneSelect {
font-size: 16px;
padding: 5px 8px;
border-radius: 8px;
border: 1px solid var(--border);
background-color: var(--button-bg);
color: var(--text);
max-height: 30px;
}
#microphoneSelect {
width: 100%;
max-width: 190px;
min-width: 120px;
}
#chunkSelector:focus,
#websocketInput:focus,
#themeSelector:focus,
#microphoneSelect:focus {
outline: none;
border-color: #007bff;
box-shadow: 0 0 0 3px rgba(0, 123, 255, 0.15);
}
label {
font-size: 13px;
color: var(--muted);
}
.ws-default {
font-size: 12px;
color: var(--muted);
}
/* Segmented pill control for Theme */
.segmented {
display: inline-flex;
align-items: stretch;
border: 1px solid var(--button-border);
background-color: var(--button-bg);
border-radius: 999px;
overflow: hidden;
}
.segmented input[type="radio"] {
position: absolute;
opacity: 0;
pointer-events: none;
}
.theme-selector-container {
display: flex;
align-items: center;
margin-top: 17px;
}
.segmented label {
display: inline-flex;
align-items: center;
gap: 6px;
padding: 6px 12px;
font-size: 14px;
color: var(--muted);
cursor: pointer;
user-select: none;
transition: background-color 0.2s ease, color 0.2s ease;
}
.segmented label span {
display: none;
}
.segmented label:hover span {
display: inline;
}
.segmented label:hover {
background-color: var(--chip-bg);
}
.segmented img {
width: 16px;
height: 16px;
}
.segmented input[type="radio"]:checked + label {
background-color: var(--chip-bg);
color: var(--text);
}
.segmented input[type="radio"]:focus-visible + label,
.segmented input[type="radio"]:focus + label {
outline: 2px solid #007bff;
outline-offset: 2px;
border-radius: 999px;
}
.transcript-container {
flex: 1;
overflow-y: auto;
padding: 20px;
scrollbar-width: none;
-ms-overflow-style: none;
}
.transcript-container::-webkit-scrollbar {
display: none;
}
/* Transcript area */
#linesTranscript {
margin: 0 auto;
max-width: 700px;
text-align: left;
font-size: 16px;
}
#linesTranscript p {
margin: 0px 0;
}
#linesTranscript strong {
color: var(--text);
}
#speaker {
border: 1px solid var(--border);
border-radius: 100px;
padding: 2px 10px;
font-size: 14px;
margin-bottom: 0px;
}
.label_diarization {
background-color: var(--chip-bg);
border-radius: 8px 8px 8px 8px;
padding: 2px 10px;
margin-left: 10px;
display: inline-block;
white-space: nowrap;
font-size: 14px;
margin-bottom: 0px;
color: var(--label-dia-text);
}
.label_transcription {
background-color: var(--chip-bg);
border-radius: 8px 8px 8px 8px;
padding: 2px 10px;
display: inline-block;
white-space: nowrap;
margin-left: 10px;
font-size: 14px;
margin-bottom: 0px;
color: var(--label-trans-text);
}
.label_translation {
background-color: var(--chip-bg);
border-radius: 10px;
padding: 4px 8px;
margin-top: 4px;
font-size: 14px;
color: var(--text);
display: flex;
align-items: flex-start;
gap: 4px;
}
.label_translation img {
margin-top: 2px;
}
.label_translation img {
width: 12px;
height: 12px;
}
#timeInfo {
color: var(--muted);
margin-left: 10px;
}
.textcontent {
font-size: 16px;
padding-left: 10px;
margin-bottom: 10px;
margin-top: 1px;
padding-top: 5px;
border-radius: 0px 0px 0px 10px;
}
.buffer_diarization {
color: var(--label-dia-text);
margin-left: 4px;
}
.buffer_transcription {
color: #7474748c;
margin-left: 4px;
}
.spinner {
display: inline-block;
width: 8px;
height: 8px;
border: 2px solid var(--spinner-border);
border-top: 2px solid var(--spinner-top);
border-radius: 50%;
animation: spin 0.7s linear infinite;
vertical-align: middle;
margin-bottom: 2px;
margin-right: 5px;
}
@keyframes spin {
to {
transform: rotate(360deg);
}
}
.silence {
color: var(--muted);
background-color: var(--silence-bg);
font-size: 13px;
border-radius: 30px;
padding: 2px 10px;
display: none;
}
.loading {
color: var(--muted);
background-color: var(--loading-bg);
border-radius: 8px 8px 8px 0px;
padding: 2px 10px;
font-size: 14px;
margin-bottom: 0px;
}
/* for smaller screens */
@media (max-width: 768px) {
.header-container {
padding: 15px;
}
.settings-container {
flex-direction: column;
gap: 10px;
}
.settings {
justify-content: center;
gap: 8px;
}
.field {
align-items: center;
}
#websocketInput,
#microphoneSelect {
min-width: 100px;
max-width: 160px;
}
.theme-selector-container {
margin-top: 10px;
}
.transcript-container {
padding: 15px;
}
}
@media (max-width: 480px) {
.header-container {
padding: 10px;
}
.settings {
flex-direction: column;
align-items: center;
gap: 6px;
}
#websocketInput,
#microphoneSelect {
max-width: 140px;
}
.segmented label {
padding: 4px 8px;
font-size: 12px;
}
.segmented img {
width: 14px;
height: 14px;
}
.transcript-container {
padding: 10px;
}
}

View File

@@ -4,679 +4,70 @@
<head> <head>
<meta charset="UTF-8" /> <meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /> <meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Audio Transcription</title> <title>WhisperLiveKit</title>
<style> <link rel="stylesheet" href="/web/live_transcription.css" />
body {
font-family: ui-sans-serif, system-ui, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol', 'Noto Color Emoji';
margin: 20px;
text-align: center;
}
#recordButton {
width: 50px;
height: 50px;
border: none;
border-radius: 50%;
background-color: white;
cursor: pointer;
transition: all 0.3s ease;
border: 1px solid rgb(233, 233, 233);
display: flex;
align-items: center;
justify-content: center;
position: relative;
}
#recordButton.recording {
width: 180px;
border-radius: 40px;
justify-content: flex-start;
padding-left: 20px;
}
#recordButton:active {
transform: scale(0.95);
}
.shape-container {
width: 25px;
height: 25px;
display: flex;
align-items: center;
justify-content: center;
flex-shrink: 0;
}
.shape {
width: 25px;
height: 25px;
background-color: rgb(209, 61, 53);
border-radius: 50%;
transition: all 0.3s ease;
}
#recordButton:disabled .shape {
background-color: #6e6d6d;
}
#recordButton.recording .shape {
border-radius: 5px;
width: 25px;
height: 25px;
}
/* Recording elements */
.recording-info {
display: none;
align-items: center;
margin-left: 15px;
flex-grow: 1;
}
#recordButton.recording .recording-info {
display: flex;
}
.wave-container {
width: 60px;
height: 30px;
position: relative;
display: flex;
align-items: center;
justify-content: center;
}
#waveCanvas {
width: 100%;
height: 100%;
}
.timer {
font-size: 14px;
font-weight: 500;
color: #333;
margin-left: 10px;
}
#status {
margin-top: 20px;
font-size: 16px;
color: #333;
}
.settings-container {
display: flex;
justify-content: center;
align-items: center;
gap: 15px;
margin-top: 20px;
}
.settings {
display: flex;
flex-direction: column;
align-items: flex-start;
gap: 5px;
}
#chunkSelector,
#websocketInput {
font-size: 16px;
padding: 5px;
border-radius: 5px;
border: 1px solid #ddd;
background-color: #ffffff;
max-height: 30px;
}
#websocketInput {
width: 200px;
}
#chunkSelector:focus,
#websocketInput:focus {
outline: none;
border-color: #007bff;
}
label {
font-size: 14px;
}
/* Speaker-labeled transcript area */
#linesTranscript {
margin: 20px auto;
max-width: 700px;
text-align: left;
font-size: 16px;
}
#linesTranscript p {
margin: 0px 0;
}
#linesTranscript strong {
color: #333;
}
#speaker {
border: 1px solid rgb(229, 229, 229);
border-radius: 100px;
padding: 2px 10px;
font-size: 14px;
margin-bottom: 0px;
}
.label_diarization {
background-color: #ffffff66;
border-radius: 8px 8px 8px 8px;
padding: 2px 10px;
margin-left: 10px;
display: inline-block;
white-space: nowrap;
font-size: 14px;
margin-bottom: 0px;
color: rgb(134, 134, 134)
}
.label_transcription {
background-color: #ffffff66;
border-radius: 8px 8px 8px 8px;
padding: 2px 10px;
display: inline-block;
white-space: nowrap;
margin-left: 10px;
font-size: 14px;
margin-bottom: 0px;
color: #000000
}
#timeInfo {
color: #666;
margin-left: 10px;
}
.textcontent {
font-size: 16px;
/* margin-left: 10px; */
padding-left: 10px;
margin-bottom: 10px;
margin-top: 1px;
padding-top: 5px;
border-radius: 0px 0px 0px 10px;
}
.buffer_diarization {
color: rgb(134, 134, 134);
margin-left: 4px;
}
.buffer_transcription {
color: #7474748c;
margin-left: 4px;
}
.spinner {
display: inline-block;
width: 8px;
height: 8px;
border: 2px solid #8d8d8d5c;
border-top: 2px solid #6c6c6ce5;
border-radius: 50%;
animation: spin 0.6s linear infinite;
vertical-align: middle;
margin-bottom: 2px;
margin-right: 5px;
}
@keyframes spin {
to {
transform: rotate(360deg);
}
}
.silence {
color: #666;
background-color: #f3f3f3;
font-size: 13px;
border-radius: 30px;
padding: 2px 10px;
}
.loading {
color: #666;
background-color: #ff4d4d0f;
border-radius: 8px 8px 8px 0px;
padding: 2px 10px;
font-size: 14px;
margin-bottom: 0px;
}
</style>
</head> </head>
<body> <body>
<div class="header-container">
<div class="settings-container"> <div class="settings-container">
<button id="recordButton"> <button id="recordButton">
<div class="shape-container"> <div class="shape-container">
<div class="shape"></div> <div class="shape"></div>
</div> </div>
<div class="recording-info"> <div class="recording-info">
<div class="wave-container"> <div class="wave-container">
<canvas id="waveCanvas"></canvas> <canvas id="waveCanvas"></canvas>
</div>
<div class="timer">00:00</div>
</div>
</button>
<div class="settings">
<div class="field">
<label for="websocketInput">Websocket URL</label>
<input id="websocketInput" type="text" placeholder="ws://host:port/asr" />
</div>
<div class="field">
<label id="microphoneSelectLabel" for="microphoneSelect">Select Microphone</label>
<select id="microphoneSelect">
<option value="">Default Microphone</option>
</select>
</div>
<div class="theme-selector-container">
<div class="segmented" role="radiogroup" aria-label="Theme selector">
<input type="radio" id="theme-system" name="theme" value="system" />
<label for="theme-system" title="System">
<img src="/web/src/system_mode.svg" alt="" />
<span>System</span>
</label>
<input type="radio" id="theme-light" name="theme" value="light" />
<label for="theme-light" title="Light">
<img src="/web/src/light_mode.svg" alt="" />
<span>Light</span>
</label>
<input type="radio" id="theme-dark" name="theme" value="dark" />
<label for="theme-dark" title="Dark">
<img src="/web/src/dark_mode.svg" alt="" />
<span>Dark</span>
</label>
</div>
</div> </div>
<div class="timer">00:00</div>
</div>
</button>
<div class="settings">
<div>
<label for="chunkSelector">Chunk size (ms):</label>
<select id="chunkSelector">
<option value="500">500 ms</option>
<option value="1000" selected>1000 ms</option>
<option value="2000">2000 ms</option>
<option value="3000">3000 ms</option>
<option value="4000">4000 ms</option>
<option value="5000">5000 ms</option>
</select>
</div>
<div>
<label for="websocketInput">WebSocket URL:</label>
<input id="websocketInput" type="text" />
</div> </div>
</div> </div>
<p id="status"></p>
</div> </div>
<p id="status"></p> <div class="transcript-container">
<div id="linesTranscript"></div>
</div>
<!-- Speaker-labeled transcript --> <script src="/web/live_transcription.js"></script>
<div id="linesTranscript"></div>
<script>
let isRecording = false;
let websocket = null;
let recorder = null;
let chunkDuration = 1000;
let websocketUrl = "ws://localhost:8000/asr";
let userClosing = false;
let startTime = null;
let timerInterval = null;
let audioContext = null;
let analyser = null;
let microphone = null;
let waveCanvas = document.getElementById("waveCanvas");
let waveCtx = waveCanvas.getContext("2d");
let animationFrame = null;
let waitingForStop = false;
let lastReceivedData = null;
waveCanvas.width = 60 * (window.devicePixelRatio || 1);
waveCanvas.height = 30 * (window.devicePixelRatio || 1);
waveCtx.scale(window.devicePixelRatio || 1, window.devicePixelRatio || 1);
const statusText = document.getElementById("status");
const recordButton = document.getElementById("recordButton");
const chunkSelector = document.getElementById("chunkSelector");
const websocketInput = document.getElementById("websocketInput");
const linesTranscriptDiv = document.getElementById("linesTranscript");
const timerElement = document.querySelector(".timer");
const host = window.location.hostname || "localhost";
const port = window.location.port;
const protocol = window.location.protocol === "https:" ? "wss" : "ws";
const defaultWebSocketUrl = `${protocol}://${host}:${port}/asr`;
websocketInput.value = defaultWebSocketUrl;
websocketUrl = defaultWebSocketUrl;
chunkSelector.addEventListener("change", () => {
chunkDuration = parseInt(chunkSelector.value);
});
websocketInput.addEventListener("change", () => {
const urlValue = websocketInput.value.trim();
if (!urlValue.startsWith("ws://") && !urlValue.startsWith("wss://")) {
statusText.textContent = "Invalid WebSocket URL (must start with ws:// or wss://)";
return;
}
websocketUrl = urlValue;
statusText.textContent = "WebSocket URL updated. Ready to connect.";
});
function setupWebSocket() {
return new Promise((resolve, reject) => {
try {
websocket = new WebSocket(websocketUrl);
} catch (error) {
statusText.textContent = "Invalid WebSocket URL. Please check and try again.";
reject(error);
return;
}
websocket.onopen = () => {
statusText.textContent = "Connected to server.";
resolve();
};
websocket.onclose = () => {
if (userClosing) {
if (waitingForStop) {
statusText.textContent = "Processing finalized or connection closed.";
if (lastReceivedData) {
renderLinesWithBuffer(
lastReceivedData.lines || [],
lastReceivedData.buffer_diarization || "",
lastReceivedData.buffer_transcription || "",
0, 0, true // isFinalizing = true
);
}
}
// If ready_to_stop was received, statusText is already "Finished processing..."
// and waitingForStop is false.
} else {
statusText.textContent = "Disconnected from the WebSocket server. (Check logs if model is loading.)";
if (isRecording) {
stopRecording();
}
}
isRecording = false;
waitingForStop = false;
userClosing = false;
lastReceivedData = null;
websocket = null;
updateUI();
};
websocket.onerror = () => {
statusText.textContent = "Error connecting to WebSocket.";
reject(new Error("Error connecting to WebSocket"));
};
// Handle messages from server
websocket.onmessage = (event) => {
const data = JSON.parse(event.data);
// Check for status messages
if (data.type === "ready_to_stop") {
console.log("Ready to stop received, finalizing display and closing WebSocket.");
waitingForStop = false;
if (lastReceivedData) {
renderLinesWithBuffer(
lastReceivedData.lines || [],
lastReceivedData.buffer_diarization || "",
lastReceivedData.buffer_transcription || "",
0, // No more lag
0, // No more lag
true // isFinalizing = true
);
}
statusText.textContent = "Finished processing audio! Ready to record again.";
recordButton.disabled = false;
if (websocket) {
websocket.close(); // will trigger onclose
// websocket = null; // onclose handle setting websocket to null
}
return;
}
lastReceivedData = data;
// Handle normal transcription updates
const {
lines = [],
buffer_transcription = "",
buffer_diarization = "",
remaining_time_transcription = 0,
remaining_time_diarization = 0,
status = "active_transcription"
} = data;
renderLinesWithBuffer(
lines,
buffer_diarization,
buffer_transcription,
remaining_time_diarization,
remaining_time_transcription,
false,
status
);
};
});
}
function renderLinesWithBuffer(lines, buffer_diarization, buffer_transcription, remaining_time_diarization, remaining_time_transcription, isFinalizing = false, current_status = "active_transcription") {
if (current_status === "no_audio_detected") {
linesTranscriptDiv.innerHTML = "<p style='text-align: center; color: #666; margin-top: 20px;'><em>No audio detected...</em></p>";
return;
}
const linesHtml = lines.map((item, idx) => {
let timeInfo = "";
if (item.beg !== undefined && item.end !== undefined) {
timeInfo = ` ${item.beg} - ${item.end}`;
}
let speakerLabel = "";
if (item.speaker === -2) {
speakerLabel = `<span class="silence">Silence<span id='timeInfo'>${timeInfo}</span></span>`;
} else if (item.speaker == 0 && !isFinalizing) {
speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'>${remaining_time_diarization} second(s) of audio are undergoing diarization</span></span>`;
} else if (item.speaker == -1) {
speakerLabel = `<span id="speaker">Speaker 1<span id='timeInfo'>${timeInfo}</span></span>`;
} else if (item.speaker !== -1 && item.speaker !== 0) {
speakerLabel = `<span id="speaker">Speaker ${item.speaker}<span id='timeInfo'>${timeInfo}</span></span>`;
}
let currentLineText = item.text || "";
if (idx === lines.length - 1) {
if (!isFinalizing) {
if (remaining_time_transcription > 0) {
speakerLabel += `<span class="label_transcription"><span class="spinner"></span>Transcription lag <span id='timeInfo'>${remaining_time_transcription}s</span></span>`;
}
if (buffer_diarization && remaining_time_diarization > 0) {
speakerLabel += `<span class="label_diarization"><span class="spinner"></span>Diarization lag<span id='timeInfo'>${remaining_time_diarization}s</span></span>`;
}
}
if (buffer_diarization) {
if (isFinalizing) {
currentLineText += (currentLineText.length > 0 && buffer_diarization.trim().length > 0 ? " " : "") + buffer_diarization.trim();
} else {
currentLineText += `<span class="buffer_diarization">${buffer_diarization}</span>`;
}
}
if (buffer_transcription) {
if (isFinalizing) {
currentLineText += (currentLineText.length > 0 && buffer_transcription.trim().length > 0 ? " " : "") + buffer_transcription.trim();
} else {
currentLineText += `<span class="buffer_transcription">${buffer_transcription}</span>`;
}
}
}
return currentLineText.trim().length > 0 || speakerLabel.length > 0
? `<p>${speakerLabel}<br/><div class='textcontent'>${currentLineText}</div></p>`
: `<p>${speakerLabel}<br/></p>`;
}).join("");
linesTranscriptDiv.innerHTML = linesHtml;
}
function updateTimer() {
if (!startTime) return;
const elapsed = Math.floor((Date.now() - startTime) / 1000);
const minutes = Math.floor(elapsed / 60).toString().padStart(2, "0");
const seconds = (elapsed % 60).toString().padStart(2, "0");
timerElement.textContent = `${minutes}:${seconds}`;
}
function drawWaveform() {
if (!analyser) return;
const bufferLength = analyser.frequencyBinCount;
const dataArray = new Uint8Array(bufferLength);
analyser.getByteTimeDomainData(dataArray);
waveCtx.clearRect(0, 0, waveCanvas.width / (window.devicePixelRatio || 1), waveCanvas.height / (window.devicePixelRatio || 1));
waveCtx.lineWidth = 1;
waveCtx.strokeStyle = 'rgb(0, 0, 0)';
waveCtx.beginPath();
const sliceWidth = (waveCanvas.width / (window.devicePixelRatio || 1)) / bufferLength;
let x = 0;
for (let i = 0; i < bufferLength; i++) {
const v = dataArray[i] / 128.0;
const y = v * (waveCanvas.height / (window.devicePixelRatio || 1)) / 2;
if (i === 0) {
waveCtx.moveTo(x, y);
} else {
waveCtx.lineTo(x, y);
}
x += sliceWidth;
}
waveCtx.lineTo(waveCanvas.width / (window.devicePixelRatio || 1), waveCanvas.height / (window.devicePixelRatio || 1) / 2);
waveCtx.stroke();
animationFrame = requestAnimationFrame(drawWaveform);
}
async function startRecording() {
try {
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
audioContext = new (window.AudioContext || window.webkitAudioContext)();
analyser = audioContext.createAnalyser();
analyser.fftSize = 256;
microphone = audioContext.createMediaStreamSource(stream);
microphone.connect(analyser);
recorder = new MediaRecorder(stream, { mimeType: "audio/webm" });
recorder.ondataavailable = (e) => {
if (websocket && websocket.readyState === WebSocket.OPEN) {
websocket.send(e.data);
}
};
recorder.start(chunkDuration);
startTime = Date.now();
timerInterval = setInterval(updateTimer, 1000);
drawWaveform();
isRecording = true;
updateUI();
} catch (err) {
statusText.textContent = "Error accessing microphone. Please allow microphone access.";
console.error(err);
}
}
async function stopRecording() {
userClosing = true;
waitingForStop = true;
if (websocket && websocket.readyState === WebSocket.OPEN) {
// Send empty audio buffer as stop signal
const emptyBlob = new Blob([], { type: 'audio/webm' });
websocket.send(emptyBlob);
statusText.textContent = "Recording stopped. Processing final audio...";
}
if (recorder) {
recorder.stop();
recorder = null;
}
if (microphone) {
microphone.disconnect();
microphone = null;
}
if (analyser) {
analyser = null;
}
if (audioContext && audioContext.state !== 'closed') {
try {
audioContext.close();
} catch (e) {
console.warn("Could not close audio context:", e);
}
audioContext = null;
}
if (animationFrame) {
cancelAnimationFrame(animationFrame);
animationFrame = null;
}
if (timerInterval) {
clearInterval(timerInterval);
timerInterval = null;
}
timerElement.textContent = "00:00";
startTime = null;
isRecording = false;
updateUI();
}
async function toggleRecording() {
if (!isRecording) {
if (waitingForStop) {
console.log("Waiting for stop, early return");
return; // Early return, UI is already updated
}
console.log("Connecting to WebSocket");
try {
// If we have an active WebSocket that's still processing, just restart audio capture
if (websocket && websocket.readyState === WebSocket.OPEN) {
await startRecording();
} else {
// If no active WebSocket or it's closed, create new one
await setupWebSocket();
await startRecording();
}
} catch (err) {
statusText.textContent = "Could not connect to WebSocket or access mic. Aborted.";
console.error(err);
}
} else {
console.log("Stopping recording");
stopRecording();
}
}
function updateUI() {
recordButton.classList.toggle("recording", isRecording);
recordButton.disabled = waitingForStop;
if (waitingForStop) {
if (statusText.textContent !== "Recording stopped. Processing final audio...") {
statusText.textContent = "Please wait for processing to complete...";
}
} else if (isRecording) {
statusText.textContent = "Recording...";
} else {
if (statusText.textContent !== "Finished processing audio! Ready to record again." &&
statusText.textContent !== "Processing finalized or connection closed.") {
statusText.textContent = "Click to start transcription";
}
}
if (!waitingForStop) {
recordButton.disabled = false;
}
}
recordButton.addEventListener("click", toggleRecording);
</script>
</body> </body>
</html> </html>

View File

@@ -0,0 +1,609 @@
/* Theme, WebSocket, recording, rendering logic extracted from inline script and adapted for segmented theme control and WS caption */
let isRecording = false;
let websocket = null;
let recorder = null;
let chunkDuration = 100;
let websocketUrl = "ws://localhost:8000/asr";
let userClosing = false;
let wakeLock = null;
let startTime = null;
let timerInterval = null;
let audioContext = null;
let analyser = null;
let microphone = null;
let waveCanvas = document.getElementById("waveCanvas");
let waveCtx = waveCanvas.getContext("2d");
let animationFrame = null;
let waitingForStop = false;
let lastReceivedData = null;
let lastSignature = null;
let availableMicrophones = [];
let selectedMicrophoneId = null;
waveCanvas.width = 60 * (window.devicePixelRatio || 1);
waveCanvas.height = 30 * (window.devicePixelRatio || 1);
waveCtx.scale(window.devicePixelRatio || 1, window.devicePixelRatio || 1);
const statusText = document.getElementById("status");
const recordButton = document.getElementById("recordButton");
const chunkSelector = document.getElementById("chunkSelector");
const websocketInput = document.getElementById("websocketInput");
const websocketDefaultSpan = document.getElementById("wsDefaultUrl");
const linesTranscriptDiv = document.getElementById("linesTranscript");
const timerElement = document.querySelector(".timer");
const themeRadios = document.querySelectorAll('input[name="theme"]');
const microphoneSelect = document.getElementById("microphoneSelect");
function getWaveStroke() {
const styles = getComputedStyle(document.documentElement);
const v = styles.getPropertyValue("--wave-stroke").trim();
return v || "#000";
}
let waveStroke = getWaveStroke();
function updateWaveStroke() {
waveStroke = getWaveStroke();
}
function applyTheme(pref) {
if (pref === "light") {
document.documentElement.setAttribute("data-theme", "light");
} else if (pref === "dark") {
document.documentElement.setAttribute("data-theme", "dark");
} else {
document.documentElement.removeAttribute("data-theme");
}
updateWaveStroke();
}
// Persisted theme preference
const savedThemePref = localStorage.getItem("themePreference") || "system";
applyTheme(savedThemePref);
if (themeRadios.length) {
themeRadios.forEach((r) => {
r.checked = r.value === savedThemePref;
r.addEventListener("change", () => {
if (r.checked) {
localStorage.setItem("themePreference", r.value);
applyTheme(r.value);
}
});
});
}
// React to OS theme changes when in "system" mode
const darkMq = window.matchMedia && window.matchMedia("(prefers-color-scheme: dark)");
const handleOsThemeChange = () => {
const pref = localStorage.getItem("themePreference") || "system";
if (pref === "system") updateWaveStroke();
};
if (darkMq && darkMq.addEventListener) {
darkMq.addEventListener("change", handleOsThemeChange);
} else if (darkMq && darkMq.addListener) {
// deprecated, but included for Safari compatibility
darkMq.addListener(handleOsThemeChange);
}
async function enumerateMicrophones() {
try {
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
stream.getTracks().forEach(track => track.stop());
const devices = await navigator.mediaDevices.enumerateDevices();
availableMicrophones = devices.filter(device => device.kind === 'audioinput');
populateMicrophoneSelect();
console.log(`Found ${availableMicrophones.length} microphone(s)`);
} catch (error) {
console.error('Error enumerating microphones:', error);
statusText.textContent = "Error accessing microphones. Please grant permission.";
}
}
function populateMicrophoneSelect() {
if (!microphoneSelect) return;
microphoneSelect.innerHTML = '<option value="">Default Microphone</option>';
availableMicrophones.forEach((device, index) => {
const option = document.createElement('option');
option.value = device.deviceId;
option.textContent = device.label || `Microphone ${index + 1}`;
microphoneSelect.appendChild(option);
});
const savedMicId = localStorage.getItem('selectedMicrophone');
if (savedMicId && availableMicrophones.some(mic => mic.deviceId === savedMicId)) {
microphoneSelect.value = savedMicId;
selectedMicrophoneId = savedMicId;
}
}
function handleMicrophoneChange() {
selectedMicrophoneId = microphoneSelect.value || null;
localStorage.setItem('selectedMicrophone', selectedMicrophoneId || '');
const selectedDevice = availableMicrophones.find(mic => mic.deviceId === selectedMicrophoneId);
const deviceName = selectedDevice ? selectedDevice.label : 'Default Microphone';
console.log(`Selected microphone: ${deviceName}`);
statusText.textContent = `Microphone changed to: ${deviceName}`;
if (isRecording) {
statusText.textContent = "Switching microphone... Please wait.";
stopRecording().then(() => {
setTimeout(() => {
toggleRecording();
}, 1000);
});
}
}
// Helpers
function fmt1(x) {
const n = Number(x);
return Number.isFinite(n) ? n.toFixed(1) : x;
}
// Default WebSocket URL computation
const host = window.location.hostname || "localhost";
const port = window.location.port;
const protocol = window.location.protocol === "https:" ? "wss" : "ws";
const defaultWebSocketUrl = `${protocol}://${host}${port ? ":" + port : ""}/asr`;
// Populate default caption and input
if (websocketDefaultSpan) websocketDefaultSpan.textContent = defaultWebSocketUrl;
websocketInput.value = defaultWebSocketUrl;
websocketUrl = defaultWebSocketUrl;
// Optional chunk selector (guard for presence)
if (chunkSelector) {
chunkSelector.addEventListener("change", () => {
chunkDuration = parseInt(chunkSelector.value);
});
}
// WebSocket input change handling
websocketInput.addEventListener("change", () => {
const urlValue = websocketInput.value.trim();
if (!urlValue.startsWith("ws://") && !urlValue.startsWith("wss://")) {
statusText.textContent = "Invalid WebSocket URL (must start with ws:// or wss://)";
return;
}
websocketUrl = urlValue;
statusText.textContent = "WebSocket URL updated. Ready to connect.";
});
function setupWebSocket() {
return new Promise((resolve, reject) => {
try {
websocket = new WebSocket(websocketUrl);
} catch (error) {
statusText.textContent = "Invalid WebSocket URL. Please check and try again.";
reject(error);
return;
}
websocket.onopen = () => {
statusText.textContent = "Connected to server.";
resolve();
};
websocket.onclose = () => {
if (userClosing) {
if (waitingForStop) {
statusText.textContent = "Processing finalized or connection closed.";
if (lastReceivedData) {
renderLinesWithBuffer(
lastReceivedData.lines || [],
lastReceivedData.buffer_diarization || "",
lastReceivedData.buffer_transcription || "",
0,
0,
true
);
}
}
} else {
statusText.textContent = "Disconnected from the WebSocket server. (Check logs if model is loading.)";
if (isRecording) {
stopRecording();
}
}
isRecording = false;
waitingForStop = false;
userClosing = false;
lastReceivedData = null;
websocket = null;
updateUI();
};
websocket.onerror = () => {
statusText.textContent = "Error connecting to WebSocket.";
reject(new Error("Error connecting to WebSocket"));
};
websocket.onmessage = (event) => {
const data = JSON.parse(event.data);
if (data.type === "ready_to_stop") {
console.log("Ready to stop received, finalizing display and closing WebSocket.");
waitingForStop = false;
if (lastReceivedData) {
renderLinesWithBuffer(
lastReceivedData.lines || [],
lastReceivedData.buffer_diarization || "",
lastReceivedData.buffer_transcription || "",
0,
0,
true
);
}
statusText.textContent = "Finished processing audio! Ready to record again.";
recordButton.disabled = false;
if (websocket) {
websocket.close();
}
return;
}
lastReceivedData = data;
const {
lines = [],
buffer_transcription = "",
buffer_diarization = "",
remaining_time_transcription = 0,
remaining_time_diarization = 0,
status = "active_transcription",
} = data;
renderLinesWithBuffer(
lines,
buffer_diarization,
buffer_transcription,
remaining_time_diarization,
remaining_time_transcription,
false,
status
);
};
});
}
function renderLinesWithBuffer(
lines,
buffer_diarization,
buffer_transcription,
remaining_time_diarization,
remaining_time_transcription,
isFinalizing = false,
current_status = "active_transcription"
) {
if (current_status === "no_audio_detected") {
linesTranscriptDiv.innerHTML =
"<p style='text-align: center; color: var(--muted); margin-top: 20px;'><em>No audio detected...</em></p>";
return;
}
const showLoading = !isFinalizing && (lines || []).some((it) => it.speaker == 0);
const showTransLag = !isFinalizing && remaining_time_transcription > 0;
const showDiaLag = !isFinalizing && !!buffer_diarization && remaining_time_diarization > 0;
const signature = JSON.stringify({
lines: (lines || []).map((it) => ({ speaker: it.speaker, text: it.text, start: it.start, end: it.end })),
buffer_transcription: buffer_transcription || "",
buffer_diarization: buffer_diarization || "",
status: current_status,
showLoading,
showTransLag,
showDiaLag,
isFinalizing: !!isFinalizing,
});
if (lastSignature === signature) {
const t = document.querySelector(".lag-transcription-value");
if (t) t.textContent = fmt1(remaining_time_transcription);
const d = document.querySelector(".lag-diarization-value");
if (d) d.textContent = fmt1(remaining_time_diarization);
const ld = document.querySelector(".loading-diarization-value");
if (ld) ld.textContent = fmt1(remaining_time_diarization);
return;
}
lastSignature = signature;
const linesHtml = (lines || [])
.map((item, idx) => {
let timeInfo = "";
if (item.start !== undefined && item.end !== undefined) {
timeInfo = ` ${item.start} - ${item.end}`;
}
let speakerLabel = "";
if (item.speaker === -2) {
speakerLabel = `<span class="silence">Silence<span id='timeInfo'>${timeInfo}</span></span>`;
} else if (item.speaker == 0 && !isFinalizing) {
speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'><span class="loading-diarization-value">${fmt1(
remaining_time_diarization
)}</span> second(s) of audio are undergoing diarization</span></span>`;
} else if (item.speaker !== 0) {
speakerLabel = `<span id="speaker">Speaker ${item.speaker}<span id='timeInfo'>${timeInfo}</span></span>`;
}
let currentLineText = item.text || "";
if (item.translation) {
currentLineText += `<div class="label_translation">
<img src="/web/src/translate.svg" alt="Translation" width="12" height="12" />
<span>${item.translation}</span>
</div>`;
}
if (idx === lines.length - 1) {
if (!isFinalizing && item.speaker !== -2) {
if (remaining_time_transcription > 0) {
speakerLabel += `<span class="label_transcription"><span class="spinner"></span>Transcription lag <span id='timeInfo'><span class="lag-transcription-value">${fmt1(
remaining_time_transcription
)}</span>s</span></span>`;
}
if (buffer_diarization && remaining_time_diarization > 0) {
speakerLabel += `<span class="label_diarization"><span class="spinner"></span>Diarization lag<span id='timeInfo'><span class="lag-diarization-value">${fmt1(
remaining_time_diarization
)}</span>s</span></span>`;
}
}
if (buffer_diarization) {
if (isFinalizing) {
currentLineText +=
(currentLineText.length > 0 && buffer_diarization.trim().length > 0 ? " " : "") + buffer_diarization.trim();
} else {
currentLineText += `<span class="buffer_diarization">${buffer_diarization}</span>`;
}
}
if (buffer_transcription) {
if (isFinalizing) {
currentLineText +=
(currentLineText.length > 0 && buffer_transcription.trim().length > 0 ? " " : "") +
buffer_transcription.trim();
} else {
currentLineText += `<span class="buffer_transcription">${buffer_transcription}</span>`;
}
}
}
return currentLineText.trim().length > 0 || speakerLabel.length > 0
? `<p>${speakerLabel}<br/><div class='textcontent'>${currentLineText}</div></p>`
: `<p>${speakerLabel}<br/></p>`;
})
.join("");
linesTranscriptDiv.innerHTML = linesHtml;
const transcriptContainer = document.querySelector('.transcript-container');
if (transcriptContainer) {
transcriptContainer.scrollTo({ top: transcriptContainer.scrollHeight, behavior: "smooth" });
}
}
function updateTimer() {
if (!startTime) return;
const elapsed = Math.floor((Date.now() - startTime) / 1000);
const minutes = Math.floor(elapsed / 60).toString().padStart(2, "0");
const seconds = (elapsed % 60).toString().padStart(2, "0");
timerElement.textContent = `${minutes}:${seconds}`;
}
function drawWaveform() {
if (!analyser) return;
const bufferLength = analyser.frequencyBinCount;
const dataArray = new Uint8Array(bufferLength);
analyser.getByteTimeDomainData(dataArray);
waveCtx.clearRect(
0,
0,
waveCanvas.width / (window.devicePixelRatio || 1),
waveCanvas.height / (window.devicePixelRatio || 1)
);
waveCtx.lineWidth = 1;
waveCtx.strokeStyle = waveStroke;
waveCtx.beginPath();
const sliceWidth = (waveCanvas.width / (window.devicePixelRatio || 1)) / bufferLength;
let x = 0;
for (let i = 0; i < bufferLength; i++) {
const v = dataArray[i] / 128.0;
const y = (v * (waveCanvas.height / (window.devicePixelRatio || 1))) / 2;
if (i === 0) {
waveCtx.moveTo(x, y);
} else {
waveCtx.lineTo(x, y);
}
x += sliceWidth;
}
waveCtx.lineTo(
waveCanvas.width / (window.devicePixelRatio || 1),
(waveCanvas.height / (window.devicePixelRatio || 1)) / 2
);
waveCtx.stroke();
animationFrame = requestAnimationFrame(drawWaveform);
}
async function startRecording() {
try {
try {
wakeLock = await navigator.wakeLock.request("screen");
} catch (err) {
console.log("Error acquiring wake lock.");
}
const audioConstraints = selectedMicrophoneId
? { audio: { deviceId: { exact: selectedMicrophoneId } } }
: { audio: true };
const stream = await navigator.mediaDevices.getUserMedia(audioConstraints);
audioContext = new (window.AudioContext || window.webkitAudioContext)();
analyser = audioContext.createAnalyser();
analyser.fftSize = 256;
microphone = audioContext.createMediaStreamSource(stream);
microphone.connect(analyser);
recorder = new MediaRecorder(stream, { mimeType: "audio/webm" });
recorder.ondataavailable = (e) => {
if (websocket && websocket.readyState === WebSocket.OPEN) {
websocket.send(e.data);
}
};
recorder.start(chunkDuration);
startTime = Date.now();
timerInterval = setInterval(updateTimer, 1000);
drawWaveform();
isRecording = true;
updateUI();
} catch (err) {
if (window.location.hostname === "0.0.0.0") {
statusText.textContent =
"Error accessing microphone. Browsers may block microphone access on 0.0.0.0. Try using localhost:8000 instead.";
} else {
statusText.textContent = "Error accessing microphone. Please allow microphone access.";
}
console.error(err);
}
}
async function stopRecording() {
if (wakeLock) {
try {
await wakeLock.release();
} catch (e) {
// ignore
}
wakeLock = null;
}
userClosing = true;
waitingForStop = true;
if (websocket && websocket.readyState === WebSocket.OPEN) {
const emptyBlob = new Blob([], { type: "audio/webm" });
websocket.send(emptyBlob);
statusText.textContent = "Recording stopped. Processing final audio...";
}
if (recorder) {
recorder.stop();
recorder = null;
}
if (microphone) {
microphone.disconnect();
microphone = null;
}
if (analyser) {
analyser = null;
}
if (audioContext && audioContext.state !== "closed") {
try {
await audioContext.close();
} catch (e) {
console.warn("Could not close audio context:", e);
}
audioContext = null;
}
if (animationFrame) {
cancelAnimationFrame(animationFrame);
animationFrame = null;
}
if (timerInterval) {
clearInterval(timerInterval);
timerInterval = null;
}
timerElement.textContent = "00:00";
startTime = null;
isRecording = false;
updateUI();
}
async function toggleRecording() {
if (!isRecording) {
if (waitingForStop) {
console.log("Waiting for stop, early return");
return;
}
console.log("Connecting to WebSocket");
try {
if (websocket && websocket.readyState === WebSocket.OPEN) {
await startRecording();
} else {
await setupWebSocket();
await startRecording();
}
} catch (err) {
statusText.textContent = "Could not connect to WebSocket or access mic. Aborted.";
console.error(err);
}
} else {
console.log("Stopping recording");
stopRecording();
}
}
function updateUI() {
recordButton.classList.toggle("recording", isRecording);
recordButton.disabled = waitingForStop;
if (waitingForStop) {
if (statusText.textContent !== "Recording stopped. Processing final audio...") {
statusText.textContent = "Please wait for processing to complete...";
}
} else if (isRecording) {
statusText.textContent = "Recording...";
} else {
if (
statusText.textContent !== "Finished processing audio! Ready to record again." &&
statusText.textContent !== "Processing finalized or connection closed."
) {
statusText.textContent = "Click to start transcription";
}
}
if (!waitingForStop) {
recordButton.disabled = false;
}
}
recordButton.addEventListener("click", toggleRecording);
if (microphoneSelect) {
microphoneSelect.addEventListener("change", handleMicrophoneChange);
}
document.addEventListener('DOMContentLoaded', async () => {
try {
await enumerateMicrophones();
} catch (error) {
console.log("Could not enumerate microphones on load:", error);
}
});
navigator.mediaDevices.addEventListener('devicechange', async () => {
console.log('Device change detected, re-enumerating microphones');
try {
await enumerateMicrophones();
} catch (error) {
console.log("Error re-enumerating microphones:", error);
}
});

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M480-120q-151 0-255.5-104.5T120-480q0-138 90-239.5T440-838q13-2 23 3.5t16 14.5q6 9 6.5 21t-7.5 23q-17 26-25.5 55t-8.5 61q0 90 63 153t153 63q31 0 61.5-9t54.5-25q11-7 22.5-6.5T819-479q10 5 15.5 15t3.5 24q-14 138-117.5 229T480-120Zm0-80q88 0 158-48.5T740-375q-20 5-40 8t-40 3q-123 0-209.5-86.5T364-660q0-20 3-40t8-40q-78 32-126.5 102T200-480q0 116 82 198t198 82Zm-10-270Z"/></svg>

After

Width:  |  Height:  |  Size: 493 B

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M480-360q50 0 85-35t35-85q0-50-35-85t-85-35q-50 0-85 35t-35 85q0 50 35 85t85 35Zm0 80q-83 0-141.5-58.5T280-480q0-83 58.5-141.5T480-680q83 0 141.5 58.5T680-480q0 83-58.5 141.5T480-280ZM80-440q-17 0-28.5-11.5T40-480q0-17 11.5-28.5T80-520h80q17 0 28.5 11.5T200-480q0 17-11.5 28.5T160-440H80Zm720 0q-17 0-28.5-11.5T760-480q0-17 11.5-28.5T800-520h80q17 0 28.5 11.5T920-480q0 17-11.5 28.5T880-440h-80ZM480-760q-17 0-28.5-11.5T440-800v-80q0-17 11.5-28.5T480-920q17 0 28.5 11.5T520-880v80q0 17-11.5 28.5T480-760Zm0 720q-17 0-28.5-11.5T440-80v-80q0-17 11.5-28.5T480-200q17 0 28.5 11.5T520-160v80q0 17-11.5 28.5T480-40ZM226-678l-43-42q-12-11-11.5-28t11.5-29q12-12 29-12t28 12l42 43q11 12 11 28t-11 28q-11 12-27.5 11.5T226-678Zm494 495-42-43q-11-12-11-28.5t11-27.5q11-12 27.5-11.5T734-282l43 42q12 11 11.5 28T777-183q-12 12-29 12t-28-12Zm-42-495q-12-11-11.5-27.5T678-734l42-43q11-12 28-11.5t29 11.5q12 12 12 29t-12 28l-43 42q-12 11-28 11t-28-11ZM183-183q-12-12-12-29t12-28l43-42q12-11 28.5-11t27.5 11q12 11 11.5 27.5T282-226l-42 43q-11 12-28 11.5T183-183Zm297-297Z"/></svg>

After

Width:  |  Height:  |  Size: 1.2 KiB

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M433-80q-27 0-46.5-18T363-142l-9-66q-13-5-24.5-12T307-235l-62 26q-25 11-50 2t-39-32l-47-82q-14-23-8-49t27-43l53-40q-1-7-1-13.5v-27q0-6.5 1-13.5l-53-40q-21-17-27-43t8-49l47-82q14-23 39-32t50 2l62 26q11-8 23-15t24-12l9-66q4-26 23.5-44t46.5-18h94q27 0 46.5 18t23.5 44l9 66q13 5 24.5 12t22.5 15l62-26q25-11 50-2t39 32l47 82q14 23 8 49t-27 43l-53 40q1 7 1 13.5v27q0 6.5-2 13.5l53 40q21 17 27 43t-8 49l-48 82q-14 23-39 32t-50-2l-60-26q-11 8-23 15t-24 12l-9 66q-4 26-23.5 44T527-80h-94Zm7-80h79l14-106q31-8 57.5-23.5T639-327l99 41 39-68-86-65q5-14 7-29.5t2-31.5q0-16-2-31.5t-7-29.5l86-65-39-68-99 42q-22-23-48.5-38.5T533-694l-13-106h-79l-14 106q-31 8-57.5 23.5T321-633l-99-41-39 68 86 64q-5 15-7 30t-2 32q0 16 2 31t7 30l-86 65 39 68 99-42q22 23 48.5 38.5T427-266l13 106Zm42-180q58 0 99-41t41-99q0-58-41-99t-99-41q-59 0-99.5 41T342-480q0 58 40.5 99t99.5 41Zm-2-140Z"/></svg>

After

Width:  |  Height:  |  Size: 982 B

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M396-396q-32-32-58.5-67T289-537q-5 14-6.5 28.5T281-480q0 83 58 141t141 58q14 0 28.5-2t28.5-6q-39-22-74-48.5T396-396Zm85 196q-56 0-107-21t-91-61q-40-40-61-91t-21-107q0-51 17-97.5t50-84.5q13-14 32-9.5t27 24.5q21 55 52.5 104t73.5 91q42 42 91 73.5T648-326q20 8 24.5 27t-9.5 32q-38 33-84.5 50T481-200Zm223-192q-16-5-23-20.5t-4-32.5q9-48-6-94.5T621-621q-35-35-80.5-49.5T448-677q-17 3-32-4t-21-23q-6-16 1.5-31t23.5-19q69-15 138 4.5T679-678q51 51 71 120t5 138q-4 17-19 25t-32 3ZM480-840q-17 0-28.5-11.5T440-880v-40q0-17 11.5-28.5T480-960q17 0 28.5 11.5T520-920v40q0 17-11.5 28.5T480-840Zm0 840q-17 0-28.5-11.5T440-40v-40q0-17 11.5-28.5T480-120q17 0 28.5 11.5T520-80v40q0 17-11.5 28.5T480 0Zm255-734q-12-12-12-28.5t12-28.5l28-28q11-11 27.5-11t28.5 11q12 12 12 28.5T819-762l-28 28q-12 12-28 12t-28-12ZM141-141q-12-12-12-28.5t12-28.5l28-28q12-12 28-12t28 12q12 12 12 28.5T225-169l-28 28q-11 11-27.5 11T141-141Zm739-299q-17 0-28.5-11.5T840-480q0-17 11.5-28.5T880-520h40q17 0 28.5 11.5T960-480q0 17-11.5 28.5T920-440h-40Zm-840 0q-17 0-28.5-11.5T0-480q0-17 11.5-28.5T40-520h40q17 0 28.5 11.5T120-480q0 17-11.5 28.5T80-440H40Zm779 299q-12 12-28.5 12T762-141l-28-28q-12-12-12-28t12-28q12-12 28.5-12t28.5 12l28 28q11 11 11 27.5T819-141ZM226-735q-12 12-28.5 12T169-735l-28-28q-11-11-11-27.5t11-28.5q12-12 28.5-12t28.5 12l28 28q12 12 12 28t-12 28Zm170 339Z"/></svg>

After

Width:  |  Height:  |  Size: 1.4 KiB

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="m603-202-34 97q-4 11-14 18t-22 7q-20 0-32.5-16.5T496-133l152-402q5-11 15-18t22-7h30q12 0 22 7t15 18l152 403q8 19-4 35.5T868-80q-13 0-22.5-7T831-106l-34-96H603ZM362-401 188-228q-11 11-27.5 11.5T132-228q-11-11-11-28t11-28l174-174q-35-35-63.5-80T190-640h84q20 39 40 68t48 58q33-33 68.5-92.5T484-720H80q-17 0-28.5-11.5T40-760q0-17 11.5-28.5T80-800h240v-40q0-17 11.5-28.5T360-880q17 0 28.5 11.5T400-840v40h240q17 0 28.5 11.5T680-760q0 17-11.5 28.5T640-720h-76q-21 72-63 148t-83 116l96 98-30 82-122-125Zm266 129h144l-72-204-72 204Z"/></svg>

After

Width:  |  Height:  |  Size: 650 B

View File

@@ -1,5 +1,6 @@
import logging import logging
import importlib.resources as resources import importlib.resources as resources
import base64
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -10,4 +11,78 @@ def get_web_interface_html():
return f.read() return f.read()
except Exception as e: except Exception as e:
logger.error(f"Error loading web interface HTML: {e}") logger.error(f"Error loading web interface HTML: {e}")
return "<html><body><h1>Error loading interface</h1></body></html>" return "<html><body><h1>Error loading interface</h1></body></html>"
def get_inline_ui_html():
"""Returns the complete web interface HTML with all assets embedded in a single call."""
try:
with resources.files('whisperlivekit.web').joinpath('live_transcription.html').open('r', encoding='utf-8') as f:
html_content = f.read()
with resources.files('whisperlivekit.web').joinpath('live_transcription.css').open('r', encoding='utf-8') as f:
css_content = f.read()
with resources.files('whisperlivekit.web').joinpath('live_transcription.js').open('r', encoding='utf-8') as f:
js_content = f.read()
# SVG files
with resources.files('whisperlivekit.web').joinpath('src', 'system_mode.svg').open('r', encoding='utf-8') as f:
system_svg = f.read()
system_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(system_svg.encode('utf-8')).decode('utf-8')}"
with resources.files('whisperlivekit.web').joinpath('src', 'light_mode.svg').open('r', encoding='utf-8') as f:
light_svg = f.read()
light_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(light_svg.encode('utf-8')).decode('utf-8')}"
with resources.files('whisperlivekit.web').joinpath('src', 'dark_mode.svg').open('r', encoding='utf-8') as f:
dark_svg = f.read()
dark_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(dark_svg.encode('utf-8')).decode('utf-8')}"
# Replace external references
html_content = html_content.replace(
'<link rel="stylesheet" href="/web/live_transcription.css" />',
f'<style>\n{css_content}\n</style>'
)
html_content = html_content.replace(
'<script src="/web/live_transcription.js"></script>',
f'<script>\n{js_content}\n</script>'
)
# Replace SVG references
html_content = html_content.replace(
'<img src="/web/src/system_mode.svg" alt="" />',
f'<img src="{system_data_uri}" alt="" />'
)
html_content = html_content.replace(
'<img src="/web/src/light_mode.svg" alt="" />',
f'<img src="{light_data_uri}" alt="" />'
)
html_content = html_content.replace(
'<img src="/web/src/dark_mode.svg" alt="" />',
f'<img src="{dark_data_uri}" alt="" />'
)
return html_content
except Exception as e:
logger.error(f"Error creating embedded web interface: {e}")
return "<html><body><h1>Error loading embedded interface</h1></body></html>"
if __name__ == '__main__':
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
import uvicorn
from starlette.staticfiles import StaticFiles
import pathlib
import whisperlivekit.web as webpkg
app = FastAPI()
web_dir = pathlib.Path(webpkg.__file__).parent
app.mount("/web", StaticFiles(directory=str(web_dir)), name="web")
@app.get("/")
async def get():
return HTMLResponse(get_inline_ui_html())
uvicorn.run(app=app)

View File

@@ -3,43 +3,10 @@ import logging
import io import io
import soundfile as sf import soundfile as sf
import math import math
try:
import torch
except ImportError:
torch = None
from typing import List from typing import List
import numpy as np import numpy as np
from whisperlivekit.timed_objects import ASRToken from whisperlivekit.timed_objects import ASRToken
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SIMULSTREAMING_ERROR_AND_INSTALLATION_INSTRUCTIONS = ImportError(
"""SimulStreaming dependencies are not available.
Please install WhisperLiveKit using pip install "whisperlivekit[simulstreaming]"
""")
try:
from whisperlivekit.simul_whisper.config import AlignAttConfig
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper, DEC_PAD
from whisperlivekit.simul_whisper.whisper import tokenizer
SIMULSTREAMING_AVAILABLE = True
except ImportError:
logger.warning("⚠️ SimulStreaming dependencies not available. Attempting to download them.")
try:
from whisperlivekit import download_simulstreaming_backend
download_simulstreaming_backend()
from whisperlivekit.simul_whisper.config import AlignAttConfig
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper, DEC_PAD
from whisperlivekit.simul_whisper.whisper import tokenizer
SIMULSTREAMING_AVAILABLE = True
logger.info("SimulStreaming dependencies downloaded successfully.")
except Exception as e:
logger.error(f"Failed to download or import SimulStreaming dependencies: {e}")
SIMULSTREAMING_AVAILABLE = False
AlignAttConfig = None
PaddedAlignAttWhisper = None
DEC_PAD = None
tokenizer = None
class ASRBase: class ASRBase:
sep = " " # join transcribe words with this character (" " for whisper_timestamped, sep = " " # join transcribe words with this character (" " for whisper_timestamped,
# "" for faster-whisper because it emits the spaces when needed) # "" for faster-whisper because it emits the spaces when needed)
@@ -320,182 +287,4 @@ class OpenaiApiASR(ASRBase):
self.use_vad_opt = True self.use_vad_opt = True
def set_translate_task(self): def set_translate_task(self):
self.task = "translate" self.task = "translate"
class SimulStreamingASR(ASRBase):
"""SimulStreaming backend with AlignAtt policy."""
sep = ""
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr, **kwargs):
if not SIMULSTREAMING_AVAILABLE:
raise SIMULSTREAMING_ERROR_AND_INSTALLATION_INSTRUCTIONS
with open("whisperlivekit/simul_whisper/dual_license_simulstreaming.md", "r") as f:
print("*"*80 + f.read() + "*"*80)
self.logfile = logfile
self.transcribe_kargs = {}
self.original_language = None if lan == "auto" else lan
self.model_path = kwargs.get('model_path', './large-v3.pt')
self.frame_threshold = kwargs.get('frame_threshold', 25)
self.audio_max_len = kwargs.get('audio_max_len', 30.0)
self.audio_min_len = kwargs.get('audio_min_len', 0.0)
self.segment_length = kwargs.get('segment_length', 0.5)
self.beams = kwargs.get('beams', 1)
self.decoder_type = kwargs.get('decoder_type', 'greedy' if self.beams == 1 else 'beam')
self.task = kwargs.get('task', 'transcribe')
self.cif_ckpt_path = kwargs.get('cif_ckpt_path', None)
self.never_fire = kwargs.get('never_fire', False)
self.init_prompt = kwargs.get('init_prompt', None)
self.static_init_prompt = kwargs.get('static_init_prompt', None)
self.max_context_tokens = kwargs.get('max_context_tokens', None)
if model_dir is not None:
self.model_path = model_dir
elif modelsize is not None: #For the moment the .en.pt models do not work!
model_mapping = {
'tiny': './tiny.pt',
'base': './base.pt',
'small': './small.pt',
'medium': './medium.pt',
'medium.en': './medium.en.pt',
'large-v1': './large-v1.pt',
'base.en': './base.en.pt',
'small.en': './small.en.pt',
'tiny.en': './tiny.en.pt',
'large-v2': './large-v2.pt',
'large-v3': './large-v3.pt',
'large': './large-v3.pt'
}
self.model_path = model_mapping.get(modelsize, f'./{modelsize}.pt')
self.model = self.load_model(modelsize, cache_dir, model_dir)
# Set up tokenizer for translation if needed
if self.task == "translate":
self.set_translate_task()
def load_model(self, modelsize, cache_dir, model_dir):
try:
cfg = AlignAttConfig(
model_path=self.model_path,
segment_length=self.segment_length,
frame_threshold=self.frame_threshold,
language=self.original_language,
audio_max_len=self.audio_max_len,
audio_min_len=self.audio_min_len,
cif_ckpt_path=self.cif_ckpt_path,
decoder_type="beam",
beam_size=self.beams,
task=self.task,
never_fire=self.never_fire,
init_prompt=self.init_prompt,
max_context_tokens=self.max_context_tokens,
static_init_prompt=self.static_init_prompt,
)
logger.info(f"Loading SimulStreaming model with language: {self.original_language}")
model = PaddedAlignAttWhisper(cfg)
return model
except Exception as e:
logger.error(f"Failed to load SimulStreaming model: {e}")
raise
def transcribe(self, audio, init_prompt=""):
"""Transcribe audio using SimulStreaming."""
try:
if isinstance(audio, np.ndarray):
audio_tensor = torch.from_numpy(audio).float()
else:
audio_tensor = audio
prompt = init_prompt if init_prompt else (self.init_prompt or "")
result = self.model.infer(audio_tensor, init_prompt=prompt)
if torch.is_tensor(result):
result = result[result < DEC_PAD]
logger.debug(f"SimulStreaming transcription result: {result}")
return result
except Exception as e:
logger.error(f"SimulStreaming transcription failed: {e}")
raise
def ts_words(self, result) -> List[ASRToken]:
"""Convert SimulStreaming result to ASRToken list."""
tokens = []
try:
if torch.is_tensor(result):
text = self.model.tokenizer.decode(result.cpu().numpy())
else:
text = str(result)
if not text or len(text.strip()) == 0:
return tokens
# We dont have word-level timestamps here. 1rst approach, should be improved later.
words = text.strip().split()
if not words:
return tokens
duration_per_word = 0.1 # this will be modified based on actual audio duration
#with the SimulStreamingOnlineProcessor
for i, word in enumerate(words):
start_time = i * duration_per_word
end_time = (i + 1) * duration_per_word
token = ASRToken(
start=start_time,
end=end_time,
text=word,
probability=1.0
)
tokens.append(token)
except Exception as e:
logger.error(f"Error converting SimulStreaming result to tokens: {e}")
return tokens
def segments_end_ts(self, result) -> List[float]:
"""Get segment end timestamps."""
if torch.is_tensor(result):
num_tokens = len(result)
return [num_tokens * 0.1] # rough estimate
return [1.0]
def use_vad(self):
"""Enable VAD - SimulStreaming has different VAD handling."""
logger.info("VAD requested for SimulStreaming - handled internally by the model")
pass
def set_translate_task(self):
"""Set up translation task."""
try:
self.model.tokenizer = tokenizer.get_tokenizer(
multilingual=True,
language=self.model.cfg.language,
num_languages=self.model.model.num_languages,
task="translate"
)
logger.info("SimulStreaming configured for translation task")
except Exception as e:
logger.error(f"Failed to configure SimulStreaming for translation: {e}")
raise
def warmup(self, audio, init_prompt=""):
"""Warmup the SimulStreaming model."""
try:
if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio).float()
self.model.insert_audio(audio)
self.model.infer(True)
self.model.refresh_segment(complete=True)
logger.info("SimulStreaming model warmed up successfully")
except Exception as e:
logger.warning(f"SimulStreaming warmup failed: {e}")

View File

@@ -6,18 +6,6 @@ from whisperlivekit.timed_objects import ASRToken, Sentence, Transcript
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# simulStreaming imports - we check if the files are here
try:
import torch
from whisperlivekit.simul_whisper.config import AlignAttConfig
SIMULSTREAMING_AVAILABLE = True
except ImportError:
logger.warning("SimulStreaming dependencies not available for online processor.")
SIMULSTREAMING_AVAILABLE = False
OnlineProcessorInterface = None
torch = None
class HypothesisBuffer: class HypothesisBuffer:
""" """
Buffer to store and process ASR hypothesis tokens. Buffer to store and process ASR hypothesis tokens.
@@ -134,6 +122,7 @@ class OnlineASRProcessor:
self.tokenize = tokenize_method self.tokenize = tokenize_method
self.logfile = logfile self.logfile = logfile
self.confidence_validation = confidence_validation self.confidence_validation = confidence_validation
self.global_time_offset = 0.0
self.init() self.init()
self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
@@ -164,6 +153,21 @@ class OnlineASRProcessor:
"""Append an audio chunk (a numpy array) to the current audio buffer.""" """Append an audio chunk (a numpy array) to the current audio buffer."""
self.audio_buffer = np.append(self.audio_buffer, audio) self.audio_buffer = np.append(self.audio_buffer, audio)
def insert_silence(self, silence_duration, offset):
"""
If silences are > 5s, we do a complete context clear. Otherwise, we just insert a small silence and shift the last_attend_frame
"""
# if self.transcript_buffer.buffer:
# self.committed.extend(self.transcript_buffer.buffer)
# self.transcript_buffer.buffer = []
if True: #silence_duration < 3: #we want the last audio to be treated to not have a gap. could also be handled in the future in ends_with_silence.
gap_silence = np.zeros(int(16000 * silence_duration), dtype=np.int16)
self.insert_audio_chunk(gap_silence)
else:
self.init(offset=silence_duration + offset)
self.global_time_offset += silence_duration
def prompt(self) -> Tuple[str, str]: def prompt(self) -> Tuple[str, str]:
""" """
Returns a tuple: (prompt, context), where: Returns a tuple: (prompt, context), where:
@@ -242,6 +246,9 @@ class OnlineASRProcessor:
logger.debug( logger.debug(
f"Length of audio buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds" f"Length of audio buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds"
) )
if self.global_time_offset:
for token in committed_tokens:
token = token.with_offset(self.global_time_offset)
return committed_tokens, current_audio_processed_upto return committed_tokens, current_audio_processed_upto
def chunk_completed_sentence(self): def chunk_completed_sentence(self):
@@ -403,330 +410,3 @@ class OnlineASRProcessor:
start = None start = None
end = None end = None
return Transcript(start, end, text, probability=probability) return Transcript(start, end, text, probability=probability)
class VACOnlineASRProcessor:
"""
Wraps an OnlineASRProcessor with a Voice Activity Controller (VAC).
It receives small chunks of audio, applies VAD (e.g. with Silero),
and when the system detects a pause in speech (or end of an utterance)
it finalizes the utterance immediately.
"""
SAMPLING_RATE = 16000
def __init__(self, online_chunk_size: float, *args, **kwargs):
self.online_chunk_size = online_chunk_size
self.online = OnlineASRProcessor(*args, **kwargs)
self.asr = self.online.asr
# Load a VAD model (e.g. Silero VAD)
import torch
model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
from .silero_vad_iterator import FixedVADIterator
self.vac = FixedVADIterator(model)
self.logfile = self.online.logfile
self.last_input_audio_stream_end_time: float = 0.0
self.init()
def init(self):
self.online.init()
self.vac.reset_states()
self.current_online_chunk_buffer_size = 0
self.last_input_audio_stream_end_time = self.online.buffer_time_offset
self.is_currently_final = False
self.status: Optional[str] = None # "voice" or "nonvoice"
self.audio_buffer = np.array([], dtype=np.float32)
self.buffer_offset = 0 # in frames
def get_audio_buffer_end_time(self) -> float:
"""Returns the absolute end time of the audio processed by the underlying OnlineASRProcessor."""
return self.online.get_audio_buffer_end_time()
def clear_buffer(self):
self.buffer_offset += len(self.audio_buffer)
self.audio_buffer = np.array([], dtype=np.float32)
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
"""
Process an incoming small audio chunk:
- run VAD on the chunk,
- decide whether to send the audio to the online ASR processor immediately,
- and/or to mark the current utterance as finished.
"""
self.last_input_audio_stream_end_time = audio_stream_end_time
res = self.vac(audio)
self.audio_buffer = np.append(self.audio_buffer, audio)
if res is not None:
# VAD returned a result; adjust the frame number
frame = list(res.values())[0] - self.buffer_offset
if "start" in res and "end" not in res:
self.status = "voice"
send_audio = self.audio_buffer[frame:]
self.online.init(offset=(frame + self.buffer_offset) / self.SAMPLING_RATE)
self.online.insert_audio_chunk(send_audio)
self.current_online_chunk_buffer_size += len(send_audio)
self.clear_buffer()
elif "end" in res and "start" not in res:
self.status = "nonvoice"
send_audio = self.audio_buffer[:frame]
self.online.insert_audio_chunk(send_audio)
self.current_online_chunk_buffer_size += len(send_audio)
self.is_currently_final = True
self.clear_buffer()
else:
beg = res["start"] - self.buffer_offset
end = res["end"] - self.buffer_offset
self.status = "nonvoice"
send_audio = self.audio_buffer[beg:end]
self.online.init(offset=(beg + self.buffer_offset) / self.SAMPLING_RATE)
self.online.insert_audio_chunk(send_audio)
self.current_online_chunk_buffer_size += len(send_audio)
self.is_currently_final = True
self.clear_buffer()
else:
if self.status == "voice":
self.online.insert_audio_chunk(self.audio_buffer)
self.current_online_chunk_buffer_size += len(self.audio_buffer)
self.clear_buffer()
else:
# Keep 1 second worth of audio in case VAD later detects voice,
# but trim to avoid unbounded memory usage.
self.buffer_offset += max(0, len(self.audio_buffer) - self.SAMPLING_RATE)
self.audio_buffer = self.audio_buffer[-self.SAMPLING_RATE:]
def process_iter(self) -> Tuple[List[ASRToken], float]:
"""
Depending on the VAD status and the amount of accumulated audio,
process the current audio chunk.
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
"""
if self.is_currently_final:
return self.finish()
elif self.current_online_chunk_buffer_size > self.SAMPLING_RATE * self.online_chunk_size:
self.current_online_chunk_buffer_size = 0
return self.online.process_iter()
else:
logger.debug("No online update, only VAD")
return [], self.last_input_audio_stream_end_time
def finish(self) -> Tuple[List[ASRToken], float]:
"""
Finish processing by flushing any remaining text.
Returns a tuple: (list of remaining ASRToken objects, float representing the final audio processed up to time).
"""
result_tokens, processed_upto = self.online.finish()
self.current_online_chunk_buffer_size = 0
self.is_currently_final = False
return result_tokens, processed_upto
def get_buffer(self):
"""
Get the unvalidated buffer in string format.
"""
return self.online.concatenate_tokens(self.online.transcript_buffer.buffer)
class SimulStreamingOnlineProcessor:
SAMPLING_RATE = 16000
def __init__(
self,
asr,
tokenize_method: Optional[callable] = None,
buffer_trimming: Tuple[str, float] = ("segment", 15),
confidence_validation = False,
logfile=sys.stderr,
):
if not SIMULSTREAMING_AVAILABLE:
raise ImportError("SimulStreaming dependencies are not available.")
self.asr = asr
self.tokenize = tokenize_method
self.logfile = logfile
self.confidence_validation = confidence_validation
self.init()
# buffer does not work yet
self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
def init(self, offset: Optional[float] = None):
"""Initialize or reset the processing state."""
self.audio_chunks = []
self.offset = offset if offset is not None else 0.0
self.is_last = False
self.beg = self.offset
self.end = self.offset
self.cumulative_audio_duration = 0.0
self.last_audio_stream_end_time = self.offset
self.committed: List[ASRToken] = []
self.last_result_tokens: List[ASRToken] = []
self.buffer_content = ""
self.processed_audio_duration = 0.0
def get_audio_buffer_end_time(self) -> float:
"""Returns the absolute end time of the current audio buffer."""
return self.end
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: Optional[float] = None):
"""Append an audio chunk to be processed by SimulStreaming."""
if torch is None:
raise ImportError("PyTorch is required for SimulStreaming but not available")
# Convert numpy array to torch tensor
audio_tensor = torch.from_numpy(audio).float()
self.audio_chunks.append(audio_tensor)
# Update timing
chunk_duration = len(audio) / self.SAMPLING_RATE
self.cumulative_audio_duration += chunk_duration
if audio_stream_end_time is not None:
self.last_audio_stream_end_time = audio_stream_end_time
self.end = audio_stream_end_time
else:
self.end = self.offset + self.cumulative_audio_duration
def prompt(self) -> Tuple[str, str]:
"""
Returns a tuple: (prompt, context).
SimulStreaming handles prompting internally, so we return empty strings.
"""
return "", ""
def get_buffer(self):
"""
Get the unvalidated buffer content.
"""
buffer_end = self.end if hasattr(self, 'end') else None
return Transcript(
start=None,
end=buffer_end,
text=self.buffer_content,
probability=None
)
def timestamped_text(self, tokens, generation):
# From the simulstreaming repo. self.model to self.asr.model
pr = generation["progress"]
if "result" not in generation:
split_words, split_tokens = self.asr.model.tokenizer.split_to_word_tokens(tokens)
else:
split_words, split_tokens = generation["result"]["split_words"], generation["result"]["split_tokens"]
frames = [p["most_attended_frames"][0] for p in pr]
tokens = tokens.copy()
ret = []
for sw,st in zip(split_words,split_tokens):
b = None
for stt in st:
t,f = tokens.pop(0), frames.pop(0)
if t != stt:
raise ValueError(f"Token mismatch: {t} != {stt} at frame {f}.")
if b is None:
b = f
e = f
out = (b*0.02, e*0.02, sw)
ret.append(out)
logger.debug(f"TS-WORD:\t{' '.join(map(str, out))}")
return ret
def process_iter(self) -> Tuple[List[ASRToken], float]:
"""
Process accumulated audio chunks using SimulStreaming.
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
"""
if not self.audio_chunks:
return [], self.end
try:
# concatenate all audio chunks
if len(self.audio_chunks) == 1:
audio = self.audio_chunks[0]
else:
audio = torch.cat(self.audio_chunks, dim=0)
audio_duration = audio.shape[0] / self.SAMPLING_RATE if audio.shape[0] > 0 else 0
self.processed_audio_duration += audio_duration
self.audio_chunks = []
logger.debug(f"SimulStreaming processing audio shape: {audio.shape}, duration: {audio_duration:.2f}s")
logger.debug(f"Current end time: {self.end:.2f}s, last stream time: {self.last_audio_stream_end_time:.2f}s")
self.asr.model.insert_audio(audio)
tokens, generation_progress = self.asr.model.infer(is_last=self.is_last)
ts_words = self.timestamped_text(tokens, generation_progress)
text = self.asr.model.tokenizer.decode(tokens)
new_tokens = []
for ts_word in ts_words:
start, end, word = ts_word
token = ASRToken(
start=start,
end=end,
text=word,
probability=0.95 # fake prob. Maybe we can extract it from the model?
)
new_tokens.append(token)
self.committed.extend(new_tokens)
return new_tokens, self.end
except Exception as e:
logger.error(f"SimulStreaming processing error: {e}")
logger.error(f"Error details: {type(e).__name__}: {str(e)}")
return [], self.end
def finish(self) -> Tuple[List[ASRToken], float]:
logger.debug("SimulStreaming finish() called")
self.is_last = True
final_tokens, final_time = self.process_iter()
self.is_last = False
return final_tokens, final_time
def concatenate_tokens(
self,
tokens: List[ASRToken],
sep: Optional[str] = None,
offset: float = 0
) -> Transcript:
"""Concatenate tokens into a Transcript object."""
sep = sep if sep is not None else self.asr.sep
text = sep.join(token.text for token in tokens)
probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
if tokens:
start = offset + tokens[0].start
end = offset + tokens[-1].end
else:
start = None
end = None
return Transcript(start, end, text, probability=probability)
def chunk_at(self, time: float):
"""
useless but kept for compatibility
"""
logger.debug(f"SimulStreaming chunk_at({time:.2f}) - handled internally")
pass
def words_to_sentences(self, tokens: List[ASRToken]) -> List[Sentence]:
"""
Create simple sentences.
"""
if not tokens:
return []
full_text = " ".join(token.text for token in tokens)
sentence = Sentence(
start=tokens[0].start,
end=tokens[-1].end,
text=full_text
)
return [sentence]

View File

@@ -5,8 +5,7 @@ import librosa
from functools import lru_cache from functools import lru_cache
import time import time
import logging import logging
from .backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR, SimulStreamingASR, SIMULSTREAMING_AVAILABLE, SIMULSTREAMING_ERROR_AND_INSTALLATION_INSTRUCTIONS from .backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR
from .online_asr import OnlineASRProcessor, VACOnlineASRProcessor, SimulStreamingOnlineProcessor, SIMULSTREAMING_AVAILABLE as SIMULSTREAMING_ONLINE_AVAILABLE
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -68,35 +67,7 @@ def backend_factory(args):
backend = args.backend backend = args.backend
if backend == "openai-api": if backend == "openai-api":
logger.debug("Using OpenAI API.") logger.debug("Using OpenAI API.")
asr = OpenaiApiASR(lan=args.lan) asr = OpenaiApiASR(lan=args.lan)
elif backend == "simulstreaming":
logger.debug("Using SimulStreaming backend.")
if not SIMULSTREAMING_AVAILABLE:
raise SIMULSTREAMING_ERROR_AND_INSTALLATION_INSTRUCTIONS
simulstreaming_kwargs = {}
for attr in ['frame_threshold', 'beams', 'decoder_type', 'audio_max_len', 'audio_min_len',
'cif_ckpt_path', 'never_fire', 'init_prompt', 'static_init_prompt',
'max_context_tokens', 'model_path']:
if hasattr(args, attr):
simulstreaming_kwargs[attr] = getattr(args, attr)
# Add segment_length from min_chunk_size
simulstreaming_kwargs['segment_length'] = getattr(args, 'min_chunk_size', 0.5)
simulstreaming_kwargs['task'] = args.task
size = args.model
t = time.time()
logger.info(f"Loading SimulStreaming {size} model for language {args.lan}...")
asr = SimulStreamingASR(
modelsize=size,
lan=args.lan,
cache_dir=getattr(args, 'model_cache_dir', None),
model_dir=getattr(args, 'model_dir', None),
**simulstreaming_kwargs
)
e = time.time()
logger.info(f"done. It took {round(e-t,2)} seconds.")
else: else:
if backend == "faster-whisper": if backend == "faster-whisper":
asr_cls = FasterWhisperASR asr_cls = FasterWhisperASR
@@ -136,107 +107,4 @@ def backend_factory(args):
tokenizer = create_tokenizer(tgt_language) tokenizer = create_tokenizer(tgt_language)
else: else:
tokenizer = None tokenizer = None
return asr, tokenizer return asr, tokenizer
def online_factory(args, asr, tokenizer, logfile=sys.stderr):
if args.backend == "simulstreaming":
if not SIMULSTREAMING_ONLINE_AVAILABLE:
raise SIMULSTREAMING_ERROR_AND_INSTALLATION_INSTRUCTIONS
logger.debug("Creating SimulStreaming online processor")
online = SimulStreamingOnlineProcessor(
asr,
tokenizer,
logfile=logfile,
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
confidence_validation=args.confidence_validation
)
elif args.vac:
online = VACOnlineASRProcessor(
args.min_chunk_size,
asr,
tokenizer,
logfile=logfile,
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
confidence_validation = args.confidence_validation
)
else:
online = OnlineASRProcessor(
asr,
tokenizer,
logfile=logfile,
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
confidence_validation = args.confidence_validation
)
return online
def asr_factory(args, logfile=sys.stderr):
"""
Creates and configures an ASR and ASR Online instance based on the specified backend and arguments.
"""
asr, tokenizer = backend_factory(args)
online = online_factory(args, asr, tokenizer, logfile=logfile)
return asr, online
def warmup_asr(asr, warmup_file=None, timeout=5):
"""
Warmup the ASR model by transcribing a short audio file.
"""
import os
import tempfile
is_simulstreaming = hasattr(asr, 'warmup') and callable(getattr(asr, 'warmup'))
if warmup_file is None:
# Download JFK sample if not already present
jfk_url = "https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav"
temp_dir = tempfile.gettempdir()
warmup_file = os.path.join(temp_dir, "whisper_warmup_jfk.wav")
if not os.path.exists(warmup_file):
logger.debug(f"Downloading warmup file from {jfk_url}")
print(f"Downloading warmup file from {jfk_url}")
import time
import urllib.request
import urllib.error
import socket
original_timeout = socket.getdefaulttimeout()
socket.setdefaulttimeout(timeout)
start_time = time.time()
try:
urllib.request.urlretrieve(jfk_url, warmup_file)
logger.debug(f"Download successful in {time.time() - start_time:.2f}s")
except (urllib.error.URLError, socket.timeout) as e:
logger.warning(f"Download failed: {e}. Proceeding without warmup.")
return False
finally:
socket.setdefaulttimeout(original_timeout)
elif not warmup_file:
return False
if not warmup_file or not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0:
logger.warning(f"Warmup file {warmup_file} invalid or missing.")
return False
print(f"Warming up {'SimulStreaming' if is_simulstreaming else 'Whisper'} with {warmup_file}")
try:
import librosa
audio, sr = librosa.load(warmup_file, sr=16000)
except Exception as e:
logger.warning(f"Failed to load audio file: {e}")
return False
try:
if is_simulstreaming:
asr.warmup(audio)
else:
asr.transcribe(audio)
logger.info(f"{'SimulStreaming' if is_simulstreaming else 'Whisper'} is warmed up")
return True
except Exception as e:
logger.warning(f"Warmup failed: {e}")
return False