Compare commits
102 Commits
ScriptProc
...
0.2.15
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6206fff118 | ||
|
|
b5067249c0 | ||
|
|
f4f9831d39 | ||
|
|
254faaf64c | ||
|
|
8e7aea4fcf | ||
|
|
270faf2069 | ||
|
|
b7c1cc77cc | ||
|
|
9a45ec221c | ||
|
|
3e13ee6fc3 | ||
|
|
b7d20a0ff0 | ||
|
|
c1bb9c2bde | ||
|
|
11e9def0b2 | ||
|
|
3104f40f6e | ||
|
|
e9b4ceeee5 | ||
|
|
437641fb43 | ||
|
|
bfd60b3921 | ||
|
|
1e67bf97f0 | ||
|
|
bbd4fd6cff | ||
|
|
28985962a0 | ||
|
|
a38c103fcd | ||
|
|
4d2ffb24f8 | ||
|
|
1bbbb7903c | ||
|
|
bcffdbc6b3 | ||
|
|
80b77998f9 | ||
|
|
d310f7e25f | ||
|
|
8d9be88fe6 | ||
|
|
16461052ed | ||
|
|
5491dbd824 | ||
|
|
13401ffe24 | ||
|
|
7108d2ddc5 | ||
|
|
a732e0903e | ||
|
|
0491681be4 | ||
|
|
ffe5284764 | ||
|
|
41ca17acda | ||
|
|
06b31f51eb | ||
|
|
ece02db6a3 | ||
|
|
939a7ebf8b | ||
|
|
61edb70fff | ||
|
|
4e455b8aab | ||
|
|
9434390ad3 | ||
|
|
65250db92c | ||
|
|
416dce7975 | ||
|
|
0c5365e7c6 | ||
|
|
19e9d76610 | ||
|
|
e7b05b0138 | ||
|
|
818c9c37ca | ||
|
|
714fb3b14a | ||
|
|
0af379c465 | ||
|
|
9c5bb5df19 | ||
|
|
dc6ea79036 | ||
|
|
21bbb59e31 | ||
|
|
12a69205ed | ||
|
|
1f684cdd97 | ||
|
|
3467109668 | ||
|
|
971f8473eb | ||
|
|
8434ef5efc | ||
|
|
290470dd60 | ||
|
|
425ac7b51d | ||
|
|
0382cfbeba | ||
|
|
9b1e061b32 | ||
|
|
b4abc158b9 | ||
|
|
5832d7433d | ||
|
|
3736458503 | ||
|
|
374618e050 | ||
|
|
543972ef38 | ||
|
|
73f36cc0ef | ||
|
|
a7db39d999 | ||
|
|
a153e11fe0 | ||
|
|
ca6f9246cc | ||
|
|
d080d675a8 | ||
|
|
40bff38933 | ||
|
|
2fe3ca0188 | ||
|
|
545ea15c9a | ||
|
|
8cbaeecc75 | ||
|
|
70e854b346 | ||
|
|
d55490cd27 | ||
|
|
1fa9e1f656 | ||
|
|
994f30e1ed | ||
|
|
b22478c0b4 | ||
|
|
94c34efd90 | ||
|
|
32099b9275 | ||
|
|
9fc6654a4a | ||
|
|
d24c110d55 | ||
|
|
4dd5d8bf8a | ||
|
|
cd9a32a36b | ||
|
|
6caf3e0485 | ||
|
|
93f002cafb | ||
|
|
c5e30c2c07 | ||
|
|
1c2afb8bd2 | ||
|
|
674b20d3af | ||
|
|
a5503308c5 | ||
|
|
e61afdefa3 | ||
|
|
426d70a790 | ||
|
|
b03a212fbf | ||
|
|
1833e7c921 | ||
|
|
777ec63a71 | ||
|
|
0a6e5ae9c1 | ||
|
|
ee448a37e9 | ||
|
|
9c051052b0 | ||
|
|
65025cc448 | ||
|
|
bbba1d9bb7 | ||
|
|
99dc96c644 |
16
.gitignore
vendored
@@ -54,21 +54,6 @@ coverage.xml
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
@@ -139,3 +124,4 @@ launch.json
|
||||
.DS_Store
|
||||
test/*
|
||||
nllb-200-distilled-600M-ctranslate2/*
|
||||
*.mp3
|
||||
23
DEV_NOTES.md
@@ -18,8 +18,29 @@ Decoder weights: 59110771 bytes
|
||||
Encoder weights: 15268874 bytes
|
||||
|
||||
|
||||
# 2. Translation: Faster model for each system
|
||||
|
||||
# 2. SortFormer Diarization: 4-to-2 Speaker Constraint Algorithm
|
||||
## Benchmark Results
|
||||
|
||||
Testing on MacBook M3 with NLLB-200-distilled-600M model:
|
||||
|
||||
### Standard Transformers vs CTranslate2
|
||||
|
||||
| Test Text | Standard Inference Time | CTranslate2 Inference Time | Speedup |
|
||||
|-----------|-------------------------|---------------------------|---------|
|
||||
| UN Chief says there is no military solution in Syria | 0.9395s | 2.0472s | 0.5x |
|
||||
| The rapid advancement of AI technology is transforming various industries | 0.7171s | 1.7516s | 0.4x |
|
||||
| Climate change poses a significant threat to global ecosystems | 0.8533s | 1.8323s | 0.5x |
|
||||
| International cooperation is essential for addressing global challenges | 0.7209s | 1.3575s | 0.5x |
|
||||
| The development of renewable energy sources is crucial for a sustainable future | 0.8760s | 1.5589s | 0.6x |
|
||||
|
||||
**Results:**
|
||||
- Total Standard time: 4.1068s
|
||||
- Total CTranslate2 time: 8.5476s
|
||||
- CTranslate2 is slower on this system --> Use Transformers, and ideally we would have an mlx implementation.
|
||||
|
||||
|
||||
# 3. SortFormer Diarization: 4-to-2 Speaker Constraint Algorithm
|
||||
|
||||
Transform a diarization model that predicts up to 4 speakers into one that predicts up to 2 speakers by mapping the output predictions.
|
||||
|
||||
|
||||
226
LICENSE
@@ -1,52 +1,210 @@
|
||||
# License
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
## Main Software License
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
MIT License
|
||||
1. Definitions.
|
||||
|
||||
Copyright (c) 2025 Quentin Fuxa.
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
## SimulStreaming Backend License
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
**When using the SimulStreaming backend (SimulWhisper), additional licensing terms apply:**
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
SimulStreaming (https://github.com/ufal/SimulStreaming) is dual-licensed:
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
### 🔹 Non-Commercial Use
|
||||
You may use SimulStreaming under the **PolyForm Noncommercial License 1.0.0** if you obtain the code through the GitHub repository. This license is **free of charge** and comes with **no obligations** for non-commercial users.
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
### 🔸 Commercial Use
|
||||
Understanding who uses SimulStreaming commercially helps improve and prioritize development. Therefore, **registration is required** for those who acquire a commercial license.
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
Commercial licenses are planned to be **affordable** to SMEs and individuals. They are considering providing commercial licenses either for free or for a symbolic one-time fee, and may also provide additional support. You can share your preference via the [questionnaire](https://forms.cloud.microsoft.com/e/7tCxb4gJfB).
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
You can also leave your contact [there](https://forms.cloud.microsoft.com/e/7tCxb4gJfB) to be notified when commercial licenses become available.
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
**Contact for SimulStreaming licensing:**
|
||||
[Dominik Macháček](https://ufal.mff.cuni.cz/dominik-machacek/), machacek@ufal.mff.cuni.cz
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2025 Quentin Fuxa
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
---
|
||||
|
||||
## Based on:
|
||||
- **whisper_streaming** by ÚFAL – MIT License – https://github.com/ufal/whisper_streaming. The original work by ÚFAL. License: https://github.com/ufal/whisper_streaming/blob/main/LICENSE
|
||||
- **silero-vad** by Snakers4 – MIT License – https://github.com/snakers4/silero-vad. The work by Snakers4 (silero-vad). License: https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/LICENSE
|
||||
- **Diart** by juanmc2005 – MIT License – https://github.com/juanmc2005/diart. The work in Diart by juanmc2005. License: https://github.com/juanmc2005/diart/blob/main/LICENSE
|
||||
- **SimulStreaming** by ÚFAL – Dual License (PolyForm Noncommercial License 1.0.0 / Commercial License) – https://github.com/ufal/SimulStreaming
|
||||
- **SimulWhisper** by Speech and Audio Technology LAB of Tsinghua University – Apache-2.0 – https://github.com/ufal/SimulStreaming
|
||||
- **SimulStreaming** by ÚFAL – MIT License – https://github.com/ufal/SimulStreaming
|
||||
- **NeMo** by NVidia - Apache-2.0 - https://github.com/NVIDIA-NeMo/NeMo
|
||||
- **whisper_streaming** by ÚFAL – MIT License – https://github.com/ufal/whisper_streaming.
|
||||
- **silero-vad** by Snakers4 – MIT License – https://github.com/snakers4/silero-vad.
|
||||
- **Diart** by juanmc2005 – MIT License – https://github.com/juanmc2005/diart.
|
||||
|
||||
105
README.md
@@ -10,17 +10,17 @@
|
||||
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
|
||||
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a>
|
||||
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.15-dark_green"></a>
|
||||
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT/Dual Licensed-dark_green"></a>
|
||||
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-Apache 2.0-dark_green"></a>
|
||||
</p>
|
||||
|
||||
|
||||
Real-time speech transcription directly to your browser, with a ready-to-use backend+server and a simple frontend. ✨
|
||||
Real-time transcription directly to your browser, with a ready-to-use backend+server and a simple frontend.
|
||||
|
||||
#### Powered by Leading Research:
|
||||
|
||||
- [SimulStreaming](https://github.com/ufal/SimulStreaming) (SOTA 2025) - Ultra-low latency transcription with AlignAtt policy
|
||||
- [NLLB](https://arxiv.org/abs/2207.04672), ([distilled](https://huggingface.co/entai2965/nllb-200-distilled-600M-ctranslate2)) (2024) - Translation to more than 100 languages.
|
||||
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - Low latency transcription with LocalAgreement policy
|
||||
- Simul-[Whisper](https://github.com/backspacetg/simul_whisper)/[Streaming](https://github.com/ufal/SimulStreaming) (SOTA 2025) - Ultra-low latency transcription using [AlignAtt policy](https://arxiv.org/pdf/2305.11408)
|
||||
- [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting) (2025), based on [distilled](https://huggingface.co/entai2965/nllb-200-distilled-600M-ctranslate2) [NLLB](https://arxiv.org/abs/2207.04672) (2022, 2024) - Simulatenous translation from & to 200 languages.
|
||||
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - Low latency transcription using [LocalAgreement policy](https://www.isca-archive.org/interspeech_2020/liu20s_interspeech.pdf)
|
||||
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - Advanced real-time speaker diarization
|
||||
- [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) - Real-time speaker diarization
|
||||
- [Silero VAD](https://github.com/snakers4/silero-vad) (2024) - Enterprise-grade Voice Activity Detection
|
||||
@@ -42,19 +42,10 @@ pip install whisperlivekit
|
||||
```
|
||||
> You can also clone the repo and `pip install -e .` for the latest version.
|
||||
|
||||
|
||||
> **FFmpeg is required** and must be installed before using WhisperLiveKit
|
||||
>
|
||||
> | OS | How to install |
|
||||
> |-----------|-------------|
|
||||
> | Ubuntu/Debian | `sudo apt install ffmpeg` |
|
||||
> | MacOS | `brew install ffmpeg` |
|
||||
> | Windows | Download .exe from https://ffmpeg.org/download.html and add to PATH |
|
||||
|
||||
#### Quick Start
|
||||
1. **Start the transcription server:**
|
||||
```bash
|
||||
whisperlivekit-server --model base --language en
|
||||
wlk --model base --language en
|
||||
```
|
||||
|
||||
2. **Open your browser** and navigate to `http://localhost:8000`. Start speaking and watch your words appear in real-time!
|
||||
@@ -62,6 +53,15 @@ pip install whisperlivekit
|
||||
|
||||
> - See [tokenizer.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) for the list of all available languages.
|
||||
> - For HTTPS requirements, see the **Parameters** section for SSL configuration options.
|
||||
> - The CLI entry point is exposed as both `wlk` and `whisperlivekit-server`; they are equivalent.
|
||||
|
||||
#### Use it to capture audio from web pages.
|
||||
|
||||
Go to `chrome-extension` for instructions.
|
||||
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/chrome-extension/demo-extension.png" alt="WhisperLiveKit Demo" width="600">
|
||||
</p>
|
||||
|
||||
|
||||
|
||||
@@ -69,13 +69,12 @@ pip install whisperlivekit
|
||||
|
||||
| Optional | `pip install` |
|
||||
|-----------|-------------|
|
||||
| **Speaker diarization with Sortformer** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
|
||||
| **Apple Silicon optimized backend** | `mlx-whisper` |
|
||||
| **NLLB Translation** | `huggingface_hub` & `transformers` |
|
||||
| **Windows/Linux optimizations** | `faster-whisper` |
|
||||
| **Apple Silicon optimizations** | `mlx-whisper` |
|
||||
| **Translation** | `nllw` |
|
||||
| **Speaker diarization** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
|
||||
| OpenAI API | `openai` |
|
||||
| *[Not recommanded]* Speaker diarization with Diart | `diart` |
|
||||
| *[Not recommanded]* Original Whisper backend | `whisper` |
|
||||
| *[Not recommanded]* Improved timestamps backend | `whisper-timestamped` |
|
||||
| OpenAI API backend | `openai` |
|
||||
|
||||
See **Parameters & Configuration** below on how to use them.
|
||||
|
||||
@@ -86,11 +85,11 @@ See **Parameters & Configuration** below on how to use them.
|
||||
**Command-line Interface**: Start the transcription server with various options:
|
||||
|
||||
```bash
|
||||
# Use better model than default (small)
|
||||
whisperlivekit-server --model large-v3
|
||||
# Large model and translate from french to danish
|
||||
wlk --model large-v3 --language fr --target-language da
|
||||
|
||||
# Advanced configuration with diarization and language
|
||||
whisperlivekit-server --host 0.0.0.0 --port 8000 --model medium --diarization --language fr
|
||||
# Diarization and server listening on */80
|
||||
wlk --host 0.0.0.0 --port 80 --model medium --diarization --language fr
|
||||
```
|
||||
|
||||
|
||||
@@ -137,26 +136,16 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
|
||||
## Parameters & Configuration
|
||||
|
||||
An important list of parameters can be changed. But what *should* you change?
|
||||
- the `--model` size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/available_models.md)
|
||||
- the `--language`. List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English.
|
||||
- the `--backend` ? you can switch to `--backend faster-whisper` if `simulstreaming` does not work correctly or if you prefer to avoid the dual-license requirements.
|
||||
- `--warmup-file`, if you have one
|
||||
- `--task translate`, to translate in english
|
||||
- `--host`, `--port`, `--ssl-certfile`, `--ssl-keyfile`, if you set up a server
|
||||
- `--diarization`, if you want to use it.
|
||||
- [BETA] `--target-language`, to translate using NLLB. [118 languages available](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/translation/mapping_languages.py). If you want to translate to english, you should rather use `--task translate`, since Whisper can do it directly.
|
||||
|
||||
### Full list of parameters :
|
||||
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--model` | Whisper model size. | `small` |
|
||||
| `--language` | Source language code or `auto` | `auto` |
|
||||
| `--task` | Set to `translate` to translate to english | `transcribe` |
|
||||
| `--target-language` | [BETA] Translation language target. Ex: `fr` | `None` |
|
||||
| `--backend` | Processing backend | `simulstreaming` |
|
||||
| `--min-chunk-size` | Minimum audio chunk size (seconds) | `1.0` |
|
||||
| `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/available_models.md) | `small` |
|
||||
| `--model-path` | Local .pt file/directory **or** Hugging Face repo ID containing the Whisper model. Overrides `--model`. Recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/models_compatible_formats.md) | `None` |
|
||||
| `--language` | List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` |
|
||||
| `--target-language` | If sets, translates using [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting). [200 languages available](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/supported_languages.md). If you want to translate to english, you can also use `--direct-english-translation`. The STT model will try to directly output the translation. | `None` |
|
||||
| `--diarization` | Enable speaker identification | `False` |
|
||||
| `--backend-policy` | Streaming strategy: `1`/`simulstreaming` uses AlignAtt SimulStreaming, `2`/`localagreement` uses the LocalAgreement policy | `simulstreaming` |
|
||||
| `--backend` | Whisper implementation selector. `auto` picks MLX on macOS (if installed), otherwise Faster-Whisper, otherwise vanilla Whisper. You can also force `mlx-whisper`, `faster-whisper`, `whisper`, or `openai-api` (LocalAgreement only) | `auto` |
|
||||
| `--no-vac` | Disable Voice Activity Controller | `False` |
|
||||
| `--no-vad` | Disable Voice Activity Detection | `False` |
|
||||
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
|
||||
@@ -164,12 +153,26 @@ An important list of parameters can be changed. But what *should* you change?
|
||||
| `--port` | Server port | `8000` |
|
||||
| `--ssl-certfile` | Path to the SSL certificate file (for HTTPS support) | `None` |
|
||||
| `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` |
|
||||
| `--pcm-input` | raw PCM (s16le) data is expected as input and FFmpeg will be bypassed. | `False` |
|
||||
| `--forwarded-allow-ips` | Ip or Ips allowed to reverse proxy the whisperlivekit-server. Supported types are IP Addresses (e.g. 127.0.0.1), IP Networks (e.g. 10.100.0.0/16), or Literals (e.g. /path/to/socket.sock) | `None` |
|
||||
| `--pcm-input` | raw PCM (s16le) data is expected as input and FFmpeg will be bypassed. Frontend will use AudioWorklet instead of MediaRecorder | `False` |
|
||||
|
||||
| Translation options | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--nllb-backend` | `transformers` or `ctranslate2` | `ctranslate2` |
|
||||
| `--nllb-size` | `600M` or `1.3B` | `600M` |
|
||||
|
||||
| Diarization options | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--diarization-backend` | `diart` or `sortformer` | `sortformer` |
|
||||
| `--disable-punctuation-split` | Disable punctuation based splits. See #214 | `False` |
|
||||
| `--segmentation-model` | Hugging Face model ID for Diart segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
|
||||
| `--embedding-model` | Hugging Face model ID for Diart embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
|
||||
|
||||
| SimulStreaming backend options | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--disable-fast-encoder` | Disable Faster Whisper or MLX Whisper backends for the encoder (if installed). Inference can be slower but helpful when GPU memory is limited | `False` |
|
||||
| `--custom-alignment-heads` | Use your own alignment heads, useful when `--model-dir` is used. Use `scripts/determine_alignment_heads.py` to extract them. <img src="scripts/alignment_heads.png" alt="WhisperLiveKit Demo" width="300">
|
||||
| `None` |
|
||||
| `--frame-threshold` | AlignAtt frame threshold (lower = faster, higher = more accurate) | `25` |
|
||||
| `--beams` | Number of beams for beam search (1 = greedy decoding) | `1` |
|
||||
| `--decoder` | Force decoder type (`beam` or `greedy`) | `auto` |
|
||||
@@ -180,29 +183,19 @@ An important list of parameters can be changed. But what *should* you change?
|
||||
| `--init-prompt` | Initial prompt for the model | `None` |
|
||||
| `--static-init-prompt` | Static prompt that doesn't scroll | `None` |
|
||||
| `--max-context-tokens` | Maximum context tokens | `None` |
|
||||
| `--model-path` | Direct path to .pt model file. Download it if not found | `./base.pt` |
|
||||
| `--preload-model-count` | Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent users) | `1` |
|
||||
|
||||
|
||||
|
||||
| WhisperStreaming backend options | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--confidence-validation` | Use confidence scores for faster validation | `False` |
|
||||
| `--buffer_trimming` | Buffer trimming strategy (`sentence` or `segment`) | `segment` |
|
||||
|
||||
| Diarization options | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--diarization` | Enable speaker identification | `False` |
|
||||
| `--diarization-backend` | `diart` or `sortformer` | `sortformer` |
|
||||
| `--disable-punctuation-split` | Disable punctuation based splits. See #214 | `False` |
|
||||
| `--segmentation-model` | Hugging Face model ID for Diart segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
|
||||
| `--embedding-model` | Hugging Face model ID for Diart embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
|
||||
|
||||
|
||||
> For diarization using Diart, you need access to pyannote.audio models:
|
||||
> 1. [Accept user conditions](https://huggingface.co/pyannote/segmentation) for the `pyannote/segmentation` model
|
||||
> 2. [Accept user conditions](https://huggingface.co/pyannote/segmentation-3.0) for the `pyannote/segmentation-3.0` model
|
||||
> 3. [Accept user conditions](https://huggingface.co/pyannote/embedding) for the `pyannote/embedding` model
|
||||
>4. Login with HuggingFace: `huggingface-cli login`
|
||||
|
||||
> For diarization using Diart, you need to accept user conditions [here](https://huggingface.co/pyannote/segmentation) for the `pyannote/segmentation` model, [here](https://huggingface.co/pyannote/segmentation-3.0) for the `pyannote/segmentation-3.0` model and [here](https://huggingface.co/pyannote/embedding) for the `pyannote/embedding` model. **Then**, login to HuggingFace: `huggingface-cli login`
|
||||
|
||||
### 🚀 Deployment Guide
|
||||
|
||||
|
||||
BIN
architecture.png
|
Before Width: | Height: | Size: 368 KiB After Width: | Height: | Size: 406 KiB |
@@ -1,11 +1,13 @@
|
||||
## WhisperLiveKit Chrome Extension v0.1.0
|
||||
Capture the audio of your current tab, transcribe or translate it using WhisperliveKit. **Still unstable**
|
||||
## WhisperLiveKit Chrome Extension v0.1.1
|
||||
Capture the audio of your current tab, transcribe diarize and translate it using WhisperliveKit, in Chrome and other Chromium-based browsers.
|
||||
|
||||
> Currently, only the tab audio is captured; your microphone audio is not recorded.
|
||||
|
||||
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/chrome-extension/demo-extension.png" alt="WhisperLiveKit Demo" width="730">
|
||||
|
||||
## Running this extension
|
||||
1. Clone this repository.
|
||||
2. Load this directory in Chrome as an unpacked extension.
|
||||
1. Run `python sync_extension.py` to copy frontend files to the `chrome-extension` directory.
|
||||
2. Load the `chrome-extension` directory in Chrome as an unpacked extension.
|
||||
|
||||
|
||||
## Devs:
|
||||
|
||||
|
Before Width: | Height: | Size: 1.2 MiB After Width: | Height: | Size: 5.8 MiB |
@@ -1,669 +0,0 @@
|
||||
/* Theme, WebSocket, recording, rendering logic extracted from inline script and adapted for segmented theme control and WS caption */
|
||||
let isRecording = false;
|
||||
let websocket = null;
|
||||
let recorder = null;
|
||||
let chunkDuration = 100;
|
||||
let websocketUrl = "ws://localhost:8000/asr";
|
||||
let userClosing = false;
|
||||
let wakeLock = null;
|
||||
let startTime = null;
|
||||
let timerInterval = null;
|
||||
let audioContext = null;
|
||||
let analyser = null;
|
||||
let microphone = null;
|
||||
let waveCanvas = document.getElementById("waveCanvas");
|
||||
let waveCtx = waveCanvas.getContext("2d");
|
||||
let animationFrame = null;
|
||||
let waitingForStop = false;
|
||||
let lastReceivedData = null;
|
||||
let lastSignature = null;
|
||||
let availableMicrophones = [];
|
||||
let selectedMicrophoneId = null;
|
||||
|
||||
waveCanvas.width = 60 * (window.devicePixelRatio || 1);
|
||||
waveCanvas.height = 30 * (window.devicePixelRatio || 1);
|
||||
waveCtx.scale(window.devicePixelRatio || 1, window.devicePixelRatio || 1);
|
||||
|
||||
const statusText = document.getElementById("status");
|
||||
const recordButton = document.getElementById("recordButton");
|
||||
const chunkSelector = document.getElementById("chunkSelector");
|
||||
const websocketInput = document.getElementById("websocketInput");
|
||||
const websocketDefaultSpan = document.getElementById("wsDefaultUrl");
|
||||
const linesTranscriptDiv = document.getElementById("linesTranscript");
|
||||
const timerElement = document.querySelector(".timer");
|
||||
const themeRadios = document.querySelectorAll('input[name="theme"]');
|
||||
const microphoneSelect = document.getElementById("microphoneSelect");
|
||||
const settingsToggle = document.getElementById("settingsToggle");
|
||||
const settingsDiv = document.querySelector(".settings");
|
||||
|
||||
|
||||
|
||||
chrome.runtime.onInstalled.addListener((details) => {
|
||||
if (details.reason.search(/install/g) === -1) {
|
||||
return
|
||||
}
|
||||
chrome.tabs.create({
|
||||
url: chrome.runtime.getURL("welcome.html"),
|
||||
active: true
|
||||
})
|
||||
})
|
||||
|
||||
function getWaveStroke() {
|
||||
const styles = getComputedStyle(document.documentElement);
|
||||
const v = styles.getPropertyValue("--wave-stroke").trim();
|
||||
return v || "#000";
|
||||
}
|
||||
|
||||
let waveStroke = getWaveStroke();
|
||||
function updateWaveStroke() {
|
||||
waveStroke = getWaveStroke();
|
||||
}
|
||||
|
||||
function applyTheme(pref) {
|
||||
if (pref === "light") {
|
||||
document.documentElement.setAttribute("data-theme", "light");
|
||||
} else if (pref === "dark") {
|
||||
document.documentElement.setAttribute("data-theme", "dark");
|
||||
} else {
|
||||
document.documentElement.removeAttribute("data-theme");
|
||||
}
|
||||
updateWaveStroke();
|
||||
}
|
||||
|
||||
// Persisted theme preference
|
||||
const savedThemePref = localStorage.getItem("themePreference") || "system";
|
||||
applyTheme(savedThemePref);
|
||||
if (themeRadios.length) {
|
||||
themeRadios.forEach((r) => {
|
||||
r.checked = r.value === savedThemePref;
|
||||
r.addEventListener("change", () => {
|
||||
if (r.checked) {
|
||||
localStorage.setItem("themePreference", r.value);
|
||||
applyTheme(r.value);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// React to OS theme changes when in "system" mode
|
||||
const darkMq = window.matchMedia && window.matchMedia("(prefers-color-scheme: dark)");
|
||||
const handleOsThemeChange = () => {
|
||||
const pref = localStorage.getItem("themePreference") || "system";
|
||||
if (pref === "system") updateWaveStroke();
|
||||
};
|
||||
if (darkMq && darkMq.addEventListener) {
|
||||
darkMq.addEventListener("change", handleOsThemeChange);
|
||||
} else if (darkMq && darkMq.addListener) {
|
||||
// deprecated, but included for Safari compatibility
|
||||
darkMq.addListener(handleOsThemeChange);
|
||||
}
|
||||
|
||||
async function enumerateMicrophones() {
|
||||
try {
|
||||
const micPermission = await navigator.permissions.query({
|
||||
name: "microphone",
|
||||
});
|
||||
|
||||
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
|
||||
stream.getTracks().forEach(track => track.stop());
|
||||
|
||||
const devices = await navigator.mediaDevices.enumerateDevices();
|
||||
availableMicrophones = devices.filter(device => device.kind === 'audioinput');
|
||||
|
||||
populateMicrophoneSelect();
|
||||
console.log(`Found ${availableMicrophones.length} microphone(s)`);
|
||||
} catch (error) {
|
||||
console.error('Error enumerating microphones:', error);
|
||||
statusText.textContent = "Error accessing microphones. Please grant permission.";
|
||||
}
|
||||
}
|
||||
|
||||
function populateMicrophoneSelect() {
|
||||
if (!microphoneSelect) return;
|
||||
|
||||
microphoneSelect.innerHTML = '<option value="">Default Microphone</option>';
|
||||
|
||||
availableMicrophones.forEach((device, index) => {
|
||||
const option = document.createElement('option');
|
||||
option.value = device.deviceId;
|
||||
option.textContent = device.label || `Microphone ${index + 1}`;
|
||||
microphoneSelect.appendChild(option);
|
||||
});
|
||||
|
||||
const savedMicId = localStorage.getItem('selectedMicrophone');
|
||||
if (savedMicId && availableMicrophones.some(mic => mic.deviceId === savedMicId)) {
|
||||
microphoneSelect.value = savedMicId;
|
||||
selectedMicrophoneId = savedMicId;
|
||||
}
|
||||
}
|
||||
|
||||
function handleMicrophoneChange() {
|
||||
selectedMicrophoneId = microphoneSelect.value || null;
|
||||
localStorage.setItem('selectedMicrophone', selectedMicrophoneId || '');
|
||||
|
||||
const selectedDevice = availableMicrophones.find(mic => mic.deviceId === selectedMicrophoneId);
|
||||
const deviceName = selectedDevice ? selectedDevice.label : 'Default Microphone';
|
||||
|
||||
console.log(`Selected microphone: ${deviceName}`);
|
||||
statusText.textContent = `Microphone changed to: ${deviceName}`;
|
||||
|
||||
if (isRecording) {
|
||||
statusText.textContent = "Switching microphone... Please wait.";
|
||||
stopRecording().then(() => {
|
||||
setTimeout(() => {
|
||||
toggleRecording();
|
||||
}, 1000);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Helpers
|
||||
function fmt1(x) {
|
||||
const n = Number(x);
|
||||
return Number.isFinite(n) ? n.toFixed(1) : x;
|
||||
}
|
||||
|
||||
// Default WebSocket URL computation
|
||||
const host = window.location.hostname || "localhost";
|
||||
const port = window.location.port;
|
||||
const protocol = window.location.protocol === "https:" ? "wss" : "ws";
|
||||
const defaultWebSocketUrl = websocketUrl;
|
||||
|
||||
// Populate default caption and input
|
||||
if (websocketDefaultSpan) websocketDefaultSpan.textContent = defaultWebSocketUrl;
|
||||
websocketInput.value = defaultWebSocketUrl;
|
||||
websocketUrl = defaultWebSocketUrl;
|
||||
|
||||
// Optional chunk selector (guard for presence)
|
||||
if (chunkSelector) {
|
||||
chunkSelector.addEventListener("change", () => {
|
||||
chunkDuration = parseInt(chunkSelector.value);
|
||||
});
|
||||
}
|
||||
|
||||
// WebSocket input change handling
|
||||
websocketInput.addEventListener("change", () => {
|
||||
const urlValue = websocketInput.value.trim();
|
||||
if (!urlValue.startsWith("ws://") && !urlValue.startsWith("wss://")) {
|
||||
statusText.textContent = "Invalid WebSocket URL (must start with ws:// or wss://)";
|
||||
return;
|
||||
}
|
||||
websocketUrl = urlValue;
|
||||
statusText.textContent = "WebSocket URL updated. Ready to connect.";
|
||||
});
|
||||
|
||||
function setupWebSocket() {
|
||||
return new Promise((resolve, reject) => {
|
||||
try {
|
||||
websocket = new WebSocket(websocketUrl);
|
||||
} catch (error) {
|
||||
statusText.textContent = "Invalid WebSocket URL. Please check and try again.";
|
||||
reject(error);
|
||||
return;
|
||||
}
|
||||
|
||||
websocket.onopen = () => {
|
||||
statusText.textContent = "Connected to server.";
|
||||
resolve();
|
||||
};
|
||||
|
||||
websocket.onclose = () => {
|
||||
if (userClosing) {
|
||||
if (waitingForStop) {
|
||||
statusText.textContent = "Processing finalized or connection closed.";
|
||||
if (lastReceivedData) {
|
||||
renderLinesWithBuffer(
|
||||
lastReceivedData.lines || [],
|
||||
lastReceivedData.buffer_diarization || "",
|
||||
lastReceivedData.buffer_transcription || "",
|
||||
0,
|
||||
0,
|
||||
true
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
statusText.textContent = "Disconnected from the WebSocket server. (Check logs if model is loading.)";
|
||||
if (isRecording) {
|
||||
stopRecording();
|
||||
}
|
||||
}
|
||||
isRecording = false;
|
||||
waitingForStop = false;
|
||||
userClosing = false;
|
||||
lastReceivedData = null;
|
||||
websocket = null;
|
||||
updateUI();
|
||||
};
|
||||
|
||||
websocket.onerror = () => {
|
||||
statusText.textContent = "Error connecting to WebSocket.";
|
||||
reject(new Error("Error connecting to WebSocket"));
|
||||
};
|
||||
|
||||
websocket.onmessage = (event) => {
|
||||
const data = JSON.parse(event.data);
|
||||
|
||||
if (data.type === "ready_to_stop") {
|
||||
console.log("Ready to stop received, finalizing display and closing WebSocket.");
|
||||
waitingForStop = false;
|
||||
|
||||
if (lastReceivedData) {
|
||||
renderLinesWithBuffer(
|
||||
lastReceivedData.lines || [],
|
||||
lastReceivedData.buffer_diarization || "",
|
||||
lastReceivedData.buffer_transcription || "",
|
||||
0,
|
||||
0,
|
||||
true
|
||||
);
|
||||
}
|
||||
statusText.textContent = "Finished processing audio! Ready to record again.";
|
||||
recordButton.disabled = false;
|
||||
|
||||
if (websocket) {
|
||||
websocket.close();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
lastReceivedData = data;
|
||||
|
||||
const {
|
||||
lines = [],
|
||||
buffer_transcription = "",
|
||||
buffer_diarization = "",
|
||||
remaining_time_transcription = 0,
|
||||
remaining_time_diarization = 0,
|
||||
status = "active_transcription",
|
||||
} = data;
|
||||
|
||||
renderLinesWithBuffer(
|
||||
lines,
|
||||
buffer_diarization,
|
||||
buffer_transcription,
|
||||
remaining_time_diarization,
|
||||
remaining_time_transcription,
|
||||
false,
|
||||
status
|
||||
);
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
function renderLinesWithBuffer(
|
||||
lines,
|
||||
buffer_diarization,
|
||||
buffer_transcription,
|
||||
remaining_time_diarization,
|
||||
remaining_time_transcription,
|
||||
isFinalizing = false,
|
||||
current_status = "active_transcription"
|
||||
) {
|
||||
if (current_status === "no_audio_detected") {
|
||||
linesTranscriptDiv.innerHTML =
|
||||
"<p style='text-align: center; color: var(--muted); margin-top: 20px;'><em>No audio detected...</em></p>";
|
||||
return;
|
||||
}
|
||||
|
||||
const showLoading = !isFinalizing && (lines || []).some((it) => it.speaker == 0);
|
||||
const showTransLag = !isFinalizing && remaining_time_transcription > 0;
|
||||
const showDiaLag = !isFinalizing && !!buffer_diarization && remaining_time_diarization > 0;
|
||||
const signature = JSON.stringify({
|
||||
lines: (lines || []).map((it) => ({ speaker: it.speaker, text: it.text, start: it.start, end: it.end })),
|
||||
buffer_transcription: buffer_transcription || "",
|
||||
buffer_diarization: buffer_diarization || "",
|
||||
status: current_status,
|
||||
showLoading,
|
||||
showTransLag,
|
||||
showDiaLag,
|
||||
isFinalizing: !!isFinalizing,
|
||||
});
|
||||
if (lastSignature === signature) {
|
||||
const t = document.querySelector(".lag-transcription-value");
|
||||
if (t) t.textContent = fmt1(remaining_time_transcription);
|
||||
const d = document.querySelector(".lag-diarization-value");
|
||||
if (d) d.textContent = fmt1(remaining_time_diarization);
|
||||
const ld = document.querySelector(".loading-diarization-value");
|
||||
if (ld) ld.textContent = fmt1(remaining_time_diarization);
|
||||
return;
|
||||
}
|
||||
lastSignature = signature;
|
||||
|
||||
const linesHtml = (lines || [])
|
||||
.map((item, idx) => {
|
||||
let timeInfo = "";
|
||||
if (item.start !== undefined && item.end !== undefined) {
|
||||
timeInfo = ` ${item.start} - ${item.end}`;
|
||||
}
|
||||
|
||||
let speakerLabel = "";
|
||||
if (item.speaker === -2) {
|
||||
speakerLabel = `<span class="silence">Silence<span id='timeInfo'>${timeInfo}</span></span>`;
|
||||
} else if (item.speaker == 0 && !isFinalizing) {
|
||||
speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'><span class="loading-diarization-value">${fmt1(
|
||||
remaining_time_diarization
|
||||
)}</span> second(s) of audio are undergoing diarization</span></span>`;
|
||||
} else if (item.speaker !== 0) {
|
||||
speakerLabel = `<span id="speaker">Speaker ${item.speaker}<span id='timeInfo'>${timeInfo}</span></span>`;
|
||||
}
|
||||
|
||||
let currentLineText = item.text || "";
|
||||
|
||||
if (idx === lines.length - 1) {
|
||||
if (!isFinalizing && item.speaker !== -2) {
|
||||
if (remaining_time_transcription > 0) {
|
||||
speakerLabel += `<span class="label_transcription"><span class="spinner"></span>Lag <span id='timeInfo'><span class="lag-transcription-value">${fmt1(
|
||||
remaining_time_transcription
|
||||
)}</span>s</span></span>`;
|
||||
}
|
||||
if (buffer_diarization && remaining_time_diarization > 0) {
|
||||
speakerLabel += `<span class="label_diarization"><span class="spinner"></span>Lag<span id='timeInfo'><span class="lag-diarization-value">${fmt1(
|
||||
remaining_time_diarization
|
||||
)}</span>s</span></span>`;
|
||||
}
|
||||
}
|
||||
|
||||
if (buffer_diarization) {
|
||||
if (isFinalizing) {
|
||||
currentLineText +=
|
||||
(currentLineText.length > 0 && buffer_diarization.trim().length > 0 ? " " : "") + buffer_diarization.trim();
|
||||
} else {
|
||||
currentLineText += `<span class="buffer_diarization">${buffer_diarization}</span>`;
|
||||
}
|
||||
}
|
||||
if (buffer_transcription) {
|
||||
if (isFinalizing) {
|
||||
currentLineText +=
|
||||
(currentLineText.length > 0 && buffer_transcription.trim().length > 0 ? " " : "") +
|
||||
buffer_transcription.trim();
|
||||
} else {
|
||||
currentLineText += `<span class="buffer_transcription">${buffer_transcription}</span>`;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return currentLineText.trim().length > 0 || speakerLabel.length > 0
|
||||
? `<p>${speakerLabel}<br/><div class='textcontent'>${currentLineText}</div></p>`
|
||||
: `<p>${speakerLabel}<br/></p>`;
|
||||
})
|
||||
.join("");
|
||||
|
||||
linesTranscriptDiv.innerHTML = linesHtml;
|
||||
window.scrollTo({ top: document.body.scrollHeight, behavior: "smooth" });
|
||||
}
|
||||
|
||||
function updateTimer() {
|
||||
if (!startTime) return;
|
||||
|
||||
const elapsed = Math.floor((Date.now() - startTime) / 1000);
|
||||
const minutes = Math.floor(elapsed / 60).toString().padStart(2, "0");
|
||||
const seconds = (elapsed % 60).toString().padStart(2, "0");
|
||||
timerElement.textContent = `${minutes}:${seconds}`;
|
||||
}
|
||||
|
||||
function drawWaveform() {
|
||||
if (!analyser) return;
|
||||
|
||||
const bufferLength = analyser.frequencyBinCount;
|
||||
const dataArray = new Uint8Array(bufferLength);
|
||||
analyser.getByteTimeDomainData(dataArray);
|
||||
|
||||
waveCtx.clearRect(
|
||||
0,
|
||||
0,
|
||||
waveCanvas.width / (window.devicePixelRatio || 1),
|
||||
waveCanvas.height / (window.devicePixelRatio || 1)
|
||||
);
|
||||
waveCtx.lineWidth = 1;
|
||||
waveCtx.strokeStyle = waveStroke;
|
||||
waveCtx.beginPath();
|
||||
|
||||
const sliceWidth = (waveCanvas.width / (window.devicePixelRatio || 1)) / bufferLength;
|
||||
let x = 0;
|
||||
|
||||
for (let i = 0; i < bufferLength; i++) {
|
||||
const v = dataArray[i] / 128.0;
|
||||
const y = (v * (waveCanvas.height / (window.devicePixelRatio || 1))) / 2;
|
||||
|
||||
if (i === 0) {
|
||||
waveCtx.moveTo(x, y);
|
||||
} else {
|
||||
waveCtx.lineTo(x, y);
|
||||
}
|
||||
|
||||
x += sliceWidth;
|
||||
}
|
||||
|
||||
waveCtx.lineTo(
|
||||
waveCanvas.width / (window.devicePixelRatio || 1),
|
||||
(waveCanvas.height / (window.devicePixelRatio || 1)) / 2
|
||||
);
|
||||
waveCtx.stroke();
|
||||
|
||||
animationFrame = requestAnimationFrame(drawWaveform);
|
||||
}
|
||||
|
||||
async function startRecording() {
|
||||
try {
|
||||
try {
|
||||
wakeLock = await navigator.wakeLock.request("screen");
|
||||
} catch (err) {
|
||||
console.log("Error acquiring wake lock.");
|
||||
}
|
||||
|
||||
let stream;
|
||||
try {
|
||||
// Try tab capture first
|
||||
stream = await new Promise((resolve, reject) => {
|
||||
chrome.tabCapture.capture({audio: true}, (s) => {
|
||||
if (s) {
|
||||
resolve(s);
|
||||
} else {
|
||||
reject(new Error('Tab capture failed or not available'));
|
||||
}
|
||||
});
|
||||
});
|
||||
statusText.textContent = "Using tab audio capture.";
|
||||
} catch (tabError) {
|
||||
console.log('Tab capture not available, falling back to microphone', tabError);
|
||||
// Fallback to microphone
|
||||
const audioConstraints = selectedMicrophoneId
|
||||
? { audio: { deviceId: { exact: selectedMicrophoneId } } }
|
||||
: { audio: true };
|
||||
stream = await navigator.mediaDevices.getUserMedia(audioConstraints);
|
||||
statusText.textContent = "Using microphone audio.";
|
||||
}
|
||||
|
||||
audioContext = new (window.AudioContext || window.webkitAudioContext)();
|
||||
analyser = audioContext.createAnalyser();
|
||||
analyser.fftSize = 256;
|
||||
microphone = audioContext.createMediaStreamSource(stream);
|
||||
microphone.connect(analyser);
|
||||
|
||||
recorder = new MediaRecorder(stream, { mimeType: "audio/webm" });
|
||||
recorder.ondataavailable = (e) => {
|
||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||
websocket.send(e.data);
|
||||
}
|
||||
};
|
||||
recorder.start(chunkDuration);
|
||||
|
||||
startTime = Date.now();
|
||||
timerInterval = setInterval(updateTimer, 1000);
|
||||
drawWaveform();
|
||||
|
||||
isRecording = true;
|
||||
updateUI();
|
||||
} catch (err) {
|
||||
if (window.location.hostname === "0.0.0.0") {
|
||||
statusText.textContent =
|
||||
"Error accessing audio input. Browsers may block audio access on 0.0.0.0. Try using localhost:8000 instead.";
|
||||
} else {
|
||||
statusText.textContent = "Error accessing audio input. Please check permissions.";
|
||||
}
|
||||
console.error(err);
|
||||
}
|
||||
}
|
||||
|
||||
async function stopRecording() {
|
||||
if (wakeLock) {
|
||||
try {
|
||||
await wakeLock.release();
|
||||
} catch (e) {
|
||||
// ignore
|
||||
}
|
||||
wakeLock = null;
|
||||
}
|
||||
|
||||
userClosing = true;
|
||||
waitingForStop = true;
|
||||
|
||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||
const emptyBlob = new Blob([], { type: "audio/webm" });
|
||||
websocket.send(emptyBlob);
|
||||
statusText.textContent = "Recording stopped. Processing final audio...";
|
||||
}
|
||||
|
||||
if (recorder) {
|
||||
recorder.stop();
|
||||
recorder = null;
|
||||
}
|
||||
|
||||
if (microphone) {
|
||||
microphone.disconnect();
|
||||
microphone = null;
|
||||
}
|
||||
|
||||
if (analyser) {
|
||||
analyser = null;
|
||||
}
|
||||
|
||||
if (audioContext && audioContext.state !== "closed") {
|
||||
try {
|
||||
await audioContext.close();
|
||||
} catch (e) {
|
||||
console.warn("Could not close audio context:", e);
|
||||
}
|
||||
audioContext = null;
|
||||
}
|
||||
|
||||
if (animationFrame) {
|
||||
cancelAnimationFrame(animationFrame);
|
||||
animationFrame = null;
|
||||
}
|
||||
|
||||
if (timerInterval) {
|
||||
clearInterval(timerInterval);
|
||||
timerInterval = null;
|
||||
}
|
||||
timerElement.textContent = "00:00";
|
||||
startTime = null;
|
||||
|
||||
isRecording = false;
|
||||
updateUI();
|
||||
}
|
||||
|
||||
async function toggleRecording() {
|
||||
if (!isRecording) {
|
||||
if (waitingForStop) {
|
||||
console.log("Waiting for stop, early return");
|
||||
return;
|
||||
}
|
||||
console.log("Connecting to WebSocket");
|
||||
try {
|
||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||
await startRecording();
|
||||
} else {
|
||||
await setupWebSocket();
|
||||
await startRecording();
|
||||
}
|
||||
} catch (err) {
|
||||
statusText.textContent = "Could not connect to WebSocket or access mic. Aborted.";
|
||||
console.error(err);
|
||||
}
|
||||
} else {
|
||||
console.log("Stopping recording");
|
||||
stopRecording();
|
||||
}
|
||||
}
|
||||
|
||||
function updateUI() {
|
||||
recordButton.classList.toggle("recording", isRecording);
|
||||
recordButton.disabled = waitingForStop;
|
||||
|
||||
if (waitingForStop) {
|
||||
if (statusText.textContent !== "Recording stopped. Processing final audio...") {
|
||||
statusText.textContent = "Please wait for processing to complete...";
|
||||
}
|
||||
} else if (isRecording) {
|
||||
statusText.textContent = "Recording...";
|
||||
} else {
|
||||
if (
|
||||
statusText.textContent !== "Finished processing audio! Ready to record again." &&
|
||||
statusText.textContent !== "Processing finalized or connection closed."
|
||||
) {
|
||||
statusText.textContent = "Click to start transcription";
|
||||
}
|
||||
}
|
||||
if (!waitingForStop) {
|
||||
recordButton.disabled = false;
|
||||
}
|
||||
}
|
||||
|
||||
recordButton.addEventListener("click", toggleRecording);
|
||||
|
||||
if (microphoneSelect) {
|
||||
microphoneSelect.addEventListener("change", handleMicrophoneChange);
|
||||
}
|
||||
|
||||
// Settings toggle functionality
|
||||
settingsToggle.addEventListener("click", () => {
|
||||
settingsDiv.classList.toggle("visible");
|
||||
settingsToggle.classList.toggle("active");
|
||||
});
|
||||
|
||||
document.addEventListener('DOMContentLoaded', async () => {
|
||||
try {
|
||||
await enumerateMicrophones();
|
||||
} catch (error) {
|
||||
console.log("Could not enumerate microphones on load:", error);
|
||||
}
|
||||
});
|
||||
navigator.mediaDevices.addEventListener('devicechange', async () => {
|
||||
console.log('Device change detected, re-enumerating microphones');
|
||||
try {
|
||||
await enumerateMicrophones();
|
||||
} catch (error) {
|
||||
console.log("Error re-enumerating microphones:", error);
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
async function run() {
|
||||
const micPermission = await navigator.permissions.query({
|
||||
name: "microphone",
|
||||
});
|
||||
|
||||
document.getElementById(
|
||||
"audioPermission"
|
||||
).innerText = `MICROPHONE: ${micPermission.state}`;
|
||||
|
||||
if (micPermission.state !== "granted") {
|
||||
chrome.tabs.create({ url: "welcome.html" });
|
||||
}
|
||||
|
||||
const intervalId = setInterval(async () => {
|
||||
const micPermission = await navigator.permissions.query({
|
||||
name: "microphone",
|
||||
});
|
||||
if (micPermission.state === "granted") {
|
||||
document.getElementById(
|
||||
"audioPermission"
|
||||
).innerText = `MICROPHONE: ${micPermission.state}`;
|
||||
clearInterval(intervalId);
|
||||
}
|
||||
}, 100);
|
||||
}
|
||||
|
||||
void run();
|
||||
@@ -3,9 +3,6 @@
|
||||
"name": "WhisperLiveKit Tab Capture",
|
||||
"version": "1.0",
|
||||
"description": "Capture and transcribe audio from browser tabs using WhisperLiveKit.",
|
||||
"background": {
|
||||
"service_worker": "background.js"
|
||||
},
|
||||
"icons": {
|
||||
"16": "icons/icon16.png",
|
||||
"32": "icons/icon32.png",
|
||||
@@ -14,7 +11,7 @@
|
||||
},
|
||||
"action": {
|
||||
"default_title": "WhisperLiveKit Tab Capture",
|
||||
"default_popup": "popup.html"
|
||||
"default_popup": "live_transcription.html"
|
||||
},
|
||||
"permissions": [
|
||||
"scripting",
|
||||
@@ -22,16 +19,5 @@
|
||||
"offscreen",
|
||||
"activeTab",
|
||||
"storage"
|
||||
],
|
||||
"web_accessible_resources": [
|
||||
{
|
||||
"resources": [
|
||||
"requestPermissions.html",
|
||||
"requestPermissions.js"
|
||||
],
|
||||
"matches": [
|
||||
"<all_urls>"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,78 +0,0 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>WhisperLiveKit</title>
|
||||
<link rel="stylesheet" href="/web/live_transcription.css" />
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<div class="settings-container">
|
||||
<button id="recordButton">
|
||||
<div class="shape-container">
|
||||
<div class="shape"></div>
|
||||
</div>
|
||||
<div class="recording-info">
|
||||
<div class="wave-container">
|
||||
<canvas id="waveCanvas"></canvas>
|
||||
</div>
|
||||
<div class="timer">00:00</div>
|
||||
</div>
|
||||
</button>
|
||||
|
||||
<button id="settingsToggle" class="settings-toggle" title="Show/hide settings">
|
||||
<img src="/web/src/settings.svg" alt="Settings" />
|
||||
</button>
|
||||
|
||||
<div class="settings">
|
||||
<div class="field">
|
||||
<label for="websocketInput">Websocket URL</label>
|
||||
<input id="websocketInput" type="text" placeholder="ws://host:port/asr" />
|
||||
</div>
|
||||
|
||||
<div class="field">
|
||||
<label id="microphoneSelectLabel" for="microphoneSelect">Select Microphone</label>
|
||||
<select id="microphoneSelect">
|
||||
<option value="">Default Microphone</option>
|
||||
</select>
|
||||
<div id="audioPermission"></div>
|
||||
|
||||
</div>
|
||||
|
||||
<div class="theme-selector-container">
|
||||
<div class="segmented" role="radiogroup" aria-label="Theme selector">
|
||||
<input type="radio" id="theme-system" name="theme" value="system" />
|
||||
<label for="theme-system" title="System">
|
||||
<img src="/web/src/system_mode.svg" alt="" />
|
||||
<!-- <span>System</span> -->
|
||||
</label>
|
||||
|
||||
<input type="radio" id="theme-light" name="theme" value="light" />
|
||||
<label for="theme-light" title="Light">
|
||||
<img src="/web/src/light_mode.svg" alt="" />
|
||||
<!-- <span>Light</span> -->
|
||||
</label>
|
||||
|
||||
<input type="radio" id="theme-dark" name="theme" value="dark" />
|
||||
<label for="theme-dark" title="Dark">
|
||||
<img src="/web/src/dark_mode.svg" alt="" />
|
||||
<!-- <span>Dark</span> -->
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
<p id="status"></p>
|
||||
|
||||
<div id="linesTranscript"></div>
|
||||
|
||||
<script src="live_transcription.js"></script>
|
||||
</body>
|
||||
|
||||
</html>
|
||||
@@ -1,539 +0,0 @@
|
||||
:root {
|
||||
--bg: #ffffff;
|
||||
--text: #111111;
|
||||
--muted: #666666;
|
||||
--border: #e5e5e5;
|
||||
--chip-bg: rgba(0, 0, 0, 0.04);
|
||||
--chip-text: #000000;
|
||||
--spinner-border: #8d8d8d5c;
|
||||
--spinner-top: #b0b0b0;
|
||||
--silence-bg: #f3f3f3;
|
||||
--loading-bg: rgba(255, 77, 77, 0.06);
|
||||
--button-bg: #ffffff;
|
||||
--button-border: #e9e9e9;
|
||||
--wave-stroke: #000000;
|
||||
--label-dia-text: #868686;
|
||||
--label-trans-text: #111111;
|
||||
}
|
||||
|
||||
@media (prefers-color-scheme: dark) {
|
||||
:root:not([data-theme="light"]) {
|
||||
--bg: #0b0b0b;
|
||||
--text: #e6e6e6;
|
||||
--muted: #9aa0a6;
|
||||
--border: #333333;
|
||||
--chip-bg: rgba(255, 255, 255, 0.08);
|
||||
--chip-text: #e6e6e6;
|
||||
--spinner-border: #555555;
|
||||
--spinner-top: #dddddd;
|
||||
--silence-bg: #1a1a1a;
|
||||
--loading-bg: rgba(255, 77, 77, 0.12);
|
||||
--button-bg: #111111;
|
||||
--button-border: #333333;
|
||||
--wave-stroke: #e6e6e6;
|
||||
--label-dia-text: #b3b3b3;
|
||||
--label-trans-text: #ffffff;
|
||||
}
|
||||
}
|
||||
|
||||
:root[data-theme="dark"] {
|
||||
--bg: #0b0b0b;
|
||||
--text: #e6e6e6;
|
||||
--muted: #9aa0a6;
|
||||
--border: #333333;
|
||||
--chip-bg: rgba(255, 255, 255, 0.08);
|
||||
--chip-text: #e6e6e6;
|
||||
--spinner-border: #555555;
|
||||
--spinner-top: #dddddd;
|
||||
--silence-bg: #1a1a1a;
|
||||
--loading-bg: rgba(255, 77, 77, 0.12);
|
||||
--button-bg: #111111;
|
||||
--button-border: #333333;
|
||||
--wave-stroke: #e6e6e6;
|
||||
--label-dia-text: #b3b3b3;
|
||||
--label-trans-text: #ffffff;
|
||||
}
|
||||
|
||||
:root[data-theme="light"] {
|
||||
--bg: #ffffff;
|
||||
--text: #111111;
|
||||
--muted: #666666;
|
||||
--border: #e5e5e5;
|
||||
--chip-bg: rgba(0, 0, 0, 0.04);
|
||||
--chip-text: #000000;
|
||||
--spinner-border: #8d8d8d5c;
|
||||
--spinner-top: #b0b0b0;
|
||||
--silence-bg: #f3f3f3;
|
||||
--loading-bg: rgba(255, 77, 77, 0.06);
|
||||
--button-bg: #ffffff;
|
||||
--button-border: #e9e9e9;
|
||||
--wave-stroke: #000000;
|
||||
--label-dia-text: #868686;
|
||||
--label-trans-text: #111111;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: ui-sans-serif, system-ui, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol', 'Noto Color Emoji';
|
||||
margin: 20px;
|
||||
text-align: center;
|
||||
background-color: var(--bg);
|
||||
color: var(--text);
|
||||
}
|
||||
|
||||
.settings-toggle {
|
||||
margin-top: 4px;
|
||||
width: 40px;
|
||||
height: 40px;
|
||||
border: none;
|
||||
border-radius: 50%;
|
||||
background-color: var(--button-bg);
|
||||
cursor: pointer;
|
||||
transition: all 0.3s ease;
|
||||
/* border: 1px solid var(--button-border); */
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.settings-toggle:hover {
|
||||
background-color: var(--chip-bg);
|
||||
}
|
||||
|
||||
.settings-toggle img {
|
||||
width: 24px;
|
||||
height: 24px;
|
||||
opacity: 0.7;
|
||||
transition: opacity 0.2s ease, transform 0.3s ease;
|
||||
}
|
||||
|
||||
.settings-toggle:hover img {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.settings-toggle.active img {
|
||||
transform: rotate(80deg);
|
||||
}
|
||||
|
||||
/* Record button */
|
||||
#recordButton {
|
||||
width: 50px;
|
||||
height: 50px;
|
||||
border: none;
|
||||
border-radius: 50%;
|
||||
background-color: var(--button-bg);
|
||||
cursor: pointer;
|
||||
transition: all 0.3s ease;
|
||||
border: 1px solid var(--button-border);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
#recordButton.recording {
|
||||
width: 180px;
|
||||
border-radius: 40px;
|
||||
justify-content: flex-start;
|
||||
padding-left: 20px;
|
||||
}
|
||||
|
||||
#recordButton:active {
|
||||
transform: scale(0.95);
|
||||
}
|
||||
|
||||
.shape-container {
|
||||
width: 25px;
|
||||
height: 25px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.shape {
|
||||
width: 25px;
|
||||
height: 25px;
|
||||
background-color: rgb(209, 61, 53);
|
||||
border-radius: 50%;
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
|
||||
#recordButton:disabled .shape {
|
||||
background-color: #6e6d6d;
|
||||
}
|
||||
|
||||
#recordButton.recording .shape {
|
||||
border-radius: 5px;
|
||||
width: 25px;
|
||||
height: 25px;
|
||||
}
|
||||
|
||||
/* Recording elements */
|
||||
.recording-info {
|
||||
display: none;
|
||||
align-items: center;
|
||||
margin-left: 15px;
|
||||
flex-grow: 1;
|
||||
}
|
||||
|
||||
#recordButton.recording .recording-info {
|
||||
display: flex;
|
||||
}
|
||||
|
||||
.wave-container {
|
||||
width: 60px;
|
||||
height: 30px;
|
||||
position: relative;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
#waveCanvas {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
.timer {
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
color: var(--text);
|
||||
margin-left: 10px;
|
||||
}
|
||||
|
||||
#status {
|
||||
margin-top: 20px;
|
||||
font-size: 16px;
|
||||
color: var(--text);
|
||||
}
|
||||
|
||||
/* Settings */
|
||||
.settings-container {
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: flex-start;
|
||||
gap: 15px;
|
||||
margin-top: 20px;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.settings {
|
||||
display: none;
|
||||
flex-wrap: wrap;
|
||||
align-items: flex-start;
|
||||
gap: 12px;
|
||||
transition: opacity 0.3s ease;
|
||||
}
|
||||
|
||||
.settings.visible {
|
||||
display: flex;
|
||||
}
|
||||
|
||||
.field {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: flex-start;
|
||||
gap: 3px;
|
||||
}
|
||||
|
||||
#chunkSelector,
|
||||
#websocketInput,
|
||||
#themeSelector,
|
||||
#microphoneSelect {
|
||||
font-size: 16px;
|
||||
padding: 5px 8px;
|
||||
border-radius: 8px;
|
||||
border: 1px solid var(--border);
|
||||
background-color: var(--button-bg);
|
||||
color: var(--text);
|
||||
max-height: 30px;
|
||||
}
|
||||
|
||||
#microphoneSelect {
|
||||
width: 100%;
|
||||
max-width: 190px;
|
||||
min-width: 120px;
|
||||
}
|
||||
|
||||
#chunkSelector:focus,
|
||||
#websocketInput:focus,
|
||||
#themeSelector:focus,
|
||||
#microphoneSelect:focus {
|
||||
outline: none;
|
||||
border-color: #007bff;
|
||||
box-shadow: 0 0 0 3px rgba(0, 123, 255, 0.15);
|
||||
}
|
||||
|
||||
label {
|
||||
font-size: 13px;
|
||||
color: var(--muted);
|
||||
}
|
||||
|
||||
.ws-default {
|
||||
font-size: 12px;
|
||||
color: var(--muted);
|
||||
}
|
||||
|
||||
/* Segmented pill control for Theme */
|
||||
.segmented {
|
||||
display: inline-flex;
|
||||
align-items: stretch;
|
||||
border: 1px solid var(--button-border);
|
||||
background-color: var(--button-bg);
|
||||
border-radius: 999px;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.segmented input[type="radio"] {
|
||||
position: absolute;
|
||||
opacity: 0;
|
||||
pointer-events: none;
|
||||
}
|
||||
|
||||
.theme-selector-container {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin-top: 17px;
|
||||
}
|
||||
|
||||
.segmented label {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
padding: 6px 12px;
|
||||
font-size: 14px;
|
||||
color: var(--muted);
|
||||
cursor: pointer;
|
||||
user-select: none;
|
||||
transition: background-color 0.2s ease, color 0.2s ease;
|
||||
}
|
||||
|
||||
.segmented label span {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.segmented label:hover span {
|
||||
display: inline;
|
||||
}
|
||||
|
||||
.segmented label:hover {
|
||||
background-color: var(--chip-bg);
|
||||
}
|
||||
|
||||
.segmented img {
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
}
|
||||
|
||||
.segmented input[type="radio"]:checked + label {
|
||||
background-color: var(--chip-bg);
|
||||
color: var(--text);
|
||||
}
|
||||
|
||||
.segmented input[type="radio"]:focus-visible + label,
|
||||
.segmented input[type="radio"]:focus + label {
|
||||
outline: 2px solid #007bff;
|
||||
outline-offset: 2px;
|
||||
border-radius: 999px;
|
||||
}
|
||||
|
||||
/* Transcript area */
|
||||
#linesTranscript {
|
||||
margin: 20px auto;
|
||||
max-width: 700px;
|
||||
text-align: left;
|
||||
font-size: 16px;
|
||||
}
|
||||
|
||||
#linesTranscript p {
|
||||
margin: 0px 0;
|
||||
}
|
||||
|
||||
#linesTranscript strong {
|
||||
color: var(--text);
|
||||
}
|
||||
|
||||
#speaker {
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 100px;
|
||||
padding: 2px 10px;
|
||||
font-size: 14px;
|
||||
margin-bottom: 0px;
|
||||
}
|
||||
|
||||
.label_diarization {
|
||||
background-color: var(--chip-bg);
|
||||
border-radius: 8px 8px 8px 8px;
|
||||
padding: 2px 10px;
|
||||
margin-left: 10px;
|
||||
display: inline-block;
|
||||
white-space: nowrap;
|
||||
font-size: 14px;
|
||||
margin-bottom: 0px;
|
||||
color: var(--label-dia-text);
|
||||
}
|
||||
|
||||
.label_transcription {
|
||||
background-color: var(--chip-bg);
|
||||
border-radius: 8px 8px 8px 8px;
|
||||
padding: 2px 10px;
|
||||
display: inline-block;
|
||||
white-space: nowrap;
|
||||
margin-left: 10px;
|
||||
font-size: 14px;
|
||||
margin-bottom: 0px;
|
||||
color: var(--label-trans-text);
|
||||
}
|
||||
|
||||
#timeInfo {
|
||||
color: var(--muted);
|
||||
margin-left: 10px;
|
||||
}
|
||||
|
||||
.textcontent {
|
||||
font-size: 16px;
|
||||
padding-left: 10px;
|
||||
margin-bottom: 10px;
|
||||
margin-top: 1px;
|
||||
padding-top: 5px;
|
||||
border-radius: 0px 0px 0px 10px;
|
||||
}
|
||||
|
||||
.buffer_diarization {
|
||||
color: var(--label-dia-text);
|
||||
margin-left: 4px;
|
||||
}
|
||||
|
||||
.buffer_transcription {
|
||||
color: #7474748c;
|
||||
margin-left: 4px;
|
||||
}
|
||||
|
||||
.spinner {
|
||||
display: inline-block;
|
||||
width: 8px;
|
||||
height: 8px;
|
||||
border: 2px solid var(--spinner-border);
|
||||
border-top: 2px solid var(--spinner-top);
|
||||
border-radius: 50%;
|
||||
animation: spin 0.7s linear infinite;
|
||||
vertical-align: middle;
|
||||
margin-bottom: 2px;
|
||||
margin-right: 5px;
|
||||
}
|
||||
|
||||
@keyframes spin {
|
||||
to {
|
||||
transform: rotate(360deg);
|
||||
}
|
||||
}
|
||||
|
||||
.silence {
|
||||
color: var(--muted);
|
||||
background-color: var(--silence-bg);
|
||||
font-size: 13px;
|
||||
border-radius: 30px;
|
||||
padding: 2px 10px;
|
||||
}
|
||||
|
||||
.loading {
|
||||
color: var(--muted);
|
||||
background-color: var(--loading-bg);
|
||||
border-radius: 8px 8px 8px 0px;
|
||||
padding: 2px 10px;
|
||||
font-size: 14px;
|
||||
margin-bottom: 0px;
|
||||
}
|
||||
|
||||
/* for smaller screens */
|
||||
/* @media (max-width: 450px) {
|
||||
.settings-container {
|
||||
flex-direction: column;
|
||||
gap: 10px;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.settings {
|
||||
justify-content: center;
|
||||
gap: 8px;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.field {
|
||||
align-items: center;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
#websocketInput,
|
||||
#microphoneSelect {
|
||||
min-width: 200px;
|
||||
max-width: 100%;
|
||||
}
|
||||
|
||||
.theme-selector-container {
|
||||
margin-top: 10px;
|
||||
}
|
||||
} */
|
||||
|
||||
/* @media (max-width: 768px) and (min-width: 451px) {
|
||||
.settings-container {
|
||||
gap: 10px;
|
||||
}
|
||||
|
||||
.settings {
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
#websocketInput,
|
||||
#microphoneSelect {
|
||||
min-width: 150px;
|
||||
max-width: 300px;
|
||||
}
|
||||
} */
|
||||
|
||||
/* @media (max-width: 480px) {
|
||||
body {
|
||||
margin: 10px;
|
||||
}
|
||||
|
||||
.settings-toggle {
|
||||
width: 35px;
|
||||
height: 35px;
|
||||
}
|
||||
|
||||
.settings-toggle img {
|
||||
width: 20px;
|
||||
height: 20px;
|
||||
}
|
||||
|
||||
.settings {
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
}
|
||||
|
||||
#websocketInput,
|
||||
#microphoneSelect {
|
||||
max-width: 400px;
|
||||
}
|
||||
|
||||
.segmented label {
|
||||
padding: 4px 8px;
|
||||
font-size: 12px;
|
||||
}
|
||||
|
||||
.segmented img {
|
||||
width: 14px;
|
||||
height: 14px;
|
||||
}
|
||||
} */
|
||||
|
||||
|
||||
html
|
||||
{
|
||||
width: 400px; /* max: 800px */
|
||||
height: 600px; /* max: 600px */
|
||||
border-radius: 10px;
|
||||
|
||||
}
|
||||
@@ -1 +0,0 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M480-120q-151 0-255.5-104.5T120-480q0-138 90-239.5T440-838q13-2 23 3.5t16 14.5q6 9 6.5 21t-7.5 23q-17 26-25.5 55t-8.5 61q0 90 63 153t153 63q31 0 61.5-9t54.5-25q11-7 22.5-6.5T819-479q10 5 15.5 15t3.5 24q-14 138-117.5 229T480-120Zm0-80q88 0 158-48.5T740-375q-20 5-40 8t-40 3q-123 0-209.5-86.5T364-660q0-20 3-40t8-40q-78 32-126.5 102T200-480q0 116 82 198t198 82Zm-10-270Z"/></svg>
|
||||
|
Before Width: | Height: | Size: 493 B |
@@ -1 +0,0 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M480-360q50 0 85-35t35-85q0-50-35-85t-85-35q-50 0-85 35t-35 85q0 50 35 85t85 35Zm0 80q-83 0-141.5-58.5T280-480q0-83 58.5-141.5T480-680q83 0 141.5 58.5T680-480q0 83-58.5 141.5T480-280ZM80-440q-17 0-28.5-11.5T40-480q0-17 11.5-28.5T80-520h80q17 0 28.5 11.5T200-480q0 17-11.5 28.5T160-440H80Zm720 0q-17 0-28.5-11.5T760-480q0-17 11.5-28.5T800-520h80q17 0 28.5 11.5T920-480q0 17-11.5 28.5T880-440h-80ZM480-760q-17 0-28.5-11.5T440-800v-80q0-17 11.5-28.5T480-920q17 0 28.5 11.5T520-880v80q0 17-11.5 28.5T480-760Zm0 720q-17 0-28.5-11.5T440-80v-80q0-17 11.5-28.5T480-200q17 0 28.5 11.5T520-160v80q0 17-11.5 28.5T480-40ZM226-678l-43-42q-12-11-11.5-28t11.5-29q12-12 29-12t28 12l42 43q11 12 11 28t-11 28q-11 12-27.5 11.5T226-678Zm494 495-42-43q-11-12-11-28.5t11-27.5q11-12 27.5-11.5T734-282l43 42q12 11 11.5 28T777-183q-12 12-29 12t-28-12Zm-42-495q-12-11-11.5-27.5T678-734l42-43q11-12 28-11.5t29 11.5q12 12 12 29t-12 28l-43 42q-12 11-28 11t-28-11ZM183-183q-12-12-12-29t12-28l43-42q12-11 28.5-11t27.5 11q12 11 11.5 27.5T282-226l-42 43q-11 12-28 11.5T183-183Zm297-297Z"/></svg>
|
||||
|
Before Width: | Height: | Size: 1.2 KiB |
@@ -1 +0,0 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M433-80q-27 0-46.5-18T363-142l-9-66q-13-5-24.5-12T307-235l-62 26q-25 11-50 2t-39-32l-47-82q-14-23-8-49t27-43l53-40q-1-7-1-13.5v-27q0-6.5 1-13.5l-53-40q-21-17-27-43t8-49l47-82q14-23 39-32t50 2l62 26q11-8 23-15t24-12l9-66q4-26 23.5-44t46.5-18h94q27 0 46.5 18t23.5 44l9 66q13 5 24.5 12t22.5 15l62-26q25-11 50-2t39 32l47 82q14 23 8 49t-27 43l-53 40q1 7 1 13.5v27q0 6.5-2 13.5l53 40q21 17 27 43t-8 49l-48 82q-14 23-39 32t-50-2l-60-26q-11 8-23 15t-24 12l-9 66q-4 26-23.5 44T527-80h-94Zm7-80h79l14-106q31-8 57.5-23.5T639-327l99 41 39-68-86-65q5-14 7-29.5t2-31.5q0-16-2-31.5t-7-29.5l86-65-39-68-99 42q-22-23-48.5-38.5T533-694l-13-106h-79l-14 106q-31 8-57.5 23.5T321-633l-99-41-39 68 86 64q-5 15-7 30t-2 32q0 16 2 31t7 30l-86 65 39 68 99-42q22 23 48.5 38.5T427-266l13 106Zm42-180q58 0 99-41t41-99q0-58-41-99t-99-41q-59 0-99.5 41T342-480q0 58 40.5 99t99.5 41Zm-2-140Z"/></svg>
|
||||
|
Before Width: | Height: | Size: 982 B |
@@ -1 +0,0 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M396-396q-32-32-58.5-67T289-537q-5 14-6.5 28.5T281-480q0 83 58 141t141 58q14 0 28.5-2t28.5-6q-39-22-74-48.5T396-396Zm85 196q-56 0-107-21t-91-61q-40-40-61-91t-21-107q0-51 17-97.5t50-84.5q13-14 32-9.5t27 24.5q21 55 52.5 104t73.5 91q42 42 91 73.5T648-326q20 8 24.5 27t-9.5 32q-38 33-84.5 50T481-200Zm223-192q-16-5-23-20.5t-4-32.5q9-48-6-94.5T621-621q-35-35-80.5-49.5T448-677q-17 3-32-4t-21-23q-6-16 1.5-31t23.5-19q69-15 138 4.5T679-678q51 51 71 120t5 138q-4 17-19 25t-32 3ZM480-840q-17 0-28.5-11.5T440-880v-40q0-17 11.5-28.5T480-960q17 0 28.5 11.5T520-920v40q0 17-11.5 28.5T480-840Zm0 840q-17 0-28.5-11.5T440-40v-40q0-17 11.5-28.5T480-120q17 0 28.5 11.5T520-80v40q0 17-11.5 28.5T480 0Zm255-734q-12-12-12-28.5t12-28.5l28-28q11-11 27.5-11t28.5 11q12 12 12 28.5T819-762l-28 28q-12 12-28 12t-28-12ZM141-141q-12-12-12-28.5t12-28.5l28-28q12-12 28-12t28 12q12 12 12 28.5T225-169l-28 28q-11 11-27.5 11T141-141Zm739-299q-17 0-28.5-11.5T840-480q0-17 11.5-28.5T880-520h40q17 0 28.5 11.5T960-480q0 17-11.5 28.5T920-440h-40Zm-840 0q-17 0-28.5-11.5T0-480q0-17 11.5-28.5T40-520h40q17 0 28.5 11.5T120-480q0 17-11.5 28.5T80-440H40Zm779 299q-12 12-28.5 12T762-141l-28-28q-12-12-12-28t12-28q12-12 28.5-12t28.5 12l28 28q11 11 11 27.5T819-141ZM226-735q-12 12-28.5 12T169-735l-28-28q-11-11-11-27.5t11-28.5q12-12 28.5-12t28.5 12l28 28q12 12 12 28t-12 28Zm170 339Z"/></svg>
|
||||
|
Before Width: | Height: | Size: 1.4 KiB |
@@ -1,12 +0,0 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Welcome</title>
|
||||
<script src="welcome.js"></script>
|
||||
</head>
|
||||
<body>
|
||||
This page exists to workaround an issue with Chrome that blocks permission
|
||||
requests from chrome extensions
|
||||
<!-- <button id="requestMicrophone">Request Microphone</button> -->
|
||||
</body>
|
||||
</html>
|
||||
BIN
demo.png
|
Before Width: | Height: | Size: 449 KiB After Width: | Height: | Size: 985 KiB |
264
docs/API.md
Normal file
@@ -0,0 +1,264 @@
|
||||
# WhisperLiveKit WebSocket API Documentation
|
||||
|
||||
> !! **Note**: The new API structure described in this document is currently under deployment.
|
||||
This documentation is intended for devs who want to build custom frontends.
|
||||
|
||||
WLK provides real-time speech transcription, speaker diarization, and translation through a WebSocket API. The server sends incremental updates as audio is processed, allowing clients to display live transcription results with minimal latency.
|
||||
|
||||
---
|
||||
|
||||
## Legacy API (Current)
|
||||
|
||||
### Message Structure
|
||||
|
||||
The current API sends complete state snapshots on each update (several time per second)
|
||||
|
||||
```typescript
|
||||
{
|
||||
"type": str,
|
||||
"status": str,
|
||||
"lines": [
|
||||
{
|
||||
"speaker": int,
|
||||
"text": str,
|
||||
"start": float,
|
||||
"end": float,
|
||||
"translation": str | null,
|
||||
"detected_language": str
|
||||
}
|
||||
],
|
||||
"buffer_transcription": str,
|
||||
"buffer_diarization": str,
|
||||
"remaining_time_transcription": float,
|
||||
"remaining_time_diarization": float
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## New API (Under Development)
|
||||
|
||||
### Philosophy
|
||||
|
||||
Principles:
|
||||
|
||||
- **Incremental Updates**: Only updates and new segments are sent
|
||||
- **Ephemeral Buffers**: Temporary, unvalidated data displayed in real-time but overwritten on next update, at speaker level
|
||||
|
||||
|
||||
## Message Format
|
||||
|
||||
|
||||
```typescript
|
||||
{
|
||||
"type": "transcript_update",
|
||||
"status": "active_transcription" | "no_audio_detected",
|
||||
"segments": [
|
||||
{
|
||||
"id": number,
|
||||
"speaker": number,
|
||||
"text": string,
|
||||
"start_speaker": float,
|
||||
"start": float,
|
||||
"end": float,
|
||||
"language": string | null,
|
||||
"translation": string,
|
||||
"words": [
|
||||
{
|
||||
"text": string,
|
||||
"start": float,
|
||||
"end": float,
|
||||
"validated": {
|
||||
"text": boolean,
|
||||
"speaker": boolean,
|
||||
}
|
||||
}
|
||||
],
|
||||
"buffer": {
|
||||
"transcription": string,
|
||||
"diarization": string,
|
||||
"translation": string
|
||||
}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"remaining_time_transcription": float,
|
||||
"remaining_time_diarization": float
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Other Message Types
|
||||
|
||||
#### Config Message (sent on connection)
|
||||
```json
|
||||
{
|
||||
"type": "config",
|
||||
"useAudioWorklet": true / false
|
||||
}
|
||||
```
|
||||
|
||||
#### Ready to Stop Message (sent after processing complete)
|
||||
```json
|
||||
{
|
||||
"type": "ready_to_stop"
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Field Descriptions
|
||||
|
||||
### Segment Fields
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `id` | `number` | Unique identifier for this segment. Used by clients to update specific segments efficiently. |
|
||||
| `speaker` | `number` | Speaker ID (1, 2, 3...). Special value `-2` indicates silence. |
|
||||
| `text` | `string` | Validated transcription text for this update. Should be **appended** to the segment's text on the client side. |
|
||||
| `start_speaker` | `float` | Timestamp (seconds) when this speaker segment began. |
|
||||
| `start` | `float` | Timestamp (seconds) of the first word in this update. |
|
||||
| `end` | `float` | Timestamp (seconds) of the last word in this update. |
|
||||
| `language` | `string \| null` | ISO language code (e.g., "en", "fr"). `null` until language is detected. |
|
||||
| `translation` | `string` | Validated translation text for this update. Should be **appended** to the segment's translation on the client side. |
|
||||
| `words` | `Array` | Array of word-level objects with timing and validation information. |
|
||||
| `buffer` | `Object` | Per-segment temporary buffers, see below |
|
||||
|
||||
### Word Object
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `text` | `string` | The word text. |
|
||||
| `start` | `number` | Start timestamp (seconds) of this word. |
|
||||
| `end` | `number` | End timestamp (seconds) of this word. |
|
||||
| `validated.text` | `boolean` | Whether the transcription text has been validated. if false, word is also in buffer: transcription |
|
||||
| `validated.speaker` | `boolean` | Whether the speaker assignment has been validated. if false, word is also in buffer: diarization |
|
||||
| `validated.language` | `boolean` | Whether the language detection has been validated. if false, word is also in buffer: translation |
|
||||
|
||||
### Buffer Object (Per-Segment)
|
||||
|
||||
Buffers are **ephemeral**. They should be displayed to the user but not stored permanently in the frontend. Each update may contain a completely different buffer value, and previous buffer is likely to be in the next validated text.
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `transcription` | `string` | Pending transcription text. Displayed immediately but **overwritten** on next update. |
|
||||
| `diarization` | `string` | Pending diarization text (text waiting for speaker assignment). Displayed immediately but **overwritten** on next update. |
|
||||
| `translation` | `string` | Pending translation text. Displayed immediately but **overwritten** on next update. |
|
||||
|
||||
|
||||
### Metadata Fields
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `remaining_time_transcription` | `float` | Seconds of audio waiting for transcription processing. |
|
||||
| `remaining_time_diarization` | `float` | Seconds of audio waiting for speaker diarization. |
|
||||
|
||||
### Status Values
|
||||
|
||||
| Status | Description |
|
||||
|--------|-------------|
|
||||
| `active_transcription` | Normal operation, transcription is active. |
|
||||
| `no_audio_detected` | No audio has been detected yet. |
|
||||
|
||||
---
|
||||
|
||||
## Update Behavior
|
||||
|
||||
### Incremental Updates
|
||||
|
||||
The API sends **only changed or new segments**. Clients should:
|
||||
|
||||
1. Maintain a local map of segments by ID
|
||||
2. When receiving an update, merge/update segments by ID
|
||||
3. Render only the changed segments
|
||||
|
||||
### Language Detection
|
||||
|
||||
When language is detected for a segment:
|
||||
|
||||
```jsonc
|
||||
// Update 1: No language yet
|
||||
{
|
||||
"segments": [
|
||||
{"id": 1, "speaker": 1, "text": "May see", "language": null}
|
||||
]
|
||||
}
|
||||
|
||||
// Update 2: Same segment ID, language now detected
|
||||
{
|
||||
"segments": [
|
||||
{"id": 1, "speaker": 1, "text": "Merci", "language": "fr"}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**Client behavior**: **Replace** the existing segment with the same ID.
|
||||
|
||||
### Buffer Behavior
|
||||
|
||||
Buffers are **per-segment** to handle multi-speaker scenarios correctly.
|
||||
|
||||
#### Example: Translation with diarization and translation
|
||||
|
||||
```jsonc
|
||||
// Update 1
|
||||
{
|
||||
"segments": [
|
||||
{
|
||||
"id": 1,
|
||||
"speaker": 1,
|
||||
"text": "Hello world, how are",
|
||||
"translation": "",
|
||||
"buffer": {
|
||||
"transcription": "",
|
||||
"diarization": " you on",
|
||||
"translation": "Bonjour le monde"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
// ==== Frontend ====
|
||||
// <SPEAKER>1</SPEAKER>
|
||||
// <TRANSCRIPTION>Hello world, how are <DIARIZATION BUFFER> you on</DIARIZATION BUFFER></TRANSCRIPTION>
|
||||
// <TRANSLATION><TRANSLATION BUFFER>Bonjour le monde</TRANSLATION BUFFER></TRANSLATION>
|
||||
|
||||
|
||||
// Update 2
|
||||
{
|
||||
"segments": [
|
||||
{
|
||||
"id": 1,
|
||||
"speaker": 1,
|
||||
"text": " you on this",
|
||||
"translation": "Bonjour tout le monde",
|
||||
"buffer": {
|
||||
"transcription": "",
|
||||
"diarization": " beautiful day",
|
||||
"translation": ",comment"
|
||||
}
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
// ==== Frontend ====
|
||||
// <SPEAKER>1</SPEAKER>
|
||||
// <TRANSCRIPTION>Hello world, how are you on this<DIARIZATION BUFFER> beautiful day</DIARIZATION BUFFER></TRANSCRIPTION>
|
||||
// <TRANSLATION>Bonjour tout le monde<TRANSLATION BUFFER>, comment</TRANSLATION BUFFER><TRANSLATION>
|
||||
```
|
||||
|
||||
### Silence Segments
|
||||
|
||||
Silence is represented with the speaker id = `-2`:
|
||||
|
||||
```jsonc
|
||||
{
|
||||
"id": 5,
|
||||
"speaker": -2,
|
||||
"text": "",
|
||||
"start": 10.5,
|
||||
"end": 12.3
|
||||
}
|
||||
```
|
||||
71
docs/alignement_principles.md
Normal file
@@ -0,0 +1,71 @@
|
||||
### Alignment between STT Tokens and Diarization Segments
|
||||
|
||||
- Example 1: The punctuation from STT and the speaker change from Diariation come in the prediction `t`
|
||||
- Example 2: The punctuation from STT comes from prediction `t`, but the speaker change from Diariation come in the prediction `t-1`
|
||||
- Example 3: The punctuation from STT comes from prediction `t-1`, but the speaker change from Diariation come in the prediction `t`
|
||||
|
||||
> `#` Is the split between the `t-1` prediction and `t` prediction.
|
||||
|
||||
|
||||
## Example 1:
|
||||
```text
|
||||
punctuations_segments : __#_______.__________________!____
|
||||
diarization_segments:
|
||||
SPK1 __#____________
|
||||
SPK2 # ___________________
|
||||
-->
|
||||
ALIGNED SPK1 __#_______.
|
||||
ALIGNED SPK2 # __________________!____
|
||||
|
||||
t-1 output:
|
||||
SPK1: __#
|
||||
SPK2: NO
|
||||
DIARIZATION BUFFER: NO
|
||||
|
||||
t output:
|
||||
SPK1: __#__.
|
||||
SPK2: __________________!____
|
||||
DIARIZATION BUFFER: No
|
||||
```
|
||||
|
||||
## Example 2:
|
||||
```text
|
||||
punctuations_segments : _____#__.___________
|
||||
diarization_segments:
|
||||
SPK1 ___ #
|
||||
SPK2 __#______________
|
||||
-->
|
||||
ALIGNED SPK1 _____#__.
|
||||
ALIGNED SPK2 # ___________
|
||||
|
||||
t-1 output:
|
||||
SPK1: ___ #
|
||||
SPK2:
|
||||
DIARIZATION BUFFER: __#
|
||||
|
||||
t output:
|
||||
SPK1: __#__.
|
||||
SPK2: ___________
|
||||
DIARIZATION BUFFER: No
|
||||
```
|
||||
|
||||
## Example 3:
|
||||
```text
|
||||
punctuations_segments : ___.__#__________
|
||||
diarization_segments:
|
||||
SPK1 ______#__
|
||||
SPK2 # ________
|
||||
-->
|
||||
ALIGNED SPK1 ___. #
|
||||
ALIGNED SPK2 __#__________
|
||||
|
||||
t-1 output:
|
||||
SPK1: ___. #
|
||||
SPK2:
|
||||
DIARIZATION BUFFER: __#
|
||||
|
||||
t output:
|
||||
SPK1: #
|
||||
SPK2: __#___________
|
||||
DIARIZATION BUFFER: NO
|
||||
```
|
||||
@@ -1,4 +1,4 @@
|
||||
# Available model sizes:
|
||||
# Available Whisper model sizes:
|
||||
|
||||
- tiny.en (english only)
|
||||
- tiny
|
||||
@@ -71,3 +71,39 @@
|
||||
3. Good hardware and want best quality? → `large-v3`
|
||||
4. Need fast, high-quality transcription without translation? → `large-v3-turbo`
|
||||
5. Need translation capabilities? → `large-v2` or `large-v3` (avoid turbo)
|
||||
|
||||
|
||||
_______________________
|
||||
|
||||
# Translation Models and Backend
|
||||
|
||||
**Language Support**: ~200 languages
|
||||
|
||||
## Distilled Model Sizes Available
|
||||
|
||||
| Model | Size | Parameters | VRAM (FP16) | VRAM (INT8) | Quality |
|
||||
|-------|------|------------|-------------|-------------|---------|
|
||||
| 600M | 2.46 GB | 600M | ~1.5GB | ~800MB | Good, understandable |
|
||||
| 1.3B | 5.48 GB | 1.3B | ~3GB | ~1.5GB | Better accuracy, context |
|
||||
|
||||
**Quality Impact**: 1.3B has ~15-25% better BLEU scores vs 600M across language pairs.
|
||||
|
||||
## Backend Performance
|
||||
|
||||
| Backend | Speed vs Base | Memory Usage | Quality Loss |
|
||||
|---------|---------------|--------------|--------------|
|
||||
| CTranslate2 | 6-10x faster | 40-60% less | ~5% BLEU drop |
|
||||
| Transformers | Baseline | High | None |
|
||||
| Transformers + MPS (on Apple Silicon) | 2x faster | Medium | None |
|
||||
|
||||
**Metrics**:
|
||||
- CTranslate2: 50-100+ tokens/sec
|
||||
- Transformers: 10-30 tokens/sec
|
||||
- Apple Silicon with MPS: Up to 2x faster than CTranslate2
|
||||
|
||||
## Quick Decision Matrix
|
||||
|
||||
**Choose 600M**: Limited resources, close to 0 lag
|
||||
**Choose 1.3B**: Quality matters
|
||||
**Choose Transformers**: On Apple Silicon
|
||||
|
||||
19
docs/models_compatible_formats.md
Normal file
@@ -0,0 +1,19 @@
|
||||
# Model Path Formats
|
||||
|
||||
The `--model-path` parameter accepts:
|
||||
|
||||
## File Path
|
||||
- **`.pt` / `.bin` / `.safetensor` formats** Should be openable by pytorch/safetensor.
|
||||
|
||||
## Directory Path (recommended)
|
||||
Must contain:
|
||||
- **`.pt` / `.bin` / `.safetensor` file** (required for decoder)
|
||||
|
||||
May optionally contain:
|
||||
- **`.bin` file** - faster-whisper model for encoder (requires faster-whisper)
|
||||
- **`weights.npz`** or **`weights.safetensors`** - for encoder (requires whisper-mlx)
|
||||
|
||||
## Hugging Face Repo ID
|
||||
- Provide the repo ID (e.g. `openai/whisper-large-v3`) and WhisperLiveKit will download and cache the snapshot automatically. For gated repos, authenticate via `huggingface-cli login` first.
|
||||
|
||||
To improve speed/reduce allucinations, you may want to use `scripts/determine_alignment_heads.py` to determine the alignment heads to use for your model, and use the `--custom-alignment-heads` to pass them to WLK. If not, alignement heads are set to be all the heads of the last half layer of decoder.
|
||||
265
docs/supported_languages.md
Normal file
@@ -0,0 +1,265 @@
|
||||
# Supported Languages
|
||||
|
||||
WhisperLiveKit supports translation into **201 languages** from the FLORES-200 dataset through the NLLB (No Language Left Behind) translation system.
|
||||
|
||||
## How to Specify Languages
|
||||
|
||||
You can specify languages in **three different ways**:
|
||||
|
||||
1. **Language Name** (case-insensitive): `"English"`, `"French"`, `"Spanish"`
|
||||
2. **ISO Language Code**: `"en"`, `"fr"`, `"es"`
|
||||
3. **NLLB Code** (FLORES-200): `"eng_Latn"`, `"fra_Latn"`, `"spa_Latn"`
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Command Line
|
||||
```bash
|
||||
# Using language name
|
||||
whisperlivekit-server --target-language "French"
|
||||
|
||||
# Using ISO code
|
||||
whisperlivekit-server --target-language fr
|
||||
|
||||
# Using NLLB code
|
||||
whisperlivekit-server --target-language fra_Latn
|
||||
```
|
||||
|
||||
### Python API
|
||||
```python
|
||||
from nllw.translation import get_language_info
|
||||
|
||||
# Get language information by name
|
||||
lang_info = get_language_info("French")
|
||||
print(lang_info)
|
||||
# {'name': 'French', 'nllb': 'fra_Latn', 'language_code': 'fr'}
|
||||
|
||||
# Get language information by ISO code
|
||||
lang_info = get_language_info("fr")
|
||||
|
||||
# Get language information by NLLB code
|
||||
lang_info = get_language_info("fra_Latn")
|
||||
|
||||
# All three return the same result
|
||||
```
|
||||
|
||||
## Complete Language List
|
||||
|
||||
The following table lists all 201 supported languages with their corresponding codes:
|
||||
|
||||
| Language Name | ISO Code | NLLB Code |
|
||||
|---------------|----------|-----------|
|
||||
| Acehnese (Arabic script) | ace_Arab | ace_Arab |
|
||||
| Acehnese (Latin script) | ace_Latn | ace_Latn |
|
||||
| Mesopotamian Arabic | acm_Arab | acm_Arab |
|
||||
| Ta'izzi-Adeni Arabic | acq_Arab | acq_Arab |
|
||||
| Tunisian Arabic | aeb_Arab | aeb_Arab |
|
||||
| Afrikaans | af | afr_Latn |
|
||||
| South Levantine Arabic | ajp_Arab | ajp_Arab |
|
||||
| Akan | ak | aka_Latn |
|
||||
| Tosk Albanian | als | als_Latn |
|
||||
| Amharic | am | amh_Ethi |
|
||||
| North Levantine Arabic | apc_Arab | apc_Arab |
|
||||
| Modern Standard Arabic | ar | arb_Arab |
|
||||
| Modern Standard Arabic (Romanized) | arb_Latn | arb_Latn |
|
||||
| Najdi Arabic | ars_Arab | ars_Arab |
|
||||
| Moroccan Arabic | ary_Arab | ary_Arab |
|
||||
| Egyptian Arabic | arz_Arab | arz_Arab |
|
||||
| Assamese | as | asm_Beng |
|
||||
| Asturian | ast | ast_Latn |
|
||||
| Awadhi | awa | awa_Deva |
|
||||
| Central Aymara | ay | ayr_Latn |
|
||||
| South Azerbaijani | azb | azb_Arab |
|
||||
| North Azerbaijani | az | azj_Latn |
|
||||
| Bashkir | ba | bak_Cyrl |
|
||||
| Bambara | bm | bam_Latn |
|
||||
| Balinese | ban | ban_Latn |
|
||||
| Belarusian | be | bel_Cyrl |
|
||||
| Bemba | bem | bem_Latn |
|
||||
| Bengali | bn | ben_Beng |
|
||||
| Bhojpuri | bho | bho_Deva |
|
||||
| Banjar (Arabic script) | bjn_Arab | bjn_Arab |
|
||||
| Banjar (Latin script) | bjn_Latn | bjn_Latn |
|
||||
| Standard Tibetan | bo | bod_Tibt |
|
||||
| Bosnian | bs | bos_Latn |
|
||||
| Buginese | bug | bug_Latn |
|
||||
| Bulgarian | bg | bul_Cyrl |
|
||||
| Catalan | ca | cat_Latn |
|
||||
| Cebuano | ceb | ceb_Latn |
|
||||
| Czech | cs | ces_Latn |
|
||||
| Chokwe | cjk | cjk_Latn |
|
||||
| Central Kurdish | ckb | ckb_Arab |
|
||||
| Crimean Tatar | crh | crh_Latn |
|
||||
| Welsh | cy | cym_Latn |
|
||||
| Danish | da | dan_Latn |
|
||||
| German | de | deu_Latn |
|
||||
| Southwestern Dinka | dik | dik_Latn |
|
||||
| Dyula | dyu | dyu_Latn |
|
||||
| Dzongkha | dz | dzo_Tibt |
|
||||
| Greek | el | ell_Grek |
|
||||
| English | en | eng_Latn |
|
||||
| Esperanto | eo | epo_Latn |
|
||||
| Estonian | et | est_Latn |
|
||||
| Basque | eu | eus_Latn |
|
||||
| Ewe | ee | ewe_Latn |
|
||||
| Faroese | fo | fao_Latn |
|
||||
| Fijian | fj | fij_Latn |
|
||||
| Finnish | fi | fin_Latn |
|
||||
| Fon | fon | fon_Latn |
|
||||
| French | fr | fra_Latn |
|
||||
| Friulian | fur-IT | fur_Latn |
|
||||
| Nigerian Fulfulde | fuv | fuv_Latn |
|
||||
| West Central Oromo | om | gaz_Latn |
|
||||
| Scottish Gaelic | gd | gla_Latn |
|
||||
| Irish | ga-IE | gle_Latn |
|
||||
| Galician | gl | glg_Latn |
|
||||
| Guarani | gn | grn_Latn |
|
||||
| Gujarati | gu-IN | guj_Gujr |
|
||||
| Haitian Creole | ht | hat_Latn |
|
||||
| Hausa | ha | hau_Latn |
|
||||
| Hebrew | he | heb_Hebr |
|
||||
| Hindi | hi | hin_Deva |
|
||||
| Chhattisgarhi | hne | hne_Deva |
|
||||
| Croatian | hr | hrv_Latn |
|
||||
| Hungarian | hu | hun_Latn |
|
||||
| Armenian | hy-AM | hye_Armn |
|
||||
| Igbo | ig | ibo_Latn |
|
||||
| Ilocano | ilo | ilo_Latn |
|
||||
| Indonesian | id | ind_Latn |
|
||||
| Icelandic | is | isl_Latn |
|
||||
| Italian | it | ita_Latn |
|
||||
| Javanese | jv | jav_Latn |
|
||||
| Japanese | ja | jpn_Jpan |
|
||||
| Kabyle | kab | kab_Latn |
|
||||
| Jingpho | kac | kac_Latn |
|
||||
| Kamba | kam | kam_Latn |
|
||||
| Kannada | kn | kan_Knda |
|
||||
| Kashmiri (Arabic script) | kas_Arab | kas_Arab |
|
||||
| Kashmiri (Devanagari script) | kas_Deva | kas_Deva |
|
||||
| Georgian | ka | kat_Geor |
|
||||
| Kazakh | kk | kaz_Cyrl |
|
||||
| Kabiyè | kbp | kbp_Latn |
|
||||
| Kabuverdianu | kea | kea_Latn |
|
||||
| Halh Mongolian | mn | khk_Cyrl |
|
||||
| Khmer | km | khm_Khmr |
|
||||
| Kikuyu | ki | kik_Latn |
|
||||
| Kinyarwanda | rw | kin_Latn |
|
||||
| Kyrgyz | ky | kir_Cyrl |
|
||||
| Kimbundu | kmb | kmb_Latn |
|
||||
| Northern Kurdish | kmr | kmr_Latn |
|
||||
| Central Kanuri (Arabic script) | knc_Arab | knc_Arab |
|
||||
| Central Kanuri (Latin script) | knc_Latn | knc_Latn |
|
||||
| Kikongo | kg | kon_Latn |
|
||||
| Korean | ko | kor_Hang |
|
||||
| Lao | lo | lao_Laoo |
|
||||
| Ligurian | lij | lij_Latn |
|
||||
| Limburgish | li | lim_Latn |
|
||||
| Lingala | ln | lin_Latn |
|
||||
| Lithuanian | lt | lit_Latn |
|
||||
| Lombard | lmo | lmo_Latn |
|
||||
| Latgalian | ltg | ltg_Latn |
|
||||
| Luxembourgish | lb | ltz_Latn |
|
||||
| Luba-Kasai | lua | lua_Latn |
|
||||
| Ganda | lg | lug_Latn |
|
||||
| Luo | luo | luo_Latn |
|
||||
| Mizo | lus | lus_Latn |
|
||||
| Standard Latvian | lv | lvs_Latn |
|
||||
| Magahi | mag | mag_Deva |
|
||||
| Maithili | mai | mai_Deva |
|
||||
| Malayalam | ml-IN | mal_Mlym |
|
||||
| Marathi | mr | mar_Deva |
|
||||
| Minangkabau (Arabic script) | min_Arab | min_Arab |
|
||||
| Minangkabau (Latin script) | min_Latn | min_Latn |
|
||||
| Macedonian | mk | mkd_Cyrl |
|
||||
| Maltese | mt | mlt_Latn |
|
||||
| Meitei (Bengali script) | mni | mni_Beng |
|
||||
| Mossi | mos | mos_Latn |
|
||||
| Maori | mi | mri_Latn |
|
||||
| Burmese | my | mya_Mymr |
|
||||
| Dutch | nl | nld_Latn |
|
||||
| Norwegian Nynorsk | nn-NO | nno_Latn |
|
||||
| Norwegian Bokmål | nb | nob_Latn |
|
||||
| Nepali | ne-NP | npi_Deva |
|
||||
| Northern Sotho | nso | nso_Latn |
|
||||
| Nuer | nus | nus_Latn |
|
||||
| Nyanja | ny | nya_Latn |
|
||||
| Occitan | oc | oci_Latn |
|
||||
| Odia | or | ory_Orya |
|
||||
| Pangasinan | pag | pag_Latn |
|
||||
| Eastern Panjabi | pa | pan_Guru |
|
||||
| Papiamento | pap | pap_Latn |
|
||||
| Southern Pashto | pbt | pbt_Arab |
|
||||
| Western Persian | fa | pes_Arab |
|
||||
| Plateau Malagasy | mg | plt_Latn |
|
||||
| Polish | pl | pol_Latn |
|
||||
| Portuguese | pt-PT | por_Latn |
|
||||
| Dari | fa-AF | prs_Arab |
|
||||
| Ayacucho Quechua | qu | quy_Latn |
|
||||
| Romanian | ro | ron_Latn |
|
||||
| Rundi | rn | run_Latn |
|
||||
| Russian | ru | rus_Cyrl |
|
||||
| Sango | sg | sag_Latn |
|
||||
| Sanskrit | sa | san_Deva |
|
||||
| Santali | sat | sat_Olck |
|
||||
| Sicilian | scn | scn_Latn |
|
||||
| Shan | shn | shn_Mymr |
|
||||
| Sinhala | si-LK | sin_Sinh |
|
||||
| Slovak | sk | slk_Latn |
|
||||
| Slovenian | sl | slv_Latn |
|
||||
| Samoan | sm | smo_Latn |
|
||||
| Shona | sn | sna_Latn |
|
||||
| Sindhi | sd | snd_Arab |
|
||||
| Somali | so | som_Latn |
|
||||
| Southern Sotho | st | sot_Latn |
|
||||
| Spanish | es-ES | spa_Latn |
|
||||
| Sardinian | sc | srd_Latn |
|
||||
| Serbian | sr | srp_Cyrl |
|
||||
| Swati | ss | ssw_Latn |
|
||||
| Sundanese | su | sun_Latn |
|
||||
| Swedish | sv-SE | swe_Latn |
|
||||
| Swahili | sw | swh_Latn |
|
||||
| Silesian | szl | szl_Latn |
|
||||
| Tamil | ta | tam_Taml |
|
||||
| Tamasheq (Latin script) | taq_Latn | taq_Latn |
|
||||
| Tamasheq (Tifinagh script) | taq_Tfng | taq_Tfng |
|
||||
| Tatar | tt-RU | tat_Cyrl |
|
||||
| Telugu | te | tel_Telu |
|
||||
| Tajik | tg | tgk_Cyrl |
|
||||
| Tagalog | tl | tgl_Latn |
|
||||
| Thai | th | tha_Thai |
|
||||
| Tigrinya | ti | tir_Ethi |
|
||||
| Tok Pisin | tpi | tpi_Latn |
|
||||
| Tswana | tn | tsn_Latn |
|
||||
| Tsonga | ts | tso_Latn |
|
||||
| Turkmen | tk | tuk_Latn |
|
||||
| Tumbuka | tum | tum_Latn |
|
||||
| Turkish | tr | tur_Latn |
|
||||
| Twi | tw | twi_Latn |
|
||||
| Central Atlas Tamazight | tzm | tzm_Tfng |
|
||||
| Uyghur | ug | uig_Arab |
|
||||
| Ukrainian | uk | ukr_Cyrl |
|
||||
| Umbundu | umb | umb_Latn |
|
||||
| Urdu | ur | urd_Arab |
|
||||
| Northern Uzbek | uz | uzn_Latn |
|
||||
| Venetian | vec | vec_Latn |
|
||||
| Vietnamese | vi | vie_Latn |
|
||||
| Waray | war | war_Latn |
|
||||
| Wolof | wo | wol_Latn |
|
||||
| Xhosa | xh | xho_Latn |
|
||||
| Eastern Yiddish | yi | ydd_Hebr |
|
||||
| Yoruba | yo | yor_Latn |
|
||||
| Yue Chinese | yue | yue_Hant |
|
||||
| Chinese (Simplified) | zh-CN | zho_Hans |
|
||||
| Chinese (Traditional) | zh-TW | zho_Hant |
|
||||
| Standard Malay | ms | zsm_Latn |
|
||||
| Zulu | zu | zul_Latn |
|
||||
|
||||
## Special Features
|
||||
|
||||
### Multiple Script Support
|
||||
Several languages are available in multiple scripts (e.g., Arabic and Latin):
|
||||
- **Acehnese**: Arabic (`ace_Arab`) and Latin (`ace_Latn`)
|
||||
- **Banjar**: Arabic (`bjn_Arab`) and Latin (`bjn_Latn`)
|
||||
- **Kashmiri**: Arabic (`kas_Arab`) and Devanagari (`kas_Deva`)
|
||||
- **Minangkabau**: Arabic (`min_Arab`) and Latin (`min_Latn`)
|
||||
- **Tamasheq**: Latin (`taq_Latn`) and Tifinagh (`taq_Tfng`)
|
||||
- **Central Kanuri**: Arabic (`knc_Arab`) and Latin (`knc_Latn`)
|
||||
43
docs/technical_integration.md
Normal file
@@ -0,0 +1,43 @@
|
||||
# Technical Integration Guide
|
||||
|
||||
This document introduce how to reuse the core components when you do **not** want to ship the bundled frontend, FastAPI server, or even the provided CLI.
|
||||
|
||||
---
|
||||
|
||||
## 1. Runtime Components
|
||||
|
||||
| Layer | File(s) | Purpose |
|
||||
|-------|---------|---------|
|
||||
| Transport | `whisperlivekit/basic_server.py`, any ASGI/WebSocket server | Accepts audio over WebSocket (MediaRecorder WebM or raw PCM chunks) and streams JSON updates back |
|
||||
| Audio processing | `whisperlivekit/audio_processor.py` | Buffers audio, orchestrates transcription, diarization, translation, handles FFmpeg/PCM input |
|
||||
| Engines | `whisperlivekit/core.py`, `whisperlivekit/simul_whisper/*`, `whisperlivekit/local_agreement/*` | Load models once (SimulStreaming or LocalAgreement), expose `TranscriptionEngine` and helpers |
|
||||
| Frontends | `whisperlivekit/web/*`, `chrome-extension/*` | Optional UI layers feeding the WebSocket endpoint |
|
||||
|
||||
**Key idea:** The server boundary is just `AudioProcessor.process_audio()` for incoming bytes and the async generator returned by `AudioProcessor.create_tasks()` for outgoing updates (`FrontData`). Everything else is optional.
|
||||
|
||||
---
|
||||
|
||||
## 2. Running Without the Bundled Frontend
|
||||
|
||||
1. Start the server/engine however you like:
|
||||
```bash
|
||||
wlk --model small --language en --host 0.0.0.0 --port 9000
|
||||
# or launch your own app that instantiates TranscriptionEngine(...)
|
||||
```
|
||||
2. Build your own client (browser, mobile, desktop) that:
|
||||
- Opens `ws(s)://<host>:<port>/asr`
|
||||
- Sends either MediaRecorder/Opus WebM blobs **or** raw PCM (`--pcm-input` on the server tells the client to use the AudioWorklet).
|
||||
- Consumes the JSON payload defined in `docs/API.md`.
|
||||
|
||||
---
|
||||
|
||||
## 3. Running Without FastAPI
|
||||
|
||||
`whisperlivekit/basic_server.py` is just an example. Any async framework works, as long as you:
|
||||
|
||||
1. Create a global `TranscriptionEngine` (expensive to initialize; reuse it).
|
||||
2. Instantiate `AudioProcessor(transcription_engine=engine)` for each connection.
|
||||
3. Call `create_tasks()` to get the async generator, `process_audio()` with incoming bytes, and ensure `cleanup()` runs when the client disconnects.
|
||||
|
||||
|
||||
If you prefer to send compressed audio, instantiate `AudioProcessor(pcm_input=False)` and pipe encoded chunks through `FFmpegManager` transparently—just ensure `ffmpeg` is available or be ready to handle the `"ffmpeg_not_found"` error in the streamed `FrontData`.
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "whisperlivekit"
|
||||
version = "0.2.9"
|
||||
version = "0.2.15"
|
||||
description = "Real-time speech-to-text with speaker diarization using Whisper"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
@@ -30,28 +30,41 @@ dependencies = [
|
||||
"fastapi",
|
||||
"librosa",
|
||||
"soundfile",
|
||||
"faster-whisper",
|
||||
"uvicorn",
|
||||
"websockets",
|
||||
"torchaudio>=2.0.0",
|
||||
"torch>=2.0.0",
|
||||
"huggingface-hub>=0.25.0",
|
||||
"tqdm",
|
||||
"tiktoken",
|
||||
'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")'
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
sentence = ["mosestokenizer", "wtpsplit"]
|
||||
translation = ["nllw"]
|
||||
sentence_tokenizer = ["mosestokenizer", "wtpsplit"]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/QuentinFuxa/WhisperLiveKit"
|
||||
|
||||
[project.scripts]
|
||||
whisperlivekit-server = "whisperlivekit.basic_server:main"
|
||||
wlk = "whisperlivekit.basic_server:main"
|
||||
|
||||
[tool.setuptools]
|
||||
packages = ["whisperlivekit", "whisperlivekit.diarization", "whisperlivekit.simul_whisper", "whisperlivekit.simul_whisper.whisper", "whisperlivekit.simul_whisper.whisper.assets", "whisperlivekit.simul_whisper.whisper.normalizers", "whisperlivekit.web", "whisperlivekit.whisper_streaming_custom"]
|
||||
packages = [
|
||||
"whisperlivekit",
|
||||
"whisperlivekit.diarization",
|
||||
"whisperlivekit.simul_whisper",
|
||||
"whisperlivekit.whisper",
|
||||
"whisperlivekit.whisper.assets",
|
||||
"whisperlivekit.whisper.normalizers",
|
||||
"whisperlivekit.web",
|
||||
"whisperlivekit.local_agreement",
|
||||
"whisperlivekit.vad_models"
|
||||
]
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"]
|
||||
"whisperlivekit.simul_whisper.whisper.assets" = ["*.tiktoken", "*.npz"]
|
||||
"whisperlivekit.whisper.assets" = ["*.tiktoken", "*.npz"]
|
||||
"whisperlivekit.vad_models" = ["*.jit", "*.onnx"]
|
||||
|
||||
BIN
scripts/alignment_heads.png
Normal file
|
After Width: | Height: | Size: 276 KiB |
153
scripts/convert_hf_whisper.py
Normal file
@@ -0,0 +1,153 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Convert a Hugging Face style Whisper checkpoint into a WhisperLiveKit .pt file.
|
||||
|
||||
Optionally shrink the supported audio chunk length (in seconds) by trimming the
|
||||
encoder positional embeddings and updating the stored model dimensions.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from whisperlivekit.whisper.audio import HOP_LENGTH, SAMPLE_RATE
|
||||
from whisperlivekit.whisper.model import ModelDimensions
|
||||
from whisperlivekit.whisper.utils import exact_div
|
||||
from whisperlivekit.whisper import _convert_hf_state_dict
|
||||
|
||||
|
||||
def _load_state_dict(repo_path: Path) -> Dict[str, torch.Tensor]:
|
||||
safetensor_path = repo_path / "model.safetensors"
|
||||
bin_path = repo_path / "pytorch_model.bin"
|
||||
|
||||
if safetensor_path.is_file():
|
||||
try:
|
||||
from safetensors.torch import load_file # type: ignore
|
||||
except Exception as exc: # pragma: no cover - import guard
|
||||
raise RuntimeError(
|
||||
"Install safetensors to load model.safetensors "
|
||||
"(pip install safetensors)"
|
||||
) from exc
|
||||
return load_file(str(safetensor_path))
|
||||
|
||||
if bin_path.is_file():
|
||||
return torch.load(bin_path, map_location="cpu")
|
||||
|
||||
raise FileNotFoundError(
|
||||
f"Could not find model.safetensors or pytorch_model.bin under {repo_path}"
|
||||
)
|
||||
|
||||
|
||||
def _load_config(repo_path: Path) -> Dict:
|
||||
config_path = repo_path / "config.json"
|
||||
if not config_path.is_file():
|
||||
raise FileNotFoundError(
|
||||
f"Hugging Face checkpoint at {repo_path} is missing config.json"
|
||||
)
|
||||
with open(config_path, "r", encoding="utf-8") as fp:
|
||||
return json.load(fp)
|
||||
|
||||
|
||||
def _derive_audio_ctx(chunk_length: float) -> Tuple[int, int]:
|
||||
n_samples = int(round(chunk_length * SAMPLE_RATE))
|
||||
expected_samples = chunk_length * SAMPLE_RATE
|
||||
if abs(n_samples - expected_samples) > 1e-6:
|
||||
raise ValueError(
|
||||
"chunk_length must align with sample rate so that "
|
||||
"chunk_length * SAMPLE_RATE is an integer"
|
||||
)
|
||||
n_frames = exact_div(n_samples, HOP_LENGTH)
|
||||
n_audio_ctx = exact_div(n_frames, 2)
|
||||
return n_frames, n_audio_ctx
|
||||
|
||||
|
||||
def _build_dims(config: Dict, chunk_length: float) -> Dict:
|
||||
base_dims = ModelDimensions(
|
||||
n_mels=config["num_mel_bins"],
|
||||
n_audio_ctx=config["max_source_positions"],
|
||||
n_audio_state=config["d_model"],
|
||||
n_audio_head=config["encoder_attention_heads"],
|
||||
n_audio_layer=config.get("encoder_layers") or config["num_hidden_layers"],
|
||||
n_vocab=config["vocab_size"],
|
||||
n_text_ctx=config["max_target_positions"],
|
||||
n_text_state=config["d_model"],
|
||||
n_text_head=config["decoder_attention_heads"],
|
||||
n_text_layer=config["decoder_layers"],
|
||||
).__dict__.copy()
|
||||
|
||||
_, n_audio_ctx = _derive_audio_ctx(chunk_length)
|
||||
base_dims["n_audio_ctx"] = n_audio_ctx
|
||||
base_dims["chunk_length"] = chunk_length
|
||||
return base_dims
|
||||
|
||||
|
||||
def _trim_positional_embedding(
|
||||
state_dict: Dict[str, torch.Tensor], target_ctx: int
|
||||
) -> None:
|
||||
key = "encoder.positional_embedding"
|
||||
if key not in state_dict:
|
||||
raise KeyError(f"{key} missing from converted state dict")
|
||||
|
||||
tensor = state_dict[key]
|
||||
if tensor.shape[0] < target_ctx:
|
||||
raise ValueError(
|
||||
f"Cannot increase encoder ctx from {tensor.shape[0]} to {target_ctx}"
|
||||
)
|
||||
if tensor.shape[0] == target_ctx:
|
||||
return
|
||||
state_dict[key] = tensor[:target_ctx].contiguous()
|
||||
|
||||
|
||||
def convert_checkpoint(hf_path: Path, output_path: Path, chunk_length: float) -> None:
|
||||
state_dict = _load_state_dict(hf_path)
|
||||
converted = _convert_hf_state_dict(state_dict)
|
||||
|
||||
config = _load_config(hf_path)
|
||||
dims = _build_dims(config, chunk_length)
|
||||
|
||||
_trim_positional_embedding(converted, dims["n_audio_ctx"])
|
||||
|
||||
package = {"dims": dims, "model_state_dict": converted}
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(package, output_path)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert Hugging Face Whisper checkpoint to WhisperLiveKit format."
|
||||
)
|
||||
parser.add_argument(
|
||||
"hf_path",
|
||||
type=str,
|
||||
help="Path to the cloned Hugging Face repository (e.g. whisper-tiny.en)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="converted-whisper.pt",
|
||||
help="Destination path for the .pt file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chunk-length",
|
||||
type=float,
|
||||
default=30.0,
|
||||
help="Audio chunk length in seconds to support (default: 30)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
hf_path = Path(os.path.expanduser(args.hf_path)).resolve()
|
||||
output_path = Path(os.path.expanduser(args.output)).resolve()
|
||||
|
||||
convert_checkpoint(hf_path, output_path, args.chunk_length)
|
||||
print(f"Saved converted checkpoint to {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
292
scripts/determine_alignment_heads.py
Normal file
@@ -0,0 +1,292 @@
|
||||
"""Determine alignment heads for a variants, such as distilled model"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import gzip
|
||||
import io
|
||||
import pathlib
|
||||
import sys
|
||||
import math
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import Audio as DatasetAudio, load_dataset
|
||||
import soundfile as sf
|
||||
import matplotlib.pyplot as plt
|
||||
REPO_ROOT = pathlib.Path(__file__).resolve().parents[1]
|
||||
WHISPER_ROOT = REPO_ROOT / "whisper"
|
||||
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
sys.path.insert(0, str(WHISPER_ROOT))
|
||||
|
||||
from whisper import load_model
|
||||
from whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||
from whisper.tokenizer import get_tokenizer
|
||||
|
||||
AudioInput = Union[str, pathlib.Path, np.ndarray, torch.Tensor]
|
||||
|
||||
|
||||
def load_dataset_clips(name, config, split, limit):
|
||||
ds = load_dataset(name, config, split=split)
|
||||
ds = ds.cast_column("audio", DatasetAudio(decode=False))
|
||||
clips = []
|
||||
for idx, row in enumerate(ds):
|
||||
if limit is not None and idx >= limit:
|
||||
break
|
||||
audio_field = row["audio"]
|
||||
transcript = row["text"]
|
||||
|
||||
waveform_np, _ = sf.read(io.BytesIO(audio_field["bytes"]), dtype="float32")
|
||||
if waveform_np.ndim > 1:
|
||||
waveform_np = waveform_np.mean(axis=1)
|
||||
waveform = waveform_np
|
||||
transcript = str(transcript)
|
||||
|
||||
clips.append((waveform, transcript))
|
||||
return clips
|
||||
|
||||
|
||||
def load_clips(args):
|
||||
return load_dataset_clips(
|
||||
args.dataset,
|
||||
args.dataset_config,
|
||||
args.dataset_split,
|
||||
args.dataset_num_samples,
|
||||
)
|
||||
|
||||
|
||||
def _waveform_from_source(source: AudioInput) -> torch.Tensor:
|
||||
waveform = torch.from_numpy(source.astype(np.float32, copy=False))
|
||||
return waveform
|
||||
|
||||
|
||||
def _parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="pytorch_model.bin",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||
help="Torch device to run on",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default="librispeech_asr"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-config",
|
||||
type=str,
|
||||
default="clean"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-split",
|
||||
type=str,
|
||||
default="validation[:1%]",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-num-samples",
|
||||
type=int,
|
||||
default=16,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--threshold",
|
||||
type=float,
|
||||
default=1.5,
|
||||
help="Z score threshold for a head to be selected",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--votes",
|
||||
type=float,
|
||||
default=0.75,
|
||||
help="percentage of clips that must vote for a head",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="alignment_heads.b85",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--visualize-top-k",
|
||||
type=int,
|
||||
default=32,
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def collect_heads(
|
||||
model,
|
||||
tokenizer,
|
||||
clips: Sequence[Tuple[AudioInput, str]],
|
||||
threshold: float,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
device = model.device
|
||||
votes = torch.zeros(model.dims.n_text_layer, model.dims.n_text_head, device=device)
|
||||
strengths = torch.zeros_like(votes)
|
||||
|
||||
for audio_source, transcript in clips:
|
||||
waveform = pad_or_trim(_waveform_from_source(audio_source))
|
||||
mel = log_mel_spectrogram(waveform, device=device)
|
||||
|
||||
tokens = torch.tensor(
|
||||
[
|
||||
*tokenizer.sot_sequence,
|
||||
tokenizer.no_timestamps,
|
||||
*tokenizer.encode(transcript),
|
||||
tokenizer.eot,
|
||||
],
|
||||
device=device,
|
||||
)
|
||||
|
||||
qks = [None] * model.dims.n_text_layer
|
||||
hooks = [
|
||||
block.cross_attn.register_forward_hook(
|
||||
lambda _, __, outputs, index=i: qks.__setitem__(index, outputs[-1][0])
|
||||
)
|
||||
for i, block in enumerate(model.decoder.blocks)
|
||||
]
|
||||
|
||||
with torch.no_grad():
|
||||
model(mel.unsqueeze(0), tokens.unsqueeze(0))
|
||||
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
|
||||
for layer_idx, tensor in enumerate(qks):
|
||||
if tensor is None:
|
||||
continue
|
||||
tensor = tensor[:, :, : mel.shape[-1] // 2]
|
||||
tensor = tensor.softmax(dim=-1)
|
||||
peak = tensor.max(dim=-1).values # [heads, tokens]
|
||||
strengths[layer_idx] += peak.mean(dim=-1)
|
||||
zscore = (peak - peak.mean(dim=-1, keepdim=True)) / (
|
||||
peak.std(dim=-1, keepdim=True, unbiased=False) + 1e-6
|
||||
)
|
||||
mask = (zscore > 3).any(dim=-1)
|
||||
votes[layer_idx] += mask.float()
|
||||
|
||||
votes /= len(clips)
|
||||
strengths /= len(clips)
|
||||
return votes, strengths
|
||||
|
||||
|
||||
def _select_heads_for_visualization(selection, strengths, top_k):
|
||||
selected = torch.nonzero(selection, as_tuple=False)
|
||||
if selected.numel() == 0:
|
||||
return []
|
||||
|
||||
entries = [
|
||||
(int(layer.item()), int(head.item()), float(strengths[layer, head].item()))
|
||||
for layer, head in selected
|
||||
]
|
||||
entries.sort(key=lambda item: item[2], reverse=True)
|
||||
return entries[:top_k]
|
||||
|
||||
def _extract_heatmaps(
|
||||
model,
|
||||
tokenizer,
|
||||
clip: Tuple[AudioInput, str],
|
||||
heads: Sequence[Tuple[int, int, float]],
|
||||
) -> dict:
|
||||
if not heads:
|
||||
return {}
|
||||
|
||||
target_map = {}
|
||||
for layer, head, _ in heads:
|
||||
target_map.setdefault(layer, set()).add(head)
|
||||
|
||||
waveform = pad_or_trim(_waveform_from_source(clip[0]))
|
||||
mel = log_mel_spectrogram(waveform, device=model.device)
|
||||
transcript = clip[1]
|
||||
tokens = torch.tensor(
|
||||
[
|
||||
*tokenizer.sot_sequence,
|
||||
tokenizer.no_timestamps,
|
||||
*tokenizer.encode(transcript),
|
||||
tokenizer.eot,
|
||||
],
|
||||
device=model.device,
|
||||
)
|
||||
|
||||
QKs = [None] * model.dims.n_text_layer
|
||||
hooks = [
|
||||
block.cross_attn.register_forward_hook(
|
||||
lambda _, __, outputs, index=i: QKs.__setitem__(index, outputs[-1][0])
|
||||
)
|
||||
for i, block in enumerate(model.decoder.blocks)
|
||||
]
|
||||
|
||||
with torch.no_grad():
|
||||
model(mel.unsqueeze(0), tokens.unsqueeze(0))
|
||||
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
|
||||
heatmaps = {}
|
||||
for layer_idx, tensor in enumerate(QKs):
|
||||
if tensor is None or layer_idx not in target_map:
|
||||
continue
|
||||
tensor = tensor[:, :, : mel.shape[-1] // 2]
|
||||
tensor = tensor.softmax(dim=-1).cpu()
|
||||
for head_idx in target_map[layer_idx]:
|
||||
heatmaps[(layer_idx, head_idx)] = tensor[head_idx]
|
||||
|
||||
return heatmaps
|
||||
|
||||
|
||||
def _plot_heatmaps(
|
||||
heads, heatmaps, output_path):
|
||||
cols = min(3, len(heads))
|
||||
rows = math.ceil(len(heads) / cols)
|
||||
fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 3.2 * rows), squeeze=False)
|
||||
|
||||
for idx, (layer, head, score) in enumerate(heads):
|
||||
ax = axes[idx // cols][idx % cols]
|
||||
mat = heatmaps.get((layer, head))
|
||||
if mat is None:
|
||||
ax.axis("off")
|
||||
continue
|
||||
im = ax.imshow(mat.to(torch.float32).numpy(), aspect="auto", origin="lower")
|
||||
ax.set_title(f"L{layer} H{head} · score {score:.2f}")
|
||||
ax.set_xlabel("time")
|
||||
ax.set_ylabel("tokens")
|
||||
|
||||
for j in range(len(heads), rows * cols):
|
||||
axes[j // cols][j % cols].axis("off")
|
||||
|
||||
fig.tight_layout()
|
||||
fig.savefig(output_path, dpi=200)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def _dump_mask(mask: torch.Tensor, output_path: str):
|
||||
payload = mask.numpy().astype(np.bool_)
|
||||
blob = base64.b85encode(gzip.compress(payload.tobytes()))
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(blob)
|
||||
|
||||
|
||||
def main():
|
||||
args = _parse_args()
|
||||
model = load_model(args.model, device=args.device)
|
||||
model.eval()
|
||||
tokenizer = get_tokenizer(multilingual=model.is_multilingual)
|
||||
clips = load_clips(args)
|
||||
|
||||
votes, strengths = collect_heads(model, tokenizer, clips, args.threshold)
|
||||
# selection = votes > 0.5
|
||||
selection = strengths > 0.05
|
||||
_dump_mask(selection.cpu(), args.output)
|
||||
|
||||
viz_heads = _select_heads_for_visualization(selection, strengths, args.visualize_top_k)
|
||||
heatmaps = _extract_heatmaps(model, tokenizer, clips[0], viz_heads)
|
||||
_plot_heatmaps(viz_heads, heatmaps, "alignment_heads.png")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
39
scripts/sync_extension.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""Copy core files from web directory to Chrome extension directory."""
|
||||
|
||||
import shutil
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
def sync_extension_files():
|
||||
|
||||
web_dir = Path("whisperlivekit/web")
|
||||
extension_dir = Path("chrome-extension")
|
||||
|
||||
files_to_sync = [
|
||||
"live_transcription.html", "live_transcription.js", "live_transcription.css"
|
||||
]
|
||||
|
||||
svg_files = [
|
||||
"system_mode.svg",
|
||||
"light_mode.svg",
|
||||
"dark_mode.svg",
|
||||
"settings.svg"
|
||||
]
|
||||
|
||||
for file in files_to_sync:
|
||||
src_path = web_dir / file
|
||||
dest_path = extension_dir / file
|
||||
|
||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(src_path, dest_path)
|
||||
|
||||
for svg_file in svg_files:
|
||||
src_path = web_dir / "src" / svg_file
|
||||
dest_path = extension_dir / "web" / "src" / svg_file
|
||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(src_path, dest_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
sync_extension_files()
|
||||
@@ -1,29 +1,45 @@
|
||||
import asyncio
|
||||
import numpy as np
|
||||
from time import time, sleep
|
||||
import math
|
||||
from time import time
|
||||
import logging
|
||||
import traceback
|
||||
from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State
|
||||
from typing import Optional, Union, List, Any, AsyncGenerator
|
||||
from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State, Transcript, ChangeSpeaker
|
||||
from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory, online_translation_factory
|
||||
from whisperlivekit.silero_vad_iterator import FixedVADIterator
|
||||
from whisperlivekit.results_formater import format_output
|
||||
# Set up logging once
|
||||
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
|
||||
from whisperlivekit.tokens_alignment import TokensAlignment
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
SENTINEL = object() # unique sentinel object for end of stream marker
|
||||
MIN_DURATION_REAL_SILENCE = 5
|
||||
|
||||
async def get_all_from_queue(queue: asyncio.Queue) -> Union[object, Silence, np.ndarray, List[Any]]:
|
||||
items: List[Any] = []
|
||||
|
||||
first_item = await queue.get()
|
||||
queue.task_done()
|
||||
if first_item is SENTINEL:
|
||||
return first_item
|
||||
if isinstance(first_item, Silence):
|
||||
return first_item
|
||||
items.append(first_item)
|
||||
|
||||
async def get_all_from_queue(queue):
|
||||
items = []
|
||||
try:
|
||||
while True:
|
||||
item = queue.get_nowait()
|
||||
items.append(item)
|
||||
except asyncio.QueueEmpty:
|
||||
pass
|
||||
if not queue._queue:
|
||||
break
|
||||
next_item = queue._queue[0]
|
||||
if next_item is SENTINEL:
|
||||
break
|
||||
if isinstance(next_item, Silence):
|
||||
break
|
||||
items.append(await queue.get())
|
||||
queue.task_done()
|
||||
if isinstance(items[0], np.ndarray):
|
||||
return np.concatenate(items)
|
||||
else: #translation
|
||||
return items
|
||||
|
||||
class AudioProcessor:
|
||||
@@ -32,7 +48,7 @@ class AudioProcessor:
|
||||
Handles audio processing, state management, and result formatting.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the audio processor with configuration, models, and state."""
|
||||
|
||||
if 'transcription_engine' in kwargs and isinstance(kwargs['transcription_engine'], TranscriptionEngine):
|
||||
@@ -48,187 +64,261 @@ class AudioProcessor:
|
||||
self.bytes_per_sample = 2
|
||||
self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample
|
||||
self.max_bytes_per_sec = 32000 * 5 # 5 seconds of audio at 32 kHz
|
||||
self.is_pcm_input = True
|
||||
self.debug = False
|
||||
self.is_pcm_input = self.args.pcm_input
|
||||
|
||||
# State management
|
||||
self.is_stopping = False
|
||||
self.silence = False
|
||||
self.silence_duration = 0.0
|
||||
self.tokens = []
|
||||
self.translated_segments = []
|
||||
self.buffer_transcription = ""
|
||||
self.buffer_diarization = ""
|
||||
self.end_buffer = 0
|
||||
self.end_attributed_speaker = 0
|
||||
self.lock = asyncio.Lock()
|
||||
self.beg_loop = None #to deal with a potential little lag at the websocket initialization, this is now set in process_audio
|
||||
self.sep = " " # Default separator
|
||||
self.last_response_content = FrontData()
|
||||
self.is_stopping: bool = False
|
||||
self.current_silence: Optional[Silence] = None
|
||||
self.state: State = State()
|
||||
self.lock: asyncio.Lock = asyncio.Lock()
|
||||
self.sep: str = " " # Default separator
|
||||
self.last_response_content: FrontData = FrontData()
|
||||
|
||||
self.tokens_alignment: TokensAlignment = TokensAlignment(self.state, self.args, self.sep)
|
||||
self.beg_loop: Optional[float] = None
|
||||
|
||||
# Models and processing
|
||||
self.asr = models.asr
|
||||
self.tokenizer = models.tokenizer
|
||||
self.vac_model = models.vac_model
|
||||
self.asr: Any = models.asr
|
||||
self.vac_model: Any = models.vac_model
|
||||
if self.args.vac:
|
||||
self.vac = FixedVADIterator(models.vac_model)
|
||||
self.vac: Optional[FixedVADIterator] = FixedVADIterator(models.vac_model)
|
||||
else:
|
||||
self.vac = None
|
||||
self.vac: Optional[FixedVADIterator] = None
|
||||
|
||||
self.transcription_queue = asyncio.Queue() if self.args.transcription else None
|
||||
self.diarization_queue = asyncio.Queue() if self.args.diarization else None
|
||||
self.translation_queue = asyncio.Queue() if self.args.target_language else None
|
||||
self.pcm_buffer = bytearray()
|
||||
self.ffmpeg_manager: Optional[FFmpegManager] = None
|
||||
self.ffmpeg_reader_task: Optional[asyncio.Task] = None
|
||||
self._ffmpeg_error: Optional[str] = None
|
||||
|
||||
self.transcription_task = None
|
||||
self.diarization_task = None
|
||||
self.watchdog_task = None
|
||||
self.all_tasks_for_cleanup = []
|
||||
if not self.is_pcm_input:
|
||||
self.ffmpeg_manager = FFmpegManager(
|
||||
sample_rate=self.sample_rate,
|
||||
channels=self.channels
|
||||
)
|
||||
async def handle_ffmpeg_error(error_type: str):
|
||||
logger.error(f"FFmpeg error: {error_type}")
|
||||
self._ffmpeg_error = error_type
|
||||
self.ffmpeg_manager.on_error_callback = handle_ffmpeg_error
|
||||
|
||||
self.transcription_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.transcription else None
|
||||
self.diarization_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.diarization else None
|
||||
self.translation_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.target_language else None
|
||||
self.pcm_buffer: bytearray = bytearray()
|
||||
self.total_pcm_samples: int = 0
|
||||
self.transcription_task: Optional[asyncio.Task] = None
|
||||
self.diarization_task: Optional[asyncio.Task] = None
|
||||
self.translation_task: Optional[asyncio.Task] = None
|
||||
self.watchdog_task: Optional[asyncio.Task] = None
|
||||
self.all_tasks_for_cleanup: List[asyncio.Task] = []
|
||||
|
||||
self.transcription: Optional[Any] = None
|
||||
self.translation: Optional[Any] = None
|
||||
self.diarization: Optional[Any] = None
|
||||
|
||||
if self.args.transcription:
|
||||
self.online = online_factory(self.args, models.asr, models.tokenizer)
|
||||
self.sep = self.online.asr.sep
|
||||
self.transcription = online_factory(self.args, models.asr)
|
||||
self.sep = self.transcription.asr.sep
|
||||
if self.args.diarization:
|
||||
self.diarization = online_diarization_factory(self.args, models.diarization_model)
|
||||
if self.args.target_language:
|
||||
self.online_translation = online_translation_factory(self.args, models.translation_model)
|
||||
if models.translation_model:
|
||||
self.translation = online_translation_factory(self.args, models.translation_model)
|
||||
|
||||
def convert_pcm_to_float(self, pcm_buffer):
|
||||
async def _push_silence_event(self) -> None:
|
||||
if self.transcription_queue:
|
||||
await self.transcription_queue.put(self.current_silence)
|
||||
if self.args.diarization and self.diarization_queue:
|
||||
await self.diarization_queue.put(self.current_silence)
|
||||
if self.translation_queue:
|
||||
await self.translation_queue.put(self.current_silence)
|
||||
|
||||
async def _begin_silence(self) -> None:
|
||||
if self.current_silence:
|
||||
return
|
||||
now = time() - self.beg_loop
|
||||
self.current_silence = Silence(
|
||||
is_starting=True, start=now
|
||||
)
|
||||
await self._push_silence_event()
|
||||
|
||||
async def _end_silence(self) -> None:
|
||||
if not self.current_silence:
|
||||
return
|
||||
now = time() - self.beg_loop
|
||||
self.current_silence.end = now
|
||||
self.current_silence.is_starting=False
|
||||
self.current_silence.has_ended=True
|
||||
self.current_silence.compute_duration()
|
||||
if self.current_silence.duration > MIN_DURATION_REAL_SILENCE:
|
||||
self.state.new_tokens.append(self.current_silence)
|
||||
await self._push_silence_event()
|
||||
self.current_silence = None
|
||||
|
||||
async def _enqueue_active_audio(self, pcm_chunk: np.ndarray) -> None:
|
||||
if pcm_chunk is None or pcm_chunk.size == 0:
|
||||
return
|
||||
if self.transcription_queue:
|
||||
await self.transcription_queue.put(pcm_chunk.copy())
|
||||
if self.args.diarization and self.diarization_queue:
|
||||
await self.diarization_queue.put(pcm_chunk.copy())
|
||||
|
||||
def _slice_before_silence(self, pcm_array: np.ndarray, chunk_sample_start: int, silence_sample: Optional[int]) -> Optional[np.ndarray]:
|
||||
if silence_sample is None:
|
||||
return None
|
||||
relative_index = int(silence_sample - chunk_sample_start)
|
||||
if relative_index <= 0:
|
||||
return None
|
||||
split_index = min(relative_index, len(pcm_array))
|
||||
if split_index <= 0:
|
||||
return None
|
||||
return pcm_array[:split_index]
|
||||
|
||||
def convert_pcm_to_float(self, pcm_buffer: Union[bytes, bytearray]) -> np.ndarray:
|
||||
"""Convert PCM buffer in s16le format to normalized NumPy array."""
|
||||
return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
|
||||
|
||||
async def update_transcription(self, new_tokens, buffer, end_buffer):
|
||||
"""Thread-safe update of transcription with new data."""
|
||||
async with self.lock:
|
||||
self.tokens.extend(new_tokens)
|
||||
self.buffer_transcription = buffer
|
||||
self.end_buffer = end_buffer
|
||||
|
||||
async def update_diarization(self, end_attributed_speaker, buffer_diarization=""):
|
||||
"""Thread-safe update of diarization with new data."""
|
||||
async with self.lock:
|
||||
self.end_attributed_speaker = end_attributed_speaker
|
||||
if buffer_diarization:
|
||||
self.buffer_diarization = buffer_diarization
|
||||
|
||||
async def add_dummy_token(self):
|
||||
"""Placeholder token when no transcription is available."""
|
||||
async with self.lock:
|
||||
current_time = time() - self.beg_loop if self.beg_loop else 0
|
||||
self.tokens.append(ASRToken(
|
||||
start=current_time, end=current_time + 1,
|
||||
text=".", speaker=-1, is_dummy=True
|
||||
))
|
||||
|
||||
async def get_current_state(self):
|
||||
async def get_current_state(self) -> State:
|
||||
"""Get current state."""
|
||||
async with self.lock:
|
||||
current_time = time()
|
||||
|
||||
# Calculate remaining times
|
||||
remaining_transcription = 0
|
||||
if self.end_buffer > 0:
|
||||
remaining_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 1))
|
||||
if self.state.end_buffer > 0:
|
||||
remaining_transcription = max(0, round(current_time - self.beg_loop - self.state.end_buffer, 1))
|
||||
|
||||
remaining_diarization = 0
|
||||
if self.tokens:
|
||||
latest_end = max(self.end_buffer, self.tokens[-1].end if self.tokens else 0)
|
||||
remaining_diarization = max(0, round(latest_end - self.end_attributed_speaker, 1))
|
||||
if self.state.tokens:
|
||||
latest_end = max(self.state.end_buffer, self.state.tokens[-1].end if self.state.tokens else 0)
|
||||
remaining_diarization = max(0, round(latest_end - self.state.end_attributed_speaker, 1))
|
||||
|
||||
return State(
|
||||
tokens=self.tokens.copy(),
|
||||
translated_segments=self.translated_segments.copy(),
|
||||
buffer_transcription=self.buffer_transcription,
|
||||
buffer_diarization=self.buffer_diarization,
|
||||
end_buffer=self.end_buffer,
|
||||
end_attributed_speaker=self.end_attributed_speaker,
|
||||
remaining_time_transcription=remaining_transcription,
|
||||
remaining_time_diarization=remaining_diarization
|
||||
)
|
||||
self.state.remaining_time_transcription = remaining_transcription
|
||||
self.state.remaining_time_diarization = remaining_diarization
|
||||
|
||||
async def reset(self):
|
||||
"""Reset all state variables to initial values."""
|
||||
async with self.lock:
|
||||
self.tokens = []
|
||||
self.translated_segments = []
|
||||
self.buffer_transcription = self.buffer_diarization = ""
|
||||
self.end_buffer = self.end_attributed_speaker = 0
|
||||
self.beg_loop = time()
|
||||
return self.state
|
||||
|
||||
async def transcription_processor(self):
|
||||
async def ffmpeg_stdout_reader(self) -> None:
|
||||
"""Read audio data from FFmpeg stdout and process it into the PCM pipeline."""
|
||||
beg = time()
|
||||
while True:
|
||||
try:
|
||||
if self.is_stopping:
|
||||
logger.info("Stopping ffmpeg_stdout_reader due to stopping flag.")
|
||||
break
|
||||
|
||||
state = await self.ffmpeg_manager.get_state() if self.ffmpeg_manager else FFmpegState.STOPPED
|
||||
if state == FFmpegState.FAILED:
|
||||
logger.error("FFmpeg is in FAILED state, cannot read data")
|
||||
break
|
||||
elif state == FFmpegState.STOPPED:
|
||||
logger.info("FFmpeg is stopped")
|
||||
break
|
||||
elif state != FFmpegState.RUNNING:
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
|
||||
current_time = time()
|
||||
elapsed_time = max(0.0, current_time - beg)
|
||||
buffer_size = max(int(32000 * elapsed_time), 4096) # dynamic read
|
||||
beg = current_time
|
||||
|
||||
chunk = await self.ffmpeg_manager.read_data(buffer_size)
|
||||
if not chunk:
|
||||
# No data currently available
|
||||
await asyncio.sleep(0.05)
|
||||
continue
|
||||
|
||||
self.pcm_buffer.extend(chunk)
|
||||
await self.handle_pcm_data()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("ffmpeg_stdout_reader cancelled.")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in ffmpeg_stdout_reader: {e}")
|
||||
logger.debug(f"Traceback: {traceback.format_exc()}")
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
logger.info("FFmpeg stdout processing finished. Signaling downstream processors if needed.")
|
||||
if self.transcription_queue:
|
||||
await self.transcription_queue.put(SENTINEL)
|
||||
if self.diarization:
|
||||
await self.diarization_queue.put(SENTINEL)
|
||||
if self.translation:
|
||||
await self.translation_queue.put(SENTINEL)
|
||||
|
||||
async def transcription_processor(self) -> None:
|
||||
"""Process audio chunks for transcription."""
|
||||
cumulative_pcm_duration_stream_time = 0.0
|
||||
|
||||
while True:
|
||||
try:
|
||||
item = await self.transcription_queue.get()
|
||||
# item = await self.transcription_queue.get()
|
||||
item = await get_all_from_queue(self.transcription_queue)
|
||||
if item is SENTINEL:
|
||||
logger.debug("Transcription processor received sentinel. Finishing.")
|
||||
self.transcription_queue.task_done()
|
||||
break
|
||||
|
||||
if not self.online:
|
||||
logger.warning("Transcription processor: self.online not initialized.")
|
||||
self.transcription_queue.task_done()
|
||||
continue
|
||||
|
||||
asr_internal_buffer_duration_s = len(getattr(self.online, 'audio_buffer', [])) / self.online.SAMPLING_RATE
|
||||
transcription_lag_s = max(0.0, time() - self.beg_loop - self.end_buffer)
|
||||
asr_internal_buffer_duration_s = len(getattr(self.transcription, 'audio_buffer', [])) / self.transcription.SAMPLING_RATE
|
||||
transcription_lag_s = max(0.0, time() - self.beg_loop - self.state.end_buffer)
|
||||
asr_processing_logs = f"internal_buffer={asr_internal_buffer_duration_s:.2f}s | lag={transcription_lag_s:.2f}s |"
|
||||
if type(item) is Silence:
|
||||
asr_processing_logs += f" + Silence of = {item.duration:.2f}s"
|
||||
if self.tokens:
|
||||
asr_processing_logs += f" | last_end = {self.tokens[-1].end} |"
|
||||
logger.info(asr_processing_logs)
|
||||
|
||||
if type(item) is Silence:
|
||||
cumulative_pcm_duration_stream_time += item.duration
|
||||
self.online.insert_silence(item.duration, self.tokens[-1].end if self.tokens else 0)
|
||||
continue
|
||||
|
||||
if isinstance(item, np.ndarray):
|
||||
pcm_array = item
|
||||
else:
|
||||
raise Exception('item should be pcm_array')
|
||||
|
||||
duration_this_chunk = len(pcm_array) / self.sample_rate
|
||||
cumulative_pcm_duration_stream_time += duration_this_chunk
|
||||
stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time
|
||||
new_tokens = []
|
||||
current_audio_processed_upto = self.state.end_buffer
|
||||
|
||||
self.online.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm)
|
||||
new_tokens, current_audio_processed_upto = await asyncio.to_thread(self.online.process_iter)
|
||||
if isinstance(item, Silence):
|
||||
if item.is_starting:
|
||||
new_tokens, current_audio_processed_upto = await asyncio.to_thread(
|
||||
self.transcription.start_silence
|
||||
)
|
||||
asr_processing_logs += f" + Silence starting"
|
||||
if item.has_ended:
|
||||
asr_processing_logs += f" + Silence of = {item.duration:.2f}s"
|
||||
cumulative_pcm_duration_stream_time += item.duration
|
||||
current_audio_processed_upto = cumulative_pcm_duration_stream_time
|
||||
self.transcription.end_silence(item.duration, self.state.tokens[-1].end if self.state.tokens else 0)
|
||||
if self.state.tokens:
|
||||
asr_processing_logs += f" | last_end = {self.state.tokens[-1].end} |"
|
||||
logger.info(asr_processing_logs)
|
||||
new_tokens = new_tokens or []
|
||||
current_audio_processed_upto = max(current_audio_processed_upto, stream_time_end_of_current_pcm)
|
||||
elif isinstance(item, ChangeSpeaker):
|
||||
self.transcription.new_speaker(item)
|
||||
continue
|
||||
elif isinstance(item, np.ndarray):
|
||||
pcm_array = item
|
||||
logger.info(asr_processing_logs)
|
||||
cumulative_pcm_duration_stream_time += len(pcm_array) / self.sample_rate
|
||||
stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time
|
||||
self.transcription.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm)
|
||||
new_tokens, current_audio_processed_upto = await asyncio.to_thread(self.transcription.process_iter)
|
||||
new_tokens = new_tokens or []
|
||||
|
||||
# Get buffer information
|
||||
_buffer_transcript_obj = self.online.get_buffer()
|
||||
buffer_text = _buffer_transcript_obj.text
|
||||
_buffer_transcript = self.transcription.get_buffer()
|
||||
buffer_text = _buffer_transcript.text
|
||||
|
||||
if new_tokens:
|
||||
validated_text = self.sep.join([t.text for t in new_tokens])
|
||||
if buffer_text.startswith(validated_text):
|
||||
buffer_text = buffer_text[len(validated_text):].lstrip()
|
||||
_buffer_transcript.text = buffer_text[len(validated_text):].lstrip()
|
||||
|
||||
candidate_end_times = [self.end_buffer]
|
||||
candidate_end_times = [self.state.end_buffer]
|
||||
|
||||
if new_tokens:
|
||||
candidate_end_times.append(new_tokens[-1].end)
|
||||
|
||||
if _buffer_transcript_obj.end is not None:
|
||||
candidate_end_times.append(_buffer_transcript_obj.end)
|
||||
if _buffer_transcript.end is not None:
|
||||
candidate_end_times.append(_buffer_transcript.end)
|
||||
|
||||
candidate_end_times.append(current_audio_processed_upto)
|
||||
|
||||
new_end_buffer = max(candidate_end_times)
|
||||
async with self.lock:
|
||||
self.state.tokens.extend(new_tokens)
|
||||
self.state.buffer_transcription = _buffer_transcript
|
||||
self.state.end_buffer = max(candidate_end_times)
|
||||
self.state.new_tokens.extend(new_tokens)
|
||||
self.state.new_tokens_buffer = _buffer_transcript
|
||||
|
||||
await self.update_transcription(
|
||||
new_tokens, buffer_text, new_end_buffer
|
||||
)
|
||||
|
||||
if new_tokens and self.args.target_language and self.translation_queue:
|
||||
if self.translation_queue:
|
||||
for token in new_tokens:
|
||||
await self.translation_queue.put(token)
|
||||
|
||||
self.transcription_queue.task_done()
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in transcription_processor: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
@@ -245,185 +335,137 @@ class AudioProcessor:
|
||||
logger.info("Transcription processor task finished.")
|
||||
|
||||
|
||||
async def diarization_processor(self, diarization_obj):
|
||||
"""Process audio chunks for speaker diarization."""
|
||||
buffer_diarization = ""
|
||||
cumulative_pcm_duration_stream_time = 0.0
|
||||
async def diarization_processor(self) -> None:
|
||||
while True:
|
||||
try:
|
||||
item = await self.diarization_queue.get()
|
||||
item = await get_all_from_queue(self.diarization_queue)
|
||||
if item is SENTINEL:
|
||||
logger.debug("Diarization processor received sentinel. Finishing.")
|
||||
self.diarization_queue.task_done()
|
||||
break
|
||||
|
||||
if type(item) is Silence:
|
||||
cumulative_pcm_duration_stream_time += item.duration
|
||||
diarization_obj.insert_silence(item.duration)
|
||||
elif type(item) is Silence:
|
||||
if item.has_ended:
|
||||
self.diarization.insert_silence(item.duration)
|
||||
continue
|
||||
|
||||
if isinstance(item, np.ndarray):
|
||||
pcm_array = item
|
||||
else:
|
||||
raise Exception('item should be pcm_array')
|
||||
|
||||
# Process diarization
|
||||
await diarization_obj.diarize(pcm_array)
|
||||
|
||||
async with self.lock:
|
||||
self.tokens = diarization_obj.assign_speakers_to_tokens(
|
||||
self.tokens,
|
||||
use_punctuation_split=self.args.punctuation_split
|
||||
)
|
||||
if len(self.tokens) > 0:
|
||||
self.end_attributed_speaker = max(self.tokens[-1].end, self.end_attributed_speaker)
|
||||
if buffer_diarization:
|
||||
self.buffer_diarization = buffer_diarization
|
||||
|
||||
self.diarization_queue.task_done()
|
||||
self.diarization.insert_audio_chunk(item)
|
||||
diarization_segments = await self.diarization.diarize()
|
||||
self.state.new_diarization = diarization_segments
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in diarization_processor: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
if 'pcm_array' in locals() and pcm_array is not SENTINEL:
|
||||
self.diarization_queue.task_done()
|
||||
logger.info("Diarization processor task finished.")
|
||||
|
||||
async def translation_processor(self, online_translation):
|
||||
async def translation_processor(self) -> None:
|
||||
# the idea is to ignore diarization for the moment. We use only transcription tokens.
|
||||
# And the speaker is attributed given the segments used for the translation
|
||||
# in the future we want to have different languages for each speaker etc, so it will be more complex.
|
||||
while True:
|
||||
try:
|
||||
token = await self.translation_queue.get() #block until at least 1 token
|
||||
if token is SENTINEL:
|
||||
item = await get_all_from_queue(self.translation_queue)
|
||||
if item is SENTINEL:
|
||||
logger.debug("Translation processor received sentinel. Finishing.")
|
||||
self.translation_queue.task_done()
|
||||
break
|
||||
|
||||
# get all the available tokens for translation. The more words, the more precise
|
||||
tokens_to_process = [token]
|
||||
additional_tokens = await get_all_from_queue(self.translation_queue)
|
||||
|
||||
sentinel_found = False
|
||||
for additional_token in additional_tokens:
|
||||
if additional_token is SENTINEL:
|
||||
sentinel_found = True
|
||||
break
|
||||
tokens_to_process.append(additional_token)
|
||||
if tokens_to_process:
|
||||
online_translation.insert_tokens(tokens_to_process)
|
||||
self.translated_segments = await asyncio.to_thread(online_translation.process)
|
||||
|
||||
self.translation_queue.task_done()
|
||||
for _ in additional_tokens:
|
||||
self.translation_queue.task_done()
|
||||
|
||||
if sentinel_found:
|
||||
logger.debug("Translation processor received sentinel in batch. Finishing.")
|
||||
break
|
||||
|
||||
elif type(item) is Silence:
|
||||
if item.is_starting:
|
||||
new_translation, new_translation_buffer = self.translation.validate_buffer_and_reset()
|
||||
if item.has_ended:
|
||||
self.translation.insert_silence(item.duration)
|
||||
continue
|
||||
elif isinstance(item, ChangeSpeaker):
|
||||
new_translation, new_translation_buffer = self.translation.validate_buffer_and_reset()
|
||||
pass
|
||||
else:
|
||||
self.translation.insert_tokens(item)
|
||||
new_translation, new_translation_buffer = await asyncio.to_thread(self.translation.process)
|
||||
async with self.lock:
|
||||
self.state.new_translation.append(new_translation)
|
||||
self.state.new_translation_buffer = new_translation_buffer
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in translation_processor: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
if 'token' in locals() and token is not SENTINEL:
|
||||
self.translation_queue.task_done()
|
||||
if 'additional_tokens' in locals():
|
||||
for _ in additional_tokens:
|
||||
self.translation_queue.task_done()
|
||||
logger.info("Translation processor task finished.")
|
||||
|
||||
async def results_formatter(self):
|
||||
async def results_formatter(self) -> AsyncGenerator[FrontData, None]:
|
||||
"""Format processing results for output."""
|
||||
while True:
|
||||
try:
|
||||
# Get current state
|
||||
state = await self.get_current_state()
|
||||
if self._ffmpeg_error:
|
||||
yield FrontData(status="error", error=f"FFmpeg error: {self._ffmpeg_error}")
|
||||
self._ffmpeg_error = None
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
# Add dummy tokens if needed
|
||||
if (not state.tokens or state.tokens[-1].is_dummy) and not self.args.transcription and self.args.diarization:
|
||||
await self.add_dummy_token()
|
||||
sleep(0.5)
|
||||
state = await self.get_current_state()
|
||||
|
||||
# Format output
|
||||
lines, undiarized_text, buffer_transcription, buffer_diarization = format_output(
|
||||
state,
|
||||
self.silence,
|
||||
current_time = time() - self.beg_loop if self.beg_loop else None,
|
||||
args = self.args,
|
||||
debug = self.debug,
|
||||
sep=self.sep
|
||||
self.tokens_alignment.update()
|
||||
lines, buffer_diarization_text, buffer_translation_text = self.tokens_alignment.get_lines(
|
||||
diarization=self.args.diarization,
|
||||
translation=bool(self.translation),
|
||||
current_silence=self.current_silence
|
||||
)
|
||||
# Handle undiarized text
|
||||
if undiarized_text:
|
||||
combined = self.sep.join(undiarized_text)
|
||||
if buffer_transcription:
|
||||
combined += self.sep
|
||||
await self.update_diarization(state.end_attributed_speaker, combined)
|
||||
buffer_diarization = combined
|
||||
state = await self.get_current_state()
|
||||
|
||||
buffer_transcription_text = state.buffer_transcription.text if state.buffer_transcription else ''
|
||||
|
||||
response_status = "active_transcription"
|
||||
if not state.tokens and not buffer_transcription and not buffer_diarization:
|
||||
if not lines and not buffer_transcription_text and not buffer_diarization_text:
|
||||
response_status = "no_audio_detected"
|
||||
lines = []
|
||||
elif response_status == "active_transcription" and not lines:
|
||||
lines = [Line(
|
||||
speaker=1,
|
||||
start=state.get("end_buffer", 0),
|
||||
end=state.get("end_buffer", 0)
|
||||
)]
|
||||
|
||||
response = FrontData(
|
||||
status=response_status,
|
||||
lines=lines,
|
||||
buffer_transcription=buffer_transcription,
|
||||
buffer_diarization=buffer_diarization,
|
||||
buffer_transcription=buffer_transcription_text,
|
||||
buffer_diarization=buffer_diarization_text,
|
||||
buffer_translation=buffer_translation_text,
|
||||
remaining_time_transcription=state.remaining_time_transcription,
|
||||
remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0
|
||||
)
|
||||
|
||||
should_push = (response != self.last_response_content)
|
||||
if should_push and (lines or buffer_transcription or buffer_diarization or response_status == "no_audio_detected"):
|
||||
if should_push:
|
||||
yield response
|
||||
self.last_response_content = response
|
||||
|
||||
# Check for termination condition
|
||||
if self.is_stopping:
|
||||
all_processors_done = True
|
||||
if self.args.transcription and self.transcription_task and not self.transcription_task.done():
|
||||
all_processors_done = False
|
||||
if self.args.diarization and self.diarization_task and not self.diarization_task.done():
|
||||
all_processors_done = False
|
||||
|
||||
if all_processors_done:
|
||||
if self.is_stopping and self._processing_tasks_done():
|
||||
logger.info("Results formatter: All upstream processors are done and in stopping state. Terminating.")
|
||||
return
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in results_formatter: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
logger.warning(f"Exception in results_formatter. Traceback: {traceback.format_exc()}")
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
async def create_tasks(self):
|
||||
async def create_tasks(self) -> AsyncGenerator[FrontData, None]:
|
||||
"""Create and start processing tasks."""
|
||||
self.all_tasks_for_cleanup = []
|
||||
processing_tasks_for_watchdog = []
|
||||
processing_tasks_for_watchdog: List[asyncio.Task] = []
|
||||
|
||||
if self.args.transcription and self.online:
|
||||
# If using FFmpeg (non-PCM input), start it and spawn stdout reader
|
||||
if not self.is_pcm_input:
|
||||
success = await self.ffmpeg_manager.start()
|
||||
if not success:
|
||||
logger.error("Failed to start FFmpeg manager")
|
||||
async def error_generator() -> AsyncGenerator[FrontData, None]:
|
||||
yield FrontData(
|
||||
status="error",
|
||||
error="FFmpeg failed to start. Please check that FFmpeg is installed."
|
||||
)
|
||||
return error_generator()
|
||||
self.ffmpeg_reader_task = asyncio.create_task(self.ffmpeg_stdout_reader())
|
||||
self.all_tasks_for_cleanup.append(self.ffmpeg_reader_task)
|
||||
processing_tasks_for_watchdog.append(self.ffmpeg_reader_task)
|
||||
|
||||
if self.transcription:
|
||||
self.transcription_task = asyncio.create_task(self.transcription_processor())
|
||||
self.all_tasks_for_cleanup.append(self.transcription_task)
|
||||
processing_tasks_for_watchdog.append(self.transcription_task)
|
||||
|
||||
if self.args.diarization and self.diarization:
|
||||
self.diarization_task = asyncio.create_task(self.diarization_processor(self.diarization))
|
||||
if self.diarization:
|
||||
self.diarization_task = asyncio.create_task(self.diarization_processor())
|
||||
self.all_tasks_for_cleanup.append(self.diarization_task)
|
||||
processing_tasks_for_watchdog.append(self.diarization_task)
|
||||
|
||||
if self.args.target_language and self.args.lan != 'auto':
|
||||
self.translation_task = asyncio.create_task(self.translation_processor(self.online_translation))
|
||||
if self.translation:
|
||||
self.translation_task = asyncio.create_task(self.translation_processor())
|
||||
self.all_tasks_for_cleanup.append(self.translation_task)
|
||||
processing_tasks_for_watchdog.append(self.translation_task)
|
||||
|
||||
@@ -433,13 +475,18 @@ class AudioProcessor:
|
||||
|
||||
return self.results_formatter()
|
||||
|
||||
async def watchdog(self, tasks_to_monitor):
|
||||
async def watchdog(self, tasks_to_monitor: List[asyncio.Task]) -> None:
|
||||
"""Monitors the health of critical processing tasks."""
|
||||
tasks_remaining: List[asyncio.Task] = [task for task in tasks_to_monitor if task]
|
||||
while True:
|
||||
try:
|
||||
if not tasks_remaining:
|
||||
logger.info("Watchdog task finishing: all monitored tasks completed.")
|
||||
return
|
||||
|
||||
await asyncio.sleep(10)
|
||||
|
||||
for i, task in enumerate(tasks_to_monitor):
|
||||
for i, task in enumerate(list(tasks_remaining)):
|
||||
if task.done():
|
||||
exc = task.exception()
|
||||
task_name = task.get_name() if hasattr(task, 'get_name') else f"Monitored Task {i}"
|
||||
@@ -447,6 +494,7 @@ class AudioProcessor:
|
||||
logger.error(f"{task_name} unexpectedly completed with exception: {exc}")
|
||||
else:
|
||||
logger.info(f"{task_name} completed normally.")
|
||||
tasks_remaining.remove(task)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Watchdog task cancelled.")
|
||||
@@ -454,7 +502,7 @@ class AudioProcessor:
|
||||
except Exception as e:
|
||||
logger.error(f"Error in watchdog task: {e}", exc_info=True)
|
||||
|
||||
async def cleanup(self):
|
||||
async def cleanup(self) -> None:
|
||||
"""Clean up resources when processing is complete."""
|
||||
logger.info("Starting cleanup of AudioProcessor resources.")
|
||||
self.is_stopping = True
|
||||
@@ -466,16 +514,35 @@ class AudioProcessor:
|
||||
if created_tasks:
|
||||
await asyncio.gather(*created_tasks, return_exceptions=True)
|
||||
logger.info("All processing tasks cancelled or finished.")
|
||||
if self.args.diarization and hasattr(self, 'diarization') and hasattr(self.diarization, 'close'):
|
||||
|
||||
if not self.is_pcm_input and self.ffmpeg_manager:
|
||||
try:
|
||||
await self.ffmpeg_manager.stop()
|
||||
logger.info("FFmpeg manager stopped.")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error stopping FFmpeg manager: {e}")
|
||||
if self.diarization:
|
||||
self.diarization.close()
|
||||
logger.info("AudioProcessor cleanup complete.")
|
||||
|
||||
def _processing_tasks_done(self) -> bool:
|
||||
"""Return True when all active processing tasks have completed."""
|
||||
tasks_to_check = [
|
||||
self.transcription_task,
|
||||
self.diarization_task,
|
||||
self.translation_task,
|
||||
self.ffmpeg_reader_task,
|
||||
]
|
||||
return all(task.done() for task in tasks_to_check if task)
|
||||
|
||||
async def process_audio(self, message):
|
||||
|
||||
async def process_audio(self, message: Optional[bytes]) -> None:
|
||||
"""Process incoming audio data."""
|
||||
|
||||
if not self.beg_loop:
|
||||
self.beg_loop = time()
|
||||
self.current_silence = Silence(start=0.0, is_starting=True)
|
||||
self.tokens_alignment.beg_loop = self.beg_loop
|
||||
|
||||
if not message:
|
||||
logger.info("Empty audio message received, initiating stop sequence.")
|
||||
@@ -484,6 +551,9 @@ class AudioProcessor:
|
||||
if self.transcription_queue:
|
||||
await self.transcription_queue.put(SENTINEL)
|
||||
|
||||
if not self.is_pcm_input and self.ffmpeg_manager:
|
||||
await self.ffmpeg_manager.stop()
|
||||
|
||||
return
|
||||
|
||||
if self.is_stopping:
|
||||
@@ -493,8 +563,19 @@ class AudioProcessor:
|
||||
if self.is_pcm_input:
|
||||
self.pcm_buffer.extend(message)
|
||||
await self.handle_pcm_data()
|
||||
else:
|
||||
if not self.ffmpeg_manager:
|
||||
logger.error("FFmpeg manager not initialized for non-PCM input.")
|
||||
return
|
||||
success = await self.ffmpeg_manager.write_data(message)
|
||||
if not success:
|
||||
ffmpeg_state = await self.ffmpeg_manager.get_state()
|
||||
if ffmpeg_state == FFmpegState.FAILED:
|
||||
logger.error("FFmpeg is in FAILED state, cannot process audio")
|
||||
else:
|
||||
logger.warning("Failed to write audio data to FFmpeg")
|
||||
|
||||
async def handle_pcm_data(self):
|
||||
async def handle_pcm_data(self) -> None:
|
||||
# Process when enough data
|
||||
if len(self.pcm_buffer) < self.bytes_per_sec:
|
||||
return
|
||||
@@ -505,42 +586,38 @@ class AudioProcessor:
|
||||
f"Consider using a smaller model."
|
||||
)
|
||||
|
||||
# Process audio chunk
|
||||
pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:self.max_bytes_per_sec])
|
||||
self.pcm_buffer = self.pcm_buffer[self.max_bytes_per_sec:]
|
||||
chunk_size = min(len(self.pcm_buffer), self.max_bytes_per_sec)
|
||||
aligned_chunk_size = (chunk_size // self.bytes_per_sample) * self.bytes_per_sample
|
||||
|
||||
if aligned_chunk_size == 0:
|
||||
return
|
||||
pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:aligned_chunk_size])
|
||||
self.pcm_buffer = self.pcm_buffer[aligned_chunk_size:]
|
||||
|
||||
num_samples = len(pcm_array)
|
||||
chunk_sample_start = self.total_pcm_samples
|
||||
chunk_sample_end = chunk_sample_start + num_samples
|
||||
|
||||
res = None
|
||||
end_of_audio = False
|
||||
silence_buffer = None
|
||||
|
||||
if self.args.vac:
|
||||
res = self.vac(pcm_array)
|
||||
|
||||
if res is not None:
|
||||
if res.get("end", 0) > res.get("start", 0):
|
||||
end_of_audio = True
|
||||
elif self.silence: #end of silence
|
||||
self.silence = False
|
||||
silence_buffer = Silence(duration=time() - self.start_silence)
|
||||
silence_detected = res.get("end", 0) > res.get("start", 0)
|
||||
if silence_detected and not self.current_silence:
|
||||
pre_silence_chunk = self._slice_before_silence(
|
||||
pcm_array, chunk_sample_start, res.get("end")
|
||||
)
|
||||
if pre_silence_chunk is not None and pre_silence_chunk.size > 0:
|
||||
await self._enqueue_active_audio(pre_silence_chunk)
|
||||
await self._begin_silence()
|
||||
elif self.current_silence:
|
||||
await self._end_silence()
|
||||
|
||||
if silence_buffer:
|
||||
if self.args.transcription and self.transcription_queue:
|
||||
await self.transcription_queue.put(silence_buffer)
|
||||
if self.args.diarization and self.diarization_queue:
|
||||
await self.diarization_queue.put(silence_buffer)
|
||||
if not self.current_silence:
|
||||
await self._enqueue_active_audio(pcm_array)
|
||||
|
||||
if not self.silence:
|
||||
if self.args.transcription and self.transcription_queue:
|
||||
await self.transcription_queue.put(pcm_array.copy())
|
||||
|
||||
if self.args.diarization and self.diarization_queue:
|
||||
await self.diarization_queue.put(pcm_array.copy())
|
||||
|
||||
self.silence_duration = 0.0
|
||||
|
||||
if end_of_audio:
|
||||
self.silence = True
|
||||
self.start_silence = time()
|
||||
self.total_pcm_samples = chunk_sample_end
|
||||
|
||||
if not self.args.transcription and not self.args.diarization:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
41
whisperlivekit/backend_support.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import importlib.util
|
||||
import logging
|
||||
import platform
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def module_available(module_name):
|
||||
"""Return True if the given module can be imported."""
|
||||
return importlib.util.find_spec(module_name) is not None
|
||||
|
||||
|
||||
def mlx_backend_available(warn_on_missing = False):
|
||||
is_macos = platform.system() == "Darwin"
|
||||
is_arm = platform.machine() == "arm64"
|
||||
available = (
|
||||
is_macos
|
||||
and is_arm
|
||||
and module_available("mlx_whisper")
|
||||
)
|
||||
if not available and warn_on_missing and is_macos and is_arm:
|
||||
logger.warning(
|
||||
"=" * 50
|
||||
+ "\nMLX Whisper not found but you are on Apple Silicon. "
|
||||
"Consider installing mlx-whisper for better performance: "
|
||||
"`pip install mlx-whisper`\n"
|
||||
+ "=" * 50
|
||||
)
|
||||
return available
|
||||
|
||||
|
||||
def faster_backend_available(warn_on_missing = False):
|
||||
available = module_available("faster_whisper")
|
||||
if not available and warn_on_missing and platform.system() != "Darwin":
|
||||
logger.warning(
|
||||
"=" * 50
|
||||
+ "\nFaster-Whisper not found. Consider installing faster-whisper "
|
||||
"for better performance: `pip install faster-whisper`\n"
|
||||
+ "=" * 50
|
||||
)
|
||||
return available
|
||||
@@ -5,9 +5,6 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from whisperlivekit import TranscriptionEngine, AudioProcessor, get_inline_ui_html, parse_args
|
||||
import asyncio
|
||||
import logging
|
||||
from starlette.staticfiles import StaticFiles
|
||||
import pathlib
|
||||
import whisperlivekit.web as webpkg
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
logging.getLogger().setLevel(logging.WARNING)
|
||||
@@ -19,15 +16,6 @@ transcription_engine = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
|
||||
#to remove after 0.2.8
|
||||
if args.backend == "simulstreaming" and not args.disable_fast_encoder:
|
||||
logger.warning(f"""
|
||||
{'='*50}
|
||||
WhisperLiveKit 0.2.8 has introduced a new fast encoder feature using MLX Whisper or Faster Whisper for improved speed. Use --disable-fast-encoder to disable if you encounter issues.
|
||||
{'='*50}
|
||||
""")
|
||||
|
||||
global transcription_engine
|
||||
transcription_engine = TranscriptionEngine(
|
||||
**vars(args),
|
||||
@@ -42,8 +30,6 @@ app.add_middleware(
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
web_dir = pathlib.Path(webpkg.__file__).parent
|
||||
app.mount("/web", StaticFiles(directory=str(web_dir)), name="web")
|
||||
|
||||
@app.get("/")
|
||||
async def get():
|
||||
@@ -73,6 +59,11 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
logger.info("WebSocket connection opened.")
|
||||
|
||||
try:
|
||||
await websocket.send_json({"type": "config", "useAudioWorklet": bool(args.pcm_input)})
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send config to client: {e}")
|
||||
|
||||
results_generator = await audio_processor.create_tasks()
|
||||
websocket_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
|
||||
|
||||
@@ -127,6 +118,8 @@ def main():
|
||||
|
||||
if ssl_kwargs:
|
||||
uvicorn_kwargs = {**uvicorn_kwargs, **ssl_kwargs}
|
||||
if args.forwarded_allow_ips:
|
||||
uvicorn_kwargs = { **uvicorn_kwargs, "forwarded_allow_ips" : args.forwarded_allow_ips }
|
||||
|
||||
uvicorn.run(**uvicorn_kwargs)
|
||||
|
||||
|
||||
@@ -1,12 +1,18 @@
|
||||
try:
|
||||
from whisperlivekit.whisper_streaming_custom.whisper_online import backend_factory
|
||||
from whisperlivekit.whisper_streaming_custom.online_asr import OnlineASRProcessor
|
||||
except ImportError:
|
||||
from .whisper_streaming_custom.whisper_online import backend_factory
|
||||
from .whisper_streaming_custom.online_asr import OnlineASRProcessor
|
||||
from whisperlivekit.warmup import warmup_asr
|
||||
from whisperlivekit.local_agreement.whisper_online import backend_factory
|
||||
from whisperlivekit.simul_whisper import SimulStreamingASR
|
||||
from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor
|
||||
from argparse import Namespace
|
||||
import sys
|
||||
import logging
|
||||
|
||||
def update_with_kwargs(_dict, kwargs):
|
||||
_dict.update({
|
||||
k: v for k, v in kwargs.items() if k in _dict
|
||||
})
|
||||
return _dict
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TranscriptionEngine:
|
||||
_instance = None
|
||||
@@ -21,34 +27,69 @@ class TranscriptionEngine:
|
||||
if TranscriptionEngine._initialized:
|
||||
return
|
||||
|
||||
defaults = {
|
||||
global_params = {
|
||||
"host": "localhost",
|
||||
"port": 8000,
|
||||
"warmup_file": None,
|
||||
"diarization": False,
|
||||
"punctuation_split": False,
|
||||
"min_chunk_size": 0.5,
|
||||
"model": "tiny",
|
||||
"model_cache_dir": None,
|
||||
"model_dir": None,
|
||||
"lan": "auto",
|
||||
"task": "transcribe",
|
||||
"target_language": "",
|
||||
"backend": "faster-whisper",
|
||||
"vac": True,
|
||||
"vac_onnx": False,
|
||||
"vac_chunk_size": 0.04,
|
||||
"log_level": "DEBUG",
|
||||
"ssl_certfile": None,
|
||||
"ssl_keyfile": None,
|
||||
"forwarded_allow_ips": None,
|
||||
"transcription": True,
|
||||
"vad": True,
|
||||
"pcm_input": False,
|
||||
# whisperstreaming params:
|
||||
"buffer_trimming": "segment",
|
||||
"confidence_validation": False,
|
||||
"buffer_trimming_sec": 15,
|
||||
# simulstreaming params:
|
||||
"disable_punctuation_split" : False,
|
||||
"diarization_backend": "sortformer",
|
||||
"backend_policy": "simulstreaming",
|
||||
"backend": "auto",
|
||||
}
|
||||
global_params = update_with_kwargs(global_params, kwargs)
|
||||
|
||||
transcription_common_params = {
|
||||
"warmup_file": None,
|
||||
"min_chunk_size": 0.1,
|
||||
"model_size": "base",
|
||||
"model_cache_dir": None,
|
||||
"model_dir": None,
|
||||
"model_path": None,
|
||||
"lan": "auto",
|
||||
"direct_english_translation": False,
|
||||
}
|
||||
transcription_common_params = update_with_kwargs(transcription_common_params, kwargs)
|
||||
|
||||
if transcription_common_params['model_size'].endswith(".en"):
|
||||
transcription_common_params["lan"] = "en"
|
||||
if 'no_transcription' in kwargs:
|
||||
global_params['transcription'] = not global_params['no_transcription']
|
||||
if 'no_vad' in kwargs:
|
||||
global_params['vad'] = not kwargs['no_vad']
|
||||
if 'no_vac' in kwargs:
|
||||
global_params['vac'] = not kwargs['no_vac']
|
||||
|
||||
self.args = Namespace(**{**global_params, **transcription_common_params})
|
||||
|
||||
self.asr = None
|
||||
self.tokenizer = None
|
||||
self.diarization = None
|
||||
self.vac_model = None
|
||||
|
||||
if self.args.vac:
|
||||
from whisperlivekit.silero_vad_iterator import load_silero_vad
|
||||
# Use ONNX if specified, otherwise use JIT (default)
|
||||
use_onnx = kwargs.get('vac_onnx', False)
|
||||
self.vac_model = load_silero_vad(onnx=use_onnx)
|
||||
|
||||
backend_policy = self.args.backend_policy
|
||||
if self.args.transcription:
|
||||
if backend_policy == "simulstreaming":
|
||||
simulstreaming_params = {
|
||||
"disable_fast_encoder": False,
|
||||
"custom_alignment_heads": None,
|
||||
"frame_threshold": 25,
|
||||
"beams": 1,
|
||||
"decoder_type": None,
|
||||
@@ -59,110 +100,79 @@ class TranscriptionEngine:
|
||||
"init_prompt": None,
|
||||
"static_init_prompt": None,
|
||||
"max_context_tokens": None,
|
||||
"model_path": './base.pt',
|
||||
"diarization_backend": "sortformer",
|
||||
# diarization params:
|
||||
"disable_punctuation_split" : False,
|
||||
"segmentation_model": "pyannote/segmentation-3.0",
|
||||
"embedding_model": "pyannote/embedding",
|
||||
"preload_model_count": 1,
|
||||
}
|
||||
simulstreaming_params = update_with_kwargs(simulstreaming_params, kwargs)
|
||||
|
||||
config_dict = {**defaults, **kwargs}
|
||||
|
||||
if 'no_transcription' in kwargs:
|
||||
config_dict['transcription'] = not kwargs['no_transcription']
|
||||
if 'no_vad' in kwargs:
|
||||
config_dict['vad'] = not kwargs['no_vad']
|
||||
if 'no_vac' in kwargs:
|
||||
config_dict['vac'] = not kwargs['no_vac']
|
||||
|
||||
config_dict.pop('no_transcription', None)
|
||||
config_dict.pop('no_vad', None)
|
||||
|
||||
if 'language' in kwargs:
|
||||
config_dict['lan'] = kwargs['language']
|
||||
config_dict.pop('language', None)
|
||||
|
||||
self.args = Namespace(**config_dict)
|
||||
|
||||
self.asr = None
|
||||
self.tokenizer = None
|
||||
self.diarization = None
|
||||
self.vac_model = None
|
||||
|
||||
if self.args.vac:
|
||||
import torch
|
||||
self.vac_model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
|
||||
|
||||
if self.args.transcription:
|
||||
if self.args.backend == "simulstreaming":
|
||||
from whisperlivekit.simul_whisper import SimulStreamingASR
|
||||
self.tokenizer = None
|
||||
simulstreaming_kwargs = {}
|
||||
for attr in ['frame_threshold', 'beams', 'decoder_type', 'audio_max_len', 'audio_min_len',
|
||||
'cif_ckpt_path', 'never_fire', 'init_prompt', 'static_init_prompt',
|
||||
'max_context_tokens', 'model_path', 'warmup_file', 'preload_model_count', 'disable_fast_encoder']:
|
||||
if hasattr(self.args, attr):
|
||||
simulstreaming_kwargs[attr] = getattr(self.args, attr)
|
||||
|
||||
# Add segment_length from min_chunk_size
|
||||
simulstreaming_kwargs['segment_length'] = getattr(self.args, 'min_chunk_size', 0.5)
|
||||
simulstreaming_kwargs['task'] = self.args.task
|
||||
|
||||
size = self.args.model
|
||||
self.asr = SimulStreamingASR(
|
||||
modelsize=size,
|
||||
lan=self.args.lan,
|
||||
cache_dir=getattr(self.args, 'model_cache_dir', None),
|
||||
model_dir=getattr(self.args, 'model_dir', None),
|
||||
**simulstreaming_kwargs
|
||||
**transcription_common_params,
|
||||
**simulstreaming_params,
|
||||
backend=self.args.backend,
|
||||
)
|
||||
logger.info(
|
||||
"Using SimulStreaming policy with %s backend",
|
||||
getattr(self.asr, "encoder_backend", "whisper"),
|
||||
)
|
||||
|
||||
else:
|
||||
self.asr, self.tokenizer = backend_factory(self.args)
|
||||
warmup_asr(self.asr, self.args.warmup_file) #for simulstreaming, warmup should be done in the online class not here
|
||||
|
||||
whisperstreaming_params = {
|
||||
"buffer_trimming": "segment",
|
||||
"confidence_validation": False,
|
||||
"buffer_trimming_sec": 15,
|
||||
}
|
||||
whisperstreaming_params = update_with_kwargs(whisperstreaming_params, kwargs)
|
||||
|
||||
self.asr = backend_factory(
|
||||
backend=self.args.backend,
|
||||
**transcription_common_params,
|
||||
**whisperstreaming_params,
|
||||
)
|
||||
logger.info(
|
||||
"Using LocalAgreement policy with %s backend",
|
||||
getattr(self.asr, "backend_choice", self.asr.__class__.__name__),
|
||||
)
|
||||
|
||||
if self.args.diarization:
|
||||
if self.args.diarization_backend == "diart":
|
||||
from whisperlivekit.diarization.diart_backend import DiartDiarization
|
||||
diart_params = {
|
||||
"segmentation_model": "pyannote/segmentation-3.0",
|
||||
"embedding_model": "pyannote/embedding",
|
||||
}
|
||||
diart_params = update_with_kwargs(diart_params, kwargs)
|
||||
self.diarization_model = DiartDiarization(
|
||||
block_duration=self.args.min_chunk_size,
|
||||
segmentation_model_name=self.args.segmentation_model,
|
||||
embedding_model_name=self.args.embedding_model
|
||||
**diart_params
|
||||
)
|
||||
elif self.args.diarization_backend == "sortformer":
|
||||
from whisperlivekit.diarization.sortformer_backend import SortformerDiarization
|
||||
self.diarization_model = SortformerDiarization()
|
||||
else:
|
||||
raise ValueError(f"Unknown diarization backend: {self.args.diarization_backend}")
|
||||
|
||||
self.translation_model = None
|
||||
if self.args.target_language:
|
||||
if self.args.lan == 'auto':
|
||||
raise Exception('Translation cannot be set with language auto')
|
||||
if self.args.lan == 'auto' and backend_policy != "simulstreaming":
|
||||
raise Exception('Translation cannot be set with language auto when transcription backend is not simulstreaming')
|
||||
else:
|
||||
from whisperlivekit.translation.translation import load_model
|
||||
self.translation_model = load_model([self.args.lan]) #in the future we want to handle different languages for different speakers
|
||||
|
||||
try:
|
||||
from nllw import load_model
|
||||
except:
|
||||
raise Exception('To use translation, you must install nllw: `pip install nllw`')
|
||||
translation_params = {
|
||||
"nllb_backend": "transformers",
|
||||
"nllb_size": "600M"
|
||||
}
|
||||
translation_params = update_with_kwargs(translation_params, kwargs)
|
||||
self.translation_model = load_model([self.args.lan], **translation_params) #in the future we want to handle different languages for different speakers
|
||||
TranscriptionEngine._initialized = True
|
||||
|
||||
|
||||
|
||||
def online_factory(args, asr, tokenizer, logfile=sys.stderr):
|
||||
if args.backend == "simulstreaming":
|
||||
def online_factory(args, asr):
|
||||
if args.backend_policy == "simulstreaming":
|
||||
from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor
|
||||
online = SimulStreamingOnlineProcessor(
|
||||
asr,
|
||||
logfile=logfile,
|
||||
)
|
||||
online = SimulStreamingOnlineProcessor(asr)
|
||||
else:
|
||||
online = OnlineASRProcessor(
|
||||
asr,
|
||||
tokenizer,
|
||||
logfile=logfile,
|
||||
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
|
||||
confidence_validation = args.confidence_validation
|
||||
)
|
||||
online = OnlineASRProcessor(asr)
|
||||
return online
|
||||
|
||||
|
||||
@@ -181,5 +191,5 @@ def online_translation_factory(args, translation_model):
|
||||
#should be at speaker level in the future:
|
||||
#one shared nllb model for all speaker
|
||||
#one tokenizer per speaker/language
|
||||
from whisperlivekit.translation.translation import OnlineTranslation
|
||||
from nllw import OnlineTranslation
|
||||
return OnlineTranslation(translation_model, [args.lan], [args.target_language])
|
||||
@@ -26,7 +26,7 @@ class DiarizationObserver(Observer):
|
||||
"""Observer that logs all data emitted by the diarization pipeline and stores speaker segments."""
|
||||
|
||||
def __init__(self):
|
||||
self.speaker_segments = []
|
||||
self.diarization_segments = []
|
||||
self.processed_time = 0
|
||||
self.segment_lock = threading.Lock()
|
||||
self.global_time_offset = 0.0
|
||||
@@ -48,7 +48,7 @@ class DiarizationObserver(Observer):
|
||||
for speaker, label in annotation._labels.items():
|
||||
for start, end in zip(label.segments_boundaries_[:-1], label.segments_boundaries_[1:]):
|
||||
print(f" {speaker}: {start:.2f}s-{end:.2f}s")
|
||||
self.speaker_segments.append(SpeakerSegment(
|
||||
self.diarization_segments.append(SpeakerSegment(
|
||||
speaker=speaker,
|
||||
start=start + self.global_time_offset,
|
||||
end=end + self.global_time_offset
|
||||
@@ -59,14 +59,14 @@ class DiarizationObserver(Observer):
|
||||
def get_segments(self) -> List[SpeakerSegment]:
|
||||
"""Get a copy of the current speaker segments."""
|
||||
with self.segment_lock:
|
||||
return self.speaker_segments.copy()
|
||||
return self.diarization_segments.copy()
|
||||
|
||||
def clear_old_segments(self, older_than: float = 30.0):
|
||||
"""Clear segments older than the specified time."""
|
||||
with self.segment_lock:
|
||||
current_time = self.processed_time
|
||||
self.speaker_segments = [
|
||||
segment for segment in self.speaker_segments
|
||||
self.diarization_segments = [
|
||||
segment for segment in self.diarization_segments
|
||||
if current_time - segment.end < older_than
|
||||
]
|
||||
|
||||
@@ -178,7 +178,6 @@ class DiartDiarization:
|
||||
|
||||
self.pipeline = SpeakerDiarization(config=config)
|
||||
self.observer = DiarizationObserver()
|
||||
self.lag_diart = None
|
||||
|
||||
if use_microphone:
|
||||
self.source = MicrophoneAudioSource(block_duration=block_duration)
|
||||
@@ -217,32 +216,6 @@ class DiartDiarization:
|
||||
if self.custom_source:
|
||||
self.custom_source.close()
|
||||
|
||||
def assign_speakers_to_tokens(self, tokens: list, use_punctuation_split: bool = False) -> float:
|
||||
"""
|
||||
Assign speakers to tokens based on timing overlap with speaker segments.
|
||||
Uses the segments collected by the observer.
|
||||
|
||||
If use_punctuation_split is True, uses punctuation marks to refine speaker boundaries.
|
||||
"""
|
||||
segments = self.observer.get_segments()
|
||||
|
||||
# Debug logging
|
||||
logger.debug(f"assign_speakers_to_tokens called with {len(tokens)} tokens")
|
||||
logger.debug(f"Available segments: {len(segments)}")
|
||||
for i, seg in enumerate(segments[:5]): # Show first 5 segments
|
||||
logger.debug(f" Segment {i}: {seg.speaker} [{seg.start:.2f}-{seg.end:.2f}]")
|
||||
|
||||
if not self.lag_diart and segments and tokens:
|
||||
self.lag_diart = segments[0].start - tokens[0].start
|
||||
|
||||
if not use_punctuation_split:
|
||||
for token in tokens:
|
||||
for segment in segments:
|
||||
if not (segment.end <= token.start + self.lag_diart or segment.start >= token.end + self.lag_diart):
|
||||
token.speaker = extract_number(segment.speaker) + 1
|
||||
else:
|
||||
tokens = add_speaker_to_tokens(segments, tokens)
|
||||
return tokens
|
||||
|
||||
def concatenate_speakers(segments):
|
||||
segments_concatenated = [{"speaker": 1, "begin": 0.0, "end": 0.0}]
|
||||
|
||||
@@ -94,11 +94,11 @@ class SortformerDiarizationOnline:
|
||||
model_name: Pre-trained model name (default: "nvidia/diar_streaming_sortformer_4spk-v2")
|
||||
"""
|
||||
self.sample_rate = sample_rate
|
||||
self.speaker_segments = []
|
||||
self.diarization_segments = []
|
||||
self.diar_segments = []
|
||||
self.buffer_audio = np.array([], dtype=np.float32)
|
||||
self.segment_lock = threading.Lock()
|
||||
self.global_time_offset = 0.0
|
||||
self.processed_time = 0.0
|
||||
self.debug = False
|
||||
|
||||
self.diar_model = shared_model.diar_model
|
||||
@@ -156,11 +156,9 @@ class SortformerDiarizationOnline:
|
||||
self.streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||
self.streaming_state.mean_sil_emb = torch.zeros((batch_size, self.diar_model.sortformer_modules.fc_d_model), device=device)
|
||||
self.streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||
|
||||
# Initialize total predictions tensor
|
||||
self.total_preds = torch.zeros((batch_size, 0, self.diar_model.sortformer_modules.n_spk), device=device)
|
||||
|
||||
def insert_silence(self, silence_duration: float):
|
||||
def insert_silence(self, silence_duration: Optional[float]):
|
||||
"""
|
||||
Insert silence period by adjusting the global time offset.
|
||||
|
||||
@@ -171,22 +169,24 @@ class SortformerDiarizationOnline:
|
||||
self.global_time_offset += silence_duration
|
||||
logger.debug(f"Inserted silence of {silence_duration:.2f}s, new offset: {self.global_time_offset:.2f}s")
|
||||
|
||||
async def diarize(self, pcm_array: np.ndarray):
|
||||
def insert_audio_chunk(self, pcm_array: np.ndarray):
|
||||
if self.debug:
|
||||
self.audio_buffer.append(pcm_array.copy())
|
||||
self.buffer_audio = np.concatenate([self.buffer_audio, pcm_array.copy()])
|
||||
|
||||
|
||||
async def diarize(self):
|
||||
"""
|
||||
Process audio data for diarization in streaming fashion.
|
||||
|
||||
Args:
|
||||
pcm_array: Audio data as numpy array
|
||||
"""
|
||||
try:
|
||||
if self.debug:
|
||||
self.audio_buffer.append(pcm_array.copy())
|
||||
|
||||
threshold = int(self.chunk_duration_seconds * self.sample_rate)
|
||||
|
||||
self.buffer_audio = np.concatenate([self.buffer_audio, pcm_array.copy()])
|
||||
if not len(self.buffer_audio) >= threshold:
|
||||
return
|
||||
return []
|
||||
|
||||
audio = self.buffer_audio[:threshold]
|
||||
self.buffer_audio = self.buffer_audio[threshold:]
|
||||
@@ -223,195 +223,57 @@ class SortformerDiarizationOnline:
|
||||
left_offset=left_offset,
|
||||
right_offset=right_offset,
|
||||
)
|
||||
|
||||
# Convert predictions to speaker segments
|
||||
self._process_predictions()
|
||||
new_segments = self._process_predictions()
|
||||
|
||||
self._chunk_index += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in diarize: {e}")
|
||||
raise
|
||||
|
||||
# TODO: Handle case when stream ends with partial buffer (accumulated_duration > 0 but < chunk_duration_seconds)
|
||||
return new_segments
|
||||
|
||||
def _process_predictions(self):
|
||||
"""Process model predictions and convert to speaker segments."""
|
||||
try:
|
||||
preds_np = self.total_preds[0].cpu().numpy()
|
||||
active_speakers = np.argmax(preds_np, axis=1)
|
||||
|
||||
if self._len_prediction is None:
|
||||
self._len_prediction = len(active_speakers)
|
||||
self._len_prediction = len(active_speakers) #12
|
||||
|
||||
# Get predictions for current chunk
|
||||
frame_duration = self.chunk_duration_seconds / self._len_prediction
|
||||
current_chunk_preds = active_speakers[-self._len_prediction:]
|
||||
|
||||
new_segments = []
|
||||
|
||||
with self.segment_lock:
|
||||
# Process predictions into segments
|
||||
base_time = self._chunk_index * self.chunk_duration_seconds + self.global_time_offset
|
||||
|
||||
current_spk = current_chunk_preds[0]
|
||||
start_time = round(base_time, 2)
|
||||
for idx, spk in enumerate(current_chunk_preds):
|
||||
start_time = base_time + idx * frame_duration
|
||||
end_time = base_time + (idx + 1) * frame_duration
|
||||
|
||||
# Check if this continues the last segment or starts a new one
|
||||
if (self.speaker_segments and
|
||||
self.speaker_segments[-1].speaker == spk and
|
||||
abs(self.speaker_segments[-1].end - start_time) < frame_duration * 0.5):
|
||||
# Continue existing segment
|
||||
self.speaker_segments[-1].end = end_time
|
||||
else:
|
||||
|
||||
# Create new segment
|
||||
self.speaker_segments.append(SpeakerSegment(
|
||||
speaker=spk,
|
||||
current_time = round(base_time + idx * frame_duration, 2)
|
||||
if spk != current_spk:
|
||||
new_segments.append(SpeakerSegment(
|
||||
speaker=current_spk,
|
||||
start=start_time,
|
||||
end=end_time
|
||||
end=current_time
|
||||
))
|
||||
|
||||
# Update processed time
|
||||
self.processed_time = max(self.processed_time, base_time + self.chunk_duration_seconds)
|
||||
|
||||
logger.debug(f"Processed chunk {self._chunk_index}, total segments: {len(self.speaker_segments)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing predictions: {e}")
|
||||
|
||||
def assign_speakers_to_tokens(self, tokens: list, use_punctuation_split: bool = False) -> list:
|
||||
"""
|
||||
Assign speakers to tokens based on timing overlap with speaker segments.
|
||||
|
||||
Args:
|
||||
tokens: List of tokens with timing information
|
||||
use_punctuation_split: Whether to use punctuation for boundary refinement
|
||||
|
||||
Returns:
|
||||
List of tokens with speaker assignments
|
||||
"""
|
||||
with self.segment_lock:
|
||||
segments = self.speaker_segments.copy()
|
||||
|
||||
if not segments or not tokens:
|
||||
logger.debug("No segments or tokens available for speaker assignment")
|
||||
return tokens
|
||||
|
||||
logger.debug(f"Assigning speakers to {len(tokens)} tokens using {len(segments)} segments")
|
||||
use_punctuation_split = False
|
||||
if not use_punctuation_split:
|
||||
# Simple overlap-based assignment
|
||||
for token in tokens:
|
||||
token.speaker = -1 # Default to no speaker
|
||||
for segment in segments:
|
||||
# Check for timing overlap
|
||||
if not (segment.end <= token.start or segment.start >= token.end):
|
||||
token.speaker = segment.speaker + 1 # Convert to 1-based indexing
|
||||
break
|
||||
else:
|
||||
# Use punctuation-aware assignment (similar to diart_backend)
|
||||
tokens = self._add_speaker_to_tokens_with_punctuation(segments, tokens)
|
||||
|
||||
return tokens
|
||||
|
||||
def _add_speaker_to_tokens_with_punctuation(self, segments: List[SpeakerSegment], tokens: list) -> list:
|
||||
"""
|
||||
Assign speakers to tokens with punctuation-aware boundary adjustment.
|
||||
|
||||
Args:
|
||||
segments: List of speaker segments
|
||||
tokens: List of tokens to assign speakers to
|
||||
|
||||
Returns:
|
||||
List of tokens with speaker assignments
|
||||
"""
|
||||
punctuation_marks = {'.', '!', '?'}
|
||||
punctuation_tokens = [token for token in tokens if token.text.strip() in punctuation_marks]
|
||||
|
||||
# Convert segments to concatenated format
|
||||
segments_concatenated = self._concatenate_speakers(segments)
|
||||
|
||||
# Adjust segment boundaries based on punctuation
|
||||
for ind, segment in enumerate(segments_concatenated):
|
||||
for i, punctuation_token in enumerate(punctuation_tokens):
|
||||
if punctuation_token.start > segment['end']:
|
||||
after_length = punctuation_token.start - segment['end']
|
||||
before_length = segment['end'] - punctuation_tokens[i - 1].end if i > 0 else float('inf')
|
||||
|
||||
if before_length > after_length:
|
||||
segment['end'] = punctuation_token.start
|
||||
if i < len(punctuation_tokens) - 1 and ind + 1 < len(segments_concatenated):
|
||||
segments_concatenated[ind + 1]['begin'] = punctuation_token.start
|
||||
else:
|
||||
segment['end'] = punctuation_tokens[i - 1].end if i > 0 else segment['end']
|
||||
if i < len(punctuation_tokens) - 1 and ind - 1 >= 0:
|
||||
segments_concatenated[ind - 1]['begin'] = punctuation_tokens[i - 1].end
|
||||
break
|
||||
|
||||
# Ensure non-overlapping tokens
|
||||
last_end = 0.0
|
||||
for token in tokens:
|
||||
start = max(last_end + 0.01, token.start)
|
||||
token.start = start
|
||||
token.end = max(start, token.end)
|
||||
last_end = token.end
|
||||
|
||||
# Assign speakers based on adjusted segments
|
||||
ind_last_speaker = 0
|
||||
for segment in segments_concatenated:
|
||||
for i, token in enumerate(tokens[ind_last_speaker:]):
|
||||
if token.end <= segment['end']:
|
||||
token.speaker = segment['speaker']
|
||||
ind_last_speaker = i + 1
|
||||
elif token.start > segment['end']:
|
||||
break
|
||||
|
||||
return tokens
|
||||
|
||||
def _concatenate_speakers(self, segments: List[SpeakerSegment]) -> List[dict]:
|
||||
"""
|
||||
Concatenate consecutive segments from the same speaker.
|
||||
|
||||
Args:
|
||||
segments: List of speaker segments
|
||||
|
||||
Returns:
|
||||
List of concatenated speaker segments
|
||||
"""
|
||||
if not segments:
|
||||
return []
|
||||
|
||||
segments_concatenated = [{"speaker": segments[0].speaker + 1, "begin": segments[0].start, "end": segments[0].end}]
|
||||
|
||||
for segment in segments[1:]:
|
||||
speaker = segment.speaker + 1
|
||||
if segments_concatenated[-1]['speaker'] != speaker:
|
||||
segments_concatenated.append({"speaker": speaker, "begin": segment.start, "end": segment.end})
|
||||
else:
|
||||
segments_concatenated[-1]['end'] = segment.end
|
||||
|
||||
return segments_concatenated
|
||||
start_time = current_time
|
||||
current_spk = spk
|
||||
new_segments.append(
|
||||
SpeakerSegment(
|
||||
speaker=current_spk,
|
||||
start=start_time,
|
||||
end=current_time
|
||||
)
|
||||
)
|
||||
return new_segments
|
||||
|
||||
def get_segments(self) -> List[SpeakerSegment]:
|
||||
"""Get a copy of the current speaker segments."""
|
||||
with self.segment_lock:
|
||||
return self.speaker_segments.copy()
|
||||
|
||||
def clear_old_segments(self, older_than: float = 30.0):
|
||||
"""Clear segments older than the specified time."""
|
||||
with self.segment_lock:
|
||||
current_time = self.processed_time
|
||||
self.speaker_segments = [
|
||||
segment for segment in self.speaker_segments
|
||||
if current_time - segment.end < older_than
|
||||
]
|
||||
logger.debug(f"Cleared old segments, remaining: {len(self.speaker_segments)}")
|
||||
return self.diarization_segments.copy()
|
||||
|
||||
def close(self):
|
||||
"""Close the diarization system and clean up resources."""
|
||||
logger.info("Closing SortformerDiarization")
|
||||
with self.segment_lock:
|
||||
self.speaker_segments.clear()
|
||||
self.diarization_segments.clear()
|
||||
|
||||
if self.debug:
|
||||
concatenated_audio = np.concatenate(self.audio_buffer)
|
||||
@@ -437,7 +299,7 @@ if __name__ == '__main__':
|
||||
|
||||
async def main():
|
||||
"""TEST ONLY."""
|
||||
an4_audio = 'audio_test.mp3'
|
||||
an4_audio = 'diarization_audio.wav'
|
||||
signal, sr = librosa.load(an4_audio, sr=16000)
|
||||
signal = signal[:16000*30]
|
||||
|
||||
@@ -449,13 +311,15 @@ if __name__ == '__main__':
|
||||
print("Speaker 0: 0:25 - 0:30")
|
||||
print("=" * 50)
|
||||
|
||||
diarization = SortformerDiarization(sample_rate=16000)
|
||||
diarization_backend = SortformerDiarization()
|
||||
diarization = SortformerDiarizationOnline(shared_model = diarization_backend)
|
||||
chunk_size = 1600
|
||||
|
||||
for i in range(0, len(signal), chunk_size):
|
||||
chunk = signal[i:i+chunk_size]
|
||||
await diarization.diarize(chunk)
|
||||
new_segments = await diarization.diarize(chunk)
|
||||
print(f"Processed chunk {i // chunk_size + 1}")
|
||||
print(new_segments)
|
||||
|
||||
segments = diarization.get_segments()
|
||||
print("\nDiarization results:")
|
||||
|
||||
@@ -1,205 +0,0 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import logging
|
||||
|
||||
from nemo.collections.asr.models import SortformerEncLabelModel
|
||||
from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor
|
||||
import librosa
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def load_model():
|
||||
|
||||
diar_model = SortformerEncLabelModel.from_pretrained("nvidia/diar_streaming_sortformer_4spk-v2")
|
||||
diar_model.eval()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
diar_model.to(torch.device("cuda"))
|
||||
|
||||
#we target 1 second lag for the moment. chunk_len could be reduced.
|
||||
diar_model.sortformer_modules.chunk_len = 10
|
||||
diar_model.sortformer_modules.subsampling_factor = 10 #8 would be better ideally
|
||||
|
||||
diar_model.sortformer_modules.chunk_right_context = 0 #no.
|
||||
diar_model.sortformer_modules.chunk_left_context = 10 #big so it compensiate the problem with no padding later.
|
||||
|
||||
diar_model.sortformer_modules.spkcache_len = 188
|
||||
diar_model.sortformer_modules.fifo_len = 188
|
||||
diar_model.sortformer_modules.spkcache_update_period = 144
|
||||
diar_model.sortformer_modules.log = False
|
||||
diar_model.sortformer_modules._check_streaming_parameters()
|
||||
|
||||
|
||||
audio2mel = AudioToMelSpectrogramPreprocessor(
|
||||
window_size= 0.025,
|
||||
normalize="NA",
|
||||
n_fft=512,
|
||||
features=128,
|
||||
pad_to=0) #pad_to 16 works better than 0. On test audio, we detect a third speaker for 1 second with pad_to=0. To solve that : increase left context to 10.
|
||||
|
||||
return diar_model, audio2mel
|
||||
|
||||
diar_model, audio2mel = load_model()
|
||||
|
||||
class StreamingSortformerState:
|
||||
"""
|
||||
This class creates a class instance that will be used to store the state of the
|
||||
streaming Sortformer model.
|
||||
|
||||
Attributes:
|
||||
spkcache (torch.Tensor): Speaker cache to store embeddings from start
|
||||
spkcache_lengths (torch.Tensor): Lengths of the speaker cache
|
||||
spkcache_preds (torch.Tensor): The speaker predictions for the speaker cache parts
|
||||
fifo (torch.Tensor): FIFO queue to save the embedding from the latest chunks
|
||||
fifo_lengths (torch.Tensor): Lengths of the FIFO queue
|
||||
fifo_preds (torch.Tensor): The speaker predictions for the FIFO queue parts
|
||||
spk_perm (torch.Tensor): Speaker permutation information for the speaker cache
|
||||
mean_sil_emb (torch.Tensor): Mean silence embedding
|
||||
n_sil_frames (torch.Tensor): Number of silence frames
|
||||
"""
|
||||
|
||||
spkcache = None # Speaker cache to store embeddings from start
|
||||
spkcache_lengths = None #
|
||||
spkcache_preds = None # speaker cache predictions
|
||||
fifo = None # to save the embedding from the latest chunks
|
||||
fifo_lengths = None
|
||||
fifo_preds = None
|
||||
spk_perm = None
|
||||
mean_sil_emb = None
|
||||
n_sil_frames = None
|
||||
|
||||
|
||||
def init_streaming_state(self, batch_size: int = 1, async_streaming: bool = False, device: torch.device = None):
|
||||
"""
|
||||
Initializes StreamingSortformerState with empty tensors or zero-valued tensors.
|
||||
|
||||
Args:
|
||||
batch_size (int): Batch size for tensors in streaming state
|
||||
async_streaming (bool): True for asynchronous update, False for synchronous update
|
||||
device (torch.device): Device for tensors in streaming state
|
||||
|
||||
Returns:
|
||||
streaming_state (SortformerStreamingState): initialized streaming state
|
||||
"""
|
||||
streaming_state = StreamingSortformerState()
|
||||
if async_streaming:
|
||||
streaming_state.spkcache = torch.zeros((batch_size, self.spkcache_len, self.fc_d_model), device=device)
|
||||
streaming_state.spkcache_preds = torch.zeros((batch_size, self.spkcache_len, self.n_spk), device=device)
|
||||
streaming_state.spkcache_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||
streaming_state.fifo = torch.zeros((batch_size, self.fifo_len, self.fc_d_model), device=device)
|
||||
streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||
else:
|
||||
streaming_state.spkcache = torch.zeros((batch_size, 0, self.fc_d_model), device=device)
|
||||
streaming_state.fifo = torch.zeros((batch_size, 0, self.fc_d_model), device=device)
|
||||
streaming_state.mean_sil_emb = torch.zeros((batch_size, self.fc_d_model), device=device)
|
||||
streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||
return streaming_state
|
||||
|
||||
|
||||
def process_diarization(chunks):
|
||||
"""
|
||||
what it does:
|
||||
1. Preprocessing: Applies dithering and pre-emphasis (high-pass filter) if enabled
|
||||
2. STFT: Computes the Short-Time Fourier Transform using:
|
||||
- the window of window_size=0.025 --> size of a window : 400 samples
|
||||
- the hop parameter : n_window_stride = 0.01 -> every 160 samples, a new window
|
||||
3. Magnitude Calculation: Converts complex STFT output to magnitude spectrogram
|
||||
4. Mel Conversion: Applies Mel filterbanks (128 filters in this case) to get Mel spectrogram
|
||||
5. Logarithm: Takes the log of the Mel spectrogram (if `log=True`)
|
||||
6. Normalization: Skips normalization since `normalize="NA"`
|
||||
7. Padding: Pads the time dimension to a multiple of `pad_to` (default 16)
|
||||
"""
|
||||
previous_chunk = None
|
||||
l_chunk_feat_seq_t = []
|
||||
for chunk in chunks:
|
||||
audio_signal_chunk = torch.tensor(chunk).unsqueeze(0).to(diar_model.device)
|
||||
audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]]).to(diar_model.device)
|
||||
processed_signal_chunk, processed_signal_length_chunk = audio2mel.get_features(audio_signal_chunk, audio_signal_length_chunk)
|
||||
if previous_chunk is not None:
|
||||
to_add = previous_chunk[:, :, -99:]
|
||||
total = torch.concat([to_add, processed_signal_chunk], dim=2)
|
||||
else:
|
||||
total = processed_signal_chunk
|
||||
previous_chunk = processed_signal_chunk
|
||||
l_chunk_feat_seq_t.append(torch.transpose(total, 1, 2))
|
||||
|
||||
batch_size = 1
|
||||
streaming_state = init_streaming_state(diar_model.sortformer_modules,
|
||||
batch_size = batch_size,
|
||||
async_streaming = True,
|
||||
device = diar_model.device
|
||||
)
|
||||
total_preds = torch.zeros((batch_size, 0, diar_model.sortformer_modules.n_spk), device=diar_model.device)
|
||||
|
||||
chunk_duration_seconds = diar_model.sortformer_modules.chunk_len * diar_model.sortformer_modules.subsampling_factor * diar_model.preprocessor._cfg.window_stride
|
||||
|
||||
l_speakers = [
|
||||
{'start_time': 0,
|
||||
'end_time': 0,
|
||||
'speaker': 0
|
||||
}
|
||||
]
|
||||
len_prediction = None
|
||||
left_offset = 0
|
||||
right_offset = 8
|
||||
for i, chunk_feat_seq_t in enumerate(l_chunk_feat_seq_t):
|
||||
with torch.inference_mode():
|
||||
streaming_state, total_preds = diar_model.forward_streaming_step(
|
||||
processed_signal=chunk_feat_seq_t,
|
||||
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]),
|
||||
streaming_state=streaming_state,
|
||||
total_preds=total_preds,
|
||||
left_offset=left_offset,
|
||||
right_offset=right_offset,
|
||||
)
|
||||
left_offset = 8
|
||||
preds_np = total_preds[0].cpu().numpy()
|
||||
active_speakers = np.argmax(preds_np, axis=1)
|
||||
if len_prediction is None:
|
||||
len_prediction = len(active_speakers) # we want to get the len of 1 prediction
|
||||
frame_duration = chunk_duration_seconds / len_prediction
|
||||
active_speakers = active_speakers[-len_prediction:]
|
||||
for idx, spk in enumerate(active_speakers):
|
||||
if spk != l_speakers[-1]['speaker']:
|
||||
l_speakers.append(
|
||||
{'start_time': (i * chunk_duration_seconds + idx * frame_duration),
|
||||
'end_time': (i * chunk_duration_seconds + (idx + 1) * frame_duration),
|
||||
'speaker': spk
|
||||
})
|
||||
else:
|
||||
l_speakers[-1]['end_time'] = i * chunk_duration_seconds + (idx + 1) * frame_duration
|
||||
|
||||
|
||||
"""
|
||||
Should print
|
||||
[{'start_time': 0, 'end_time': 8.72, 'speaker': 0},
|
||||
{'start_time': 8.72, 'end_time': 18.88, 'speaker': 1},
|
||||
{'start_time': 18.88, 'end_time': 24.96, 'speaker': 2},
|
||||
{'start_time': 24.96, 'end_time': 31.68, 'speaker': 0}]
|
||||
"""
|
||||
for speaker in l_speakers:
|
||||
print(f"Speaker {speaker['speaker']}: {speaker['start_time']:.2f}s - {speaker['end_time']:.2f}s")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
an4_audio = 'audio_test.mp3'
|
||||
signal, sr = librosa.load(an4_audio, sr=16000)
|
||||
signal = signal[:16000*30]
|
||||
# signal = signal[:-(len(signal)%16000)]
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("Expected ground truth:")
|
||||
print("Speaker 0: 0:00 - 0:09")
|
||||
print("Speaker 1: 0:09 - 0:19")
|
||||
print("Speaker 2: 0:19 - 0:25")
|
||||
print("Speaker 0: 0:25 - 0:30")
|
||||
print("=" * 50)
|
||||
|
||||
chunk_size = 16000 # 1 second
|
||||
chunks = []
|
||||
for i in range(0, len(signal), chunk_size):
|
||||
chunk = signal[i:i+chunk_size]
|
||||
chunks.append(chunk)
|
||||
|
||||
process_diarization(chunks)
|
||||
197
whisperlivekit/ffmpeg_manager.py
Normal file
@@ -0,0 +1,197 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Optional, Callable
|
||||
import contextlib
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
ERROR_INSTALL_INSTRUCTIONS = f"""
|
||||
{'='*50}
|
||||
FFmpeg is not installed or not found in your system's PATH.
|
||||
Alternative Solution: You can still use WhisperLiveKit without FFmpeg by adding the --pcm-input parameter. Note that when using this option, audio will not be compressed between the frontend and backend, which may result in higher bandwidth usage.
|
||||
|
||||
If you want to install FFmpeg:
|
||||
|
||||
# Ubuntu/Debian:
|
||||
sudo apt update && sudo apt install ffmpeg
|
||||
|
||||
# macOS (using Homebrew):
|
||||
brew install ffmpeg
|
||||
|
||||
# Windows:
|
||||
# 1. Download the latest static build from https://ffmpeg.org/download.html
|
||||
# 2. Extract the archive (e.g., to C:\\FFmpeg).
|
||||
# 3. Add the 'bin' directory (e.g., C:\\FFmpeg\\bin) to your system's PATH environment variable.
|
||||
|
||||
After installation, please restart the application.
|
||||
{'='*50}
|
||||
"""
|
||||
|
||||
class FFmpegState(Enum):
|
||||
STOPPED = "stopped"
|
||||
STARTING = "starting"
|
||||
RUNNING = "running"
|
||||
RESTARTING = "restarting"
|
||||
FAILED = "failed"
|
||||
|
||||
class FFmpegManager:
|
||||
def __init__(self, sample_rate: int = 16000, channels: int = 1):
|
||||
self.sample_rate = sample_rate
|
||||
self.channels = channels
|
||||
|
||||
self.process: Optional[asyncio.subprocess.Process] = None
|
||||
self._stderr_task: Optional[asyncio.Task] = None
|
||||
|
||||
self.on_error_callback: Optional[Callable[[str], None]] = None
|
||||
|
||||
self.state = FFmpegState.STOPPED
|
||||
self._state_lock = asyncio.Lock()
|
||||
|
||||
async def start(self) -> bool:
|
||||
async with self._state_lock:
|
||||
if self.state != FFmpegState.STOPPED:
|
||||
logger.warning(f"FFmpeg already running in state: {self.state}")
|
||||
return False
|
||||
self.state = FFmpegState.STARTING
|
||||
|
||||
try:
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-hide_banner",
|
||||
"-loglevel", "error",
|
||||
"-i", "pipe:0",
|
||||
"-f", "s16le",
|
||||
"-acodec", "pcm_s16le",
|
||||
"-ac", str(self.channels),
|
||||
"-ar", str(self.sample_rate),
|
||||
"pipe:1"
|
||||
]
|
||||
|
||||
self.process = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
|
||||
self._stderr_task = asyncio.create_task(self._drain_stderr())
|
||||
|
||||
async with self._state_lock:
|
||||
self.state = FFmpegState.RUNNING
|
||||
|
||||
logger.info("FFmpeg started.")
|
||||
return True
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(ERROR_INSTALL_INSTRUCTIONS)
|
||||
async with self._state_lock:
|
||||
self.state = FFmpegState.FAILED
|
||||
if self.on_error_callback:
|
||||
await self.on_error_callback("ffmpeg_not_found")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting FFmpeg: {e}")
|
||||
async with self._state_lock:
|
||||
self.state = FFmpegState.FAILED
|
||||
if self.on_error_callback:
|
||||
await self.on_error_callback("start_failed")
|
||||
return False
|
||||
|
||||
async def stop(self):
|
||||
async with self._state_lock:
|
||||
if self.state == FFmpegState.STOPPED:
|
||||
return
|
||||
self.state = FFmpegState.STOPPED
|
||||
|
||||
if self.process:
|
||||
if self.process.stdin and not self.process.stdin.is_closing():
|
||||
self.process.stdin.close()
|
||||
await self.process.stdin.wait_closed()
|
||||
await self.process.wait()
|
||||
self.process = None
|
||||
|
||||
if self._stderr_task:
|
||||
self._stderr_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._stderr_task
|
||||
|
||||
logger.info("FFmpeg stopped.")
|
||||
|
||||
async def write_data(self, data: bytes) -> bool:
|
||||
async with self._state_lock:
|
||||
if self.state != FFmpegState.RUNNING:
|
||||
logger.warning(f"Cannot write, FFmpeg state: {self.state}")
|
||||
return False
|
||||
|
||||
try:
|
||||
self.process.stdin.write(data)
|
||||
await self.process.stdin.drain()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error writing to FFmpeg: {e}")
|
||||
if self.on_error_callback:
|
||||
await self.on_error_callback("write_error")
|
||||
return False
|
||||
|
||||
async def read_data(self, size: int) -> Optional[bytes]:
|
||||
async with self._state_lock:
|
||||
if self.state != FFmpegState.RUNNING:
|
||||
logger.warning(f"Cannot read, FFmpeg state: {self.state}")
|
||||
return None
|
||||
|
||||
try:
|
||||
data = await asyncio.wait_for(
|
||||
self.process.stdout.read(size),
|
||||
timeout=20.0
|
||||
)
|
||||
return data
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("FFmpeg read timeout.")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading from FFmpeg: {e}")
|
||||
if self.on_error_callback:
|
||||
await self.on_error_callback("read_error")
|
||||
return None
|
||||
|
||||
async def get_state(self) -> FFmpegState:
|
||||
async with self._state_lock:
|
||||
return self.state
|
||||
|
||||
async def restart(self) -> bool:
|
||||
async with self._state_lock:
|
||||
if self.state == FFmpegState.RESTARTING:
|
||||
logger.warning("Restart already in progress.")
|
||||
return False
|
||||
self.state = FFmpegState.RESTARTING
|
||||
|
||||
logger.info("Restarting FFmpeg...")
|
||||
|
||||
try:
|
||||
await self.stop()
|
||||
await asyncio.sleep(1) # short delay before restarting
|
||||
return await self.start()
|
||||
except Exception as e:
|
||||
logger.error(f"Error during FFmpeg restart: {e}")
|
||||
async with self._state_lock:
|
||||
self.state = FFmpegState.FAILED
|
||||
if self.on_error_callback:
|
||||
await self.on_error_callback("restart_failed")
|
||||
return False
|
||||
|
||||
async def _drain_stderr(self):
|
||||
try:
|
||||
while True:
|
||||
if not self.process or not self.process.stderr:
|
||||
break
|
||||
line = await self.process.stderr.readline()
|
||||
if not line:
|
||||
break
|
||||
logger.debug(f"FFmpeg stderr: {line.decode(errors='ignore').strip()}")
|
||||
except asyncio.CancelledError:
|
||||
logger.info("FFmpeg stderr drain task cancelled.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error draining FFmpeg stderr: {e}")
|
||||
@@ -6,19 +6,21 @@ import math
|
||||
from typing import List
|
||||
import numpy as np
|
||||
from whisperlivekit.timed_objects import ASRToken
|
||||
from whisperlivekit.model_paths import resolve_model_path, model_path_and_type
|
||||
from whisperlivekit.whisper.transcribe import transcribe as whisper_transcribe
|
||||
logger = logging.getLogger(__name__)
|
||||
class ASRBase:
|
||||
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
|
||||
# "" for faster-whisper because it emits the spaces when needed)
|
||||
|
||||
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr):
|
||||
def __init__(self, lan, model_size=None, cache_dir=None, model_dir=None, logfile=sys.stderr):
|
||||
self.logfile = logfile
|
||||
self.transcribe_kargs = {}
|
||||
if lan == "auto":
|
||||
self.original_language = None
|
||||
else:
|
||||
self.original_language = lan
|
||||
self.model = self.load_model(modelsize, cache_dir, model_dir)
|
||||
self.model = self.load_model(model_size, cache_dir, model_dir)
|
||||
|
||||
def with_offset(self, offset: float) -> ASRToken:
|
||||
# This method is kept for compatibility (typically you will use ASRToken.with_offset)
|
||||
@@ -27,7 +29,7 @@ class ASRBase:
|
||||
def __repr__(self):
|
||||
return f"ASRToken(start={self.start:.2f}, end={self.end:.2f}, text={self.text!r})"
|
||||
|
||||
def load_model(self, modelsize, cache_dir, model_dir):
|
||||
def load_model(self, model_size, cache_dir, model_dir):
|
||||
raise NotImplementedError("must be implemented in the child class")
|
||||
|
||||
def transcribe(self, audio, init_prompt=""):
|
||||
@@ -37,40 +39,60 @@ class ASRBase:
|
||||
raise NotImplementedError("must be implemented in the child class")
|
||||
|
||||
|
||||
class WhisperTimestampedASR(ASRBase):
|
||||
"""Uses whisper_timestamped as the backend."""
|
||||
class WhisperASR(ASRBase):
|
||||
"""Uses WhisperLiveKit's built-in Whisper implementation."""
|
||||
sep = " "
|
||||
|
||||
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
||||
import whisper
|
||||
import whisper_timestamped
|
||||
from whisper_timestamped import transcribe_timestamped
|
||||
def load_model(self, model_size=None, cache_dir=None, model_dir=None):
|
||||
from whisperlivekit.whisper import load_model as load_model
|
||||
|
||||
self.transcribe_timestamped = transcribe_timestamped
|
||||
if model_dir is not None:
|
||||
logger.debug("ignoring model_dir, not implemented")
|
||||
return whisper.load_model(modelsize, download_root=cache_dir)
|
||||
resolved_path = resolve_model_path(model_dir)
|
||||
if resolved_path.is_dir():
|
||||
pytorch_path, _, _ = model_path_and_type(resolved_path)
|
||||
if pytorch_path is None:
|
||||
raise FileNotFoundError(
|
||||
f"No supported PyTorch checkpoint found under {resolved_path}"
|
||||
)
|
||||
resolved_path = pytorch_path
|
||||
logger.debug(f"Loading Whisper model from custom path {resolved_path}")
|
||||
return load_model(str(resolved_path))
|
||||
|
||||
if model_size is None:
|
||||
raise ValueError("Either model_size or model_dir must be set for WhisperASR")
|
||||
|
||||
return load_model(model_size, download_root=cache_dir)
|
||||
|
||||
def transcribe(self, audio, init_prompt=""):
|
||||
result = self.transcribe_timestamped(
|
||||
options = dict(self.transcribe_kargs)
|
||||
options.pop("vad", None)
|
||||
options.pop("vad_filter", None)
|
||||
language = self.original_language if self.original_language else None
|
||||
|
||||
result = whisper_transcribe(
|
||||
self.model,
|
||||
audio,
|
||||
language=self.original_language,
|
||||
language=language,
|
||||
initial_prompt=init_prompt,
|
||||
verbose=None,
|
||||
condition_on_previous_text=True,
|
||||
**self.transcribe_kargs,
|
||||
word_timestamps=True,
|
||||
**options,
|
||||
)
|
||||
return result
|
||||
|
||||
def ts_words(self, r) -> List[ASRToken]:
|
||||
"""
|
||||
Converts the whisper_timestamped result to a list of ASRToken objects.
|
||||
Converts the Whisper result to a list of ASRToken objects.
|
||||
"""
|
||||
tokens = []
|
||||
for segment in r["segments"]:
|
||||
for word in segment["words"]:
|
||||
token = ASRToken(word["start"], word["end"], word["text"])
|
||||
token = ASRToken(
|
||||
word["start"],
|
||||
word["end"],
|
||||
word["word"],
|
||||
probability=word.get("probability"),
|
||||
)
|
||||
tokens.append(token)
|
||||
return tokens
|
||||
|
||||
@@ -78,27 +100,24 @@ class WhisperTimestampedASR(ASRBase):
|
||||
return [segment["end"] for segment in res["segments"]]
|
||||
|
||||
def use_vad(self):
|
||||
self.transcribe_kargs["vad"] = True
|
||||
|
||||
def set_translate_task(self):
|
||||
self.transcribe_kargs["task"] = "translate"
|
||||
|
||||
logger.warning("VAD is not currently supported for WhisperASR backend and will be ignored.")
|
||||
|
||||
class FasterWhisperASR(ASRBase):
|
||||
"""Uses faster-whisper as the backend."""
|
||||
sep = ""
|
||||
|
||||
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
||||
def load_model(self, model_size=None, cache_dir=None, model_dir=None):
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
if model_dir is not None:
|
||||
logger.debug(f"Loading whisper model from model_dir {model_dir}. "
|
||||
f"modelsize and cache_dir parameters are not used.")
|
||||
model_size_or_path = model_dir
|
||||
elif modelsize is not None:
|
||||
model_size_or_path = modelsize
|
||||
resolved_path = resolve_model_path(model_dir)
|
||||
logger.debug(f"Loading faster-whisper model from {resolved_path}. "
|
||||
f"model_size and cache_dir parameters are not used.")
|
||||
model_size_or_path = str(resolved_path)
|
||||
elif model_size is not None:
|
||||
model_size_or_path = model_size
|
||||
else:
|
||||
raise ValueError("Either modelsize or model_dir must be set")
|
||||
raise ValueError("Either model_size or model_dir must be set")
|
||||
device = "auto" # Allow CTranslate2 to decide available device
|
||||
compute_type = "auto" # Allow CTranslate2 to decide faster compute type
|
||||
|
||||
@@ -139,28 +158,25 @@ class FasterWhisperASR(ASRBase):
|
||||
def use_vad(self):
|
||||
self.transcribe_kargs["vad_filter"] = True
|
||||
|
||||
def set_translate_task(self):
|
||||
self.transcribe_kargs["task"] = "translate"
|
||||
|
||||
|
||||
class MLXWhisper(ASRBase):
|
||||
"""
|
||||
Uses MLX Whisper optimized for Apple Silicon.
|
||||
"""
|
||||
sep = ""
|
||||
|
||||
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
||||
def load_model(self, model_size=None, cache_dir=None, model_dir=None):
|
||||
from mlx_whisper.transcribe import ModelHolder, transcribe
|
||||
import mlx.core as mx
|
||||
|
||||
if model_dir is not None:
|
||||
logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used.")
|
||||
model_size_or_path = model_dir
|
||||
elif modelsize is not None:
|
||||
model_size_or_path = self.translate_model_name(modelsize)
|
||||
logger.debug(f"Loading whisper model {modelsize}. You use mlx whisper, so {model_size_or_path} will be used.")
|
||||
resolved_path = resolve_model_path(model_dir)
|
||||
logger.debug(f"Loading MLX Whisper model from {resolved_path}. model_size parameter is not used.")
|
||||
model_size_or_path = str(resolved_path)
|
||||
elif model_size is not None:
|
||||
model_size_or_path = self.translate_model_name(model_size)
|
||||
logger.debug(f"Loading whisper model {model_size}. You use mlx whisper, so {model_size_or_path} will be used.")
|
||||
else:
|
||||
raise ValueError("Either modelsize or model_dir must be set")
|
||||
raise ValueError("Either model_size or model_dir must be set")
|
||||
|
||||
self.model_size_or_path = model_size_or_path
|
||||
dtype = mx.float16
|
||||
@@ -208,7 +224,8 @@ class MLXWhisper(ASRBase):
|
||||
if segment.get("no_speech_prob", 0) > 0.9:
|
||||
continue
|
||||
for word in segment.get("words", []):
|
||||
token = ASRToken(word["start"], word["end"], word["word"], probability=word["probability"])
|
||||
probability=word["probability"]
|
||||
token = ASRToken(word["start"], word["end"], word["word"])
|
||||
tokens.append(token)
|
||||
return tokens
|
||||
|
||||
@@ -218,10 +235,6 @@ class MLXWhisper(ASRBase):
|
||||
def use_vad(self):
|
||||
self.transcribe_kargs["vad_filter"] = True
|
||||
|
||||
def set_translate_task(self):
|
||||
self.transcribe_kargs["task"] = "translate"
|
||||
|
||||
|
||||
class OpenaiApiASR(ASRBase):
|
||||
"""Uses OpenAI's Whisper API for transcription."""
|
||||
def __init__(self, lan=None, temperature=0, logfile=sys.stderr):
|
||||
@@ -232,7 +245,7 @@ class OpenaiApiASR(ASRBase):
|
||||
self.temperature = temperature
|
||||
self.load_model()
|
||||
self.use_vad_opt = False
|
||||
self.task = "transcribe"
|
||||
self.direct_english_translation = False
|
||||
|
||||
def load_model(self, *args, **kwargs):
|
||||
from openai import OpenAI
|
||||
@@ -274,7 +287,7 @@ class OpenaiApiASR(ASRBase):
|
||||
"temperature": self.temperature,
|
||||
"timestamp_granularities": ["word", "segment"],
|
||||
}
|
||||
if self.task != "translate" and self.original_language:
|
||||
if not self.direct_english_translation and self.original_language:
|
||||
params["language"] = self.original_language
|
||||
if prompt:
|
||||
params["prompt"] = prompt
|
||||
@@ -285,6 +298,3 @@ class OpenaiApiASR(ASRBase):
|
||||
|
||||
def use_vad(self):
|
||||
self.use_vad_opt = True
|
||||
|
||||
def set_translate_task(self):
|
||||
self.task = "translate"
|
||||
@@ -106,9 +106,6 @@ class OnlineASRProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
asr,
|
||||
tokenize_method: Optional[callable] = None,
|
||||
buffer_trimming: Tuple[str, float] = ("segment", 15),
|
||||
confidence_validation = False,
|
||||
logfile=sys.stderr,
|
||||
):
|
||||
"""
|
||||
@@ -119,13 +116,14 @@ class OnlineASRProcessor:
|
||||
buffer_trimming: A tuple (option, seconds), where option is either "sentence" or "segment".
|
||||
"""
|
||||
self.asr = asr
|
||||
self.tokenize = tokenize_method
|
||||
self.tokenize = asr.tokenizer
|
||||
self.logfile = logfile
|
||||
self.confidence_validation = confidence_validation
|
||||
self.confidence_validation = asr.confidence_validation
|
||||
self.global_time_offset = 0.0
|
||||
self.init()
|
||||
|
||||
self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
|
||||
self.buffer_trimming_way = asr.buffer_trimming
|
||||
self.buffer_trimming_sec = asr.buffer_trimming_sec
|
||||
|
||||
if self.buffer_trimming_way not in ["sentence", "segment"]:
|
||||
raise ValueError("buffer_trimming must be either 'sentence' or 'segment'")
|
||||
@@ -153,21 +151,32 @@ class OnlineASRProcessor:
|
||||
"""Append an audio chunk (a numpy array) to the current audio buffer."""
|
||||
self.audio_buffer = np.append(self.audio_buffer, audio)
|
||||
|
||||
def insert_silence(self, silence_duration, offset):
|
||||
"""
|
||||
If silences are > 5s, we do a complete context clear. Otherwise, we just insert a small silence and shift the last_attend_frame
|
||||
"""
|
||||
# if self.transcript_buffer.buffer:
|
||||
# self.committed.extend(self.transcript_buffer.buffer)
|
||||
# self.transcript_buffer.buffer = []
|
||||
def start_silence(self):
|
||||
if self.audio_buffer.size == 0:
|
||||
return [], self.get_audio_buffer_end_time()
|
||||
return self.process_iter()
|
||||
|
||||
if True: #silence_duration < 3: #we want the last audio to be treated to not have a gap. could also be handled in the future in ends_with_silence.
|
||||
gap_silence = np.zeros(int(16000 * silence_duration), dtype=np.int16)
|
||||
def end_silence(self, silence_duration: Optional[float], offset: float):
|
||||
if not silence_duration or silence_duration <= 0:
|
||||
return
|
||||
|
||||
long_silence = silence_duration >= 5
|
||||
if not long_silence:
|
||||
gap_samples = int(self.SAMPLING_RATE * silence_duration)
|
||||
if gap_samples > 0:
|
||||
gap_silence = np.zeros(gap_samples, dtype=np.float32)
|
||||
self.insert_audio_chunk(gap_silence)
|
||||
else:
|
||||
self.init(offset=silence_duration + offset)
|
||||
|
||||
self.global_time_offset += silence_duration
|
||||
|
||||
def insert_silence(self, silence_duration, offset):
|
||||
"""
|
||||
Backwards compatibility shim for legacy callers that still use insert_silence.
|
||||
"""
|
||||
self.end_silence(silence_duration, offset)
|
||||
|
||||
def prompt(self) -> Tuple[str, str]:
|
||||
"""
|
||||
Returns a tuple: (prompt, context), where:
|
||||
@@ -402,11 +411,11 @@ class OnlineASRProcessor:
|
||||
) -> Transcript:
|
||||
sep = sep if sep is not None else self.asr.sep
|
||||
text = sep.join(token.text for token in tokens)
|
||||
probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
|
||||
# probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
|
||||
if tokens:
|
||||
start = offset + tokens[0].start
|
||||
end = offset + tokens[-1].end
|
||||
else:
|
||||
start = None
|
||||
end = None
|
||||
return Transcript(start, end, text, probability=probability)
|
||||
return Transcript(start, end, text)
|
||||
199
whisperlivekit/local_agreement/whisper_online.py
Normal file
@@ -0,0 +1,199 @@
|
||||
#!/usr/bin/env python3
|
||||
import sys
|
||||
import numpy as np
|
||||
import librosa
|
||||
from functools import lru_cache
|
||||
import time
|
||||
import logging
|
||||
import platform
|
||||
from .backends import FasterWhisperASR, MLXWhisper, WhisperASR, OpenaiApiASR
|
||||
from whisperlivekit.warmup import warmup_asr
|
||||
from whisperlivekit.model_paths import resolve_model_path, model_path_and_type
|
||||
from whisperlivekit.backend_support import (
|
||||
mlx_backend_available,
|
||||
faster_backend_available,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
WHISPER_LANG_CODES = "af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh".split(
|
||||
","
|
||||
)
|
||||
|
||||
|
||||
def create_tokenizer(lan):
|
||||
"""returns an object that has split function that works like the one of MosesTokenizer"""
|
||||
|
||||
assert (
|
||||
lan in WHISPER_LANG_CODES
|
||||
), "language must be Whisper's supported lang code: " + " ".join(WHISPER_LANG_CODES)
|
||||
|
||||
if lan == "uk":
|
||||
import tokenize_uk
|
||||
|
||||
class UkrainianTokenizer:
|
||||
def split(self, text):
|
||||
return tokenize_uk.tokenize_sents(text)
|
||||
|
||||
return UkrainianTokenizer()
|
||||
|
||||
# supported by fast-mosestokenizer
|
||||
if (
|
||||
lan
|
||||
in "as bn ca cs de el en es et fi fr ga gu hi hu is it kn lt lv ml mni mr nl or pa pl pt ro ru sk sl sv ta te yue zh".split()
|
||||
):
|
||||
from mosestokenizer import MosesSentenceSplitter
|
||||
|
||||
return MosesSentenceSplitter(lan)
|
||||
|
||||
# the following languages are in Whisper, but not in wtpsplit:
|
||||
if (
|
||||
lan
|
||||
in "as ba bo br bs fo haw hr ht jw lb ln lo mi nn oc sa sd sn so su sw tk tl tt".split()
|
||||
):
|
||||
logger.debug(
|
||||
f"{lan} code is not supported by wtpsplit. Going to use None lang_code option."
|
||||
)
|
||||
lan = None
|
||||
|
||||
from wtpsplit import WtP
|
||||
|
||||
# downloads the model from huggingface on the first use
|
||||
wtp = WtP("wtp-canine-s-12l-no-adapters")
|
||||
|
||||
class WtPtok:
|
||||
def split(self, sent):
|
||||
return wtp.split(sent, lang_code=lan)
|
||||
|
||||
return WtPtok()
|
||||
|
||||
|
||||
def backend_factory(
|
||||
backend,
|
||||
lan,
|
||||
model_size,
|
||||
model_cache_dir,
|
||||
model_dir,
|
||||
model_path,
|
||||
direct_english_translation,
|
||||
buffer_trimming,
|
||||
buffer_trimming_sec,
|
||||
confidence_validation,
|
||||
warmup_file=None,
|
||||
min_chunk_size=None,
|
||||
):
|
||||
backend_choice = backend
|
||||
custom_reference = model_path or model_dir
|
||||
resolved_root = None
|
||||
pytorch_checkpoint = None
|
||||
has_mlx_weights = False
|
||||
has_fw_weights = False
|
||||
|
||||
if custom_reference:
|
||||
resolved_root = resolve_model_path(custom_reference)
|
||||
if resolved_root.is_dir():
|
||||
pytorch_checkpoint, has_mlx_weights, has_fw_weights = model_path_and_type(resolved_root)
|
||||
else:
|
||||
pytorch_checkpoint = resolved_root
|
||||
|
||||
if backend_choice == "openai-api":
|
||||
logger.debug("Using OpenAI API.")
|
||||
asr = OpenaiApiASR(lan=lan)
|
||||
else:
|
||||
backend_choice = _normalize_backend_choice(
|
||||
backend_choice,
|
||||
resolved_root,
|
||||
has_mlx_weights,
|
||||
has_fw_weights,
|
||||
)
|
||||
|
||||
if backend_choice == "faster-whisper":
|
||||
asr_cls = FasterWhisperASR
|
||||
if resolved_root is not None and not resolved_root.is_dir():
|
||||
raise ValueError("Faster-Whisper backend expects a directory with CTranslate2 weights.")
|
||||
model_override = str(resolved_root) if resolved_root is not None else None
|
||||
elif backend_choice == "mlx-whisper":
|
||||
asr_cls = MLXWhisper
|
||||
if resolved_root is not None and not resolved_root.is_dir():
|
||||
raise ValueError("MLX Whisper backend expects a directory containing MLX weights.")
|
||||
model_override = str(resolved_root) if resolved_root is not None else None
|
||||
else:
|
||||
asr_cls = WhisperASR
|
||||
model_override = str(pytorch_checkpoint) if pytorch_checkpoint is not None else None
|
||||
if custom_reference and model_override is None:
|
||||
raise FileNotFoundError(
|
||||
f"No PyTorch checkpoint found under {resolved_root or custom_reference}"
|
||||
)
|
||||
|
||||
t = time.time()
|
||||
logger.info(f"Loading Whisper {model_size} model for language {lan} using backend {backend_choice}...")
|
||||
asr = asr_cls(
|
||||
model_size=model_size,
|
||||
lan=lan,
|
||||
cache_dir=model_cache_dir,
|
||||
model_dir=model_override,
|
||||
)
|
||||
e = time.time()
|
||||
logger.info(f"done. It took {round(e-t,2)} seconds.")
|
||||
|
||||
if direct_english_translation:
|
||||
tgt_language = "en" # Whisper translates into English
|
||||
else:
|
||||
tgt_language = lan # Whisper transcribes in this language
|
||||
|
||||
# Create the tokenizer
|
||||
if buffer_trimming == "sentence":
|
||||
tokenizer = create_tokenizer(tgt_language)
|
||||
else:
|
||||
tokenizer = None
|
||||
|
||||
warmup_asr(asr, warmup_file)
|
||||
|
||||
asr.confidence_validation = confidence_validation
|
||||
asr.tokenizer = tokenizer
|
||||
asr.buffer_trimming = buffer_trimming
|
||||
asr.buffer_trimming_sec = buffer_trimming_sec
|
||||
asr.backend_choice = backend_choice
|
||||
return asr
|
||||
|
||||
|
||||
def _normalize_backend_choice(
|
||||
preferred_backend,
|
||||
resolved_root,
|
||||
has_mlx_weights,
|
||||
has_fw_weights,
|
||||
):
|
||||
backend_choice = preferred_backend
|
||||
|
||||
if backend_choice == "auto":
|
||||
if mlx_backend_available(warn_on_missing=True) and (resolved_root is None or has_mlx_weights):
|
||||
return "mlx-whisper"
|
||||
if faster_backend_available(warn_on_missing=True) and (resolved_root is None or has_fw_weights):
|
||||
return "faster-whisper"
|
||||
return "whisper"
|
||||
|
||||
if backend_choice == "mlx-whisper":
|
||||
if not mlx_backend_available():
|
||||
raise RuntimeError("mlx-whisper backend requested but mlx-whisper is not installed.")
|
||||
if resolved_root is not None and not has_mlx_weights:
|
||||
raise FileNotFoundError(
|
||||
f"mlx-whisper backend requested but no MLX weights were found under {resolved_root}"
|
||||
)
|
||||
if platform.system() != "Darwin":
|
||||
logger.warning("mlx-whisper backend requested on a non-macOS system; this may fail.")
|
||||
return backend_choice
|
||||
|
||||
if backend_choice == "faster-whisper":
|
||||
if not faster_backend_available():
|
||||
raise RuntimeError("faster-whisper backend requested but faster-whisper is not installed.")
|
||||
if resolved_root is not None and not has_fw_weights:
|
||||
raise FileNotFoundError(
|
||||
f"faster-whisper backend requested but no Faster-Whisper weights were found under {resolved_root}"
|
||||
)
|
||||
return backend_choice
|
||||
|
||||
if backend_choice == "whisper":
|
||||
return backend_choice
|
||||
|
||||
raise ValueError(f"Unknown backend '{preferred_backend}' for LocalAgreement.")
|
||||
69
whisperlivekit/model_paths.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
|
||||
def model_path_and_type(model_path: Union[str, Path]) -> Tuple[Optional[Path], bool, bool]:
|
||||
"""
|
||||
Inspect the provided path and determine which model formats are available.
|
||||
|
||||
Returns:
|
||||
pytorch_path: Path to a PyTorch checkpoint (if present).
|
||||
compatible_whisper_mlx: True if MLX weights exist in this folder.
|
||||
compatible_faster_whisper: True if Faster-Whisper (ctranslate2) weights exist.
|
||||
"""
|
||||
path = Path(model_path)
|
||||
|
||||
compatible_whisper_mlx = False
|
||||
compatible_faster_whisper = False
|
||||
pytorch_path: Optional[Path] = None
|
||||
|
||||
if path.is_file() and path.suffix.lower() in [".pt", ".safetensors", ".bin"]:
|
||||
pytorch_path = path
|
||||
return pytorch_path, compatible_whisper_mlx, compatible_faster_whisper
|
||||
|
||||
if path.is_dir():
|
||||
for file in path.iterdir():
|
||||
if not file.is_file():
|
||||
continue
|
||||
|
||||
filename = file.name.lower()
|
||||
suffix = file.suffix.lower()
|
||||
|
||||
if filename in {"weights.npz", "weights.safetensors"}:
|
||||
compatible_whisper_mlx = True
|
||||
elif filename in {"model.bin", "encoder.bin", "decoder.bin"}:
|
||||
compatible_faster_whisper = True
|
||||
elif suffix in {".pt", ".safetensors"}:
|
||||
pytorch_path = file
|
||||
elif filename == "pytorch_model.bin":
|
||||
pytorch_path = file
|
||||
|
||||
if pytorch_path is None:
|
||||
fallback = path / "pytorch_model.bin"
|
||||
if fallback.exists():
|
||||
pytorch_path = fallback
|
||||
|
||||
return pytorch_path, compatible_whisper_mlx, compatible_faster_whisper
|
||||
|
||||
|
||||
def resolve_model_path(model_path: Union[str, Path]) -> Path:
|
||||
"""
|
||||
Return a local path for the provided model reference.
|
||||
|
||||
If the path does not exist locally, it is treated as a Hugging Face repo id
|
||||
and downloaded via snapshot_download.
|
||||
"""
|
||||
path = Path(model_path).expanduser()
|
||||
if path.exists():
|
||||
return path
|
||||
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
except ImportError as exc: # pragma: no cover - optional dependency guard
|
||||
raise FileNotFoundError(
|
||||
f"Model path '{model_path}' does not exist locally and huggingface_hub "
|
||||
"is not installed to download it."
|
||||
) from exc
|
||||
|
||||
downloaded_path = Path(snapshot_download(repo_id=str(model_path)))
|
||||
return downloaded_path
|
||||
@@ -81,14 +81,15 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--min-chunk-size",
|
||||
type=float,
|
||||
default=0.5,
|
||||
default=0.1,
|
||||
help="Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="small",
|
||||
default="base",
|
||||
dest='model_size',
|
||||
help="Name size of the Whisper model to use (default: tiny). Suggested values: tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo. The model is automatically downloaded from the model hub if not present in model cache dir.",
|
||||
)
|
||||
|
||||
@@ -109,14 +110,14 @@ def parse_args():
|
||||
"--language",
|
||||
type=str,
|
||||
default="auto",
|
||||
dest='lan',
|
||||
help="Source language code, e.g. en,de,cs, or 'auto' for language detection.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
default="transcribe",
|
||||
choices=["transcribe", "translate"],
|
||||
help="Transcribe or translate.",
|
||||
"--direct-english-translation",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use Whisper to directly translate to english.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@@ -128,11 +129,18 @@ def parse_args():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
"--backend-policy",
|
||||
type=str,
|
||||
default="simulstreaming",
|
||||
choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api", "simulstreaming"],
|
||||
help="Load only this backend for Whisper processing.",
|
||||
choices=["1", "2", "simulstreaming", "localagreement"],
|
||||
help="Select the streaming policy: 1 or 'simulstreaming' for AlignAtt, 2 or 'localagreement' for LocalAgreement.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api"],
|
||||
help="Select the Whisper backend implementation (auto: prefer MLX on macOS, otherwise Faster-Whisper, else Whisper). Use 'openai-api' with --backend-policy localagreement to call OpenAI's API.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-vac",
|
||||
@@ -173,11 +181,12 @@ def parse_args():
|
||||
)
|
||||
parser.add_argument("--ssl-certfile", type=str, help="Path to the SSL certificate file.", default=None)
|
||||
parser.add_argument("--ssl-keyfile", type=str, help="Path to the SSL private key file.", default=None)
|
||||
parser.add_argument("--forwarded-allow-ips", type=str, help="Allowed ips for reverse proxying.", default=None)
|
||||
parser.add_argument(
|
||||
"--pcm-input",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="If set, raw PCM (s16le) data is expected as input and FFmpeg will be bypassed."
|
||||
help="If set, raw PCM (s16le) data is expected as input and FFmpeg will be bypassed. Frontend will use AudioWorklet instead of MediaRecorder."
|
||||
)
|
||||
# SimulStreaming-specific arguments
|
||||
simulstreaming_group = parser.add_argument_group('SimulStreaming arguments (only used with --backend simulstreaming)')
|
||||
@@ -190,6 +199,13 @@ def parse_args():
|
||||
help="Disable Faster Whisper or MLX Whisper backends for encoding (if installed). Slower but helpful when GPU memory is limited",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--custom-alignment-heads",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Use your own alignment heads, useful when `--model-dir` is used",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--frame-threshold",
|
||||
type=int,
|
||||
@@ -287,6 +303,20 @@ def parse_args():
|
||||
help="Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent instances).",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--nllb-backend",
|
||||
type=str,
|
||||
default="transformers",
|
||||
help="transformers or ctranslate2",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--nllb-size",
|
||||
type=str,
|
||||
default="600M",
|
||||
help="600M or 1.3B",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
args.transcription = not args.no_transcription
|
||||
@@ -294,4 +324,9 @@ def parse_args():
|
||||
delattr(args, 'no_transcription')
|
||||
delattr(args, 'no_vad')
|
||||
|
||||
if args.backend_policy == "1":
|
||||
args.backend_policy = "simulstreaming"
|
||||
elif args.backend_policy == "2":
|
||||
args.backend_policy = "localagreement"
|
||||
|
||||
return args
|
||||
|
||||
@@ -1,110 +0,0 @@
|
||||
from whisperlivekit.timed_objects import ASRToken
|
||||
import re
|
||||
|
||||
MIN_SILENCE_DURATION = 4 #in seconds
|
||||
END_SILENCE_DURATION = 8 #in seconds. you should keep it important to not have false positive when the model lag is important
|
||||
END_SILENCE_DURATION_VAC = 3 #VAC is good at detecting silences, but we want to skip the smallest silences
|
||||
|
||||
def blank_to_silence(tokens):
|
||||
full_string = ''.join([t.text for t in tokens])
|
||||
patterns = [re.compile(r'(?:\s*\[BLANK_AUDIO\]\s*)+'), re.compile(r'(?:\s*\[typing\]\s*)+')]
|
||||
matches = []
|
||||
for pattern in patterns:
|
||||
for m in pattern.finditer(full_string):
|
||||
matches.append({
|
||||
'start': m.start(),
|
||||
'end': m.end()
|
||||
})
|
||||
if matches:
|
||||
# cleaned = pattern.sub(' ', full_string).strip()
|
||||
# print("Cleaned:", cleaned)
|
||||
cumulated_len = 0
|
||||
silence_token = None
|
||||
cleaned_tokens = []
|
||||
for token in tokens:
|
||||
if matches:
|
||||
start = cumulated_len
|
||||
end = cumulated_len + len(token.text)
|
||||
cumulated_len = end
|
||||
if start >= matches[0]['start'] and end <= matches[0]['end']:
|
||||
if silence_token: #previous token was already silence
|
||||
silence_token.start = min(silence_token.start, token.start)
|
||||
silence_token.end = max(silence_token.end, token.end)
|
||||
else: #new silence
|
||||
silence_token = ASRToken(
|
||||
start=token.start,
|
||||
end=token.end,
|
||||
speaker=-2,
|
||||
probability=0.95
|
||||
)
|
||||
else:
|
||||
if silence_token: #there was silence but no more
|
||||
if silence_token.end - silence_token.start >= MIN_SILENCE_DURATION:
|
||||
cleaned_tokens.append(
|
||||
silence_token
|
||||
)
|
||||
silence_token = None
|
||||
matches.pop(0)
|
||||
cleaned_tokens.append(token)
|
||||
# print(cleaned_tokens)
|
||||
return cleaned_tokens
|
||||
return tokens
|
||||
|
||||
def no_token_to_silence(tokens):
|
||||
new_tokens = []
|
||||
silence_token = None
|
||||
for token in tokens:
|
||||
if token.speaker == -2:
|
||||
if new_tokens and new_tokens[-1].speaker == -2: #if token is silence and previous one too
|
||||
new_tokens[-1].end = token.end
|
||||
else:
|
||||
new_tokens.append(token)
|
||||
|
||||
last_end = new_tokens[-1].end if new_tokens else 0.0
|
||||
if token.start - last_end >= MIN_SILENCE_DURATION: #if token is not silence but important gap
|
||||
if new_tokens and new_tokens[-1].speaker == -2:
|
||||
new_tokens[-1].end = token.start
|
||||
else:
|
||||
silence_token = ASRToken(
|
||||
start=last_end,
|
||||
end=token.start,
|
||||
speaker=-2,
|
||||
probability=0.95
|
||||
)
|
||||
new_tokens.append(silence_token)
|
||||
|
||||
if token.speaker != -2:
|
||||
new_tokens.append(token)
|
||||
return new_tokens
|
||||
|
||||
def ends_with_silence(tokens, buffer_transcription, buffer_diarization, current_time, vac_detected_silence):
|
||||
if not tokens:
|
||||
return [], buffer_transcription, buffer_diarization
|
||||
last_token = tokens[-1]
|
||||
if tokens and current_time and (
|
||||
current_time - last_token.end >= END_SILENCE_DURATION
|
||||
or
|
||||
(current_time - last_token.end >= 3 and vac_detected_silence)
|
||||
):
|
||||
if last_token.speaker == -2:
|
||||
last_token.end = current_time
|
||||
else:
|
||||
tokens.append(
|
||||
ASRToken(
|
||||
start=tokens[-1].end,
|
||||
end=current_time,
|
||||
speaker=-2,
|
||||
probability=0.95
|
||||
)
|
||||
)
|
||||
buffer_transcription = "" # for whisperstreaming backend, we should probably validate the buffer has because of the silence
|
||||
buffer_diarization = ""
|
||||
return tokens, buffer_transcription, buffer_diarization
|
||||
|
||||
|
||||
def handle_silences(tokens, buffer_transcription, buffer_diarization, current_time, vac_detected_silence):
|
||||
tokens = blank_to_silence(tokens) #useful for simulstreaming backend which tends to generate [BLANK_AUDIO] text
|
||||
tokens = no_token_to_silence(tokens)
|
||||
tokens, buffer_transcription, buffer_diarization = ends_with_silence(tokens, buffer_transcription, buffer_diarization, current_time, vac_detected_silence)
|
||||
return tokens, buffer_transcription, buffer_diarization
|
||||
|
||||
@@ -1,136 +0,0 @@
|
||||
|
||||
import logging
|
||||
from whisperlivekit.remove_silences import handle_silences
|
||||
from whisperlivekit.timed_objects import Line, format_time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
PUNCTUATION_MARKS = {'.', '!', '?', '。', '!', '?'}
|
||||
CHECK_AROUND = 4
|
||||
|
||||
def is_punctuation(token):
|
||||
if token.text.strip() in PUNCTUATION_MARKS:
|
||||
return True
|
||||
return False
|
||||
|
||||
def next_punctuation_change(i, tokens):
|
||||
for ind in range(i+1, min(len(tokens), i+CHECK_AROUND+1)):
|
||||
if is_punctuation(tokens[ind]):
|
||||
return ind
|
||||
return None
|
||||
|
||||
def next_speaker_change(i, tokens, speaker):
|
||||
for ind in range(i-1, max(0, i-CHECK_AROUND)-1, -1):
|
||||
token = tokens[ind]
|
||||
if is_punctuation(token):
|
||||
break
|
||||
if token.speaker != speaker:
|
||||
return ind, token.speaker
|
||||
return None, speaker
|
||||
|
||||
def new_line(
|
||||
token,
|
||||
speaker,
|
||||
debug_info = ""
|
||||
):
|
||||
return Line(
|
||||
speaker = speaker,
|
||||
text = token.text + debug_info,
|
||||
start = token.start,
|
||||
end = token.end,
|
||||
)
|
||||
|
||||
def append_token_to_last_line(lines, sep, token, debug_info):
|
||||
if token.text:
|
||||
lines[-1].text += sep + token.text + debug_info
|
||||
lines[-1].end = token.end
|
||||
|
||||
def format_output(state, silence, current_time, args, debug, sep):
|
||||
diarization = args.diarization
|
||||
disable_punctuation_split = args.disable_punctuation_split
|
||||
tokens = state.tokens
|
||||
translated_segments = state.translated_segments # Here we will attribute the speakers only based on the timestamps of the segments
|
||||
buffer_transcription = state.buffer_transcription
|
||||
buffer_diarization = state.buffer_diarization
|
||||
end_attributed_speaker = state.end_attributed_speaker
|
||||
|
||||
previous_speaker = -1
|
||||
lines = []
|
||||
undiarized_text = []
|
||||
tokens, buffer_transcription, buffer_diarization = handle_silences(tokens, buffer_transcription, buffer_diarization, current_time, silence)
|
||||
last_punctuation = None
|
||||
for i, token in enumerate(tokens):
|
||||
speaker = token.speaker
|
||||
if not diarization and speaker == -1: #Speaker -1 means no attributed by diarization. In the frontend, it should appear under 'Speaker 1'
|
||||
speaker = 1
|
||||
if diarization and not tokens[-1].speaker == -2:
|
||||
if (speaker in [-1, 0]) and token.end >= end_attributed_speaker:
|
||||
undiarized_text.append(token.text)
|
||||
continue
|
||||
elif (speaker in [-1, 0]) and token.end < end_attributed_speaker:
|
||||
speaker = previous_speaker
|
||||
debug_info = ""
|
||||
if debug:
|
||||
debug_info = f"[{format_time(token.start)} : {format_time(token.end)}]"
|
||||
|
||||
if not lines:
|
||||
lines.append(new_line(token, speaker, debug_info = ""))
|
||||
continue
|
||||
else:
|
||||
previous_speaker = lines[-1].speaker
|
||||
|
||||
if is_punctuation(token):
|
||||
last_punctuation = i
|
||||
|
||||
|
||||
if last_punctuation == i-1:
|
||||
if speaker != previous_speaker:
|
||||
# perfect, diarization perfectly aligned
|
||||
lines.append(new_line(token, speaker, debug_info = ""))
|
||||
last_punctuation, next_punctuation = None, None
|
||||
continue
|
||||
|
||||
speaker_change_pos, new_speaker = next_speaker_change(i, tokens, speaker)
|
||||
if speaker_change_pos:
|
||||
# Corrects delay:
|
||||
# That was the idea. Okay haha |SPLIT SPEAKER| that's a good one
|
||||
# should become:
|
||||
# That was the idea. |SPLIT SPEAKER| Okay haha that's a good one
|
||||
lines.append(new_line(token, new_speaker, debug_info = ""))
|
||||
else:
|
||||
# No speaker change to come
|
||||
append_token_to_last_line(lines, sep, token, debug_info)
|
||||
continue
|
||||
|
||||
|
||||
if speaker != previous_speaker:
|
||||
if speaker == -2 or previous_speaker == -2: #silences can happen anytime
|
||||
lines.append(new_line(token, speaker, debug_info = ""))
|
||||
continue
|
||||
elif next_punctuation_change(i, tokens):
|
||||
# Corrects advance:
|
||||
# Are you |SPLIT SPEAKER| okay? yeah, sure. Absolutely
|
||||
# should become:
|
||||
# Are you okay? |SPLIT SPEAKER| yeah, sure. Absolutely
|
||||
append_token_to_last_line(lines, sep, token, debug_info)
|
||||
continue
|
||||
else: #we create a new speaker, but that's no ideal. We are not sure about the split. We prefer to append to previous line
|
||||
if disable_punctuation_split:
|
||||
lines.append(new_line(token, speaker, debug_info = ""))
|
||||
continue
|
||||
pass
|
||||
|
||||
append_token_to_last_line(lines, sep, token, debug_info)
|
||||
if lines and translated_segments:
|
||||
cts_idx = 0 # current_translated_segment_idx
|
||||
for line in lines:
|
||||
while cts_idx < len(translated_segments):
|
||||
ts = translated_segments[cts_idx]
|
||||
if ts and ts.start and ts.start >= line.start and ts.end <= line.end:
|
||||
line.translation += ts.text + ' '
|
||||
cts_idx += 1
|
||||
else:
|
||||
break
|
||||
return lines, undiarized_text, buffer_transcription, ''
|
||||
|
||||
@@ -1,27 +1,182 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
# This is copied from silero-vad's vad_utils.py:
|
||||
# https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/utils_vad.py#L340
|
||||
# (except changed defaults)
|
||||
"""
|
||||
Code is adapted from silero-vad v6: https://github.com/snakers4/silero-vad
|
||||
"""
|
||||
|
||||
# Their licence is MIT, same as ours: https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/LICENSE
|
||||
def init_jit_model(model_path: str, device=torch.device('cpu')):
|
||||
"""Load a JIT model from file."""
|
||||
model = torch.jit.load(model_path, map_location=device)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
class OnnxWrapper():
|
||||
"""ONNX Runtime wrapper for Silero VAD model."""
|
||||
|
||||
def __init__(self, path, force_onnx_cpu=False):
|
||||
global np
|
||||
import numpy as np
|
||||
import onnxruntime
|
||||
|
||||
opts = onnxruntime.SessionOptions()
|
||||
opts.inter_op_num_threads = 1
|
||||
opts.intra_op_num_threads = 1
|
||||
|
||||
if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
|
||||
self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts)
|
||||
else:
|
||||
self.session = onnxruntime.InferenceSession(path, sess_options=opts)
|
||||
|
||||
self.reset_states()
|
||||
if '16k' in path:
|
||||
warnings.warn('This model support only 16000 sampling rate!')
|
||||
self.sample_rates = [16000]
|
||||
else:
|
||||
self.sample_rates = [8000, 16000]
|
||||
|
||||
def _validate_input(self, x, sr: int):
|
||||
if x.dim() == 1:
|
||||
x = x.unsqueeze(0)
|
||||
if x.dim() > 2:
|
||||
raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}")
|
||||
|
||||
if sr != 16000 and (sr % 16000 == 0):
|
||||
step = sr // 16000
|
||||
x = x[:,::step]
|
||||
sr = 16000
|
||||
|
||||
if sr not in self.sample_rates:
|
||||
raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)")
|
||||
if sr / x.shape[1] > 31.25:
|
||||
raise ValueError("Input audio chunk is too short")
|
||||
|
||||
return x, sr
|
||||
|
||||
def reset_states(self, batch_size=1):
|
||||
self._state = torch.zeros((2, batch_size, 128)).float()
|
||||
self._context = torch.zeros(0)
|
||||
self._last_sr = 0
|
||||
self._last_batch_size = 0
|
||||
|
||||
def __call__(self, x, sr: int):
|
||||
|
||||
x, sr = self._validate_input(x, sr)
|
||||
num_samples = 512 if sr == 16000 else 256
|
||||
|
||||
if x.shape[-1] != num_samples:
|
||||
raise ValueError(f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")
|
||||
|
||||
batch_size = x.shape[0]
|
||||
context_size = 64 if sr == 16000 else 32
|
||||
|
||||
if not self._last_batch_size:
|
||||
self.reset_states(batch_size)
|
||||
if (self._last_sr) and (self._last_sr != sr):
|
||||
self.reset_states(batch_size)
|
||||
if (self._last_batch_size) and (self._last_batch_size != batch_size):
|
||||
self.reset_states(batch_size)
|
||||
|
||||
if not len(self._context):
|
||||
self._context = torch.zeros(batch_size, context_size)
|
||||
|
||||
x = torch.cat([self._context, x], dim=1)
|
||||
if sr in [8000, 16000]:
|
||||
ort_inputs = {'input': x.numpy(), 'state': self._state.numpy(), 'sr': np.array(sr, dtype='int64')}
|
||||
ort_outs = self.session.run(None, ort_inputs)
|
||||
out, state = ort_outs
|
||||
self._state = torch.from_numpy(state)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
self._context = x[..., -context_size:]
|
||||
self._last_sr = sr
|
||||
self._last_batch_size = batch_size
|
||||
|
||||
out = torch.from_numpy(out)
|
||||
return out
|
||||
|
||||
|
||||
def load_silero_vad(model_path: str = None, onnx: bool = False, opset_version: int = 16):
|
||||
"""
|
||||
Load Silero VAD model (JIT or ONNX).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_path : str, optional
|
||||
Path to model file. If None, uses default bundled model.
|
||||
onnx : bool, default False
|
||||
Whether to use ONNX runtime (requires onnxruntime package).
|
||||
opset_version : int, default 16
|
||||
ONNX opset version (15 or 16). Only used if onnx=True.
|
||||
|
||||
Returns
|
||||
-------
|
||||
model
|
||||
Loaded VAD model (JIT or ONNX wrapper)
|
||||
"""
|
||||
available_ops = [15, 16]
|
||||
if onnx and opset_version not in available_ops:
|
||||
raise Exception(f'Available ONNX opset_version: {available_ops}')
|
||||
if model_path is None:
|
||||
current_dir = Path(__file__).parent
|
||||
data_dir = current_dir / 'vad_models'
|
||||
|
||||
if onnx:
|
||||
if opset_version == 16:
|
||||
model_name = 'silero_vad.onnx'
|
||||
else:
|
||||
model_name = f'silero_vad_16k_op{opset_version}.onnx'
|
||||
else:
|
||||
model_name = 'silero_vad.jit'
|
||||
|
||||
model_path = data_dir / model_name
|
||||
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Model file not found: {model_path}\n"
|
||||
f"Please ensure the whisperlivekit/vad_models/ directory contains the model files."
|
||||
)
|
||||
else:
|
||||
model_path = Path(model_path)
|
||||
if onnx:
|
||||
try:
|
||||
model = OnnxWrapper(str(model_path), force_onnx_cpu=True)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"ONNX runtime not available. Install with: pip install onnxruntime\n"
|
||||
"Or use JIT model by setting onnx=False"
|
||||
)
|
||||
else:
|
||||
model = init_jit_model(str(model_path))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class VADIterator:
|
||||
def __init__(
|
||||
self,
|
||||
"""
|
||||
Voice Activity Detection iterator for streaming audio.
|
||||
|
||||
This is the Silero VAD v6 implementation.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model,
|
||||
threshold: float = 0.5,
|
||||
sampling_rate: int = 16000,
|
||||
min_silence_duration_ms: int = 500, # makes sense on one recording that I checked
|
||||
speech_pad_ms: int = 100, # same
|
||||
min_silence_duration_ms: int = 100,
|
||||
speech_pad_ms: int = 30
|
||||
):
|
||||
|
||||
"""
|
||||
Class for stream imitation
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model: preloaded .jit silero VAD model
|
||||
model: preloaded .jit/.onnx silero VAD model
|
||||
|
||||
threshold: float (default - 0.5)
|
||||
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
|
||||
@@ -42,9 +197,7 @@ class VADIterator:
|
||||
self.sampling_rate = sampling_rate
|
||||
|
||||
if sampling_rate not in [8000, 16000]:
|
||||
raise ValueError(
|
||||
"VADIterator does not support sampling rates other than [8000, 16000]"
|
||||
)
|
||||
raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]')
|
||||
|
||||
self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
|
||||
self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
|
||||
@@ -57,13 +210,17 @@ class VADIterator:
|
||||
self.temp_end = 0
|
||||
self.current_sample = 0
|
||||
|
||||
def __call__(self, x, return_seconds=False):
|
||||
@torch.no_grad()
|
||||
def __call__(self, x, return_seconds=False, time_resolution: int = 1):
|
||||
"""
|
||||
x: torch.Tensor
|
||||
audio chunk (see examples in repo)
|
||||
|
||||
return_seconds: bool (default - False)
|
||||
whether return timestamps in seconds (default - samples)
|
||||
|
||||
time_resolution: int (default - 1)
|
||||
time resolution of speech coordinates when requested as seconds
|
||||
"""
|
||||
|
||||
if not torch.is_tensor(x):
|
||||
@@ -82,14 +239,8 @@ class VADIterator:
|
||||
|
||||
if (speech_prob >= self.threshold) and not self.triggered:
|
||||
self.triggered = True
|
||||
speech_start = self.current_sample - self.speech_pad_samples
|
||||
return {
|
||||
"start": (
|
||||
int(speech_start)
|
||||
if not return_seconds
|
||||
else round(speech_start / self.sampling_rate, 1)
|
||||
)
|
||||
}
|
||||
speech_start = max(0, self.current_sample - self.speech_pad_samples - window_size_samples)
|
||||
return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, time_resolution)}
|
||||
|
||||
if (speech_prob < self.threshold - 0.15) and self.triggered:
|
||||
if not self.temp_end:
|
||||
@@ -97,30 +248,17 @@ class VADIterator:
|
||||
if self.current_sample - self.temp_end < self.min_silence_samples:
|
||||
return None
|
||||
else:
|
||||
speech_end = self.temp_end + self.speech_pad_samples
|
||||
speech_end = self.temp_end + self.speech_pad_samples - window_size_samples
|
||||
self.temp_end = 0
|
||||
self.triggered = False
|
||||
return {
|
||||
"end": (
|
||||
int(speech_end)
|
||||
if not return_seconds
|
||||
else round(speech_end / self.sampling_rate, 1)
|
||||
)
|
||||
}
|
||||
return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, time_resolution)}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
#######################
|
||||
# because Silero now requires exactly 512-sized audio chunks
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class FixedVADIterator(VADIterator):
|
||||
"""It fixes VADIterator by allowing to process any audio length, not only exactly 512 frames at once.
|
||||
If audio to be processed at once is long and multiple voiced segments detected,
|
||||
then __call__ returns the start of the first segment, and end (or middle, which means no end) of the last segment.
|
||||
"""
|
||||
Fixed VAD Iterator that handles variable-length audio chunks, not only exactly 512 frames at once.
|
||||
"""
|
||||
|
||||
def reset_states(self):
|
||||
@@ -137,27 +275,20 @@ class FixedVADIterator(VADIterator):
|
||||
ret = r
|
||||
elif r is not None:
|
||||
if "end" in r:
|
||||
ret["end"] = r["end"] # the latter end
|
||||
if "start" in r and "end" in ret: # there is an earlier start.
|
||||
# Remove end, merging this segment with the previous one.
|
||||
ret["end"] = r["end"]
|
||||
if "start" in r and "end" in ret:
|
||||
del ret["end"]
|
||||
return ret if ret != {} else None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# test/demonstrate the need for FixedVADIterator:
|
||||
model = load_silero_vad(onnx=False)
|
||||
vad = FixedVADIterator(model)
|
||||
|
||||
import torch
|
||||
audio_buffer = np.array([0] * 512, dtype=np.float32)
|
||||
result = vad(audio_buffer)
|
||||
print(f" 512 samples: {result}")
|
||||
|
||||
model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
|
||||
vac = FixedVADIterator(model)
|
||||
# vac = VADIterator(model) # the second case crashes with this
|
||||
|
||||
# this works: for both
|
||||
audio_buffer = np.array([0] * (512), dtype=np.float32)
|
||||
vac(audio_buffer)
|
||||
|
||||
# this crashes on the non FixedVADIterator with
|
||||
# ops.prim.RaiseException("Input audio chunk is too short", "builtins.ValueError")
|
||||
audio_buffer = np.array([0] * (512 - 1), dtype=np.float32)
|
||||
vac(audio_buffer)
|
||||
# test with 511 samples
|
||||
audio_buffer = np.array([0] * 511, dtype=np.float32)
|
||||
result = vad(audio_buffer)
|
||||
@@ -2,40 +2,39 @@ import sys
|
||||
import numpy as np
|
||||
import logging
|
||||
from typing import List, Tuple, Optional
|
||||
import logging
|
||||
import platform
|
||||
from whisperlivekit.timed_objects import ASRToken, Transcript
|
||||
from whisperlivekit.timed_objects import ASRToken, Transcript, ChangeSpeaker
|
||||
from whisperlivekit.warmup import load_file
|
||||
from whisperlivekit.simul_whisper.license_simulstreaming import SIMULSTREAMING_LICENSE
|
||||
from .whisper import load_model, tokenizer
|
||||
from .whisper.audio import TOKENS_PER_SECOND
|
||||
from whisperlivekit.whisper import load_model, tokenizer
|
||||
from whisperlivekit.whisper.audio import TOKENS_PER_SECOND
|
||||
import os
|
||||
import gc
|
||||
logger = logging.getLogger(__name__)
|
||||
from pathlib import Path
|
||||
from whisperlivekit.model_paths import model_path_and_type, resolve_model_path
|
||||
from whisperlivekit.backend_support import (
|
||||
mlx_backend_available,
|
||||
faster_backend_available,
|
||||
)
|
||||
|
||||
import torch
|
||||
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
||||
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper
|
||||
from whisperlivekit.simul_whisper.whisper import tokenizer
|
||||
from whisperlivekit.simul_whisper.simul_whisper import AlignAtt
|
||||
|
||||
try:
|
||||
from .mlx_encoder import mlx_model_mapping, load_mlx_encoder
|
||||
HAS_MLX_WHISPER = True
|
||||
except ImportError:
|
||||
if platform.system() == "Darwin" and platform.machine() == "arm64":
|
||||
print('MLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: pip install mlx-whisper')
|
||||
HAS_MLX_WHISPER = False
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True)
|
||||
if HAS_MLX_WHISPER:
|
||||
HAS_FASTER_WHISPER = False
|
||||
from .mlx_encoder import mlx_model_mapping, load_mlx_encoder
|
||||
else:
|
||||
try:
|
||||
mlx_model_mapping = {}
|
||||
HAS_FASTER_WHISPER = faster_backend_available(warn_on_missing=not HAS_MLX_WHISPER)
|
||||
if HAS_FASTER_WHISPER:
|
||||
from faster_whisper import WhisperModel
|
||||
HAS_FASTER_WHISPER = True
|
||||
except ImportError:
|
||||
HAS_FASTER_WHISPER = False
|
||||
else:
|
||||
WhisperModel = None
|
||||
|
||||
|
||||
# TOO_MANY_REPETITIONS = 3
|
||||
MIN_DURATION_REAL_SILENCE = 5
|
||||
|
||||
class SimulStreamingOnlineProcessor:
|
||||
SAMPLING_RATE = 16000
|
||||
@@ -44,13 +43,11 @@ class SimulStreamingOnlineProcessor:
|
||||
self,
|
||||
asr,
|
||||
logfile=sys.stderr,
|
||||
warmup_file=None
|
||||
):
|
||||
self.asr = asr
|
||||
self.logfile = logfile
|
||||
self.end = 0.0
|
||||
self.global_time_offset = 0.0
|
||||
|
||||
self.buffer = []
|
||||
self.committed: List[ASRToken] = []
|
||||
self.last_result_tokens: List[ASRToken] = []
|
||||
self.load_new_backend()
|
||||
@@ -61,25 +58,31 @@ class SimulStreamingOnlineProcessor:
|
||||
|
||||
def load_new_backend(self):
|
||||
model = self.asr.get_new_model_instance()
|
||||
self.model = PaddedAlignAttWhisper(
|
||||
self.model = AlignAtt(
|
||||
cfg=self.asr.cfg,
|
||||
loaded_model=model,
|
||||
mlx_encoder=self.asr.mlx_encoder,
|
||||
fw_encoder=self.asr.fw_encoder,
|
||||
)
|
||||
|
||||
def insert_silence(self, silence_duration, offset):
|
||||
def start_silence(self):
|
||||
tokens, processed_upto = self.process_iter(is_last=True)
|
||||
return tokens, processed_upto
|
||||
|
||||
def end_silence(self, silence_duration, offset):
|
||||
"""
|
||||
If silences are > 5s, we do a complete context clear. Otherwise, we just insert a small silence and shift the last_attend_frame
|
||||
If silences are > MIN_DURATION_REAL_SILENCE, we do a complete context clear. Otherwise, we just insert a small silence and shift the last_attend_frame
|
||||
"""
|
||||
if silence_duration < 5:
|
||||
gap_silence = torch.zeros(int(16000*silence_duration))
|
||||
self.end += silence_duration
|
||||
long_silence = silence_duration >= MIN_DURATION_REAL_SILENCE
|
||||
if not long_silence:
|
||||
gap_len = int(16000 * silence_duration)
|
||||
if gap_len > 0:
|
||||
gap_silence = torch.zeros(gap_len)
|
||||
self.model.insert_audio(gap_silence)
|
||||
# self.global_time_offset += silence_duration
|
||||
else:
|
||||
self.process_iter(is_last=True) #we want to totally process what remains in the buffer.
|
||||
if long_silence:
|
||||
self.model.refresh_segment(complete=True)
|
||||
self.global_time_offset = silence_duration + offset
|
||||
self.model.global_time_offset = silence_duration + offset
|
||||
|
||||
|
||||
|
||||
@@ -91,63 +94,15 @@ class SimulStreamingOnlineProcessor:
|
||||
self.end = audio_stream_end_time #Only to be aligned with what happens in whisperstreaming backend.
|
||||
self.model.insert_audio(audio_tensor)
|
||||
|
||||
def new_speaker(self, change_speaker: ChangeSpeaker):
|
||||
self.process_iter(is_last=True)
|
||||
self.model.refresh_segment(complete=True)
|
||||
self.model.speaker = change_speaker.speaker
|
||||
self.global_time_offset = change_speaker.start
|
||||
|
||||
def get_buffer(self):
|
||||
return Transcript(
|
||||
start=None,
|
||||
end=None,
|
||||
text='',
|
||||
probability=None
|
||||
)
|
||||
|
||||
def timestamped_text(self, tokens, generation):
|
||||
"""
|
||||
generate timestamped text from tokens and generation data.
|
||||
|
||||
args:
|
||||
tokens: List of tokens to process
|
||||
generation: Dictionary containing generation progress and optionally results
|
||||
|
||||
returns:
|
||||
List of tuples containing (start_time, end_time, word) for each word
|
||||
"""
|
||||
FRAME_DURATION = 0.02
|
||||
if "result" in generation:
|
||||
split_words = generation["result"]["split_words"]
|
||||
split_tokens = generation["result"]["split_tokens"]
|
||||
else:
|
||||
split_words, split_tokens = self.model.tokenizer.split_to_word_tokens(tokens)
|
||||
progress = generation["progress"]
|
||||
frames = [p["most_attended_frames"][0] for p in progress]
|
||||
absolute_timestamps = [p["absolute_timestamps"][0] for p in progress]
|
||||
tokens_queue = tokens.copy()
|
||||
timestamped_words = []
|
||||
|
||||
for word, word_tokens in zip(split_words, split_tokens):
|
||||
# start_frame = None
|
||||
# end_frame = None
|
||||
for expected_token in word_tokens:
|
||||
if not tokens_queue or not frames:
|
||||
raise ValueError(f"Insufficient tokens or frames for word '{word}'")
|
||||
|
||||
actual_token = tokens_queue.pop(0)
|
||||
current_frame = frames.pop(0)
|
||||
current_timestamp = absolute_timestamps.pop(0)
|
||||
if actual_token != expected_token:
|
||||
raise ValueError(
|
||||
f"Token mismatch: expected '{expected_token}', "
|
||||
f"got '{actual_token}' at frame {current_frame}"
|
||||
)
|
||||
# if start_frame is None:
|
||||
# start_frame = current_frame
|
||||
# end_frame = current_frame
|
||||
# start_time = start_frame * FRAME_DURATION
|
||||
# end_time = end_frame * FRAME_DURATION
|
||||
start_time = current_timestamp
|
||||
end_time = current_timestamp + 0.1
|
||||
timestamp_entry = (start_time, end_time, word)
|
||||
timestamped_words.append(timestamp_entry)
|
||||
logger.debug(f"TS-WORD:\t{start_time:.2f}\t{end_time:.2f}\t{word}")
|
||||
return timestamped_words
|
||||
concat_buffer = Transcript.from_tokens(tokens= self.buffer, sep='')
|
||||
return concat_buffer
|
||||
|
||||
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
|
||||
"""
|
||||
@@ -156,47 +111,14 @@ class SimulStreamingOnlineProcessor:
|
||||
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
|
||||
"""
|
||||
try:
|
||||
tokens, generation_progress = self.model.infer(is_last=is_last)
|
||||
ts_words = self.timestamped_text(tokens, generation_progress)
|
||||
timestamped_words = self.model.infer(is_last=is_last)
|
||||
if self.model.cfg.language == "auto" and timestamped_words and timestamped_words[0].detected_language == None:
|
||||
self.buffer.extend(timestamped_words)
|
||||
return [], self.end
|
||||
|
||||
new_tokens = []
|
||||
for ts_word in ts_words:
|
||||
|
||||
start, end, word = ts_word
|
||||
token = ASRToken(
|
||||
start=start,
|
||||
end=end,
|
||||
text=word,
|
||||
probability=0.95 # fake prob. Maybe we can extract it from the model?
|
||||
).with_offset(
|
||||
self.global_time_offset
|
||||
)
|
||||
new_tokens.append(token)
|
||||
|
||||
# identical_tokens = 0
|
||||
# n_new_tokens = len(new_tokens)
|
||||
# if n_new_tokens:
|
||||
|
||||
self.committed.extend(new_tokens)
|
||||
|
||||
# if token in self.committed:
|
||||
# pos = len(self.committed) - 1 - self.committed[::-1].index(token)
|
||||
# if pos:
|
||||
# for i in range(len(self.committed) - n_new_tokens, -1, -n_new_tokens):
|
||||
# commited_segment = self.committed[i:i+n_new_tokens]
|
||||
# if commited_segment == new_tokens:
|
||||
# identical_segments +=1
|
||||
# if identical_tokens >= TOO_MANY_REPETITIONS:
|
||||
# logger.warning('Too many repetition, model is stuck. Load a new one')
|
||||
# self.committed = self.committed[:i]
|
||||
# self.load_new_backend()
|
||||
# return [], self.end
|
||||
|
||||
# pos = self.committed.rindex(token)
|
||||
|
||||
|
||||
|
||||
return new_tokens, self.end
|
||||
self.committed.extend(timestamped_words)
|
||||
self.buffer = []
|
||||
return timestamped_words, self.end
|
||||
|
||||
|
||||
except Exception as e:
|
||||
@@ -225,32 +147,34 @@ class SimulStreamingASR():
|
||||
"""SimulStreaming backend with AlignAtt policy."""
|
||||
sep = ""
|
||||
|
||||
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr, **kwargs):
|
||||
logger.warning(SIMULSTREAMING_LICENSE)
|
||||
def __init__(self, logfile=sys.stderr, **kwargs):
|
||||
self.logfile = logfile
|
||||
self.transcribe_kargs = {}
|
||||
self.original_language = lan
|
||||
|
||||
self.model_path = kwargs.get('model_path', './large-v3.pt')
|
||||
self.frame_threshold = kwargs.get('frame_threshold', 25)
|
||||
self.audio_max_len = kwargs.get('audio_max_len', 20.0)
|
||||
self.audio_min_len = kwargs.get('audio_min_len', 0.0)
|
||||
self.segment_length = kwargs.get('segment_length', 0.5)
|
||||
self.beams = kwargs.get('beams', 1)
|
||||
self.decoder_type = kwargs.get('decoder_type', 'greedy' if self.beams == 1 else 'beam')
|
||||
self.task = kwargs.get('task', 'transcribe')
|
||||
self.cif_ckpt_path = kwargs.get('cif_ckpt_path', None)
|
||||
self.never_fire = kwargs.get('never_fire', False)
|
||||
self.init_prompt = kwargs.get('init_prompt', None)
|
||||
self.static_init_prompt = kwargs.get('static_init_prompt', None)
|
||||
self.max_context_tokens = kwargs.get('max_context_tokens', None)
|
||||
self.warmup_file = kwargs.get('warmup_file', None)
|
||||
self.preload_model_count = kwargs.get('preload_model_count', 1)
|
||||
self.disable_fast_encoder = kwargs.get('disable_fast_encoder', False)
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
if self.decoder_type is None:
|
||||
self.decoder_type = 'greedy' if self.beams == 1 else 'beam'
|
||||
|
||||
self.fast_encoder = False
|
||||
if model_dir is not None:
|
||||
self.model_path = model_dir
|
||||
elif modelsize is not None:
|
||||
self._resolved_model_path = None
|
||||
self.encoder_backend = "whisper"
|
||||
preferred_backend = getattr(self, "backend", "auto")
|
||||
self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = None, True, True
|
||||
if self.model_path:
|
||||
resolved_model_path = resolve_model_path(self.model_path)
|
||||
self._resolved_model_path = resolved_model_path
|
||||
self.model_path = str(resolved_model_path)
|
||||
self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = model_path_and_type(resolved_model_path)
|
||||
if self.pytorch_path:
|
||||
self.model_name = self.pytorch_path.stem
|
||||
else:
|
||||
self.model_name = Path(self.model_path).stem
|
||||
raise FileNotFoundError(
|
||||
f"No PyTorch checkpoint (.pt/.bin/.safetensors) found under {self.model_path}"
|
||||
)
|
||||
elif self.model_size is not None:
|
||||
model_mapping = {
|
||||
'tiny': './tiny.pt',
|
||||
'base': './base.pt',
|
||||
@@ -265,19 +189,32 @@ class SimulStreamingASR():
|
||||
'large-v3': './large-v3.pt',
|
||||
'large': './large-v3.pt'
|
||||
}
|
||||
self.model_path = model_mapping.get(modelsize, f'./{modelsize}.pt')
|
||||
self.model_name = self.model_size
|
||||
else:
|
||||
raise ValueError("Either model_size or model_path must be specified for SimulStreaming.")
|
||||
|
||||
is_multilingual = not self.model_name.endswith(".en")
|
||||
|
||||
self.encoder_backend = self._resolve_encoder_backend(
|
||||
preferred_backend,
|
||||
compatible_whisper_mlx,
|
||||
compatible_faster_whisper,
|
||||
)
|
||||
self.fast_encoder = self.encoder_backend in ("mlx-whisper", "faster-whisper")
|
||||
if self.encoder_backend == "whisper":
|
||||
self.disable_fast_encoder = True
|
||||
|
||||
self.cfg = AlignAttConfig(
|
||||
model_path=self.model_path,
|
||||
segment_length=self.segment_length,
|
||||
tokenizer_is_multilingual= is_multilingual,
|
||||
segment_length=self.min_chunk_size,
|
||||
frame_threshold=self.frame_threshold,
|
||||
language=self.original_language,
|
||||
language=self.lan,
|
||||
audio_max_len=self.audio_max_len,
|
||||
audio_min_len=self.audio_min_len,
|
||||
cif_ckpt_path=self.cif_ckpt_path,
|
||||
decoder_type="beam",
|
||||
beam_size=self.beams,
|
||||
task=self.task,
|
||||
task=self.direct_english_translation,
|
||||
never_fire=self.never_fire,
|
||||
init_prompt=self.init_prompt,
|
||||
max_context_tokens=self.max_context_tokens,
|
||||
@@ -285,40 +222,93 @@ class SimulStreamingASR():
|
||||
)
|
||||
|
||||
# Set up tokenizer for translation if needed
|
||||
if self.task == "translate":
|
||||
if self.direct_english_translation:
|
||||
self.tokenizer = self.set_translate_task()
|
||||
else:
|
||||
self.tokenizer = None
|
||||
|
||||
self.model_name = os.path.basename(self.cfg.model_path).replace(".pt", "")
|
||||
self.model_path = os.path.dirname(os.path.abspath(self.cfg.model_path))
|
||||
|
||||
|
||||
|
||||
self.mlx_encoder, self.fw_encoder = None, None
|
||||
if not self.disable_fast_encoder:
|
||||
if HAS_MLX_WHISPER:
|
||||
print('Simulstreaming will use MLX whisper for a faster encoder.')
|
||||
mlx_model_name = mlx_model_mapping[self.model_name]
|
||||
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_name)
|
||||
self.fast_encoder = True
|
||||
elif HAS_FASTER_WHISPER:
|
||||
if self.encoder_backend == "mlx-whisper":
|
||||
print('Simulstreaming will use MLX whisper to increase encoding speed.')
|
||||
if self._resolved_model_path is not None:
|
||||
mlx_model = str(self._resolved_model_path)
|
||||
else:
|
||||
mlx_model = mlx_model_mapping.get(self.model_name)
|
||||
if not mlx_model:
|
||||
raise FileNotFoundError(
|
||||
f"MLX Whisper backend requested but no compatible weights found for model '{self.model_name}'."
|
||||
)
|
||||
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model)
|
||||
elif self.encoder_backend == "faster-whisper":
|
||||
print('Simulstreaming will use Faster Whisper for the encoder.')
|
||||
if self._resolved_model_path is not None:
|
||||
fw_model = str(self._resolved_model_path)
|
||||
else:
|
||||
fw_model = self.model_name
|
||||
self.fw_encoder = WhisperModel(
|
||||
self.model_name,
|
||||
fw_model,
|
||||
device='auto',
|
||||
compute_type='auto',
|
||||
)
|
||||
self.fast_encoder = True
|
||||
|
||||
self.models = [self.load_model() for i in range(self.preload_model_count)]
|
||||
|
||||
|
||||
def _resolve_encoder_backend(self, preferred_backend, compatible_whisper_mlx, compatible_faster_whisper):
|
||||
choice = preferred_backend or "auto"
|
||||
if self.disable_fast_encoder:
|
||||
return "whisper"
|
||||
if choice == "whisper":
|
||||
return "whisper"
|
||||
if choice == "mlx-whisper":
|
||||
if not self._can_use_mlx(compatible_whisper_mlx):
|
||||
raise RuntimeError("mlx-whisper backend requested but MLX Whisper is unavailable or incompatible with the provided model.")
|
||||
return "mlx-whisper"
|
||||
if choice == "faster-whisper":
|
||||
if not self._can_use_faster(compatible_faster_whisper):
|
||||
raise RuntimeError("faster-whisper backend requested but Faster-Whisper is unavailable or incompatible with the provided model.")
|
||||
return "faster-whisper"
|
||||
if choice == "openai-api":
|
||||
raise ValueError("openai-api backend is only supported with the LocalAgreement policy.")
|
||||
# auto mode
|
||||
if platform.system() == "Darwin" and self._can_use_mlx(compatible_whisper_mlx):
|
||||
return "mlx-whisper"
|
||||
if self._can_use_faster(compatible_faster_whisper):
|
||||
return "faster-whisper"
|
||||
return "whisper"
|
||||
|
||||
def _has_custom_model_path(self):
|
||||
return self._resolved_model_path is not None
|
||||
|
||||
def _can_use_mlx(self, compatible_whisper_mlx):
|
||||
if not HAS_MLX_WHISPER:
|
||||
return False
|
||||
if self._has_custom_model_path():
|
||||
return compatible_whisper_mlx
|
||||
return self.model_name in mlx_model_mapping
|
||||
|
||||
def _can_use_faster(self, compatible_faster_whisper):
|
||||
if not HAS_FASTER_WHISPER:
|
||||
return False
|
||||
if self._has_custom_model_path():
|
||||
return compatible_faster_whisper
|
||||
return True
|
||||
|
||||
def load_model(self):
|
||||
whisper_model = load_model(name=self.model_name, download_root=self.model_path, decoder_only=self.fast_encoder)
|
||||
whisper_model = load_model(
|
||||
name=self.pytorch_path if self.pytorch_path else self.model_name,
|
||||
download_root=self.model_path,
|
||||
decoder_only=self.fast_encoder,
|
||||
custom_alignment_heads=self.custom_alignment_heads
|
||||
)
|
||||
warmup_audio = load_file(self.warmup_file)
|
||||
if warmup_audio is not None:
|
||||
warmup_audio = torch.from_numpy(warmup_audio).float()
|
||||
if self.fast_encoder:
|
||||
temp_model = PaddedAlignAttWhisper(
|
||||
temp_model = AlignAtt(
|
||||
cfg=self.cfg,
|
||||
loaded_model=whisper_model,
|
||||
mlx_encoder=self.mlx_encoder,
|
||||
@@ -329,7 +319,7 @@ class SimulStreamingASR():
|
||||
else:
|
||||
# For standard encoder, use the original transcribe warmup
|
||||
warmup_audio = load_file(self.warmup_file)
|
||||
whisper_model.transcribe(warmup_audio, language=self.original_language if self.original_language != 'auto' else None)
|
||||
whisper_model.transcribe(warmup_audio, language=self.lan if self.lan != 'auto' else None)
|
||||
return whisper_model
|
||||
|
||||
def get_new_model_instance(self):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .whisper.decoding import PyTorchInference
|
||||
from whisperlivekit.whisper.decoding import PyTorchInference
|
||||
|
||||
# extention of PyTorchInference for beam search
|
||||
class BeamPyTorchInference(PyTorchInference):
|
||||
|
||||
@@ -1,25 +1,8 @@
|
||||
# This code was originally in simul_whisper/transcriber/simul_whisper.py . It is adapted a lot for SimulStreaming.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal
|
||||
|
||||
@dataclass
|
||||
class SimulWhisperConfig:
|
||||
'''Options that are common for all simul policies that could be implemented in SimulWhisper.'''
|
||||
model_path: str
|
||||
language: str = field(default="zh")
|
||||
nonspeech_prob: float = 0.5
|
||||
audio_min_len: float = 1.0
|
||||
decoder_type: Literal["greedy","beam"] = "greedy"
|
||||
beam_size: int = 5
|
||||
task: Literal["transcribe","translate"] = "transcribe"
|
||||
init_prompt: str = field(default=None)
|
||||
static_init_prompt: str = field(default=None)
|
||||
max_context_tokens: int = field(default=None)
|
||||
|
||||
@dataclass
|
||||
class AlignAttConfig(SimulWhisperConfig):
|
||||
'''Options specific to the AlignAtt policy.'''
|
||||
class AlignAttConfig():
|
||||
eval_data_path: str = "tmp"
|
||||
segment_length: float = field(default=1.0, metadata = {"help": "in second"})
|
||||
frame_threshold: int = 4
|
||||
@@ -27,3 +10,14 @@ class AlignAttConfig(SimulWhisperConfig):
|
||||
audio_max_len: float = 20.0
|
||||
cif_ckpt_path: str = ""
|
||||
never_fire: bool = False
|
||||
language: str = field(default="zh")
|
||||
nonspeech_prob: float = 0.5
|
||||
audio_min_len: float = 1.0
|
||||
decoder_type: Literal["greedy","beam"] = "greedy"
|
||||
beam_size: int = 5
|
||||
task: Literal["transcribe","translate"] = "transcribe"
|
||||
tokenizer_is_multilingual: bool = False
|
||||
init_prompt: str = field(default=None)
|
||||
static_init_prompt: str = field(default=None)
|
||||
max_context_tokens: int = field(default=None)
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
class Tokens:
|
||||
def __init__(self, tokens):
|
||||
self.tokens = tokens
|
||||
|
||||
# def clone(self):
|
||||
# return Tokens(self.tokens.clone())
|
||||
|
||||
def __str__(self):
|
||||
return str(self.tokens.tolist())
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
class BeamTokens(Tokens):
|
||||
def __init__(self, tokens, beam_size):
|
||||
self.tokens = tokens
|
||||
self.beam_size = beam_size
|
||||
|
||||
def clone(self):
|
||||
return BeamTokens(self.tokens.clone())
|
||||
|
||||
def __str__(self):
|
||||
return f"BeamTokens({self.tokens.tolist()}, beam_size={self.beam_size})"
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
def as_text(self, tokenizer):
|
||||
return tokenizer.decode(self.tokens)
|
||||
|
||||
class Logits(Tokens):
|
||||
def __init__(self, logits):
|
||||
super().__init__(logits)
|
||||
|
||||
# def clone(self):
|
||||
# return Logits(self.tokens.clone(), self.beam_size)
|
||||
|
||||
def __str__(self):
|
||||
# return "abc"
|
||||
return f"Logits({self.tokens.shape})"
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
@@ -1,5 +0,0 @@
|
||||
SIMULSTREAMING_LICENSE = f"""
|
||||
SimulStreaming backend is dual-licensed:
|
||||
• Non-Commercial Use: PolyForm Noncommercial License 1.0.0.
|
||||
• Commercial Use: Check SimulStreaming README (github.com/ufal/SimulStreaming) for more details.
|
||||
"""
|
||||
@@ -1,52 +1,57 @@
|
||||
# This code was originally in simul_whisper/transcriber/simul_whisper.py . It is adapted a lot for SimulStreaming.
|
||||
|
||||
import os
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from .whisper import load_model, DecodingOptions, tokenizer
|
||||
from whisperlivekit.whisper import DecodingOptions, tokenizer
|
||||
from .config import AlignAttConfig
|
||||
from .whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES
|
||||
from .whisper.timing import median_filter
|
||||
from .whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens, detect_language
|
||||
from whisperlivekit.timed_objects import ASRToken
|
||||
from whisperlivekit.whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES
|
||||
from whisperlivekit.whisper.timing import median_filter
|
||||
from whisperlivekit.whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens
|
||||
from .beam import BeamPyTorchInference
|
||||
from .eow_detection import fire_at_boundary, load_cif
|
||||
import os
|
||||
from time import time
|
||||
from .token_buffer import TokenBuffer
|
||||
from whisperlivekit.backend_support import (
|
||||
mlx_backend_available,
|
||||
faster_backend_available,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from .generation_progress import *
|
||||
from ..timed_objects import PUNCTUATION_MARKS
|
||||
|
||||
DEC_PAD = 50257
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
try:
|
||||
if mlx_backend_available():
|
||||
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
|
||||
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
|
||||
HAS_MLX_WHISPER = True
|
||||
except ImportError:
|
||||
HAS_MLX_WHISPER = False
|
||||
if HAS_MLX_WHISPER:
|
||||
HAS_FASTER_WHISPER = False
|
||||
else:
|
||||
try:
|
||||
|
||||
if faster_backend_available():
|
||||
from faster_whisper.audio import pad_or_trim as fw_pad_or_trim
|
||||
from faster_whisper.feature_extractor import FeatureExtractor
|
||||
HAS_FASTER_WHISPER = True
|
||||
except ImportError:
|
||||
HAS_FASTER_WHISPER = False
|
||||
|
||||
# New features added to the original version of Simul-Whisper:
|
||||
# - large-v3 model support
|
||||
# - translation support
|
||||
# - beam search
|
||||
# - prompt -- static vs. non-static
|
||||
# - context
|
||||
class PaddedAlignAttWhisper:
|
||||
USE_MLCORE = False
|
||||
|
||||
|
||||
def load_coreml_encoder():
|
||||
try:
|
||||
from coremltools.models import MLModel
|
||||
except ImportError:
|
||||
logger.warning("coremltools is not installed")
|
||||
return None
|
||||
COREML_ENCODER_PATH = os.environ.get("MLCORE_ENCODER_PATH", "whisperlivekit/whisper/whisper_encoder.mlpackage")
|
||||
_coreml_encoder = MLModel(COREML_ENCODER_PATH)
|
||||
spec = _coreml_encoder.get_spec()
|
||||
_coreml_input_name = spec.description.input[0].name if spec.description.input else "mel"
|
||||
_coreml_output_name = spec.description.output[0].name if spec.description.output else None
|
||||
return _coreml_encoder, _coreml_input_name, _coreml_output_name
|
||||
|
||||
|
||||
class AlignAtt:
|
||||
def __init__(
|
||||
self,
|
||||
cfg: AlignAttConfig,
|
||||
@@ -55,30 +60,32 @@ class PaddedAlignAttWhisper:
|
||||
fw_encoder=None,
|
||||
) -> None:
|
||||
self.log_segments = 0
|
||||
model_name = os.path.basename(cfg.model_path).replace(".pt", "")
|
||||
model_path = os.path.dirname(os.path.abspath(cfg.model_path))
|
||||
if loaded_model:
|
||||
|
||||
self.model = loaded_model
|
||||
else:
|
||||
self.model = load_model(name=model_name, download_root=model_path)
|
||||
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
self.mlx_encoder = mlx_encoder
|
||||
self.fw_encoder = fw_encoder
|
||||
if fw_encoder:
|
||||
self.fw_feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels)
|
||||
self.coreml_encoder_tuple = None
|
||||
if USE_MLCORE:
|
||||
self.coreml_encoder_tuple = load_coreml_encoder()
|
||||
self.use_mlcore = self.coreml_encoder_tuple is not None
|
||||
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
logger.info(f"Model dimensions: {self.model.dims}")
|
||||
|
||||
self.speaker = -1
|
||||
self.decode_options = DecodingOptions(
|
||||
language = cfg.language,
|
||||
without_timestamps = True,
|
||||
task=cfg.task
|
||||
)
|
||||
self.tokenizer_is_multilingual = not model_name.endswith(".en")
|
||||
self.tokenizer_is_multilingual = cfg.tokenizer_is_multilingual
|
||||
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
|
||||
# self.create_tokenizer('en')
|
||||
self.detected_language = cfg.language if cfg.language != "auto" else None
|
||||
self.global_time_offset = 0.0
|
||||
self.reset_tokenizer_to_auto_next_call = False
|
||||
|
||||
self.max_text_len = self.model.dims.n_text_ctx
|
||||
self.num_decoder_layers = len(self.model.decoder.blocks)
|
||||
@@ -153,6 +160,7 @@ class PaddedAlignAttWhisper:
|
||||
|
||||
self.last_attend_frame = -self.cfg.rewind_threshold
|
||||
self.cumulative_time_offset = 0.0
|
||||
self.first_timestamp = None
|
||||
|
||||
if self.cfg.max_context_tokens is None:
|
||||
self.max_context_tokens = self.max_text_len
|
||||
@@ -173,6 +181,9 @@ class PaddedAlignAttWhisper:
|
||||
|
||||
self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size)
|
||||
|
||||
# Tokens to carry over to next chunk for incomplete UTF-8 characters
|
||||
self.pending_incomplete_tokens = []
|
||||
|
||||
def remove_hooks(self):
|
||||
for hook in self.l_hooks:
|
||||
hook.remove()
|
||||
@@ -255,18 +266,18 @@ class PaddedAlignAttWhisper:
|
||||
logger.debug("Refreshing segment:")
|
||||
self.init_tokens()
|
||||
self.last_attend_frame = -self.cfg.rewind_threshold
|
||||
self.detected_language = None
|
||||
# self.detected_language = None
|
||||
self.cumulative_time_offset = 0.0
|
||||
self.init_context()
|
||||
logger.debug(f"Context: {self.context}")
|
||||
if not complete and len(self.segments) > 2:
|
||||
logger.debug("keeping last two segments because they are and it is not complete.")
|
||||
self.segments = self.segments[-2:]
|
||||
else:
|
||||
logger.debug("removing all segments.")
|
||||
self.segments = []
|
||||
self.log_segments += 1
|
||||
|
||||
self.pending_incomplete_tokens = []
|
||||
|
||||
def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor):
|
||||
if self.always_fire: return True
|
||||
@@ -328,7 +339,7 @@ class PaddedAlignAttWhisper:
|
||||
self.segments = self.segments[1:]
|
||||
logger.debug(f"remove segments: {len(self.segments)} {len(self.tokens)}, cumulative offset: {self.cumulative_time_offset:.2f}s")
|
||||
if len(self.tokens) > 1:
|
||||
self.context.append_token_ids(self.tokens[1][0,:])
|
||||
self.context.append_token_ids(self.tokens[1][0,:].tolist())
|
||||
self.tokens = [self.initial_tokens] + self.tokens[2:]
|
||||
return removed_len
|
||||
|
||||
@@ -382,11 +393,11 @@ class PaddedAlignAttWhisper:
|
||||
new_segment = True
|
||||
if len(self.segments) == 0:
|
||||
logger.debug("No segments, nothing to do")
|
||||
return [], {}
|
||||
return []
|
||||
if not self._apply_minseglen():
|
||||
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
|
||||
input_segments = torch.cat(self.segments, dim=0)
|
||||
return [], {}
|
||||
return []
|
||||
|
||||
# input_segments is concatenation of audio, it's one array
|
||||
if len(self.segments) > 1:
|
||||
@@ -394,8 +405,28 @@ class PaddedAlignAttWhisper:
|
||||
else:
|
||||
input_segments = self.segments[0]
|
||||
|
||||
# NEW : we can use a different encoder, before using standart whisper for cross attention with the hooks on the decoder
|
||||
beg_encode = time()
|
||||
if self.use_mlcore:
|
||||
coreml_encoder, coreml_input_name, coreml_output_name = self.coreml_encoder_tuple
|
||||
mel_padded = log_mel_spectrogram(
|
||||
input_segments,
|
||||
n_mels=self.model.dims.n_mels,
|
||||
padding=N_SAMPLES,
|
||||
device="cpu",
|
||||
).unsqueeze(0)
|
||||
mel = pad_or_trim(mel_padded, N_FRAMES)
|
||||
content_mel_len = int((mel_padded.shape[2] - mel.shape[2]) / 2)
|
||||
mel_np = np.ascontiguousarray(mel.numpy())
|
||||
ml_inputs = {coreml_input_name or "mel": mel_np}
|
||||
coreml_outputs = coreml_encoder.predict(ml_inputs)
|
||||
if coreml_output_name and coreml_output_name in coreml_outputs:
|
||||
encoder_feature_np = coreml_outputs[coreml_output_name]
|
||||
else:
|
||||
encoder_feature_np = next(iter(coreml_outputs.values()))
|
||||
encoder_feature = torch.as_tensor(
|
||||
np.array(encoder_feature_np),
|
||||
device=self.device,
|
||||
)
|
||||
if self.mlx_encoder:
|
||||
mlx_mel_padded = mlx_log_mel_spectrogram(audio=input_segments.detach(), n_mels=self.model.dims.n_mels, padding=N_SAMPLES)
|
||||
mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2)
|
||||
@@ -426,58 +457,38 @@ class PaddedAlignAttWhisper:
|
||||
end_encode = time()
|
||||
# print('Encoder duration:', end_encode-beg_encode)
|
||||
|
||||
# logger.debug(f"Encoder feature shape: {encoder_feature.shape}")
|
||||
# if mel.shape[-2:] != (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
|
||||
# logger.debug("mel ")
|
||||
if self.cfg.language == "auto" and self.detected_language is None:
|
||||
if self.cfg.language == "auto" and self.detected_language is None and self.first_timestamp:
|
||||
seconds_since_start = self.segments_len() - self.first_timestamp
|
||||
if seconds_since_start >= 2.0:
|
||||
language_tokens, language_probs = self.lang_id(encoder_feature)
|
||||
logger.debug(f"Language tokens: {language_tokens}, probs: {language_probs}")
|
||||
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
|
||||
logger.info(f"Detected language: {top_lan} with p={p:.4f}")
|
||||
#self.tokenizer.language = top_lan
|
||||
#self.tokenizer.__post_init__()
|
||||
print(f"Detected language: {top_lan} with p={p:.4f}")
|
||||
self.create_tokenizer(top_lan)
|
||||
self.detected_language = top_lan
|
||||
self.last_attend_frame = -self.cfg.rewind_threshold
|
||||
self.cumulative_time_offset = 0.0
|
||||
self.init_tokens()
|
||||
self.init_context()
|
||||
self.detected_language = top_lan
|
||||
logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}")
|
||||
|
||||
self.trim_context()
|
||||
current_tokens = self._current_tokens()
|
||||
#
|
||||
|
||||
fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :])
|
||||
|
||||
|
||||
####################### Decoding loop
|
||||
logger.info("Decoding loop starts\n")
|
||||
|
||||
sum_logprobs = torch.zeros(self.cfg.beam_size, device=self.device)
|
||||
completed = False
|
||||
# punctuation_stop = False
|
||||
|
||||
attn_of_alignment_heads = None
|
||||
most_attended_frame = None
|
||||
|
||||
token_len_before_decoding = current_tokens.shape[1]
|
||||
|
||||
generation_progress = []
|
||||
generation = {
|
||||
"starting_tokens": BeamTokens(current_tokens[0,:].clone(), self.cfg.beam_size),
|
||||
"token_len_before_decoding": token_len_before_decoding,
|
||||
#"fire_detected": fire_detected,
|
||||
"frames_len": content_mel_len,
|
||||
"frames_threshold": 4 if is_last else self.cfg.frame_threshold,
|
||||
l_absolute_timestamps = []
|
||||
|
||||
# to be filled later
|
||||
"logits_starting": None,
|
||||
|
||||
# to be filled later
|
||||
"no_speech_prob": None,
|
||||
"no_speech": False,
|
||||
|
||||
# to be filled in the loop
|
||||
"progress": generation_progress,
|
||||
}
|
||||
while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens
|
||||
generation_progress_loop = []
|
||||
|
||||
if new_segment:
|
||||
tokens_for_logits = current_tokens
|
||||
@@ -486,50 +497,26 @@ class PaddedAlignAttWhisper:
|
||||
tokens_for_logits = current_tokens[:,-1:]
|
||||
|
||||
logits = self.logits(tokens_for_logits, encoder_feature) # B, len(tokens), token dict size
|
||||
if new_segment:
|
||||
generation["logits_starting"] = Logits(logits[:,:,:])
|
||||
|
||||
if new_segment and self.tokenizer.no_speech is not None:
|
||||
probs_at_sot = logits[:, self.sot_index, :].float().softmax(dim=-1)
|
||||
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
||||
generation["no_speech_prob"] = no_speech_probs[0]
|
||||
if no_speech_probs[0] > self.cfg.nonspeech_prob:
|
||||
generation["no_speech"] = True
|
||||
logger.info("no speech, stop")
|
||||
break
|
||||
|
||||
logits = logits[:, -1, :] # logits for the last token
|
||||
generation_progress_loop.append(("logits_before_suppress",Logits(logits)))
|
||||
|
||||
# supress blank tokens only at the beginning of the segment
|
||||
if new_segment:
|
||||
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
|
||||
new_segment = False
|
||||
self.suppress_tokens(logits)
|
||||
#generation_progress_loop.append(("logits_after_suppres",BeamLogits(logits[0,:].clone(), self.cfg.beam_size)))
|
||||
generation_progress_loop.append(("logits_after_suppress",Logits(logits)))
|
||||
|
||||
current_tokens, completed = self.token_decoder.update(current_tokens, logits, sum_logprobs)
|
||||
generation_progress_loop.append(("beam_tokens",Tokens(current_tokens[:,-1].clone())))
|
||||
generation_progress_loop.append(("sum_logprobs",sum_logprobs.tolist()))
|
||||
generation_progress_loop.append(("completed",completed))
|
||||
|
||||
logger.debug(f"Decoding completed: {completed}, sum_logprobs: {sum_logprobs.tolist()}, tokens: ")
|
||||
self.debug_print_tokens(current_tokens)
|
||||
|
||||
|
||||
# if self.decoder_type == "beam":
|
||||
# logger.debug(f"Finished sequences: {self.token_decoder.finished_sequences}")
|
||||
|
||||
# logprobs = F.log_softmax(logits.float(), dim=-1)
|
||||
# idx = 0
|
||||
# logger.debug(f"Beam search topk: {logprobs[idx].topk(self.cfg.beam_size + 1)}")
|
||||
# logger.debug(f"Greedy search argmax: {logits.argmax(dim=-1)}")
|
||||
# if completed:
|
||||
# self.debug_print_tokens(current_tokens)
|
||||
|
||||
# logger.debug("decode stopped because decoder completed")
|
||||
|
||||
attn_of_alignment_heads = [[] for _ in range(self.num_align_heads)]
|
||||
for i, attn_mat in enumerate(self.dec_attns):
|
||||
layer_rank = int(i % len(self.model.decoder.blocks))
|
||||
@@ -548,30 +535,24 @@ class PaddedAlignAttWhisper:
|
||||
t = torch.cat(mat, dim=1)
|
||||
tmp.append(t)
|
||||
attn_of_alignment_heads = torch.stack(tmp, dim=1)
|
||||
# logger.debug(str(attn_of_alignment_heads.shape) + " tttady")
|
||||
std, mean = torch.std_mean(attn_of_alignment_heads, dim=-2, keepdim=True, unbiased=False)
|
||||
attn_of_alignment_heads = (attn_of_alignment_heads - mean) / std
|
||||
attn_of_alignment_heads = median_filter(attn_of_alignment_heads, 7) # from whisper.timing
|
||||
attn_of_alignment_heads = attn_of_alignment_heads.mean(dim=1)
|
||||
# logger.debug(str(attn_of_alignment_heads.shape) + " po mean")
|
||||
attn_of_alignment_heads = attn_of_alignment_heads[:,:, :content_mel_len]
|
||||
# logger.debug(str(attn_of_alignment_heads.shape) + " pak ")
|
||||
|
||||
# for each beam, the most attended frame is:
|
||||
most_attended_frames = torch.argmax(attn_of_alignment_heads[:,-1,:], dim=-1)
|
||||
generation_progress_loop.append(("most_attended_frames",most_attended_frames.clone().tolist()))
|
||||
|
||||
# Calculate absolute timestamps accounting for cumulative offset
|
||||
absolute_timestamps = [(frame * 0.02 + self.cumulative_time_offset) for frame in most_attended_frames.tolist()]
|
||||
generation_progress_loop.append(("absolute_timestamps", absolute_timestamps))
|
||||
|
||||
logger.debug(str(most_attended_frames.tolist()) + " most att frames")
|
||||
logger.debug(f"Absolute timestamps: {absolute_timestamps} (offset: {self.cumulative_time_offset:.2f}s)")
|
||||
|
||||
most_attended_frame = most_attended_frames[0].item()
|
||||
l_absolute_timestamps.append(absolute_timestamps[0])
|
||||
|
||||
|
||||
generation_progress.append(dict(generation_progress_loop))
|
||||
logger.debug("current tokens" + str(current_tokens.shape))
|
||||
if completed:
|
||||
# # stripping the last token, the eot
|
||||
@@ -609,66 +590,71 @@ class PaddedAlignAttWhisper:
|
||||
self.tokenizer.decode([current_tokens[i, -1].item()])
|
||||
))
|
||||
|
||||
# for k,v in generation.items():
|
||||
# print(k,v,file=sys.stderr)
|
||||
# for x in generation_progress:
|
||||
# for y in x.items():
|
||||
# print("\t\t",*y,file=sys.stderr)
|
||||
# print("\t","----", file=sys.stderr)
|
||||
# print("\t", "end of generation_progress_loop", file=sys.stderr)
|
||||
# sys.exit(1)
|
||||
####################### End of decoding loop
|
||||
|
||||
logger.info("End of decoding loop")
|
||||
|
||||
# if attn_of_alignment_heads is not None:
|
||||
# seg_len = int(segment.shape[0] / 16000 * TOKENS_PER_SECOND)
|
||||
|
||||
# # Lets' now consider only the top hypothesis in the beam search
|
||||
# top_beam_attn_of_alignment_heads = attn_of_alignment_heads[0]
|
||||
|
||||
# # debug print: how is the new token attended?
|
||||
# new_token_attn = top_beam_attn_of_alignment_heads[token_len_before_decoding:, -seg_len:]
|
||||
# logger.debug(f"New token attention shape: {new_token_attn.shape}")
|
||||
# if new_token_attn.shape[0] == 0: # it's not attended in the current audio segment
|
||||
# logger.debug("no token generated")
|
||||
# else: # it is, and the max attention is:
|
||||
# new_token_max_attn, _ = new_token_attn.max(dim=-1)
|
||||
# logger.debug(f"segment max attention: {new_token_max_attn.mean().item()/len(self.segments)}")
|
||||
|
||||
|
||||
# let's now operate only with the top beam hypothesis
|
||||
tokens_to_split = current_tokens[0, token_len_before_decoding:]
|
||||
if fire_detected or is_last:
|
||||
|
||||
# Prepend pending tokens from previous chunk if any
|
||||
if self.pending_incomplete_tokens:
|
||||
logger.debug(f"[UTF-8 Fix] Prepending {len(self.pending_incomplete_tokens)} pending tokens: {self.pending_incomplete_tokens}")
|
||||
pending_tensor = torch.tensor(self.pending_incomplete_tokens, dtype=torch.long, device=self.device)
|
||||
tokens_to_split = torch.cat([pending_tensor, tokens_to_split])
|
||||
|
||||
if fire_detected or is_last: #or punctuation_stop:
|
||||
new_hypothesis = tokens_to_split.flatten().tolist()
|
||||
split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis)
|
||||
else:
|
||||
# going to truncate the tokens after the last space
|
||||
split_words, split_tokens = self.tokenizer.split_to_word_tokens(tokens_to_split.tolist())
|
||||
generation["result"] = {"split_words": split_words[:-1], "split_tokens": split_tokens[:-1]}
|
||||
generation["result_truncated"] = {"split_words": split_words[-1:], "split_tokens": split_tokens[-1:]}
|
||||
|
||||
# text_to_split = self.tokenizer.decode(tokens_to_split)
|
||||
# logger.debug(f"text_to_split: {text_to_split}")
|
||||
# logger.debug("text at current step: {}".format(text_to_split.replace(" ", "<space>")))
|
||||
# text_before_space = " ".join(text_to_split.split(" ")[:-1])
|
||||
# logger.debug("before the last space: {}".format(text_before_space.replace(" ", "<space>")))
|
||||
if len(split_words) > 1:
|
||||
new_hypothesis = [i for sublist in split_tokens[:-1] for i in sublist]
|
||||
else:
|
||||
new_hypothesis = []
|
||||
|
||||
|
||||
### new hypothesis
|
||||
logger.debug(f"new_hypothesis: {new_hypothesis}")
|
||||
new_tokens = torch.tensor([new_hypothesis], dtype=torch.long).repeat_interleave(self.cfg.beam_size, dim=0).to(
|
||||
device=self.device,
|
||||
)
|
||||
self.tokens.append(new_tokens)
|
||||
# TODO: test if this is redundant or not
|
||||
# ret = ret[ret<DEC_PAD]
|
||||
|
||||
logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}")
|
||||
|
||||
self._clean_cache()
|
||||
|
||||
return new_hypothesis, generation
|
||||
if len(l_absolute_timestamps) >=2 and self.first_timestamp is None:
|
||||
self.first_timestamp = l_absolute_timestamps[0]
|
||||
|
||||
|
||||
timestamped_words = []
|
||||
timestamp_idx = 0
|
||||
replacement_char = "\ufffd"
|
||||
for word, word_tokens in zip(split_words, split_tokens):
|
||||
# Skip words containing incomplete UTF-8 from client output
|
||||
if replacement_char in word:
|
||||
logger.warning(f"[UTF-8 Filter] Skipping incomplete word from client output: {repr(word)}")
|
||||
timestamp_idx += len(word_tokens)
|
||||
continue
|
||||
|
||||
try:
|
||||
current_timestamp = l_absolute_timestamps[timestamp_idx]
|
||||
except:
|
||||
pass
|
||||
timestamp_idx += len(word_tokens)
|
||||
|
||||
timestamp_entry = ASRToken(
|
||||
start=round(current_timestamp, 2),
|
||||
end=round(current_timestamp + 0.1, 2),
|
||||
text= word,
|
||||
speaker=self.speaker,
|
||||
detected_language=self.detected_language
|
||||
).with_offset(
|
||||
self.global_time_offset
|
||||
)
|
||||
timestamped_words.append(timestamp_entry)
|
||||
|
||||
# Hold incomplete tokens for next chunk
|
||||
self.pending_incomplete_tokens = []
|
||||
if split_words and replacement_char in split_words[-1]:
|
||||
self.pending_incomplete_tokens = split_tokens[-1]
|
||||
logger.warning(f"[UTF-8 Fix] Holding {len(self.pending_incomplete_tokens)} incomplete tokens for next chunk: {self.pending_incomplete_tokens}")
|
||||
|
||||
return timestamped_words
|
||||
|
||||
@@ -7,6 +7,7 @@ class TokenBuffer:
|
||||
self.prefix_token_ids = prefix_token_ids
|
||||
self.tokenizer = tokenizer
|
||||
self.device = device
|
||||
self.pending_token_ids = []
|
||||
|
||||
def as_token_ids(self, tokenizer=None):
|
||||
|
||||
@@ -64,7 +65,26 @@ class TokenBuffer:
|
||||
def append_token_ids(self, token_ids):
|
||||
tokenizer = self.tokenizer
|
||||
assert tokenizer is not None, "Tokenizer is not set."
|
||||
self.text += self.tokenizer.decode(token_ids)
|
||||
|
||||
all_tokens = self.pending_token_ids + token_ids
|
||||
|
||||
decoded = tokenizer.decode(all_tokens)
|
||||
replacement_char = "\ufffd"
|
||||
|
||||
if replacement_char in decoded:
|
||||
if len(all_tokens) > 1:
|
||||
decoded_partial = tokenizer.decode(all_tokens[:-1])
|
||||
|
||||
if replacement_char not in decoded_partial:
|
||||
self.text += decoded_partial
|
||||
self.pending_token_ids = [all_tokens[-1]]
|
||||
else:
|
||||
self.pending_token_ids = all_tokens
|
||||
else:
|
||||
self.pending_token_ids = all_tokens
|
||||
else:
|
||||
self.text += decoded
|
||||
self.pending_token_ids = []
|
||||
|
||||
def as_split_word_tokens(self):
|
||||
tokenizer = self.tokenizer
|
||||
|
||||
@@ -1,168 +0,0 @@
|
||||
import hashlib
|
||||
import io
|
||||
import os
|
||||
import urllib
|
||||
import warnings
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
|
||||
from .model import ModelDimensions, Whisper
|
||||
from .transcribe import transcribe
|
||||
from .version import __version__
|
||||
|
||||
_MODELS = {
|
||||
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
||||
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
|
||||
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
|
||||
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
|
||||
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
|
||||
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
|
||||
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
|
||||
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
|
||||
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
|
||||
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
||||
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
||||
"large-v3-turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
|
||||
"turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
|
||||
}
|
||||
|
||||
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
|
||||
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
|
||||
_ALIGNMENT_HEADS = {
|
||||
"tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
|
||||
"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
|
||||
"base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
|
||||
"base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
|
||||
"small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
|
||||
"small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
|
||||
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
|
||||
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
|
||||
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
|
||||
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
|
||||
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
||||
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
||||
"large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
|
||||
"turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
|
||||
}
|
||||
|
||||
|
||||
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
||||
os.makedirs(root, exist_ok=True)
|
||||
|
||||
expected_sha256 = url.split("/")[-2]
|
||||
download_target = os.path.join(root, os.path.basename(url))
|
||||
|
||||
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
||||
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
||||
|
||||
if os.path.isfile(download_target):
|
||||
with open(download_target, "rb") as f:
|
||||
model_bytes = f.read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
||||
return model_bytes if in_memory else download_target
|
||||
else:
|
||||
warnings.warn(
|
||||
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
|
||||
)
|
||||
|
||||
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
||||
with tqdm(
|
||||
total=int(source.info().get("Content-Length")),
|
||||
ncols=80,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as loop:
|
||||
while True:
|
||||
buffer = source.read(8192)
|
||||
if not buffer:
|
||||
break
|
||||
|
||||
output.write(buffer)
|
||||
loop.update(len(buffer))
|
||||
|
||||
model_bytes = open(download_target, "rb").read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
||||
raise RuntimeError(
|
||||
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
||||
)
|
||||
|
||||
return model_bytes if in_memory else download_target
|
||||
|
||||
|
||||
def available_models() -> List[str]:
|
||||
"""Returns the names of available models"""
|
||||
return list(_MODELS.keys())
|
||||
|
||||
|
||||
def load_model(
|
||||
name: str,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
download_root: str = None,
|
||||
in_memory: bool = False,
|
||||
decoder_only=False
|
||||
) -> Whisper:
|
||||
"""
|
||||
Load a Whisper ASR model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
one of the official model names listed by `whisper.available_models()`, or
|
||||
path to a model checkpoint containing the model dimensions and the model state_dict.
|
||||
device : Union[str, torch.device]
|
||||
the PyTorch device to put the model into
|
||||
download_root: str
|
||||
path to download the model files; by default, it uses "~/.cache/whisper"
|
||||
in_memory: bool
|
||||
whether to preload the model weights into host memory
|
||||
|
||||
Returns
|
||||
-------
|
||||
model : Whisper
|
||||
The Whisper ASR model instance
|
||||
"""
|
||||
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
if download_root is None:
|
||||
default = os.path.join(os.path.expanduser("~"), ".cache")
|
||||
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
|
||||
|
||||
if name in _MODELS:
|
||||
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
|
||||
alignment_heads = _ALIGNMENT_HEADS[name]
|
||||
elif os.path.isfile(name):
|
||||
checkpoint_file = open(name, "rb").read() if in_memory else name
|
||||
alignment_heads = None
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Model {name} not found; available models = {available_models()}"
|
||||
)
|
||||
|
||||
with (
|
||||
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
|
||||
) as fp:
|
||||
checkpoint = torch.load(fp, map_location=device)
|
||||
del checkpoint_file
|
||||
|
||||
dims = ModelDimensions(**checkpoint["dims"])
|
||||
model = Whisper(dims, decoder_only=decoder_only)
|
||||
|
||||
if decoder_only:
|
||||
checkpoint["model_state_dict"] = {
|
||||
k: v for k, v in checkpoint["model_state_dict"].items()
|
||||
if 'encoder' not in k
|
||||
}
|
||||
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
|
||||
if alignment_heads is not None:
|
||||
model.set_alignment_heads(alignment_heads)
|
||||
|
||||
return model.to(device)
|
||||
@@ -1,26 +1,52 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from typing import Optional, List, Union, Dict, Any
|
||||
from datetime import timedelta
|
||||
|
||||
PUNCTUATION_MARKS = {'.', '!', '?', '。', '!', '?'}
|
||||
|
||||
def format_time(seconds: float) -> str:
|
||||
"""Format seconds as HH:MM:SS."""
|
||||
return str(timedelta(seconds=int(seconds)))
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimedText:
|
||||
class Timed:
|
||||
start: Optional[float] = 0
|
||||
end: Optional[float] = 0
|
||||
text: Optional[str] = ''
|
||||
speaker: Optional[int] = -1
|
||||
probability: Optional[float] = None
|
||||
is_dummy: Optional[bool] = False
|
||||
|
||||
@dataclass
|
||||
class TimedText(Timed):
|
||||
text: Optional[str] = ''
|
||||
speaker: Optional[int] = -1
|
||||
detected_language: Optional[str] = None
|
||||
|
||||
def has_punctuation(self) -> bool:
|
||||
return any(char in PUNCTUATION_MARKS for char in self.text.strip())
|
||||
|
||||
def is_within(self, other: 'TimedText') -> bool:
|
||||
return other.contains_timespan(self)
|
||||
|
||||
def duration(self) -> float:
|
||||
return self.end - self.start
|
||||
|
||||
def contains_timespan(self, other: 'TimedText') -> bool:
|
||||
return self.start <= other.start and self.end >= other.end
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.text)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.text)
|
||||
|
||||
@dataclass()
|
||||
class ASRToken(TimedText):
|
||||
|
||||
def with_offset(self, offset: float) -> "ASRToken":
|
||||
"""Return a new token with the time offset added."""
|
||||
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, self.probability)
|
||||
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, detected_language=self.detected_language)
|
||||
|
||||
def is_silence(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class Sentence(TimedText):
|
||||
@@ -28,13 +54,35 @@ class Sentence(TimedText):
|
||||
|
||||
@dataclass
|
||||
class Transcript(TimedText):
|
||||
pass
|
||||
"""
|
||||
represents a concatenation of several ASRToken
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_tokens(
|
||||
cls,
|
||||
tokens: List[ASRToken],
|
||||
sep: Optional[str] = None,
|
||||
offset: float = 0
|
||||
) -> "Transcript":
|
||||
"""Collapse multiple ASR tokens into a single transcript span."""
|
||||
sep = sep if sep is not None else ' '
|
||||
text = sep.join(token.text for token in tokens)
|
||||
if tokens:
|
||||
start = offset + tokens[0].start
|
||||
end = offset + tokens[-1].end
|
||||
else:
|
||||
start = None
|
||||
end = None
|
||||
return cls(start, end, text)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeakerSegment(TimedText):
|
||||
class SpeakerSegment(Timed):
|
||||
"""Represents a segment of audio attributed to a specific speaker.
|
||||
No text nor probability is associated with this segment.
|
||||
"""
|
||||
speaker: Optional[int] = -1
|
||||
pass
|
||||
|
||||
@dataclass
|
||||
@@ -43,21 +91,106 @@ class Translation(TimedText):
|
||||
|
||||
@dataclass
|
||||
class Silence():
|
||||
duration: float
|
||||
start: Optional[float] = None
|
||||
end: Optional[float] = None
|
||||
duration: Optional[float] = None
|
||||
is_starting: bool = False
|
||||
has_ended: bool = False
|
||||
|
||||
def compute_duration(self) -> Optional[float]:
|
||||
if self.start is None or self.end is None:
|
||||
return None
|
||||
self.duration = self.end - self.start
|
||||
return self.duration
|
||||
|
||||
def is_silence(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@dataclass
|
||||
class Segment(TimedText):
|
||||
"""Generic contiguous span built from tokens or silence markers."""
|
||||
start: Optional[float]
|
||||
end: Optional[float]
|
||||
text: Optional[str]
|
||||
speaker: Optional[str]
|
||||
@classmethod
|
||||
def from_tokens(
|
||||
cls,
|
||||
tokens: List[Union[ASRToken, Silence]],
|
||||
is_silence: bool = False
|
||||
) -> Optional["Segment"]:
|
||||
"""Return a normalized segment representing the provided tokens."""
|
||||
if not tokens:
|
||||
return None
|
||||
|
||||
start_token = tokens[0]
|
||||
end_token = tokens[-1]
|
||||
if is_silence:
|
||||
return cls(
|
||||
start=start_token.start,
|
||||
end=end_token.end,
|
||||
text=None,
|
||||
speaker=-2
|
||||
)
|
||||
else:
|
||||
return cls(
|
||||
start=start_token.start,
|
||||
end=end_token.end,
|
||||
text=''.join(token.text for token in tokens),
|
||||
speaker=-1,
|
||||
detected_language=start_token.detected_language
|
||||
)
|
||||
def is_silence(self) -> bool:
|
||||
"""True when this segment represents a silence gap."""
|
||||
return self.speaker == -2
|
||||
|
||||
|
||||
@dataclass
|
||||
class Line(TimedText):
|
||||
translation: str = ''
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'speaker': int(self.speaker),
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serialize the line for frontend consumption."""
|
||||
_dict: Dict[str, Any] = {
|
||||
'speaker': int(self.speaker) if self.speaker != -1 else 1,
|
||||
'text': self.text,
|
||||
'translation': self.translation,
|
||||
'start': format_time(self.start),
|
||||
'end': format_time(self.end),
|
||||
}
|
||||
if self.translation:
|
||||
_dict['translation'] = self.translation
|
||||
if self.detected_language:
|
||||
_dict['detected_language'] = self.detected_language
|
||||
return _dict
|
||||
|
||||
def build_from_tokens(self, tokens: List[ASRToken]) -> "Line":
|
||||
"""Populate line attributes from a contiguous token list."""
|
||||
self.text = ''.join([token.text for token in tokens])
|
||||
self.start = tokens[0].start
|
||||
self.end = tokens[-1].end
|
||||
self.speaker = 1
|
||||
self.detected_language = tokens[0].detected_language
|
||||
return self
|
||||
|
||||
def build_from_segment(self, segment: Segment) -> "Line":
|
||||
"""Populate the line fields from a pre-built segment."""
|
||||
self.text = segment.text
|
||||
self.start = segment.start
|
||||
self.end = segment.end
|
||||
self.speaker = segment.speaker
|
||||
self.detected_language = segment.detected_language
|
||||
return self
|
||||
|
||||
def is_silent(self) -> bool:
|
||||
return self.speaker == -2
|
||||
|
||||
class SilentLine(Line):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.speaker = -2
|
||||
self.text = ''
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrontData():
|
||||
@@ -66,15 +199,18 @@ class FrontData():
|
||||
lines: list[Line] = field(default_factory=list)
|
||||
buffer_transcription: str = ''
|
||||
buffer_diarization: str = ''
|
||||
buffer_translation: str = ''
|
||||
remaining_time_transcription: float = 0.
|
||||
remaining_time_diarization: float = 0.
|
||||
|
||||
def to_dict(self):
|
||||
_dict = {
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serialize the front-end data payload."""
|
||||
_dict: Dict[str, Any] = {
|
||||
'status': self.status,
|
||||
'lines': [line.to_dict() for line in self.lines],
|
||||
'lines': [line.to_dict() for line in self.lines if (line.text or line.speaker == -2)],
|
||||
'buffer_transcription': self.buffer_transcription,
|
||||
'buffer_diarization': self.buffer_diarization,
|
||||
'buffer_translation': self.buffer_translation,
|
||||
'remaining_time_transcription': self.remaining_time_transcription,
|
||||
'remaining_time_diarization': self.remaining_time_diarization,
|
||||
}
|
||||
@@ -82,13 +218,29 @@ class FrontData():
|
||||
_dict['error'] = self.error
|
||||
return _dict
|
||||
|
||||
@dataclass
|
||||
class ChangeSpeaker:
|
||||
speaker: int
|
||||
start: int
|
||||
|
||||
@dataclass
|
||||
class State():
|
||||
tokens: list
|
||||
translated_segments: list
|
||||
buffer_transcription: str
|
||||
buffer_diarization: str
|
||||
end_buffer: float
|
||||
end_attributed_speaker: float
|
||||
remaining_time_transcription: float
|
||||
remaining_time_diarization: float
|
||||
"""Unified state class for audio processing.
|
||||
|
||||
Contains both persistent state (tokens, buffers) and temporary update buffers
|
||||
(new_* fields) that are consumed by TokensAlignment.
|
||||
"""
|
||||
# Persistent state
|
||||
tokens: List[ASRToken] = field(default_factory=list)
|
||||
buffer_transcription: Transcript = field(default_factory=Transcript)
|
||||
end_buffer: float = 0.0
|
||||
end_attributed_speaker: float = 0.0
|
||||
remaining_time_transcription: float = 0.0
|
||||
remaining_time_diarization: float = 0.0
|
||||
|
||||
# Temporary update buffers (consumed by TokensAlignment.update())
|
||||
new_tokens: List[Union[ASRToken, Silence]] = field(default_factory=list)
|
||||
new_translation: List[Any] = field(default_factory=list)
|
||||
new_diarization: List[Any] = field(default_factory=list)
|
||||
new_tokens_buffer: List[Any] = field(default_factory=list) # only when local agreement
|
||||
new_translation_buffer= TimedText()
|
||||
177
whisperlivekit/tokens_alignment.py
Normal file
@@ -0,0 +1,177 @@
|
||||
from time import time
|
||||
from typing import Optional, List, Tuple, Union, Any
|
||||
|
||||
from whisperlivekit.timed_objects import Line, SilentLine, ASRToken, SpeakerSegment, Silence, TimedText, Segment
|
||||
|
||||
|
||||
class TokensAlignment:
|
||||
|
||||
def __init__(self, state: Any, args: Any, sep: Optional[str]) -> None:
|
||||
self.state = state
|
||||
self.diarization = args.diarization
|
||||
self._tokens_index: int = 0
|
||||
self._diarization_index: int = 0
|
||||
self._translation_index: int = 0
|
||||
|
||||
self.all_tokens: List[ASRToken] = []
|
||||
self.all_diarization_segments: List[SpeakerSegment] = []
|
||||
self.all_translation_segments: List[Any] = []
|
||||
|
||||
self.new_tokens: List[ASRToken] = []
|
||||
self.new_diarization: List[SpeakerSegment] = []
|
||||
self.new_translation: List[Any] = []
|
||||
self.new_translation_buffer: Union[TimedText, str] = TimedText()
|
||||
self.new_tokens_buffer: List[Any] = []
|
||||
self.sep: str = sep if sep is not None else ' '
|
||||
self.beg_loop: Optional[float] = None
|
||||
|
||||
def update(self) -> None:
|
||||
"""Drain state buffers into the running alignment context."""
|
||||
self.new_tokens, self.state.new_tokens = self.state.new_tokens, []
|
||||
self.new_diarization, self.state.new_diarization = self.state.new_diarization, []
|
||||
self.new_translation, self.state.new_translation = self.state.new_translation, []
|
||||
self.new_tokens_buffer, self.state.new_tokens_buffer = self.state.new_tokens_buffer, []
|
||||
|
||||
self.all_tokens.extend(self.new_tokens)
|
||||
self.all_diarization_segments.extend(self.new_diarization)
|
||||
self.all_translation_segments.extend(self.new_translation)
|
||||
self.new_translation_buffer = self.state.new_translation_buffer
|
||||
|
||||
def add_translation(self, line: Line) -> None:
|
||||
"""Append translated text segments that overlap with a line."""
|
||||
for ts in self.all_translation_segments:
|
||||
if ts.is_within(line):
|
||||
line.translation += ts.text + (self.sep if ts.text else '')
|
||||
elif line.translation:
|
||||
break
|
||||
|
||||
|
||||
def compute_punctuations_segments(self, tokens: Optional[List[ASRToken]] = None) -> List[Segment]:
|
||||
"""Group tokens into segments split by punctuation and explicit silence."""
|
||||
segments = []
|
||||
segment_start_idx = 0
|
||||
for i, token in enumerate(self.all_tokens):
|
||||
if token.is_silence():
|
||||
previous_segment = Segment.from_tokens(
|
||||
tokens=self.all_tokens[segment_start_idx: i],
|
||||
)
|
||||
if previous_segment:
|
||||
segments.append(previous_segment)
|
||||
segment = Segment.from_tokens(
|
||||
tokens=[token],
|
||||
is_silence=True
|
||||
)
|
||||
segments.append(segment)
|
||||
segment_start_idx = i+1
|
||||
else:
|
||||
if token.has_punctuation():
|
||||
segment = Segment.from_tokens(
|
||||
tokens=self.all_tokens[segment_start_idx: i+1],
|
||||
)
|
||||
segments.append(segment)
|
||||
segment_start_idx = i+1
|
||||
|
||||
final_segment = Segment.from_tokens(
|
||||
tokens=self.all_tokens[segment_start_idx:],
|
||||
)
|
||||
if final_segment:
|
||||
segments.append(final_segment)
|
||||
return segments
|
||||
|
||||
|
||||
def concatenate_diar_segments(self) -> List[SpeakerSegment]:
|
||||
"""Merge consecutive diarization slices that share the same speaker."""
|
||||
if not self.all_diarization_segments:
|
||||
return []
|
||||
merged = [self.all_diarization_segments[0]]
|
||||
for segment in self.all_diarization_segments[1:]:
|
||||
if segment.speaker == merged[-1].speaker:
|
||||
merged[-1].end = segment.end
|
||||
else:
|
||||
merged.append(segment)
|
||||
return merged
|
||||
|
||||
|
||||
@staticmethod
|
||||
def intersection_duration(seg1: TimedText, seg2: TimedText) -> float:
|
||||
"""Return the overlap duration between two timed segments."""
|
||||
start = max(seg1.start, seg2.start)
|
||||
end = min(seg1.end, seg2.end)
|
||||
|
||||
return max(0, end - start)
|
||||
|
||||
def get_lines_diarization(self) -> Tuple[List[Line], str]:
|
||||
"""Build lines when diarization is enabled and track overflow buffer."""
|
||||
diarization_buffer = ''
|
||||
punctuation_segments = self.compute_punctuations_segments()
|
||||
diarization_segments = self.concatenate_diar_segments()
|
||||
for punctuation_segment in punctuation_segments:
|
||||
if not punctuation_segment.is_silence():
|
||||
if diarization_segments and punctuation_segment.start >= diarization_segments[-1].end:
|
||||
diarization_buffer += punctuation_segment.text
|
||||
else:
|
||||
max_overlap = 0.0
|
||||
max_overlap_speaker = 1
|
||||
for diarization_segment in diarization_segments:
|
||||
intersec = self.intersection_duration(punctuation_segment, diarization_segment)
|
||||
if intersec > max_overlap:
|
||||
max_overlap = intersec
|
||||
max_overlap_speaker = diarization_segment.speaker + 1
|
||||
punctuation_segment.speaker = max_overlap_speaker
|
||||
|
||||
lines = []
|
||||
if punctuation_segments:
|
||||
lines = [Line().build_from_segment(punctuation_segments[0])]
|
||||
for segment in punctuation_segments[1:]:
|
||||
if segment.speaker == lines[-1].speaker:
|
||||
if lines[-1].text:
|
||||
lines[-1].text += segment.text
|
||||
lines[-1].end = segment.end
|
||||
else:
|
||||
lines.append(Line().build_from_segment(segment))
|
||||
|
||||
return lines, diarization_buffer
|
||||
|
||||
|
||||
def get_lines(
|
||||
self,
|
||||
diarization: bool = False,
|
||||
translation: bool = False,
|
||||
current_silence: Optional[Silence] = None
|
||||
) -> Tuple[List[Line], str, Union[str, TimedText]]:
|
||||
"""Return the formatted lines plus buffers, optionally with diarization/translation."""
|
||||
if diarization:
|
||||
lines, diarization_buffer = self.get_lines_diarization()
|
||||
else:
|
||||
diarization_buffer = ''
|
||||
lines = []
|
||||
current_line_tokens = []
|
||||
for token in self.all_tokens:
|
||||
if token.is_silence():
|
||||
if current_line_tokens:
|
||||
lines.append(Line().build_from_tokens(current_line_tokens))
|
||||
current_line_tokens = []
|
||||
end_silence = token.end if token.has_ended else time() - self.beg_loop
|
||||
if lines and lines[-1].is_silent():
|
||||
lines[-1].end = end_silence
|
||||
else:
|
||||
lines.append(SilentLine(
|
||||
start = token.start,
|
||||
end = end_silence
|
||||
))
|
||||
else:
|
||||
current_line_tokens.append(token)
|
||||
if current_line_tokens:
|
||||
lines.append(Line().build_from_tokens(current_line_tokens))
|
||||
if current_silence:
|
||||
end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop
|
||||
if lines and lines[-1].is_silent():
|
||||
lines[-1].end = end_silence
|
||||
else:
|
||||
lines.append(SilentLine(
|
||||
start = current_silence.start,
|
||||
end = end_silence
|
||||
))
|
||||
if translation:
|
||||
[self.add_translation(line) for line in lines if not type(line) == Silence]
|
||||
return lines, diarization_buffer, self.new_translation_buffer.text
|
||||
@@ -1,60 +0,0 @@
|
||||
from typing import Sequence, Callable, Any, Optional, Dict
|
||||
|
||||
def _detect_tail_repetition(
|
||||
seq: Sequence[Any],
|
||||
key: Callable[[Any], Any] = lambda x: x, # extract comparable value
|
||||
min_block: int = 1, # set to 2 to ignore 1-token loops like "."
|
||||
max_tail: int = 300, # search window from the end for speed
|
||||
prefer: str = "longest", # "longest" coverage or "smallest" block
|
||||
) -> Optional[Dict]:
|
||||
vals = [key(x) for x in seq][-max_tail:]
|
||||
n = len(vals)
|
||||
best = None
|
||||
|
||||
# try every possible block length
|
||||
for b in range(min_block, n // 2 + 1):
|
||||
block = vals[-b:]
|
||||
# count how many times this block repeats contiguously at the very end
|
||||
count, i = 0, n
|
||||
while i - b >= 0 and vals[i - b:i] == block:
|
||||
count += 1
|
||||
i -= b
|
||||
|
||||
if count >= 2:
|
||||
cand = {
|
||||
"block_size": b,
|
||||
"count": count,
|
||||
"start_index": len(seq) - count * b, # in original seq
|
||||
"end_index": len(seq),
|
||||
}
|
||||
if (best is None or
|
||||
(prefer == "longest" and count * b > best["count"] * best["block_size"]) or
|
||||
(prefer == "smallest" and b < best["block_size"])):
|
||||
best = cand
|
||||
return best
|
||||
|
||||
def trim_tail_repetition(
|
||||
seq: Sequence[Any],
|
||||
key: Callable[[Any], Any] = lambda x: x,
|
||||
min_block: int = 1,
|
||||
max_tail: int = 300,
|
||||
prefer: str = "longest",
|
||||
keep: int = 1, # how many copies of the repeating block to keep at the end (0 or 1 are common)
|
||||
):
|
||||
"""
|
||||
Returns a new sequence with repeated tail trimmed.
|
||||
keep=1 -> keep a single copy of the repeated block.
|
||||
keep=0 -> remove all copies of the repeated block.
|
||||
"""
|
||||
rep = _detect_tail_repetition(seq, key, min_block, max_tail, prefer)
|
||||
if not rep:
|
||||
return seq, False # nothing to trim
|
||||
|
||||
b, c = rep["block_size"], rep["count"]
|
||||
if keep < 0:
|
||||
keep = 0
|
||||
if keep >= c:
|
||||
return seq, False # nothing to trim (already <= keep copies)
|
||||
# new length = total - (copies_to_remove * block_size)
|
||||
new_len = len(seq) - (c - keep) * b
|
||||
return seq[:new_len], True
|
||||
@@ -1,182 +0,0 @@
|
||||
"""
|
||||
adapted from https://store.crowdin.com/custom-mt
|
||||
"""
|
||||
|
||||
LANGUAGES = [
|
||||
{"name": "Afrikaans", "nllb": "afr_Latn", "crowdin": "af"},
|
||||
{"name": "Akan", "nllb": "aka_Latn", "crowdin": "ak"},
|
||||
{"name": "Amharic", "nllb": "amh_Ethi", "crowdin": "am"},
|
||||
{"name": "Assamese", "nllb": "asm_Beng", "crowdin": "as"},
|
||||
{"name": "Asturian", "nllb": "ast_Latn", "crowdin": "ast"},
|
||||
{"name": "Bashkir", "nllb": "bak_Cyrl", "crowdin": "ba"},
|
||||
{"name": "Bambara", "nllb": "bam_Latn", "crowdin": "bm"},
|
||||
{"name": "Balinese", "nllb": "ban_Latn", "crowdin": "ban"},
|
||||
{"name": "Belarusian", "nllb": "bel_Cyrl", "crowdin": "be"},
|
||||
{"name": "Bengali", "nllb": "ben_Beng", "crowdin": "bn"},
|
||||
{"name": "Bosnian", "nllb": "bos_Latn", "crowdin": "bs"},
|
||||
{"name": "Bulgarian", "nllb": "bul_Cyrl", "crowdin": "bg"},
|
||||
{"name": "Catalan", "nllb": "cat_Latn", "crowdin": "ca"},
|
||||
{"name": "Cebuano", "nllb": "ceb_Latn", "crowdin": "ceb"},
|
||||
{"name": "Czech", "nllb": "ces_Latn", "crowdin": "cs"},
|
||||
{"name": "Welsh", "nllb": "cym_Latn", "crowdin": "cy"},
|
||||
{"name": "Danish", "nllb": "dan_Latn", "crowdin": "da"},
|
||||
{"name": "German", "nllb": "deu_Latn", "crowdin": "de"},
|
||||
{"name": "Dzongkha", "nllb": "dzo_Tibt", "crowdin": "dz"},
|
||||
{"name": "Greek", "nllb": "ell_Grek", "crowdin": "el"},
|
||||
{"name": "English", "nllb": "eng_Latn", "crowdin": "en"},
|
||||
{"name": "Esperanto", "nllb": "epo_Latn", "crowdin": "eo"},
|
||||
{"name": "Estonian", "nllb": "est_Latn", "crowdin": "et"},
|
||||
{"name": "Basque", "nllb": "eus_Latn", "crowdin": "eu"},
|
||||
{"name": "Ewe", "nllb": "ewe_Latn", "crowdin": "ee"},
|
||||
{"name": "Faroese", "nllb": "fao_Latn", "crowdin": "fo"},
|
||||
{"name": "Fijian", "nllb": "fij_Latn", "crowdin": "fj"},
|
||||
{"name": "Finnish", "nllb": "fin_Latn", "crowdin": "fi"},
|
||||
{"name": "French", "nllb": "fra_Latn", "crowdin": "fr"},
|
||||
{"name": "Friulian", "nllb": "fur_Latn", "crowdin": "fur-IT"},
|
||||
{"name": "Scottish Gaelic", "nllb": "gla_Latn", "crowdin": "gd"},
|
||||
{"name": "Irish", "nllb": "gle_Latn", "crowdin": "ga-IE"},
|
||||
{"name": "Galician", "nllb": "glg_Latn", "crowdin": "gl"},
|
||||
{"name": "Guarani", "nllb": "grn_Latn", "crowdin": "gn"},
|
||||
{"name": "Gujarati", "nllb": "guj_Gujr", "crowdin": "gu-IN"},
|
||||
{"name": "Haitian Creole", "nllb": "hat_Latn", "crowdin": "ht"},
|
||||
{"name": "Hausa", "nllb": "hau_Latn", "crowdin": "ha"},
|
||||
{"name": "Hebrew", "nllb": "heb_Hebr", "crowdin": "he"},
|
||||
{"name": "Hindi", "nllb": "hin_Deva", "crowdin": "hi"},
|
||||
{"name": "Croatian", "nllb": "hrv_Latn", "crowdin": "hr"},
|
||||
{"name": "Hungarian", "nllb": "hun_Latn", "crowdin": "hu"},
|
||||
{"name": "Armenian", "nllb": "hye_Armn", "crowdin": "hy-AM"},
|
||||
{"name": "Igbo", "nllb": "ibo_Latn", "crowdin": "ig"},
|
||||
{"name": "Indonesian", "nllb": "ind_Latn", "crowdin": "id"},
|
||||
{"name": "Icelandic", "nllb": "isl_Latn", "crowdin": "is"},
|
||||
{"name": "Italian", "nllb": "ita_Latn", "crowdin": "it"},
|
||||
{"name": "Javanese", "nllb": "jav_Latn", "crowdin": "jv"},
|
||||
{"name": "Japanese", "nllb": "jpn_Jpan", "crowdin": "ja"},
|
||||
{"name": "Kabyle", "nllb": "kab_Latn", "crowdin": "kab"},
|
||||
{"name": "Kannada", "nllb": "kan_Knda", "crowdin": "kn"},
|
||||
{"name": "Georgian", "nllb": "kat_Geor", "crowdin": "ka"},
|
||||
{"name": "Kazakh", "nllb": "kaz_Cyrl", "crowdin": "kk"},
|
||||
{"name": "Khmer", "nllb": "khm_Khmr", "crowdin": "km"},
|
||||
{"name": "Kinyarwanda", "nllb": "kin_Latn", "crowdin": "rw"},
|
||||
{"name": "Kyrgyz", "nllb": "kir_Cyrl", "crowdin": "ky"},
|
||||
{"name": "Korean", "nllb": "kor_Hang", "crowdin": "ko"},
|
||||
{"name": "Lao", "nllb": "lao_Laoo", "crowdin": "lo"},
|
||||
{"name": "Ligurian", "nllb": "lij_Latn", "crowdin": "lij"},
|
||||
{"name": "Limburgish", "nllb": "lim_Latn", "crowdin": "li"},
|
||||
{"name": "Lingala", "nllb": "lin_Latn", "crowdin": "ln"},
|
||||
{"name": "Lithuanian", "nllb": "lit_Latn", "crowdin": "lt"},
|
||||
{"name": "Luxembourgish", "nllb": "ltz_Latn", "crowdin": "lb"},
|
||||
{"name": "Maithili", "nllb": "mai_Deva", "crowdin": "mai"},
|
||||
{"name": "Malayalam", "nllb": "mal_Mlym", "crowdin": "ml-IN"},
|
||||
{"name": "Marathi", "nllb": "mar_Deva", "crowdin": "mr"},
|
||||
{"name": "Macedonian", "nllb": "mkd_Cyrl", "crowdin": "mk"},
|
||||
{"name": "Maltese", "nllb": "mlt_Latn", "crowdin": "mt"},
|
||||
{"name": "Mossi", "nllb": "mos_Latn", "crowdin": "mos"},
|
||||
{"name": "Maori", "nllb": "mri_Latn", "crowdin": "mi"},
|
||||
{"name": "Burmese", "nllb": "mya_Mymr", "crowdin": "my"},
|
||||
{"name": "Dutch", "nllb": "nld_Latn", "crowdin": "nl"},
|
||||
{"name": "Norwegian Nynorsk", "nllb": "nno_Latn", "crowdin": "nn-NO"},
|
||||
{"name": "Nepali", "nllb": "npi_Deva", "crowdin": "ne-NP"},
|
||||
{"name": "Northern Sotho", "nllb": "nso_Latn", "crowdin": "nso"},
|
||||
{"name": "Occitan", "nllb": "oci_Latn", "crowdin": "oc"},
|
||||
{"name": "Odia", "nllb": "ory_Orya", "crowdin": "or"},
|
||||
{"name": "Papiamento", "nllb": "pap_Latn", "crowdin": "pap"},
|
||||
{"name": "Polish", "nllb": "pol_Latn", "crowdin": "pl"},
|
||||
{"name": "Portuguese", "nllb": "por_Latn", "crowdin": "pt-PT"},
|
||||
{"name": "Dari", "nllb": "prs_Arab", "crowdin": "fa-AF"},
|
||||
{"name": "Romanian", "nllb": "ron_Latn", "crowdin": "ro"},
|
||||
{"name": "Rundi", "nllb": "run_Latn", "crowdin": "rn"},
|
||||
{"name": "Russian", "nllb": "rus_Cyrl", "crowdin": "ru"},
|
||||
{"name": "Sango", "nllb": "sag_Latn", "crowdin": "sg"},
|
||||
{"name": "Sanskrit", "nllb": "san_Deva", "crowdin": "sa"},
|
||||
{"name": "Santali", "nllb": "sat_Olck", "crowdin": "sat"},
|
||||
{"name": "Sinhala", "nllb": "sin_Sinh", "crowdin": "si-LK"},
|
||||
{"name": "Slovak", "nllb": "slk_Latn", "crowdin": "sk"},
|
||||
{"name": "Slovenian", "nllb": "slv_Latn", "crowdin": "sl"},
|
||||
{"name": "Shona", "nllb": "sna_Latn", "crowdin": "sn"},
|
||||
{"name": "Sindhi", "nllb": "snd_Arab", "crowdin": "sd"},
|
||||
{"name": "Somali", "nllb": "som_Latn", "crowdin": "so"},
|
||||
{"name": "Southern Sotho", "nllb": "sot_Latn", "crowdin": "st"},
|
||||
{"name": "Spanish", "nllb": "spa_Latn", "crowdin": "es-ES"},
|
||||
{"name": "Sardinian", "nllb": "srd_Latn", "crowdin": "sc"},
|
||||
{"name": "Swati", "nllb": "ssw_Latn", "crowdin": "ss"},
|
||||
{"name": "Sundanese", "nllb": "sun_Latn", "crowdin": "su"},
|
||||
{"name": "Swedish", "nllb": "swe_Latn", "crowdin": "sv-SE"},
|
||||
{"name": "Swahili", "nllb": "swh_Latn", "crowdin": "sw"},
|
||||
{"name": "Tamil", "nllb": "tam_Taml", "crowdin": "ta"},
|
||||
{"name": "Tatar", "nllb": "tat_Cyrl", "crowdin": "tt-RU"},
|
||||
{"name": "Telugu", "nllb": "tel_Telu", "crowdin": "te"},
|
||||
{"name": "Tajik", "nllb": "tgk_Cyrl", "crowdin": "tg"},
|
||||
{"name": "Tagalog", "nllb": "tgl_Latn", "crowdin": "tl"},
|
||||
{"name": "Thai", "nllb": "tha_Thai", "crowdin": "th"},
|
||||
{"name": "Tigrinya", "nllb": "tir_Ethi", "crowdin": "ti"},
|
||||
{"name": "Tswana", "nllb": "tsn_Latn", "crowdin": "tn"},
|
||||
{"name": "Tsonga", "nllb": "tso_Latn", "crowdin": "ts"},
|
||||
{"name": "Turkmen", "nllb": "tuk_Latn", "crowdin": "tk"},
|
||||
{"name": "Turkish", "nllb": "tur_Latn", "crowdin": "tr"},
|
||||
{"name": "Uyghur", "nllb": "uig_Arab", "crowdin": "ug"},
|
||||
{"name": "Ukrainian", "nllb": "ukr_Cyrl", "crowdin": "uk"},
|
||||
{"name": "Venetian", "nllb": "vec_Latn", "crowdin": "vec"},
|
||||
{"name": "Vietnamese", "nllb": "vie_Latn", "crowdin": "vi"},
|
||||
{"name": "Wolof", "nllb": "wol_Latn", "crowdin": "wo"},
|
||||
{"name": "Xhosa", "nllb": "xho_Latn", "crowdin": "xh"},
|
||||
{"name": "Yoruba", "nllb": "yor_Latn", "crowdin": "yo"},
|
||||
{"name": "Zulu", "nllb": "zul_Latn", "crowdin": "zu"},
|
||||
]
|
||||
|
||||
NAME_TO_NLLB = {lang["name"]: lang["nllb"] for lang in LANGUAGES}
|
||||
NAME_TO_CROWDIN = {lang["name"]: lang["crowdin"] for lang in LANGUAGES}
|
||||
CROWDIN_TO_NLLB = {lang["crowdin"]: lang["nllb"] for lang in LANGUAGES}
|
||||
NLLB_TO_CROWDIN = {lang["nllb"]: lang["crowdin"] for lang in LANGUAGES}
|
||||
CROWDIN_TO_NAME = {lang["crowdin"]: lang["name"] for lang in LANGUAGES}
|
||||
NLLB_TO_NAME = {lang["nllb"]: lang["name"] for lang in LANGUAGES}
|
||||
|
||||
|
||||
def get_nllb_code(crowdin_code):
|
||||
return CROWDIN_TO_NLLB.get(crowdin_code, None)
|
||||
|
||||
|
||||
def get_crowdin_code(nllb_code):
|
||||
return NLLB_TO_CROWDIN.get(nllb_code)
|
||||
|
||||
|
||||
def get_language_name_by_crowdin(crowdin_code):
|
||||
return CROWDIN_TO_NAME.get(crowdin_code)
|
||||
|
||||
|
||||
def get_language_name_by_nllb(nllb_code):
|
||||
return NLLB_TO_NAME.get(nllb_code)
|
||||
|
||||
|
||||
def get_language_info(identifier, identifier_type="auto"):
|
||||
if identifier_type == "auto":
|
||||
for lang in LANGUAGES:
|
||||
if (lang["name"].lower() == identifier.lower() or
|
||||
lang["nllb"] == identifier or
|
||||
lang["crowdin"] == identifier):
|
||||
return lang
|
||||
elif identifier_type == "name":
|
||||
for lang in LANGUAGES:
|
||||
if lang["name"].lower() == identifier.lower():
|
||||
return lang
|
||||
elif identifier_type == "nllb":
|
||||
for lang in LANGUAGES:
|
||||
if lang["nllb"] == identifier:
|
||||
return lang
|
||||
elif identifier_type == "crowdin":
|
||||
for lang in LANGUAGES:
|
||||
if lang["crowdin"] == identifier:
|
||||
return lang
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def list_all_languages():
|
||||
return [lang["name"] for lang in LANGUAGES]
|
||||
|
||||
|
||||
def list_all_nllb_codes():
|
||||
return [lang["nllb"] for lang in LANGUAGES]
|
||||
|
||||
|
||||
def list_all_crowdin_codes():
|
||||
return [lang["crowdin"] for lang in LANGUAGES]
|
||||
@@ -1,137 +0,0 @@
|
||||
import ctranslate2
|
||||
import torch
|
||||
import transformers
|
||||
from dataclasses import dataclass
|
||||
import huggingface_hub
|
||||
from whisperlivekit.translation.mapping_languages import get_nllb_code
|
||||
from whisperlivekit.timed_objects import Translation
|
||||
|
||||
|
||||
#In diarization case, we may want to translate just one speaker, or at least start the sentences there
|
||||
|
||||
PUNCTUATION_MARKS = {'.', '!', '?', '。', '!', '?'}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranslationModel():
|
||||
translator: ctranslate2.Translator
|
||||
tokenizer: dict
|
||||
|
||||
def load_model(src_langs):
|
||||
MODEL = 'nllb-200-distilled-600M-ctranslate2'
|
||||
MODEL_GUY = 'entai2965'
|
||||
huggingface_hub.snapshot_download(MODEL_GUY + '/' + MODEL,local_dir=MODEL)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
translator = ctranslate2.Translator(MODEL,device=device)
|
||||
tokenizer = dict()
|
||||
for src_lang in src_langs:
|
||||
tokenizer[src_lang] = transformers.AutoTokenizer.from_pretrained(MODEL, src_lang=src_lang, clean_up_tokenization_spaces=True)
|
||||
return TranslationModel(
|
||||
translator=translator,
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
|
||||
def translate(input, translation_model, tgt_lang):
|
||||
source = translation_model.tokenizer.convert_ids_to_tokens(translation_model.tokenizer.encode(input))
|
||||
target_prefix = [tgt_lang]
|
||||
results = translation_model.translator.translate_batch([source], target_prefix=[target_prefix])
|
||||
target = results[0].hypotheses[0][1:]
|
||||
return translation_model.tokenizer.decode(translation_model.tokenizer.convert_tokens_to_ids(target))
|
||||
|
||||
class OnlineTranslation:
|
||||
def __init__(self, translation_model: TranslationModel, input_languages: list, output_languages: list):
|
||||
self.buffer = []
|
||||
self.len_processed_buffer = 0
|
||||
self.translation_remaining = Translation()
|
||||
self.validated = []
|
||||
self.translation_pending_validation = ''
|
||||
self.translation_model = translation_model
|
||||
self.input_languages = input_languages
|
||||
self.output_languages = output_languages
|
||||
|
||||
def compute_common_prefix(self, results):
|
||||
#we dont want want to prune the result for the moment.
|
||||
if not self.buffer:
|
||||
self.buffer = results
|
||||
else:
|
||||
for i in range(min(len(self.buffer), len(results))):
|
||||
if self.buffer[i] != results[i]:
|
||||
self.commited.extend(self.buffer[:i])
|
||||
self.buffer = results[i:]
|
||||
|
||||
def translate(self, input, input_lang=None, output_lang=None):
|
||||
if not input:
|
||||
return ""
|
||||
if input_lang is None:
|
||||
input_lang = self.input_languages[0]
|
||||
if output_lang is None:
|
||||
output_lang = self.output_languages[0]
|
||||
nllb_output_lang = get_nllb_code(output_lang)
|
||||
|
||||
source = self.translation_model.tokenizer[input_lang].convert_ids_to_tokens(self.translation_model.tokenizer[input_lang].encode(input))
|
||||
results = self.translation_model.translator.translate_batch([source], target_prefix=[[nllb_output_lang]]) #we can use return_attention=True to try to optimize the stuff.
|
||||
target = results[0].hypotheses[0][1:]
|
||||
results = self.translation_model.tokenizer[input_lang].decode(self.translation_model.tokenizer[input_lang].convert_tokens_to_ids(target))
|
||||
return results
|
||||
|
||||
def translate_tokens(self, tokens):
|
||||
if tokens:
|
||||
text = ' '.join([token.text for token in tokens])
|
||||
start = tokens[0].start
|
||||
end = tokens[-1].end
|
||||
translated_text = self.translate(text)
|
||||
translation = Translation(
|
||||
text=translated_text,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
return translation
|
||||
return None
|
||||
|
||||
|
||||
|
||||
def insert_tokens(self, tokens):
|
||||
self.buffer.extend(tokens)
|
||||
pass
|
||||
|
||||
def process(self):
|
||||
i = 0
|
||||
if len(self.buffer) < self.len_processed_buffer + 3: #nothing new to process
|
||||
return self.validated + [self.translation_remaining]
|
||||
while i < len(self.buffer):
|
||||
if self.buffer[i].text in PUNCTUATION_MARKS:
|
||||
translation_sentence = self.translate_tokens(self.buffer[:i+1])
|
||||
self.validated.append(translation_sentence)
|
||||
self.buffer = self.buffer[i+1:]
|
||||
i = 0
|
||||
else:
|
||||
i+=1
|
||||
self.translation_remaining = self.translate_tokens(self.buffer)
|
||||
self.len_processed_buffer = len(self.buffer)
|
||||
return self.validated + [self.translation_remaining]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
output_lang = 'fr'
|
||||
input_lang = "en"
|
||||
|
||||
|
||||
test_string = """
|
||||
Transcription technology has improved so much in the past few years. Have you noticed how accurate real-time speech-to-text is now?
|
||||
"""
|
||||
test = test_string.split(' ')
|
||||
step = len(test) // 3
|
||||
|
||||
shared_model = load_model([input_lang])
|
||||
online_translation = OnlineTranslation(shared_model, input_languages=[input_lang], output_languages=[output_lang])
|
||||
|
||||
for id in range(5):
|
||||
val = test[id*step : (id+1)*step]
|
||||
val_str = ' '.join(val)
|
||||
result = online_translation.translate(val_str)
|
||||
print(result)
|
||||
|
||||
|
||||
|
||||
|
||||
# print(result)
|
||||
BIN
whisperlivekit/vad_models/silero_vad.jit
Normal file
BIN
whisperlivekit/vad_models/silero_vad.onnx
Normal file
BIN
whisperlivekit/vad_models/silero_vad_16k_op15.onnx
Normal file
BIN
whisperlivekit/vad_models/silero_vad_half.onnx
Normal file
@@ -72,6 +72,12 @@
|
||||
--label-trans-text: #111111;
|
||||
}
|
||||
|
||||
html.is-extension
|
||||
{
|
||||
width: 350px;
|
||||
height: 500px;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: ui-sans-serif, system-ui, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol', 'Noto Color Emoji';
|
||||
margin: 0;
|
||||
@@ -191,6 +197,14 @@ body {
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
gap: 15px;
|
||||
position: relative;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.buttons-container {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 15px;
|
||||
}
|
||||
|
||||
.settings {
|
||||
@@ -200,6 +214,66 @@ body {
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.settings-toggle {
|
||||
width: 40px;
|
||||
height: 40px;
|
||||
border: none;
|
||||
border-radius: 50%;
|
||||
background-color: var(--button-bg);
|
||||
border: 1px solid var(--button-border);
|
||||
cursor: pointer;
|
||||
display: none;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
.settings-toggle:hover {
|
||||
background-color: var(--chip-bg);
|
||||
}
|
||||
|
||||
.settings-toggle.active {
|
||||
background-color: var(--chip-bg);
|
||||
}
|
||||
|
||||
.settings-toggle img {
|
||||
width: 20px;
|
||||
height: 20px;
|
||||
}
|
||||
|
||||
@media (max-width: 10000px) {
|
||||
.settings-toggle {
|
||||
display: flex;
|
||||
}
|
||||
|
||||
.settings {
|
||||
display: none;
|
||||
background: var(--bg);
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 18px;
|
||||
padding: 12px;
|
||||
}
|
||||
|
||||
.settings.visible {
|
||||
display: flex;
|
||||
}
|
||||
}
|
||||
|
||||
@media (max-width: 600px) {
|
||||
.settings-container {
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
}
|
||||
|
||||
.buttons-container {
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
gap: 15px;
|
||||
}
|
||||
}
|
||||
|
||||
.field {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
@@ -346,7 +420,7 @@ label {
|
||||
|
||||
.label_diarization {
|
||||
background-color: var(--chip-bg);
|
||||
border-radius: 8px 8px 8px 8px;
|
||||
border-radius: 100px;
|
||||
padding: 2px 10px;
|
||||
margin-left: 10px;
|
||||
display: inline-block;
|
||||
@@ -358,7 +432,7 @@ label {
|
||||
|
||||
.label_transcription {
|
||||
background-color: var(--chip-bg);
|
||||
border-radius: 8px 8px 8px 8px;
|
||||
border-radius: 100px;
|
||||
padding: 2px 10px;
|
||||
display: inline-block;
|
||||
white-space: nowrap;
|
||||
@@ -370,16 +444,20 @@ label {
|
||||
|
||||
.label_translation {
|
||||
background-color: var(--chip-bg);
|
||||
display: inline-flex;
|
||||
border-radius: 10px;
|
||||
padding: 4px 8px;
|
||||
margin-top: 4px;
|
||||
font-size: 14px;
|
||||
color: var(--text);
|
||||
display: flex;
|
||||
align-items: flex-start;
|
||||
gap: 4px;
|
||||
}
|
||||
|
||||
.lag-diarization-value {
|
||||
margin-left: 10px;
|
||||
}
|
||||
|
||||
.label_translation img {
|
||||
margin-top: 2px;
|
||||
}
|
||||
@@ -391,7 +469,7 @@ label {
|
||||
|
||||
#timeInfo {
|
||||
color: var(--muted);
|
||||
margin-left: 10px;
|
||||
margin-left: 0px;
|
||||
}
|
||||
|
||||
.textcontent {
|
||||
@@ -405,7 +483,6 @@ label {
|
||||
|
||||
.buffer_diarization {
|
||||
color: var(--label-dia-text);
|
||||
margin-left: 4px;
|
||||
}
|
||||
|
||||
.buffer_transcription {
|
||||
@@ -413,6 +490,11 @@ label {
|
||||
margin-left: 4px;
|
||||
}
|
||||
|
||||
.buffer_translation {
|
||||
color: #a0a0a0;
|
||||
margin-left: 6px;
|
||||
}
|
||||
|
||||
.spinner {
|
||||
display: inline-block;
|
||||
width: 8px;
|
||||
@@ -438,7 +520,6 @@ label {
|
||||
font-size: 13px;
|
||||
border-radius: 30px;
|
||||
padding: 2px 10px;
|
||||
display: none;
|
||||
}
|
||||
|
||||
.loading {
|
||||
@@ -451,7 +532,7 @@ label {
|
||||
}
|
||||
|
||||
/* for smaller screens */
|
||||
@media (max-width: 768px) {
|
||||
@media (max-width: 200px) {
|
||||
.header-container {
|
||||
padding: 15px;
|
||||
}
|
||||
@@ -461,6 +542,10 @@ label {
|
||||
gap: 10px;
|
||||
}
|
||||
|
||||
.buttons-container {
|
||||
gap: 10px;
|
||||
}
|
||||
|
||||
.settings {
|
||||
justify-content: center;
|
||||
gap: 8px;
|
||||
@@ -515,3 +600,31 @@ label {
|
||||
padding: 10px;
|
||||
}
|
||||
}
|
||||
|
||||
.label_language {
|
||||
background-color: var(--chip-bg);
|
||||
margin-bottom: 0px;
|
||||
border-radius: 100px;
|
||||
padding: 2px 8px;
|
||||
margin-left: 10px;
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 4px;
|
||||
font-size: 14px;
|
||||
color: var(--muted);
|
||||
}
|
||||
|
||||
|
||||
.speaker-badge {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
margin-left: -5px;
|
||||
border-radius: 50%;
|
||||
font-size: 11px;
|
||||
line-height: 1;
|
||||
font-weight: 800;
|
||||
color: var(--muted);
|
||||
}
|
||||
|
||||
@@ -5,12 +5,13 @@
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>WhisperLiveKit</title>
|
||||
<link rel="stylesheet" href="/web/live_transcription.css" />
|
||||
<link rel="stylesheet" href="live_transcription.css" />
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<div class="header-container">
|
||||
<div class="settings-container">
|
||||
<div class="buttons-container">
|
||||
<button id="recordButton">
|
||||
<div class="shape-container">
|
||||
<div class="shape"></div>
|
||||
@@ -23,6 +24,11 @@
|
||||
</div>
|
||||
</button>
|
||||
|
||||
<button id="settingsToggle" class="settings-toggle" title="Show/hide settings">
|
||||
<img src="web/src/settings.svg" alt="Settings" />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div class="settings">
|
||||
<div class="field">
|
||||
<label for="websocketInput">Websocket URL</label>
|
||||
@@ -67,7 +73,7 @@
|
||||
<div id="linesTranscript"></div>
|
||||
</div>
|
||||
|
||||
<script src="/web/live_transcription.js"></script>
|
||||
<script src="live_transcription.js"></script>
|
||||
</body>
|
||||
|
||||
</html>
|
||||
@@ -1,4 +1,8 @@
|
||||
/* Theme, WebSocket, recording, rendering logic extracted from inline script and adapted for segmented theme control and WS caption */
|
||||
const isExtension = typeof chrome !== 'undefined' && chrome.runtime && chrome.runtime.getURL;
|
||||
if (isExtension) {
|
||||
document.documentElement.classList.add('is-extension');
|
||||
}
|
||||
const isWebContext = !isExtension;
|
||||
|
||||
let isRecording = false;
|
||||
let websocket = null;
|
||||
@@ -22,6 +26,11 @@ let lastReceivedData = null;
|
||||
let lastSignature = null;
|
||||
let availableMicrophones = [];
|
||||
let selectedMicrophoneId = null;
|
||||
let serverUseAudioWorklet = null;
|
||||
let configReadyResolve;
|
||||
const configReady = new Promise((r) => (configReadyResolve = r));
|
||||
let outputAudioContext = null;
|
||||
let audioSource = null;
|
||||
|
||||
waveCanvas.width = 60 * (window.devicePixelRatio || 1);
|
||||
waveCanvas.height = 30 * (window.devicePixelRatio || 1);
|
||||
@@ -37,6 +46,26 @@ const timerElement = document.querySelector(".timer");
|
||||
const themeRadios = document.querySelectorAll('input[name="theme"]');
|
||||
const microphoneSelect = document.getElementById("microphoneSelect");
|
||||
|
||||
const settingsToggle = document.getElementById("settingsToggle");
|
||||
const settingsDiv = document.querySelector(".settings");
|
||||
|
||||
// if (isExtension) {
|
||||
// chrome.runtime.onInstalled.addListener((details) => {
|
||||
// if (details.reason.search(/install/g) === -1) {
|
||||
// return;
|
||||
// }
|
||||
// chrome.tabs.create({
|
||||
// url: chrome.runtime.getURL("welcome.html"),
|
||||
// active: true
|
||||
// });
|
||||
// });
|
||||
// }
|
||||
|
||||
const translationIcon = `<svg xmlns="http://www.w3.org/2000/svg" height="12px" viewBox="0 -960 960 960" width="12px" fill="#5f6368"><path d="m603-202-34 97q-4 11-14 18t-22 7q-20 0-32.5-16.5T496-133l152-402q5-11 15-18t22-7h30q12 0 22 7t15 18l152 403q8 19-4 35.5T868-80q-13 0-22.5-7T831-106l-34-96H603ZM362-401 188-228q-11 11-27.5 11.5T132-228q-11-11-11-28t11-28l174-174q-35-35-63.5-80T190-640h84q20 39 40 68t48 58q33-33 68.5-92.5T484-720H80q-17 0-28.5-11.5T40-760q0-17 11.5-28.5T80-800h240v-40q0-17 11.5-28.5T360-880q17 0 28.5 11.5T400-840v40h240q17 0 28.5 11.5T680-760q0 17-11.5 28.5T640-720h-76q-21 72-63 148t-83 116l96 98-30 82-122-125Zm266 129h144l-72-204-72 204Z"/></svg>`
|
||||
const silenceIcon = `<svg xmlns="http://www.w3.org/2000/svg" style="vertical-align: text-bottom;" height="14px" viewBox="0 -960 960 960" width="14px" fill="#5f6368"><path d="M514-556 320-752q9-3 19-5.5t21-2.5q66 0 113 47t47 113q0 11-1.5 22t-4.5 22ZM40-200v-32q0-33 17-62t47-44q51-26 115-44t141-18q26 0 49.5 2.5T456-392l-56-54q-9 3-19 4.5t-21 1.5q-66 0-113-47t-47-113q0-11 1.5-21t4.5-19L84-764q-11-11-11-28t11-28q12-12 28.5-12t27.5 12l675 685q11 11 11.5 27.5T816-80q-11 13-28 12.5T759-80L641-200h39q0 33-23.5 56.5T600-120H120q-33 0-56.5-23.5T40-200Zm80 0h480v-32q0-14-4.5-19.5T580-266q-36-18-92.5-36T360-320q-71 0-127.5 18T140-266q-9 5-14.5 14t-5.5 20v32Zm240 0Zm560-400q0 69-24.5 131.5T829-355q-12 14-30 15t-32-13q-13-13-12-31t12-33q30-38 46.5-85t16.5-98q0-51-16.5-97T767-781q-12-15-12.5-33t12.5-32q13-14 31.5-13.5T829-845q42 51 66.5 113.5T920-600Zm-182 0q0 32-10 61.5T700-484q-11 15-29.5 15.5T638-482q-13-13-13.5-31.5T633-549q6-11 9.5-24t3.5-27q0-14-3.5-27t-9.5-25q-9-17-8.5-35t13.5-31q14-14 32.5-13.5T700-716q18 25 28 54.5t10 61.5Z"/></svg>`;
|
||||
const languageIcon = `<svg xmlns="http://www.w3.org/2000/svg" height="12" viewBox="0 -960 960 960" width="12" fill="#5f6368"><path d="M480-80q-82 0-155-31.5t-127.5-86Q143-252 111.5-325T80-480q0-83 31.5-155.5t86-127Q252-817 325-848.5T480-880q83 0 155.5 31.5t127 86q54.5 54.5 86 127T880-480q0 82-31.5 155t-86 127.5q-54.5 54.5-127 86T480-80Zm0-82q26-36 45-75t31-83H404q12 44 31 83t45 75Zm-104-16q-18-33-31.5-68.5T322-320H204q29 50 72.5 87t99.5 55Zm208 0q56-18 99.5-55t72.5-87H638q-9 38-22.5 73.5T584-178ZM170-400h136q-3-20-4.5-39.5T300-480q0-21 1.5-40.5T306-560H170q-5 20-7.5 39.5T160-480q0 21 2.5 40.5T170-400Zm216 0h188q3-20 4.5-39.5T580-480q0-21-1.5-40.5T574-560H386q-3 20-4.5 39.5T380-480q0 21 1.5 40.5T386-400Zm268 0h136q5-20 7.5-39.5T800-480q0-21-2.5-40.5T790-560H654q3 20 4.5 39.5T660-480q0 21-1.5 40.5T654-400Zm-16-240h118q-29-50-72.5-87T584-782q18 33 31.5 68.5T638-640Zm-234 0h152q-12-44-31-83t-45-75q-26 36-45 75t-31 83Zm-200 0h118q9-38 22.5-73.5T376-782q-56 18-99.5 55T204-640Z"/></svg>`
|
||||
const speakerIcon = `<svg xmlns="http://www.w3.org/2000/svg" height="16px" style="vertical-align: text-bottom;" viewBox="0 -960 960 960" width="16px" fill="#5f6368"><path d="M480-480q-66 0-113-47t-47-113q0-66 47-113t113-47q66 0 113 47t47 113q0 66-47 113t-113 47ZM160-240v-32q0-34 17.5-62.5T224-378q62-31 126-46.5T480-440q66 0 130 15.5T736-378q29 15 46.5 43.5T800-272v32q0 33-23.5 56.5T720-160H240q-33 0-56.5-23.5T160-240Zm80 0h480v-32q0-11-5.5-20T700-306q-54-27-109-40.5T480-360q-56 0-111 13.5T260-306q-9 5-14.5 14t-5.5 20v32Zm240-320q33 0 56.5-23.5T560-640q0-33-23.5-56.5T480-720q-33 0-56.5 23.5T400-640q0 33 23.5 56.5T480-560Zm0-80Zm0 400Z"/></svg>`;
|
||||
|
||||
function getWaveStroke() {
|
||||
const styles = getComputedStyle(document.documentElement);
|
||||
const v = styles.getPropertyValue("--wave-stroke").trim();
|
||||
@@ -148,10 +177,16 @@ function fmt1(x) {
|
||||
return Number.isFinite(n) ? n.toFixed(1) : x;
|
||||
}
|
||||
|
||||
// Default WebSocket URL computation
|
||||
const host = window.location.hostname || "localhost";
|
||||
const port = window.location.port;
|
||||
const protocol = window.location.protocol === "https:" ? "wss" : "ws";
|
||||
let host, port, protocol;
|
||||
port = 8000;
|
||||
if (isExtension) {
|
||||
host = "localhost";
|
||||
protocol = "ws";
|
||||
} else {
|
||||
host = window.location.hostname || "localhost";
|
||||
port = window.location.port;
|
||||
protocol = window.location.protocol === "https:" ? "wss" : "ws";
|
||||
}
|
||||
const defaultWebSocketUrl = `${protocol}://${host}${port ? ":" + port : ""}/asr`;
|
||||
|
||||
// Populate default caption and input
|
||||
@@ -201,6 +236,7 @@ function setupWebSocket() {
|
||||
lastReceivedData.lines || [],
|
||||
lastReceivedData.buffer_diarization || "",
|
||||
lastReceivedData.buffer_transcription || "",
|
||||
lastReceivedData.buffer_translation || "",
|
||||
0,
|
||||
0,
|
||||
true
|
||||
@@ -228,6 +264,14 @@ function setupWebSocket() {
|
||||
|
||||
websocket.onmessage = (event) => {
|
||||
const data = JSON.parse(event.data);
|
||||
if (data.type === "config") {
|
||||
serverUseAudioWorklet = !!data.useAudioWorklet;
|
||||
statusText.textContent = serverUseAudioWorklet
|
||||
? "Connected. Using AudioWorklet (PCM)."
|
||||
: "Connected. Using MediaRecorder (WebM).";
|
||||
if (configReadyResolve) configReadyResolve();
|
||||
return;
|
||||
}
|
||||
|
||||
if (data.type === "ready_to_stop") {
|
||||
console.log("Ready to stop received, finalizing display and closing WebSocket.");
|
||||
@@ -238,6 +282,7 @@ function setupWebSocket() {
|
||||
lastReceivedData.lines || [],
|
||||
lastReceivedData.buffer_diarization || "",
|
||||
lastReceivedData.buffer_transcription || "",
|
||||
lastReceivedData.buffer_translation || "",
|
||||
0,
|
||||
0,
|
||||
true
|
||||
@@ -258,6 +303,7 @@ function setupWebSocket() {
|
||||
lines = [],
|
||||
buffer_transcription = "",
|
||||
buffer_diarization = "",
|
||||
buffer_translation = "",
|
||||
remaining_time_transcription = 0,
|
||||
remaining_time_diarization = 0,
|
||||
status = "active_transcription",
|
||||
@@ -267,6 +313,7 @@ function setupWebSocket() {
|
||||
lines,
|
||||
buffer_diarization,
|
||||
buffer_transcription,
|
||||
buffer_translation,
|
||||
remaining_time_diarization,
|
||||
remaining_time_transcription,
|
||||
false,
|
||||
@@ -280,6 +327,7 @@ function renderLinesWithBuffer(
|
||||
lines,
|
||||
buffer_diarization,
|
||||
buffer_transcription,
|
||||
buffer_translation,
|
||||
remaining_time_diarization,
|
||||
remaining_time_transcription,
|
||||
isFinalizing = false,
|
||||
@@ -295,9 +343,10 @@ function renderLinesWithBuffer(
|
||||
const showTransLag = !isFinalizing && remaining_time_transcription > 0;
|
||||
const showDiaLag = !isFinalizing && !!buffer_diarization && remaining_time_diarization > 0;
|
||||
const signature = JSON.stringify({
|
||||
lines: (lines || []).map((it) => ({ speaker: it.speaker, text: it.text, start: it.start, end: it.end })),
|
||||
lines: (lines || []).map((it) => ({ speaker: it.speaker, text: it.text, start: it.start, end: it.end, detected_language: it.detected_language })),
|
||||
buffer_transcription: buffer_transcription || "",
|
||||
buffer_diarization: buffer_diarization || "",
|
||||
buffer_translation: buffer_translation,
|
||||
status: current_status,
|
||||
showLoading,
|
||||
showTransLag,
|
||||
@@ -324,32 +373,29 @@ function renderLinesWithBuffer(
|
||||
|
||||
let speakerLabel = "";
|
||||
if (item.speaker === -2) {
|
||||
speakerLabel = `<span class="silence">Silence<span id='timeInfo'>${timeInfo}</span></span>`;
|
||||
speakerLabel = `<span class="silence">${silenceIcon}<span id='timeInfo'>${timeInfo}</span></span>`;
|
||||
} else if (item.speaker == 0 && !isFinalizing) {
|
||||
speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'><span class="loading-diarization-value">${fmt1(
|
||||
remaining_time_diarization
|
||||
)}</span> second(s) of audio are undergoing diarization</span></span>`;
|
||||
} else if (item.speaker !== 0) {
|
||||
speakerLabel = `<span id="speaker">Speaker ${item.speaker}<span id='timeInfo'>${timeInfo}</span></span>`;
|
||||
const speakerNum = `<span class="speaker-badge">${item.speaker}</span>`;
|
||||
speakerLabel = `<span id="speaker">${speakerIcon}${speakerNum}<span id='timeInfo'>${timeInfo}</span></span>`;
|
||||
|
||||
if (item.detected_language) {
|
||||
speakerLabel += `<span class="label_language">${languageIcon}<span>${item.detected_language}</span></span>`;
|
||||
}
|
||||
}
|
||||
|
||||
let currentLineText = item.text || "";
|
||||
|
||||
if (item.translation) {
|
||||
currentLineText += `<div class="label_translation">
|
||||
<img src="/web/src/translate.svg" alt="Translation" width="12" height="12" />
|
||||
<span>${item.translation}</span>
|
||||
</div>`;
|
||||
}
|
||||
|
||||
if (idx === lines.length - 1) {
|
||||
if (!isFinalizing && item.speaker !== -2) {
|
||||
if (remaining_time_transcription > 0) {
|
||||
speakerLabel += `<span class="label_transcription"><span class="spinner"></span>Transcription lag <span id='timeInfo'><span class="lag-transcription-value">${fmt1(
|
||||
remaining_time_transcription
|
||||
)}</span>s</span></span>`;
|
||||
}
|
||||
if (buffer_diarization && remaining_time_diarization > 0) {
|
||||
|
||||
if (buffer_diarization && remaining_time_diarization) {
|
||||
speakerLabel += `<span class="label_diarization"><span class="spinner"></span>Diarization lag<span id='timeInfo'><span class="lag-diarization-value">${fmt1(
|
||||
remaining_time_diarization
|
||||
)}</span>s</span></span>`;
|
||||
@@ -374,6 +420,25 @@ function renderLinesWithBuffer(
|
||||
}
|
||||
}
|
||||
}
|
||||
let translationContent = "";
|
||||
if (item.translation) {
|
||||
translationContent += item.translation.trim();
|
||||
}
|
||||
if (idx === lines.length - 1 && buffer_translation) {
|
||||
const bufferPiece = isFinalizing
|
||||
? buffer_translation
|
||||
: `<span class="buffer_translation">${buffer_translation}</span>`;
|
||||
translationContent += translationContent ? `${bufferPiece}` : bufferPiece;
|
||||
}
|
||||
if (translationContent.trim().length > 0) {
|
||||
currentLineText += `
|
||||
<div>
|
||||
<div class="label_translation">
|
||||
${translationIcon}
|
||||
<span class="translation_text">${translationContent}</span>
|
||||
</div>
|
||||
</div>`;
|
||||
}
|
||||
|
||||
return currentLineText.trim().length > 0 || speakerLabel.length > 0
|
||||
? `<p>${speakerLabel}<br/><div class='textcontent'>${currentLineText}</div></p>`
|
||||
@@ -447,11 +512,44 @@ async function startRecording() {
|
||||
console.log("Error acquiring wake lock.");
|
||||
}
|
||||
|
||||
let stream;
|
||||
|
||||
// chromium extension. in the future, both chrome page audio and mic will be used
|
||||
if (isExtension) {
|
||||
try {
|
||||
stream = await new Promise((resolve, reject) => {
|
||||
chrome.tabCapture.capture({audio: true}, (s) => {
|
||||
if (s) {
|
||||
resolve(s);
|
||||
} else {
|
||||
reject(new Error('Tab capture failed or not available'));
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
try {
|
||||
outputAudioContext = new (window.AudioContext || window.webkitAudioContext)();
|
||||
audioSource = outputAudioContext.createMediaStreamSource(stream);
|
||||
audioSource.connect(outputAudioContext.destination);
|
||||
} catch (audioError) {
|
||||
console.warn('could not preserve system audio:', audioError);
|
||||
}
|
||||
|
||||
statusText.textContent = "Using tab audio capture.";
|
||||
} catch (tabError) {
|
||||
console.log('Tab capture not available, falling back to microphone', tabError);
|
||||
const audioConstraints = selectedMicrophoneId
|
||||
? { audio: { deviceId: { exact: selectedMicrophoneId } } }
|
||||
: { audio: true };
|
||||
|
||||
const stream = await navigator.mediaDevices.getUserMedia(audioConstraints);
|
||||
stream = await navigator.mediaDevices.getUserMedia(audioConstraints);
|
||||
statusText.textContent = "Using microphone audio.";
|
||||
}
|
||||
} else if (isWebContext) {
|
||||
const audioConstraints = selectedMicrophoneId
|
||||
? { audio: { deviceId: { exact: selectedMicrophoneId } } }
|
||||
: { audio: true };
|
||||
stream = await navigator.mediaDevices.getUserMedia(audioConstraints);
|
||||
}
|
||||
|
||||
audioContext = new (window.AudioContext || window.webkitAudioContext)();
|
||||
analyser = audioContext.createAnalyser();
|
||||
@@ -459,6 +557,7 @@ async function startRecording() {
|
||||
microphone = audioContext.createMediaStreamSource(stream);
|
||||
microphone.connect(analyser);
|
||||
|
||||
if (serverUseAudioWorklet) {
|
||||
if (!audioContext.audioWorklet) {
|
||||
throw new Error("AudioWorklet is not supported in this browser");
|
||||
}
|
||||
@@ -491,6 +590,21 @@ async function startRecording() {
|
||||
[ab]
|
||||
);
|
||||
};
|
||||
} else {
|
||||
try {
|
||||
recorder = new MediaRecorder(stream, { mimeType: "audio/webm" });
|
||||
} catch (e) {
|
||||
recorder = new MediaRecorder(stream);
|
||||
}
|
||||
recorder.ondataavailable = (e) => {
|
||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||
if (e.data && e.data.size > 0) {
|
||||
websocket.send(e.data);
|
||||
}
|
||||
}
|
||||
};
|
||||
recorder.start(chunkDuration);
|
||||
}
|
||||
|
||||
startTime = Date.now();
|
||||
timerInterval = setInterval(updateTimer, 1000);
|
||||
@@ -528,6 +642,14 @@ async function stopRecording() {
|
||||
statusText.textContent = "Recording stopped. Processing final audio...";
|
||||
}
|
||||
|
||||
if (recorder) {
|
||||
try {
|
||||
recorder.stop();
|
||||
} catch (e) {
|
||||
}
|
||||
recorder = null;
|
||||
}
|
||||
|
||||
if (recorderWorker) {
|
||||
recorderWorker.terminate();
|
||||
recorderWorker = null;
|
||||
@@ -561,6 +683,16 @@ async function stopRecording() {
|
||||
audioContext = null;
|
||||
}
|
||||
|
||||
if (audioSource) {
|
||||
audioSource.disconnect();
|
||||
audioSource = null;
|
||||
}
|
||||
|
||||
if (outputAudioContext && outputAudioContext.state !== "closed") {
|
||||
outputAudioContext.close()
|
||||
outputAudioContext = null;
|
||||
}
|
||||
|
||||
if (animationFrame) {
|
||||
cancelAnimationFrame(animationFrame);
|
||||
animationFrame = null;
|
||||
@@ -586,9 +718,11 @@ async function toggleRecording() {
|
||||
console.log("Connecting to WebSocket");
|
||||
try {
|
||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||
await configReady;
|
||||
await startRecording();
|
||||
} else {
|
||||
await setupWebSocket();
|
||||
await configReady;
|
||||
await startRecording();
|
||||
}
|
||||
} catch (err) {
|
||||
@@ -610,7 +744,7 @@ function updateUI() {
|
||||
statusText.textContent = "Please wait for processing to complete...";
|
||||
}
|
||||
} else if (isRecording) {
|
||||
statusText.textContent = "Recording...";
|
||||
statusText.textContent = "";
|
||||
} else {
|
||||
if (
|
||||
statusText.textContent !== "Finished processing audio! Ready to record again." &&
|
||||
@@ -644,3 +778,40 @@ navigator.mediaDevices.addEventListener('devicechange', async () => {
|
||||
console.log("Error re-enumerating microphones:", error);
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
settingsToggle.addEventListener("click", () => {
|
||||
settingsDiv.classList.toggle("visible");
|
||||
settingsToggle.classList.toggle("active");
|
||||
});
|
||||
|
||||
if (isExtension) {
|
||||
async function checkAndRequestPermissions() {
|
||||
const micPermission = await navigator.permissions.query({
|
||||
name: "microphone",
|
||||
});
|
||||
|
||||
const permissionDisplay = document.getElementById("audioPermission");
|
||||
if (permissionDisplay) {
|
||||
permissionDisplay.innerText = `MICROPHONE: ${micPermission.state}`;
|
||||
}
|
||||
|
||||
// if (micPermission.state !== "granted") {
|
||||
// chrome.tabs.create({ url: "welcome.html" });
|
||||
// }
|
||||
|
||||
const intervalId = setInterval(async () => {
|
||||
const micPermission = await navigator.permissions.query({
|
||||
name: "microphone",
|
||||
});
|
||||
if (micPermission.state === "granted") {
|
||||
if (permissionDisplay) {
|
||||
permissionDisplay.innerText = `MICROPHONE: ${micPermission.state}`;
|
||||
}
|
||||
clearInterval(intervalId);
|
||||
}
|
||||
}, 100);
|
||||
}
|
||||
|
||||
void checkAndRequestPermissions();
|
||||
}
|
||||
|
||||
1
whisperlivekit/web/src/language.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M480-80q-82 0-155-31.5t-127.5-86Q143-252 111.5-325T80-480q0-83 31.5-155.5t86-127Q252-817 325-848.5T480-880q83 0 155.5 31.5t127 86q54.5 54.5 86 127T880-480q0 82-31.5 155t-86 127.5q-54.5 54.5-127 86T480-80Zm0-82q26-36 45-75t31-83H404q12 44 31 83t45 75Zm-104-16q-18-33-31.5-68.5T322-320H204q29 50 72.5 87t99.5 55Zm208 0q56-18 99.5-55t72.5-87H638q-9 38-22.5 73.5T584-178ZM170-400h136q-3-20-4.5-39.5T300-480q0-21 1.5-40.5T306-560H170q-5 20-7.5 39.5T160-480q0 21 2.5 40.5T170-400Zm216 0h188q3-20 4.5-39.5T580-480q0-21-1.5-40.5T574-560H386q-3 20-4.5 39.5T380-480q0 21 1.5 40.5T386-400Zm268 0h136q5-20 7.5-39.5T800-480q0-21-2.5-40.5T790-560H654q3 20 4.5 39.5T660-480q0 21-1.5 40.5T654-400Zm-16-240h118q-29-50-72.5-87T584-782q18 33 31.5 68.5T638-640Zm-234 0h152q-12-44-31-83t-45-75q-26 36-45 75t-31 83Zm-200 0h118q9-38 22.5-73.5T376-782q-56 18-99.5 55T204-640Z"/></svg>
|
||||
|
After Width: | Height: | Size: 976 B |
1
whisperlivekit/web/src/silence.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M514-556 320-752q9-3 19-5.5t21-2.5q66 0 113 47t47 113q0 11-1.5 22t-4.5 22ZM40-200v-32q0-33 17-62t47-44q51-26 115-44t141-18q26 0 49.5 2.5T456-392l-56-54q-9 3-19 4.5t-21 1.5q-66 0-113-47t-47-113q0-11 1.5-21t4.5-19L84-764q-11-11-11-28t11-28q12-12 28.5-12t27.5 12l675 685q11 11 11.5 27.5T816-80q-11 13-28 12.5T759-80L641-200h39q0 33-23.5 56.5T600-120H120q-33 0-56.5-23.5T40-200Zm80 0h480v-32q0-14-4.5-19.5T580-266q-36-18-92.5-36T360-320q-71 0-127.5 18T140-266q-9 5-14.5 14t-5.5 20v32Zm240 0Zm560-400q0 69-24.5 131.5T829-355q-12 14-30 15t-32-13q-13-13-12-31t12-33q30-38 46.5-85t16.5-98q0-51-16.5-97T767-781q-12-15-12.5-33t12.5-32q13-14 31.5-13.5T829-845q42 51 66.5 113.5T920-600Zm-182 0q0 32-10 61.5T700-484q-11 15-29.5 15.5T638-482q-13-13-13.5-31.5T633-549q6-11 9.5-24t3.5-27q0-14-3.5-27t-9.5-25q-9-17-8.5-35t13.5-31q14-14 32.5-13.5T700-716q18 25 28 54.5t10 61.5Z"/></svg>
|
||||
|
After Width: | Height: | Size: 984 B |
1
whisperlivekit/web/src/speaker.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M480-480q-66 0-113-47t-47-113q0-66 47-113t113-47q66 0 113 47t47 113q0 66-47 113t-113 47ZM160-240v-32q0-34 17.5-62.5T224-378q62-31 126-46.5T480-440q66 0 130 15.5T736-378q29 15 46.5 43.5T800-272v32q0 33-23.5 56.5T720-160H240q-33 0-56.5-23.5T160-240Zm80 0h480v-32q0-11-5.5-20T700-306q-54-27-109-40.5T480-360q-56 0-111 13.5T260-306q-9 5-14.5 14t-5.5 20v32Zm240-320q33 0 56.5-23.5T560-640q0-33-23.5-56.5T480-720q-33 0-56.5 23.5T400-640q0 33 23.5 56.5T480-560Zm0-80Zm0 400Z"/></svg>
|
||||
|
After Width: | Height: | Size: 592 B |
@@ -23,6 +23,24 @@ def get_inline_ui_html():
|
||||
with resources.files('whisperlivekit.web').joinpath('live_transcription.js').open('r', encoding='utf-8') as f:
|
||||
js_content = f.read()
|
||||
|
||||
with resources.files('whisperlivekit.web').joinpath('pcm_worklet.js').open('r', encoding='utf-8') as f:
|
||||
worklet_code = f.read()
|
||||
with resources.files('whisperlivekit.web').joinpath('recorder_worker.js').open('r', encoding='utf-8') as f:
|
||||
worker_code = f.read()
|
||||
|
||||
js_content = js_content.replace(
|
||||
'await audioContext.audioWorklet.addModule("/web/pcm_worklet.js");',
|
||||
'const workletBlob = new Blob([`' + worklet_code + '`], { type: "application/javascript" });\n' +
|
||||
'const workletUrl = URL.createObjectURL(workletBlob);\n' +
|
||||
'await audioContext.audioWorklet.addModule(workletUrl);'
|
||||
)
|
||||
js_content = js_content.replace(
|
||||
'recorderWorker = new Worker("/web/recorder_worker.js");',
|
||||
'const workerBlob = new Blob([`' + worker_code + '`], { type: "application/javascript" });\n' +
|
||||
'const workerUrl = URL.createObjectURL(workerBlob);\n' +
|
||||
'recorderWorker = new Worker(workerUrl);'
|
||||
)
|
||||
|
||||
# SVG files
|
||||
with resources.files('whisperlivekit.web').joinpath('src', 'system_mode.svg').open('r', encoding='utf-8') as f:
|
||||
system_svg = f.read()
|
||||
@@ -33,15 +51,18 @@ def get_inline_ui_html():
|
||||
with resources.files('whisperlivekit.web').joinpath('src', 'dark_mode.svg').open('r', encoding='utf-8') as f:
|
||||
dark_svg = f.read()
|
||||
dark_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(dark_svg.encode('utf-8')).decode('utf-8')}"
|
||||
with resources.files('whisperlivekit.web').joinpath('src', 'settings.svg').open('r', encoding='utf-8') as f:
|
||||
settings = f.read()
|
||||
settings_uri = f"data:image/svg+xml;base64,{base64.b64encode(settings.encode('utf-8')).decode('utf-8')}"
|
||||
|
||||
# Replace external references
|
||||
html_content = html_content.replace(
|
||||
'<link rel="stylesheet" href="/web/live_transcription.css" />',
|
||||
'<link rel="stylesheet" href="live_transcription.css" />',
|
||||
f'<style>\n{css_content}\n</style>'
|
||||
)
|
||||
|
||||
html_content = html_content.replace(
|
||||
'<script src="/web/live_transcription.js"></script>',
|
||||
'<script src="live_transcription.js"></script>',
|
||||
f'<script>\n{js_content}\n</script>'
|
||||
)
|
||||
|
||||
@@ -61,6 +82,11 @@ def get_inline_ui_html():
|
||||
f'<img src="{dark_data_uri}" alt="" />'
|
||||
)
|
||||
|
||||
html_content = html_content.replace(
|
||||
'<img src="web/src/settings.svg" alt="Settings" />',
|
||||
f'<img src="{settings_uri}" alt="" />'
|
||||
)
|
||||
|
||||
return html_content
|
||||
|
||||
except Exception as e:
|
||||
|
||||
463
whisperlivekit/whisper/__init__.py
Normal file
@@ -0,0 +1,463 @@
|
||||
import hashlib
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import urllib
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
from torch import Tensor
|
||||
|
||||
from whisperlivekit.whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||
from whisperlivekit.whisper.decoding import DecodingOptions, DecodingResult, decode, detect_language
|
||||
from whisperlivekit.whisper.model import ModelDimensions, Whisper
|
||||
from whisperlivekit.whisper.transcribe import transcribe
|
||||
from whisperlivekit.whisper.version import __version__
|
||||
|
||||
_MODELS = {
|
||||
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
||||
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
|
||||
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
|
||||
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
|
||||
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
|
||||
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
|
||||
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
|
||||
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
|
||||
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
|
||||
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
||||
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
||||
"large-v3-turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
|
||||
"turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
|
||||
}
|
||||
|
||||
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
|
||||
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
|
||||
_ALIGNMENT_HEADS = {
|
||||
"tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
|
||||
"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
|
||||
"base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
|
||||
"base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
|
||||
"small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
|
||||
"small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
|
||||
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
|
||||
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
|
||||
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
|
||||
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
|
||||
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
||||
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
||||
"large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
|
||||
"turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
|
||||
}
|
||||
|
||||
|
||||
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
||||
os.makedirs(root, exist_ok=True)
|
||||
|
||||
expected_sha256 = url.split("/")[-2]
|
||||
download_target = os.path.join(root, os.path.basename(url))
|
||||
|
||||
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
||||
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
||||
|
||||
if os.path.isfile(download_target):
|
||||
with open(download_target, "rb") as f:
|
||||
model_bytes = f.read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
||||
return model_bytes if in_memory else download_target
|
||||
else:
|
||||
warnings.warn(
|
||||
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
|
||||
)
|
||||
|
||||
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
||||
with tqdm(
|
||||
total=int(source.info().get("Content-Length")),
|
||||
ncols=80,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as loop:
|
||||
while True:
|
||||
buffer = source.read(8192)
|
||||
if not buffer:
|
||||
break
|
||||
|
||||
output.write(buffer)
|
||||
loop.update(len(buffer))
|
||||
|
||||
model_bytes = open(download_target, "rb").read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
||||
raise RuntimeError(
|
||||
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
||||
)
|
||||
|
||||
return model_bytes if in_memory else download_target
|
||||
|
||||
|
||||
def available_models() -> List[str]:
|
||||
"""Returns the names of available models"""
|
||||
return list(_MODELS.keys())
|
||||
|
||||
|
||||
def _infer_dims_from_config(path: str) -> Optional[ModelDimensions]:
|
||||
"""
|
||||
attempt to infer ModelDimensions from a HF style config.json located
|
||||
next to the given checkpoint, usefull for distilled models
|
||||
"""
|
||||
candidates = []
|
||||
if os.path.isdir(path):
|
||||
candidates.append(os.path.join(path, "config.json"))
|
||||
else:
|
||||
candidates.append(os.path.join(os.path.dirname(path), "config.json"))
|
||||
|
||||
for candidate in candidates:
|
||||
if not os.path.isfile(candidate):
|
||||
continue
|
||||
with open(candidate, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
|
||||
try:
|
||||
return ModelDimensions(
|
||||
n_mels=config["num_mel_bins"],
|
||||
n_audio_ctx=config["max_source_positions"],
|
||||
n_audio_state=config["d_model"],
|
||||
n_audio_head=config["encoder_attention_heads"],
|
||||
n_audio_layer=config.get("encoder_layers")
|
||||
or config["num_hidden_layers"],
|
||||
n_vocab=config["vocab_size"],
|
||||
n_text_ctx=config["max_target_positions"],
|
||||
n_text_state=config["d_model"],
|
||||
n_text_head=config["decoder_attention_heads"],
|
||||
n_text_layer=config["decoder_layers"],
|
||||
)
|
||||
except KeyError as err:
|
||||
warnings.warn(f"Missing key {err} in HuggingFace config {candidate}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _convert_hf_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
converts a HF checkpoint state_dict into the naming convention used by
|
||||
default whisper
|
||||
"""
|
||||
|
||||
if not any(k.startswith("model.") for k in state_dict):
|
||||
return state_dict
|
||||
|
||||
def map_block(prefix: str, target_prefix: str, remainder: str) -> Optional[str]:
|
||||
if remainder.startswith("self_attn."):
|
||||
suffix = remainder.split(".", 1)[1]
|
||||
mapping = {
|
||||
"q_proj": "attn.query",
|
||||
"k_proj": "attn.key",
|
||||
"v_proj": "attn.value",
|
||||
"out_proj": "attn.out",
|
||||
}
|
||||
stem = mapping.get(suffix.split(".")[0])
|
||||
if stem:
|
||||
rest = suffix.split(".", 1)[1] if "." in suffix else ""
|
||||
return f"{target_prefix}.{stem}" + (f".{rest}" if rest else "")
|
||||
elif remainder == "self_attn_layer_norm.weight":
|
||||
return f"{target_prefix}.attn_ln.weight"
|
||||
elif remainder == "self_attn_layer_norm.bias":
|
||||
return f"{target_prefix}.attn_ln.bias"
|
||||
elif remainder.startswith("encoder_attn."):
|
||||
suffix = remainder.split(".", 1)[1]
|
||||
mapping = {
|
||||
"q_proj": "cross_attn.query",
|
||||
"k_proj": "cross_attn.key",
|
||||
"v_proj": "cross_attn.value",
|
||||
"out_proj": "cross_attn.out",
|
||||
}
|
||||
stem = mapping.get(suffix.split(".", 1)[0])
|
||||
if stem:
|
||||
rest = suffix.split(".", 1)[1] if "." in suffix else ""
|
||||
return f"{target_prefix}.{stem}" + (f".{rest}" if rest else "")
|
||||
elif remainder == "encoder_attn_layer_norm.weight":
|
||||
return f"{target_prefix}.cross_attn_ln.weight"
|
||||
elif remainder == "encoder_attn_layer_norm.bias":
|
||||
return f"{target_prefix}.cross_attn_ln.bias"
|
||||
elif remainder.startswith("fc1."):
|
||||
return f"{target_prefix}.mlp.0.{remainder.split('.',1)[1]}"
|
||||
elif remainder.startswith("fc2."):
|
||||
return f"{target_prefix}.mlp.2.{remainder.split('.',1)[1]}"
|
||||
elif remainder == "final_layer_norm.weight":
|
||||
return f"{target_prefix}.mlp_ln.weight"
|
||||
elif remainder == "final_layer_norm.bias":
|
||||
return f"{target_prefix}.mlp_ln.bias"
|
||||
return None
|
||||
|
||||
converted = {}
|
||||
for key, value in state_dict.items():
|
||||
if not key.startswith("model."):
|
||||
continue
|
||||
subkey = key[len("model.") :]
|
||||
|
||||
if subkey.startswith("encoder.layers."):
|
||||
parts = subkey.split(".")
|
||||
layer_idx = parts[2]
|
||||
remainder = ".".join(parts[3:])
|
||||
mapped = map_block(subkey, f"encoder.blocks.{layer_idx}", remainder)
|
||||
elif subkey.startswith("decoder.layers."):
|
||||
parts = subkey.split(".")
|
||||
layer_idx = parts[2]
|
||||
remainder = ".".join(parts[3:])
|
||||
mapped = map_block(subkey, f"decoder.blocks.{layer_idx}", remainder)
|
||||
elif subkey.startswith("encoder.conv") or subkey.startswith("decoder.conv"):
|
||||
mapped = subkey
|
||||
elif subkey == "encoder.embed_positions.weight":
|
||||
mapped = "encoder.positional_embedding"
|
||||
elif subkey == "decoder.embed_positions.weight":
|
||||
mapped = "decoder.positional_embedding"
|
||||
elif subkey == "encoder.layer_norm.weight":
|
||||
mapped = "encoder.ln_post.weight"
|
||||
elif subkey == "encoder.layer_norm.bias":
|
||||
mapped = "encoder.ln_post.bias"
|
||||
elif subkey.startswith("decoder.embed_tokens."):
|
||||
mapped = subkey.replace("embed_tokens", "token_embedding", 1)
|
||||
elif subkey == "decoder.layer_norm.weight":
|
||||
mapped = "decoder.ln.weight"
|
||||
elif subkey == "decoder.layer_norm.bias":
|
||||
mapped = "decoder.ln.bias"
|
||||
else:
|
||||
mapped = None
|
||||
|
||||
if mapped:
|
||||
converted[mapped] = value
|
||||
|
||||
return converted if converted else state_dict
|
||||
|
||||
|
||||
def _load_lora_state(lora_path: str):
|
||||
safe_path = os.path.join(lora_path, "adapter_model.safetensors")
|
||||
bin_path = os.path.join(lora_path, "adapter_model.bin")
|
||||
if os.path.isfile(safe_path):
|
||||
try:
|
||||
from safetensors.torch import load_file
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Loading LoRA adapters stored as .safetensors requires the `safetensors` package."
|
||||
) from exc
|
||||
return load_file(safe_path)
|
||||
if os.path.isfile(bin_path):
|
||||
return torch.load(bin_path, map_location="cpu")
|
||||
raise FileNotFoundError(
|
||||
f"No adapter weights found under {lora_path}. Expected adapter_model.safetensors or adapter_model.bin."
|
||||
)
|
||||
|
||||
|
||||
def _collapse_hf_module_name(module: str):
|
||||
if module.startswith("base_model."):
|
||||
module = module[len("base_model.") :]
|
||||
if module.startswith("model.model."):
|
||||
module = module[len("model.") :]
|
||||
if not module.startswith("model."):
|
||||
module = f"model.{module}"
|
||||
return module
|
||||
|
||||
|
||||
def _apply_lora_adapter(state_dict: Dict[str, Tensor], lora_path: Optional[str]):
|
||||
if not lora_path:
|
||||
return
|
||||
|
||||
config_path = os.path.join(lora_path, "adapter_config.json")
|
||||
if not os.path.isfile(config_path):
|
||||
raise FileNotFoundError(f"Missing adapter_config.json inside {lora_path}")
|
||||
with open(config_path, "r", encoding="utf-8") as handle:
|
||||
config = json.load(handle)
|
||||
if config.get("peft_type") != "LORA":
|
||||
raise ValueError("Only LoRA adapters are supported.")
|
||||
|
||||
r = config.get("r")
|
||||
alpha = config.get("lora_alpha") or config.get("alpha")
|
||||
if not r or not alpha:
|
||||
raise ValueError("LoRA config must include `r` and `lora_alpha`.")
|
||||
scaling = alpha / r
|
||||
|
||||
adapter_state = _load_lora_state(lora_path)
|
||||
lora_layers: Dict[str, Dict[str, Tensor]] = {}
|
||||
for key, tensor in adapter_state.items():
|
||||
if key.endswith("lora_A.weight"):
|
||||
module = key[: -len(".lora_A.weight")]
|
||||
lora_layers.setdefault(module, {})["A"] = tensor
|
||||
elif key.endswith("lora_B.weight"):
|
||||
module = key[: -len(".lora_B.weight")]
|
||||
lora_layers.setdefault(module, {})["B"] = tensor
|
||||
|
||||
if not lora_layers:
|
||||
raise ValueError(f"No LoRA tensors found in {lora_path}")
|
||||
|
||||
for module, parts in lora_layers.items():
|
||||
if "A" not in parts or "B" not in parts:
|
||||
raise ValueError(f"Incomplete LoRA tensors for module '{module}'")
|
||||
|
||||
hf_module = _collapse_hf_module_name(module)
|
||||
hf_weight_key = f"{hf_module}.weight"
|
||||
|
||||
delta = parts["B"] @ parts["A"]
|
||||
delta = delta * scaling
|
||||
|
||||
converted = _convert_hf_state_dict({hf_weight_key: delta})
|
||||
if not converted:
|
||||
raise KeyError(f"Failed to map LoRA module '{module}' into Whisper state dict.")
|
||||
target_name, delta_tensor = next(iter(converted.items()))
|
||||
if target_name not in state_dict:
|
||||
raise KeyError(
|
||||
f"LoRA module '{module}' mapped to '{target_name}', but the base model has no such parameter."
|
||||
)
|
||||
|
||||
state_dict[target_name] = state_dict[target_name] + delta_tensor.to(
|
||||
dtype=state_dict[target_name].dtype, device=state_dict[target_name].device
|
||||
)
|
||||
|
||||
|
||||
def load_model(
|
||||
name: str,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
download_root: str = None,
|
||||
in_memory: bool = False,
|
||||
decoder_only: bool = False,
|
||||
custom_alignment_heads: Optional[str] = None,
|
||||
lora_path: Optional[str] = None,
|
||||
) -> Whisper:
|
||||
"""
|
||||
Load a Whisper ASR model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
one of the official model names listed by `whisper.available_models()`, or
|
||||
path to a model checkpoint containing the model dimensions and the model state_dict.
|
||||
device : Union[str, torch.device]
|
||||
the PyTorch device to put the model into
|
||||
download_root: str
|
||||
path to download the model files; by default, it uses "~/.cache/whisper"
|
||||
in_memory: bool
|
||||
whether to preload the model weights into host memory
|
||||
lora_path: str
|
||||
optional directory containing PEFT LoRA adapter weights (adapter_config + adapter_model)
|
||||
|
||||
Returns
|
||||
-------
|
||||
model : Whisper
|
||||
The Whisper ASR model instance
|
||||
"""
|
||||
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
if download_root is None:
|
||||
default = os.path.join(os.path.expanduser("~"), ".cache")
|
||||
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
|
||||
if name in _MODELS:
|
||||
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
|
||||
elif os.path.isfile(name):
|
||||
checkpoint_file = open(name, "rb").read() if in_memory else name
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Model {name} not found; available models = {available_models()}"
|
||||
)
|
||||
|
||||
alignment_heads = _ALIGNMENT_HEADS.get(name, None)
|
||||
if custom_alignment_heads:
|
||||
alignment_heads = custom_alignment_heads.encode()
|
||||
|
||||
if isinstance(checkpoint_file, Path) and checkpoint_file.suffix == '.safetensors':
|
||||
try:
|
||||
from safetensors.torch import load_file
|
||||
except ImportError:
|
||||
raise ImportError("Please install safetensors to load .safetensors model files: `pip install safetensors`")
|
||||
if in_memory:
|
||||
checkpoint = load_file(checkpoint_file, device=device)
|
||||
else:
|
||||
checkpoint = load_file(checkpoint_file, device=device)
|
||||
else:
|
||||
with (
|
||||
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
|
||||
) as fp:
|
||||
checkpoint = torch.load(fp, map_location=device)
|
||||
del checkpoint_file
|
||||
|
||||
dims_cfg = checkpoint.get("dims") if isinstance(checkpoint, dict) else None
|
||||
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
|
||||
state_dict = checkpoint["model_state_dict"]
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
state_dict = _convert_hf_state_dict(state_dict)
|
||||
_apply_lora_adapter(state_dict, lora_path)
|
||||
|
||||
if dims_cfg is not None:
|
||||
dims = ModelDimensions(**dims_cfg)
|
||||
else:
|
||||
dims = _infer_dims_from_config(name)
|
||||
if dims is None:
|
||||
raise RuntimeError(
|
||||
"Could not determine model dimensions. "
|
||||
"Ensure the checkpoint includes 'dims' or a HuggingFace config.json is present."
|
||||
)
|
||||
if not isinstance(state_dict, dict):
|
||||
state_dict = checkpoint
|
||||
|
||||
model = Whisper(dims, decoder_only=decoder_only)
|
||||
|
||||
if decoder_only:
|
||||
state_dict = {
|
||||
k: v for k, v in state_dict.items()
|
||||
if 'encoder' not in k
|
||||
}
|
||||
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
if alignment_heads is not None:
|
||||
model.set_alignment_heads(alignment_heads)
|
||||
|
||||
return model.to(device)
|
||||
|
||||
|
||||
def convert_encoder_to_coreml(
|
||||
model_name = "base",
|
||||
output_path= "whisper_encoder.mlpackage",
|
||||
dummy_frames = 3000, #Number of time frames to use for the dummy mel input during tracing
|
||||
precision = "float16",
|
||||
):
|
||||
|
||||
import coremltools as ct
|
||||
model = load_model(model_name, device="cpu", decoder_only=False)
|
||||
encoder = model.encoder.eval().cpu()
|
||||
|
||||
dummy_input = torch.randn(
|
||||
1,
|
||||
model.dims.n_mels,
|
||||
dummy_frames,
|
||||
dtype=next(encoder.parameters()).dtype,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
traced_encoder = torch.jit.trace(encoder, dummy_input)
|
||||
|
||||
precision_map = {
|
||||
"float16": ct.precision.FLOAT16,
|
||||
"fp16": ct.precision.FLOAT16,
|
||||
"float32": ct.precision.FLOAT32,
|
||||
"fp32": ct.precision.FLOAT32,
|
||||
}
|
||||
coreml_precision = precision_map[precision.lower()]
|
||||
|
||||
mlmodel = ct.convert(
|
||||
traced_encoder,
|
||||
inputs=[ct.TensorType(name="mel", shape=dummy_input.shape)],
|
||||
convert_to= "mlprogram",
|
||||
compute_precision=coreml_precision,
|
||||
)
|
||||
|
||||
output_path = Path(output_path)
|
||||
mlmodel.save(str(output_path))
|
||||
return output_path
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# convert_encoder_to_coreml(model_name="tiny", output_path="whisper_encoder.mlpackage", dummy_frames=3000, precision="float16", convert_to="mlprogram")
|
||||
@@ -1,110 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import sys
|
||||
import numpy as np
|
||||
import librosa
|
||||
from functools import lru_cache
|
||||
import time
|
||||
import logging
|
||||
from .backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
WHISPER_LANG_CODES = "af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh".split(
|
||||
","
|
||||
)
|
||||
|
||||
|
||||
def create_tokenizer(lan):
|
||||
"""returns an object that has split function that works like the one of MosesTokenizer"""
|
||||
|
||||
assert (
|
||||
lan in WHISPER_LANG_CODES
|
||||
), "language must be Whisper's supported lang code: " + " ".join(WHISPER_LANG_CODES)
|
||||
|
||||
if lan == "uk":
|
||||
import tokenize_uk
|
||||
|
||||
class UkrainianTokenizer:
|
||||
def split(self, text):
|
||||
return tokenize_uk.tokenize_sents(text)
|
||||
|
||||
return UkrainianTokenizer()
|
||||
|
||||
# supported by fast-mosestokenizer
|
||||
if (
|
||||
lan
|
||||
in "as bn ca cs de el en es et fi fr ga gu hi hu is it kn lt lv ml mni mr nl or pa pl pt ro ru sk sl sv ta te yue zh".split()
|
||||
):
|
||||
from mosestokenizer import MosesSentenceSplitter
|
||||
|
||||
return MosesSentenceSplitter(lan)
|
||||
|
||||
# the following languages are in Whisper, but not in wtpsplit:
|
||||
if (
|
||||
lan
|
||||
in "as ba bo br bs fo haw hr ht jw lb ln lo mi nn oc sa sd sn so su sw tk tl tt".split()
|
||||
):
|
||||
logger.debug(
|
||||
f"{lan} code is not supported by wtpsplit. Going to use None lang_code option."
|
||||
)
|
||||
lan = None
|
||||
|
||||
from wtpsplit import WtP
|
||||
|
||||
# downloads the model from huggingface on the first use
|
||||
wtp = WtP("wtp-canine-s-12l-no-adapters")
|
||||
|
||||
class WtPtok:
|
||||
def split(self, sent):
|
||||
return wtp.split(sent, lang_code=lan)
|
||||
|
||||
return WtPtok()
|
||||
|
||||
|
||||
def backend_factory(args):
|
||||
backend = args.backend
|
||||
if backend == "openai-api":
|
||||
logger.debug("Using OpenAI API.")
|
||||
asr = OpenaiApiASR(lan=args.lan)
|
||||
else:
|
||||
if backend == "faster-whisper":
|
||||
asr_cls = FasterWhisperASR
|
||||
elif backend == "mlx-whisper":
|
||||
asr_cls = MLXWhisper
|
||||
else:
|
||||
asr_cls = WhisperTimestampedASR
|
||||
|
||||
# Only for FasterWhisperASR and WhisperTimestampedASR
|
||||
size = args.model
|
||||
t = time.time()
|
||||
logger.info(f"Loading Whisper {size} model for language {args.lan}...")
|
||||
asr = asr_cls(
|
||||
modelsize=size,
|
||||
lan=args.lan,
|
||||
cache_dir=getattr(args, 'model_cache_dir', None),
|
||||
model_dir=getattr(args, 'model_dir', None),
|
||||
)
|
||||
e = time.time()
|
||||
logger.info(f"done. It took {round(e-t,2)} seconds.")
|
||||
|
||||
# Apply common configurations
|
||||
if getattr(args, "vad", False): # Checks if VAD argument is present and True
|
||||
logger.info("Setting VAD filter")
|
||||
asr.use_vad()
|
||||
|
||||
language = args.lan
|
||||
if args.task == "translate":
|
||||
if backend != "simulstreaming":
|
||||
asr.set_translate_task()
|
||||
tgt_language = "en" # Whisper translates into English
|
||||
else:
|
||||
tgt_language = language # Whisper transcribes in this language
|
||||
|
||||
# Create the tokenizer
|
||||
if args.buffer_trimming == "sentence":
|
||||
tokenizer = create_tokenizer(tgt_language)
|
||||
else:
|
||||
tokenizer = None
|
||||
return asr, tokenizer
|
||||