2 Commits

Author SHA1 Message Date
Quentin Fuxa
aa44a92a67 add embedded web interface HTML (single-file version with inline CSS/JS/SVG)
### Added
- `get_inline_ui_html()`: generates a self-contained version of the web interface, with CSS, JS, and SVG assets inlined directly into the HTML. useful for environments where serving static files is inconvenient or when a single-call UI delivery is preferred.
2025-08-29 21:58:51 +02:00
Quentin Fuxa
01d791470b add test files 2025-08-29 17:45:32 +02:00
67 changed files with 1970 additions and 3772 deletions

19
.gitignore vendored
View File

@@ -54,6 +54,21 @@ coverage.xml
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
@@ -122,6 +137,4 @@ run_*.sh
test_*.py
launch.json
.DS_Store
test/*
nllb-200-distilled-600M-ctranslate2/*
*.mp3
test/*

View File

@@ -1,91 +0,0 @@
# 1. Simulstreaming: Decouple the encoder for faster inference
Simulstreaming encoder time (whisperlivekit/simul_whisper/simul_whisper.py l. 397) experimentations :
On macOS Apple Silicon M4 :
| Encoder | base.en | small |
|--------|---------|-------|
| WHISPER (no modification) | 0.35s | 1.09s |
| FASTER_WHISPER | 0.4s | 1.20s |
| MLX_WHISPER | 0.07s | 0.20s |
Memory saved by only loading encoder for optimized framework:
For tiny.en, mlx whisper:
Sizes MLX whisper:
Decoder weights: 59110771 bytes
Encoder weights: 15268874 bytes
# 2. Translation: Faster model for each system
## Benchmark Results
Testing on MacBook M3 with NLLB-200-distilled-600M model:
### Standard Transformers vs CTranslate2
| Test Text | Standard Inference Time | CTranslate2 Inference Time | Speedup |
|-----------|-------------------------|---------------------------|---------|
| UN Chief says there is no military solution in Syria | 0.9395s | 2.0472s | 0.5x |
| The rapid advancement of AI technology is transforming various industries | 0.7171s | 1.7516s | 0.4x |
| Climate change poses a significant threat to global ecosystems | 0.8533s | 1.8323s | 0.5x |
| International cooperation is essential for addressing global challenges | 0.7209s | 1.3575s | 0.5x |
| The development of renewable energy sources is crucial for a sustainable future | 0.8760s | 1.5589s | 0.6x |
**Results:**
- Total Standard time: 4.1068s
- Total CTranslate2 time: 8.5476s
- CTranslate2 is slower on this system --> Use Transformers, and ideally we would have an mlx implementation.
# 3. SortFormer Diarization: 4-to-2 Speaker Constraint Algorithm
Transform a diarization model that predicts up to 4 speakers into one that predicts up to 2 speakers by mapping the output predictions.
## Problem Statement
- Input: `self.total_preds` with shape `(x, x, 4)` - predictions for 4 speakers
- Output: Constrained predictions with shape `(x, x, 2)` - predictions for 2 speakers
#
### Initial Setup
For each time step `i`, we have a ranking of 4 speaker predictions (1-4). When only 2 speakers are present, the model will have close predictions for the 2 active speaker positions.
Instead of `np.argmax(preds_np, axis=1)`, we take the top 2 predictions and build a dynamic 4→2 mapping that can evolve over time.
### Algorithm
```python
top_2_speakers = np.argsort(preds_np, axis=1)[:, -2:]
```
- `DS_a_{i}`: Top detected speaker for prediction i
- `DS_b_{i}`: Second detected speaker for prediction i
- `AS_{i}`: Attributed speaker for prediction i
- `GTS_A`: Ground truth speaker A
- `GTS_B`: Ground truth speaker B
- `DIST(a, b)`: Distance between detected speakers a and b
3. **Attribution Logic**
```
AS_0 ← A
AS_1 ← B
IF DIST(DS_a_0, DS_a_1) < DIST(DS_a_0, DS_a_2) AND
DIST(DS_a_0, DS_a_1) < DIST(DS_a_1, DS_a_2):
# Likely that DS_a_0 = DS_a_1 (same speaker)
AS_1 ← A
AS_2 ← B
ELIF DIST(DS_a_0, DS_a_2) < DIST(DS_a_0, DS_a_1) AND
DIST(DS_a_0, DS_a_2) < DIST(DS_a_1, DS_a_2):
AS_2 ← A
ELSE:
AS_2 ← B
to finish
```

View File

@@ -17,26 +17,18 @@ RUN apt-get update && \
ffmpeg \
git \
build-essential \
python3-dev \
ca-certificates && \
python3-dev && \
rm -rf /var/lib/apt/lists/*
RUN python3 -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
# timeout/retries for large torch wheels
RUN pip3 install --upgrade pip setuptools wheel && \
pip3 --disable-pip-version-check install --timeout=120 --retries=5 \
--index-url https://download.pytorch.org/whl/cu129 \
torch torchaudio \
|| (echo "Initial install failed — retrying with extended timeout..." && \
pip3 --disable-pip-version-check install --timeout=300 --retries=3 \
--index-url https://download.pytorch.org/whl/cu129 \
torch torchvision torchaudio)
RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu129
COPY . .
# Install WhisperLiveKit directly, allowing for optional dependencies
# Note: For gates models, need to add your HF toke. See README.md
# for more details.
RUN if [ -n "$EXTRAS" ]; then \
echo "Installing with extras: [$EXTRAS]"; \
pip install --no-cache-dir whisperlivekit[$EXTRAS]; \
@@ -45,14 +37,16 @@ RUN if [ -n "$EXTRAS" ]; then \
pip install --no-cache-dir whisperlivekit; \
fi
# In-container caching for Hugging Face models by:
# Enable in-container caching for Hugging Face models by:
# Note: If running multiple containers, better to map a shared
# bucket.
#
# A) Make the cache directory persistent via an anonymous volume.
# Note: This only persists for a single, named container. This is
# only for convenience at de/test stage.
# For prod, it is better to use a named volume via host mount/k8s.
VOLUME ["/root/.cache/huggingface/hub"]
# or
# B) Conditionally copy a local pre-cache from the build context to the
# container's cache via the HF_PRECACHE_DIR build-arg.
@@ -67,7 +61,8 @@ RUN if [ -n "$HF_PRECACHE_DIR" ]; then \
echo "No local Hugging Face cache specified, skipping copy"; \
fi
# Conditionally copy a Hugging Face token if provided. Useful for Diart backend (pyannote audio models)
# Conditionally copy a Hugging Face token if provided
RUN if [ -n "$HF_TKN_FILE" ]; then \
echo "Copying Hugging Face token from $HF_TKN_FILE"; \
mkdir -p /root/.cache/huggingface && \
@@ -75,9 +70,11 @@ RUN if [ -n "$HF_TKN_FILE" ]; then \
else \
echo "No Hugging Face token file specified, skipping token setup"; \
fi
# Expose port for the transcription server
EXPOSE 8000
ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"]
CMD ["--model", "medium"]
# Default args
CMD ["--model", "medium"]

226
LICENSE
View File

@@ -1,210 +1,52 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
# License
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
## Main Software License
1. Definitions.
MIT License
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
Copyright (c) 2025 Quentin Fuxa.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
## SimulStreaming Backend License
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
**When using the SimulStreaming backend (SimulWhisper), additional licensing terms apply:**
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
SimulStreaming (https://github.com/ufal/SimulStreaming) is dual-licensed:
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
### 🔹 Non-Commercial Use
You may use SimulStreaming under the **PolyForm Noncommercial License 1.0.0** if you obtain the code through the GitHub repository. This license is **free of charge** and comes with **no obligations** for non-commercial users.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
### 🔸 Commercial Use
Understanding who uses SimulStreaming commercially helps improve and prioritize development. Therefore, **registration is required** for those who acquire a commercial license.
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
Commercial licenses are planned to be **affordable** to SMEs and individuals. They are considering providing commercial licenses either for free or for a symbolic one-time fee, and may also provide additional support. You can share your preference via the [questionnaire](https://forms.cloud.microsoft.com/e/7tCxb4gJfB).
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
You can also leave your contact [there](https://forms.cloud.microsoft.com/e/7tCxb4gJfB) to be notified when commercial licenses become available.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
**Contact for SimulStreaming licensing:**
[Dominik Macháček](https://ufal.mff.cuni.cz/dominik-machacek/), machacek@ufal.mff.cuni.cz
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2025 Quentin Fuxa
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
---
## Based on:
- **SimulWhisper** by Speech and Audio Technology LAB of Tsinghua University Apache-2.0 https://github.com/ufal/SimulStreaming
- **SimulStreaming** by ÚFAL MIT License https://github.com/ufal/SimulStreaming
- **NeMo** by NVidia - Apache-2.0 - https://github.com/NVIDIA-NeMo/NeMo
- **whisper_streaming** by ÚFAL MIT License https://github.com/ufal/whisper_streaming.
- **silero-vad** by Snakers4 MIT License https://github.com/snakers4/silero-vad.
- **Diart** by juanmc2005 MIT License https://github.com/juanmc2005/diart.
- **whisper_streaming** by ÚFAL MIT License https://github.com/ufal/whisper_streaming. The original work by ÚFAL. License: https://github.com/ufal/whisper_streaming/blob/main/LICENSE
- **silero-vad** by Snakers4 MIT License https://github.com/snakers4/silero-vad. The work by Snakers4 (silero-vad). License: https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/LICENSE
- **Diart** by juanmc2005 MIT License https://github.com/juanmc2005/diart. The work in Diart by juanmc2005. License: https://github.com/juanmc2005/diart/blob/main/LICENSE
- **SimulStreaming** by ÚFAL Dual License (PolyForm Noncommercial License 1.0.0 / Commercial License) https://github.com/ufal/SimulStreaming

110
README.md
View File

@@ -9,18 +9,17 @@
<p align="center">
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a>
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.15-dark_green"></a>
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-Apache 2.0-dark_green"></a>
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.13-dark_green"></a>
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT/Dual Licensed-dark_green"></a>
</p>
Real-time transcription directly to your browser, with a ready-to-use backend+server and a simple frontend.
Real-time speech transcription directly to your browser, with a ready-to-use backend+server and a simple frontend.
#### Powered by Leading Research:
- Simul-[Whisper](https://github.com/backspacetg/simul_whisper)/[Streaming](https://github.com/ufal/SimulStreaming) (SOTA 2025) - Ultra-low latency transcription using [AlignAtt policy](https://arxiv.org/pdf/2305.11408)
- [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting) (2025), based on [distilled](https://huggingface.co/entai2965/nllb-200-distilled-600M-ctranslate2) [NLLB](https://arxiv.org/abs/2207.04672) (2022, 2024) - Simulatenous translation from & to 200 languages.
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - Low latency transcription using [LocalAgreement policy](https://www.isca-archive.org/interspeech_2020/liu20s_interspeech.pdf)
- [SimulStreaming](https://github.com/ufal/SimulStreaming) (SOTA 2025) - Ultra-low latency transcription with AlignAtt policy
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - Low latency transcription with LocalAgreement policy
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - Advanced real-time speaker diarization
- [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) - Real-time speaker diarization
- [Silero VAD](https://github.com/snakers4/silero-vad) (2024) - Enterprise-grade Voice Activity Detection
@@ -40,7 +39,14 @@ Real-time transcription directly to your browser, with a ready-to-use backend+se
```bash
pip install whisperlivekit
```
> You can also clone the repo and `pip install -e .` for the latest version.
> **FFmpeg is required** and must be installed before using WhisperLiveKit
>
> | OS | How to install |
> |-----------|-------------|
> | Ubuntu/Debian | `sudo apt install ffmpeg` |
> | MacOS | `brew install ffmpeg` |
> | Windows | Download .exe from https://ffmpeg.org/download.html and add to PATH |
#### Quick Start
1. **Start the transcription server:**
@@ -54,26 +60,17 @@ pip install whisperlivekit
> - See [tokenizer.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) for the list of all available languages.
> - For HTTPS requirements, see the **Parameters** section for SSL configuration options.
#### Use it to capture audio from web pages.
Go to `chrome-extension` for instructions.
<p align="center">
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/chrome-extension/demo-extension.png" alt="WhisperLiveKit Demo" width="600">
</p>
#### Optional Dependencies
| Optional | `pip install` |
|-----------|-------------|
| **Speaker diarization** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
| **Apple Silicon optimizations** | `mlx-whisper` |
| **Translation** | `nllw` |
| *[Not recommanded]* Speaker diarization with Diart | `diart` |
| *[Not recommanded]* Original Whisper backend | `whisper` |
| *[Not recommanded]* Improved timestamps backend | `whisper-timestamped` |
| **Speaker diarization with Sortformer** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
| Speaker diarization with Diart | `diart` |
| Original Whisper backend | `whisper` |
| Improved timestamps backend | `whisper-timestamped` |
| Apple Silicon optimization backend | `mlx-whisper` |
| OpenAI API backend | `openai` |
See **Parameters & Configuration** below on how to use them.
@@ -85,11 +82,11 @@ See **Parameters & Configuration** below on how to use them.
**Command-line Interface**: Start the transcription server with various options:
```bash
# Large model and translate from french to danish
whisperlivekit-server --model large-v3 --language fr --target-language da
# Use better model than default (small)
whisperlivekit-server --model large-v3
# Diarization and server listening on */80
whisperlivekit-server --host 0.0.0.0 --port 80 --model medium --diarization --language fr
# Advanced configuration with diarization and language
whisperlivekit-server --host 0.0.0.0 --port 8000 --model medium --diarization --language fr
```
@@ -131,21 +128,28 @@ async def websocket_endpoint(websocket: WebSocket):
await audio_processor.process_audio(message)
```
**Frontend Implementation**: The package includes an HTML/JavaScript implementation [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html). You can also import it using `from whisperlivekit import get_inline_ui_html` & `page = get_inline_ui_html()`
**Frontend Implementation**: The package includes an HTML/JavaScript implementation [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html). You can also import it using `from whisperlivekit import get_web_interface_html` & `page = get_web_interface_html()`
## Parameters & Configuration
An important list of parameters can be changed. But what *should* you change?
- the `--model` size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/available_models.md)
- the `--language`. List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English.
- the `--backend` ? you can switch to `--backend faster-whisper` if `simulstreaming` does not work correctly or if you prefer to avoid the dual-license requirements.
- `--warmup-file`, if you have one
- `--host`, `--port`, `--ssl-certfile`, `--ssl-keyfile`, if you set up a server
- `--diarization`, if you want to use it.
The rest I don't recommend. But below are your options.
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/available_models.md) | `small` |
| `--model-path` | .pt file/directory containing whisper model. Overrides `--model`. Recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/models_compatible_formats.md) | `None` |
| `--language` | List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` |
| `--target-language` | If sets, translate to using NLLB. Ex: `fr`. [200 languages available](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/supported_languages.md). If you want to translate to english, you should rather use `--task translate`, since Whisper can do it directly. | `None` |
| `--task` | Set to `translate` to translate *only* to english, using Whisper translation. | `transcribe` |
| `--diarization` | Enable speaker identification | `False` |
| `--backend` | Processing backend. You can switch to `faster-whisper` if `simulstreaming` does not work correctly | `simulstreaming` |
| `--model` | Whisper model size. | `small` |
| `--language` | Source language code or `auto` | `auto` |
| `--task` | `transcribe` or `translate` | `transcribe` |
| `--backend` | Processing backend | `simulstreaming` |
| `--min-chunk-size` | Minimum audio chunk size (seconds) | `1.0` |
| `--no-vac` | Disable Voice Activity Controller | `False` |
| `--no-vad` | Disable Voice Activity Detection | `False` |
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
@@ -153,25 +157,16 @@ async def websocket_endpoint(websocket: WebSocket):
| `--port` | Server port | `8000` |
| `--ssl-certfile` | Path to the SSL certificate file (for HTTPS support) | `None` |
| `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` |
| `--forwarded-allow-ips` | Ip or Ips allowed to reverse proxy the whisperlivekit-server. Supported types are IP Addresses (e.g. 127.0.0.1), IP Networks (e.g. 10.100.0.0/16), or Literals (e.g. /path/to/socket.sock) | `None` |
| `--pcm-input` | raw PCM (s16le) data is expected as input and FFmpeg will be bypassed. Frontend will use AudioWorklet instead of MediaRecorder | `False` |
| Translation options | Description | Default |
|-----------|-------------|---------|
| `--nllb-backend` | `transformers` or `ctranslate2` | `ctranslate2` |
| `--nllb-size` | `600M` or `1.3B` | `600M` |
| Diarization options | Description | Default |
| WhisperStreaming backend options | Description | Default |
|-----------|-------------|---------|
| `--diarization-backend` | `diart` or `sortformer` | `sortformer` |
| `--disable-punctuation-split` | Disable punctuation based splits. See #214 | `False` |
| `--segmentation-model` | Hugging Face model ID for Diart segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
| `--embedding-model` | Hugging Face model ID for Diart embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
| `--confidence-validation` | Use confidence scores for faster validation | `False` |
| `--buffer_trimming` | Buffer trimming strategy (`sentence` or `segment`) | `segment` |
| SimulStreaming backend options | Description | Default |
|-----------|-------------|---------|
| `--disable-fast-encoder` | Disable Faster Whisper or MLX Whisper backends for the encoder (if installed). Inference can be slower but helpful when GPU memory is limited | `False` |
| `--custom-alignment-heads` | Use your own alignment heads, useful when `--model-dir` is used | `None` |
| `--frame-threshold` | AlignAtt frame threshold (lower = faster, higher = more accurate) | `25` |
| `--beams` | Number of beams for beam search (1 = greedy decoding) | `1` |
| `--decoder` | Force decoder type (`beam` or `greedy`) | `auto` |
@@ -182,19 +177,22 @@ async def websocket_endpoint(websocket: WebSocket):
| `--init-prompt` | Initial prompt for the model | `None` |
| `--static-init-prompt` | Static prompt that doesn't scroll | `None` |
| `--max-context-tokens` | Maximum context tokens | `None` |
| `--preload-model-count` | Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent users) | `1` |
| `--model-path` | Direct path to .pt model file. Download it if not found | `./base.pt` |
| `--preloaded-model-count` | Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent users) | `1` |
| WhisperStreaming backend options | Description | Default |
| Diarization options | Description | Default |
|-----------|-------------|---------|
| `--confidence-validation` | Use confidence scores for faster validation | `False` |
| `--buffer_trimming` | Buffer trimming strategy (`sentence` or `segment`) | `segment` |
| `--diarization` | Enable speaker identification | `False` |
| `--diarization-backend` | `diart` or `sortformer` | `sortformer` |
| `--segmentation-model` | Hugging Face model ID for Diart segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
| `--embedding-model` | Hugging Face model ID for Diart embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
> For diarization using Diart, you need to accept user conditions [here](https://huggingface.co/pyannote/segmentation) for the `pyannote/segmentation` model, [here](https://huggingface.co/pyannote/segmentation-3.0) for the `pyannote/segmentation-3.0` model and [here](https://huggingface.co/pyannote/embedding) for the `pyannote/embedding` model. **Then**, login to HuggingFace: `huggingface-cli login`
> For diarization using Diart, you need access to pyannote.audio models:
> 1. [Accept user conditions](https://huggingface.co/pyannote/segmentation) for the `pyannote/segmentation` model
> 2. [Accept user conditions](https://huggingface.co/pyannote/segmentation-3.0) for the `pyannote/segmentation-3.0` model
> 3. [Accept user conditions](https://huggingface.co/pyannote/embedding) for the `pyannote/embedding` model
>4. Login with HuggingFace: `huggingface-cli login`
### 🚀 Deployment Guide

View File

@@ -1,258 +0,0 @@
<h1 align="center">WhisperLiveKit</h1>
<p align="center">
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/demo.png" alt="WhisperLiveKit Demo" width="730">
</p>
<p align="center"><b>話者識別機能付き、リアルタイム、完全ローカルな音声テキスト変換</b></p>
<p align="center">
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a>
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.13-dark_green"></a>
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT/Dual Licensed-dark_green"></a>
</p>
すぐに使えるバックエンド+サーバーとシンプルなフロントエンドで、リアルタイムの音声文字起こしをブラウザに直接提供します。✨
#### 主要な研究による技術:
- [SimulStreaming](https://github.com/ufal/SimulStreaming) (SOTA 2025) - AlignAttポリシーによる超低遅延文字起こし
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - LocalAgreementポリシーによる低遅延文字起こし
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - 高度なリアルタイム話者ダイアライゼーション
- [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) - リアルタイム話者ダイアライゼーション
- [Silero VAD](https://github.com/snakers4/silero-vad) (2024) - エンタープライズグレードの音声区間検出
> **なぜ各音声バッチで単純なWhisperモデルを実行しないのか** Whisperは完全な発話向けに設計されており、リアルタイムのチャンク向けではありません。小さなセグメントを処理するとコンテキストが失われ、単語が音節の途中で途切れ、質の悪い文字起こしになります。WhisperLiveKitは、インテリジェントなバッファリングとインクリメンタルな処理のために、最先端の同時音声研究を利用しています。
### アーキテクチャ
<img alt="Architecture" src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/architecture.png" />
*バックエンドは複数の同時ユーザーをサポートします。音声が検出されない場合、音声区間検出がオーバーヘッドを削減します。*
### インストールとクイックスタート
```bash
pip install whisperlivekit
```
> **FFmpegが必要です** WhisperLiveKitを使用する前にインストールする必要があります。
>
> | OS | インストール方法 |
> |-----------|-------------|
> | Ubuntu/Debian | `sudo apt install ffmpeg` |
> | MacOS | `brew install ffmpeg` |
> | Windows | https://ffmpeg.org/download.html から.exeをダウンロードし、PATHに追加 |
#### クイックスタート
1. **文字起こしサーバーを起動します:**
```bash
whisperlivekit-server --model base --language en
```
2. **ブラウザを開き** `http://localhost:8000` にアクセスします。話し始めると、あなたの言葉がリアルタイムで表示されます!
> - 利用可能なすべての言語のリストについては、[tokenizer.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) を参照してください。
> - HTTPSの要件については、**パラメータ**セクションのSSL設定オプションを参照してください。
#### オプションの依存関係
| オプション | `pip install` |
|-----------|-------------|
| **Sortformerによる話者ダイアライゼーション** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
| Diartによる話者ダイアライゼーション | `diart` |
| オリジナルのWhisperバックエンド | `whisper` |
| タイムスタンプ改善バックエンド | `whisper-timestamped` |
| Apple Silicon最適化バックエンド | `mlx-whisper` |
| OpenAI APIバックエンド | `openai` |
それらの使用方法については、以下の**パラメータと設定**を参照してください。
### 使用例
**コマンドラインインターフェース**: 様々なオプションで文字起こしサーバーを起動します:
```bash
# デフォルト(small)より良いモデルを使用
whisperlivekit-server --model large-v3
# ダイアライゼーションと言語を指定した高度な設定
whisperlivekit-server --host 0.0.0.0 --port 8000 --model medium --diarization --language fr
```
**Python API連携**: 関数やクラスの使用方法のより完全な例については、[basic_server](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) を確認してください。
```python
from whisperlivekit import TranscriptionEngine, AudioProcessor, parse_args
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from contextlib import asynccontextmanager
import asyncio
transcription_engine = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global transcription_engine
transcription_engine = TranscriptionEngine(model="medium", diarization=True, lan="en")
yield
app = FastAPI(lifespan=lifespan)
async def handle_websocket_results(websocket: WebSocket, results_generator):
async for response in results_generator:
await websocket.send_json(response)
await websocket.send_json({"type": "ready_to_stop"})
@app.websocket("/asr")
async def websocket_endpoint(websocket: WebSocket):
global transcription_engine
# 接続ごとに新しいAudioProcessorを作成し、共有エンジンを渡す
audio_processor = AudioProcessor(transcription_engine=transcription_engine)
results_generator = await audio_processor.create_tasks()
results_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
await websocket.accept()
while True:
message = await websocket.receive_bytes()
await audio_processor.process_audio(message)
```
**フロントエンド実装**: パッケージにはHTML/JavaScript実装が[ここ](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html)に含まれています。`from whisperlivekit import get_web_interface_html` & `page = get_web_interface_html()` を使ってインポートすることもできます。
## パラメータと設定
重要なパラメータのリストを変更できます。しかし、何を*変更すべき*でしょうか?
- `--model` サイズ。リストと推奨事項は[こちら](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/available_models.md)
- `--language`。リストは[こちら](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py)。`auto`を使用すると、モデルは自動的に言語を検出しようとしますが、英語に偏る傾向があります。
- `--backend` `simulstreaming`が正しく動作しない場合や、デュアルライセンス要件を避けたい場合は`--backend faster-whisper`に切り替えることができます。
- `--warmup-file`、もしあれば
- `--host`, `--port`, `--ssl-certfile`, `--ssl-keyfile`、サーバーをセットアップする場合
- `--diarization`、使用したい場合。
残りは推奨しません。しかし、以下があなたのオプションです。
| パラメータ | 説明 | デフォルト |
|-----------|-------------|---------|
| `--model` | Whisperモデルのサイズ。 | `small` |
| `--language` | ソース言語コードまたは`auto` | `auto` |
| `--task` | `transcribe`または`translate` | `transcribe` |
| `--backend` | 処理バックエンド | `simulstreaming` |
| `--min-chunk-size` | 最小音声チャンクサイズ(秒) | `1.0` |
| `--no-vac` | 音声アクティビティコントローラーを無効化 | `False` |
| `--no-vad` | 音声区間検出を無効化 | `False` |
| `--warmup-file` | モデルのウォームアップ用音声ファイルパス | `jfk.wav` |
| `--host` | サーバーホストアドレス | `localhost` |
| `--port` | サーバーポート | `8000` |
| `--ssl-certfile` | SSL証明書ファイルへのパスHTTPSサポート用 | `None` |
| `--ssl-keyfile` | SSL秘密鍵ファイルへのパスHTTPSサポート用 | `None` |
| WhisperStreamingバックエンドオプション | 説明 | デフォルト |
|-----------|-------------|---------|
| `--confidence-validation` | 高速な検証のために信頼スコアを使用 | `False` |
| `--buffer_trimming` | バッファトリミング戦略(`sentence`または`segment` | `segment` |
| SimulStreamingバックエンドオプション | 説明 | デフォルト |
|-----------|-------------|---------|
| `--frame-threshold` | AlignAttフレームしきい値低いほど速く、高いほど正確 | `25` |
| `--beams` | ビームサーチのビーム数1 = 貪欲デコーディング) | `1` |
| `--decoder` | デコーダタイプを強制(`beam`または`greedy` | `auto` |
| `--audio-max-len` | 最大音声バッファ長(秒) | `30.0` |
| `--audio-min-len` | 処理する最小音声長(秒) | `0.0` |
| `--cif-ckpt-path` | 単語境界検出用CIFモデルへのパス | `None` |
| `--never-fire` | 未完了の単語を決して切り捨てない | `False` |
| `--init-prompt` | モデルの初期プロンプト | `None` |
| `--static-init-prompt` | スクロールしない静的プロンプト | `None` |
| `--max-context-tokens` | 最大コンテキストトークン数 | `None` |
| `--model-path` | .ptモデルファイルへの直接パス。見つからない場合はダウンロード | `./base.pt` |
| `--preloaded-model-count` | オプション。メモリにプリロードするモデルの数(予想される同時ユーザー数まで設定) | `1` |
| ダイアライゼーションオプション | 説明 | デフォルト |
|-----------|-------------|---------|
| `--diarization` | 話者識別を有効化 | `False` |
| `--diarization-backend` | `diart`または`sortformer` | `sortformer` |
| `--segmentation-model` | DiartセグメンテーションモデルのHugging FaceモデルID。[利用可能なモデル](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
| `--embedding-model` | Diart埋め込みモデルのHugging FaceモデルID。[利用可能なモデル](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
> Diartを使用したダイアライゼーションには、pyannote.audioモデルへのアクセスが必要です
> 1. `pyannote/segmentation`モデルの[ユーザー条件に同意](https://huggingface.co/pyannote/segmentation)
> 2. `pyannote/segmentation-3.0`モデルの[ユーザー条件に同意](https://huggingface.co/pyannote/segmentation-3.0)
> 3. `pyannote/embedding`モデルの[ユーザー条件に同意](https://huggingface.co/pyannote/embedding)
>4. HuggingFaceでログイン: `huggingface-cli login`
### 🚀 デプロイガイド
WhisperLiveKitを本番環境にデプロイするには
1. **サーバーセットアップ**: 本番用ASGIサーバーをインストールし、複数のワーカーで起動します
```bash
pip install uvicorn gunicorn
gunicorn -k uvicorn.workers.UvicornWorker -w 4 your_app:app
```
2. **フロントエンド**: カスタマイズした`html`のバージョンをホストし、WebSocket接続が正しくポイントするようにします
3. **Nginx設定** (本番環境で推奨):
```nginx
server {
listen 80;
server_name your-domain.com;
location / {
proxy_pass http://localhost:8000;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_set_header Host $host;
}}
```
4. **HTTPSサポート**: 安全なデプロイメントのために、WebSocket URLで "ws://" の代わりに "wss://" を使用します
## 🐋 Docker
GPUまたはCPUサポート付きでDockerを使用してアプリケーションを簡単にデプロイします。
### 前提条件
- Dockerがシステムにインストールされていること
- GPUサポートの場合: NVIDIA Dockerランタイムがインストールされていること
### クイックスタート
**GPUアクセラレーション付き (推奨):**
```bash
docker build -t wlk .
docker run --gpus all -p 8000:8000 --name wlk wlk
```
**CPUのみ:**
```bash
docker build -f Dockerfile.cpu -t wlk .
docker run -p 8000:8000 --name wlk wlk
```
### 高度な使用法
**カスタム設定:**
```bash
# カスタムモデルと言語の例
docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
```
### メモリ要件
- **大規模モデル**: Dockerランタイムに十分なメモリが割り当てられていることを確認してください
#### カスタマイズ
- `--build-arg` オプション:
- `EXTRAS="whisper-timestamped"` - イメージのインストールにエクストラを追加します(スペースなし)。必要なコンテナオプションを設定することを忘れないでください!
- `HF_PRECACHE_DIR="./.cache/"` - 初回起動を高速化するためにモデルキャッシュをプリロードします
- `HF_TKN_FILE="./token"` - ゲート付きモデルをダウンロードするためにHugging Face Hubアクセストークンを追加します
## 🔮 ユースケース
会議の文字起こしのためにリアルタイムで議論をキャプチャする、聴覚障害のあるユーザーがアクセシビリティツールを通じて会話を追うのを助ける、コンテンツ作成のためにポッドキャストやビデオを自動的に文字起こしする、カスタマーサービスのために話者識別付きでサポートコールを文字起こしする...

Binary file not shown.

Before

Width:  |  Height:  |  Size: 406 KiB

After

Width:  |  Height:  |  Size: 388 KiB

View File

@@ -1,4 +1,4 @@
# Available Whisper model sizes:
# Available model sizes:
- tiny.en (english only)
- tiny
@@ -58,7 +58,6 @@
- `small`: ~2GB VRAM
- `medium`: ~5GB VRAM
- `large`: ~10GB VRAM
- `largev3turbo`: ~6GB VRAM
**Audio Quality Impact**:
- Clean, clear audio: smaller models may suffice
@@ -70,40 +69,4 @@
2. Limited resources or need speed? → `small` or smaller
3. Good hardware and want best quality? → `large-v3`
4. Need fast, high-quality transcription without translation? → `large-v3-turbo`
5. Need translation capabilities? → `large-v2` or `large-v3` (avoid turbo)
_______________________
# Translation Models and Backend
**Language Support**: ~200 languages
## Distilled Model Sizes Available
| Model | Size | Parameters | VRAM (FP16) | VRAM (INT8) | Quality |
|-------|------|------------|-------------|-------------|---------|
| 600M | 2.46 GB | 600M | ~1.5GB | ~800MB | Good, understandable |
| 1.3B | 5.48 GB | 1.3B | ~3GB | ~1.5GB | Better accuracy, context |
**Quality Impact**: 1.3B has ~15-25% better BLEU scores vs 600M across language pairs.
## Backend Performance
| Backend | Speed vs Base | Memory Usage | Quality Loss |
|---------|---------------|--------------|--------------|
| CTranslate2 | 6-10x faster | 40-60% less | ~5% BLEU drop |
| Transformers | Baseline | High | None |
| Transformers + MPS (on Apple Silicon) | 2x faster | Medium | None |
**Metrics**:
- CTranslate2: 50-100+ tokens/sec
- Transformers: 10-30 tokens/sec
- Apple Silicon with MPS: Up to 2x faster than CTranslate2
## Quick Decision Matrix
**Choose 600M**: Limited resources, close to 0 lag
**Choose 1.3B**: Quality matters
**Choose Transformers**: On Apple Silicon
5. Need translation capabilities? → `large-v2` or `large-v3` (avoid turbo)

View File

@@ -1,19 +0,0 @@
## WhisperLiveKit Chrome Extension v0.1.1
Capture the audio of your current tab, transcribe diarize and translate it using WhisperliveKit, in Chrome and other Chromium-based browsers.
> Currently, only the tab audio is captured; your microphone audio is not recorded.
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/chrome-extension/demo-extension.png" alt="WhisperLiveKit Demo" width="730">
## Running this extension
1. Run `python sync_extension.py` to copy frontend files to the `chrome-extension` directory.
2. Load the `chrome-extension` directory in Chrome as an unpacked extension.
## Devs:
- Impossible to capture audio from tabs if extension is a pannel, unfortunately:
- https://issues.chromium.org/issues/40926394
- https://groups.google.com/a/chromium.org/g/chromium-extensions/c/DET2SXCFnDg
- https://issues.chromium.org/issues/40916430
- To capture microphone in an extension, there are tricks: https://github.com/justinmann/sidepanel-audio-issue , https://medium.com/@lynchee.owo/how-to-enable-microphone-access-in-chrome-extensions-by-code-924295170080 (comments)

View File

@@ -1,9 +0,0 @@
chrome.runtime.onInstalled.addListener((details) => {
if (details.reason.search(/install/g) === -1) {
return
}
chrome.tabs.create({
url: chrome.runtime.getURL("welcome.html"),
active: true
})
})

Binary file not shown.

Before

Width:  |  Height:  |  Size: 5.8 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 5.8 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 376 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 823 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.4 KiB

View File

@@ -1,23 +0,0 @@
{
"manifest_version": 3,
"name": "WhisperLiveKit Tab Capture",
"version": "1.0",
"description": "Capture and transcribe audio from browser tabs using WhisperLiveKit.",
"icons": {
"16": "icons/icon16.png",
"32": "icons/icon32.png",
"48": "icons/icon48.png",
"128": "icons/icon128.png"
},
"action": {
"default_title": "WhisperLiveKit Tab Capture",
"default_popup": "live_transcription.html"
},
"permissions": [
"scripting",
"tabCapture",
"offscreen",
"activeTab",
"storage"
]
}

View File

@@ -1,12 +0,0 @@
<!DOCTYPE html>
<html>
<head>
<title>Request Permissions</title>
<script src="requestPermissions.js"></script>
</head>
<body>
This page exists to workaround an issue with Chrome that blocks permission
requests from chrome extensions
<button id="requestMicrophone">Request Microphone</button>
</body>
</html>

View File

@@ -1,17 +0,0 @@
/**
* Requests user permission for microphone access.
* @returns {Promise<void>} A Promise that resolves when permission is granted or rejects with an error.
*/
async function getUserPermission() {
console.log("Getting user permission for microphone access...");
await navigator.mediaDevices.getUserMedia({ audio: true });
const micPermission = await navigator.permissions.query({
name: "microphone",
});
if (micPermission.state == "granted") {
window.close();
}
}
// Call the function to request microphone permission
getUserPermission();

View File

@@ -1,29 +0,0 @@
console.log("sidepanel.js");
async function run() {
const micPermission = await navigator.permissions.query({
name: "microphone",
});
document.getElementById(
"audioPermission"
).innerText = `MICROPHONE: ${micPermission.state}`;
if (micPermission.state !== "granted") {
chrome.tabs.create({ url: "requestPermissions.html" });
}
const intervalId = setInterval(async () => {
const micPermission = await navigator.permissions.query({
name: "microphone",
});
if (micPermission.state === "granted") {
document.getElementById(
"audioPermission"
).innerText = `MICROPHONE: ${micPermission.state}`;
clearInterval(intervalId);
}
}, 100);
}
void run();

BIN
demo.png

Binary file not shown.

Before

Width:  |  Height:  |  Size: 985 KiB

After

Width:  |  Height:  |  Size: 423 KiB

View File

@@ -1,264 +0,0 @@
# WhisperLiveKit WebSocket API Documentation
> !! **Note**: The new API structure described in this document is currently under deployment.
This documentation is intended for devs who want to build custom frontends.
WLK provides real-time speech transcription, speaker diarization, and translation through a WebSocket API. The server sends incremental updates as audio is processed, allowing clients to display live transcription results with minimal latency.
---
## Legacy API (Current)
### Message Structure
The current API sends complete state snapshots on each update (several time per second)
```typescript
{
"type": str,
"status": str,
"lines": [
{
"speaker": int,
"text": str,
"start": float,
"end": float,
"translation": str | null,
"detected_language": str
}
],
"buffer_transcription": str,
"buffer_diarization": str,
"remaining_time_transcription": float,
"remaining_time_diarization": float
}
```
---
## New API (Under Development)
### Philosophy
Principles:
- **Incremental Updates**: Only updates and new segments are sent
- **Ephemeral Buffers**: Temporary, unvalidated data displayed in real-time but overwritten on next update, at speaker level
## Message Format
```typescript
{
"type": "transcript_update",
"status": "active_transcription" | "no_audio_detected",
"segments": [
{
"id": number,
"speaker": number,
"text": string,
"start_speaker": float,
"start": float,
"end": float,
"language": string | null,
"translation": string,
"words": [
{
"text": string,
"start": float,
"end": float,
"validated": {
"text": boolean,
"speaker": boolean,
}
}
],
"buffer": {
"transcription": string,
"diarization": string,
"translation": string
}
}
],
"metadata": {
"remaining_time_transcription": float,
"remaining_time_diarization": float
}
}
```
### Other Message Types
#### Config Message (sent on connection)
```json
{
"type": "config",
"useAudioWorklet": true / false
}
```
#### Ready to Stop Message (sent after processing complete)
```json
{
"type": "ready_to_stop"
}
```
---
## Field Descriptions
### Segment Fields
| Field | Type | Description |
|-------|------|-------------|
| `id` | `number` | Unique identifier for this segment. Used by clients to update specific segments efficiently. |
| `speaker` | `number` | Speaker ID (1, 2, 3...). Special value `-2` indicates silence. |
| `text` | `string` | Validated transcription text for this update. Should be **appended** to the segment's text on the client side. |
| `start_speaker` | `float` | Timestamp (seconds) when this speaker segment began. |
| `start` | `float` | Timestamp (seconds) of the first word in this update. |
| `end` | `float` | Timestamp (seconds) of the last word in this update. |
| `language` | `string \| null` | ISO language code (e.g., "en", "fr"). `null` until language is detected. |
| `translation` | `string` | Validated translation text for this update. Should be **appended** to the segment's translation on the client side. |
| `words` | `Array` | Array of word-level objects with timing and validation information. |
| `buffer` | `Object` | Per-segment temporary buffers, see below |
### Word Object
| Field | Type | Description |
|-------|------|-------------|
| `text` | `string` | The word text. |
| `start` | `number` | Start timestamp (seconds) of this word. |
| `end` | `number` | End timestamp (seconds) of this word. |
| `validated.text` | `boolean` | Whether the transcription text has been validated. if false, word is also in buffer: transcription |
| `validated.speaker` | `boolean` | Whether the speaker assignment has been validated. if false, word is also in buffer: diarization |
| `validated.language` | `boolean` | Whether the language detection has been validated. if false, word is also in buffer: translation |
### Buffer Object (Per-Segment)
Buffers are **ephemeral**. They should be displayed to the user but not stored permanently in the frontend. Each update may contain a completely different buffer value, and previous buffer is likely to be in the next validated text.
| Field | Type | Description |
|-------|------|-------------|
| `transcription` | `string` | Pending transcription text. Displayed immediately but **overwritten** on next update. |
| `diarization` | `string` | Pending diarization text (text waiting for speaker assignment). Displayed immediately but **overwritten** on next update. |
| `translation` | `string` | Pending translation text. Displayed immediately but **overwritten** on next update. |
### Metadata Fields
| Field | Type | Description |
|-------|------|-------------|
| `remaining_time_transcription` | `float` | Seconds of audio waiting for transcription processing. |
| `remaining_time_diarization` | `float` | Seconds of audio waiting for speaker diarization. |
### Status Values
| Status | Description |
|--------|-------------|
| `active_transcription` | Normal operation, transcription is active. |
| `no_audio_detected` | No audio has been detected yet. |
---
## Update Behavior
### Incremental Updates
The API sends **only changed or new segments**. Clients should:
1. Maintain a local map of segments by ID
2. When receiving an update, merge/update segments by ID
3. Render only the changed segments
### Language Detection
When language is detected for a segment:
```jsonc
// Update 1: No language yet
{
"segments": [
{"id": 1, "speaker": 1, "text": "May see", "language": null}
]
}
// Update 2: Same segment ID, language now detected
{
"segments": [
{"id": 1, "speaker": 1, "text": "Merci", "language": "fr"}
]
}
```
**Client behavior**: **Replace** the existing segment with the same ID.
### Buffer Behavior
Buffers are **per-segment** to handle multi-speaker scenarios correctly.
#### Example: Translation with diarization and translation
```jsonc
// Update 1
{
"segments": [
{
"id": 1,
"speaker": 1,
"text": "Hello world, how are",
"translation": "",
"buffer": {
"transcription": "",
"diarization": " you on",
"translation": "Bonjour le monde"
}
}
]
}
// ==== Frontend ====
// <SPEAKER>1</SPEAKER>
// <TRANSCRIPTION>Hello world, how are <DIARIZATION BUFFER> you on</DIARIZATION BUFFER></TRANSCRIPTION>
// <TRANSLATION><TRANSLATION BUFFER>Bonjour le monde</TRANSLATION BUFFER></TRANSLATION>
// Update 2
{
"segments": [
{
"id": 1,
"speaker": 1,
"text": " you on this",
"translation": "Bonjour tout le monde",
"buffer": {
"transcription": "",
"diarization": " beautiful day",
"translation": ",comment"
}
},
]
}
// ==== Frontend ====
// <SPEAKER>1</SPEAKER>
// <TRANSCRIPTION>Hello world, how are you on this<DIARIZATION BUFFER> beautiful day</DIARIZATION BUFFER></TRANSCRIPTION>
// <TRANSLATION>Bonjour tout le monde<TRANSLATION BUFFER>, comment</TRANSLATION BUFFER><TRANSLATION>
```
### Silence Segments
Silence is represented with the speaker id = `-2`:
```jsonc
{
"id": 5,
"speaker": -2,
"text": "",
"start": 10.5,
"end": 12.3
}
```

View File

@@ -1,14 +0,0 @@
# Model Path Formats
The `--model-path` parameter accepts:
## File Path
- **`.pt` format only** (required for AlignAtt policy decoder)
## Directory Path (recommended)
Must contain:
- **`.pt` file** (required for decoder)
May optionally contain:
- **`.bin` file** - faster-whisper model for encoder (requires faster-whisper)
- **`weights.npz`** or **`weights.safetensors`** - for encoder (requires whisper-mlx)

View File

@@ -1,265 +0,0 @@
# Supported Languages
WhisperLiveKit supports translation into **201 languages** from the FLORES-200 dataset through the NLLB (No Language Left Behind) translation system.
## How to Specify Languages
You can specify languages in **three different ways**:
1. **Language Name** (case-insensitive): `"English"`, `"French"`, `"Spanish"`
2. **ISO Language Code**: `"en"`, `"fr"`, `"es"`
3. **NLLB Code** (FLORES-200): `"eng_Latn"`, `"fra_Latn"`, `"spa_Latn"`
## Usage Examples
### Command Line
```bash
# Using language name
whisperlivekit-server --target-language "French"
# Using ISO code
whisperlivekit-server --target-language fr
# Using NLLB code
whisperlivekit-server --target-language fra_Latn
```
### Python API
```python
from nllw.translation import get_language_info
# Get language information by name
lang_info = get_language_info("French")
print(lang_info)
# {'name': 'French', 'nllb': 'fra_Latn', 'language_code': 'fr'}
# Get language information by ISO code
lang_info = get_language_info("fr")
# Get language information by NLLB code
lang_info = get_language_info("fra_Latn")
# All three return the same result
```
## Complete Language List
The following table lists all 201 supported languages with their corresponding codes:
| Language Name | ISO Code | NLLB Code |
|---------------|----------|-----------|
| Acehnese (Arabic script) | ace_Arab | ace_Arab |
| Acehnese (Latin script) | ace_Latn | ace_Latn |
| Mesopotamian Arabic | acm_Arab | acm_Arab |
| Ta'izzi-Adeni Arabic | acq_Arab | acq_Arab |
| Tunisian Arabic | aeb_Arab | aeb_Arab |
| Afrikaans | af | afr_Latn |
| South Levantine Arabic | ajp_Arab | ajp_Arab |
| Akan | ak | aka_Latn |
| Tosk Albanian | als | als_Latn |
| Amharic | am | amh_Ethi |
| North Levantine Arabic | apc_Arab | apc_Arab |
| Modern Standard Arabic | ar | arb_Arab |
| Modern Standard Arabic (Romanized) | arb_Latn | arb_Latn |
| Najdi Arabic | ars_Arab | ars_Arab |
| Moroccan Arabic | ary_Arab | ary_Arab |
| Egyptian Arabic | arz_Arab | arz_Arab |
| Assamese | as | asm_Beng |
| Asturian | ast | ast_Latn |
| Awadhi | awa | awa_Deva |
| Central Aymara | ay | ayr_Latn |
| South Azerbaijani | azb | azb_Arab |
| North Azerbaijani | az | azj_Latn |
| Bashkir | ba | bak_Cyrl |
| Bambara | bm | bam_Latn |
| Balinese | ban | ban_Latn |
| Belarusian | be | bel_Cyrl |
| Bemba | bem | bem_Latn |
| Bengali | bn | ben_Beng |
| Bhojpuri | bho | bho_Deva |
| Banjar (Arabic script) | bjn_Arab | bjn_Arab |
| Banjar (Latin script) | bjn_Latn | bjn_Latn |
| Standard Tibetan | bo | bod_Tibt |
| Bosnian | bs | bos_Latn |
| Buginese | bug | bug_Latn |
| Bulgarian | bg | bul_Cyrl |
| Catalan | ca | cat_Latn |
| Cebuano | ceb | ceb_Latn |
| Czech | cs | ces_Latn |
| Chokwe | cjk | cjk_Latn |
| Central Kurdish | ckb | ckb_Arab |
| Crimean Tatar | crh | crh_Latn |
| Welsh | cy | cym_Latn |
| Danish | da | dan_Latn |
| German | de | deu_Latn |
| Southwestern Dinka | dik | dik_Latn |
| Dyula | dyu | dyu_Latn |
| Dzongkha | dz | dzo_Tibt |
| Greek | el | ell_Grek |
| English | en | eng_Latn |
| Esperanto | eo | epo_Latn |
| Estonian | et | est_Latn |
| Basque | eu | eus_Latn |
| Ewe | ee | ewe_Latn |
| Faroese | fo | fao_Latn |
| Fijian | fj | fij_Latn |
| Finnish | fi | fin_Latn |
| Fon | fon | fon_Latn |
| French | fr | fra_Latn |
| Friulian | fur-IT | fur_Latn |
| Nigerian Fulfulde | fuv | fuv_Latn |
| West Central Oromo | om | gaz_Latn |
| Scottish Gaelic | gd | gla_Latn |
| Irish | ga-IE | gle_Latn |
| Galician | gl | glg_Latn |
| Guarani | gn | grn_Latn |
| Gujarati | gu-IN | guj_Gujr |
| Haitian Creole | ht | hat_Latn |
| Hausa | ha | hau_Latn |
| Hebrew | he | heb_Hebr |
| Hindi | hi | hin_Deva |
| Chhattisgarhi | hne | hne_Deva |
| Croatian | hr | hrv_Latn |
| Hungarian | hu | hun_Latn |
| Armenian | hy-AM | hye_Armn |
| Igbo | ig | ibo_Latn |
| Ilocano | ilo | ilo_Latn |
| Indonesian | id | ind_Latn |
| Icelandic | is | isl_Latn |
| Italian | it | ita_Latn |
| Javanese | jv | jav_Latn |
| Japanese | ja | jpn_Jpan |
| Kabyle | kab | kab_Latn |
| Jingpho | kac | kac_Latn |
| Kamba | kam | kam_Latn |
| Kannada | kn | kan_Knda |
| Kashmiri (Arabic script) | kas_Arab | kas_Arab |
| Kashmiri (Devanagari script) | kas_Deva | kas_Deva |
| Georgian | ka | kat_Geor |
| Kazakh | kk | kaz_Cyrl |
| Kabiyè | kbp | kbp_Latn |
| Kabuverdianu | kea | kea_Latn |
| Halh Mongolian | mn | khk_Cyrl |
| Khmer | km | khm_Khmr |
| Kikuyu | ki | kik_Latn |
| Kinyarwanda | rw | kin_Latn |
| Kyrgyz | ky | kir_Cyrl |
| Kimbundu | kmb | kmb_Latn |
| Northern Kurdish | kmr | kmr_Latn |
| Central Kanuri (Arabic script) | knc_Arab | knc_Arab |
| Central Kanuri (Latin script) | knc_Latn | knc_Latn |
| Kikongo | kg | kon_Latn |
| Korean | ko | kor_Hang |
| Lao | lo | lao_Laoo |
| Ligurian | lij | lij_Latn |
| Limburgish | li | lim_Latn |
| Lingala | ln | lin_Latn |
| Lithuanian | lt | lit_Latn |
| Lombard | lmo | lmo_Latn |
| Latgalian | ltg | ltg_Latn |
| Luxembourgish | lb | ltz_Latn |
| Luba-Kasai | lua | lua_Latn |
| Ganda | lg | lug_Latn |
| Luo | luo | luo_Latn |
| Mizo | lus | lus_Latn |
| Standard Latvian | lv | lvs_Latn |
| Magahi | mag | mag_Deva |
| Maithili | mai | mai_Deva |
| Malayalam | ml-IN | mal_Mlym |
| Marathi | mr | mar_Deva |
| Minangkabau (Arabic script) | min_Arab | min_Arab |
| Minangkabau (Latin script) | min_Latn | min_Latn |
| Macedonian | mk | mkd_Cyrl |
| Maltese | mt | mlt_Latn |
| Meitei (Bengali script) | mni | mni_Beng |
| Mossi | mos | mos_Latn |
| Maori | mi | mri_Latn |
| Burmese | my | mya_Mymr |
| Dutch | nl | nld_Latn |
| Norwegian Nynorsk | nn-NO | nno_Latn |
| Norwegian Bokmål | nb | nob_Latn |
| Nepali | ne-NP | npi_Deva |
| Northern Sotho | nso | nso_Latn |
| Nuer | nus | nus_Latn |
| Nyanja | ny | nya_Latn |
| Occitan | oc | oci_Latn |
| Odia | or | ory_Orya |
| Pangasinan | pag | pag_Latn |
| Eastern Panjabi | pa | pan_Guru |
| Papiamento | pap | pap_Latn |
| Southern Pashto | pbt | pbt_Arab |
| Western Persian | fa | pes_Arab |
| Plateau Malagasy | mg | plt_Latn |
| Polish | pl | pol_Latn |
| Portuguese | pt-PT | por_Latn |
| Dari | fa-AF | prs_Arab |
| Ayacucho Quechua | qu | quy_Latn |
| Romanian | ro | ron_Latn |
| Rundi | rn | run_Latn |
| Russian | ru | rus_Cyrl |
| Sango | sg | sag_Latn |
| Sanskrit | sa | san_Deva |
| Santali | sat | sat_Olck |
| Sicilian | scn | scn_Latn |
| Shan | shn | shn_Mymr |
| Sinhala | si-LK | sin_Sinh |
| Slovak | sk | slk_Latn |
| Slovenian | sl | slv_Latn |
| Samoan | sm | smo_Latn |
| Shona | sn | sna_Latn |
| Sindhi | sd | snd_Arab |
| Somali | so | som_Latn |
| Southern Sotho | st | sot_Latn |
| Spanish | es-ES | spa_Latn |
| Sardinian | sc | srd_Latn |
| Serbian | sr | srp_Cyrl |
| Swati | ss | ssw_Latn |
| Sundanese | su | sun_Latn |
| Swedish | sv-SE | swe_Latn |
| Swahili | sw | swh_Latn |
| Silesian | szl | szl_Latn |
| Tamil | ta | tam_Taml |
| Tamasheq (Latin script) | taq_Latn | taq_Latn |
| Tamasheq (Tifinagh script) | taq_Tfng | taq_Tfng |
| Tatar | tt-RU | tat_Cyrl |
| Telugu | te | tel_Telu |
| Tajik | tg | tgk_Cyrl |
| Tagalog | tl | tgl_Latn |
| Thai | th | tha_Thai |
| Tigrinya | ti | tir_Ethi |
| Tok Pisin | tpi | tpi_Latn |
| Tswana | tn | tsn_Latn |
| Tsonga | ts | tso_Latn |
| Turkmen | tk | tuk_Latn |
| Tumbuka | tum | tum_Latn |
| Turkish | tr | tur_Latn |
| Twi | tw | twi_Latn |
| Central Atlas Tamazight | tzm | tzm_Tfng |
| Uyghur | ug | uig_Arab |
| Ukrainian | uk | ukr_Cyrl |
| Umbundu | umb | umb_Latn |
| Urdu | ur | urd_Arab |
| Northern Uzbek | uz | uzn_Latn |
| Venetian | vec | vec_Latn |
| Vietnamese | vi | vie_Latn |
| Waray | war | war_Latn |
| Wolof | wo | wol_Latn |
| Xhosa | xh | xho_Latn |
| Eastern Yiddish | yi | ydd_Hebr |
| Yoruba | yo | yor_Latn |
| Yue Chinese | yue | yue_Hant |
| Chinese (Simplified) | zh-CN | zho_Hans |
| Chinese (Traditional) | zh-TW | zho_Hant |
| Standard Malay | ms | zsm_Latn |
| Zulu | zu | zul_Latn |
## Special Features
### Multiple Script Support
Several languages are available in multiple scripts (e.g., Arabic and Latin):
- **Acehnese**: Arabic (`ace_Arab`) and Latin (`ace_Latn`)
- **Banjar**: Arabic (`bjn_Arab`) and Latin (`bjn_Latn`)
- **Kashmiri**: Arabic (`kas_Arab`) and Devanagari (`kas_Deva`)
- **Minangkabau**: Arabic (`min_Arab`) and Latin (`min_Latn`)
- **Tamasheq**: Latin (`taq_Latn`) and Tifinagh (`taq_Tfng`)
- **Central Kanuri**: Arabic (`knc_Arab`) and Latin (`knc_Latn`)

View File

@@ -4,8 +4,8 @@ build-backend = "setuptools.build_meta"
[project]
name = "whisperlivekit"
version = "0.2.13"
description = "Real-time speech-to-text with speaker diarization using Whisper"
version = "0.2.7"
description = "Real-time, Fully Local Whisper's Speech-to-Text and Speaker Diarization"
readme = "README.md"
authors = [
{ name = "Quentin Fuxa" }
@@ -18,11 +18,6 @@ classifiers = [
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
"Programming Language :: Python :: 3.15",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Multimedia :: Sound/Audio :: Speech"
]
@@ -33,16 +28,14 @@ dependencies = [
"faster-whisper",
"uvicorn",
"websockets",
"torchaudio>=2.0.0",
"torch>=2.0.0",
"torch",
"tqdm",
"tiktoken",
'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")'
]
[project.optional-dependencies]
translation = ["nllw"]
sentence_tokenizer = ["mosestokenizer", "wtpsplit"]
sentence = ["mosestokenizer", "wtpsplit"]
[project.urls]
Homepage = "https://github.com/QuentinFuxa/WhisperLiveKit"
@@ -51,19 +44,8 @@ Homepage = "https://github.com/QuentinFuxa/WhisperLiveKit"
whisperlivekit-server = "whisperlivekit.basic_server:main"
[tool.setuptools]
packages = [
"whisperlivekit",
"whisperlivekit.diarization",
"whisperlivekit.simul_whisper",
"whisperlivekit.simul_whisper.whisper",
"whisperlivekit.simul_whisper.whisper.assets",
"whisperlivekit.simul_whisper.whisper.normalizers",
"whisperlivekit.web",
"whisperlivekit.whisper_streaming_custom",
"whisperlivekit.vad_models"
]
packages = ["whisperlivekit", "whisperlivekit.diarization", "whisperlivekit.simul_whisper", "whisperlivekit.simul_whisper.whisper", "whisperlivekit.simul_whisper.whisper.assets", "whisperlivekit.simul_whisper.whisper.normalizers", "whisperlivekit.web", "whisperlivekit.whisper_streaming_custom"]
[tool.setuptools.package-data]
whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"]
"whisperlivekit.simul_whisper.whisper.assets" = ["*.tiktoken", "*.npz"]
"whisperlivekit.vad_models" = ["*.jit", "*.onnx"]

View File

@@ -1,38 +0,0 @@
import shutil
import os
from pathlib import Path
def sync_extension_files():
"""Copy core files from web directory to Chrome extension directory."""
web_dir = Path("whisperlivekit/web")
extension_dir = Path("chrome-extension")
files_to_sync = [
"live_transcription.html", "live_transcription.js", "live_transcription.css"
]
svg_files = [
"system_mode.svg",
"light_mode.svg",
"dark_mode.svg",
"settings.svg"
]
for file in files_to_sync:
src_path = web_dir / file
dest_path = extension_dir / file
dest_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(src_path, dest_path)
for svg_file in svg_files:
src_path = web_dir / "src" / svg_file
dest_path = extension_dir / "web" / "src" / svg_file
dest_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(src_path, dest_path)
if __name__ == "__main__":
sync_extension_files()

View File

@@ -4,41 +4,18 @@ from time import time, sleep
import math
import logging
import traceback
from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State, Transcript, ChangeSpeaker
from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory, online_translation_factory
from whisperlivekit.silero_vad_iterator import FixedVADIterator
from whisperlivekit.results_formater import format_output
from whisperlivekit.timed_objects import ASRToken, Silence
from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
from whisperlivekit.silero_vad_iterator import FixedVADIterator
from whisperlivekit.results_formater import format_output, format_time
# Set up logging once
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
SENTINEL = object() # unique sentinel object for end of stream marker
def cut_at(cumulative_pcm, cut_sec):
cumulative_len = 0
cut_sample = int(cut_sec * 16000)
for ind, pcm_array in enumerate(cumulative_pcm):
if (cumulative_len + len(pcm_array)) >= cut_sample:
cut_chunk = cut_sample - cumulative_len
before = np.concatenate(cumulative_pcm[:ind] + [cumulative_pcm[ind][:cut_chunk]])
after = [cumulative_pcm[ind][cut_chunk:]] + cumulative_pcm[ind+1:]
return before, after
cumulative_len += len(pcm_array)
return np.concatenate(cumulative_pcm), []
async def get_all_from_queue(queue):
items = []
try:
while True:
item = queue.get_nowait()
items.append(item)
except asyncio.QueueEmpty:
pass
return items
class AudioProcessor:
"""
Processes audio streams for transcription and diarization.
@@ -61,82 +38,90 @@ class AudioProcessor:
self.bytes_per_sample = 2
self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample
self.max_bytes_per_sec = 32000 * 5 # 5 seconds of audio at 32 kHz
self.is_pcm_input = self.args.pcm_input
self.last_ffmpeg_activity = time()
self.ffmpeg_health_check_interval = 5
self.ffmpeg_max_idle_time = 10
self.debug = False
# State management
self.is_stopping = False
self.silence = False
self.silence_duration = 0.0
self.state = State()
self.tokens = []
self.buffer_transcription = ""
self.buffer_diarization = ""
self.end_buffer = 0
self.end_attributed_speaker = 0
self.lock = asyncio.Lock()
self.beg_loop = None #to deal with a potential little lag at the websocket initialization, this is now set in process_audio
self.sep = " " # Default separator
self.last_response_content = FrontData()
self.last_detected_speaker = None
self.speaker_languages = {}
self.diarization_before_transcription = False
self.segments = []
if self.diarization_before_transcription:
self.cumulative_pcm = []
self.last_start = 0.0
self.last_end = 0.0
self.last_response_content = ""
# Models and processing
self.asr = models.asr
self.tokenizer = models.tokenizer
self.vac_model = models.vac_model
if self.args.vac:
self.vac = FixedVADIterator(models.vac_model)
else:
self.vac = None
self.ffmpeg_manager = None
self.ffmpeg_reader_task = None
self.ffmpeg_manager = FFmpegManager(
sample_rate=self.sample_rate,
channels=self.channels
)
async def handle_ffmpeg_error(error_type: str):
logger.error(f"FFmpeg error: {error_type}")
self._ffmpeg_error = error_type
self.ffmpeg_manager.on_error_callback = handle_ffmpeg_error
self._ffmpeg_error = None
if not self.is_pcm_input:
self.ffmpeg_manager = FFmpegManager(
sample_rate=self.sample_rate,
channels=self.channels
)
async def handle_ffmpeg_error(error_type: str):
logger.error(f"FFmpeg error: {error_type}")
self._ffmpeg_error = error_type
self.ffmpeg_manager.on_error_callback = handle_ffmpeg_error
self.transcription_queue = asyncio.Queue() if self.args.transcription else None
self.diarization_queue = asyncio.Queue() if self.args.diarization else None
self.translation_queue = asyncio.Queue() if self.args.target_language else None
self.pcm_buffer = bytearray()
# Task references
self.transcription_task = None
self.diarization_task = None
self.translation_task = None
self.ffmpeg_reader_task = None
self.watchdog_task = None
self.all_tasks_for_cleanup = []
self.transcription = None
self.translation = None
self.diarization = None
# Initialize transcription engine if enabled
if self.args.transcription:
self.transcription = online_factory(self.args, models.asr)
self.sep = self.transcription.asr.sep
self.online = online_factory(self.args, models.asr, models.tokenizer)
# Initialize diarization engine if enabled
if self.args.diarization:
self.diarization = online_diarization_factory(self.args, models.diarization_model)
if models.translation_model:
self.translation = online_translation_factory(self.args, models.translation_model)
def convert_pcm_to_float(self, pcm_buffer):
"""Convert PCM buffer in s16le format to normalized NumPy array."""
return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
async def update_transcription(self, new_tokens, buffer, end_buffer, sep):
"""Thread-safe update of transcription with new data."""
async with self.lock:
self.tokens.extend(new_tokens)
self.buffer_transcription = buffer
self.end_buffer = end_buffer
self.sep = sep
async def update_diarization(self, end_attributed_speaker, buffer_diarization=""):
"""Thread-safe update of diarization with new data."""
async with self.lock:
self.end_attributed_speaker = end_attributed_speaker
if buffer_diarization:
self.buffer_diarization = buffer_diarization
async def add_dummy_token(self):
"""Placeholder token when no transcription is available."""
async with self.lock:
current_time = time() - self.state.beg_loop
self.state.tokens.append(ASRToken(
current_time = time() - self.beg_loop if self.beg_loop else 0
self.tokens.append(ASRToken(
start=current_time, end=current_time + 1,
text=".", speaker=-1, is_dummy=True
))
@@ -146,30 +131,43 @@ class AudioProcessor:
async with self.lock:
current_time = time()
# Calculate remaining times
remaining_transcription = 0
if self.state.end_buffer > 0:
remaining_transcription = max(0, round(current_time - self.state.beg_loop - self.state.end_buffer, 1))
if self.end_buffer > 0:
remaining_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 1))
remaining_diarization = 0
if self.state.tokens:
latest_end = max(self.state.end_buffer, self.state.tokens[-1].end if self.state.tokens else 0)
remaining_diarization = max(0, round(latest_end - self.state.end_attributed_speaker, 1))
if self.tokens:
latest_end = max(self.end_buffer, self.tokens[-1].end if self.tokens else 0)
remaining_diarization = max(0, round(latest_end - self.end_attributed_speaker, 1))
self.state.remaining_time_transcription = remaining_transcription
self.state.remaining_time_diarization = remaining_diarization
return {
"tokens": self.tokens.copy(),
"buffer_transcription": self.buffer_transcription,
"buffer_diarization": self.buffer_diarization,
"end_buffer": self.end_buffer,
"end_attributed_speaker": self.end_attributed_speaker,
"sep": self.sep,
"remaining_time_transcription": remaining_transcription,
"remaining_time_diarization": remaining_diarization
}
return self.state
async def reset(self):
"""Reset all state variables to initial values."""
async with self.lock:
self.tokens = []
self.buffer_transcription = self.buffer_diarization = ""
self.end_buffer = self.end_attributed_speaker = 0
self.beg_loop = time()
async def ffmpeg_stdout_reader(self):
"""Read audio data from FFmpeg stdout and process it into the PCM pipeline."""
"""Read audio data from FFmpeg stdout and process it."""
beg = time()
while True:
try:
if self.is_stopping:
logger.info("Stopping ffmpeg_stdout_reader due to stopping flag.")
break
state = await self.ffmpeg_manager.get_state() if self.ffmpeg_manager else FFmpegState.STOPPED
# Check if FFmpeg is running
state = await self.ffmpeg_manager.get_state()
if state == FFmpegState.FAILED:
logger.error("FFmpeg is in FAILED state, cannot read data")
break
@@ -177,41 +175,100 @@ class AudioProcessor:
logger.info("FFmpeg is stopped")
break
elif state != FFmpegState.RUNNING:
await asyncio.sleep(0.1)
logger.warning(f"FFmpeg is in {state} state, waiting...")
await asyncio.sleep(0.5)
continue
current_time = time()
elapsed_time = max(0.0, current_time - beg)
buffer_size = max(int(32000 * elapsed_time), 4096) # dynamic read
elapsed_time = math.floor((current_time - beg) * 10) / 10
buffer_size = max(int(32000 * elapsed_time), 4096)
beg = current_time
chunk = await self.ffmpeg_manager.read_data(buffer_size)
if not chunk:
# No data currently available
await asyncio.sleep(0.05)
continue
if self.is_stopping:
logger.info("FFmpeg stdout closed, stopping.")
break
else:
# No data available, but not stopping - FFmpeg might be restarting
await asyncio.sleep(0.1)
continue
self.pcm_buffer.extend(chunk)
await self.handle_pcm_data()
except asyncio.CancelledError:
logger.info("ffmpeg_stdout_reader cancelled.")
break
# Process when enough data
if len(self.pcm_buffer) >= self.bytes_per_sec:
if len(self.pcm_buffer) > self.max_bytes_per_sec:
logger.warning(
f"Audio buffer too large: {len(self.pcm_buffer) / self.bytes_per_sec:.2f}s. "
f"Consider using a smaller model."
)
# Process audio chunk
pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:self.max_bytes_per_sec])
self.pcm_buffer = self.pcm_buffer[self.max_bytes_per_sec:]
res = None
end_of_audio = False
silence_buffer = None
if self.args.vac:
res = self.vac(pcm_array)
if res is not None:
if res.get('end', 0) > res.get('start', 0):
end_of_audio = True
elif self.silence: #end of silence
self.silence = False
silence_buffer = Silence(duration=time() - self.start_silence)
if silence_buffer:
if self.args.transcription and self.transcription_queue:
await self.transcription_queue.put(silence_buffer)
if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(silence_buffer)
if not self.silence:
if self.args.transcription and self.transcription_queue:
await self.transcription_queue.put(pcm_array.copy())
if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(pcm_array.copy())
self.silence_duration = 0.0
if end_of_audio:
self.silence = True
self.start_silence = time()
# Sleep if no processing is happening
if not self.args.transcription and not self.args.diarization:
await asyncio.sleep(0.1)
except Exception as e:
logger.warning(f"Exception in ffmpeg_stdout_reader: {e}")
logger.debug(f"Traceback: {traceback.format_exc()}")
await asyncio.sleep(0.2)
logger.info("FFmpeg stdout processing finished. Signaling downstream processors if needed.")
if not self.diarization_before_transcription and self.transcription_queue:
logger.warning(f"Traceback: {traceback.format_exc()}")
# Try to recover by waiting a bit
await asyncio.sleep(1)
# Check if we should exit
if self.is_stopping:
break
logger.info("FFmpeg stdout processing finished. Signaling downstream processors.")
if self.args.transcription and self.transcription_queue:
await self.transcription_queue.put(SENTINEL)
if self.diarization:
logger.debug("Sentinel put into transcription_queue.")
if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(SENTINEL)
if self.translation:
await self.translation_queue.put(SENTINEL)
logger.debug("Sentinel put into diarization_queue.")
async def transcription_processor(self):
"""Process audio chunks for transcription."""
self.sep = self.online.asr.sep
cumulative_pcm_duration_stream_time = 0.0
while True:
@@ -221,59 +278,65 @@ class AudioProcessor:
logger.debug("Transcription processor received sentinel. Finishing.")
self.transcription_queue.task_done()
break
if not self.online:
logger.warning("Transcription processor: self.online not initialized.")
self.transcription_queue.task_done()
continue
asr_internal_buffer_duration_s = len(getattr(self.transcription, 'audio_buffer', [])) / self.transcription.SAMPLING_RATE
transcription_lag_s = max(0.0, time() - self.state.beg_loop - self.state.end_buffer)
asr_internal_buffer_duration_s = len(getattr(self.online, 'audio_buffer', [])) / self.online.SAMPLING_RATE
transcription_lag_s = max(0.0, time() - self.beg_loop - self.end_buffer)
asr_processing_logs = f"internal_buffer={asr_internal_buffer_duration_s:.2f}s | lag={transcription_lag_s:.2f}s |"
if type(item) is Silence:
asr_processing_logs += f" + Silence of = {item.duration:.2f}s"
if self.state.tokens:
asr_processing_logs += f" | last_end = {self.state.tokens[-1].end} |"
logger.info(asr_processing_logs)
cumulative_pcm_duration_stream_time += item.duration
self.transcription.insert_silence(item.duration, self.state.tokens[-1].end if self.state.tokens else 0)
continue
elif isinstance(item, ChangeSpeaker):
self.transcription.new_speaker(item)
elif isinstance(item, np.ndarray):
pcm_array = item
if self.tokens:
asr_processing_logs += f" | last_end = {self.tokens[-1].end} |"
logger.info(asr_processing_logs)
if type(item) is Silence:
cumulative_pcm_duration_stream_time += item.duration
self.online.insert_silence(item.duration, self.tokens[-1].end if self.tokens else 0)
continue
if isinstance(item, np.ndarray):
pcm_array = item
else:
raise Exception('item should be pcm_array')
duration_this_chunk = len(pcm_array) / self.sample_rate
cumulative_pcm_duration_stream_time += duration_this_chunk
stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time
self.transcription.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm)
new_tokens, current_audio_processed_upto = await asyncio.to_thread(self.transcription.process_iter)
_buffer_transcript = self.transcription.get_buffer()
buffer_text = _buffer_transcript.text
self.online.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm)
new_tokens, current_audio_processed_upto = self.online.process_iter()
# Get buffer information
_buffer_transcript_obj = self.online.get_buffer()
buffer_text = _buffer_transcript_obj.text
if new_tokens:
validated_text = self.sep.join([t.text for t in new_tokens])
if buffer_text.startswith(validated_text):
_buffer_transcript.text = buffer_text[len(validated_text):].lstrip()
buffer_text = buffer_text[len(validated_text):].lstrip()
candidate_end_times = [self.state.end_buffer]
candidate_end_times = [self.end_buffer]
if new_tokens:
candidate_end_times.append(new_tokens[-1].end)
if _buffer_transcript.end is not None:
candidate_end_times.append(_buffer_transcript.end)
if _buffer_transcript_obj.end is not None:
candidate_end_times.append(_buffer_transcript_obj.end)
candidate_end_times.append(current_audio_processed_upto)
async with self.lock:
self.state.tokens.extend(new_tokens)
self.state.buffer_transcription = _buffer_transcript
self.state.end_buffer = max(candidate_end_times)
new_end_buffer = max(candidate_end_times)
if self.translation_queue:
for token in new_tokens:
await self.translation_queue.put(token)
await self.update_transcription(
new_tokens, buffer_text, new_end_buffer, self.sep
)
self.transcription_queue.task_done()
except Exception as e:
@@ -281,22 +344,13 @@ class AudioProcessor:
logger.warning(f"Traceback: {traceback.format_exc()}")
if 'pcm_array' in locals() and pcm_array is not SENTINEL : # Check if pcm_array was assigned from queue
self.transcription_queue.task_done()
if self.is_stopping:
logger.info("Transcription processor finishing due to stopping flag.")
if self.diarization_queue:
await self.diarization_queue.put(SENTINEL)
if self.translation_queue:
await self.translation_queue.put(SENTINEL)
logger.info("Transcription processor task finished.")
async def diarization_processor(self, diarization_obj):
"""Process audio chunks for speaker diarization."""
if self.diarization_before_transcription:
self.current_speaker = 0
await self.transcription_queue.put(ChangeSpeaker(speaker=self.current_speaker, start=0.0))
buffer_diarization = ""
cumulative_pcm_duration_stream_time = 0.0
while True:
try:
item = await self.diarization_queue.get()
@@ -304,49 +358,30 @@ class AudioProcessor:
logger.debug("Diarization processor received sentinel. Finishing.")
self.diarization_queue.task_done()
break
elif type(item) is Silence:
if type(item) is Silence:
cumulative_pcm_duration_stream_time += item.duration
diarization_obj.insert_silence(item.duration)
continue
elif isinstance(item, np.ndarray):
if isinstance(item, np.ndarray):
pcm_array = item
else:
raise Exception('item should be pcm_array')
# Process diarization
await diarization_obj.diarize(pcm_array)
if self.diarization_before_transcription:
segments = diarization_obj.get_segments()
self.cumulative_pcm.append(pcm_array)
if segments:
last_segment = segments[-1]
if last_segment.speaker != self.current_speaker:
cut_sec = last_segment.start - self.last_end
to_transcript, self.cumulative_pcm = cut_at(self.cumulative_pcm, cut_sec)
await self.transcription_queue.put(to_transcript)
self.current_speaker = last_segment.speaker
await self.transcription_queue.put(ChangeSpeaker(speaker=self.current_speaker, start=last_segment.start))
cut_sec = last_segment.end - last_segment.start
to_transcript, self.cumulative_pcm = cut_at(self.cumulative_pcm, cut_sec)
await self.transcription_queue.put(to_transcript)
self.last_start = last_segment.start
self.last_end = last_segment.end
else:
cut_sec = last_segment.end - self.last_end
to_transcript, self.cumulative_pcm = cut_at(self.cumulative_pcm, cut_sec)
await self.transcription_queue.put(to_transcript)
self.last_end = last_segment.end
elif not self.diarization_before_transcription:
async with self.lock:
self.state.tokens = diarization_obj.assign_speakers_to_tokens(
self.state.tokens,
use_punctuation_split=self.args.punctuation_split
)
if len(self.state.tokens) > 0:
self.state.end_attributed_speaker = max(self.state.tokens[-1].end, self.state.end_attributed_speaker)
async with self.lock:
self.tokens = diarization_obj.assign_speakers_to_tokens(
self.tokens,
use_punctuation_split=self.args.punctuation_split
)
if len(self.tokens) > 0:
self.end_attributed_speaker = max(self.tokens[-1].end, self.end_attributed_speaker)
if buffer_diarization:
self.buffer_diarization = buffer_diarization
self.diarization_queue.task_done()
except Exception as e:
@@ -356,159 +391,155 @@ class AudioProcessor:
self.diarization_queue.task_done()
logger.info("Diarization processor task finished.")
async def translation_processor(self):
# the idea is to ignore diarization for the moment. We use only transcription tokens.
# And the speaker is attributed given the segments used for the translation
# in the future we want to have different languages for each speaker etc, so it will be more complex.
while True:
try:
item = await self.translation_queue.get() #block until at least 1 token
if item is SENTINEL:
logger.debug("Translation processor received sentinel. Finishing.")
self.translation_queue.task_done()
break
elif type(item) is Silence:
self.translation.insert_silence(item.duration)
continue
# get all the available tokens for translation. The more words, the more precise
tokens_to_process = [item]
additional_tokens = await get_all_from_queue(self.translation_queue)
sentinel_found = False
for additional_token in additional_tokens:
if additional_token is SENTINEL:
sentinel_found = True
break
elif type(additional_token) is Silence:
self.translation.insert_silence(additional_token.duration)
continue
else:
tokens_to_process.append(additional_token)
if tokens_to_process:
self.translation.insert_tokens(tokens_to_process)
translation_validated_segments, translation_buffer = await asyncio.to_thread(self.translation.process)
async with self.lock:
self.state.translation_validated_segments = translation_validated_segments
self.state.translation_buffer = translation_buffer
self.translation_queue.task_done()
for _ in additional_tokens:
self.translation_queue.task_done()
if sentinel_found:
logger.debug("Translation processor received sentinel in batch. Finishing.")
break
except Exception as e:
logger.warning(f"Exception in translation_processor: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}")
if 'token' in locals() and item is not SENTINEL:
self.translation_queue.task_done()
if 'additional_tokens' in locals():
for _ in additional_tokens:
self.translation_queue.task_done()
logger.info("Translation processor task finished.")
async def results_formatter(self):
"""Format processing results for output."""
last_sent_trans = None
last_sent_diar = None
while True:
try:
if self._ffmpeg_error:
yield FrontData(status="error", error=f"FFmpeg error: {self._ffmpeg_error}")
ffmpeg_state = await self.ffmpeg_manager.get_state()
if ffmpeg_state == FFmpegState.FAILED and self._ffmpeg_error:
yield {
"status": "error",
"error": f"FFmpeg error: {self._ffmpeg_error}",
"lines": [],
"buffer_transcription": "",
"buffer_diarization": "",
"remaining_time_transcription": 0,
"remaining_time_diarization": 0
}
self._ffmpeg_error = None
await asyncio.sleep(1)
continue
state = await self.get_current_state()
lines, undiarized_text = format_output(
# Get current state
state = await self.get_current_state()
tokens = state["tokens"]
buffer_transcription = state["buffer_transcription"]
buffer_diarization = state["buffer_diarization"]
end_attributed_speaker = state["end_attributed_speaker"]
sep = state["sep"]
# Add dummy tokens if needed
if (not tokens or tokens[-1].is_dummy) and not self.args.transcription and self.args.diarization:
await self.add_dummy_token()
sleep(0.5)
state = await self.get_current_state()
tokens = state["tokens"]
# Format output
lines, undiarized_text, buffer_transcription, buffer_diarization = format_output(
state,
self.silence,
args = self.args,
sep=self.sep
current_time = time() - self.beg_loop if self.beg_loop else None,
diarization = self.args.diarization,
debug = self.debug
)
if lines and lines[-1].speaker == -2:
buffer_transcription = Transcript()
else:
buffer_transcription = state.buffer_transcription
buffer_diarization = ''
# Handle undiarized text
if undiarized_text:
buffer_diarization = self.sep.join(undiarized_text)
async with self.lock:
self.state.end_attributed_speaker = state.end_attributed_speaker
combined = sep.join(undiarized_text)
if buffer_transcription:
combined += sep
await self.update_diarization(end_attributed_speaker, combined)
buffer_diarization = combined
response_status = "active_transcription"
if not state.tokens and not buffer_transcription and not buffer_diarization:
final_lines_for_response = lines.copy()
if not tokens and not buffer_transcription and not buffer_diarization:
response_status = "no_audio_detected"
lines = []
elif not lines:
lines = [Line(
speaker=1,
start=state.end_buffer,
end=state.end_buffer
)]
final_lines_for_response = []
elif response_status == "active_transcription" and not final_lines_for_response:
final_lines_for_response = [{
"speaker": 1,
"text": "",
"beg": format_time(state.get("end_buffer", 0)),
"end": format_time(state.get("end_buffer", 0)),
"diff": 0
}]
response = FrontData(
status=response_status,
lines=lines,
buffer_transcription=buffer_transcription.text.strip(),
buffer_diarization=buffer_diarization,
remaining_time_transcription=state.remaining_time_transcription,
remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0
response = {
"status": response_status,
"lines": final_lines_for_response,
"buffer_transcription": buffer_transcription,
"buffer_diarization": buffer_diarization,
"remaining_time_transcription": state["remaining_time_transcription"],
"remaining_time_diarization": state["remaining_time_diarization"] if self.args.diarization else 0
}
current_response_signature = f"{response_status} | " + \
' '.join([f"{line['speaker']} {line['text']}" for line in final_lines_for_response]) + \
f" | {buffer_transcription} | {buffer_diarization}"
trans = state["remaining_time_transcription"]
diar = state["remaining_time_diarization"]
should_push = (
current_response_signature != self.last_response_content
or last_sent_trans is None
or round(trans, 1) != round(last_sent_trans, 1)
or round(diar, 1) != round(last_sent_diar, 1)
)
should_push = (response != self.last_response_content)
if should_push and (lines or buffer_transcription or buffer_diarization or response_status == "no_audio_detected"):
if should_push and (final_lines_for_response or buffer_transcription or buffer_diarization or response_status == "no_audio_detected" or trans > 0 or diar > 0):
yield response
self.last_response_content = response
self.last_response_content = current_response_signature
last_sent_trans = trans
last_sent_diar = diar
if self.is_stopping and self.transcription_task and self.transcription_task.done() and self.diarization_task and self.diarization_task.done():
logger.info("Results formatter: All upstream processors are done and in stopping state. Terminating.")
return
# Check for termination condition
if self.is_stopping:
all_processors_done = True
if self.args.transcription and self.transcription_task and not self.transcription_task.done():
all_processors_done = False
if self.args.diarization and self.diarization_task and not self.diarization_task.done():
all_processors_done = False
if all_processors_done:
logger.info("Results formatter: All upstream processors are done and in stopping state. Terminating.")
final_state = await self.get_current_state()
return
await asyncio.sleep(0.05)
await asyncio.sleep(0.1) # Avoid overwhelming the client
except Exception as e:
logger.warning(f"Exception in results_formatter. Traceback: {traceback.format_exc()}")
await asyncio.sleep(0.5)
logger.warning(f"Exception in results_formatter: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}")
await asyncio.sleep(0.5) # Back off on error
async def create_tasks(self):
"""Create and start processing tasks."""
self.all_tasks_for_cleanup = []
processing_tasks_for_watchdog = []
# If using FFmpeg (non-PCM input), start it and spawn stdout reader
if not self.is_pcm_input:
success = await self.ffmpeg_manager.start()
if not success:
logger.error("Failed to start FFmpeg manager")
async def error_generator():
yield FrontData(
status="error",
error="FFmpeg failed to start. Please check that FFmpeg is installed."
)
return error_generator()
self.ffmpeg_reader_task = asyncio.create_task(self.ffmpeg_stdout_reader())
self.all_tasks_for_cleanup.append(self.ffmpeg_reader_task)
processing_tasks_for_watchdog.append(self.ffmpeg_reader_task)
success = await self.ffmpeg_manager.start()
if not success:
logger.error("Failed to start FFmpeg manager")
async def error_generator():
yield {
"status": "error",
"error": "FFmpeg failed to start. Please check that FFmpeg is installed.",
"lines": [],
"buffer_transcription": "",
"buffer_diarization": "",
"remaining_time_transcription": 0,
"remaining_time_diarization": 0
}
return error_generator()
if self.transcription:
if self.args.transcription and self.online:
self.transcription_task = asyncio.create_task(self.transcription_processor())
self.all_tasks_for_cleanup.append(self.transcription_task)
processing_tasks_for_watchdog.append(self.transcription_task)
if self.diarization:
if self.args.diarization and self.diarization:
self.diarization_task = asyncio.create_task(self.diarization_processor(self.diarization))
self.all_tasks_for_cleanup.append(self.diarization_task)
processing_tasks_for_watchdog.append(self.diarization_task)
if self.translation:
self.translation_task = asyncio.create_task(self.translation_processor())
self.all_tasks_for_cleanup.append(self.translation_task)
processing_tasks_for_watchdog.append(self.translation_task)
self.ffmpeg_reader_task = asyncio.create_task(self.ffmpeg_stdout_reader())
self.all_tasks_for_cleanup.append(self.ffmpeg_reader_task)
processing_tasks_for_watchdog.append(self.ffmpeg_reader_task)
# Monitor overall system health
self.watchdog_task = asyncio.create_task(self.watchdog(processing_tasks_for_watchdog))
self.all_tasks_for_cleanup.append(self.watchdog_task)
@@ -529,6 +560,15 @@ class AudioProcessor:
logger.error(f"{task_name} unexpectedly completed with exception: {exc}")
else:
logger.info(f"{task_name} completed normally.")
# Check FFmpeg status through the manager
ffmpeg_state = await self.ffmpeg_manager.get_state()
if ffmpeg_state == FFmpegState.FAILED:
logger.error("FFmpeg is in FAILED state, notifying results formatter")
# FFmpeg manager will handle its own recovery
elif ffmpeg_state == FFmpegState.STOPPED and not self.is_stopping:
logger.warning("FFmpeg unexpectedly stopped, attempting restart")
await self.ffmpeg_manager.restart()
except asyncio.CancelledError:
logger.info("Watchdog task cancelled.")
@@ -538,24 +578,18 @@ class AudioProcessor:
async def cleanup(self):
"""Clean up resources when processing is complete."""
logger.info("Starting cleanup of AudioProcessor resources.")
self.is_stopping = True
logger.info("Starting cleanup of AudioProcessor resources.")
for task in self.all_tasks_for_cleanup:
if task and not task.done():
task.cancel()
created_tasks = [t for t in self.all_tasks_for_cleanup if t]
if created_tasks:
await asyncio.gather(*created_tasks, return_exceptions=True)
logger.info("All processing tasks cancelled or finished.")
if not self.is_pcm_input and self.ffmpeg_manager:
try:
await self.ffmpeg_manager.stop()
logger.info("FFmpeg manager stopped.")
except Exception as e:
logger.warning(f"Error stopping FFmpeg manager: {e}")
if self.diarization:
await self.ffmpeg_manager.stop()
logger.info("FFmpeg manager stopped.")
if self.args.diarization and hasattr(self, 'diarization') and hasattr(self.diarization, 'close'):
self.diarization.close()
logger.info("AudioProcessor cleanup complete.")
@@ -563,93 +597,24 @@ class AudioProcessor:
async def process_audio(self, message):
"""Process incoming audio data."""
if not self.state.beg_loop:
self.state.beg_loop = time()
if not self.beg_loop:
self.beg_loop = time()
if not message:
logger.info("Empty audio message received, initiating stop sequence.")
self.is_stopping = True
if self.transcription_queue:
await self.transcription_queue.put(SENTINEL)
if not self.is_pcm_input and self.ffmpeg_manager:
await self.ffmpeg_manager.stop()
# Signal FFmpeg manager to stop accepting data
await self.ffmpeg_manager.stop()
return
if self.is_stopping:
logger.warning("AudioProcessor is stopping. Ignoring incoming audio.")
return
if self.is_pcm_input:
self.pcm_buffer.extend(message)
await self.handle_pcm_data()
else:
if not self.ffmpeg_manager:
logger.error("FFmpeg manager not initialized for non-PCM input.")
return
success = await self.ffmpeg_manager.write_data(message)
if not success:
ffmpeg_state = await self.ffmpeg_manager.get_state()
if ffmpeg_state == FFmpegState.FAILED:
logger.error("FFmpeg is in FAILED state, cannot process audio")
else:
logger.warning("Failed to write audio data to FFmpeg")
async def handle_pcm_data(self):
# Process when enough data
if len(self.pcm_buffer) < self.bytes_per_sec:
return
if len(self.pcm_buffer) > self.max_bytes_per_sec:
logger.warning(
f"Audio buffer too large: {len(self.pcm_buffer) / self.bytes_per_sec:.2f}s. "
f"Consider using a smaller model."
)
chunk_size = min(len(self.pcm_buffer), self.max_bytes_per_sec)
aligned_chunk_size = (chunk_size // self.bytes_per_sample) * self.bytes_per_sample
if aligned_chunk_size == 0:
return
pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:aligned_chunk_size])
self.pcm_buffer = self.pcm_buffer[aligned_chunk_size:]
res = None
end_of_audio = False
silence_buffer = None
if self.args.vac:
res = self.vac(pcm_array)
if res is not None:
if res.get("end", 0) > res.get("start", 0):
end_of_audio = True
elif self.silence: #end of silence
self.silence = False
silence_buffer = Silence(duration=time() - self.start_silence)
if silence_buffer:
if not self.diarization_before_transcription and self.transcription_queue:
await self.transcription_queue.put(silence_buffer)
if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(silence_buffer)
if self.translation_queue:
await self.translation_queue.put(silence_buffer)
if not self.silence:
if not self.diarization_before_transcription and self.transcription_queue:
await self.transcription_queue.put(pcm_array.copy())
if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(pcm_array.copy())
self.silence_duration = 0.0
if end_of_audio:
self.silence = True
self.start_silence = time()
if not self.args.transcription and not self.args.diarization:
await asyncio.sleep(0.1)
success = await self.ffmpeg_manager.write_data(message)
if not success:
ffmpeg_state = await self.ffmpeg_manager.get_state()
if ffmpeg_state == FFmpegState.FAILED:
logger.error("FFmpeg is in FAILED state, cannot process audio")
else:
logger.warning("Failed to write audio data to FFmpeg")

View File

@@ -5,6 +5,9 @@ from fastapi.middleware.cors import CORSMiddleware
from whisperlivekit import TranscriptionEngine, AudioProcessor, get_inline_ui_html, parse_args
import asyncio
import logging
from starlette.staticfiles import StaticFiles
import pathlib
import whisperlivekit.web as webpkg
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logging.getLogger().setLevel(logging.WARNING)
@@ -15,7 +18,7 @@ args = parse_args()
transcription_engine = None
@asynccontextmanager
async def lifespan(app: FastAPI):
async def lifespan(app: FastAPI):
global transcription_engine
transcription_engine = TranscriptionEngine(
**vars(args),
@@ -30,6 +33,8 @@ app.add_middleware(
allow_methods=["*"],
allow_headers=["*"],
)
web_dir = pathlib.Path(webpkg.__file__).parent
app.mount("/web", StaticFiles(directory=str(web_dir)), name="web")
@app.get("/")
async def get():
@@ -40,7 +45,7 @@ async def handle_websocket_results(websocket, results_generator):
"""Consumes results from the audio processor and sends them via WebSocket."""
try:
async for response in results_generator:
await websocket.send_json(response.to_dict())
await websocket.send_json(response)
# when the results_generator finishes it means all audio has been processed
logger.info("Results generator finished. Sending 'ready_to_stop' to client.")
await websocket.send_json({"type": "ready_to_stop"})
@@ -58,11 +63,6 @@ async def websocket_endpoint(websocket: WebSocket):
)
await websocket.accept()
logger.info("WebSocket connection opened.")
try:
await websocket.send_json({"type": "config", "useAudioWorklet": bool(args.pcm_input)})
except Exception as e:
logger.warning(f"Failed to send config to client: {e}")
results_generator = await audio_processor.create_tasks()
websocket_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
@@ -118,8 +118,6 @@ def main():
if ssl_kwargs:
uvicorn_kwargs = {**uvicorn_kwargs, **ssl_kwargs}
if args.forwarded_allow_ips:
uvicorn_kwargs = { **uvicorn_kwargs, "forwarded_allow_ips" : args.forwarded_allow_ips }
uvicorn.run(**uvicorn_kwargs)

View File

@@ -4,15 +4,10 @@ try:
except ImportError:
from .whisper_streaming_custom.whisper_online import backend_factory
from .whisper_streaming_custom.online_asr import OnlineASRProcessor
from whisperlivekit.warmup import warmup_asr, warmup_online
from argparse import Namespace
import sys
def update_with_kwargs(_dict, kwargs):
_dict.update({
k: v for k, v in kwargs.items() if k in _dict
})
return _dict
class TranscriptionEngine:
_instance = None
_initialized = False
@@ -26,49 +21,65 @@ class TranscriptionEngine:
if TranscriptionEngine._initialized:
return
global_params = {
defaults = {
"host": "localhost",
"port": 8000,
"warmup_file": None,
"diarization": False,
"punctuation_split": False,
"target_language": "",
"vac": True,
"vac_onnx": False,
"vac_chunk_size": 0.04,
"log_level": "DEBUG",
"ssl_certfile": None,
"ssl_keyfile": None,
"forwarded_allow_ips": None,
"transcription": True,
"vad": True,
"pcm_input": False,
"disable_punctuation_split" : False,
"diarization_backend": "sortformer",
}
global_params = update_with_kwargs(global_params, kwargs)
transcription_common_params = {
"backend": "simulstreaming",
"warmup_file": None,
"min_chunk_size": 0.5,
"model_size": "tiny",
"model": "tiny",
"model_cache_dir": None,
"model_dir": None,
"lan": "auto",
"task": "transcribe",
"backend": "faster-whisper",
"vac": True,
"vac_chunk_size": 0.04,
"log_level": "DEBUG",
"ssl_certfile": None,
"ssl_keyfile": None,
"transcription": True,
"vad": True,
# whisperstreaming params:
"buffer_trimming": "segment",
"confidence_validation": False,
"buffer_trimming_sec": 15,
# simulstreaming params:
"frame_threshold": 25,
"beams": 1,
"decoder_type": None,
"audio_max_len": 20.0,
"audio_min_len": 0.0,
"cif_ckpt_path": None,
"never_fire": False,
"init_prompt": None,
"static_init_prompt": None,
"max_context_tokens": None,
"model_path": './base.pt',
"diarization_backend": "sortformer",
# diart params:
"segmentation_model": "pyannote/segmentation-3.0",
"embedding_model": "pyannote/embedding",
}
transcription_common_params = update_with_kwargs(transcription_common_params, kwargs)
if transcription_common_params['model_size'].endswith(".en"):
transcription_common_params["lan"] = "en"
config_dict = {**defaults, **kwargs}
if 'no_transcription' in kwargs:
global_params['transcription'] = not global_params['no_transcription']
config_dict['transcription'] = not kwargs['no_transcription']
if 'no_vad' in kwargs:
global_params['vad'] = not kwargs['no_vad']
config_dict['vad'] = not kwargs['no_vad']
if 'no_vac' in kwargs:
global_params['vac'] = not kwargs['no_vac']
config_dict['vac'] = not kwargs['no_vac']
config_dict.pop('no_transcription', None)
config_dict.pop('no_vad', None)
self.args = Namespace(**{**global_params, **transcription_common_params})
if 'language' in kwargs:
config_dict['lan'] = kwargs['language']
config_dict.pop('language', None)
self.args = Namespace(**config_dict)
self.asr = None
self.tokenizer = None
@@ -76,108 +87,82 @@ class TranscriptionEngine:
self.vac_model = None
if self.args.vac:
from whisperlivekit.silero_vad_iterator import load_silero_vad
# Use ONNX if specified, otherwise use JIT (default)
use_onnx = kwargs.get('vac_onnx', False)
self.vac_model = load_silero_vad(onnx=use_onnx)
import torch
self.vac_model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
if self.args.transcription:
if self.args.backend == "simulstreaming":
from whisperlivekit.simul_whisper import SimulStreamingASR
self.tokenizer = None
simulstreaming_kwargs = {}
for attr in ['frame_threshold', 'beams', 'decoder_type', 'audio_max_len', 'audio_min_len',
'cif_ckpt_path', 'never_fire', 'init_prompt', 'static_init_prompt',
'max_context_tokens', 'model_path', 'warmup_file', 'preload_model_count']:
if hasattr(self.args, attr):
simulstreaming_kwargs[attr] = getattr(self.args, attr)
# Add segment_length from min_chunk_size
simulstreaming_kwargs['segment_length'] = getattr(self.args, 'min_chunk_size', 0.5)
simulstreaming_kwargs['task'] = self.args.task
simulstreaming_params = {
"disable_fast_encoder": False,
"custom_alignment_heads": None,
"frame_threshold": 25,
"beams": 1,
"decoder_type": None,
"audio_max_len": 20.0,
"audio_min_len": 0.0,
"cif_ckpt_path": None,
"never_fire": False,
"init_prompt": None,
"static_init_prompt": None,
"max_context_tokens": None,
"model_path": './base.pt',
"preload_model_count": 1,
}
simulstreaming_params = update_with_kwargs(simulstreaming_params, kwargs)
self.tokenizer = None
size = self.args.model
self.asr = SimulStreamingASR(
**transcription_common_params, **simulstreaming_params
modelsize=size,
lan=self.args.lan,
cache_dir=getattr(self.args, 'model_cache_dir', None),
model_dir=getattr(self.args, 'model_dir', None),
**simulstreaming_kwargs
)
else:
whisperstreaming_params = {
"buffer_trimming": "segment",
"confidence_validation": False,
"buffer_trimming_sec": 15,
}
whisperstreaming_params = update_with_kwargs(whisperstreaming_params, kwargs)
self.asr = backend_factory(
**transcription_common_params, **whisperstreaming_params
)
self.asr, self.tokenizer = backend_factory(self.args)
warmup_asr(self.asr, self.args.warmup_file) #for simulstreaming, warmup should be done in the online class not here
if self.args.diarization:
if self.args.diarization_backend == "diart":
from whisperlivekit.diarization.diart_backend import DiartDiarization
diart_params = {
"segmentation_model": "pyannote/segmentation-3.0",
"embedding_model": "pyannote/embedding",
}
diart_params = update_with_kwargs(diart_params, kwargs)
self.diarization_model = DiartDiarization(
block_duration=self.args.min_chunk_size,
**diart_params
segmentation_model_name=self.args.segmentation_model,
embedding_model_name=self.args.embedding_model
)
elif self.args.diarization_backend == "sortformer":
from whisperlivekit.diarization.sortformer_backend import SortformerDiarization
self.diarization_model = SortformerDiarization()
self.translation_model = None
if self.args.target_language:
if self.args.lan == 'auto' and self.args.backend != "simulstreaming":
raise Exception('Translation cannot be set with language auto when transcription backend is not simulstreaming')
else:
try:
from nllw import load_model
except:
raise Exception('To use translation, you must install nllw: `pip install nllw`')
translation_params = {
"nllb_backend": "transformers",
"nllb_size": "600M"
}
translation_params = update_with_kwargs(translation_params, kwargs)
self.translation_model = load_model([self.args.lan], **translation_params) #in the future we want to handle different languages for different speakers
raise ValueError(f"Unknown diarization backend: {self.args.diarization_backend}")
TranscriptionEngine._initialized = True
def online_factory(args, asr):
def online_factory(args, asr, tokenizer, logfile=sys.stderr):
if args.backend == "simulstreaming":
from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor
online = SimulStreamingOnlineProcessor(asr)
online = SimulStreamingOnlineProcessor(
asr,
logfile=logfile,
)
# warmup_online(online, args.warmup_file)
else:
online = OnlineASRProcessor(asr)
online = OnlineASRProcessor(
asr,
tokenizer,
logfile=logfile,
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
confidence_validation = args.confidence_validation
)
return online
def online_diarization_factory(args, diarization_backend):
if args.diarization_backend == "diart":
online = diarization_backend
# Not the best here, since several user/instances will share the same backend, but diart is not SOTA anymore and sortformer is recommended
# Not the best here, since several user/instances will share the same backend, but diart is not SOTA anymore and sortformer is recommanded
if args.diarization_backend == "sortformer":
from whisperlivekit.diarization.sortformer_backend import SortformerDiarizationOnline
online = SortformerDiarizationOnline(shared_model=diarization_backend)
return online
def online_translation_factory(args, translation_model):
#should be at speaker level in the future:
#one shared nllb model for all speaker
#one tokenizer per speaker/language
from nllw import OnlineTranslation
from nllw import OnlineTranslation
return OnlineTranslation(translation_model, [args.lan], [args.target_language])

View File

@@ -60,15 +60,11 @@ class SortformerDiarization:
self.diar_model = SortformerEncLabelModel.from_pretrained(model_name)
self.diar_model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.diar_model.to(device)
## to test
# for name, param in self.diar_model.named_parameters():
# if param.device != device:
# raise RuntimeError(f"Parameter {name} is on {param.device} but should be on {device}")
logger.info(f"Using {device.type.upper()} for Sortformer model")
if torch.cuda.is_available():
self.diar_model.to(torch.device("cuda"))
logger.info("Using CUDA for Sortformer model")
else:
logger.info("Using CPU for Sortformer model")
self.diar_model.sortformer_modules.chunk_len = 10
self.diar_model.sortformer_modules.subsampling_factor = 10
@@ -110,7 +106,6 @@ class SortformerDiarizationOnline:
features=128,
pad_to=0
)
self.audio2mel.to(self.diar_model.device)
self.chunk_duration_seconds = (
self.diar_model.sortformer_modules.chunk_len *
@@ -191,25 +186,22 @@ class SortformerDiarizationOnline:
audio = self.buffer_audio[:threshold]
self.buffer_audio = self.buffer_audio[threshold:]
device = self.diar_model.device
audio_signal_chunk = torch.tensor(audio, device=device).unsqueeze(0)
audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]], device=device)
audio_signal_chunk = torch.tensor(audio).unsqueeze(0).to(self.diar_model.device)
audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]]).to(self.diar_model.device)
processed_signal_chunk, processed_signal_length_chunk = self.audio2mel.get_features(
audio_signal_chunk, audio_signal_length_chunk
)
processed_signal_chunk = processed_signal_chunk.to(device)
processed_signal_length_chunk = processed_signal_length_chunk.to(device)
if self._previous_chunk_features is not None:
to_add = self._previous_chunk_features[:, :, -99:].to(device)
total_features = torch.concat([to_add, processed_signal_chunk], dim=2).to(device)
to_add = self._previous_chunk_features[:, :, -99:]
total_features = torch.concat([to_add, processed_signal_chunk], dim=2)
else:
total_features = processed_signal_chunk.to(device)
total_features = processed_signal_chunk
self._previous_chunk_features = processed_signal_chunk.to(device)
self._previous_chunk_features = processed_signal_chunk
chunk_feat_seq_t = torch.transpose(total_features, 1, 2).to(device)
chunk_feat_seq_t = torch.transpose(total_features, 1, 2)
with torch.inference_mode():
left_offset = 8 if self._chunk_index > 0 else 0
@@ -217,7 +209,7 @@ class SortformerDiarizationOnline:
self.streaming_state, self.total_preds = self.diar_model.forward_streaming_step(
processed_signal=chunk_feat_seq_t,
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]).to(device),
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]),
streaming_state=self.streaming_state,
total_preds=self.total_preds,
left_offset=left_offset,
@@ -289,7 +281,6 @@ class SortformerDiarizationOnline:
Returns:
List of tokens with speaker assignments
Last speaker_segment
"""
with self.segment_lock:
segments = self.speaker_segments.copy()

View File

@@ -7,12 +7,11 @@ import contextlib
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
ERROR_INSTALL_INSTRUCTIONS = f"""
{'='*50}
ERROR_INSTALL_INSTRUCTIONS = """
FFmpeg is not installed or not found in your system's PATH.
Alternative Solution: You can still use WhisperLiveKit without FFmpeg by adding the --pcm-input parameter. Note that when using this option, audio will not be compressed between the frontend and backend, which may result in higher bandwidth usage.
Please install FFmpeg to enable audio processing.
If you want to install FFmpeg:
Installation instructions:
# Ubuntu/Debian:
sudo apt update && sudo apt install ffmpeg
@@ -26,7 +25,6 @@ brew install ffmpeg
# 3. Add the 'bin' directory (e.g., C:\\FFmpeg\\bin) to your system's PATH environment variable.
After installation, please restart the application.
{'='*50}
"""
class FFmpegState(Enum):
@@ -185,8 +183,6 @@ class FFmpegManager:
async def _drain_stderr(self):
try:
while True:
if not self.process or not self.process.stderr:
break
line = await self.process.stderr.readline()
if not line:
break
@@ -194,4 +190,4 @@ class FFmpegManager:
except asyncio.CancelledError:
logger.info("FFmpeg stderr drain task cancelled.")
except Exception as e:
logger.error(f"Error draining FFmpeg stderr: {e}")
logger.error(f"Error draining FFmpeg stderr: {e}")

View File

@@ -20,7 +20,7 @@ def parse_args():
help="""
The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast.
If not set, uses https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav.
If empty, no warmup is performed.
If False, no warmup is performed.
""",
)
@@ -72,12 +72,6 @@ def parse_args():
help="Disable transcription to only see live diarization results.",
)
parser.add_argument(
"--disable-punctuation-split",
action="store_true",
help="Disable the split parameter.",
)
parser.add_argument(
"--min-chunk-size",
type=float,
@@ -89,7 +83,6 @@ def parse_args():
"--model",
type=str,
default="small",
dest='model_size',
help="Name size of the Whisper model to use (default: tiny). Suggested values: tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo. The model is automatically downloaded from the model hub if not present in model cache dir.",
)
@@ -110,7 +103,6 @@ def parse_args():
"--language",
type=str,
default="auto",
dest='lan',
help="Source language code, e.g. en,de,cs, or 'auto' for language detection.",
)
parser.add_argument(
@@ -120,15 +112,6 @@ def parse_args():
choices=["transcribe", "translate"],
help="Transcribe or translate.",
)
parser.add_argument(
"--target-language",
type=str,
default="",
dest="target_language",
help="Target language for translation. Not functional yet.",
)
parser.add_argument(
"--backend",
type=str,
@@ -175,30 +158,9 @@ def parse_args():
)
parser.add_argument("--ssl-certfile", type=str, help="Path to the SSL certificate file.", default=None)
parser.add_argument("--ssl-keyfile", type=str, help="Path to the SSL private key file.", default=None)
parser.add_argument("--forwarded-allow-ips", type=str, help="Allowed ips for reverse proxying.", default=None)
parser.add_argument(
"--pcm-input",
action="store_true",
default=False,
help="If set, raw PCM (s16le) data is expected as input and FFmpeg will be bypassed. Frontend will use AudioWorklet instead of MediaRecorder."
)
# SimulStreaming-specific arguments
simulstreaming_group = parser.add_argument_group('SimulStreaming arguments (only used with --backend simulstreaming)')
simulstreaming_group.add_argument(
"--disable-fast-encoder",
action="store_true",
default=False,
dest="disable_fast_encoder",
help="Disable Faster Whisper or MLX Whisper backends for encoding (if installed). Slower but helpful when GPU memory is limited",
)
simulstreaming_group.add_argument(
"--custom-alignment-heads",
type=str,
default=None,
help="Use your own alignment heads, useful when `--model-dir` is used",
)
simulstreaming_group.add_argument(
"--frame-threshold",
@@ -290,27 +252,13 @@ def parse_args():
)
simulstreaming_group.add_argument(
"--preload-model-count",
"--preloaded_model_count",
type=int,
default=1,
dest="preload_model_count",
dest="preloaded_model_count",
help="Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent instances).",
)
simulstreaming_group.add_argument(
"--nllb-backend",
type=str,
default="transformers",
help="transformers or ctranslate2",
)
simulstreaming_group.add_argument(
"--nllb-size",
type=str,
default="600M",
help="600M or 1.3B",
)
args = parser.parse_args()
args.transcription = not args.no_transcription

View File

@@ -1,5 +1,4 @@
from whisperlivekit.timed_objects import ASRToken
from time import time
import re
MIN_SILENCE_DURATION = 4 #in seconds
@@ -40,7 +39,7 @@ def blank_to_silence(tokens):
)
else:
if silence_token: #there was silence but no more
if silence_token.duration() >= MIN_SILENCE_DURATION:
if silence_token.end - silence_token.start >= MIN_SILENCE_DURATION:
cleaned_tokens.append(
silence_token
)
@@ -78,11 +77,15 @@ def no_token_to_silence(tokens):
new_tokens.append(token)
return new_tokens
def ends_with_silence(tokens, beg_loop, vac_detected_silence):
current_time = time() - (beg_loop if beg_loop else 0.0)
def ends_with_silence(tokens, buffer_transcription, buffer_diarization, current_time, vac_detected_silence):
if not tokens:
return [], buffer_transcription, buffer_diarization
last_token = tokens[-1]
silence_duration = current_time - last_token.end
if (vac_detected_silence and silence_duration > END_SILENCE_DURATION_VAC) or (silence_duration >= END_SILENCE_DURATION):
if tokens and current_time and (
current_time - last_token.end >= END_SILENCE_DURATION
or
(current_time - last_token.end >= 3 and vac_detected_silence)
):
if last_token.speaker == -2:
last_token.end = current_time
else:
@@ -94,14 +97,14 @@ def ends_with_silence(tokens, beg_loop, vac_detected_silence):
probability=0.95
)
)
return tokens
buffer_transcription = "" # for whisperstreaming backend, we should probably validate the buffer has because of the silence
buffer_diarization = ""
return tokens, buffer_transcription, buffer_diarization
def handle_silences(tokens, beg_loop, vac_detected_silence):
if not tokens:
return []
def handle_silences(tokens, buffer_transcription, buffer_diarization, current_time, vac_detected_silence):
tokens = blank_to_silence(tokens) #useful for simulstreaming backend which tends to generate [BLANK_AUDIO] text
tokens = no_token_to_silence(tokens)
tokens = ends_with_silence(tokens, beg_loop, vac_detected_silence)
return tokens
tokens, buffer_transcription, buffer_diarization = ends_with_silence(tokens, buffer_transcription, buffer_diarization, current_time, vac_detected_silence)
return tokens, buffer_transcription, buffer_diarization

View File

@@ -1,17 +1,21 @@
import logging
from datetime import timedelta
from whisperlivekit.remove_silences import handle_silences
from whisperlivekit.timed_objects import Line, Segment, format_time
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
PUNCTUATION_MARKS = {'.', '!', '?'}
CHECK_AROUND = 4
DEBUG = False
def format_time(seconds: float) -> str:
"""Format seconds as HH:MM:SS."""
return str(timedelta(seconds=int(seconds)))
def is_punctuation(token):
if token.is_punctuation():
if token.text.strip() in PUNCTUATION_MARKS:
return True
return False
@@ -30,138 +34,105 @@ def next_speaker_change(i, tokens, speaker):
return ind, token.speaker
return None, speaker
def new_line(
token,
speaker,
last_end_diarized,
debug_info = ""
):
return Line(
speaker = token.corrected_speaker,
text = token.text + (f"[{format_time(token.start)} : {format_time(token.end)}]" if DEBUG else ""),
start = token.start,
end = token.end,
detected_language=token.detected_language
)
return {
"speaker": int(speaker),
"text": token.text + debug_info,
"beg": format_time(token.start),
"end": format_time(token.end),
"diff": round(token.end - last_end_diarized, 2)
}
def append_token_to_last_line(lines, sep, token):
if not lines:
lines.append(new_line(token))
else:
if token.text:
lines[-1].text += sep + token.text + (f"[{format_time(token.start)} : {format_time(token.end)}]" if DEBUG else "")
lines[-1].end = token.end
if not lines[-1].detected_language and token.detected_language:
lines[-1].detected_language = token.detected_language
def append_token_to_last_line(lines, sep, token, debug_info, last_end_diarized):
if token.text:
lines[-1]["text"] += sep + token.text + debug_info
lines[-1]["end"] = format_time(token.end)
lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
def format_output(state, silence, args, sep):
diarization = args.diarization
disable_punctuation_split = args.disable_punctuation_split
tokens = state.tokens
translation_validated_segments = state.translation_validated_segments # Here we will attribute the speakers only based on the timestamps of the segments
translation_buffer = state.translation_buffer
last_validated_token = state.last_validated_token
previous_speaker = 1
undiarized_text = []
tokens = handle_silences(tokens, state.beg_loop, silence)
last_punctuation = None
for i, token in enumerate(tokens[last_validated_token:]):
speaker = int(token.speaker)
token.corrected_speaker = speaker
if not diarization:
if speaker == -1: #Speaker -1 means no attributed by diarization. In the frontend, it should appear under 'Speaker 1'
token.corrected_speaker = 1
token.validated_speaker = True
else:
if is_punctuation(token):
last_punctuation = i
if last_punctuation == i-1:
if token.speaker != previous_speaker:
token.validated_speaker = True
# perfect, diarization perfectly aligned
last_punctuation = None
else:
speaker_change_pos, new_speaker = next_speaker_change(i, tokens, speaker)
if speaker_change_pos:
# Corrects delay:
# That was the idea. <Okay> haha |SPLIT SPEAKER| that's a good one
# should become:
# That was the idea. |SPLIT SPEAKER| <Okay> haha that's a good one
token.corrected_speaker = new_speaker
token.validated_speaker = True
elif speaker != previous_speaker:
if not (speaker == -2 or previous_speaker == -2):
if next_punctuation_change(i, tokens):
# Corrects advance:
# Are you |SPLIT SPEAKER| <okay>? yeah, sure. Absolutely
# should become:
# Are you <okay>? |SPLIT SPEAKER| yeah, sure. Absolutely
token.corrected_speaker = previous_speaker
token.validated_speaker = True
else: #Problematic, except if the language has no punctuation. We append to previous line, except if disable_punctuation_split is set to True.
if not disable_punctuation_split:
token.corrected_speaker = previous_speaker
token.validated_speaker = False
if token.validated_speaker:
state.last_validated_token = i + last_validated_token
previous_speaker = token.corrected_speaker
for token in tokens[last_validated_token+1:state.last_validated_token+1]:
if not state.segments or int(token.corrected_speaker) != int(state.segments[-1].speaker):
state.segments.append(
Segment(
speaker=token.corrected_speaker,
words=[token]
)
)
else:
state.segments[-1].words.append(token)
for token in tokens[state.last_validated_token+1:]:
# if not state.segments or int(token.corrected_speaker) != int(state.segments[-1].speaker):
# state.segments.append(
# Segment(
# speaker=token.corrected_speaker,
# buffer_tokens=[token]
# )
# )
# else:
state.segments[-1].buffer_tokens.append(token)
for segment in state.segments:
segment.consolidate(sep)
# lines = []
# for token in tokens:
# if int(token.corrected_speaker) != int(previous_speaker):
# lines.append(new_line(token))
# else:
# append_token_to_last_line(lines, sep, token)
# previous_speaker = token.corrected_speaker
for ts in translation_validated_segments:
for segment in state.segments[state.last_validated_segment:]:
if ts.is_within(segment):
segment.translation += ts.text + sep
break
for ts in translation_buffer:
for segment in state.segments[state.last_validated_segment:]:
if ts.is_within(segment):
segment.buffer.translation += ts.text + sep
break
# if state.buffer_transcription and lines:
# lines[-1].end = max(state.buffer_transcription.end, lines[-1].end)
def format_output(state, silence, current_time, diarization, debug):
tokens = state["tokens"]
buffer_transcription = state["buffer_transcription"]
buffer_diarization = state["buffer_diarization"]
end_attributed_speaker = state["end_attributed_speaker"]
sep = state["sep"]
previous_speaker = -1
lines = []
for segment in state.segments:
lines.append(Line(
start=segment.start,
end=segment.end,
speaker=segment.speaker,
text=segment.text,
translation=segment.translation
))
return lines, undiarized_text
last_end_diarized = 0
undiarized_text = []
tokens, buffer_transcription, buffer_diarization = handle_silences(tokens, buffer_transcription, buffer_diarization, current_time, silence)
last_punctuation = None
for i, token in enumerate(tokens):
speaker = token.speaker
if not diarization and speaker == -1: #Speaker -1 means no attributed by diarization. In the frontend, it should appear under 'Speaker 1'
speaker = 1
if diarization and not tokens[-1].speaker == -2:
if (speaker in [-1, 0]) and token.end >= end_attributed_speaker:
undiarized_text.append(token.text)
continue
elif (speaker in [-1, 0]) and token.end < end_attributed_speaker:
speaker = previous_speaker
if speaker not in [-1, 0]:
last_end_diarized = max(token.end, last_end_diarized)
debug_info = ""
if debug:
debug_info = f"[{format_time(token.start)} : {format_time(token.end)}]"
if not lines:
lines.append(new_line(token, speaker, last_end_diarized, debug_info = ""))
continue
else:
previous_speaker = lines[-1]['speaker']
if is_punctuation(token):
last_punctuation = i
if last_punctuation == i-1:
if speaker != previous_speaker:
# perfect, diarization perfectly aligned
lines.append(new_line(token, speaker, last_end_diarized, debug_info = ""))
last_punctuation, next_punctuation = None, None
continue
speaker_change_pos, new_speaker = next_speaker_change(i, tokens, speaker)
if speaker_change_pos:
# Corrects delay:
# That was the idea. Okay haha |SPLIT SPEAKER| that's a good one
# should become:
# That was the idea. |SPLIT SPEAKER| Okay haha that's a good one
lines.append(new_line(token, new_speaker, last_end_diarized, debug_info = ""))
else:
# No speaker change to come
append_token_to_last_line(lines, sep, token, debug_info, last_end_diarized)
continue
if speaker != previous_speaker:
if speaker == -2 or previous_speaker == -2: #silences can happen anytime
lines.append(new_line(token, speaker, last_end_diarized, debug_info = ""))
continue
elif next_punctuation_change(i, tokens):
# Corrects advance:
# Are you |SPLIT SPEAKER| okay? yeah, sure. Absolutely
# should become:
# Are you okay? |SPLIT SPEAKER| yeah, sure. Absolutely
append_token_to_last_line(lines, sep, token, debug_info, last_end_diarized)
continue
else: #we create a new speaker, but that's no ideal. We are not sure about the split. We prefer to append to previous line
# lines.append(new_line(token, speaker, last_end_diarized, debug_info = ""))
pass
append_token_to_last_line(lines, sep, token, debug_info, last_end_diarized)
return lines, undiarized_text, buffer_transcription, ''

View File

@@ -1,182 +1,27 @@
import torch
import numpy as np
import warnings
from pathlib import Path
"""
Code is adapted from silero-vad v6: https://github.com/snakers4/silero-vad
"""
# This is copied from silero-vad's vad_utils.py:
# https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/utils_vad.py#L340
# (except changed defaults)
def init_jit_model(model_path: str, device=torch.device('cpu')):
"""Load a JIT model from file."""
model = torch.jit.load(model_path, map_location=device)
model.eval()
return model
class OnnxWrapper():
"""ONNX Runtime wrapper for Silero VAD model."""
def __init__(self, path, force_onnx_cpu=False):
global np
import numpy as np
import onnxruntime
opts = onnxruntime.SessionOptions()
opts.inter_op_num_threads = 1
opts.intra_op_num_threads = 1
if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts)
else:
self.session = onnxruntime.InferenceSession(path, sess_options=opts)
self.reset_states()
if '16k' in path:
warnings.warn('This model support only 16000 sampling rate!')
self.sample_rates = [16000]
else:
self.sample_rates = [8000, 16000]
def _validate_input(self, x, sr: int):
if x.dim() == 1:
x = x.unsqueeze(0)
if x.dim() > 2:
raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}")
if sr != 16000 and (sr % 16000 == 0):
step = sr // 16000
x = x[:,::step]
sr = 16000
if sr not in self.sample_rates:
raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)")
if sr / x.shape[1] > 31.25:
raise ValueError("Input audio chunk is too short")
return x, sr
def reset_states(self, batch_size=1):
self._state = torch.zeros((2, batch_size, 128)).float()
self._context = torch.zeros(0)
self._last_sr = 0
self._last_batch_size = 0
def __call__(self, x, sr: int):
x, sr = self._validate_input(x, sr)
num_samples = 512 if sr == 16000 else 256
if x.shape[-1] != num_samples:
raise ValueError(f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")
batch_size = x.shape[0]
context_size = 64 if sr == 16000 else 32
if not self._last_batch_size:
self.reset_states(batch_size)
if (self._last_sr) and (self._last_sr != sr):
self.reset_states(batch_size)
if (self._last_batch_size) and (self._last_batch_size != batch_size):
self.reset_states(batch_size)
if not len(self._context):
self._context = torch.zeros(batch_size, context_size)
x = torch.cat([self._context, x], dim=1)
if sr in [8000, 16000]:
ort_inputs = {'input': x.numpy(), 'state': self._state.numpy(), 'sr': np.array(sr, dtype='int64')}
ort_outs = self.session.run(None, ort_inputs)
out, state = ort_outs
self._state = torch.from_numpy(state)
else:
raise ValueError()
self._context = x[..., -context_size:]
self._last_sr = sr
self._last_batch_size = batch_size
out = torch.from_numpy(out)
return out
def load_silero_vad(model_path: str = None, onnx: bool = False, opset_version: int = 16):
"""
Load Silero VAD model (JIT or ONNX).
Parameters
----------
model_path : str, optional
Path to model file. If None, uses default bundled model.
onnx : bool, default False
Whether to use ONNX runtime (requires onnxruntime package).
opset_version : int, default 16
ONNX opset version (15 or 16). Only used if onnx=True.
Returns
-------
model
Loaded VAD model (JIT or ONNX wrapper)
"""
available_ops = [15, 16]
if onnx and opset_version not in available_ops:
raise Exception(f'Available ONNX opset_version: {available_ops}')
if model_path is None:
current_dir = Path(__file__).parent
data_dir = current_dir / 'vad_models'
if onnx:
if opset_version == 16:
model_name = 'silero_vad.onnx'
else:
model_name = f'silero_vad_16k_op{opset_version}.onnx'
else:
model_name = 'silero_vad.jit'
model_path = data_dir / model_name
if not model_path.exists():
raise FileNotFoundError(
f"Model file not found: {model_path}\n"
f"Please ensure the whisperlivekit/vad_models/ directory contains the model files."
)
else:
model_path = Path(model_path)
if onnx:
try:
model = OnnxWrapper(str(model_path), force_onnx_cpu=True)
except ImportError:
raise ImportError(
"ONNX runtime not available. Install with: pip install onnxruntime\n"
"Or use JIT model by setting onnx=False"
)
else:
model = init_jit_model(str(model_path))
return model
# Their licence is MIT, same as ours: https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/LICENSE
class VADIterator:
"""
Voice Activity Detection iterator for streaming audio.
This is the Silero VAD v6 implementation.
"""
def __init__(self,
model,
threshold: float = 0.5,
sampling_rate: int = 16000,
min_silence_duration_ms: int = 100,
speech_pad_ms: int = 30
):
def __init__(
self,
model,
threshold: float = 0.5,
sampling_rate: int = 16000,
min_silence_duration_ms: int = 500, # makes sense on one recording that I checked
speech_pad_ms: int = 100, # same
):
"""
Class for stream imitation
Parameters
----------
model: preloaded .jit/.onnx silero VAD model
model: preloaded .jit silero VAD model
threshold: float (default - 0.5)
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
@@ -197,7 +42,9 @@ class VADIterator:
self.sampling_rate = sampling_rate
if sampling_rate not in [8000, 16000]:
raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]')
raise ValueError(
"VADIterator does not support sampling rates other than [8000, 16000]"
)
self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
@@ -210,17 +57,13 @@ class VADIterator:
self.temp_end = 0
self.current_sample = 0
@torch.no_grad()
def __call__(self, x, return_seconds=False, time_resolution: int = 1):
def __call__(self, x, return_seconds=False):
"""
x: torch.Tensor
audio chunk (see examples in repo)
return_seconds: bool (default - False)
whether return timestamps in seconds (default - samples)
time_resolution: int (default - 1)
time resolution of speech coordinates when requested as seconds
"""
if not torch.is_tensor(x):
@@ -239,8 +82,14 @@ class VADIterator:
if (speech_prob >= self.threshold) and not self.triggered:
self.triggered = True
speech_start = max(0, self.current_sample - self.speech_pad_samples - window_size_samples)
return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, time_resolution)}
speech_start = self.current_sample - self.speech_pad_samples
return {
"start": (
int(speech_start)
if not return_seconds
else round(speech_start / self.sampling_rate, 1)
)
}
if (speech_prob < self.threshold - 0.15) and self.triggered:
if not self.temp_end:
@@ -248,17 +97,30 @@ class VADIterator:
if self.current_sample - self.temp_end < self.min_silence_samples:
return None
else:
speech_end = self.temp_end + self.speech_pad_samples - window_size_samples
speech_end = self.temp_end + self.speech_pad_samples
self.temp_end = 0
self.triggered = False
return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, time_resolution)}
return {
"end": (
int(speech_end)
if not return_seconds
else round(speech_end / self.sampling_rate, 1)
)
}
return None
#######################
# because Silero now requires exactly 512-sized audio chunks
import numpy as np
class FixedVADIterator(VADIterator):
"""
Fixed VAD Iterator that handles variable-length audio chunks, not only exactly 512 frames at once.
"""It fixes VADIterator by allowing to process any audio length, not only exactly 512 frames at once.
If audio to be processed at once is long and multiple voiced segments detected,
then __call__ returns the start of the first segment, and end (or middle, which means no end) of the last segment.
"""
def reset_states(self):
@@ -275,20 +137,27 @@ class FixedVADIterator(VADIterator):
ret = r
elif r is not None:
if "end" in r:
ret["end"] = r["end"]
if "start" in r and "end" in ret:
ret["end"] = r["end"] # the latter end
if "start" in r and "end" in ret: # there is an earlier start.
# Remove end, merging this segment with the previous one.
del ret["end"]
return ret if ret != {} else None
if __name__ == "__main__":
model = load_silero_vad(onnx=False)
vad = FixedVADIterator(model)
audio_buffer = np.array([0] * 512, dtype=np.float32)
result = vad(audio_buffer)
print(f" 512 samples: {result}")
# test with 511 samples
audio_buffer = np.array([0] * 511, dtype=np.float32)
result = vad(audio_buffer)
# test/demonstrate the need for FixedVADIterator:
import torch
model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
vac = FixedVADIterator(model)
# vac = VADIterator(model) # the second case crashes with this
# this works: for both
audio_buffer = np.array([0] * (512), dtype=np.float32)
vac(audio_buffer)
# this crashes on the non FixedVADIterator with
# ops.prim.RaiseException("Input audio chunk is too short", "builtins.ValueError")
audio_buffer = np.array([0] * (512 - 1), dtype=np.float32)
vac(audio_buffer)

View File

@@ -3,55 +3,27 @@ import numpy as np
import logging
from typing import List, Tuple, Optional
import logging
import platform
from whisperlivekit.timed_objects import ASRToken, Transcript, ChangeSpeaker
from whisperlivekit.timed_objects import ASRToken, Transcript
from whisperlivekit.warmup import load_file
from whisperlivekit.simul_whisper.license_simulstreaming import SIMULSTREAMING_LICENSE
from .whisper import load_model, tokenizer
from .whisper.audio import TOKENS_PER_SECOND
import os
import gc
from pathlib import Path
logger = logging.getLogger(__name__)
import torch
from whisperlivekit.simul_whisper.config import AlignAttConfig
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper
from whisperlivekit.simul_whisper.whisper import tokenizer
try:
from .mlx_encoder import mlx_model_mapping, load_mlx_encoder
HAS_MLX_WHISPER = True
except ImportError:
if platform.system() == "Darwin" and platform.machine() == "arm64":
print(f"""{"="*50}\nMLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: pip install mlx-whisper\n{"="*50}""")
HAS_MLX_WHISPER = False
if HAS_MLX_WHISPER:
HAS_FASTER_WHISPER = False
else:
try:
from faster_whisper import WhisperModel
HAS_FASTER_WHISPER = True
except ImportError:
HAS_FASTER_WHISPER = False
def model_path_and_type(model_path):
path = Path(model_path)
compatible_whisper_mlx = False
compatible_faster_whisper = False
pt_path = path if path.is_file() and path.suffix.lower() == '.pt' else None
if path.is_dir():
for file in path.iterdir():
if file.is_file():
if file.name in ['weights.npz', "weights.safetensors"]:
compatible_whisper_mlx = True
elif file.suffix.lower() == '.bin':
compatible_faster_whisper = True
elif file.suffix.lower() == '.pt':
pt_path = file
return pt_path, compatible_whisper_mlx, compatible_faster_whisper
import torch
from whisperlivekit.simul_whisper.config import AlignAttConfig
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper
from whisperlivekit.simul_whisper.whisper import tokenizer
except ImportError as e:
raise ImportError(
"""SimulStreaming dependencies are not available.
Please install WhisperLiveKit using pip install "whisperlivekit[simulstreaming]".""")
# TOO_MANY_REPETITIONS = 3
class SimulStreamingOnlineProcessor:
SAMPLING_RATE = 16000
@@ -60,11 +32,13 @@ class SimulStreamingOnlineProcessor:
self,
asr,
logfile=sys.stderr,
warmup_file=None
):
self.asr = asr
self.logfile = logfile
self.end = 0.0
self.buffer = []
self.global_time_offset = 0.0
self.committed: List[ASRToken] = []
self.last_result_tokens: List[ASRToken] = []
self.load_new_backend()
@@ -77,10 +51,7 @@ class SimulStreamingOnlineProcessor:
model = self.asr.get_new_model_instance()
self.model = PaddedAlignAttWhisper(
cfg=self.asr.cfg,
loaded_model=model,
mlx_encoder=self.asr.mlx_encoder,
fw_encoder=self.asr.fw_encoder,
)
loaded_model=model)
def insert_silence(self, silence_duration, offset):
"""
@@ -93,7 +64,7 @@ class SimulStreamingOnlineProcessor:
else:
self.process_iter(is_last=True) #we want to totally process what remains in the buffer.
self.model.refresh_segment(complete=True)
self.model.global_time_offset = silence_duration + offset
self.global_time_offset += silence_duration + offset
@@ -105,15 +76,63 @@ class SimulStreamingOnlineProcessor:
self.end = audio_stream_end_time #Only to be aligned with what happens in whisperstreaming backend.
self.model.insert_audio(audio_tensor)
def new_speaker(self, change_speaker: ChangeSpeaker):
self.process_iter(is_last=True)
self.model.refresh_segment(complete=True)
self.model.speaker = change_speaker.speaker
self.global_time_offset = change_speaker.start
def get_buffer(self):
concat_buffer = Transcript.from_tokens(tokens= self.buffer, sep='')
return concat_buffer
return Transcript(
start=None,
end=None,
text='',
probability=None
)
def timestamped_text(self, tokens, generation):
"""
generate timestamped text from tokens and generation data.
args:
tokens: List of tokens to process
generation: Dictionary containing generation progress and optionally results
returns:
List of tuples containing (start_time, end_time, word) for each word
"""
FRAME_DURATION = 0.02
if "result" in generation:
split_words = generation["result"]["split_words"]
split_tokens = generation["result"]["split_tokens"]
else:
split_words, split_tokens = self.model.tokenizer.split_to_word_tokens(tokens)
progress = generation["progress"]
frames = [p["most_attended_frames"][0] for p in progress]
absolute_timestamps = [p["absolute_timestamps"][0] for p in progress]
tokens_queue = tokens.copy()
timestamped_words = []
for word, word_tokens in zip(split_words, split_tokens):
# start_frame = None
# end_frame = None
for expected_token in word_tokens:
if not tokens_queue or not frames:
raise ValueError(f"Insufficient tokens or frames for word '{word}'")
actual_token = tokens_queue.pop(0)
current_frame = frames.pop(0)
current_timestamp = absolute_timestamps.pop(0)
if actual_token != expected_token:
raise ValueError(
f"Token mismatch: expected '{expected_token}', "
f"got '{actual_token}' at frame {current_frame}"
)
# if start_frame is None:
# start_frame = current_frame
# end_frame = current_frame
# start_time = start_frame * FRAME_DURATION
# end_time = end_frame * FRAME_DURATION
start_time = current_timestamp
end_time = current_timestamp + 0.1
timestamp_entry = (start_time, end_time, word)
timestamped_words.append(timestamp_entry)
logger.debug(f"TS-WORD:\t{start_time:.2f}\t{end_time:.2f}\t{word}")
return timestamped_words
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
"""
@@ -122,14 +141,47 @@ class SimulStreamingOnlineProcessor:
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
"""
try:
timestamped_words = self.model.infer(is_last=is_last)
if self.model.cfg.language == "auto" and timestamped_words and timestamped_words[0].detected_language == None:
self.buffer.extend(timestamped_words)
return [], self.end
tokens, generation_progress = self.model.infer(is_last=is_last)
ts_words = self.timestamped_text(tokens, generation_progress)
self.committed.extend(timestamped_words)
self.buffer = []
return timestamped_words, self.end
new_tokens = []
for ts_word in ts_words:
start, end, word = ts_word
token = ASRToken(
start=start,
end=end,
text=word,
probability=0.95 # fake prob. Maybe we can extract it from the model?
).with_offset(
self.global_time_offset
)
new_tokens.append(token)
# identical_tokens = 0
# n_new_tokens = len(new_tokens)
# if n_new_tokens:
self.committed.extend(new_tokens)
# if token in self.committed:
# pos = len(self.committed) - 1 - self.committed[::-1].index(token)
# if pos:
# for i in range(len(self.committed) - n_new_tokens, -1, -n_new_tokens):
# commited_segment = self.committed[i:i+n_new_tokens]
# if commited_segment == new_tokens:
# identical_segments +=1
# if identical_tokens >= TOO_MANY_REPETITIONS:
# logger.warning('Too many repetition, model is stuck. Load a new one')
# self.committed = self.committed[:i]
# self.load_new_backend()
# return [], self.end
# pos = self.committed.rindex(token)
return new_tokens, self.end
except Exception as e:
@@ -158,23 +210,31 @@ class SimulStreamingASR():
"""SimulStreaming backend with AlignAtt policy."""
sep = ""
def __init__(self, logfile=sys.stderr, **kwargs):
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr, **kwargs):
logger.warning(SIMULSTREAMING_LICENSE)
self.logfile = logfile
self.transcribe_kargs = {}
self.original_language = lan
for key, value in kwargs.items():
setattr(self, key, value)
if self.decoder_type is None:
self.decoder_type = 'greedy' if self.beams == 1 else 'beam'
self.fast_encoder = False
self.model_path = kwargs.get('model_path', './large-v3.pt')
self.frame_threshold = kwargs.get('frame_threshold', 25)
self.audio_max_len = kwargs.get('audio_max_len', 20.0)
self.audio_min_len = kwargs.get('audio_min_len', 0.0)
self.segment_length = kwargs.get('segment_length', 0.5)
self.beams = kwargs.get('beams', 1)
self.decoder_type = kwargs.get('decoder_type', 'greedy' if self.beams == 1 else 'beam')
self.task = kwargs.get('task', 'transcribe')
self.cif_ckpt_path = kwargs.get('cif_ckpt_path', None)
self.never_fire = kwargs.get('never_fire', False)
self.init_prompt = kwargs.get('init_prompt', None)
self.static_init_prompt = kwargs.get('static_init_prompt', None)
self.max_context_tokens = kwargs.get('max_context_tokens', None)
self.warmup_file = kwargs.get('warmup_file', None)
self.preload_model_count = kwargs.get('preload_model_count', 1)
pt_path, compatible_whisper_mlx, compatible_faster_whisper = None, True, True
if self.model_path:
pt_path, compatible_whisper_mlx, compatible_faster_whisper = model_path_and_type(self.model_path)
elif self.model_size is not None:
if model_dir is not None:
self.model_path = model_dir
elif modelsize is not None:
model_mapping = {
'tiny': './tiny.pt',
'base': './base.pt',
@@ -189,15 +249,13 @@ class SimulStreamingASR():
'large-v3': './large-v3.pt',
'large': './large-v3.pt'
}
pt_path = Path(model_mapping.get(self.model_size, f'./{self.model_size}.pt'))
self.model_name = pt_path.name.replace(".pt", "")
self.model_path = model_mapping.get(modelsize, f'./{modelsize}.pt')
self.cfg = AlignAttConfig(
tokenizer_is_multilingual= not self.model_name.endswith(".en"),
segment_length=self.min_chunk_size,
model_path=self.model_path,
segment_length=self.segment_length,
frame_threshold=self.frame_threshold,
language=self.lan,
language=self.original_language,
audio_max_len=self.audio_max_len,
audio_min_len=self.audio_min_len,
cif_ckpt_path=self.cif_ckpt_path,
@@ -216,58 +274,17 @@ class SimulStreamingASR():
else:
self.tokenizer = None
self.mlx_encoder, self.fw_encoder = None, None
if not self.disable_fast_encoder:
if HAS_MLX_WHISPER:
print('Simulstreaming will use MLX whisper to increase encoding speed.')
if self.model_path and compatible_whisper_mlx:
mlx_model = self.model_path
else:
mlx_model = mlx_model_mapping[self.model_name]
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model)
self.fast_encoder = True
elif HAS_FASTER_WHISPER and compatible_faster_whisper:
print('Simulstreaming will use Faster Whisper for the encoder.')
if self.model_path and compatible_faster_whisper:
fw_model = self.model_path
else:
fw_model = self.model_name
self.fw_encoder = WhisperModel(
fw_model,
device='auto',
compute_type='auto',
)
self.fast_encoder = True
self.model_name = os.path.basename(self.cfg.model_path).replace(".pt", "")
self.model_path = os.path.dirname(os.path.abspath(self.cfg.model_path))
self.models = [self.load_model() for i in range(self.preload_model_count)]
def load_model(self):
whisper_model = load_model(
name=self.model_path if self.model_path else self.model_name,
download_root=self.model_path,
decoder_only=self.fast_encoder,
custom_alignment_heads=self.custom_alignment_heads
)
whisper_model = load_model(name=self.model_name, download_root=self.model_path)
warmup_audio = load_file(self.warmup_file)
if warmup_audio is not None:
warmup_audio = torch.from_numpy(warmup_audio).float()
if self.fast_encoder:
temp_model = PaddedAlignAttWhisper(
cfg=self.cfg,
loaded_model=whisper_model,
mlx_encoder=self.mlx_encoder,
fw_encoder=self.fw_encoder,
)
temp_model.warmup(warmup_audio)
temp_model.remove_hooks()
else:
# For standard encoder, use the original transcribe warmup
warmup_audio = load_file(self.warmup_file)
whisper_model.transcribe(warmup_audio, language=self.lan if self.lan != 'auto' else None)
whisper_model.transcribe(warmup_audio, language=self.original_language if self.original_language != 'auto' else None)
return whisper_model
def get_new_model_instance(self):
@@ -300,4 +317,4 @@ class SimulStreamingASR():
"""
Warmup is done directly in load_model
"""
pass
pass

View File

@@ -4,22 +4,26 @@ from dataclasses import dataclass, field
from typing import Literal
@dataclass
class AlignAttConfig():
eval_data_path: str = "tmp"
segment_length: float = field(default=1.0, metadata = {"help": "in second"})
frame_threshold: int = 4
rewind_threshold: int = 200
audio_max_len: float = 20.0
cif_ckpt_path: str = ""
never_fire: bool = False
class SimulWhisperConfig:
'''Options that are common for all simul policies that could be implemented in SimulWhisper.'''
model_path: str
language: str = field(default="zh")
nonspeech_prob: float = 0.5
audio_min_len: float = 1.0
decoder_type: Literal["greedy","beam"] = "greedy"
beam_size: int = 5
task: Literal["transcribe","translate"] = "transcribe"
tokenizer_is_multilingual: bool = False
init_prompt: str = field(default=None)
static_init_prompt: str = field(default=None)
max_context_tokens: int = field(default=None)
@dataclass
class AlignAttConfig(SimulWhisperConfig):
'''Options specific to the AlignAtt policy.'''
eval_data_path: str = "tmp"
segment_length: float = field(default=1.0, metadata = {"help": "in second"})
frame_threshold: int = 4
rewind_threshold: int = 200
audio_max_len: float = 20.0
cif_ckpt_path: str = ""
never_fire: bool = False

View File

@@ -0,0 +1,5 @@
SIMULSTREAMING_LICENSE = f"""
SimulStreaming backend is dual-licensed:
• Non-Commercial Use: PolyForm Noncommercial License 1.0.0.
• Commercial Use: Check SimulStreaming README (github.com/ufal/SimulStreaming) for more details.
"""

View File

@@ -1,72 +0,0 @@
import json
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from mlx.utils import tree_unflatten
from mlx_whisper import whisper
mlx_model_mapping = {
"tiny.en": "mlx-community/whisper-tiny.en-mlx",
"tiny": "mlx-community/whisper-tiny-mlx",
"base.en": "mlx-community/whisper-base.en-mlx",
"base": "mlx-community/whisper-base-mlx",
"small.en": "mlx-community/whisper-small.en-mlx",
"small": "mlx-community/whisper-small-mlx",
"medium.en": "mlx-community/whisper-medium.en-mlx",
"medium": "mlx-community/whisper-medium-mlx",
"large-v1": "mlx-community/whisper-large-v1-mlx",
"large-v2": "mlx-community/whisper-large-v2-mlx",
"large-v3": "mlx-community/whisper-large-v3-mlx",
"large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
"large": "mlx-community/whisper-large-mlx",
}
def load_mlx_encoder(
path_or_hf_repo: str,
dtype: mx.Dtype = mx.float32,
) -> whisper.Whisper:
model_path = Path(path_or_hf_repo)
if not model_path.exists():
model_path = Path(snapshot_download(repo_id=path_or_hf_repo))
with open(str(model_path / "config.json"), "r") as f:
config = json.loads(f.read())
config.pop("model_type", None)
quantization = config.pop("quantization", None)
model_args = whisper.ModelDimensions(**config)
wf = model_path / "weights.safetensors"
if not wf.exists():
wf = model_path / "weights.npz"
weights = mx.load(str(wf))
model = whisper.Whisper(model_args, dtype)
if quantization is not None:
class_predicate = (
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
and f"{p}.scales" in weights
)
nn.quantize(model, **quantization, class_predicate=class_predicate)
weights = tree_unflatten(list(weights.items()))
# we only want to load the encoder weights here.
# Size examples: for tiny.en,
# Decoder weights: 59110771 bytes
# Encoder weights: 15268874 bytes
encoder_weights = {}
encoder_weights['encoder'] = weights['encoder']
del(weights)
model.update(encoder_weights)
mx.eval(model.parameters())
return model

View File

@@ -8,71 +8,50 @@ import torch.nn.functional as F
from .whisper import load_model, DecodingOptions, tokenizer
from .config import AlignAttConfig
from whisperlivekit.timed_objects import ASRToken
from .whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES
from .whisper.timing import median_filter
from .whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens, detect_language
from .beam import BeamPyTorchInference
from .eow_detection import fire_at_boundary, load_cif
import os
from time import time
from .token_buffer import TokenBuffer
import numpy as np
from ..timed_objects import PUNCTUATION_MARKS
from .generation_progress import *
DEC_PAD = 50257
logger = logging.getLogger(__name__)
import sys
import wave
try:
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
HAS_MLX_WHISPER = True
except ImportError:
HAS_MLX_WHISPER = False
if HAS_MLX_WHISPER:
HAS_FASTER_WHISPER = False
else:
try:
from faster_whisper.audio import pad_or_trim as fw_pad_or_trim
from faster_whisper.feature_extractor import FeatureExtractor
HAS_FASTER_WHISPER = True
except ImportError:
HAS_FASTER_WHISPER = False
# New features added to the original version of Simul-Whisper:
# - large-v3 model support
# - translation support
# - beam search
# - prompt -- static vs. non-static
# - context
class PaddedAlignAttWhisper:
def __init__(
self,
cfg: AlignAttConfig,
loaded_model=None,
mlx_encoder=None,
fw_encoder=None,
) -> None:
def __init__(self, cfg: AlignAttConfig, loaded_model=None) -> None:
self.log_segments = 0
self.model = loaded_model
self.mlx_encoder = mlx_encoder
self.fw_encoder = fw_encoder
if fw_encoder:
self.fw_feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels)
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_name = os.path.basename(cfg.model_path).replace(".pt", "")
model_path = os.path.dirname(os.path.abspath(cfg.model_path))
if loaded_model:
self.model = loaded_model
else:
self.model = load_model(name=model_name, download_root=model_path)
logger.info(f"Model dimensions: {self.model.dims}")
self.speaker = -1
self.decode_options = DecodingOptions(
language = cfg.language,
without_timestamps = True,
task=cfg.task
)
self.tokenizer_is_multilingual = cfg.tokenizer_is_multilingual
self.tokenizer_is_multilingual = not model_name.endswith(".en")
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
# self.create_tokenizer('en')
self.detected_language = cfg.language if cfg.language != "auto" else None
self.global_time_offset = 0.0
self.reset_tokenizer_to_auto_next_call = False
self.max_text_len = self.model.dims.n_text_ctx
self.num_decoder_layers = len(self.model.decoder.blocks)
@@ -147,7 +126,6 @@ class PaddedAlignAttWhisper:
self.last_attend_frame = -self.cfg.rewind_threshold
self.cumulative_time_offset = 0.0
self.first_timestamp = None
if self.cfg.max_context_tokens is None:
self.max_context_tokens = self.max_text_len
@@ -167,23 +145,12 @@ class PaddedAlignAttWhisper:
self.inference.kv_cache = self.kv_cache
self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size)
# Tokens to carry over to next chunk for incomplete UTF-8 characters
self.pending_incomplete_tokens = []
def remove_hooks(self):
print('remove hook')
for hook in self.l_hooks:
hook.remove()
def warmup(self, audio):
try:
self.insert_audio(audio)
self.infer(is_last=True)
self.refresh_segment(complete=True)
logger.info("Model warmed up successfully")
except Exception as e:
logger.exception(f"Model warmup failed: {e}")
def create_tokenizer(self, language=None):
self.tokenizer = tokenizer.get_tokenizer(
multilingual=self.tokenizer_is_multilingual,
@@ -258,13 +225,13 @@ class PaddedAlignAttWhisper:
self.init_context()
logger.debug(f"Context: {self.context}")
if not complete and len(self.segments) > 2:
logger.debug("keeping last two segments because they are and it is not complete.")
self.segments = self.segments[-2:]
else:
logger.debug("removing all segments.")
self.segments = []
self.log_segments += 1
self.pending_incomplete_tokens = []
def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor):
if self.always_fire: return True
@@ -326,7 +293,7 @@ class PaddedAlignAttWhisper:
self.segments = self.segments[1:]
logger.debug(f"remove segments: {len(self.segments)} {len(self.tokens)}, cumulative offset: {self.cumulative_time_offset:.2f}s")
if len(self.tokens) > 1:
self.context.append_token_ids(self.tokens[1][0,:].tolist())
self.context.append_token_ids(self.tokens[1][0,:])
self.tokens = [self.initial_tokens] + self.tokens[2:]
return removed_len
@@ -380,11 +347,11 @@ class PaddedAlignAttWhisper:
new_segment = True
if len(self.segments) == 0:
logger.debug("No segments, nothing to do")
return []
return [], {}
if not self._apply_minseglen():
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
input_segments = torch.cat(self.segments, dim=0)
return []
return [], {}
# input_segments is concatenation of audio, it's one array
if len(self.segments) > 1:
@@ -392,77 +359,72 @@ class PaddedAlignAttWhisper:
else:
input_segments = self.segments[0]
# if self.cfg.language == "auto" and self.reset_tokenizer_to_auto_next_call:
# logger.debug("Resetting tokenizer to auto for new sentence.")
# self.create_tokenizer(None)
# self.detected_language = None
# self.init_tokens()
# self.reset_tokenizer_to_auto_next_call = False
# NEW : we can use a different encoder, before using standart whisper for cross attention with the hooks on the decoder
beg_encode = time()
if self.mlx_encoder:
mlx_mel_padded = mlx_log_mel_spectrogram(audio=input_segments.detach(), n_mels=self.model.dims.n_mels, padding=N_SAMPLES)
mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2)
mlx_encoder_feature = self.mlx_encoder.encoder(mlx_mel[None])
encoder_feature = torch.as_tensor(mlx_encoder_feature)
content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0])/2)
elif self.fw_encoder:
audio_length_seconds = len(input_segments) / 16000
content_mel_len = int(audio_length_seconds * 100)//2
mel_padded_2 = self.fw_feature_extractor(waveform=input_segments.numpy(), padding=N_SAMPLES)[None, :]
mel = fw_pad_or_trim(mel_padded_2, N_FRAMES, axis=-1)
encoder_feature_ctranslate = self.fw_encoder.encode(mel)
if self.device == 'cpu': #it seems that on gpu, passing StorageView to torch.as_tensor fails and wrapping in the array works
encoder_feature_ctranslate = np.array(encoder_feature_ctranslate)
try:
encoder_feature = torch.as_tensor(encoder_feature_ctranslate, device=self.device)
except TypeError: # Normally the cpu condition should prevent having exceptions, but just in case:
encoder_feature = torch.as_tensor(np.array(encoder_feature_ctranslate), device=self.device)
else:
# mel + padding to 30s
mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES,
device=self.device).unsqueeze(0)
# trim to 3000
mel = pad_or_trim(mel_padded, N_FRAMES)
# the len of actual audio
content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2)
encoder_feature = self.model.encoder(mel)
end_encode = time()
# print('Encoder duration:', end_encode-beg_encode)
if self.cfg.language == "auto" and self.detected_language is None and self.first_timestamp:
seconds_since_start = self.segments_len() - self.first_timestamp
if seconds_since_start >= 2.0:
language_tokens, language_probs = self.lang_id(encoder_feature)
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
print(f"Detected language: {top_lan} with p={p:.4f}")
self.create_tokenizer(top_lan)
self.last_attend_frame = -self.cfg.rewind_threshold
self.cumulative_time_offset = 0.0
self.init_tokens()
self.init_context()
self.detected_language = top_lan
logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}")
# mel + padding to 30s
mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES,
device=self.model.device).unsqueeze(0)
# trim to 3000
mel = pad_or_trim(mel_padded, N_FRAMES)
# the len of actual audio
content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2)
# encode
encoder_feature = self.model.encoder(mel)
# logger.debug(f"Encoder feature shape: {encoder_feature.shape}")
# if mel.shape[-2:] != (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
# logger.debug("mel ")
if self.cfg.language == "auto" and self.detected_language is None:
language_tokens, language_probs = self.lang_id(encoder_feature)
logger.debug(f"Language tokens: {language_tokens}, probs: {language_probs}")
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
logger.info(f"Detected language: {top_lan} with p={p:.4f}")
#self.tokenizer.language = top_lan
#self.tokenizer.__post_init__()
self.create_tokenizer(top_lan)
self.detected_language = top_lan
self.init_tokens()
logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}")
self.trim_context()
current_tokens = self._current_tokens()
#
fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :])
sum_logprobs = torch.zeros(self.cfg.beam_size, device=self.device)
####################### Decoding loop
logger.info("Decoding loop starts\n")
sum_logprobs = torch.zeros(self.cfg.beam_size, device=mel.device)
completed = False
# punctuation_stop = False
attn_of_alignment_heads = None
most_attended_frame = None
token_len_before_decoding = current_tokens.shape[1]
l_absolute_timestamps = []
generation_progress = []
generation = {
"starting_tokens": BeamTokens(current_tokens[0,:].clone(), self.cfg.beam_size),
"token_len_before_decoding": token_len_before_decoding,
#"fire_detected": fire_detected,
"frames_len": content_mel_len,
"frames_threshold": 4 if is_last else self.cfg.frame_threshold,
# to be filled later
"logits_starting": None,
# to be filled later
"no_speech_prob": None,
"no_speech": False,
# to be filled in the loop
"progress": generation_progress,
}
while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens
generation_progress_loop = []
if new_segment:
tokens_for_logits = current_tokens
@@ -471,26 +433,50 @@ class PaddedAlignAttWhisper:
tokens_for_logits = current_tokens[:,-1:]
logits = self.logits(tokens_for_logits, encoder_feature) # B, len(tokens), token dict size
if new_segment:
generation["logits_starting"] = Logits(logits[:,:,:])
if new_segment and self.tokenizer.no_speech is not None:
probs_at_sot = logits[:, self.sot_index, :].float().softmax(dim=-1)
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
generation["no_speech_prob"] = no_speech_probs[0]
if no_speech_probs[0] > self.cfg.nonspeech_prob:
generation["no_speech"] = True
logger.info("no speech, stop")
break
logits = logits[:, -1, :] # logits for the last token
generation_progress_loop.append(("logits_before_suppress",Logits(logits)))
# supress blank tokens only at the beginning of the segment
if new_segment:
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
new_segment = False
self.suppress_tokens(logits)
#generation_progress_loop.append(("logits_after_suppres",BeamLogits(logits[0,:].clone(), self.cfg.beam_size)))
generation_progress_loop.append(("logits_after_suppress",Logits(logits)))
current_tokens, completed = self.token_decoder.update(current_tokens, logits, sum_logprobs)
generation_progress_loop.append(("beam_tokens",Tokens(current_tokens[:,-1].clone())))
generation_progress_loop.append(("sum_logprobs",sum_logprobs.tolist()))
generation_progress_loop.append(("completed",completed))
logger.debug(f"Decoding completed: {completed}, sum_logprobs: {sum_logprobs.tolist()}, tokens: ")
self.debug_print_tokens(current_tokens)
# if self.decoder_type == "beam":
# logger.debug(f"Finished sequences: {self.token_decoder.finished_sequences}")
# logprobs = F.log_softmax(logits.float(), dim=-1)
# idx = 0
# logger.debug(f"Beam search topk: {logprobs[idx].topk(self.cfg.beam_size + 1)}")
# logger.debug(f"Greedy search argmax: {logits.argmax(dim=-1)}")
# if completed:
# self.debug_print_tokens(current_tokens)
# logger.debug("decode stopped because decoder completed")
attn_of_alignment_heads = [[] for _ in range(self.num_align_heads)]
for i, attn_mat in enumerate(self.dec_attns):
layer_rank = int(i % len(self.model.decoder.blocks))
@@ -509,24 +495,30 @@ class PaddedAlignAttWhisper:
t = torch.cat(mat, dim=1)
tmp.append(t)
attn_of_alignment_heads = torch.stack(tmp, dim=1)
# logger.debug(str(attn_of_alignment_heads.shape) + " tttady")
std, mean = torch.std_mean(attn_of_alignment_heads, dim=-2, keepdim=True, unbiased=False)
attn_of_alignment_heads = (attn_of_alignment_heads - mean) / std
attn_of_alignment_heads = median_filter(attn_of_alignment_heads, 7) # from whisper.timing
attn_of_alignment_heads = attn_of_alignment_heads.mean(dim=1)
# logger.debug(str(attn_of_alignment_heads.shape) + " po mean")
attn_of_alignment_heads = attn_of_alignment_heads[:,:, :content_mel_len]
# logger.debug(str(attn_of_alignment_heads.shape) + " pak ")
# for each beam, the most attended frame is:
most_attended_frames = torch.argmax(attn_of_alignment_heads[:,-1,:], dim=-1)
generation_progress_loop.append(("most_attended_frames",most_attended_frames.clone().tolist()))
# Calculate absolute timestamps accounting for cumulative offset
absolute_timestamps = [(frame * 0.02 + self.cumulative_time_offset) for frame in most_attended_frames.tolist()]
generation_progress_loop.append(("absolute_timestamps", absolute_timestamps))
logger.debug(str(most_attended_frames.tolist()) + " most att frames")
logger.debug(f"Absolute timestamps: {absolute_timestamps} (offset: {self.cumulative_time_offset:.2f}s)")
most_attended_frame = most_attended_frames[0].item()
l_absolute_timestamps.append(absolute_timestamps[0])
generation_progress.append(dict(generation_progress_loop))
logger.debug("current tokens" + str(current_tokens.shape))
if completed:
# # stripping the last token, the eot
@@ -564,72 +556,66 @@ class PaddedAlignAttWhisper:
self.tokenizer.decode([current_tokens[i, -1].item()])
))
# for k,v in generation.items():
# print(k,v,file=sys.stderr)
# for x in generation_progress:
# for y in x.items():
# print("\t\t",*y,file=sys.stderr)
# print("\t","----", file=sys.stderr)
# print("\t", "end of generation_progress_loop", file=sys.stderr)
# sys.exit(1)
####################### End of decoding loop
logger.info("End of decoding loop")
# if attn_of_alignment_heads is not None:
# seg_len = int(segment.shape[0] / 16000 * TOKENS_PER_SECOND)
# # Lets' now consider only the top hypothesis in the beam search
# top_beam_attn_of_alignment_heads = attn_of_alignment_heads[0]
# # debug print: how is the new token attended?
# new_token_attn = top_beam_attn_of_alignment_heads[token_len_before_decoding:, -seg_len:]
# logger.debug(f"New token attention shape: {new_token_attn.shape}")
# if new_token_attn.shape[0] == 0: # it's not attended in the current audio segment
# logger.debug("no token generated")
# else: # it is, and the max attention is:
# new_token_max_attn, _ = new_token_attn.max(dim=-1)
# logger.debug(f"segment max attention: {new_token_max_attn.mean().item()/len(self.segments)}")
# let's now operate only with the top beam hypothesis
tokens_to_split = current_tokens[0, token_len_before_decoding:]
# Prepend pending tokens from previous chunk if any
if self.pending_incomplete_tokens:
logger.debug(f"[UTF-8 Fix] Prepending {len(self.pending_incomplete_tokens)} pending tokens: {self.pending_incomplete_tokens}")
pending_tensor = torch.tensor(self.pending_incomplete_tokens, dtype=torch.long, device=self.device)
tokens_to_split = torch.cat([pending_tensor, tokens_to_split])
if fire_detected or is_last: #or punctuation_stop:
if fire_detected or is_last:
new_hypothesis = tokens_to_split.flatten().tolist()
split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis)
else:
# going to truncate the tokens after the last space
split_words, split_tokens = self.tokenizer.split_to_word_tokens(tokens_to_split.tolist())
generation["result"] = {"split_words": split_words[:-1], "split_tokens": split_tokens[:-1]}
generation["result_truncated"] = {"split_words": split_words[-1:], "split_tokens": split_tokens[-1:]}
# text_to_split = self.tokenizer.decode(tokens_to_split)
# logger.debug(f"text_to_split: {text_to_split}")
# logger.debug("text at current step: {}".format(text_to_split.replace(" ", "<space>")))
# text_before_space = " ".join(text_to_split.split(" ")[:-1])
# logger.debug("before the last space: {}".format(text_before_space.replace(" ", "<space>")))
if len(split_words) > 1:
new_hypothesis = [i for sublist in split_tokens[:-1] for i in sublist]
else:
new_hypothesis = []
### new hypothesis
logger.debug(f"new_hypothesis: {new_hypothesis}")
new_tokens = torch.tensor([new_hypothesis], dtype=torch.long).repeat_interleave(self.cfg.beam_size, dim=0).to(
device=self.device,
device=self.model.device,
)
self.tokens.append(new_tokens)
# TODO: test if this is redundant or not
# ret = ret[ret<DEC_PAD]
logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}")
self._clean_cache()
if len(l_absolute_timestamps) >=2 and self.first_timestamp is None:
self.first_timestamp = l_absolute_timestamps[0]
timestamped_words = []
timestamp_idx = 0
replacement_char = "\ufffd"
for word, word_tokens in zip(split_words, split_tokens):
# Skip words containing incomplete UTF-8 from client output
if replacement_char in word:
logger.warning(f"[UTF-8 Filter] Skipping incomplete word from client output: {repr(word)}")
timestamp_idx += len(word_tokens)
continue
try:
current_timestamp = l_absolute_timestamps[timestamp_idx]
except:
pass
timestamp_idx += len(word_tokens)
timestamp_entry = ASRToken(
start=current_timestamp,
end=current_timestamp + 0.1,
text= word,
probability=0.95,
speaker=self.speaker,
detected_language=self.detected_language
).with_offset(
self.global_time_offset
)
timestamped_words.append(timestamp_entry)
# Hold incomplete tokens for next chunk
self.pending_incomplete_tokens = []
if split_words and replacement_char in split_words[-1]:
self.pending_incomplete_tokens = split_tokens[-1]
logger.warning(f"[UTF-8 Fix] Holding {len(self.pending_incomplete_tokens)} incomplete tokens for next chunk: {self.pending_incomplete_tokens}")
return timestamped_words
return new_hypothesis, generation

View File

@@ -7,7 +7,6 @@ class TokenBuffer:
self.prefix_token_ids = prefix_token_ids
self.tokenizer = tokenizer
self.device = device
self.pending_token_ids = []
def as_token_ids(self, tokenizer=None):
@@ -65,26 +64,7 @@ class TokenBuffer:
def append_token_ids(self, token_ids):
tokenizer = self.tokenizer
assert tokenizer is not None, "Tokenizer is not set."
all_tokens = self.pending_token_ids + token_ids
decoded = tokenizer.decode(all_tokens)
replacement_char = "\ufffd"
if replacement_char in decoded:
if len(all_tokens) > 1:
decoded_partial = tokenizer.decode(all_tokens[:-1])
if replacement_char not in decoded_partial:
self.text += decoded_partial
self.pending_token_ids = [all_tokens[-1]]
else:
self.pending_token_ids = all_tokens
else:
self.pending_token_ids = all_tokens
else:
self.text += decoded
self.pending_token_ids = []
self.text += self.tokenizer.decode(token_ids)
def as_split_word_tokens(self):
tokenizer = self.tokenizer

View File

@@ -105,8 +105,6 @@ def load_model(
device: Optional[Union[str, torch.device]] = None,
download_root: str = None,
in_memory: bool = False,
decoder_only=False,
custom_alignment_heads=None
) -> Whisper:
"""
Load a Whisper ASR model
@@ -136,17 +134,15 @@ def load_model(
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
alignment_heads = _ALIGNMENT_HEADS[name]
elif os.path.isfile(name):
checkpoint_file = open(name, "rb").read() if in_memory else name
alignment_heads = None
else:
raise RuntimeError(
f"Model {name} not found; available models = {available_models()}"
)
alignment_heads = _ALIGNMENT_HEADS.get(name, None)
if custom_alignment_heads:
alignment_heads = custom_alignment_heads.encode()
with (
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
@@ -155,14 +151,7 @@ def load_model(
del checkpoint_file
dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims, decoder_only=decoder_only)
if decoder_only:
checkpoint["model_state_dict"] = {
k: v for k, v in checkpoint["model_state_dict"].items()
if 'encoder' not in k
}
model = Whisper(dims)
model.load_state_dict(checkpoint["model_state_dict"])
if alignment_heads is not None:

View File

@@ -253,18 +253,16 @@ class TextDecoder(nn.Module):
class Whisper(nn.Module):
def __init__(self, dims: ModelDimensions, decoder_only: bool = False):
def __init__(self, dims: ModelDimensions):
super().__init__()
self.dims = dims
if not decoder_only:
self.encoder = AudioEncoder(
self.dims.n_mels,
self.dims.n_audio_ctx,
self.dims.n_audio_state,
self.dims.n_audio_head,
self.dims.n_audio_layer,
)
self.encoder = AudioEncoder(
self.dims.n_mels,
self.dims.n_audio_ctx,
self.dims.n_audio_state,
self.dims.n_audio_head,
self.dims.n_audio_layer,
)
self.decoder = TextDecoder(
self.dims.n_vocab,
self.dims.n_text_ctx,

View File

@@ -1,57 +1,20 @@
from dataclasses import dataclass, field
from typing import Optional, Any, List
from datetime import timedelta
PUNCTUATION_MARKS = {'.', '!', '?', '', '', ''}
def format_time(seconds: float) -> str:
"""Format seconds as HH:MM:SS."""
return str(timedelta(seconds=int(seconds)))
from dataclasses import dataclass
from typing import Optional
@dataclass
class TimedText:
start: Optional[float] = 0
end: Optional[float] = 0
start: Optional[float]
end: Optional[float]
text: Optional[str] = ''
speaker: Optional[int] = -1
probability: Optional[float] = None
is_dummy: Optional[bool] = False
detected_language: Optional[str] = None
def is_punctuation(self):
return self.text.strip() in PUNCTUATION_MARKS
def overlaps_with(self, other: 'TimedText') -> bool:
return not (self.end <= other.start or other.end <= self.start)
def is_within(self, other: 'TimedText') -> bool:
return other.contains_timespan(self)
def duration(self) -> float:
return self.end - self.start
def contains_time(self, time: float) -> bool:
return self.start <= time <= self.end
def contains_timespan(self, other: 'TimedText') -> bool:
return self.start <= other.start and self.end >= other.end
def __bool__(self):
return bool(self.text)
@dataclass()
@dataclass
class ASRToken(TimedText):
corrected_speaker: Optional[int] = -1
validated_speaker: bool = False
validated_text: bool = False
validated_language: bool = False
def with_offset(self, offset: float) -> "ASRToken":
"""Return a new token with the time offset added."""
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, self.probability, detected_language=self.detected_language)
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, self.probability)
@dataclass
class Sentence(TimedText):
@@ -59,28 +22,7 @@ class Sentence(TimedText):
@dataclass
class Transcript(TimedText):
"""
represents a concatenation of several ASRToken
"""
@classmethod
def from_tokens(
cls,
tokens: List[ASRToken],
sep: Optional[str] = None,
offset: float = 0
) -> "Transcript":
sep = sep if sep is not None else ' '
text = sep.join(token.text for token in tokens)
probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
if tokens:
start = offset + tokens[0].start
end = offset + tokens[-1].end
else:
start = None
end = None
return cls(start, end, text, probability=probability)
pass
@dataclass
class SpeakerSegment(TimedText):
@@ -89,201 +31,6 @@ class SpeakerSegment(TimedText):
"""
pass
@dataclass
class Translation(TimedText):
is_validated : bool = False
pass
# def split(self):
# return self.text.split(" ") # should be customized with the sep
def approximate_cut_at(self, cut_time):
"""
Each word in text is considered to be of duration (end-start)/len(words in text)
"""
if not self.text or not self.contains_time(cut_time):
return self, None
words = self.text.split()
num_words = len(words)
if num_words == 0:
return self, None
duration_per_word = self.duration() / num_words
cut_word_index = int((cut_time - self.start) / duration_per_word)
if cut_word_index >= num_words:
cut_word_index = num_words -1
text0 = " ".join(words[:cut_word_index])
text1 = " ".join(words[cut_word_index:])
segment0 = Translation(start=self.start, end=cut_time, text=text0)
segment1 = Translation(start=cut_time, end=self.end, text=text1)
return segment0, segment1
def cut_position(self, position):
sep=" "
words = self.text.split(sep)
num_words = len(words)
duration_per_word = self.duration() / num_words
cut_time=duration_per_word*position
text0 = sep.join(words[:position])
text1 = sep.join(words[position:])
segment0 = Translation(start=self.start, end=cut_time, text=text0)
segment1 = Translation(start=cut_time, end=self.end, text=text1)
return segment0, segment1
@dataclass
class Silence():
duration: float
@dataclass
class Line(TimedText):
translation: str = ''
def to_dict(self):
_dict = {
'speaker': int(self.speaker) if self.speaker != -1 else 1,
'text': self.text,
'start': format_time(self.start),
'end': format_time(self.end),
}
if self.translation:
_dict['translation'] = self.translation
if self.detected_language:
_dict['detected_language'] = self.detected_language
return _dict
@dataclass
class WordValidation:
"""Validation status for word-level data."""
text: bool = False
speaker: bool = False
language: bool = False
def to_dict(self):
return {
'text': self.text,
'speaker': self.speaker,
'language': self.language
}
@dataclass
class Word:
"""Word-level object with timing and validation information."""
text: str = ''
start: float = 0.0
end: float = 0.0
validated: WordValidation = field(default_factory=WordValidation)
def to_dict(self):
return {
'text': self.text,
'start': self.start,
'end': self.end,
'validated': self.validated.to_dict()
}
@dataclass
class SegmentBuffer:
"""Per-segment temporary buffers for ephemeral data."""
transcription: str = ''
diarization: str = ''
translation: str = ''
def to_dict(self):
return {
'transcription': self.transcription,
'diarization': self.diarization,
'translation': self.translation
}
@dataclass
class Segment:
"""Represents a segment in the new API structure."""
id: int = 0
speaker: int = -1
text: str = ''
start_speaker: float = 0.0
start: float = 0.0
end: float = 0.0
language: Optional[str] = None
translation: str = ''
words: List[ASRToken] = field(default_factory=list)
buffer_tokens: List[ASRToken] = field(default_factory=list)
buffer_translation = ''
buffer: SegmentBuffer = field(default_factory=SegmentBuffer)
def to_dict(self):
"""Convert segment to dictionary for JSON serialization."""
return {
'id': self.id,
'speaker': self.speaker,
'text': self.text,
'start_speaker': self.start_speaker,
'start': self.start,
'end': self.end,
'language': self.language,
'translation': self.translation,
'words': [word.to_dict() for word in self.words],
'buffer': self.buffer.to_dict()
}
def consolidate(self, sep):
self.text = sep.join([word.text for word in self.words])
if self.words:
self.start = self.words[0].start
self.end = self.words[-1].end
@dataclass
class FrontData():
status: str = ''
error: str = ''
lines: list[Line] = field(default_factory=list)
buffer_transcription: str = ''
buffer_diarization: str = ''
remaining_time_transcription: float = 0.
remaining_time_diarization: float = 0.
def to_dict(self):
_dict = {
'status': self.status,
'lines': [line.to_dict() for line in self.lines if (line.text or line.speaker == -2)],
'buffer_transcription': self.buffer_transcription,
'buffer_diarization': self.buffer_diarization,
'remaining_time_transcription': self.remaining_time_transcription,
'remaining_time_diarization': self.remaining_time_diarization,
}
if self.error:
_dict['error'] = self.error
return _dict
@dataclass
class ChangeSpeaker:
speaker: int
start: int
@dataclass
class State():
tokens: list = field(default_factory=list)
segments: list = field(default_factory=list)
last_validated_token: int = 0
last_validated_segment: int = 0 # validated means tokens speaker and transcription are validated and terminated
translation_validated_segments: list = field(default_factory=list)
translation_buffer: list = field(default_factory=list)
buffer_transcription: str = field(default_factory=Transcript)
end_buffer: float = 0.0
end_attributed_speaker: float = 0.0
remaining_time_transcription: float = 0.0
remaining_time_diarization: float = 0.0
beg_loop: Optional[int] = None
duration: float

View File

@@ -0,0 +1,60 @@
# gemma_translate.py
import argparse
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL_ID = "google/gemma-3-270m-it"
def build_prompt(tokenizer, text, target_lang, source_lang=None):
# Use the model's chat template for best results
if source_lang:
user_msg = (
f"Translate the following {source_lang} text into {target_lang}.\n"
f"Return only the translation.\n\n"
f"Text:\n{text}"
)
else:
user_msg = (
f"Translate the following text into {target_lang}.\n"
f"Return only the translation.\n\n"
f"Text:\n{text}"
)
chat = [{"role": "user", "content": user_msg}]
return tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
def translate(text, target_lang, source_lang=None, max_new_tokens=256, temperature=0.2, top_p=0.95):
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto"
)
prompt = build_prompt(tokenizer, text, target_lang, source_lang)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=temperature > 0.0,
eos_token_id=tokenizer.eos_token_id,
)
# Slice off the prompt to keep only the assistant answer
generated_ids = output_ids[0][inputs["input_ids"].shape[1]:]
out = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
return out
if __name__ == "__main__":
ap = argparse.ArgumentParser(description="Translate with google/gemma-3-270m-it")
ap.add_argument("--text", required=True, help="Text to translate")
ap.add_argument("--to", dest="target_lang", required=True, help="Target language (e.g., French, Spanish)")
ap.add_argument("--from", dest="source_lang", default=None, help="Source language (optional)")
ap.add_argument("--temp", type=float, default=0.2, help="Sampling temperature (0 = deterministic-ish)")
ap.add_argument("--max-new", type=int, default=256, help="Max new tokens")
args = ap.parse_args()
print(translate(args.text, args.target_lang, args.source_lang, max_new_tokens=args.max_new, temperature=args.temp))

View File

@@ -0,0 +1,121 @@
# nllb_translate.py
import argparse
from pathlib import Path
from typing import List
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
MODEL_ID = "facebook/nllb-200-distilled-600M"
# Common language shortcuts → NLLB codes (extend as needed)
LANG_MAP = {
"english": "eng_Latn",
"en": "eng_Latn",
"french": "fra_Latn",
"fr": "fra_Latn",
"spanish": "spa_Latn",
"es": "spa_Latn",
"german": "deu_Latn",
"de": "deu_Latn",
"italian": "ita_Latn",
"it": "ita_Latn",
"portuguese": "por_Latn",
"pt": "por_Latn",
"arabic": "arb_Arab",
"ar": "arb_Arab",
"russian": "rus_Cyrl",
"ru": "rus_Cyrl",
"turkish": "tur_Latn",
"tr": "tur_Latn",
"chinese": "zho_Hans",
"zh": "zho_Hans", # Simplified
"zh-cn": "zho_Hans",
"zh-hans": "zho_Hans",
"zh-hant": "zho_Hant", # Traditional
"japanese": "jpn_Jpan",
"ja": "jpn_Jpan",
"korean": "kor_Hang",
"ko": "kor_Hang",
"dutch": "nld_Latn",
"nl": "nld_Latn",
"polish": "pol_Latn",
"pl": "pol_Latn",
"swedish": "swe_Latn",
"sv": "swe_Latn",
"norwegian": "nob_Latn",
"no": "nob_Latn",
"danish": "dan_Latn",
"da": "dan_Latn",
"finnish": "fin_Latn",
"fi": "fin_Latn",
"catalan": "cat_Latn",
"ca": "cat_Latn",
"hindi": "hin_Deva",
"hi": "hin_Deva",
"vietnamese": "vie_Latn",
"vi": "vie_Latn",
"indonesian": "ind_Latn",
"id": "ind_Latn",
"thai": "tha_Thai",
"th": "tha_Thai",
}
def norm_lang(code: str) -> str:
c = code.strip().lower()
return LANG_MAP.get(c, code)
def translate_texts(texts: List[str], src_code: str, tgt_code: str,
max_new_tokens=512, device=None, dtype=None) -> List[str]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, src_lang=src_code)
model = AutoModelForSeq2SeqLM.from_pretrained(
MODEL_ID,
torch_dtype=dtype if dtype is not None else (torch.float16 if torch.cuda.is_available() else torch.float32),
device_map="auto" if torch.cuda.is_available() else None,
)
if device:
model.to(device)
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
if device or torch.cuda.is_available():
inputs = {k: v.to(model.device) for k, v in inputs.items()}
forced_bos = tokenizer.convert_tokens_to_ids(tgt_code)
with torch.no_grad():
gen = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
forced_bos_token_id=forced_bos,
)
outs = tokenizer.batch_decode(gen, skip_special_tokens=True)
return [o.strip() for o in outs]
def main():
ap = argparse.ArgumentParser(description="Translate with facebook/nllb-200-distilled-600M")
ap.add_argument("--text", help="Inline text to translate")
ap.add_argument("--file", help="Path to a UTF-8 text file (one example per line)")
ap.add_argument("--src", required=True, help="Source language (e.g. fr, fra_Latn)")
ap.add_argument("--tgt", required=True, help="Target language (e.g. en, eng_Latn)")
ap.add_argument("--max-new", type=int, default=512, help="Max new tokens")
args = ap.parse_args()
src = norm_lang(args.src)
tgt = norm_lang(args.tgt)
batch: List[str] = []
if args.text:
batch.append(args.text)
if args.file:
lines = Path(args.file).read_text(encoding="utf-8").splitlines()
batch.extend([ln for ln in lines if ln.strip()])
if not batch:
raise SystemExit("Provide --text or --file")
results = translate_texts(batch, src, tgt, max_new_tokens=args.max_new)
for i, (inp, out) in enumerate(zip(batch, results), 1):
print(f"\n--- Sample {i} ---")
print(f"SRC [{src}]: {inp}")
print(f"TGT [{tgt}]: {out}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,38 @@
import regex
from functools import lru_cache
class SentenceSegmenter:
"""
Regex sentence splitter for Latin languages, Japanese and Chinese.
It is based on sacrebleu TokenizerV14International(BaseTokenizer).
Returns: a list of strings, where each string is a sentence.
Spaces following punctuation are appended after punctuation within the sequence.
Total number of characters in the output is the same as in the input.
"""
sep = 'ŽžŽžSentenceSeparatorŽžŽž' # string that certainly won't be in src or target
latin_terminals = '!?.'
jap_zh_terminals = '。!?'
terminals = latin_terminals + jap_zh_terminals
def __init__(self):
# end of sentence characters:
terminals = self.terminals
self._re = [
# Separate out punctuations preceeded by a non-digit.
# If followed by space-like sequence of characters, they are
# appended to the punctuation, not to the next sequence.
(regex.compile(r'(\P{N})(['+terminals+r'])(\p{Z}*)'), r'\1\2\3'+self.sep),
# Separate out punctuations followed by a non-digit
(regex.compile(r'('+terminals+r')(\P{N})'), r'\1'+self.sep+r'\2'),
# # Separate out symbols
# -> no, we don't tokenize but segment the punctuation
# (regex.compile(r'(\p{S})'), r' \1 '),
]
@lru_cache(maxsize=2**16)
def __call__(self, line):
for (_re, repl) in self._re:
line = _re.sub(repl, line)
return [ t for t in line.split(self.sep) if t != '' ]

View File

@@ -0,0 +1,466 @@
import sys
import ctranslate2
import sentencepiece as spm
import transformers
import argparse
def generate_words(sp, step_results):
tokens_buffer = []
for step_result in step_results:
is_new_word = step_result.token.startswith("")
if is_new_word and tokens_buffer:
word = sp.decode(tokens_buffer)
if word:
yield word
tokens_buffer = []
tokens_buffer.append(step_result.token_id)
if tokens_buffer:
word = sp.decode(tokens_buffer)
if word:
yield word
from sentence_segmenter import SentenceSegmenter
class LLMTranslator:
def __init__(self, system_prompt='Please translate.', max_context_length=4096, len_ratio=None):
self.system_prompt = system_prompt
print("Loading the model...", file=sys.stderr)
self.generator = ctranslate2.Generator("ct2_EuroLLM-9B-Instruct/", device="cuda")
self.sp = spm.SentencePieceProcessor("EuroLLM-9B-Instruct/tokenizer.model")
self.tokenizer = transformers.AutoTokenizer.from_pretrained("EuroLLM-9B-Instruct/")
print("...done", file=sys.stderr)
self.max_context_length = max_context_length
self.max_tokens_to_trim = self.max_context_length - 10
self.len_ratio = len_ratio
# my regex sentence segmenter
self.segmenter = SentenceSegmenter()
# self.max_generation_length = 512
# self.max_prompt_length = context_length - max_generation_length
def start_dialog(self):
return [{'role':'system', 'content': self.system_prompt }]
def build_prompt(self, dialog):
toks = self.tokenizer.apply_chat_template(dialog, tokenize=True, add_generation_prompt=False)
if len(dialog) == 3:
toks = toks[:-2]
print("len toks:", len(toks), file=sys.stderr)
# print(toks, file=sys.stderr)
c = self.tokenizer.convert_ids_to_tokens(toks)
# print(c,file=sys.stderr)
return c
def translate(self, src, tgt_forced=""):
#src, tgt_forced = self.trim(src, tgt_forced)
dialog = self.start_dialog()
dialog += [{'role':'user','content': src}]
if tgt_forced != "":
dialog += [{'role':'assistant','content': tgt_forced}]
prompt_tokens = self.build_prompt(dialog)
if self.len_ratio is not None:
limit_len = int(len(self.tokenizer.encode(src)) * self.len_ratio) + 10
limit_kw = {'max_length': limit_len}
else:
limit_kw = {}
step_results = self.generator.generate_tokens(
prompt_tokens,
**limit_kw,
# end_token=tokenizer.eos_token,
# sampling_temperature=0.6,
# sampling_topk=20,
# sampling_topp=1,
)
res = []
#output_ids = []
for step_result in step_results:
# is_new_word = step_result.token.startswith("▁")
# if is_new_word and output_ids:
# word = self.sp.decode(output_ids)
# print(word, end=" ", flush=True, file=sys.stderr)
# output_ids = []
# output_ids.append(step_result.token_id)
res.append(step_result)
#if output_ids:
# word = self.sp.decode(output_ids)
# print(word, file=sys.stderr)
return self.sp.decode([r.token_id for r in res])
# print(res)
# print([s.token for s in res], file=sys.stderr)
# print([s.token==self.tokenizer.eos_token for s in res], file=sys.stderr)
class ParallelTextBuffer:
def __init__(self, tokenizer, max_tokens, trimming="segments", init_src="", init_tgt=""):
self.tokenizer = tokenizer
self.max_tokens = max_tokens
self.src_buffer = [] # list of lists
if init_src:
self.src_buffer.append(init_src)
self.tgt_buffer = [] # list of strings
if init_tgt:
self.tgt_buffer.append(init_tgt)
self.trimming = trimming
if self.trimming == "sentences":
self.segmenter = SentenceSegmenter()
def len_src(self):
return sum(len(t) for t in self.src_buffer) + len(self.src_buffer) - 1
def insert(self, src, tgt):
self.src_buffer.append(src)
self.tgt_buffer.append(tgt)
def insert_src_suffix(self, s):
if self.src_buffer:
self.src_buffer[-1][-1] += s
else:
self.src_buffer.append([s])
def trim_sentences(self):
# src_tok_lens = [len(self.tokenizer.encode(" ".join(b))) for b in self.src_buffer]
# tgt_tok_lens = [len(self.tokenizer.encode(t)) for t in self.tgt_buffer]
src = " ".join(" ".join(b) for b in self.src_buffer)
tgt = "".join(self.tgt_buffer)
src_sp_toks = self.tokenizer.encode(src)
tgt_sp_toks = self.tokenizer.encode(tgt)
def trim_sentence(text):
sents = self.segmenter(text)
print("SENTS:", len(sents), sents, file=sys.stderr)
return "".join(sents[1:])
while len(src_sp_toks) + len(tgt_sp_toks) > self.max_tokens:
nsrc = trim_sentence(src)
ntgt = trim_sentence(tgt)
if not nsrc or not ntgt:
print("src or tgt is empty after trimming.", file=sys.stderr)
print("src: ", src, file=sys.stderr)
print("tgt: ", tgt, file=sys.stderr)
break
src = nsrc
tgt = ntgt
src_sp_toks = self.tokenizer.encode(src)
tgt_sp_toks = self.tokenizer.encode(tgt)
print("TRIMMED SRC:", (src,), file=sys.stderr)
print("TRIMMED TGT:", (tgt,), file=sys.stderr)
self.src_buffer = [src.split()]
self.tgt_buffer = [tgt]
return src, tgt
def trim_segments(self):
print("BUFFER:", file=sys.stderr)
for s,t in zip(self.src_buffer, self.tgt_buffer):
print("\t", s,"...",t,file=sys.stderr) #,self.src_buffer, self.tgt_buffer, file=sys.stderr)
src = " ".join(" ".join(b) for b in self.src_buffer)
tgt = "".join(self.tgt_buffer)
src_sp_toks = self.tokenizer.encode(src)
tgt_sp_toks = self.tokenizer.encode(tgt)
while len(src_sp_toks) + len(tgt_sp_toks) > self.max_tokens:
if len(self.src_buffer) > 1 and len(self.tgt_buffer) > 1:
self.src_buffer.pop(0)
self.tgt_buffer.pop(0)
else:
break
src = " ".join(" ".join(b) for b in self.src_buffer)
tgt = "".join(self.tgt_buffer)
src_sp_toks = self.tokenizer.encode(src)
tgt_sp_toks = self.tokenizer.encode(tgt)
print("TRIMMED SEGMENTS SRC:", (src,), file=sys.stderr)
print("TRIMMED SEGMENTS TGT:", (tgt,), file=sys.stderr)
return src, tgt
def trim(self):
if self.trimming == "sentences":
return self.trim_sentences()
return self.trim_segments()
class SimulLLM:
def __init__(self, llmtrans, min_len=0, chunk=1, trimming="sentences", language="ja", init_src="", init_tgt=""):
self.llmtranslator = llmtrans
#self.src_buffer = init_src
#self.confirmed_tgt = init_tgt
self.buffer = ParallelTextBuffer(self.llmtranslator.tokenizer, self.llmtranslator.max_tokens_to_trim, trimming=trimming, init_src=init_src, init_tgt=init_tgt)
self.last_inserted = []
self.last_unconfirmed = ""
self.min_len = min_len
self.step = chunk
self.language = language
if language in ["ja", "zh"]:
self.specific_space = ""
else:
self.specific_space = " "
def insert(self, src):
if isinstance(src, str):
self.last_inserted.append(src)
else:
self.last_inserted += src
def insert_suffix(self, text):
'''
Insert suffix of a word to the last inserted word.
It may be because the word was split to multiple parts in the input, each with different timestamps.
'''
if self.last_inserted:
self.last_inserted[-1] += text
elif self.src_buffer:
self.buffer.insert_src_suffix(text)
else:
# this shouldn't happen
self.last_inserted.append(text)
def trim_longest_common_prefix(self, a,b):
if self.language not in ["ja", "zh"]:
a = a.split()
b = b.split()
i = 0
for i,(x,y) in enumerate(zip(a,b)):
if x != y:
break
if self.language in ["ja", "zh"]:
#print("tady160",(a, b, i), file=sys.stderr)
return a[:i], b[i:]
else:
return " ".join(a[:i]), " ".join(b[i:])
def process_iter(self):
if self.buffer.len_src() + len(self.last_inserted) < self.min_len:
return ""
src, forced_tgt = self.buffer.trim() #llmtranslator.trim(" ".join(self.src_buffer), self.confirmed_tgt)
#self.src_buffer = self.src_buffer.split()
#src = " ".join(self.src_buffer)
confirmed_out = ""
run = False
for i in range(self.step, len(self.last_inserted), self.step):
for w in self.last_inserted[i-self.step:i]:
src += " " + w
run = True
if not run: break
print("SRC",src,file=sys.stderr)
print("FORCED TGT",forced_tgt,file=sys.stderr)
out = self.llmtranslator.translate(src, forced_tgt)
print("OUT",out,file=sys.stderr)
confirmed, unconfirmed = self.trim_longest_common_prefix(self.last_unconfirmed, out)
self.last_unconfirmed = unconfirmed
#print("tady", (self.confirmed_tgt, self.specific_space, confirmed), file=sys.stderr)
if confirmed:
# self.confirmed_tgt += self.specific_space + confirmed
# print(confirmed_out, confirmed, file=sys.stderr)
confirmed_out += self.specific_space + confirmed
print("CONFIRMED NOW:",confirmed,file=sys.stderr)
print(file=sys.stderr)
print(file=sys.stderr)
print("#################",file=sys.stderr)
if run:
self.buffer.insert(self.last_inserted, confirmed_out)
self.last_inserted = []
ret = confirmed_out
print("RET:",ret,file=sys.stderr)
return ret
def finalize(self):
return self.last_unconfirmed
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--input-instance', type=str, default=None, help="Filename of instances to simulate input. If not set, txt input is read from stdin.")
#parser.add_argument('--output_instance', type=str, default=None, help="Write output as instance into this file, while also writing to stdout.")
parser.add_argument('--min-chunk-size', type=int, default=1,
help='Minimum number of space-delimited words to process in each LocalAgreement update. The more, the higher quality, but slower.')
parser.add_argument('--min-len', type=int, default=1,
help='Minimum number of space-delimited words at the beginning.')
#parser.add_argument('--start_at', type=int, default=0, help='Skip first N words.')
# maybe later
#parser.add_argument('--offline', action="store_true", default=False, help='Offline mode.')
#parser.add_argument('--comp_unaware', action="store_true", default=False, help='Computationally unaware simulation.')
lan_to_name = {
"de": "German",
"ja": "Japanese",
"zh-tr": "Chinese Traditional",
"zh-sim": "Chinese Simplified",
"cs": "Czech",
}
parser.add_argument('--lan', '--language', type=str, default="de",
help="Target language code.",
choices=["de", "ja","zh-tr","zh-sim","cs"])
SrcLang = "English" # always
TgtLang = "German"
default_prompt="You are simultaneous interpreter from {SrcLang} to {TgtLang}. We are at a conference. It is important that you translate " + \
"only what you hear, nothing else!"
parser.add_argument('--sys_prompt', type=str, default=None,
help='System prompt. If None, default one is used, depending on the language. The prompt should ')
default_init = "Please, go ahead, you can start with your presentation, we are ready."
default_inits_tgt = {
'de': "Bitte schön, Sie können mit Ihrer Präsentation beginnen, wir sind bereit.",
'ja': "どうぞ、プレゼンテーションを始めてください。", # # Please go ahead and start your presentation. # this is in English
'zh-tr': "請繼續,您可以開始您的簡報,我們已經準備好了。",
'zh-sim': "请吧,你可以开始发言了,我们已经准备好了。",
'cs': "Prosím, můžete začít s prezentací, jsme připraveni.",
}
parser.add_argument('--init_prompt_src', type=str, default=None, help='Init translation with source text. It should be a complete sentence in the source language. '
'It can be context specific for the given input. Default is ')
parser.add_argument('--init_prompt_tgt', type=str, default=None, help='Init translation with this target. It should be example translation of init_prompt_src. '
' There is default init message, depending on the language.')
parser.add_argument('--len-threshold', type=float, default=None, help='Ratio of the length of the source and generated target, in number of sentencepiece tokens. '
'It should reflect the target language and. If not set, no len-threshold is used.')
# how many times is target text longer than English
lan_thresholds = {
'de': 1.3, # 12751/9817 ... the proportion of subword tokens for ACL6060 dev de vs. en text, for EuroLLM-9B-Instruct tokenizer
'ja': 1.34, # 13187/9817
'zh': 1.23, # 12115/9817
'zh-tr': 1.23, # 12115/9817
'zh-sim': 1.23, # 12115/9817
# 'cs': I don't know # guessed
}
parser.add_argument('--language-specific-len-threshold', default=False, action="store_true",
help='Use language-specific length threshold, e.g. 1.3 for German.')
parser.add_argument("--max-context-length", type=int, default=4096, help="Maximum number of tokens in the model to use.")
parser.add_argument("--buffer_trimming", type=str, default="sentences", choices=["segments","sentences"], help="Buffer trimming strategy.")
args = parser.parse_args()
if args.sys_prompt is None:
TgtLang = lan_to_name[args.lan]
sys_prompt = default_prompt.format(SrcLang=SrcLang, TgtLang=TgtLang)
else:
sys_prompt = args.sys_prompt
if args.init_prompt_src is None:
init_src = default_init.split()
if args.init_prompt_tgt is None:
init_tgt = default_inits_tgt[args.lan]
if args.lan == "ja":
init_src = 'Please go ahead and start your presentation.'.split()
print("WARNING: Default init_prompt_src not set and language is Japanese. The init_src prompt changed to be more verbose.", file=sys.stderr)
else:
print("WARNING: init_prompt_tgt is used, init_prompt_src is None, the default one. It may be wrong!", file=sys.stderr)
init_tgt = args.init_prompt_tgt
else:
init_src = args.init_prompt_src.split()
if args.init_prompt_tgt is None:
print("WARNING: init_prompt_src is used, init_prompt_tgt is None, so the default one is used. It may be wrong!", file=sys.stderr)
init_tgt = default_inits_tgt[args.lan]
else:
init_tgt = args.init_prompt_tgt
print("INFO: System prompt:", sys_prompt, file=sys.stderr)
print("INFO: Init prompt src:", init_src, file=sys.stderr)
print("INFO: Init prompt tgt:", init_tgt, file=sys.stderr)
if args.language_specific_len_threshold:
if args.len_threshold is not None:
print("ERROR: --len-threshold is set, but --language-specific-len-threshold is also set. Only one can be used.", file=sys.stderr)
sys.exit(1)
else:
len_threshold = lan_thresholds[args.lan]
else:
len_threshold = args.len_threshold
llmtrans = LLMTranslator(system_prompt=sys_prompt, max_context_length=args.max_context_length, len_ratio=len_threshold)
lan = args.lan if not args.lan.startswith("zh") else "zh"
simul = SimulLLM(llmtrans,language=lan, min_len=args.min_len, chunk=args.min_chunk_size,
init_src=init_src, init_tgt=init_tgt, trimming=args.buffer_trimming
)
# two input options
if args.input_instance is not None:
print("INFO: Reading input from file", args.input_instance, file=sys.stderr)
import json
with open(args.input_instance, "r") as f:
instance = json.load(f)
asr_source = instance["prediction"]
timestamps = instance["delays"]
elapsed = instance["elapsed"]
yield_ts_words = zip(timestamps, timestamps, elapsed, asr_source.split())
else:
print("INFO: Reading stdin in txt format", file=sys.stderr)
def yield_input():
for line in sys.stdin:
line = line.strip()
ts, beg, end, *_ = line.split()
text = line[len(ts)+len(beg)+len(end)+3:]
ts = float(ts)
# in rare cases, the first word is a suffix of the previous word, that was split to multiple parts
if text[0] != " ":
first, *words = text.split()
yield (ts, beg, end, " "+first) # marking the first word with " ", so that it can be later detected and inserted as suffix
else:
words = text.split()
for w in words:
yield (ts, beg, end, w)
yield_ts_words = yield_input()
#i = 0
for t,b,e,w in yield_ts_words:
if w.startswith(" "): # it is suffix of the previous word
w = w[1:]
simul.insert_suffix(w)
continue
simul.insert(w)
out = simul.process_iter()
if out:
print(t,b,e,out,flush=True)
# if i > 50:
# break
# i += 1
out = simul.finalize()
print(t,b,e,out,flush=True)

View File

@@ -6,46 +6,57 @@ logger = logging.getLogger(__name__)
def load_file(warmup_file=None, timeout=5):
import os
import tempfile
import urllib.request
import librosa
if warmup_file == "":
logger.info(f"Skipping warmup.")
return None
# Download JFK sample if not already present
if warmup_file is None:
# Download JFK sample if not already present
jfk_url = "https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav"
temp_dir = tempfile.gettempdir()
warmup_file = os.path.join(temp_dir, "whisper_warmup_jfk.wav")
if not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0:
if not os.path.exists(warmup_file):
logger.debug(f"Downloading warmup file from {jfk_url}")
print(f"Downloading warmup file from {jfk_url}")
import time
import urllib.request
import urllib.error
import socket
original_timeout = socket.getdefaulttimeout()
socket.setdefaulttimeout(timeout)
start_time = time.time()
try:
logger.debug(f"Downloading warmup file from {jfk_url}")
with urllib.request.urlopen(jfk_url, timeout=timeout) as r, open(warmup_file, "wb") as f:
f.write(r.read())
except Exception as e:
logger.warning(f"Warmup file download failed: {e}.")
return None
# Validate file and load
if not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0:
logger.warning(f"Warmup file {warmup_file} is invalid or missing.")
return None
urllib.request.urlretrieve(jfk_url, warmup_file)
logger.debug(f"Download successful in {time.time() - start_time:.2f}s")
except (urllib.error.URLError, socket.timeout) as e:
logger.warning(f"Download failed: {e}. Proceeding without warmup.")
return False
finally:
socket.setdefaulttimeout(original_timeout)
elif not warmup_file:
return False
if not warmup_file or not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0:
logger.warning(f"Warmup file {warmup_file} invalid or missing.")
return False
try:
audio, _ = librosa.load(warmup_file, sr=16000)
return audio
audio, sr = librosa.load(warmup_file, sr=16000)
except Exception as e:
logger.warning(f"Failed to load warmup file: {e}")
return None
logger.warning(f"Failed to load audio file: {e}")
return False
return audio
def warmup_asr(asr, warmup_file=None, timeout=5):
"""
Warmup the ASR model by transcribing a short audio file.
"""
audio = load_file(warmup_file=warmup_file, timeout=timeout)
if audio is None:
logger.warning("Warmup file unavailable. Skipping ASR warmup.")
return
audio = load_file(warmup_file=None, timeout=5)
asr.transcribe(audio)
logger.info("ASR model is warmed up.")
logger.info("ASR model is warmed up")
def warmup_online(online, warmup_file=None, timeout=5):
audio = load_file(warmup_file=None, timeout=5)
online.warmup(audio)
logger.warning("ASR is warmed up")

View File

@@ -72,21 +72,12 @@
--label-trans-text: #111111;
}
html.is-extension
{
width: 350px;
height: 500px;
}
body {
font-family: ui-sans-serif, system-ui, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol', 'Noto Color Emoji';
margin: 0;
margin: 20px;
text-align: center;
background-color: var(--bg);
color: var(--text);
height: 100vh;
display: flex;
flex-direction: column;
}
/* Record button */
@@ -177,18 +168,9 @@ body {
}
#status {
margin-top: 15px;
margin-top: 20px;
font-size: 16px;
color: var(--text);
margin-bottom: 0;
}
.header-container {
position: sticky;
top: 0;
background-color: var(--bg);
z-index: 100;
padding: 20px;
}
/* Settings */
@@ -197,83 +179,16 @@ body {
justify-content: center;
align-items: center;
gap: 15px;
position: relative;
flex-wrap: wrap;
}
.buttons-container {
display: flex;
align-items: center;
gap: 15px;
margin-top: 20px;
}
.settings {
display: flex;
flex-wrap: wrap;
flex-direction: column;
align-items: flex-start;
gap: 12px;
}
.settings-toggle {
width: 40px;
height: 40px;
border: none;
border-radius: 50%;
background-color: var(--button-bg);
border: 1px solid var(--button-border);
cursor: pointer;
display: none;
align-items: center;
justify-content: center;
transition: all 0.2s ease;
}
.settings-toggle:hover {
background-color: var(--chip-bg);
}
.settings-toggle.active {
background-color: var(--chip-bg);
}
.settings-toggle img {
width: 20px;
height: 20px;
}
@media (max-width: 10000px) {
.settings-toggle {
display: flex;
}
.settings {
display: none;
background: var(--bg);
border: 1px solid var(--border);
border-radius: 18px;
padding: 12px;
}
.settings.visible {
display: flex;
}
}
@media (max-width: 600px) {
.settings-container {
flex-direction: column;
align-items: center;
gap: 10px;
}
.buttons-container {
display: flex;
justify-content: center;
align-items: center;
gap: 15px;
}
}
.field {
display: flex;
flex-direction: column;
@@ -283,27 +198,23 @@ body {
#chunkSelector,
#websocketInput,
#themeSelector,
#microphoneSelect {
#themeSelector {
font-size: 16px;
padding: 5px 8px;
border-radius: 8px;
border: 1px solid var(--border);
background-color: var(--button-bg);
color: var(--text);
max-height: 30px;
max-height: 34px;
}
#microphoneSelect {
width: 100%;
max-width: 190px;
min-width: 120px;
#websocketInput {
width: 220px;
}
#chunkSelector:focus,
#websocketInput:focus,
#themeSelector:focus,
#microphoneSelect:focus {
#themeSelector:focus {
outline: none;
border-color: #007bff;
box-shadow: 0 0 0 3px rgba(0, 123, 255, 0.15);
@@ -336,9 +247,9 @@ label {
}
.theme-selector-container {
display: flex;
align-items: center;
margin-top: 17px;
position: absolute;
top: 20px;
right: 20px;
}
.segmented label {
@@ -382,21 +293,9 @@ label {
border-radius: 999px;
}
.transcript-container {
flex: 1;
overflow-y: auto;
padding: 20px;
scrollbar-width: none;
-ms-overflow-style: none;
}
.transcript-container::-webkit-scrollbar {
display: none;
}
/* Transcript area */
#linesTranscript {
margin: 0 auto;
margin: 20px auto;
max-width: 700px;
text-align: left;
font-size: 16px;
@@ -420,7 +319,7 @@ label {
.label_diarization {
background-color: var(--chip-bg);
border-radius: 100px;
border-radius: 8px 8px 8px 8px;
padding: 2px 10px;
margin-left: 10px;
display: inline-block;
@@ -432,7 +331,7 @@ label {
.label_transcription {
background-color: var(--chip-bg);
border-radius: 100px;
border-radius: 8px 8px 8px 8px;
padding: 2px 10px;
display: inline-block;
white-space: nowrap;
@@ -442,34 +341,9 @@ label {
color: var(--label-trans-text);
}
.label_translation {
background-color: var(--chip-bg);
display: inline-flex;
border-radius: 10px;
padding: 4px 8px;
margin-top: 4px;
font-size: 14px;
color: var(--text);
align-items: flex-start;
gap: 4px;
}
.lag-diarization-value {
margin-left: 10px;
}
.label_translation img {
margin-top: 2px;
}
.label_translation img {
width: 12px;
height: 12px;
}
#timeInfo {
color: var(--muted);
margin-left: 0px;
margin-left: 10px;
}
.textcontent {
@@ -483,6 +357,7 @@ label {
.buffer_diarization {
color: var(--label-dia-text);
margin-left: 4px;
}
.buffer_transcription {
@@ -525,101 +400,3 @@ label {
font-size: 14px;
margin-bottom: 0px;
}
/* for smaller screens */
@media (max-width: 200px) {
.header-container {
padding: 15px;
}
.settings-container {
flex-direction: column;
gap: 10px;
}
.buttons-container {
gap: 10px;
}
.settings {
justify-content: center;
gap: 8px;
}
.field {
align-items: center;
}
#websocketInput,
#microphoneSelect {
min-width: 100px;
max-width: 160px;
}
.theme-selector-container {
margin-top: 10px;
}
.transcript-container {
padding: 15px;
}
}
@media (max-width: 480px) {
.header-container {
padding: 10px;
}
.settings {
flex-direction: column;
align-items: center;
gap: 6px;
}
#websocketInput,
#microphoneSelect {
max-width: 140px;
}
.segmented label {
padding: 4px 8px;
font-size: 12px;
}
.segmented img {
width: 14px;
height: 14px;
}
.transcript-container {
padding: 10px;
}
}
.label_language {
background-color: var(--chip-bg);
margin-bottom: 0px;
border-radius: 100px;
padding: 2px 8px;
margin-left: 10px;
display: inline-flex;
align-items: center;
gap: 4px;
font-size: 14px;
color: var(--muted);
}
.speaker-badge {
display: inline-flex;
align-items: center;
justify-content: center;
width: 16px;
height: 16px;
margin-left: -5px;
border-radius: 50%;
font-size: 11px;
line-height: 1;
font-weight: 800;
color: var(--muted);
}

View File

@@ -1,79 +1,61 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>WhisperLiveKit</title>
<link rel="stylesheet" href="live_transcription.css" />
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>WhisperLiveKit</title>
<link rel="stylesheet" href="/web/live_transcription.css" />
</head>
<body>
<div class="header-container">
<div class="settings-container">
<div class="buttons-container">
<button id="recordButton">
<div class="shape-container">
<div class="shape"></div>
</div>
<div class="recording-info">
<div class="wave-container">
<canvas id="waveCanvas"></canvas>
</div>
<div class="timer">00:00</div>
</div>
</button>
<button id="settingsToggle" class="settings-toggle" title="Show/hide settings">
<img src="web/src/settings.svg" alt="Settings" />
</button>
</div>
<div class="settings">
<div class="field">
<label for="websocketInput">Websocket URL</label>
<input id="websocketInput" type="text" placeholder="ws://host:port/asr" />
</div>
<div class="field">
<label id="microphoneSelectLabel" for="microphoneSelect">Select Microphone</label>
<select id="microphoneSelect">
<option value="">Default Microphone</option>
</select>
</div>
<div class="theme-selector-container">
<div class="segmented" role="radiogroup" aria-label="Theme selector">
<input type="radio" id="theme-system" name="theme" value="system" />
<label for="theme-system" title="System">
<img src="/web/src/system_mode.svg" alt="" />
<span>System</span>
</label>
<input type="radio" id="theme-light" name="theme" value="light" />
<label for="theme-light" title="Light">
<img src="/web/src/light_mode.svg" alt="" />
<span>Light</span>
</label>
<input type="radio" id="theme-dark" name="theme" value="dark" />
<label for="theme-dark" title="Dark">
<img src="/web/src/dark_mode.svg" alt="" />
<span>Dark</span>
</label>
</div>
</div>
</div>
<div class="settings-container">
<button id="recordButton">
<div class="shape-container">
<div class="shape"></div>
</div>
<div class="recording-info">
<div class="wave-container">
<canvas id="waveCanvas"></canvas>
</div>
<p id="status"></p>
</div>
<div class="timer">00:00</div>
</div>
</button>
<div class="transcript-container">
<div id="linesTranscript"></div>
</div>
<div class="settings">
<div class="field">
<label for="websocketInput">WebSocket URL</label>
<input id="websocketInput" type="text" placeholder="ws://host:port/asr" />
</div>
<script src="live_transcription.js"></script>
</div>
</div>
</div>
<div class="theme-selector-container">
<div class="segmented" role="radiogroup" aria-label="Theme selector">
<input type="radio" id="theme-system" name="theme" value="system" />
<label for="theme-system" title="System">
<img src="/web/src/system_mode.svg" alt="" />
<span>System</span>
</label>
<input type="radio" id="theme-light" name="theme" value="light" />
<label for="theme-light" title="Light">
<img src="/web/src/light_mode.svg" alt="" />
<span>Light</span>
</label>
<input type="radio" id="theme-dark" name="theme" value="dark" />
<label for="theme-dark" title="Dark">
<img src="/web/src/dark_mode.svg" alt="" />
<span>Dark</span>
</label>
</div>
</div>
<p id="status"></p>
<div id="linesTranscript"></div>
<script src="/web/live_transcription.js"></script>
</body>
</html>

View File

@@ -1,8 +1,4 @@
const isExtension = typeof chrome !== 'undefined' && chrome.runtime && chrome.runtime.getURL;
if (isExtension) {
document.documentElement.classList.add('is-extension');
}
const isWebContext = !isExtension;
/* Theme, WebSocket, recording, rendering logic extracted from inline script and adapted for segmented theme control and WS caption */
let isRecording = false;
let websocket = null;
@@ -16,21 +12,12 @@ let timerInterval = null;
let audioContext = null;
let analyser = null;
let microphone = null;
let workletNode = null;
let recorderWorker = null;
let waveCanvas = document.getElementById("waveCanvas");
let waveCtx = waveCanvas.getContext("2d");
let animationFrame = null;
let waitingForStop = false;
let lastReceivedData = null;
let lastSignature = null;
let availableMicrophones = [];
let selectedMicrophoneId = null;
let serverUseAudioWorklet = null;
let configReadyResolve;
const configReady = new Promise((r) => (configReadyResolve = r));
let outputAudioContext = null;
let audioSource = null;
waveCanvas.width = 60 * (window.devicePixelRatio || 1);
waveCanvas.height = 30 * (window.devicePixelRatio || 1);
@@ -44,27 +31,6 @@ const websocketDefaultSpan = document.getElementById("wsDefaultUrl");
const linesTranscriptDiv = document.getElementById("linesTranscript");
const timerElement = document.querySelector(".timer");
const themeRadios = document.querySelectorAll('input[name="theme"]');
const microphoneSelect = document.getElementById("microphoneSelect");
const settingsToggle = document.getElementById("settingsToggle");
const settingsDiv = document.querySelector(".settings");
// if (isExtension) {
// chrome.runtime.onInstalled.addListener((details) => {
// if (details.reason.search(/install/g) === -1) {
// return;
// }
// chrome.tabs.create({
// url: chrome.runtime.getURL("welcome.html"),
// active: true
// });
// });
// }
const translationIcon = `<svg xmlns="http://www.w3.org/2000/svg" height="12px" viewBox="0 -960 960 960" width="12px" fill="#5f6368"><path d="m603-202-34 97q-4 11-14 18t-22 7q-20 0-32.5-16.5T496-133l152-402q5-11 15-18t22-7h30q12 0 22 7t15 18l152 403q8 19-4 35.5T868-80q-13 0-22.5-7T831-106l-34-96H603ZM362-401 188-228q-11 11-27.5 11.5T132-228q-11-11-11-28t11-28l174-174q-35-35-63.5-80T190-640h84q20 39 40 68t48 58q33-33 68.5-92.5T484-720H80q-17 0-28.5-11.5T40-760q0-17 11.5-28.5T80-800h240v-40q0-17 11.5-28.5T360-880q17 0 28.5 11.5T400-840v40h240q17 0 28.5 11.5T680-760q0 17-11.5 28.5T640-720h-76q-21 72-63 148t-83 116l96 98-30 82-122-125Zm266 129h144l-72-204-72 204Z"/></svg>`
const silenceIcon = `<svg xmlns="http://www.w3.org/2000/svg" style="vertical-align: text-bottom;" height="14px" viewBox="0 -960 960 960" width="14px" fill="#5f6368"><path d="M514-556 320-752q9-3 19-5.5t21-2.5q66 0 113 47t47 113q0 11-1.5 22t-4.5 22ZM40-200v-32q0-33 17-62t47-44q51-26 115-44t141-18q26 0 49.5 2.5T456-392l-56-54q-9 3-19 4.5t-21 1.5q-66 0-113-47t-47-113q0-11 1.5-21t4.5-19L84-764q-11-11-11-28t11-28q12-12 28.5-12t27.5 12l675 685q11 11 11.5 27.5T816-80q-11 13-28 12.5T759-80L641-200h39q0 33-23.5 56.5T600-120H120q-33 0-56.5-23.5T40-200Zm80 0h480v-32q0-14-4.5-19.5T580-266q-36-18-92.5-36T360-320q-71 0-127.5 18T140-266q-9 5-14.5 14t-5.5 20v32Zm240 0Zm560-400q0 69-24.5 131.5T829-355q-12 14-30 15t-32-13q-13-13-12-31t12-33q30-38 46.5-85t16.5-98q0-51-16.5-97T767-781q-12-15-12.5-33t12.5-32q13-14 31.5-13.5T829-845q42 51 66.5 113.5T920-600Zm-182 0q0 32-10 61.5T700-484q-11 15-29.5 15.5T638-482q-13-13-13.5-31.5T633-549q6-11 9.5-24t3.5-27q0-14-3.5-27t-9.5-25q-9-17-8.5-35t13.5-31q14-14 32.5-13.5T700-716q18 25 28 54.5t10 61.5Z"/></svg>`;
const languageIcon = `<svg xmlns="http://www.w3.org/2000/svg" height="12" viewBox="0 -960 960 960" width="12" fill="#5f6368"><path d="M480-80q-82 0-155-31.5t-127.5-86Q143-252 111.5-325T80-480q0-83 31.5-155.5t86-127Q252-817 325-848.5T480-880q83 0 155.5 31.5t127 86q54.5 54.5 86 127T880-480q0 82-31.5 155t-86 127.5q-54.5 54.5-127 86T480-80Zm0-82q26-36 45-75t31-83H404q12 44 31 83t45 75Zm-104-16q-18-33-31.5-68.5T322-320H204q29 50 72.5 87t99.5 55Zm208 0q56-18 99.5-55t72.5-87H638q-9 38-22.5 73.5T584-178ZM170-400h136q-3-20-4.5-39.5T300-480q0-21 1.5-40.5T306-560H170q-5 20-7.5 39.5T160-480q0 21 2.5 40.5T170-400Zm216 0h188q3-20 4.5-39.5T580-480q0-21-1.5-40.5T574-560H386q-3 20-4.5 39.5T380-480q0 21 1.5 40.5T386-400Zm268 0h136q5-20 7.5-39.5T800-480q0-21-2.5-40.5T790-560H654q3 20 4.5 39.5T660-480q0 21-1.5 40.5T654-400Zm-16-240h118q-29-50-72.5-87T584-782q18 33 31.5 68.5T638-640Zm-234 0h152q-12-44-31-83t-45-75q-26 36-45 75t-31 83Zm-200 0h118q9-38 22.5-73.5T376-782q-56 18-99.5 55T204-640Z"/></svg>`
const speakerIcon = `<svg xmlns="http://www.w3.org/2000/svg" height="16px" style="vertical-align: text-bottom;" viewBox="0 -960 960 960" width="16px" fill="#5f6368"><path d="M480-480q-66 0-113-47t-47-113q0-66 47-113t113-47q66 0 113 47t47 113q0 66-47 113t-113 47ZM160-240v-32q0-34 17.5-62.5T224-378q62-31 126-46.5T480-440q66 0 130 15.5T736-378q29 15 46.5 43.5T800-272v32q0 33-23.5 56.5T720-160H240q-33 0-56.5-23.5T160-240Zm80 0h480v-32q0-11-5.5-20T700-306q-54-27-109-40.5T480-360q-56 0-111 13.5T260-306q-9 5-14.5 14t-5.5 20v32Zm240-320q33 0 56.5-23.5T560-640q0-33-23.5-56.5T480-720q-33 0-56.5 23.5T400-640q0 33 23.5 56.5T480-560Zm0-80Zm0 400Z"/></svg>`;
function getWaveStroke() {
const styles = getComputedStyle(document.documentElement);
@@ -116,77 +82,16 @@ if (darkMq && darkMq.addEventListener) {
darkMq.addListener(handleOsThemeChange);
}
async function enumerateMicrophones() {
try {
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
stream.getTracks().forEach(track => track.stop());
const devices = await navigator.mediaDevices.enumerateDevices();
availableMicrophones = devices.filter(device => device.kind === 'audioinput');
populateMicrophoneSelect();
console.log(`Found ${availableMicrophones.length} microphone(s)`);
} catch (error) {
console.error('Error enumerating microphones:', error);
statusText.textContent = "Error accessing microphones. Please grant permission.";
}
}
function populateMicrophoneSelect() {
if (!microphoneSelect) return;
microphoneSelect.innerHTML = '<option value="">Default Microphone</option>';
availableMicrophones.forEach((device, index) => {
const option = document.createElement('option');
option.value = device.deviceId;
option.textContent = device.label || `Microphone ${index + 1}`;
microphoneSelect.appendChild(option);
});
const savedMicId = localStorage.getItem('selectedMicrophone');
if (savedMicId && availableMicrophones.some(mic => mic.deviceId === savedMicId)) {
microphoneSelect.value = savedMicId;
selectedMicrophoneId = savedMicId;
}
}
function handleMicrophoneChange() {
selectedMicrophoneId = microphoneSelect.value || null;
localStorage.setItem('selectedMicrophone', selectedMicrophoneId || '');
const selectedDevice = availableMicrophones.find(mic => mic.deviceId === selectedMicrophoneId);
const deviceName = selectedDevice ? selectedDevice.label : 'Default Microphone';
console.log(`Selected microphone: ${deviceName}`);
statusText.textContent = `Microphone changed to: ${deviceName}`;
if (isRecording) {
statusText.textContent = "Switching microphone... Please wait.";
stopRecording().then(() => {
setTimeout(() => {
toggleRecording();
}, 1000);
});
}
}
// Helpers
function fmt1(x) {
const n = Number(x);
return Number.isFinite(n) ? n.toFixed(1) : x;
}
let host, port, protocol;
port = 8000;
if (isExtension) {
host = "localhost";
protocol = "ws";
} else {
host = window.location.hostname || "localhost";
port = window.location.port;
protocol = window.location.protocol === "https:" ? "wss" : "ws";
}
// Default WebSocket URL computation
const host = window.location.hostname || "localhost";
const port = window.location.port;
const protocol = window.location.protocol === "https:" ? "wss" : "ws";
const defaultWebSocketUrl = `${protocol}://${host}${port ? ":" + port : ""}/asr`;
// Populate default caption and input
@@ -263,14 +168,6 @@ function setupWebSocket() {
websocket.onmessage = (event) => {
const data = JSON.parse(event.data);
if (data.type === "config") {
serverUseAudioWorklet = !!data.useAudioWorklet;
statusText.textContent = serverUseAudioWorklet
? "Connected. Using AudioWorklet (PCM)."
: "Connected. Using MediaRecorder (WebM).";
if (configReadyResolve) configReadyResolve();
return;
}
if (data.type === "ready_to_stop") {
console.log("Ready to stop received, finalizing display and closing WebSocket.");
@@ -338,7 +235,7 @@ function renderLinesWithBuffer(
const showTransLag = !isFinalizing && remaining_time_transcription > 0;
const showDiaLag = !isFinalizing && !!buffer_diarization && remaining_time_diarization > 0;
const signature = JSON.stringify({
lines: (lines || []).map((it) => ({ speaker: it.speaker, text: it.text, start: it.start, end: it.end, detected_language: it.detected_language })),
lines: (lines || []).map((it) => ({ speaker: it.speaker, text: it.text, beg: it.beg, end: it.end })),
buffer_transcription: buffer_transcription || "",
buffer_diarization: buffer_diarization || "",
status: current_status,
@@ -361,24 +258,19 @@ function renderLinesWithBuffer(
const linesHtml = (lines || [])
.map((item, idx) => {
let timeInfo = "";
if (item.start !== undefined && item.end !== undefined) {
timeInfo = ` ${item.start} - ${item.end}`;
if (item.beg !== undefined && item.end !== undefined) {
timeInfo = ` ${item.beg} - ${item.end}`;
}
let speakerLabel = "";
if (item.speaker === -2) {
speakerLabel = `<span class="silence">${silenceIcon}<span id='timeInfo'>${timeInfo}</span></span>`;
speakerLabel = `<span class="silence">Silence<span id='timeInfo'>${timeInfo}</span></span>`;
} else if (item.speaker == 0 && !isFinalizing) {
speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'><span class="loading-diarization-value">${fmt1(
remaining_time_diarization
)}</span> second(s) of audio are undergoing diarization</span></span>`;
} else if (item.speaker !== 0) {
const speakerNum = `<span class="speaker-badge">${item.speaker}</span>`;
speakerLabel = `<span id="speaker">${speakerIcon}${speakerNum}<span id='timeInfo'>${timeInfo}</span></span>`;
if (item.detected_language) {
speakerLabel += `<span class="label_language">${languageIcon}<span>${item.detected_language}</span></span>`;
}
speakerLabel = `<span id="speaker">Speaker ${item.speaker}<span id='timeInfo'>${timeInfo}</span></span>`;
}
let currentLineText = item.text || "";
@@ -415,16 +307,6 @@ function renderLinesWithBuffer(
}
}
}
if (item.translation) {
currentLineText += `
<div>
<div class="label_translation">
${translationIcon}
<span>${item.translation}</span>
</div>
</div>`;
}
return currentLineText.trim().length > 0 || speakerLabel.length > 0
? `<p>${speakerLabel}<br/><div class='textcontent'>${currentLineText}</div></p>`
@@ -433,10 +315,7 @@ function renderLinesWithBuffer(
.join("");
linesTranscriptDiv.innerHTML = linesHtml;
const transcriptContainer = document.querySelector('.transcript-container');
if (transcriptContainer) {
transcriptContainer.scrollTo({ top: transcriptContainer.scrollHeight, behavior: "smooth" });
}
window.scrollTo({ top: document.body.scrollHeight, behavior: "smooth" });
}
function updateTimer() {
@@ -498,44 +377,7 @@ async function startRecording() {
console.log("Error acquiring wake lock.");
}
let stream;
// chromium extension. in the future, both chrome page audio and mic will be used
if (isExtension) {
try {
stream = await new Promise((resolve, reject) => {
chrome.tabCapture.capture({audio: true}, (s) => {
if (s) {
resolve(s);
} else {
reject(new Error('Tab capture failed or not available'));
}
});
});
try {
outputAudioContext = new (window.AudioContext || window.webkitAudioContext)();
audioSource = outputAudioContext.createMediaStreamSource(stream);
audioSource.connect(outputAudioContext.destination);
} catch (audioError) {
console.warn('could not preserve system audio:', audioError);
}
statusText.textContent = "Using tab audio capture.";
} catch (tabError) {
console.log('Tab capture not available, falling back to microphone', tabError);
const audioConstraints = selectedMicrophoneId
? { audio: { deviceId: { exact: selectedMicrophoneId } } }
: { audio: true };
stream = await navigator.mediaDevices.getUserMedia(audioConstraints);
statusText.textContent = "Using microphone audio.";
}
} else if (isWebContext) {
const audioConstraints = selectedMicrophoneId
? { audio: { deviceId: { exact: selectedMicrophoneId } } }
: { audio: true };
stream = await navigator.mediaDevices.getUserMedia(audioConstraints);
}
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
audioContext = new (window.AudioContext || window.webkitAudioContext)();
analyser = audioContext.createAnalyser();
@@ -543,54 +385,13 @@ async function startRecording() {
microphone = audioContext.createMediaStreamSource(stream);
microphone.connect(analyser);
if (serverUseAudioWorklet) {
if (!audioContext.audioWorklet) {
throw new Error("AudioWorklet is not supported in this browser");
recorder = new MediaRecorder(stream, { mimeType: "audio/webm" });
recorder.ondataavailable = (e) => {
if (websocket && websocket.readyState === WebSocket.OPEN) {
websocket.send(e.data);
}
await audioContext.audioWorklet.addModule("/web/pcm_worklet.js");
workletNode = new AudioWorkletNode(audioContext, "pcm-forwarder", { numberOfInputs: 1, numberOfOutputs: 0, channelCount: 1 });
microphone.connect(workletNode);
recorderWorker = new Worker("/web/recorder_worker.js");
recorderWorker.postMessage({
command: "init",
config: {
sampleRate: audioContext.sampleRate,
},
});
recorderWorker.onmessage = (e) => {
if (websocket && websocket.readyState === WebSocket.OPEN) {
websocket.send(e.data.buffer);
}
};
workletNode.port.onmessage = (e) => {
const data = e.data;
const ab = data instanceof ArrayBuffer ? data : data.buffer;
recorderWorker.postMessage(
{
command: "record",
buffer: ab,
},
[ab]
);
};
} else {
try {
recorder = new MediaRecorder(stream, { mimeType: "audio/webm" });
} catch (e) {
recorder = new MediaRecorder(stream);
}
recorder.ondataavailable = (e) => {
if (websocket && websocket.readyState === WebSocket.OPEN) {
if (e.data && e.data.size > 0) {
websocket.send(e.data);
}
}
};
recorder.start(chunkDuration);
}
};
recorder.start(chunkDuration);
startTime = Date.now();
timerInterval = setInterval(updateTimer, 1000);
@@ -629,28 +430,10 @@ async function stopRecording() {
}
if (recorder) {
try {
recorder.stop();
} catch (e) {
}
recorder.stop();
recorder = null;
}
if (recorderWorker) {
recorderWorker.terminate();
recorderWorker = null;
}
if (workletNode) {
try {
workletNode.port.onmessage = null;
} catch (e) {}
try {
workletNode.disconnect();
} catch (e) {}
workletNode = null;
}
if (microphone) {
microphone.disconnect();
microphone = null;
@@ -669,16 +452,6 @@ async function stopRecording() {
audioContext = null;
}
if (audioSource) {
audioSource.disconnect();
audioSource = null;
}
if (outputAudioContext && outputAudioContext.state !== "closed") {
outputAudioContext.close()
outputAudioContext = null;
}
if (animationFrame) {
cancelAnimationFrame(animationFrame);
animationFrame = null;
@@ -704,11 +477,9 @@ async function toggleRecording() {
console.log("Connecting to WebSocket");
try {
if (websocket && websocket.readyState === WebSocket.OPEN) {
await configReady;
await startRecording();
} else {
await setupWebSocket();
await configReady;
await startRecording();
}
} catch (err) {
@@ -730,7 +501,7 @@ function updateUI() {
statusText.textContent = "Please wait for processing to complete...";
}
} else if (isRecording) {
statusText.textContent = "";
statusText.textContent = "Recording...";
} else {
if (
statusText.textContent !== "Finished processing audio! Ready to record again." &&
@@ -745,59 +516,3 @@ function updateUI() {
}
recordButton.addEventListener("click", toggleRecording);
if (microphoneSelect) {
microphoneSelect.addEventListener("change", handleMicrophoneChange);
}
document.addEventListener('DOMContentLoaded', async () => {
try {
await enumerateMicrophones();
} catch (error) {
console.log("Could not enumerate microphones on load:", error);
}
});
navigator.mediaDevices.addEventListener('devicechange', async () => {
console.log('Device change detected, re-enumerating microphones');
try {
await enumerateMicrophones();
} catch (error) {
console.log("Error re-enumerating microphones:", error);
}
});
settingsToggle.addEventListener("click", () => {
settingsDiv.classList.toggle("visible");
settingsToggle.classList.toggle("active");
});
if (isExtension) {
async function checkAndRequestPermissions() {
const micPermission = await navigator.permissions.query({
name: "microphone",
});
const permissionDisplay = document.getElementById("audioPermission");
if (permissionDisplay) {
permissionDisplay.innerText = `MICROPHONE: ${micPermission.state}`;
}
// if (micPermission.state !== "granted") {
// chrome.tabs.create({ url: "welcome.html" });
// }
const intervalId = setInterval(async () => {
const micPermission = await navigator.permissions.query({
name: "microphone",
});
if (micPermission.state === "granted") {
if (permissionDisplay) {
permissionDisplay.innerText = `MICROPHONE: ${micPermission.state}`;
}
clearInterval(intervalId);
}
}, 100);
}
void checkAndRequestPermissions();
}

View File

@@ -1,16 +0,0 @@
class PCMForwarder extends AudioWorkletProcessor {
process(inputs) {
const input = inputs[0];
if (input && input[0] && input[0].length) {
// Forward mono channel (0). If multi-channel, downmixing can be added here.
const channelData = input[0];
const copy = new Float32Array(channelData.length);
copy.set(channelData);
this.port.postMessage(copy, [copy.buffer]);
}
// Keep processor alive
return true;
}
}
registerProcessor('pcm-forwarder', PCMForwarder);

View File

@@ -1,58 +0,0 @@
let sampleRate = 48000;
let targetSampleRate = 16000;
self.onmessage = function (e) {
switch (e.data.command) {
case 'init':
init(e.data.config);
break;
case 'record':
record(e.data.buffer);
break;
}
};
function init(config) {
sampleRate = config.sampleRate;
targetSampleRate = config.targetSampleRate || 16000;
}
function record(inputBuffer) {
const buffer = new Float32Array(inputBuffer);
const resampledBuffer = resample(buffer, sampleRate, targetSampleRate);
const pcmBuffer = toPCM(resampledBuffer);
self.postMessage({ buffer: pcmBuffer }, [pcmBuffer]);
}
function resample(buffer, from, to) {
if (from === to) {
return buffer;
}
const ratio = from / to;
const newLength = Math.round(buffer.length / ratio);
const result = new Float32Array(newLength);
let offsetResult = 0;
let offsetBuffer = 0;
while (offsetResult < result.length) {
const nextOffsetBuffer = Math.round((offsetResult + 1) * ratio);
let accum = 0, count = 0;
for (let i = offsetBuffer; i < nextOffsetBuffer && i < buffer.length; i++) {
accum += buffer[i];
count++;
}
result[offsetResult] = accum / count;
offsetResult++;
offsetBuffer = nextOffsetBuffer;
}
return result;
}
function toPCM(input) {
const buffer = new ArrayBuffer(input.length * 2);
const view = new DataView(buffer);
for (let i = 0; i < input.length; i++) {
const s = Math.max(-1, Math.min(1, input[i]));
view.setInt16(i * 2, s < 0 ? s * 0x8000 : s * 0x7FFF, true);
}
return buffer;
}

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M480-80q-82 0-155-31.5t-127.5-86Q143-252 111.5-325T80-480q0-83 31.5-155.5t86-127Q252-817 325-848.5T480-880q83 0 155.5 31.5t127 86q54.5 54.5 86 127T880-480q0 82-31.5 155t-86 127.5q-54.5 54.5-127 86T480-80Zm0-82q26-36 45-75t31-83H404q12 44 31 83t45 75Zm-104-16q-18-33-31.5-68.5T322-320H204q29 50 72.5 87t99.5 55Zm208 0q56-18 99.5-55t72.5-87H638q-9 38-22.5 73.5T584-178ZM170-400h136q-3-20-4.5-39.5T300-480q0-21 1.5-40.5T306-560H170q-5 20-7.5 39.5T160-480q0 21 2.5 40.5T170-400Zm216 0h188q3-20 4.5-39.5T580-480q0-21-1.5-40.5T574-560H386q-3 20-4.5 39.5T380-480q0 21 1.5 40.5T386-400Zm268 0h136q5-20 7.5-39.5T800-480q0-21-2.5-40.5T790-560H654q3 20 4.5 39.5T660-480q0 21-1.5 40.5T654-400Zm-16-240h118q-29-50-72.5-87T584-782q18 33 31.5 68.5T638-640Zm-234 0h152q-12-44-31-83t-45-75q-26 36-45 75t-31 83Zm-200 0h118q9-38 22.5-73.5T376-782q-56 18-99.5 55T204-640Z"/></svg>

Before

Width:  |  Height:  |  Size: 976 B

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M433-80q-27 0-46.5-18T363-142l-9-66q-13-5-24.5-12T307-235l-62 26q-25 11-50 2t-39-32l-47-82q-14-23-8-49t27-43l53-40q-1-7-1-13.5v-27q0-6.5 1-13.5l-53-40q-21-17-27-43t8-49l47-82q14-23 39-32t50 2l62 26q11-8 23-15t24-12l9-66q4-26 23.5-44t46.5-18h94q27 0 46.5 18t23.5 44l9 66q13 5 24.5 12t22.5 15l62-26q25-11 50-2t39 32l47 82q14 23 8 49t-27 43l-53 40q1 7 1 13.5v27q0 6.5-2 13.5l53 40q21 17 27 43t-8 49l-48 82q-14 23-39 32t-50-2l-60-26q-11 8-23 15t-24 12l-9 66q-4 26-23.5 44T527-80h-94Zm7-80h79l14-106q31-8 57.5-23.5T639-327l99 41 39-68-86-65q5-14 7-29.5t2-31.5q0-16-2-31.5t-7-29.5l86-65-39-68-99 42q-22-23-48.5-38.5T533-694l-13-106h-79l-14 106q-31 8-57.5 23.5T321-633l-99-41-39 68 86 64q-5 15-7 30t-2 32q0 16 2 31t7 30l-86 65 39 68 99-42q22 23 48.5 38.5T427-266l13 106Zm42-180q58 0 99-41t41-99q0-58-41-99t-99-41q-59 0-99.5 41T342-480q0 58 40.5 99t99.5 41Zm-2-140Z"/></svg>

Before

Width:  |  Height:  |  Size: 982 B

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M514-556 320-752q9-3 19-5.5t21-2.5q66 0 113 47t47 113q0 11-1.5 22t-4.5 22ZM40-200v-32q0-33 17-62t47-44q51-26 115-44t141-18q26 0 49.5 2.5T456-392l-56-54q-9 3-19 4.5t-21 1.5q-66 0-113-47t-47-113q0-11 1.5-21t4.5-19L84-764q-11-11-11-28t11-28q12-12 28.5-12t27.5 12l675 685q11 11 11.5 27.5T816-80q-11 13-28 12.5T759-80L641-200h39q0 33-23.5 56.5T600-120H120q-33 0-56.5-23.5T40-200Zm80 0h480v-32q0-14-4.5-19.5T580-266q-36-18-92.5-36T360-320q-71 0-127.5 18T140-266q-9 5-14.5 14t-5.5 20v32Zm240 0Zm560-400q0 69-24.5 131.5T829-355q-12 14-30 15t-32-13q-13-13-12-31t12-33q30-38 46.5-85t16.5-98q0-51-16.5-97T767-781q-12-15-12.5-33t12.5-32q13-14 31.5-13.5T829-845q42 51 66.5 113.5T920-600Zm-182 0q0 32-10 61.5T700-484q-11 15-29.5 15.5T638-482q-13-13-13.5-31.5T633-549q6-11 9.5-24t3.5-27q0-14-3.5-27t-9.5-25q-9-17-8.5-35t13.5-31q14-14 32.5-13.5T700-716q18 25 28 54.5t10 61.5Z"/></svg>

Before

Width:  |  Height:  |  Size: 984 B

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M480-480q-66 0-113-47t-47-113q0-66 47-113t113-47q66 0 113 47t47 113q0 66-47 113t-113 47ZM160-240v-32q0-34 17.5-62.5T224-378q62-31 126-46.5T480-440q66 0 130 15.5T736-378q29 15 46.5 43.5T800-272v32q0 33-23.5 56.5T720-160H240q-33 0-56.5-23.5T160-240Zm80 0h480v-32q0-11-5.5-20T700-306q-54-27-109-40.5T480-360q-56 0-111 13.5T260-306q-9 5-14.5 14t-5.5 20v32Zm240-320q33 0 56.5-23.5T560-640q0-33-23.5-56.5T480-720q-33 0-56.5 23.5T400-640q0 33 23.5 56.5T480-560Zm0-80Zm0 400Z"/></svg>

Before

Width:  |  Height:  |  Size: 592 B

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="m603-202-34 97q-4 11-14 18t-22 7q-20 0-32.5-16.5T496-133l152-402q5-11 15-18t22-7h30q12 0 22 7t15 18l152 403q8 19-4 35.5T868-80q-13 0-22.5-7T831-106l-34-96H603ZM362-401 188-228q-11 11-27.5 11.5T132-228q-11-11-11-28t11-28l174-174q-35-35-63.5-80T190-640h84q20 39 40 68t48 58q33-33 68.5-92.5T484-720H80q-17 0-28.5-11.5T40-760q0-17 11.5-28.5T80-800h240v-40q0-17 11.5-28.5T360-880q17 0 28.5 11.5T400-840v40h240q17 0 28.5 11.5T680-760q0 17-11.5 28.5T640-720h-76q-21 72-63 148t-83 116l96 98-30 82-122-125Zm266 129h144l-72-204-72 204Z"/></svg>

Before

Width:  |  Height:  |  Size: 650 B

View File

@@ -16,57 +16,43 @@ def get_web_interface_html():
def get_inline_ui_html():
"""Returns the complete web interface HTML with all assets embedded in a single call."""
try:
# Load HTML template
with resources.files('whisperlivekit.web').joinpath('live_transcription.html').open('r', encoding='utf-8') as f:
html_content = f.read()
html_content = f.read()
# Load CSS and embed it
with resources.files('whisperlivekit.web').joinpath('live_transcription.css').open('r', encoding='utf-8') as f:
css_content = f.read()
# Load JS and embed it
with resources.files('whisperlivekit.web').joinpath('live_transcription.js').open('r', encoding='utf-8') as f:
js_content = f.read()
with resources.files('whisperlivekit.web').joinpath('pcm_worklet.js').open('r', encoding='utf-8') as f:
worklet_code = f.read()
with resources.files('whisperlivekit.web').joinpath('recorder_worker.js').open('r', encoding='utf-8') as f:
worker_code = f.read()
js_content = js_content.replace(
'await audioContext.audioWorklet.addModule("/web/pcm_worklet.js");',
'const workletBlob = new Blob([`' + worklet_code + '`], { type: "application/javascript" });\n' +
'const workletUrl = URL.createObjectURL(workletBlob);\n' +
'await audioContext.audioWorklet.addModule(workletUrl);'
)
js_content = js_content.replace(
'recorderWorker = new Worker("/web/recorder_worker.js");',
'const workerBlob = new Blob([`' + worker_code + '`], { type: "application/javascript" });\n' +
'const workerUrl = URL.createObjectURL(workerBlob);\n' +
'recorderWorker = new Worker(workerUrl);'
)
# SVG files
# Load SVG files and convert to data URIs
with resources.files('whisperlivekit.web').joinpath('src', 'system_mode.svg').open('r', encoding='utf-8') as f:
system_svg = f.read()
system_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(system_svg.encode('utf-8')).decode('utf-8')}"
with resources.files('whisperlivekit.web').joinpath('src', 'light_mode.svg').open('r', encoding='utf-8') as f:
light_svg = f.read()
light_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(light_svg.encode('utf-8')).decode('utf-8')}"
with resources.files('whisperlivekit.web').joinpath('src', 'dark_mode.svg').open('r', encoding='utf-8') as f:
dark_svg = f.read()
dark_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(dark_svg.encode('utf-8')).decode('utf-8')}"
with resources.files('whisperlivekit.web').joinpath('src', 'settings.svg').open('r', encoding='utf-8') as f:
settings = f.read()
settings_uri = f"data:image/svg+xml;base64,{base64.b64encode(settings.encode('utf-8')).decode('utf-8')}"
# Replace external references
# Replace external references with embedded content
html_content = html_content.replace(
'<link rel="stylesheet" href="live_transcription.css" />',
'<link rel="stylesheet" href="/web/live_transcription.css" />',
f'<style>\n{css_content}\n</style>'
)
html_content = html_content.replace(
'<script src="live_transcription.js"></script>',
'<script src="/web/live_transcription.js"></script>',
f'<script>\n{js_content}\n</script>'
)
# Replace SVG references
# Replace SVG references with data URIs
html_content = html_content.replace(
'<img src="/web/src/system_mode.svg" alt="" />',
f'<img src="{system_data_uri}" alt="" />'
@@ -82,11 +68,6 @@ def get_inline_ui_html():
f'<img src="{dark_data_uri}" alt="" />'
)
html_content = html_content.replace(
'<img src="web/src/settings.svg" alt="Settings" />',
f'<img src="{settings_uri}" alt="" />'
)
return html_content
except Exception as e:

View File

@@ -11,14 +11,14 @@ class ASRBase:
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
# "" for faster-whisper because it emits the spaces when needed)
def __init__(self, lan, model_size=None, cache_dir=None, model_dir=None, logfile=sys.stderr):
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr):
self.logfile = logfile
self.transcribe_kargs = {}
if lan == "auto":
self.original_language = None
else:
self.original_language = lan
self.model = self.load_model(model_size, cache_dir, model_dir)
self.model = self.load_model(modelsize, cache_dir, model_dir)
def with_offset(self, offset: float) -> ASRToken:
# This method is kept for compatibility (typically you will use ASRToken.with_offset)
@@ -27,7 +27,7 @@ class ASRBase:
def __repr__(self):
return f"ASRToken(start={self.start:.2f}, end={self.end:.2f}, text={self.text!r})"
def load_model(self, model_size, cache_dir, model_dir):
def load_model(self, modelsize, cache_dir, model_dir):
raise NotImplementedError("must be implemented in the child class")
def transcribe(self, audio, init_prompt=""):
@@ -41,7 +41,7 @@ class WhisperTimestampedASR(ASRBase):
"""Uses whisper_timestamped as the backend."""
sep = " "
def load_model(self, model_size=None, cache_dir=None, model_dir=None):
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
import whisper
import whisper_timestamped
from whisper_timestamped import transcribe_timestamped
@@ -49,7 +49,7 @@ class WhisperTimestampedASR(ASRBase):
self.transcribe_timestamped = transcribe_timestamped
if model_dir is not None:
logger.debug("ignoring model_dir, not implemented")
return whisper.load_model(model_size, download_root=cache_dir)
return whisper.load_model(modelsize, download_root=cache_dir)
def transcribe(self, audio, init_prompt=""):
result = self.transcribe_timestamped(
@@ -88,17 +88,17 @@ class FasterWhisperASR(ASRBase):
"""Uses faster-whisper as the backend."""
sep = ""
def load_model(self, model_size=None, cache_dir=None, model_dir=None):
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
from faster_whisper import WhisperModel
if model_dir is not None:
logger.debug(f"Loading whisper model from model_dir {model_dir}. "
f"model_size and cache_dir parameters are not used.")
f"modelsize and cache_dir parameters are not used.")
model_size_or_path = model_dir
elif model_size is not None:
model_size_or_path = model_size
elif modelsize is not None:
model_size_or_path = modelsize
else:
raise ValueError("Either model_size or model_dir must be set")
raise ValueError("Either modelsize or model_dir must be set")
device = "auto" # Allow CTranslate2 to decide available device
compute_type = "auto" # Allow CTranslate2 to decide faster compute type
@@ -149,18 +149,18 @@ class MLXWhisper(ASRBase):
"""
sep = ""
def load_model(self, model_size=None, cache_dir=None, model_dir=None):
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
from mlx_whisper.transcribe import ModelHolder, transcribe
import mlx.core as mx
if model_dir is not None:
logger.debug(f"Loading whisper model from model_dir {model_dir}. model_size parameter is not used.")
logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used.")
model_size_or_path = model_dir
elif model_size is not None:
model_size_or_path = self.translate_model_name(model_size)
logger.debug(f"Loading whisper model {model_size}. You use mlx whisper, so {model_size_or_path} will be used.")
elif modelsize is not None:
model_size_or_path = self.translate_model_name(modelsize)
logger.debug(f"Loading whisper model {modelsize}. You use mlx whisper, so {model_size_or_path} will be used.")
else:
raise ValueError("Either model_size or model_dir must be set")
raise ValueError("Either modelsize or model_dir must be set")
self.model_size_or_path = model_size_or_path
dtype = mx.float16

View File

@@ -106,6 +106,9 @@ class OnlineASRProcessor:
def __init__(
self,
asr,
tokenize_method: Optional[callable] = None,
buffer_trimming: Tuple[str, float] = ("segment", 15),
confidence_validation = False,
logfile=sys.stderr,
):
"""
@@ -116,14 +119,13 @@ class OnlineASRProcessor:
buffer_trimming: A tuple (option, seconds), where option is either "sentence" or "segment".
"""
self.asr = asr
self.tokenize = asr.tokenizer
self.tokenize = tokenize_method
self.logfile = logfile
self.confidence_validation = asr.confidence_validation
self.confidence_validation = confidence_validation
self.global_time_offset = 0.0
self.init()
self.buffer_trimming_way = asr.buffer_trimming
self.buffer_trimming_sec = asr.buffer_trimming_sec
self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
if self.buffer_trimming_way not in ["sentence", "segment"]:
raise ValueError("buffer_trimming must be either 'sentence' or 'segment'")

View File

@@ -6,7 +6,6 @@ from functools import lru_cache
import time
import logging
from .backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR
from whisperlivekit.warmup import warmup_asr
logger = logging.getLogger(__name__)
@@ -64,23 +63,11 @@ def create_tokenizer(lan):
return WtPtok()
def backend_factory(
backend,
lan,
model_size,
model_cache_dir,
model_dir,
task,
buffer_trimming,
buffer_trimming_sec,
confidence_validation,
warmup_file=None,
min_chunk_size=None,
):
backend = backend
def backend_factory(args):
backend = args.backend
if backend == "openai-api":
logger.debug("Using OpenAI API.")
asr = OpenaiApiASR(lan=lan)
asr = OpenaiApiASR(lan=args.lan)
else:
if backend == "faster-whisper":
asr_cls = FasterWhisperASR
@@ -90,33 +77,34 @@ def backend_factory(
asr_cls = WhisperTimestampedASR
# Only for FasterWhisperASR and WhisperTimestampedASR
size = args.model
t = time.time()
logger.info(f"Loading Whisper {model_size} model for language {lan}...")
logger.info(f"Loading Whisper {size} model for language {args.lan}...")
asr = asr_cls(
model_size=model_size,
lan=lan,
cache_dir=model_cache_dir,
model_dir=model_dir,
modelsize=size,
lan=args.lan,
cache_dir=getattr(args, 'model_cache_dir', None),
model_dir=getattr(args, 'model_dir', None),
)
e = time.time()
logger.info(f"done. It took {round(e-t,2)} seconds.")
if task == "translate":
# Apply common configurations
if getattr(args, "vad", False): # Checks if VAD argument is present and True
logger.info("Setting VAD filter")
asr.use_vad()
language = args.lan
if args.task == "translate":
if backend != "simulstreaming":
asr.set_translate_task()
tgt_language = "en" # Whisper translates into English
else:
tgt_language = lan # Whisper transcribes in this language
tgt_language = language # Whisper transcribes in this language
# Create the tokenizer
if buffer_trimming == "sentence":
if args.buffer_trimming == "sentence":
tokenizer = create_tokenizer(tgt_language)
else:
tokenizer = None
warmup_asr(asr, warmup_file)
asr.confidence_validation = confidence_validation
asr.tokenizer = tokenizer
asr.buffer_trimming = buffer_trimming
asr.buffer_trimming_sec = buffer_trimming_sec
return asr
return asr, tokenizer