mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
Compare commits
20 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8bc0937c46 | ||
|
|
929cf7a26b | ||
|
|
abfaf06203 | ||
|
|
d1fe932241 | ||
|
|
c112ceffb6 | ||
|
|
4917406e06 | ||
|
|
b63f54e838 | ||
|
|
c56a53fbf4 | ||
|
|
66e58624b9 | ||
|
|
9366e067f9 | ||
|
|
866c25670c | ||
|
|
2553ef283e | ||
|
|
73e7fafc48 | ||
|
|
bbcebcb1fe | ||
|
|
4bb58dc7aa | ||
|
|
27ca028479 | ||
|
|
d24805cc18 | ||
|
|
994ce21365 | ||
|
|
132823dc09 | ||
|
|
d6d8c2635f |
13
.dockerignore
Normal file
13
.dockerignore
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
.git
|
||||||
|
.github
|
||||||
|
.venv
|
||||||
|
__pycache__
|
||||||
|
*.pyc
|
||||||
|
.pytest_cache
|
||||||
|
.mypy_cache
|
||||||
|
.ruff_cache
|
||||||
|
.cache
|
||||||
|
.tmp
|
||||||
|
.secrets
|
||||||
|
dist
|
||||||
|
build
|
||||||
61
.github/workflows/publish-docker.yml
vendored
Normal file
61
.github/workflows/publish-docker.yml
vendored
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
name: Publish Docker Images
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
tags:
|
||||||
|
- "v*"
|
||||||
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
tag:
|
||||||
|
description: "Image tag to publish (without image suffix)"
|
||||||
|
required: true
|
||||||
|
type: string
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
packages: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
docker:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
env:
|
||||||
|
IMAGE_TAG: ${{ github.event_name == 'workflow_dispatch' && github.event.inputs.tag || github.ref_name }}
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- image_suffix: cpu-diarization-sortformer
|
||||||
|
dockerfile: Dockerfile.cpu
|
||||||
|
extras: cpu,diarization-sortformer
|
||||||
|
- image_suffix: cu129-diarization-sortformer
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
extras: cu129,diarization-sortformer
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set lowercase owner
|
||||||
|
id: owner
|
||||||
|
run: echo "value=${GITHUB_REPOSITORY_OWNER,,}" >> "${GITHUB_OUTPUT}"
|
||||||
|
|
||||||
|
- name: Login to GHCR
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
registry: ghcr.io
|
||||||
|
username: ${{ github.actor }}
|
||||||
|
password: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
||||||
|
- name: Setup Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
|
- name: Build and push image
|
||||||
|
uses: docker/build-push-action@v6
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
file: ./${{ matrix.dockerfile }}
|
||||||
|
push: true
|
||||||
|
build-args: |
|
||||||
|
EXTRAS=${{ matrix.extras }}
|
||||||
|
tags: |
|
||||||
|
ghcr.io/${{ steps.owner.outputs.value }}/whisperlivekit:${{ env.IMAGE_TAG }}-${{ matrix.image_suffix }}
|
||||||
|
ghcr.io/${{ steps.owner.outputs.value }}/whisperlivekit:latest-${{ matrix.image_suffix }}
|
||||||
124
Dockerfile
124
Dockerfile
@@ -1,86 +1,74 @@
|
|||||||
FROM nvidia/cuda:12.9.1-cudnn-devel-ubuntu24.04
|
FROM ghcr.io/astral-sh/uv:0.10.4 AS uvbin
|
||||||
|
|
||||||
|
# --- MARK: Builder Stage
|
||||||
|
FROM nvidia/cuda:12.9.1-cudnn-devel-ubuntu24.04 AS builder-gpu
|
||||||
ENV DEBIAN_FRONTEND=noninteractive
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
ENV PYTHONUNBUFFERED=1
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
ARG EXTRAS
|
RUN apt-get update && \
|
||||||
ARG HF_PRECACHE_DIR
|
apt-get install -y --no-install-recommends \
|
||||||
ARG HF_TKN_FILE
|
build-essential \
|
||||||
|
python3-dev && \
|
||||||
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Install UV and set up the environment
|
||||||
|
COPY --from=uvbin /uv /uvx /bin/
|
||||||
|
|
||||||
|
ENV UV_COMPILE_BYTECODE=1 UV_LINK_MODE=copy UV_NO_DEV=1
|
||||||
|
ENV UV_PYTHON_PREFERENCE=only-managed
|
||||||
|
ENV UV_PYTHON_INSTALL_DIR=/python
|
||||||
|
|
||||||
|
RUN uv python install 3.12
|
||||||
|
|
||||||
|
# Install dependencies first to leverage caching
|
||||||
|
ARG EXTRAS=cu129
|
||||||
|
COPY pyproject.toml uv.lock /app/
|
||||||
|
RUN set -eux; \
|
||||||
|
set --; \
|
||||||
|
for extra in $(echo "${EXTRAS:-}" | tr ',' ' '); do \
|
||||||
|
set -- "$@" --extra "$extra"; \
|
||||||
|
done; \
|
||||||
|
uv sync --frozen --no-install-project --no-editable --no-cache "$@"
|
||||||
|
|
||||||
|
# Copy the source code and install the package only
|
||||||
|
COPY whisperlivekit /app/whisperlivekit
|
||||||
|
RUN set -eux; \
|
||||||
|
set --; \
|
||||||
|
for extra in $(echo "${EXTRAS:-}" | tr ',' ' '); do \
|
||||||
|
set -- "$@" --extra "$extra"; \
|
||||||
|
done; \
|
||||||
|
uv sync --frozen --no-editable --no-cache "$@"
|
||||||
|
|
||||||
|
# --- MARK: Runtime Stage
|
||||||
|
FROM nvidia/cuda:12.9.1-cudnn-runtime-ubuntu24.04
|
||||||
|
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
RUN apt-get update && \
|
RUN apt-get update && \
|
||||||
apt-get install -y --no-install-recommends \
|
apt-get install -y --no-install-recommends \
|
||||||
python3 \
|
ffmpeg &&\
|
||||||
python3-pip \
|
rm -rf /var/lib/apt/lists/*
|
||||||
python3-venv \
|
|
||||||
ffmpeg \
|
|
||||||
git \
|
|
||||||
build-essential \
|
|
||||||
python3-dev \
|
|
||||||
ca-certificates && \
|
|
||||||
rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
RUN python3 -m venv /opt/venv
|
# Copy UV binaries
|
||||||
ENV PATH="/opt/venv/bin:$PATH"
|
COPY --from=uvbin /uv /uvx /bin/
|
||||||
|
|
||||||
# timeout/retries for large torch wheels
|
# Copy the Python version
|
||||||
RUN pip3 install --upgrade pip setuptools wheel && \
|
COPY --from=builder-gpu --chown=python:python /python /python
|
||||||
pip3 --disable-pip-version-check install --timeout=120 --retries=5 \
|
|
||||||
--index-url https://download.pytorch.org/whl/cu129 \
|
|
||||||
torch torchaudio \
|
|
||||||
|| (echo "Initial install failed — retrying with extended timeout..." && \
|
|
||||||
pip3 --disable-pip-version-check install --timeout=300 --retries=3 \
|
|
||||||
--index-url https://download.pytorch.org/whl/cu129 \
|
|
||||||
torch torchvision torchaudio)
|
|
||||||
|
|
||||||
COPY . .
|
# Copy the virtual environment with all dependencies installed
|
||||||
|
COPY --from=builder-gpu /app/.venv /app/.venv
|
||||||
# Install WhisperLiveKit directly, allowing for optional dependencies
|
|
||||||
# Example: --build-arg EXTRAS="translation"
|
|
||||||
RUN if [ -n "$EXTRAS" ]; then \
|
|
||||||
echo "Installing with extras: [$EXTRAS]"; \
|
|
||||||
pip install --no-cache-dir "whisperlivekit[$EXTRAS]"; \
|
|
||||||
else \
|
|
||||||
echo "Installing base package only"; \
|
|
||||||
pip install --no-cache-dir whisperlivekit; \
|
|
||||||
fi
|
|
||||||
|
|
||||||
# In-container caching for Hugging Face models by:
|
|
||||||
# A) Make the cache directory persistent via an anonymous volume.
|
|
||||||
# Note: This only persists for a single, named container. This is
|
|
||||||
# only for convenience at de/test stage.
|
|
||||||
# For prod, it is better to use a named volume via host mount/k8s.
|
|
||||||
VOLUME ["/root/.cache/huggingface/hub"]
|
|
||||||
|
|
||||||
|
|
||||||
# or
|
|
||||||
# B) Conditionally copy a local pre-cache from the build context to the
|
|
||||||
# container's cache via the HF_PRECACHE_DIR build-arg.
|
|
||||||
# WARNING: This will copy ALL files in the pre-cache location.
|
|
||||||
|
|
||||||
# Conditionally copy a cache directory if provided
|
|
||||||
RUN if [ -n "$HF_PRECACHE_DIR" ]; then \
|
|
||||||
echo "Copying Hugging Face cache from $HF_PRECACHE_DIR"; \
|
|
||||||
mkdir -p /root/.cache/huggingface/hub && \
|
|
||||||
cp -r $HF_PRECACHE_DIR/* /root/.cache/huggingface/hub; \
|
|
||||||
else \
|
|
||||||
echo "No local Hugging Face cache specified, skipping copy"; \
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Conditionally copy a Hugging Face token if provided. Useful for Diart backend (pyannote audio models)
|
|
||||||
RUN if [ -n "$HF_TKN_FILE" ]; then \
|
|
||||||
echo "Copying Hugging Face token from $HF_TKN_FILE"; \
|
|
||||||
mkdir -p /root/.cache/huggingface && \
|
|
||||||
cp $HF_TKN_FILE /root/.cache/huggingface/token; \
|
|
||||||
else \
|
|
||||||
echo "No Hugging Face token file specified, skipping token setup"; \
|
|
||||||
fi
|
|
||||||
|
|
||||||
EXPOSE 8000
|
EXPOSE 8000
|
||||||
|
|
||||||
|
ENV PATH="/app/.venv/bin:$PATH"
|
||||||
|
ENV UV_PYTHON_DOWNLOADS=0
|
||||||
|
|
||||||
HEALTHCHECK --interval=30s --timeout=5s --start-period=120s --retries=3 \
|
HEALTHCHECK --interval=30s --timeout=5s --start-period=120s --retries=3 \
|
||||||
CMD python3 -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/')" || exit 1
|
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/')" || exit 1
|
||||||
|
|
||||||
ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"]
|
ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"]
|
||||||
|
|
||||||
|
|||||||
102
Dockerfile.cpu
102
Dockerfile.cpu
@@ -1,64 +1,76 @@
|
|||||||
FROM python:3.13-slim
|
FROM ghcr.io/astral-sh/uv:0.10.4 AS uvbin
|
||||||
|
|
||||||
|
# --- MARK: Builder Stage
|
||||||
|
FROM debian:bookworm-slim AS builder-cpu
|
||||||
ENV DEBIAN_FRONTEND=noninteractive
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
ENV PYTHONUNBUFFERED=1
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
ARG EXTRAS
|
RUN apt-get update && \
|
||||||
ARG HF_PRECACHE_DIR
|
apt-get install -y --no-install-recommends \
|
||||||
ARG HF_TKN_FILE
|
build-essential \
|
||||||
|
python3-dev && \
|
||||||
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Install UV and set up the environment
|
||||||
|
COPY --from=uvbin /uv /uvx /bin/
|
||||||
|
|
||||||
|
ENV UV_COMPILE_BYTECODE=1 UV_LINK_MODE=copy UV_NO_DEV=1
|
||||||
|
ENV UV_PYTHON_PREFERENCE=only-managed
|
||||||
|
ENV UV_PYTHON_INSTALL_DIR=/python
|
||||||
|
|
||||||
|
RUN uv python install 3.12
|
||||||
|
|
||||||
|
# Install dependencies first to leverage caching
|
||||||
|
ARG EXTRAS=cpu
|
||||||
|
COPY pyproject.toml uv.lock /app/
|
||||||
|
RUN set -eux; \
|
||||||
|
set --; \
|
||||||
|
for extra in $(echo "${EXTRAS:-}" | tr ',' ' '); do \
|
||||||
|
set -- "$@" --extra "$extra"; \
|
||||||
|
done; \
|
||||||
|
uv sync --frozen --no-install-project --no-editable --no-cache "$@"
|
||||||
|
|
||||||
|
# Copy the source code and install the package only
|
||||||
|
COPY whisperlivekit /app/whisperlivekit
|
||||||
|
RUN set -eux; \
|
||||||
|
set --; \
|
||||||
|
for extra in $(echo "${EXTRAS:-}" | tr ',' ' '); do \
|
||||||
|
set -- "$@" --extra "$extra"; \
|
||||||
|
done; \
|
||||||
|
uv sync --frozen --no-editable --no-cache "$@"
|
||||||
|
|
||||||
|
# --- MARK: Runtime Stage
|
||||||
|
FROM debian:bookworm-slim
|
||||||
|
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
RUN apt-get update && \
|
RUN apt-get update && \
|
||||||
apt-get install -y --no-install-recommends \
|
apt-get install -y --no-install-recommends \
|
||||||
ffmpeg \
|
ffmpeg &&\
|
||||||
git \
|
rm -rf /var/lib/apt/lists/*
|
||||||
build-essential \
|
|
||||||
python3-dev && \
|
|
||||||
rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
# Install CPU-only PyTorch
|
# Copy UV binaries
|
||||||
RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
COPY --from=uvbin /uv /uvx /bin/
|
||||||
|
|
||||||
COPY . .
|
# Copy the Python version
|
||||||
|
COPY --from=builder-cpu --chown=python:python /python /python
|
||||||
|
|
||||||
# Install WhisperLiveKit directly, allowing for optional dependencies
|
# Copy the virtual environment with all dependencies installed
|
||||||
RUN if [ -n "$EXTRAS" ]; then \
|
COPY --from=builder-cpu /app/.venv /app/.venv
|
||||||
echo "Installing with extras: [$EXTRAS]"; \
|
|
||||||
pip install --no-cache-dir whisperlivekit[$EXTRAS]; \
|
|
||||||
else \
|
|
||||||
echo "Installing base package only"; \
|
|
||||||
pip install --no-cache-dir whisperlivekit; \
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Enable in-container caching for Hugging Face models
|
|
||||||
VOLUME ["/root/.cache/huggingface/hub"]
|
|
||||||
|
|
||||||
# Conditionally copy a local pre-cache from the build context
|
|
||||||
RUN if [ -n "$HF_PRECACHE_DIR" ]; then \
|
|
||||||
echo "Copying Hugging Face cache from $HF_PRECACHE_DIR"; \
|
|
||||||
mkdir -p /root/.cache/huggingface/hub && \
|
|
||||||
cp -r $HF_PRECACHE_DIR/* /root/.cache/huggingface/hub; \
|
|
||||||
else \
|
|
||||||
echo "No local Hugging Face cache specified, skipping copy"; \
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Conditionally copy a Hugging Face token if provided
|
|
||||||
RUN if [ -n "$HF_TKN_FILE" ]; then \
|
|
||||||
echo "Copying Hugging Face token from $HF_TKN_FILE"; \
|
|
||||||
mkdir -p /root/.cache/huggingface && \
|
|
||||||
cp $HF_TKN_FILE /root/.cache/huggingface/token; \
|
|
||||||
else \
|
|
||||||
echo "No Hugging Face token file specified, skipping token setup"; \
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Expose port for the transcription server
|
|
||||||
EXPOSE 8000
|
EXPOSE 8000
|
||||||
|
|
||||||
|
ENV PATH="/app/.venv/bin:$PATH"
|
||||||
|
ENV UV_PYTHON_DOWNLOADS=0
|
||||||
|
|
||||||
HEALTHCHECK --interval=30s --timeout=5s --start-period=120s --retries=3 \
|
HEALTHCHECK --interval=30s --timeout=5s --start-period=120s --retries=3 \
|
||||||
CMD python3 -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/')" || exit 1
|
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/')" || exit 1
|
||||||
|
|
||||||
ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"]
|
ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"]
|
||||||
|
|
||||||
# Default args - you might want to use a smaller model for CPU
|
# Default args - you might want to use a smaller model for CPU
|
||||||
CMD ["--model", "tiny"]
|
CMD ["--model", "tiny"]
|
||||||
|
|||||||
59
README.md
59
README.md
@@ -18,9 +18,10 @@
|
|||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
|
||||||
#### Powered by Leading Research:
|
### Powered by Leading Research:
|
||||||
|
|
||||||
- Simul-[Whisper](https://arxiv.org/pdf/2406.10052)/[Streaming](https://arxiv.org/abs/2506.17077) (SOTA 2025) - Ultra-low latency transcription using [AlignAtt policy](https://arxiv.org/pdf/2305.11408)
|
**See the interactive playground in [this repo](https://github.com/QuentinFuxa/streamlit-d3-network) to explore how AlignAtt works**
|
||||||
|
- Simul-[Whisper](https://arxiv.org/pdf/2406.10052)/[Streaming](https://arxiv.org/abs/2506.17077) (SOTA 2025) - Ultra-low latency transcription using [AlignAtt policy](https://arxiv.org/pdf/2305.11408).
|
||||||
- [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting) (2025), based on [distilled](https://huggingface.co/entai2965/nllb-200-distilled-600M-ctranslate2) [NLLB](https://arxiv.org/abs/2207.04672) (2022, 2024) - Simulatenous translation from & to 200 languages.
|
- [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting) (2025), based on [distilled](https://huggingface.co/entai2965/nllb-200-distilled-600M-ctranslate2) [NLLB](https://arxiv.org/abs/2207.04672) (2022, 2024) - Simulatenous translation from & to 200 languages.
|
||||||
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - Low latency transcription using [LocalAgreement policy](https://www.isca-archive.org/interspeech_2020/liu20s_interspeech.pdf)
|
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - Low latency transcription using [LocalAgreement policy](https://www.isca-archive.org/interspeech_2020/liu20s_interspeech.pdf)
|
||||||
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - Advanced real-time speaker diarization
|
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - Advanced real-time speaker diarization
|
||||||
@@ -72,15 +73,29 @@ Go to `chrome-extension` for instructions.
|
|||||||
|
|
||||||
#### Optional Dependencies
|
#### Optional Dependencies
|
||||||
|
|
||||||
| Optional | `pip install` |
|
| Feature | `uv sync` | `pip install -e` |
|
||||||
|-----------|-------------|
|
|-----------|-------------|-------------|
|
||||||
| **Windows/Linux optimizations** | `faster-whisper` |
|
| **Apple Silicon MLX Whisper backend** | `uv sync --extra mlx-whisper` | `pip install -e ".[mlx-whisper]"` |
|
||||||
| **Apple Silicon optimizations** | `mlx-whisper` |
|
| **Voxtral (MLX backend, Apple Silicon)** | `uv sync --extra voxtral-mlx` | `pip install -e ".[voxtral-mlx]"` |
|
||||||
| **Voxtral (multilingual, auto-detect)** | `transformers torch` (or use built-in `voxtral-mlx` on Apple Silicon) |
|
| **CPU PyTorch stack** | `uv sync --extra cpu` | `pip install -e ".[cpu]"` |
|
||||||
| **Translation** | `nllw` |
|
| **CUDA 12.9 PyTorch stack** | `uv sync --extra cu129` | `pip install -e ".[cu129]"` |
|
||||||
| **Speaker diarization** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
|
| **Translation** | `uv sync --extra translation` | `pip install -e ".[translation]"` |
|
||||||
| OpenAI API | `openai` |
|
| **Sentence tokenizer** | `uv sync --extra sentence_tokenizer` | `pip install -e ".[sentence_tokenizer]"` |
|
||||||
| *[Not recommanded]* Speaker diarization with Diart | `diart` |
|
| **Voxtral (HF backend)** | `uv sync --extra voxtral-hf` | `pip install -e ".[voxtral-hf]"` |
|
||||||
|
| **Speaker diarization (Sortformer / NeMo)** | `uv sync --extra diarization-sortformer` | `pip install -e ".[diarization-sortformer]"` |
|
||||||
|
| *[Not recommended]* Speaker diarization with Diart | `uv sync --extra diarization-diart` | `pip install -e ".[diarization-diart]"` |
|
||||||
|
|
||||||
|
Supported GPU profiles:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Profile A: Sortformer diarization
|
||||||
|
uv sync --extra cu129 --extra diarization-sortformer
|
||||||
|
|
||||||
|
# Profile B: Voxtral HF + translation
|
||||||
|
uv sync --extra cu129 --extra voxtral-hf --extra translation
|
||||||
|
```
|
||||||
|
|
||||||
|
`voxtral-hf` and `diarization-sortformer` are intentionally incompatible extras and must be installed in separate environments.
|
||||||
|
|
||||||
See **Parameters & Configuration** below on how to use them.
|
See **Parameters & Configuration** below on how to use them.
|
||||||
|
|
||||||
@@ -102,6 +117,7 @@ detection is more reliable and does not bias towards English.
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Apple Silicon (native MLX, recommended)
|
# Apple Silicon (native MLX, recommended)
|
||||||
|
pip install -e ".[voxtral-mlx]"
|
||||||
wlk --backend voxtral-mlx
|
wlk --backend voxtral-mlx
|
||||||
|
|
||||||
# Linux/GPU (HuggingFace transformers)
|
# Linux/GPU (HuggingFace transformers)
|
||||||
@@ -279,7 +295,7 @@ docker run --gpus all -p 8000:8000 --name wlk wlk
|
|||||||
|
|
||||||
**CPU only:**
|
**CPU only:**
|
||||||
```bash
|
```bash
|
||||||
docker build -f Dockerfile.cpu -t wlk .
|
docker build -f Dockerfile.cpu -t wlk --build-arg EXTRAS="cpu" .
|
||||||
docker run -p 8000:8000 --name wlk wlk
|
docker run -p 8000:8000 --name wlk wlk
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -291,6 +307,18 @@ docker run -p 8000:8000 --name wlk wlk
|
|||||||
docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
|
docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Compose (recommended for cache + token wiring):**
|
||||||
|
```bash
|
||||||
|
# GPU Sortformer profile
|
||||||
|
docker compose up --build wlk-gpu-sortformer
|
||||||
|
|
||||||
|
# GPU Voxtral profile
|
||||||
|
docker compose up --build wlk-gpu-voxtral
|
||||||
|
|
||||||
|
# CPU service
|
||||||
|
docker compose up --build wlk-cpu
|
||||||
|
```
|
||||||
|
|
||||||
### Memory Requirements
|
### Memory Requirements
|
||||||
- **Large models**: Ensure your Docker runtime has sufficient memory allocated
|
- **Large models**: Ensure your Docker runtime has sufficient memory allocated
|
||||||
|
|
||||||
@@ -298,9 +326,10 @@ docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
|
|||||||
#### Customization
|
#### Customization
|
||||||
|
|
||||||
- `--build-arg` Options:
|
- `--build-arg` Options:
|
||||||
- `EXTRAS="translation"` - Add extras to the image's installation (no spaces). Remember to set necessary container options!
|
- `EXTRAS="cu129,diarization-sortformer"` - GPU Sortformer profile extras.
|
||||||
- `HF_PRECACHE_DIR="./.cache/"` - Pre-load a model cache for faster first-time start
|
- `EXTRAS="cu129,voxtral-hf,translation"` - GPU Voxtral profile extras.
|
||||||
- `HF_TKN_FILE="./token"` - Add your Hugging Face Hub access token to download gated models
|
- `EXTRAS="cpu,diarization-diart,translation"` - CPU profile extras.
|
||||||
|
- Hugging Face cache + token are configured in `compose.yml` using a named volume and `HF_TKN_FILE` (default: `./token`).
|
||||||
|
|
||||||
## Testing & Benchmarks
|
## Testing & Benchmarks
|
||||||
|
|
||||||
|
|||||||
52
compose.yml
Normal file
52
compose.yml
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
services:
|
||||||
|
wlk-gpu-sortformer:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
args:
|
||||||
|
EXTRAS: ${GPU_SORTFORMER_EXTRAS:-cu129,diarization-sortformer}
|
||||||
|
image: wlk:gpu-sortformer
|
||||||
|
gpus: all
|
||||||
|
ports:
|
||||||
|
- "8000:8000"
|
||||||
|
volumes:
|
||||||
|
- hf-cache:/root/.cache/huggingface/hub
|
||||||
|
# - ${HF_TKN_FILE:-./token}:/root/.cache/huggingface/token:ro
|
||||||
|
environment:
|
||||||
|
- HF_TOKEN
|
||||||
|
command: ["--model", "medium", "--diarization", "--pcm-input"]
|
||||||
|
|
||||||
|
wlk-gpu-voxtral:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
args:
|
||||||
|
EXTRAS: ${GPU_VOXTRAL_EXTRAS:-cu129,voxtral-hf,translation}
|
||||||
|
image: wlk:gpu-voxtral
|
||||||
|
gpus: all
|
||||||
|
ports:
|
||||||
|
- "8001:8000"
|
||||||
|
volumes:
|
||||||
|
- hf-cache:/root/.cache/huggingface/hub
|
||||||
|
# - ${HF_TKN_FILE:-./token}:/root/.cache/huggingface/token:ro
|
||||||
|
environment:
|
||||||
|
- HF_TOKEN
|
||||||
|
command: ["--backend", "voxtral", "--pcm-input"]
|
||||||
|
|
||||||
|
wlk-cpu:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: Dockerfile.cpu
|
||||||
|
args:
|
||||||
|
EXTRAS: ${CPU_EXTRAS:-cpu,diarization-diart,translation}
|
||||||
|
image: wlk:cpu
|
||||||
|
ports:
|
||||||
|
- "8000:8000"
|
||||||
|
volumes:
|
||||||
|
- hf-cache:/root/.cache/huggingface/hub
|
||||||
|
# - ${HF_TKN_FILE:-./token}:/root/.cache/huggingface/token:ro
|
||||||
|
environment:
|
||||||
|
- HF_TOKEN
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
hf-cache:
|
||||||
@@ -7,24 +7,18 @@ name = "whisperlivekit"
|
|||||||
version = "0.2.19"
|
version = "0.2.19"
|
||||||
description = "Real-time speech-to-text with speaker diarization using Whisper"
|
description = "Real-time speech-to-text with speaker diarization using Whisper"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
authors = [
|
authors = [{ name = "Quentin Fuxa" }]
|
||||||
{ name = "Quentin Fuxa" }
|
|
||||||
]
|
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.11, <3.14"
|
||||||
classifiers = [
|
classifiers = [
|
||||||
"Development Status :: 4 - Beta",
|
"Development Status :: 4 - Beta",
|
||||||
"Intended Audience :: Developers",
|
"Intended Audience :: Developers",
|
||||||
"License :: OSI Approved :: MIT License",
|
"License :: OSI Approved :: MIT License",
|
||||||
"Programming Language :: Python :: 3.9",
|
|
||||||
"Programming Language :: Python :: 3.10",
|
|
||||||
"Programming Language :: Python :: 3.11",
|
"Programming Language :: Python :: 3.11",
|
||||||
"Programming Language :: Python :: 3.12",
|
"Programming Language :: Python :: 3.12",
|
||||||
"Programming Language :: Python :: 3.13",
|
"Programming Language :: Python :: 3.13",
|
||||||
"Programming Language :: Python :: 3.14",
|
|
||||||
"Programming Language :: Python :: 3.15",
|
|
||||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
"Topic :: Multimedia :: Sound/Audio :: Speech"
|
"Topic :: Multimedia :: Sound/Audio :: Speech",
|
||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"fastapi",
|
"fastapi",
|
||||||
@@ -32,20 +26,91 @@ dependencies = [
|
|||||||
"soundfile",
|
"soundfile",
|
||||||
"uvicorn",
|
"uvicorn",
|
||||||
"websockets",
|
"websockets",
|
||||||
"torchaudio>=2.0.0",
|
|
||||||
"torch>=2.0.0",
|
|
||||||
"huggingface-hub>=0.25.0",
|
"huggingface-hub>=0.25.0",
|
||||||
"faster-whisper>=1.2.0",
|
"faster-whisper>=1.2.0",
|
||||||
|
"torch>=2.0.0",
|
||||||
|
"torchaudio>=2.0.0",
|
||||||
"tqdm",
|
"tqdm",
|
||||||
"tiktoken",
|
"tiktoken",
|
||||||
'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")'
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
test = ["pytest>=7.0", "pytest-asyncio>=0.21"]
|
test = ["pytest>=7.0", "pytest-asyncio>=0.21"]
|
||||||
translation = ["nllw"]
|
translation = ["nllw"]
|
||||||
sentence_tokenizer = ["mosestokenizer", "wtpsplit"]
|
sentence_tokenizer = ["mosestokenizer", "wtpsplit"]
|
||||||
voxtral-hf = ["transformers>=5.2.0", "mistral-common[audio]"]
|
mlx-whisper = [
|
||||||
|
'mlx>=0.11.0; sys_platform == "darwin" and platform_machine == "arm64"',
|
||||||
|
'mlx-whisper>=0.4.0; sys_platform == "darwin" and platform_machine == "arm64"',
|
||||||
|
]
|
||||||
|
voxtral-mlx = [
|
||||||
|
'mlx>=0.11.0; sys_platform == "darwin" and platform_machine == "arm64"',
|
||||||
|
'mlx-whisper>=0.4.0; sys_platform == "darwin" and platform_machine == "arm64"',
|
||||||
|
"mistral-common[audio]",
|
||||||
|
]
|
||||||
|
voxtral-hf = [
|
||||||
|
"transformers>=5.2.0; python_version >= '3.10'",
|
||||||
|
"mistral-common[audio]",
|
||||||
|
"accelerate>=0.12",
|
||||||
|
]
|
||||||
|
cpu = ["torch>=2.0.0", "torchaudio>=2.0.0"]
|
||||||
|
cu129 = [
|
||||||
|
"torch>=2.0.0",
|
||||||
|
"torchaudio>=2.0.0",
|
||||||
|
'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")',
|
||||||
|
]
|
||||||
|
diarization-sortformer = [
|
||||||
|
"nemo-toolkit[asr]>2.4; python_version >= '3.10' and python_version < '3.13'",
|
||||||
|
]
|
||||||
|
diarization-diart = [
|
||||||
|
"diart",
|
||||||
|
"torch<2.9.0",
|
||||||
|
"torchaudio<2.9.0",
|
||||||
|
"torchvision<0.24.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[dependency-groups]
|
||||||
|
dev = ["rich>=14.3.3"]
|
||||||
|
|
||||||
|
[tool.uv]
|
||||||
|
conflicts = [
|
||||||
|
[
|
||||||
|
{ extra = "cpu" },
|
||||||
|
{ extra = "cu129" },
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{ extra = "diarization-diart" },
|
||||||
|
{ extra = "cu129" },
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{ extra = "voxtral-hf" },
|
||||||
|
{ extra = "diarization-sortformer" },
|
||||||
|
],
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.uv.sources]
|
||||||
|
torch = [
|
||||||
|
{ index = "pytorch-cpu", extra = "cpu", marker = "platform_system != 'Darwin'" },
|
||||||
|
{ index = "pytorch-cpu", extra = "diarization-diart", marker = "platform_system != 'Darwin'" },
|
||||||
|
{ index = "pytorch-cu129", extra = "cu129", marker = "platform_system == 'Linux' and platform_machine == 'x86_64'" },
|
||||||
|
]
|
||||||
|
torchaudio = [
|
||||||
|
{ index = "pytorch-cpu", extra = "cpu", marker = "platform_system != 'Darwin'" },
|
||||||
|
{ index = "pytorch-cpu", extra = "diarization-diart", marker = "platform_system != 'Darwin'" },
|
||||||
|
{ index = "pytorch-cu129", extra = "cu129", marker = "platform_system == 'Linux' and platform_machine == 'x86_64'" },
|
||||||
|
]
|
||||||
|
torchvision = [
|
||||||
|
{ index = "pytorch-cpu", extra = "diarization-diart", marker = "platform_system != 'Darwin'" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[tool.uv.index]]
|
||||||
|
name = "pytorch-cpu"
|
||||||
|
url = "https://download.pytorch.org/whl/cpu"
|
||||||
|
explicit = true
|
||||||
|
|
||||||
|
[[tool.uv.index]]
|
||||||
|
name = "pytorch-cu129"
|
||||||
|
url = "https://download.pytorch.org/whl/cu129"
|
||||||
|
explicit = true
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
Homepage = "https://github.com/QuentinFuxa/WhisperLiveKit"
|
Homepage = "https://github.com/QuentinFuxa/WhisperLiveKit"
|
||||||
@@ -66,7 +131,7 @@ packages = [
|
|||||||
"whisperlivekit.web",
|
"whisperlivekit.web",
|
||||||
"whisperlivekit.local_agreement",
|
"whisperlivekit.local_agreement",
|
||||||
"whisperlivekit.voxtral_mlx",
|
"whisperlivekit.voxtral_mlx",
|
||||||
"whisperlivekit.silero_vad_models"
|
"whisperlivekit.silero_vad_models",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.setuptools.package-data]
|
[tool.setuptools.package-data]
|
||||||
|
|||||||
580
scripts/python_support_matrix.py
Normal file
580
scripts/python_support_matrix.py
Normal file
@@ -0,0 +1,580 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Offline Python support matrix runner for WhisperLiveKit."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import shlex
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
try:
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
|
HAS_RICH = True
|
||||||
|
except Exception:
|
||||||
|
HAS_RICH = False
|
||||||
|
|
||||||
|
SAMPLE_URL = (
|
||||||
|
"https://github.com/pyannote/pyannote-audio/raw/develop/tutorials/assets/sample.wav"
|
||||||
|
)
|
||||||
|
SAMPLE_PATH = Path("audio_tests/support-matrix-sample.wav")
|
||||||
|
DEFAULT_LOGS_DIR = Path("outputs/python-matrix/logs")
|
||||||
|
PYTHON_VERSIONS = ("3.11", "3.12", "3.13")
|
||||||
|
CONSOLE = Console() if HAS_RICH else None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class MatrixRow:
|
||||||
|
row_id: str
|
||||||
|
extras: tuple[str, ...]
|
||||||
|
backend: str
|
||||||
|
policy: str
|
||||||
|
diarization_backend: str
|
||||||
|
requires_gpu: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
CASES = (
|
||||||
|
MatrixRow(
|
||||||
|
row_id="fw-diart-cpu",
|
||||||
|
extras=("test", "cpu", "diarization-diart"),
|
||||||
|
backend="faster-whisper",
|
||||||
|
policy="simulstreaming",
|
||||||
|
diarization_backend="diart",
|
||||||
|
),
|
||||||
|
MatrixRow(
|
||||||
|
row_id="fw-sortformer-cpu",
|
||||||
|
extras=("test", "cpu", "diarization-sortformer"),
|
||||||
|
backend="faster-whisper",
|
||||||
|
policy="simulstreaming",
|
||||||
|
diarization_backend="sortformer",
|
||||||
|
),
|
||||||
|
MatrixRow(
|
||||||
|
row_id="fw-sortformer-gpu",
|
||||||
|
extras=("test", "cu129", "diarization-sortformer"),
|
||||||
|
backend="faster-whisper",
|
||||||
|
policy="simulstreaming",
|
||||||
|
diarization_backend="sortformer",
|
||||||
|
requires_gpu=True,
|
||||||
|
),
|
||||||
|
MatrixRow(
|
||||||
|
row_id="voxtral-diart-cpu",
|
||||||
|
extras=("test", "cpu", "voxtral-hf", "diarization-diart"),
|
||||||
|
backend="voxtral",
|
||||||
|
policy="voxtral",
|
||||||
|
diarization_backend="diart",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
EXPECTED_FAILURE_CASES = {
|
||||||
|
("3.11", "voxtral-diart-cpu"): "known_unstable_voxtral_diart_cpu",
|
||||||
|
("3.12", "voxtral-diart-cpu"): "known_unstable_voxtral_diart_cpu",
|
||||||
|
}
|
||||||
|
UNSUPPORTED_CASES = {
|
||||||
|
("3.13", "fw-sortformer-cpu"): "unsupported_py313_sortformer_protobuf",
|
||||||
|
("3.13", "fw-sortformer-gpu"): "unsupported_py313_sortformer_protobuf",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class CaseResult:
|
||||||
|
python_version: str
|
||||||
|
row_id: str
|
||||||
|
status: Literal["PASS", "FAIL", "N/A"]
|
||||||
|
reason: str
|
||||||
|
duration_sec: float
|
||||||
|
hint: str = ""
|
||||||
|
log_path: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Minimal WhisperLiveKit offline support matrix"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--timeout-sec",
|
||||||
|
type=int,
|
||||||
|
default=300,
|
||||||
|
help="Per-case timeout in seconds (default: 300)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--logs-dir",
|
||||||
|
default=str(DEFAULT_LOGS_DIR),
|
||||||
|
help="Directory where per-case logs are written (default: outputs/python-matrix/logs)",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def safe_slug(text: str) -> str:
|
||||||
|
return text.replace("=", "-").replace("|", "__").replace("/", "-").replace(" ", "-")
|
||||||
|
|
||||||
|
|
||||||
|
def status_style(status: str) -> str:
|
||||||
|
if status == "PASS":
|
||||||
|
return "green"
|
||||||
|
if status == "FAIL":
|
||||||
|
return "bold red"
|
||||||
|
if status == "N/A":
|
||||||
|
return "yellow"
|
||||||
|
return "white"
|
||||||
|
|
||||||
|
|
||||||
|
def print_line(message: str, style: str | None = None) -> None:
|
||||||
|
if CONSOLE is None:
|
||||||
|
print(message)
|
||||||
|
return
|
||||||
|
if style:
|
||||||
|
CONSOLE.print(message, style=style, highlight=False)
|
||||||
|
else:
|
||||||
|
CONSOLE.print(message, highlight=False)
|
||||||
|
|
||||||
|
|
||||||
|
def tail_text(text: str | None, max_chars: int = 220) -> str:
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
normalized = " ".join(text.split())
|
||||||
|
if len(normalized) <= max_chars:
|
||||||
|
return normalized
|
||||||
|
return normalized[-max_chars:]
|
||||||
|
|
||||||
|
|
||||||
|
def run_command(
|
||||||
|
cmd: list[str],
|
||||||
|
cwd: Path,
|
||||||
|
env: dict[str, str],
|
||||||
|
timeout: int | None = None,
|
||||||
|
log_path: Path | None = None,
|
||||||
|
log_section: str | None = None,
|
||||||
|
) -> subprocess.CompletedProcess[str]:
|
||||||
|
def _append_log(
|
||||||
|
*,
|
||||||
|
command: list[str],
|
||||||
|
section: str,
|
||||||
|
returncode: int | None,
|
||||||
|
stdout: str | None,
|
||||||
|
stderr: str | None,
|
||||||
|
timed_out: bool = False,
|
||||||
|
) -> None:
|
||||||
|
if log_path is None:
|
||||||
|
return
|
||||||
|
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with log_path.open("a", encoding="utf-8") as f:
|
||||||
|
f.write(f"\n=== {section} ===\n")
|
||||||
|
f.write(f"$ {shlex.join(command)}\n")
|
||||||
|
if timed_out:
|
||||||
|
f.write("status: timeout\n")
|
||||||
|
else:
|
||||||
|
f.write(f"status: exit_code={returncode}\n")
|
||||||
|
if stdout:
|
||||||
|
f.write("--- stdout ---\n")
|
||||||
|
f.write(stdout)
|
||||||
|
if not stdout.endswith("\n"):
|
||||||
|
f.write("\n")
|
||||||
|
if stderr:
|
||||||
|
f.write("--- stderr ---\n")
|
||||||
|
f.write(stderr)
|
||||||
|
if not stderr.endswith("\n"):
|
||||||
|
f.write("\n")
|
||||||
|
|
||||||
|
section = log_section or "command"
|
||||||
|
try:
|
||||||
|
proc = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
cwd=str(cwd),
|
||||||
|
env=env,
|
||||||
|
text=True,
|
||||||
|
capture_output=True,
|
||||||
|
check=False,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
except subprocess.TimeoutExpired as exc:
|
||||||
|
_append_log(
|
||||||
|
command=cmd,
|
||||||
|
section=section,
|
||||||
|
returncode=None,
|
||||||
|
stdout=exc.stdout if isinstance(exc.stdout, str) else None,
|
||||||
|
stderr=exc.stderr if isinstance(exc.stderr, str) else None,
|
||||||
|
timed_out=True,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
_append_log(
|
||||||
|
command=cmd,
|
||||||
|
section=section,
|
||||||
|
returncode=proc.returncode,
|
||||||
|
stdout=proc.stdout,
|
||||||
|
stderr=proc.stderr,
|
||||||
|
)
|
||||||
|
return proc
|
||||||
|
|
||||||
|
|
||||||
|
def detect_gpu_available() -> bool:
|
||||||
|
try:
|
||||||
|
proc = subprocess.run(
|
||||||
|
["nvidia-smi", "-L"],
|
||||||
|
text=True,
|
||||||
|
capture_output=True,
|
||||||
|
check=False,
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
except (FileNotFoundError, subprocess.TimeoutExpired):
|
||||||
|
return False
|
||||||
|
return proc.returncode == 0
|
||||||
|
|
||||||
|
|
||||||
|
def download_sample(repo_root: Path) -> Path:
|
||||||
|
target = repo_root / SAMPLE_PATH
|
||||||
|
target.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
cmd = [
|
||||||
|
"curl",
|
||||||
|
"--fail",
|
||||||
|
"--location",
|
||||||
|
"--silent",
|
||||||
|
"--show-error",
|
||||||
|
SAMPLE_URL,
|
||||||
|
"--output",
|
||||||
|
str(target),
|
||||||
|
]
|
||||||
|
proc = run_command(cmd, cwd=repo_root, env=os.environ.copy())
|
||||||
|
if proc.returncode != 0:
|
||||||
|
hint = tail_text(proc.stderr or proc.stdout)
|
||||||
|
raise RuntimeError(f"sample_download_failed: {hint}")
|
||||||
|
return target
|
||||||
|
|
||||||
|
|
||||||
|
def sync_case_environment(
|
||||||
|
repo_root: Path,
|
||||||
|
python_version: str,
|
||||||
|
row: MatrixRow,
|
||||||
|
env_dir: Path,
|
||||||
|
log_path: Path,
|
||||||
|
) -> tuple[bool, str]:
|
||||||
|
cmd = ["uv", "sync", "--python", python_version, "--no-dev"]
|
||||||
|
for extra in row.extras:
|
||||||
|
cmd.extend(["--extra", extra])
|
||||||
|
env = os.environ.copy()
|
||||||
|
env["UV_PROJECT_ENVIRONMENT"] = str(env_dir)
|
||||||
|
proc = run_command(
|
||||||
|
cmd,
|
||||||
|
cwd=repo_root,
|
||||||
|
env=env,
|
||||||
|
log_path=log_path,
|
||||||
|
log_section="sync",
|
||||||
|
)
|
||||||
|
if proc.returncode != 0:
|
||||||
|
return False, tail_text(proc.stderr or proc.stdout)
|
||||||
|
return True, ""
|
||||||
|
|
||||||
|
|
||||||
|
def apply_expected_failure_policy(result: CaseResult) -> CaseResult:
|
||||||
|
expected_reason = EXPECTED_FAILURE_CASES.get((result.python_version, result.row_id))
|
||||||
|
if result.status != "FAIL" or not expected_reason:
|
||||||
|
return result
|
||||||
|
override_hint = result.hint
|
||||||
|
if result.reason:
|
||||||
|
override_hint = (
|
||||||
|
f"expected_failure_override original_reason={result.reason}; {override_hint}"
|
||||||
|
if override_hint
|
||||||
|
else f"expected_failure_override original_reason={result.reason}"
|
||||||
|
)
|
||||||
|
return CaseResult(
|
||||||
|
python_version=result.python_version,
|
||||||
|
row_id=result.row_id,
|
||||||
|
status="N/A",
|
||||||
|
reason=expected_reason,
|
||||||
|
duration_sec=result.duration_sec,
|
||||||
|
hint=override_hint,
|
||||||
|
log_path=result.log_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_offline_command(
|
||||||
|
python_version: str,
|
||||||
|
row: MatrixRow,
|
||||||
|
sample_audio: Path,
|
||||||
|
timeout_sec: int,
|
||||||
|
) -> tuple[list[str], int | None]:
|
||||||
|
base_cmd = [
|
||||||
|
"uv",
|
||||||
|
"run",
|
||||||
|
"--python",
|
||||||
|
python_version,
|
||||||
|
"--no-sync",
|
||||||
|
"python",
|
||||||
|
"test_backend_offline.py",
|
||||||
|
"--backend",
|
||||||
|
row.backend,
|
||||||
|
"--policy",
|
||||||
|
row.policy,
|
||||||
|
"--audio",
|
||||||
|
str(sample_audio),
|
||||||
|
"--model",
|
||||||
|
"tiny",
|
||||||
|
"--diarization",
|
||||||
|
"--diarization-backend",
|
||||||
|
row.diarization_backend,
|
||||||
|
"--lan",
|
||||||
|
"en",
|
||||||
|
"--no-realtime",
|
||||||
|
]
|
||||||
|
if shutil.which("timeout"):
|
||||||
|
return ["timeout", str(timeout_sec), *base_cmd], None
|
||||||
|
return base_cmd, timeout_sec
|
||||||
|
|
||||||
|
|
||||||
|
def run_case(
|
||||||
|
repo_root: Path,
|
||||||
|
python_version: str,
|
||||||
|
row: MatrixRow,
|
||||||
|
sample_audio: Path,
|
||||||
|
timeout_sec: int,
|
||||||
|
gpu_available: bool,
|
||||||
|
logs_dir: Path,
|
||||||
|
) -> CaseResult:
|
||||||
|
start = time.monotonic()
|
||||||
|
case_slug = safe_slug(f"py{python_version}-{row.row_id}")
|
||||||
|
log_path = logs_dir / f"run-{case_slug}.log"
|
||||||
|
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
log_path.write_text("", encoding="utf-8")
|
||||||
|
|
||||||
|
unsupported_reason = UNSUPPORTED_CASES.get((python_version, row.row_id))
|
||||||
|
if unsupported_reason:
|
||||||
|
log_path.write_text(
|
||||||
|
f"[matrix] precheck_short_circuit status=N/A reason={unsupported_reason}\n",
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
return CaseResult(
|
||||||
|
python_version=python_version,
|
||||||
|
row_id=row.row_id,
|
||||||
|
status="N/A",
|
||||||
|
reason=unsupported_reason,
|
||||||
|
duration_sec=0.0,
|
||||||
|
hint="unsupported_case_precheck",
|
||||||
|
log_path=str(log_path),
|
||||||
|
)
|
||||||
|
|
||||||
|
if row.requires_gpu and not gpu_available:
|
||||||
|
return CaseResult(
|
||||||
|
python_version=python_version,
|
||||||
|
row_id=row.row_id,
|
||||||
|
status="N/A",
|
||||||
|
reason="gpu_unavailable",
|
||||||
|
duration_sec=0.0,
|
||||||
|
hint="nvidia-smi unavailable or failed",
|
||||||
|
log_path=str(log_path),
|
||||||
|
)
|
||||||
|
|
||||||
|
env_dir = repo_root / ".matrix-envs" / safe_slug(f"py{python_version}-{row.row_id}")
|
||||||
|
sync_ok, sync_hint = sync_case_environment(
|
||||||
|
repo_root,
|
||||||
|
python_version,
|
||||||
|
row,
|
||||||
|
env_dir,
|
||||||
|
log_path=log_path,
|
||||||
|
)
|
||||||
|
if not sync_ok:
|
||||||
|
return CaseResult(
|
||||||
|
python_version=python_version,
|
||||||
|
row_id=row.row_id,
|
||||||
|
status="FAIL",
|
||||||
|
reason="dependency_sync_failed",
|
||||||
|
duration_sec=round(time.monotonic() - start, 3),
|
||||||
|
hint=sync_hint,
|
||||||
|
log_path=str(log_path),
|
||||||
|
)
|
||||||
|
|
||||||
|
cmd, process_timeout = build_offline_command(
|
||||||
|
python_version, row, sample_audio, timeout_sec
|
||||||
|
)
|
||||||
|
env = os.environ.copy()
|
||||||
|
env["UV_PROJECT_ENVIRONMENT"] = str(env_dir)
|
||||||
|
if row.requires_gpu:
|
||||||
|
env.pop("CUDA_VISIBLE_DEVICES", None)
|
||||||
|
else:
|
||||||
|
env["CUDA_VISIBLE_DEVICES"] = ""
|
||||||
|
try:
|
||||||
|
proc = run_command(
|
||||||
|
cmd,
|
||||||
|
cwd=repo_root,
|
||||||
|
env=env,
|
||||||
|
timeout=process_timeout,
|
||||||
|
log_path=log_path,
|
||||||
|
log_section="offline",
|
||||||
|
)
|
||||||
|
except subprocess.TimeoutExpired as exc:
|
||||||
|
return CaseResult(
|
||||||
|
python_version=python_version,
|
||||||
|
row_id=row.row_id,
|
||||||
|
status="FAIL",
|
||||||
|
reason="offline_timeout",
|
||||||
|
duration_sec=round(time.monotonic() - start, 3),
|
||||||
|
hint=tail_text((exc.stderr or "") if isinstance(exc.stderr, str) else ""),
|
||||||
|
log_path=str(log_path),
|
||||||
|
)
|
||||||
|
|
||||||
|
hint = tail_text(proc.stderr or proc.stdout)
|
||||||
|
if proc.returncode == 0:
|
||||||
|
return CaseResult(
|
||||||
|
python_version=python_version,
|
||||||
|
row_id=row.row_id,
|
||||||
|
status="PASS",
|
||||||
|
reason="ok",
|
||||||
|
duration_sec=round(time.monotonic() - start, 3),
|
||||||
|
hint=hint,
|
||||||
|
log_path=str(log_path),
|
||||||
|
)
|
||||||
|
|
||||||
|
reason = "offline_timeout" if proc.returncode == 124 else "offline_run_failed"
|
||||||
|
return CaseResult(
|
||||||
|
python_version=python_version,
|
||||||
|
row_id=row.row_id,
|
||||||
|
status="FAIL",
|
||||||
|
reason=reason,
|
||||||
|
duration_sec=round(time.monotonic() - start, 3),
|
||||||
|
hint=hint,
|
||||||
|
log_path=str(log_path),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def print_summary(results: list[CaseResult]) -> None:
|
||||||
|
pass_count = sum(1 for row in results if row.status == "PASS")
|
||||||
|
fail_count = sum(1 for row in results if row.status == "FAIL")
|
||||||
|
na_count = sum(1 for row in results if row.status == "N/A")
|
||||||
|
if CONSOLE is None:
|
||||||
|
print("\n[matrix] results")
|
||||||
|
print("python | row | status | reason | duration_s")
|
||||||
|
print("---|---|---|---|---")
|
||||||
|
for result in results:
|
||||||
|
print(
|
||||||
|
f"{result.python_version} | {result.row_id} | {result.status} | "
|
||||||
|
f"{result.reason} | {result.duration_sec:.3f}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"\n[matrix] summary pass={pass_count} fail={fail_count} "
|
||||||
|
f"na={na_count} total={len(results)}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
table = Table(title="Support Matrix Results")
|
||||||
|
table.add_column("Python", style="cyan", no_wrap=True)
|
||||||
|
table.add_column("Row", style="white")
|
||||||
|
table.add_column("Status", no_wrap=True)
|
||||||
|
table.add_column("Reason")
|
||||||
|
table.add_column("Duration (s)", justify="right", no_wrap=True)
|
||||||
|
for result in results:
|
||||||
|
table.add_row(
|
||||||
|
result.python_version,
|
||||||
|
result.row_id,
|
||||||
|
f"[{status_style(result.status)}]{result.status}[/{status_style(result.status)}]",
|
||||||
|
result.reason,
|
||||||
|
f"{result.duration_sec:.3f}",
|
||||||
|
)
|
||||||
|
CONSOLE.print()
|
||||||
|
CONSOLE.print(table)
|
||||||
|
CONSOLE.print(
|
||||||
|
f"[bold]Summary[/bold] "
|
||||||
|
f"pass=[green]{pass_count}[/green] "
|
||||||
|
f"fail=[bold red]{fail_count}[/bold red] "
|
||||||
|
f"na=[yellow]{na_count}[/yellow] "
|
||||||
|
f"total={len(results)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
diagnostics = [row for row in results if row.status in {"FAIL", "N/A"} and row.hint]
|
||||||
|
if diagnostics:
|
||||||
|
if CONSOLE is None:
|
||||||
|
print("\n[matrix] diagnostics (failed/n-a cases)")
|
||||||
|
for row in diagnostics:
|
||||||
|
print(
|
||||||
|
f"- py={row.python_version} row={row.row_id} "
|
||||||
|
f"status={row.status} reason={row.reason}"
|
||||||
|
)
|
||||||
|
print(f" hint: {row.hint}")
|
||||||
|
if row.log_path:
|
||||||
|
print(f" log: {row.log_path}")
|
||||||
|
else:
|
||||||
|
diagnostics_table = Table(title="Diagnostics (FAIL / N/A)")
|
||||||
|
diagnostics_table.add_column("Case", style="cyan")
|
||||||
|
diagnostics_table.add_column("Status", no_wrap=True)
|
||||||
|
diagnostics_table.add_column("Reason")
|
||||||
|
diagnostics_table.add_column("Hint")
|
||||||
|
diagnostics_table.add_column("Log")
|
||||||
|
for row in diagnostics:
|
||||||
|
diagnostics_table.add_row(
|
||||||
|
f"py={row.python_version} {row.row_id}",
|
||||||
|
f"[{status_style(row.status)}]{row.status}[/{status_style(row.status)}]",
|
||||||
|
row.reason,
|
||||||
|
row.hint,
|
||||||
|
row.log_path,
|
||||||
|
)
|
||||||
|
CONSOLE.print()
|
||||||
|
CONSOLE.print(diagnostics_table)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
args = parse_args()
|
||||||
|
if args.timeout_sec <= 0:
|
||||||
|
print("[matrix] error: --timeout-sec must be > 0", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
repo_root = Path(__file__).resolve().parents[1]
|
||||||
|
logs_dir = (repo_root / args.logs_dir).resolve()
|
||||||
|
logs_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
print_line(f"[matrix] repo_root={repo_root}", style="cyan")
|
||||||
|
print_line(f"[matrix] timeout_sec={args.timeout_sec}", style="cyan")
|
||||||
|
print_line(f"[matrix] logs_dir={logs_dir}", style="cyan")
|
||||||
|
|
||||||
|
try:
|
||||||
|
sample_audio = download_sample(repo_root)
|
||||||
|
except Exception as exc: # pragma: no cover - straightforward failure path
|
||||||
|
if CONSOLE is None:
|
||||||
|
print(f"[matrix] sample_download_failed: {exc}", file=sys.stderr)
|
||||||
|
else:
|
||||||
|
CONSOLE.print(
|
||||||
|
f"[matrix] sample_download_failed: {exc}",
|
||||||
|
style="bold red",
|
||||||
|
highlight=False,
|
||||||
|
)
|
||||||
|
return 1
|
||||||
|
print_line(f"[matrix] sample_audio={sample_audio}", style="cyan")
|
||||||
|
|
||||||
|
gpu_available = detect_gpu_available()
|
||||||
|
print_line(f"[matrix] gpu_available={gpu_available}", style="cyan")
|
||||||
|
|
||||||
|
results: list[CaseResult] = []
|
||||||
|
for python_version in PYTHON_VERSIONS:
|
||||||
|
for row in CASES:
|
||||||
|
print_line(
|
||||||
|
f"\n[matrix] running py={python_version} row={row.row_id}", style="blue"
|
||||||
|
)
|
||||||
|
result = run_case(
|
||||||
|
repo_root=repo_root,
|
||||||
|
python_version=python_version,
|
||||||
|
row=row,
|
||||||
|
sample_audio=sample_audio,
|
||||||
|
timeout_sec=args.timeout_sec,
|
||||||
|
gpu_available=gpu_available,
|
||||||
|
logs_dir=logs_dir,
|
||||||
|
)
|
||||||
|
result = apply_expected_failure_policy(result)
|
||||||
|
results.append(result)
|
||||||
|
print_line(
|
||||||
|
f"[matrix] {result.status} py={result.python_version} "
|
||||||
|
f"row={result.row_id} reason={result.reason} duration={result.duration_sec:.3f}s",
|
||||||
|
style=status_style(result.status),
|
||||||
|
)
|
||||||
|
if result.log_path:
|
||||||
|
print_line(f"[matrix] log={result.log_path}", style="dim")
|
||||||
|
|
||||||
|
print_summary(results)
|
||||||
|
fail_count = sum(1 for row in results if row.status == "FAIL")
|
||||||
|
return 1 if fail_count else 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
@@ -150,7 +150,10 @@ def float32_to_s16le_bytes(audio: np.ndarray) -> bytes:
|
|||||||
|
|
||||||
def create_engine(
|
def create_engine(
|
||||||
backend: str, model_size: str, lan: str,
|
backend: str, model_size: str, lan: str,
|
||||||
diarization: bool = False, vac: bool = True, policy: str = "",
|
diarization: bool = False,
|
||||||
|
diarization_backend: str = "",
|
||||||
|
vac: bool = True,
|
||||||
|
policy: str = "",
|
||||||
):
|
):
|
||||||
"""Create a TranscriptionEngine with the given backend config."""
|
"""Create a TranscriptionEngine with the given backend config."""
|
||||||
import gc
|
import gc
|
||||||
@@ -169,6 +172,8 @@ def create_engine(
|
|||||||
transcription=True,
|
transcription=True,
|
||||||
diarization=diarization,
|
diarization=diarization,
|
||||||
)
|
)
|
||||||
|
if diarization_backend:
|
||||||
|
kwargs["diarization_backend"] = diarization_backend
|
||||||
if model_size:
|
if model_size:
|
||||||
kwargs["model_size"] = model_size
|
kwargs["model_size"] = model_size
|
||||||
if policy:
|
if policy:
|
||||||
@@ -179,13 +184,18 @@ def create_engine(
|
|||||||
|
|
||||||
def _extract_text_from_response(response_dict: dict) -> str:
|
def _extract_text_from_response(response_dict: dict) -> str:
|
||||||
"""Extract full transcription text from a FrontData dict."""
|
"""Extract full transcription text from a FrontData dict."""
|
||||||
|
def _strip_or_empty(value: object) -> str:
|
||||||
|
return value.strip() if isinstance(value, str) else ""
|
||||||
|
|
||||||
segments = response_dict.get("lines", [])
|
segments = response_dict.get("lines", [])
|
||||||
full_text = " ".join(
|
full_text = " ".join(
|
||||||
seg.get("text", "").strip()
|
text
|
||||||
for seg in segments
|
for seg in segments
|
||||||
if seg.get("text", "").strip()
|
if isinstance(seg, dict)
|
||||||
|
for text in [_strip_or_empty(seg.get("text"))]
|
||||||
|
if text
|
||||||
)
|
)
|
||||||
buf = response_dict.get("buffer_transcription", "").strip()
|
buf = _strip_or_empty(response_dict.get("buffer_transcription"))
|
||||||
if buf:
|
if buf:
|
||||||
full_text = f"{full_text} {buf}".strip() if full_text else buf
|
full_text = f"{full_text} {buf}".strip() if full_text else buf
|
||||||
return full_text
|
return full_text
|
||||||
@@ -236,7 +246,8 @@ async def run_test(
|
|||||||
# Only print when transcription text actually changes
|
# Only print when transcription text actually changes
|
||||||
current_text = _extract_text_from_response(d)
|
current_text = _extract_text_from_response(d)
|
||||||
if current_text and current_text != last_printed_text:
|
if current_text and current_text != last_printed_text:
|
||||||
buf = d.get("buffer_transcription", "").strip()
|
buf = d.get("buffer_transcription")
|
||||||
|
buf = buf.strip() if isinstance(buf, str) else ""
|
||||||
committed = current_text
|
committed = current_text
|
||||||
if buf and committed.endswith(buf):
|
if buf and committed.endswith(buf):
|
||||||
committed = committed[:-len(buf)].strip()
|
committed = committed[:-len(buf)].strip()
|
||||||
@@ -686,6 +697,12 @@ def main():
|
|||||||
"--diarization", action="store_true",
|
"--diarization", action="store_true",
|
||||||
help="Enable speaker diarization.",
|
help="Enable speaker diarization.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--diarization-backend",
|
||||||
|
default="",
|
||||||
|
choices=["diart", "sortformer"],
|
||||||
|
help="Diarization backend when --diarization is enabled.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--benchmark", action="store_true",
|
"--benchmark", action="store_true",
|
||||||
help="Run benchmark across all detected backend+policy combinations.",
|
help="Run benchmark across all detected backend+policy combinations.",
|
||||||
@@ -748,7 +765,10 @@ def main():
|
|||||||
logger.info(f"Creating {args.backend} engine...")
|
logger.info(f"Creating {args.backend} engine...")
|
||||||
engine = create_engine(
|
engine = create_engine(
|
||||||
args.backend, args.model_size, args.lan,
|
args.backend, args.model_size, args.lan,
|
||||||
diarization=args.diarization, vac=vac, policy=policy,
|
diarization=args.diarization,
|
||||||
|
diarization_backend=args.diarization_backend,
|
||||||
|
vac=vac,
|
||||||
|
policy=policy,
|
||||||
)
|
)
|
||||||
logger.info("Engine ready.")
|
logger.info("Engine ready.")
|
||||||
|
|
||||||
|
|||||||
6575
uv.lock
generated
Normal file
6575
uv.lock
generated
Normal file
File diff suppressed because one or more lines are too long
@@ -120,6 +120,7 @@ class AlignAttBase(ABC):
|
|||||||
self.state.segments = []
|
self.state.segments = []
|
||||||
self.state.log_segments += 1
|
self.state.log_segments += 1
|
||||||
self.state.pending_incomplete_tokens = []
|
self.state.pending_incomplete_tokens = []
|
||||||
|
self.state.pending_retries = 0
|
||||||
|
|
||||||
def segments_len(self):
|
def segments_len(self):
|
||||||
return sum(s.shape[0] for s in self.state.segments) / 16000
|
return sum(s.shape[0] for s in self.state.segments) / 16000
|
||||||
@@ -223,6 +224,7 @@ class AlignAttBase(ABC):
|
|||||||
new_segment = False
|
new_segment = False
|
||||||
|
|
||||||
logits = self._apply_token_suppression(logits)
|
logits = self._apply_token_suppression(logits)
|
||||||
|
logits = self._apply_dry_penalty(logits, current_tokens)
|
||||||
current_tokens, completed = self._update_tokens(
|
current_tokens, completed = self._update_tokens(
|
||||||
current_tokens, logits, sum_logprobs
|
current_tokens, logits, sum_logprobs
|
||||||
)
|
)
|
||||||
@@ -326,9 +328,13 @@ class AlignAttBase(ABC):
|
|||||||
|
|
||||||
for word, word_tokens in zip(split_words, split_tokens):
|
for word, word_tokens in zip(split_words, split_tokens):
|
||||||
if replacement_char in word:
|
if replacement_char in word:
|
||||||
logger.warning(f"[UTF-8 Filter] Skipping: {repr(word)}")
|
cleaned = word.replace(replacement_char, "")
|
||||||
timestamp_idx += len(word_tokens)
|
if not cleaned.strip():
|
||||||
continue
|
logger.debug(f"[UTF-8 Filter] Skipping: {repr(word)}")
|
||||||
|
timestamp_idx += len(word_tokens)
|
||||||
|
continue
|
||||||
|
logger.debug(f"[UTF-8 Filter] Cleaned {repr(word)} -> {repr(cleaned)}")
|
||||||
|
word = cleaned
|
||||||
|
|
||||||
try:
|
try:
|
||||||
current_timestamp = l_absolute_timestamps[timestamp_idx]
|
current_timestamp = l_absolute_timestamps[timestamp_idx]
|
||||||
@@ -354,21 +360,84 @@ class AlignAttBase(ABC):
|
|||||||
|
|
||||||
def _handle_pending_tokens(self, split_words, split_tokens):
|
def _handle_pending_tokens(self, split_words, split_tokens):
|
||||||
"""Handle incomplete UTF-8 tokens for next chunk."""
|
"""Handle incomplete UTF-8 tokens for next chunk."""
|
||||||
self.state.pending_incomplete_tokens = []
|
|
||||||
MAX_PENDING_TOKENS = 10
|
MAX_PENDING_TOKENS = 10
|
||||||
|
MAX_PENDING_RETRIES = 2
|
||||||
replacement_char = "\ufffd"
|
replacement_char = "\ufffd"
|
||||||
|
|
||||||
if split_words and replacement_char in split_words[-1]:
|
if split_words and replacement_char in split_words[-1]:
|
||||||
if len(split_tokens[-1]) <= MAX_PENDING_TOKENS:
|
self.state.pending_retries += 1
|
||||||
|
if self.state.pending_retries > MAX_PENDING_RETRIES:
|
||||||
|
logger.warning(
|
||||||
|
f"[UTF-8 Fix] Dropping {len(split_tokens[-1])} incomplete tokens "
|
||||||
|
f"after {MAX_PENDING_RETRIES} retries (won't resolve)"
|
||||||
|
)
|
||||||
|
self.state.pending_incomplete_tokens = []
|
||||||
|
self.state.pending_retries = 0
|
||||||
|
elif len(split_tokens[-1]) <= MAX_PENDING_TOKENS:
|
||||||
self.state.pending_incomplete_tokens = split_tokens[-1]
|
self.state.pending_incomplete_tokens = split_tokens[-1]
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"[UTF-8 Fix] Holding {len(self.state.pending_incomplete_tokens)} "
|
f"[UTF-8 Fix] Holding {len(self.state.pending_incomplete_tokens)} "
|
||||||
f"incomplete tokens for next chunk"
|
f"incomplete tokens for next chunk (retry {self.state.pending_retries})"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"[UTF-8 Fix] Skipping {len(split_tokens[-1])} tokens "
|
f"[UTF-8 Fix] Skipping {len(split_tokens[-1])} tokens "
|
||||||
f"(exceeds limit of {MAX_PENDING_TOKENS}, likely hallucination)"
|
f"(exceeds limit of {MAX_PENDING_TOKENS}, likely hallucination)"
|
||||||
)
|
)
|
||||||
|
self.state.pending_incomplete_tokens = []
|
||||||
|
self.state.pending_retries = 0
|
||||||
|
else:
|
||||||
|
self.state.pending_incomplete_tokens = []
|
||||||
|
self.state.pending_retries = 0
|
||||||
|
|
||||||
|
# === Repetition penalty ===
|
||||||
|
|
||||||
|
def _apply_dry_penalty(self, logits, current_tokens):
|
||||||
|
"""DRY penalty v0: penalize tokens that would extend a verbatim repetition.
|
||||||
|
See https://github.com/oobabooga/text-generation-webui/pull/5677
|
||||||
|
|
||||||
|
Scans the decoded sequence for positions where the current suffix already
|
||||||
|
appeared --> for each such match, the token that followed it in the past is
|
||||||
|
penalised exponentially with the match length
|
||||||
|
"""
|
||||||
|
eot = self.tokenizer.eot
|
||||||
|
seq = current_tokens[0].tolist()
|
||||||
|
if len(seq) < 5:
|
||||||
|
return logits
|
||||||
|
|
||||||
|
last = seq[-1]
|
||||||
|
if last >= eot:
|
||||||
|
return logits
|
||||||
|
|
||||||
|
penalties = {}
|
||||||
|
for i in range(len(seq) - 2, -1, -1):
|
||||||
|
if seq[i] != last:
|
||||||
|
continue
|
||||||
|
next_tok = seq[i + 1]
|
||||||
|
if next_tok >= eot:
|
||||||
|
continue
|
||||||
|
|
||||||
|
length = 1
|
||||||
|
while length < 50:
|
||||||
|
j, k = i - length, len(seq) - 1 - length
|
||||||
|
if j < 0 or k <= i:
|
||||||
|
break
|
||||||
|
if seq[j] != seq[k] or seq[j] >= eot:
|
||||||
|
break
|
||||||
|
length += 1
|
||||||
|
|
||||||
|
if next_tok not in penalties or length > penalties[next_tok]:
|
||||||
|
penalties[next_tok] = length
|
||||||
|
|
||||||
|
if penalties:
|
||||||
|
max_len = max(penalties.values())
|
||||||
|
if max_len >= 4:
|
||||||
|
logger.debug(f"[DRY] penalising {len(penalties)} tokens (longest match: {max_len})")
|
||||||
|
for tok, length in penalties.items():
|
||||||
|
if length >= 2:
|
||||||
|
logits[:, tok] = logits[:, tok] - 1.0 * 2.0 ** (length - 2)
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
# === Abstract methods — subclass must implement ===
|
# === Abstract methods — subclass must implement ===
|
||||||
|
|
||||||
|
|||||||
@@ -200,9 +200,12 @@ class SimulStreamingASR:
|
|||||||
if self.encoder_backend == "whisper":
|
if self.encoder_backend == "whisper":
|
||||||
self.disable_fast_encoder = True
|
self.disable_fast_encoder = True
|
||||||
|
|
||||||
if self.encoder_backend == "mlx-whisper" and platform.system() == "Darwin":
|
# MLX full decoder disabled by default — MLXAlignAtt has known issues
|
||||||
if not hasattr(self, '_full_mlx_disabled'):
|
# with token generation after punctuation. Users can opt-in with
|
||||||
self.use_full_mlx = True
|
# --use-full-mlx if they want to test it.
|
||||||
|
# if self.encoder_backend == "mlx-whisper" and platform.system() == "Darwin":
|
||||||
|
# if not hasattr(self, '_full_mlx_disabled'):
|
||||||
|
# self.use_full_mlx = True
|
||||||
|
|
||||||
self.cfg = AlignAttConfig(
|
self.cfg = AlignAttConfig(
|
||||||
tokenizer_is_multilingual= is_multilingual,
|
tokenizer_is_multilingual= is_multilingual,
|
||||||
|
|||||||
@@ -25,7 +25,8 @@ class DecoderState:
|
|||||||
context: Any = None
|
context: Any = None
|
||||||
|
|
||||||
pending_incomplete_tokens: List[int] = field(default_factory=list)
|
pending_incomplete_tokens: List[int] = field(default_factory=list)
|
||||||
|
pending_retries: int = 0
|
||||||
|
|
||||||
global_time_offset: float = 0.0
|
global_time_offset: float = 0.0
|
||||||
cumulative_time_offset: float = 0.0
|
cumulative_time_offset: float = 0.0
|
||||||
first_timestamp: Optional[float] = None
|
first_timestamp: Optional[float] = None
|
||||||
@@ -78,8 +79,9 @@ class DecoderState:
|
|||||||
self.last_attend_frame = -rewind_threshold
|
self.last_attend_frame = -rewind_threshold
|
||||||
self.cumulative_time_offset = 0.0
|
self.cumulative_time_offset = 0.0
|
||||||
self.pending_incomplete_tokens = []
|
self.pending_incomplete_tokens = []
|
||||||
|
self.pending_retries = 0
|
||||||
self.log_segments += 1
|
self.log_segments += 1
|
||||||
|
|
||||||
def full_reset(self, rewind_threshold: int = 200):
|
def full_reset(self, rewind_threshold: int = 200):
|
||||||
"""
|
"""
|
||||||
Full reset including audio segments and tokens.
|
Full reset including audio segments and tokens.
|
||||||
|
|||||||
@@ -29,7 +29,8 @@ class MLXDecoderState:
|
|||||||
context: Any = None
|
context: Any = None
|
||||||
|
|
||||||
pending_incomplete_tokens: List[int] = field(default_factory=list)
|
pending_incomplete_tokens: List[int] = field(default_factory=list)
|
||||||
|
pending_retries: int = 0
|
||||||
|
|
||||||
global_time_offset: float = 0.0
|
global_time_offset: float = 0.0
|
||||||
cumulative_time_offset: float = 0.0
|
cumulative_time_offset: float = 0.0
|
||||||
first_timestamp: Optional[float] = None
|
first_timestamp: Optional[float] = None
|
||||||
@@ -59,8 +60,9 @@ class MLXDecoderState:
|
|||||||
self.last_attend_frame = -rewind_threshold
|
self.last_attend_frame = -rewind_threshold
|
||||||
self.cumulative_time_offset = 0.0
|
self.cumulative_time_offset = 0.0
|
||||||
self.pending_incomplete_tokens = []
|
self.pending_incomplete_tokens = []
|
||||||
|
self.pending_retries = 0
|
||||||
self.log_segments += 1
|
self.log_segments += 1
|
||||||
|
|
||||||
def full_reset(self, rewind_threshold: int = 200):
|
def full_reset(self, rewind_threshold: int = 200):
|
||||||
"""
|
"""
|
||||||
Full reset including audio segments and tokens.
|
Full reset including audio segments and tokens.
|
||||||
|
|||||||
@@ -296,10 +296,15 @@ class Tokenizer:
|
|||||||
current_tokens.append(token)
|
current_tokens.append(token)
|
||||||
decoded = self.decode_with_timestamps(current_tokens)
|
decoded = self.decode_with_timestamps(current_tokens)
|
||||||
|
|
||||||
if (
|
try:
|
||||||
replacement_char not in decoded
|
replacement_char_index = decoded.index(replacement_char)
|
||||||
or decoded_full[unicode_offset + decoded.index(replacement_char)]
|
replacement_char_index += unicode_offset
|
||||||
== replacement_char
|
except ValueError:
|
||||||
|
replacement_char_index = None
|
||||||
|
|
||||||
|
if replacement_char_index is None or (
|
||||||
|
replacement_char_index < len(decoded_full)
|
||||||
|
and decoded_full[replacement_char_index] == replacement_char
|
||||||
):
|
):
|
||||||
words.append(decoded)
|
words.append(decoded)
|
||||||
word_tokens.append(current_tokens)
|
word_tokens.append(current_tokens)
|
||||||
|
|||||||
Reference in New Issue
Block a user