89 Commits

Author SHA1 Message Date
Quentin Fuxa
6206fff118 0.2.15 2025-11-21 23:52:00 +01:00
Quentin Fuxa
b5067249c0 stt/diar/nllw alignment: internal rework 5 2025-11-20 23:52:00 +01:00
Quentin Fuxa
f4f9831d39 stt/diar/nllw alignment: internal rework 5 2025-11-20 23:52:00 +01:00
Quentin Fuxa
254faaf64c stt/diar/nllw alignment: internal rework 5 2025-11-20 23:52:00 +01:00
Quentin Fuxa
8e7aea4fcf internal rework 4 2025-11-20 23:45:20 +01:00
Quentin Fuxa
270faf2069 internal rework 3 2025-11-20 22:28:30 +01:00
Quentin Fuxa
b7c1cc77cc internal rework 2 2025-11-20 22:06:38 +01:00
Quentin Fuxa
9a45ec221c internal rework 1 2025-11-20 12:58:38 +01:00
Quentin Fuxa
3e13ee6fc3 bump to post4 2025-11-19 21:23:43 +01:00
Quentin Fuxa
b7d20a0ff0 segment attribution in result formatter 2025-11-19 21:10:28 +01:00
Quentin Fuxa
c1bb9c2bde reduce flickering remaining_time_transcription 2025-11-19 19:09:37 +01:00
Quentin Fuxa
11e9def0b2 diarization corrections 2025-11-19 19:06:03 +01:00
Quentin Fuxa
3104f40f6e fixes #279 #278 2025-11-19 18:17:50 +01:00
Quentin Fuxa
e9b4ceeee5 Add audio partial silence in chunks handling. bump to 0.2.14.post3 2025-11-17 22:52:00 +01:00
Quentin Fuxa
437641fb43 reduce min-chunk-size to 0.1, set default model to base 2027-04-25 23:52:00 +02:00
Quentin Fuxa
bfd60b3921 Add audio partial silence in chunks handling. bump to 0.2.14.post2 2025-11-17 22:52:00 +01:00
Quentin Fuxa
1e67bf97f0 improve buffering when use of heavy models 2027-04-25 23:52:00 +02:00
Quentin Fuxa
bbd4fd6cff Merge branch 'improve_EOS_handling' 2025-11-16 22:30:31 +01:00
Quentin Fuxa
28985962a0 Silence handling: finish transcription even if not validated at the BEGINNING of the silence 2025-11-16 22:29:08 +01:00
Quentin Fuxa
a38c103fcd simulstreaming coreml encoder compatibility 2025-11-16 21:24:14 +01:00
Quentin Fuxa
4d2ffb24f8 coreml conversion 2025-11-16 19:11:43 +01:00
Quentin Fuxa
1bbbb7903c lora loader in shared whisper core 2025-11-16 18:44:35 +01:00
Quentin Fuxa
bcffdbc6b3 bump to 0.2.14 2025-11-15 20:19:09 +01:00
Quentin Fuxa
80b77998f9 Refactor backend handling 2025-11-15 19:51:41 +01:00
Quentin Fuxa
d310f7e25f hf compatibility 2025-11-15 18:34:19 +01:00
Quentin Fuxa
8d9be88fe6 translation buffer is now displayed in frontend 2025-11-10 15:22:26 +01:00
Quentin Fuxa
16461052ed task to direct-english-translation 2025-11-10 13:20:26 +01:00
Quentin Fuxa
5491dbd824 last_validated_token handled in state 2025-11-10 13:18:52 +01:00
Quentin Fuxa
13401ffe24 whisper core at root of wlk 2025-11-10 12:17:18 +01:00
Quentin Fuxa
7108d2ddc5 fixes https://github.com/QuentinFuxa/WhisperLiveKit/issues/269 2025-11-09 20:08:18 +01:00
Quentin Fuxa
a732e0903e Add a script to detect alignement heads, usefull for distilled whisper 2025-11-09 18:12:09 +01:00
Quentin Fuxa
0491681be4 Distilled model compatibility with HF config.json to ModelDimensions 2025-11-08 20:20:05 +01:00
Quentin Fuxa
ffe5284764 _processing_tasks_done checks task completion 2025-11-05 23:34:00 +01:00
Quentin Fuxa
41ca17acda to 0.2.13 2025-10-30 23:30:49 +01:00
Quentin Fuxa
06b31f51eb exception when translation and no nllw 2025-10-30 23:30:19 +01:00
Quentin Fuxa
ece02db6a3 Use optional new separate NLLW package for translation 2025-10-30 19:36:28 +01:00
Quentin Fuxa
939a7ebf8b Translation Local Agreement + Cache optimization v0. Not connected yet 2025-10-28 00:16:52 +01:00
Quentin Fuxa
61edb70fff audioProcessor state variables are now uniquely in State dataclass 2025-10-26 18:54:47 +01:00
Quentin Fuxa
4e455b8aab translation now separates validated from output buffer tokens 2025-10-26 18:51:09 +01:00
Quentin Fuxa
9434390ad3 simplify task stopping condition 2025-10-26 17:26:43 +01:00
Quentin Fuxa
65250db92c tensor to list at the stream end 2025-10-26 16:40:12 +01:00
Quentin Fuxa
416dce7975 fixes #261
Co-authored-by: yosagi <11404771+yosagi@users.noreply.github.com>"
2025-10-25 14:20:08 +02:00
Quentin Fuxa
0c5365e7c6 fixes #258 2025-10-24 20:51:16 +02:00
Quentin Fuxa
19e9d76610 fixes #257 2025-10-24 20:39:37 +02:00
Quentin Fuxa
e7b05b0138 migration to silero vad v6: supports onnx 2025-10-23 23:52:00 +02:00
Quentin Fuxa
818c9c37ca README: path to doc for model file format 2025-10-23 20:34:36 +02:00
Quentin Fuxa
714fb3b14a custom faster-whisper/mlx whisper encoder available 2025-10-23 20:33:17 +02:00
Quentin Fuxa
0af379c465 DOC: information about file format 2025-10-23 20:32:05 +02:00
Quentin Fuxa
9c5bb5df19 README: dir to pah
Co-authored-by: David Georg Reichelt <david.reichelt@uni-leipzig.de>
2025-10-23 20:31:12 +02:00
Quentin Fuxa
dc6ea79036 apache license inheritance from simulwhisper and nemo 2025-10-23 20:28:02 +02:00
Quentin Fuxa
21bbb59e31 Merge pull request #250 from ladinu/patch-1
fix broken link
2025-10-15 08:59:02 +02:00
Quentin Fuxa
12a69205ed bump to 0.2.12 2025-10-06 19:59:05 +02:00
Quentin Fuxa
1f684cdd97 fixes #251 2025-10-06 19:53:27 +02:00
Ladinu Chandrasinghe
3467109668 fix broken link 2025-10-05 10:51:41 -07:00
Quentin Fuxa
971f8473eb update api doc 2025-10-05 11:09:47 +02:00
Quentin Fuxa
8434ef5efc update api 2025-10-05 11:09:12 +02:00
Quentin Fuxa
290470dd60 forwarded_allow_ips in core 2025-10-04 23:04:00 +02:00
Quentin Fuxa
425ac7b51d forwarded_allow_ips in core 2025-10-04 23:04:00 +02:00
Quentin Fuxa
0382cfbeba forwarded_allow_ips in core 2025-10-04 23:04:00 +02:00
Quentin Fuxa
9b1e061b32 forwarded_allow_ips in core 2025-10-04 23:04:00 +02:00
Quentin Fuxa
b4abc158b9 Merge pull request #249 from Damrod/add-ip-forwarding-support
fix wss for reverse proxying
2025-10-06 10:20:05 +02:00
Alvaro Ollero
5832d7433d update documentation 2025-10-04 23:18:10 +02:00
Alvaro Ollero
3736458503 Uvicorn exposes a configuration option to enable reverse proxying from a trusted ip. This PR exposes it downstreams to end clients 2025-10-04 22:21:06 +02:00
Quentin Fuxa
374618e050 token speakers are only reattributed for token coming after last_validated_token 2025-10-04 09:52:00 +02:00
Quentin Fuxa
543972ef38 fixes #248 2025-10-04 09:52:00 +02:00
Quentin Fuxa
73f36cc0ef v0 doc new api 2025-10-02 23:04:00 +02:00
Quentin Fuxa
a7db39d999 solves incorrect spacing in buffer diarization 2025-10-02 23:04:00 +02:00
Quentin Fuxa
a153e11fe0 update when self.diarization_before_transcription 2025-09-28 11:04:00 +02:00
Quentin Fuxa
ca6f9246cc force language = en for .en models 2025-09-28 11:04:00 +02:00
Quentin Fuxa
d080d675a8 cutom alignment heads parameter for custom models 2025-09-27 11:04:00 +02:00
Quentin Fuxa
40bff38933 Merge pull request #239 from msghik/feature/fine-tuned-model-support
feat: Allow loading fine-tuned models in simulstreaming
2025-09-29 10:08:26 +02:00
Quentin Fuxa
2fe3ca0188 connect source to output destination when used as chrome extension to keep audio playing 2025-09-27 13:59:44 +02:00
Quentin Fuxa
545ea15c9a ensure buffer size to be a multiple of the element size 2025-09-27 13:58:32 +02:00
Quentin Fuxa
8cbaeecc75 cutom alignment heads parameter for custom models 2025-09-27 11:04:00 +02:00
google-labs-jules[bot]
70e854b346 feat: Allow loading fine-tuned models in simulstreaming
This change modifies the `simulstreaming` backend to support loading fine-tuned Whisper models via the `--model_dir` argument.

The `SimulStreamingASR` class has been updated to:
- Use the `model_dir` path directly to load the model, which is the correct procedure for fine-tuned `.pt` files.
- Automatically disable the `faster-whisper` and `mlx-whisper` fast encoders when `model_dir` is used, as they are not compatible with standard fine-tuned models.

The call site in `core.py` already passed the `model_dir` argument, so no changes were needed there. This change makes the `simulstreaming` backend more flexible and allows users to leverage their own custom models.
2025-09-27 07:29:30 +00:00
Quentin Fuxa
d55490cd27 typo and simpler conditions 2025-09-26 20:38:26 +02:00
Quentin Fuxa
1fa9e1f656 Merge pull request #238 from CorentinvdBdO/fix_install
fix: translation in pyproject
2025-09-26 20:35:29 +02:00
cvandenbroek
994f30e1ed fix: translation in pyproject 2025-09-26 20:08:35 +02:00
Quentin Fuxa
b22478c0b4 correct silences handling when language not auto 2025-09-25 23:20:00 +02:00
Quentin Fuxa
94c34efd90 chrome extension ws default to localhost 2025-09-25 23:04:00 +02:00
Quentin Fuxa
32099b9275 demo extension 2025-09-25 23:59:24 +02:00
Quentin Fuxa
9fc6654a4a common frontend for web/ and chrome extension 2025-09-25 23:14:25 +02:00
Quentin Fuxa
d24c110d55 to 0.2.11 2025-09-24 22:34:01 +02:00
Quentin Fuxa
4dd5d8bf8a translation compatible with auto and detected language 2025-09-22 11:20:00 +02:00
Quentin Fuxa
cd9a32a36b update archi to show fastapi server is independent from core 2025-09-21 11:03:00 +02:00
Quentin Fuxa
6caf3e0485 correct silence handling in translation 2025-09-27 11:58:00 +02:00
Quentin Fuxa
93f002cafb language detection after few seconds working 2025-09-20 11:08:00 +02:00
Quentin Fuxa
c5e30c2c07 svg loaded once in javascript, no more need for StaticFiles 2025-09-20 11:06:00 +02:00
Quentin Fuxa
1c2afb8bd2 svg loaded once in javascript, no more need for StaticFiles 2025-09-20 11:06:00 +02:00
84 changed files with 3863 additions and 3706 deletions

16
.gitignore vendored
View File

@@ -54,21 +54,6 @@ coverage.xml
# Translations # Translations
*.mo *.mo
*.pot *.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder # PyBuilder
target/ target/
@@ -139,3 +124,4 @@ launch.json
.DS_Store .DS_Store
test/* test/*
nllb-200-distilled-600M-ctranslate2/* nllb-200-distilled-600M-ctranslate2/*
*.mp3

226
LICENSE
View File

@@ -1,52 +1,210 @@
# License Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
## Main Software License TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
MIT License 1. Definitions.
Copyright (c) 2025 Quentin Fuxa. "License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
Permission is hereby granted, free of charge, to any person obtaining a copy "Licensor" shall mean the copyright owner or entity authorized by
of this software and associated documentation files (the "Software"), to deal the copyright owner that is granting the License.
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all "Legal Entity" shall mean the union of the acting entity and all
copies or substantial portions of the Software. other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR "You" (or "Your") shall mean an individual or Legal Entity
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, exercising permissions granted by this License.
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
## SimulStreaming Backend License "Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
**When using the SimulStreaming backend (SimulWhisper), additional licensing terms apply:** "Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
SimulStreaming (https://github.com/ufal/SimulStreaming) is dual-licensed: "Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
### 🔹 Non-Commercial Use "Derivative Works" shall mean any work, whether in Source or Object
You may use SimulStreaming under the **PolyForm Noncommercial License 1.0.0** if you obtain the code through the GitHub repository. This license is **free of charge** and comes with **no obligations** for non-commercial users. form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
### 🔸 Commercial Use "Contribution" shall mean any work of authorship, including
Understanding who uses SimulStreaming commercially helps improve and prioritize development. Therefore, **registration is required** for those who acquire a commercial license. the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
Commercial licenses are planned to be **affordable** to SMEs and individuals. They are considering providing commercial licenses either for free or for a symbolic one-time fee, and may also provide additional support. You can share your preference via the [questionnaire](https://forms.cloud.microsoft.com/e/7tCxb4gJfB). "Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
You can also leave your contact [there](https://forms.cloud.microsoft.com/e/7tCxb4gJfB) to be notified when commercial licenses become available. 2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
**Contact for SimulStreaming licensing:** 3. Grant of Patent License. Subject to the terms and conditions of
[Dominik Macháček](https://ufal.mff.cuni.cz/dominik-machacek/), machacek@ufal.mff.cuni.cz this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2025 Quentin Fuxa
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
--- ---
## Based on: ## Based on:
- **whisper_streaming** by ÚFAL MIT License https://github.com/ufal/whisper_streaming. The original work by ÚFAL. License: https://github.com/ufal/whisper_streaming/blob/main/LICENSE - **SimulWhisper** by Speech and Audio Technology LAB of Tsinghua University Apache-2.0 https://github.com/ufal/SimulStreaming
- **silero-vad** by Snakers4 MIT License https://github.com/snakers4/silero-vad. The work by Snakers4 (silero-vad). License: https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/LICENSE - **SimulStreaming** by ÚFAL MIT License https://github.com/ufal/SimulStreaming
- **Diart** by juanmc2005 MIT License https://github.com/juanmc2005/diart. The work in Diart by juanmc2005. License: https://github.com/juanmc2005/diart/blob/main/LICENSE - **NeMo** by NVidia - Apache-2.0 - https://github.com/NVIDIA-NeMo/NeMo
- **SimulStreaming** by ÚFAL Dual License (PolyForm Noncommercial License 1.0.0 / Commercial License) https://github.com/ufal/SimulStreaming - **whisper_streaming** by ÚFAL MIT License https://github.com/ufal/whisper_streaming.
- **silero-vad** by Snakers4 MIT License https://github.com/snakers4/silero-vad.
- **Diart** by juanmc2005 MIT License https://github.com/juanmc2005/diart.

View File

@@ -10,16 +10,16 @@
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a> <a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a> <a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a>
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.15-dark_green"></a> <a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.15-dark_green"></a>
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT/Dual Licensed-dark_green"></a> <a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-Apache 2.0-dark_green"></a>
</p> </p>
Real-time speech transcription directly to your browser, with a ready-to-use backend+server and a simple frontend. Real-time transcription directly to your browser, with a ready-to-use backend+server and a simple frontend.
#### Powered by Leading Research: #### Powered by Leading Research:
- [SimulStreaming](https://github.com/ufalSimulStreaming) (SOTA 2025) - Ultra-low latency transcription using [AlignAtt policy](https://arxiv.org/pdf/2305.11408) - Simul-[Whisper](https://github.com/backspacetg/simul_whisper)/[Streaming](https://github.com/ufal/SimulStreaming) (SOTA 2025) - Ultra-low latency transcription using [AlignAtt policy](https://arxiv.org/pdf/2305.11408)
- [NLLB](https://arxiv.org/abs/2207.04672), ([distilled](https://huggingface.co/entai2965/nllb-200-distilled-600M-ctranslate2)) (2024) - Translation to more than 100 languages. - [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting) (2025), based on [distilled](https://huggingface.co/entai2965/nllb-200-distilled-600M-ctranslate2) [NLLB](https://arxiv.org/abs/2207.04672) (2022, 2024) - Simulatenous translation from & to 200 languages.
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - Low latency transcription using [LocalAgreement policy](https://www.isca-archive.org/interspeech_2020/liu20s_interspeech.pdf) - [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - Low latency transcription using [LocalAgreement policy](https://www.isca-archive.org/interspeech_2020/liu20s_interspeech.pdf)
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - Advanced real-time speaker diarization - [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - Advanced real-time speaker diarization
- [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) - Real-time speaker diarization - [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) - Real-time speaker diarization
@@ -45,7 +45,7 @@ pip install whisperlivekit
#### Quick Start #### Quick Start
1. **Start the transcription server:** 1. **Start the transcription server:**
```bash ```bash
whisperlivekit-server --model base --language en wlk --model base --language en
``` ```
2. **Open your browser** and navigate to `http://localhost:8000`. Start speaking and watch your words appear in real-time! 2. **Open your browser** and navigate to `http://localhost:8000`. Start speaking and watch your words appear in real-time!
@@ -53,6 +53,15 @@ pip install whisperlivekit
> - See [tokenizer.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) for the list of all available languages. > - See [tokenizer.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) for the list of all available languages.
> - For HTTPS requirements, see the **Parameters** section for SSL configuration options. > - For HTTPS requirements, see the **Parameters** section for SSL configuration options.
> - The CLI entry point is exposed as both `wlk` and `whisperlivekit-server`; they are equivalent.
#### Use it to capture audio from web pages.
Go to `chrome-extension` for instructions.
<p align="center">
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/chrome-extension/demo-extension.png" alt="WhisperLiveKit Demo" width="600">
</p>
@@ -60,13 +69,12 @@ pip install whisperlivekit
| Optional | `pip install` | | Optional | `pip install` |
|-----------|-------------| |-----------|-------------|
| **Speaker diarization with Sortformer** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` | | **Windows/Linux optimizations** | `faster-whisper` |
| **Apple Silicon optimized backend** | `mlx-whisper` | | **Apple Silicon optimizations** | `mlx-whisper` |
| **NLLB Translation** | `huggingface_hub` & `transformers` | | **Translation** | `nllw` |
| **Speaker diarization** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
| OpenAI API | `openai` |
| *[Not recommanded]* Speaker diarization with Diart | `diart` | | *[Not recommanded]* Speaker diarization with Diart | `diart` |
| *[Not recommanded]* Original Whisper backend | `whisper` |
| *[Not recommanded]* Improved timestamps backend | `whisper-timestamped` |
| OpenAI API backend | `openai` |
See **Parameters & Configuration** below on how to use them. See **Parameters & Configuration** below on how to use them.
@@ -78,10 +86,10 @@ See **Parameters & Configuration** below on how to use them.
```bash ```bash
# Large model and translate from french to danish # Large model and translate from french to danish
whisperlivekit-server --model large-v3 --language fr --target-language da wlk --model large-v3 --language fr --target-language da
# Diarization and server listening on */80 # Diarization and server listening on */80
whisperlivekit-server --host 0.0.0.0 --port 80 --model medium --diarization --language fr wlk --host 0.0.0.0 --port 80 --model medium --diarization --language fr
``` ```
@@ -131,12 +139,13 @@ async def websocket_endpoint(websocket: WebSocket):
| Parameter | Description | Default | | Parameter | Description | Default |
|-----------|-------------|---------| |-----------|-------------|---------|
| `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/available_models.md) | `small` | | `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/available_models.md) | `small` |
| `--language` | List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` | | `--model-path` | Local .pt file/directory **or** Hugging Face repo ID containing the Whisper model. Overrides `--model`. Recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/models_compatible_formats.md) | `None` |
| `--target-language` | If sets, activates translation using NLLB. Ex: `fr`. [118 languages available](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/translation/mapping_languages.py). If you want to translate to english, you should rather use `--task translate`, since Whisper can do it directly. | `None` | | `--language` | List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` |
| `--task` | Set to `translate` to translate *only* to english, using Whisper translation. | `transcribe` | | `--target-language` | If sets, translates using [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting). [200 languages available](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/supported_languages.md). If you want to translate to english, you can also use `--direct-english-translation`. The STT model will try to directly output the translation. | `None` |
| `--diarization` | Enable speaker identification | `False` | | `--diarization` | Enable speaker identification | `False` |
| `--backend` | Processing backend. You can switch to `faster-whisper` if `simulstreaming` does not work correctly | `simulstreaming` | | `--backend-policy` | Streaming strategy: `1`/`simulstreaming` uses AlignAtt SimulStreaming, `2`/`localagreement` uses the LocalAgreement policy | `simulstreaming` |
| `--backend` | Whisper implementation selector. `auto` picks MLX on macOS (if installed), otherwise Faster-Whisper, otherwise vanilla Whisper. You can also force `mlx-whisper`, `faster-whisper`, `whisper`, or `openai-api` (LocalAgreement only) | `auto` |
| `--no-vac` | Disable Voice Activity Controller | `False` | | `--no-vac` | Disable Voice Activity Controller | `False` |
| `--no-vad` | Disable Voice Activity Detection | `False` | | `--no-vad` | Disable Voice Activity Detection | `False` |
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` | | `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
@@ -144,6 +153,7 @@ async def websocket_endpoint(websocket: WebSocket):
| `--port` | Server port | `8000` | | `--port` | Server port | `8000` |
| `--ssl-certfile` | Path to the SSL certificate file (for HTTPS support) | `None` | | `--ssl-certfile` | Path to the SSL certificate file (for HTTPS support) | `None` |
| `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` | | `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` |
| `--forwarded-allow-ips` | Ip or Ips allowed to reverse proxy the whisperlivekit-server. Supported types are IP Addresses (e.g. 127.0.0.1), IP Networks (e.g. 10.100.0.0/16), or Literals (e.g. /path/to/socket.sock) | `None` |
| `--pcm-input` | raw PCM (s16le) data is expected as input and FFmpeg will be bypassed. Frontend will use AudioWorklet instead of MediaRecorder | `False` | | `--pcm-input` | raw PCM (s16le) data is expected as input and FFmpeg will be bypassed. Frontend will use AudioWorklet instead of MediaRecorder | `False` |
| Translation options | Description | Default | | Translation options | Description | Default |
@@ -161,6 +171,8 @@ async def websocket_endpoint(websocket: WebSocket):
| SimulStreaming backend options | Description | Default | | SimulStreaming backend options | Description | Default |
|-----------|-------------|---------| |-----------|-------------|---------|
| `--disable-fast-encoder` | Disable Faster Whisper or MLX Whisper backends for the encoder (if installed). Inference can be slower but helpful when GPU memory is limited | `False` | | `--disable-fast-encoder` | Disable Faster Whisper or MLX Whisper backends for the encoder (if installed). Inference can be slower but helpful when GPU memory is limited | `False` |
| `--custom-alignment-heads` | Use your own alignment heads, useful when `--model-dir` is used. Use `scripts/determine_alignment_heads.py` to extract them. <img src="scripts/alignment_heads.png" alt="WhisperLiveKit Demo" width="300">
| `None` |
| `--frame-threshold` | AlignAtt frame threshold (lower = faster, higher = more accurate) | `25` | | `--frame-threshold` | AlignAtt frame threshold (lower = faster, higher = more accurate) | `25` |
| `--beams` | Number of beams for beam search (1 = greedy decoding) | `1` | | `--beams` | Number of beams for beam search (1 = greedy decoding) | `1` |
| `--decoder` | Force decoder type (`beam` or `greedy`) | `auto` | | `--decoder` | Force decoder type (`beam` or `greedy`) | `auto` |
@@ -171,7 +183,6 @@ async def websocket_endpoint(websocket: WebSocket):
| `--init-prompt` | Initial prompt for the model | `None` | | `--init-prompt` | Initial prompt for the model | `None` |
| `--static-init-prompt` | Static prompt that doesn't scroll | `None` | | `--static-init-prompt` | Static prompt that doesn't scroll | `None` |
| `--max-context-tokens` | Maximum context tokens | `None` | | `--max-context-tokens` | Maximum context tokens | `None` |
| `--model-path` | Direct path to .pt model file. Download it if not found | `./base.pt` |
| `--preload-model-count` | Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent users) | `1` | | `--preload-model-count` | Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent users) | `1` |

Binary file not shown.

Before

Width:  |  Height:  |  Size: 368 KiB

After

Width:  |  Height:  |  Size: 406 KiB

View File

@@ -1,11 +1,13 @@
## WhisperLiveKit Chrome Extension v0.1.0 ## WhisperLiveKit Chrome Extension v0.1.1
Capture the audio of your current tab, transcribe or translate it using WhisperliveKit. **Still unstable** Capture the audio of your current tab, transcribe diarize and translate it using WhisperliveKit, in Chrome and other Chromium-based browsers.
> Currently, only the tab audio is captured; your microphone audio is not recorded.
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/chrome-extension/demo-extension.png" alt="WhisperLiveKit Demo" width="730"> <img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/chrome-extension/demo-extension.png" alt="WhisperLiveKit Demo" width="730">
## Running this extension ## Running this extension
1. Clone this repository. 1. Run `python sync_extension.py` to copy frontend files to the `chrome-extension` directory.
2. Load this directory in Chrome as an unpacked extension. 2. Load the `chrome-extension` directory in Chrome as an unpacked extension.
## Devs: ## Devs:

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.2 MiB

After

Width:  |  Height:  |  Size: 5.8 MiB

View File

@@ -1,669 +0,0 @@
/* Theme, WebSocket, recording, rendering logic extracted from inline script and adapted for segmented theme control and WS caption */
let isRecording = false;
let websocket = null;
let recorder = null;
let chunkDuration = 100;
let websocketUrl = "ws://localhost:8000/asr";
let userClosing = false;
let wakeLock = null;
let startTime = null;
let timerInterval = null;
let audioContext = null;
let analyser = null;
let microphone = null;
let waveCanvas = document.getElementById("waveCanvas");
let waveCtx = waveCanvas.getContext("2d");
let animationFrame = null;
let waitingForStop = false;
let lastReceivedData = null;
let lastSignature = null;
let availableMicrophones = [];
let selectedMicrophoneId = null;
waveCanvas.width = 60 * (window.devicePixelRatio || 1);
waveCanvas.height = 30 * (window.devicePixelRatio || 1);
waveCtx.scale(window.devicePixelRatio || 1, window.devicePixelRatio || 1);
const statusText = document.getElementById("status");
const recordButton = document.getElementById("recordButton");
const chunkSelector = document.getElementById("chunkSelector");
const websocketInput = document.getElementById("websocketInput");
const websocketDefaultSpan = document.getElementById("wsDefaultUrl");
const linesTranscriptDiv = document.getElementById("linesTranscript");
const timerElement = document.querySelector(".timer");
const themeRadios = document.querySelectorAll('input[name="theme"]');
const microphoneSelect = document.getElementById("microphoneSelect");
const settingsToggle = document.getElementById("settingsToggle");
const settingsDiv = document.querySelector(".settings");
chrome.runtime.onInstalled.addListener((details) => {
if (details.reason.search(/install/g) === -1) {
return
}
chrome.tabs.create({
url: chrome.runtime.getURL("welcome.html"),
active: true
})
})
function getWaveStroke() {
const styles = getComputedStyle(document.documentElement);
const v = styles.getPropertyValue("--wave-stroke").trim();
return v || "#000";
}
let waveStroke = getWaveStroke();
function updateWaveStroke() {
waveStroke = getWaveStroke();
}
function applyTheme(pref) {
if (pref === "light") {
document.documentElement.setAttribute("data-theme", "light");
} else if (pref === "dark") {
document.documentElement.setAttribute("data-theme", "dark");
} else {
document.documentElement.removeAttribute("data-theme");
}
updateWaveStroke();
}
// Persisted theme preference
const savedThemePref = localStorage.getItem("themePreference") || "system";
applyTheme(savedThemePref);
if (themeRadios.length) {
themeRadios.forEach((r) => {
r.checked = r.value === savedThemePref;
r.addEventListener("change", () => {
if (r.checked) {
localStorage.setItem("themePreference", r.value);
applyTheme(r.value);
}
});
});
}
// React to OS theme changes when in "system" mode
const darkMq = window.matchMedia && window.matchMedia("(prefers-color-scheme: dark)");
const handleOsThemeChange = () => {
const pref = localStorage.getItem("themePreference") || "system";
if (pref === "system") updateWaveStroke();
};
if (darkMq && darkMq.addEventListener) {
darkMq.addEventListener("change", handleOsThemeChange);
} else if (darkMq && darkMq.addListener) {
// deprecated, but included for Safari compatibility
darkMq.addListener(handleOsThemeChange);
}
async function enumerateMicrophones() {
try {
const micPermission = await navigator.permissions.query({
name: "microphone",
});
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
stream.getTracks().forEach(track => track.stop());
const devices = await navigator.mediaDevices.enumerateDevices();
availableMicrophones = devices.filter(device => device.kind === 'audioinput');
populateMicrophoneSelect();
console.log(`Found ${availableMicrophones.length} microphone(s)`);
} catch (error) {
console.error('Error enumerating microphones:', error);
statusText.textContent = "Error accessing microphones. Please grant permission.";
}
}
function populateMicrophoneSelect() {
if (!microphoneSelect) return;
microphoneSelect.innerHTML = '<option value="">Default Microphone</option>';
availableMicrophones.forEach((device, index) => {
const option = document.createElement('option');
option.value = device.deviceId;
option.textContent = device.label || `Microphone ${index + 1}`;
microphoneSelect.appendChild(option);
});
const savedMicId = localStorage.getItem('selectedMicrophone');
if (savedMicId && availableMicrophones.some(mic => mic.deviceId === savedMicId)) {
microphoneSelect.value = savedMicId;
selectedMicrophoneId = savedMicId;
}
}
function handleMicrophoneChange() {
selectedMicrophoneId = microphoneSelect.value || null;
localStorage.setItem('selectedMicrophone', selectedMicrophoneId || '');
const selectedDevice = availableMicrophones.find(mic => mic.deviceId === selectedMicrophoneId);
const deviceName = selectedDevice ? selectedDevice.label : 'Default Microphone';
console.log(`Selected microphone: ${deviceName}`);
statusText.textContent = `Microphone changed to: ${deviceName}`;
if (isRecording) {
statusText.textContent = "Switching microphone... Please wait.";
stopRecording().then(() => {
setTimeout(() => {
toggleRecording();
}, 1000);
});
}
}
// Helpers
function fmt1(x) {
const n = Number(x);
return Number.isFinite(n) ? n.toFixed(1) : x;
}
// Default WebSocket URL computation
const host = window.location.hostname || "localhost";
const port = window.location.port;
const protocol = window.location.protocol === "https:" ? "wss" : "ws";
const defaultWebSocketUrl = websocketUrl;
// Populate default caption and input
if (websocketDefaultSpan) websocketDefaultSpan.textContent = defaultWebSocketUrl;
websocketInput.value = defaultWebSocketUrl;
websocketUrl = defaultWebSocketUrl;
// Optional chunk selector (guard for presence)
if (chunkSelector) {
chunkSelector.addEventListener("change", () => {
chunkDuration = parseInt(chunkSelector.value);
});
}
// WebSocket input change handling
websocketInput.addEventListener("change", () => {
const urlValue = websocketInput.value.trim();
if (!urlValue.startsWith("ws://") && !urlValue.startsWith("wss://")) {
statusText.textContent = "Invalid WebSocket URL (must start with ws:// or wss://)";
return;
}
websocketUrl = urlValue;
statusText.textContent = "WebSocket URL updated. Ready to connect.";
});
function setupWebSocket() {
return new Promise((resolve, reject) => {
try {
websocket = new WebSocket(websocketUrl);
} catch (error) {
statusText.textContent = "Invalid WebSocket URL. Please check and try again.";
reject(error);
return;
}
websocket.onopen = () => {
statusText.textContent = "Connected to server.";
resolve();
};
websocket.onclose = () => {
if (userClosing) {
if (waitingForStop) {
statusText.textContent = "Processing finalized or connection closed.";
if (lastReceivedData) {
renderLinesWithBuffer(
lastReceivedData.lines || [],
lastReceivedData.buffer_diarization || "",
lastReceivedData.buffer_transcription || "",
0,
0,
true
);
}
}
} else {
statusText.textContent = "Disconnected from the WebSocket server. (Check logs if model is loading.)";
if (isRecording) {
stopRecording();
}
}
isRecording = false;
waitingForStop = false;
userClosing = false;
lastReceivedData = null;
websocket = null;
updateUI();
};
websocket.onerror = () => {
statusText.textContent = "Error connecting to WebSocket.";
reject(new Error("Error connecting to WebSocket"));
};
websocket.onmessage = (event) => {
const data = JSON.parse(event.data);
if (data.type === "ready_to_stop") {
console.log("Ready to stop received, finalizing display and closing WebSocket.");
waitingForStop = false;
if (lastReceivedData) {
renderLinesWithBuffer(
lastReceivedData.lines || [],
lastReceivedData.buffer_diarization || "",
lastReceivedData.buffer_transcription || "",
0,
0,
true
);
}
statusText.textContent = "Finished processing audio! Ready to record again.";
recordButton.disabled = false;
if (websocket) {
websocket.close();
}
return;
}
lastReceivedData = data;
const {
lines = [],
buffer_transcription = "",
buffer_diarization = "",
remaining_time_transcription = 0,
remaining_time_diarization = 0,
status = "active_transcription",
} = data;
renderLinesWithBuffer(
lines,
buffer_diarization,
buffer_transcription,
remaining_time_diarization,
remaining_time_transcription,
false,
status
);
};
});
}
function renderLinesWithBuffer(
lines,
buffer_diarization,
buffer_transcription,
remaining_time_diarization,
remaining_time_transcription,
isFinalizing = false,
current_status = "active_transcription"
) {
if (current_status === "no_audio_detected") {
linesTranscriptDiv.innerHTML =
"<p style='text-align: center; color: var(--muted); margin-top: 20px;'><em>No audio detected...</em></p>";
return;
}
const showLoading = !isFinalizing && (lines || []).some((it) => it.speaker == 0);
const showTransLag = !isFinalizing && remaining_time_transcription > 0;
const showDiaLag = !isFinalizing && !!buffer_diarization && remaining_time_diarization > 0;
const signature = JSON.stringify({
lines: (lines || []).map((it) => ({ speaker: it.speaker, text: it.text, start: it.start, end: it.end })),
buffer_transcription: buffer_transcription || "",
buffer_diarization: buffer_diarization || "",
status: current_status,
showLoading,
showTransLag,
showDiaLag,
isFinalizing: !!isFinalizing,
});
if (lastSignature === signature) {
const t = document.querySelector(".lag-transcription-value");
if (t) t.textContent = fmt1(remaining_time_transcription);
const d = document.querySelector(".lag-diarization-value");
if (d) d.textContent = fmt1(remaining_time_diarization);
const ld = document.querySelector(".loading-diarization-value");
if (ld) ld.textContent = fmt1(remaining_time_diarization);
return;
}
lastSignature = signature;
const linesHtml = (lines || [])
.map((item, idx) => {
let timeInfo = "";
if (item.start !== undefined && item.end !== undefined) {
timeInfo = ` ${item.start} - ${item.end}`;
}
let speakerLabel = "";
if (item.speaker === -2) {
speakerLabel = `<span class="silence">Silence<span id='timeInfo'>${timeInfo}</span></span>`;
} else if (item.speaker == 0 && !isFinalizing) {
speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'><span class="loading-diarization-value">${fmt1(
remaining_time_diarization
)}</span> second(s) of audio are undergoing diarization</span></span>`;
} else if (item.speaker !== 0) {
speakerLabel = `<span id="speaker">Speaker ${item.speaker}<span id='timeInfo'>${timeInfo}</span></span>`;
}
let currentLineText = item.text || "";
if (idx === lines.length - 1) {
if (!isFinalizing && item.speaker !== -2) {
if (remaining_time_transcription > 0) {
speakerLabel += `<span class="label_transcription"><span class="spinner"></span>Lag <span id='timeInfo'><span class="lag-transcription-value">${fmt1(
remaining_time_transcription
)}</span>s</span></span>`;
}
if (buffer_diarization && remaining_time_diarization > 0) {
speakerLabel += `<span class="label_diarization"><span class="spinner"></span>Lag<span id='timeInfo'><span class="lag-diarization-value">${fmt1(
remaining_time_diarization
)}</span>s</span></span>`;
}
}
if (buffer_diarization) {
if (isFinalizing) {
currentLineText +=
(currentLineText.length > 0 && buffer_diarization.trim().length > 0 ? " " : "") + buffer_diarization.trim();
} else {
currentLineText += `<span class="buffer_diarization">${buffer_diarization}</span>`;
}
}
if (buffer_transcription) {
if (isFinalizing) {
currentLineText +=
(currentLineText.length > 0 && buffer_transcription.trim().length > 0 ? " " : "") +
buffer_transcription.trim();
} else {
currentLineText += `<span class="buffer_transcription">${buffer_transcription}</span>`;
}
}
}
return currentLineText.trim().length > 0 || speakerLabel.length > 0
? `<p>${speakerLabel}<br/><div class='textcontent'>${currentLineText}</div></p>`
: `<p>${speakerLabel}<br/></p>`;
})
.join("");
linesTranscriptDiv.innerHTML = linesHtml;
window.scrollTo({ top: document.body.scrollHeight, behavior: "smooth" });
}
function updateTimer() {
if (!startTime) return;
const elapsed = Math.floor((Date.now() - startTime) / 1000);
const minutes = Math.floor(elapsed / 60).toString().padStart(2, "0");
const seconds = (elapsed % 60).toString().padStart(2, "0");
timerElement.textContent = `${minutes}:${seconds}`;
}
function drawWaveform() {
if (!analyser) return;
const bufferLength = analyser.frequencyBinCount;
const dataArray = new Uint8Array(bufferLength);
analyser.getByteTimeDomainData(dataArray);
waveCtx.clearRect(
0,
0,
waveCanvas.width / (window.devicePixelRatio || 1),
waveCanvas.height / (window.devicePixelRatio || 1)
);
waveCtx.lineWidth = 1;
waveCtx.strokeStyle = waveStroke;
waveCtx.beginPath();
const sliceWidth = (waveCanvas.width / (window.devicePixelRatio || 1)) / bufferLength;
let x = 0;
for (let i = 0; i < bufferLength; i++) {
const v = dataArray[i] / 128.0;
const y = (v * (waveCanvas.height / (window.devicePixelRatio || 1))) / 2;
if (i === 0) {
waveCtx.moveTo(x, y);
} else {
waveCtx.lineTo(x, y);
}
x += sliceWidth;
}
waveCtx.lineTo(
waveCanvas.width / (window.devicePixelRatio || 1),
(waveCanvas.height / (window.devicePixelRatio || 1)) / 2
);
waveCtx.stroke();
animationFrame = requestAnimationFrame(drawWaveform);
}
async function startRecording() {
try {
try {
wakeLock = await navigator.wakeLock.request("screen");
} catch (err) {
console.log("Error acquiring wake lock.");
}
let stream;
try {
// Try tab capture first
stream = await new Promise((resolve, reject) => {
chrome.tabCapture.capture({audio: true}, (s) => {
if (s) {
resolve(s);
} else {
reject(new Error('Tab capture failed or not available'));
}
});
});
statusText.textContent = "Using tab audio capture.";
} catch (tabError) {
console.log('Tab capture not available, falling back to microphone', tabError);
// Fallback to microphone
const audioConstraints = selectedMicrophoneId
? { audio: { deviceId: { exact: selectedMicrophoneId } } }
: { audio: true };
stream = await navigator.mediaDevices.getUserMedia(audioConstraints);
statusText.textContent = "Using microphone audio.";
}
audioContext = new (window.AudioContext || window.webkitAudioContext)();
analyser = audioContext.createAnalyser();
analyser.fftSize = 256;
microphone = audioContext.createMediaStreamSource(stream);
microphone.connect(analyser);
recorder = new MediaRecorder(stream, { mimeType: "audio/webm" });
recorder.ondataavailable = (e) => {
if (websocket && websocket.readyState === WebSocket.OPEN) {
websocket.send(e.data);
}
};
recorder.start(chunkDuration);
startTime = Date.now();
timerInterval = setInterval(updateTimer, 1000);
drawWaveform();
isRecording = true;
updateUI();
} catch (err) {
if (window.location.hostname === "0.0.0.0") {
statusText.textContent =
"Error accessing audio input. Browsers may block audio access on 0.0.0.0. Try using localhost:8000 instead.";
} else {
statusText.textContent = "Error accessing audio input. Please check permissions.";
}
console.error(err);
}
}
async function stopRecording() {
if (wakeLock) {
try {
await wakeLock.release();
} catch (e) {
// ignore
}
wakeLock = null;
}
userClosing = true;
waitingForStop = true;
if (websocket && websocket.readyState === WebSocket.OPEN) {
const emptyBlob = new Blob([], { type: "audio/webm" });
websocket.send(emptyBlob);
statusText.textContent = "Recording stopped. Processing final audio...";
}
if (recorder) {
recorder.stop();
recorder = null;
}
if (microphone) {
microphone.disconnect();
microphone = null;
}
if (analyser) {
analyser = null;
}
if (audioContext && audioContext.state !== "closed") {
try {
await audioContext.close();
} catch (e) {
console.warn("Could not close audio context:", e);
}
audioContext = null;
}
if (animationFrame) {
cancelAnimationFrame(animationFrame);
animationFrame = null;
}
if (timerInterval) {
clearInterval(timerInterval);
timerInterval = null;
}
timerElement.textContent = "00:00";
startTime = null;
isRecording = false;
updateUI();
}
async function toggleRecording() {
if (!isRecording) {
if (waitingForStop) {
console.log("Waiting for stop, early return");
return;
}
console.log("Connecting to WebSocket");
try {
if (websocket && websocket.readyState === WebSocket.OPEN) {
await startRecording();
} else {
await setupWebSocket();
await startRecording();
}
} catch (err) {
statusText.textContent = "Could not connect to WebSocket or access mic. Aborted.";
console.error(err);
}
} else {
console.log("Stopping recording");
stopRecording();
}
}
function updateUI() {
recordButton.classList.toggle("recording", isRecording);
recordButton.disabled = waitingForStop;
if (waitingForStop) {
if (statusText.textContent !== "Recording stopped. Processing final audio...") {
statusText.textContent = "Please wait for processing to complete...";
}
} else if (isRecording) {
statusText.textContent = "Recording...";
} else {
if (
statusText.textContent !== "Finished processing audio! Ready to record again." &&
statusText.textContent !== "Processing finalized or connection closed."
) {
statusText.textContent = "Click to start transcription";
}
}
if (!waitingForStop) {
recordButton.disabled = false;
}
}
recordButton.addEventListener("click", toggleRecording);
if (microphoneSelect) {
microphoneSelect.addEventListener("change", handleMicrophoneChange);
}
// Settings toggle functionality
settingsToggle.addEventListener("click", () => {
settingsDiv.classList.toggle("visible");
settingsToggle.classList.toggle("active");
});
document.addEventListener('DOMContentLoaded', async () => {
try {
await enumerateMicrophones();
} catch (error) {
console.log("Could not enumerate microphones on load:", error);
}
});
navigator.mediaDevices.addEventListener('devicechange', async () => {
console.log('Device change detected, re-enumerating microphones');
try {
await enumerateMicrophones();
} catch (error) {
console.log("Error re-enumerating microphones:", error);
}
});
async function run() {
const micPermission = await navigator.permissions.query({
name: "microphone",
});
document.getElementById(
"audioPermission"
).innerText = `MICROPHONE: ${micPermission.state}`;
if (micPermission.state !== "granted") {
chrome.tabs.create({ url: "welcome.html" });
}
const intervalId = setInterval(async () => {
const micPermission = await navigator.permissions.query({
name: "microphone",
});
if (micPermission.state === "granted") {
document.getElementById(
"audioPermission"
).innerText = `MICROPHONE: ${micPermission.state}`;
clearInterval(intervalId);
}
}, 100);
}
void run();

View File

@@ -3,9 +3,6 @@
"name": "WhisperLiveKit Tab Capture", "name": "WhisperLiveKit Tab Capture",
"version": "1.0", "version": "1.0",
"description": "Capture and transcribe audio from browser tabs using WhisperLiveKit.", "description": "Capture and transcribe audio from browser tabs using WhisperLiveKit.",
"background": {
"service_worker": "background.js"
},
"icons": { "icons": {
"16": "icons/icon16.png", "16": "icons/icon16.png",
"32": "icons/icon32.png", "32": "icons/icon32.png",
@@ -14,7 +11,7 @@
}, },
"action": { "action": {
"default_title": "WhisperLiveKit Tab Capture", "default_title": "WhisperLiveKit Tab Capture",
"default_popup": "popup.html" "default_popup": "live_transcription.html"
}, },
"permissions": [ "permissions": [
"scripting", "scripting",
@@ -22,16 +19,5 @@
"offscreen", "offscreen",
"activeTab", "activeTab",
"storage" "storage"
],
"web_accessible_resources": [
{
"resources": [
"requestPermissions.html",
"requestPermissions.js"
],
"matches": [
"<all_urls>"
]
}
] ]
} }

View File

@@ -1,78 +0,0 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>WhisperLiveKit</title>
<link rel="stylesheet" href="/web/live_transcription.css" />
</head>
<body>
<div class="settings-container">
<button id="recordButton">
<div class="shape-container">
<div class="shape"></div>
</div>
<div class="recording-info">
<div class="wave-container">
<canvas id="waveCanvas"></canvas>
</div>
<div class="timer">00:00</div>
</div>
</button>
<button id="settingsToggle" class="settings-toggle" title="Show/hide settings">
<img src="/web/src/settings.svg" alt="Settings" />
</button>
<div class="settings">
<div class="field">
<label for="websocketInput">Websocket URL</label>
<input id="websocketInput" type="text" placeholder="ws://host:port/asr" />
</div>
<div class="field">
<label id="microphoneSelectLabel" for="microphoneSelect">Select Microphone</label>
<select id="microphoneSelect">
<option value="">Default Microphone</option>
</select>
<div id="audioPermission"></div>
</div>
<div class="theme-selector-container">
<div class="segmented" role="radiogroup" aria-label="Theme selector">
<input type="radio" id="theme-system" name="theme" value="system" />
<label for="theme-system" title="System">
<img src="/web/src/system_mode.svg" alt="" />
<!-- <span>System</span> -->
</label>
<input type="radio" id="theme-light" name="theme" value="light" />
<label for="theme-light" title="Light">
<img src="/web/src/light_mode.svg" alt="" />
<!-- <span>Light</span> -->
</label>
<input type="radio" id="theme-dark" name="theme" value="dark" />
<label for="theme-dark" title="Dark">
<img src="/web/src/dark_mode.svg" alt="" />
<!-- <span>Dark</span> -->
</label>
</div>
</div>
</div>
</div>
<p id="status"></p>
<div id="linesTranscript"></div>
<script src="live_transcription.js"></script>
</body>
</html>

View File

@@ -1,539 +0,0 @@
:root {
--bg: #ffffff;
--text: #111111;
--muted: #666666;
--border: #e5e5e5;
--chip-bg: rgba(0, 0, 0, 0.04);
--chip-text: #000000;
--spinner-border: #8d8d8d5c;
--spinner-top: #b0b0b0;
--silence-bg: #f3f3f3;
--loading-bg: rgba(255, 77, 77, 0.06);
--button-bg: #ffffff;
--button-border: #e9e9e9;
--wave-stroke: #000000;
--label-dia-text: #868686;
--label-trans-text: #111111;
}
@media (prefers-color-scheme: dark) {
:root:not([data-theme="light"]) {
--bg: #0b0b0b;
--text: #e6e6e6;
--muted: #9aa0a6;
--border: #333333;
--chip-bg: rgba(255, 255, 255, 0.08);
--chip-text: #e6e6e6;
--spinner-border: #555555;
--spinner-top: #dddddd;
--silence-bg: #1a1a1a;
--loading-bg: rgba(255, 77, 77, 0.12);
--button-bg: #111111;
--button-border: #333333;
--wave-stroke: #e6e6e6;
--label-dia-text: #b3b3b3;
--label-trans-text: #ffffff;
}
}
:root[data-theme="dark"] {
--bg: #0b0b0b;
--text: #e6e6e6;
--muted: #9aa0a6;
--border: #333333;
--chip-bg: rgba(255, 255, 255, 0.08);
--chip-text: #e6e6e6;
--spinner-border: #555555;
--spinner-top: #dddddd;
--silence-bg: #1a1a1a;
--loading-bg: rgba(255, 77, 77, 0.12);
--button-bg: #111111;
--button-border: #333333;
--wave-stroke: #e6e6e6;
--label-dia-text: #b3b3b3;
--label-trans-text: #ffffff;
}
:root[data-theme="light"] {
--bg: #ffffff;
--text: #111111;
--muted: #666666;
--border: #e5e5e5;
--chip-bg: rgba(0, 0, 0, 0.04);
--chip-text: #000000;
--spinner-border: #8d8d8d5c;
--spinner-top: #b0b0b0;
--silence-bg: #f3f3f3;
--loading-bg: rgba(255, 77, 77, 0.06);
--button-bg: #ffffff;
--button-border: #e9e9e9;
--wave-stroke: #000000;
--label-dia-text: #868686;
--label-trans-text: #111111;
}
body {
font-family: ui-sans-serif, system-ui, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol', 'Noto Color Emoji';
margin: 20px;
text-align: center;
background-color: var(--bg);
color: var(--text);
}
.settings-toggle {
margin-top: 4px;
width: 40px;
height: 40px;
border: none;
border-radius: 50%;
background-color: var(--button-bg);
cursor: pointer;
transition: all 0.3s ease;
/* border: 1px solid var(--button-border); */
display: flex;
align-items: center;
justify-content: center;
position: relative;
}
.settings-toggle:hover {
background-color: var(--chip-bg);
}
.settings-toggle img {
width: 24px;
height: 24px;
opacity: 0.7;
transition: opacity 0.2s ease, transform 0.3s ease;
}
.settings-toggle:hover img {
opacity: 1;
}
.settings-toggle.active img {
transform: rotate(80deg);
}
/* Record button */
#recordButton {
width: 50px;
height: 50px;
border: none;
border-radius: 50%;
background-color: var(--button-bg);
cursor: pointer;
transition: all 0.3s ease;
border: 1px solid var(--button-border);
display: flex;
align-items: center;
justify-content: center;
position: relative;
}
#recordButton.recording {
width: 180px;
border-radius: 40px;
justify-content: flex-start;
padding-left: 20px;
}
#recordButton:active {
transform: scale(0.95);
}
.shape-container {
width: 25px;
height: 25px;
display: flex;
align-items: center;
justify-content: center;
flex-shrink: 0;
}
.shape {
width: 25px;
height: 25px;
background-color: rgb(209, 61, 53);
border-radius: 50%;
transition: all 0.3s ease;
}
#recordButton:disabled .shape {
background-color: #6e6d6d;
}
#recordButton.recording .shape {
border-radius: 5px;
width: 25px;
height: 25px;
}
/* Recording elements */
.recording-info {
display: none;
align-items: center;
margin-left: 15px;
flex-grow: 1;
}
#recordButton.recording .recording-info {
display: flex;
}
.wave-container {
width: 60px;
height: 30px;
position: relative;
display: flex;
align-items: center;
justify-content: center;
}
#waveCanvas {
width: 100%;
height: 100%;
}
.timer {
font-size: 14px;
font-weight: 500;
color: var(--text);
margin-left: 10px;
}
#status {
margin-top: 20px;
font-size: 16px;
color: var(--text);
}
/* Settings */
.settings-container {
display: flex;
justify-content: center;
align-items: flex-start;
gap: 15px;
margin-top: 20px;
flex-wrap: wrap;
}
.settings {
display: none;
flex-wrap: wrap;
align-items: flex-start;
gap: 12px;
transition: opacity 0.3s ease;
}
.settings.visible {
display: flex;
}
.field {
display: flex;
flex-direction: column;
align-items: flex-start;
gap: 3px;
}
#chunkSelector,
#websocketInput,
#themeSelector,
#microphoneSelect {
font-size: 16px;
padding: 5px 8px;
border-radius: 8px;
border: 1px solid var(--border);
background-color: var(--button-bg);
color: var(--text);
max-height: 30px;
}
#microphoneSelect {
width: 100%;
max-width: 190px;
min-width: 120px;
}
#chunkSelector:focus,
#websocketInput:focus,
#themeSelector:focus,
#microphoneSelect:focus {
outline: none;
border-color: #007bff;
box-shadow: 0 0 0 3px rgba(0, 123, 255, 0.15);
}
label {
font-size: 13px;
color: var(--muted);
}
.ws-default {
font-size: 12px;
color: var(--muted);
}
/* Segmented pill control for Theme */
.segmented {
display: inline-flex;
align-items: stretch;
border: 1px solid var(--button-border);
background-color: var(--button-bg);
border-radius: 999px;
overflow: hidden;
}
.segmented input[type="radio"] {
position: absolute;
opacity: 0;
pointer-events: none;
}
.theme-selector-container {
display: flex;
align-items: center;
margin-top: 17px;
}
.segmented label {
display: inline-flex;
align-items: center;
gap: 6px;
padding: 6px 12px;
font-size: 14px;
color: var(--muted);
cursor: pointer;
user-select: none;
transition: background-color 0.2s ease, color 0.2s ease;
}
.segmented label span {
display: none;
}
.segmented label:hover span {
display: inline;
}
.segmented label:hover {
background-color: var(--chip-bg);
}
.segmented img {
width: 16px;
height: 16px;
}
.segmented input[type="radio"]:checked + label {
background-color: var(--chip-bg);
color: var(--text);
}
.segmented input[type="radio"]:focus-visible + label,
.segmented input[type="radio"]:focus + label {
outline: 2px solid #007bff;
outline-offset: 2px;
border-radius: 999px;
}
/* Transcript area */
#linesTranscript {
margin: 20px auto;
max-width: 700px;
text-align: left;
font-size: 16px;
}
#linesTranscript p {
margin: 0px 0;
}
#linesTranscript strong {
color: var(--text);
}
#speaker {
border: 1px solid var(--border);
border-radius: 100px;
padding: 2px 10px;
font-size: 14px;
margin-bottom: 0px;
}
.label_diarization {
background-color: var(--chip-bg);
border-radius: 8px 8px 8px 8px;
padding: 2px 10px;
margin-left: 10px;
display: inline-block;
white-space: nowrap;
font-size: 14px;
margin-bottom: 0px;
color: var(--label-dia-text);
}
.label_transcription {
background-color: var(--chip-bg);
border-radius: 8px 8px 8px 8px;
padding: 2px 10px;
display: inline-block;
white-space: nowrap;
margin-left: 10px;
font-size: 14px;
margin-bottom: 0px;
color: var(--label-trans-text);
}
#timeInfo {
color: var(--muted);
margin-left: 10px;
}
.textcontent {
font-size: 16px;
padding-left: 10px;
margin-bottom: 10px;
margin-top: 1px;
padding-top: 5px;
border-radius: 0px 0px 0px 10px;
}
.buffer_diarization {
color: var(--label-dia-text);
margin-left: 4px;
}
.buffer_transcription {
color: #7474748c;
margin-left: 4px;
}
.spinner {
display: inline-block;
width: 8px;
height: 8px;
border: 2px solid var(--spinner-border);
border-top: 2px solid var(--spinner-top);
border-radius: 50%;
animation: spin 0.7s linear infinite;
vertical-align: middle;
margin-bottom: 2px;
margin-right: 5px;
}
@keyframes spin {
to {
transform: rotate(360deg);
}
}
.silence {
color: var(--muted);
background-color: var(--silence-bg);
font-size: 13px;
border-radius: 30px;
padding: 2px 10px;
}
.loading {
color: var(--muted);
background-color: var(--loading-bg);
border-radius: 8px 8px 8px 0px;
padding: 2px 10px;
font-size: 14px;
margin-bottom: 0px;
}
/* for smaller screens */
/* @media (max-width: 450px) {
.settings-container {
flex-direction: column;
gap: 10px;
align-items: center;
}
.settings {
justify-content: center;
gap: 8px;
width: 100%;
}
.field {
align-items: center;
width: 100%;
}
#websocketInput,
#microphoneSelect {
min-width: 200px;
max-width: 100%;
}
.theme-selector-container {
margin-top: 10px;
}
} */
/* @media (max-width: 768px) and (min-width: 451px) {
.settings-container {
gap: 10px;
}
.settings {
gap: 8px;
}
#websocketInput,
#microphoneSelect {
min-width: 150px;
max-width: 300px;
}
} */
/* @media (max-width: 480px) {
body {
margin: 10px;
}
.settings-toggle {
width: 35px;
height: 35px;
}
.settings-toggle img {
width: 20px;
height: 20px;
}
.settings {
flex-direction: column;
align-items: center;
gap: 6px;
}
#websocketInput,
#microphoneSelect {
max-width: 400px;
}
.segmented label {
padding: 4px 8px;
font-size: 12px;
}
.segmented img {
width: 14px;
height: 14px;
}
} */
html
{
width: 400px; /* max: 800px */
height: 600px; /* max: 600px */
border-radius: 10px;
}

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M480-120q-151 0-255.5-104.5T120-480q0-138 90-239.5T440-838q13-2 23 3.5t16 14.5q6 9 6.5 21t-7.5 23q-17 26-25.5 55t-8.5 61q0 90 63 153t153 63q31 0 61.5-9t54.5-25q11-7 22.5-6.5T819-479q10 5 15.5 15t3.5 24q-14 138-117.5 229T480-120Zm0-80q88 0 158-48.5T740-375q-20 5-40 8t-40 3q-123 0-209.5-86.5T364-660q0-20 3-40t8-40q-78 32-126.5 102T200-480q0 116 82 198t198 82Zm-10-270Z"/></svg>

Before

Width:  |  Height:  |  Size: 493 B

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M480-360q50 0 85-35t35-85q0-50-35-85t-85-35q-50 0-85 35t-35 85q0 50 35 85t85 35Zm0 80q-83 0-141.5-58.5T280-480q0-83 58.5-141.5T480-680q83 0 141.5 58.5T680-480q0 83-58.5 141.5T480-280ZM80-440q-17 0-28.5-11.5T40-480q0-17 11.5-28.5T80-520h80q17 0 28.5 11.5T200-480q0 17-11.5 28.5T160-440H80Zm720 0q-17 0-28.5-11.5T760-480q0-17 11.5-28.5T800-520h80q17 0 28.5 11.5T920-480q0 17-11.5 28.5T880-440h-80ZM480-760q-17 0-28.5-11.5T440-800v-80q0-17 11.5-28.5T480-920q17 0 28.5 11.5T520-880v80q0 17-11.5 28.5T480-760Zm0 720q-17 0-28.5-11.5T440-80v-80q0-17 11.5-28.5T480-200q17 0 28.5 11.5T520-160v80q0 17-11.5 28.5T480-40ZM226-678l-43-42q-12-11-11.5-28t11.5-29q12-12 29-12t28 12l42 43q11 12 11 28t-11 28q-11 12-27.5 11.5T226-678Zm494 495-42-43q-11-12-11-28.5t11-27.5q11-12 27.5-11.5T734-282l43 42q12 11 11.5 28T777-183q-12 12-29 12t-28-12Zm-42-495q-12-11-11.5-27.5T678-734l42-43q11-12 28-11.5t29 11.5q12 12 12 29t-12 28l-43 42q-12 11-28 11t-28-11ZM183-183q-12-12-12-29t12-28l43-42q12-11 28.5-11t27.5 11q12 11 11.5 27.5T282-226l-42 43q-11 12-28 11.5T183-183Zm297-297Z"/></svg>

Before

Width:  |  Height:  |  Size: 1.2 KiB

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M433-80q-27 0-46.5-18T363-142l-9-66q-13-5-24.5-12T307-235l-62 26q-25 11-50 2t-39-32l-47-82q-14-23-8-49t27-43l53-40q-1-7-1-13.5v-27q0-6.5 1-13.5l-53-40q-21-17-27-43t8-49l47-82q14-23 39-32t50 2l62 26q11-8 23-15t24-12l9-66q4-26 23.5-44t46.5-18h94q27 0 46.5 18t23.5 44l9 66q13 5 24.5 12t22.5 15l62-26q25-11 50-2t39 32l47 82q14 23 8 49t-27 43l-53 40q1 7 1 13.5v27q0 6.5-2 13.5l53 40q21 17 27 43t-8 49l-48 82q-14 23-39 32t-50-2l-60-26q-11 8-23 15t-24 12l-9 66q-4 26-23.5 44T527-80h-94Zm7-80h79l14-106q31-8 57.5-23.5T639-327l99 41 39-68-86-65q5-14 7-29.5t2-31.5q0-16-2-31.5t-7-29.5l86-65-39-68-99 42q-22-23-48.5-38.5T533-694l-13-106h-79l-14 106q-31 8-57.5 23.5T321-633l-99-41-39 68 86 64q-5 15-7 30t-2 32q0 16 2 31t7 30l-86 65 39 68 99-42q22 23 48.5 38.5T427-266l13 106Zm42-180q58 0 99-41t41-99q0-58-41-99t-99-41q-59 0-99.5 41T342-480q0 58 40.5 99t99.5 41Zm-2-140Z"/></svg>

Before

Width:  |  Height:  |  Size: 982 B

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M396-396q-32-32-58.5-67T289-537q-5 14-6.5 28.5T281-480q0 83 58 141t141 58q14 0 28.5-2t28.5-6q-39-22-74-48.5T396-396Zm85 196q-56 0-107-21t-91-61q-40-40-61-91t-21-107q0-51 17-97.5t50-84.5q13-14 32-9.5t27 24.5q21 55 52.5 104t73.5 91q42 42 91 73.5T648-326q20 8 24.5 27t-9.5 32q-38 33-84.5 50T481-200Zm223-192q-16-5-23-20.5t-4-32.5q9-48-6-94.5T621-621q-35-35-80.5-49.5T448-677q-17 3-32-4t-21-23q-6-16 1.5-31t23.5-19q69-15 138 4.5T679-678q51 51 71 120t5 138q-4 17-19 25t-32 3ZM480-840q-17 0-28.5-11.5T440-880v-40q0-17 11.5-28.5T480-960q17 0 28.5 11.5T520-920v40q0 17-11.5 28.5T480-840Zm0 840q-17 0-28.5-11.5T440-40v-40q0-17 11.5-28.5T480-120q17 0 28.5 11.5T520-80v40q0 17-11.5 28.5T480 0Zm255-734q-12-12-12-28.5t12-28.5l28-28q11-11 27.5-11t28.5 11q12 12 12 28.5T819-762l-28 28q-12 12-28 12t-28-12ZM141-141q-12-12-12-28.5t12-28.5l28-28q12-12 28-12t28 12q12 12 12 28.5T225-169l-28 28q-11 11-27.5 11T141-141Zm739-299q-17 0-28.5-11.5T840-480q0-17 11.5-28.5T880-520h40q17 0 28.5 11.5T960-480q0 17-11.5 28.5T920-440h-40Zm-840 0q-17 0-28.5-11.5T0-480q0-17 11.5-28.5T40-520h40q17 0 28.5 11.5T120-480q0 17-11.5 28.5T80-440H40Zm779 299q-12 12-28.5 12T762-141l-28-28q-12-12-12-28t12-28q12-12 28.5-12t28.5 12l28 28q11 11 11 27.5T819-141ZM226-735q-12 12-28.5 12T169-735l-28-28q-11-11-11-27.5t11-28.5q12-12 28.5-12t28.5 12l28 28q12 12 12 28t-12 28Zm170 339Z"/></svg>

Before

Width:  |  Height:  |  Size: 1.4 KiB

View File

@@ -1,12 +0,0 @@
<!DOCTYPE html>
<html>
<head>
<title>Welcome</title>
<script src="welcome.js"></script>
</head>
<body>
This page exists to workaround an issue with Chrome that blocks permission
requests from chrome extensions
<!-- <button id="requestMicrophone">Request Microphone</button> -->
</body>
</html>

264
docs/API.md Normal file
View File

@@ -0,0 +1,264 @@
# WhisperLiveKit WebSocket API Documentation
> !! **Note**: The new API structure described in this document is currently under deployment.
This documentation is intended for devs who want to build custom frontends.
WLK provides real-time speech transcription, speaker diarization, and translation through a WebSocket API. The server sends incremental updates as audio is processed, allowing clients to display live transcription results with minimal latency.
---
## Legacy API (Current)
### Message Structure
The current API sends complete state snapshots on each update (several time per second)
```typescript
{
"type": str,
"status": str,
"lines": [
{
"speaker": int,
"text": str,
"start": float,
"end": float,
"translation": str | null,
"detected_language": str
}
],
"buffer_transcription": str,
"buffer_diarization": str,
"remaining_time_transcription": float,
"remaining_time_diarization": float
}
```
---
## New API (Under Development)
### Philosophy
Principles:
- **Incremental Updates**: Only updates and new segments are sent
- **Ephemeral Buffers**: Temporary, unvalidated data displayed in real-time but overwritten on next update, at speaker level
## Message Format
```typescript
{
"type": "transcript_update",
"status": "active_transcription" | "no_audio_detected",
"segments": [
{
"id": number,
"speaker": number,
"text": string,
"start_speaker": float,
"start": float,
"end": float,
"language": string | null,
"translation": string,
"words": [
{
"text": string,
"start": float,
"end": float,
"validated": {
"text": boolean,
"speaker": boolean,
}
}
],
"buffer": {
"transcription": string,
"diarization": string,
"translation": string
}
}
],
"metadata": {
"remaining_time_transcription": float,
"remaining_time_diarization": float
}
}
```
### Other Message Types
#### Config Message (sent on connection)
```json
{
"type": "config",
"useAudioWorklet": true / false
}
```
#### Ready to Stop Message (sent after processing complete)
```json
{
"type": "ready_to_stop"
}
```
---
## Field Descriptions
### Segment Fields
| Field | Type | Description |
|-------|------|-------------|
| `id` | `number` | Unique identifier for this segment. Used by clients to update specific segments efficiently. |
| `speaker` | `number` | Speaker ID (1, 2, 3...). Special value `-2` indicates silence. |
| `text` | `string` | Validated transcription text for this update. Should be **appended** to the segment's text on the client side. |
| `start_speaker` | `float` | Timestamp (seconds) when this speaker segment began. |
| `start` | `float` | Timestamp (seconds) of the first word in this update. |
| `end` | `float` | Timestamp (seconds) of the last word in this update. |
| `language` | `string \| null` | ISO language code (e.g., "en", "fr"). `null` until language is detected. |
| `translation` | `string` | Validated translation text for this update. Should be **appended** to the segment's translation on the client side. |
| `words` | `Array` | Array of word-level objects with timing and validation information. |
| `buffer` | `Object` | Per-segment temporary buffers, see below |
### Word Object
| Field | Type | Description |
|-------|------|-------------|
| `text` | `string` | The word text. |
| `start` | `number` | Start timestamp (seconds) of this word. |
| `end` | `number` | End timestamp (seconds) of this word. |
| `validated.text` | `boolean` | Whether the transcription text has been validated. if false, word is also in buffer: transcription |
| `validated.speaker` | `boolean` | Whether the speaker assignment has been validated. if false, word is also in buffer: diarization |
| `validated.language` | `boolean` | Whether the language detection has been validated. if false, word is also in buffer: translation |
### Buffer Object (Per-Segment)
Buffers are **ephemeral**. They should be displayed to the user but not stored permanently in the frontend. Each update may contain a completely different buffer value, and previous buffer is likely to be in the next validated text.
| Field | Type | Description |
|-------|------|-------------|
| `transcription` | `string` | Pending transcription text. Displayed immediately but **overwritten** on next update. |
| `diarization` | `string` | Pending diarization text (text waiting for speaker assignment). Displayed immediately but **overwritten** on next update. |
| `translation` | `string` | Pending translation text. Displayed immediately but **overwritten** on next update. |
### Metadata Fields
| Field | Type | Description |
|-------|------|-------------|
| `remaining_time_transcription` | `float` | Seconds of audio waiting for transcription processing. |
| `remaining_time_diarization` | `float` | Seconds of audio waiting for speaker diarization. |
### Status Values
| Status | Description |
|--------|-------------|
| `active_transcription` | Normal operation, transcription is active. |
| `no_audio_detected` | No audio has been detected yet. |
---
## Update Behavior
### Incremental Updates
The API sends **only changed or new segments**. Clients should:
1. Maintain a local map of segments by ID
2. When receiving an update, merge/update segments by ID
3. Render only the changed segments
### Language Detection
When language is detected for a segment:
```jsonc
// Update 1: No language yet
{
"segments": [
{"id": 1, "speaker": 1, "text": "May see", "language": null}
]
}
// Update 2: Same segment ID, language now detected
{
"segments": [
{"id": 1, "speaker": 1, "text": "Merci", "language": "fr"}
]
}
```
**Client behavior**: **Replace** the existing segment with the same ID.
### Buffer Behavior
Buffers are **per-segment** to handle multi-speaker scenarios correctly.
#### Example: Translation with diarization and translation
```jsonc
// Update 1
{
"segments": [
{
"id": 1,
"speaker": 1,
"text": "Hello world, how are",
"translation": "",
"buffer": {
"transcription": "",
"diarization": " you on",
"translation": "Bonjour le monde"
}
}
]
}
// ==== Frontend ====
// <SPEAKER>1</SPEAKER>
// <TRANSCRIPTION>Hello world, how are <DIARIZATION BUFFER> you on</DIARIZATION BUFFER></TRANSCRIPTION>
// <TRANSLATION><TRANSLATION BUFFER>Bonjour le monde</TRANSLATION BUFFER></TRANSLATION>
// Update 2
{
"segments": [
{
"id": 1,
"speaker": 1,
"text": " you on this",
"translation": "Bonjour tout le monde",
"buffer": {
"transcription": "",
"diarization": " beautiful day",
"translation": ",comment"
}
},
]
}
// ==== Frontend ====
// <SPEAKER>1</SPEAKER>
// <TRANSCRIPTION>Hello world, how are you on this<DIARIZATION BUFFER> beautiful day</DIARIZATION BUFFER></TRANSCRIPTION>
// <TRANSLATION>Bonjour tout le monde<TRANSLATION BUFFER>, comment</TRANSLATION BUFFER><TRANSLATION>
```
### Silence Segments
Silence is represented with the speaker id = `-2`:
```jsonc
{
"id": 5,
"speaker": -2,
"text": "",
"start": 10.5,
"end": 12.3
}
```

View File

@@ -0,0 +1,71 @@
### Alignment between STT Tokens and Diarization Segments
- Example 1: The punctuation from STT and the speaker change from Diariation come in the prediction `t`
- Example 2: The punctuation from STT comes from prediction `t`, but the speaker change from Diariation come in the prediction `t-1`
- Example 3: The punctuation from STT comes from prediction `t-1`, but the speaker change from Diariation come in the prediction `t`
> `#` Is the split between the `t-1` prediction and `t` prediction.
## Example 1:
```text
punctuations_segments : __#_______.__________________!____
diarization_segments:
SPK1 __#____________
SPK2 # ___________________
-->
ALIGNED SPK1 __#_______.
ALIGNED SPK2 # __________________!____
t-1 output:
SPK1: __#
SPK2: NO
DIARIZATION BUFFER: NO
t output:
SPK1: __#__.
SPK2: __________________!____
DIARIZATION BUFFER: No
```
## Example 2:
```text
punctuations_segments : _____#__.___________
diarization_segments:
SPK1 ___ #
SPK2 __#______________
-->
ALIGNED SPK1 _____#__.
ALIGNED SPK2 # ___________
t-1 output:
SPK1: ___ #
SPK2:
DIARIZATION BUFFER: __#
t output:
SPK1: __#__.
SPK2: ___________
DIARIZATION BUFFER: No
```
## Example 3:
```text
punctuations_segments : ___.__#__________
diarization_segments:
SPK1 ______#__
SPK2 # ________
-->
ALIGNED SPK1 ___. #
ALIGNED SPK2 __#__________
t-1 output:
SPK1: ___. #
SPK2:
DIARIZATION BUFFER: __#
t output:
SPK1: #
SPK2: __#___________
DIARIZATION BUFFER: NO
```

View File

@@ -0,0 +1,19 @@
# Model Path Formats
The `--model-path` parameter accepts:
## File Path
- **`.pt` / `.bin` / `.safetensor` formats** Should be openable by pytorch/safetensor.
## Directory Path (recommended)
Must contain:
- **`.pt` / `.bin` / `.safetensor` file** (required for decoder)
May optionally contain:
- **`.bin` file** - faster-whisper model for encoder (requires faster-whisper)
- **`weights.npz`** or **`weights.safetensors`** - for encoder (requires whisper-mlx)
## Hugging Face Repo ID
- Provide the repo ID (e.g. `openai/whisper-large-v3`) and WhisperLiveKit will download and cache the snapshot automatically. For gated repos, authenticate via `huggingface-cli login` first.
To improve speed/reduce allucinations, you may want to use `scripts/determine_alignment_heads.py` to determine the alignment heads to use for your model, and use the `--custom-alignment-heads` to pass them to WLK. If not, alignement heads are set to be all the heads of the last half layer of decoder.

265
docs/supported_languages.md Normal file
View File

@@ -0,0 +1,265 @@
# Supported Languages
WhisperLiveKit supports translation into **201 languages** from the FLORES-200 dataset through the NLLB (No Language Left Behind) translation system.
## How to Specify Languages
You can specify languages in **three different ways**:
1. **Language Name** (case-insensitive): `"English"`, `"French"`, `"Spanish"`
2. **ISO Language Code**: `"en"`, `"fr"`, `"es"`
3. **NLLB Code** (FLORES-200): `"eng_Latn"`, `"fra_Latn"`, `"spa_Latn"`
## Usage Examples
### Command Line
```bash
# Using language name
whisperlivekit-server --target-language "French"
# Using ISO code
whisperlivekit-server --target-language fr
# Using NLLB code
whisperlivekit-server --target-language fra_Latn
```
### Python API
```python
from nllw.translation import get_language_info
# Get language information by name
lang_info = get_language_info("French")
print(lang_info)
# {'name': 'French', 'nllb': 'fra_Latn', 'language_code': 'fr'}
# Get language information by ISO code
lang_info = get_language_info("fr")
# Get language information by NLLB code
lang_info = get_language_info("fra_Latn")
# All three return the same result
```
## Complete Language List
The following table lists all 201 supported languages with their corresponding codes:
| Language Name | ISO Code | NLLB Code |
|---------------|----------|-----------|
| Acehnese (Arabic script) | ace_Arab | ace_Arab |
| Acehnese (Latin script) | ace_Latn | ace_Latn |
| Mesopotamian Arabic | acm_Arab | acm_Arab |
| Ta'izzi-Adeni Arabic | acq_Arab | acq_Arab |
| Tunisian Arabic | aeb_Arab | aeb_Arab |
| Afrikaans | af | afr_Latn |
| South Levantine Arabic | ajp_Arab | ajp_Arab |
| Akan | ak | aka_Latn |
| Tosk Albanian | als | als_Latn |
| Amharic | am | amh_Ethi |
| North Levantine Arabic | apc_Arab | apc_Arab |
| Modern Standard Arabic | ar | arb_Arab |
| Modern Standard Arabic (Romanized) | arb_Latn | arb_Latn |
| Najdi Arabic | ars_Arab | ars_Arab |
| Moroccan Arabic | ary_Arab | ary_Arab |
| Egyptian Arabic | arz_Arab | arz_Arab |
| Assamese | as | asm_Beng |
| Asturian | ast | ast_Latn |
| Awadhi | awa | awa_Deva |
| Central Aymara | ay | ayr_Latn |
| South Azerbaijani | azb | azb_Arab |
| North Azerbaijani | az | azj_Latn |
| Bashkir | ba | bak_Cyrl |
| Bambara | bm | bam_Latn |
| Balinese | ban | ban_Latn |
| Belarusian | be | bel_Cyrl |
| Bemba | bem | bem_Latn |
| Bengali | bn | ben_Beng |
| Bhojpuri | bho | bho_Deva |
| Banjar (Arabic script) | bjn_Arab | bjn_Arab |
| Banjar (Latin script) | bjn_Latn | bjn_Latn |
| Standard Tibetan | bo | bod_Tibt |
| Bosnian | bs | bos_Latn |
| Buginese | bug | bug_Latn |
| Bulgarian | bg | bul_Cyrl |
| Catalan | ca | cat_Latn |
| Cebuano | ceb | ceb_Latn |
| Czech | cs | ces_Latn |
| Chokwe | cjk | cjk_Latn |
| Central Kurdish | ckb | ckb_Arab |
| Crimean Tatar | crh | crh_Latn |
| Welsh | cy | cym_Latn |
| Danish | da | dan_Latn |
| German | de | deu_Latn |
| Southwestern Dinka | dik | dik_Latn |
| Dyula | dyu | dyu_Latn |
| Dzongkha | dz | dzo_Tibt |
| Greek | el | ell_Grek |
| English | en | eng_Latn |
| Esperanto | eo | epo_Latn |
| Estonian | et | est_Latn |
| Basque | eu | eus_Latn |
| Ewe | ee | ewe_Latn |
| Faroese | fo | fao_Latn |
| Fijian | fj | fij_Latn |
| Finnish | fi | fin_Latn |
| Fon | fon | fon_Latn |
| French | fr | fra_Latn |
| Friulian | fur-IT | fur_Latn |
| Nigerian Fulfulde | fuv | fuv_Latn |
| West Central Oromo | om | gaz_Latn |
| Scottish Gaelic | gd | gla_Latn |
| Irish | ga-IE | gle_Latn |
| Galician | gl | glg_Latn |
| Guarani | gn | grn_Latn |
| Gujarati | gu-IN | guj_Gujr |
| Haitian Creole | ht | hat_Latn |
| Hausa | ha | hau_Latn |
| Hebrew | he | heb_Hebr |
| Hindi | hi | hin_Deva |
| Chhattisgarhi | hne | hne_Deva |
| Croatian | hr | hrv_Latn |
| Hungarian | hu | hun_Latn |
| Armenian | hy-AM | hye_Armn |
| Igbo | ig | ibo_Latn |
| Ilocano | ilo | ilo_Latn |
| Indonesian | id | ind_Latn |
| Icelandic | is | isl_Latn |
| Italian | it | ita_Latn |
| Javanese | jv | jav_Latn |
| Japanese | ja | jpn_Jpan |
| Kabyle | kab | kab_Latn |
| Jingpho | kac | kac_Latn |
| Kamba | kam | kam_Latn |
| Kannada | kn | kan_Knda |
| Kashmiri (Arabic script) | kas_Arab | kas_Arab |
| Kashmiri (Devanagari script) | kas_Deva | kas_Deva |
| Georgian | ka | kat_Geor |
| Kazakh | kk | kaz_Cyrl |
| Kabiyè | kbp | kbp_Latn |
| Kabuverdianu | kea | kea_Latn |
| Halh Mongolian | mn | khk_Cyrl |
| Khmer | km | khm_Khmr |
| Kikuyu | ki | kik_Latn |
| Kinyarwanda | rw | kin_Latn |
| Kyrgyz | ky | kir_Cyrl |
| Kimbundu | kmb | kmb_Latn |
| Northern Kurdish | kmr | kmr_Latn |
| Central Kanuri (Arabic script) | knc_Arab | knc_Arab |
| Central Kanuri (Latin script) | knc_Latn | knc_Latn |
| Kikongo | kg | kon_Latn |
| Korean | ko | kor_Hang |
| Lao | lo | lao_Laoo |
| Ligurian | lij | lij_Latn |
| Limburgish | li | lim_Latn |
| Lingala | ln | lin_Latn |
| Lithuanian | lt | lit_Latn |
| Lombard | lmo | lmo_Latn |
| Latgalian | ltg | ltg_Latn |
| Luxembourgish | lb | ltz_Latn |
| Luba-Kasai | lua | lua_Latn |
| Ganda | lg | lug_Latn |
| Luo | luo | luo_Latn |
| Mizo | lus | lus_Latn |
| Standard Latvian | lv | lvs_Latn |
| Magahi | mag | mag_Deva |
| Maithili | mai | mai_Deva |
| Malayalam | ml-IN | mal_Mlym |
| Marathi | mr | mar_Deva |
| Minangkabau (Arabic script) | min_Arab | min_Arab |
| Minangkabau (Latin script) | min_Latn | min_Latn |
| Macedonian | mk | mkd_Cyrl |
| Maltese | mt | mlt_Latn |
| Meitei (Bengali script) | mni | mni_Beng |
| Mossi | mos | mos_Latn |
| Maori | mi | mri_Latn |
| Burmese | my | mya_Mymr |
| Dutch | nl | nld_Latn |
| Norwegian Nynorsk | nn-NO | nno_Latn |
| Norwegian Bokmål | nb | nob_Latn |
| Nepali | ne-NP | npi_Deva |
| Northern Sotho | nso | nso_Latn |
| Nuer | nus | nus_Latn |
| Nyanja | ny | nya_Latn |
| Occitan | oc | oci_Latn |
| Odia | or | ory_Orya |
| Pangasinan | pag | pag_Latn |
| Eastern Panjabi | pa | pan_Guru |
| Papiamento | pap | pap_Latn |
| Southern Pashto | pbt | pbt_Arab |
| Western Persian | fa | pes_Arab |
| Plateau Malagasy | mg | plt_Latn |
| Polish | pl | pol_Latn |
| Portuguese | pt-PT | por_Latn |
| Dari | fa-AF | prs_Arab |
| Ayacucho Quechua | qu | quy_Latn |
| Romanian | ro | ron_Latn |
| Rundi | rn | run_Latn |
| Russian | ru | rus_Cyrl |
| Sango | sg | sag_Latn |
| Sanskrit | sa | san_Deva |
| Santali | sat | sat_Olck |
| Sicilian | scn | scn_Latn |
| Shan | shn | shn_Mymr |
| Sinhala | si-LK | sin_Sinh |
| Slovak | sk | slk_Latn |
| Slovenian | sl | slv_Latn |
| Samoan | sm | smo_Latn |
| Shona | sn | sna_Latn |
| Sindhi | sd | snd_Arab |
| Somali | so | som_Latn |
| Southern Sotho | st | sot_Latn |
| Spanish | es-ES | spa_Latn |
| Sardinian | sc | srd_Latn |
| Serbian | sr | srp_Cyrl |
| Swati | ss | ssw_Latn |
| Sundanese | su | sun_Latn |
| Swedish | sv-SE | swe_Latn |
| Swahili | sw | swh_Latn |
| Silesian | szl | szl_Latn |
| Tamil | ta | tam_Taml |
| Tamasheq (Latin script) | taq_Latn | taq_Latn |
| Tamasheq (Tifinagh script) | taq_Tfng | taq_Tfng |
| Tatar | tt-RU | tat_Cyrl |
| Telugu | te | tel_Telu |
| Tajik | tg | tgk_Cyrl |
| Tagalog | tl | tgl_Latn |
| Thai | th | tha_Thai |
| Tigrinya | ti | tir_Ethi |
| Tok Pisin | tpi | tpi_Latn |
| Tswana | tn | tsn_Latn |
| Tsonga | ts | tso_Latn |
| Turkmen | tk | tuk_Latn |
| Tumbuka | tum | tum_Latn |
| Turkish | tr | tur_Latn |
| Twi | tw | twi_Latn |
| Central Atlas Tamazight | tzm | tzm_Tfng |
| Uyghur | ug | uig_Arab |
| Ukrainian | uk | ukr_Cyrl |
| Umbundu | umb | umb_Latn |
| Urdu | ur | urd_Arab |
| Northern Uzbek | uz | uzn_Latn |
| Venetian | vec | vec_Latn |
| Vietnamese | vi | vie_Latn |
| Waray | war | war_Latn |
| Wolof | wo | wol_Latn |
| Xhosa | xh | xho_Latn |
| Eastern Yiddish | yi | ydd_Hebr |
| Yoruba | yo | yor_Latn |
| Yue Chinese | yue | yue_Hant |
| Chinese (Simplified) | zh-CN | zho_Hans |
| Chinese (Traditional) | zh-TW | zho_Hant |
| Standard Malay | ms | zsm_Latn |
| Zulu | zu | zul_Latn |
## Special Features
### Multiple Script Support
Several languages are available in multiple scripts (e.g., Arabic and Latin):
- **Acehnese**: Arabic (`ace_Arab`) and Latin (`ace_Latn`)
- **Banjar**: Arabic (`bjn_Arab`) and Latin (`bjn_Latn`)
- **Kashmiri**: Arabic (`kas_Arab`) and Devanagari (`kas_Deva`)
- **Minangkabau**: Arabic (`min_Arab`) and Latin (`min_Latn`)
- **Tamasheq**: Latin (`taq_Latn`) and Tifinagh (`taq_Tfng`)
- **Central Kanuri**: Arabic (`knc_Arab`) and Latin (`knc_Latn`)

View File

@@ -0,0 +1,43 @@
# Technical Integration Guide
This document introduce how to reuse the core components when you do **not** want to ship the bundled frontend, FastAPI server, or even the provided CLI.
---
## 1. Runtime Components
| Layer | File(s) | Purpose |
|-------|---------|---------|
| Transport | `whisperlivekit/basic_server.py`, any ASGI/WebSocket server | Accepts audio over WebSocket (MediaRecorder WebM or raw PCM chunks) and streams JSON updates back |
| Audio processing | `whisperlivekit/audio_processor.py` | Buffers audio, orchestrates transcription, diarization, translation, handles FFmpeg/PCM input |
| Engines | `whisperlivekit/core.py`, `whisperlivekit/simul_whisper/*`, `whisperlivekit/local_agreement/*` | Load models once (SimulStreaming or LocalAgreement), expose `TranscriptionEngine` and helpers |
| Frontends | `whisperlivekit/web/*`, `chrome-extension/*` | Optional UI layers feeding the WebSocket endpoint |
**Key idea:** The server boundary is just `AudioProcessor.process_audio()` for incoming bytes and the async generator returned by `AudioProcessor.create_tasks()` for outgoing updates (`FrontData`). Everything else is optional.
---
## 2. Running Without the Bundled Frontend
1. Start the server/engine however you like:
```bash
wlk --model small --language en --host 0.0.0.0 --port 9000
# or launch your own app that instantiates TranscriptionEngine(...)
```
2. Build your own client (browser, mobile, desktop) that:
- Opens `ws(s)://<host>:<port>/asr`
- Sends either MediaRecorder/Opus WebM blobs **or** raw PCM (`--pcm-input` on the server tells the client to use the AudioWorklet).
- Consumes the JSON payload defined in `docs/API.md`.
---
## 3. Running Without FastAPI
`whisperlivekit/basic_server.py` is just an example. Any async framework works, as long as you:
1. Create a global `TranscriptionEngine` (expensive to initialize; reuse it).
2. Instantiate `AudioProcessor(transcription_engine=engine)` for each connection.
3. Call `create_tasks()` to get the async generator, `process_audio()` with incoming bytes, and ensure `cleanup()` runs when the client disconnects.
If you prefer to send compressed audio, instantiate `AudioProcessor(pcm_input=False)` and pipe encoded chunks through `FFmpegManager` transparently—just ensure `ffmpeg` is available or be ready to handle the `"ffmpeg_not_found"` error in the streamed `FrontData`.

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "whisperlivekit" name = "whisperlivekit"
version = "0.2.10" version = "0.2.15"
description = "Real-time speech-to-text with speaker diarization using Whisper" description = "Real-time speech-to-text with speaker diarization using Whisper"
readme = "README.md" readme = "README.md"
authors = [ authors = [
@@ -30,28 +30,41 @@ dependencies = [
"fastapi", "fastapi",
"librosa", "librosa",
"soundfile", "soundfile",
"faster-whisper",
"uvicorn", "uvicorn",
"websockets", "websockets",
"torchaudio>=2.0.0", "torchaudio>=2.0.0",
"torch>=2.0.0", "torch>=2.0.0",
"huggingface-hub>=0.25.0",
"tqdm", "tqdm",
"tiktoken", "tiktoken",
'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")' 'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")'
] ]
[project.optional-dependencies] [project.optional-dependencies]
sentence = ["mosestokenizer", "wtpsplit"] translation = ["nllw"]
sentence_tokenizer = ["mosestokenizer", "wtpsplit"]
[project.urls] [project.urls]
Homepage = "https://github.com/QuentinFuxa/WhisperLiveKit" Homepage = "https://github.com/QuentinFuxa/WhisperLiveKit"
[project.scripts] [project.scripts]
whisperlivekit-server = "whisperlivekit.basic_server:main" whisperlivekit-server = "whisperlivekit.basic_server:main"
wlk = "whisperlivekit.basic_server:main"
[tool.setuptools] [tool.setuptools]
packages = ["whisperlivekit", "whisperlivekit.diarization", "whisperlivekit.simul_whisper", "whisperlivekit.simul_whisper.whisper", "whisperlivekit.simul_whisper.whisper.assets", "whisperlivekit.simul_whisper.whisper.normalizers", "whisperlivekit.web", "whisperlivekit.whisper_streaming_custom"] packages = [
"whisperlivekit",
"whisperlivekit.diarization",
"whisperlivekit.simul_whisper",
"whisperlivekit.whisper",
"whisperlivekit.whisper.assets",
"whisperlivekit.whisper.normalizers",
"whisperlivekit.web",
"whisperlivekit.local_agreement",
"whisperlivekit.vad_models"
]
[tool.setuptools.package-data] [tool.setuptools.package-data]
whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"] whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"]
"whisperlivekit.simul_whisper.whisper.assets" = ["*.tiktoken", "*.npz"] "whisperlivekit.whisper.assets" = ["*.tiktoken", "*.npz"]
"whisperlivekit.vad_models" = ["*.jit", "*.onnx"]

BIN
scripts/alignment_heads.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 276 KiB

View File

@@ -0,0 +1,153 @@
#!/usr/bin/env python3
"""
Convert a Hugging Face style Whisper checkpoint into a WhisperLiveKit .pt file.
Optionally shrink the supported audio chunk length (in seconds) by trimming the
encoder positional embeddings and updating the stored model dimensions.
"""
import argparse
import json
import os
from pathlib import Path
from typing import Dict, Tuple
import torch
from whisperlivekit.whisper.audio import HOP_LENGTH, SAMPLE_RATE
from whisperlivekit.whisper.model import ModelDimensions
from whisperlivekit.whisper.utils import exact_div
from whisperlivekit.whisper import _convert_hf_state_dict
def _load_state_dict(repo_path: Path) -> Dict[str, torch.Tensor]:
safetensor_path = repo_path / "model.safetensors"
bin_path = repo_path / "pytorch_model.bin"
if safetensor_path.is_file():
try:
from safetensors.torch import load_file # type: ignore
except Exception as exc: # pragma: no cover - import guard
raise RuntimeError(
"Install safetensors to load model.safetensors "
"(pip install safetensors)"
) from exc
return load_file(str(safetensor_path))
if bin_path.is_file():
return torch.load(bin_path, map_location="cpu")
raise FileNotFoundError(
f"Could not find model.safetensors or pytorch_model.bin under {repo_path}"
)
def _load_config(repo_path: Path) -> Dict:
config_path = repo_path / "config.json"
if not config_path.is_file():
raise FileNotFoundError(
f"Hugging Face checkpoint at {repo_path} is missing config.json"
)
with open(config_path, "r", encoding="utf-8") as fp:
return json.load(fp)
def _derive_audio_ctx(chunk_length: float) -> Tuple[int, int]:
n_samples = int(round(chunk_length * SAMPLE_RATE))
expected_samples = chunk_length * SAMPLE_RATE
if abs(n_samples - expected_samples) > 1e-6:
raise ValueError(
"chunk_length must align with sample rate so that "
"chunk_length * SAMPLE_RATE is an integer"
)
n_frames = exact_div(n_samples, HOP_LENGTH)
n_audio_ctx = exact_div(n_frames, 2)
return n_frames, n_audio_ctx
def _build_dims(config: Dict, chunk_length: float) -> Dict:
base_dims = ModelDimensions(
n_mels=config["num_mel_bins"],
n_audio_ctx=config["max_source_positions"],
n_audio_state=config["d_model"],
n_audio_head=config["encoder_attention_heads"],
n_audio_layer=config.get("encoder_layers") or config["num_hidden_layers"],
n_vocab=config["vocab_size"],
n_text_ctx=config["max_target_positions"],
n_text_state=config["d_model"],
n_text_head=config["decoder_attention_heads"],
n_text_layer=config["decoder_layers"],
).__dict__.copy()
_, n_audio_ctx = _derive_audio_ctx(chunk_length)
base_dims["n_audio_ctx"] = n_audio_ctx
base_dims["chunk_length"] = chunk_length
return base_dims
def _trim_positional_embedding(
state_dict: Dict[str, torch.Tensor], target_ctx: int
) -> None:
key = "encoder.positional_embedding"
if key not in state_dict:
raise KeyError(f"{key} missing from converted state dict")
tensor = state_dict[key]
if tensor.shape[0] < target_ctx:
raise ValueError(
f"Cannot increase encoder ctx from {tensor.shape[0]} to {target_ctx}"
)
if tensor.shape[0] == target_ctx:
return
state_dict[key] = tensor[:target_ctx].contiguous()
def convert_checkpoint(hf_path: Path, output_path: Path, chunk_length: float) -> None:
state_dict = _load_state_dict(hf_path)
converted = _convert_hf_state_dict(state_dict)
config = _load_config(hf_path)
dims = _build_dims(config, chunk_length)
_trim_positional_embedding(converted, dims["n_audio_ctx"])
package = {"dims": dims, "model_state_dict": converted}
output_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(package, output_path)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Convert Hugging Face Whisper checkpoint to WhisperLiveKit format."
)
parser.add_argument(
"hf_path",
type=str,
help="Path to the cloned Hugging Face repository (e.g. whisper-tiny.en)",
)
parser.add_argument(
"--output",
type=str,
default="converted-whisper.pt",
help="Destination path for the .pt file",
)
parser.add_argument(
"--chunk-length",
type=float,
default=30.0,
help="Audio chunk length in seconds to support (default: 30)",
)
return parser.parse_args()
def main():
args = parse_args()
hf_path = Path(os.path.expanduser(args.hf_path)).resolve()
output_path = Path(os.path.expanduser(args.output)).resolve()
convert_checkpoint(hf_path, output_path, args.chunk_length)
print(f"Saved converted checkpoint to {output_path}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,292 @@
"""Determine alignment heads for a variants, such as distilled model"""
from __future__ import annotations
import argparse
import base64
import gzip
import io
import pathlib
import sys
import math
from typing import List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
from datasets import Audio as DatasetAudio, load_dataset
import soundfile as sf
import matplotlib.pyplot as plt
REPO_ROOT = pathlib.Path(__file__).resolve().parents[1]
WHISPER_ROOT = REPO_ROOT / "whisper"
sys.path.insert(0, str(REPO_ROOT))
sys.path.insert(0, str(WHISPER_ROOT))
from whisper import load_model
from whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim
from whisper.tokenizer import get_tokenizer
AudioInput = Union[str, pathlib.Path, np.ndarray, torch.Tensor]
def load_dataset_clips(name, config, split, limit):
ds = load_dataset(name, config, split=split)
ds = ds.cast_column("audio", DatasetAudio(decode=False))
clips = []
for idx, row in enumerate(ds):
if limit is not None and idx >= limit:
break
audio_field = row["audio"]
transcript = row["text"]
waveform_np, _ = sf.read(io.BytesIO(audio_field["bytes"]), dtype="float32")
if waveform_np.ndim > 1:
waveform_np = waveform_np.mean(axis=1)
waveform = waveform_np
transcript = str(transcript)
clips.append((waveform, transcript))
return clips
def load_clips(args):
return load_dataset_clips(
args.dataset,
args.dataset_config,
args.dataset_split,
args.dataset_num_samples,
)
def _waveform_from_source(source: AudioInput) -> torch.Tensor:
waveform = torch.from_numpy(source.astype(np.float32, copy=False))
return waveform
def _parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
type=str,
default="pytorch_model.bin",
)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="Torch device to run on",
)
parser.add_argument(
"--dataset",
type=str,
default="librispeech_asr"
)
parser.add_argument(
"--dataset-config",
type=str,
default="clean"
)
parser.add_argument(
"--dataset-split",
type=str,
default="validation[:1%]",
)
parser.add_argument(
"--dataset-num-samples",
type=int,
default=16,
)
parser.add_argument(
"--threshold",
type=float,
default=1.5,
help="Z score threshold for a head to be selected",
)
parser.add_argument(
"--votes",
type=float,
default=0.75,
help="percentage of clips that must vote for a head",
)
parser.add_argument(
"--output",
type=str,
default="alignment_heads.b85",
)
parser.add_argument(
"--visualize-top-k",
type=int,
default=32,
)
return parser.parse_args()
def collect_heads(
model,
tokenizer,
clips: Sequence[Tuple[AudioInput, str]],
threshold: float,
) -> Tuple[torch.Tensor, torch.Tensor]:
device = model.device
votes = torch.zeros(model.dims.n_text_layer, model.dims.n_text_head, device=device)
strengths = torch.zeros_like(votes)
for audio_source, transcript in clips:
waveform = pad_or_trim(_waveform_from_source(audio_source))
mel = log_mel_spectrogram(waveform, device=device)
tokens = torch.tensor(
[
*tokenizer.sot_sequence,
tokenizer.no_timestamps,
*tokenizer.encode(transcript),
tokenizer.eot,
],
device=device,
)
qks = [None] * model.dims.n_text_layer
hooks = [
block.cross_attn.register_forward_hook(
lambda _, __, outputs, index=i: qks.__setitem__(index, outputs[-1][0])
)
for i, block in enumerate(model.decoder.blocks)
]
with torch.no_grad():
model(mel.unsqueeze(0), tokens.unsqueeze(0))
for hook in hooks:
hook.remove()
for layer_idx, tensor in enumerate(qks):
if tensor is None:
continue
tensor = tensor[:, :, : mel.shape[-1] // 2]
tensor = tensor.softmax(dim=-1)
peak = tensor.max(dim=-1).values # [heads, tokens]
strengths[layer_idx] += peak.mean(dim=-1)
zscore = (peak - peak.mean(dim=-1, keepdim=True)) / (
peak.std(dim=-1, keepdim=True, unbiased=False) + 1e-6
)
mask = (zscore > 3).any(dim=-1)
votes[layer_idx] += mask.float()
votes /= len(clips)
strengths /= len(clips)
return votes, strengths
def _select_heads_for_visualization(selection, strengths, top_k):
selected = torch.nonzero(selection, as_tuple=False)
if selected.numel() == 0:
return []
entries = [
(int(layer.item()), int(head.item()), float(strengths[layer, head].item()))
for layer, head in selected
]
entries.sort(key=lambda item: item[2], reverse=True)
return entries[:top_k]
def _extract_heatmaps(
model,
tokenizer,
clip: Tuple[AudioInput, str],
heads: Sequence[Tuple[int, int, float]],
) -> dict:
if not heads:
return {}
target_map = {}
for layer, head, _ in heads:
target_map.setdefault(layer, set()).add(head)
waveform = pad_or_trim(_waveform_from_source(clip[0]))
mel = log_mel_spectrogram(waveform, device=model.device)
transcript = clip[1]
tokens = torch.tensor(
[
*tokenizer.sot_sequence,
tokenizer.no_timestamps,
*tokenizer.encode(transcript),
tokenizer.eot,
],
device=model.device,
)
QKs = [None] * model.dims.n_text_layer
hooks = [
block.cross_attn.register_forward_hook(
lambda _, __, outputs, index=i: QKs.__setitem__(index, outputs[-1][0])
)
for i, block in enumerate(model.decoder.blocks)
]
with torch.no_grad():
model(mel.unsqueeze(0), tokens.unsqueeze(0))
for hook in hooks:
hook.remove()
heatmaps = {}
for layer_idx, tensor in enumerate(QKs):
if tensor is None or layer_idx not in target_map:
continue
tensor = tensor[:, :, : mel.shape[-1] // 2]
tensor = tensor.softmax(dim=-1).cpu()
for head_idx in target_map[layer_idx]:
heatmaps[(layer_idx, head_idx)] = tensor[head_idx]
return heatmaps
def _plot_heatmaps(
heads, heatmaps, output_path):
cols = min(3, len(heads))
rows = math.ceil(len(heads) / cols)
fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 3.2 * rows), squeeze=False)
for idx, (layer, head, score) in enumerate(heads):
ax = axes[idx // cols][idx % cols]
mat = heatmaps.get((layer, head))
if mat is None:
ax.axis("off")
continue
im = ax.imshow(mat.to(torch.float32).numpy(), aspect="auto", origin="lower")
ax.set_title(f"L{layer} H{head} · score {score:.2f}")
ax.set_xlabel("time")
ax.set_ylabel("tokens")
for j in range(len(heads), rows * cols):
axes[j // cols][j % cols].axis("off")
fig.tight_layout()
fig.savefig(output_path, dpi=200)
plt.close(fig)
def _dump_mask(mask: torch.Tensor, output_path: str):
payload = mask.numpy().astype(np.bool_)
blob = base64.b85encode(gzip.compress(payload.tobytes()))
with open(output_path, "wb") as f:
f.write(blob)
def main():
args = _parse_args()
model = load_model(args.model, device=args.device)
model.eval()
tokenizer = get_tokenizer(multilingual=model.is_multilingual)
clips = load_clips(args)
votes, strengths = collect_heads(model, tokenizer, clips, args.threshold)
# selection = votes > 0.5
selection = strengths > 0.05
_dump_mask(selection.cpu(), args.output)
viz_heads = _select_heads_for_visualization(selection, strengths, args.visualize_top_k)
heatmaps = _extract_heatmaps(model, tokenizer, clips[0], viz_heads)
_plot_heatmaps(viz_heads, heatmaps, "alignment_heads.png")
if __name__ == "__main__":
main()

39
scripts/sync_extension.py Normal file
View File

@@ -0,0 +1,39 @@
"""Copy core files from web directory to Chrome extension directory."""
import shutil
import os
from pathlib import Path
def sync_extension_files():
web_dir = Path("whisperlivekit/web")
extension_dir = Path("chrome-extension")
files_to_sync = [
"live_transcription.html", "live_transcription.js", "live_transcription.css"
]
svg_files = [
"system_mode.svg",
"light_mode.svg",
"dark_mode.svg",
"settings.svg"
]
for file in files_to_sync:
src_path = web_dir / file
dest_path = extension_dir / file
dest_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(src_path, dest_path)
for svg_file in svg_files:
src_path = web_dir / "src" / svg_file
dest_path = extension_dir / "web" / "src" / svg_file
dest_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(src_path, dest_path)
if __name__ == "__main__":
sync_extension_files()

View File

@@ -1,31 +1,46 @@
import asyncio import asyncio
import numpy as np import numpy as np
from time import time, sleep from time import time
import math
import logging import logging
import traceback import traceback
from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State, Transcript from typing import Optional, Union, List, Any, AsyncGenerator
from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State, Transcript, ChangeSpeaker
from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory, online_translation_factory from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory, online_translation_factory
from whisperlivekit.silero_vad_iterator import FixedVADIterator from whisperlivekit.silero_vad_iterator import FixedVADIterator
from whisperlivekit.results_formater import format_output
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
# Set up logging once from whisperlivekit.tokens_alignment import TokensAlignment
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
SENTINEL = object() # unique sentinel object for end of stream marker SENTINEL = object() # unique sentinel object for end of stream marker
MIN_DURATION_REAL_SILENCE = 5
async def get_all_from_queue(queue: asyncio.Queue) -> Union[object, Silence, np.ndarray, List[Any]]:
items: List[Any] = []
async def get_all_from_queue(queue): first_item = await queue.get()
items = [] queue.task_done()
try: if first_item is SENTINEL:
while True: return first_item
item = queue.get_nowait() if isinstance(first_item, Silence):
items.append(item) return first_item
except asyncio.QueueEmpty: items.append(first_item)
pass
return items while True:
if not queue._queue:
break
next_item = queue._queue[0]
if next_item is SENTINEL:
break
if isinstance(next_item, Silence):
break
items.append(await queue.get())
queue.task_done()
if isinstance(items[0], np.ndarray):
return np.concatenate(items)
else: #translation
return items
class AudioProcessor: class AudioProcessor:
""" """
@@ -33,7 +48,7 @@ class AudioProcessor:
Handles audio processing, state management, and result formatting. Handles audio processing, state management, and result formatting.
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs: Any) -> None:
"""Initialize the audio processor with configuration, models, and state.""" """Initialize the audio processor with configuration, models, and state."""
if 'transcription_engine' in kwargs and isinstance(kwargs['transcription_engine'], TranscriptionEngine): if 'transcription_engine' in kwargs and isinstance(kwargs['transcription_engine'], TranscriptionEngine):
@@ -50,37 +65,29 @@ class AudioProcessor:
self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample
self.max_bytes_per_sec = 32000 * 5 # 5 seconds of audio at 32 kHz self.max_bytes_per_sec = 32000 * 5 # 5 seconds of audio at 32 kHz
self.is_pcm_input = self.args.pcm_input self.is_pcm_input = self.args.pcm_input
self.debug = False
# State management # State management
self.is_stopping = False self.is_stopping: bool = False
self.silence = False self.current_silence: Optional[Silence] = None
self.silence_duration = 0.0 self.state: State = State()
self.tokens = [] self.lock: asyncio.Lock = asyncio.Lock()
self.translated_segments = [] self.sep: str = " " # Default separator
self.buffer_transcription = Transcript() self.last_response_content: FrontData = FrontData()
self.buffer_diarization = ""
self.end_buffer = 0 self.tokens_alignment: TokensAlignment = TokensAlignment(self.state, self.args, self.sep)
self.end_attributed_speaker = 0 self.beg_loop: Optional[float] = None
self.lock = asyncio.Lock()
self.beg_loop = None #to deal with a potential little lag at the websocket initialization, this is now set in process_audio
self.sep = " " # Default separator
self.last_response_content = FrontData()
self.last_detected_speaker = None
self.speaker_languages = {}
# Models and processing # Models and processing
self.asr = models.asr self.asr: Any = models.asr
self.tokenizer = models.tokenizer self.vac_model: Any = models.vac_model
self.vac_model = models.vac_model
if self.args.vac: if self.args.vac:
self.vac = FixedVADIterator(models.vac_model) self.vac: Optional[FixedVADIterator] = FixedVADIterator(models.vac_model)
else: else:
self.vac = None self.vac: Optional[FixedVADIterator] = None
self.ffmpeg_manager = None self.ffmpeg_manager: Optional[FFmpegManager] = None
self.ffmpeg_reader_task = None self.ffmpeg_reader_task: Optional[asyncio.Task] = None
self._ffmpeg_error = None self._ffmpeg_error: Optional[str] = None
if not self.is_pcm_input: if not self.is_pcm_input:
self.ffmpeg_manager = FFmpegManager( self.ffmpeg_manager = FFmpegManager(
@@ -92,73 +99,102 @@ class AudioProcessor:
self._ffmpeg_error = error_type self._ffmpeg_error = error_type
self.ffmpeg_manager.on_error_callback = handle_ffmpeg_error self.ffmpeg_manager.on_error_callback = handle_ffmpeg_error
self.transcription_queue = asyncio.Queue() if self.args.transcription else None self.transcription_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.transcription else None
self.diarization_queue = asyncio.Queue() if self.args.diarization else None self.diarization_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.diarization else None
self.translation_queue = asyncio.Queue() if self.args.target_language else None self.translation_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.target_language else None
self.pcm_buffer = bytearray() self.pcm_buffer: bytearray = bytearray()
self.total_pcm_samples: int = 0
self.transcription_task: Optional[asyncio.Task] = None
self.diarization_task: Optional[asyncio.Task] = None
self.translation_task: Optional[asyncio.Task] = None
self.watchdog_task: Optional[asyncio.Task] = None
self.all_tasks_for_cleanup: List[asyncio.Task] = []
self.transcription_task = None self.transcription: Optional[Any] = None
self.diarization_task = None self.translation: Optional[Any] = None
self.watchdog_task = None self.diarization: Optional[Any] = None
self.all_tasks_for_cleanup = []
if self.args.transcription: if self.args.transcription:
self.online = online_factory(self.args, models.asr, models.tokenizer) self.transcription = online_factory(self.args, models.asr)
self.sep = self.online.asr.sep self.sep = self.transcription.asr.sep
if self.args.diarization: if self.args.diarization:
self.diarization = online_diarization_factory(self.args, models.diarization_model) self.diarization = online_diarization_factory(self.args, models.diarization_model)
if self.args.target_language: if models.translation_model:
self.online_translation = online_translation_factory(self.args, models.translation_model) self.translation = online_translation_factory(self.args, models.translation_model)
def convert_pcm_to_float(self, pcm_buffer): async def _push_silence_event(self) -> None:
if self.transcription_queue:
await self.transcription_queue.put(self.current_silence)
if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(self.current_silence)
if self.translation_queue:
await self.translation_queue.put(self.current_silence)
async def _begin_silence(self) -> None:
if self.current_silence:
return
now = time() - self.beg_loop
self.current_silence = Silence(
is_starting=True, start=now
)
await self._push_silence_event()
async def _end_silence(self) -> None:
if not self.current_silence:
return
now = time() - self.beg_loop
self.current_silence.end = now
self.current_silence.is_starting=False
self.current_silence.has_ended=True
self.current_silence.compute_duration()
if self.current_silence.duration > MIN_DURATION_REAL_SILENCE:
self.state.new_tokens.append(self.current_silence)
await self._push_silence_event()
self.current_silence = None
async def _enqueue_active_audio(self, pcm_chunk: np.ndarray) -> None:
if pcm_chunk is None or pcm_chunk.size == 0:
return
if self.transcription_queue:
await self.transcription_queue.put(pcm_chunk.copy())
if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(pcm_chunk.copy())
def _slice_before_silence(self, pcm_array: np.ndarray, chunk_sample_start: int, silence_sample: Optional[int]) -> Optional[np.ndarray]:
if silence_sample is None:
return None
relative_index = int(silence_sample - chunk_sample_start)
if relative_index <= 0:
return None
split_index = min(relative_index, len(pcm_array))
if split_index <= 0:
return None
return pcm_array[:split_index]
def convert_pcm_to_float(self, pcm_buffer: Union[bytes, bytearray]) -> np.ndarray:
"""Convert PCM buffer in s16le format to normalized NumPy array.""" """Convert PCM buffer in s16le format to normalized NumPy array."""
return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0 return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
async def add_dummy_token(self): async def get_current_state(self) -> State:
"""Placeholder token when no transcription is available."""
async with self.lock:
current_time = time() - self.beg_loop if self.beg_loop else 0
self.tokens.append(ASRToken(
start=current_time, end=current_time + 1,
text=".", speaker=-1, is_dummy=True
))
async def get_current_state(self):
"""Get current state.""" """Get current state."""
async with self.lock: async with self.lock:
current_time = time() current_time = time()
# Calculate remaining times
remaining_transcription = 0 remaining_transcription = 0
if self.end_buffer > 0: if self.state.end_buffer > 0:
remaining_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 1)) remaining_transcription = max(0, round(current_time - self.beg_loop - self.state.end_buffer, 1))
remaining_diarization = 0 remaining_diarization = 0
if self.tokens: if self.state.tokens:
latest_end = max(self.end_buffer, self.tokens[-1].end if self.tokens else 0) latest_end = max(self.state.end_buffer, self.state.tokens[-1].end if self.state.tokens else 0)
remaining_diarization = max(0, round(latest_end - self.end_attributed_speaker, 1)) remaining_diarization = max(0, round(latest_end - self.state.end_attributed_speaker, 1))
return State( self.state.remaining_time_transcription = remaining_transcription
tokens=self.tokens.copy(), self.state.remaining_time_diarization = remaining_diarization
translated_segments=self.translated_segments.copy(),
buffer_transcription=self.buffer_transcription,
buffer_diarization=self.buffer_diarization,
end_buffer=self.end_buffer,
end_attributed_speaker=self.end_attributed_speaker,
remaining_time_transcription=remaining_transcription,
remaining_time_diarization=remaining_diarization
)
async def reset(self): return self.state
"""Reset all state variables to initial values."""
async with self.lock:
self.tokens = []
self.translated_segments = []
self.buffer_transcription = self.buffer_diarization = Transcript()
self.end_buffer = self.end_attributed_speaker = 0
self.beg_loop = time()
async def ffmpeg_stdout_reader(self): async def ffmpeg_stdout_reader(self) -> None:
"""Read audio data from FFmpeg stdout and process it into the PCM pipeline.""" """Read audio data from FFmpeg stdout and process it into the PCM pipeline."""
beg = time() beg = time()
while True: while True:
@@ -201,56 +237,61 @@ class AudioProcessor:
await asyncio.sleep(0.2) await asyncio.sleep(0.2)
logger.info("FFmpeg stdout processing finished. Signaling downstream processors if needed.") logger.info("FFmpeg stdout processing finished. Signaling downstream processors if needed.")
if self.args.transcription and self.transcription_queue: if self.transcription_queue:
await self.transcription_queue.put(SENTINEL) await self.transcription_queue.put(SENTINEL)
if self.args.diarization and self.diarization_queue: if self.diarization:
await self.diarization_queue.put(SENTINEL) await self.diarization_queue.put(SENTINEL)
if self.args.target_language and self.translation_queue: if self.translation:
await self.translation_queue.put(SENTINEL) await self.translation_queue.put(SENTINEL)
async def transcription_processor(self): async def transcription_processor(self) -> None:
"""Process audio chunks for transcription.""" """Process audio chunks for transcription."""
cumulative_pcm_duration_stream_time = 0.0 cumulative_pcm_duration_stream_time = 0.0
while True: while True:
try: try:
item = await self.transcription_queue.get() # item = await self.transcription_queue.get()
item = await get_all_from_queue(self.transcription_queue)
if item is SENTINEL: if item is SENTINEL:
logger.debug("Transcription processor received sentinel. Finishing.") logger.debug("Transcription processor received sentinel. Finishing.")
self.transcription_queue.task_done()
break break
if not self.online: asr_internal_buffer_duration_s = len(getattr(self.transcription, 'audio_buffer', [])) / self.transcription.SAMPLING_RATE
logger.warning("Transcription processor: self.online not initialized.") transcription_lag_s = max(0.0, time() - self.beg_loop - self.state.end_buffer)
self.transcription_queue.task_done()
continue
asr_internal_buffer_duration_s = len(getattr(self.online, 'audio_buffer', [])) / self.online.SAMPLING_RATE
transcription_lag_s = max(0.0, time() - self.beg_loop - self.end_buffer)
asr_processing_logs = f"internal_buffer={asr_internal_buffer_duration_s:.2f}s | lag={transcription_lag_s:.2f}s |" asr_processing_logs = f"internal_buffer={asr_internal_buffer_duration_s:.2f}s | lag={transcription_lag_s:.2f}s |"
if type(item) is Silence:
asr_processing_logs += f" + Silence of = {item.duration:.2f}s"
if self.tokens:
asr_processing_logs += f" | last_end = {self.tokens[-1].end} |"
logger.info(asr_processing_logs)
cumulative_pcm_duration_stream_time += item.duration
self.online.insert_silence(item.duration, self.tokens[-1].end if self.tokens else 0)
continue
logger.info(asr_processing_logs)
if isinstance(item, np.ndarray):
pcm_array = item
else:
raise Exception('item should be pcm_array')
duration_this_chunk = len(pcm_array) / self.sample_rate
cumulative_pcm_duration_stream_time += duration_this_chunk
stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time
new_tokens = []
current_audio_processed_upto = self.state.end_buffer
self.online.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm) if isinstance(item, Silence):
new_tokens, current_audio_processed_upto = await asyncio.to_thread(self.online.process_iter) if item.is_starting:
new_tokens, current_audio_processed_upto = await asyncio.to_thread(
self.transcription.start_silence
)
asr_processing_logs += f" + Silence starting"
if item.has_ended:
asr_processing_logs += f" + Silence of = {item.duration:.2f}s"
cumulative_pcm_duration_stream_time += item.duration
current_audio_processed_upto = cumulative_pcm_duration_stream_time
self.transcription.end_silence(item.duration, self.state.tokens[-1].end if self.state.tokens else 0)
if self.state.tokens:
asr_processing_logs += f" | last_end = {self.state.tokens[-1].end} |"
logger.info(asr_processing_logs)
new_tokens = new_tokens or []
current_audio_processed_upto = max(current_audio_processed_upto, stream_time_end_of_current_pcm)
elif isinstance(item, ChangeSpeaker):
self.transcription.new_speaker(item)
continue
elif isinstance(item, np.ndarray):
pcm_array = item
logger.info(asr_processing_logs)
cumulative_pcm_duration_stream_time += len(pcm_array) / self.sample_rate
stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time
self.transcription.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm)
new_tokens, current_audio_processed_upto = await asyncio.to_thread(self.transcription.process_iter)
new_tokens = new_tokens or []
_buffer_transcript = self.online.get_buffer() _buffer_transcript = self.transcription.get_buffer()
buffer_text = _buffer_transcript.text buffer_text = _buffer_transcript.text
if new_tokens: if new_tokens:
@@ -258,7 +299,7 @@ class AudioProcessor:
if buffer_text.startswith(validated_text): if buffer_text.startswith(validated_text):
_buffer_transcript.text = buffer_text[len(validated_text):].lstrip() _buffer_transcript.text = buffer_text[len(validated_text):].lstrip()
candidate_end_times = [self.end_buffer] candidate_end_times = [self.state.end_buffer]
if new_tokens: if new_tokens:
candidate_end_times.append(new_tokens[-1].end) candidate_end_times.append(new_tokens[-1].end)
@@ -269,16 +310,15 @@ class AudioProcessor:
candidate_end_times.append(current_audio_processed_upto) candidate_end_times.append(current_audio_processed_upto)
async with self.lock: async with self.lock:
self.tokens.extend(new_tokens) self.state.tokens.extend(new_tokens)
self.buffer_transcription = _buffer_transcript self.state.buffer_transcription = _buffer_transcript
self.end_buffer = max(candidate_end_times) self.state.end_buffer = max(candidate_end_times)
self.state.new_tokens.extend(new_tokens)
self.state.new_tokens_buffer = _buffer_transcript
if self.translation_queue: if self.translation_queue:
for token in new_tokens: for token in new_tokens:
await self.translation_queue.put(token) await self.translation_queue.put(token)
self.transcription_queue.task_done()
except Exception as e: except Exception as e:
logger.warning(f"Exception in transcription_processor: {e}") logger.warning(f"Exception in transcription_processor: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}") logger.warning(f"Traceback: {traceback.format_exc()}")
@@ -295,207 +335,116 @@ class AudioProcessor:
logger.info("Transcription processor task finished.") logger.info("Transcription processor task finished.")
async def diarization_processor(self, diarization_obj): async def diarization_processor(self) -> None:
"""Process audio chunks for speaker diarization."""
buffer_diarization = ""
cumulative_pcm_duration_stream_time = 0.0
while True: while True:
try: try:
item = await self.diarization_queue.get() item = await get_all_from_queue(self.diarization_queue)
if item is SENTINEL: if item is SENTINEL:
logger.debug("Diarization processor received sentinel. Finishing.")
self.diarization_queue.task_done()
break break
elif type(item) is Silence: elif type(item) is Silence:
cumulative_pcm_duration_stream_time += item.duration if item.has_ended:
diarization_obj.insert_silence(item.duration) self.diarization.insert_silence(item.duration)
continue continue
elif isinstance(item, np.ndarray):
pcm_array = item
else:
raise Exception('item should be pcm_array')
# Process diarization self.diarization.insert_audio_chunk(item)
await diarization_obj.diarize(pcm_array) diarization_segments = await self.diarization.diarize()
self.state.new_diarization = diarization_segments
async with self.lock:
self.tokens, last_segment = diarization_obj.assign_speakers_to_tokens(
self.tokens,
use_punctuation_split=self.args.punctuation_split
)
if len(self.tokens) > 0:
self.end_attributed_speaker = max(self.tokens[-1].end, self.end_attributed_speaker)
if buffer_diarization:
self.buffer_diarization = buffer_diarization
# if last_segment is not None and last_segment.speaker != self.last_detected_speaker:
# if not self.speaker_languages.get(last_segment.speaker, None):
# self.last_detected_speaker = last_segment.speaker
# self.online.on_new_speaker(last_segment)
self.diarization_queue.task_done()
except Exception as e: except Exception as e:
logger.warning(f"Exception in diarization_processor: {e}") logger.warning(f"Exception in diarization_processor: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}") logger.warning(f"Traceback: {traceback.format_exc()}")
if 'pcm_array' in locals() and pcm_array is not SENTINEL:
self.diarization_queue.task_done()
logger.info("Diarization processor task finished.") logger.info("Diarization processor task finished.")
async def translation_processor(self, online_translation): async def translation_processor(self) -> None:
# the idea is to ignore diarization for the moment. We use only transcription tokens. # the idea is to ignore diarization for the moment. We use only transcription tokens.
# And the speaker is attributed given the segments used for the translation # And the speaker is attributed given the segments used for the translation
# in the future we want to have different languages for each speaker etc, so it will be more complex. # in the future we want to have different languages for each speaker etc, so it will be more complex.
while True: while True:
try: try:
item = await self.translation_queue.get() #block until at least 1 token item = await get_all_from_queue(self.translation_queue)
if item is SENTINEL: if item is SENTINEL:
logger.debug("Translation processor received sentinel. Finishing.") logger.debug("Translation processor received sentinel. Finishing.")
self.translation_queue.task_done()
break break
elif type(item) is Silence: elif type(item) is Silence:
online_translation.insert_silence(item.duration) if item.is_starting:
continue new_translation, new_translation_buffer = self.translation.validate_buffer_and_reset()
if item.has_ended:
# get all the available tokens for translation. The more words, the more precise self.translation.insert_silence(item.duration)
tokens_to_process = [item] continue
additional_tokens = await get_all_from_queue(self.translation_queue) elif isinstance(item, ChangeSpeaker):
new_translation, new_translation_buffer = self.translation.validate_buffer_and_reset()
sentinel_found = False pass
for additional_token in additional_tokens: else:
if additional_token is SENTINEL: self.translation.insert_tokens(item)
sentinel_found = True new_translation, new_translation_buffer = await asyncio.to_thread(self.translation.process)
break async with self.lock:
tokens_to_process.append(additional_token) self.state.new_translation.append(new_translation)
if tokens_to_process: self.state.new_translation_buffer = new_translation_buffer
online_translation.insert_tokens(tokens_to_process)
self.translated_segments = await asyncio.to_thread(online_translation.process)
self.translation_queue.task_done()
for _ in additional_tokens:
self.translation_queue.task_done()
if sentinel_found:
logger.debug("Translation processor received sentinel in batch. Finishing.")
break
except Exception as e: except Exception as e:
logger.warning(f"Exception in translation_processor: {e}") logger.warning(f"Exception in translation_processor: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}") logger.warning(f"Traceback: {traceback.format_exc()}")
if 'token' in locals() and item is not SENTINEL:
self.translation_queue.task_done()
if 'additional_tokens' in locals():
for _ in additional_tokens:
self.translation_queue.task_done()
logger.info("Translation processor task finished.") logger.info("Translation processor task finished.")
async def results_formatter(self): async def results_formatter(self) -> AsyncGenerator[FrontData, None]:
"""Format processing results for output.""" """Format processing results for output."""
while True: while True:
try: try:
# If FFmpeg error occurred, notify front-end
if self._ffmpeg_error: if self._ffmpeg_error:
yield FrontData( yield FrontData(status="error", error=f"FFmpeg error: {self._ffmpeg_error}")
status="error",
error=f"FFmpeg error: {self._ffmpeg_error}"
)
self._ffmpeg_error = None self._ffmpeg_error = None
await asyncio.sleep(1) await asyncio.sleep(1)
continue continue
# Get current state self.tokens_alignment.update()
lines, buffer_diarization_text, buffer_translation_text = self.tokens_alignment.get_lines(
diarization=self.args.diarization,
translation=bool(self.translation),
current_silence=self.current_silence
)
state = await self.get_current_state() state = await self.get_current_state()
# Add dummy tokens if needed buffer_transcription_text = state.buffer_transcription.text if state.buffer_transcription else ''
if (not state.tokens or state.tokens[-1].is_dummy) and not self.args.transcription and self.args.diarization:
await self.add_dummy_token()
sleep(0.5)
state = await self.get_current_state()
# Format output
lines, undiarized_text, end_w_silence = format_output(
state,
self.silence,
current_time = time() - self.beg_loop if self.beg_loop else None,
args = self.args,
debug = self.debug,
sep=self.sep
)
if end_w_silence:
buffer_transcription = Transcript()
buffer_diarization = Transcript()
else:
buffer_transcription = state.buffer_transcription
buffer_diarization = state.buffer_diarization
# Handle undiarized text
if undiarized_text:
combined = self.sep.join(undiarized_text)
if buffer_transcription:
combined += self.sep
async with self.lock:
self.end_attributed_speaker = state.end_attributed_speaker
if buffer_diarization:
self.buffer_diarization = buffer_diarization
buffer_diarization.text = combined
response_status = "active_transcription" response_status = "active_transcription"
if not state.tokens and not buffer_transcription and not buffer_diarization: if not lines and not buffer_transcription_text and not buffer_diarization_text:
response_status = "no_audio_detected" response_status = "no_audio_detected"
lines = []
elif not lines:
lines = [Line(
speaker=1,
start=state.end_buffer,
end=state.end_buffer
)]
response = FrontData( response = FrontData(
status=response_status, status=response_status,
lines=lines, lines=lines,
buffer_transcription=buffer_transcription.text, buffer_transcription=buffer_transcription_text,
buffer_diarization=buffer_transcription.text, buffer_diarization=buffer_diarization_text,
buffer_translation=buffer_translation_text,
remaining_time_transcription=state.remaining_time_transcription, remaining_time_transcription=state.remaining_time_transcription,
remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0 remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0
) )
should_push = (response != self.last_response_content) should_push = (response != self.last_response_content)
if should_push and (lines or buffer_transcription or buffer_diarization or response_status == "no_audio_detected"): if should_push:
yield response yield response
self.last_response_content = response self.last_response_content = response
# Check for termination condition if self.is_stopping and self._processing_tasks_done():
if self.is_stopping: logger.info("Results formatter: All upstream processors are done and in stopping state. Terminating.")
all_processors_done = True return
if self.args.transcription and self.transcription_task and not self.transcription_task.done():
all_processors_done = False
if self.args.diarization and self.diarization_task and not self.diarization_task.done():
all_processors_done = False
if all_processors_done:
logger.info("Results formatter: All upstream processors are done and in stopping state. Terminating.")
return
await asyncio.sleep(0.05) await asyncio.sleep(0.05)
except Exception as e: except Exception as e:
logger.warning(f"Exception in results_formatter: {e}") logger.warning(f"Exception in results_formatter. Traceback: {traceback.format_exc()}")
logger.warning(f"Traceback: {traceback.format_exc()}")
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
async def create_tasks(self): async def create_tasks(self) -> AsyncGenerator[FrontData, None]:
"""Create and start processing tasks.""" """Create and start processing tasks."""
self.all_tasks_for_cleanup = [] self.all_tasks_for_cleanup = []
processing_tasks_for_watchdog = [] processing_tasks_for_watchdog: List[asyncio.Task] = []
# If using FFmpeg (non-PCM input), start it and spawn stdout reader # If using FFmpeg (non-PCM input), start it and spawn stdout reader
if not self.is_pcm_input: if not self.is_pcm_input:
success = await self.ffmpeg_manager.start() success = await self.ffmpeg_manager.start()
if not success: if not success:
logger.error("Failed to start FFmpeg manager") logger.error("Failed to start FFmpeg manager")
async def error_generator(): async def error_generator() -> AsyncGenerator[FrontData, None]:
yield FrontData( yield FrontData(
status="error", status="error",
error="FFmpeg failed to start. Please check that FFmpeg is installed." error="FFmpeg failed to start. Please check that FFmpeg is installed."
@@ -505,18 +454,18 @@ class AudioProcessor:
self.all_tasks_for_cleanup.append(self.ffmpeg_reader_task) self.all_tasks_for_cleanup.append(self.ffmpeg_reader_task)
processing_tasks_for_watchdog.append(self.ffmpeg_reader_task) processing_tasks_for_watchdog.append(self.ffmpeg_reader_task)
if self.args.transcription and self.online: if self.transcription:
self.transcription_task = asyncio.create_task(self.transcription_processor()) self.transcription_task = asyncio.create_task(self.transcription_processor())
self.all_tasks_for_cleanup.append(self.transcription_task) self.all_tasks_for_cleanup.append(self.transcription_task)
processing_tasks_for_watchdog.append(self.transcription_task) processing_tasks_for_watchdog.append(self.transcription_task)
if self.args.diarization and self.diarization: if self.diarization:
self.diarization_task = asyncio.create_task(self.diarization_processor(self.diarization)) self.diarization_task = asyncio.create_task(self.diarization_processor())
self.all_tasks_for_cleanup.append(self.diarization_task) self.all_tasks_for_cleanup.append(self.diarization_task)
processing_tasks_for_watchdog.append(self.diarization_task) processing_tasks_for_watchdog.append(self.diarization_task)
if self.args.target_language and self.args.lan != 'auto': if self.translation:
self.translation_task = asyncio.create_task(self.translation_processor(self.online_translation)) self.translation_task = asyncio.create_task(self.translation_processor())
self.all_tasks_for_cleanup.append(self.translation_task) self.all_tasks_for_cleanup.append(self.translation_task)
processing_tasks_for_watchdog.append(self.translation_task) processing_tasks_for_watchdog.append(self.translation_task)
@@ -526,13 +475,18 @@ class AudioProcessor:
return self.results_formatter() return self.results_formatter()
async def watchdog(self, tasks_to_monitor): async def watchdog(self, tasks_to_monitor: List[asyncio.Task]) -> None:
"""Monitors the health of critical processing tasks.""" """Monitors the health of critical processing tasks."""
tasks_remaining: List[asyncio.Task] = [task for task in tasks_to_monitor if task]
while True: while True:
try: try:
if not tasks_remaining:
logger.info("Watchdog task finishing: all monitored tasks completed.")
return
await asyncio.sleep(10) await asyncio.sleep(10)
for i, task in enumerate(tasks_to_monitor): for i, task in enumerate(list(tasks_remaining)):
if task.done(): if task.done():
exc = task.exception() exc = task.exception()
task_name = task.get_name() if hasattr(task, 'get_name') else f"Monitored Task {i}" task_name = task.get_name() if hasattr(task, 'get_name') else f"Monitored Task {i}"
@@ -540,6 +494,7 @@ class AudioProcessor:
logger.error(f"{task_name} unexpectedly completed with exception: {exc}") logger.error(f"{task_name} unexpectedly completed with exception: {exc}")
else: else:
logger.info(f"{task_name} completed normally.") logger.info(f"{task_name} completed normally.")
tasks_remaining.remove(task)
except asyncio.CancelledError: except asyncio.CancelledError:
logger.info("Watchdog task cancelled.") logger.info("Watchdog task cancelled.")
@@ -547,7 +502,7 @@ class AudioProcessor:
except Exception as e: except Exception as e:
logger.error(f"Error in watchdog task: {e}", exc_info=True) logger.error(f"Error in watchdog task: {e}", exc_info=True)
async def cleanup(self): async def cleanup(self) -> None:
"""Clean up resources when processing is complete.""" """Clean up resources when processing is complete."""
logger.info("Starting cleanup of AudioProcessor resources.") logger.info("Starting cleanup of AudioProcessor resources.")
self.is_stopping = True self.is_stopping = True
@@ -566,16 +521,28 @@ class AudioProcessor:
logger.info("FFmpeg manager stopped.") logger.info("FFmpeg manager stopped.")
except Exception as e: except Exception as e:
logger.warning(f"Error stopping FFmpeg manager: {e}") logger.warning(f"Error stopping FFmpeg manager: {e}")
if self.args.diarization and hasattr(self, 'dianization') and hasattr(self.diarization, 'close'): if self.diarization:
self.diarization.close() self.diarization.close()
logger.info("AudioProcessor cleanup complete.") logger.info("AudioProcessor cleanup complete.")
def _processing_tasks_done(self) -> bool:
"""Return True when all active processing tasks have completed."""
tasks_to_check = [
self.transcription_task,
self.diarization_task,
self.translation_task,
self.ffmpeg_reader_task,
]
return all(task.done() for task in tasks_to_check if task)
async def process_audio(self, message):
async def process_audio(self, message: Optional[bytes]) -> None:
"""Process incoming audio data.""" """Process incoming audio data."""
if not self.beg_loop: if not self.beg_loop:
self.beg_loop = time() self.beg_loop = time()
self.current_silence = Silence(start=0.0, is_starting=True)
self.tokens_alignment.beg_loop = self.beg_loop
if not message: if not message:
logger.info("Empty audio message received, initiating stop sequence.") logger.info("Empty audio message received, initiating stop sequence.")
@@ -608,7 +575,7 @@ class AudioProcessor:
else: else:
logger.warning("Failed to write audio data to FFmpeg") logger.warning("Failed to write audio data to FFmpeg")
async def handle_pcm_data(self): async def handle_pcm_data(self) -> None:
# Process when enough data # Process when enough data
if len(self.pcm_buffer) < self.bytes_per_sec: if len(self.pcm_buffer) < self.bytes_per_sec:
return return
@@ -619,44 +586,38 @@ class AudioProcessor:
f"Consider using a smaller model." f"Consider using a smaller model."
) )
# Process audio chunk chunk_size = min(len(self.pcm_buffer), self.max_bytes_per_sec)
pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:self.max_bytes_per_sec]) aligned_chunk_size = (chunk_size // self.bytes_per_sample) * self.bytes_per_sample
self.pcm_buffer = self.pcm_buffer[self.max_bytes_per_sec:]
if aligned_chunk_size == 0:
return
pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:aligned_chunk_size])
self.pcm_buffer = self.pcm_buffer[aligned_chunk_size:]
num_samples = len(pcm_array)
chunk_sample_start = self.total_pcm_samples
chunk_sample_end = chunk_sample_start + num_samples
res = None res = None
end_of_audio = False
silence_buffer = None
if self.args.vac: if self.args.vac:
res = self.vac(pcm_array) res = self.vac(pcm_array)
if res is not None: if res is not None:
if res.get("end", 0) > res.get("start", 0): silence_detected = res.get("end", 0) > res.get("start", 0)
end_of_audio = True if silence_detected and not self.current_silence:
elif self.silence: #end of silence pre_silence_chunk = self._slice_before_silence(
self.silence = False pcm_array, chunk_sample_start, res.get("end")
silence_buffer = Silence(duration=time() - self.start_silence) )
if pre_silence_chunk is not None and pre_silence_chunk.size > 0:
await self._enqueue_active_audio(pre_silence_chunk)
await self._begin_silence()
elif self.current_silence:
await self._end_silence()
if silence_buffer: if not self.current_silence:
if self.args.transcription and self.transcription_queue: await self._enqueue_active_audio(pcm_array)
await self.transcription_queue.put(silence_buffer)
if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(silence_buffer)
if self.translation_queue:
await self.translation_queue.put(silence_buffer)
if not self.silence: self.total_pcm_samples = chunk_sample_end
if self.args.transcription and self.transcription_queue:
await self.transcription_queue.put(pcm_array.copy())
if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(pcm_array.copy())
self.silence_duration = 0.0
if end_of_audio:
self.silence = True
self.start_silence = time()
if not self.args.transcription and not self.args.diarization: if not self.args.transcription and not self.args.diarization:
await asyncio.sleep(0.1) await asyncio.sleep(0.1)

View File

@@ -0,0 +1,41 @@
import importlib.util
import logging
import platform
logger = logging.getLogger(__name__)
def module_available(module_name):
"""Return True if the given module can be imported."""
return importlib.util.find_spec(module_name) is not None
def mlx_backend_available(warn_on_missing = False):
is_macos = platform.system() == "Darwin"
is_arm = platform.machine() == "arm64"
available = (
is_macos
and is_arm
and module_available("mlx_whisper")
)
if not available and warn_on_missing and is_macos and is_arm:
logger.warning(
"=" * 50
+ "\nMLX Whisper not found but you are on Apple Silicon. "
"Consider installing mlx-whisper for better performance: "
"`pip install mlx-whisper`\n"
+ "=" * 50
)
return available
def faster_backend_available(warn_on_missing = False):
available = module_available("faster_whisper")
if not available and warn_on_missing and platform.system() != "Darwin":
logger.warning(
"=" * 50
+ "\nFaster-Whisper not found. Consider installing faster-whisper "
"for better performance: `pip install faster-whisper`\n"
+ "=" * 50
)
return available

View File

@@ -5,9 +5,6 @@ from fastapi.middleware.cors import CORSMiddleware
from whisperlivekit import TranscriptionEngine, AudioProcessor, get_inline_ui_html, parse_args from whisperlivekit import TranscriptionEngine, AudioProcessor, get_inline_ui_html, parse_args
import asyncio import asyncio
import logging import logging
from starlette.staticfiles import StaticFiles
import pathlib
import whisperlivekit.web as webpkg
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logging.getLogger().setLevel(logging.WARNING) logging.getLogger().setLevel(logging.WARNING)
@@ -33,8 +30,6 @@ app.add_middleware(
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],
) )
web_dir = pathlib.Path(webpkg.__file__).parent
app.mount("/web", StaticFiles(directory=str(web_dir)), name="web")
@app.get("/") @app.get("/")
async def get(): async def get():
@@ -123,6 +118,8 @@ def main():
if ssl_kwargs: if ssl_kwargs:
uvicorn_kwargs = {**uvicorn_kwargs, **ssl_kwargs} uvicorn_kwargs = {**uvicorn_kwargs, **ssl_kwargs}
if args.forwarded_allow_ips:
uvicorn_kwargs = { **uvicorn_kwargs, "forwarded_allow_ips" : args.forwarded_allow_ips }
uvicorn.run(**uvicorn_kwargs) uvicorn.run(**uvicorn_kwargs)

View File

@@ -1,12 +1,18 @@
try: from whisperlivekit.local_agreement.whisper_online import backend_factory
from whisperlivekit.whisper_streaming_custom.whisper_online import backend_factory from whisperlivekit.simul_whisper import SimulStreamingASR
from whisperlivekit.whisper_streaming_custom.online_asr import OnlineASRProcessor from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor
except ImportError:
from .whisper_streaming_custom.whisper_online import backend_factory
from .whisper_streaming_custom.online_asr import OnlineASRProcessor
from whisperlivekit.warmup import warmup_asr
from argparse import Namespace from argparse import Namespace
import sys import sys
import logging
def update_with_kwargs(_dict, kwargs):
_dict.update({
k: v for k, v in kwargs.items() if k in _dict
})
return _dict
logger = logging.getLogger(__name__)
class TranscriptionEngine: class TranscriptionEngine:
_instance = None _instance = None
@@ -21,76 +27,51 @@ class TranscriptionEngine:
if TranscriptionEngine._initialized: if TranscriptionEngine._initialized:
return return
defaults = { global_params = {
"host": "localhost", "host": "localhost",
"port": 8000, "port": 8000,
"warmup_file": None,
"diarization": False, "diarization": False,
"punctuation_split": False, "punctuation_split": False,
"min_chunk_size": 0.5,
"model": "tiny",
"model_cache_dir": None,
"model_dir": None,
"lan": "auto",
"task": "transcribe",
"target_language": "", "target_language": "",
"backend": "faster-whisper",
"vac": True, "vac": True,
"vac_onnx": False,
"vac_chunk_size": 0.04, "vac_chunk_size": 0.04,
"log_level": "DEBUG", "log_level": "DEBUG",
"ssl_certfile": None, "ssl_certfile": None,
"ssl_keyfile": None, "ssl_keyfile": None,
"forwarded_allow_ips": None,
"transcription": True, "transcription": True,
"vad": True, "vad": True,
"pcm_input": False, "pcm_input": False,
# whisperstreaming params:
"buffer_trimming": "segment",
"confidence_validation": False,
"buffer_trimming_sec": 15,
# simulstreaming params:
"disable_fast_encoder": False,
"frame_threshold": 25,
"beams": 1,
"decoder_type": None,
"audio_max_len": 20.0,
"audio_min_len": 0.0,
"cif_ckpt_path": None,
"never_fire": False,
"init_prompt": None,
"static_init_prompt": None,
"max_context_tokens": None,
"model_path": './base.pt',
"diarization_backend": "sortformer",
# diarization params:
"disable_punctuation_split" : False, "disable_punctuation_split" : False,
"segmentation_model": "pyannote/segmentation-3.0", "diarization_backend": "sortformer",
"embedding_model": "pyannote/embedding", "backend_policy": "simulstreaming",
"backend": "auto",
# translation params:
"nllb_backend": "ctranslate2",
"nllb_size": "600M"
} }
global_params = update_with_kwargs(global_params, kwargs)
config_dict = {**defaults, **kwargs} transcription_common_params = {
"warmup_file": None,
"min_chunk_size": 0.1,
"model_size": "base",
"model_cache_dir": None,
"model_dir": None,
"model_path": None,
"lan": "auto",
"direct_english_translation": False,
}
transcription_common_params = update_with_kwargs(transcription_common_params, kwargs)
if transcription_common_params['model_size'].endswith(".en"):
transcription_common_params["lan"] = "en"
if 'no_transcription' in kwargs: if 'no_transcription' in kwargs:
config_dict['transcription'] = not kwargs['no_transcription'] global_params['transcription'] = not global_params['no_transcription']
if 'no_vad' in kwargs: if 'no_vad' in kwargs:
config_dict['vad'] = not kwargs['no_vad'] global_params['vad'] = not kwargs['no_vad']
if 'no_vac' in kwargs: if 'no_vac' in kwargs:
config_dict['vac'] = not kwargs['no_vac'] global_params['vac'] = not kwargs['no_vac']
config_dict.pop('no_transcription', None) self.args = Namespace(**{**global_params, **transcription_common_params})
config_dict.pop('no_vad', None)
if 'language' in kwargs:
config_dict['lan'] = kwargs['language']
config_dict.pop('language', None)
self.args = Namespace(**config_dict)
self.asr = None self.asr = None
self.tokenizer = None self.tokenizer = None
@@ -98,77 +79,100 @@ class TranscriptionEngine:
self.vac_model = None self.vac_model = None
if self.args.vac: if self.args.vac:
import torch from whisperlivekit.silero_vad_iterator import load_silero_vad
self.vac_model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad") # Use ONNX if specified, otherwise use JIT (default)
use_onnx = kwargs.get('vac_onnx', False)
self.vac_model = load_silero_vad(onnx=use_onnx)
backend_policy = self.args.backend_policy
if self.args.transcription: if self.args.transcription:
if self.args.backend == "simulstreaming": if backend_policy == "simulstreaming":
from whisperlivekit.simul_whisper import SimulStreamingASR simulstreaming_params = {
"disable_fast_encoder": False,
"custom_alignment_heads": None,
"frame_threshold": 25,
"beams": 1,
"decoder_type": None,
"audio_max_len": 20.0,
"audio_min_len": 0.0,
"cif_ckpt_path": None,
"never_fire": False,
"init_prompt": None,
"static_init_prompt": None,
"max_context_tokens": None,
"preload_model_count": 1,
}
simulstreaming_params = update_with_kwargs(simulstreaming_params, kwargs)
self.tokenizer = None self.tokenizer = None
simulstreaming_kwargs = {}
for attr in ['frame_threshold', 'beams', 'decoder_type', 'audio_max_len', 'audio_min_len',
'cif_ckpt_path', 'never_fire', 'init_prompt', 'static_init_prompt',
'max_context_tokens', 'model_path', 'warmup_file', 'preload_model_count', 'disable_fast_encoder']:
if hasattr(self.args, attr):
simulstreaming_kwargs[attr] = getattr(self.args, attr)
# Add segment_length from min_chunk_size
simulstreaming_kwargs['segment_length'] = getattr(self.args, 'min_chunk_size', 0.5)
simulstreaming_kwargs['task'] = self.args.task
size = self.args.model
self.asr = SimulStreamingASR( self.asr = SimulStreamingASR(
modelsize=size, **transcription_common_params,
lan=self.args.lan, **simulstreaming_params,
cache_dir=getattr(self.args, 'model_cache_dir', None), backend=self.args.backend,
model_dir=getattr(self.args, 'model_dir', None), )
**simulstreaming_kwargs logger.info(
"Using SimulStreaming policy with %s backend",
getattr(self.asr, "encoder_backend", "whisper"),
) )
else: else:
self.asr, self.tokenizer = backend_factory(self.args)
warmup_asr(self.asr, self.args.warmup_file) #for simulstreaming, warmup should be done in the online class not here whisperstreaming_params = {
"buffer_trimming": "segment",
"confidence_validation": False,
"buffer_trimming_sec": 15,
}
whisperstreaming_params = update_with_kwargs(whisperstreaming_params, kwargs)
self.asr = backend_factory(
backend=self.args.backend,
**transcription_common_params,
**whisperstreaming_params,
)
logger.info(
"Using LocalAgreement policy with %s backend",
getattr(self.asr, "backend_choice", self.asr.__class__.__name__),
)
if self.args.diarization: if self.args.diarization:
if self.args.diarization_backend == "diart": if self.args.diarization_backend == "diart":
from whisperlivekit.diarization.diart_backend import DiartDiarization from whisperlivekit.diarization.diart_backend import DiartDiarization
diart_params = {
"segmentation_model": "pyannote/segmentation-3.0",
"embedding_model": "pyannote/embedding",
}
diart_params = update_with_kwargs(diart_params, kwargs)
self.diarization_model = DiartDiarization( self.diarization_model = DiartDiarization(
block_duration=self.args.min_chunk_size, block_duration=self.args.min_chunk_size,
segmentation_model_name=self.args.segmentation_model, **diart_params
embedding_model_name=self.args.embedding_model
) )
elif self.args.diarization_backend == "sortformer": elif self.args.diarization_backend == "sortformer":
from whisperlivekit.diarization.sortformer_backend import SortformerDiarization from whisperlivekit.diarization.sortformer_backend import SortformerDiarization
self.diarization_model = SortformerDiarization() self.diarization_model = SortformerDiarization()
else:
raise ValueError(f"Unknown diarization backend: {self.args.diarization_backend}")
self.translation_model = None self.translation_model = None
if self.args.target_language: if self.args.target_language:
if self.args.lan == 'auto': if self.args.lan == 'auto' and backend_policy != "simulstreaming":
raise Exception('Translation cannot be set with language auto') raise Exception('Translation cannot be set with language auto when transcription backend is not simulstreaming')
else: else:
from whisperlivekit.translation.translation import load_model try:
self.translation_model = load_model([self.args.lan], backend=self.args.nllb_backend, model_size=self.args.nllb_size) #in the future we want to handle different languages for different speakers from nllw import load_model
except:
raise Exception('To use translation, you must install nllw: `pip install nllw`')
translation_params = {
"nllb_backend": "transformers",
"nllb_size": "600M"
}
translation_params = update_with_kwargs(translation_params, kwargs)
self.translation_model = load_model([self.args.lan], **translation_params) #in the future we want to handle different languages for different speakers
TranscriptionEngine._initialized = True TranscriptionEngine._initialized = True
def online_factory(args, asr):
def online_factory(args, asr, tokenizer, logfile=sys.stderr): if args.backend_policy == "simulstreaming":
if args.backend == "simulstreaming":
from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor
online = SimulStreamingOnlineProcessor( online = SimulStreamingOnlineProcessor(asr)
asr,
logfile=logfile,
)
else: else:
online = OnlineASRProcessor( online = OnlineASRProcessor(asr)
asr,
tokenizer,
logfile=logfile,
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
confidence_validation = args.confidence_validation
)
return online return online
@@ -187,5 +191,5 @@ def online_translation_factory(args, translation_model):
#should be at speaker level in the future: #should be at speaker level in the future:
#one shared nllb model for all speaker #one shared nllb model for all speaker
#one tokenizer per speaker/language #one tokenizer per speaker/language
from whisperlivekit.translation.translation import OnlineTranslation from nllw import OnlineTranslation
return OnlineTranslation(translation_model, [args.lan], [args.target_language]) return OnlineTranslation(translation_model, [args.lan], [args.target_language])

View File

@@ -26,7 +26,7 @@ class DiarizationObserver(Observer):
"""Observer that logs all data emitted by the diarization pipeline and stores speaker segments.""" """Observer that logs all data emitted by the diarization pipeline and stores speaker segments."""
def __init__(self): def __init__(self):
self.speaker_segments = [] self.diarization_segments = []
self.processed_time = 0 self.processed_time = 0
self.segment_lock = threading.Lock() self.segment_lock = threading.Lock()
self.global_time_offset = 0.0 self.global_time_offset = 0.0
@@ -48,7 +48,7 @@ class DiarizationObserver(Observer):
for speaker, label in annotation._labels.items(): for speaker, label in annotation._labels.items():
for start, end in zip(label.segments_boundaries_[:-1], label.segments_boundaries_[1:]): for start, end in zip(label.segments_boundaries_[:-1], label.segments_boundaries_[1:]):
print(f" {speaker}: {start:.2f}s-{end:.2f}s") print(f" {speaker}: {start:.2f}s-{end:.2f}s")
self.speaker_segments.append(SpeakerSegment( self.diarization_segments.append(SpeakerSegment(
speaker=speaker, speaker=speaker,
start=start + self.global_time_offset, start=start + self.global_time_offset,
end=end + self.global_time_offset end=end + self.global_time_offset
@@ -59,14 +59,14 @@ class DiarizationObserver(Observer):
def get_segments(self) -> List[SpeakerSegment]: def get_segments(self) -> List[SpeakerSegment]:
"""Get a copy of the current speaker segments.""" """Get a copy of the current speaker segments."""
with self.segment_lock: with self.segment_lock:
return self.speaker_segments.copy() return self.diarization_segments.copy()
def clear_old_segments(self, older_than: float = 30.0): def clear_old_segments(self, older_than: float = 30.0):
"""Clear segments older than the specified time.""" """Clear segments older than the specified time."""
with self.segment_lock: with self.segment_lock:
current_time = self.processed_time current_time = self.processed_time
self.speaker_segments = [ self.diarization_segments = [
segment for segment in self.speaker_segments segment for segment in self.diarization_segments
if current_time - segment.end < older_than if current_time - segment.end < older_than
] ]
@@ -178,7 +178,6 @@ class DiartDiarization:
self.pipeline = SpeakerDiarization(config=config) self.pipeline = SpeakerDiarization(config=config)
self.observer = DiarizationObserver() self.observer = DiarizationObserver()
self.lag_diart = None
if use_microphone: if use_microphone:
self.source = MicrophoneAudioSource(block_duration=block_duration) self.source = MicrophoneAudioSource(block_duration=block_duration)
@@ -217,32 +216,6 @@ class DiartDiarization:
if self.custom_source: if self.custom_source:
self.custom_source.close() self.custom_source.close()
def assign_speakers_to_tokens(self, tokens: list, use_punctuation_split: bool = False) -> float:
"""
Assign speakers to tokens based on timing overlap with speaker segments.
Uses the segments collected by the observer.
If use_punctuation_split is True, uses punctuation marks to refine speaker boundaries.
"""
segments = self.observer.get_segments()
# Debug logging
logger.debug(f"assign_speakers_to_tokens called with {len(tokens)} tokens")
logger.debug(f"Available segments: {len(segments)}")
for i, seg in enumerate(segments[:5]): # Show first 5 segments
logger.debug(f" Segment {i}: {seg.speaker} [{seg.start:.2f}-{seg.end:.2f}]")
if not self.lag_diart and segments and tokens:
self.lag_diart = segments[0].start - tokens[0].start
if not use_punctuation_split:
for token in tokens:
for segment in segments:
if not (segment.end <= token.start + self.lag_diart or segment.start >= token.end + self.lag_diart):
token.speaker = extract_number(segment.speaker) + 1
else:
tokens = add_speaker_to_tokens(segments, tokens)
return tokens, segments[-1]
def concatenate_speakers(segments): def concatenate_speakers(segments):
segments_concatenated = [{"speaker": 1, "begin": 0.0, "end": 0.0}] segments_concatenated = [{"speaker": 1, "begin": 0.0, "end": 0.0}]

View File

@@ -94,11 +94,11 @@ class SortformerDiarizationOnline:
model_name: Pre-trained model name (default: "nvidia/diar_streaming_sortformer_4spk-v2") model_name: Pre-trained model name (default: "nvidia/diar_streaming_sortformer_4spk-v2")
""" """
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.speaker_segments = [] self.diarization_segments = []
self.diar_segments = []
self.buffer_audio = np.array([], dtype=np.float32) self.buffer_audio = np.array([], dtype=np.float32)
self.segment_lock = threading.Lock() self.segment_lock = threading.Lock()
self.global_time_offset = 0.0 self.global_time_offset = 0.0
self.processed_time = 0.0
self.debug = False self.debug = False
self.diar_model = shared_model.diar_model self.diar_model = shared_model.diar_model
@@ -156,11 +156,9 @@ class SortformerDiarizationOnline:
self.streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device) self.streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
self.streaming_state.mean_sil_emb = torch.zeros((batch_size, self.diar_model.sortformer_modules.fc_d_model), device=device) self.streaming_state.mean_sil_emb = torch.zeros((batch_size, self.diar_model.sortformer_modules.fc_d_model), device=device)
self.streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device) self.streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device)
# Initialize total predictions tensor
self.total_preds = torch.zeros((batch_size, 0, self.diar_model.sortformer_modules.n_spk), device=device) self.total_preds = torch.zeros((batch_size, 0, self.diar_model.sortformer_modules.n_spk), device=device)
def insert_silence(self, silence_duration: float): def insert_silence(self, silence_duration: Optional[float]):
""" """
Insert silence period by adjusting the global time offset. Insert silence period by adjusting the global time offset.
@@ -171,248 +169,111 @@ class SortformerDiarizationOnline:
self.global_time_offset += silence_duration self.global_time_offset += silence_duration
logger.debug(f"Inserted silence of {silence_duration:.2f}s, new offset: {self.global_time_offset:.2f}s") logger.debug(f"Inserted silence of {silence_duration:.2f}s, new offset: {self.global_time_offset:.2f}s")
async def diarize(self, pcm_array: np.ndarray): def insert_audio_chunk(self, pcm_array: np.ndarray):
if self.debug:
self.audio_buffer.append(pcm_array.copy())
self.buffer_audio = np.concatenate([self.buffer_audio, pcm_array.copy()])
async def diarize(self):
""" """
Process audio data for diarization in streaming fashion. Process audio data for diarization in streaming fashion.
Args: Args:
pcm_array: Audio data as numpy array pcm_array: Audio data as numpy array
""" """
try:
if self.debug:
self.audio_buffer.append(pcm_array.copy())
threshold = int(self.chunk_duration_seconds * self.sample_rate) threshold = int(self.chunk_duration_seconds * self.sample_rate)
self.buffer_audio = np.concatenate([self.buffer_audio, pcm_array.copy()]) if not len(self.buffer_audio) >= threshold:
if not len(self.buffer_audio) >= threshold: return []
return
audio = self.buffer_audio[:threshold] audio = self.buffer_audio[:threshold]
self.buffer_audio = self.buffer_audio[threshold:] self.buffer_audio = self.buffer_audio[threshold:]
device = self.diar_model.device device = self.diar_model.device
audio_signal_chunk = torch.tensor(audio, device=device).unsqueeze(0) audio_signal_chunk = torch.tensor(audio, device=device).unsqueeze(0)
audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]], device=device) audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]], device=device)
processed_signal_chunk, processed_signal_length_chunk = self.audio2mel.get_features( processed_signal_chunk, processed_signal_length_chunk = self.audio2mel.get_features(
audio_signal_chunk, audio_signal_length_chunk audio_signal_chunk, audio_signal_length_chunk
)
processed_signal_chunk = processed_signal_chunk.to(device)
processed_signal_length_chunk = processed_signal_length_chunk.to(device)
if self._previous_chunk_features is not None:
to_add = self._previous_chunk_features[:, :, -99:].to(device)
total_features = torch.concat([to_add, processed_signal_chunk], dim=2).to(device)
else:
total_features = processed_signal_chunk.to(device)
self._previous_chunk_features = processed_signal_chunk.to(device)
chunk_feat_seq_t = torch.transpose(total_features, 1, 2).to(device)
with torch.inference_mode():
left_offset = 8 if self._chunk_index > 0 else 0
right_offset = 8
self.streaming_state, self.total_preds = self.diar_model.forward_streaming_step(
processed_signal=chunk_feat_seq_t,
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]).to(device),
streaming_state=self.streaming_state,
total_preds=self.total_preds,
left_offset=left_offset,
right_offset=right_offset,
) )
processed_signal_chunk = processed_signal_chunk.to(device) new_segments = self._process_predictions()
processed_signal_length_chunk = processed_signal_length_chunk.to(device)
if self._previous_chunk_features is not None: self._chunk_index += 1
to_add = self._previous_chunk_features[:, :, -99:].to(device) return new_segments
total_features = torch.concat([to_add, processed_signal_chunk], dim=2).to(device)
else:
total_features = processed_signal_chunk.to(device)
self._previous_chunk_features = processed_signal_chunk.to(device)
chunk_feat_seq_t = torch.transpose(total_features, 1, 2).to(device)
with torch.inference_mode():
left_offset = 8 if self._chunk_index > 0 else 0
right_offset = 8
self.streaming_state, self.total_preds = self.diar_model.forward_streaming_step(
processed_signal=chunk_feat_seq_t,
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]).to(device),
streaming_state=self.streaming_state,
total_preds=self.total_preds,
left_offset=left_offset,
right_offset=right_offset,
)
# Convert predictions to speaker segments
self._process_predictions()
self._chunk_index += 1
except Exception as e:
logger.error(f"Error in diarize: {e}")
raise
# TODO: Handle case when stream ends with partial buffer (accumulated_duration > 0 but < chunk_duration_seconds)
def _process_predictions(self): def _process_predictions(self):
"""Process model predictions and convert to speaker segments.""" """Process model predictions and convert to speaker segments."""
try: preds_np = self.total_preds[0].cpu().numpy()
preds_np = self.total_preds[0].cpu().numpy() active_speakers = np.argmax(preds_np, axis=1)
active_speakers = np.argmax(preds_np, axis=1)
if self._len_prediction is None: if self._len_prediction is None:
self._len_prediction = len(active_speakers) self._len_prediction = len(active_speakers) #12
# Get predictions for current chunk frame_duration = self.chunk_duration_seconds / self._len_prediction
frame_duration = self.chunk_duration_seconds / self._len_prediction current_chunk_preds = active_speakers[-self._len_prediction:]
current_chunk_preds = active_speakers[-self._len_prediction:]
with self.segment_lock: new_segments = []
# Process predictions into segments
base_time = self._chunk_index * self.chunk_duration_seconds + self.global_time_offset
for idx, spk in enumerate(current_chunk_preds):
start_time = base_time + idx * frame_duration
end_time = base_time + (idx + 1) * frame_duration
# Check if this continues the last segment or starts a new one
if (self.speaker_segments and
self.speaker_segments[-1].speaker == spk and
abs(self.speaker_segments[-1].end - start_time) < frame_duration * 0.5):
# Continue existing segment
self.speaker_segments[-1].end = end_time
else:
# Create new segment
self.speaker_segments.append(SpeakerSegment(
speaker=spk,
start=start_time,
end=end_time
))
# Update processed time
self.processed_time = max(self.processed_time, base_time + self.chunk_duration_seconds)
logger.debug(f"Processed chunk {self._chunk_index}, total segments: {len(self.speaker_segments)}")
except Exception as e:
logger.error(f"Error processing predictions: {e}")
def assign_speakers_to_tokens(self, tokens: list, use_punctuation_split: bool = False) -> list:
"""
Assign speakers to tokens based on timing overlap with speaker segments.
Args:
tokens: List of tokens with timing information
use_punctuation_split: Whether to use punctuation for boundary refinement
Returns:
List of tokens with speaker assignments
Last speaker_segment
"""
with self.segment_lock: with self.segment_lock:
segments = self.speaker_segments.copy() base_time = self._chunk_index * self.chunk_duration_seconds + self.global_time_offset
current_spk = current_chunk_preds[0]
if not segments or not tokens: start_time = round(base_time, 2)
logger.debug("No segments or tokens available for speaker assignment") for idx, spk in enumerate(current_chunk_preds):
return tokens, None current_time = round(base_time + idx * frame_duration, 2)
if spk != current_spk:
logger.debug(f"Assigning speakers to {len(tokens)} tokens using {len(segments)} segments") new_segments.append(SpeakerSegment(
use_punctuation_split = False speaker=current_spk,
if not use_punctuation_split: start=start_time,
# Simple overlap-based assignment end=current_time
for token in tokens: ))
token.speaker = -1 # Default to no speaker start_time = current_time
for segment in segments: current_spk = spk
# Check for timing overlap new_segments.append(
if not (segment.end <= token.start or segment.start >= token.end): SpeakerSegment(
token.speaker = segment.speaker + 1 # Convert to 1-based indexing speaker=current_spk,
break start=start_time,
else: end=current_time
# Use punctuation-aware assignment (similar to diart_backend) )
tokens = self._add_speaker_to_tokens_with_punctuation(segments, tokens) )
return new_segments
return tokens, segments[-1]
def _add_speaker_to_tokens_with_punctuation(self, segments: List[SpeakerSegment], tokens: list) -> list:
"""
Assign speakers to tokens with punctuation-aware boundary adjustment.
Args:
segments: List of speaker segments
tokens: List of tokens to assign speakers to
Returns:
List of tokens with speaker assignments
"""
punctuation_marks = {'.', '!', '?'}
punctuation_tokens = [token for token in tokens if token.text.strip() in punctuation_marks]
# Convert segments to concatenated format
segments_concatenated = self._concatenate_speakers(segments)
# Adjust segment boundaries based on punctuation
for ind, segment in enumerate(segments_concatenated):
for i, punctuation_token in enumerate(punctuation_tokens):
if punctuation_token.start > segment['end']:
after_length = punctuation_token.start - segment['end']
before_length = segment['end'] - punctuation_tokens[i - 1].end if i > 0 else float('inf')
if before_length > after_length:
segment['end'] = punctuation_token.start
if i < len(punctuation_tokens) - 1 and ind + 1 < len(segments_concatenated):
segments_concatenated[ind + 1]['begin'] = punctuation_token.start
else:
segment['end'] = punctuation_tokens[i - 1].end if i > 0 else segment['end']
if i < len(punctuation_tokens) - 1 and ind - 1 >= 0:
segments_concatenated[ind - 1]['begin'] = punctuation_tokens[i - 1].end
break
# Ensure non-overlapping tokens
last_end = 0.0
for token in tokens:
start = max(last_end + 0.01, token.start)
token.start = start
token.end = max(start, token.end)
last_end = token.end
# Assign speakers based on adjusted segments
ind_last_speaker = 0
for segment in segments_concatenated:
for i, token in enumerate(tokens[ind_last_speaker:]):
if token.end <= segment['end']:
token.speaker = segment['speaker']
ind_last_speaker = i + 1
elif token.start > segment['end']:
break
return tokens
def _concatenate_speakers(self, segments: List[SpeakerSegment]) -> List[dict]:
"""
Concatenate consecutive segments from the same speaker.
Args:
segments: List of speaker segments
Returns:
List of concatenated speaker segments
"""
if not segments:
return []
segments_concatenated = [{"speaker": segments[0].speaker + 1, "begin": segments[0].start, "end": segments[0].end}]
for segment in segments[1:]:
speaker = segment.speaker + 1
if segments_concatenated[-1]['speaker'] != speaker:
segments_concatenated.append({"speaker": speaker, "begin": segment.start, "end": segment.end})
else:
segments_concatenated[-1]['end'] = segment.end
return segments_concatenated
def get_segments(self) -> List[SpeakerSegment]: def get_segments(self) -> List[SpeakerSegment]:
"""Get a copy of the current speaker segments.""" """Get a copy of the current speaker segments."""
with self.segment_lock: with self.segment_lock:
return self.speaker_segments.copy() return self.diarization_segments.copy()
def clear_old_segments(self, older_than: float = 30.0):
"""Clear segments older than the specified time."""
with self.segment_lock:
current_time = self.processed_time
self.speaker_segments = [
segment for segment in self.speaker_segments
if current_time - segment.end < older_than
]
logger.debug(f"Cleared old segments, remaining: {len(self.speaker_segments)}")
def close(self): def close(self):
"""Close the diarization system and clean up resources.""" """Close the diarization system and clean up resources."""
logger.info("Closing SortformerDiarization") logger.info("Closing SortformerDiarization")
with self.segment_lock: with self.segment_lock:
self.speaker_segments.clear() self.diarization_segments.clear()
if self.debug: if self.debug:
concatenated_audio = np.concatenate(self.audio_buffer) concatenated_audio = np.concatenate(self.audio_buffer)
@@ -438,7 +299,7 @@ if __name__ == '__main__':
async def main(): async def main():
"""TEST ONLY.""" """TEST ONLY."""
an4_audio = 'audio_test.mp3' an4_audio = 'diarization_audio.wav'
signal, sr = librosa.load(an4_audio, sr=16000) signal, sr = librosa.load(an4_audio, sr=16000)
signal = signal[:16000*30] signal = signal[:16000*30]
@@ -450,13 +311,15 @@ if __name__ == '__main__':
print("Speaker 0: 0:25 - 0:30") print("Speaker 0: 0:25 - 0:30")
print("=" * 50) print("=" * 50)
diarization = SortformerDiarization(sample_rate=16000) diarization_backend = SortformerDiarization()
diarization = SortformerDiarizationOnline(shared_model = diarization_backend)
chunk_size = 1600 chunk_size = 1600
for i in range(0, len(signal), chunk_size): for i in range(0, len(signal), chunk_size):
chunk = signal[i:i+chunk_size] chunk = signal[i:i+chunk_size]
await diarization.diarize(chunk) new_segments = await diarization.diarize(chunk)
print(f"Processed chunk {i // chunk_size + 1}") print(f"Processed chunk {i // chunk_size + 1}")
print(new_segments)
segments = diarization.get_segments() segments = diarization.get_segments()
print("\nDiarization results:") print("\nDiarization results:")

View File

@@ -1,205 +0,0 @@
import numpy as np
import torch
import logging
from nemo.collections.asr.models import SortformerEncLabelModel
from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor
import librosa
logger = logging.getLogger(__name__)
def load_model():
diar_model = SortformerEncLabelModel.from_pretrained("nvidia/diar_streaming_sortformer_4spk-v2")
diar_model.eval()
if torch.cuda.is_available():
diar_model.to(torch.device("cuda"))
#we target 1 second lag for the moment. chunk_len could be reduced.
diar_model.sortformer_modules.chunk_len = 10
diar_model.sortformer_modules.subsampling_factor = 10 #8 would be better ideally
diar_model.sortformer_modules.chunk_right_context = 0 #no.
diar_model.sortformer_modules.chunk_left_context = 10 #big so it compensiate the problem with no padding later.
diar_model.sortformer_modules.spkcache_len = 188
diar_model.sortformer_modules.fifo_len = 188
diar_model.sortformer_modules.spkcache_update_period = 144
diar_model.sortformer_modules.log = False
diar_model.sortformer_modules._check_streaming_parameters()
audio2mel = AudioToMelSpectrogramPreprocessor(
window_size= 0.025,
normalize="NA",
n_fft=512,
features=128,
pad_to=0) #pad_to 16 works better than 0. On test audio, we detect a third speaker for 1 second with pad_to=0. To solve that : increase left context to 10.
return diar_model, audio2mel
diar_model, audio2mel = load_model()
class StreamingSortformerState:
"""
This class creates a class instance that will be used to store the state of the
streaming Sortformer model.
Attributes:
spkcache (torch.Tensor): Speaker cache to store embeddings from start
spkcache_lengths (torch.Tensor): Lengths of the speaker cache
spkcache_preds (torch.Tensor): The speaker predictions for the speaker cache parts
fifo (torch.Tensor): FIFO queue to save the embedding from the latest chunks
fifo_lengths (torch.Tensor): Lengths of the FIFO queue
fifo_preds (torch.Tensor): The speaker predictions for the FIFO queue parts
spk_perm (torch.Tensor): Speaker permutation information for the speaker cache
mean_sil_emb (torch.Tensor): Mean silence embedding
n_sil_frames (torch.Tensor): Number of silence frames
"""
spkcache = None # Speaker cache to store embeddings from start
spkcache_lengths = None #
spkcache_preds = None # speaker cache predictions
fifo = None # to save the embedding from the latest chunks
fifo_lengths = None
fifo_preds = None
spk_perm = None
mean_sil_emb = None
n_sil_frames = None
def init_streaming_state(self, batch_size: int = 1, async_streaming: bool = False, device: torch.device = None):
"""
Initializes StreamingSortformerState with empty tensors or zero-valued tensors.
Args:
batch_size (int): Batch size for tensors in streaming state
async_streaming (bool): True for asynchronous update, False for synchronous update
device (torch.device): Device for tensors in streaming state
Returns:
streaming_state (SortformerStreamingState): initialized streaming state
"""
streaming_state = StreamingSortformerState()
if async_streaming:
streaming_state.spkcache = torch.zeros((batch_size, self.spkcache_len, self.fc_d_model), device=device)
streaming_state.spkcache_preds = torch.zeros((batch_size, self.spkcache_len, self.n_spk), device=device)
streaming_state.spkcache_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
streaming_state.fifo = torch.zeros((batch_size, self.fifo_len, self.fc_d_model), device=device)
streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
else:
streaming_state.spkcache = torch.zeros((batch_size, 0, self.fc_d_model), device=device)
streaming_state.fifo = torch.zeros((batch_size, 0, self.fc_d_model), device=device)
streaming_state.mean_sil_emb = torch.zeros((batch_size, self.fc_d_model), device=device)
streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device)
return streaming_state
def process_diarization(chunks):
"""
what it does:
1. Preprocessing: Applies dithering and pre-emphasis (high-pass filter) if enabled
2. STFT: Computes the Short-Time Fourier Transform using:
- the window of window_size=0.025 --> size of a window : 400 samples
- the hop parameter : n_window_stride = 0.01 -> every 160 samples, a new window
3. Magnitude Calculation: Converts complex STFT output to magnitude spectrogram
4. Mel Conversion: Applies Mel filterbanks (128 filters in this case) to get Mel spectrogram
5. Logarithm: Takes the log of the Mel spectrogram (if `log=True`)
6. Normalization: Skips normalization since `normalize="NA"`
7. Padding: Pads the time dimension to a multiple of `pad_to` (default 16)
"""
previous_chunk = None
l_chunk_feat_seq_t = []
for chunk in chunks:
audio_signal_chunk = torch.tensor(chunk).unsqueeze(0).to(diar_model.device)
audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]]).to(diar_model.device)
processed_signal_chunk, processed_signal_length_chunk = audio2mel.get_features(audio_signal_chunk, audio_signal_length_chunk)
if previous_chunk is not None:
to_add = previous_chunk[:, :, -99:]
total = torch.concat([to_add, processed_signal_chunk], dim=2)
else:
total = processed_signal_chunk
previous_chunk = processed_signal_chunk
l_chunk_feat_seq_t.append(torch.transpose(total, 1, 2))
batch_size = 1
streaming_state = init_streaming_state(diar_model.sortformer_modules,
batch_size = batch_size,
async_streaming = True,
device = diar_model.device
)
total_preds = torch.zeros((batch_size, 0, diar_model.sortformer_modules.n_spk), device=diar_model.device)
chunk_duration_seconds = diar_model.sortformer_modules.chunk_len * diar_model.sortformer_modules.subsampling_factor * diar_model.preprocessor._cfg.window_stride
l_speakers = [
{'start_time': 0,
'end_time': 0,
'speaker': 0
}
]
len_prediction = None
left_offset = 0
right_offset = 8
for i, chunk_feat_seq_t in enumerate(l_chunk_feat_seq_t):
with torch.inference_mode():
streaming_state, total_preds = diar_model.forward_streaming_step(
processed_signal=chunk_feat_seq_t,
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]),
streaming_state=streaming_state,
total_preds=total_preds,
left_offset=left_offset,
right_offset=right_offset,
)
left_offset = 8
preds_np = total_preds[0].cpu().numpy()
active_speakers = np.argmax(preds_np, axis=1)
if len_prediction is None:
len_prediction = len(active_speakers) # we want to get the len of 1 prediction
frame_duration = chunk_duration_seconds / len_prediction
active_speakers = active_speakers[-len_prediction:]
for idx, spk in enumerate(active_speakers):
if spk != l_speakers[-1]['speaker']:
l_speakers.append(
{'start_time': (i * chunk_duration_seconds + idx * frame_duration),
'end_time': (i * chunk_duration_seconds + (idx + 1) * frame_duration),
'speaker': spk
})
else:
l_speakers[-1]['end_time'] = i * chunk_duration_seconds + (idx + 1) * frame_duration
"""
Should print
[{'start_time': 0, 'end_time': 8.72, 'speaker': 0},
{'start_time': 8.72, 'end_time': 18.88, 'speaker': 1},
{'start_time': 18.88, 'end_time': 24.96, 'speaker': 2},
{'start_time': 24.96, 'end_time': 31.68, 'speaker': 0}]
"""
for speaker in l_speakers:
print(f"Speaker {speaker['speaker']}: {speaker['start_time']:.2f}s - {speaker['end_time']:.2f}s")
if __name__ == '__main__':
an4_audio = 'audio_test.mp3'
signal, sr = librosa.load(an4_audio, sr=16000)
signal = signal[:16000*30]
# signal = signal[:-(len(signal)%16000)]
print("\n" + "=" * 50)
print("Expected ground truth:")
print("Speaker 0: 0:00 - 0:09")
print("Speaker 1: 0:09 - 0:19")
print("Speaker 2: 0:19 - 0:25")
print("Speaker 0: 0:25 - 0:30")
print("=" * 50)
chunk_size = 16000 # 1 second
chunks = []
for i in range(0, len(signal), chunk_size):
chunk = signal[i:i+chunk_size]
chunks.append(chunk)
process_diarization(chunks)

View File

@@ -6,19 +6,21 @@ import math
from typing import List from typing import List
import numpy as np import numpy as np
from whisperlivekit.timed_objects import ASRToken from whisperlivekit.timed_objects import ASRToken
from whisperlivekit.model_paths import resolve_model_path, model_path_and_type
from whisperlivekit.whisper.transcribe import transcribe as whisper_transcribe
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ASRBase: class ASRBase:
sep = " " # join transcribe words with this character (" " for whisper_timestamped, sep = " " # join transcribe words with this character (" " for whisper_timestamped,
# "" for faster-whisper because it emits the spaces when needed) # "" for faster-whisper because it emits the spaces when needed)
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr): def __init__(self, lan, model_size=None, cache_dir=None, model_dir=None, logfile=sys.stderr):
self.logfile = logfile self.logfile = logfile
self.transcribe_kargs = {} self.transcribe_kargs = {}
if lan == "auto": if lan == "auto":
self.original_language = None self.original_language = None
else: else:
self.original_language = lan self.original_language = lan
self.model = self.load_model(modelsize, cache_dir, model_dir) self.model = self.load_model(model_size, cache_dir, model_dir)
def with_offset(self, offset: float) -> ASRToken: def with_offset(self, offset: float) -> ASRToken:
# This method is kept for compatibility (typically you will use ASRToken.with_offset) # This method is kept for compatibility (typically you will use ASRToken.with_offset)
@@ -27,7 +29,7 @@ class ASRBase:
def __repr__(self): def __repr__(self):
return f"ASRToken(start={self.start:.2f}, end={self.end:.2f}, text={self.text!r})" return f"ASRToken(start={self.start:.2f}, end={self.end:.2f}, text={self.text!r})"
def load_model(self, modelsize, cache_dir, model_dir): def load_model(self, model_size, cache_dir, model_dir):
raise NotImplementedError("must be implemented in the child class") raise NotImplementedError("must be implemented in the child class")
def transcribe(self, audio, init_prompt=""): def transcribe(self, audio, init_prompt=""):
@@ -37,40 +39,60 @@ class ASRBase:
raise NotImplementedError("must be implemented in the child class") raise NotImplementedError("must be implemented in the child class")
class WhisperTimestampedASR(ASRBase): class WhisperASR(ASRBase):
"""Uses whisper_timestamped as the backend.""" """Uses WhisperLiveKit's built-in Whisper implementation."""
sep = " " sep = " "
def load_model(self, modelsize=None, cache_dir=None, model_dir=None): def load_model(self, model_size=None, cache_dir=None, model_dir=None):
import whisper from whisperlivekit.whisper import load_model as load_model
import whisper_timestamped
from whisper_timestamped import transcribe_timestamped
self.transcribe_timestamped = transcribe_timestamped
if model_dir is not None: if model_dir is not None:
logger.debug("ignoring model_dir, not implemented") resolved_path = resolve_model_path(model_dir)
return whisper.load_model(modelsize, download_root=cache_dir) if resolved_path.is_dir():
pytorch_path, _, _ = model_path_and_type(resolved_path)
if pytorch_path is None:
raise FileNotFoundError(
f"No supported PyTorch checkpoint found under {resolved_path}"
)
resolved_path = pytorch_path
logger.debug(f"Loading Whisper model from custom path {resolved_path}")
return load_model(str(resolved_path))
if model_size is None:
raise ValueError("Either model_size or model_dir must be set for WhisperASR")
return load_model(model_size, download_root=cache_dir)
def transcribe(self, audio, init_prompt=""): def transcribe(self, audio, init_prompt=""):
result = self.transcribe_timestamped( options = dict(self.transcribe_kargs)
options.pop("vad", None)
options.pop("vad_filter", None)
language = self.original_language if self.original_language else None
result = whisper_transcribe(
self.model, self.model,
audio, audio,
language=self.original_language, language=language,
initial_prompt=init_prompt, initial_prompt=init_prompt,
verbose=None,
condition_on_previous_text=True, condition_on_previous_text=True,
**self.transcribe_kargs, word_timestamps=True,
**options,
) )
return result return result
def ts_words(self, r) -> List[ASRToken]: def ts_words(self, r) -> List[ASRToken]:
""" """
Converts the whisper_timestamped result to a list of ASRToken objects. Converts the Whisper result to a list of ASRToken objects.
""" """
tokens = [] tokens = []
for segment in r["segments"]: for segment in r["segments"]:
for word in segment["words"]: for word in segment["words"]:
token = ASRToken(word["start"], word["end"], word["text"]) token = ASRToken(
word["start"],
word["end"],
word["word"],
probability=word.get("probability"),
)
tokens.append(token) tokens.append(token)
return tokens return tokens
@@ -78,27 +100,24 @@ class WhisperTimestampedASR(ASRBase):
return [segment["end"] for segment in res["segments"]] return [segment["end"] for segment in res["segments"]]
def use_vad(self): def use_vad(self):
self.transcribe_kargs["vad"] = True logger.warning("VAD is not currently supported for WhisperASR backend and will be ignored.")
def set_translate_task(self):
self.transcribe_kargs["task"] = "translate"
class FasterWhisperASR(ASRBase): class FasterWhisperASR(ASRBase):
"""Uses faster-whisper as the backend.""" """Uses faster-whisper as the backend."""
sep = "" sep = ""
def load_model(self, modelsize=None, cache_dir=None, model_dir=None): def load_model(self, model_size=None, cache_dir=None, model_dir=None):
from faster_whisper import WhisperModel from faster_whisper import WhisperModel
if model_dir is not None: if model_dir is not None:
logger.debug(f"Loading whisper model from model_dir {model_dir}. " resolved_path = resolve_model_path(model_dir)
f"modelsize and cache_dir parameters are not used.") logger.debug(f"Loading faster-whisper model from {resolved_path}. "
model_size_or_path = model_dir f"model_size and cache_dir parameters are not used.")
elif modelsize is not None: model_size_or_path = str(resolved_path)
model_size_or_path = modelsize elif model_size is not None:
model_size_or_path = model_size
else: else:
raise ValueError("Either modelsize or model_dir must be set") raise ValueError("Either model_size or model_dir must be set")
device = "auto" # Allow CTranslate2 to decide available device device = "auto" # Allow CTranslate2 to decide available device
compute_type = "auto" # Allow CTranslate2 to decide faster compute type compute_type = "auto" # Allow CTranslate2 to decide faster compute type
@@ -139,28 +158,25 @@ class FasterWhisperASR(ASRBase):
def use_vad(self): def use_vad(self):
self.transcribe_kargs["vad_filter"] = True self.transcribe_kargs["vad_filter"] = True
def set_translate_task(self):
self.transcribe_kargs["task"] = "translate"
class MLXWhisper(ASRBase): class MLXWhisper(ASRBase):
""" """
Uses MLX Whisper optimized for Apple Silicon. Uses MLX Whisper optimized for Apple Silicon.
""" """
sep = "" sep = ""
def load_model(self, modelsize=None, cache_dir=None, model_dir=None): def load_model(self, model_size=None, cache_dir=None, model_dir=None):
from mlx_whisper.transcribe import ModelHolder, transcribe from mlx_whisper.transcribe import ModelHolder, transcribe
import mlx.core as mx import mlx.core as mx
if model_dir is not None: if model_dir is not None:
logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used.") resolved_path = resolve_model_path(model_dir)
model_size_or_path = model_dir logger.debug(f"Loading MLX Whisper model from {resolved_path}. model_size parameter is not used.")
elif modelsize is not None: model_size_or_path = str(resolved_path)
model_size_or_path = self.translate_model_name(modelsize) elif model_size is not None:
logger.debug(f"Loading whisper model {modelsize}. You use mlx whisper, so {model_size_or_path} will be used.") model_size_or_path = self.translate_model_name(model_size)
logger.debug(f"Loading whisper model {model_size}. You use mlx whisper, so {model_size_or_path} will be used.")
else: else:
raise ValueError("Either modelsize or model_dir must be set") raise ValueError("Either model_size or model_dir must be set")
self.model_size_or_path = model_size_or_path self.model_size_or_path = model_size_or_path
dtype = mx.float16 dtype = mx.float16
@@ -208,7 +224,8 @@ class MLXWhisper(ASRBase):
if segment.get("no_speech_prob", 0) > 0.9: if segment.get("no_speech_prob", 0) > 0.9:
continue continue
for word in segment.get("words", []): for word in segment.get("words", []):
token = ASRToken(word["start"], word["end"], word["word"], probability=word["probability"]) probability=word["probability"]
token = ASRToken(word["start"], word["end"], word["word"])
tokens.append(token) tokens.append(token)
return tokens return tokens
@@ -218,10 +235,6 @@ class MLXWhisper(ASRBase):
def use_vad(self): def use_vad(self):
self.transcribe_kargs["vad_filter"] = True self.transcribe_kargs["vad_filter"] = True
def set_translate_task(self):
self.transcribe_kargs["task"] = "translate"
class OpenaiApiASR(ASRBase): class OpenaiApiASR(ASRBase):
"""Uses OpenAI's Whisper API for transcription.""" """Uses OpenAI's Whisper API for transcription."""
def __init__(self, lan=None, temperature=0, logfile=sys.stderr): def __init__(self, lan=None, temperature=0, logfile=sys.stderr):
@@ -232,7 +245,7 @@ class OpenaiApiASR(ASRBase):
self.temperature = temperature self.temperature = temperature
self.load_model() self.load_model()
self.use_vad_opt = False self.use_vad_opt = False
self.task = "transcribe" self.direct_english_translation = False
def load_model(self, *args, **kwargs): def load_model(self, *args, **kwargs):
from openai import OpenAI from openai import OpenAI
@@ -274,7 +287,7 @@ class OpenaiApiASR(ASRBase):
"temperature": self.temperature, "temperature": self.temperature,
"timestamp_granularities": ["word", "segment"], "timestamp_granularities": ["word", "segment"],
} }
if self.task != "translate" and self.original_language: if not self.direct_english_translation and self.original_language:
params["language"] = self.original_language params["language"] = self.original_language
if prompt: if prompt:
params["prompt"] = prompt params["prompt"] = prompt
@@ -285,6 +298,3 @@ class OpenaiApiASR(ASRBase):
def use_vad(self): def use_vad(self):
self.use_vad_opt = True self.use_vad_opt = True
def set_translate_task(self):
self.task = "translate"

View File

@@ -106,9 +106,6 @@ class OnlineASRProcessor:
def __init__( def __init__(
self, self,
asr, asr,
tokenize_method: Optional[callable] = None,
buffer_trimming: Tuple[str, float] = ("segment", 15),
confidence_validation = False,
logfile=sys.stderr, logfile=sys.stderr,
): ):
""" """
@@ -119,13 +116,14 @@ class OnlineASRProcessor:
buffer_trimming: A tuple (option, seconds), where option is either "sentence" or "segment". buffer_trimming: A tuple (option, seconds), where option is either "sentence" or "segment".
""" """
self.asr = asr self.asr = asr
self.tokenize = tokenize_method self.tokenize = asr.tokenizer
self.logfile = logfile self.logfile = logfile
self.confidence_validation = confidence_validation self.confidence_validation = asr.confidence_validation
self.global_time_offset = 0.0 self.global_time_offset = 0.0
self.init() self.init()
self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming self.buffer_trimming_way = asr.buffer_trimming
self.buffer_trimming_sec = asr.buffer_trimming_sec
if self.buffer_trimming_way not in ["sentence", "segment"]: if self.buffer_trimming_way not in ["sentence", "segment"]:
raise ValueError("buffer_trimming must be either 'sentence' or 'segment'") raise ValueError("buffer_trimming must be either 'sentence' or 'segment'")
@@ -153,21 +151,32 @@ class OnlineASRProcessor:
"""Append an audio chunk (a numpy array) to the current audio buffer.""" """Append an audio chunk (a numpy array) to the current audio buffer."""
self.audio_buffer = np.append(self.audio_buffer, audio) self.audio_buffer = np.append(self.audio_buffer, audio)
def insert_silence(self, silence_duration, offset): def start_silence(self):
""" if self.audio_buffer.size == 0:
If silences are > 5s, we do a complete context clear. Otherwise, we just insert a small silence and shift the last_attend_frame return [], self.get_audio_buffer_end_time()
""" return self.process_iter()
# if self.transcript_buffer.buffer:
# self.committed.extend(self.transcript_buffer.buffer)
# self.transcript_buffer.buffer = []
if True: #silence_duration < 3: #we want the last audio to be treated to not have a gap. could also be handled in the future in ends_with_silence. def end_silence(self, silence_duration: Optional[float], offset: float):
gap_silence = np.zeros(int(16000 * silence_duration), dtype=np.int16) if not silence_duration or silence_duration <= 0:
self.insert_audio_chunk(gap_silence) return
long_silence = silence_duration >= 5
if not long_silence:
gap_samples = int(self.SAMPLING_RATE * silence_duration)
if gap_samples > 0:
gap_silence = np.zeros(gap_samples, dtype=np.float32)
self.insert_audio_chunk(gap_silence)
else: else:
self.init(offset=silence_duration + offset) self.init(offset=silence_duration + offset)
self.global_time_offset += silence_duration self.global_time_offset += silence_duration
def insert_silence(self, silence_duration, offset):
"""
Backwards compatibility shim for legacy callers that still use insert_silence.
"""
self.end_silence(silence_duration, offset)
def prompt(self) -> Tuple[str, str]: def prompt(self) -> Tuple[str, str]:
""" """
Returns a tuple: (prompt, context), where: Returns a tuple: (prompt, context), where:
@@ -402,11 +411,11 @@ class OnlineASRProcessor:
) -> Transcript: ) -> Transcript:
sep = sep if sep is not None else self.asr.sep sep = sep if sep is not None else self.asr.sep
text = sep.join(token.text for token in tokens) text = sep.join(token.text for token in tokens)
probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None # probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
if tokens: if tokens:
start = offset + tokens[0].start start = offset + tokens[0].start
end = offset + tokens[-1].end end = offset + tokens[-1].end
else: else:
start = None start = None
end = None end = None
return Transcript(start, end, text, probability=probability) return Transcript(start, end, text)

View File

@@ -0,0 +1,199 @@
#!/usr/bin/env python3
import sys
import numpy as np
import librosa
from functools import lru_cache
import time
import logging
import platform
from .backends import FasterWhisperASR, MLXWhisper, WhisperASR, OpenaiApiASR
from whisperlivekit.warmup import warmup_asr
from whisperlivekit.model_paths import resolve_model_path, model_path_and_type
from whisperlivekit.backend_support import (
mlx_backend_available,
faster_backend_available,
)
logger = logging.getLogger(__name__)
WHISPER_LANG_CODES = "af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh".split(
","
)
def create_tokenizer(lan):
"""returns an object that has split function that works like the one of MosesTokenizer"""
assert (
lan in WHISPER_LANG_CODES
), "language must be Whisper's supported lang code: " + " ".join(WHISPER_LANG_CODES)
if lan == "uk":
import tokenize_uk
class UkrainianTokenizer:
def split(self, text):
return tokenize_uk.tokenize_sents(text)
return UkrainianTokenizer()
# supported by fast-mosestokenizer
if (
lan
in "as bn ca cs de el en es et fi fr ga gu hi hu is it kn lt lv ml mni mr nl or pa pl pt ro ru sk sl sv ta te yue zh".split()
):
from mosestokenizer import MosesSentenceSplitter
return MosesSentenceSplitter(lan)
# the following languages are in Whisper, but not in wtpsplit:
if (
lan
in "as ba bo br bs fo haw hr ht jw lb ln lo mi nn oc sa sd sn so su sw tk tl tt".split()
):
logger.debug(
f"{lan} code is not supported by wtpsplit. Going to use None lang_code option."
)
lan = None
from wtpsplit import WtP
# downloads the model from huggingface on the first use
wtp = WtP("wtp-canine-s-12l-no-adapters")
class WtPtok:
def split(self, sent):
return wtp.split(sent, lang_code=lan)
return WtPtok()
def backend_factory(
backend,
lan,
model_size,
model_cache_dir,
model_dir,
model_path,
direct_english_translation,
buffer_trimming,
buffer_trimming_sec,
confidence_validation,
warmup_file=None,
min_chunk_size=None,
):
backend_choice = backend
custom_reference = model_path or model_dir
resolved_root = None
pytorch_checkpoint = None
has_mlx_weights = False
has_fw_weights = False
if custom_reference:
resolved_root = resolve_model_path(custom_reference)
if resolved_root.is_dir():
pytorch_checkpoint, has_mlx_weights, has_fw_weights = model_path_and_type(resolved_root)
else:
pytorch_checkpoint = resolved_root
if backend_choice == "openai-api":
logger.debug("Using OpenAI API.")
asr = OpenaiApiASR(lan=lan)
else:
backend_choice = _normalize_backend_choice(
backend_choice,
resolved_root,
has_mlx_weights,
has_fw_weights,
)
if backend_choice == "faster-whisper":
asr_cls = FasterWhisperASR
if resolved_root is not None and not resolved_root.is_dir():
raise ValueError("Faster-Whisper backend expects a directory with CTranslate2 weights.")
model_override = str(resolved_root) if resolved_root is not None else None
elif backend_choice == "mlx-whisper":
asr_cls = MLXWhisper
if resolved_root is not None and not resolved_root.is_dir():
raise ValueError("MLX Whisper backend expects a directory containing MLX weights.")
model_override = str(resolved_root) if resolved_root is not None else None
else:
asr_cls = WhisperASR
model_override = str(pytorch_checkpoint) if pytorch_checkpoint is not None else None
if custom_reference and model_override is None:
raise FileNotFoundError(
f"No PyTorch checkpoint found under {resolved_root or custom_reference}"
)
t = time.time()
logger.info(f"Loading Whisper {model_size} model for language {lan} using backend {backend_choice}...")
asr = asr_cls(
model_size=model_size,
lan=lan,
cache_dir=model_cache_dir,
model_dir=model_override,
)
e = time.time()
logger.info(f"done. It took {round(e-t,2)} seconds.")
if direct_english_translation:
tgt_language = "en" # Whisper translates into English
else:
tgt_language = lan # Whisper transcribes in this language
# Create the tokenizer
if buffer_trimming == "sentence":
tokenizer = create_tokenizer(tgt_language)
else:
tokenizer = None
warmup_asr(asr, warmup_file)
asr.confidence_validation = confidence_validation
asr.tokenizer = tokenizer
asr.buffer_trimming = buffer_trimming
asr.buffer_trimming_sec = buffer_trimming_sec
asr.backend_choice = backend_choice
return asr
def _normalize_backend_choice(
preferred_backend,
resolved_root,
has_mlx_weights,
has_fw_weights,
):
backend_choice = preferred_backend
if backend_choice == "auto":
if mlx_backend_available(warn_on_missing=True) and (resolved_root is None or has_mlx_weights):
return "mlx-whisper"
if faster_backend_available(warn_on_missing=True) and (resolved_root is None or has_fw_weights):
return "faster-whisper"
return "whisper"
if backend_choice == "mlx-whisper":
if not mlx_backend_available():
raise RuntimeError("mlx-whisper backend requested but mlx-whisper is not installed.")
if resolved_root is not None and not has_mlx_weights:
raise FileNotFoundError(
f"mlx-whisper backend requested but no MLX weights were found under {resolved_root}"
)
if platform.system() != "Darwin":
logger.warning("mlx-whisper backend requested on a non-macOS system; this may fail.")
return backend_choice
if backend_choice == "faster-whisper":
if not faster_backend_available():
raise RuntimeError("faster-whisper backend requested but faster-whisper is not installed.")
if resolved_root is not None and not has_fw_weights:
raise FileNotFoundError(
f"faster-whisper backend requested but no Faster-Whisper weights were found under {resolved_root}"
)
return backend_choice
if backend_choice == "whisper":
return backend_choice
raise ValueError(f"Unknown backend '{preferred_backend}' for LocalAgreement.")

View File

@@ -0,0 +1,69 @@
from pathlib import Path
from typing import Optional, Tuple, Union
def model_path_and_type(model_path: Union[str, Path]) -> Tuple[Optional[Path], bool, bool]:
"""
Inspect the provided path and determine which model formats are available.
Returns:
pytorch_path: Path to a PyTorch checkpoint (if present).
compatible_whisper_mlx: True if MLX weights exist in this folder.
compatible_faster_whisper: True if Faster-Whisper (ctranslate2) weights exist.
"""
path = Path(model_path)
compatible_whisper_mlx = False
compatible_faster_whisper = False
pytorch_path: Optional[Path] = None
if path.is_file() and path.suffix.lower() in [".pt", ".safetensors", ".bin"]:
pytorch_path = path
return pytorch_path, compatible_whisper_mlx, compatible_faster_whisper
if path.is_dir():
for file in path.iterdir():
if not file.is_file():
continue
filename = file.name.lower()
suffix = file.suffix.lower()
if filename in {"weights.npz", "weights.safetensors"}:
compatible_whisper_mlx = True
elif filename in {"model.bin", "encoder.bin", "decoder.bin"}:
compatible_faster_whisper = True
elif suffix in {".pt", ".safetensors"}:
pytorch_path = file
elif filename == "pytorch_model.bin":
pytorch_path = file
if pytorch_path is None:
fallback = path / "pytorch_model.bin"
if fallback.exists():
pytorch_path = fallback
return pytorch_path, compatible_whisper_mlx, compatible_faster_whisper
def resolve_model_path(model_path: Union[str, Path]) -> Path:
"""
Return a local path for the provided model reference.
If the path does not exist locally, it is treated as a Hugging Face repo id
and downloaded via snapshot_download.
"""
path = Path(model_path).expanduser()
if path.exists():
return path
try:
from huggingface_hub import snapshot_download
except ImportError as exc: # pragma: no cover - optional dependency guard
raise FileNotFoundError(
f"Model path '{model_path}' does not exist locally and huggingface_hub "
"is not installed to download it."
) from exc
downloaded_path = Path(snapshot_download(repo_id=str(model_path)))
return downloaded_path

View File

@@ -81,14 +81,15 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--min-chunk-size", "--min-chunk-size",
type=float, type=float,
default=0.5, default=0.1,
help="Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.", help="Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.",
) )
parser.add_argument( parser.add_argument(
"--model", "--model",
type=str, type=str,
default="small", default="base",
dest='model_size',
help="Name size of the Whisper model to use (default: tiny). Suggested values: tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo. The model is automatically downloaded from the model hub if not present in model cache dir.", help="Name size of the Whisper model to use (default: tiny). Suggested values: tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo. The model is automatically downloaded from the model hub if not present in model cache dir.",
) )
@@ -109,14 +110,14 @@ def parse_args():
"--language", "--language",
type=str, type=str,
default="auto", default="auto",
dest='lan',
help="Source language code, e.g. en,de,cs, or 'auto' for language detection.", help="Source language code, e.g. en,de,cs, or 'auto' for language detection.",
) )
parser.add_argument( parser.add_argument(
"--task", "--direct-english-translation",
type=str, action="store_true",
default="transcribe", default=False,
choices=["transcribe", "translate"], help="Use Whisper to directly translate to english.",
help="Transcribe or translate.",
) )
parser.add_argument( parser.add_argument(
@@ -128,11 +129,18 @@ def parse_args():
) )
parser.add_argument( parser.add_argument(
"--backend", "--backend-policy",
type=str, type=str,
default="simulstreaming", default="simulstreaming",
choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api", "simulstreaming"], choices=["1", "2", "simulstreaming", "localagreement"],
help="Load only this backend for Whisper processing.", help="Select the streaming policy: 1 or 'simulstreaming' for AlignAtt, 2 or 'localagreement' for LocalAgreement.",
)
parser.add_argument(
"--backend",
type=str,
default="auto",
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api"],
help="Select the Whisper backend implementation (auto: prefer MLX on macOS, otherwise Faster-Whisper, else Whisper). Use 'openai-api' with --backend-policy localagreement to call OpenAI's API.",
) )
parser.add_argument( parser.add_argument(
"--no-vac", "--no-vac",
@@ -173,6 +181,7 @@ def parse_args():
) )
parser.add_argument("--ssl-certfile", type=str, help="Path to the SSL certificate file.", default=None) parser.add_argument("--ssl-certfile", type=str, help="Path to the SSL certificate file.", default=None)
parser.add_argument("--ssl-keyfile", type=str, help="Path to the SSL private key file.", default=None) parser.add_argument("--ssl-keyfile", type=str, help="Path to the SSL private key file.", default=None)
parser.add_argument("--forwarded-allow-ips", type=str, help="Allowed ips for reverse proxying.", default=None)
parser.add_argument( parser.add_argument(
"--pcm-input", "--pcm-input",
action="store_true", action="store_true",
@@ -190,6 +199,13 @@ def parse_args():
help="Disable Faster Whisper or MLX Whisper backends for encoding (if installed). Slower but helpful when GPU memory is limited", help="Disable Faster Whisper or MLX Whisper backends for encoding (if installed). Slower but helpful when GPU memory is limited",
) )
simulstreaming_group.add_argument(
"--custom-alignment-heads",
type=str,
default=None,
help="Use your own alignment heads, useful when `--model-dir` is used",
)
simulstreaming_group.add_argument( simulstreaming_group.add_argument(
"--frame-threshold", "--frame-threshold",
type=int, type=int,
@@ -290,7 +306,7 @@ def parse_args():
simulstreaming_group.add_argument( simulstreaming_group.add_argument(
"--nllb-backend", "--nllb-backend",
type=str, type=str,
default="ctranslate2", default="transformers",
help="transformers or ctranslate2", help="transformers or ctranslate2",
) )
@@ -308,4 +324,9 @@ def parse_args():
delattr(args, 'no_transcription') delattr(args, 'no_transcription')
delattr(args, 'no_vad') delattr(args, 'no_vad')
if args.backend_policy == "1":
args.backend_policy = "simulstreaming"
elif args.backend_policy == "2":
args.backend_policy = "localagreement"
return args return args

View File

@@ -1,110 +0,0 @@
from whisperlivekit.timed_objects import ASRToken
import re
MIN_SILENCE_DURATION = 4 #in seconds
END_SILENCE_DURATION = 8 #in seconds. you should keep it important to not have false positive when the model lag is important
END_SILENCE_DURATION_VAC = 3 #VAC is good at detecting silences, but we want to skip the smallest silences
def blank_to_silence(tokens):
full_string = ''.join([t.text for t in tokens])
patterns = [re.compile(r'(?:\s*\[BLANK_AUDIO\]\s*)+'), re.compile(r'(?:\s*\[typing\]\s*)+')]
matches = []
for pattern in patterns:
for m in pattern.finditer(full_string):
matches.append({
'start': m.start(),
'end': m.end()
})
if matches:
# cleaned = pattern.sub(' ', full_string).strip()
# print("Cleaned:", cleaned)
cumulated_len = 0
silence_token = None
cleaned_tokens = []
for token in tokens:
if matches:
start = cumulated_len
end = cumulated_len + len(token.text)
cumulated_len = end
if start >= matches[0]['start'] and end <= matches[0]['end']:
if silence_token: #previous token was already silence
silence_token.start = min(silence_token.start, token.start)
silence_token.end = max(silence_token.end, token.end)
else: #new silence
silence_token = ASRToken(
start=token.start,
end=token.end,
speaker=-2,
probability=0.95
)
else:
if silence_token: #there was silence but no more
if silence_token.duration() >= MIN_SILENCE_DURATION:
cleaned_tokens.append(
silence_token
)
silence_token = None
matches.pop(0)
cleaned_tokens.append(token)
# print(cleaned_tokens)
return cleaned_tokens
return tokens
def no_token_to_silence(tokens):
new_tokens = []
silence_token = None
for token in tokens:
if token.speaker == -2:
if new_tokens and new_tokens[-1].speaker == -2: #if token is silence and previous one too
new_tokens[-1].end = token.end
else:
new_tokens.append(token)
last_end = new_tokens[-1].end if new_tokens else 0.0
if token.start - last_end >= MIN_SILENCE_DURATION: #if token is not silence but important gap
if new_tokens and new_tokens[-1].speaker == -2:
new_tokens[-1].end = token.start
else:
silence_token = ASRToken(
start=last_end,
end=token.start,
speaker=-2,
probability=0.95
)
new_tokens.append(silence_token)
if token.speaker != -2:
new_tokens.append(token)
return new_tokens
def ends_with_silence(tokens, current_time, vac_detected_silence):
end_w_silence = False
if not tokens:
return [], end_w_silence
last_token = tokens[-1]
if tokens and current_time and (
current_time - last_token.end >= END_SILENCE_DURATION
or
(current_time - last_token.end >= 3 and vac_detected_silence)
):
end_w_silence = True
if last_token.speaker == -2:
last_token.end = current_time
else:
tokens.append(
ASRToken(
start=tokens[-1].end,
end=current_time,
speaker=-2,
probability=0.95
)
)
return tokens, end_w_silence
def handle_silences(tokens, current_time, vac_detected_silence):
tokens = blank_to_silence(tokens) #useful for simulstreaming backend which tends to generate [BLANK_AUDIO] text
tokens = no_token_to_silence(tokens)
tokens, end_w_silence = ends_with_silence(tokens, current_time, vac_detected_silence)
return tokens, end_w_silence

View File

@@ -1,157 +0,0 @@
import logging
from whisperlivekit.remove_silences import handle_silences
from whisperlivekit.timed_objects import Line, format_time
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
CHECK_AROUND = 4
def is_punctuation(token):
if token.is_punctuation():
return True
return False
def next_punctuation_change(i, tokens):
for ind in range(i+1, min(len(tokens), i+CHECK_AROUND+1)):
if is_punctuation(tokens[ind]):
return ind
return None
def next_speaker_change(i, tokens, speaker):
for ind in range(i-1, max(0, i-CHECK_AROUND)-1, -1):
token = tokens[ind]
if is_punctuation(token):
break
if token.speaker != speaker:
return ind, token.speaker
return None, speaker
def new_line(
token,
speaker,
debug_info = ""
):
return Line(
speaker = speaker,
text = token.text + debug_info,
start = token.start,
end = token.end,
)
def append_token_to_last_line(lines, sep, token, debug_info):
if token.text:
lines[-1].text += sep + token.text + debug_info
lines[-1].end = token.end
def format_output(state, silence, current_time, args, debug, sep):
diarization = args.diarization
disable_punctuation_split = args.disable_punctuation_split
tokens = state.tokens
translated_segments = state.translated_segments # Here we will attribute the speakers only based on the timestamps of the segments
end_attributed_speaker = state.end_attributed_speaker
previous_speaker = -1
lines = []
undiarized_text = []
tokens, end_w_silence = handle_silences(tokens, current_time, silence)
last_punctuation = None
for i, token in enumerate(tokens):
speaker = token.speaker
if not diarization and speaker == -1: #Speaker -1 means no attributed by diarization. In the frontend, it should appear under 'Speaker 1'
speaker = 1
if diarization and not tokens[-1].speaker == -2:
if (speaker in [-1, 0]) and token.end >= end_attributed_speaker:
undiarized_text.append(token.text)
continue
elif (speaker in [-1, 0]) and token.end < end_attributed_speaker:
speaker = previous_speaker
debug_info = ""
if debug:
debug_info = f"[{format_time(token.start)} : {format_time(token.end)}]"
if not lines:
lines.append(new_line(token, speaker, debug_info = ""))
continue
else:
previous_speaker = lines[-1].speaker
if is_punctuation(token):
last_punctuation = i
if last_punctuation == i-1:
if speaker != previous_speaker:
# perfect, diarization perfectly aligned
lines.append(new_line(token, speaker, debug_info = ""))
last_punctuation, next_punctuation = None, None
continue
speaker_change_pos, new_speaker = next_speaker_change(i, tokens, speaker)
if speaker_change_pos:
# Corrects delay:
# That was the idea. Okay haha |SPLIT SPEAKER| that's a good one
# should become:
# That was the idea. |SPLIT SPEAKER| Okay haha that's a good one
lines.append(new_line(token, new_speaker, debug_info = ""))
else:
# No speaker change to come
append_token_to_last_line(lines, sep, token, debug_info)
continue
if speaker != previous_speaker:
if speaker == -2 or previous_speaker == -2: #silences can happen anytime
lines.append(new_line(token, speaker, debug_info = ""))
continue
elif next_punctuation_change(i, tokens):
# Corrects advance:
# Are you |SPLIT SPEAKER| okay? yeah, sure. Absolutely
# should become:
# Are you okay? |SPLIT SPEAKER| yeah, sure. Absolutely
append_token_to_last_line(lines, sep, token, debug_info)
continue
else: #we create a new speaker, but that's no ideal. We are not sure about the split. We prefer to append to previous line
if disable_punctuation_split:
lines.append(new_line(token, speaker, debug_info = ""))
continue
pass
append_token_to_last_line(lines, sep, token, debug_info)
if lines and translated_segments:
unassigned_translated_segments = []
for ts in translated_segments:
assigned = False
for line in lines:
if ts and ts.overlaps_with(line):
if ts.is_within(line):
line.translation += ts.text + ' '
assigned = True
break
else:
ts0, ts1 = ts.approximate_cut_at(line.end)
if ts0 and line.overlaps_with(ts0):
line.translation += ts0.text + ' '
if ts1:
unassigned_translated_segments.append(ts1)
assigned = True
break
if not assigned:
unassigned_translated_segments.append(ts)
if unassigned_translated_segments:
for line in lines:
remaining_segments = []
for ts in unassigned_translated_segments:
if ts and ts.overlaps_with(line):
line.translation += ts.text + ' '
else:
remaining_segments.append(ts)
unassigned_translated_segments = remaining_segments #maybe do smth in the future about that
if state.buffer_transcription and lines:
lines[-1].end = max(state.buffer_transcription.end, lines[-1].end)
return lines, undiarized_text, end_w_silence

View File

@@ -1,27 +1,182 @@
import torch import torch
import numpy as np
import warnings
from pathlib import Path
# This is copied from silero-vad's vad_utils.py: """
# https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/utils_vad.py#L340 Code is adapted from silero-vad v6: https://github.com/snakers4/silero-vad
# (except changed defaults) """
# Their licence is MIT, same as ours: https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/LICENSE def init_jit_model(model_path: str, device=torch.device('cpu')):
"""Load a JIT model from file."""
model = torch.jit.load(model_path, map_location=device)
model.eval()
return model
class OnnxWrapper():
"""ONNX Runtime wrapper for Silero VAD model."""
def __init__(self, path, force_onnx_cpu=False):
global np
import numpy as np
import onnxruntime
opts = onnxruntime.SessionOptions()
opts.inter_op_num_threads = 1
opts.intra_op_num_threads = 1
if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts)
else:
self.session = onnxruntime.InferenceSession(path, sess_options=opts)
self.reset_states()
if '16k' in path:
warnings.warn('This model support only 16000 sampling rate!')
self.sample_rates = [16000]
else:
self.sample_rates = [8000, 16000]
def _validate_input(self, x, sr: int):
if x.dim() == 1:
x = x.unsqueeze(0)
if x.dim() > 2:
raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}")
if sr != 16000 and (sr % 16000 == 0):
step = sr // 16000
x = x[:,::step]
sr = 16000
if sr not in self.sample_rates:
raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)")
if sr / x.shape[1] > 31.25:
raise ValueError("Input audio chunk is too short")
return x, sr
def reset_states(self, batch_size=1):
self._state = torch.zeros((2, batch_size, 128)).float()
self._context = torch.zeros(0)
self._last_sr = 0
self._last_batch_size = 0
def __call__(self, x, sr: int):
x, sr = self._validate_input(x, sr)
num_samples = 512 if sr == 16000 else 256
if x.shape[-1] != num_samples:
raise ValueError(f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")
batch_size = x.shape[0]
context_size = 64 if sr == 16000 else 32
if not self._last_batch_size:
self.reset_states(batch_size)
if (self._last_sr) and (self._last_sr != sr):
self.reset_states(batch_size)
if (self._last_batch_size) and (self._last_batch_size != batch_size):
self.reset_states(batch_size)
if not len(self._context):
self._context = torch.zeros(batch_size, context_size)
x = torch.cat([self._context, x], dim=1)
if sr in [8000, 16000]:
ort_inputs = {'input': x.numpy(), 'state': self._state.numpy(), 'sr': np.array(sr, dtype='int64')}
ort_outs = self.session.run(None, ort_inputs)
out, state = ort_outs
self._state = torch.from_numpy(state)
else:
raise ValueError()
self._context = x[..., -context_size:]
self._last_sr = sr
self._last_batch_size = batch_size
out = torch.from_numpy(out)
return out
def load_silero_vad(model_path: str = None, onnx: bool = False, opset_version: int = 16):
"""
Load Silero VAD model (JIT or ONNX).
Parameters
----------
model_path : str, optional
Path to model file. If None, uses default bundled model.
onnx : bool, default False
Whether to use ONNX runtime (requires onnxruntime package).
opset_version : int, default 16
ONNX opset version (15 or 16). Only used if onnx=True.
Returns
-------
model
Loaded VAD model (JIT or ONNX wrapper)
"""
available_ops = [15, 16]
if onnx and opset_version not in available_ops:
raise Exception(f'Available ONNX opset_version: {available_ops}')
if model_path is None:
current_dir = Path(__file__).parent
data_dir = current_dir / 'vad_models'
if onnx:
if opset_version == 16:
model_name = 'silero_vad.onnx'
else:
model_name = f'silero_vad_16k_op{opset_version}.onnx'
else:
model_name = 'silero_vad.jit'
model_path = data_dir / model_name
if not model_path.exists():
raise FileNotFoundError(
f"Model file not found: {model_path}\n"
f"Please ensure the whisperlivekit/vad_models/ directory contains the model files."
)
else:
model_path = Path(model_path)
if onnx:
try:
model = OnnxWrapper(str(model_path), force_onnx_cpu=True)
except ImportError:
raise ImportError(
"ONNX runtime not available. Install with: pip install onnxruntime\n"
"Or use JIT model by setting onnx=False"
)
else:
model = init_jit_model(str(model_path))
return model
class VADIterator: class VADIterator:
def __init__( """
self, Voice Activity Detection iterator for streaming audio.
model,
threshold: float = 0.5, This is the Silero VAD v6 implementation.
sampling_rate: int = 16000, """
min_silence_duration_ms: int = 500, # makes sense on one recording that I checked
speech_pad_ms: int = 100, # same def __init__(self,
): model,
threshold: float = 0.5,
sampling_rate: int = 16000,
min_silence_duration_ms: int = 100,
speech_pad_ms: int = 30
):
""" """
Class for stream imitation Class for stream imitation
Parameters Parameters
---------- ----------
model: preloaded .jit silero VAD model model: preloaded .jit/.onnx silero VAD model
threshold: float (default - 0.5) threshold: float (default - 0.5)
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH. Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
@@ -42,9 +197,7 @@ class VADIterator:
self.sampling_rate = sampling_rate self.sampling_rate = sampling_rate
if sampling_rate not in [8000, 16000]: if sampling_rate not in [8000, 16000]:
raise ValueError( raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]')
"VADIterator does not support sampling rates other than [8000, 16000]"
)
self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000 self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000 self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
@@ -57,13 +210,17 @@ class VADIterator:
self.temp_end = 0 self.temp_end = 0
self.current_sample = 0 self.current_sample = 0
def __call__(self, x, return_seconds=False): @torch.no_grad()
def __call__(self, x, return_seconds=False, time_resolution: int = 1):
""" """
x: torch.Tensor x: torch.Tensor
audio chunk (see examples in repo) audio chunk (see examples in repo)
return_seconds: bool (default - False) return_seconds: bool (default - False)
whether return timestamps in seconds (default - samples) whether return timestamps in seconds (default - samples)
time_resolution: int (default - 1)
time resolution of speech coordinates when requested as seconds
""" """
if not torch.is_tensor(x): if not torch.is_tensor(x):
@@ -82,14 +239,8 @@ class VADIterator:
if (speech_prob >= self.threshold) and not self.triggered: if (speech_prob >= self.threshold) and not self.triggered:
self.triggered = True self.triggered = True
speech_start = self.current_sample - self.speech_pad_samples speech_start = max(0, self.current_sample - self.speech_pad_samples - window_size_samples)
return { return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, time_resolution)}
"start": (
int(speech_start)
if not return_seconds
else round(speech_start / self.sampling_rate, 1)
)
}
if (speech_prob < self.threshold - 0.15) and self.triggered: if (speech_prob < self.threshold - 0.15) and self.triggered:
if not self.temp_end: if not self.temp_end:
@@ -97,30 +248,17 @@ class VADIterator:
if self.current_sample - self.temp_end < self.min_silence_samples: if self.current_sample - self.temp_end < self.min_silence_samples:
return None return None
else: else:
speech_end = self.temp_end + self.speech_pad_samples speech_end = self.temp_end + self.speech_pad_samples - window_size_samples
self.temp_end = 0 self.temp_end = 0
self.triggered = False self.triggered = False
return { return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, time_resolution)}
"end": (
int(speech_end)
if not return_seconds
else round(speech_end / self.sampling_rate, 1)
)
}
return None return None
#######################
# because Silero now requires exactly 512-sized audio chunks
import numpy as np
class FixedVADIterator(VADIterator): class FixedVADIterator(VADIterator):
"""It fixes VADIterator by allowing to process any audio length, not only exactly 512 frames at once. """
If audio to be processed at once is long and multiple voiced segments detected, Fixed VAD Iterator that handles variable-length audio chunks, not only exactly 512 frames at once.
then __call__ returns the start of the first segment, and end (or middle, which means no end) of the last segment.
""" """
def reset_states(self): def reset_states(self):
@@ -137,27 +275,20 @@ class FixedVADIterator(VADIterator):
ret = r ret = r
elif r is not None: elif r is not None:
if "end" in r: if "end" in r:
ret["end"] = r["end"] # the latter end ret["end"] = r["end"]
if "start" in r and "end" in ret: # there is an earlier start. if "start" in r and "end" in ret:
# Remove end, merging this segment with the previous one.
del ret["end"] del ret["end"]
return ret if ret != {} else None return ret if ret != {} else None
if __name__ == "__main__": if __name__ == "__main__":
# test/demonstrate the need for FixedVADIterator: model = load_silero_vad(onnx=False)
vad = FixedVADIterator(model)
import torch audio_buffer = np.array([0] * 512, dtype=np.float32)
result = vad(audio_buffer)
print(f" 512 samples: {result}")
model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad") # test with 511 samples
vac = FixedVADIterator(model) audio_buffer = np.array([0] * 511, dtype=np.float32)
# vac = VADIterator(model) # the second case crashes with this result = vad(audio_buffer)
# this works: for both
audio_buffer = np.array([0] * (512), dtype=np.float32)
vac(audio_buffer)
# this crashes on the non FixedVADIterator with
# ops.prim.RaiseException("Input audio chunk is too short", "builtins.ValueError")
audio_buffer = np.array([0] * (512 - 1), dtype=np.float32)
vac(audio_buffer)

View File

@@ -2,43 +2,39 @@ import sys
import numpy as np import numpy as np
import logging import logging
from typing import List, Tuple, Optional from typing import List, Tuple, Optional
import logging
import platform import platform
from whisperlivekit.timed_objects import ASRToken, Transcript, SpeakerSegment from whisperlivekit.timed_objects import ASRToken, Transcript, ChangeSpeaker
from whisperlivekit.warmup import load_file from whisperlivekit.warmup import load_file
from .whisper import load_model, tokenizer from whisperlivekit.whisper import load_model, tokenizer
from .whisper.audio import TOKENS_PER_SECOND from whisperlivekit.whisper.audio import TOKENS_PER_SECOND
import os import os
import gc import gc
logger = logging.getLogger(__name__) from pathlib import Path
from whisperlivekit.model_paths import model_path_and_type, resolve_model_path
from whisperlivekit.backend_support import (
mlx_backend_available,
faster_backend_available,
)
import torch import torch
from whisperlivekit.simul_whisper.config import AlignAttConfig from whisperlivekit.simul_whisper.config import AlignAttConfig
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper from whisperlivekit.simul_whisper.simul_whisper import AlignAtt
from whisperlivekit.simul_whisper.whisper import tokenizer
try: logger = logging.getLogger(__name__)
from .mlx_encoder import mlx_model_mapping, load_mlx_encoder
HAS_MLX_WHISPER = True
except ImportError: HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True)
if platform.system() == "Darwin" and platform.machine() == "arm64":
print(f"""
{"="*50}
MLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: pip install mlx-whisper
{"="*50}
""")
HAS_MLX_WHISPER = False
if HAS_MLX_WHISPER: if HAS_MLX_WHISPER:
HAS_FASTER_WHISPER = False from .mlx_encoder import mlx_model_mapping, load_mlx_encoder
else: else:
try: mlx_model_mapping = {}
from faster_whisper import WhisperModel HAS_FASTER_WHISPER = faster_backend_available(warn_on_missing=not HAS_MLX_WHISPER)
HAS_FASTER_WHISPER = True if HAS_FASTER_WHISPER:
except ImportError: from faster_whisper import WhisperModel
HAS_FASTER_WHISPER = False else:
WhisperModel = None
MIN_DURATION_REAL_SILENCE = 5
# TOO_MANY_REPETITIONS = 3
class SimulStreamingOnlineProcessor: class SimulStreamingOnlineProcessor:
SAMPLING_RATE = 16000 SAMPLING_RATE = 16000
@@ -47,7 +43,6 @@ class SimulStreamingOnlineProcessor:
self, self,
asr, asr,
logfile=sys.stderr, logfile=sys.stderr,
warmup_file=None
): ):
self.asr = asr self.asr = asr
self.logfile = logfile self.logfile = logfile
@@ -63,23 +58,29 @@ class SimulStreamingOnlineProcessor:
def load_new_backend(self): def load_new_backend(self):
model = self.asr.get_new_model_instance() model = self.asr.get_new_model_instance()
self.model = PaddedAlignAttWhisper( self.model = AlignAtt(
cfg=self.asr.cfg, cfg=self.asr.cfg,
loaded_model=model, loaded_model=model,
mlx_encoder=self.asr.mlx_encoder, mlx_encoder=self.asr.mlx_encoder,
fw_encoder=self.asr.fw_encoder, fw_encoder=self.asr.fw_encoder,
) )
def insert_silence(self, silence_duration, offset): def start_silence(self):
tokens, processed_upto = self.process_iter(is_last=True)
return tokens, processed_upto
def end_silence(self, silence_duration, offset):
""" """
If silences are > 5s, we do a complete context clear. Otherwise, we just insert a small silence and shift the last_attend_frame If silences are > MIN_DURATION_REAL_SILENCE, we do a complete context clear. Otherwise, we just insert a small silence and shift the last_attend_frame
""" """
if silence_duration < 5: self.end += silence_duration
gap_silence = torch.zeros(int(16000*silence_duration)) long_silence = silence_duration >= MIN_DURATION_REAL_SILENCE
self.model.insert_audio(gap_silence) if not long_silence:
# self.global_time_offset += silence_duration gap_len = int(16000 * silence_duration)
else: if gap_len > 0:
self.process_iter(is_last=True) #we want to totally process what remains in the buffer. gap_silence = torch.zeros(gap_len)
self.model.insert_audio(gap_silence)
if long_silence:
self.model.refresh_segment(complete=True) self.model.refresh_segment(complete=True)
self.model.global_time_offset = silence_duration + offset self.model.global_time_offset = silence_duration + offset
@@ -93,9 +94,11 @@ class SimulStreamingOnlineProcessor:
self.end = audio_stream_end_time #Only to be aligned with what happens in whisperstreaming backend. self.end = audio_stream_end_time #Only to be aligned with what happens in whisperstreaming backend.
self.model.insert_audio(audio_tensor) self.model.insert_audio(audio_tensor)
def on_new_speaker(self, last_segment: SpeakerSegment): def new_speaker(self, change_speaker: ChangeSpeaker):
self.model.on_new_speaker(last_segment) self.process_iter(is_last=True)
self.model.refresh_segment(complete=True) self.model.refresh_segment(complete=True)
self.model.speaker = change_speaker.speaker
self.global_time_offset = change_speaker.start
def get_buffer(self): def get_buffer(self):
concat_buffer = Transcript.from_tokens(tokens= self.buffer, sep='') concat_buffer = Transcript.from_tokens(tokens= self.buffer, sep='')
@@ -108,9 +111,13 @@ class SimulStreamingOnlineProcessor:
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time). Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
""" """
try: try:
timestamped_words, timestamped_buffer_language = self.model.infer(is_last=is_last) timestamped_words = self.model.infer(is_last=is_last)
self.buffer = timestamped_buffer_language if self.model.cfg.language == "auto" and timestamped_words and timestamped_words[0].detected_language == None:
self.buffer.extend(timestamped_words)
return [], self.end
self.committed.extend(timestamped_words) self.committed.extend(timestamped_words)
self.buffer = []
return timestamped_words, self.end return timestamped_words, self.end
@@ -140,31 +147,34 @@ class SimulStreamingASR():
"""SimulStreaming backend with AlignAtt policy.""" """SimulStreaming backend with AlignAtt policy."""
sep = "" sep = ""
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr, **kwargs): def __init__(self, logfile=sys.stderr, **kwargs):
self.logfile = logfile self.logfile = logfile
self.transcribe_kargs = {} self.transcribe_kargs = {}
self.original_language = lan
self.model_path = kwargs.get('model_path', './large-v3.pt') for key, value in kwargs.items():
self.frame_threshold = kwargs.get('frame_threshold', 25) setattr(self, key, value)
self.audio_max_len = kwargs.get('audio_max_len', 20.0)
self.audio_min_len = kwargs.get('audio_min_len', 0.0) if self.decoder_type is None:
self.segment_length = kwargs.get('segment_length', 0.5) self.decoder_type = 'greedy' if self.beams == 1 else 'beam'
self.beams = kwargs.get('beams', 1)
self.decoder_type = kwargs.get('decoder_type', 'greedy' if self.beams == 1 else 'beam')
self.task = kwargs.get('task', 'transcribe')
self.cif_ckpt_path = kwargs.get('cif_ckpt_path', None)
self.never_fire = kwargs.get('never_fire', False)
self.init_prompt = kwargs.get('init_prompt', None)
self.static_init_prompt = kwargs.get('static_init_prompt', None)
self.max_context_tokens = kwargs.get('max_context_tokens', None)
self.warmup_file = kwargs.get('warmup_file', None)
self.preload_model_count = kwargs.get('preload_model_count', 1)
self.disable_fast_encoder = kwargs.get('disable_fast_encoder', False)
self.fast_encoder = False self.fast_encoder = False
if model_dir is not None: self._resolved_model_path = None
self.model_path = model_dir self.encoder_backend = "whisper"
elif modelsize is not None: preferred_backend = getattr(self, "backend", "auto")
self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = None, True, True
if self.model_path:
resolved_model_path = resolve_model_path(self.model_path)
self._resolved_model_path = resolved_model_path
self.model_path = str(resolved_model_path)
self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = model_path_and_type(resolved_model_path)
if self.pytorch_path:
self.model_name = self.pytorch_path.stem
else:
self.model_name = Path(self.model_path).stem
raise FileNotFoundError(
f"No PyTorch checkpoint (.pt/.bin/.safetensors) found under {self.model_path}"
)
elif self.model_size is not None:
model_mapping = { model_mapping = {
'tiny': './tiny.pt', 'tiny': './tiny.pt',
'base': './base.pt', 'base': './base.pt',
@@ -179,19 +189,32 @@ class SimulStreamingASR():
'large-v3': './large-v3.pt', 'large-v3': './large-v3.pt',
'large': './large-v3.pt' 'large': './large-v3.pt'
} }
self.model_path = model_mapping.get(modelsize, f'./{modelsize}.pt') self.model_name = self.model_size
else:
raise ValueError("Either model_size or model_path must be specified for SimulStreaming.")
is_multilingual = not self.model_name.endswith(".en")
self.encoder_backend = self._resolve_encoder_backend(
preferred_backend,
compatible_whisper_mlx,
compatible_faster_whisper,
)
self.fast_encoder = self.encoder_backend in ("mlx-whisper", "faster-whisper")
if self.encoder_backend == "whisper":
self.disable_fast_encoder = True
self.cfg = AlignAttConfig( self.cfg = AlignAttConfig(
model_path=self.model_path, tokenizer_is_multilingual= is_multilingual,
segment_length=self.segment_length, segment_length=self.min_chunk_size,
frame_threshold=self.frame_threshold, frame_threshold=self.frame_threshold,
language=self.original_language, language=self.lan,
audio_max_len=self.audio_max_len, audio_max_len=self.audio_max_len,
audio_min_len=self.audio_min_len, audio_min_len=self.audio_min_len,
cif_ckpt_path=self.cif_ckpt_path, cif_ckpt_path=self.cif_ckpt_path,
decoder_type="beam", decoder_type="beam",
beam_size=self.beams, beam_size=self.beams,
task=self.task, task=self.direct_english_translation,
never_fire=self.never_fire, never_fire=self.never_fire,
init_prompt=self.init_prompt, init_prompt=self.init_prompt,
max_context_tokens=self.max_context_tokens, max_context_tokens=self.max_context_tokens,
@@ -199,40 +222,93 @@ class SimulStreamingASR():
) )
# Set up tokenizer for translation if needed # Set up tokenizer for translation if needed
if self.task == "translate": if self.direct_english_translation:
self.tokenizer = self.set_translate_task() self.tokenizer = self.set_translate_task()
else: else:
self.tokenizer = None self.tokenizer = None
self.model_name = os.path.basename(self.cfg.model_path).replace(".pt", "")
self.model_path = os.path.dirname(os.path.abspath(self.cfg.model_path))
self.mlx_encoder, self.fw_encoder = None, None self.mlx_encoder, self.fw_encoder = None, None
if not self.disable_fast_encoder: if self.encoder_backend == "mlx-whisper":
if HAS_MLX_WHISPER: print('Simulstreaming will use MLX whisper to increase encoding speed.')
print('Simulstreaming will use MLX whisper for a faster encoder.') if self._resolved_model_path is not None:
mlx_model_name = mlx_model_mapping[self.model_name] mlx_model = str(self._resolved_model_path)
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_name) else:
self.fast_encoder = True mlx_model = mlx_model_mapping.get(self.model_name)
elif HAS_FASTER_WHISPER: if not mlx_model:
print('Simulstreaming will use Faster Whisper for the encoder.') raise FileNotFoundError(
self.fw_encoder = WhisperModel( f"MLX Whisper backend requested but no compatible weights found for model '{self.model_name}'."
self.model_name,
device='auto',
compute_type='auto',
) )
self.fast_encoder = True self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model)
elif self.encoder_backend == "faster-whisper":
print('Simulstreaming will use Faster Whisper for the encoder.')
if self._resolved_model_path is not None:
fw_model = str(self._resolved_model_path)
else:
fw_model = self.model_name
self.fw_encoder = WhisperModel(
fw_model,
device='auto',
compute_type='auto',
)
self.models = [self.load_model() for i in range(self.preload_model_count)] self.models = [self.load_model() for i in range(self.preload_model_count)]
def _resolve_encoder_backend(self, preferred_backend, compatible_whisper_mlx, compatible_faster_whisper):
choice = preferred_backend or "auto"
if self.disable_fast_encoder:
return "whisper"
if choice == "whisper":
return "whisper"
if choice == "mlx-whisper":
if not self._can_use_mlx(compatible_whisper_mlx):
raise RuntimeError("mlx-whisper backend requested but MLX Whisper is unavailable or incompatible with the provided model.")
return "mlx-whisper"
if choice == "faster-whisper":
if not self._can_use_faster(compatible_faster_whisper):
raise RuntimeError("faster-whisper backend requested but Faster-Whisper is unavailable or incompatible with the provided model.")
return "faster-whisper"
if choice == "openai-api":
raise ValueError("openai-api backend is only supported with the LocalAgreement policy.")
# auto mode
if platform.system() == "Darwin" and self._can_use_mlx(compatible_whisper_mlx):
return "mlx-whisper"
if self._can_use_faster(compatible_faster_whisper):
return "faster-whisper"
return "whisper"
def _has_custom_model_path(self):
return self._resolved_model_path is not None
def _can_use_mlx(self, compatible_whisper_mlx):
if not HAS_MLX_WHISPER:
return False
if self._has_custom_model_path():
return compatible_whisper_mlx
return self.model_name in mlx_model_mapping
def _can_use_faster(self, compatible_faster_whisper):
if not HAS_FASTER_WHISPER:
return False
if self._has_custom_model_path():
return compatible_faster_whisper
return True
def load_model(self): def load_model(self):
whisper_model = load_model(name=self.model_name, download_root=self.model_path, decoder_only=self.fast_encoder) whisper_model = load_model(
name=self.pytorch_path if self.pytorch_path else self.model_name,
download_root=self.model_path,
decoder_only=self.fast_encoder,
custom_alignment_heads=self.custom_alignment_heads
)
warmup_audio = load_file(self.warmup_file) warmup_audio = load_file(self.warmup_file)
if warmup_audio is not None: if warmup_audio is not None:
warmup_audio = torch.from_numpy(warmup_audio).float() warmup_audio = torch.from_numpy(warmup_audio).float()
if self.fast_encoder: if self.fast_encoder:
temp_model = PaddedAlignAttWhisper( temp_model = AlignAtt(
cfg=self.cfg, cfg=self.cfg,
loaded_model=whisper_model, loaded_model=whisper_model,
mlx_encoder=self.mlx_encoder, mlx_encoder=self.mlx_encoder,
@@ -243,7 +319,7 @@ class SimulStreamingASR():
else: else:
# For standard encoder, use the original transcribe warmup # For standard encoder, use the original transcribe warmup
warmup_audio = load_file(self.warmup_file) warmup_audio = load_file(self.warmup_file)
whisper_model.transcribe(warmup_audio, language=self.original_language if self.original_language != 'auto' else None) whisper_model.transcribe(warmup_audio, language=self.lan if self.lan != 'auto' else None)
return whisper_model return whisper_model
def get_new_model_instance(self): def get_new_model_instance(self):

View File

@@ -1,4 +1,4 @@
from .whisper.decoding import PyTorchInference from whisperlivekit.whisper.decoding import PyTorchInference
# extention of PyTorchInference for beam search # extention of PyTorchInference for beam search
class BeamPyTorchInference(PyTorchInference): class BeamPyTorchInference(PyTorchInference):

View File

@@ -1,25 +1,8 @@
# This code was originally in simul_whisper/transcriber/simul_whisper.py . It is adapted a lot for SimulStreaming.
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal from typing import Literal
@dataclass @dataclass
class SimulWhisperConfig: class AlignAttConfig():
'''Options that are common for all simul policies that could be implemented in SimulWhisper.'''
model_path: str
language: str = field(default="zh")
nonspeech_prob: float = 0.5
audio_min_len: float = 1.0
decoder_type: Literal["greedy","beam"] = "greedy"
beam_size: int = 5
task: Literal["transcribe","translate"] = "transcribe"
init_prompt: str = field(default=None)
static_init_prompt: str = field(default=None)
max_context_tokens: int = field(default=None)
@dataclass
class AlignAttConfig(SimulWhisperConfig):
'''Options specific to the AlignAtt policy.'''
eval_data_path: str = "tmp" eval_data_path: str = "tmp"
segment_length: float = field(default=1.0, metadata = {"help": "in second"}) segment_length: float = field(default=1.0, metadata = {"help": "in second"})
frame_threshold: int = 4 frame_threshold: int = 4
@@ -27,3 +10,14 @@ class AlignAttConfig(SimulWhisperConfig):
audio_max_len: float = 20.0 audio_max_len: float = 20.0
cif_ckpt_path: str = "" cif_ckpt_path: str = ""
never_fire: bool = False never_fire: bool = False
language: str = field(default="zh")
nonspeech_prob: float = 0.5
audio_min_len: float = 1.0
decoder_type: Literal["greedy","beam"] = "greedy"
beam_size: int = 5
task: Literal["transcribe","translate"] = "transcribe"
tokenizer_is_multilingual: bool = False
init_prompt: str = field(default=None)
static_init_prompt: str = field(default=None)
max_context_tokens: int = field(default=None)

View File

@@ -1,43 +0,0 @@
class Tokens:
def __init__(self, tokens):
self.tokens = tokens
# def clone(self):
# return Tokens(self.tokens.clone())
def __str__(self):
return str(self.tokens.tolist())
def __repr__(self):
return self.__str__()
class BeamTokens(Tokens):
def __init__(self, tokens, beam_size):
self.tokens = tokens
self.beam_size = beam_size
def clone(self):
return BeamTokens(self.tokens.clone())
def __str__(self):
return f"BeamTokens({self.tokens.tolist()}, beam_size={self.beam_size})"
def __repr__(self):
return self.__str__()
def as_text(self, tokenizer):
return tokenizer.decode(self.tokens)
class Logits(Tokens):
def __init__(self, logits):
super().__init__(logits)
# def clone(self):
# return Logits(self.tokens.clone(), self.beam_size)
def __str__(self):
# return "abc"
return f"Logits({self.tokens.shape})"
def __repr__(self):
return self.__str__()

View File

@@ -1,5 +0,0 @@
SIMULSTREAMING_LICENSE = f"""
SimulStreaming backend is dual-licensed:
• Non-Commercial Use: PolyForm Noncommercial License 1.0.0.
• Commercial Use: Check SimulStreaming README (github.com/ufal/SimulStreaming) for more details.
"""

View File

@@ -1,48 +1,57 @@
# This code was originally in simul_whisper/transcriber/simul_whisper.py . It is adapted a lot for SimulStreaming.
import os import os
import logging import logging
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np
from .whisper import load_model, DecodingOptions, tokenizer from whisperlivekit.whisper import DecodingOptions, tokenizer
from .config import AlignAttConfig from .config import AlignAttConfig
from whisperlivekit.timed_objects import ASRToken from whisperlivekit.timed_objects import ASRToken
from .whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES from whisperlivekit.whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES
from .whisper.timing import median_filter from whisperlivekit.whisper.timing import median_filter
from .whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens, detect_language from whisperlivekit.whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens
from .beam import BeamPyTorchInference from .beam import BeamPyTorchInference
from .eow_detection import fire_at_boundary, load_cif from .eow_detection import fire_at_boundary, load_cif
import os import os
from time import time from time import time
from .token_buffer import TokenBuffer from .token_buffer import TokenBuffer
from whisperlivekit.backend_support import (
mlx_backend_available,
faster_backend_available,
)
import numpy as np
from ..timed_objects import PUNCTUATION_MARKS from ..timed_objects import PUNCTUATION_MARKS
from .generation_progress import *
DEC_PAD = 50257 DEC_PAD = 50257
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if mlx_backend_available():
try:
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
HAS_MLX_WHISPER = True
except ImportError:
HAS_MLX_WHISPER = False
if HAS_MLX_WHISPER:
HAS_FASTER_WHISPER = False
else:
try:
from faster_whisper.audio import pad_or_trim as fw_pad_or_trim
from faster_whisper.feature_extractor import FeatureExtractor
HAS_FASTER_WHISPER = True
except ImportError:
HAS_FASTER_WHISPER = False
class PaddedAlignAttWhisper: if faster_backend_available():
from faster_whisper.audio import pad_or_trim as fw_pad_or_trim
from faster_whisper.feature_extractor import FeatureExtractor
USE_MLCORE = False
def load_coreml_encoder():
try:
from coremltools.models import MLModel
except ImportError:
logger.warning("coremltools is not installed")
return None
COREML_ENCODER_PATH = os.environ.get("MLCORE_ENCODER_PATH", "whisperlivekit/whisper/whisper_encoder.mlpackage")
_coreml_encoder = MLModel(COREML_ENCODER_PATH)
spec = _coreml_encoder.get_spec()
_coreml_input_name = spec.description.input[0].name if spec.description.input else "mel"
_coreml_output_name = spec.description.output[0].name if spec.description.output else None
return _coreml_encoder, _coreml_input_name, _coreml_output_name
class AlignAtt:
def __init__( def __init__(
self, self,
cfg: AlignAttConfig, cfg: AlignAttConfig,
@@ -51,34 +60,32 @@ class PaddedAlignAttWhisper:
fw_encoder=None, fw_encoder=None,
) -> None: ) -> None:
self.log_segments = 0 self.log_segments = 0
model_name = os.path.basename(cfg.model_path).replace(".pt", "")
model_path = os.path.dirname(os.path.abspath(cfg.model_path))
if loaded_model:
self.model = loaded_model
else:
self.model = load_model(name=model_name, download_root=model_path)
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model = loaded_model
self.mlx_encoder = mlx_encoder self.mlx_encoder = mlx_encoder
self.fw_encoder = fw_encoder self.fw_encoder = fw_encoder
if fw_encoder: if fw_encoder:
self.fw_feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels) self.fw_feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels)
self.coreml_encoder_tuple = None
if USE_MLCORE:
self.coreml_encoder_tuple = load_coreml_encoder()
self.use_mlcore = self.coreml_encoder_tuple is not None
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
logger.info(f"Model dimensions: {self.model.dims}") logger.info(f"Model dimensions: {self.model.dims}")
self.speaker = -1
self.decode_options = DecodingOptions( self.decode_options = DecodingOptions(
language = cfg.language, language = cfg.language,
without_timestamps = True, without_timestamps = True,
task=cfg.task task=cfg.task
) )
self.tokenizer_is_multilingual = not model_name.endswith(".en") self.tokenizer_is_multilingual = cfg.tokenizer_is_multilingual
self.create_tokenizer(cfg.language if cfg.language != "auto" else None) self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
# self.create_tokenizer('en') # self.create_tokenizer('en')
self.detected_language = cfg.language if cfg.language != "auto" else None self.detected_language = cfg.language if cfg.language != "auto" else None
self.global_time_offset = 0.0 self.global_time_offset = 0.0
self.reset_tokenizer_to_auto_next_call = False self.reset_tokenizer_to_auto_next_call = False
self.sentence_start_time = 0.0
self.max_text_len = self.model.dims.n_text_ctx self.max_text_len = self.model.dims.n_text_ctx
self.num_decoder_layers = len(self.model.decoder.blocks) self.num_decoder_layers = len(self.model.decoder.blocks)
@@ -153,7 +160,7 @@ class PaddedAlignAttWhisper:
self.last_attend_frame = -self.cfg.rewind_threshold self.last_attend_frame = -self.cfg.rewind_threshold
self.cumulative_time_offset = 0.0 self.cumulative_time_offset = 0.0
self.sentence_start_time = self.cumulative_time_offset + self.segments_len() self.first_timestamp = None
if self.cfg.max_context_tokens is None: if self.cfg.max_context_tokens is None:
self.max_context_tokens = self.max_text_len self.max_context_tokens = self.max_text_len
@@ -174,6 +181,9 @@ class PaddedAlignAttWhisper:
self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size) self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size)
# Tokens to carry over to next chunk for incomplete UTF-8 characters
self.pending_incomplete_tokens = []
def remove_hooks(self): def remove_hooks(self):
for hook in self.l_hooks: for hook in self.l_hooks:
hook.remove() hook.remove()
@@ -256,18 +266,18 @@ class PaddedAlignAttWhisper:
logger.debug("Refreshing segment:") logger.debug("Refreshing segment:")
self.init_tokens() self.init_tokens()
self.last_attend_frame = -self.cfg.rewind_threshold self.last_attend_frame = -self.cfg.rewind_threshold
self.detected_language = None # self.detected_language = None
self.cumulative_time_offset = 0.0 self.cumulative_time_offset = 0.0
self.init_context() self.init_context()
logger.debug(f"Context: {self.context}") logger.debug(f"Context: {self.context}")
if not complete and len(self.segments) > 2: if not complete and len(self.segments) > 2:
logger.debug("keeping last two segments because they are and it is not complete.")
self.segments = self.segments[-2:] self.segments = self.segments[-2:]
else: else:
logger.debug("removing all segments.") logger.debug("removing all segments.")
self.segments = [] self.segments = []
self.log_segments += 1 self.log_segments += 1
self.pending_incomplete_tokens = []
def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor): def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor):
if self.always_fire: return True if self.always_fire: return True
@@ -329,7 +339,7 @@ class PaddedAlignAttWhisper:
self.segments = self.segments[1:] self.segments = self.segments[1:]
logger.debug(f"remove segments: {len(self.segments)} {len(self.tokens)}, cumulative offset: {self.cumulative_time_offset:.2f}s") logger.debug(f"remove segments: {len(self.segments)} {len(self.tokens)}, cumulative offset: {self.cumulative_time_offset:.2f}s")
if len(self.tokens) > 1: if len(self.tokens) > 1:
self.context.append_token_ids(self.tokens[1][0,:]) self.context.append_token_ids(self.tokens[1][0,:].tolist())
self.tokens = [self.initial_tokens] + self.tokens[2:] self.tokens = [self.initial_tokens] + self.tokens[2:]
return removed_len return removed_len
@@ -395,15 +405,28 @@ class PaddedAlignAttWhisper:
else: else:
input_segments = self.segments[0] input_segments = self.segments[0]
# if self.cfg.language == "auto" and self.reset_tokenizer_to_auto_next_call:
# logger.debug("Resetting tokenizer to auto for new sentence.")
# self.create_tokenizer(None)
# self.detected_language = None
# self.init_tokens()
# self.reset_tokenizer_to_auto_next_call = False
# NEW : we can use a different encoder, before using standart whisper for cross attention with the hooks on the decoder
beg_encode = time() beg_encode = time()
if self.use_mlcore:
coreml_encoder, coreml_input_name, coreml_output_name = self.coreml_encoder_tuple
mel_padded = log_mel_spectrogram(
input_segments,
n_mels=self.model.dims.n_mels,
padding=N_SAMPLES,
device="cpu",
).unsqueeze(0)
mel = pad_or_trim(mel_padded, N_FRAMES)
content_mel_len = int((mel_padded.shape[2] - mel.shape[2]) / 2)
mel_np = np.ascontiguousarray(mel.numpy())
ml_inputs = {coreml_input_name or "mel": mel_np}
coreml_outputs = coreml_encoder.predict(ml_inputs)
if coreml_output_name and coreml_output_name in coreml_outputs:
encoder_feature_np = coreml_outputs[coreml_output_name]
else:
encoder_feature_np = next(iter(coreml_outputs.values()))
encoder_feature = torch.as_tensor(
np.array(encoder_feature_np),
device=self.device,
)
if self.mlx_encoder: if self.mlx_encoder:
mlx_mel_padded = mlx_log_mel_spectrogram(audio=input_segments.detach(), n_mels=self.model.dims.n_mels, padding=N_SAMPLES) mlx_mel_padded = mlx_log_mel_spectrogram(audio=input_segments.detach(), n_mels=self.model.dims.n_mels, padding=N_SAMPLES)
mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2) mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2)
@@ -434,18 +457,19 @@ class PaddedAlignAttWhisper:
end_encode = time() end_encode = time()
# print('Encoder duration:', end_encode-beg_encode) # print('Encoder duration:', end_encode-beg_encode)
if self.cfg.language == "auto" and self.detected_language is None: if self.cfg.language == "auto" and self.detected_language is None and self.first_timestamp:
seconds_since_start = (self.cumulative_time_offset + self.segments_len()) - self.sentence_start_time seconds_since_start = self.segments_len() - self.first_timestamp
if seconds_since_start >= 3.0: if seconds_since_start >= 2.0:
language_tokens, language_probs = self.lang_id(encoder_feature) language_tokens, language_probs = self.lang_id(encoder_feature)
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1]) top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
print(f"Detected language: {top_lan} with p={p:.4f}") print(f"Detected language: {top_lan} with p={p:.4f}")
self.create_tokenizer(top_lan) self.create_tokenizer(top_lan)
self.refresh_segment(complete=True) self.last_attend_frame = -self.cfg.rewind_threshold
self.cumulative_time_offset = 0.0
self.init_tokens()
self.init_context()
self.detected_language = top_lan self.detected_language = top_lan
logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}") logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}")
else:
logger.debug(f"Skipping language detection: {seconds_since_start:.2f}s < 3.0s")
self.trim_context() self.trim_context()
current_tokens = self._current_tokens() current_tokens = self._current_tokens()
@@ -568,6 +592,12 @@ class PaddedAlignAttWhisper:
tokens_to_split = current_tokens[0, token_len_before_decoding:] tokens_to_split = current_tokens[0, token_len_before_decoding:]
# Prepend pending tokens from previous chunk if any
if self.pending_incomplete_tokens:
logger.debug(f"[UTF-8 Fix] Prepending {len(self.pending_incomplete_tokens)} pending tokens: {self.pending_incomplete_tokens}")
pending_tensor = torch.tensor(self.pending_incomplete_tokens, dtype=torch.long, device=self.device)
tokens_to_split = torch.cat([pending_tensor, tokens_to_split])
if fire_detected or is_last: #or punctuation_stop: if fire_detected or is_last: #or punctuation_stop:
new_hypothesis = tokens_to_split.flatten().tolist() new_hypothesis = tokens_to_split.flatten().tolist()
split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis) split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis)
@@ -590,9 +620,20 @@ class PaddedAlignAttWhisper:
self._clean_cache() self._clean_cache()
if len(l_absolute_timestamps) >=2 and self.first_timestamp is None:
self.first_timestamp = l_absolute_timestamps[0]
timestamped_words = [] timestamped_words = []
timestamp_idx = 0 timestamp_idx = 0
replacement_char = "\ufffd"
for word, word_tokens in zip(split_words, split_tokens): for word, word_tokens in zip(split_words, split_tokens):
# Skip words containing incomplete UTF-8 from client output
if replacement_char in word:
logger.warning(f"[UTF-8 Filter] Skipping incomplete word from client output: {repr(word)}")
timestamp_idx += len(word_tokens)
continue
try: try:
current_timestamp = l_absolute_timestamps[timestamp_idx] current_timestamp = l_absolute_timestamps[timestamp_idx]
except: except:
@@ -600,19 +641,20 @@ class PaddedAlignAttWhisper:
timestamp_idx += len(word_tokens) timestamp_idx += len(word_tokens)
timestamp_entry = ASRToken( timestamp_entry = ASRToken(
start=current_timestamp, start=round(current_timestamp, 2),
end=current_timestamp + 0.1, end=round(current_timestamp + 0.1, 2),
text= word, text= word,
probability=0.95, speaker=self.speaker,
language=self.detected_language detected_language=self.detected_language
).with_offset( ).with_offset(
self.global_time_offset self.global_time_offset
) )
timestamped_words.append(timestamp_entry) timestamped_words.append(timestamp_entry)
if self.detected_language is None and self.cfg.language == "auto": # Hold incomplete tokens for next chunk
timestamped_buffer_language, timestamped_words = timestamped_words, [] self.pending_incomplete_tokens = []
else: if split_words and replacement_char in split_words[-1]:
timestamped_buffer_language = [] self.pending_incomplete_tokens = split_tokens[-1]
logger.warning(f"[UTF-8 Fix] Holding {len(self.pending_incomplete_tokens)} incomplete tokens for next chunk: {self.pending_incomplete_tokens}")
return timestamped_words, timestamped_buffer_language return timestamped_words

View File

@@ -7,6 +7,7 @@ class TokenBuffer:
self.prefix_token_ids = prefix_token_ids self.prefix_token_ids = prefix_token_ids
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.device = device self.device = device
self.pending_token_ids = []
def as_token_ids(self, tokenizer=None): def as_token_ids(self, tokenizer=None):
@@ -64,7 +65,26 @@ class TokenBuffer:
def append_token_ids(self, token_ids): def append_token_ids(self, token_ids):
tokenizer = self.tokenizer tokenizer = self.tokenizer
assert tokenizer is not None, "Tokenizer is not set." assert tokenizer is not None, "Tokenizer is not set."
self.text += self.tokenizer.decode(token_ids)
all_tokens = self.pending_token_ids + token_ids
decoded = tokenizer.decode(all_tokens)
replacement_char = "\ufffd"
if replacement_char in decoded:
if len(all_tokens) > 1:
decoded_partial = tokenizer.decode(all_tokens[:-1])
if replacement_char not in decoded_partial:
self.text += decoded_partial
self.pending_token_ids = [all_tokens[-1]]
else:
self.pending_token_ids = all_tokens
else:
self.pending_token_ids = all_tokens
else:
self.text += decoded
self.pending_token_ids = []
def as_split_word_tokens(self): def as_split_word_tokens(self):
tokenizer = self.tokenizer tokenizer = self.tokenizer

View File

@@ -1,168 +0,0 @@
import hashlib
import io
import os
import urllib
import warnings
from typing import List, Optional, Union
import torch
from tqdm import tqdm
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
from .model import ModelDimensions, Whisper
from .transcribe import transcribe
from .version import __version__
_MODELS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
"large-v3-turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
"turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
}
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
_ALIGNMENT_HEADS = {
"tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
"base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
"base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
"small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
"small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
"large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
"turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
}
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
os.makedirs(root, exist_ok=True)
expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, os.path.basename(url))
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
with open(download_target, "rb") as f:
model_bytes = f.read()
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return model_bytes if in_memory else download_target
else:
warnings.warn(
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
)
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(
total=int(source.info().get("Content-Length")),
ncols=80,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
)
return model_bytes if in_memory else download_target
def available_models() -> List[str]:
"""Returns the names of available models"""
return list(_MODELS.keys())
def load_model(
name: str,
device: Optional[Union[str, torch.device]] = None,
download_root: str = None,
in_memory: bool = False,
decoder_only=False
) -> Whisper:
"""
Load a Whisper ASR model
Parameters
----------
name : str
one of the official model names listed by `whisper.available_models()`, or
path to a model checkpoint containing the model dimensions and the model state_dict.
device : Union[str, torch.device]
the PyTorch device to put the model into
download_root: str
path to download the model files; by default, it uses "~/.cache/whisper"
in_memory: bool
whether to preload the model weights into host memory
Returns
-------
model : Whisper
The Whisper ASR model instance
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if download_root is None:
default = os.path.join(os.path.expanduser("~"), ".cache")
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
alignment_heads = _ALIGNMENT_HEADS[name]
elif os.path.isfile(name):
checkpoint_file = open(name, "rb").read() if in_memory else name
alignment_heads = None
else:
raise RuntimeError(
f"Model {name} not found; available models = {available_models()}"
)
with (
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
) as fp:
checkpoint = torch.load(fp, map_location=device)
del checkpoint_file
dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims, decoder_only=decoder_only)
if decoder_only:
checkpoint["model_state_dict"] = {
k: v for k, v in checkpoint["model_state_dict"].items()
if 'encoder' not in k
}
model.load_state_dict(checkpoint["model_state_dict"])
if alignment_heads is not None:
model.set_alignment_heads(alignment_heads)
return model.to(device)

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional, Any, List from typing import Optional, List, Union, Dict, Any
from datetime import timedelta from datetime import timedelta
PUNCTUATION_MARKS = {'.', '!', '?', '', '', ''} PUNCTUATION_MARKS = {'.', '!', '?', '', '', ''}
@@ -8,22 +8,19 @@ def format_time(seconds: float) -> str:
"""Format seconds as HH:MM:SS.""" """Format seconds as HH:MM:SS."""
return str(timedelta(seconds=int(seconds))) return str(timedelta(seconds=int(seconds)))
@dataclass @dataclass
class TimedText: class Timed:
start: Optional[float] = 0 start: Optional[float] = 0
end: Optional[float] = 0 end: Optional[float] = 0
@dataclass
class TimedText(Timed):
text: Optional[str] = '' text: Optional[str] = ''
speaker: Optional[int] = -1 speaker: Optional[int] = -1
probability: Optional[float] = None detected_language: Optional[str] = None
is_dummy: Optional[bool] = False
language: str = None
def is_punctuation(self): def has_punctuation(self) -> bool:
return self.text.strip() in PUNCTUATION_MARKS return any(char in PUNCTUATION_MARKS for char in self.text.strip())
def overlaps_with(self, other: 'TimedText') -> bool:
return not (self.end <= other.start or other.end <= self.start)
def is_within(self, other: 'TimedText') -> bool: def is_within(self, other: 'TimedText') -> bool:
return other.contains_timespan(self) return other.contains_timespan(self)
@@ -31,21 +28,25 @@ class TimedText:
def duration(self) -> float: def duration(self) -> float:
return self.end - self.start return self.end - self.start
def contains_time(self, time: float) -> bool:
return self.start <= time <= self.end
def contains_timespan(self, other: 'TimedText') -> bool: def contains_timespan(self, other: 'TimedText') -> bool:
return self.start <= other.start and self.end >= other.end return self.start <= other.start and self.end >= other.end
def __bool__(self): def __bool__(self) -> bool:
return bool(self.text) return bool(self.text)
def __str__(self) -> str:
return str(self.text)
@dataclass @dataclass()
class ASRToken(TimedText): class ASRToken(TimedText):
def with_offset(self, offset: float) -> "ASRToken": def with_offset(self, offset: float) -> "ASRToken":
"""Return a new token with the time offset added.""" """Return a new token with the time offset added."""
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, self.probability) return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, detected_language=self.detected_language)
def is_silence(self) -> bool:
return False
@dataclass @dataclass
class Sentence(TimedText): class Sentence(TimedText):
@@ -64,70 +65,95 @@ class Transcript(TimedText):
sep: Optional[str] = None, sep: Optional[str] = None,
offset: float = 0 offset: float = 0
) -> "Transcript": ) -> "Transcript":
"""Collapse multiple ASR tokens into a single transcript span."""
sep = sep if sep is not None else ' ' sep = sep if sep is not None else ' '
text = sep.join(token.text for token in tokens) text = sep.join(token.text for token in tokens)
probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
if tokens: if tokens:
start = offset + tokens[0].start start = offset + tokens[0].start
end = offset + tokens[-1].end end = offset + tokens[-1].end
else: else:
start = None start = None
end = None end = None
return cls(start, end, text, probability=probability) return cls(start, end, text)
@dataclass @dataclass
class SpeakerSegment(TimedText): class SpeakerSegment(Timed):
"""Represents a segment of audio attributed to a specific speaker. """Represents a segment of audio attributed to a specific speaker.
No text nor probability is associated with this segment. No text nor probability is associated with this segment.
""" """
speaker: Optional[int] = -1
pass pass
@dataclass @dataclass
class Translation(TimedText): class Translation(TimedText):
pass pass
def approximate_cut_at(self, cut_time): @dataclass
""" class Silence():
Each word in text is considered to be of duration (end-start)/len(words in text) start: Optional[float] = None
""" end: Optional[float] = None
if not self.text or not self.contains_time(cut_time): duration: Optional[float] = None
return self, None is_starting: bool = False
has_ended: bool = False
words = self.text.split() def compute_duration(self) -> Optional[float]:
num_words = len(words) if self.start is None or self.end is None:
if num_words == 0: return None
return self, None self.duration = self.end - self.start
return self.duration
duration_per_word = self.duration() / num_words def is_silence(self) -> bool:
return True
cut_word_index = int((cut_time - self.start) / duration_per_word)
if cut_word_index >= num_words:
cut_word_index = num_words -1
text0 = " ".join(words[:cut_word_index])
text1 = " ".join(words[cut_word_index:])
segment0 = Translation(start=self.start, end=cut_time, text=text0)
segment1 = Translation(start=cut_time, end=self.end, text=text1)
return segment0, segment1
@dataclass @dataclass
class Silence(): class Segment(TimedText):
duration: float """Generic contiguous span built from tokens or silence markers."""
start: Optional[float]
end: Optional[float]
text: Optional[str]
speaker: Optional[str]
@classmethod
def from_tokens(
cls,
tokens: List[Union[ASRToken, Silence]],
is_silence: bool = False
) -> Optional["Segment"]:
"""Return a normalized segment representing the provided tokens."""
if not tokens:
return None
start_token = tokens[0]
end_token = tokens[-1]
if is_silence:
return cls(
start=start_token.start,
end=end_token.end,
text=None,
speaker=-2
)
else:
return cls(
start=start_token.start,
end=end_token.end,
text=''.join(token.text for token in tokens),
speaker=-1,
detected_language=start_token.detected_language
)
def is_silence(self) -> bool:
"""True when this segment represents a silence gap."""
return self.speaker == -2
@dataclass @dataclass
class Line(TimedText): class Line(TimedText):
translation: str = '' translation: str = ''
detected_language: str = None
def to_dict(self): def to_dict(self) -> Dict[str, Any]:
_dict = { """Serialize the line for frontend consumption."""
'speaker': int(self.speaker), _dict: Dict[str, Any] = {
'speaker': int(self.speaker) if self.speaker != -1 else 1,
'text': self.text, 'text': self.text,
'start': format_time(self.start), 'start': format_time(self.start),
'end': format_time(self.end), 'end': format_time(self.end),
@@ -138,6 +164,33 @@ class Line(TimedText):
_dict['detected_language'] = self.detected_language _dict['detected_language'] = self.detected_language
return _dict return _dict
def build_from_tokens(self, tokens: List[ASRToken]) -> "Line":
"""Populate line attributes from a contiguous token list."""
self.text = ''.join([token.text for token in tokens])
self.start = tokens[0].start
self.end = tokens[-1].end
self.speaker = 1
self.detected_language = tokens[0].detected_language
return self
def build_from_segment(self, segment: Segment) -> "Line":
"""Populate the line fields from a pre-built segment."""
self.text = segment.text
self.start = segment.start
self.end = segment.end
self.speaker = segment.speaker
self.detected_language = segment.detected_language
return self
def is_silent(self) -> bool:
return self.speaker == -2
class SilentLine(Line):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.speaker = -2
self.text = ''
@dataclass @dataclass
class FrontData(): class FrontData():
@@ -146,15 +199,18 @@ class FrontData():
lines: list[Line] = field(default_factory=list) lines: list[Line] = field(default_factory=list)
buffer_transcription: str = '' buffer_transcription: str = ''
buffer_diarization: str = '' buffer_diarization: str = ''
buffer_translation: str = ''
remaining_time_transcription: float = 0. remaining_time_transcription: float = 0.
remaining_time_diarization: float = 0. remaining_time_diarization: float = 0.
def to_dict(self): def to_dict(self) -> Dict[str, Any]:
_dict = { """Serialize the front-end data payload."""
_dict: Dict[str, Any] = {
'status': self.status, 'status': self.status,
'lines': [line.to_dict() for line in self.lines], 'lines': [line.to_dict() for line in self.lines if (line.text or line.speaker == -2)],
'buffer_transcription': self.buffer_transcription, 'buffer_transcription': self.buffer_transcription,
'buffer_diarization': self.buffer_diarization, 'buffer_diarization': self.buffer_diarization,
'buffer_translation': self.buffer_translation,
'remaining_time_transcription': self.remaining_time_transcription, 'remaining_time_transcription': self.remaining_time_transcription,
'remaining_time_diarization': self.remaining_time_diarization, 'remaining_time_diarization': self.remaining_time_diarization,
} }
@@ -162,13 +218,29 @@ class FrontData():
_dict['error'] = self.error _dict['error'] = self.error
return _dict return _dict
@dataclass
class ChangeSpeaker:
speaker: int
start: int
@dataclass @dataclass
class State(): class State():
tokens: list """Unified state class for audio processing.
translated_segments: list
buffer_transcription: str Contains both persistent state (tokens, buffers) and temporary update buffers
buffer_diarization: str (new_* fields) that are consumed by TokensAlignment.
end_buffer: float """
end_attributed_speaker: float # Persistent state
remaining_time_transcription: float tokens: List[ASRToken] = field(default_factory=list)
remaining_time_diarization: float buffer_transcription: Transcript = field(default_factory=Transcript)
end_buffer: float = 0.0
end_attributed_speaker: float = 0.0
remaining_time_transcription: float = 0.0
remaining_time_diarization: float = 0.0
# Temporary update buffers (consumed by TokensAlignment.update())
new_tokens: List[Union[ASRToken, Silence]] = field(default_factory=list)
new_translation: List[Any] = field(default_factory=list)
new_diarization: List[Any] = field(default_factory=list)
new_tokens_buffer: List[Any] = field(default_factory=list) # only when local agreement
new_translation_buffer= TimedText()

View File

@@ -0,0 +1,177 @@
from time import time
from typing import Optional, List, Tuple, Union, Any
from whisperlivekit.timed_objects import Line, SilentLine, ASRToken, SpeakerSegment, Silence, TimedText, Segment
class TokensAlignment:
def __init__(self, state: Any, args: Any, sep: Optional[str]) -> None:
self.state = state
self.diarization = args.diarization
self._tokens_index: int = 0
self._diarization_index: int = 0
self._translation_index: int = 0
self.all_tokens: List[ASRToken] = []
self.all_diarization_segments: List[SpeakerSegment] = []
self.all_translation_segments: List[Any] = []
self.new_tokens: List[ASRToken] = []
self.new_diarization: List[SpeakerSegment] = []
self.new_translation: List[Any] = []
self.new_translation_buffer: Union[TimedText, str] = TimedText()
self.new_tokens_buffer: List[Any] = []
self.sep: str = sep if sep is not None else ' '
self.beg_loop: Optional[float] = None
def update(self) -> None:
"""Drain state buffers into the running alignment context."""
self.new_tokens, self.state.new_tokens = self.state.new_tokens, []
self.new_diarization, self.state.new_diarization = self.state.new_diarization, []
self.new_translation, self.state.new_translation = self.state.new_translation, []
self.new_tokens_buffer, self.state.new_tokens_buffer = self.state.new_tokens_buffer, []
self.all_tokens.extend(self.new_tokens)
self.all_diarization_segments.extend(self.new_diarization)
self.all_translation_segments.extend(self.new_translation)
self.new_translation_buffer = self.state.new_translation_buffer
def add_translation(self, line: Line) -> None:
"""Append translated text segments that overlap with a line."""
for ts in self.all_translation_segments:
if ts.is_within(line):
line.translation += ts.text + (self.sep if ts.text else '')
elif line.translation:
break
def compute_punctuations_segments(self, tokens: Optional[List[ASRToken]] = None) -> List[Segment]:
"""Group tokens into segments split by punctuation and explicit silence."""
segments = []
segment_start_idx = 0
for i, token in enumerate(self.all_tokens):
if token.is_silence():
previous_segment = Segment.from_tokens(
tokens=self.all_tokens[segment_start_idx: i],
)
if previous_segment:
segments.append(previous_segment)
segment = Segment.from_tokens(
tokens=[token],
is_silence=True
)
segments.append(segment)
segment_start_idx = i+1
else:
if token.has_punctuation():
segment = Segment.from_tokens(
tokens=self.all_tokens[segment_start_idx: i+1],
)
segments.append(segment)
segment_start_idx = i+1
final_segment = Segment.from_tokens(
tokens=self.all_tokens[segment_start_idx:],
)
if final_segment:
segments.append(final_segment)
return segments
def concatenate_diar_segments(self) -> List[SpeakerSegment]:
"""Merge consecutive diarization slices that share the same speaker."""
if not self.all_diarization_segments:
return []
merged = [self.all_diarization_segments[0]]
for segment in self.all_diarization_segments[1:]:
if segment.speaker == merged[-1].speaker:
merged[-1].end = segment.end
else:
merged.append(segment)
return merged
@staticmethod
def intersection_duration(seg1: TimedText, seg2: TimedText) -> float:
"""Return the overlap duration between two timed segments."""
start = max(seg1.start, seg2.start)
end = min(seg1.end, seg2.end)
return max(0, end - start)
def get_lines_diarization(self) -> Tuple[List[Line], str]:
"""Build lines when diarization is enabled and track overflow buffer."""
diarization_buffer = ''
punctuation_segments = self.compute_punctuations_segments()
diarization_segments = self.concatenate_diar_segments()
for punctuation_segment in punctuation_segments:
if not punctuation_segment.is_silence():
if diarization_segments and punctuation_segment.start >= diarization_segments[-1].end:
diarization_buffer += punctuation_segment.text
else:
max_overlap = 0.0
max_overlap_speaker = 1
for diarization_segment in diarization_segments:
intersec = self.intersection_duration(punctuation_segment, diarization_segment)
if intersec > max_overlap:
max_overlap = intersec
max_overlap_speaker = diarization_segment.speaker + 1
punctuation_segment.speaker = max_overlap_speaker
lines = []
if punctuation_segments:
lines = [Line().build_from_segment(punctuation_segments[0])]
for segment in punctuation_segments[1:]:
if segment.speaker == lines[-1].speaker:
if lines[-1].text:
lines[-1].text += segment.text
lines[-1].end = segment.end
else:
lines.append(Line().build_from_segment(segment))
return lines, diarization_buffer
def get_lines(
self,
diarization: bool = False,
translation: bool = False,
current_silence: Optional[Silence] = None
) -> Tuple[List[Line], str, Union[str, TimedText]]:
"""Return the formatted lines plus buffers, optionally with diarization/translation."""
if diarization:
lines, diarization_buffer = self.get_lines_diarization()
else:
diarization_buffer = ''
lines = []
current_line_tokens = []
for token in self.all_tokens:
if token.is_silence():
if current_line_tokens:
lines.append(Line().build_from_tokens(current_line_tokens))
current_line_tokens = []
end_silence = token.end if token.has_ended else time() - self.beg_loop
if lines and lines[-1].is_silent():
lines[-1].end = end_silence
else:
lines.append(SilentLine(
start = token.start,
end = end_silence
))
else:
current_line_tokens.append(token)
if current_line_tokens:
lines.append(Line().build_from_tokens(current_line_tokens))
if current_silence:
end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop
if lines and lines[-1].is_silent():
lines[-1].end = end_silence
else:
lines.append(SilentLine(
start = current_silence.start,
end = end_silence
))
if translation:
[self.add_translation(line) for line in lines if not type(line) == Silence]
return lines, diarization_buffer, self.new_translation_buffer.text

View File

@@ -1,60 +0,0 @@
from typing import Sequence, Callable, Any, Optional, Dict
def _detect_tail_repetition(
seq: Sequence[Any],
key: Callable[[Any], Any] = lambda x: x, # extract comparable value
min_block: int = 1, # set to 2 to ignore 1-token loops like "."
max_tail: int = 300, # search window from the end for speed
prefer: str = "longest", # "longest" coverage or "smallest" block
) -> Optional[Dict]:
vals = [key(x) for x in seq][-max_tail:]
n = len(vals)
best = None
# try every possible block length
for b in range(min_block, n // 2 + 1):
block = vals[-b:]
# count how many times this block repeats contiguously at the very end
count, i = 0, n
while i - b >= 0 and vals[i - b:i] == block:
count += 1
i -= b
if count >= 2:
cand = {
"block_size": b,
"count": count,
"start_index": len(seq) - count * b, # in original seq
"end_index": len(seq),
}
if (best is None or
(prefer == "longest" and count * b > best["count"] * best["block_size"]) or
(prefer == "smallest" and b < best["block_size"])):
best = cand
return best
def trim_tail_repetition(
seq: Sequence[Any],
key: Callable[[Any], Any] = lambda x: x,
min_block: int = 1,
max_tail: int = 300,
prefer: str = "longest",
keep: int = 1, # how many copies of the repeating block to keep at the end (0 or 1 are common)
):
"""
Returns a new sequence with repeated tail trimmed.
keep=1 -> keep a single copy of the repeated block.
keep=0 -> remove all copies of the repeated block.
"""
rep = _detect_tail_repetition(seq, key, min_block, max_tail, prefer)
if not rep:
return seq, False # nothing to trim
b, c = rep["block_size"], rep["count"]
if keep < 0:
keep = 0
if keep >= c:
return seq, False # nothing to trim (already <= keep copies)
# new length = total - (copies_to_remove * block_size)
new_len = len(seq) - (c - keep) * b
return seq[:new_len], True

View File

@@ -1,182 +0,0 @@
"""
adapted from https://store.crowdin.com/custom-mt
"""
LANGUAGES = [
{"name": "Afrikaans", "nllb": "afr_Latn", "crowdin": "af"},
{"name": "Akan", "nllb": "aka_Latn", "crowdin": "ak"},
{"name": "Amharic", "nllb": "amh_Ethi", "crowdin": "am"},
{"name": "Assamese", "nllb": "asm_Beng", "crowdin": "as"},
{"name": "Asturian", "nllb": "ast_Latn", "crowdin": "ast"},
{"name": "Bashkir", "nllb": "bak_Cyrl", "crowdin": "ba"},
{"name": "Bambara", "nllb": "bam_Latn", "crowdin": "bm"},
{"name": "Balinese", "nllb": "ban_Latn", "crowdin": "ban"},
{"name": "Belarusian", "nllb": "bel_Cyrl", "crowdin": "be"},
{"name": "Bengali", "nllb": "ben_Beng", "crowdin": "bn"},
{"name": "Bosnian", "nllb": "bos_Latn", "crowdin": "bs"},
{"name": "Bulgarian", "nllb": "bul_Cyrl", "crowdin": "bg"},
{"name": "Catalan", "nllb": "cat_Latn", "crowdin": "ca"},
{"name": "Cebuano", "nllb": "ceb_Latn", "crowdin": "ceb"},
{"name": "Czech", "nllb": "ces_Latn", "crowdin": "cs"},
{"name": "Welsh", "nllb": "cym_Latn", "crowdin": "cy"},
{"name": "Danish", "nllb": "dan_Latn", "crowdin": "da"},
{"name": "German", "nllb": "deu_Latn", "crowdin": "de"},
{"name": "Dzongkha", "nllb": "dzo_Tibt", "crowdin": "dz"},
{"name": "Greek", "nllb": "ell_Grek", "crowdin": "el"},
{"name": "English", "nllb": "eng_Latn", "crowdin": "en"},
{"name": "Esperanto", "nllb": "epo_Latn", "crowdin": "eo"},
{"name": "Estonian", "nllb": "est_Latn", "crowdin": "et"},
{"name": "Basque", "nllb": "eus_Latn", "crowdin": "eu"},
{"name": "Ewe", "nllb": "ewe_Latn", "crowdin": "ee"},
{"name": "Faroese", "nllb": "fao_Latn", "crowdin": "fo"},
{"name": "Fijian", "nllb": "fij_Latn", "crowdin": "fj"},
{"name": "Finnish", "nllb": "fin_Latn", "crowdin": "fi"},
{"name": "French", "nllb": "fra_Latn", "crowdin": "fr"},
{"name": "Friulian", "nllb": "fur_Latn", "crowdin": "fur-IT"},
{"name": "Scottish Gaelic", "nllb": "gla_Latn", "crowdin": "gd"},
{"name": "Irish", "nllb": "gle_Latn", "crowdin": "ga-IE"},
{"name": "Galician", "nllb": "glg_Latn", "crowdin": "gl"},
{"name": "Guarani", "nllb": "grn_Latn", "crowdin": "gn"},
{"name": "Gujarati", "nllb": "guj_Gujr", "crowdin": "gu-IN"},
{"name": "Haitian Creole", "nllb": "hat_Latn", "crowdin": "ht"},
{"name": "Hausa", "nllb": "hau_Latn", "crowdin": "ha"},
{"name": "Hebrew", "nllb": "heb_Hebr", "crowdin": "he"},
{"name": "Hindi", "nllb": "hin_Deva", "crowdin": "hi"},
{"name": "Croatian", "nllb": "hrv_Latn", "crowdin": "hr"},
{"name": "Hungarian", "nllb": "hun_Latn", "crowdin": "hu"},
{"name": "Armenian", "nllb": "hye_Armn", "crowdin": "hy-AM"},
{"name": "Igbo", "nllb": "ibo_Latn", "crowdin": "ig"},
{"name": "Indonesian", "nllb": "ind_Latn", "crowdin": "id"},
{"name": "Icelandic", "nllb": "isl_Latn", "crowdin": "is"},
{"name": "Italian", "nllb": "ita_Latn", "crowdin": "it"},
{"name": "Javanese", "nllb": "jav_Latn", "crowdin": "jv"},
{"name": "Japanese", "nllb": "jpn_Jpan", "crowdin": "ja"},
{"name": "Kabyle", "nllb": "kab_Latn", "crowdin": "kab"},
{"name": "Kannada", "nllb": "kan_Knda", "crowdin": "kn"},
{"name": "Georgian", "nllb": "kat_Geor", "crowdin": "ka"},
{"name": "Kazakh", "nllb": "kaz_Cyrl", "crowdin": "kk"},
{"name": "Khmer", "nllb": "khm_Khmr", "crowdin": "km"},
{"name": "Kinyarwanda", "nllb": "kin_Latn", "crowdin": "rw"},
{"name": "Kyrgyz", "nllb": "kir_Cyrl", "crowdin": "ky"},
{"name": "Korean", "nllb": "kor_Hang", "crowdin": "ko"},
{"name": "Lao", "nllb": "lao_Laoo", "crowdin": "lo"},
{"name": "Ligurian", "nllb": "lij_Latn", "crowdin": "lij"},
{"name": "Limburgish", "nllb": "lim_Latn", "crowdin": "li"},
{"name": "Lingala", "nllb": "lin_Latn", "crowdin": "ln"},
{"name": "Lithuanian", "nllb": "lit_Latn", "crowdin": "lt"},
{"name": "Luxembourgish", "nllb": "ltz_Latn", "crowdin": "lb"},
{"name": "Maithili", "nllb": "mai_Deva", "crowdin": "mai"},
{"name": "Malayalam", "nllb": "mal_Mlym", "crowdin": "ml-IN"},
{"name": "Marathi", "nllb": "mar_Deva", "crowdin": "mr"},
{"name": "Macedonian", "nllb": "mkd_Cyrl", "crowdin": "mk"},
{"name": "Maltese", "nllb": "mlt_Latn", "crowdin": "mt"},
{"name": "Mossi", "nllb": "mos_Latn", "crowdin": "mos"},
{"name": "Maori", "nllb": "mri_Latn", "crowdin": "mi"},
{"name": "Burmese", "nllb": "mya_Mymr", "crowdin": "my"},
{"name": "Dutch", "nllb": "nld_Latn", "crowdin": "nl"},
{"name": "Norwegian Nynorsk", "nllb": "nno_Latn", "crowdin": "nn-NO"},
{"name": "Nepali", "nllb": "npi_Deva", "crowdin": "ne-NP"},
{"name": "Northern Sotho", "nllb": "nso_Latn", "crowdin": "nso"},
{"name": "Occitan", "nllb": "oci_Latn", "crowdin": "oc"},
{"name": "Odia", "nllb": "ory_Orya", "crowdin": "or"},
{"name": "Papiamento", "nllb": "pap_Latn", "crowdin": "pap"},
{"name": "Polish", "nllb": "pol_Latn", "crowdin": "pl"},
{"name": "Portuguese", "nllb": "por_Latn", "crowdin": "pt-PT"},
{"name": "Dari", "nllb": "prs_Arab", "crowdin": "fa-AF"},
{"name": "Romanian", "nllb": "ron_Latn", "crowdin": "ro"},
{"name": "Rundi", "nllb": "run_Latn", "crowdin": "rn"},
{"name": "Russian", "nllb": "rus_Cyrl", "crowdin": "ru"},
{"name": "Sango", "nllb": "sag_Latn", "crowdin": "sg"},
{"name": "Sanskrit", "nllb": "san_Deva", "crowdin": "sa"},
{"name": "Santali", "nllb": "sat_Olck", "crowdin": "sat"},
{"name": "Sinhala", "nllb": "sin_Sinh", "crowdin": "si-LK"},
{"name": "Slovak", "nllb": "slk_Latn", "crowdin": "sk"},
{"name": "Slovenian", "nllb": "slv_Latn", "crowdin": "sl"},
{"name": "Shona", "nllb": "sna_Latn", "crowdin": "sn"},
{"name": "Sindhi", "nllb": "snd_Arab", "crowdin": "sd"},
{"name": "Somali", "nllb": "som_Latn", "crowdin": "so"},
{"name": "Southern Sotho", "nllb": "sot_Latn", "crowdin": "st"},
{"name": "Spanish", "nllb": "spa_Latn", "crowdin": "es-ES"},
{"name": "Sardinian", "nllb": "srd_Latn", "crowdin": "sc"},
{"name": "Swati", "nllb": "ssw_Latn", "crowdin": "ss"},
{"name": "Sundanese", "nllb": "sun_Latn", "crowdin": "su"},
{"name": "Swedish", "nllb": "swe_Latn", "crowdin": "sv-SE"},
{"name": "Swahili", "nllb": "swh_Latn", "crowdin": "sw"},
{"name": "Tamil", "nllb": "tam_Taml", "crowdin": "ta"},
{"name": "Tatar", "nllb": "tat_Cyrl", "crowdin": "tt-RU"},
{"name": "Telugu", "nllb": "tel_Telu", "crowdin": "te"},
{"name": "Tajik", "nllb": "tgk_Cyrl", "crowdin": "tg"},
{"name": "Tagalog", "nllb": "tgl_Latn", "crowdin": "tl"},
{"name": "Thai", "nllb": "tha_Thai", "crowdin": "th"},
{"name": "Tigrinya", "nllb": "tir_Ethi", "crowdin": "ti"},
{"name": "Tswana", "nllb": "tsn_Latn", "crowdin": "tn"},
{"name": "Tsonga", "nllb": "tso_Latn", "crowdin": "ts"},
{"name": "Turkmen", "nllb": "tuk_Latn", "crowdin": "tk"},
{"name": "Turkish", "nllb": "tur_Latn", "crowdin": "tr"},
{"name": "Uyghur", "nllb": "uig_Arab", "crowdin": "ug"},
{"name": "Ukrainian", "nllb": "ukr_Cyrl", "crowdin": "uk"},
{"name": "Venetian", "nllb": "vec_Latn", "crowdin": "vec"},
{"name": "Vietnamese", "nllb": "vie_Latn", "crowdin": "vi"},
{"name": "Wolof", "nllb": "wol_Latn", "crowdin": "wo"},
{"name": "Xhosa", "nllb": "xho_Latn", "crowdin": "xh"},
{"name": "Yoruba", "nllb": "yor_Latn", "crowdin": "yo"},
{"name": "Zulu", "nllb": "zul_Latn", "crowdin": "zu"},
]
NAME_TO_NLLB = {lang["name"]: lang["nllb"] for lang in LANGUAGES}
NAME_TO_CROWDIN = {lang["name"]: lang["crowdin"] for lang in LANGUAGES}
CROWDIN_TO_NLLB = {lang["crowdin"]: lang["nllb"] for lang in LANGUAGES}
NLLB_TO_CROWDIN = {lang["nllb"]: lang["crowdin"] for lang in LANGUAGES}
CROWDIN_TO_NAME = {lang["crowdin"]: lang["name"] for lang in LANGUAGES}
NLLB_TO_NAME = {lang["nllb"]: lang["name"] for lang in LANGUAGES}
def get_nllb_code(crowdin_code):
return CROWDIN_TO_NLLB.get(crowdin_code, None)
def get_crowdin_code(nllb_code):
return NLLB_TO_CROWDIN.get(nllb_code)
def get_language_name_by_crowdin(crowdin_code):
return CROWDIN_TO_NAME.get(crowdin_code)
def get_language_name_by_nllb(nllb_code):
return NLLB_TO_NAME.get(nllb_code)
def get_language_info(identifier, identifier_type="auto"):
if identifier_type == "auto":
for lang in LANGUAGES:
if (lang["name"].lower() == identifier.lower() or
lang["nllb"] == identifier or
lang["crowdin"] == identifier):
return lang
elif identifier_type == "name":
for lang in LANGUAGES:
if lang["name"].lower() == identifier.lower():
return lang
elif identifier_type == "nllb":
for lang in LANGUAGES:
if lang["nllb"] == identifier:
return lang
elif identifier_type == "crowdin":
for lang in LANGUAGES:
if lang["crowdin"] == identifier:
return lang
return None
def list_all_languages():
return [lang["name"] for lang in LANGUAGES]
def list_all_nllb_codes():
return [lang["nllb"] for lang in LANGUAGES]
def list_all_crowdin_codes():
return [lang["crowdin"] for lang in LANGUAGES]

View File

@@ -1,148 +0,0 @@
import logging
import time
import ctranslate2
import torch
import transformers
from dataclasses import dataclass
import huggingface_hub
from whisperlivekit.translation.mapping_languages import get_nllb_code
from whisperlivekit.timed_objects import Translation
logger = logging.getLogger(__name__)
#In diarization case, we may want to translate just one speaker, or at least start the sentences there
MIN_SILENCE_DURATION_DEL_BUFFER = 3 #After a silence of x seconds, we consider the model should not use the buffer, even if the previous
# sentence is not finished.
@dataclass
class TranslationModel():
translator: ctranslate2.Translator
tokenizer: dict
device: str
backend_type: str = 'ctranslate2'
def load_model(src_langs, backend='ctranslate2', model_size='600M'):
device = "cuda" if torch.cuda.is_available() else "cpu"
MODEL = f'nllb-200-distilled-{model_size}-ctranslate2'
if backend=='ctranslate2':
MODEL_GUY = 'entai2965'
huggingface_hub.snapshot_download(MODEL_GUY + '/' + MODEL,local_dir=MODEL)
translator = ctranslate2.Translator(MODEL,device=device)
elif backend=='transformers':
translator = transformers.AutoModelForSeq2SeqLM.from_pretrained(f"facebook/nllb-200-distilled-{model_size}")
tokenizer = dict()
for src_lang in src_langs:
tokenizer[src_lang] = transformers.AutoTokenizer.from_pretrained(MODEL, src_lang=src_lang, clean_up_tokenization_spaces=True)
return TranslationModel(
translator=translator,
tokenizer=tokenizer,
backend_type=backend,
device = device
)
class OnlineTranslation:
def __init__(self, translation_model: TranslationModel, input_languages: list, output_languages: list):
self.buffer = []
self.len_processed_buffer = 0
self.translation_remaining = Translation()
self.validated = []
self.translation_pending_validation = ''
self.translation_model = translation_model
self.input_languages = input_languages
self.output_languages = output_languages
def compute_common_prefix(self, results):
#we dont want want to prune the result for the moment.
if not self.buffer:
self.buffer = results
else:
for i in range(min(len(self.buffer), len(results))):
if self.buffer[i] != results[i]:
self.commited.extend(self.buffer[:i])
self.buffer = results[i:]
def translate(self, input, input_lang=None, output_lang=None):
if not input:
return ""
if input_lang is None:
input_lang = self.input_languages[0]
if output_lang is None:
output_lang = self.output_languages[0]
nllb_output_lang = get_nllb_code(output_lang)
tokenizer = self.translation_model.tokenizer[input_lang]
tokenizer_output = tokenizer(input, return_tensors="pt").to(self.translation_model.device)
if self.translation_model.backend_type == 'ctranslate2':
source = tokenizer.convert_ids_to_tokens(tokenizer_output['input_ids'][0])
results = self.translation_model.translator.translate_batch([source], target_prefix=[[nllb_output_lang]])
target = results[0].hypotheses[0][1:]
result = tokenizer.decode(tokenizer.convert_tokens_to_ids(target))
else:
translated_tokens = self.translation_model.translator.generate(**tokenizer_output, forced_bos_token_id=tokenizer.convert_tokens_to_ids(nllb_output_lang))
result = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
return result
def translate_tokens(self, tokens):
if tokens:
text = ' '.join([token.text for token in tokens])
start = tokens[0].start
end = tokens[-1].end
translated_text = self.translate(text)
translation = Translation(
text=translated_text,
start=start,
end=end,
)
return translation
return None
def insert_tokens(self, tokens):
self.buffer.extend(tokens)
pass
def process(self):
i = 0
if len(self.buffer) < self.len_processed_buffer + 3: #nothing new to process
return self.validated + [self.translation_remaining]
while i < len(self.buffer):
if self.buffer[i].is_punctuation():
translation_sentence = self.translate_tokens(self.buffer[:i+1])
self.validated.append(translation_sentence)
self.buffer = self.buffer[i+1:]
i = 0
else:
i+=1
self.translation_remaining = self.translate_tokens(self.buffer)
self.len_processed_buffer = len(self.buffer)
return self.validated + [self.translation_remaining]
def insert_silence(self, silence_duration: float):
if silence_duration >= MIN_SILENCE_DURATION_DEL_BUFFER:
self.buffer = []
self.validated += [self.translation_remaining]
if __name__ == '__main__':
output_lang = 'fr'
input_lang = "en"
test_string = """
Transcription technology has improved so much in the past few years. Have you noticed how accurate real-time speech-to-text is now?
"""
test = test_string.split(' ')
step = len(test) // 3
shared_model = load_model([input_lang], backend='ctranslate2')
online_translation = OnlineTranslation(shared_model, input_languages=[input_lang], output_languages=[output_lang])
beg_inference = time.time()
for id in range(5):
val = test[id*step : (id+1)*step]
val_str = ' '.join(val)
result = online_translation.translate(val_str)
print(result)
print('inference time:', time.time() - beg_inference)

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -72,6 +72,12 @@
--label-trans-text: #111111; --label-trans-text: #111111;
} }
html.is-extension
{
width: 350px;
height: 500px;
}
body { body {
font-family: ui-sans-serif, system-ui, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol', 'Noto Color Emoji'; font-family: ui-sans-serif, system-ui, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol', 'Noto Color Emoji';
margin: 0; margin: 0;
@@ -191,6 +197,14 @@ body {
justify-content: center; justify-content: center;
align-items: center; align-items: center;
gap: 15px; gap: 15px;
position: relative;
flex-wrap: wrap;
}
.buttons-container {
display: flex;
align-items: center;
gap: 15px;
} }
.settings { .settings {
@@ -200,6 +214,66 @@ body {
gap: 12px; gap: 12px;
} }
.settings-toggle {
width: 40px;
height: 40px;
border: none;
border-radius: 50%;
background-color: var(--button-bg);
border: 1px solid var(--button-border);
cursor: pointer;
display: none;
align-items: center;
justify-content: center;
transition: all 0.2s ease;
}
.settings-toggle:hover {
background-color: var(--chip-bg);
}
.settings-toggle.active {
background-color: var(--chip-bg);
}
.settings-toggle img {
width: 20px;
height: 20px;
}
@media (max-width: 10000px) {
.settings-toggle {
display: flex;
}
.settings {
display: none;
background: var(--bg);
border: 1px solid var(--border);
border-radius: 18px;
padding: 12px;
}
.settings.visible {
display: flex;
}
}
@media (max-width: 600px) {
.settings-container {
flex-direction: column;
align-items: center;
gap: 10px;
}
.buttons-container {
display: flex;
justify-content: center;
align-items: center;
gap: 15px;
}
}
.field { .field {
display: flex; display: flex;
flex-direction: column; flex-direction: column;
@@ -409,7 +483,6 @@ label {
.buffer_diarization { .buffer_diarization {
color: var(--label-dia-text); color: var(--label-dia-text);
margin-left: 4px;
} }
.buffer_transcription { .buffer_transcription {
@@ -417,6 +490,11 @@ label {
margin-left: 4px; margin-left: 4px;
} }
.buffer_translation {
color: #a0a0a0;
margin-left: 6px;
}
.spinner { .spinner {
display: inline-block; display: inline-block;
width: 8px; width: 8px;
@@ -454,7 +532,7 @@ label {
} }
/* for smaller screens */ /* for smaller screens */
@media (max-width: 768px) { @media (max-width: 200px) {
.header-container { .header-container {
padding: 15px; padding: 15px;
} }
@@ -464,6 +542,10 @@ label {
gap: 10px; gap: 10px;
} }
.buttons-container {
gap: 10px;
}
.settings { .settings {
justify-content: center; justify-content: center;
gap: 8px; gap: 8px;
@@ -522,8 +604,6 @@ label {
.label_language { .label_language {
background-color: var(--chip-bg); background-color: var(--chip-bg);
margin-bottom: 0px; margin-bottom: 0px;
margin-top: 5px;
height: 18.5px;
border-radius: 100px; border-radius: 100px;
padding: 2px 8px; padding: 2px 8px;
margin-left: 10px; margin-left: 10px;
@@ -534,22 +614,6 @@ label {
color: var(--muted); color: var(--muted);
} }
.label_language img {
width: 12px;
height: 12px;
}
.silence-icon {
width: 14px;
height: 14px;
vertical-align: text-bottom;
}
.speaker-icon {
width: 16px;
height: 16px;
vertical-align: text-bottom;
}
.speaker-badge { .speaker-badge {
display: inline-flex; display: inline-flex;

View File

@@ -5,23 +5,29 @@
<meta charset="UTF-8" /> <meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /> <meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>WhisperLiveKit</title> <title>WhisperLiveKit</title>
<link rel="stylesheet" href="/web/live_transcription.css" /> <link rel="stylesheet" href="live_transcription.css" />
</head> </head>
<body> <body>
<div class="header-container"> <div class="header-container">
<div class="settings-container"> <div class="settings-container">
<button id="recordButton"> <div class="buttons-container">
<div class="shape-container"> <button id="recordButton">
<div class="shape"></div> <div class="shape-container">
</div> <div class="shape"></div>
<div class="recording-info">
<div class="wave-container">
<canvas id="waveCanvas"></canvas>
</div> </div>
<div class="timer">00:00</div> <div class="recording-info">
</div> <div class="wave-container">
</button> <canvas id="waveCanvas"></canvas>
</div>
<div class="timer">00:00</div>
</div>
</button>
<button id="settingsToggle" class="settings-toggle" title="Show/hide settings">
<img src="web/src/settings.svg" alt="Settings" />
</button>
</div>
<div class="settings"> <div class="settings">
<div class="field"> <div class="field">
@@ -67,7 +73,7 @@
<div id="linesTranscript"></div> <div id="linesTranscript"></div>
</div> </div>
<script src="/web/live_transcription.js"></script> <script src="live_transcription.js"></script>
</body> </body>
</html> </html>

View File

@@ -1,4 +1,8 @@
/* Theme, WebSocket, recording, rendering logic extracted from inline script and adapted for segmented theme control and WS caption */ const isExtension = typeof chrome !== 'undefined' && chrome.runtime && chrome.runtime.getURL;
if (isExtension) {
document.documentElement.classList.add('is-extension');
}
const isWebContext = !isExtension;
let isRecording = false; let isRecording = false;
let websocket = null; let websocket = null;
@@ -25,6 +29,8 @@ let selectedMicrophoneId = null;
let serverUseAudioWorklet = null; let serverUseAudioWorklet = null;
let configReadyResolve; let configReadyResolve;
const configReady = new Promise((r) => (configReadyResolve = r)); const configReady = new Promise((r) => (configReadyResolve = r));
let outputAudioContext = null;
let audioSource = null;
waveCanvas.width = 60 * (window.devicePixelRatio || 1); waveCanvas.width = 60 * (window.devicePixelRatio || 1);
waveCanvas.height = 30 * (window.devicePixelRatio || 1); waveCanvas.height = 30 * (window.devicePixelRatio || 1);
@@ -40,6 +46,26 @@ const timerElement = document.querySelector(".timer");
const themeRadios = document.querySelectorAll('input[name="theme"]'); const themeRadios = document.querySelectorAll('input[name="theme"]');
const microphoneSelect = document.getElementById("microphoneSelect"); const microphoneSelect = document.getElementById("microphoneSelect");
const settingsToggle = document.getElementById("settingsToggle");
const settingsDiv = document.querySelector(".settings");
// if (isExtension) {
// chrome.runtime.onInstalled.addListener((details) => {
// if (details.reason.search(/install/g) === -1) {
// return;
// }
// chrome.tabs.create({
// url: chrome.runtime.getURL("welcome.html"),
// active: true
// });
// });
// }
const translationIcon = `<svg xmlns="http://www.w3.org/2000/svg" height="12px" viewBox="0 -960 960 960" width="12px" fill="#5f6368"><path d="m603-202-34 97q-4 11-14 18t-22 7q-20 0-32.5-16.5T496-133l152-402q5-11 15-18t22-7h30q12 0 22 7t15 18l152 403q8 19-4 35.5T868-80q-13 0-22.5-7T831-106l-34-96H603ZM362-401 188-228q-11 11-27.5 11.5T132-228q-11-11-11-28t11-28l174-174q-35-35-63.5-80T190-640h84q20 39 40 68t48 58q33-33 68.5-92.5T484-720H80q-17 0-28.5-11.5T40-760q0-17 11.5-28.5T80-800h240v-40q0-17 11.5-28.5T360-880q17 0 28.5 11.5T400-840v40h240q17 0 28.5 11.5T680-760q0 17-11.5 28.5T640-720h-76q-21 72-63 148t-83 116l96 98-30 82-122-125Zm266 129h144l-72-204-72 204Z"/></svg>`
const silenceIcon = `<svg xmlns="http://www.w3.org/2000/svg" style="vertical-align: text-bottom;" height="14px" viewBox="0 -960 960 960" width="14px" fill="#5f6368"><path d="M514-556 320-752q9-3 19-5.5t21-2.5q66 0 113 47t47 113q0 11-1.5 22t-4.5 22ZM40-200v-32q0-33 17-62t47-44q51-26 115-44t141-18q26 0 49.5 2.5T456-392l-56-54q-9 3-19 4.5t-21 1.5q-66 0-113-47t-47-113q0-11 1.5-21t4.5-19L84-764q-11-11-11-28t11-28q12-12 28.5-12t27.5 12l675 685q11 11 11.5 27.5T816-80q-11 13-28 12.5T759-80L641-200h39q0 33-23.5 56.5T600-120H120q-33 0-56.5-23.5T40-200Zm80 0h480v-32q0-14-4.5-19.5T580-266q-36-18-92.5-36T360-320q-71 0-127.5 18T140-266q-9 5-14.5 14t-5.5 20v32Zm240 0Zm560-400q0 69-24.5 131.5T829-355q-12 14-30 15t-32-13q-13-13-12-31t12-33q30-38 46.5-85t16.5-98q0-51-16.5-97T767-781q-12-15-12.5-33t12.5-32q13-14 31.5-13.5T829-845q42 51 66.5 113.5T920-600Zm-182 0q0 32-10 61.5T700-484q-11 15-29.5 15.5T638-482q-13-13-13.5-31.5T633-549q6-11 9.5-24t3.5-27q0-14-3.5-27t-9.5-25q-9-17-8.5-35t13.5-31q14-14 32.5-13.5T700-716q18 25 28 54.5t10 61.5Z"/></svg>`;
const languageIcon = `<svg xmlns="http://www.w3.org/2000/svg" height="12" viewBox="0 -960 960 960" width="12" fill="#5f6368"><path d="M480-80q-82 0-155-31.5t-127.5-86Q143-252 111.5-325T80-480q0-83 31.5-155.5t86-127Q252-817 325-848.5T480-880q83 0 155.5 31.5t127 86q54.5 54.5 86 127T880-480q0 82-31.5 155t-86 127.5q-54.5 54.5-127 86T480-80Zm0-82q26-36 45-75t31-83H404q12 44 31 83t45 75Zm-104-16q-18-33-31.5-68.5T322-320H204q29 50 72.5 87t99.5 55Zm208 0q56-18 99.5-55t72.5-87H638q-9 38-22.5 73.5T584-178ZM170-400h136q-3-20-4.5-39.5T300-480q0-21 1.5-40.5T306-560H170q-5 20-7.5 39.5T160-480q0 21 2.5 40.5T170-400Zm216 0h188q3-20 4.5-39.5T580-480q0-21-1.5-40.5T574-560H386q-3 20-4.5 39.5T380-480q0 21 1.5 40.5T386-400Zm268 0h136q5-20 7.5-39.5T800-480q0-21-2.5-40.5T790-560H654q3 20 4.5 39.5T660-480q0 21-1.5 40.5T654-400Zm-16-240h118q-29-50-72.5-87T584-782q18 33 31.5 68.5T638-640Zm-234 0h152q-12-44-31-83t-45-75q-26 36-45 75t-31 83Zm-200 0h118q9-38 22.5-73.5T376-782q-56 18-99.5 55T204-640Z"/></svg>`
const speakerIcon = `<svg xmlns="http://www.w3.org/2000/svg" height="16px" style="vertical-align: text-bottom;" viewBox="0 -960 960 960" width="16px" fill="#5f6368"><path d="M480-480q-66 0-113-47t-47-113q0-66 47-113t113-47q66 0 113 47t47 113q0 66-47 113t-113 47ZM160-240v-32q0-34 17.5-62.5T224-378q62-31 126-46.5T480-440q66 0 130 15.5T736-378q29 15 46.5 43.5T800-272v32q0 33-23.5 56.5T720-160H240q-33 0-56.5-23.5T160-240Zm80 0h480v-32q0-11-5.5-20T700-306q-54-27-109-40.5T480-360q-56 0-111 13.5T260-306q-9 5-14.5 14t-5.5 20v32Zm240-320q33 0 56.5-23.5T560-640q0-33-23.5-56.5T480-720q-33 0-56.5 23.5T400-640q0 33 23.5 56.5T480-560Zm0-80Zm0 400Z"/></svg>`;
function getWaveStroke() { function getWaveStroke() {
const styles = getComputedStyle(document.documentElement); const styles = getComputedStyle(document.documentElement);
const v = styles.getPropertyValue("--wave-stroke").trim(); const v = styles.getPropertyValue("--wave-stroke").trim();
@@ -151,10 +177,16 @@ function fmt1(x) {
return Number.isFinite(n) ? n.toFixed(1) : x; return Number.isFinite(n) ? n.toFixed(1) : x;
} }
// Default WebSocket URL computation let host, port, protocol;
const host = window.location.hostname || "localhost"; port = 8000;
const port = window.location.port; if (isExtension) {
const protocol = window.location.protocol === "https:" ? "wss" : "ws"; host = "localhost";
protocol = "ws";
} else {
host = window.location.hostname || "localhost";
port = window.location.port;
protocol = window.location.protocol === "https:" ? "wss" : "ws";
}
const defaultWebSocketUrl = `${protocol}://${host}${port ? ":" + port : ""}/asr`; const defaultWebSocketUrl = `${protocol}://${host}${port ? ":" + port : ""}/asr`;
// Populate default caption and input // Populate default caption and input
@@ -200,10 +232,11 @@ function setupWebSocket() {
if (waitingForStop) { if (waitingForStop) {
statusText.textContent = "Processing finalized or connection closed."; statusText.textContent = "Processing finalized or connection closed.";
if (lastReceivedData) { if (lastReceivedData) {
renderLinesWithBuffer( renderLinesWithBuffer(
lastReceivedData.lines || [], lastReceivedData.lines || [],
lastReceivedData.buffer_diarization || "", lastReceivedData.buffer_diarization || "",
lastReceivedData.buffer_transcription || "", lastReceivedData.buffer_transcription || "",
lastReceivedData.buffer_translation || "",
0, 0,
0, 0,
true true
@@ -249,6 +282,7 @@ function setupWebSocket() {
lastReceivedData.lines || [], lastReceivedData.lines || [],
lastReceivedData.buffer_diarization || "", lastReceivedData.buffer_diarization || "",
lastReceivedData.buffer_transcription || "", lastReceivedData.buffer_transcription || "",
lastReceivedData.buffer_translation || "",
0, 0,
0, 0,
true true
@@ -269,6 +303,7 @@ function setupWebSocket() {
lines = [], lines = [],
buffer_transcription = "", buffer_transcription = "",
buffer_diarization = "", buffer_diarization = "",
buffer_translation = "",
remaining_time_transcription = 0, remaining_time_transcription = 0,
remaining_time_diarization = 0, remaining_time_diarization = 0,
status = "active_transcription", status = "active_transcription",
@@ -278,6 +313,7 @@ function setupWebSocket() {
lines, lines,
buffer_diarization, buffer_diarization,
buffer_transcription, buffer_transcription,
buffer_translation,
remaining_time_diarization, remaining_time_diarization,
remaining_time_transcription, remaining_time_transcription,
false, false,
@@ -291,6 +327,7 @@ function renderLinesWithBuffer(
lines, lines,
buffer_diarization, buffer_diarization,
buffer_transcription, buffer_transcription,
buffer_translation,
remaining_time_diarization, remaining_time_diarization,
remaining_time_transcription, remaining_time_transcription,
isFinalizing = false, isFinalizing = false,
@@ -309,6 +346,7 @@ function renderLinesWithBuffer(
lines: (lines || []).map((it) => ({ speaker: it.speaker, text: it.text, start: it.start, end: it.end, detected_language: it.detected_language })), lines: (lines || []).map((it) => ({ speaker: it.speaker, text: it.text, start: it.start, end: it.end, detected_language: it.detected_language })),
buffer_transcription: buffer_transcription || "", buffer_transcription: buffer_transcription || "",
buffer_diarization: buffer_diarization || "", buffer_diarization: buffer_diarization || "",
buffer_translation: buffer_translation,
status: current_status, status: current_status,
showLoading, showLoading,
showTransLag, showTransLag,
@@ -335,19 +373,17 @@ function renderLinesWithBuffer(
let speakerLabel = ""; let speakerLabel = "";
if (item.speaker === -2) { if (item.speaker === -2) {
const silenceIcon = `<img class="silence-icon" src="/web/src/silence.svg" alt="Silence" />`;
speakerLabel = `<span class="silence">${silenceIcon}<span id='timeInfo'>${timeInfo}</span></span>`; speakerLabel = `<span class="silence">${silenceIcon}<span id='timeInfo'>${timeInfo}</span></span>`;
} else if (item.speaker == 0 && !isFinalizing) { } else if (item.speaker == 0 && !isFinalizing) {
speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'><span class="loading-diarization-value">${fmt1( speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'><span class="loading-diarization-value">${fmt1(
remaining_time_diarization remaining_time_diarization
)}</span> second(s) of audio are undergoing diarization</span></span>`; )}</span> second(s) of audio are undergoing diarization</span></span>`;
} else if (item.speaker !== 0) { } else if (item.speaker !== 0) {
const speakerIcon = `<img class="speaker-icon" src="/web/src/speaker.svg" alt="Speaker ${item.speaker}" />`;
const speakerNum = `<span class="speaker-badge">${item.speaker}</span>`; const speakerNum = `<span class="speaker-badge">${item.speaker}</span>`;
speakerLabel = `<span id="speaker">${speakerIcon}${speakerNum}<span id='timeInfo'>${timeInfo}</span></span>`; speakerLabel = `<span id="speaker">${speakerIcon}${speakerNum}<span id='timeInfo'>${timeInfo}</span></span>`;
if (item.detected_language) { if (item.detected_language) {
speakerLabel += `<span class="label_language"><img src="/web/src/language.svg" alt="Detected language" width="12" height="12" /><span>${item.detected_language}</span></span>`; speakerLabel += `<span class="label_language">${languageIcon}<span>${item.detected_language}</span></span>`;
} }
} }
@@ -355,12 +391,11 @@ function renderLinesWithBuffer(
if (idx === lines.length - 1) { if (idx === lines.length - 1) {
if (!isFinalizing && item.speaker !== -2) { if (!isFinalizing && item.speaker !== -2) {
if (remaining_time_transcription > 0) {
speakerLabel += `<span class="label_transcription"><span class="spinner"></span>Transcription lag <span id='timeInfo'><span class="lag-transcription-value">${fmt1( speakerLabel += `<span class="label_transcription"><span class="spinner"></span>Transcription lag <span id='timeInfo'><span class="lag-transcription-value">${fmt1(
remaining_time_transcription remaining_time_transcription
)}</span>s</span></span>`; )}</span>s</span></span>`;
}
if (buffer_diarization && remaining_time_diarization > 0) { if (buffer_diarization && remaining_time_diarization) {
speakerLabel += `<span class="label_diarization"><span class="spinner"></span>Diarization lag<span id='timeInfo'><span class="lag-diarization-value">${fmt1( speakerLabel += `<span class="label_diarization"><span class="spinner"></span>Diarization lag<span id='timeInfo'><span class="lag-diarization-value">${fmt1(
remaining_time_diarization remaining_time_diarization
)}</span>s</span></span>`; )}</span>s</span></span>`;
@@ -385,12 +420,24 @@ function renderLinesWithBuffer(
} }
} }
} }
let translationContent = "";
if (item.translation) { if (item.translation) {
currentLineText += `<div class="label_translation"> translationContent += item.translation.trim();
<img src="/web/src/translate.svg" alt="Translation" width="12" height="12" /> }
<span>${item.translation}</span> if (idx === lines.length - 1 && buffer_translation) {
</div>`; const bufferPiece = isFinalizing
? buffer_translation
: `<span class="buffer_translation">${buffer_translation}</span>`;
translationContent += translationContent ? `${bufferPiece}` : bufferPiece;
}
if (translationContent.trim().length > 0) {
currentLineText += `
<div>
<div class="label_translation">
${translationIcon}
<span class="translation_text">${translationContent}</span>
</div>
</div>`;
} }
return currentLineText.trim().length > 0 || speakerLabel.length > 0 return currentLineText.trim().length > 0 || speakerLabel.length > 0
@@ -465,11 +512,44 @@ async function startRecording() {
console.log("Error acquiring wake lock."); console.log("Error acquiring wake lock.");
} }
const audioConstraints = selectedMicrophoneId let stream;
? { audio: { deviceId: { exact: selectedMicrophoneId } } }
: { audio: true };
const stream = await navigator.mediaDevices.getUserMedia(audioConstraints); // chromium extension. in the future, both chrome page audio and mic will be used
if (isExtension) {
try {
stream = await new Promise((resolve, reject) => {
chrome.tabCapture.capture({audio: true}, (s) => {
if (s) {
resolve(s);
} else {
reject(new Error('Tab capture failed or not available'));
}
});
});
try {
outputAudioContext = new (window.AudioContext || window.webkitAudioContext)();
audioSource = outputAudioContext.createMediaStreamSource(stream);
audioSource.connect(outputAudioContext.destination);
} catch (audioError) {
console.warn('could not preserve system audio:', audioError);
}
statusText.textContent = "Using tab audio capture.";
} catch (tabError) {
console.log('Tab capture not available, falling back to microphone', tabError);
const audioConstraints = selectedMicrophoneId
? { audio: { deviceId: { exact: selectedMicrophoneId } } }
: { audio: true };
stream = await navigator.mediaDevices.getUserMedia(audioConstraints);
statusText.textContent = "Using microphone audio.";
}
} else if (isWebContext) {
const audioConstraints = selectedMicrophoneId
? { audio: { deviceId: { exact: selectedMicrophoneId } } }
: { audio: true };
stream = await navigator.mediaDevices.getUserMedia(audioConstraints);
}
audioContext = new (window.AudioContext || window.webkitAudioContext)(); audioContext = new (window.AudioContext || window.webkitAudioContext)();
analyser = audioContext.createAnalyser(); analyser = audioContext.createAnalyser();
@@ -603,6 +683,16 @@ async function stopRecording() {
audioContext = null; audioContext = null;
} }
if (audioSource) {
audioSource.disconnect();
audioSource = null;
}
if (outputAudioContext && outputAudioContext.state !== "closed") {
outputAudioContext.close()
outputAudioContext = null;
}
if (animationFrame) { if (animationFrame) {
cancelAnimationFrame(animationFrame); cancelAnimationFrame(animationFrame);
animationFrame = null; animationFrame = null;
@@ -654,7 +744,7 @@ function updateUI() {
statusText.textContent = "Please wait for processing to complete..."; statusText.textContent = "Please wait for processing to complete...";
} }
} else if (isRecording) { } else if (isRecording) {
statusText.textContent = "Recording..."; statusText.textContent = "";
} else { } else {
if ( if (
statusText.textContent !== "Finished processing audio! Ready to record again." && statusText.textContent !== "Finished processing audio! Ready to record again." &&
@@ -688,3 +778,40 @@ navigator.mediaDevices.addEventListener('devicechange', async () => {
console.log("Error re-enumerating microphones:", error); console.log("Error re-enumerating microphones:", error);
} }
}); });
settingsToggle.addEventListener("click", () => {
settingsDiv.classList.toggle("visible");
settingsToggle.classList.toggle("active");
});
if (isExtension) {
async function checkAndRequestPermissions() {
const micPermission = await navigator.permissions.query({
name: "microphone",
});
const permissionDisplay = document.getElementById("audioPermission");
if (permissionDisplay) {
permissionDisplay.innerText = `MICROPHONE: ${micPermission.state}`;
}
// if (micPermission.state !== "granted") {
// chrome.tabs.create({ url: "welcome.html" });
// }
const intervalId = setInterval(async () => {
const micPermission = await navigator.permissions.query({
name: "microphone",
});
if (micPermission.state === "granted") {
if (permissionDisplay) {
permissionDisplay.innerText = `MICROPHONE: ${micPermission.state}`;
}
clearInterval(intervalId);
}
}, 100);
}
void checkAndRequestPermissions();
}

View File

@@ -23,6 +23,24 @@ def get_inline_ui_html():
with resources.files('whisperlivekit.web').joinpath('live_transcription.js').open('r', encoding='utf-8') as f: with resources.files('whisperlivekit.web').joinpath('live_transcription.js').open('r', encoding='utf-8') as f:
js_content = f.read() js_content = f.read()
with resources.files('whisperlivekit.web').joinpath('pcm_worklet.js').open('r', encoding='utf-8') as f:
worklet_code = f.read()
with resources.files('whisperlivekit.web').joinpath('recorder_worker.js').open('r', encoding='utf-8') as f:
worker_code = f.read()
js_content = js_content.replace(
'await audioContext.audioWorklet.addModule("/web/pcm_worklet.js");',
'const workletBlob = new Blob([`' + worklet_code + '`], { type: "application/javascript" });\n' +
'const workletUrl = URL.createObjectURL(workletBlob);\n' +
'await audioContext.audioWorklet.addModule(workletUrl);'
)
js_content = js_content.replace(
'recorderWorker = new Worker("/web/recorder_worker.js");',
'const workerBlob = new Blob([`' + worker_code + '`], { type: "application/javascript" });\n' +
'const workerUrl = URL.createObjectURL(workerBlob);\n' +
'recorderWorker = new Worker(workerUrl);'
)
# SVG files # SVG files
with resources.files('whisperlivekit.web').joinpath('src', 'system_mode.svg').open('r', encoding='utf-8') as f: with resources.files('whisperlivekit.web').joinpath('src', 'system_mode.svg').open('r', encoding='utf-8') as f:
system_svg = f.read() system_svg = f.read()
@@ -33,15 +51,18 @@ def get_inline_ui_html():
with resources.files('whisperlivekit.web').joinpath('src', 'dark_mode.svg').open('r', encoding='utf-8') as f: with resources.files('whisperlivekit.web').joinpath('src', 'dark_mode.svg').open('r', encoding='utf-8') as f:
dark_svg = f.read() dark_svg = f.read()
dark_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(dark_svg.encode('utf-8')).decode('utf-8')}" dark_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(dark_svg.encode('utf-8')).decode('utf-8')}"
with resources.files('whisperlivekit.web').joinpath('src', 'settings.svg').open('r', encoding='utf-8') as f:
settings = f.read()
settings_uri = f"data:image/svg+xml;base64,{base64.b64encode(settings.encode('utf-8')).decode('utf-8')}"
# Replace external references # Replace external references
html_content = html_content.replace( html_content = html_content.replace(
'<link rel="stylesheet" href="/web/live_transcription.css" />', '<link rel="stylesheet" href="live_transcription.css" />',
f'<style>\n{css_content}\n</style>' f'<style>\n{css_content}\n</style>'
) )
html_content = html_content.replace( html_content = html_content.replace(
'<script src="/web/live_transcription.js"></script>', '<script src="live_transcription.js"></script>',
f'<script>\n{js_content}\n</script>' f'<script>\n{js_content}\n</script>'
) )
@@ -61,6 +82,11 @@ def get_inline_ui_html():
f'<img src="{dark_data_uri}" alt="" />' f'<img src="{dark_data_uri}" alt="" />'
) )
html_content = html_content.replace(
'<img src="web/src/settings.svg" alt="Settings" />',
f'<img src="{settings_uri}" alt="" />'
)
return html_content return html_content
except Exception as e: except Exception as e:

View File

@@ -0,0 +1,463 @@
import hashlib
import io
import json
import os
import urllib
import warnings
from typing import Dict, List, Optional, Union
import torch
from tqdm import tqdm
from pathlib import Path
from torch import Tensor
from whisperlivekit.whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim
from whisperlivekit.whisper.decoding import DecodingOptions, DecodingResult, decode, detect_language
from whisperlivekit.whisper.model import ModelDimensions, Whisper
from whisperlivekit.whisper.transcribe import transcribe
from whisperlivekit.whisper.version import __version__
_MODELS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
"large-v3-turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
"turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
}
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
_ALIGNMENT_HEADS = {
"tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
"base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
"base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
"small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
"small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
"large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
"turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
}
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
os.makedirs(root, exist_ok=True)
expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, os.path.basename(url))
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
with open(download_target, "rb") as f:
model_bytes = f.read()
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return model_bytes if in_memory else download_target
else:
warnings.warn(
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
)
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(
total=int(source.info().get("Content-Length")),
ncols=80,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
)
return model_bytes if in_memory else download_target
def available_models() -> List[str]:
"""Returns the names of available models"""
return list(_MODELS.keys())
def _infer_dims_from_config(path: str) -> Optional[ModelDimensions]:
"""
attempt to infer ModelDimensions from a HF style config.json located
next to the given checkpoint, usefull for distilled models
"""
candidates = []
if os.path.isdir(path):
candidates.append(os.path.join(path, "config.json"))
else:
candidates.append(os.path.join(os.path.dirname(path), "config.json"))
for candidate in candidates:
if not os.path.isfile(candidate):
continue
with open(candidate, "r", encoding="utf-8") as f:
config = json.load(f)
try:
return ModelDimensions(
n_mels=config["num_mel_bins"],
n_audio_ctx=config["max_source_positions"],
n_audio_state=config["d_model"],
n_audio_head=config["encoder_attention_heads"],
n_audio_layer=config.get("encoder_layers")
or config["num_hidden_layers"],
n_vocab=config["vocab_size"],
n_text_ctx=config["max_target_positions"],
n_text_state=config["d_model"],
n_text_head=config["decoder_attention_heads"],
n_text_layer=config["decoder_layers"],
)
except KeyError as err:
warnings.warn(f"Missing key {err} in HuggingFace config {candidate}")
return None
return None
def _convert_hf_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
converts a HF checkpoint state_dict into the naming convention used by
default whisper
"""
if not any(k.startswith("model.") for k in state_dict):
return state_dict
def map_block(prefix: str, target_prefix: str, remainder: str) -> Optional[str]:
if remainder.startswith("self_attn."):
suffix = remainder.split(".", 1)[1]
mapping = {
"q_proj": "attn.query",
"k_proj": "attn.key",
"v_proj": "attn.value",
"out_proj": "attn.out",
}
stem = mapping.get(suffix.split(".")[0])
if stem:
rest = suffix.split(".", 1)[1] if "." in suffix else ""
return f"{target_prefix}.{stem}" + (f".{rest}" if rest else "")
elif remainder == "self_attn_layer_norm.weight":
return f"{target_prefix}.attn_ln.weight"
elif remainder == "self_attn_layer_norm.bias":
return f"{target_prefix}.attn_ln.bias"
elif remainder.startswith("encoder_attn."):
suffix = remainder.split(".", 1)[1]
mapping = {
"q_proj": "cross_attn.query",
"k_proj": "cross_attn.key",
"v_proj": "cross_attn.value",
"out_proj": "cross_attn.out",
}
stem = mapping.get(suffix.split(".", 1)[0])
if stem:
rest = suffix.split(".", 1)[1] if "." in suffix else ""
return f"{target_prefix}.{stem}" + (f".{rest}" if rest else "")
elif remainder == "encoder_attn_layer_norm.weight":
return f"{target_prefix}.cross_attn_ln.weight"
elif remainder == "encoder_attn_layer_norm.bias":
return f"{target_prefix}.cross_attn_ln.bias"
elif remainder.startswith("fc1."):
return f"{target_prefix}.mlp.0.{remainder.split('.',1)[1]}"
elif remainder.startswith("fc2."):
return f"{target_prefix}.mlp.2.{remainder.split('.',1)[1]}"
elif remainder == "final_layer_norm.weight":
return f"{target_prefix}.mlp_ln.weight"
elif remainder == "final_layer_norm.bias":
return f"{target_prefix}.mlp_ln.bias"
return None
converted = {}
for key, value in state_dict.items():
if not key.startswith("model."):
continue
subkey = key[len("model.") :]
if subkey.startswith("encoder.layers."):
parts = subkey.split(".")
layer_idx = parts[2]
remainder = ".".join(parts[3:])
mapped = map_block(subkey, f"encoder.blocks.{layer_idx}", remainder)
elif subkey.startswith("decoder.layers."):
parts = subkey.split(".")
layer_idx = parts[2]
remainder = ".".join(parts[3:])
mapped = map_block(subkey, f"decoder.blocks.{layer_idx}", remainder)
elif subkey.startswith("encoder.conv") or subkey.startswith("decoder.conv"):
mapped = subkey
elif subkey == "encoder.embed_positions.weight":
mapped = "encoder.positional_embedding"
elif subkey == "decoder.embed_positions.weight":
mapped = "decoder.positional_embedding"
elif subkey == "encoder.layer_norm.weight":
mapped = "encoder.ln_post.weight"
elif subkey == "encoder.layer_norm.bias":
mapped = "encoder.ln_post.bias"
elif subkey.startswith("decoder.embed_tokens."):
mapped = subkey.replace("embed_tokens", "token_embedding", 1)
elif subkey == "decoder.layer_norm.weight":
mapped = "decoder.ln.weight"
elif subkey == "decoder.layer_norm.bias":
mapped = "decoder.ln.bias"
else:
mapped = None
if mapped:
converted[mapped] = value
return converted if converted else state_dict
def _load_lora_state(lora_path: str):
safe_path = os.path.join(lora_path, "adapter_model.safetensors")
bin_path = os.path.join(lora_path, "adapter_model.bin")
if os.path.isfile(safe_path):
try:
from safetensors.torch import load_file
except ImportError as exc:
raise ImportError(
"Loading LoRA adapters stored as .safetensors requires the `safetensors` package."
) from exc
return load_file(safe_path)
if os.path.isfile(bin_path):
return torch.load(bin_path, map_location="cpu")
raise FileNotFoundError(
f"No adapter weights found under {lora_path}. Expected adapter_model.safetensors or adapter_model.bin."
)
def _collapse_hf_module_name(module: str):
if module.startswith("base_model."):
module = module[len("base_model.") :]
if module.startswith("model.model."):
module = module[len("model.") :]
if not module.startswith("model."):
module = f"model.{module}"
return module
def _apply_lora_adapter(state_dict: Dict[str, Tensor], lora_path: Optional[str]):
if not lora_path:
return
config_path = os.path.join(lora_path, "adapter_config.json")
if not os.path.isfile(config_path):
raise FileNotFoundError(f"Missing adapter_config.json inside {lora_path}")
with open(config_path, "r", encoding="utf-8") as handle:
config = json.load(handle)
if config.get("peft_type") != "LORA":
raise ValueError("Only LoRA adapters are supported.")
r = config.get("r")
alpha = config.get("lora_alpha") or config.get("alpha")
if not r or not alpha:
raise ValueError("LoRA config must include `r` and `lora_alpha`.")
scaling = alpha / r
adapter_state = _load_lora_state(lora_path)
lora_layers: Dict[str, Dict[str, Tensor]] = {}
for key, tensor in adapter_state.items():
if key.endswith("lora_A.weight"):
module = key[: -len(".lora_A.weight")]
lora_layers.setdefault(module, {})["A"] = tensor
elif key.endswith("lora_B.weight"):
module = key[: -len(".lora_B.weight")]
lora_layers.setdefault(module, {})["B"] = tensor
if not lora_layers:
raise ValueError(f"No LoRA tensors found in {lora_path}")
for module, parts in lora_layers.items():
if "A" not in parts or "B" not in parts:
raise ValueError(f"Incomplete LoRA tensors for module '{module}'")
hf_module = _collapse_hf_module_name(module)
hf_weight_key = f"{hf_module}.weight"
delta = parts["B"] @ parts["A"]
delta = delta * scaling
converted = _convert_hf_state_dict({hf_weight_key: delta})
if not converted:
raise KeyError(f"Failed to map LoRA module '{module}' into Whisper state dict.")
target_name, delta_tensor = next(iter(converted.items()))
if target_name not in state_dict:
raise KeyError(
f"LoRA module '{module}' mapped to '{target_name}', but the base model has no such parameter."
)
state_dict[target_name] = state_dict[target_name] + delta_tensor.to(
dtype=state_dict[target_name].dtype, device=state_dict[target_name].device
)
def load_model(
name: str,
device: Optional[Union[str, torch.device]] = None,
download_root: str = None,
in_memory: bool = False,
decoder_only: bool = False,
custom_alignment_heads: Optional[str] = None,
lora_path: Optional[str] = None,
) -> Whisper:
"""
Load a Whisper ASR model
Parameters
----------
name : str
one of the official model names listed by `whisper.available_models()`, or
path to a model checkpoint containing the model dimensions and the model state_dict.
device : Union[str, torch.device]
the PyTorch device to put the model into
download_root: str
path to download the model files; by default, it uses "~/.cache/whisper"
in_memory: bool
whether to preload the model weights into host memory
lora_path: str
optional directory containing PEFT LoRA adapter weights (adapter_config + adapter_model)
Returns
-------
model : Whisper
The Whisper ASR model instance
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if download_root is None:
default = os.path.join(os.path.expanduser("~"), ".cache")
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
elif os.path.isfile(name):
checkpoint_file = open(name, "rb").read() if in_memory else name
else:
raise RuntimeError(
f"Model {name} not found; available models = {available_models()}"
)
alignment_heads = _ALIGNMENT_HEADS.get(name, None)
if custom_alignment_heads:
alignment_heads = custom_alignment_heads.encode()
if isinstance(checkpoint_file, Path) and checkpoint_file.suffix == '.safetensors':
try:
from safetensors.torch import load_file
except ImportError:
raise ImportError("Please install safetensors to load .safetensors model files: `pip install safetensors`")
if in_memory:
checkpoint = load_file(checkpoint_file, device=device)
else:
checkpoint = load_file(checkpoint_file, device=device)
else:
with (
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
) as fp:
checkpoint = torch.load(fp, map_location=device)
del checkpoint_file
dims_cfg = checkpoint.get("dims") if isinstance(checkpoint, dict) else None
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
else:
state_dict = checkpoint
state_dict = _convert_hf_state_dict(state_dict)
_apply_lora_adapter(state_dict, lora_path)
if dims_cfg is not None:
dims = ModelDimensions(**dims_cfg)
else:
dims = _infer_dims_from_config(name)
if dims is None:
raise RuntimeError(
"Could not determine model dimensions. "
"Ensure the checkpoint includes 'dims' or a HuggingFace config.json is present."
)
if not isinstance(state_dict, dict):
state_dict = checkpoint
model = Whisper(dims, decoder_only=decoder_only)
if decoder_only:
state_dict = {
k: v for k, v in state_dict.items()
if 'encoder' not in k
}
model.load_state_dict(state_dict)
if alignment_heads is not None:
model.set_alignment_heads(alignment_heads)
return model.to(device)
def convert_encoder_to_coreml(
model_name = "base",
output_path= "whisper_encoder.mlpackage",
dummy_frames = 3000, #Number of time frames to use for the dummy mel input during tracing
precision = "float16",
):
import coremltools as ct
model = load_model(model_name, device="cpu", decoder_only=False)
encoder = model.encoder.eval().cpu()
dummy_input = torch.randn(
1,
model.dims.n_mels,
dummy_frames,
dtype=next(encoder.parameters()).dtype,
)
with torch.no_grad():
traced_encoder = torch.jit.trace(encoder, dummy_input)
precision_map = {
"float16": ct.precision.FLOAT16,
"fp16": ct.precision.FLOAT16,
"float32": ct.precision.FLOAT32,
"fp32": ct.precision.FLOAT32,
}
coreml_precision = precision_map[precision.lower()]
mlmodel = ct.convert(
traced_encoder,
inputs=[ct.TensorType(name="mel", shape=dummy_input.shape)],
convert_to= "mlprogram",
compute_precision=coreml_precision,
)
output_path = Path(output_path)
mlmodel.save(str(output_path))
return output_path
# if __name__ == "__main__":
# convert_encoder_to_coreml(model_name="tiny", output_path="whisper_encoder.mlpackage", dummy_frames=3000, precision="float16", convert_to="mlprogram")

View File

@@ -1,110 +0,0 @@
#!/usr/bin/env python3
import sys
import numpy as np
import librosa
from functools import lru_cache
import time
import logging
from .backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR
logger = logging.getLogger(__name__)
WHISPER_LANG_CODES = "af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh".split(
","
)
def create_tokenizer(lan):
"""returns an object that has split function that works like the one of MosesTokenizer"""
assert (
lan in WHISPER_LANG_CODES
), "language must be Whisper's supported lang code: " + " ".join(WHISPER_LANG_CODES)
if lan == "uk":
import tokenize_uk
class UkrainianTokenizer:
def split(self, text):
return tokenize_uk.tokenize_sents(text)
return UkrainianTokenizer()
# supported by fast-mosestokenizer
if (
lan
in "as bn ca cs de el en es et fi fr ga gu hi hu is it kn lt lv ml mni mr nl or pa pl pt ro ru sk sl sv ta te yue zh".split()
):
from mosestokenizer import MosesSentenceSplitter
return MosesSentenceSplitter(lan)
# the following languages are in Whisper, but not in wtpsplit:
if (
lan
in "as ba bo br bs fo haw hr ht jw lb ln lo mi nn oc sa sd sn so su sw tk tl tt".split()
):
logger.debug(
f"{lan} code is not supported by wtpsplit. Going to use None lang_code option."
)
lan = None
from wtpsplit import WtP
# downloads the model from huggingface on the first use
wtp = WtP("wtp-canine-s-12l-no-adapters")
class WtPtok:
def split(self, sent):
return wtp.split(sent, lang_code=lan)
return WtPtok()
def backend_factory(args):
backend = args.backend
if backend == "openai-api":
logger.debug("Using OpenAI API.")
asr = OpenaiApiASR(lan=args.lan)
else:
if backend == "faster-whisper":
asr_cls = FasterWhisperASR
elif backend == "mlx-whisper":
asr_cls = MLXWhisper
else:
asr_cls = WhisperTimestampedASR
# Only for FasterWhisperASR and WhisperTimestampedASR
size = args.model
t = time.time()
logger.info(f"Loading Whisper {size} model for language {args.lan}...")
asr = asr_cls(
modelsize=size,
lan=args.lan,
cache_dir=getattr(args, 'model_cache_dir', None),
model_dir=getattr(args, 'model_dir', None),
)
e = time.time()
logger.info(f"done. It took {round(e-t,2)} seconds.")
# Apply common configurations
if getattr(args, "vad", False): # Checks if VAD argument is present and True
logger.info("Setting VAD filter")
asr.use_vad()
language = args.lan
if args.task == "translate":
if backend != "simulstreaming":
asr.set_translate_task()
tgt_language = "en" # Whisper translates into English
else:
tgt_language = language # Whisper transcribes in this language
# Create the tokenizer
if args.buffer_trimming == "sentence":
tokenizer = create_tokenizer(tgt_language)
else:
tokenizer = None
return asr, tokenizer