mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
Compare commits
30 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bcffdbc6b3 | ||
|
|
80b77998f9 | ||
|
|
d310f7e25f | ||
|
|
8d9be88fe6 | ||
|
|
16461052ed | ||
|
|
5491dbd824 | ||
|
|
13401ffe24 | ||
|
|
7108d2ddc5 | ||
|
|
a732e0903e | ||
|
|
0491681be4 | ||
|
|
ffe5284764 | ||
|
|
41ca17acda | ||
|
|
06b31f51eb | ||
|
|
ece02db6a3 | ||
|
|
939a7ebf8b | ||
|
|
61edb70fff | ||
|
|
4e455b8aab | ||
|
|
9434390ad3 | ||
|
|
65250db92c | ||
|
|
416dce7975 | ||
|
|
0c5365e7c6 | ||
|
|
19e9d76610 | ||
|
|
e7b05b0138 | ||
|
|
818c9c37ca | ||
|
|
714fb3b14a | ||
|
|
0af379c465 | ||
|
|
9c5bb5df19 | ||
|
|
dc6ea79036 | ||
|
|
21bbb59e31 | ||
|
|
3467109668 |
18
.gitignore
vendored
18
.gitignore
vendored
@@ -54,21 +54,6 @@ coverage.xml
|
|||||||
# Translations
|
# Translations
|
||||||
*.mo
|
*.mo
|
||||||
*.pot
|
*.pot
|
||||||
# Django stuff:
|
|
||||||
*.log
|
|
||||||
local_settings.py
|
|
||||||
db.sqlite3
|
|
||||||
db.sqlite3-journal
|
|
||||||
|
|
||||||
# Flask stuff:
|
|
||||||
instance/
|
|
||||||
.webassets-cache
|
|
||||||
|
|
||||||
# Scrapy stuff:
|
|
||||||
.scrapy
|
|
||||||
|
|
||||||
# Sphinx documentation
|
|
||||||
docs/_build/
|
|
||||||
|
|
||||||
# PyBuilder
|
# PyBuilder
|
||||||
target/
|
target/
|
||||||
@@ -138,4 +123,5 @@ test_*.py
|
|||||||
launch.json
|
launch.json
|
||||||
.DS_Store
|
.DS_Store
|
||||||
test/*
|
test/*
|
||||||
nllb-200-distilled-600M-ctranslate2/*
|
nllb-200-distilled-600M-ctranslate2/*
|
||||||
|
*.mp3
|
||||||
226
LICENSE
226
LICENSE
@@ -1,52 +1,210 @@
|
|||||||
# License
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
## Main Software License
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
MIT License
|
1. Definitions.
|
||||||
|
|
||||||
Copyright (c) 2025 Quentin Fuxa.
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
the copyright owner that is granting the License.
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
copies or substantial portions of the Software.
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
exercising permissions granted by this License.
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
||||||
|
|
||||||
## SimulStreaming Backend License
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
**When using the SimulStreaming backend (SimulWhisper), additional licensing terms apply:**
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
SimulStreaming (https://github.com/ufal/SimulStreaming) is dual-licensed:
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
### 🔹 Non-Commercial Use
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
You may use SimulStreaming under the **PolyForm Noncommercial License 1.0.0** if you obtain the code through the GitHub repository. This license is **free of charge** and comes with **no obligations** for non-commercial users.
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
### 🔸 Commercial Use
|
"Contribution" shall mean any work of authorship, including
|
||||||
Understanding who uses SimulStreaming commercially helps improve and prioritize development. Therefore, **registration is required** for those who acquire a commercial license.
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
Commercial licenses are planned to be **affordable** to SMEs and individuals. They are considering providing commercial licenses either for free or for a symbolic one-time fee, and may also provide additional support. You can share your preference via the [questionnaire](https://forms.cloud.microsoft.com/e/7tCxb4gJfB).
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
You can also leave your contact [there](https://forms.cloud.microsoft.com/e/7tCxb4gJfB) to be notified when commercial licenses become available.
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
**Contact for SimulStreaming licensing:**
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
[Dominik Macháček](https://ufal.mff.cuni.cz/dominik-machacek/), machacek@ufal.mff.cuni.cz
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright 2025 Quentin Fuxa
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
---
|
---
|
||||||
|
|
||||||
## Based on:
|
## Based on:
|
||||||
- **whisper_streaming** by ÚFAL – MIT License – https://github.com/ufal/whisper_streaming. The original work by ÚFAL. License: https://github.com/ufal/whisper_streaming/blob/main/LICENSE
|
- **SimulWhisper** by Speech and Audio Technology LAB of Tsinghua University – Apache-2.0 – https://github.com/ufal/SimulStreaming
|
||||||
- **silero-vad** by Snakers4 – MIT License – https://github.com/snakers4/silero-vad. The work by Snakers4 (silero-vad). License: https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/LICENSE
|
- **SimulStreaming** by ÚFAL – MIT License – https://github.com/ufal/SimulStreaming
|
||||||
- **Diart** by juanmc2005 – MIT License – https://github.com/juanmc2005/diart. The work in Diart by juanmc2005. License: https://github.com/juanmc2005/diart/blob/main/LICENSE
|
- **NeMo** by NVidia - Apache-2.0 - https://github.com/NVIDIA-NeMo/NeMo
|
||||||
- **SimulStreaming** by ÚFAL – Dual License (PolyForm Noncommercial License 1.0.0 / Commercial License) – https://github.com/ufal/SimulStreaming
|
- **whisper_streaming** by ÚFAL – MIT License – https://github.com/ufal/whisper_streaming.
|
||||||
|
- **silero-vad** by Snakers4 – MIT License – https://github.com/snakers4/silero-vad.
|
||||||
|
- **Diart** by juanmc2005 – MIT License – https://github.com/juanmc2005/diart.
|
||||||
|
|||||||
40
README.md
40
README.md
@@ -10,16 +10,16 @@
|
|||||||
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
|
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
|
||||||
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a>
|
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a>
|
||||||
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.15-dark_green"></a>
|
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.15-dark_green"></a>
|
||||||
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT/Dual Licensed-dark_green"></a>
|
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-Apache 2.0-dark_green"></a>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
|
||||||
Real-time speech transcription directly to your browser, with a ready-to-use backend+server and a simple frontend. ✨
|
Real-time transcription directly to your browser, with a ready-to-use backend+server and a simple frontend.
|
||||||
|
|
||||||
#### Powered by Leading Research:
|
#### Powered by Leading Research:
|
||||||
|
|
||||||
- [SimulStreaming](https://github.com/ufalSimulStreaming) (SOTA 2025) - Ultra-low latency transcription using [AlignAtt policy](https://arxiv.org/pdf/2305.11408)
|
- Simul-[Whisper](https://github.com/backspacetg/simul_whisper)/[Streaming](https://github.com/ufal/SimulStreaming) (SOTA 2025) - Ultra-low latency transcription using [AlignAtt policy](https://arxiv.org/pdf/2305.11408)
|
||||||
- [NLLB](https://arxiv.org/abs/2207.04672), ([distilled](https://huggingface.co/entai2965/nllb-200-distilled-600M-ctranslate2)) (2024) - Translation to more than 100 languages.
|
- [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting) (2025), based on [distilled](https://huggingface.co/entai2965/nllb-200-distilled-600M-ctranslate2) [NLLB](https://arxiv.org/abs/2207.04672) (2022, 2024) - Simulatenous translation from & to 200 languages.
|
||||||
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - Low latency transcription using [LocalAgreement policy](https://www.isca-archive.org/interspeech_2020/liu20s_interspeech.pdf)
|
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - Low latency transcription using [LocalAgreement policy](https://www.isca-archive.org/interspeech_2020/liu20s_interspeech.pdf)
|
||||||
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - Advanced real-time speaker diarization
|
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - Advanced real-time speaker diarization
|
||||||
- [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) - Real-time speaker diarization
|
- [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) - Real-time speaker diarization
|
||||||
@@ -45,7 +45,7 @@ pip install whisperlivekit
|
|||||||
#### Quick Start
|
#### Quick Start
|
||||||
1. **Start the transcription server:**
|
1. **Start the transcription server:**
|
||||||
```bash
|
```bash
|
||||||
whisperlivekit-server --model base --language en
|
wlk --model base --language en
|
||||||
```
|
```
|
||||||
|
|
||||||
2. **Open your browser** and navigate to `http://localhost:8000`. Start speaking and watch your words appear in real-time!
|
2. **Open your browser** and navigate to `http://localhost:8000`. Start speaking and watch your words appear in real-time!
|
||||||
@@ -53,6 +53,7 @@ pip install whisperlivekit
|
|||||||
|
|
||||||
> - See [tokenizer.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) for the list of all available languages.
|
> - See [tokenizer.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) for the list of all available languages.
|
||||||
> - For HTTPS requirements, see the **Parameters** section for SSL configuration options.
|
> - For HTTPS requirements, see the **Parameters** section for SSL configuration options.
|
||||||
|
> - The CLI entry point is exposed as both `wlk` and `whisperlivekit-server`; they are equivalent.
|
||||||
|
|
||||||
#### Use it to capture audio from web pages.
|
#### Use it to capture audio from web pages.
|
||||||
|
|
||||||
@@ -68,13 +69,12 @@ Go to `chrome-extension` for instructions.
|
|||||||
|
|
||||||
| Optional | `pip install` |
|
| Optional | `pip install` |
|
||||||
|-----------|-------------|
|
|-----------|-------------|
|
||||||
| **Speaker diarization with Sortformer** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
|
| **Windows/Linux optimizations** | `faster-whisper` |
|
||||||
| **Apple Silicon optimized backend** | `mlx-whisper` |
|
| **Apple Silicon optimizations** | `mlx-whisper` |
|
||||||
| **NLLB Translation** | `huggingface_hub` & `transformers` |
|
| **Translation** | `nllw` |
|
||||||
|
| **Speaker diarization** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
|
||||||
|
| OpenAI API | `openai` |
|
||||||
| *[Not recommanded]* Speaker diarization with Diart | `diart` |
|
| *[Not recommanded]* Speaker diarization with Diart | `diart` |
|
||||||
| *[Not recommanded]* Original Whisper backend | `whisper` |
|
|
||||||
| *[Not recommanded]* Improved timestamps backend | `whisper-timestamped` |
|
|
||||||
| OpenAI API backend | `openai` |
|
|
||||||
|
|
||||||
See **Parameters & Configuration** below on how to use them.
|
See **Parameters & Configuration** below on how to use them.
|
||||||
|
|
||||||
@@ -86,10 +86,10 @@ See **Parameters & Configuration** below on how to use them.
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Large model and translate from french to danish
|
# Large model and translate from french to danish
|
||||||
whisperlivekit-server --model large-v3 --language fr --target-language da
|
wlk --model large-v3 --language fr --target-language da
|
||||||
|
|
||||||
# Diarization and server listening on */80
|
# Diarization and server listening on */80
|
||||||
whisperlivekit-server --host 0.0.0.0 --port 80 --model medium --diarization --language fr
|
wlk --host 0.0.0.0 --port 80 --model medium --diarization --language fr
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
@@ -139,13 +139,13 @@ async def websocket_endpoint(websocket: WebSocket):
|
|||||||
|
|
||||||
| Parameter | Description | Default |
|
| Parameter | Description | Default |
|
||||||
|-----------|-------------|---------|
|
|-----------|-------------|---------|
|
||||||
| `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/available_models.md) | `small` |
|
| `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/available_models.md) | `small` |
|
||||||
| `--model-dir` | Directory containing Whisper model.bin and other files. Overrides `--model`. | `None` |
|
| `--model-path` | Local .pt file/directory **or** Hugging Face repo ID containing the Whisper model. Overrides `--model`. Recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/models_compatible_formats.md) | `None` |
|
||||||
| `--language` | List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` |
|
| `--language` | List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` |
|
||||||
| `--target-language` | If sets, activates translation using NLLB. Ex: `fr`. [118 languages available](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/translation/mapping_languages.py). If you want to translate to english, you should rather use `--task translate`, since Whisper can do it directly. | `None` |
|
| `--target-language` | If sets, translates using [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting). [200 languages available](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/supported_languages.md). If you want to translate to english, you can also use `--direct-english-translation`. The STT model will try to directly output the translation. | `None` |
|
||||||
| `--task` | Set to `translate` to translate *only* to english, using Whisper translation. | `transcribe` |
|
|
||||||
| `--diarization` | Enable speaker identification | `False` |
|
| `--diarization` | Enable speaker identification | `False` |
|
||||||
| `--backend` | Processing backend. You can switch to `faster-whisper` if `simulstreaming` does not work correctly | `simulstreaming` |
|
| `--backend-policy` | Streaming strategy: `1`/`simulstreaming` uses AlignAtt SimulStreaming, `2`/`localagreement` uses the LocalAgreement policy | `simulstreaming` |
|
||||||
|
| `--backend` | Whisper implementation selector. `auto` picks MLX on macOS (if installed), otherwise Faster-Whisper, otherwise vanilla Whisper. You can also force `mlx-whisper`, `faster-whisper`, `whisper`, or `openai-api` (LocalAgreement only) | `auto` |
|
||||||
| `--no-vac` | Disable Voice Activity Controller | `False` |
|
| `--no-vac` | Disable Voice Activity Controller | `False` |
|
||||||
| `--no-vad` | Disable Voice Activity Detection | `False` |
|
| `--no-vad` | Disable Voice Activity Detection | `False` |
|
||||||
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
|
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
|
||||||
@@ -171,7 +171,8 @@ async def websocket_endpoint(websocket: WebSocket):
|
|||||||
| SimulStreaming backend options | Description | Default |
|
| SimulStreaming backend options | Description | Default |
|
||||||
|-----------|-------------|---------|
|
|-----------|-------------|---------|
|
||||||
| `--disable-fast-encoder` | Disable Faster Whisper or MLX Whisper backends for the encoder (if installed). Inference can be slower but helpful when GPU memory is limited | `False` |
|
| `--disable-fast-encoder` | Disable Faster Whisper or MLX Whisper backends for the encoder (if installed). Inference can be slower but helpful when GPU memory is limited | `False` |
|
||||||
| `--custom-alignment-heads` | Use your own alignment heads, useful when `--model-dir` is used | `None` |
|
| `--custom-alignment-heads` | Use your own alignment heads, useful when `--model-dir` is used. Use `scripts/determine_alignment_heads.py` to extract them. <img src="scripts/alignment_heads.png" alt="WhisperLiveKit Demo" width="300">
|
||||||
|
| `None` |
|
||||||
| `--frame-threshold` | AlignAtt frame threshold (lower = faster, higher = more accurate) | `25` |
|
| `--frame-threshold` | AlignAtt frame threshold (lower = faster, higher = more accurate) | `25` |
|
||||||
| `--beams` | Number of beams for beam search (1 = greedy decoding) | `1` |
|
| `--beams` | Number of beams for beam search (1 = greedy decoding) | `1` |
|
||||||
| `--decoder` | Force decoder type (`beam` or `greedy`) | `auto` |
|
| `--decoder` | Force decoder type (`beam` or `greedy`) | `auto` |
|
||||||
@@ -182,7 +183,6 @@ async def websocket_endpoint(websocket: WebSocket):
|
|||||||
| `--init-prompt` | Initial prompt for the model | `None` |
|
| `--init-prompt` | Initial prompt for the model | `None` |
|
||||||
| `--static-init-prompt` | Static prompt that doesn't scroll | `None` |
|
| `--static-init-prompt` | Static prompt that doesn't scroll | `None` |
|
||||||
| `--max-context-tokens` | Maximum context tokens | `None` |
|
| `--max-context-tokens` | Maximum context tokens | `None` |
|
||||||
| `--model-path` | Direct path to .pt model file. Download it if not found | `./base.pt` |
|
|
||||||
| `--preload-model-count` | Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent users) | `1` |
|
| `--preload-model-count` | Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent users) | `1` |
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
19
docs/models_compatible_formats.md
Normal file
19
docs/models_compatible_formats.md
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
# Model Path Formats
|
||||||
|
|
||||||
|
The `--model-path` parameter accepts:
|
||||||
|
|
||||||
|
## File Path
|
||||||
|
- **`.pt` / `.bin` / `.safetensor` formats** Should be openable by pytorch/safetensor.
|
||||||
|
|
||||||
|
## Directory Path (recommended)
|
||||||
|
Must contain:
|
||||||
|
- **`.pt` / `.bin` / `.safetensor` file** (required for decoder)
|
||||||
|
|
||||||
|
May optionally contain:
|
||||||
|
- **`.bin` file** - faster-whisper model for encoder (requires faster-whisper)
|
||||||
|
- **`weights.npz`** or **`weights.safetensors`** - for encoder (requires whisper-mlx)
|
||||||
|
|
||||||
|
## Hugging Face Repo ID
|
||||||
|
- Provide the repo ID (e.g. `openai/whisper-large-v3`) and WhisperLiveKit will download and cache the snapshot automatically. For gated repos, authenticate via `huggingface-cli login` first.
|
||||||
|
|
||||||
|
To improve speed/reduce allucinations, you may want to use `scripts/determine_alignment_heads.py` to determine the alignment heads to use for your model, and use the `--custom-alignment-heads` to pass them to WLK. If not, alignement heads are set to be all the heads of the last half layer of decoder.
|
||||||
265
docs/supported_languages.md
Normal file
265
docs/supported_languages.md
Normal file
@@ -0,0 +1,265 @@
|
|||||||
|
# Supported Languages
|
||||||
|
|
||||||
|
WhisperLiveKit supports translation into **201 languages** from the FLORES-200 dataset through the NLLB (No Language Left Behind) translation system.
|
||||||
|
|
||||||
|
## How to Specify Languages
|
||||||
|
|
||||||
|
You can specify languages in **three different ways**:
|
||||||
|
|
||||||
|
1. **Language Name** (case-insensitive): `"English"`, `"French"`, `"Spanish"`
|
||||||
|
2. **ISO Language Code**: `"en"`, `"fr"`, `"es"`
|
||||||
|
3. **NLLB Code** (FLORES-200): `"eng_Latn"`, `"fra_Latn"`, `"spa_Latn"`
|
||||||
|
|
||||||
|
## Usage Examples
|
||||||
|
|
||||||
|
### Command Line
|
||||||
|
```bash
|
||||||
|
# Using language name
|
||||||
|
whisperlivekit-server --target-language "French"
|
||||||
|
|
||||||
|
# Using ISO code
|
||||||
|
whisperlivekit-server --target-language fr
|
||||||
|
|
||||||
|
# Using NLLB code
|
||||||
|
whisperlivekit-server --target-language fra_Latn
|
||||||
|
```
|
||||||
|
|
||||||
|
### Python API
|
||||||
|
```python
|
||||||
|
from nllw.translation import get_language_info
|
||||||
|
|
||||||
|
# Get language information by name
|
||||||
|
lang_info = get_language_info("French")
|
||||||
|
print(lang_info)
|
||||||
|
# {'name': 'French', 'nllb': 'fra_Latn', 'language_code': 'fr'}
|
||||||
|
|
||||||
|
# Get language information by ISO code
|
||||||
|
lang_info = get_language_info("fr")
|
||||||
|
|
||||||
|
# Get language information by NLLB code
|
||||||
|
lang_info = get_language_info("fra_Latn")
|
||||||
|
|
||||||
|
# All three return the same result
|
||||||
|
```
|
||||||
|
|
||||||
|
## Complete Language List
|
||||||
|
|
||||||
|
The following table lists all 201 supported languages with their corresponding codes:
|
||||||
|
|
||||||
|
| Language Name | ISO Code | NLLB Code |
|
||||||
|
|---------------|----------|-----------|
|
||||||
|
| Acehnese (Arabic script) | ace_Arab | ace_Arab |
|
||||||
|
| Acehnese (Latin script) | ace_Latn | ace_Latn |
|
||||||
|
| Mesopotamian Arabic | acm_Arab | acm_Arab |
|
||||||
|
| Ta'izzi-Adeni Arabic | acq_Arab | acq_Arab |
|
||||||
|
| Tunisian Arabic | aeb_Arab | aeb_Arab |
|
||||||
|
| Afrikaans | af | afr_Latn |
|
||||||
|
| South Levantine Arabic | ajp_Arab | ajp_Arab |
|
||||||
|
| Akan | ak | aka_Latn |
|
||||||
|
| Tosk Albanian | als | als_Latn |
|
||||||
|
| Amharic | am | amh_Ethi |
|
||||||
|
| North Levantine Arabic | apc_Arab | apc_Arab |
|
||||||
|
| Modern Standard Arabic | ar | arb_Arab |
|
||||||
|
| Modern Standard Arabic (Romanized) | arb_Latn | arb_Latn |
|
||||||
|
| Najdi Arabic | ars_Arab | ars_Arab |
|
||||||
|
| Moroccan Arabic | ary_Arab | ary_Arab |
|
||||||
|
| Egyptian Arabic | arz_Arab | arz_Arab |
|
||||||
|
| Assamese | as | asm_Beng |
|
||||||
|
| Asturian | ast | ast_Latn |
|
||||||
|
| Awadhi | awa | awa_Deva |
|
||||||
|
| Central Aymara | ay | ayr_Latn |
|
||||||
|
| South Azerbaijani | azb | azb_Arab |
|
||||||
|
| North Azerbaijani | az | azj_Latn |
|
||||||
|
| Bashkir | ba | bak_Cyrl |
|
||||||
|
| Bambara | bm | bam_Latn |
|
||||||
|
| Balinese | ban | ban_Latn |
|
||||||
|
| Belarusian | be | bel_Cyrl |
|
||||||
|
| Bemba | bem | bem_Latn |
|
||||||
|
| Bengali | bn | ben_Beng |
|
||||||
|
| Bhojpuri | bho | bho_Deva |
|
||||||
|
| Banjar (Arabic script) | bjn_Arab | bjn_Arab |
|
||||||
|
| Banjar (Latin script) | bjn_Latn | bjn_Latn |
|
||||||
|
| Standard Tibetan | bo | bod_Tibt |
|
||||||
|
| Bosnian | bs | bos_Latn |
|
||||||
|
| Buginese | bug | bug_Latn |
|
||||||
|
| Bulgarian | bg | bul_Cyrl |
|
||||||
|
| Catalan | ca | cat_Latn |
|
||||||
|
| Cebuano | ceb | ceb_Latn |
|
||||||
|
| Czech | cs | ces_Latn |
|
||||||
|
| Chokwe | cjk | cjk_Latn |
|
||||||
|
| Central Kurdish | ckb | ckb_Arab |
|
||||||
|
| Crimean Tatar | crh | crh_Latn |
|
||||||
|
| Welsh | cy | cym_Latn |
|
||||||
|
| Danish | da | dan_Latn |
|
||||||
|
| German | de | deu_Latn |
|
||||||
|
| Southwestern Dinka | dik | dik_Latn |
|
||||||
|
| Dyula | dyu | dyu_Latn |
|
||||||
|
| Dzongkha | dz | dzo_Tibt |
|
||||||
|
| Greek | el | ell_Grek |
|
||||||
|
| English | en | eng_Latn |
|
||||||
|
| Esperanto | eo | epo_Latn |
|
||||||
|
| Estonian | et | est_Latn |
|
||||||
|
| Basque | eu | eus_Latn |
|
||||||
|
| Ewe | ee | ewe_Latn |
|
||||||
|
| Faroese | fo | fao_Latn |
|
||||||
|
| Fijian | fj | fij_Latn |
|
||||||
|
| Finnish | fi | fin_Latn |
|
||||||
|
| Fon | fon | fon_Latn |
|
||||||
|
| French | fr | fra_Latn |
|
||||||
|
| Friulian | fur-IT | fur_Latn |
|
||||||
|
| Nigerian Fulfulde | fuv | fuv_Latn |
|
||||||
|
| West Central Oromo | om | gaz_Latn |
|
||||||
|
| Scottish Gaelic | gd | gla_Latn |
|
||||||
|
| Irish | ga-IE | gle_Latn |
|
||||||
|
| Galician | gl | glg_Latn |
|
||||||
|
| Guarani | gn | grn_Latn |
|
||||||
|
| Gujarati | gu-IN | guj_Gujr |
|
||||||
|
| Haitian Creole | ht | hat_Latn |
|
||||||
|
| Hausa | ha | hau_Latn |
|
||||||
|
| Hebrew | he | heb_Hebr |
|
||||||
|
| Hindi | hi | hin_Deva |
|
||||||
|
| Chhattisgarhi | hne | hne_Deva |
|
||||||
|
| Croatian | hr | hrv_Latn |
|
||||||
|
| Hungarian | hu | hun_Latn |
|
||||||
|
| Armenian | hy-AM | hye_Armn |
|
||||||
|
| Igbo | ig | ibo_Latn |
|
||||||
|
| Ilocano | ilo | ilo_Latn |
|
||||||
|
| Indonesian | id | ind_Latn |
|
||||||
|
| Icelandic | is | isl_Latn |
|
||||||
|
| Italian | it | ita_Latn |
|
||||||
|
| Javanese | jv | jav_Latn |
|
||||||
|
| Japanese | ja | jpn_Jpan |
|
||||||
|
| Kabyle | kab | kab_Latn |
|
||||||
|
| Jingpho | kac | kac_Latn |
|
||||||
|
| Kamba | kam | kam_Latn |
|
||||||
|
| Kannada | kn | kan_Knda |
|
||||||
|
| Kashmiri (Arabic script) | kas_Arab | kas_Arab |
|
||||||
|
| Kashmiri (Devanagari script) | kas_Deva | kas_Deva |
|
||||||
|
| Georgian | ka | kat_Geor |
|
||||||
|
| Kazakh | kk | kaz_Cyrl |
|
||||||
|
| Kabiyè | kbp | kbp_Latn |
|
||||||
|
| Kabuverdianu | kea | kea_Latn |
|
||||||
|
| Halh Mongolian | mn | khk_Cyrl |
|
||||||
|
| Khmer | km | khm_Khmr |
|
||||||
|
| Kikuyu | ki | kik_Latn |
|
||||||
|
| Kinyarwanda | rw | kin_Latn |
|
||||||
|
| Kyrgyz | ky | kir_Cyrl |
|
||||||
|
| Kimbundu | kmb | kmb_Latn |
|
||||||
|
| Northern Kurdish | kmr | kmr_Latn |
|
||||||
|
| Central Kanuri (Arabic script) | knc_Arab | knc_Arab |
|
||||||
|
| Central Kanuri (Latin script) | knc_Latn | knc_Latn |
|
||||||
|
| Kikongo | kg | kon_Latn |
|
||||||
|
| Korean | ko | kor_Hang |
|
||||||
|
| Lao | lo | lao_Laoo |
|
||||||
|
| Ligurian | lij | lij_Latn |
|
||||||
|
| Limburgish | li | lim_Latn |
|
||||||
|
| Lingala | ln | lin_Latn |
|
||||||
|
| Lithuanian | lt | lit_Latn |
|
||||||
|
| Lombard | lmo | lmo_Latn |
|
||||||
|
| Latgalian | ltg | ltg_Latn |
|
||||||
|
| Luxembourgish | lb | ltz_Latn |
|
||||||
|
| Luba-Kasai | lua | lua_Latn |
|
||||||
|
| Ganda | lg | lug_Latn |
|
||||||
|
| Luo | luo | luo_Latn |
|
||||||
|
| Mizo | lus | lus_Latn |
|
||||||
|
| Standard Latvian | lv | lvs_Latn |
|
||||||
|
| Magahi | mag | mag_Deva |
|
||||||
|
| Maithili | mai | mai_Deva |
|
||||||
|
| Malayalam | ml-IN | mal_Mlym |
|
||||||
|
| Marathi | mr | mar_Deva |
|
||||||
|
| Minangkabau (Arabic script) | min_Arab | min_Arab |
|
||||||
|
| Minangkabau (Latin script) | min_Latn | min_Latn |
|
||||||
|
| Macedonian | mk | mkd_Cyrl |
|
||||||
|
| Maltese | mt | mlt_Latn |
|
||||||
|
| Meitei (Bengali script) | mni | mni_Beng |
|
||||||
|
| Mossi | mos | mos_Latn |
|
||||||
|
| Maori | mi | mri_Latn |
|
||||||
|
| Burmese | my | mya_Mymr |
|
||||||
|
| Dutch | nl | nld_Latn |
|
||||||
|
| Norwegian Nynorsk | nn-NO | nno_Latn |
|
||||||
|
| Norwegian Bokmål | nb | nob_Latn |
|
||||||
|
| Nepali | ne-NP | npi_Deva |
|
||||||
|
| Northern Sotho | nso | nso_Latn |
|
||||||
|
| Nuer | nus | nus_Latn |
|
||||||
|
| Nyanja | ny | nya_Latn |
|
||||||
|
| Occitan | oc | oci_Latn |
|
||||||
|
| Odia | or | ory_Orya |
|
||||||
|
| Pangasinan | pag | pag_Latn |
|
||||||
|
| Eastern Panjabi | pa | pan_Guru |
|
||||||
|
| Papiamento | pap | pap_Latn |
|
||||||
|
| Southern Pashto | pbt | pbt_Arab |
|
||||||
|
| Western Persian | fa | pes_Arab |
|
||||||
|
| Plateau Malagasy | mg | plt_Latn |
|
||||||
|
| Polish | pl | pol_Latn |
|
||||||
|
| Portuguese | pt-PT | por_Latn |
|
||||||
|
| Dari | fa-AF | prs_Arab |
|
||||||
|
| Ayacucho Quechua | qu | quy_Latn |
|
||||||
|
| Romanian | ro | ron_Latn |
|
||||||
|
| Rundi | rn | run_Latn |
|
||||||
|
| Russian | ru | rus_Cyrl |
|
||||||
|
| Sango | sg | sag_Latn |
|
||||||
|
| Sanskrit | sa | san_Deva |
|
||||||
|
| Santali | sat | sat_Olck |
|
||||||
|
| Sicilian | scn | scn_Latn |
|
||||||
|
| Shan | shn | shn_Mymr |
|
||||||
|
| Sinhala | si-LK | sin_Sinh |
|
||||||
|
| Slovak | sk | slk_Latn |
|
||||||
|
| Slovenian | sl | slv_Latn |
|
||||||
|
| Samoan | sm | smo_Latn |
|
||||||
|
| Shona | sn | sna_Latn |
|
||||||
|
| Sindhi | sd | snd_Arab |
|
||||||
|
| Somali | so | som_Latn |
|
||||||
|
| Southern Sotho | st | sot_Latn |
|
||||||
|
| Spanish | es-ES | spa_Latn |
|
||||||
|
| Sardinian | sc | srd_Latn |
|
||||||
|
| Serbian | sr | srp_Cyrl |
|
||||||
|
| Swati | ss | ssw_Latn |
|
||||||
|
| Sundanese | su | sun_Latn |
|
||||||
|
| Swedish | sv-SE | swe_Latn |
|
||||||
|
| Swahili | sw | swh_Latn |
|
||||||
|
| Silesian | szl | szl_Latn |
|
||||||
|
| Tamil | ta | tam_Taml |
|
||||||
|
| Tamasheq (Latin script) | taq_Latn | taq_Latn |
|
||||||
|
| Tamasheq (Tifinagh script) | taq_Tfng | taq_Tfng |
|
||||||
|
| Tatar | tt-RU | tat_Cyrl |
|
||||||
|
| Telugu | te | tel_Telu |
|
||||||
|
| Tajik | tg | tgk_Cyrl |
|
||||||
|
| Tagalog | tl | tgl_Latn |
|
||||||
|
| Thai | th | tha_Thai |
|
||||||
|
| Tigrinya | ti | tir_Ethi |
|
||||||
|
| Tok Pisin | tpi | tpi_Latn |
|
||||||
|
| Tswana | tn | tsn_Latn |
|
||||||
|
| Tsonga | ts | tso_Latn |
|
||||||
|
| Turkmen | tk | tuk_Latn |
|
||||||
|
| Tumbuka | tum | tum_Latn |
|
||||||
|
| Turkish | tr | tur_Latn |
|
||||||
|
| Twi | tw | twi_Latn |
|
||||||
|
| Central Atlas Tamazight | tzm | tzm_Tfng |
|
||||||
|
| Uyghur | ug | uig_Arab |
|
||||||
|
| Ukrainian | uk | ukr_Cyrl |
|
||||||
|
| Umbundu | umb | umb_Latn |
|
||||||
|
| Urdu | ur | urd_Arab |
|
||||||
|
| Northern Uzbek | uz | uzn_Latn |
|
||||||
|
| Venetian | vec | vec_Latn |
|
||||||
|
| Vietnamese | vi | vie_Latn |
|
||||||
|
| Waray | war | war_Latn |
|
||||||
|
| Wolof | wo | wol_Latn |
|
||||||
|
| Xhosa | xh | xho_Latn |
|
||||||
|
| Eastern Yiddish | yi | ydd_Hebr |
|
||||||
|
| Yoruba | yo | yor_Latn |
|
||||||
|
| Yue Chinese | yue | yue_Hant |
|
||||||
|
| Chinese (Simplified) | zh-CN | zho_Hans |
|
||||||
|
| Chinese (Traditional) | zh-TW | zho_Hant |
|
||||||
|
| Standard Malay | ms | zsm_Latn |
|
||||||
|
| Zulu | zu | zul_Latn |
|
||||||
|
|
||||||
|
## Special Features
|
||||||
|
|
||||||
|
### Multiple Script Support
|
||||||
|
Several languages are available in multiple scripts (e.g., Arabic and Latin):
|
||||||
|
- **Acehnese**: Arabic (`ace_Arab`) and Latin (`ace_Latn`)
|
||||||
|
- **Banjar**: Arabic (`bjn_Arab`) and Latin (`bjn_Latn`)
|
||||||
|
- **Kashmiri**: Arabic (`kas_Arab`) and Devanagari (`kas_Deva`)
|
||||||
|
- **Minangkabau**: Arabic (`min_Arab`) and Latin (`min_Latn`)
|
||||||
|
- **Tamasheq**: Latin (`taq_Latn`) and Tifinagh (`taq_Tfng`)
|
||||||
|
- **Central Kanuri**: Arabic (`knc_Arab`) and Latin (`knc_Latn`)
|
||||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "whisperlivekit"
|
name = "whisperlivekit"
|
||||||
version = "0.2.12"
|
version = "0.2.14"
|
||||||
description = "Real-time speech-to-text with speaker diarization using Whisper"
|
description = "Real-time speech-to-text with speaker diarization using Whisper"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
authors = [
|
authors = [
|
||||||
@@ -30,28 +30,41 @@ dependencies = [
|
|||||||
"fastapi",
|
"fastapi",
|
||||||
"librosa",
|
"librosa",
|
||||||
"soundfile",
|
"soundfile",
|
||||||
"faster-whisper",
|
|
||||||
"uvicorn",
|
"uvicorn",
|
||||||
"websockets",
|
"websockets",
|
||||||
"torchaudio>=2.0.0",
|
"torchaudio>=2.0.0",
|
||||||
"torch>=2.0.0",
|
"torch>=2.0.0",
|
||||||
|
"huggingface-hub>=0.25.0",
|
||||||
"tqdm",
|
"tqdm",
|
||||||
"tiktoken",
|
"tiktoken",
|
||||||
'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")'
|
'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")'
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
sentence = ["mosestokenizer", "wtpsplit"]
|
translation = ["nllw"]
|
||||||
|
sentence_tokenizer = ["mosestokenizer", "wtpsplit"]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
Homepage = "https://github.com/QuentinFuxa/WhisperLiveKit"
|
Homepage = "https://github.com/QuentinFuxa/WhisperLiveKit"
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
whisperlivekit-server = "whisperlivekit.basic_server:main"
|
whisperlivekit-server = "whisperlivekit.basic_server:main"
|
||||||
|
wlk = "whisperlivekit.basic_server:main"
|
||||||
|
|
||||||
[tool.setuptools]
|
[tool.setuptools]
|
||||||
packages = ["whisperlivekit", "whisperlivekit.diarization", "whisperlivekit.simul_whisper", "whisperlivekit.simul_whisper.whisper", "whisperlivekit.simul_whisper.whisper.assets", "whisperlivekit.simul_whisper.whisper.normalizers", "whisperlivekit.web", "whisperlivekit.whisper_streaming_custom", "whisperlivekit.translation"]
|
packages = [
|
||||||
|
"whisperlivekit",
|
||||||
|
"whisperlivekit.diarization",
|
||||||
|
"whisperlivekit.simul_whisper",
|
||||||
|
"whisperlivekit.whisper",
|
||||||
|
"whisperlivekit.whisper.assets",
|
||||||
|
"whisperlivekit.whisper.normalizers",
|
||||||
|
"whisperlivekit.web",
|
||||||
|
"whisperlivekit.local_agreement",
|
||||||
|
"whisperlivekit.vad_models"
|
||||||
|
]
|
||||||
|
|
||||||
[tool.setuptools.package-data]
|
[tool.setuptools.package-data]
|
||||||
whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"]
|
whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"]
|
||||||
"whisperlivekit.simul_whisper.whisper.assets" = ["*.tiktoken", "*.npz"]
|
"whisperlivekit.whisper.assets" = ["*.tiktoken", "*.npz"]
|
||||||
|
"whisperlivekit.vad_models" = ["*.jit", "*.onnx"]
|
||||||
|
|||||||
BIN
scripts/alignment_heads.png
Normal file
BIN
scripts/alignment_heads.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 276 KiB |
153
scripts/convert_hf_whisper.py
Normal file
153
scripts/convert_hf_whisper.py
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Convert a Hugging Face style Whisper checkpoint into a WhisperLiveKit .pt file.
|
||||||
|
|
||||||
|
Optionally shrink the supported audio chunk length (in seconds) by trimming the
|
||||||
|
encoder positional embeddings and updating the stored model dimensions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from whisperlivekit.whisper.audio import HOP_LENGTH, SAMPLE_RATE
|
||||||
|
from whisperlivekit.whisper.model import ModelDimensions
|
||||||
|
from whisperlivekit.whisper.utils import exact_div
|
||||||
|
from whisperlivekit.whisper import _convert_hf_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _load_state_dict(repo_path: Path) -> Dict[str, torch.Tensor]:
|
||||||
|
safetensor_path = repo_path / "model.safetensors"
|
||||||
|
bin_path = repo_path / "pytorch_model.bin"
|
||||||
|
|
||||||
|
if safetensor_path.is_file():
|
||||||
|
try:
|
||||||
|
from safetensors.torch import load_file # type: ignore
|
||||||
|
except Exception as exc: # pragma: no cover - import guard
|
||||||
|
raise RuntimeError(
|
||||||
|
"Install safetensors to load model.safetensors "
|
||||||
|
"(pip install safetensors)"
|
||||||
|
) from exc
|
||||||
|
return load_file(str(safetensor_path))
|
||||||
|
|
||||||
|
if bin_path.is_file():
|
||||||
|
return torch.load(bin_path, map_location="cpu")
|
||||||
|
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Could not find model.safetensors or pytorch_model.bin under {repo_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_config(repo_path: Path) -> Dict:
|
||||||
|
config_path = repo_path / "config.json"
|
||||||
|
if not config_path.is_file():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Hugging Face checkpoint at {repo_path} is missing config.json"
|
||||||
|
)
|
||||||
|
with open(config_path, "r", encoding="utf-8") as fp:
|
||||||
|
return json.load(fp)
|
||||||
|
|
||||||
|
|
||||||
|
def _derive_audio_ctx(chunk_length: float) -> Tuple[int, int]:
|
||||||
|
n_samples = int(round(chunk_length * SAMPLE_RATE))
|
||||||
|
expected_samples = chunk_length * SAMPLE_RATE
|
||||||
|
if abs(n_samples - expected_samples) > 1e-6:
|
||||||
|
raise ValueError(
|
||||||
|
"chunk_length must align with sample rate so that "
|
||||||
|
"chunk_length * SAMPLE_RATE is an integer"
|
||||||
|
)
|
||||||
|
n_frames = exact_div(n_samples, HOP_LENGTH)
|
||||||
|
n_audio_ctx = exact_div(n_frames, 2)
|
||||||
|
return n_frames, n_audio_ctx
|
||||||
|
|
||||||
|
|
||||||
|
def _build_dims(config: Dict, chunk_length: float) -> Dict:
|
||||||
|
base_dims = ModelDimensions(
|
||||||
|
n_mels=config["num_mel_bins"],
|
||||||
|
n_audio_ctx=config["max_source_positions"],
|
||||||
|
n_audio_state=config["d_model"],
|
||||||
|
n_audio_head=config["encoder_attention_heads"],
|
||||||
|
n_audio_layer=config.get("encoder_layers") or config["num_hidden_layers"],
|
||||||
|
n_vocab=config["vocab_size"],
|
||||||
|
n_text_ctx=config["max_target_positions"],
|
||||||
|
n_text_state=config["d_model"],
|
||||||
|
n_text_head=config["decoder_attention_heads"],
|
||||||
|
n_text_layer=config["decoder_layers"],
|
||||||
|
).__dict__.copy()
|
||||||
|
|
||||||
|
_, n_audio_ctx = _derive_audio_ctx(chunk_length)
|
||||||
|
base_dims["n_audio_ctx"] = n_audio_ctx
|
||||||
|
base_dims["chunk_length"] = chunk_length
|
||||||
|
return base_dims
|
||||||
|
|
||||||
|
|
||||||
|
def _trim_positional_embedding(
|
||||||
|
state_dict: Dict[str, torch.Tensor], target_ctx: int
|
||||||
|
) -> None:
|
||||||
|
key = "encoder.positional_embedding"
|
||||||
|
if key not in state_dict:
|
||||||
|
raise KeyError(f"{key} missing from converted state dict")
|
||||||
|
|
||||||
|
tensor = state_dict[key]
|
||||||
|
if tensor.shape[0] < target_ctx:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot increase encoder ctx from {tensor.shape[0]} to {target_ctx}"
|
||||||
|
)
|
||||||
|
if tensor.shape[0] == target_ctx:
|
||||||
|
return
|
||||||
|
state_dict[key] = tensor[:target_ctx].contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
def convert_checkpoint(hf_path: Path, output_path: Path, chunk_length: float) -> None:
|
||||||
|
state_dict = _load_state_dict(hf_path)
|
||||||
|
converted = _convert_hf_state_dict(state_dict)
|
||||||
|
|
||||||
|
config = _load_config(hf_path)
|
||||||
|
dims = _build_dims(config, chunk_length)
|
||||||
|
|
||||||
|
_trim_positional_embedding(converted, dims["n_audio_ctx"])
|
||||||
|
|
||||||
|
package = {"dims": dims, "model_state_dict": converted}
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
torch.save(package, output_path)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Convert Hugging Face Whisper checkpoint to WhisperLiveKit format."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"hf_path",
|
||||||
|
type=str,
|
||||||
|
help="Path to the cloned Hugging Face repository (e.g. whisper-tiny.en)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output",
|
||||||
|
type=str,
|
||||||
|
default="converted-whisper.pt",
|
||||||
|
help="Destination path for the .pt file",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--chunk-length",
|
||||||
|
type=float,
|
||||||
|
default=30.0,
|
||||||
|
help="Audio chunk length in seconds to support (default: 30)",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
hf_path = Path(os.path.expanduser(args.hf_path)).resolve()
|
||||||
|
output_path = Path(os.path.expanduser(args.output)).resolve()
|
||||||
|
|
||||||
|
convert_checkpoint(hf_path, output_path, args.chunk_length)
|
||||||
|
print(f"Saved converted checkpoint to {output_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
292
scripts/determine_alignment_heads.py
Normal file
292
scripts/determine_alignment_heads.py
Normal file
@@ -0,0 +1,292 @@
|
|||||||
|
"""Determine alignment heads for a variants, such as distilled model"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import base64
|
||||||
|
import gzip
|
||||||
|
import io
|
||||||
|
import pathlib
|
||||||
|
import sys
|
||||||
|
import math
|
||||||
|
from typing import List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from datasets import Audio as DatasetAudio, load_dataset
|
||||||
|
import soundfile as sf
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
REPO_ROOT = pathlib.Path(__file__).resolve().parents[1]
|
||||||
|
WHISPER_ROOT = REPO_ROOT / "whisper"
|
||||||
|
|
||||||
|
sys.path.insert(0, str(REPO_ROOT))
|
||||||
|
sys.path.insert(0, str(WHISPER_ROOT))
|
||||||
|
|
||||||
|
from whisper import load_model
|
||||||
|
from whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||||
|
from whisper.tokenizer import get_tokenizer
|
||||||
|
|
||||||
|
AudioInput = Union[str, pathlib.Path, np.ndarray, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
def load_dataset_clips(name, config, split, limit):
|
||||||
|
ds = load_dataset(name, config, split=split)
|
||||||
|
ds = ds.cast_column("audio", DatasetAudio(decode=False))
|
||||||
|
clips = []
|
||||||
|
for idx, row in enumerate(ds):
|
||||||
|
if limit is not None and idx >= limit:
|
||||||
|
break
|
||||||
|
audio_field = row["audio"]
|
||||||
|
transcript = row["text"]
|
||||||
|
|
||||||
|
waveform_np, _ = sf.read(io.BytesIO(audio_field["bytes"]), dtype="float32")
|
||||||
|
if waveform_np.ndim > 1:
|
||||||
|
waveform_np = waveform_np.mean(axis=1)
|
||||||
|
waveform = waveform_np
|
||||||
|
transcript = str(transcript)
|
||||||
|
|
||||||
|
clips.append((waveform, transcript))
|
||||||
|
return clips
|
||||||
|
|
||||||
|
|
||||||
|
def load_clips(args):
|
||||||
|
return load_dataset_clips(
|
||||||
|
args.dataset,
|
||||||
|
args.dataset_config,
|
||||||
|
args.dataset_split,
|
||||||
|
args.dataset_num_samples,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _waveform_from_source(source: AudioInput) -> torch.Tensor:
|
||||||
|
waveform = torch.from_numpy(source.astype(np.float32, copy=False))
|
||||||
|
return waveform
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
type=str,
|
||||||
|
default="pytorch_model.bin",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=str,
|
||||||
|
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||||
|
help="Torch device to run on",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset",
|
||||||
|
type=str,
|
||||||
|
default="librispeech_asr"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset-config",
|
||||||
|
type=str,
|
||||||
|
default="clean"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset-split",
|
||||||
|
type=str,
|
||||||
|
default="validation[:1%]",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset-num-samples",
|
||||||
|
type=int,
|
||||||
|
default=16,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--threshold",
|
||||||
|
type=float,
|
||||||
|
default=1.5,
|
||||||
|
help="Z score threshold for a head to be selected",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--votes",
|
||||||
|
type=float,
|
||||||
|
default=0.75,
|
||||||
|
help="percentage of clips that must vote for a head",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output",
|
||||||
|
type=str,
|
||||||
|
default="alignment_heads.b85",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--visualize-top-k",
|
||||||
|
type=int,
|
||||||
|
default=32,
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def collect_heads(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
clips: Sequence[Tuple[AudioInput, str]],
|
||||||
|
threshold: float,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
device = model.device
|
||||||
|
votes = torch.zeros(model.dims.n_text_layer, model.dims.n_text_head, device=device)
|
||||||
|
strengths = torch.zeros_like(votes)
|
||||||
|
|
||||||
|
for audio_source, transcript in clips:
|
||||||
|
waveform = pad_or_trim(_waveform_from_source(audio_source))
|
||||||
|
mel = log_mel_spectrogram(waveform, device=device)
|
||||||
|
|
||||||
|
tokens = torch.tensor(
|
||||||
|
[
|
||||||
|
*tokenizer.sot_sequence,
|
||||||
|
tokenizer.no_timestamps,
|
||||||
|
*tokenizer.encode(transcript),
|
||||||
|
tokenizer.eot,
|
||||||
|
],
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
qks = [None] * model.dims.n_text_layer
|
||||||
|
hooks = [
|
||||||
|
block.cross_attn.register_forward_hook(
|
||||||
|
lambda _, __, outputs, index=i: qks.__setitem__(index, outputs[-1][0])
|
||||||
|
)
|
||||||
|
for i, block in enumerate(model.decoder.blocks)
|
||||||
|
]
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
model(mel.unsqueeze(0), tokens.unsqueeze(0))
|
||||||
|
|
||||||
|
for hook in hooks:
|
||||||
|
hook.remove()
|
||||||
|
|
||||||
|
for layer_idx, tensor in enumerate(qks):
|
||||||
|
if tensor is None:
|
||||||
|
continue
|
||||||
|
tensor = tensor[:, :, : mel.shape[-1] // 2]
|
||||||
|
tensor = tensor.softmax(dim=-1)
|
||||||
|
peak = tensor.max(dim=-1).values # [heads, tokens]
|
||||||
|
strengths[layer_idx] += peak.mean(dim=-1)
|
||||||
|
zscore = (peak - peak.mean(dim=-1, keepdim=True)) / (
|
||||||
|
peak.std(dim=-1, keepdim=True, unbiased=False) + 1e-6
|
||||||
|
)
|
||||||
|
mask = (zscore > 3).any(dim=-1)
|
||||||
|
votes[layer_idx] += mask.float()
|
||||||
|
|
||||||
|
votes /= len(clips)
|
||||||
|
strengths /= len(clips)
|
||||||
|
return votes, strengths
|
||||||
|
|
||||||
|
|
||||||
|
def _select_heads_for_visualization(selection, strengths, top_k):
|
||||||
|
selected = torch.nonzero(selection, as_tuple=False)
|
||||||
|
if selected.numel() == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
entries = [
|
||||||
|
(int(layer.item()), int(head.item()), float(strengths[layer, head].item()))
|
||||||
|
for layer, head in selected
|
||||||
|
]
|
||||||
|
entries.sort(key=lambda item: item[2], reverse=True)
|
||||||
|
return entries[:top_k]
|
||||||
|
|
||||||
|
def _extract_heatmaps(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
clip: Tuple[AudioInput, str],
|
||||||
|
heads: Sequence[Tuple[int, int, float]],
|
||||||
|
) -> dict:
|
||||||
|
if not heads:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
target_map = {}
|
||||||
|
for layer, head, _ in heads:
|
||||||
|
target_map.setdefault(layer, set()).add(head)
|
||||||
|
|
||||||
|
waveform = pad_or_trim(_waveform_from_source(clip[0]))
|
||||||
|
mel = log_mel_spectrogram(waveform, device=model.device)
|
||||||
|
transcript = clip[1]
|
||||||
|
tokens = torch.tensor(
|
||||||
|
[
|
||||||
|
*tokenizer.sot_sequence,
|
||||||
|
tokenizer.no_timestamps,
|
||||||
|
*tokenizer.encode(transcript),
|
||||||
|
tokenizer.eot,
|
||||||
|
],
|
||||||
|
device=model.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
QKs = [None] * model.dims.n_text_layer
|
||||||
|
hooks = [
|
||||||
|
block.cross_attn.register_forward_hook(
|
||||||
|
lambda _, __, outputs, index=i: QKs.__setitem__(index, outputs[-1][0])
|
||||||
|
)
|
||||||
|
for i, block in enumerate(model.decoder.blocks)
|
||||||
|
]
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
model(mel.unsqueeze(0), tokens.unsqueeze(0))
|
||||||
|
|
||||||
|
for hook in hooks:
|
||||||
|
hook.remove()
|
||||||
|
|
||||||
|
heatmaps = {}
|
||||||
|
for layer_idx, tensor in enumerate(QKs):
|
||||||
|
if tensor is None or layer_idx not in target_map:
|
||||||
|
continue
|
||||||
|
tensor = tensor[:, :, : mel.shape[-1] // 2]
|
||||||
|
tensor = tensor.softmax(dim=-1).cpu()
|
||||||
|
for head_idx in target_map[layer_idx]:
|
||||||
|
heatmaps[(layer_idx, head_idx)] = tensor[head_idx]
|
||||||
|
|
||||||
|
return heatmaps
|
||||||
|
|
||||||
|
|
||||||
|
def _plot_heatmaps(
|
||||||
|
heads, heatmaps, output_path):
|
||||||
|
cols = min(3, len(heads))
|
||||||
|
rows = math.ceil(len(heads) / cols)
|
||||||
|
fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 3.2 * rows), squeeze=False)
|
||||||
|
|
||||||
|
for idx, (layer, head, score) in enumerate(heads):
|
||||||
|
ax = axes[idx // cols][idx % cols]
|
||||||
|
mat = heatmaps.get((layer, head))
|
||||||
|
if mat is None:
|
||||||
|
ax.axis("off")
|
||||||
|
continue
|
||||||
|
im = ax.imshow(mat.to(torch.float32).numpy(), aspect="auto", origin="lower")
|
||||||
|
ax.set_title(f"L{layer} H{head} · score {score:.2f}")
|
||||||
|
ax.set_xlabel("time")
|
||||||
|
ax.set_ylabel("tokens")
|
||||||
|
|
||||||
|
for j in range(len(heads), rows * cols):
|
||||||
|
axes[j // cols][j % cols].axis("off")
|
||||||
|
|
||||||
|
fig.tight_layout()
|
||||||
|
fig.savefig(output_path, dpi=200)
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
def _dump_mask(mask: torch.Tensor, output_path: str):
|
||||||
|
payload = mask.numpy().astype(np.bool_)
|
||||||
|
blob = base64.b85encode(gzip.compress(payload.tobytes()))
|
||||||
|
with open(output_path, "wb") as f:
|
||||||
|
f.write(blob)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = _parse_args()
|
||||||
|
model = load_model(args.model, device=args.device)
|
||||||
|
model.eval()
|
||||||
|
tokenizer = get_tokenizer(multilingual=model.is_multilingual)
|
||||||
|
clips = load_clips(args)
|
||||||
|
|
||||||
|
votes, strengths = collect_heads(model, tokenizer, clips, args.threshold)
|
||||||
|
# selection = votes > 0.5
|
||||||
|
selection = strengths > 0.05
|
||||||
|
_dump_mask(selection.cpu(), args.output)
|
||||||
|
|
||||||
|
viz_heads = _select_heads_for_visualization(selection, strengths, args.visualize_top_k)
|
||||||
|
heatmaps = _extract_heatmaps(model, tokenizer, clips[0], viz_heads)
|
||||||
|
_plot_heatmaps(viz_heads, heatmaps, "alignment_heads.png")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,9 +1,10 @@
|
|||||||
|
"""Copy core files from web directory to Chrome extension directory."""
|
||||||
|
|
||||||
import shutil
|
import shutil
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
def sync_extension_files():
|
def sync_extension_files():
|
||||||
"""Copy core files from web directory to Chrome extension directory."""
|
|
||||||
|
|
||||||
web_dir = Path("whisperlivekit/web")
|
web_dir = Path("whisperlivekit/web")
|
||||||
extension_dir = Path("chrome-extension")
|
extension_dir = Path("chrome-extension")
|
||||||
@@ -67,20 +67,17 @@ class AudioProcessor:
|
|||||||
self.is_stopping = False
|
self.is_stopping = False
|
||||||
self.silence = False
|
self.silence = False
|
||||||
self.silence_duration = 0.0
|
self.silence_duration = 0.0
|
||||||
self.tokens = []
|
self.state = State()
|
||||||
self.last_validated_token = 0
|
|
||||||
self.translated_segments = []
|
|
||||||
self.buffer_transcription = Transcript()
|
|
||||||
self.end_buffer = 0
|
|
||||||
self.end_attributed_speaker = 0
|
|
||||||
self.lock = asyncio.Lock()
|
self.lock = asyncio.Lock()
|
||||||
self.beg_loop = 0.0 #to deal with a potential little lag at the websocket initialization, this is now set in process_audio
|
|
||||||
self.sep = " " # Default separator
|
self.sep = " " # Default separator
|
||||||
self.last_response_content = FrontData()
|
self.last_response_content = FrontData()
|
||||||
self.last_detected_speaker = None
|
self.last_detected_speaker = None
|
||||||
self.speaker_languages = {}
|
self.speaker_languages = {}
|
||||||
self.diarization_before_transcription = False
|
self.diarization_before_transcription = False
|
||||||
|
|
||||||
|
self.segments = []
|
||||||
|
|
||||||
|
|
||||||
if self.diarization_before_transcription:
|
if self.diarization_before_transcription:
|
||||||
self.cumulative_pcm = []
|
self.cumulative_pcm = []
|
||||||
self.last_start = 0.0
|
self.last_start = 0.0
|
||||||
@@ -138,8 +135,8 @@ class AudioProcessor:
|
|||||||
async def add_dummy_token(self):
|
async def add_dummy_token(self):
|
||||||
"""Placeholder token when no transcription is available."""
|
"""Placeholder token when no transcription is available."""
|
||||||
async with self.lock:
|
async with self.lock:
|
||||||
current_time = time() - self.beg_loop
|
current_time = time() - self.state.beg_loop
|
||||||
self.tokens.append(ASRToken(
|
self.state.tokens.append(ASRToken(
|
||||||
start=current_time, end=current_time + 1,
|
start=current_time, end=current_time + 1,
|
||||||
text=".", speaker=-1, is_dummy=True
|
text=".", speaker=-1, is_dummy=True
|
||||||
))
|
))
|
||||||
@@ -149,35 +146,19 @@ class AudioProcessor:
|
|||||||
async with self.lock:
|
async with self.lock:
|
||||||
current_time = time()
|
current_time = time()
|
||||||
|
|
||||||
# Calculate remaining times
|
|
||||||
remaining_transcription = 0
|
remaining_transcription = 0
|
||||||
if self.end_buffer > 0:
|
if self.state.end_buffer > 0:
|
||||||
remaining_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 1))
|
remaining_transcription = max(0, round(current_time - self.state.beg_loop - self.state.end_buffer, 1))
|
||||||
|
|
||||||
remaining_diarization = 0
|
remaining_diarization = 0
|
||||||
if self.tokens:
|
if self.state.tokens:
|
||||||
latest_end = max(self.end_buffer, self.tokens[-1].end if self.tokens else 0)
|
latest_end = max(self.state.end_buffer, self.state.tokens[-1].end if self.state.tokens else 0)
|
||||||
remaining_diarization = max(0, round(latest_end - self.end_attributed_speaker, 1))
|
remaining_diarization = max(0, round(latest_end - self.state.end_attributed_speaker, 1))
|
||||||
|
|
||||||
return State(
|
self.state.remaining_time_transcription = remaining_transcription
|
||||||
tokens=self.tokens.copy(),
|
self.state.remaining_time_diarization = remaining_diarization
|
||||||
last_validated_token=self.last_validated_token,
|
|
||||||
translated_segments=self.translated_segments.copy(),
|
|
||||||
buffer_transcription=self.buffer_transcription,
|
|
||||||
end_buffer=self.end_buffer,
|
|
||||||
end_attributed_speaker=self.end_attributed_speaker,
|
|
||||||
remaining_time_transcription=remaining_transcription,
|
|
||||||
remaining_time_diarization=remaining_diarization
|
|
||||||
)
|
|
||||||
|
|
||||||
async def reset(self):
|
return self.state
|
||||||
"""Reset all state variables to initial values."""
|
|
||||||
async with self.lock:
|
|
||||||
self.tokens = []
|
|
||||||
self.translated_segments = []
|
|
||||||
self.buffer_transcription = Transcript()
|
|
||||||
self.end_buffer = self.end_attributed_speaker = 0
|
|
||||||
self.beg_loop = time()
|
|
||||||
|
|
||||||
async def ffmpeg_stdout_reader(self):
|
async def ffmpeg_stdout_reader(self):
|
||||||
"""Read audio data from FFmpeg stdout and process it into the PCM pipeline."""
|
"""Read audio data from FFmpeg stdout and process it into the PCM pipeline."""
|
||||||
@@ -242,15 +223,15 @@ class AudioProcessor:
|
|||||||
break
|
break
|
||||||
|
|
||||||
asr_internal_buffer_duration_s = len(getattr(self.transcription, 'audio_buffer', [])) / self.transcription.SAMPLING_RATE
|
asr_internal_buffer_duration_s = len(getattr(self.transcription, 'audio_buffer', [])) / self.transcription.SAMPLING_RATE
|
||||||
transcription_lag_s = max(0.0, time() - self.beg_loop - self.end_buffer)
|
transcription_lag_s = max(0.0, time() - self.state.beg_loop - self.state.end_buffer)
|
||||||
asr_processing_logs = f"internal_buffer={asr_internal_buffer_duration_s:.2f}s | lag={transcription_lag_s:.2f}s |"
|
asr_processing_logs = f"internal_buffer={asr_internal_buffer_duration_s:.2f}s | lag={transcription_lag_s:.2f}s |"
|
||||||
if type(item) is Silence:
|
if type(item) is Silence:
|
||||||
asr_processing_logs += f" + Silence of = {item.duration:.2f}s"
|
asr_processing_logs += f" + Silence of = {item.duration:.2f}s"
|
||||||
if self.tokens:
|
if self.state.tokens:
|
||||||
asr_processing_logs += f" | last_end = {self.tokens[-1].end} |"
|
asr_processing_logs += f" | last_end = {self.state.tokens[-1].end} |"
|
||||||
logger.info(asr_processing_logs)
|
logger.info(asr_processing_logs)
|
||||||
cumulative_pcm_duration_stream_time += item.duration
|
cumulative_pcm_duration_stream_time += item.duration
|
||||||
self.transcription.insert_silence(item.duration, self.tokens[-1].end if self.tokens else 0)
|
self.transcription.insert_silence(item.duration, self.state.tokens[-1].end if self.state.tokens else 0)
|
||||||
continue
|
continue
|
||||||
elif isinstance(item, ChangeSpeaker):
|
elif isinstance(item, ChangeSpeaker):
|
||||||
self.transcription.new_speaker(item)
|
self.transcription.new_speaker(item)
|
||||||
@@ -274,7 +255,7 @@ class AudioProcessor:
|
|||||||
if buffer_text.startswith(validated_text):
|
if buffer_text.startswith(validated_text):
|
||||||
_buffer_transcript.text = buffer_text[len(validated_text):].lstrip()
|
_buffer_transcript.text = buffer_text[len(validated_text):].lstrip()
|
||||||
|
|
||||||
candidate_end_times = [self.end_buffer]
|
candidate_end_times = [self.state.end_buffer]
|
||||||
|
|
||||||
if new_tokens:
|
if new_tokens:
|
||||||
candidate_end_times.append(new_tokens[-1].end)
|
candidate_end_times.append(new_tokens[-1].end)
|
||||||
@@ -285,9 +266,9 @@ class AudioProcessor:
|
|||||||
candidate_end_times.append(current_audio_processed_upto)
|
candidate_end_times.append(current_audio_processed_upto)
|
||||||
|
|
||||||
async with self.lock:
|
async with self.lock:
|
||||||
self.tokens.extend(new_tokens)
|
self.state.tokens.extend(new_tokens)
|
||||||
self.buffer_transcription = _buffer_transcript
|
self.state.buffer_transcription = _buffer_transcript
|
||||||
self.end_buffer = max(candidate_end_times)
|
self.state.end_buffer = max(candidate_end_times)
|
||||||
|
|
||||||
if self.translation_queue:
|
if self.translation_queue:
|
||||||
for token in new_tokens:
|
for token in new_tokens:
|
||||||
@@ -360,12 +341,12 @@ class AudioProcessor:
|
|||||||
self.last_end = last_segment.end
|
self.last_end = last_segment.end
|
||||||
elif not self.diarization_before_transcription:
|
elif not self.diarization_before_transcription:
|
||||||
async with self.lock:
|
async with self.lock:
|
||||||
self.tokens = diarization_obj.assign_speakers_to_tokens(
|
self.state.tokens = diarization_obj.assign_speakers_to_tokens(
|
||||||
self.tokens,
|
self.state.tokens,
|
||||||
use_punctuation_split=self.args.punctuation_split
|
use_punctuation_split=self.args.punctuation_split
|
||||||
)
|
)
|
||||||
if len(self.tokens) > 0:
|
if len(self.state.tokens) > 0:
|
||||||
self.end_attributed_speaker = max(self.tokens[-1].end, self.end_attributed_speaker)
|
self.state.end_attributed_speaker = max(self.state.tokens[-1].end, self.state.end_attributed_speaker)
|
||||||
self.diarization_queue.task_done()
|
self.diarization_queue.task_done()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -406,7 +387,10 @@ class AudioProcessor:
|
|||||||
tokens_to_process.append(additional_token)
|
tokens_to_process.append(additional_token)
|
||||||
if tokens_to_process:
|
if tokens_to_process:
|
||||||
self.translation.insert_tokens(tokens_to_process)
|
self.translation.insert_tokens(tokens_to_process)
|
||||||
self.translated_segments = await asyncio.to_thread(self.translation.process)
|
translation_validated_segments, buffer_translation = await asyncio.to_thread(self.translation.process)
|
||||||
|
async with self.lock:
|
||||||
|
self.state.translation_validated_segments = translation_validated_segments
|
||||||
|
self.state.buffer_translation = buffer_translation
|
||||||
self.translation_queue.task_done()
|
self.translation_queue.task_done()
|
||||||
for _ in additional_tokens:
|
for _ in additional_tokens:
|
||||||
self.translation_queue.task_done()
|
self.translation_queue.task_done()
|
||||||
@@ -437,11 +421,9 @@ class AudioProcessor:
|
|||||||
|
|
||||||
state = await self.get_current_state()
|
state = await self.get_current_state()
|
||||||
|
|
||||||
|
|
||||||
lines, undiarized_text = format_output(
|
lines, undiarized_text = format_output(
|
||||||
state,
|
state,
|
||||||
self.silence,
|
self.silence,
|
||||||
current_time = time() - self.beg_loop,
|
|
||||||
args = self.args,
|
args = self.args,
|
||||||
sep=self.sep
|
sep=self.sep
|
||||||
)
|
)
|
||||||
@@ -455,7 +437,13 @@ class AudioProcessor:
|
|||||||
buffer_diarization = self.sep.join(undiarized_text)
|
buffer_diarization = self.sep.join(undiarized_text)
|
||||||
|
|
||||||
async with self.lock:
|
async with self.lock:
|
||||||
self.end_attributed_speaker = state.end_attributed_speaker
|
self.state.end_attributed_speaker = state.end_attributed_speaker
|
||||||
|
|
||||||
|
buffer_translation_text = ''
|
||||||
|
if state.buffer_translation:
|
||||||
|
raw_buffer_translation = getattr(state.buffer_translation, 'text', state.buffer_translation)
|
||||||
|
if raw_buffer_translation:
|
||||||
|
buffer_translation_text = raw_buffer_translation.strip()
|
||||||
|
|
||||||
response_status = "active_transcription"
|
response_status = "active_transcription"
|
||||||
if not state.tokens and not buffer_transcription and not buffer_diarization:
|
if not state.tokens and not buffer_transcription and not buffer_diarization:
|
||||||
@@ -473,6 +461,7 @@ class AudioProcessor:
|
|||||||
lines=lines,
|
lines=lines,
|
||||||
buffer_transcription=buffer_transcription.text.strip(),
|
buffer_transcription=buffer_transcription.text.strip(),
|
||||||
buffer_diarization=buffer_diarization,
|
buffer_diarization=buffer_diarization,
|
||||||
|
buffer_translation=buffer_translation_text,
|
||||||
remaining_time_transcription=state.remaining_time_transcription,
|
remaining_time_transcription=state.remaining_time_transcription,
|
||||||
remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0
|
remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0
|
||||||
)
|
)
|
||||||
@@ -482,23 +471,14 @@ class AudioProcessor:
|
|||||||
yield response
|
yield response
|
||||||
self.last_response_content = response
|
self.last_response_content = response
|
||||||
|
|
||||||
# Check for termination condition
|
if self.is_stopping and self._processing_tasks_done():
|
||||||
if self.is_stopping:
|
logger.info("Results formatter: All upstream processors are done and in stopping state. Terminating.")
|
||||||
all_processors_done = True
|
return
|
||||||
if self.args.transcription and self.transcription_task and not self.transcription_task.done():
|
|
||||||
all_processors_done = False
|
|
||||||
if self.args.diarization and self.diarization_task and not self.diarization_task.done():
|
|
||||||
all_processors_done = False
|
|
||||||
|
|
||||||
if all_processors_done:
|
|
||||||
logger.info("Results formatter: All upstream processors are done and in stopping state. Terminating.")
|
|
||||||
return
|
|
||||||
|
|
||||||
await asyncio.sleep(0.05)
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Exception in results_formatter: {e}")
|
logger.warning(f"Exception in results_formatter. Traceback: {traceback.format_exc()}")
|
||||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
|
||||||
await asyncio.sleep(0.5)
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
async def create_tasks(self):
|
async def create_tasks(self):
|
||||||
@@ -544,11 +524,16 @@ class AudioProcessor:
|
|||||||
|
|
||||||
async def watchdog(self, tasks_to_monitor):
|
async def watchdog(self, tasks_to_monitor):
|
||||||
"""Monitors the health of critical processing tasks."""
|
"""Monitors the health of critical processing tasks."""
|
||||||
|
tasks_remaining = [task for task in tasks_to_monitor if task]
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
if not tasks_remaining:
|
||||||
|
logger.info("Watchdog task finishing: all monitored tasks completed.")
|
||||||
|
return
|
||||||
|
|
||||||
await asyncio.sleep(10)
|
await asyncio.sleep(10)
|
||||||
|
|
||||||
for i, task in enumerate(tasks_to_monitor):
|
for i, task in enumerate(list(tasks_remaining)):
|
||||||
if task.done():
|
if task.done():
|
||||||
exc = task.exception()
|
exc = task.exception()
|
||||||
task_name = task.get_name() if hasattr(task, 'get_name') else f"Monitored Task {i}"
|
task_name = task.get_name() if hasattr(task, 'get_name') else f"Monitored Task {i}"
|
||||||
@@ -556,6 +541,7 @@ class AudioProcessor:
|
|||||||
logger.error(f"{task_name} unexpectedly completed with exception: {exc}")
|
logger.error(f"{task_name} unexpectedly completed with exception: {exc}")
|
||||||
else:
|
else:
|
||||||
logger.info(f"{task_name} completed normally.")
|
logger.info(f"{task_name} completed normally.")
|
||||||
|
tasks_remaining.remove(task)
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info("Watchdog task cancelled.")
|
logger.info("Watchdog task cancelled.")
|
||||||
@@ -586,12 +572,22 @@ class AudioProcessor:
|
|||||||
self.diarization.close()
|
self.diarization.close()
|
||||||
logger.info("AudioProcessor cleanup complete.")
|
logger.info("AudioProcessor cleanup complete.")
|
||||||
|
|
||||||
|
def _processing_tasks_done(self):
|
||||||
|
"""Return True when all active processing tasks have completed."""
|
||||||
|
tasks_to_check = [
|
||||||
|
self.transcription_task,
|
||||||
|
self.diarization_task,
|
||||||
|
self.translation_task,
|
||||||
|
self.ffmpeg_reader_task,
|
||||||
|
]
|
||||||
|
return all(task.done() for task in tasks_to_check if task)
|
||||||
|
|
||||||
|
|
||||||
async def process_audio(self, message):
|
async def process_audio(self, message):
|
||||||
"""Process incoming audio data."""
|
"""Process incoming audio data."""
|
||||||
|
|
||||||
if not self.beg_loop:
|
if not self.state.beg_loop:
|
||||||
self.beg_loop = time()
|
self.state.beg_loop = time()
|
||||||
|
|
||||||
if not message:
|
if not message:
|
||||||
logger.info("Empty audio message received, initiating stop sequence.")
|
logger.info("Empty audio message received, initiating stop sequence.")
|
||||||
|
|||||||
41
whisperlivekit/backend_support.py
Normal file
41
whisperlivekit/backend_support.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
import importlib.util
|
||||||
|
import logging
|
||||||
|
import platform
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def module_available(module_name):
|
||||||
|
"""Return True if the given module can be imported."""
|
||||||
|
return importlib.util.find_spec(module_name) is not None
|
||||||
|
|
||||||
|
|
||||||
|
def mlx_backend_available(warn_on_missing = False):
|
||||||
|
is_macos = platform.system() == "Darwin"
|
||||||
|
is_arm = platform.machine() == "arm64"
|
||||||
|
available = (
|
||||||
|
is_macos
|
||||||
|
and is_arm
|
||||||
|
and module_available("mlx_whisper")
|
||||||
|
)
|
||||||
|
if not available and warn_on_missing and is_macos and is_arm:
|
||||||
|
logger.warning(
|
||||||
|
"=" * 50
|
||||||
|
+ "\nMLX Whisper not found but you are on Apple Silicon. "
|
||||||
|
"Consider installing mlx-whisper for better performance: "
|
||||||
|
"`pip install mlx-whisper`\n"
|
||||||
|
+ "=" * 50
|
||||||
|
)
|
||||||
|
return available
|
||||||
|
|
||||||
|
|
||||||
|
def faster_backend_available(warn_on_missing = False):
|
||||||
|
available = module_available("faster_whisper")
|
||||||
|
if not available and warn_on_missing and platform.system() != "Darwin":
|
||||||
|
logger.warning(
|
||||||
|
"=" * 50
|
||||||
|
+ "\nFaster-Whisper not found. Consider installing faster-whisper "
|
||||||
|
"for better performance: `pip install faster-whisper`\n"
|
||||||
|
+ "=" * 50
|
||||||
|
)
|
||||||
|
return available
|
||||||
@@ -1,11 +1,9 @@
|
|||||||
try:
|
from whisperlivekit.local_agreement.whisper_online import backend_factory
|
||||||
from whisperlivekit.whisper_streaming_custom.whisper_online import backend_factory
|
from whisperlivekit.simul_whisper import SimulStreamingASR
|
||||||
from whisperlivekit.whisper_streaming_custom.online_asr import OnlineASRProcessor
|
from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor
|
||||||
except ImportError:
|
|
||||||
from .whisper_streaming_custom.whisper_online import backend_factory
|
|
||||||
from .whisper_streaming_custom.online_asr import OnlineASRProcessor
|
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
import sys
|
import sys
|
||||||
|
import logging
|
||||||
|
|
||||||
def update_with_kwargs(_dict, kwargs):
|
def update_with_kwargs(_dict, kwargs):
|
||||||
_dict.update({
|
_dict.update({
|
||||||
@@ -13,6 +11,9 @@ def update_with_kwargs(_dict, kwargs):
|
|||||||
})
|
})
|
||||||
return _dict
|
return _dict
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class TranscriptionEngine:
|
class TranscriptionEngine:
|
||||||
_instance = None
|
_instance = None
|
||||||
_initialized = False
|
_initialized = False
|
||||||
@@ -33,6 +34,7 @@ class TranscriptionEngine:
|
|||||||
"punctuation_split": False,
|
"punctuation_split": False,
|
||||||
"target_language": "",
|
"target_language": "",
|
||||||
"vac": True,
|
"vac": True,
|
||||||
|
"vac_onnx": False,
|
||||||
"vac_chunk_size": 0.04,
|
"vac_chunk_size": 0.04,
|
||||||
"log_level": "DEBUG",
|
"log_level": "DEBUG",
|
||||||
"ssl_certfile": None,
|
"ssl_certfile": None,
|
||||||
@@ -43,18 +45,20 @@ class TranscriptionEngine:
|
|||||||
"pcm_input": False,
|
"pcm_input": False,
|
||||||
"disable_punctuation_split" : False,
|
"disable_punctuation_split" : False,
|
||||||
"diarization_backend": "sortformer",
|
"diarization_backend": "sortformer",
|
||||||
|
"backend_policy": "simulstreaming",
|
||||||
|
"backend": "auto",
|
||||||
}
|
}
|
||||||
global_params = update_with_kwargs(global_params, kwargs)
|
global_params = update_with_kwargs(global_params, kwargs)
|
||||||
|
|
||||||
transcription_common_params = {
|
transcription_common_params = {
|
||||||
"backend": "simulstreaming",
|
|
||||||
"warmup_file": None,
|
"warmup_file": None,
|
||||||
"min_chunk_size": 0.5,
|
"min_chunk_size": 0.5,
|
||||||
"model_size": "tiny",
|
"model_size": "tiny",
|
||||||
"model_cache_dir": None,
|
"model_cache_dir": None,
|
||||||
"model_dir": None,
|
"model_dir": None,
|
||||||
|
"model_path": None,
|
||||||
"lan": "auto",
|
"lan": "auto",
|
||||||
"task": "transcribe",
|
"direct_english_translation": False,
|
||||||
}
|
}
|
||||||
transcription_common_params = update_with_kwargs(transcription_common_params, kwargs)
|
transcription_common_params = update_with_kwargs(transcription_common_params, kwargs)
|
||||||
|
|
||||||
@@ -75,13 +79,14 @@ class TranscriptionEngine:
|
|||||||
self.vac_model = None
|
self.vac_model = None
|
||||||
|
|
||||||
if self.args.vac:
|
if self.args.vac:
|
||||||
import torch
|
from whisperlivekit.silero_vad_iterator import load_silero_vad
|
||||||
self.vac_model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
|
# Use ONNX if specified, otherwise use JIT (default)
|
||||||
|
use_onnx = kwargs.get('vac_onnx', False)
|
||||||
|
self.vac_model = load_silero_vad(onnx=use_onnx)
|
||||||
|
|
||||||
|
backend_policy = self.args.backend_policy
|
||||||
if self.args.transcription:
|
if self.args.transcription:
|
||||||
if self.args.backend == "simulstreaming":
|
if backend_policy == "simulstreaming":
|
||||||
from whisperlivekit.simul_whisper import SimulStreamingASR
|
|
||||||
|
|
||||||
simulstreaming_params = {
|
simulstreaming_params = {
|
||||||
"disable_fast_encoder": False,
|
"disable_fast_encoder": False,
|
||||||
"custom_alignment_heads": None,
|
"custom_alignment_heads": None,
|
||||||
@@ -95,14 +100,19 @@ class TranscriptionEngine:
|
|||||||
"init_prompt": None,
|
"init_prompt": None,
|
||||||
"static_init_prompt": None,
|
"static_init_prompt": None,
|
||||||
"max_context_tokens": None,
|
"max_context_tokens": None,
|
||||||
"model_path": './base.pt',
|
|
||||||
"preload_model_count": 1,
|
"preload_model_count": 1,
|
||||||
}
|
}
|
||||||
simulstreaming_params = update_with_kwargs(simulstreaming_params, kwargs)
|
simulstreaming_params = update_with_kwargs(simulstreaming_params, kwargs)
|
||||||
|
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
self.asr = SimulStreamingASR(
|
self.asr = SimulStreamingASR(
|
||||||
**transcription_common_params, **simulstreaming_params
|
**transcription_common_params,
|
||||||
|
**simulstreaming_params,
|
||||||
|
backend=self.args.backend,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"Using SimulStreaming policy with %s backend",
|
||||||
|
getattr(self.asr, "encoder_backend", "whisper"),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
@@ -114,7 +124,13 @@ class TranscriptionEngine:
|
|||||||
whisperstreaming_params = update_with_kwargs(whisperstreaming_params, kwargs)
|
whisperstreaming_params = update_with_kwargs(whisperstreaming_params, kwargs)
|
||||||
|
|
||||||
self.asr = backend_factory(
|
self.asr = backend_factory(
|
||||||
**transcription_common_params, **whisperstreaming_params
|
backend=self.args.backend,
|
||||||
|
**transcription_common_params,
|
||||||
|
**whisperstreaming_params,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"Using LocalAgreement policy with %s backend",
|
||||||
|
getattr(self.asr, "backend_choice", self.asr.__class__.__name__),
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.diarization:
|
if self.args.diarization:
|
||||||
@@ -135,12 +151,15 @@ class TranscriptionEngine:
|
|||||||
|
|
||||||
self.translation_model = None
|
self.translation_model = None
|
||||||
if self.args.target_language:
|
if self.args.target_language:
|
||||||
if self.args.lan == 'auto' and self.args.backend != "simulstreaming":
|
if self.args.lan == 'auto' and backend_policy != "simulstreaming":
|
||||||
raise Exception('Translation cannot be set with language auto when transcription backend is not simulstreaming')
|
raise Exception('Translation cannot be set with language auto when transcription backend is not simulstreaming')
|
||||||
else:
|
else:
|
||||||
from whisperlivekit.translation.translation import load_model
|
try:
|
||||||
|
from nllw import load_model
|
||||||
|
except:
|
||||||
|
raise Exception('To use translation, you must install nllw: `pip install nllw`')
|
||||||
translation_params = {
|
translation_params = {
|
||||||
"nllb_backend": "ctranslate2",
|
"nllb_backend": "transformers",
|
||||||
"nllb_size": "600M"
|
"nllb_size": "600M"
|
||||||
}
|
}
|
||||||
translation_params = update_with_kwargs(translation_params, kwargs)
|
translation_params = update_with_kwargs(translation_params, kwargs)
|
||||||
@@ -149,7 +168,7 @@ class TranscriptionEngine:
|
|||||||
|
|
||||||
|
|
||||||
def online_factory(args, asr):
|
def online_factory(args, asr):
|
||||||
if args.backend == "simulstreaming":
|
if args.backend_policy == "simulstreaming":
|
||||||
from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor
|
from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor
|
||||||
online = SimulStreamingOnlineProcessor(asr)
|
online = SimulStreamingOnlineProcessor(asr)
|
||||||
else:
|
else:
|
||||||
@@ -172,5 +191,5 @@ def online_translation_factory(args, translation_model):
|
|||||||
#should be at speaker level in the future:
|
#should be at speaker level in the future:
|
||||||
#one shared nllb model for all speaker
|
#one shared nllb model for all speaker
|
||||||
#one tokenizer per speaker/language
|
#one tokenizer per speaker/language
|
||||||
from whisperlivekit.translation.translation import OnlineTranslation
|
from nllw import OnlineTranslation
|
||||||
return OnlineTranslation(translation_model, [args.lan], [args.target_language])
|
return OnlineTranslation(translation_model, [args.lan], [args.target_language])
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ import math
|
|||||||
from typing import List
|
from typing import List
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from whisperlivekit.timed_objects import ASRToken
|
from whisperlivekit.timed_objects import ASRToken
|
||||||
|
from whisperlivekit.model_paths import resolve_model_path, model_path_and_type
|
||||||
|
from whisperlivekit.whisper.transcribe import transcribe as whisper_transcribe
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
class ASRBase:
|
class ASRBase:
|
||||||
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
|
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
|
||||||
@@ -37,40 +39,60 @@ class ASRBase:
|
|||||||
raise NotImplementedError("must be implemented in the child class")
|
raise NotImplementedError("must be implemented in the child class")
|
||||||
|
|
||||||
|
|
||||||
class WhisperTimestampedASR(ASRBase):
|
class WhisperASR(ASRBase):
|
||||||
"""Uses whisper_timestamped as the backend."""
|
"""Uses WhisperLiveKit's built-in Whisper implementation."""
|
||||||
sep = " "
|
sep = " "
|
||||||
|
|
||||||
def load_model(self, model_size=None, cache_dir=None, model_dir=None):
|
def load_model(self, model_size=None, cache_dir=None, model_dir=None):
|
||||||
import whisper
|
from whisperlivekit.whisper import load_model as load_model
|
||||||
import whisper_timestamped
|
|
||||||
from whisper_timestamped import transcribe_timestamped
|
|
||||||
|
|
||||||
self.transcribe_timestamped = transcribe_timestamped
|
|
||||||
if model_dir is not None:
|
if model_dir is not None:
|
||||||
logger.debug("ignoring model_dir, not implemented")
|
resolved_path = resolve_model_path(model_dir)
|
||||||
return whisper.load_model(model_size, download_root=cache_dir)
|
if resolved_path.is_dir():
|
||||||
|
pytorch_path, _, _ = model_path_and_type(resolved_path)
|
||||||
|
if pytorch_path is None:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"No supported PyTorch checkpoint found under {resolved_path}"
|
||||||
|
)
|
||||||
|
resolved_path = pytorch_path
|
||||||
|
logger.debug(f"Loading Whisper model from custom path {resolved_path}")
|
||||||
|
return load_model(str(resolved_path))
|
||||||
|
|
||||||
|
if model_size is None:
|
||||||
|
raise ValueError("Either model_size or model_dir must be set for WhisperASR")
|
||||||
|
|
||||||
|
return load_model(model_size, download_root=cache_dir)
|
||||||
|
|
||||||
def transcribe(self, audio, init_prompt=""):
|
def transcribe(self, audio, init_prompt=""):
|
||||||
result = self.transcribe_timestamped(
|
options = dict(self.transcribe_kargs)
|
||||||
|
options.pop("vad", None)
|
||||||
|
options.pop("vad_filter", None)
|
||||||
|
language = self.original_language if self.original_language else None
|
||||||
|
|
||||||
|
result = whisper_transcribe(
|
||||||
self.model,
|
self.model,
|
||||||
audio,
|
audio,
|
||||||
language=self.original_language,
|
language=language,
|
||||||
initial_prompt=init_prompt,
|
initial_prompt=init_prompt,
|
||||||
verbose=None,
|
|
||||||
condition_on_previous_text=True,
|
condition_on_previous_text=True,
|
||||||
**self.transcribe_kargs,
|
word_timestamps=True,
|
||||||
|
**options,
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def ts_words(self, r) -> List[ASRToken]:
|
def ts_words(self, r) -> List[ASRToken]:
|
||||||
"""
|
"""
|
||||||
Converts the whisper_timestamped result to a list of ASRToken objects.
|
Converts the Whisper result to a list of ASRToken objects.
|
||||||
"""
|
"""
|
||||||
tokens = []
|
tokens = []
|
||||||
for segment in r["segments"]:
|
for segment in r["segments"]:
|
||||||
for word in segment["words"]:
|
for word in segment["words"]:
|
||||||
token = ASRToken(word["start"], word["end"], word["text"])
|
token = ASRToken(
|
||||||
|
word["start"],
|
||||||
|
word["end"],
|
||||||
|
word["word"],
|
||||||
|
probability=word.get("probability"),
|
||||||
|
)
|
||||||
tokens.append(token)
|
tokens.append(token)
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
@@ -78,11 +100,7 @@ class WhisperTimestampedASR(ASRBase):
|
|||||||
return [segment["end"] for segment in res["segments"]]
|
return [segment["end"] for segment in res["segments"]]
|
||||||
|
|
||||||
def use_vad(self):
|
def use_vad(self):
|
||||||
self.transcribe_kargs["vad"] = True
|
logger.warning("VAD is not currently supported for WhisperASR backend and will be ignored.")
|
||||||
|
|
||||||
def set_translate_task(self):
|
|
||||||
self.transcribe_kargs["task"] = "translate"
|
|
||||||
|
|
||||||
|
|
||||||
class FasterWhisperASR(ASRBase):
|
class FasterWhisperASR(ASRBase):
|
||||||
"""Uses faster-whisper as the backend."""
|
"""Uses faster-whisper as the backend."""
|
||||||
@@ -92,9 +110,10 @@ class FasterWhisperASR(ASRBase):
|
|||||||
from faster_whisper import WhisperModel
|
from faster_whisper import WhisperModel
|
||||||
|
|
||||||
if model_dir is not None:
|
if model_dir is not None:
|
||||||
logger.debug(f"Loading whisper model from model_dir {model_dir}. "
|
resolved_path = resolve_model_path(model_dir)
|
||||||
|
logger.debug(f"Loading faster-whisper model from {resolved_path}. "
|
||||||
f"model_size and cache_dir parameters are not used.")
|
f"model_size and cache_dir parameters are not used.")
|
||||||
model_size_or_path = model_dir
|
model_size_or_path = str(resolved_path)
|
||||||
elif model_size is not None:
|
elif model_size is not None:
|
||||||
model_size_or_path = model_size
|
model_size_or_path = model_size
|
||||||
else:
|
else:
|
||||||
@@ -139,10 +158,6 @@ class FasterWhisperASR(ASRBase):
|
|||||||
def use_vad(self):
|
def use_vad(self):
|
||||||
self.transcribe_kargs["vad_filter"] = True
|
self.transcribe_kargs["vad_filter"] = True
|
||||||
|
|
||||||
def set_translate_task(self):
|
|
||||||
self.transcribe_kargs["task"] = "translate"
|
|
||||||
|
|
||||||
|
|
||||||
class MLXWhisper(ASRBase):
|
class MLXWhisper(ASRBase):
|
||||||
"""
|
"""
|
||||||
Uses MLX Whisper optimized for Apple Silicon.
|
Uses MLX Whisper optimized for Apple Silicon.
|
||||||
@@ -154,8 +169,9 @@ class MLXWhisper(ASRBase):
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
if model_dir is not None:
|
if model_dir is not None:
|
||||||
logger.debug(f"Loading whisper model from model_dir {model_dir}. model_size parameter is not used.")
|
resolved_path = resolve_model_path(model_dir)
|
||||||
model_size_or_path = model_dir
|
logger.debug(f"Loading MLX Whisper model from {resolved_path}. model_size parameter is not used.")
|
||||||
|
model_size_or_path = str(resolved_path)
|
||||||
elif model_size is not None:
|
elif model_size is not None:
|
||||||
model_size_or_path = self.translate_model_name(model_size)
|
model_size_or_path = self.translate_model_name(model_size)
|
||||||
logger.debug(f"Loading whisper model {model_size}. You use mlx whisper, so {model_size_or_path} will be used.")
|
logger.debug(f"Loading whisper model {model_size}. You use mlx whisper, so {model_size_or_path} will be used.")
|
||||||
@@ -218,10 +234,6 @@ class MLXWhisper(ASRBase):
|
|||||||
def use_vad(self):
|
def use_vad(self):
|
||||||
self.transcribe_kargs["vad_filter"] = True
|
self.transcribe_kargs["vad_filter"] = True
|
||||||
|
|
||||||
def set_translate_task(self):
|
|
||||||
self.transcribe_kargs["task"] = "translate"
|
|
||||||
|
|
||||||
|
|
||||||
class OpenaiApiASR(ASRBase):
|
class OpenaiApiASR(ASRBase):
|
||||||
"""Uses OpenAI's Whisper API for transcription."""
|
"""Uses OpenAI's Whisper API for transcription."""
|
||||||
def __init__(self, lan=None, temperature=0, logfile=sys.stderr):
|
def __init__(self, lan=None, temperature=0, logfile=sys.stderr):
|
||||||
@@ -232,7 +244,7 @@ class OpenaiApiASR(ASRBase):
|
|||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.load_model()
|
self.load_model()
|
||||||
self.use_vad_opt = False
|
self.use_vad_opt = False
|
||||||
self.task = "transcribe"
|
self.direct_english_translation = False
|
||||||
|
|
||||||
def load_model(self, *args, **kwargs):
|
def load_model(self, *args, **kwargs):
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
@@ -274,7 +286,7 @@ class OpenaiApiASR(ASRBase):
|
|||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
"timestamp_granularities": ["word", "segment"],
|
"timestamp_granularities": ["word", "segment"],
|
||||||
}
|
}
|
||||||
if self.task != "translate" and self.original_language:
|
if not self.direct_english_translation and self.original_language:
|
||||||
params["language"] = self.original_language
|
params["language"] = self.original_language
|
||||||
if prompt:
|
if prompt:
|
||||||
params["prompt"] = prompt
|
params["prompt"] = prompt
|
||||||
@@ -285,6 +297,3 @@ class OpenaiApiASR(ASRBase):
|
|||||||
|
|
||||||
def use_vad(self):
|
def use_vad(self):
|
||||||
self.use_vad_opt = True
|
self.use_vad_opt = True
|
||||||
|
|
||||||
def set_translate_task(self):
|
|
||||||
self.task = "translate"
|
|
||||||
199
whisperlivekit/local_agreement/whisper_online.py
Normal file
199
whisperlivekit/local_agreement/whisper_online.py
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
import librosa
|
||||||
|
from functools import lru_cache
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
import platform
|
||||||
|
from .backends import FasterWhisperASR, MLXWhisper, WhisperASR, OpenaiApiASR
|
||||||
|
from whisperlivekit.warmup import warmup_asr
|
||||||
|
from whisperlivekit.model_paths import resolve_model_path, model_path_and_type
|
||||||
|
from whisperlivekit.backend_support import (
|
||||||
|
mlx_backend_available,
|
||||||
|
faster_backend_available,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
WHISPER_LANG_CODES = "af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh".split(
|
||||||
|
","
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_tokenizer(lan):
|
||||||
|
"""returns an object that has split function that works like the one of MosesTokenizer"""
|
||||||
|
|
||||||
|
assert (
|
||||||
|
lan in WHISPER_LANG_CODES
|
||||||
|
), "language must be Whisper's supported lang code: " + " ".join(WHISPER_LANG_CODES)
|
||||||
|
|
||||||
|
if lan == "uk":
|
||||||
|
import tokenize_uk
|
||||||
|
|
||||||
|
class UkrainianTokenizer:
|
||||||
|
def split(self, text):
|
||||||
|
return tokenize_uk.tokenize_sents(text)
|
||||||
|
|
||||||
|
return UkrainianTokenizer()
|
||||||
|
|
||||||
|
# supported by fast-mosestokenizer
|
||||||
|
if (
|
||||||
|
lan
|
||||||
|
in "as bn ca cs de el en es et fi fr ga gu hi hu is it kn lt lv ml mni mr nl or pa pl pt ro ru sk sl sv ta te yue zh".split()
|
||||||
|
):
|
||||||
|
from mosestokenizer import MosesSentenceSplitter
|
||||||
|
|
||||||
|
return MosesSentenceSplitter(lan)
|
||||||
|
|
||||||
|
# the following languages are in Whisper, but not in wtpsplit:
|
||||||
|
if (
|
||||||
|
lan
|
||||||
|
in "as ba bo br bs fo haw hr ht jw lb ln lo mi nn oc sa sd sn so su sw tk tl tt".split()
|
||||||
|
):
|
||||||
|
logger.debug(
|
||||||
|
f"{lan} code is not supported by wtpsplit. Going to use None lang_code option."
|
||||||
|
)
|
||||||
|
lan = None
|
||||||
|
|
||||||
|
from wtpsplit import WtP
|
||||||
|
|
||||||
|
# downloads the model from huggingface on the first use
|
||||||
|
wtp = WtP("wtp-canine-s-12l-no-adapters")
|
||||||
|
|
||||||
|
class WtPtok:
|
||||||
|
def split(self, sent):
|
||||||
|
return wtp.split(sent, lang_code=lan)
|
||||||
|
|
||||||
|
return WtPtok()
|
||||||
|
|
||||||
|
|
||||||
|
def backend_factory(
|
||||||
|
backend,
|
||||||
|
lan,
|
||||||
|
model_size,
|
||||||
|
model_cache_dir,
|
||||||
|
model_dir,
|
||||||
|
model_path,
|
||||||
|
direct_english_translation,
|
||||||
|
buffer_trimming,
|
||||||
|
buffer_trimming_sec,
|
||||||
|
confidence_validation,
|
||||||
|
warmup_file=None,
|
||||||
|
min_chunk_size=None,
|
||||||
|
):
|
||||||
|
backend_choice = backend
|
||||||
|
custom_reference = model_path or model_dir
|
||||||
|
resolved_root = None
|
||||||
|
pytorch_checkpoint = None
|
||||||
|
has_mlx_weights = False
|
||||||
|
has_fw_weights = False
|
||||||
|
|
||||||
|
if custom_reference:
|
||||||
|
resolved_root = resolve_model_path(custom_reference)
|
||||||
|
if resolved_root.is_dir():
|
||||||
|
pytorch_checkpoint, has_mlx_weights, has_fw_weights = model_path_and_type(resolved_root)
|
||||||
|
else:
|
||||||
|
pytorch_checkpoint = resolved_root
|
||||||
|
|
||||||
|
if backend_choice == "openai-api":
|
||||||
|
logger.debug("Using OpenAI API.")
|
||||||
|
asr = OpenaiApiASR(lan=lan)
|
||||||
|
else:
|
||||||
|
backend_choice = _normalize_backend_choice(
|
||||||
|
backend_choice,
|
||||||
|
resolved_root,
|
||||||
|
has_mlx_weights,
|
||||||
|
has_fw_weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
if backend_choice == "faster-whisper":
|
||||||
|
asr_cls = FasterWhisperASR
|
||||||
|
if resolved_root is not None and not resolved_root.is_dir():
|
||||||
|
raise ValueError("Faster-Whisper backend expects a directory with CTranslate2 weights.")
|
||||||
|
model_override = str(resolved_root) if resolved_root is not None else None
|
||||||
|
elif backend_choice == "mlx-whisper":
|
||||||
|
asr_cls = MLXWhisper
|
||||||
|
if resolved_root is not None and not resolved_root.is_dir():
|
||||||
|
raise ValueError("MLX Whisper backend expects a directory containing MLX weights.")
|
||||||
|
model_override = str(resolved_root) if resolved_root is not None else None
|
||||||
|
else:
|
||||||
|
asr_cls = WhisperASR
|
||||||
|
model_override = str(pytorch_checkpoint) if pytorch_checkpoint is not None else None
|
||||||
|
if custom_reference and model_override is None:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"No PyTorch checkpoint found under {resolved_root or custom_reference}"
|
||||||
|
)
|
||||||
|
|
||||||
|
t = time.time()
|
||||||
|
logger.info(f"Loading Whisper {model_size} model for language {lan} using backend {backend_choice}...")
|
||||||
|
asr = asr_cls(
|
||||||
|
model_size=model_size,
|
||||||
|
lan=lan,
|
||||||
|
cache_dir=model_cache_dir,
|
||||||
|
model_dir=model_override,
|
||||||
|
)
|
||||||
|
e = time.time()
|
||||||
|
logger.info(f"done. It took {round(e-t,2)} seconds.")
|
||||||
|
|
||||||
|
if direct_english_translation:
|
||||||
|
tgt_language = "en" # Whisper translates into English
|
||||||
|
else:
|
||||||
|
tgt_language = lan # Whisper transcribes in this language
|
||||||
|
|
||||||
|
# Create the tokenizer
|
||||||
|
if buffer_trimming == "sentence":
|
||||||
|
tokenizer = create_tokenizer(tgt_language)
|
||||||
|
else:
|
||||||
|
tokenizer = None
|
||||||
|
|
||||||
|
warmup_asr(asr, warmup_file)
|
||||||
|
|
||||||
|
asr.confidence_validation = confidence_validation
|
||||||
|
asr.tokenizer = tokenizer
|
||||||
|
asr.buffer_trimming = buffer_trimming
|
||||||
|
asr.buffer_trimming_sec = buffer_trimming_sec
|
||||||
|
asr.backend_choice = backend_choice
|
||||||
|
return asr
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_backend_choice(
|
||||||
|
preferred_backend,
|
||||||
|
resolved_root,
|
||||||
|
has_mlx_weights,
|
||||||
|
has_fw_weights,
|
||||||
|
):
|
||||||
|
backend_choice = preferred_backend
|
||||||
|
|
||||||
|
if backend_choice == "auto":
|
||||||
|
if mlx_backend_available(warn_on_missing=True) and (resolved_root is None or has_mlx_weights):
|
||||||
|
return "mlx-whisper"
|
||||||
|
if faster_backend_available(warn_on_missing=True) and (resolved_root is None or has_fw_weights):
|
||||||
|
return "faster-whisper"
|
||||||
|
return "whisper"
|
||||||
|
|
||||||
|
if backend_choice == "mlx-whisper":
|
||||||
|
if not mlx_backend_available():
|
||||||
|
raise RuntimeError("mlx-whisper backend requested but mlx-whisper is not installed.")
|
||||||
|
if resolved_root is not None and not has_mlx_weights:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"mlx-whisper backend requested but no MLX weights were found under {resolved_root}"
|
||||||
|
)
|
||||||
|
if platform.system() != "Darwin":
|
||||||
|
logger.warning("mlx-whisper backend requested on a non-macOS system; this may fail.")
|
||||||
|
return backend_choice
|
||||||
|
|
||||||
|
if backend_choice == "faster-whisper":
|
||||||
|
if not faster_backend_available():
|
||||||
|
raise RuntimeError("faster-whisper backend requested but faster-whisper is not installed.")
|
||||||
|
if resolved_root is not None and not has_fw_weights:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"faster-whisper backend requested but no Faster-Whisper weights were found under {resolved_root}"
|
||||||
|
)
|
||||||
|
return backend_choice
|
||||||
|
|
||||||
|
if backend_choice == "whisper":
|
||||||
|
return backend_choice
|
||||||
|
|
||||||
|
raise ValueError(f"Unknown backend '{preferred_backend}' for LocalAgreement.")
|
||||||
69
whisperlivekit/model_paths.py
Normal file
69
whisperlivekit/model_paths.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
|
def model_path_and_type(model_path: Union[str, Path]) -> Tuple[Optional[Path], bool, bool]:
|
||||||
|
"""
|
||||||
|
Inspect the provided path and determine which model formats are available.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pytorch_path: Path to a PyTorch checkpoint (if present).
|
||||||
|
compatible_whisper_mlx: True if MLX weights exist in this folder.
|
||||||
|
compatible_faster_whisper: True if Faster-Whisper (ctranslate2) weights exist.
|
||||||
|
"""
|
||||||
|
path = Path(model_path)
|
||||||
|
|
||||||
|
compatible_whisper_mlx = False
|
||||||
|
compatible_faster_whisper = False
|
||||||
|
pytorch_path: Optional[Path] = None
|
||||||
|
|
||||||
|
if path.is_file() and path.suffix.lower() in [".pt", ".safetensors", ".bin"]:
|
||||||
|
pytorch_path = path
|
||||||
|
return pytorch_path, compatible_whisper_mlx, compatible_faster_whisper
|
||||||
|
|
||||||
|
if path.is_dir():
|
||||||
|
for file in path.iterdir():
|
||||||
|
if not file.is_file():
|
||||||
|
continue
|
||||||
|
|
||||||
|
filename = file.name.lower()
|
||||||
|
suffix = file.suffix.lower()
|
||||||
|
|
||||||
|
if filename in {"weights.npz", "weights.safetensors"}:
|
||||||
|
compatible_whisper_mlx = True
|
||||||
|
elif filename in {"model.bin", "encoder.bin", "decoder.bin"}:
|
||||||
|
compatible_faster_whisper = True
|
||||||
|
elif suffix in {".pt", ".safetensors"}:
|
||||||
|
pytorch_path = file
|
||||||
|
elif filename == "pytorch_model.bin":
|
||||||
|
pytorch_path = file
|
||||||
|
|
||||||
|
if pytorch_path is None:
|
||||||
|
fallback = path / "pytorch_model.bin"
|
||||||
|
if fallback.exists():
|
||||||
|
pytorch_path = fallback
|
||||||
|
|
||||||
|
return pytorch_path, compatible_whisper_mlx, compatible_faster_whisper
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_model_path(model_path: Union[str, Path]) -> Path:
|
||||||
|
"""
|
||||||
|
Return a local path for the provided model reference.
|
||||||
|
|
||||||
|
If the path does not exist locally, it is treated as a Hugging Face repo id
|
||||||
|
and downloaded via snapshot_download.
|
||||||
|
"""
|
||||||
|
path = Path(model_path).expanduser()
|
||||||
|
if path.exists():
|
||||||
|
return path
|
||||||
|
|
||||||
|
try:
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
except ImportError as exc: # pragma: no cover - optional dependency guard
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Model path '{model_path}' does not exist locally and huggingface_hub "
|
||||||
|
"is not installed to download it."
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
downloaded_path = Path(snapshot_download(repo_id=str(model_path)))
|
||||||
|
return downloaded_path
|
||||||
@@ -114,11 +114,10 @@ def parse_args():
|
|||||||
help="Source language code, e.g. en,de,cs, or 'auto' for language detection.",
|
help="Source language code, e.g. en,de,cs, or 'auto' for language detection.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--task",
|
"--direct-english-translation",
|
||||||
type=str,
|
action="store_true",
|
||||||
default="transcribe",
|
default=False,
|
||||||
choices=["transcribe", "translate"],
|
help="Use Whisper to directly translate to english.",
|
||||||
help="Transcribe or translate.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -130,11 +129,18 @@ def parse_args():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--backend",
|
"--backend-policy",
|
||||||
type=str,
|
type=str,
|
||||||
default="simulstreaming",
|
default="simulstreaming",
|
||||||
choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api", "simulstreaming"],
|
choices=["1", "2", "simulstreaming", "localagreement"],
|
||||||
help="Load only this backend for Whisper processing.",
|
help="Select the streaming policy: 1 or 'simulstreaming' for AlignAtt, 2 or 'localagreement' for LocalAgreement.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--backend",
|
||||||
|
type=str,
|
||||||
|
default="auto",
|
||||||
|
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api"],
|
||||||
|
help="Select the Whisper backend implementation (auto: prefer MLX on macOS, otherwise Faster-Whisper, else Whisper). Use 'openai-api' with --backend-policy localagreement to call OpenAI's API.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--no-vac",
|
"--no-vac",
|
||||||
@@ -300,7 +306,7 @@ def parse_args():
|
|||||||
simulstreaming_group.add_argument(
|
simulstreaming_group.add_argument(
|
||||||
"--nllb-backend",
|
"--nllb-backend",
|
||||||
type=str,
|
type=str,
|
||||||
default="ctranslate2",
|
default="transformers",
|
||||||
help="transformers or ctranslate2",
|
help="transformers or ctranslate2",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -317,5 +323,10 @@ def parse_args():
|
|||||||
args.vad = not args.no_vad
|
args.vad = not args.no_vad
|
||||||
delattr(args, 'no_transcription')
|
delattr(args, 'no_transcription')
|
||||||
delattr(args, 'no_vad')
|
delattr(args, 'no_vad')
|
||||||
|
|
||||||
|
if args.backend_policy == "1":
|
||||||
|
args.backend_policy = "simulstreaming"
|
||||||
|
elif args.backend_policy == "2":
|
||||||
|
args.backend_policy = "localagreement"
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from whisperlivekit.timed_objects import ASRToken
|
from whisperlivekit.timed_objects import ASRToken
|
||||||
|
from time import time
|
||||||
import re
|
import re
|
||||||
|
|
||||||
MIN_SILENCE_DURATION = 4 #in seconds
|
MIN_SILENCE_DURATION = 4 #in seconds
|
||||||
@@ -77,7 +78,8 @@ def no_token_to_silence(tokens):
|
|||||||
new_tokens.append(token)
|
new_tokens.append(token)
|
||||||
return new_tokens
|
return new_tokens
|
||||||
|
|
||||||
def ends_with_silence(tokens, current_time, vac_detected_silence):
|
def ends_with_silence(tokens, beg_loop, vac_detected_silence):
|
||||||
|
current_time = time() - (beg_loop if beg_loop else 0.0)
|
||||||
last_token = tokens[-1]
|
last_token = tokens[-1]
|
||||||
if vac_detected_silence or (current_time - last_token.end >= END_SILENCE_DURATION):
|
if vac_detected_silence or (current_time - last_token.end >= END_SILENCE_DURATION):
|
||||||
if last_token.speaker == -2:
|
if last_token.speaker == -2:
|
||||||
@@ -94,11 +96,11 @@ def ends_with_silence(tokens, current_time, vac_detected_silence):
|
|||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
def handle_silences(tokens, current_time, vac_detected_silence):
|
def handle_silences(tokens, beg_loop, vac_detected_silence):
|
||||||
if not tokens:
|
if not tokens:
|
||||||
return []
|
return []
|
||||||
tokens = blank_to_silence(tokens) #useful for simulstreaming backend which tends to generate [BLANK_AUDIO] text
|
tokens = blank_to_silence(tokens) #useful for simulstreaming backend which tends to generate [BLANK_AUDIO] text
|
||||||
tokens = no_token_to_silence(tokens)
|
tokens = no_token_to_silence(tokens)
|
||||||
tokens = ends_with_silence(tokens, current_time, vac_detected_silence)
|
tokens = ends_with_silence(tokens, beg_loop, vac_detected_silence)
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
@@ -52,18 +52,18 @@ def append_token_to_last_line(lines, sep, token):
|
|||||||
lines[-1].detected_language = token.detected_language
|
lines[-1].detected_language = token.detected_language
|
||||||
|
|
||||||
|
|
||||||
def format_output(state, silence, current_time, args, sep):
|
def format_output(state, silence, args, sep):
|
||||||
diarization = args.diarization
|
diarization = args.diarization
|
||||||
disable_punctuation_split = args.disable_punctuation_split
|
disable_punctuation_split = args.disable_punctuation_split
|
||||||
tokens = state.tokens
|
tokens = state.tokens
|
||||||
translated_segments = state.translated_segments # Here we will attribute the speakers only based on the timestamps of the segments
|
translation_validated_segments = state.translation_validated_segments # Here we will attribute the speakers only based on the timestamps of the segments
|
||||||
last_validated_token = state.last_validated_token
|
last_validated_token = state.last_validated_token
|
||||||
|
|
||||||
previous_speaker = 1
|
previous_speaker = 1
|
||||||
undiarized_text = []
|
undiarized_text = []
|
||||||
tokens = handle_silences(tokens, current_time, silence)
|
tokens = handle_silences(tokens, state.beg_loop, silence)
|
||||||
last_punctuation = None
|
for i in range(last_validated_token, len(tokens)):
|
||||||
for i, token in enumerate(tokens[last_validated_token:]):
|
token = tokens[i]
|
||||||
speaker = int(token.speaker)
|
speaker = int(token.speaker)
|
||||||
token.corrected_speaker = speaker
|
token.corrected_speaker = speaker
|
||||||
if not diarization:
|
if not diarization:
|
||||||
@@ -71,17 +71,10 @@ def format_output(state, silence, current_time, args, sep):
|
|||||||
token.corrected_speaker = 1
|
token.corrected_speaker = 1
|
||||||
token.validated_speaker = True
|
token.validated_speaker = True
|
||||||
else:
|
else:
|
||||||
# if token.end > end_attributed_speaker and token.speaker != -2:
|
|
||||||
# if tokens[-1].speaker == -2: #if it finishes by a silence, we want to append the undiarized text to the last speaker.
|
|
||||||
# token.corrected_speaker = previous_speaker
|
|
||||||
# else:
|
|
||||||
# undiarized_text.append(token.text)
|
|
||||||
# continue
|
|
||||||
# else:
|
|
||||||
if is_punctuation(token):
|
if is_punctuation(token):
|
||||||
last_punctuation = i
|
state.last_punctuation_index = i
|
||||||
|
|
||||||
if last_punctuation == i-1:
|
if state.last_punctuation_index == i-1:
|
||||||
if token.speaker != previous_speaker:
|
if token.speaker != previous_speaker:
|
||||||
token.validated_speaker = True
|
token.validated_speaker = True
|
||||||
# perfect, diarization perfectly aligned
|
# perfect, diarization perfectly aligned
|
||||||
@@ -123,9 +116,9 @@ def format_output(state, silence, current_time, args, sep):
|
|||||||
|
|
||||||
previous_speaker = token.corrected_speaker
|
previous_speaker = token.corrected_speaker
|
||||||
|
|
||||||
if lines and translated_segments:
|
if lines:
|
||||||
unassigned_translated_segments = []
|
unassigned_translated_segments = []
|
||||||
for ts in translated_segments:
|
for ts in translation_validated_segments:
|
||||||
assigned = False
|
assigned = False
|
||||||
for line in lines:
|
for line in lines:
|
||||||
if ts and ts.overlaps_with(line):
|
if ts and ts.overlaps_with(line):
|
||||||
|
|||||||
@@ -1,27 +1,182 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import warnings
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
# This is copied from silero-vad's vad_utils.py:
|
"""
|
||||||
# https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/utils_vad.py#L340
|
Code is adapted from silero-vad v6: https://github.com/snakers4/silero-vad
|
||||||
# (except changed defaults)
|
"""
|
||||||
|
|
||||||
# Their licence is MIT, same as ours: https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/LICENSE
|
def init_jit_model(model_path: str, device=torch.device('cpu')):
|
||||||
|
"""Load a JIT model from file."""
|
||||||
|
model = torch.jit.load(model_path, map_location=device)
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxWrapper():
|
||||||
|
"""ONNX Runtime wrapper for Silero VAD model."""
|
||||||
|
|
||||||
|
def __init__(self, path, force_onnx_cpu=False):
|
||||||
|
global np
|
||||||
|
import numpy as np
|
||||||
|
import onnxruntime
|
||||||
|
|
||||||
|
opts = onnxruntime.SessionOptions()
|
||||||
|
opts.inter_op_num_threads = 1
|
||||||
|
opts.intra_op_num_threads = 1
|
||||||
|
|
||||||
|
if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
|
||||||
|
self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts)
|
||||||
|
else:
|
||||||
|
self.session = onnxruntime.InferenceSession(path, sess_options=opts)
|
||||||
|
|
||||||
|
self.reset_states()
|
||||||
|
if '16k' in path:
|
||||||
|
warnings.warn('This model support only 16000 sampling rate!')
|
||||||
|
self.sample_rates = [16000]
|
||||||
|
else:
|
||||||
|
self.sample_rates = [8000, 16000]
|
||||||
|
|
||||||
|
def _validate_input(self, x, sr: int):
|
||||||
|
if x.dim() == 1:
|
||||||
|
x = x.unsqueeze(0)
|
||||||
|
if x.dim() > 2:
|
||||||
|
raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}")
|
||||||
|
|
||||||
|
if sr != 16000 and (sr % 16000 == 0):
|
||||||
|
step = sr // 16000
|
||||||
|
x = x[:,::step]
|
||||||
|
sr = 16000
|
||||||
|
|
||||||
|
if sr not in self.sample_rates:
|
||||||
|
raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)")
|
||||||
|
if sr / x.shape[1] > 31.25:
|
||||||
|
raise ValueError("Input audio chunk is too short")
|
||||||
|
|
||||||
|
return x, sr
|
||||||
|
|
||||||
|
def reset_states(self, batch_size=1):
|
||||||
|
self._state = torch.zeros((2, batch_size, 128)).float()
|
||||||
|
self._context = torch.zeros(0)
|
||||||
|
self._last_sr = 0
|
||||||
|
self._last_batch_size = 0
|
||||||
|
|
||||||
|
def __call__(self, x, sr: int):
|
||||||
|
|
||||||
|
x, sr = self._validate_input(x, sr)
|
||||||
|
num_samples = 512 if sr == 16000 else 256
|
||||||
|
|
||||||
|
if x.shape[-1] != num_samples:
|
||||||
|
raise ValueError(f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")
|
||||||
|
|
||||||
|
batch_size = x.shape[0]
|
||||||
|
context_size = 64 if sr == 16000 else 32
|
||||||
|
|
||||||
|
if not self._last_batch_size:
|
||||||
|
self.reset_states(batch_size)
|
||||||
|
if (self._last_sr) and (self._last_sr != sr):
|
||||||
|
self.reset_states(batch_size)
|
||||||
|
if (self._last_batch_size) and (self._last_batch_size != batch_size):
|
||||||
|
self.reset_states(batch_size)
|
||||||
|
|
||||||
|
if not len(self._context):
|
||||||
|
self._context = torch.zeros(batch_size, context_size)
|
||||||
|
|
||||||
|
x = torch.cat([self._context, x], dim=1)
|
||||||
|
if sr in [8000, 16000]:
|
||||||
|
ort_inputs = {'input': x.numpy(), 'state': self._state.numpy(), 'sr': np.array(sr, dtype='int64')}
|
||||||
|
ort_outs = self.session.run(None, ort_inputs)
|
||||||
|
out, state = ort_outs
|
||||||
|
self._state = torch.from_numpy(state)
|
||||||
|
else:
|
||||||
|
raise ValueError()
|
||||||
|
|
||||||
|
self._context = x[..., -context_size:]
|
||||||
|
self._last_sr = sr
|
||||||
|
self._last_batch_size = batch_size
|
||||||
|
|
||||||
|
out = torch.from_numpy(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def load_silero_vad(model_path: str = None, onnx: bool = False, opset_version: int = 16):
|
||||||
|
"""
|
||||||
|
Load Silero VAD model (JIT or ONNX).
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model_path : str, optional
|
||||||
|
Path to model file. If None, uses default bundled model.
|
||||||
|
onnx : bool, default False
|
||||||
|
Whether to use ONNX runtime (requires onnxruntime package).
|
||||||
|
opset_version : int, default 16
|
||||||
|
ONNX opset version (15 or 16). Only used if onnx=True.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
model
|
||||||
|
Loaded VAD model (JIT or ONNX wrapper)
|
||||||
|
"""
|
||||||
|
available_ops = [15, 16]
|
||||||
|
if onnx and opset_version not in available_ops:
|
||||||
|
raise Exception(f'Available ONNX opset_version: {available_ops}')
|
||||||
|
if model_path is None:
|
||||||
|
current_dir = Path(__file__).parent
|
||||||
|
data_dir = current_dir / 'vad_models'
|
||||||
|
|
||||||
|
if onnx:
|
||||||
|
if opset_version == 16:
|
||||||
|
model_name = 'silero_vad.onnx'
|
||||||
|
else:
|
||||||
|
model_name = f'silero_vad_16k_op{opset_version}.onnx'
|
||||||
|
else:
|
||||||
|
model_name = 'silero_vad.jit'
|
||||||
|
|
||||||
|
model_path = data_dir / model_name
|
||||||
|
|
||||||
|
if not model_path.exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Model file not found: {model_path}\n"
|
||||||
|
f"Please ensure the whisperlivekit/vad_models/ directory contains the model files."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model_path = Path(model_path)
|
||||||
|
if onnx:
|
||||||
|
try:
|
||||||
|
model = OnnxWrapper(str(model_path), force_onnx_cpu=True)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"ONNX runtime not available. Install with: pip install onnxruntime\n"
|
||||||
|
"Or use JIT model by setting onnx=False"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model = init_jit_model(str(model_path))
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
class VADIterator:
|
class VADIterator:
|
||||||
def __init__(
|
"""
|
||||||
self,
|
Voice Activity Detection iterator for streaming audio.
|
||||||
model,
|
|
||||||
threshold: float = 0.5,
|
This is the Silero VAD v6 implementation.
|
||||||
sampling_rate: int = 16000,
|
"""
|
||||||
min_silence_duration_ms: int = 500, # makes sense on one recording that I checked
|
|
||||||
speech_pad_ms: int = 100, # same
|
def __init__(self,
|
||||||
):
|
model,
|
||||||
|
threshold: float = 0.5,
|
||||||
|
sampling_rate: int = 16000,
|
||||||
|
min_silence_duration_ms: int = 100,
|
||||||
|
speech_pad_ms: int = 30
|
||||||
|
):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Class for stream imitation
|
Class for stream imitation
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
model: preloaded .jit silero VAD model
|
model: preloaded .jit/.onnx silero VAD model
|
||||||
|
|
||||||
threshold: float (default - 0.5)
|
threshold: float (default - 0.5)
|
||||||
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
|
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
|
||||||
@@ -42,9 +197,7 @@ class VADIterator:
|
|||||||
self.sampling_rate = sampling_rate
|
self.sampling_rate = sampling_rate
|
||||||
|
|
||||||
if sampling_rate not in [8000, 16000]:
|
if sampling_rate not in [8000, 16000]:
|
||||||
raise ValueError(
|
raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]')
|
||||||
"VADIterator does not support sampling rates other than [8000, 16000]"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
|
self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
|
||||||
self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
|
self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
|
||||||
@@ -57,13 +210,17 @@ class VADIterator:
|
|||||||
self.temp_end = 0
|
self.temp_end = 0
|
||||||
self.current_sample = 0
|
self.current_sample = 0
|
||||||
|
|
||||||
def __call__(self, x, return_seconds=False):
|
@torch.no_grad()
|
||||||
|
def __call__(self, x, return_seconds=False, time_resolution: int = 1):
|
||||||
"""
|
"""
|
||||||
x: torch.Tensor
|
x: torch.Tensor
|
||||||
audio chunk (see examples in repo)
|
audio chunk (see examples in repo)
|
||||||
|
|
||||||
return_seconds: bool (default - False)
|
return_seconds: bool (default - False)
|
||||||
whether return timestamps in seconds (default - samples)
|
whether return timestamps in seconds (default - samples)
|
||||||
|
|
||||||
|
time_resolution: int (default - 1)
|
||||||
|
time resolution of speech coordinates when requested as seconds
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not torch.is_tensor(x):
|
if not torch.is_tensor(x):
|
||||||
@@ -82,14 +239,8 @@ class VADIterator:
|
|||||||
|
|
||||||
if (speech_prob >= self.threshold) and not self.triggered:
|
if (speech_prob >= self.threshold) and not self.triggered:
|
||||||
self.triggered = True
|
self.triggered = True
|
||||||
speech_start = self.current_sample - self.speech_pad_samples
|
speech_start = max(0, self.current_sample - self.speech_pad_samples - window_size_samples)
|
||||||
return {
|
return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, time_resolution)}
|
||||||
"start": (
|
|
||||||
int(speech_start)
|
|
||||||
if not return_seconds
|
|
||||||
else round(speech_start / self.sampling_rate, 1)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
if (speech_prob < self.threshold - 0.15) and self.triggered:
|
if (speech_prob < self.threshold - 0.15) and self.triggered:
|
||||||
if not self.temp_end:
|
if not self.temp_end:
|
||||||
@@ -97,30 +248,17 @@ class VADIterator:
|
|||||||
if self.current_sample - self.temp_end < self.min_silence_samples:
|
if self.current_sample - self.temp_end < self.min_silence_samples:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
speech_end = self.temp_end + self.speech_pad_samples
|
speech_end = self.temp_end + self.speech_pad_samples - window_size_samples
|
||||||
self.temp_end = 0
|
self.temp_end = 0
|
||||||
self.triggered = False
|
self.triggered = False
|
||||||
return {
|
return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, time_resolution)}
|
||||||
"end": (
|
|
||||||
int(speech_end)
|
|
||||||
if not return_seconds
|
|
||||||
else round(speech_end / self.sampling_rate, 1)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
#######################
|
|
||||||
# because Silero now requires exactly 512-sized audio chunks
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
class FixedVADIterator(VADIterator):
|
class FixedVADIterator(VADIterator):
|
||||||
"""It fixes VADIterator by allowing to process any audio length, not only exactly 512 frames at once.
|
"""
|
||||||
If audio to be processed at once is long and multiple voiced segments detected,
|
Fixed VAD Iterator that handles variable-length audio chunks, not only exactly 512 frames at once.
|
||||||
then __call__ returns the start of the first segment, and end (or middle, which means no end) of the last segment.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def reset_states(self):
|
def reset_states(self):
|
||||||
@@ -137,27 +275,20 @@ class FixedVADIterator(VADIterator):
|
|||||||
ret = r
|
ret = r
|
||||||
elif r is not None:
|
elif r is not None:
|
||||||
if "end" in r:
|
if "end" in r:
|
||||||
ret["end"] = r["end"] # the latter end
|
ret["end"] = r["end"]
|
||||||
if "start" in r and "end" in ret: # there is an earlier start.
|
if "start" in r and "end" in ret:
|
||||||
# Remove end, merging this segment with the previous one.
|
|
||||||
del ret["end"]
|
del ret["end"]
|
||||||
return ret if ret != {} else None
|
return ret if ret != {} else None
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# test/demonstrate the need for FixedVADIterator:
|
model = load_silero_vad(onnx=False)
|
||||||
|
vad = FixedVADIterator(model)
|
||||||
import torch
|
|
||||||
|
audio_buffer = np.array([0] * 512, dtype=np.float32)
|
||||||
model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
|
result = vad(audio_buffer)
|
||||||
vac = FixedVADIterator(model)
|
print(f" 512 samples: {result}")
|
||||||
# vac = VADIterator(model) # the second case crashes with this
|
|
||||||
|
# test with 511 samples
|
||||||
# this works: for both
|
audio_buffer = np.array([0] * 511, dtype=np.float32)
|
||||||
audio_buffer = np.array([0] * (512), dtype=np.float32)
|
result = vad(audio_buffer)
|
||||||
vac(audio_buffer)
|
|
||||||
|
|
||||||
# this crashes on the non FixedVADIterator with
|
|
||||||
# ops.prim.RaiseException("Input audio chunk is too short", "builtins.ValueError")
|
|
||||||
audio_buffer = np.array([0] * (512 - 1), dtype=np.float32)
|
|
||||||
vac(audio_buffer)
|
|
||||||
@@ -2,41 +2,37 @@ import sys
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Tuple, Optional
|
from typing import List, Tuple, Optional
|
||||||
import logging
|
|
||||||
import platform
|
import platform
|
||||||
from whisperlivekit.timed_objects import ASRToken, Transcript, ChangeSpeaker
|
from whisperlivekit.timed_objects import ASRToken, Transcript, ChangeSpeaker
|
||||||
from whisperlivekit.warmup import load_file
|
from whisperlivekit.warmup import load_file
|
||||||
from .whisper import load_model, tokenizer
|
from whisperlivekit.whisper import load_model, tokenizer
|
||||||
from .whisper.audio import TOKENS_PER_SECOND
|
from whisperlivekit.whisper.audio import TOKENS_PER_SECOND
|
||||||
import os
|
import os
|
||||||
import gc
|
import gc
|
||||||
logger = logging.getLogger(__name__)
|
from pathlib import Path
|
||||||
|
from whisperlivekit.model_paths import model_path_and_type, resolve_model_path
|
||||||
|
from whisperlivekit.backend_support import (
|
||||||
|
mlx_backend_available,
|
||||||
|
faster_backend_available,
|
||||||
|
)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
||||||
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper
|
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper
|
||||||
from whisperlivekit.simul_whisper.whisper import tokenizer
|
|
||||||
|
|
||||||
try:
|
logger = logging.getLogger(__name__)
|
||||||
from .mlx_encoder import mlx_model_mapping, load_mlx_encoder
|
|
||||||
HAS_MLX_WHISPER = True
|
|
||||||
except ImportError:
|
HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True)
|
||||||
if platform.system() == "Darwin" and platform.machine() == "arm64":
|
|
||||||
print(f"""{"="*50}
|
|
||||||
MLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: pip install mlx-whisper
|
|
||||||
{"="*50}""")
|
|
||||||
HAS_MLX_WHISPER = False
|
|
||||||
if HAS_MLX_WHISPER:
|
if HAS_MLX_WHISPER:
|
||||||
HAS_FASTER_WHISPER = False
|
from .mlx_encoder import mlx_model_mapping, load_mlx_encoder
|
||||||
else:
|
else:
|
||||||
try:
|
mlx_model_mapping = {}
|
||||||
from faster_whisper import WhisperModel
|
HAS_FASTER_WHISPER = faster_backend_available(warn_on_missing=not HAS_MLX_WHISPER)
|
||||||
HAS_FASTER_WHISPER = True
|
if HAS_FASTER_WHISPER:
|
||||||
except ImportError:
|
from faster_whisper import WhisperModel
|
||||||
HAS_FASTER_WHISPER = False
|
else:
|
||||||
|
WhisperModel = None
|
||||||
|
|
||||||
# TOO_MANY_REPETITIONS = 3
|
|
||||||
|
|
||||||
class SimulStreamingOnlineProcessor:
|
class SimulStreamingOnlineProcessor:
|
||||||
SAMPLING_RATE = 16000
|
SAMPLING_RATE = 16000
|
||||||
@@ -154,8 +150,22 @@ class SimulStreamingASR():
|
|||||||
self.decoder_type = 'greedy' if self.beams == 1 else 'beam'
|
self.decoder_type = 'greedy' if self.beams == 1 else 'beam'
|
||||||
|
|
||||||
self.fast_encoder = False
|
self.fast_encoder = False
|
||||||
if self.model_dir is not None:
|
self._resolved_model_path = None
|
||||||
self.model_path = self.model_dir
|
self.encoder_backend = "whisper"
|
||||||
|
preferred_backend = getattr(self, "backend", "auto")
|
||||||
|
self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = None, True, True
|
||||||
|
if self.model_path:
|
||||||
|
resolved_model_path = resolve_model_path(self.model_path)
|
||||||
|
self._resolved_model_path = resolved_model_path
|
||||||
|
self.model_path = str(resolved_model_path)
|
||||||
|
self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = model_path_and_type(resolved_model_path)
|
||||||
|
if self.pytorch_path:
|
||||||
|
self.model_name = self.pytorch_path.stem
|
||||||
|
else:
|
||||||
|
self.model_name = Path(self.model_path).stem
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"No PyTorch checkpoint (.pt/.bin/.safetensors) found under {self.model_path}"
|
||||||
|
)
|
||||||
elif self.model_size is not None:
|
elif self.model_size is not None:
|
||||||
model_mapping = {
|
model_mapping = {
|
||||||
'tiny': './tiny.pt',
|
'tiny': './tiny.pt',
|
||||||
@@ -171,10 +181,23 @@ class SimulStreamingASR():
|
|||||||
'large-v3': './large-v3.pt',
|
'large-v3': './large-v3.pt',
|
||||||
'large': './large-v3.pt'
|
'large': './large-v3.pt'
|
||||||
}
|
}
|
||||||
self.model_path = model_mapping.get(self.model_size, f'./{self.model_size}.pt')
|
self.model_name = self.model_size
|
||||||
|
else:
|
||||||
|
raise ValueError("Either model_size or model_path must be specified for SimulStreaming.")
|
||||||
|
|
||||||
|
is_multilingual = not self.model_name.endswith(".en")
|
||||||
|
|
||||||
|
self.encoder_backend = self._resolve_encoder_backend(
|
||||||
|
preferred_backend,
|
||||||
|
compatible_whisper_mlx,
|
||||||
|
compatible_faster_whisper,
|
||||||
|
)
|
||||||
|
self.fast_encoder = self.encoder_backend in ("mlx-whisper", "faster-whisper")
|
||||||
|
if self.encoder_backend == "whisper":
|
||||||
|
self.disable_fast_encoder = True
|
||||||
|
|
||||||
self.cfg = AlignAttConfig(
|
self.cfg = AlignAttConfig(
|
||||||
model_path=self.model_path,
|
tokenizer_is_multilingual= is_multilingual,
|
||||||
segment_length=self.min_chunk_size,
|
segment_length=self.min_chunk_size,
|
||||||
frame_threshold=self.frame_threshold,
|
frame_threshold=self.frame_threshold,
|
||||||
language=self.lan,
|
language=self.lan,
|
||||||
@@ -183,7 +206,7 @@ class SimulStreamingASR():
|
|||||||
cif_ckpt_path=self.cif_ckpt_path,
|
cif_ckpt_path=self.cif_ckpt_path,
|
||||||
decoder_type="beam",
|
decoder_type="beam",
|
||||||
beam_size=self.beams,
|
beam_size=self.beams,
|
||||||
task=self.task,
|
task=self.direct_english_translation,
|
||||||
never_fire=self.never_fire,
|
never_fire=self.never_fire,
|
||||||
init_prompt=self.init_prompt,
|
init_prompt=self.init_prompt,
|
||||||
max_context_tokens=self.max_context_tokens,
|
max_context_tokens=self.max_context_tokens,
|
||||||
@@ -191,40 +214,84 @@ class SimulStreamingASR():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Set up tokenizer for translation if needed
|
# Set up tokenizer for translation if needed
|
||||||
if self.task == "translate":
|
if self.direct_english_translation:
|
||||||
self.tokenizer = self.set_translate_task()
|
self.tokenizer = self.set_translate_task()
|
||||||
else:
|
else:
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
|
|
||||||
if self.model_dir:
|
|
||||||
self.model_name = self.model_dir
|
|
||||||
self.model_path = None
|
|
||||||
else:
|
|
||||||
self.model_name = os.path.basename(self.cfg.model_path).replace(".pt", "")
|
|
||||||
self.model_path = os.path.dirname(os.path.abspath(self.cfg.model_path))
|
|
||||||
|
|
||||||
self.mlx_encoder, self.fw_encoder = None, None
|
self.mlx_encoder, self.fw_encoder = None, None
|
||||||
if not self.disable_fast_encoder:
|
if self.encoder_backend == "mlx-whisper":
|
||||||
if HAS_MLX_WHISPER:
|
print('Simulstreaming will use MLX whisper to increase encoding speed.')
|
||||||
print('Simulstreaming will use MLX whisper for a faster encoder.')
|
if self._resolved_model_path is not None:
|
||||||
mlx_model_name = mlx_model_mapping[self.model_name]
|
mlx_model = str(self._resolved_model_path)
|
||||||
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_name)
|
else:
|
||||||
self.fast_encoder = True
|
mlx_model = mlx_model_mapping.get(self.model_name)
|
||||||
elif HAS_FASTER_WHISPER:
|
if not mlx_model:
|
||||||
print('Simulstreaming will use Faster Whisper for the encoder.')
|
raise FileNotFoundError(
|
||||||
self.fw_encoder = WhisperModel(
|
f"MLX Whisper backend requested but no compatible weights found for model '{self.model_name}'."
|
||||||
self.model_name,
|
|
||||||
device='auto',
|
|
||||||
compute_type='auto',
|
|
||||||
)
|
)
|
||||||
self.fast_encoder = True
|
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model)
|
||||||
|
elif self.encoder_backend == "faster-whisper":
|
||||||
|
print('Simulstreaming will use Faster Whisper for the encoder.')
|
||||||
|
if self._resolved_model_path is not None:
|
||||||
|
fw_model = str(self._resolved_model_path)
|
||||||
|
else:
|
||||||
|
fw_model = self.model_name
|
||||||
|
self.fw_encoder = WhisperModel(
|
||||||
|
fw_model,
|
||||||
|
device='auto',
|
||||||
|
compute_type='auto',
|
||||||
|
)
|
||||||
|
|
||||||
self.models = [self.load_model() for i in range(self.preload_model_count)]
|
self.models = [self.load_model() for i in range(self.preload_model_count)]
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_encoder_backend(self, preferred_backend, compatible_whisper_mlx, compatible_faster_whisper):
|
||||||
|
choice = preferred_backend or "auto"
|
||||||
|
if self.disable_fast_encoder:
|
||||||
|
return "whisper"
|
||||||
|
if choice == "whisper":
|
||||||
|
return "whisper"
|
||||||
|
if choice == "mlx-whisper":
|
||||||
|
if not self._can_use_mlx(compatible_whisper_mlx):
|
||||||
|
raise RuntimeError("mlx-whisper backend requested but MLX Whisper is unavailable or incompatible with the provided model.")
|
||||||
|
return "mlx-whisper"
|
||||||
|
if choice == "faster-whisper":
|
||||||
|
if not self._can_use_faster(compatible_faster_whisper):
|
||||||
|
raise RuntimeError("faster-whisper backend requested but Faster-Whisper is unavailable or incompatible with the provided model.")
|
||||||
|
return "faster-whisper"
|
||||||
|
if choice == "openai-api":
|
||||||
|
raise ValueError("openai-api backend is only supported with the LocalAgreement policy.")
|
||||||
|
# auto mode
|
||||||
|
if platform.system() == "Darwin" and self._can_use_mlx(compatible_whisper_mlx):
|
||||||
|
return "mlx-whisper"
|
||||||
|
if self._can_use_faster(compatible_faster_whisper):
|
||||||
|
return "faster-whisper"
|
||||||
|
return "whisper"
|
||||||
|
|
||||||
|
def _has_custom_model_path(self):
|
||||||
|
return self._resolved_model_path is not None
|
||||||
|
|
||||||
|
def _can_use_mlx(self, compatible_whisper_mlx):
|
||||||
|
if not HAS_MLX_WHISPER:
|
||||||
|
return False
|
||||||
|
if self._has_custom_model_path():
|
||||||
|
return compatible_whisper_mlx
|
||||||
|
return self.model_name in mlx_model_mapping
|
||||||
|
|
||||||
|
def _can_use_faster(self, compatible_faster_whisper):
|
||||||
|
if not HAS_FASTER_WHISPER:
|
||||||
|
return False
|
||||||
|
if self._has_custom_model_path():
|
||||||
|
return compatible_faster_whisper
|
||||||
|
return True
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
whisper_model = load_model(
|
whisper_model = load_model(
|
||||||
name=self.model_name,
|
name=self.pytorch_path if self.pytorch_path else self.model_name,
|
||||||
download_root=self.model_path,
|
download_root=self.model_path,
|
||||||
decoder_only=self.fast_encoder,
|
decoder_only=self.fast_encoder,
|
||||||
custom_alignment_heads=self.custom_alignment_heads
|
custom_alignment_heads=self.custom_alignment_heads
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from .whisper.decoding import PyTorchInference
|
from whisperlivekit.whisper.decoding import PyTorchInference
|
||||||
|
|
||||||
# extention of PyTorchInference for beam search
|
# extention of PyTorchInference for beam search
|
||||||
class BeamPyTorchInference(PyTorchInference):
|
class BeamPyTorchInference(PyTorchInference):
|
||||||
|
|||||||
@@ -1,29 +1,23 @@
|
|||||||
# This code was originally in simul_whisper/transcriber/simul_whisper.py . It is adapted a lot for SimulStreaming.
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SimulWhisperConfig:
|
class AlignAttConfig():
|
||||||
'''Options that are common for all simul policies that could be implemented in SimulWhisper.'''
|
|
||||||
model_path: str
|
|
||||||
language: str = field(default="zh")
|
|
||||||
nonspeech_prob: float = 0.5
|
|
||||||
audio_min_len: float = 1.0
|
|
||||||
decoder_type: Literal["greedy","beam"] = "greedy"
|
|
||||||
beam_size: int = 5
|
|
||||||
task: Literal["transcribe","translate"] = "transcribe"
|
|
||||||
init_prompt: str = field(default=None)
|
|
||||||
static_init_prompt: str = field(default=None)
|
|
||||||
max_context_tokens: int = field(default=None)
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AlignAttConfig(SimulWhisperConfig):
|
|
||||||
'''Options specific to the AlignAtt policy.'''
|
|
||||||
eval_data_path: str = "tmp"
|
eval_data_path: str = "tmp"
|
||||||
segment_length: float = field(default=1.0, metadata = {"help": "in second"})
|
segment_length: float = field(default=1.0, metadata = {"help": "in second"})
|
||||||
frame_threshold: int = 4
|
frame_threshold: int = 4
|
||||||
rewind_threshold: int = 200
|
rewind_threshold: int = 200
|
||||||
audio_max_len: float = 20.0
|
audio_max_len: float = 20.0
|
||||||
cif_ckpt_path: str = ""
|
cif_ckpt_path: str = ""
|
||||||
never_fire: bool = False
|
never_fire: bool = False
|
||||||
|
language: str = field(default="zh")
|
||||||
|
nonspeech_prob: float = 0.5
|
||||||
|
audio_min_len: float = 1.0
|
||||||
|
decoder_type: Literal["greedy","beam"] = "greedy"
|
||||||
|
beam_size: int = 5
|
||||||
|
task: Literal["transcribe","translate"] = "transcribe"
|
||||||
|
tokenizer_is_multilingual: bool = False
|
||||||
|
init_prompt: str = field(default=None)
|
||||||
|
static_init_prompt: str = field(default=None)
|
||||||
|
max_context_tokens: int = field(default=None)
|
||||||
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
SIMULSTREAMING_LICENSE = f"""
|
|
||||||
SimulStreaming backend is dual-licensed:
|
|
||||||
• Non-Commercial Use: PolyForm Noncommercial License 1.0.0.
|
|
||||||
• Commercial Use: Check SimulStreaming README (github.com/ufal/SimulStreaming) for more details.
|
|
||||||
"""
|
|
||||||
@@ -6,17 +6,21 @@ import logging
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from .whisper import load_model, DecodingOptions, tokenizer
|
from whisperlivekit.whisper import load_model, DecodingOptions, tokenizer
|
||||||
from .config import AlignAttConfig
|
from .config import AlignAttConfig
|
||||||
from whisperlivekit.timed_objects import ASRToken
|
from whisperlivekit.timed_objects import ASRToken
|
||||||
from .whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES
|
from whisperlivekit.whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES
|
||||||
from .whisper.timing import median_filter
|
from whisperlivekit.whisper.timing import median_filter
|
||||||
from .whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens, detect_language
|
from whisperlivekit.whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens, detect_language
|
||||||
from .beam import BeamPyTorchInference
|
from .beam import BeamPyTorchInference
|
||||||
from .eow_detection import fire_at_boundary, load_cif
|
from .eow_detection import fire_at_boundary, load_cif
|
||||||
import os
|
import os
|
||||||
from time import time
|
from time import time
|
||||||
from .token_buffer import TokenBuffer
|
from .token_buffer import TokenBuffer
|
||||||
|
from whisperlivekit.backend_support import (
|
||||||
|
mlx_backend_available,
|
||||||
|
faster_backend_available,
|
||||||
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from ..timed_objects import PUNCTUATION_MARKS
|
from ..timed_objects import PUNCTUATION_MARKS
|
||||||
@@ -26,21 +30,18 @@ DEC_PAD = 50257
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
try:
|
HAS_MLX_WHISPER = False
|
||||||
|
HAS_FASTER_WHISPER = False
|
||||||
|
|
||||||
|
if mlx_backend_available():
|
||||||
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
|
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
|
||||||
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
|
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
|
||||||
HAS_MLX_WHISPER = True
|
HAS_MLX_WHISPER = True
|
||||||
except ImportError:
|
|
||||||
HAS_MLX_WHISPER = False
|
if faster_backend_available():
|
||||||
if HAS_MLX_WHISPER:
|
from faster_whisper.audio import pad_or_trim as fw_pad_or_trim
|
||||||
HAS_FASTER_WHISPER = False
|
from faster_whisper.feature_extractor import FeatureExtractor
|
||||||
else:
|
HAS_FASTER_WHISPER = True
|
||||||
try:
|
|
||||||
from faster_whisper.audio import pad_or_trim as fw_pad_or_trim
|
|
||||||
from faster_whisper.feature_extractor import FeatureExtractor
|
|
||||||
HAS_FASTER_WHISPER = True
|
|
||||||
except ImportError:
|
|
||||||
HAS_FASTER_WHISPER = False
|
|
||||||
|
|
||||||
class PaddedAlignAttWhisper:
|
class PaddedAlignAttWhisper:
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -51,20 +52,15 @@ class PaddedAlignAttWhisper:
|
|||||||
fw_encoder=None,
|
fw_encoder=None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.log_segments = 0
|
self.log_segments = 0
|
||||||
model_name = os.path.basename(cfg.model_path).replace(".pt", "")
|
|
||||||
model_path = os.path.dirname(os.path.abspath(cfg.model_path))
|
|
||||||
if loaded_model:
|
|
||||||
self.model = loaded_model
|
|
||||||
else:
|
|
||||||
self.model = load_model(name=model_name, download_root=model_path)
|
|
||||||
|
|
||||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
self.model = loaded_model
|
||||||
|
|
||||||
self.mlx_encoder = mlx_encoder
|
self.mlx_encoder = mlx_encoder
|
||||||
self.fw_encoder = fw_encoder
|
self.fw_encoder = fw_encoder
|
||||||
if fw_encoder:
|
if fw_encoder:
|
||||||
self.fw_feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels)
|
self.fw_feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels)
|
||||||
|
|
||||||
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
|
||||||
logger.info(f"Model dimensions: {self.model.dims}")
|
logger.info(f"Model dimensions: {self.model.dims}")
|
||||||
self.speaker = -1
|
self.speaker = -1
|
||||||
self.decode_options = DecodingOptions(
|
self.decode_options = DecodingOptions(
|
||||||
@@ -72,7 +68,7 @@ class PaddedAlignAttWhisper:
|
|||||||
without_timestamps = True,
|
without_timestamps = True,
|
||||||
task=cfg.task
|
task=cfg.task
|
||||||
)
|
)
|
||||||
self.tokenizer_is_multilingual = not model_name.endswith(".en")
|
self.tokenizer_is_multilingual = cfg.tokenizer_is_multilingual
|
||||||
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
|
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
|
||||||
# self.create_tokenizer('en')
|
# self.create_tokenizer('en')
|
||||||
self.detected_language = cfg.language if cfg.language != "auto" else None
|
self.detected_language = cfg.language if cfg.language != "auto" else None
|
||||||
@@ -172,7 +168,10 @@ class PaddedAlignAttWhisper:
|
|||||||
self.inference.kv_cache = self.kv_cache
|
self.inference.kv_cache = self.kv_cache
|
||||||
|
|
||||||
self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size)
|
self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size)
|
||||||
|
|
||||||
|
# Tokens to carry over to next chunk for incomplete UTF-8 characters
|
||||||
|
self.pending_incomplete_tokens = []
|
||||||
|
|
||||||
def remove_hooks(self):
|
def remove_hooks(self):
|
||||||
for hook in self.l_hooks:
|
for hook in self.l_hooks:
|
||||||
hook.remove()
|
hook.remove()
|
||||||
@@ -266,6 +265,7 @@ class PaddedAlignAttWhisper:
|
|||||||
self.segments = []
|
self.segments = []
|
||||||
self.log_segments += 1
|
self.log_segments += 1
|
||||||
|
|
||||||
|
self.pending_incomplete_tokens = []
|
||||||
|
|
||||||
def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor):
|
def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor):
|
||||||
if self.always_fire: return True
|
if self.always_fire: return True
|
||||||
@@ -327,7 +327,7 @@ class PaddedAlignAttWhisper:
|
|||||||
self.segments = self.segments[1:]
|
self.segments = self.segments[1:]
|
||||||
logger.debug(f"remove segments: {len(self.segments)} {len(self.tokens)}, cumulative offset: {self.cumulative_time_offset:.2f}s")
|
logger.debug(f"remove segments: {len(self.segments)} {len(self.tokens)}, cumulative offset: {self.cumulative_time_offset:.2f}s")
|
||||||
if len(self.tokens) > 1:
|
if len(self.tokens) > 1:
|
||||||
self.context.append_token_ids(self.tokens[1][0,:])
|
self.context.append_token_ids(self.tokens[1][0,:].tolist())
|
||||||
self.tokens = [self.initial_tokens] + self.tokens[2:]
|
self.tokens = [self.initial_tokens] + self.tokens[2:]
|
||||||
return removed_len
|
return removed_len
|
||||||
|
|
||||||
@@ -567,6 +567,12 @@ class PaddedAlignAttWhisper:
|
|||||||
|
|
||||||
tokens_to_split = current_tokens[0, token_len_before_decoding:]
|
tokens_to_split = current_tokens[0, token_len_before_decoding:]
|
||||||
|
|
||||||
|
# Prepend pending tokens from previous chunk if any
|
||||||
|
if self.pending_incomplete_tokens:
|
||||||
|
logger.debug(f"[UTF-8 Fix] Prepending {len(self.pending_incomplete_tokens)} pending tokens: {self.pending_incomplete_tokens}")
|
||||||
|
pending_tensor = torch.tensor(self.pending_incomplete_tokens, dtype=torch.long, device=self.device)
|
||||||
|
tokens_to_split = torch.cat([pending_tensor, tokens_to_split])
|
||||||
|
|
||||||
if fire_detected or is_last: #or punctuation_stop:
|
if fire_detected or is_last: #or punctuation_stop:
|
||||||
new_hypothesis = tokens_to_split.flatten().tolist()
|
new_hypothesis = tokens_to_split.flatten().tolist()
|
||||||
split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis)
|
split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis)
|
||||||
@@ -595,7 +601,14 @@ class PaddedAlignAttWhisper:
|
|||||||
|
|
||||||
timestamped_words = []
|
timestamped_words = []
|
||||||
timestamp_idx = 0
|
timestamp_idx = 0
|
||||||
|
replacement_char = "\ufffd"
|
||||||
for word, word_tokens in zip(split_words, split_tokens):
|
for word, word_tokens in zip(split_words, split_tokens):
|
||||||
|
# Skip words containing incomplete UTF-8 from client output
|
||||||
|
if replacement_char in word:
|
||||||
|
logger.warning(f"[UTF-8 Filter] Skipping incomplete word from client output: {repr(word)}")
|
||||||
|
timestamp_idx += len(word_tokens)
|
||||||
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
current_timestamp = l_absolute_timestamps[timestamp_idx]
|
current_timestamp = l_absolute_timestamps[timestamp_idx]
|
||||||
except:
|
except:
|
||||||
@@ -613,5 +626,11 @@ class PaddedAlignAttWhisper:
|
|||||||
self.global_time_offset
|
self.global_time_offset
|
||||||
)
|
)
|
||||||
timestamped_words.append(timestamp_entry)
|
timestamped_words.append(timestamp_entry)
|
||||||
|
|
||||||
return timestamped_words
|
# Hold incomplete tokens for next chunk
|
||||||
|
self.pending_incomplete_tokens = []
|
||||||
|
if split_words and replacement_char in split_words[-1]:
|
||||||
|
self.pending_incomplete_tokens = split_tokens[-1]
|
||||||
|
logger.warning(f"[UTF-8 Fix] Holding {len(self.pending_incomplete_tokens)} incomplete tokens for next chunk: {self.pending_incomplete_tokens}")
|
||||||
|
|
||||||
|
return timestamped_words
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ class TokenBuffer:
|
|||||||
self.prefix_token_ids = prefix_token_ids
|
self.prefix_token_ids = prefix_token_ids
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.device = device
|
self.device = device
|
||||||
|
self.pending_token_ids = []
|
||||||
|
|
||||||
def as_token_ids(self, tokenizer=None):
|
def as_token_ids(self, tokenizer=None):
|
||||||
|
|
||||||
@@ -64,7 +65,26 @@ class TokenBuffer:
|
|||||||
def append_token_ids(self, token_ids):
|
def append_token_ids(self, token_ids):
|
||||||
tokenizer = self.tokenizer
|
tokenizer = self.tokenizer
|
||||||
assert tokenizer is not None, "Tokenizer is not set."
|
assert tokenizer is not None, "Tokenizer is not set."
|
||||||
self.text += self.tokenizer.decode(token_ids)
|
|
||||||
|
all_tokens = self.pending_token_ids + token_ids
|
||||||
|
|
||||||
|
decoded = tokenizer.decode(all_tokens)
|
||||||
|
replacement_char = "\ufffd"
|
||||||
|
|
||||||
|
if replacement_char in decoded:
|
||||||
|
if len(all_tokens) > 1:
|
||||||
|
decoded_partial = tokenizer.decode(all_tokens[:-1])
|
||||||
|
|
||||||
|
if replacement_char not in decoded_partial:
|
||||||
|
self.text += decoded_partial
|
||||||
|
self.pending_token_ids = [all_tokens[-1]]
|
||||||
|
else:
|
||||||
|
self.pending_token_ids = all_tokens
|
||||||
|
else:
|
||||||
|
self.pending_token_ids = all_tokens
|
||||||
|
else:
|
||||||
|
self.text += decoded
|
||||||
|
self.pending_token_ids = []
|
||||||
|
|
||||||
def as_split_word_tokens(self):
|
def as_split_word_tokens(self):
|
||||||
tokenizer = self.tokenizer
|
tokenizer = self.tokenizer
|
||||||
|
|||||||
@@ -151,6 +151,7 @@ class FrontData():
|
|||||||
lines: list[Line] = field(default_factory=list)
|
lines: list[Line] = field(default_factory=list)
|
||||||
buffer_transcription: str = ''
|
buffer_transcription: str = ''
|
||||||
buffer_diarization: str = ''
|
buffer_diarization: str = ''
|
||||||
|
buffer_translation: str = ''
|
||||||
remaining_time_transcription: float = 0.
|
remaining_time_transcription: float = 0.
|
||||||
remaining_time_diarization: float = 0.
|
remaining_time_diarization: float = 0.
|
||||||
|
|
||||||
@@ -160,6 +161,7 @@ class FrontData():
|
|||||||
'lines': [line.to_dict() for line in self.lines if (line.text or line.speaker == -2)],
|
'lines': [line.to_dict() for line in self.lines if (line.text or line.speaker == -2)],
|
||||||
'buffer_transcription': self.buffer_transcription,
|
'buffer_transcription': self.buffer_transcription,
|
||||||
'buffer_diarization': self.buffer_diarization,
|
'buffer_diarization': self.buffer_diarization,
|
||||||
|
'buffer_translation': self.buffer_translation,
|
||||||
'remaining_time_transcription': self.remaining_time_transcription,
|
'remaining_time_transcription': self.remaining_time_transcription,
|
||||||
'remaining_time_diarization': self.remaining_time_diarization,
|
'remaining_time_diarization': self.remaining_time_diarization,
|
||||||
}
|
}
|
||||||
@@ -174,11 +176,14 @@ class ChangeSpeaker:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class State():
|
class State():
|
||||||
tokens: list
|
tokens: list = field(default_factory=list)
|
||||||
last_validated_token: int
|
last_validated_token: int = 0
|
||||||
translated_segments: list
|
last_punctuation_index: Optional[int] = None
|
||||||
buffer_transcription: str
|
translation_validated_segments: list = field(default_factory=list)
|
||||||
end_buffer: float
|
buffer_translation: str = field(default_factory=Transcript)
|
||||||
end_attributed_speaker: float
|
buffer_transcription: str = field(default_factory=Transcript)
|
||||||
remaining_time_transcription: float
|
end_buffer: float = 0.0
|
||||||
remaining_time_diarization: float
|
end_attributed_speaker: float = 0.0
|
||||||
|
remaining_time_transcription: float = 0.0
|
||||||
|
remaining_time_diarization: float = 0.0
|
||||||
|
beg_loop: Optional[int] = None
|
||||||
|
|||||||
@@ -1,182 +0,0 @@
|
|||||||
"""
|
|
||||||
adapted from https://store.crowdin.com/custom-mt
|
|
||||||
"""
|
|
||||||
|
|
||||||
LANGUAGES = [
|
|
||||||
{"name": "Afrikaans", "nllb": "afr_Latn", "crowdin": "af"},
|
|
||||||
{"name": "Akan", "nllb": "aka_Latn", "crowdin": "ak"},
|
|
||||||
{"name": "Amharic", "nllb": "amh_Ethi", "crowdin": "am"},
|
|
||||||
{"name": "Assamese", "nllb": "asm_Beng", "crowdin": "as"},
|
|
||||||
{"name": "Asturian", "nllb": "ast_Latn", "crowdin": "ast"},
|
|
||||||
{"name": "Bashkir", "nllb": "bak_Cyrl", "crowdin": "ba"},
|
|
||||||
{"name": "Bambara", "nllb": "bam_Latn", "crowdin": "bm"},
|
|
||||||
{"name": "Balinese", "nllb": "ban_Latn", "crowdin": "ban"},
|
|
||||||
{"name": "Belarusian", "nllb": "bel_Cyrl", "crowdin": "be"},
|
|
||||||
{"name": "Bengali", "nllb": "ben_Beng", "crowdin": "bn"},
|
|
||||||
{"name": "Bosnian", "nllb": "bos_Latn", "crowdin": "bs"},
|
|
||||||
{"name": "Bulgarian", "nllb": "bul_Cyrl", "crowdin": "bg"},
|
|
||||||
{"name": "Catalan", "nllb": "cat_Latn", "crowdin": "ca"},
|
|
||||||
{"name": "Cebuano", "nllb": "ceb_Latn", "crowdin": "ceb"},
|
|
||||||
{"name": "Czech", "nllb": "ces_Latn", "crowdin": "cs"},
|
|
||||||
{"name": "Welsh", "nllb": "cym_Latn", "crowdin": "cy"},
|
|
||||||
{"name": "Danish", "nllb": "dan_Latn", "crowdin": "da"},
|
|
||||||
{"name": "German", "nllb": "deu_Latn", "crowdin": "de"},
|
|
||||||
{"name": "Dzongkha", "nllb": "dzo_Tibt", "crowdin": "dz"},
|
|
||||||
{"name": "Greek", "nllb": "ell_Grek", "crowdin": "el"},
|
|
||||||
{"name": "English", "nllb": "eng_Latn", "crowdin": "en"},
|
|
||||||
{"name": "Esperanto", "nllb": "epo_Latn", "crowdin": "eo"},
|
|
||||||
{"name": "Estonian", "nllb": "est_Latn", "crowdin": "et"},
|
|
||||||
{"name": "Basque", "nllb": "eus_Latn", "crowdin": "eu"},
|
|
||||||
{"name": "Ewe", "nllb": "ewe_Latn", "crowdin": "ee"},
|
|
||||||
{"name": "Faroese", "nllb": "fao_Latn", "crowdin": "fo"},
|
|
||||||
{"name": "Fijian", "nllb": "fij_Latn", "crowdin": "fj"},
|
|
||||||
{"name": "Finnish", "nllb": "fin_Latn", "crowdin": "fi"},
|
|
||||||
{"name": "French", "nllb": "fra_Latn", "crowdin": "fr"},
|
|
||||||
{"name": "Friulian", "nllb": "fur_Latn", "crowdin": "fur-IT"},
|
|
||||||
{"name": "Scottish Gaelic", "nllb": "gla_Latn", "crowdin": "gd"},
|
|
||||||
{"name": "Irish", "nllb": "gle_Latn", "crowdin": "ga-IE"},
|
|
||||||
{"name": "Galician", "nllb": "glg_Latn", "crowdin": "gl"},
|
|
||||||
{"name": "Guarani", "nllb": "grn_Latn", "crowdin": "gn"},
|
|
||||||
{"name": "Gujarati", "nllb": "guj_Gujr", "crowdin": "gu-IN"},
|
|
||||||
{"name": "Haitian Creole", "nllb": "hat_Latn", "crowdin": "ht"},
|
|
||||||
{"name": "Hausa", "nllb": "hau_Latn", "crowdin": "ha"},
|
|
||||||
{"name": "Hebrew", "nllb": "heb_Hebr", "crowdin": "he"},
|
|
||||||
{"name": "Hindi", "nllb": "hin_Deva", "crowdin": "hi"},
|
|
||||||
{"name": "Croatian", "nllb": "hrv_Latn", "crowdin": "hr"},
|
|
||||||
{"name": "Hungarian", "nllb": "hun_Latn", "crowdin": "hu"},
|
|
||||||
{"name": "Armenian", "nllb": "hye_Armn", "crowdin": "hy-AM"},
|
|
||||||
{"name": "Igbo", "nllb": "ibo_Latn", "crowdin": "ig"},
|
|
||||||
{"name": "Indonesian", "nllb": "ind_Latn", "crowdin": "id"},
|
|
||||||
{"name": "Icelandic", "nllb": "isl_Latn", "crowdin": "is"},
|
|
||||||
{"name": "Italian", "nllb": "ita_Latn", "crowdin": "it"},
|
|
||||||
{"name": "Javanese", "nllb": "jav_Latn", "crowdin": "jv"},
|
|
||||||
{"name": "Japanese", "nllb": "jpn_Jpan", "crowdin": "ja"},
|
|
||||||
{"name": "Kabyle", "nllb": "kab_Latn", "crowdin": "kab"},
|
|
||||||
{"name": "Kannada", "nllb": "kan_Knda", "crowdin": "kn"},
|
|
||||||
{"name": "Georgian", "nllb": "kat_Geor", "crowdin": "ka"},
|
|
||||||
{"name": "Kazakh", "nllb": "kaz_Cyrl", "crowdin": "kk"},
|
|
||||||
{"name": "Khmer", "nllb": "khm_Khmr", "crowdin": "km"},
|
|
||||||
{"name": "Kinyarwanda", "nllb": "kin_Latn", "crowdin": "rw"},
|
|
||||||
{"name": "Kyrgyz", "nllb": "kir_Cyrl", "crowdin": "ky"},
|
|
||||||
{"name": "Korean", "nllb": "kor_Hang", "crowdin": "ko"},
|
|
||||||
{"name": "Lao", "nllb": "lao_Laoo", "crowdin": "lo"},
|
|
||||||
{"name": "Ligurian", "nllb": "lij_Latn", "crowdin": "lij"},
|
|
||||||
{"name": "Limburgish", "nllb": "lim_Latn", "crowdin": "li"},
|
|
||||||
{"name": "Lingala", "nllb": "lin_Latn", "crowdin": "ln"},
|
|
||||||
{"name": "Lithuanian", "nllb": "lit_Latn", "crowdin": "lt"},
|
|
||||||
{"name": "Luxembourgish", "nllb": "ltz_Latn", "crowdin": "lb"},
|
|
||||||
{"name": "Maithili", "nllb": "mai_Deva", "crowdin": "mai"},
|
|
||||||
{"name": "Malayalam", "nllb": "mal_Mlym", "crowdin": "ml-IN"},
|
|
||||||
{"name": "Marathi", "nllb": "mar_Deva", "crowdin": "mr"},
|
|
||||||
{"name": "Macedonian", "nllb": "mkd_Cyrl", "crowdin": "mk"},
|
|
||||||
{"name": "Maltese", "nllb": "mlt_Latn", "crowdin": "mt"},
|
|
||||||
{"name": "Mossi", "nllb": "mos_Latn", "crowdin": "mos"},
|
|
||||||
{"name": "Maori", "nllb": "mri_Latn", "crowdin": "mi"},
|
|
||||||
{"name": "Burmese", "nllb": "mya_Mymr", "crowdin": "my"},
|
|
||||||
{"name": "Dutch", "nllb": "nld_Latn", "crowdin": "nl"},
|
|
||||||
{"name": "Norwegian Nynorsk", "nllb": "nno_Latn", "crowdin": "nn-NO"},
|
|
||||||
{"name": "Nepali", "nllb": "npi_Deva", "crowdin": "ne-NP"},
|
|
||||||
{"name": "Northern Sotho", "nllb": "nso_Latn", "crowdin": "nso"},
|
|
||||||
{"name": "Occitan", "nllb": "oci_Latn", "crowdin": "oc"},
|
|
||||||
{"name": "Odia", "nllb": "ory_Orya", "crowdin": "or"},
|
|
||||||
{"name": "Papiamento", "nllb": "pap_Latn", "crowdin": "pap"},
|
|
||||||
{"name": "Polish", "nllb": "pol_Latn", "crowdin": "pl"},
|
|
||||||
{"name": "Portuguese", "nllb": "por_Latn", "crowdin": "pt-PT"},
|
|
||||||
{"name": "Dari", "nllb": "prs_Arab", "crowdin": "fa-AF"},
|
|
||||||
{"name": "Romanian", "nllb": "ron_Latn", "crowdin": "ro"},
|
|
||||||
{"name": "Rundi", "nllb": "run_Latn", "crowdin": "rn"},
|
|
||||||
{"name": "Russian", "nllb": "rus_Cyrl", "crowdin": "ru"},
|
|
||||||
{"name": "Sango", "nllb": "sag_Latn", "crowdin": "sg"},
|
|
||||||
{"name": "Sanskrit", "nllb": "san_Deva", "crowdin": "sa"},
|
|
||||||
{"name": "Santali", "nllb": "sat_Olck", "crowdin": "sat"},
|
|
||||||
{"name": "Sinhala", "nllb": "sin_Sinh", "crowdin": "si-LK"},
|
|
||||||
{"name": "Slovak", "nllb": "slk_Latn", "crowdin": "sk"},
|
|
||||||
{"name": "Slovenian", "nllb": "slv_Latn", "crowdin": "sl"},
|
|
||||||
{"name": "Shona", "nllb": "sna_Latn", "crowdin": "sn"},
|
|
||||||
{"name": "Sindhi", "nllb": "snd_Arab", "crowdin": "sd"},
|
|
||||||
{"name": "Somali", "nllb": "som_Latn", "crowdin": "so"},
|
|
||||||
{"name": "Southern Sotho", "nllb": "sot_Latn", "crowdin": "st"},
|
|
||||||
{"name": "Spanish", "nllb": "spa_Latn", "crowdin": "es-ES"},
|
|
||||||
{"name": "Sardinian", "nllb": "srd_Latn", "crowdin": "sc"},
|
|
||||||
{"name": "Swati", "nllb": "ssw_Latn", "crowdin": "ss"},
|
|
||||||
{"name": "Sundanese", "nllb": "sun_Latn", "crowdin": "su"},
|
|
||||||
{"name": "Swedish", "nllb": "swe_Latn", "crowdin": "sv-SE"},
|
|
||||||
{"name": "Swahili", "nllb": "swh_Latn", "crowdin": "sw"},
|
|
||||||
{"name": "Tamil", "nllb": "tam_Taml", "crowdin": "ta"},
|
|
||||||
{"name": "Tatar", "nllb": "tat_Cyrl", "crowdin": "tt-RU"},
|
|
||||||
{"name": "Telugu", "nllb": "tel_Telu", "crowdin": "te"},
|
|
||||||
{"name": "Tajik", "nllb": "tgk_Cyrl", "crowdin": "tg"},
|
|
||||||
{"name": "Tagalog", "nllb": "tgl_Latn", "crowdin": "tl"},
|
|
||||||
{"name": "Thai", "nllb": "tha_Thai", "crowdin": "th"},
|
|
||||||
{"name": "Tigrinya", "nllb": "tir_Ethi", "crowdin": "ti"},
|
|
||||||
{"name": "Tswana", "nllb": "tsn_Latn", "crowdin": "tn"},
|
|
||||||
{"name": "Tsonga", "nllb": "tso_Latn", "crowdin": "ts"},
|
|
||||||
{"name": "Turkmen", "nllb": "tuk_Latn", "crowdin": "tk"},
|
|
||||||
{"name": "Turkish", "nllb": "tur_Latn", "crowdin": "tr"},
|
|
||||||
{"name": "Uyghur", "nllb": "uig_Arab", "crowdin": "ug"},
|
|
||||||
{"name": "Ukrainian", "nllb": "ukr_Cyrl", "crowdin": "uk"},
|
|
||||||
{"name": "Venetian", "nllb": "vec_Latn", "crowdin": "vec"},
|
|
||||||
{"name": "Vietnamese", "nllb": "vie_Latn", "crowdin": "vi"},
|
|
||||||
{"name": "Wolof", "nllb": "wol_Latn", "crowdin": "wo"},
|
|
||||||
{"name": "Xhosa", "nllb": "xho_Latn", "crowdin": "xh"},
|
|
||||||
{"name": "Yoruba", "nllb": "yor_Latn", "crowdin": "yo"},
|
|
||||||
{"name": "Zulu", "nllb": "zul_Latn", "crowdin": "zu"},
|
|
||||||
]
|
|
||||||
|
|
||||||
NAME_TO_NLLB = {lang["name"]: lang["nllb"] for lang in LANGUAGES}
|
|
||||||
NAME_TO_CROWDIN = {lang["name"]: lang["crowdin"] for lang in LANGUAGES}
|
|
||||||
CROWDIN_TO_NLLB = {lang["crowdin"]: lang["nllb"] for lang in LANGUAGES}
|
|
||||||
NLLB_TO_CROWDIN = {lang["nllb"]: lang["crowdin"] for lang in LANGUAGES}
|
|
||||||
CROWDIN_TO_NAME = {lang["crowdin"]: lang["name"] for lang in LANGUAGES}
|
|
||||||
NLLB_TO_NAME = {lang["nllb"]: lang["name"] for lang in LANGUAGES}
|
|
||||||
|
|
||||||
|
|
||||||
def get_nllb_code(crowdin_code):
|
|
||||||
return CROWDIN_TO_NLLB.get(crowdin_code, None)
|
|
||||||
|
|
||||||
|
|
||||||
def get_crowdin_code(nllb_code):
|
|
||||||
return NLLB_TO_CROWDIN.get(nllb_code)
|
|
||||||
|
|
||||||
|
|
||||||
def get_language_name_by_crowdin(crowdin_code):
|
|
||||||
return CROWDIN_TO_NAME.get(crowdin_code)
|
|
||||||
|
|
||||||
|
|
||||||
def get_language_name_by_nllb(nllb_code):
|
|
||||||
return NLLB_TO_NAME.get(nllb_code)
|
|
||||||
|
|
||||||
|
|
||||||
def get_language_info(identifier, identifier_type="auto"):
|
|
||||||
if identifier_type == "auto":
|
|
||||||
for lang in LANGUAGES:
|
|
||||||
if (lang["name"].lower() == identifier.lower() or
|
|
||||||
lang["nllb"] == identifier or
|
|
||||||
lang["crowdin"] == identifier):
|
|
||||||
return lang
|
|
||||||
elif identifier_type == "name":
|
|
||||||
for lang in LANGUAGES:
|
|
||||||
if lang["name"].lower() == identifier.lower():
|
|
||||||
return lang
|
|
||||||
elif identifier_type == "nllb":
|
|
||||||
for lang in LANGUAGES:
|
|
||||||
if lang["nllb"] == identifier:
|
|
||||||
return lang
|
|
||||||
elif identifier_type == "crowdin":
|
|
||||||
for lang in LANGUAGES:
|
|
||||||
if lang["crowdin"] == identifier:
|
|
||||||
return lang
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def list_all_languages():
|
|
||||||
return [lang["name"] for lang in LANGUAGES]
|
|
||||||
|
|
||||||
|
|
||||||
def list_all_nllb_codes():
|
|
||||||
return [lang["nllb"] for lang in LANGUAGES]
|
|
||||||
|
|
||||||
|
|
||||||
def list_all_crowdin_codes():
|
|
||||||
return [lang["crowdin"] for lang in LANGUAGES]
|
|
||||||
@@ -1,169 +0,0 @@
|
|||||||
import logging
|
|
||||||
import time
|
|
||||||
import ctranslate2
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
import huggingface_hub
|
|
||||||
from whisperlivekit.translation.mapping_languages import get_nllb_code
|
|
||||||
from whisperlivekit.timed_objects import Translation
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
#In diarization case, we may want to translate just one speaker, or at least start the sentences there
|
|
||||||
|
|
||||||
MIN_SILENCE_DURATION_DEL_BUFFER = 3 #After a silence of x seconds, we consider the model should not use the buffer, even if the previous
|
|
||||||
# sentence is not finished.
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TranslationModel():
|
|
||||||
translator: ctranslate2.Translator
|
|
||||||
device: str
|
|
||||||
tokenizer: dict = field(default_factory=dict)
|
|
||||||
backend_type: str = 'ctranslate2'
|
|
||||||
nllb_size: str = '600M'
|
|
||||||
|
|
||||||
def get_tokenizer(self, input_lang):
|
|
||||||
if not self.tokenizer.get(input_lang, False):
|
|
||||||
self.tokenizer[input_lang] = transformers.AutoTokenizer.from_pretrained(
|
|
||||||
f"facebook/nllb-200-distilled-{self.nllb_size}",
|
|
||||||
src_lang=input_lang,
|
|
||||||
clean_up_tokenization_spaces=True
|
|
||||||
)
|
|
||||||
return self.tokenizer[input_lang]
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(src_langs, nllb_backend='ctranslate2', nllb_size='600M'):
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
MODEL = f'nllb-200-distilled-{nllb_size}-ctranslate2'
|
|
||||||
if nllb_backend=='ctranslate2':
|
|
||||||
MODEL_GUY = 'entai2965'
|
|
||||||
huggingface_hub.snapshot_download(MODEL_GUY + '/' + MODEL,local_dir=MODEL)
|
|
||||||
translator = ctranslate2.Translator(MODEL,device=device)
|
|
||||||
elif nllb_backend=='transformers':
|
|
||||||
translator = transformers.AutoModelForSeq2SeqLM.from_pretrained(f"facebook/nllb-200-distilled-{nllb_size}")
|
|
||||||
tokenizer = dict()
|
|
||||||
for src_lang in src_langs:
|
|
||||||
if src_lang != 'auto':
|
|
||||||
tokenizer[src_lang] = transformers.AutoTokenizer.from_pretrained(MODEL, src_lang=src_lang, clean_up_tokenization_spaces=True)
|
|
||||||
|
|
||||||
translation_model = TranslationModel(
|
|
||||||
translator=translator,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
backend_type=nllb_backend,
|
|
||||||
device = device,
|
|
||||||
nllb_size = nllb_size
|
|
||||||
)
|
|
||||||
for src_lang in src_langs:
|
|
||||||
if src_lang != 'auto':
|
|
||||||
translation_model.get_tokenizer(src_lang)
|
|
||||||
return translation_model
|
|
||||||
|
|
||||||
class OnlineTranslation:
|
|
||||||
def __init__(self, translation_model: TranslationModel, input_languages: list, output_languages: list):
|
|
||||||
self.buffer = []
|
|
||||||
self.len_processed_buffer = 0
|
|
||||||
self.translation_remaining = Translation()
|
|
||||||
self.validated = []
|
|
||||||
self.translation_pending_validation = ''
|
|
||||||
self.translation_model = translation_model
|
|
||||||
self.input_languages = input_languages
|
|
||||||
self.output_languages = output_languages
|
|
||||||
|
|
||||||
def compute_common_prefix(self, results):
|
|
||||||
#we dont want want to prune the result for the moment.
|
|
||||||
if not self.buffer:
|
|
||||||
self.buffer = results
|
|
||||||
else:
|
|
||||||
for i in range(min(len(self.buffer), len(results))):
|
|
||||||
if self.buffer[i] != results[i]:
|
|
||||||
self.commited.extend(self.buffer[:i])
|
|
||||||
self.buffer = results[i:]
|
|
||||||
|
|
||||||
def translate(self, input, input_lang, output_lang):
|
|
||||||
if not input:
|
|
||||||
return ""
|
|
||||||
nllb_output_lang = get_nllb_code(output_lang)
|
|
||||||
|
|
||||||
tokenizer = self.translation_model.get_tokenizer(input_lang)
|
|
||||||
tokenizer_output = tokenizer(input, return_tensors="pt").to(self.translation_model.device)
|
|
||||||
|
|
||||||
if self.translation_model.backend_type == 'ctranslate2':
|
|
||||||
source = tokenizer.convert_ids_to_tokens(tokenizer_output['input_ids'][0])
|
|
||||||
results = self.translation_model.translator.translate_batch([source], target_prefix=[[nllb_output_lang]])
|
|
||||||
target = results[0].hypotheses[0][1:]
|
|
||||||
result = tokenizer.decode(tokenizer.convert_tokens_to_ids(target))
|
|
||||||
else:
|
|
||||||
translated_tokens = self.translation_model.translator.generate(**tokenizer_output, forced_bos_token_id=tokenizer.convert_tokens_to_ids(nllb_output_lang))
|
|
||||||
result = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
|
|
||||||
return result
|
|
||||||
|
|
||||||
def translate_tokens(self, tokens):
|
|
||||||
if tokens:
|
|
||||||
text = ' '.join([token.text for token in tokens])
|
|
||||||
start = tokens[0].start
|
|
||||||
end = tokens[-1].end
|
|
||||||
if self.input_languages[0] == 'auto':
|
|
||||||
input_lang = tokens[0].detected_language
|
|
||||||
else:
|
|
||||||
input_lang = self.input_languages[0]
|
|
||||||
|
|
||||||
translated_text = self.translate(text,
|
|
||||||
input_lang,
|
|
||||||
self.output_languages[0]
|
|
||||||
)
|
|
||||||
translation = Translation(
|
|
||||||
text=translated_text,
|
|
||||||
start=start,
|
|
||||||
end=end,
|
|
||||||
)
|
|
||||||
return translation
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def insert_tokens(self, tokens):
|
|
||||||
self.buffer.extend(tokens)
|
|
||||||
pass
|
|
||||||
|
|
||||||
def process(self):
|
|
||||||
i = 0
|
|
||||||
if len(self.buffer) < self.len_processed_buffer + 3: #nothing new to process
|
|
||||||
return self.validated + [self.translation_remaining]
|
|
||||||
while i < len(self.buffer):
|
|
||||||
if self.buffer[i].is_punctuation():
|
|
||||||
translation_sentence = self.translate_tokens(self.buffer[:i+1])
|
|
||||||
self.validated.append(translation_sentence)
|
|
||||||
self.buffer = self.buffer[i+1:]
|
|
||||||
i = 0
|
|
||||||
else:
|
|
||||||
i+=1
|
|
||||||
self.translation_remaining = self.translate_tokens(self.buffer)
|
|
||||||
self.len_processed_buffer = len(self.buffer)
|
|
||||||
return self.validated + [self.translation_remaining]
|
|
||||||
|
|
||||||
def insert_silence(self, silence_duration: float):
|
|
||||||
if silence_duration >= MIN_SILENCE_DURATION_DEL_BUFFER:
|
|
||||||
self.buffer = []
|
|
||||||
self.validated += [self.translation_remaining]
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
output_lang = 'fr'
|
|
||||||
input_lang = "en"
|
|
||||||
|
|
||||||
|
|
||||||
test_string = """
|
|
||||||
Transcription technology has improved so much in the past few years. Have you noticed how accurate real-time speech-to-text is now?
|
|
||||||
"""
|
|
||||||
test = test_string.split(' ')
|
|
||||||
step = len(test) // 3
|
|
||||||
|
|
||||||
shared_model = load_model([input_lang], nllb_backend='ctranslate2')
|
|
||||||
online_translation = OnlineTranslation(shared_model, input_languages=[input_lang], output_languages=[output_lang])
|
|
||||||
|
|
||||||
beg_inference = time.time()
|
|
||||||
for id in range(5):
|
|
||||||
val = test[id*step : (id+1)*step]
|
|
||||||
val_str = ' '.join(val)
|
|
||||||
result = online_translation.translate(val_str)
|
|
||||||
print(result)
|
|
||||||
print('inference time:', time.time() - beg_inference)
|
|
||||||
BIN
whisperlivekit/vad_models/silero_vad.jit
Normal file
BIN
whisperlivekit/vad_models/silero_vad.jit
Normal file
Binary file not shown.
BIN
whisperlivekit/vad_models/silero_vad.onnx
Normal file
BIN
whisperlivekit/vad_models/silero_vad.onnx
Normal file
Binary file not shown.
BIN
whisperlivekit/vad_models/silero_vad_16k_op15.onnx
Normal file
BIN
whisperlivekit/vad_models/silero_vad_16k_op15.onnx
Normal file
Binary file not shown.
BIN
whisperlivekit/vad_models/silero_vad_half.onnx
Normal file
BIN
whisperlivekit/vad_models/silero_vad_half.onnx
Normal file
Binary file not shown.
@@ -490,6 +490,11 @@ label {
|
|||||||
margin-left: 4px;
|
margin-left: 4px;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.buffer_translation {
|
||||||
|
color: #a0a0a0;
|
||||||
|
margin-left: 6px;
|
||||||
|
}
|
||||||
|
|
||||||
.spinner {
|
.spinner {
|
||||||
display: inline-block;
|
display: inline-block;
|
||||||
width: 8px;
|
width: 8px;
|
||||||
|
|||||||
@@ -232,10 +232,11 @@ function setupWebSocket() {
|
|||||||
if (waitingForStop) {
|
if (waitingForStop) {
|
||||||
statusText.textContent = "Processing finalized or connection closed.";
|
statusText.textContent = "Processing finalized or connection closed.";
|
||||||
if (lastReceivedData) {
|
if (lastReceivedData) {
|
||||||
renderLinesWithBuffer(
|
renderLinesWithBuffer(
|
||||||
lastReceivedData.lines || [],
|
lastReceivedData.lines || [],
|
||||||
lastReceivedData.buffer_diarization || "",
|
lastReceivedData.buffer_diarization || "",
|
||||||
lastReceivedData.buffer_transcription || "",
|
lastReceivedData.buffer_transcription || "",
|
||||||
|
lastReceivedData.buffer_translation || "",
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
true
|
true
|
||||||
@@ -281,6 +282,7 @@ function setupWebSocket() {
|
|||||||
lastReceivedData.lines || [],
|
lastReceivedData.lines || [],
|
||||||
lastReceivedData.buffer_diarization || "",
|
lastReceivedData.buffer_diarization || "",
|
||||||
lastReceivedData.buffer_transcription || "",
|
lastReceivedData.buffer_transcription || "",
|
||||||
|
lastReceivedData.buffer_translation || "",
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
true
|
true
|
||||||
@@ -301,6 +303,7 @@ function setupWebSocket() {
|
|||||||
lines = [],
|
lines = [],
|
||||||
buffer_transcription = "",
|
buffer_transcription = "",
|
||||||
buffer_diarization = "",
|
buffer_diarization = "",
|
||||||
|
buffer_translation = "",
|
||||||
remaining_time_transcription = 0,
|
remaining_time_transcription = 0,
|
||||||
remaining_time_diarization = 0,
|
remaining_time_diarization = 0,
|
||||||
status = "active_transcription",
|
status = "active_transcription",
|
||||||
@@ -310,6 +313,7 @@ function setupWebSocket() {
|
|||||||
lines,
|
lines,
|
||||||
buffer_diarization,
|
buffer_diarization,
|
||||||
buffer_transcription,
|
buffer_transcription,
|
||||||
|
buffer_translation,
|
||||||
remaining_time_diarization,
|
remaining_time_diarization,
|
||||||
remaining_time_transcription,
|
remaining_time_transcription,
|
||||||
false,
|
false,
|
||||||
@@ -323,6 +327,7 @@ function renderLinesWithBuffer(
|
|||||||
lines,
|
lines,
|
||||||
buffer_diarization,
|
buffer_diarization,
|
||||||
buffer_transcription,
|
buffer_transcription,
|
||||||
|
buffer_translation,
|
||||||
remaining_time_diarization,
|
remaining_time_diarization,
|
||||||
remaining_time_transcription,
|
remaining_time_transcription,
|
||||||
isFinalizing = false,
|
isFinalizing = false,
|
||||||
@@ -341,6 +346,7 @@ function renderLinesWithBuffer(
|
|||||||
lines: (lines || []).map((it) => ({ speaker: it.speaker, text: it.text, start: it.start, end: it.end, detected_language: it.detected_language })),
|
lines: (lines || []).map((it) => ({ speaker: it.speaker, text: it.text, start: it.start, end: it.end, detected_language: it.detected_language })),
|
||||||
buffer_transcription: buffer_transcription || "",
|
buffer_transcription: buffer_transcription || "",
|
||||||
buffer_diarization: buffer_diarization || "",
|
buffer_diarization: buffer_diarization || "",
|
||||||
|
buffer_translation: buffer_translation,
|
||||||
status: current_status,
|
status: current_status,
|
||||||
showLoading,
|
showLoading,
|
||||||
showTransLag,
|
showTransLag,
|
||||||
@@ -415,13 +421,22 @@ function renderLinesWithBuffer(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
let translationContent = "";
|
||||||
if (item.translation) {
|
if (item.translation) {
|
||||||
|
translationContent += item.translation.trim();
|
||||||
|
}
|
||||||
|
if (idx === lines.length - 1 && buffer_translation) {
|
||||||
|
const bufferPiece = isFinalizing
|
||||||
|
? buffer_translation
|
||||||
|
: `<span class="buffer_translation">${buffer_translation}</span>`;
|
||||||
|
translationContent += translationContent ? `${bufferPiece}` : bufferPiece;
|
||||||
|
}
|
||||||
|
if (translationContent.trim().length > 0) {
|
||||||
currentLineText += `
|
currentLineText += `
|
||||||
<div>
|
<div>
|
||||||
<div class="label_translation">
|
<div class="label_translation">
|
||||||
${translationIcon}
|
${translationIcon}
|
||||||
<span>${item.translation}</span>
|
<span class="translation_text">${translationContent}</span>
|
||||||
</div>
|
</div>
|
||||||
</div>`;
|
</div>`;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
import io
|
import io
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import urllib
|
import urllib
|
||||||
import warnings
|
import warnings
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union, Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
|
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||||
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
|
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
|
||||||
@@ -100,6 +102,137 @@ def available_models() -> List[str]:
|
|||||||
return list(_MODELS.keys())
|
return list(_MODELS.keys())
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_dims_from_config(path: str) -> Optional[ModelDimensions]:
|
||||||
|
"""
|
||||||
|
attempt to infer ModelDimensions from a HF style config.json located
|
||||||
|
next to the given checkpoint, usefull for distilled models
|
||||||
|
"""
|
||||||
|
candidates = []
|
||||||
|
if os.path.isdir(path):
|
||||||
|
candidates.append(os.path.join(path, "config.json"))
|
||||||
|
else:
|
||||||
|
candidates.append(os.path.join(os.path.dirname(path), "config.json"))
|
||||||
|
|
||||||
|
for candidate in candidates:
|
||||||
|
if not os.path.isfile(candidate):
|
||||||
|
continue
|
||||||
|
with open(candidate, "r", encoding="utf-8") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return ModelDimensions(
|
||||||
|
n_mels=config["num_mel_bins"],
|
||||||
|
n_audio_ctx=config["max_source_positions"],
|
||||||
|
n_audio_state=config["d_model"],
|
||||||
|
n_audio_head=config["encoder_attention_heads"],
|
||||||
|
n_audio_layer=config.get("encoder_layers")
|
||||||
|
or config["num_hidden_layers"],
|
||||||
|
n_vocab=config["vocab_size"],
|
||||||
|
n_text_ctx=config["max_target_positions"],
|
||||||
|
n_text_state=config["d_model"],
|
||||||
|
n_text_head=config["decoder_attention_heads"],
|
||||||
|
n_text_layer=config["decoder_layers"],
|
||||||
|
)
|
||||||
|
except KeyError as err:
|
||||||
|
warnings.warn(f"Missing key {err} in HuggingFace config {candidate}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_hf_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
converts a HF checkpoint state_dict into the naming convention used by
|
||||||
|
default whisper
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not any(k.startswith("model.") for k in state_dict):
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def map_block(prefix: str, target_prefix: str, remainder: str) -> Optional[str]:
|
||||||
|
if remainder.startswith("self_attn."):
|
||||||
|
suffix = remainder.split(".", 1)[1]
|
||||||
|
mapping = {
|
||||||
|
"q_proj": "attn.query",
|
||||||
|
"k_proj": "attn.key",
|
||||||
|
"v_proj": "attn.value",
|
||||||
|
"out_proj": "attn.out",
|
||||||
|
}
|
||||||
|
stem = mapping.get(suffix.split(".")[0])
|
||||||
|
if stem:
|
||||||
|
rest = suffix.split(".", 1)[1] if "." in suffix else ""
|
||||||
|
return f"{target_prefix}.{stem}" + (f".{rest}" if rest else "")
|
||||||
|
elif remainder == "self_attn_layer_norm.weight":
|
||||||
|
return f"{target_prefix}.attn_ln.weight"
|
||||||
|
elif remainder == "self_attn_layer_norm.bias":
|
||||||
|
return f"{target_prefix}.attn_ln.bias"
|
||||||
|
elif remainder.startswith("encoder_attn."):
|
||||||
|
suffix = remainder.split(".", 1)[1]
|
||||||
|
mapping = {
|
||||||
|
"q_proj": "cross_attn.query",
|
||||||
|
"k_proj": "cross_attn.key",
|
||||||
|
"v_proj": "cross_attn.value",
|
||||||
|
"out_proj": "cross_attn.out",
|
||||||
|
}
|
||||||
|
stem = mapping.get(suffix.split(".", 1)[0])
|
||||||
|
if stem:
|
||||||
|
rest = suffix.split(".", 1)[1] if "." in suffix else ""
|
||||||
|
return f"{target_prefix}.{stem}" + (f".{rest}" if rest else "")
|
||||||
|
elif remainder == "encoder_attn_layer_norm.weight":
|
||||||
|
return f"{target_prefix}.cross_attn_ln.weight"
|
||||||
|
elif remainder == "encoder_attn_layer_norm.bias":
|
||||||
|
return f"{target_prefix}.cross_attn_ln.bias"
|
||||||
|
elif remainder.startswith("fc1."):
|
||||||
|
return f"{target_prefix}.mlp.0.{remainder.split('.',1)[1]}"
|
||||||
|
elif remainder.startswith("fc2."):
|
||||||
|
return f"{target_prefix}.mlp.2.{remainder.split('.',1)[1]}"
|
||||||
|
elif remainder == "final_layer_norm.weight":
|
||||||
|
return f"{target_prefix}.mlp_ln.weight"
|
||||||
|
elif remainder == "final_layer_norm.bias":
|
||||||
|
return f"{target_prefix}.mlp_ln.bias"
|
||||||
|
return None
|
||||||
|
|
||||||
|
converted = {}
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
if not key.startswith("model."):
|
||||||
|
continue
|
||||||
|
subkey = key[len("model.") :]
|
||||||
|
|
||||||
|
if subkey.startswith("encoder.layers."):
|
||||||
|
parts = subkey.split(".")
|
||||||
|
layer_idx = parts[2]
|
||||||
|
remainder = ".".join(parts[3:])
|
||||||
|
mapped = map_block(subkey, f"encoder.blocks.{layer_idx}", remainder)
|
||||||
|
elif subkey.startswith("decoder.layers."):
|
||||||
|
parts = subkey.split(".")
|
||||||
|
layer_idx = parts[2]
|
||||||
|
remainder = ".".join(parts[3:])
|
||||||
|
mapped = map_block(subkey, f"decoder.blocks.{layer_idx}", remainder)
|
||||||
|
elif subkey.startswith("encoder.conv") or subkey.startswith("decoder.conv"):
|
||||||
|
mapped = subkey
|
||||||
|
elif subkey == "encoder.embed_positions.weight":
|
||||||
|
mapped = "encoder.positional_embedding"
|
||||||
|
elif subkey == "decoder.embed_positions.weight":
|
||||||
|
mapped = "decoder.positional_embedding"
|
||||||
|
elif subkey == "encoder.layer_norm.weight":
|
||||||
|
mapped = "encoder.ln_post.weight"
|
||||||
|
elif subkey == "encoder.layer_norm.bias":
|
||||||
|
mapped = "encoder.ln_post.bias"
|
||||||
|
elif subkey.startswith("decoder.embed_tokens."):
|
||||||
|
mapped = subkey.replace("embed_tokens", "token_embedding", 1)
|
||||||
|
elif subkey == "decoder.layer_norm.weight":
|
||||||
|
mapped = "decoder.ln.weight"
|
||||||
|
elif subkey == "decoder.layer_norm.bias":
|
||||||
|
mapped = "decoder.ln.bias"
|
||||||
|
else:
|
||||||
|
mapped = None
|
||||||
|
|
||||||
|
if mapped:
|
||||||
|
converted[mapped] = value
|
||||||
|
|
||||||
|
return converted if converted else state_dict
|
||||||
|
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
name: str,
|
name: str,
|
||||||
device: Optional[Union[str, torch.device]] = None,
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
@@ -134,7 +267,6 @@ def load_model(
|
|||||||
if download_root is None:
|
if download_root is None:
|
||||||
default = os.path.join(os.path.expanduser("~"), ".cache")
|
default = os.path.join(os.path.expanduser("~"), ".cache")
|
||||||
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
|
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
|
||||||
|
|
||||||
if name in _MODELS:
|
if name in _MODELS:
|
||||||
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
|
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
|
||||||
elif os.path.isfile(name):
|
elif os.path.isfile(name):
|
||||||
@@ -148,22 +280,50 @@ def load_model(
|
|||||||
if custom_alignment_heads:
|
if custom_alignment_heads:
|
||||||
alignment_heads = custom_alignment_heads.encode()
|
alignment_heads = custom_alignment_heads.encode()
|
||||||
|
|
||||||
with (
|
if isinstance(checkpoint_file, Path) and checkpoint_file.suffix == '.safetensors':
|
||||||
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
|
try:
|
||||||
) as fp:
|
from safetensors.torch import load_file
|
||||||
checkpoint = torch.load(fp, map_location=device)
|
except ImportError:
|
||||||
|
raise ImportError("Please install safetensors to load .safetensors model files: `pip install safetensors`")
|
||||||
|
if in_memory:
|
||||||
|
checkpoint = load_file(checkpoint_file, device=device)
|
||||||
|
else:
|
||||||
|
checkpoint = load_file(checkpoint_file, device=device)
|
||||||
|
else:
|
||||||
|
with (
|
||||||
|
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
|
||||||
|
) as fp:
|
||||||
|
checkpoint = torch.load(fp, map_location=device)
|
||||||
del checkpoint_file
|
del checkpoint_file
|
||||||
|
|
||||||
dims = ModelDimensions(**checkpoint["dims"])
|
dims_cfg = checkpoint.get("dims") if isinstance(checkpoint, dict) else None
|
||||||
|
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
|
||||||
|
state_dict = checkpoint["model_state_dict"]
|
||||||
|
else:
|
||||||
|
state_dict = checkpoint
|
||||||
|
state_dict = _convert_hf_state_dict(state_dict)
|
||||||
|
|
||||||
|
if dims_cfg is not None:
|
||||||
|
dims = ModelDimensions(**dims_cfg)
|
||||||
|
else:
|
||||||
|
dims = _infer_dims_from_config(name)
|
||||||
|
if dims is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Could not determine model dimensions. "
|
||||||
|
"Ensure the checkpoint includes 'dims' or a HuggingFace config.json is present."
|
||||||
|
)
|
||||||
|
if not isinstance(state_dict, dict):
|
||||||
|
state_dict = checkpoint
|
||||||
|
|
||||||
model = Whisper(dims, decoder_only=decoder_only)
|
model = Whisper(dims, decoder_only=decoder_only)
|
||||||
|
|
||||||
if decoder_only:
|
if decoder_only:
|
||||||
checkpoint["model_state_dict"] = {
|
state_dict = {
|
||||||
k: v for k, v in checkpoint["model_state_dict"].items()
|
k: v for k, v in state_dict.items()
|
||||||
if 'encoder' not in k
|
if 'encoder' not in k
|
||||||
}
|
}
|
||||||
|
|
||||||
model.load_state_dict(checkpoint["model_state_dict"])
|
model.load_state_dict(state_dict)
|
||||||
|
|
||||||
if alignment_heads is not None:
|
if alignment_heads is not None:
|
||||||
model.set_alignment_heads(alignment_heads)
|
model.set_alignment_heads(alignment_heads)
|
||||||
@@ -1,122 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
import sys
|
|
||||||
import numpy as np
|
|
||||||
import librosa
|
|
||||||
from functools import lru_cache
|
|
||||||
import time
|
|
||||||
import logging
|
|
||||||
from .backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR
|
|
||||||
from whisperlivekit.warmup import warmup_asr
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
WHISPER_LANG_CODES = "af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh".split(
|
|
||||||
","
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_tokenizer(lan):
|
|
||||||
"""returns an object that has split function that works like the one of MosesTokenizer"""
|
|
||||||
|
|
||||||
assert (
|
|
||||||
lan in WHISPER_LANG_CODES
|
|
||||||
), "language must be Whisper's supported lang code: " + " ".join(WHISPER_LANG_CODES)
|
|
||||||
|
|
||||||
if lan == "uk":
|
|
||||||
import tokenize_uk
|
|
||||||
|
|
||||||
class UkrainianTokenizer:
|
|
||||||
def split(self, text):
|
|
||||||
return tokenize_uk.tokenize_sents(text)
|
|
||||||
|
|
||||||
return UkrainianTokenizer()
|
|
||||||
|
|
||||||
# supported by fast-mosestokenizer
|
|
||||||
if (
|
|
||||||
lan
|
|
||||||
in "as bn ca cs de el en es et fi fr ga gu hi hu is it kn lt lv ml mni mr nl or pa pl pt ro ru sk sl sv ta te yue zh".split()
|
|
||||||
):
|
|
||||||
from mosestokenizer import MosesSentenceSplitter
|
|
||||||
|
|
||||||
return MosesSentenceSplitter(lan)
|
|
||||||
|
|
||||||
# the following languages are in Whisper, but not in wtpsplit:
|
|
||||||
if (
|
|
||||||
lan
|
|
||||||
in "as ba bo br bs fo haw hr ht jw lb ln lo mi nn oc sa sd sn so su sw tk tl tt".split()
|
|
||||||
):
|
|
||||||
logger.debug(
|
|
||||||
f"{lan} code is not supported by wtpsplit. Going to use None lang_code option."
|
|
||||||
)
|
|
||||||
lan = None
|
|
||||||
|
|
||||||
from wtpsplit import WtP
|
|
||||||
|
|
||||||
# downloads the model from huggingface on the first use
|
|
||||||
wtp = WtP("wtp-canine-s-12l-no-adapters")
|
|
||||||
|
|
||||||
class WtPtok:
|
|
||||||
def split(self, sent):
|
|
||||||
return wtp.split(sent, lang_code=lan)
|
|
||||||
|
|
||||||
return WtPtok()
|
|
||||||
|
|
||||||
|
|
||||||
def backend_factory(
|
|
||||||
backend,
|
|
||||||
lan,
|
|
||||||
model_size,
|
|
||||||
model_cache_dir,
|
|
||||||
model_dir,
|
|
||||||
task,
|
|
||||||
buffer_trimming,
|
|
||||||
buffer_trimming_sec,
|
|
||||||
confidence_validation,
|
|
||||||
warmup_file=None,
|
|
||||||
min_chunk_size=None,
|
|
||||||
):
|
|
||||||
backend = backend
|
|
||||||
if backend == "openai-api":
|
|
||||||
logger.debug("Using OpenAI API.")
|
|
||||||
asr = OpenaiApiASR(lan=lan)
|
|
||||||
else:
|
|
||||||
if backend == "faster-whisper":
|
|
||||||
asr_cls = FasterWhisperASR
|
|
||||||
elif backend == "mlx-whisper":
|
|
||||||
asr_cls = MLXWhisper
|
|
||||||
else:
|
|
||||||
asr_cls = WhisperTimestampedASR
|
|
||||||
|
|
||||||
# Only for FasterWhisperASR and WhisperTimestampedASR
|
|
||||||
|
|
||||||
t = time.time()
|
|
||||||
logger.info(f"Loading Whisper {model_size} model for language {lan}...")
|
|
||||||
asr = asr_cls(
|
|
||||||
model_size=model_size,
|
|
||||||
lan=lan,
|
|
||||||
cache_dir=model_cache_dir,
|
|
||||||
model_dir=model_dir,
|
|
||||||
)
|
|
||||||
e = time.time()
|
|
||||||
logger.info(f"done. It took {round(e-t,2)} seconds.")
|
|
||||||
|
|
||||||
if task == "translate":
|
|
||||||
tgt_language = "en" # Whisper translates into English
|
|
||||||
else:
|
|
||||||
tgt_language = lan # Whisper transcribes in this language
|
|
||||||
|
|
||||||
# Create the tokenizer
|
|
||||||
if buffer_trimming == "sentence":
|
|
||||||
tokenizer = create_tokenizer(tgt_language)
|
|
||||||
else:
|
|
||||||
tokenizer = None
|
|
||||||
|
|
||||||
warmup_asr(asr, warmup_file)
|
|
||||||
|
|
||||||
asr.confidence_validation = confidence_validation
|
|
||||||
asr.tokenizer = tokenizer
|
|
||||||
asr.buffer_trimming = buffer_trimming
|
|
||||||
asr.buffer_trimming_sec = buffer_trimming_sec
|
|
||||||
return asr
|
|
||||||
Reference in New Issue
Block a user