mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-09 15:25:34 +00:00
Compare commits
4 Commits
voxtral_te
...
api_live
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d9a4c8dcb2 | ||
|
|
4fb735a784 | ||
|
|
d2f998cb7e | ||
|
|
7b18917f2b |
BIN
architecture.png
BIN
architecture.png
Binary file not shown.
|
Before Width: | Height: | Size: 422 KiB After Width: | Height: | Size: 422 KiB |
299
docs/API.md
299
docs/API.md
@@ -1,53 +1,22 @@
|
||||
# WhisperLiveKit WebSocket API Documentation
|
||||
|
||||
> !! **Note**: The new API structure described in this document is currently under deployment.
|
||||
This documentation is intended for devs who want to build custom frontends.
|
||||
|
||||
WLK provides real-time speech transcription, speaker diarization, and translation through a WebSocket API. The server sends incremental updates as audio is processed, allowing clients to display live transcription results with minimal latency.
|
||||
WLK provides real-time speech transcription, speaker diarization, and translation through a WebSocket API. The server sends updates as audio is processed, allowing clients to display live transcription results with minimal latency.
|
||||
|
||||
---
|
||||
|
||||
## Legacy API (Current)
|
||||
## Endpoints
|
||||
|
||||
### Message Structure
|
||||
|
||||
The current API sends complete state snapshots on each update (several time per second)
|
||||
|
||||
```typescript
|
||||
{
|
||||
"type": str,
|
||||
"status": str,
|
||||
"lines": [
|
||||
{
|
||||
"speaker": int,
|
||||
"text": str,
|
||||
"start": float,
|
||||
"end": float,
|
||||
"translation": str | null,
|
||||
"detected_language": str
|
||||
}
|
||||
],
|
||||
"buffer_transcription": str,
|
||||
"buffer_diarization": str,
|
||||
"remaining_time_transcription": float,
|
||||
"remaining_time_diarization": float
|
||||
}
|
||||
```
|
||||
| Endpoint | Description |
|
||||
|----------|-------------|
|
||||
| `/` | Main web interface with visual styling |
|
||||
| `/text` | Simple text-based interface for easy copy/paste (debug/development) |
|
||||
| `/asr` | WebSocket endpoint for audio streaming |
|
||||
|
||||
---
|
||||
|
||||
## New API (Under Development)
|
||||
|
||||
### Philosophy
|
||||
|
||||
Principles:
|
||||
|
||||
- **Incremental Updates**: Only updates and new segments are sent
|
||||
- **Ephemeral Buffers**: Temporary, unvalidated data displayed in real-time but overwritten on next update, at speaker level
|
||||
|
||||
|
||||
## Message Format
|
||||
|
||||
### Transcript Update (Server → Client)
|
||||
|
||||
```typescript
|
||||
{
|
||||
@@ -58,22 +27,11 @@ Principles:
|
||||
"id": number,
|
||||
"speaker": number,
|
||||
"text": string,
|
||||
"start_speaker": float,
|
||||
"start": float,
|
||||
"end": float,
|
||||
"start_speaker": string, // HH:MM:SS format
|
||||
"start": string, // HH:MM:SS format
|
||||
"end": string, // HH:MM:SS format
|
||||
"language": string | null,
|
||||
"translation": string,
|
||||
"words": [
|
||||
{
|
||||
"text": string,
|
||||
"start": float,
|
||||
"end": float,
|
||||
"validated": {
|
||||
"text": boolean,
|
||||
"speaker": boolean,
|
||||
}
|
||||
}
|
||||
],
|
||||
"buffer": {
|
||||
"transcription": string,
|
||||
"diarization": string,
|
||||
@@ -94,9 +52,10 @@ Principles:
|
||||
```json
|
||||
{
|
||||
"type": "config",
|
||||
"useAudioWorklet": true / false
|
||||
"useAudioWorklet": true
|
||||
}
|
||||
```
|
||||
- `useAudioWorklet`: If `true`, client should use AudioWorklet for PCM streaming. If `false`, use MediaRecorder for WebM.
|
||||
|
||||
#### Ready to Stop Message (sent after processing complete)
|
||||
```json
|
||||
@@ -104,6 +63,7 @@ Principles:
|
||||
"type": "ready_to_stop"
|
||||
}
|
||||
```
|
||||
Indicates all audio has been processed and the client can safely close the connection.
|
||||
|
||||
---
|
||||
|
||||
@@ -113,152 +73,179 @@ Principles:
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `id` | `number` | Unique identifier for this segment. Used by clients to update specific segments efficiently. |
|
||||
| `id` | `number` | Unique identifier for this segment. |
|
||||
| `speaker` | `number` | Speaker ID (1, 2, 3...). Special value `-2` indicates silence. |
|
||||
| `text` | `string` | Validated transcription text for this update. Should be **appended** to the segment's text on the client side. |
|
||||
| `start_speaker` | `float` | Timestamp (seconds) when this speaker segment began. |
|
||||
| `start` | `float` | Timestamp (seconds) of the first word in this update. |
|
||||
| `end` | `float` | Timestamp (seconds) of the last word in this update. |
|
||||
| `language` | `string \| null` | ISO language code (e.g., "en", "fr"). `null` until language is detected. |
|
||||
| `translation` | `string` | Validated translation text for this update. Should be **appended** to the segment's translation on the client side. |
|
||||
| `words` | `Array` | Array of word-level objects with timing and validation information. |
|
||||
| `buffer` | `Object` | Per-segment temporary buffers, see below |
|
||||
|
||||
### Word Object
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `text` | `string` | The word text. |
|
||||
| `start` | `number` | Start timestamp (seconds) of this word. |
|
||||
| `end` | `number` | End timestamp (seconds) of this word. |
|
||||
| `validated.text` | `boolean` | Whether the transcription text has been validated. if false, word is also in buffer: transcription |
|
||||
| `validated.speaker` | `boolean` | Whether the speaker assignment has been validated. if false, word is also in buffer: diarization |
|
||||
| `validated.language` | `boolean` | Whether the language detection has been validated. if false, word is also in buffer: translation |
|
||||
| `text` | `string` | Validated transcription text. |
|
||||
| `start_speaker` | `string` | Timestamp (HH:MM:SS) when this speaker segment began. |
|
||||
| `start` | `string` | Timestamp (HH:MM:SS) of the first word. |
|
||||
| `end` | `string` | Timestamp (HH:MM:SS) of the last word. |
|
||||
| `language` | `string \| null` | ISO language code (e.g., "en", "fr"). `null` until detected. |
|
||||
| `translation` | `string` | Validated translation text. |
|
||||
| `buffer` | `Object` | Per-segment temporary buffers (see below). |
|
||||
|
||||
### Buffer Object (Per-Segment)
|
||||
|
||||
Buffers are **ephemeral**. They should be displayed to the user but not stored permanently in the frontend. Each update may contain a completely different buffer value, and previous buffer is likely to be in the next validated text.
|
||||
Buffers are **ephemeral**. They should be displayed to the user but are overwritten on each update. Only the **last non-silent segment** contains buffer content.
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `transcription` | `string` | Pending transcription text. Displayed immediately but **overwritten** on next update. |
|
||||
| `diarization` | `string` | Pending diarization text (text waiting for speaker assignment). Displayed immediately but **overwritten** on next update. |
|
||||
| `translation` | `string` | Pending translation text. Displayed immediately but **overwritten** on next update. |
|
||||
|
||||
| `transcription` | `string` | Text pending validation (waiting for more context). |
|
||||
| `diarization` | `string` | Text pending speaker assignment (diarization hasn't caught up). |
|
||||
| `translation` | `string` | Translation pending validation. |
|
||||
|
||||
### Metadata Fields
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `remaining_time_transcription` | `float` | Seconds of audio waiting for transcription processing. |
|
||||
| `remaining_time_diarization` | `float` | Seconds of audio waiting for speaker diarization. |
|
||||
| `remaining_time_transcription` | `float` | Seconds of audio waiting for transcription. |
|
||||
| `remaining_time_diarization` | `float` | Seconds of audio waiting for diarization. |
|
||||
|
||||
### Status Values
|
||||
|
||||
| Status | Description |
|
||||
|--------|-------------|
|
||||
| `active_transcription` | Normal operation, transcription is active. |
|
||||
| `no_audio_detected` | No audio has been detected yet. |
|
||||
| `no_audio_detected` | No audio/speech has been detected yet. |
|
||||
|
||||
---
|
||||
|
||||
## Update Behavior
|
||||
## Behavior Notes
|
||||
|
||||
### Incremental Updates
|
||||
### Silence Handling
|
||||
|
||||
The API sends **only changed or new segments**. Clients should:
|
||||
- **Short silences (< 2 seconds)** are filtered out and not displayed.
|
||||
- Only significant pauses appear as silence segments with `speaker: -2`.
|
||||
- Consecutive same-speaker segments are merged even across short silences.
|
||||
|
||||
1. Maintain a local map of segments by ID
|
||||
2. When receiving an update, merge/update segments by ID
|
||||
3. Render only the changed segments
|
||||
### Update Frequency
|
||||
|
||||
### Language Detection
|
||||
- **Active transcription**: ~20 updates/second (every 50ms)
|
||||
- **During silence**: ~2 updates/second (every 500ms) to reduce bandwidth
|
||||
|
||||
When language is detected for a segment:
|
||||
### Token-by-Token Validation (Diarization Mode)
|
||||
|
||||
```jsonc
|
||||
// Update 1: No language yet
|
||||
{
|
||||
"segments": [
|
||||
{"id": 1, "speaker": 1, "text": "May see", "language": null}
|
||||
]
|
||||
}
|
||||
|
||||
// Update 2: Same segment ID, language now detected
|
||||
{
|
||||
"segments": [
|
||||
{"id": 1, "speaker": 1, "text": "Merci", "language": "fr"}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**Client behavior**: **Replace** the existing segment with the same ID.
|
||||
|
||||
### Buffer Behavior
|
||||
|
||||
Buffers are **per-segment** to handle multi-speaker scenarios correctly.
|
||||
|
||||
#### Example: Translation with diarization and translation
|
||||
|
||||
```jsonc
|
||||
// Update 1
|
||||
When diarization is enabled, text is validated **token-by-token** as soon as diarization covers each token, rather than waiting for punctuation. This provides:
|
||||
- Faster text validation
|
||||
- More responsive speaker attribution
|
||||
- Buffer only contains tokens that diarization hasn't processed yet
|
||||
|
||||
---
|
||||
|
||||
## Example Messages
|
||||
|
||||
### Normal Transcription
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "transcript_update",
|
||||
"status": "active_transcription",
|
||||
"segments": [
|
||||
{
|
||||
"id": 1,
|
||||
"speaker": 1,
|
||||
"text": "Hello world, how are",
|
||||
"text": "Hello, how are you today?",
|
||||
"start_speaker": "0:00:02",
|
||||
"start": "0:00:02",
|
||||
"end": "0:00:05",
|
||||
"language": "en",
|
||||
"translation": "",
|
||||
"buffer": {
|
||||
"transcription": " I'm doing",
|
||||
"diarization": "",
|
||||
"translation": ""
|
||||
}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"remaining_time_transcription": 0.5,
|
||||
"remaining_time_diarization": 0
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### With Diarization Buffer
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "transcript_update",
|
||||
"status": "active_transcription",
|
||||
"segments": [
|
||||
{
|
||||
"id": 1,
|
||||
"speaker": 1,
|
||||
"text": "The meeting starts at nine.",
|
||||
"start_speaker": "0:00:03",
|
||||
"start": "0:00:03",
|
||||
"end": "0:00:06",
|
||||
"language": "en",
|
||||
"translation": "",
|
||||
"buffer": {
|
||||
"transcription": "",
|
||||
"diarization": " you on",
|
||||
"translation": "Bonjour le monde"
|
||||
"diarization": " Let me check my calendar",
|
||||
"translation": ""
|
||||
}
|
||||
}
|
||||
]
|
||||
],
|
||||
"metadata": {
|
||||
"remaining_time_transcription": 0.3,
|
||||
"remaining_time_diarization": 2.1
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// ==== Frontend ====
|
||||
// <SPEAKER>1</SPEAKER>
|
||||
// <TRANSCRIPTION>Hello world, how are <DIARIZATION BUFFER> you on</DIARIZATION BUFFER></TRANSCRIPTION>
|
||||
// <TRANSLATION><TRANSLATION BUFFER>Bonjour le monde</TRANSLATION BUFFER></TRANSLATION>
|
||||
|
||||
|
||||
// Update 2
|
||||
{
|
||||
"segments": [
|
||||
{
|
||||
"id": 1,
|
||||
"speaker": 1,
|
||||
"text": " you on this",
|
||||
"translation": "Bonjour tout le monde",
|
||||
"buffer": {
|
||||
"transcription": "",
|
||||
"diarization": " beautiful day",
|
||||
"translation": ",comment"
|
||||
}
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
// ==== Frontend ====
|
||||
// <SPEAKER>1</SPEAKER>
|
||||
// <TRANSCRIPTION>Hello world, how are you on this<DIARIZATION BUFFER> beautiful day</DIARIZATION BUFFER></TRANSCRIPTION>
|
||||
// <TRANSLATION>Bonjour tout le monde<TRANSLATION BUFFER>, comment</TRANSLATION BUFFER><TRANSLATION>
|
||||
```
|
||||
|
||||
### Silence Segments
|
||||
### Silence Segment
|
||||
|
||||
Silence is represented with the speaker id = `-2`:
|
||||
|
||||
```jsonc
|
||||
```json
|
||||
{
|
||||
"id": 5,
|
||||
"speaker": -2,
|
||||
"text": "",
|
||||
"start": 10.5,
|
||||
"end": 12.3
|
||||
"start_speaker": "0:00:10",
|
||||
"start": "0:00:10",
|
||||
"end": "0:00:15",
|
||||
"language": null,
|
||||
"translation": "",
|
||||
"buffer": {
|
||||
"transcription": "",
|
||||
"diarization": "",
|
||||
"translation": ""
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Text Transcript Endpoint (`/text`)
|
||||
|
||||
The `/text` endpoint provides a simple, monospace text interface designed for:
|
||||
- Easy copy/paste of transcripts
|
||||
- Debugging and development
|
||||
- Integration testing
|
||||
|
||||
Output uses text markers instead of HTML styling:
|
||||
|
||||
```
|
||||
[METADATA transcription_lag=0.5s diarization_lag=1.2s]
|
||||
|
||||
[SPEAKER 1] 0:00:03 - 0:00:11 [LANG: en]
|
||||
Hello world, how are you doing today?[DIAR_BUFFER] I'm doing fine[/DIAR_BUFFER]
|
||||
|
||||
[SILENCE 0:00:15 - 0:00:18]
|
||||
|
||||
[SPEAKER 2] 0:00:18 - 0:00:22 [LANG: en]
|
||||
That's great to hear!
|
||||
[TRANSLATION]C'est super à entendre![/TRANSLATION]
|
||||
```
|
||||
|
||||
### Markers
|
||||
|
||||
| Marker | Description |
|
||||
|--------|-------------|
|
||||
| `[SPEAKER N]` | Speaker label with ID |
|
||||
| `[SILENCE start - end]` | Silence segment |
|
||||
| `[LANG: xx]` | Detected language code |
|
||||
| `[DIAR_BUFFER]...[/DIAR_BUFFER]` | Text pending speaker assignment |
|
||||
| `[TRANS_BUFFER]...[/TRANS_BUFFER]` | Text pending validation |
|
||||
| `[TRANSLATION]...[/TRANSLATION]` | Translation content |
|
||||
| `[METADATA ...]` | Lag/timing information |
|
||||
|
||||
|
||||
@@ -1,13 +1,73 @@
|
||||
### Alignment between STT Tokens and Diarization Segments
|
||||
# Alignment Principles
|
||||
|
||||
- Example 1: The punctuation from STT and the speaker change from Diariation come in the prediction `t`
|
||||
- Example 2: The punctuation from STT comes from prediction `t`, but the speaker change from Diariation come in the prediction `t-1`
|
||||
- Example 3: The punctuation from STT comes from prediction `t-1`, but the speaker change from Diariation come in the prediction `t`
|
||||
This document explains how transcription tokens are aligned with diarization (speaker identification) segments.
|
||||
|
||||
> `#` Is the split between the `t-1` prediction and `t` prediction.
|
||||
---
|
||||
|
||||
## Token-by-Token Validation
|
||||
|
||||
When diarization is enabled, text is validated **token-by-token** rather than waiting for sentence boundaries. As soon as diarization covers a token's time range, that token is validated and assigned to the appropriate speaker.
|
||||
|
||||
### How It Works
|
||||
|
||||
1. **Transcription produces tokens** with timestamps (start, end)
|
||||
2. **Diarization produces speaker segments** with timestamps
|
||||
3. **For each token**: Check if diarization has caught up to that token's time
|
||||
- If yes → Find speaker with maximum overlap, validate token
|
||||
- If no → Keep token in "pending" (becomes diarization buffer)
|
||||
|
||||
```
|
||||
Timeline: 0s -------- 5s -------- 10s -------- 15s
|
||||
| | | |
|
||||
Transcription: [Hello, how are you doing today?]
|
||||
|_______|___|____|_____|_____|_____|
|
||||
tok1 tok2 tok3 tok4 tok5 tok6
|
||||
|
||||
Diarization: [SPEAKER 1 ][SPEAKER 2 ]
|
||||
|__________________|__________________|
|
||||
0s 8s 15s
|
||||
|
||||
At time t when diarization covers up to 8s:
|
||||
- Tokens 1-4 (0s-7s) → Validated as SPEAKER 1
|
||||
- Tokens 5-6 (7s-10s) → In buffer (diarization hasn't caught up)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Silence Handling
|
||||
|
||||
- **Short silences (< 2 seconds)**: Filtered out, not displayed
|
||||
- **Significant silences (≥ 2 seconds)**: Displayed as silence segments with `speaker: -2`
|
||||
- **Same speaker across gaps**: Segments are merged even if separated by short silences
|
||||
|
||||
```
|
||||
Before filtering:
|
||||
[SPK1 0:00-0:03] [SILENCE 0:03-0:04] [SPK1 0:04-0:08]
|
||||
|
||||
After filtering (silence < 2s):
|
||||
[SPK1 0:00-0:08] ← Merged into single segment
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Buffer Types
|
||||
|
||||
| Buffer | Contains | Displayed When |
|
||||
|--------|----------|----------------|
|
||||
| `transcription` | Text awaiting validation (more context needed) | Always on last segment |
|
||||
| `diarization` | Text awaiting speaker assignment | When diarization lags behind transcription |
|
||||
| `translation` | Translation awaiting validation | When translation is enabled |
|
||||
|
||||
---
|
||||
|
||||
## Legacy: Punctuation-Based Splitting
|
||||
|
||||
The previous approach split segments at punctuation marks and aligned with diarization at those boundaries. This is now replaced by token-by-token validation for faster, more responsive results.
|
||||
|
||||
### Historical Examples (for reference)
|
||||
|
||||
Example of punctuation-based alignment:
|
||||
|
||||
## Example 1:
|
||||
```text
|
||||
punctuations_segments : __#_______.__________________!____
|
||||
diarization_segments:
|
||||
@@ -16,56 +76,6 @@ SPK2 # ___________________
|
||||
-->
|
||||
ALIGNED SPK1 __#_______.
|
||||
ALIGNED SPK2 # __________________!____
|
||||
|
||||
t-1 output:
|
||||
SPK1: __#
|
||||
SPK2: NO
|
||||
DIARIZATION BUFFER: NO
|
||||
|
||||
t output:
|
||||
SPK1: __#__.
|
||||
SPK2: __________________!____
|
||||
DIARIZATION BUFFER: No
|
||||
```
|
||||
|
||||
## Example 2:
|
||||
```text
|
||||
punctuations_segments : _____#__.___________
|
||||
diarization_segments:
|
||||
SPK1 ___ #
|
||||
SPK2 __#______________
|
||||
-->
|
||||
ALIGNED SPK1 _____#__.
|
||||
ALIGNED SPK2 # ___________
|
||||
|
||||
t-1 output:
|
||||
SPK1: ___ #
|
||||
SPK2:
|
||||
DIARIZATION BUFFER: __#
|
||||
|
||||
t output:
|
||||
SPK1: __#__.
|
||||
SPK2: ___________
|
||||
DIARIZATION BUFFER: No
|
||||
```
|
||||
|
||||
## Example 3:
|
||||
```text
|
||||
punctuations_segments : ___.__#__________
|
||||
diarization_segments:
|
||||
SPK1 ______#__
|
||||
SPK2 # ________
|
||||
-->
|
||||
ALIGNED SPK1 ___. #
|
||||
ALIGNED SPK2 __#__________
|
||||
|
||||
t-1 output:
|
||||
SPK1: ___. #
|
||||
SPK2:
|
||||
DIARIZATION BUFFER: __#
|
||||
|
||||
t output:
|
||||
SPK1: #
|
||||
SPK2: __#___________
|
||||
DIARIZATION BUFFER: NO
|
||||
```
|
||||
With token-by-token validation, the alignment happens continuously rather than at punctuation boundaries.
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from .audio_processor import AudioProcessor
|
||||
from .core import TranscriptionEngine
|
||||
from .parse_args import parse_args
|
||||
from .web.web_interface import get_inline_ui_html, get_web_interface_html
|
||||
from .web.web_interface import get_inline_ui_html, get_text_transcript_html, get_web_interface_html
|
||||
|
||||
__all__ = [
|
||||
"TranscriptionEngine",
|
||||
@@ -9,5 +9,6 @@ __all__ = [
|
||||
"parse_args",
|
||||
"get_web_interface_html",
|
||||
"get_inline_ui_html",
|
||||
"get_text_transcript_html",
|
||||
"download_simulstreaming_backend",
|
||||
]
|
||||
|
||||
@@ -393,6 +393,10 @@ class AudioProcessor:
|
||||
|
||||
async def results_formatter(self) -> AsyncGenerator[FrontData, None]:
|
||||
"""Format processing results for output."""
|
||||
# Update intervals
|
||||
ACTIVE_INTERVAL = 0.05 # 20 updates/sec during active transcription
|
||||
SILENCE_INTERVAL = 0.5 # 2 updates/sec during silence
|
||||
|
||||
while True:
|
||||
try:
|
||||
if self._ffmpeg_error:
|
||||
@@ -402,25 +406,35 @@ class AudioProcessor:
|
||||
continue
|
||||
|
||||
self.tokens_alignment.update()
|
||||
lines, buffer_diarization_text, buffer_translation_text = self.tokens_alignment.get_lines(
|
||||
state = await self.get_current_state()
|
||||
|
||||
# Get transcription buffer text to pass to get_lines
|
||||
buffer_transcription_text = state.buffer_transcription.text if state.buffer_transcription else ''
|
||||
|
||||
# get_lines now returns segments with per-segment buffers
|
||||
segments = self.tokens_alignment.get_lines(
|
||||
diarization=self.args.diarization,
|
||||
translation=bool(self.translation),
|
||||
current_silence=self.current_silence
|
||||
current_silence=self.current_silence,
|
||||
buffer_transcription=buffer_transcription_text
|
||||
)
|
||||
state = await self.get_current_state()
|
||||
|
||||
buffer_transcription_text = state.buffer_transcription.text if state.buffer_transcription else ''
|
||||
|
||||
response_status = "active_transcription"
|
||||
if not lines and not buffer_transcription_text and not buffer_diarization_text:
|
||||
# Check if there's any content (segments with text or buffers)
|
||||
has_active_content = any(
|
||||
seg.buffer and (seg.buffer.transcription or seg.buffer.diarization)
|
||||
for seg in segments if not seg.is_silence()
|
||||
)
|
||||
has_any_content = any(
|
||||
seg.text or (seg.buffer and (seg.buffer.transcription or seg.buffer.diarization))
|
||||
for seg in segments if not seg.is_silence()
|
||||
)
|
||||
if not segments or not has_any_content:
|
||||
response_status = "no_audio_detected"
|
||||
|
||||
response = FrontData(
|
||||
status=response_status,
|
||||
lines=lines,
|
||||
buffer_transcription=buffer_transcription_text,
|
||||
buffer_diarization=buffer_diarization_text,
|
||||
buffer_translation=buffer_translation_text,
|
||||
segments=segments,
|
||||
remaining_time_transcription=state.remaining_time_transcription,
|
||||
remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0
|
||||
)
|
||||
@@ -434,7 +448,15 @@ class AudioProcessor:
|
||||
logger.info("Results formatter: All upstream processors are done and in stopping state. Terminating.")
|
||||
return
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
# Throttle updates during silence: use slower interval when in silence mode
|
||||
# with no pending buffers (nothing actively being processed)
|
||||
is_in_silence = self.current_silence is not None
|
||||
has_pending_work = has_active_content or state.remaining_time_transcription > 0.5
|
||||
|
||||
if is_in_silence and not has_pending_work:
|
||||
await asyncio.sleep(SILENCE_INTERVAL)
|
||||
else:
|
||||
await asyncio.sleep(ACTIVE_INTERVAL)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in results_formatter. Traceback: {traceback.format_exc()}")
|
||||
|
||||
@@ -7,7 +7,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from whisperlivekit import (AudioProcessor, TranscriptionEngine,
|
||||
get_inline_ui_html, parse_args)
|
||||
get_inline_ui_html, get_text_transcript_html, parse_args)
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
logging.getLogger().setLevel(logging.WARNING)
|
||||
@@ -39,6 +39,12 @@ async def get():
|
||||
return HTMLResponse(get_inline_ui_html())
|
||||
|
||||
|
||||
@app.get("/text")
|
||||
async def get_text():
|
||||
"""Simple text-based transcript view for easy copy/paste."""
|
||||
return HTMLResponse(get_text_transcript_html())
|
||||
|
||||
|
||||
async def handle_websocket_results(websocket, results_generator):
|
||||
"""Consumes results from the audio processor and sends them via WebSocket."""
|
||||
try:
|
||||
|
||||
@@ -107,6 +107,21 @@ class Silence():
|
||||
return True
|
||||
|
||||
|
||||
@dataclass
|
||||
class SegmentBuffer:
|
||||
"""Per-segment buffer for ephemeral/unvalidated content."""
|
||||
transcription: str = ''
|
||||
diarization: str = ''
|
||||
translation: str = ''
|
||||
|
||||
def to_dict(self) -> Dict[str, str]:
|
||||
return {
|
||||
'transcription': self.transcription,
|
||||
'diarization': self.diarization,
|
||||
'translation': self.translation
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Segment(TimedText):
|
||||
"""Generic contiguous span built from tokens or silence markers."""
|
||||
@@ -114,14 +129,18 @@ class Segment(TimedText):
|
||||
end: Optional[float]
|
||||
text: Optional[str]
|
||||
speaker: Optional[str]
|
||||
id: Optional[int] = None
|
||||
start_speaker: Optional[float] = None
|
||||
tokens: Optional[ASRToken] = None
|
||||
translation: Optional[Translation] = None
|
||||
buffer: Optional[SegmentBuffer] = None
|
||||
|
||||
@classmethod
|
||||
def from_tokens(
|
||||
cls,
|
||||
tokens: List[Union[ASRToken, Silence]],
|
||||
is_silence: bool = False
|
||||
is_silence: bool = False,
|
||||
segment_id: Optional[int] = None
|
||||
) -> Optional["Segment"]:
|
||||
"""Return a normalized segment representing the provided tokens."""
|
||||
if not tokens:
|
||||
@@ -134,7 +153,9 @@ class Segment(TimedText):
|
||||
start=start_token.start,
|
||||
end=end_token.end,
|
||||
text=None,
|
||||
speaker=-2
|
||||
speaker=-2,
|
||||
id=segment_id,
|
||||
start_speaker=start_token.start
|
||||
)
|
||||
else:
|
||||
return cls(
|
||||
@@ -142,6 +163,8 @@ class Segment(TimedText):
|
||||
end=end_token.end,
|
||||
text=''.join(token.text for token in tokens),
|
||||
speaker=-1,
|
||||
id=segment_id,
|
||||
start_speaker=start_token.start,
|
||||
detected_language=start_token.detected_language
|
||||
)
|
||||
|
||||
@@ -150,17 +173,18 @@ class Segment(TimedText):
|
||||
return self.speaker == -2
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serialize the segment for frontend consumption."""
|
||||
"""Serialize the segment for frontend consumption (new API format)."""
|
||||
_dict: Dict[str, Any] = {
|
||||
'id': self.id if self.id is not None else 0,
|
||||
'speaker': int(self.speaker) if self.speaker != -1 else 1,
|
||||
'text': self.text,
|
||||
'text': self.text or '',
|
||||
'start_speaker': format_time(self.start_speaker) if self.start_speaker is not None else format_time(self.start),
|
||||
'start': format_time(self.start),
|
||||
'end': format_time(self.end),
|
||||
'language': self.detected_language,
|
||||
'translation': self.translation or '',
|
||||
'buffer': self.buffer.to_dict() if self.buffer else SegmentBuffer().to_dict()
|
||||
}
|
||||
if self.translation:
|
||||
_dict['translation'] = self.translation
|
||||
if self.detected_language:
|
||||
_dict['detected_language'] = self.detected_language
|
||||
return _dict
|
||||
|
||||
|
||||
@@ -179,23 +203,20 @@ class SilentSegment(Segment):
|
||||
class FrontData():
|
||||
status: str = ''
|
||||
error: str = ''
|
||||
lines: list[Segment] = field(default_factory=list)
|
||||
buffer_transcription: str = ''
|
||||
buffer_diarization: str = ''
|
||||
buffer_translation: str = ''
|
||||
segments: list[Segment] = field(default_factory=list)
|
||||
remaining_time_transcription: float = 0.
|
||||
remaining_time_diarization: float = 0.
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serialize the front-end data payload."""
|
||||
"""Serialize the front-end data payload (new API format)."""
|
||||
_dict: Dict[str, Any] = {
|
||||
'type': 'transcript_update',
|
||||
'status': self.status,
|
||||
'lines': [line.to_dict() for line in self.lines if (line.text or line.speaker == -2)],
|
||||
'buffer_transcription': self.buffer_transcription,
|
||||
'buffer_diarization': self.buffer_diarization,
|
||||
'buffer_translation': self.buffer_translation,
|
||||
'remaining_time_transcription': self.remaining_time_transcription,
|
||||
'remaining_time_diarization': self.remaining_time_diarization,
|
||||
'segments': [seg.to_dict() for seg in self.segments if (seg.text or seg.speaker == -2)],
|
||||
'metadata': {
|
||||
'remaining_time_transcription': self.remaining_time_transcription,
|
||||
'remaining_time_diarization': self.remaining_time_diarization,
|
||||
}
|
||||
}
|
||||
if self.error:
|
||||
_dict['error'] = self.error
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
from time import time
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
from whisperlivekit.timed_objects import (ASRToken, Segment, PuncSegment, Silence,
|
||||
from whisperlivekit.timed_objects import (ASRToken, Segment, SegmentBuffer, PuncSegment, Silence,
|
||||
SilentSegment, SpeakerSegment,
|
||||
TimedText)
|
||||
|
||||
|
||||
class TokensAlignment:
|
||||
# Minimum duration (seconds) for a silence to be displayed
|
||||
MIN_SILENCE_DISPLAY_DURATION = 2.0
|
||||
|
||||
def __init__(self, state: Any, args: Any, sep: Optional[str]) -> None:
|
||||
self.state = state
|
||||
@@ -33,7 +35,15 @@ class TokensAlignment:
|
||||
|
||||
self.last_punctuation = None
|
||||
self.last_uncompleted_punc_segment: PuncSegment = None
|
||||
self.unvalidated_tokens: PuncSegment = []
|
||||
self.tokens_after_last_punctuation: PuncSegment = []
|
||||
self.all_validated_segments: List[Segment] = []
|
||||
|
||||
# For token-by-token validation with diarization
|
||||
self.pending_tokens: List[ASRToken] = []
|
||||
self.last_validated_token_end: float = 0.0
|
||||
|
||||
# Segment ID counter for the new API
|
||||
self._next_segment_id: int = 1
|
||||
|
||||
def update(self) -> None:
|
||||
"""Drain state buffers into the running alignment context."""
|
||||
@@ -91,11 +101,11 @@ class TokensAlignment:
|
||||
def compute_new_punctuations_segments(self) -> List[PuncSegment]:
|
||||
new_punc_segments = []
|
||||
segment_start_idx = 0
|
||||
self.unvalidated_tokens += self.new_tokens
|
||||
for i, token in enumerate(self.unvalidated_tokens):
|
||||
self.tokens_after_last_punctuation += self.new_tokens
|
||||
for i, token in enumerate(self.tokens_after_last_punctuation):
|
||||
if token.is_silence():
|
||||
previous_segment = PuncSegment.from_tokens(
|
||||
tokens=self.unvalidated_tokens[segment_start_idx: i],
|
||||
tokens=self.tokens_after_last_punctuation[segment_start_idx: i],
|
||||
)
|
||||
if previous_segment:
|
||||
new_punc_segments.append(previous_segment)
|
||||
@@ -108,12 +118,12 @@ class TokensAlignment:
|
||||
else:
|
||||
if token.has_punctuation():
|
||||
segment = PuncSegment.from_tokens(
|
||||
tokens=self.unvalidated_tokens[segment_start_idx: i+1],
|
||||
tokens=self.tokens_after_last_punctuation[segment_start_idx: i+1],
|
||||
)
|
||||
new_punc_segments.append(segment)
|
||||
segment_start_idx = i+1
|
||||
|
||||
self.unvalidated_tokens = self.unvalidated_tokens[segment_start_idx:]
|
||||
self.tokens_after_last_punctuation = self.tokens_after_last_punctuation[segment_start_idx:]
|
||||
return new_punc_segments
|
||||
|
||||
|
||||
@@ -138,64 +148,189 @@ class TokensAlignment:
|
||||
|
||||
return max(0, end - start)
|
||||
|
||||
def get_lines_diarization(self) -> Tuple[List[Segment], str]:
|
||||
"""Build segments when diarization is enabled and track overflow buffer."""
|
||||
diarization_buffer = ''
|
||||
punctuation_segments = self.compute_punctuations_segments()
|
||||
diarization_segments = self.concatenate_diar_segments()
|
||||
for punctuation_segment in punctuation_segments:
|
||||
if not punctuation_segment.is_silence():
|
||||
if diarization_segments and punctuation_segment.start >= diarization_segments[-1].end:
|
||||
diarization_buffer += punctuation_segment.text
|
||||
else:
|
||||
max_overlap = 0.0
|
||||
max_overlap_speaker = 1
|
||||
for diarization_segment in diarization_segments:
|
||||
intersec = self.intersection_duration(punctuation_segment, diarization_segment)
|
||||
if intersec > max_overlap:
|
||||
max_overlap = intersec
|
||||
max_overlap_speaker = diarization_segment.speaker + 1
|
||||
punctuation_segment.speaker = max_overlap_speaker
|
||||
def _get_speaker_for_token(self, token: ASRToken, diarization_segments: List[SpeakerSegment]) -> Optional[int]:
|
||||
"""Get speaker ID for a token based on diarization overlap. Returns None if not covered."""
|
||||
if not diarization_segments:
|
||||
return None
|
||||
|
||||
segments = []
|
||||
if punctuation_segments:
|
||||
segments = [punctuation_segments[0]]
|
||||
for segment in punctuation_segments[1:]:
|
||||
if segment.speaker == segments[-1].speaker:
|
||||
if segments[-1].text:
|
||||
segments[-1].text += segment.text
|
||||
segments[-1].end = segment.end
|
||||
# Check if token is beyond diarization coverage
|
||||
if token.start >= diarization_segments[-1].end:
|
||||
return None
|
||||
|
||||
# Find speaker with max overlap
|
||||
max_overlap = 0.0
|
||||
best_speaker = None
|
||||
for diar_seg in diarization_segments:
|
||||
overlap = self.intersection_duration(token, diar_seg)
|
||||
if overlap > max_overlap:
|
||||
max_overlap = overlap
|
||||
best_speaker = diar_seg.speaker + 1 # 1-indexed
|
||||
|
||||
return best_speaker if max_overlap > 0 else None
|
||||
|
||||
def get_lines_diarization(self) -> Tuple[List[Segment], str]:
|
||||
"""Build segments with token-by-token validation when diarization covers them."""
|
||||
diarization_segments = self.concatenate_diar_segments()
|
||||
|
||||
# Add new tokens to pending
|
||||
self.pending_tokens.extend(self.new_tokens)
|
||||
|
||||
# Process pending tokens - validate those covered by diarization
|
||||
still_pending = []
|
||||
for token in self.pending_tokens:
|
||||
if token.is_silence():
|
||||
# Handle silence tokens
|
||||
silence_duration = (token.end or 0) - (token.start or 0)
|
||||
if silence_duration >= self.MIN_SILENCE_DISPLAY_DURATION:
|
||||
# Significant silence - add as separate segment
|
||||
if self.all_validated_segments and not self.all_validated_segments[-1].is_silence():
|
||||
self.all_validated_segments.append(SilentSegment(
|
||||
start=token.start,
|
||||
end=token.end
|
||||
))
|
||||
elif self.all_validated_segments and self.all_validated_segments[-1].is_silence():
|
||||
# Extend existing silence
|
||||
self.all_validated_segments[-1].end = token.end
|
||||
else:
|
||||
self.all_validated_segments.append(SilentSegment(
|
||||
start=token.start,
|
||||
end=token.end
|
||||
))
|
||||
# Short silences are ignored (don't go to pending either)
|
||||
continue
|
||||
|
||||
speaker = self._get_speaker_for_token(token, diarization_segments)
|
||||
|
||||
if speaker is not None:
|
||||
# Token is covered by diarization - validate it
|
||||
if self.all_validated_segments:
|
||||
last_seg = self.all_validated_segments[-1]
|
||||
if not last_seg.is_silence() and last_seg.speaker == speaker:
|
||||
# Same speaker - append to existing segment
|
||||
last_seg.text += token.text
|
||||
last_seg.end = token.end
|
||||
else:
|
||||
# Different speaker or after silence - new segment
|
||||
new_seg = Segment(
|
||||
start=token.start,
|
||||
end=token.end,
|
||||
text=token.text,
|
||||
speaker=speaker,
|
||||
start_speaker=token.start,
|
||||
detected_language=token.detected_language
|
||||
)
|
||||
self.all_validated_segments.append(new_seg)
|
||||
else:
|
||||
segments.append(segment)
|
||||
# First segment
|
||||
new_seg = Segment(
|
||||
start=token.start,
|
||||
end=token.end,
|
||||
text=token.text,
|
||||
speaker=speaker,
|
||||
start_speaker=token.start,
|
||||
detected_language=token.detected_language
|
||||
)
|
||||
self.all_validated_segments.append(new_seg)
|
||||
|
||||
self.last_validated_token_end = token.end
|
||||
else:
|
||||
# Token not yet covered by diarization - keep pending
|
||||
still_pending.append(token)
|
||||
|
||||
self.pending_tokens = still_pending
|
||||
|
||||
# Build diarization buffer from pending tokens
|
||||
diarization_buffer = ''.join(t.text for t in self.pending_tokens if not t.is_silence())
|
||||
|
||||
return self.all_validated_segments, diarization_buffer
|
||||
|
||||
return segments, diarization_buffer
|
||||
|
||||
def _assign_segment_ids(self, segments: List[Segment]) -> None:
|
||||
"""Assign unique IDs to segments that don't have one yet."""
|
||||
for segment in segments:
|
||||
if segment.id is None:
|
||||
segment.id = self._next_segment_id
|
||||
self._next_segment_id += 1
|
||||
|
||||
def _assign_buffers_to_last_segment(
|
||||
self,
|
||||
segments: List[Segment],
|
||||
buffer_transcription: str,
|
||||
buffer_diarization: str,
|
||||
buffer_translation: str
|
||||
) -> None:
|
||||
"""Assign buffer content to the last non-silent segment."""
|
||||
# First, clear ALL buffers (they're ephemeral and shouldn't persist)
|
||||
for segment in segments:
|
||||
segment.buffer = SegmentBuffer()
|
||||
|
||||
# Find the last non-silent segment and assign buffers to it
|
||||
for segment in reversed(segments):
|
||||
if not segment.is_silence():
|
||||
segment.buffer = SegmentBuffer(
|
||||
transcription=buffer_transcription,
|
||||
diarization=buffer_diarization,
|
||||
translation=buffer_translation
|
||||
)
|
||||
break
|
||||
|
||||
def _filter_and_merge_segments(self, segments: List[Segment]) -> List[Segment]:
|
||||
"""Filter parasitic silences and merge consecutive same-speaker segments."""
|
||||
if not segments:
|
||||
return segments
|
||||
|
||||
result = []
|
||||
for seg in segments:
|
||||
if seg.is_silence():
|
||||
# Filter short silences
|
||||
duration = (seg.end or 0) - (seg.start or 0)
|
||||
if duration < self.MIN_SILENCE_DISPLAY_DURATION:
|
||||
continue
|
||||
# Merge consecutive silences
|
||||
if result and result[-1].is_silence():
|
||||
result[-1].end = seg.end
|
||||
continue
|
||||
else:
|
||||
# Merge same speaker segments (across filtered silences)
|
||||
if result and not result[-1].is_silence() and result[-1].speaker == seg.speaker:
|
||||
result[-1].text += seg.text
|
||||
result[-1].end = seg.end
|
||||
continue
|
||||
|
||||
result.append(seg)
|
||||
|
||||
return result
|
||||
|
||||
def get_lines(
|
||||
self,
|
||||
diarization: bool = False,
|
||||
translation: bool = False,
|
||||
current_silence: Optional[Silence] = None
|
||||
) -> Tuple[List[Segment], str, Union[str, TimedText]]:
|
||||
"""Return the formatted segments plus buffers, optionally with diarization/translation."""
|
||||
current_silence: Optional[Silence] = None,
|
||||
buffer_transcription: str = ''
|
||||
) -> List[Segment]:
|
||||
"""Return the formatted segments with per-segment buffers, optionally with diarization/translation."""
|
||||
diarization_buffer = ''
|
||||
|
||||
if diarization:
|
||||
segments, diarization_buffer = self.get_lines_diarization()
|
||||
else:
|
||||
diarization_buffer = ''
|
||||
for token in self.new_tokens:
|
||||
if token.is_silence():
|
||||
if self.current_line_tokens:
|
||||
self.validated_segments.append(Segment().from_tokens(self.current_line_tokens))
|
||||
self.current_line_tokens = []
|
||||
|
||||
end_silence = token.end if token.has_ended else time() - self.beg_loop
|
||||
if self.validated_segments and self.validated_segments[-1].is_silence():
|
||||
self.validated_segments[-1].end = end_silence
|
||||
else:
|
||||
self.validated_segments.append(SilentSegment(
|
||||
start=token.start,
|
||||
end=end_silence
|
||||
))
|
||||
# Check silence duration before adding
|
||||
silence_duration = (token.end or 0) - (token.start or 0)
|
||||
if silence_duration >= self.MIN_SILENCE_DISPLAY_DURATION:
|
||||
if self.current_line_tokens:
|
||||
self.validated_segments.append(Segment().from_tokens(self.current_line_tokens))
|
||||
self.current_line_tokens = []
|
||||
|
||||
end_silence = token.end if token.has_ended else time() - self.beg_loop
|
||||
if self.validated_segments and self.validated_segments[-1].is_silence():
|
||||
self.validated_segments[-1].end = end_silence
|
||||
else:
|
||||
self.validated_segments.append(SilentSegment(
|
||||
start=token.start,
|
||||
end=end_silence
|
||||
))
|
||||
else:
|
||||
self.current_line_tokens.append(token)
|
||||
|
||||
@@ -203,15 +338,37 @@ class TokensAlignment:
|
||||
if self.current_line_tokens:
|
||||
segments.append(Segment().from_tokens(self.current_line_tokens))
|
||||
|
||||
# Handle current ongoing silence
|
||||
if current_silence:
|
||||
end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop
|
||||
if segments and segments[-1].is_silence():
|
||||
segments[-1] = SilentSegment(start=segments[-1].start, end=end_silence)
|
||||
else:
|
||||
segments.append(SilentSegment(
|
||||
start=current_silence.start,
|
||||
end=end_silence
|
||||
))
|
||||
silence_duration = (current_silence.end or time() - self.beg_loop) - (current_silence.start or 0)
|
||||
if silence_duration >= self.MIN_SILENCE_DISPLAY_DURATION:
|
||||
end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop
|
||||
if segments and segments[-1].is_silence():
|
||||
segments[-1] = SilentSegment(start=segments[-1].start, end=end_silence)
|
||||
else:
|
||||
segments.append(SilentSegment(
|
||||
start=current_silence.start,
|
||||
end=end_silence
|
||||
))
|
||||
|
||||
if translation:
|
||||
[self.add_translation(segment) for segment in segments if not segment.is_silence()]
|
||||
return segments, diarization_buffer, self.new_translation_buffer.text
|
||||
|
||||
# Get translation buffer text
|
||||
translation_buffer = self.new_translation_buffer.text if self.new_translation_buffer else ''
|
||||
|
||||
# Filter parasitic silences and merge same-speaker segments
|
||||
segments = self._filter_and_merge_segments(segments)
|
||||
|
||||
# Assign unique IDs to all segments
|
||||
self._assign_segment_ids(segments)
|
||||
|
||||
# Assign buffers to the last active segment
|
||||
self._assign_buffers_to_last_segment(
|
||||
segments,
|
||||
buffer_transcription=buffer_transcription,
|
||||
buffer_diarization=diarization_buffer,
|
||||
buffer_translation=translation_buffer
|
||||
)
|
||||
|
||||
return segments
|
||||
|
||||
@@ -454,8 +454,9 @@ label {
|
||||
gap: 4px;
|
||||
}
|
||||
|
||||
.lag-diarization-value {
|
||||
margin-left: 10px;
|
||||
.lag-diarization-value,
|
||||
.lag-transcription-value {
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.label_translation img {
|
||||
|
||||
@@ -232,11 +232,8 @@ function setupWebSocket() {
|
||||
if (waitingForStop) {
|
||||
statusText.textContent = "Processing finalized or connection closed.";
|
||||
if (lastReceivedData) {
|
||||
renderLinesWithBuffer(
|
||||
lastReceivedData.lines || [],
|
||||
lastReceivedData.buffer_diarization || "",
|
||||
lastReceivedData.buffer_transcription || "",
|
||||
lastReceivedData.buffer_translation || "",
|
||||
renderSegments(
|
||||
lastReceivedData.segments || [],
|
||||
0,
|
||||
0,
|
||||
true
|
||||
@@ -278,11 +275,8 @@ function setupWebSocket() {
|
||||
waitingForStop = false;
|
||||
|
||||
if (lastReceivedData) {
|
||||
renderLinesWithBuffer(
|
||||
lastReceivedData.lines || [],
|
||||
lastReceivedData.buffer_diarization || "",
|
||||
lastReceivedData.buffer_transcription || "",
|
||||
lastReceivedData.buffer_translation || "",
|
||||
renderSegments(
|
||||
lastReceivedData.segments || [],
|
||||
0,
|
||||
0,
|
||||
true
|
||||
@@ -299,21 +293,20 @@ function setupWebSocket() {
|
||||
|
||||
lastReceivedData = data;
|
||||
|
||||
// New API format: segments with per-segment buffers, metadata wrapper
|
||||
const {
|
||||
lines = [],
|
||||
buffer_transcription = "",
|
||||
buffer_diarization = "",
|
||||
buffer_translation = "",
|
||||
remaining_time_transcription = 0,
|
||||
remaining_time_diarization = 0,
|
||||
segments = [],
|
||||
metadata = {},
|
||||
status = "active_transcription",
|
||||
} = data;
|
||||
|
||||
const {
|
||||
remaining_time_transcription = 0,
|
||||
remaining_time_diarization = 0,
|
||||
} = metadata;
|
||||
|
||||
renderLinesWithBuffer(
|
||||
lines,
|
||||
buffer_diarization,
|
||||
buffer_transcription,
|
||||
buffer_translation,
|
||||
renderSegments(
|
||||
segments,
|
||||
remaining_time_diarization,
|
||||
remaining_time_transcription,
|
||||
false,
|
||||
@@ -323,11 +316,8 @@ function setupWebSocket() {
|
||||
});
|
||||
}
|
||||
|
||||
function renderLinesWithBuffer(
|
||||
lines,
|
||||
buffer_diarization,
|
||||
buffer_transcription,
|
||||
buffer_translation,
|
||||
function renderSegments(
|
||||
segments,
|
||||
remaining_time_diarization,
|
||||
remaining_time_transcription,
|
||||
isFinalizing = false,
|
||||
@@ -339,33 +329,38 @@ function renderLinesWithBuffer(
|
||||
return;
|
||||
}
|
||||
|
||||
const showLoading = !isFinalizing && (lines || []).some((it) => it.speaker == 0);
|
||||
const showTransLag = !isFinalizing && remaining_time_transcription > 0;
|
||||
const showDiaLag = !isFinalizing && !!buffer_diarization && remaining_time_diarization > 0;
|
||||
// Build signature for change detection
|
||||
const signature = JSON.stringify({
|
||||
lines: (lines || []).map((it) => ({ speaker: it.speaker, text: it.text, start: it.start, end: it.end, detected_language: it.detected_language })),
|
||||
buffer_transcription: buffer_transcription || "",
|
||||
buffer_diarization: buffer_diarization || "",
|
||||
buffer_translation: buffer_translation,
|
||||
segments: (segments || []).map((it) => ({
|
||||
id: it.id,
|
||||
speaker: it.speaker,
|
||||
text: it.text,
|
||||
start: it.start,
|
||||
end: it.end,
|
||||
language: it.language,
|
||||
buffer: it.buffer || {}
|
||||
})),
|
||||
status: current_status,
|
||||
showLoading,
|
||||
showTransLag,
|
||||
showDiaLag,
|
||||
isFinalizing: !!isFinalizing,
|
||||
});
|
||||
|
||||
// Only update lag values if signature unchanged
|
||||
if (lastSignature === signature) {
|
||||
const t = document.querySelector(".lag-transcription-value");
|
||||
if (t) t.textContent = fmt1(remaining_time_transcription);
|
||||
const d = document.querySelector(".lag-diarization-value");
|
||||
if (d) d.textContent = fmt1(remaining_time_diarization);
|
||||
const ld = document.querySelector(".loading-diarization-value");
|
||||
if (ld) ld.textContent = fmt1(remaining_time_diarization);
|
||||
return;
|
||||
}
|
||||
lastSignature = signature;
|
||||
|
||||
const linesHtml = (lines || [])
|
||||
const segmentsHtml = (segments || [])
|
||||
.map((item, idx) => {
|
||||
const buffer = item.buffer || {};
|
||||
const buffer_transcription = buffer.transcription || "";
|
||||
const buffer_diarization = buffer.diarization || "";
|
||||
const buffer_translation = buffer.translation || "";
|
||||
|
||||
let timeInfo = "";
|
||||
if (item.start !== undefined && item.end !== undefined) {
|
||||
timeInfo = ` ${item.start} - ${item.end}`;
|
||||
@@ -373,80 +368,78 @@ function renderLinesWithBuffer(
|
||||
|
||||
let speakerLabel = "";
|
||||
if (item.speaker === -2) {
|
||||
// Silence segment
|
||||
speakerLabel = `<span class="silence">${silenceIcon}<span id='timeInfo'>${timeInfo}</span></span>`;
|
||||
} else if (item.speaker == 0 && !isFinalizing) {
|
||||
speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'><span class="loading-diarization-value">${fmt1(
|
||||
remaining_time_diarization
|
||||
)}</span> second(s) of audio are undergoing diarization</span></span>`;
|
||||
} else if (item.speaker !== 0) {
|
||||
// Normal speaker segment
|
||||
const speakerNum = `<span class="speaker-badge">${item.speaker}</span>`;
|
||||
speakerLabel = `<span id="speaker">${speakerIcon}${speakerNum}<span id='timeInfo'>${timeInfo}</span></span>`;
|
||||
|
||||
if (item.detected_language) {
|
||||
speakerLabel += `<span class="label_language">${languageIcon}<span>${item.detected_language}</span></span>`;
|
||||
if (item.language) {
|
||||
speakerLabel += `<span class="label_language">${languageIcon}<span>${item.language}</span></span>`;
|
||||
}
|
||||
}
|
||||
|
||||
let currentLineText = item.text || "";
|
||||
|
||||
if (idx === lines.length - 1) {
|
||||
if (!isFinalizing && item.speaker !== -2) {
|
||||
speakerLabel += `<span class="label_transcription"><span class="spinner"></span>Transcription lag <span id='timeInfo'><span class="lag-transcription-value">${fmt1(
|
||||
remaining_time_transcription
|
||||
)}</span>s</span></span>`;
|
||||
|
||||
if (buffer_diarization && remaining_time_diarization) {
|
||||
speakerLabel += `<span class="label_diarization"><span class="spinner"></span>Diarization lag<span id='timeInfo'><span class="lag-diarization-value">${fmt1(
|
||||
remaining_time_diarization
|
||||
)}</span>s</span></span>`;
|
||||
}
|
||||
const isLastSegment = idx === segments.length - 1;
|
||||
const hasBufferContent = buffer_diarization || buffer_transcription;
|
||||
|
||||
// Show lag indicators on last non-silent segment (without spinners)
|
||||
if (isLastSegment && item.speaker !== -2 && !isFinalizing) {
|
||||
if (remaining_time_transcription > 0) {
|
||||
speakerLabel += `<span class="label_transcription">Transcription lag: <span class="lag-transcription-value">${fmt1(remaining_time_transcription)}</span>s</span>`;
|
||||
}
|
||||
if (buffer_diarization && remaining_time_diarization > 0) {
|
||||
speakerLabel += `<span class="label_diarization">Diarization lag: <span class="lag-diarization-value">${fmt1(remaining_time_diarization)}</span>s</span>`;
|
||||
}
|
||||
}
|
||||
|
||||
// Render buffers
|
||||
if (hasBufferContent && item.speaker !== -2) {
|
||||
if (buffer_diarization) {
|
||||
if (isFinalizing) {
|
||||
currentLineText +=
|
||||
(currentLineText.length > 0 && buffer_diarization.trim().length > 0 ? " " : "") + buffer_diarization.trim();
|
||||
currentLineText += (currentLineText.length > 0 ? " " : "") + buffer_diarization.trim();
|
||||
} else {
|
||||
currentLineText += `<span class="buffer_diarization">${buffer_diarization}</span>`;
|
||||
}
|
||||
}
|
||||
if (buffer_transcription) {
|
||||
if (isFinalizing) {
|
||||
currentLineText +=
|
||||
(currentLineText.length > 0 && buffer_transcription.trim().length > 0 ? " " : "") +
|
||||
buffer_transcription.trim();
|
||||
currentLineText += (currentLineText.length > 0 ? " " : "") + buffer_transcription.trim();
|
||||
} else {
|
||||
currentLineText += `<span class="buffer_transcription">${buffer_transcription}</span>`;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Translation
|
||||
let translationContent = "";
|
||||
if (item.translation) {
|
||||
translationContent += item.translation.trim();
|
||||
}
|
||||
if (idx === lines.length - 1 && buffer_translation) {
|
||||
if (buffer_translation) {
|
||||
const bufferPiece = isFinalizing
|
||||
? buffer_translation
|
||||
: `<span class="buffer_translation">${buffer_translation}</span>`;
|
||||
translationContent += translationContent ? `${bufferPiece}` : bufferPiece;
|
||||
translationContent += translationContent ? bufferPiece : bufferPiece;
|
||||
}
|
||||
if (translationContent.trim().length > 0) {
|
||||
currentLineText += `
|
||||
<div>
|
||||
<div class="label_translation">
|
||||
${translationIcon}
|
||||
<span class="translation_text">${translationContent}</span>
|
||||
</div>
|
||||
</div>`;
|
||||
<div class="label_translation">
|
||||
${translationIcon}
|
||||
<span class="translation_text">${translationContent}</span>
|
||||
</div>`;
|
||||
}
|
||||
|
||||
return currentLineText.trim().length > 0 || speakerLabel.length > 0
|
||||
? `<p>${speakerLabel}<br/><div class='textcontent'>${currentLineText}</div></p>`
|
||||
: `<p>${speakerLabel}<br/></p>`;
|
||||
if (currentLineText.trim().length > 0 || speakerLabel.length > 0) {
|
||||
return `<p>${speakerLabel}<br/><div class='textcontent'>${currentLineText}</div></p>`;
|
||||
}
|
||||
return speakerLabel ? `<p>${speakerLabel}</p>` : "";
|
||||
})
|
||||
.filter(html => html.length > 0)
|
||||
.join("");
|
||||
|
||||
linesTranscriptDiv.innerHTML = linesHtml;
|
||||
linesTranscriptDiv.innerHTML = segmentsHtml;
|
||||
const transcriptContainer = document.querySelector('.transcript-container');
|
||||
if (transcriptContainer) {
|
||||
transcriptContainer.scrollTo({ top: transcriptContainer.scrollHeight, behavior: "smooth" });
|
||||
|
||||
377
whisperlivekit/web/text_transcript.html
Normal file
377
whisperlivekit/web/text_transcript.html
Normal file
@@ -0,0 +1,377 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>WhisperLiveKit Transcript</title>
|
||||
<style>
|
||||
:root {
|
||||
--bg: #111;
|
||||
--text: #ddd;
|
||||
--dim: #666;
|
||||
--border: #333;
|
||||
--active: #e74c3c;
|
||||
}
|
||||
body {
|
||||
font-family: 'SF Mono', 'Monaco', 'Inconsolata', 'Roboto Mono', monospace;
|
||||
background: var(--bg);
|
||||
color: var(--text);
|
||||
margin: 0;
|
||||
padding: 2rem;
|
||||
font-size: 13px;
|
||||
line-height: 1.6;
|
||||
}
|
||||
.nav {
|
||||
display: flex;
|
||||
gap: 12px;
|
||||
align-items: center;
|
||||
margin-bottom: 3rem;
|
||||
font-size: 12px;
|
||||
}
|
||||
button, input, select {
|
||||
background: transparent;
|
||||
border: 1px solid var(--border);
|
||||
color: var(--dim);
|
||||
padding: 6px 12px;
|
||||
font-family: inherit;
|
||||
font-size: inherit;
|
||||
border-radius: 4px;
|
||||
outline: none;
|
||||
transition: all 0.2s;
|
||||
}
|
||||
button:hover, input:hover, input:focus, select:hover, select:focus {
|
||||
border-color: var(--text);
|
||||
color: var(--text);
|
||||
cursor: pointer;
|
||||
}
|
||||
select {
|
||||
cursor: pointer;
|
||||
appearance: none; /* Minimalist look */
|
||||
background-image: linear-gradient(45deg, transparent 50%, var(--dim) 50%), linear-gradient(135deg, var(--dim) 50%, transparent 50%);
|
||||
background-position: calc(100% - 15px) 50%, calc(100% - 10px) 50%;
|
||||
background-size: 5px 5px, 5px 5px;
|
||||
background-repeat: no-repeat;
|
||||
padding-right: 25px;
|
||||
}
|
||||
select:hover, select:focus {
|
||||
background-image: linear-gradient(45deg, transparent 50%, var(--text) 50%), linear-gradient(135deg, var(--text) 50%, transparent 50%);
|
||||
}
|
||||
button.recording {
|
||||
border-color: var(--active);
|
||||
color: var(--active);
|
||||
}
|
||||
input {
|
||||
width: 150px;
|
||||
cursor: text;
|
||||
}
|
||||
#status {
|
||||
margin-left: auto;
|
||||
color: var(--dim);
|
||||
}
|
||||
#transcript {
|
||||
white-space: pre-wrap;
|
||||
word-wrap: break-word;
|
||||
max-width: 800px;
|
||||
margin: 0 auto;
|
||||
outline: none;
|
||||
}
|
||||
/* Minimal scrollbar */
|
||||
::-webkit-scrollbar { width: 6px; }
|
||||
::-webkit-scrollbar-track { background: transparent; }
|
||||
::-webkit-scrollbar-thumb { background: #222; border-radius: 3px; }
|
||||
::-webkit-scrollbar-thumb:hover { background: #333; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="nav">
|
||||
<button id="recordBtn">Record</button>
|
||||
<button id="copyBtn">Copy</button>
|
||||
<select id="microphoneSelect"></select>
|
||||
<input type="text" id="wsUrl" placeholder="WebSocket URL">
|
||||
<div id="status">Ready</div>
|
||||
</div>
|
||||
|
||||
<div id="transcript"></div>
|
||||
|
||||
<script>
|
||||
const recordBtn = document.getElementById('recordBtn');
|
||||
const copyBtn = document.getElementById('copyBtn');
|
||||
const wsUrlInput = document.getElementById('wsUrl');
|
||||
const statusEl = document.getElementById('status');
|
||||
const transcriptEl = document.getElementById('transcript');
|
||||
const microphoneSelect = document.getElementById('microphoneSelect');
|
||||
|
||||
// Default WebSocket URL
|
||||
const protocol = window.location.protocol === 'https:' ? 'wss' : 'ws';
|
||||
const host = window.location.hostname || 'localhost';
|
||||
const port = window.location.port;
|
||||
const defaultUrl = `${protocol}://${host}${port ? ':' + port : ''}/asr`;
|
||||
wsUrlInput.value = defaultUrl;
|
||||
|
||||
let websocket = null;
|
||||
let isRecording = false;
|
||||
let audioContext = null;
|
||||
let workletNode = null;
|
||||
let recorderWorker = null;
|
||||
let microphone = null;
|
||||
let useAudioWorklet = false;
|
||||
let recorder = null;
|
||||
let availableMicrophones = [];
|
||||
let selectedMicrophoneId = null;
|
||||
|
||||
async function enumerateMicrophones() {
|
||||
try {
|
||||
// Request permission first to get labels
|
||||
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
|
||||
stream.getTracks().forEach(track => track.stop());
|
||||
|
||||
const devices = await navigator.mediaDevices.enumerateDevices();
|
||||
availableMicrophones = devices.filter(device => device.kind === 'audioinput');
|
||||
|
||||
populateMicrophoneSelect();
|
||||
} catch (error) {
|
||||
console.error('Error enumerating microphones:', error);
|
||||
statusEl.textContent = "Mic permission needed";
|
||||
}
|
||||
}
|
||||
|
||||
function populateMicrophoneSelect() {
|
||||
microphoneSelect.innerHTML = '<option value="">Default Microphone</option>';
|
||||
|
||||
availableMicrophones.forEach((device, index) => {
|
||||
const option = document.createElement('option');
|
||||
option.value = device.deviceId;
|
||||
option.textContent = device.label || `Microphone ${index + 1}`;
|
||||
microphoneSelect.appendChild(option);
|
||||
});
|
||||
|
||||
const savedMicId = localStorage.getItem('selectedMicrophone');
|
||||
if (savedMicId && availableMicrophones.some(mic => mic.deviceId === savedMicId)) {
|
||||
microphoneSelect.value = savedMicId;
|
||||
selectedMicrophoneId = savedMicId;
|
||||
}
|
||||
}
|
||||
|
||||
function handleMicrophoneChange() {
|
||||
selectedMicrophoneId = microphoneSelect.value || null;
|
||||
localStorage.setItem('selectedMicrophone', selectedMicrophoneId || '');
|
||||
|
||||
if (isRecording) {
|
||||
stopRecording();
|
||||
setTimeout(() => {
|
||||
startRecording();
|
||||
}, 500);
|
||||
}
|
||||
}
|
||||
|
||||
microphoneSelect.addEventListener('change', handleMicrophoneChange);
|
||||
|
||||
// Initial enumeration
|
||||
enumerateMicrophones();
|
||||
navigator.mediaDevices.addEventListener('devicechange', enumerateMicrophones);
|
||||
|
||||
function formatSegment(segment) {
|
||||
const speaker = segment.speaker;
|
||||
const text = segment.text || '';
|
||||
const buffer = segment.buffer || {};
|
||||
const start = segment.start || '';
|
||||
const end = segment.end || '';
|
||||
const language = segment.language || '';
|
||||
|
||||
let output = '';
|
||||
|
||||
// Silence marker
|
||||
if (speaker === -2) {
|
||||
output += `[SILENCE ${start} - ${end}]\n`;
|
||||
return output;
|
||||
}
|
||||
|
||||
// Speaker header
|
||||
output += `[SPEAKER ${speaker}]`;
|
||||
if (start && end) output += ` ${start} - ${end}`;
|
||||
if (language) output += ` [LANG: ${language}]`;
|
||||
output += '\n';
|
||||
|
||||
// Main text
|
||||
if (text) {
|
||||
output += text;
|
||||
}
|
||||
|
||||
// Diarization buffer (text waiting for speaker assignment)
|
||||
if (buffer.diarization) {
|
||||
output += `[DIAR_BUFFER]${buffer.diarization}[/DIAR_BUFFER]`;
|
||||
}
|
||||
|
||||
// Transcription buffer (text waiting for validation)
|
||||
if (buffer.transcription) {
|
||||
output += `[TRANS_BUFFER]${buffer.transcription}[/TRANS_BUFFER]`;
|
||||
}
|
||||
|
||||
output += '\n';
|
||||
|
||||
// Translation
|
||||
if (segment.translation) {
|
||||
output += `[TRANSLATION]${segment.translation}`;
|
||||
if (buffer.translation) {
|
||||
output += `[TRANS_BUFFER]${buffer.translation}[/TRANS_BUFFER]`;
|
||||
}
|
||||
output += `[/TRANSLATION]\n`;
|
||||
} else if (buffer.translation) {
|
||||
output += `[TRANSLATION][TRANS_BUFFER]${buffer.translation}[/TRANS_BUFFER][/TRANSLATION]\n`;
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
function renderTranscript(data) {
|
||||
const { segments = [], metadata = {}, status: msgStatus } = data;
|
||||
|
||||
if (msgStatus === 'no_audio_detected') {
|
||||
// transcriptEl.textContent = '[NO AUDIO DETECTED]';
|
||||
// Minimalist: maybe just don't show anything or show status
|
||||
statusEl.textContent = 'No audio detected';
|
||||
return;
|
||||
}
|
||||
|
||||
let output = '';
|
||||
|
||||
// Metadata header
|
||||
const remainingTrans = metadata.remaining_time_transcription || 0;
|
||||
const remainingDiar = metadata.remaining_time_diarization || 0;
|
||||
if (remainingTrans > 0 || remainingDiar > 0) {
|
||||
output += `[LAG: trans=${remainingTrans.toFixed(1)}s diar=${remainingDiar.toFixed(1)}s]\n\n`;
|
||||
}
|
||||
|
||||
// All segments
|
||||
for (const segment of segments) {
|
||||
output += formatSegment(segment);
|
||||
output += '\n';
|
||||
}
|
||||
|
||||
transcriptEl.textContent = output;
|
||||
transcriptEl.scrollTop = transcriptEl.scrollHeight;
|
||||
}
|
||||
|
||||
async function startRecording() {
|
||||
try {
|
||||
websocket = new WebSocket(wsUrlInput.value);
|
||||
|
||||
websocket.onopen = async () => {
|
||||
statusEl.textContent = 'Connecting...';
|
||||
};
|
||||
|
||||
websocket.onmessage = async (event) => {
|
||||
const data = JSON.parse(event.data);
|
||||
|
||||
if (data.type === 'config') {
|
||||
useAudioWorklet = !!data.useAudioWorklet;
|
||||
statusEl.textContent = 'Recording';
|
||||
await initAudio();
|
||||
return;
|
||||
}
|
||||
|
||||
if (data.type === 'ready_to_stop') {
|
||||
statusEl.textContent = 'Done';
|
||||
return;
|
||||
}
|
||||
|
||||
// transcript_update
|
||||
renderTranscript(data);
|
||||
};
|
||||
|
||||
websocket.onclose = () => {
|
||||
statusEl.textContent = 'Disconnected';
|
||||
stopRecording(false);
|
||||
};
|
||||
|
||||
websocket.onerror = () => {
|
||||
statusEl.textContent = 'Error';
|
||||
};
|
||||
|
||||
} catch (err) {
|
||||
statusEl.textContent = 'Error: ' + err.message;
|
||||
}
|
||||
}
|
||||
|
||||
async function initAudio() {
|
||||
const audioConstraints = selectedMicrophoneId
|
||||
? { audio: { deviceId: { exact: selectedMicrophoneId } } }
|
||||
: { audio: true };
|
||||
|
||||
const stream = await navigator.mediaDevices.getUserMedia(audioConstraints);
|
||||
audioContext = new (window.AudioContext || window.webkitAudioContext)();
|
||||
microphone = audioContext.createMediaStreamSource(stream);
|
||||
|
||||
if (useAudioWorklet) {
|
||||
await audioContext.audioWorklet.addModule('/web/pcm_worklet.js');
|
||||
workletNode = new AudioWorkletNode(audioContext, 'pcm-forwarder', {
|
||||
numberOfInputs: 1, numberOfOutputs: 0, channelCount: 1
|
||||
});
|
||||
microphone.connect(workletNode);
|
||||
|
||||
recorderWorker = new Worker('/web/recorder_worker.js');
|
||||
recorderWorker.postMessage({ command: 'init', config: { sampleRate: audioContext.sampleRate } });
|
||||
|
||||
recorderWorker.onmessage = (e) => {
|
||||
if (websocket?.readyState === WebSocket.OPEN) {
|
||||
websocket.send(e.data.buffer);
|
||||
}
|
||||
};
|
||||
|
||||
workletNode.port.onmessage = (e) => {
|
||||
const ab = e.data instanceof ArrayBuffer ? e.data : e.data.buffer;
|
||||
recorderWorker.postMessage({ command: 'record', buffer: ab }, [ab]);
|
||||
};
|
||||
} else {
|
||||
try {
|
||||
recorder = new MediaRecorder(stream, { mimeType: 'audio/webm' });
|
||||
} catch {
|
||||
recorder = new MediaRecorder(stream);
|
||||
}
|
||||
recorder.ondataavailable = (e) => {
|
||||
if (websocket?.readyState === WebSocket.OPEN && e.data?.size > 0) {
|
||||
websocket.send(e.data);
|
||||
}
|
||||
};
|
||||
recorder.start(100);
|
||||
}
|
||||
|
||||
isRecording = true;
|
||||
recordBtn.textContent = 'Stop';
|
||||
recordBtn.classList.add('recording');
|
||||
}
|
||||
|
||||
function stopRecording(sendStop = true) {
|
||||
if (sendStop && websocket?.readyState === WebSocket.OPEN) {
|
||||
websocket.send(new Blob([], { type: 'audio/webm' }));
|
||||
}
|
||||
|
||||
if (recorder) { try { recorder.stop(); } catch {} recorder = null; }
|
||||
if (recorderWorker) { recorderWorker.terminate(); recorderWorker = null; }
|
||||
if (workletNode) { workletNode.disconnect(); workletNode = null; }
|
||||
if (microphone) { microphone.disconnect(); microphone = null; }
|
||||
if (audioContext) { audioContext.close(); audioContext = null; }
|
||||
|
||||
isRecording = false;
|
||||
recordBtn.textContent = 'Record';
|
||||
recordBtn.classList.remove('recording');
|
||||
}
|
||||
|
||||
recordBtn.addEventListener('click', () => {
|
||||
if (!isRecording) {
|
||||
startRecording();
|
||||
} else {
|
||||
stopRecording();
|
||||
}
|
||||
});
|
||||
|
||||
copyBtn.addEventListener('click', () => {
|
||||
navigator.clipboard.writeText(transcriptEl.textContent).then(() => {
|
||||
const original = copyBtn.textContent;
|
||||
copyBtn.textContent = 'Copied';
|
||||
setTimeout(() => { copyBtn.textContent = original; }, 1500);
|
||||
});
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -13,6 +13,37 @@ def get_web_interface_html():
|
||||
logger.error(f"Error loading web interface HTML: {e}")
|
||||
return "<html><body><h1>Error loading interface</h1></body></html>"
|
||||
|
||||
|
||||
def get_text_transcript_html():
|
||||
"""Loads the simple text-based transcript HTML for easy copy/paste."""
|
||||
try:
|
||||
with resources.files('whisperlivekit.web').joinpath('text_transcript.html').open('r', encoding='utf-8') as f:
|
||||
html_content = f.read()
|
||||
|
||||
# Inline the worker scripts
|
||||
with resources.files('whisperlivekit.web').joinpath('pcm_worklet.js').open('r', encoding='utf-8') as f:
|
||||
worklet_code = f.read()
|
||||
with resources.files('whisperlivekit.web').joinpath('recorder_worker.js').open('r', encoding='utf-8') as f:
|
||||
worker_code = f.read()
|
||||
|
||||
html_content = html_content.replace(
|
||||
"await audioContext.audioWorklet.addModule('/web/pcm_worklet.js');",
|
||||
'const workletBlob = new Blob([`' + worklet_code + '`], { type: "application/javascript" });\n' +
|
||||
'const workletUrl = URL.createObjectURL(workletBlob);\n' +
|
||||
'await audioContext.audioWorklet.addModule(workletUrl);'
|
||||
)
|
||||
html_content = html_content.replace(
|
||||
"recorderWorker = new Worker('/web/recorder_worker.js');",
|
||||
'const workerBlob = new Blob([`' + worker_code + '`], { type: "application/javascript" });\n' +
|
||||
'const workerUrl = URL.createObjectURL(workerBlob);\n' +
|
||||
'recorderWorker = new Worker(workerUrl);'
|
||||
)
|
||||
|
||||
return html_content
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading text transcript HTML: {e}")
|
||||
return "<html><body><h1>Error loading text interface</h1></body></html>"
|
||||
|
||||
def get_inline_ui_html():
|
||||
"""Returns the complete web interface HTML with all assets embedded in a single call."""
|
||||
try:
|
||||
|
||||
@@ -18,6 +18,8 @@ from whisperlivekit.whisper.decoding import (DecodingOptions, DecodingResult,
|
||||
from whisperlivekit.whisper.model import ModelDimensions, Whisper
|
||||
from whisperlivekit.whisper.transcribe import transcribe
|
||||
from whisperlivekit.whisper.version import __version__
|
||||
from whisperlivekit.whisper.lora import (LoRAAdapter, LoRAAdapterManager,
|
||||
LoRAConfig, LoRALinear)
|
||||
|
||||
_MODELS = {
|
||||
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
||||
@@ -551,6 +553,94 @@ def load_model(
|
||||
return model.to(device)
|
||||
|
||||
|
||||
def load_model_with_lora_manager(
|
||||
name: str,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
download_root: str = None,
|
||||
in_memory: bool = False,
|
||||
decoder_only: bool = False,
|
||||
custom_alignment_heads: Optional[str] = None,
|
||||
adapters: Optional[Dict[str, str]] = None,
|
||||
) -> tuple:
|
||||
"""
|
||||
Load a Whisper model with a LoRA adapter manager for dynamic adapter swapping.
|
||||
|
||||
This allows you to load multiple LoRA adapters and switch between them at runtime
|
||||
without keeping multiple full models in memory.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
Model name or path (same as load_model)
|
||||
device : Union[str, torch.device]
|
||||
Device to load model on
|
||||
download_root : str
|
||||
Download directory for model files
|
||||
in_memory : bool
|
||||
Whether to preload model weights into host memory
|
||||
decoder_only : bool
|
||||
If True, only load the decoder (no encoder)
|
||||
custom_alignment_heads : str
|
||||
Custom alignment heads configuration
|
||||
adapters : Dict[str, str]
|
||||
Optional dict mapping adapter names to paths/HuggingFace repo IDs.
|
||||
Example: {"french": "path/to/french-lora", "spanish": "user/spanish-whisper-lora"}
|
||||
|
||||
Returns
|
||||
-------
|
||||
model : Whisper
|
||||
The base Whisper model (without any LoRA baked in)
|
||||
manager : LoRAAdapterManager
|
||||
The adapter manager for loading/switching adapters
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> model, manager = load_model_with_lora_manager(
|
||||
... "large-v3",
|
||||
... adapters={
|
||||
... "french": "path/to/french-lora",
|
||||
... "spanish": "path/to/spanish-lora"
|
||||
... }
|
||||
... )
|
||||
>>>
|
||||
>>> # Switch to French adapter
|
||||
>>> manager.set_adapter("french")
|
||||
>>> result_fr = model.transcribe(audio_fr)
|
||||
>>>
|
||||
>>> # Switch to Spanish adapter
|
||||
>>> manager.set_adapter("spanish")
|
||||
>>> result_es = model.transcribe(audio_es)
|
||||
>>>
|
||||
>>> # Use base model without LoRA
|
||||
>>> manager.set_adapter(None)
|
||||
>>> result_base = model.transcribe(audio)
|
||||
>>>
|
||||
>>> # Check memory usage
|
||||
>>> print(manager.get_memory_usage())
|
||||
{'french': 12.5, 'spanish': 12.5} # MB per adapter
|
||||
"""
|
||||
# Load the base model WITHOUT any LoRA baked in
|
||||
model = load_model(
|
||||
name=name,
|
||||
device=device,
|
||||
download_root=download_root,
|
||||
in_memory=in_memory,
|
||||
decoder_only=decoder_only,
|
||||
custom_alignment_heads=custom_alignment_heads,
|
||||
lora_path=None, # Important: no baked-in LoRA
|
||||
)
|
||||
|
||||
# Create the adapter manager
|
||||
manager = LoRAAdapterManager(model)
|
||||
|
||||
# Load any provided adapters
|
||||
if adapters:
|
||||
for adapter_name, adapter_path in adapters.items():
|
||||
manager.load_adapter(adapter_name, adapter_path)
|
||||
|
||||
return model, manager
|
||||
|
||||
|
||||
def convert_encoder_to_coreml(
|
||||
model_name = "base",
|
||||
output_path= "whisper_encoder.mlpackage",
|
||||
|
||||
473
whisperlivekit/whisper/lora.py
Normal file
473
whisperlivekit/whisper/lora.py
Normal file
@@ -0,0 +1,473 @@
|
||||
"""
|
||||
Dynamic LoRA adapter support for Whisper models.
|
||||
|
||||
This module enables loading a single base Whisper model and dynamically swapping
|
||||
between multiple LoRA adapters at runtime, saving GPU memory when working with
|
||||
multiple language-specific fine-tuned models.
|
||||
|
||||
Usage:
|
||||
from whisperlivekit.whisper import load_model
|
||||
from whisperlivekit.whisper.lora import LoRAAdapterManager
|
||||
|
||||
# Load base model without any LoRA baked in
|
||||
model = load_model("large-v3", device="cuda")
|
||||
|
||||
# Create adapter manager
|
||||
manager = LoRAAdapterManager(model)
|
||||
|
||||
# Load multiple adapters (small memory footprint each)
|
||||
manager.load_adapter("french", "path/to/french-lora")
|
||||
manager.load_adapter("spanish", "path/to/spanish-lora")
|
||||
|
||||
# Switch between adapters at runtime
|
||||
manager.set_adapter("french")
|
||||
result_fr = model.transcribe(audio_fr)
|
||||
|
||||
manager.set_adapter("spanish")
|
||||
result_es = model.transcribe(audio_es)
|
||||
|
||||
# Disable LoRA (use base model only)
|
||||
manager.set_adapter(None)
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .model import Linear
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAConfig:
|
||||
"""Configuration for a LoRA adapter."""
|
||||
r: int # LoRA rank
|
||||
alpha: float # LoRA alpha (scaling factor)
|
||||
target_modules: List[str] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def scaling(self) -> float:
|
||||
return self.alpha / self.r
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAAdapter:
|
||||
"""Holds the LoRA A/B weight matrices for a single adapter."""
|
||||
name: str
|
||||
config: LoRAConfig
|
||||
# Maps target module name -> (A matrix, B matrix)
|
||||
weights: Dict[str, Tuple[Tensor, Tensor]] = field(default_factory=dict)
|
||||
device: torch.device = field(default_factory=lambda: torch.device("cpu"))
|
||||
dtype: torch.dtype = field(default=torch.float32)
|
||||
|
||||
def to(self, device: torch.device, dtype: Optional[torch.dtype] = None):
|
||||
"""Move adapter weights to specified device/dtype."""
|
||||
self.device = device
|
||||
if dtype is not None:
|
||||
self.dtype = dtype
|
||||
self.weights = {
|
||||
name: (a.to(device=device, dtype=dtype or self.dtype),
|
||||
b.to(device=device, dtype=dtype or self.dtype))
|
||||
for name, (a, b) in self.weights.items()
|
||||
}
|
||||
return self
|
||||
|
||||
def memory_footprint_mb(self) -> float:
|
||||
"""Return approximate memory usage in MB."""
|
||||
total_bytes = 0
|
||||
for a, b in self.weights.values():
|
||||
total_bytes += a.numel() * a.element_size()
|
||||
total_bytes += b.numel() * b.element_size()
|
||||
return total_bytes / (1024 * 1024)
|
||||
|
||||
|
||||
class LoRALinear(nn.Module):
|
||||
"""
|
||||
A Linear layer wrapper that supports dynamic LoRA injection.
|
||||
|
||||
The base weights remain unchanged. LoRA is applied additively during forward:
|
||||
output = base_linear(x) + (x @ A @ B) * scaling
|
||||
"""
|
||||
|
||||
def __init__(self, base_linear: Linear):
|
||||
super().__init__()
|
||||
self.base_linear = base_linear
|
||||
self.lora_A: Optional[Tensor] = None
|
||||
self.lora_B: Optional[Tensor] = None
|
||||
self.scaling: float = 1.0
|
||||
self._lora_enabled: bool = False
|
||||
|
||||
def set_lora(self, A: Optional[Tensor], B: Optional[Tensor], scaling: float = 1.0):
|
||||
"""Set the LoRA matrices for this layer."""
|
||||
self.lora_A = A
|
||||
self.lora_B = B
|
||||
self.scaling = scaling
|
||||
self._lora_enabled = A is not None and B is not None
|
||||
|
||||
def clear_lora(self):
|
||||
"""Remove LoRA from this layer."""
|
||||
self.lora_A = None
|
||||
self.lora_B = None
|
||||
self._lora_enabled = False
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
# Base linear output
|
||||
out = self.base_linear(x)
|
||||
|
||||
# Add LoRA contribution if enabled
|
||||
if self._lora_enabled and self.lora_A is not None and self.lora_B is not None:
|
||||
# x: (..., in_features)
|
||||
# A: (in_features, r)
|
||||
# B: (r, out_features)
|
||||
# lora_out: (..., out_features)
|
||||
lora_out = (x @ self.lora_A.to(x.dtype)) @ self.lora_B.to(x.dtype)
|
||||
out = out + lora_out * self.scaling
|
||||
|
||||
return out
|
||||
|
||||
# Delegate attribute access to base_linear for compatibility
|
||||
@property
|
||||
def weight(self):
|
||||
return self.base_linear.weight
|
||||
|
||||
@property
|
||||
def bias(self):
|
||||
return self.base_linear.bias
|
||||
|
||||
@property
|
||||
def in_features(self):
|
||||
return self.base_linear.in_features
|
||||
|
||||
@property
|
||||
def out_features(self):
|
||||
return self.base_linear.out_features
|
||||
|
||||
|
||||
# Mapping from HuggingFace LoRA module names to Whisper module paths
|
||||
_HF_TO_WHISPER_MODULE_MAP = {
|
||||
# Encoder attention
|
||||
"model.encoder.layers.{}.self_attn.q_proj": "encoder.blocks.{}.attn.query",
|
||||
"model.encoder.layers.{}.self_attn.k_proj": "encoder.blocks.{}.attn.key",
|
||||
"model.encoder.layers.{}.self_attn.v_proj": "encoder.blocks.{}.attn.value",
|
||||
"model.encoder.layers.{}.self_attn.out_proj": "encoder.blocks.{}.attn.out",
|
||||
# Encoder MLP
|
||||
"model.encoder.layers.{}.fc1": "encoder.blocks.{}.mlp.0",
|
||||
"model.encoder.layers.{}.fc2": "encoder.blocks.{}.mlp.2",
|
||||
|
||||
# Decoder self-attention
|
||||
"model.decoder.layers.{}.self_attn.q_proj": "decoder.blocks.{}.attn.query",
|
||||
"model.decoder.layers.{}.self_attn.k_proj": "decoder.blocks.{}.attn.key",
|
||||
"model.decoder.layers.{}.self_attn.v_proj": "decoder.blocks.{}.attn.value",
|
||||
"model.decoder.layers.{}.self_attn.out_proj": "decoder.blocks.{}.attn.out",
|
||||
# Decoder cross-attention
|
||||
"model.decoder.layers.{}.encoder_attn.q_proj": "decoder.blocks.{}.cross_attn.query",
|
||||
"model.decoder.layers.{}.encoder_attn.k_proj": "decoder.blocks.{}.cross_attn.key",
|
||||
"model.decoder.layers.{}.encoder_attn.v_proj": "decoder.blocks.{}.cross_attn.value",
|
||||
"model.decoder.layers.{}.encoder_attn.out_proj": "decoder.blocks.{}.cross_attn.out",
|
||||
# Decoder MLP
|
||||
"model.decoder.layers.{}.fc1": "decoder.blocks.{}.mlp.0",
|
||||
"model.decoder.layers.{}.fc2": "decoder.blocks.{}.mlp.2",
|
||||
}
|
||||
|
||||
|
||||
def _normalize_hf_module_name(name: str) -> str:
|
||||
"""Normalize HF-style LoRA module names."""
|
||||
if name.startswith("base_model."):
|
||||
name = name[len("base_model."):]
|
||||
if name.startswith("model.model."):
|
||||
name = name[len("model."):]
|
||||
if not name.startswith("model."):
|
||||
name = f"model.{name}"
|
||||
return name
|
||||
|
||||
|
||||
def _map_hf_to_whisper_module(hf_name: str) -> Optional[str]:
|
||||
"""Map a HuggingFace LoRA module name to Whisper module path."""
|
||||
hf_name = _normalize_hf_module_name(hf_name)
|
||||
|
||||
# Try to match with layer index patterns
|
||||
import re
|
||||
|
||||
# Match patterns like model.encoder.layers.5.self_attn.q_proj
|
||||
for pattern, target_pattern in _HF_TO_WHISPER_MODULE_MAP.items():
|
||||
# Create regex from pattern (replace {} with capture group)
|
||||
regex = pattern.replace(".", r"\.").replace("{}", r"(\d+)")
|
||||
match = re.fullmatch(regex, hf_name)
|
||||
if match:
|
||||
layer_idx = match.group(1)
|
||||
return target_pattern.format(layer_idx)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_module_by_path(model: nn.Module, path: str) -> Optional[nn.Module]:
|
||||
"""Get a submodule by dot-separated path."""
|
||||
parts = path.split(".")
|
||||
current = model
|
||||
for part in parts:
|
||||
if hasattr(current, part):
|
||||
current = getattr(current, part)
|
||||
elif hasattr(current, "__getitem__"):
|
||||
try:
|
||||
current = current[int(part)]
|
||||
except (ValueError, IndexError, KeyError):
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
return current
|
||||
|
||||
|
||||
def _set_module_by_path(model: nn.Module, path: str, module: nn.Module):
|
||||
"""Set a submodule by dot-separated path."""
|
||||
parts = path.split(".")
|
||||
parent = model
|
||||
for part in parts[:-1]:
|
||||
if hasattr(parent, part):
|
||||
parent = getattr(parent, part)
|
||||
elif hasattr(parent, "__getitem__"):
|
||||
parent = parent[int(part)]
|
||||
setattr(parent, parts[-1], module)
|
||||
|
||||
|
||||
class LoRAAdapterManager:
|
||||
"""
|
||||
Manages multiple LoRA adapters for a Whisper model.
|
||||
|
||||
Enables loading multiple adapters and switching between them at runtime
|
||||
without reloading the full model.
|
||||
"""
|
||||
|
||||
def __init__(self, model: nn.Module):
|
||||
"""
|
||||
Initialize the adapter manager.
|
||||
|
||||
Args:
|
||||
model: A Whisper model instance
|
||||
"""
|
||||
self.model = model
|
||||
self.adapters: Dict[str, LoRAAdapter] = {}
|
||||
self.current_adapter: Optional[str] = None
|
||||
self._lora_layers: Dict[str, LoRALinear] = {}
|
||||
self._original_layers: Dict[str, Linear] = {}
|
||||
self._initialized = False
|
||||
|
||||
def _initialize_lora_layers(self, target_modules: List[str]):
|
||||
"""
|
||||
Replace target Linear layers with LoRALinear wrappers.
|
||||
|
||||
This is done lazily on first adapter load.
|
||||
"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
# Find and wrap all potential LoRA target modules
|
||||
for whisper_path in target_modules:
|
||||
module = _get_module_by_path(self.model, whisper_path)
|
||||
if module is None:
|
||||
continue
|
||||
if isinstance(module, Linear) and not isinstance(module, LoRALinear):
|
||||
# Wrap the Linear layer
|
||||
lora_linear = LoRALinear(module)
|
||||
_set_module_by_path(self.model, whisper_path, lora_linear)
|
||||
self._lora_layers[whisper_path] = lora_linear
|
||||
self._original_layers[whisper_path] = module
|
||||
|
||||
self._initialized = True
|
||||
|
||||
def _resolve_lora_path(self, lora_path: str) -> str:
|
||||
"""Resolve LoRA path, downloading from HuggingFace Hub if needed."""
|
||||
if os.path.isdir(lora_path):
|
||||
return lora_path
|
||||
|
||||
# Try HuggingFace Hub
|
||||
if "/" in lora_path:
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
return snapshot_download(
|
||||
repo_id=lora_path,
|
||||
allow_patterns=["adapter_config.json", "adapter_model.*"],
|
||||
)
|
||||
except Exception as e:
|
||||
raise FileNotFoundError(
|
||||
f"Could not find LoRA adapter at local path or HuggingFace Hub: {lora_path}. Error: {e}"
|
||||
)
|
||||
|
||||
raise FileNotFoundError(f"LoRA path '{lora_path}' not found.")
|
||||
|
||||
def _load_adapter_weights(self, lora_path: str) -> Dict[str, Tensor]:
|
||||
"""Load adapter weights from safetensors or bin file."""
|
||||
safe_path = os.path.join(lora_path, "adapter_model.safetensors")
|
||||
bin_path = os.path.join(lora_path, "adapter_model.bin")
|
||||
|
||||
if os.path.isfile(safe_path):
|
||||
from safetensors.torch import load_file
|
||||
return load_file(safe_path)
|
||||
elif os.path.isfile(bin_path):
|
||||
return torch.load(bin_path, map_location="cpu")
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
f"No adapter weights found in {lora_path}. "
|
||||
"Expected adapter_model.safetensors or adapter_model.bin."
|
||||
)
|
||||
|
||||
def load_adapter(
|
||||
self,
|
||||
name: str,
|
||||
lora_path: str,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> LoRAAdapter:
|
||||
"""
|
||||
Load a LoRA adapter from disk or HuggingFace Hub.
|
||||
|
||||
Args:
|
||||
name: Unique name for this adapter (e.g., "french", "spanish")
|
||||
lora_path: Local path or HuggingFace repo ID
|
||||
device: Device to load weights to (default: model's device)
|
||||
dtype: Data type for weights (default: model's dtype)
|
||||
|
||||
Returns:
|
||||
The loaded LoRAAdapter
|
||||
"""
|
||||
if device is None:
|
||||
device = next(self.model.parameters()).device
|
||||
if dtype is None:
|
||||
dtype = next(self.model.parameters()).dtype
|
||||
|
||||
# Resolve path
|
||||
lora_path = self._resolve_lora_path(lora_path)
|
||||
|
||||
# Load config
|
||||
config_path = os.path.join(lora_path, "adapter_config.json")
|
||||
if not os.path.isfile(config_path):
|
||||
raise FileNotFoundError(f"Missing adapter_config.json in {lora_path}")
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
if config_dict.get("peft_type") != "LORA":
|
||||
raise ValueError("Only LoRA adapters are supported.")
|
||||
|
||||
config = LoRAConfig(
|
||||
r=config_dict["r"],
|
||||
alpha=config_dict.get("lora_alpha") or config_dict.get("alpha"),
|
||||
target_modules=config_dict.get("target_modules", []),
|
||||
)
|
||||
|
||||
# Load weights
|
||||
adapter_state = self._load_adapter_weights(lora_path)
|
||||
|
||||
# Parse LoRA A/B matrices and map to Whisper module paths
|
||||
lora_layers: Dict[str, Dict[str, Tensor]] = {}
|
||||
for key, tensor in adapter_state.items():
|
||||
if key.endswith("lora_A.weight"):
|
||||
module = key[:-len(".lora_A.weight")]
|
||||
lora_layers.setdefault(module, {})["A"] = tensor
|
||||
elif key.endswith("lora_B.weight"):
|
||||
module = key[:-len(".lora_B.weight")]
|
||||
lora_layers.setdefault(module, {})["B"] = tensor
|
||||
|
||||
# Map to Whisper module paths and collect weights
|
||||
weights: Dict[str, Tuple[Tensor, Tensor]] = {}
|
||||
whisper_paths = set()
|
||||
|
||||
for hf_module, parts in lora_layers.items():
|
||||
if "A" not in parts or "B" not in parts:
|
||||
continue
|
||||
|
||||
whisper_path = _map_hf_to_whisper_module(hf_module)
|
||||
if whisper_path is None:
|
||||
# Try direct mapping (module might already be in Whisper format)
|
||||
whisper_path = hf_module
|
||||
|
||||
# A: (r, in_features) -> transpose to (in_features, r)
|
||||
# B: (out_features, r) -> transpose to (r, out_features)
|
||||
A = parts["A"].T # (in_features, r)
|
||||
B = parts["B"].T # (r, out_features)
|
||||
|
||||
weights[whisper_path] = (A, B)
|
||||
whisper_paths.add(whisper_path)
|
||||
|
||||
# Create adapter
|
||||
adapter = LoRAAdapter(
|
||||
name=name,
|
||||
config=config,
|
||||
weights=weights,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
adapter.to(device, dtype)
|
||||
|
||||
# Initialize LoRA layers if not done yet
|
||||
self._initialize_lora_layers(list(whisper_paths))
|
||||
|
||||
# Store adapter
|
||||
self.adapters[name] = adapter
|
||||
|
||||
return adapter
|
||||
|
||||
def set_adapter(self, name: Optional[str]):
|
||||
"""
|
||||
Switch to a different adapter or disable LoRA.
|
||||
|
||||
Args:
|
||||
name: Adapter name to activate, or None to disable all LoRA
|
||||
"""
|
||||
if name is not None and name not in self.adapters:
|
||||
raise KeyError(f"Adapter '{name}' not loaded. Available: {list(self.adapters.keys())}")
|
||||
|
||||
# Clear all LoRA from layers
|
||||
for lora_linear in self._lora_layers.values():
|
||||
lora_linear.clear_lora()
|
||||
|
||||
self.current_adapter = name
|
||||
|
||||
if name is None:
|
||||
return
|
||||
|
||||
# Apply the selected adapter
|
||||
adapter = self.adapters[name]
|
||||
for module_path, (A, B) in adapter.weights.items():
|
||||
if module_path in self._lora_layers:
|
||||
self._lora_layers[module_path].set_lora(A, B, adapter.config.scaling)
|
||||
|
||||
def unload_adapter(self, name: str):
|
||||
"""
|
||||
Unload an adapter from memory.
|
||||
|
||||
Args:
|
||||
name: Name of adapter to unload
|
||||
"""
|
||||
if name not in self.adapters:
|
||||
return
|
||||
|
||||
if self.current_adapter == name:
|
||||
self.set_adapter(None)
|
||||
|
||||
del self.adapters[name]
|
||||
|
||||
def list_adapters(self) -> List[str]:
|
||||
"""Return list of loaded adapter names."""
|
||||
return list(self.adapters.keys())
|
||||
|
||||
def get_memory_usage(self) -> Dict[str, float]:
|
||||
"""Return memory usage in MB for each loaded adapter."""
|
||||
return {name: adapter.memory_footprint_mb() for name, adapter in self.adapters.items()}
|
||||
|
||||
def restore_original_layers(self):
|
||||
"""
|
||||
Restore the original Linear layers, removing LoRA wrappers.
|
||||
|
||||
Call this if you want to go back to the original model structure.
|
||||
"""
|
||||
for path, original in self._original_layers.items():
|
||||
_set_module_by_path(self.model, path, original)
|
||||
|
||||
self._lora_layers.clear()
|
||||
self._original_layers.clear()
|
||||
self._initialized = False
|
||||
self.current_adapter = None
|
||||
|
||||
Reference in New Issue
Block a user