mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
Compare commits
30 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bcffdbc6b3 | ||
|
|
80b77998f9 | ||
|
|
d310f7e25f | ||
|
|
8d9be88fe6 | ||
|
|
16461052ed | ||
|
|
5491dbd824 | ||
|
|
13401ffe24 | ||
|
|
7108d2ddc5 | ||
|
|
a732e0903e | ||
|
|
0491681be4 | ||
|
|
ffe5284764 | ||
|
|
41ca17acda | ||
|
|
06b31f51eb | ||
|
|
ece02db6a3 | ||
|
|
939a7ebf8b | ||
|
|
61edb70fff | ||
|
|
4e455b8aab | ||
|
|
9434390ad3 | ||
|
|
65250db92c | ||
|
|
416dce7975 | ||
|
|
0c5365e7c6 | ||
|
|
19e9d76610 | ||
|
|
e7b05b0138 | ||
|
|
818c9c37ca | ||
|
|
714fb3b14a | ||
|
|
0af379c465 | ||
|
|
9c5bb5df19 | ||
|
|
dc6ea79036 | ||
|
|
21bbb59e31 | ||
|
|
3467109668 |
18
.gitignore
vendored
18
.gitignore
vendored
@@ -54,21 +54,6 @@ coverage.xml
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
@@ -138,4 +123,5 @@ test_*.py
|
||||
launch.json
|
||||
.DS_Store
|
||||
test/*
|
||||
nllb-200-distilled-600M-ctranslate2/*
|
||||
nllb-200-distilled-600M-ctranslate2/*
|
||||
*.mp3
|
||||
226
LICENSE
226
LICENSE
@@ -1,52 +1,210 @@
|
||||
# License
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
## Main Software License
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
MIT License
|
||||
1. Definitions.
|
||||
|
||||
Copyright (c) 2025 Quentin Fuxa.
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
## SimulStreaming Backend License
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
**When using the SimulStreaming backend (SimulWhisper), additional licensing terms apply:**
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
SimulStreaming (https://github.com/ufal/SimulStreaming) is dual-licensed:
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
### 🔹 Non-Commercial Use
|
||||
You may use SimulStreaming under the **PolyForm Noncommercial License 1.0.0** if you obtain the code through the GitHub repository. This license is **free of charge** and comes with **no obligations** for non-commercial users.
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
### 🔸 Commercial Use
|
||||
Understanding who uses SimulStreaming commercially helps improve and prioritize development. Therefore, **registration is required** for those who acquire a commercial license.
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
Commercial licenses are planned to be **affordable** to SMEs and individuals. They are considering providing commercial licenses either for free or for a symbolic one-time fee, and may also provide additional support. You can share your preference via the [questionnaire](https://forms.cloud.microsoft.com/e/7tCxb4gJfB).
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
You can also leave your contact [there](https://forms.cloud.microsoft.com/e/7tCxb4gJfB) to be notified when commercial licenses become available.
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
**Contact for SimulStreaming licensing:**
|
||||
[Dominik Macháček](https://ufal.mff.cuni.cz/dominik-machacek/), machacek@ufal.mff.cuni.cz
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2025 Quentin Fuxa
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
---
|
||||
|
||||
## Based on:
|
||||
- **whisper_streaming** by ÚFAL – MIT License – https://github.com/ufal/whisper_streaming. The original work by ÚFAL. License: https://github.com/ufal/whisper_streaming/blob/main/LICENSE
|
||||
- **silero-vad** by Snakers4 – MIT License – https://github.com/snakers4/silero-vad. The work by Snakers4 (silero-vad). License: https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/LICENSE
|
||||
- **Diart** by juanmc2005 – MIT License – https://github.com/juanmc2005/diart. The work in Diart by juanmc2005. License: https://github.com/juanmc2005/diart/blob/main/LICENSE
|
||||
- **SimulStreaming** by ÚFAL – Dual License (PolyForm Noncommercial License 1.0.0 / Commercial License) – https://github.com/ufal/SimulStreaming
|
||||
- **SimulWhisper** by Speech and Audio Technology LAB of Tsinghua University – Apache-2.0 – https://github.com/ufal/SimulStreaming
|
||||
- **SimulStreaming** by ÚFAL – MIT License – https://github.com/ufal/SimulStreaming
|
||||
- **NeMo** by NVidia - Apache-2.0 - https://github.com/NVIDIA-NeMo/NeMo
|
||||
- **whisper_streaming** by ÚFAL – MIT License – https://github.com/ufal/whisper_streaming.
|
||||
- **silero-vad** by Snakers4 – MIT License – https://github.com/snakers4/silero-vad.
|
||||
- **Diart** by juanmc2005 – MIT License – https://github.com/juanmc2005/diart.
|
||||
|
||||
40
README.md
40
README.md
@@ -10,16 +10,16 @@
|
||||
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
|
||||
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a>
|
||||
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.15-dark_green"></a>
|
||||
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT/Dual Licensed-dark_green"></a>
|
||||
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-Apache 2.0-dark_green"></a>
|
||||
</p>
|
||||
|
||||
|
||||
Real-time speech transcription directly to your browser, with a ready-to-use backend+server and a simple frontend. ✨
|
||||
Real-time transcription directly to your browser, with a ready-to-use backend+server and a simple frontend.
|
||||
|
||||
#### Powered by Leading Research:
|
||||
|
||||
- [SimulStreaming](https://github.com/ufalSimulStreaming) (SOTA 2025) - Ultra-low latency transcription using [AlignAtt policy](https://arxiv.org/pdf/2305.11408)
|
||||
- [NLLB](https://arxiv.org/abs/2207.04672), ([distilled](https://huggingface.co/entai2965/nllb-200-distilled-600M-ctranslate2)) (2024) - Translation to more than 100 languages.
|
||||
- Simul-[Whisper](https://github.com/backspacetg/simul_whisper)/[Streaming](https://github.com/ufal/SimulStreaming) (SOTA 2025) - Ultra-low latency transcription using [AlignAtt policy](https://arxiv.org/pdf/2305.11408)
|
||||
- [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting) (2025), based on [distilled](https://huggingface.co/entai2965/nllb-200-distilled-600M-ctranslate2) [NLLB](https://arxiv.org/abs/2207.04672) (2022, 2024) - Simulatenous translation from & to 200 languages.
|
||||
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - Low latency transcription using [LocalAgreement policy](https://www.isca-archive.org/interspeech_2020/liu20s_interspeech.pdf)
|
||||
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - Advanced real-time speaker diarization
|
||||
- [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) - Real-time speaker diarization
|
||||
@@ -45,7 +45,7 @@ pip install whisperlivekit
|
||||
#### Quick Start
|
||||
1. **Start the transcription server:**
|
||||
```bash
|
||||
whisperlivekit-server --model base --language en
|
||||
wlk --model base --language en
|
||||
```
|
||||
|
||||
2. **Open your browser** and navigate to `http://localhost:8000`. Start speaking and watch your words appear in real-time!
|
||||
@@ -53,6 +53,7 @@ pip install whisperlivekit
|
||||
|
||||
> - See [tokenizer.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) for the list of all available languages.
|
||||
> - For HTTPS requirements, see the **Parameters** section for SSL configuration options.
|
||||
> - The CLI entry point is exposed as both `wlk` and `whisperlivekit-server`; they are equivalent.
|
||||
|
||||
#### Use it to capture audio from web pages.
|
||||
|
||||
@@ -68,13 +69,12 @@ Go to `chrome-extension` for instructions.
|
||||
|
||||
| Optional | `pip install` |
|
||||
|-----------|-------------|
|
||||
| **Speaker diarization with Sortformer** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
|
||||
| **Apple Silicon optimized backend** | `mlx-whisper` |
|
||||
| **NLLB Translation** | `huggingface_hub` & `transformers` |
|
||||
| **Windows/Linux optimizations** | `faster-whisper` |
|
||||
| **Apple Silicon optimizations** | `mlx-whisper` |
|
||||
| **Translation** | `nllw` |
|
||||
| **Speaker diarization** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
|
||||
| OpenAI API | `openai` |
|
||||
| *[Not recommanded]* Speaker diarization with Diart | `diart` |
|
||||
| *[Not recommanded]* Original Whisper backend | `whisper` |
|
||||
| *[Not recommanded]* Improved timestamps backend | `whisper-timestamped` |
|
||||
| OpenAI API backend | `openai` |
|
||||
|
||||
See **Parameters & Configuration** below on how to use them.
|
||||
|
||||
@@ -86,10 +86,10 @@ See **Parameters & Configuration** below on how to use them.
|
||||
|
||||
```bash
|
||||
# Large model and translate from french to danish
|
||||
whisperlivekit-server --model large-v3 --language fr --target-language da
|
||||
wlk --model large-v3 --language fr --target-language da
|
||||
|
||||
# Diarization and server listening on */80
|
||||
whisperlivekit-server --host 0.0.0.0 --port 80 --model medium --diarization --language fr
|
||||
wlk --host 0.0.0.0 --port 80 --model medium --diarization --language fr
|
||||
```
|
||||
|
||||
|
||||
@@ -139,13 +139,13 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/available_models.md) | `small` |
|
||||
| `--model-dir` | Directory containing Whisper model.bin and other files. Overrides `--model`. | `None` |
|
||||
| `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/available_models.md) | `small` |
|
||||
| `--model-path` | Local .pt file/directory **or** Hugging Face repo ID containing the Whisper model. Overrides `--model`. Recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/models_compatible_formats.md) | `None` |
|
||||
| `--language` | List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` |
|
||||
| `--target-language` | If sets, activates translation using NLLB. Ex: `fr`. [118 languages available](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/translation/mapping_languages.py). If you want to translate to english, you should rather use `--task translate`, since Whisper can do it directly. | `None` |
|
||||
| `--task` | Set to `translate` to translate *only* to english, using Whisper translation. | `transcribe` |
|
||||
| `--target-language` | If sets, translates using [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting). [200 languages available](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/supported_languages.md). If you want to translate to english, you can also use `--direct-english-translation`. The STT model will try to directly output the translation. | `None` |
|
||||
| `--diarization` | Enable speaker identification | `False` |
|
||||
| `--backend` | Processing backend. You can switch to `faster-whisper` if `simulstreaming` does not work correctly | `simulstreaming` |
|
||||
| `--backend-policy` | Streaming strategy: `1`/`simulstreaming` uses AlignAtt SimulStreaming, `2`/`localagreement` uses the LocalAgreement policy | `simulstreaming` |
|
||||
| `--backend` | Whisper implementation selector. `auto` picks MLX on macOS (if installed), otherwise Faster-Whisper, otherwise vanilla Whisper. You can also force `mlx-whisper`, `faster-whisper`, `whisper`, or `openai-api` (LocalAgreement only) | `auto` |
|
||||
| `--no-vac` | Disable Voice Activity Controller | `False` |
|
||||
| `--no-vad` | Disable Voice Activity Detection | `False` |
|
||||
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
|
||||
@@ -171,7 +171,8 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
| SimulStreaming backend options | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--disable-fast-encoder` | Disable Faster Whisper or MLX Whisper backends for the encoder (if installed). Inference can be slower but helpful when GPU memory is limited | `False` |
|
||||
| `--custom-alignment-heads` | Use your own alignment heads, useful when `--model-dir` is used | `None` |
|
||||
| `--custom-alignment-heads` | Use your own alignment heads, useful when `--model-dir` is used. Use `scripts/determine_alignment_heads.py` to extract them. <img src="scripts/alignment_heads.png" alt="WhisperLiveKit Demo" width="300">
|
||||
| `None` |
|
||||
| `--frame-threshold` | AlignAtt frame threshold (lower = faster, higher = more accurate) | `25` |
|
||||
| `--beams` | Number of beams for beam search (1 = greedy decoding) | `1` |
|
||||
| `--decoder` | Force decoder type (`beam` or `greedy`) | `auto` |
|
||||
@@ -182,7 +183,6 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
| `--init-prompt` | Initial prompt for the model | `None` |
|
||||
| `--static-init-prompt` | Static prompt that doesn't scroll | `None` |
|
||||
| `--max-context-tokens` | Maximum context tokens | `None` |
|
||||
| `--model-path` | Direct path to .pt model file. Download it if not found | `./base.pt` |
|
||||
| `--preload-model-count` | Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent users) | `1` |
|
||||
|
||||
|
||||
|
||||
19
docs/models_compatible_formats.md
Normal file
19
docs/models_compatible_formats.md
Normal file
@@ -0,0 +1,19 @@
|
||||
# Model Path Formats
|
||||
|
||||
The `--model-path` parameter accepts:
|
||||
|
||||
## File Path
|
||||
- **`.pt` / `.bin` / `.safetensor` formats** Should be openable by pytorch/safetensor.
|
||||
|
||||
## Directory Path (recommended)
|
||||
Must contain:
|
||||
- **`.pt` / `.bin` / `.safetensor` file** (required for decoder)
|
||||
|
||||
May optionally contain:
|
||||
- **`.bin` file** - faster-whisper model for encoder (requires faster-whisper)
|
||||
- **`weights.npz`** or **`weights.safetensors`** - for encoder (requires whisper-mlx)
|
||||
|
||||
## Hugging Face Repo ID
|
||||
- Provide the repo ID (e.g. `openai/whisper-large-v3`) and WhisperLiveKit will download and cache the snapshot automatically. For gated repos, authenticate via `huggingface-cli login` first.
|
||||
|
||||
To improve speed/reduce allucinations, you may want to use `scripts/determine_alignment_heads.py` to determine the alignment heads to use for your model, and use the `--custom-alignment-heads` to pass them to WLK. If not, alignement heads are set to be all the heads of the last half layer of decoder.
|
||||
265
docs/supported_languages.md
Normal file
265
docs/supported_languages.md
Normal file
@@ -0,0 +1,265 @@
|
||||
# Supported Languages
|
||||
|
||||
WhisperLiveKit supports translation into **201 languages** from the FLORES-200 dataset through the NLLB (No Language Left Behind) translation system.
|
||||
|
||||
## How to Specify Languages
|
||||
|
||||
You can specify languages in **three different ways**:
|
||||
|
||||
1. **Language Name** (case-insensitive): `"English"`, `"French"`, `"Spanish"`
|
||||
2. **ISO Language Code**: `"en"`, `"fr"`, `"es"`
|
||||
3. **NLLB Code** (FLORES-200): `"eng_Latn"`, `"fra_Latn"`, `"spa_Latn"`
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Command Line
|
||||
```bash
|
||||
# Using language name
|
||||
whisperlivekit-server --target-language "French"
|
||||
|
||||
# Using ISO code
|
||||
whisperlivekit-server --target-language fr
|
||||
|
||||
# Using NLLB code
|
||||
whisperlivekit-server --target-language fra_Latn
|
||||
```
|
||||
|
||||
### Python API
|
||||
```python
|
||||
from nllw.translation import get_language_info
|
||||
|
||||
# Get language information by name
|
||||
lang_info = get_language_info("French")
|
||||
print(lang_info)
|
||||
# {'name': 'French', 'nllb': 'fra_Latn', 'language_code': 'fr'}
|
||||
|
||||
# Get language information by ISO code
|
||||
lang_info = get_language_info("fr")
|
||||
|
||||
# Get language information by NLLB code
|
||||
lang_info = get_language_info("fra_Latn")
|
||||
|
||||
# All three return the same result
|
||||
```
|
||||
|
||||
## Complete Language List
|
||||
|
||||
The following table lists all 201 supported languages with their corresponding codes:
|
||||
|
||||
| Language Name | ISO Code | NLLB Code |
|
||||
|---------------|----------|-----------|
|
||||
| Acehnese (Arabic script) | ace_Arab | ace_Arab |
|
||||
| Acehnese (Latin script) | ace_Latn | ace_Latn |
|
||||
| Mesopotamian Arabic | acm_Arab | acm_Arab |
|
||||
| Ta'izzi-Adeni Arabic | acq_Arab | acq_Arab |
|
||||
| Tunisian Arabic | aeb_Arab | aeb_Arab |
|
||||
| Afrikaans | af | afr_Latn |
|
||||
| South Levantine Arabic | ajp_Arab | ajp_Arab |
|
||||
| Akan | ak | aka_Latn |
|
||||
| Tosk Albanian | als | als_Latn |
|
||||
| Amharic | am | amh_Ethi |
|
||||
| North Levantine Arabic | apc_Arab | apc_Arab |
|
||||
| Modern Standard Arabic | ar | arb_Arab |
|
||||
| Modern Standard Arabic (Romanized) | arb_Latn | arb_Latn |
|
||||
| Najdi Arabic | ars_Arab | ars_Arab |
|
||||
| Moroccan Arabic | ary_Arab | ary_Arab |
|
||||
| Egyptian Arabic | arz_Arab | arz_Arab |
|
||||
| Assamese | as | asm_Beng |
|
||||
| Asturian | ast | ast_Latn |
|
||||
| Awadhi | awa | awa_Deva |
|
||||
| Central Aymara | ay | ayr_Latn |
|
||||
| South Azerbaijani | azb | azb_Arab |
|
||||
| North Azerbaijani | az | azj_Latn |
|
||||
| Bashkir | ba | bak_Cyrl |
|
||||
| Bambara | bm | bam_Latn |
|
||||
| Balinese | ban | ban_Latn |
|
||||
| Belarusian | be | bel_Cyrl |
|
||||
| Bemba | bem | bem_Latn |
|
||||
| Bengali | bn | ben_Beng |
|
||||
| Bhojpuri | bho | bho_Deva |
|
||||
| Banjar (Arabic script) | bjn_Arab | bjn_Arab |
|
||||
| Banjar (Latin script) | bjn_Latn | bjn_Latn |
|
||||
| Standard Tibetan | bo | bod_Tibt |
|
||||
| Bosnian | bs | bos_Latn |
|
||||
| Buginese | bug | bug_Latn |
|
||||
| Bulgarian | bg | bul_Cyrl |
|
||||
| Catalan | ca | cat_Latn |
|
||||
| Cebuano | ceb | ceb_Latn |
|
||||
| Czech | cs | ces_Latn |
|
||||
| Chokwe | cjk | cjk_Latn |
|
||||
| Central Kurdish | ckb | ckb_Arab |
|
||||
| Crimean Tatar | crh | crh_Latn |
|
||||
| Welsh | cy | cym_Latn |
|
||||
| Danish | da | dan_Latn |
|
||||
| German | de | deu_Latn |
|
||||
| Southwestern Dinka | dik | dik_Latn |
|
||||
| Dyula | dyu | dyu_Latn |
|
||||
| Dzongkha | dz | dzo_Tibt |
|
||||
| Greek | el | ell_Grek |
|
||||
| English | en | eng_Latn |
|
||||
| Esperanto | eo | epo_Latn |
|
||||
| Estonian | et | est_Latn |
|
||||
| Basque | eu | eus_Latn |
|
||||
| Ewe | ee | ewe_Latn |
|
||||
| Faroese | fo | fao_Latn |
|
||||
| Fijian | fj | fij_Latn |
|
||||
| Finnish | fi | fin_Latn |
|
||||
| Fon | fon | fon_Latn |
|
||||
| French | fr | fra_Latn |
|
||||
| Friulian | fur-IT | fur_Latn |
|
||||
| Nigerian Fulfulde | fuv | fuv_Latn |
|
||||
| West Central Oromo | om | gaz_Latn |
|
||||
| Scottish Gaelic | gd | gla_Latn |
|
||||
| Irish | ga-IE | gle_Latn |
|
||||
| Galician | gl | glg_Latn |
|
||||
| Guarani | gn | grn_Latn |
|
||||
| Gujarati | gu-IN | guj_Gujr |
|
||||
| Haitian Creole | ht | hat_Latn |
|
||||
| Hausa | ha | hau_Latn |
|
||||
| Hebrew | he | heb_Hebr |
|
||||
| Hindi | hi | hin_Deva |
|
||||
| Chhattisgarhi | hne | hne_Deva |
|
||||
| Croatian | hr | hrv_Latn |
|
||||
| Hungarian | hu | hun_Latn |
|
||||
| Armenian | hy-AM | hye_Armn |
|
||||
| Igbo | ig | ibo_Latn |
|
||||
| Ilocano | ilo | ilo_Latn |
|
||||
| Indonesian | id | ind_Latn |
|
||||
| Icelandic | is | isl_Latn |
|
||||
| Italian | it | ita_Latn |
|
||||
| Javanese | jv | jav_Latn |
|
||||
| Japanese | ja | jpn_Jpan |
|
||||
| Kabyle | kab | kab_Latn |
|
||||
| Jingpho | kac | kac_Latn |
|
||||
| Kamba | kam | kam_Latn |
|
||||
| Kannada | kn | kan_Knda |
|
||||
| Kashmiri (Arabic script) | kas_Arab | kas_Arab |
|
||||
| Kashmiri (Devanagari script) | kas_Deva | kas_Deva |
|
||||
| Georgian | ka | kat_Geor |
|
||||
| Kazakh | kk | kaz_Cyrl |
|
||||
| Kabiyè | kbp | kbp_Latn |
|
||||
| Kabuverdianu | kea | kea_Latn |
|
||||
| Halh Mongolian | mn | khk_Cyrl |
|
||||
| Khmer | km | khm_Khmr |
|
||||
| Kikuyu | ki | kik_Latn |
|
||||
| Kinyarwanda | rw | kin_Latn |
|
||||
| Kyrgyz | ky | kir_Cyrl |
|
||||
| Kimbundu | kmb | kmb_Latn |
|
||||
| Northern Kurdish | kmr | kmr_Latn |
|
||||
| Central Kanuri (Arabic script) | knc_Arab | knc_Arab |
|
||||
| Central Kanuri (Latin script) | knc_Latn | knc_Latn |
|
||||
| Kikongo | kg | kon_Latn |
|
||||
| Korean | ko | kor_Hang |
|
||||
| Lao | lo | lao_Laoo |
|
||||
| Ligurian | lij | lij_Latn |
|
||||
| Limburgish | li | lim_Latn |
|
||||
| Lingala | ln | lin_Latn |
|
||||
| Lithuanian | lt | lit_Latn |
|
||||
| Lombard | lmo | lmo_Latn |
|
||||
| Latgalian | ltg | ltg_Latn |
|
||||
| Luxembourgish | lb | ltz_Latn |
|
||||
| Luba-Kasai | lua | lua_Latn |
|
||||
| Ganda | lg | lug_Latn |
|
||||
| Luo | luo | luo_Latn |
|
||||
| Mizo | lus | lus_Latn |
|
||||
| Standard Latvian | lv | lvs_Latn |
|
||||
| Magahi | mag | mag_Deva |
|
||||
| Maithili | mai | mai_Deva |
|
||||
| Malayalam | ml-IN | mal_Mlym |
|
||||
| Marathi | mr | mar_Deva |
|
||||
| Minangkabau (Arabic script) | min_Arab | min_Arab |
|
||||
| Minangkabau (Latin script) | min_Latn | min_Latn |
|
||||
| Macedonian | mk | mkd_Cyrl |
|
||||
| Maltese | mt | mlt_Latn |
|
||||
| Meitei (Bengali script) | mni | mni_Beng |
|
||||
| Mossi | mos | mos_Latn |
|
||||
| Maori | mi | mri_Latn |
|
||||
| Burmese | my | mya_Mymr |
|
||||
| Dutch | nl | nld_Latn |
|
||||
| Norwegian Nynorsk | nn-NO | nno_Latn |
|
||||
| Norwegian Bokmål | nb | nob_Latn |
|
||||
| Nepali | ne-NP | npi_Deva |
|
||||
| Northern Sotho | nso | nso_Latn |
|
||||
| Nuer | nus | nus_Latn |
|
||||
| Nyanja | ny | nya_Latn |
|
||||
| Occitan | oc | oci_Latn |
|
||||
| Odia | or | ory_Orya |
|
||||
| Pangasinan | pag | pag_Latn |
|
||||
| Eastern Panjabi | pa | pan_Guru |
|
||||
| Papiamento | pap | pap_Latn |
|
||||
| Southern Pashto | pbt | pbt_Arab |
|
||||
| Western Persian | fa | pes_Arab |
|
||||
| Plateau Malagasy | mg | plt_Latn |
|
||||
| Polish | pl | pol_Latn |
|
||||
| Portuguese | pt-PT | por_Latn |
|
||||
| Dari | fa-AF | prs_Arab |
|
||||
| Ayacucho Quechua | qu | quy_Latn |
|
||||
| Romanian | ro | ron_Latn |
|
||||
| Rundi | rn | run_Latn |
|
||||
| Russian | ru | rus_Cyrl |
|
||||
| Sango | sg | sag_Latn |
|
||||
| Sanskrit | sa | san_Deva |
|
||||
| Santali | sat | sat_Olck |
|
||||
| Sicilian | scn | scn_Latn |
|
||||
| Shan | shn | shn_Mymr |
|
||||
| Sinhala | si-LK | sin_Sinh |
|
||||
| Slovak | sk | slk_Latn |
|
||||
| Slovenian | sl | slv_Latn |
|
||||
| Samoan | sm | smo_Latn |
|
||||
| Shona | sn | sna_Latn |
|
||||
| Sindhi | sd | snd_Arab |
|
||||
| Somali | so | som_Latn |
|
||||
| Southern Sotho | st | sot_Latn |
|
||||
| Spanish | es-ES | spa_Latn |
|
||||
| Sardinian | sc | srd_Latn |
|
||||
| Serbian | sr | srp_Cyrl |
|
||||
| Swati | ss | ssw_Latn |
|
||||
| Sundanese | su | sun_Latn |
|
||||
| Swedish | sv-SE | swe_Latn |
|
||||
| Swahili | sw | swh_Latn |
|
||||
| Silesian | szl | szl_Latn |
|
||||
| Tamil | ta | tam_Taml |
|
||||
| Tamasheq (Latin script) | taq_Latn | taq_Latn |
|
||||
| Tamasheq (Tifinagh script) | taq_Tfng | taq_Tfng |
|
||||
| Tatar | tt-RU | tat_Cyrl |
|
||||
| Telugu | te | tel_Telu |
|
||||
| Tajik | tg | tgk_Cyrl |
|
||||
| Tagalog | tl | tgl_Latn |
|
||||
| Thai | th | tha_Thai |
|
||||
| Tigrinya | ti | tir_Ethi |
|
||||
| Tok Pisin | tpi | tpi_Latn |
|
||||
| Tswana | tn | tsn_Latn |
|
||||
| Tsonga | ts | tso_Latn |
|
||||
| Turkmen | tk | tuk_Latn |
|
||||
| Tumbuka | tum | tum_Latn |
|
||||
| Turkish | tr | tur_Latn |
|
||||
| Twi | tw | twi_Latn |
|
||||
| Central Atlas Tamazight | tzm | tzm_Tfng |
|
||||
| Uyghur | ug | uig_Arab |
|
||||
| Ukrainian | uk | ukr_Cyrl |
|
||||
| Umbundu | umb | umb_Latn |
|
||||
| Urdu | ur | urd_Arab |
|
||||
| Northern Uzbek | uz | uzn_Latn |
|
||||
| Venetian | vec | vec_Latn |
|
||||
| Vietnamese | vi | vie_Latn |
|
||||
| Waray | war | war_Latn |
|
||||
| Wolof | wo | wol_Latn |
|
||||
| Xhosa | xh | xho_Latn |
|
||||
| Eastern Yiddish | yi | ydd_Hebr |
|
||||
| Yoruba | yo | yor_Latn |
|
||||
| Yue Chinese | yue | yue_Hant |
|
||||
| Chinese (Simplified) | zh-CN | zho_Hans |
|
||||
| Chinese (Traditional) | zh-TW | zho_Hant |
|
||||
| Standard Malay | ms | zsm_Latn |
|
||||
| Zulu | zu | zul_Latn |
|
||||
|
||||
## Special Features
|
||||
|
||||
### Multiple Script Support
|
||||
Several languages are available in multiple scripts (e.g., Arabic and Latin):
|
||||
- **Acehnese**: Arabic (`ace_Arab`) and Latin (`ace_Latn`)
|
||||
- **Banjar**: Arabic (`bjn_Arab`) and Latin (`bjn_Latn`)
|
||||
- **Kashmiri**: Arabic (`kas_Arab`) and Devanagari (`kas_Deva`)
|
||||
- **Minangkabau**: Arabic (`min_Arab`) and Latin (`min_Latn`)
|
||||
- **Tamasheq**: Latin (`taq_Latn`) and Tifinagh (`taq_Tfng`)
|
||||
- **Central Kanuri**: Arabic (`knc_Arab`) and Latin (`knc_Latn`)
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "whisperlivekit"
|
||||
version = "0.2.12"
|
||||
version = "0.2.14"
|
||||
description = "Real-time speech-to-text with speaker diarization using Whisper"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
@@ -30,28 +30,41 @@ dependencies = [
|
||||
"fastapi",
|
||||
"librosa",
|
||||
"soundfile",
|
||||
"faster-whisper",
|
||||
"uvicorn",
|
||||
"websockets",
|
||||
"torchaudio>=2.0.0",
|
||||
"torch>=2.0.0",
|
||||
"huggingface-hub>=0.25.0",
|
||||
"tqdm",
|
||||
"tiktoken",
|
||||
'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")'
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
sentence = ["mosestokenizer", "wtpsplit"]
|
||||
translation = ["nllw"]
|
||||
sentence_tokenizer = ["mosestokenizer", "wtpsplit"]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/QuentinFuxa/WhisperLiveKit"
|
||||
|
||||
[project.scripts]
|
||||
whisperlivekit-server = "whisperlivekit.basic_server:main"
|
||||
wlk = "whisperlivekit.basic_server:main"
|
||||
|
||||
[tool.setuptools]
|
||||
packages = ["whisperlivekit", "whisperlivekit.diarization", "whisperlivekit.simul_whisper", "whisperlivekit.simul_whisper.whisper", "whisperlivekit.simul_whisper.whisper.assets", "whisperlivekit.simul_whisper.whisper.normalizers", "whisperlivekit.web", "whisperlivekit.whisper_streaming_custom", "whisperlivekit.translation"]
|
||||
packages = [
|
||||
"whisperlivekit",
|
||||
"whisperlivekit.diarization",
|
||||
"whisperlivekit.simul_whisper",
|
||||
"whisperlivekit.whisper",
|
||||
"whisperlivekit.whisper.assets",
|
||||
"whisperlivekit.whisper.normalizers",
|
||||
"whisperlivekit.web",
|
||||
"whisperlivekit.local_agreement",
|
||||
"whisperlivekit.vad_models"
|
||||
]
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"]
|
||||
"whisperlivekit.simul_whisper.whisper.assets" = ["*.tiktoken", "*.npz"]
|
||||
"whisperlivekit.whisper.assets" = ["*.tiktoken", "*.npz"]
|
||||
"whisperlivekit.vad_models" = ["*.jit", "*.onnx"]
|
||||
|
||||
BIN
scripts/alignment_heads.png
Normal file
BIN
scripts/alignment_heads.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 276 KiB |
153
scripts/convert_hf_whisper.py
Normal file
153
scripts/convert_hf_whisper.py
Normal file
@@ -0,0 +1,153 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Convert a Hugging Face style Whisper checkpoint into a WhisperLiveKit .pt file.
|
||||
|
||||
Optionally shrink the supported audio chunk length (in seconds) by trimming the
|
||||
encoder positional embeddings and updating the stored model dimensions.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from whisperlivekit.whisper.audio import HOP_LENGTH, SAMPLE_RATE
|
||||
from whisperlivekit.whisper.model import ModelDimensions
|
||||
from whisperlivekit.whisper.utils import exact_div
|
||||
from whisperlivekit.whisper import _convert_hf_state_dict
|
||||
|
||||
|
||||
def _load_state_dict(repo_path: Path) -> Dict[str, torch.Tensor]:
|
||||
safetensor_path = repo_path / "model.safetensors"
|
||||
bin_path = repo_path / "pytorch_model.bin"
|
||||
|
||||
if safetensor_path.is_file():
|
||||
try:
|
||||
from safetensors.torch import load_file # type: ignore
|
||||
except Exception as exc: # pragma: no cover - import guard
|
||||
raise RuntimeError(
|
||||
"Install safetensors to load model.safetensors "
|
||||
"(pip install safetensors)"
|
||||
) from exc
|
||||
return load_file(str(safetensor_path))
|
||||
|
||||
if bin_path.is_file():
|
||||
return torch.load(bin_path, map_location="cpu")
|
||||
|
||||
raise FileNotFoundError(
|
||||
f"Could not find model.safetensors or pytorch_model.bin under {repo_path}"
|
||||
)
|
||||
|
||||
|
||||
def _load_config(repo_path: Path) -> Dict:
|
||||
config_path = repo_path / "config.json"
|
||||
if not config_path.is_file():
|
||||
raise FileNotFoundError(
|
||||
f"Hugging Face checkpoint at {repo_path} is missing config.json"
|
||||
)
|
||||
with open(config_path, "r", encoding="utf-8") as fp:
|
||||
return json.load(fp)
|
||||
|
||||
|
||||
def _derive_audio_ctx(chunk_length: float) -> Tuple[int, int]:
|
||||
n_samples = int(round(chunk_length * SAMPLE_RATE))
|
||||
expected_samples = chunk_length * SAMPLE_RATE
|
||||
if abs(n_samples - expected_samples) > 1e-6:
|
||||
raise ValueError(
|
||||
"chunk_length must align with sample rate so that "
|
||||
"chunk_length * SAMPLE_RATE is an integer"
|
||||
)
|
||||
n_frames = exact_div(n_samples, HOP_LENGTH)
|
||||
n_audio_ctx = exact_div(n_frames, 2)
|
||||
return n_frames, n_audio_ctx
|
||||
|
||||
|
||||
def _build_dims(config: Dict, chunk_length: float) -> Dict:
|
||||
base_dims = ModelDimensions(
|
||||
n_mels=config["num_mel_bins"],
|
||||
n_audio_ctx=config["max_source_positions"],
|
||||
n_audio_state=config["d_model"],
|
||||
n_audio_head=config["encoder_attention_heads"],
|
||||
n_audio_layer=config.get("encoder_layers") or config["num_hidden_layers"],
|
||||
n_vocab=config["vocab_size"],
|
||||
n_text_ctx=config["max_target_positions"],
|
||||
n_text_state=config["d_model"],
|
||||
n_text_head=config["decoder_attention_heads"],
|
||||
n_text_layer=config["decoder_layers"],
|
||||
).__dict__.copy()
|
||||
|
||||
_, n_audio_ctx = _derive_audio_ctx(chunk_length)
|
||||
base_dims["n_audio_ctx"] = n_audio_ctx
|
||||
base_dims["chunk_length"] = chunk_length
|
||||
return base_dims
|
||||
|
||||
|
||||
def _trim_positional_embedding(
|
||||
state_dict: Dict[str, torch.Tensor], target_ctx: int
|
||||
) -> None:
|
||||
key = "encoder.positional_embedding"
|
||||
if key not in state_dict:
|
||||
raise KeyError(f"{key} missing from converted state dict")
|
||||
|
||||
tensor = state_dict[key]
|
||||
if tensor.shape[0] < target_ctx:
|
||||
raise ValueError(
|
||||
f"Cannot increase encoder ctx from {tensor.shape[0]} to {target_ctx}"
|
||||
)
|
||||
if tensor.shape[0] == target_ctx:
|
||||
return
|
||||
state_dict[key] = tensor[:target_ctx].contiguous()
|
||||
|
||||
|
||||
def convert_checkpoint(hf_path: Path, output_path: Path, chunk_length: float) -> None:
|
||||
state_dict = _load_state_dict(hf_path)
|
||||
converted = _convert_hf_state_dict(state_dict)
|
||||
|
||||
config = _load_config(hf_path)
|
||||
dims = _build_dims(config, chunk_length)
|
||||
|
||||
_trim_positional_embedding(converted, dims["n_audio_ctx"])
|
||||
|
||||
package = {"dims": dims, "model_state_dict": converted}
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(package, output_path)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert Hugging Face Whisper checkpoint to WhisperLiveKit format."
|
||||
)
|
||||
parser.add_argument(
|
||||
"hf_path",
|
||||
type=str,
|
||||
help="Path to the cloned Hugging Face repository (e.g. whisper-tiny.en)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="converted-whisper.pt",
|
||||
help="Destination path for the .pt file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chunk-length",
|
||||
type=float,
|
||||
default=30.0,
|
||||
help="Audio chunk length in seconds to support (default: 30)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
hf_path = Path(os.path.expanduser(args.hf_path)).resolve()
|
||||
output_path = Path(os.path.expanduser(args.output)).resolve()
|
||||
|
||||
convert_checkpoint(hf_path, output_path, args.chunk_length)
|
||||
print(f"Saved converted checkpoint to {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
292
scripts/determine_alignment_heads.py
Normal file
292
scripts/determine_alignment_heads.py
Normal file
@@ -0,0 +1,292 @@
|
||||
"""Determine alignment heads for a variants, such as distilled model"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import gzip
|
||||
import io
|
||||
import pathlib
|
||||
import sys
|
||||
import math
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import Audio as DatasetAudio, load_dataset
|
||||
import soundfile as sf
|
||||
import matplotlib.pyplot as plt
|
||||
REPO_ROOT = pathlib.Path(__file__).resolve().parents[1]
|
||||
WHISPER_ROOT = REPO_ROOT / "whisper"
|
||||
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
sys.path.insert(0, str(WHISPER_ROOT))
|
||||
|
||||
from whisper import load_model
|
||||
from whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||
from whisper.tokenizer import get_tokenizer
|
||||
|
||||
AudioInput = Union[str, pathlib.Path, np.ndarray, torch.Tensor]
|
||||
|
||||
|
||||
def load_dataset_clips(name, config, split, limit):
|
||||
ds = load_dataset(name, config, split=split)
|
||||
ds = ds.cast_column("audio", DatasetAudio(decode=False))
|
||||
clips = []
|
||||
for idx, row in enumerate(ds):
|
||||
if limit is not None and idx >= limit:
|
||||
break
|
||||
audio_field = row["audio"]
|
||||
transcript = row["text"]
|
||||
|
||||
waveform_np, _ = sf.read(io.BytesIO(audio_field["bytes"]), dtype="float32")
|
||||
if waveform_np.ndim > 1:
|
||||
waveform_np = waveform_np.mean(axis=1)
|
||||
waveform = waveform_np
|
||||
transcript = str(transcript)
|
||||
|
||||
clips.append((waveform, transcript))
|
||||
return clips
|
||||
|
||||
|
||||
def load_clips(args):
|
||||
return load_dataset_clips(
|
||||
args.dataset,
|
||||
args.dataset_config,
|
||||
args.dataset_split,
|
||||
args.dataset_num_samples,
|
||||
)
|
||||
|
||||
|
||||
def _waveform_from_source(source: AudioInput) -> torch.Tensor:
|
||||
waveform = torch.from_numpy(source.astype(np.float32, copy=False))
|
||||
return waveform
|
||||
|
||||
|
||||
def _parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="pytorch_model.bin",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||
help="Torch device to run on",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default="librispeech_asr"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-config",
|
||||
type=str,
|
||||
default="clean"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-split",
|
||||
type=str,
|
||||
default="validation[:1%]",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-num-samples",
|
||||
type=int,
|
||||
default=16,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--threshold",
|
||||
type=float,
|
||||
default=1.5,
|
||||
help="Z score threshold for a head to be selected",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--votes",
|
||||
type=float,
|
||||
default=0.75,
|
||||
help="percentage of clips that must vote for a head",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="alignment_heads.b85",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--visualize-top-k",
|
||||
type=int,
|
||||
default=32,
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def collect_heads(
|
||||
model,
|
||||
tokenizer,
|
||||
clips: Sequence[Tuple[AudioInput, str]],
|
||||
threshold: float,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
device = model.device
|
||||
votes = torch.zeros(model.dims.n_text_layer, model.dims.n_text_head, device=device)
|
||||
strengths = torch.zeros_like(votes)
|
||||
|
||||
for audio_source, transcript in clips:
|
||||
waveform = pad_or_trim(_waveform_from_source(audio_source))
|
||||
mel = log_mel_spectrogram(waveform, device=device)
|
||||
|
||||
tokens = torch.tensor(
|
||||
[
|
||||
*tokenizer.sot_sequence,
|
||||
tokenizer.no_timestamps,
|
||||
*tokenizer.encode(transcript),
|
||||
tokenizer.eot,
|
||||
],
|
||||
device=device,
|
||||
)
|
||||
|
||||
qks = [None] * model.dims.n_text_layer
|
||||
hooks = [
|
||||
block.cross_attn.register_forward_hook(
|
||||
lambda _, __, outputs, index=i: qks.__setitem__(index, outputs[-1][0])
|
||||
)
|
||||
for i, block in enumerate(model.decoder.blocks)
|
||||
]
|
||||
|
||||
with torch.no_grad():
|
||||
model(mel.unsqueeze(0), tokens.unsqueeze(0))
|
||||
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
|
||||
for layer_idx, tensor in enumerate(qks):
|
||||
if tensor is None:
|
||||
continue
|
||||
tensor = tensor[:, :, : mel.shape[-1] // 2]
|
||||
tensor = tensor.softmax(dim=-1)
|
||||
peak = tensor.max(dim=-1).values # [heads, tokens]
|
||||
strengths[layer_idx] += peak.mean(dim=-1)
|
||||
zscore = (peak - peak.mean(dim=-1, keepdim=True)) / (
|
||||
peak.std(dim=-1, keepdim=True, unbiased=False) + 1e-6
|
||||
)
|
||||
mask = (zscore > 3).any(dim=-1)
|
||||
votes[layer_idx] += mask.float()
|
||||
|
||||
votes /= len(clips)
|
||||
strengths /= len(clips)
|
||||
return votes, strengths
|
||||
|
||||
|
||||
def _select_heads_for_visualization(selection, strengths, top_k):
|
||||
selected = torch.nonzero(selection, as_tuple=False)
|
||||
if selected.numel() == 0:
|
||||
return []
|
||||
|
||||
entries = [
|
||||
(int(layer.item()), int(head.item()), float(strengths[layer, head].item()))
|
||||
for layer, head in selected
|
||||
]
|
||||
entries.sort(key=lambda item: item[2], reverse=True)
|
||||
return entries[:top_k]
|
||||
|
||||
def _extract_heatmaps(
|
||||
model,
|
||||
tokenizer,
|
||||
clip: Tuple[AudioInput, str],
|
||||
heads: Sequence[Tuple[int, int, float]],
|
||||
) -> dict:
|
||||
if not heads:
|
||||
return {}
|
||||
|
||||
target_map = {}
|
||||
for layer, head, _ in heads:
|
||||
target_map.setdefault(layer, set()).add(head)
|
||||
|
||||
waveform = pad_or_trim(_waveform_from_source(clip[0]))
|
||||
mel = log_mel_spectrogram(waveform, device=model.device)
|
||||
transcript = clip[1]
|
||||
tokens = torch.tensor(
|
||||
[
|
||||
*tokenizer.sot_sequence,
|
||||
tokenizer.no_timestamps,
|
||||
*tokenizer.encode(transcript),
|
||||
tokenizer.eot,
|
||||
],
|
||||
device=model.device,
|
||||
)
|
||||
|
||||
QKs = [None] * model.dims.n_text_layer
|
||||
hooks = [
|
||||
block.cross_attn.register_forward_hook(
|
||||
lambda _, __, outputs, index=i: QKs.__setitem__(index, outputs[-1][0])
|
||||
)
|
||||
for i, block in enumerate(model.decoder.blocks)
|
||||
]
|
||||
|
||||
with torch.no_grad():
|
||||
model(mel.unsqueeze(0), tokens.unsqueeze(0))
|
||||
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
|
||||
heatmaps = {}
|
||||
for layer_idx, tensor in enumerate(QKs):
|
||||
if tensor is None or layer_idx not in target_map:
|
||||
continue
|
||||
tensor = tensor[:, :, : mel.shape[-1] // 2]
|
||||
tensor = tensor.softmax(dim=-1).cpu()
|
||||
for head_idx in target_map[layer_idx]:
|
||||
heatmaps[(layer_idx, head_idx)] = tensor[head_idx]
|
||||
|
||||
return heatmaps
|
||||
|
||||
|
||||
def _plot_heatmaps(
|
||||
heads, heatmaps, output_path):
|
||||
cols = min(3, len(heads))
|
||||
rows = math.ceil(len(heads) / cols)
|
||||
fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 3.2 * rows), squeeze=False)
|
||||
|
||||
for idx, (layer, head, score) in enumerate(heads):
|
||||
ax = axes[idx // cols][idx % cols]
|
||||
mat = heatmaps.get((layer, head))
|
||||
if mat is None:
|
||||
ax.axis("off")
|
||||
continue
|
||||
im = ax.imshow(mat.to(torch.float32).numpy(), aspect="auto", origin="lower")
|
||||
ax.set_title(f"L{layer} H{head} · score {score:.2f}")
|
||||
ax.set_xlabel("time")
|
||||
ax.set_ylabel("tokens")
|
||||
|
||||
for j in range(len(heads), rows * cols):
|
||||
axes[j // cols][j % cols].axis("off")
|
||||
|
||||
fig.tight_layout()
|
||||
fig.savefig(output_path, dpi=200)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def _dump_mask(mask: torch.Tensor, output_path: str):
|
||||
payload = mask.numpy().astype(np.bool_)
|
||||
blob = base64.b85encode(gzip.compress(payload.tobytes()))
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(blob)
|
||||
|
||||
|
||||
def main():
|
||||
args = _parse_args()
|
||||
model = load_model(args.model, device=args.device)
|
||||
model.eval()
|
||||
tokenizer = get_tokenizer(multilingual=model.is_multilingual)
|
||||
clips = load_clips(args)
|
||||
|
||||
votes, strengths = collect_heads(model, tokenizer, clips, args.threshold)
|
||||
# selection = votes > 0.5
|
||||
selection = strengths > 0.05
|
||||
_dump_mask(selection.cpu(), args.output)
|
||||
|
||||
viz_heads = _select_heads_for_visualization(selection, strengths, args.visualize_top_k)
|
||||
heatmaps = _extract_heatmaps(model, tokenizer, clips[0], viz_heads)
|
||||
_plot_heatmaps(viz_heads, heatmaps, "alignment_heads.png")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,9 +1,10 @@
|
||||
"""Copy core files from web directory to Chrome extension directory."""
|
||||
|
||||
import shutil
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
def sync_extension_files():
|
||||
"""Copy core files from web directory to Chrome extension directory."""
|
||||
|
||||
web_dir = Path("whisperlivekit/web")
|
||||
extension_dir = Path("chrome-extension")
|
||||
@@ -67,20 +67,17 @@ class AudioProcessor:
|
||||
self.is_stopping = False
|
||||
self.silence = False
|
||||
self.silence_duration = 0.0
|
||||
self.tokens = []
|
||||
self.last_validated_token = 0
|
||||
self.translated_segments = []
|
||||
self.buffer_transcription = Transcript()
|
||||
self.end_buffer = 0
|
||||
self.end_attributed_speaker = 0
|
||||
self.state = State()
|
||||
self.lock = asyncio.Lock()
|
||||
self.beg_loop = 0.0 #to deal with a potential little lag at the websocket initialization, this is now set in process_audio
|
||||
self.sep = " " # Default separator
|
||||
self.last_response_content = FrontData()
|
||||
self.last_detected_speaker = None
|
||||
self.speaker_languages = {}
|
||||
self.diarization_before_transcription = False
|
||||
|
||||
self.segments = []
|
||||
|
||||
|
||||
if self.diarization_before_transcription:
|
||||
self.cumulative_pcm = []
|
||||
self.last_start = 0.0
|
||||
@@ -138,8 +135,8 @@ class AudioProcessor:
|
||||
async def add_dummy_token(self):
|
||||
"""Placeholder token when no transcription is available."""
|
||||
async with self.lock:
|
||||
current_time = time() - self.beg_loop
|
||||
self.tokens.append(ASRToken(
|
||||
current_time = time() - self.state.beg_loop
|
||||
self.state.tokens.append(ASRToken(
|
||||
start=current_time, end=current_time + 1,
|
||||
text=".", speaker=-1, is_dummy=True
|
||||
))
|
||||
@@ -149,35 +146,19 @@ class AudioProcessor:
|
||||
async with self.lock:
|
||||
current_time = time()
|
||||
|
||||
# Calculate remaining times
|
||||
remaining_transcription = 0
|
||||
if self.end_buffer > 0:
|
||||
remaining_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 1))
|
||||
if self.state.end_buffer > 0:
|
||||
remaining_transcription = max(0, round(current_time - self.state.beg_loop - self.state.end_buffer, 1))
|
||||
|
||||
remaining_diarization = 0
|
||||
if self.tokens:
|
||||
latest_end = max(self.end_buffer, self.tokens[-1].end if self.tokens else 0)
|
||||
remaining_diarization = max(0, round(latest_end - self.end_attributed_speaker, 1))
|
||||
if self.state.tokens:
|
||||
latest_end = max(self.state.end_buffer, self.state.tokens[-1].end if self.state.tokens else 0)
|
||||
remaining_diarization = max(0, round(latest_end - self.state.end_attributed_speaker, 1))
|
||||
|
||||
return State(
|
||||
tokens=self.tokens.copy(),
|
||||
last_validated_token=self.last_validated_token,
|
||||
translated_segments=self.translated_segments.copy(),
|
||||
buffer_transcription=self.buffer_transcription,
|
||||
end_buffer=self.end_buffer,
|
||||
end_attributed_speaker=self.end_attributed_speaker,
|
||||
remaining_time_transcription=remaining_transcription,
|
||||
remaining_time_diarization=remaining_diarization
|
||||
)
|
||||
self.state.remaining_time_transcription = remaining_transcription
|
||||
self.state.remaining_time_diarization = remaining_diarization
|
||||
|
||||
async def reset(self):
|
||||
"""Reset all state variables to initial values."""
|
||||
async with self.lock:
|
||||
self.tokens = []
|
||||
self.translated_segments = []
|
||||
self.buffer_transcription = Transcript()
|
||||
self.end_buffer = self.end_attributed_speaker = 0
|
||||
self.beg_loop = time()
|
||||
return self.state
|
||||
|
||||
async def ffmpeg_stdout_reader(self):
|
||||
"""Read audio data from FFmpeg stdout and process it into the PCM pipeline."""
|
||||
@@ -242,15 +223,15 @@ class AudioProcessor:
|
||||
break
|
||||
|
||||
asr_internal_buffer_duration_s = len(getattr(self.transcription, 'audio_buffer', [])) / self.transcription.SAMPLING_RATE
|
||||
transcription_lag_s = max(0.0, time() - self.beg_loop - self.end_buffer)
|
||||
transcription_lag_s = max(0.0, time() - self.state.beg_loop - self.state.end_buffer)
|
||||
asr_processing_logs = f"internal_buffer={asr_internal_buffer_duration_s:.2f}s | lag={transcription_lag_s:.2f}s |"
|
||||
if type(item) is Silence:
|
||||
asr_processing_logs += f" + Silence of = {item.duration:.2f}s"
|
||||
if self.tokens:
|
||||
asr_processing_logs += f" | last_end = {self.tokens[-1].end} |"
|
||||
if self.state.tokens:
|
||||
asr_processing_logs += f" | last_end = {self.state.tokens[-1].end} |"
|
||||
logger.info(asr_processing_logs)
|
||||
cumulative_pcm_duration_stream_time += item.duration
|
||||
self.transcription.insert_silence(item.duration, self.tokens[-1].end if self.tokens else 0)
|
||||
self.transcription.insert_silence(item.duration, self.state.tokens[-1].end if self.state.tokens else 0)
|
||||
continue
|
||||
elif isinstance(item, ChangeSpeaker):
|
||||
self.transcription.new_speaker(item)
|
||||
@@ -274,7 +255,7 @@ class AudioProcessor:
|
||||
if buffer_text.startswith(validated_text):
|
||||
_buffer_transcript.text = buffer_text[len(validated_text):].lstrip()
|
||||
|
||||
candidate_end_times = [self.end_buffer]
|
||||
candidate_end_times = [self.state.end_buffer]
|
||||
|
||||
if new_tokens:
|
||||
candidate_end_times.append(new_tokens[-1].end)
|
||||
@@ -285,9 +266,9 @@ class AudioProcessor:
|
||||
candidate_end_times.append(current_audio_processed_upto)
|
||||
|
||||
async with self.lock:
|
||||
self.tokens.extend(new_tokens)
|
||||
self.buffer_transcription = _buffer_transcript
|
||||
self.end_buffer = max(candidate_end_times)
|
||||
self.state.tokens.extend(new_tokens)
|
||||
self.state.buffer_transcription = _buffer_transcript
|
||||
self.state.end_buffer = max(candidate_end_times)
|
||||
|
||||
if self.translation_queue:
|
||||
for token in new_tokens:
|
||||
@@ -360,12 +341,12 @@ class AudioProcessor:
|
||||
self.last_end = last_segment.end
|
||||
elif not self.diarization_before_transcription:
|
||||
async with self.lock:
|
||||
self.tokens = diarization_obj.assign_speakers_to_tokens(
|
||||
self.tokens,
|
||||
self.state.tokens = diarization_obj.assign_speakers_to_tokens(
|
||||
self.state.tokens,
|
||||
use_punctuation_split=self.args.punctuation_split
|
||||
)
|
||||
if len(self.tokens) > 0:
|
||||
self.end_attributed_speaker = max(self.tokens[-1].end, self.end_attributed_speaker)
|
||||
if len(self.state.tokens) > 0:
|
||||
self.state.end_attributed_speaker = max(self.state.tokens[-1].end, self.state.end_attributed_speaker)
|
||||
self.diarization_queue.task_done()
|
||||
|
||||
except Exception as e:
|
||||
@@ -406,7 +387,10 @@ class AudioProcessor:
|
||||
tokens_to_process.append(additional_token)
|
||||
if tokens_to_process:
|
||||
self.translation.insert_tokens(tokens_to_process)
|
||||
self.translated_segments = await asyncio.to_thread(self.translation.process)
|
||||
translation_validated_segments, buffer_translation = await asyncio.to_thread(self.translation.process)
|
||||
async with self.lock:
|
||||
self.state.translation_validated_segments = translation_validated_segments
|
||||
self.state.buffer_translation = buffer_translation
|
||||
self.translation_queue.task_done()
|
||||
for _ in additional_tokens:
|
||||
self.translation_queue.task_done()
|
||||
@@ -437,11 +421,9 @@ class AudioProcessor:
|
||||
|
||||
state = await self.get_current_state()
|
||||
|
||||
|
||||
lines, undiarized_text = format_output(
|
||||
state,
|
||||
self.silence,
|
||||
current_time = time() - self.beg_loop,
|
||||
args = self.args,
|
||||
sep=self.sep
|
||||
)
|
||||
@@ -455,7 +437,13 @@ class AudioProcessor:
|
||||
buffer_diarization = self.sep.join(undiarized_text)
|
||||
|
||||
async with self.lock:
|
||||
self.end_attributed_speaker = state.end_attributed_speaker
|
||||
self.state.end_attributed_speaker = state.end_attributed_speaker
|
||||
|
||||
buffer_translation_text = ''
|
||||
if state.buffer_translation:
|
||||
raw_buffer_translation = getattr(state.buffer_translation, 'text', state.buffer_translation)
|
||||
if raw_buffer_translation:
|
||||
buffer_translation_text = raw_buffer_translation.strip()
|
||||
|
||||
response_status = "active_transcription"
|
||||
if not state.tokens and not buffer_transcription and not buffer_diarization:
|
||||
@@ -473,6 +461,7 @@ class AudioProcessor:
|
||||
lines=lines,
|
||||
buffer_transcription=buffer_transcription.text.strip(),
|
||||
buffer_diarization=buffer_diarization,
|
||||
buffer_translation=buffer_translation_text,
|
||||
remaining_time_transcription=state.remaining_time_transcription,
|
||||
remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0
|
||||
)
|
||||
@@ -482,23 +471,14 @@ class AudioProcessor:
|
||||
yield response
|
||||
self.last_response_content = response
|
||||
|
||||
# Check for termination condition
|
||||
if self.is_stopping:
|
||||
all_processors_done = True
|
||||
if self.args.transcription and self.transcription_task and not self.transcription_task.done():
|
||||
all_processors_done = False
|
||||
if self.args.diarization and self.diarization_task and not self.diarization_task.done():
|
||||
all_processors_done = False
|
||||
|
||||
if all_processors_done:
|
||||
logger.info("Results formatter: All upstream processors are done and in stopping state. Terminating.")
|
||||
return
|
||||
if self.is_stopping and self._processing_tasks_done():
|
||||
logger.info("Results formatter: All upstream processors are done and in stopping state. Terminating.")
|
||||
return
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in results_formatter: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
logger.warning(f"Exception in results_formatter. Traceback: {traceback.format_exc()}")
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
async def create_tasks(self):
|
||||
@@ -544,11 +524,16 @@ class AudioProcessor:
|
||||
|
||||
async def watchdog(self, tasks_to_monitor):
|
||||
"""Monitors the health of critical processing tasks."""
|
||||
tasks_remaining = [task for task in tasks_to_monitor if task]
|
||||
while True:
|
||||
try:
|
||||
if not tasks_remaining:
|
||||
logger.info("Watchdog task finishing: all monitored tasks completed.")
|
||||
return
|
||||
|
||||
await asyncio.sleep(10)
|
||||
|
||||
for i, task in enumerate(tasks_to_monitor):
|
||||
for i, task in enumerate(list(tasks_remaining)):
|
||||
if task.done():
|
||||
exc = task.exception()
|
||||
task_name = task.get_name() if hasattr(task, 'get_name') else f"Monitored Task {i}"
|
||||
@@ -556,6 +541,7 @@ class AudioProcessor:
|
||||
logger.error(f"{task_name} unexpectedly completed with exception: {exc}")
|
||||
else:
|
||||
logger.info(f"{task_name} completed normally.")
|
||||
tasks_remaining.remove(task)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Watchdog task cancelled.")
|
||||
@@ -586,12 +572,22 @@ class AudioProcessor:
|
||||
self.diarization.close()
|
||||
logger.info("AudioProcessor cleanup complete.")
|
||||
|
||||
def _processing_tasks_done(self):
|
||||
"""Return True when all active processing tasks have completed."""
|
||||
tasks_to_check = [
|
||||
self.transcription_task,
|
||||
self.diarization_task,
|
||||
self.translation_task,
|
||||
self.ffmpeg_reader_task,
|
||||
]
|
||||
return all(task.done() for task in tasks_to_check if task)
|
||||
|
||||
|
||||
async def process_audio(self, message):
|
||||
"""Process incoming audio data."""
|
||||
|
||||
if not self.beg_loop:
|
||||
self.beg_loop = time()
|
||||
if not self.state.beg_loop:
|
||||
self.state.beg_loop = time()
|
||||
|
||||
if not message:
|
||||
logger.info("Empty audio message received, initiating stop sequence.")
|
||||
|
||||
41
whisperlivekit/backend_support.py
Normal file
41
whisperlivekit/backend_support.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import importlib.util
|
||||
import logging
|
||||
import platform
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def module_available(module_name):
|
||||
"""Return True if the given module can be imported."""
|
||||
return importlib.util.find_spec(module_name) is not None
|
||||
|
||||
|
||||
def mlx_backend_available(warn_on_missing = False):
|
||||
is_macos = platform.system() == "Darwin"
|
||||
is_arm = platform.machine() == "arm64"
|
||||
available = (
|
||||
is_macos
|
||||
and is_arm
|
||||
and module_available("mlx_whisper")
|
||||
)
|
||||
if not available and warn_on_missing and is_macos and is_arm:
|
||||
logger.warning(
|
||||
"=" * 50
|
||||
+ "\nMLX Whisper not found but you are on Apple Silicon. "
|
||||
"Consider installing mlx-whisper for better performance: "
|
||||
"`pip install mlx-whisper`\n"
|
||||
+ "=" * 50
|
||||
)
|
||||
return available
|
||||
|
||||
|
||||
def faster_backend_available(warn_on_missing = False):
|
||||
available = module_available("faster_whisper")
|
||||
if not available and warn_on_missing and platform.system() != "Darwin":
|
||||
logger.warning(
|
||||
"=" * 50
|
||||
+ "\nFaster-Whisper not found. Consider installing faster-whisper "
|
||||
"for better performance: `pip install faster-whisper`\n"
|
||||
+ "=" * 50
|
||||
)
|
||||
return available
|
||||
@@ -1,11 +1,9 @@
|
||||
try:
|
||||
from whisperlivekit.whisper_streaming_custom.whisper_online import backend_factory
|
||||
from whisperlivekit.whisper_streaming_custom.online_asr import OnlineASRProcessor
|
||||
except ImportError:
|
||||
from .whisper_streaming_custom.whisper_online import backend_factory
|
||||
from .whisper_streaming_custom.online_asr import OnlineASRProcessor
|
||||
from whisperlivekit.local_agreement.whisper_online import backend_factory
|
||||
from whisperlivekit.simul_whisper import SimulStreamingASR
|
||||
from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor
|
||||
from argparse import Namespace
|
||||
import sys
|
||||
import logging
|
||||
|
||||
def update_with_kwargs(_dict, kwargs):
|
||||
_dict.update({
|
||||
@@ -13,6 +11,9 @@ def update_with_kwargs(_dict, kwargs):
|
||||
})
|
||||
return _dict
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TranscriptionEngine:
|
||||
_instance = None
|
||||
_initialized = False
|
||||
@@ -33,6 +34,7 @@ class TranscriptionEngine:
|
||||
"punctuation_split": False,
|
||||
"target_language": "",
|
||||
"vac": True,
|
||||
"vac_onnx": False,
|
||||
"vac_chunk_size": 0.04,
|
||||
"log_level": "DEBUG",
|
||||
"ssl_certfile": None,
|
||||
@@ -43,18 +45,20 @@ class TranscriptionEngine:
|
||||
"pcm_input": False,
|
||||
"disable_punctuation_split" : False,
|
||||
"diarization_backend": "sortformer",
|
||||
"backend_policy": "simulstreaming",
|
||||
"backend": "auto",
|
||||
}
|
||||
global_params = update_with_kwargs(global_params, kwargs)
|
||||
|
||||
transcription_common_params = {
|
||||
"backend": "simulstreaming",
|
||||
"warmup_file": None,
|
||||
"min_chunk_size": 0.5,
|
||||
"model_size": "tiny",
|
||||
"model_cache_dir": None,
|
||||
"model_dir": None,
|
||||
"model_path": None,
|
||||
"lan": "auto",
|
||||
"task": "transcribe",
|
||||
"direct_english_translation": False,
|
||||
}
|
||||
transcription_common_params = update_with_kwargs(transcription_common_params, kwargs)
|
||||
|
||||
@@ -75,13 +79,14 @@ class TranscriptionEngine:
|
||||
self.vac_model = None
|
||||
|
||||
if self.args.vac:
|
||||
import torch
|
||||
self.vac_model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
|
||||
from whisperlivekit.silero_vad_iterator import load_silero_vad
|
||||
# Use ONNX if specified, otherwise use JIT (default)
|
||||
use_onnx = kwargs.get('vac_onnx', False)
|
||||
self.vac_model = load_silero_vad(onnx=use_onnx)
|
||||
|
||||
backend_policy = self.args.backend_policy
|
||||
if self.args.transcription:
|
||||
if self.args.backend == "simulstreaming":
|
||||
from whisperlivekit.simul_whisper import SimulStreamingASR
|
||||
|
||||
if backend_policy == "simulstreaming":
|
||||
simulstreaming_params = {
|
||||
"disable_fast_encoder": False,
|
||||
"custom_alignment_heads": None,
|
||||
@@ -95,14 +100,19 @@ class TranscriptionEngine:
|
||||
"init_prompt": None,
|
||||
"static_init_prompt": None,
|
||||
"max_context_tokens": None,
|
||||
"model_path": './base.pt',
|
||||
"preload_model_count": 1,
|
||||
}
|
||||
simulstreaming_params = update_with_kwargs(simulstreaming_params, kwargs)
|
||||
|
||||
self.tokenizer = None
|
||||
self.asr = SimulStreamingASR(
|
||||
**transcription_common_params, **simulstreaming_params
|
||||
**transcription_common_params,
|
||||
**simulstreaming_params,
|
||||
backend=self.args.backend,
|
||||
)
|
||||
logger.info(
|
||||
"Using SimulStreaming policy with %s backend",
|
||||
getattr(self.asr, "encoder_backend", "whisper"),
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -114,7 +124,13 @@ class TranscriptionEngine:
|
||||
whisperstreaming_params = update_with_kwargs(whisperstreaming_params, kwargs)
|
||||
|
||||
self.asr = backend_factory(
|
||||
**transcription_common_params, **whisperstreaming_params
|
||||
backend=self.args.backend,
|
||||
**transcription_common_params,
|
||||
**whisperstreaming_params,
|
||||
)
|
||||
logger.info(
|
||||
"Using LocalAgreement policy with %s backend",
|
||||
getattr(self.asr, "backend_choice", self.asr.__class__.__name__),
|
||||
)
|
||||
|
||||
if self.args.diarization:
|
||||
@@ -135,12 +151,15 @@ class TranscriptionEngine:
|
||||
|
||||
self.translation_model = None
|
||||
if self.args.target_language:
|
||||
if self.args.lan == 'auto' and self.args.backend != "simulstreaming":
|
||||
if self.args.lan == 'auto' and backend_policy != "simulstreaming":
|
||||
raise Exception('Translation cannot be set with language auto when transcription backend is not simulstreaming')
|
||||
else:
|
||||
from whisperlivekit.translation.translation import load_model
|
||||
try:
|
||||
from nllw import load_model
|
||||
except:
|
||||
raise Exception('To use translation, you must install nllw: `pip install nllw`')
|
||||
translation_params = {
|
||||
"nllb_backend": "ctranslate2",
|
||||
"nllb_backend": "transformers",
|
||||
"nllb_size": "600M"
|
||||
}
|
||||
translation_params = update_with_kwargs(translation_params, kwargs)
|
||||
@@ -149,7 +168,7 @@ class TranscriptionEngine:
|
||||
|
||||
|
||||
def online_factory(args, asr):
|
||||
if args.backend == "simulstreaming":
|
||||
if args.backend_policy == "simulstreaming":
|
||||
from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor
|
||||
online = SimulStreamingOnlineProcessor(asr)
|
||||
else:
|
||||
@@ -172,5 +191,5 @@ def online_translation_factory(args, translation_model):
|
||||
#should be at speaker level in the future:
|
||||
#one shared nllb model for all speaker
|
||||
#one tokenizer per speaker/language
|
||||
from whisperlivekit.translation.translation import OnlineTranslation
|
||||
return OnlineTranslation(translation_model, [args.lan], [args.target_language])
|
||||
from nllw import OnlineTranslation
|
||||
return OnlineTranslation(translation_model, [args.lan], [args.target_language])
|
||||
|
||||
@@ -6,6 +6,8 @@ import math
|
||||
from typing import List
|
||||
import numpy as np
|
||||
from whisperlivekit.timed_objects import ASRToken
|
||||
from whisperlivekit.model_paths import resolve_model_path, model_path_and_type
|
||||
from whisperlivekit.whisper.transcribe import transcribe as whisper_transcribe
|
||||
logger = logging.getLogger(__name__)
|
||||
class ASRBase:
|
||||
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
|
||||
@@ -37,40 +39,60 @@ class ASRBase:
|
||||
raise NotImplementedError("must be implemented in the child class")
|
||||
|
||||
|
||||
class WhisperTimestampedASR(ASRBase):
|
||||
"""Uses whisper_timestamped as the backend."""
|
||||
class WhisperASR(ASRBase):
|
||||
"""Uses WhisperLiveKit's built-in Whisper implementation."""
|
||||
sep = " "
|
||||
|
||||
def load_model(self, model_size=None, cache_dir=None, model_dir=None):
|
||||
import whisper
|
||||
import whisper_timestamped
|
||||
from whisper_timestamped import transcribe_timestamped
|
||||
from whisperlivekit.whisper import load_model as load_model
|
||||
|
||||
self.transcribe_timestamped = transcribe_timestamped
|
||||
if model_dir is not None:
|
||||
logger.debug("ignoring model_dir, not implemented")
|
||||
return whisper.load_model(model_size, download_root=cache_dir)
|
||||
resolved_path = resolve_model_path(model_dir)
|
||||
if resolved_path.is_dir():
|
||||
pytorch_path, _, _ = model_path_and_type(resolved_path)
|
||||
if pytorch_path is None:
|
||||
raise FileNotFoundError(
|
||||
f"No supported PyTorch checkpoint found under {resolved_path}"
|
||||
)
|
||||
resolved_path = pytorch_path
|
||||
logger.debug(f"Loading Whisper model from custom path {resolved_path}")
|
||||
return load_model(str(resolved_path))
|
||||
|
||||
if model_size is None:
|
||||
raise ValueError("Either model_size or model_dir must be set for WhisperASR")
|
||||
|
||||
return load_model(model_size, download_root=cache_dir)
|
||||
|
||||
def transcribe(self, audio, init_prompt=""):
|
||||
result = self.transcribe_timestamped(
|
||||
options = dict(self.transcribe_kargs)
|
||||
options.pop("vad", None)
|
||||
options.pop("vad_filter", None)
|
||||
language = self.original_language if self.original_language else None
|
||||
|
||||
result = whisper_transcribe(
|
||||
self.model,
|
||||
audio,
|
||||
language=self.original_language,
|
||||
language=language,
|
||||
initial_prompt=init_prompt,
|
||||
verbose=None,
|
||||
condition_on_previous_text=True,
|
||||
**self.transcribe_kargs,
|
||||
word_timestamps=True,
|
||||
**options,
|
||||
)
|
||||
return result
|
||||
|
||||
def ts_words(self, r) -> List[ASRToken]:
|
||||
"""
|
||||
Converts the whisper_timestamped result to a list of ASRToken objects.
|
||||
Converts the Whisper result to a list of ASRToken objects.
|
||||
"""
|
||||
tokens = []
|
||||
for segment in r["segments"]:
|
||||
for word in segment["words"]:
|
||||
token = ASRToken(word["start"], word["end"], word["text"])
|
||||
token = ASRToken(
|
||||
word["start"],
|
||||
word["end"],
|
||||
word["word"],
|
||||
probability=word.get("probability"),
|
||||
)
|
||||
tokens.append(token)
|
||||
return tokens
|
||||
|
||||
@@ -78,11 +100,7 @@ class WhisperTimestampedASR(ASRBase):
|
||||
return [segment["end"] for segment in res["segments"]]
|
||||
|
||||
def use_vad(self):
|
||||
self.transcribe_kargs["vad"] = True
|
||||
|
||||
def set_translate_task(self):
|
||||
self.transcribe_kargs["task"] = "translate"
|
||||
|
||||
logger.warning("VAD is not currently supported for WhisperASR backend and will be ignored.")
|
||||
|
||||
class FasterWhisperASR(ASRBase):
|
||||
"""Uses faster-whisper as the backend."""
|
||||
@@ -92,9 +110,10 @@ class FasterWhisperASR(ASRBase):
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
if model_dir is not None:
|
||||
logger.debug(f"Loading whisper model from model_dir {model_dir}. "
|
||||
resolved_path = resolve_model_path(model_dir)
|
||||
logger.debug(f"Loading faster-whisper model from {resolved_path}. "
|
||||
f"model_size and cache_dir parameters are not used.")
|
||||
model_size_or_path = model_dir
|
||||
model_size_or_path = str(resolved_path)
|
||||
elif model_size is not None:
|
||||
model_size_or_path = model_size
|
||||
else:
|
||||
@@ -139,10 +158,6 @@ class FasterWhisperASR(ASRBase):
|
||||
def use_vad(self):
|
||||
self.transcribe_kargs["vad_filter"] = True
|
||||
|
||||
def set_translate_task(self):
|
||||
self.transcribe_kargs["task"] = "translate"
|
||||
|
||||
|
||||
class MLXWhisper(ASRBase):
|
||||
"""
|
||||
Uses MLX Whisper optimized for Apple Silicon.
|
||||
@@ -154,8 +169,9 @@ class MLXWhisper(ASRBase):
|
||||
import mlx.core as mx
|
||||
|
||||
if model_dir is not None:
|
||||
logger.debug(f"Loading whisper model from model_dir {model_dir}. model_size parameter is not used.")
|
||||
model_size_or_path = model_dir
|
||||
resolved_path = resolve_model_path(model_dir)
|
||||
logger.debug(f"Loading MLX Whisper model from {resolved_path}. model_size parameter is not used.")
|
||||
model_size_or_path = str(resolved_path)
|
||||
elif model_size is not None:
|
||||
model_size_or_path = self.translate_model_name(model_size)
|
||||
logger.debug(f"Loading whisper model {model_size}. You use mlx whisper, so {model_size_or_path} will be used.")
|
||||
@@ -218,10 +234,6 @@ class MLXWhisper(ASRBase):
|
||||
def use_vad(self):
|
||||
self.transcribe_kargs["vad_filter"] = True
|
||||
|
||||
def set_translate_task(self):
|
||||
self.transcribe_kargs["task"] = "translate"
|
||||
|
||||
|
||||
class OpenaiApiASR(ASRBase):
|
||||
"""Uses OpenAI's Whisper API for transcription."""
|
||||
def __init__(self, lan=None, temperature=0, logfile=sys.stderr):
|
||||
@@ -232,7 +244,7 @@ class OpenaiApiASR(ASRBase):
|
||||
self.temperature = temperature
|
||||
self.load_model()
|
||||
self.use_vad_opt = False
|
||||
self.task = "transcribe"
|
||||
self.direct_english_translation = False
|
||||
|
||||
def load_model(self, *args, **kwargs):
|
||||
from openai import OpenAI
|
||||
@@ -274,7 +286,7 @@ class OpenaiApiASR(ASRBase):
|
||||
"temperature": self.temperature,
|
||||
"timestamp_granularities": ["word", "segment"],
|
||||
}
|
||||
if self.task != "translate" and self.original_language:
|
||||
if not self.direct_english_translation and self.original_language:
|
||||
params["language"] = self.original_language
|
||||
if prompt:
|
||||
params["prompt"] = prompt
|
||||
@@ -285,6 +297,3 @@ class OpenaiApiASR(ASRBase):
|
||||
|
||||
def use_vad(self):
|
||||
self.use_vad_opt = True
|
||||
|
||||
def set_translate_task(self):
|
||||
self.task = "translate"
|
||||
199
whisperlivekit/local_agreement/whisper_online.py
Normal file
199
whisperlivekit/local_agreement/whisper_online.py
Normal file
@@ -0,0 +1,199 @@
|
||||
#!/usr/bin/env python3
|
||||
import sys
|
||||
import numpy as np
|
||||
import librosa
|
||||
from functools import lru_cache
|
||||
import time
|
||||
import logging
|
||||
import platform
|
||||
from .backends import FasterWhisperASR, MLXWhisper, WhisperASR, OpenaiApiASR
|
||||
from whisperlivekit.warmup import warmup_asr
|
||||
from whisperlivekit.model_paths import resolve_model_path, model_path_and_type
|
||||
from whisperlivekit.backend_support import (
|
||||
mlx_backend_available,
|
||||
faster_backend_available,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
WHISPER_LANG_CODES = "af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh".split(
|
||||
","
|
||||
)
|
||||
|
||||
|
||||
def create_tokenizer(lan):
|
||||
"""returns an object that has split function that works like the one of MosesTokenizer"""
|
||||
|
||||
assert (
|
||||
lan in WHISPER_LANG_CODES
|
||||
), "language must be Whisper's supported lang code: " + " ".join(WHISPER_LANG_CODES)
|
||||
|
||||
if lan == "uk":
|
||||
import tokenize_uk
|
||||
|
||||
class UkrainianTokenizer:
|
||||
def split(self, text):
|
||||
return tokenize_uk.tokenize_sents(text)
|
||||
|
||||
return UkrainianTokenizer()
|
||||
|
||||
# supported by fast-mosestokenizer
|
||||
if (
|
||||
lan
|
||||
in "as bn ca cs de el en es et fi fr ga gu hi hu is it kn lt lv ml mni mr nl or pa pl pt ro ru sk sl sv ta te yue zh".split()
|
||||
):
|
||||
from mosestokenizer import MosesSentenceSplitter
|
||||
|
||||
return MosesSentenceSplitter(lan)
|
||||
|
||||
# the following languages are in Whisper, but not in wtpsplit:
|
||||
if (
|
||||
lan
|
||||
in "as ba bo br bs fo haw hr ht jw lb ln lo mi nn oc sa sd sn so su sw tk tl tt".split()
|
||||
):
|
||||
logger.debug(
|
||||
f"{lan} code is not supported by wtpsplit. Going to use None lang_code option."
|
||||
)
|
||||
lan = None
|
||||
|
||||
from wtpsplit import WtP
|
||||
|
||||
# downloads the model from huggingface on the first use
|
||||
wtp = WtP("wtp-canine-s-12l-no-adapters")
|
||||
|
||||
class WtPtok:
|
||||
def split(self, sent):
|
||||
return wtp.split(sent, lang_code=lan)
|
||||
|
||||
return WtPtok()
|
||||
|
||||
|
||||
def backend_factory(
|
||||
backend,
|
||||
lan,
|
||||
model_size,
|
||||
model_cache_dir,
|
||||
model_dir,
|
||||
model_path,
|
||||
direct_english_translation,
|
||||
buffer_trimming,
|
||||
buffer_trimming_sec,
|
||||
confidence_validation,
|
||||
warmup_file=None,
|
||||
min_chunk_size=None,
|
||||
):
|
||||
backend_choice = backend
|
||||
custom_reference = model_path or model_dir
|
||||
resolved_root = None
|
||||
pytorch_checkpoint = None
|
||||
has_mlx_weights = False
|
||||
has_fw_weights = False
|
||||
|
||||
if custom_reference:
|
||||
resolved_root = resolve_model_path(custom_reference)
|
||||
if resolved_root.is_dir():
|
||||
pytorch_checkpoint, has_mlx_weights, has_fw_weights = model_path_and_type(resolved_root)
|
||||
else:
|
||||
pytorch_checkpoint = resolved_root
|
||||
|
||||
if backend_choice == "openai-api":
|
||||
logger.debug("Using OpenAI API.")
|
||||
asr = OpenaiApiASR(lan=lan)
|
||||
else:
|
||||
backend_choice = _normalize_backend_choice(
|
||||
backend_choice,
|
||||
resolved_root,
|
||||
has_mlx_weights,
|
||||
has_fw_weights,
|
||||
)
|
||||
|
||||
if backend_choice == "faster-whisper":
|
||||
asr_cls = FasterWhisperASR
|
||||
if resolved_root is not None and not resolved_root.is_dir():
|
||||
raise ValueError("Faster-Whisper backend expects a directory with CTranslate2 weights.")
|
||||
model_override = str(resolved_root) if resolved_root is not None else None
|
||||
elif backend_choice == "mlx-whisper":
|
||||
asr_cls = MLXWhisper
|
||||
if resolved_root is not None and not resolved_root.is_dir():
|
||||
raise ValueError("MLX Whisper backend expects a directory containing MLX weights.")
|
||||
model_override = str(resolved_root) if resolved_root is not None else None
|
||||
else:
|
||||
asr_cls = WhisperASR
|
||||
model_override = str(pytorch_checkpoint) if pytorch_checkpoint is not None else None
|
||||
if custom_reference and model_override is None:
|
||||
raise FileNotFoundError(
|
||||
f"No PyTorch checkpoint found under {resolved_root or custom_reference}"
|
||||
)
|
||||
|
||||
t = time.time()
|
||||
logger.info(f"Loading Whisper {model_size} model for language {lan} using backend {backend_choice}...")
|
||||
asr = asr_cls(
|
||||
model_size=model_size,
|
||||
lan=lan,
|
||||
cache_dir=model_cache_dir,
|
||||
model_dir=model_override,
|
||||
)
|
||||
e = time.time()
|
||||
logger.info(f"done. It took {round(e-t,2)} seconds.")
|
||||
|
||||
if direct_english_translation:
|
||||
tgt_language = "en" # Whisper translates into English
|
||||
else:
|
||||
tgt_language = lan # Whisper transcribes in this language
|
||||
|
||||
# Create the tokenizer
|
||||
if buffer_trimming == "sentence":
|
||||
tokenizer = create_tokenizer(tgt_language)
|
||||
else:
|
||||
tokenizer = None
|
||||
|
||||
warmup_asr(asr, warmup_file)
|
||||
|
||||
asr.confidence_validation = confidence_validation
|
||||
asr.tokenizer = tokenizer
|
||||
asr.buffer_trimming = buffer_trimming
|
||||
asr.buffer_trimming_sec = buffer_trimming_sec
|
||||
asr.backend_choice = backend_choice
|
||||
return asr
|
||||
|
||||
|
||||
def _normalize_backend_choice(
|
||||
preferred_backend,
|
||||
resolved_root,
|
||||
has_mlx_weights,
|
||||
has_fw_weights,
|
||||
):
|
||||
backend_choice = preferred_backend
|
||||
|
||||
if backend_choice == "auto":
|
||||
if mlx_backend_available(warn_on_missing=True) and (resolved_root is None or has_mlx_weights):
|
||||
return "mlx-whisper"
|
||||
if faster_backend_available(warn_on_missing=True) and (resolved_root is None or has_fw_weights):
|
||||
return "faster-whisper"
|
||||
return "whisper"
|
||||
|
||||
if backend_choice == "mlx-whisper":
|
||||
if not mlx_backend_available():
|
||||
raise RuntimeError("mlx-whisper backend requested but mlx-whisper is not installed.")
|
||||
if resolved_root is not None and not has_mlx_weights:
|
||||
raise FileNotFoundError(
|
||||
f"mlx-whisper backend requested but no MLX weights were found under {resolved_root}"
|
||||
)
|
||||
if platform.system() != "Darwin":
|
||||
logger.warning("mlx-whisper backend requested on a non-macOS system; this may fail.")
|
||||
return backend_choice
|
||||
|
||||
if backend_choice == "faster-whisper":
|
||||
if not faster_backend_available():
|
||||
raise RuntimeError("faster-whisper backend requested but faster-whisper is not installed.")
|
||||
if resolved_root is not None and not has_fw_weights:
|
||||
raise FileNotFoundError(
|
||||
f"faster-whisper backend requested but no Faster-Whisper weights were found under {resolved_root}"
|
||||
)
|
||||
return backend_choice
|
||||
|
||||
if backend_choice == "whisper":
|
||||
return backend_choice
|
||||
|
||||
raise ValueError(f"Unknown backend '{preferred_backend}' for LocalAgreement.")
|
||||
69
whisperlivekit/model_paths.py
Normal file
69
whisperlivekit/model_paths.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
|
||||
def model_path_and_type(model_path: Union[str, Path]) -> Tuple[Optional[Path], bool, bool]:
|
||||
"""
|
||||
Inspect the provided path and determine which model formats are available.
|
||||
|
||||
Returns:
|
||||
pytorch_path: Path to a PyTorch checkpoint (if present).
|
||||
compatible_whisper_mlx: True if MLX weights exist in this folder.
|
||||
compatible_faster_whisper: True if Faster-Whisper (ctranslate2) weights exist.
|
||||
"""
|
||||
path = Path(model_path)
|
||||
|
||||
compatible_whisper_mlx = False
|
||||
compatible_faster_whisper = False
|
||||
pytorch_path: Optional[Path] = None
|
||||
|
||||
if path.is_file() and path.suffix.lower() in [".pt", ".safetensors", ".bin"]:
|
||||
pytorch_path = path
|
||||
return pytorch_path, compatible_whisper_mlx, compatible_faster_whisper
|
||||
|
||||
if path.is_dir():
|
||||
for file in path.iterdir():
|
||||
if not file.is_file():
|
||||
continue
|
||||
|
||||
filename = file.name.lower()
|
||||
suffix = file.suffix.lower()
|
||||
|
||||
if filename in {"weights.npz", "weights.safetensors"}:
|
||||
compatible_whisper_mlx = True
|
||||
elif filename in {"model.bin", "encoder.bin", "decoder.bin"}:
|
||||
compatible_faster_whisper = True
|
||||
elif suffix in {".pt", ".safetensors"}:
|
||||
pytorch_path = file
|
||||
elif filename == "pytorch_model.bin":
|
||||
pytorch_path = file
|
||||
|
||||
if pytorch_path is None:
|
||||
fallback = path / "pytorch_model.bin"
|
||||
if fallback.exists():
|
||||
pytorch_path = fallback
|
||||
|
||||
return pytorch_path, compatible_whisper_mlx, compatible_faster_whisper
|
||||
|
||||
|
||||
def resolve_model_path(model_path: Union[str, Path]) -> Path:
|
||||
"""
|
||||
Return a local path for the provided model reference.
|
||||
|
||||
If the path does not exist locally, it is treated as a Hugging Face repo id
|
||||
and downloaded via snapshot_download.
|
||||
"""
|
||||
path = Path(model_path).expanduser()
|
||||
if path.exists():
|
||||
return path
|
||||
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
except ImportError as exc: # pragma: no cover - optional dependency guard
|
||||
raise FileNotFoundError(
|
||||
f"Model path '{model_path}' does not exist locally and huggingface_hub "
|
||||
"is not installed to download it."
|
||||
) from exc
|
||||
|
||||
downloaded_path = Path(snapshot_download(repo_id=str(model_path)))
|
||||
return downloaded_path
|
||||
@@ -114,11 +114,10 @@ def parse_args():
|
||||
help="Source language code, e.g. en,de,cs, or 'auto' for language detection.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
default="transcribe",
|
||||
choices=["transcribe", "translate"],
|
||||
help="Transcribe or translate.",
|
||||
"--direct-english-translation",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use Whisper to directly translate to english.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@@ -130,11 +129,18 @@ def parse_args():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
"--backend-policy",
|
||||
type=str,
|
||||
default="simulstreaming",
|
||||
choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api", "simulstreaming"],
|
||||
help="Load only this backend for Whisper processing.",
|
||||
choices=["1", "2", "simulstreaming", "localagreement"],
|
||||
help="Select the streaming policy: 1 or 'simulstreaming' for AlignAtt, 2 or 'localagreement' for LocalAgreement.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api"],
|
||||
help="Select the Whisper backend implementation (auto: prefer MLX on macOS, otherwise Faster-Whisper, else Whisper). Use 'openai-api' with --backend-policy localagreement to call OpenAI's API.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-vac",
|
||||
@@ -300,7 +306,7 @@ def parse_args():
|
||||
simulstreaming_group.add_argument(
|
||||
"--nllb-backend",
|
||||
type=str,
|
||||
default="ctranslate2",
|
||||
default="transformers",
|
||||
help="transformers or ctranslate2",
|
||||
)
|
||||
|
||||
@@ -317,5 +323,10 @@ def parse_args():
|
||||
args.vad = not args.no_vad
|
||||
delattr(args, 'no_transcription')
|
||||
delattr(args, 'no_vad')
|
||||
|
||||
if args.backend_policy == "1":
|
||||
args.backend_policy = "simulstreaming"
|
||||
elif args.backend_policy == "2":
|
||||
args.backend_policy = "localagreement"
|
||||
|
||||
return args
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from whisperlivekit.timed_objects import ASRToken
|
||||
from time import time
|
||||
import re
|
||||
|
||||
MIN_SILENCE_DURATION = 4 #in seconds
|
||||
@@ -77,7 +78,8 @@ def no_token_to_silence(tokens):
|
||||
new_tokens.append(token)
|
||||
return new_tokens
|
||||
|
||||
def ends_with_silence(tokens, current_time, vac_detected_silence):
|
||||
def ends_with_silence(tokens, beg_loop, vac_detected_silence):
|
||||
current_time = time() - (beg_loop if beg_loop else 0.0)
|
||||
last_token = tokens[-1]
|
||||
if vac_detected_silence or (current_time - last_token.end >= END_SILENCE_DURATION):
|
||||
if last_token.speaker == -2:
|
||||
@@ -94,11 +96,11 @@ def ends_with_silence(tokens, current_time, vac_detected_silence):
|
||||
return tokens
|
||||
|
||||
|
||||
def handle_silences(tokens, current_time, vac_detected_silence):
|
||||
def handle_silences(tokens, beg_loop, vac_detected_silence):
|
||||
if not tokens:
|
||||
return []
|
||||
tokens = blank_to_silence(tokens) #useful for simulstreaming backend which tends to generate [BLANK_AUDIO] text
|
||||
tokens = no_token_to_silence(tokens)
|
||||
tokens = ends_with_silence(tokens, current_time, vac_detected_silence)
|
||||
tokens = ends_with_silence(tokens, beg_loop, vac_detected_silence)
|
||||
return tokens
|
||||
|
||||
@@ -52,18 +52,18 @@ def append_token_to_last_line(lines, sep, token):
|
||||
lines[-1].detected_language = token.detected_language
|
||||
|
||||
|
||||
def format_output(state, silence, current_time, args, sep):
|
||||
def format_output(state, silence, args, sep):
|
||||
diarization = args.diarization
|
||||
disable_punctuation_split = args.disable_punctuation_split
|
||||
tokens = state.tokens
|
||||
translated_segments = state.translated_segments # Here we will attribute the speakers only based on the timestamps of the segments
|
||||
translation_validated_segments = state.translation_validated_segments # Here we will attribute the speakers only based on the timestamps of the segments
|
||||
last_validated_token = state.last_validated_token
|
||||
|
||||
previous_speaker = 1
|
||||
undiarized_text = []
|
||||
tokens = handle_silences(tokens, current_time, silence)
|
||||
last_punctuation = None
|
||||
for i, token in enumerate(tokens[last_validated_token:]):
|
||||
tokens = handle_silences(tokens, state.beg_loop, silence)
|
||||
for i in range(last_validated_token, len(tokens)):
|
||||
token = tokens[i]
|
||||
speaker = int(token.speaker)
|
||||
token.corrected_speaker = speaker
|
||||
if not diarization:
|
||||
@@ -71,17 +71,10 @@ def format_output(state, silence, current_time, args, sep):
|
||||
token.corrected_speaker = 1
|
||||
token.validated_speaker = True
|
||||
else:
|
||||
# if token.end > end_attributed_speaker and token.speaker != -2:
|
||||
# if tokens[-1].speaker == -2: #if it finishes by a silence, we want to append the undiarized text to the last speaker.
|
||||
# token.corrected_speaker = previous_speaker
|
||||
# else:
|
||||
# undiarized_text.append(token.text)
|
||||
# continue
|
||||
# else:
|
||||
if is_punctuation(token):
|
||||
last_punctuation = i
|
||||
state.last_punctuation_index = i
|
||||
|
||||
if last_punctuation == i-1:
|
||||
if state.last_punctuation_index == i-1:
|
||||
if token.speaker != previous_speaker:
|
||||
token.validated_speaker = True
|
||||
# perfect, diarization perfectly aligned
|
||||
@@ -123,9 +116,9 @@ def format_output(state, silence, current_time, args, sep):
|
||||
|
||||
previous_speaker = token.corrected_speaker
|
||||
|
||||
if lines and translated_segments:
|
||||
if lines:
|
||||
unassigned_translated_segments = []
|
||||
for ts in translated_segments:
|
||||
for ts in translation_validated_segments:
|
||||
assigned = False
|
||||
for line in lines:
|
||||
if ts and ts.overlaps_with(line):
|
||||
|
||||
@@ -1,27 +1,182 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
# This is copied from silero-vad's vad_utils.py:
|
||||
# https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/utils_vad.py#L340
|
||||
# (except changed defaults)
|
||||
"""
|
||||
Code is adapted from silero-vad v6: https://github.com/snakers4/silero-vad
|
||||
"""
|
||||
|
||||
# Their licence is MIT, same as ours: https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/LICENSE
|
||||
def init_jit_model(model_path: str, device=torch.device('cpu')):
|
||||
"""Load a JIT model from file."""
|
||||
model = torch.jit.load(model_path, map_location=device)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
class OnnxWrapper():
|
||||
"""ONNX Runtime wrapper for Silero VAD model."""
|
||||
|
||||
def __init__(self, path, force_onnx_cpu=False):
|
||||
global np
|
||||
import numpy as np
|
||||
import onnxruntime
|
||||
|
||||
opts = onnxruntime.SessionOptions()
|
||||
opts.inter_op_num_threads = 1
|
||||
opts.intra_op_num_threads = 1
|
||||
|
||||
if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
|
||||
self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts)
|
||||
else:
|
||||
self.session = onnxruntime.InferenceSession(path, sess_options=opts)
|
||||
|
||||
self.reset_states()
|
||||
if '16k' in path:
|
||||
warnings.warn('This model support only 16000 sampling rate!')
|
||||
self.sample_rates = [16000]
|
||||
else:
|
||||
self.sample_rates = [8000, 16000]
|
||||
|
||||
def _validate_input(self, x, sr: int):
|
||||
if x.dim() == 1:
|
||||
x = x.unsqueeze(0)
|
||||
if x.dim() > 2:
|
||||
raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}")
|
||||
|
||||
if sr != 16000 and (sr % 16000 == 0):
|
||||
step = sr // 16000
|
||||
x = x[:,::step]
|
||||
sr = 16000
|
||||
|
||||
if sr not in self.sample_rates:
|
||||
raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)")
|
||||
if sr / x.shape[1] > 31.25:
|
||||
raise ValueError("Input audio chunk is too short")
|
||||
|
||||
return x, sr
|
||||
|
||||
def reset_states(self, batch_size=1):
|
||||
self._state = torch.zeros((2, batch_size, 128)).float()
|
||||
self._context = torch.zeros(0)
|
||||
self._last_sr = 0
|
||||
self._last_batch_size = 0
|
||||
|
||||
def __call__(self, x, sr: int):
|
||||
|
||||
x, sr = self._validate_input(x, sr)
|
||||
num_samples = 512 if sr == 16000 else 256
|
||||
|
||||
if x.shape[-1] != num_samples:
|
||||
raise ValueError(f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")
|
||||
|
||||
batch_size = x.shape[0]
|
||||
context_size = 64 if sr == 16000 else 32
|
||||
|
||||
if not self._last_batch_size:
|
||||
self.reset_states(batch_size)
|
||||
if (self._last_sr) and (self._last_sr != sr):
|
||||
self.reset_states(batch_size)
|
||||
if (self._last_batch_size) and (self._last_batch_size != batch_size):
|
||||
self.reset_states(batch_size)
|
||||
|
||||
if not len(self._context):
|
||||
self._context = torch.zeros(batch_size, context_size)
|
||||
|
||||
x = torch.cat([self._context, x], dim=1)
|
||||
if sr in [8000, 16000]:
|
||||
ort_inputs = {'input': x.numpy(), 'state': self._state.numpy(), 'sr': np.array(sr, dtype='int64')}
|
||||
ort_outs = self.session.run(None, ort_inputs)
|
||||
out, state = ort_outs
|
||||
self._state = torch.from_numpy(state)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
self._context = x[..., -context_size:]
|
||||
self._last_sr = sr
|
||||
self._last_batch_size = batch_size
|
||||
|
||||
out = torch.from_numpy(out)
|
||||
return out
|
||||
|
||||
|
||||
def load_silero_vad(model_path: str = None, onnx: bool = False, opset_version: int = 16):
|
||||
"""
|
||||
Load Silero VAD model (JIT or ONNX).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_path : str, optional
|
||||
Path to model file. If None, uses default bundled model.
|
||||
onnx : bool, default False
|
||||
Whether to use ONNX runtime (requires onnxruntime package).
|
||||
opset_version : int, default 16
|
||||
ONNX opset version (15 or 16). Only used if onnx=True.
|
||||
|
||||
Returns
|
||||
-------
|
||||
model
|
||||
Loaded VAD model (JIT or ONNX wrapper)
|
||||
"""
|
||||
available_ops = [15, 16]
|
||||
if onnx and opset_version not in available_ops:
|
||||
raise Exception(f'Available ONNX opset_version: {available_ops}')
|
||||
if model_path is None:
|
||||
current_dir = Path(__file__).parent
|
||||
data_dir = current_dir / 'vad_models'
|
||||
|
||||
if onnx:
|
||||
if opset_version == 16:
|
||||
model_name = 'silero_vad.onnx'
|
||||
else:
|
||||
model_name = f'silero_vad_16k_op{opset_version}.onnx'
|
||||
else:
|
||||
model_name = 'silero_vad.jit'
|
||||
|
||||
model_path = data_dir / model_name
|
||||
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Model file not found: {model_path}\n"
|
||||
f"Please ensure the whisperlivekit/vad_models/ directory contains the model files."
|
||||
)
|
||||
else:
|
||||
model_path = Path(model_path)
|
||||
if onnx:
|
||||
try:
|
||||
model = OnnxWrapper(str(model_path), force_onnx_cpu=True)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"ONNX runtime not available. Install with: pip install onnxruntime\n"
|
||||
"Or use JIT model by setting onnx=False"
|
||||
)
|
||||
else:
|
||||
model = init_jit_model(str(model_path))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class VADIterator:
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
threshold: float = 0.5,
|
||||
sampling_rate: int = 16000,
|
||||
min_silence_duration_ms: int = 500, # makes sense on one recording that I checked
|
||||
speech_pad_ms: int = 100, # same
|
||||
):
|
||||
"""
|
||||
Voice Activity Detection iterator for streaming audio.
|
||||
|
||||
This is the Silero VAD v6 implementation.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model,
|
||||
threshold: float = 0.5,
|
||||
sampling_rate: int = 16000,
|
||||
min_silence_duration_ms: int = 100,
|
||||
speech_pad_ms: int = 30
|
||||
):
|
||||
|
||||
"""
|
||||
Class for stream imitation
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model: preloaded .jit silero VAD model
|
||||
model: preloaded .jit/.onnx silero VAD model
|
||||
|
||||
threshold: float (default - 0.5)
|
||||
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
|
||||
@@ -42,9 +197,7 @@ class VADIterator:
|
||||
self.sampling_rate = sampling_rate
|
||||
|
||||
if sampling_rate not in [8000, 16000]:
|
||||
raise ValueError(
|
||||
"VADIterator does not support sampling rates other than [8000, 16000]"
|
||||
)
|
||||
raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]')
|
||||
|
||||
self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
|
||||
self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
|
||||
@@ -57,13 +210,17 @@ class VADIterator:
|
||||
self.temp_end = 0
|
||||
self.current_sample = 0
|
||||
|
||||
def __call__(self, x, return_seconds=False):
|
||||
@torch.no_grad()
|
||||
def __call__(self, x, return_seconds=False, time_resolution: int = 1):
|
||||
"""
|
||||
x: torch.Tensor
|
||||
audio chunk (see examples in repo)
|
||||
|
||||
return_seconds: bool (default - False)
|
||||
whether return timestamps in seconds (default - samples)
|
||||
|
||||
time_resolution: int (default - 1)
|
||||
time resolution of speech coordinates when requested as seconds
|
||||
"""
|
||||
|
||||
if not torch.is_tensor(x):
|
||||
@@ -82,14 +239,8 @@ class VADIterator:
|
||||
|
||||
if (speech_prob >= self.threshold) and not self.triggered:
|
||||
self.triggered = True
|
||||
speech_start = self.current_sample - self.speech_pad_samples
|
||||
return {
|
||||
"start": (
|
||||
int(speech_start)
|
||||
if not return_seconds
|
||||
else round(speech_start / self.sampling_rate, 1)
|
||||
)
|
||||
}
|
||||
speech_start = max(0, self.current_sample - self.speech_pad_samples - window_size_samples)
|
||||
return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, time_resolution)}
|
||||
|
||||
if (speech_prob < self.threshold - 0.15) and self.triggered:
|
||||
if not self.temp_end:
|
||||
@@ -97,30 +248,17 @@ class VADIterator:
|
||||
if self.current_sample - self.temp_end < self.min_silence_samples:
|
||||
return None
|
||||
else:
|
||||
speech_end = self.temp_end + self.speech_pad_samples
|
||||
speech_end = self.temp_end + self.speech_pad_samples - window_size_samples
|
||||
self.temp_end = 0
|
||||
self.triggered = False
|
||||
return {
|
||||
"end": (
|
||||
int(speech_end)
|
||||
if not return_seconds
|
||||
else round(speech_end / self.sampling_rate, 1)
|
||||
)
|
||||
}
|
||||
return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, time_resolution)}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
#######################
|
||||
# because Silero now requires exactly 512-sized audio chunks
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class FixedVADIterator(VADIterator):
|
||||
"""It fixes VADIterator by allowing to process any audio length, not only exactly 512 frames at once.
|
||||
If audio to be processed at once is long and multiple voiced segments detected,
|
||||
then __call__ returns the start of the first segment, and end (or middle, which means no end) of the last segment.
|
||||
"""
|
||||
Fixed VAD Iterator that handles variable-length audio chunks, not only exactly 512 frames at once.
|
||||
"""
|
||||
|
||||
def reset_states(self):
|
||||
@@ -137,27 +275,20 @@ class FixedVADIterator(VADIterator):
|
||||
ret = r
|
||||
elif r is not None:
|
||||
if "end" in r:
|
||||
ret["end"] = r["end"] # the latter end
|
||||
if "start" in r and "end" in ret: # there is an earlier start.
|
||||
# Remove end, merging this segment with the previous one.
|
||||
ret["end"] = r["end"]
|
||||
if "start" in r and "end" in ret:
|
||||
del ret["end"]
|
||||
return ret if ret != {} else None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# test/demonstrate the need for FixedVADIterator:
|
||||
|
||||
import torch
|
||||
|
||||
model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
|
||||
vac = FixedVADIterator(model)
|
||||
# vac = VADIterator(model) # the second case crashes with this
|
||||
|
||||
# this works: for both
|
||||
audio_buffer = np.array([0] * (512), dtype=np.float32)
|
||||
vac(audio_buffer)
|
||||
|
||||
# this crashes on the non FixedVADIterator with
|
||||
# ops.prim.RaiseException("Input audio chunk is too short", "builtins.ValueError")
|
||||
audio_buffer = np.array([0] * (512 - 1), dtype=np.float32)
|
||||
vac(audio_buffer)
|
||||
model = load_silero_vad(onnx=False)
|
||||
vad = FixedVADIterator(model)
|
||||
|
||||
audio_buffer = np.array([0] * 512, dtype=np.float32)
|
||||
result = vad(audio_buffer)
|
||||
print(f" 512 samples: {result}")
|
||||
|
||||
# test with 511 samples
|
||||
audio_buffer = np.array([0] * 511, dtype=np.float32)
|
||||
result = vad(audio_buffer)
|
||||
@@ -2,41 +2,37 @@ import sys
|
||||
import numpy as np
|
||||
import logging
|
||||
from typing import List, Tuple, Optional
|
||||
import logging
|
||||
import platform
|
||||
from whisperlivekit.timed_objects import ASRToken, Transcript, ChangeSpeaker
|
||||
from whisperlivekit.warmup import load_file
|
||||
from .whisper import load_model, tokenizer
|
||||
from .whisper.audio import TOKENS_PER_SECOND
|
||||
from whisperlivekit.whisper import load_model, tokenizer
|
||||
from whisperlivekit.whisper.audio import TOKENS_PER_SECOND
|
||||
import os
|
||||
import gc
|
||||
logger = logging.getLogger(__name__)
|
||||
from pathlib import Path
|
||||
from whisperlivekit.model_paths import model_path_and_type, resolve_model_path
|
||||
from whisperlivekit.backend_support import (
|
||||
mlx_backend_available,
|
||||
faster_backend_available,
|
||||
)
|
||||
|
||||
import torch
|
||||
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
||||
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper
|
||||
from whisperlivekit.simul_whisper.whisper import tokenizer
|
||||
|
||||
try:
|
||||
from .mlx_encoder import mlx_model_mapping, load_mlx_encoder
|
||||
HAS_MLX_WHISPER = True
|
||||
except ImportError:
|
||||
if platform.system() == "Darwin" and platform.machine() == "arm64":
|
||||
print(f"""{"="*50}
|
||||
MLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: pip install mlx-whisper
|
||||
{"="*50}""")
|
||||
HAS_MLX_WHISPER = False
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True)
|
||||
if HAS_MLX_WHISPER:
|
||||
HAS_FASTER_WHISPER = False
|
||||
from .mlx_encoder import mlx_model_mapping, load_mlx_encoder
|
||||
else:
|
||||
try:
|
||||
from faster_whisper import WhisperModel
|
||||
HAS_FASTER_WHISPER = True
|
||||
except ImportError:
|
||||
HAS_FASTER_WHISPER = False
|
||||
|
||||
|
||||
# TOO_MANY_REPETITIONS = 3
|
||||
mlx_model_mapping = {}
|
||||
HAS_FASTER_WHISPER = faster_backend_available(warn_on_missing=not HAS_MLX_WHISPER)
|
||||
if HAS_FASTER_WHISPER:
|
||||
from faster_whisper import WhisperModel
|
||||
else:
|
||||
WhisperModel = None
|
||||
|
||||
class SimulStreamingOnlineProcessor:
|
||||
SAMPLING_RATE = 16000
|
||||
@@ -154,8 +150,22 @@ class SimulStreamingASR():
|
||||
self.decoder_type = 'greedy' if self.beams == 1 else 'beam'
|
||||
|
||||
self.fast_encoder = False
|
||||
if self.model_dir is not None:
|
||||
self.model_path = self.model_dir
|
||||
self._resolved_model_path = None
|
||||
self.encoder_backend = "whisper"
|
||||
preferred_backend = getattr(self, "backend", "auto")
|
||||
self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = None, True, True
|
||||
if self.model_path:
|
||||
resolved_model_path = resolve_model_path(self.model_path)
|
||||
self._resolved_model_path = resolved_model_path
|
||||
self.model_path = str(resolved_model_path)
|
||||
self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = model_path_and_type(resolved_model_path)
|
||||
if self.pytorch_path:
|
||||
self.model_name = self.pytorch_path.stem
|
||||
else:
|
||||
self.model_name = Path(self.model_path).stem
|
||||
raise FileNotFoundError(
|
||||
f"No PyTorch checkpoint (.pt/.bin/.safetensors) found under {self.model_path}"
|
||||
)
|
||||
elif self.model_size is not None:
|
||||
model_mapping = {
|
||||
'tiny': './tiny.pt',
|
||||
@@ -171,10 +181,23 @@ class SimulStreamingASR():
|
||||
'large-v3': './large-v3.pt',
|
||||
'large': './large-v3.pt'
|
||||
}
|
||||
self.model_path = model_mapping.get(self.model_size, f'./{self.model_size}.pt')
|
||||
|
||||
self.model_name = self.model_size
|
||||
else:
|
||||
raise ValueError("Either model_size or model_path must be specified for SimulStreaming.")
|
||||
|
||||
is_multilingual = not self.model_name.endswith(".en")
|
||||
|
||||
self.encoder_backend = self._resolve_encoder_backend(
|
||||
preferred_backend,
|
||||
compatible_whisper_mlx,
|
||||
compatible_faster_whisper,
|
||||
)
|
||||
self.fast_encoder = self.encoder_backend in ("mlx-whisper", "faster-whisper")
|
||||
if self.encoder_backend == "whisper":
|
||||
self.disable_fast_encoder = True
|
||||
|
||||
self.cfg = AlignAttConfig(
|
||||
model_path=self.model_path,
|
||||
tokenizer_is_multilingual= is_multilingual,
|
||||
segment_length=self.min_chunk_size,
|
||||
frame_threshold=self.frame_threshold,
|
||||
language=self.lan,
|
||||
@@ -183,7 +206,7 @@ class SimulStreamingASR():
|
||||
cif_ckpt_path=self.cif_ckpt_path,
|
||||
decoder_type="beam",
|
||||
beam_size=self.beams,
|
||||
task=self.task,
|
||||
task=self.direct_english_translation,
|
||||
never_fire=self.never_fire,
|
||||
init_prompt=self.init_prompt,
|
||||
max_context_tokens=self.max_context_tokens,
|
||||
@@ -191,40 +214,84 @@ class SimulStreamingASR():
|
||||
)
|
||||
|
||||
# Set up tokenizer for translation if needed
|
||||
if self.task == "translate":
|
||||
if self.direct_english_translation:
|
||||
self.tokenizer = self.set_translate_task()
|
||||
else:
|
||||
self.tokenizer = None
|
||||
|
||||
if self.model_dir:
|
||||
self.model_name = self.model_dir
|
||||
self.model_path = None
|
||||
else:
|
||||
self.model_name = os.path.basename(self.cfg.model_path).replace(".pt", "")
|
||||
self.model_path = os.path.dirname(os.path.abspath(self.cfg.model_path))
|
||||
|
||||
|
||||
|
||||
self.mlx_encoder, self.fw_encoder = None, None
|
||||
if not self.disable_fast_encoder:
|
||||
if HAS_MLX_WHISPER:
|
||||
print('Simulstreaming will use MLX whisper for a faster encoder.')
|
||||
mlx_model_name = mlx_model_mapping[self.model_name]
|
||||
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_name)
|
||||
self.fast_encoder = True
|
||||
elif HAS_FASTER_WHISPER:
|
||||
print('Simulstreaming will use Faster Whisper for the encoder.')
|
||||
self.fw_encoder = WhisperModel(
|
||||
self.model_name,
|
||||
device='auto',
|
||||
compute_type='auto',
|
||||
if self.encoder_backend == "mlx-whisper":
|
||||
print('Simulstreaming will use MLX whisper to increase encoding speed.')
|
||||
if self._resolved_model_path is not None:
|
||||
mlx_model = str(self._resolved_model_path)
|
||||
else:
|
||||
mlx_model = mlx_model_mapping.get(self.model_name)
|
||||
if not mlx_model:
|
||||
raise FileNotFoundError(
|
||||
f"MLX Whisper backend requested but no compatible weights found for model '{self.model_name}'."
|
||||
)
|
||||
self.fast_encoder = True
|
||||
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model)
|
||||
elif self.encoder_backend == "faster-whisper":
|
||||
print('Simulstreaming will use Faster Whisper for the encoder.')
|
||||
if self._resolved_model_path is not None:
|
||||
fw_model = str(self._resolved_model_path)
|
||||
else:
|
||||
fw_model = self.model_name
|
||||
self.fw_encoder = WhisperModel(
|
||||
fw_model,
|
||||
device='auto',
|
||||
compute_type='auto',
|
||||
)
|
||||
|
||||
self.models = [self.load_model() for i in range(self.preload_model_count)]
|
||||
|
||||
|
||||
def _resolve_encoder_backend(self, preferred_backend, compatible_whisper_mlx, compatible_faster_whisper):
|
||||
choice = preferred_backend or "auto"
|
||||
if self.disable_fast_encoder:
|
||||
return "whisper"
|
||||
if choice == "whisper":
|
||||
return "whisper"
|
||||
if choice == "mlx-whisper":
|
||||
if not self._can_use_mlx(compatible_whisper_mlx):
|
||||
raise RuntimeError("mlx-whisper backend requested but MLX Whisper is unavailable or incompatible with the provided model.")
|
||||
return "mlx-whisper"
|
||||
if choice == "faster-whisper":
|
||||
if not self._can_use_faster(compatible_faster_whisper):
|
||||
raise RuntimeError("faster-whisper backend requested but Faster-Whisper is unavailable or incompatible with the provided model.")
|
||||
return "faster-whisper"
|
||||
if choice == "openai-api":
|
||||
raise ValueError("openai-api backend is only supported with the LocalAgreement policy.")
|
||||
# auto mode
|
||||
if platform.system() == "Darwin" and self._can_use_mlx(compatible_whisper_mlx):
|
||||
return "mlx-whisper"
|
||||
if self._can_use_faster(compatible_faster_whisper):
|
||||
return "faster-whisper"
|
||||
return "whisper"
|
||||
|
||||
def _has_custom_model_path(self):
|
||||
return self._resolved_model_path is not None
|
||||
|
||||
def _can_use_mlx(self, compatible_whisper_mlx):
|
||||
if not HAS_MLX_WHISPER:
|
||||
return False
|
||||
if self._has_custom_model_path():
|
||||
return compatible_whisper_mlx
|
||||
return self.model_name in mlx_model_mapping
|
||||
|
||||
def _can_use_faster(self, compatible_faster_whisper):
|
||||
if not HAS_FASTER_WHISPER:
|
||||
return False
|
||||
if self._has_custom_model_path():
|
||||
return compatible_faster_whisper
|
||||
return True
|
||||
|
||||
def load_model(self):
|
||||
whisper_model = load_model(
|
||||
name=self.model_name,
|
||||
name=self.pytorch_path if self.pytorch_path else self.model_name,
|
||||
download_root=self.model_path,
|
||||
decoder_only=self.fast_encoder,
|
||||
custom_alignment_heads=self.custom_alignment_heads
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .whisper.decoding import PyTorchInference
|
||||
from whisperlivekit.whisper.decoding import PyTorchInference
|
||||
|
||||
# extention of PyTorchInference for beam search
|
||||
class BeamPyTorchInference(PyTorchInference):
|
||||
|
||||
@@ -1,29 +1,23 @@
|
||||
# This code was originally in simul_whisper/transcriber/simul_whisper.py . It is adapted a lot for SimulStreaming.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal
|
||||
|
||||
@dataclass
|
||||
class SimulWhisperConfig:
|
||||
'''Options that are common for all simul policies that could be implemented in SimulWhisper.'''
|
||||
model_path: str
|
||||
language: str = field(default="zh")
|
||||
nonspeech_prob: float = 0.5
|
||||
audio_min_len: float = 1.0
|
||||
decoder_type: Literal["greedy","beam"] = "greedy"
|
||||
beam_size: int = 5
|
||||
task: Literal["transcribe","translate"] = "transcribe"
|
||||
init_prompt: str = field(default=None)
|
||||
static_init_prompt: str = field(default=None)
|
||||
max_context_tokens: int = field(default=None)
|
||||
|
||||
@dataclass
|
||||
class AlignAttConfig(SimulWhisperConfig):
|
||||
'''Options specific to the AlignAtt policy.'''
|
||||
class AlignAttConfig():
|
||||
eval_data_path: str = "tmp"
|
||||
segment_length: float = field(default=1.0, metadata = {"help": "in second"})
|
||||
frame_threshold: int = 4
|
||||
rewind_threshold: int = 200
|
||||
audio_max_len: float = 20.0
|
||||
cif_ckpt_path: str = ""
|
||||
never_fire: bool = False
|
||||
never_fire: bool = False
|
||||
language: str = field(default="zh")
|
||||
nonspeech_prob: float = 0.5
|
||||
audio_min_len: float = 1.0
|
||||
decoder_type: Literal["greedy","beam"] = "greedy"
|
||||
beam_size: int = 5
|
||||
task: Literal["transcribe","translate"] = "transcribe"
|
||||
tokenizer_is_multilingual: bool = False
|
||||
init_prompt: str = field(default=None)
|
||||
static_init_prompt: str = field(default=None)
|
||||
max_context_tokens: int = field(default=None)
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
SIMULSTREAMING_LICENSE = f"""
|
||||
SimulStreaming backend is dual-licensed:
|
||||
• Non-Commercial Use: PolyForm Noncommercial License 1.0.0.
|
||||
• Commercial Use: Check SimulStreaming README (github.com/ufal/SimulStreaming) for more details.
|
||||
"""
|
||||
@@ -6,17 +6,21 @@ import logging
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .whisper import load_model, DecodingOptions, tokenizer
|
||||
from whisperlivekit.whisper import load_model, DecodingOptions, tokenizer
|
||||
from .config import AlignAttConfig
|
||||
from whisperlivekit.timed_objects import ASRToken
|
||||
from .whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES
|
||||
from .whisper.timing import median_filter
|
||||
from .whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens, detect_language
|
||||
from whisperlivekit.whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES
|
||||
from whisperlivekit.whisper.timing import median_filter
|
||||
from whisperlivekit.whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens, detect_language
|
||||
from .beam import BeamPyTorchInference
|
||||
from .eow_detection import fire_at_boundary, load_cif
|
||||
import os
|
||||
from time import time
|
||||
from .token_buffer import TokenBuffer
|
||||
from whisperlivekit.backend_support import (
|
||||
mlx_backend_available,
|
||||
faster_backend_available,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from ..timed_objects import PUNCTUATION_MARKS
|
||||
@@ -26,21 +30,18 @@ DEC_PAD = 50257
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
try:
|
||||
HAS_MLX_WHISPER = False
|
||||
HAS_FASTER_WHISPER = False
|
||||
|
||||
if mlx_backend_available():
|
||||
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
|
||||
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
|
||||
HAS_MLX_WHISPER = True
|
||||
except ImportError:
|
||||
HAS_MLX_WHISPER = False
|
||||
if HAS_MLX_WHISPER:
|
||||
HAS_FASTER_WHISPER = False
|
||||
else:
|
||||
try:
|
||||
from faster_whisper.audio import pad_or_trim as fw_pad_or_trim
|
||||
from faster_whisper.feature_extractor import FeatureExtractor
|
||||
HAS_FASTER_WHISPER = True
|
||||
except ImportError:
|
||||
HAS_FASTER_WHISPER = False
|
||||
|
||||
if faster_backend_available():
|
||||
from faster_whisper.audio import pad_or_trim as fw_pad_or_trim
|
||||
from faster_whisper.feature_extractor import FeatureExtractor
|
||||
HAS_FASTER_WHISPER = True
|
||||
|
||||
class PaddedAlignAttWhisper:
|
||||
def __init__(
|
||||
@@ -51,20 +52,15 @@ class PaddedAlignAttWhisper:
|
||||
fw_encoder=None,
|
||||
) -> None:
|
||||
self.log_segments = 0
|
||||
model_name = os.path.basename(cfg.model_path).replace(".pt", "")
|
||||
model_path = os.path.dirname(os.path.abspath(cfg.model_path))
|
||||
if loaded_model:
|
||||
self.model = loaded_model
|
||||
else:
|
||||
self.model = load_model(name=model_name, download_root=model_path)
|
||||
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
self.model = loaded_model
|
||||
self.mlx_encoder = mlx_encoder
|
||||
self.fw_encoder = fw_encoder
|
||||
if fw_encoder:
|
||||
self.fw_feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels)
|
||||
|
||||
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
logger.info(f"Model dimensions: {self.model.dims}")
|
||||
self.speaker = -1
|
||||
self.decode_options = DecodingOptions(
|
||||
@@ -72,7 +68,7 @@ class PaddedAlignAttWhisper:
|
||||
without_timestamps = True,
|
||||
task=cfg.task
|
||||
)
|
||||
self.tokenizer_is_multilingual = not model_name.endswith(".en")
|
||||
self.tokenizer_is_multilingual = cfg.tokenizer_is_multilingual
|
||||
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
|
||||
# self.create_tokenizer('en')
|
||||
self.detected_language = cfg.language if cfg.language != "auto" else None
|
||||
@@ -172,7 +168,10 @@ class PaddedAlignAttWhisper:
|
||||
self.inference.kv_cache = self.kv_cache
|
||||
|
||||
self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size)
|
||||
|
||||
|
||||
# Tokens to carry over to next chunk for incomplete UTF-8 characters
|
||||
self.pending_incomplete_tokens = []
|
||||
|
||||
def remove_hooks(self):
|
||||
for hook in self.l_hooks:
|
||||
hook.remove()
|
||||
@@ -266,6 +265,7 @@ class PaddedAlignAttWhisper:
|
||||
self.segments = []
|
||||
self.log_segments += 1
|
||||
|
||||
self.pending_incomplete_tokens = []
|
||||
|
||||
def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor):
|
||||
if self.always_fire: return True
|
||||
@@ -327,7 +327,7 @@ class PaddedAlignAttWhisper:
|
||||
self.segments = self.segments[1:]
|
||||
logger.debug(f"remove segments: {len(self.segments)} {len(self.tokens)}, cumulative offset: {self.cumulative_time_offset:.2f}s")
|
||||
if len(self.tokens) > 1:
|
||||
self.context.append_token_ids(self.tokens[1][0,:])
|
||||
self.context.append_token_ids(self.tokens[1][0,:].tolist())
|
||||
self.tokens = [self.initial_tokens] + self.tokens[2:]
|
||||
return removed_len
|
||||
|
||||
@@ -567,6 +567,12 @@ class PaddedAlignAttWhisper:
|
||||
|
||||
tokens_to_split = current_tokens[0, token_len_before_decoding:]
|
||||
|
||||
# Prepend pending tokens from previous chunk if any
|
||||
if self.pending_incomplete_tokens:
|
||||
logger.debug(f"[UTF-8 Fix] Prepending {len(self.pending_incomplete_tokens)} pending tokens: {self.pending_incomplete_tokens}")
|
||||
pending_tensor = torch.tensor(self.pending_incomplete_tokens, dtype=torch.long, device=self.device)
|
||||
tokens_to_split = torch.cat([pending_tensor, tokens_to_split])
|
||||
|
||||
if fire_detected or is_last: #or punctuation_stop:
|
||||
new_hypothesis = tokens_to_split.flatten().tolist()
|
||||
split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis)
|
||||
@@ -595,7 +601,14 @@ class PaddedAlignAttWhisper:
|
||||
|
||||
timestamped_words = []
|
||||
timestamp_idx = 0
|
||||
replacement_char = "\ufffd"
|
||||
for word, word_tokens in zip(split_words, split_tokens):
|
||||
# Skip words containing incomplete UTF-8 from client output
|
||||
if replacement_char in word:
|
||||
logger.warning(f"[UTF-8 Filter] Skipping incomplete word from client output: {repr(word)}")
|
||||
timestamp_idx += len(word_tokens)
|
||||
continue
|
||||
|
||||
try:
|
||||
current_timestamp = l_absolute_timestamps[timestamp_idx]
|
||||
except:
|
||||
@@ -613,5 +626,11 @@ class PaddedAlignAttWhisper:
|
||||
self.global_time_offset
|
||||
)
|
||||
timestamped_words.append(timestamp_entry)
|
||||
|
||||
return timestamped_words
|
||||
|
||||
# Hold incomplete tokens for next chunk
|
||||
self.pending_incomplete_tokens = []
|
||||
if split_words and replacement_char in split_words[-1]:
|
||||
self.pending_incomplete_tokens = split_tokens[-1]
|
||||
logger.warning(f"[UTF-8 Fix] Holding {len(self.pending_incomplete_tokens)} incomplete tokens for next chunk: {self.pending_incomplete_tokens}")
|
||||
|
||||
return timestamped_words
|
||||
|
||||
@@ -7,6 +7,7 @@ class TokenBuffer:
|
||||
self.prefix_token_ids = prefix_token_ids
|
||||
self.tokenizer = tokenizer
|
||||
self.device = device
|
||||
self.pending_token_ids = []
|
||||
|
||||
def as_token_ids(self, tokenizer=None):
|
||||
|
||||
@@ -64,7 +65,26 @@ class TokenBuffer:
|
||||
def append_token_ids(self, token_ids):
|
||||
tokenizer = self.tokenizer
|
||||
assert tokenizer is not None, "Tokenizer is not set."
|
||||
self.text += self.tokenizer.decode(token_ids)
|
||||
|
||||
all_tokens = self.pending_token_ids + token_ids
|
||||
|
||||
decoded = tokenizer.decode(all_tokens)
|
||||
replacement_char = "\ufffd"
|
||||
|
||||
if replacement_char in decoded:
|
||||
if len(all_tokens) > 1:
|
||||
decoded_partial = tokenizer.decode(all_tokens[:-1])
|
||||
|
||||
if replacement_char not in decoded_partial:
|
||||
self.text += decoded_partial
|
||||
self.pending_token_ids = [all_tokens[-1]]
|
||||
else:
|
||||
self.pending_token_ids = all_tokens
|
||||
else:
|
||||
self.pending_token_ids = all_tokens
|
||||
else:
|
||||
self.text += decoded
|
||||
self.pending_token_ids = []
|
||||
|
||||
def as_split_word_tokens(self):
|
||||
tokenizer = self.tokenizer
|
||||
|
||||
@@ -151,6 +151,7 @@ class FrontData():
|
||||
lines: list[Line] = field(default_factory=list)
|
||||
buffer_transcription: str = ''
|
||||
buffer_diarization: str = ''
|
||||
buffer_translation: str = ''
|
||||
remaining_time_transcription: float = 0.
|
||||
remaining_time_diarization: float = 0.
|
||||
|
||||
@@ -160,6 +161,7 @@ class FrontData():
|
||||
'lines': [line.to_dict() for line in self.lines if (line.text or line.speaker == -2)],
|
||||
'buffer_transcription': self.buffer_transcription,
|
||||
'buffer_diarization': self.buffer_diarization,
|
||||
'buffer_translation': self.buffer_translation,
|
||||
'remaining_time_transcription': self.remaining_time_transcription,
|
||||
'remaining_time_diarization': self.remaining_time_diarization,
|
||||
}
|
||||
@@ -174,11 +176,14 @@ class ChangeSpeaker:
|
||||
|
||||
@dataclass
|
||||
class State():
|
||||
tokens: list
|
||||
last_validated_token: int
|
||||
translated_segments: list
|
||||
buffer_transcription: str
|
||||
end_buffer: float
|
||||
end_attributed_speaker: float
|
||||
remaining_time_transcription: float
|
||||
remaining_time_diarization: float
|
||||
tokens: list = field(default_factory=list)
|
||||
last_validated_token: int = 0
|
||||
last_punctuation_index: Optional[int] = None
|
||||
translation_validated_segments: list = field(default_factory=list)
|
||||
buffer_translation: str = field(default_factory=Transcript)
|
||||
buffer_transcription: str = field(default_factory=Transcript)
|
||||
end_buffer: float = 0.0
|
||||
end_attributed_speaker: float = 0.0
|
||||
remaining_time_transcription: float = 0.0
|
||||
remaining_time_diarization: float = 0.0
|
||||
beg_loop: Optional[int] = None
|
||||
|
||||
@@ -1,182 +0,0 @@
|
||||
"""
|
||||
adapted from https://store.crowdin.com/custom-mt
|
||||
"""
|
||||
|
||||
LANGUAGES = [
|
||||
{"name": "Afrikaans", "nllb": "afr_Latn", "crowdin": "af"},
|
||||
{"name": "Akan", "nllb": "aka_Latn", "crowdin": "ak"},
|
||||
{"name": "Amharic", "nllb": "amh_Ethi", "crowdin": "am"},
|
||||
{"name": "Assamese", "nllb": "asm_Beng", "crowdin": "as"},
|
||||
{"name": "Asturian", "nllb": "ast_Latn", "crowdin": "ast"},
|
||||
{"name": "Bashkir", "nllb": "bak_Cyrl", "crowdin": "ba"},
|
||||
{"name": "Bambara", "nllb": "bam_Latn", "crowdin": "bm"},
|
||||
{"name": "Balinese", "nllb": "ban_Latn", "crowdin": "ban"},
|
||||
{"name": "Belarusian", "nllb": "bel_Cyrl", "crowdin": "be"},
|
||||
{"name": "Bengali", "nllb": "ben_Beng", "crowdin": "bn"},
|
||||
{"name": "Bosnian", "nllb": "bos_Latn", "crowdin": "bs"},
|
||||
{"name": "Bulgarian", "nllb": "bul_Cyrl", "crowdin": "bg"},
|
||||
{"name": "Catalan", "nllb": "cat_Latn", "crowdin": "ca"},
|
||||
{"name": "Cebuano", "nllb": "ceb_Latn", "crowdin": "ceb"},
|
||||
{"name": "Czech", "nllb": "ces_Latn", "crowdin": "cs"},
|
||||
{"name": "Welsh", "nllb": "cym_Latn", "crowdin": "cy"},
|
||||
{"name": "Danish", "nllb": "dan_Latn", "crowdin": "da"},
|
||||
{"name": "German", "nllb": "deu_Latn", "crowdin": "de"},
|
||||
{"name": "Dzongkha", "nllb": "dzo_Tibt", "crowdin": "dz"},
|
||||
{"name": "Greek", "nllb": "ell_Grek", "crowdin": "el"},
|
||||
{"name": "English", "nllb": "eng_Latn", "crowdin": "en"},
|
||||
{"name": "Esperanto", "nllb": "epo_Latn", "crowdin": "eo"},
|
||||
{"name": "Estonian", "nllb": "est_Latn", "crowdin": "et"},
|
||||
{"name": "Basque", "nllb": "eus_Latn", "crowdin": "eu"},
|
||||
{"name": "Ewe", "nllb": "ewe_Latn", "crowdin": "ee"},
|
||||
{"name": "Faroese", "nllb": "fao_Latn", "crowdin": "fo"},
|
||||
{"name": "Fijian", "nllb": "fij_Latn", "crowdin": "fj"},
|
||||
{"name": "Finnish", "nllb": "fin_Latn", "crowdin": "fi"},
|
||||
{"name": "French", "nllb": "fra_Latn", "crowdin": "fr"},
|
||||
{"name": "Friulian", "nllb": "fur_Latn", "crowdin": "fur-IT"},
|
||||
{"name": "Scottish Gaelic", "nllb": "gla_Latn", "crowdin": "gd"},
|
||||
{"name": "Irish", "nllb": "gle_Latn", "crowdin": "ga-IE"},
|
||||
{"name": "Galician", "nllb": "glg_Latn", "crowdin": "gl"},
|
||||
{"name": "Guarani", "nllb": "grn_Latn", "crowdin": "gn"},
|
||||
{"name": "Gujarati", "nllb": "guj_Gujr", "crowdin": "gu-IN"},
|
||||
{"name": "Haitian Creole", "nllb": "hat_Latn", "crowdin": "ht"},
|
||||
{"name": "Hausa", "nllb": "hau_Latn", "crowdin": "ha"},
|
||||
{"name": "Hebrew", "nllb": "heb_Hebr", "crowdin": "he"},
|
||||
{"name": "Hindi", "nllb": "hin_Deva", "crowdin": "hi"},
|
||||
{"name": "Croatian", "nllb": "hrv_Latn", "crowdin": "hr"},
|
||||
{"name": "Hungarian", "nllb": "hun_Latn", "crowdin": "hu"},
|
||||
{"name": "Armenian", "nllb": "hye_Armn", "crowdin": "hy-AM"},
|
||||
{"name": "Igbo", "nllb": "ibo_Latn", "crowdin": "ig"},
|
||||
{"name": "Indonesian", "nllb": "ind_Latn", "crowdin": "id"},
|
||||
{"name": "Icelandic", "nllb": "isl_Latn", "crowdin": "is"},
|
||||
{"name": "Italian", "nllb": "ita_Latn", "crowdin": "it"},
|
||||
{"name": "Javanese", "nllb": "jav_Latn", "crowdin": "jv"},
|
||||
{"name": "Japanese", "nllb": "jpn_Jpan", "crowdin": "ja"},
|
||||
{"name": "Kabyle", "nllb": "kab_Latn", "crowdin": "kab"},
|
||||
{"name": "Kannada", "nllb": "kan_Knda", "crowdin": "kn"},
|
||||
{"name": "Georgian", "nllb": "kat_Geor", "crowdin": "ka"},
|
||||
{"name": "Kazakh", "nllb": "kaz_Cyrl", "crowdin": "kk"},
|
||||
{"name": "Khmer", "nllb": "khm_Khmr", "crowdin": "km"},
|
||||
{"name": "Kinyarwanda", "nllb": "kin_Latn", "crowdin": "rw"},
|
||||
{"name": "Kyrgyz", "nllb": "kir_Cyrl", "crowdin": "ky"},
|
||||
{"name": "Korean", "nllb": "kor_Hang", "crowdin": "ko"},
|
||||
{"name": "Lao", "nllb": "lao_Laoo", "crowdin": "lo"},
|
||||
{"name": "Ligurian", "nllb": "lij_Latn", "crowdin": "lij"},
|
||||
{"name": "Limburgish", "nllb": "lim_Latn", "crowdin": "li"},
|
||||
{"name": "Lingala", "nllb": "lin_Latn", "crowdin": "ln"},
|
||||
{"name": "Lithuanian", "nllb": "lit_Latn", "crowdin": "lt"},
|
||||
{"name": "Luxembourgish", "nllb": "ltz_Latn", "crowdin": "lb"},
|
||||
{"name": "Maithili", "nllb": "mai_Deva", "crowdin": "mai"},
|
||||
{"name": "Malayalam", "nllb": "mal_Mlym", "crowdin": "ml-IN"},
|
||||
{"name": "Marathi", "nllb": "mar_Deva", "crowdin": "mr"},
|
||||
{"name": "Macedonian", "nllb": "mkd_Cyrl", "crowdin": "mk"},
|
||||
{"name": "Maltese", "nllb": "mlt_Latn", "crowdin": "mt"},
|
||||
{"name": "Mossi", "nllb": "mos_Latn", "crowdin": "mos"},
|
||||
{"name": "Maori", "nllb": "mri_Latn", "crowdin": "mi"},
|
||||
{"name": "Burmese", "nllb": "mya_Mymr", "crowdin": "my"},
|
||||
{"name": "Dutch", "nllb": "nld_Latn", "crowdin": "nl"},
|
||||
{"name": "Norwegian Nynorsk", "nllb": "nno_Latn", "crowdin": "nn-NO"},
|
||||
{"name": "Nepali", "nllb": "npi_Deva", "crowdin": "ne-NP"},
|
||||
{"name": "Northern Sotho", "nllb": "nso_Latn", "crowdin": "nso"},
|
||||
{"name": "Occitan", "nllb": "oci_Latn", "crowdin": "oc"},
|
||||
{"name": "Odia", "nllb": "ory_Orya", "crowdin": "or"},
|
||||
{"name": "Papiamento", "nllb": "pap_Latn", "crowdin": "pap"},
|
||||
{"name": "Polish", "nllb": "pol_Latn", "crowdin": "pl"},
|
||||
{"name": "Portuguese", "nllb": "por_Latn", "crowdin": "pt-PT"},
|
||||
{"name": "Dari", "nllb": "prs_Arab", "crowdin": "fa-AF"},
|
||||
{"name": "Romanian", "nllb": "ron_Latn", "crowdin": "ro"},
|
||||
{"name": "Rundi", "nllb": "run_Latn", "crowdin": "rn"},
|
||||
{"name": "Russian", "nllb": "rus_Cyrl", "crowdin": "ru"},
|
||||
{"name": "Sango", "nllb": "sag_Latn", "crowdin": "sg"},
|
||||
{"name": "Sanskrit", "nllb": "san_Deva", "crowdin": "sa"},
|
||||
{"name": "Santali", "nllb": "sat_Olck", "crowdin": "sat"},
|
||||
{"name": "Sinhala", "nllb": "sin_Sinh", "crowdin": "si-LK"},
|
||||
{"name": "Slovak", "nllb": "slk_Latn", "crowdin": "sk"},
|
||||
{"name": "Slovenian", "nllb": "slv_Latn", "crowdin": "sl"},
|
||||
{"name": "Shona", "nllb": "sna_Latn", "crowdin": "sn"},
|
||||
{"name": "Sindhi", "nllb": "snd_Arab", "crowdin": "sd"},
|
||||
{"name": "Somali", "nllb": "som_Latn", "crowdin": "so"},
|
||||
{"name": "Southern Sotho", "nllb": "sot_Latn", "crowdin": "st"},
|
||||
{"name": "Spanish", "nllb": "spa_Latn", "crowdin": "es-ES"},
|
||||
{"name": "Sardinian", "nllb": "srd_Latn", "crowdin": "sc"},
|
||||
{"name": "Swati", "nllb": "ssw_Latn", "crowdin": "ss"},
|
||||
{"name": "Sundanese", "nllb": "sun_Latn", "crowdin": "su"},
|
||||
{"name": "Swedish", "nllb": "swe_Latn", "crowdin": "sv-SE"},
|
||||
{"name": "Swahili", "nllb": "swh_Latn", "crowdin": "sw"},
|
||||
{"name": "Tamil", "nllb": "tam_Taml", "crowdin": "ta"},
|
||||
{"name": "Tatar", "nllb": "tat_Cyrl", "crowdin": "tt-RU"},
|
||||
{"name": "Telugu", "nllb": "tel_Telu", "crowdin": "te"},
|
||||
{"name": "Tajik", "nllb": "tgk_Cyrl", "crowdin": "tg"},
|
||||
{"name": "Tagalog", "nllb": "tgl_Latn", "crowdin": "tl"},
|
||||
{"name": "Thai", "nllb": "tha_Thai", "crowdin": "th"},
|
||||
{"name": "Tigrinya", "nllb": "tir_Ethi", "crowdin": "ti"},
|
||||
{"name": "Tswana", "nllb": "tsn_Latn", "crowdin": "tn"},
|
||||
{"name": "Tsonga", "nllb": "tso_Latn", "crowdin": "ts"},
|
||||
{"name": "Turkmen", "nllb": "tuk_Latn", "crowdin": "tk"},
|
||||
{"name": "Turkish", "nllb": "tur_Latn", "crowdin": "tr"},
|
||||
{"name": "Uyghur", "nllb": "uig_Arab", "crowdin": "ug"},
|
||||
{"name": "Ukrainian", "nllb": "ukr_Cyrl", "crowdin": "uk"},
|
||||
{"name": "Venetian", "nllb": "vec_Latn", "crowdin": "vec"},
|
||||
{"name": "Vietnamese", "nllb": "vie_Latn", "crowdin": "vi"},
|
||||
{"name": "Wolof", "nllb": "wol_Latn", "crowdin": "wo"},
|
||||
{"name": "Xhosa", "nllb": "xho_Latn", "crowdin": "xh"},
|
||||
{"name": "Yoruba", "nllb": "yor_Latn", "crowdin": "yo"},
|
||||
{"name": "Zulu", "nllb": "zul_Latn", "crowdin": "zu"},
|
||||
]
|
||||
|
||||
NAME_TO_NLLB = {lang["name"]: lang["nllb"] for lang in LANGUAGES}
|
||||
NAME_TO_CROWDIN = {lang["name"]: lang["crowdin"] for lang in LANGUAGES}
|
||||
CROWDIN_TO_NLLB = {lang["crowdin"]: lang["nllb"] for lang in LANGUAGES}
|
||||
NLLB_TO_CROWDIN = {lang["nllb"]: lang["crowdin"] for lang in LANGUAGES}
|
||||
CROWDIN_TO_NAME = {lang["crowdin"]: lang["name"] for lang in LANGUAGES}
|
||||
NLLB_TO_NAME = {lang["nllb"]: lang["name"] for lang in LANGUAGES}
|
||||
|
||||
|
||||
def get_nllb_code(crowdin_code):
|
||||
return CROWDIN_TO_NLLB.get(crowdin_code, None)
|
||||
|
||||
|
||||
def get_crowdin_code(nllb_code):
|
||||
return NLLB_TO_CROWDIN.get(nllb_code)
|
||||
|
||||
|
||||
def get_language_name_by_crowdin(crowdin_code):
|
||||
return CROWDIN_TO_NAME.get(crowdin_code)
|
||||
|
||||
|
||||
def get_language_name_by_nllb(nllb_code):
|
||||
return NLLB_TO_NAME.get(nllb_code)
|
||||
|
||||
|
||||
def get_language_info(identifier, identifier_type="auto"):
|
||||
if identifier_type == "auto":
|
||||
for lang in LANGUAGES:
|
||||
if (lang["name"].lower() == identifier.lower() or
|
||||
lang["nllb"] == identifier or
|
||||
lang["crowdin"] == identifier):
|
||||
return lang
|
||||
elif identifier_type == "name":
|
||||
for lang in LANGUAGES:
|
||||
if lang["name"].lower() == identifier.lower():
|
||||
return lang
|
||||
elif identifier_type == "nllb":
|
||||
for lang in LANGUAGES:
|
||||
if lang["nllb"] == identifier:
|
||||
return lang
|
||||
elif identifier_type == "crowdin":
|
||||
for lang in LANGUAGES:
|
||||
if lang["crowdin"] == identifier:
|
||||
return lang
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def list_all_languages():
|
||||
return [lang["name"] for lang in LANGUAGES]
|
||||
|
||||
|
||||
def list_all_nllb_codes():
|
||||
return [lang["nllb"] for lang in LANGUAGES]
|
||||
|
||||
|
||||
def list_all_crowdin_codes():
|
||||
return [lang["crowdin"] for lang in LANGUAGES]
|
||||
@@ -1,169 +0,0 @@
|
||||
import logging
|
||||
import time
|
||||
import ctranslate2
|
||||
import torch
|
||||
import transformers
|
||||
from dataclasses import dataclass, field
|
||||
import huggingface_hub
|
||||
from whisperlivekit.translation.mapping_languages import get_nllb_code
|
||||
from whisperlivekit.timed_objects import Translation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
#In diarization case, we may want to translate just one speaker, or at least start the sentences there
|
||||
|
||||
MIN_SILENCE_DURATION_DEL_BUFFER = 3 #After a silence of x seconds, we consider the model should not use the buffer, even if the previous
|
||||
# sentence is not finished.
|
||||
|
||||
@dataclass
|
||||
class TranslationModel():
|
||||
translator: ctranslate2.Translator
|
||||
device: str
|
||||
tokenizer: dict = field(default_factory=dict)
|
||||
backend_type: str = 'ctranslate2'
|
||||
nllb_size: str = '600M'
|
||||
|
||||
def get_tokenizer(self, input_lang):
|
||||
if not self.tokenizer.get(input_lang, False):
|
||||
self.tokenizer[input_lang] = transformers.AutoTokenizer.from_pretrained(
|
||||
f"facebook/nllb-200-distilled-{self.nllb_size}",
|
||||
src_lang=input_lang,
|
||||
clean_up_tokenization_spaces=True
|
||||
)
|
||||
return self.tokenizer[input_lang]
|
||||
|
||||
|
||||
def load_model(src_langs, nllb_backend='ctranslate2', nllb_size='600M'):
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
MODEL = f'nllb-200-distilled-{nllb_size}-ctranslate2'
|
||||
if nllb_backend=='ctranslate2':
|
||||
MODEL_GUY = 'entai2965'
|
||||
huggingface_hub.snapshot_download(MODEL_GUY + '/' + MODEL,local_dir=MODEL)
|
||||
translator = ctranslate2.Translator(MODEL,device=device)
|
||||
elif nllb_backend=='transformers':
|
||||
translator = transformers.AutoModelForSeq2SeqLM.from_pretrained(f"facebook/nllb-200-distilled-{nllb_size}")
|
||||
tokenizer = dict()
|
||||
for src_lang in src_langs:
|
||||
if src_lang != 'auto':
|
||||
tokenizer[src_lang] = transformers.AutoTokenizer.from_pretrained(MODEL, src_lang=src_lang, clean_up_tokenization_spaces=True)
|
||||
|
||||
translation_model = TranslationModel(
|
||||
translator=translator,
|
||||
tokenizer=tokenizer,
|
||||
backend_type=nllb_backend,
|
||||
device = device,
|
||||
nllb_size = nllb_size
|
||||
)
|
||||
for src_lang in src_langs:
|
||||
if src_lang != 'auto':
|
||||
translation_model.get_tokenizer(src_lang)
|
||||
return translation_model
|
||||
|
||||
class OnlineTranslation:
|
||||
def __init__(self, translation_model: TranslationModel, input_languages: list, output_languages: list):
|
||||
self.buffer = []
|
||||
self.len_processed_buffer = 0
|
||||
self.translation_remaining = Translation()
|
||||
self.validated = []
|
||||
self.translation_pending_validation = ''
|
||||
self.translation_model = translation_model
|
||||
self.input_languages = input_languages
|
||||
self.output_languages = output_languages
|
||||
|
||||
def compute_common_prefix(self, results):
|
||||
#we dont want want to prune the result for the moment.
|
||||
if not self.buffer:
|
||||
self.buffer = results
|
||||
else:
|
||||
for i in range(min(len(self.buffer), len(results))):
|
||||
if self.buffer[i] != results[i]:
|
||||
self.commited.extend(self.buffer[:i])
|
||||
self.buffer = results[i:]
|
||||
|
||||
def translate(self, input, input_lang, output_lang):
|
||||
if not input:
|
||||
return ""
|
||||
nllb_output_lang = get_nllb_code(output_lang)
|
||||
|
||||
tokenizer = self.translation_model.get_tokenizer(input_lang)
|
||||
tokenizer_output = tokenizer(input, return_tensors="pt").to(self.translation_model.device)
|
||||
|
||||
if self.translation_model.backend_type == 'ctranslate2':
|
||||
source = tokenizer.convert_ids_to_tokens(tokenizer_output['input_ids'][0])
|
||||
results = self.translation_model.translator.translate_batch([source], target_prefix=[[nllb_output_lang]])
|
||||
target = results[0].hypotheses[0][1:]
|
||||
result = tokenizer.decode(tokenizer.convert_tokens_to_ids(target))
|
||||
else:
|
||||
translated_tokens = self.translation_model.translator.generate(**tokenizer_output, forced_bos_token_id=tokenizer.convert_tokens_to_ids(nllb_output_lang))
|
||||
result = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
|
||||
return result
|
||||
|
||||
def translate_tokens(self, tokens):
|
||||
if tokens:
|
||||
text = ' '.join([token.text for token in tokens])
|
||||
start = tokens[0].start
|
||||
end = tokens[-1].end
|
||||
if self.input_languages[0] == 'auto':
|
||||
input_lang = tokens[0].detected_language
|
||||
else:
|
||||
input_lang = self.input_languages[0]
|
||||
|
||||
translated_text = self.translate(text,
|
||||
input_lang,
|
||||
self.output_languages[0]
|
||||
)
|
||||
translation = Translation(
|
||||
text=translated_text,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
return translation
|
||||
return None
|
||||
|
||||
|
||||
def insert_tokens(self, tokens):
|
||||
self.buffer.extend(tokens)
|
||||
pass
|
||||
|
||||
def process(self):
|
||||
i = 0
|
||||
if len(self.buffer) < self.len_processed_buffer + 3: #nothing new to process
|
||||
return self.validated + [self.translation_remaining]
|
||||
while i < len(self.buffer):
|
||||
if self.buffer[i].is_punctuation():
|
||||
translation_sentence = self.translate_tokens(self.buffer[:i+1])
|
||||
self.validated.append(translation_sentence)
|
||||
self.buffer = self.buffer[i+1:]
|
||||
i = 0
|
||||
else:
|
||||
i+=1
|
||||
self.translation_remaining = self.translate_tokens(self.buffer)
|
||||
self.len_processed_buffer = len(self.buffer)
|
||||
return self.validated + [self.translation_remaining]
|
||||
|
||||
def insert_silence(self, silence_duration: float):
|
||||
if silence_duration >= MIN_SILENCE_DURATION_DEL_BUFFER:
|
||||
self.buffer = []
|
||||
self.validated += [self.translation_remaining]
|
||||
|
||||
if __name__ == '__main__':
|
||||
output_lang = 'fr'
|
||||
input_lang = "en"
|
||||
|
||||
|
||||
test_string = """
|
||||
Transcription technology has improved so much in the past few years. Have you noticed how accurate real-time speech-to-text is now?
|
||||
"""
|
||||
test = test_string.split(' ')
|
||||
step = len(test) // 3
|
||||
|
||||
shared_model = load_model([input_lang], nllb_backend='ctranslate2')
|
||||
online_translation = OnlineTranslation(shared_model, input_languages=[input_lang], output_languages=[output_lang])
|
||||
|
||||
beg_inference = time.time()
|
||||
for id in range(5):
|
||||
val = test[id*step : (id+1)*step]
|
||||
val_str = ' '.join(val)
|
||||
result = online_translation.translate(val_str)
|
||||
print(result)
|
||||
print('inference time:', time.time() - beg_inference)
|
||||
BIN
whisperlivekit/vad_models/silero_vad.jit
Normal file
BIN
whisperlivekit/vad_models/silero_vad.jit
Normal file
Binary file not shown.
BIN
whisperlivekit/vad_models/silero_vad.onnx
Normal file
BIN
whisperlivekit/vad_models/silero_vad.onnx
Normal file
Binary file not shown.
BIN
whisperlivekit/vad_models/silero_vad_16k_op15.onnx
Normal file
BIN
whisperlivekit/vad_models/silero_vad_16k_op15.onnx
Normal file
Binary file not shown.
BIN
whisperlivekit/vad_models/silero_vad_half.onnx
Normal file
BIN
whisperlivekit/vad_models/silero_vad_half.onnx
Normal file
Binary file not shown.
@@ -490,6 +490,11 @@ label {
|
||||
margin-left: 4px;
|
||||
}
|
||||
|
||||
.buffer_translation {
|
||||
color: #a0a0a0;
|
||||
margin-left: 6px;
|
||||
}
|
||||
|
||||
.spinner {
|
||||
display: inline-block;
|
||||
width: 8px;
|
||||
|
||||
@@ -232,10 +232,11 @@ function setupWebSocket() {
|
||||
if (waitingForStop) {
|
||||
statusText.textContent = "Processing finalized or connection closed.";
|
||||
if (lastReceivedData) {
|
||||
renderLinesWithBuffer(
|
||||
renderLinesWithBuffer(
|
||||
lastReceivedData.lines || [],
|
||||
lastReceivedData.buffer_diarization || "",
|
||||
lastReceivedData.buffer_transcription || "",
|
||||
lastReceivedData.buffer_translation || "",
|
||||
0,
|
||||
0,
|
||||
true
|
||||
@@ -281,6 +282,7 @@ function setupWebSocket() {
|
||||
lastReceivedData.lines || [],
|
||||
lastReceivedData.buffer_diarization || "",
|
||||
lastReceivedData.buffer_transcription || "",
|
||||
lastReceivedData.buffer_translation || "",
|
||||
0,
|
||||
0,
|
||||
true
|
||||
@@ -301,6 +303,7 @@ function setupWebSocket() {
|
||||
lines = [],
|
||||
buffer_transcription = "",
|
||||
buffer_diarization = "",
|
||||
buffer_translation = "",
|
||||
remaining_time_transcription = 0,
|
||||
remaining_time_diarization = 0,
|
||||
status = "active_transcription",
|
||||
@@ -310,6 +313,7 @@ function setupWebSocket() {
|
||||
lines,
|
||||
buffer_diarization,
|
||||
buffer_transcription,
|
||||
buffer_translation,
|
||||
remaining_time_diarization,
|
||||
remaining_time_transcription,
|
||||
false,
|
||||
@@ -323,6 +327,7 @@ function renderLinesWithBuffer(
|
||||
lines,
|
||||
buffer_diarization,
|
||||
buffer_transcription,
|
||||
buffer_translation,
|
||||
remaining_time_diarization,
|
||||
remaining_time_transcription,
|
||||
isFinalizing = false,
|
||||
@@ -341,6 +346,7 @@ function renderLinesWithBuffer(
|
||||
lines: (lines || []).map((it) => ({ speaker: it.speaker, text: it.text, start: it.start, end: it.end, detected_language: it.detected_language })),
|
||||
buffer_transcription: buffer_transcription || "",
|
||||
buffer_diarization: buffer_diarization || "",
|
||||
buffer_translation: buffer_translation,
|
||||
status: current_status,
|
||||
showLoading,
|
||||
showTransLag,
|
||||
@@ -415,13 +421,22 @@ function renderLinesWithBuffer(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let translationContent = "";
|
||||
if (item.translation) {
|
||||
translationContent += item.translation.trim();
|
||||
}
|
||||
if (idx === lines.length - 1 && buffer_translation) {
|
||||
const bufferPiece = isFinalizing
|
||||
? buffer_translation
|
||||
: `<span class="buffer_translation">${buffer_translation}</span>`;
|
||||
translationContent += translationContent ? `${bufferPiece}` : bufferPiece;
|
||||
}
|
||||
if (translationContent.trim().length > 0) {
|
||||
currentLineText += `
|
||||
<div>
|
||||
<div class="label_translation">
|
||||
${translationIcon}
|
||||
<span>${item.translation}</span>
|
||||
<span class="translation_text">${translationContent}</span>
|
||||
</div>
|
||||
</div>`;
|
||||
}
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
import hashlib
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import urllib
|
||||
import warnings
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional, Union, Dict
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
|
||||
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
|
||||
@@ -100,6 +102,137 @@ def available_models() -> List[str]:
|
||||
return list(_MODELS.keys())
|
||||
|
||||
|
||||
def _infer_dims_from_config(path: str) -> Optional[ModelDimensions]:
|
||||
"""
|
||||
attempt to infer ModelDimensions from a HF style config.json located
|
||||
next to the given checkpoint, usefull for distilled models
|
||||
"""
|
||||
candidates = []
|
||||
if os.path.isdir(path):
|
||||
candidates.append(os.path.join(path, "config.json"))
|
||||
else:
|
||||
candidates.append(os.path.join(os.path.dirname(path), "config.json"))
|
||||
|
||||
for candidate in candidates:
|
||||
if not os.path.isfile(candidate):
|
||||
continue
|
||||
with open(candidate, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
|
||||
try:
|
||||
return ModelDimensions(
|
||||
n_mels=config["num_mel_bins"],
|
||||
n_audio_ctx=config["max_source_positions"],
|
||||
n_audio_state=config["d_model"],
|
||||
n_audio_head=config["encoder_attention_heads"],
|
||||
n_audio_layer=config.get("encoder_layers")
|
||||
or config["num_hidden_layers"],
|
||||
n_vocab=config["vocab_size"],
|
||||
n_text_ctx=config["max_target_positions"],
|
||||
n_text_state=config["d_model"],
|
||||
n_text_head=config["decoder_attention_heads"],
|
||||
n_text_layer=config["decoder_layers"],
|
||||
)
|
||||
except KeyError as err:
|
||||
warnings.warn(f"Missing key {err} in HuggingFace config {candidate}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _convert_hf_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
converts a HF checkpoint state_dict into the naming convention used by
|
||||
default whisper
|
||||
"""
|
||||
|
||||
if not any(k.startswith("model.") for k in state_dict):
|
||||
return state_dict
|
||||
|
||||
def map_block(prefix: str, target_prefix: str, remainder: str) -> Optional[str]:
|
||||
if remainder.startswith("self_attn."):
|
||||
suffix = remainder.split(".", 1)[1]
|
||||
mapping = {
|
||||
"q_proj": "attn.query",
|
||||
"k_proj": "attn.key",
|
||||
"v_proj": "attn.value",
|
||||
"out_proj": "attn.out",
|
||||
}
|
||||
stem = mapping.get(suffix.split(".")[0])
|
||||
if stem:
|
||||
rest = suffix.split(".", 1)[1] if "." in suffix else ""
|
||||
return f"{target_prefix}.{stem}" + (f".{rest}" if rest else "")
|
||||
elif remainder == "self_attn_layer_norm.weight":
|
||||
return f"{target_prefix}.attn_ln.weight"
|
||||
elif remainder == "self_attn_layer_norm.bias":
|
||||
return f"{target_prefix}.attn_ln.bias"
|
||||
elif remainder.startswith("encoder_attn."):
|
||||
suffix = remainder.split(".", 1)[1]
|
||||
mapping = {
|
||||
"q_proj": "cross_attn.query",
|
||||
"k_proj": "cross_attn.key",
|
||||
"v_proj": "cross_attn.value",
|
||||
"out_proj": "cross_attn.out",
|
||||
}
|
||||
stem = mapping.get(suffix.split(".", 1)[0])
|
||||
if stem:
|
||||
rest = suffix.split(".", 1)[1] if "." in suffix else ""
|
||||
return f"{target_prefix}.{stem}" + (f".{rest}" if rest else "")
|
||||
elif remainder == "encoder_attn_layer_norm.weight":
|
||||
return f"{target_prefix}.cross_attn_ln.weight"
|
||||
elif remainder == "encoder_attn_layer_norm.bias":
|
||||
return f"{target_prefix}.cross_attn_ln.bias"
|
||||
elif remainder.startswith("fc1."):
|
||||
return f"{target_prefix}.mlp.0.{remainder.split('.',1)[1]}"
|
||||
elif remainder.startswith("fc2."):
|
||||
return f"{target_prefix}.mlp.2.{remainder.split('.',1)[1]}"
|
||||
elif remainder == "final_layer_norm.weight":
|
||||
return f"{target_prefix}.mlp_ln.weight"
|
||||
elif remainder == "final_layer_norm.bias":
|
||||
return f"{target_prefix}.mlp_ln.bias"
|
||||
return None
|
||||
|
||||
converted = {}
|
||||
for key, value in state_dict.items():
|
||||
if not key.startswith("model."):
|
||||
continue
|
||||
subkey = key[len("model.") :]
|
||||
|
||||
if subkey.startswith("encoder.layers."):
|
||||
parts = subkey.split(".")
|
||||
layer_idx = parts[2]
|
||||
remainder = ".".join(parts[3:])
|
||||
mapped = map_block(subkey, f"encoder.blocks.{layer_idx}", remainder)
|
||||
elif subkey.startswith("decoder.layers."):
|
||||
parts = subkey.split(".")
|
||||
layer_idx = parts[2]
|
||||
remainder = ".".join(parts[3:])
|
||||
mapped = map_block(subkey, f"decoder.blocks.{layer_idx}", remainder)
|
||||
elif subkey.startswith("encoder.conv") or subkey.startswith("decoder.conv"):
|
||||
mapped = subkey
|
||||
elif subkey == "encoder.embed_positions.weight":
|
||||
mapped = "encoder.positional_embedding"
|
||||
elif subkey == "decoder.embed_positions.weight":
|
||||
mapped = "decoder.positional_embedding"
|
||||
elif subkey == "encoder.layer_norm.weight":
|
||||
mapped = "encoder.ln_post.weight"
|
||||
elif subkey == "encoder.layer_norm.bias":
|
||||
mapped = "encoder.ln_post.bias"
|
||||
elif subkey.startswith("decoder.embed_tokens."):
|
||||
mapped = subkey.replace("embed_tokens", "token_embedding", 1)
|
||||
elif subkey == "decoder.layer_norm.weight":
|
||||
mapped = "decoder.ln.weight"
|
||||
elif subkey == "decoder.layer_norm.bias":
|
||||
mapped = "decoder.ln.bias"
|
||||
else:
|
||||
mapped = None
|
||||
|
||||
if mapped:
|
||||
converted[mapped] = value
|
||||
|
||||
return converted if converted else state_dict
|
||||
|
||||
|
||||
def load_model(
|
||||
name: str,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
@@ -134,7 +267,6 @@ def load_model(
|
||||
if download_root is None:
|
||||
default = os.path.join(os.path.expanduser("~"), ".cache")
|
||||
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
|
||||
|
||||
if name in _MODELS:
|
||||
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
|
||||
elif os.path.isfile(name):
|
||||
@@ -148,22 +280,50 @@ def load_model(
|
||||
if custom_alignment_heads:
|
||||
alignment_heads = custom_alignment_heads.encode()
|
||||
|
||||
with (
|
||||
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
|
||||
) as fp:
|
||||
checkpoint = torch.load(fp, map_location=device)
|
||||
if isinstance(checkpoint_file, Path) and checkpoint_file.suffix == '.safetensors':
|
||||
try:
|
||||
from safetensors.torch import load_file
|
||||
except ImportError:
|
||||
raise ImportError("Please install safetensors to load .safetensors model files: `pip install safetensors`")
|
||||
if in_memory:
|
||||
checkpoint = load_file(checkpoint_file, device=device)
|
||||
else:
|
||||
checkpoint = load_file(checkpoint_file, device=device)
|
||||
else:
|
||||
with (
|
||||
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
|
||||
) as fp:
|
||||
checkpoint = torch.load(fp, map_location=device)
|
||||
del checkpoint_file
|
||||
|
||||
dims = ModelDimensions(**checkpoint["dims"])
|
||||
dims_cfg = checkpoint.get("dims") if isinstance(checkpoint, dict) else None
|
||||
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
|
||||
state_dict = checkpoint["model_state_dict"]
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
state_dict = _convert_hf_state_dict(state_dict)
|
||||
|
||||
if dims_cfg is not None:
|
||||
dims = ModelDimensions(**dims_cfg)
|
||||
else:
|
||||
dims = _infer_dims_from_config(name)
|
||||
if dims is None:
|
||||
raise RuntimeError(
|
||||
"Could not determine model dimensions. "
|
||||
"Ensure the checkpoint includes 'dims' or a HuggingFace config.json is present."
|
||||
)
|
||||
if not isinstance(state_dict, dict):
|
||||
state_dict = checkpoint
|
||||
|
||||
model = Whisper(dims, decoder_only=decoder_only)
|
||||
|
||||
if decoder_only:
|
||||
checkpoint["model_state_dict"] = {
|
||||
k: v for k, v in checkpoint["model_state_dict"].items()
|
||||
state_dict = {
|
||||
k: v for k, v in state_dict.items()
|
||||
if 'encoder' not in k
|
||||
}
|
||||
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
if alignment_heads is not None:
|
||||
model.set_alignment_heads(alignment_heads)
|
||||
@@ -1,122 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import sys
|
||||
import numpy as np
|
||||
import librosa
|
||||
from functools import lru_cache
|
||||
import time
|
||||
import logging
|
||||
from .backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR
|
||||
from whisperlivekit.warmup import warmup_asr
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
WHISPER_LANG_CODES = "af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh".split(
|
||||
","
|
||||
)
|
||||
|
||||
|
||||
def create_tokenizer(lan):
|
||||
"""returns an object that has split function that works like the one of MosesTokenizer"""
|
||||
|
||||
assert (
|
||||
lan in WHISPER_LANG_CODES
|
||||
), "language must be Whisper's supported lang code: " + " ".join(WHISPER_LANG_CODES)
|
||||
|
||||
if lan == "uk":
|
||||
import tokenize_uk
|
||||
|
||||
class UkrainianTokenizer:
|
||||
def split(self, text):
|
||||
return tokenize_uk.tokenize_sents(text)
|
||||
|
||||
return UkrainianTokenizer()
|
||||
|
||||
# supported by fast-mosestokenizer
|
||||
if (
|
||||
lan
|
||||
in "as bn ca cs de el en es et fi fr ga gu hi hu is it kn lt lv ml mni mr nl or pa pl pt ro ru sk sl sv ta te yue zh".split()
|
||||
):
|
||||
from mosestokenizer import MosesSentenceSplitter
|
||||
|
||||
return MosesSentenceSplitter(lan)
|
||||
|
||||
# the following languages are in Whisper, but not in wtpsplit:
|
||||
if (
|
||||
lan
|
||||
in "as ba bo br bs fo haw hr ht jw lb ln lo mi nn oc sa sd sn so su sw tk tl tt".split()
|
||||
):
|
||||
logger.debug(
|
||||
f"{lan} code is not supported by wtpsplit. Going to use None lang_code option."
|
||||
)
|
||||
lan = None
|
||||
|
||||
from wtpsplit import WtP
|
||||
|
||||
# downloads the model from huggingface on the first use
|
||||
wtp = WtP("wtp-canine-s-12l-no-adapters")
|
||||
|
||||
class WtPtok:
|
||||
def split(self, sent):
|
||||
return wtp.split(sent, lang_code=lan)
|
||||
|
||||
return WtPtok()
|
||||
|
||||
|
||||
def backend_factory(
|
||||
backend,
|
||||
lan,
|
||||
model_size,
|
||||
model_cache_dir,
|
||||
model_dir,
|
||||
task,
|
||||
buffer_trimming,
|
||||
buffer_trimming_sec,
|
||||
confidence_validation,
|
||||
warmup_file=None,
|
||||
min_chunk_size=None,
|
||||
):
|
||||
backend = backend
|
||||
if backend == "openai-api":
|
||||
logger.debug("Using OpenAI API.")
|
||||
asr = OpenaiApiASR(lan=lan)
|
||||
else:
|
||||
if backend == "faster-whisper":
|
||||
asr_cls = FasterWhisperASR
|
||||
elif backend == "mlx-whisper":
|
||||
asr_cls = MLXWhisper
|
||||
else:
|
||||
asr_cls = WhisperTimestampedASR
|
||||
|
||||
# Only for FasterWhisperASR and WhisperTimestampedASR
|
||||
|
||||
t = time.time()
|
||||
logger.info(f"Loading Whisper {model_size} model for language {lan}...")
|
||||
asr = asr_cls(
|
||||
model_size=model_size,
|
||||
lan=lan,
|
||||
cache_dir=model_cache_dir,
|
||||
model_dir=model_dir,
|
||||
)
|
||||
e = time.time()
|
||||
logger.info(f"done. It took {round(e-t,2)} seconds.")
|
||||
|
||||
if task == "translate":
|
||||
tgt_language = "en" # Whisper translates into English
|
||||
else:
|
||||
tgt_language = lan # Whisper transcribes in this language
|
||||
|
||||
# Create the tokenizer
|
||||
if buffer_trimming == "sentence":
|
||||
tokenizer = create_tokenizer(tgt_language)
|
||||
else:
|
||||
tokenizer = None
|
||||
|
||||
warmup_asr(asr, warmup_file)
|
||||
|
||||
asr.confidence_validation = confidence_validation
|
||||
asr.tokenizer = tokenizer
|
||||
asr.buffer_trimming = buffer_trimming
|
||||
asr.buffer_trimming_sec = buffer_trimming_sec
|
||||
return asr
|
||||
Reference in New Issue
Block a user