mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
Compare commits
248 Commits
seamless-s
...
0.0.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e61d1d111f | ||
|
|
5624c1f6b7 | ||
|
|
7679370cf6 | ||
|
|
5ca65e21b7 | ||
|
|
dc02bcdbdd | ||
|
|
4f87ac3ea4 | ||
|
|
eead544977 | ||
|
|
f4a57cd810 | ||
|
|
b768b219fe | ||
|
|
2fb386f94c | ||
|
|
cb5cf39336 | ||
|
|
3024a9bdb2 | ||
|
|
7b582f3f9f | ||
|
|
8ae38a48ef | ||
|
|
fc3ffada59 | ||
|
|
e3550ef07d | ||
|
|
b502c8c81d | ||
|
|
b37d3cafb3 | ||
|
|
d304011aac | ||
|
|
597772c6c5 | ||
|
|
a656ccae72 | ||
|
|
e910873312 | ||
|
|
2a869cd509 | ||
|
|
d053bac871 | ||
|
|
e486ef8d98 | ||
|
|
0a1fb08371 | ||
|
|
ddb8860528 | ||
|
|
2e19516b3e | ||
|
|
3c7bc6f472 | ||
|
|
2d2a4967e6 | ||
|
|
7e880e039e | ||
|
|
627386a8a4 | ||
|
|
14af47e84b | ||
|
|
00eb4a0a4f | ||
|
|
2f87e592e0 | ||
|
|
56717b094f | ||
|
|
7b1c88589e | ||
|
|
72ce8d0e3f | ||
|
|
09090aa3f5 | ||
|
|
d3960ffef9 | ||
|
|
247582fb33 | ||
|
|
091d5d7bf5 | ||
|
|
9d5d6d8031 | ||
|
|
8aa3c760c7 | ||
|
|
f925ef3786 | ||
|
|
2ced4fef20 | ||
|
|
5b9b9328e0 | ||
|
|
d89622b9c2 | ||
|
|
d4096e7e11 | ||
|
|
296327071d | ||
|
|
34b707d84e | ||
|
|
f200f2cad4 | ||
|
|
8c6d39162f | ||
|
|
e3adc379ed | ||
|
|
90f24ef537 | ||
|
|
e4c84346c9 | ||
|
|
cf7944f13d | ||
|
|
d7c945dcce | ||
|
|
fa39eda923 | ||
|
|
01f02b066a | ||
|
|
a93bae69a5 | ||
|
|
f21dad559d | ||
|
|
97c0ae6154 | ||
|
|
09d40a7de8 | ||
|
|
2608abf0f3 | ||
|
|
58eba2a1f6 | ||
|
|
450c93fef8 | ||
|
|
1ffa2fa224 | ||
|
|
dc24366580 | ||
|
|
6121083549 | ||
|
|
0ecac75455 | ||
|
|
525abcbca7 | ||
|
|
365e7c882f | ||
|
|
84b09bb2cc | ||
|
|
4601e97221 | ||
|
|
15089c80fd | ||
|
|
788fe1c676 | ||
|
|
d623578d95 | ||
|
|
149d2ee44c | ||
|
|
adaca751ce | ||
|
|
eb989038bd | ||
|
|
1f6119e405 | ||
|
|
f7f1f259c1 | ||
|
|
b82cc3b613 | ||
|
|
46f7f9cbd1 | ||
|
|
48c111f494 | ||
|
|
54628274d6 | ||
|
|
0d874fb515 | ||
|
|
4d1aa4421a | ||
|
|
f4d98e2c8c | ||
|
|
15205f31d1 | ||
|
|
b1f7034577 | ||
|
|
23dee02d56 | ||
|
|
efd80095a7 | ||
|
|
f4d3df3d87 | ||
|
|
9c7d429e15 | ||
|
|
611d33cba5 | ||
|
|
ab7c22d3e3 | ||
|
|
870a779666 | ||
|
|
c3d72cae7c | ||
|
|
4622fe7aff | ||
|
|
8ee1488c08 | ||
|
|
77d43885a3 | ||
|
|
04170153e0 | ||
|
|
baddf0284b | ||
|
|
6e0f1dda25 | ||
|
|
c66794e1f5 | ||
|
|
f0eaffacd3 | ||
|
|
69a2ed6bfb | ||
|
|
25eb276794 | ||
|
|
9f262813ec | ||
|
|
4293580581 | ||
|
|
42d2784c20 | ||
|
|
7fad0a3ee2 | ||
|
|
27d2db77f7 | ||
|
|
fba37eba0a | ||
|
|
5523b51fd7 | ||
|
|
9bdb92e923 | ||
|
|
b51c8427f4 | ||
|
|
977436622a | ||
|
|
ce56264241 | ||
|
|
9cbac96c44 | ||
|
|
3f30d3de6e | ||
|
|
f884d1162d | ||
|
|
6ee91c3c93 | ||
|
|
f52a5ae3c2 | ||
|
|
0ff6067f37 | ||
|
|
da6c8d25e4 | ||
|
|
aa0ba598f0 | ||
|
|
b7a2d23a18 | ||
|
|
58e48bb717 | ||
|
|
6a04ddbed2 | ||
|
|
aa4d2599cc | ||
|
|
5fdb08edae | ||
|
|
4cb3660666 | ||
|
|
122368bff3 | ||
|
|
0d833eaea2 | ||
|
|
c960d1571d | ||
|
|
1aa1b9ea99 | ||
|
|
99019f1dd7 | ||
|
|
1cea20a42d | ||
|
|
50bbd26517 | ||
|
|
cf5d1cf013 | ||
|
|
0553b75415 | ||
|
|
baa01728be | ||
|
|
8dcebd9329 | ||
|
|
bfe973a0d2 | ||
|
|
87cab7c280 | ||
|
|
bee27c68e6 | ||
|
|
aa4480b138 | ||
|
|
cc92e97e17 | ||
|
|
8c6c0104a3 | ||
|
|
494b6e3ca9 | ||
|
|
d045137ba8 | ||
|
|
54a37fbcb6 | ||
|
|
104f7bde03 | ||
|
|
e6648e4f46 | ||
|
|
863242f107 | ||
|
|
d48895c343 | ||
|
|
8cfd8d85a3 | ||
|
|
e1b0e146a5 | ||
|
|
e3dc524783 | ||
|
|
2de090023c | ||
|
|
e25ad4fcd7 | ||
|
|
63870987c0 | ||
|
|
7eeb73f4d4 | ||
|
|
d665f9a96e | ||
|
|
827425bb91 | ||
|
|
4a89935ee5 | ||
|
|
4c17b56041 | ||
|
|
52da12120c | ||
|
|
7edc534f8a | ||
|
|
14c2bbef87 | ||
|
|
36bf3a32d4 | ||
|
|
2ec2266929 | ||
|
|
f3907703ed | ||
|
|
13fd21a201 | ||
|
|
84a999570a | ||
|
|
884958127f | ||
|
|
726fa574a2 | ||
|
|
333eea4b76 | ||
|
|
8d60fd3bf6 | ||
|
|
9c15262015 | ||
|
|
7bca7a2b8e | ||
|
|
264b8a32c2 | ||
|
|
b50f68749b | ||
|
|
7286dfdfa1 | ||
|
|
8060d45aea | ||
|
|
df64b4e2c3 | ||
|
|
97a4ebdf15 | ||
|
|
2ba48bcbf4 | ||
|
|
dcddb17de8 | ||
|
|
f32eeef4dd | ||
|
|
bb93952fd2 | ||
|
|
ce215e621b | ||
|
|
e0f5d42b13 | ||
|
|
2afc97db48 | ||
|
|
a7cb7a5469 | ||
|
|
8883397b44 | ||
|
|
626dedf2f5 | ||
|
|
fc4b3cd518 | ||
|
|
23a018d341 | ||
|
|
70bc57180c | ||
|
|
cc56fdd931 | ||
|
|
380c30d48d | ||
|
|
5ebbed3bd7 | ||
|
|
d497503b5c | ||
|
|
6b1c2c5606 | ||
|
|
8223afee78 | ||
|
|
b3647da087 | ||
|
|
3af93975cc | ||
|
|
bccbb15177 | ||
|
|
006de3e7b0 | ||
|
|
50937bb872 | ||
|
|
8896389ea3 | ||
|
|
5929a82896 | ||
|
|
706b7f847e | ||
|
|
4405c451ce | ||
|
|
24926c98e0 | ||
|
|
db8b7d2883 | ||
|
|
80eb0baf5d | ||
|
|
949304ab05 | ||
|
|
9fcd403439 | ||
|
|
922ad18ebc | ||
|
|
f0a24cd5e1 | ||
|
|
3696fef2b1 | ||
|
|
531418ad07 | ||
|
|
2270014219 | ||
|
|
f8b2ae07b8 | ||
|
|
6ec1f65fe2 | ||
|
|
f412812082 | ||
|
|
c8123344c6 | ||
|
|
6b968c6e29 | ||
|
|
b66c61cf7a | ||
|
|
cd221a3198 | ||
|
|
d65fd8a649 | ||
|
|
50f1b94856 | ||
|
|
ab27bfb361 | ||
|
|
c30969fe27 | ||
|
|
6fa008080a | ||
|
|
d543411bbd | ||
|
|
b2e4e9f727 | ||
|
|
324dee03e7 | ||
|
|
fe4207edca | ||
|
|
ea2a9ca2e6 | ||
|
|
c8c786af4f | ||
|
|
3fad8133b4 | ||
|
|
9556d07484 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -127,3 +127,6 @@ dmypy.json
|
|||||||
|
|
||||||
# Pyre type checker
|
# Pyre type checker
|
||||||
.pyre/
|
.pyre/
|
||||||
|
|
||||||
|
*.wav
|
||||||
|
run_*.sh
|
||||||
46
CONTRIBUTING.md
Normal file
46
CONTRIBUTING.md
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
# Contributing
|
||||||
|
|
||||||
|
Thank you for considering contributing ! We appreciate your time and effort to help make this project better.
|
||||||
|
|
||||||
|
## Before You Start
|
||||||
|
|
||||||
|
1. **Search for Existing Issues or Discussions:**
|
||||||
|
- Before opening a new issue or discussion, please check if there's already an existing one related to your topic. This helps avoid duplicates and keeps discussions centralized.
|
||||||
|
|
||||||
|
2. **Discuss Your Contribution:**
|
||||||
|
- If you plan to make a significant change, it's advisable to discuss it in an issue first. This ensures that your contribution aligns with the project's goals and avoids duplicated efforts.
|
||||||
|
|
||||||
|
3. **General questions about whisper streaming web:**
|
||||||
|
- For general questions about whisper streaming web, use the discussion space on GitHub. This helps in fostering a collaborative environment and encourages knowledge-sharing.
|
||||||
|
|
||||||
|
## Opening Issues
|
||||||
|
|
||||||
|
If you encounter a problem with diart or want to suggest an improvement, please follow these guidelines when opening an issue:
|
||||||
|
|
||||||
|
- **Bug Reports:**
|
||||||
|
- Clearly describe the error. **Please indicate the parameters you use, especially the model(s)**
|
||||||
|
- Provide a minimal, reproducible example that demonstrates the issue.
|
||||||
|
|
||||||
|
- **Feature Requests:**
|
||||||
|
- Clearly outline the new feature you are proposing.
|
||||||
|
- Explain how it would benefit the project.
|
||||||
|
|
||||||
|
## Opening Pull Requests
|
||||||
|
|
||||||
|
We welcome and appreciate contributions! To ensure a smooth review process, please follow these guidelines when opening a pull request:
|
||||||
|
|
||||||
|
- **Commit Messages:**
|
||||||
|
- Write clear and concise commit messages, explaining the purpose of each change.
|
||||||
|
|
||||||
|
- **Documentation:**
|
||||||
|
- Update documentation when introducing new features or making changes that impact existing functionality.
|
||||||
|
|
||||||
|
- **Tests:**
|
||||||
|
- If applicable, add or update tests to cover your changes.
|
||||||
|
|
||||||
|
- **Discuss Before Major Changes:**
|
||||||
|
- If your PR includes significant changes, discuss it in an issue first.
|
||||||
|
|
||||||
|
## Thank You
|
||||||
|
|
||||||
|
Your contributions make diart better for everyone. Thank you for your time and dedication!
|
||||||
323
README.md
323
README.md
@@ -1,231 +1,158 @@
|
|||||||
# whisper_streaming
|
<h1 align="center">WhisperLiveKit</h1>
|
||||||
Whisper realtime streaming for long speech-to-text transcription and translation
|
<p align="center"><b>Real-time, Fully Local Whisper's Speech-to-Text and Speaker Diarization</b></p>
|
||||||
|
|
||||||
**Turning Whisper into Real-Time Transcription System**
|
|
||||||
|
|
||||||
Demonstration paper, by Dominik Macháček, Raj Dabre, Ondřej Bojar, 2023
|
|
||||||
|
|
||||||
Abstract: Whisper is one of the recent state-of-the-art multilingual speech recognition and translation models, however, it is not designed for real time transcription. In this paper, we build on top of Whisper and create Whisper-Streaming, an implementation of real-time speech transcription and translation of Whisper-like models. Whisper-Streaming uses local agreement policy with self-adaptive latency to enable streaming transcription. We show that Whisper-Streaming achieves high quality and 3.3 seconds latency on unsegmented long-form speech transcription test set, and we demonstrate its robustness and practical usability as a component in live transcription service at a multilingual conference.
|
|
||||||
|
|
||||||
|
|
||||||
Paper in proceedings: http://www.afnlp.org/conferences/ijcnlp2023/proceedings/main-demo/cdrom/pdf/2023.ijcnlp-demo.3.pdf
|
This project is based on [Whisper Streaming](https://github.com/ufal/whisper_streaming) and lets you transcribe audio directly from your browser. Simply launch the local server and grant microphone access. Everything runs locally on your machine ✨
|
||||||
|
|
||||||
Demo video: https://player.vimeo.com/video/840442741
|
<p align="center">
|
||||||
|
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/demo.png" alt="Demo Screenshot" width="730">
|
||||||
|
</p>
|
||||||
|
|
||||||
[Slides](http://ufallab.ms.mff.cuni.cz/~machacek/pre-prints/AACL23-2.11.2023-Turning-Whisper-oral.pdf) -- 15 minutes oral presentation at IJCNLP-AACL 2023
|
### Differences from [Whisper Streaming](https://github.com/ufal/whisper_streaming)
|
||||||
|
|
||||||
Please, cite us. [Bibtex citation](http://www.afnlp.org/conferences/ijcnlp2023/proceedings/main-demo/cdrom/bib/2023.ijcnlp-demo.3.bib):
|
#### ⚙️ **Core Improvements**
|
||||||
|
- **Buffering Preview** – Displays unvalidated transcription segments
|
||||||
|
- **Multi-User Support** – Handles multiple users simultaneously by decoupling backend and online asr
|
||||||
|
- **MLX Whisper Backend** – Optimized for Apple Silicon for faster local processing.
|
||||||
|
- **Confidence validation** – Immediately validate high-confidence tokens for faster inference
|
||||||
|
|
||||||
```
|
#### 🎙️ **Speaker Identification**
|
||||||
@InProceedings{machacek-dabre-bojar:2023:ijcnlp,
|
- **Real-Time Diarization** – Identify different speakers in real time using [Diart](https://github.com/juanmc2005/diart)
|
||||||
author = {Macháček, Dominik and Dabre, Raj and Bojar, Ondřej},
|
|
||||||
title = {Turning Whisper into Real-Time Transcription System},
|
#### 🌐 **Web & API**
|
||||||
booktitle = {System Demonstrations},
|
- **Built-in Web UI** – Simple raw html browser interface with no frontend setup required
|
||||||
month = {November},
|
- **FastAPI WebSocket Server** – Real-time speech-to-text processing with async FFmpeg streaming.
|
||||||
year = {2023},
|
- **JavaScript Client** – Ready-to-use MediaRecorder implementation for seamless client-side integration.
|
||||||
address = {Bali, Indonesia},
|
|
||||||
publisher = {Asian Federation of Natural Language Processing},
|
|
||||||
pages = {17--24},
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
1) ``pip install librosa`` -- audio processing library
|
### Via pip
|
||||||
|
|
||||||
2) Whisper backend.
|
```bash
|
||||||
|
pip install whisperlivekit
|
||||||
Two alternative backends are integrated. The most recommended one is [faster-whisper](https://github.com/guillaumekln/faster-whisper) with GPU support. Follow their instructions for NVIDIA libraries -- we succeeded with CUDNN 8.5.0 and CUDA 11.7. Install with `pip install faster-whisper`.
|
|
||||||
|
|
||||||
Alternative, less restrictive, but slower backend is [whisper-timestamped](https://github.com/linto-ai/whisper-timestamped): `pip install git+https://github.com/linto-ai/whisper-timestamped`
|
|
||||||
|
|
||||||
The backend is loaded only when chosen. The unused one does not have to be installed.
|
|
||||||
|
|
||||||
3) Optional, not recommended: sentence segmenter (aka sentence tokenizer)
|
|
||||||
|
|
||||||
Two buffer trimming options are integrated and evaluated. They have impact on
|
|
||||||
the quality and latency. The default "segment" option performs better according
|
|
||||||
to our tests and does not require any sentence segmentation installed.
|
|
||||||
|
|
||||||
The other option, "sentence" -- trimming at the end of confirmed sentences,
|
|
||||||
requires sentence segmenter installed. It splits punctuated text to sentences by full
|
|
||||||
stops, avoiding the dots that are not full stops. The segmenters are language
|
|
||||||
specific. The unused one does not have to be installed. We integrate the
|
|
||||||
following segmenters, but suggestions for better alternatives are welcome.
|
|
||||||
|
|
||||||
- `pip install opus-fast-mosestokenizer` for the languages with codes `as bn ca cs de el en es et fi fr ga gu hi hu is it kn lt lv ml mni mr nl or pa pl pt ro ru sk sl sv ta te yue zh`
|
|
||||||
|
|
||||||
- `pip install tokenize_uk` for Ukrainian -- `uk`
|
|
||||||
|
|
||||||
- for other languages, we integrate a good performing multi-lingual model of `wtpslit`. It requires `pip install torch wtpsplit`, and its neural model `wtp-canine-s-12l-no-adapters`. It is downloaded to the default huggingface cache during the first use.
|
|
||||||
|
|
||||||
- we did not find a segmenter for languages `as ba bo br bs fo haw hr ht jw lb ln lo mi nn oc sa sd sn so su sw tk tl tt` that are supported by Whisper and not by wtpsplit. The default fallback option for them is wtpsplit with unspecified language. Alternative suggestions welcome.
|
|
||||||
|
|
||||||
In case of installation issues of opus-fast-mosestokenizer, especially on Windows and Mac, we recommend using only the "segment" option that does not require it.
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
### Real-time simulation from audio file
|
|
||||||
|
|
||||||
```
|
|
||||||
usage: whisper_online.py [-h] [--min-chunk-size MIN_CHUNK_SIZE] [--model {tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large}] [--model_cache_dir MODEL_CACHE_DIR] [--model_dir MODEL_DIR] [--lan LAN] [--task {transcribe,translate}]
|
|
||||||
[--backend {faster-whisper,whisper_timestamped}] [--vad] [--buffer_trimming {sentence,segment}] [--buffer_trimming_sec BUFFER_TRIMMING_SEC] [--start_at START_AT] [--offline] [--comp_unaware]
|
|
||||||
audio_path
|
|
||||||
|
|
||||||
positional arguments:
|
|
||||||
audio_path Filename of 16kHz mono channel wav, on which live streaming is simulated.
|
|
||||||
|
|
||||||
options:
|
|
||||||
-h, --help show this help message and exit
|
|
||||||
--min-chunk-size MIN_CHUNK_SIZE
|
|
||||||
Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.
|
|
||||||
--model {tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large}
|
|
||||||
Name size of the Whisper model to use (default: large-v2). The model is automatically downloaded from the model hub if not present in model cache dir.
|
|
||||||
--model_cache_dir MODEL_CACHE_DIR
|
|
||||||
Overriding the default model cache dir where models downloaded from the hub are saved
|
|
||||||
--model_dir MODEL_DIR
|
|
||||||
Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.
|
|
||||||
--lan LAN, --language LAN
|
|
||||||
Language code for transcription, e.g. en,de,cs.
|
|
||||||
--task {transcribe,translate}
|
|
||||||
Transcribe or translate.
|
|
||||||
--backend {faster-whisper,whisper_timestamped}
|
|
||||||
Load only this backend for Whisper processing.
|
|
||||||
--vad Use VAD = voice activity detection, with the default parameters.
|
|
||||||
--buffer_trimming {sentence,segment}
|
|
||||||
Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.
|
|
||||||
--buffer_trimming_sec BUFFER_TRIMMING_SEC
|
|
||||||
Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.
|
|
||||||
--start_at START_AT Start processing audio at this time.
|
|
||||||
--offline Offline mode.
|
|
||||||
--comp_unaware Computationally unaware simulation.
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Example:
|
### From source
|
||||||
|
|
||||||
It simulates realtime processing from a pre-recorded mono 16k wav file.
|
1. **Clone the Repository**:
|
||||||
|
|
||||||
```
|
```bash
|
||||||
python3 whisper_online.py en-demo16.wav --language en --min-chunk-size 1 > out.txt
|
git clone https://github.com/QuentinFuxa/WhisperLiveKit
|
||||||
```
|
cd WhisperLiveKit
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
Simulation modes:
|
### System Dependencies
|
||||||
|
|
||||||
- default mode, no special option: real-time simulation from file, computationally aware. The chunk size is `MIN_CHUNK_SIZE` or larger, if more audio arrived during last update computation.
|
You need to install FFmpeg on your system:
|
||||||
|
|
||||||
- `--comp_unaware` option: computationally unaware simulation. It means that the timer that counts the emission times "stops" when the model is computing. The chunk size is always `MIN_CHUNK_SIZE`. The latency is caused only by the model being unable to confirm the output, e.g. because of language ambiguity etc., and not because of slow hardware or suboptimal implementation. We implement this feature for finding the lower bound for latency.
|
- Install system dependencies:
|
||||||
|
```bash
|
||||||
|
# Install FFmpeg on your system (required for audio processing)
|
||||||
|
# For Ubuntu/Debian:
|
||||||
|
sudo apt install ffmpeg
|
||||||
|
|
||||||
|
# For macOS:
|
||||||
|
brew install ffmpeg
|
||||||
|
|
||||||
|
# For Windows:
|
||||||
|
# Download from https://ffmpeg.org/download.html and add to PATH
|
||||||
|
```
|
||||||
|
|
||||||
- `--start_at START_AT`: Start processing audio at this time. The first update receives the whole audio by `START_AT`. It is useful for debugging, e.g. when we observe a bug in a specific time in audio file, and want to reproduce it quickly, without long waiting.
|
- Install required Python dependencies:
|
||||||
|
|
||||||
- `--offline` option: It processes the whole audio file at once, in offline mode. We implement it to find out the lowest possible WER on given audio file.
|
```bash
|
||||||
|
# Whisper streaming required dependencies
|
||||||
|
pip install librosa soundfile
|
||||||
|
|
||||||
|
# Whisper streaming web required dependencies
|
||||||
|
pip install fastapi ffmpeg-python
|
||||||
|
```
|
||||||
|
- Install at least one whisper backend among:
|
||||||
|
|
||||||
|
```
|
||||||
|
whisper
|
||||||
|
whisper-timestamped
|
||||||
|
faster-whisper (faster backend on NVIDIA GPU)
|
||||||
|
mlx-whisper (faster backend on Apple Silicon)
|
||||||
|
```
|
||||||
|
- Optionnal dependencies
|
||||||
|
|
||||||
|
```
|
||||||
|
# If you want to use VAC (Voice Activity Controller). Useful for preventing hallucinations
|
||||||
|
torch
|
||||||
|
|
||||||
|
# If you choose sentences as buffer trimming strategy
|
||||||
|
mosestokenizer
|
||||||
|
wtpsplit
|
||||||
|
tokenize_uk # If you work with Ukrainian text
|
||||||
|
|
||||||
|
# If you want to run the server using uvicorn (recommended)
|
||||||
|
uvicorn
|
||||||
|
|
||||||
|
# If you want to use diarization
|
||||||
|
diart
|
||||||
|
```
|
||||||
|
|
||||||
|
Diart uses by default [pyannote.audio](https://github.com/pyannote/pyannote-audio) models from the _huggingface hub_. To use them, please follow the steps described [here](https://github.com/juanmc2005/diart?tab=readme-ov-file#get-access-to--pyannote-models).
|
||||||
|
|
||||||
|
|
||||||
|
3. **Run the FastAPI Server**:
|
||||||
|
|
||||||
### Output format
|
```bash
|
||||||
|
python whisper_fastapi_online_server.py --host 0.0.0.0 --port 8000
|
||||||
|
```
|
||||||
|
|
||||||
```
|
**Parameters**
|
||||||
2691.4399 300 1380 Chairman, thank you.
|
|
||||||
6914.5501 1940 4940 If the debate today had a
|
The following parameters are supported:
|
||||||
9019.0277 5160 7160 the subject the situation in
|
|
||||||
10065.1274 7180 7480 Gaza
|
- `--host` and `--port` let you specify the server's IP/port.
|
||||||
11058.3558 7480 9460 Strip, I might
|
- `-min-chunk-size` sets the minimum chunk size for audio processing. Make sure this value aligns with the chunk size selected in the frontend. If not aligned, the system will work but may unnecessarily over-process audio data.
|
||||||
12224.3731 9460 9760 have
|
- `--transcription`: Enable/disable transcription (default: True)
|
||||||
13555.1929 9760 11060 joined Mrs.
|
- `--diarization`: Enable/disable speaker diarization (default: False)
|
||||||
14928.5479 11140 12240 De Kaiser and all the
|
- `--confidence-validation`: Use confidence scores for faster validation. Transcription will be faster but punctuation might be less accurate (default: True)
|
||||||
16588.0787 12240 12560 other
|
- `--warmup-file`: The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast. :
|
||||||
18324.9285 12560 14420 colleagues across the
|
- If not set, uses https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav.
|
||||||
```
|
- If False, no warmup is performed.
|
||||||
|
- `--min-chunk-size` Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.
|
||||||
|
- `--model` {_tiny.en, tiny, base.en, base, small.en, small, medium.en, medium, large-v1, large-v2, large-v3, large, large-v3-turbo_}
|
||||||
|
Name size of the Whisper model to use (default: tiny). The model is automatically downloaded from the model hub if not present in model cache dir.
|
||||||
|
- `--model_cache_dir` Overriding the default model cache dir where models downloaded from the hub are saved
|
||||||
|
- `--model_dir` Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.
|
||||||
|
- `--lan`, --language Source language code, e.g. en,de,cs, or 'auto' for language detection.
|
||||||
|
- `--task` {_transcribe, translate_} Transcribe or translate. If translate is set, we recommend avoiding the _large-v3-turbo_ backend, as it [performs significantly worse](https://github.com/QuentinFuxa/whisper_streaming_web/issues/40#issuecomment-2652816533) than other models for translation.
|
||||||
|
- `--backend` {_faster-whisper, whisper_timestamped, openai-api, mlx-whisper_} Load only this backend for Whisper processing.
|
||||||
|
- `--vac` Use VAC = voice activity controller. Requires torch.
|
||||||
|
- `--vac-chunk-size` VAC sample size in seconds.
|
||||||
|
- `--vad` Use VAD = voice activity detection, with the default parameters.
|
||||||
|
- `--buffer_trimming` {_sentence, segment_} Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.
|
||||||
|
- `--buffer_trimming_sec` Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.
|
||||||
|
|
||||||
[See description here](https://github.com/ufal/whisper_streaming/blob/d915d790a62d7be4e7392dde1480e7981eb142ae/whisper_online.py#L361)
|
5. **Open the Provided HTML**:
|
||||||
|
|
||||||
### As a module
|
- By default, the server root endpoint `/` serves a simple `live_transcription.html` page.
|
||||||
|
- Open your browser at `http://localhost:8000` (or replace `localhost` and `8000` with whatever you specified).
|
||||||
|
- The page uses vanilla JavaScript and the WebSocket API to capture your microphone and stream audio to the server in real time.
|
||||||
|
|
||||||
TL;DR: use OnlineASRProcessor object and its methods insert_audio_chunk and process_iter.
|
### How the Live Interface Works
|
||||||
|
|
||||||
The code whisper_online.py is nicely commented, read it as the full documentation.
|
- Once you **allow microphone access**, the page records small chunks of audio using the **MediaRecorder** API in **webm/opus** format.
|
||||||
|
- These chunks are sent over a **WebSocket** to the FastAPI endpoint at `/asr`.
|
||||||
|
- The Python server decodes `.webm` chunks on the fly using **FFmpeg** and streams them into the **whisper streaming** implementation for transcription.
|
||||||
|
- **Partial transcription** appears as soon as enough audio is processed. The “unvalidated” text is shown in **lighter or grey color** (i.e., an ‘aperçu’) to indicate it’s still buffered partial output. Once Whisper finalizes that segment, it’s displayed in normal text.
|
||||||
|
- You can watch the transcription update in near real time, ideal for demos, prototyping, or quick debugging.
|
||||||
|
|
||||||
|
### Deploying to a Remote Server
|
||||||
|
|
||||||
This pseudocode describes the interface that we suggest for your implementation. You can implement any features that you need for your application.
|
If you want to **deploy** this setup:
|
||||||
|
|
||||||
```
|
1. **Host the FastAPI app** behind a production-grade HTTP(S) server (like **Uvicorn + Nginx** or Docker). If you use HTTPS, use "wss" instead of "ws" in WebSocket URL.
|
||||||
from whisper_online import *
|
2. The **HTML/JS page** can be served by the same FastAPI app or a separate static host.
|
||||||
|
3. Users open the page in **Chrome/Firefox** (any modern browser that supports MediaRecorder + WebSocket).
|
||||||
src_lan = "en" # source language
|
|
||||||
tgt_lan = "en" # target language -- same as source for ASR, "en" if translate task is used
|
|
||||||
|
|
||||||
asr = FasterWhisperASR(lan, "large-v2") # loads and wraps Whisper model
|
|
||||||
# set options:
|
|
||||||
# asr.set_translate_task() # it will translate from lan into English
|
|
||||||
# asr.use_vad() # set using VAD
|
|
||||||
|
|
||||||
online = OnlineASRProcessor(asr) # create processing object with default buffer trimming option
|
|
||||||
|
|
||||||
while audio_has_not_ended: # processing loop:
|
|
||||||
a = # receive new audio chunk (and e.g. wait for min_chunk_size seconds first, ...)
|
|
||||||
online.insert_audio_chunk(a)
|
|
||||||
o = online.process_iter()
|
|
||||||
print(o) # do something with current partial output
|
|
||||||
# at the end of this audio processing
|
|
||||||
o = online.finish()
|
|
||||||
print(o) # do something with the last output
|
|
||||||
|
|
||||||
|
|
||||||
online.init() # refresh if you're going to re-use the object for the next audio
|
|
||||||
```
|
|
||||||
|
|
||||||
### Server -- real-time from mic
|
|
||||||
|
|
||||||
`whisper_online_server.py` has the same model options as `whisper_online.py`, plus `--host` and `--port` of the TCP connection. See help message (`-h` option).
|
|
||||||
|
|
||||||
Client example:
|
|
||||||
|
|
||||||
```
|
|
||||||
arecord -f S16_LE -c1 -r 16000 -t raw -D default | nc localhost 43001
|
|
||||||
```
|
|
||||||
|
|
||||||
- arecord sends realtime audio from a sound device (e.g. mic), in raw audio format -- 16000 sampling rate, mono channel, S16\_LE -- signed 16-bit integer low endian. (use the alternative to arecord that works for you)
|
|
||||||
|
|
||||||
- nc is netcat with server's host and port
|
|
||||||
|
|
||||||
|
|
||||||
## Background
|
|
||||||
|
|
||||||
Default Whisper is intended for audio chunks of at most 30 seconds that contain
|
|
||||||
one full sentence. Longer audio files must be split to shorter chunks and
|
|
||||||
merged with "init prompt". In low latency simultaneous streaming mode, the
|
|
||||||
simple and naive chunking fixed-sized windows does not work well, it can split
|
|
||||||
a word in the middle. It is also necessary to know when the transcribt is
|
|
||||||
stable, should be confirmed ("commited") and followed up, and when the future
|
|
||||||
content makes the transcript clearer.
|
|
||||||
|
|
||||||
For that, there is LocalAgreement-n policy: if n consecutive updates, each with
|
|
||||||
a newly available audio stream chunk, agree on a prefix transcript, it is
|
|
||||||
confirmed. (Reference: CUNI-KIT at IWSLT 2022 etc.)
|
|
||||||
|
|
||||||
In this project, we re-use the idea of Peter Polák from this demo:
|
|
||||||
https://github.com/pe-trik/transformers/blob/online_decode/examples/pytorch/online-decoding/whisper-online-demo.py
|
|
||||||
However, it doesn't do any sentence segmentation, but Whisper produces
|
|
||||||
punctuation and the libraries `faster-whisper` and `whisper_transcribed` make
|
|
||||||
word-level timestamps. In short: we
|
|
||||||
consecutively process new audio chunks, emit the transcripts that are confirmed
|
|
||||||
by 2 iterations, and scroll the audio processing buffer on a timestamp of a
|
|
||||||
confirmed complete sentence. The processing audio buffer is not too long and
|
|
||||||
the processing is fast.
|
|
||||||
|
|
||||||
In more detail: we use the init prompt, we handle the inaccurate timestamps, we
|
|
||||||
re-process confirmed sentence prefixes and skip them, making sure they don't
|
|
||||||
overlap, and we limit the processing buffer window.
|
|
||||||
|
|
||||||
Contributions are welcome.
|
|
||||||
|
|
||||||
### Performance evaluation
|
|
||||||
|
|
||||||
[See the paper.](http://www.afnlp.org/conferences/ijcnlp2023/proceedings/main-demo/cdrom/pdf/2023.ijcnlp-demo.3.pdf)
|
|
||||||
|
|
||||||
|
|
||||||
## Contact
|
|
||||||
|
|
||||||
Dominik Macháček, machacek@ufal.mff.cuni.cz
|
|
||||||
|
|
||||||
|
No additional front-end libraries or frameworks are required. The WebSocket logic in `live_transcription.html` is minimal enough to adapt for your own custom UI or embed in other pages.
|
||||||
|
|
||||||
|
## Acknowledgments
|
||||||
|
|
||||||
|
This project builds upon the foundational work of the Whisper Streaming project. We extend our gratitude to the original authors for their contributions.
|
||||||
|
|||||||
@@ -1,94 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
|
|
||||||
"""Functions for sending and receiving individual lines of text over a socket.
|
|
||||||
|
|
||||||
Used by marian-server-server.py to communicate with the Marian worker.
|
|
||||||
|
|
||||||
A line is transmitted using one or more fixed-size packets of UTF-8 bytes
|
|
||||||
containing:
|
|
||||||
|
|
||||||
- Zero or more bytes of UTF-8, excluding \n and \0, followed by
|
|
||||||
|
|
||||||
- Zero or more \0 bytes as required to pad the packet to PACKET_SIZE
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
PACKET_SIZE = 65536
|
|
||||||
|
|
||||||
|
|
||||||
def send_one_line(socket, text):
|
|
||||||
"""Sends a line of text over the given socket.
|
|
||||||
|
|
||||||
The 'text' argument should contain a single line of text (line break
|
|
||||||
characters are optional). Line boundaries are determined by Python's
|
|
||||||
str.splitlines() function [1]. We also count '\0' as a line terminator.
|
|
||||||
If 'text' contains multiple lines then only the first will be sent.
|
|
||||||
|
|
||||||
If the send fails then an exception will be raised.
|
|
||||||
|
|
||||||
[1] https://docs.python.org/3.5/library/stdtypes.html#str.splitlines
|
|
||||||
|
|
||||||
Args:
|
|
||||||
socket: a socket object.
|
|
||||||
text: string containing a line of text for transmission.
|
|
||||||
"""
|
|
||||||
text.replace('\0', '\n')
|
|
||||||
lines = text.splitlines()
|
|
||||||
first_line = '' if len(lines) == 0 else lines[0]
|
|
||||||
# TODO Is there a better way of handling bad input than 'replace'?
|
|
||||||
data = first_line.encode('utf-8', errors='replace') + b'\n\0'
|
|
||||||
for offset in range(0, len(data), PACKET_SIZE):
|
|
||||||
bytes_remaining = len(data) - offset
|
|
||||||
if bytes_remaining < PACKET_SIZE:
|
|
||||||
padding_length = PACKET_SIZE - bytes_remaining
|
|
||||||
packet = data[offset:] + b'\0' * padding_length
|
|
||||||
else:
|
|
||||||
packet = data[offset:offset+PACKET_SIZE]
|
|
||||||
socket.sendall(packet)
|
|
||||||
|
|
||||||
|
|
||||||
def receive_one_line(socket):
|
|
||||||
"""Receives a line of text from the given socket.
|
|
||||||
|
|
||||||
This function will (attempt to) receive a single line of text. If data is
|
|
||||||
currently unavailable then it will block until data becomes available or
|
|
||||||
the sender has closed the connection (in which case it will return an
|
|
||||||
empty string).
|
|
||||||
|
|
||||||
The string should not contain any newline characters, but if it does then
|
|
||||||
only the first line will be returned.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
socket: a socket object.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A string representing a single line with a terminating newline or
|
|
||||||
None if the connection has been closed.
|
|
||||||
"""
|
|
||||||
data = b''
|
|
||||||
while True:
|
|
||||||
packet = socket.recv(PACKET_SIZE)
|
|
||||||
if not packet: # Connection has been closed.
|
|
||||||
return None
|
|
||||||
data += packet
|
|
||||||
if b'\0' in packet:
|
|
||||||
break
|
|
||||||
# TODO Is there a better way of handling bad input than 'replace'?
|
|
||||||
text = data.decode('utf-8', errors='replace').strip('\0')
|
|
||||||
lines = text.split('\n')
|
|
||||||
return lines[0] + '\n'
|
|
||||||
|
|
||||||
|
|
||||||
def receive_lines(socket):
|
|
||||||
try:
|
|
||||||
data = socket.recv(PACKET_SIZE)
|
|
||||||
except BlockingIOError:
|
|
||||||
return []
|
|
||||||
if data is None: # Connection has been closed.
|
|
||||||
return None
|
|
||||||
# TODO Is there a better way of handling bad input than 'replace'?
|
|
||||||
text = data.decode('utf-8', errors='replace').strip('\0')
|
|
||||||
lines = text.split('\n')
|
|
||||||
if len(lines)==1 and not lines[0]:
|
|
||||||
return None
|
|
||||||
return lines
|
|
||||||
44
setup.py
Normal file
44
setup.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name="whisperlivekit",
|
||||||
|
version="0.1.0",
|
||||||
|
description="Real-time, Fully Local Whisper's Speech-to-Text and Speaker Diarization",
|
||||||
|
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||||
|
long_description_content_type="text/markdown",
|
||||||
|
author="Quentin Fuxa",
|
||||||
|
url="https://github.com/QuentinFuxa/WhisperLiveKit",
|
||||||
|
packages=find_packages(),
|
||||||
|
install_requires=[
|
||||||
|
"fastapi",
|
||||||
|
"ffmpeg-python",
|
||||||
|
"librosa",
|
||||||
|
"soundfile",
|
||||||
|
"faster-whisper",
|
||||||
|
"uvicorn",
|
||||||
|
"websockets",
|
||||||
|
],
|
||||||
|
extras_require={
|
||||||
|
"diarization": ["diart"],
|
||||||
|
"vac": ["torch"],
|
||||||
|
"sentence": ["mosestokenizer", "wtpsplit"],
|
||||||
|
},
|
||||||
|
package_data={
|
||||||
|
'whisperlivekit': ['web/*.html'],
|
||||||
|
},
|
||||||
|
entry_points={
|
||||||
|
'console_scripts': [
|
||||||
|
'whisperlivekit-server=whisperlivekit.server:run_server',
|
||||||
|
],
|
||||||
|
},
|
||||||
|
classifiers=[
|
||||||
|
"Development Status :: 4 - Beta",
|
||||||
|
"Intended Audience :: Developers",
|
||||||
|
"License :: OSI Approved :: MIT License",
|
||||||
|
"Programming Language :: Python :: 3.9",
|
||||||
|
"Programming Language :: Python :: 3.10",
|
||||||
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
|
"Topic :: Multimedia :: Sound/Audio :: Speech",
|
||||||
|
],
|
||||||
|
python_requires=">=3.9",
|
||||||
|
)
|
||||||
82
whisper_fastapi_online_server.py
Normal file
82
whisper_fastapi_online_server.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||||
|
from fastapi.responses import HTMLResponse
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from whisperlivekit import WhisperLiveKit
|
||||||
|
from whisperlivekit.audio_processor import AudioProcessor
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||||
|
logging.getLogger().setLevel(logging.WARNING)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
kit = None
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
global kit
|
||||||
|
kit = WhisperLiveKit()
|
||||||
|
yield
|
||||||
|
|
||||||
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def get():
|
||||||
|
return HTMLResponse(kit.web_interface())
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_websocket_results(websocket, results_generator):
|
||||||
|
"""Consumes results from the audio processor and sends them via WebSocket."""
|
||||||
|
try:
|
||||||
|
async for response in results_generator:
|
||||||
|
await websocket.send_json(response)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error in WebSocket results handler: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@app.websocket("/asr")
|
||||||
|
async def websocket_endpoint(websocket: WebSocket):
|
||||||
|
audio_processor = AudioProcessor()
|
||||||
|
|
||||||
|
await websocket.accept()
|
||||||
|
logger.info("WebSocket connection opened.")
|
||||||
|
|
||||||
|
results_generator = await audio_processor.create_tasks()
|
||||||
|
websocket_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
message = await websocket.receive_bytes()
|
||||||
|
await audio_processor.process_audio(message)
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
logger.warning("WebSocket disconnected.")
|
||||||
|
finally:
|
||||||
|
websocket_task.cancel()
|
||||||
|
await audio_processor.cleanup()
|
||||||
|
logger.info("WebSocket endpoint cleaned up.")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
temp_kit = WhisperLiveKit(transcription=False, diarization=False)
|
||||||
|
|
||||||
|
uvicorn.run(
|
||||||
|
"whisper_fastapi_online_server:app",
|
||||||
|
host=temp_kit.args.host,
|
||||||
|
port=temp_kit.args.port,
|
||||||
|
reload=True,
|
||||||
|
log_level="info"
|
||||||
|
)
|
||||||
@@ -1,611 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
import sys
|
|
||||||
import numpy as np
|
|
||||||
import librosa
|
|
||||||
from functools import lru_cache
|
|
||||||
import time
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
|
||||||
def load_audio(fname):
|
|
||||||
a, _ = librosa.load(fname, sr=16000)
|
|
||||||
return a
|
|
||||||
|
|
||||||
def load_audio_chunk(fname, beg, end):
|
|
||||||
audio = load_audio(fname)
|
|
||||||
beg_s = int(beg*16000)
|
|
||||||
end_s = int(end*16000)
|
|
||||||
return audio[beg_s:end_s]
|
|
||||||
|
|
||||||
|
|
||||||
# Whisper backend
|
|
||||||
|
|
||||||
class ASRBase:
|
|
||||||
|
|
||||||
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
|
|
||||||
# "" for faster-whisper because it emits the spaces when neeeded)
|
|
||||||
|
|
||||||
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr):
|
|
||||||
self.logfile = logfile
|
|
||||||
|
|
||||||
self.transcribe_kargs = {}
|
|
||||||
self.original_language = lan
|
|
||||||
|
|
||||||
self.model = self.load_model(modelsize, cache_dir, model_dir)
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(self, modelsize, cache_dir):
|
|
||||||
raise NotImplemented("must be implemented in the child class")
|
|
||||||
|
|
||||||
def transcribe(self, audio, init_prompt=""):
|
|
||||||
raise NotImplemented("must be implemented in the child class")
|
|
||||||
|
|
||||||
def use_vad(self):
|
|
||||||
raise NotImplemented("must be implemented in the child class")
|
|
||||||
|
|
||||||
|
|
||||||
class WhisperTimestampedASR(ASRBase):
|
|
||||||
"""Uses whisper_timestamped library as the backend. Initially, we tested the code on this backend. It worked, but slower than faster-whisper.
|
|
||||||
On the other hand, the installation for GPU could be easier.
|
|
||||||
"""
|
|
||||||
|
|
||||||
sep = " "
|
|
||||||
|
|
||||||
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
|
||||||
import whisper
|
|
||||||
from whisper_timestamped import transcribe_timestamped
|
|
||||||
self.transcribe_timestamped = transcribe_timestamped
|
|
||||||
if model_dir is not None:
|
|
||||||
print("ignoring model_dir, not implemented",file=self.logfile)
|
|
||||||
return whisper.load_model(modelsize, download_root=cache_dir)
|
|
||||||
|
|
||||||
def transcribe(self, audio, init_prompt=""):
|
|
||||||
result = self.transcribe_timestamped(self.model,
|
|
||||||
audio, language=self.original_language,
|
|
||||||
initial_prompt=init_prompt, verbose=None,
|
|
||||||
condition_on_previous_text=True, **self.transcribe_kargs)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def ts_words(self,r):
|
|
||||||
# return: transcribe result object to [(beg,end,"word1"), ...]
|
|
||||||
o = []
|
|
||||||
for s in r["segments"]:
|
|
||||||
for w in s["words"]:
|
|
||||||
t = (w["start"],w["end"],w["text"])
|
|
||||||
o.append(t)
|
|
||||||
return o
|
|
||||||
|
|
||||||
def segments_end_ts(self, res):
|
|
||||||
return [s["end"] for s in res["segments"]]
|
|
||||||
|
|
||||||
def use_vad(self):
|
|
||||||
self.transcribe_kargs["vad"] = True
|
|
||||||
|
|
||||||
def set_translate_task(self):
|
|
||||||
self.transcribe_kargs["task"] = "translate"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FasterWhisperASR(ASRBase):
|
|
||||||
"""Uses faster-whisper library as the backend. Works much faster, appx 4-times (in offline mode). For GPU, it requires installation with a specific CUDNN version.
|
|
||||||
"""
|
|
||||||
|
|
||||||
sep = ""
|
|
||||||
|
|
||||||
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
|
||||||
from faster_whisper import WhisperModel
|
|
||||||
if model_dir is not None:
|
|
||||||
print(f"Loading whisper model from model_dir {model_dir}. modelsize and cache_dir parameters are not used.",file=self.logfile)
|
|
||||||
model_size_or_path = model_dir
|
|
||||||
elif modelsize is not None:
|
|
||||||
model_size_or_path = modelsize
|
|
||||||
else:
|
|
||||||
raise ValueError("modelsize or model_dir parameter must be set")
|
|
||||||
|
|
||||||
|
|
||||||
# this worked fast and reliably on NVIDIA L40
|
|
||||||
model = WhisperModel(model_size_or_path, device="cuda", compute_type="float16", download_root=cache_dir)
|
|
||||||
|
|
||||||
# or run on GPU with INT8
|
|
||||||
# tested: the transcripts were different, probably worse than with FP16, and it was slightly (appx 20%) slower
|
|
||||||
#model = WhisperModel(model_size, device="cuda", compute_type="int8_float16")
|
|
||||||
|
|
||||||
# or run on CPU with INT8
|
|
||||||
# tested: works, but slow, appx 10-times than cuda FP16
|
|
||||||
# model = WhisperModel(modelsize, device="cpu", compute_type="int8") #, download_root="faster-disk-cache-dir/")
|
|
||||||
return model
|
|
||||||
|
|
||||||
def transcribe(self, audio, init_prompt=""):
|
|
||||||
# tested: beam_size=5 is faster and better than 1 (on one 200 second document from En ESIC, min chunk 0.01)
|
|
||||||
segments, info = self.model.transcribe(audio, language=self.original_language, initial_prompt=init_prompt, beam_size=5, word_timestamps=True, condition_on_previous_text=True, **self.transcribe_kargs)
|
|
||||||
return list(segments)
|
|
||||||
|
|
||||||
def ts_words(self, segments):
|
|
||||||
o = []
|
|
||||||
for segment in segments:
|
|
||||||
for word in segment.words:
|
|
||||||
# not stripping the spaces -- should not be merged with them!
|
|
||||||
w = word.word
|
|
||||||
t = (word.start, word.end, w)
|
|
||||||
o.append(t)
|
|
||||||
return o
|
|
||||||
|
|
||||||
def segments_end_ts(self, res):
|
|
||||||
return [s.end for s in res]
|
|
||||||
|
|
||||||
def use_vad(self):
|
|
||||||
self.transcribe_kargs["vad_filter"] = True
|
|
||||||
|
|
||||||
def set_translate_task(self):
|
|
||||||
self.transcribe_kargs["task"] = "translate"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class HypothesisBuffer:
|
|
||||||
|
|
||||||
def __init__(self, logfile=sys.stderr):
|
|
||||||
self.commited_in_buffer = []
|
|
||||||
self.buffer = []
|
|
||||||
self.new = []
|
|
||||||
|
|
||||||
self.last_commited_time = 0
|
|
||||||
self.last_commited_word = None
|
|
||||||
|
|
||||||
self.logfile = logfile
|
|
||||||
|
|
||||||
def insert(self, new, offset):
|
|
||||||
# compare self.commited_in_buffer and new. It inserts only the words in new that extend the commited_in_buffer, it means they are roughly behind last_commited_time and new in content
|
|
||||||
# the new tail is added to self.new
|
|
||||||
|
|
||||||
new = [(a+offset,b+offset,t) for a,b,t in new]
|
|
||||||
self.new = [(a,b,t) for a,b,t in new if a > self.last_commited_time-0.1]
|
|
||||||
|
|
||||||
if len(self.new) >= 1:
|
|
||||||
a,b,t = self.new[0]
|
|
||||||
if abs(a - self.last_commited_time) < 1:
|
|
||||||
if self.commited_in_buffer:
|
|
||||||
# it's going to search for 1, 2, ..., 5 consecutive words (n-grams) that are identical in commited and new. If they are, they're dropped.
|
|
||||||
cn = len(self.commited_in_buffer)
|
|
||||||
nn = len(self.new)
|
|
||||||
for i in range(1,min(min(cn,nn),5)+1): # 5 is the maximum
|
|
||||||
c = " ".join([self.commited_in_buffer[-j][2] for j in range(1,i+1)][::-1])
|
|
||||||
tail = " ".join(self.new[j-1][2] for j in range(1,i+1))
|
|
||||||
if c == tail:
|
|
||||||
print("removing last",i,"words:",file=self.logfile)
|
|
||||||
for j in range(i):
|
|
||||||
print("\t",self.new.pop(0),file=self.logfile)
|
|
||||||
break
|
|
||||||
|
|
||||||
def flush(self):
|
|
||||||
# returns commited chunk = the longest common prefix of 2 last inserts.
|
|
||||||
|
|
||||||
commit = []
|
|
||||||
while self.new:
|
|
||||||
na, nb, nt = self.new[0]
|
|
||||||
|
|
||||||
if len(self.buffer) == 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
if nt == self.buffer[0][2]:
|
|
||||||
commit.append((na,nb,nt))
|
|
||||||
self.last_commited_word = nt
|
|
||||||
self.last_commited_time = nb
|
|
||||||
self.buffer.pop(0)
|
|
||||||
self.new.pop(0)
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
self.buffer = self.new
|
|
||||||
self.new = []
|
|
||||||
self.commited_in_buffer.extend(commit)
|
|
||||||
return commit
|
|
||||||
|
|
||||||
def pop_commited(self, time):
|
|
||||||
while self.commited_in_buffer and self.commited_in_buffer[0][1] <= time:
|
|
||||||
self.commited_in_buffer.pop(0)
|
|
||||||
|
|
||||||
def complete(self):
|
|
||||||
return self.buffer
|
|
||||||
|
|
||||||
class OnlineASRProcessor:
|
|
||||||
|
|
||||||
SAMPLING_RATE = 16000
|
|
||||||
|
|
||||||
def __init__(self, asr, tokenizer=None, buffer_trimming=("segment", 15), logfile=sys.stderr):
|
|
||||||
"""asr: WhisperASR object
|
|
||||||
tokenizer: sentence tokenizer object for the target language. Must have a method *split* that behaves like the one of MosesTokenizer. It can be None, if "segment" buffer trimming option is used, then tokenizer is not used at all.
|
|
||||||
("segment", 15)
|
|
||||||
buffer_trimming: a pair of (option, seconds), where option is either "sentence" or "segment", and seconds is a number. Buffer is trimmed if it is longer than "seconds" threshold. Default is the most recommended option.
|
|
||||||
logfile: where to store the log.
|
|
||||||
"""
|
|
||||||
self.asr = asr
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.logfile = logfile
|
|
||||||
|
|
||||||
self.init()
|
|
||||||
|
|
||||||
self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
|
|
||||||
|
|
||||||
def init(self):
|
|
||||||
"""run this when starting or restarting processing"""
|
|
||||||
self.audio_buffer = np.array([],dtype=np.float32)
|
|
||||||
self.buffer_time_offset = 0
|
|
||||||
|
|
||||||
self.transcript_buffer = HypothesisBuffer(logfile=self.logfile)
|
|
||||||
self.commited = []
|
|
||||||
self.last_chunked_at = 0
|
|
||||||
|
|
||||||
self.silence_iters = 0
|
|
||||||
|
|
||||||
def insert_audio_chunk(self, audio):
|
|
||||||
self.audio_buffer = np.append(self.audio_buffer, audio)
|
|
||||||
|
|
||||||
def prompt(self):
|
|
||||||
"""Returns a tuple: (prompt, context), where "prompt" is a 200-character suffix of commited text that is inside of the scrolled away part of audio buffer.
|
|
||||||
"context" is the commited text that is inside the audio buffer. It is transcribed again and skipped. It is returned only for debugging and logging reasons.
|
|
||||||
"""
|
|
||||||
k = max(0,len(self.commited)-1)
|
|
||||||
while k > 0 and self.commited[k-1][1] > self.last_chunked_at:
|
|
||||||
k -= 1
|
|
||||||
|
|
||||||
p = self.commited[:k]
|
|
||||||
p = [t for _,_,t in p]
|
|
||||||
prompt = []
|
|
||||||
l = 0
|
|
||||||
while p and l < 200: # 200 characters prompt size
|
|
||||||
x = p.pop(-1)
|
|
||||||
l += len(x)+1
|
|
||||||
prompt.append(x)
|
|
||||||
non_prompt = self.commited[k:]
|
|
||||||
return self.asr.sep.join(prompt[::-1]), self.asr.sep.join(t for _,_,t in non_prompt)
|
|
||||||
|
|
||||||
def process_iter(self):
|
|
||||||
"""Runs on the current audio buffer.
|
|
||||||
Returns: a tuple (beg_timestamp, end_timestamp, "text"), or (None, None, "").
|
|
||||||
The non-emty text is confirmed (committed) partial transcript.
|
|
||||||
"""
|
|
||||||
|
|
||||||
prompt, non_prompt = self.prompt()
|
|
||||||
print("PROMPT:", prompt, file=self.logfile)
|
|
||||||
print("CONTEXT:", non_prompt, file=self.logfile)
|
|
||||||
print(f"transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds from {self.buffer_time_offset:2.2f}",file=self.logfile)
|
|
||||||
res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt)
|
|
||||||
|
|
||||||
# transform to [(beg,end,"word1"), ...]
|
|
||||||
tsw = self.asr.ts_words(res)
|
|
||||||
|
|
||||||
self.transcript_buffer.insert(tsw, self.buffer_time_offset)
|
|
||||||
o = self.transcript_buffer.flush()
|
|
||||||
self.commited.extend(o)
|
|
||||||
print(">>>>COMPLETE NOW:",self.to_flush(o),file=self.logfile,flush=True)
|
|
||||||
print("INCOMPLETE:",self.to_flush(self.transcript_buffer.complete()),file=self.logfile,flush=True)
|
|
||||||
|
|
||||||
# there is a newly confirmed text
|
|
||||||
|
|
||||||
if o and self.buffer_trimming_way == "sentence": # trim the completed sentences
|
|
||||||
if len(self.audio_buffer)/self.SAMPLING_RATE > self.buffer_trimming_sec: # longer than this
|
|
||||||
self.chunk_completed_sentence()
|
|
||||||
|
|
||||||
|
|
||||||
if self.buffer_trimming_way == "segment":
|
|
||||||
s = self.buffer_trimming_sec # trim the completed segments longer than s,
|
|
||||||
else:
|
|
||||||
s = 30 # if the audio buffer is longer than 30s, trim it
|
|
||||||
|
|
||||||
if len(self.audio_buffer)/self.SAMPLING_RATE > s:
|
|
||||||
self.chunk_completed_segment(res)
|
|
||||||
|
|
||||||
# alternative: on any word
|
|
||||||
#l = self.buffer_time_offset + len(self.audio_buffer)/self.SAMPLING_RATE - 10
|
|
||||||
# let's find commited word that is less
|
|
||||||
#k = len(self.commited)-1
|
|
||||||
#while k>0 and self.commited[k][1] > l:
|
|
||||||
# k -= 1
|
|
||||||
#t = self.commited[k][1]
|
|
||||||
print(f"chunking segment",file=self.logfile)
|
|
||||||
#self.chunk_at(t)
|
|
||||||
|
|
||||||
print(f"len of buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}",file=self.logfile)
|
|
||||||
return self.to_flush(o)
|
|
||||||
|
|
||||||
def chunk_completed_sentence(self):
|
|
||||||
if self.commited == []: return
|
|
||||||
print(self.commited,file=self.logfile)
|
|
||||||
sents = self.words_to_sentences(self.commited)
|
|
||||||
for s in sents:
|
|
||||||
print("\t\tSENT:",s,file=self.logfile)
|
|
||||||
if len(sents) < 2:
|
|
||||||
return
|
|
||||||
while len(sents) > 2:
|
|
||||||
sents.pop(0)
|
|
||||||
# we will continue with audio processing at this timestamp
|
|
||||||
chunk_at = sents[-2][1]
|
|
||||||
|
|
||||||
print(f"--- sentence chunked at {chunk_at:2.2f}",file=self.logfile)
|
|
||||||
self.chunk_at(chunk_at)
|
|
||||||
|
|
||||||
def chunk_completed_segment(self, res):
|
|
||||||
if self.commited == []: return
|
|
||||||
|
|
||||||
ends = self.asr.segments_end_ts(res)
|
|
||||||
|
|
||||||
t = self.commited[-1][1]
|
|
||||||
|
|
||||||
if len(ends) > 1:
|
|
||||||
|
|
||||||
e = ends[-2]+self.buffer_time_offset
|
|
||||||
while len(ends) > 2 and e > t:
|
|
||||||
ends.pop(-1)
|
|
||||||
e = ends[-2]+self.buffer_time_offset
|
|
||||||
if e <= t:
|
|
||||||
print(f"--- segment chunked at {e:2.2f}",file=self.logfile)
|
|
||||||
self.chunk_at(e)
|
|
||||||
else:
|
|
||||||
print(f"--- last segment not within commited area",file=self.logfile)
|
|
||||||
else:
|
|
||||||
print(f"--- not enough segments to chunk",file=self.logfile)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def chunk_at(self, time):
|
|
||||||
"""trims the hypothesis and audio buffer at "time"
|
|
||||||
"""
|
|
||||||
self.transcript_buffer.pop_commited(time)
|
|
||||||
cut_seconds = time - self.buffer_time_offset
|
|
||||||
self.audio_buffer = self.audio_buffer[int(cut_seconds*self.SAMPLING_RATE):]
|
|
||||||
self.buffer_time_offset = time
|
|
||||||
self.last_chunked_at = time
|
|
||||||
|
|
||||||
def words_to_sentences(self, words):
|
|
||||||
"""Uses self.tokenizer for sentence segmentation of words.
|
|
||||||
Returns: [(beg,end,"sentence 1"),...]
|
|
||||||
"""
|
|
||||||
|
|
||||||
cwords = [w for w in words]
|
|
||||||
t = " ".join(o[2] for o in cwords)
|
|
||||||
s = self.tokenizer.split(t)
|
|
||||||
out = []
|
|
||||||
while s:
|
|
||||||
beg = None
|
|
||||||
end = None
|
|
||||||
sent = s.pop(0).strip()
|
|
||||||
fsent = sent
|
|
||||||
while cwords:
|
|
||||||
b,e,w = cwords.pop(0)
|
|
||||||
w = w.strip()
|
|
||||||
if beg is None and sent.startswith(w):
|
|
||||||
beg = b
|
|
||||||
elif end is None and sent == w:
|
|
||||||
end = e
|
|
||||||
out.append((beg,end,fsent))
|
|
||||||
break
|
|
||||||
sent = sent[len(w):].strip()
|
|
||||||
return out
|
|
||||||
|
|
||||||
def finish(self):
|
|
||||||
"""Flush the incomplete text when the whole processing ends.
|
|
||||||
Returns: the same format as self.process_iter()
|
|
||||||
"""
|
|
||||||
o = self.transcript_buffer.complete()
|
|
||||||
f = self.to_flush(o)
|
|
||||||
print("last, noncommited:",f,file=self.logfile)
|
|
||||||
return f
|
|
||||||
|
|
||||||
|
|
||||||
def to_flush(self, sents, sep=None, offset=0, ):
|
|
||||||
# concatenates the timestamped words or sentences into one sequence that is flushed in one line
|
|
||||||
# sents: [(beg1, end1, "sentence1"), ...] or [] if empty
|
|
||||||
# return: (beg1,end-of-last-sentence,"concatenation of sentences") or (None, None, "") if empty
|
|
||||||
if sep is None:
|
|
||||||
sep = self.asr.sep
|
|
||||||
t = sep.join(s[2] for s in sents)
|
|
||||||
if len(sents) == 0:
|
|
||||||
b = None
|
|
||||||
e = None
|
|
||||||
else:
|
|
||||||
b = offset + sents[0][0]
|
|
||||||
e = offset + sents[-1][1]
|
|
||||||
return (b,e,t)
|
|
||||||
|
|
||||||
WHISPER_LANG_CODES = "af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh".split(",")
|
|
||||||
|
|
||||||
def create_tokenizer(lan):
|
|
||||||
"""returns an object that has split function that works like the one of MosesTokenizer"""
|
|
||||||
|
|
||||||
assert lan in WHISPER_LANG_CODES, "language must be Whisper's supported lang code: " + " ".join(WHISPER_LANG_CODES)
|
|
||||||
|
|
||||||
if lan == "uk":
|
|
||||||
import tokenize_uk
|
|
||||||
class UkrainianTokenizer:
|
|
||||||
def split(self, text):
|
|
||||||
return tokenize_uk.tokenize_sents(text)
|
|
||||||
return UkrainianTokenizer()
|
|
||||||
|
|
||||||
# supported by fast-mosestokenizer
|
|
||||||
if lan in "as bn ca cs de el en es et fi fr ga gu hi hu is it kn lt lv ml mni mr nl or pa pl pt ro ru sk sl sv ta te yue zh".split():
|
|
||||||
from mosestokenizer import MosesTokenizer
|
|
||||||
return MosesTokenizer(lan)
|
|
||||||
|
|
||||||
# the following languages are in Whisper, but not in wtpsplit:
|
|
||||||
if lan in "as ba bo br bs fo haw hr ht jw lb ln lo mi nn oc sa sd sn so su sw tk tl tt".split():
|
|
||||||
print(f"{lan} code is not supported by wtpsplit. Going to use None lang_code option.", file=sys.stderr)
|
|
||||||
lan = None
|
|
||||||
|
|
||||||
from wtpsplit import WtP
|
|
||||||
# downloads the model from huggingface on the first use
|
|
||||||
wtp = WtP("wtp-canine-s-12l-no-adapters")
|
|
||||||
class WtPtok:
|
|
||||||
def split(self, sent):
|
|
||||||
return wtp.split(sent, lang_code=lan)
|
|
||||||
return WtPtok()
|
|
||||||
|
|
||||||
|
|
||||||
def add_shared_args(parser):
|
|
||||||
"""shared args for simulation (this entry point) and server
|
|
||||||
parser: argparse.ArgumentParser object
|
|
||||||
"""
|
|
||||||
parser.add_argument('--min-chunk-size', type=float, default=1.0, help='Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.')
|
|
||||||
parser.add_argument('--model', type=str, default='large-v2', choices="tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large".split(","),help="Name size of the Whisper model to use (default: large-v2). The model is automatically downloaded from the model hub if not present in model cache dir.")
|
|
||||||
parser.add_argument('--model_cache_dir', type=str, default=None, help="Overriding the default model cache dir where models downloaded from the hub are saved")
|
|
||||||
parser.add_argument('--model_dir', type=str, default=None, help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.")
|
|
||||||
parser.add_argument('--lan', '--language', type=str, default='en', help="Language code for transcription, e.g. en,de,cs.")
|
|
||||||
parser.add_argument('--task', type=str, default='transcribe', choices=["transcribe","translate"],help="Transcribe or translate.")
|
|
||||||
parser.add_argument('--backend', type=str, default="faster-whisper", choices=["faster-whisper", "whisper_timestamped"],help='Load only this backend for Whisper processing.')
|
|
||||||
parser.add_argument('--vad', action="store_true", default=False, help='Use VAD = voice activity detection, with the default parameters.')
|
|
||||||
parser.add_argument('--buffer_trimming', type=str, default="segment", choices=["sentence", "segment"],help='Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.')
|
|
||||||
parser.add_argument('--buffer_trimming_sec', type=float, default=15, help='Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.')
|
|
||||||
|
|
||||||
## main:
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('audio_path', type=str, help="Filename of 16kHz mono channel wav, on which live streaming is simulated.")
|
|
||||||
add_shared_args(parser)
|
|
||||||
parser.add_argument('--start_at', type=float, default=0.0, help='Start processing audio at this time.')
|
|
||||||
parser.add_argument('--offline', action="store_true", default=False, help='Offline mode.')
|
|
||||||
parser.add_argument('--comp_unaware', action="store_true", default=False, help='Computationally unaware simulation.')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# reset to store stderr to different file stream, e.g. open(os.devnull,"w")
|
|
||||||
logfile = sys.stderr
|
|
||||||
|
|
||||||
if args.offline and args.comp_unaware:
|
|
||||||
print("No or one option from --offline and --comp_unaware are available, not both. Exiting.",file=logfile)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
audio_path = args.audio_path
|
|
||||||
|
|
||||||
SAMPLING_RATE = 16000
|
|
||||||
duration = len(load_audio(audio_path))/SAMPLING_RATE
|
|
||||||
print("Audio duration is: %2.2f seconds" % duration, file=logfile)
|
|
||||||
|
|
||||||
size = args.model
|
|
||||||
language = args.lan
|
|
||||||
|
|
||||||
t = time.time()
|
|
||||||
print(f"Loading Whisper {size} model for {language}...",file=logfile,end=" ",flush=True)
|
|
||||||
|
|
||||||
if args.backend == "faster-whisper":
|
|
||||||
asr_cls = FasterWhisperASR
|
|
||||||
else:
|
|
||||||
asr_cls = WhisperTimestampedASR
|
|
||||||
|
|
||||||
asr = asr_cls(modelsize=size, lan=language, cache_dir=args.model_cache_dir, model_dir=args.model_dir)
|
|
||||||
|
|
||||||
if args.task == "translate":
|
|
||||||
asr.set_translate_task()
|
|
||||||
tgt_language = "en" # Whisper translates into English
|
|
||||||
else:
|
|
||||||
tgt_language = language # Whisper transcribes in this language
|
|
||||||
|
|
||||||
|
|
||||||
e = time.time()
|
|
||||||
print(f"done. It took {round(e-t,2)} seconds.",file=logfile)
|
|
||||||
|
|
||||||
if args.vad:
|
|
||||||
print("setting VAD filter",file=logfile)
|
|
||||||
asr.use_vad()
|
|
||||||
|
|
||||||
|
|
||||||
min_chunk = args.min_chunk_size
|
|
||||||
if args.buffer_trimming == "sentence":
|
|
||||||
tokenizer = create_tokenizer(tgt_language)
|
|
||||||
else:
|
|
||||||
tokenizer = None
|
|
||||||
online = OnlineASRProcessor(asr,tokenizer,logfile=logfile,buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec))
|
|
||||||
|
|
||||||
|
|
||||||
# load the audio into the LRU cache before we start the timer
|
|
||||||
a = load_audio_chunk(audio_path,0,1)
|
|
||||||
|
|
||||||
# warm up the ASR, because the very first transcribe takes much more time than the other
|
|
||||||
asr.transcribe(a)
|
|
||||||
|
|
||||||
beg = args.start_at
|
|
||||||
start = time.time()-beg
|
|
||||||
|
|
||||||
def output_transcript(o, now=None):
|
|
||||||
# output format in stdout is like:
|
|
||||||
# 4186.3606 0 1720 Takhle to je
|
|
||||||
# - the first three words are:
|
|
||||||
# - emission time from beginning of processing, in milliseconds
|
|
||||||
# - beg and end timestamp of the text segment, as estimated by Whisper model. The timestamps are not accurate, but they're useful anyway
|
|
||||||
# - the next words: segment transcript
|
|
||||||
if now is None:
|
|
||||||
now = time.time()-start
|
|
||||||
if o[0] is not None:
|
|
||||||
print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),file=logfile,flush=True)
|
|
||||||
print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),flush=True)
|
|
||||||
else:
|
|
||||||
print(o,file=logfile,flush=True)
|
|
||||||
|
|
||||||
if args.offline: ## offline mode processing (for testing/debugging)
|
|
||||||
a = load_audio(audio_path)
|
|
||||||
online.insert_audio_chunk(a)
|
|
||||||
try:
|
|
||||||
o = online.process_iter()
|
|
||||||
except AssertionError:
|
|
||||||
print("assertion error",file=logfile)
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
output_transcript(o)
|
|
||||||
now = None
|
|
||||||
elif args.comp_unaware: # computational unaware mode
|
|
||||||
end = beg + min_chunk
|
|
||||||
while True:
|
|
||||||
a = load_audio_chunk(audio_path,beg,end)
|
|
||||||
online.insert_audio_chunk(a)
|
|
||||||
try:
|
|
||||||
o = online.process_iter()
|
|
||||||
except AssertionError:
|
|
||||||
print("assertion error",file=logfile)
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
output_transcript(o, now=end)
|
|
||||||
|
|
||||||
print(f"## last processed {end:.2f}s",file=logfile,flush=True)
|
|
||||||
|
|
||||||
if end >= duration:
|
|
||||||
break
|
|
||||||
|
|
||||||
beg = end
|
|
||||||
|
|
||||||
if end + min_chunk > duration:
|
|
||||||
end = duration
|
|
||||||
else:
|
|
||||||
end += min_chunk
|
|
||||||
now = duration
|
|
||||||
|
|
||||||
else: # online = simultaneous mode
|
|
||||||
end = 0
|
|
||||||
while True:
|
|
||||||
now = time.time() - start
|
|
||||||
if now < end+min_chunk:
|
|
||||||
time.sleep(min_chunk+end-now)
|
|
||||||
end = time.time() - start
|
|
||||||
a = load_audio_chunk(audio_path,beg,end)
|
|
||||||
beg = end
|
|
||||||
online.insert_audio_chunk(a)
|
|
||||||
|
|
||||||
try:
|
|
||||||
o = online.process_iter()
|
|
||||||
except AssertionError:
|
|
||||||
print("assertion error",file=logfile)
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
output_transcript(o)
|
|
||||||
now = time.time() - start
|
|
||||||
print(f"## last processed {end:.2f} s, now is {now:.2f}, the latency is {now-end:.2f}",file=logfile,flush=True)
|
|
||||||
|
|
||||||
if end >= duration:
|
|
||||||
break
|
|
||||||
now = None
|
|
||||||
|
|
||||||
o = online.finish()
|
|
||||||
output_transcript(o, now=now)
|
|
||||||
@@ -1,214 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
from whisper_online import *
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
|
|
||||||
# server options
|
|
||||||
parser.add_argument("--host", type=str, default='localhost')
|
|
||||||
parser.add_argument("--port", type=int, default=43007)
|
|
||||||
|
|
||||||
|
|
||||||
# options from whisper_online
|
|
||||||
add_shared_args(parser)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
# setting whisper object by args
|
|
||||||
|
|
||||||
SAMPLING_RATE = 16000
|
|
||||||
|
|
||||||
size = args.model
|
|
||||||
language = args.lan
|
|
||||||
|
|
||||||
t = time.time()
|
|
||||||
print(f"Loading Whisper {size} model for {language}...",file=sys.stderr,end=" ",flush=True)
|
|
||||||
|
|
||||||
if args.backend == "faster-whisper":
|
|
||||||
from faster_whisper import WhisperModel
|
|
||||||
asr_cls = FasterWhisperASR
|
|
||||||
else:
|
|
||||||
import whisper
|
|
||||||
import whisper_timestamped
|
|
||||||
# from whisper_timestamped_model import WhisperTimestampedASR
|
|
||||||
asr_cls = WhisperTimestampedASR
|
|
||||||
|
|
||||||
asr = asr_cls(modelsize=size, lan=language, cache_dir=args.model_cache_dir, model_dir=args.model_dir)
|
|
||||||
|
|
||||||
if args.task == "translate":
|
|
||||||
asr.set_translate_task()
|
|
||||||
tgt_language = "en"
|
|
||||||
else:
|
|
||||||
tgt_language = language
|
|
||||||
|
|
||||||
e = time.time()
|
|
||||||
print(f"done. It took {round(e-t,2)} seconds.",file=sys.stderr)
|
|
||||||
|
|
||||||
if args.vad:
|
|
||||||
print("setting VAD filter",file=sys.stderr)
|
|
||||||
asr.use_vad()
|
|
||||||
|
|
||||||
|
|
||||||
min_chunk = args.min_chunk_size
|
|
||||||
|
|
||||||
if args.buffer_trimming == "sentence":
|
|
||||||
tokenizer = create_tokenizer(tgt_language)
|
|
||||||
else:
|
|
||||||
tokenizer = None
|
|
||||||
online = OnlineASRProcessor(asr,tokenizer,buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec))
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
demo_audio_path = "cs-maji-2.16k.wav"
|
|
||||||
if os.path.exists(demo_audio_path):
|
|
||||||
# load the audio into the LRU cache before we start the timer
|
|
||||||
a = load_audio_chunk(demo_audio_path,0,1)
|
|
||||||
|
|
||||||
# TODO: it should be tested whether it's meaningful
|
|
||||||
# warm up the ASR, because the very first transcribe takes much more time than the other
|
|
||||||
asr.transcribe(a)
|
|
||||||
else:
|
|
||||||
print("Whisper is not warmed up",file=sys.stderr)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
######### Server objects
|
|
||||||
|
|
||||||
import line_packet
|
|
||||||
import socket
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
|
|
||||||
class Connection:
|
|
||||||
'''it wraps conn object'''
|
|
||||||
PACKET_SIZE = 65536
|
|
||||||
|
|
||||||
def __init__(self, conn):
|
|
||||||
self.conn = conn
|
|
||||||
self.last_line = ""
|
|
||||||
|
|
||||||
self.conn.setblocking(True)
|
|
||||||
|
|
||||||
def send(self, line):
|
|
||||||
'''it doesn't send the same line twice, because it was problematic in online-text-flow-events'''
|
|
||||||
if line == self.last_line:
|
|
||||||
return
|
|
||||||
line_packet.send_one_line(self.conn, line)
|
|
||||||
self.last_line = line
|
|
||||||
|
|
||||||
def receive_lines(self):
|
|
||||||
in_line = line_packet.receive_lines(self.conn)
|
|
||||||
return in_line
|
|
||||||
|
|
||||||
def non_blocking_receive_audio(self):
|
|
||||||
r = self.conn.recv(self.PACKET_SIZE)
|
|
||||||
return r
|
|
||||||
|
|
||||||
|
|
||||||
import io
|
|
||||||
import soundfile
|
|
||||||
|
|
||||||
# wraps socket and ASR object, and serves one client connection.
|
|
||||||
# next client should be served by a new instance of this object
|
|
||||||
class ServerProcessor:
|
|
||||||
|
|
||||||
def __init__(self, c, online_asr_proc, min_chunk):
|
|
||||||
self.connection = c
|
|
||||||
self.online_asr_proc = online_asr_proc
|
|
||||||
self.min_chunk = min_chunk
|
|
||||||
|
|
||||||
self.last_end = None
|
|
||||||
|
|
||||||
def receive_audio_chunk(self):
|
|
||||||
# receive all audio that is available by this time
|
|
||||||
# blocks operation if less than self.min_chunk seconds is available
|
|
||||||
# unblocks if connection is closed or a chunk is available
|
|
||||||
out = []
|
|
||||||
while sum(len(x) for x in out) < self.min_chunk*SAMPLING_RATE:
|
|
||||||
raw_bytes = self.connection.non_blocking_receive_audio()
|
|
||||||
print(raw_bytes[:10])
|
|
||||||
print(len(raw_bytes))
|
|
||||||
if not raw_bytes:
|
|
||||||
break
|
|
||||||
sf = soundfile.SoundFile(io.BytesIO(raw_bytes), channels=1,endian="LITTLE",samplerate=SAMPLING_RATE, subtype="PCM_16",format="RAW")
|
|
||||||
audio, _ = librosa.load(sf,sr=SAMPLING_RATE)
|
|
||||||
out.append(audio)
|
|
||||||
if not out:
|
|
||||||
return None
|
|
||||||
return np.concatenate(out)
|
|
||||||
|
|
||||||
def format_output_transcript(self,o):
|
|
||||||
# output format in stdout is like:
|
|
||||||
# 0 1720 Takhle to je
|
|
||||||
# - the first two words are:
|
|
||||||
# - beg and end timestamp of the text segment, as estimated by Whisper model. The timestamps are not accurate, but they're useful anyway
|
|
||||||
# - the next words: segment transcript
|
|
||||||
|
|
||||||
# This function differs from whisper_online.output_transcript in the following:
|
|
||||||
# succeeding [beg,end] intervals are not overlapping because ELITR protocol (implemented in online-text-flow events) requires it.
|
|
||||||
# Therefore, beg, is max of previous end and current beg outputed by Whisper.
|
|
||||||
# Usually it differs negligibly, by appx 20 ms.
|
|
||||||
|
|
||||||
if o[0] is not None:
|
|
||||||
beg, end = o[0]*1000,o[1]*1000
|
|
||||||
if self.last_end is not None:
|
|
||||||
beg = max(beg, self.last_end)
|
|
||||||
|
|
||||||
self.last_end = end
|
|
||||||
print("%1.0f %1.0f %s" % (beg,end,o[2]),flush=True,file=sys.stderr)
|
|
||||||
return "%1.0f %1.0f %s" % (beg,end,o[2])
|
|
||||||
else:
|
|
||||||
print(o,file=sys.stderr,flush=True)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def send_result(self, o):
|
|
||||||
msg = self.format_output_transcript(o)
|
|
||||||
if msg is not None:
|
|
||||||
self.connection.send(msg)
|
|
||||||
|
|
||||||
def process(self):
|
|
||||||
# handle one client connection
|
|
||||||
self.online_asr_proc.init()
|
|
||||||
while True:
|
|
||||||
a = self.receive_audio_chunk()
|
|
||||||
if a is None:
|
|
||||||
print("break here",file=sys.stderr)
|
|
||||||
break
|
|
||||||
self.online_asr_proc.insert_audio_chunk(a)
|
|
||||||
o = online.process_iter()
|
|
||||||
try:
|
|
||||||
self.send_result(o)
|
|
||||||
except BrokenPipeError:
|
|
||||||
print("broken pipe -- connection closed?",file=sys.stderr)
|
|
||||||
break
|
|
||||||
|
|
||||||
# o = online.finish() # this should be working
|
|
||||||
# self.send_result(o)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Start logging.
|
|
||||||
level = logging.INFO
|
|
||||||
logging.basicConfig(level=level, format='whisper-server-%(levelname)s: %(message)s')
|
|
||||||
|
|
||||||
# server loop
|
|
||||||
|
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
||||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
||||||
s.bind((args.host, args.port))
|
|
||||||
s.listen(1)
|
|
||||||
logging.info('INFO: Listening on'+str((args.host, args.port)))
|
|
||||||
while True:
|
|
||||||
conn, addr = s.accept()
|
|
||||||
logging.info('INFO: Connected to client on {}'.format(addr))
|
|
||||||
connection = Connection(conn)
|
|
||||||
proc = ServerProcessor(connection, online, min_chunk)
|
|
||||||
proc.process()
|
|
||||||
conn.close()
|
|
||||||
logging.info('INFO: Connection to client closed')
|
|
||||||
logging.info('INFO: Connection closed, terminating.')
|
|
||||||
4
whisperlivekit/__init__.py
Normal file
4
whisperlivekit/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from .core import WhisperLiveKit, parse_args
|
||||||
|
from .audio_processor import AudioProcessor
|
||||||
|
|
||||||
|
__all__ = ['WhisperLiveKit', 'AudioProcessor', 'parse_args']
|
||||||
409
whisperlivekit/audio_processor.py
Normal file
409
whisperlivekit/audio_processor.py
Normal file
@@ -0,0 +1,409 @@
|
|||||||
|
import asyncio
|
||||||
|
import numpy as np
|
||||||
|
import ffmpeg
|
||||||
|
from time import time, sleep
|
||||||
|
import math
|
||||||
|
import logging
|
||||||
|
import traceback
|
||||||
|
from datetime import timedelta
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
from whisperlivekit.timed_objects import ASRToken
|
||||||
|
from whisperlivekit.whisper_streaming_custom.whisper_online import online_factory
|
||||||
|
from whisperlivekit.core import WhisperLiveKit
|
||||||
|
|
||||||
|
# Set up logging once
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
def format_time(seconds: float) -> str:
|
||||||
|
"""Format seconds as HH:MM:SS."""
|
||||||
|
return str(timedelta(seconds=int(seconds)))
|
||||||
|
|
||||||
|
class AudioProcessor:
|
||||||
|
"""
|
||||||
|
Processes audio streams for transcription and diarization.
|
||||||
|
Handles audio processing, state management, and result formatting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the audio processor with configuration, models, and state."""
|
||||||
|
|
||||||
|
models = WhisperLiveKit()
|
||||||
|
|
||||||
|
# Audio processing settings
|
||||||
|
self.args = models.args
|
||||||
|
self.sample_rate = 16000
|
||||||
|
self.channels = 1
|
||||||
|
self.samples_per_sec = int(self.sample_rate * self.args.min_chunk_size)
|
||||||
|
self.bytes_per_sample = 2
|
||||||
|
self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample
|
||||||
|
self.max_bytes_per_sec = 32000 * 5 # 5 seconds of audio at 32 kHz
|
||||||
|
|
||||||
|
# State management
|
||||||
|
self.tokens = []
|
||||||
|
self.buffer_transcription = ""
|
||||||
|
self.buffer_diarization = ""
|
||||||
|
self.full_transcription = ""
|
||||||
|
self.end_buffer = 0
|
||||||
|
self.end_attributed_speaker = 0
|
||||||
|
self.lock = asyncio.Lock()
|
||||||
|
self.beg_loop = time()
|
||||||
|
self.sep = " " # Default separator
|
||||||
|
self.last_response_content = ""
|
||||||
|
|
||||||
|
# Models and processing
|
||||||
|
self.asr = models.asr
|
||||||
|
self.tokenizer = models.tokenizer
|
||||||
|
self.diarization = models.diarization
|
||||||
|
self.ffmpeg_process = self.start_ffmpeg_decoder()
|
||||||
|
self.transcription_queue = asyncio.Queue() if self.args.transcription else None
|
||||||
|
self.diarization_queue = asyncio.Queue() if self.args.diarization else None
|
||||||
|
self.pcm_buffer = bytearray()
|
||||||
|
|
||||||
|
# Initialize transcription engine if enabled
|
||||||
|
if self.args.transcription:
|
||||||
|
self.online = online_factory(self.args, models.asr, models.tokenizer)
|
||||||
|
|
||||||
|
def convert_pcm_to_float(self, pcm_buffer):
|
||||||
|
"""Convert PCM buffer in s16le format to normalized NumPy array."""
|
||||||
|
return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
|
||||||
|
|
||||||
|
def start_ffmpeg_decoder(self):
|
||||||
|
"""Start FFmpeg process for WebM to PCM conversion."""
|
||||||
|
return (ffmpeg.input("pipe:0", format="webm")
|
||||||
|
.output("pipe:1", format="s16le", acodec="pcm_s16le",
|
||||||
|
ac=self.channels, ar=str(self.sample_rate))
|
||||||
|
.run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True))
|
||||||
|
|
||||||
|
async def restart_ffmpeg(self):
|
||||||
|
"""Restart the FFmpeg process after failure."""
|
||||||
|
if self.ffmpeg_process:
|
||||||
|
try:
|
||||||
|
self.ffmpeg_process.kill()
|
||||||
|
await asyncio.get_event_loop().run_in_executor(None, self.ffmpeg_process.wait)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error killing FFmpeg process: {e}")
|
||||||
|
self.ffmpeg_process = self.start_ffmpeg_decoder()
|
||||||
|
self.pcm_buffer = bytearray()
|
||||||
|
|
||||||
|
async def update_transcription(self, new_tokens, buffer, end_buffer, full_transcription, sep):
|
||||||
|
"""Thread-safe update of transcription with new data."""
|
||||||
|
async with self.lock:
|
||||||
|
self.tokens.extend(new_tokens)
|
||||||
|
self.buffer_transcription = buffer
|
||||||
|
self.end_buffer = end_buffer
|
||||||
|
self.full_transcription = full_transcription
|
||||||
|
self.sep = sep
|
||||||
|
|
||||||
|
async def update_diarization(self, end_attributed_speaker, buffer_diarization=""):
|
||||||
|
"""Thread-safe update of diarization with new data."""
|
||||||
|
async with self.lock:
|
||||||
|
self.end_attributed_speaker = end_attributed_speaker
|
||||||
|
if buffer_diarization:
|
||||||
|
self.buffer_diarization = buffer_diarization
|
||||||
|
|
||||||
|
async def add_dummy_token(self):
|
||||||
|
"""Placeholder token when no transcription is available."""
|
||||||
|
async with self.lock:
|
||||||
|
current_time = time() - self.beg_loop
|
||||||
|
self.tokens.append(ASRToken(
|
||||||
|
start=current_time, end=current_time + 1,
|
||||||
|
text=".", speaker=-1, is_dummy=True
|
||||||
|
))
|
||||||
|
|
||||||
|
async def get_current_state(self):
|
||||||
|
"""Get current state."""
|
||||||
|
async with self.lock:
|
||||||
|
current_time = time()
|
||||||
|
|
||||||
|
# Calculate remaining times
|
||||||
|
remaining_transcription = 0
|
||||||
|
if self.end_buffer > 0:
|
||||||
|
remaining_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 2))
|
||||||
|
|
||||||
|
remaining_diarization = 0
|
||||||
|
if self.tokens:
|
||||||
|
latest_end = max(self.end_buffer, self.tokens[-1].end if self.tokens else 0)
|
||||||
|
remaining_diarization = max(0, round(latest_end - self.end_attributed_speaker, 2))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"tokens": self.tokens.copy(),
|
||||||
|
"buffer_transcription": self.buffer_transcription,
|
||||||
|
"buffer_diarization": self.buffer_diarization,
|
||||||
|
"end_buffer": self.end_buffer,
|
||||||
|
"end_attributed_speaker": self.end_attributed_speaker,
|
||||||
|
"sep": self.sep,
|
||||||
|
"remaining_time_transcription": remaining_transcription,
|
||||||
|
"remaining_time_diarization": remaining_diarization
|
||||||
|
}
|
||||||
|
|
||||||
|
async def reset(self):
|
||||||
|
"""Reset all state variables to initial values."""
|
||||||
|
async with self.lock:
|
||||||
|
self.tokens = []
|
||||||
|
self.buffer_transcription = self.buffer_diarization = ""
|
||||||
|
self.end_buffer = self.end_attributed_speaker = 0
|
||||||
|
self.full_transcription = self.last_response_content = ""
|
||||||
|
self.beg_loop = time()
|
||||||
|
|
||||||
|
async def ffmpeg_stdout_reader(self):
|
||||||
|
"""Read audio data from FFmpeg stdout and process it."""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
beg = time()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# Calculate buffer size based on elapsed time
|
||||||
|
elapsed_time = math.floor((time() - beg) * 10) / 10 # Round to 0.1 sec
|
||||||
|
buffer_size = max(int(32000 * elapsed_time), 4096)
|
||||||
|
beg = time()
|
||||||
|
|
||||||
|
# Read chunk with timeout
|
||||||
|
try:
|
||||||
|
chunk = await asyncio.wait_for(
|
||||||
|
loop.run_in_executor(None, self.ffmpeg_process.stdout.read, buffer_size),
|
||||||
|
timeout=15.0
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning("FFmpeg read timeout. Restarting...")
|
||||||
|
await self.restart_ffmpeg()
|
||||||
|
beg = time()
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not chunk:
|
||||||
|
logger.info("FFmpeg stdout closed.")
|
||||||
|
break
|
||||||
|
|
||||||
|
self.pcm_buffer.extend(chunk)
|
||||||
|
|
||||||
|
# Send to diarization if enabled
|
||||||
|
if self.args.diarization and self.diarization_queue:
|
||||||
|
await self.diarization_queue.put(
|
||||||
|
self.convert_pcm_to_float(self.pcm_buffer).copy()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process when we have enough data
|
||||||
|
if len(self.pcm_buffer) >= self.bytes_per_sec:
|
||||||
|
if len(self.pcm_buffer) > self.max_bytes_per_sec:
|
||||||
|
logger.warning(
|
||||||
|
f"Audio buffer too large: {len(self.pcm_buffer) / self.bytes_per_sec:.2f}s. "
|
||||||
|
f"Consider using a smaller model."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process audio chunk
|
||||||
|
pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:self.max_bytes_per_sec])
|
||||||
|
self.pcm_buffer = self.pcm_buffer[self.max_bytes_per_sec:]
|
||||||
|
|
||||||
|
# Send to transcription if enabled
|
||||||
|
if self.args.transcription and self.transcription_queue:
|
||||||
|
await self.transcription_queue.put(pcm_array.copy())
|
||||||
|
|
||||||
|
# Sleep if no processing is happening
|
||||||
|
if not self.args.transcription and not self.args.diarization:
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Exception in ffmpeg_stdout_reader: {e}")
|
||||||
|
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||||
|
break
|
||||||
|
|
||||||
|
async def transcription_processor(self):
|
||||||
|
"""Process audio chunks for transcription."""
|
||||||
|
self.full_transcription = ""
|
||||||
|
self.sep = self.online.asr.sep
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
pcm_array = await self.transcription_queue.get()
|
||||||
|
|
||||||
|
logger.info(f"{len(self.online.audio_buffer) / self.online.SAMPLING_RATE} seconds of audio to process.")
|
||||||
|
|
||||||
|
# Process transcription
|
||||||
|
self.online.insert_audio_chunk(pcm_array)
|
||||||
|
new_tokens = self.online.process_iter()
|
||||||
|
|
||||||
|
if new_tokens:
|
||||||
|
self.full_transcription += self.sep.join([t.text for t in new_tokens])
|
||||||
|
|
||||||
|
# Get buffer information
|
||||||
|
_buffer = self.online.get_buffer()
|
||||||
|
buffer = _buffer.text
|
||||||
|
end_buffer = _buffer.end if _buffer.end else (
|
||||||
|
new_tokens[-1].end if new_tokens else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Avoid duplicating content
|
||||||
|
if buffer in self.full_transcription:
|
||||||
|
buffer = ""
|
||||||
|
|
||||||
|
await self.update_transcription(
|
||||||
|
new_tokens, buffer, end_buffer, self.full_transcription, self.sep
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Exception in transcription_processor: {e}")
|
||||||
|
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||||
|
finally:
|
||||||
|
self.transcription_queue.task_done()
|
||||||
|
|
||||||
|
async def diarization_processor(self, diarization_obj):
|
||||||
|
"""Process audio chunks for speaker diarization."""
|
||||||
|
buffer_diarization = ""
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
pcm_array = await self.diarization_queue.get()
|
||||||
|
|
||||||
|
# Process diarization
|
||||||
|
await diarization_obj.diarize(pcm_array)
|
||||||
|
|
||||||
|
# Get current state and update speakers
|
||||||
|
state = await self.get_current_state()
|
||||||
|
new_end = diarization_obj.assign_speakers_to_tokens(
|
||||||
|
state["end_attributed_speaker"], state["tokens"]
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.update_diarization(new_end, buffer_diarization)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Exception in diarization_processor: {e}")
|
||||||
|
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||||
|
finally:
|
||||||
|
self.diarization_queue.task_done()
|
||||||
|
|
||||||
|
async def results_formatter(self):
|
||||||
|
"""Format processing results for output."""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# Get current state
|
||||||
|
state = await self.get_current_state()
|
||||||
|
tokens = state["tokens"]
|
||||||
|
buffer_transcription = state["buffer_transcription"]
|
||||||
|
buffer_diarization = state["buffer_diarization"]
|
||||||
|
end_attributed_speaker = state["end_attributed_speaker"]
|
||||||
|
sep = state["sep"]
|
||||||
|
|
||||||
|
# Add dummy tokens if needed
|
||||||
|
if (not tokens or tokens[-1].is_dummy) and not self.args.transcription and self.args.diarization:
|
||||||
|
await self.add_dummy_token()
|
||||||
|
sleep(0.5)
|
||||||
|
state = await self.get_current_state()
|
||||||
|
tokens = state["tokens"]
|
||||||
|
|
||||||
|
# Format output
|
||||||
|
previous_speaker = -1
|
||||||
|
lines = []
|
||||||
|
last_end_diarized = 0
|
||||||
|
undiarized_text = []
|
||||||
|
|
||||||
|
# Process each token
|
||||||
|
for token in tokens:
|
||||||
|
speaker = token.speaker
|
||||||
|
|
||||||
|
# Handle diarization
|
||||||
|
if self.args.diarization:
|
||||||
|
if (speaker in [-1, 0]) and token.end >= end_attributed_speaker:
|
||||||
|
undiarized_text.append(token.text)
|
||||||
|
continue
|
||||||
|
elif (speaker in [-1, 0]) and token.end < end_attributed_speaker:
|
||||||
|
speaker = previous_speaker
|
||||||
|
if speaker not in [-1, 0]:
|
||||||
|
last_end_diarized = max(token.end, last_end_diarized)
|
||||||
|
|
||||||
|
# Group by speaker
|
||||||
|
if speaker != previous_speaker or not lines:
|
||||||
|
lines.append({
|
||||||
|
"speaker": speaker,
|
||||||
|
"text": token.text,
|
||||||
|
"beg": format_time(token.start),
|
||||||
|
"end": format_time(token.end),
|
||||||
|
"diff": round(token.end - last_end_diarized, 2)
|
||||||
|
})
|
||||||
|
previous_speaker = speaker
|
||||||
|
elif token.text: # Only append if text isn't empty
|
||||||
|
lines[-1]["text"] += sep + token.text
|
||||||
|
lines[-1]["end"] = format_time(token.end)
|
||||||
|
lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
|
||||||
|
|
||||||
|
# Handle undiarized text
|
||||||
|
if undiarized_text:
|
||||||
|
combined = sep.join(undiarized_text)
|
||||||
|
if buffer_transcription:
|
||||||
|
combined += sep
|
||||||
|
await self.update_diarization(end_attributed_speaker, combined)
|
||||||
|
buffer_diarization = combined
|
||||||
|
|
||||||
|
# Create response object
|
||||||
|
if not lines:
|
||||||
|
lines = [{
|
||||||
|
"speaker": 1,
|
||||||
|
"text": "",
|
||||||
|
"beg": format_time(0),
|
||||||
|
"end": format_time(tokens[-1].end if tokens else 0),
|
||||||
|
"diff": 0
|
||||||
|
}]
|
||||||
|
|
||||||
|
response = {
|
||||||
|
"lines": lines,
|
||||||
|
"buffer_transcription": buffer_transcription,
|
||||||
|
"buffer_diarization": buffer_diarization,
|
||||||
|
"remaining_time_transcription": state["remaining_time_transcription"],
|
||||||
|
"remaining_time_diarization": state["remaining_time_diarization"]
|
||||||
|
}
|
||||||
|
|
||||||
|
# Only yield if content has changed
|
||||||
|
response_content = ' '.join([f"{line['speaker']} {line['text']}" for line in lines]) + \
|
||||||
|
f" | {buffer_transcription} | {buffer_diarization}"
|
||||||
|
|
||||||
|
if response_content != self.last_response_content and (lines or buffer_transcription or buffer_diarization):
|
||||||
|
yield response
|
||||||
|
self.last_response_content = response_content
|
||||||
|
|
||||||
|
await asyncio.sleep(0.1) # Avoid overwhelming the client
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Exception in results_formatter: {e}")
|
||||||
|
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||||
|
await asyncio.sleep(0.5) # Back off on error
|
||||||
|
|
||||||
|
async def create_tasks(self):
|
||||||
|
"""Create and start processing tasks."""
|
||||||
|
|
||||||
|
tasks = []
|
||||||
|
if self.args.transcription and self.online:
|
||||||
|
tasks.append(asyncio.create_task(self.transcription_processor()))
|
||||||
|
|
||||||
|
if self.args.diarization and self.diarization:
|
||||||
|
tasks.append(asyncio.create_task(self.diarization_processor(self.diarization)))
|
||||||
|
|
||||||
|
tasks.append(asyncio.create_task(self.ffmpeg_stdout_reader()))
|
||||||
|
self.tasks = tasks
|
||||||
|
|
||||||
|
return self.results_formatter()
|
||||||
|
|
||||||
|
async def cleanup(self):
|
||||||
|
"""Clean up resources when processing is complete."""
|
||||||
|
for task in self.tasks:
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
try:
|
||||||
|
await asyncio.gather(*self.tasks, return_exceptions=True)
|
||||||
|
self.ffmpeg_process.stdin.close()
|
||||||
|
self.ffmpeg_process.wait()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error during cleanup: {e}")
|
||||||
|
|
||||||
|
if self.args.diarization and hasattr(self, 'diarization'):
|
||||||
|
self.diarization.close()
|
||||||
|
|
||||||
|
async def process_audio(self, message):
|
||||||
|
"""Process incoming audio data."""
|
||||||
|
try:
|
||||||
|
self.ffmpeg_process.stdin.write(message)
|
||||||
|
self.ffmpeg_process.stdin.flush()
|
||||||
|
except (BrokenPipeError, AttributeError) as e:
|
||||||
|
logger.warning(f"Error writing to FFmpeg: {e}. Restarting...")
|
||||||
|
await self.restart_ffmpeg()
|
||||||
|
self.ffmpeg_process.stdin.write(message)
|
||||||
|
self.ffmpeg_process.stdin.flush()
|
||||||
174
whisperlivekit/core.py
Normal file
174
whisperlivekit/core.py
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
from whisperlivekit.whisper_streaming_custom.whisper_online import backend_factory, warmup_asr
|
||||||
|
from argparse import Namespace, ArgumentParser
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = ArgumentParser(description="Whisper FastAPI Online Server")
|
||||||
|
parser.add_argument(
|
||||||
|
"--host",
|
||||||
|
type=str,
|
||||||
|
default="localhost",
|
||||||
|
help="The host address to bind the server to.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--port", type=int, default=8000, help="The port number to bind the server to."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--warmup-file",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
dest="warmup_file",
|
||||||
|
help="""
|
||||||
|
The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast.
|
||||||
|
If not set, uses https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav.
|
||||||
|
If False, no warmup is performed.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--confidence-validation",
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
help="Accelerates validation of tokens using confidence scores. Transcription will be faster but punctuation might be less accurate.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--diarization",
|
||||||
|
type=bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to enable speaker diarization.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--transcription",
|
||||||
|
type=bool,
|
||||||
|
default=True,
|
||||||
|
help="To disable to only see live diarization results.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--min-chunk-size",
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help="Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
type=str,
|
||||||
|
default="tiny",
|
||||||
|
choices="tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo".split(
|
||||||
|
","
|
||||||
|
),
|
||||||
|
help="Name size of the Whisper model to use (default: large-v2). The model is automatically downloaded from the model hub if not present in model cache dir.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_cache_dir",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Overriding the default model cache dir where models downloaded from the hub are saved",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_dir",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lan",
|
||||||
|
"--language",
|
||||||
|
type=str,
|
||||||
|
default="auto",
|
||||||
|
help="Source language code, e.g. en,de,cs, or 'auto' for language detection.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--task",
|
||||||
|
type=str,
|
||||||
|
default="transcribe",
|
||||||
|
choices=["transcribe", "translate"],
|
||||||
|
help="Transcribe or translate.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--backend",
|
||||||
|
type=str,
|
||||||
|
default="faster-whisper",
|
||||||
|
choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"],
|
||||||
|
help="Load only this backend for Whisper processing.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vac",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Use VAC = voice activity controller. Recommended. Requires torch.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vac-chunk-size", type=float, default=0.04, help="VAC sample size in seconds."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vad",
|
||||||
|
action="store_true",
|
||||||
|
default=True,
|
||||||
|
help="Use VAD = voice activity detection, with the default parameters.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--buffer_trimming",
|
||||||
|
type=str,
|
||||||
|
default="segment",
|
||||||
|
choices=["sentence", "segment"],
|
||||||
|
help='Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--buffer_trimming_sec",
|
||||||
|
type=float,
|
||||||
|
default=15,
|
||||||
|
help="Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-l",
|
||||||
|
"--log-level",
|
||||||
|
dest="log_level",
|
||||||
|
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
||||||
|
help="Set the log level",
|
||||||
|
default="DEBUG",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
class WhisperLiveKit:
|
||||||
|
_instance = None
|
||||||
|
_initialized = False
|
||||||
|
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
if WhisperLiveKit._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
default_args = vars(parse_args())
|
||||||
|
|
||||||
|
merged_args = {**default_args, **kwargs}
|
||||||
|
|
||||||
|
self.args = Namespace(**merged_args)
|
||||||
|
|
||||||
|
self.asr = None
|
||||||
|
self.tokenizer = None
|
||||||
|
self.diarization = None
|
||||||
|
|
||||||
|
if self.args.transcription:
|
||||||
|
self.asr, self.tokenizer = backend_factory(self.args)
|
||||||
|
warmup_asr(self.asr, self.args.warmup_file)
|
||||||
|
|
||||||
|
if self.args.diarization:
|
||||||
|
from whisperlivekit.diarization.diarization_online import DiartDiarization
|
||||||
|
self.diarization = DiartDiarization()
|
||||||
|
|
||||||
|
WhisperLiveKit._initialized = True
|
||||||
|
|
||||||
|
def web_interface(self):
|
||||||
|
import pkg_resources
|
||||||
|
html_path = pkg_resources.resource_filename('whisperlivekit', 'web/live_transcription.html')
|
||||||
|
with open(html_path, "r", encoding="utf-8") as f:
|
||||||
|
html = f.read()
|
||||||
|
return html
|
||||||
153
whisperlivekit/diarization/diarization_online.py
Normal file
153
whisperlivekit/diarization/diarization_online.py
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
import asyncio
|
||||||
|
import re
|
||||||
|
import threading
|
||||||
|
import numpy as np
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
from diart import SpeakerDiarization, SpeakerDiarizationConfig
|
||||||
|
from diart.inference import StreamingInference
|
||||||
|
from diart.sources import AudioSource
|
||||||
|
from whisperlivekit.timed_objects import SpeakerSegment
|
||||||
|
from diart.sources import MicrophoneAudioSource
|
||||||
|
from rx.core import Observer
|
||||||
|
from typing import Tuple, Any, List
|
||||||
|
from pyannote.core import Annotation
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def extract_number(s: str) -> int:
|
||||||
|
m = re.search(r'\d+', s)
|
||||||
|
return int(m.group()) if m else None
|
||||||
|
|
||||||
|
class DiarizationObserver(Observer):
|
||||||
|
"""Observer that logs all data emitted by the diarization pipeline and stores speaker segments."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.speaker_segments = []
|
||||||
|
self.processed_time = 0
|
||||||
|
self.segment_lock = threading.Lock()
|
||||||
|
|
||||||
|
def on_next(self, value: Tuple[Annotation, Any]):
|
||||||
|
annotation, audio = value
|
||||||
|
|
||||||
|
logger.debug("\n--- New Diarization Result ---")
|
||||||
|
|
||||||
|
duration = audio.extent.end - audio.extent.start
|
||||||
|
logger.debug(f"Audio segment: {audio.extent.start:.2f}s - {audio.extent.end:.2f}s (duration: {duration:.2f}s)")
|
||||||
|
logger.debug(f"Audio shape: {audio.data.shape}")
|
||||||
|
|
||||||
|
with self.segment_lock:
|
||||||
|
if audio.extent.end > self.processed_time:
|
||||||
|
self.processed_time = audio.extent.end
|
||||||
|
if annotation and len(annotation._labels) > 0:
|
||||||
|
logger.debug("\nSpeaker segments:")
|
||||||
|
for speaker, label in annotation._labels.items():
|
||||||
|
for start, end in zip(label.segments_boundaries_[:-1], label.segments_boundaries_[1:]):
|
||||||
|
print(f" {speaker}: {start:.2f}s-{end:.2f}s")
|
||||||
|
self.speaker_segments.append(SpeakerSegment(
|
||||||
|
speaker=speaker,
|
||||||
|
start=start,
|
||||||
|
end=end
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
logger.debug("\nNo speakers detected in this segment")
|
||||||
|
|
||||||
|
def get_segments(self) -> List[SpeakerSegment]:
|
||||||
|
"""Get a copy of the current speaker segments."""
|
||||||
|
with self.segment_lock:
|
||||||
|
return self.speaker_segments.copy()
|
||||||
|
|
||||||
|
def clear_old_segments(self, older_than: float = 30.0):
|
||||||
|
"""Clear segments older than the specified time."""
|
||||||
|
with self.segment_lock:
|
||||||
|
current_time = self.processed_time
|
||||||
|
self.speaker_segments = [
|
||||||
|
segment for segment in self.speaker_segments
|
||||||
|
if current_time - segment.end < older_than
|
||||||
|
]
|
||||||
|
|
||||||
|
def on_error(self, error):
|
||||||
|
"""Handle an error in the stream."""
|
||||||
|
logger.debug(f"Error in diarization stream: {error}")
|
||||||
|
|
||||||
|
def on_completed(self):
|
||||||
|
"""Handle the completion of the stream."""
|
||||||
|
logger.debug("Diarization stream completed")
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocketAudioSource(AudioSource):
|
||||||
|
"""
|
||||||
|
Custom AudioSource that blocks in read() until close() is called.
|
||||||
|
Use push_audio() to inject PCM chunks.
|
||||||
|
"""
|
||||||
|
def __init__(self, uri: str = "websocket", sample_rate: int = 16000):
|
||||||
|
super().__init__(uri, sample_rate)
|
||||||
|
self._closed = False
|
||||||
|
self._close_event = threading.Event()
|
||||||
|
|
||||||
|
def read(self):
|
||||||
|
self._close_event.wait()
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
if not self._closed:
|
||||||
|
self._closed = True
|
||||||
|
self.stream.on_completed()
|
||||||
|
self._close_event.set()
|
||||||
|
|
||||||
|
def push_audio(self, chunk: np.ndarray):
|
||||||
|
if not self._closed:
|
||||||
|
new_audio = np.expand_dims(chunk, axis=0)
|
||||||
|
logger.debug('Add new chunk with shape:', new_audio.shape)
|
||||||
|
self.stream.on_next(new_audio)
|
||||||
|
|
||||||
|
|
||||||
|
class DiartDiarization:
|
||||||
|
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False):
|
||||||
|
self.pipeline = SpeakerDiarization(config=config)
|
||||||
|
self.observer = DiarizationObserver()
|
||||||
|
|
||||||
|
if use_microphone:
|
||||||
|
self.source = MicrophoneAudioSource()
|
||||||
|
self.custom_source = None
|
||||||
|
else:
|
||||||
|
self.custom_source = WebSocketAudioSource(uri="websocket_source", sample_rate=sample_rate)
|
||||||
|
self.source = self.custom_source
|
||||||
|
|
||||||
|
self.inference = StreamingInference(
|
||||||
|
pipeline=self.pipeline,
|
||||||
|
source=self.source,
|
||||||
|
do_plot=False,
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
self.inference.attach_observers(self.observer)
|
||||||
|
asyncio.get_event_loop().run_in_executor(None, self.inference)
|
||||||
|
|
||||||
|
async def diarize(self, pcm_array: np.ndarray):
|
||||||
|
"""
|
||||||
|
Process audio data for diarization.
|
||||||
|
Only used when working with WebSocketAudioSource.
|
||||||
|
"""
|
||||||
|
if self.custom_source:
|
||||||
|
self.custom_source.push_audio(pcm_array)
|
||||||
|
self.observer.clear_old_segments()
|
||||||
|
return self.observer.get_segments()
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Close the audio source."""
|
||||||
|
if self.custom_source:
|
||||||
|
self.custom_source.close()
|
||||||
|
|
||||||
|
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) -> float:
|
||||||
|
"""
|
||||||
|
Assign speakers to tokens based on timing overlap with speaker segments.
|
||||||
|
Uses the segments collected by the observer.
|
||||||
|
"""
|
||||||
|
segments = self.observer.get_segments()
|
||||||
|
|
||||||
|
for token in tokens:
|
||||||
|
for segment in segments:
|
||||||
|
if not (segment.end <= token.start or segment.start >= token.end):
|
||||||
|
token.speaker = extract_number(segment.speaker) + 1
|
||||||
|
end_attributed_speaker = max(token.end, end_attributed_speaker)
|
||||||
|
return end_attributed_speaker
|
||||||
163
whisperlivekit/silero_vad_iterator.py
Normal file
163
whisperlivekit/silero_vad_iterator.py
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
# This is copied from silero-vad's vad_utils.py:
|
||||||
|
# https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/utils_vad.py#L340
|
||||||
|
# (except changed defaults)
|
||||||
|
|
||||||
|
# Their licence is MIT, same as ours: https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/LICENSE
|
||||||
|
|
||||||
|
|
||||||
|
class VADIterator:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
threshold: float = 0.5,
|
||||||
|
sampling_rate: int = 16000,
|
||||||
|
min_silence_duration_ms: int = 500, # makes sense on one recording that I checked
|
||||||
|
speech_pad_ms: int = 100, # same
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Class for stream imitation
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model: preloaded .jit silero VAD model
|
||||||
|
|
||||||
|
threshold: float (default - 0.5)
|
||||||
|
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
|
||||||
|
It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
|
||||||
|
|
||||||
|
sampling_rate: int (default - 16000)
|
||||||
|
Currently silero VAD models support 8000 and 16000 sample rates
|
||||||
|
|
||||||
|
min_silence_duration_ms: int (default - 100 milliseconds)
|
||||||
|
In the end of each speech chunk wait for min_silence_duration_ms before separating it
|
||||||
|
|
||||||
|
speech_pad_ms: int (default - 30 milliseconds)
|
||||||
|
Final speech chunks are padded by speech_pad_ms each side
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
self.threshold = threshold
|
||||||
|
self.sampling_rate = sampling_rate
|
||||||
|
|
||||||
|
if sampling_rate not in [8000, 16000]:
|
||||||
|
raise ValueError(
|
||||||
|
"VADIterator does not support sampling rates other than [8000, 16000]"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
|
||||||
|
self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
|
||||||
|
self.reset_states()
|
||||||
|
|
||||||
|
def reset_states(self):
|
||||||
|
|
||||||
|
self.model.reset_states()
|
||||||
|
self.triggered = False
|
||||||
|
self.temp_end = 0
|
||||||
|
self.current_sample = 0
|
||||||
|
|
||||||
|
def __call__(self, x, return_seconds=False):
|
||||||
|
"""
|
||||||
|
x: torch.Tensor
|
||||||
|
audio chunk (see examples in repo)
|
||||||
|
|
||||||
|
return_seconds: bool (default - False)
|
||||||
|
whether return timestamps in seconds (default - samples)
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not torch.is_tensor(x):
|
||||||
|
try:
|
||||||
|
x = torch.Tensor(x)
|
||||||
|
except:
|
||||||
|
raise TypeError("Audio cannot be casted to tensor. Cast it manually")
|
||||||
|
|
||||||
|
window_size_samples = len(x[0]) if x.dim() == 2 else len(x)
|
||||||
|
self.current_sample += window_size_samples
|
||||||
|
|
||||||
|
speech_prob = self.model(x, self.sampling_rate).item()
|
||||||
|
|
||||||
|
if (speech_prob >= self.threshold) and self.temp_end:
|
||||||
|
self.temp_end = 0
|
||||||
|
|
||||||
|
if (speech_prob >= self.threshold) and not self.triggered:
|
||||||
|
self.triggered = True
|
||||||
|
speech_start = self.current_sample - self.speech_pad_samples
|
||||||
|
return {
|
||||||
|
"start": (
|
||||||
|
int(speech_start)
|
||||||
|
if not return_seconds
|
||||||
|
else round(speech_start / self.sampling_rate, 1)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (speech_prob < self.threshold - 0.15) and self.triggered:
|
||||||
|
if not self.temp_end:
|
||||||
|
self.temp_end = self.current_sample
|
||||||
|
if self.current_sample - self.temp_end < self.min_silence_samples:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
speech_end = self.temp_end + self.speech_pad_samples
|
||||||
|
self.temp_end = 0
|
||||||
|
self.triggered = False
|
||||||
|
return {
|
||||||
|
"end": (
|
||||||
|
int(speech_end)
|
||||||
|
if not return_seconds
|
||||||
|
else round(speech_end / self.sampling_rate, 1)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
#######################
|
||||||
|
# because Silero now requires exactly 512-sized audio chunks
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class FixedVADIterator(VADIterator):
|
||||||
|
"""It fixes VADIterator by allowing to process any audio length, not only exactly 512 frames at once.
|
||||||
|
If audio to be processed at once is long and multiple voiced segments detected,
|
||||||
|
then __call__ returns the start of the first segment, and end (or middle, which means no end) of the last segment.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def reset_states(self):
|
||||||
|
super().reset_states()
|
||||||
|
self.buffer = np.array([], dtype=np.float32)
|
||||||
|
|
||||||
|
def __call__(self, x, return_seconds=False):
|
||||||
|
self.buffer = np.append(self.buffer, x)
|
||||||
|
ret = None
|
||||||
|
while len(self.buffer) >= 512:
|
||||||
|
r = super().__call__(self.buffer[:512], return_seconds=return_seconds)
|
||||||
|
self.buffer = self.buffer[512:]
|
||||||
|
if ret is None:
|
||||||
|
ret = r
|
||||||
|
elif r is not None:
|
||||||
|
if "end" in r:
|
||||||
|
ret["end"] = r["end"] # the latter end
|
||||||
|
if "start" in r and "end" in ret: # there is an earlier start.
|
||||||
|
# Remove end, merging this segment with the previous one.
|
||||||
|
del ret["end"]
|
||||||
|
return ret if ret != {} else None
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# test/demonstrate the need for FixedVADIterator:
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
|
||||||
|
vac = FixedVADIterator(model)
|
||||||
|
# vac = VADIterator(model) # the second case crashes with this
|
||||||
|
|
||||||
|
# this works: for both
|
||||||
|
audio_buffer = np.array([0] * (512), dtype=np.float32)
|
||||||
|
vac(audio_buffer)
|
||||||
|
|
||||||
|
# this crashes on the non FixedVADIterator with
|
||||||
|
# ops.prim.RaiseException("Input audio chunk is too short", "builtins.ValueError")
|
||||||
|
audio_buffer = np.array([0] * (512 - 1), dtype=np.float32)
|
||||||
|
vac(audio_buffer)
|
||||||
29
whisperlivekit/timed_objects.py
Normal file
29
whisperlivekit/timed_objects.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TimedText:
|
||||||
|
start: Optional[float]
|
||||||
|
end: Optional[float]
|
||||||
|
text: Optional[str] = ''
|
||||||
|
speaker: Optional[int] = -1
|
||||||
|
probability: Optional[float] = None
|
||||||
|
is_dummy: Optional[bool] = False
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ASRToken(TimedText):
|
||||||
|
def with_offset(self, offset: float) -> "ASRToken":
|
||||||
|
"""Return a new token with the time offset added."""
|
||||||
|
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, self.probability)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Sentence(TimedText):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Transcript(TimedText):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SpeakerSegment(TimedText):
|
||||||
|
pass
|
||||||
568
whisperlivekit/web/live_transcription.html
Normal file
568
whisperlivekit/web/live_transcription.html
Normal file
@@ -0,0 +1,568 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8" />
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||||
|
<title>Audio Transcription</title>
|
||||||
|
<style>
|
||||||
|
body {
|
||||||
|
font-family: ui-sans-serif, system-ui, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol', 'Noto Color Emoji';
|
||||||
|
margin: 20px;
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
#recordButton {
|
||||||
|
width: 50px;
|
||||||
|
height: 50px;
|
||||||
|
border: none;
|
||||||
|
border-radius: 50%;
|
||||||
|
background-color: white;
|
||||||
|
cursor: pointer;
|
||||||
|
transition: all 0.3s ease;
|
||||||
|
border: 1px solid rgb(233, 233, 233);
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
position: relative;
|
||||||
|
}
|
||||||
|
|
||||||
|
#recordButton.recording {
|
||||||
|
width: 180px;
|
||||||
|
border-radius: 40px;
|
||||||
|
justify-content: flex-start;
|
||||||
|
padding-left: 20px;
|
||||||
|
}
|
||||||
|
|
||||||
|
#recordButton:active {
|
||||||
|
transform: scale(0.95);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Shape inside the button */
|
||||||
|
.shape-container {
|
||||||
|
width: 25px;
|
||||||
|
height: 25px;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
flex-shrink: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.shape {
|
||||||
|
width: 25px;
|
||||||
|
height: 25px;
|
||||||
|
background-color: rgb(209, 61, 53);
|
||||||
|
border-radius: 50%;
|
||||||
|
transition: all 0.3s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
#recordButton.recording .shape {
|
||||||
|
border-radius: 5px;
|
||||||
|
width: 25px;
|
||||||
|
height: 25px;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Recording elements */
|
||||||
|
.recording-info {
|
||||||
|
display: none;
|
||||||
|
align-items: center;
|
||||||
|
margin-left: 15px;
|
||||||
|
flex-grow: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
#recordButton.recording .recording-info {
|
||||||
|
display: flex;
|
||||||
|
}
|
||||||
|
|
||||||
|
.wave-container {
|
||||||
|
width: 60px;
|
||||||
|
height: 30px;
|
||||||
|
position: relative;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
#waveCanvas {
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
}
|
||||||
|
|
||||||
|
.timer {
|
||||||
|
font-size: 14px;
|
||||||
|
font-weight: 500;
|
||||||
|
color: #333;
|
||||||
|
margin-left: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
#status {
|
||||||
|
margin-top: 20px;
|
||||||
|
font-size: 16px;
|
||||||
|
color: #333;
|
||||||
|
}
|
||||||
|
|
||||||
|
.settings-container {
|
||||||
|
display: flex;
|
||||||
|
justify-content: center;
|
||||||
|
align-items: center;
|
||||||
|
gap: 15px;
|
||||||
|
margin-top: 20px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.settings {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
align-items: flex-start;
|
||||||
|
gap: 5px;
|
||||||
|
}
|
||||||
|
|
||||||
|
#chunkSelector,
|
||||||
|
#websocketInput {
|
||||||
|
font-size: 16px;
|
||||||
|
padding: 5px;
|
||||||
|
border-radius: 5px;
|
||||||
|
border: 1px solid #ddd;
|
||||||
|
background-color: #ffffff;
|
||||||
|
max-height: 30px;
|
||||||
|
}
|
||||||
|
|
||||||
|
#websocketInput {
|
||||||
|
width: 200px;
|
||||||
|
}
|
||||||
|
|
||||||
|
#chunkSelector:focus,
|
||||||
|
#websocketInput:focus {
|
||||||
|
outline: none;
|
||||||
|
border-color: #007bff;
|
||||||
|
}
|
||||||
|
|
||||||
|
label {
|
||||||
|
font-size: 14px;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Speaker-labeled transcript area */
|
||||||
|
#linesTranscript {
|
||||||
|
margin: 20px auto;
|
||||||
|
max-width: 700px;
|
||||||
|
text-align: left;
|
||||||
|
font-size: 16px;
|
||||||
|
}
|
||||||
|
|
||||||
|
#linesTranscript p {
|
||||||
|
margin: 0px 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
#linesTranscript strong {
|
||||||
|
color: #333;
|
||||||
|
}
|
||||||
|
|
||||||
|
#speaker {
|
||||||
|
border: 1px solid rgb(229, 229, 229);
|
||||||
|
border-radius: 100px;
|
||||||
|
padding: 2px 10px;
|
||||||
|
font-size: 14px;
|
||||||
|
margin-bottom: 0px;
|
||||||
|
}
|
||||||
|
.label_diarization {
|
||||||
|
background-color: #ffffff66;
|
||||||
|
border-radius: 8px 8px 8px 8px;
|
||||||
|
padding: 2px 10px;
|
||||||
|
margin-left: 10px;
|
||||||
|
display: inline-block;
|
||||||
|
white-space: nowrap;
|
||||||
|
font-size: 14px;
|
||||||
|
margin-bottom: 0px;
|
||||||
|
color: rgb(134, 134, 134)
|
||||||
|
}
|
||||||
|
|
||||||
|
.label_transcription {
|
||||||
|
background-color: #ffffff66;
|
||||||
|
border-radius: 8px 8px 8px 8px;
|
||||||
|
padding: 2px 10px;
|
||||||
|
display: inline-block;
|
||||||
|
white-space: nowrap;
|
||||||
|
margin-left: 10px;
|
||||||
|
font-size: 14px;
|
||||||
|
margin-bottom: 0px;
|
||||||
|
color: #000000
|
||||||
|
}
|
||||||
|
|
||||||
|
#timeInfo {
|
||||||
|
color: #666;
|
||||||
|
margin-left: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.textcontent {
|
||||||
|
font-size: 16px;
|
||||||
|
/* margin-left: 10px; */
|
||||||
|
padding-left: 10px;
|
||||||
|
margin-bottom: 10px;
|
||||||
|
margin-top: 1px;
|
||||||
|
padding-top: 5px;
|
||||||
|
border-radius: 0px 0px 0px 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.buffer_diarization {
|
||||||
|
color: rgb(134, 134, 134);
|
||||||
|
margin-left: 4px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.buffer_transcription {
|
||||||
|
color: #7474748c;
|
||||||
|
margin-left: 4px;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
.spinner {
|
||||||
|
display: inline-block;
|
||||||
|
width: 8px;
|
||||||
|
height: 8px;
|
||||||
|
border: 2px solid #8d8d8d5c;
|
||||||
|
border-top: 2px solid #6c6c6ce5;
|
||||||
|
border-radius: 50%;
|
||||||
|
animation: spin 0.6s linear infinite;
|
||||||
|
vertical-align: middle;
|
||||||
|
margin-bottom: 2px;
|
||||||
|
margin-right: 5px;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes spin {
|
||||||
|
to {
|
||||||
|
transform: rotate(360deg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
.silence {
|
||||||
|
color: #666;
|
||||||
|
background-color: #f3f3f3;
|
||||||
|
font-size: 13px;
|
||||||
|
border-radius: 30px;
|
||||||
|
padding: 2px 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.loading {
|
||||||
|
color: #666;
|
||||||
|
background-color: #ff4d4d0f;
|
||||||
|
border-radius: 8px 8px 8px 0px;
|
||||||
|
padding: 2px 10px;
|
||||||
|
font-size: 14px;
|
||||||
|
margin-bottom: 0px;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
|
||||||
|
<body>
|
||||||
|
|
||||||
|
<div class="settings-container">
|
||||||
|
<button id="recordButton">
|
||||||
|
<div class="shape-container">
|
||||||
|
<div class="shape"></div>
|
||||||
|
</div>
|
||||||
|
<div class="recording-info">
|
||||||
|
<div class="wave-container">
|
||||||
|
<canvas id="waveCanvas"></canvas>
|
||||||
|
</div>
|
||||||
|
<div class="timer">00:00</div>
|
||||||
|
</div>
|
||||||
|
</button>
|
||||||
|
<div class="settings">
|
||||||
|
<div>
|
||||||
|
<label for="chunkSelector">Chunk size (ms):</label>
|
||||||
|
<select id="chunkSelector">
|
||||||
|
<option value="500">500 ms</option>
|
||||||
|
<option value="1000" selected>1000 ms</option>
|
||||||
|
<option value="2000">2000 ms</option>
|
||||||
|
<option value="3000">3000 ms</option>
|
||||||
|
<option value="4000">4000 ms</option>
|
||||||
|
<option value="5000">5000 ms</option>
|
||||||
|
</select>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<label for="websocketInput">WebSocket URL:</label>
|
||||||
|
<input id="websocketInput" type="text" value="ws://localhost:8000/asr" />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<p id="status"></p>
|
||||||
|
|
||||||
|
<!-- Speaker-labeled transcript -->
|
||||||
|
<div id="linesTranscript"></div>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
let isRecording = false;
|
||||||
|
let websocket = null;
|
||||||
|
let recorder = null;
|
||||||
|
let chunkDuration = 1000;
|
||||||
|
let websocketUrl = "ws://localhost:8000/asr";
|
||||||
|
let userClosing = false;
|
||||||
|
let startTime = null;
|
||||||
|
let timerInterval = null;
|
||||||
|
let audioContext = null;
|
||||||
|
let analyser = null;
|
||||||
|
let microphone = null;
|
||||||
|
let waveCanvas = document.getElementById("waveCanvas");
|
||||||
|
let waveCtx = waveCanvas.getContext("2d");
|
||||||
|
let animationFrame = null;
|
||||||
|
waveCanvas.width = 60 * (window.devicePixelRatio || 1);
|
||||||
|
waveCanvas.height = 30 * (window.devicePixelRatio || 1);
|
||||||
|
waveCtx.scale(window.devicePixelRatio || 1, window.devicePixelRatio || 1);
|
||||||
|
|
||||||
|
const statusText = document.getElementById("status");
|
||||||
|
const recordButton = document.getElementById("recordButton");
|
||||||
|
const chunkSelector = document.getElementById("chunkSelector");
|
||||||
|
const websocketInput = document.getElementById("websocketInput");
|
||||||
|
const linesTranscriptDiv = document.getElementById("linesTranscript");
|
||||||
|
const timerElement = document.querySelector(".timer");
|
||||||
|
|
||||||
|
chunkSelector.addEventListener("change", () => {
|
||||||
|
chunkDuration = parseInt(chunkSelector.value);
|
||||||
|
});
|
||||||
|
|
||||||
|
websocketInput.addEventListener("change", () => {
|
||||||
|
const urlValue = websocketInput.value.trim();
|
||||||
|
if (!urlValue.startsWith("ws://") && !urlValue.startsWith("wss://")) {
|
||||||
|
statusText.textContent = "Invalid WebSocket URL (must start with ws:// or wss://)";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
websocketUrl = urlValue;
|
||||||
|
statusText.textContent = "WebSocket URL updated. Ready to connect.";
|
||||||
|
});
|
||||||
|
|
||||||
|
function setupWebSocket() {
|
||||||
|
return new Promise((resolve, reject) => {
|
||||||
|
try {
|
||||||
|
websocket = new WebSocket(websocketUrl);
|
||||||
|
} catch (error) {
|
||||||
|
statusText.textContent = "Invalid WebSocket URL. Please check and try again.";
|
||||||
|
reject(error);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
websocket.onopen = () => {
|
||||||
|
statusText.textContent = "Connected to server.";
|
||||||
|
resolve();
|
||||||
|
};
|
||||||
|
|
||||||
|
websocket.onclose = () => {
|
||||||
|
if (userClosing) {
|
||||||
|
statusText.textContent = "WebSocket closed by user.";
|
||||||
|
} else {
|
||||||
|
statusText.textContent =
|
||||||
|
"Disconnected from the WebSocket server. (Check logs if model is loading.)";
|
||||||
|
}
|
||||||
|
userClosing = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
websocket.onerror = () => {
|
||||||
|
statusText.textContent = "Error connecting to WebSocket.";
|
||||||
|
reject(new Error("Error connecting to WebSocket"));
|
||||||
|
};
|
||||||
|
|
||||||
|
// Handle messages from server
|
||||||
|
websocket.onmessage = (event) => {
|
||||||
|
const data = JSON.parse(event.data);
|
||||||
|
|
||||||
|
const {
|
||||||
|
lines = [],
|
||||||
|
buffer_transcription = "",
|
||||||
|
buffer_diarization = "",
|
||||||
|
remaining_time_transcription = 0,
|
||||||
|
remaining_time_diarization = 0
|
||||||
|
} = data;
|
||||||
|
|
||||||
|
renderLinesWithBuffer(
|
||||||
|
lines,
|
||||||
|
buffer_diarization,
|
||||||
|
buffer_transcription,
|
||||||
|
remaining_time_diarization,
|
||||||
|
remaining_time_transcription
|
||||||
|
);
|
||||||
|
};
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function renderLinesWithBuffer(lines, buffer_diarization, buffer_transcription, remaining_time_diarization, remaining_time_transcription) {
|
||||||
|
const linesHtml = lines.map((item, idx) => {
|
||||||
|
let timeInfo = "";
|
||||||
|
if (item.beg !== undefined && item.end !== undefined) {
|
||||||
|
timeInfo = ` ${item.beg} - ${item.end}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
let speakerLabel = "";
|
||||||
|
if (item.speaker === -2) {
|
||||||
|
speakerLabel = `<span class="silence">Silence<span id='timeInfo'>${timeInfo}</span></span>`;
|
||||||
|
} else if (item.speaker == 0) {
|
||||||
|
speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'>${remaining_time_diarization} second(s) of audio are undergoing diarization</span></span>`;
|
||||||
|
} else if (item.speaker == -1) {
|
||||||
|
speakerLabel = `<span id="speaker"><span id='timeInfo'>${timeInfo}</span></span>`;
|
||||||
|
} else if (item.speaker !== -1) {
|
||||||
|
speakerLabel = `<span id="speaker">Speaker ${item.speaker}<span id='timeInfo'>${timeInfo}</span></span>`;
|
||||||
|
}
|
||||||
|
|
||||||
|
let textContent = item.text;
|
||||||
|
if (idx === lines.length - 1) {
|
||||||
|
speakerLabel += `<span class="label_transcription"><span class="spinner"></span>Transcription lag <span id='timeInfo'>${remaining_time_transcription}s</span></span>`
|
||||||
|
}
|
||||||
|
if (idx === lines.length - 1 && buffer_diarization) {
|
||||||
|
speakerLabel += `<span class="label_diarization"><span class="spinner"></span>Diarization lag<span id='timeInfo'>${remaining_time_diarization}s</span></span>`
|
||||||
|
textContent += `<span class="buffer_diarization">${buffer_diarization}</span>`;
|
||||||
|
}
|
||||||
|
if (idx === lines.length - 1) {
|
||||||
|
textContent += `<span class="buffer_transcription">${buffer_transcription}</span>`;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
return textContent
|
||||||
|
? `<p>${speakerLabel}<br/><div class='textcontent'>${textContent}</div></p>`
|
||||||
|
: `<p>${speakerLabel}<br/></p>`;
|
||||||
|
}).join("");
|
||||||
|
|
||||||
|
linesTranscriptDiv.innerHTML = linesHtml;
|
||||||
|
}
|
||||||
|
|
||||||
|
function updateTimer() {
|
||||||
|
if (!startTime) return;
|
||||||
|
|
||||||
|
const elapsed = Math.floor((Date.now() - startTime) / 1000);
|
||||||
|
const minutes = Math.floor(elapsed / 60).toString().padStart(2, "0");
|
||||||
|
const seconds = (elapsed % 60).toString().padStart(2, "0");
|
||||||
|
timerElement.textContent = `${minutes}:${seconds}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
function drawWaveform() {
|
||||||
|
if (!analyser) return;
|
||||||
|
|
||||||
|
const bufferLength = analyser.frequencyBinCount;
|
||||||
|
const dataArray = new Uint8Array(bufferLength);
|
||||||
|
analyser.getByteTimeDomainData(dataArray);
|
||||||
|
|
||||||
|
waveCtx.clearRect(0, 0, waveCanvas.width / (window.devicePixelRatio || 1), waveCanvas.height / (window.devicePixelRatio || 1));
|
||||||
|
waveCtx.lineWidth = 1;
|
||||||
|
waveCtx.strokeStyle = 'rgb(0, 0, 0)';
|
||||||
|
waveCtx.beginPath();
|
||||||
|
|
||||||
|
const sliceWidth = (waveCanvas.width / (window.devicePixelRatio || 1)) / bufferLength;
|
||||||
|
let x = 0;
|
||||||
|
|
||||||
|
for (let i = 0; i < bufferLength; i++) {
|
||||||
|
const v = dataArray[i] / 128.0;
|
||||||
|
const y = v * (waveCanvas.height / (window.devicePixelRatio || 1)) / 2;
|
||||||
|
|
||||||
|
if (i === 0) {
|
||||||
|
waveCtx.moveTo(x, y);
|
||||||
|
} else {
|
||||||
|
waveCtx.lineTo(x, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
x += sliceWidth;
|
||||||
|
}
|
||||||
|
|
||||||
|
waveCtx.lineTo(waveCanvas.width / (window.devicePixelRatio || 1), waveCanvas.height / (window.devicePixelRatio || 1) / 2);
|
||||||
|
waveCtx.stroke();
|
||||||
|
|
||||||
|
animationFrame = requestAnimationFrame(drawWaveform);
|
||||||
|
}
|
||||||
|
|
||||||
|
async function startRecording() {
|
||||||
|
try {
|
||||||
|
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
|
||||||
|
|
||||||
|
audioContext = new (window.AudioContext || window.webkitAudioContext)();
|
||||||
|
analyser = audioContext.createAnalyser();
|
||||||
|
analyser.fftSize = 256;
|
||||||
|
microphone = audioContext.createMediaStreamSource(stream);
|
||||||
|
microphone.connect(analyser);
|
||||||
|
|
||||||
|
recorder = new MediaRecorder(stream, { mimeType: "audio/webm" });
|
||||||
|
recorder.ondataavailable = (e) => {
|
||||||
|
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||||
|
websocket.send(e.data);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
recorder.start(chunkDuration);
|
||||||
|
|
||||||
|
startTime = Date.now();
|
||||||
|
timerInterval = setInterval(updateTimer, 1000);
|
||||||
|
drawWaveform();
|
||||||
|
|
||||||
|
isRecording = true;
|
||||||
|
updateUI();
|
||||||
|
} catch (err) {
|
||||||
|
statusText.textContent = "Error accessing microphone. Please allow microphone access.";
|
||||||
|
console.error(err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function stopRecording() {
|
||||||
|
userClosing = true;
|
||||||
|
if (recorder) {
|
||||||
|
recorder.stop();
|
||||||
|
recorder = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (microphone) {
|
||||||
|
microphone.disconnect();
|
||||||
|
microphone = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (analyser) {
|
||||||
|
analyser = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (audioContext && audioContext.state !== 'closed') {
|
||||||
|
try {
|
||||||
|
audioContext.close();
|
||||||
|
} catch (e) {
|
||||||
|
console.warn("Could not close audio context:", e);
|
||||||
|
}
|
||||||
|
audioContext = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (animationFrame) {
|
||||||
|
cancelAnimationFrame(animationFrame);
|
||||||
|
animationFrame = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (timerInterval) {
|
||||||
|
clearInterval(timerInterval);
|
||||||
|
timerInterval = null;
|
||||||
|
}
|
||||||
|
timerElement.textContent = "00:00";
|
||||||
|
startTime = null;
|
||||||
|
|
||||||
|
isRecording = false;
|
||||||
|
|
||||||
|
if (websocket) {
|
||||||
|
websocket.close();
|
||||||
|
websocket = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
updateUI();
|
||||||
|
}
|
||||||
|
|
||||||
|
async function toggleRecording() {
|
||||||
|
if (!isRecording) {
|
||||||
|
linesTranscriptDiv.innerHTML = "";
|
||||||
|
try {
|
||||||
|
await setupWebSocket();
|
||||||
|
await startRecording();
|
||||||
|
} catch (err) {
|
||||||
|
statusText.textContent = "Could not connect to WebSocket or access mic. Aborted.";
|
||||||
|
console.error(err);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
stopRecording();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function updateUI() {
|
||||||
|
recordButton.classList.toggle("recording", isRecording);
|
||||||
|
statusText.textContent = isRecording ? "Recording..." : "Click to start transcription";
|
||||||
|
}
|
||||||
|
|
||||||
|
recordButton.addEventListener("click", toggleRecording);
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
</html>
|
||||||
292
whisperlivekit/whisper_streaming_custom/backends.py
Normal file
292
whisperlivekit/whisper_streaming_custom/backends.py
Normal file
@@ -0,0 +1,292 @@
|
|||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
import io
|
||||||
|
import soundfile as sf
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
from typing import List
|
||||||
|
import numpy as np
|
||||||
|
from whisperlivekit.timed_objects import ASRToken
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class ASRBase:
|
||||||
|
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
|
||||||
|
# "" for faster-whisper because it emits the spaces when needed)
|
||||||
|
|
||||||
|
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr):
|
||||||
|
self.logfile = logfile
|
||||||
|
self.transcribe_kargs = {}
|
||||||
|
if lan == "auto":
|
||||||
|
self.original_language = None
|
||||||
|
else:
|
||||||
|
self.original_language = lan
|
||||||
|
self.model = self.load_model(modelsize, cache_dir, model_dir)
|
||||||
|
|
||||||
|
def with_offset(self, offset: float) -> ASRToken:
|
||||||
|
# This method is kept for compatibility (typically you will use ASRToken.with_offset)
|
||||||
|
return ASRToken(self.start + offset, self.end + offset, self.text)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"ASRToken(start={self.start:.2f}, end={self.end:.2f}, text={self.text!r})"
|
||||||
|
|
||||||
|
def load_model(self, modelsize, cache_dir, model_dir):
|
||||||
|
raise NotImplementedError("must be implemented in the child class")
|
||||||
|
|
||||||
|
def transcribe(self, audio, init_prompt=""):
|
||||||
|
raise NotImplementedError("must be implemented in the child class")
|
||||||
|
|
||||||
|
def use_vad(self):
|
||||||
|
raise NotImplementedError("must be implemented in the child class")
|
||||||
|
|
||||||
|
|
||||||
|
class WhisperTimestampedASR(ASRBase):
|
||||||
|
"""Uses whisper_timestamped as the backend."""
|
||||||
|
sep = " "
|
||||||
|
|
||||||
|
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
||||||
|
import whisper
|
||||||
|
import whisper_timestamped
|
||||||
|
from whisper_timestamped import transcribe_timestamped
|
||||||
|
|
||||||
|
self.transcribe_timestamped = transcribe_timestamped
|
||||||
|
if model_dir is not None:
|
||||||
|
logger.debug("ignoring model_dir, not implemented")
|
||||||
|
return whisper.load_model(modelsize, download_root=cache_dir)
|
||||||
|
|
||||||
|
def transcribe(self, audio, init_prompt=""):
|
||||||
|
result = self.transcribe_timestamped(
|
||||||
|
self.model,
|
||||||
|
audio,
|
||||||
|
language=self.original_language,
|
||||||
|
initial_prompt=init_prompt,
|
||||||
|
verbose=None,
|
||||||
|
condition_on_previous_text=True,
|
||||||
|
**self.transcribe_kargs,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def ts_words(self, r) -> List[ASRToken]:
|
||||||
|
"""
|
||||||
|
Converts the whisper_timestamped result to a list of ASRToken objects.
|
||||||
|
"""
|
||||||
|
tokens = []
|
||||||
|
for segment in r["segments"]:
|
||||||
|
for word in segment["words"]:
|
||||||
|
token = ASRToken(word["start"], word["end"], word["text"])
|
||||||
|
tokens.append(token)
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def segments_end_ts(self, res) -> List[float]:
|
||||||
|
return [segment["end"] for segment in res["segments"]]
|
||||||
|
|
||||||
|
def use_vad(self):
|
||||||
|
self.transcribe_kargs["vad"] = True
|
||||||
|
|
||||||
|
def set_translate_task(self):
|
||||||
|
self.transcribe_kargs["task"] = "translate"
|
||||||
|
|
||||||
|
|
||||||
|
class FasterWhisperASR(ASRBase):
|
||||||
|
"""Uses faster-whisper as the backend."""
|
||||||
|
sep = ""
|
||||||
|
|
||||||
|
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
||||||
|
from faster_whisper import WhisperModel
|
||||||
|
|
||||||
|
if model_dir is not None:
|
||||||
|
logger.debug(f"Loading whisper model from model_dir {model_dir}. "
|
||||||
|
f"modelsize and cache_dir parameters are not used.")
|
||||||
|
model_size_or_path = model_dir
|
||||||
|
elif modelsize is not None:
|
||||||
|
model_size_or_path = modelsize
|
||||||
|
else:
|
||||||
|
raise ValueError("Either modelsize or model_dir must be set")
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
compute_type = "float16" if device == "cuda" else "float32"
|
||||||
|
|
||||||
|
model = WhisperModel(
|
||||||
|
model_size_or_path,
|
||||||
|
device=device,
|
||||||
|
compute_type=compute_type,
|
||||||
|
download_root=cache_dir,
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def transcribe(self, audio: np.ndarray, init_prompt: str = "") -> list:
|
||||||
|
segments, info = self.model.transcribe(
|
||||||
|
audio,
|
||||||
|
language=self.original_language,
|
||||||
|
initial_prompt=init_prompt,
|
||||||
|
beam_size=5,
|
||||||
|
word_timestamps=True,
|
||||||
|
condition_on_previous_text=True,
|
||||||
|
**self.transcribe_kargs,
|
||||||
|
)
|
||||||
|
return list(segments)
|
||||||
|
|
||||||
|
def ts_words(self, segments) -> List[ASRToken]:
|
||||||
|
tokens = []
|
||||||
|
for segment in segments:
|
||||||
|
if segment.no_speech_prob > 0.9:
|
||||||
|
continue
|
||||||
|
for word in segment.words:
|
||||||
|
token = ASRToken(word.start, word.end, word.word, probability=word.probability)
|
||||||
|
tokens.append(token)
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def segments_end_ts(self, segments) -> List[float]:
|
||||||
|
return [segment.end for segment in segments]
|
||||||
|
|
||||||
|
def use_vad(self):
|
||||||
|
self.transcribe_kargs["vad_filter"] = True
|
||||||
|
|
||||||
|
def set_translate_task(self):
|
||||||
|
self.transcribe_kargs["task"] = "translate"
|
||||||
|
|
||||||
|
|
||||||
|
class MLXWhisper(ASRBase):
|
||||||
|
"""
|
||||||
|
Uses MLX Whisper optimized for Apple Silicon.
|
||||||
|
"""
|
||||||
|
sep = ""
|
||||||
|
|
||||||
|
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
||||||
|
from mlx_whisper.transcribe import ModelHolder, transcribe
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
if model_dir is not None:
|
||||||
|
logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used.")
|
||||||
|
model_size_or_path = model_dir
|
||||||
|
elif modelsize is not None:
|
||||||
|
model_size_or_path = self.translate_model_name(modelsize)
|
||||||
|
logger.debug(f"Loading whisper model {modelsize}. You use mlx whisper, so {model_size_or_path} will be used.")
|
||||||
|
else:
|
||||||
|
raise ValueError("Either modelsize or model_dir must be set")
|
||||||
|
|
||||||
|
self.model_size_or_path = model_size_or_path
|
||||||
|
dtype = mx.float16
|
||||||
|
ModelHolder.get_model(model_size_or_path, dtype)
|
||||||
|
return transcribe
|
||||||
|
|
||||||
|
def translate_model_name(self, model_name):
|
||||||
|
model_mapping = {
|
||||||
|
"tiny.en": "mlx-community/whisper-tiny.en-mlx",
|
||||||
|
"tiny": "mlx-community/whisper-tiny-mlx",
|
||||||
|
"base.en": "mlx-community/whisper-base.en-mlx",
|
||||||
|
"base": "mlx-community/whisper-base-mlx",
|
||||||
|
"small.en": "mlx-community/whisper-small.en-mlx",
|
||||||
|
"small": "mlx-community/whisper-small-mlx",
|
||||||
|
"medium.en": "mlx-community/whisper-medium.en-mlx",
|
||||||
|
"medium": "mlx-community/whisper-medium-mlx",
|
||||||
|
"large-v1": "mlx-community/whisper-large-v1-mlx",
|
||||||
|
"large-v2": "mlx-community/whisper-large-v2-mlx",
|
||||||
|
"large-v3": "mlx-community/whisper-large-v3-mlx",
|
||||||
|
"large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
|
||||||
|
"large": "mlx-community/whisper-large-mlx",
|
||||||
|
}
|
||||||
|
mlx_model_path = model_mapping.get(model_name)
|
||||||
|
if mlx_model_path:
|
||||||
|
return mlx_model_path
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Model name '{model_name}' is not recognized or not supported.")
|
||||||
|
|
||||||
|
def transcribe(self, audio, init_prompt=""):
|
||||||
|
if self.transcribe_kargs:
|
||||||
|
logger.warning("Transcribe kwargs (vad, task) are not compatible with MLX Whisper and will be ignored.")
|
||||||
|
segments = self.model(
|
||||||
|
audio,
|
||||||
|
language=self.original_language,
|
||||||
|
initial_prompt=init_prompt,
|
||||||
|
word_timestamps=True,
|
||||||
|
condition_on_previous_text=True,
|
||||||
|
path_or_hf_repo=self.model_size_or_path,
|
||||||
|
)
|
||||||
|
return segments.get("segments", [])
|
||||||
|
|
||||||
|
def ts_words(self, segments) -> List[ASRToken]:
|
||||||
|
tokens = []
|
||||||
|
for segment in segments:
|
||||||
|
if segment.get("no_speech_prob", 0) > 0.9:
|
||||||
|
continue
|
||||||
|
for word in segment.get("words", []):
|
||||||
|
token = ASRToken(word["start"], word["end"], word["word"], probability=word["probability"])
|
||||||
|
tokens.append(token)
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def segments_end_ts(self, res) -> List[float]:
|
||||||
|
return [s["end"] for s in res]
|
||||||
|
|
||||||
|
def use_vad(self):
|
||||||
|
self.transcribe_kargs["vad_filter"] = True
|
||||||
|
|
||||||
|
def set_translate_task(self):
|
||||||
|
self.transcribe_kargs["task"] = "translate"
|
||||||
|
|
||||||
|
|
||||||
|
class OpenaiApiASR(ASRBase):
|
||||||
|
"""Uses OpenAI's Whisper API for transcription."""
|
||||||
|
def __init__(self, lan=None, temperature=0, logfile=sys.stderr):
|
||||||
|
self.logfile = logfile
|
||||||
|
self.modelname = "whisper-1"
|
||||||
|
self.original_language = None if lan == "auto" else lan
|
||||||
|
self.response_format = "verbose_json"
|
||||||
|
self.temperature = temperature
|
||||||
|
self.load_model()
|
||||||
|
self.use_vad_opt = False
|
||||||
|
self.task = "transcribe"
|
||||||
|
|
||||||
|
def load_model(self, *args, **kwargs):
|
||||||
|
from openai import OpenAI
|
||||||
|
self.client = OpenAI()
|
||||||
|
self.transcribed_seconds = 0
|
||||||
|
|
||||||
|
def ts_words(self, segments) -> List[ASRToken]:
|
||||||
|
"""
|
||||||
|
Converts OpenAI API response words into ASRToken objects while
|
||||||
|
optionally skipping words that fall into no-speech segments.
|
||||||
|
"""
|
||||||
|
no_speech_segments = []
|
||||||
|
if self.use_vad_opt:
|
||||||
|
for segment in segments.segments:
|
||||||
|
if segment["no_speech_prob"] > 0.8:
|
||||||
|
no_speech_segments.append((segment.get("start"), segment.get("end")))
|
||||||
|
tokens = []
|
||||||
|
for word in segments.words:
|
||||||
|
start = word.start
|
||||||
|
end = word.end
|
||||||
|
if any(s[0] <= start <= s[1] for s in no_speech_segments):
|
||||||
|
continue
|
||||||
|
tokens.append(ASRToken(start, end, word.word))
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def segments_end_ts(self, res) -> List[float]:
|
||||||
|
return [s.end for s in res.words]
|
||||||
|
|
||||||
|
def transcribe(self, audio_data, prompt=None, *args, **kwargs):
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
buffer.name = "temp.wav"
|
||||||
|
sf.write(buffer, audio_data, samplerate=16000, format="WAV", subtype="PCM_16")
|
||||||
|
buffer.seek(0)
|
||||||
|
self.transcribed_seconds += math.ceil(len(audio_data) / 16000)
|
||||||
|
params = {
|
||||||
|
"model": self.modelname,
|
||||||
|
"file": buffer,
|
||||||
|
"response_format": self.response_format,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"timestamp_granularities": ["word", "segment"],
|
||||||
|
}
|
||||||
|
if self.task != "translate" and self.original_language:
|
||||||
|
params["language"] = self.original_language
|
||||||
|
if prompt:
|
||||||
|
params["prompt"] = prompt
|
||||||
|
proc = self.client.audio.translations if self.task == "translate" else self.client.audio.transcriptions
|
||||||
|
transcript = proc.create(**params)
|
||||||
|
logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds")
|
||||||
|
return transcript
|
||||||
|
|
||||||
|
def use_vad(self):
|
||||||
|
self.use_vad_opt = True
|
||||||
|
|
||||||
|
def set_translate_task(self):
|
||||||
|
self.task = "translate"
|
||||||
453
whisperlivekit/whisper_streaming_custom/online_asr.py
Normal file
453
whisperlivekit/whisper_streaming_custom/online_asr.py
Normal file
@@ -0,0 +1,453 @@
|
|||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
import logging
|
||||||
|
from typing import List, Tuple, Optional
|
||||||
|
from whisperlivekit.timed_objects import ASRToken, Sentence, Transcript
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class HypothesisBuffer:
|
||||||
|
"""
|
||||||
|
Buffer to store and process ASR hypothesis tokens.
|
||||||
|
|
||||||
|
It holds:
|
||||||
|
- committed_in_buffer: tokens that have been confirmed (committed)
|
||||||
|
- buffer: the last hypothesis that is not yet committed
|
||||||
|
- new: new tokens coming from the recognizer
|
||||||
|
"""
|
||||||
|
def __init__(self, logfile=sys.stderr, confidence_validation=False):
|
||||||
|
self.confidence_validation = confidence_validation
|
||||||
|
self.committed_in_buffer: List[ASRToken] = []
|
||||||
|
self.buffer: List[ASRToken] = []
|
||||||
|
self.new: List[ASRToken] = []
|
||||||
|
self.last_committed_time = 0.0
|
||||||
|
self.last_committed_word: Optional[str] = None
|
||||||
|
self.logfile = logfile
|
||||||
|
|
||||||
|
def insert(self, new_tokens: List[ASRToken], offset: float):
|
||||||
|
"""
|
||||||
|
Insert new tokens (after applying a time offset) and compare them with the
|
||||||
|
already committed tokens. Only tokens that extend the committed hypothesis
|
||||||
|
are added.
|
||||||
|
"""
|
||||||
|
# Apply the offset to each token.
|
||||||
|
new_tokens = [token.with_offset(offset) for token in new_tokens]
|
||||||
|
# Only keep tokens that are roughly “new”
|
||||||
|
self.new = [token for token in new_tokens if token.start > self.last_committed_time - 0.1]
|
||||||
|
|
||||||
|
if self.new:
|
||||||
|
first_token = self.new[0]
|
||||||
|
if abs(first_token.start - self.last_committed_time) < 1:
|
||||||
|
if self.committed_in_buffer:
|
||||||
|
committed_len = len(self.committed_in_buffer)
|
||||||
|
new_len = len(self.new)
|
||||||
|
# Try to match 1 to 5 consecutive tokens
|
||||||
|
max_ngram = min(min(committed_len, new_len), 5)
|
||||||
|
for i in range(1, max_ngram + 1):
|
||||||
|
committed_ngram = " ".join(token.text for token in self.committed_in_buffer[-i:])
|
||||||
|
new_ngram = " ".join(token.text for token in self.new[:i])
|
||||||
|
if committed_ngram == new_ngram:
|
||||||
|
removed = []
|
||||||
|
for _ in range(i):
|
||||||
|
removed_token = self.new.pop(0)
|
||||||
|
removed.append(repr(removed_token))
|
||||||
|
logger.debug(f"Removing last {i} words: {' '.join(removed)}")
|
||||||
|
break
|
||||||
|
|
||||||
|
def flush(self) -> List[ASRToken]:
|
||||||
|
"""
|
||||||
|
Returns the committed chunk, defined as the longest common prefix
|
||||||
|
between the previous hypothesis and the new tokens.
|
||||||
|
"""
|
||||||
|
committed: List[ASRToken] = []
|
||||||
|
while self.new:
|
||||||
|
current_new = self.new[0]
|
||||||
|
if self.confidence_validation and current_new.probability and current_new.probability > 0.95:
|
||||||
|
committed.append(current_new)
|
||||||
|
self.last_committed_word = current_new.text
|
||||||
|
self.last_committed_time = current_new.end
|
||||||
|
self.new.pop(0)
|
||||||
|
self.buffer.pop(0) if self.buffer else None
|
||||||
|
elif not self.buffer:
|
||||||
|
break
|
||||||
|
elif current_new.text == self.buffer[0].text:
|
||||||
|
committed.append(current_new)
|
||||||
|
self.last_committed_word = current_new.text
|
||||||
|
self.last_committed_time = current_new.end
|
||||||
|
self.buffer.pop(0)
|
||||||
|
self.new.pop(0)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
self.buffer = self.new
|
||||||
|
self.new = []
|
||||||
|
self.committed_in_buffer.extend(committed)
|
||||||
|
return committed
|
||||||
|
|
||||||
|
def pop_committed(self, time: float):
|
||||||
|
"""
|
||||||
|
Remove tokens (from the beginning) that have ended before `time`.
|
||||||
|
"""
|
||||||
|
while self.committed_in_buffer and self.committed_in_buffer[0].end <= time:
|
||||||
|
self.committed_in_buffer.pop(0)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineASRProcessor:
|
||||||
|
"""
|
||||||
|
Processes incoming audio in a streaming fashion, calling the ASR system
|
||||||
|
periodically, and uses a hypothesis buffer to commit and trim recognized text.
|
||||||
|
|
||||||
|
The processor supports two types of buffer trimming:
|
||||||
|
- "sentence": trims at sentence boundaries (using a sentence tokenizer)
|
||||||
|
- "segment": trims at fixed segment durations.
|
||||||
|
"""
|
||||||
|
SAMPLING_RATE = 16000
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
asr,
|
||||||
|
tokenize_method: Optional[callable] = None,
|
||||||
|
buffer_trimming: Tuple[str, float] = ("segment", 15),
|
||||||
|
confidence_validation = False,
|
||||||
|
logfile=sys.stderr,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
asr: An ASR system object (for example, a WhisperASR instance) that
|
||||||
|
provides a `transcribe` method, a `ts_words` method (to extract tokens),
|
||||||
|
a `segments_end_ts` method, and a separator attribute `sep`.
|
||||||
|
tokenize_method: A function that receives text and returns a list of sentence strings.
|
||||||
|
buffer_trimming: A tuple (option, seconds), where option is either "sentence" or "segment".
|
||||||
|
"""
|
||||||
|
self.asr = asr
|
||||||
|
self.tokenize = tokenize_method
|
||||||
|
self.logfile = logfile
|
||||||
|
self.confidence_validation = confidence_validation
|
||||||
|
self.init()
|
||||||
|
|
||||||
|
self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
|
||||||
|
|
||||||
|
if self.buffer_trimming_way not in ["sentence", "segment"]:
|
||||||
|
raise ValueError("buffer_trimming must be either 'sentence' or 'segment'")
|
||||||
|
if self.buffer_trimming_sec <= 0:
|
||||||
|
raise ValueError("buffer_trimming_sec must be positive")
|
||||||
|
elif self.buffer_trimming_sec > 30:
|
||||||
|
logger.warning(
|
||||||
|
f"buffer_trimming_sec is set to {self.buffer_trimming_sec}, which is very long. It may cause OOM."
|
||||||
|
)
|
||||||
|
|
||||||
|
def init(self, offset: Optional[float] = None):
|
||||||
|
"""Initialize or reset the processing buffers."""
|
||||||
|
self.audio_buffer = np.array([], dtype=np.float32)
|
||||||
|
self.transcript_buffer = HypothesisBuffer(logfile=self.logfile, confidence_validation=self.confidence_validation)
|
||||||
|
self.buffer_time_offset = offset if offset is not None else 0.0
|
||||||
|
self.transcript_buffer.last_committed_time = self.buffer_time_offset
|
||||||
|
self.committed: List[ASRToken] = []
|
||||||
|
|
||||||
|
def insert_audio_chunk(self, audio: np.ndarray):
|
||||||
|
"""Append an audio chunk (a numpy array) to the current audio buffer."""
|
||||||
|
self.audio_buffer = np.append(self.audio_buffer, audio)
|
||||||
|
|
||||||
|
def prompt(self) -> Tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Returns a tuple: (prompt, context), where:
|
||||||
|
- prompt is a 200-character suffix of committed text that falls
|
||||||
|
outside the current audio buffer.
|
||||||
|
- context is the committed text within the current audio buffer.
|
||||||
|
"""
|
||||||
|
k = len(self.committed)
|
||||||
|
while k > 0 and self.committed[k - 1].end > self.buffer_time_offset:
|
||||||
|
k -= 1
|
||||||
|
|
||||||
|
prompt_tokens = self.committed[:k]
|
||||||
|
prompt_words = [token.text for token in prompt_tokens]
|
||||||
|
prompt_list = []
|
||||||
|
length_count = 0
|
||||||
|
# Use the last words until reaching 200 characters.
|
||||||
|
while prompt_words and length_count < 200:
|
||||||
|
word = prompt_words.pop(-1)
|
||||||
|
length_count += len(word) + 1
|
||||||
|
prompt_list.append(word)
|
||||||
|
non_prompt_tokens = self.committed[k:]
|
||||||
|
context_text = self.asr.sep.join(token.text for token in non_prompt_tokens)
|
||||||
|
return self.asr.sep.join(prompt_list[::-1]), context_text
|
||||||
|
|
||||||
|
def get_buffer(self):
|
||||||
|
"""
|
||||||
|
Get the unvalidated buffer in string format.
|
||||||
|
"""
|
||||||
|
return self.concatenate_tokens(self.transcript_buffer.buffer)
|
||||||
|
|
||||||
|
|
||||||
|
def process_iter(self) -> Transcript:
|
||||||
|
"""
|
||||||
|
Processes the current audio buffer.
|
||||||
|
|
||||||
|
Returns a Transcript object representing the committed transcript.
|
||||||
|
"""
|
||||||
|
prompt_text, _ = self.prompt()
|
||||||
|
logger.debug(
|
||||||
|
f"Transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds from {self.buffer_time_offset:.2f}"
|
||||||
|
)
|
||||||
|
res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt_text)
|
||||||
|
tokens = self.asr.ts_words(res) # Expecting List[ASRToken]
|
||||||
|
self.transcript_buffer.insert(tokens, self.buffer_time_offset)
|
||||||
|
committed_tokens = self.transcript_buffer.flush()
|
||||||
|
self.committed.extend(committed_tokens)
|
||||||
|
completed = self.concatenate_tokens(committed_tokens)
|
||||||
|
logger.debug(f">>>> COMPLETE NOW: {completed.text}")
|
||||||
|
incomp = self.concatenate_tokens(self.transcript_buffer.buffer)
|
||||||
|
logger.debug(f"INCOMPLETE: {incomp.text}")
|
||||||
|
|
||||||
|
if committed_tokens and self.buffer_trimming_way == "sentence":
|
||||||
|
if len(self.audio_buffer) / self.SAMPLING_RATE > self.buffer_trimming_sec:
|
||||||
|
self.chunk_completed_sentence()
|
||||||
|
|
||||||
|
s = self.buffer_trimming_sec if self.buffer_trimming_way == "segment" else 30
|
||||||
|
if len(self.audio_buffer) / self.SAMPLING_RATE > s:
|
||||||
|
self.chunk_completed_segment(res)
|
||||||
|
logger.debug("Chunking segment")
|
||||||
|
logger.debug(
|
||||||
|
f"Length of audio buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds"
|
||||||
|
)
|
||||||
|
return committed_tokens
|
||||||
|
|
||||||
|
def chunk_completed_sentence(self):
|
||||||
|
"""
|
||||||
|
If the committed tokens form at least two sentences, chunk the audio
|
||||||
|
buffer at the end time of the penultimate sentence.
|
||||||
|
"""
|
||||||
|
if not self.committed:
|
||||||
|
return
|
||||||
|
logger.debug("COMPLETED SENTENCE: " + " ".join(token.text for token in self.committed))
|
||||||
|
sentences = self.words_to_sentences(self.committed)
|
||||||
|
for sentence in sentences:
|
||||||
|
logger.debug(f"\tSentence: {sentence.text}")
|
||||||
|
if len(sentences) < 2:
|
||||||
|
return
|
||||||
|
# Keep the last two sentences.
|
||||||
|
while len(sentences) > 2:
|
||||||
|
sentences.pop(0)
|
||||||
|
chunk_time = sentences[-2].end
|
||||||
|
logger.debug(f"--- Sentence chunked at {chunk_time:.2f}")
|
||||||
|
self.chunk_at(chunk_time)
|
||||||
|
|
||||||
|
def chunk_completed_segment(self, res):
|
||||||
|
"""
|
||||||
|
Chunk the audio buffer based on segment-end timestamps reported by the ASR.
|
||||||
|
"""
|
||||||
|
if not self.committed:
|
||||||
|
return
|
||||||
|
ends = self.asr.segments_end_ts(res)
|
||||||
|
last_committed_time = self.committed[-1].end
|
||||||
|
if len(ends) > 1:
|
||||||
|
e = ends[-2] + self.buffer_time_offset
|
||||||
|
while len(ends) > 2 and e > last_committed_time:
|
||||||
|
ends.pop(-1)
|
||||||
|
e = ends[-2] + self.buffer_time_offset
|
||||||
|
if e <= last_committed_time:
|
||||||
|
logger.debug(f"--- Segment chunked at {e:.2f}")
|
||||||
|
self.chunk_at(e)
|
||||||
|
else:
|
||||||
|
logger.debug("--- Last segment not within committed area")
|
||||||
|
else:
|
||||||
|
logger.debug("--- Not enough segments to chunk")
|
||||||
|
|
||||||
|
def chunk_at(self, time: float):
|
||||||
|
"""
|
||||||
|
Trim both the hypothesis and audio buffer at the given time.
|
||||||
|
"""
|
||||||
|
logger.debug(f"Chunking at {time:.2f}s")
|
||||||
|
logger.debug(
|
||||||
|
f"Audio buffer length before chunking: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f}s"
|
||||||
|
)
|
||||||
|
self.transcript_buffer.pop_committed(time)
|
||||||
|
cut_seconds = time - self.buffer_time_offset
|
||||||
|
self.audio_buffer = self.audio_buffer[int(cut_seconds * self.SAMPLING_RATE):]
|
||||||
|
self.buffer_time_offset = time
|
||||||
|
logger.debug(
|
||||||
|
f"Audio buffer length after chunking: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
def words_to_sentences(self, tokens: List[ASRToken]) -> List[Sentence]:
|
||||||
|
"""
|
||||||
|
Converts a list of tokens to a list of Sentence objects using the provided
|
||||||
|
sentence tokenizer.
|
||||||
|
"""
|
||||||
|
if not tokens:
|
||||||
|
return []
|
||||||
|
|
||||||
|
full_text = " ".join(token.text for token in tokens)
|
||||||
|
|
||||||
|
if self.tokenize:
|
||||||
|
try:
|
||||||
|
sentence_texts = self.tokenize(full_text)
|
||||||
|
except Exception as e:
|
||||||
|
# Some tokenizers (e.g., MosesSentenceSplitter) expect a list input.
|
||||||
|
try:
|
||||||
|
sentence_texts = self.tokenize([full_text])
|
||||||
|
except Exception as e2:
|
||||||
|
raise ValueError("Tokenization failed") from e2
|
||||||
|
else:
|
||||||
|
sentence_texts = [full_text]
|
||||||
|
|
||||||
|
sentences: List[Sentence] = []
|
||||||
|
token_index = 0
|
||||||
|
for sent_text in sentence_texts:
|
||||||
|
sent_text = sent_text.strip()
|
||||||
|
if not sent_text:
|
||||||
|
continue
|
||||||
|
sent_tokens = []
|
||||||
|
accumulated = ""
|
||||||
|
# Accumulate tokens until roughly matching the length of the sentence text.
|
||||||
|
while token_index < len(tokens) and len(accumulated) < len(sent_text):
|
||||||
|
token = tokens[token_index]
|
||||||
|
accumulated = (accumulated + " " + token.text).strip() if accumulated else token.text
|
||||||
|
sent_tokens.append(token)
|
||||||
|
token_index += 1
|
||||||
|
if sent_tokens:
|
||||||
|
sentence = Sentence(
|
||||||
|
start=sent_tokens[0].start,
|
||||||
|
end=sent_tokens[-1].end,
|
||||||
|
text=" ".join(t.text for t in sent_tokens),
|
||||||
|
)
|
||||||
|
sentences.append(sentence)
|
||||||
|
return sentences
|
||||||
|
def finish(self) -> Transcript:
|
||||||
|
"""
|
||||||
|
Flush the remaining transcript when processing ends.
|
||||||
|
"""
|
||||||
|
remaining_tokens = self.transcript_buffer.buffer
|
||||||
|
final_transcript = self.concatenate_tokens(remaining_tokens)
|
||||||
|
logger.debug(f"Final non-committed transcript: {final_transcript}")
|
||||||
|
self.buffer_time_offset += len(self.audio_buffer) / self.SAMPLING_RATE
|
||||||
|
return final_transcript
|
||||||
|
|
||||||
|
def concatenate_tokens(
|
||||||
|
self,
|
||||||
|
tokens: List[ASRToken],
|
||||||
|
sep: Optional[str] = None,
|
||||||
|
offset: float = 0
|
||||||
|
) -> Transcript:
|
||||||
|
sep = sep if sep is not None else self.asr.sep
|
||||||
|
text = sep.join(token.text for token in tokens)
|
||||||
|
probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
|
||||||
|
if tokens:
|
||||||
|
start = offset + tokens[0].start
|
||||||
|
end = offset + tokens[-1].end
|
||||||
|
else:
|
||||||
|
start = None
|
||||||
|
end = None
|
||||||
|
return Transcript(start, end, text, probability=probability)
|
||||||
|
|
||||||
|
|
||||||
|
class VACOnlineASRProcessor:
|
||||||
|
"""
|
||||||
|
Wraps an OnlineASRProcessor with a Voice Activity Controller (VAC).
|
||||||
|
|
||||||
|
It receives small chunks of audio, applies VAD (e.g. with Silero),
|
||||||
|
and when the system detects a pause in speech (or end of an utterance)
|
||||||
|
it finalizes the utterance immediately.
|
||||||
|
"""
|
||||||
|
SAMPLING_RATE = 16000
|
||||||
|
|
||||||
|
def __init__(self, online_chunk_size: float, *args, **kwargs):
|
||||||
|
self.online_chunk_size = online_chunk_size
|
||||||
|
self.online = OnlineASRProcessor(*args, **kwargs)
|
||||||
|
|
||||||
|
# Load a VAD model (e.g. Silero VAD)
|
||||||
|
import torch
|
||||||
|
model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
|
||||||
|
from silero_vad_iterator import FixedVADIterator
|
||||||
|
|
||||||
|
self.vac = FixedVADIterator(model)
|
||||||
|
self.logfile = self.online.logfile
|
||||||
|
self.init()
|
||||||
|
|
||||||
|
def init(self):
|
||||||
|
self.online.init()
|
||||||
|
self.vac.reset_states()
|
||||||
|
self.current_online_chunk_buffer_size = 0
|
||||||
|
self.is_currently_final = False
|
||||||
|
self.status: Optional[str] = None # "voice" or "nonvoice"
|
||||||
|
self.audio_buffer = np.array([], dtype=np.float32)
|
||||||
|
self.buffer_offset = 0 # in frames
|
||||||
|
|
||||||
|
def clear_buffer(self):
|
||||||
|
self.buffer_offset += len(self.audio_buffer)
|
||||||
|
self.audio_buffer = np.array([], dtype=np.float32)
|
||||||
|
|
||||||
|
def insert_audio_chunk(self, audio: np.ndarray):
|
||||||
|
"""
|
||||||
|
Process an incoming small audio chunk:
|
||||||
|
- run VAD on the chunk,
|
||||||
|
- decide whether to send the audio to the online ASR processor immediately,
|
||||||
|
- and/or to mark the current utterance as finished.
|
||||||
|
"""
|
||||||
|
res = self.vac(audio)
|
||||||
|
self.audio_buffer = np.append(self.audio_buffer, audio)
|
||||||
|
|
||||||
|
if res is not None:
|
||||||
|
# VAD returned a result; adjust the frame number
|
||||||
|
frame = list(res.values())[0] - self.buffer_offset
|
||||||
|
if "start" in res and "end" not in res:
|
||||||
|
self.status = "voice"
|
||||||
|
send_audio = self.audio_buffer[frame:]
|
||||||
|
self.online.init(offset=(frame + self.buffer_offset) / self.SAMPLING_RATE)
|
||||||
|
self.online.insert_audio_chunk(send_audio)
|
||||||
|
self.current_online_chunk_buffer_size += len(send_audio)
|
||||||
|
self.clear_buffer()
|
||||||
|
elif "end" in res and "start" not in res:
|
||||||
|
self.status = "nonvoice"
|
||||||
|
send_audio = self.audio_buffer[:frame]
|
||||||
|
self.online.insert_audio_chunk(send_audio)
|
||||||
|
self.current_online_chunk_buffer_size += len(send_audio)
|
||||||
|
self.is_currently_final = True
|
||||||
|
self.clear_buffer()
|
||||||
|
else:
|
||||||
|
beg = res["start"] - self.buffer_offset
|
||||||
|
end = res["end"] - self.buffer_offset
|
||||||
|
self.status = "nonvoice"
|
||||||
|
send_audio = self.audio_buffer[beg:end]
|
||||||
|
self.online.init(offset=(beg + self.buffer_offset) / self.SAMPLING_RATE)
|
||||||
|
self.online.insert_audio_chunk(send_audio)
|
||||||
|
self.current_online_chunk_buffer_size += len(send_audio)
|
||||||
|
self.is_currently_final = True
|
||||||
|
self.clear_buffer()
|
||||||
|
else:
|
||||||
|
if self.status == "voice":
|
||||||
|
self.online.insert_audio_chunk(self.audio_buffer)
|
||||||
|
self.current_online_chunk_buffer_size += len(self.audio_buffer)
|
||||||
|
self.clear_buffer()
|
||||||
|
else:
|
||||||
|
# Keep 1 second worth of audio in case VAD later detects voice,
|
||||||
|
# but trim to avoid unbounded memory usage.
|
||||||
|
self.buffer_offset += max(0, len(self.audio_buffer) - self.SAMPLING_RATE)
|
||||||
|
self.audio_buffer = self.audio_buffer[-self.SAMPLING_RATE:]
|
||||||
|
|
||||||
|
def process_iter(self) -> Transcript:
|
||||||
|
"""
|
||||||
|
Depending on the VAD status and the amount of accumulated audio,
|
||||||
|
process the current audio chunk.
|
||||||
|
"""
|
||||||
|
if self.is_currently_final:
|
||||||
|
return self.finish()
|
||||||
|
elif self.current_online_chunk_buffer_size > self.SAMPLING_RATE * self.online_chunk_size:
|
||||||
|
self.current_online_chunk_buffer_size = 0
|
||||||
|
return self.online.process_iter()
|
||||||
|
else:
|
||||||
|
logger.debug("No online update, only VAD")
|
||||||
|
return Transcript(None, None, "")
|
||||||
|
|
||||||
|
def finish(self) -> Transcript:
|
||||||
|
"""Finish processing by flushing any remaining text."""
|
||||||
|
result = self.online.finish()
|
||||||
|
self.current_online_chunk_buffer_size = 0
|
||||||
|
self.is_currently_final = False
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_buffer(self):
|
||||||
|
"""
|
||||||
|
Get the unvalidated buffer in string format.
|
||||||
|
"""
|
||||||
|
return self.online.concatenate_tokens(self.online.transcript_buffer.buffer).text
|
||||||
194
whisperlivekit/whisper_streaming_custom/whisper_online.py
Normal file
194
whisperlivekit/whisper_streaming_custom/whisper_online.py
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
import librosa
|
||||||
|
from functools import lru_cache
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
from .backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR
|
||||||
|
from .online_asr import OnlineASRProcessor, VACOnlineASRProcessor
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
WHISPER_LANG_CODES = "af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh".split(
|
||||||
|
","
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_tokenizer(lan):
|
||||||
|
"""returns an object that has split function that works like the one of MosesTokenizer"""
|
||||||
|
|
||||||
|
assert (
|
||||||
|
lan in WHISPER_LANG_CODES
|
||||||
|
), "language must be Whisper's supported lang code: " + " ".join(WHISPER_LANG_CODES)
|
||||||
|
|
||||||
|
if lan == "uk":
|
||||||
|
import tokenize_uk
|
||||||
|
|
||||||
|
class UkrainianTokenizer:
|
||||||
|
def split(self, text):
|
||||||
|
return tokenize_uk.tokenize_sents(text)
|
||||||
|
|
||||||
|
return UkrainianTokenizer()
|
||||||
|
|
||||||
|
# supported by fast-mosestokenizer
|
||||||
|
if (
|
||||||
|
lan
|
||||||
|
in "as bn ca cs de el en es et fi fr ga gu hi hu is it kn lt lv ml mni mr nl or pa pl pt ro ru sk sl sv ta te yue zh".split()
|
||||||
|
):
|
||||||
|
from mosestokenizer import MosesSentenceSplitter
|
||||||
|
|
||||||
|
return MosesSentenceSplitter(lan)
|
||||||
|
|
||||||
|
# the following languages are in Whisper, but not in wtpsplit:
|
||||||
|
if (
|
||||||
|
lan
|
||||||
|
in "as ba bo br bs fo haw hr ht jw lb ln lo mi nn oc sa sd sn so su sw tk tl tt".split()
|
||||||
|
):
|
||||||
|
logger.debug(
|
||||||
|
f"{lan} code is not supported by wtpsplit. Going to use None lang_code option."
|
||||||
|
)
|
||||||
|
lan = None
|
||||||
|
|
||||||
|
from wtpsplit import WtP
|
||||||
|
|
||||||
|
# downloads the model from huggingface on the first use
|
||||||
|
wtp = WtP("wtp-canine-s-12l-no-adapters")
|
||||||
|
|
||||||
|
class WtPtok:
|
||||||
|
def split(self, sent):
|
||||||
|
return wtp.split(sent, lang_code=lan)
|
||||||
|
|
||||||
|
return WtPtok()
|
||||||
|
|
||||||
|
|
||||||
|
def backend_factory(args):
|
||||||
|
backend = args.backend
|
||||||
|
if backend == "openai-api":
|
||||||
|
logger.debug("Using OpenAI API.")
|
||||||
|
asr = OpenaiApiASR(lan=args.lan)
|
||||||
|
else:
|
||||||
|
if backend == "faster-whisper":
|
||||||
|
asr_cls = FasterWhisperASR
|
||||||
|
elif backend == "mlx-whisper":
|
||||||
|
asr_cls = MLXWhisper
|
||||||
|
else:
|
||||||
|
asr_cls = WhisperTimestampedASR
|
||||||
|
|
||||||
|
# Only for FasterWhisperASR and WhisperTimestampedASR
|
||||||
|
size = args.model
|
||||||
|
t = time.time()
|
||||||
|
logger.info(f"Loading Whisper {size} model for language {args.lan}...")
|
||||||
|
asr = asr_cls(
|
||||||
|
modelsize=size,
|
||||||
|
lan=args.lan,
|
||||||
|
cache_dir=args.model_cache_dir,
|
||||||
|
model_dir=args.model_dir,
|
||||||
|
)
|
||||||
|
e = time.time()
|
||||||
|
logger.info(f"done. It took {round(e-t,2)} seconds.")
|
||||||
|
|
||||||
|
# Apply common configurations
|
||||||
|
if getattr(args, "vad", False): # Checks if VAD argument is present and True
|
||||||
|
logger.info("Setting VAD filter")
|
||||||
|
asr.use_vad()
|
||||||
|
|
||||||
|
language = args.lan
|
||||||
|
if args.task == "translate":
|
||||||
|
asr.set_translate_task()
|
||||||
|
tgt_language = "en" # Whisper translates into English
|
||||||
|
else:
|
||||||
|
tgt_language = language # Whisper transcribes in this language
|
||||||
|
|
||||||
|
# Create the tokenizer
|
||||||
|
if args.buffer_trimming == "sentence":
|
||||||
|
|
||||||
|
tokenizer = create_tokenizer(tgt_language)
|
||||||
|
else:
|
||||||
|
tokenizer = None
|
||||||
|
return asr, tokenizer
|
||||||
|
|
||||||
|
def online_factory(args, asr, tokenizer, logfile=sys.stderr):
|
||||||
|
if args.vac:
|
||||||
|
online = VACOnlineASRProcessor(
|
||||||
|
args.min_chunk_size,
|
||||||
|
asr,
|
||||||
|
tokenizer,
|
||||||
|
logfile=logfile,
|
||||||
|
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
|
||||||
|
confidence_validation = args.confidence_validation
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
online = OnlineASRProcessor(
|
||||||
|
asr,
|
||||||
|
tokenizer,
|
||||||
|
logfile=logfile,
|
||||||
|
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
|
||||||
|
confidence_validation = args.confidence_validation
|
||||||
|
)
|
||||||
|
return online
|
||||||
|
|
||||||
|
def asr_factory(args, logfile=sys.stderr):
|
||||||
|
"""
|
||||||
|
Creates and configures an ASR and ASR Online instance based on the specified backend and arguments.
|
||||||
|
"""
|
||||||
|
asr, tokenizer = backend_factory(args)
|
||||||
|
online = online_factory(args, asr, tokenizer, logfile=logfile)
|
||||||
|
return asr, online
|
||||||
|
|
||||||
|
def warmup_asr(asr, warmup_file=None, timeout=5):
|
||||||
|
"""
|
||||||
|
Warmup the ASR model by transcribing a short audio file.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
|
||||||
|
if warmup_file is None:
|
||||||
|
# Download JFK sample if not already present
|
||||||
|
jfk_url = "https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav"
|
||||||
|
temp_dir = tempfile.gettempdir()
|
||||||
|
warmup_file = os.path.join(temp_dir, "whisper_warmup_jfk.wav")
|
||||||
|
|
||||||
|
if not os.path.exists(warmup_file):
|
||||||
|
logger.debug(f"Downloading warmup file from {jfk_url}")
|
||||||
|
print(f"Downloading warmup file from {jfk_url}")
|
||||||
|
import time
|
||||||
|
import urllib.request
|
||||||
|
import urllib.error
|
||||||
|
import socket
|
||||||
|
|
||||||
|
original_timeout = socket.getdefaulttimeout()
|
||||||
|
socket.setdefaulttimeout(timeout)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
urllib.request.urlretrieve(jfk_url, warmup_file)
|
||||||
|
logger.debug(f"Download successful in {time.time() - start_time:.2f}s")
|
||||||
|
except (urllib.error.URLError, socket.timeout) as e:
|
||||||
|
logger.warning(f"Download failed: {e}. Proceeding without warmup.")
|
||||||
|
return False
|
||||||
|
finally:
|
||||||
|
socket.setdefaulttimeout(original_timeout)
|
||||||
|
elif not warmup_file:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not warmup_file or not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0:
|
||||||
|
logger.warning(f"Warmup file {warmup_file} invalid or missing.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
print(f"Warmping up Whisper with {warmup_file}")
|
||||||
|
try:
|
||||||
|
import librosa
|
||||||
|
audio, sr = librosa.load(warmup_file, sr=16000)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load audio file: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Process the audio
|
||||||
|
asr.transcribe(audio)
|
||||||
|
|
||||||
|
logger.info("Whisper is warmed up")
|
||||||
|
|
||||||
Reference in New Issue
Block a user