mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-08 06:44:09 +00:00
Compare commits
38 Commits
0.2.17.pos
...
api_live
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d9a4c8dcb2 | ||
|
|
4fb735a784 | ||
|
|
d2f998cb7e | ||
|
|
7b18917f2b | ||
|
|
f1113e3eb0 | ||
|
|
cc5f819ce7 | ||
|
|
82cd24bb75 | ||
|
|
d45c397c6a | ||
|
|
45bf3f57d7 | ||
|
|
1d88ba9d69 | ||
|
|
c0965c6c31 | ||
|
|
34ddd2ac02 | ||
|
|
345d781e97 | ||
|
|
28cf831701 | ||
|
|
60c62f8f84 | ||
|
|
7faa21f95f | ||
|
|
4e9f951551 | ||
|
|
870141298c | ||
|
|
872faa422a | ||
|
|
fc9cb66813 | ||
|
|
a175d1a327 | ||
|
|
6206fff118 | ||
|
|
b5067249c0 | ||
|
|
f4f9831d39 | ||
|
|
254faaf64c | ||
|
|
8e7aea4fcf | ||
|
|
270faf2069 | ||
|
|
b7c1cc77cc | ||
|
|
9a45ec221c | ||
|
|
3e13ee6fc3 | ||
|
|
b7d20a0ff0 | ||
|
|
c1bb9c2bde | ||
|
|
11e9def0b2 | ||
|
|
3104f40f6e | ||
|
|
e9b4ceeee5 | ||
|
|
437641fb43 | ||
|
|
bfd60b3921 | ||
|
|
1e67bf97f0 |
@@ -37,10 +37,9 @@ RUN pip3 install --upgrade pip setuptools wheel && \
|
|||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
# Install WhisperLiveKit directly, allowing for optional dependencies
|
# Install WhisperLiveKit directly, allowing for optional dependencies
|
||||||
# Example: --build-arg EXTRAS="translation"
|
|
||||||
RUN if [ -n "$EXTRAS" ]; then \
|
RUN if [ -n "$EXTRAS" ]; then \
|
||||||
echo "Installing with extras: [$EXTRAS]"; \
|
echo "Installing with extras: [$EXTRAS]"; \
|
||||||
pip install --no-cache-dir "whisperlivekit[$EXTRAS]"; \
|
pip install --no-cache-dir whisperlivekit[$EXTRAS]; \
|
||||||
else \
|
else \
|
||||||
echo "Installing base package only"; \
|
echo "Installing base package only"; \
|
||||||
pip install --no-cache-dir whisperlivekit; \
|
pip install --no-cache-dir whisperlivekit; \
|
||||||
|
|||||||
@@ -147,8 +147,8 @@ async def websocket_endpoint(websocket: WebSocket):
|
|||||||
|-----------|-------------|---------|
|
|-----------|-------------|---------|
|
||||||
| `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/default_and_custom_models.md) | `small` |
|
| `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/default_and_custom_models.md) | `small` |
|
||||||
| `--model-path` | Local .pt file/directory **or** Hugging Face repo ID containing the Whisper model. Overrides `--model`. Recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/default_and_custom_models.md) | `None` |
|
| `--model-path` | Local .pt file/directory **or** Hugging Face repo ID containing the Whisper model. Overrides `--model`. Recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/default_and_custom_models.md) | `None` |
|
||||||
| `--language` | List [here](docs/supported_languages.md). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` |
|
| `--language` | List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` |
|
||||||
| `--target-language` | If sets, translates using [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting). [200 languages available](docs/supported_languages.md). If you want to translate to english, you can also use `--direct-english-translation`. The STT model will try to directly output the translation. | `None` |
|
| `--target-language` | If sets, translates using [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting). [200 languages available](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/supported_languages.md). If you want to translate to english, you can also use `--direct-english-translation`. The STT model will try to directly output the translation. | `None` |
|
||||||
| `--diarization` | Enable speaker identification | `False` |
|
| `--diarization` | Enable speaker identification | `False` |
|
||||||
| `--backend-policy` | Streaming strategy: `1`/`simulstreaming` uses AlignAtt SimulStreaming, `2`/`localagreement` uses the LocalAgreement policy | `simulstreaming` |
|
| `--backend-policy` | Streaming strategy: `1`/`simulstreaming` uses AlignAtt SimulStreaming, `2`/`localagreement` uses the LocalAgreement policy | `simulstreaming` |
|
||||||
| `--backend` | Whisper implementation selector. `auto` picks MLX on macOS (if installed), otherwise Faster-Whisper, otherwise vanilla Whisper. You can also force `mlx-whisper`, `faster-whisper`, `whisper`, or `openai-api` (LocalAgreement only) | `auto` |
|
| `--backend` | Whisper implementation selector. `auto` picks MLX on macOS (if installed), otherwise Faster-Whisper, otherwise vanilla Whisper. You can also force `mlx-whisper`, `faster-whisper`, `whisper`, or `openai-api` (LocalAgreement only) | `auto` |
|
||||||
@@ -267,7 +267,7 @@ docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
|
|||||||
#### Customization
|
#### Customization
|
||||||
|
|
||||||
- `--build-arg` Options:
|
- `--build-arg` Options:
|
||||||
- `EXTRAS="translation"` - Add extras to the image's installation (no spaces). Remember to set necessary container options!
|
- `EXTRAS="whisper-timestamped"` - Add extras to the image's installation (no spaces). Remember to set necessary container options!
|
||||||
- `HF_PRECACHE_DIR="./.cache/"` - Pre-load a model cache for faster first-time start
|
- `HF_PRECACHE_DIR="./.cache/"` - Pre-load a model cache for faster first-time start
|
||||||
- `HF_TKN_FILE="./token"` - Add your Hugging Face Hub access token to download gated models
|
- `HF_TKN_FILE="./token"` - Add your Hugging Face Hub access token to download gated models
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ Capture the audio of your current tab, transcribe diarize and translate it using
|
|||||||
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/chrome-extension/demo-extension.png" alt="WhisperLiveKit Demo" width="730">
|
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/chrome-extension/demo-extension.png" alt="WhisperLiveKit Demo" width="730">
|
||||||
|
|
||||||
## Running this extension
|
## Running this extension
|
||||||
1. Run `python scripts/sync_extension.py` to copy frontend files to the `chrome-extension` directory.
|
1. Run `python sync_extension.py` to copy frontend files to the `chrome-extension` directory.
|
||||||
2. Load the `chrome-extension` directory in Chrome as an unpacked extension.
|
2. Load the `chrome-extension` directory in Chrome as an unpacked extension.
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
299
docs/API.md
299
docs/API.md
@@ -1,53 +1,22 @@
|
|||||||
# WhisperLiveKit WebSocket API Documentation
|
# WhisperLiveKit WebSocket API Documentation
|
||||||
|
|
||||||
> !! **Note**: The new API structure described in this document is currently under deployment.
|
WLK provides real-time speech transcription, speaker diarization, and translation through a WebSocket API. The server sends updates as audio is processed, allowing clients to display live transcription results with minimal latency.
|
||||||
This documentation is intended for devs who want to build custom frontends.
|
|
||||||
|
|
||||||
WLK provides real-time speech transcription, speaker diarization, and translation through a WebSocket API. The server sends incremental updates as audio is processed, allowing clients to display live transcription results with minimal latency.
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Legacy API (Current)
|
## Endpoints
|
||||||
|
|
||||||
### Message Structure
|
| Endpoint | Description |
|
||||||
|
|----------|-------------|
|
||||||
The current API sends complete state snapshots on each update (several time per second)
|
| `/` | Main web interface with visual styling |
|
||||||
|
| `/text` | Simple text-based interface for easy copy/paste (debug/development) |
|
||||||
```typescript
|
| `/asr` | WebSocket endpoint for audio streaming |
|
||||||
{
|
|
||||||
"type": str,
|
|
||||||
"status": str,
|
|
||||||
"lines": [
|
|
||||||
{
|
|
||||||
"speaker": int,
|
|
||||||
"text": str,
|
|
||||||
"start": float,
|
|
||||||
"end": float,
|
|
||||||
"translation": str | null,
|
|
||||||
"detected_language": str
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"buffer_transcription": str,
|
|
||||||
"buffer_diarization": str,
|
|
||||||
"remaining_time_transcription": float,
|
|
||||||
"remaining_time_diarization": float
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## New API (Under Development)
|
|
||||||
|
|
||||||
### Philosophy
|
|
||||||
|
|
||||||
Principles:
|
|
||||||
|
|
||||||
- **Incremental Updates**: Only updates and new segments are sent
|
|
||||||
- **Ephemeral Buffers**: Temporary, unvalidated data displayed in real-time but overwritten on next update, at speaker level
|
|
||||||
|
|
||||||
|
|
||||||
## Message Format
|
## Message Format
|
||||||
|
|
||||||
|
### Transcript Update (Server → Client)
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
{
|
{
|
||||||
@@ -58,22 +27,11 @@ Principles:
|
|||||||
"id": number,
|
"id": number,
|
||||||
"speaker": number,
|
"speaker": number,
|
||||||
"text": string,
|
"text": string,
|
||||||
"start_speaker": float,
|
"start_speaker": string, // HH:MM:SS format
|
||||||
"start": float,
|
"start": string, // HH:MM:SS format
|
||||||
"end": float,
|
"end": string, // HH:MM:SS format
|
||||||
"language": string | null,
|
"language": string | null,
|
||||||
"translation": string,
|
"translation": string,
|
||||||
"words": [
|
|
||||||
{
|
|
||||||
"text": string,
|
|
||||||
"start": float,
|
|
||||||
"end": float,
|
|
||||||
"validated": {
|
|
||||||
"text": boolean,
|
|
||||||
"speaker": boolean,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"buffer": {
|
"buffer": {
|
||||||
"transcription": string,
|
"transcription": string,
|
||||||
"diarization": string,
|
"diarization": string,
|
||||||
@@ -94,9 +52,10 @@ Principles:
|
|||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"type": "config",
|
"type": "config",
|
||||||
"useAudioWorklet": true / false
|
"useAudioWorklet": true
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
- `useAudioWorklet`: If `true`, client should use AudioWorklet for PCM streaming. If `false`, use MediaRecorder for WebM.
|
||||||
|
|
||||||
#### Ready to Stop Message (sent after processing complete)
|
#### Ready to Stop Message (sent after processing complete)
|
||||||
```json
|
```json
|
||||||
@@ -104,6 +63,7 @@ Principles:
|
|||||||
"type": "ready_to_stop"
|
"type": "ready_to_stop"
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
Indicates all audio has been processed and the client can safely close the connection.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -113,152 +73,179 @@ Principles:
|
|||||||
|
|
||||||
| Field | Type | Description |
|
| Field | Type | Description |
|
||||||
|-------|------|-------------|
|
|-------|------|-------------|
|
||||||
| `id` | `number` | Unique identifier for this segment. Used by clients to update specific segments efficiently. |
|
| `id` | `number` | Unique identifier for this segment. |
|
||||||
| `speaker` | `number` | Speaker ID (1, 2, 3...). Special value `-2` indicates silence. |
|
| `speaker` | `number` | Speaker ID (1, 2, 3...). Special value `-2` indicates silence. |
|
||||||
| `text` | `string` | Validated transcription text for this update. Should be **appended** to the segment's text on the client side. |
|
| `text` | `string` | Validated transcription text. |
|
||||||
| `start_speaker` | `float` | Timestamp (seconds) when this speaker segment began. |
|
| `start_speaker` | `string` | Timestamp (HH:MM:SS) when this speaker segment began. |
|
||||||
| `start` | `float` | Timestamp (seconds) of the first word in this update. |
|
| `start` | `string` | Timestamp (HH:MM:SS) of the first word. |
|
||||||
| `end` | `float` | Timestamp (seconds) of the last word in this update. |
|
| `end` | `string` | Timestamp (HH:MM:SS) of the last word. |
|
||||||
| `language` | `string \| null` | ISO language code (e.g., "en", "fr"). `null` until language is detected. |
|
| `language` | `string \| null` | ISO language code (e.g., "en", "fr"). `null` until detected. |
|
||||||
| `translation` | `string` | Validated translation text for this update. Should be **appended** to the segment's translation on the client side. |
|
| `translation` | `string` | Validated translation text. |
|
||||||
| `words` | `Array` | Array of word-level objects with timing and validation information. |
|
| `buffer` | `Object` | Per-segment temporary buffers (see below). |
|
||||||
| `buffer` | `Object` | Per-segment temporary buffers, see below |
|
|
||||||
|
|
||||||
### Word Object
|
|
||||||
|
|
||||||
| Field | Type | Description |
|
|
||||||
|-------|------|-------------|
|
|
||||||
| `text` | `string` | The word text. |
|
|
||||||
| `start` | `number` | Start timestamp (seconds) of this word. |
|
|
||||||
| `end` | `number` | End timestamp (seconds) of this word. |
|
|
||||||
| `validated.text` | `boolean` | Whether the transcription text has been validated. if false, word is also in buffer: transcription |
|
|
||||||
| `validated.speaker` | `boolean` | Whether the speaker assignment has been validated. if false, word is also in buffer: diarization |
|
|
||||||
| `validated.language` | `boolean` | Whether the language detection has been validated. if false, word is also in buffer: translation |
|
|
||||||
|
|
||||||
### Buffer Object (Per-Segment)
|
### Buffer Object (Per-Segment)
|
||||||
|
|
||||||
Buffers are **ephemeral**. They should be displayed to the user but not stored permanently in the frontend. Each update may contain a completely different buffer value, and previous buffer is likely to be in the next validated text.
|
Buffers are **ephemeral**. They should be displayed to the user but are overwritten on each update. Only the **last non-silent segment** contains buffer content.
|
||||||
|
|
||||||
| Field | Type | Description |
|
| Field | Type | Description |
|
||||||
|-------|------|-------------|
|
|-------|------|-------------|
|
||||||
| `transcription` | `string` | Pending transcription text. Displayed immediately but **overwritten** on next update. |
|
| `transcription` | `string` | Text pending validation (waiting for more context). |
|
||||||
| `diarization` | `string` | Pending diarization text (text waiting for speaker assignment). Displayed immediately but **overwritten** on next update. |
|
| `diarization` | `string` | Text pending speaker assignment (diarization hasn't caught up). |
|
||||||
| `translation` | `string` | Pending translation text. Displayed immediately but **overwritten** on next update. |
|
| `translation` | `string` | Translation pending validation. |
|
||||||
|
|
||||||
|
|
||||||
### Metadata Fields
|
### Metadata Fields
|
||||||
|
|
||||||
| Field | Type | Description |
|
| Field | Type | Description |
|
||||||
|-------|------|-------------|
|
|-------|------|-------------|
|
||||||
| `remaining_time_transcription` | `float` | Seconds of audio waiting for transcription processing. |
|
| `remaining_time_transcription` | `float` | Seconds of audio waiting for transcription. |
|
||||||
| `remaining_time_diarization` | `float` | Seconds of audio waiting for speaker diarization. |
|
| `remaining_time_diarization` | `float` | Seconds of audio waiting for diarization. |
|
||||||
|
|
||||||
### Status Values
|
### Status Values
|
||||||
|
|
||||||
| Status | Description |
|
| Status | Description |
|
||||||
|--------|-------------|
|
|--------|-------------|
|
||||||
| `active_transcription` | Normal operation, transcription is active. |
|
| `active_transcription` | Normal operation, transcription is active. |
|
||||||
| `no_audio_detected` | No audio has been detected yet. |
|
| `no_audio_detected` | No audio/speech has been detected yet. |
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Update Behavior
|
## Behavior Notes
|
||||||
|
|
||||||
### Incremental Updates
|
### Silence Handling
|
||||||
|
|
||||||
The API sends **only changed or new segments**. Clients should:
|
- **Short silences (< 2 seconds)** are filtered out and not displayed.
|
||||||
|
- Only significant pauses appear as silence segments with `speaker: -2`.
|
||||||
|
- Consecutive same-speaker segments are merged even across short silences.
|
||||||
|
|
||||||
1. Maintain a local map of segments by ID
|
### Update Frequency
|
||||||
2. When receiving an update, merge/update segments by ID
|
|
||||||
3. Render only the changed segments
|
|
||||||
|
|
||||||
### Language Detection
|
- **Active transcription**: ~20 updates/second (every 50ms)
|
||||||
|
- **During silence**: ~2 updates/second (every 500ms) to reduce bandwidth
|
||||||
|
|
||||||
When language is detected for a segment:
|
### Token-by-Token Validation (Diarization Mode)
|
||||||
|
|
||||||
```jsonc
|
When diarization is enabled, text is validated **token-by-token** as soon as diarization covers each token, rather than waiting for punctuation. This provides:
|
||||||
// Update 1: No language yet
|
- Faster text validation
|
||||||
{
|
- More responsive speaker attribution
|
||||||
"segments": [
|
- Buffer only contains tokens that diarization hasn't processed yet
|
||||||
{"id": 1, "speaker": 1, "text": "May see", "language": null}
|
|
||||||
]
|
---
|
||||||
}
|
|
||||||
|
## Example Messages
|
||||||
// Update 2: Same segment ID, language now detected
|
|
||||||
{
|
### Normal Transcription
|
||||||
"segments": [
|
|
||||||
{"id": 1, "speaker": 1, "text": "Merci", "language": "fr"}
|
```json
|
||||||
]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
**Client behavior**: **Replace** the existing segment with the same ID.
|
|
||||||
|
|
||||||
### Buffer Behavior
|
|
||||||
|
|
||||||
Buffers are **per-segment** to handle multi-speaker scenarios correctly.
|
|
||||||
|
|
||||||
#### Example: Translation with diarization and translation
|
|
||||||
|
|
||||||
```jsonc
|
|
||||||
// Update 1
|
|
||||||
{
|
{
|
||||||
|
"type": "transcript_update",
|
||||||
|
"status": "active_transcription",
|
||||||
"segments": [
|
"segments": [
|
||||||
{
|
{
|
||||||
"id": 1,
|
"id": 1,
|
||||||
"speaker": 1,
|
"speaker": 1,
|
||||||
"text": "Hello world, how are",
|
"text": "Hello, how are you today?",
|
||||||
|
"start_speaker": "0:00:02",
|
||||||
|
"start": "0:00:02",
|
||||||
|
"end": "0:00:05",
|
||||||
|
"language": "en",
|
||||||
|
"translation": "",
|
||||||
|
"buffer": {
|
||||||
|
"transcription": " I'm doing",
|
||||||
|
"diarization": "",
|
||||||
|
"translation": ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"remaining_time_transcription": 0.5,
|
||||||
|
"remaining_time_diarization": 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### With Diarization Buffer
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "transcript_update",
|
||||||
|
"status": "active_transcription",
|
||||||
|
"segments": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"speaker": 1,
|
||||||
|
"text": "The meeting starts at nine.",
|
||||||
|
"start_speaker": "0:00:03",
|
||||||
|
"start": "0:00:03",
|
||||||
|
"end": "0:00:06",
|
||||||
|
"language": "en",
|
||||||
"translation": "",
|
"translation": "",
|
||||||
"buffer": {
|
"buffer": {
|
||||||
"transcription": "",
|
"transcription": "",
|
||||||
"diarization": " you on",
|
"diarization": " Let me check my calendar",
|
||||||
"translation": "Bonjour le monde"
|
"translation": ""
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
|
"metadata": {
|
||||||
|
"remaining_time_transcription": 0.3,
|
||||||
|
"remaining_time_diarization": 2.1
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// ==== Frontend ====
|
|
||||||
// <SPEAKER>1</SPEAKER>
|
|
||||||
// <TRANSCRIPTION>Hello world, how are <DIARIZATION BUFFER> you on</DIARIZATION BUFFER></TRANSCRIPTION>
|
|
||||||
// <TRANSLATION><TRANSLATION BUFFER>Bonjour le monde</TRANSLATION BUFFER></TRANSLATION>
|
|
||||||
|
|
||||||
|
|
||||||
// Update 2
|
|
||||||
{
|
|
||||||
"segments": [
|
|
||||||
{
|
|
||||||
"id": 1,
|
|
||||||
"speaker": 1,
|
|
||||||
"text": " you on this",
|
|
||||||
"translation": "Bonjour tout le monde",
|
|
||||||
"buffer": {
|
|
||||||
"transcription": "",
|
|
||||||
"diarization": " beautiful day",
|
|
||||||
"translation": ",comment"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// ==== Frontend ====
|
|
||||||
// <SPEAKER>1</SPEAKER>
|
|
||||||
// <TRANSCRIPTION>Hello world, how are you on this<DIARIZATION BUFFER> beautiful day</DIARIZATION BUFFER></TRANSCRIPTION>
|
|
||||||
// <TRANSLATION>Bonjour tout le monde<TRANSLATION BUFFER>, comment</TRANSLATION BUFFER><TRANSLATION>
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Silence Segments
|
### Silence Segment
|
||||||
|
|
||||||
Silence is represented with the speaker id = `-2`:
|
```json
|
||||||
|
|
||||||
```jsonc
|
|
||||||
{
|
{
|
||||||
"id": 5,
|
"id": 5,
|
||||||
"speaker": -2,
|
"speaker": -2,
|
||||||
"text": "",
|
"text": "",
|
||||||
"start": 10.5,
|
"start_speaker": "0:00:10",
|
||||||
"end": 12.3
|
"start": "0:00:10",
|
||||||
|
"end": "0:00:15",
|
||||||
|
"language": null,
|
||||||
|
"translation": "",
|
||||||
|
"buffer": {
|
||||||
|
"transcription": "",
|
||||||
|
"diarization": "",
|
||||||
|
"translation": ""
|
||||||
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Text Transcript Endpoint (`/text`)
|
||||||
|
|
||||||
|
The `/text` endpoint provides a simple, monospace text interface designed for:
|
||||||
|
- Easy copy/paste of transcripts
|
||||||
|
- Debugging and development
|
||||||
|
- Integration testing
|
||||||
|
|
||||||
|
Output uses text markers instead of HTML styling:
|
||||||
|
|
||||||
|
```
|
||||||
|
[METADATA transcription_lag=0.5s diarization_lag=1.2s]
|
||||||
|
|
||||||
|
[SPEAKER 1] 0:00:03 - 0:00:11 [LANG: en]
|
||||||
|
Hello world, how are you doing today?[DIAR_BUFFER] I'm doing fine[/DIAR_BUFFER]
|
||||||
|
|
||||||
|
[SILENCE 0:00:15 - 0:00:18]
|
||||||
|
|
||||||
|
[SPEAKER 2] 0:00:18 - 0:00:22 [LANG: en]
|
||||||
|
That's great to hear!
|
||||||
|
[TRANSLATION]C'est super à entendre![/TRANSLATION]
|
||||||
|
```
|
||||||
|
|
||||||
|
### Markers
|
||||||
|
|
||||||
|
| Marker | Description |
|
||||||
|
|--------|-------------|
|
||||||
|
| `[SPEAKER N]` | Speaker label with ID |
|
||||||
|
| `[SILENCE start - end]` | Silence segment |
|
||||||
|
| `[LANG: xx]` | Detected language code |
|
||||||
|
| `[DIAR_BUFFER]...[/DIAR_BUFFER]` | Text pending speaker assignment |
|
||||||
|
| `[TRANS_BUFFER]...[/TRANS_BUFFER]` | Text pending validation |
|
||||||
|
| `[TRANSLATION]...[/TRANSLATION]` | Translation content |
|
||||||
|
| `[METADATA ...]` | Lag/timing information |
|
||||||
|
|
||||||
|
|||||||
@@ -1,13 +1,73 @@
|
|||||||
### Alignment between STT Tokens and Diarization Segments
|
# Alignment Principles
|
||||||
|
|
||||||
- Example 1: The punctuation from STT and the speaker change from Diariation come in the prediction `t`
|
This document explains how transcription tokens are aligned with diarization (speaker identification) segments.
|
||||||
- Example 2: The punctuation from STT comes from prediction `t`, but the speaker change from Diariation come in the prediction `t-1`
|
|
||||||
- Example 3: The punctuation from STT comes from prediction `t-1`, but the speaker change from Diariation come in the prediction `t`
|
|
||||||
|
|
||||||
> `#` Is the split between the `t-1` prediction and `t` prediction.
|
---
|
||||||
|
|
||||||
|
## Token-by-Token Validation
|
||||||
|
|
||||||
|
When diarization is enabled, text is validated **token-by-token** rather than waiting for sentence boundaries. As soon as diarization covers a token's time range, that token is validated and assigned to the appropriate speaker.
|
||||||
|
|
||||||
|
### How It Works
|
||||||
|
|
||||||
|
1. **Transcription produces tokens** with timestamps (start, end)
|
||||||
|
2. **Diarization produces speaker segments** with timestamps
|
||||||
|
3. **For each token**: Check if diarization has caught up to that token's time
|
||||||
|
- If yes → Find speaker with maximum overlap, validate token
|
||||||
|
- If no → Keep token in "pending" (becomes diarization buffer)
|
||||||
|
|
||||||
|
```
|
||||||
|
Timeline: 0s -------- 5s -------- 10s -------- 15s
|
||||||
|
| | | |
|
||||||
|
Transcription: [Hello, how are you doing today?]
|
||||||
|
|_______|___|____|_____|_____|_____|
|
||||||
|
tok1 tok2 tok3 tok4 tok5 tok6
|
||||||
|
|
||||||
|
Diarization: [SPEAKER 1 ][SPEAKER 2 ]
|
||||||
|
|__________________|__________________|
|
||||||
|
0s 8s 15s
|
||||||
|
|
||||||
|
At time t when diarization covers up to 8s:
|
||||||
|
- Tokens 1-4 (0s-7s) → Validated as SPEAKER 1
|
||||||
|
- Tokens 5-6 (7s-10s) → In buffer (diarization hasn't caught up)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Silence Handling
|
||||||
|
|
||||||
|
- **Short silences (< 2 seconds)**: Filtered out, not displayed
|
||||||
|
- **Significant silences (≥ 2 seconds)**: Displayed as silence segments with `speaker: -2`
|
||||||
|
- **Same speaker across gaps**: Segments are merged even if separated by short silences
|
||||||
|
|
||||||
|
```
|
||||||
|
Before filtering:
|
||||||
|
[SPK1 0:00-0:03] [SILENCE 0:03-0:04] [SPK1 0:04-0:08]
|
||||||
|
|
||||||
|
After filtering (silence < 2s):
|
||||||
|
[SPK1 0:00-0:08] ← Merged into single segment
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Buffer Types
|
||||||
|
|
||||||
|
| Buffer | Contains | Displayed When |
|
||||||
|
|--------|----------|----------------|
|
||||||
|
| `transcription` | Text awaiting validation (more context needed) | Always on last segment |
|
||||||
|
| `diarization` | Text awaiting speaker assignment | When diarization lags behind transcription |
|
||||||
|
| `translation` | Translation awaiting validation | When translation is enabled |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Legacy: Punctuation-Based Splitting
|
||||||
|
|
||||||
|
The previous approach split segments at punctuation marks and aligned with diarization at those boundaries. This is now replaced by token-by-token validation for faster, more responsive results.
|
||||||
|
|
||||||
|
### Historical Examples (for reference)
|
||||||
|
|
||||||
|
Example of punctuation-based alignment:
|
||||||
|
|
||||||
## Example 1:
|
|
||||||
```text
|
```text
|
||||||
punctuations_segments : __#_______.__________________!____
|
punctuations_segments : __#_______.__________________!____
|
||||||
diarization_segments:
|
diarization_segments:
|
||||||
@@ -16,56 +76,6 @@ SPK2 # ___________________
|
|||||||
-->
|
-->
|
||||||
ALIGNED SPK1 __#_______.
|
ALIGNED SPK1 __#_______.
|
||||||
ALIGNED SPK2 # __________________!____
|
ALIGNED SPK2 # __________________!____
|
||||||
|
|
||||||
t-1 output:
|
|
||||||
SPK1: __#
|
|
||||||
SPK2: NO
|
|
||||||
DIARIZATION BUFFER: NO
|
|
||||||
|
|
||||||
t output:
|
|
||||||
SPK1: __#__.
|
|
||||||
SPK2: __________________!____
|
|
||||||
DIARIZATION BUFFER: No
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Example 2:
|
With token-by-token validation, the alignment happens continuously rather than at punctuation boundaries.
|
||||||
```text
|
|
||||||
punctuations_segments : _____#__.___________
|
|
||||||
diarization_segments:
|
|
||||||
SPK1 ___ #
|
|
||||||
SPK2 __#______________
|
|
||||||
-->
|
|
||||||
ALIGNED SPK1 _____#__.
|
|
||||||
ALIGNED SPK2 # ___________
|
|
||||||
|
|
||||||
t-1 output:
|
|
||||||
SPK1: ___ #
|
|
||||||
SPK2:
|
|
||||||
DIARIZATION BUFFER: __#
|
|
||||||
|
|
||||||
t output:
|
|
||||||
SPK1: __#__.
|
|
||||||
SPK2: ___________
|
|
||||||
DIARIZATION BUFFER: No
|
|
||||||
```
|
|
||||||
|
|
||||||
## Example 3:
|
|
||||||
```text
|
|
||||||
punctuations_segments : ___.__#__________
|
|
||||||
diarization_segments:
|
|
||||||
SPK1 ______#__
|
|
||||||
SPK2 # ________
|
|
||||||
-->
|
|
||||||
ALIGNED SPK1 ___. #
|
|
||||||
ALIGNED SPK2 __#__________
|
|
||||||
|
|
||||||
t-1 output:
|
|
||||||
SPK1: ___. #
|
|
||||||
SPK2:
|
|
||||||
DIARIZATION BUFFER: __#
|
|
||||||
|
|
||||||
t output:
|
|
||||||
SPK1: #
|
|
||||||
SPK2: __#___________
|
|
||||||
DIARIZATION BUFFER: NO
|
|
||||||
```
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "whisperlivekit"
|
name = "whisperlivekit"
|
||||||
version = "0.2.17.post1"
|
version = "0.2.16.dev0"
|
||||||
description = "Real-time speech-to-text with speaker diarization using Whisper"
|
description = "Real-time speech-to-text with speaker diarization using Whisper"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
authors = [
|
authors = [
|
||||||
@@ -35,7 +35,6 @@ dependencies = [
|
|||||||
"torchaudio>=2.0.0",
|
"torchaudio>=2.0.0",
|
||||||
"torch>=2.0.0",
|
"torch>=2.0.0",
|
||||||
"huggingface-hub>=0.25.0",
|
"huggingface-hub>=0.25.0",
|
||||||
"faster-whisper>=1.2.0",
|
|
||||||
"tqdm",
|
"tqdm",
|
||||||
"tiktoken",
|
"tiktoken",
|
||||||
'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")'
|
'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")'
|
||||||
@@ -57,7 +56,6 @@ packages = [
|
|||||||
"whisperlivekit",
|
"whisperlivekit",
|
||||||
"whisperlivekit.diarization",
|
"whisperlivekit.diarization",
|
||||||
"whisperlivekit.simul_whisper",
|
"whisperlivekit.simul_whisper",
|
||||||
"whisperlivekit.simul_whisper.mlx",
|
|
||||||
"whisperlivekit.whisper",
|
"whisperlivekit.whisper",
|
||||||
"whisperlivekit.whisper.assets",
|
"whisperlivekit.whisper.assets",
|
||||||
"whisperlivekit.whisper.normalizers",
|
"whisperlivekit.whisper.normalizers",
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from .audio_processor import AudioProcessor
|
from .audio_processor import AudioProcessor
|
||||||
from .core import TranscriptionEngine
|
from .core import TranscriptionEngine
|
||||||
from .parse_args import parse_args
|
from .parse_args import parse_args
|
||||||
from .web.web_interface import get_inline_ui_html, get_web_interface_html
|
from .web.web_interface import get_inline_ui_html, get_text_transcript_html, get_web_interface_html
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TranscriptionEngine",
|
"TranscriptionEngine",
|
||||||
@@ -9,5 +9,6 @@ __all__ = [
|
|||||||
"parse_args",
|
"parse_args",
|
||||||
"get_web_interface_html",
|
"get_web_interface_html",
|
||||||
"get_inline_ui_html",
|
"get_inline_ui_html",
|
||||||
|
"get_text_transcript_html",
|
||||||
"download_simulstreaming_backend",
|
"download_simulstreaming_backend",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from whisperlivekit.core import (TranscriptionEngine,
|
|||||||
online_diarization_factory, online_factory,
|
online_diarization_factory, online_factory,
|
||||||
online_translation_factory)
|
online_translation_factory)
|
||||||
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
|
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
|
||||||
from whisperlivekit.silero_vad_iterator import FixedVADIterator, OnnxWrapper, load_jit_vad
|
from whisperlivekit.silero_vad_iterator import FixedVADIterator
|
||||||
from whisperlivekit.timed_objects import (ASRToken, ChangeSpeaker, FrontData,
|
from whisperlivekit.timed_objects import (ASRToken, ChangeSpeaker, FrontData,
|
||||||
Segment, Silence, State, Transcript)
|
Segment, Silence, State, Transcript)
|
||||||
from whisperlivekit.tokens_alignment import TokensAlignment
|
from whisperlivekit.tokens_alignment import TokensAlignment
|
||||||
@@ -32,7 +32,7 @@ async def get_all_from_queue(queue: asyncio.Queue) -> Union[object, Silence, np.
|
|||||||
if isinstance(first_item, Silence):
|
if isinstance(first_item, Silence):
|
||||||
return first_item
|
return first_item
|
||||||
items.append(first_item)
|
items.append(first_item)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
if not queue._queue:
|
if not queue._queue:
|
||||||
break
|
break
|
||||||
@@ -53,15 +53,15 @@ class AudioProcessor:
|
|||||||
Processes audio streams for transcription and diarization.
|
Processes audio streams for transcription and diarization.
|
||||||
Handles audio processing, state management, and result formatting.
|
Handles audio processing, state management, and result formatting.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any) -> None:
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
"""Initialize the audio processor with configuration, models, and state."""
|
"""Initialize the audio processor with configuration, models, and state."""
|
||||||
|
|
||||||
if 'transcription_engine' in kwargs and isinstance(kwargs['transcription_engine'], TranscriptionEngine):
|
if 'transcription_engine' in kwargs and isinstance(kwargs['transcription_engine'], TranscriptionEngine):
|
||||||
models = kwargs['transcription_engine']
|
models = kwargs['transcription_engine']
|
||||||
else:
|
else:
|
||||||
models = TranscriptionEngine(**kwargs)
|
models = TranscriptionEngine(**kwargs)
|
||||||
|
|
||||||
# Audio processing settings
|
# Audio processing settings
|
||||||
self.args = models.args
|
self.args = models.args
|
||||||
self.sample_rate = 16000
|
self.sample_rate = 16000
|
||||||
@@ -85,14 +85,12 @@ class AudioProcessor:
|
|||||||
|
|
||||||
# Models and processing
|
# Models and processing
|
||||||
self.asr: Any = models.asr
|
self.asr: Any = models.asr
|
||||||
self.vac: Optional[FixedVADIterator] = None
|
self.vac_model: Any = models.vac_model
|
||||||
|
|
||||||
if self.args.vac:
|
if self.args.vac:
|
||||||
if models.vac_session is not None:
|
self.vac: Optional[FixedVADIterator] = FixedVADIterator(models.vac_model)
|
||||||
vac_model = OnnxWrapper(session=models.vac_session)
|
else:
|
||||||
self.vac = FixedVADIterator(vac_model)
|
self.vac: Optional[FixedVADIterator] = None
|
||||||
else:
|
|
||||||
self.vac = FixedVADIterator(load_jit_vad())
|
|
||||||
self.ffmpeg_manager: Optional[FFmpegManager] = None
|
self.ffmpeg_manager: Optional[FFmpegManager] = None
|
||||||
self.ffmpeg_reader_task: Optional[asyncio.Task] = None
|
self.ffmpeg_reader_task: Optional[asyncio.Task] = None
|
||||||
self._ffmpeg_error: Optional[str] = None
|
self._ffmpeg_error: Optional[str] = None
|
||||||
@@ -106,7 +104,7 @@ class AudioProcessor:
|
|||||||
logger.error(f"FFmpeg error: {error_type}")
|
logger.error(f"FFmpeg error: {error_type}")
|
||||||
self._ffmpeg_error = error_type
|
self._ffmpeg_error = error_type
|
||||||
self.ffmpeg_manager.on_error_callback = handle_ffmpeg_error
|
self.ffmpeg_manager.on_error_callback = handle_ffmpeg_error
|
||||||
|
|
||||||
self.transcription_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.transcription else None
|
self.transcription_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.transcription else None
|
||||||
self.diarization_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.diarization else None
|
self.diarization_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.diarization else None
|
||||||
self.translation_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.target_language else None
|
self.translation_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.target_language else None
|
||||||
@@ -117,14 +115,14 @@ class AudioProcessor:
|
|||||||
self.translation_task: Optional[asyncio.Task] = None
|
self.translation_task: Optional[asyncio.Task] = None
|
||||||
self.watchdog_task: Optional[asyncio.Task] = None
|
self.watchdog_task: Optional[asyncio.Task] = None
|
||||||
self.all_tasks_for_cleanup: List[asyncio.Task] = []
|
self.all_tasks_for_cleanup: List[asyncio.Task] = []
|
||||||
|
|
||||||
self.transcription: Optional[Any] = None
|
self.transcription: Optional[Any] = None
|
||||||
self.translation: Optional[Any] = None
|
self.translation: Optional[Any] = None
|
||||||
self.diarization: Optional[Any] = None
|
self.diarization: Optional[Any] = None
|
||||||
|
|
||||||
if self.args.transcription:
|
if self.args.transcription:
|
||||||
self.transcription = online_factory(self.args, models.asr)
|
self.transcription = online_factory(self.args, models.asr)
|
||||||
self.sep = self.transcription.asr.sep
|
self.sep = self.transcription.asr.sep
|
||||||
if self.args.diarization:
|
if self.args.diarization:
|
||||||
self.diarization = online_diarization_factory(self.args, models.diarization_model)
|
self.diarization = online_diarization_factory(self.args, models.diarization_model)
|
||||||
if models.translation_model:
|
if models.translation_model:
|
||||||
@@ -182,24 +180,24 @@ class AudioProcessor:
|
|||||||
def convert_pcm_to_float(self, pcm_buffer: Union[bytes, bytearray]) -> np.ndarray:
|
def convert_pcm_to_float(self, pcm_buffer: Union[bytes, bytearray]) -> np.ndarray:
|
||||||
"""Convert PCM buffer in s16le format to normalized NumPy array."""
|
"""Convert PCM buffer in s16le format to normalized NumPy array."""
|
||||||
return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
|
return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
|
||||||
|
|
||||||
async def get_current_state(self) -> State:
|
async def get_current_state(self) -> State:
|
||||||
"""Get current state."""
|
"""Get current state."""
|
||||||
async with self.lock:
|
async with self.lock:
|
||||||
current_time = time()
|
current_time = time()
|
||||||
|
|
||||||
remaining_transcription = 0
|
remaining_transcription = 0
|
||||||
if self.state.end_buffer > 0:
|
if self.state.end_buffer > 0:
|
||||||
remaining_transcription = max(0, round(current_time - self.beg_loop - self.state.end_buffer, 1))
|
remaining_transcription = max(0, round(current_time - self.beg_loop - self.state.end_buffer, 1))
|
||||||
|
|
||||||
remaining_diarization = 0
|
remaining_diarization = 0
|
||||||
if self.state.tokens:
|
if self.state.tokens:
|
||||||
latest_end = max(self.state.end_buffer, self.state.tokens[-1].end if self.state.tokens else 0)
|
latest_end = max(self.state.end_buffer, self.state.tokens[-1].end if self.state.tokens else 0)
|
||||||
remaining_diarization = max(0, round(latest_end - self.state.end_attributed_speaker, 1))
|
remaining_diarization = max(0, round(latest_end - self.state.end_attributed_speaker, 1))
|
||||||
|
|
||||||
self.state.remaining_time_transcription = remaining_transcription
|
self.state.remaining_time_transcription = remaining_transcription
|
||||||
self.state.remaining_time_diarization = remaining_diarization
|
self.state.remaining_time_diarization = remaining_diarization
|
||||||
|
|
||||||
return self.state
|
return self.state
|
||||||
|
|
||||||
async def ffmpeg_stdout_reader(self) -> None:
|
async def ffmpeg_stdout_reader(self) -> None:
|
||||||
@@ -255,7 +253,7 @@ class AudioProcessor:
|
|||||||
async def transcription_processor(self) -> None:
|
async def transcription_processor(self) -> None:
|
||||||
"""Process audio chunks for transcription."""
|
"""Process audio chunks for transcription."""
|
||||||
cumulative_pcm_duration_stream_time = 0.0
|
cumulative_pcm_duration_stream_time = 0.0
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
# item = await self.transcription_queue.get()
|
# item = await self.transcription_queue.get()
|
||||||
@@ -311,12 +309,12 @@ class AudioProcessor:
|
|||||||
|
|
||||||
if new_tokens:
|
if new_tokens:
|
||||||
candidate_end_times.append(new_tokens[-1].end)
|
candidate_end_times.append(new_tokens[-1].end)
|
||||||
|
|
||||||
if _buffer_transcript.end is not None:
|
if _buffer_transcript.end is not None:
|
||||||
candidate_end_times.append(_buffer_transcript.end)
|
candidate_end_times.append(_buffer_transcript.end)
|
||||||
|
|
||||||
candidate_end_times.append(current_audio_processed_upto)
|
candidate_end_times.append(current_audio_processed_upto)
|
||||||
|
|
||||||
async with self.lock:
|
async with self.lock:
|
||||||
self.state.tokens.extend(new_tokens)
|
self.state.tokens.extend(new_tokens)
|
||||||
self.state.buffer_transcription = _buffer_transcript
|
self.state.buffer_transcription = _buffer_transcript
|
||||||
@@ -326,13 +324,13 @@ class AudioProcessor:
|
|||||||
|
|
||||||
if self.translation_queue:
|
if self.translation_queue:
|
||||||
for token in new_tokens:
|
for token in new_tokens:
|
||||||
await self.translation_queue.put(token)
|
await self.translation_queue.put(token)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Exception in transcription_processor: {e}")
|
logger.warning(f"Exception in transcription_processor: {e}")
|
||||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||||
if 'pcm_array' in locals() and pcm_array is not SENTINEL : # Check if pcm_array was assigned from queue
|
if 'pcm_array' in locals() and pcm_array is not SENTINEL : # Check if pcm_array was assigned from queue
|
||||||
self.transcription_queue.task_done()
|
self.transcription_queue.task_done()
|
||||||
|
|
||||||
if self.is_stopping:
|
if self.is_stopping:
|
||||||
logger.info("Transcription processor finishing due to stopping flag.")
|
logger.info("Transcription processor finishing due to stopping flag.")
|
||||||
if self.diarization_queue:
|
if self.diarization_queue:
|
||||||
@@ -353,21 +351,18 @@ class AudioProcessor:
|
|||||||
if item.has_ended:
|
if item.has_ended:
|
||||||
self.diarization.insert_silence(item.duration)
|
self.diarization.insert_silence(item.duration)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self.diarization.insert_audio_chunk(item)
|
self.diarization.insert_audio_chunk(item)
|
||||||
diarization_segments = await self.diarization.diarize()
|
diarization_segments = await self.diarization.diarize()
|
||||||
diar_end = 0.0
|
self.state.new_diarization = diarization_segments
|
||||||
if diarization_segments:
|
|
||||||
diar_end = max(getattr(s, "end", 0.0) for s in diarization_segments)
|
|
||||||
async with self.lock:
|
|
||||||
self.state.new_diarization = diarization_segments
|
|
||||||
self.state.end_attributed_speaker = max(self.state.end_attributed_speaker, diar_end)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Exception in diarization_processor: {e}")
|
logger.warning(f"Exception in diarization_processor: {e}")
|
||||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||||
logger.info("Diarization processor task finished.")
|
logger.info("Diarization processor task finished.")
|
||||||
|
|
||||||
async def translation_processor(self) -> None:
|
async def translation_processor(self) -> None:
|
||||||
# the idea is to ignore diarization for the moment. We use only transcription tokens.
|
# the idea is to ignore diarization for the moment. We use only transcription tokens.
|
||||||
# And the speaker is attributed given the segments used for the translation
|
# And the speaker is attributed given the segments used for the translation
|
||||||
# in the future we want to have different languages for each speaker etc, so it will be more complex.
|
# in the future we want to have different languages for each speaker etc, so it will be more complex.
|
||||||
while True:
|
while True:
|
||||||
@@ -398,6 +393,10 @@ class AudioProcessor:
|
|||||||
|
|
||||||
async def results_formatter(self) -> AsyncGenerator[FrontData, None]:
|
async def results_formatter(self) -> AsyncGenerator[FrontData, None]:
|
||||||
"""Format processing results for output."""
|
"""Format processing results for output."""
|
||||||
|
# Update intervals
|
||||||
|
ACTIVE_INTERVAL = 0.05 # 20 updates/sec during active transcription
|
||||||
|
SILENCE_INTERVAL = 0.5 # 2 updates/sec during silence
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
if self._ffmpeg_error:
|
if self._ffmpeg_error:
|
||||||
@@ -407,44 +406,62 @@ class AudioProcessor:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
self.tokens_alignment.update()
|
self.tokens_alignment.update()
|
||||||
lines, buffer_diarization_text, buffer_translation_text = self.tokens_alignment.get_lines(
|
state = await self.get_current_state()
|
||||||
|
|
||||||
|
# Get transcription buffer text to pass to get_lines
|
||||||
|
buffer_transcription_text = state.buffer_transcription.text if state.buffer_transcription else ''
|
||||||
|
|
||||||
|
# get_lines now returns segments with per-segment buffers
|
||||||
|
segments = self.tokens_alignment.get_lines(
|
||||||
diarization=self.args.diarization,
|
diarization=self.args.diarization,
|
||||||
translation=bool(self.translation),
|
translation=bool(self.translation),
|
||||||
current_silence=self.current_silence
|
current_silence=self.current_silence,
|
||||||
|
buffer_transcription=buffer_transcription_text
|
||||||
)
|
)
|
||||||
state = await self.get_current_state()
|
|
||||||
|
|
||||||
buffer_transcription_text = state.buffer_transcription.text if state.buffer_transcription else ''
|
|
||||||
|
|
||||||
response_status = "active_transcription"
|
response_status = "active_transcription"
|
||||||
if not lines and not buffer_transcription_text and not buffer_diarization_text:
|
# Check if there's any content (segments with text or buffers)
|
||||||
|
has_active_content = any(
|
||||||
|
seg.buffer and (seg.buffer.transcription or seg.buffer.diarization)
|
||||||
|
for seg in segments if not seg.is_silence()
|
||||||
|
)
|
||||||
|
has_any_content = any(
|
||||||
|
seg.text or (seg.buffer and (seg.buffer.transcription or seg.buffer.diarization))
|
||||||
|
for seg in segments if not seg.is_silence()
|
||||||
|
)
|
||||||
|
if not segments or not has_any_content:
|
||||||
response_status = "no_audio_detected"
|
response_status = "no_audio_detected"
|
||||||
|
|
||||||
response = FrontData(
|
response = FrontData(
|
||||||
status=response_status,
|
status=response_status,
|
||||||
lines=lines,
|
segments=segments,
|
||||||
buffer_transcription=buffer_transcription_text,
|
|
||||||
buffer_diarization=buffer_diarization_text,
|
|
||||||
buffer_translation=buffer_translation_text,
|
|
||||||
remaining_time_transcription=state.remaining_time_transcription,
|
remaining_time_transcription=state.remaining_time_transcription,
|
||||||
remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0
|
remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0
|
||||||
)
|
)
|
||||||
|
|
||||||
should_push = (response != self.last_response_content)
|
should_push = (response != self.last_response_content)
|
||||||
if should_push:
|
if should_push:
|
||||||
yield response
|
yield response
|
||||||
self.last_response_content = response
|
self.last_response_content = response
|
||||||
|
|
||||||
if self.is_stopping and self._processing_tasks_done():
|
if self.is_stopping and self._processing_tasks_done():
|
||||||
logger.info("Results formatter: All upstream processors are done and in stopping state. Terminating.")
|
logger.info("Results formatter: All upstream processors are done and in stopping state. Terminating.")
|
||||||
return
|
return
|
||||||
|
|
||||||
await asyncio.sleep(0.05)
|
# Throttle updates during silence: use slower interval when in silence mode
|
||||||
|
# with no pending buffers (nothing actively being processed)
|
||||||
|
is_in_silence = self.current_silence is not None
|
||||||
|
has_pending_work = has_active_content or state.remaining_time_transcription > 0.5
|
||||||
|
|
||||||
|
if is_in_silence and not has_pending_work:
|
||||||
|
await asyncio.sleep(SILENCE_INTERVAL)
|
||||||
|
else:
|
||||||
|
await asyncio.sleep(ACTIVE_INTERVAL)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Exception in results_formatter. Traceback: {traceback.format_exc()}")
|
logger.warning(f"Exception in results_formatter. Traceback: {traceback.format_exc()}")
|
||||||
await asyncio.sleep(0.5)
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
async def create_tasks(self) -> AsyncGenerator[FrontData, None]:
|
async def create_tasks(self) -> AsyncGenerator[FrontData, None]:
|
||||||
"""Create and start processing tasks."""
|
"""Create and start processing tasks."""
|
||||||
self.all_tasks_for_cleanup = []
|
self.all_tasks_for_cleanup = []
|
||||||
@@ -469,21 +486,21 @@ class AudioProcessor:
|
|||||||
self.transcription_task = asyncio.create_task(self.transcription_processor())
|
self.transcription_task = asyncio.create_task(self.transcription_processor())
|
||||||
self.all_tasks_for_cleanup.append(self.transcription_task)
|
self.all_tasks_for_cleanup.append(self.transcription_task)
|
||||||
processing_tasks_for_watchdog.append(self.transcription_task)
|
processing_tasks_for_watchdog.append(self.transcription_task)
|
||||||
|
|
||||||
if self.diarization:
|
if self.diarization:
|
||||||
self.diarization_task = asyncio.create_task(self.diarization_processor())
|
self.diarization_task = asyncio.create_task(self.diarization_processor())
|
||||||
self.all_tasks_for_cleanup.append(self.diarization_task)
|
self.all_tasks_for_cleanup.append(self.diarization_task)
|
||||||
processing_tasks_for_watchdog.append(self.diarization_task)
|
processing_tasks_for_watchdog.append(self.diarization_task)
|
||||||
|
|
||||||
if self.translation:
|
if self.translation:
|
||||||
self.translation_task = asyncio.create_task(self.translation_processor())
|
self.translation_task = asyncio.create_task(self.translation_processor())
|
||||||
self.all_tasks_for_cleanup.append(self.translation_task)
|
self.all_tasks_for_cleanup.append(self.translation_task)
|
||||||
processing_tasks_for_watchdog.append(self.translation_task)
|
processing_tasks_for_watchdog.append(self.translation_task)
|
||||||
|
|
||||||
# Monitor overall system health
|
# Monitor overall system health
|
||||||
self.watchdog_task = asyncio.create_task(self.watchdog(processing_tasks_for_watchdog))
|
self.watchdog_task = asyncio.create_task(self.watchdog(processing_tasks_for_watchdog))
|
||||||
self.all_tasks_for_cleanup.append(self.watchdog_task)
|
self.all_tasks_for_cleanup.append(self.watchdog_task)
|
||||||
|
|
||||||
return self.results_formatter()
|
return self.results_formatter()
|
||||||
|
|
||||||
async def watchdog(self, tasks_to_monitor: List[asyncio.Task]) -> None:
|
async def watchdog(self, tasks_to_monitor: List[asyncio.Task]) -> None:
|
||||||
@@ -496,7 +513,7 @@ class AudioProcessor:
|
|||||||
return
|
return
|
||||||
|
|
||||||
await asyncio.sleep(10)
|
await asyncio.sleep(10)
|
||||||
|
|
||||||
for i, task in enumerate(list(tasks_remaining)):
|
for i, task in enumerate(list(tasks_remaining)):
|
||||||
if task.done():
|
if task.done():
|
||||||
exc = task.exception()
|
exc = task.exception()
|
||||||
@@ -506,13 +523,13 @@ class AudioProcessor:
|
|||||||
else:
|
else:
|
||||||
logger.info(f"{task_name} completed normally.")
|
logger.info(f"{task_name} completed normally.")
|
||||||
tasks_remaining.remove(task)
|
tasks_remaining.remove(task)
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info("Watchdog task cancelled.")
|
logger.info("Watchdog task cancelled.")
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in watchdog task: {e}", exc_info=True)
|
logger.error(f"Error in watchdog task: {e}", exc_info=True)
|
||||||
|
|
||||||
async def cleanup(self) -> None:
|
async def cleanup(self) -> None:
|
||||||
"""Clean up resources when processing is complete."""
|
"""Clean up resources when processing is complete."""
|
||||||
logger.info("Starting cleanup of AudioProcessor resources.")
|
logger.info("Starting cleanup of AudioProcessor resources.")
|
||||||
@@ -520,7 +537,7 @@ class AudioProcessor:
|
|||||||
for task in self.all_tasks_for_cleanup:
|
for task in self.all_tasks_for_cleanup:
|
||||||
if task and not task.done():
|
if task and not task.done():
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
created_tasks = [t for t in self.all_tasks_for_cleanup if t]
|
created_tasks = [t for t in self.all_tasks_for_cleanup if t]
|
||||||
if created_tasks:
|
if created_tasks:
|
||||||
await asyncio.gather(*created_tasks, return_exceptions=True)
|
await asyncio.gather(*created_tasks, return_exceptions=True)
|
||||||
@@ -558,7 +575,7 @@ class AudioProcessor:
|
|||||||
if not message:
|
if not message:
|
||||||
logger.info("Empty audio message received, initiating stop sequence.")
|
logger.info("Empty audio message received, initiating stop sequence.")
|
||||||
self.is_stopping = True
|
self.is_stopping = True
|
||||||
|
|
||||||
if self.transcription_queue:
|
if self.transcription_queue:
|
||||||
await self.transcription_queue.put(SENTINEL)
|
await self.transcription_queue.put(SENTINEL)
|
||||||
|
|
||||||
@@ -599,7 +616,7 @@ class AudioProcessor:
|
|||||||
|
|
||||||
chunk_size = min(len(self.pcm_buffer), self.max_bytes_per_sec)
|
chunk_size = min(len(self.pcm_buffer), self.max_bytes_per_sec)
|
||||||
aligned_chunk_size = (chunk_size // self.bytes_per_sample) * self.bytes_per_sample
|
aligned_chunk_size = (chunk_size // self.bytes_per_sample) * self.bytes_per_sample
|
||||||
|
|
||||||
if aligned_chunk_size == 0:
|
if aligned_chunk_size == 0:
|
||||||
return
|
return
|
||||||
pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:aligned_chunk_size])
|
pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:aligned_chunk_size])
|
||||||
@@ -616,7 +633,7 @@ class AudioProcessor:
|
|||||||
if res is not None:
|
if res is not None:
|
||||||
if "start" in res and self.current_silence:
|
if "start" in res and self.current_silence:
|
||||||
await self._end_silence()
|
await self._end_silence()
|
||||||
|
|
||||||
if "end" in res and not self.current_silence:
|
if "end" in res and not self.current_silence:
|
||||||
pre_silence_chunk = self._slice_before_silence(
|
pre_silence_chunk = self._slice_before_silence(
|
||||||
pcm_array, chunk_sample_start, res.get("end")
|
pcm_array, chunk_sample_start, res.get("end")
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||||||
from fastapi.responses import HTMLResponse
|
from fastapi.responses import HTMLResponse
|
||||||
|
|
||||||
from whisperlivekit import (AudioProcessor, TranscriptionEngine,
|
from whisperlivekit import (AudioProcessor, TranscriptionEngine,
|
||||||
get_inline_ui_html, parse_args)
|
get_inline_ui_html, get_text_transcript_html, parse_args)
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||||
logging.getLogger().setLevel(logging.WARNING)
|
logging.getLogger().setLevel(logging.WARNING)
|
||||||
@@ -39,6 +39,12 @@ async def get():
|
|||||||
return HTMLResponse(get_inline_ui_html())
|
return HTMLResponse(get_inline_ui_html())
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/text")
|
||||||
|
async def get_text():
|
||||||
|
"""Simple text-based transcript view for easy copy/paste."""
|
||||||
|
return HTMLResponse(get_text_transcript_html())
|
||||||
|
|
||||||
|
|
||||||
async def handle_websocket_results(websocket, results_generator):
|
async def handle_websocket_results(websocket, results_generator):
|
||||||
"""Consumes results from the audio processor and sends them via WebSocket."""
|
"""Consumes results from the audio processor and sends them via WebSocket."""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
import threading
|
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
|
||||||
from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor
|
from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor
|
||||||
@@ -20,26 +19,16 @@ logger = logging.getLogger(__name__)
|
|||||||
class TranscriptionEngine:
|
class TranscriptionEngine:
|
||||||
_instance = None
|
_instance = None
|
||||||
_initialized = False
|
_initialized = False
|
||||||
_lock = threading.Lock() # Thread-safe singleton lock
|
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs):
|
def __new__(cls, *args, **kwargs):
|
||||||
# Double-checked locking pattern for thread-safe singleton
|
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
with cls._lock:
|
cls._instance = super().__new__(cls)
|
||||||
# Check again inside lock to prevent race condition
|
|
||||||
if cls._instance is None:
|
|
||||||
cls._instance = super().__new__(cls)
|
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
# Thread-safe initialization check
|
if TranscriptionEngine._initialized:
|
||||||
with TranscriptionEngine._lock:
|
return
|
||||||
if TranscriptionEngine._initialized:
|
|
||||||
return
|
|
||||||
# Set flag immediately to prevent re-initialization
|
|
||||||
TranscriptionEngine._initialized = True
|
|
||||||
|
|
||||||
# Perform initialization outside lock to avoid holding lock during slow operations
|
|
||||||
global_params = {
|
global_params = {
|
||||||
"host": "localhost",
|
"host": "localhost",
|
||||||
"port": 8000,
|
"port": 8000,
|
||||||
@@ -47,6 +36,7 @@ class TranscriptionEngine:
|
|||||||
"punctuation_split": False,
|
"punctuation_split": False,
|
||||||
"target_language": "",
|
"target_language": "",
|
||||||
"vac": True,
|
"vac": True,
|
||||||
|
"vac_onnx": False,
|
||||||
"vac_chunk_size": 0.04,
|
"vac_chunk_size": 0.04,
|
||||||
"log_level": "DEBUG",
|
"log_level": "DEBUG",
|
||||||
"ssl_certfile": None,
|
"ssl_certfile": None,
|
||||||
@@ -89,19 +79,15 @@ class TranscriptionEngine:
|
|||||||
self.asr = None
|
self.asr = None
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
self.diarization = None
|
self.diarization = None
|
||||||
self.vac_session = None
|
self.vac_model = None
|
||||||
|
|
||||||
if self.args.vac:
|
if self.args.vac:
|
||||||
from whisperlivekit.silero_vad_iterator import is_onnx_available
|
from whisperlivekit.silero_vad_iterator import load_silero_vad
|
||||||
|
|
||||||
if is_onnx_available():
|
# Use ONNX if specified, otherwise use JIT (default)
|
||||||
from whisperlivekit.silero_vad_iterator import load_onnx_session
|
use_onnx = kwargs.get('vac_onnx', False)
|
||||||
self.vac_session = load_onnx_session()
|
self.vac_model = load_silero_vad(onnx=use_onnx)
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"onnxruntime not installed. VAC will use JIT model which is loaded per-session. "
|
|
||||||
"For multi-user scenarios, install onnxruntime: pip install onnxruntime"
|
|
||||||
)
|
|
||||||
backend_policy = self.args.backend_policy
|
backend_policy = self.args.backend_policy
|
||||||
if self.args.transcription:
|
if self.args.transcription:
|
||||||
if backend_policy == "simulstreaming":
|
if backend_policy == "simulstreaming":
|
||||||
@@ -183,13 +169,16 @@ class TranscriptionEngine:
|
|||||||
}
|
}
|
||||||
translation_params = update_with_kwargs(translation_params, kwargs)
|
translation_params = update_with_kwargs(translation_params, kwargs)
|
||||||
self.translation_model = load_model([self.args.lan], **translation_params) #in the future we want to handle different languages for different speakers
|
self.translation_model = load_model([self.args.lan], **translation_params) #in the future we want to handle different languages for different speakers
|
||||||
|
TranscriptionEngine._initialized = True
|
||||||
|
|
||||||
|
|
||||||
def online_factory(args, asr):
|
def online_factory(args, asr):
|
||||||
if args.backend_policy == "simulstreaming":
|
if args.backend_policy == "simulstreaming":
|
||||||
from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor
|
from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor
|
||||||
return SimulStreamingOnlineProcessor(asr)
|
online = SimulStreamingOnlineProcessor(asr)
|
||||||
return OnlineASRProcessor(asr)
|
else:
|
||||||
|
online = OnlineASRProcessor(asr)
|
||||||
|
return online
|
||||||
|
|
||||||
|
|
||||||
def online_diarization_factory(args, diarization_backend):
|
def online_diarization_factory(args, diarization_backend):
|
||||||
|
|||||||
@@ -151,7 +151,7 @@ class FasterWhisperASR(ASRBase):
|
|||||||
if segment.no_speech_prob > 0.9:
|
if segment.no_speech_prob > 0.9:
|
||||||
continue
|
continue
|
||||||
for word in segment.words:
|
for word in segment.words:
|
||||||
token = ASRToken(word.start, word.end, word.word)
|
token = ASRToken(word.start, word.end, word.word, probability=word.probability)
|
||||||
tokens.append(token)
|
tokens.append(token)
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
|
|||||||
@@ -8,15 +8,6 @@ import torch
|
|||||||
Code is adapted from silero-vad v6: https://github.com/snakers4/silero-vad
|
Code is adapted from silero-vad v6: https://github.com/snakers4/silero-vad
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def is_onnx_available() -> bool:
|
|
||||||
"""Check if onnxruntime is installed."""
|
|
||||||
try:
|
|
||||||
import onnxruntime
|
|
||||||
return True
|
|
||||||
except ImportError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def init_jit_model(model_path: str, device=torch.device('cpu')):
|
def init_jit_model(model_path: str, device=torch.device('cpu')):
|
||||||
"""Load a JIT model from file."""
|
"""Load a JIT model from file."""
|
||||||
model = torch.jit.load(model_path, map_location=device)
|
model = torch.jit.load(model_path, map_location=device)
|
||||||
@@ -24,12 +15,12 @@ def init_jit_model(model_path: str, device=torch.device('cpu')):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
class OnnxSession():
|
class OnnxWrapper():
|
||||||
"""
|
"""ONNX Runtime wrapper for Silero VAD model."""
|
||||||
Shared ONNX session for Silero VAD model (stateless).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, path, force_onnx_cpu=False):
|
def __init__(self, path, force_onnx_cpu=False):
|
||||||
|
global np
|
||||||
|
import numpy as np
|
||||||
import onnxruntime
|
import onnxruntime
|
||||||
|
|
||||||
opts = onnxruntime.SessionOptions()
|
opts = onnxruntime.SessionOptions()
|
||||||
@@ -41,28 +32,13 @@ class OnnxSession():
|
|||||||
else:
|
else:
|
||||||
self.session = onnxruntime.InferenceSession(path, sess_options=opts)
|
self.session = onnxruntime.InferenceSession(path, sess_options=opts)
|
||||||
|
|
||||||
self.path = path
|
self.reset_states()
|
||||||
if '16k' in path:
|
if '16k' in path:
|
||||||
warnings.warn('This model support only 16000 sampling rate!')
|
warnings.warn('This model support only 16000 sampling rate!')
|
||||||
self.sample_rates = [16000]
|
self.sample_rates = [16000]
|
||||||
else:
|
else:
|
||||||
self.sample_rates = [8000, 16000]
|
self.sample_rates = [8000, 16000]
|
||||||
|
|
||||||
|
|
||||||
class OnnxWrapper():
|
|
||||||
"""
|
|
||||||
ONNX Runtime wrapper for Silero VAD model with per-instance state.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, session: OnnxSession, force_onnx_cpu=False):
|
|
||||||
self._shared_session = session
|
|
||||||
self.sample_rates = session.sample_rates
|
|
||||||
self.reset_states()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def session(self):
|
|
||||||
return self._shared_session.session
|
|
||||||
|
|
||||||
def _validate_input(self, x, sr: int):
|
def _validate_input(self, x, sr: int):
|
||||||
if x.dim() == 1:
|
if x.dim() == 1:
|
||||||
x = x.unsqueeze(0)
|
x = x.unsqueeze(0)
|
||||||
@@ -125,20 +101,38 @@ class OnnxWrapper():
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def _get_onnx_model_path(model_path: str = None, opset_version: int = 16) -> Path:
|
def load_silero_vad(model_path: str = None, onnx: bool = False, opset_version: int = 16):
|
||||||
"""Get the path to the ONNX model file."""
|
"""
|
||||||
|
Load Silero VAD model (JIT or ONNX).
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model_path : str, optional
|
||||||
|
Path to model file. If None, uses default bundled model.
|
||||||
|
onnx : bool, default False
|
||||||
|
Whether to use ONNX runtime (requires onnxruntime package).
|
||||||
|
opset_version : int, default 16
|
||||||
|
ONNX opset version (15 or 16). Only used if onnx=True.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
model
|
||||||
|
Loaded VAD model (JIT or ONNX wrapper)
|
||||||
|
"""
|
||||||
available_ops = [15, 16]
|
available_ops = [15, 16]
|
||||||
if opset_version not in available_ops:
|
if onnx and opset_version not in available_ops:
|
||||||
raise Exception(f'Available ONNX opset_version: {available_ops}')
|
raise Exception(f'Available ONNX opset_version: {available_ops}')
|
||||||
|
|
||||||
if model_path is None:
|
if model_path is None:
|
||||||
current_dir = Path(__file__).parent
|
current_dir = Path(__file__).parent
|
||||||
data_dir = current_dir / 'silero_vad_models'
|
data_dir = current_dir / 'silero_vad_models'
|
||||||
|
|
||||||
if opset_version == 16:
|
if onnx:
|
||||||
model_name = 'silero_vad.onnx'
|
if opset_version == 16:
|
||||||
|
model_name = 'silero_vad.onnx'
|
||||||
|
else:
|
||||||
|
model_name = f'silero_vad_16k_op{opset_version}.onnx'
|
||||||
else:
|
else:
|
||||||
model_name = f'silero_vad_16k_op{opset_version}.onnx'
|
model_name = 'silero_vad.jit'
|
||||||
|
|
||||||
model_path = data_dir / model_name
|
model_path = data_dir / model_name
|
||||||
|
|
||||||
@@ -149,39 +143,17 @@ def _get_onnx_model_path(model_path: str = None, opset_version: int = 16) -> Pat
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model_path = Path(model_path)
|
model_path = Path(model_path)
|
||||||
|
if onnx:
|
||||||
return model_path
|
try:
|
||||||
|
model = OnnxWrapper(str(model_path), force_onnx_cpu=True)
|
||||||
|
except ImportError:
|
||||||
def load_onnx_session(model_path: str = None, opset_version: int = 16, force_onnx_cpu: bool = True) -> OnnxSession:
|
raise ImportError(
|
||||||
"""
|
"ONNX runtime not available. Install with: pip install onnxruntime\n"
|
||||||
Load a shared ONNX session for Silero VAD.
|
"Or use JIT model by setting onnx=False"
|
||||||
"""
|
|
||||||
path = _get_onnx_model_path(model_path, opset_version)
|
|
||||||
return OnnxSession(str(path), force_onnx_cpu=force_onnx_cpu)
|
|
||||||
|
|
||||||
|
|
||||||
def load_jit_vad(model_path: str = None):
|
|
||||||
"""
|
|
||||||
Load Silero VAD model in JIT format.
|
|
||||||
"""
|
|
||||||
if model_path is None:
|
|
||||||
current_dir = Path(__file__).parent
|
|
||||||
data_dir = current_dir / 'silero_vad_models'
|
|
||||||
model_name = 'silero_vad.jit'
|
|
||||||
|
|
||||||
model_path = data_dir / model_name
|
|
||||||
|
|
||||||
if not model_path.exists():
|
|
||||||
raise FileNotFoundError(
|
|
||||||
f"Model file not found: {model_path}\n"
|
|
||||||
f"Please ensure the whisperlivekit/silero_vad_models/ directory contains the model files."
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model_path = Path(model_path)
|
model = init_jit_model(str(model_path))
|
||||||
|
|
||||||
model = init_jit_model(str(model_path))
|
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@@ -313,14 +285,13 @@ class FixedVADIterator(VADIterator):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# vad = FixedVADIterator(load_jit_vad())
|
model = load_silero_vad(onnx=False)
|
||||||
vad = FixedVADIterator(OnnxWrapper(session=load_onnx_session()))
|
vad = FixedVADIterator(model)
|
||||||
|
|
||||||
audio_buffer = np.array([0] * 512, dtype=np.float32)
|
audio_buffer = np.array([0] * 512, dtype=np.float32)
|
||||||
result = vad(audio_buffer)
|
result = vad(audio_buffer)
|
||||||
print(f" 512 samples: {result}")
|
print(f" 512 samples: {result}")
|
||||||
|
|
||||||
# test with 511 samples
|
# test with 511 samples
|
||||||
audio_buffer = np.array([0] * 511, dtype=np.float32)
|
audio_buffer = np.array([0] * 511, dtype=np.float32)
|
||||||
result = vad(audio_buffer)
|
result = vad(audio_buffer)
|
||||||
print(f" 511 samples: {result}")
|
|
||||||
@@ -24,11 +24,9 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True)
|
HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True)
|
||||||
if HAS_MLX_WHISPER:
|
if HAS_MLX_WHISPER:
|
||||||
from .mlx_encoder import load_mlx_encoder, load_mlx_model, mlx_model_mapping
|
from .mlx_encoder import load_mlx_encoder, mlx_model_mapping
|
||||||
from .mlx import MLXAlignAtt
|
|
||||||
else:
|
else:
|
||||||
mlx_model_mapping = {}
|
mlx_model_mapping = {}
|
||||||
MLXAlignAtt = None
|
|
||||||
HAS_FASTER_WHISPER = faster_backend_available(warn_on_missing=not HAS_MLX_WHISPER)
|
HAS_FASTER_WHISPER = faster_backend_available(warn_on_missing=not HAS_MLX_WHISPER)
|
||||||
if HAS_FASTER_WHISPER:
|
if HAS_FASTER_WHISPER:
|
||||||
from faster_whisper import WhisperModel
|
from faster_whisper import WhisperModel
|
||||||
@@ -38,49 +36,50 @@ else:
|
|||||||
MIN_DURATION_REAL_SILENCE = 5
|
MIN_DURATION_REAL_SILENCE = 5
|
||||||
|
|
||||||
class SimulStreamingOnlineProcessor:
|
class SimulStreamingOnlineProcessor:
|
||||||
"""Online processor for SimulStreaming ASR."""
|
|
||||||
SAMPLING_RATE = 16000
|
SAMPLING_RATE = 16000
|
||||||
|
|
||||||
def __init__(self, asr, logfile=sys.stderr):
|
def __init__(
|
||||||
|
self,
|
||||||
|
asr,
|
||||||
|
logfile=sys.stderr,
|
||||||
|
):
|
||||||
self.asr = asr
|
self.asr = asr
|
||||||
self.logfile = logfile
|
self.logfile = logfile
|
||||||
self.end = 0.0
|
self.end = 0.0
|
||||||
self.buffer = []
|
self.buffer = []
|
||||||
self.committed: List[ASRToken] = []
|
self.committed: List[ASRToken] = []
|
||||||
self.last_result_tokens: List[ASRToken] = []
|
self.last_result_tokens: List[ASRToken] = []
|
||||||
self.model = self._create_alignatt()
|
self.load_new_alignatt_instance()
|
||||||
|
|
||||||
if asr.tokenizer:
|
if asr.tokenizer:
|
||||||
self.model.tokenizer = asr.tokenizer
|
self.model.tokenizer = asr.tokenizer
|
||||||
self.model.state.tokenizer = asr.tokenizer
|
|
||||||
|
|
||||||
def _create_alignatt(self):
|
def load_new_alignatt_instance(self):
|
||||||
"""Create the AlignAtt decoder instance based on ASR mode."""
|
"""Initialize AlignAtt decoder using the shared model."""
|
||||||
if self.asr.use_full_mlx and HAS_MLX_WHISPER:
|
self.model = AlignAtt(
|
||||||
return MLXAlignAtt(cfg=self.asr.cfg, mlx_model=self.asr.mlx_model)
|
cfg=self.asr.cfg,
|
||||||
else:
|
loaded_model=self.asr.shared_model,
|
||||||
return AlignAtt(
|
mlx_encoder=self.asr.mlx_encoder,
|
||||||
cfg=self.asr.cfg,
|
fw_encoder=self.asr.fw_encoder,
|
||||||
loaded_model=self.asr.shared_model,
|
)
|
||||||
mlx_encoder=self.asr.mlx_encoder,
|
|
||||||
fw_encoder=self.asr.fw_encoder,
|
|
||||||
)
|
|
||||||
|
|
||||||
def start_silence(self):
|
def start_silence(self):
|
||||||
tokens, processed_upto = self.process_iter(is_last=True)
|
tokens, processed_upto = self.process_iter(is_last=True)
|
||||||
return tokens, processed_upto
|
return tokens, processed_upto
|
||||||
|
|
||||||
def end_silence(self, silence_duration, offset):
|
def end_silence(self, silence_duration, offset):
|
||||||
"""Handle silence period."""
|
"""
|
||||||
|
Handle silence period.
|
||||||
|
|
||||||
|
If silence > MIN_DURATION_REAL_SILENCE, do a complete context clear.
|
||||||
|
Otherwise, insert a small silence and shift the last_attend_frame.
|
||||||
|
"""
|
||||||
self.end += silence_duration
|
self.end += silence_duration
|
||||||
long_silence = silence_duration >= MIN_DURATION_REAL_SILENCE
|
long_silence = silence_duration >= MIN_DURATION_REAL_SILENCE
|
||||||
if not long_silence:
|
if not long_silence:
|
||||||
gap_len = int(16000 * silence_duration)
|
gap_len = int(16000 * silence_duration)
|
||||||
if gap_len > 0:
|
if gap_len > 0:
|
||||||
if self.asr.use_full_mlx:
|
gap_silence = torch.zeros(gap_len)
|
||||||
gap_silence = np.zeros(gap_len, dtype=np.float32)
|
|
||||||
else:
|
|
||||||
gap_silence = torch.zeros(gap_len)
|
|
||||||
self.model.insert_audio(gap_silence)
|
self.model.insert_audio(gap_silence)
|
||||||
if long_silence:
|
if long_silence:
|
||||||
self.model.refresh_segment(complete=True)
|
self.model.refresh_segment(complete=True)
|
||||||
@@ -88,12 +87,11 @@ class SimulStreamingOnlineProcessor:
|
|||||||
|
|
||||||
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time):
|
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time):
|
||||||
"""Append an audio chunk to be processed by SimulStreaming."""
|
"""Append an audio chunk to be processed by SimulStreaming."""
|
||||||
self.end = audio_stream_end_time
|
|
||||||
if self.asr.use_full_mlx:
|
# Convert numpy array to torch tensor
|
||||||
self.model.insert_audio(audio)
|
audio_tensor = torch.from_numpy(audio).float()
|
||||||
else:
|
self.end = audio_stream_end_time # Aligned with whisperstreaming backend behavior
|
||||||
audio_tensor = torch.from_numpy(audio).float()
|
self.model.insert_audio(audio_tensor)
|
||||||
self.model.insert_audio(audio_tensor)
|
|
||||||
|
|
||||||
def new_speaker(self, change_speaker: ChangeSpeaker):
|
def new_speaker(self, change_speaker: ChangeSpeaker):
|
||||||
"""Handle speaker change event."""
|
"""Handle speaker change event."""
|
||||||
@@ -132,10 +130,6 @@ class SimulStreamingOnlineProcessor:
|
|||||||
def warmup(self, audio, init_prompt=""):
|
def warmup(self, audio, init_prompt=""):
|
||||||
"""Warmup the SimulStreaming model."""
|
"""Warmup the SimulStreaming model."""
|
||||||
try:
|
try:
|
||||||
if self.asr.use_full_mlx:
|
|
||||||
# MLX mode: ensure numpy array
|
|
||||||
if hasattr(audio, 'numpy'):
|
|
||||||
audio = audio.numpy()
|
|
||||||
self.model.insert_audio(audio)
|
self.model.insert_audio(audio)
|
||||||
self.model.infer(True)
|
self.model.infer(True)
|
||||||
self.model.refresh_segment(complete=True)
|
self.model.refresh_segment(complete=True)
|
||||||
@@ -145,14 +139,9 @@ class SimulStreamingOnlineProcessor:
|
|||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if not getattr(self.asr, 'use_full_mlx', True) and torch is not None:
|
torch.cuda.empty_cache()
|
||||||
try:
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
class SimulStreamingASR():
|
||||||
class SimulStreamingASR:
|
|
||||||
"""SimulStreaming backend with AlignAtt policy."""
|
"""SimulStreaming backend with AlignAtt policy."""
|
||||||
sep = ""
|
sep = ""
|
||||||
|
|
||||||
@@ -169,7 +158,6 @@ class SimulStreamingASR:
|
|||||||
self.fast_encoder = False
|
self.fast_encoder = False
|
||||||
self._resolved_model_path = None
|
self._resolved_model_path = None
|
||||||
self.encoder_backend = "whisper"
|
self.encoder_backend = "whisper"
|
||||||
self.use_full_mlx = getattr(self, "use_full_mlx", False)
|
|
||||||
preferred_backend = getattr(self, "backend", "auto")
|
preferred_backend = getattr(self, "backend", "auto")
|
||||||
compatible_whisper_mlx, compatible_faster_whisper = True, True
|
compatible_whisper_mlx, compatible_faster_whisper = True, True
|
||||||
|
|
||||||
@@ -182,7 +170,7 @@ class SimulStreamingASR:
|
|||||||
compatible_whisper_mlx = model_info.compatible_whisper_mlx
|
compatible_whisper_mlx = model_info.compatible_whisper_mlx
|
||||||
compatible_faster_whisper = model_info.compatible_faster_whisper
|
compatible_faster_whisper = model_info.compatible_faster_whisper
|
||||||
|
|
||||||
if not self.use_full_mlx and not model_info.has_pytorch:
|
if not model_info.has_pytorch:
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
f"No PyTorch checkpoint (.pt/.bin/.safetensors) found under {self.model_path}"
|
f"No PyTorch checkpoint (.pt/.bin/.safetensors) found under {self.model_path}"
|
||||||
)
|
)
|
||||||
@@ -202,10 +190,6 @@ class SimulStreamingASR:
|
|||||||
self.fast_encoder = self.encoder_backend in ("mlx-whisper", "faster-whisper")
|
self.fast_encoder = self.encoder_backend in ("mlx-whisper", "faster-whisper")
|
||||||
if self.encoder_backend == "whisper":
|
if self.encoder_backend == "whisper":
|
||||||
self.disable_fast_encoder = True
|
self.disable_fast_encoder = True
|
||||||
|
|
||||||
if self.encoder_backend == "mlx-whisper" and platform.system() == "Darwin":
|
|
||||||
if not hasattr(self, '_full_mlx_disabled'):
|
|
||||||
self.use_full_mlx = True
|
|
||||||
|
|
||||||
self.cfg = AlignAttConfig(
|
self.cfg = AlignAttConfig(
|
||||||
tokenizer_is_multilingual= is_multilingual,
|
tokenizer_is_multilingual= is_multilingual,
|
||||||
@@ -230,36 +214,20 @@ class SimulStreamingASR:
|
|||||||
else:
|
else:
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
|
|
||||||
self.mlx_encoder, self.fw_encoder, self.mlx_model = None, None, None
|
self.mlx_encoder, self.fw_encoder = None, None
|
||||||
self.shared_model = None
|
if self.encoder_backend == "mlx-whisper":
|
||||||
|
print('Simulstreaming will use MLX whisper to increase encoding speed.')
|
||||||
if self.use_full_mlx and HAS_MLX_WHISPER:
|
|
||||||
logger.info('MLX Whisper backend used.')
|
|
||||||
if self._resolved_model_path is not None:
|
if self._resolved_model_path is not None:
|
||||||
mlx_model_path = str(self._resolved_model_path)
|
mlx_model = str(self._resolved_model_path)
|
||||||
else:
|
else:
|
||||||
mlx_model_path = mlx_model_mapping.get(self.model_name)
|
mlx_model = mlx_model_mapping.get(self.model_name)
|
||||||
if not mlx_model_path:
|
if not mlx_model:
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
f"MLX Whisper backend requested but no compatible weights found for model '{self.model_name}'."
|
f"MLX Whisper backend requested but no compatible weights found for model '{self.model_name}'."
|
||||||
)
|
)
|
||||||
self.mlx_model = load_mlx_model(path_or_hf_repo=mlx_model_path)
|
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model)
|
||||||
self._warmup_mlx_model()
|
|
||||||
elif self.encoder_backend == "mlx-whisper":
|
|
||||||
# hybrid mode: mlx encoder + pytorch decoder
|
|
||||||
logger.info('SimulStreaming will use MLX Whisper encoder with PyTorch decoder.')
|
|
||||||
if self._resolved_model_path is not None:
|
|
||||||
mlx_model_path = str(self._resolved_model_path)
|
|
||||||
else:
|
|
||||||
mlx_model_path = mlx_model_mapping.get(self.model_name)
|
|
||||||
if not mlx_model_path:
|
|
||||||
raise FileNotFoundError(
|
|
||||||
f"MLX Whisper backend requested but no compatible weights found for model '{self.model_name}'."
|
|
||||||
)
|
|
||||||
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_path)
|
|
||||||
self.shared_model = self.load_model()
|
|
||||||
elif self.encoder_backend == "faster-whisper":
|
elif self.encoder_backend == "faster-whisper":
|
||||||
print('SimulStreaming will use Faster Whisper for the encoder.')
|
print('Simulstreaming will use Faster Whisper for the encoder.')
|
||||||
if self._resolved_model_path is not None:
|
if self._resolved_model_path is not None:
|
||||||
fw_model = str(self._resolved_model_path)
|
fw_model = str(self._resolved_model_path)
|
||||||
else:
|
else:
|
||||||
@@ -269,20 +237,7 @@ class SimulStreamingASR:
|
|||||||
device='auto',
|
device='auto',
|
||||||
compute_type='auto',
|
compute_type='auto',
|
||||||
)
|
)
|
||||||
self.shared_model = self.load_model()
|
self.shared_model = self.load_model()
|
||||||
else:
|
|
||||||
self.shared_model = self.load_model()
|
|
||||||
|
|
||||||
def _warmup_mlx_model(self):
|
|
||||||
"""Warmup the full MLX model."""
|
|
||||||
warmup_audio = load_file(self.warmup_file)
|
|
||||||
if warmup_audio is not None:
|
|
||||||
temp_model = MLXAlignAtt(
|
|
||||||
cfg=self.cfg,
|
|
||||||
mlx_model=self.mlx_model,
|
|
||||||
)
|
|
||||||
temp_model.warmup(warmup_audio)
|
|
||||||
logger.info("Full MLX model warmed up successfully")
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_encoder_backend(self, preferred_backend, compatible_whisper_mlx, compatible_faster_whisper):
|
def _resolve_encoder_backend(self, preferred_backend, compatible_whisper_mlx, compatible_faster_whisper):
|
||||||
|
|||||||
@@ -47,24 +47,9 @@ class DecoderState:
|
|||||||
|
|
||||||
def clean_cache(self):
|
def clean_cache(self):
|
||||||
"""Clean the kv_cache after each inference step."""
|
"""Clean the kv_cache after each inference step."""
|
||||||
# Explicitly delete tensor references to free GPU memory
|
self.kv_cache = {}
|
||||||
if self.kv_cache:
|
|
||||||
for key in list(self.kv_cache.keys()):
|
|
||||||
tensor = self.kv_cache.pop(key, None)
|
|
||||||
if tensor is not None:
|
|
||||||
del tensor
|
|
||||||
|
|
||||||
# Clear the dict
|
|
||||||
self.kv_cache.clear()
|
|
||||||
|
|
||||||
# Force GPU cache cleanup (only if CUDA is available)
|
|
||||||
import torch
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
if self.decoder_type == "beam" and self.inference is not None:
|
if self.decoder_type == "beam" and self.inference is not None:
|
||||||
# Create NEW dict instead of sharing reference
|
self.inference.kv_cache = self.kv_cache
|
||||||
self.inference.kv_cache = {}
|
|
||||||
if self.token_decoder is not None:
|
if self.token_decoder is not None:
|
||||||
self.token_decoder.reset()
|
self.token_decoder.reset()
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +0,0 @@
|
|||||||
from .decoder_state import MLXDecoderState
|
|
||||||
from .decoders import MLXBeamSearchDecoder, MLXGreedyDecoder, MLXInference
|
|
||||||
from .simul_whisper import MLXAlignAtt
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"MLXAlignAtt",
|
|
||||||
"MLXBeamSearchDecoder",
|
|
||||||
"MLXDecoderState",
|
|
||||||
"MLXGreedyDecoder",
|
|
||||||
"MLXInference",
|
|
||||||
]
|
|
||||||
@@ -1,76 +0,0 @@
|
|||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MLXDecoderState:
|
|
||||||
"""
|
|
||||||
mlx kv cache format: List of ((k, v), (cross_k, cross_v)) tuples per layer,
|
|
||||||
where each element is a tuple of mx.arrays.
|
|
||||||
"""
|
|
||||||
|
|
||||||
kv_cache: Optional[List[Tuple[Tuple[mx.array, mx.array], Tuple[mx.array, mx.array]]]] = None
|
|
||||||
|
|
||||||
tokenizer: Any = None
|
|
||||||
detected_language: Optional[str] = None
|
|
||||||
reset_tokenizer_to_auto_next_call: bool = False
|
|
||||||
|
|
||||||
tokens: List[mx.array] = field(default_factory=list)
|
|
||||||
initial_tokens: Optional[mx.array] = None
|
|
||||||
initial_token_length: int = 0
|
|
||||||
sot_index: int = 0
|
|
||||||
align_source: Dict[int, List[Tuple[int, int]]] = field(default_factory=dict)
|
|
||||||
num_align_heads: int = 0
|
|
||||||
segments: List[np.ndarray] = field(default_factory=list)
|
|
||||||
|
|
||||||
context: Any = None
|
|
||||||
|
|
||||||
pending_incomplete_tokens: List[int] = field(default_factory=list)
|
|
||||||
|
|
||||||
global_time_offset: float = 0.0
|
|
||||||
cumulative_time_offset: float = 0.0
|
|
||||||
first_timestamp: Optional[float] = None
|
|
||||||
last_attend_frame: int = 0
|
|
||||||
|
|
||||||
speaker: int = -1
|
|
||||||
log_segments: int = 0
|
|
||||||
cif_weights: Optional[mx.array] = None
|
|
||||||
always_fire: bool = False
|
|
||||||
never_fire: bool = False
|
|
||||||
|
|
||||||
suppress_tokens: Optional[Tuple[int, ...]] = None
|
|
||||||
|
|
||||||
token_decoder: Any = None
|
|
||||||
decoder_type: str = "greedy"
|
|
||||||
|
|
||||||
inference: Any = None
|
|
||||||
|
|
||||||
def clean_cache(self):
|
|
||||||
self.kv_cache = None
|
|
||||||
if self.decoder_type == "beam" and self.inference is not None:
|
|
||||||
self.inference.kv_cache = None
|
|
||||||
if self.token_decoder is not None:
|
|
||||||
self.token_decoder.reset()
|
|
||||||
|
|
||||||
def reset(self, rewind_threshold: int = 200):
|
|
||||||
self.last_attend_frame = -rewind_threshold
|
|
||||||
self.cumulative_time_offset = 0.0
|
|
||||||
self.pending_incomplete_tokens = []
|
|
||||||
self.log_segments += 1
|
|
||||||
|
|
||||||
def full_reset(self, rewind_threshold: int = 200):
|
|
||||||
"""
|
|
||||||
Full reset including audio segments and tokens.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
rewind_threshold: Value for resetting last_attend_frame
|
|
||||||
"""
|
|
||||||
self.reset(rewind_threshold)
|
|
||||||
self.segments = []
|
|
||||||
self.tokens = []
|
|
||||||
self.kv_cache = None
|
|
||||||
self.first_timestamp = None
|
|
||||||
|
|
||||||
@@ -1,219 +0,0 @@
|
|||||||
"""
|
|
||||||
MLX-native token decoders for streaming ASR.
|
|
||||||
"""
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
class MLXGreedyDecoder:
|
|
||||||
"""Greedy decoder using MLX operations."""
|
|
||||||
|
|
||||||
def __init__(self, temperature: float, eot: int):
|
|
||||||
self.temperature = temperature
|
|
||||||
self.eot = eot
|
|
||||||
|
|
||||||
def update(
|
|
||||||
self, tokens: mx.array, logits: mx.array, sum_logprobs: mx.array
|
|
||||||
) -> Tuple[mx.array, bool]:
|
|
||||||
"""
|
|
||||||
Update tokens with next predicted token.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tokens: Current token sequence, shape (batch, seq_len)
|
|
||||||
logits: Logits for next token, shape (batch, vocab_size)
|
|
||||||
sum_logprobs: Cumulative log probabilities, shape (batch,)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated tokens and completion flag
|
|
||||||
"""
|
|
||||||
if self.temperature == 0:
|
|
||||||
next_tokens = mx.argmax(logits, axis=-1)
|
|
||||||
else:
|
|
||||||
probs = mx.softmax(logits / self.temperature, axis=-1)
|
|
||||||
next_tokens = mx.random.categorical(mx.log(probs + 1e-10))
|
|
||||||
|
|
||||||
logprobs = mx.softmax(logits, axis=-1)
|
|
||||||
logprobs = mx.log(logprobs + 1e-10)
|
|
||||||
batch_size = logprobs.shape[0]
|
|
||||||
current_logprobs = logprobs[mx.arange(batch_size), next_tokens]
|
|
||||||
mask = (tokens[:, -1] != self.eot).astype(mx.float32)
|
|
||||||
sum_logprobs = sum_logprobs + current_logprobs * mask
|
|
||||||
eot_mask = (tokens[:, -1] == self.eot)
|
|
||||||
next_tokens = mx.where(eot_mask, mx.array(self.eot), next_tokens)
|
|
||||||
tokens = mx.concatenate([tokens, next_tokens[:, None]], axis=1)
|
|
||||||
completed = bool(mx.all(tokens[:, -1] == self.eot))
|
|
||||||
|
|
||||||
return tokens, completed
|
|
||||||
|
|
||||||
def finalize(self, tokens: mx.array, sum_logprobs: mx.array):
|
|
||||||
"""Finalize decoding by ensuring EOT at end."""
|
|
||||||
eot_column = mx.full((tokens.shape[0], 1), self.eot, dtype=tokens.dtype)
|
|
||||||
tokens = mx.concatenate([tokens, eot_column], axis=1)
|
|
||||||
return tokens, sum_logprobs.tolist()
|
|
||||||
|
|
||||||
|
|
||||||
class MLXBeamSearchDecoder:
|
|
||||||
"""Beam search decoder using MLX operations."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
beam_size: int,
|
|
||||||
eot: int,
|
|
||||||
inference: Any,
|
|
||||||
patience: Optional[float] = None,
|
|
||||||
):
|
|
||||||
self.beam_size = beam_size
|
|
||||||
self.eot = eot
|
|
||||||
self.inference = inference
|
|
||||||
self.patience = patience or 1.0
|
|
||||||
self.max_candidates: int = round(beam_size * self.patience)
|
|
||||||
self.finished_sequences: Optional[List[Dict]] = None
|
|
||||||
|
|
||||||
assert (
|
|
||||||
self.max_candidates > 0
|
|
||||||
), f"Invalid beam size ({beam_size}) or patience ({patience})"
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
"""Reset finished sequences for new segment."""
|
|
||||||
self.finished_sequences = None
|
|
||||||
|
|
||||||
def update(
|
|
||||||
self, tokens: mx.array, logits: mx.array, sum_logprobs: mx.array
|
|
||||||
) -> Tuple[mx.array, bool]:
|
|
||||||
"""
|
|
||||||
Update tokens using beam search.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tokens: Current token sequences, shape (batch * beam_size, seq_len)
|
|
||||||
logits: Logits for next token, shape (batch * beam_size, vocab_size)
|
|
||||||
sum_logprobs: Cumulative log probabilities, shape (batch * beam_size,)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated tokens and completion flag
|
|
||||||
"""
|
|
||||||
if tokens.shape[0] % self.beam_size != 0:
|
|
||||||
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
|
|
||||||
|
|
||||||
n_audio = tokens.shape[0] // self.beam_size
|
|
||||||
if self.finished_sequences is None:
|
|
||||||
self.finished_sequences = [{} for _ in range(n_audio)]
|
|
||||||
logprobs = mx.softmax(logits, axis=-1)
|
|
||||||
logprobs = mx.log(logprobs + 1e-10)
|
|
||||||
logprobs_np = np.array(logprobs)
|
|
||||||
tokens_np = np.array(tokens)
|
|
||||||
sum_logprobs_np = np.array(sum_logprobs)
|
|
||||||
|
|
||||||
next_tokens, source_indices, finished_sequences = [], [], []
|
|
||||||
new_sum_logprobs = []
|
|
||||||
|
|
||||||
for i in range(n_audio):
|
|
||||||
scores, sources, finished = {}, {}, {}
|
|
||||||
for j in range(self.beam_size):
|
|
||||||
idx = i * self.beam_size + j
|
|
||||||
prefix = tokens_np[idx].tolist()
|
|
||||||
top_k_indices = np.argsort(logprobs_np[idx])[-self.beam_size - 1:][::-1]
|
|
||||||
|
|
||||||
for token_idx in top_k_indices:
|
|
||||||
logprob = logprobs_np[idx, token_idx]
|
|
||||||
new_logprob = sum_logprobs_np[idx] + logprob
|
|
||||||
sequence = tuple(prefix + [int(token_idx)])
|
|
||||||
scores[sequence] = new_logprob
|
|
||||||
sources[sequence] = idx
|
|
||||||
saved = 0
|
|
||||||
for sequence in sorted(scores, key=scores.get, reverse=True):
|
|
||||||
if sequence[-1] == self.eot:
|
|
||||||
finished[sequence] = scores[sequence]
|
|
||||||
else:
|
|
||||||
new_sum_logprobs.append(scores[sequence])
|
|
||||||
next_tokens.append(sequence)
|
|
||||||
source_indices.append(sources[sequence])
|
|
||||||
|
|
||||||
saved += 1
|
|
||||||
if saved == self.beam_size:
|
|
||||||
break
|
|
||||||
|
|
||||||
finished_sequences.append(finished)
|
|
||||||
tokens = mx.array(np.array(next_tokens, dtype=np.int32))
|
|
||||||
sum_logprobs = mx.array(np.array(new_sum_logprobs, dtype=np.float32))
|
|
||||||
self.inference.rearrange_kv_cache(source_indices)
|
|
||||||
assert len(self.finished_sequences) == len(finished_sequences)
|
|
||||||
for previously_finished, newly_finished in zip(
|
|
||||||
self.finished_sequences, finished_sequences
|
|
||||||
):
|
|
||||||
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
|
|
||||||
if len(previously_finished) >= self.max_candidates:
|
|
||||||
break
|
|
||||||
previously_finished[seq] = newly_finished[seq]
|
|
||||||
completed = all(
|
|
||||||
len(sequences) >= self.max_candidates
|
|
||||||
for sequences in self.finished_sequences
|
|
||||||
)
|
|
||||||
|
|
||||||
return tokens, completed
|
|
||||||
|
|
||||||
def finalize(self, preceding_tokens: mx.array, sum_logprobs: mx.array):
|
|
||||||
"""Finalize beam search by selecting best sequences."""
|
|
||||||
preceding_tokens_np = np.array(preceding_tokens)
|
|
||||||
sum_logprobs_np = np.array(sum_logprobs)
|
|
||||||
|
|
||||||
n_audio = preceding_tokens_np.shape[0] // self.beam_size
|
|
||||||
tokens_list: List[List[int]] = [[] for _ in range(n_audio)]
|
|
||||||
sum_logprobs_list: List[float] = [0.0] * n_audio
|
|
||||||
|
|
||||||
for i, sequences in enumerate(self.finished_sequences):
|
|
||||||
if sequences:
|
|
||||||
best_seq = max(sequences, key=sequences.get)
|
|
||||||
tokens_list[i] = list(best_seq)
|
|
||||||
sum_logprobs_list[i] = sequences[best_seq]
|
|
||||||
else:
|
|
||||||
idx = i * self.beam_size
|
|
||||||
tokens_list[i] = preceding_tokens_np[idx].tolist() + [self.eot]
|
|
||||||
sum_logprobs_list[i] = float(sum_logprobs_np[idx])
|
|
||||||
max_len = max(len(t) for t in tokens_list)
|
|
||||||
for i, t in enumerate(tokens_list):
|
|
||||||
tokens_list[i] = t + [self.eot] * (max_len - len(t))
|
|
||||||
|
|
||||||
tokens = mx.array(np.array(tokens_list, dtype=np.int32))
|
|
||||||
return tokens, sum_logprobs_list
|
|
||||||
|
|
||||||
|
|
||||||
class MLXInference:
|
|
||||||
"""MLX inference wrapper for beam search KV cache management."""
|
|
||||||
|
|
||||||
def __init__(self, model, initial_token_length: int):
|
|
||||||
self.model = model
|
|
||||||
self.initial_token_length = initial_token_length
|
|
||||||
self.kv_cache = None
|
|
||||||
|
|
||||||
def rearrange_kv_cache(self, source_indices: List[int]):
|
|
||||||
"""Rearrange KV cache based on beam search source indices."""
|
|
||||||
if self.kv_cache is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
if source_indices == list(range(len(source_indices))):
|
|
||||||
return
|
|
||||||
|
|
||||||
source_indices_mx = mx.array(source_indices, dtype=mx.int32)
|
|
||||||
|
|
||||||
new_cache = []
|
|
||||||
for layer_cache in self.kv_cache:
|
|
||||||
(k, v), (cross_k, cross_v) = layer_cache
|
|
||||||
new_k = k[source_indices_mx]
|
|
||||||
new_v = v[source_indices_mx]
|
|
||||||
new_cache.append(((new_k, new_v), (cross_k, cross_v)))
|
|
||||||
|
|
||||||
self.kv_cache = new_cache
|
|
||||||
|
|
||||||
def logits(
|
|
||||||
self,
|
|
||||||
tokens: mx.array,
|
|
||||||
audio_features: mx.array,
|
|
||||||
) -> Tuple[mx.array, List]:
|
|
||||||
"""Get logits from decoder with KV cache."""
|
|
||||||
logits, self.kv_cache, cross_qk = self.model.decoder(
|
|
||||||
tokens, audio_features, kv_cache=self.kv_cache
|
|
||||||
)
|
|
||||||
return logits, cross_qk
|
|
||||||
|
|
||||||
@@ -1,752 +0,0 @@
|
|||||||
"""
|
|
||||||
MLX whisper AlignAtt streaming decoder
|
|
||||||
"""
|
|
||||||
import logging
|
|
||||||
from time import time
|
|
||||||
from typing import Any, List, Optional, Tuple
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
|
|
||||||
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
|
|
||||||
|
|
||||||
from whisperlivekit.timed_objects import ASRToken
|
|
||||||
from whisperlivekit.whisper import DecodingOptions, tokenizer
|
|
||||||
from whisperlivekit.whisper.audio import N_FRAMES, N_SAMPLES, TOKENS_PER_SECOND
|
|
||||||
|
|
||||||
from ..config import AlignAttConfig
|
|
||||||
from .decoder_state import MLXDecoderState
|
|
||||||
from .decoders import MLXBeamSearchDecoder, MLXGreedyDecoder, MLXInference
|
|
||||||
|
|
||||||
DEC_PAD = 50257
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class MLXTokenBuffer: #should try to make it heritate from classic simul whisper class
|
|
||||||
"""Token buffer for MLX-based decoding."""
|
|
||||||
|
|
||||||
def __init__(self, text="", tokenizer=None, prefix_token_ids=None):
|
|
||||||
self.text = text
|
|
||||||
self.prefix_token_ids = prefix_token_ids or []
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.pending_token_ids = []
|
|
||||||
|
|
||||||
def as_token_ids(self, tokenizer=None):
|
|
||||||
if tokenizer is None:
|
|
||||||
tokenizer = self.tokenizer
|
|
||||||
if tokenizer is None:
|
|
||||||
raise ValueError("Tokenizer is not set.")
|
|
||||||
return self.prefix_token_ids + tokenizer.encode(self.text)
|
|
||||||
|
|
||||||
def as_mlx_array(self) -> mx.array:
|
|
||||||
"""Return tokens as MLX array."""
|
|
||||||
tok_ids = self.as_token_ids()
|
|
||||||
return mx.array([tok_ids], dtype=mx.int32)
|
|
||||||
|
|
||||||
def as_mlx_array_beam(self, beam: int) -> mx.array:
|
|
||||||
"""Return tokens as MLX array repeated for beam search."""
|
|
||||||
t = self.as_mlx_array()
|
|
||||||
return mx.repeat(t, beam, axis=0)
|
|
||||||
|
|
||||||
def as_text(self):
|
|
||||||
return self.text
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def empty(*a, **kw):
|
|
||||||
return MLXTokenBuffer(*a, **kw)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_text(text, *a, **kw):
|
|
||||||
return MLXTokenBuffer(*a, text=text, **kw)
|
|
||||||
|
|
||||||
def is_empty(self):
|
|
||||||
return self.text is None or self.text == ""
|
|
||||||
|
|
||||||
def trim_words(self, num=1, after=0):
|
|
||||||
"""Trim words from the beginning of the context."""
|
|
||||||
tokenizer = self.tokenizer
|
|
||||||
assert tokenizer is not None, "Tokenizer is not set."
|
|
||||||
|
|
||||||
ids = tokenizer.encode(self.text[after:])
|
|
||||||
words, wids = self.tokenizer.split_to_word_tokens(ids)
|
|
||||||
if not words:
|
|
||||||
return 0
|
|
||||||
self.text = self.text[:after] + "".join(words[num:])
|
|
||||||
return sum(len(wi) for wi in wids[:num])
|
|
||||||
|
|
||||||
def append_token_ids(self, token_ids):
|
|
||||||
"""Append token IDs to the buffer, handling incomplete UTF-8."""
|
|
||||||
tokenizer = self.tokenizer
|
|
||||||
assert tokenizer is not None, "Tokenizer is not set."
|
|
||||||
|
|
||||||
all_tokens = self.pending_token_ids + token_ids
|
|
||||||
decoded = tokenizer.decode(all_tokens)
|
|
||||||
replacement_char = "\ufffd"
|
|
||||||
|
|
||||||
if replacement_char in decoded:
|
|
||||||
if len(all_tokens) > 1:
|
|
||||||
decoded_partial = tokenizer.decode(all_tokens[:-1])
|
|
||||||
if replacement_char not in decoded_partial:
|
|
||||||
self.text += decoded_partial
|
|
||||||
self.pending_token_ids = [all_tokens[-1]]
|
|
||||||
else:
|
|
||||||
self.pending_token_ids = all_tokens
|
|
||||||
else:
|
|
||||||
self.pending_token_ids = all_tokens
|
|
||||||
else:
|
|
||||||
self.text += decoded
|
|
||||||
self.pending_token_ids = []
|
|
||||||
|
|
||||||
|
|
||||||
def mlx_median_filter(x: mx.array, filter_width: int) -> mx.array:
|
|
||||||
"""
|
|
||||||
Apply median filter along the last axis.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x: Input array of shape (..., T)
|
|
||||||
filter_width: Width of the median filter (should be odd)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Filtered array of same shape
|
|
||||||
"""
|
|
||||||
if filter_width <= 1:
|
|
||||||
return x
|
|
||||||
|
|
||||||
pad_width = filter_width // 2
|
|
||||||
shape = x.shape
|
|
||||||
|
|
||||||
left_pad = mx.repeat(x[..., :1], pad_width, axis=-1)
|
|
||||||
right_pad = mx.repeat(x[..., -1:], pad_width, axis=-1)
|
|
||||||
x_padded = mx.concatenate([left_pad, x, right_pad], axis=-1)
|
|
||||||
|
|
||||||
result_shape = list(shape)
|
|
||||||
result = []
|
|
||||||
|
|
||||||
for i in range(shape[-1]):
|
|
||||||
window = x_padded[..., i:i + filter_width]
|
|
||||||
sorted_window = mx.sort(window, axis=-1)
|
|
||||||
median_val = sorted_window[..., filter_width // 2:filter_width // 2 + 1]
|
|
||||||
result.append(median_val)
|
|
||||||
|
|
||||||
return mx.concatenate(result, axis=-1)
|
|
||||||
|
|
||||||
|
|
||||||
class MLXAlignAtt:
|
|
||||||
"""
|
|
||||||
MLX-native Alignment-based Attention decoder for SimulStreaming.
|
|
||||||
|
|
||||||
This class runs entirely on MLX, with no PyTorch dependencies for inference.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def speaker(self):
|
|
||||||
return self.state.speaker
|
|
||||||
|
|
||||||
@speaker.setter
|
|
||||||
def speaker(self, value):
|
|
||||||
self.state.speaker = value
|
|
||||||
|
|
||||||
@property
|
|
||||||
def global_time_offset(self):
|
|
||||||
return self.state.global_time_offset
|
|
||||||
|
|
||||||
@global_time_offset.setter
|
|
||||||
def global_time_offset(self, value):
|
|
||||||
self.state.global_time_offset = value
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
cfg: AlignAttConfig,
|
|
||||||
mlx_model: Any,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Initialize MLX AlignAtt decoder.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cfg: AlignAtt configuration
|
|
||||||
mlx_model: MLX Whisper model (full model, not just encoder)
|
|
||||||
"""
|
|
||||||
self.model = mlx_model
|
|
||||||
self.cfg = cfg
|
|
||||||
|
|
||||||
logger.info(f"MLX Model dimensions: {self.model.dims}")
|
|
||||||
|
|
||||||
self.decode_options = DecodingOptions(
|
|
||||||
language=cfg.language,
|
|
||||||
without_timestamps=True,
|
|
||||||
task=cfg.task
|
|
||||||
)
|
|
||||||
self.tokenizer_is_multilingual = cfg.tokenizer_is_multilingual
|
|
||||||
|
|
||||||
self.max_text_len = self.model.dims.n_text_ctx
|
|
||||||
self.num_decoder_layers = len(self.model.decoder.blocks)
|
|
||||||
|
|
||||||
if self.cfg.max_context_tokens is None:
|
|
||||||
self.max_context_tokens = self.max_text_len
|
|
||||||
else:
|
|
||||||
self.max_context_tokens = self.cfg.max_context_tokens
|
|
||||||
|
|
||||||
# Initialize per-session state
|
|
||||||
self.state = MLXDecoderState()
|
|
||||||
self._init_state(cfg)
|
|
||||||
|
|
||||||
def _init_state(self, cfg: AlignAttConfig):
|
|
||||||
"""Initialize the per-session decoder state."""
|
|
||||||
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
|
|
||||||
self.state.tokenizer = self.tokenizer
|
|
||||||
self.state.detected_language = cfg.language if cfg.language != "auto" else None
|
|
||||||
self.state.global_time_offset = 0.0
|
|
||||||
self.state.last_attend_frame = -cfg.rewind_threshold
|
|
||||||
self.state.speaker = -1
|
|
||||||
|
|
||||||
if cfg.cif_ckpt_path is None or not cfg.cif_ckpt_path:
|
|
||||||
if cfg.never_fire:
|
|
||||||
self.state.never_fire = True
|
|
||||||
self.state.always_fire = False
|
|
||||||
else:
|
|
||||||
self.state.always_fire = True
|
|
||||||
self.state.never_fire = False
|
|
||||||
else:
|
|
||||||
logger.warning("CIF checkpoint provided but MLX CIF not implemented. Using always_fire=True")
|
|
||||||
self.state.always_fire = True
|
|
||||||
self.state.never_fire = cfg.never_fire
|
|
||||||
|
|
||||||
self._build_alignment_source()
|
|
||||||
|
|
||||||
suppress_tokens = [
|
|
||||||
self.tokenizer.transcribe,
|
|
||||||
self.tokenizer.translate,
|
|
||||||
self.tokenizer.sot,
|
|
||||||
self.tokenizer.sot_prev,
|
|
||||||
self.tokenizer.sot_lm,
|
|
||||||
self.tokenizer.no_timestamps,
|
|
||||||
] + list(self.tokenizer.all_language_tokens)
|
|
||||||
if self.tokenizer.no_speech is not None:
|
|
||||||
suppress_tokens.append(self.tokenizer.no_speech)
|
|
||||||
self.state.suppress_tokens = tuple(sorted(set(suppress_tokens)))
|
|
||||||
logger.debug(f"Suppress tokens: {self.state.suppress_tokens}")
|
|
||||||
|
|
||||||
self.init_tokens()
|
|
||||||
self.init_context()
|
|
||||||
|
|
||||||
self.state.decoder_type = cfg.decoder_type
|
|
||||||
if cfg.decoder_type == "greedy":
|
|
||||||
logger.info("Using MLX greedy decoder")
|
|
||||||
self.state.token_decoder = MLXGreedyDecoder(0.0, self.tokenizer.eot)
|
|
||||||
elif cfg.decoder_type == "beam":
|
|
||||||
logger.info("Using MLX beam decoder")
|
|
||||||
self.state.inference = MLXInference(self.model, self.state.initial_token_length)
|
|
||||||
self.state.token_decoder = MLXBeamSearchDecoder(
|
|
||||||
inference=self.state.inference,
|
|
||||||
eot=self.tokenizer.eot,
|
|
||||||
beam_size=cfg.beam_size
|
|
||||||
)
|
|
||||||
|
|
||||||
def _build_alignment_source(self):
|
|
||||||
"""Build alignment source mapping from model's alignment_heads."""
|
|
||||||
self.state.align_source = {}
|
|
||||||
self.state.num_align_heads = 0
|
|
||||||
|
|
||||||
alignment_heads = self.model.alignment_heads
|
|
||||||
|
|
||||||
if alignment_heads is None:
|
|
||||||
logger.warning("No alignment heads found in model")
|
|
||||||
return
|
|
||||||
|
|
||||||
if hasattr(alignment_heads, 'tolist'):
|
|
||||||
heads_list = alignment_heads.tolist()
|
|
||||||
else:
|
|
||||||
heads_list = np.array(alignment_heads).tolist()
|
|
||||||
|
|
||||||
for layer_rank, head_id in heads_list:
|
|
||||||
layer_rank = int(layer_rank)
|
|
||||||
head_id = int(head_id)
|
|
||||||
heads = self.state.align_source.get(layer_rank, [])
|
|
||||||
heads.append((self.state.num_align_heads, head_id))
|
|
||||||
self.state.align_source[layer_rank] = heads
|
|
||||||
self.state.num_align_heads += 1
|
|
||||||
|
|
||||||
def warmup(self, audio: np.ndarray):
|
|
||||||
"""Warmup the model with sample audio."""
|
|
||||||
try:
|
|
||||||
self.insert_audio(audio)
|
|
||||||
self.infer(is_last=True)
|
|
||||||
self.refresh_segment(complete=True)
|
|
||||||
logger.info("MLX model warmed up successfully")
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f"MLX model warmup failed: {e}")
|
|
||||||
|
|
||||||
def create_tokenizer(self, language=None):
|
|
||||||
"""Create tokenizer for the given language."""
|
|
||||||
self.tokenizer = tokenizer.get_tokenizer(
|
|
||||||
multilingual=self.tokenizer_is_multilingual,
|
|
||||||
language=language,
|
|
||||||
num_languages=self.model.num_languages,
|
|
||||||
task=self.decode_options.task
|
|
||||||
)
|
|
||||||
self.state.tokenizer = self.tokenizer
|
|
||||||
|
|
||||||
def init_context(self):
|
|
||||||
"""Initialize context buffer."""
|
|
||||||
kw = {
|
|
||||||
'tokenizer': self.tokenizer,
|
|
||||||
'prefix_token_ids': [self.tokenizer.sot_prev]
|
|
||||||
}
|
|
||||||
self.state.context = MLXTokenBuffer.empty(**kw)
|
|
||||||
if self.cfg.static_init_prompt is not None:
|
|
||||||
self.state.context = MLXTokenBuffer.from_text(self.cfg.static_init_prompt, **kw)
|
|
||||||
if self.cfg.init_prompt is not None:
|
|
||||||
self.state.context.text += self.cfg.init_prompt
|
|
||||||
|
|
||||||
def init_tokens(self):
|
|
||||||
"""Initialize token sequence."""
|
|
||||||
logger.debug(f"init tokens, {len(self.state.segments)}")
|
|
||||||
self.state.initial_tokens = mx.array(
|
|
||||||
[self.tokenizer.sot_sequence_including_notimestamps],
|
|
||||||
dtype=mx.int32
|
|
||||||
)
|
|
||||||
self.state.initial_token_length = self.state.initial_tokens.shape[1]
|
|
||||||
self.state.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
|
|
||||||
logger.debug(f"init tokens after, {len(self.state.segments)}")
|
|
||||||
self.state.tokens = [self.state.initial_tokens]
|
|
||||||
|
|
||||||
def trim_context(self):
|
|
||||||
"""Trim context if too long."""
|
|
||||||
logger.info("Trimming context")
|
|
||||||
c = len(self.state.context.as_token_ids()) - len(self.state.context.prefix_token_ids)
|
|
||||||
logger.info(f"Context text: {self.state.context.as_text()}")
|
|
||||||
l = sum(t.shape[1] for t in self.state.tokens) + c
|
|
||||||
if self.cfg.static_init_prompt is None:
|
|
||||||
after = 0
|
|
||||||
else:
|
|
||||||
after = len(self.cfg.static_init_prompt)
|
|
||||||
while c > self.max_context_tokens or l > self.max_text_len - 20:
|
|
||||||
t = self.state.context.trim_words(after=after)
|
|
||||||
l -= t
|
|
||||||
c -= t
|
|
||||||
logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
|
|
||||||
if t == 0:
|
|
||||||
break
|
|
||||||
logger.info(f"Context after trim: {self.state.context.text} (len: {l})")
|
|
||||||
|
|
||||||
def refresh_segment(self, complete=False):
|
|
||||||
"""Refresh segment state."""
|
|
||||||
logger.debug("Refreshing segment:")
|
|
||||||
self.init_tokens()
|
|
||||||
self.state.last_attend_frame = -self.cfg.rewind_threshold
|
|
||||||
self.state.cumulative_time_offset = 0.0
|
|
||||||
self.init_context()
|
|
||||||
logger.debug(f"Context: {self.state.context}")
|
|
||||||
if not complete and len(self.state.segments) > 2:
|
|
||||||
self.state.segments = self.state.segments[-2:]
|
|
||||||
else:
|
|
||||||
logger.debug("removing all segments.")
|
|
||||||
self.state.segments = []
|
|
||||||
self.state.log_segments += 1
|
|
||||||
self.state.pending_incomplete_tokens = []
|
|
||||||
|
|
||||||
def fire_at_boundary(self, chunked_encoder_feature: mx.array) -> bool:
|
|
||||||
"""Check if we should fire at word boundary (CIF-based)."""
|
|
||||||
if self.state.always_fire:
|
|
||||||
return True
|
|
||||||
if self.state.never_fire:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _current_tokens(self) -> mx.array:
|
|
||||||
"""Get current token sequence for decoding."""
|
|
||||||
toks = self.state.tokens
|
|
||||||
|
|
||||||
if toks[0].shape[0] == 1:
|
|
||||||
toks[0] = mx.repeat(toks[0], self.cfg.beam_size, axis=0)
|
|
||||||
|
|
||||||
if not self.state.context.is_empty():
|
|
||||||
context_toks = self.state.context.as_mlx_array_beam(self.cfg.beam_size)
|
|
||||||
toks = [context_toks] + toks
|
|
||||||
|
|
||||||
# Concatenate all tokens
|
|
||||||
if len(toks) > 1:
|
|
||||||
current_tokens = mx.concatenate(toks, axis=1)
|
|
||||||
else:
|
|
||||||
current_tokens = toks[0]
|
|
||||||
|
|
||||||
logger.debug("debug print current_tokens:")
|
|
||||||
self.debug_print_tokens(current_tokens)
|
|
||||||
return current_tokens
|
|
||||||
|
|
||||||
def debug_print_tokens(self, tokens: mx.array):
|
|
||||||
"""Debug print token sequences."""
|
|
||||||
tokens_np = np.array(tokens)
|
|
||||||
for i in range(min(self.cfg.beam_size, tokens_np.shape[0])):
|
|
||||||
logger.debug(self.tokenizer.decode_with_timestamps(tokens_np[i].tolist()))
|
|
||||||
|
|
||||||
def segments_len(self) -> float:
|
|
||||||
"""Get total length of audio segments in seconds."""
|
|
||||||
return sum(s.shape[0] for s in self.state.segments) / 16000
|
|
||||||
|
|
||||||
def _apply_minseglen(self) -> bool:
|
|
||||||
"""Check if we have enough audio to process."""
|
|
||||||
segments_len = self.segments_len()
|
|
||||||
if segments_len < self.cfg.audio_min_len:
|
|
||||||
logger.debug("waiting for next segment")
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def insert_audio(self, segment: np.ndarray = None):
|
|
||||||
"""Insert audio segment into buffer."""
|
|
||||||
if segment is not None:
|
|
||||||
if hasattr(segment, 'numpy'):
|
|
||||||
segment = segment.numpy()
|
|
||||||
self.state.segments.append(segment)
|
|
||||||
|
|
||||||
removed_len = 0
|
|
||||||
segments_len = self.segments_len()
|
|
||||||
|
|
||||||
while len(self.state.segments) > 1 and segments_len > self.cfg.audio_max_len:
|
|
||||||
removed_len = self.state.segments[0].shape[0] / 16000
|
|
||||||
segments_len -= removed_len
|
|
||||||
self.state.last_attend_frame -= int(TOKENS_PER_SECOND * removed_len)
|
|
||||||
self.state.cumulative_time_offset += removed_len
|
|
||||||
self.state.segments = self.state.segments[1:]
|
|
||||||
logger.debug(f"remove segments: {len(self.state.segments)} {len(self.state.tokens)}, cumulative offset: {self.state.cumulative_time_offset:.2f}s")
|
|
||||||
|
|
||||||
if len(self.state.tokens) > 1:
|
|
||||||
# Convert MLX array to list for context
|
|
||||||
token_list = np.array(self.state.tokens[1][0, :]).tolist()
|
|
||||||
self.state.context.append_token_ids(token_list)
|
|
||||||
self.state.tokens = [self.state.initial_tokens] + self.state.tokens[2:]
|
|
||||||
|
|
||||||
return removed_len
|
|
||||||
|
|
||||||
def _clean_cache(self):
|
|
||||||
"""Clean the kv_cache after each inference step."""
|
|
||||||
self.state.clean_cache()
|
|
||||||
|
|
||||||
def _suppress_tokens(self, logits: mx.array) -> mx.array:
|
|
||||||
"""Apply token suppression to logits."""
|
|
||||||
if self.state.suppress_tokens:
|
|
||||||
suppress_indices = mx.array(list(self.state.suppress_tokens), dtype=mx.int32)
|
|
||||||
logits = logits.at[:, suppress_indices].add(-float('inf'))
|
|
||||||
return logits
|
|
||||||
|
|
||||||
def lang_id(self, encoder_features: mx.array) -> Tuple[mx.array, List[dict]]:
|
|
||||||
"""Language detection from encoder features."""
|
|
||||||
n_audio = encoder_features.shape[0]
|
|
||||||
x = mx.array([[self.tokenizer.sot]] * n_audio, dtype=mx.int32)
|
|
||||||
|
|
||||||
logits, _, _ = self.model.decoder(x, encoder_features, kv_cache=None)
|
|
||||||
logits = logits[:, 0]
|
|
||||||
|
|
||||||
mask = mx.ones(logits.shape[-1], dtype=mx.bool_)
|
|
||||||
language_token_indices = mx.array(list(self.tokenizer.all_language_tokens), dtype=mx.int32)
|
|
||||||
mask = mask.at[language_token_indices].add(False)
|
|
||||||
|
|
||||||
logits = mx.where(mask, mx.array(-float('inf')), logits)
|
|
||||||
|
|
||||||
language_tokens = mx.argmax(logits, axis=-1)
|
|
||||||
language_token_probs = mx.softmax(logits, axis=-1)
|
|
||||||
|
|
||||||
probs_np = np.array(language_token_probs)
|
|
||||||
|
|
||||||
language_probs = [
|
|
||||||
{
|
|
||||||
c: float(probs_np[i, j])
|
|
||||||
for j, c in zip(self.tokenizer.all_language_tokens, self.tokenizer.all_language_codes)
|
|
||||||
}
|
|
||||||
for i in range(n_audio)
|
|
||||||
]
|
|
||||||
|
|
||||||
self._clean_cache()
|
|
||||||
return language_tokens, language_probs
|
|
||||||
|
|
||||||
def infer(self, is_last: bool = False) -> List[ASRToken]:
|
|
||||||
"""
|
|
||||||
Main inference method.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
is_last: Whether this is the final chunk
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of timestamped ASR tokens
|
|
||||||
"""
|
|
||||||
new_segment = True
|
|
||||||
|
|
||||||
if len(self.state.segments) == 0:
|
|
||||||
logger.debug("No segments, nothing to do")
|
|
||||||
return []
|
|
||||||
|
|
||||||
if not self._apply_minseglen():
|
|
||||||
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
|
|
||||||
return []
|
|
||||||
|
|
||||||
if len(self.state.segments) > 1:
|
|
||||||
input_segments = np.concatenate(self.state.segments, axis=0)
|
|
||||||
else:
|
|
||||||
input_segments = self.state.segments[0]
|
|
||||||
|
|
||||||
beg_encode = time()
|
|
||||||
|
|
||||||
mlx_mel_padded = mlx_log_mel_spectrogram(
|
|
||||||
audio=input_segments,
|
|
||||||
n_mels=self.model.dims.n_mels,
|
|
||||||
padding=N_SAMPLES
|
|
||||||
)
|
|
||||||
mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2)
|
|
||||||
encoder_feature = self.model.encoder(mlx_mel[None])
|
|
||||||
content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0]) / 2)
|
|
||||||
|
|
||||||
mx.eval(encoder_feature)
|
|
||||||
|
|
||||||
end_encode = time()
|
|
||||||
logger.debug(f'MLX Encoder duration: {end_encode - beg_encode:.3f}s')
|
|
||||||
|
|
||||||
if self.cfg.language == "auto" and self.state.detected_language is None and self.state.first_timestamp:
|
|
||||||
seconds_since_start = self.segments_len() - self.state.first_timestamp
|
|
||||||
if seconds_since_start >= 2.0:
|
|
||||||
language_tokens, language_probs = self.lang_id(encoder_feature)
|
|
||||||
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
|
|
||||||
print(f"Detected language: {top_lan} with p={p:.4f}")
|
|
||||||
self.create_tokenizer(top_lan)
|
|
||||||
self.state.last_attend_frame = -self.cfg.rewind_threshold
|
|
||||||
self.state.cumulative_time_offset = 0.0
|
|
||||||
self.init_tokens()
|
|
||||||
self.init_context()
|
|
||||||
self.state.detected_language = top_lan
|
|
||||||
logger.info(f"Tokenizer language: {self.tokenizer.language}")
|
|
||||||
|
|
||||||
self.trim_context()
|
|
||||||
current_tokens = self._current_tokens()
|
|
||||||
|
|
||||||
fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :])
|
|
||||||
|
|
||||||
sum_logprobs = mx.zeros((self.cfg.beam_size,), dtype=mx.float32)
|
|
||||||
completed = False
|
|
||||||
|
|
||||||
attn_of_alignment_heads = None
|
|
||||||
most_attended_frame = None
|
|
||||||
|
|
||||||
token_len_before_decoding = current_tokens.shape[1]
|
|
||||||
|
|
||||||
l_absolute_timestamps = []
|
|
||||||
accumulated_cross_attns = []
|
|
||||||
|
|
||||||
audio_duration_s = self.segments_len()
|
|
||||||
max_tokens_per_chunk = max(50, int(audio_duration_s * TOKENS_PER_SECOND * 2.0))
|
|
||||||
tokens_produced_this_chunk = 0
|
|
||||||
|
|
||||||
while not completed and current_tokens.shape[1] < self.max_text_len:
|
|
||||||
tokens_produced_this_chunk += 1
|
|
||||||
|
|
||||||
if tokens_produced_this_chunk > max_tokens_per_chunk:
|
|
||||||
logger.warning(f"[Loop Detection] Too many tokens ({tokens_produced_this_chunk}) for {audio_duration_s:.2f}s audio. Breaking.")
|
|
||||||
current_tokens = current_tokens[:, :token_len_before_decoding]
|
|
||||||
break
|
|
||||||
|
|
||||||
if new_segment:
|
|
||||||
tokens_for_logits = current_tokens
|
|
||||||
else:
|
|
||||||
tokens_for_logits = current_tokens[:, -1:]
|
|
||||||
|
|
||||||
if self.state.decoder_type == "greedy":
|
|
||||||
logits, self.state.kv_cache, cross_qk = self.model.decoder(
|
|
||||||
tokens_for_logits, encoder_feature, kv_cache=self.state.kv_cache
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logits, cross_qk = self.state.inference.logits(tokens_for_logits, encoder_feature)
|
|
||||||
|
|
||||||
mx.eval(logits)
|
|
||||||
|
|
||||||
accumulated_cross_attns.append(cross_qk)
|
|
||||||
|
|
||||||
if new_segment and self.tokenizer.no_speech is not None:
|
|
||||||
probs_at_sot = mx.softmax(logits[:, self.state.sot_index, :], axis=-1)
|
|
||||||
no_speech_probs = np.array(probs_at_sot[:, self.tokenizer.no_speech]).tolist()
|
|
||||||
if no_speech_probs[0] > self.cfg.nonspeech_prob:
|
|
||||||
logger.info("no speech, stop")
|
|
||||||
break
|
|
||||||
|
|
||||||
logits = logits[:, -1, :] # Last token logits
|
|
||||||
|
|
||||||
# Suppress tokens at segment start
|
|
||||||
if new_segment:
|
|
||||||
blank_tokens = self.tokenizer.encode(" ") + [self.tokenizer.eot]
|
|
||||||
logits = logits.at[:, blank_tokens].add(-float('inf'))
|
|
||||||
new_segment = False
|
|
||||||
|
|
||||||
logits = self._suppress_tokens(logits)
|
|
||||||
|
|
||||||
current_tokens, completed = self.state.token_decoder.update(
|
|
||||||
current_tokens, logits, sum_logprobs
|
|
||||||
)
|
|
||||||
mx.eval(current_tokens)
|
|
||||||
|
|
||||||
logger.debug(f"Decoding completed: {completed}")
|
|
||||||
self.debug_print_tokens(current_tokens)
|
|
||||||
|
|
||||||
attn_of_alignment_heads = self._process_cross_attention(
|
|
||||||
accumulated_cross_attns, content_mel_len
|
|
||||||
)
|
|
||||||
|
|
||||||
most_attended_frames = mx.argmax(attn_of_alignment_heads[:, -1, :], axis=-1)
|
|
||||||
most_attended_frames_np = np.array(most_attended_frames)
|
|
||||||
|
|
||||||
absolute_timestamps = [
|
|
||||||
(frame * 0.02 + self.state.cumulative_time_offset)
|
|
||||||
for frame in most_attended_frames_np.tolist()
|
|
||||||
]
|
|
||||||
|
|
||||||
logger.debug(str(most_attended_frames_np.tolist()) + " most att frames")
|
|
||||||
logger.debug(f"Absolute timestamps: {absolute_timestamps}")
|
|
||||||
|
|
||||||
most_attended_frame = int(most_attended_frames_np[0])
|
|
||||||
l_absolute_timestamps.append(absolute_timestamps[0])
|
|
||||||
|
|
||||||
if completed:
|
|
||||||
current_tokens = current_tokens[:, :-1]
|
|
||||||
break
|
|
||||||
if not is_last and self.state.last_attend_frame - most_attended_frame > self.cfg.rewind_threshold:
|
|
||||||
current_tokens_np = np.array(current_tokens)
|
|
||||||
if current_tokens.shape[1] > 1 and current_tokens_np[0, -2] >= DEC_PAD:
|
|
||||||
logger.debug("omit rewinding from special tokens")
|
|
||||||
self.state.last_attend_frame = most_attended_frame
|
|
||||||
else:
|
|
||||||
logger.debug(f"[rewind detected] current: {most_attended_frame}, last: {self.state.last_attend_frame}")
|
|
||||||
self.state.last_attend_frame = -self.cfg.rewind_threshold
|
|
||||||
current_tokens = mx.concatenate(self.state.tokens, axis=1) if len(self.state.tokens) > 0 else self.state.tokens[0]
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
self.state.last_attend_frame = most_attended_frame
|
|
||||||
if content_mel_len - most_attended_frame <= (4 if is_last else self.cfg.frame_threshold):
|
|
||||||
logger.debug(f"attention reaches the end: {most_attended_frame}/{content_mel_len}")
|
|
||||||
current_tokens = current_tokens[:, :-1]
|
|
||||||
break
|
|
||||||
tokens_to_split = np.array(current_tokens[0, token_len_before_decoding:]).tolist()
|
|
||||||
if self.state.pending_incomplete_tokens:
|
|
||||||
logger.debug(f"[UTF-8 Fix] Prepending pending tokens: {self.state.pending_incomplete_tokens}")
|
|
||||||
tokens_to_split = self.state.pending_incomplete_tokens + tokens_to_split
|
|
||||||
|
|
||||||
if fire_detected or is_last:
|
|
||||||
new_hypothesis = tokens_to_split
|
|
||||||
split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis)
|
|
||||||
else:
|
|
||||||
split_words, split_tokens = self.tokenizer.split_to_word_tokens(tokens_to_split)
|
|
||||||
if len(split_words) > 1:
|
|
||||||
new_hypothesis = [i for sublist in split_tokens[:-1] for i in sublist]
|
|
||||||
else:
|
|
||||||
new_hypothesis = []
|
|
||||||
|
|
||||||
logger.debug(f"new_hypothesis: {new_hypothesis}")
|
|
||||||
new_tokens = mx.array([new_hypothesis], dtype=mx.int32)
|
|
||||||
new_tokens = mx.repeat(new_tokens, self.cfg.beam_size, axis=0)
|
|
||||||
self.state.tokens.append(new_tokens)
|
|
||||||
|
|
||||||
logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}")
|
|
||||||
|
|
||||||
self._clean_cache()
|
|
||||||
|
|
||||||
if len(l_absolute_timestamps) >= 2 and self.state.first_timestamp is None:
|
|
||||||
self.state.first_timestamp = l_absolute_timestamps[0]
|
|
||||||
timestamped_words = []
|
|
||||||
timestamp_idx = 0
|
|
||||||
replacement_char = "\ufffd"
|
|
||||||
|
|
||||||
for word, word_tokens in zip(split_words, split_tokens):
|
|
||||||
if replacement_char in word:
|
|
||||||
logger.warning(f"[UTF-8 Filter] Skipping: {repr(word)}")
|
|
||||||
timestamp_idx += len(word_tokens)
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
current_timestamp = l_absolute_timestamps[timestamp_idx]
|
|
||||||
except IndexError:
|
|
||||||
pass
|
|
||||||
timestamp_idx += len(word_tokens)
|
|
||||||
|
|
||||||
timestamp_entry = ASRToken(
|
|
||||||
start=round(current_timestamp, 2),
|
|
||||||
end=round(current_timestamp + 0.1, 2),
|
|
||||||
text=word,
|
|
||||||
speaker=self.state.speaker,
|
|
||||||
detected_language=self.state.detected_language
|
|
||||||
).with_offset(self.state.global_time_offset)
|
|
||||||
timestamped_words.append(timestamp_entry)
|
|
||||||
self.state.pending_incomplete_tokens = []
|
|
||||||
MAX_PENDING_TOKENS = 10
|
|
||||||
if split_words and replacement_char in split_words[-1]:
|
|
||||||
if len(split_tokens[-1]) <= MAX_PENDING_TOKENS:
|
|
||||||
self.state.pending_incomplete_tokens = split_tokens[-1]
|
|
||||||
logger.debug(f"[UTF-8 Fix] Holding incomplete tokens")
|
|
||||||
else:
|
|
||||||
logger.warning(f"[UTF-8 Fix] Skipping too many tokens")
|
|
||||||
|
|
||||||
return timestamped_words
|
|
||||||
|
|
||||||
def _process_cross_attention(
|
|
||||||
self,
|
|
||||||
cross_attns: List[List[mx.array]],
|
|
||||||
content_mel_len: int
|
|
||||||
) -> mx.array:
|
|
||||||
"""
|
|
||||||
Process cross-attention weights for alignment.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cross_attns: List of cross-attention from each forward pass
|
|
||||||
Each element is a list of mx.arrays per layer
|
|
||||||
content_mel_len: Length of actual audio content
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Processed attention tensor, shape (batch, seq_len, content_mel_len)
|
|
||||||
"""
|
|
||||||
attn_of_alignment_heads = [[] for _ in range(self.state.num_align_heads)]
|
|
||||||
num_decoder_layers = self.num_decoder_layers
|
|
||||||
|
|
||||||
if cross_attns and isinstance(cross_attns[0], list):
|
|
||||||
flattened_attns = [attn for layer_list in cross_attns for attn in layer_list]
|
|
||||||
else:
|
|
||||||
flattened_attns = cross_attns
|
|
||||||
|
|
||||||
for idx, attn_mat in enumerate(flattened_attns):
|
|
||||||
if attn_mat is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
layer_rank = idx % num_decoder_layers
|
|
||||||
align_heads_in_layer = self.state.align_source.get(layer_rank, [])
|
|
||||||
|
|
||||||
if len(align_heads_in_layer) == 0:
|
|
||||||
continue
|
|
||||||
attn_mat = mx.softmax(attn_mat, axis=-1)
|
|
||||||
|
|
||||||
for align_head_rank, head_id in align_heads_in_layer:
|
|
||||||
if self.cfg.beam_size == 1:
|
|
||||||
if attn_mat.ndim == 4:
|
|
||||||
a = attn_mat[0, head_id, :, :]
|
|
||||||
else:
|
|
||||||
a = attn_mat[head_id, :, :]
|
|
||||||
a = a[None, :, :]
|
|
||||||
else:
|
|
||||||
a = attn_mat[:, head_id, :, :]
|
|
||||||
attn_of_alignment_heads[align_head_rank].append(a)
|
|
||||||
tmp = []
|
|
||||||
for mat in attn_of_alignment_heads:
|
|
||||||
if mat:
|
|
||||||
t = mx.concatenate(mat, axis=1)
|
|
||||||
tmp.append(t)
|
|
||||||
|
|
||||||
if not tmp:
|
|
||||||
return mx.zeros((self.cfg.beam_size, 1, content_mel_len))
|
|
||||||
attn_of_alignment_heads = mx.stack(tmp, axis=1)
|
|
||||||
|
|
||||||
std = mx.std(attn_of_alignment_heads, axis=-2, keepdims=True)
|
|
||||||
mean = mx.mean(attn_of_alignment_heads, axis=-2, keepdims=True)
|
|
||||||
attn_of_alignment_heads = (attn_of_alignment_heads - mean) / (std + 1e-8)
|
|
||||||
|
|
||||||
attn_of_alignment_heads = mlx_median_filter(attn_of_alignment_heads, 7)
|
|
||||||
|
|
||||||
attn_of_alignment_heads = mx.mean(attn_of_alignment_heads, axis=1)
|
|
||||||
|
|
||||||
attn_of_alignment_heads = attn_of_alignment_heads[:, :, :content_mel_len]
|
|
||||||
|
|
||||||
mx.eval(attn_of_alignment_heads)
|
|
||||||
return attn_of_alignment_heads
|
|
||||||
|
|
||||||
@@ -68,40 +68,4 @@ def load_mlx_encoder(
|
|||||||
|
|
||||||
model.update(encoder_weights)
|
model.update(encoder_weights)
|
||||||
mx.eval(model.parameters())
|
mx.eval(model.parameters())
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def load_mlx_model(
|
|
||||||
path_or_hf_repo: str,
|
|
||||||
dtype: mx.Dtype = mx.float32,
|
|
||||||
) -> whisper.Whisper:
|
|
||||||
model_path = Path(path_or_hf_repo)
|
|
||||||
if not model_path.exists():
|
|
||||||
model_path = Path(snapshot_download(repo_id=path_or_hf_repo))
|
|
||||||
|
|
||||||
with open(str(model_path / "config.json"), "r") as f:
|
|
||||||
config = json.loads(f.read())
|
|
||||||
config.pop("model_type", None)
|
|
||||||
quantization = config.pop("quantization", None)
|
|
||||||
|
|
||||||
model_args = whisper.ModelDimensions(**config)
|
|
||||||
|
|
||||||
wf = model_path / "weights.safetensors"
|
|
||||||
if not wf.exists():
|
|
||||||
wf = model_path / "weights.npz"
|
|
||||||
weights = mx.load(str(wf))
|
|
||||||
|
|
||||||
model = whisper.Whisper(model_args, dtype)
|
|
||||||
|
|
||||||
if quantization is not None:
|
|
||||||
class_predicate = (
|
|
||||||
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
|
|
||||||
and f"{p}.scales" in weights
|
|
||||||
)
|
|
||||||
nn.quantize(model, **quantization, class_predicate=class_predicate)
|
|
||||||
|
|
||||||
weights = tree_unflatten(list(weights.items()))
|
|
||||||
|
|
||||||
model.update(weights)
|
|
||||||
mx.eval(model.parameters())
|
|
||||||
return model
|
return model
|
||||||
@@ -626,10 +626,8 @@ class AlignAtt:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
current_timestamp = l_absolute_timestamps[timestamp_idx]
|
current_timestamp = l_absolute_timestamps[timestamp_idx]
|
||||||
except IndexError:
|
except:
|
||||||
# Use last timestamp if index out of range
|
pass
|
||||||
logger.warning(f"Timestamp index {timestamp_idx} out of range, using last timestamp")
|
|
||||||
current_timestamp = l_absolute_timestamps[-1] if l_absolute_timestamps else 0.0
|
|
||||||
timestamp_idx += len(word_tokens)
|
timestamp_idx += len(word_tokens)
|
||||||
|
|
||||||
timestamp_entry = ASRToken(
|
timestamp_entry = ASRToken(
|
||||||
|
|||||||
@@ -1,139 +0,0 @@
|
|||||||
"""
|
|
||||||
Thread Safety Configuration for WhisperLiveKit
|
|
||||||
|
|
||||||
This module provides thread safety configuration and utilities.
|
|
||||||
|
|
||||||
Environment Variables:
|
|
||||||
WHISPERLIVEKIT_MODEL_LOCK: Enable/disable model locking (default: 1)
|
|
||||||
Set to "0" to disable for single-connection deployments
|
|
||||||
|
|
||||||
WHISPERLIVEKIT_LOCK_TIMEOUT: Lock acquisition timeout in seconds (default: 30)
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
# Enable model locking (default)
|
|
||||||
export WHISPERLIVEKIT_MODEL_LOCK=1
|
|
||||||
|
|
||||||
# Disable for single-connection deployment
|
|
||||||
export WHISPERLIVEKIT_MODEL_LOCK=0
|
|
||||||
|
|
||||||
# Custom timeout
|
|
||||||
export WHISPERLIVEKIT_LOCK_TIMEOUT=60
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import logging
|
|
||||||
import threading
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Configuration
|
|
||||||
USE_MODEL_LOCK = os.environ.get("WHISPERLIVEKIT_MODEL_LOCK", "1") == "1"
|
|
||||||
LOCK_TIMEOUT = float(os.environ.get("WHISPERLIVEKIT_LOCK_TIMEOUT", "30.0"))
|
|
||||||
|
|
||||||
# Global model lock
|
|
||||||
_model_lock = threading.Lock()
|
|
||||||
|
|
||||||
# Log configuration on import
|
|
||||||
if USE_MODEL_LOCK:
|
|
||||||
logger.info(f"Model locking ENABLED (timeout: {LOCK_TIMEOUT}s)")
|
|
||||||
logger.info("For single-connection deployments, set WHISPERLIVEKIT_MODEL_LOCK=0")
|
|
||||||
else:
|
|
||||||
logger.warning("Model locking DISABLED - only safe for single-connection deployments")
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_lock():
|
|
||||||
"""Get the global model lock instance"""
|
|
||||||
return _model_lock
|
|
||||||
|
|
||||||
|
|
||||||
def acquire_model_lock(timeout=None):
|
|
||||||
"""
|
|
||||||
Acquire model lock with timeout.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
timeout: Lock acquisition timeout (default: use LOCK_TIMEOUT)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if lock acquired, False on timeout
|
|
||||||
"""
|
|
||||||
if not USE_MODEL_LOCK:
|
|
||||||
return True
|
|
||||||
|
|
||||||
timeout = timeout or LOCK_TIMEOUT
|
|
||||||
acquired = _model_lock.acquire(timeout=timeout)
|
|
||||||
|
|
||||||
if not acquired:
|
|
||||||
logger.error(f"Failed to acquire model lock within {timeout}s")
|
|
||||||
|
|
||||||
return acquired
|
|
||||||
|
|
||||||
|
|
||||||
def release_model_lock():
|
|
||||||
"""Release model lock"""
|
|
||||||
if not USE_MODEL_LOCK:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
_model_lock.release()
|
|
||||||
except RuntimeError:
|
|
||||||
# Lock not held - this is fine
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ModelLockContext:
|
|
||||||
"""Context manager for model lock"""
|
|
||||||
|
|
||||||
def __init__(self, timeout=None):
|
|
||||||
self.timeout = timeout
|
|
||||||
self.acquired = False
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
self.acquired = acquire_model_lock(self.timeout)
|
|
||||||
return self.acquired
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
if self.acquired:
|
|
||||||
release_model_lock()
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
# Concurrency recommendations
|
|
||||||
RECOMMENDED_CONNECTIONS_PER_WORKER = 1 if USE_MODEL_LOCK else 1
|
|
||||||
RECOMMENDED_WORKERS = 4
|
|
||||||
|
|
||||||
def print_deployment_recommendations():
|
|
||||||
"""Print recommended deployment configuration"""
|
|
||||||
print("\n" + "="*60)
|
|
||||||
print("WhisperLiveKit Deployment Recommendations")
|
|
||||||
print("="*60)
|
|
||||||
|
|
||||||
if USE_MODEL_LOCK:
|
|
||||||
print("⚠️ Model locking is ENABLED")
|
|
||||||
print(" This serializes inference across connections.")
|
|
||||||
print()
|
|
||||||
print("Recommended deployment:")
|
|
||||||
print(f" gunicorn -w {RECOMMENDED_WORKERS} \\")
|
|
||||||
print(" -k uvicorn.workers.UvicornWorker \\")
|
|
||||||
print(" --worker-connections 1 \\")
|
|
||||||
print(" whisperlivekit.basic_server:app")
|
|
||||||
print()
|
|
||||||
print("Expected capacity:")
|
|
||||||
print(f" - {RECOMMENDED_WORKERS} concurrent users (1 per worker)")
|
|
||||||
print(f" - Memory: ~{RECOMMENDED_WORKERS}x model size")
|
|
||||||
else:
|
|
||||||
print("✅ Model locking is DISABLED")
|
|
||||||
print(" ⚠️ ONLY safe for single-connection deployments")
|
|
||||||
print()
|
|
||||||
print("Recommended deployment:")
|
|
||||||
print(" uvicorn whisperlivekit.basic_server:app \\")
|
|
||||||
print(" --host 0.0.0.0 --port 8000 \\")
|
|
||||||
print(" --workers 1")
|
|
||||||
print()
|
|
||||||
print("Expected capacity:")
|
|
||||||
print(" - 1 concurrent user only")
|
|
||||||
|
|
||||||
print("="*60 + "\n")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print_deployment_recommendations()
|
|
||||||
@@ -107,6 +107,21 @@ class Silence():
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SegmentBuffer:
|
||||||
|
"""Per-segment buffer for ephemeral/unvalidated content."""
|
||||||
|
transcription: str = ''
|
||||||
|
diarization: str = ''
|
||||||
|
translation: str = ''
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, str]:
|
||||||
|
return {
|
||||||
|
'transcription': self.transcription,
|
||||||
|
'diarization': self.diarization,
|
||||||
|
'translation': self.translation
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Segment(TimedText):
|
class Segment(TimedText):
|
||||||
"""Generic contiguous span built from tokens or silence markers."""
|
"""Generic contiguous span built from tokens or silence markers."""
|
||||||
@@ -114,14 +129,18 @@ class Segment(TimedText):
|
|||||||
end: Optional[float]
|
end: Optional[float]
|
||||||
text: Optional[str]
|
text: Optional[str]
|
||||||
speaker: Optional[str]
|
speaker: Optional[str]
|
||||||
|
id: Optional[int] = None
|
||||||
|
start_speaker: Optional[float] = None
|
||||||
tokens: Optional[ASRToken] = None
|
tokens: Optional[ASRToken] = None
|
||||||
translation: Optional[Translation] = None
|
translation: Optional[Translation] = None
|
||||||
|
buffer: Optional[SegmentBuffer] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_tokens(
|
def from_tokens(
|
||||||
cls,
|
cls,
|
||||||
tokens: List[Union[ASRToken, Silence]],
|
tokens: List[Union[ASRToken, Silence]],
|
||||||
is_silence: bool = False
|
is_silence: bool = False,
|
||||||
|
segment_id: Optional[int] = None
|
||||||
) -> Optional["Segment"]:
|
) -> Optional["Segment"]:
|
||||||
"""Return a normalized segment representing the provided tokens."""
|
"""Return a normalized segment representing the provided tokens."""
|
||||||
if not tokens:
|
if not tokens:
|
||||||
@@ -134,7 +153,9 @@ class Segment(TimedText):
|
|||||||
start=start_token.start,
|
start=start_token.start,
|
||||||
end=end_token.end,
|
end=end_token.end,
|
||||||
text=None,
|
text=None,
|
||||||
speaker=-2
|
speaker=-2,
|
||||||
|
id=segment_id,
|
||||||
|
start_speaker=start_token.start
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return cls(
|
return cls(
|
||||||
@@ -142,6 +163,8 @@ class Segment(TimedText):
|
|||||||
end=end_token.end,
|
end=end_token.end,
|
||||||
text=''.join(token.text for token in tokens),
|
text=''.join(token.text for token in tokens),
|
||||||
speaker=-1,
|
speaker=-1,
|
||||||
|
id=segment_id,
|
||||||
|
start_speaker=start_token.start,
|
||||||
detected_language=start_token.detected_language
|
detected_language=start_token.detected_language
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -150,17 +173,18 @@ class Segment(TimedText):
|
|||||||
return self.speaker == -2
|
return self.speaker == -2
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
"""Serialize the segment for frontend consumption."""
|
"""Serialize the segment for frontend consumption (new API format)."""
|
||||||
_dict: Dict[str, Any] = {
|
_dict: Dict[str, Any] = {
|
||||||
|
'id': self.id if self.id is not None else 0,
|
||||||
'speaker': int(self.speaker) if self.speaker != -1 else 1,
|
'speaker': int(self.speaker) if self.speaker != -1 else 1,
|
||||||
'text': self.text,
|
'text': self.text or '',
|
||||||
|
'start_speaker': format_time(self.start_speaker) if self.start_speaker is not None else format_time(self.start),
|
||||||
'start': format_time(self.start),
|
'start': format_time(self.start),
|
||||||
'end': format_time(self.end),
|
'end': format_time(self.end),
|
||||||
|
'language': self.detected_language,
|
||||||
|
'translation': self.translation or '',
|
||||||
|
'buffer': self.buffer.to_dict() if self.buffer else SegmentBuffer().to_dict()
|
||||||
}
|
}
|
||||||
if self.translation:
|
|
||||||
_dict['translation'] = self.translation
|
|
||||||
if self.detected_language:
|
|
||||||
_dict['detected_language'] = self.detected_language
|
|
||||||
return _dict
|
return _dict
|
||||||
|
|
||||||
|
|
||||||
@@ -179,23 +203,20 @@ class SilentSegment(Segment):
|
|||||||
class FrontData():
|
class FrontData():
|
||||||
status: str = ''
|
status: str = ''
|
||||||
error: str = ''
|
error: str = ''
|
||||||
lines: list[Segment] = field(default_factory=list)
|
segments: list[Segment] = field(default_factory=list)
|
||||||
buffer_transcription: str = ''
|
|
||||||
buffer_diarization: str = ''
|
|
||||||
buffer_translation: str = ''
|
|
||||||
remaining_time_transcription: float = 0.
|
remaining_time_transcription: float = 0.
|
||||||
remaining_time_diarization: float = 0.
|
remaining_time_diarization: float = 0.
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
"""Serialize the front-end data payload."""
|
"""Serialize the front-end data payload (new API format)."""
|
||||||
_dict: Dict[str, Any] = {
|
_dict: Dict[str, Any] = {
|
||||||
|
'type': 'transcript_update',
|
||||||
'status': self.status,
|
'status': self.status,
|
||||||
'lines': [line.to_dict() for line in self.lines if (line.text or line.speaker == -2)],
|
'segments': [seg.to_dict() for seg in self.segments if (seg.text or seg.speaker == -2)],
|
||||||
'buffer_transcription': self.buffer_transcription,
|
'metadata': {
|
||||||
'buffer_diarization': self.buffer_diarization,
|
'remaining_time_transcription': self.remaining_time_transcription,
|
||||||
'buffer_translation': self.buffer_translation,
|
'remaining_time_diarization': self.remaining_time_diarization,
|
||||||
'remaining_time_transcription': self.remaining_time_transcription,
|
}
|
||||||
'remaining_time_diarization': self.remaining_time_diarization,
|
|
||||||
}
|
}
|
||||||
if self.error:
|
if self.error:
|
||||||
_dict['error'] = self.error
|
_dict['error'] = self.error
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
from time import time
|
from time import time
|
||||||
from typing import Any, List, Optional, Tuple, Union
|
from typing import Any, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from whisperlivekit.timed_objects import (ASRToken, Segment, PuncSegment, Silence,
|
from whisperlivekit.timed_objects import (ASRToken, Segment, SegmentBuffer, PuncSegment, Silence,
|
||||||
SilentSegment, SpeakerSegment,
|
SilentSegment, SpeakerSegment,
|
||||||
TimedText)
|
TimedText)
|
||||||
|
|
||||||
|
|
||||||
class TokensAlignment:
|
class TokensAlignment:
|
||||||
|
# Minimum duration (seconds) for a silence to be displayed
|
||||||
|
MIN_SILENCE_DISPLAY_DURATION = 2.0
|
||||||
|
|
||||||
def __init__(self, state: Any, args: Any, sep: Optional[str]) -> None:
|
def __init__(self, state: Any, args: Any, sep: Optional[str]) -> None:
|
||||||
self.state = state
|
self.state = state
|
||||||
@@ -33,7 +35,15 @@ class TokensAlignment:
|
|||||||
|
|
||||||
self.last_punctuation = None
|
self.last_punctuation = None
|
||||||
self.last_uncompleted_punc_segment: PuncSegment = None
|
self.last_uncompleted_punc_segment: PuncSegment = None
|
||||||
self.unvalidated_tokens: PuncSegment = []
|
self.tokens_after_last_punctuation: PuncSegment = []
|
||||||
|
self.all_validated_segments: List[Segment] = []
|
||||||
|
|
||||||
|
# For token-by-token validation with diarization
|
||||||
|
self.pending_tokens: List[ASRToken] = []
|
||||||
|
self.last_validated_token_end: float = 0.0
|
||||||
|
|
||||||
|
# Segment ID counter for the new API
|
||||||
|
self._next_segment_id: int = 1
|
||||||
|
|
||||||
def update(self) -> None:
|
def update(self) -> None:
|
||||||
"""Drain state buffers into the running alignment context."""
|
"""Drain state buffers into the running alignment context."""
|
||||||
@@ -49,8 +59,6 @@ class TokensAlignment:
|
|||||||
|
|
||||||
def add_translation(self, segment: Segment) -> None:
|
def add_translation(self, segment: Segment) -> None:
|
||||||
"""Append translated text segments that overlap with a segment."""
|
"""Append translated text segments that overlap with a segment."""
|
||||||
if segment.translation is None:
|
|
||||||
segment.translation = ''
|
|
||||||
for ts in self.all_translation_segments:
|
for ts in self.all_translation_segments:
|
||||||
if ts.is_within(segment):
|
if ts.is_within(segment):
|
||||||
segment.translation += ts.text + (self.sep if ts.text else '')
|
segment.translation += ts.text + (self.sep if ts.text else '')
|
||||||
@@ -93,11 +101,11 @@ class TokensAlignment:
|
|||||||
def compute_new_punctuations_segments(self) -> List[PuncSegment]:
|
def compute_new_punctuations_segments(self) -> List[PuncSegment]:
|
||||||
new_punc_segments = []
|
new_punc_segments = []
|
||||||
segment_start_idx = 0
|
segment_start_idx = 0
|
||||||
self.unvalidated_tokens += self.new_tokens
|
self.tokens_after_last_punctuation += self.new_tokens
|
||||||
for i, token in enumerate(self.unvalidated_tokens):
|
for i, token in enumerate(self.tokens_after_last_punctuation):
|
||||||
if token.is_silence():
|
if token.is_silence():
|
||||||
previous_segment = PuncSegment.from_tokens(
|
previous_segment = PuncSegment.from_tokens(
|
||||||
tokens=self.unvalidated_tokens[segment_start_idx: i],
|
tokens=self.tokens_after_last_punctuation[segment_start_idx: i],
|
||||||
)
|
)
|
||||||
if previous_segment:
|
if previous_segment:
|
||||||
new_punc_segments.append(previous_segment)
|
new_punc_segments.append(previous_segment)
|
||||||
@@ -110,12 +118,12 @@ class TokensAlignment:
|
|||||||
else:
|
else:
|
||||||
if token.has_punctuation():
|
if token.has_punctuation():
|
||||||
segment = PuncSegment.from_tokens(
|
segment = PuncSegment.from_tokens(
|
||||||
tokens=self.unvalidated_tokens[segment_start_idx: i+1],
|
tokens=self.tokens_after_last_punctuation[segment_start_idx: i+1],
|
||||||
)
|
)
|
||||||
new_punc_segments.append(segment)
|
new_punc_segments.append(segment)
|
||||||
segment_start_idx = i+1
|
segment_start_idx = i+1
|
||||||
|
|
||||||
self.unvalidated_tokens = self.unvalidated_tokens[segment_start_idx:]
|
self.tokens_after_last_punctuation = self.tokens_after_last_punctuation[segment_start_idx:]
|
||||||
return new_punc_segments
|
return new_punc_segments
|
||||||
|
|
||||||
|
|
||||||
@@ -140,64 +148,189 @@ class TokensAlignment:
|
|||||||
|
|
||||||
return max(0, end - start)
|
return max(0, end - start)
|
||||||
|
|
||||||
def get_lines_diarization(self) -> Tuple[List[Segment], str]:
|
def _get_speaker_for_token(self, token: ASRToken, diarization_segments: List[SpeakerSegment]) -> Optional[int]:
|
||||||
"""Build segments when diarization is enabled and track overflow buffer."""
|
"""Get speaker ID for a token based on diarization overlap. Returns None if not covered."""
|
||||||
diarization_buffer = ''
|
if not diarization_segments:
|
||||||
punctuation_segments = self.compute_punctuations_segments()
|
return None
|
||||||
diarization_segments = self.concatenate_diar_segments()
|
|
||||||
for punctuation_segment in punctuation_segments:
|
|
||||||
if not punctuation_segment.is_silence():
|
|
||||||
if diarization_segments and punctuation_segment.start >= diarization_segments[-1].end:
|
|
||||||
diarization_buffer += punctuation_segment.text
|
|
||||||
else:
|
|
||||||
max_overlap = 0.0
|
|
||||||
max_overlap_speaker = 1
|
|
||||||
for diarization_segment in diarization_segments:
|
|
||||||
intersec = self.intersection_duration(punctuation_segment, diarization_segment)
|
|
||||||
if intersec > max_overlap:
|
|
||||||
max_overlap = intersec
|
|
||||||
max_overlap_speaker = diarization_segment.speaker + 1
|
|
||||||
punctuation_segment.speaker = max_overlap_speaker
|
|
||||||
|
|
||||||
segments = []
|
# Check if token is beyond diarization coverage
|
||||||
if punctuation_segments:
|
if token.start >= diarization_segments[-1].end:
|
||||||
segments = [punctuation_segments[0]]
|
return None
|
||||||
for segment in punctuation_segments[1:]:
|
|
||||||
if segment.speaker == segments[-1].speaker:
|
# Find speaker with max overlap
|
||||||
if segments[-1].text:
|
max_overlap = 0.0
|
||||||
segments[-1].text += segment.text
|
best_speaker = None
|
||||||
segments[-1].end = segment.end
|
for diar_seg in diarization_segments:
|
||||||
|
overlap = self.intersection_duration(token, diar_seg)
|
||||||
|
if overlap > max_overlap:
|
||||||
|
max_overlap = overlap
|
||||||
|
best_speaker = diar_seg.speaker + 1 # 1-indexed
|
||||||
|
|
||||||
|
return best_speaker if max_overlap > 0 else None
|
||||||
|
|
||||||
|
def get_lines_diarization(self) -> Tuple[List[Segment], str]:
|
||||||
|
"""Build segments with token-by-token validation when diarization covers them."""
|
||||||
|
diarization_segments = self.concatenate_diar_segments()
|
||||||
|
|
||||||
|
# Add new tokens to pending
|
||||||
|
self.pending_tokens.extend(self.new_tokens)
|
||||||
|
|
||||||
|
# Process pending tokens - validate those covered by diarization
|
||||||
|
still_pending = []
|
||||||
|
for token in self.pending_tokens:
|
||||||
|
if token.is_silence():
|
||||||
|
# Handle silence tokens
|
||||||
|
silence_duration = (token.end or 0) - (token.start or 0)
|
||||||
|
if silence_duration >= self.MIN_SILENCE_DISPLAY_DURATION:
|
||||||
|
# Significant silence - add as separate segment
|
||||||
|
if self.all_validated_segments and not self.all_validated_segments[-1].is_silence():
|
||||||
|
self.all_validated_segments.append(SilentSegment(
|
||||||
|
start=token.start,
|
||||||
|
end=token.end
|
||||||
|
))
|
||||||
|
elif self.all_validated_segments and self.all_validated_segments[-1].is_silence():
|
||||||
|
# Extend existing silence
|
||||||
|
self.all_validated_segments[-1].end = token.end
|
||||||
|
else:
|
||||||
|
self.all_validated_segments.append(SilentSegment(
|
||||||
|
start=token.start,
|
||||||
|
end=token.end
|
||||||
|
))
|
||||||
|
# Short silences are ignored (don't go to pending either)
|
||||||
|
continue
|
||||||
|
|
||||||
|
speaker = self._get_speaker_for_token(token, diarization_segments)
|
||||||
|
|
||||||
|
if speaker is not None:
|
||||||
|
# Token is covered by diarization - validate it
|
||||||
|
if self.all_validated_segments:
|
||||||
|
last_seg = self.all_validated_segments[-1]
|
||||||
|
if not last_seg.is_silence() and last_seg.speaker == speaker:
|
||||||
|
# Same speaker - append to existing segment
|
||||||
|
last_seg.text += token.text
|
||||||
|
last_seg.end = token.end
|
||||||
|
else:
|
||||||
|
# Different speaker or after silence - new segment
|
||||||
|
new_seg = Segment(
|
||||||
|
start=token.start,
|
||||||
|
end=token.end,
|
||||||
|
text=token.text,
|
||||||
|
speaker=speaker,
|
||||||
|
start_speaker=token.start,
|
||||||
|
detected_language=token.detected_language
|
||||||
|
)
|
||||||
|
self.all_validated_segments.append(new_seg)
|
||||||
else:
|
else:
|
||||||
segments.append(segment)
|
# First segment
|
||||||
|
new_seg = Segment(
|
||||||
|
start=token.start,
|
||||||
|
end=token.end,
|
||||||
|
text=token.text,
|
||||||
|
speaker=speaker,
|
||||||
|
start_speaker=token.start,
|
||||||
|
detected_language=token.detected_language
|
||||||
|
)
|
||||||
|
self.all_validated_segments.append(new_seg)
|
||||||
|
|
||||||
|
self.last_validated_token_end = token.end
|
||||||
|
else:
|
||||||
|
# Token not yet covered by diarization - keep pending
|
||||||
|
still_pending.append(token)
|
||||||
|
|
||||||
|
self.pending_tokens = still_pending
|
||||||
|
|
||||||
|
# Build diarization buffer from pending tokens
|
||||||
|
diarization_buffer = ''.join(t.text for t in self.pending_tokens if not t.is_silence())
|
||||||
|
|
||||||
|
return self.all_validated_segments, diarization_buffer
|
||||||
|
|
||||||
return segments, diarization_buffer
|
|
||||||
|
|
||||||
|
def _assign_segment_ids(self, segments: List[Segment]) -> None:
|
||||||
|
"""Assign unique IDs to segments that don't have one yet."""
|
||||||
|
for segment in segments:
|
||||||
|
if segment.id is None:
|
||||||
|
segment.id = self._next_segment_id
|
||||||
|
self._next_segment_id += 1
|
||||||
|
|
||||||
|
def _assign_buffers_to_last_segment(
|
||||||
|
self,
|
||||||
|
segments: List[Segment],
|
||||||
|
buffer_transcription: str,
|
||||||
|
buffer_diarization: str,
|
||||||
|
buffer_translation: str
|
||||||
|
) -> None:
|
||||||
|
"""Assign buffer content to the last non-silent segment."""
|
||||||
|
# First, clear ALL buffers (they're ephemeral and shouldn't persist)
|
||||||
|
for segment in segments:
|
||||||
|
segment.buffer = SegmentBuffer()
|
||||||
|
|
||||||
|
# Find the last non-silent segment and assign buffers to it
|
||||||
|
for segment in reversed(segments):
|
||||||
|
if not segment.is_silence():
|
||||||
|
segment.buffer = SegmentBuffer(
|
||||||
|
transcription=buffer_transcription,
|
||||||
|
diarization=buffer_diarization,
|
||||||
|
translation=buffer_translation
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
def _filter_and_merge_segments(self, segments: List[Segment]) -> List[Segment]:
|
||||||
|
"""Filter parasitic silences and merge consecutive same-speaker segments."""
|
||||||
|
if not segments:
|
||||||
|
return segments
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for seg in segments:
|
||||||
|
if seg.is_silence():
|
||||||
|
# Filter short silences
|
||||||
|
duration = (seg.end or 0) - (seg.start or 0)
|
||||||
|
if duration < self.MIN_SILENCE_DISPLAY_DURATION:
|
||||||
|
continue
|
||||||
|
# Merge consecutive silences
|
||||||
|
if result and result[-1].is_silence():
|
||||||
|
result[-1].end = seg.end
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# Merge same speaker segments (across filtered silences)
|
||||||
|
if result and not result[-1].is_silence() and result[-1].speaker == seg.speaker:
|
||||||
|
result[-1].text += seg.text
|
||||||
|
result[-1].end = seg.end
|
||||||
|
continue
|
||||||
|
|
||||||
|
result.append(seg)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def get_lines(
|
def get_lines(
|
||||||
self,
|
self,
|
||||||
diarization: bool = False,
|
diarization: bool = False,
|
||||||
translation: bool = False,
|
translation: bool = False,
|
||||||
current_silence: Optional[Silence] = None
|
current_silence: Optional[Silence] = None,
|
||||||
) -> Tuple[List[Segment], str, Union[str, TimedText]]:
|
buffer_transcription: str = ''
|
||||||
"""Return the formatted segments plus buffers, optionally with diarization/translation."""
|
) -> List[Segment]:
|
||||||
|
"""Return the formatted segments with per-segment buffers, optionally with diarization/translation."""
|
||||||
|
diarization_buffer = ''
|
||||||
|
|
||||||
if diarization:
|
if diarization:
|
||||||
segments, diarization_buffer = self.get_lines_diarization()
|
segments, diarization_buffer = self.get_lines_diarization()
|
||||||
else:
|
else:
|
||||||
diarization_buffer = ''
|
|
||||||
for token in self.new_tokens:
|
for token in self.new_tokens:
|
||||||
if token.is_silence():
|
if token.is_silence():
|
||||||
if self.current_line_tokens:
|
# Check silence duration before adding
|
||||||
self.validated_segments.append(Segment().from_tokens(self.current_line_tokens))
|
silence_duration = (token.end or 0) - (token.start or 0)
|
||||||
self.current_line_tokens = []
|
if silence_duration >= self.MIN_SILENCE_DISPLAY_DURATION:
|
||||||
|
if self.current_line_tokens:
|
||||||
end_silence = token.end if token.has_ended else time() - self.beg_loop
|
self.validated_segments.append(Segment().from_tokens(self.current_line_tokens))
|
||||||
if self.validated_segments and self.validated_segments[-1].is_silence():
|
self.current_line_tokens = []
|
||||||
self.validated_segments[-1].end = end_silence
|
|
||||||
else:
|
end_silence = token.end if token.has_ended else time() - self.beg_loop
|
||||||
self.validated_segments.append(SilentSegment(
|
if self.validated_segments and self.validated_segments[-1].is_silence():
|
||||||
start=token.start,
|
self.validated_segments[-1].end = end_silence
|
||||||
end=end_silence
|
else:
|
||||||
))
|
self.validated_segments.append(SilentSegment(
|
||||||
|
start=token.start,
|
||||||
|
end=end_silence
|
||||||
|
))
|
||||||
else:
|
else:
|
||||||
self.current_line_tokens.append(token)
|
self.current_line_tokens.append(token)
|
||||||
|
|
||||||
@@ -205,15 +338,37 @@ class TokensAlignment:
|
|||||||
if self.current_line_tokens:
|
if self.current_line_tokens:
|
||||||
segments.append(Segment().from_tokens(self.current_line_tokens))
|
segments.append(Segment().from_tokens(self.current_line_tokens))
|
||||||
|
|
||||||
|
# Handle current ongoing silence
|
||||||
if current_silence:
|
if current_silence:
|
||||||
end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop
|
silence_duration = (current_silence.end or time() - self.beg_loop) - (current_silence.start or 0)
|
||||||
if segments and segments[-1].is_silence():
|
if silence_duration >= self.MIN_SILENCE_DISPLAY_DURATION:
|
||||||
segments[-1] = SilentSegment(start=segments[-1].start, end=end_silence)
|
end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop
|
||||||
else:
|
if segments and segments[-1].is_silence():
|
||||||
segments.append(SilentSegment(
|
segments[-1] = SilentSegment(start=segments[-1].start, end=end_silence)
|
||||||
start=current_silence.start,
|
else:
|
||||||
end=end_silence
|
segments.append(SilentSegment(
|
||||||
))
|
start=current_silence.start,
|
||||||
|
end=end_silence
|
||||||
|
))
|
||||||
|
|
||||||
if translation:
|
if translation:
|
||||||
[self.add_translation(segment) for segment in segments if not segment.is_silence()]
|
[self.add_translation(segment) for segment in segments if not segment.is_silence()]
|
||||||
return segments, diarization_buffer, self.new_translation_buffer.text
|
|
||||||
|
# Get translation buffer text
|
||||||
|
translation_buffer = self.new_translation_buffer.text if self.new_translation_buffer else ''
|
||||||
|
|
||||||
|
# Filter parasitic silences and merge same-speaker segments
|
||||||
|
segments = self._filter_and_merge_segments(segments)
|
||||||
|
|
||||||
|
# Assign unique IDs to all segments
|
||||||
|
self._assign_segment_ids(segments)
|
||||||
|
|
||||||
|
# Assign buffers to the last active segment
|
||||||
|
self._assign_buffers_to_last_segment(
|
||||||
|
segments,
|
||||||
|
buffer_transcription=buffer_transcription,
|
||||||
|
buffer_diarization=diarization_buffer,
|
||||||
|
buffer_translation=translation_buffer
|
||||||
|
)
|
||||||
|
|
||||||
|
return segments
|
||||||
|
|||||||
@@ -454,8 +454,9 @@ label {
|
|||||||
gap: 4px;
|
gap: 4px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.lag-diarization-value {
|
.lag-diarization-value,
|
||||||
margin-left: 10px;
|
.lag-transcription-value {
|
||||||
|
font-weight: 600;
|
||||||
}
|
}
|
||||||
|
|
||||||
.label_translation img {
|
.label_translation img {
|
||||||
|
|||||||
@@ -232,11 +232,8 @@ function setupWebSocket() {
|
|||||||
if (waitingForStop) {
|
if (waitingForStop) {
|
||||||
statusText.textContent = "Processing finalized or connection closed.";
|
statusText.textContent = "Processing finalized or connection closed.";
|
||||||
if (lastReceivedData) {
|
if (lastReceivedData) {
|
||||||
renderLinesWithBuffer(
|
renderSegments(
|
||||||
lastReceivedData.lines || [],
|
lastReceivedData.segments || [],
|
||||||
lastReceivedData.buffer_diarization || "",
|
|
||||||
lastReceivedData.buffer_transcription || "",
|
|
||||||
lastReceivedData.buffer_translation || "",
|
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
true
|
true
|
||||||
@@ -278,11 +275,8 @@ function setupWebSocket() {
|
|||||||
waitingForStop = false;
|
waitingForStop = false;
|
||||||
|
|
||||||
if (lastReceivedData) {
|
if (lastReceivedData) {
|
||||||
renderLinesWithBuffer(
|
renderSegments(
|
||||||
lastReceivedData.lines || [],
|
lastReceivedData.segments || [],
|
||||||
lastReceivedData.buffer_diarization || "",
|
|
||||||
lastReceivedData.buffer_transcription || "",
|
|
||||||
lastReceivedData.buffer_translation || "",
|
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
true
|
true
|
||||||
@@ -299,21 +293,20 @@ function setupWebSocket() {
|
|||||||
|
|
||||||
lastReceivedData = data;
|
lastReceivedData = data;
|
||||||
|
|
||||||
|
// New API format: segments with per-segment buffers, metadata wrapper
|
||||||
const {
|
const {
|
||||||
lines = [],
|
segments = [],
|
||||||
buffer_transcription = "",
|
metadata = {},
|
||||||
buffer_diarization = "",
|
|
||||||
buffer_translation = "",
|
|
||||||
remaining_time_transcription = 0,
|
|
||||||
remaining_time_diarization = 0,
|
|
||||||
status = "active_transcription",
|
status = "active_transcription",
|
||||||
} = data;
|
} = data;
|
||||||
|
|
||||||
|
const {
|
||||||
|
remaining_time_transcription = 0,
|
||||||
|
remaining_time_diarization = 0,
|
||||||
|
} = metadata;
|
||||||
|
|
||||||
renderLinesWithBuffer(
|
renderSegments(
|
||||||
lines,
|
segments,
|
||||||
buffer_diarization,
|
|
||||||
buffer_transcription,
|
|
||||||
buffer_translation,
|
|
||||||
remaining_time_diarization,
|
remaining_time_diarization,
|
||||||
remaining_time_transcription,
|
remaining_time_transcription,
|
||||||
false,
|
false,
|
||||||
@@ -323,11 +316,8 @@ function setupWebSocket() {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
function renderLinesWithBuffer(
|
function renderSegments(
|
||||||
lines,
|
segments,
|
||||||
buffer_diarization,
|
|
||||||
buffer_transcription,
|
|
||||||
buffer_translation,
|
|
||||||
remaining_time_diarization,
|
remaining_time_diarization,
|
||||||
remaining_time_transcription,
|
remaining_time_transcription,
|
||||||
isFinalizing = false,
|
isFinalizing = false,
|
||||||
@@ -339,33 +329,38 @@ function renderLinesWithBuffer(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const showLoading = !isFinalizing && (lines || []).some((it) => it.speaker == 0);
|
// Build signature for change detection
|
||||||
const showTransLag = !isFinalizing && remaining_time_transcription > 0;
|
|
||||||
const showDiaLag = !isFinalizing && !!buffer_diarization && remaining_time_diarization > 0;
|
|
||||||
const signature = JSON.stringify({
|
const signature = JSON.stringify({
|
||||||
lines: (lines || []).map((it) => ({ speaker: it.speaker, text: it.text, start: it.start, end: it.end, detected_language: it.detected_language })),
|
segments: (segments || []).map((it) => ({
|
||||||
buffer_transcription: buffer_transcription || "",
|
id: it.id,
|
||||||
buffer_diarization: buffer_diarization || "",
|
speaker: it.speaker,
|
||||||
buffer_translation: buffer_translation,
|
text: it.text,
|
||||||
|
start: it.start,
|
||||||
|
end: it.end,
|
||||||
|
language: it.language,
|
||||||
|
buffer: it.buffer || {}
|
||||||
|
})),
|
||||||
status: current_status,
|
status: current_status,
|
||||||
showLoading,
|
|
||||||
showTransLag,
|
|
||||||
showDiaLag,
|
|
||||||
isFinalizing: !!isFinalizing,
|
isFinalizing: !!isFinalizing,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Only update lag values if signature unchanged
|
||||||
if (lastSignature === signature) {
|
if (lastSignature === signature) {
|
||||||
const t = document.querySelector(".lag-transcription-value");
|
const t = document.querySelector(".lag-transcription-value");
|
||||||
if (t) t.textContent = fmt1(remaining_time_transcription);
|
if (t) t.textContent = fmt1(remaining_time_transcription);
|
||||||
const d = document.querySelector(".lag-diarization-value");
|
const d = document.querySelector(".lag-diarization-value");
|
||||||
if (d) d.textContent = fmt1(remaining_time_diarization);
|
if (d) d.textContent = fmt1(remaining_time_diarization);
|
||||||
const ld = document.querySelector(".loading-diarization-value");
|
|
||||||
if (ld) ld.textContent = fmt1(remaining_time_diarization);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
lastSignature = signature;
|
lastSignature = signature;
|
||||||
|
|
||||||
const linesHtml = (lines || [])
|
const segmentsHtml = (segments || [])
|
||||||
.map((item, idx) => {
|
.map((item, idx) => {
|
||||||
|
const buffer = item.buffer || {};
|
||||||
|
const buffer_transcription = buffer.transcription || "";
|
||||||
|
const buffer_diarization = buffer.diarization || "";
|
||||||
|
const buffer_translation = buffer.translation || "";
|
||||||
|
|
||||||
let timeInfo = "";
|
let timeInfo = "";
|
||||||
if (item.start !== undefined && item.end !== undefined) {
|
if (item.start !== undefined && item.end !== undefined) {
|
||||||
timeInfo = ` ${item.start} - ${item.end}`;
|
timeInfo = ` ${item.start} - ${item.end}`;
|
||||||
@@ -373,80 +368,78 @@ function renderLinesWithBuffer(
|
|||||||
|
|
||||||
let speakerLabel = "";
|
let speakerLabel = "";
|
||||||
if (item.speaker === -2) {
|
if (item.speaker === -2) {
|
||||||
|
// Silence segment
|
||||||
speakerLabel = `<span class="silence">${silenceIcon}<span id='timeInfo'>${timeInfo}</span></span>`;
|
speakerLabel = `<span class="silence">${silenceIcon}<span id='timeInfo'>${timeInfo}</span></span>`;
|
||||||
} else if (item.speaker == 0 && !isFinalizing) {
|
|
||||||
speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'><span class="loading-diarization-value">${fmt1(
|
|
||||||
remaining_time_diarization
|
|
||||||
)}</span> second(s) of audio are undergoing diarization</span></span>`;
|
|
||||||
} else if (item.speaker !== 0) {
|
} else if (item.speaker !== 0) {
|
||||||
|
// Normal speaker segment
|
||||||
const speakerNum = `<span class="speaker-badge">${item.speaker}</span>`;
|
const speakerNum = `<span class="speaker-badge">${item.speaker}</span>`;
|
||||||
speakerLabel = `<span id="speaker">${speakerIcon}${speakerNum}<span id='timeInfo'>${timeInfo}</span></span>`;
|
speakerLabel = `<span id="speaker">${speakerIcon}${speakerNum}<span id='timeInfo'>${timeInfo}</span></span>`;
|
||||||
|
|
||||||
if (item.detected_language) {
|
if (item.language) {
|
||||||
speakerLabel += `<span class="label_language">${languageIcon}<span>${item.detected_language}</span></span>`;
|
speakerLabel += `<span class="label_language">${languageIcon}<span>${item.language}</span></span>`;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let currentLineText = item.text || "";
|
let currentLineText = item.text || "";
|
||||||
|
const isLastSegment = idx === segments.length - 1;
|
||||||
if (idx === lines.length - 1) {
|
const hasBufferContent = buffer_diarization || buffer_transcription;
|
||||||
if (!isFinalizing && item.speaker !== -2) {
|
|
||||||
speakerLabel += `<span class="label_transcription"><span class="spinner"></span>Transcription lag <span id='timeInfo'><span class="lag-transcription-value">${fmt1(
|
// Show lag indicators on last non-silent segment (without spinners)
|
||||||
remaining_time_transcription
|
if (isLastSegment && item.speaker !== -2 && !isFinalizing) {
|
||||||
)}</span>s</span></span>`;
|
if (remaining_time_transcription > 0) {
|
||||||
|
speakerLabel += `<span class="label_transcription">Transcription lag: <span class="lag-transcription-value">${fmt1(remaining_time_transcription)}</span>s</span>`;
|
||||||
if (buffer_diarization && remaining_time_diarization) {
|
|
||||||
speakerLabel += `<span class="label_diarization"><span class="spinner"></span>Diarization lag<span id='timeInfo'><span class="lag-diarization-value">${fmt1(
|
|
||||||
remaining_time_diarization
|
|
||||||
)}</span>s</span></span>`;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
if (buffer_diarization && remaining_time_diarization > 0) {
|
||||||
|
speakerLabel += `<span class="label_diarization">Diarization lag: <span class="lag-diarization-value">${fmt1(remaining_time_diarization)}</span>s</span>`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Render buffers
|
||||||
|
if (hasBufferContent && item.speaker !== -2) {
|
||||||
if (buffer_diarization) {
|
if (buffer_diarization) {
|
||||||
if (isFinalizing) {
|
if (isFinalizing) {
|
||||||
currentLineText +=
|
currentLineText += (currentLineText.length > 0 ? " " : "") + buffer_diarization.trim();
|
||||||
(currentLineText.length > 0 && buffer_diarization.trim().length > 0 ? " " : "") + buffer_diarization.trim();
|
|
||||||
} else {
|
} else {
|
||||||
currentLineText += `<span class="buffer_diarization">${buffer_diarization}</span>`;
|
currentLineText += `<span class="buffer_diarization">${buffer_diarization}</span>`;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (buffer_transcription) {
|
if (buffer_transcription) {
|
||||||
if (isFinalizing) {
|
if (isFinalizing) {
|
||||||
currentLineText +=
|
currentLineText += (currentLineText.length > 0 ? " " : "") + buffer_transcription.trim();
|
||||||
(currentLineText.length > 0 && buffer_transcription.trim().length > 0 ? " " : "") +
|
|
||||||
buffer_transcription.trim();
|
|
||||||
} else {
|
} else {
|
||||||
currentLineText += `<span class="buffer_transcription">${buffer_transcription}</span>`;
|
currentLineText += `<span class="buffer_transcription">${buffer_transcription}</span>`;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Translation
|
||||||
let translationContent = "";
|
let translationContent = "";
|
||||||
if (item.translation) {
|
if (item.translation) {
|
||||||
translationContent += item.translation.trim();
|
translationContent += item.translation.trim();
|
||||||
}
|
}
|
||||||
if (idx === lines.length - 1 && buffer_translation) {
|
if (buffer_translation) {
|
||||||
const bufferPiece = isFinalizing
|
const bufferPiece = isFinalizing
|
||||||
? buffer_translation
|
? buffer_translation
|
||||||
: `<span class="buffer_translation">${buffer_translation}</span>`;
|
: `<span class="buffer_translation">${buffer_translation}</span>`;
|
||||||
translationContent += translationContent ? `${bufferPiece}` : bufferPiece;
|
translationContent += translationContent ? bufferPiece : bufferPiece;
|
||||||
}
|
}
|
||||||
if (translationContent.trim().length > 0) {
|
if (translationContent.trim().length > 0) {
|
||||||
currentLineText += `
|
currentLineText += `
|
||||||
<div>
|
<div class="label_translation">
|
||||||
<div class="label_translation">
|
${translationIcon}
|
||||||
${translationIcon}
|
<span class="translation_text">${translationContent}</span>
|
||||||
<span class="translation_text">${translationContent}</span>
|
</div>`;
|
||||||
</div>
|
|
||||||
</div>`;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return currentLineText.trim().length > 0 || speakerLabel.length > 0
|
if (currentLineText.trim().length > 0 || speakerLabel.length > 0) {
|
||||||
? `<p>${speakerLabel}<br/><div class='textcontent'>${currentLineText}</div></p>`
|
return `<p>${speakerLabel}<br/><div class='textcontent'>${currentLineText}</div></p>`;
|
||||||
: `<p>${speakerLabel}<br/></p>`;
|
}
|
||||||
|
return speakerLabel ? `<p>${speakerLabel}</p>` : "";
|
||||||
})
|
})
|
||||||
|
.filter(html => html.length > 0)
|
||||||
.join("");
|
.join("");
|
||||||
|
|
||||||
linesTranscriptDiv.innerHTML = linesHtml;
|
linesTranscriptDiv.innerHTML = segmentsHtml;
|
||||||
const transcriptContainer = document.querySelector('.transcript-container');
|
const transcriptContainer = document.querySelector('.transcript-container');
|
||||||
if (transcriptContainer) {
|
if (transcriptContainer) {
|
||||||
transcriptContainer.scrollTo({ top: transcriptContainer.scrollHeight, behavior: "smooth" });
|
transcriptContainer.scrollTo({ top: transcriptContainer.scrollHeight, behavior: "smooth" });
|
||||||
|
|||||||
377
whisperlivekit/web/text_transcript.html
Normal file
377
whisperlivekit/web/text_transcript.html
Normal file
@@ -0,0 +1,377 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>WhisperLiveKit Transcript</title>
|
||||||
|
<style>
|
||||||
|
:root {
|
||||||
|
--bg: #111;
|
||||||
|
--text: #ddd;
|
||||||
|
--dim: #666;
|
||||||
|
--border: #333;
|
||||||
|
--active: #e74c3c;
|
||||||
|
}
|
||||||
|
body {
|
||||||
|
font-family: 'SF Mono', 'Monaco', 'Inconsolata', 'Roboto Mono', monospace;
|
||||||
|
background: var(--bg);
|
||||||
|
color: var(--text);
|
||||||
|
margin: 0;
|
||||||
|
padding: 2rem;
|
||||||
|
font-size: 13px;
|
||||||
|
line-height: 1.6;
|
||||||
|
}
|
||||||
|
.nav {
|
||||||
|
display: flex;
|
||||||
|
gap: 12px;
|
||||||
|
align-items: center;
|
||||||
|
margin-bottom: 3rem;
|
||||||
|
font-size: 12px;
|
||||||
|
}
|
||||||
|
button, input, select {
|
||||||
|
background: transparent;
|
||||||
|
border: 1px solid var(--border);
|
||||||
|
color: var(--dim);
|
||||||
|
padding: 6px 12px;
|
||||||
|
font-family: inherit;
|
||||||
|
font-size: inherit;
|
||||||
|
border-radius: 4px;
|
||||||
|
outline: none;
|
||||||
|
transition: all 0.2s;
|
||||||
|
}
|
||||||
|
button:hover, input:hover, input:focus, select:hover, select:focus {
|
||||||
|
border-color: var(--text);
|
||||||
|
color: var(--text);
|
||||||
|
cursor: pointer;
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
cursor: pointer;
|
||||||
|
appearance: none; /* Minimalist look */
|
||||||
|
background-image: linear-gradient(45deg, transparent 50%, var(--dim) 50%), linear-gradient(135deg, var(--dim) 50%, transparent 50%);
|
||||||
|
background-position: calc(100% - 15px) 50%, calc(100% - 10px) 50%;
|
||||||
|
background-size: 5px 5px, 5px 5px;
|
||||||
|
background-repeat: no-repeat;
|
||||||
|
padding-right: 25px;
|
||||||
|
}
|
||||||
|
select:hover, select:focus {
|
||||||
|
background-image: linear-gradient(45deg, transparent 50%, var(--text) 50%), linear-gradient(135deg, var(--text) 50%, transparent 50%);
|
||||||
|
}
|
||||||
|
button.recording {
|
||||||
|
border-color: var(--active);
|
||||||
|
color: var(--active);
|
||||||
|
}
|
||||||
|
input {
|
||||||
|
width: 150px;
|
||||||
|
cursor: text;
|
||||||
|
}
|
||||||
|
#status {
|
||||||
|
margin-left: auto;
|
||||||
|
color: var(--dim);
|
||||||
|
}
|
||||||
|
#transcript {
|
||||||
|
white-space: pre-wrap;
|
||||||
|
word-wrap: break-word;
|
||||||
|
max-width: 800px;
|
||||||
|
margin: 0 auto;
|
||||||
|
outline: none;
|
||||||
|
}
|
||||||
|
/* Minimal scrollbar */
|
||||||
|
::-webkit-scrollbar { width: 6px; }
|
||||||
|
::-webkit-scrollbar-track { background: transparent; }
|
||||||
|
::-webkit-scrollbar-thumb { background: #222; border-radius: 3px; }
|
||||||
|
::-webkit-scrollbar-thumb:hover { background: #333; }
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="nav">
|
||||||
|
<button id="recordBtn">Record</button>
|
||||||
|
<button id="copyBtn">Copy</button>
|
||||||
|
<select id="microphoneSelect"></select>
|
||||||
|
<input type="text" id="wsUrl" placeholder="WebSocket URL">
|
||||||
|
<div id="status">Ready</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div id="transcript"></div>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
const recordBtn = document.getElementById('recordBtn');
|
||||||
|
const copyBtn = document.getElementById('copyBtn');
|
||||||
|
const wsUrlInput = document.getElementById('wsUrl');
|
||||||
|
const statusEl = document.getElementById('status');
|
||||||
|
const transcriptEl = document.getElementById('transcript');
|
||||||
|
const microphoneSelect = document.getElementById('microphoneSelect');
|
||||||
|
|
||||||
|
// Default WebSocket URL
|
||||||
|
const protocol = window.location.protocol === 'https:' ? 'wss' : 'ws';
|
||||||
|
const host = window.location.hostname || 'localhost';
|
||||||
|
const port = window.location.port;
|
||||||
|
const defaultUrl = `${protocol}://${host}${port ? ':' + port : ''}/asr`;
|
||||||
|
wsUrlInput.value = defaultUrl;
|
||||||
|
|
||||||
|
let websocket = null;
|
||||||
|
let isRecording = false;
|
||||||
|
let audioContext = null;
|
||||||
|
let workletNode = null;
|
||||||
|
let recorderWorker = null;
|
||||||
|
let microphone = null;
|
||||||
|
let useAudioWorklet = false;
|
||||||
|
let recorder = null;
|
||||||
|
let availableMicrophones = [];
|
||||||
|
let selectedMicrophoneId = null;
|
||||||
|
|
||||||
|
async function enumerateMicrophones() {
|
||||||
|
try {
|
||||||
|
// Request permission first to get labels
|
||||||
|
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
|
||||||
|
stream.getTracks().forEach(track => track.stop());
|
||||||
|
|
||||||
|
const devices = await navigator.mediaDevices.enumerateDevices();
|
||||||
|
availableMicrophones = devices.filter(device => device.kind === 'audioinput');
|
||||||
|
|
||||||
|
populateMicrophoneSelect();
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error enumerating microphones:', error);
|
||||||
|
statusEl.textContent = "Mic permission needed";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function populateMicrophoneSelect() {
|
||||||
|
microphoneSelect.innerHTML = '<option value="">Default Microphone</option>';
|
||||||
|
|
||||||
|
availableMicrophones.forEach((device, index) => {
|
||||||
|
const option = document.createElement('option');
|
||||||
|
option.value = device.deviceId;
|
||||||
|
option.textContent = device.label || `Microphone ${index + 1}`;
|
||||||
|
microphoneSelect.appendChild(option);
|
||||||
|
});
|
||||||
|
|
||||||
|
const savedMicId = localStorage.getItem('selectedMicrophone');
|
||||||
|
if (savedMicId && availableMicrophones.some(mic => mic.deviceId === savedMicId)) {
|
||||||
|
microphoneSelect.value = savedMicId;
|
||||||
|
selectedMicrophoneId = savedMicId;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleMicrophoneChange() {
|
||||||
|
selectedMicrophoneId = microphoneSelect.value || null;
|
||||||
|
localStorage.setItem('selectedMicrophone', selectedMicrophoneId || '');
|
||||||
|
|
||||||
|
if (isRecording) {
|
||||||
|
stopRecording();
|
||||||
|
setTimeout(() => {
|
||||||
|
startRecording();
|
||||||
|
}, 500);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
microphoneSelect.addEventListener('change', handleMicrophoneChange);
|
||||||
|
|
||||||
|
// Initial enumeration
|
||||||
|
enumerateMicrophones();
|
||||||
|
navigator.mediaDevices.addEventListener('devicechange', enumerateMicrophones);
|
||||||
|
|
||||||
|
function formatSegment(segment) {
|
||||||
|
const speaker = segment.speaker;
|
||||||
|
const text = segment.text || '';
|
||||||
|
const buffer = segment.buffer || {};
|
||||||
|
const start = segment.start || '';
|
||||||
|
const end = segment.end || '';
|
||||||
|
const language = segment.language || '';
|
||||||
|
|
||||||
|
let output = '';
|
||||||
|
|
||||||
|
// Silence marker
|
||||||
|
if (speaker === -2) {
|
||||||
|
output += `[SILENCE ${start} - ${end}]\n`;
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Speaker header
|
||||||
|
output += `[SPEAKER ${speaker}]`;
|
||||||
|
if (start && end) output += ` ${start} - ${end}`;
|
||||||
|
if (language) output += ` [LANG: ${language}]`;
|
||||||
|
output += '\n';
|
||||||
|
|
||||||
|
// Main text
|
||||||
|
if (text) {
|
||||||
|
output += text;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Diarization buffer (text waiting for speaker assignment)
|
||||||
|
if (buffer.diarization) {
|
||||||
|
output += `[DIAR_BUFFER]${buffer.diarization}[/DIAR_BUFFER]`;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transcription buffer (text waiting for validation)
|
||||||
|
if (buffer.transcription) {
|
||||||
|
output += `[TRANS_BUFFER]${buffer.transcription}[/TRANS_BUFFER]`;
|
||||||
|
}
|
||||||
|
|
||||||
|
output += '\n';
|
||||||
|
|
||||||
|
// Translation
|
||||||
|
if (segment.translation) {
|
||||||
|
output += `[TRANSLATION]${segment.translation}`;
|
||||||
|
if (buffer.translation) {
|
||||||
|
output += `[TRANS_BUFFER]${buffer.translation}[/TRANS_BUFFER]`;
|
||||||
|
}
|
||||||
|
output += `[/TRANSLATION]\n`;
|
||||||
|
} else if (buffer.translation) {
|
||||||
|
output += `[TRANSLATION][TRANS_BUFFER]${buffer.translation}[/TRANS_BUFFER][/TRANSLATION]\n`;
|
||||||
|
}
|
||||||
|
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
function renderTranscript(data) {
|
||||||
|
const { segments = [], metadata = {}, status: msgStatus } = data;
|
||||||
|
|
||||||
|
if (msgStatus === 'no_audio_detected') {
|
||||||
|
// transcriptEl.textContent = '[NO AUDIO DETECTED]';
|
||||||
|
// Minimalist: maybe just don't show anything or show status
|
||||||
|
statusEl.textContent = 'No audio detected';
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let output = '';
|
||||||
|
|
||||||
|
// Metadata header
|
||||||
|
const remainingTrans = metadata.remaining_time_transcription || 0;
|
||||||
|
const remainingDiar = metadata.remaining_time_diarization || 0;
|
||||||
|
if (remainingTrans > 0 || remainingDiar > 0) {
|
||||||
|
output += `[LAG: trans=${remainingTrans.toFixed(1)}s diar=${remainingDiar.toFixed(1)}s]\n\n`;
|
||||||
|
}
|
||||||
|
|
||||||
|
// All segments
|
||||||
|
for (const segment of segments) {
|
||||||
|
output += formatSegment(segment);
|
||||||
|
output += '\n';
|
||||||
|
}
|
||||||
|
|
||||||
|
transcriptEl.textContent = output;
|
||||||
|
transcriptEl.scrollTop = transcriptEl.scrollHeight;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function startRecording() {
|
||||||
|
try {
|
||||||
|
websocket = new WebSocket(wsUrlInput.value);
|
||||||
|
|
||||||
|
websocket.onopen = async () => {
|
||||||
|
statusEl.textContent = 'Connecting...';
|
||||||
|
};
|
||||||
|
|
||||||
|
websocket.onmessage = async (event) => {
|
||||||
|
const data = JSON.parse(event.data);
|
||||||
|
|
||||||
|
if (data.type === 'config') {
|
||||||
|
useAudioWorklet = !!data.useAudioWorklet;
|
||||||
|
statusEl.textContent = 'Recording';
|
||||||
|
await initAudio();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (data.type === 'ready_to_stop') {
|
||||||
|
statusEl.textContent = 'Done';
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// transcript_update
|
||||||
|
renderTranscript(data);
|
||||||
|
};
|
||||||
|
|
||||||
|
websocket.onclose = () => {
|
||||||
|
statusEl.textContent = 'Disconnected';
|
||||||
|
stopRecording(false);
|
||||||
|
};
|
||||||
|
|
||||||
|
websocket.onerror = () => {
|
||||||
|
statusEl.textContent = 'Error';
|
||||||
|
};
|
||||||
|
|
||||||
|
} catch (err) {
|
||||||
|
statusEl.textContent = 'Error: ' + err.message;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function initAudio() {
|
||||||
|
const audioConstraints = selectedMicrophoneId
|
||||||
|
? { audio: { deviceId: { exact: selectedMicrophoneId } } }
|
||||||
|
: { audio: true };
|
||||||
|
|
||||||
|
const stream = await navigator.mediaDevices.getUserMedia(audioConstraints);
|
||||||
|
audioContext = new (window.AudioContext || window.webkitAudioContext)();
|
||||||
|
microphone = audioContext.createMediaStreamSource(stream);
|
||||||
|
|
||||||
|
if (useAudioWorklet) {
|
||||||
|
await audioContext.audioWorklet.addModule('/web/pcm_worklet.js');
|
||||||
|
workletNode = new AudioWorkletNode(audioContext, 'pcm-forwarder', {
|
||||||
|
numberOfInputs: 1, numberOfOutputs: 0, channelCount: 1
|
||||||
|
});
|
||||||
|
microphone.connect(workletNode);
|
||||||
|
|
||||||
|
recorderWorker = new Worker('/web/recorder_worker.js');
|
||||||
|
recorderWorker.postMessage({ command: 'init', config: { sampleRate: audioContext.sampleRate } });
|
||||||
|
|
||||||
|
recorderWorker.onmessage = (e) => {
|
||||||
|
if (websocket?.readyState === WebSocket.OPEN) {
|
||||||
|
websocket.send(e.data.buffer);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
workletNode.port.onmessage = (e) => {
|
||||||
|
const ab = e.data instanceof ArrayBuffer ? e.data : e.data.buffer;
|
||||||
|
recorderWorker.postMessage({ command: 'record', buffer: ab }, [ab]);
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
try {
|
||||||
|
recorder = new MediaRecorder(stream, { mimeType: 'audio/webm' });
|
||||||
|
} catch {
|
||||||
|
recorder = new MediaRecorder(stream);
|
||||||
|
}
|
||||||
|
recorder.ondataavailable = (e) => {
|
||||||
|
if (websocket?.readyState === WebSocket.OPEN && e.data?.size > 0) {
|
||||||
|
websocket.send(e.data);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
recorder.start(100);
|
||||||
|
}
|
||||||
|
|
||||||
|
isRecording = true;
|
||||||
|
recordBtn.textContent = 'Stop';
|
||||||
|
recordBtn.classList.add('recording');
|
||||||
|
}
|
||||||
|
|
||||||
|
function stopRecording(sendStop = true) {
|
||||||
|
if (sendStop && websocket?.readyState === WebSocket.OPEN) {
|
||||||
|
websocket.send(new Blob([], { type: 'audio/webm' }));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (recorder) { try { recorder.stop(); } catch {} recorder = null; }
|
||||||
|
if (recorderWorker) { recorderWorker.terminate(); recorderWorker = null; }
|
||||||
|
if (workletNode) { workletNode.disconnect(); workletNode = null; }
|
||||||
|
if (microphone) { microphone.disconnect(); microphone = null; }
|
||||||
|
if (audioContext) { audioContext.close(); audioContext = null; }
|
||||||
|
|
||||||
|
isRecording = false;
|
||||||
|
recordBtn.textContent = 'Record';
|
||||||
|
recordBtn.classList.remove('recording');
|
||||||
|
}
|
||||||
|
|
||||||
|
recordBtn.addEventListener('click', () => {
|
||||||
|
if (!isRecording) {
|
||||||
|
startRecording();
|
||||||
|
} else {
|
||||||
|
stopRecording();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
copyBtn.addEventListener('click', () => {
|
||||||
|
navigator.clipboard.writeText(transcriptEl.textContent).then(() => {
|
||||||
|
const original = copyBtn.textContent;
|
||||||
|
copyBtn.textContent = 'Copied';
|
||||||
|
setTimeout(() => { copyBtn.textContent = original; }, 1500);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
@@ -13,6 +13,37 @@ def get_web_interface_html():
|
|||||||
logger.error(f"Error loading web interface HTML: {e}")
|
logger.error(f"Error loading web interface HTML: {e}")
|
||||||
return "<html><body><h1>Error loading interface</h1></body></html>"
|
return "<html><body><h1>Error loading interface</h1></body></html>"
|
||||||
|
|
||||||
|
|
||||||
|
def get_text_transcript_html():
|
||||||
|
"""Loads the simple text-based transcript HTML for easy copy/paste."""
|
||||||
|
try:
|
||||||
|
with resources.files('whisperlivekit.web').joinpath('text_transcript.html').open('r', encoding='utf-8') as f:
|
||||||
|
html_content = f.read()
|
||||||
|
|
||||||
|
# Inline the worker scripts
|
||||||
|
with resources.files('whisperlivekit.web').joinpath('pcm_worklet.js').open('r', encoding='utf-8') as f:
|
||||||
|
worklet_code = f.read()
|
||||||
|
with resources.files('whisperlivekit.web').joinpath('recorder_worker.js').open('r', encoding='utf-8') as f:
|
||||||
|
worker_code = f.read()
|
||||||
|
|
||||||
|
html_content = html_content.replace(
|
||||||
|
"await audioContext.audioWorklet.addModule('/web/pcm_worklet.js');",
|
||||||
|
'const workletBlob = new Blob([`' + worklet_code + '`], { type: "application/javascript" });\n' +
|
||||||
|
'const workletUrl = URL.createObjectURL(workletBlob);\n' +
|
||||||
|
'await audioContext.audioWorklet.addModule(workletUrl);'
|
||||||
|
)
|
||||||
|
html_content = html_content.replace(
|
||||||
|
"recorderWorker = new Worker('/web/recorder_worker.js');",
|
||||||
|
'const workerBlob = new Blob([`' + worker_code + '`], { type: "application/javascript" });\n' +
|
||||||
|
'const workerUrl = URL.createObjectURL(workerBlob);\n' +
|
||||||
|
'recorderWorker = new Worker(workerUrl);'
|
||||||
|
)
|
||||||
|
|
||||||
|
return html_content
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading text transcript HTML: {e}")
|
||||||
|
return "<html><body><h1>Error loading text interface</h1></body></html>"
|
||||||
|
|
||||||
def get_inline_ui_html():
|
def get_inline_ui_html():
|
||||||
"""Returns the complete web interface HTML with all assets embedded in a single call."""
|
"""Returns the complete web interface HTML with all assets embedded in a single call."""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -18,6 +18,8 @@ from whisperlivekit.whisper.decoding import (DecodingOptions, DecodingResult,
|
|||||||
from whisperlivekit.whisper.model import ModelDimensions, Whisper
|
from whisperlivekit.whisper.model import ModelDimensions, Whisper
|
||||||
from whisperlivekit.whisper.transcribe import transcribe
|
from whisperlivekit.whisper.transcribe import transcribe
|
||||||
from whisperlivekit.whisper.version import __version__
|
from whisperlivekit.whisper.version import __version__
|
||||||
|
from whisperlivekit.whisper.lora import (LoRAAdapter, LoRAAdapterManager,
|
||||||
|
LoRAConfig, LoRALinear)
|
||||||
|
|
||||||
_MODELS = {
|
_MODELS = {
|
||||||
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
||||||
@@ -108,7 +110,7 @@ def available_models() -> List[str]:
|
|||||||
def _infer_dims_from_config(path: str) -> Optional[ModelDimensions]:
|
def _infer_dims_from_config(path: str) -> Optional[ModelDimensions]:
|
||||||
"""
|
"""
|
||||||
attempt to infer ModelDimensions from a HF style config.json located
|
attempt to infer ModelDimensions from a HF style config.json located
|
||||||
next to the given checkpoint, usefull for distilled models/MLX models.
|
next to the given checkpoint, usefull for distilled models
|
||||||
"""
|
"""
|
||||||
candidates = []
|
candidates = []
|
||||||
if os.path.isdir(path):
|
if os.path.isdir(path):
|
||||||
@@ -122,25 +124,6 @@ def _infer_dims_from_config(path: str) -> Optional[ModelDimensions]:
|
|||||||
with open(candidate, "r", encoding="utf-8") as f:
|
with open(candidate, "r", encoding="utf-8") as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
|
|
||||||
# native Whisper format
|
|
||||||
native_keys = ["n_mels", "n_audio_ctx", "n_audio_state", "n_audio_head",
|
|
||||||
"n_audio_layer", "n_vocab", "n_text_ctx", "n_text_state",
|
|
||||||
"n_text_head", "n_text_layer"]
|
|
||||||
if all(k in config for k in native_keys):
|
|
||||||
return ModelDimensions(
|
|
||||||
n_mels=config["n_mels"],
|
|
||||||
n_audio_ctx=config["n_audio_ctx"],
|
|
||||||
n_audio_state=config["n_audio_state"],
|
|
||||||
n_audio_head=config["n_audio_head"],
|
|
||||||
n_audio_layer=config["n_audio_layer"],
|
|
||||||
n_vocab=config["n_vocab"],
|
|
||||||
n_text_ctx=config["n_text_ctx"],
|
|
||||||
n_text_state=config["n_text_state"],
|
|
||||||
n_text_head=config["n_text_head"],
|
|
||||||
n_text_layer=config["n_text_layer"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# HuggingFace format
|
|
||||||
try:
|
try:
|
||||||
return ModelDimensions(
|
return ModelDimensions(
|
||||||
n_mels=config["num_mel_bins"],
|
n_mels=config["num_mel_bins"],
|
||||||
@@ -255,24 +238,6 @@ def _convert_hf_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, tor
|
|||||||
return converted if converted else state_dict
|
return converted if converted else state_dict
|
||||||
|
|
||||||
|
|
||||||
def _convert_mlx_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Converts an mlx whisper checkpoint to a default openai whisper one
|
|
||||||
"""
|
|
||||||
if not any("mlp1" in k or "mlp2" in k for k in state_dict):
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
converted = {}
|
|
||||||
for key, value in state_dict.items():
|
|
||||||
if key == "alignment_heads":
|
|
||||||
continue
|
|
||||||
|
|
||||||
new_key = key.replace(".mlp1.", ".mlp.0.").replace(".mlp2.", ".mlp.2.")
|
|
||||||
converted[new_key] = value
|
|
||||||
|
|
||||||
return converted
|
|
||||||
|
|
||||||
|
|
||||||
def _load_lora_state(lora_path: str):
|
def _load_lora_state(lora_path: str):
|
||||||
safe_path = os.path.join(lora_path, "adapter_model.safetensors")
|
safe_path = os.path.join(lora_path, "adapter_model.safetensors")
|
||||||
bin_path = os.path.join(lora_path, "adapter_model.bin")
|
bin_path = os.path.join(lora_path, "adapter_model.bin")
|
||||||
@@ -557,12 +522,7 @@ def load_model(
|
|||||||
state_dict = checkpoint["model_state_dict"]
|
state_dict = checkpoint["model_state_dict"]
|
||||||
else:
|
else:
|
||||||
state_dict = checkpoint
|
state_dict = checkpoint
|
||||||
|
|
||||||
if alignment_heads is None and "alignment_heads" in state_dict:
|
|
||||||
alignment_heads = state_dict["alignment_heads"]
|
|
||||||
|
|
||||||
state_dict = _convert_hf_state_dict(state_dict)
|
state_dict = _convert_hf_state_dict(state_dict)
|
||||||
state_dict = _convert_mlx_state_dict(state_dict)
|
|
||||||
_apply_lora_adapter(state_dict, lora_path)
|
_apply_lora_adapter(state_dict, lora_path)
|
||||||
|
|
||||||
if dims_cfg is not None:
|
if dims_cfg is not None:
|
||||||
@@ -588,16 +548,99 @@ def load_model(
|
|||||||
model.load_state_dict(state_dict)
|
model.load_state_dict(state_dict)
|
||||||
|
|
||||||
if alignment_heads is not None:
|
if alignment_heads is not None:
|
||||||
if isinstance(alignment_heads, bytes):
|
model.set_alignment_heads(alignment_heads)
|
||||||
model.set_alignment_heads(alignment_heads)
|
|
||||||
elif isinstance(alignment_heads, torch.Tensor): #for mlx whisper
|
|
||||||
mask = torch.zeros(dims.n_text_layer, dims.n_text_head, dtype=torch.bool)
|
|
||||||
for layer, head in alignment_heads.tolist():
|
|
||||||
mask[layer, head] = True
|
|
||||||
model.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
|
|
||||||
return model.to(device)
|
return model.to(device)
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_with_lora_manager(
|
||||||
|
name: str,
|
||||||
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
|
download_root: str = None,
|
||||||
|
in_memory: bool = False,
|
||||||
|
decoder_only: bool = False,
|
||||||
|
custom_alignment_heads: Optional[str] = None,
|
||||||
|
adapters: Optional[Dict[str, str]] = None,
|
||||||
|
) -> tuple:
|
||||||
|
"""
|
||||||
|
Load a Whisper model with a LoRA adapter manager for dynamic adapter swapping.
|
||||||
|
|
||||||
|
This allows you to load multiple LoRA adapters and switch between them at runtime
|
||||||
|
without keeping multiple full models in memory.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
Model name or path (same as load_model)
|
||||||
|
device : Union[str, torch.device]
|
||||||
|
Device to load model on
|
||||||
|
download_root : str
|
||||||
|
Download directory for model files
|
||||||
|
in_memory : bool
|
||||||
|
Whether to preload model weights into host memory
|
||||||
|
decoder_only : bool
|
||||||
|
If True, only load the decoder (no encoder)
|
||||||
|
custom_alignment_heads : str
|
||||||
|
Custom alignment heads configuration
|
||||||
|
adapters : Dict[str, str]
|
||||||
|
Optional dict mapping adapter names to paths/HuggingFace repo IDs.
|
||||||
|
Example: {"french": "path/to/french-lora", "spanish": "user/spanish-whisper-lora"}
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
model : Whisper
|
||||||
|
The base Whisper model (without any LoRA baked in)
|
||||||
|
manager : LoRAAdapterManager
|
||||||
|
The adapter manager for loading/switching adapters
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
>>> model, manager = load_model_with_lora_manager(
|
||||||
|
... "large-v3",
|
||||||
|
... adapters={
|
||||||
|
... "french": "path/to/french-lora",
|
||||||
|
... "spanish": "path/to/spanish-lora"
|
||||||
|
... }
|
||||||
|
... )
|
||||||
|
>>>
|
||||||
|
>>> # Switch to French adapter
|
||||||
|
>>> manager.set_adapter("french")
|
||||||
|
>>> result_fr = model.transcribe(audio_fr)
|
||||||
|
>>>
|
||||||
|
>>> # Switch to Spanish adapter
|
||||||
|
>>> manager.set_adapter("spanish")
|
||||||
|
>>> result_es = model.transcribe(audio_es)
|
||||||
|
>>>
|
||||||
|
>>> # Use base model without LoRA
|
||||||
|
>>> manager.set_adapter(None)
|
||||||
|
>>> result_base = model.transcribe(audio)
|
||||||
|
>>>
|
||||||
|
>>> # Check memory usage
|
||||||
|
>>> print(manager.get_memory_usage())
|
||||||
|
{'french': 12.5, 'spanish': 12.5} # MB per adapter
|
||||||
|
"""
|
||||||
|
# Load the base model WITHOUT any LoRA baked in
|
||||||
|
model = load_model(
|
||||||
|
name=name,
|
||||||
|
device=device,
|
||||||
|
download_root=download_root,
|
||||||
|
in_memory=in_memory,
|
||||||
|
decoder_only=decoder_only,
|
||||||
|
custom_alignment_heads=custom_alignment_heads,
|
||||||
|
lora_path=None, # Important: no baked-in LoRA
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the adapter manager
|
||||||
|
manager = LoRAAdapterManager(model)
|
||||||
|
|
||||||
|
# Load any provided adapters
|
||||||
|
if adapters:
|
||||||
|
for adapter_name, adapter_path in adapters.items():
|
||||||
|
manager.load_adapter(adapter_name, adapter_path)
|
||||||
|
|
||||||
|
return model, manager
|
||||||
|
|
||||||
|
|
||||||
def convert_encoder_to_coreml(
|
def convert_encoder_to_coreml(
|
||||||
model_name = "base",
|
model_name = "base",
|
||||||
output_path= "whisper_encoder.mlpackage",
|
output_path= "whisper_encoder.mlpackage",
|
||||||
|
|||||||
473
whisperlivekit/whisper/lora.py
Normal file
473
whisperlivekit/whisper/lora.py
Normal file
@@ -0,0 +1,473 @@
|
|||||||
|
"""
|
||||||
|
Dynamic LoRA adapter support for Whisper models.
|
||||||
|
|
||||||
|
This module enables loading a single base Whisper model and dynamically swapping
|
||||||
|
between multiple LoRA adapters at runtime, saving GPU memory when working with
|
||||||
|
multiple language-specific fine-tuned models.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from whisperlivekit.whisper import load_model
|
||||||
|
from whisperlivekit.whisper.lora import LoRAAdapterManager
|
||||||
|
|
||||||
|
# Load base model without any LoRA baked in
|
||||||
|
model = load_model("large-v3", device="cuda")
|
||||||
|
|
||||||
|
# Create adapter manager
|
||||||
|
manager = LoRAAdapterManager(model)
|
||||||
|
|
||||||
|
# Load multiple adapters (small memory footprint each)
|
||||||
|
manager.load_adapter("french", "path/to/french-lora")
|
||||||
|
manager.load_adapter("spanish", "path/to/spanish-lora")
|
||||||
|
|
||||||
|
# Switch between adapters at runtime
|
||||||
|
manager.set_adapter("french")
|
||||||
|
result_fr = model.transcribe(audio_fr)
|
||||||
|
|
||||||
|
manager.set_adapter("spanish")
|
||||||
|
result_es = model.transcribe(audio_es)
|
||||||
|
|
||||||
|
# Disable LoRA (use base model only)
|
||||||
|
manager.set_adapter(None)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
from .model import Linear
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoRAConfig:
|
||||||
|
"""Configuration for a LoRA adapter."""
|
||||||
|
r: int # LoRA rank
|
||||||
|
alpha: float # LoRA alpha (scaling factor)
|
||||||
|
target_modules: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scaling(self) -> float:
|
||||||
|
return self.alpha / self.r
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoRAAdapter:
|
||||||
|
"""Holds the LoRA A/B weight matrices for a single adapter."""
|
||||||
|
name: str
|
||||||
|
config: LoRAConfig
|
||||||
|
# Maps target module name -> (A matrix, B matrix)
|
||||||
|
weights: Dict[str, Tuple[Tensor, Tensor]] = field(default_factory=dict)
|
||||||
|
device: torch.device = field(default_factory=lambda: torch.device("cpu"))
|
||||||
|
dtype: torch.dtype = field(default=torch.float32)
|
||||||
|
|
||||||
|
def to(self, device: torch.device, dtype: Optional[torch.dtype] = None):
|
||||||
|
"""Move adapter weights to specified device/dtype."""
|
||||||
|
self.device = device
|
||||||
|
if dtype is not None:
|
||||||
|
self.dtype = dtype
|
||||||
|
self.weights = {
|
||||||
|
name: (a.to(device=device, dtype=dtype or self.dtype),
|
||||||
|
b.to(device=device, dtype=dtype or self.dtype))
|
||||||
|
for name, (a, b) in self.weights.items()
|
||||||
|
}
|
||||||
|
return self
|
||||||
|
|
||||||
|
def memory_footprint_mb(self) -> float:
|
||||||
|
"""Return approximate memory usage in MB."""
|
||||||
|
total_bytes = 0
|
||||||
|
for a, b in self.weights.values():
|
||||||
|
total_bytes += a.numel() * a.element_size()
|
||||||
|
total_bytes += b.numel() * b.element_size()
|
||||||
|
return total_bytes / (1024 * 1024)
|
||||||
|
|
||||||
|
|
||||||
|
class LoRALinear(nn.Module):
|
||||||
|
"""
|
||||||
|
A Linear layer wrapper that supports dynamic LoRA injection.
|
||||||
|
|
||||||
|
The base weights remain unchanged. LoRA is applied additively during forward:
|
||||||
|
output = base_linear(x) + (x @ A @ B) * scaling
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, base_linear: Linear):
|
||||||
|
super().__init__()
|
||||||
|
self.base_linear = base_linear
|
||||||
|
self.lora_A: Optional[Tensor] = None
|
||||||
|
self.lora_B: Optional[Tensor] = None
|
||||||
|
self.scaling: float = 1.0
|
||||||
|
self._lora_enabled: bool = False
|
||||||
|
|
||||||
|
def set_lora(self, A: Optional[Tensor], B: Optional[Tensor], scaling: float = 1.0):
|
||||||
|
"""Set the LoRA matrices for this layer."""
|
||||||
|
self.lora_A = A
|
||||||
|
self.lora_B = B
|
||||||
|
self.scaling = scaling
|
||||||
|
self._lora_enabled = A is not None and B is not None
|
||||||
|
|
||||||
|
def clear_lora(self):
|
||||||
|
"""Remove LoRA from this layer."""
|
||||||
|
self.lora_A = None
|
||||||
|
self.lora_B = None
|
||||||
|
self._lora_enabled = False
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
# Base linear output
|
||||||
|
out = self.base_linear(x)
|
||||||
|
|
||||||
|
# Add LoRA contribution if enabled
|
||||||
|
if self._lora_enabled and self.lora_A is not None and self.lora_B is not None:
|
||||||
|
# x: (..., in_features)
|
||||||
|
# A: (in_features, r)
|
||||||
|
# B: (r, out_features)
|
||||||
|
# lora_out: (..., out_features)
|
||||||
|
lora_out = (x @ self.lora_A.to(x.dtype)) @ self.lora_B.to(x.dtype)
|
||||||
|
out = out + lora_out * self.scaling
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
# Delegate attribute access to base_linear for compatibility
|
||||||
|
@property
|
||||||
|
def weight(self):
|
||||||
|
return self.base_linear.weight
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bias(self):
|
||||||
|
return self.base_linear.bias
|
||||||
|
|
||||||
|
@property
|
||||||
|
def in_features(self):
|
||||||
|
return self.base_linear.in_features
|
||||||
|
|
||||||
|
@property
|
||||||
|
def out_features(self):
|
||||||
|
return self.base_linear.out_features
|
||||||
|
|
||||||
|
|
||||||
|
# Mapping from HuggingFace LoRA module names to Whisper module paths
|
||||||
|
_HF_TO_WHISPER_MODULE_MAP = {
|
||||||
|
# Encoder attention
|
||||||
|
"model.encoder.layers.{}.self_attn.q_proj": "encoder.blocks.{}.attn.query",
|
||||||
|
"model.encoder.layers.{}.self_attn.k_proj": "encoder.blocks.{}.attn.key",
|
||||||
|
"model.encoder.layers.{}.self_attn.v_proj": "encoder.blocks.{}.attn.value",
|
||||||
|
"model.encoder.layers.{}.self_attn.out_proj": "encoder.blocks.{}.attn.out",
|
||||||
|
# Encoder MLP
|
||||||
|
"model.encoder.layers.{}.fc1": "encoder.blocks.{}.mlp.0",
|
||||||
|
"model.encoder.layers.{}.fc2": "encoder.blocks.{}.mlp.2",
|
||||||
|
|
||||||
|
# Decoder self-attention
|
||||||
|
"model.decoder.layers.{}.self_attn.q_proj": "decoder.blocks.{}.attn.query",
|
||||||
|
"model.decoder.layers.{}.self_attn.k_proj": "decoder.blocks.{}.attn.key",
|
||||||
|
"model.decoder.layers.{}.self_attn.v_proj": "decoder.blocks.{}.attn.value",
|
||||||
|
"model.decoder.layers.{}.self_attn.out_proj": "decoder.blocks.{}.attn.out",
|
||||||
|
# Decoder cross-attention
|
||||||
|
"model.decoder.layers.{}.encoder_attn.q_proj": "decoder.blocks.{}.cross_attn.query",
|
||||||
|
"model.decoder.layers.{}.encoder_attn.k_proj": "decoder.blocks.{}.cross_attn.key",
|
||||||
|
"model.decoder.layers.{}.encoder_attn.v_proj": "decoder.blocks.{}.cross_attn.value",
|
||||||
|
"model.decoder.layers.{}.encoder_attn.out_proj": "decoder.blocks.{}.cross_attn.out",
|
||||||
|
# Decoder MLP
|
||||||
|
"model.decoder.layers.{}.fc1": "decoder.blocks.{}.mlp.0",
|
||||||
|
"model.decoder.layers.{}.fc2": "decoder.blocks.{}.mlp.2",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_hf_module_name(name: str) -> str:
|
||||||
|
"""Normalize HF-style LoRA module names."""
|
||||||
|
if name.startswith("base_model."):
|
||||||
|
name = name[len("base_model."):]
|
||||||
|
if name.startswith("model.model."):
|
||||||
|
name = name[len("model."):]
|
||||||
|
if not name.startswith("model."):
|
||||||
|
name = f"model.{name}"
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def _map_hf_to_whisper_module(hf_name: str) -> Optional[str]:
|
||||||
|
"""Map a HuggingFace LoRA module name to Whisper module path."""
|
||||||
|
hf_name = _normalize_hf_module_name(hf_name)
|
||||||
|
|
||||||
|
# Try to match with layer index patterns
|
||||||
|
import re
|
||||||
|
|
||||||
|
# Match patterns like model.encoder.layers.5.self_attn.q_proj
|
||||||
|
for pattern, target_pattern in _HF_TO_WHISPER_MODULE_MAP.items():
|
||||||
|
# Create regex from pattern (replace {} with capture group)
|
||||||
|
regex = pattern.replace(".", r"\.").replace("{}", r"(\d+)")
|
||||||
|
match = re.fullmatch(regex, hf_name)
|
||||||
|
if match:
|
||||||
|
layer_idx = match.group(1)
|
||||||
|
return target_pattern.format(layer_idx)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_module_by_path(model: nn.Module, path: str) -> Optional[nn.Module]:
|
||||||
|
"""Get a submodule by dot-separated path."""
|
||||||
|
parts = path.split(".")
|
||||||
|
current = model
|
||||||
|
for part in parts:
|
||||||
|
if hasattr(current, part):
|
||||||
|
current = getattr(current, part)
|
||||||
|
elif hasattr(current, "__getitem__"):
|
||||||
|
try:
|
||||||
|
current = current[int(part)]
|
||||||
|
except (ValueError, IndexError, KeyError):
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
return current
|
||||||
|
|
||||||
|
|
||||||
|
def _set_module_by_path(model: nn.Module, path: str, module: nn.Module):
|
||||||
|
"""Set a submodule by dot-separated path."""
|
||||||
|
parts = path.split(".")
|
||||||
|
parent = model
|
||||||
|
for part in parts[:-1]:
|
||||||
|
if hasattr(parent, part):
|
||||||
|
parent = getattr(parent, part)
|
||||||
|
elif hasattr(parent, "__getitem__"):
|
||||||
|
parent = parent[int(part)]
|
||||||
|
setattr(parent, parts[-1], module)
|
||||||
|
|
||||||
|
|
||||||
|
class LoRAAdapterManager:
|
||||||
|
"""
|
||||||
|
Manages multiple LoRA adapters for a Whisper model.
|
||||||
|
|
||||||
|
Enables loading multiple adapters and switching between them at runtime
|
||||||
|
without reloading the full model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model: nn.Module):
|
||||||
|
"""
|
||||||
|
Initialize the adapter manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: A Whisper model instance
|
||||||
|
"""
|
||||||
|
self.model = model
|
||||||
|
self.adapters: Dict[str, LoRAAdapter] = {}
|
||||||
|
self.current_adapter: Optional[str] = None
|
||||||
|
self._lora_layers: Dict[str, LoRALinear] = {}
|
||||||
|
self._original_layers: Dict[str, Linear] = {}
|
||||||
|
self._initialized = False
|
||||||
|
|
||||||
|
def _initialize_lora_layers(self, target_modules: List[str]):
|
||||||
|
"""
|
||||||
|
Replace target Linear layers with LoRALinear wrappers.
|
||||||
|
|
||||||
|
This is done lazily on first adapter load.
|
||||||
|
"""
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Find and wrap all potential LoRA target modules
|
||||||
|
for whisper_path in target_modules:
|
||||||
|
module = _get_module_by_path(self.model, whisper_path)
|
||||||
|
if module is None:
|
||||||
|
continue
|
||||||
|
if isinstance(module, Linear) and not isinstance(module, LoRALinear):
|
||||||
|
# Wrap the Linear layer
|
||||||
|
lora_linear = LoRALinear(module)
|
||||||
|
_set_module_by_path(self.model, whisper_path, lora_linear)
|
||||||
|
self._lora_layers[whisper_path] = lora_linear
|
||||||
|
self._original_layers[whisper_path] = module
|
||||||
|
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
def _resolve_lora_path(self, lora_path: str) -> str:
|
||||||
|
"""Resolve LoRA path, downloading from HuggingFace Hub if needed."""
|
||||||
|
if os.path.isdir(lora_path):
|
||||||
|
return lora_path
|
||||||
|
|
||||||
|
# Try HuggingFace Hub
|
||||||
|
if "/" in lora_path:
|
||||||
|
try:
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
return snapshot_download(
|
||||||
|
repo_id=lora_path,
|
||||||
|
allow_patterns=["adapter_config.json", "adapter_model.*"],
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Could not find LoRA adapter at local path or HuggingFace Hub: {lora_path}. Error: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
raise FileNotFoundError(f"LoRA path '{lora_path}' not found.")
|
||||||
|
|
||||||
|
def _load_adapter_weights(self, lora_path: str) -> Dict[str, Tensor]:
|
||||||
|
"""Load adapter weights from safetensors or bin file."""
|
||||||
|
safe_path = os.path.join(lora_path, "adapter_model.safetensors")
|
||||||
|
bin_path = os.path.join(lora_path, "adapter_model.bin")
|
||||||
|
|
||||||
|
if os.path.isfile(safe_path):
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
return load_file(safe_path)
|
||||||
|
elif os.path.isfile(bin_path):
|
||||||
|
return torch.load(bin_path, map_location="cpu")
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"No adapter weights found in {lora_path}. "
|
||||||
|
"Expected adapter_model.safetensors or adapter_model.bin."
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_adapter(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
lora_path: str,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
) -> LoRAAdapter:
|
||||||
|
"""
|
||||||
|
Load a LoRA adapter from disk or HuggingFace Hub.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Unique name for this adapter (e.g., "french", "spanish")
|
||||||
|
lora_path: Local path or HuggingFace repo ID
|
||||||
|
device: Device to load weights to (default: model's device)
|
||||||
|
dtype: Data type for weights (default: model's dtype)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The loaded LoRAAdapter
|
||||||
|
"""
|
||||||
|
if device is None:
|
||||||
|
device = next(self.model.parameters()).device
|
||||||
|
if dtype is None:
|
||||||
|
dtype = next(self.model.parameters()).dtype
|
||||||
|
|
||||||
|
# Resolve path
|
||||||
|
lora_path = self._resolve_lora_path(lora_path)
|
||||||
|
|
||||||
|
# Load config
|
||||||
|
config_path = os.path.join(lora_path, "adapter_config.json")
|
||||||
|
if not os.path.isfile(config_path):
|
||||||
|
raise FileNotFoundError(f"Missing adapter_config.json in {lora_path}")
|
||||||
|
|
||||||
|
with open(config_path, "r", encoding="utf-8") as f:
|
||||||
|
config_dict = json.load(f)
|
||||||
|
|
||||||
|
if config_dict.get("peft_type") != "LORA":
|
||||||
|
raise ValueError("Only LoRA adapters are supported.")
|
||||||
|
|
||||||
|
config = LoRAConfig(
|
||||||
|
r=config_dict["r"],
|
||||||
|
alpha=config_dict.get("lora_alpha") or config_dict.get("alpha"),
|
||||||
|
target_modules=config_dict.get("target_modules", []),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load weights
|
||||||
|
adapter_state = self._load_adapter_weights(lora_path)
|
||||||
|
|
||||||
|
# Parse LoRA A/B matrices and map to Whisper module paths
|
||||||
|
lora_layers: Dict[str, Dict[str, Tensor]] = {}
|
||||||
|
for key, tensor in adapter_state.items():
|
||||||
|
if key.endswith("lora_A.weight"):
|
||||||
|
module = key[:-len(".lora_A.weight")]
|
||||||
|
lora_layers.setdefault(module, {})["A"] = tensor
|
||||||
|
elif key.endswith("lora_B.weight"):
|
||||||
|
module = key[:-len(".lora_B.weight")]
|
||||||
|
lora_layers.setdefault(module, {})["B"] = tensor
|
||||||
|
|
||||||
|
# Map to Whisper module paths and collect weights
|
||||||
|
weights: Dict[str, Tuple[Tensor, Tensor]] = {}
|
||||||
|
whisper_paths = set()
|
||||||
|
|
||||||
|
for hf_module, parts in lora_layers.items():
|
||||||
|
if "A" not in parts or "B" not in parts:
|
||||||
|
continue
|
||||||
|
|
||||||
|
whisper_path = _map_hf_to_whisper_module(hf_module)
|
||||||
|
if whisper_path is None:
|
||||||
|
# Try direct mapping (module might already be in Whisper format)
|
||||||
|
whisper_path = hf_module
|
||||||
|
|
||||||
|
# A: (r, in_features) -> transpose to (in_features, r)
|
||||||
|
# B: (out_features, r) -> transpose to (r, out_features)
|
||||||
|
A = parts["A"].T # (in_features, r)
|
||||||
|
B = parts["B"].T # (r, out_features)
|
||||||
|
|
||||||
|
weights[whisper_path] = (A, B)
|
||||||
|
whisper_paths.add(whisper_path)
|
||||||
|
|
||||||
|
# Create adapter
|
||||||
|
adapter = LoRAAdapter(
|
||||||
|
name=name,
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
adapter.to(device, dtype)
|
||||||
|
|
||||||
|
# Initialize LoRA layers if not done yet
|
||||||
|
self._initialize_lora_layers(list(whisper_paths))
|
||||||
|
|
||||||
|
# Store adapter
|
||||||
|
self.adapters[name] = adapter
|
||||||
|
|
||||||
|
return adapter
|
||||||
|
|
||||||
|
def set_adapter(self, name: Optional[str]):
|
||||||
|
"""
|
||||||
|
Switch to a different adapter or disable LoRA.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Adapter name to activate, or None to disable all LoRA
|
||||||
|
"""
|
||||||
|
if name is not None and name not in self.adapters:
|
||||||
|
raise KeyError(f"Adapter '{name}' not loaded. Available: {list(self.adapters.keys())}")
|
||||||
|
|
||||||
|
# Clear all LoRA from layers
|
||||||
|
for lora_linear in self._lora_layers.values():
|
||||||
|
lora_linear.clear_lora()
|
||||||
|
|
||||||
|
self.current_adapter = name
|
||||||
|
|
||||||
|
if name is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Apply the selected adapter
|
||||||
|
adapter = self.adapters[name]
|
||||||
|
for module_path, (A, B) in adapter.weights.items():
|
||||||
|
if module_path in self._lora_layers:
|
||||||
|
self._lora_layers[module_path].set_lora(A, B, adapter.config.scaling)
|
||||||
|
|
||||||
|
def unload_adapter(self, name: str):
|
||||||
|
"""
|
||||||
|
Unload an adapter from memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Name of adapter to unload
|
||||||
|
"""
|
||||||
|
if name not in self.adapters:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.current_adapter == name:
|
||||||
|
self.set_adapter(None)
|
||||||
|
|
||||||
|
del self.adapters[name]
|
||||||
|
|
||||||
|
def list_adapters(self) -> List[str]:
|
||||||
|
"""Return list of loaded adapter names."""
|
||||||
|
return list(self.adapters.keys())
|
||||||
|
|
||||||
|
def get_memory_usage(self) -> Dict[str, float]:
|
||||||
|
"""Return memory usage in MB for each loaded adapter."""
|
||||||
|
return {name: adapter.memory_footprint_mb() for name, adapter in self.adapters.items()}
|
||||||
|
|
||||||
|
def restore_original_layers(self):
|
||||||
|
"""
|
||||||
|
Restore the original Linear layers, removing LoRA wrappers.
|
||||||
|
|
||||||
|
Call this if you want to go back to the original model structure.
|
||||||
|
"""
|
||||||
|
for path, original in self._original_layers.items():
|
||||||
|
_set_module_by_path(self.model, path, original)
|
||||||
|
|
||||||
|
self._lora_layers.clear()
|
||||||
|
self._original_layers.clear()
|
||||||
|
self._initialized = False
|
||||||
|
self.current_adapter = None
|
||||||
|
|
||||||
Reference in New Issue
Block a user