38 Commits

Author SHA1 Message Date
Quentin Fuxa
d9a4c8dcb2 Refactor transcription and diarization handling with token-by-token validation. Introduce segment buffers for ephemeral content and update API to return structured segment data. Enhance silence handling and improve web interface for text transcripts. 2025-11-30 16:39:27 +01:00
Quentin Fuxa
4fb735a784 new token treatment only iar 2025-11-30 15:16:36 +01:00
Quentin Fuxa
d2f998cb7e val 2025-11-30 14:37:37 +01:00
Quentin Fuxa
7b18917f2b LoRA archi 2025-11-30 12:30:18 +01:00
Quentin Fuxa
f1113e3eb0 update with LoRA 2025-11-29 18:33:30 +01:00
Quentin Fuxa
cc5f819ce7 hf weights 2025-11-29 17:50:46 +01:00
Quentin Fuxa
82cd24bb75 LoRa path v0 - functional 2025-11-29 17:21:10 +01:00
Quentin Fuxa
d45c397c6a simulstreaming: limit n tokens to prevent hallucinations 2025-11-28 21:41:19 +01:00
Quentin Fuxa
45bf3f57d7 troubleshooting doc for aarch64 systems 2025-11-28 21:40:43 +01:00
Quentin Fuxa
1d88ba9d69 Fixes #294. improve model path backend detection and file extraction 2025-11-27 23:14:00 +01:00
Quentin Fuxa
c0965c6c31 Lines to Segments. Merging dataclasses 2025-11-27 21:54:58 +01:00
Quentin Fuxa
34ddd2ac02 update doc 2025-11-25 23:20:00 +01:00
Quentin Fuxa
345d781e97 update doc 2025-11-25 23:20:00 +01:00
Quentin Fuxa
28cf831701 indicate for context token limits for --max-context-tokens. bump to 0.2.16.dev0 2025-11-25 23:45:15 +01:00
Quentin Fuxa
60c62f8f84 troubleshooting #271 #276 #284 #286 2025-11-25 23:31:46 +01:00
Quentin Fuxa
7faa21f95f alignatt: enable model sharing by removing hooks and centralizing session state. Solves #282
Co-authored-by: Emmanuel Schmidbauer <eschmidbauer@gmail.com>
2025-11-25 23:07:42 +01:00
Quentin Fuxa
4e9f951551 correct silences handling when language not auto 2025-11-20 11:20:00 +01:00
Quentin Fuxa
870141298c isort 2025-11-23 11:20:00 +01:00
Quentin Fuxa
872faa422a correct silences handling when language not auto 2025-11-20 11:20:00 +01:00
Quentin Fuxa
fc9cb66813 disabling vac is not advised 2025-11-23 11:20:00 +01:00
Quentin Fuxa
a175d1a327 fixes silence detected but never reported by silero 2025-11-23 11:20:00 +01:00
Quentin Fuxa
6206fff118 0.2.15 2025-11-21 23:52:00 +01:00
Quentin Fuxa
b5067249c0 stt/diar/nllw alignment: internal rework 5 2025-11-20 23:52:00 +01:00
Quentin Fuxa
f4f9831d39 stt/diar/nllw alignment: internal rework 5 2025-11-20 23:52:00 +01:00
Quentin Fuxa
254faaf64c stt/diar/nllw alignment: internal rework 5 2025-11-20 23:52:00 +01:00
Quentin Fuxa
8e7aea4fcf internal rework 4 2025-11-20 23:45:20 +01:00
Quentin Fuxa
270faf2069 internal rework 3 2025-11-20 22:28:30 +01:00
Quentin Fuxa
b7c1cc77cc internal rework 2 2025-11-20 22:06:38 +01:00
Quentin Fuxa
9a45ec221c internal rework 1 2025-11-20 12:58:38 +01:00
Quentin Fuxa
3e13ee6fc3 bump to post4 2025-11-19 21:23:43 +01:00
Quentin Fuxa
b7d20a0ff0 segment attribution in result formatter 2025-11-19 21:10:28 +01:00
Quentin Fuxa
c1bb9c2bde reduce flickering remaining_time_transcription 2025-11-19 19:09:37 +01:00
Quentin Fuxa
11e9def0b2 diarization corrections 2025-11-19 19:06:03 +01:00
Quentin Fuxa
3104f40f6e fixes #279 #278 2025-11-19 18:17:50 +01:00
Quentin Fuxa
e9b4ceeee5 Add audio partial silence in chunks handling. bump to 0.2.14.post3 2025-11-17 22:52:00 +01:00
Quentin Fuxa
437641fb43 reduce min-chunk-size to 0.1, set default model to base 2027-04-25 23:52:00 +02:00
Quentin Fuxa
bfd60b3921 Add audio partial silence in chunks handling. bump to 0.2.14.post2 2025-11-17 22:52:00 +01:00
Quentin Fuxa
1e67bf97f0 improve buffering when use of heavy models 2027-04-25 23:52:00 +02:00
29 changed files with 1708 additions and 1931 deletions

View File

@@ -37,10 +37,9 @@ RUN pip3 install --upgrade pip setuptools wheel && \
COPY . . COPY . .
# Install WhisperLiveKit directly, allowing for optional dependencies # Install WhisperLiveKit directly, allowing for optional dependencies
# Example: --build-arg EXTRAS="translation"
RUN if [ -n "$EXTRAS" ]; then \ RUN if [ -n "$EXTRAS" ]; then \
echo "Installing with extras: [$EXTRAS]"; \ echo "Installing with extras: [$EXTRAS]"; \
pip install --no-cache-dir "whisperlivekit[$EXTRAS]"; \ pip install --no-cache-dir whisperlivekit[$EXTRAS]; \
else \ else \
echo "Installing base package only"; \ echo "Installing base package only"; \
pip install --no-cache-dir whisperlivekit; \ pip install --no-cache-dir whisperlivekit; \

View File

@@ -147,8 +147,8 @@ async def websocket_endpoint(websocket: WebSocket):
|-----------|-------------|---------| |-----------|-------------|---------|
| `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/default_and_custom_models.md) | `small` | | `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/default_and_custom_models.md) | `small` |
| `--model-path` | Local .pt file/directory **or** Hugging Face repo ID containing the Whisper model. Overrides `--model`. Recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/default_and_custom_models.md) | `None` | | `--model-path` | Local .pt file/directory **or** Hugging Face repo ID containing the Whisper model. Overrides `--model`. Recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/default_and_custom_models.md) | `None` |
| `--language` | List [here](docs/supported_languages.md). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` | | `--language` | List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` |
| `--target-language` | If sets, translates using [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting). [200 languages available](docs/supported_languages.md). If you want to translate to english, you can also use `--direct-english-translation`. The STT model will try to directly output the translation. | `None` | | `--target-language` | If sets, translates using [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting). [200 languages available](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/supported_languages.md). If you want to translate to english, you can also use `--direct-english-translation`. The STT model will try to directly output the translation. | `None` |
| `--diarization` | Enable speaker identification | `False` | | `--diarization` | Enable speaker identification | `False` |
| `--backend-policy` | Streaming strategy: `1`/`simulstreaming` uses AlignAtt SimulStreaming, `2`/`localagreement` uses the LocalAgreement policy | `simulstreaming` | | `--backend-policy` | Streaming strategy: `1`/`simulstreaming` uses AlignAtt SimulStreaming, `2`/`localagreement` uses the LocalAgreement policy | `simulstreaming` |
| `--backend` | Whisper implementation selector. `auto` picks MLX on macOS (if installed), otherwise Faster-Whisper, otherwise vanilla Whisper. You can also force `mlx-whisper`, `faster-whisper`, `whisper`, or `openai-api` (LocalAgreement only) | `auto` | | `--backend` | Whisper implementation selector. `auto` picks MLX on macOS (if installed), otherwise Faster-Whisper, otherwise vanilla Whisper. You can also force `mlx-whisper`, `faster-whisper`, `whisper`, or `openai-api` (LocalAgreement only) | `auto` |
@@ -267,7 +267,7 @@ docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
#### Customization #### Customization
- `--build-arg` Options: - `--build-arg` Options:
- `EXTRAS="translation"` - Add extras to the image's installation (no spaces). Remember to set necessary container options! - `EXTRAS="whisper-timestamped"` - Add extras to the image's installation (no spaces). Remember to set necessary container options!
- `HF_PRECACHE_DIR="./.cache/"` - Pre-load a model cache for faster first-time start - `HF_PRECACHE_DIR="./.cache/"` - Pre-load a model cache for faster first-time start
- `HF_TKN_FILE="./token"` - Add your Hugging Face Hub access token to download gated models - `HF_TKN_FILE="./token"` - Add your Hugging Face Hub access token to download gated models

View File

@@ -6,7 +6,7 @@ Capture the audio of your current tab, transcribe diarize and translate it using
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/chrome-extension/demo-extension.png" alt="WhisperLiveKit Demo" width="730"> <img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/chrome-extension/demo-extension.png" alt="WhisperLiveKit Demo" width="730">
## Running this extension ## Running this extension
1. Run `python scripts/sync_extension.py` to copy frontend files to the `chrome-extension` directory. 1. Run `python sync_extension.py` to copy frontend files to the `chrome-extension` directory.
2. Load the `chrome-extension` directory in Chrome as an unpacked extension. 2. Load the `chrome-extension` directory in Chrome as an unpacked extension.

View File

@@ -1,53 +1,22 @@
# WhisperLiveKit WebSocket API Documentation # WhisperLiveKit WebSocket API Documentation
> !! **Note**: The new API structure described in this document is currently under deployment. WLK provides real-time speech transcription, speaker diarization, and translation through a WebSocket API. The server sends updates as audio is processed, allowing clients to display live transcription results with minimal latency.
This documentation is intended for devs who want to build custom frontends.
WLK provides real-time speech transcription, speaker diarization, and translation through a WebSocket API. The server sends incremental updates as audio is processed, allowing clients to display live transcription results with minimal latency.
--- ---
## Legacy API (Current) ## Endpoints
### Message Structure | Endpoint | Description |
|----------|-------------|
The current API sends complete state snapshots on each update (several time per second) | `/` | Main web interface with visual styling |
| `/text` | Simple text-based interface for easy copy/paste (debug/development) |
```typescript | `/asr` | WebSocket endpoint for audio streaming |
{
"type": str,
"status": str,
"lines": [
{
"speaker": int,
"text": str,
"start": float,
"end": float,
"translation": str | null,
"detected_language": str
}
],
"buffer_transcription": str,
"buffer_diarization": str,
"remaining_time_transcription": float,
"remaining_time_diarization": float
}
```
--- ---
## New API (Under Development)
### Philosophy
Principles:
- **Incremental Updates**: Only updates and new segments are sent
- **Ephemeral Buffers**: Temporary, unvalidated data displayed in real-time but overwritten on next update, at speaker level
## Message Format ## Message Format
### Transcript Update (Server → Client)
```typescript ```typescript
{ {
@@ -58,22 +27,11 @@ Principles:
"id": number, "id": number,
"speaker": number, "speaker": number,
"text": string, "text": string,
"start_speaker": float, "start_speaker": string, // HH:MM:SS format
"start": float, "start": string, // HH:MM:SS format
"end": float, "end": string, // HH:MM:SS format
"language": string | null, "language": string | null,
"translation": string, "translation": string,
"words": [
{
"text": string,
"start": float,
"end": float,
"validated": {
"text": boolean,
"speaker": boolean,
}
}
],
"buffer": { "buffer": {
"transcription": string, "transcription": string,
"diarization": string, "diarization": string,
@@ -94,9 +52,10 @@ Principles:
```json ```json
{ {
"type": "config", "type": "config",
"useAudioWorklet": true / false "useAudioWorklet": true
} }
``` ```
- `useAudioWorklet`: If `true`, client should use AudioWorklet for PCM streaming. If `false`, use MediaRecorder for WebM.
#### Ready to Stop Message (sent after processing complete) #### Ready to Stop Message (sent after processing complete)
```json ```json
@@ -104,6 +63,7 @@ Principles:
"type": "ready_to_stop" "type": "ready_to_stop"
} }
``` ```
Indicates all audio has been processed and the client can safely close the connection.
--- ---
@@ -113,152 +73,179 @@ Principles:
| Field | Type | Description | | Field | Type | Description |
|-------|------|-------------| |-------|------|-------------|
| `id` | `number` | Unique identifier for this segment. Used by clients to update specific segments efficiently. | | `id` | `number` | Unique identifier for this segment. |
| `speaker` | `number` | Speaker ID (1, 2, 3...). Special value `-2` indicates silence. | | `speaker` | `number` | Speaker ID (1, 2, 3...). Special value `-2` indicates silence. |
| `text` | `string` | Validated transcription text for this update. Should be **appended** to the segment's text on the client side. | | `text` | `string` | Validated transcription text. |
| `start_speaker` | `float` | Timestamp (seconds) when this speaker segment began. | | `start_speaker` | `string` | Timestamp (HH:MM:SS) when this speaker segment began. |
| `start` | `float` | Timestamp (seconds) of the first word in this update. | | `start` | `string` | Timestamp (HH:MM:SS) of the first word. |
| `end` | `float` | Timestamp (seconds) of the last word in this update. | | `end` | `string` | Timestamp (HH:MM:SS) of the last word. |
| `language` | `string \| null` | ISO language code (e.g., "en", "fr"). `null` until language is detected. | | `language` | `string \| null` | ISO language code (e.g., "en", "fr"). `null` until detected. |
| `translation` | `string` | Validated translation text for this update. Should be **appended** to the segment's translation on the client side. | | `translation` | `string` | Validated translation text. |
| `words` | `Array` | Array of word-level objects with timing and validation information. | | `buffer` | `Object` | Per-segment temporary buffers (see below). |
| `buffer` | `Object` | Per-segment temporary buffers, see below |
### Word Object
| Field | Type | Description |
|-------|------|-------------|
| `text` | `string` | The word text. |
| `start` | `number` | Start timestamp (seconds) of this word. |
| `end` | `number` | End timestamp (seconds) of this word. |
| `validated.text` | `boolean` | Whether the transcription text has been validated. if false, word is also in buffer: transcription |
| `validated.speaker` | `boolean` | Whether the speaker assignment has been validated. if false, word is also in buffer: diarization |
| `validated.language` | `boolean` | Whether the language detection has been validated. if false, word is also in buffer: translation |
### Buffer Object (Per-Segment) ### Buffer Object (Per-Segment)
Buffers are **ephemeral**. They should be displayed to the user but not stored permanently in the frontend. Each update may contain a completely different buffer value, and previous buffer is likely to be in the next validated text. Buffers are **ephemeral**. They should be displayed to the user but are overwritten on each update. Only the **last non-silent segment** contains buffer content.
| Field | Type | Description | | Field | Type | Description |
|-------|------|-------------| |-------|------|-------------|
| `transcription` | `string` | Pending transcription text. Displayed immediately but **overwritten** on next update. | | `transcription` | `string` | Text pending validation (waiting for more context). |
| `diarization` | `string` | Pending diarization text (text waiting for speaker assignment). Displayed immediately but **overwritten** on next update. | | `diarization` | `string` | Text pending speaker assignment (diarization hasn't caught up). |
| `translation` | `string` | Pending translation text. Displayed immediately but **overwritten** on next update. | | `translation` | `string` | Translation pending validation. |
### Metadata Fields ### Metadata Fields
| Field | Type | Description | | Field | Type | Description |
|-------|------|-------------| |-------|------|-------------|
| `remaining_time_transcription` | `float` | Seconds of audio waiting for transcription processing. | | `remaining_time_transcription` | `float` | Seconds of audio waiting for transcription. |
| `remaining_time_diarization` | `float` | Seconds of audio waiting for speaker diarization. | | `remaining_time_diarization` | `float` | Seconds of audio waiting for diarization. |
### Status Values ### Status Values
| Status | Description | | Status | Description |
|--------|-------------| |--------|-------------|
| `active_transcription` | Normal operation, transcription is active. | | `active_transcription` | Normal operation, transcription is active. |
| `no_audio_detected` | No audio has been detected yet. | | `no_audio_detected` | No audio/speech has been detected yet. |
--- ---
## Update Behavior ## Behavior Notes
### Incremental Updates ### Silence Handling
The API sends **only changed or new segments**. Clients should: - **Short silences (< 2 seconds)** are filtered out and not displayed.
- Only significant pauses appear as silence segments with `speaker: -2`.
- Consecutive same-speaker segments are merged even across short silences.
1. Maintain a local map of segments by ID ### Update Frequency
2. When receiving an update, merge/update segments by ID
3. Render only the changed segments
### Language Detection - **Active transcription**: ~20 updates/second (every 50ms)
- **During silence**: ~2 updates/second (every 500ms) to reduce bandwidth
When language is detected for a segment: ### Token-by-Token Validation (Diarization Mode)
```jsonc When diarization is enabled, text is validated **token-by-token** as soon as diarization covers each token, rather than waiting for punctuation. This provides:
// Update 1: No language yet - Faster text validation
{ - More responsive speaker attribution
"segments": [ - Buffer only contains tokens that diarization hasn't processed yet
{"id": 1, "speaker": 1, "text": "May see", "language": null}
] ---
}
## Example Messages
// Update 2: Same segment ID, language now detected
{ ### Normal Transcription
"segments": [
{"id": 1, "speaker": 1, "text": "Merci", "language": "fr"} ```json
]
}
```
**Client behavior**: **Replace** the existing segment with the same ID.
### Buffer Behavior
Buffers are **per-segment** to handle multi-speaker scenarios correctly.
#### Example: Translation with diarization and translation
```jsonc
// Update 1
{ {
"type": "transcript_update",
"status": "active_transcription",
"segments": [ "segments": [
{ {
"id": 1, "id": 1,
"speaker": 1, "speaker": 1,
"text": "Hello world, how are", "text": "Hello, how are you today?",
"start_speaker": "0:00:02",
"start": "0:00:02",
"end": "0:00:05",
"language": "en",
"translation": "",
"buffer": {
"transcription": " I'm doing",
"diarization": "",
"translation": ""
}
}
],
"metadata": {
"remaining_time_transcription": 0.5,
"remaining_time_diarization": 0
}
}
```
### With Diarization Buffer
```json
{
"type": "transcript_update",
"status": "active_transcription",
"segments": [
{
"id": 1,
"speaker": 1,
"text": "The meeting starts at nine.",
"start_speaker": "0:00:03",
"start": "0:00:03",
"end": "0:00:06",
"language": "en",
"translation": "", "translation": "",
"buffer": { "buffer": {
"transcription": "", "transcription": "",
"diarization": " you on", "diarization": " Let me check my calendar",
"translation": "Bonjour le monde" "translation": ""
} }
} }
] ],
"metadata": {
"remaining_time_transcription": 0.3,
"remaining_time_diarization": 2.1
}
} }
// ==== Frontend ====
// <SPEAKER>1</SPEAKER>
// <TRANSCRIPTION>Hello world, how are <DIARIZATION BUFFER> you on</DIARIZATION BUFFER></TRANSCRIPTION>
// <TRANSLATION><TRANSLATION BUFFER>Bonjour le monde</TRANSLATION BUFFER></TRANSLATION>
// Update 2
{
"segments": [
{
"id": 1,
"speaker": 1,
"text": " you on this",
"translation": "Bonjour tout le monde",
"buffer": {
"transcription": "",
"diarization": " beautiful day",
"translation": ",comment"
}
},
]
}
// ==== Frontend ====
// <SPEAKER>1</SPEAKER>
// <TRANSCRIPTION>Hello world, how are you on this<DIARIZATION BUFFER> beautiful day</DIARIZATION BUFFER></TRANSCRIPTION>
// <TRANSLATION>Bonjour tout le monde<TRANSLATION BUFFER>, comment</TRANSLATION BUFFER><TRANSLATION>
``` ```
### Silence Segments ### Silence Segment
Silence is represented with the speaker id = `-2`: ```json
```jsonc
{ {
"id": 5, "id": 5,
"speaker": -2, "speaker": -2,
"text": "", "text": "",
"start": 10.5, "start_speaker": "0:00:10",
"end": 12.3 "start": "0:00:10",
"end": "0:00:15",
"language": null,
"translation": "",
"buffer": {
"transcription": "",
"diarization": "",
"translation": ""
}
} }
``` ```
---
## Text Transcript Endpoint (`/text`)
The `/text` endpoint provides a simple, monospace text interface designed for:
- Easy copy/paste of transcripts
- Debugging and development
- Integration testing
Output uses text markers instead of HTML styling:
```
[METADATA transcription_lag=0.5s diarization_lag=1.2s]
[SPEAKER 1] 0:00:03 - 0:00:11 [LANG: en]
Hello world, how are you doing today?[DIAR_BUFFER] I'm doing fine[/DIAR_BUFFER]
[SILENCE 0:00:15 - 0:00:18]
[SPEAKER 2] 0:00:18 - 0:00:22 [LANG: en]
That's great to hear!
[TRANSLATION]C'est super à entendre![/TRANSLATION]
```
### Markers
| Marker | Description |
|--------|-------------|
| `[SPEAKER N]` | Speaker label with ID |
| `[SILENCE start - end]` | Silence segment |
| `[LANG: xx]` | Detected language code |
| `[DIAR_BUFFER]...[/DIAR_BUFFER]` | Text pending speaker assignment |
| `[TRANS_BUFFER]...[/TRANS_BUFFER]` | Text pending validation |
| `[TRANSLATION]...[/TRANSLATION]` | Translation content |
| `[METADATA ...]` | Lag/timing information |

View File

@@ -1,13 +1,73 @@
### Alignment between STT Tokens and Diarization Segments # Alignment Principles
- Example 1: The punctuation from STT and the speaker change from Diariation come in the prediction `t` This document explains how transcription tokens are aligned with diarization (speaker identification) segments.
- Example 2: The punctuation from STT comes from prediction `t`, but the speaker change from Diariation come in the prediction `t-1`
- Example 3: The punctuation from STT comes from prediction `t-1`, but the speaker change from Diariation come in the prediction `t`
> `#` Is the split between the `t-1` prediction and `t` prediction. ---
## Token-by-Token Validation
When diarization is enabled, text is validated **token-by-token** rather than waiting for sentence boundaries. As soon as diarization covers a token's time range, that token is validated and assigned to the appropriate speaker.
### How It Works
1. **Transcription produces tokens** with timestamps (start, end)
2. **Diarization produces speaker segments** with timestamps
3. **For each token**: Check if diarization has caught up to that token's time
- If yes → Find speaker with maximum overlap, validate token
- If no → Keep token in "pending" (becomes diarization buffer)
```
Timeline: 0s -------- 5s -------- 10s -------- 15s
| | | |
Transcription: [Hello, how are you doing today?]
|_______|___|____|_____|_____|_____|
tok1 tok2 tok3 tok4 tok5 tok6
Diarization: [SPEAKER 1 ][SPEAKER 2 ]
|__________________|__________________|
0s 8s 15s
At time t when diarization covers up to 8s:
- Tokens 1-4 (0s-7s) → Validated as SPEAKER 1
- Tokens 5-6 (7s-10s) → In buffer (diarization hasn't caught up)
```
---
## Silence Handling
- **Short silences (< 2 seconds)**: Filtered out, not displayed
- **Significant silences (≥ 2 seconds)**: Displayed as silence segments with `speaker: -2`
- **Same speaker across gaps**: Segments are merged even if separated by short silences
```
Before filtering:
[SPK1 0:00-0:03] [SILENCE 0:03-0:04] [SPK1 0:04-0:08]
After filtering (silence < 2s):
[SPK1 0:00-0:08] ← Merged into single segment
```
---
## Buffer Types
| Buffer | Contains | Displayed When |
|--------|----------|----------------|
| `transcription` | Text awaiting validation (more context needed) | Always on last segment |
| `diarization` | Text awaiting speaker assignment | When diarization lags behind transcription |
| `translation` | Translation awaiting validation | When translation is enabled |
---
## Legacy: Punctuation-Based Splitting
The previous approach split segments at punctuation marks and aligned with diarization at those boundaries. This is now replaced by token-by-token validation for faster, more responsive results.
### Historical Examples (for reference)
Example of punctuation-based alignment:
## Example 1:
```text ```text
punctuations_segments : __#_______.__________________!____ punctuations_segments : __#_______.__________________!____
diarization_segments: diarization_segments:
@@ -16,56 +76,6 @@ SPK2 # ___________________
--> -->
ALIGNED SPK1 __#_______. ALIGNED SPK1 __#_______.
ALIGNED SPK2 # __________________!____ ALIGNED SPK2 # __________________!____
t-1 output:
SPK1: __#
SPK2: NO
DIARIZATION BUFFER: NO
t output:
SPK1: __#__.
SPK2: __________________!____
DIARIZATION BUFFER: No
``` ```
## Example 2: With token-by-token validation, the alignment happens continuously rather than at punctuation boundaries.
```text
punctuations_segments : _____#__.___________
diarization_segments:
SPK1 ___ #
SPK2 __#______________
-->
ALIGNED SPK1 _____#__.
ALIGNED SPK2 # ___________
t-1 output:
SPK1: ___ #
SPK2:
DIARIZATION BUFFER: __#
t output:
SPK1: __#__.
SPK2: ___________
DIARIZATION BUFFER: No
```
## Example 3:
```text
punctuations_segments : ___.__#__________
diarization_segments:
SPK1 ______#__
SPK2 # ________
-->
ALIGNED SPK1 ___. #
ALIGNED SPK2 __#__________
t-1 output:
SPK1: ___. #
SPK2:
DIARIZATION BUFFER: __#
t output:
SPK1: #
SPK2: __#___________
DIARIZATION BUFFER: NO
```

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "whisperlivekit" name = "whisperlivekit"
version = "0.2.17.post1" version = "0.2.16.dev0"
description = "Real-time speech-to-text with speaker diarization using Whisper" description = "Real-time speech-to-text with speaker diarization using Whisper"
readme = "README.md" readme = "README.md"
authors = [ authors = [
@@ -35,7 +35,6 @@ dependencies = [
"torchaudio>=2.0.0", "torchaudio>=2.0.0",
"torch>=2.0.0", "torch>=2.0.0",
"huggingface-hub>=0.25.0", "huggingface-hub>=0.25.0",
"faster-whisper>=1.2.0",
"tqdm", "tqdm",
"tiktoken", "tiktoken",
'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")' 'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")'
@@ -57,7 +56,6 @@ packages = [
"whisperlivekit", "whisperlivekit",
"whisperlivekit.diarization", "whisperlivekit.diarization",
"whisperlivekit.simul_whisper", "whisperlivekit.simul_whisper",
"whisperlivekit.simul_whisper.mlx",
"whisperlivekit.whisper", "whisperlivekit.whisper",
"whisperlivekit.whisper.assets", "whisperlivekit.whisper.assets",
"whisperlivekit.whisper.normalizers", "whisperlivekit.whisper.normalizers",

View File

@@ -1,7 +1,7 @@
from .audio_processor import AudioProcessor from .audio_processor import AudioProcessor
from .core import TranscriptionEngine from .core import TranscriptionEngine
from .parse_args import parse_args from .parse_args import parse_args
from .web.web_interface import get_inline_ui_html, get_web_interface_html from .web.web_interface import get_inline_ui_html, get_text_transcript_html, get_web_interface_html
__all__ = [ __all__ = [
"TranscriptionEngine", "TranscriptionEngine",
@@ -9,5 +9,6 @@ __all__ = [
"parse_args", "parse_args",
"get_web_interface_html", "get_web_interface_html",
"get_inline_ui_html", "get_inline_ui_html",
"get_text_transcript_html",
"download_simulstreaming_backend", "download_simulstreaming_backend",
] ]

View File

@@ -10,7 +10,7 @@ from whisperlivekit.core import (TranscriptionEngine,
online_diarization_factory, online_factory, online_diarization_factory, online_factory,
online_translation_factory) online_translation_factory)
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
from whisperlivekit.silero_vad_iterator import FixedVADIterator, OnnxWrapper, load_jit_vad from whisperlivekit.silero_vad_iterator import FixedVADIterator
from whisperlivekit.timed_objects import (ASRToken, ChangeSpeaker, FrontData, from whisperlivekit.timed_objects import (ASRToken, ChangeSpeaker, FrontData,
Segment, Silence, State, Transcript) Segment, Silence, State, Transcript)
from whisperlivekit.tokens_alignment import TokensAlignment from whisperlivekit.tokens_alignment import TokensAlignment
@@ -32,7 +32,7 @@ async def get_all_from_queue(queue: asyncio.Queue) -> Union[object, Silence, np.
if isinstance(first_item, Silence): if isinstance(first_item, Silence):
return first_item return first_item
items.append(first_item) items.append(first_item)
while True: while True:
if not queue._queue: if not queue._queue:
break break
@@ -53,15 +53,15 @@ class AudioProcessor:
Processes audio streams for transcription and diarization. Processes audio streams for transcription and diarization.
Handles audio processing, state management, and result formatting. Handles audio processing, state management, and result formatting.
""" """
def __init__(self, **kwargs: Any) -> None: def __init__(self, **kwargs: Any) -> None:
"""Initialize the audio processor with configuration, models, and state.""" """Initialize the audio processor with configuration, models, and state."""
if 'transcription_engine' in kwargs and isinstance(kwargs['transcription_engine'], TranscriptionEngine): if 'transcription_engine' in kwargs and isinstance(kwargs['transcription_engine'], TranscriptionEngine):
models = kwargs['transcription_engine'] models = kwargs['transcription_engine']
else: else:
models = TranscriptionEngine(**kwargs) models = TranscriptionEngine(**kwargs)
# Audio processing settings # Audio processing settings
self.args = models.args self.args = models.args
self.sample_rate = 16000 self.sample_rate = 16000
@@ -85,14 +85,12 @@ class AudioProcessor:
# Models and processing # Models and processing
self.asr: Any = models.asr self.asr: Any = models.asr
self.vac: Optional[FixedVADIterator] = None self.vac_model: Any = models.vac_model
if self.args.vac: if self.args.vac:
if models.vac_session is not None: self.vac: Optional[FixedVADIterator] = FixedVADIterator(models.vac_model)
vac_model = OnnxWrapper(session=models.vac_session) else:
self.vac = FixedVADIterator(vac_model) self.vac: Optional[FixedVADIterator] = None
else:
self.vac = FixedVADIterator(load_jit_vad())
self.ffmpeg_manager: Optional[FFmpegManager] = None self.ffmpeg_manager: Optional[FFmpegManager] = None
self.ffmpeg_reader_task: Optional[asyncio.Task] = None self.ffmpeg_reader_task: Optional[asyncio.Task] = None
self._ffmpeg_error: Optional[str] = None self._ffmpeg_error: Optional[str] = None
@@ -106,7 +104,7 @@ class AudioProcessor:
logger.error(f"FFmpeg error: {error_type}") logger.error(f"FFmpeg error: {error_type}")
self._ffmpeg_error = error_type self._ffmpeg_error = error_type
self.ffmpeg_manager.on_error_callback = handle_ffmpeg_error self.ffmpeg_manager.on_error_callback = handle_ffmpeg_error
self.transcription_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.transcription else None self.transcription_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.transcription else None
self.diarization_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.diarization else None self.diarization_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.diarization else None
self.translation_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.target_language else None self.translation_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.target_language else None
@@ -117,14 +115,14 @@ class AudioProcessor:
self.translation_task: Optional[asyncio.Task] = None self.translation_task: Optional[asyncio.Task] = None
self.watchdog_task: Optional[asyncio.Task] = None self.watchdog_task: Optional[asyncio.Task] = None
self.all_tasks_for_cleanup: List[asyncio.Task] = [] self.all_tasks_for_cleanup: List[asyncio.Task] = []
self.transcription: Optional[Any] = None self.transcription: Optional[Any] = None
self.translation: Optional[Any] = None self.translation: Optional[Any] = None
self.diarization: Optional[Any] = None self.diarization: Optional[Any] = None
if self.args.transcription: if self.args.transcription:
self.transcription = online_factory(self.args, models.asr) self.transcription = online_factory(self.args, models.asr)
self.sep = self.transcription.asr.sep self.sep = self.transcription.asr.sep
if self.args.diarization: if self.args.diarization:
self.diarization = online_diarization_factory(self.args, models.diarization_model) self.diarization = online_diarization_factory(self.args, models.diarization_model)
if models.translation_model: if models.translation_model:
@@ -182,24 +180,24 @@ class AudioProcessor:
def convert_pcm_to_float(self, pcm_buffer: Union[bytes, bytearray]) -> np.ndarray: def convert_pcm_to_float(self, pcm_buffer: Union[bytes, bytearray]) -> np.ndarray:
"""Convert PCM buffer in s16le format to normalized NumPy array.""" """Convert PCM buffer in s16le format to normalized NumPy array."""
return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0 return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
async def get_current_state(self) -> State: async def get_current_state(self) -> State:
"""Get current state.""" """Get current state."""
async with self.lock: async with self.lock:
current_time = time() current_time = time()
remaining_transcription = 0 remaining_transcription = 0
if self.state.end_buffer > 0: if self.state.end_buffer > 0:
remaining_transcription = max(0, round(current_time - self.beg_loop - self.state.end_buffer, 1)) remaining_transcription = max(0, round(current_time - self.beg_loop - self.state.end_buffer, 1))
remaining_diarization = 0 remaining_diarization = 0
if self.state.tokens: if self.state.tokens:
latest_end = max(self.state.end_buffer, self.state.tokens[-1].end if self.state.tokens else 0) latest_end = max(self.state.end_buffer, self.state.tokens[-1].end if self.state.tokens else 0)
remaining_diarization = max(0, round(latest_end - self.state.end_attributed_speaker, 1)) remaining_diarization = max(0, round(latest_end - self.state.end_attributed_speaker, 1))
self.state.remaining_time_transcription = remaining_transcription self.state.remaining_time_transcription = remaining_transcription
self.state.remaining_time_diarization = remaining_diarization self.state.remaining_time_diarization = remaining_diarization
return self.state return self.state
async def ffmpeg_stdout_reader(self) -> None: async def ffmpeg_stdout_reader(self) -> None:
@@ -255,7 +253,7 @@ class AudioProcessor:
async def transcription_processor(self) -> None: async def transcription_processor(self) -> None:
"""Process audio chunks for transcription.""" """Process audio chunks for transcription."""
cumulative_pcm_duration_stream_time = 0.0 cumulative_pcm_duration_stream_time = 0.0
while True: while True:
try: try:
# item = await self.transcription_queue.get() # item = await self.transcription_queue.get()
@@ -311,12 +309,12 @@ class AudioProcessor:
if new_tokens: if new_tokens:
candidate_end_times.append(new_tokens[-1].end) candidate_end_times.append(new_tokens[-1].end)
if _buffer_transcript.end is not None: if _buffer_transcript.end is not None:
candidate_end_times.append(_buffer_transcript.end) candidate_end_times.append(_buffer_transcript.end)
candidate_end_times.append(current_audio_processed_upto) candidate_end_times.append(current_audio_processed_upto)
async with self.lock: async with self.lock:
self.state.tokens.extend(new_tokens) self.state.tokens.extend(new_tokens)
self.state.buffer_transcription = _buffer_transcript self.state.buffer_transcription = _buffer_transcript
@@ -326,13 +324,13 @@ class AudioProcessor:
if self.translation_queue: if self.translation_queue:
for token in new_tokens: for token in new_tokens:
await self.translation_queue.put(token) await self.translation_queue.put(token)
except Exception as e: except Exception as e:
logger.warning(f"Exception in transcription_processor: {e}") logger.warning(f"Exception in transcription_processor: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}") logger.warning(f"Traceback: {traceback.format_exc()}")
if 'pcm_array' in locals() and pcm_array is not SENTINEL : # Check if pcm_array was assigned from queue if 'pcm_array' in locals() and pcm_array is not SENTINEL : # Check if pcm_array was assigned from queue
self.transcription_queue.task_done() self.transcription_queue.task_done()
if self.is_stopping: if self.is_stopping:
logger.info("Transcription processor finishing due to stopping flag.") logger.info("Transcription processor finishing due to stopping flag.")
if self.diarization_queue: if self.diarization_queue:
@@ -353,21 +351,18 @@ class AudioProcessor:
if item.has_ended: if item.has_ended:
self.diarization.insert_silence(item.duration) self.diarization.insert_silence(item.duration)
continue continue
self.diarization.insert_audio_chunk(item) self.diarization.insert_audio_chunk(item)
diarization_segments = await self.diarization.diarize() diarization_segments = await self.diarization.diarize()
diar_end = 0.0 self.state.new_diarization = diarization_segments
if diarization_segments:
diar_end = max(getattr(s, "end", 0.0) for s in diarization_segments)
async with self.lock:
self.state.new_diarization = diarization_segments
self.state.end_attributed_speaker = max(self.state.end_attributed_speaker, diar_end)
except Exception as e: except Exception as e:
logger.warning(f"Exception in diarization_processor: {e}") logger.warning(f"Exception in diarization_processor: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}") logger.warning(f"Traceback: {traceback.format_exc()}")
logger.info("Diarization processor task finished.") logger.info("Diarization processor task finished.")
async def translation_processor(self) -> None: async def translation_processor(self) -> None:
# the idea is to ignore diarization for the moment. We use only transcription tokens. # the idea is to ignore diarization for the moment. We use only transcription tokens.
# And the speaker is attributed given the segments used for the translation # And the speaker is attributed given the segments used for the translation
# in the future we want to have different languages for each speaker etc, so it will be more complex. # in the future we want to have different languages for each speaker etc, so it will be more complex.
while True: while True:
@@ -398,6 +393,10 @@ class AudioProcessor:
async def results_formatter(self) -> AsyncGenerator[FrontData, None]: async def results_formatter(self) -> AsyncGenerator[FrontData, None]:
"""Format processing results for output.""" """Format processing results for output."""
# Update intervals
ACTIVE_INTERVAL = 0.05 # 20 updates/sec during active transcription
SILENCE_INTERVAL = 0.5 # 2 updates/sec during silence
while True: while True:
try: try:
if self._ffmpeg_error: if self._ffmpeg_error:
@@ -407,44 +406,62 @@ class AudioProcessor:
continue continue
self.tokens_alignment.update() self.tokens_alignment.update()
lines, buffer_diarization_text, buffer_translation_text = self.tokens_alignment.get_lines( state = await self.get_current_state()
# Get transcription buffer text to pass to get_lines
buffer_transcription_text = state.buffer_transcription.text if state.buffer_transcription else ''
# get_lines now returns segments with per-segment buffers
segments = self.tokens_alignment.get_lines(
diarization=self.args.diarization, diarization=self.args.diarization,
translation=bool(self.translation), translation=bool(self.translation),
current_silence=self.current_silence current_silence=self.current_silence,
buffer_transcription=buffer_transcription_text
) )
state = await self.get_current_state()
buffer_transcription_text = state.buffer_transcription.text if state.buffer_transcription else ''
response_status = "active_transcription" response_status = "active_transcription"
if not lines and not buffer_transcription_text and not buffer_diarization_text: # Check if there's any content (segments with text or buffers)
has_active_content = any(
seg.buffer and (seg.buffer.transcription or seg.buffer.diarization)
for seg in segments if not seg.is_silence()
)
has_any_content = any(
seg.text or (seg.buffer and (seg.buffer.transcription or seg.buffer.diarization))
for seg in segments if not seg.is_silence()
)
if not segments or not has_any_content:
response_status = "no_audio_detected" response_status = "no_audio_detected"
response = FrontData( response = FrontData(
status=response_status, status=response_status,
lines=lines, segments=segments,
buffer_transcription=buffer_transcription_text,
buffer_diarization=buffer_diarization_text,
buffer_translation=buffer_translation_text,
remaining_time_transcription=state.remaining_time_transcription, remaining_time_transcription=state.remaining_time_transcription,
remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0 remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0
) )
should_push = (response != self.last_response_content) should_push = (response != self.last_response_content)
if should_push: if should_push:
yield response yield response
self.last_response_content = response self.last_response_content = response
if self.is_stopping and self._processing_tasks_done(): if self.is_stopping and self._processing_tasks_done():
logger.info("Results formatter: All upstream processors are done and in stopping state. Terminating.") logger.info("Results formatter: All upstream processors are done and in stopping state. Terminating.")
return return
await asyncio.sleep(0.05) # Throttle updates during silence: use slower interval when in silence mode
# with no pending buffers (nothing actively being processed)
is_in_silence = self.current_silence is not None
has_pending_work = has_active_content or state.remaining_time_transcription > 0.5
if is_in_silence and not has_pending_work:
await asyncio.sleep(SILENCE_INTERVAL)
else:
await asyncio.sleep(ACTIVE_INTERVAL)
except Exception as e: except Exception as e:
logger.warning(f"Exception in results_formatter. Traceback: {traceback.format_exc()}") logger.warning(f"Exception in results_formatter. Traceback: {traceback.format_exc()}")
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
async def create_tasks(self) -> AsyncGenerator[FrontData, None]: async def create_tasks(self) -> AsyncGenerator[FrontData, None]:
"""Create and start processing tasks.""" """Create and start processing tasks."""
self.all_tasks_for_cleanup = [] self.all_tasks_for_cleanup = []
@@ -469,21 +486,21 @@ class AudioProcessor:
self.transcription_task = asyncio.create_task(self.transcription_processor()) self.transcription_task = asyncio.create_task(self.transcription_processor())
self.all_tasks_for_cleanup.append(self.transcription_task) self.all_tasks_for_cleanup.append(self.transcription_task)
processing_tasks_for_watchdog.append(self.transcription_task) processing_tasks_for_watchdog.append(self.transcription_task)
if self.diarization: if self.diarization:
self.diarization_task = asyncio.create_task(self.diarization_processor()) self.diarization_task = asyncio.create_task(self.diarization_processor())
self.all_tasks_for_cleanup.append(self.diarization_task) self.all_tasks_for_cleanup.append(self.diarization_task)
processing_tasks_for_watchdog.append(self.diarization_task) processing_tasks_for_watchdog.append(self.diarization_task)
if self.translation: if self.translation:
self.translation_task = asyncio.create_task(self.translation_processor()) self.translation_task = asyncio.create_task(self.translation_processor())
self.all_tasks_for_cleanup.append(self.translation_task) self.all_tasks_for_cleanup.append(self.translation_task)
processing_tasks_for_watchdog.append(self.translation_task) processing_tasks_for_watchdog.append(self.translation_task)
# Monitor overall system health # Monitor overall system health
self.watchdog_task = asyncio.create_task(self.watchdog(processing_tasks_for_watchdog)) self.watchdog_task = asyncio.create_task(self.watchdog(processing_tasks_for_watchdog))
self.all_tasks_for_cleanup.append(self.watchdog_task) self.all_tasks_for_cleanup.append(self.watchdog_task)
return self.results_formatter() return self.results_formatter()
async def watchdog(self, tasks_to_monitor: List[asyncio.Task]) -> None: async def watchdog(self, tasks_to_monitor: List[asyncio.Task]) -> None:
@@ -496,7 +513,7 @@ class AudioProcessor:
return return
await asyncio.sleep(10) await asyncio.sleep(10)
for i, task in enumerate(list(tasks_remaining)): for i, task in enumerate(list(tasks_remaining)):
if task.done(): if task.done():
exc = task.exception() exc = task.exception()
@@ -506,13 +523,13 @@ class AudioProcessor:
else: else:
logger.info(f"{task_name} completed normally.") logger.info(f"{task_name} completed normally.")
tasks_remaining.remove(task) tasks_remaining.remove(task)
except asyncio.CancelledError: except asyncio.CancelledError:
logger.info("Watchdog task cancelled.") logger.info("Watchdog task cancelled.")
break break
except Exception as e: except Exception as e:
logger.error(f"Error in watchdog task: {e}", exc_info=True) logger.error(f"Error in watchdog task: {e}", exc_info=True)
async def cleanup(self) -> None: async def cleanup(self) -> None:
"""Clean up resources when processing is complete.""" """Clean up resources when processing is complete."""
logger.info("Starting cleanup of AudioProcessor resources.") logger.info("Starting cleanup of AudioProcessor resources.")
@@ -520,7 +537,7 @@ class AudioProcessor:
for task in self.all_tasks_for_cleanup: for task in self.all_tasks_for_cleanup:
if task and not task.done(): if task and not task.done():
task.cancel() task.cancel()
created_tasks = [t for t in self.all_tasks_for_cleanup if t] created_tasks = [t for t in self.all_tasks_for_cleanup if t]
if created_tasks: if created_tasks:
await asyncio.gather(*created_tasks, return_exceptions=True) await asyncio.gather(*created_tasks, return_exceptions=True)
@@ -558,7 +575,7 @@ class AudioProcessor:
if not message: if not message:
logger.info("Empty audio message received, initiating stop sequence.") logger.info("Empty audio message received, initiating stop sequence.")
self.is_stopping = True self.is_stopping = True
if self.transcription_queue: if self.transcription_queue:
await self.transcription_queue.put(SENTINEL) await self.transcription_queue.put(SENTINEL)
@@ -599,7 +616,7 @@ class AudioProcessor:
chunk_size = min(len(self.pcm_buffer), self.max_bytes_per_sec) chunk_size = min(len(self.pcm_buffer), self.max_bytes_per_sec)
aligned_chunk_size = (chunk_size // self.bytes_per_sample) * self.bytes_per_sample aligned_chunk_size = (chunk_size // self.bytes_per_sample) * self.bytes_per_sample
if aligned_chunk_size == 0: if aligned_chunk_size == 0:
return return
pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:aligned_chunk_size]) pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:aligned_chunk_size])
@@ -616,7 +633,7 @@ class AudioProcessor:
if res is not None: if res is not None:
if "start" in res and self.current_silence: if "start" in res and self.current_silence:
await self._end_silence() await self._end_silence()
if "end" in res and not self.current_silence: if "end" in res and not self.current_silence:
pre_silence_chunk = self._slice_before_silence( pre_silence_chunk = self._slice_before_silence(
pcm_array, chunk_sample_start, res.get("end") pcm_array, chunk_sample_start, res.get("end")

View File

@@ -7,7 +7,7 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
from whisperlivekit import (AudioProcessor, TranscriptionEngine, from whisperlivekit import (AudioProcessor, TranscriptionEngine,
get_inline_ui_html, parse_args) get_inline_ui_html, get_text_transcript_html, parse_args)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logging.getLogger().setLevel(logging.WARNING) logging.getLogger().setLevel(logging.WARNING)
@@ -39,6 +39,12 @@ async def get():
return HTMLResponse(get_inline_ui_html()) return HTMLResponse(get_inline_ui_html())
@app.get("/text")
async def get_text():
"""Simple text-based transcript view for easy copy/paste."""
return HTMLResponse(get_text_transcript_html())
async def handle_websocket_results(websocket, results_generator): async def handle_websocket_results(websocket, results_generator):
"""Consumes results from the audio processor and sends them via WebSocket.""" """Consumes results from the audio processor and sends them via WebSocket."""
try: try:

View File

@@ -1,6 +1,5 @@
import logging import logging
import sys import sys
import threading
from argparse import Namespace from argparse import Namespace
from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor
@@ -20,26 +19,16 @@ logger = logging.getLogger(__name__)
class TranscriptionEngine: class TranscriptionEngine:
_instance = None _instance = None
_initialized = False _initialized = False
_lock = threading.Lock() # Thread-safe singleton lock
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
# Double-checked locking pattern for thread-safe singleton
if cls._instance is None: if cls._instance is None:
with cls._lock: cls._instance = super().__new__(cls)
# Check again inside lock to prevent race condition
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance return cls._instance
def __init__(self, **kwargs): def __init__(self, **kwargs):
# Thread-safe initialization check if TranscriptionEngine._initialized:
with TranscriptionEngine._lock: return
if TranscriptionEngine._initialized:
return
# Set flag immediately to prevent re-initialization
TranscriptionEngine._initialized = True
# Perform initialization outside lock to avoid holding lock during slow operations
global_params = { global_params = {
"host": "localhost", "host": "localhost",
"port": 8000, "port": 8000,
@@ -47,6 +36,7 @@ class TranscriptionEngine:
"punctuation_split": False, "punctuation_split": False,
"target_language": "", "target_language": "",
"vac": True, "vac": True,
"vac_onnx": False,
"vac_chunk_size": 0.04, "vac_chunk_size": 0.04,
"log_level": "DEBUG", "log_level": "DEBUG",
"ssl_certfile": None, "ssl_certfile": None,
@@ -89,19 +79,15 @@ class TranscriptionEngine:
self.asr = None self.asr = None
self.tokenizer = None self.tokenizer = None
self.diarization = None self.diarization = None
self.vac_session = None self.vac_model = None
if self.args.vac: if self.args.vac:
from whisperlivekit.silero_vad_iterator import is_onnx_available from whisperlivekit.silero_vad_iterator import load_silero_vad
if is_onnx_available(): # Use ONNX if specified, otherwise use JIT (default)
from whisperlivekit.silero_vad_iterator import load_onnx_session use_onnx = kwargs.get('vac_onnx', False)
self.vac_session = load_onnx_session() self.vac_model = load_silero_vad(onnx=use_onnx)
else:
logger.warning(
"onnxruntime not installed. VAC will use JIT model which is loaded per-session. "
"For multi-user scenarios, install onnxruntime: pip install onnxruntime"
)
backend_policy = self.args.backend_policy backend_policy = self.args.backend_policy
if self.args.transcription: if self.args.transcription:
if backend_policy == "simulstreaming": if backend_policy == "simulstreaming":
@@ -183,13 +169,16 @@ class TranscriptionEngine:
} }
translation_params = update_with_kwargs(translation_params, kwargs) translation_params = update_with_kwargs(translation_params, kwargs)
self.translation_model = load_model([self.args.lan], **translation_params) #in the future we want to handle different languages for different speakers self.translation_model = load_model([self.args.lan], **translation_params) #in the future we want to handle different languages for different speakers
TranscriptionEngine._initialized = True
def online_factory(args, asr): def online_factory(args, asr):
if args.backend_policy == "simulstreaming": if args.backend_policy == "simulstreaming":
from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor
return SimulStreamingOnlineProcessor(asr) online = SimulStreamingOnlineProcessor(asr)
return OnlineASRProcessor(asr) else:
online = OnlineASRProcessor(asr)
return online
def online_diarization_factory(args, diarization_backend): def online_diarization_factory(args, diarization_backend):

View File

@@ -151,7 +151,7 @@ class FasterWhisperASR(ASRBase):
if segment.no_speech_prob > 0.9: if segment.no_speech_prob > 0.9:
continue continue
for word in segment.words: for word in segment.words:
token = ASRToken(word.start, word.end, word.word) token = ASRToken(word.start, word.end, word.word, probability=word.probability)
tokens.append(token) tokens.append(token)
return tokens return tokens

View File

@@ -8,15 +8,6 @@ import torch
Code is adapted from silero-vad v6: https://github.com/snakers4/silero-vad Code is adapted from silero-vad v6: https://github.com/snakers4/silero-vad
""" """
def is_onnx_available() -> bool:
"""Check if onnxruntime is installed."""
try:
import onnxruntime
return True
except ImportError:
return False
def init_jit_model(model_path: str, device=torch.device('cpu')): def init_jit_model(model_path: str, device=torch.device('cpu')):
"""Load a JIT model from file.""" """Load a JIT model from file."""
model = torch.jit.load(model_path, map_location=device) model = torch.jit.load(model_path, map_location=device)
@@ -24,12 +15,12 @@ def init_jit_model(model_path: str, device=torch.device('cpu')):
return model return model
class OnnxSession(): class OnnxWrapper():
""" """ONNX Runtime wrapper for Silero VAD model."""
Shared ONNX session for Silero VAD model (stateless).
"""
def __init__(self, path, force_onnx_cpu=False): def __init__(self, path, force_onnx_cpu=False):
global np
import numpy as np
import onnxruntime import onnxruntime
opts = onnxruntime.SessionOptions() opts = onnxruntime.SessionOptions()
@@ -41,28 +32,13 @@ class OnnxSession():
else: else:
self.session = onnxruntime.InferenceSession(path, sess_options=opts) self.session = onnxruntime.InferenceSession(path, sess_options=opts)
self.path = path self.reset_states()
if '16k' in path: if '16k' in path:
warnings.warn('This model support only 16000 sampling rate!') warnings.warn('This model support only 16000 sampling rate!')
self.sample_rates = [16000] self.sample_rates = [16000]
else: else:
self.sample_rates = [8000, 16000] self.sample_rates = [8000, 16000]
class OnnxWrapper():
"""
ONNX Runtime wrapper for Silero VAD model with per-instance state.
"""
def __init__(self, session: OnnxSession, force_onnx_cpu=False):
self._shared_session = session
self.sample_rates = session.sample_rates
self.reset_states()
@property
def session(self):
return self._shared_session.session
def _validate_input(self, x, sr: int): def _validate_input(self, x, sr: int):
if x.dim() == 1: if x.dim() == 1:
x = x.unsqueeze(0) x = x.unsqueeze(0)
@@ -125,20 +101,38 @@ class OnnxWrapper():
return out return out
def _get_onnx_model_path(model_path: str = None, opset_version: int = 16) -> Path: def load_silero_vad(model_path: str = None, onnx: bool = False, opset_version: int = 16):
"""Get the path to the ONNX model file.""" """
Load Silero VAD model (JIT or ONNX).
Parameters
----------
model_path : str, optional
Path to model file. If None, uses default bundled model.
onnx : bool, default False
Whether to use ONNX runtime (requires onnxruntime package).
opset_version : int, default 16
ONNX opset version (15 or 16). Only used if onnx=True.
Returns
-------
model
Loaded VAD model (JIT or ONNX wrapper)
"""
available_ops = [15, 16] available_ops = [15, 16]
if opset_version not in available_ops: if onnx and opset_version not in available_ops:
raise Exception(f'Available ONNX opset_version: {available_ops}') raise Exception(f'Available ONNX opset_version: {available_ops}')
if model_path is None: if model_path is None:
current_dir = Path(__file__).parent current_dir = Path(__file__).parent
data_dir = current_dir / 'silero_vad_models' data_dir = current_dir / 'silero_vad_models'
if opset_version == 16: if onnx:
model_name = 'silero_vad.onnx' if opset_version == 16:
model_name = 'silero_vad.onnx'
else:
model_name = f'silero_vad_16k_op{opset_version}.onnx'
else: else:
model_name = f'silero_vad_16k_op{opset_version}.onnx' model_name = 'silero_vad.jit'
model_path = data_dir / model_name model_path = data_dir / model_name
@@ -149,39 +143,17 @@ def _get_onnx_model_path(model_path: str = None, opset_version: int = 16) -> Pat
) )
else: else:
model_path = Path(model_path) model_path = Path(model_path)
if onnx:
return model_path try:
model = OnnxWrapper(str(model_path), force_onnx_cpu=True)
except ImportError:
def load_onnx_session(model_path: str = None, opset_version: int = 16, force_onnx_cpu: bool = True) -> OnnxSession: raise ImportError(
""" "ONNX runtime not available. Install with: pip install onnxruntime\n"
Load a shared ONNX session for Silero VAD. "Or use JIT model by setting onnx=False"
"""
path = _get_onnx_model_path(model_path, opset_version)
return OnnxSession(str(path), force_onnx_cpu=force_onnx_cpu)
def load_jit_vad(model_path: str = None):
"""
Load Silero VAD model in JIT format.
"""
if model_path is None:
current_dir = Path(__file__).parent
data_dir = current_dir / 'silero_vad_models'
model_name = 'silero_vad.jit'
model_path = data_dir / model_name
if not model_path.exists():
raise FileNotFoundError(
f"Model file not found: {model_path}\n"
f"Please ensure the whisperlivekit/silero_vad_models/ directory contains the model files."
) )
else: else:
model_path = Path(model_path) model = init_jit_model(str(model_path))
model = init_jit_model(str(model_path))
return model return model
@@ -313,14 +285,13 @@ class FixedVADIterator(VADIterator):
if __name__ == "__main__": if __name__ == "__main__":
# vad = FixedVADIterator(load_jit_vad()) model = load_silero_vad(onnx=False)
vad = FixedVADIterator(OnnxWrapper(session=load_onnx_session())) vad = FixedVADIterator(model)
audio_buffer = np.array([0] * 512, dtype=np.float32) audio_buffer = np.array([0] * 512, dtype=np.float32)
result = vad(audio_buffer) result = vad(audio_buffer)
print(f" 512 samples: {result}") print(f" 512 samples: {result}")
# test with 511 samples # test with 511 samples
audio_buffer = np.array([0] * 511, dtype=np.float32) audio_buffer = np.array([0] * 511, dtype=np.float32)
result = vad(audio_buffer) result = vad(audio_buffer)
print(f" 511 samples: {result}")

View File

@@ -24,11 +24,9 @@ logger = logging.getLogger(__name__)
HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True) HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True)
if HAS_MLX_WHISPER: if HAS_MLX_WHISPER:
from .mlx_encoder import load_mlx_encoder, load_mlx_model, mlx_model_mapping from .mlx_encoder import load_mlx_encoder, mlx_model_mapping
from .mlx import MLXAlignAtt
else: else:
mlx_model_mapping = {} mlx_model_mapping = {}
MLXAlignAtt = None
HAS_FASTER_WHISPER = faster_backend_available(warn_on_missing=not HAS_MLX_WHISPER) HAS_FASTER_WHISPER = faster_backend_available(warn_on_missing=not HAS_MLX_WHISPER)
if HAS_FASTER_WHISPER: if HAS_FASTER_WHISPER:
from faster_whisper import WhisperModel from faster_whisper import WhisperModel
@@ -38,49 +36,50 @@ else:
MIN_DURATION_REAL_SILENCE = 5 MIN_DURATION_REAL_SILENCE = 5
class SimulStreamingOnlineProcessor: class SimulStreamingOnlineProcessor:
"""Online processor for SimulStreaming ASR."""
SAMPLING_RATE = 16000 SAMPLING_RATE = 16000
def __init__(self, asr, logfile=sys.stderr): def __init__(
self,
asr,
logfile=sys.stderr,
):
self.asr = asr self.asr = asr
self.logfile = logfile self.logfile = logfile
self.end = 0.0 self.end = 0.0
self.buffer = [] self.buffer = []
self.committed: List[ASRToken] = [] self.committed: List[ASRToken] = []
self.last_result_tokens: List[ASRToken] = [] self.last_result_tokens: List[ASRToken] = []
self.model = self._create_alignatt() self.load_new_alignatt_instance()
if asr.tokenizer: if asr.tokenizer:
self.model.tokenizer = asr.tokenizer self.model.tokenizer = asr.tokenizer
self.model.state.tokenizer = asr.tokenizer
def _create_alignatt(self): def load_new_alignatt_instance(self):
"""Create the AlignAtt decoder instance based on ASR mode.""" """Initialize AlignAtt decoder using the shared model."""
if self.asr.use_full_mlx and HAS_MLX_WHISPER: self.model = AlignAtt(
return MLXAlignAtt(cfg=self.asr.cfg, mlx_model=self.asr.mlx_model) cfg=self.asr.cfg,
else: loaded_model=self.asr.shared_model,
return AlignAtt( mlx_encoder=self.asr.mlx_encoder,
cfg=self.asr.cfg, fw_encoder=self.asr.fw_encoder,
loaded_model=self.asr.shared_model, )
mlx_encoder=self.asr.mlx_encoder,
fw_encoder=self.asr.fw_encoder,
)
def start_silence(self): def start_silence(self):
tokens, processed_upto = self.process_iter(is_last=True) tokens, processed_upto = self.process_iter(is_last=True)
return tokens, processed_upto return tokens, processed_upto
def end_silence(self, silence_duration, offset): def end_silence(self, silence_duration, offset):
"""Handle silence period.""" """
Handle silence period.
If silence > MIN_DURATION_REAL_SILENCE, do a complete context clear.
Otherwise, insert a small silence and shift the last_attend_frame.
"""
self.end += silence_duration self.end += silence_duration
long_silence = silence_duration >= MIN_DURATION_REAL_SILENCE long_silence = silence_duration >= MIN_DURATION_REAL_SILENCE
if not long_silence: if not long_silence:
gap_len = int(16000 * silence_duration) gap_len = int(16000 * silence_duration)
if gap_len > 0: if gap_len > 0:
if self.asr.use_full_mlx: gap_silence = torch.zeros(gap_len)
gap_silence = np.zeros(gap_len, dtype=np.float32)
else:
gap_silence = torch.zeros(gap_len)
self.model.insert_audio(gap_silence) self.model.insert_audio(gap_silence)
if long_silence: if long_silence:
self.model.refresh_segment(complete=True) self.model.refresh_segment(complete=True)
@@ -88,12 +87,11 @@ class SimulStreamingOnlineProcessor:
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time): def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time):
"""Append an audio chunk to be processed by SimulStreaming.""" """Append an audio chunk to be processed by SimulStreaming."""
self.end = audio_stream_end_time
if self.asr.use_full_mlx: # Convert numpy array to torch tensor
self.model.insert_audio(audio) audio_tensor = torch.from_numpy(audio).float()
else: self.end = audio_stream_end_time # Aligned with whisperstreaming backend behavior
audio_tensor = torch.from_numpy(audio).float() self.model.insert_audio(audio_tensor)
self.model.insert_audio(audio_tensor)
def new_speaker(self, change_speaker: ChangeSpeaker): def new_speaker(self, change_speaker: ChangeSpeaker):
"""Handle speaker change event.""" """Handle speaker change event."""
@@ -132,10 +130,6 @@ class SimulStreamingOnlineProcessor:
def warmup(self, audio, init_prompt=""): def warmup(self, audio, init_prompt=""):
"""Warmup the SimulStreaming model.""" """Warmup the SimulStreaming model."""
try: try:
if self.asr.use_full_mlx:
# MLX mode: ensure numpy array
if hasattr(audio, 'numpy'):
audio = audio.numpy()
self.model.insert_audio(audio) self.model.insert_audio(audio)
self.model.infer(True) self.model.infer(True)
self.model.refresh_segment(complete=True) self.model.refresh_segment(complete=True)
@@ -145,14 +139,9 @@ class SimulStreamingOnlineProcessor:
def __del__(self): def __del__(self):
gc.collect() gc.collect()
if not getattr(self.asr, 'use_full_mlx', True) and torch is not None: torch.cuda.empty_cache()
try:
torch.cuda.empty_cache()
except Exception:
pass
class SimulStreamingASR():
class SimulStreamingASR:
"""SimulStreaming backend with AlignAtt policy.""" """SimulStreaming backend with AlignAtt policy."""
sep = "" sep = ""
@@ -169,7 +158,6 @@ class SimulStreamingASR:
self.fast_encoder = False self.fast_encoder = False
self._resolved_model_path = None self._resolved_model_path = None
self.encoder_backend = "whisper" self.encoder_backend = "whisper"
self.use_full_mlx = getattr(self, "use_full_mlx", False)
preferred_backend = getattr(self, "backend", "auto") preferred_backend = getattr(self, "backend", "auto")
compatible_whisper_mlx, compatible_faster_whisper = True, True compatible_whisper_mlx, compatible_faster_whisper = True, True
@@ -182,7 +170,7 @@ class SimulStreamingASR:
compatible_whisper_mlx = model_info.compatible_whisper_mlx compatible_whisper_mlx = model_info.compatible_whisper_mlx
compatible_faster_whisper = model_info.compatible_faster_whisper compatible_faster_whisper = model_info.compatible_faster_whisper
if not self.use_full_mlx and not model_info.has_pytorch: if not model_info.has_pytorch:
raise FileNotFoundError( raise FileNotFoundError(
f"No PyTorch checkpoint (.pt/.bin/.safetensors) found under {self.model_path}" f"No PyTorch checkpoint (.pt/.bin/.safetensors) found under {self.model_path}"
) )
@@ -202,10 +190,6 @@ class SimulStreamingASR:
self.fast_encoder = self.encoder_backend in ("mlx-whisper", "faster-whisper") self.fast_encoder = self.encoder_backend in ("mlx-whisper", "faster-whisper")
if self.encoder_backend == "whisper": if self.encoder_backend == "whisper":
self.disable_fast_encoder = True self.disable_fast_encoder = True
if self.encoder_backend == "mlx-whisper" and platform.system() == "Darwin":
if not hasattr(self, '_full_mlx_disabled'):
self.use_full_mlx = True
self.cfg = AlignAttConfig( self.cfg = AlignAttConfig(
tokenizer_is_multilingual= is_multilingual, tokenizer_is_multilingual= is_multilingual,
@@ -230,36 +214,20 @@ class SimulStreamingASR:
else: else:
self.tokenizer = None self.tokenizer = None
self.mlx_encoder, self.fw_encoder, self.mlx_model = None, None, None self.mlx_encoder, self.fw_encoder = None, None
self.shared_model = None if self.encoder_backend == "mlx-whisper":
print('Simulstreaming will use MLX whisper to increase encoding speed.')
if self.use_full_mlx and HAS_MLX_WHISPER:
logger.info('MLX Whisper backend used.')
if self._resolved_model_path is not None: if self._resolved_model_path is not None:
mlx_model_path = str(self._resolved_model_path) mlx_model = str(self._resolved_model_path)
else: else:
mlx_model_path = mlx_model_mapping.get(self.model_name) mlx_model = mlx_model_mapping.get(self.model_name)
if not mlx_model_path: if not mlx_model:
raise FileNotFoundError( raise FileNotFoundError(
f"MLX Whisper backend requested but no compatible weights found for model '{self.model_name}'." f"MLX Whisper backend requested but no compatible weights found for model '{self.model_name}'."
) )
self.mlx_model = load_mlx_model(path_or_hf_repo=mlx_model_path) self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model)
self._warmup_mlx_model()
elif self.encoder_backend == "mlx-whisper":
# hybrid mode: mlx encoder + pytorch decoder
logger.info('SimulStreaming will use MLX Whisper encoder with PyTorch decoder.')
if self._resolved_model_path is not None:
mlx_model_path = str(self._resolved_model_path)
else:
mlx_model_path = mlx_model_mapping.get(self.model_name)
if not mlx_model_path:
raise FileNotFoundError(
f"MLX Whisper backend requested but no compatible weights found for model '{self.model_name}'."
)
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_path)
self.shared_model = self.load_model()
elif self.encoder_backend == "faster-whisper": elif self.encoder_backend == "faster-whisper":
print('SimulStreaming will use Faster Whisper for the encoder.') print('Simulstreaming will use Faster Whisper for the encoder.')
if self._resolved_model_path is not None: if self._resolved_model_path is not None:
fw_model = str(self._resolved_model_path) fw_model = str(self._resolved_model_path)
else: else:
@@ -269,20 +237,7 @@ class SimulStreamingASR:
device='auto', device='auto',
compute_type='auto', compute_type='auto',
) )
self.shared_model = self.load_model() self.shared_model = self.load_model()
else:
self.shared_model = self.load_model()
def _warmup_mlx_model(self):
"""Warmup the full MLX model."""
warmup_audio = load_file(self.warmup_file)
if warmup_audio is not None:
temp_model = MLXAlignAtt(
cfg=self.cfg,
mlx_model=self.mlx_model,
)
temp_model.warmup(warmup_audio)
logger.info("Full MLX model warmed up successfully")
def _resolve_encoder_backend(self, preferred_backend, compatible_whisper_mlx, compatible_faster_whisper): def _resolve_encoder_backend(self, preferred_backend, compatible_whisper_mlx, compatible_faster_whisper):

View File

@@ -47,24 +47,9 @@ class DecoderState:
def clean_cache(self): def clean_cache(self):
"""Clean the kv_cache after each inference step.""" """Clean the kv_cache after each inference step."""
# Explicitly delete tensor references to free GPU memory self.kv_cache = {}
if self.kv_cache:
for key in list(self.kv_cache.keys()):
tensor = self.kv_cache.pop(key, None)
if tensor is not None:
del tensor
# Clear the dict
self.kv_cache.clear()
# Force GPU cache cleanup (only if CUDA is available)
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
if self.decoder_type == "beam" and self.inference is not None: if self.decoder_type == "beam" and self.inference is not None:
# Create NEW dict instead of sharing reference self.inference.kv_cache = self.kv_cache
self.inference.kv_cache = {}
if self.token_decoder is not None: if self.token_decoder is not None:
self.token_decoder.reset() self.token_decoder.reset()

View File

@@ -1,11 +0,0 @@
from .decoder_state import MLXDecoderState
from .decoders import MLXBeamSearchDecoder, MLXGreedyDecoder, MLXInference
from .simul_whisper import MLXAlignAtt
__all__ = [
"MLXAlignAtt",
"MLXBeamSearchDecoder",
"MLXDecoderState",
"MLXGreedyDecoder",
"MLXInference",
]

View File

@@ -1,76 +0,0 @@
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
import mlx.core as mx
import numpy as np
@dataclass
class MLXDecoderState:
"""
mlx kv cache format: List of ((k, v), (cross_k, cross_v)) tuples per layer,
where each element is a tuple of mx.arrays.
"""
kv_cache: Optional[List[Tuple[Tuple[mx.array, mx.array], Tuple[mx.array, mx.array]]]] = None
tokenizer: Any = None
detected_language: Optional[str] = None
reset_tokenizer_to_auto_next_call: bool = False
tokens: List[mx.array] = field(default_factory=list)
initial_tokens: Optional[mx.array] = None
initial_token_length: int = 0
sot_index: int = 0
align_source: Dict[int, List[Tuple[int, int]]] = field(default_factory=dict)
num_align_heads: int = 0
segments: List[np.ndarray] = field(default_factory=list)
context: Any = None
pending_incomplete_tokens: List[int] = field(default_factory=list)
global_time_offset: float = 0.0
cumulative_time_offset: float = 0.0
first_timestamp: Optional[float] = None
last_attend_frame: int = 0
speaker: int = -1
log_segments: int = 0
cif_weights: Optional[mx.array] = None
always_fire: bool = False
never_fire: bool = False
suppress_tokens: Optional[Tuple[int, ...]] = None
token_decoder: Any = None
decoder_type: str = "greedy"
inference: Any = None
def clean_cache(self):
self.kv_cache = None
if self.decoder_type == "beam" and self.inference is not None:
self.inference.kv_cache = None
if self.token_decoder is not None:
self.token_decoder.reset()
def reset(self, rewind_threshold: int = 200):
self.last_attend_frame = -rewind_threshold
self.cumulative_time_offset = 0.0
self.pending_incomplete_tokens = []
self.log_segments += 1
def full_reset(self, rewind_threshold: int = 200):
"""
Full reset including audio segments and tokens.
Args:
rewind_threshold: Value for resetting last_attend_frame
"""
self.reset(rewind_threshold)
self.segments = []
self.tokens = []
self.kv_cache = None
self.first_timestamp = None

View File

@@ -1,219 +0,0 @@
"""
MLX-native token decoders for streaming ASR.
"""
from typing import Any, Dict, List, Optional, Tuple
import mlx.core as mx
import numpy as np
class MLXGreedyDecoder:
"""Greedy decoder using MLX operations."""
def __init__(self, temperature: float, eot: int):
self.temperature = temperature
self.eot = eot
def update(
self, tokens: mx.array, logits: mx.array, sum_logprobs: mx.array
) -> Tuple[mx.array, bool]:
"""
Update tokens with next predicted token.
Args:
tokens: Current token sequence, shape (batch, seq_len)
logits: Logits for next token, shape (batch, vocab_size)
sum_logprobs: Cumulative log probabilities, shape (batch,)
Returns:
Updated tokens and completion flag
"""
if self.temperature == 0:
next_tokens = mx.argmax(logits, axis=-1)
else:
probs = mx.softmax(logits / self.temperature, axis=-1)
next_tokens = mx.random.categorical(mx.log(probs + 1e-10))
logprobs = mx.softmax(logits, axis=-1)
logprobs = mx.log(logprobs + 1e-10)
batch_size = logprobs.shape[0]
current_logprobs = logprobs[mx.arange(batch_size), next_tokens]
mask = (tokens[:, -1] != self.eot).astype(mx.float32)
sum_logprobs = sum_logprobs + current_logprobs * mask
eot_mask = (tokens[:, -1] == self.eot)
next_tokens = mx.where(eot_mask, mx.array(self.eot), next_tokens)
tokens = mx.concatenate([tokens, next_tokens[:, None]], axis=1)
completed = bool(mx.all(tokens[:, -1] == self.eot))
return tokens, completed
def finalize(self, tokens: mx.array, sum_logprobs: mx.array):
"""Finalize decoding by ensuring EOT at end."""
eot_column = mx.full((tokens.shape[0], 1), self.eot, dtype=tokens.dtype)
tokens = mx.concatenate([tokens, eot_column], axis=1)
return tokens, sum_logprobs.tolist()
class MLXBeamSearchDecoder:
"""Beam search decoder using MLX operations."""
def __init__(
self,
beam_size: int,
eot: int,
inference: Any,
patience: Optional[float] = None,
):
self.beam_size = beam_size
self.eot = eot
self.inference = inference
self.patience = patience or 1.0
self.max_candidates: int = round(beam_size * self.patience)
self.finished_sequences: Optional[List[Dict]] = None
assert (
self.max_candidates > 0
), f"Invalid beam size ({beam_size}) or patience ({patience})"
def reset(self):
"""Reset finished sequences for new segment."""
self.finished_sequences = None
def update(
self, tokens: mx.array, logits: mx.array, sum_logprobs: mx.array
) -> Tuple[mx.array, bool]:
"""
Update tokens using beam search.
Args:
tokens: Current token sequences, shape (batch * beam_size, seq_len)
logits: Logits for next token, shape (batch * beam_size, vocab_size)
sum_logprobs: Cumulative log probabilities, shape (batch * beam_size,)
Returns:
Updated tokens and completion flag
"""
if tokens.shape[0] % self.beam_size != 0:
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
n_audio = tokens.shape[0] // self.beam_size
if self.finished_sequences is None:
self.finished_sequences = [{} for _ in range(n_audio)]
logprobs = mx.softmax(logits, axis=-1)
logprobs = mx.log(logprobs + 1e-10)
logprobs_np = np.array(logprobs)
tokens_np = np.array(tokens)
sum_logprobs_np = np.array(sum_logprobs)
next_tokens, source_indices, finished_sequences = [], [], []
new_sum_logprobs = []
for i in range(n_audio):
scores, sources, finished = {}, {}, {}
for j in range(self.beam_size):
idx = i * self.beam_size + j
prefix = tokens_np[idx].tolist()
top_k_indices = np.argsort(logprobs_np[idx])[-self.beam_size - 1:][::-1]
for token_idx in top_k_indices:
logprob = logprobs_np[idx, token_idx]
new_logprob = sum_logprobs_np[idx] + logprob
sequence = tuple(prefix + [int(token_idx)])
scores[sequence] = new_logprob
sources[sequence] = idx
saved = 0
for sequence in sorted(scores, key=scores.get, reverse=True):
if sequence[-1] == self.eot:
finished[sequence] = scores[sequence]
else:
new_sum_logprobs.append(scores[sequence])
next_tokens.append(sequence)
source_indices.append(sources[sequence])
saved += 1
if saved == self.beam_size:
break
finished_sequences.append(finished)
tokens = mx.array(np.array(next_tokens, dtype=np.int32))
sum_logprobs = mx.array(np.array(new_sum_logprobs, dtype=np.float32))
self.inference.rearrange_kv_cache(source_indices)
assert len(self.finished_sequences) == len(finished_sequences)
for previously_finished, newly_finished in zip(
self.finished_sequences, finished_sequences
):
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
if len(previously_finished) >= self.max_candidates:
break
previously_finished[seq] = newly_finished[seq]
completed = all(
len(sequences) >= self.max_candidates
for sequences in self.finished_sequences
)
return tokens, completed
def finalize(self, preceding_tokens: mx.array, sum_logprobs: mx.array):
"""Finalize beam search by selecting best sequences."""
preceding_tokens_np = np.array(preceding_tokens)
sum_logprobs_np = np.array(sum_logprobs)
n_audio = preceding_tokens_np.shape[0] // self.beam_size
tokens_list: List[List[int]] = [[] for _ in range(n_audio)]
sum_logprobs_list: List[float] = [0.0] * n_audio
for i, sequences in enumerate(self.finished_sequences):
if sequences:
best_seq = max(sequences, key=sequences.get)
tokens_list[i] = list(best_seq)
sum_logprobs_list[i] = sequences[best_seq]
else:
idx = i * self.beam_size
tokens_list[i] = preceding_tokens_np[idx].tolist() + [self.eot]
sum_logprobs_list[i] = float(sum_logprobs_np[idx])
max_len = max(len(t) for t in tokens_list)
for i, t in enumerate(tokens_list):
tokens_list[i] = t + [self.eot] * (max_len - len(t))
tokens = mx.array(np.array(tokens_list, dtype=np.int32))
return tokens, sum_logprobs_list
class MLXInference:
"""MLX inference wrapper for beam search KV cache management."""
def __init__(self, model, initial_token_length: int):
self.model = model
self.initial_token_length = initial_token_length
self.kv_cache = None
def rearrange_kv_cache(self, source_indices: List[int]):
"""Rearrange KV cache based on beam search source indices."""
if self.kv_cache is None:
return
if source_indices == list(range(len(source_indices))):
return
source_indices_mx = mx.array(source_indices, dtype=mx.int32)
new_cache = []
for layer_cache in self.kv_cache:
(k, v), (cross_k, cross_v) = layer_cache
new_k = k[source_indices_mx]
new_v = v[source_indices_mx]
new_cache.append(((new_k, new_v), (cross_k, cross_v)))
self.kv_cache = new_cache
def logits(
self,
tokens: mx.array,
audio_features: mx.array,
) -> Tuple[mx.array, List]:
"""Get logits from decoder with KV cache."""
logits, self.kv_cache, cross_qk = self.model.decoder(
tokens, audio_features, kv_cache=self.kv_cache
)
return logits, cross_qk

View File

@@ -1,752 +0,0 @@
"""
MLX whisper AlignAtt streaming decoder
"""
import logging
from time import time
from typing import Any, List, Optional, Tuple
import mlx.core as mx
import numpy as np
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
from whisperlivekit.timed_objects import ASRToken
from whisperlivekit.whisper import DecodingOptions, tokenizer
from whisperlivekit.whisper.audio import N_FRAMES, N_SAMPLES, TOKENS_PER_SECOND
from ..config import AlignAttConfig
from .decoder_state import MLXDecoderState
from .decoders import MLXBeamSearchDecoder, MLXGreedyDecoder, MLXInference
DEC_PAD = 50257
logger = logging.getLogger(__name__)
class MLXTokenBuffer: #should try to make it heritate from classic simul whisper class
"""Token buffer for MLX-based decoding."""
def __init__(self, text="", tokenizer=None, prefix_token_ids=None):
self.text = text
self.prefix_token_ids = prefix_token_ids or []
self.tokenizer = tokenizer
self.pending_token_ids = []
def as_token_ids(self, tokenizer=None):
if tokenizer is None:
tokenizer = self.tokenizer
if tokenizer is None:
raise ValueError("Tokenizer is not set.")
return self.prefix_token_ids + tokenizer.encode(self.text)
def as_mlx_array(self) -> mx.array:
"""Return tokens as MLX array."""
tok_ids = self.as_token_ids()
return mx.array([tok_ids], dtype=mx.int32)
def as_mlx_array_beam(self, beam: int) -> mx.array:
"""Return tokens as MLX array repeated for beam search."""
t = self.as_mlx_array()
return mx.repeat(t, beam, axis=0)
def as_text(self):
return self.text
@staticmethod
def empty(*a, **kw):
return MLXTokenBuffer(*a, **kw)
@staticmethod
def from_text(text, *a, **kw):
return MLXTokenBuffer(*a, text=text, **kw)
def is_empty(self):
return self.text is None or self.text == ""
def trim_words(self, num=1, after=0):
"""Trim words from the beginning of the context."""
tokenizer = self.tokenizer
assert tokenizer is not None, "Tokenizer is not set."
ids = tokenizer.encode(self.text[after:])
words, wids = self.tokenizer.split_to_word_tokens(ids)
if not words:
return 0
self.text = self.text[:after] + "".join(words[num:])
return sum(len(wi) for wi in wids[:num])
def append_token_ids(self, token_ids):
"""Append token IDs to the buffer, handling incomplete UTF-8."""
tokenizer = self.tokenizer
assert tokenizer is not None, "Tokenizer is not set."
all_tokens = self.pending_token_ids + token_ids
decoded = tokenizer.decode(all_tokens)
replacement_char = "\ufffd"
if replacement_char in decoded:
if len(all_tokens) > 1:
decoded_partial = tokenizer.decode(all_tokens[:-1])
if replacement_char not in decoded_partial:
self.text += decoded_partial
self.pending_token_ids = [all_tokens[-1]]
else:
self.pending_token_ids = all_tokens
else:
self.pending_token_ids = all_tokens
else:
self.text += decoded
self.pending_token_ids = []
def mlx_median_filter(x: mx.array, filter_width: int) -> mx.array:
"""
Apply median filter along the last axis.
Args:
x: Input array of shape (..., T)
filter_width: Width of the median filter (should be odd)
Returns:
Filtered array of same shape
"""
if filter_width <= 1:
return x
pad_width = filter_width // 2
shape = x.shape
left_pad = mx.repeat(x[..., :1], pad_width, axis=-1)
right_pad = mx.repeat(x[..., -1:], pad_width, axis=-1)
x_padded = mx.concatenate([left_pad, x, right_pad], axis=-1)
result_shape = list(shape)
result = []
for i in range(shape[-1]):
window = x_padded[..., i:i + filter_width]
sorted_window = mx.sort(window, axis=-1)
median_val = sorted_window[..., filter_width // 2:filter_width // 2 + 1]
result.append(median_val)
return mx.concatenate(result, axis=-1)
class MLXAlignAtt:
"""
MLX-native Alignment-based Attention decoder for SimulStreaming.
This class runs entirely on MLX, with no PyTorch dependencies for inference.
"""
@property
def speaker(self):
return self.state.speaker
@speaker.setter
def speaker(self, value):
self.state.speaker = value
@property
def global_time_offset(self):
return self.state.global_time_offset
@global_time_offset.setter
def global_time_offset(self, value):
self.state.global_time_offset = value
def __init__(
self,
cfg: AlignAttConfig,
mlx_model: Any,
) -> None:
"""
Initialize MLX AlignAtt decoder.
Args:
cfg: AlignAtt configuration
mlx_model: MLX Whisper model (full model, not just encoder)
"""
self.model = mlx_model
self.cfg = cfg
logger.info(f"MLX Model dimensions: {self.model.dims}")
self.decode_options = DecodingOptions(
language=cfg.language,
without_timestamps=True,
task=cfg.task
)
self.tokenizer_is_multilingual = cfg.tokenizer_is_multilingual
self.max_text_len = self.model.dims.n_text_ctx
self.num_decoder_layers = len(self.model.decoder.blocks)
if self.cfg.max_context_tokens is None:
self.max_context_tokens = self.max_text_len
else:
self.max_context_tokens = self.cfg.max_context_tokens
# Initialize per-session state
self.state = MLXDecoderState()
self._init_state(cfg)
def _init_state(self, cfg: AlignAttConfig):
"""Initialize the per-session decoder state."""
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
self.state.tokenizer = self.tokenizer
self.state.detected_language = cfg.language if cfg.language != "auto" else None
self.state.global_time_offset = 0.0
self.state.last_attend_frame = -cfg.rewind_threshold
self.state.speaker = -1
if cfg.cif_ckpt_path is None or not cfg.cif_ckpt_path:
if cfg.never_fire:
self.state.never_fire = True
self.state.always_fire = False
else:
self.state.always_fire = True
self.state.never_fire = False
else:
logger.warning("CIF checkpoint provided but MLX CIF not implemented. Using always_fire=True")
self.state.always_fire = True
self.state.never_fire = cfg.never_fire
self._build_alignment_source()
suppress_tokens = [
self.tokenizer.transcribe,
self.tokenizer.translate,
self.tokenizer.sot,
self.tokenizer.sot_prev,
self.tokenizer.sot_lm,
self.tokenizer.no_timestamps,
] + list(self.tokenizer.all_language_tokens)
if self.tokenizer.no_speech is not None:
suppress_tokens.append(self.tokenizer.no_speech)
self.state.suppress_tokens = tuple(sorted(set(suppress_tokens)))
logger.debug(f"Suppress tokens: {self.state.suppress_tokens}")
self.init_tokens()
self.init_context()
self.state.decoder_type = cfg.decoder_type
if cfg.decoder_type == "greedy":
logger.info("Using MLX greedy decoder")
self.state.token_decoder = MLXGreedyDecoder(0.0, self.tokenizer.eot)
elif cfg.decoder_type == "beam":
logger.info("Using MLX beam decoder")
self.state.inference = MLXInference(self.model, self.state.initial_token_length)
self.state.token_decoder = MLXBeamSearchDecoder(
inference=self.state.inference,
eot=self.tokenizer.eot,
beam_size=cfg.beam_size
)
def _build_alignment_source(self):
"""Build alignment source mapping from model's alignment_heads."""
self.state.align_source = {}
self.state.num_align_heads = 0
alignment_heads = self.model.alignment_heads
if alignment_heads is None:
logger.warning("No alignment heads found in model")
return
if hasattr(alignment_heads, 'tolist'):
heads_list = alignment_heads.tolist()
else:
heads_list = np.array(alignment_heads).tolist()
for layer_rank, head_id in heads_list:
layer_rank = int(layer_rank)
head_id = int(head_id)
heads = self.state.align_source.get(layer_rank, [])
heads.append((self.state.num_align_heads, head_id))
self.state.align_source[layer_rank] = heads
self.state.num_align_heads += 1
def warmup(self, audio: np.ndarray):
"""Warmup the model with sample audio."""
try:
self.insert_audio(audio)
self.infer(is_last=True)
self.refresh_segment(complete=True)
logger.info("MLX model warmed up successfully")
except Exception as e:
logger.exception(f"MLX model warmup failed: {e}")
def create_tokenizer(self, language=None):
"""Create tokenizer for the given language."""
self.tokenizer = tokenizer.get_tokenizer(
multilingual=self.tokenizer_is_multilingual,
language=language,
num_languages=self.model.num_languages,
task=self.decode_options.task
)
self.state.tokenizer = self.tokenizer
def init_context(self):
"""Initialize context buffer."""
kw = {
'tokenizer': self.tokenizer,
'prefix_token_ids': [self.tokenizer.sot_prev]
}
self.state.context = MLXTokenBuffer.empty(**kw)
if self.cfg.static_init_prompt is not None:
self.state.context = MLXTokenBuffer.from_text(self.cfg.static_init_prompt, **kw)
if self.cfg.init_prompt is not None:
self.state.context.text += self.cfg.init_prompt
def init_tokens(self):
"""Initialize token sequence."""
logger.debug(f"init tokens, {len(self.state.segments)}")
self.state.initial_tokens = mx.array(
[self.tokenizer.sot_sequence_including_notimestamps],
dtype=mx.int32
)
self.state.initial_token_length = self.state.initial_tokens.shape[1]
self.state.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
logger.debug(f"init tokens after, {len(self.state.segments)}")
self.state.tokens = [self.state.initial_tokens]
def trim_context(self):
"""Trim context if too long."""
logger.info("Trimming context")
c = len(self.state.context.as_token_ids()) - len(self.state.context.prefix_token_ids)
logger.info(f"Context text: {self.state.context.as_text()}")
l = sum(t.shape[1] for t in self.state.tokens) + c
if self.cfg.static_init_prompt is None:
after = 0
else:
after = len(self.cfg.static_init_prompt)
while c > self.max_context_tokens or l > self.max_text_len - 20:
t = self.state.context.trim_words(after=after)
l -= t
c -= t
logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
if t == 0:
break
logger.info(f"Context after trim: {self.state.context.text} (len: {l})")
def refresh_segment(self, complete=False):
"""Refresh segment state."""
logger.debug("Refreshing segment:")
self.init_tokens()
self.state.last_attend_frame = -self.cfg.rewind_threshold
self.state.cumulative_time_offset = 0.0
self.init_context()
logger.debug(f"Context: {self.state.context}")
if not complete and len(self.state.segments) > 2:
self.state.segments = self.state.segments[-2:]
else:
logger.debug("removing all segments.")
self.state.segments = []
self.state.log_segments += 1
self.state.pending_incomplete_tokens = []
def fire_at_boundary(self, chunked_encoder_feature: mx.array) -> bool:
"""Check if we should fire at word boundary (CIF-based)."""
if self.state.always_fire:
return True
if self.state.never_fire:
return False
return True
def _current_tokens(self) -> mx.array:
"""Get current token sequence for decoding."""
toks = self.state.tokens
if toks[0].shape[0] == 1:
toks[0] = mx.repeat(toks[0], self.cfg.beam_size, axis=0)
if not self.state.context.is_empty():
context_toks = self.state.context.as_mlx_array_beam(self.cfg.beam_size)
toks = [context_toks] + toks
# Concatenate all tokens
if len(toks) > 1:
current_tokens = mx.concatenate(toks, axis=1)
else:
current_tokens = toks[0]
logger.debug("debug print current_tokens:")
self.debug_print_tokens(current_tokens)
return current_tokens
def debug_print_tokens(self, tokens: mx.array):
"""Debug print token sequences."""
tokens_np = np.array(tokens)
for i in range(min(self.cfg.beam_size, tokens_np.shape[0])):
logger.debug(self.tokenizer.decode_with_timestamps(tokens_np[i].tolist()))
def segments_len(self) -> float:
"""Get total length of audio segments in seconds."""
return sum(s.shape[0] for s in self.state.segments) / 16000
def _apply_minseglen(self) -> bool:
"""Check if we have enough audio to process."""
segments_len = self.segments_len()
if segments_len < self.cfg.audio_min_len:
logger.debug("waiting for next segment")
return False
return True
def insert_audio(self, segment: np.ndarray = None):
"""Insert audio segment into buffer."""
if segment is not None:
if hasattr(segment, 'numpy'):
segment = segment.numpy()
self.state.segments.append(segment)
removed_len = 0
segments_len = self.segments_len()
while len(self.state.segments) > 1 and segments_len > self.cfg.audio_max_len:
removed_len = self.state.segments[0].shape[0] / 16000
segments_len -= removed_len
self.state.last_attend_frame -= int(TOKENS_PER_SECOND * removed_len)
self.state.cumulative_time_offset += removed_len
self.state.segments = self.state.segments[1:]
logger.debug(f"remove segments: {len(self.state.segments)} {len(self.state.tokens)}, cumulative offset: {self.state.cumulative_time_offset:.2f}s")
if len(self.state.tokens) > 1:
# Convert MLX array to list for context
token_list = np.array(self.state.tokens[1][0, :]).tolist()
self.state.context.append_token_ids(token_list)
self.state.tokens = [self.state.initial_tokens] + self.state.tokens[2:]
return removed_len
def _clean_cache(self):
"""Clean the kv_cache after each inference step."""
self.state.clean_cache()
def _suppress_tokens(self, logits: mx.array) -> mx.array:
"""Apply token suppression to logits."""
if self.state.suppress_tokens:
suppress_indices = mx.array(list(self.state.suppress_tokens), dtype=mx.int32)
logits = logits.at[:, suppress_indices].add(-float('inf'))
return logits
def lang_id(self, encoder_features: mx.array) -> Tuple[mx.array, List[dict]]:
"""Language detection from encoder features."""
n_audio = encoder_features.shape[0]
x = mx.array([[self.tokenizer.sot]] * n_audio, dtype=mx.int32)
logits, _, _ = self.model.decoder(x, encoder_features, kv_cache=None)
logits = logits[:, 0]
mask = mx.ones(logits.shape[-1], dtype=mx.bool_)
language_token_indices = mx.array(list(self.tokenizer.all_language_tokens), dtype=mx.int32)
mask = mask.at[language_token_indices].add(False)
logits = mx.where(mask, mx.array(-float('inf')), logits)
language_tokens = mx.argmax(logits, axis=-1)
language_token_probs = mx.softmax(logits, axis=-1)
probs_np = np.array(language_token_probs)
language_probs = [
{
c: float(probs_np[i, j])
for j, c in zip(self.tokenizer.all_language_tokens, self.tokenizer.all_language_codes)
}
for i in range(n_audio)
]
self._clean_cache()
return language_tokens, language_probs
def infer(self, is_last: bool = False) -> List[ASRToken]:
"""
Main inference method.
Args:
is_last: Whether this is the final chunk
Returns:
List of timestamped ASR tokens
"""
new_segment = True
if len(self.state.segments) == 0:
logger.debug("No segments, nothing to do")
return []
if not self._apply_minseglen():
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
return []
if len(self.state.segments) > 1:
input_segments = np.concatenate(self.state.segments, axis=0)
else:
input_segments = self.state.segments[0]
beg_encode = time()
mlx_mel_padded = mlx_log_mel_spectrogram(
audio=input_segments,
n_mels=self.model.dims.n_mels,
padding=N_SAMPLES
)
mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2)
encoder_feature = self.model.encoder(mlx_mel[None])
content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0]) / 2)
mx.eval(encoder_feature)
end_encode = time()
logger.debug(f'MLX Encoder duration: {end_encode - beg_encode:.3f}s')
if self.cfg.language == "auto" and self.state.detected_language is None and self.state.first_timestamp:
seconds_since_start = self.segments_len() - self.state.first_timestamp
if seconds_since_start >= 2.0:
language_tokens, language_probs = self.lang_id(encoder_feature)
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
print(f"Detected language: {top_lan} with p={p:.4f}")
self.create_tokenizer(top_lan)
self.state.last_attend_frame = -self.cfg.rewind_threshold
self.state.cumulative_time_offset = 0.0
self.init_tokens()
self.init_context()
self.state.detected_language = top_lan
logger.info(f"Tokenizer language: {self.tokenizer.language}")
self.trim_context()
current_tokens = self._current_tokens()
fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :])
sum_logprobs = mx.zeros((self.cfg.beam_size,), dtype=mx.float32)
completed = False
attn_of_alignment_heads = None
most_attended_frame = None
token_len_before_decoding = current_tokens.shape[1]
l_absolute_timestamps = []
accumulated_cross_attns = []
audio_duration_s = self.segments_len()
max_tokens_per_chunk = max(50, int(audio_duration_s * TOKENS_PER_SECOND * 2.0))
tokens_produced_this_chunk = 0
while not completed and current_tokens.shape[1] < self.max_text_len:
tokens_produced_this_chunk += 1
if tokens_produced_this_chunk > max_tokens_per_chunk:
logger.warning(f"[Loop Detection] Too many tokens ({tokens_produced_this_chunk}) for {audio_duration_s:.2f}s audio. Breaking.")
current_tokens = current_tokens[:, :token_len_before_decoding]
break
if new_segment:
tokens_for_logits = current_tokens
else:
tokens_for_logits = current_tokens[:, -1:]
if self.state.decoder_type == "greedy":
logits, self.state.kv_cache, cross_qk = self.model.decoder(
tokens_for_logits, encoder_feature, kv_cache=self.state.kv_cache
)
else:
logits, cross_qk = self.state.inference.logits(tokens_for_logits, encoder_feature)
mx.eval(logits)
accumulated_cross_attns.append(cross_qk)
if new_segment and self.tokenizer.no_speech is not None:
probs_at_sot = mx.softmax(logits[:, self.state.sot_index, :], axis=-1)
no_speech_probs = np.array(probs_at_sot[:, self.tokenizer.no_speech]).tolist()
if no_speech_probs[0] > self.cfg.nonspeech_prob:
logger.info("no speech, stop")
break
logits = logits[:, -1, :] # Last token logits
# Suppress tokens at segment start
if new_segment:
blank_tokens = self.tokenizer.encode(" ") + [self.tokenizer.eot]
logits = logits.at[:, blank_tokens].add(-float('inf'))
new_segment = False
logits = self._suppress_tokens(logits)
current_tokens, completed = self.state.token_decoder.update(
current_tokens, logits, sum_logprobs
)
mx.eval(current_tokens)
logger.debug(f"Decoding completed: {completed}")
self.debug_print_tokens(current_tokens)
attn_of_alignment_heads = self._process_cross_attention(
accumulated_cross_attns, content_mel_len
)
most_attended_frames = mx.argmax(attn_of_alignment_heads[:, -1, :], axis=-1)
most_attended_frames_np = np.array(most_attended_frames)
absolute_timestamps = [
(frame * 0.02 + self.state.cumulative_time_offset)
for frame in most_attended_frames_np.tolist()
]
logger.debug(str(most_attended_frames_np.tolist()) + " most att frames")
logger.debug(f"Absolute timestamps: {absolute_timestamps}")
most_attended_frame = int(most_attended_frames_np[0])
l_absolute_timestamps.append(absolute_timestamps[0])
if completed:
current_tokens = current_tokens[:, :-1]
break
if not is_last and self.state.last_attend_frame - most_attended_frame > self.cfg.rewind_threshold:
current_tokens_np = np.array(current_tokens)
if current_tokens.shape[1] > 1 and current_tokens_np[0, -2] >= DEC_PAD:
logger.debug("omit rewinding from special tokens")
self.state.last_attend_frame = most_attended_frame
else:
logger.debug(f"[rewind detected] current: {most_attended_frame}, last: {self.state.last_attend_frame}")
self.state.last_attend_frame = -self.cfg.rewind_threshold
current_tokens = mx.concatenate(self.state.tokens, axis=1) if len(self.state.tokens) > 0 else self.state.tokens[0]
break
else:
self.state.last_attend_frame = most_attended_frame
if content_mel_len - most_attended_frame <= (4 if is_last else self.cfg.frame_threshold):
logger.debug(f"attention reaches the end: {most_attended_frame}/{content_mel_len}")
current_tokens = current_tokens[:, :-1]
break
tokens_to_split = np.array(current_tokens[0, token_len_before_decoding:]).tolist()
if self.state.pending_incomplete_tokens:
logger.debug(f"[UTF-8 Fix] Prepending pending tokens: {self.state.pending_incomplete_tokens}")
tokens_to_split = self.state.pending_incomplete_tokens + tokens_to_split
if fire_detected or is_last:
new_hypothesis = tokens_to_split
split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis)
else:
split_words, split_tokens = self.tokenizer.split_to_word_tokens(tokens_to_split)
if len(split_words) > 1:
new_hypothesis = [i for sublist in split_tokens[:-1] for i in sublist]
else:
new_hypothesis = []
logger.debug(f"new_hypothesis: {new_hypothesis}")
new_tokens = mx.array([new_hypothesis], dtype=mx.int32)
new_tokens = mx.repeat(new_tokens, self.cfg.beam_size, axis=0)
self.state.tokens.append(new_tokens)
logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}")
self._clean_cache()
if len(l_absolute_timestamps) >= 2 and self.state.first_timestamp is None:
self.state.first_timestamp = l_absolute_timestamps[0]
timestamped_words = []
timestamp_idx = 0
replacement_char = "\ufffd"
for word, word_tokens in zip(split_words, split_tokens):
if replacement_char in word:
logger.warning(f"[UTF-8 Filter] Skipping: {repr(word)}")
timestamp_idx += len(word_tokens)
continue
try:
current_timestamp = l_absolute_timestamps[timestamp_idx]
except IndexError:
pass
timestamp_idx += len(word_tokens)
timestamp_entry = ASRToken(
start=round(current_timestamp, 2),
end=round(current_timestamp + 0.1, 2),
text=word,
speaker=self.state.speaker,
detected_language=self.state.detected_language
).with_offset(self.state.global_time_offset)
timestamped_words.append(timestamp_entry)
self.state.pending_incomplete_tokens = []
MAX_PENDING_TOKENS = 10
if split_words and replacement_char in split_words[-1]:
if len(split_tokens[-1]) <= MAX_PENDING_TOKENS:
self.state.pending_incomplete_tokens = split_tokens[-1]
logger.debug(f"[UTF-8 Fix] Holding incomplete tokens")
else:
logger.warning(f"[UTF-8 Fix] Skipping too many tokens")
return timestamped_words
def _process_cross_attention(
self,
cross_attns: List[List[mx.array]],
content_mel_len: int
) -> mx.array:
"""
Process cross-attention weights for alignment.
Args:
cross_attns: List of cross-attention from each forward pass
Each element is a list of mx.arrays per layer
content_mel_len: Length of actual audio content
Returns:
Processed attention tensor, shape (batch, seq_len, content_mel_len)
"""
attn_of_alignment_heads = [[] for _ in range(self.state.num_align_heads)]
num_decoder_layers = self.num_decoder_layers
if cross_attns and isinstance(cross_attns[0], list):
flattened_attns = [attn for layer_list in cross_attns for attn in layer_list]
else:
flattened_attns = cross_attns
for idx, attn_mat in enumerate(flattened_attns):
if attn_mat is None:
continue
layer_rank = idx % num_decoder_layers
align_heads_in_layer = self.state.align_source.get(layer_rank, [])
if len(align_heads_in_layer) == 0:
continue
attn_mat = mx.softmax(attn_mat, axis=-1)
for align_head_rank, head_id in align_heads_in_layer:
if self.cfg.beam_size == 1:
if attn_mat.ndim == 4:
a = attn_mat[0, head_id, :, :]
else:
a = attn_mat[head_id, :, :]
a = a[None, :, :]
else:
a = attn_mat[:, head_id, :, :]
attn_of_alignment_heads[align_head_rank].append(a)
tmp = []
for mat in attn_of_alignment_heads:
if mat:
t = mx.concatenate(mat, axis=1)
tmp.append(t)
if not tmp:
return mx.zeros((self.cfg.beam_size, 1, content_mel_len))
attn_of_alignment_heads = mx.stack(tmp, axis=1)
std = mx.std(attn_of_alignment_heads, axis=-2, keepdims=True)
mean = mx.mean(attn_of_alignment_heads, axis=-2, keepdims=True)
attn_of_alignment_heads = (attn_of_alignment_heads - mean) / (std + 1e-8)
attn_of_alignment_heads = mlx_median_filter(attn_of_alignment_heads, 7)
attn_of_alignment_heads = mx.mean(attn_of_alignment_heads, axis=1)
attn_of_alignment_heads = attn_of_alignment_heads[:, :, :content_mel_len]
mx.eval(attn_of_alignment_heads)
return attn_of_alignment_heads

View File

@@ -68,40 +68,4 @@ def load_mlx_encoder(
model.update(encoder_weights) model.update(encoder_weights)
mx.eval(model.parameters()) mx.eval(model.parameters())
return model
def load_mlx_model(
path_or_hf_repo: str,
dtype: mx.Dtype = mx.float32,
) -> whisper.Whisper:
model_path = Path(path_or_hf_repo)
if not model_path.exists():
model_path = Path(snapshot_download(repo_id=path_or_hf_repo))
with open(str(model_path / "config.json"), "r") as f:
config = json.loads(f.read())
config.pop("model_type", None)
quantization = config.pop("quantization", None)
model_args = whisper.ModelDimensions(**config)
wf = model_path / "weights.safetensors"
if not wf.exists():
wf = model_path / "weights.npz"
weights = mx.load(str(wf))
model = whisper.Whisper(model_args, dtype)
if quantization is not None:
class_predicate = (
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
and f"{p}.scales" in weights
)
nn.quantize(model, **quantization, class_predicate=class_predicate)
weights = tree_unflatten(list(weights.items()))
model.update(weights)
mx.eval(model.parameters())
return model return model

View File

@@ -626,10 +626,8 @@ class AlignAtt:
try: try:
current_timestamp = l_absolute_timestamps[timestamp_idx] current_timestamp = l_absolute_timestamps[timestamp_idx]
except IndexError: except:
# Use last timestamp if index out of range pass
logger.warning(f"Timestamp index {timestamp_idx} out of range, using last timestamp")
current_timestamp = l_absolute_timestamps[-1] if l_absolute_timestamps else 0.0
timestamp_idx += len(word_tokens) timestamp_idx += len(word_tokens)
timestamp_entry = ASRToken( timestamp_entry = ASRToken(

View File

@@ -1,139 +0,0 @@
"""
Thread Safety Configuration for WhisperLiveKit
This module provides thread safety configuration and utilities.
Environment Variables:
WHISPERLIVEKIT_MODEL_LOCK: Enable/disable model locking (default: 1)
Set to "0" to disable for single-connection deployments
WHISPERLIVEKIT_LOCK_TIMEOUT: Lock acquisition timeout in seconds (default: 30)
Usage:
# Enable model locking (default)
export WHISPERLIVEKIT_MODEL_LOCK=1
# Disable for single-connection deployment
export WHISPERLIVEKIT_MODEL_LOCK=0
# Custom timeout
export WHISPERLIVEKIT_LOCK_TIMEOUT=60
"""
import os
import logging
import threading
logger = logging.getLogger(__name__)
# Configuration
USE_MODEL_LOCK = os.environ.get("WHISPERLIVEKIT_MODEL_LOCK", "1") == "1"
LOCK_TIMEOUT = float(os.environ.get("WHISPERLIVEKIT_LOCK_TIMEOUT", "30.0"))
# Global model lock
_model_lock = threading.Lock()
# Log configuration on import
if USE_MODEL_LOCK:
logger.info(f"Model locking ENABLED (timeout: {LOCK_TIMEOUT}s)")
logger.info("For single-connection deployments, set WHISPERLIVEKIT_MODEL_LOCK=0")
else:
logger.warning("Model locking DISABLED - only safe for single-connection deployments")
def get_model_lock():
"""Get the global model lock instance"""
return _model_lock
def acquire_model_lock(timeout=None):
"""
Acquire model lock with timeout.
Args:
timeout: Lock acquisition timeout (default: use LOCK_TIMEOUT)
Returns:
bool: True if lock acquired, False on timeout
"""
if not USE_MODEL_LOCK:
return True
timeout = timeout or LOCK_TIMEOUT
acquired = _model_lock.acquire(timeout=timeout)
if not acquired:
logger.error(f"Failed to acquire model lock within {timeout}s")
return acquired
def release_model_lock():
"""Release model lock"""
if not USE_MODEL_LOCK:
return
try:
_model_lock.release()
except RuntimeError:
# Lock not held - this is fine
pass
class ModelLockContext:
"""Context manager for model lock"""
def __init__(self, timeout=None):
self.timeout = timeout
self.acquired = False
def __enter__(self):
self.acquired = acquire_model_lock(self.timeout)
return self.acquired
def __exit__(self, exc_type, exc_val, exc_tb):
if self.acquired:
release_model_lock()
return False
# Concurrency recommendations
RECOMMENDED_CONNECTIONS_PER_WORKER = 1 if USE_MODEL_LOCK else 1
RECOMMENDED_WORKERS = 4
def print_deployment_recommendations():
"""Print recommended deployment configuration"""
print("\n" + "="*60)
print("WhisperLiveKit Deployment Recommendations")
print("="*60)
if USE_MODEL_LOCK:
print("⚠️ Model locking is ENABLED")
print(" This serializes inference across connections.")
print()
print("Recommended deployment:")
print(f" gunicorn -w {RECOMMENDED_WORKERS} \\")
print(" -k uvicorn.workers.UvicornWorker \\")
print(" --worker-connections 1 \\")
print(" whisperlivekit.basic_server:app")
print()
print("Expected capacity:")
print(f" - {RECOMMENDED_WORKERS} concurrent users (1 per worker)")
print(f" - Memory: ~{RECOMMENDED_WORKERS}x model size")
else:
print("✅ Model locking is DISABLED")
print(" ⚠️ ONLY safe for single-connection deployments")
print()
print("Recommended deployment:")
print(" uvicorn whisperlivekit.basic_server:app \\")
print(" --host 0.0.0.0 --port 8000 \\")
print(" --workers 1")
print()
print("Expected capacity:")
print(" - 1 concurrent user only")
print("="*60 + "\n")
if __name__ == "__main__":
print_deployment_recommendations()

View File

@@ -107,6 +107,21 @@ class Silence():
return True return True
@dataclass
class SegmentBuffer:
"""Per-segment buffer for ephemeral/unvalidated content."""
transcription: str = ''
diarization: str = ''
translation: str = ''
def to_dict(self) -> Dict[str, str]:
return {
'transcription': self.transcription,
'diarization': self.diarization,
'translation': self.translation
}
@dataclass @dataclass
class Segment(TimedText): class Segment(TimedText):
"""Generic contiguous span built from tokens or silence markers.""" """Generic contiguous span built from tokens or silence markers."""
@@ -114,14 +129,18 @@ class Segment(TimedText):
end: Optional[float] end: Optional[float]
text: Optional[str] text: Optional[str]
speaker: Optional[str] speaker: Optional[str]
id: Optional[int] = None
start_speaker: Optional[float] = None
tokens: Optional[ASRToken] = None tokens: Optional[ASRToken] = None
translation: Optional[Translation] = None translation: Optional[Translation] = None
buffer: Optional[SegmentBuffer] = None
@classmethod @classmethod
def from_tokens( def from_tokens(
cls, cls,
tokens: List[Union[ASRToken, Silence]], tokens: List[Union[ASRToken, Silence]],
is_silence: bool = False is_silence: bool = False,
segment_id: Optional[int] = None
) -> Optional["Segment"]: ) -> Optional["Segment"]:
"""Return a normalized segment representing the provided tokens.""" """Return a normalized segment representing the provided tokens."""
if not tokens: if not tokens:
@@ -134,7 +153,9 @@ class Segment(TimedText):
start=start_token.start, start=start_token.start,
end=end_token.end, end=end_token.end,
text=None, text=None,
speaker=-2 speaker=-2,
id=segment_id,
start_speaker=start_token.start
) )
else: else:
return cls( return cls(
@@ -142,6 +163,8 @@ class Segment(TimedText):
end=end_token.end, end=end_token.end,
text=''.join(token.text for token in tokens), text=''.join(token.text for token in tokens),
speaker=-1, speaker=-1,
id=segment_id,
start_speaker=start_token.start,
detected_language=start_token.detected_language detected_language=start_token.detected_language
) )
@@ -150,17 +173,18 @@ class Segment(TimedText):
return self.speaker == -2 return self.speaker == -2
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
"""Serialize the segment for frontend consumption.""" """Serialize the segment for frontend consumption (new API format)."""
_dict: Dict[str, Any] = { _dict: Dict[str, Any] = {
'id': self.id if self.id is not None else 0,
'speaker': int(self.speaker) if self.speaker != -1 else 1, 'speaker': int(self.speaker) if self.speaker != -1 else 1,
'text': self.text, 'text': self.text or '',
'start_speaker': format_time(self.start_speaker) if self.start_speaker is not None else format_time(self.start),
'start': format_time(self.start), 'start': format_time(self.start),
'end': format_time(self.end), 'end': format_time(self.end),
'language': self.detected_language,
'translation': self.translation or '',
'buffer': self.buffer.to_dict() if self.buffer else SegmentBuffer().to_dict()
} }
if self.translation:
_dict['translation'] = self.translation
if self.detected_language:
_dict['detected_language'] = self.detected_language
return _dict return _dict
@@ -179,23 +203,20 @@ class SilentSegment(Segment):
class FrontData(): class FrontData():
status: str = '' status: str = ''
error: str = '' error: str = ''
lines: list[Segment] = field(default_factory=list) segments: list[Segment] = field(default_factory=list)
buffer_transcription: str = ''
buffer_diarization: str = ''
buffer_translation: str = ''
remaining_time_transcription: float = 0. remaining_time_transcription: float = 0.
remaining_time_diarization: float = 0. remaining_time_diarization: float = 0.
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
"""Serialize the front-end data payload.""" """Serialize the front-end data payload (new API format)."""
_dict: Dict[str, Any] = { _dict: Dict[str, Any] = {
'type': 'transcript_update',
'status': self.status, 'status': self.status,
'lines': [line.to_dict() for line in self.lines if (line.text or line.speaker == -2)], 'segments': [seg.to_dict() for seg in self.segments if (seg.text or seg.speaker == -2)],
'buffer_transcription': self.buffer_transcription, 'metadata': {
'buffer_diarization': self.buffer_diarization, 'remaining_time_transcription': self.remaining_time_transcription,
'buffer_translation': self.buffer_translation, 'remaining_time_diarization': self.remaining_time_diarization,
'remaining_time_transcription': self.remaining_time_transcription, }
'remaining_time_diarization': self.remaining_time_diarization,
} }
if self.error: if self.error:
_dict['error'] = self.error _dict['error'] = self.error

View File

@@ -1,12 +1,14 @@
from time import time from time import time
from typing import Any, List, Optional, Tuple, Union from typing import Any, List, Optional, Tuple, Union
from whisperlivekit.timed_objects import (ASRToken, Segment, PuncSegment, Silence, from whisperlivekit.timed_objects import (ASRToken, Segment, SegmentBuffer, PuncSegment, Silence,
SilentSegment, SpeakerSegment, SilentSegment, SpeakerSegment,
TimedText) TimedText)
class TokensAlignment: class TokensAlignment:
# Minimum duration (seconds) for a silence to be displayed
MIN_SILENCE_DISPLAY_DURATION = 2.0
def __init__(self, state: Any, args: Any, sep: Optional[str]) -> None: def __init__(self, state: Any, args: Any, sep: Optional[str]) -> None:
self.state = state self.state = state
@@ -33,7 +35,15 @@ class TokensAlignment:
self.last_punctuation = None self.last_punctuation = None
self.last_uncompleted_punc_segment: PuncSegment = None self.last_uncompleted_punc_segment: PuncSegment = None
self.unvalidated_tokens: PuncSegment = [] self.tokens_after_last_punctuation: PuncSegment = []
self.all_validated_segments: List[Segment] = []
# For token-by-token validation with diarization
self.pending_tokens: List[ASRToken] = []
self.last_validated_token_end: float = 0.0
# Segment ID counter for the new API
self._next_segment_id: int = 1
def update(self) -> None: def update(self) -> None:
"""Drain state buffers into the running alignment context.""" """Drain state buffers into the running alignment context."""
@@ -49,8 +59,6 @@ class TokensAlignment:
def add_translation(self, segment: Segment) -> None: def add_translation(self, segment: Segment) -> None:
"""Append translated text segments that overlap with a segment.""" """Append translated text segments that overlap with a segment."""
if segment.translation is None:
segment.translation = ''
for ts in self.all_translation_segments: for ts in self.all_translation_segments:
if ts.is_within(segment): if ts.is_within(segment):
segment.translation += ts.text + (self.sep if ts.text else '') segment.translation += ts.text + (self.sep if ts.text else '')
@@ -93,11 +101,11 @@ class TokensAlignment:
def compute_new_punctuations_segments(self) -> List[PuncSegment]: def compute_new_punctuations_segments(self) -> List[PuncSegment]:
new_punc_segments = [] new_punc_segments = []
segment_start_idx = 0 segment_start_idx = 0
self.unvalidated_tokens += self.new_tokens self.tokens_after_last_punctuation += self.new_tokens
for i, token in enumerate(self.unvalidated_tokens): for i, token in enumerate(self.tokens_after_last_punctuation):
if token.is_silence(): if token.is_silence():
previous_segment = PuncSegment.from_tokens( previous_segment = PuncSegment.from_tokens(
tokens=self.unvalidated_tokens[segment_start_idx: i], tokens=self.tokens_after_last_punctuation[segment_start_idx: i],
) )
if previous_segment: if previous_segment:
new_punc_segments.append(previous_segment) new_punc_segments.append(previous_segment)
@@ -110,12 +118,12 @@ class TokensAlignment:
else: else:
if token.has_punctuation(): if token.has_punctuation():
segment = PuncSegment.from_tokens( segment = PuncSegment.from_tokens(
tokens=self.unvalidated_tokens[segment_start_idx: i+1], tokens=self.tokens_after_last_punctuation[segment_start_idx: i+1],
) )
new_punc_segments.append(segment) new_punc_segments.append(segment)
segment_start_idx = i+1 segment_start_idx = i+1
self.unvalidated_tokens = self.unvalidated_tokens[segment_start_idx:] self.tokens_after_last_punctuation = self.tokens_after_last_punctuation[segment_start_idx:]
return new_punc_segments return new_punc_segments
@@ -140,64 +148,189 @@ class TokensAlignment:
return max(0, end - start) return max(0, end - start)
def get_lines_diarization(self) -> Tuple[List[Segment], str]: def _get_speaker_for_token(self, token: ASRToken, diarization_segments: List[SpeakerSegment]) -> Optional[int]:
"""Build segments when diarization is enabled and track overflow buffer.""" """Get speaker ID for a token based on diarization overlap. Returns None if not covered."""
diarization_buffer = '' if not diarization_segments:
punctuation_segments = self.compute_punctuations_segments() return None
diarization_segments = self.concatenate_diar_segments()
for punctuation_segment in punctuation_segments:
if not punctuation_segment.is_silence():
if diarization_segments and punctuation_segment.start >= diarization_segments[-1].end:
diarization_buffer += punctuation_segment.text
else:
max_overlap = 0.0
max_overlap_speaker = 1
for diarization_segment in diarization_segments:
intersec = self.intersection_duration(punctuation_segment, diarization_segment)
if intersec > max_overlap:
max_overlap = intersec
max_overlap_speaker = diarization_segment.speaker + 1
punctuation_segment.speaker = max_overlap_speaker
segments = [] # Check if token is beyond diarization coverage
if punctuation_segments: if token.start >= diarization_segments[-1].end:
segments = [punctuation_segments[0]] return None
for segment in punctuation_segments[1:]:
if segment.speaker == segments[-1].speaker: # Find speaker with max overlap
if segments[-1].text: max_overlap = 0.0
segments[-1].text += segment.text best_speaker = None
segments[-1].end = segment.end for diar_seg in diarization_segments:
overlap = self.intersection_duration(token, diar_seg)
if overlap > max_overlap:
max_overlap = overlap
best_speaker = diar_seg.speaker + 1 # 1-indexed
return best_speaker if max_overlap > 0 else None
def get_lines_diarization(self) -> Tuple[List[Segment], str]:
"""Build segments with token-by-token validation when diarization covers them."""
diarization_segments = self.concatenate_diar_segments()
# Add new tokens to pending
self.pending_tokens.extend(self.new_tokens)
# Process pending tokens - validate those covered by diarization
still_pending = []
for token in self.pending_tokens:
if token.is_silence():
# Handle silence tokens
silence_duration = (token.end or 0) - (token.start or 0)
if silence_duration >= self.MIN_SILENCE_DISPLAY_DURATION:
# Significant silence - add as separate segment
if self.all_validated_segments and not self.all_validated_segments[-1].is_silence():
self.all_validated_segments.append(SilentSegment(
start=token.start,
end=token.end
))
elif self.all_validated_segments and self.all_validated_segments[-1].is_silence():
# Extend existing silence
self.all_validated_segments[-1].end = token.end
else:
self.all_validated_segments.append(SilentSegment(
start=token.start,
end=token.end
))
# Short silences are ignored (don't go to pending either)
continue
speaker = self._get_speaker_for_token(token, diarization_segments)
if speaker is not None:
# Token is covered by diarization - validate it
if self.all_validated_segments:
last_seg = self.all_validated_segments[-1]
if not last_seg.is_silence() and last_seg.speaker == speaker:
# Same speaker - append to existing segment
last_seg.text += token.text
last_seg.end = token.end
else:
# Different speaker or after silence - new segment
new_seg = Segment(
start=token.start,
end=token.end,
text=token.text,
speaker=speaker,
start_speaker=token.start,
detected_language=token.detected_language
)
self.all_validated_segments.append(new_seg)
else: else:
segments.append(segment) # First segment
new_seg = Segment(
start=token.start,
end=token.end,
text=token.text,
speaker=speaker,
start_speaker=token.start,
detected_language=token.detected_language
)
self.all_validated_segments.append(new_seg)
self.last_validated_token_end = token.end
else:
# Token not yet covered by diarization - keep pending
still_pending.append(token)
self.pending_tokens = still_pending
# Build diarization buffer from pending tokens
diarization_buffer = ''.join(t.text for t in self.pending_tokens if not t.is_silence())
return self.all_validated_segments, diarization_buffer
return segments, diarization_buffer
def _assign_segment_ids(self, segments: List[Segment]) -> None:
"""Assign unique IDs to segments that don't have one yet."""
for segment in segments:
if segment.id is None:
segment.id = self._next_segment_id
self._next_segment_id += 1
def _assign_buffers_to_last_segment(
self,
segments: List[Segment],
buffer_transcription: str,
buffer_diarization: str,
buffer_translation: str
) -> None:
"""Assign buffer content to the last non-silent segment."""
# First, clear ALL buffers (they're ephemeral and shouldn't persist)
for segment in segments:
segment.buffer = SegmentBuffer()
# Find the last non-silent segment and assign buffers to it
for segment in reversed(segments):
if not segment.is_silence():
segment.buffer = SegmentBuffer(
transcription=buffer_transcription,
diarization=buffer_diarization,
translation=buffer_translation
)
break
def _filter_and_merge_segments(self, segments: List[Segment]) -> List[Segment]:
"""Filter parasitic silences and merge consecutive same-speaker segments."""
if not segments:
return segments
result = []
for seg in segments:
if seg.is_silence():
# Filter short silences
duration = (seg.end or 0) - (seg.start or 0)
if duration < self.MIN_SILENCE_DISPLAY_DURATION:
continue
# Merge consecutive silences
if result and result[-1].is_silence():
result[-1].end = seg.end
continue
else:
# Merge same speaker segments (across filtered silences)
if result and not result[-1].is_silence() and result[-1].speaker == seg.speaker:
result[-1].text += seg.text
result[-1].end = seg.end
continue
result.append(seg)
return result
def get_lines( def get_lines(
self, self,
diarization: bool = False, diarization: bool = False,
translation: bool = False, translation: bool = False,
current_silence: Optional[Silence] = None current_silence: Optional[Silence] = None,
) -> Tuple[List[Segment], str, Union[str, TimedText]]: buffer_transcription: str = ''
"""Return the formatted segments plus buffers, optionally with diarization/translation.""" ) -> List[Segment]:
"""Return the formatted segments with per-segment buffers, optionally with diarization/translation."""
diarization_buffer = ''
if diarization: if diarization:
segments, diarization_buffer = self.get_lines_diarization() segments, diarization_buffer = self.get_lines_diarization()
else: else:
diarization_buffer = ''
for token in self.new_tokens: for token in self.new_tokens:
if token.is_silence(): if token.is_silence():
if self.current_line_tokens: # Check silence duration before adding
self.validated_segments.append(Segment().from_tokens(self.current_line_tokens)) silence_duration = (token.end or 0) - (token.start or 0)
self.current_line_tokens = [] if silence_duration >= self.MIN_SILENCE_DISPLAY_DURATION:
if self.current_line_tokens:
end_silence = token.end if token.has_ended else time() - self.beg_loop self.validated_segments.append(Segment().from_tokens(self.current_line_tokens))
if self.validated_segments and self.validated_segments[-1].is_silence(): self.current_line_tokens = []
self.validated_segments[-1].end = end_silence
else: end_silence = token.end if token.has_ended else time() - self.beg_loop
self.validated_segments.append(SilentSegment( if self.validated_segments and self.validated_segments[-1].is_silence():
start=token.start, self.validated_segments[-1].end = end_silence
end=end_silence else:
)) self.validated_segments.append(SilentSegment(
start=token.start,
end=end_silence
))
else: else:
self.current_line_tokens.append(token) self.current_line_tokens.append(token)
@@ -205,15 +338,37 @@ class TokensAlignment:
if self.current_line_tokens: if self.current_line_tokens:
segments.append(Segment().from_tokens(self.current_line_tokens)) segments.append(Segment().from_tokens(self.current_line_tokens))
# Handle current ongoing silence
if current_silence: if current_silence:
end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop silence_duration = (current_silence.end or time() - self.beg_loop) - (current_silence.start or 0)
if segments and segments[-1].is_silence(): if silence_duration >= self.MIN_SILENCE_DISPLAY_DURATION:
segments[-1] = SilentSegment(start=segments[-1].start, end=end_silence) end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop
else: if segments and segments[-1].is_silence():
segments.append(SilentSegment( segments[-1] = SilentSegment(start=segments[-1].start, end=end_silence)
start=current_silence.start, else:
end=end_silence segments.append(SilentSegment(
)) start=current_silence.start,
end=end_silence
))
if translation: if translation:
[self.add_translation(segment) for segment in segments if not segment.is_silence()] [self.add_translation(segment) for segment in segments if not segment.is_silence()]
return segments, diarization_buffer, self.new_translation_buffer.text
# Get translation buffer text
translation_buffer = self.new_translation_buffer.text if self.new_translation_buffer else ''
# Filter parasitic silences and merge same-speaker segments
segments = self._filter_and_merge_segments(segments)
# Assign unique IDs to all segments
self._assign_segment_ids(segments)
# Assign buffers to the last active segment
self._assign_buffers_to_last_segment(
segments,
buffer_transcription=buffer_transcription,
buffer_diarization=diarization_buffer,
buffer_translation=translation_buffer
)
return segments

View File

@@ -454,8 +454,9 @@ label {
gap: 4px; gap: 4px;
} }
.lag-diarization-value { .lag-diarization-value,
margin-left: 10px; .lag-transcription-value {
font-weight: 600;
} }
.label_translation img { .label_translation img {

View File

@@ -232,11 +232,8 @@ function setupWebSocket() {
if (waitingForStop) { if (waitingForStop) {
statusText.textContent = "Processing finalized or connection closed."; statusText.textContent = "Processing finalized or connection closed.";
if (lastReceivedData) { if (lastReceivedData) {
renderLinesWithBuffer( renderSegments(
lastReceivedData.lines || [], lastReceivedData.segments || [],
lastReceivedData.buffer_diarization || "",
lastReceivedData.buffer_transcription || "",
lastReceivedData.buffer_translation || "",
0, 0,
0, 0,
true true
@@ -278,11 +275,8 @@ function setupWebSocket() {
waitingForStop = false; waitingForStop = false;
if (lastReceivedData) { if (lastReceivedData) {
renderLinesWithBuffer( renderSegments(
lastReceivedData.lines || [], lastReceivedData.segments || [],
lastReceivedData.buffer_diarization || "",
lastReceivedData.buffer_transcription || "",
lastReceivedData.buffer_translation || "",
0, 0,
0, 0,
true true
@@ -299,21 +293,20 @@ function setupWebSocket() {
lastReceivedData = data; lastReceivedData = data;
// New API format: segments with per-segment buffers, metadata wrapper
const { const {
lines = [], segments = [],
buffer_transcription = "", metadata = {},
buffer_diarization = "",
buffer_translation = "",
remaining_time_transcription = 0,
remaining_time_diarization = 0,
status = "active_transcription", status = "active_transcription",
} = data; } = data;
const {
remaining_time_transcription = 0,
remaining_time_diarization = 0,
} = metadata;
renderLinesWithBuffer( renderSegments(
lines, segments,
buffer_diarization,
buffer_transcription,
buffer_translation,
remaining_time_diarization, remaining_time_diarization,
remaining_time_transcription, remaining_time_transcription,
false, false,
@@ -323,11 +316,8 @@ function setupWebSocket() {
}); });
} }
function renderLinesWithBuffer( function renderSegments(
lines, segments,
buffer_diarization,
buffer_transcription,
buffer_translation,
remaining_time_diarization, remaining_time_diarization,
remaining_time_transcription, remaining_time_transcription,
isFinalizing = false, isFinalizing = false,
@@ -339,33 +329,38 @@ function renderLinesWithBuffer(
return; return;
} }
const showLoading = !isFinalizing && (lines || []).some((it) => it.speaker == 0); // Build signature for change detection
const showTransLag = !isFinalizing && remaining_time_transcription > 0;
const showDiaLag = !isFinalizing && !!buffer_diarization && remaining_time_diarization > 0;
const signature = JSON.stringify({ const signature = JSON.stringify({
lines: (lines || []).map((it) => ({ speaker: it.speaker, text: it.text, start: it.start, end: it.end, detected_language: it.detected_language })), segments: (segments || []).map((it) => ({
buffer_transcription: buffer_transcription || "", id: it.id,
buffer_diarization: buffer_diarization || "", speaker: it.speaker,
buffer_translation: buffer_translation, text: it.text,
start: it.start,
end: it.end,
language: it.language,
buffer: it.buffer || {}
})),
status: current_status, status: current_status,
showLoading,
showTransLag,
showDiaLag,
isFinalizing: !!isFinalizing, isFinalizing: !!isFinalizing,
}); });
// Only update lag values if signature unchanged
if (lastSignature === signature) { if (lastSignature === signature) {
const t = document.querySelector(".lag-transcription-value"); const t = document.querySelector(".lag-transcription-value");
if (t) t.textContent = fmt1(remaining_time_transcription); if (t) t.textContent = fmt1(remaining_time_transcription);
const d = document.querySelector(".lag-diarization-value"); const d = document.querySelector(".lag-diarization-value");
if (d) d.textContent = fmt1(remaining_time_diarization); if (d) d.textContent = fmt1(remaining_time_diarization);
const ld = document.querySelector(".loading-diarization-value");
if (ld) ld.textContent = fmt1(remaining_time_diarization);
return; return;
} }
lastSignature = signature; lastSignature = signature;
const linesHtml = (lines || []) const segmentsHtml = (segments || [])
.map((item, idx) => { .map((item, idx) => {
const buffer = item.buffer || {};
const buffer_transcription = buffer.transcription || "";
const buffer_diarization = buffer.diarization || "";
const buffer_translation = buffer.translation || "";
let timeInfo = ""; let timeInfo = "";
if (item.start !== undefined && item.end !== undefined) { if (item.start !== undefined && item.end !== undefined) {
timeInfo = ` ${item.start} - ${item.end}`; timeInfo = ` ${item.start} - ${item.end}`;
@@ -373,80 +368,78 @@ function renderLinesWithBuffer(
let speakerLabel = ""; let speakerLabel = "";
if (item.speaker === -2) { if (item.speaker === -2) {
// Silence segment
speakerLabel = `<span class="silence">${silenceIcon}<span id='timeInfo'>${timeInfo}</span></span>`; speakerLabel = `<span class="silence">${silenceIcon}<span id='timeInfo'>${timeInfo}</span></span>`;
} else if (item.speaker == 0 && !isFinalizing) {
speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'><span class="loading-diarization-value">${fmt1(
remaining_time_diarization
)}</span> second(s) of audio are undergoing diarization</span></span>`;
} else if (item.speaker !== 0) { } else if (item.speaker !== 0) {
// Normal speaker segment
const speakerNum = `<span class="speaker-badge">${item.speaker}</span>`; const speakerNum = `<span class="speaker-badge">${item.speaker}</span>`;
speakerLabel = `<span id="speaker">${speakerIcon}${speakerNum}<span id='timeInfo'>${timeInfo}</span></span>`; speakerLabel = `<span id="speaker">${speakerIcon}${speakerNum}<span id='timeInfo'>${timeInfo}</span></span>`;
if (item.detected_language) { if (item.language) {
speakerLabel += `<span class="label_language">${languageIcon}<span>${item.detected_language}</span></span>`; speakerLabel += `<span class="label_language">${languageIcon}<span>${item.language}</span></span>`;
} }
} }
let currentLineText = item.text || ""; let currentLineText = item.text || "";
const isLastSegment = idx === segments.length - 1;
if (idx === lines.length - 1) { const hasBufferContent = buffer_diarization || buffer_transcription;
if (!isFinalizing && item.speaker !== -2) {
speakerLabel += `<span class="label_transcription"><span class="spinner"></span>Transcription lag <span id='timeInfo'><span class="lag-transcription-value">${fmt1( // Show lag indicators on last non-silent segment (without spinners)
remaining_time_transcription if (isLastSegment && item.speaker !== -2 && !isFinalizing) {
)}</span>s</span></span>`; if (remaining_time_transcription > 0) {
speakerLabel += `<span class="label_transcription">Transcription lag: <span class="lag-transcription-value">${fmt1(remaining_time_transcription)}</span>s</span>`;
if (buffer_diarization && remaining_time_diarization) {
speakerLabel += `<span class="label_diarization"><span class="spinner"></span>Diarization lag<span id='timeInfo'><span class="lag-diarization-value">${fmt1(
remaining_time_diarization
)}</span>s</span></span>`;
}
} }
if (buffer_diarization && remaining_time_diarization > 0) {
speakerLabel += `<span class="label_diarization">Diarization lag: <span class="lag-diarization-value">${fmt1(remaining_time_diarization)}</span>s</span>`;
}
}
// Render buffers
if (hasBufferContent && item.speaker !== -2) {
if (buffer_diarization) { if (buffer_diarization) {
if (isFinalizing) { if (isFinalizing) {
currentLineText += currentLineText += (currentLineText.length > 0 ? " " : "") + buffer_diarization.trim();
(currentLineText.length > 0 && buffer_diarization.trim().length > 0 ? " " : "") + buffer_diarization.trim();
} else { } else {
currentLineText += `<span class="buffer_diarization">${buffer_diarization}</span>`; currentLineText += `<span class="buffer_diarization">${buffer_diarization}</span>`;
} }
} }
if (buffer_transcription) { if (buffer_transcription) {
if (isFinalizing) { if (isFinalizing) {
currentLineText += currentLineText += (currentLineText.length > 0 ? " " : "") + buffer_transcription.trim();
(currentLineText.length > 0 && buffer_transcription.trim().length > 0 ? " " : "") +
buffer_transcription.trim();
} else { } else {
currentLineText += `<span class="buffer_transcription">${buffer_transcription}</span>`; currentLineText += `<span class="buffer_transcription">${buffer_transcription}</span>`;
} }
} }
} }
// Translation
let translationContent = ""; let translationContent = "";
if (item.translation) { if (item.translation) {
translationContent += item.translation.trim(); translationContent += item.translation.trim();
} }
if (idx === lines.length - 1 && buffer_translation) { if (buffer_translation) {
const bufferPiece = isFinalizing const bufferPiece = isFinalizing
? buffer_translation ? buffer_translation
: `<span class="buffer_translation">${buffer_translation}</span>`; : `<span class="buffer_translation">${buffer_translation}</span>`;
translationContent += translationContent ? `${bufferPiece}` : bufferPiece; translationContent += translationContent ? bufferPiece : bufferPiece;
} }
if (translationContent.trim().length > 0) { if (translationContent.trim().length > 0) {
currentLineText += ` currentLineText += `
<div> <div class="label_translation">
<div class="label_translation"> ${translationIcon}
${translationIcon} <span class="translation_text">${translationContent}</span>
<span class="translation_text">${translationContent}</span> </div>`;
</div>
</div>`;
} }
return currentLineText.trim().length > 0 || speakerLabel.length > 0 if (currentLineText.trim().length > 0 || speakerLabel.length > 0) {
? `<p>${speakerLabel}<br/><div class='textcontent'>${currentLineText}</div></p>` return `<p>${speakerLabel}<br/><div class='textcontent'>${currentLineText}</div></p>`;
: `<p>${speakerLabel}<br/></p>`; }
return speakerLabel ? `<p>${speakerLabel}</p>` : "";
}) })
.filter(html => html.length > 0)
.join(""); .join("");
linesTranscriptDiv.innerHTML = linesHtml; linesTranscriptDiv.innerHTML = segmentsHtml;
const transcriptContainer = document.querySelector('.transcript-container'); const transcriptContainer = document.querySelector('.transcript-container');
if (transcriptContainer) { if (transcriptContainer) {
transcriptContainer.scrollTo({ top: transcriptContainer.scrollHeight, behavior: "smooth" }); transcriptContainer.scrollTo({ top: transcriptContainer.scrollHeight, behavior: "smooth" });

View File

@@ -0,0 +1,377 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>WhisperLiveKit Transcript</title>
<style>
:root {
--bg: #111;
--text: #ddd;
--dim: #666;
--border: #333;
--active: #e74c3c;
}
body {
font-family: 'SF Mono', 'Monaco', 'Inconsolata', 'Roboto Mono', monospace;
background: var(--bg);
color: var(--text);
margin: 0;
padding: 2rem;
font-size: 13px;
line-height: 1.6;
}
.nav {
display: flex;
gap: 12px;
align-items: center;
margin-bottom: 3rem;
font-size: 12px;
}
button, input, select {
background: transparent;
border: 1px solid var(--border);
color: var(--dim);
padding: 6px 12px;
font-family: inherit;
font-size: inherit;
border-radius: 4px;
outline: none;
transition: all 0.2s;
}
button:hover, input:hover, input:focus, select:hover, select:focus {
border-color: var(--text);
color: var(--text);
cursor: pointer;
}
select {
cursor: pointer;
appearance: none; /* Minimalist look */
background-image: linear-gradient(45deg, transparent 50%, var(--dim) 50%), linear-gradient(135deg, var(--dim) 50%, transparent 50%);
background-position: calc(100% - 15px) 50%, calc(100% - 10px) 50%;
background-size: 5px 5px, 5px 5px;
background-repeat: no-repeat;
padding-right: 25px;
}
select:hover, select:focus {
background-image: linear-gradient(45deg, transparent 50%, var(--text) 50%), linear-gradient(135deg, var(--text) 50%, transparent 50%);
}
button.recording {
border-color: var(--active);
color: var(--active);
}
input {
width: 150px;
cursor: text;
}
#status {
margin-left: auto;
color: var(--dim);
}
#transcript {
white-space: pre-wrap;
word-wrap: break-word;
max-width: 800px;
margin: 0 auto;
outline: none;
}
/* Minimal scrollbar */
::-webkit-scrollbar { width: 6px; }
::-webkit-scrollbar-track { background: transparent; }
::-webkit-scrollbar-thumb { background: #222; border-radius: 3px; }
::-webkit-scrollbar-thumb:hover { background: #333; }
</style>
</head>
<body>
<div class="nav">
<button id="recordBtn">Record</button>
<button id="copyBtn">Copy</button>
<select id="microphoneSelect"></select>
<input type="text" id="wsUrl" placeholder="WebSocket URL">
<div id="status">Ready</div>
</div>
<div id="transcript"></div>
<script>
const recordBtn = document.getElementById('recordBtn');
const copyBtn = document.getElementById('copyBtn');
const wsUrlInput = document.getElementById('wsUrl');
const statusEl = document.getElementById('status');
const transcriptEl = document.getElementById('transcript');
const microphoneSelect = document.getElementById('microphoneSelect');
// Default WebSocket URL
const protocol = window.location.protocol === 'https:' ? 'wss' : 'ws';
const host = window.location.hostname || 'localhost';
const port = window.location.port;
const defaultUrl = `${protocol}://${host}${port ? ':' + port : ''}/asr`;
wsUrlInput.value = defaultUrl;
let websocket = null;
let isRecording = false;
let audioContext = null;
let workletNode = null;
let recorderWorker = null;
let microphone = null;
let useAudioWorklet = false;
let recorder = null;
let availableMicrophones = [];
let selectedMicrophoneId = null;
async function enumerateMicrophones() {
try {
// Request permission first to get labels
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
stream.getTracks().forEach(track => track.stop());
const devices = await navigator.mediaDevices.enumerateDevices();
availableMicrophones = devices.filter(device => device.kind === 'audioinput');
populateMicrophoneSelect();
} catch (error) {
console.error('Error enumerating microphones:', error);
statusEl.textContent = "Mic permission needed";
}
}
function populateMicrophoneSelect() {
microphoneSelect.innerHTML = '<option value="">Default Microphone</option>';
availableMicrophones.forEach((device, index) => {
const option = document.createElement('option');
option.value = device.deviceId;
option.textContent = device.label || `Microphone ${index + 1}`;
microphoneSelect.appendChild(option);
});
const savedMicId = localStorage.getItem('selectedMicrophone');
if (savedMicId && availableMicrophones.some(mic => mic.deviceId === savedMicId)) {
microphoneSelect.value = savedMicId;
selectedMicrophoneId = savedMicId;
}
}
function handleMicrophoneChange() {
selectedMicrophoneId = microphoneSelect.value || null;
localStorage.setItem('selectedMicrophone', selectedMicrophoneId || '');
if (isRecording) {
stopRecording();
setTimeout(() => {
startRecording();
}, 500);
}
}
microphoneSelect.addEventListener('change', handleMicrophoneChange);
// Initial enumeration
enumerateMicrophones();
navigator.mediaDevices.addEventListener('devicechange', enumerateMicrophones);
function formatSegment(segment) {
const speaker = segment.speaker;
const text = segment.text || '';
const buffer = segment.buffer || {};
const start = segment.start || '';
const end = segment.end || '';
const language = segment.language || '';
let output = '';
// Silence marker
if (speaker === -2) {
output += `[SILENCE ${start} - ${end}]\n`;
return output;
}
// Speaker header
output += `[SPEAKER ${speaker}]`;
if (start && end) output += ` ${start} - ${end}`;
if (language) output += ` [LANG: ${language}]`;
output += '\n';
// Main text
if (text) {
output += text;
}
// Diarization buffer (text waiting for speaker assignment)
if (buffer.diarization) {
output += `[DIAR_BUFFER]${buffer.diarization}[/DIAR_BUFFER]`;
}
// Transcription buffer (text waiting for validation)
if (buffer.transcription) {
output += `[TRANS_BUFFER]${buffer.transcription}[/TRANS_BUFFER]`;
}
output += '\n';
// Translation
if (segment.translation) {
output += `[TRANSLATION]${segment.translation}`;
if (buffer.translation) {
output += `[TRANS_BUFFER]${buffer.translation}[/TRANS_BUFFER]`;
}
output += `[/TRANSLATION]\n`;
} else if (buffer.translation) {
output += `[TRANSLATION][TRANS_BUFFER]${buffer.translation}[/TRANS_BUFFER][/TRANSLATION]\n`;
}
return output;
}
function renderTranscript(data) {
const { segments = [], metadata = {}, status: msgStatus } = data;
if (msgStatus === 'no_audio_detected') {
// transcriptEl.textContent = '[NO AUDIO DETECTED]';
// Minimalist: maybe just don't show anything or show status
statusEl.textContent = 'No audio detected';
return;
}
let output = '';
// Metadata header
const remainingTrans = metadata.remaining_time_transcription || 0;
const remainingDiar = metadata.remaining_time_diarization || 0;
if (remainingTrans > 0 || remainingDiar > 0) {
output += `[LAG: trans=${remainingTrans.toFixed(1)}s diar=${remainingDiar.toFixed(1)}s]\n\n`;
}
// All segments
for (const segment of segments) {
output += formatSegment(segment);
output += '\n';
}
transcriptEl.textContent = output;
transcriptEl.scrollTop = transcriptEl.scrollHeight;
}
async function startRecording() {
try {
websocket = new WebSocket(wsUrlInput.value);
websocket.onopen = async () => {
statusEl.textContent = 'Connecting...';
};
websocket.onmessage = async (event) => {
const data = JSON.parse(event.data);
if (data.type === 'config') {
useAudioWorklet = !!data.useAudioWorklet;
statusEl.textContent = 'Recording';
await initAudio();
return;
}
if (data.type === 'ready_to_stop') {
statusEl.textContent = 'Done';
return;
}
// transcript_update
renderTranscript(data);
};
websocket.onclose = () => {
statusEl.textContent = 'Disconnected';
stopRecording(false);
};
websocket.onerror = () => {
statusEl.textContent = 'Error';
};
} catch (err) {
statusEl.textContent = 'Error: ' + err.message;
}
}
async function initAudio() {
const audioConstraints = selectedMicrophoneId
? { audio: { deviceId: { exact: selectedMicrophoneId } } }
: { audio: true };
const stream = await navigator.mediaDevices.getUserMedia(audioConstraints);
audioContext = new (window.AudioContext || window.webkitAudioContext)();
microphone = audioContext.createMediaStreamSource(stream);
if (useAudioWorklet) {
await audioContext.audioWorklet.addModule('/web/pcm_worklet.js');
workletNode = new AudioWorkletNode(audioContext, 'pcm-forwarder', {
numberOfInputs: 1, numberOfOutputs: 0, channelCount: 1
});
microphone.connect(workletNode);
recorderWorker = new Worker('/web/recorder_worker.js');
recorderWorker.postMessage({ command: 'init', config: { sampleRate: audioContext.sampleRate } });
recorderWorker.onmessage = (e) => {
if (websocket?.readyState === WebSocket.OPEN) {
websocket.send(e.data.buffer);
}
};
workletNode.port.onmessage = (e) => {
const ab = e.data instanceof ArrayBuffer ? e.data : e.data.buffer;
recorderWorker.postMessage({ command: 'record', buffer: ab }, [ab]);
};
} else {
try {
recorder = new MediaRecorder(stream, { mimeType: 'audio/webm' });
} catch {
recorder = new MediaRecorder(stream);
}
recorder.ondataavailable = (e) => {
if (websocket?.readyState === WebSocket.OPEN && e.data?.size > 0) {
websocket.send(e.data);
}
};
recorder.start(100);
}
isRecording = true;
recordBtn.textContent = 'Stop';
recordBtn.classList.add('recording');
}
function stopRecording(sendStop = true) {
if (sendStop && websocket?.readyState === WebSocket.OPEN) {
websocket.send(new Blob([], { type: 'audio/webm' }));
}
if (recorder) { try { recorder.stop(); } catch {} recorder = null; }
if (recorderWorker) { recorderWorker.terminate(); recorderWorker = null; }
if (workletNode) { workletNode.disconnect(); workletNode = null; }
if (microphone) { microphone.disconnect(); microphone = null; }
if (audioContext) { audioContext.close(); audioContext = null; }
isRecording = false;
recordBtn.textContent = 'Record';
recordBtn.classList.remove('recording');
}
recordBtn.addEventListener('click', () => {
if (!isRecording) {
startRecording();
} else {
stopRecording();
}
});
copyBtn.addEventListener('click', () => {
navigator.clipboard.writeText(transcriptEl.textContent).then(() => {
const original = copyBtn.textContent;
copyBtn.textContent = 'Copied';
setTimeout(() => { copyBtn.textContent = original; }, 1500);
});
});
</script>
</body>
</html>

View File

@@ -13,6 +13,37 @@ def get_web_interface_html():
logger.error(f"Error loading web interface HTML: {e}") logger.error(f"Error loading web interface HTML: {e}")
return "<html><body><h1>Error loading interface</h1></body></html>" return "<html><body><h1>Error loading interface</h1></body></html>"
def get_text_transcript_html():
"""Loads the simple text-based transcript HTML for easy copy/paste."""
try:
with resources.files('whisperlivekit.web').joinpath('text_transcript.html').open('r', encoding='utf-8') as f:
html_content = f.read()
# Inline the worker scripts
with resources.files('whisperlivekit.web').joinpath('pcm_worklet.js').open('r', encoding='utf-8') as f:
worklet_code = f.read()
with resources.files('whisperlivekit.web').joinpath('recorder_worker.js').open('r', encoding='utf-8') as f:
worker_code = f.read()
html_content = html_content.replace(
"await audioContext.audioWorklet.addModule('/web/pcm_worklet.js');",
'const workletBlob = new Blob([`' + worklet_code + '`], { type: "application/javascript" });\n' +
'const workletUrl = URL.createObjectURL(workletBlob);\n' +
'await audioContext.audioWorklet.addModule(workletUrl);'
)
html_content = html_content.replace(
"recorderWorker = new Worker('/web/recorder_worker.js');",
'const workerBlob = new Blob([`' + worker_code + '`], { type: "application/javascript" });\n' +
'const workerUrl = URL.createObjectURL(workerBlob);\n' +
'recorderWorker = new Worker(workerUrl);'
)
return html_content
except Exception as e:
logger.error(f"Error loading text transcript HTML: {e}")
return "<html><body><h1>Error loading text interface</h1></body></html>"
def get_inline_ui_html(): def get_inline_ui_html():
"""Returns the complete web interface HTML with all assets embedded in a single call.""" """Returns the complete web interface HTML with all assets embedded in a single call."""
try: try:

View File

@@ -18,6 +18,8 @@ from whisperlivekit.whisper.decoding import (DecodingOptions, DecodingResult,
from whisperlivekit.whisper.model import ModelDimensions, Whisper from whisperlivekit.whisper.model import ModelDimensions, Whisper
from whisperlivekit.whisper.transcribe import transcribe from whisperlivekit.whisper.transcribe import transcribe
from whisperlivekit.whisper.version import __version__ from whisperlivekit.whisper.version import __version__
from whisperlivekit.whisper.lora import (LoRAAdapter, LoRAAdapterManager,
LoRAConfig, LoRALinear)
_MODELS = { _MODELS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
@@ -108,7 +110,7 @@ def available_models() -> List[str]:
def _infer_dims_from_config(path: str) -> Optional[ModelDimensions]: def _infer_dims_from_config(path: str) -> Optional[ModelDimensions]:
""" """
attempt to infer ModelDimensions from a HF style config.json located attempt to infer ModelDimensions from a HF style config.json located
next to the given checkpoint, usefull for distilled models/MLX models. next to the given checkpoint, usefull for distilled models
""" """
candidates = [] candidates = []
if os.path.isdir(path): if os.path.isdir(path):
@@ -122,25 +124,6 @@ def _infer_dims_from_config(path: str) -> Optional[ModelDimensions]:
with open(candidate, "r", encoding="utf-8") as f: with open(candidate, "r", encoding="utf-8") as f:
config = json.load(f) config = json.load(f)
# native Whisper format
native_keys = ["n_mels", "n_audio_ctx", "n_audio_state", "n_audio_head",
"n_audio_layer", "n_vocab", "n_text_ctx", "n_text_state",
"n_text_head", "n_text_layer"]
if all(k in config for k in native_keys):
return ModelDimensions(
n_mels=config["n_mels"],
n_audio_ctx=config["n_audio_ctx"],
n_audio_state=config["n_audio_state"],
n_audio_head=config["n_audio_head"],
n_audio_layer=config["n_audio_layer"],
n_vocab=config["n_vocab"],
n_text_ctx=config["n_text_ctx"],
n_text_state=config["n_text_state"],
n_text_head=config["n_text_head"],
n_text_layer=config["n_text_layer"],
)
# HuggingFace format
try: try:
return ModelDimensions( return ModelDimensions(
n_mels=config["num_mel_bins"], n_mels=config["num_mel_bins"],
@@ -255,24 +238,6 @@ def _convert_hf_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, tor
return converted if converted else state_dict return converted if converted else state_dict
def _convert_mlx_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Converts an mlx whisper checkpoint to a default openai whisper one
"""
if not any("mlp1" in k or "mlp2" in k for k in state_dict):
return state_dict
converted = {}
for key, value in state_dict.items():
if key == "alignment_heads":
continue
new_key = key.replace(".mlp1.", ".mlp.0.").replace(".mlp2.", ".mlp.2.")
converted[new_key] = value
return converted
def _load_lora_state(lora_path: str): def _load_lora_state(lora_path: str):
safe_path = os.path.join(lora_path, "adapter_model.safetensors") safe_path = os.path.join(lora_path, "adapter_model.safetensors")
bin_path = os.path.join(lora_path, "adapter_model.bin") bin_path = os.path.join(lora_path, "adapter_model.bin")
@@ -557,12 +522,7 @@ def load_model(
state_dict = checkpoint["model_state_dict"] state_dict = checkpoint["model_state_dict"]
else: else:
state_dict = checkpoint state_dict = checkpoint
if alignment_heads is None and "alignment_heads" in state_dict:
alignment_heads = state_dict["alignment_heads"]
state_dict = _convert_hf_state_dict(state_dict) state_dict = _convert_hf_state_dict(state_dict)
state_dict = _convert_mlx_state_dict(state_dict)
_apply_lora_adapter(state_dict, lora_path) _apply_lora_adapter(state_dict, lora_path)
if dims_cfg is not None: if dims_cfg is not None:
@@ -588,16 +548,99 @@ def load_model(
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
if alignment_heads is not None: if alignment_heads is not None:
if isinstance(alignment_heads, bytes): model.set_alignment_heads(alignment_heads)
model.set_alignment_heads(alignment_heads)
elif isinstance(alignment_heads, torch.Tensor): #for mlx whisper
mask = torch.zeros(dims.n_text_layer, dims.n_text_head, dtype=torch.bool)
for layer, head in alignment_heads.tolist():
mask[layer, head] = True
model.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
return model.to(device) return model.to(device)
def load_model_with_lora_manager(
name: str,
device: Optional[Union[str, torch.device]] = None,
download_root: str = None,
in_memory: bool = False,
decoder_only: bool = False,
custom_alignment_heads: Optional[str] = None,
adapters: Optional[Dict[str, str]] = None,
) -> tuple:
"""
Load a Whisper model with a LoRA adapter manager for dynamic adapter swapping.
This allows you to load multiple LoRA adapters and switch between them at runtime
without keeping multiple full models in memory.
Parameters
----------
name : str
Model name or path (same as load_model)
device : Union[str, torch.device]
Device to load model on
download_root : str
Download directory for model files
in_memory : bool
Whether to preload model weights into host memory
decoder_only : bool
If True, only load the decoder (no encoder)
custom_alignment_heads : str
Custom alignment heads configuration
adapters : Dict[str, str]
Optional dict mapping adapter names to paths/HuggingFace repo IDs.
Example: {"french": "path/to/french-lora", "spanish": "user/spanish-whisper-lora"}
Returns
-------
model : Whisper
The base Whisper model (without any LoRA baked in)
manager : LoRAAdapterManager
The adapter manager for loading/switching adapters
Example
-------
>>> model, manager = load_model_with_lora_manager(
... "large-v3",
... adapters={
... "french": "path/to/french-lora",
... "spanish": "path/to/spanish-lora"
... }
... )
>>>
>>> # Switch to French adapter
>>> manager.set_adapter("french")
>>> result_fr = model.transcribe(audio_fr)
>>>
>>> # Switch to Spanish adapter
>>> manager.set_adapter("spanish")
>>> result_es = model.transcribe(audio_es)
>>>
>>> # Use base model without LoRA
>>> manager.set_adapter(None)
>>> result_base = model.transcribe(audio)
>>>
>>> # Check memory usage
>>> print(manager.get_memory_usage())
{'french': 12.5, 'spanish': 12.5} # MB per adapter
"""
# Load the base model WITHOUT any LoRA baked in
model = load_model(
name=name,
device=device,
download_root=download_root,
in_memory=in_memory,
decoder_only=decoder_only,
custom_alignment_heads=custom_alignment_heads,
lora_path=None, # Important: no baked-in LoRA
)
# Create the adapter manager
manager = LoRAAdapterManager(model)
# Load any provided adapters
if adapters:
for adapter_name, adapter_path in adapters.items():
manager.load_adapter(adapter_name, adapter_path)
return model, manager
def convert_encoder_to_coreml( def convert_encoder_to_coreml(
model_name = "base", model_name = "base",
output_path= "whisper_encoder.mlpackage", output_path= "whisper_encoder.mlpackage",

View File

@@ -0,0 +1,473 @@
"""
Dynamic LoRA adapter support for Whisper models.
This module enables loading a single base Whisper model and dynamically swapping
between multiple LoRA adapters at runtime, saving GPU memory when working with
multiple language-specific fine-tuned models.
Usage:
from whisperlivekit.whisper import load_model
from whisperlivekit.whisper.lora import LoRAAdapterManager
# Load base model without any LoRA baked in
model = load_model("large-v3", device="cuda")
# Create adapter manager
manager = LoRAAdapterManager(model)
# Load multiple adapters (small memory footprint each)
manager.load_adapter("french", "path/to/french-lora")
manager.load_adapter("spanish", "path/to/spanish-lora")
# Switch between adapters at runtime
manager.set_adapter("french")
result_fr = model.transcribe(audio_fr)
manager.set_adapter("spanish")
result_es = model.transcribe(audio_es)
# Disable LoRA (use base model only)
manager.set_adapter(None)
"""
import json
import os
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
import torch
from torch import Tensor, nn
from .model import Linear
@dataclass
class LoRAConfig:
"""Configuration for a LoRA adapter."""
r: int # LoRA rank
alpha: float # LoRA alpha (scaling factor)
target_modules: List[str] = field(default_factory=list)
@property
def scaling(self) -> float:
return self.alpha / self.r
@dataclass
class LoRAAdapter:
"""Holds the LoRA A/B weight matrices for a single adapter."""
name: str
config: LoRAConfig
# Maps target module name -> (A matrix, B matrix)
weights: Dict[str, Tuple[Tensor, Tensor]] = field(default_factory=dict)
device: torch.device = field(default_factory=lambda: torch.device("cpu"))
dtype: torch.dtype = field(default=torch.float32)
def to(self, device: torch.device, dtype: Optional[torch.dtype] = None):
"""Move adapter weights to specified device/dtype."""
self.device = device
if dtype is not None:
self.dtype = dtype
self.weights = {
name: (a.to(device=device, dtype=dtype or self.dtype),
b.to(device=device, dtype=dtype or self.dtype))
for name, (a, b) in self.weights.items()
}
return self
def memory_footprint_mb(self) -> float:
"""Return approximate memory usage in MB."""
total_bytes = 0
for a, b in self.weights.values():
total_bytes += a.numel() * a.element_size()
total_bytes += b.numel() * b.element_size()
return total_bytes / (1024 * 1024)
class LoRALinear(nn.Module):
"""
A Linear layer wrapper that supports dynamic LoRA injection.
The base weights remain unchanged. LoRA is applied additively during forward:
output = base_linear(x) + (x @ A @ B) * scaling
"""
def __init__(self, base_linear: Linear):
super().__init__()
self.base_linear = base_linear
self.lora_A: Optional[Tensor] = None
self.lora_B: Optional[Tensor] = None
self.scaling: float = 1.0
self._lora_enabled: bool = False
def set_lora(self, A: Optional[Tensor], B: Optional[Tensor], scaling: float = 1.0):
"""Set the LoRA matrices for this layer."""
self.lora_A = A
self.lora_B = B
self.scaling = scaling
self._lora_enabled = A is not None and B is not None
def clear_lora(self):
"""Remove LoRA from this layer."""
self.lora_A = None
self.lora_B = None
self._lora_enabled = False
def forward(self, x: Tensor) -> Tensor:
# Base linear output
out = self.base_linear(x)
# Add LoRA contribution if enabled
if self._lora_enabled and self.lora_A is not None and self.lora_B is not None:
# x: (..., in_features)
# A: (in_features, r)
# B: (r, out_features)
# lora_out: (..., out_features)
lora_out = (x @ self.lora_A.to(x.dtype)) @ self.lora_B.to(x.dtype)
out = out + lora_out * self.scaling
return out
# Delegate attribute access to base_linear for compatibility
@property
def weight(self):
return self.base_linear.weight
@property
def bias(self):
return self.base_linear.bias
@property
def in_features(self):
return self.base_linear.in_features
@property
def out_features(self):
return self.base_linear.out_features
# Mapping from HuggingFace LoRA module names to Whisper module paths
_HF_TO_WHISPER_MODULE_MAP = {
# Encoder attention
"model.encoder.layers.{}.self_attn.q_proj": "encoder.blocks.{}.attn.query",
"model.encoder.layers.{}.self_attn.k_proj": "encoder.blocks.{}.attn.key",
"model.encoder.layers.{}.self_attn.v_proj": "encoder.blocks.{}.attn.value",
"model.encoder.layers.{}.self_attn.out_proj": "encoder.blocks.{}.attn.out",
# Encoder MLP
"model.encoder.layers.{}.fc1": "encoder.blocks.{}.mlp.0",
"model.encoder.layers.{}.fc2": "encoder.blocks.{}.mlp.2",
# Decoder self-attention
"model.decoder.layers.{}.self_attn.q_proj": "decoder.blocks.{}.attn.query",
"model.decoder.layers.{}.self_attn.k_proj": "decoder.blocks.{}.attn.key",
"model.decoder.layers.{}.self_attn.v_proj": "decoder.blocks.{}.attn.value",
"model.decoder.layers.{}.self_attn.out_proj": "decoder.blocks.{}.attn.out",
# Decoder cross-attention
"model.decoder.layers.{}.encoder_attn.q_proj": "decoder.blocks.{}.cross_attn.query",
"model.decoder.layers.{}.encoder_attn.k_proj": "decoder.blocks.{}.cross_attn.key",
"model.decoder.layers.{}.encoder_attn.v_proj": "decoder.blocks.{}.cross_attn.value",
"model.decoder.layers.{}.encoder_attn.out_proj": "decoder.blocks.{}.cross_attn.out",
# Decoder MLP
"model.decoder.layers.{}.fc1": "decoder.blocks.{}.mlp.0",
"model.decoder.layers.{}.fc2": "decoder.blocks.{}.mlp.2",
}
def _normalize_hf_module_name(name: str) -> str:
"""Normalize HF-style LoRA module names."""
if name.startswith("base_model."):
name = name[len("base_model."):]
if name.startswith("model.model."):
name = name[len("model."):]
if not name.startswith("model."):
name = f"model.{name}"
return name
def _map_hf_to_whisper_module(hf_name: str) -> Optional[str]:
"""Map a HuggingFace LoRA module name to Whisper module path."""
hf_name = _normalize_hf_module_name(hf_name)
# Try to match with layer index patterns
import re
# Match patterns like model.encoder.layers.5.self_attn.q_proj
for pattern, target_pattern in _HF_TO_WHISPER_MODULE_MAP.items():
# Create regex from pattern (replace {} with capture group)
regex = pattern.replace(".", r"\.").replace("{}", r"(\d+)")
match = re.fullmatch(regex, hf_name)
if match:
layer_idx = match.group(1)
return target_pattern.format(layer_idx)
return None
def _get_module_by_path(model: nn.Module, path: str) -> Optional[nn.Module]:
"""Get a submodule by dot-separated path."""
parts = path.split(".")
current = model
for part in parts:
if hasattr(current, part):
current = getattr(current, part)
elif hasattr(current, "__getitem__"):
try:
current = current[int(part)]
except (ValueError, IndexError, KeyError):
return None
else:
return None
return current
def _set_module_by_path(model: nn.Module, path: str, module: nn.Module):
"""Set a submodule by dot-separated path."""
parts = path.split(".")
parent = model
for part in parts[:-1]:
if hasattr(parent, part):
parent = getattr(parent, part)
elif hasattr(parent, "__getitem__"):
parent = parent[int(part)]
setattr(parent, parts[-1], module)
class LoRAAdapterManager:
"""
Manages multiple LoRA adapters for a Whisper model.
Enables loading multiple adapters and switching between them at runtime
without reloading the full model.
"""
def __init__(self, model: nn.Module):
"""
Initialize the adapter manager.
Args:
model: A Whisper model instance
"""
self.model = model
self.adapters: Dict[str, LoRAAdapter] = {}
self.current_adapter: Optional[str] = None
self._lora_layers: Dict[str, LoRALinear] = {}
self._original_layers: Dict[str, Linear] = {}
self._initialized = False
def _initialize_lora_layers(self, target_modules: List[str]):
"""
Replace target Linear layers with LoRALinear wrappers.
This is done lazily on first adapter load.
"""
if self._initialized:
return
# Find and wrap all potential LoRA target modules
for whisper_path in target_modules:
module = _get_module_by_path(self.model, whisper_path)
if module is None:
continue
if isinstance(module, Linear) and not isinstance(module, LoRALinear):
# Wrap the Linear layer
lora_linear = LoRALinear(module)
_set_module_by_path(self.model, whisper_path, lora_linear)
self._lora_layers[whisper_path] = lora_linear
self._original_layers[whisper_path] = module
self._initialized = True
def _resolve_lora_path(self, lora_path: str) -> str:
"""Resolve LoRA path, downloading from HuggingFace Hub if needed."""
if os.path.isdir(lora_path):
return lora_path
# Try HuggingFace Hub
if "/" in lora_path:
try:
from huggingface_hub import snapshot_download
return snapshot_download(
repo_id=lora_path,
allow_patterns=["adapter_config.json", "adapter_model.*"],
)
except Exception as e:
raise FileNotFoundError(
f"Could not find LoRA adapter at local path or HuggingFace Hub: {lora_path}. Error: {e}"
)
raise FileNotFoundError(f"LoRA path '{lora_path}' not found.")
def _load_adapter_weights(self, lora_path: str) -> Dict[str, Tensor]:
"""Load adapter weights from safetensors or bin file."""
safe_path = os.path.join(lora_path, "adapter_model.safetensors")
bin_path = os.path.join(lora_path, "adapter_model.bin")
if os.path.isfile(safe_path):
from safetensors.torch import load_file
return load_file(safe_path)
elif os.path.isfile(bin_path):
return torch.load(bin_path, map_location="cpu")
else:
raise FileNotFoundError(
f"No adapter weights found in {lora_path}. "
"Expected adapter_model.safetensors or adapter_model.bin."
)
def load_adapter(
self,
name: str,
lora_path: str,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> LoRAAdapter:
"""
Load a LoRA adapter from disk or HuggingFace Hub.
Args:
name: Unique name for this adapter (e.g., "french", "spanish")
lora_path: Local path or HuggingFace repo ID
device: Device to load weights to (default: model's device)
dtype: Data type for weights (default: model's dtype)
Returns:
The loaded LoRAAdapter
"""
if device is None:
device = next(self.model.parameters()).device
if dtype is None:
dtype = next(self.model.parameters()).dtype
# Resolve path
lora_path = self._resolve_lora_path(lora_path)
# Load config
config_path = os.path.join(lora_path, "adapter_config.json")
if not os.path.isfile(config_path):
raise FileNotFoundError(f"Missing adapter_config.json in {lora_path}")
with open(config_path, "r", encoding="utf-8") as f:
config_dict = json.load(f)
if config_dict.get("peft_type") != "LORA":
raise ValueError("Only LoRA adapters are supported.")
config = LoRAConfig(
r=config_dict["r"],
alpha=config_dict.get("lora_alpha") or config_dict.get("alpha"),
target_modules=config_dict.get("target_modules", []),
)
# Load weights
adapter_state = self._load_adapter_weights(lora_path)
# Parse LoRA A/B matrices and map to Whisper module paths
lora_layers: Dict[str, Dict[str, Tensor]] = {}
for key, tensor in adapter_state.items():
if key.endswith("lora_A.weight"):
module = key[:-len(".lora_A.weight")]
lora_layers.setdefault(module, {})["A"] = tensor
elif key.endswith("lora_B.weight"):
module = key[:-len(".lora_B.weight")]
lora_layers.setdefault(module, {})["B"] = tensor
# Map to Whisper module paths and collect weights
weights: Dict[str, Tuple[Tensor, Tensor]] = {}
whisper_paths = set()
for hf_module, parts in lora_layers.items():
if "A" not in parts or "B" not in parts:
continue
whisper_path = _map_hf_to_whisper_module(hf_module)
if whisper_path is None:
# Try direct mapping (module might already be in Whisper format)
whisper_path = hf_module
# A: (r, in_features) -> transpose to (in_features, r)
# B: (out_features, r) -> transpose to (r, out_features)
A = parts["A"].T # (in_features, r)
B = parts["B"].T # (r, out_features)
weights[whisper_path] = (A, B)
whisper_paths.add(whisper_path)
# Create adapter
adapter = LoRAAdapter(
name=name,
config=config,
weights=weights,
device=device,
dtype=dtype,
)
adapter.to(device, dtype)
# Initialize LoRA layers if not done yet
self._initialize_lora_layers(list(whisper_paths))
# Store adapter
self.adapters[name] = adapter
return adapter
def set_adapter(self, name: Optional[str]):
"""
Switch to a different adapter or disable LoRA.
Args:
name: Adapter name to activate, or None to disable all LoRA
"""
if name is not None and name not in self.adapters:
raise KeyError(f"Adapter '{name}' not loaded. Available: {list(self.adapters.keys())}")
# Clear all LoRA from layers
for lora_linear in self._lora_layers.values():
lora_linear.clear_lora()
self.current_adapter = name
if name is None:
return
# Apply the selected adapter
adapter = self.adapters[name]
for module_path, (A, B) in adapter.weights.items():
if module_path in self._lora_layers:
self._lora_layers[module_path].set_lora(A, B, adapter.config.scaling)
def unload_adapter(self, name: str):
"""
Unload an adapter from memory.
Args:
name: Name of adapter to unload
"""
if name not in self.adapters:
return
if self.current_adapter == name:
self.set_adapter(None)
del self.adapters[name]
def list_adapters(self) -> List[str]:
"""Return list of loaded adapter names."""
return list(self.adapters.keys())
def get_memory_usage(self) -> Dict[str, float]:
"""Return memory usage in MB for each loaded adapter."""
return {name: adapter.memory_footprint_mb() for name, adapter in self.adapters.items()}
def restore_original_layers(self):
"""
Restore the original Linear layers, removing LoRA wrappers.
Call this if you want to go back to the original model structure.
"""
for path, original in self._original_layers.items():
_set_module_by_path(self.model, path, original)
self._lora_layers.clear()
self._original_layers.clear()
self._initialized = False
self.current_adapter = None