3 Commits

Author SHA1 Message Date
Dominik Macháček
5b5805231e better words in README.md 2024-01-17 16:32:12 +01:00
Dominik Macháček
cdb2a0ba17 seamless done 2024-01-17 16:22:28 +01:00
Dominik Macháček
1703969432 seamless streaming integrated 2024-01-16 16:45:33 +01:00
107 changed files with 1348 additions and 117737 deletions

30
.gitignore vendored
View File

@@ -55,6 +55,22 @@ coverage.xml
*.mo *.mo
*.pot *.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder # PyBuilder
target/ target/
@@ -111,17 +127,3 @@ dmypy.json
# Pyre type checker # Pyre type checker
.pyre/ .pyre/
*.wav
run_*.sh
# Downloaded models
*.pt
# Debug & testing
test_*.py
launch.json
.DS_Store
test/*
nllb-200-distilled-600M-ctranslate2/*
*.mp3

View File

@@ -1,46 +0,0 @@
# Contributing
Thank you for considering contributing ! We appreciate your time and effort to help make this project better.
## Before You Start
1. **Search for Existing Issues or Discussions:**
- Before opening a new issue or discussion, please check if there's already an existing one related to your topic. This helps avoid duplicates and keeps discussions centralized.
2. **Discuss Your Contribution:**
- If you plan to make a significant change, it's advisable to discuss it in an issue first. This ensures that your contribution aligns with the project's goals and avoids duplicated efforts.
3. **General questions about whisper streaming web:**
- For general questions about whisper streaming web, use the discussion space on GitHub. This helps in fostering a collaborative environment and encourages knowledge-sharing.
## Opening Issues
If you encounter a problem with WhisperLiveKit or want to suggest an improvement, please follow these guidelines when opening an issue:
- **Bug Reports:**
- Clearly describe the error. **Please indicate the parameters you use, especially the model(s)**
- Provide a minimal, reproducible example that demonstrates the issue.
- **Feature Requests:**
- Clearly outline the new feature you are proposing.
- Explain how it would benefit the project.
## Opening Pull Requests
We welcome and appreciate contributions! To ensure a smooth review process, please follow these guidelines when opening a pull request:
- **Commit Messages:**
- Write clear and concise commit messages, explaining the purpose of each change.
- **Documentation:**
- Update documentation when introducing new features or making changes that impact existing functionality.
- **Tests:**
- If applicable, add or update tests to cover your changes.
- **Discuss Before Major Changes:**
- If your PR includes significant changes, discuss it in an issue first.
## Thank You
Your contributions make WhisperLiveKit better for everyone. Thank you for your time and dedication!

View File

@@ -1,91 +0,0 @@
# 1. Simulstreaming: Decouple the encoder for faster inference
Simulstreaming encoder time (whisperlivekit/simul_whisper/simul_whisper.py l. 397) experimentations :
On macOS Apple Silicon M4 :
| Encoder | base.en | small |
|--------|---------|-------|
| WHISPER (no modification) | 0.35s | 1.09s |
| FASTER_WHISPER | 0.4s | 1.20s |
| MLX_WHISPER | 0.07s | 0.20s |
Memory saved by only loading encoder for optimized framework:
For tiny.en, mlx whisper:
Sizes MLX whisper:
Decoder weights: 59110771 bytes
Encoder weights: 15268874 bytes
# 2. Translation: Faster model for each system
## Benchmark Results
Testing on MacBook M3 with NLLB-200-distilled-600M model:
### Standard Transformers vs CTranslate2
| Test Text | Standard Inference Time | CTranslate2 Inference Time | Speedup |
|-----------|-------------------------|---------------------------|---------|
| UN Chief says there is no military solution in Syria | 0.9395s | 2.0472s | 0.5x |
| The rapid advancement of AI technology is transforming various industries | 0.7171s | 1.7516s | 0.4x |
| Climate change poses a significant threat to global ecosystems | 0.8533s | 1.8323s | 0.5x |
| International cooperation is essential for addressing global challenges | 0.7209s | 1.3575s | 0.5x |
| The development of renewable energy sources is crucial for a sustainable future | 0.8760s | 1.5589s | 0.6x |
**Results:**
- Total Standard time: 4.1068s
- Total CTranslate2 time: 8.5476s
- CTranslate2 is slower on this system --> Use Transformers, and ideally we would have an mlx implementation.
# 3. SortFormer Diarization: 4-to-2 Speaker Constraint Algorithm
Transform a diarization model that predicts up to 4 speakers into one that predicts up to 2 speakers by mapping the output predictions.
## Problem Statement
- Input: `self.total_preds` with shape `(x, x, 4)` - predictions for 4 speakers
- Output: Constrained predictions with shape `(x, x, 2)` - predictions for 2 speakers
#
### Initial Setup
For each time step `i`, we have a ranking of 4 speaker predictions (1-4). When only 2 speakers are present, the model will have close predictions for the 2 active speaker positions.
Instead of `np.argmax(preds_np, axis=1)`, we take the top 2 predictions and build a dynamic 4→2 mapping that can evolve over time.
### Algorithm
```python
top_2_speakers = np.argsort(preds_np, axis=1)[:, -2:]
```
- `DS_a_{i}`: Top detected speaker for prediction i
- `DS_b_{i}`: Second detected speaker for prediction i
- `AS_{i}`: Attributed speaker for prediction i
- `GTS_A`: Ground truth speaker A
- `GTS_B`: Ground truth speaker B
- `DIST(a, b)`: Distance between detected speakers a and b
3. **Attribution Logic**
```
AS_0 ← A
AS_1 ← B
IF DIST(DS_a_0, DS_a_1) < DIST(DS_a_0, DS_a_2) AND
DIST(DS_a_0, DS_a_1) < DIST(DS_a_1, DS_a_2):
# Likely that DS_a_0 = DS_a_1 (same speaker)
AS_1 ← A
AS_2 ← B
ELIF DIST(DS_a_0, DS_a_2) < DIST(DS_a_0, DS_a_1) AND
DIST(DS_a_0, DS_a_2) < DIST(DS_a_1, DS_a_2):
AS_2 ← A
ELSE:
AS_2 ← B
to finish
```

View File

@@ -1,84 +0,0 @@
FROM nvidia/cuda:12.9.1-cudnn-devel-ubuntu24.04
ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1
WORKDIR /app
ARG EXTRAS
ARG HF_PRECACHE_DIR
ARG HF_TKN_FILE
RUN apt-get update && \
apt-get install -y --no-install-recommends \
python3 \
python3-pip \
python3-venv \
ffmpeg \
git \
build-essential \
python3-dev \
ca-certificates && \
rm -rf /var/lib/apt/lists/*
RUN python3 -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
# timeout/retries for large torch wheels
RUN pip3 install --upgrade pip setuptools wheel && \
pip3 --disable-pip-version-check install --timeout=120 --retries=5 \
--index-url https://download.pytorch.org/whl/cu129 \
torch torchaudio \
|| (echo "Initial install failed — retrying with extended timeout..." && \
pip3 --disable-pip-version-check install --timeout=300 --retries=3 \
--index-url https://download.pytorch.org/whl/cu129 \
torch torchvision torchaudio)
COPY . .
# Install WhisperLiveKit directly, allowing for optional dependencies
# Example: --build-arg EXTRAS="translation"
RUN if [ -n "$EXTRAS" ]; then \
echo "Installing with extras: [$EXTRAS]"; \
pip install --no-cache-dir "whisperlivekit[$EXTRAS]"; \
else \
echo "Installing base package only"; \
pip install --no-cache-dir whisperlivekit; \
fi
# In-container caching for Hugging Face models by:
# A) Make the cache directory persistent via an anonymous volume.
# Note: This only persists for a single, named container. This is
# only for convenience at de/test stage.
# For prod, it is better to use a named volume via host mount/k8s.
VOLUME ["/root/.cache/huggingface/hub"]
# or
# B) Conditionally copy a local pre-cache from the build context to the
# container's cache via the HF_PRECACHE_DIR build-arg.
# WARNING: This will copy ALL files in the pre-cache location.
# Conditionally copy a cache directory if provided
RUN if [ -n "$HF_PRECACHE_DIR" ]; then \
echo "Copying Hugging Face cache from $HF_PRECACHE_DIR"; \
mkdir -p /root/.cache/huggingface/hub && \
cp -r $HF_PRECACHE_DIR/* /root/.cache/huggingface/hub; \
else \
echo "No local Hugging Face cache specified, skipping copy"; \
fi
# Conditionally copy a Hugging Face token if provided. Useful for Diart backend (pyannote audio models)
RUN if [ -n "$HF_TKN_FILE" ]; then \
echo "Copying Hugging Face token from $HF_TKN_FILE"; \
mkdir -p /root/.cache/huggingface && \
cp $HF_TKN_FILE /root/.cache/huggingface/token; \
else \
echo "No Hugging Face token file specified, skipping token setup"; \
fi
EXPOSE 8000
ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"]
CMD ["--model", "medium"]

View File

@@ -1,61 +0,0 @@
FROM python:3.13-slim
ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1
WORKDIR /app
ARG EXTRAS
ARG HF_PRECACHE_DIR
ARG HF_TKN_FILE
RUN apt-get update && \
apt-get install -y --no-install-recommends \
ffmpeg \
git \
build-essential \
python3-dev && \
rm -rf /var/lib/apt/lists/*
# Install CPU-only PyTorch
RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
COPY . .
# Install WhisperLiveKit directly, allowing for optional dependencies
RUN if [ -n "$EXTRAS" ]; then \
echo "Installing with extras: [$EXTRAS]"; \
pip install --no-cache-dir whisperlivekit[$EXTRAS]; \
else \
echo "Installing base package only"; \
pip install --no-cache-dir whisperlivekit; \
fi
# Enable in-container caching for Hugging Face models
VOLUME ["/root/.cache/huggingface/hub"]
# Conditionally copy a local pre-cache from the build context
RUN if [ -n "$HF_PRECACHE_DIR" ]; then \
echo "Copying Hugging Face cache from $HF_PRECACHE_DIR"; \
mkdir -p /root/.cache/huggingface/hub && \
cp -r $HF_PRECACHE_DIR/* /root/.cache/huggingface/hub; \
else \
echo "No local Hugging Face cache specified, skipping copy"; \
fi
# Conditionally copy a Hugging Face token if provided
RUN if [ -n "$HF_TKN_FILE" ]; then \
echo "Copying Hugging Face token from $HF_TKN_FILE"; \
mkdir -p /root/.cache/huggingface && \
cp $HF_TKN_FILE /root/.cache/huggingface/token; \
else \
echo "No Hugging Face token file specified, skipping token setup"; \
fi
# Expose port for the transcription server
EXPOSE 8000
ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"]
# Default args - you might want to use a smaller model for CPU
CMD ["--model", "tiny"]

223
LICENSE
View File

@@ -1,210 +1,21 @@
Apache License MIT License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION Copyright (c) 2023 ÚFAL
1. Definitions. Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
"License" shall mean the terms and conditions for use, reproduction, The above copyright notice and this permission notice shall be included in all
and distribution as defined by Sections 1 through 9 of this document. copies or substantial portions of the Software.
"Licensor" shall mean the copyright owner or entity authorized by THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
the copyright owner that is granting the License. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
"Legal Entity" shall mean the union of the acting entity and all AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
other entities that control, are controlled by, or are under common LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
control with that entity. For the purposes of this definition, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
"control" means (i) the power, direct or indirect, to cause the SOFTWARE.
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2025 Quentin Fuxa
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
---
## Based on:
- **SimulWhisper** by Speech and Audio Technology LAB of Tsinghua University Apache-2.0 https://github.com/ufal/SimulStreaming
- **SimulStreaming** by ÚFAL MIT License https://github.com/ufal/SimulStreaming
- **NeMo** by NVidia - Apache-2.0 - https://github.com/NVIDIA-NeMo/NeMo
- **whisper_streaming** by ÚFAL MIT License https://github.com/ufal/whisper_streaming.
- **silero-vad** by Snakers4 MIT License https://github.com/snakers4/silero-vad.
- **Diart** by juanmc2005 MIT License https://github.com/juanmc2005/diart.

448
README.md
View File

@@ -1,275 +1,241 @@
<h1 align="center">WLK</h1> # whisper_streaming
<p align="center"><b>WhisperLiveKit: Ultra-low-latency, self-hosted speech-to-text with speaker identification</b></p> Whisper realtime streaming for long speech-to-text transcription and translation
**Turning Whisper into Real-Time Transcription System**
Demonstration paper, by Dominik Macháček, Raj Dabre, Ondřej Bojar, 2023
Abstract: Whisper is one of the recent state-of-the-art multilingual speech recognition and translation models, however, it is not designed for real time transcription. In this paper, we build on top of Whisper and create Whisper-Streaming, an implementation of real-time speech transcription and translation of Whisper-like models. Whisper-Streaming uses local agreement policy with self-adaptive latency to enable streaming transcription. We show that Whisper-Streaming achieves high quality and 3.3 seconds latency on unsegmented long-form speech transcription test set, and we demonstrate its robustness and practical usability as a component in live transcription service at a multilingual conference.
<p align="center"> Paper in proceedings: http://www.afnlp.org/conferences/ijcnlp2023/proceedings/main-demo/cdrom/pdf/2023.ijcnlp-demo.3.pdf
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/demo.png" alt="WhisperLiveKit Demo" width="730">
</p>
Demo video: https://player.vimeo.com/video/840442741
<p align="center"> [Slides](http://ufallab.ms.mff.cuni.cz/~machacek/pre-prints/AACL23-2.11.2023-Turning-Whisper-oral.pdf) -- 15 minutes oral presentation at IJCNLP-AACL 2023
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a>
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.15-dark_green"></a>
<a href="https://huggingface.co/qfuxa/whisper-base-french-lora">
<img alt="Hugging Face Weights" src="https://img.shields.io/badge/🤗-Hugging%20Face%20Weights-yellow" />
</a>
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-Apache 2.0-dark_green"></a>
</p>
Please, cite us. [Bibtex citation](http://www.afnlp.org/conferences/ijcnlp2023/proceedings/main-demo/cdrom/bib/2023.ijcnlp-demo.3.bib):
#### Powered by Leading Research:
- Simul-[Whisper](https://arxiv.org/pdf/2406.10052)/[Streaming](https://arxiv.org/abs/2506.17077) (SOTA 2025) - Ultra-low latency transcription using [AlignAtt policy](https://arxiv.org/pdf/2305.11408)
- [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting) (2025), based on [distilled](https://huggingface.co/entai2965/nllb-200-distilled-600M-ctranslate2) [NLLB](https://arxiv.org/abs/2207.04672) (2022, 2024) - Simulatenous translation from & to 200 languages.
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - Low latency transcription using [LocalAgreement policy](https://www.isca-archive.org/interspeech_2020/liu20s_interspeech.pdf)
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - Advanced real-time speaker diarization
- [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) - Real-time speaker diarization
- [Silero VAD](https://github.com/snakers4/silero-vad) (2024) - Enterprise-grade Voice Activity Detection
> **Why not just run a simple Whisper model on every audio batch?** Whisper is designed for complete utterances, not real-time chunks. Processing small segments loses context, cuts off words mid-syllable, and produces poor transcription. WhisperLiveKit uses state-of-the-art simultaneous speech research for intelligent buffering and incremental processing.
### Architecture
<img alt="Architecture" src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/architecture.png" />
*The backend supports multiple concurrent users. Voice Activity Detection reduces overhead when no voice is detected.*
### Installation & Quick Start
```bash
pip install whisperlivekit
``` ```
> You can also clone the repo and `pip install -e .` for the latest version. @InProceedings{machacek-dabre-bojar:2023:ijcnlp,
author = {Macháček, Dominik and Dabre, Raj and Bojar, Ondřej},
#### Quick Start title = {Turning Whisper into Real-Time Transcription System},
1. **Start the transcription server:** booktitle = {System Demonstrations},
```bash month = {November},
wlk --model base --language en year = {2023},
``` address = {Bali, Indonesia},
publisher = {Asian Federation of Natural Language Processing},
2. **Open your browser** and navigate to `http://localhost:8000`. Start speaking and watch your words appear in real-time! pages = {17--24},
}
> - See [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) for the list of all available languages.
> - Check the [troubleshooting guide](docs/troubleshooting.md) for step-by-step fixes collected from recent GPU setup/env issues.
> - The CLI entry point is exposed as both `wlk` and `whisperlivekit-server`; they are equivalent.
> - For HTTPS requirements, see the **Parameters** section for SSL configuration options.
#### Use it to capture audio from web pages.
Go to `chrome-extension` for instructions.
<p align="center">
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/chrome-extension/demo-extension.png" alt="WhisperLiveKit Demo" width="600">
</p>
#### Optional Dependencies
| Optional | `pip install` |
|-----------|-------------|
| **Windows/Linux optimizations** | `faster-whisper` |
| **Apple Silicon optimizations** | `mlx-whisper` |
| **Translation** | `nllw` |
| **Speaker diarization** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
| OpenAI API | `openai` |
| *[Not recommanded]* Speaker diarization with Diart | `diart` |
See **Parameters & Configuration** below on how to use them.
### Usage Examples
**Command-line Interface**: Start the transcription server with various options:
```bash
# Large model and translate from french to danish
wlk --model large-v3 --language fr --target-language da
# Diarization and server listening on */80
wlk --host 0.0.0.0 --port 80 --model medium --diarization --language fr
``` ```
## Installation
**Python API Integration**: Check [basic_server](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) for a more complete example of how to use the functions and classes. 1) ``pip install librosa`` -- audio processing library
```python 2) **Whisper backend**.
import asyncio
from contextlib import asynccontextmanager
from fastapi import FastAPI, WebSocket, WebSocketDisconnect Two alternative backends are integrated. The most recommended one is [faster-whisper](https://github.com/guillaumekln/faster-whisper) with GPU support. Follow their instructions for NVIDIA libraries -- we succeeded with CUDNN 8.5.0 and CUDA 11.7. Install with `pip install faster-whisper`.
from fastapi.responses import HTMLResponse
from whisperlivekit import AudioProcessor, TranscriptionEngine, parse_args Alternative, less restrictive, but slower backend is [whisper-timestamped](https://github.com/linto-ai/whisper-timestamped): `pip install git+https://github.com/linto-ai/whisper-timestamped`
transcription_engine = None The backend is loaded only when chosen. The unused one does not have to be installed.
@asynccontextmanager Or: **Seamless Streaming** -- alternative to Whisper, wrapped to enable the same operation modes and input/output format.
async def lifespan(app: FastAPI):
global transcription_engine
transcription_engine = TranscriptionEngine(model="medium", diarization=True, lan="en")
yield
app = FastAPI(lifespan=lifespan) `pip install fairseq2 pydub sentencepiece git+https://github.com/facebookresearch/seamless_communication.git`
async def handle_websocket_results(websocket: WebSocket, results_generator): Installation suggested [here](https://github.com/facebookresearch/seamless_communication/blob/main/Seamless_Tutorial.ipynb), for special torch version cases refer to [fairseq2](https://github.com/facebookresearch/fairseq2#variants).
async for response in results_generator:
await websocket.send_json(response)
await websocket.send_json({"type": "ready_to_stop"})
@app.websocket("/asr")
async def websocket_endpoint(websocket: WebSocket):
global transcription_engine
# Create a new AudioProcessor for each connection, passing the shared engine 3) Optional, not recommended: sentence segmenter (aka sentence tokenizer)
audio_processor = AudioProcessor(transcription_engine=transcription_engine)
results_generator = await audio_processor.create_tasks() Two buffer trimming options are integrated and evaluated for Whisper backends. They have impact on
results_task = asyncio.create_task(handle_websocket_results(websocket, results_generator)) the quality and latency. The default "segment" option performs better according
await websocket.accept() to our tests and does not require any sentence segmentation installed.
while True:
message = await websocket.receive_bytes() The other option, "sentence" -- trimming at the end of confirmed sentences,
await audio_processor.process_audio(message) requires sentence segmenter installed. It splits punctuated text to sentences by full
stops, avoiding the dots that are not full stops. The segmenters are language
specific. The unused one does not have to be installed. We integrate the
following segmenters, but suggestions for better alternatives are welcome.
- `pip install opus-fast-mosestokenizer` for the languages with codes `as bn ca cs de el en es et fi fr ga gu hi hu is it kn lt lv ml mni mr nl or pa pl pt ro ru sk sl sv ta te yue zh`
- `pip install tokenize_uk` for Ukrainian -- `uk`
- for other languages, we integrate a good performing multi-lingual model of `wtpslit`. It requires `pip install torch wtpsplit`, and its neural model `wtp-canine-s-12l-no-adapters`. It is downloaded to the default huggingface cache during the first use.
- we did not find a segmenter for languages `as ba bo br bs fo haw hr ht jw lb ln lo mi nn oc sa sd sn so su sw tk tl tt` that are supported by Whisper and not by wtpsplit. The default fallback option for them is wtpsplit with unspecified language. Alternative suggestions welcome.
In case of installation issues of opus-fast-mosestokenizer, especially on Windows and Mac, we recommend using only the "segment" option that does not require it.
## Usage
### Real-time simulation from audio file
```
usage: whisper_online.py [-h] [--min-chunk-size MIN_CHUNK_SIZE] [--model {tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large}] [--model_cache_dir MODEL_CACHE_DIR]
[--model_dir MODEL_DIR] [--lan LAN] [--task {transcribe,translate}] [--backend {faster-whisper,whisper_timestamped,seamless}] [--vad] [--buffer_trimming {sentence,segment}]
[--buffer_trimming_sec BUFFER_TRIMMING_SEC] [--start_at START_AT] [--offline] [--comp_unaware]
audio_path
positional arguments:
audio_path Filename of 16kHz mono channel wav, on which live streaming is simulated.
options:
-h, --help show this help message and exit
--min-chunk-size MIN_CHUNK_SIZE
Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received
by this time. Applicable both to Whisper and seamless.
--model {tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large}
Name size of the Whisper model to use (default: large-v2). The model is automatically downloaded from the model hub if not present in model cache dir. Not applicable to seamless.
--model_cache_dir MODEL_CACHE_DIR
Overriding the default model cache dir where models downloaded from the hub are saved. Not applicable to seamless.
--model_dir MODEL_DIR
Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter. Not applicable to seamless.
--lan LAN, --language LAN
Language code for transcription, e.g. en,de,cs. Seamless backend has its own 3-letter language codes, e.g. eng, deu, ces.
--task {transcribe,translate}
Transcribe or translate.
--backend {faster-whisper,whisper_timestamped,seamless}
Load only this backend for Whisper processing, or Seamless Streaming.
--vad Use VAD = voice activity detection, with the default parameters. Not applicable to seamless.
--buffer_trimming {sentence,segment}
Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter
must be installed for "sentence" option. Not applicable to seamless.
--buffer_trimming_sec BUFFER_TRIMMING_SEC
Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered. Not applicable to seamless.
--start_at START_AT Start processing audio at this time.
--offline Offline mode.
--comp_unaware Computationally unaware simulation.
``` ```
**Frontend Implementation**: The package includes an HTML/JavaScript implementation [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html). You can also import it using `from whisperlivekit import get_inline_ui_html` & `page = get_inline_ui_html()` Example:
It simulates realtime processing from a pre-recorded mono 16k wav file.
## Parameters & Configuration ```
python3 whisper_online.py en-demo16.wav --language en --min-chunk-size 1 > out.txt
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/default_and_custom_models.md) | `small` |
| `--model-path` | Local .pt file/directory **or** Hugging Face repo ID containing the Whisper model. Overrides `--model`. Recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/default_and_custom_models.md) | `None` |
| `--language` | List [here](docs/supported_languages.md). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` |
| `--target-language` | If sets, translates using [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting). [200 languages available](docs/supported_languages.md). If you want to translate to english, you can also use `--direct-english-translation`. The STT model will try to directly output the translation. | `None` |
| `--diarization` | Enable speaker identification | `False` |
| `--backend-policy` | Streaming strategy: `1`/`simulstreaming` uses AlignAtt SimulStreaming, `2`/`localagreement` uses the LocalAgreement policy | `simulstreaming` |
| `--backend` | Whisper implementation selector. `auto` picks MLX on macOS (if installed), otherwise Faster-Whisper, otherwise vanilla Whisper. You can also force `mlx-whisper`, `faster-whisper`, `whisper`, or `openai-api` (LocalAgreement only) | `auto` |
| `--no-vac` | Disable Voice Activity Controller. NOT ADVISED | `False` |
| `--no-vad` | Disable Voice Activity Detection. NOT ADVISED | `False` |
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
| `--host` | Server host address | `localhost` |
| `--port` | Server port | `8000` |
| `--ssl-certfile` | Path to the SSL certificate file (for HTTPS support) | `None` |
| `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` |
| `--forwarded-allow-ips` | Ip or Ips allowed to reverse proxy the whisperlivekit-server. Supported types are IP Addresses (e.g. 127.0.0.1), IP Networks (e.g. 10.100.0.0/16), or Literals (e.g. /path/to/socket.sock) | `None` |
| `--pcm-input` | raw PCM (s16le) data is expected as input and FFmpeg will be bypassed. Frontend will use AudioWorklet instead of MediaRecorder | `False` |
| `--lora-path` | Path or Hugging Face repo ID for LoRA adapter weights (e.g., `qfuxa/whisper-base-french-lora`). Only works with native Whisper backend (`--backend whisper`) | `None` |
| Translation options | Description | Default |
|-----------|-------------|---------|
| `--nllb-backend` | `transformers` or `ctranslate2` | `ctranslate2` |
| `--nllb-size` | `600M` or `1.3B` | `600M` |
| Diarization options | Description | Default |
|-----------|-------------|---------|
| `--diarization-backend` | `diart` or `sortformer` | `sortformer` |
| `--disable-punctuation-split` | [NOT FUNCTIONAL IN 0.2.15 / 0.2.16] Disable punctuation based splits. See #214 | `False` |
| `--segmentation-model` | Hugging Face model ID for Diart segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
| `--embedding-model` | Hugging Face model ID for Diart embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
| SimulStreaming backend options | Description | Default |
|-----------|-------------|---------|
| `--disable-fast-encoder` | Disable Faster Whisper or MLX Whisper backends for the encoder (if installed). Inference can be slower but helpful when GPU memory is limited | `False` |
| `--custom-alignment-heads` | Use your own alignment heads, useful when `--model-dir` is used. Use `scripts/determine_alignment_heads.py` to extract them. <img src="scripts/alignment_heads.png" alt="WhisperLiveKit Demo" width="300">
| `None` |
| `--frame-threshold` | AlignAtt frame threshold (lower = faster, higher = more accurate) | `25` |
| `--beams` | Number of beams for beam search (1 = greedy decoding) | `1` |
| `--decoder` | Force decoder type (`beam` or `greedy`) | `auto` |
| `--audio-max-len` | Maximum audio buffer length (seconds) | `30.0` |
| `--audio-min-len` | Minimum audio length to process (seconds) | `0.0` |
| `--cif-ckpt-path` | Path to CIF model for word boundary detection | `None` |
| `--never-fire` | Never truncate incomplete words | `False` |
| `--init-prompt` | Initial prompt for the model | `None` |
| `--static-init-prompt` | Static prompt that doesn't scroll | `None` |
| `--max-context-tokens` | Maximum context tokens | Depends on model used, but usually 448. |
| WhisperStreaming backend options | Description | Default |
|-----------|-------------|---------|
| `--confidence-validation` | Use confidence scores for faster validation | `False` |
| `--buffer_trimming` | Buffer trimming strategy (`sentence` or `segment`) | `segment` |
> For diarization using Diart, you need to accept user conditions [here](https://huggingface.co/pyannote/segmentation) for the `pyannote/segmentation` model, [here](https://huggingface.co/pyannote/segmentation-3.0) for the `pyannote/segmentation-3.0` model and [here](https://huggingface.co/pyannote/embedding) for the `pyannote/embedding` model. **Then**, login to HuggingFace: `huggingface-cli login`
### 🚀 Deployment Guide
To deploy WhisperLiveKit in production:
1. **Server Setup**: Install production ASGI server & launch with multiple workers
```bash
pip install uvicorn gunicorn
gunicorn -k uvicorn.workers.UvicornWorker -w 4 your_app:app
```
2. **Frontend**: Host your customized version of the `html` example & ensure WebSocket connection points correctly
3. **Nginx Configuration** (recommended for production):
```nginx
server {
listen 80;
server_name your-domain.com;
location / {
proxy_pass http://localhost:8000;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_set_header Host $host;
}}
```
4. **HTTPS Support**: For secure deployments, use "wss://" instead of "ws://" in WebSocket URL
## 🐋 Docker
Deploy the application easily using Docker with GPU or CPU support.
### Prerequisites
- Docker installed on your system
- For GPU support: NVIDIA Docker runtime installed
### Quick Start
**With GPU acceleration (recommended):**
```bash
docker build -t wlk .
docker run --gpus all -p 8000:8000 --name wlk wlk
``` ```
**CPU only:** Simulation modes:
```bash
docker build -f Dockerfile.cpu -t wlk . - default mode, no special option: real-time simulation from file, computationally aware. The chunk size is `MIN_CHUNK_SIZE` or larger, if more audio arrived during last update computation.
docker run -p 8000:8000 --name wlk wlk
- `--comp_unaware` option: computationally unaware simulation. It means that the timer that counts the emission times "stops" when the model is computing. The chunk size is always `MIN_CHUNK_SIZE`. The latency is caused only by the model being unable to confirm the output, e.g. because of language ambiguity etc., and not because of slow hardware or suboptimal implementation. We implement this feature for finding the lower bound for latency.
- `--start_at START_AT`: Start processing audio at this time. The first update receives the whole audio by `START_AT`. It is useful for debugging, e.g. when we observe a bug in a specific time in audio file, and want to reproduce it quickly, without long waiting.
- `--offline` option: It processes the whole audio file at once, in offline mode. We implement it to find out the lowest possible WER on given audio file.
### Output format
```
2691.4399 300 1380 Chairman, thank you.
6914.5501 1940 4940 If the debate today had a
9019.0277 5160 7160 the subject the situation in
10065.1274 7180 7480 Gaza
11058.3558 7480 9460 Strip, I might
12224.3731 9460 9760 have
13555.1929 9760 11060 joined Mrs.
14928.5479 11140 12240 De Kaiser and all the
16588.0787 12240 12560 other
18324.9285 12560 14420 colleagues across the
``` ```
### Advanced Usage [See description here](https://github.com/ufal/whisper_streaming/blob/d915d790a62d7be4e7392dde1480e7981eb142ae/whisper_online.py#L361)
**Custom configuration:** ### As a module
```bash
# Example with custom model and language TL;DR: use OnlineASRProcessor object and its methods insert_audio_chunk and process_iter.
docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
The code whisper_online.py is nicely commented, read it as the full documentation.
This pseudocode describes the interface that we suggest for your implementation. You can implement any features that you need for your application.
```
from whisper_online import *
src_lan = "en" # source language
tgt_lan = "en" # target language -- same as source for ASR, "en" if translate task is used
asr = FasterWhisperASR(lan, "large-v2") # loads and wraps Whisper model
# set options:
# asr.set_translate_task() # it will translate from lan into English
# asr.use_vad() # set using VAD
online = OnlineASRProcessor(asr) # create processing object with default buffer trimming option
while audio_has_not_ended: # processing loop:
a = # receive new audio chunk (and e.g. wait for min_chunk_size seconds first, ...)
online.insert_audio_chunk(a)
o = online.process_iter()
print(o) # do something with current partial output
# at the end of this audio processing
o = online.finish()
print(o) # do something with the last output
online.init() # refresh if you're going to re-use the object for the next audio
``` ```
### Memory Requirements ### Server -- real-time from mic
- **Large models**: Ensure your Docker runtime has sufficient memory allocated
`whisper_online_server.py` has the same model options as `whisper_online.py`, plus `--host` and `--port` of the TCP connection. See help message (`-h` option).
Client example:
```
arecord -f S16_LE -c1 -r 16000 -t raw -D default | nc localhost 43001
```
- arecord sends realtime audio from a sound device (e.g. mic), in raw audio format -- 16000 sampling rate, mono channel, S16\_LE -- signed 16-bit integer low endian. (use the alternative to arecord that works for you)
- nc is netcat with server's host and port
#### Customization ## Background
Default Whisper is intended for audio chunks of at most 30 seconds that contain
one full sentence. Longer audio files must be split to shorter chunks and
merged with "init prompt". In low latency simultaneous streaming mode, the
simple and naive chunking fixed-sized windows does not work well, it can split
a word in the middle. It is also necessary to know when the transcribt is
stable, should be confirmed ("commited") and followed up, and when the future
content makes the transcript clearer.
For that, there is LocalAgreement-n policy: if n consecutive updates, each with
a newly available audio stream chunk, agree on a prefix transcript, it is
confirmed. (Reference: CUNI-KIT at IWSLT 2022 etc.)
In this project, we re-use the idea of Peter Polák from this demo:
https://github.com/pe-trik/transformers/blob/online_decode/examples/pytorch/online-decoding/whisper-online-demo.py
However, it doesn't do any sentence segmentation, but Whisper produces
punctuation and the libraries `faster-whisper` and `whisper_transcribed` make
word-level timestamps. In short: we
consecutively process new audio chunks, emit the transcripts that are confirmed
by 2 iterations, and scroll the audio processing buffer on a timestamp of a
confirmed complete sentence. The processing audio buffer is not too long and
the processing is fast.
In more detail: we use the init prompt, we handle the inaccurate timestamps, we
re-process confirmed sentence prefixes and skip them, making sure they don't
overlap, and we limit the processing buffer window.
Contributions are welcome.
### Performance evaluation
[See the paper.](http://www.afnlp.org/conferences/ijcnlp2023/proceedings/main-demo/cdrom/pdf/2023.ijcnlp-demo.3.pdf)
## Contact
Dominik Macháček, machacek@ufal.mff.cuni.cz
- `--build-arg` Options:
- `EXTRAS="translation"` - Add extras to the image's installation (no spaces). Remember to set necessary container options!
- `HF_PRECACHE_DIR="./.cache/"` - Pre-load a model cache for faster first-time start
- `HF_TKN_FILE="./token"` - Add your Hugging Face Hub access token to download gated models
## 🔮 Use Cases
Capture discussions in real-time for meeting transcription, help hearing-impaired users follow conversations through accessibility tools, transcribe podcasts or videos automatically for content creation, transcribe support calls with speaker identification for customer service...

Binary file not shown.

Before

Width:  |  Height:  |  Size: 422 KiB

View File

@@ -1,19 +0,0 @@
## WhisperLiveKit Chrome Extension v0.1.1
Capture the audio of your current tab, transcribe diarize and translate it using WhisperliveKit, in Chrome and other Chromium-based browsers.
> Currently, only the tab audio is captured; your microphone audio is not recorded.
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/chrome-extension/demo-extension.png" alt="WhisperLiveKit Demo" width="730">
## Running this extension
1. Run `python scripts/sync_extension.py` to copy frontend files to the `chrome-extension` directory.
2. Load the `chrome-extension` directory in Chrome as an unpacked extension.
## Devs:
- Impossible to capture audio from tabs if extension is a pannel, unfortunately:
- https://issues.chromium.org/issues/40926394
- https://groups.google.com/a/chromium.org/g/chromium-extensions/c/DET2SXCFnDg
- https://issues.chromium.org/issues/40916430
- To capture microphone in an extension, there are tricks: https://github.com/justinmann/sidepanel-audio-issue , https://medium.com/@lynchee.owo/how-to-enable-microphone-access-in-chrome-extensions-by-code-924295170080 (comments)

View File

@@ -1,9 +0,0 @@
chrome.runtime.onInstalled.addListener((details) => {
if (details.reason.search(/install/g) === -1) {
return
}
chrome.tabs.create({
url: chrome.runtime.getURL("welcome.html"),
active: true
})
})

Binary file not shown.

Before

Width:  |  Height:  |  Size: 5.8 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 5.8 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 376 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 823 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.4 KiB

View File

@@ -1,23 +0,0 @@
{
"manifest_version": 3,
"name": "WhisperLiveKit Tab Capture",
"version": "1.0",
"description": "Capture and transcribe audio from browser tabs using WhisperLiveKit.",
"icons": {
"16": "icons/icon16.png",
"32": "icons/icon32.png",
"48": "icons/icon48.png",
"128": "icons/icon128.png"
},
"action": {
"default_title": "WhisperLiveKit Tab Capture",
"default_popup": "live_transcription.html"
},
"permissions": [
"scripting",
"tabCapture",
"offscreen",
"activeTab",
"storage"
]
}

View File

@@ -1,12 +0,0 @@
<!DOCTYPE html>
<html>
<head>
<title>Request Permissions</title>
<script src="requestPermissions.js"></script>
</head>
<body>
This page exists to workaround an issue with Chrome that blocks permission
requests from chrome extensions
<button id="requestMicrophone">Request Microphone</button>
</body>
</html>

View File

@@ -1,17 +0,0 @@
/**
* Requests user permission for microphone access.
* @returns {Promise<void>} A Promise that resolves when permission is granted or rejects with an error.
*/
async function getUserPermission() {
console.log("Getting user permission for microphone access...");
await navigator.mediaDevices.getUserMedia({ audio: true });
const micPermission = await navigator.permissions.query({
name: "microphone",
});
if (micPermission.state == "granted") {
window.close();
}
}
// Call the function to request microphone permission
getUserPermission();

View File

@@ -1,29 +0,0 @@
console.log("sidepanel.js");
async function run() {
const micPermission = await navigator.permissions.query({
name: "microphone",
});
document.getElementById(
"audioPermission"
).innerText = `MICROPHONE: ${micPermission.state}`;
if (micPermission.state !== "granted") {
chrome.tabs.create({ url: "requestPermissions.html" });
}
const intervalId = setInterval(async () => {
const micPermission = await navigator.permissions.query({
name: "microphone",
});
if (micPermission.state === "granted") {
document.getElementById(
"audioPermission"
).innerText = `MICROPHONE: ${micPermission.state}`;
clearInterval(intervalId);
}
}, 100);
}
void run();

BIN
demo.png

Binary file not shown.

Before

Width:  |  Height:  |  Size: 985 KiB

View File

@@ -1,264 +0,0 @@
# WhisperLiveKit WebSocket API Documentation
> !! **Note**: The new API structure described in this document is currently under deployment.
This documentation is intended for devs who want to build custom frontends.
WLK provides real-time speech transcription, speaker diarization, and translation through a WebSocket API. The server sends incremental updates as audio is processed, allowing clients to display live transcription results with minimal latency.
---
## Legacy API (Current)
### Message Structure
The current API sends complete state snapshots on each update (several time per second)
```typescript
{
"type": str,
"status": str,
"lines": [
{
"speaker": int,
"text": str,
"start": float,
"end": float,
"translation": str | null,
"detected_language": str
}
],
"buffer_transcription": str,
"buffer_diarization": str,
"remaining_time_transcription": float,
"remaining_time_diarization": float
}
```
---
## New API (Under Development)
### Philosophy
Principles:
- **Incremental Updates**: Only updates and new segments are sent
- **Ephemeral Buffers**: Temporary, unvalidated data displayed in real-time but overwritten on next update, at speaker level
## Message Format
```typescript
{
"type": "transcript_update",
"status": "active_transcription" | "no_audio_detected",
"segments": [
{
"id": number,
"speaker": number,
"text": string,
"start_speaker": float,
"start": float,
"end": float,
"language": string | null,
"translation": string,
"words": [
{
"text": string,
"start": float,
"end": float,
"validated": {
"text": boolean,
"speaker": boolean,
}
}
],
"buffer": {
"transcription": string,
"diarization": string,
"translation": string
}
}
],
"metadata": {
"remaining_time_transcription": float,
"remaining_time_diarization": float
}
}
```
### Other Message Types
#### Config Message (sent on connection)
```json
{
"type": "config",
"useAudioWorklet": true / false
}
```
#### Ready to Stop Message (sent after processing complete)
```json
{
"type": "ready_to_stop"
}
```
---
## Field Descriptions
### Segment Fields
| Field | Type | Description |
|-------|------|-------------|
| `id` | `number` | Unique identifier for this segment. Used by clients to update specific segments efficiently. |
| `speaker` | `number` | Speaker ID (1, 2, 3...). Special value `-2` indicates silence. |
| `text` | `string` | Validated transcription text for this update. Should be **appended** to the segment's text on the client side. |
| `start_speaker` | `float` | Timestamp (seconds) when this speaker segment began. |
| `start` | `float` | Timestamp (seconds) of the first word in this update. |
| `end` | `float` | Timestamp (seconds) of the last word in this update. |
| `language` | `string \| null` | ISO language code (e.g., "en", "fr"). `null` until language is detected. |
| `translation` | `string` | Validated translation text for this update. Should be **appended** to the segment's translation on the client side. |
| `words` | `Array` | Array of word-level objects with timing and validation information. |
| `buffer` | `Object` | Per-segment temporary buffers, see below |
### Word Object
| Field | Type | Description |
|-------|------|-------------|
| `text` | `string` | The word text. |
| `start` | `number` | Start timestamp (seconds) of this word. |
| `end` | `number` | End timestamp (seconds) of this word. |
| `validated.text` | `boolean` | Whether the transcription text has been validated. if false, word is also in buffer: transcription |
| `validated.speaker` | `boolean` | Whether the speaker assignment has been validated. if false, word is also in buffer: diarization |
| `validated.language` | `boolean` | Whether the language detection has been validated. if false, word is also in buffer: translation |
### Buffer Object (Per-Segment)
Buffers are **ephemeral**. They should be displayed to the user but not stored permanently in the frontend. Each update may contain a completely different buffer value, and previous buffer is likely to be in the next validated text.
| Field | Type | Description |
|-------|------|-------------|
| `transcription` | `string` | Pending transcription text. Displayed immediately but **overwritten** on next update. |
| `diarization` | `string` | Pending diarization text (text waiting for speaker assignment). Displayed immediately but **overwritten** on next update. |
| `translation` | `string` | Pending translation text. Displayed immediately but **overwritten** on next update. |
### Metadata Fields
| Field | Type | Description |
|-------|------|-------------|
| `remaining_time_transcription` | `float` | Seconds of audio waiting for transcription processing. |
| `remaining_time_diarization` | `float` | Seconds of audio waiting for speaker diarization. |
### Status Values
| Status | Description |
|--------|-------------|
| `active_transcription` | Normal operation, transcription is active. |
| `no_audio_detected` | No audio has been detected yet. |
---
## Update Behavior
### Incremental Updates
The API sends **only changed or new segments**. Clients should:
1. Maintain a local map of segments by ID
2. When receiving an update, merge/update segments by ID
3. Render only the changed segments
### Language Detection
When language is detected for a segment:
```jsonc
// Update 1: No language yet
{
"segments": [
{"id": 1, "speaker": 1, "text": "May see", "language": null}
]
}
// Update 2: Same segment ID, language now detected
{
"segments": [
{"id": 1, "speaker": 1, "text": "Merci", "language": "fr"}
]
}
```
**Client behavior**: **Replace** the existing segment with the same ID.
### Buffer Behavior
Buffers are **per-segment** to handle multi-speaker scenarios correctly.
#### Example: Translation with diarization and translation
```jsonc
// Update 1
{
"segments": [
{
"id": 1,
"speaker": 1,
"text": "Hello world, how are",
"translation": "",
"buffer": {
"transcription": "",
"diarization": " you on",
"translation": "Bonjour le monde"
}
}
]
}
// ==== Frontend ====
// <SPEAKER>1</SPEAKER>
// <TRANSCRIPTION>Hello world, how are <DIARIZATION BUFFER> you on</DIARIZATION BUFFER></TRANSCRIPTION>
// <TRANSLATION><TRANSLATION BUFFER>Bonjour le monde</TRANSLATION BUFFER></TRANSLATION>
// Update 2
{
"segments": [
{
"id": 1,
"speaker": 1,
"text": " you on this",
"translation": "Bonjour tout le monde",
"buffer": {
"transcription": "",
"diarization": " beautiful day",
"translation": ",comment"
}
},
]
}
// ==== Frontend ====
// <SPEAKER>1</SPEAKER>
// <TRANSCRIPTION>Hello world, how are you on this<DIARIZATION BUFFER> beautiful day</DIARIZATION BUFFER></TRANSCRIPTION>
// <TRANSLATION>Bonjour tout le monde<TRANSLATION BUFFER>, comment</TRANSLATION BUFFER><TRANSLATION>
```
### Silence Segments
Silence is represented with the speaker id = `-2`:
```jsonc
{
"id": 5,
"speaker": -2,
"text": "",
"start": 10.5,
"end": 12.3
}
```

View File

@@ -1,71 +0,0 @@
### Alignment between STT Tokens and Diarization Segments
- Example 1: The punctuation from STT and the speaker change from Diariation come in the prediction `t`
- Example 2: The punctuation from STT comes from prediction `t`, but the speaker change from Diariation come in the prediction `t-1`
- Example 3: The punctuation from STT comes from prediction `t-1`, but the speaker change from Diariation come in the prediction `t`
> `#` Is the split between the `t-1` prediction and `t` prediction.
## Example 1:
```text
punctuations_segments : __#_______.__________________!____
diarization_segments:
SPK1 __#____________
SPK2 # ___________________
-->
ALIGNED SPK1 __#_______.
ALIGNED SPK2 # __________________!____
t-1 output:
SPK1: __#
SPK2: NO
DIARIZATION BUFFER: NO
t output:
SPK1: __#__.
SPK2: __________________!____
DIARIZATION BUFFER: No
```
## Example 2:
```text
punctuations_segments : _____#__.___________
diarization_segments:
SPK1 ___ #
SPK2 __#______________
-->
ALIGNED SPK1 _____#__.
ALIGNED SPK2 # ___________
t-1 output:
SPK1: ___ #
SPK2:
DIARIZATION BUFFER: __#
t output:
SPK1: __#__.
SPK2: ___________
DIARIZATION BUFFER: No
```
## Example 3:
```text
punctuations_segments : ___.__#__________
diarization_segments:
SPK1 ______#__
SPK2 # ________
-->
ALIGNED SPK1 ___. #
ALIGNED SPK2 __#__________
t-1 output:
SPK1: ___. #
SPK2:
DIARIZATION BUFFER: __#
t output:
SPK1: #
SPK2: __#___________
DIARIZATION BUFFER: NO
```

View File

@@ -1,106 +0,0 @@
# Models and Model Paths
## Defaults
**Default Whisper Model**: `base`
When no model is specified, WhisperLiveKit uses the `base` model, which provides a good balance of speed and accuracy for most use cases.
**Default Model Cache Directory**: `~/.cache/whisper`
Models are automatically downloaded from OpenAI's model hub and cached in this directory. You can override this with `--model_cache_dir`.
**Default Translation Model**: `600M` (NLLB-200-distilled)
When translation is enabled, the 600M distilled NLLB model is used by default. This provides good quality with minimal resource usage.
**Default Translation Backend**: `transformers`
The translation backend defaults to Transformers. On Apple Silicon, this automatically uses MPS acceleration for better performance.
---
## Available Whisper model sizes:
| Available Model | Speed | Accuracy | Multilingual | Translation | Hardware Requirements | Best Use Case |
|--------------------|----------|-----------|--------------|-------------|----------------------|----------------------------------|
| tiny(.en) | Fastest | Basic | Yes/No | Yes/No | ~1GB VRAM | Real-time, low resources |
| base(.en) | Fast | Good | Yes/No | Yes/No | ~1GB VRAM | Balanced performance |
| small(.en) | Medium | Better | Yes/No | Yes/No | ~2GB VRAM | Quality on limited hardware |
| medium(.en) | Slow | High | Yes/No | Yes/No | ~5GB VRAM | High quality, moderate resources |
| large-v2 | Slowest | Excellent | Yes | Yes | ~10GB VRAM | Good overall accuracy & language support |
| large-v3 | Slowest | Excellent | Yes | Yes | ~10GB VRAM | Best overall accuracy & language support |
| large-v3-turbo | Fast | Excellent | Yes | No | ~6GB VRAM | Fast, high-quality transcription |
### How to choose?
#### Language Support
- **English only**: Use `.en` (ex: `base.en`) models for better accuracy and faster processing when you only need English transcription
- **Multilingual**: Do not use `.en` models.
#### Special Cases
- **No translation needed**: Use `large-v3-turbo`
- Same transcription quality as `large-v2` but significantly faster
- **Important**: Does not translate correctly, only transcribes
### Additional Considerations
**Model Performance**:
- Accuracy improves significantly from tiny to large models
- English-only models are ~10-15% more accurate for English audio
- Newer versions (v2, v3) have better punctuation and formatting
**Audio Quality Impact**:
- Clean, clear audio: smaller models may suffice
- Noisy, accented, or technical audio: larger models recommended
- Phone/low-quality audio: use at least `small` model
_______________________
# Custom Models:
The `--model-path` parameter accepts:
## File Path
- **`.pt` / `.bin` / `.safetensor` formats** Should be openable by pytorch/safetensor.
## Directory Path (recommended)
Must contain:
- **`.pt` / `.bin` / `.safetensor` file** (required for decoder)
May optionally contain:
- **`.bin` file** - faster-whisper model for encoder (requires faster-whisper)
- **`weights.npz`** or **`weights.safetensors`** - for encoder (requires whisper-mlx)
## Hugging Face Repo ID
- Provide the repo ID (e.g. `openai/whisper-large-v3`) and WhisperLiveKit will download and cache the snapshot automatically. For gated repos, authenticate via `huggingface-cli login` first.
To improve speed/reduce hallucinations, you may want to use `scripts/determine_alignment_heads.py` to determine the alignment heads to use for your model, and use the `--custom-alignment-heads` to pass them to WLK. If not, alignment heads are set to be all the heads of the last half layer of decoder.
_______________________
# Translation Models and Backend
**Language Support**: ~200 languages
## Distilled Model Sizes Available
| Model | Size | Parameters | VRAM (FP16) | VRAM (INT8) | Quality |
|-------|------|------------|-------------|-------------|---------|
| 600M | 2.46 GB | 600M | ~1.5GB | ~800MB | Good, understandable |
| 1.3B | 5.48 GB | 1.3B | ~3GB | ~1.5GB | Better accuracy, context |
**Quality Impact**: 1.3B has ~15-25% better BLEU scores vs 600M across language pairs.
## Backend Performance
| Backend | Speed vs Base | Memory Usage | Quality Loss |
|---------|---------------|--------------|--------------|
| CTranslate2 | 6-10x faster | 40-60% less | ~5% BLEU drop |
| Transformers | Baseline | High | None |
| Transformers + MPS (on Apple Silicon) | 2x faster | Medium | None |
**Metrics**:
- CTranslate2: 50-100+ tokens/sec
- Transformers: 10-30 tokens/sec
- Apple Silicon with MPS: Up to 2x faster than CTranslate2

View File

@@ -1,373 +0,0 @@
# Transcription: Supported Language
WLK supports transcription in the following languages:
| ISO Code | Language Name |
|----------|---------------------|
| en | English |
| zh | Chinese |
| de | German |
| es | Spanish |
| ru | Russian |
| ko | Korean |
| fr | French |
| ja | Japanese |
| pt | Portuguese |
| tr | Turkish |
| pl | Polish |
| ca | Catalan |
| nl | Dutch |
| ar | Arabic |
| sv | Swedish |
| it | Italian |
| id | Indonesian |
| hi | Hindi |
| fi | Finnish |
| vi | Vietnamese |
| he | Hebrew |
| uk | Ukrainian |
| el | Greek |
| ms | Malay |
| cs | Czech |
| ro | Romanian |
| da | Danish |
| hu | Hungarian |
| ta | Tamil |
| no | Norwegian |
| th | Thai |
| ur | Urdu |
| hr | Croatian |
| bg | Bulgarian |
| lt | Lithuanian |
| la | Latin |
| mi | Maori |
| ml | Malayalam |
| cy | Welsh |
| sk | Slovak |
| te | Telugu |
| fa | Persian |
| lv | Latvian |
| bn | Bengali |
| sr | Serbian |
| az | Azerbaijani |
| sl | Slovenian |
| kn | Kannada |
| et | Estonian |
| mk | Macedonian |
| br | Breton |
| eu | Basque |
| is | Icelandic |
| hy | Armenian |
| ne | Nepali |
| mn | Mongolian |
| bs | Bosnian |
| kk | Kazakh |
| sq | Albanian |
| sw | Swahili |
| gl | Galician |
| mr | Marathi |
| pa | Punjabi |
| si | Sinhala |
| km | Khmer |
| sn | Shona |
| yo | Yoruba |
| so | Somali |
| af | Afrikaans |
| oc | Occitan |
| ka | Georgian |
| be | Belarusian |
| tg | Tajik |
| sd | Sindhi |
| gu | Gujarati |
| am | Amharic |
| yi | Yiddish |
| lo | Lao |
| uz | Uzbek |
| fo | Faroese |
| ht | Haitian Creole |
| ps | Pashto |
| tk | Turkmen |
| nn | Nynorsk |
| mt | Maltese |
| sa | Sanskrit |
| lb | Luxembourgish |
| my | Myanmar |
| bo | Tibetan |
| tl | Tagalog |
| mg | Malagasy |
| as | Assamese |
| tt | Tatar |
| haw | Hawaiian |
| ln | Lingala |
| ha | Hausa |
| ba | Bashkir |
| jw | Javanese |
| su | Sundanese |
| yue | Cantonese |
# Translation: Supported Languages
WLK supports translation into **201 languages** from the FLORES-200 dataset through the [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting) translation system.
## How to Specify Languages
You can specify languages in **three different ways**:
1. **Language Name** (case-insensitive): `"English"`, `"French"`, `"Spanish"`
2. **ISO Language Code**: `"en"`, `"fr"`, `"es"`
3. **NLLB Code** (FLORES-200): `"eng_Latn"`, `"fra_Latn"`, `"spa_Latn"`
## Usage Examples
### Command Line
```bash
# Using language name
whisperlivekit-server --target-language "French"
# Using ISO code
whisperlivekit-server --target-language fr
# Using NLLB code
whisperlivekit-server --target-language fra_Latn
```
### Python API
```python
from nllw.translation import get_language_info
# Get language information by name
lang_info = get_language_info("French")
print(lang_info)
# {'name': 'French', 'nllb': 'fra_Latn', 'language_code': 'fr'}
# Get language information by ISO code
lang_info = get_language_info("fr")
# Get language information by NLLB code
lang_info = get_language_info("fra_Latn")
# All three return the same result
```
## Complete Language List
The following table lists all 201 supported languages with their corresponding codes:
| Language Name | ISO Code | NLLB Code |
|---------------|----------|-----------|
| Acehnese (Arabic script) | ace_Arab | ace_Arab |
| Acehnese (Latin script) | ace_Latn | ace_Latn |
| Mesopotamian Arabic | acm_Arab | acm_Arab |
| Ta'izzi-Adeni Arabic | acq_Arab | acq_Arab |
| Tunisian Arabic | aeb_Arab | aeb_Arab |
| Afrikaans | af | afr_Latn |
| South Levantine Arabic | ajp_Arab | ajp_Arab |
| Akan | ak | aka_Latn |
| Tosk Albanian | als | als_Latn |
| Amharic | am | amh_Ethi |
| North Levantine Arabic | apc_Arab | apc_Arab |
| Modern Standard Arabic | ar | arb_Arab |
| Modern Standard Arabic (Romanized) | arb_Latn | arb_Latn |
| Najdi Arabic | ars_Arab | ars_Arab |
| Moroccan Arabic | ary_Arab | ary_Arab |
| Egyptian Arabic | arz_Arab | arz_Arab |
| Assamese | as | asm_Beng |
| Asturian | ast | ast_Latn |
| Awadhi | awa | awa_Deva |
| Central Aymara | ay | ayr_Latn |
| South Azerbaijani | azb | azb_Arab |
| North Azerbaijani | az | azj_Latn |
| Bashkir | ba | bak_Cyrl |
| Bambara | bm | bam_Latn |
| Balinese | ban | ban_Latn |
| Belarusian | be | bel_Cyrl |
| Bemba | bem | bem_Latn |
| Bengali | bn | ben_Beng |
| Bhojpuri | bho | bho_Deva |
| Banjar (Arabic script) | bjn_Arab | bjn_Arab |
| Banjar (Latin script) | bjn_Latn | bjn_Latn |
| Standard Tibetan | bo | bod_Tibt |
| Bosnian | bs | bos_Latn |
| Buginese | bug | bug_Latn |
| Bulgarian | bg | bul_Cyrl |
| Catalan | ca | cat_Latn |
| Cebuano | ceb | ceb_Latn |
| Czech | cs | ces_Latn |
| Chokwe | cjk | cjk_Latn |
| Central Kurdish | ckb | ckb_Arab |
| Crimean Tatar | crh | crh_Latn |
| Welsh | cy | cym_Latn |
| Danish | da | dan_Latn |
| German | de | deu_Latn |
| Southwestern Dinka | dik | dik_Latn |
| Dyula | dyu | dyu_Latn |
| Dzongkha | dz | dzo_Tibt |
| Greek | el | ell_Grek |
| English | en | eng_Latn |
| Esperanto | eo | epo_Latn |
| Estonian | et | est_Latn |
| Basque | eu | eus_Latn |
| Ewe | ee | ewe_Latn |
| Faroese | fo | fao_Latn |
| Fijian | fj | fij_Latn |
| Finnish | fi | fin_Latn |
| Fon | fon | fon_Latn |
| French | fr | fra_Latn |
| Friulian | fur-IT | fur_Latn |
| Nigerian Fulfulde | fuv | fuv_Latn |
| West Central Oromo | om | gaz_Latn |
| Scottish Gaelic | gd | gla_Latn |
| Irish | ga-IE | gle_Latn |
| Galician | gl | glg_Latn |
| Guarani | gn | grn_Latn |
| Gujarati | gu-IN | guj_Gujr |
| Haitian Creole | ht | hat_Latn |
| Hausa | ha | hau_Latn |
| Hebrew | he | heb_Hebr |
| Hindi | hi | hin_Deva |
| Chhattisgarhi | hne | hne_Deva |
| Croatian | hr | hrv_Latn |
| Hungarian | hu | hun_Latn |
| Armenian | hy-AM | hye_Armn |
| Igbo | ig | ibo_Latn |
| Ilocano | ilo | ilo_Latn |
| Indonesian | id | ind_Latn |
| Icelandic | is | isl_Latn |
| Italian | it | ita_Latn |
| Javanese | jv | jav_Latn |
| Japanese | ja | jpn_Jpan |
| Kabyle | kab | kab_Latn |
| Jingpho | kac | kac_Latn |
| Kamba | kam | kam_Latn |
| Kannada | kn | kan_Knda |
| Kashmiri (Arabic script) | kas_Arab | kas_Arab |
| Kashmiri (Devanagari script) | kas_Deva | kas_Deva |
| Georgian | ka | kat_Geor |
| Kazakh | kk | kaz_Cyrl |
| Kabiyè | kbp | kbp_Latn |
| Kabuverdianu | kea | kea_Latn |
| Halh Mongolian | mn | khk_Cyrl |
| Khmer | km | khm_Khmr |
| Kikuyu | ki | kik_Latn |
| Kinyarwanda | rw | kin_Latn |
| Kyrgyz | ky | kir_Cyrl |
| Kimbundu | kmb | kmb_Latn |
| Northern Kurdish | kmr | kmr_Latn |
| Central Kanuri (Arabic script) | knc_Arab | knc_Arab |
| Central Kanuri (Latin script) | knc_Latn | knc_Latn |
| Kikongo | kg | kon_Latn |
| Korean | ko | kor_Hang |
| Lao | lo | lao_Laoo |
| Ligurian | lij | lij_Latn |
| Limburgish | li | lim_Latn |
| Lingala | ln | lin_Latn |
| Lithuanian | lt | lit_Latn |
| Lombard | lmo | lmo_Latn |
| Latgalian | ltg | ltg_Latn |
| Luxembourgish | lb | ltz_Latn |
| Luba-Kasai | lua | lua_Latn |
| Ganda | lg | lug_Latn |
| Luo | luo | luo_Latn |
| Mizo | lus | lus_Latn |
| Standard Latvian | lv | lvs_Latn |
| Magahi | mag | mag_Deva |
| Maithili | mai | mai_Deva |
| Malayalam | ml-IN | mal_Mlym |
| Marathi | mr | mar_Deva |
| Minangkabau (Arabic script) | min_Arab | min_Arab |
| Minangkabau (Latin script) | min_Latn | min_Latn |
| Macedonian | mk | mkd_Cyrl |
| Maltese | mt | mlt_Latn |
| Meitei (Bengali script) | mni | mni_Beng |
| Mossi | mos | mos_Latn |
| Maori | mi | mri_Latn |
| Burmese | my | mya_Mymr |
| Dutch | nl | nld_Latn |
| Norwegian Nynorsk | nn-NO | nno_Latn |
| Norwegian Bokmål | nb | nob_Latn |
| Nepali | ne-NP | npi_Deva |
| Northern Sotho | nso | nso_Latn |
| Nuer | nus | nus_Latn |
| Nyanja | ny | nya_Latn |
| Occitan | oc | oci_Latn |
| Odia | or | ory_Orya |
| Pangasinan | pag | pag_Latn |
| Eastern Panjabi | pa | pan_Guru |
| Papiamento | pap | pap_Latn |
| Southern Pashto | pbt | pbt_Arab |
| Western Persian | fa | pes_Arab |
| Plateau Malagasy | mg | plt_Latn |
| Polish | pl | pol_Latn |
| Portuguese | pt-PT | por_Latn |
| Dari | fa-AF | prs_Arab |
| Ayacucho Quechua | qu | quy_Latn |
| Romanian | ro | ron_Latn |
| Rundi | rn | run_Latn |
| Russian | ru | rus_Cyrl |
| Sango | sg | sag_Latn |
| Sanskrit | sa | san_Deva |
| Santali | sat | sat_Olck |
| Sicilian | scn | scn_Latn |
| Shan | shn | shn_Mymr |
| Sinhala | si-LK | sin_Sinh |
| Slovak | sk | slk_Latn |
| Slovenian | sl | slv_Latn |
| Samoan | sm | smo_Latn |
| Shona | sn | sna_Latn |
| Sindhi | sd | snd_Arab |
| Somali | so | som_Latn |
| Southern Sotho | st | sot_Latn |
| Spanish | es-ES | spa_Latn |
| Sardinian | sc | srd_Latn |
| Serbian | sr | srp_Cyrl |
| Swati | ss | ssw_Latn |
| Sundanese | su | sun_Latn |
| Swedish | sv-SE | swe_Latn |
| Swahili | sw | swh_Latn |
| Silesian | szl | szl_Latn |
| Tamil | ta | tam_Taml |
| Tamasheq (Latin script) | taq_Latn | taq_Latn |
| Tamasheq (Tifinagh script) | taq_Tfng | taq_Tfng |
| Tatar | tt-RU | tat_Cyrl |
| Telugu | te | tel_Telu |
| Tajik | tg | tgk_Cyrl |
| Tagalog | tl | tgl_Latn |
| Thai | th | tha_Thai |
| Tigrinya | ti | tir_Ethi |
| Tok Pisin | tpi | tpi_Latn |
| Tswana | tn | tsn_Latn |
| Tsonga | ts | tso_Latn |
| Turkmen | tk | tuk_Latn |
| Tumbuka | tum | tum_Latn |
| Turkish | tr | tur_Latn |
| Twi | tw | twi_Latn |
| Central Atlas Tamazight | tzm | tzm_Tfng |
| Uyghur | ug | uig_Arab |
| Ukrainian | uk | ukr_Cyrl |
| Umbundu | umb | umb_Latn |
| Urdu | ur | urd_Arab |
| Northern Uzbek | uz | uzn_Latn |
| Venetian | vec | vec_Latn |
| Vietnamese | vi | vie_Latn |
| Waray | war | war_Latn |
| Wolof | wo | wol_Latn |
| Xhosa | xh | xho_Latn |
| Eastern Yiddish | yi | ydd_Hebr |
| Yoruba | yo | yor_Latn |
| Yue Chinese | yue | yue_Hant |
| Chinese (Simplified) | zh-CN | zho_Hans |
| Chinese (Traditional) | zh-TW | zho_Hant |
| Standard Malay | ms | zsm_Latn |
| Zulu | zu | zul_Latn |
## Special Features
### Multiple Script Support
Several languages are available in multiple scripts (e.g., Arabic and Latin):
- **Acehnese**: Arabic (`ace_Arab`) and Latin (`ace_Latn`)
- **Banjar**: Arabic (`bjn_Arab`) and Latin (`bjn_Latn`)
- **Kashmiri**: Arabic (`kas_Arab`) and Devanagari (`kas_Deva`)
- **Minangkabau**: Arabic (`min_Arab`) and Latin (`min_Latn`)
- **Tamasheq**: Latin (`taq_Latn`) and Tifinagh (`taq_Tfng`)
- **Central Kanuri**: Arabic (`knc_Arab`) and Latin (`knc_Latn`)

View File

@@ -1,43 +0,0 @@
# Technical Integration Guide
This document introduce how to reuse the core components when you do **not** want to ship the bundled frontend, FastAPI server, or even the provided CLI.
---
## 1. Runtime Components
| Layer | File(s) | Purpose |
|-------|---------|---------|
| Transport | `whisperlivekit/basic_server.py`, any ASGI/WebSocket server | Accepts audio over WebSocket (MediaRecorder WebM or raw PCM chunks) and streams JSON updates back |
| Audio processing | `whisperlivekit/audio_processor.py` | Buffers audio, orchestrates transcription, diarization, translation, handles FFmpeg/PCM input |
| Engines | `whisperlivekit/core.py`, `whisperlivekit/simul_whisper/*`, `whisperlivekit/local_agreement/*` | Load models once (SimulStreaming or LocalAgreement), expose `TranscriptionEngine` and helpers |
| Frontends | `whisperlivekit/web/*`, `chrome-extension/*` | Optional UI layers feeding the WebSocket endpoint |
**Key idea:** The server boundary is just `AudioProcessor.process_audio()` for incoming bytes and the async generator returned by `AudioProcessor.create_tasks()` for outgoing updates (`FrontData`). Everything else is optional.
---
## 2. Running Without the Bundled Frontend
1. Start the server/engine however you like:
```bash
wlk --model small --language en --host 0.0.0.0 --port 9000
# or launch your own app that instantiates TranscriptionEngine(...)
```
2. Build your own client (browser, mobile, desktop) that:
- Opens `ws(s)://<host>:<port>/asr`
- Sends either MediaRecorder/Opus WebM blobs **or** raw PCM (`--pcm-input` on the server tells the client to use the AudioWorklet).
- Consumes the JSON payload defined in `docs/API.md`.
---
## 3. Running Without FastAPI
`whisperlivekit/basic_server.py` is just an example. Any async framework works, as long as you:
1. Create a global `TranscriptionEngine` (expensive to initialize; reuse it).
2. Instantiate `AudioProcessor(transcription_engine=engine)` for each connection.
3. Call `create_tasks()` to get the async generator, `process_audio()` with incoming bytes, and ensure `cleanup()` runs when the client disconnects.
If you prefer to send compressed audio, instantiate `AudioProcessor(pcm_input=False)` and pipe encoded chunks through `FFmpegManager` transparently. Just ensure `ffmpeg` is available.

View File

@@ -1,140 +0,0 @@
# Troubleshooting
## GPU drivers & cuDNN visibility
### Linux error: `Unable to load libcudnn_ops.so* / cudnnCreateTensorDescriptor`
> Reported in issue #271 (Arch/CachyOS)
`faster-whisper` (used for the SimulStreaming encoder) dynamically loads cuDNN.
If the runtime cannot find `libcudnn_*`, verify that CUDA and cuDNN match the PyTorch build you installed:
1. **Install CUDA + cuDNN** (Arch/CachyOS example):
```bash
sudo pacman -S cuda cudnn
sudo ldconfig
```
2. **Make sure the shared objects are visible**:
```bash
ls /usr/lib/libcudnn*
```
3. **Check what CUDA version PyTorch expects** and match that with the driver you installed:
```bash
python - <<'EOF'
import torch
print(torch.version.cuda)
EOF
nvcc --version
```
4. If you installed CUDA in a non-default location, export `CUDA_HOME` and add `$CUDA_HOME/lib64` to `LD_LIBRARY_PATH`.
Once the CUDA/cuDNN versions match, `whisperlivekit-server` starts normally.
### Windows error: `Could not locate cudnn_ops64_9.dll`
> Reported in issue #286 (Conda on Windows)
PyTorch bundles cuDNN DLLs inside your environment (`<env>\Lib\site-packages\torch\lib`).
When `ctranslate2` or `faster-whisper` cannot find `cudnn_ops64_9.dll`:
1. Locate the DLL shipped with PyTorch, e.g.
```
E:\conda\envs\WhisperLiveKit\Lib\site-packages\torch\lib\cudnn_ops64_9.dll
```
2. Add that directory to your `PATH` **or** copy the `cudnn_*64_9.dll` files into a directory that is already on `PATH` (such as the environment's `Scripts/` folder).
3. Restart the shell before launching `wlk`.
Installing NVIDIA's standalone cuDNN 9.x and pointing `PATH`/`CUDNN_PATH` to it works as well, but is usually not required.
---
## PyTorch / CTranslate2 GPU builds
### `Torch not compiled with CUDA enabled`
> Reported in issue #284
If `torch.zeros(1).cuda()` raises that assertion it means you installed a CPU-only wheel.
Install the GPU-enabled wheels that match your CUDA toolkit:
```bash
pip install --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu130
```
Replace `cu130` with the CUDA version supported by your driver (see [PyTorch install selector](https://pytorch.org/get-started/locally/)).
Validate with:
```python
import torch
print(torch.cuda.is_available(), torch.cuda.get_device_name())
```
### `CTranslate2 device count: 0` or `Could not infer dtype of ctranslate2._ext.StorageView`
> Follow-up in issue #284
`ctranslate2` publishes separate CPU and CUDA wheels. The default `pip install ctranslate2` brings the CPU build, which makes WhisperLiveKit fall back to CPU tensors and leads to the dtype error above.
1. Uninstall the CPU build: `pip uninstall -y ctranslate2`.
2. Install the CUDA wheel that matches your toolkit (example for CUDA 13.0):
```bash
pip install ctranslate2==4.5.0 -f https://opennmt.net/ctranslate2/whl/cu130
```
(See the [CTranslate2 installation table](https://opennmt.net/CTranslate2/installation.html) for other CUDA versions.)
3. Verify:
```python
import ctranslate2
print("CUDA devices:", ctranslate2.get_cuda_device_count())
print("CUDA compute types:", ctranslate2.get_supported_compute_types("cuda", 0))
```
**Note for aarch64 systems (e.g., NVIDIA DGX Spark):** Pre-built CUDA wheels may not be available for all CUDA versions on ARM architectures. If the wheel installation fails, you may need to compile CTranslate2 from source with CUDA support enabled.
If you intentionally want CPU inference, run `wlk --backend whisper` to avoid mixing CPU-only CTranslate2 with a GPU Torch build.
---
## Hopper / Blackwell (`sm_121a`) systems
> Reported in issues #276 and #284 (NVIDIA DGX Spark)
CUDA 12.1a GPUs (e.g., NVIDIA GB10 on DGX Spark) ship before some toolchains know about the architecture ID, so Triton/PTXAS need manual configuration.
### Error: `ptxas fatal : Value 'sm_121a' is not defined for option 'gpu-name'`
If you encounter this error after compiling CTranslate2 from source on aarch64 systems, Triton's bundled `ptxas` may not support the `sm_121a` architecture. The solution is to replace Triton's `ptxas` with the system's CUDA `ptxas`:
```bash
# Find your Python environment's Triton directory
python -c "import triton; import os; print(os.path.dirname(triton.__file__))"
# Copy the system ptxas to Triton's backend directory
# Replace <triton_path> with the output above
cp /usr/local/cuda/bin/ptxas <triton_path>/backends/nvidia/bin/ptxas
```
For example, in a virtual environment:
```bash
cp /usr/local/cuda/bin/ptxas ~/wlk/lib/python3.12/site-packages/triton/backends/nvidia/bin/ptxas
```
**Note:** On DGX Spark systems, CUDA is typically already in `PATH` (`/usr/local/cuda/bin`), so explicit `CUDA_HOME` and `PATH` exports may not be necessary. Verify with `which ptxas` before copying.
### Alternative: Environment variable approach
If the above doesn't work, you can try setting environment variables (though this may not resolve the `sm_121a` issue on all systems):
```bash
export CUDA_HOME="/usr/local/cuda-13.0"
export PATH="$CUDA_HOME/bin:$PATH"
export LD_LIBRARY_PATH="$CUDA_HOME/lib64:$LD_LIBRARY_PATH"
# Tell Triton where the new ptxas lives
export TRITON_PTXAS_PATH="$CUDA_HOME/bin/ptxas"
# Force PyTorch to JIT kernels for all needed architectures
export TORCH_CUDA_ARCH_LIST="8.0 9.0 10.0 12.0 12.1a"
```
After applying the fix, restart `wlk`. Incoming streams will now compile kernels targeting `sm_121a` without crashing.
---
Need help with another recurring issue? Open a GitHub discussion or PR and reference this document so we can keep it current.

94
line_packet.py Normal file
View File

@@ -0,0 +1,94 @@
#!/usr/bin/env python3
"""Functions for sending and receiving individual lines of text over a socket.
Used by marian-server-server.py to communicate with the Marian worker.
A line is transmitted using one or more fixed-size packets of UTF-8 bytes
containing:
- Zero or more bytes of UTF-8, excluding \n and \0, followed by
- Zero or more \0 bytes as required to pad the packet to PACKET_SIZE
"""
PACKET_SIZE = 65536
def send_one_line(socket, text):
"""Sends a line of text over the given socket.
The 'text' argument should contain a single line of text (line break
characters are optional). Line boundaries are determined by Python's
str.splitlines() function [1]. We also count '\0' as a line terminator.
If 'text' contains multiple lines then only the first will be sent.
If the send fails then an exception will be raised.
[1] https://docs.python.org/3.5/library/stdtypes.html#str.splitlines
Args:
socket: a socket object.
text: string containing a line of text for transmission.
"""
text.replace('\0', '\n')
lines = text.splitlines()
first_line = '' if len(lines) == 0 else lines[0]
# TODO Is there a better way of handling bad input than 'replace'?
data = first_line.encode('utf-8', errors='replace') + b'\n\0'
for offset in range(0, len(data), PACKET_SIZE):
bytes_remaining = len(data) - offset
if bytes_remaining < PACKET_SIZE:
padding_length = PACKET_SIZE - bytes_remaining
packet = data[offset:] + b'\0' * padding_length
else:
packet = data[offset:offset+PACKET_SIZE]
socket.sendall(packet)
def receive_one_line(socket):
"""Receives a line of text from the given socket.
This function will (attempt to) receive a single line of text. If data is
currently unavailable then it will block until data becomes available or
the sender has closed the connection (in which case it will return an
empty string).
The string should not contain any newline characters, but if it does then
only the first line will be returned.
Args:
socket: a socket object.
Returns:
A string representing a single line with a terminating newline or
None if the connection has been closed.
"""
data = b''
while True:
packet = socket.recv(PACKET_SIZE)
if not packet: # Connection has been closed.
return None
data += packet
if b'\0' in packet:
break
# TODO Is there a better way of handling bad input than 'replace'?
text = data.decode('utf-8', errors='replace').strip('\0')
lines = text.split('\n')
return lines[0] + '\n'
def receive_lines(socket):
try:
data = socket.recv(PACKET_SIZE)
except BlockingIOError:
return []
if data is None: # Connection has been closed.
return None
# TODO Is there a better way of handling bad input than 'replace'?
text = data.decode('utf-8', errors='replace').strip('\0')
lines = text.split('\n')
if len(lines)==1 and not lines[0]:
return None
return lines

View File

@@ -1,72 +0,0 @@
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[project]
name = "whisperlivekit"
version = "0.2.17.post1"
description = "Real-time speech-to-text with speaker diarization using Whisper"
readme = "README.md"
authors = [
{ name = "Quentin Fuxa" }
]
license = { file = "LICENSE" }
requires-python = ">=3.9"
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
"Programming Language :: Python :: 3.15",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Multimedia :: Sound/Audio :: Speech"
]
dependencies = [
"fastapi",
"librosa",
"soundfile",
"uvicorn",
"websockets",
"torchaudio>=2.0.0",
"torch>=2.0.0",
"huggingface-hub>=0.25.0",
"faster-whisper>=1.2.0",
"tqdm",
"tiktoken",
'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")'
]
[project.optional-dependencies]
translation = ["nllw"]
sentence_tokenizer = ["mosestokenizer", "wtpsplit"]
[project.urls]
Homepage = "https://github.com/QuentinFuxa/WhisperLiveKit"
[project.scripts]
whisperlivekit-server = "whisperlivekit.basic_server:main"
wlk = "whisperlivekit.basic_server:main"
[tool.setuptools]
packages = [
"whisperlivekit",
"whisperlivekit.diarization",
"whisperlivekit.simul_whisper",
"whisperlivekit.simul_whisper.mlx",
"whisperlivekit.whisper",
"whisperlivekit.whisper.assets",
"whisperlivekit.whisper.normalizers",
"whisperlivekit.web",
"whisperlivekit.local_agreement",
"whisperlivekit.silero_vad_models"
]
[tool.setuptools.package-data]
whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"]
"whisperlivekit.whisper.assets" = ["*.tiktoken", "*.npz"]
"whisperlivekit.silero_vad_models" = ["*.jit", "*.onnx"]

Binary file not shown.

Before

Width:  |  Height:  |  Size: 276 KiB

View File

@@ -1,153 +0,0 @@
#!/usr/bin/env python3
"""
Convert a Hugging Face style Whisper checkpoint into a WhisperLiveKit .pt file.
Optionally shrink the supported audio chunk length (in seconds) by trimming the
encoder positional embeddings and updating the stored model dimensions.
"""
import argparse
import json
import os
from pathlib import Path
from typing import Dict, Tuple
import torch
from whisperlivekit.whisper import _convert_hf_state_dict
from whisperlivekit.whisper.audio import HOP_LENGTH, SAMPLE_RATE
from whisperlivekit.whisper.model import ModelDimensions
from whisperlivekit.whisper.utils import exact_div
def _load_state_dict(repo_path: Path) -> Dict[str, torch.Tensor]:
safetensor_path = repo_path / "model.safetensors"
bin_path = repo_path / "pytorch_model.bin"
if safetensor_path.is_file():
try:
from safetensors.torch import load_file # type: ignore
except Exception as exc: # pragma: no cover - import guard
raise RuntimeError(
"Install safetensors to load model.safetensors "
"(pip install safetensors)"
) from exc
return load_file(str(safetensor_path))
if bin_path.is_file():
return torch.load(bin_path, map_location="cpu")
raise FileNotFoundError(
f"Could not find model.safetensors or pytorch_model.bin under {repo_path}"
)
def _load_config(repo_path: Path) -> Dict:
config_path = repo_path / "config.json"
if not config_path.is_file():
raise FileNotFoundError(
f"Hugging Face checkpoint at {repo_path} is missing config.json"
)
with open(config_path, "r", encoding="utf-8") as fp:
return json.load(fp)
def _derive_audio_ctx(chunk_length: float) -> Tuple[int, int]:
n_samples = int(round(chunk_length * SAMPLE_RATE))
expected_samples = chunk_length * SAMPLE_RATE
if abs(n_samples - expected_samples) > 1e-6:
raise ValueError(
"chunk_length must align with sample rate so that "
"chunk_length * SAMPLE_RATE is an integer"
)
n_frames = exact_div(n_samples, HOP_LENGTH)
n_audio_ctx = exact_div(n_frames, 2)
return n_frames, n_audio_ctx
def _build_dims(config: Dict, chunk_length: float) -> Dict:
base_dims = ModelDimensions(
n_mels=config["num_mel_bins"],
n_audio_ctx=config["max_source_positions"],
n_audio_state=config["d_model"],
n_audio_head=config["encoder_attention_heads"],
n_audio_layer=config.get("encoder_layers") or config["num_hidden_layers"],
n_vocab=config["vocab_size"],
n_text_ctx=config["max_target_positions"],
n_text_state=config["d_model"],
n_text_head=config["decoder_attention_heads"],
n_text_layer=config["decoder_layers"],
).__dict__.copy()
_, n_audio_ctx = _derive_audio_ctx(chunk_length)
base_dims["n_audio_ctx"] = n_audio_ctx
base_dims["chunk_length"] = chunk_length
return base_dims
def _trim_positional_embedding(
state_dict: Dict[str, torch.Tensor], target_ctx: int
) -> None:
key = "encoder.positional_embedding"
if key not in state_dict:
raise KeyError(f"{key} missing from converted state dict")
tensor = state_dict[key]
if tensor.shape[0] < target_ctx:
raise ValueError(
f"Cannot increase encoder ctx from {tensor.shape[0]} to {target_ctx}"
)
if tensor.shape[0] == target_ctx:
return
state_dict[key] = tensor[:target_ctx].contiguous()
def convert_checkpoint(hf_path: Path, output_path: Path, chunk_length: float) -> None:
state_dict = _load_state_dict(hf_path)
converted = _convert_hf_state_dict(state_dict)
config = _load_config(hf_path)
dims = _build_dims(config, chunk_length)
_trim_positional_embedding(converted, dims["n_audio_ctx"])
package = {"dims": dims, "model_state_dict": converted}
output_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(package, output_path)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Convert Hugging Face Whisper checkpoint to WhisperLiveKit format."
)
parser.add_argument(
"hf_path",
type=str,
help="Path to the cloned Hugging Face repository (e.g. whisper-tiny.en)",
)
parser.add_argument(
"--output",
type=str,
default="converted-whisper.pt",
help="Destination path for the .pt file",
)
parser.add_argument(
"--chunk-length",
type=float,
default=30.0,
help="Audio chunk length in seconds to support (default: 30)",
)
return parser.parse_args()
def main():
args = parse_args()
hf_path = Path(os.path.expanduser(args.hf_path)).resolve()
output_path = Path(os.path.expanduser(args.output)).resolve()
convert_checkpoint(hf_path, output_path, args.chunk_length)
print(f"Saved converted checkpoint to {output_path}")
if __name__ == "__main__":
main()

View File

@@ -1,294 +0,0 @@
"""Determine alignment heads for a variants, such as distilled model"""
from __future__ import annotations
import argparse
import base64
import gzip
import io
import math
import pathlib
import sys
from typing import List, Optional, Sequence, Tuple, Union
import matplotlib.pyplot as plt
import numpy as np
import soundfile as sf
import torch
from datasets import Audio as DatasetAudio
from datasets import load_dataset
REPO_ROOT = pathlib.Path(__file__).resolve().parents[1]
WHISPER_ROOT = REPO_ROOT / "whisper"
sys.path.insert(0, str(REPO_ROOT))
sys.path.insert(0, str(WHISPER_ROOT))
from whisper import load_model
from whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim
from whisper.tokenizer import get_tokenizer
AudioInput = Union[str, pathlib.Path, np.ndarray, torch.Tensor]
def load_dataset_clips(name, config, split, limit):
ds = load_dataset(name, config, split=split)
ds = ds.cast_column("audio", DatasetAudio(decode=False))
clips = []
for idx, row in enumerate(ds):
if limit is not None and idx >= limit:
break
audio_field = row["audio"]
transcript = row["text"]
waveform_np, _ = sf.read(io.BytesIO(audio_field["bytes"]), dtype="float32")
if waveform_np.ndim > 1:
waveform_np = waveform_np.mean(axis=1)
waveform = waveform_np
transcript = str(transcript)
clips.append((waveform, transcript))
return clips
def load_clips(args):
return load_dataset_clips(
args.dataset,
args.dataset_config,
args.dataset_split,
args.dataset_num_samples,
)
def _waveform_from_source(source: AudioInput) -> torch.Tensor:
waveform = torch.from_numpy(source.astype(np.float32, copy=False))
return waveform
def _parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
type=str,
default="pytorch_model.bin",
)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="Torch device to run on",
)
parser.add_argument(
"--dataset",
type=str,
default="librispeech_asr"
)
parser.add_argument(
"--dataset-config",
type=str,
default="clean"
)
parser.add_argument(
"--dataset-split",
type=str,
default="validation[:1%]",
)
parser.add_argument(
"--dataset-num-samples",
type=int,
default=16,
)
parser.add_argument(
"--threshold",
type=float,
default=1.5,
help="Z score threshold for a head to be selected",
)
parser.add_argument(
"--votes",
type=float,
default=0.75,
help="percentage of clips that must vote for a head",
)
parser.add_argument(
"--output",
type=str,
default="alignment_heads.b85",
)
parser.add_argument(
"--visualize-top-k",
type=int,
default=32,
)
return parser.parse_args()
def collect_heads(
model,
tokenizer,
clips: Sequence[Tuple[AudioInput, str]],
threshold: float,
) -> Tuple[torch.Tensor, torch.Tensor]:
device = model.device
votes = torch.zeros(model.dims.n_text_layer, model.dims.n_text_head, device=device)
strengths = torch.zeros_like(votes)
for audio_source, transcript in clips:
waveform = pad_or_trim(_waveform_from_source(audio_source))
mel = log_mel_spectrogram(waveform, device=device)
tokens = torch.tensor(
[
*tokenizer.sot_sequence,
tokenizer.no_timestamps,
*tokenizer.encode(transcript),
tokenizer.eot,
],
device=device,
)
qks = [None] * model.dims.n_text_layer
hooks = [
block.cross_attn.register_forward_hook(
lambda _, __, outputs, index=i: qks.__setitem__(index, outputs[-1][0])
)
for i, block in enumerate(model.decoder.blocks)
]
with torch.no_grad():
model(mel.unsqueeze(0), tokens.unsqueeze(0))
for hook in hooks:
hook.remove()
for layer_idx, tensor in enumerate(qks):
if tensor is None:
continue
tensor = tensor[:, :, : mel.shape[-1] // 2]
tensor = tensor.softmax(dim=-1)
peak = tensor.max(dim=-1).values # [heads, tokens]
strengths[layer_idx] += peak.mean(dim=-1)
zscore = (peak - peak.mean(dim=-1, keepdim=True)) / (
peak.std(dim=-1, keepdim=True, unbiased=False) + 1e-6
)
mask = (zscore > 3).any(dim=-1)
votes[layer_idx] += mask.float()
votes /= len(clips)
strengths /= len(clips)
return votes, strengths
def _select_heads_for_visualization(selection, strengths, top_k):
selected = torch.nonzero(selection, as_tuple=False)
if selected.numel() == 0:
return []
entries = [
(int(layer.item()), int(head.item()), float(strengths[layer, head].item()))
for layer, head in selected
]
entries.sort(key=lambda item: item[2], reverse=True)
return entries[:top_k]
def _extract_heatmaps(
model,
tokenizer,
clip: Tuple[AudioInput, str],
heads: Sequence[Tuple[int, int, float]],
) -> dict:
if not heads:
return {}
target_map = {}
for layer, head, _ in heads:
target_map.setdefault(layer, set()).add(head)
waveform = pad_or_trim(_waveform_from_source(clip[0]))
mel = log_mel_spectrogram(waveform, device=model.device)
transcript = clip[1]
tokens = torch.tensor(
[
*tokenizer.sot_sequence,
tokenizer.no_timestamps,
*tokenizer.encode(transcript),
tokenizer.eot,
],
device=model.device,
)
QKs = [None] * model.dims.n_text_layer
hooks = [
block.cross_attn.register_forward_hook(
lambda _, __, outputs, index=i: QKs.__setitem__(index, outputs[-1][0])
)
for i, block in enumerate(model.decoder.blocks)
]
with torch.no_grad():
model(mel.unsqueeze(0), tokens.unsqueeze(0))
for hook in hooks:
hook.remove()
heatmaps = {}
for layer_idx, tensor in enumerate(QKs):
if tensor is None or layer_idx not in target_map:
continue
tensor = tensor[:, :, : mel.shape[-1] // 2]
tensor = tensor.softmax(dim=-1).cpu()
for head_idx in target_map[layer_idx]:
heatmaps[(layer_idx, head_idx)] = tensor[head_idx]
return heatmaps
def _plot_heatmaps(
heads, heatmaps, output_path):
cols = min(3, len(heads))
rows = math.ceil(len(heads) / cols)
fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 3.2 * rows), squeeze=False)
for idx, (layer, head, score) in enumerate(heads):
ax = axes[idx // cols][idx % cols]
mat = heatmaps.get((layer, head))
if mat is None:
ax.axis("off")
continue
im = ax.imshow(mat.to(torch.float32).numpy(), aspect="auto", origin="lower")
ax.set_title(f"L{layer} H{head} · score {score:.2f}")
ax.set_xlabel("time")
ax.set_ylabel("tokens")
for j in range(len(heads), rows * cols):
axes[j // cols][j % cols].axis("off")
fig.tight_layout()
fig.savefig(output_path, dpi=200)
plt.close(fig)
def _dump_mask(mask: torch.Tensor, output_path: str):
payload = mask.numpy().astype(np.bool_)
blob = base64.b85encode(gzip.compress(payload.tobytes()))
with open(output_path, "wb") as f:
f.write(blob)
def main():
args = _parse_args()
model = load_model(args.model, device=args.device)
model.eval()
tokenizer = get_tokenizer(multilingual=model.is_multilingual)
clips = load_clips(args)
votes, strengths = collect_heads(model, tokenizer, clips, args.threshold)
# selection = votes > 0.5
selection = strengths > 0.05
_dump_mask(selection.cpu(), args.output)
viz_heads = _select_heads_for_visualization(selection, strengths, args.visualize_top_k)
heatmaps = _extract_heatmaps(model, tokenizer, clips[0], viz_heads)
_plot_heatmaps(viz_heads, heatmaps, "alignment_heads.png")
if __name__ == "__main__":
main()

View File

@@ -1,40 +0,0 @@
"""Copy core files from web directory to Chrome extension directory."""
import os
import shutil
from pathlib import Path
def sync_extension_files():
web_dir = Path("whisperlivekit/web")
extension_dir = Path("chrome-extension")
files_to_sync = [
"live_transcription.html", "live_transcription.js", "live_transcription.css"
]
svg_files = [
"system_mode.svg",
"light_mode.svg",
"dark_mode.svg",
"settings.svg"
]
for file in files_to_sync:
src_path = web_dir / file
dest_path = extension_dir / file
dest_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(src_path, dest_path)
for svg_file in svg_files:
src_path = web_dir / "src" / svg_file
dest_path = extension_dir / "web" / "src" / svg_file
dest_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(src_path, dest_path)
if __name__ == "__main__":
sync_extension_files()

172
seamless_integration.py Normal file
View File

@@ -0,0 +1,172 @@
#!/usr/bin/env python3
import sys
import numpy as np
# code extracted from https://github.com/facebookresearch/seamless_communication/blob/main/Seamless_Tutorial.ipynb :
from simuleval.data.segments import SpeechSegment, EmptySegment
from simuleval.utils.arguments import cli_argument_list
from simuleval import options
from typing import Union, List
from simuleval.data.segments import Segment, TextSegment
from simuleval.agents.pipeline import TreeAgentPipeline
from simuleval.agents.states import AgentStates
SAMPLE_RATE = 16000
def reset_states(system, states):
if isinstance(system, TreeAgentPipeline):
states_iter = states.values()
else:
states_iter = states
for state in states_iter:
state.reset()
def get_states_root(system, states) -> AgentStates:
if isinstance(system, TreeAgentPipeline):
# self.states is a dict
return states[system.source_module]
else:
# self.states is a list
return system.states[0]
def build_streaming_system(model_configs, agent_class):
parser = options.general_parser()
parser.add_argument("-f", "--f", help="a dummy argument to fool ipython", default="1")
agent_class.add_args(parser)
args, _ = parser.parse_known_args(cli_argument_list(model_configs))
system = agent_class.from_args(args)
return system
class OutputSegments:
def __init__(self, segments: Union[List[Segment], Segment]):
if isinstance(segments, Segment):
segments = [segments]
self.segments: List[Segment] = [s for s in segments]
@property
def is_empty(self):
return all(segment.is_empty for segment in self.segments)
@property
def finished(self):
return all(segment.finished for segment in self.segments)
######################
# fixing DetokenizerAgent -- it strips output segment.content last space, but sometimes a word is split into more segments. Simple joining with spaces would be wrong.
from seamless_communication.streaming.agents.detokenizer import DetokenizerAgent
from seamless_communication.streaming.agents.offline_w2v_bert_encoder import (
OfflineWav2VecBertEncoderAgent,
)
from seamless_communication.streaming.agents.online_feature_extractor import (
OnlineFeatureExtractorAgent,
)
from seamless_communication.streaming.agents.online_text_decoder import (
MMASpeechToTextDecoderAgent,
)
from seamless_communication.streaming.agents.silero_vad import SileroVADAgent
from seamless_communication.streaming.agents.unity_pipeline import UnitYAgentPipeline
class FixDetokenizerAgent(DetokenizerAgent):
def decode(self, x: str) -> str:
return x.replace(" ", "").replace("\u2581", " ") # .strip() is removed
class FixSeamlessStreamingS2TVADAgent(UnitYAgentPipeline):
pipeline = [
SileroVADAgent,
OnlineFeatureExtractorAgent,
OfflineWav2VecBertEncoderAgent,
MMASpeechToTextDecoderAgent,
FixDetokenizerAgent,
]
##################################
# the next pieces of are copypasted from the tutorial and put to the corresponding methods
#class SeamlessProcessor(OnlineASRProcessorBase): # TODO: there should be a common base class. But the code would not be simple anymore.
class SeamlessProcessor:
'''
Wrapping SeamlessStreaming for the same operation modes as
Whisper-Streaming's OnlineASRProcessor.
'''
def __init__(self, tgt_lan, task, logfile=sys.stderr):
'''
tgt_lan: must be 3-letter language code that Seamless-Streaming supports for text output mode.
task: see below
logfile
'''
if task in ("transcribe","asr"):
task_arg = "asr"
elif task in ("translate","s2tt"):
task_arg = "s2tt"
else:
raise ValueError("task argument must be 'transcribe' or 'translate', or 'asr' or 's2tt'")
self.logfile = logfile
agent_class = FixSeamlessStreamingS2TVADAgent
model_configs = dict(
source_segment_size=320,
device="cuda:0",
dtype="fp16",
min_starting_wait_w2vbert=192,
decision_threshold=0.5,
min_unit_chunk_size=50,
no_early_stop=True,
max_len_a=0,
max_len_b=100,
task=task_arg,
tgt_lang=tgt_lan,
block_ngrams=True,
detokenize_only=True,
)
self.tgt_lan = tgt_lan
self.system = build_streaming_system(model_configs, agent_class)
self.system_states = self.system.build_states()
self.init()
def init(self):
reset_states(self.system, self.system_states)
self.audio_buffer = np.array([],dtype=np.float32)
self.beg, self.end = 0, 0
def insert_audio_chunk(self, audio):
self.audio_buffer = np.append(self.audio_buffer, audio)
def process_segment(self, input_segment):
output_segments = OutputSegments(self.system.pushpop(input_segment, self.system_states))
out = []
for segment in output_segments.segments:
if not segment.is_empty:
out.append(segment.content)
if output_segments.finished:
print("End of VAD segment",file=self.logfile)
reset_states(self.system, self.system_states)
if out:
b = self.beg
self.beg = self.end
o = "".join(out)
return (b, self.end, "".join(out))
return (None, None, "")
def process_iter(self, finished=False):
input_segment = SpeechSegment(
content=self.audio_buffer,
sample_rate=SAMPLE_RATE,
finished=finished,
)
self.audio_buffer = np.array([],dtype=np.float32)
input_segment.tgt_lang = self.tgt_lan
self.end += (len(input_segment.content)/SAMPLE_RATE)
return self.process_segment(input_segment)
def finish(self):
return self.process_iter(finished=True)

628
whisper_online.py Normal file
View File

@@ -0,0 +1,628 @@
#!/usr/bin/env python3
import sys
import numpy as np
import librosa
from functools import lru_cache
import time
@lru_cache
def load_audio(fname):
a, _ = librosa.load(fname, sr=16000)
return a
def load_audio_chunk(fname, beg, end):
audio = load_audio(fname)
beg_s = int(beg*16000)
end_s = int(end*16000)
return audio[beg_s:end_s]
# Whisper backend
class ASRBase:
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
# "" for faster-whisper because it emits the spaces when neeeded)
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr):
self.logfile = logfile
self.transcribe_kargs = {}
self.original_language = lan
self.model = self.load_model(modelsize, cache_dir, model_dir)
def load_model(self, modelsize, cache_dir):
raise NotImplemented("must be implemented in the child class")
def transcribe(self, audio, init_prompt=""):
raise NotImplemented("must be implemented in the child class")
def use_vad(self):
raise NotImplemented("must be implemented in the child class")
class WhisperTimestampedASR(ASRBase):
"""Uses whisper_timestamped library as the backend. Initially, we tested the code on this backend. It worked, but slower than faster-whisper.
On the other hand, the installation for GPU could be easier.
"""
sep = " "
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
import whisper
from whisper_timestamped import transcribe_timestamped
self.transcribe_timestamped = transcribe_timestamped
if model_dir is not None:
print("ignoring model_dir, not implemented",file=self.logfile)
return whisper.load_model(modelsize, download_root=cache_dir)
def transcribe(self, audio, init_prompt=""):
result = self.transcribe_timestamped(self.model,
audio, language=self.original_language,
initial_prompt=init_prompt, verbose=None,
condition_on_previous_text=True, **self.transcribe_kargs)
return result
def ts_words(self,r):
# return: transcribe result object to [(beg,end,"word1"), ...]
o = []
for s in r["segments"]:
for w in s["words"]:
t = (w["start"],w["end"],w["text"])
o.append(t)
return o
def segments_end_ts(self, res):
return [s["end"] for s in res["segments"]]
def use_vad(self):
self.transcribe_kargs["vad"] = True
def set_translate_task(self):
self.transcribe_kargs["task"] = "translate"
class FasterWhisperASR(ASRBase):
"""Uses faster-whisper library as the backend. Works much faster, appx 4-times (in offline mode). For GPU, it requires installation with a specific CUDNN version.
"""
sep = ""
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
from faster_whisper import WhisperModel
if model_dir is not None:
print(f"Loading whisper model from model_dir {model_dir}. modelsize and cache_dir parameters are not used.",file=self.logfile)
model_size_or_path = model_dir
elif modelsize is not None:
model_size_or_path = modelsize
else:
raise ValueError("modelsize or model_dir parameter must be set")
# this worked fast and reliably on NVIDIA L40
model = WhisperModel(model_size_or_path, device="cuda", compute_type="float16", download_root=cache_dir)
# or run on GPU with INT8
# tested: the transcripts were different, probably worse than with FP16, and it was slightly (appx 20%) slower
#model = WhisperModel(model_size, device="cuda", compute_type="int8_float16")
# or run on CPU with INT8
# tested: works, but slow, appx 10-times than cuda FP16
# model = WhisperModel(modelsize, device="cpu", compute_type="int8") #, download_root="faster-disk-cache-dir/")
return model
def transcribe(self, audio, init_prompt=""):
# tested: beam_size=5 is faster and better than 1 (on one 200 second document from En ESIC, min chunk 0.01)
segments, info = self.model.transcribe(audio, language=self.original_language, initial_prompt=init_prompt, beam_size=5, word_timestamps=True, condition_on_previous_text=True, **self.transcribe_kargs)
return list(segments)
def ts_words(self, segments):
o = []
for segment in segments:
for word in segment.words:
# not stripping the spaces -- should not be merged with them!
w = word.word
t = (word.start, word.end, w)
o.append(t)
return o
def segments_end_ts(self, res):
return [s.end for s in res]
def use_vad(self):
self.transcribe_kargs["vad_filter"] = True
def set_translate_task(self):
self.transcribe_kargs["task"] = "translate"
class HypothesisBuffer:
def __init__(self, logfile=sys.stderr):
self.commited_in_buffer = []
self.buffer = []
self.new = []
self.last_commited_time = 0
self.last_commited_word = None
self.logfile = logfile
def insert(self, new, offset):
# compare self.commited_in_buffer and new. It inserts only the words in new that extend the commited_in_buffer, it means they are roughly behind last_commited_time and new in content
# the new tail is added to self.new
new = [(a+offset,b+offset,t) for a,b,t in new]
self.new = [(a,b,t) for a,b,t in new if a > self.last_commited_time-0.1]
if len(self.new) >= 1:
a,b,t = self.new[0]
if abs(a - self.last_commited_time) < 1:
if self.commited_in_buffer:
# it's going to search for 1, 2, ..., 5 consecutive words (n-grams) that are identical in commited and new. If they are, they're dropped.
cn = len(self.commited_in_buffer)
nn = len(self.new)
for i in range(1,min(min(cn,nn),5)+1): # 5 is the maximum
c = " ".join([self.commited_in_buffer[-j][2] for j in range(1,i+1)][::-1])
tail = " ".join(self.new[j-1][2] for j in range(1,i+1))
if c == tail:
print("removing last",i,"words:",file=self.logfile)
for j in range(i):
print("\t",self.new.pop(0),file=self.logfile)
break
def flush(self):
# returns commited chunk = the longest common prefix of 2 last inserts.
commit = []
while self.new:
na, nb, nt = self.new[0]
if len(self.buffer) == 0:
break
if nt == self.buffer[0][2]:
commit.append((na,nb,nt))
self.last_commited_word = nt
self.last_commited_time = nb
self.buffer.pop(0)
self.new.pop(0)
else:
break
self.buffer = self.new
self.new = []
self.commited_in_buffer.extend(commit)
return commit
def pop_commited(self, time):
while self.commited_in_buffer and self.commited_in_buffer[0][1] <= time:
self.commited_in_buffer.pop(0)
def complete(self):
return self.buffer
class OnlineASRProcessorBase:
'''Showing minimum common public interface for various specialized subclasses.'''
def init(self):
raise NotImplemented()
def insert_audio_chunk(self, audio):
raise NotImplemented()
def process_iter(self):
raise NotImplemented()
def finish(self):
raise NotImplemented()
class OnlineASRProcessor(OnlineASRProcessorBase):
SAMPLING_RATE = 16000
def __init__(self, asr, tokenizer=None, buffer_trimming=("segment", 15), logfile=sys.stderr):
"""asr: WhisperASR object
tokenizer: sentence tokenizer object for the target language. Must have a method *split* that behaves like the one of MosesTokenizer. It can be None, if "segment" buffer trimming option is used, then tokenizer is not used at all.
("segment", 15)
buffer_trimming: a pair of (option, seconds), where option is either "sentence" or "segment", and seconds is a number. Buffer is trimmed if it is longer than "seconds" threshold. Default is the most recommended option.
logfile: where to store the log.
"""
self.asr = asr
self.tokenizer = tokenizer
self.logfile = logfile
self.init()
self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
def init(self):
"""run this when starting or restarting processing"""
self.audio_buffer = np.array([],dtype=np.float32)
self.buffer_time_offset = 0
self.transcript_buffer = HypothesisBuffer(logfile=self.logfile)
self.commited = []
self.last_chunked_at = 0
self.silence_iters = 0
def insert_audio_chunk(self, audio):
self.audio_buffer = np.append(self.audio_buffer, audio)
def prompt(self):
"""Returns a tuple: (prompt, context), where "prompt" is a 200-character suffix of commited text that is inside of the scrolled away part of audio buffer.
"context" is the commited text that is inside the audio buffer. It is transcribed again and skipped. It is returned only for debugging and logging reasons.
"""
k = max(0,len(self.commited)-1)
while k > 0 and self.commited[k-1][1] > self.last_chunked_at:
k -= 1
p = self.commited[:k]
p = [t for _,_,t in p]
prompt = []
l = 0
while p and l < 200: # 200 characters prompt size
x = p.pop(-1)
l += len(x)+1
prompt.append(x)
non_prompt = self.commited[k:]
return self.asr.sep.join(prompt[::-1]), self.asr.sep.join(t for _,_,t in non_prompt)
def process_iter(self):
"""Runs on the current audio buffer.
Returns: a tuple (beg_timestamp, end_timestamp, "text"), or (None, None, "").
The non-emty text is confirmed (committed) partial transcript.
"""
prompt, non_prompt = self.prompt()
print("PROMPT:", prompt, file=self.logfile)
print("CONTEXT:", non_prompt, file=self.logfile)
print(f"transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds from {self.buffer_time_offset:2.2f}",file=self.logfile)
res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt)
# transform to [(beg,end,"word1"), ...]
tsw = self.asr.ts_words(res)
self.transcript_buffer.insert(tsw, self.buffer_time_offset)
o = self.transcript_buffer.flush()
self.commited.extend(o)
print(">>>>COMPLETE NOW:",self.to_flush(o),file=self.logfile,flush=True)
print("INCOMPLETE:",self.to_flush(self.transcript_buffer.complete()),file=self.logfile,flush=True)
# there is a newly confirmed text
if o and self.buffer_trimming_way == "sentence": # trim the completed sentences
if len(self.audio_buffer)/self.SAMPLING_RATE > self.buffer_trimming_sec: # longer than this
self.chunk_completed_sentence()
if self.buffer_trimming_way == "segment":
s = self.buffer_trimming_sec # trim the completed segments longer than s,
else:
s = 30 # if the audio buffer is longer than 30s, trim it
if len(self.audio_buffer)/self.SAMPLING_RATE > s:
self.chunk_completed_segment(res)
# alternative: on any word
#l = self.buffer_time_offset + len(self.audio_buffer)/self.SAMPLING_RATE - 10
# let's find commited word that is less
#k = len(self.commited)-1
#while k>0 and self.commited[k][1] > l:
# k -= 1
#t = self.commited[k][1]
print(f"chunking segment",file=self.logfile)
#self.chunk_at(t)
print(f"len of buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}",file=self.logfile)
return self.to_flush(o)
def chunk_completed_sentence(self):
if self.commited == []: return
print(self.commited,file=self.logfile)
sents = self.words_to_sentences(self.commited)
for s in sents:
print("\t\tSENT:",s,file=self.logfile)
if len(sents) < 2:
return
while len(sents) > 2:
sents.pop(0)
# we will continue with audio processing at this timestamp
chunk_at = sents[-2][1]
print(f"--- sentence chunked at {chunk_at:2.2f}",file=self.logfile)
self.chunk_at(chunk_at)
def chunk_completed_segment(self, res):
if self.commited == []: return
ends = self.asr.segments_end_ts(res)
t = self.commited[-1][1]
if len(ends) > 1:
e = ends[-2]+self.buffer_time_offset
while len(ends) > 2 and e > t:
ends.pop(-1)
e = ends[-2]+self.buffer_time_offset
if e <= t:
print(f"--- segment chunked at {e:2.2f}",file=self.logfile)
self.chunk_at(e)
else:
print(f"--- last segment not within commited area",file=self.logfile)
else:
print(f"--- not enough segments to chunk",file=self.logfile)
def chunk_at(self, time):
"""trims the hypothesis and audio buffer at "time"
"""
self.transcript_buffer.pop_commited(time)
cut_seconds = time - self.buffer_time_offset
self.audio_buffer = self.audio_buffer[int(cut_seconds*self.SAMPLING_RATE):]
self.buffer_time_offset = time
self.last_chunked_at = time
def words_to_sentences(self, words):
"""Uses self.tokenizer for sentence segmentation of words.
Returns: [(beg,end,"sentence 1"),...]
"""
cwords = [w for w in words]
t = " ".join(o[2] for o in cwords)
s = self.tokenizer.split(t)
out = []
while s:
beg = None
end = None
sent = s.pop(0).strip()
fsent = sent
while cwords:
b,e,w = cwords.pop(0)
w = w.strip()
if beg is None and sent.startswith(w):
beg = b
elif end is None and sent == w:
end = e
out.append((beg,end,fsent))
break
sent = sent[len(w):].strip()
return out
def finish(self):
"""Flush the incomplete text when the whole processing ends.
Returns: the same format as self.process_iter()
"""
o = self.transcript_buffer.complete()
f = self.to_flush(o)
print("last, noncommited:",f,file=self.logfile)
return f
def to_flush(self, sents, sep=None, offset=0, ):
# concatenates the timestamped words or sentences into one sequence that is flushed in one line
# sents: [(beg1, end1, "sentence1"), ...] or [] if empty
# return: (beg1,end-of-last-sentence,"concatenation of sentences") or (None, None, "") if empty
if sep is None:
sep = self.asr.sep
t = sep.join(s[2] for s in sents)
if len(sents) == 0:
b = None
e = None
else:
b = offset + sents[0][0]
e = offset + sents[-1][1]
return (b,e,t)
WHISPER_LANG_CODES = "af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh".split(",")
def create_tokenizer(lan):
"""returns an object that has split function that works like the one of MosesTokenizer"""
assert lan in WHISPER_LANG_CODES, "language must be Whisper's supported lang code: " + " ".join(WHISPER_LANG_CODES)
if lan == "uk":
import tokenize_uk
class UkrainianTokenizer:
def split(self, text):
return tokenize_uk.tokenize_sents(text)
return UkrainianTokenizer()
# supported by fast-mosestokenizer
if lan in "as bn ca cs de el en es et fi fr ga gu hi hu is it kn lt lv ml mni mr nl or pa pl pt ro ru sk sl sv ta te yue zh".split():
from mosestokenizer import MosesTokenizer
return MosesTokenizer(lan)
# the following languages are in Whisper, but not in wtpsplit:
if lan in "as ba bo br bs fo haw hr ht jw lb ln lo mi nn oc sa sd sn so su sw tk tl tt".split():
print(f"{lan} code is not supported by wtpsplit. Going to use None lang_code option.", file=sys.stderr)
lan = None
from wtpsplit import WtP
# downloads the model from huggingface on the first use
wtp = WtP("wtp-canine-s-12l-no-adapters")
class WtPtok:
def split(self, sent):
return wtp.split(sent, lang_code=lan)
return WtPtok()
def add_shared_args(parser):
"""shared args for simulation (this entry point) and server
parser: argparse.ArgumentParser object
"""
parser.add_argument('--min-chunk-size', type=float, default=1.0, help='Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time. Applicable both to Whisper and seamless.')
parser.add_argument('--model', type=str, default='large-v2', choices="tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large".split(","),help="Name size of the Whisper model to use (default: large-v2). The model is automatically downloaded from the model hub if not present in model cache dir. Not applicable to seamless.")
parser.add_argument('--model_cache_dir', type=str, default=None, help="Overriding the default model cache dir where models downloaded from the hub are saved. Not applicable to seamless.")
parser.add_argument('--model_dir', type=str, default=None, help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter. Not applicable to seamless.")
parser.add_argument('--lan', '--language', type=str, default='en', help="Language code for transcription, e.g. en,de,cs. Seamless backend has its own 3-letter language codes, e.g. eng, deu, ces.")
parser.add_argument('--task', type=str, default='transcribe', choices=["transcribe","translate"],help="Transcribe or translate.")
parser.add_argument('--backend', type=str, default="faster-whisper", choices=["faster-whisper", "whisper_timestamped", "seamless"],help='Load only this backend for Whisper processing, or Seamless Streaming.')
parser.add_argument('--vad', action="store_true", default=False, help='Use VAD = voice activity detection, with the default parameters. Not applicable to seamless.')
parser.add_argument('--buffer_trimming', type=str, default="segment", choices=["sentence", "segment"],help='Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option. Not applicable to seamless.')
parser.add_argument('--buffer_trimming_sec', type=float, default=15, help='Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered. Not applicable to seamless.')
## main:
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('audio_path', type=str, help="Filename of 16kHz mono channel wav, on which live streaming is simulated.")
add_shared_args(parser)
parser.add_argument('--start_at', type=float, default=0.0, help='Start processing audio at this time.')
parser.add_argument('--offline', action="store_true", default=False, help='Offline mode.')
parser.add_argument('--comp_unaware', action="store_true", default=False, help='Computationally unaware simulation.')
args = parser.parse_args()
# reset to store stderr to different file stream, e.g. open(os.devnull,"w")
logfile = sys.stderr
if args.offline and args.comp_unaware:
print("No or one option from --offline and --comp_unaware are available, not both. Exiting.",file=logfile)
sys.exit(1)
audio_path = args.audio_path
SAMPLING_RATE = 16000
duration = len(load_audio(audio_path))/SAMPLING_RATE
print("Audio duration is: %2.2f seconds" % duration, file=logfile)
size = args.model
language = args.lan
min_chunk = args.min_chunk_size
if args.backend != "seamless":
# loading Whisper model
t = time.time()
print(f"Loading Whisper {size} model for {language}...",file=logfile,end=" ",flush=True)
if args.backend == "faster-whisper":
asr_cls = FasterWhisperASR
elif args.backend == "whisper_timestamped":
asr_cls = WhisperTimestampedASR
asr = asr_cls(modelsize=size, lan=language, cache_dir=args.model_cache_dir, model_dir=args.model_dir)
e = time.time()
print(f"done. It took {round(e-t,2)} seconds.",file=logfile)
if args.vad:
print("setting VAD filter",file=logfile)
asr.use_vad()
if args.task == "translate":
asr.set_translate_task()
tgt_language = "en" # Whisper translates into English
else:
tgt_language = language # Whisper transcribes in this language
if args.buffer_trimming == "sentence":
tokenizer = create_tokenizer(tgt_language)
else:
tokenizer = None
online = OnlineASRProcessor(asr,tokenizer,logfile=logfile,buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec))
# load the audio into the LRU cache before we start the timer
a = load_audio_chunk(audio_path,0,1)
# warm up the ASR, because the very first transcribe takes much more time than the other
asr.transcribe(a)
else:
print(f"Loading Seamless Streaming backend model",file=logfile,flush=True)
from seamless_integration import SeamlessProcessor
online = SeamlessProcessor(language, args.task, logfile=logfile)
beg = args.start_at
start = time.time()-beg
def output_transcript(o, now=None):
# output format in stdout is like:
# 4186.3606 0 1720 Takhle to je
# - the first three words are:
# - emission time from beginning of processing, in milliseconds
# - beg and end timestamp of the text segment, as estimated by Whisper model. The timestamps are not accurate, but they're useful anyway
# - the next words: segment transcript
if now is None:
now = time.time()-start
if o[0] is not None:
print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),file=logfile,flush=True)
print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),flush=True)
else:
print(o,file=logfile,flush=True)
if args.offline: ## offline mode processing (for testing/debugging)
a = load_audio(audio_path)
online.insert_audio_chunk(a)
try:
o = online.process_iter()
except AssertionError:
print("assertion error",file=logfile)
pass
else:
output_transcript(o)
now = None
elif args.comp_unaware: # computational unaware mode
end = beg + min_chunk
while True:
a = load_audio_chunk(audio_path,beg,end)
online.insert_audio_chunk(a)
try:
o = online.process_iter()
except AssertionError:
print("assertion error",file=logfile)
pass
else:
output_transcript(o, now=end)
print(f"## last processed {end:.2f}s",file=logfile,flush=True)
if end >= duration:
break
beg = end
if end + min_chunk > duration:
end = duration
else:
end += min_chunk
now = duration
else: # online = simultaneous mode
end = 0
while True:
now = time.time() - start
if now < end+min_chunk:
time.sleep(min_chunk+end-now)
end = time.time() - start
a = load_audio_chunk(audio_path,beg,end)
beg = end
online.insert_audio_chunk(a)
try:
o = online.process_iter()
except AssertionError:
print("assertion error",file=logfile)
pass
else:
output_transcript(o)
now = time.time() - start
print(f"## last processed {end:.2f} s, now is {now:.2f}, the latency is {now-end:.2f}",file=logfile,flush=True)
if end >= duration:
break
now = None
o = online.finish()
output_transcript(o, now=now)

214
whisper_online_server.py Normal file
View File

@@ -0,0 +1,214 @@
#!/usr/bin/env python3
from whisper_online import *
import sys
import argparse
import os
parser = argparse.ArgumentParser()
# server options
parser.add_argument("--host", type=str, default='localhost')
parser.add_argument("--port", type=int, default=43007)
# options from whisper_online
add_shared_args(parser)
args = parser.parse_args()
# setting whisper object by args
SAMPLING_RATE = 16000
language = args.lan
min_chunk = args.min_chunk_size
if args.backend != "seamless": # loading Whisper backend
size = args.model
t = time.time()
print(f"Loading Whisper {size} model for {language}...",file=sys.stderr,end=" ",flush=True)
if args.backend == "faster-whisper":
from faster_whisper import WhisperModel
asr_cls = FasterWhisperASR
else:
import whisper
import whisper_timestamped
# from whisper_timestamped_model import WhisperTimestampedASR
asr_cls = WhisperTimestampedASR
asr = asr_cls(modelsize=size, lan=language, cache_dir=args.model_cache_dir, model_dir=args.model_dir)
if args.task == "translate":
asr.set_translate_task()
tgt_language = "en"
else:
tgt_language = language
e = time.time()
print(f"done. It took {round(e-t,2)} seconds.",file=sys.stderr)
if args.vad:
print("setting VAD filter",file=sys.stderr)
asr.use_vad()
demo_audio_path = "cs-maji-2.16k.wav"
if os.path.exists(demo_audio_path):
# load the audio into the LRU cache before we start the timer
a = load_audio_chunk(demo_audio_path,0,1)
# TODO: it should be tested whether it's meaningful
# warm up the ASR, because the very first transcribe takes much more time than the other
asr.transcribe(a)
else:
print("Whisper is not warmed up",file=sys.stderr)
if args.buffer_trimming == "sentence":
tokenizer = create_tokenizer(tgt_language)
else:
tokenizer = None
online = OnlineASRProcessor(asr,tokenizer,buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec))
else: # seamless backend:
print(f"Loading Seamless Streaming backend model",file=sys.stderr,flush=True)
from seamless_integration import SeamlessProcessor
online = SeamlessProcessor(language, args.task, logfile=sys.stderr)
######### Server objects
import line_packet
import socket
import logging
class Connection:
'''it wraps conn object'''
PACKET_SIZE = 65536
def __init__(self, conn):
self.conn = conn
self.last_line = ""
self.conn.setblocking(True)
def send(self, line):
'''it doesn't send the same line twice, because it was problematic in online-text-flow-events'''
if line == self.last_line:
return
line_packet.send_one_line(self.conn, line)
self.last_line = line
def receive_lines(self):
in_line = line_packet.receive_lines(self.conn)
return in_line
def non_blocking_receive_audio(self):
r = self.conn.recv(self.PACKET_SIZE)
return r
import io
import soundfile
# wraps socket and ASR object, and serves one client connection.
# next client should be served by a new instance of this object
class ServerProcessor:
def __init__(self, c, online_asr_proc, min_chunk):
self.connection = c
self.online_asr_proc = online_asr_proc
self.min_chunk = min_chunk
self.last_end = None
def receive_audio_chunk(self):
# receive all audio that is available by this time
# blocks operation if less than self.min_chunk seconds is available
# unblocks if connection is closed or a chunk is available
out = []
while sum(len(x) for x in out) < self.min_chunk*SAMPLING_RATE:
raw_bytes = self.connection.non_blocking_receive_audio()
print(raw_bytes[:10])
print(len(raw_bytes))
if not raw_bytes:
break
sf = soundfile.SoundFile(io.BytesIO(raw_bytes), channels=1,endian="LITTLE",samplerate=SAMPLING_RATE, subtype="PCM_16",format="RAW")
audio, _ = librosa.load(sf,sr=SAMPLING_RATE)
out.append(audio)
if not out:
return None
return np.concatenate(out)
def format_output_transcript(self,o):
# output format in stdout is like:
# 0 1720 Takhle to je
# - the first two words are:
# - beg and end timestamp of the text segment, as estimated by Whisper model. The timestamps are not accurate, but they're useful anyway
# - the next words: segment transcript
# This function differs from whisper_online.output_transcript in the following:
# succeeding [beg,end] intervals are not overlapping because ELITR protocol (implemented in online-text-flow events) requires it.
# Therefore, beg, is max of previous end and current beg outputed by Whisper.
# Usually it differs negligibly, by appx 20 ms.
if o[0] is not None:
beg, end = o[0]*1000,o[1]*1000
if self.last_end is not None:
beg = max(beg, self.last_end)
self.last_end = end
print("%1.0f %1.0f %s" % (beg,end,o[2]),flush=True,file=sys.stderr)
return "%1.0f %1.0f %s" % (beg,end,o[2])
else:
print(o,file=sys.stderr,flush=True)
return None
def send_result(self, o):
msg = self.format_output_transcript(o)
if msg is not None:
self.connection.send(msg)
def process(self):
# handle one client connection
self.online_asr_proc.init()
while True:
a = self.receive_audio_chunk()
if a is None:
print("break here",file=sys.stderr)
break
self.online_asr_proc.insert_audio_chunk(a)
o = online.process_iter()
try:
self.send_result(o)
except BrokenPipeError:
print("broken pipe -- connection closed?",file=sys.stderr)
break
# o = online.finish() # this should be working
# self.send_result(o)
# Start logging.
level = logging.INFO
logging.basicConfig(level=level, format='whisper-server-%(levelname)s: %(message)s')
# server loop
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind((args.host, args.port))
s.listen(1)
logging.info('INFO: Listening on'+str((args.host, args.port)))
while True:
conn, addr = s.accept()
logging.info('INFO: Connected to client on {}'.format(addr))
connection = Connection(conn)
proc = ServerProcessor(connection, online, min_chunk)
proc.process()
conn.close()
logging.info('INFO: Connection to client closed')
logging.info('INFO: Connection closed, terminating.')

View File

@@ -1,13 +0,0 @@
from .audio_processor import AudioProcessor
from .core import TranscriptionEngine
from .parse_args import parse_args
from .web.web_interface import get_inline_ui_html, get_web_interface_html
__all__ = [
"TranscriptionEngine",
"AudioProcessor",
"parse_args",
"get_web_interface_html",
"get_inline_ui_html",
"download_simulstreaming_backend",
]

View File

@@ -1,634 +0,0 @@
import asyncio
import logging
import traceback
from time import time
from typing import Any, AsyncGenerator, List, Optional, Union
import numpy as np
from whisperlivekit.core import (TranscriptionEngine,
online_diarization_factory, online_factory,
online_translation_factory)
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
from whisperlivekit.silero_vad_iterator import FixedVADIterator, OnnxWrapper, load_jit_vad
from whisperlivekit.timed_objects import (ASRToken, ChangeSpeaker, FrontData,
Segment, Silence, State, Transcript)
from whisperlivekit.tokens_alignment import TokensAlignment
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
SENTINEL = object() # unique sentinel object for end of stream marker
MIN_DURATION_REAL_SILENCE = 5
async def get_all_from_queue(queue: asyncio.Queue) -> Union[object, Silence, np.ndarray, List[Any]]:
items: List[Any] = []
first_item = await queue.get()
queue.task_done()
if first_item is SENTINEL:
return first_item
if isinstance(first_item, Silence):
return first_item
items.append(first_item)
while True:
if not queue._queue:
break
next_item = queue._queue[0]
if next_item is SENTINEL:
break
if isinstance(next_item, Silence):
break
items.append(await queue.get())
queue.task_done()
if isinstance(items[0], np.ndarray):
return np.concatenate(items)
else: #translation
return items
class AudioProcessor:
"""
Processes audio streams for transcription and diarization.
Handles audio processing, state management, and result formatting.
"""
def __init__(self, **kwargs: Any) -> None:
"""Initialize the audio processor with configuration, models, and state."""
if 'transcription_engine' in kwargs and isinstance(kwargs['transcription_engine'], TranscriptionEngine):
models = kwargs['transcription_engine']
else:
models = TranscriptionEngine(**kwargs)
# Audio processing settings
self.args = models.args
self.sample_rate = 16000
self.channels = 1
self.samples_per_sec = int(self.sample_rate * self.args.min_chunk_size)
self.bytes_per_sample = 2
self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample
self.max_bytes_per_sec = 32000 * 5 # 5 seconds of audio at 32 kHz
self.is_pcm_input = self.args.pcm_input
# State management
self.is_stopping: bool = False
self.current_silence: Optional[Silence] = None
self.state: State = State()
self.lock: asyncio.Lock = asyncio.Lock()
self.sep: str = " " # Default separator
self.last_response_content: FrontData = FrontData()
self.tokens_alignment: TokensAlignment = TokensAlignment(self.state, self.args, self.sep)
self.beg_loop: Optional[float] = None
# Models and processing
self.asr: Any = models.asr
self.vac: Optional[FixedVADIterator] = None
if self.args.vac:
if models.vac_session is not None:
vac_model = OnnxWrapper(session=models.vac_session)
self.vac = FixedVADIterator(vac_model)
else:
self.vac = FixedVADIterator(load_jit_vad())
self.ffmpeg_manager: Optional[FFmpegManager] = None
self.ffmpeg_reader_task: Optional[asyncio.Task] = None
self._ffmpeg_error: Optional[str] = None
if not self.is_pcm_input:
self.ffmpeg_manager = FFmpegManager(
sample_rate=self.sample_rate,
channels=self.channels
)
async def handle_ffmpeg_error(error_type: str):
logger.error(f"FFmpeg error: {error_type}")
self._ffmpeg_error = error_type
self.ffmpeg_manager.on_error_callback = handle_ffmpeg_error
self.transcription_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.transcription else None
self.diarization_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.diarization else None
self.translation_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.target_language else None
self.pcm_buffer: bytearray = bytearray()
self.total_pcm_samples: int = 0
self.transcription_task: Optional[asyncio.Task] = None
self.diarization_task: Optional[asyncio.Task] = None
self.translation_task: Optional[asyncio.Task] = None
self.watchdog_task: Optional[asyncio.Task] = None
self.all_tasks_for_cleanup: List[asyncio.Task] = []
self.transcription: Optional[Any] = None
self.translation: Optional[Any] = None
self.diarization: Optional[Any] = None
if self.args.transcription:
self.transcription = online_factory(self.args, models.asr)
self.sep = self.transcription.asr.sep
if self.args.diarization:
self.diarization = online_diarization_factory(self.args, models.diarization_model)
if models.translation_model:
self.translation = online_translation_factory(self.args, models.translation_model)
async def _push_silence_event(self) -> None:
if self.transcription_queue:
await self.transcription_queue.put(self.current_silence)
if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(self.current_silence)
if self.translation_queue:
await self.translation_queue.put(self.current_silence)
async def _begin_silence(self) -> None:
if self.current_silence:
return
now = time() - self.beg_loop
self.current_silence = Silence(
is_starting=True, start=now
)
await self._push_silence_event()
async def _end_silence(self) -> None:
if not self.current_silence:
return
now = time() - self.beg_loop
self.current_silence.end = now
self.current_silence.is_starting=False
self.current_silence.has_ended=True
self.current_silence.compute_duration()
if self.current_silence.duration > MIN_DURATION_REAL_SILENCE:
self.state.new_tokens.append(self.current_silence)
await self._push_silence_event()
self.current_silence = None
async def _enqueue_active_audio(self, pcm_chunk: np.ndarray) -> None:
if pcm_chunk is None or pcm_chunk.size == 0:
return
if self.transcription_queue:
await self.transcription_queue.put(pcm_chunk.copy())
if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(pcm_chunk.copy())
def _slice_before_silence(self, pcm_array: np.ndarray, chunk_sample_start: int, silence_sample: Optional[int]) -> Optional[np.ndarray]:
if silence_sample is None:
return None
relative_index = int(silence_sample - chunk_sample_start)
if relative_index <= 0:
return None
split_index = min(relative_index, len(pcm_array))
if split_index <= 0:
return None
return pcm_array[:split_index]
def convert_pcm_to_float(self, pcm_buffer: Union[bytes, bytearray]) -> np.ndarray:
"""Convert PCM buffer in s16le format to normalized NumPy array."""
return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
async def get_current_state(self) -> State:
"""Get current state."""
async with self.lock:
current_time = time()
remaining_transcription = 0
if self.state.end_buffer > 0:
remaining_transcription = max(0, round(current_time - self.beg_loop - self.state.end_buffer, 1))
remaining_diarization = 0
if self.state.tokens:
latest_end = max(self.state.end_buffer, self.state.tokens[-1].end if self.state.tokens else 0)
remaining_diarization = max(0, round(latest_end - self.state.end_attributed_speaker, 1))
self.state.remaining_time_transcription = remaining_transcription
self.state.remaining_time_diarization = remaining_diarization
return self.state
async def ffmpeg_stdout_reader(self) -> None:
"""Read audio data from FFmpeg stdout and process it into the PCM pipeline."""
beg = time()
while True:
try:
if self.is_stopping:
logger.info("Stopping ffmpeg_stdout_reader due to stopping flag.")
break
state = await self.ffmpeg_manager.get_state() if self.ffmpeg_manager else FFmpegState.STOPPED
if state == FFmpegState.FAILED:
logger.error("FFmpeg is in FAILED state, cannot read data")
break
elif state == FFmpegState.STOPPED:
logger.info("FFmpeg is stopped")
break
elif state != FFmpegState.RUNNING:
await asyncio.sleep(0.1)
continue
current_time = time()
elapsed_time = max(0.0, current_time - beg)
buffer_size = max(int(32000 * elapsed_time), 4096) # dynamic read
beg = current_time
chunk = await self.ffmpeg_manager.read_data(buffer_size)
if not chunk:
# No data currently available
await asyncio.sleep(0.05)
continue
self.pcm_buffer.extend(chunk)
await self.handle_pcm_data()
except asyncio.CancelledError:
logger.info("ffmpeg_stdout_reader cancelled.")
break
except Exception as e:
logger.warning(f"Exception in ffmpeg_stdout_reader: {e}")
logger.debug(f"Traceback: {traceback.format_exc()}")
await asyncio.sleep(0.2)
logger.info("FFmpeg stdout processing finished. Signaling downstream processors if needed.")
if self.transcription_queue:
await self.transcription_queue.put(SENTINEL)
if self.diarization:
await self.diarization_queue.put(SENTINEL)
if self.translation:
await self.translation_queue.put(SENTINEL)
async def transcription_processor(self) -> None:
"""Process audio chunks for transcription."""
cumulative_pcm_duration_stream_time = 0.0
while True:
try:
# item = await self.transcription_queue.get()
item = await get_all_from_queue(self.transcription_queue)
if item is SENTINEL:
logger.debug("Transcription processor received sentinel. Finishing.")
break
asr_internal_buffer_duration_s = len(getattr(self.transcription, 'audio_buffer', [])) / self.transcription.SAMPLING_RATE
transcription_lag_s = max(0.0, time() - self.beg_loop - self.state.end_buffer)
asr_processing_logs = f"internal_buffer={asr_internal_buffer_duration_s:.2f}s | lag={transcription_lag_s:.2f}s |"
stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time
new_tokens = []
current_audio_processed_upto = self.state.end_buffer
if isinstance(item, Silence):
if item.is_starting:
new_tokens, current_audio_processed_upto = await asyncio.to_thread(
self.transcription.start_silence
)
asr_processing_logs += f" + Silence starting"
if item.has_ended:
asr_processing_logs += f" + Silence of = {item.duration:.2f}s"
cumulative_pcm_duration_stream_time += item.duration
current_audio_processed_upto = cumulative_pcm_duration_stream_time
self.transcription.end_silence(item.duration, self.state.tokens[-1].end if self.state.tokens else 0)
if self.state.tokens:
asr_processing_logs += f" | last_end = {self.state.tokens[-1].end} |"
logger.info(asr_processing_logs)
new_tokens = new_tokens or []
current_audio_processed_upto = max(current_audio_processed_upto, stream_time_end_of_current_pcm)
elif isinstance(item, ChangeSpeaker):
self.transcription.new_speaker(item)
continue
elif isinstance(item, np.ndarray):
pcm_array = item
logger.info(asr_processing_logs)
cumulative_pcm_duration_stream_time += len(pcm_array) / self.sample_rate
stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time
self.transcription.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm)
new_tokens, current_audio_processed_upto = await asyncio.to_thread(self.transcription.process_iter)
new_tokens = new_tokens or []
_buffer_transcript = self.transcription.get_buffer()
buffer_text = _buffer_transcript.text
if new_tokens:
validated_text = self.sep.join([t.text for t in new_tokens])
if buffer_text.startswith(validated_text):
_buffer_transcript.text = buffer_text[len(validated_text):].lstrip()
candidate_end_times = [self.state.end_buffer]
if new_tokens:
candidate_end_times.append(new_tokens[-1].end)
if _buffer_transcript.end is not None:
candidate_end_times.append(_buffer_transcript.end)
candidate_end_times.append(current_audio_processed_upto)
async with self.lock:
self.state.tokens.extend(new_tokens)
self.state.buffer_transcription = _buffer_transcript
self.state.end_buffer = max(candidate_end_times)
self.state.new_tokens.extend(new_tokens)
self.state.new_tokens_buffer = _buffer_transcript
if self.translation_queue:
for token in new_tokens:
await self.translation_queue.put(token)
except Exception as e:
logger.warning(f"Exception in transcription_processor: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}")
if 'pcm_array' in locals() and pcm_array is not SENTINEL : # Check if pcm_array was assigned from queue
self.transcription_queue.task_done()
if self.is_stopping:
logger.info("Transcription processor finishing due to stopping flag.")
if self.diarization_queue:
await self.diarization_queue.put(SENTINEL)
if self.translation_queue:
await self.translation_queue.put(SENTINEL)
logger.info("Transcription processor task finished.")
async def diarization_processor(self) -> None:
while True:
try:
item = await get_all_from_queue(self.diarization_queue)
if item is SENTINEL:
break
elif type(item) is Silence:
if item.has_ended:
self.diarization.insert_silence(item.duration)
continue
self.diarization.insert_audio_chunk(item)
diarization_segments = await self.diarization.diarize()
diar_end = 0.0
if diarization_segments:
diar_end = max(getattr(s, "end", 0.0) for s in diarization_segments)
async with self.lock:
self.state.new_diarization = diarization_segments
self.state.end_attributed_speaker = max(self.state.end_attributed_speaker, diar_end)
except Exception as e:
logger.warning(f"Exception in diarization_processor: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}")
logger.info("Diarization processor task finished.")
async def translation_processor(self) -> None:
# the idea is to ignore diarization for the moment. We use only transcription tokens.
# And the speaker is attributed given the segments used for the translation
# in the future we want to have different languages for each speaker etc, so it will be more complex.
while True:
try:
item = await get_all_from_queue(self.translation_queue)
if item is SENTINEL:
logger.debug("Translation processor received sentinel. Finishing.")
break
elif type(item) is Silence:
if item.is_starting:
new_translation, new_translation_buffer = self.translation.validate_buffer_and_reset()
if item.has_ended:
self.translation.insert_silence(item.duration)
continue
elif isinstance(item, ChangeSpeaker):
new_translation, new_translation_buffer = self.translation.validate_buffer_and_reset()
pass
else:
self.translation.insert_tokens(item)
new_translation, new_translation_buffer = await asyncio.to_thread(self.translation.process)
async with self.lock:
self.state.new_translation.append(new_translation)
self.state.new_translation_buffer = new_translation_buffer
except Exception as e:
logger.warning(f"Exception in translation_processor: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}")
logger.info("Translation processor task finished.")
async def results_formatter(self) -> AsyncGenerator[FrontData, None]:
"""Format processing results for output."""
while True:
try:
if self._ffmpeg_error:
yield FrontData(status="error", error=f"FFmpeg error: {self._ffmpeg_error}")
self._ffmpeg_error = None
await asyncio.sleep(1)
continue
self.tokens_alignment.update()
lines, buffer_diarization_text, buffer_translation_text = self.tokens_alignment.get_lines(
diarization=self.args.diarization,
translation=bool(self.translation),
current_silence=self.current_silence
)
state = await self.get_current_state()
buffer_transcription_text = state.buffer_transcription.text if state.buffer_transcription else ''
response_status = "active_transcription"
if not lines and not buffer_transcription_text and not buffer_diarization_text:
response_status = "no_audio_detected"
response = FrontData(
status=response_status,
lines=lines,
buffer_transcription=buffer_transcription_text,
buffer_diarization=buffer_diarization_text,
buffer_translation=buffer_translation_text,
remaining_time_transcription=state.remaining_time_transcription,
remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0
)
should_push = (response != self.last_response_content)
if should_push:
yield response
self.last_response_content = response
if self.is_stopping and self._processing_tasks_done():
logger.info("Results formatter: All upstream processors are done and in stopping state. Terminating.")
return
await asyncio.sleep(0.05)
except Exception as e:
logger.warning(f"Exception in results_formatter. Traceback: {traceback.format_exc()}")
await asyncio.sleep(0.5)
async def create_tasks(self) -> AsyncGenerator[FrontData, None]:
"""Create and start processing tasks."""
self.all_tasks_for_cleanup = []
processing_tasks_for_watchdog: List[asyncio.Task] = []
# If using FFmpeg (non-PCM input), start it and spawn stdout reader
if not self.is_pcm_input:
success = await self.ffmpeg_manager.start()
if not success:
logger.error("Failed to start FFmpeg manager")
async def error_generator() -> AsyncGenerator[FrontData, None]:
yield FrontData(
status="error",
error="FFmpeg failed to start. Please check that FFmpeg is installed."
)
return error_generator()
self.ffmpeg_reader_task = asyncio.create_task(self.ffmpeg_stdout_reader())
self.all_tasks_for_cleanup.append(self.ffmpeg_reader_task)
processing_tasks_for_watchdog.append(self.ffmpeg_reader_task)
if self.transcription:
self.transcription_task = asyncio.create_task(self.transcription_processor())
self.all_tasks_for_cleanup.append(self.transcription_task)
processing_tasks_for_watchdog.append(self.transcription_task)
if self.diarization:
self.diarization_task = asyncio.create_task(self.diarization_processor())
self.all_tasks_for_cleanup.append(self.diarization_task)
processing_tasks_for_watchdog.append(self.diarization_task)
if self.translation:
self.translation_task = asyncio.create_task(self.translation_processor())
self.all_tasks_for_cleanup.append(self.translation_task)
processing_tasks_for_watchdog.append(self.translation_task)
# Monitor overall system health
self.watchdog_task = asyncio.create_task(self.watchdog(processing_tasks_for_watchdog))
self.all_tasks_for_cleanup.append(self.watchdog_task)
return self.results_formatter()
async def watchdog(self, tasks_to_monitor: List[asyncio.Task]) -> None:
"""Monitors the health of critical processing tasks."""
tasks_remaining: List[asyncio.Task] = [task for task in tasks_to_monitor if task]
while True:
try:
if not tasks_remaining:
logger.info("Watchdog task finishing: all monitored tasks completed.")
return
await asyncio.sleep(10)
for i, task in enumerate(list(tasks_remaining)):
if task.done():
exc = task.exception()
task_name = task.get_name() if hasattr(task, 'get_name') else f"Monitored Task {i}"
if exc:
logger.error(f"{task_name} unexpectedly completed with exception: {exc}")
else:
logger.info(f"{task_name} completed normally.")
tasks_remaining.remove(task)
except asyncio.CancelledError:
logger.info("Watchdog task cancelled.")
break
except Exception as e:
logger.error(f"Error in watchdog task: {e}", exc_info=True)
async def cleanup(self) -> None:
"""Clean up resources when processing is complete."""
logger.info("Starting cleanup of AudioProcessor resources.")
self.is_stopping = True
for task in self.all_tasks_for_cleanup:
if task and not task.done():
task.cancel()
created_tasks = [t for t in self.all_tasks_for_cleanup if t]
if created_tasks:
await asyncio.gather(*created_tasks, return_exceptions=True)
logger.info("All processing tasks cancelled or finished.")
if not self.is_pcm_input and self.ffmpeg_manager:
try:
await self.ffmpeg_manager.stop()
logger.info("FFmpeg manager stopped.")
except Exception as e:
logger.warning(f"Error stopping FFmpeg manager: {e}")
if self.diarization:
self.diarization.close()
logger.info("AudioProcessor cleanup complete.")
def _processing_tasks_done(self) -> bool:
"""Return True when all active processing tasks have completed."""
tasks_to_check = [
self.transcription_task,
self.diarization_task,
self.translation_task,
self.ffmpeg_reader_task,
]
return all(task.done() for task in tasks_to_check if task)
async def process_audio(self, message: Optional[bytes]) -> None:
"""Process incoming audio data."""
if not self.beg_loop:
self.beg_loop = time()
self.current_silence = Silence(start=0.0, is_starting=True)
self.tokens_alignment.beg_loop = self.beg_loop
if not message:
logger.info("Empty audio message received, initiating stop sequence.")
self.is_stopping = True
if self.transcription_queue:
await self.transcription_queue.put(SENTINEL)
if not self.is_pcm_input and self.ffmpeg_manager:
await self.ffmpeg_manager.stop()
return
if self.is_stopping:
logger.warning("AudioProcessor is stopping. Ignoring incoming audio.")
return
if self.is_pcm_input:
self.pcm_buffer.extend(message)
await self.handle_pcm_data()
else:
if not self.ffmpeg_manager:
logger.error("FFmpeg manager not initialized for non-PCM input.")
return
success = await self.ffmpeg_manager.write_data(message)
if not success:
ffmpeg_state = await self.ffmpeg_manager.get_state()
if ffmpeg_state == FFmpegState.FAILED:
logger.error("FFmpeg is in FAILED state, cannot process audio")
else:
logger.warning("Failed to write audio data to FFmpeg")
async def handle_pcm_data(self) -> None:
# Process when enough data
if len(self.pcm_buffer) < self.bytes_per_sec:
return
if len(self.pcm_buffer) > self.max_bytes_per_sec:
logger.warning(
f"Audio buffer too large: {len(self.pcm_buffer) / self.bytes_per_sec:.2f}s. "
f"Consider using a smaller model."
)
chunk_size = min(len(self.pcm_buffer), self.max_bytes_per_sec)
aligned_chunk_size = (chunk_size // self.bytes_per_sample) * self.bytes_per_sample
if aligned_chunk_size == 0:
return
pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:aligned_chunk_size])
self.pcm_buffer = self.pcm_buffer[aligned_chunk_size:]
num_samples = len(pcm_array)
chunk_sample_start = self.total_pcm_samples
chunk_sample_end = chunk_sample_start + num_samples
res = None
if self.args.vac:
res = self.vac(pcm_array)
if res is not None:
if "start" in res and self.current_silence:
await self._end_silence()
if "end" in res and not self.current_silence:
pre_silence_chunk = self._slice_before_silence(
pcm_array, chunk_sample_start, res.get("end")
)
if pre_silence_chunk is not None and pre_silence_chunk.size > 0:
await self._enqueue_active_audio(pre_silence_chunk)
await self._begin_silence()
if not self.current_silence:
await self._enqueue_active_audio(pcm_array)
self.total_pcm_samples = chunk_sample_end
if not self.args.transcription and not self.args.diarization:
await asyncio.sleep(0.1)

View File

@@ -1,41 +0,0 @@
import importlib.util
import logging
import platform
logger = logging.getLogger(__name__)
def module_available(module_name):
"""Return True if the given module can be imported."""
return importlib.util.find_spec(module_name) is not None
def mlx_backend_available(warn_on_missing = False):
is_macos = platform.system() == "Darwin"
is_arm = platform.machine() == "arm64"
available = (
is_macos
and is_arm
and module_available("mlx_whisper")
)
if not available and warn_on_missing and is_macos and is_arm:
logger.warning(
"=" * 50
+ "\nMLX Whisper not found but you are on Apple Silicon. "
"Consider installing mlx-whisper for better performance: "
"`pip install mlx-whisper`\n"
+ "=" * 50
)
return available
def faster_backend_available(warn_on_missing = False):
available = module_available("faster_whisper")
if not available and warn_on_missing and platform.system() != "Darwin":
logger.warning(
"=" * 50
+ "\nFaster-Whisper not found. Consider installing faster-whisper "
"for better performance: `pip install faster-whisper`\n"
+ "=" * 50
)
return available

View File

@@ -1,130 +0,0 @@
import asyncio
import logging
from contextlib import asynccontextmanager
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from whisperlivekit import (AudioProcessor, TranscriptionEngine,
get_inline_ui_html, parse_args)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logging.getLogger().setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
args = parse_args()
transcription_engine = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global transcription_engine
transcription_engine = TranscriptionEngine(
**vars(args),
)
yield
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def get():
return HTMLResponse(get_inline_ui_html())
async def handle_websocket_results(websocket, results_generator):
"""Consumes results from the audio processor and sends them via WebSocket."""
try:
async for response in results_generator:
await websocket.send_json(response.to_dict())
# when the results_generator finishes it means all audio has been processed
logger.info("Results generator finished. Sending 'ready_to_stop' to client.")
await websocket.send_json({"type": "ready_to_stop"})
except WebSocketDisconnect:
logger.info("WebSocket disconnected while handling results (client likely closed connection).")
except Exception as e:
logger.exception(f"Error in WebSocket results handler: {e}")
@app.websocket("/asr")
async def websocket_endpoint(websocket: WebSocket):
global transcription_engine
audio_processor = AudioProcessor(
transcription_engine=transcription_engine,
)
await websocket.accept()
logger.info("WebSocket connection opened.")
try:
await websocket.send_json({"type": "config", "useAudioWorklet": bool(args.pcm_input)})
except Exception as e:
logger.warning(f"Failed to send config to client: {e}")
results_generator = await audio_processor.create_tasks()
websocket_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
try:
while True:
message = await websocket.receive_bytes()
await audio_processor.process_audio(message)
except KeyError as e:
if 'bytes' in str(e):
logger.warning(f"Client has closed the connection.")
else:
logger.error(f"Unexpected KeyError in websocket_endpoint: {e}", exc_info=True)
except WebSocketDisconnect:
logger.info("WebSocket disconnected by client during message receiving loop.")
except Exception as e:
logger.error(f"Unexpected error in websocket_endpoint main loop: {e}", exc_info=True)
finally:
logger.info("Cleaning up WebSocket endpoint...")
if not websocket_task.done():
websocket_task.cancel()
try:
await websocket_task
except asyncio.CancelledError:
logger.info("WebSocket results handler task was cancelled.")
except Exception as e:
logger.warning(f"Exception while awaiting websocket_task completion: {e}")
await audio_processor.cleanup()
logger.info("WebSocket endpoint cleaned up successfully.")
def main():
"""Entry point for the CLI command."""
import uvicorn
uvicorn_kwargs = {
"app": "whisperlivekit.basic_server:app",
"host":args.host,
"port":args.port,
"reload": False,
"log_level": "info",
"lifespan": "on",
}
ssl_kwargs = {}
if args.ssl_certfile or args.ssl_keyfile:
if not (args.ssl_certfile and args.ssl_keyfile):
raise ValueError("Both --ssl-certfile and --ssl-keyfile must be specified together.")
ssl_kwargs = {
"ssl_certfile": args.ssl_certfile,
"ssl_keyfile": args.ssl_keyfile
}
if ssl_kwargs:
uvicorn_kwargs = {**uvicorn_kwargs, **ssl_kwargs}
if args.forwarded_allow_ips:
uvicorn_kwargs = { **uvicorn_kwargs, "forwarded_allow_ips" : args.forwarded_allow_ips }
uvicorn.run(**uvicorn_kwargs)
if __name__ == "__main__":
main()

View File

@@ -1,212 +0,0 @@
import logging
import sys
import threading
from argparse import Namespace
from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor
from whisperlivekit.local_agreement.whisper_online import backend_factory
from whisperlivekit.simul_whisper import SimulStreamingASR
def update_with_kwargs(_dict, kwargs):
_dict.update({
k: v for k, v in kwargs.items() if k in _dict
})
return _dict
logger = logging.getLogger(__name__)
class TranscriptionEngine:
_instance = None
_initialized = False
_lock = threading.Lock() # Thread-safe singleton lock
def __new__(cls, *args, **kwargs):
# Double-checked locking pattern for thread-safe singleton
if cls._instance is None:
with cls._lock:
# Check again inside lock to prevent race condition
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, **kwargs):
# Thread-safe initialization check
with TranscriptionEngine._lock:
if TranscriptionEngine._initialized:
return
# Set flag immediately to prevent re-initialization
TranscriptionEngine._initialized = True
# Perform initialization outside lock to avoid holding lock during slow operations
global_params = {
"host": "localhost",
"port": 8000,
"diarization": False,
"punctuation_split": False,
"target_language": "",
"vac": True,
"vac_chunk_size": 0.04,
"log_level": "DEBUG",
"ssl_certfile": None,
"ssl_keyfile": None,
"forwarded_allow_ips": None,
"transcription": True,
"vad": True,
"pcm_input": False,
"disable_punctuation_split" : False,
"diarization_backend": "sortformer",
"backend_policy": "simulstreaming",
"backend": "auto",
}
global_params = update_with_kwargs(global_params, kwargs)
transcription_common_params = {
"warmup_file": None,
"min_chunk_size": 0.1,
"model_size": "base",
"model_cache_dir": None,
"model_dir": None,
"model_path": None,
"lora_path": None,
"lan": "auto",
"direct_english_translation": False,
}
transcription_common_params = update_with_kwargs(transcription_common_params, kwargs)
if transcription_common_params['model_size'].endswith(".en"):
transcription_common_params["lan"] = "en"
if 'no_transcription' in kwargs:
global_params['transcription'] = not global_params['no_transcription']
if 'no_vad' in kwargs:
global_params['vad'] = not kwargs['no_vad']
if 'no_vac' in kwargs:
global_params['vac'] = not kwargs['no_vac']
self.args = Namespace(**{**global_params, **transcription_common_params})
self.asr = None
self.tokenizer = None
self.diarization = None
self.vac_session = None
if self.args.vac:
from whisperlivekit.silero_vad_iterator import is_onnx_available
if is_onnx_available():
from whisperlivekit.silero_vad_iterator import load_onnx_session
self.vac_session = load_onnx_session()
else:
logger.warning(
"onnxruntime not installed. VAC will use JIT model which is loaded per-session. "
"For multi-user scenarios, install onnxruntime: pip install onnxruntime"
)
backend_policy = self.args.backend_policy
if self.args.transcription:
if backend_policy == "simulstreaming":
simulstreaming_params = {
"disable_fast_encoder": False,
"custom_alignment_heads": None,
"frame_threshold": 25,
"beams": 1,
"decoder_type": None,
"audio_max_len": 20.0,
"audio_min_len": 0.0,
"cif_ckpt_path": None,
"never_fire": False,
"init_prompt": None,
"static_init_prompt": None,
"max_context_tokens": None,
}
simulstreaming_params = update_with_kwargs(simulstreaming_params, kwargs)
self.tokenizer = None
self.asr = SimulStreamingASR(
**transcription_common_params,
**simulstreaming_params,
backend=self.args.backend,
)
logger.info(
"Using SimulStreaming policy with %s backend",
getattr(self.asr, "encoder_backend", "whisper"),
)
else:
whisperstreaming_params = {
"buffer_trimming": "segment",
"confidence_validation": False,
"buffer_trimming_sec": 15,
}
whisperstreaming_params = update_with_kwargs(whisperstreaming_params, kwargs)
self.asr = backend_factory(
backend=self.args.backend,
**transcription_common_params,
**whisperstreaming_params,
)
logger.info(
"Using LocalAgreement policy with %s backend",
getattr(self.asr, "backend_choice", self.asr.__class__.__name__),
)
if self.args.diarization:
if self.args.diarization_backend == "diart":
from whisperlivekit.diarization.diart_backend import \
DiartDiarization
diart_params = {
"segmentation_model": "pyannote/segmentation-3.0",
"embedding_model": "pyannote/embedding",
}
diart_params = update_with_kwargs(diart_params, kwargs)
self.diarization_model = DiartDiarization(
block_duration=self.args.min_chunk_size,
**diart_params
)
elif self.args.diarization_backend == "sortformer":
from whisperlivekit.diarization.sortformer_backend import \
SortformerDiarization
self.diarization_model = SortformerDiarization()
self.translation_model = None
if self.args.target_language:
if self.args.lan == 'auto' and backend_policy != "simulstreaming":
raise Exception('Translation cannot be set with language auto when transcription backend is not simulstreaming')
else:
try:
from nllw import load_model
except:
raise Exception('To use translation, you must install nllw: `pip install nllw`')
translation_params = {
"nllb_backend": "transformers",
"nllb_size": "600M"
}
translation_params = update_with_kwargs(translation_params, kwargs)
self.translation_model = load_model([self.args.lan], **translation_params) #in the future we want to handle different languages for different speakers
def online_factory(args, asr):
if args.backend_policy == "simulstreaming":
from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor
return SimulStreamingOnlineProcessor(asr)
return OnlineASRProcessor(asr)
def online_diarization_factory(args, diarization_backend):
if args.diarization_backend == "diart":
online = diarization_backend
# Not the best here, since several user/instances will share the same backend, but diart is not SOTA anymore and sortformer is recommended
if args.diarization_backend == "sortformer":
from whisperlivekit.diarization.sortformer_backend import \
SortformerDiarizationOnline
online = SortformerDiarizationOnline(shared_model=diarization_backend)
return online
def online_translation_factory(args, translation_model):
#should be at speaker level in the future:
#one shared nllb model for all speaker
#one tokenizer per speaker/language
from nllw import OnlineTranslation
return OnlineTranslation(translation_model, [args.lan], [args.target_language])

View File

@@ -1,288 +0,0 @@
import asyncio
import logging
import re
import threading
import time
from queue import Empty, SimpleQueue
from typing import Any, List, Tuple
import diart.models as m
import numpy as np
from diart import SpeakerDiarization, SpeakerDiarizationConfig
from diart.inference import StreamingInference
from diart.sources import AudioSource, MicrophoneAudioSource
from pyannote.core import Annotation
from rx.core import Observer
from whisperlivekit.timed_objects import SpeakerSegment
logger = logging.getLogger(__name__)
def extract_number(s: str) -> int:
m = re.search(r'\d+', s)
return int(m.group()) if m else None
class DiarizationObserver(Observer):
"""Observer that logs all data emitted by the diarization pipeline and stores speaker segments."""
def __init__(self):
self.diarization_segments = []
self.processed_time = 0
self.segment_lock = threading.Lock()
self.global_time_offset = 0.0
def on_next(self, value: Tuple[Annotation, Any]):
annotation, audio = value
logger.debug("\n--- New Diarization Result ---")
duration = audio.extent.end - audio.extent.start
logger.debug(f"Audio segment: {audio.extent.start:.2f}s - {audio.extent.end:.2f}s (duration: {duration:.2f}s)")
logger.debug(f"Audio shape: {audio.data.shape}")
with self.segment_lock:
if audio.extent.end > self.processed_time:
self.processed_time = audio.extent.end
if annotation and len(annotation._labels) > 0:
logger.debug("\nSpeaker segments:")
for speaker, label in annotation._labels.items():
for start, end in zip(label.segments_boundaries_[:-1], label.segments_boundaries_[1:]):
print(f" {speaker}: {start:.2f}s-{end:.2f}s")
self.diarization_segments.append(SpeakerSegment(
speaker=speaker,
start=start + self.global_time_offset,
end=end + self.global_time_offset
))
else:
logger.debug("\nNo speakers detected in this segment")
def get_segments(self) -> List[SpeakerSegment]:
"""Get a copy of the current speaker segments."""
with self.segment_lock:
return self.diarization_segments.copy()
def clear_old_segments(self, older_than: float = 30.0):
"""Clear segments older than the specified time."""
with self.segment_lock:
current_time = self.processed_time
self.diarization_segments = [
segment for segment in self.diarization_segments
if current_time - segment.end < older_than
]
def on_error(self, error):
"""Handle an error in the stream."""
logger.debug(f"Error in diarization stream: {error}")
def on_completed(self):
"""Handle the completion of the stream."""
logger.debug("Diarization stream completed")
class WebSocketAudioSource(AudioSource):
"""
Buffers incoming audio and releases it in fixed-size chunks at regular intervals.
"""
def __init__(self, uri: str = "websocket", sample_rate: int = 16000, block_duration: float = 0.5):
super().__init__(uri, sample_rate)
self.block_duration = block_duration
self.block_size = int(np.rint(block_duration * sample_rate))
self._queue = SimpleQueue()
self._buffer = np.array([], dtype=np.float32)
self._buffer_lock = threading.Lock()
self._closed = False
self._close_event = threading.Event()
self._processing_thread = None
self._last_chunk_time = time.time()
def read(self):
"""Start processing buffered audio and emit fixed-size chunks."""
self._processing_thread = threading.Thread(target=self._process_chunks)
self._processing_thread.daemon = True
self._processing_thread.start()
self._close_event.wait()
if self._processing_thread:
self._processing_thread.join(timeout=2.0)
def _process_chunks(self):
"""Process audio from queue and emit fixed-size chunks at regular intervals."""
while not self._closed:
try:
audio_chunk = self._queue.get(timeout=0.1)
with self._buffer_lock:
self._buffer = np.concatenate([self._buffer, audio_chunk])
while len(self._buffer) >= self.block_size:
chunk = self._buffer[:self.block_size]
self._buffer = self._buffer[self.block_size:]
current_time = time.time()
time_since_last = current_time - self._last_chunk_time
if time_since_last < self.block_duration:
time.sleep(self.block_duration - time_since_last)
chunk_reshaped = chunk.reshape(1, -1)
self.stream.on_next(chunk_reshaped)
self._last_chunk_time = time.time()
except Empty:
with self._buffer_lock:
if len(self._buffer) > 0 and time.time() - self._last_chunk_time > self.block_duration:
padded_chunk = np.zeros(self.block_size, dtype=np.float32)
padded_chunk[:len(self._buffer)] = self._buffer
self._buffer = np.array([], dtype=np.float32)
chunk_reshaped = padded_chunk.reshape(1, -1)
self.stream.on_next(chunk_reshaped)
self._last_chunk_time = time.time()
except Exception as e:
logger.error(f"Error in audio processing thread: {e}")
self.stream.on_error(e)
break
with self._buffer_lock:
if len(self._buffer) > 0:
padded_chunk = np.zeros(self.block_size, dtype=np.float32)
padded_chunk[:len(self._buffer)] = self._buffer
chunk_reshaped = padded_chunk.reshape(1, -1)
self.stream.on_next(chunk_reshaped)
self.stream.on_completed()
def close(self):
if not self._closed:
self._closed = True
self._close_event.set()
def push_audio(self, chunk: np.ndarray):
"""Add audio chunk to the processing queue."""
if not self._closed:
if chunk.ndim > 1:
chunk = chunk.flatten()
self._queue.put(chunk)
logger.debug(f'Added chunk to queue with {len(chunk)} samples')
class DiartDiarization:
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 1.5, segmentation_model_name: str = "pyannote/segmentation-3.0", embedding_model_name: str = "pyannote/embedding"):
segmentation_model = m.SegmentationModel.from_pretrained(segmentation_model_name)
embedding_model = m.EmbeddingModel.from_pretrained(embedding_model_name)
if config is None:
config = SpeakerDiarizationConfig(
segmentation=segmentation_model,
embedding=embedding_model,
)
self.pipeline = SpeakerDiarization(config=config)
self.observer = DiarizationObserver()
if use_microphone:
self.source = MicrophoneAudioSource(block_duration=block_duration)
self.custom_source = None
else:
self.custom_source = WebSocketAudioSource(
uri="websocket_source",
sample_rate=sample_rate,
block_duration=block_duration
)
self.source = self.custom_source
self.inference = StreamingInference(
pipeline=self.pipeline,
source=self.source,
do_plot=False,
show_progress=False,
)
self.inference.attach_observers(self.observer)
asyncio.get_event_loop().run_in_executor(None, self.inference)
def insert_silence(self, silence_duration):
self.observer.global_time_offset += silence_duration
async def diarize(self, pcm_array: np.ndarray):
"""
Process audio data for diarization.
Only used when working with WebSocketAudioSource.
"""
if self.custom_source:
self.custom_source.push_audio(pcm_array)
# self.observer.clear_old_segments()
def close(self):
"""Close the audio source."""
if self.custom_source:
self.custom_source.close()
def concatenate_speakers(segments):
segments_concatenated = [{"speaker": 1, "begin": 0.0, "end": 0.0}]
for segment in segments:
speaker = extract_number(segment.speaker) + 1
if segments_concatenated[-1]['speaker'] != speaker:
segments_concatenated.append({"speaker": speaker, "begin": segment.start, "end": segment.end})
else:
segments_concatenated[-1]['end'] = segment.end
# print("Segments concatenated:")
# for entry in segments_concatenated:
# print(f"Speaker {entry['speaker']}: {entry['begin']:.2f}s - {entry['end']:.2f}s")
return segments_concatenated
def add_speaker_to_tokens(segments, tokens):
"""
Assign speakers to tokens based on diarization segments, with punctuation-aware boundary adjustment.
"""
punctuation_marks = {'.', '!', '?'}
punctuation_tokens = [token for token in tokens if token.text.strip() in punctuation_marks]
segments_concatenated = concatenate_speakers(segments)
for ind, segment in enumerate(segments_concatenated):
for i, punctuation_token in enumerate(punctuation_tokens):
if punctuation_token.start > segment['end']:
after_length = punctuation_token.start - segment['end']
before_length = segment['end'] - punctuation_tokens[i - 1].end
if before_length > after_length:
segment['end'] = punctuation_token.start
if i < len(punctuation_tokens) - 1 and ind + 1 < len(segments_concatenated):
segments_concatenated[ind + 1]['begin'] = punctuation_token.start
else:
segment['end'] = punctuation_tokens[i - 1].end
if i < len(punctuation_tokens) - 1 and ind - 1 >= 0:
segments_concatenated[ind - 1]['begin'] = punctuation_tokens[i - 1].end
break
last_end = 0.0
for token in tokens:
start = max(last_end + 0.01, token.start)
token.start = start
token.end = max(start, token.end)
last_end = token.end
ind_last_speaker = 0
for segment in segments_concatenated:
for i, token in enumerate(tokens[ind_last_speaker:]):
if token.end <= segment['end']:
token.speaker = segment['speaker']
ind_last_speaker = i + 1
# print(
# f"Token '{token.text}' ('begin': {token.start:.2f}, 'end': {token.end:.2f}) "
# f"assigned to Speaker {segment['speaker']} ('segment': {segment['begin']:.2f}-{segment['end']:.2f})"
# )
elif token.start > segment['end']:
break
return tokens
def visualize_tokens(tokens):
conversation = [{"speaker": -1, "text": ""}]
for token in tokens:
speaker = conversation[-1]['speaker']
if token.speaker != speaker:
conversation.append({"speaker": token.speaker, "text": token.text})
else:
conversation[-1]['text'] += token.text
print("Conversation:")
for entry in conversation:
print(f"Speaker {entry['speaker']}: {entry['text']}")

View File

@@ -1,331 +0,0 @@
import logging
import threading
import time
import wave
from queue import Empty, SimpleQueue
from typing import List, Optional
import numpy as np
import torch
from whisperlivekit.timed_objects import SpeakerSegment
logger = logging.getLogger(__name__)
try:
from nemo.collections.asr.models import SortformerEncLabelModel
from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor
except ImportError:
raise SystemExit("""Please use `pip install "git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]"` to use the Sortformer diarization""")
class StreamingSortformerState:
"""
This class creates a class instance that will be used to store the state of the
streaming Sortformer model.
Attributes:
spkcache (torch.Tensor): Speaker cache to store embeddings from start
spkcache_lengths (torch.Tensor): Lengths of the speaker cache
spkcache_preds (torch.Tensor): The speaker predictions for the speaker cache parts
fifo (torch.Tensor): FIFO queue to save the embedding from the latest chunks
fifo_lengths (torch.Tensor): Lengths of the FIFO queue
fifo_preds (torch.Tensor): The speaker predictions for the FIFO queue parts
spk_perm (torch.Tensor): Speaker permutation information for the speaker cache
mean_sil_emb (torch.Tensor): Mean silence embedding
n_sil_frames (torch.Tensor): Number of silence frames
"""
def __init__(self):
self.spkcache = None # Speaker cache to store embeddings from start
self.spkcache_lengths = None
self.spkcache_preds = None # speaker cache predictions
self.fifo = None # to save the embedding from the latest chunks
self.fifo_lengths = None
self.fifo_preds = None
self.spk_perm = None
self.mean_sil_emb = None
self.n_sil_frames = None
class SortformerDiarization:
def __init__(self, model_name: str = "nvidia/diar_streaming_sortformer_4spk-v2"):
"""
Stores the shared streaming Sortformer diarization model. Used when a new online_diarization is initialized.
"""
self._load_model(model_name)
def _load_model(self, model_name: str):
"""Load and configure the Sortformer model for streaming."""
try:
self.diar_model = SortformerEncLabelModel.from_pretrained(model_name)
self.diar_model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.diar_model.to(device)
## to test
# for name, param in self.diar_model.named_parameters():
# if param.device != device:
# raise RuntimeError(f"Parameter {name} is on {param.device} but should be on {device}")
logger.info(f"Using {device.type.upper()} for Sortformer model")
self.diar_model.sortformer_modules.chunk_len = 10
self.diar_model.sortformer_modules.subsampling_factor = 10
self.diar_model.sortformer_modules.chunk_right_context = 0
self.diar_model.sortformer_modules.chunk_left_context = 10
self.diar_model.sortformer_modules.spkcache_len = 188
self.diar_model.sortformer_modules.fifo_len = 188
self.diar_model.sortformer_modules.spkcache_update_period = 144
self.diar_model.sortformer_modules.log = False
self.diar_model.sortformer_modules._check_streaming_parameters()
except Exception as e:
logger.error(f"Failed to load Sortformer model: {e}")
raise
class SortformerDiarizationOnline:
def __init__(self, shared_model, sample_rate: int = 16000):
"""
Initialize the streaming Sortformer diarization system.
Args:
sample_rate: Audio sample rate (default: 16000)
model_name: Pre-trained model name (default: "nvidia/diar_streaming_sortformer_4spk-v2")
"""
self.sample_rate = sample_rate
self.diarization_segments = []
self.diar_segments = []
self.buffer_audio = np.array([], dtype=np.float32)
self.segment_lock = threading.Lock()
self.global_time_offset = 0.0
self.debug = False
self.diar_model = shared_model.diar_model
self.audio2mel = AudioToMelSpectrogramPreprocessor(
window_size=0.025,
normalize="NA",
n_fft=512,
features=128,
pad_to=0
)
self.audio2mel.to(self.diar_model.device)
self.chunk_duration_seconds = (
self.diar_model.sortformer_modules.chunk_len *
self.diar_model.sortformer_modules.subsampling_factor *
self.diar_model.preprocessor._cfg.window_stride
)
self._init_streaming_state()
self._previous_chunk_features = None
self._chunk_index = 0
self._len_prediction = None
# Audio buffer to store PCM chunks for debugging
self.audio_buffer = []
# Buffer for accumulating audio chunks until reaching chunk_duration_seconds
self.audio_chunk_buffer = []
self.accumulated_duration = 0.0
logger.info("SortformerDiarization initialized successfully")
def _init_streaming_state(self):
"""Initialize the streaming state for the model."""
batch_size = 1
device = self.diar_model.device
self.streaming_state = StreamingSortformerState()
self.streaming_state.spkcache = torch.zeros(
(batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.fc_d_model),
device=device
)
self.streaming_state.spkcache_preds = torch.zeros(
(batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.n_spk),
device=device
)
self.streaming_state.spkcache_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
self.streaming_state.fifo = torch.zeros(
(batch_size, self.diar_model.sortformer_modules.fifo_len, self.diar_model.sortformer_modules.fc_d_model),
device=device
)
self.streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
self.streaming_state.mean_sil_emb = torch.zeros((batch_size, self.diar_model.sortformer_modules.fc_d_model), device=device)
self.streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device)
self.total_preds = torch.zeros((batch_size, 0, self.diar_model.sortformer_modules.n_spk), device=device)
def insert_silence(self, silence_duration: Optional[float]):
"""
Insert silence period by adjusting the global time offset.
Args:
silence_duration: Duration of silence in seconds
"""
with self.segment_lock:
self.global_time_offset += silence_duration
logger.debug(f"Inserted silence of {silence_duration:.2f}s, new offset: {self.global_time_offset:.2f}s")
def insert_audio_chunk(self, pcm_array: np.ndarray):
if self.debug:
self.audio_buffer.append(pcm_array.copy())
self.buffer_audio = np.concatenate([self.buffer_audio, pcm_array.copy()])
async def diarize(self):
"""
Process audio data for diarization in streaming fashion.
Args:
pcm_array: Audio data as numpy array
"""
threshold = int(self.chunk_duration_seconds * self.sample_rate)
if not len(self.buffer_audio) >= threshold:
return []
audio = self.buffer_audio[:threshold]
self.buffer_audio = self.buffer_audio[threshold:]
device = self.diar_model.device
audio_signal_chunk = torch.tensor(audio, device=device).unsqueeze(0)
audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]], device=device)
processed_signal_chunk, processed_signal_length_chunk = self.audio2mel.get_features(
audio_signal_chunk, audio_signal_length_chunk
)
processed_signal_chunk = processed_signal_chunk.to(device)
processed_signal_length_chunk = processed_signal_length_chunk.to(device)
if self._previous_chunk_features is not None:
to_add = self._previous_chunk_features[:, :, -99:].to(device)
total_features = torch.concat([to_add, processed_signal_chunk], dim=2).to(device)
else:
total_features = processed_signal_chunk.to(device)
self._previous_chunk_features = processed_signal_chunk.to(device)
chunk_feat_seq_t = torch.transpose(total_features, 1, 2).to(device)
with torch.inference_mode():
left_offset = 8 if self._chunk_index > 0 else 0
right_offset = 8
self.streaming_state, self.total_preds = self.diar_model.forward_streaming_step(
processed_signal=chunk_feat_seq_t,
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]).to(device),
streaming_state=self.streaming_state,
total_preds=self.total_preds,
left_offset=left_offset,
right_offset=right_offset,
)
new_segments = self._process_predictions()
self._chunk_index += 1
return new_segments
def _process_predictions(self):
"""Process model predictions and convert to speaker segments."""
preds_np = self.total_preds[0].cpu().numpy()
active_speakers = np.argmax(preds_np, axis=1)
if self._len_prediction is None:
self._len_prediction = len(active_speakers) #12
frame_duration = self.chunk_duration_seconds / self._len_prediction
current_chunk_preds = active_speakers[-self._len_prediction:]
new_segments = []
with self.segment_lock:
base_time = self._chunk_index * self.chunk_duration_seconds + self.global_time_offset
current_spk = current_chunk_preds[0]
start_time = round(base_time, 2)
for idx, spk in enumerate(current_chunk_preds):
current_time = round(base_time + idx * frame_duration, 2)
if spk != current_spk:
new_segments.append(SpeakerSegment(
speaker=current_spk,
start=start_time,
end=current_time
))
start_time = current_time
current_spk = spk
new_segments.append(
SpeakerSegment(
speaker=current_spk,
start=start_time,
end=current_time
)
)
return new_segments
def get_segments(self) -> List[SpeakerSegment]:
"""Get a copy of the current speaker segments."""
with self.segment_lock:
return self.diarization_segments.copy()
def close(self):
"""Close the diarization system and clean up resources."""
logger.info("Closing SortformerDiarization")
with self.segment_lock:
self.diarization_segments.clear()
if self.debug:
concatenated_audio = np.concatenate(self.audio_buffer)
audio_data_int16 = (concatenated_audio * 32767).astype(np.int16)
with wave.open("diarization_audio.wav", "wb") as wav_file:
wav_file.setnchannels(1) # mono audio
wav_file.setsampwidth(2) # 2 bytes per sample (int16)
wav_file.setframerate(self.sample_rate)
wav_file.writeframes(audio_data_int16.tobytes())
logger.info(f"Saved {len(concatenated_audio)} samples to diarization_audio.wav")
def extract_number(s: str) -> int:
"""Extract number from speaker string (compatibility function)."""
import re
m = re.search(r'\d+', s)
return int(m.group()) if m else 0
if __name__ == '__main__':
import asyncio
import librosa
async def main():
"""TEST ONLY."""
an4_audio = 'diarization_audio.wav'
signal, sr = librosa.load(an4_audio, sr=16000)
signal = signal[:16000*30]
print("\n" + "=" * 50)
print("ground truth:")
print("Speaker 0: 0:00 - 0:09")
print("Speaker 1: 0:09 - 0:19")
print("Speaker 2: 0:19 - 0:25")
print("Speaker 0: 0:25 - 0:30")
print("=" * 50)
diarization_backend = SortformerDiarization()
diarization = SortformerDiarizationOnline(shared_model = diarization_backend)
chunk_size = 1600
for i in range(0, len(signal), chunk_size):
chunk = signal[i:i+chunk_size]
new_segments = await diarization.diarize(chunk)
print(f"Processed chunk {i // chunk_size + 1}")
print(new_segments)
segments = diarization.get_segments()
print("\nDiarization results:")
for segment in segments:
print(f"Speaker {segment.speaker}: {segment.start:.2f}s - {segment.end:.2f}s")
asyncio.run(main())

View File

@@ -1,197 +0,0 @@
import asyncio
import contextlib
import logging
from enum import Enum
from typing import Callable, Optional
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
ERROR_INSTALL_INSTRUCTIONS = f"""
{'='*50}
FFmpeg is not installed or not found in your system's PATH.
Alternative Solution: You can still use WhisperLiveKit without FFmpeg by adding the --pcm-input parameter. Note that when using this option, audio will not be compressed between the frontend and backend, which may result in higher bandwidth usage.
If you want to install FFmpeg:
# Ubuntu/Debian:
sudo apt update && sudo apt install ffmpeg
# macOS (using Homebrew):
brew install ffmpeg
# Windows:
# 1. Download the latest static build from https://ffmpeg.org/download.html
# 2. Extract the archive (e.g., to C:\\FFmpeg).
# 3. Add the 'bin' directory (e.g., C:\\FFmpeg\\bin) to your system's PATH environment variable.
After installation, please restart the application.
{'='*50}
"""
class FFmpegState(Enum):
STOPPED = "stopped"
STARTING = "starting"
RUNNING = "running"
RESTARTING = "restarting"
FAILED = "failed"
class FFmpegManager:
def __init__(self, sample_rate: int = 16000, channels: int = 1):
self.sample_rate = sample_rate
self.channels = channels
self.process: Optional[asyncio.subprocess.Process] = None
self._stderr_task: Optional[asyncio.Task] = None
self.on_error_callback: Optional[Callable[[str], None]] = None
self.state = FFmpegState.STOPPED
self._state_lock = asyncio.Lock()
async def start(self) -> bool:
async with self._state_lock:
if self.state != FFmpegState.STOPPED:
logger.warning(f"FFmpeg already running in state: {self.state}")
return False
self.state = FFmpegState.STARTING
try:
cmd = [
"ffmpeg",
"-hide_banner",
"-loglevel", "error",
"-i", "pipe:0",
"-f", "s16le",
"-acodec", "pcm_s16le",
"-ac", str(self.channels),
"-ar", str(self.sample_rate),
"pipe:1"
]
self.process = await asyncio.create_subprocess_exec(
*cmd,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
self._stderr_task = asyncio.create_task(self._drain_stderr())
async with self._state_lock:
self.state = FFmpegState.RUNNING
logger.info("FFmpeg started.")
return True
except FileNotFoundError:
logger.error(ERROR_INSTALL_INSTRUCTIONS)
async with self._state_lock:
self.state = FFmpegState.FAILED
if self.on_error_callback:
await self.on_error_callback("ffmpeg_not_found")
return False
except Exception as e:
logger.error(f"Error starting FFmpeg: {e}")
async with self._state_lock:
self.state = FFmpegState.FAILED
if self.on_error_callback:
await self.on_error_callback("start_failed")
return False
async def stop(self):
async with self._state_lock:
if self.state == FFmpegState.STOPPED:
return
self.state = FFmpegState.STOPPED
if self.process:
if self.process.stdin and not self.process.stdin.is_closing():
self.process.stdin.close()
await self.process.stdin.wait_closed()
await self.process.wait()
self.process = None
if self._stderr_task:
self._stderr_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._stderr_task
logger.info("FFmpeg stopped.")
async def write_data(self, data: bytes) -> bool:
async with self._state_lock:
if self.state != FFmpegState.RUNNING:
logger.warning(f"Cannot write, FFmpeg state: {self.state}")
return False
try:
self.process.stdin.write(data)
await self.process.stdin.drain()
return True
except Exception as e:
logger.error(f"Error writing to FFmpeg: {e}")
if self.on_error_callback:
await self.on_error_callback("write_error")
return False
async def read_data(self, size: int) -> Optional[bytes]:
async with self._state_lock:
if self.state != FFmpegState.RUNNING:
logger.warning(f"Cannot read, FFmpeg state: {self.state}")
return None
try:
data = await asyncio.wait_for(
self.process.stdout.read(size),
timeout=20.0
)
return data
except asyncio.TimeoutError:
logger.warning("FFmpeg read timeout.")
return None
except Exception as e:
logger.error(f"Error reading from FFmpeg: {e}")
if self.on_error_callback:
await self.on_error_callback("read_error")
return None
async def get_state(self) -> FFmpegState:
async with self._state_lock:
return self.state
async def restart(self) -> bool:
async with self._state_lock:
if self.state == FFmpegState.RESTARTING:
logger.warning("Restart already in progress.")
return False
self.state = FFmpegState.RESTARTING
logger.info("Restarting FFmpeg...")
try:
await self.stop()
await asyncio.sleep(1) # short delay before restarting
return await self.start()
except Exception as e:
logger.error(f"Error during FFmpeg restart: {e}")
async with self._state_lock:
self.state = FFmpegState.FAILED
if self.on_error_callback:
await self.on_error_callback("restart_failed")
return False
async def _drain_stderr(self):
try:
while True:
if not self.process or not self.process.stderr:
break
line = await self.process.stderr.readline()
if not line:
break
logger.debug(f"FFmpeg stderr: {line.decode(errors='ignore').strip()}")
except asyncio.CancelledError:
logger.info("FFmpeg stderr drain task cancelled.")
except Exception as e:
logger.error(f"Error draining FFmpeg stderr: {e}")

View File

@@ -1,303 +0,0 @@
import io
import logging
import math
import sys
from typing import List
import numpy as np
import soundfile as sf
from whisperlivekit.model_paths import detect_model_format, resolve_model_path
from whisperlivekit.timed_objects import ASRToken
from whisperlivekit.whisper.transcribe import transcribe as whisper_transcribe
logger = logging.getLogger(__name__)
class ASRBase:
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
# "" for faster-whisper because it emits the spaces when needed)
def __init__(self, lan, model_size=None, cache_dir=None, model_dir=None, lora_path=None, logfile=sys.stderr):
self.logfile = logfile
self.transcribe_kargs = {}
self.lora_path = lora_path
if lan == "auto":
self.original_language = None
else:
self.original_language = lan
self.model = self.load_model(model_size, cache_dir, model_dir)
def with_offset(self, offset: float) -> ASRToken:
# This method is kept for compatibility (typically you will use ASRToken.with_offset)
return ASRToken(self.start + offset, self.end + offset, self.text)
def __repr__(self):
return f"ASRToken(start={self.start:.2f}, end={self.end:.2f}, text={self.text!r})"
def load_model(self, model_size, cache_dir, model_dir):
raise NotImplementedError("must be implemented in the child class")
def transcribe(self, audio, init_prompt=""):
raise NotImplementedError("must be implemented in the child class")
def use_vad(self):
raise NotImplementedError("must be implemented in the child class")
class WhisperASR(ASRBase):
"""Uses WhisperLiveKit's built-in Whisper implementation."""
sep = " "
def load_model(self, model_size=None, cache_dir=None, model_dir=None):
from whisperlivekit.whisper import load_model as load_whisper_model
if model_dir is not None:
resolved_path = resolve_model_path(model_dir)
if resolved_path.is_dir():
model_info = detect_model_format(resolved_path)
if not model_info.has_pytorch:
raise FileNotFoundError(
f"No supported PyTorch checkpoint found under {resolved_path}"
)
logger.debug(f"Loading Whisper model from custom path {resolved_path}")
return load_whisper_model(str(resolved_path), lora_path=self.lora_path)
if model_size is None:
raise ValueError("Either model_size or model_dir must be set for WhisperASR")
return load_whisper_model(model_size, download_root=cache_dir, lora_path=self.lora_path)
def transcribe(self, audio, init_prompt=""):
options = dict(self.transcribe_kargs)
options.pop("vad", None)
options.pop("vad_filter", None)
language = self.original_language if self.original_language else None
result = whisper_transcribe(
self.model,
audio,
language=language,
initial_prompt=init_prompt,
condition_on_previous_text=True,
word_timestamps=True,
**options,
)
return result
def ts_words(self, r) -> List[ASRToken]:
"""
Converts the Whisper result to a list of ASRToken objects.
"""
tokens = []
for segment in r["segments"]:
for word in segment["words"]:
token = ASRToken(
word["start"],
word["end"],
word["word"],
probability=word.get("probability"),
)
tokens.append(token)
return tokens
def segments_end_ts(self, res) -> List[float]:
return [segment["end"] for segment in res["segments"]]
def use_vad(self):
logger.warning("VAD is not currently supported for WhisperASR backend and will be ignored.")
class FasterWhisperASR(ASRBase):
"""Uses faster-whisper as the backend."""
sep = ""
def load_model(self, model_size=None, cache_dir=None, model_dir=None):
from faster_whisper import WhisperModel
if model_dir is not None:
resolved_path = resolve_model_path(model_dir)
logger.debug(f"Loading faster-whisper model from {resolved_path}. "
f"model_size and cache_dir parameters are not used.")
model_size_or_path = str(resolved_path)
elif model_size is not None:
model_size_or_path = model_size
else:
raise ValueError("Either model_size or model_dir must be set")
device = "auto" # Allow CTranslate2 to decide available device
compute_type = "auto" # Allow CTranslate2 to decide faster compute type
model = WhisperModel(
model_size_or_path,
device=device,
compute_type=compute_type,
download_root=cache_dir,
)
return model
def transcribe(self, audio: np.ndarray, init_prompt: str = "") -> list:
segments, info = self.model.transcribe(
audio,
language=self.original_language,
initial_prompt=init_prompt,
beam_size=5,
word_timestamps=True,
condition_on_previous_text=True,
**self.transcribe_kargs,
)
return list(segments)
def ts_words(self, segments) -> List[ASRToken]:
tokens = []
for segment in segments:
if segment.no_speech_prob > 0.9:
continue
for word in segment.words:
token = ASRToken(word.start, word.end, word.word)
tokens.append(token)
return tokens
def segments_end_ts(self, segments) -> List[float]:
return [segment.end for segment in segments]
def use_vad(self):
self.transcribe_kargs["vad_filter"] = True
class MLXWhisper(ASRBase):
"""
Uses MLX Whisper optimized for Apple Silicon.
"""
sep = ""
def load_model(self, model_size=None, cache_dir=None, model_dir=None):
import mlx.core as mx
from mlx_whisper.transcribe import ModelHolder, transcribe
if model_dir is not None:
resolved_path = resolve_model_path(model_dir)
logger.debug(f"Loading MLX Whisper model from {resolved_path}. model_size parameter is not used.")
model_size_or_path = str(resolved_path)
elif model_size is not None:
model_size_or_path = self.translate_model_name(model_size)
logger.debug(f"Loading whisper model {model_size}. You use mlx whisper, so {model_size_or_path} will be used.")
else:
raise ValueError("Either model_size or model_dir must be set")
self.model_size_or_path = model_size_or_path
dtype = mx.float16
ModelHolder.get_model(model_size_or_path, dtype)
return transcribe
def translate_model_name(self, model_name):
model_mapping = {
"tiny.en": "mlx-community/whisper-tiny.en-mlx",
"tiny": "mlx-community/whisper-tiny-mlx",
"base.en": "mlx-community/whisper-base.en-mlx",
"base": "mlx-community/whisper-base-mlx",
"small.en": "mlx-community/whisper-small.en-mlx",
"small": "mlx-community/whisper-small-mlx",
"medium.en": "mlx-community/whisper-medium.en-mlx",
"medium": "mlx-community/whisper-medium-mlx",
"large-v1": "mlx-community/whisper-large-v1-mlx",
"large-v2": "mlx-community/whisper-large-v2-mlx",
"large-v3": "mlx-community/whisper-large-v3-mlx",
"large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
"large": "mlx-community/whisper-large-mlx",
}
mlx_model_path = model_mapping.get(model_name)
if mlx_model_path:
return mlx_model_path
else:
raise ValueError(f"Model name '{model_name}' is not recognized or not supported.")
def transcribe(self, audio, init_prompt=""):
if self.transcribe_kargs:
logger.warning("Transcribe kwargs (vad, task) are not compatible with MLX Whisper and will be ignored.")
segments = self.model(
audio,
language=self.original_language,
initial_prompt=init_prompt,
word_timestamps=True,
condition_on_previous_text=True,
path_or_hf_repo=self.model_size_or_path,
)
return segments.get("segments", [])
def ts_words(self, segments) -> List[ASRToken]:
tokens = []
for segment in segments:
if segment.get("no_speech_prob", 0) > 0.9:
continue
for word in segment.get("words", []):
probability=word["probability"]
token = ASRToken(word["start"], word["end"], word["word"])
tokens.append(token)
return tokens
def segments_end_ts(self, res) -> List[float]:
return [s["end"] for s in res]
def use_vad(self):
self.transcribe_kargs["vad_filter"] = True
class OpenaiApiASR(ASRBase):
"""Uses OpenAI's Whisper API for transcription."""
def __init__(self, lan=None, temperature=0, logfile=sys.stderr):
self.logfile = logfile
self.modelname = "whisper-1"
self.original_language = None if lan == "auto" else lan
self.response_format = "verbose_json"
self.temperature = temperature
self.load_model()
self.use_vad_opt = False
self.direct_english_translation = False
def load_model(self, *args, **kwargs):
from openai import OpenAI
self.client = OpenAI()
self.transcribed_seconds = 0
def ts_words(self, segments) -> List[ASRToken]:
"""
Converts OpenAI API response words into ASRToken objects while
optionally skipping words that fall into no-speech segments.
"""
no_speech_segments = []
if self.use_vad_opt:
for segment in segments.segments:
if segment.no_speech_prob > 0.8:
no_speech_segments.append((segment.start, segment.end))
tokens = []
for word in segments.words:
start = word.start
end = word.end
if any(s[0] <= start <= s[1] for s in no_speech_segments):
continue
tokens.append(ASRToken(start, end, word.word))
return tokens
def segments_end_ts(self, res) -> List[float]:
return [s.end for s in res.words]
def transcribe(self, audio_data, prompt=None, *args, **kwargs):
buffer = io.BytesIO()
buffer.name = "temp.wav"
sf.write(buffer, audio_data, samplerate=16000, format="WAV", subtype="PCM_16")
buffer.seek(0)
self.transcribed_seconds += math.ceil(len(audio_data) / 16000)
params = {
"model": self.modelname,
"file": buffer,
"response_format": self.response_format,
"temperature": self.temperature,
"timestamp_granularities": ["word", "segment"],
}
if not self.direct_english_translation and self.original_language:
params["language"] = self.original_language
if prompt:
params["prompt"] = prompt
proc = self.client.audio.translations if self.task == "translate" else self.client.audio.transcriptions
transcript = proc.create(**params)
logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds")
return transcript
def use_vad(self):
self.use_vad_opt = True

View File

@@ -1,423 +0,0 @@
import logging
import sys
from typing import List, Optional, Tuple
import numpy as np
from whisperlivekit.timed_objects import ASRToken, Sentence, Transcript
logger = logging.getLogger(__name__)
class HypothesisBuffer:
"""
Buffer to store and process ASR hypothesis tokens.
It holds:
- committed_in_buffer: tokens that have been confirmed (committed)
- buffer: the last hypothesis that is not yet committed
- new: new tokens coming from the recognizer
"""
def __init__(self, logfile=sys.stderr, confidence_validation=False):
self.confidence_validation = confidence_validation
self.committed_in_buffer: List[ASRToken] = []
self.buffer: List[ASRToken] = []
self.new: List[ASRToken] = []
self.last_committed_time = 0.0
self.last_committed_word: Optional[str] = None
self.logfile = logfile
def insert(self, new_tokens: List[ASRToken], offset: float):
"""
Insert new tokens (after applying a time offset) and compare them with the
already committed tokens. Only tokens that extend the committed hypothesis
are added.
"""
# Apply the offset to each token.
new_tokens = [token.with_offset(offset) for token in new_tokens]
# Only keep tokens that are roughly “new”
self.new = [token for token in new_tokens if token.start > self.last_committed_time - 0.1]
if self.new:
first_token = self.new[0]
if abs(first_token.start - self.last_committed_time) < 1:
if self.committed_in_buffer:
committed_len = len(self.committed_in_buffer)
new_len = len(self.new)
# Try to match 1 to 5 consecutive tokens
max_ngram = min(min(committed_len, new_len), 5)
for i in range(1, max_ngram + 1):
committed_ngram = " ".join(token.text for token in self.committed_in_buffer[-i:])
new_ngram = " ".join(token.text for token in self.new[:i])
if committed_ngram == new_ngram:
removed = []
for _ in range(i):
removed_token = self.new.pop(0)
removed.append(repr(removed_token))
logger.debug(f"Removing last {i} words: {' '.join(removed)}")
break
def flush(self) -> List[ASRToken]:
"""
Returns the committed chunk, defined as the longest common prefix
between the previous hypothesis and the new tokens.
"""
committed: List[ASRToken] = []
while self.new:
current_new = self.new[0]
if self.confidence_validation and current_new.probability and current_new.probability > 0.95:
committed.append(current_new)
self.last_committed_word = current_new.text
self.last_committed_time = current_new.end
self.new.pop(0)
self.buffer.pop(0) if self.buffer else None
elif not self.buffer:
break
elif current_new.text == self.buffer[0].text:
committed.append(current_new)
self.last_committed_word = current_new.text
self.last_committed_time = current_new.end
self.buffer.pop(0)
self.new.pop(0)
else:
break
self.buffer = self.new
self.new = []
self.committed_in_buffer.extend(committed)
return committed
def pop_committed(self, time: float):
"""
Remove tokens (from the beginning) that have ended before `time`.
"""
while self.committed_in_buffer and self.committed_in_buffer[0].end <= time:
self.committed_in_buffer.pop(0)
class OnlineASRProcessor:
"""
Processes incoming audio in a streaming fashion, calling the ASR system
periodically, and uses a hypothesis buffer to commit and trim recognized text.
The processor supports two types of buffer trimming:
- "sentence": trims at sentence boundaries (using a sentence tokenizer)
- "segment": trims at fixed segment durations.
"""
SAMPLING_RATE = 16000
def __init__(
self,
asr,
logfile=sys.stderr,
):
"""
asr: An ASR system object (for example, a WhisperASR instance) that
provides a `transcribe` method, a `ts_words` method (to extract tokens),
a `segments_end_ts` method, and a separator attribute `sep`.
tokenize_method: A function that receives text and returns a list of sentence strings.
buffer_trimming: A tuple (option, seconds), where option is either "sentence" or "segment".
"""
self.asr = asr
self.tokenize = asr.tokenizer
self.logfile = logfile
self.confidence_validation = asr.confidence_validation
self.global_time_offset = 0.0
self.init()
self.buffer_trimming_way = asr.buffer_trimming
self.buffer_trimming_sec = asr.buffer_trimming_sec
if self.buffer_trimming_way not in ["sentence", "segment"]:
raise ValueError("buffer_trimming must be either 'sentence' or 'segment'")
if self.buffer_trimming_sec <= 0:
raise ValueError("buffer_trimming_sec must be positive")
elif self.buffer_trimming_sec > 30:
logger.warning(
f"buffer_trimming_sec is set to {self.buffer_trimming_sec}, which is very long. It may cause OOM."
)
def init(self, offset: Optional[float] = None):
"""Initialize or reset the processing buffers."""
self.audio_buffer = np.array([], dtype=np.float32)
self.transcript_buffer = HypothesisBuffer(logfile=self.logfile, confidence_validation=self.confidence_validation)
self.buffer_time_offset = offset if offset is not None else 0.0
self.transcript_buffer.last_committed_time = self.buffer_time_offset
self.committed: List[ASRToken] = []
self.time_of_last_asr_output = 0.0
def get_audio_buffer_end_time(self) -> float:
"""Returns the absolute end time of the current audio_buffer."""
return self.buffer_time_offset + (len(self.audio_buffer) / self.SAMPLING_RATE)
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: Optional[float] = None):
"""Append an audio chunk (a numpy array) to the current audio buffer."""
self.audio_buffer = np.append(self.audio_buffer, audio)
def start_silence(self):
if self.audio_buffer.size == 0:
return [], self.get_audio_buffer_end_time()
return self.process_iter()
def end_silence(self, silence_duration: Optional[float], offset: float):
if not silence_duration or silence_duration <= 0:
return
long_silence = silence_duration >= 5
if not long_silence:
gap_samples = int(self.SAMPLING_RATE * silence_duration)
if gap_samples > 0:
gap_silence = np.zeros(gap_samples, dtype=np.float32)
self.insert_audio_chunk(gap_silence)
else:
self.init(offset=silence_duration + offset)
self.global_time_offset += silence_duration
def insert_silence(self, silence_duration, offset):
"""
Backwards compatibility shim for legacy callers that still use insert_silence.
"""
self.end_silence(silence_duration, offset)
def prompt(self) -> Tuple[str, str]:
"""
Returns a tuple: (prompt, context), where:
- prompt is a 200-character suffix of committed text that falls
outside the current audio buffer.
- context is the committed text within the current audio buffer.
"""
k = len(self.committed)
while k > 0 and self.committed[k - 1].end > self.buffer_time_offset:
k -= 1
prompt_tokens = self.committed[:k]
prompt_words = [token.text for token in prompt_tokens]
prompt_list = []
length_count = 0
# Use the last words until reaching 200 characters.
while prompt_words and length_count < 200:
word = prompt_words.pop(-1)
length_count += len(word) + 1
prompt_list.append(word)
non_prompt_tokens = self.committed[k:]
context_text = self.asr.sep.join(token.text for token in non_prompt_tokens)
return self.asr.sep.join(prompt_list[::-1]), context_text
def get_buffer(self):
"""
Get the unvalidated buffer in string format.
"""
return self.concatenate_tokens(self.transcript_buffer.buffer)
def process_iter(self) -> Tuple[List[ASRToken], float]:
"""
Processes the current audio buffer.
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
"""
current_audio_processed_upto = self.get_audio_buffer_end_time()
prompt_text, _ = self.prompt()
logger.debug(
f"Transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds from {self.buffer_time_offset:.2f}"
)
res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt_text)
tokens = self.asr.ts_words(res)
self.transcript_buffer.insert(tokens, self.buffer_time_offset)
committed_tokens = self.transcript_buffer.flush()
self.committed.extend(committed_tokens)
if committed_tokens:
self.time_of_last_asr_output = self.committed[-1].end
completed = self.concatenate_tokens(committed_tokens)
logger.debug(f">>>> COMPLETE NOW: {completed.text}")
incomp = self.concatenate_tokens(self.transcript_buffer.buffer)
logger.debug(f"INCOMPLETE: {incomp.text}")
buffer_duration = len(self.audio_buffer) / self.SAMPLING_RATE
if not committed_tokens and buffer_duration > self.buffer_trimming_sec:
time_since_last_output = self.get_audio_buffer_end_time() - self.time_of_last_asr_output
if time_since_last_output > self.buffer_trimming_sec:
logger.warning(
f"No ASR output for {time_since_last_output:.2f}s. "
f"Resetting buffer to prevent freezing."
)
self.init(offset=self.get_audio_buffer_end_time())
return [], current_audio_processed_upto
if committed_tokens and self.buffer_trimming_way == "sentence":
if len(self.audio_buffer) / self.SAMPLING_RATE > self.buffer_trimming_sec:
self.chunk_completed_sentence()
s = self.buffer_trimming_sec if self.buffer_trimming_way == "segment" else 30
if len(self.audio_buffer) / self.SAMPLING_RATE > s:
self.chunk_completed_segment(res)
logger.debug("Chunking segment")
logger.debug(
f"Length of audio buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds"
)
if self.global_time_offset:
for token in committed_tokens:
token = token.with_offset(self.global_time_offset)
return committed_tokens, current_audio_processed_upto
def chunk_completed_sentence(self):
"""
If the committed tokens form at least two sentences, chunk the audio
buffer at the end time of the penultimate sentence.
Also ensures chunking happens if audio buffer exceeds a time limit.
"""
buffer_duration = len(self.audio_buffer) / self.SAMPLING_RATE
if not self.committed:
if buffer_duration > self.buffer_trimming_sec:
chunk_time = self.buffer_time_offset + (buffer_duration / 2)
logger.debug(f"--- No speech detected, forced chunking at {chunk_time:.2f}")
self.chunk_at(chunk_time)
return
logger.debug("COMPLETED SENTENCE: " + " ".join(token.text for token in self.committed))
sentences = self.words_to_sentences(self.committed)
for sentence in sentences:
logger.debug(f"\tSentence: {sentence.text}")
chunk_done = False
if len(sentences) >= 2:
while len(sentences) > 2:
sentences.pop(0)
chunk_time = sentences[-2].end
logger.debug(f"--- Sentence chunked at {chunk_time:.2f}")
self.chunk_at(chunk_time)
chunk_done = True
if not chunk_done and buffer_duration > self.buffer_trimming_sec:
last_committed_time = self.committed[-1].end
logger.debug(f"--- Not enough sentences, chunking at last committed time {last_committed_time:.2f}")
self.chunk_at(last_committed_time)
def chunk_completed_segment(self, res):
"""
Chunk the audio buffer based on segment-end timestamps reported by the ASR.
Also ensures chunking happens if audio buffer exceeds a time limit.
"""
buffer_duration = len(self.audio_buffer) / self.SAMPLING_RATE
if not self.committed:
if buffer_duration > self.buffer_trimming_sec:
chunk_time = self.buffer_time_offset + (buffer_duration / 2)
logger.debug(f"--- No speech detected, forced chunking at {chunk_time:.2f}")
self.chunk_at(chunk_time)
return
logger.debug("Processing committed tokens for segmenting")
ends = self.asr.segments_end_ts(res)
last_committed_time = self.committed[-1].end
chunk_done = False
if len(ends) > 1:
logger.debug("Multiple segments available for chunking")
e = ends[-2] + self.buffer_time_offset
while len(ends) > 2 and e > last_committed_time:
ends.pop(-1)
e = ends[-2] + self.buffer_time_offset
if e <= last_committed_time:
logger.debug(f"--- Segment chunked at {e:.2f}")
self.chunk_at(e)
chunk_done = True
else:
logger.debug("--- Last segment not within committed area")
else:
logger.debug("--- Not enough segments to chunk")
if not chunk_done and buffer_duration > self.buffer_trimming_sec:
logger.debug(f"--- Buffer too large, chunking at last committed time {last_committed_time:.2f}")
self.chunk_at(last_committed_time)
logger.debug("Segment chunking complete")
def chunk_at(self, time: float):
"""
Trim both the hypothesis and audio buffer at the given time.
"""
logger.debug(f"Chunking at {time:.2f}s")
logger.debug(
f"Audio buffer length before chunking: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f}s"
)
self.transcript_buffer.pop_committed(time)
cut_seconds = time - self.buffer_time_offset
self.audio_buffer = self.audio_buffer[int(cut_seconds * self.SAMPLING_RATE):]
self.buffer_time_offset = time
logger.debug(
f"Audio buffer length after chunking: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f}s"
)
def words_to_sentences(self, tokens: List[ASRToken]) -> List[Sentence]:
"""
Converts a list of tokens to a list of Sentence objects using the provided
sentence tokenizer.
"""
if not tokens:
return []
full_text = " ".join(token.text for token in tokens)
if self.tokenize:
try:
sentence_texts = self.tokenize(full_text)
except Exception as e:
# Some tokenizers (e.g., MosesSentenceSplitter) expect a list input.
try:
sentence_texts = self.tokenize([full_text])
except Exception as e2:
raise ValueError("Tokenization failed") from e2
else:
sentence_texts = [full_text]
sentences: List[Sentence] = []
token_index = 0
for sent_text in sentence_texts:
sent_text = sent_text.strip()
if not sent_text:
continue
sent_tokens = []
accumulated = ""
# Accumulate tokens until roughly matching the length of the sentence text.
while token_index < len(tokens) and len(accumulated) < len(sent_text):
token = tokens[token_index]
accumulated = (accumulated + " " + token.text).strip() if accumulated else token.text
sent_tokens.append(token)
token_index += 1
if sent_tokens:
sentence = Sentence(
start=sent_tokens[0].start,
end=sent_tokens[-1].end,
text=" ".join(t.text for t in sent_tokens),
)
sentences.append(sentence)
return sentences
def finish(self) -> Tuple[List[ASRToken], float]:
"""
Flush the remaining transcript when processing ends.
Returns a tuple: (list of remaining ASRToken objects, float representing the final audio processed up to time).
"""
remaining_tokens = self.transcript_buffer.buffer
logger.debug(f"Final non-committed tokens: {remaining_tokens}")
final_processed_upto = self.buffer_time_offset + (len(self.audio_buffer) / self.SAMPLING_RATE)
self.buffer_time_offset = final_processed_upto
return remaining_tokens, final_processed_upto
def concatenate_tokens(
self,
tokens: List[ASRToken],
sep: Optional[str] = None,
offset: float = 0
) -> Transcript:
sep = sep if sep is not None else self.asr.sep
text = sep.join(token.text for token in tokens)
# probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
if tokens:
start = offset + tokens[0].start
end = offset + tokens[-1].end
else:
start = None
end = None
return Transcript(start, end, text)

View File

@@ -1,206 +0,0 @@
#!/usr/bin/env python3
import logging
import platform
import sys
import time
from functools import lru_cache
import librosa
import numpy as np
from whisperlivekit.backend_support import (faster_backend_available,
mlx_backend_available)
from whisperlivekit.model_paths import detect_model_format, resolve_model_path
from whisperlivekit.warmup import warmup_asr
from .backends import FasterWhisperASR, MLXWhisper, OpenaiApiASR, WhisperASR
logger = logging.getLogger(__name__)
WHISPER_LANG_CODES = "af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh".split(
","
)
def create_tokenizer(lan):
"""returns an object that has split function that works like the one of MosesTokenizer"""
assert (
lan in WHISPER_LANG_CODES
), "language must be Whisper's supported lang code: " + " ".join(WHISPER_LANG_CODES)
if lan == "uk":
import tokenize_uk
class UkrainianTokenizer:
def split(self, text):
return tokenize_uk.tokenize_sents(text)
return UkrainianTokenizer()
# supported by fast-mosestokenizer
if (
lan
in "as bn ca cs de el en es et fi fr ga gu hi hu is it kn lt lv ml mni mr nl or pa pl pt ro ru sk sl sv ta te yue zh".split()
):
from mosestokenizer import MosesSentenceSplitter
return MosesSentenceSplitter(lan)
# the following languages are in Whisper, but not in wtpsplit:
if (
lan
in "as ba bo br bs fo haw hr ht jw lb ln lo mi nn oc sa sd sn so su sw tk tl tt".split()
):
logger.debug(
f"{lan} code is not supported by wtpsplit. Going to use None lang_code option."
)
lan = None
from wtpsplit import WtP
# downloads the model from huggingface on the first use
wtp = WtP("wtp-canine-s-12l-no-adapters")
class WtPtok:
def split(self, sent):
return wtp.split(sent, lang_code=lan)
return WtPtok()
def backend_factory(
backend,
lan,
model_size,
model_cache_dir,
model_dir,
model_path,
lora_path,
direct_english_translation,
buffer_trimming,
buffer_trimming_sec,
confidence_validation,
warmup_file=None,
min_chunk_size=None,
):
backend_choice = backend
custom_reference = model_path or model_dir
resolved_root = None
has_mlx_weights = False
has_fw_weights = False
has_pytorch = False
if custom_reference:
resolved_root = resolve_model_path(custom_reference)
if resolved_root.is_dir():
model_info = detect_model_format(resolved_root)
has_mlx_weights = model_info.compatible_whisper_mlx
has_fw_weights = model_info.compatible_faster_whisper
has_pytorch = model_info.has_pytorch
else:
# Single file provided
has_pytorch = True
if backend_choice == "openai-api":
logger.debug("Using OpenAI API.")
asr = OpenaiApiASR(lan=lan)
else:
backend_choice = _normalize_backend_choice(
backend_choice,
resolved_root,
has_mlx_weights,
has_fw_weights,
)
if backend_choice == "faster-whisper":
asr_cls = FasterWhisperASR
if resolved_root is not None and not resolved_root.is_dir():
raise ValueError("Faster-Whisper backend expects a directory with CTranslate2 weights.")
model_override = str(resolved_root) if resolved_root is not None else None
elif backend_choice == "mlx-whisper":
asr_cls = MLXWhisper
if resolved_root is not None and not resolved_root.is_dir():
raise ValueError("MLX Whisper backend expects a directory containing MLX weights.")
model_override = str(resolved_root) if resolved_root is not None else None
else:
asr_cls = WhisperASR
model_override = str(resolved_root) if resolved_root is not None else None
if custom_reference and not has_pytorch:
raise FileNotFoundError(
f"No PyTorch checkpoint found under {resolved_root or custom_reference}"
)
t = time.time()
logger.info(f"Loading Whisper {model_size} model for language {lan} using backend {backend_choice}...")
asr = asr_cls(
model_size=model_size,
lan=lan,
cache_dir=model_cache_dir,
model_dir=model_override,
lora_path=lora_path if backend_choice == "whisper" else None,
)
e = time.time()
logger.info(f"done. It took {round(e-t,2)} seconds.")
if direct_english_translation:
tgt_language = "en" # Whisper translates into English
else:
tgt_language = lan # Whisper transcribes in this language
# Create the tokenizer
if buffer_trimming == "sentence":
tokenizer = create_tokenizer(tgt_language)
else:
tokenizer = None
warmup_asr(asr, warmup_file)
asr.confidence_validation = confidence_validation
asr.tokenizer = tokenizer
asr.buffer_trimming = buffer_trimming
asr.buffer_trimming_sec = buffer_trimming_sec
asr.backend_choice = backend_choice
return asr
def _normalize_backend_choice(
preferred_backend,
resolved_root,
has_mlx_weights,
has_fw_weights,
):
backend_choice = preferred_backend
if backend_choice == "auto":
if mlx_backend_available(warn_on_missing=True) and (resolved_root is None or has_mlx_weights):
return "mlx-whisper"
if faster_backend_available(warn_on_missing=True) and (resolved_root is None or has_fw_weights):
return "faster-whisper"
return "whisper"
if backend_choice == "mlx-whisper":
if not mlx_backend_available():
raise RuntimeError("mlx-whisper backend requested but mlx-whisper is not installed.")
if resolved_root is not None and not has_mlx_weights:
raise FileNotFoundError(
f"mlx-whisper backend requested but no MLX weights were found under {resolved_root}"
)
if platform.system() != "Darwin":
logger.warning("mlx-whisper backend requested on a non-macOS system; this may fail.")
return backend_choice
if backend_choice == "faster-whisper":
if not faster_backend_available():
raise RuntimeError("faster-whisper backend requested but faster-whisper is not installed.")
if resolved_root is not None and not has_fw_weights:
raise FileNotFoundError(
f"faster-whisper backend requested but no Faster-Whisper weights were found under {resolved_root}"
)
return backend_choice
if backend_choice == "whisper":
return backend_choice
raise ValueError(f"Unknown backend '{preferred_backend}' for LocalAgreement.")

View File

@@ -1,215 +0,0 @@
import json
import re
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional, Tuple, Union
@dataclass
class ModelInfo:
"""Information about detected model format and files in a directory."""
path: Optional[Path] = None
pytorch_files: List[Path] = field(default_factory=list)
compatible_whisper_mlx: bool = False
compatible_faster_whisper: bool = False
@property
def has_pytorch(self) -> bool:
return len(self.pytorch_files) > 0
@property
def is_sharded(self) -> bool:
return len(self.pytorch_files) > 1
@property
def primary_pytorch_file(self) -> Optional[Path]:
"""Return the primary PyTorch file (or first shard for sharded models)."""
if not self.pytorch_files:
return None
return self.pytorch_files[0]
#regex pattern for sharded model files such as: model-00001-of-00002.safetensors or pytorch_model-00001-of-00002.bin
SHARDED_PATTERN = re.compile(r"^(.+)-(\d{5})-of-(\d{5})\.(safetensors|bin)$")
FASTER_WHISPER_MARKERS = {"model.bin", "encoder.bin", "decoder.bin"}
MLX_WHISPER_MARKERS = {"weights.npz", "weights.safetensors"}
CT2_INDICATOR_FILES = {"vocabulary.json", "vocabulary.txt", "shared_vocabulary.json"}
def _is_ct2_model_bin(directory: Path, filename: str) -> bool:
"""
Determine if model.bin/encoder.bin/decoder.bin is a CTranslate2 model.
CTranslate2 models have specific companion files that distinguish them
from PyTorch .bin files.
"""
n_indicators = 0
for indicator in CT2_INDICATOR_FILES: #test 1
if (directory / indicator).exists():
n_indicators += 1
if n_indicators == 0:
return False
config_path = directory / "config.json" #test 2
if config_path.exists():
try:
with open(config_path, "r", encoding="utf-8") as f:
config = json.load(f)
if config.get("model_type") == "whisper": #test 2
return False
except (json.JSONDecodeError, IOError):
pass
return True
def _collect_pytorch_files(directory: Path) -> List[Path]:
"""
Collect all PyTorch checkpoint files from a directory.
Handles:
- Single files: model.safetensors, pytorch_model.bin, *.pt
- Sharded files: model-00001-of-00002.safetensors, pytorch_model-00001-of-00002.bin
- Index-based sharded models (reads index file to find shards)
Returns files sorted appropriately (shards in order, or single file).
"""
for index_name in ["model.safetensors.index.json", "pytorch_model.bin.index.json"]:
index_path = directory / index_name
if index_path.exists():
try:
with open(index_path, "r", encoding="utf-8") as f:
index_data = json.load(f)
weight_map = index_data.get("weight_map", {})
if weight_map:
shard_names = sorted(set(weight_map.values()))
shards = [directory / name for name in shard_names if (directory / name).exists()]
if shards:
return shards
except (json.JSONDecodeError, IOError):
pass
sharded_groups = {}
single_files = {}
for file in directory.iterdir():
if not file.is_file():
continue
filename = file.name
suffix = file.suffix.lower()
if filename.startswith("adapter_"):
continue
match = SHARDED_PATTERN.match(filename)
if match:
base_name, shard_idx, total_shards, ext = match.groups()
key = (base_name, ext, int(total_shards))
if key not in sharded_groups:
sharded_groups[key] = []
sharded_groups[key].append((int(shard_idx), file))
continue
if filename == "model.safetensors":
single_files[0] = file # Highest priority
elif filename == "pytorch_model.bin":
single_files[1] = file
elif suffix == ".pt":
single_files[2] = file
elif suffix == ".safetensors" and not filename.startswith("adapter"):
single_files[3] = file
for (base_name, ext, total_shards), shards in sharded_groups.items():
if len(shards) == total_shards:
return [path for _, path in sorted(shards)]
for priority in sorted(single_files.keys()):
return [single_files[priority]]
return []
def detect_model_format(model_path: Union[str, Path]) -> ModelInfo:
"""
Detect the model format in a given path.
This function analyzes a file or directory to determine:
- What PyTorch checkpoint files are available (including sharded models)
- Whether the directory contains MLX Whisper weights
- Whether the directory contains Faster-Whisper (CTranslate2) weights
Args:
model_path: Path to a model file or directory
Returns:
ModelInfo with detected format information
"""
path = Path(model_path)
info = ModelInfo(path=path)
if path.is_file():
suffix = path.suffix.lower()
if suffix in {".pt", ".safetensors", ".bin"}:
info.pytorch_files = [path]
return info
if not path.is_dir():
return info
for file in path.iterdir():
if not file.is_file():
continue
filename = file.name.lower()
if filename in MLX_WHISPER_MARKERS:
info.compatible_whisper_mlx = True
if filename in FASTER_WHISPER_MARKERS:
if _is_ct2_model_bin(path, filename):
info.compatible_faster_whisper = True
info.pytorch_files = _collect_pytorch_files(path)
return info
def model_path_and_type(model_path: Union[str, Path]) -> Tuple[Optional[Path], bool, bool]:
"""
Inspect the provided path and determine which model formats are available.
This is a compatibility wrapper around detect_model_format().
Returns:
pytorch_path: Path to a PyTorch checkpoint (first shard for sharded models, or None).
compatible_whisper_mlx: True if MLX weights exist in this folder.
compatible_faster_whisper: True if Faster-Whisper (CTranslate2) weights exist.
"""
info = detect_model_format(model_path)
return info.primary_pytorch_file, info.compatible_whisper_mlx, info.compatible_faster_whisper
def resolve_model_path(model_path: Union[str, Path]) -> Path:
"""
Return a local path for the provided model reference.
If the path does not exist locally, it is treated as a Hugging Face repo id
and downloaded via snapshot_download.
"""
path = Path(model_path).expanduser()
if path.exists():
return path
try:
from huggingface_hub import snapshot_download
except ImportError as exc:
raise FileNotFoundError(
f"Model path '{model_path}' does not exist locally and huggingface_hub "
"is not installed to download it."
) from exc
downloaded_path = Path(snapshot_download(repo_id=str(model_path)))
return downloaded_path

View File

@@ -1,332 +0,0 @@
from argparse import ArgumentParser
def parse_args():
parser = ArgumentParser(description="Whisper FastAPI Online Server")
parser.add_argument(
"--host",
type=str,
default="localhost",
help="The host address to bind the server to.",
)
parser.add_argument(
"--port", type=int, default=8000, help="The port number to bind the server to."
)
parser.add_argument(
"--warmup-file",
type=str,
default=None,
dest="warmup_file",
help="""
The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast.
If not set, uses https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav.
If empty, no warmup is performed.
""",
)
parser.add_argument(
"--confidence-validation",
action="store_true",
help="Accelerates validation of tokens using confidence scores. Transcription will be faster but punctuation might be less accurate.",
)
parser.add_argument(
"--diarization",
action="store_true",
default=False,
help="Enable speaker diarization.",
)
parser.add_argument(
"--punctuation-split",
action="store_true",
default=False,
help="Use punctuation marks from transcription to improve speaker boundary detection. Requires both transcription and diarization to be enabled.",
)
parser.add_argument(
"--segmentation-model",
type=str,
default="pyannote/segmentation-3.0",
help="Hugging Face model ID for pyannote.audio segmentation model.",
)
parser.add_argument(
"--embedding-model",
type=str,
default="pyannote/embedding",
help="Hugging Face model ID for pyannote.audio embedding model.",
)
parser.add_argument(
"--diarization-backend",
type=str,
default="sortformer",
choices=["sortformer", "diart"],
help="The diarization backend to use.",
)
parser.add_argument(
"--no-transcription",
action="store_true",
help="Disable transcription to only see live diarization results.",
)
parser.add_argument(
"--disable-punctuation-split",
action="store_true",
help="Disable the split parameter.",
)
parser.add_argument(
"--min-chunk-size",
type=float,
default=0.1,
help="Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.",
)
parser.add_argument(
"--model",
type=str,
default="base",
dest='model_size',
help="Name size of the Whisper model to use (default: tiny). Suggested values: tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo. The model is automatically downloaded from the model hub if not present in model cache dir.",
)
parser.add_argument(
"--model_cache_dir",
type=str,
default=None,
help="Overriding the default model cache dir where models downloaded from the hub are saved",
)
parser.add_argument(
"--model_dir",
type=str,
default=None,
help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.",
)
parser.add_argument(
"--lora-path",
type=str,
default=None,
dest="lora_path",
help="Path or Hugging Face repo ID for LoRA adapter weights (e.g., QuentinFuxa/whisper-base-french-lora). Only works with native Whisper backend.",
)
parser.add_argument(
"--lan",
"--language",
type=str,
default="auto",
dest='lan',
help="Source language code, e.g. en,de,cs, or 'auto' for language detection.",
)
parser.add_argument(
"--direct-english-translation",
action="store_true",
default=False,
help="Use Whisper to directly translate to english.",
)
parser.add_argument(
"--target-language",
type=str,
default="",
dest="target_language",
help="Target language for translation. Not functional yet.",
)
parser.add_argument(
"--backend-policy",
type=str,
default="simulstreaming",
choices=["1", "2", "simulstreaming", "localagreement"],
help="Select the streaming policy: 1 or 'simulstreaming' for AlignAtt, 2 or 'localagreement' for LocalAgreement.",
)
parser.add_argument(
"--backend",
type=str,
default="auto",
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api"],
help="Select the Whisper backend implementation (auto: prefer MLX on macOS, otherwise Faster-Whisper, else Whisper). Use 'openai-api' with --backend-policy localagreement to call OpenAI's API.",
)
parser.add_argument(
"--no-vac",
action="store_true",
default=False,
help="Disable VAC = voice activity controller.",
)
parser.add_argument(
"--vac-chunk-size", type=float, default=0.04, help="VAC sample size in seconds."
)
parser.add_argument(
"--no-vad",
action="store_true",
help="Disable VAD (voice activity detection).",
)
parser.add_argument(
"--buffer_trimming",
type=str,
default="segment",
choices=["sentence", "segment"],
help='Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.',
)
parser.add_argument(
"--buffer_trimming_sec",
type=float,
default=15,
help="Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.",
)
parser.add_argument(
"-l",
"--log-level",
dest="log_level",
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Set the log level",
default="DEBUG",
)
parser.add_argument("--ssl-certfile", type=str, help="Path to the SSL certificate file.", default=None)
parser.add_argument("--ssl-keyfile", type=str, help="Path to the SSL private key file.", default=None)
parser.add_argument("--forwarded-allow-ips", type=str, help="Allowed ips for reverse proxying.", default=None)
parser.add_argument(
"--pcm-input",
action="store_true",
default=False,
help="If set, raw PCM (s16le) data is expected as input and FFmpeg will be bypassed. Frontend will use AudioWorklet instead of MediaRecorder."
)
# SimulStreaming-specific arguments
simulstreaming_group = parser.add_argument_group('SimulStreaming arguments (only used with --backend simulstreaming)')
simulstreaming_group.add_argument(
"--disable-fast-encoder",
action="store_true",
default=False,
dest="disable_fast_encoder",
help="Disable Faster Whisper or MLX Whisper backends for encoding (if installed). Slower but helpful when GPU memory is limited",
)
simulstreaming_group.add_argument(
"--custom-alignment-heads",
type=str,
default=None,
help="Use your own alignment heads, useful when `--model-dir` is used",
)
simulstreaming_group.add_argument(
"--frame-threshold",
type=int,
default=25,
dest="frame_threshold",
help="Threshold for the attention-guided decoding. The AlignAtt policy will decode only until this number of frames from the end of audio. In frames: one frame is 0.02 seconds for large-v3 model.",
)
simulstreaming_group.add_argument(
"--beams",
"-b",
type=int,
default=1,
help="Number of beams for beam search decoding. If 1, GreedyDecoder is used.",
)
simulstreaming_group.add_argument(
"--decoder",
type=str,
default=None,
dest="decoder_type",
choices=["beam", "greedy"],
help="Override automatic selection of beam or greedy decoder. If beams > 1 and greedy: invalid.",
)
simulstreaming_group.add_argument(
"--audio-max-len",
type=float,
default=30.0,
dest="audio_max_len",
help="Max length of the audio buffer, in seconds.",
)
simulstreaming_group.add_argument(
"--audio-min-len",
type=float,
default=0.0,
dest="audio_min_len",
help="Skip processing if the audio buffer is shorter than this length, in seconds. Useful when the --min-chunk-size is small.",
)
simulstreaming_group.add_argument(
"--cif-ckpt-path",
type=str,
default=None,
dest="cif_ckpt_path",
help="The file path to the Simul-Whisper's CIF model checkpoint that detects whether there is end of word at the end of the chunk. If not, the last decoded space-separated word is truncated because it is often wrong -- transcribing a word in the middle. The CIF model adapted for the Whisper model version should be used. Find the models in https://github.com/backspacetg/simul_whisper/tree/main/cif_models . Note that there is no model for large-v3.",
)
simulstreaming_group.add_argument(
"--never-fire",
action="store_true",
default=False,
dest="never_fire",
help="Override the CIF model. If True, the last word is NEVER truncated, no matter what the CIF model detects. If False: if CIF model path is set, the last word is SOMETIMES truncated, depending on the CIF detection. Otherwise, if the CIF model path is not set, the last word is ALWAYS trimmed.",
)
simulstreaming_group.add_argument(
"--init-prompt",
type=str,
default=None,
dest="init_prompt",
help="Init prompt for the model. It should be in the target language.",
)
simulstreaming_group.add_argument(
"--static-init-prompt",
type=str,
default=None,
dest="static_init_prompt",
help="Do not scroll over this text. It can contain terminology that should be relevant over all document.",
)
simulstreaming_group.add_argument(
"--max-context-tokens",
type=int,
default=None,
dest="max_context_tokens",
help="Max context tokens for the model. Default is 0.",
)
simulstreaming_group.add_argument(
"--model-path",
type=str,
default=None,
dest="model_path",
help="Direct path to the SimulStreaming Whisper .pt model file. Overrides --model for SimulStreaming backend.",
)
simulstreaming_group.add_argument(
"--nllb-backend",
type=str,
default="transformers",
help="transformers or ctranslate2",
)
simulstreaming_group.add_argument(
"--nllb-size",
type=str,
default="600M",
help="600M or 1.3B",
)
args = parser.parse_args()
args.transcription = not args.no_transcription
args.vad = not args.no_vad
delattr(args, 'no_transcription')
delattr(args, 'no_vad')
if args.backend_policy == "1":
args.backend_policy = "simulstreaming"
elif args.backend_policy == "2":
args.backend_policy = "localagreement"
return args

View File

@@ -1,326 +0,0 @@
import warnings
from pathlib import Path
import numpy as np
import torch
"""
Code is adapted from silero-vad v6: https://github.com/snakers4/silero-vad
"""
def is_onnx_available() -> bool:
"""Check if onnxruntime is installed."""
try:
import onnxruntime
return True
except ImportError:
return False
def init_jit_model(model_path: str, device=torch.device('cpu')):
"""Load a JIT model from file."""
model = torch.jit.load(model_path, map_location=device)
model.eval()
return model
class OnnxSession():
"""
Shared ONNX session for Silero VAD model (stateless).
"""
def __init__(self, path, force_onnx_cpu=False):
import onnxruntime
opts = onnxruntime.SessionOptions()
opts.inter_op_num_threads = 1
opts.intra_op_num_threads = 1
if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts)
else:
self.session = onnxruntime.InferenceSession(path, sess_options=opts)
self.path = path
if '16k' in path:
warnings.warn('This model support only 16000 sampling rate!')
self.sample_rates = [16000]
else:
self.sample_rates = [8000, 16000]
class OnnxWrapper():
"""
ONNX Runtime wrapper for Silero VAD model with per-instance state.
"""
def __init__(self, session: OnnxSession, force_onnx_cpu=False):
self._shared_session = session
self.sample_rates = session.sample_rates
self.reset_states()
@property
def session(self):
return self._shared_session.session
def _validate_input(self, x, sr: int):
if x.dim() == 1:
x = x.unsqueeze(0)
if x.dim() > 2:
raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}")
if sr != 16000 and (sr % 16000 == 0):
step = sr // 16000
x = x[:,::step]
sr = 16000
if sr not in self.sample_rates:
raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)")
if sr / x.shape[1] > 31.25:
raise ValueError("Input audio chunk is too short")
return x, sr
def reset_states(self, batch_size=1):
self._state = torch.zeros((2, batch_size, 128)).float()
self._context = torch.zeros(0)
self._last_sr = 0
self._last_batch_size = 0
def __call__(self, x, sr: int):
x, sr = self._validate_input(x, sr)
num_samples = 512 if sr == 16000 else 256
if x.shape[-1] != num_samples:
raise ValueError(f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")
batch_size = x.shape[0]
context_size = 64 if sr == 16000 else 32
if not self._last_batch_size:
self.reset_states(batch_size)
if (self._last_sr) and (self._last_sr != sr):
self.reset_states(batch_size)
if (self._last_batch_size) and (self._last_batch_size != batch_size):
self.reset_states(batch_size)
if not len(self._context):
self._context = torch.zeros(batch_size, context_size)
x = torch.cat([self._context, x], dim=1)
if sr in [8000, 16000]:
ort_inputs = {'input': x.numpy(), 'state': self._state.numpy(), 'sr': np.array(sr, dtype='int64')}
ort_outs = self.session.run(None, ort_inputs)
out, state = ort_outs
self._state = torch.from_numpy(state)
else:
raise ValueError()
self._context = x[..., -context_size:]
self._last_sr = sr
self._last_batch_size = batch_size
out = torch.from_numpy(out)
return out
def _get_onnx_model_path(model_path: str = None, opset_version: int = 16) -> Path:
"""Get the path to the ONNX model file."""
available_ops = [15, 16]
if opset_version not in available_ops:
raise Exception(f'Available ONNX opset_version: {available_ops}')
if model_path is None:
current_dir = Path(__file__).parent
data_dir = current_dir / 'silero_vad_models'
if opset_version == 16:
model_name = 'silero_vad.onnx'
else:
model_name = f'silero_vad_16k_op{opset_version}.onnx'
model_path = data_dir / model_name
if not model_path.exists():
raise FileNotFoundError(
f"Model file not found: {model_path}\n"
f"Please ensure the whisperlivekit/silero_vad_models/ directory contains the model files."
)
else:
model_path = Path(model_path)
return model_path
def load_onnx_session(model_path: str = None, opset_version: int = 16, force_onnx_cpu: bool = True) -> OnnxSession:
"""
Load a shared ONNX session for Silero VAD.
"""
path = _get_onnx_model_path(model_path, opset_version)
return OnnxSession(str(path), force_onnx_cpu=force_onnx_cpu)
def load_jit_vad(model_path: str = None):
"""
Load Silero VAD model in JIT format.
"""
if model_path is None:
current_dir = Path(__file__).parent
data_dir = current_dir / 'silero_vad_models'
model_name = 'silero_vad.jit'
model_path = data_dir / model_name
if not model_path.exists():
raise FileNotFoundError(
f"Model file not found: {model_path}\n"
f"Please ensure the whisperlivekit/silero_vad_models/ directory contains the model files."
)
else:
model_path = Path(model_path)
model = init_jit_model(str(model_path))
return model
class VADIterator:
"""
Voice Activity Detection iterator for streaming audio.
This is the Silero VAD v6 implementation.
"""
def __init__(self,
model,
threshold: float = 0.5,
sampling_rate: int = 16000,
min_silence_duration_ms: int = 100,
speech_pad_ms: int = 30
):
"""
Class for stream imitation
Parameters
----------
model: preloaded .jit/.onnx silero VAD model
threshold: float (default - 0.5)
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
sampling_rate: int (default - 16000)
Currently silero VAD models support 8000 and 16000 sample rates
min_silence_duration_ms: int (default - 100 milliseconds)
In the end of each speech chunk wait for min_silence_duration_ms before separating it
speech_pad_ms: int (default - 30 milliseconds)
Final speech chunks are padded by speech_pad_ms each side
"""
self.model = model
self.threshold = threshold
self.sampling_rate = sampling_rate
if sampling_rate not in [8000, 16000]:
raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]')
self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
self.reset_states()
def reset_states(self):
self.model.reset_states()
self.triggered = False
self.temp_end = 0
self.current_sample = 0
@torch.no_grad()
def __call__(self, x, return_seconds=False, time_resolution: int = 1):
"""
x: torch.Tensor
audio chunk (see examples in repo)
return_seconds: bool (default - False)
whether return timestamps in seconds (default - samples)
time_resolution: int (default - 1)
time resolution of speech coordinates when requested as seconds
"""
if not torch.is_tensor(x):
try:
x = torch.Tensor(x)
except:
raise TypeError("Audio cannot be casted to tensor. Cast it manually")
window_size_samples = len(x[0]) if x.dim() == 2 else len(x)
self.current_sample += window_size_samples
speech_prob = self.model(x, self.sampling_rate).item()
if (speech_prob >= self.threshold) and self.temp_end:
self.temp_end = 0
if (speech_prob >= self.threshold) and not self.triggered:
self.triggered = True
speech_start = max(0, self.current_sample - self.speech_pad_samples - window_size_samples)
return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, time_resolution)}
if (speech_prob < self.threshold - 0.15) and self.triggered:
if not self.temp_end:
self.temp_end = self.current_sample
if self.current_sample - self.temp_end < self.min_silence_samples:
return None
else:
speech_end = self.temp_end + self.speech_pad_samples - window_size_samples
self.temp_end = 0
self.triggered = False
return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, time_resolution)}
return None
class FixedVADIterator(VADIterator):
"""
Fixed VAD Iterator that handles variable-length audio chunks, not only exactly 512 frames at once.
"""
def reset_states(self):
super().reset_states()
self.buffer = np.array([], dtype=np.float32)
def __call__(self, x, return_seconds=False):
self.buffer = np.append(self.buffer, x)
ret = None
while len(self.buffer) >= 512:
r = super().__call__(self.buffer[:512], return_seconds=return_seconds)
self.buffer = self.buffer[512:]
if ret is None:
ret = r
elif r is not None:
if "end" in r:
ret["end"] = r["end"]
if "start" in r:
ret["start"] = r["start"]
if "end" in ret:
del ret["end"]
return ret if ret != {} else None
if __name__ == "__main__":
# vad = FixedVADIterator(load_jit_vad())
vad = FixedVADIterator(OnnxWrapper(session=load_onnx_session()))
audio_buffer = np.array([0] * 512, dtype=np.float32)
result = vad(audio_buffer)
print(f" 512 samples: {result}")
# test with 511 samples
audio_buffer = np.array([0] * 511, dtype=np.float32)
result = vad(audio_buffer)
print(f" 511 samples: {result}")

View File

@@ -1,6 +0,0 @@
from .backend import SimulStreamingASR, SimulStreamingOnlineProcessor
__all__ = [
"SimulStreamingASR",
"SimulStreamingOnlineProcessor",
]

View File

@@ -1,368 +0,0 @@
import gc
import logging
import os
import platform
import sys
from pathlib import Path
from typing import List, Optional, Tuple
import numpy as np
import torch
from whisperlivekit.backend_support import (faster_backend_available,
mlx_backend_available)
from whisperlivekit.model_paths import detect_model_format, resolve_model_path
from whisperlivekit.simul_whisper.config import AlignAttConfig
from whisperlivekit.simul_whisper.simul_whisper import AlignAtt
from whisperlivekit.timed_objects import ASRToken, ChangeSpeaker, Transcript
from whisperlivekit.warmup import load_file
from whisperlivekit.whisper import load_model, tokenizer
from whisperlivekit.whisper.audio import TOKENS_PER_SECOND
logger = logging.getLogger(__name__)
HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True)
if HAS_MLX_WHISPER:
from .mlx_encoder import load_mlx_encoder, load_mlx_model, mlx_model_mapping
from .mlx import MLXAlignAtt
else:
mlx_model_mapping = {}
MLXAlignAtt = None
HAS_FASTER_WHISPER = faster_backend_available(warn_on_missing=not HAS_MLX_WHISPER)
if HAS_FASTER_WHISPER:
from faster_whisper import WhisperModel
else:
WhisperModel = None
MIN_DURATION_REAL_SILENCE = 5
class SimulStreamingOnlineProcessor:
"""Online processor for SimulStreaming ASR."""
SAMPLING_RATE = 16000
def __init__(self, asr, logfile=sys.stderr):
self.asr = asr
self.logfile = logfile
self.end = 0.0
self.buffer = []
self.committed: List[ASRToken] = []
self.last_result_tokens: List[ASRToken] = []
self.model = self._create_alignatt()
if asr.tokenizer:
self.model.tokenizer = asr.tokenizer
self.model.state.tokenizer = asr.tokenizer
def _create_alignatt(self):
"""Create the AlignAtt decoder instance based on ASR mode."""
if self.asr.use_full_mlx and HAS_MLX_WHISPER:
return MLXAlignAtt(cfg=self.asr.cfg, mlx_model=self.asr.mlx_model)
else:
return AlignAtt(
cfg=self.asr.cfg,
loaded_model=self.asr.shared_model,
mlx_encoder=self.asr.mlx_encoder,
fw_encoder=self.asr.fw_encoder,
)
def start_silence(self):
tokens, processed_upto = self.process_iter(is_last=True)
return tokens, processed_upto
def end_silence(self, silence_duration, offset):
"""Handle silence period."""
self.end += silence_duration
long_silence = silence_duration >= MIN_DURATION_REAL_SILENCE
if not long_silence:
gap_len = int(16000 * silence_duration)
if gap_len > 0:
if self.asr.use_full_mlx:
gap_silence = np.zeros(gap_len, dtype=np.float32)
else:
gap_silence = torch.zeros(gap_len)
self.model.insert_audio(gap_silence)
if long_silence:
self.model.refresh_segment(complete=True)
self.model.global_time_offset = silence_duration + offset
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time):
"""Append an audio chunk to be processed by SimulStreaming."""
self.end = audio_stream_end_time
if self.asr.use_full_mlx:
self.model.insert_audio(audio)
else:
audio_tensor = torch.from_numpy(audio).float()
self.model.insert_audio(audio_tensor)
def new_speaker(self, change_speaker: ChangeSpeaker):
"""Handle speaker change event."""
self.process_iter(is_last=True)
self.model.refresh_segment(complete=True)
self.model.speaker = change_speaker.speaker
self.model.global_time_offset = change_speaker.start
def get_buffer(self):
concat_buffer = Transcript.from_tokens(tokens= self.buffer, sep='')
return concat_buffer
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
"""
Process accumulated audio chunks using SimulStreaming.
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
"""
try:
timestamped_words = self.model.infer(is_last=is_last)
if not timestamped_words:
return [], self.end
if self.model.cfg.language == "auto" and timestamped_words[0].detected_language is None:
self.buffer.extend(timestamped_words)
return [], self.end
self.committed.extend(timestamped_words)
self.buffer = []
return timestamped_words, self.end
except Exception as e:
logger.exception(f"SimulStreaming processing error: {e}")
return [], self.end
def warmup(self, audio, init_prompt=""):
"""Warmup the SimulStreaming model."""
try:
if self.asr.use_full_mlx:
# MLX mode: ensure numpy array
if hasattr(audio, 'numpy'):
audio = audio.numpy()
self.model.insert_audio(audio)
self.model.infer(True)
self.model.refresh_segment(complete=True)
logger.info("SimulStreaming model warmed up successfully")
except Exception as e:
logger.exception(f"SimulStreaming warmup failed: {e}")
def __del__(self):
gc.collect()
if not getattr(self.asr, 'use_full_mlx', True) and torch is not None:
try:
torch.cuda.empty_cache()
except Exception:
pass
class SimulStreamingASR:
"""SimulStreaming backend with AlignAtt policy."""
sep = ""
def __init__(self, logfile=sys.stderr, **kwargs):
self.logfile = logfile
self.transcribe_kargs = {}
for key, value in kwargs.items():
setattr(self, key, value)
if self.decoder_type is None:
self.decoder_type = 'greedy' if self.beams == 1 else 'beam'
self.fast_encoder = False
self._resolved_model_path = None
self.encoder_backend = "whisper"
self.use_full_mlx = getattr(self, "use_full_mlx", False)
preferred_backend = getattr(self, "backend", "auto")
compatible_whisper_mlx, compatible_faster_whisper = True, True
if self.model_path:
resolved_model_path = resolve_model_path(self.model_path)
self._resolved_model_path = resolved_model_path
self.model_path = str(resolved_model_path)
model_info = detect_model_format(resolved_model_path)
compatible_whisper_mlx = model_info.compatible_whisper_mlx
compatible_faster_whisper = model_info.compatible_faster_whisper
if not self.use_full_mlx and not model_info.has_pytorch:
raise FileNotFoundError(
f"No PyTorch checkpoint (.pt/.bin/.safetensors) found under {self.model_path}"
)
self.model_name = resolved_model_path.name if resolved_model_path.is_dir() else resolved_model_path.stem
elif self.model_size is not None:
self.model_name = self.model_size
else:
raise ValueError("Either model_size or model_path must be specified for SimulStreaming.")
is_multilingual = not self.model_name.endswith(".en")
self.encoder_backend = self._resolve_encoder_backend(
preferred_backend,
compatible_whisper_mlx,
compatible_faster_whisper,
)
self.fast_encoder = self.encoder_backend in ("mlx-whisper", "faster-whisper")
if self.encoder_backend == "whisper":
self.disable_fast_encoder = True
if self.encoder_backend == "mlx-whisper" and platform.system() == "Darwin":
if not hasattr(self, '_full_mlx_disabled'):
self.use_full_mlx = True
self.cfg = AlignAttConfig(
tokenizer_is_multilingual= is_multilingual,
segment_length=self.min_chunk_size,
frame_threshold=self.frame_threshold,
language=self.lan,
audio_max_len=self.audio_max_len,
audio_min_len=self.audio_min_len,
cif_ckpt_path=self.cif_ckpt_path,
decoder_type="beam",
beam_size=self.beams,
task=self.direct_english_translation,
never_fire=self.never_fire,
init_prompt=self.init_prompt,
max_context_tokens=self.max_context_tokens,
static_init_prompt=self.static_init_prompt,
)
# Set up tokenizer for translation if needed
if self.direct_english_translation:
self.tokenizer = self.set_translate_task()
else:
self.tokenizer = None
self.mlx_encoder, self.fw_encoder, self.mlx_model = None, None, None
self.shared_model = None
if self.use_full_mlx and HAS_MLX_WHISPER:
logger.info('MLX Whisper backend used.')
if self._resolved_model_path is not None:
mlx_model_path = str(self._resolved_model_path)
else:
mlx_model_path = mlx_model_mapping.get(self.model_name)
if not mlx_model_path:
raise FileNotFoundError(
f"MLX Whisper backend requested but no compatible weights found for model '{self.model_name}'."
)
self.mlx_model = load_mlx_model(path_or_hf_repo=mlx_model_path)
self._warmup_mlx_model()
elif self.encoder_backend == "mlx-whisper":
# hybrid mode: mlx encoder + pytorch decoder
logger.info('SimulStreaming will use MLX Whisper encoder with PyTorch decoder.')
if self._resolved_model_path is not None:
mlx_model_path = str(self._resolved_model_path)
else:
mlx_model_path = mlx_model_mapping.get(self.model_name)
if not mlx_model_path:
raise FileNotFoundError(
f"MLX Whisper backend requested but no compatible weights found for model '{self.model_name}'."
)
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_path)
self.shared_model = self.load_model()
elif self.encoder_backend == "faster-whisper":
print('SimulStreaming will use Faster Whisper for the encoder.')
if self._resolved_model_path is not None:
fw_model = str(self._resolved_model_path)
else:
fw_model = self.model_name
self.fw_encoder = WhisperModel(
fw_model,
device='auto',
compute_type='auto',
)
self.shared_model = self.load_model()
else:
self.shared_model = self.load_model()
def _warmup_mlx_model(self):
"""Warmup the full MLX model."""
warmup_audio = load_file(self.warmup_file)
if warmup_audio is not None:
temp_model = MLXAlignAtt(
cfg=self.cfg,
mlx_model=self.mlx_model,
)
temp_model.warmup(warmup_audio)
logger.info("Full MLX model warmed up successfully")
def _resolve_encoder_backend(self, preferred_backend, compatible_whisper_mlx, compatible_faster_whisper):
choice = preferred_backend or "auto"
if self.disable_fast_encoder:
return "whisper"
if choice == "whisper":
return "whisper"
if choice == "mlx-whisper":
if not self._can_use_mlx(compatible_whisper_mlx):
raise RuntimeError("mlx-whisper backend requested but MLX Whisper is unavailable or incompatible with the provided model.")
return "mlx-whisper"
if choice == "faster-whisper":
if not self._can_use_faster(compatible_faster_whisper):
raise RuntimeError("faster-whisper backend requested but Faster-Whisper is unavailable or incompatible with the provided model.")
return "faster-whisper"
if choice == "openai-api":
raise ValueError("openai-api backend is only supported with the LocalAgreement policy.")
# auto mode
if platform.system() == "Darwin" and self._can_use_mlx(compatible_whisper_mlx):
return "mlx-whisper"
if self._can_use_faster(compatible_faster_whisper):
return "faster-whisper"
return "whisper"
def _has_custom_model_path(self):
return self._resolved_model_path is not None
def _can_use_mlx(self, compatible_whisper_mlx):
if not HAS_MLX_WHISPER:
return False
if self._has_custom_model_path():
return compatible_whisper_mlx
return self.model_name in mlx_model_mapping
def _can_use_faster(self, compatible_faster_whisper):
if not HAS_FASTER_WHISPER:
return False
if self._has_custom_model_path():
return compatible_faster_whisper
return True
def load_model(self):
model_ref = str(self._resolved_model_path) if self._resolved_model_path else self.model_name
lora_path = getattr(self, 'lora_path', None)
whisper_model = load_model(
name=model_ref,
download_root=None,
decoder_only=self.fast_encoder,
custom_alignment_heads=self.custom_alignment_heads,
lora_path=lora_path,
)
warmup_audio = load_file(self.warmup_file)
if warmup_audio is not None:
warmup_audio = torch.from_numpy(warmup_audio).float()
if self.fast_encoder:
temp_model = AlignAtt(
cfg=self.cfg,
loaded_model=whisper_model,
mlx_encoder=self.mlx_encoder,
fw_encoder=self.fw_encoder,
)
temp_model.warmup(warmup_audio)
else:
whisper_model.transcribe(warmup_audio, language=self.lan if self.lan != 'auto' else None)
return whisper_model
def set_translate_task(self):
"""Set up translation task."""
if self.cfg.language == 'auto':
raise Exception('Translation cannot be done with language = auto')
return tokenizer.get_tokenizer(
multilingual=True,
language=self.cfg.language,
num_languages=99,
task="translate"
)
def transcribe(self, audio):
"""
Warmup is done directly in load_model
"""
pass

View File

@@ -1,32 +0,0 @@
from torch import Tensor
from whisperlivekit.whisper.decoding import PyTorchInference
class BeamPyTorchInference(PyTorchInference):
"""Extension of PyTorchInference for beam search with cross-attention support."""
def _kv_cache_ids(self):
"""Get cache_id strings for self-attention key/value modules."""
key_ids = [block.attn.key_cache_id for block in self.model.decoder.blocks]
value_ids = [block.attn.value_cache_id for block in self.model.decoder.blocks]
return key_ids + value_ids
def rearrange_kv_cache(self, source_indices):
if source_indices != list(range(len(source_indices))):
for cache_id in self._kv_cache_ids():
if cache_id in self.kv_cache:
self.kv_cache[cache_id] = self.kv_cache[cache_id][source_indices].detach()
def logits(
self,
tokens: Tensor,
audio_features: Tensor,
return_cross_attn: bool = False,
):
"""Get logits, optionally returning cross-attention weights."""
return self.model.decoder(
tokens, audio_features,
kv_cache=self.kv_cache,
return_cross_attn=return_cross_attn,
)

View File

@@ -1,24 +0,0 @@
from dataclasses import dataclass, field
from typing import Literal
@dataclass
class AlignAttConfig():
eval_data_path: str = "tmp"
segment_length: float = field(default=1.0, metadata = {"help": "in second"})
frame_threshold: int = 4
rewind_threshold: int = 200
audio_max_len: float = 20.0
cif_ckpt_path: str = ""
never_fire: bool = False
language: str = field(default="zh")
nonspeech_prob: float = 0.5
audio_min_len: float = 1.0
decoder_type: Literal["greedy","beam"] = "greedy"
beam_size: int = 5
task: Literal["transcribe","translate"] = "transcribe"
tokenizer_is_multilingual: bool = False
init_prompt: str = field(default=None)
static_init_prompt: str = field(default=None)
max_context_tokens: int = field(default=None)

View File

@@ -1,95 +0,0 @@
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
import torch
@dataclass
class DecoderState:
kv_cache: Dict[str, torch.Tensor] = field(default_factory=dict)
tokenizer: Any = None
detected_language: Optional[str] = None
reset_tokenizer_to_auto_next_call: bool = False
tokens: List[torch.Tensor] = field(default_factory=list)
initial_tokens: Optional[torch.Tensor] = None
initial_token_length: int = 0
sot_index: int = 0
align_source: Dict[int, List[Tuple[int, int]]] = field(default_factory=dict)
num_align_heads: int = 0
segments: List[torch.Tensor] = field(default_factory=list)
context: Any = None
pending_incomplete_tokens: List[int] = field(default_factory=list)
global_time_offset: float = 0.0
cumulative_time_offset: float = 0.0
first_timestamp: Optional[float] = None
last_attend_frame: int = 0
speaker: int = -1
log_segments: int = 0
CIFLinear: Optional[torch.nn.Module] = None
always_fire: bool = False
never_fire: bool = False
suppress_tokens_fn: Any = None
token_decoder: Any = None
decoder_type: str = "greedy"
inference: Any = None
def clean_cache(self):
"""Clean the kv_cache after each inference step."""
# Explicitly delete tensor references to free GPU memory
if self.kv_cache:
for key in list(self.kv_cache.keys()):
tensor = self.kv_cache.pop(key, None)
if tensor is not None:
del tensor
# Clear the dict
self.kv_cache.clear()
# Force GPU cache cleanup (only if CUDA is available)
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
if self.decoder_type == "beam" and self.inference is not None:
# Create NEW dict instead of sharing reference
self.inference.kv_cache = {}
if self.token_decoder is not None:
self.token_decoder.reset()
def reset(self, rewind_threshold: int = 200):
"""
Reset transient state for a new segment.
Args:
rewind_threshold: Value for resetting last_attend_frame
"""
self.last_attend_frame = -rewind_threshold
self.cumulative_time_offset = 0.0
self.pending_incomplete_tokens = []
self.log_segments += 1
def full_reset(self, rewind_threshold: int = 200):
"""
Full reset including audio segments and tokens.
Args:
rewind_threshold: Value for resetting last_attend_frame
"""
self.reset(rewind_threshold)
self.segments = []
self.tokens = []
self.kv_cache = {}
self.first_timestamp = None

View File

@@ -1,65 +0,0 @@
import torch
# code for the end-of-word detection based on the CIF model proposed in Simul-Whisper
def load_cif(cfg, n_audio_state, device):
"""cfg: AlignAttConfig, n_audio_state: int, device: torch.device"""
cif_linear = torch.nn.Linear(n_audio_state, 1)
if cfg.cif_ckpt_path is None or not cfg.cif_ckpt_path:
if cfg.never_fire:
never_fire = True
always_fire = False
else:
always_fire = True
never_fire = False
else:
always_fire = False
never_fire = cfg.never_fire
checkpoint = torch.load(cfg.cif_ckpt_path)
cif_linear.load_state_dict(checkpoint)
cif_linear.to(device)
return cif_linear, always_fire, never_fire
# from https://github.com/dqqcasia/mosst/blob/master/fairseq/models/speech_to_text/convtransformer_wav2vec_cif.py
def resize(alphas, target_lengths, threshold=0.999):
"""
alpha in thresh=1.0 | (0.0, +0.21)
target_lengths: if None, apply round and resize, else apply scaling
"""
# sum
_num = alphas.sum(-1)
num = target_lengths.float()
# scaling
_alphas = alphas * (num / _num)[:, None].repeat(1, alphas.size(1))
# rm attention value that exceeds threashold
count = 0
while len(torch.where(_alphas > threshold)[0]):
count += 1
if count > 10:
break
xs, ys = torch.where(_alphas > threshold)
for x, y in zip(xs, ys):
if _alphas[x][y] >= threshold:
mask = _alphas[x].ne(0).float()
mean = 0.5 * _alphas[x].sum() / mask.sum()
_alphas[x] = _alphas[x] * 0.5 + mean * mask
return _alphas, _num
def fire_at_boundary(chunked_encoder_feature: torch.Tensor, cif_linear):
content_mel_len = chunked_encoder_feature.shape[1] # B, T, D
alphas = cif_linear(chunked_encoder_feature).squeeze(dim=2) # B, T
alphas = torch.sigmoid(alphas)
decode_length = torch.round(alphas.sum(-1)).int()
alphas, _ = resize(alphas, decode_length)
alphas = alphas.squeeze(0) # (T, )
threshold = 0.999
integrate = torch.cumsum(alphas[:-1], dim=0) # ignore the peak value at the end of the content chunk
exceed_count = integrate[-1] // threshold
integrate = integrate - exceed_count*1.0 # minus 1 every time intergrate exceed the threshold
important_positions = (integrate >= 0).nonzero(as_tuple=True)[0]
if important_positions.numel() == 0:
return False
else:
return important_positions[0] >= content_mel_len-2

View File

@@ -1,11 +0,0 @@
from .decoder_state import MLXDecoderState
from .decoders import MLXBeamSearchDecoder, MLXGreedyDecoder, MLXInference
from .simul_whisper import MLXAlignAtt
__all__ = [
"MLXAlignAtt",
"MLXBeamSearchDecoder",
"MLXDecoderState",
"MLXGreedyDecoder",
"MLXInference",
]

View File

@@ -1,76 +0,0 @@
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
import mlx.core as mx
import numpy as np
@dataclass
class MLXDecoderState:
"""
mlx kv cache format: List of ((k, v), (cross_k, cross_v)) tuples per layer,
where each element is a tuple of mx.arrays.
"""
kv_cache: Optional[List[Tuple[Tuple[mx.array, mx.array], Tuple[mx.array, mx.array]]]] = None
tokenizer: Any = None
detected_language: Optional[str] = None
reset_tokenizer_to_auto_next_call: bool = False
tokens: List[mx.array] = field(default_factory=list)
initial_tokens: Optional[mx.array] = None
initial_token_length: int = 0
sot_index: int = 0
align_source: Dict[int, List[Tuple[int, int]]] = field(default_factory=dict)
num_align_heads: int = 0
segments: List[np.ndarray] = field(default_factory=list)
context: Any = None
pending_incomplete_tokens: List[int] = field(default_factory=list)
global_time_offset: float = 0.0
cumulative_time_offset: float = 0.0
first_timestamp: Optional[float] = None
last_attend_frame: int = 0
speaker: int = -1
log_segments: int = 0
cif_weights: Optional[mx.array] = None
always_fire: bool = False
never_fire: bool = False
suppress_tokens: Optional[Tuple[int, ...]] = None
token_decoder: Any = None
decoder_type: str = "greedy"
inference: Any = None
def clean_cache(self):
self.kv_cache = None
if self.decoder_type == "beam" and self.inference is not None:
self.inference.kv_cache = None
if self.token_decoder is not None:
self.token_decoder.reset()
def reset(self, rewind_threshold: int = 200):
self.last_attend_frame = -rewind_threshold
self.cumulative_time_offset = 0.0
self.pending_incomplete_tokens = []
self.log_segments += 1
def full_reset(self, rewind_threshold: int = 200):
"""
Full reset including audio segments and tokens.
Args:
rewind_threshold: Value for resetting last_attend_frame
"""
self.reset(rewind_threshold)
self.segments = []
self.tokens = []
self.kv_cache = None
self.first_timestamp = None

View File

@@ -1,219 +0,0 @@
"""
MLX-native token decoders for streaming ASR.
"""
from typing import Any, Dict, List, Optional, Tuple
import mlx.core as mx
import numpy as np
class MLXGreedyDecoder:
"""Greedy decoder using MLX operations."""
def __init__(self, temperature: float, eot: int):
self.temperature = temperature
self.eot = eot
def update(
self, tokens: mx.array, logits: mx.array, sum_logprobs: mx.array
) -> Tuple[mx.array, bool]:
"""
Update tokens with next predicted token.
Args:
tokens: Current token sequence, shape (batch, seq_len)
logits: Logits for next token, shape (batch, vocab_size)
sum_logprobs: Cumulative log probabilities, shape (batch,)
Returns:
Updated tokens and completion flag
"""
if self.temperature == 0:
next_tokens = mx.argmax(logits, axis=-1)
else:
probs = mx.softmax(logits / self.temperature, axis=-1)
next_tokens = mx.random.categorical(mx.log(probs + 1e-10))
logprobs = mx.softmax(logits, axis=-1)
logprobs = mx.log(logprobs + 1e-10)
batch_size = logprobs.shape[0]
current_logprobs = logprobs[mx.arange(batch_size), next_tokens]
mask = (tokens[:, -1] != self.eot).astype(mx.float32)
sum_logprobs = sum_logprobs + current_logprobs * mask
eot_mask = (tokens[:, -1] == self.eot)
next_tokens = mx.where(eot_mask, mx.array(self.eot), next_tokens)
tokens = mx.concatenate([tokens, next_tokens[:, None]], axis=1)
completed = bool(mx.all(tokens[:, -1] == self.eot))
return tokens, completed
def finalize(self, tokens: mx.array, sum_logprobs: mx.array):
"""Finalize decoding by ensuring EOT at end."""
eot_column = mx.full((tokens.shape[0], 1), self.eot, dtype=tokens.dtype)
tokens = mx.concatenate([tokens, eot_column], axis=1)
return tokens, sum_logprobs.tolist()
class MLXBeamSearchDecoder:
"""Beam search decoder using MLX operations."""
def __init__(
self,
beam_size: int,
eot: int,
inference: Any,
patience: Optional[float] = None,
):
self.beam_size = beam_size
self.eot = eot
self.inference = inference
self.patience = patience or 1.0
self.max_candidates: int = round(beam_size * self.patience)
self.finished_sequences: Optional[List[Dict]] = None
assert (
self.max_candidates > 0
), f"Invalid beam size ({beam_size}) or patience ({patience})"
def reset(self):
"""Reset finished sequences for new segment."""
self.finished_sequences = None
def update(
self, tokens: mx.array, logits: mx.array, sum_logprobs: mx.array
) -> Tuple[mx.array, bool]:
"""
Update tokens using beam search.
Args:
tokens: Current token sequences, shape (batch * beam_size, seq_len)
logits: Logits for next token, shape (batch * beam_size, vocab_size)
sum_logprobs: Cumulative log probabilities, shape (batch * beam_size,)
Returns:
Updated tokens and completion flag
"""
if tokens.shape[0] % self.beam_size != 0:
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
n_audio = tokens.shape[0] // self.beam_size
if self.finished_sequences is None:
self.finished_sequences = [{} for _ in range(n_audio)]
logprobs = mx.softmax(logits, axis=-1)
logprobs = mx.log(logprobs + 1e-10)
logprobs_np = np.array(logprobs)
tokens_np = np.array(tokens)
sum_logprobs_np = np.array(sum_logprobs)
next_tokens, source_indices, finished_sequences = [], [], []
new_sum_logprobs = []
for i in range(n_audio):
scores, sources, finished = {}, {}, {}
for j in range(self.beam_size):
idx = i * self.beam_size + j
prefix = tokens_np[idx].tolist()
top_k_indices = np.argsort(logprobs_np[idx])[-self.beam_size - 1:][::-1]
for token_idx in top_k_indices:
logprob = logprobs_np[idx, token_idx]
new_logprob = sum_logprobs_np[idx] + logprob
sequence = tuple(prefix + [int(token_idx)])
scores[sequence] = new_logprob
sources[sequence] = idx
saved = 0
for sequence in sorted(scores, key=scores.get, reverse=True):
if sequence[-1] == self.eot:
finished[sequence] = scores[sequence]
else:
new_sum_logprobs.append(scores[sequence])
next_tokens.append(sequence)
source_indices.append(sources[sequence])
saved += 1
if saved == self.beam_size:
break
finished_sequences.append(finished)
tokens = mx.array(np.array(next_tokens, dtype=np.int32))
sum_logprobs = mx.array(np.array(new_sum_logprobs, dtype=np.float32))
self.inference.rearrange_kv_cache(source_indices)
assert len(self.finished_sequences) == len(finished_sequences)
for previously_finished, newly_finished in zip(
self.finished_sequences, finished_sequences
):
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
if len(previously_finished) >= self.max_candidates:
break
previously_finished[seq] = newly_finished[seq]
completed = all(
len(sequences) >= self.max_candidates
for sequences in self.finished_sequences
)
return tokens, completed
def finalize(self, preceding_tokens: mx.array, sum_logprobs: mx.array):
"""Finalize beam search by selecting best sequences."""
preceding_tokens_np = np.array(preceding_tokens)
sum_logprobs_np = np.array(sum_logprobs)
n_audio = preceding_tokens_np.shape[0] // self.beam_size
tokens_list: List[List[int]] = [[] for _ in range(n_audio)]
sum_logprobs_list: List[float] = [0.0] * n_audio
for i, sequences in enumerate(self.finished_sequences):
if sequences:
best_seq = max(sequences, key=sequences.get)
tokens_list[i] = list(best_seq)
sum_logprobs_list[i] = sequences[best_seq]
else:
idx = i * self.beam_size
tokens_list[i] = preceding_tokens_np[idx].tolist() + [self.eot]
sum_logprobs_list[i] = float(sum_logprobs_np[idx])
max_len = max(len(t) for t in tokens_list)
for i, t in enumerate(tokens_list):
tokens_list[i] = t + [self.eot] * (max_len - len(t))
tokens = mx.array(np.array(tokens_list, dtype=np.int32))
return tokens, sum_logprobs_list
class MLXInference:
"""MLX inference wrapper for beam search KV cache management."""
def __init__(self, model, initial_token_length: int):
self.model = model
self.initial_token_length = initial_token_length
self.kv_cache = None
def rearrange_kv_cache(self, source_indices: List[int]):
"""Rearrange KV cache based on beam search source indices."""
if self.kv_cache is None:
return
if source_indices == list(range(len(source_indices))):
return
source_indices_mx = mx.array(source_indices, dtype=mx.int32)
new_cache = []
for layer_cache in self.kv_cache:
(k, v), (cross_k, cross_v) = layer_cache
new_k = k[source_indices_mx]
new_v = v[source_indices_mx]
new_cache.append(((new_k, new_v), (cross_k, cross_v)))
self.kv_cache = new_cache
def logits(
self,
tokens: mx.array,
audio_features: mx.array,
) -> Tuple[mx.array, List]:
"""Get logits from decoder with KV cache."""
logits, self.kv_cache, cross_qk = self.model.decoder(
tokens, audio_features, kv_cache=self.kv_cache
)
return logits, cross_qk

View File

@@ -1,752 +0,0 @@
"""
MLX whisper AlignAtt streaming decoder
"""
import logging
from time import time
from typing import Any, List, Optional, Tuple
import mlx.core as mx
import numpy as np
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
from whisperlivekit.timed_objects import ASRToken
from whisperlivekit.whisper import DecodingOptions, tokenizer
from whisperlivekit.whisper.audio import N_FRAMES, N_SAMPLES, TOKENS_PER_SECOND
from ..config import AlignAttConfig
from .decoder_state import MLXDecoderState
from .decoders import MLXBeamSearchDecoder, MLXGreedyDecoder, MLXInference
DEC_PAD = 50257
logger = logging.getLogger(__name__)
class MLXTokenBuffer: #should try to make it heritate from classic simul whisper class
"""Token buffer for MLX-based decoding."""
def __init__(self, text="", tokenizer=None, prefix_token_ids=None):
self.text = text
self.prefix_token_ids = prefix_token_ids or []
self.tokenizer = tokenizer
self.pending_token_ids = []
def as_token_ids(self, tokenizer=None):
if tokenizer is None:
tokenizer = self.tokenizer
if tokenizer is None:
raise ValueError("Tokenizer is not set.")
return self.prefix_token_ids + tokenizer.encode(self.text)
def as_mlx_array(self) -> mx.array:
"""Return tokens as MLX array."""
tok_ids = self.as_token_ids()
return mx.array([tok_ids], dtype=mx.int32)
def as_mlx_array_beam(self, beam: int) -> mx.array:
"""Return tokens as MLX array repeated for beam search."""
t = self.as_mlx_array()
return mx.repeat(t, beam, axis=0)
def as_text(self):
return self.text
@staticmethod
def empty(*a, **kw):
return MLXTokenBuffer(*a, **kw)
@staticmethod
def from_text(text, *a, **kw):
return MLXTokenBuffer(*a, text=text, **kw)
def is_empty(self):
return self.text is None or self.text == ""
def trim_words(self, num=1, after=0):
"""Trim words from the beginning of the context."""
tokenizer = self.tokenizer
assert tokenizer is not None, "Tokenizer is not set."
ids = tokenizer.encode(self.text[after:])
words, wids = self.tokenizer.split_to_word_tokens(ids)
if not words:
return 0
self.text = self.text[:after] + "".join(words[num:])
return sum(len(wi) for wi in wids[:num])
def append_token_ids(self, token_ids):
"""Append token IDs to the buffer, handling incomplete UTF-8."""
tokenizer = self.tokenizer
assert tokenizer is not None, "Tokenizer is not set."
all_tokens = self.pending_token_ids + token_ids
decoded = tokenizer.decode(all_tokens)
replacement_char = "\ufffd"
if replacement_char in decoded:
if len(all_tokens) > 1:
decoded_partial = tokenizer.decode(all_tokens[:-1])
if replacement_char not in decoded_partial:
self.text += decoded_partial
self.pending_token_ids = [all_tokens[-1]]
else:
self.pending_token_ids = all_tokens
else:
self.pending_token_ids = all_tokens
else:
self.text += decoded
self.pending_token_ids = []
def mlx_median_filter(x: mx.array, filter_width: int) -> mx.array:
"""
Apply median filter along the last axis.
Args:
x: Input array of shape (..., T)
filter_width: Width of the median filter (should be odd)
Returns:
Filtered array of same shape
"""
if filter_width <= 1:
return x
pad_width = filter_width // 2
shape = x.shape
left_pad = mx.repeat(x[..., :1], pad_width, axis=-1)
right_pad = mx.repeat(x[..., -1:], pad_width, axis=-1)
x_padded = mx.concatenate([left_pad, x, right_pad], axis=-1)
result_shape = list(shape)
result = []
for i in range(shape[-1]):
window = x_padded[..., i:i + filter_width]
sorted_window = mx.sort(window, axis=-1)
median_val = sorted_window[..., filter_width // 2:filter_width // 2 + 1]
result.append(median_val)
return mx.concatenate(result, axis=-1)
class MLXAlignAtt:
"""
MLX-native Alignment-based Attention decoder for SimulStreaming.
This class runs entirely on MLX, with no PyTorch dependencies for inference.
"""
@property
def speaker(self):
return self.state.speaker
@speaker.setter
def speaker(self, value):
self.state.speaker = value
@property
def global_time_offset(self):
return self.state.global_time_offset
@global_time_offset.setter
def global_time_offset(self, value):
self.state.global_time_offset = value
def __init__(
self,
cfg: AlignAttConfig,
mlx_model: Any,
) -> None:
"""
Initialize MLX AlignAtt decoder.
Args:
cfg: AlignAtt configuration
mlx_model: MLX Whisper model (full model, not just encoder)
"""
self.model = mlx_model
self.cfg = cfg
logger.info(f"MLX Model dimensions: {self.model.dims}")
self.decode_options = DecodingOptions(
language=cfg.language,
without_timestamps=True,
task=cfg.task
)
self.tokenizer_is_multilingual = cfg.tokenizer_is_multilingual
self.max_text_len = self.model.dims.n_text_ctx
self.num_decoder_layers = len(self.model.decoder.blocks)
if self.cfg.max_context_tokens is None:
self.max_context_tokens = self.max_text_len
else:
self.max_context_tokens = self.cfg.max_context_tokens
# Initialize per-session state
self.state = MLXDecoderState()
self._init_state(cfg)
def _init_state(self, cfg: AlignAttConfig):
"""Initialize the per-session decoder state."""
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
self.state.tokenizer = self.tokenizer
self.state.detected_language = cfg.language if cfg.language != "auto" else None
self.state.global_time_offset = 0.0
self.state.last_attend_frame = -cfg.rewind_threshold
self.state.speaker = -1
if cfg.cif_ckpt_path is None or not cfg.cif_ckpt_path:
if cfg.never_fire:
self.state.never_fire = True
self.state.always_fire = False
else:
self.state.always_fire = True
self.state.never_fire = False
else:
logger.warning("CIF checkpoint provided but MLX CIF not implemented. Using always_fire=True")
self.state.always_fire = True
self.state.never_fire = cfg.never_fire
self._build_alignment_source()
suppress_tokens = [
self.tokenizer.transcribe,
self.tokenizer.translate,
self.tokenizer.sot,
self.tokenizer.sot_prev,
self.tokenizer.sot_lm,
self.tokenizer.no_timestamps,
] + list(self.tokenizer.all_language_tokens)
if self.tokenizer.no_speech is not None:
suppress_tokens.append(self.tokenizer.no_speech)
self.state.suppress_tokens = tuple(sorted(set(suppress_tokens)))
logger.debug(f"Suppress tokens: {self.state.suppress_tokens}")
self.init_tokens()
self.init_context()
self.state.decoder_type = cfg.decoder_type
if cfg.decoder_type == "greedy":
logger.info("Using MLX greedy decoder")
self.state.token_decoder = MLXGreedyDecoder(0.0, self.tokenizer.eot)
elif cfg.decoder_type == "beam":
logger.info("Using MLX beam decoder")
self.state.inference = MLXInference(self.model, self.state.initial_token_length)
self.state.token_decoder = MLXBeamSearchDecoder(
inference=self.state.inference,
eot=self.tokenizer.eot,
beam_size=cfg.beam_size
)
def _build_alignment_source(self):
"""Build alignment source mapping from model's alignment_heads."""
self.state.align_source = {}
self.state.num_align_heads = 0
alignment_heads = self.model.alignment_heads
if alignment_heads is None:
logger.warning("No alignment heads found in model")
return
if hasattr(alignment_heads, 'tolist'):
heads_list = alignment_heads.tolist()
else:
heads_list = np.array(alignment_heads).tolist()
for layer_rank, head_id in heads_list:
layer_rank = int(layer_rank)
head_id = int(head_id)
heads = self.state.align_source.get(layer_rank, [])
heads.append((self.state.num_align_heads, head_id))
self.state.align_source[layer_rank] = heads
self.state.num_align_heads += 1
def warmup(self, audio: np.ndarray):
"""Warmup the model with sample audio."""
try:
self.insert_audio(audio)
self.infer(is_last=True)
self.refresh_segment(complete=True)
logger.info("MLX model warmed up successfully")
except Exception as e:
logger.exception(f"MLX model warmup failed: {e}")
def create_tokenizer(self, language=None):
"""Create tokenizer for the given language."""
self.tokenizer = tokenizer.get_tokenizer(
multilingual=self.tokenizer_is_multilingual,
language=language,
num_languages=self.model.num_languages,
task=self.decode_options.task
)
self.state.tokenizer = self.tokenizer
def init_context(self):
"""Initialize context buffer."""
kw = {
'tokenizer': self.tokenizer,
'prefix_token_ids': [self.tokenizer.sot_prev]
}
self.state.context = MLXTokenBuffer.empty(**kw)
if self.cfg.static_init_prompt is not None:
self.state.context = MLXTokenBuffer.from_text(self.cfg.static_init_prompt, **kw)
if self.cfg.init_prompt is not None:
self.state.context.text += self.cfg.init_prompt
def init_tokens(self):
"""Initialize token sequence."""
logger.debug(f"init tokens, {len(self.state.segments)}")
self.state.initial_tokens = mx.array(
[self.tokenizer.sot_sequence_including_notimestamps],
dtype=mx.int32
)
self.state.initial_token_length = self.state.initial_tokens.shape[1]
self.state.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
logger.debug(f"init tokens after, {len(self.state.segments)}")
self.state.tokens = [self.state.initial_tokens]
def trim_context(self):
"""Trim context if too long."""
logger.info("Trimming context")
c = len(self.state.context.as_token_ids()) - len(self.state.context.prefix_token_ids)
logger.info(f"Context text: {self.state.context.as_text()}")
l = sum(t.shape[1] for t in self.state.tokens) + c
if self.cfg.static_init_prompt is None:
after = 0
else:
after = len(self.cfg.static_init_prompt)
while c > self.max_context_tokens or l > self.max_text_len - 20:
t = self.state.context.trim_words(after=after)
l -= t
c -= t
logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
if t == 0:
break
logger.info(f"Context after trim: {self.state.context.text} (len: {l})")
def refresh_segment(self, complete=False):
"""Refresh segment state."""
logger.debug("Refreshing segment:")
self.init_tokens()
self.state.last_attend_frame = -self.cfg.rewind_threshold
self.state.cumulative_time_offset = 0.0
self.init_context()
logger.debug(f"Context: {self.state.context}")
if not complete and len(self.state.segments) > 2:
self.state.segments = self.state.segments[-2:]
else:
logger.debug("removing all segments.")
self.state.segments = []
self.state.log_segments += 1
self.state.pending_incomplete_tokens = []
def fire_at_boundary(self, chunked_encoder_feature: mx.array) -> bool:
"""Check if we should fire at word boundary (CIF-based)."""
if self.state.always_fire:
return True
if self.state.never_fire:
return False
return True
def _current_tokens(self) -> mx.array:
"""Get current token sequence for decoding."""
toks = self.state.tokens
if toks[0].shape[0] == 1:
toks[0] = mx.repeat(toks[0], self.cfg.beam_size, axis=0)
if not self.state.context.is_empty():
context_toks = self.state.context.as_mlx_array_beam(self.cfg.beam_size)
toks = [context_toks] + toks
# Concatenate all tokens
if len(toks) > 1:
current_tokens = mx.concatenate(toks, axis=1)
else:
current_tokens = toks[0]
logger.debug("debug print current_tokens:")
self.debug_print_tokens(current_tokens)
return current_tokens
def debug_print_tokens(self, tokens: mx.array):
"""Debug print token sequences."""
tokens_np = np.array(tokens)
for i in range(min(self.cfg.beam_size, tokens_np.shape[0])):
logger.debug(self.tokenizer.decode_with_timestamps(tokens_np[i].tolist()))
def segments_len(self) -> float:
"""Get total length of audio segments in seconds."""
return sum(s.shape[0] for s in self.state.segments) / 16000
def _apply_minseglen(self) -> bool:
"""Check if we have enough audio to process."""
segments_len = self.segments_len()
if segments_len < self.cfg.audio_min_len:
logger.debug("waiting for next segment")
return False
return True
def insert_audio(self, segment: np.ndarray = None):
"""Insert audio segment into buffer."""
if segment is not None:
if hasattr(segment, 'numpy'):
segment = segment.numpy()
self.state.segments.append(segment)
removed_len = 0
segments_len = self.segments_len()
while len(self.state.segments) > 1 and segments_len > self.cfg.audio_max_len:
removed_len = self.state.segments[0].shape[0] / 16000
segments_len -= removed_len
self.state.last_attend_frame -= int(TOKENS_PER_SECOND * removed_len)
self.state.cumulative_time_offset += removed_len
self.state.segments = self.state.segments[1:]
logger.debug(f"remove segments: {len(self.state.segments)} {len(self.state.tokens)}, cumulative offset: {self.state.cumulative_time_offset:.2f}s")
if len(self.state.tokens) > 1:
# Convert MLX array to list for context
token_list = np.array(self.state.tokens[1][0, :]).tolist()
self.state.context.append_token_ids(token_list)
self.state.tokens = [self.state.initial_tokens] + self.state.tokens[2:]
return removed_len
def _clean_cache(self):
"""Clean the kv_cache after each inference step."""
self.state.clean_cache()
def _suppress_tokens(self, logits: mx.array) -> mx.array:
"""Apply token suppression to logits."""
if self.state.suppress_tokens:
suppress_indices = mx.array(list(self.state.suppress_tokens), dtype=mx.int32)
logits = logits.at[:, suppress_indices].add(-float('inf'))
return logits
def lang_id(self, encoder_features: mx.array) -> Tuple[mx.array, List[dict]]:
"""Language detection from encoder features."""
n_audio = encoder_features.shape[0]
x = mx.array([[self.tokenizer.sot]] * n_audio, dtype=mx.int32)
logits, _, _ = self.model.decoder(x, encoder_features, kv_cache=None)
logits = logits[:, 0]
mask = mx.ones(logits.shape[-1], dtype=mx.bool_)
language_token_indices = mx.array(list(self.tokenizer.all_language_tokens), dtype=mx.int32)
mask = mask.at[language_token_indices].add(False)
logits = mx.where(mask, mx.array(-float('inf')), logits)
language_tokens = mx.argmax(logits, axis=-1)
language_token_probs = mx.softmax(logits, axis=-1)
probs_np = np.array(language_token_probs)
language_probs = [
{
c: float(probs_np[i, j])
for j, c in zip(self.tokenizer.all_language_tokens, self.tokenizer.all_language_codes)
}
for i in range(n_audio)
]
self._clean_cache()
return language_tokens, language_probs
def infer(self, is_last: bool = False) -> List[ASRToken]:
"""
Main inference method.
Args:
is_last: Whether this is the final chunk
Returns:
List of timestamped ASR tokens
"""
new_segment = True
if len(self.state.segments) == 0:
logger.debug("No segments, nothing to do")
return []
if not self._apply_minseglen():
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
return []
if len(self.state.segments) > 1:
input_segments = np.concatenate(self.state.segments, axis=0)
else:
input_segments = self.state.segments[0]
beg_encode = time()
mlx_mel_padded = mlx_log_mel_spectrogram(
audio=input_segments,
n_mels=self.model.dims.n_mels,
padding=N_SAMPLES
)
mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2)
encoder_feature = self.model.encoder(mlx_mel[None])
content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0]) / 2)
mx.eval(encoder_feature)
end_encode = time()
logger.debug(f'MLX Encoder duration: {end_encode - beg_encode:.3f}s')
if self.cfg.language == "auto" and self.state.detected_language is None and self.state.first_timestamp:
seconds_since_start = self.segments_len() - self.state.first_timestamp
if seconds_since_start >= 2.0:
language_tokens, language_probs = self.lang_id(encoder_feature)
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
print(f"Detected language: {top_lan} with p={p:.4f}")
self.create_tokenizer(top_lan)
self.state.last_attend_frame = -self.cfg.rewind_threshold
self.state.cumulative_time_offset = 0.0
self.init_tokens()
self.init_context()
self.state.detected_language = top_lan
logger.info(f"Tokenizer language: {self.tokenizer.language}")
self.trim_context()
current_tokens = self._current_tokens()
fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :])
sum_logprobs = mx.zeros((self.cfg.beam_size,), dtype=mx.float32)
completed = False
attn_of_alignment_heads = None
most_attended_frame = None
token_len_before_decoding = current_tokens.shape[1]
l_absolute_timestamps = []
accumulated_cross_attns = []
audio_duration_s = self.segments_len()
max_tokens_per_chunk = max(50, int(audio_duration_s * TOKENS_PER_SECOND * 2.0))
tokens_produced_this_chunk = 0
while not completed and current_tokens.shape[1] < self.max_text_len:
tokens_produced_this_chunk += 1
if tokens_produced_this_chunk > max_tokens_per_chunk:
logger.warning(f"[Loop Detection] Too many tokens ({tokens_produced_this_chunk}) for {audio_duration_s:.2f}s audio. Breaking.")
current_tokens = current_tokens[:, :token_len_before_decoding]
break
if new_segment:
tokens_for_logits = current_tokens
else:
tokens_for_logits = current_tokens[:, -1:]
if self.state.decoder_type == "greedy":
logits, self.state.kv_cache, cross_qk = self.model.decoder(
tokens_for_logits, encoder_feature, kv_cache=self.state.kv_cache
)
else:
logits, cross_qk = self.state.inference.logits(tokens_for_logits, encoder_feature)
mx.eval(logits)
accumulated_cross_attns.append(cross_qk)
if new_segment and self.tokenizer.no_speech is not None:
probs_at_sot = mx.softmax(logits[:, self.state.sot_index, :], axis=-1)
no_speech_probs = np.array(probs_at_sot[:, self.tokenizer.no_speech]).tolist()
if no_speech_probs[0] > self.cfg.nonspeech_prob:
logger.info("no speech, stop")
break
logits = logits[:, -1, :] # Last token logits
# Suppress tokens at segment start
if new_segment:
blank_tokens = self.tokenizer.encode(" ") + [self.tokenizer.eot]
logits = logits.at[:, blank_tokens].add(-float('inf'))
new_segment = False
logits = self._suppress_tokens(logits)
current_tokens, completed = self.state.token_decoder.update(
current_tokens, logits, sum_logprobs
)
mx.eval(current_tokens)
logger.debug(f"Decoding completed: {completed}")
self.debug_print_tokens(current_tokens)
attn_of_alignment_heads = self._process_cross_attention(
accumulated_cross_attns, content_mel_len
)
most_attended_frames = mx.argmax(attn_of_alignment_heads[:, -1, :], axis=-1)
most_attended_frames_np = np.array(most_attended_frames)
absolute_timestamps = [
(frame * 0.02 + self.state.cumulative_time_offset)
for frame in most_attended_frames_np.tolist()
]
logger.debug(str(most_attended_frames_np.tolist()) + " most att frames")
logger.debug(f"Absolute timestamps: {absolute_timestamps}")
most_attended_frame = int(most_attended_frames_np[0])
l_absolute_timestamps.append(absolute_timestamps[0])
if completed:
current_tokens = current_tokens[:, :-1]
break
if not is_last and self.state.last_attend_frame - most_attended_frame > self.cfg.rewind_threshold:
current_tokens_np = np.array(current_tokens)
if current_tokens.shape[1] > 1 and current_tokens_np[0, -2] >= DEC_PAD:
logger.debug("omit rewinding from special tokens")
self.state.last_attend_frame = most_attended_frame
else:
logger.debug(f"[rewind detected] current: {most_attended_frame}, last: {self.state.last_attend_frame}")
self.state.last_attend_frame = -self.cfg.rewind_threshold
current_tokens = mx.concatenate(self.state.tokens, axis=1) if len(self.state.tokens) > 0 else self.state.tokens[0]
break
else:
self.state.last_attend_frame = most_attended_frame
if content_mel_len - most_attended_frame <= (4 if is_last else self.cfg.frame_threshold):
logger.debug(f"attention reaches the end: {most_attended_frame}/{content_mel_len}")
current_tokens = current_tokens[:, :-1]
break
tokens_to_split = np.array(current_tokens[0, token_len_before_decoding:]).tolist()
if self.state.pending_incomplete_tokens:
logger.debug(f"[UTF-8 Fix] Prepending pending tokens: {self.state.pending_incomplete_tokens}")
tokens_to_split = self.state.pending_incomplete_tokens + tokens_to_split
if fire_detected or is_last:
new_hypothesis = tokens_to_split
split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis)
else:
split_words, split_tokens = self.tokenizer.split_to_word_tokens(tokens_to_split)
if len(split_words) > 1:
new_hypothesis = [i for sublist in split_tokens[:-1] for i in sublist]
else:
new_hypothesis = []
logger.debug(f"new_hypothesis: {new_hypothesis}")
new_tokens = mx.array([new_hypothesis], dtype=mx.int32)
new_tokens = mx.repeat(new_tokens, self.cfg.beam_size, axis=0)
self.state.tokens.append(new_tokens)
logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}")
self._clean_cache()
if len(l_absolute_timestamps) >= 2 and self.state.first_timestamp is None:
self.state.first_timestamp = l_absolute_timestamps[0]
timestamped_words = []
timestamp_idx = 0
replacement_char = "\ufffd"
for word, word_tokens in zip(split_words, split_tokens):
if replacement_char in word:
logger.warning(f"[UTF-8 Filter] Skipping: {repr(word)}")
timestamp_idx += len(word_tokens)
continue
try:
current_timestamp = l_absolute_timestamps[timestamp_idx]
except IndexError:
pass
timestamp_idx += len(word_tokens)
timestamp_entry = ASRToken(
start=round(current_timestamp, 2),
end=round(current_timestamp + 0.1, 2),
text=word,
speaker=self.state.speaker,
detected_language=self.state.detected_language
).with_offset(self.state.global_time_offset)
timestamped_words.append(timestamp_entry)
self.state.pending_incomplete_tokens = []
MAX_PENDING_TOKENS = 10
if split_words and replacement_char in split_words[-1]:
if len(split_tokens[-1]) <= MAX_PENDING_TOKENS:
self.state.pending_incomplete_tokens = split_tokens[-1]
logger.debug(f"[UTF-8 Fix] Holding incomplete tokens")
else:
logger.warning(f"[UTF-8 Fix] Skipping too many tokens")
return timestamped_words
def _process_cross_attention(
self,
cross_attns: List[List[mx.array]],
content_mel_len: int
) -> mx.array:
"""
Process cross-attention weights for alignment.
Args:
cross_attns: List of cross-attention from each forward pass
Each element is a list of mx.arrays per layer
content_mel_len: Length of actual audio content
Returns:
Processed attention tensor, shape (batch, seq_len, content_mel_len)
"""
attn_of_alignment_heads = [[] for _ in range(self.state.num_align_heads)]
num_decoder_layers = self.num_decoder_layers
if cross_attns and isinstance(cross_attns[0], list):
flattened_attns = [attn for layer_list in cross_attns for attn in layer_list]
else:
flattened_attns = cross_attns
for idx, attn_mat in enumerate(flattened_attns):
if attn_mat is None:
continue
layer_rank = idx % num_decoder_layers
align_heads_in_layer = self.state.align_source.get(layer_rank, [])
if len(align_heads_in_layer) == 0:
continue
attn_mat = mx.softmax(attn_mat, axis=-1)
for align_head_rank, head_id in align_heads_in_layer:
if self.cfg.beam_size == 1:
if attn_mat.ndim == 4:
a = attn_mat[0, head_id, :, :]
else:
a = attn_mat[head_id, :, :]
a = a[None, :, :]
else:
a = attn_mat[:, head_id, :, :]
attn_of_alignment_heads[align_head_rank].append(a)
tmp = []
for mat in attn_of_alignment_heads:
if mat:
t = mx.concatenate(mat, axis=1)
tmp.append(t)
if not tmp:
return mx.zeros((self.cfg.beam_size, 1, content_mel_len))
attn_of_alignment_heads = mx.stack(tmp, axis=1)
std = mx.std(attn_of_alignment_heads, axis=-2, keepdims=True)
mean = mx.mean(attn_of_alignment_heads, axis=-2, keepdims=True)
attn_of_alignment_heads = (attn_of_alignment_heads - mean) / (std + 1e-8)
attn_of_alignment_heads = mlx_median_filter(attn_of_alignment_heads, 7)
attn_of_alignment_heads = mx.mean(attn_of_alignment_heads, axis=1)
attn_of_alignment_heads = attn_of_alignment_heads[:, :, :content_mel_len]
mx.eval(attn_of_alignment_heads)
return attn_of_alignment_heads

View File

@@ -1,107 +0,0 @@
import json
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from mlx.utils import tree_unflatten
from mlx_whisper import whisper
mlx_model_mapping = {
"tiny.en": "mlx-community/whisper-tiny.en-mlx",
"tiny": "mlx-community/whisper-tiny-mlx",
"base.en": "mlx-community/whisper-base.en-mlx",
"base": "mlx-community/whisper-base-mlx",
"small.en": "mlx-community/whisper-small.en-mlx",
"small": "mlx-community/whisper-small-mlx",
"medium.en": "mlx-community/whisper-medium.en-mlx",
"medium": "mlx-community/whisper-medium-mlx",
"large-v1": "mlx-community/whisper-large-v1-mlx",
"large-v2": "mlx-community/whisper-large-v2-mlx",
"large-v3": "mlx-community/whisper-large-v3-mlx",
"large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
"large": "mlx-community/whisper-large-mlx",
}
def load_mlx_encoder(
path_or_hf_repo: str,
dtype: mx.Dtype = mx.float32,
) -> whisper.Whisper:
model_path = Path(path_or_hf_repo)
if not model_path.exists():
model_path = Path(snapshot_download(repo_id=path_or_hf_repo))
with open(str(model_path / "config.json"), "r") as f:
config = json.loads(f.read())
config.pop("model_type", None)
quantization = config.pop("quantization", None)
model_args = whisper.ModelDimensions(**config)
wf = model_path / "weights.safetensors"
if not wf.exists():
wf = model_path / "weights.npz"
weights = mx.load(str(wf))
model = whisper.Whisper(model_args, dtype)
if quantization is not None:
class_predicate = (
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
and f"{p}.scales" in weights
)
nn.quantize(model, **quantization, class_predicate=class_predicate)
weights = tree_unflatten(list(weights.items()))
# we only want to load the encoder weights here.
# Size examples: for tiny.en,
# Decoder weights: 59110771 bytes
# Encoder weights: 15268874 bytes
encoder_weights = {}
encoder_weights['encoder'] = weights['encoder']
del(weights)
model.update(encoder_weights)
mx.eval(model.parameters())
return model
def load_mlx_model(
path_or_hf_repo: str,
dtype: mx.Dtype = mx.float32,
) -> whisper.Whisper:
model_path = Path(path_or_hf_repo)
if not model_path.exists():
model_path = Path(snapshot_download(repo_id=path_or_hf_repo))
with open(str(model_path / "config.json"), "r") as f:
config = json.loads(f.read())
config.pop("model_type", None)
quantization = config.pop("quantization", None)
model_args = whisper.ModelDimensions(**config)
wf = model_path / "weights.safetensors"
if not wf.exists():
wf = model_path / "weights.npz"
weights = mx.load(str(wf))
model = whisper.Whisper(model_args, dtype)
if quantization is not None:
class_predicate = (
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
and f"{p}.scales" in weights
)
nn.quantize(model, **quantization, class_predicate=class_predicate)
weights = tree_unflatten(list(weights.items()))
model.update(weights)
mx.eval(model.parameters())
return model

View File

@@ -1,721 +0,0 @@
import logging
import os
from time import time
from typing import List, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from whisperlivekit.backend_support import (faster_backend_available,
mlx_backend_available)
from whisperlivekit.timed_objects import ASRToken
from whisperlivekit.whisper import DecodingOptions, tokenizer
from whisperlivekit.whisper.audio import (N_FRAMES, N_SAMPLES,
TOKENS_PER_SECOND,
log_mel_spectrogram, pad_or_trim)
from whisperlivekit.whisper.decoding import (BeamSearchDecoder, GreedyDecoder,
SuppressTokens)
from whisperlivekit.whisper.timing import median_filter
from ..timed_objects import PUNCTUATION_MARKS
from .beam import BeamPyTorchInference
from .config import AlignAttConfig
from .decoder_state import DecoderState
from .eow_detection import fire_at_boundary, load_cif
from .token_buffer import TokenBuffer
DEC_PAD = 50257
logger = logging.getLogger(__name__)
if mlx_backend_available():
from mlx_whisper.audio import \
log_mel_spectrogram as mlx_log_mel_spectrogram
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
if faster_backend_available():
from faster_whisper.audio import pad_or_trim as fw_pad_or_trim
from faster_whisper.feature_extractor import FeatureExtractor
USE_MLCORE = False
def load_coreml_encoder():
try:
from coremltools.models import MLModel
except ImportError:
logger.warning("coremltools is not installed")
return None
COREML_ENCODER_PATH = os.environ.get("MLCORE_ENCODER_PATH", "whisperlivekit/whisper/whisper_encoder.mlpackage")
_coreml_encoder = MLModel(COREML_ENCODER_PATH)
spec = _coreml_encoder.get_spec()
_coreml_input_name = spec.description.input[0].name if spec.description.input else "mel"
_coreml_output_name = spec.description.output[0].name if spec.description.output else None
return _coreml_encoder, _coreml_input_name, _coreml_output_name
class AlignAtt:
"""
Alignment-based Attention decoder for SimulStreaming.
This class is now hookless - the model can be shared across multiple
sessions, with each session maintaining its own DecoderState.
"""
# Property accessors for backward compatibility
@property
def speaker(self):
return self.state.speaker
@speaker.setter
def speaker(self, value):
self.state.speaker = value
@property
def global_time_offset(self):
return self.state.global_time_offset
@global_time_offset.setter
def global_time_offset(self, value):
self.state.global_time_offset = value
def __init__(
self,
cfg: AlignAttConfig,
loaded_model=None,
mlx_encoder=None,
fw_encoder=None,
) -> None:
# Shared model reference (can be shared across sessions)
self.model = loaded_model
self.mlx_encoder = mlx_encoder
self.fw_encoder = fw_encoder
if fw_encoder:
self.fw_feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels)
self.coreml_encoder_tuple = None
if USE_MLCORE:
self.coreml_encoder_tuple = load_coreml_encoder()
self.use_mlcore = self.coreml_encoder_tuple is not None
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
logger.info(f"Model dimensions: {self.model.dims}")
self.decode_options = DecodingOptions(
language=cfg.language,
without_timestamps=True,
task=cfg.task
)
self.tokenizer_is_multilingual = cfg.tokenizer_is_multilingual
self.max_text_len = self.model.dims.n_text_ctx
self.num_decoder_layers = len(self.model.decoder.blocks)
self.cfg = cfg
if self.cfg.max_context_tokens is None:
self.max_context_tokens = self.max_text_len
else:
self.max_context_tokens = self.cfg.max_context_tokens
# Initialize per-session state
self.state = DecoderState()
self._init_state(cfg)
def _init_state(self, cfg: AlignAttConfig):
"""Initialize the per-session decoder state."""
# Create tokenizer
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
self.state.tokenizer = self.tokenizer
self.state.detected_language = cfg.language if cfg.language != "auto" else None
# Timing state
self.state.global_time_offset = 0.0
self.state.last_attend_frame = -cfg.rewind_threshold
self.state.speaker = -1
# CIF helpers for end-of-word boundary detection
self.state.CIFLinear, self.state.always_fire, self.state.never_fire = load_cif(
cfg,
n_audio_state=self.model.dims.n_audio_state,
device=self.model.device
)
# Build alignment source mapping from model's alignment_heads
self.state.align_source = {}
self.state.num_align_heads = 0
for layer_rank, head_id in self.model.alignment_heads.indices().T:
layer_rank = layer_rank.item()
heads = self.state.align_source.get(layer_rank, [])
heads.append((self.state.num_align_heads, head_id.item()))
self.state.align_source[layer_rank] = heads
self.state.num_align_heads += 1
# Build suppress tokens function
suppress_tokens = [
self.tokenizer.transcribe,
self.tokenizer.translate,
self.tokenizer.sot,
self.tokenizer.sot_prev,
self.tokenizer.sot_lm,
self.tokenizer.no_timestamps,
] + list(self.tokenizer.all_language_tokens)
if self.tokenizer.no_speech is not None:
suppress_tokens.append(self.tokenizer.no_speech)
suppress_tokens = tuple(sorted(set(suppress_tokens)))
logger.debug(f"Suppress tokens: {suppress_tokens}")
sup_tokens = SuppressTokens(suppress_tokens)
self.state.suppress_tokens_fn = lambda logits: sup_tokens.apply(logits, None)
# Initialize tokens
self.init_tokens()
self.init_context()
# Set up decoder type
self.state.decoder_type = cfg.decoder_type
if cfg.decoder_type == "greedy":
logger.info("Using greedy decoder")
self.state.token_decoder = GreedyDecoder(0.0, self.tokenizer.eot)
elif cfg.decoder_type == "beam":
logger.info("Using beam decoder")
self.state.inference = BeamPyTorchInference(self.model, self.state.initial_token_length)
self.state.inference.kv_cache = self.state.kv_cache
self.state.token_decoder = BeamSearchDecoder(
inference=self.state.inference,
eot=self.tokenizer.eot,
beam_size=cfg.beam_size
)
def warmup(self, audio):
try:
self.insert_audio(audio)
self.infer(is_last=True)
self.refresh_segment(complete=True)
logger.info("Model warmed up successfully")
except Exception as e:
logger.exception(f"Model warmup failed: {e}")
def create_tokenizer(self, language=None):
self.tokenizer = tokenizer.get_tokenizer(
multilingual=self.tokenizer_is_multilingual,
language=language,
num_languages=self.model.num_languages,
task=self.decode_options.task
)
self.state.tokenizer = self.tokenizer
def init_context(self):
kw = {'tokenizer': self.tokenizer,
'device': self.model.device,
'prefix_token_ids': [self.tokenizer.sot_prev]}
self.state.context = TokenBuffer.empty(**kw)
if self.cfg.static_init_prompt is not None:
self.state.context = TokenBuffer.from_text(self.cfg.static_init_prompt, **kw)
if self.cfg.init_prompt is not None:
self.state.context.text += self.cfg.init_prompt
def init_tokens(self):
logger.debug(f"init tokens, {len(self.state.segments)}")
# init tokens (mandatory prompt)
self.state.initial_tokens = torch.tensor(
self.tokenizer.sot_sequence_including_notimestamps,
dtype=torch.long,
device=self.model.device).unsqueeze(0)
self.state.initial_token_length = self.state.initial_tokens.shape[1]
self.state.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
logger.debug(f"init tokens after, {len(self.state.segments)}")
self.state.tokens = [self.state.initial_tokens]
def trim_context(self):
logger.info("Trimming context")
c = len(self.state.context.as_token_ids()) - len(self.state.context.prefix_token_ids)
logger.info(f"Context text: {self.state.context.as_text()}")
l = sum(t.shape[1] for t in self.state.tokens) + c
if self.cfg.static_init_prompt is None:
after = 0
else:
after = len(self.cfg.static_init_prompt)
while c > self.max_context_tokens or l > self.max_text_len - 20:
t = self.state.context.trim_words(after=after)
l -= t
c -= t
logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
if t == 0:
break
logger.info(f"Context after trim: {self.state.context.text} (len: {l})")
def logits(
self,
tokens: torch.Tensor,
audio_features: torch.Tensor,
return_cross_attn: bool = False
):
"""Get logits from decoder, optionally returning cross-attention weights."""
if self.state.decoder_type == "greedy":
return self.model.decoder(
tokens, audio_features,
kv_cache=self.state.kv_cache,
return_cross_attn=return_cross_attn
)
else:
logger.debug(f"Logits shape: {tokens.shape}")
return self.state.inference.logits(
tokens, audio_features,
return_cross_attn=return_cross_attn
)
def refresh_segment(self, complete=False):
logger.debug("Refreshing segment:")
self.init_tokens()
self.state.last_attend_frame = -self.cfg.rewind_threshold
self.state.cumulative_time_offset = 0.0
self.init_context()
logger.debug(f"Context: {self.state.context}")
if not complete and len(self.state.segments) > 2:
self.state.segments = self.state.segments[-2:]
else:
logger.debug("removing all segments.")
self.state.segments = []
self.state.log_segments += 1
self.state.pending_incomplete_tokens = []
def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor):
if self.state.always_fire:
return True
if self.state.never_fire:
return False
return fire_at_boundary(chunked_encoder_feature, self.state.CIFLinear)
def _current_tokens(self):
toks = self.state.tokens
# very first infer: duplicate start of seq to beam_size
if toks[0].shape[0] == 1:
toks[0] = toks[0].repeat_interleave(self.cfg.beam_size, dim=0)
if not self.state.context.is_empty():
context_toks = self.state.context.as_tensor_beam(self.cfg.beam_size, device=self.model.device)
toks = [context_toks] + toks
# make it one tensor
if len(toks) > 1:
current_tokens = torch.cat(toks, dim=1)
else:
current_tokens = toks[0]
logger.debug("debug print current_tokens:")
self.debug_print_tokens(current_tokens)
return current_tokens
def debug_print_tokens(self, tokens):
for i in range(self.cfg.beam_size):
logger.debug(self.tokenizer.decode_with_timestamps(tokens[i].tolist()))
### audio buffer
def segments_len(self):
segments_len = sum(s.shape[0] for s in self.state.segments) / 16000
return segments_len
def _apply_minseglen(self):
segments_len = self.segments_len()
# wait for long enough audio to start
if segments_len < self.cfg.audio_min_len:
logger.debug("waiting for next segment")
return False
return True
def insert_audio(self, segment=None):
if segment is not None:
self.state.segments.append(segment)
removed_len = 0
# len of audio is bigger than buffer_len. Going to remove the first segment
segments_len = self.segments_len()
while len(self.state.segments) > 1 and segments_len > self.cfg.audio_max_len:
removed_len = self.state.segments[0].shape[0] / 16000
segments_len -= removed_len
self.state.last_attend_frame -= int(TOKENS_PER_SECOND * removed_len)
self.state.cumulative_time_offset += removed_len # Track cumulative time removed
self.state.segments = self.state.segments[1:]
logger.debug(f"remove segments: {len(self.state.segments)} {len(self.state.tokens)}, cumulative offset: {self.state.cumulative_time_offset:.2f}s")
if len(self.state.tokens) > 1:
self.state.context.append_token_ids(self.state.tokens[1][0, :].tolist())
self.state.tokens = [self.state.initial_tokens] + self.state.tokens[2:]
return removed_len
def _clean_cache(self):
"""Clean the kv_cache after each inference step."""
self.state.clean_cache()
@torch.no_grad()
def lang_id(self, encoder_features):
"""Language detection from encoder features.
This code is trimmed and copy-pasted from whisper.decoding.detect_language.
"""
# forward pass using a single token, startoftranscript
n_audio = encoder_features.shape[0]
x = torch.tensor([[self.tokenizer.sot]] * n_audio).to(self.model.device) # [n_audio, 1]
# Note: don't use kv_cache for language detection
logits = self.model.logits(x, encoder_features)[:, 0]
# collect detected languages; suppress all non-language tokens
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
mask[list(self.tokenizer.all_language_tokens)] = False
logits[:, mask] = -np.inf
language_tokens = logits.argmax(dim=-1)
language_token_probs = logits.softmax(dim=-1).cpu()
language_probs = [
{
c: language_token_probs[i, j].item()
for j, c in zip(self.tokenizer.all_language_tokens, self.tokenizer.all_language_codes)
}
for i in range(n_audio)
]
single = encoder_features.ndim == 2
if single:
language_tokens = language_tokens[0]
language_probs = language_probs[0]
self._clean_cache()
return language_tokens, language_probs
### transcription / translation
@torch.no_grad()
def infer(self, is_last=False):
new_segment = True
if len(self.state.segments) == 0:
logger.debug("No segments, nothing to do")
return []
if not self._apply_minseglen():
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
input_segments = torch.cat(self.state.segments, dim=0)
return []
# input_segments is concatenation of audio, it's one array
if len(self.state.segments) > 1:
input_segments = torch.cat(self.state.segments, dim=0)
else:
input_segments = self.state.segments[0]
beg_encode = time()
if self.use_mlcore:
coreml_encoder, coreml_input_name, coreml_output_name = self.coreml_encoder_tuple
mel_padded = log_mel_spectrogram(
input_segments,
n_mels=self.model.dims.n_mels,
padding=N_SAMPLES,
device="cpu",
).unsqueeze(0)
mel = pad_or_trim(mel_padded, N_FRAMES)
content_mel_len = int((mel_padded.shape[2] - mel.shape[2]) / 2)
mel_np = np.ascontiguousarray(mel.numpy())
ml_inputs = {coreml_input_name or "mel": mel_np}
coreml_outputs = coreml_encoder.predict(ml_inputs)
if coreml_output_name and coreml_output_name in coreml_outputs:
encoder_feature_np = coreml_outputs[coreml_output_name]
else:
encoder_feature_np = next(iter(coreml_outputs.values()))
encoder_feature = torch.as_tensor(
np.array(encoder_feature_np),
device=self.device,
)
if self.mlx_encoder:
mlx_mel_padded = mlx_log_mel_spectrogram(audio=input_segments.detach(), n_mels=self.model.dims.n_mels, padding=N_SAMPLES)
mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2)
mlx_encoder_feature = self.mlx_encoder.encoder(mlx_mel[None])
encoder_feature = torch.as_tensor(mlx_encoder_feature)
content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0])/2)
elif self.fw_encoder:
audio_length_seconds = len(input_segments) / 16000
content_mel_len = int(audio_length_seconds * 100)//2
mel_padded_2 = self.fw_feature_extractor(waveform=input_segments.numpy(), padding=N_SAMPLES)[None, :]
mel = fw_pad_or_trim(mel_padded_2, N_FRAMES, axis=-1)
encoder_feature_ctranslate = self.fw_encoder.encode(mel)
if self.device == 'cpu': #it seems that on gpu, passing StorageView to torch.as_tensor fails and wrapping in the array works
encoder_feature_ctranslate = np.array(encoder_feature_ctranslate)
try:
encoder_feature = torch.as_tensor(encoder_feature_ctranslate, device=self.device)
except TypeError: # Normally the cpu condition should prevent having exceptions, but just in case:
encoder_feature = torch.as_tensor(np.array(encoder_feature_ctranslate), device=self.device)
else:
# mel + padding to 30s
mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES,
device=self.device).unsqueeze(0)
# trim to 3000
mel = pad_or_trim(mel_padded, N_FRAMES)
# the len of actual audio
content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2)
encoder_feature = self.model.encoder(mel)
end_encode = time()
# print('Encoder duration:', end_encode-beg_encode)
if self.cfg.language == "auto" and self.state.detected_language is None and self.state.first_timestamp:
seconds_since_start = self.segments_len() - self.state.first_timestamp
if seconds_since_start >= 2.0:
language_tokens, language_probs = self.lang_id(encoder_feature)
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
print(f"Detected language: {top_lan} with p={p:.4f}")
self.create_tokenizer(top_lan)
self.state.last_attend_frame = -self.cfg.rewind_threshold
self.state.cumulative_time_offset = 0.0
self.init_tokens()
self.init_context()
self.state.detected_language = top_lan
logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}")
self.trim_context()
current_tokens = self._current_tokens()
fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :])
sum_logprobs = torch.zeros(self.cfg.beam_size, device=self.device)
completed = False
# punctuation_stop = False
attn_of_alignment_heads = None
most_attended_frame = None
token_len_before_decoding = current_tokens.shape[1]
l_absolute_timestamps = []
accumulated_cross_attns = []
audio_duration_s = self.segments_len()
max_tokens_per_chunk = max(50, int(audio_duration_s * TOKENS_PER_SECOND * 2.0)) # 2x margin, min 50
tokens_produced_this_chunk = 0
while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens
tokens_produced_this_chunk += 1
if tokens_produced_this_chunk > max_tokens_per_chunk:
logger.warning(f"[Loop Detection] Too many tokens ({tokens_produced_this_chunk}) for {audio_duration_s:.2f}s audio. Breaking.")
current_tokens = current_tokens[:, :token_len_before_decoding] # Discard all new tokens
break
if new_segment:
tokens_for_logits = current_tokens
else:
# only need to use the last token except in the first forward pass
tokens_for_logits = current_tokens[:, -1:]
# Get logits and cross-attention weights from decoder
result = self.logits(tokens_for_logits, encoder_feature, return_cross_attn=True)
logits, cross_attns = result
# Accumulate cross-attention from this forward pass
accumulated_cross_attns.append(cross_attns)
if new_segment and self.tokenizer.no_speech is not None:
probs_at_sot = logits[:, self.state.sot_index, :].float().softmax(dim=-1)
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
if no_speech_probs[0] > self.cfg.nonspeech_prob:
logger.info("no speech, stop")
break
logits = logits[:, -1, :] # logits for the last token
# suppress blank tokens only at the beginning of the segment
if new_segment:
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
new_segment = False
self.state.suppress_tokens_fn(logits)
current_tokens, completed = self.state.token_decoder.update(current_tokens, logits, sum_logprobs)
logger.debug(f"Decoding completed: {completed}, sum_logprobs: {sum_logprobs.tolist()}, tokens: ")
self.debug_print_tokens(current_tokens)
# Process accumulated cross-attention weights for alignment
attn_of_alignment_heads = self._process_cross_attention(accumulated_cross_attns, content_mel_len)
# for each beam, the most attended frame is:
most_attended_frames = torch.argmax(attn_of_alignment_heads[:, -1, :], dim=-1)
# Calculate absolute timestamps accounting for cumulative offset
absolute_timestamps = [
(frame * 0.02 + self.state.cumulative_time_offset)
for frame in most_attended_frames.tolist()
]
logger.debug(str(most_attended_frames.tolist()) + " most att frames")
logger.debug(f"Absolute timestamps: {absolute_timestamps} (offset: {self.state.cumulative_time_offset:.2f}s)")
most_attended_frame = most_attended_frames[0].item()
l_absolute_timestamps.append(absolute_timestamps[0])
logger.debug("current tokens" + str(current_tokens.shape))
if completed:
# stripping the last token, the eot
current_tokens = current_tokens[:, :-1]
break
# for some rare cases where the attention fails
if not is_last and self.state.last_attend_frame - most_attended_frame > self.cfg.rewind_threshold:
if current_tokens.shape[1] > 1 and current_tokens[0, -2] >= DEC_PAD:
logger.debug("omit rewinding from special tokens")
self.state.last_attend_frame = most_attended_frame
else:
logger.debug(
f"[rewind detected] current attention pos: {most_attended_frame}, "
f"last attention pos: {self.state.last_attend_frame}; omit this segment")
self.state.last_attend_frame = -self.cfg.rewind_threshold
current_tokens = torch.cat(self.state.tokens, dim=1) if len(self.state.tokens) > 0 else self.state.tokens[0]
break
else:
self.state.last_attend_frame = most_attended_frame
if content_mel_len - most_attended_frame <= (4 if is_last else self.cfg.frame_threshold):
logger.debug(f"attention reaches the end: {most_attended_frame}/{content_mel_len}")
# stripping the last token, the one that is attended too close to the end
current_tokens = current_tokens[:, :-1]
break
# debug print
for i in range(self.cfg.beam_size):
logger.debug("attn: {}, current pos: {}, current token: {}({})".format(
attn_of_alignment_heads.shape if attn_of_alignment_heads is not None else None,
most_attended_frames[i],
current_tokens[i, -1].item(),
self.tokenizer.decode([current_tokens[i, -1].item()])
))
tokens_to_split = current_tokens[0, token_len_before_decoding:]
# Prepend pending tokens from previous chunk if any
if self.state.pending_incomplete_tokens:
logger.debug(f"[UTF-8 Fix] Prepending {len(self.state.pending_incomplete_tokens)} pending tokens: {self.state.pending_incomplete_tokens}")
pending_tensor = torch.tensor(self.state.pending_incomplete_tokens, dtype=torch.long, device=self.device)
tokens_to_split = torch.cat([pending_tensor, tokens_to_split])
if fire_detected or is_last:
new_hypothesis = tokens_to_split.flatten().tolist()
split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis)
else:
# going to truncate the tokens after the last space
split_words, split_tokens = self.tokenizer.split_to_word_tokens(tokens_to_split.tolist())
if len(split_words) > 1:
new_hypothesis = [i for sublist in split_tokens[:-1] for i in sublist]
else:
new_hypothesis = []
logger.debug(f"new_hypothesis: {new_hypothesis}")
new_tokens = torch.tensor([new_hypothesis], dtype=torch.long).repeat_interleave(self.cfg.beam_size, dim=0).to(
device=self.device,
)
self.state.tokens.append(new_tokens)
logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}")
self._clean_cache()
if len(l_absolute_timestamps) >= 2 and self.state.first_timestamp is None:
self.state.first_timestamp = l_absolute_timestamps[0]
timestamped_words = []
timestamp_idx = 0
replacement_char = "\ufffd"
for word, word_tokens in zip(split_words, split_tokens):
# Skip words containing incomplete UTF-8 from client output
if replacement_char in word:
logger.warning(f"[UTF-8 Filter] Skipping incomplete word from client output: {repr(word)}")
timestamp_idx += len(word_tokens)
continue
try:
current_timestamp = l_absolute_timestamps[timestamp_idx]
except IndexError:
# Use last timestamp if index out of range
logger.warning(f"Timestamp index {timestamp_idx} out of range, using last timestamp")
current_timestamp = l_absolute_timestamps[-1] if l_absolute_timestamps else 0.0
timestamp_idx += len(word_tokens)
timestamp_entry = ASRToken(
start=round(current_timestamp, 2),
end=round(current_timestamp + 0.1, 2),
text=word,
speaker=self.state.speaker,
detected_language=self.state.detected_language
).with_offset(
self.state.global_time_offset
)
timestamped_words.append(timestamp_entry)
# Hold incomplete tokens for next chunk (with limit to prevent hallucination accumulation)
self.state.pending_incomplete_tokens = []
MAX_PENDING_TOKENS = 10 # Real incomplete UTF-8 chars are at most a few tokens
if split_words and replacement_char in split_words[-1]:
if len(split_tokens[-1]) <= MAX_PENDING_TOKENS:
self.state.pending_incomplete_tokens = split_tokens[-1]
logger.debug(f"[UTF-8 Fix] Holding {len(self.state.pending_incomplete_tokens)} incomplete tokens for next chunk")
else:
logger.warning(f"[UTF-8 Fix] Skipping {len(split_tokens[-1])} tokens (exceeds limit of {MAX_PENDING_TOKENS}, likely hallucination)")
return timestamped_words
def _process_cross_attention(
self,
cross_attns: List[torch.Tensor],
content_mel_len: int
) -> torch.Tensor:
"""
Process cross-attention weights from decoder layers for alignment.
Args:
cross_attns: List of cross-attention tensors from each decoder layer.
Each tensor has shape (batch, n_head, seq_len, audio_len)
content_mel_len: Length of actual audio content in mel frames
Returns processed attention tensor for alignment, shape (batch, seq_len, content_mel_len)
"""
attn_of_alignment_heads = [[] for _ in range(self.state.num_align_heads)]
num_decoder_layers = len(self.model.decoder.blocks)
if cross_attns and isinstance(cross_attns[0], list):
flattened_attns: List[torch.Tensor] = [attn for layer_list in cross_attns for attn in layer_list]
else:
flattened_attns = cross_attns
for idx, attn_mat in enumerate(flattened_attns):
layer_rank = idx % num_decoder_layers
# attn_mat shape: (batch, n_head, seq_len, audio_len) or (n_head, seq_len, audio_len) for batch=1
align_heads_in_layer = self.state.align_source.get(layer_rank, [])
if len(align_heads_in_layer) == 0:
continue
attn_mat = F.softmax(attn_mat, dim=-1)
for align_head_rank, head_id in align_heads_in_layer:
if self.cfg.beam_size == 1:
# (n_head, seq_len, audio_len) when squeezed
if attn_mat.dim() == 4:
a = attn_mat[0, head_id, :, :] # (seq_len, audio_len)
else:
a = attn_mat[head_id, :, :]
a = a.unsqueeze(0) # (1, seq_len, audio_len)
else:
# attn_mat: (batch, n_head, seq_len, audio_len)
a = attn_mat[:, head_id, :, :] # (batch, seq_len, audio_len)
attn_of_alignment_heads[align_head_rank].append(a)
tmp = []
for mat in attn_of_alignment_heads:
if mat:
t = torch.cat(mat, dim=1) # (batch, total_seq_len, audio_len)
tmp.append(t)
if not tmp:
return torch.zeros(self.cfg.beam_size, 1, content_mel_len, device=self.device)
# stck al heads: (batch, num_align_heads, seq_len, audio_len)
attn_of_alignment_heads = torch.stack(tmp, dim=1)
std, mean = torch.std_mean(attn_of_alignment_heads, dim=-2, keepdim=True, unbiased=False)
attn_of_alignment_heads = (attn_of_alignment_heads - mean) / (std + 1e-8)
attn_of_alignment_heads = median_filter(attn_of_alignment_heads, 7)
attn_of_alignment_heads = attn_of_alignment_heads.mean(dim=1)
attn_of_alignment_heads = attn_of_alignment_heads[:, :, :content_mel_len]
return attn_of_alignment_heads

View File

@@ -1,96 +0,0 @@
import sys
import torch
class TokenBuffer:
def __init__(self, text="", tokenizer=None, device=None, prefix_token_ids=[]):
self.text = text
self.prefix_token_ids = prefix_token_ids
self.tokenizer = tokenizer
self.device = device
self.pending_token_ids = []
def as_token_ids(self, tokenizer=None):
if tokenizer is None:
tokenizer = self.tokenizer
if tokenizer is None:
raise ValueError("Tokenizer is not set.")
return self.prefix_token_ids + tokenizer.encode(self.text)
def as_tensor(self, device=None):
if device is None:
device = self.device
if device is None:
raise ValueError("Device is not set.")
tok_ids = self.as_token_ids()
return torch.tensor(tok_ids,
dtype=torch.long, device=device).unsqueeze(0)
def as_tensor_beam(self, beam, device=None):
t = self.as_tensor(device=device)
return t.repeat_interleave(beam, dim=0)
def as_text(self):
return self.text
@staticmethod
def empty(*a, **kw):
return TokenBuffer(*a,**kw)
@staticmethod
def from_text(text, *a, **kw):
return TokenBuffer(*a, text=text, **kw)
def is_empty(self):
return self.text is None or self.text == ""
def trim_words(self, num=1, after=0):
'''
num: how many words to trim from the beginning
after: how many characters to skip (length of the static prompt)
'''
tokenizer = self.tokenizer
assert tokenizer is not None, "Tokenizer is not set."
ids = tokenizer.encode(self.text[after:])
words, wids = self.tokenizer.split_to_word_tokens(ids)
# print(words, file=sys.stderr)
# print(wids, file=sys.stderr)
if not words:
return 0
self.text = self.text[:after] + "".join(words[num:])
return sum(len(wi) for wi in wids[:num])
def append_token_ids(self, token_ids):
tokenizer = self.tokenizer
assert tokenizer is not None, "Tokenizer is not set."
all_tokens = self.pending_token_ids + token_ids
decoded = tokenizer.decode(all_tokens)
replacement_char = "\ufffd"
if replacement_char in decoded:
if len(all_tokens) > 1:
decoded_partial = tokenizer.decode(all_tokens[:-1])
if replacement_char not in decoded_partial:
self.text += decoded_partial
self.pending_token_ids = [all_tokens[-1]]
else:
self.pending_token_ids = all_tokens
else:
self.pending_token_ids = all_tokens
else:
self.text += decoded
self.pending_token_ids = []
def as_split_word_tokens(self):
tokenizer = self.tokenizer
assert tokenizer is not None, "Tokenizer is not set."
ids = tokenizer.encode(self.text)
return tokenizer.split_to_word_tokens(ids)

View File

@@ -1,139 +0,0 @@
"""
Thread Safety Configuration for WhisperLiveKit
This module provides thread safety configuration and utilities.
Environment Variables:
WHISPERLIVEKIT_MODEL_LOCK: Enable/disable model locking (default: 1)
Set to "0" to disable for single-connection deployments
WHISPERLIVEKIT_LOCK_TIMEOUT: Lock acquisition timeout in seconds (default: 30)
Usage:
# Enable model locking (default)
export WHISPERLIVEKIT_MODEL_LOCK=1
# Disable for single-connection deployment
export WHISPERLIVEKIT_MODEL_LOCK=0
# Custom timeout
export WHISPERLIVEKIT_LOCK_TIMEOUT=60
"""
import os
import logging
import threading
logger = logging.getLogger(__name__)
# Configuration
USE_MODEL_LOCK = os.environ.get("WHISPERLIVEKIT_MODEL_LOCK", "1") == "1"
LOCK_TIMEOUT = float(os.environ.get("WHISPERLIVEKIT_LOCK_TIMEOUT", "30.0"))
# Global model lock
_model_lock = threading.Lock()
# Log configuration on import
if USE_MODEL_LOCK:
logger.info(f"Model locking ENABLED (timeout: {LOCK_TIMEOUT}s)")
logger.info("For single-connection deployments, set WHISPERLIVEKIT_MODEL_LOCK=0")
else:
logger.warning("Model locking DISABLED - only safe for single-connection deployments")
def get_model_lock():
"""Get the global model lock instance"""
return _model_lock
def acquire_model_lock(timeout=None):
"""
Acquire model lock with timeout.
Args:
timeout: Lock acquisition timeout (default: use LOCK_TIMEOUT)
Returns:
bool: True if lock acquired, False on timeout
"""
if not USE_MODEL_LOCK:
return True
timeout = timeout or LOCK_TIMEOUT
acquired = _model_lock.acquire(timeout=timeout)
if not acquired:
logger.error(f"Failed to acquire model lock within {timeout}s")
return acquired
def release_model_lock():
"""Release model lock"""
if not USE_MODEL_LOCK:
return
try:
_model_lock.release()
except RuntimeError:
# Lock not held - this is fine
pass
class ModelLockContext:
"""Context manager for model lock"""
def __init__(self, timeout=None):
self.timeout = timeout
self.acquired = False
def __enter__(self):
self.acquired = acquire_model_lock(self.timeout)
return self.acquired
def __exit__(self, exc_type, exc_val, exc_tb):
if self.acquired:
release_model_lock()
return False
# Concurrency recommendations
RECOMMENDED_CONNECTIONS_PER_WORKER = 1 if USE_MODEL_LOCK else 1
RECOMMENDED_WORKERS = 4
def print_deployment_recommendations():
"""Print recommended deployment configuration"""
print("\n" + "="*60)
print("WhisperLiveKit Deployment Recommendations")
print("="*60)
if USE_MODEL_LOCK:
print("⚠️ Model locking is ENABLED")
print(" This serializes inference across connections.")
print()
print("Recommended deployment:")
print(f" gunicorn -w {RECOMMENDED_WORKERS} \\")
print(" -k uvicorn.workers.UvicornWorker \\")
print(" --worker-connections 1 \\")
print(" whisperlivekit.basic_server:app")
print()
print("Expected capacity:")
print(f" - {RECOMMENDED_WORKERS} concurrent users (1 per worker)")
print(f" - Memory: ~{RECOMMENDED_WORKERS}x model size")
else:
print("✅ Model locking is DISABLED")
print(" ⚠️ ONLY safe for single-connection deployments")
print()
print("Recommended deployment:")
print(" uvicorn whisperlivekit.basic_server:app \\")
print(" --host 0.0.0.0 --port 8000 \\")
print(" --workers 1")
print()
print("Expected capacity:")
print(" - 1 concurrent user only")
print("="*60 + "\n")
if __name__ == "__main__":
print_deployment_recommendations()

View File

@@ -1,229 +0,0 @@
from dataclasses import dataclass, field
from datetime import timedelta
from typing import Any, Dict, List, Optional, Union
PUNCTUATION_MARKS = {'.', '!', '?', '', '', ''}
def format_time(seconds: float) -> str:
"""Format seconds as HH:MM:SS."""
return str(timedelta(seconds=int(seconds)))
@dataclass
class Timed:
start: Optional[float] = 0
end: Optional[float] = 0
@dataclass
class TimedText(Timed):
text: Optional[str] = ''
speaker: Optional[int] = -1
detected_language: Optional[str] = None
def has_punctuation(self) -> bool:
return any(char in PUNCTUATION_MARKS for char in self.text.strip())
def is_within(self, other: 'TimedText') -> bool:
return other.contains_timespan(self)
def duration(self) -> float:
return self.end - self.start
def contains_timespan(self, other: 'TimedText') -> bool:
return self.start <= other.start and self.end >= other.end
def __bool__(self) -> bool:
return bool(self.text)
def __str__(self) -> str:
return str(self.text)
@dataclass()
class ASRToken(TimedText):
def with_offset(self, offset: float) -> "ASRToken":
"""Return a new token with the time offset added."""
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, detected_language=self.detected_language)
def is_silence(self) -> bool:
return False
@dataclass
class Sentence(TimedText):
pass
@dataclass
class Transcript(TimedText):
"""
represents a concatenation of several ASRToken
"""
@classmethod
def from_tokens(
cls,
tokens: List[ASRToken],
sep: Optional[str] = None,
offset: float = 0
) -> "Transcript":
"""Collapse multiple ASR tokens into a single transcript span."""
sep = sep if sep is not None else ' '
text = sep.join(token.text for token in tokens)
if tokens:
start = offset + tokens[0].start
end = offset + tokens[-1].end
else:
start = None
end = None
return cls(start, end, text)
@dataclass
class SpeakerSegment(Timed):
"""Represents a segment of audio attributed to a specific speaker.
No text nor probability is associated with this segment.
"""
speaker: Optional[int] = -1
pass
@dataclass
class Translation(TimedText):
pass
@dataclass
class Silence():
start: Optional[float] = None
end: Optional[float] = None
duration: Optional[float] = None
is_starting: bool = False
has_ended: bool = False
def compute_duration(self) -> Optional[float]:
if self.start is None or self.end is None:
return None
self.duration = self.end - self.start
return self.duration
def is_silence(self) -> bool:
return True
@dataclass
class Segment(TimedText):
"""Generic contiguous span built from tokens or silence markers."""
start: Optional[float]
end: Optional[float]
text: Optional[str]
speaker: Optional[str]
tokens: Optional[ASRToken] = None
translation: Optional[Translation] = None
@classmethod
def from_tokens(
cls,
tokens: List[Union[ASRToken, Silence]],
is_silence: bool = False
) -> Optional["Segment"]:
"""Return a normalized segment representing the provided tokens."""
if not tokens:
return None
start_token = tokens[0]
end_token = tokens[-1]
if is_silence:
return cls(
start=start_token.start,
end=end_token.end,
text=None,
speaker=-2
)
else:
return cls(
start=start_token.start,
end=end_token.end,
text=''.join(token.text for token in tokens),
speaker=-1,
detected_language=start_token.detected_language
)
def is_silence(self) -> bool:
"""True when this segment represents a silence gap."""
return self.speaker == -2
def to_dict(self) -> Dict[str, Any]:
"""Serialize the segment for frontend consumption."""
_dict: Dict[str, Any] = {
'speaker': int(self.speaker) if self.speaker != -1 else 1,
'text': self.text,
'start': format_time(self.start),
'end': format_time(self.end),
}
if self.translation:
_dict['translation'] = self.translation
if self.detected_language:
_dict['detected_language'] = self.detected_language
return _dict
@dataclass
class PuncSegment(Segment):
pass
class SilentSegment(Segment):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.speaker = -2
self.text = ''
@dataclass
class FrontData():
status: str = ''
error: str = ''
lines: list[Segment] = field(default_factory=list)
buffer_transcription: str = ''
buffer_diarization: str = ''
buffer_translation: str = ''
remaining_time_transcription: float = 0.
remaining_time_diarization: float = 0.
def to_dict(self) -> Dict[str, Any]:
"""Serialize the front-end data payload."""
_dict: Dict[str, Any] = {
'status': self.status,
'lines': [line.to_dict() for line in self.lines if (line.text or line.speaker == -2)],
'buffer_transcription': self.buffer_transcription,
'buffer_diarization': self.buffer_diarization,
'buffer_translation': self.buffer_translation,
'remaining_time_transcription': self.remaining_time_transcription,
'remaining_time_diarization': self.remaining_time_diarization,
}
if self.error:
_dict['error'] = self.error
return _dict
@dataclass
class ChangeSpeaker:
speaker: int
start: int
@dataclass
class State():
"""Unified state class for audio processing.
Contains both persistent state (tokens, buffers) and temporary update buffers
(new_* fields) that are consumed by TokensAlignment.
"""
# Persistent state
tokens: List[ASRToken] = field(default_factory=list)
buffer_transcription: Transcript = field(default_factory=Transcript)
end_buffer: float = 0.0
end_attributed_speaker: float = 0.0
remaining_time_transcription: float = 0.0
remaining_time_diarization: float = 0.0
# Temporary update buffers (consumed by TokensAlignment.update())
new_tokens: List[Union[ASRToken, Silence]] = field(default_factory=list)
new_translation: List[Any] = field(default_factory=list)
new_diarization: List[Any] = field(default_factory=list)
new_tokens_buffer: List[Any] = field(default_factory=list) # only when local agreement
new_translation_buffer= TimedText()

View File

@@ -1,219 +0,0 @@
from time import time
from typing import Any, List, Optional, Tuple, Union
from whisperlivekit.timed_objects import (ASRToken, Segment, PuncSegment, Silence,
SilentSegment, SpeakerSegment,
TimedText)
class TokensAlignment:
def __init__(self, state: Any, args: Any, sep: Optional[str]) -> None:
self.state = state
self.diarization = args.diarization
self._tokens_index: int = 0
self._diarization_index: int = 0
self._translation_index: int = 0
self.all_tokens: List[ASRToken] = []
self.all_diarization_segments: List[SpeakerSegment] = []
self.all_translation_segments: List[Any] = []
self.new_tokens: List[ASRToken] = []
self.new_diarization: List[SpeakerSegment] = []
self.new_translation: List[Any] = []
self.new_translation_buffer: Union[TimedText, str] = TimedText()
self.new_tokens_buffer: List[Any] = []
self.sep: str = sep if sep is not None else ' '
self.beg_loop: Optional[float] = None
self.validated_segments: List[Segment] = []
self.current_line_tokens: List[ASRToken] = []
self.diarization_buffer: List[ASRToken] = []
self.last_punctuation = None
self.last_uncompleted_punc_segment: PuncSegment = None
self.unvalidated_tokens: PuncSegment = []
def update(self) -> None:
"""Drain state buffers into the running alignment context."""
self.new_tokens, self.state.new_tokens = self.state.new_tokens, []
self.new_diarization, self.state.new_diarization = self.state.new_diarization, []
self.new_translation, self.state.new_translation = self.state.new_translation, []
self.new_tokens_buffer, self.state.new_tokens_buffer = self.state.new_tokens_buffer, []
self.all_tokens.extend(self.new_tokens)
self.all_diarization_segments.extend(self.new_diarization)
self.all_translation_segments.extend(self.new_translation)
self.new_translation_buffer = self.state.new_translation_buffer
def add_translation(self, segment: Segment) -> None:
"""Append translated text segments that overlap with a segment."""
if segment.translation is None:
segment.translation = ''
for ts in self.all_translation_segments:
if ts.is_within(segment):
segment.translation += ts.text + (self.sep if ts.text else '')
elif segment.translation:
break
def compute_punctuations_segments(self, tokens: Optional[List[ASRToken]] = None) -> List[PuncSegment]:
"""Group tokens into segments split by punctuation and explicit silence."""
segments = []
segment_start_idx = 0
for i, token in enumerate(self.all_tokens):
if token.is_silence():
previous_segment = PuncSegment.from_tokens(
tokens=self.all_tokens[segment_start_idx: i],
)
if previous_segment:
segments.append(previous_segment)
segment = PuncSegment.from_tokens(
tokens=[token],
is_silence=True
)
segments.append(segment)
segment_start_idx = i+1
else:
if token.has_punctuation():
segment = PuncSegment.from_tokens(
tokens=self.all_tokens[segment_start_idx: i+1],
)
segments.append(segment)
segment_start_idx = i+1
final_segment = PuncSegment.from_tokens(
tokens=self.all_tokens[segment_start_idx:],
)
if final_segment:
segments.append(final_segment)
return segments
def compute_new_punctuations_segments(self) -> List[PuncSegment]:
new_punc_segments = []
segment_start_idx = 0
self.unvalidated_tokens += self.new_tokens
for i, token in enumerate(self.unvalidated_tokens):
if token.is_silence():
previous_segment = PuncSegment.from_tokens(
tokens=self.unvalidated_tokens[segment_start_idx: i],
)
if previous_segment:
new_punc_segments.append(previous_segment)
segment = PuncSegment.from_tokens(
tokens=[token],
is_silence=True
)
new_punc_segments.append(segment)
segment_start_idx = i+1
else:
if token.has_punctuation():
segment = PuncSegment.from_tokens(
tokens=self.unvalidated_tokens[segment_start_idx: i+1],
)
new_punc_segments.append(segment)
segment_start_idx = i+1
self.unvalidated_tokens = self.unvalidated_tokens[segment_start_idx:]
return new_punc_segments
def concatenate_diar_segments(self) -> List[SpeakerSegment]:
"""Merge consecutive diarization slices that share the same speaker."""
if not self.all_diarization_segments:
return []
merged = [self.all_diarization_segments[0]]
for segment in self.all_diarization_segments[1:]:
if segment.speaker == merged[-1].speaker:
merged[-1].end = segment.end
else:
merged.append(segment)
return merged
@staticmethod
def intersection_duration(seg1: TimedText, seg2: TimedText) -> float:
"""Return the overlap duration between two timed segments."""
start = max(seg1.start, seg2.start)
end = min(seg1.end, seg2.end)
return max(0, end - start)
def get_lines_diarization(self) -> Tuple[List[Segment], str]:
"""Build segments when diarization is enabled and track overflow buffer."""
diarization_buffer = ''
punctuation_segments = self.compute_punctuations_segments()
diarization_segments = self.concatenate_diar_segments()
for punctuation_segment in punctuation_segments:
if not punctuation_segment.is_silence():
if diarization_segments and punctuation_segment.start >= diarization_segments[-1].end:
diarization_buffer += punctuation_segment.text
else:
max_overlap = 0.0
max_overlap_speaker = 1
for diarization_segment in diarization_segments:
intersec = self.intersection_duration(punctuation_segment, diarization_segment)
if intersec > max_overlap:
max_overlap = intersec
max_overlap_speaker = diarization_segment.speaker + 1
punctuation_segment.speaker = max_overlap_speaker
segments = []
if punctuation_segments:
segments = [punctuation_segments[0]]
for segment in punctuation_segments[1:]:
if segment.speaker == segments[-1].speaker:
if segments[-1].text:
segments[-1].text += segment.text
segments[-1].end = segment.end
else:
segments.append(segment)
return segments, diarization_buffer
def get_lines(
self,
diarization: bool = False,
translation: bool = False,
current_silence: Optional[Silence] = None
) -> Tuple[List[Segment], str, Union[str, TimedText]]:
"""Return the formatted segments plus buffers, optionally with diarization/translation."""
if diarization:
segments, diarization_buffer = self.get_lines_diarization()
else:
diarization_buffer = ''
for token in self.new_tokens:
if token.is_silence():
if self.current_line_tokens:
self.validated_segments.append(Segment().from_tokens(self.current_line_tokens))
self.current_line_tokens = []
end_silence = token.end if token.has_ended else time() - self.beg_loop
if self.validated_segments and self.validated_segments[-1].is_silence():
self.validated_segments[-1].end = end_silence
else:
self.validated_segments.append(SilentSegment(
start=token.start,
end=end_silence
))
else:
self.current_line_tokens.append(token)
segments = list(self.validated_segments)
if self.current_line_tokens:
segments.append(Segment().from_tokens(self.current_line_tokens))
if current_silence:
end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop
if segments and segments[-1].is_silence():
segments[-1] = SilentSegment(start=segments[-1].start, end=end_silence)
else:
segments.append(SilentSegment(
start=current_silence.start,
end=end_silence
))
if translation:
[self.add_translation(segment) for segment in segments if not segment.is_silence()]
return segments, diarization_buffer, self.new_translation_buffer.text

View File

@@ -1,52 +0,0 @@
import logging
logger = logging.getLogger(__name__)
def load_file(warmup_file=None, timeout=5):
import os
import tempfile
import urllib.request
import librosa
if warmup_file == "":
logger.info(f"Skipping warmup.")
return None
# Download JFK sample if not already present
if warmup_file is None:
jfk_url = "https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav"
temp_dir = tempfile.gettempdir()
warmup_file = os.path.join(temp_dir, "whisper_warmup_jfk.wav")
if not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0:
try:
logger.debug(f"Downloading warmup file from {jfk_url}")
with urllib.request.urlopen(jfk_url, timeout=timeout) as r, open(warmup_file, "wb") as f:
f.write(r.read())
except Exception as e:
logger.warning(f"Warmup file download failed: {e}.")
return None
# Validate file and load
if not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0:
logger.warning(f"Warmup file {warmup_file} is invalid or missing.")
return None
try:
audio, _ = librosa.load(warmup_file, sr=16000)
return audio
except Exception as e:
logger.warning(f"Failed to load warmup file: {e}")
return None
def warmup_asr(asr, warmup_file=None, timeout=5):
"""
Warmup the ASR model by transcribing a short audio file.
"""
audio = load_file(warmup_file=warmup_file, timeout=timeout)
if audio is None:
logger.warning("Warmup file unavailable. Skipping ASR warmup.")
return
asr.transcribe(audio)
logger.info("ASR model is warmed up.")

View File

@@ -1,630 +0,0 @@
:root {
--bg: #ffffff;
--text: #111111;
--muted: #666666;
--border: #e5e5e5;
--chip-bg: rgba(0, 0, 0, 0.04);
--chip-text: #000000;
--spinner-border: #8d8d8d5c;
--spinner-top: #b0b0b0;
--silence-bg: #f3f3f3;
--loading-bg: rgba(255, 77, 77, 0.06);
--button-bg: #ffffff;
--button-border: #e9e9e9;
--wave-stroke: #000000;
--label-dia-text: #868686;
--label-trans-text: #111111;
}
@media (prefers-color-scheme: dark) {
:root:not([data-theme="light"]) {
--bg: #0b0b0b;
--text: #e6e6e6;
--muted: #9aa0a6;
--border: #333333;
--chip-bg: rgba(255, 255, 255, 0.08);
--chip-text: #e6e6e6;
--spinner-border: #555555;
--spinner-top: #dddddd;
--silence-bg: #1a1a1a;
--loading-bg: rgba(255, 77, 77, 0.12);
--button-bg: #111111;
--button-border: #333333;
--wave-stroke: #e6e6e6;
--label-dia-text: #b3b3b3;
--label-trans-text: #ffffff;
}
}
:root[data-theme="dark"] {
--bg: #0b0b0b;
--text: #e6e6e6;
--muted: #9aa0a6;
--border: #333333;
--chip-bg: rgba(255, 255, 255, 0.08);
--chip-text: #e6e6e6;
--spinner-border: #555555;
--spinner-top: #dddddd;
--silence-bg: #1a1a1a;
--loading-bg: rgba(255, 77, 77, 0.12);
--button-bg: #111111;
--button-border: #333333;
--wave-stroke: #e6e6e6;
--label-dia-text: #b3b3b3;
--label-trans-text: #ffffff;
}
:root[data-theme="light"] {
--bg: #ffffff;
--text: #111111;
--muted: #666666;
--border: #e5e5e5;
--chip-bg: rgba(0, 0, 0, 0.04);
--chip-text: #000000;
--spinner-border: #8d8d8d5c;
--spinner-top: #b0b0b0;
--silence-bg: #f3f3f3;
--loading-bg: rgba(255, 77, 77, 0.06);
--button-bg: #ffffff;
--button-border: #e9e9e9;
--wave-stroke: #000000;
--label-dia-text: #868686;
--label-trans-text: #111111;
}
html.is-extension
{
width: 350px;
height: 500px;
}
body {
font-family: ui-sans-serif, system-ui, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol', 'Noto Color Emoji';
margin: 0;
text-align: center;
background-color: var(--bg);
color: var(--text);
height: 100vh;
display: flex;
flex-direction: column;
}
/* Record button */
#recordButton {
width: 50px;
height: 50px;
border: none;
border-radius: 50%;
background-color: var(--button-bg);
cursor: pointer;
transition: all 0.3s ease;
border: 1px solid var(--button-border);
display: flex;
align-items: center;
justify-content: center;
position: relative;
}
#recordButton.recording {
width: 180px;
border-radius: 40px;
justify-content: flex-start;
padding-left: 20px;
}
#recordButton:active {
transform: scale(0.95);
}
.shape-container {
width: 25px;
height: 25px;
display: flex;
align-items: center;
justify-content: center;
flex-shrink: 0;
}
.shape {
width: 25px;
height: 25px;
background-color: rgb(209, 61, 53);
border-radius: 50%;
transition: all 0.3s ease;
}
#recordButton:disabled .shape {
background-color: #6e6d6d;
}
#recordButton.recording .shape {
border-radius: 5px;
width: 25px;
height: 25px;
}
/* Recording elements */
.recording-info {
display: none;
align-items: center;
margin-left: 15px;
flex-grow: 1;
}
#recordButton.recording .recording-info {
display: flex;
}
.wave-container {
width: 60px;
height: 30px;
position: relative;
display: flex;
align-items: center;
justify-content: center;
}
#waveCanvas {
width: 100%;
height: 100%;
}
.timer {
font-size: 14px;
font-weight: 500;
color: var(--text);
margin-left: 10px;
}
#status {
margin-top: 15px;
font-size: 16px;
color: var(--text);
margin-bottom: 0;
}
.header-container {
position: sticky;
top: 0;
background-color: var(--bg);
z-index: 100;
padding: 20px;
}
/* Settings */
.settings-container {
display: flex;
justify-content: center;
align-items: center;
gap: 15px;
position: relative;
flex-wrap: wrap;
}
.buttons-container {
display: flex;
align-items: center;
gap: 15px;
}
.settings {
display: flex;
flex-wrap: wrap;
align-items: flex-start;
gap: 12px;
}
.settings-toggle {
width: 40px;
height: 40px;
border: none;
border-radius: 50%;
background-color: var(--button-bg);
border: 1px solid var(--button-border);
cursor: pointer;
display: none;
align-items: center;
justify-content: center;
transition: all 0.2s ease;
}
.settings-toggle:hover {
background-color: var(--chip-bg);
}
.settings-toggle.active {
background-color: var(--chip-bg);
}
.settings-toggle img {
width: 20px;
height: 20px;
}
@media (max-width: 10000px) {
.settings-toggle {
display: flex;
}
.settings {
display: none;
background: var(--bg);
border: 1px solid var(--border);
border-radius: 18px;
padding: 12px;
}
.settings.visible {
display: flex;
}
}
@media (max-width: 600px) {
.settings-container {
flex-direction: column;
align-items: center;
gap: 10px;
}
.buttons-container {
display: flex;
justify-content: center;
align-items: center;
gap: 15px;
}
}
.field {
display: flex;
flex-direction: column;
align-items: flex-start;
gap: 3px;
}
#chunkSelector,
#websocketInput,
#themeSelector,
#microphoneSelect {
font-size: 16px;
padding: 5px 8px;
border-radius: 8px;
border: 1px solid var(--border);
background-color: var(--button-bg);
color: var(--text);
max-height: 30px;
}
#microphoneSelect {
width: 100%;
max-width: 190px;
min-width: 120px;
}
#chunkSelector:focus,
#websocketInput:focus,
#themeSelector:focus,
#microphoneSelect:focus {
outline: none;
border-color: #007bff;
box-shadow: 0 0 0 3px rgba(0, 123, 255, 0.15);
}
label {
font-size: 13px;
color: var(--muted);
}
.ws-default {
font-size: 12px;
color: var(--muted);
}
/* Segmented pill control for Theme */
.segmented {
display: inline-flex;
align-items: stretch;
border: 1px solid var(--button-border);
background-color: var(--button-bg);
border-radius: 999px;
overflow: hidden;
}
.segmented input[type="radio"] {
position: absolute;
opacity: 0;
pointer-events: none;
}
.theme-selector-container {
display: flex;
align-items: center;
margin-top: 17px;
}
.segmented label {
display: inline-flex;
align-items: center;
gap: 6px;
padding: 6px 12px;
font-size: 14px;
color: var(--muted);
cursor: pointer;
user-select: none;
transition: background-color 0.2s ease, color 0.2s ease;
}
.segmented label span {
display: none;
}
.segmented label:hover span {
display: inline;
}
.segmented label:hover {
background-color: var(--chip-bg);
}
.segmented img {
width: 16px;
height: 16px;
}
.segmented input[type="radio"]:checked + label {
background-color: var(--chip-bg);
color: var(--text);
}
.segmented input[type="radio"]:focus-visible + label,
.segmented input[type="radio"]:focus + label {
outline: 2px solid #007bff;
outline-offset: 2px;
border-radius: 999px;
}
.transcript-container {
flex: 1;
overflow-y: auto;
padding: 20px;
scrollbar-width: none;
-ms-overflow-style: none;
}
.transcript-container::-webkit-scrollbar {
display: none;
}
/* Transcript area */
#linesTranscript {
margin: 0 auto;
max-width: 700px;
text-align: left;
font-size: 16px;
}
#linesTranscript p {
margin: 0px 0;
}
#linesTranscript strong {
color: var(--text);
}
#speaker {
border: 1px solid var(--border);
border-radius: 100px;
padding: 2px 10px;
font-size: 14px;
margin-bottom: 0px;
}
.label_diarization {
background-color: var(--chip-bg);
border-radius: 100px;
padding: 2px 10px;
margin-left: 10px;
display: inline-block;
white-space: nowrap;
font-size: 14px;
margin-bottom: 0px;
color: var(--label-dia-text);
}
.label_transcription {
background-color: var(--chip-bg);
border-radius: 100px;
padding: 2px 10px;
display: inline-block;
white-space: nowrap;
margin-left: 10px;
font-size: 14px;
margin-bottom: 0px;
color: var(--label-trans-text);
}
.label_translation {
background-color: var(--chip-bg);
display: inline-flex;
border-radius: 10px;
padding: 4px 8px;
margin-top: 4px;
font-size: 14px;
color: var(--text);
align-items: flex-start;
gap: 4px;
}
.lag-diarization-value {
margin-left: 10px;
}
.label_translation img {
margin-top: 2px;
}
.label_translation img {
width: 12px;
height: 12px;
}
#timeInfo {
color: var(--muted);
margin-left: 0px;
}
.textcontent {
font-size: 16px;
padding-left: 10px;
margin-bottom: 10px;
margin-top: 1px;
padding-top: 5px;
border-radius: 0px 0px 0px 10px;
}
.buffer_diarization {
color: var(--label-dia-text);
}
.buffer_transcription {
color: #7474748c;
margin-left: 4px;
}
.buffer_translation {
color: #a0a0a0;
margin-left: 6px;
}
.spinner {
display: inline-block;
width: 8px;
height: 8px;
border: 2px solid var(--spinner-border);
border-top: 2px solid var(--spinner-top);
border-radius: 50%;
animation: spin 0.7s linear infinite;
vertical-align: middle;
margin-bottom: 2px;
margin-right: 5px;
}
@keyframes spin {
to {
transform: rotate(360deg);
}
}
.silence {
color: var(--muted);
background-color: var(--silence-bg);
font-size: 13px;
border-radius: 30px;
padding: 2px 10px;
}
.loading {
color: var(--muted);
background-color: var(--loading-bg);
border-radius: 8px 8px 8px 0px;
padding: 2px 10px;
font-size: 14px;
margin-bottom: 0px;
}
/* for smaller screens */
@media (max-width: 200px) {
.header-container {
padding: 15px;
}
.settings-container {
flex-direction: column;
gap: 10px;
}
.buttons-container {
gap: 10px;
}
.settings {
justify-content: center;
gap: 8px;
}
.field {
align-items: center;
}
#websocketInput,
#microphoneSelect {
min-width: 100px;
max-width: 160px;
}
.theme-selector-container {
margin-top: 10px;
}
.transcript-container {
padding: 15px;
}
}
@media (max-width: 480px) {
.header-container {
padding: 10px;
}
.settings {
flex-direction: column;
align-items: center;
gap: 6px;
}
#websocketInput,
#microphoneSelect {
max-width: 140px;
}
.segmented label {
padding: 4px 8px;
font-size: 12px;
}
.segmented img {
width: 14px;
height: 14px;
}
.transcript-container {
padding: 10px;
}
}
.label_language {
background-color: var(--chip-bg);
margin-bottom: 0px;
border-radius: 100px;
padding: 2px 8px;
margin-left: 10px;
display: inline-flex;
align-items: center;
gap: 4px;
font-size: 14px;
color: var(--muted);
}
.speaker-badge {
display: inline-flex;
align-items: center;
justify-content: center;
width: 16px;
height: 16px;
margin-left: -5px;
border-radius: 50%;
font-size: 11px;
line-height: 1;
font-weight: 800;
color: var(--muted);
}

View File

@@ -1,79 +0,0 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>WhisperLiveKit</title>
<link rel="stylesheet" href="live_transcription.css" />
</head>
<body>
<div class="header-container">
<div class="settings-container">
<div class="buttons-container">
<button id="recordButton">
<div class="shape-container">
<div class="shape"></div>
</div>
<div class="recording-info">
<div class="wave-container">
<canvas id="waveCanvas"></canvas>
</div>
<div class="timer">00:00</div>
</div>
</button>
<button id="settingsToggle" class="settings-toggle" title="Show/hide settings">
<img src="web/src/settings.svg" alt="Settings" />
</button>
</div>
<div class="settings">
<div class="field">
<label for="websocketInput">Websocket URL</label>
<input id="websocketInput" type="text" placeholder="ws://host:port/asr" />
</div>
<div class="field">
<label id="microphoneSelectLabel" for="microphoneSelect">Select Microphone</label>
<select id="microphoneSelect">
<option value="">Default Microphone</option>
</select>
</div>
<div class="theme-selector-container">
<div class="segmented" role="radiogroup" aria-label="Theme selector">
<input type="radio" id="theme-system" name="theme" value="system" />
<label for="theme-system" title="System">
<img src="/web/src/system_mode.svg" alt="" />
<span>System</span>
</label>
<input type="radio" id="theme-light" name="theme" value="light" />
<label for="theme-light" title="Light">
<img src="/web/src/light_mode.svg" alt="" />
<span>Light</span>
</label>
<input type="radio" id="theme-dark" name="theme" value="dark" />
<label for="theme-dark" title="Dark">
<img src="/web/src/dark_mode.svg" alt="" />
<span>Dark</span>
</label>
</div>
</div>
</div>
</div>
<p id="status"></p>
</div>
<div class="transcript-container">
<div id="linesTranscript"></div>
</div>
<script src="live_transcription.js"></script>
</body>
</html>

View File

@@ -1,817 +0,0 @@
const isExtension = typeof chrome !== 'undefined' && chrome.runtime && chrome.runtime.getURL;
if (isExtension) {
document.documentElement.classList.add('is-extension');
}
const isWebContext = !isExtension;
let isRecording = false;
let websocket = null;
let recorder = null;
let chunkDuration = 100;
let websocketUrl = "ws://localhost:8000/asr";
let userClosing = false;
let wakeLock = null;
let startTime = null;
let timerInterval = null;
let audioContext = null;
let analyser = null;
let microphone = null;
let workletNode = null;
let recorderWorker = null;
let waveCanvas = document.getElementById("waveCanvas");
let waveCtx = waveCanvas.getContext("2d");
let animationFrame = null;
let waitingForStop = false;
let lastReceivedData = null;
let lastSignature = null;
let availableMicrophones = [];
let selectedMicrophoneId = null;
let serverUseAudioWorklet = null;
let configReadyResolve;
const configReady = new Promise((r) => (configReadyResolve = r));
let outputAudioContext = null;
let audioSource = null;
waveCanvas.width = 60 * (window.devicePixelRatio || 1);
waveCanvas.height = 30 * (window.devicePixelRatio || 1);
waveCtx.scale(window.devicePixelRatio || 1, window.devicePixelRatio || 1);
const statusText = document.getElementById("status");
const recordButton = document.getElementById("recordButton");
const chunkSelector = document.getElementById("chunkSelector");
const websocketInput = document.getElementById("websocketInput");
const websocketDefaultSpan = document.getElementById("wsDefaultUrl");
const linesTranscriptDiv = document.getElementById("linesTranscript");
const timerElement = document.querySelector(".timer");
const themeRadios = document.querySelectorAll('input[name="theme"]');
const microphoneSelect = document.getElementById("microphoneSelect");
const settingsToggle = document.getElementById("settingsToggle");
const settingsDiv = document.querySelector(".settings");
// if (isExtension) {
// chrome.runtime.onInstalled.addListener((details) => {
// if (details.reason.search(/install/g) === -1) {
// return;
// }
// chrome.tabs.create({
// url: chrome.runtime.getURL("welcome.html"),
// active: true
// });
// });
// }
const translationIcon = `<svg xmlns="http://www.w3.org/2000/svg" height="12px" viewBox="0 -960 960 960" width="12px" fill="#5f6368"><path d="m603-202-34 97q-4 11-14 18t-22 7q-20 0-32.5-16.5T496-133l152-402q5-11 15-18t22-7h30q12 0 22 7t15 18l152 403q8 19-4 35.5T868-80q-13 0-22.5-7T831-106l-34-96H603ZM362-401 188-228q-11 11-27.5 11.5T132-228q-11-11-11-28t11-28l174-174q-35-35-63.5-80T190-640h84q20 39 40 68t48 58q33-33 68.5-92.5T484-720H80q-17 0-28.5-11.5T40-760q0-17 11.5-28.5T80-800h240v-40q0-17 11.5-28.5T360-880q17 0 28.5 11.5T400-840v40h240q17 0 28.5 11.5T680-760q0 17-11.5 28.5T640-720h-76q-21 72-63 148t-83 116l96 98-30 82-122-125Zm266 129h144l-72-204-72 204Z"/></svg>`
const silenceIcon = `<svg xmlns="http://www.w3.org/2000/svg" style="vertical-align: text-bottom;" height="14px" viewBox="0 -960 960 960" width="14px" fill="#5f6368"><path d="M514-556 320-752q9-3 19-5.5t21-2.5q66 0 113 47t47 113q0 11-1.5 22t-4.5 22ZM40-200v-32q0-33 17-62t47-44q51-26 115-44t141-18q26 0 49.5 2.5T456-392l-56-54q-9 3-19 4.5t-21 1.5q-66 0-113-47t-47-113q0-11 1.5-21t4.5-19L84-764q-11-11-11-28t11-28q12-12 28.5-12t27.5 12l675 685q11 11 11.5 27.5T816-80q-11 13-28 12.5T759-80L641-200h39q0 33-23.5 56.5T600-120H120q-33 0-56.5-23.5T40-200Zm80 0h480v-32q0-14-4.5-19.5T580-266q-36-18-92.5-36T360-320q-71 0-127.5 18T140-266q-9 5-14.5 14t-5.5 20v32Zm240 0Zm560-400q0 69-24.5 131.5T829-355q-12 14-30 15t-32-13q-13-13-12-31t12-33q30-38 46.5-85t16.5-98q0-51-16.5-97T767-781q-12-15-12.5-33t12.5-32q13-14 31.5-13.5T829-845q42 51 66.5 113.5T920-600Zm-182 0q0 32-10 61.5T700-484q-11 15-29.5 15.5T638-482q-13-13-13.5-31.5T633-549q6-11 9.5-24t3.5-27q0-14-3.5-27t-9.5-25q-9-17-8.5-35t13.5-31q14-14 32.5-13.5T700-716q18 25 28 54.5t10 61.5Z"/></svg>`;
const languageIcon = `<svg xmlns="http://www.w3.org/2000/svg" height="12" viewBox="0 -960 960 960" width="12" fill="#5f6368"><path d="M480-80q-82 0-155-31.5t-127.5-86Q143-252 111.5-325T80-480q0-83 31.5-155.5t86-127Q252-817 325-848.5T480-880q83 0 155.5 31.5t127 86q54.5 54.5 86 127T880-480q0 82-31.5 155t-86 127.5q-54.5 54.5-127 86T480-80Zm0-82q26-36 45-75t31-83H404q12 44 31 83t45 75Zm-104-16q-18-33-31.5-68.5T322-320H204q29 50 72.5 87t99.5 55Zm208 0q56-18 99.5-55t72.5-87H638q-9 38-22.5 73.5T584-178ZM170-400h136q-3-20-4.5-39.5T300-480q0-21 1.5-40.5T306-560H170q-5 20-7.5 39.5T160-480q0 21 2.5 40.5T170-400Zm216 0h188q3-20 4.5-39.5T580-480q0-21-1.5-40.5T574-560H386q-3 20-4.5 39.5T380-480q0 21 1.5 40.5T386-400Zm268 0h136q5-20 7.5-39.5T800-480q0-21-2.5-40.5T790-560H654q3 20 4.5 39.5T660-480q0 21-1.5 40.5T654-400Zm-16-240h118q-29-50-72.5-87T584-782q18 33 31.5 68.5T638-640Zm-234 0h152q-12-44-31-83t-45-75q-26 36-45 75t-31 83Zm-200 0h118q9-38 22.5-73.5T376-782q-56 18-99.5 55T204-640Z"/></svg>`
const speakerIcon = `<svg xmlns="http://www.w3.org/2000/svg" height="16px" style="vertical-align: text-bottom;" viewBox="0 -960 960 960" width="16px" fill="#5f6368"><path d="M480-480q-66 0-113-47t-47-113q0-66 47-113t113-47q66 0 113 47t47 113q0 66-47 113t-113 47ZM160-240v-32q0-34 17.5-62.5T224-378q62-31 126-46.5T480-440q66 0 130 15.5T736-378q29 15 46.5 43.5T800-272v32q0 33-23.5 56.5T720-160H240q-33 0-56.5-23.5T160-240Zm80 0h480v-32q0-11-5.5-20T700-306q-54-27-109-40.5T480-360q-56 0-111 13.5T260-306q-9 5-14.5 14t-5.5 20v32Zm240-320q33 0 56.5-23.5T560-640q0-33-23.5-56.5T480-720q-33 0-56.5 23.5T400-640q0 33 23.5 56.5T480-560Zm0-80Zm0 400Z"/></svg>`;
function getWaveStroke() {
const styles = getComputedStyle(document.documentElement);
const v = styles.getPropertyValue("--wave-stroke").trim();
return v || "#000";
}
let waveStroke = getWaveStroke();
function updateWaveStroke() {
waveStroke = getWaveStroke();
}
function applyTheme(pref) {
if (pref === "light") {
document.documentElement.setAttribute("data-theme", "light");
} else if (pref === "dark") {
document.documentElement.setAttribute("data-theme", "dark");
} else {
document.documentElement.removeAttribute("data-theme");
}
updateWaveStroke();
}
// Persisted theme preference
const savedThemePref = localStorage.getItem("themePreference") || "system";
applyTheme(savedThemePref);
if (themeRadios.length) {
themeRadios.forEach((r) => {
r.checked = r.value === savedThemePref;
r.addEventListener("change", () => {
if (r.checked) {
localStorage.setItem("themePreference", r.value);
applyTheme(r.value);
}
});
});
}
// React to OS theme changes when in "system" mode
const darkMq = window.matchMedia && window.matchMedia("(prefers-color-scheme: dark)");
const handleOsThemeChange = () => {
const pref = localStorage.getItem("themePreference") || "system";
if (pref === "system") updateWaveStroke();
};
if (darkMq && darkMq.addEventListener) {
darkMq.addEventListener("change", handleOsThemeChange);
} else if (darkMq && darkMq.addListener) {
// deprecated, but included for Safari compatibility
darkMq.addListener(handleOsThemeChange);
}
async function enumerateMicrophones() {
try {
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
stream.getTracks().forEach(track => track.stop());
const devices = await navigator.mediaDevices.enumerateDevices();
availableMicrophones = devices.filter(device => device.kind === 'audioinput');
populateMicrophoneSelect();
console.log(`Found ${availableMicrophones.length} microphone(s)`);
} catch (error) {
console.error('Error enumerating microphones:', error);
statusText.textContent = "Error accessing microphones. Please grant permission.";
}
}
function populateMicrophoneSelect() {
if (!microphoneSelect) return;
microphoneSelect.innerHTML = '<option value="">Default Microphone</option>';
availableMicrophones.forEach((device, index) => {
const option = document.createElement('option');
option.value = device.deviceId;
option.textContent = device.label || `Microphone ${index + 1}`;
microphoneSelect.appendChild(option);
});
const savedMicId = localStorage.getItem('selectedMicrophone');
if (savedMicId && availableMicrophones.some(mic => mic.deviceId === savedMicId)) {
microphoneSelect.value = savedMicId;
selectedMicrophoneId = savedMicId;
}
}
function handleMicrophoneChange() {
selectedMicrophoneId = microphoneSelect.value || null;
localStorage.setItem('selectedMicrophone', selectedMicrophoneId || '');
const selectedDevice = availableMicrophones.find(mic => mic.deviceId === selectedMicrophoneId);
const deviceName = selectedDevice ? selectedDevice.label : 'Default Microphone';
console.log(`Selected microphone: ${deviceName}`);
statusText.textContent = `Microphone changed to: ${deviceName}`;
if (isRecording) {
statusText.textContent = "Switching microphone... Please wait.";
stopRecording().then(() => {
setTimeout(() => {
toggleRecording();
}, 1000);
});
}
}
// Helpers
function fmt1(x) {
const n = Number(x);
return Number.isFinite(n) ? n.toFixed(1) : x;
}
let host, port, protocol;
port = 8000;
if (isExtension) {
host = "localhost";
protocol = "ws";
} else {
host = window.location.hostname || "localhost";
port = window.location.port;
protocol = window.location.protocol === "https:" ? "wss" : "ws";
}
const defaultWebSocketUrl = `${protocol}://${host}${port ? ":" + port : ""}/asr`;
// Populate default caption and input
if (websocketDefaultSpan) websocketDefaultSpan.textContent = defaultWebSocketUrl;
websocketInput.value = defaultWebSocketUrl;
websocketUrl = defaultWebSocketUrl;
// Optional chunk selector (guard for presence)
if (chunkSelector) {
chunkSelector.addEventListener("change", () => {
chunkDuration = parseInt(chunkSelector.value);
});
}
// WebSocket input change handling
websocketInput.addEventListener("change", () => {
const urlValue = websocketInput.value.trim();
if (!urlValue.startsWith("ws://") && !urlValue.startsWith("wss://")) {
statusText.textContent = "Invalid WebSocket URL (must start with ws:// or wss://)";
return;
}
websocketUrl = urlValue;
statusText.textContent = "WebSocket URL updated. Ready to connect.";
});
function setupWebSocket() {
return new Promise((resolve, reject) => {
try {
websocket = new WebSocket(websocketUrl);
} catch (error) {
statusText.textContent = "Invalid WebSocket URL. Please check and try again.";
reject(error);
return;
}
websocket.onopen = () => {
statusText.textContent = "Connected to server.";
resolve();
};
websocket.onclose = () => {
if (userClosing) {
if (waitingForStop) {
statusText.textContent = "Processing finalized or connection closed.";
if (lastReceivedData) {
renderLinesWithBuffer(
lastReceivedData.lines || [],
lastReceivedData.buffer_diarization || "",
lastReceivedData.buffer_transcription || "",
lastReceivedData.buffer_translation || "",
0,
0,
true
);
}
}
} else {
statusText.textContent = "Disconnected from the WebSocket server. (Check logs if model is loading.)";
if (isRecording) {
stopRecording();
}
}
isRecording = false;
waitingForStop = false;
userClosing = false;
lastReceivedData = null;
websocket = null;
updateUI();
};
websocket.onerror = () => {
statusText.textContent = "Error connecting to WebSocket.";
reject(new Error("Error connecting to WebSocket"));
};
websocket.onmessage = (event) => {
const data = JSON.parse(event.data);
if (data.type === "config") {
serverUseAudioWorklet = !!data.useAudioWorklet;
statusText.textContent = serverUseAudioWorklet
? "Connected. Using AudioWorklet (PCM)."
: "Connected. Using MediaRecorder (WebM).";
if (configReadyResolve) configReadyResolve();
return;
}
if (data.type === "ready_to_stop") {
console.log("Ready to stop received, finalizing display and closing WebSocket.");
waitingForStop = false;
if (lastReceivedData) {
renderLinesWithBuffer(
lastReceivedData.lines || [],
lastReceivedData.buffer_diarization || "",
lastReceivedData.buffer_transcription || "",
lastReceivedData.buffer_translation || "",
0,
0,
true
);
}
statusText.textContent = "Finished processing audio! Ready to record again.";
recordButton.disabled = false;
if (websocket) {
websocket.close();
}
return;
}
lastReceivedData = data;
const {
lines = [],
buffer_transcription = "",
buffer_diarization = "",
buffer_translation = "",
remaining_time_transcription = 0,
remaining_time_diarization = 0,
status = "active_transcription",
} = data;
renderLinesWithBuffer(
lines,
buffer_diarization,
buffer_transcription,
buffer_translation,
remaining_time_diarization,
remaining_time_transcription,
false,
status
);
};
});
}
function renderLinesWithBuffer(
lines,
buffer_diarization,
buffer_transcription,
buffer_translation,
remaining_time_diarization,
remaining_time_transcription,
isFinalizing = false,
current_status = "active_transcription"
) {
if (current_status === "no_audio_detected") {
linesTranscriptDiv.innerHTML =
"<p style='text-align: center; color: var(--muted); margin-top: 20px;'><em>No audio detected...</em></p>";
return;
}
const showLoading = !isFinalizing && (lines || []).some((it) => it.speaker == 0);
const showTransLag = !isFinalizing && remaining_time_transcription > 0;
const showDiaLag = !isFinalizing && !!buffer_diarization && remaining_time_diarization > 0;
const signature = JSON.stringify({
lines: (lines || []).map((it) => ({ speaker: it.speaker, text: it.text, start: it.start, end: it.end, detected_language: it.detected_language })),
buffer_transcription: buffer_transcription || "",
buffer_diarization: buffer_diarization || "",
buffer_translation: buffer_translation,
status: current_status,
showLoading,
showTransLag,
showDiaLag,
isFinalizing: !!isFinalizing,
});
if (lastSignature === signature) {
const t = document.querySelector(".lag-transcription-value");
if (t) t.textContent = fmt1(remaining_time_transcription);
const d = document.querySelector(".lag-diarization-value");
if (d) d.textContent = fmt1(remaining_time_diarization);
const ld = document.querySelector(".loading-diarization-value");
if (ld) ld.textContent = fmt1(remaining_time_diarization);
return;
}
lastSignature = signature;
const linesHtml = (lines || [])
.map((item, idx) => {
let timeInfo = "";
if (item.start !== undefined && item.end !== undefined) {
timeInfo = ` ${item.start} - ${item.end}`;
}
let speakerLabel = "";
if (item.speaker === -2) {
speakerLabel = `<span class="silence">${silenceIcon}<span id='timeInfo'>${timeInfo}</span></span>`;
} else if (item.speaker == 0 && !isFinalizing) {
speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'><span class="loading-diarization-value">${fmt1(
remaining_time_diarization
)}</span> second(s) of audio are undergoing diarization</span></span>`;
} else if (item.speaker !== 0) {
const speakerNum = `<span class="speaker-badge">${item.speaker}</span>`;
speakerLabel = `<span id="speaker">${speakerIcon}${speakerNum}<span id='timeInfo'>${timeInfo}</span></span>`;
if (item.detected_language) {
speakerLabel += `<span class="label_language">${languageIcon}<span>${item.detected_language}</span></span>`;
}
}
let currentLineText = item.text || "";
if (idx === lines.length - 1) {
if (!isFinalizing && item.speaker !== -2) {
speakerLabel += `<span class="label_transcription"><span class="spinner"></span>Transcription lag <span id='timeInfo'><span class="lag-transcription-value">${fmt1(
remaining_time_transcription
)}</span>s</span></span>`;
if (buffer_diarization && remaining_time_diarization) {
speakerLabel += `<span class="label_diarization"><span class="spinner"></span>Diarization lag<span id='timeInfo'><span class="lag-diarization-value">${fmt1(
remaining_time_diarization
)}</span>s</span></span>`;
}
}
if (buffer_diarization) {
if (isFinalizing) {
currentLineText +=
(currentLineText.length > 0 && buffer_diarization.trim().length > 0 ? " " : "") + buffer_diarization.trim();
} else {
currentLineText += `<span class="buffer_diarization">${buffer_diarization}</span>`;
}
}
if (buffer_transcription) {
if (isFinalizing) {
currentLineText +=
(currentLineText.length > 0 && buffer_transcription.trim().length > 0 ? " " : "") +
buffer_transcription.trim();
} else {
currentLineText += `<span class="buffer_transcription">${buffer_transcription}</span>`;
}
}
}
let translationContent = "";
if (item.translation) {
translationContent += item.translation.trim();
}
if (idx === lines.length - 1 && buffer_translation) {
const bufferPiece = isFinalizing
? buffer_translation
: `<span class="buffer_translation">${buffer_translation}</span>`;
translationContent += translationContent ? `${bufferPiece}` : bufferPiece;
}
if (translationContent.trim().length > 0) {
currentLineText += `
<div>
<div class="label_translation">
${translationIcon}
<span class="translation_text">${translationContent}</span>
</div>
</div>`;
}
return currentLineText.trim().length > 0 || speakerLabel.length > 0
? `<p>${speakerLabel}<br/><div class='textcontent'>${currentLineText}</div></p>`
: `<p>${speakerLabel}<br/></p>`;
})
.join("");
linesTranscriptDiv.innerHTML = linesHtml;
const transcriptContainer = document.querySelector('.transcript-container');
if (transcriptContainer) {
transcriptContainer.scrollTo({ top: transcriptContainer.scrollHeight, behavior: "smooth" });
}
}
function updateTimer() {
if (!startTime) return;
const elapsed = Math.floor((Date.now() - startTime) / 1000);
const minutes = Math.floor(elapsed / 60).toString().padStart(2, "0");
const seconds = (elapsed % 60).toString().padStart(2, "0");
timerElement.textContent = `${minutes}:${seconds}`;
}
function drawWaveform() {
if (!analyser) return;
const bufferLength = analyser.frequencyBinCount;
const dataArray = new Uint8Array(bufferLength);
analyser.getByteTimeDomainData(dataArray);
waveCtx.clearRect(
0,
0,
waveCanvas.width / (window.devicePixelRatio || 1),
waveCanvas.height / (window.devicePixelRatio || 1)
);
waveCtx.lineWidth = 1;
waveCtx.strokeStyle = waveStroke;
waveCtx.beginPath();
const sliceWidth = (waveCanvas.width / (window.devicePixelRatio || 1)) / bufferLength;
let x = 0;
for (let i = 0; i < bufferLength; i++) {
const v = dataArray[i] / 128.0;
const y = (v * (waveCanvas.height / (window.devicePixelRatio || 1))) / 2;
if (i === 0) {
waveCtx.moveTo(x, y);
} else {
waveCtx.lineTo(x, y);
}
x += sliceWidth;
}
waveCtx.lineTo(
waveCanvas.width / (window.devicePixelRatio || 1),
(waveCanvas.height / (window.devicePixelRatio || 1)) / 2
);
waveCtx.stroke();
animationFrame = requestAnimationFrame(drawWaveform);
}
async function startRecording() {
try {
try {
wakeLock = await navigator.wakeLock.request("screen");
} catch (err) {
console.log("Error acquiring wake lock.");
}
let stream;
// chromium extension. in the future, both chrome page audio and mic will be used
if (isExtension) {
try {
stream = await new Promise((resolve, reject) => {
chrome.tabCapture.capture({audio: true}, (s) => {
if (s) {
resolve(s);
} else {
reject(new Error('Tab capture failed or not available'));
}
});
});
try {
outputAudioContext = new (window.AudioContext || window.webkitAudioContext)();
audioSource = outputAudioContext.createMediaStreamSource(stream);
audioSource.connect(outputAudioContext.destination);
} catch (audioError) {
console.warn('could not preserve system audio:', audioError);
}
statusText.textContent = "Using tab audio capture.";
} catch (tabError) {
console.log('Tab capture not available, falling back to microphone', tabError);
const audioConstraints = selectedMicrophoneId
? { audio: { deviceId: { exact: selectedMicrophoneId } } }
: { audio: true };
stream = await navigator.mediaDevices.getUserMedia(audioConstraints);
statusText.textContent = "Using microphone audio.";
}
} else if (isWebContext) {
const audioConstraints = selectedMicrophoneId
? { audio: { deviceId: { exact: selectedMicrophoneId } } }
: { audio: true };
stream = await navigator.mediaDevices.getUserMedia(audioConstraints);
}
audioContext = new (window.AudioContext || window.webkitAudioContext)();
analyser = audioContext.createAnalyser();
analyser.fftSize = 256;
microphone = audioContext.createMediaStreamSource(stream);
microphone.connect(analyser);
if (serverUseAudioWorklet) {
if (!audioContext.audioWorklet) {
throw new Error("AudioWorklet is not supported in this browser");
}
await audioContext.audioWorklet.addModule("/web/pcm_worklet.js");
workletNode = new AudioWorkletNode(audioContext, "pcm-forwarder", { numberOfInputs: 1, numberOfOutputs: 0, channelCount: 1 });
microphone.connect(workletNode);
recorderWorker = new Worker("/web/recorder_worker.js");
recorderWorker.postMessage({
command: "init",
config: {
sampleRate: audioContext.sampleRate,
},
});
recorderWorker.onmessage = (e) => {
if (websocket && websocket.readyState === WebSocket.OPEN) {
websocket.send(e.data.buffer);
}
};
workletNode.port.onmessage = (e) => {
const data = e.data;
const ab = data instanceof ArrayBuffer ? data : data.buffer;
recorderWorker.postMessage(
{
command: "record",
buffer: ab,
},
[ab]
);
};
} else {
try {
recorder = new MediaRecorder(stream, { mimeType: "audio/webm" });
} catch (e) {
recorder = new MediaRecorder(stream);
}
recorder.ondataavailable = (e) => {
if (websocket && websocket.readyState === WebSocket.OPEN) {
if (e.data && e.data.size > 0) {
websocket.send(e.data);
}
}
};
recorder.start(chunkDuration);
}
startTime = Date.now();
timerInterval = setInterval(updateTimer, 1000);
drawWaveform();
isRecording = true;
updateUI();
} catch (err) {
if (window.location.hostname === "0.0.0.0") {
statusText.textContent =
"Error accessing microphone. Browsers may block microphone access on 0.0.0.0. Try using localhost:8000 instead.";
} else {
statusText.textContent = "Error accessing microphone. Please allow microphone access.";
}
console.error(err);
}
}
async function stopRecording() {
if (wakeLock) {
try {
await wakeLock.release();
} catch (e) {
// ignore
}
wakeLock = null;
}
userClosing = true;
waitingForStop = true;
if (websocket && websocket.readyState === WebSocket.OPEN) {
const emptyBlob = new Blob([], { type: "audio/webm" });
websocket.send(emptyBlob);
statusText.textContent = "Recording stopped. Processing final audio...";
}
if (recorder) {
try {
recorder.stop();
} catch (e) {
}
recorder = null;
}
if (recorderWorker) {
recorderWorker.terminate();
recorderWorker = null;
}
if (workletNode) {
try {
workletNode.port.onmessage = null;
} catch (e) {}
try {
workletNode.disconnect();
} catch (e) {}
workletNode = null;
}
if (microphone) {
microphone.disconnect();
microphone = null;
}
if (analyser) {
analyser = null;
}
if (audioContext && audioContext.state !== "closed") {
try {
await audioContext.close();
} catch (e) {
console.warn("Could not close audio context:", e);
}
audioContext = null;
}
if (audioSource) {
audioSource.disconnect();
audioSource = null;
}
if (outputAudioContext && outputAudioContext.state !== "closed") {
outputAudioContext.close()
outputAudioContext = null;
}
if (animationFrame) {
cancelAnimationFrame(animationFrame);
animationFrame = null;
}
if (timerInterval) {
clearInterval(timerInterval);
timerInterval = null;
}
timerElement.textContent = "00:00";
startTime = null;
isRecording = false;
updateUI();
}
async function toggleRecording() {
if (!isRecording) {
if (waitingForStop) {
console.log("Waiting for stop, early return");
return;
}
console.log("Connecting to WebSocket");
try {
if (websocket && websocket.readyState === WebSocket.OPEN) {
await configReady;
await startRecording();
} else {
await setupWebSocket();
await configReady;
await startRecording();
}
} catch (err) {
statusText.textContent = "Could not connect to WebSocket or access mic. Aborted.";
console.error(err);
}
} else {
console.log("Stopping recording");
stopRecording();
}
}
function updateUI() {
recordButton.classList.toggle("recording", isRecording);
recordButton.disabled = waitingForStop;
if (waitingForStop) {
if (statusText.textContent !== "Recording stopped. Processing final audio...") {
statusText.textContent = "Please wait for processing to complete...";
}
} else if (isRecording) {
statusText.textContent = "";
} else {
if (
statusText.textContent !== "Finished processing audio! Ready to record again." &&
statusText.textContent !== "Processing finalized or connection closed."
) {
statusText.textContent = "Click to start transcription";
}
}
if (!waitingForStop) {
recordButton.disabled = false;
}
}
recordButton.addEventListener("click", toggleRecording);
if (microphoneSelect) {
microphoneSelect.addEventListener("change", handleMicrophoneChange);
}
document.addEventListener('DOMContentLoaded', async () => {
try {
await enumerateMicrophones();
} catch (error) {
console.log("Could not enumerate microphones on load:", error);
}
});
navigator.mediaDevices.addEventListener('devicechange', async () => {
console.log('Device change detected, re-enumerating microphones');
try {
await enumerateMicrophones();
} catch (error) {
console.log("Error re-enumerating microphones:", error);
}
});
settingsToggle.addEventListener("click", () => {
settingsDiv.classList.toggle("visible");
settingsToggle.classList.toggle("active");
});
if (isExtension) {
async function checkAndRequestPermissions() {
const micPermission = await navigator.permissions.query({
name: "microphone",
});
const permissionDisplay = document.getElementById("audioPermission");
if (permissionDisplay) {
permissionDisplay.innerText = `MICROPHONE: ${micPermission.state}`;
}
// if (micPermission.state !== "granted") {
// chrome.tabs.create({ url: "welcome.html" });
// }
const intervalId = setInterval(async () => {
const micPermission = await navigator.permissions.query({
name: "microphone",
});
if (micPermission.state === "granted") {
if (permissionDisplay) {
permissionDisplay.innerText = `MICROPHONE: ${micPermission.state}`;
}
clearInterval(intervalId);
}
}, 100);
}
void checkAndRequestPermissions();
}

View File

@@ -1,16 +0,0 @@
class PCMForwarder extends AudioWorkletProcessor {
process(inputs) {
const input = inputs[0];
if (input && input[0] && input[0].length) {
// Forward mono channel (0). If multi-channel, downmixing can be added here.
const channelData = input[0];
const copy = new Float32Array(channelData.length);
copy.set(channelData);
this.port.postMessage(copy, [copy.buffer]);
}
// Keep processor alive
return true;
}
}
registerProcessor('pcm-forwarder', PCMForwarder);

View File

@@ -1,58 +0,0 @@
let sampleRate = 48000;
let targetSampleRate = 16000;
self.onmessage = function (e) {
switch (e.data.command) {
case 'init':
init(e.data.config);
break;
case 'record':
record(e.data.buffer);
break;
}
};
function init(config) {
sampleRate = config.sampleRate;
targetSampleRate = config.targetSampleRate || 16000;
}
function record(inputBuffer) {
const buffer = new Float32Array(inputBuffer);
const resampledBuffer = resample(buffer, sampleRate, targetSampleRate);
const pcmBuffer = toPCM(resampledBuffer);
self.postMessage({ buffer: pcmBuffer }, [pcmBuffer]);
}
function resample(buffer, from, to) {
if (from === to) {
return buffer;
}
const ratio = from / to;
const newLength = Math.round(buffer.length / ratio);
const result = new Float32Array(newLength);
let offsetResult = 0;
let offsetBuffer = 0;
while (offsetResult < result.length) {
const nextOffsetBuffer = Math.round((offsetResult + 1) * ratio);
let accum = 0, count = 0;
for (let i = offsetBuffer; i < nextOffsetBuffer && i < buffer.length; i++) {
accum += buffer[i];
count++;
}
result[offsetResult] = accum / count;
offsetResult++;
offsetBuffer = nextOffsetBuffer;
}
return result;
}
function toPCM(input) {
const buffer = new ArrayBuffer(input.length * 2);
const view = new DataView(buffer);
for (let i = 0; i < input.length; i++) {
const s = Math.max(-1, Math.min(1, input[i]));
view.setInt16(i * 2, s < 0 ? s * 0x8000 : s * 0x7FFF, true);
}
return buffer;
}

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M480-120q-151 0-255.5-104.5T120-480q0-138 90-239.5T440-838q13-2 23 3.5t16 14.5q6 9 6.5 21t-7.5 23q-17 26-25.5 55t-8.5 61q0 90 63 153t153 63q31 0 61.5-9t54.5-25q11-7 22.5-6.5T819-479q10 5 15.5 15t3.5 24q-14 138-117.5 229T480-120Zm0-80q88 0 158-48.5T740-375q-20 5-40 8t-40 3q-123 0-209.5-86.5T364-660q0-20 3-40t8-40q-78 32-126.5 102T200-480q0 116 82 198t198 82Zm-10-270Z"/></svg>

Before

Width:  |  Height:  |  Size: 493 B

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M480-80q-82 0-155-31.5t-127.5-86Q143-252 111.5-325T80-480q0-83 31.5-155.5t86-127Q252-817 325-848.5T480-880q83 0 155.5 31.5t127 86q54.5 54.5 86 127T880-480q0 82-31.5 155t-86 127.5q-54.5 54.5-127 86T480-80Zm0-82q26-36 45-75t31-83H404q12 44 31 83t45 75Zm-104-16q-18-33-31.5-68.5T322-320H204q29 50 72.5 87t99.5 55Zm208 0q56-18 99.5-55t72.5-87H638q-9 38-22.5 73.5T584-178ZM170-400h136q-3-20-4.5-39.5T300-480q0-21 1.5-40.5T306-560H170q-5 20-7.5 39.5T160-480q0 21 2.5 40.5T170-400Zm216 0h188q3-20 4.5-39.5T580-480q0-21-1.5-40.5T574-560H386q-3 20-4.5 39.5T380-480q0 21 1.5 40.5T386-400Zm268 0h136q5-20 7.5-39.5T800-480q0-21-2.5-40.5T790-560H654q3 20 4.5 39.5T660-480q0 21-1.5 40.5T654-400Zm-16-240h118q-29-50-72.5-87T584-782q18 33 31.5 68.5T638-640Zm-234 0h152q-12-44-31-83t-45-75q-26 36-45 75t-31 83Zm-200 0h118q9-38 22.5-73.5T376-782q-56 18-99.5 55T204-640Z"/></svg>

Before

Width:  |  Height:  |  Size: 976 B

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M480-360q50 0 85-35t35-85q0-50-35-85t-85-35q-50 0-85 35t-35 85q0 50 35 85t85 35Zm0 80q-83 0-141.5-58.5T280-480q0-83 58.5-141.5T480-680q83 0 141.5 58.5T680-480q0 83-58.5 141.5T480-280ZM80-440q-17 0-28.5-11.5T40-480q0-17 11.5-28.5T80-520h80q17 0 28.5 11.5T200-480q0 17-11.5 28.5T160-440H80Zm720 0q-17 0-28.5-11.5T760-480q0-17 11.5-28.5T800-520h80q17 0 28.5 11.5T920-480q0 17-11.5 28.5T880-440h-80ZM480-760q-17 0-28.5-11.5T440-800v-80q0-17 11.5-28.5T480-920q17 0 28.5 11.5T520-880v80q0 17-11.5 28.5T480-760Zm0 720q-17 0-28.5-11.5T440-80v-80q0-17 11.5-28.5T480-200q17 0 28.5 11.5T520-160v80q0 17-11.5 28.5T480-40ZM226-678l-43-42q-12-11-11.5-28t11.5-29q12-12 29-12t28 12l42 43q11 12 11 28t-11 28q-11 12-27.5 11.5T226-678Zm494 495-42-43q-11-12-11-28.5t11-27.5q11-12 27.5-11.5T734-282l43 42q12 11 11.5 28T777-183q-12 12-29 12t-28-12Zm-42-495q-12-11-11.5-27.5T678-734l42-43q11-12 28-11.5t29 11.5q12 12 12 29t-12 28l-43 42q-12 11-28 11t-28-11ZM183-183q-12-12-12-29t12-28l43-42q12-11 28.5-11t27.5 11q12 11 11.5 27.5T282-226l-42 43q-11 12-28 11.5T183-183Zm297-297Z"/></svg>

Before

Width:  |  Height:  |  Size: 1.2 KiB

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M433-80q-27 0-46.5-18T363-142l-9-66q-13-5-24.5-12T307-235l-62 26q-25 11-50 2t-39-32l-47-82q-14-23-8-49t27-43l53-40q-1-7-1-13.5v-27q0-6.5 1-13.5l-53-40q-21-17-27-43t8-49l47-82q14-23 39-32t50 2l62 26q11-8 23-15t24-12l9-66q4-26 23.5-44t46.5-18h94q27 0 46.5 18t23.5 44l9 66q13 5 24.5 12t22.5 15l62-26q25-11 50-2t39 32l47 82q14 23 8 49t-27 43l-53 40q1 7 1 13.5v27q0 6.5-2 13.5l53 40q21 17 27 43t-8 49l-48 82q-14 23-39 32t-50-2l-60-26q-11 8-23 15t-24 12l-9 66q-4 26-23.5 44T527-80h-94Zm7-80h79l14-106q31-8 57.5-23.5T639-327l99 41 39-68-86-65q5-14 7-29.5t2-31.5q0-16-2-31.5t-7-29.5l86-65-39-68-99 42q-22-23-48.5-38.5T533-694l-13-106h-79l-14 106q-31 8-57.5 23.5T321-633l-99-41-39 68 86 64q-5 15-7 30t-2 32q0 16 2 31t7 30l-86 65 39 68 99-42q22 23 48.5 38.5T427-266l13 106Zm42-180q58 0 99-41t41-99q0-58-41-99t-99-41q-59 0-99.5 41T342-480q0 58 40.5 99t99.5 41Zm-2-140Z"/></svg>

Before

Width:  |  Height:  |  Size: 982 B

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M514-556 320-752q9-3 19-5.5t21-2.5q66 0 113 47t47 113q0 11-1.5 22t-4.5 22ZM40-200v-32q0-33 17-62t47-44q51-26 115-44t141-18q26 0 49.5 2.5T456-392l-56-54q-9 3-19 4.5t-21 1.5q-66 0-113-47t-47-113q0-11 1.5-21t4.5-19L84-764q-11-11-11-28t11-28q12-12 28.5-12t27.5 12l675 685q11 11 11.5 27.5T816-80q-11 13-28 12.5T759-80L641-200h39q0 33-23.5 56.5T600-120H120q-33 0-56.5-23.5T40-200Zm80 0h480v-32q0-14-4.5-19.5T580-266q-36-18-92.5-36T360-320q-71 0-127.5 18T140-266q-9 5-14.5 14t-5.5 20v32Zm240 0Zm560-400q0 69-24.5 131.5T829-355q-12 14-30 15t-32-13q-13-13-12-31t12-33q30-38 46.5-85t16.5-98q0-51-16.5-97T767-781q-12-15-12.5-33t12.5-32q13-14 31.5-13.5T829-845q42 51 66.5 113.5T920-600Zm-182 0q0 32-10 61.5T700-484q-11 15-29.5 15.5T638-482q-13-13-13.5-31.5T633-549q6-11 9.5-24t3.5-27q0-14-3.5-27t-9.5-25q-9-17-8.5-35t13.5-31q14-14 32.5-13.5T700-716q18 25 28 54.5t10 61.5Z"/></svg>

Before

Width:  |  Height:  |  Size: 984 B

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M480-480q-66 0-113-47t-47-113q0-66 47-113t113-47q66 0 113 47t47 113q0 66-47 113t-113 47ZM160-240v-32q0-34 17.5-62.5T224-378q62-31 126-46.5T480-440q66 0 130 15.5T736-378q29 15 46.5 43.5T800-272v32q0 33-23.5 56.5T720-160H240q-33 0-56.5-23.5T160-240Zm80 0h480v-32q0-11-5.5-20T700-306q-54-27-109-40.5T480-360q-56 0-111 13.5T260-306q-9 5-14.5 14t-5.5 20v32Zm240-320q33 0 56.5-23.5T560-640q0-33-23.5-56.5T480-720q-33 0-56.5 23.5T400-640q0 33 23.5 56.5T480-560Zm0-80Zm0 400Z"/></svg>

Before

Width:  |  Height:  |  Size: 592 B

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M396-396q-32-32-58.5-67T289-537q-5 14-6.5 28.5T281-480q0 83 58 141t141 58q14 0 28.5-2t28.5-6q-39-22-74-48.5T396-396Zm85 196q-56 0-107-21t-91-61q-40-40-61-91t-21-107q0-51 17-97.5t50-84.5q13-14 32-9.5t27 24.5q21 55 52.5 104t73.5 91q42 42 91 73.5T648-326q20 8 24.5 27t-9.5 32q-38 33-84.5 50T481-200Zm223-192q-16-5-23-20.5t-4-32.5q9-48-6-94.5T621-621q-35-35-80.5-49.5T448-677q-17 3-32-4t-21-23q-6-16 1.5-31t23.5-19q69-15 138 4.5T679-678q51 51 71 120t5 138q-4 17-19 25t-32 3ZM480-840q-17 0-28.5-11.5T440-880v-40q0-17 11.5-28.5T480-960q17 0 28.5 11.5T520-920v40q0 17-11.5 28.5T480-840Zm0 840q-17 0-28.5-11.5T440-40v-40q0-17 11.5-28.5T480-120q17 0 28.5 11.5T520-80v40q0 17-11.5 28.5T480 0Zm255-734q-12-12-12-28.5t12-28.5l28-28q11-11 27.5-11t28.5 11q12 12 12 28.5T819-762l-28 28q-12 12-28 12t-28-12ZM141-141q-12-12-12-28.5t12-28.5l28-28q12-12 28-12t28 12q12 12 12 28.5T225-169l-28 28q-11 11-27.5 11T141-141Zm739-299q-17 0-28.5-11.5T840-480q0-17 11.5-28.5T880-520h40q17 0 28.5 11.5T960-480q0 17-11.5 28.5T920-440h-40Zm-840 0q-17 0-28.5-11.5T0-480q0-17 11.5-28.5T40-520h40q17 0 28.5 11.5T120-480q0 17-11.5 28.5T80-440H40Zm779 299q-12 12-28.5 12T762-141l-28-28q-12-12-12-28t12-28q12-12 28.5-12t28.5 12l28 28q11 11 11 27.5T819-141ZM226-735q-12 12-28.5 12T169-735l-28-28q-11-11-11-27.5t11-28.5q12-12 28.5-12t28.5 12l28 28q12 12 12 28t-12 28Zm170 339Z"/></svg>

Before

Width:  |  Height:  |  Size: 1.4 KiB

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="m603-202-34 97q-4 11-14 18t-22 7q-20 0-32.5-16.5T496-133l152-402q5-11 15-18t22-7h30q12 0 22 7t15 18l152 403q8 19-4 35.5T868-80q-13 0-22.5-7T831-106l-34-96H603ZM362-401 188-228q-11 11-27.5 11.5T132-228q-11-11-11-28t11-28l174-174q-35-35-63.5-80T190-640h84q20 39 40 68t48 58q33-33 68.5-92.5T484-720H80q-17 0-28.5-11.5T40-760q0-17 11.5-28.5T80-800h240v-40q0-17 11.5-28.5T360-880q17 0 28.5 11.5T400-840v40h240q17 0 28.5 11.5T680-760q0 17-11.5 28.5T640-720h-76q-21 72-63 148t-83 116l96 98-30 82-122-125Zm266 129h144l-72-204-72 204Z"/></svg>

Before

Width:  |  Height:  |  Size: 650 B

View File

@@ -1,116 +0,0 @@
import base64
import importlib.resources as resources
import logging
logger = logging.getLogger(__name__)
def get_web_interface_html():
"""Loads the HTML for the web interface using importlib.resources."""
try:
with resources.files('whisperlivekit.web').joinpath('live_transcription.html').open('r', encoding='utf-8') as f:
return f.read()
except Exception as e:
logger.error(f"Error loading web interface HTML: {e}")
return "<html><body><h1>Error loading interface</h1></body></html>"
def get_inline_ui_html():
"""Returns the complete web interface HTML with all assets embedded in a single call."""
try:
with resources.files('whisperlivekit.web').joinpath('live_transcription.html').open('r', encoding='utf-8') as f:
html_content = f.read()
with resources.files('whisperlivekit.web').joinpath('live_transcription.css').open('r', encoding='utf-8') as f:
css_content = f.read()
with resources.files('whisperlivekit.web').joinpath('live_transcription.js').open('r', encoding='utf-8') as f:
js_content = f.read()
with resources.files('whisperlivekit.web').joinpath('pcm_worklet.js').open('r', encoding='utf-8') as f:
worklet_code = f.read()
with resources.files('whisperlivekit.web').joinpath('recorder_worker.js').open('r', encoding='utf-8') as f:
worker_code = f.read()
js_content = js_content.replace(
'await audioContext.audioWorklet.addModule("/web/pcm_worklet.js");',
'const workletBlob = new Blob([`' + worklet_code + '`], { type: "application/javascript" });\n' +
'const workletUrl = URL.createObjectURL(workletBlob);\n' +
'await audioContext.audioWorklet.addModule(workletUrl);'
)
js_content = js_content.replace(
'recorderWorker = new Worker("/web/recorder_worker.js");',
'const workerBlob = new Blob([`' + worker_code + '`], { type: "application/javascript" });\n' +
'const workerUrl = URL.createObjectURL(workerBlob);\n' +
'recorderWorker = new Worker(workerUrl);'
)
# SVG files
with resources.files('whisperlivekit.web').joinpath('src', 'system_mode.svg').open('r', encoding='utf-8') as f:
system_svg = f.read()
system_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(system_svg.encode('utf-8')).decode('utf-8')}"
with resources.files('whisperlivekit.web').joinpath('src', 'light_mode.svg').open('r', encoding='utf-8') as f:
light_svg = f.read()
light_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(light_svg.encode('utf-8')).decode('utf-8')}"
with resources.files('whisperlivekit.web').joinpath('src', 'dark_mode.svg').open('r', encoding='utf-8') as f:
dark_svg = f.read()
dark_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(dark_svg.encode('utf-8')).decode('utf-8')}"
with resources.files('whisperlivekit.web').joinpath('src', 'settings.svg').open('r', encoding='utf-8') as f:
settings = f.read()
settings_uri = f"data:image/svg+xml;base64,{base64.b64encode(settings.encode('utf-8')).decode('utf-8')}"
# Replace external references
html_content = html_content.replace(
'<link rel="stylesheet" href="live_transcription.css" />',
f'<style>\n{css_content}\n</style>'
)
html_content = html_content.replace(
'<script src="live_transcription.js"></script>',
f'<script>\n{js_content}\n</script>'
)
# Replace SVG references
html_content = html_content.replace(
'<img src="/web/src/system_mode.svg" alt="" />',
f'<img src="{system_data_uri}" alt="" />'
)
html_content = html_content.replace(
'<img src="/web/src/light_mode.svg" alt="" />',
f'<img src="{light_data_uri}" alt="" />'
)
html_content = html_content.replace(
'<img src="/web/src/dark_mode.svg" alt="" />',
f'<img src="{dark_data_uri}" alt="" />'
)
html_content = html_content.replace(
'<img src="web/src/settings.svg" alt="Settings" />',
f'<img src="{settings_uri}" alt="" />'
)
return html_content
except Exception as e:
logger.error(f"Error creating embedded web interface: {e}")
return "<html><body><h1>Error loading embedded interface</h1></body></html>"
if __name__ == '__main__':
import pathlib
import uvicorn
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from starlette.staticfiles import StaticFiles
import whisperlivekit.web as webpkg
app = FastAPI()
web_dir = pathlib.Path(webpkg.__file__).parent
app.mount("/web", StaticFiles(directory=str(web_dir)), name="web")
@app.get("/")
async def get():
return HTMLResponse(get_inline_ui_html())
uvicorn.run(app=app)

View File

@@ -1,642 +0,0 @@
import hashlib
import io
import json
import os
import urllib
import warnings
from pathlib import Path
from typing import Dict, List, Optional, Union
import torch
from torch import Tensor
from tqdm import tqdm
from whisperlivekit.whisper.audio import (load_audio, log_mel_spectrogram,
pad_or_trim)
from whisperlivekit.whisper.decoding import (DecodingOptions, DecodingResult,
decode, detect_language)
from whisperlivekit.whisper.model import ModelDimensions, Whisper
from whisperlivekit.whisper.transcribe import transcribe
from whisperlivekit.whisper.version import __version__
_MODELS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
"large-v3-turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
"turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
}
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
_ALIGNMENT_HEADS = {
"tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
"base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
"base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
"small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
"small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
"large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
"turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
}
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
os.makedirs(root, exist_ok=True)
expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, os.path.basename(url))
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
with open(download_target, "rb") as f:
model_bytes = f.read()
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return model_bytes if in_memory else download_target
else:
warnings.warn(
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
)
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(
total=int(source.info().get("Content-Length")),
ncols=80,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
)
return model_bytes if in_memory else download_target
def available_models() -> List[str]:
"""Returns the names of available models"""
return list(_MODELS.keys())
def _infer_dims_from_config(path: str) -> Optional[ModelDimensions]:
"""
attempt to infer ModelDimensions from a HF style config.json located
next to the given checkpoint, usefull for distilled models/MLX models.
"""
candidates = []
if os.path.isdir(path):
candidates.append(os.path.join(path, "config.json"))
else:
candidates.append(os.path.join(os.path.dirname(path), "config.json"))
for candidate in candidates:
if not os.path.isfile(candidate):
continue
with open(candidate, "r", encoding="utf-8") as f:
config = json.load(f)
# native Whisper format
native_keys = ["n_mels", "n_audio_ctx", "n_audio_state", "n_audio_head",
"n_audio_layer", "n_vocab", "n_text_ctx", "n_text_state",
"n_text_head", "n_text_layer"]
if all(k in config for k in native_keys):
return ModelDimensions(
n_mels=config["n_mels"],
n_audio_ctx=config["n_audio_ctx"],
n_audio_state=config["n_audio_state"],
n_audio_head=config["n_audio_head"],
n_audio_layer=config["n_audio_layer"],
n_vocab=config["n_vocab"],
n_text_ctx=config["n_text_ctx"],
n_text_state=config["n_text_state"],
n_text_head=config["n_text_head"],
n_text_layer=config["n_text_layer"],
)
# HuggingFace format
try:
return ModelDimensions(
n_mels=config["num_mel_bins"],
n_audio_ctx=config["max_source_positions"],
n_audio_state=config["d_model"],
n_audio_head=config["encoder_attention_heads"],
n_audio_layer=config.get("encoder_layers")
or config["num_hidden_layers"],
n_vocab=config["vocab_size"],
n_text_ctx=config["max_target_positions"],
n_text_state=config["d_model"],
n_text_head=config["decoder_attention_heads"],
n_text_layer=config["decoder_layers"],
)
except KeyError as err:
warnings.warn(f"Missing key {err} in HuggingFace config {candidate}")
return None
return None
def _convert_hf_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
converts a HF checkpoint state_dict into the naming convention used by
default whisper
"""
if not any(k.startswith("model.") for k in state_dict):
return state_dict
def map_block(prefix: str, target_prefix: str, remainder: str) -> Optional[str]:
if remainder.startswith("self_attn."):
suffix = remainder.split(".", 1)[1]
mapping = {
"q_proj": "attn.query",
"k_proj": "attn.key",
"v_proj": "attn.value",
"out_proj": "attn.out",
}
stem = mapping.get(suffix.split(".")[0])
if stem:
rest = suffix.split(".", 1)[1] if "." in suffix else ""
return f"{target_prefix}.{stem}" + (f".{rest}" if rest else "")
elif remainder == "self_attn_layer_norm.weight":
return f"{target_prefix}.attn_ln.weight"
elif remainder == "self_attn_layer_norm.bias":
return f"{target_prefix}.attn_ln.bias"
elif remainder.startswith("encoder_attn."):
suffix = remainder.split(".", 1)[1]
mapping = {
"q_proj": "cross_attn.query",
"k_proj": "cross_attn.key",
"v_proj": "cross_attn.value",
"out_proj": "cross_attn.out",
}
stem = mapping.get(suffix.split(".", 1)[0])
if stem:
rest = suffix.split(".", 1)[1] if "." in suffix else ""
return f"{target_prefix}.{stem}" + (f".{rest}" if rest else "")
elif remainder == "encoder_attn_layer_norm.weight":
return f"{target_prefix}.cross_attn_ln.weight"
elif remainder == "encoder_attn_layer_norm.bias":
return f"{target_prefix}.cross_attn_ln.bias"
elif remainder.startswith("fc1."):
return f"{target_prefix}.mlp.0.{remainder.split('.',1)[1]}"
elif remainder.startswith("fc2."):
return f"{target_prefix}.mlp.2.{remainder.split('.',1)[1]}"
elif remainder == "final_layer_norm.weight":
return f"{target_prefix}.mlp_ln.weight"
elif remainder == "final_layer_norm.bias":
return f"{target_prefix}.mlp_ln.bias"
return None
converted = {}
for key, value in state_dict.items():
if not key.startswith("model."):
continue
subkey = key[len("model.") :]
if subkey.startswith("encoder.layers."):
parts = subkey.split(".")
layer_idx = parts[2]
remainder = ".".join(parts[3:])
mapped = map_block(subkey, f"encoder.blocks.{layer_idx}", remainder)
elif subkey.startswith("decoder.layers."):
parts = subkey.split(".")
layer_idx = parts[2]
remainder = ".".join(parts[3:])
mapped = map_block(subkey, f"decoder.blocks.{layer_idx}", remainder)
elif subkey.startswith("encoder.conv") or subkey.startswith("decoder.conv"):
mapped = subkey
elif subkey == "encoder.embed_positions.weight":
mapped = "encoder.positional_embedding"
elif subkey == "decoder.embed_positions.weight":
mapped = "decoder.positional_embedding"
elif subkey == "encoder.layer_norm.weight":
mapped = "encoder.ln_post.weight"
elif subkey == "encoder.layer_norm.bias":
mapped = "encoder.ln_post.bias"
elif subkey.startswith("decoder.embed_tokens."):
mapped = subkey.replace("embed_tokens", "token_embedding", 1)
elif subkey == "decoder.layer_norm.weight":
mapped = "decoder.ln.weight"
elif subkey == "decoder.layer_norm.bias":
mapped = "decoder.ln.bias"
else:
mapped = None
if mapped:
converted[mapped] = value
return converted if converted else state_dict
def _convert_mlx_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Converts an mlx whisper checkpoint to a default openai whisper one
"""
if not any("mlp1" in k or "mlp2" in k for k in state_dict):
return state_dict
converted = {}
for key, value in state_dict.items():
if key == "alignment_heads":
continue
new_key = key.replace(".mlp1.", ".mlp.0.").replace(".mlp2.", ".mlp.2.")
converted[new_key] = value
return converted
def _load_lora_state(lora_path: str):
safe_path = os.path.join(lora_path, "adapter_model.safetensors")
bin_path = os.path.join(lora_path, "adapter_model.bin")
if os.path.isfile(safe_path):
try:
from safetensors.torch import load_file
except ImportError as exc:
raise ImportError(
"Loading LoRA adapters stored as .safetensors requires the `safetensors` package."
) from exc
return load_file(safe_path)
if os.path.isfile(bin_path):
return torch.load(bin_path, map_location="cpu")
raise FileNotFoundError(
f"No adapter weights found under {lora_path}. Expected adapter_model.safetensors or adapter_model.bin."
)
def _collapse_hf_module_name(module: str):
if module.startswith("base_model."):
module = module[len("base_model.") :]
if module.startswith("model.model."):
module = module[len("model.") :]
if not module.startswith("model."):
module = f"model.{module}"
return module
def _resolve_lora_path(lora_path: Optional[str]) -> Optional[str]:
"""
Resolve LoRA adapter path - handles both local paths and HuggingFace repo IDs.
If lora_path is a local directory containing adapter files, returns it as-is.
If lora_path looks like a HuggingFace repo ID (contains '/'), downloads and caches it.
"""
if not lora_path:
return None
# Check if it's already a valid local path
if os.path.isdir(lora_path):
config_path = os.path.join(lora_path, "adapter_config.json")
if os.path.isfile(config_path):
return lora_path
# Try to download from HuggingFace Hub
if "/" in lora_path:
try:
from huggingface_hub import snapshot_download
local_path = snapshot_download(
repo_id=lora_path,
allow_patterns=["adapter_config.json", "adapter_model.*"],
)
return local_path
except Exception as e:
raise FileNotFoundError(
f"Could not find LoRA adapter at local path or HuggingFace Hub: {lora_path}. Error: {e}"
)
raise FileNotFoundError(
f"LoRA path '{lora_path}' is not a valid local directory or HuggingFace repo ID."
)
def _apply_lora_adapter(state_dict: Dict[str, Tensor], lora_path: Optional[str]):
if not lora_path:
return
# Resolve path (handles HuggingFace Hub download)
lora_path = _resolve_lora_path(lora_path)
if not lora_path:
return
config_path = os.path.join(lora_path, "adapter_config.json")
if not os.path.isfile(config_path):
raise FileNotFoundError(f"Missing adapter_config.json inside {lora_path}")
with open(config_path, "r", encoding="utf-8") as handle:
config = json.load(handle)
if config.get("peft_type") != "LORA":
raise ValueError("Only LoRA adapters are supported.")
r = config.get("r")
alpha = config.get("lora_alpha") or config.get("alpha")
if not r or not alpha:
raise ValueError("LoRA config must include `r` and `lora_alpha`.")
scaling = alpha / r
adapter_state = _load_lora_state(lora_path)
lora_layers: Dict[str, Dict[str, Tensor]] = {}
for key, tensor in adapter_state.items():
if key.endswith("lora_A.weight"):
module = key[: -len(".lora_A.weight")]
lora_layers.setdefault(module, {})["A"] = tensor
elif key.endswith("lora_B.weight"):
module = key[: -len(".lora_B.weight")]
lora_layers.setdefault(module, {})["B"] = tensor
if not lora_layers:
raise ValueError(f"No LoRA tensors found in {lora_path}")
for module, parts in lora_layers.items():
if "A" not in parts or "B" not in parts:
raise ValueError(f"Incomplete LoRA tensors for module '{module}'")
hf_module = _collapse_hf_module_name(module)
hf_weight_key = f"{hf_module}.weight"
delta = parts["B"] @ parts["A"]
delta = delta * scaling
converted = _convert_hf_state_dict({hf_weight_key: delta})
if not converted:
raise KeyError(f"Failed to map LoRA module '{module}' into Whisper state dict.")
target_name, delta_tensor = next(iter(converted.items()))
if target_name not in state_dict:
raise KeyError(
f"LoRA module '{module}' mapped to '{target_name}', but the base model has no such parameter."
)
state_dict[target_name] = state_dict[target_name] + delta_tensor.to(
dtype=state_dict[target_name].dtype, device=state_dict[target_name].device
)
def _load_checkpoint(
file_path: Union[str, Path],
device: str,
in_memory: bool = False,
checkpoint_bytes: Optional[bytes] = None,
) -> Dict[str, torch.Tensor]:
"""
Load a checkpoint from a single file.
Handles .pt, .bin, and .safetensors formats.
"""
if checkpoint_bytes is not None:
with io.BytesIO(checkpoint_bytes) as fp:
return torch.load(fp, map_location=device)
file_path = Path(file_path)
suffix = file_path.suffix.lower()
if suffix == '.safetensors':
try:
from safetensors.torch import load_file
except ImportError:
raise ImportError(
"Please install safetensors to load .safetensors model files: `pip install safetensors`"
)
return load_file(str(file_path), device=device)
else:
if in_memory:
with open(file_path, "rb") as f:
checkpoint_bytes = f.read()
with io.BytesIO(checkpoint_bytes) as fp:
return torch.load(fp, map_location=device)
else:
with open(file_path, "rb") as fp:
return torch.load(fp, map_location=device)
def _load_sharded_checkpoint(
shard_files: List[Path],
device: str,
) -> Dict[str, torch.Tensor]:
"""
Load a sharded checkpoint (multiple .safetensors or .bin files).
Merges all shards into a single state dict.
"""
merged_state_dict = {}
first_suffix = shard_files[0].suffix.lower()
if first_suffix == '.safetensors':
try:
from safetensors.torch import load_file
except ImportError:
raise ImportError(
"Please install safetensors to load sharded .safetensors model: `pip install safetensors`"
)
for shard_path in shard_files:
shard_dict = load_file(str(shard_path), device=device)
merged_state_dict.update(shard_dict)
else:
for shard_path in shard_files:
with open(shard_path, "rb") as fp:
shard_dict = torch.load(fp, map_location=device)
if isinstance(shard_dict, dict):
merged_state_dict.update(shard_dict)
return merged_state_dict
def load_model(
name: str,
device: Optional[Union[str, torch.device]] = None,
download_root: str = None,
in_memory: bool = False,
decoder_only: bool = False,
custom_alignment_heads: Optional[str] = None,
lora_path: Optional[str] = None,
) -> Whisper:
"""
Load a Whisper ASR model
Parameters
----------
name : str
one of the official model names listed by `whisper.available_models()`, or
path to a model checkpoint containing the model dimensions and the model state_dict.
Can be a single file (.pt, .bin, .safetensors), a directory containing model files,
or a sharded model directory with files like model-00001-of-00002.safetensors.
device : Union[str, torch.device]
the PyTorch device to put the model into
download_root: str
path to download the model files; by default, it uses "~/.cache/whisper"
in_memory: bool
whether to preload the model weights into host memory
lora_path: str
optional directory containing PEFT LoRA adapter weights (adapter_config + adapter_model)
Returns
-------
model : Whisper
The Whisper ASR model instance
"""
from whisperlivekit.model_paths import detect_model_format
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if download_root is None:
default = os.path.join(os.path.expanduser("~"), ".cache")
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
checkpoint = None
model_path_for_config = name # Used to find config.json for dims inference
if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
if in_memory:
checkpoint = _load_checkpoint(None, device, checkpoint_bytes=checkpoint_file)
else:
checkpoint = _load_checkpoint(checkpoint_file, device)
elif os.path.isfile(name):
if in_memory:
with open(name, "rb") as f:
checkpoint_bytes = f.read()
checkpoint = _load_checkpoint(None, device, checkpoint_bytes=checkpoint_bytes)
else:
checkpoint = _load_checkpoint(name, device)
model_path_for_config = name
elif os.path.isdir(name):
model_info = detect_model_format(name)
if not model_info.has_pytorch:
raise RuntimeError(
f"No PyTorch checkpoint found in directory {name}. "
f"Expected .pt, .bin, or .safetensors file(s)."
)
if model_info.is_sharded:
checkpoint = _load_sharded_checkpoint(model_info.pytorch_files, device)
else:
single_file = model_info.pytorch_files[0]
if in_memory:
with open(single_file, "rb") as f:
checkpoint_bytes = f.read()
checkpoint = _load_checkpoint(None, device, checkpoint_bytes=checkpoint_bytes)
else:
checkpoint = _load_checkpoint(single_file, device)
model_path_for_config = name
else:
raise RuntimeError(
f"Model {name} not found; available models = {available_models()}"
)
alignment_heads = _ALIGNMENT_HEADS.get(name, None)
if custom_alignment_heads:
alignment_heads = custom_alignment_heads.encode()
dims_cfg = checkpoint.get("dims") if isinstance(checkpoint, dict) else None
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
else:
state_dict = checkpoint
if alignment_heads is None and "alignment_heads" in state_dict:
alignment_heads = state_dict["alignment_heads"]
state_dict = _convert_hf_state_dict(state_dict)
state_dict = _convert_mlx_state_dict(state_dict)
_apply_lora_adapter(state_dict, lora_path)
if dims_cfg is not None:
dims = ModelDimensions(**dims_cfg)
else:
dims = _infer_dims_from_config(model_path_for_config)
if dims is None:
raise RuntimeError(
"Could not determine model dimensions. "
"Ensure the checkpoint includes 'dims' or a HuggingFace config.json is present."
)
if not isinstance(state_dict, dict):
state_dict = checkpoint
model = Whisper(dims, decoder_only=decoder_only)
if decoder_only:
state_dict = {
k: v for k, v in state_dict.items()
if 'encoder' not in k
}
model.load_state_dict(state_dict)
if alignment_heads is not None:
if isinstance(alignment_heads, bytes):
model.set_alignment_heads(alignment_heads)
elif isinstance(alignment_heads, torch.Tensor): #for mlx whisper
mask = torch.zeros(dims.n_text_layer, dims.n_text_head, dtype=torch.bool)
for layer, head in alignment_heads.tolist():
mask[layer, head] = True
model.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
return model.to(device)
def convert_encoder_to_coreml(
model_name = "base",
output_path= "whisper_encoder.mlpackage",
dummy_frames = 3000, #Number of time frames to use for the dummy mel input during tracing
precision = "float16",
):
import coremltools as ct
model = load_model(model_name, device="cpu", decoder_only=False)
encoder = model.encoder.eval().cpu()
dummy_input = torch.randn(
1,
model.dims.n_mels,
dummy_frames,
dtype=next(encoder.parameters()).dtype,
)
with torch.no_grad():
traced_encoder = torch.jit.trace(encoder, dummy_input)
precision_map = {
"float16": ct.precision.FLOAT16,
"fp16": ct.precision.FLOAT16,
"float32": ct.precision.FLOAT32,
"fp32": ct.precision.FLOAT32,
}
coreml_precision = precision_map[precision.lower()]
mlmodel = ct.convert(
traced_encoder,
inputs=[ct.TensorType(name="mel", shape=dummy_input.shape)],
convert_to= "mlprogram",
compute_precision=coreml_precision,
)
output_path = Path(output_path)
mlmodel.save(str(output_path))
return output_path
# if __name__ == "__main__":
# convert_encoder_to_coreml(model_name="tiny", output_path="whisper_encoder.mlpackage", dummy_frames=3000, precision="float16", convert_to="mlprogram")

View File

@@ -1,3 +0,0 @@
from .transcribe import cli
cli()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,157 +0,0 @@
import os
from functools import lru_cache
from subprocess import CalledProcessError, run
from typing import Optional, Union
import numpy as np
import torch
import torch.nn.functional as F
from .utils import exact_div
# hard-coded audio hyperparameters
SAMPLE_RATE = 16000
N_FFT = 400
HOP_LENGTH = 160
CHUNK_LENGTH = 30
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
def load_audio(file: str, sr: int = SAMPLE_RATE):
"""
Open an audio file and read as mono waveform, resampling as necessary
Parameters
----------
file: str
The audio file to open
sr: int
The sample rate to resample the audio if necessary
Returns
-------
A NumPy array containing the audio waveform, in float32 dtype.
"""
# This launches a subprocess to decode audio while down-mixing
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
# fmt: off
cmd = [
"ffmpeg",
"-nostdin",
"-threads", "0",
"-i", file,
"-f", "s16le",
"-ac", "1",
"-acodec", "pcm_s16le",
"-ar", str(sr),
"-"
]
# fmt: on
try:
out = run(cmd, capture_output=True, check=True).stdout
except CalledProcessError as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
"""
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
"""
if torch.is_tensor(array):
if array.shape[axis] > length:
array = array.index_select(
dim=axis, index=torch.arange(length, device=array.device)
)
if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
else:
if array.shape[axis] > length:
array = array.take(indices=range(length), axis=axis)
if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = np.pad(array, pad_widths)
return array
@lru_cache(maxsize=None)
def mel_filters(device, n_mels: int) -> torch.Tensor:
"""
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
Allows decoupling librosa dependency; saved using:
np.savez_compressed(
"mel_filters.npz",
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
)
"""
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
with np.load(filters_path, allow_pickle=False) as f:
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
def log_mel_spectrogram(
audio: Union[str, np.ndarray, torch.Tensor],
n_mels: int = 80,
padding: int = 0,
device: Optional[Union[str, torch.device]] = None,
):
"""
Compute the log-Mel spectrogram of
Parameters
----------
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
n_mels: int
The number of Mel-frequency filters, only 80 and 128 are supported
padding: int
Number of zero samples to pad to the right
device: Optional[Union[str, torch.device]]
If given, the audio tensor is moved to this device before STFT
Returns
-------
torch.Tensor, shape = (n_mels, n_frames)
A Tensor that contains the Mel spectrogram
"""
if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
audio = torch.from_numpy(audio)
if device is not None:
audio = audio.to(device)
if padding > 0:
audio = F.pad(audio, (0, padding))
window = torch.hann_window(N_FFT).to(audio.device)
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2
filters = mel_filters(audio.device, n_mels)
mel_spec = filters @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec

View File

@@ -1,821 +0,0 @@
from dataclasses import dataclass, field, replace
from typing import (TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence,
Tuple, Union)
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.distributions import Categorical
from .audio import CHUNK_LENGTH
from .tokenizer import Tokenizer, get_tokenizer
from .utils import compression_ratio
if TYPE_CHECKING:
from .model import Whisper
@torch.no_grad()
def detect_language(
model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None
) -> Tuple[Tensor, List[dict]]:
"""
Detect the spoken language in the audio, and return them as list of strings, along with the ids
of the most probable language tokens and the probability distribution over all language tokens.
This is performed outside the main decode loop in order to not interfere with kv-caching.
Returns
-------
language_tokens : Tensor, shape = (n_audio,)
ids of the most probable language tokens, which appears after the startoftranscript token.
language_probs : List[Dict[str, float]], length = n_audio
list of dictionaries containing the probability distribution over all languages.
"""
if tokenizer is None:
tokenizer = get_tokenizer(
model.is_multilingual, num_languages=model.num_languages
)
if (
tokenizer.language is None
or tokenizer.language_token not in tokenizer.sot_sequence
):
raise ValueError(
"This model doesn't have language tokens so it can't perform lang id"
)
single = mel.ndim == 2
if single:
mel = mel.unsqueeze(0)
# skip encoder forward pass if already-encoded audio features were given
if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
mel = model.encoder(mel)
# forward pass using a single token, startoftranscript
n_audio = mel.shape[0]
x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
logits = model.logits(x, mel)[:, 0]
# collect detected languages; suppress all non-language tokens
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
mask[list(tokenizer.all_language_tokens)] = False
logits[:, mask] = -np.inf
language_tokens = logits.argmax(dim=-1)
language_token_probs = logits.softmax(dim=-1).cpu()
language_probs = [
{
c: language_token_probs[i, j].item()
for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
}
for i in range(n_audio)
]
if single:
language_tokens = language_tokens[0]
language_probs = language_probs[0]
return language_tokens, language_probs
@dataclass(frozen=True)
class DecodingOptions:
# whether to perform X->X "transcribe" or X->English "translate"
task: str = "transcribe"
# language that the audio is in; uses detected language if None
language: Optional[str] = None
# sampling-related options
temperature: float = 0.0
sample_len: Optional[int] = None # maximum number of tokens to sample
best_of: Optional[int] = None # number of independent sample trajectories, if t > 0
beam_size: Optional[int] = None # number of beams in beam search, if t == 0
patience: Optional[float] = None # patience in beam search (arxiv:2204.05424)
# "alpha" in Google NMT, or None for length norm, when ranking generations
# to select which to return among the beams or best-of-N samples
length_penalty: Optional[float] = None
# text or tokens to feed as the prompt or the prefix; for more info:
# https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
prompt: Optional[Union[str, List[int]]] = None # for the previous context
prefix: Optional[Union[str, List[int]]] = None # to prefix the current context
# list of tokens ids (or comma-separated token ids) to suppress
# "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
suppress_blank: bool = True # this will suppress blank outputs
# timestamp sampling options
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
max_initial_timestamp: Optional[float] = 1.0
# implementation details
fp16: bool = True # use fp16 for most of the calculation
@dataclass(frozen=True)
class DecodingResult:
audio_features: Tensor
language: str
language_probs: Optional[Dict[str, float]] = None
tokens: List[int] = field(default_factory=list)
text: str = ""
avg_logprob: float = np.nan
no_speech_prob: float = np.nan
temperature: float = np.nan
compression_ratio: float = np.nan
class Inference:
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
"""Perform a forward pass on the decoder and return per-token logits"""
raise NotImplementedError
def rearrange_kv_cache(self, source_indices) -> None:
"""Update the key-value cache according to the updated beams"""
raise NotImplementedError
def cleanup_caching(self) -> None:
"""Clean up any resources or hooks after decoding is finished"""
pass
class PyTorchInference(Inference):
def __init__(self, model: "Whisper", initial_token_length: int):
self.model: "Whisper" = model
self.initial_token_length = initial_token_length
self.kv_cache = {}
self.kv_cache_ids = []
for block in self.model.decoder.blocks:
self.kv_cache_ids.append(block.attn.key_cache_id)
self.kv_cache_ids.append(block.attn.value_cache_id)
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
if tokens.shape[-1] > self.initial_token_length:
# only need to use the last token except in the first forward pass
tokens = tokens[:, -1:]
return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
def cleanup_caching(self):
self.kv_cache = {}
def rearrange_kv_cache(self, source_indices):
if source_indices != list(range(len(source_indices))):
for cache_id in self.kv_cache_ids:
if cache_id in self.kv_cache:
# update the key/value cache to contain the selected sequences
self.kv_cache[cache_id] = self.kv_cache[cache_id][source_indices].detach()
class SequenceRanker:
def rank(
self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]
) -> List[int]:
"""
Given a list of groups of samples and their cumulative log probabilities,
return the indices of the samples in each group to select as the final result
"""
raise NotImplementedError
class MaximumLikelihoodRanker(SequenceRanker):
"""
Select the sample with the highest log probabilities, penalized using either
a simple length normalization or Google NMT paper's length penalty
"""
def __init__(self, length_penalty: Optional[float]):
self.length_penalty = length_penalty
def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]):
def scores(logprobs, lengths):
result = []
for logprob, length in zip(logprobs, lengths):
if self.length_penalty is None:
penalty = length
else:
# from the Google NMT paper
penalty = ((5 + length) / 6) ** self.length_penalty
result.append(logprob / penalty)
return result
# get the sequence with the highest score
lengths = [[len(t) for t in s] for s in tokens]
return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
class TokenDecoder:
def reset(self):
"""Initialize any stateful variables for decoding a new sequence"""
def update(
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
) -> Tuple[Tensor, bool]:
"""Specify how to select the next token, based on the current trace and logits
Parameters
----------
tokens : Tensor, shape = (n_batch, current_sequence_length)
all tokens in the context so far, including the prefix and sot_sequence tokens
logits : Tensor, shape = (n_batch, vocab_size)
per-token logits of the probability distribution at the current step
sum_logprobs : Tensor, shape = (n_batch)
cumulative log probabilities for each sequence
Returns
-------
tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
the tokens, appended with the selected next token
completed : bool
True if all sequences has reached the end of text
"""
raise NotImplementedError
def finalize(
self, tokens: Tensor, sum_logprobs: Tensor
) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
"""Finalize search and return the final candidate sequences
Parameters
----------
tokens : Tensor, shape = (n_audio, n_group, current_sequence_length)
all tokens in the context so far, including the prefix and sot_sequence
sum_logprobs : Tensor, shape = (n_audio, n_group)
cumulative log probabilities for each sequence
Returns
-------
tokens : Sequence[Sequence[Tensor]], length = n_audio
sequence of Tensors containing candidate token sequences, for each audio input
sum_logprobs : List[List[float]], length = n_audio
sequence of cumulative log probabilities corresponding to the above
"""
raise NotImplementedError
class GreedyDecoder(TokenDecoder):
def __init__(self, temperature: float, eot: int):
self.temperature = temperature
self.eot = eot
def update(
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
) -> Tuple[Tensor, bool]:
if self.temperature == 0:
next_tokens = logits.argmax(dim=-1)
else:
next_tokens = Categorical(logits=logits / self.temperature).sample()
logprobs = F.log_softmax(logits.float(), dim=-1)
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
next_tokens[tokens[:, -1] == self.eot] = self.eot
tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
completed = (tokens[:, -1] == self.eot).all()
return tokens, completed
def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
# make sure each sequence has at least one EOT token at the end
tokens = F.pad(tokens, (0, 1), value=self.eot)
return tokens, sum_logprobs.tolist()
class BeamSearchDecoder(TokenDecoder):
def __init__(
self,
beam_size: int,
eot: int,
inference: Inference,
patience: Optional[float] = None,
):
self.beam_size = beam_size
self.eot = eot
self.inference = inference
self.patience = patience or 1.0
self.max_candidates: int = round(beam_size * self.patience)
self.finished_sequences = None
assert (
self.max_candidates > 0
), f"Invalid beam size ({beam_size}) or patience ({patience})"
def reset(self):
self.finished_sequences = None
def update(
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
) -> Tuple[Tensor, bool]:
if tokens.shape[0] % self.beam_size != 0:
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
n_audio = tokens.shape[0] // self.beam_size
if self.finished_sequences is None: # for the first update
self.finished_sequences = [{} for _ in range(n_audio)]
logprobs = F.log_softmax(logits.float(), dim=-1)
next_tokens, source_indices, finished_sequences = [], [], []
for i in range(n_audio):
scores, sources, finished = {}, {}, {}
# STEP 1: calculate the cumulative log probabilities for possible candidates
for j in range(self.beam_size):
idx = i * self.beam_size + j
prefix = tokens[idx].tolist()
for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
new_logprob = (sum_logprobs[idx] + logprob).item()
sequence = tuple(prefix + [token.item()])
scores[sequence] = new_logprob
sources[sequence] = idx
# STEP 2: rank the candidates and keep the top beam_size sequences for each audio
saved = 0
for sequence in sorted(scores, key=scores.get, reverse=True):
if sequence[-1] == self.eot:
finished[sequence] = scores[sequence]
else:
sum_logprobs[len(next_tokens)] = scores[sequence]
next_tokens.append(sequence)
source_indices.append(sources[sequence])
saved += 1
if saved == self.beam_size:
break
finished_sequences.append(finished)
tokens = torch.tensor(next_tokens, device=tokens.device)
self.inference.rearrange_kv_cache(source_indices)
# add newly finished sequences to self.finished_sequences
assert len(self.finished_sequences) == len(finished_sequences)
for previously_finished, newly_finished in zip(
self.finished_sequences, finished_sequences
):
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
if len(previously_finished) >= self.max_candidates:
break # the candidate list is full
previously_finished[seq] = newly_finished[seq]
# mark as completed if all audio has enough number of samples
completed = all(
len(sequences) >= self.max_candidates
for sequences in self.finished_sequences
)
return tokens, completed
def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
# collect all finished sequences, including patience, and add unfinished ones if not enough
sum_logprobs = sum_logprobs.cpu()
for i, sequences in enumerate(self.finished_sequences):
if (
len(sequences) < self.beam_size
): # when not enough sequences are finished
for j in list(np.argsort(sum_logprobs[i]))[::-1]:
sequence = preceding_tokens[i, j].tolist() + [self.eot]
sequences[tuple(sequence)] = sum_logprobs[i][j].item()
if len(sequences) >= self.beam_size:
break
tokens: List[List[Tensor]] = [
[torch.tensor(seq) for seq in sequences.keys()]
for sequences in self.finished_sequences
]
sum_logprobs: List[List[float]] = [
list(sequences.values()) for sequences in self.finished_sequences
]
return tokens, sum_logprobs
class LogitFilter:
def apply(self, logits: Tensor, tokens: Tensor) -> None:
"""Apply any filtering or masking to logits in-place
Parameters
----------
logits : Tensor, shape = (n_batch, vocab_size)
per-token logits of the probability distribution at the current step
tokens : Tensor, shape = (n_batch, current_sequence_length)
all tokens in the context so far, including the prefix and sot_sequence tokens
"""
raise NotImplementedError
class SuppressBlank(LogitFilter):
def __init__(self, tokenizer: Tokenizer, sample_begin: int):
self.tokenizer = tokenizer
self.sample_begin = sample_begin
def apply(self, logits: Tensor, tokens: Tensor):
if tokens.shape[1] == self.sample_begin:
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
class SuppressTokens(LogitFilter):
def __init__(self, suppress_tokens: Sequence[int]):
self.suppress_tokens = list(suppress_tokens)
def apply(self, logits: Tensor, tokens: Tensor):
logits[:, self.suppress_tokens] = -np.inf
class ApplyTimestampRules(LogitFilter):
def __init__(
self,
tokenizer: Tokenizer,
sample_begin: int,
max_initial_timestamp_index: Optional[int],
):
self.tokenizer = tokenizer
self.sample_begin = sample_begin
self.max_initial_timestamp_index = max_initial_timestamp_index
def apply(self, logits: Tensor, tokens: Tensor):
# suppress <|notimestamps|> which is handled by without_timestamps
if self.tokenizer.no_timestamps is not None:
logits[:, self.tokenizer.no_timestamps] = -np.inf
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
for k in range(tokens.shape[0]):
sampled_tokens = tokens[k, self.sample_begin :]
seq = [t for t in sampled_tokens.tolist()]
last_was_timestamp = (
len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
)
penultimate_was_timestamp = (
len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
)
if last_was_timestamp:
if penultimate_was_timestamp: # has to be non-timestamp
logits[k, self.tokenizer.timestamp_begin :] = -np.inf
else: # cannot be normal text tokens
logits[k, : self.tokenizer.eot] = -np.inf
timestamps = sampled_tokens[
sampled_tokens.ge(self.tokenizer.timestamp_begin)
]
if timestamps.numel() > 0:
# timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
# also force each segment to have a nonzero length, to prevent infinite looping
if last_was_timestamp and not penultimate_was_timestamp:
timestamp_last = timestamps[-1]
else:
timestamp_last = timestamps[-1] + 1
logits[k, self.tokenizer.timestamp_begin : timestamp_last] = -np.inf
if tokens.shape[1] == self.sample_begin:
# suppress generating non-timestamp tokens at the beginning
logits[:, : self.tokenizer.timestamp_begin] = -np.inf
# apply the `max_initial_timestamp` option
if self.max_initial_timestamp_index is not None:
last_allowed = (
self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
)
logits[:, last_allowed + 1 :] = -np.inf
# if sum of probability over timestamps is above any other token, sample timestamp
logprobs = F.log_softmax(logits.float(), dim=-1)
for k in range(tokens.shape[0]):
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(
dim=-1
)
max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
if timestamp_logprob > max_text_token_logprob:
logits[k, : self.tokenizer.timestamp_begin] = -np.inf
class DecodingTask:
inference: Inference
sequence_ranker: SequenceRanker
decoder: TokenDecoder
logit_filters: List[LogitFilter]
def __init__(self, model: "Whisper", options: DecodingOptions):
self.model = model
language = options.language or "en"
tokenizer = get_tokenizer(
model.is_multilingual,
num_languages=model.num_languages,
language=language,
task=options.task,
)
self.tokenizer: Tokenizer = tokenizer
self.options: DecodingOptions = self._verify_options(options)
self.n_group: int = options.beam_size or options.best_of or 1
self.n_ctx: int = model.dims.n_text_ctx
self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
if self.options.without_timestamps:
self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
self.initial_tokens: Tuple[int] = self._get_initial_tokens()
self.sample_begin: int = len(self.initial_tokens)
self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
# inference: implements the forward pass through the decoder, including kv caching
self.inference = PyTorchInference(model, len(self.initial_tokens))
# sequence ranker: implements how to rank a group of sampled sequences
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
# decoder: implements how to select the next tokens, given the autoregressive distribution
if options.beam_size is not None:
self.decoder = BeamSearchDecoder(
options.beam_size, tokenizer.eot, self.inference, options.patience
)
else:
self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
# logit filters: applies various rules to suppress or penalize certain tokens
self.logit_filters = []
if self.options.suppress_blank:
self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
if self.options.suppress_tokens:
self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
if not options.without_timestamps:
precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
max_initial_timestamp_index = None
if options.max_initial_timestamp:
max_initial_timestamp_index = round(
self.options.max_initial_timestamp / precision
)
self.logit_filters.append(
ApplyTimestampRules(
tokenizer, self.sample_begin, max_initial_timestamp_index
)
)
def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
if options.beam_size is not None and options.best_of is not None:
raise ValueError("beam_size and best_of can't be given together")
if options.temperature == 0:
if options.best_of is not None:
raise ValueError("best_of with greedy sampling (T=0) is not compatible")
if options.patience is not None and options.beam_size is None:
raise ValueError("patience requires beam_size to be given")
if options.length_penalty is not None and not (
0 <= options.length_penalty <= 1
):
raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
return options
def _get_initial_tokens(self) -> Tuple[int]:
tokens = list(self.sot_sequence)
if prefix := self.options.prefix:
prefix_tokens = (
self.tokenizer.encode(" " + prefix.strip())
if isinstance(prefix, str)
else prefix
)
if self.sample_len is not None:
max_prefix_len = self.n_ctx // 2 - self.sample_len
prefix_tokens = prefix_tokens[-max_prefix_len:]
tokens = tokens + prefix_tokens
if prompt := self.options.prompt:
prompt_tokens = (
self.tokenizer.encode(" " + prompt.strip())
if isinstance(prompt, str)
else prompt
)
tokens = (
[self.tokenizer.sot_prev]
+ prompt_tokens[-(self.n_ctx // 2 - 1) :]
+ tokens
)
return tuple(tokens)
def _get_suppress_tokens(self) -> Tuple[int]:
suppress_tokens = self.options.suppress_tokens
if isinstance(suppress_tokens, str):
suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
if -1 in suppress_tokens:
suppress_tokens = [t for t in suppress_tokens if t >= 0]
suppress_tokens.extend(self.tokenizer.non_speech_tokens)
elif suppress_tokens is None or len(suppress_tokens) == 0:
suppress_tokens = [] # interpret empty string as an empty list
else:
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
suppress_tokens.extend(
[
self.tokenizer.transcribe,
self.tokenizer.translate,
self.tokenizer.sot,
self.tokenizer.sot_prev,
self.tokenizer.sot_lm,
]
)
if self.tokenizer.no_speech is not None:
# no-speech probability is collected separately
suppress_tokens.append(self.tokenizer.no_speech)
return tuple(sorted(set(suppress_tokens)))
def _get_audio_features(self, mel: Tensor):
if self.options.fp16:
mel = mel.half()
if mel.shape[-2:] == (
self.model.dims.n_audio_ctx,
self.model.dims.n_audio_state,
):
# encoded audio features are given; skip audio encoding
audio_features = mel
else:
audio_features = self.model.encoder(mel)
if audio_features.dtype != (
torch.float16 if self.options.fp16 else torch.float32
):
return TypeError(
f"audio_features has an incorrect dtype: {audio_features.dtype}"
)
return audio_features
def _detect_language(self, audio_features: Tensor, tokens: Tensor):
languages = [self.options.language] * audio_features.shape[0]
lang_probs = None
if self.options.language is None or self.options.task == "lang_id":
lang_tokens, lang_probs = self.model.detect_language(
audio_features, self.tokenizer
)
languages = [max(probs, key=probs.get) for probs in lang_probs]
if self.options.language is None:
tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
return languages, lang_probs
def _main_loop(self, audio_features: Tensor, tokens: Tensor):
n_batch = tokens.shape[0]
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
no_speech_probs = [np.nan] * n_batch
try:
for i in range(self.sample_len):
logits = self.inference.logits(tokens, audio_features)
if (
i == 0 and self.tokenizer.no_speech is not None
): # save no_speech_probs
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
# now we need to consider the logits at the last token only
logits = logits[:, -1]
# apply the logit filters, e.g. for suppressing or applying penalty to
for logit_filter in self.logit_filters:
logit_filter.apply(logits, tokens)
# expand the tokens tensor with the selected next tokens
tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
if completed or tokens.shape[-1] > self.n_ctx:
break
finally:
self.inference.cleanup_caching()
return tokens, sum_logprobs, no_speech_probs
@torch.no_grad()
def run(self, mel: Tensor) -> List[DecodingResult]:
self.decoder.reset()
tokenizer: Tokenizer = self.tokenizer
n_audio: int = mel.shape[0]
audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass
tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
# detect language if requested, overwriting the language token
languages, language_probs = self._detect_language(audio_features, tokens)
if self.options.task == "lang_id":
return [
DecodingResult(
audio_features=features, language=language, language_probs=probs
)
for features, language, probs in zip(
audio_features, languages, language_probs
)
]
# repeat text tensors by the group size, for beam search or best-of-n sampling
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
# call the main sampling loop
tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
# reshape the tensors to have (n_audio, n_group) as the first two dimensions
audio_features = audio_features[:: self.n_group]
no_speech_probs = no_speech_probs[:: self.n_group]
assert audio_features.shape[0] == len(no_speech_probs) == n_audio
tokens = tokens.reshape(n_audio, self.n_group, -1)
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
# get the final candidates for each group, and slice between the first sampled token and EOT
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
tokens: List[List[Tensor]] = [
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
for s in tokens
]
# select the top-ranked sample in each group
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
avg_logprobs: List[float] = [
lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
]
fields = (
texts,
languages,
tokens,
audio_features,
avg_logprobs,
no_speech_probs,
)
if len(set(map(len, fields))) != 1:
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
return [
DecodingResult(
audio_features=features,
language=language,
tokens=tokens,
text=text,
avg_logprob=avg_logprob,
no_speech_prob=no_speech_prob,
temperature=self.options.temperature,
compression_ratio=compression_ratio(text),
)
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
*fields
)
]
@torch.no_grad()
def decode(
model: "Whisper",
mel: Tensor,
options: DecodingOptions = DecodingOptions(),
**kwargs,
) -> Union[DecodingResult, List[DecodingResult]]:
"""
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
Parameters
----------
model: Whisper
the Whisper model instance
mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000)
A tensor containing the Mel spectrogram(s)
options: DecodingOptions
A dataclass that contains all necessary options for decoding 30-second segments
Returns
-------
result: Union[DecodingResult, List[DecodingResult]]
The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
"""
if single := mel.ndim == 2:
mel = mel.unsqueeze(0)
if kwargs:
options = replace(options, **kwargs)
result = DecodingTask(model, options).run(mel)
return result[0] if single else result

View File

@@ -1,407 +0,0 @@
import base64
import gzip
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Dict, Iterable, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from .decoding import decode as decode_function
from .decoding import detect_language as detect_language_function
from .transcribe import transcribe as transcribe_function
try:
from torch.nn.functional import scaled_dot_product_attention
SDPA_AVAILABLE = True
except (ImportError, RuntimeError, OSError):
scaled_dot_product_attention = None
SDPA_AVAILABLE = False
@dataclass
class ModelDimensions:
n_mels: int
n_audio_ctx: int
n_audio_state: int
n_audio_head: int
n_audio_layer: int
n_vocab: int
n_text_ctx: int
n_text_state: int
n_text_head: int
n_text_layer: int
class LayerNorm(nn.LayerNorm):
def forward(self, x: Tensor) -> Tensor:
return super().forward(x.float()).type(x.dtype)
class Linear(nn.Linear):
def forward(self, x: Tensor) -> Tensor:
return F.linear(
x,
self.weight.to(x.dtype),
None if self.bias is None else self.bias.to(x.dtype),
)
class Conv1d(nn.Conv1d):
def _conv_forward(
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
) -> Tensor:
return super()._conv_forward(
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
)
def sinusoids(length, channels, max_timescale=10000):
"""Returns sinusoids for positional embedding"""
assert channels % 2 == 0
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
@contextmanager
def disable_sdpa():
prev_state = MultiHeadAttention.use_sdpa
try:
MultiHeadAttention.use_sdpa = False
yield
finally:
MultiHeadAttention.use_sdpa = prev_state
class MultiHeadAttention(nn.Module):
use_sdpa = False # Disable SDPA to ensure qk is always computed when needed
def __init__(self, n_state: int, n_head: int, cache_id: str = "", n_text_ctx: int = 448):
super().__init__()
self.n_head = n_head
self.n_text_ctx = n_text_ctx
self.query = Linear(n_state, n_state)
self.key = Linear(n_state, n_state, bias=False)
self.value = Linear(n_state, n_state)
self.out = Linear(n_state, n_state)
self.cache_id = cache_id
# Cache IDs for key and value (used with dict-based kv_cache)
self.key_cache_id = f"{cache_id}_key"
self.value_cache_id = f"{cache_id}_value"
# Keep these for backward compatibility with hook-based caching
self.key.cache_id = self.key_cache_id
self.value.cache_id = self.value_cache_id
def forward(
self,
x: Tensor,
xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None,
):
q = self.query(x)
if xa is None:
# Self-attention
k = self.key(x)
v = self.value(x)
if kv_cache is not None:
k, v = self._update_self_attn_cache(k, v, kv_cache)
else:
# Cross-attention: compute once and cache, or reuse from cache
if kv_cache is not None and self.key_cache_id in kv_cache:
k = kv_cache[self.key_cache_id]
v = kv_cache[self.value_cache_id]
else:
k = self.key(xa)
v = self.value(xa)
if kv_cache is not None:
kv_cache[self.key_cache_id] = k
kv_cache[self.value_cache_id] = v
wv, qk = self.qkv_attention(q, k, v, mask)
return self.out(wv), qk
def _update_self_attn_cache(
self, k: Tensor, v: Tensor, kv_cache: dict
) -> Tuple[Tensor, Tensor]:
"""Update self-attention kv cache by concatenating new k,v with cached values."""
if self.key_cache_id not in kv_cache or k.shape[1] > self.n_text_ctx:
# First token or context overflow: save as-is
kv_cache[self.key_cache_id] = k.detach()
kv_cache[self.value_cache_id] = v.detach()
else:
# Concatenate with existing cache
cached_k = kv_cache[self.key_cache_id]
cached_v = kv_cache[self.value_cache_id]
k = torch.cat([cached_k, k], dim=1).detach()
v = torch.cat([cached_v, v], dim=1).detach()
kv_cache[self.key_cache_id] = k
kv_cache[self.value_cache_id] = v
return k, v
def qkv_attention(
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
n_batch, n_ctx, n_state = q.shape
scale = (n_state // self.n_head) ** -0.25
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
if SDPA_AVAILABLE and MultiHeadAttention.use_sdpa:
a = scaled_dot_product_attention(
q, k, v, is_causal=mask is not None and n_ctx > 1
)
out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
qk = None
else:
qk = (q * scale) @ (k * scale).transpose(-1, -2)
if mask is not None:
qk = qk + mask[:n_ctx, :n_ctx]
qk = qk.float()
w = F.softmax(qk, dim=-1).to(q.dtype)
out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
qk = qk.detach()
return out, qk
class ResidualAttentionBlock(nn.Module):
def __init__(
self, n_state: int, n_head: int, cross_attention: bool = False,
cache_id: str = "", n_text_ctx: int = 448
):
super().__init__()
self.attn = MultiHeadAttention(
n_state, n_head, cache_id=f"{cache_id}_self_attn", n_text_ctx=n_text_ctx
)
self.attn_ln = LayerNorm(n_state)
self.cross_attn = (
MultiHeadAttention(
n_state, n_head, cache_id=f"{cache_id}_cross_attn", n_text_ctx=n_text_ctx
) if cross_attention else None
)
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
n_mlp = n_state * 4
self.mlp = nn.Sequential(
Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
)
self.mlp_ln = LayerNorm(n_state)
def forward(
self,
x: Tensor,
xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
"""
Returns:
x: The output tensor
cross_attn_qk: Cross-attention weights (if cross_attn exists), else None
"""
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
cross_attn_qk = None
if self.cross_attn:
cross_out, cross_attn_qk = self.cross_attn(
self.cross_attn_ln(x), xa, kv_cache=kv_cache
)
x = x + cross_out
x = x + self.mlp(self.mlp_ln(x))
return x, cross_attn_qk
class AudioEncoder(nn.Module):
def __init__(
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
):
super().__init__()
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[ResidualAttentionBlock(n_state, n_head, cache_id=f"enc_layer{i}") for i in range(n_layer)]
)
self.ln_post = LayerNorm(n_state)
def forward(self, x: Tensor):
"""
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
the mel spectrogram of the audio
"""
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
x = x.permute(0, 2, 1)
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
x = (x + self.positional_embedding).to(x.dtype)
for block in self.blocks:
x, _ = block(x) # Encoder blocks don't have cross-attention
x = self.ln_post(x)
return x
class TextDecoder(nn.Module):
def __init__(
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
):
super().__init__()
self.n_ctx = n_ctx
self.token_embedding = nn.Embedding(n_vocab, n_state)
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[
ResidualAttentionBlock(
n_state, n_head, cross_attention=True,
cache_id=f"dec_layer{i}", n_text_ctx=n_ctx
)
for i in range(n_layer)
]
)
self.ln = LayerNorm(n_state)
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
self.register_buffer("mask", mask, persistent=False)
def forward(
self,
x: Tensor,
xa: Tensor,
kv_cache: Optional[dict] = None,
return_cross_attn: bool = False,
):
"""
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
the text tokens
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
the encoded audio features to be attended on
kv_cache : Optional[dict]
Dictionary to store/retrieve key-value cache for efficient decoding
return_cross_attn : bool
If True, return cross-attention weights from all decoder layers
Returns
-------
logits : Tensor
The output logits
cross_attns : Optional[List[Tensor]]
List of cross-attention weights per layer (only if return_cross_attn=True)
"""
# Calculate offset from self-attention cache (not cross-attention which has audio length)
offset = 0
if kv_cache:
# Use the first decoder block's self-attention key cache to get token position
first_self_attn_key = self.blocks[0].attn.key_cache_id
if first_self_attn_key in kv_cache:
offset = kv_cache[first_self_attn_key].shape[1]
x = (
self.token_embedding(x)
+ self.positional_embedding[offset : offset + x.shape[-1]]
)
x = x.to(xa.dtype)
cross_attns = [] if return_cross_attn else None
for block in self.blocks:
x, cross_attn_qk = block(x, xa, mask=self.mask, kv_cache=kv_cache)
if return_cross_attn and cross_attn_qk is not None:
cross_attns.append(cross_attn_qk)
x = self.ln(x)
logits = (
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
).float()
if return_cross_attn:
return logits, cross_attns
return logits
class Whisper(nn.Module):
def __init__(self, dims: ModelDimensions, decoder_only: bool = False):
super().__init__()
self.dims = dims
if not decoder_only:
self.encoder = AudioEncoder(
self.dims.n_mels,
self.dims.n_audio_ctx,
self.dims.n_audio_state,
self.dims.n_audio_head,
self.dims.n_audio_layer,
)
self.decoder = TextDecoder(
self.dims.n_vocab,
self.dims.n_text_ctx,
self.dims.n_text_state,
self.dims.n_text_head,
self.dims.n_text_layer,
)
# use the last half among the decoder layers for time alignment by default;
# to use a specific set of heads, see `set_alignment_heads()` below.
all_heads = torch.zeros(
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
)
all_heads[self.dims.n_text_layer // 2 :] = True
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
def set_alignment_heads(self, dump: bytes):
array = np.frombuffer(
gzip.decompress(base64.b85decode(dump)), dtype=bool
).copy()
mask = torch.from_numpy(array).reshape(
self.dims.n_text_layer, self.dims.n_text_head
)
self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
def embed_audio(self, mel: torch.Tensor):
return self.encoder(mel)
def logits(
self,
tokens: torch.Tensor,
audio_features: torch.Tensor,
kv_cache: Optional[dict] = None,
return_cross_attn: bool = False,
):
return self.decoder(
tokens, audio_features,
kv_cache=kv_cache,
return_cross_attn=return_cross_attn
)
def forward(
self, mel: torch.Tensor, tokens: torch.Tensor
) -> Dict[str, torch.Tensor]:
return self.decoder(tokens, self.encoder(mel))
@property
def device(self):
return next(self.parameters()).device
@property
def is_multilingual(self):
return self.dims.n_vocab >= 51865
@property
def num_languages(self):
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
detect_language = detect_language_function
transcribe = transcribe_function
decode = decode_function

View File

@@ -1,2 +0,0 @@
from .basic import BasicTextNormalizer as BasicTextNormalizer
from .english import EnglishTextNormalizer as EnglishTextNormalizer

View File

@@ -1,80 +0,0 @@
import re
import unicodedata
import regex
# non-ASCII letters that are not separated by "NFKD" normalization
ADDITIONAL_DIACRITICS = {
"œ": "oe",
"Œ": "OE",
"ø": "o",
"Ø": "O",
"æ": "ae",
"Æ": "AE",
"ß": "ss",
"": "SS",
"đ": "d",
"Đ": "D",
"ð": "d",
"Ð": "D",
"þ": "th",
"Þ": "th",
"ł": "l",
"Ł": "L",
}
def remove_symbols_and_diacritics(s: str, keep=""):
"""
Replace any other markers, symbols, and punctuations with a space,
and drop any diacritics (category 'Mn' and some manual mappings)
"""
return "".join(
(
c
if c in keep
else (
ADDITIONAL_DIACRITICS[c]
if c in ADDITIONAL_DIACRITICS
else (
""
if unicodedata.category(c) == "Mn"
else " " if unicodedata.category(c)[0] in "MSP" else c
)
)
)
for c in unicodedata.normalize("NFKD", s)
)
def remove_symbols(s: str):
"""
Replace any other markers, symbols, punctuations with a space, keeping diacritics
"""
return "".join(
" " if unicodedata.category(c)[0] in "MSP" else c
for c in unicodedata.normalize("NFKC", s)
)
class BasicTextNormalizer:
def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
self.clean = (
remove_symbols_and_diacritics if remove_diacritics else remove_symbols
)
self.split_letters = split_letters
def __call__(self, s: str):
s = s.lower()
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
s = self.clean(s).lower()
if self.split_letters:
s = " ".join(regex.findall(r"\X", s, regex.U))
s = re.sub(
r"\s+", " ", s
) # replace any successive whitespace characters with a space
return s

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More