21 Commits

Author SHA1 Message Date
Quentin Fuxa
8bc0937c46 Update README section on powered research 2026-03-06 18:46:07 +01:00
Quentin Fuxa
929cf7a26b add link to AlignAtt interactive playground 2026-03-06 18:43:25 +01:00
Quentin Fuxa
abfaf06203 Merge branch 'main' of https://github.com/QuentinFuxa/WhisperLiveKit 2026-03-04 18:17:23 +01:00
Quentin Fuxa
d1fe932241 Apply DRY method v0 - to try to catch and resolve infinite loops such as in #338 2026-03-03 22:52:00 +01:00
Quentin Fuxa
c112ceffb6 Merge pull request #342 from mnicnc404/fix/whisper-tokenizer-index-error
fix(whisper/tokenizer): prevent IndexError from crashing multilingual…
2026-03-02 20:36:58 +01:00
Quentin Fuxa
4917406e06 Merge pull request #341 from AymurAI/feat/uv-deps-resolution
deps/docker: align python support, deterministic deps resolution & docker images releases
2026-03-02 20:34:49 +01:00
Chingning Chen
b63f54e838 fix(whisper/tokenizer): prevent IndexError from crashing multilingual streams
This fix addresses a critical bug in the Whisper tokenizer that causes
the transcription server to crash with an `IndexError: string index out
of range` when streaming audio in languages utilizing multi-byte UTF-8
characters (e.g., Cantonese, Japanese, Mandarin).

When a 3-byte character is cut off at the boundary of an audio chunk,
incomplete bytes are decoded into a single Unicode replacement character
(`\ufffd`), artificially shortening the string and breaking the offset
mapping assumed by `split_tokens_on_unicode`.

This ports the upstream fix from SYSTRAN/faster-whisper (PR #111) to add
a strict bounds check before accessing the string index, allowing
incomplete bytes to be safely caught and handled in the next chunk.
2026-03-02 15:31:43 +08:00
jedzill4
c56a53fbf4 deps(mlx-groups): add optional dependencies for Apple Silicon MLX backends 2026-03-01 20:05:52 -03:00
Quentin Fuxa
66e58624b9 disable MLXAlignAtt which fails on special characters 2026-03-01 11:52:00 +01:00
jedzill4
9366e067f9 deps(pyproject): add torch and torchaudio to main dependencies 2026-02-27 19:19:18 -03:00
jedzill4
866c25670c deps(docker): change CUDA base image to runtime version 2026-02-27 19:16:29 -03:00
jedzill4
2553ef283e deps(docker): fix dependency group for cu129 image
- Changed the extras for cu129-diarization-sortformer from gpu-cu129 to cu129.
- This aligns the dependency with the correct naming convention for consistency.
2026-02-25 21:49:08 -03:00
jedzill4
73e7fafc48 feat(tests): python matrix support test
- Introduced a new argument for selecting the diarization backend in the engine creation.
- Enhanced the `create_engine` function to accept and utilize the specified diarization backend.
- Updated the test runner to accommodate the new backend option for improved flexibility.
2026-02-25 21:35:41 -03:00
jedzill4
bbcebcb1fe deps(sortformer): adjust nemo-toolkit version constraints
- Updated the version constraint for `diarization-sortformer` to restrict it to Python 3.10 and below.
2026-02-25 21:33:00 -03:00
jedzill4
4bb58dc7aa deps(diart): improve diart dependency tree. rename gpu-cu129 dependency group to cu129 2026-02-25 20:27:26 -03:00
jedzill4
27ca028479 ci(github): add GitHub Actions workflows for Docker image publishing and support matrix
- Introduced a workflow to publish Docker images on tag push and manual triggers.
- Added a support matrix workflow to test across multiple OS and Python versions.
2026-02-25 14:27:51 -03:00
jedzill4
d24805cc18 🚀 chore (docker): update docker images improving caching and using uv as python package manager 2026-02-25 14:22:43 -03:00
jedzill4
994ce21365 📌 chore(deps): pin dependences to python 3.11 to 3.13 due dependency resolution matrix 2026-02-25 14:21:19 -03:00
jedzill4
132823dc09 deps: improve deps dependency resolution (wip) 2026-02-24 20:15:53 -03:00
jedzill4
d6d8c2635f chore: use uv as python project manager to improve dependency resolution 2026-02-23 22:16:32 -03:00
Quentin Fuxa
8fedeb9fed Merge pull request #340 from QuentinFuxa/voxtral_tests
feat: voxtral-mlx backend, benchmark suite, unit tests, runtime metrics
2026-02-23 10:37:40 +01:00
15 changed files with 7641 additions and 165 deletions

13
.dockerignore Normal file
View File

@@ -0,0 +1,13 @@
.git
.github
.venv
__pycache__
*.pyc
.pytest_cache
.mypy_cache
.ruff_cache
.cache
.tmp
.secrets
dist
build

61
.github/workflows/publish-docker.yml vendored Normal file
View File

@@ -0,0 +1,61 @@
name: Publish Docker Images
on:
push:
tags:
- "v*"
workflow_dispatch:
inputs:
tag:
description: "Image tag to publish (without image suffix)"
required: true
type: string
permissions:
contents: read
packages: write
jobs:
docker:
runs-on: ubuntu-latest
env:
IMAGE_TAG: ${{ github.event_name == 'workflow_dispatch' && github.event.inputs.tag || github.ref_name }}
strategy:
fail-fast: false
matrix:
include:
- image_suffix: cpu-diarization-sortformer
dockerfile: Dockerfile.cpu
extras: cpu,diarization-sortformer
- image_suffix: cu129-diarization-sortformer
dockerfile: Dockerfile
extras: cu129,diarization-sortformer
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Set lowercase owner
id: owner
run: echo "value=${GITHUB_REPOSITORY_OWNER,,}" >> "${GITHUB_OUTPUT}"
- name: Login to GHCR
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Setup Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build and push image
uses: docker/build-push-action@v6
with:
context: .
file: ./${{ matrix.dockerfile }}
push: true
build-args: |
EXTRAS=${{ matrix.extras }}
tags: |
ghcr.io/${{ steps.owner.outputs.value }}/whisperlivekit:${{ env.IMAGE_TAG }}-${{ matrix.image_suffix }}
ghcr.io/${{ steps.owner.outputs.value }}/whisperlivekit:latest-${{ matrix.image_suffix }}

View File

@@ -1,86 +1,74 @@
FROM nvidia/cuda:12.9.1-cudnn-devel-ubuntu24.04 FROM ghcr.io/astral-sh/uv:0.10.4 AS uvbin
# --- MARK: Builder Stage
FROM nvidia/cuda:12.9.1-cudnn-devel-ubuntu24.04 AS builder-gpu
ENV DEBIAN_FRONTEND=noninteractive ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1 ENV PYTHONUNBUFFERED=1
WORKDIR /app WORKDIR /app
ARG EXTRAS RUN apt-get update && \
ARG HF_PRECACHE_DIR apt-get install -y --no-install-recommends \
ARG HF_TKN_FILE build-essential \
python3-dev && \
rm -rf /var/lib/apt/lists/*
# Install UV and set up the environment
COPY --from=uvbin /uv /uvx /bin/
ENV UV_COMPILE_BYTECODE=1 UV_LINK_MODE=copy UV_NO_DEV=1
ENV UV_PYTHON_PREFERENCE=only-managed
ENV UV_PYTHON_INSTALL_DIR=/python
RUN uv python install 3.12
# Install dependencies first to leverage caching
ARG EXTRAS=cu129
COPY pyproject.toml uv.lock /app/
RUN set -eux; \
set --; \
for extra in $(echo "${EXTRAS:-}" | tr ',' ' '); do \
set -- "$@" --extra "$extra"; \
done; \
uv sync --frozen --no-install-project --no-editable --no-cache "$@"
# Copy the source code and install the package only
COPY whisperlivekit /app/whisperlivekit
RUN set -eux; \
set --; \
for extra in $(echo "${EXTRAS:-}" | tr ',' ' '); do \
set -- "$@" --extra "$extra"; \
done; \
uv sync --frozen --no-editable --no-cache "$@"
# --- MARK: Runtime Stage
FROM nvidia/cuda:12.9.1-cudnn-runtime-ubuntu24.04
ENV DEBIAN_FRONTEND=noninteractive
WORKDIR /app
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y --no-install-recommends \ apt-get install -y --no-install-recommends \
python3 \ ffmpeg &&\
python3-pip \
python3-venv \
ffmpeg \
git \
build-essential \
python3-dev \
ca-certificates && \
rm -rf /var/lib/apt/lists/* rm -rf /var/lib/apt/lists/*
RUN python3 -m venv /opt/venv # Copy UV binaries
ENV PATH="/opt/venv/bin:$PATH" COPY --from=uvbin /uv /uvx /bin/
# timeout/retries for large torch wheels # Copy the Python version
RUN pip3 install --upgrade pip setuptools wheel && \ COPY --from=builder-gpu --chown=python:python /python /python
pip3 --disable-pip-version-check install --timeout=120 --retries=5 \
--index-url https://download.pytorch.org/whl/cu129 \
torch torchaudio \
|| (echo "Initial install failed — retrying with extended timeout..." && \
pip3 --disable-pip-version-check install --timeout=300 --retries=3 \
--index-url https://download.pytorch.org/whl/cu129 \
torch torchvision torchaudio)
COPY . . # Copy the virtual environment with all dependencies installed
COPY --from=builder-gpu /app/.venv /app/.venv
# Install WhisperLiveKit directly, allowing for optional dependencies
# Example: --build-arg EXTRAS="translation"
RUN if [ -n "$EXTRAS" ]; then \
echo "Installing with extras: [$EXTRAS]"; \
pip install --no-cache-dir "whisperlivekit[$EXTRAS]"; \
else \
echo "Installing base package only"; \
pip install --no-cache-dir whisperlivekit; \
fi
# In-container caching for Hugging Face models by:
# A) Make the cache directory persistent via an anonymous volume.
# Note: This only persists for a single, named container. This is
# only for convenience at de/test stage.
# For prod, it is better to use a named volume via host mount/k8s.
VOLUME ["/root/.cache/huggingface/hub"]
# or
# B) Conditionally copy a local pre-cache from the build context to the
# container's cache via the HF_PRECACHE_DIR build-arg.
# WARNING: This will copy ALL files in the pre-cache location.
# Conditionally copy a cache directory if provided
RUN if [ -n "$HF_PRECACHE_DIR" ]; then \
echo "Copying Hugging Face cache from $HF_PRECACHE_DIR"; \
mkdir -p /root/.cache/huggingface/hub && \
cp -r $HF_PRECACHE_DIR/* /root/.cache/huggingface/hub; \
else \
echo "No local Hugging Face cache specified, skipping copy"; \
fi
# Conditionally copy a Hugging Face token if provided. Useful for Diart backend (pyannote audio models)
RUN if [ -n "$HF_TKN_FILE" ]; then \
echo "Copying Hugging Face token from $HF_TKN_FILE"; \
mkdir -p /root/.cache/huggingface && \
cp $HF_TKN_FILE /root/.cache/huggingface/token; \
else \
echo "No Hugging Face token file specified, skipping token setup"; \
fi
EXPOSE 8000 EXPOSE 8000
ENV PATH="/app/.venv/bin:$PATH"
ENV UV_PYTHON_DOWNLOADS=0
HEALTHCHECK --interval=30s --timeout=5s --start-period=120s --retries=3 \ HEALTHCHECK --interval=30s --timeout=5s --start-period=120s --retries=3 \
CMD python3 -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/')" || exit 1 CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/')" || exit 1
ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"] ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"]

View File

@@ -1,62 +1,74 @@
FROM python:3.13-slim FROM ghcr.io/astral-sh/uv:0.10.4 AS uvbin
# --- MARK: Builder Stage
FROM debian:bookworm-slim AS builder-cpu
ENV DEBIAN_FRONTEND=noninteractive ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1 ENV PYTHONUNBUFFERED=1
WORKDIR /app WORKDIR /app
ARG EXTRAS
ARG HF_PRECACHE_DIR
ARG HF_TKN_FILE
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y --no-install-recommends \ apt-get install -y --no-install-recommends \
ffmpeg \
git \
build-essential \ build-essential \
python3-dev && \ python3-dev && \
rm -rf /var/lib/apt/lists/* rm -rf /var/lib/apt/lists/*
# Install CPU-only PyTorch # Install UV and set up the environment
RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu COPY --from=uvbin /uv /uvx /bin/
COPY . . ENV UV_COMPILE_BYTECODE=1 UV_LINK_MODE=copy UV_NO_DEV=1
ENV UV_PYTHON_PREFERENCE=only-managed
ENV UV_PYTHON_INSTALL_DIR=/python
# Install WhisperLiveKit directly, allowing for optional dependencies RUN uv python install 3.12
RUN if [ -n "$EXTRAS" ]; then \
echo "Installing with extras: [$EXTRAS]"; \
pip install --no-cache-dir whisperlivekit[$EXTRAS]; \
else \
echo "Installing base package only"; \
pip install --no-cache-dir whisperlivekit; \
fi
# Enable in-container caching for Hugging Face models # Install dependencies first to leverage caching
VOLUME ["/root/.cache/huggingface/hub"] ARG EXTRAS=cpu
COPY pyproject.toml uv.lock /app/
RUN set -eux; \
set --; \
for extra in $(echo "${EXTRAS:-}" | tr ',' ' '); do \
set -- "$@" --extra "$extra"; \
done; \
uv sync --frozen --no-install-project --no-editable --no-cache "$@"
# Conditionally copy a local pre-cache from the build context # Copy the source code and install the package only
RUN if [ -n "$HF_PRECACHE_DIR" ]; then \ COPY whisperlivekit /app/whisperlivekit
echo "Copying Hugging Face cache from $HF_PRECACHE_DIR"; \ RUN set -eux; \
mkdir -p /root/.cache/huggingface/hub && \ set --; \
cp -r $HF_PRECACHE_DIR/* /root/.cache/huggingface/hub; \ for extra in $(echo "${EXTRAS:-}" | tr ',' ' '); do \
else \ set -- "$@" --extra "$extra"; \
echo "No local Hugging Face cache specified, skipping copy"; \ done; \
fi uv sync --frozen --no-editable --no-cache "$@"
# Conditionally copy a Hugging Face token if provided # --- MARK: Runtime Stage
RUN if [ -n "$HF_TKN_FILE" ]; then \ FROM debian:bookworm-slim
echo "Copying Hugging Face token from $HF_TKN_FILE"; \
mkdir -p /root/.cache/huggingface && \ ENV DEBIAN_FRONTEND=noninteractive
cp $HF_TKN_FILE /root/.cache/huggingface/token; \
else \ WORKDIR /app
echo "No Hugging Face token file specified, skipping token setup"; \
fi RUN apt-get update && \
apt-get install -y --no-install-recommends \
ffmpeg &&\
rm -rf /var/lib/apt/lists/*
# Copy UV binaries
COPY --from=uvbin /uv /uvx /bin/
# Copy the Python version
COPY --from=builder-cpu --chown=python:python /python /python
# Copy the virtual environment with all dependencies installed
COPY --from=builder-cpu /app/.venv /app/.venv
# Expose port for the transcription server
EXPOSE 8000 EXPOSE 8000
ENV PATH="/app/.venv/bin:$PATH"
ENV UV_PYTHON_DOWNLOADS=0
HEALTHCHECK --interval=30s --timeout=5s --start-period=120s --retries=3 \ HEALTHCHECK --interval=30s --timeout=5s --start-period=120s --retries=3 \
CMD python3 -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/')" || exit 1 CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/')" || exit 1
ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"] ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"]

View File

@@ -18,9 +18,10 @@
</p> </p>
#### Powered by Leading Research: ### Powered by Leading Research:
- Simul-[Whisper](https://arxiv.org/pdf/2406.10052)/[Streaming](https://arxiv.org/abs/2506.17077) (SOTA 2025) - Ultra-low latency transcription using [AlignAtt policy](https://arxiv.org/pdf/2305.11408) **See the interactive playground in [this repo](https://github.com/QuentinFuxa/streamlit-d3-network) to explore how AlignAtt works**
- Simul-[Whisper](https://arxiv.org/pdf/2406.10052)/[Streaming](https://arxiv.org/abs/2506.17077) (SOTA 2025) - Ultra-low latency transcription using [AlignAtt policy](https://arxiv.org/pdf/2305.11408).
- [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting) (2025), based on [distilled](https://huggingface.co/entai2965/nllb-200-distilled-600M-ctranslate2) [NLLB](https://arxiv.org/abs/2207.04672) (2022, 2024) - Simulatenous translation from & to 200 languages. - [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting) (2025), based on [distilled](https://huggingface.co/entai2965/nllb-200-distilled-600M-ctranslate2) [NLLB](https://arxiv.org/abs/2207.04672) (2022, 2024) - Simulatenous translation from & to 200 languages.
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - Low latency transcription using [LocalAgreement policy](https://www.isca-archive.org/interspeech_2020/liu20s_interspeech.pdf) - [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - Low latency transcription using [LocalAgreement policy](https://www.isca-archive.org/interspeech_2020/liu20s_interspeech.pdf)
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - Advanced real-time speaker diarization - [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - Advanced real-time speaker diarization
@@ -72,15 +73,29 @@ Go to `chrome-extension` for instructions.
#### Optional Dependencies #### Optional Dependencies
| Optional | `pip install` | | Feature | `uv sync` | `pip install -e` |
|-----------|-------------| |-----------|-------------|-------------|
| **Windows/Linux optimizations** | `faster-whisper` | | **Apple Silicon MLX Whisper backend** | `uv sync --extra mlx-whisper` | `pip install -e ".[mlx-whisper]"` |
| **Apple Silicon optimizations** | `mlx-whisper` | | **Voxtral (MLX backend, Apple Silicon)** | `uv sync --extra voxtral-mlx` | `pip install -e ".[voxtral-mlx]"` |
| **Voxtral (multilingual, auto-detect)** | `transformers torch` (or use built-in `voxtral-mlx` on Apple Silicon) | | **CPU PyTorch stack** | `uv sync --extra cpu` | `pip install -e ".[cpu]"` |
| **Translation** | `nllw` | | **CUDA 12.9 PyTorch stack** | `uv sync --extra cu129` | `pip install -e ".[cu129]"` |
| **Speaker diarization** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` | | **Translation** | `uv sync --extra translation` | `pip install -e ".[translation]"` |
| OpenAI API | `openai` | | **Sentence tokenizer** | `uv sync --extra sentence_tokenizer` | `pip install -e ".[sentence_tokenizer]"` |
| *[Not recommanded]* Speaker diarization with Diart | `diart` | | **Voxtral (HF backend)** | `uv sync --extra voxtral-hf` | `pip install -e ".[voxtral-hf]"` |
| **Speaker diarization (Sortformer / NeMo)** | `uv sync --extra diarization-sortformer` | `pip install -e ".[diarization-sortformer]"` |
| *[Not recommended]* Speaker diarization with Diart | `uv sync --extra diarization-diart` | `pip install -e ".[diarization-diart]"` |
Supported GPU profiles:
```bash
# Profile A: Sortformer diarization
uv sync --extra cu129 --extra diarization-sortformer
# Profile B: Voxtral HF + translation
uv sync --extra cu129 --extra voxtral-hf --extra translation
```
`voxtral-hf` and `diarization-sortformer` are intentionally incompatible extras and must be installed in separate environments.
See **Parameters & Configuration** below on how to use them. See **Parameters & Configuration** below on how to use them.
@@ -102,6 +117,7 @@ detection is more reliable and does not bias towards English.
```bash ```bash
# Apple Silicon (native MLX, recommended) # Apple Silicon (native MLX, recommended)
pip install -e ".[voxtral-mlx]"
wlk --backend voxtral-mlx wlk --backend voxtral-mlx
# Linux/GPU (HuggingFace transformers) # Linux/GPU (HuggingFace transformers)
@@ -279,7 +295,7 @@ docker run --gpus all -p 8000:8000 --name wlk wlk
**CPU only:** **CPU only:**
```bash ```bash
docker build -f Dockerfile.cpu -t wlk . docker build -f Dockerfile.cpu -t wlk --build-arg EXTRAS="cpu" .
docker run -p 8000:8000 --name wlk wlk docker run -p 8000:8000 --name wlk wlk
``` ```
@@ -291,6 +307,18 @@ docker run -p 8000:8000 --name wlk wlk
docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
``` ```
**Compose (recommended for cache + token wiring):**
```bash
# GPU Sortformer profile
docker compose up --build wlk-gpu-sortformer
# GPU Voxtral profile
docker compose up --build wlk-gpu-voxtral
# CPU service
docker compose up --build wlk-cpu
```
### Memory Requirements ### Memory Requirements
- **Large models**: Ensure your Docker runtime has sufficient memory allocated - **Large models**: Ensure your Docker runtime has sufficient memory allocated
@@ -298,9 +326,10 @@ docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
#### Customization #### Customization
- `--build-arg` Options: - `--build-arg` Options:
- `EXTRAS="translation"` - Add extras to the image's installation (no spaces). Remember to set necessary container options! - `EXTRAS="cu129,diarization-sortformer"` - GPU Sortformer profile extras.
- `HF_PRECACHE_DIR="./.cache/"` - Pre-load a model cache for faster first-time start - `EXTRAS="cu129,voxtral-hf,translation"` - GPU Voxtral profile extras.
- `HF_TKN_FILE="./token"` - Add your Hugging Face Hub access token to download gated models - `EXTRAS="cpu,diarization-diart,translation"` - CPU profile extras.
- Hugging Face cache + token are configured in `compose.yml` using a named volume and `HF_TKN_FILE` (default: `./token`).
## Testing & Benchmarks ## Testing & Benchmarks

52
compose.yml Normal file
View File

@@ -0,0 +1,52 @@
services:
wlk-gpu-sortformer:
build:
context: .
dockerfile: Dockerfile
args:
EXTRAS: ${GPU_SORTFORMER_EXTRAS:-cu129,diarization-sortformer}
image: wlk:gpu-sortformer
gpus: all
ports:
- "8000:8000"
volumes:
- hf-cache:/root/.cache/huggingface/hub
# - ${HF_TKN_FILE:-./token}:/root/.cache/huggingface/token:ro
environment:
- HF_TOKEN
command: ["--model", "medium", "--diarization", "--pcm-input"]
wlk-gpu-voxtral:
build:
context: .
dockerfile: Dockerfile
args:
EXTRAS: ${GPU_VOXTRAL_EXTRAS:-cu129,voxtral-hf,translation}
image: wlk:gpu-voxtral
gpus: all
ports:
- "8001:8000"
volumes:
- hf-cache:/root/.cache/huggingface/hub
# - ${HF_TKN_FILE:-./token}:/root/.cache/huggingface/token:ro
environment:
- HF_TOKEN
command: ["--backend", "voxtral", "--pcm-input"]
wlk-cpu:
build:
context: .
dockerfile: Dockerfile.cpu
args:
EXTRAS: ${CPU_EXTRAS:-cpu,diarization-diart,translation}
image: wlk:cpu
ports:
- "8000:8000"
volumes:
- hf-cache:/root/.cache/huggingface/hub
# - ${HF_TKN_FILE:-./token}:/root/.cache/huggingface/token:ro
environment:
- HF_TOKEN
volumes:
hf-cache:

View File

@@ -7,24 +7,18 @@ name = "whisperlivekit"
version = "0.2.19" version = "0.2.19"
description = "Real-time speech-to-text with speaker diarization using Whisper" description = "Real-time speech-to-text with speaker diarization using Whisper"
readme = "README.md" readme = "README.md"
authors = [ authors = [{ name = "Quentin Fuxa" }]
{ name = "Quentin Fuxa" }
]
license = { file = "LICENSE" } license = { file = "LICENSE" }
requires-python = ">=3.9" requires-python = ">=3.11, <3.14"
classifiers = [ classifiers = [
"Development Status :: 4 - Beta", "Development Status :: 4 - Beta",
"Intended Audience :: Developers", "Intended Audience :: Developers",
"License :: OSI Approved :: MIT License", "License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13", "Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
"Programming Language :: Python :: 3.15",
"Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Multimedia :: Sound/Audio :: Speech" "Topic :: Multimedia :: Sound/Audio :: Speech",
] ]
dependencies = [ dependencies = [
"fastapi", "fastapi",
@@ -32,20 +26,91 @@ dependencies = [
"soundfile", "soundfile",
"uvicorn", "uvicorn",
"websockets", "websockets",
"torchaudio>=2.0.0",
"torch>=2.0.0",
"huggingface-hub>=0.25.0", "huggingface-hub>=0.25.0",
"faster-whisper>=1.2.0", "faster-whisper>=1.2.0",
"torch>=2.0.0",
"torchaudio>=2.0.0",
"tqdm", "tqdm",
"tiktoken", "tiktoken",
'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")'
] ]
[project.optional-dependencies] [project.optional-dependencies]
test = ["pytest>=7.0", "pytest-asyncio>=0.21"] test = ["pytest>=7.0", "pytest-asyncio>=0.21"]
translation = ["nllw"] translation = ["nllw"]
sentence_tokenizer = ["mosestokenizer", "wtpsplit"] sentence_tokenizer = ["mosestokenizer", "wtpsplit"]
voxtral-hf = ["transformers>=5.2.0", "mistral-common[audio]"] mlx-whisper = [
'mlx>=0.11.0; sys_platform == "darwin" and platform_machine == "arm64"',
'mlx-whisper>=0.4.0; sys_platform == "darwin" and platform_machine == "arm64"',
]
voxtral-mlx = [
'mlx>=0.11.0; sys_platform == "darwin" and platform_machine == "arm64"',
'mlx-whisper>=0.4.0; sys_platform == "darwin" and platform_machine == "arm64"',
"mistral-common[audio]",
]
voxtral-hf = [
"transformers>=5.2.0; python_version >= '3.10'",
"mistral-common[audio]",
"accelerate>=0.12",
]
cpu = ["torch>=2.0.0", "torchaudio>=2.0.0"]
cu129 = [
"torch>=2.0.0",
"torchaudio>=2.0.0",
'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")',
]
diarization-sortformer = [
"nemo-toolkit[asr]>2.4; python_version >= '3.10' and python_version < '3.13'",
]
diarization-diart = [
"diart",
"torch<2.9.0",
"torchaudio<2.9.0",
"torchvision<0.24.0",
]
[dependency-groups]
dev = ["rich>=14.3.3"]
[tool.uv]
conflicts = [
[
{ extra = "cpu" },
{ extra = "cu129" },
],
[
{ extra = "diarization-diart" },
{ extra = "cu129" },
],
[
{ extra = "voxtral-hf" },
{ extra = "diarization-sortformer" },
],
]
[tool.uv.sources]
torch = [
{ index = "pytorch-cpu", extra = "cpu", marker = "platform_system != 'Darwin'" },
{ index = "pytorch-cpu", extra = "diarization-diart", marker = "platform_system != 'Darwin'" },
{ index = "pytorch-cu129", extra = "cu129", marker = "platform_system == 'Linux' and platform_machine == 'x86_64'" },
]
torchaudio = [
{ index = "pytorch-cpu", extra = "cpu", marker = "platform_system != 'Darwin'" },
{ index = "pytorch-cpu", extra = "diarization-diart", marker = "platform_system != 'Darwin'" },
{ index = "pytorch-cu129", extra = "cu129", marker = "platform_system == 'Linux' and platform_machine == 'x86_64'" },
]
torchvision = [
{ index = "pytorch-cpu", extra = "diarization-diart", marker = "platform_system != 'Darwin'" },
]
[[tool.uv.index]]
name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu"
explicit = true
[[tool.uv.index]]
name = "pytorch-cu129"
url = "https://download.pytorch.org/whl/cu129"
explicit = true
[project.urls] [project.urls]
Homepage = "https://github.com/QuentinFuxa/WhisperLiveKit" Homepage = "https://github.com/QuentinFuxa/WhisperLiveKit"
@@ -66,7 +131,7 @@ packages = [
"whisperlivekit.web", "whisperlivekit.web",
"whisperlivekit.local_agreement", "whisperlivekit.local_agreement",
"whisperlivekit.voxtral_mlx", "whisperlivekit.voxtral_mlx",
"whisperlivekit.silero_vad_models" "whisperlivekit.silero_vad_models",
] ]
[tool.setuptools.package-data] [tool.setuptools.package-data]

View File

@@ -0,0 +1,580 @@
#!/usr/bin/env python3
"""Offline Python support matrix runner for WhisperLiveKit."""
from __future__ import annotations
import argparse
import os
import shlex
import shutil
import subprocess
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Literal
try:
from rich.console import Console
from rich.table import Table
HAS_RICH = True
except Exception:
HAS_RICH = False
SAMPLE_URL = (
"https://github.com/pyannote/pyannote-audio/raw/develop/tutorials/assets/sample.wav"
)
SAMPLE_PATH = Path("audio_tests/support-matrix-sample.wav")
DEFAULT_LOGS_DIR = Path("outputs/python-matrix/logs")
PYTHON_VERSIONS = ("3.11", "3.12", "3.13")
CONSOLE = Console() if HAS_RICH else None
@dataclass(frozen=True)
class MatrixRow:
row_id: str
extras: tuple[str, ...]
backend: str
policy: str
diarization_backend: str
requires_gpu: bool = False
CASES = (
MatrixRow(
row_id="fw-diart-cpu",
extras=("test", "cpu", "diarization-diart"),
backend="faster-whisper",
policy="simulstreaming",
diarization_backend="diart",
),
MatrixRow(
row_id="fw-sortformer-cpu",
extras=("test", "cpu", "diarization-sortformer"),
backend="faster-whisper",
policy="simulstreaming",
diarization_backend="sortformer",
),
MatrixRow(
row_id="fw-sortformer-gpu",
extras=("test", "cu129", "diarization-sortformer"),
backend="faster-whisper",
policy="simulstreaming",
diarization_backend="sortformer",
requires_gpu=True,
),
MatrixRow(
row_id="voxtral-diart-cpu",
extras=("test", "cpu", "voxtral-hf", "diarization-diart"),
backend="voxtral",
policy="voxtral",
diarization_backend="diart",
),
)
EXPECTED_FAILURE_CASES = {
("3.11", "voxtral-diart-cpu"): "known_unstable_voxtral_diart_cpu",
("3.12", "voxtral-diart-cpu"): "known_unstable_voxtral_diart_cpu",
}
UNSUPPORTED_CASES = {
("3.13", "fw-sortformer-cpu"): "unsupported_py313_sortformer_protobuf",
("3.13", "fw-sortformer-gpu"): "unsupported_py313_sortformer_protobuf",
}
@dataclass(frozen=True)
class CaseResult:
python_version: str
row_id: str
status: Literal["PASS", "FAIL", "N/A"]
reason: str
duration_sec: float
hint: str = ""
log_path: str = ""
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Minimal WhisperLiveKit offline support matrix"
)
parser.add_argument(
"--timeout-sec",
type=int,
default=300,
help="Per-case timeout in seconds (default: 300)",
)
parser.add_argument(
"--logs-dir",
default=str(DEFAULT_LOGS_DIR),
help="Directory where per-case logs are written (default: outputs/python-matrix/logs)",
)
return parser.parse_args()
def safe_slug(text: str) -> str:
return text.replace("=", "-").replace("|", "__").replace("/", "-").replace(" ", "-")
def status_style(status: str) -> str:
if status == "PASS":
return "green"
if status == "FAIL":
return "bold red"
if status == "N/A":
return "yellow"
return "white"
def print_line(message: str, style: str | None = None) -> None:
if CONSOLE is None:
print(message)
return
if style:
CONSOLE.print(message, style=style, highlight=False)
else:
CONSOLE.print(message, highlight=False)
def tail_text(text: str | None, max_chars: int = 220) -> str:
if not text:
return ""
normalized = " ".join(text.split())
if len(normalized) <= max_chars:
return normalized
return normalized[-max_chars:]
def run_command(
cmd: list[str],
cwd: Path,
env: dict[str, str],
timeout: int | None = None,
log_path: Path | None = None,
log_section: str | None = None,
) -> subprocess.CompletedProcess[str]:
def _append_log(
*,
command: list[str],
section: str,
returncode: int | None,
stdout: str | None,
stderr: str | None,
timed_out: bool = False,
) -> None:
if log_path is None:
return
log_path.parent.mkdir(parents=True, exist_ok=True)
with log_path.open("a", encoding="utf-8") as f:
f.write(f"\n=== {section} ===\n")
f.write(f"$ {shlex.join(command)}\n")
if timed_out:
f.write("status: timeout\n")
else:
f.write(f"status: exit_code={returncode}\n")
if stdout:
f.write("--- stdout ---\n")
f.write(stdout)
if not stdout.endswith("\n"):
f.write("\n")
if stderr:
f.write("--- stderr ---\n")
f.write(stderr)
if not stderr.endswith("\n"):
f.write("\n")
section = log_section or "command"
try:
proc = subprocess.run(
cmd,
cwd=str(cwd),
env=env,
text=True,
capture_output=True,
check=False,
timeout=timeout,
)
except subprocess.TimeoutExpired as exc:
_append_log(
command=cmd,
section=section,
returncode=None,
stdout=exc.stdout if isinstance(exc.stdout, str) else None,
stderr=exc.stderr if isinstance(exc.stderr, str) else None,
timed_out=True,
)
raise
_append_log(
command=cmd,
section=section,
returncode=proc.returncode,
stdout=proc.stdout,
stderr=proc.stderr,
)
return proc
def detect_gpu_available() -> bool:
try:
proc = subprocess.run(
["nvidia-smi", "-L"],
text=True,
capture_output=True,
check=False,
timeout=10,
)
except (FileNotFoundError, subprocess.TimeoutExpired):
return False
return proc.returncode == 0
def download_sample(repo_root: Path) -> Path:
target = repo_root / SAMPLE_PATH
target.parent.mkdir(parents=True, exist_ok=True)
cmd = [
"curl",
"--fail",
"--location",
"--silent",
"--show-error",
SAMPLE_URL,
"--output",
str(target),
]
proc = run_command(cmd, cwd=repo_root, env=os.environ.copy())
if proc.returncode != 0:
hint = tail_text(proc.stderr or proc.stdout)
raise RuntimeError(f"sample_download_failed: {hint}")
return target
def sync_case_environment(
repo_root: Path,
python_version: str,
row: MatrixRow,
env_dir: Path,
log_path: Path,
) -> tuple[bool, str]:
cmd = ["uv", "sync", "--python", python_version, "--no-dev"]
for extra in row.extras:
cmd.extend(["--extra", extra])
env = os.environ.copy()
env["UV_PROJECT_ENVIRONMENT"] = str(env_dir)
proc = run_command(
cmd,
cwd=repo_root,
env=env,
log_path=log_path,
log_section="sync",
)
if proc.returncode != 0:
return False, tail_text(proc.stderr or proc.stdout)
return True, ""
def apply_expected_failure_policy(result: CaseResult) -> CaseResult:
expected_reason = EXPECTED_FAILURE_CASES.get((result.python_version, result.row_id))
if result.status != "FAIL" or not expected_reason:
return result
override_hint = result.hint
if result.reason:
override_hint = (
f"expected_failure_override original_reason={result.reason}; {override_hint}"
if override_hint
else f"expected_failure_override original_reason={result.reason}"
)
return CaseResult(
python_version=result.python_version,
row_id=result.row_id,
status="N/A",
reason=expected_reason,
duration_sec=result.duration_sec,
hint=override_hint,
log_path=result.log_path,
)
def build_offline_command(
python_version: str,
row: MatrixRow,
sample_audio: Path,
timeout_sec: int,
) -> tuple[list[str], int | None]:
base_cmd = [
"uv",
"run",
"--python",
python_version,
"--no-sync",
"python",
"test_backend_offline.py",
"--backend",
row.backend,
"--policy",
row.policy,
"--audio",
str(sample_audio),
"--model",
"tiny",
"--diarization",
"--diarization-backend",
row.diarization_backend,
"--lan",
"en",
"--no-realtime",
]
if shutil.which("timeout"):
return ["timeout", str(timeout_sec), *base_cmd], None
return base_cmd, timeout_sec
def run_case(
repo_root: Path,
python_version: str,
row: MatrixRow,
sample_audio: Path,
timeout_sec: int,
gpu_available: bool,
logs_dir: Path,
) -> CaseResult:
start = time.monotonic()
case_slug = safe_slug(f"py{python_version}-{row.row_id}")
log_path = logs_dir / f"run-{case_slug}.log"
log_path.parent.mkdir(parents=True, exist_ok=True)
log_path.write_text("", encoding="utf-8")
unsupported_reason = UNSUPPORTED_CASES.get((python_version, row.row_id))
if unsupported_reason:
log_path.write_text(
f"[matrix] precheck_short_circuit status=N/A reason={unsupported_reason}\n",
encoding="utf-8",
)
return CaseResult(
python_version=python_version,
row_id=row.row_id,
status="N/A",
reason=unsupported_reason,
duration_sec=0.0,
hint="unsupported_case_precheck",
log_path=str(log_path),
)
if row.requires_gpu and not gpu_available:
return CaseResult(
python_version=python_version,
row_id=row.row_id,
status="N/A",
reason="gpu_unavailable",
duration_sec=0.0,
hint="nvidia-smi unavailable or failed",
log_path=str(log_path),
)
env_dir = repo_root / ".matrix-envs" / safe_slug(f"py{python_version}-{row.row_id}")
sync_ok, sync_hint = sync_case_environment(
repo_root,
python_version,
row,
env_dir,
log_path=log_path,
)
if not sync_ok:
return CaseResult(
python_version=python_version,
row_id=row.row_id,
status="FAIL",
reason="dependency_sync_failed",
duration_sec=round(time.monotonic() - start, 3),
hint=sync_hint,
log_path=str(log_path),
)
cmd, process_timeout = build_offline_command(
python_version, row, sample_audio, timeout_sec
)
env = os.environ.copy()
env["UV_PROJECT_ENVIRONMENT"] = str(env_dir)
if row.requires_gpu:
env.pop("CUDA_VISIBLE_DEVICES", None)
else:
env["CUDA_VISIBLE_DEVICES"] = ""
try:
proc = run_command(
cmd,
cwd=repo_root,
env=env,
timeout=process_timeout,
log_path=log_path,
log_section="offline",
)
except subprocess.TimeoutExpired as exc:
return CaseResult(
python_version=python_version,
row_id=row.row_id,
status="FAIL",
reason="offline_timeout",
duration_sec=round(time.monotonic() - start, 3),
hint=tail_text((exc.stderr or "") if isinstance(exc.stderr, str) else ""),
log_path=str(log_path),
)
hint = tail_text(proc.stderr or proc.stdout)
if proc.returncode == 0:
return CaseResult(
python_version=python_version,
row_id=row.row_id,
status="PASS",
reason="ok",
duration_sec=round(time.monotonic() - start, 3),
hint=hint,
log_path=str(log_path),
)
reason = "offline_timeout" if proc.returncode == 124 else "offline_run_failed"
return CaseResult(
python_version=python_version,
row_id=row.row_id,
status="FAIL",
reason=reason,
duration_sec=round(time.monotonic() - start, 3),
hint=hint,
log_path=str(log_path),
)
def print_summary(results: list[CaseResult]) -> None:
pass_count = sum(1 for row in results if row.status == "PASS")
fail_count = sum(1 for row in results if row.status == "FAIL")
na_count = sum(1 for row in results if row.status == "N/A")
if CONSOLE is None:
print("\n[matrix] results")
print("python | row | status | reason | duration_s")
print("---|---|---|---|---")
for result in results:
print(
f"{result.python_version} | {result.row_id} | {result.status} | "
f"{result.reason} | {result.duration_sec:.3f}"
)
print(
f"\n[matrix] summary pass={pass_count} fail={fail_count} "
f"na={na_count} total={len(results)}"
)
else:
table = Table(title="Support Matrix Results")
table.add_column("Python", style="cyan", no_wrap=True)
table.add_column("Row", style="white")
table.add_column("Status", no_wrap=True)
table.add_column("Reason")
table.add_column("Duration (s)", justify="right", no_wrap=True)
for result in results:
table.add_row(
result.python_version,
result.row_id,
f"[{status_style(result.status)}]{result.status}[/{status_style(result.status)}]",
result.reason,
f"{result.duration_sec:.3f}",
)
CONSOLE.print()
CONSOLE.print(table)
CONSOLE.print(
f"[bold]Summary[/bold] "
f"pass=[green]{pass_count}[/green] "
f"fail=[bold red]{fail_count}[/bold red] "
f"na=[yellow]{na_count}[/yellow] "
f"total={len(results)}"
)
diagnostics = [row for row in results if row.status in {"FAIL", "N/A"} and row.hint]
if diagnostics:
if CONSOLE is None:
print("\n[matrix] diagnostics (failed/n-a cases)")
for row in diagnostics:
print(
f"- py={row.python_version} row={row.row_id} "
f"status={row.status} reason={row.reason}"
)
print(f" hint: {row.hint}")
if row.log_path:
print(f" log: {row.log_path}")
else:
diagnostics_table = Table(title="Diagnostics (FAIL / N/A)")
diagnostics_table.add_column("Case", style="cyan")
diagnostics_table.add_column("Status", no_wrap=True)
diagnostics_table.add_column("Reason")
diagnostics_table.add_column("Hint")
diagnostics_table.add_column("Log")
for row in diagnostics:
diagnostics_table.add_row(
f"py={row.python_version} {row.row_id}",
f"[{status_style(row.status)}]{row.status}[/{status_style(row.status)}]",
row.reason,
row.hint,
row.log_path,
)
CONSOLE.print()
CONSOLE.print(diagnostics_table)
def main() -> int:
args = parse_args()
if args.timeout_sec <= 0:
print("[matrix] error: --timeout-sec must be > 0", file=sys.stderr)
return 1
repo_root = Path(__file__).resolve().parents[1]
logs_dir = (repo_root / args.logs_dir).resolve()
logs_dir.mkdir(parents=True, exist_ok=True)
print_line(f"[matrix] repo_root={repo_root}", style="cyan")
print_line(f"[matrix] timeout_sec={args.timeout_sec}", style="cyan")
print_line(f"[matrix] logs_dir={logs_dir}", style="cyan")
try:
sample_audio = download_sample(repo_root)
except Exception as exc: # pragma: no cover - straightforward failure path
if CONSOLE is None:
print(f"[matrix] sample_download_failed: {exc}", file=sys.stderr)
else:
CONSOLE.print(
f"[matrix] sample_download_failed: {exc}",
style="bold red",
highlight=False,
)
return 1
print_line(f"[matrix] sample_audio={sample_audio}", style="cyan")
gpu_available = detect_gpu_available()
print_line(f"[matrix] gpu_available={gpu_available}", style="cyan")
results: list[CaseResult] = []
for python_version in PYTHON_VERSIONS:
for row in CASES:
print_line(
f"\n[matrix] running py={python_version} row={row.row_id}", style="blue"
)
result = run_case(
repo_root=repo_root,
python_version=python_version,
row=row,
sample_audio=sample_audio,
timeout_sec=args.timeout_sec,
gpu_available=gpu_available,
logs_dir=logs_dir,
)
result = apply_expected_failure_policy(result)
results.append(result)
print_line(
f"[matrix] {result.status} py={result.python_version} "
f"row={result.row_id} reason={result.reason} duration={result.duration_sec:.3f}s",
style=status_style(result.status),
)
if result.log_path:
print_line(f"[matrix] log={result.log_path}", style="dim")
print_summary(results)
fail_count = sum(1 for row in results if row.status == "FAIL")
return 1 if fail_count else 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -150,7 +150,10 @@ def float32_to_s16le_bytes(audio: np.ndarray) -> bytes:
def create_engine( def create_engine(
backend: str, model_size: str, lan: str, backend: str, model_size: str, lan: str,
diarization: bool = False, vac: bool = True, policy: str = "", diarization: bool = False,
diarization_backend: str = "",
vac: bool = True,
policy: str = "",
): ):
"""Create a TranscriptionEngine with the given backend config.""" """Create a TranscriptionEngine with the given backend config."""
import gc import gc
@@ -169,6 +172,8 @@ def create_engine(
transcription=True, transcription=True,
diarization=diarization, diarization=diarization,
) )
if diarization_backend:
kwargs["diarization_backend"] = diarization_backend
if model_size: if model_size:
kwargs["model_size"] = model_size kwargs["model_size"] = model_size
if policy: if policy:
@@ -179,13 +184,18 @@ def create_engine(
def _extract_text_from_response(response_dict: dict) -> str: def _extract_text_from_response(response_dict: dict) -> str:
"""Extract full transcription text from a FrontData dict.""" """Extract full transcription text from a FrontData dict."""
def _strip_or_empty(value: object) -> str:
return value.strip() if isinstance(value, str) else ""
segments = response_dict.get("lines", []) segments = response_dict.get("lines", [])
full_text = " ".join( full_text = " ".join(
seg.get("text", "").strip() text
for seg in segments for seg in segments
if seg.get("text", "").strip() if isinstance(seg, dict)
for text in [_strip_or_empty(seg.get("text"))]
if text
) )
buf = response_dict.get("buffer_transcription", "").strip() buf = _strip_or_empty(response_dict.get("buffer_transcription"))
if buf: if buf:
full_text = f"{full_text} {buf}".strip() if full_text else buf full_text = f"{full_text} {buf}".strip() if full_text else buf
return full_text return full_text
@@ -236,7 +246,8 @@ async def run_test(
# Only print when transcription text actually changes # Only print when transcription text actually changes
current_text = _extract_text_from_response(d) current_text = _extract_text_from_response(d)
if current_text and current_text != last_printed_text: if current_text and current_text != last_printed_text:
buf = d.get("buffer_transcription", "").strip() buf = d.get("buffer_transcription")
buf = buf.strip() if isinstance(buf, str) else ""
committed = current_text committed = current_text
if buf and committed.endswith(buf): if buf and committed.endswith(buf):
committed = committed[:-len(buf)].strip() committed = committed[:-len(buf)].strip()
@@ -686,6 +697,12 @@ def main():
"--diarization", action="store_true", "--diarization", action="store_true",
help="Enable speaker diarization.", help="Enable speaker diarization.",
) )
parser.add_argument(
"--diarization-backend",
default="",
choices=["diart", "sortformer"],
help="Diarization backend when --diarization is enabled.",
)
parser.add_argument( parser.add_argument(
"--benchmark", action="store_true", "--benchmark", action="store_true",
help="Run benchmark across all detected backend+policy combinations.", help="Run benchmark across all detected backend+policy combinations.",
@@ -748,7 +765,10 @@ def main():
logger.info(f"Creating {args.backend} engine...") logger.info(f"Creating {args.backend} engine...")
engine = create_engine( engine = create_engine(
args.backend, args.model_size, args.lan, args.backend, args.model_size, args.lan,
diarization=args.diarization, vac=vac, policy=policy, diarization=args.diarization,
diarization_backend=args.diarization_backend,
vac=vac,
policy=policy,
) )
logger.info("Engine ready.") logger.info("Engine ready.")

6575
uv.lock generated Normal file

File diff suppressed because one or more lines are too long

View File

@@ -120,6 +120,7 @@ class AlignAttBase(ABC):
self.state.segments = [] self.state.segments = []
self.state.log_segments += 1 self.state.log_segments += 1
self.state.pending_incomplete_tokens = [] self.state.pending_incomplete_tokens = []
self.state.pending_retries = 0
def segments_len(self): def segments_len(self):
return sum(s.shape[0] for s in self.state.segments) / 16000 return sum(s.shape[0] for s in self.state.segments) / 16000
@@ -223,6 +224,7 @@ class AlignAttBase(ABC):
new_segment = False new_segment = False
logits = self._apply_token_suppression(logits) logits = self._apply_token_suppression(logits)
logits = self._apply_dry_penalty(logits, current_tokens)
current_tokens, completed = self._update_tokens( current_tokens, completed = self._update_tokens(
current_tokens, logits, sum_logprobs current_tokens, logits, sum_logprobs
) )
@@ -326,9 +328,13 @@ class AlignAttBase(ABC):
for word, word_tokens in zip(split_words, split_tokens): for word, word_tokens in zip(split_words, split_tokens):
if replacement_char in word: if replacement_char in word:
logger.warning(f"[UTF-8 Filter] Skipping: {repr(word)}") cleaned = word.replace(replacement_char, "")
if not cleaned.strip():
logger.debug(f"[UTF-8 Filter] Skipping: {repr(word)}")
timestamp_idx += len(word_tokens) timestamp_idx += len(word_tokens)
continue continue
logger.debug(f"[UTF-8 Filter] Cleaned {repr(word)} -> {repr(cleaned)}")
word = cleaned
try: try:
current_timestamp = l_absolute_timestamps[timestamp_idx] current_timestamp = l_absolute_timestamps[timestamp_idx]
@@ -354,21 +360,84 @@ class AlignAttBase(ABC):
def _handle_pending_tokens(self, split_words, split_tokens): def _handle_pending_tokens(self, split_words, split_tokens):
"""Handle incomplete UTF-8 tokens for next chunk.""" """Handle incomplete UTF-8 tokens for next chunk."""
self.state.pending_incomplete_tokens = []
MAX_PENDING_TOKENS = 10 MAX_PENDING_TOKENS = 10
MAX_PENDING_RETRIES = 2
replacement_char = "\ufffd" replacement_char = "\ufffd"
if split_words and replacement_char in split_words[-1]: if split_words and replacement_char in split_words[-1]:
if len(split_tokens[-1]) <= MAX_PENDING_TOKENS: self.state.pending_retries += 1
if self.state.pending_retries > MAX_PENDING_RETRIES:
logger.warning(
f"[UTF-8 Fix] Dropping {len(split_tokens[-1])} incomplete tokens "
f"after {MAX_PENDING_RETRIES} retries (won't resolve)"
)
self.state.pending_incomplete_tokens = []
self.state.pending_retries = 0
elif len(split_tokens[-1]) <= MAX_PENDING_TOKENS:
self.state.pending_incomplete_tokens = split_tokens[-1] self.state.pending_incomplete_tokens = split_tokens[-1]
logger.debug( logger.debug(
f"[UTF-8 Fix] Holding {len(self.state.pending_incomplete_tokens)} " f"[UTF-8 Fix] Holding {len(self.state.pending_incomplete_tokens)} "
f"incomplete tokens for next chunk" f"incomplete tokens for next chunk (retry {self.state.pending_retries})"
) )
else: else:
logger.warning( logger.warning(
f"[UTF-8 Fix] Skipping {len(split_tokens[-1])} tokens " f"[UTF-8 Fix] Skipping {len(split_tokens[-1])} tokens "
f"(exceeds limit of {MAX_PENDING_TOKENS}, likely hallucination)" f"(exceeds limit of {MAX_PENDING_TOKENS}, likely hallucination)"
) )
self.state.pending_incomplete_tokens = []
self.state.pending_retries = 0
else:
self.state.pending_incomplete_tokens = []
self.state.pending_retries = 0
# === Repetition penalty ===
def _apply_dry_penalty(self, logits, current_tokens):
"""DRY penalty v0: penalize tokens that would extend a verbatim repetition.
See https://github.com/oobabooga/text-generation-webui/pull/5677
Scans the decoded sequence for positions where the current suffix already
appeared --> for each such match, the token that followed it in the past is
penalised exponentially with the match length
"""
eot = self.tokenizer.eot
seq = current_tokens[0].tolist()
if len(seq) < 5:
return logits
last = seq[-1]
if last >= eot:
return logits
penalties = {}
for i in range(len(seq) - 2, -1, -1):
if seq[i] != last:
continue
next_tok = seq[i + 1]
if next_tok >= eot:
continue
length = 1
while length < 50:
j, k = i - length, len(seq) - 1 - length
if j < 0 or k <= i:
break
if seq[j] != seq[k] or seq[j] >= eot:
break
length += 1
if next_tok not in penalties or length > penalties[next_tok]:
penalties[next_tok] = length
if penalties:
max_len = max(penalties.values())
if max_len >= 4:
logger.debug(f"[DRY] penalising {len(penalties)} tokens (longest match: {max_len})")
for tok, length in penalties.items():
if length >= 2:
logits[:, tok] = logits[:, tok] - 1.0 * 2.0 ** (length - 2)
return logits
# === Abstract methods — subclass must implement === # === Abstract methods — subclass must implement ===

View File

@@ -200,9 +200,12 @@ class SimulStreamingASR:
if self.encoder_backend == "whisper": if self.encoder_backend == "whisper":
self.disable_fast_encoder = True self.disable_fast_encoder = True
if self.encoder_backend == "mlx-whisper" and platform.system() == "Darwin": # MLX full decoder disabled by default — MLXAlignAtt has known issues
if not hasattr(self, '_full_mlx_disabled'): # with token generation after punctuation. Users can opt-in with
self.use_full_mlx = True # --use-full-mlx if they want to test it.
# if self.encoder_backend == "mlx-whisper" and platform.system() == "Darwin":
# if not hasattr(self, '_full_mlx_disabled'):
# self.use_full_mlx = True
self.cfg = AlignAttConfig( self.cfg = AlignAttConfig(
tokenizer_is_multilingual= is_multilingual, tokenizer_is_multilingual= is_multilingual,

View File

@@ -25,6 +25,7 @@ class DecoderState:
context: Any = None context: Any = None
pending_incomplete_tokens: List[int] = field(default_factory=list) pending_incomplete_tokens: List[int] = field(default_factory=list)
pending_retries: int = 0
global_time_offset: float = 0.0 global_time_offset: float = 0.0
cumulative_time_offset: float = 0.0 cumulative_time_offset: float = 0.0
@@ -78,6 +79,7 @@ class DecoderState:
self.last_attend_frame = -rewind_threshold self.last_attend_frame = -rewind_threshold
self.cumulative_time_offset = 0.0 self.cumulative_time_offset = 0.0
self.pending_incomplete_tokens = [] self.pending_incomplete_tokens = []
self.pending_retries = 0
self.log_segments += 1 self.log_segments += 1
def full_reset(self, rewind_threshold: int = 200): def full_reset(self, rewind_threshold: int = 200):

View File

@@ -29,6 +29,7 @@ class MLXDecoderState:
context: Any = None context: Any = None
pending_incomplete_tokens: List[int] = field(default_factory=list) pending_incomplete_tokens: List[int] = field(default_factory=list)
pending_retries: int = 0
global_time_offset: float = 0.0 global_time_offset: float = 0.0
cumulative_time_offset: float = 0.0 cumulative_time_offset: float = 0.0
@@ -59,6 +60,7 @@ class MLXDecoderState:
self.last_attend_frame = -rewind_threshold self.last_attend_frame = -rewind_threshold
self.cumulative_time_offset = 0.0 self.cumulative_time_offset = 0.0
self.pending_incomplete_tokens = [] self.pending_incomplete_tokens = []
self.pending_retries = 0
self.log_segments += 1 self.log_segments += 1
def full_reset(self, rewind_threshold: int = 200): def full_reset(self, rewind_threshold: int = 200):

View File

@@ -296,10 +296,15 @@ class Tokenizer:
current_tokens.append(token) current_tokens.append(token)
decoded = self.decode_with_timestamps(current_tokens) decoded = self.decode_with_timestamps(current_tokens)
if ( try:
replacement_char not in decoded replacement_char_index = decoded.index(replacement_char)
or decoded_full[unicode_offset + decoded.index(replacement_char)] replacement_char_index += unicode_offset
== replacement_char except ValueError:
replacement_char_index = None
if replacement_char_index is None or (
replacement_char_index < len(decoded_full)
and decoded_full[replacement_char_index] == replacement_char
): ):
words.append(decoded) words.append(decoded)
word_tokens.append(current_tokens) word_tokens.append(current_tokens)