mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-08 06:44:09 +00:00
Compare commits
8 Commits
rework_sta
...
VAD-evolut
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9d4ae33249 | ||
|
|
6206fff118 | ||
|
|
b5067249c0 | ||
|
|
f4f9831d39 | ||
|
|
254faaf64c | ||
|
|
8e7aea4fcf | ||
|
|
270faf2069 | ||
|
|
b7c1cc77cc |
@@ -141,7 +141,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
|-----------|-------------|---------|
|
||||
| `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/available_models.md) | `small` |
|
||||
| `--model-path` | Local .pt file/directory **or** Hugging Face repo ID containing the Whisper model. Overrides `--model`. Recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/models_compatible_formats.md) | `None` |
|
||||
| `--language` | List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` |
|
||||
| `--language` | List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` |
|
||||
| `--target-language` | If sets, translates using [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting). [200 languages available](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/supported_languages.md). If you want to translate to english, you can also use `--direct-english-translation`. The STT model will try to directly output the translation. | `None` |
|
||||
| `--diarization` | Enable speaker identification | `False` |
|
||||
| `--backend-policy` | Streaming strategy: `1`/`simulstreaming` uses AlignAtt SimulStreaming, `2`/`localagreement` uses the LocalAgreement policy | `simulstreaming` |
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
- Example 2: The punctuation from STT comes from prediction `t`, but the speaker change from Diariation come in the prediction `t-1`
|
||||
- Example 3: The punctuation from STT comes from prediction `t-1`, but the speaker change from Diariation come in the prediction `t`
|
||||
|
||||
> `#` Is the split between the `t-1` prediction and t prediction.
|
||||
> `#` Is the split between the `t-1` prediction and `t` prediction.
|
||||
|
||||
|
||||
## Example 1:
|
||||
|
||||
43
docs/technical_integration.md
Normal file
43
docs/technical_integration.md
Normal file
@@ -0,0 +1,43 @@
|
||||
# Technical Integration Guide
|
||||
|
||||
This document introduce how to reuse the core components when you do **not** want to ship the bundled frontend, FastAPI server, or even the provided CLI.
|
||||
|
||||
---
|
||||
|
||||
## 1. Runtime Components
|
||||
|
||||
| Layer | File(s) | Purpose |
|
||||
|-------|---------|---------|
|
||||
| Transport | `whisperlivekit/basic_server.py`, any ASGI/WebSocket server | Accepts audio over WebSocket (MediaRecorder WebM or raw PCM chunks) and streams JSON updates back |
|
||||
| Audio processing | `whisperlivekit/audio_processor.py` | Buffers audio, orchestrates transcription, diarization, translation, handles FFmpeg/PCM input |
|
||||
| Engines | `whisperlivekit/core.py`, `whisperlivekit/simul_whisper/*`, `whisperlivekit/local_agreement/*` | Load models once (SimulStreaming or LocalAgreement), expose `TranscriptionEngine` and helpers |
|
||||
| Frontends | `whisperlivekit/web/*`, `chrome-extension/*` | Optional UI layers feeding the WebSocket endpoint |
|
||||
|
||||
**Key idea:** The server boundary is just `AudioProcessor.process_audio()` for incoming bytes and the async generator returned by `AudioProcessor.create_tasks()` for outgoing updates (`FrontData`). Everything else is optional.
|
||||
|
||||
---
|
||||
|
||||
## 2. Running Without the Bundled Frontend
|
||||
|
||||
1. Start the server/engine however you like:
|
||||
```bash
|
||||
wlk --model small --language en --host 0.0.0.0 --port 9000
|
||||
# or launch your own app that instantiates TranscriptionEngine(...)
|
||||
```
|
||||
2. Build your own client (browser, mobile, desktop) that:
|
||||
- Opens `ws(s)://<host>:<port>/asr`
|
||||
- Sends either MediaRecorder/Opus WebM blobs **or** raw PCM (`--pcm-input` on the server tells the client to use the AudioWorklet).
|
||||
- Consumes the JSON payload defined in `docs/API.md`.
|
||||
|
||||
---
|
||||
|
||||
## 3. Running Without FastAPI
|
||||
|
||||
`whisperlivekit/basic_server.py` is just an example. Any async framework works, as long as you:
|
||||
|
||||
1. Create a global `TranscriptionEngine` (expensive to initialize; reuse it).
|
||||
2. Instantiate `AudioProcessor(transcription_engine=engine)` for each connection.
|
||||
3. Call `create_tasks()` to get the async generator, `process_audio()` with incoming bytes, and ensure `cleanup()` runs when the client disconnects.
|
||||
|
||||
|
||||
If you prefer to send compressed audio, instantiate `AudioProcessor(pcm_input=False)` and pipe encoded chunks through `FFmpegManager` transparently—just ensure `ffmpeg` is available or be ready to handle the `"ffmpeg_not_found"` error in the streamed `FrontData`.
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "whisperlivekit"
|
||||
version = "0.2.14.post4"
|
||||
version = "0.2.15"
|
||||
description = "Real-time speech-to-text with speaker diarization using Whisper"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
|
||||
@@ -1,168 +0,0 @@
|
||||
class TokensAlignment:
|
||||
|
||||
def __init__(self, state_light, silence=None, args=None):
|
||||
self.state_light = state_light
|
||||
self.silence = silence
|
||||
self.args = args
|
||||
|
||||
self._tokens_index = 0
|
||||
self._diarization_index = 0
|
||||
self._translation_index = 0
|
||||
|
||||
def update(self):
|
||||
pass
|
||||
|
||||
|
||||
def compute_punctuations_segments(self):
|
||||
punctuations_breaks = []
|
||||
new_tokens = self.state.tokens[self.state.last_validated_token:]
|
||||
for i in range(len(new_tokens)):
|
||||
token = new_tokens[i]
|
||||
if token.is_punctuation():
|
||||
punctuations_breaks.append({
|
||||
'token_index': i,
|
||||
'token': token,
|
||||
'start': token.start,
|
||||
'end': token.end,
|
||||
})
|
||||
punctuations_segments = []
|
||||
for i, break_info in enumerate(punctuations_breaks):
|
||||
start = punctuations_breaks[i - 1]['end'] if i > 0 else 0.0
|
||||
end = break_info['end']
|
||||
punctuations_segments.append({
|
||||
'start': start,
|
||||
'end': end,
|
||||
'token_index': break_info['token_index'],
|
||||
'token': break_info['token']
|
||||
})
|
||||
return punctuations_segments
|
||||
|
||||
def concatenate_diar_segments(self):
|
||||
diarization_segments = self.state.diarization_segments
|
||||
|
||||
if __name__ == "__main__":
|
||||
from whisperlivekit.timed_objects import State, ASRToken, SpeakerSegment, Transcript, Silence
|
||||
|
||||
# Reconstruct the state from the backup data
|
||||
tokens = [
|
||||
ASRToken(start=1.38, end=1.48, text=' The'),
|
||||
ASRToken(start=1.42, end=1.52, text=' description'),
|
||||
ASRToken(start=1.82, end=1.92, text=' technology'),
|
||||
ASRToken(start=2.54, end=2.64, text=' has'),
|
||||
ASRToken(start=2.7, end=2.8, text=' improved'),
|
||||
ASRToken(start=3.24, end=3.34, text=' so'),
|
||||
ASRToken(start=3.66, end=3.76, text=' much'),
|
||||
ASRToken(start=4.02, end=4.12, text=' in'),
|
||||
ASRToken(start=4.08, end=4.18, text=' the'),
|
||||
ASRToken(start=4.26, end=4.36, text=' past'),
|
||||
ASRToken(start=4.48, end=4.58, text=' few'),
|
||||
ASRToken(start=4.76, end=4.86, text=' years'),
|
||||
ASRToken(start=5.76, end=5.86, text='.'),
|
||||
ASRToken(start=5.72, end=5.82, text=' Have'),
|
||||
ASRToken(start=5.92, end=6.02, text=' you'),
|
||||
ASRToken(start=6.08, end=6.18, text=' noticed'),
|
||||
ASRToken(start=6.52, end=6.62, text=' how'),
|
||||
ASRToken(start=6.8, end=6.9, text=' accurate'),
|
||||
ASRToken(start=7.46, end=7.56, text=' real'),
|
||||
ASRToken(start=7.72, end=7.82, text='-time'),
|
||||
ASRToken(start=8.06, end=8.16, text=' speech'),
|
||||
ASRToken(start=8.48, end=8.58, text=' to'),
|
||||
ASRToken(start=8.68, end=8.78, text=' text'),
|
||||
ASRToken(start=9.0, end=9.1, text=' is'),
|
||||
ASRToken(start=9.24, end=9.34, text=' now'),
|
||||
ASRToken(start=9.82, end=9.92, text='?'),
|
||||
ASRToken(start=9.86, end=9.96, text=' Absolutely'),
|
||||
ASRToken(start=11.26, end=11.36, text='.'),
|
||||
ASRToken(start=11.36, end=11.46, text=' I'),
|
||||
ASRToken(start=11.58, end=11.68, text=' use'),
|
||||
ASRToken(start=11.78, end=11.88, text=' it'),
|
||||
ASRToken(start=11.94, end=12.04, text=' all'),
|
||||
ASRToken(start=12.08, end=12.18, text=' the'),
|
||||
ASRToken(start=12.32, end=12.42, text=' time'),
|
||||
ASRToken(start=12.58, end=12.68, text=' for'),
|
||||
ASRToken(start=12.78, end=12.88, text=' taking'),
|
||||
ASRToken(start=13.14, end=13.24, text=' notes'),
|
||||
ASRToken(start=13.4, end=13.5, text=' during'),
|
||||
ASRToken(start=13.78, end=13.88, text=' meetings'),
|
||||
ASRToken(start=14.6, end=14.7, text='.'),
|
||||
ASRToken(start=14.82, end=14.92, text=' It'),
|
||||
ASRToken(start=14.92, end=15.02, text="'s"),
|
||||
ASRToken(start=15.04, end=15.14, text=' amazing'),
|
||||
ASRToken(start=15.5, end=15.6, text=' how'),
|
||||
ASRToken(start=15.66, end=15.76, text=' it'),
|
||||
ASRToken(start=15.8, end=15.9, text=' can'),
|
||||
ASRToken(start=15.96, end=16.06, text=' recognize'),
|
||||
ASRToken(start=16.58, end=16.68, text=' different'),
|
||||
ASRToken(start=16.94, end=17.04, text=' speakers'),
|
||||
ASRToken(start=17.82, end=17.92, text=' and'),
|
||||
ASRToken(start=18.0, end=18.1, text=' even'),
|
||||
ASRToken(start=18.42, end=18.52, text=' add'),
|
||||
ASRToken(start=18.74, end=18.84, text=' punct'),
|
||||
ASRToken(start=19.02, end=19.12, text='uation'),
|
||||
ASRToken(start=19.68, end=19.78, text='.'),
|
||||
ASRToken(start=20.04, end=20.14, text=' Yeah'),
|
||||
ASRToken(start=20.5, end=20.6, text=','),
|
||||
ASRToken(start=20.6, end=20.7, text=' but'),
|
||||
ASRToken(start=20.76, end=20.86, text=' sometimes'),
|
||||
ASRToken(start=21.42, end=21.52, text=' noise'),
|
||||
ASRToken(start=21.82, end=21.92, text=' can'),
|
||||
ASRToken(start=22.08, end=22.18, text=' still'),
|
||||
ASRToken(start=22.38, end=22.48, text=' cause'),
|
||||
ASRToken(start=22.72, end=22.82, text=' mistakes'),
|
||||
ASRToken(start=23.74, end=23.84, text='.'),
|
||||
ASRToken(start=23.96, end=24.06, text=' Does'),
|
||||
ASRToken(start=24.16, end=24.26, text=' this'),
|
||||
ASRToken(start=24.4, end=24.5, text=' system'),
|
||||
ASRToken(start=24.76, end=24.86, text=' handle'),
|
||||
ASRToken(start=25.12, end=25.22, text=' that'),
|
||||
ASRToken(start=25.38, end=25.48, text=' well'),
|
||||
ASRToken(start=25.68, end=25.78, text='?'),
|
||||
ASRToken(start=26.4, end=26.5, text=' It'),
|
||||
ASRToken(start=26.5, end=26.6, text=' does'),
|
||||
ASRToken(start=26.7, end=26.8, text=' a'),
|
||||
ASRToken(start=27.08, end=27.18, text=' pretty'),
|
||||
ASRToken(start=27.12, end=27.22, text=' good'),
|
||||
ASRToken(start=27.34, end=27.44, text=' job'),
|
||||
ASRToken(start=27.64, end=27.74, text=' filtering'),
|
||||
ASRToken(start=28.1, end=28.2, text=' noise'),
|
||||
ASRToken(start=28.64, end=28.74, text=','),
|
||||
ASRToken(start=28.78, end=28.88, text=' especially'),
|
||||
ASRToken(start=29.3, end=29.4, text=' with'),
|
||||
ASRToken(start=29.51, end=29.61, text=' models'),
|
||||
ASRToken(start=29.99, end=30.09, text=' that'),
|
||||
ASRToken(start=30.21, end=30.31, text=' use'),
|
||||
ASRToken(start=30.51, end=30.61, text=' voice'),
|
||||
ASRToken(start=30.83, end=30.93, text=' activity'),
|
||||
]
|
||||
|
||||
diarization_segments = [
|
||||
SpeakerSegment(start=1.3255040645599365, end=4.3255040645599365, speaker=0),
|
||||
SpeakerSegment(start=4.806154012680054, end=9.806154012680054, speaker=0),
|
||||
SpeakerSegment(start=9.806154012680054, end=10.806154012680054, speaker=1),
|
||||
SpeakerSegment(start=11.168735027313232, end=14.168735027313232, speaker=1),
|
||||
SpeakerSegment(start=14.41029405593872, end=17.41029405593872, speaker=1),
|
||||
SpeakerSegment(start=17.52983808517456, end=19.52983808517456, speaker=1),
|
||||
SpeakerSegment(start=19.64953374862671, end=20.066200415293377, speaker=1),
|
||||
SpeakerSegment(start=20.066200415293377, end=22.64953374862671, speaker=2),
|
||||
SpeakerSegment(start=23.012792587280273, end=25.012792587280273, speaker=2),
|
||||
SpeakerSegment(start=25.495875597000122, end=26.41254226366679, speaker=2),
|
||||
SpeakerSegment(start=26.41254226366679, end=30.495875597000122, speaker=0),
|
||||
]
|
||||
|
||||
state = State(
|
||||
tokens=tokens,
|
||||
last_validated_token=72,
|
||||
last_speaker=-1,
|
||||
last_punctuation_index=71,
|
||||
translation_validated_segments=[],
|
||||
buffer_translation=Transcript(start=0, end=0, speaker=-1),
|
||||
buffer_transcription=Transcript(start=None, end=None, speaker=-1),
|
||||
diarization_segments=diarization_segments,
|
||||
end_buffer=31.21587559700018,
|
||||
end_attributed_speaker=30.495875597000122,
|
||||
remaining_time_transcription=0.4,
|
||||
remaining_time_diarization=0.7,
|
||||
beg_loop=1763627603.968919
|
||||
)
|
||||
|
||||
alignment = TokensAlignment(state)
|
||||
@@ -1,36 +1,23 @@
|
||||
import asyncio
|
||||
import numpy as np
|
||||
from time import time, sleep
|
||||
import math
|
||||
from time import time
|
||||
import logging
|
||||
import traceback
|
||||
from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State, StateLight, Transcript, ChangeSpeaker
|
||||
from typing import Optional, Union, List, Any, AsyncGenerator
|
||||
from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State, Transcript, ChangeSpeaker
|
||||
from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory, online_translation_factory
|
||||
from whisperlivekit.silero_vad_iterator import FixedVADIterator
|
||||
from whisperlivekit.results_formater import format_output
|
||||
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
|
||||
from whisperlivekit.TokensAlignment import TokensAlignment
|
||||
from whisperlivekit.tokens_alignment import TokensAlignment
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
SENTINEL = object() # unique sentinel object for end of stream marker
|
||||
MIN_DURATION_REAL_SILENCE = 5
|
||||
|
||||
def cut_at(cumulative_pcm, cut_sec):
|
||||
cumulative_len = 0
|
||||
cut_sample = int(cut_sec * 16000)
|
||||
|
||||
for ind, pcm_array in enumerate(cumulative_pcm):
|
||||
if (cumulative_len + len(pcm_array)) >= cut_sample:
|
||||
cut_chunk = cut_sample - cumulative_len
|
||||
before = np.concatenate(cumulative_pcm[:ind] + [cumulative_pcm[ind][:cut_chunk]])
|
||||
after = [cumulative_pcm[ind][cut_chunk:]] + cumulative_pcm[ind+1:]
|
||||
return before, after
|
||||
cumulative_len += len(pcm_array)
|
||||
return np.concatenate(cumulative_pcm), []
|
||||
|
||||
async def get_all_from_queue(queue):
|
||||
items = []
|
||||
async def get_all_from_queue(queue: asyncio.Queue) -> Union[object, Silence, np.ndarray, List[Any]]:
|
||||
items: List[Any] = []
|
||||
|
||||
first_item = await queue.get()
|
||||
queue.task_done()
|
||||
@@ -61,7 +48,7 @@ class AudioProcessor:
|
||||
Handles audio processing, state management, and result formatting.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the audio processor with configuration, models, and state."""
|
||||
|
||||
if 'transcription_engine' in kwargs and isinstance(kwargs['transcription_engine'], TranscriptionEngine):
|
||||
@@ -80,33 +67,27 @@ class AudioProcessor:
|
||||
self.is_pcm_input = self.args.pcm_input
|
||||
|
||||
# State management
|
||||
self.is_stopping = False
|
||||
self.silence = True
|
||||
self.silence_duration = 0.0
|
||||
self.start_silence = None
|
||||
self.last_silence_dispatch_time = None
|
||||
self.state = State()
|
||||
self.state_light = StateLight()
|
||||
self.lock = asyncio.Lock()
|
||||
self.sep = " " # Default separator
|
||||
self.last_response_content = FrontData()
|
||||
self.last_detected_speaker = None
|
||||
self.speaker_languages = {}
|
||||
self.is_stopping: bool = False
|
||||
self.current_silence: Optional[Silence] = None
|
||||
self.state: State = State()
|
||||
self.lock: asyncio.Lock = asyncio.Lock()
|
||||
self.sep: str = " " # Default separator
|
||||
self.last_response_content: FrontData = FrontData()
|
||||
|
||||
self.tokens_alignment = TokensAlignment(self.state_light, self.args, self.sep)
|
||||
self.beg_loop = None
|
||||
self.tokens_alignment: TokensAlignment = TokensAlignment(self.state, self.args, self.sep)
|
||||
self.beg_loop: Optional[float] = None
|
||||
|
||||
# Models and processing
|
||||
self.asr = models.asr
|
||||
self.vac_model = models.vac_model
|
||||
self.asr: Any = models.asr
|
||||
self.vac_model: Any = models.vac_model
|
||||
if self.args.vac:
|
||||
self.vac = FixedVADIterator(models.vac_model)
|
||||
self.vac: Optional[FixedVADIterator] = FixedVADIterator(models.vac_model)
|
||||
else:
|
||||
self.vac = None
|
||||
self.vac: Optional[FixedVADIterator] = None
|
||||
|
||||
self.ffmpeg_manager = None
|
||||
self.ffmpeg_reader_task = None
|
||||
self._ffmpeg_error = None
|
||||
self.ffmpeg_manager: Optional[FFmpegManager] = None
|
||||
self.ffmpeg_reader_task: Optional[asyncio.Task] = None
|
||||
self._ffmpeg_error: Optional[str] = None
|
||||
|
||||
if not self.is_pcm_input:
|
||||
self.ffmpeg_manager = FFmpegManager(
|
||||
@@ -118,21 +99,20 @@ class AudioProcessor:
|
||||
self._ffmpeg_error = error_type
|
||||
self.ffmpeg_manager.on_error_callback = handle_ffmpeg_error
|
||||
|
||||
self.transcription_queue = asyncio.Queue() if self.args.transcription else None
|
||||
self.diarization_queue = asyncio.Queue() if self.args.diarization else None
|
||||
self.translation_queue = asyncio.Queue() if self.args.target_language else None
|
||||
self.pcm_buffer = bytearray()
|
||||
self.total_pcm_samples = 0
|
||||
self.end_buffer = 0.0
|
||||
self.transcription_task = None
|
||||
self.diarization_task = None
|
||||
self.translation_task = None
|
||||
self.watchdog_task = None
|
||||
self.all_tasks_for_cleanup = []
|
||||
self.transcription_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.transcription else None
|
||||
self.diarization_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.diarization else None
|
||||
self.translation_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.target_language else None
|
||||
self.pcm_buffer: bytearray = bytearray()
|
||||
self.total_pcm_samples: int = 0
|
||||
self.transcription_task: Optional[asyncio.Task] = None
|
||||
self.diarization_task: Optional[asyncio.Task] = None
|
||||
self.translation_task: Optional[asyncio.Task] = None
|
||||
self.watchdog_task: Optional[asyncio.Task] = None
|
||||
self.all_tasks_for_cleanup: List[asyncio.Task] = []
|
||||
|
||||
self.transcription = None
|
||||
self.translation = None
|
||||
self.diarization = None
|
||||
self.transcription: Optional[Any] = None
|
||||
self.translation: Optional[Any] = None
|
||||
self.diarization: Optional[Any] = None
|
||||
|
||||
if self.args.transcription:
|
||||
self.transcription = online_factory(self.args, models.asr)
|
||||
@@ -142,44 +122,45 @@ class AudioProcessor:
|
||||
if models.translation_model:
|
||||
self.translation = online_translation_factory(self.args, models.translation_model)
|
||||
|
||||
async def _push_silence_event(self, silence_buffer: Silence):
|
||||
async def _push_silence_event(self) -> None:
|
||||
if self.transcription_queue:
|
||||
await self.transcription_queue.put(silence_buffer)
|
||||
await self.transcription_queue.put(self.current_silence)
|
||||
if self.args.diarization and self.diarization_queue:
|
||||
await self.diarization_queue.put(silence_buffer)
|
||||
await self.diarization_queue.put(self.current_silence)
|
||||
if self.translation_queue:
|
||||
await self.translation_queue.put(silence_buffer)
|
||||
await self.translation_queue.put(self.current_silence)
|
||||
|
||||
async def _begin_silence(self):
|
||||
if self.silence:
|
||||
async def _begin_silence(self) -> None:
|
||||
if self.current_silence:
|
||||
return
|
||||
self.silence = True
|
||||
now = time()
|
||||
self.start_silence = now
|
||||
self.last_silence_dispatch_time = now
|
||||
await self._push_silence_event(Silence(is_starting=True))
|
||||
now = time() - self.beg_loop
|
||||
self.current_silence = Silence(
|
||||
is_starting=True, start=now
|
||||
)
|
||||
await self._push_silence_event()
|
||||
|
||||
async def _end_silence(self):
|
||||
if not self.silence:
|
||||
async def _end_silence(self) -> None:
|
||||
if not self.current_silence:
|
||||
return
|
||||
now = time()
|
||||
duration = now - (self.last_silence_dispatch_time if self.last_silence_dispatch_time else self.beg_loop)
|
||||
await self._push_silence_event(Silence(duration=duration, has_ended=True))
|
||||
self.last_silence_dispatch_time = now
|
||||
self.silence = False
|
||||
self.start_silence = None
|
||||
self.last_silence_dispatch_time = None
|
||||
now = time() - self.beg_loop
|
||||
self.current_silence.end = now
|
||||
self.current_silence.is_starting=False
|
||||
self.current_silence.has_ended=True
|
||||
self.current_silence.compute_duration()
|
||||
if self.current_silence.duration > MIN_DURATION_REAL_SILENCE:
|
||||
self.state.new_tokens.append(self.current_silence)
|
||||
await self._push_silence_event()
|
||||
self.current_silence = None
|
||||
|
||||
async def _enqueue_active_audio(self, pcm_chunk: np.ndarray):
|
||||
async def _enqueue_active_audio(self, pcm_chunk: np.ndarray) -> None:
|
||||
if pcm_chunk is None or pcm_chunk.size == 0:
|
||||
return
|
||||
if self.transcription_queue:
|
||||
await self.transcription_queue.put(pcm_chunk.copy())
|
||||
if self.args.diarization and self.diarization_queue:
|
||||
await self.diarization_queue.put(pcm_chunk.copy())
|
||||
self.silence_duration = 0.0
|
||||
|
||||
def _slice_before_silence(self, pcm_array, chunk_sample_start, silence_sample):
|
||||
def _slice_before_silence(self, pcm_array: np.ndarray, chunk_sample_start: int, silence_sample: Optional[int]) -> Optional[np.ndarray]:
|
||||
if silence_sample is None:
|
||||
return None
|
||||
relative_index = int(silence_sample - chunk_sample_start)
|
||||
@@ -190,22 +171,22 @@ class AudioProcessor:
|
||||
return None
|
||||
return pcm_array[:split_index]
|
||||
|
||||
def convert_pcm_to_float(self, pcm_buffer):
|
||||
def convert_pcm_to_float(self, pcm_buffer: Union[bytes, bytearray]) -> np.ndarray:
|
||||
"""Convert PCM buffer in s16le format to normalized NumPy array."""
|
||||
return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
|
||||
|
||||
async def get_current_state(self):
|
||||
async def get_current_state(self) -> State:
|
||||
"""Get current state."""
|
||||
async with self.lock:
|
||||
current_time = time()
|
||||
|
||||
remaining_transcription = 0
|
||||
if self.end_buffer > 0:
|
||||
remaining_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 1))
|
||||
if self.state.end_buffer > 0:
|
||||
remaining_transcription = max(0, round(current_time - self.beg_loop - self.state.end_buffer, 1))
|
||||
|
||||
remaining_diarization = 0
|
||||
if self.state.tokens:
|
||||
latest_end = max(self.end_buffer, self.state.tokens[-1].end if self.state.tokens else 0)
|
||||
latest_end = max(self.state.end_buffer, self.state.tokens[-1].end if self.state.tokens else 0)
|
||||
remaining_diarization = max(0, round(latest_end - self.state.end_attributed_speaker, 1))
|
||||
|
||||
self.state.remaining_time_transcription = remaining_transcription
|
||||
@@ -213,7 +194,7 @@ class AudioProcessor:
|
||||
|
||||
return self.state
|
||||
|
||||
async def ffmpeg_stdout_reader(self):
|
||||
async def ffmpeg_stdout_reader(self) -> None:
|
||||
"""Read audio data from FFmpeg stdout and process it into the PCM pipeline."""
|
||||
beg = time()
|
||||
while True:
|
||||
@@ -263,7 +244,7 @@ class AudioProcessor:
|
||||
if self.translation:
|
||||
await self.translation_queue.put(SENTINEL)
|
||||
|
||||
async def transcription_processor(self):
|
||||
async def transcription_processor(self) -> None:
|
||||
"""Process audio chunks for transcription."""
|
||||
cumulative_pcm_duration_stream_time = 0.0
|
||||
|
||||
@@ -276,11 +257,11 @@ class AudioProcessor:
|
||||
break
|
||||
|
||||
asr_internal_buffer_duration_s = len(getattr(self.transcription, 'audio_buffer', [])) / self.transcription.SAMPLING_RATE
|
||||
transcription_lag_s = max(0.0, time() - self.beg_loop - self.end_buffer)
|
||||
transcription_lag_s = max(0.0, time() - self.beg_loop - self.state.end_buffer)
|
||||
asr_processing_logs = f"internal_buffer={asr_internal_buffer_duration_s:.2f}s | lag={transcription_lag_s:.2f}s |"
|
||||
stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time
|
||||
new_tokens = []
|
||||
current_audio_processed_upto = self.end_buffer
|
||||
current_audio_processed_upto = self.state.end_buffer
|
||||
|
||||
if isinstance(item, Silence):
|
||||
if item.is_starting:
|
||||
@@ -318,7 +299,7 @@ class AudioProcessor:
|
||||
if buffer_text.startswith(validated_text):
|
||||
_buffer_transcript.text = buffer_text[len(validated_text):].lstrip()
|
||||
|
||||
candidate_end_times = [self.end_buffer]
|
||||
candidate_end_times = [self.state.end_buffer]
|
||||
|
||||
if new_tokens:
|
||||
candidate_end_times.append(new_tokens[-1].end)
|
||||
@@ -331,10 +312,9 @@ class AudioProcessor:
|
||||
async with self.lock:
|
||||
self.state.tokens.extend(new_tokens)
|
||||
self.state.buffer_transcription = _buffer_transcript
|
||||
self.end_buffer = max(candidate_end_times)
|
||||
self.state_light.new_tokens = new_tokens
|
||||
self.state_light.new_tokens += 1
|
||||
self.state_light.new_tokens_buffer = _buffer_transcript
|
||||
self.state.end_buffer = max(candidate_end_times)
|
||||
self.state.new_tokens.extend(new_tokens)
|
||||
self.state.new_tokens_buffer = _buffer_transcript
|
||||
|
||||
if self.translation_queue:
|
||||
for token in new_tokens:
|
||||
@@ -355,7 +335,7 @@ class AudioProcessor:
|
||||
logger.info("Transcription processor task finished.")
|
||||
|
||||
|
||||
async def diarization_processor(self):
|
||||
async def diarization_processor(self) -> None:
|
||||
while True:
|
||||
try:
|
||||
item = await get_all_from_queue(self.diarization_queue)
|
||||
@@ -368,41 +348,44 @@ class AudioProcessor:
|
||||
|
||||
self.diarization.insert_audio_chunk(item)
|
||||
diarization_segments = await self.diarization.diarize()
|
||||
self.state_light.new_diarization = diarization_segments
|
||||
self.state_light.new_diarization_index += 1
|
||||
self.state.new_diarization = diarization_segments
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in diarization_processor: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
logger.info("Diarization processor task finished.")
|
||||
|
||||
async def translation_processor(self):
|
||||
async def translation_processor(self) -> None:
|
||||
# the idea is to ignore diarization for the moment. We use only transcription tokens.
|
||||
# And the speaker is attributed given the segments used for the translation
|
||||
# in the future we want to have different languages for each speaker etc, so it will be more complex.
|
||||
while True:
|
||||
try:
|
||||
tokens_to_process = await get_all_from_queue(self.translation_queue)
|
||||
if tokens_to_process is SENTINEL:
|
||||
item = await get_all_from_queue(self.translation_queue)
|
||||
if item is SENTINEL:
|
||||
logger.debug("Translation processor received sentinel. Finishing.")
|
||||
self.translation_queue.task_done()
|
||||
break
|
||||
elif type(tokens_to_process) is Silence:
|
||||
if tokens_to_process.has_ended:
|
||||
self.translation.insert_silence(tokens_to_process.duration)
|
||||
continue
|
||||
if tokens_to_process:
|
||||
self.translation.insert_tokens(tokens_to_process)
|
||||
translation_validated_segments, buffer_translation = await asyncio.to_thread(self.translation.process)
|
||||
async with self.lock:
|
||||
self.state.translation_validated_segments = translation_validated_segments
|
||||
self.state.buffer_translation = buffer_translation
|
||||
elif type(item) is Silence:
|
||||
if item.is_starting:
|
||||
new_translation, new_translation_buffer = self.translation.validate_buffer_and_reset()
|
||||
if item.has_ended:
|
||||
self.translation.insert_silence(item.duration)
|
||||
continue
|
||||
elif isinstance(item, ChangeSpeaker):
|
||||
new_translation, new_translation_buffer = self.translation.validate_buffer_and_reset()
|
||||
pass
|
||||
else:
|
||||
self.translation.insert_tokens(item)
|
||||
new_translation, new_translation_buffer = await asyncio.to_thread(self.translation.process)
|
||||
async with self.lock:
|
||||
self.state.new_translation.append(new_translation)
|
||||
self.state.new_translation_buffer = new_translation_buffer
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in translation_processor: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
logger.info("Translation processor task finished.")
|
||||
|
||||
async def results_formatter(self):
|
||||
async def results_formatter(self) -> AsyncGenerator[FrontData, None]:
|
||||
"""Format processing results for output."""
|
||||
while True:
|
||||
try:
|
||||
@@ -412,55 +395,32 @@ class AudioProcessor:
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
state = await self.get_current_state()
|
||||
self.tokens_alignment.compute_punctuations_segments()
|
||||
lines, undiarized_text = format_output(
|
||||
state,
|
||||
self.silence,
|
||||
args = self.args,
|
||||
sep=self.sep
|
||||
self.tokens_alignment.update()
|
||||
lines, buffer_diarization_text, buffer_translation_text = self.tokens_alignment.get_lines(
|
||||
diarization=self.args.diarization,
|
||||
translation=bool(self.translation),
|
||||
current_silence=self.current_silence
|
||||
)
|
||||
if lines and lines[-1].speaker == -2:
|
||||
buffer_transcription = Transcript()
|
||||
else:
|
||||
buffer_transcription = state.buffer_transcription
|
||||
state = await self.get_current_state()
|
||||
|
||||
buffer_diarization = ''
|
||||
if undiarized_text:
|
||||
buffer_diarization = self.sep.join(undiarized_text)
|
||||
buffer_transcription_text = state.buffer_transcription.text if state.buffer_transcription else ''
|
||||
|
||||
async with self.lock:
|
||||
self.state.end_attributed_speaker = state.end_attributed_speaker
|
||||
|
||||
buffer_translation_text = ''
|
||||
if state.buffer_translation:
|
||||
raw_buffer_translation = getattr(state.buffer_translation, 'text', state.buffer_translation)
|
||||
if raw_buffer_translation:
|
||||
buffer_translation_text = raw_buffer_translation.strip()
|
||||
|
||||
response_status = "active_transcription"
|
||||
if not state.tokens and not buffer_transcription and not buffer_diarization:
|
||||
if not lines and not buffer_transcription_text and not buffer_diarization_text:
|
||||
response_status = "no_audio_detected"
|
||||
lines = []
|
||||
elif not lines:
|
||||
lines = [Line(
|
||||
speaker=1,
|
||||
start=state.end_buffer,
|
||||
end=state.end_buffer
|
||||
)]
|
||||
|
||||
|
||||
response = FrontData(
|
||||
status=response_status,
|
||||
lines=lines,
|
||||
buffer_transcription=buffer_transcription.text.strip(),
|
||||
buffer_diarization=buffer_diarization,
|
||||
buffer_transcription=buffer_transcription_text,
|
||||
buffer_diarization=buffer_diarization_text,
|
||||
buffer_translation=buffer_translation_text,
|
||||
remaining_time_transcription=state.remaining_time_transcription,
|
||||
remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0
|
||||
)
|
||||
|
||||
should_push = (response != self.last_response_content)
|
||||
if should_push and (lines or buffer_transcription or buffer_diarization or response_status == "no_audio_detected"):
|
||||
if should_push:
|
||||
yield response
|
||||
self.last_response_content = response
|
||||
|
||||
@@ -474,17 +434,17 @@ class AudioProcessor:
|
||||
logger.warning(f"Exception in results_formatter. Traceback: {traceback.format_exc()}")
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
async def create_tasks(self):
|
||||
async def create_tasks(self) -> AsyncGenerator[FrontData, None]:
|
||||
"""Create and start processing tasks."""
|
||||
self.all_tasks_for_cleanup = []
|
||||
processing_tasks_for_watchdog = []
|
||||
processing_tasks_for_watchdog: List[asyncio.Task] = []
|
||||
|
||||
# If using FFmpeg (non-PCM input), start it and spawn stdout reader
|
||||
if not self.is_pcm_input:
|
||||
success = await self.ffmpeg_manager.start()
|
||||
if not success:
|
||||
logger.error("Failed to start FFmpeg manager")
|
||||
async def error_generator():
|
||||
async def error_generator() -> AsyncGenerator[FrontData, None]:
|
||||
yield FrontData(
|
||||
status="error",
|
||||
error="FFmpeg failed to start. Please check that FFmpeg is installed."
|
||||
@@ -515,9 +475,9 @@ class AudioProcessor:
|
||||
|
||||
return self.results_formatter()
|
||||
|
||||
async def watchdog(self, tasks_to_monitor):
|
||||
async def watchdog(self, tasks_to_monitor: List[asyncio.Task]) -> None:
|
||||
"""Monitors the health of critical processing tasks."""
|
||||
tasks_remaining = [task for task in tasks_to_monitor if task]
|
||||
tasks_remaining: List[asyncio.Task] = [task for task in tasks_to_monitor if task]
|
||||
while True:
|
||||
try:
|
||||
if not tasks_remaining:
|
||||
@@ -542,7 +502,7 @@ class AudioProcessor:
|
||||
except Exception as e:
|
||||
logger.error(f"Error in watchdog task: {e}", exc_info=True)
|
||||
|
||||
async def cleanup(self):
|
||||
async def cleanup(self) -> None:
|
||||
"""Clean up resources when processing is complete."""
|
||||
logger.info("Starting cleanup of AudioProcessor resources.")
|
||||
self.is_stopping = True
|
||||
@@ -565,7 +525,7 @@ class AudioProcessor:
|
||||
self.diarization.close()
|
||||
logger.info("AudioProcessor cleanup complete.")
|
||||
|
||||
def _processing_tasks_done(self):
|
||||
def _processing_tasks_done(self) -> bool:
|
||||
"""Return True when all active processing tasks have completed."""
|
||||
tasks_to_check = [
|
||||
self.transcription_task,
|
||||
@@ -576,11 +536,13 @@ class AudioProcessor:
|
||||
return all(task.done() for task in tasks_to_check if task)
|
||||
|
||||
|
||||
async def process_audio(self, message):
|
||||
async def process_audio(self, message: Optional[bytes]) -> None:
|
||||
"""Process incoming audio data."""
|
||||
|
||||
if not self.beg_loop:
|
||||
self.beg_loop = time()
|
||||
self.current_silence = Silence(start=0.0, is_starting=True)
|
||||
self.tokens_alignment.beg_loop = self.beg_loop
|
||||
|
||||
if not message:
|
||||
logger.info("Empty audio message received, initiating stop sequence.")
|
||||
@@ -613,7 +575,7 @@ class AudioProcessor:
|
||||
else:
|
||||
logger.warning("Failed to write audio data to FFmpeg")
|
||||
|
||||
async def handle_pcm_data(self):
|
||||
async def handle_pcm_data(self) -> None:
|
||||
# Process when enough data
|
||||
if len(self.pcm_buffer) < self.bytes_per_sec:
|
||||
return
|
||||
@@ -642,17 +604,17 @@ class AudioProcessor:
|
||||
|
||||
if res is not None:
|
||||
silence_detected = res.get("end", 0) > res.get("start", 0)
|
||||
if silence_detected and not self.silence:
|
||||
if silence_detected and not self.current_silence:
|
||||
pre_silence_chunk = self._slice_before_silence(
|
||||
pcm_array, chunk_sample_start, res.get("end")
|
||||
)
|
||||
if pre_silence_chunk is not None and pre_silence_chunk.size > 0:
|
||||
await self._enqueue_active_audio(pre_silence_chunk)
|
||||
await self._begin_silence()
|
||||
elif self.silence:
|
||||
elif self.current_silence:
|
||||
await self._end_silence()
|
||||
|
||||
if not self.silence:
|
||||
if not self.current_silence:
|
||||
await self._enqueue_active_audio(pcm_array)
|
||||
|
||||
self.total_pcm_samples = chunk_sample_end
|
||||
|
||||
@@ -224,7 +224,8 @@ class MLXWhisper(ASRBase):
|
||||
if segment.get("no_speech_prob", 0) > 0.9:
|
||||
continue
|
||||
for word in segment.get("words", []):
|
||||
token = ASRToken(word["start"], word["end"], word["word"], probability=word["probability"])
|
||||
probability=word["probability"]
|
||||
token = ASRToken(word["start"], word["end"], word["word"])
|
||||
tokens.append(token)
|
||||
return tokens
|
||||
|
||||
|
||||
@@ -411,11 +411,11 @@ class OnlineASRProcessor:
|
||||
) -> Transcript:
|
||||
sep = sep if sep is not None else self.asr.sep
|
||||
text = sep.join(token.text for token in tokens)
|
||||
probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
|
||||
# probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
|
||||
if tokens:
|
||||
start = offset + tokens[0].start
|
||||
end = offset + tokens[-1].end
|
||||
else:
|
||||
start = None
|
||||
end = None
|
||||
return Transcript(start, end, text, probability=probability)
|
||||
return Transcript(start, end, text)
|
||||
|
||||
@@ -1,103 +0,0 @@
|
||||
from whisperlivekit.timed_objects import ASRToken
|
||||
from time import time
|
||||
import re
|
||||
|
||||
MIN_SILENCE_DURATION = 4 #in seconds
|
||||
END_SILENCE_DURATION = 8 #in seconds. you should keep it important to not have false positive when the model lag is important
|
||||
END_SILENCE_DURATION_VAC = 3 #VAC is good at detecting silences, but we want to skip the smallest silences
|
||||
|
||||
def blank_to_silence(tokens):
|
||||
full_string = ''.join([t.text for t in tokens])
|
||||
patterns = [re.compile(r'(?:\s*\[BLANK_AUDIO\]\s*)+'), re.compile(r'(?:\s*\[typing\]\s*)+')]
|
||||
matches = []
|
||||
for pattern in patterns:
|
||||
for m in pattern.finditer(full_string):
|
||||
matches.append({
|
||||
'start': m.start(),
|
||||
'end': m.end()
|
||||
})
|
||||
if matches:
|
||||
# cleaned = pattern.sub(' ', full_string).strip()
|
||||
# print("Cleaned:", cleaned)
|
||||
cumulated_len = 0
|
||||
silence_token = None
|
||||
cleaned_tokens = []
|
||||
for token in tokens:
|
||||
if matches:
|
||||
start = cumulated_len
|
||||
end = cumulated_len + len(token.text)
|
||||
cumulated_len = end
|
||||
if start >= matches[0]['start'] and end <= matches[0]['end']:
|
||||
if silence_token: #previous token was already silence
|
||||
silence_token.start = min(silence_token.start, token.start)
|
||||
silence_token.end = max(silence_token.end, token.end)
|
||||
else: #new silence
|
||||
silence_token = ASRToken(
|
||||
start=token.start,
|
||||
end=token.end,
|
||||
speaker=-2,
|
||||
)
|
||||
else:
|
||||
if silence_token: #there was silence but no more
|
||||
if silence_token.duration() >= MIN_SILENCE_DURATION:
|
||||
cleaned_tokens.append(
|
||||
silence_token
|
||||
)
|
||||
silence_token = None
|
||||
matches.pop(0)
|
||||
cleaned_tokens.append(token)
|
||||
# print(cleaned_tokens)
|
||||
return cleaned_tokens
|
||||
return tokens
|
||||
|
||||
def no_token_to_silence(tokens):
|
||||
new_tokens = []
|
||||
silence_token = None
|
||||
for token in tokens:
|
||||
if token.speaker == -2:
|
||||
if new_tokens and new_tokens[-1].speaker == -2: #if token is silence and previous one too
|
||||
new_tokens[-1].end = token.end
|
||||
else:
|
||||
new_tokens.append(token)
|
||||
|
||||
last_end = new_tokens[-1].end if new_tokens else 0.0
|
||||
if token.start - last_end >= MIN_SILENCE_DURATION: #if token is not silence but important gap
|
||||
if new_tokens and new_tokens[-1].speaker == -2:
|
||||
new_tokens[-1].end = token.start
|
||||
else:
|
||||
silence_token = ASRToken(
|
||||
start=last_end,
|
||||
end=token.start,
|
||||
speaker=-2,
|
||||
)
|
||||
new_tokens.append(silence_token)
|
||||
|
||||
if token.speaker != -2:
|
||||
new_tokens.append(token)
|
||||
return new_tokens
|
||||
|
||||
def ends_with_silence(tokens, beg_loop, vac_detected_silence):
|
||||
current_time = time() - (beg_loop if beg_loop else 0.0)
|
||||
last_token = tokens[-1]
|
||||
if vac_detected_silence or (current_time - last_token.end >= END_SILENCE_DURATION):
|
||||
if last_token.speaker == -2:
|
||||
last_token.end = current_time
|
||||
else:
|
||||
tokens.append(
|
||||
ASRToken(
|
||||
start=tokens[-1].end,
|
||||
end=current_time,
|
||||
speaker=-2,
|
||||
)
|
||||
)
|
||||
return tokens
|
||||
|
||||
|
||||
def handle_silences(tokens, beg_loop, vac_detected_silence):
|
||||
if not tokens:
|
||||
return []
|
||||
tokens = blank_to_silence(tokens) #useful for simulstreaming backend which tends to generate [BLANK_AUDIO] text
|
||||
tokens = no_token_to_silence(tokens)
|
||||
tokens = ends_with_silence(tokens, beg_loop, vac_detected_silence)
|
||||
return tokens
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
########### WHAT IS PRODUCED: ###########
|
||||
|
||||
SPEAKER 1 0:00:04 - 0:00:06
|
||||
Transcription technology has improved so much in the past
|
||||
|
||||
SPEAKER 1 0:00:07 - 0:00:12
|
||||
years. Have you noticed how accurate real-time speech detects is now?
|
||||
|
||||
SPEAKER 2 0:00:12 - 0:00:12
|
||||
Absolutely
|
||||
|
||||
SPEAKER 1 0:00:13 - 0:00:13
|
||||
.
|
||||
|
||||
SPEAKER 2 0:00:14 - 0:00:14
|
||||
I
|
||||
|
||||
SPEAKER 1 0:00:14 - 0:00:17
|
||||
use it all the time for taking notes during meetings.
|
||||
|
||||
SPEAKER 2 0:00:17 - 0:00:17
|
||||
It
|
||||
|
||||
SPEAKER 1 0:00:17 - 0:00:22
|
||||
's amazing how it can recognize different speakers, and even add punctuation.
|
||||
|
||||
SPEAKER 2 0:00:22 - 0:00:22
|
||||
Yeah
|
||||
|
||||
SPEAKER 1 0:00:23 - 0:00:26
|
||||
, but sometimes noise can still cause mistakes.
|
||||
|
||||
SPEAKER 3 0:00:26 - 0:00:27
|
||||
Does
|
||||
|
||||
SPEAKER 1 0:00:27 - 0:00:28
|
||||
this system handle that
|
||||
|
||||
SPEAKER 1 0:00:29 - 0:00:29
|
||||
?
|
||||
|
||||
SPEAKER 3 0:00:29 - 0:00:29
|
||||
It
|
||||
|
||||
SPEAKER 1 0:00:29 - 0:00:33
|
||||
does a pretty good job filtering noise, especially with models that use voice activity
|
||||
|
||||
########### WHAT SHOULD BE PRODUCED: ###########
|
||||
|
||||
SPEAKER 1 0:00:04 - 0:00:12
|
||||
Transcription technology has improved so much in the past years. Have you noticed how accurate real-time speech detects is now?
|
||||
|
||||
SPEAKER 2 0:00:12 - 0:00:22
|
||||
Absolutely. I use it all the time for taking notes during meetings. It's amazing how it can recognize different speakers, and even add punctuation.
|
||||
|
||||
SPEAKER 3 0:00:22 - 0:00:28
|
||||
Yeah, but sometimes noise can still cause mistakes. Does this system handle that well?
|
||||
|
||||
SPEAKER 1 0:00:29 - 0:00:29
|
||||
It does a pretty good job filtering noise, especially with models that use voice activity
|
||||
@@ -1,257 +0,0 @@
|
||||
|
||||
import logging
|
||||
import re
|
||||
from whisperlivekit.remove_silences import handle_silences
|
||||
from whisperlivekit.timed_objects import Line, format_time, SpeakerSegment
|
||||
from typing import List
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
CHECK_AROUND = 4
|
||||
DEBUG = False
|
||||
|
||||
def next_punctuation_change(i, tokens):
|
||||
for ind in range(i+1, min(len(tokens), i+CHECK_AROUND+1)):
|
||||
if tokens[ind].is_punctuation():
|
||||
return ind
|
||||
return None
|
||||
|
||||
def next_speaker_change(i, tokens, speaker):
|
||||
for ind in range(i-1, max(0, i-CHECK_AROUND)-1, -1):
|
||||
token = tokens[ind]
|
||||
if token.is_punctuation():
|
||||
break
|
||||
if token.speaker != speaker:
|
||||
return ind, token.speaker
|
||||
return None, speaker
|
||||
|
||||
def new_line(
|
||||
token,
|
||||
):
|
||||
return Line(
|
||||
speaker = token.corrected_speaker,
|
||||
text = token.text + (f"[{format_time(token.start)} : {format_time(token.end)}]" if DEBUG else ""),
|
||||
start = token.start,
|
||||
end = token.end,
|
||||
detected_language=token.detected_language
|
||||
)
|
||||
|
||||
def append_token_to_last_line(lines, sep, token):
|
||||
if not lines:
|
||||
lines.append(new_line(token))
|
||||
else:
|
||||
if token.text:
|
||||
lines[-1].text += sep + token.text + (f"[{format_time(token.start)} : {format_time(token.end)}]" if DEBUG else "")
|
||||
lines[-1].end = token.end
|
||||
if not lines[-1].detected_language and token.detected_language:
|
||||
lines[-1].detected_language = token.detected_language
|
||||
|
||||
def extract_number(s) -> int:
|
||||
"""Extract number from speaker string (for diart compatibility)."""
|
||||
if isinstance(s, int):
|
||||
return s
|
||||
m = re.search(r'\d+', str(s))
|
||||
return int(m.group()) if m else 0
|
||||
|
||||
def concatenate_speakers(segments: List[SpeakerSegment]) -> List[dict]:
|
||||
"""Concatenate consecutive segments from the same speaker."""
|
||||
if not segments:
|
||||
return []
|
||||
|
||||
# Get speaker number from first segment
|
||||
first_speaker = extract_number(segments[0].speaker)
|
||||
segments_concatenated = [{"speaker": first_speaker + 1, "begin": segments[0].start, "end": segments[0].end}]
|
||||
|
||||
for segment in segments[1:]:
|
||||
speaker = extract_number(segment.speaker) + 1
|
||||
if segments_concatenated[-1]['speaker'] != speaker:
|
||||
segments_concatenated.append({"speaker": speaker, "begin": segment.start, "end": segment.end})
|
||||
else:
|
||||
segments_concatenated[-1]['end'] = segment.end
|
||||
|
||||
return segments_concatenated
|
||||
|
||||
def add_speaker_to_tokens_with_punctuation(segments: List[SpeakerSegment], tokens: list) -> list:
|
||||
"""Assign speakers to tokens with punctuation-aware boundary adjustment."""
|
||||
punctuation_marks = {'.', '!', '?'}
|
||||
punctuation_tokens = [token for token in tokens if token.text.strip() in punctuation_marks]
|
||||
segments_concatenated = concatenate_speakers(segments)
|
||||
|
||||
for ind, segment in enumerate(segments_concatenated):
|
||||
for i, punctuation_token in enumerate(punctuation_tokens):
|
||||
if punctuation_token.start > segment['end']:
|
||||
after_length = punctuation_token.start - segment['end']
|
||||
before_length = segment['end'] - punctuation_tokens[i - 1].end if i > 0 else float('inf')
|
||||
if before_length > after_length:
|
||||
segment['end'] = punctuation_token.start
|
||||
if i < len(punctuation_tokens) - 1 and ind + 1 < len(segments_concatenated):
|
||||
segments_concatenated[ind + 1]['begin'] = punctuation_token.start
|
||||
else:
|
||||
segment['end'] = punctuation_tokens[i - 1].end if i > 0 else segment['end']
|
||||
if i < len(punctuation_tokens) - 1 and ind - 1 >= 0:
|
||||
segments_concatenated[ind - 1]['begin'] = punctuation_tokens[i - 1].end
|
||||
break
|
||||
|
||||
# Ensure non-overlapping tokens
|
||||
last_end = 0.0
|
||||
for token in tokens:
|
||||
start = max(last_end + 0.01, token.start)
|
||||
token.start = start
|
||||
token.end = max(start, token.end)
|
||||
last_end = token.end
|
||||
|
||||
# Assign speakers based on adjusted segments
|
||||
ind_last_speaker = 0
|
||||
for segment in segments_concatenated:
|
||||
for i, token in enumerate(tokens[ind_last_speaker:]):
|
||||
if token.end <= segment['end']:
|
||||
token.speaker = segment['speaker']
|
||||
ind_last_speaker = i + 1
|
||||
elif token.start > segment['end']:
|
||||
break
|
||||
|
||||
return tokens
|
||||
|
||||
def assign_speakers_to_tokens(tokens: list, segments: List[SpeakerSegment], use_punctuation_split: bool = False) -> list:
|
||||
"""
|
||||
Assign speakers to tokens based on timing overlap with speaker segments.
|
||||
|
||||
Args:
|
||||
tokens: List of tokens with timing information
|
||||
segments: List of speaker segments
|
||||
use_punctuation_split: Whether to use punctuation for boundary refinement
|
||||
|
||||
Returns:
|
||||
List of tokens with speaker assignments
|
||||
"""
|
||||
if not segments or not tokens:
|
||||
logger.debug("No segments or tokens available for speaker assignment")
|
||||
return tokens
|
||||
|
||||
logger.debug(f"Assigning speakers to {len(tokens)} tokens using {len(segments)} segments")
|
||||
|
||||
if not use_punctuation_split:
|
||||
# Simple overlap-based assignment
|
||||
for token in tokens:
|
||||
token.speaker = -1 # Default to no speaker
|
||||
for segment in segments:
|
||||
# Check for timing overlap
|
||||
if not (segment.end <= token.start or segment.start >= token.end):
|
||||
speaker_num = extract_number(segment.speaker)
|
||||
token.speaker = speaker_num + 1 # Convert to 1-based indexing
|
||||
break
|
||||
else:
|
||||
# Use punctuation-aware assignment
|
||||
tokens = add_speaker_to_tokens_with_punctuation(segments, tokens)
|
||||
|
||||
return tokens
|
||||
|
||||
def format_output(state, silence, args, sep):
|
||||
diarization = args.diarization
|
||||
disable_punctuation_split = args.disable_punctuation_split
|
||||
tokens = state.tokens
|
||||
translation_validated_segments = state.translation_validated_segments # Here we will attribute the speakers only based on the timestamps of the segments
|
||||
last_validated_token = state.last_validated_token
|
||||
|
||||
last_speaker = abs(state.last_speaker)
|
||||
undiarized_text = []
|
||||
tokens = handle_silences(tokens, state.beg_loop, silence)
|
||||
|
||||
# Assign speakers to tokens based on segments stored in state
|
||||
if False and diarization and state.diarization_segments:
|
||||
use_punctuation_split = args.punctuation_split if hasattr(args, 'punctuation_split') else False
|
||||
tokens = assign_speakers_to_tokens(tokens, state.diarization_segments, use_punctuation_split=use_punctuation_split)
|
||||
for i in range(last_validated_token, len(tokens)):
|
||||
token = tokens[i]
|
||||
speaker = int(token.speaker)
|
||||
token.corrected_speaker = speaker
|
||||
if True or not diarization:
|
||||
if speaker == -1: #Speaker -1 means no attributed by diarization. In the frontend, it should appear under 'Speaker 1'
|
||||
token.corrected_speaker = 1
|
||||
token.validated_speaker = True
|
||||
else:
|
||||
if token.speaker == -1:
|
||||
undiarized_text.append(token.text)
|
||||
elif token.is_punctuation():
|
||||
state.last_punctuation_index = i
|
||||
token.corrected_speaker = last_speaker
|
||||
token.validated_speaker = True
|
||||
elif state.last_punctuation_index == i-1:
|
||||
if token.speaker != last_speaker:
|
||||
token.corrected_speaker = token.speaker
|
||||
token.validated_speaker = True
|
||||
# perfect, diarization perfectly aligned
|
||||
else:
|
||||
speaker_change_pos, new_speaker = next_speaker_change(i, tokens, speaker)
|
||||
if speaker_change_pos:
|
||||
# Corrects delay:
|
||||
# That was the idea. <Okay> haha |SPLIT SPEAKER| that's a good one
|
||||
# should become:
|
||||
# That was the idea. |SPLIT SPEAKER| <Okay> haha that's a good one
|
||||
token.corrected_speaker = new_speaker
|
||||
token.validated_speaker = True
|
||||
elif speaker != last_speaker:
|
||||
if not (speaker == -2 or last_speaker == -2):
|
||||
if next_punctuation_change(i, tokens):
|
||||
# Corrects advance:
|
||||
# Are you |SPLIT SPEAKER| <okay>? yeah, sure. Absolutely
|
||||
# should become:
|
||||
# Are you <okay>? |SPLIT SPEAKER| yeah, sure. Absolutely
|
||||
token.corrected_speaker = last_speaker
|
||||
token.validated_speaker = True
|
||||
else: #Problematic, except if the language has no punctuation. We append to previous line, except if disable_punctuation_split is set to True.
|
||||
if not disable_punctuation_split:
|
||||
token.corrected_speaker = last_speaker
|
||||
token.validated_speaker = False
|
||||
if token.validated_speaker:
|
||||
state.last_validated_token = i
|
||||
state.last_speaker = token.corrected_speaker
|
||||
|
||||
last_speaker = 1
|
||||
|
||||
lines = []
|
||||
for token in tokens:
|
||||
if token.corrected_speaker != -1:
|
||||
if int(token.corrected_speaker) != int(last_speaker):
|
||||
lines.append(new_line(token))
|
||||
else:
|
||||
append_token_to_last_line(lines, sep, token)
|
||||
|
||||
last_speaker = token.corrected_speaker
|
||||
|
||||
if lines:
|
||||
unassigned_translated_segments = []
|
||||
for ts in translation_validated_segments:
|
||||
assigned = False
|
||||
for line in lines:
|
||||
if ts and ts.overlaps_with(line):
|
||||
if ts.is_within(line):
|
||||
line.translation += ts.text + ' '
|
||||
assigned = True
|
||||
break
|
||||
else:
|
||||
ts0, ts1 = ts.approximate_cut_at(line.end)
|
||||
if ts0 and line.overlaps_with(ts0):
|
||||
line.translation += ts0.text + ' '
|
||||
if ts1:
|
||||
unassigned_translated_segments.append(ts1)
|
||||
assigned = True
|
||||
break
|
||||
if not assigned:
|
||||
unassigned_translated_segments.append(ts)
|
||||
|
||||
if unassigned_translated_segments:
|
||||
for line in lines:
|
||||
remaining_segments = []
|
||||
for ts in unassigned_translated_segments:
|
||||
if ts and ts.overlaps_with(line):
|
||||
line.translation += ts.text + ' '
|
||||
else:
|
||||
remaining_segments.append(ts)
|
||||
unassigned_translated_segments = remaining_segments #maybe do smth in the future about that
|
||||
|
||||
if state.buffer_transcription and lines:
|
||||
lines[-1].end = max(state.buffer_transcription.end, lines[-1].end)
|
||||
|
||||
return lines, undiarized_text
|
||||
@@ -266,7 +266,7 @@ class AlignAtt:
|
||||
logger.debug("Refreshing segment:")
|
||||
self.init_tokens()
|
||||
self.last_attend_frame = -self.cfg.rewind_threshold
|
||||
self.detected_language = None
|
||||
# self.detected_language = None
|
||||
self.cumulative_time_offset = 0.0
|
||||
self.init_context()
|
||||
logger.debug(f"Context: {self.context}")
|
||||
|
||||
261
whisperlivekit/ten_vad_alpha.py
Normal file
261
whisperlivekit/ten_vad_alpha.py
Normal file
@@ -0,0 +1,261 @@
|
||||
"""
|
||||
ALPHA. results are not great yet
|
||||
To replace `whisperlivekit.silero_vad_iterator import FixedVADIterator`
|
||||
by `from whisperlivekit.ten_vad_iterator import TenVADIterator`
|
||||
|
||||
Use self.vac = TenVADIterator() instead of self.vac = FixedVADIterator(models.vac_model)
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from ten_vad import TenVad
|
||||
|
||||
|
||||
class TenVADIterator:
|
||||
def __init__(self,
|
||||
threshold: float = 0.5,
|
||||
sampling_rate: int = 16000,
|
||||
min_silence_duration_ms: int = 100,
|
||||
speech_pad_ms: int = 30):
|
||||
self.vad = TenVad()
|
||||
self.threshold = threshold
|
||||
self.sampling_rate = sampling_rate
|
||||
self.min_silence_duration_ms = min_silence_duration_ms
|
||||
self.speech_pad_ms = speech_pad_ms
|
||||
|
||||
self.min_silence_samples = int(sampling_rate * min_silence_duration_ms / 1000)
|
||||
self.speech_pad_samples = int(sampling_rate * speech_pad_ms / 1000)
|
||||
|
||||
self.reset_states()
|
||||
|
||||
def reset_states(self):
|
||||
self.triggered = False
|
||||
self.temp_end = 0
|
||||
self.current_sample = 0
|
||||
self.buffer = np.array([], dtype=np.float32)
|
||||
|
||||
def __call__(self, x, return_seconds=False):
|
||||
if not isinstance(x, np.ndarray):
|
||||
x = np.array(x, dtype=np.float32)
|
||||
|
||||
self.buffer = np.append(self.buffer, x)
|
||||
|
||||
chunk_size = 256
|
||||
ret = None
|
||||
|
||||
while len(self.buffer) >= chunk_size:
|
||||
chunk = self.buffer[:chunk_size].astype(np.int16)
|
||||
self.buffer = self.buffer[chunk_size:]
|
||||
|
||||
window_size_samples = len(chunk)
|
||||
self.current_sample += window_size_samples
|
||||
speech_prob, speech_flag = self.vad.process(chunk)
|
||||
if (speech_prob >= self.threshold) and self.temp_end:
|
||||
self.temp_end = 0
|
||||
|
||||
if (speech_prob >= self.threshold) and not self.triggered:
|
||||
self.triggered = True
|
||||
speech_start = max(0, self.current_sample - self.speech_pad_samples - window_size_samples)
|
||||
result = {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, 1)}
|
||||
if ret is None:
|
||||
ret = result
|
||||
elif "end" in ret:
|
||||
ret = result
|
||||
else:
|
||||
ret.update(result)
|
||||
|
||||
if (speech_prob < self.threshold - 0.15) and self.triggered:
|
||||
if not self.temp_end:
|
||||
self.temp_end = self.current_sample
|
||||
if self.current_sample - self.temp_end < self.min_silence_samples:
|
||||
continue
|
||||
else:
|
||||
speech_end = self.temp_end + self.speech_pad_samples - window_size_samples
|
||||
self.temp_end = 0
|
||||
self.triggered = False
|
||||
result = {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, 1)}
|
||||
if ret is None:
|
||||
ret = result
|
||||
else:
|
||||
ret.update(result)
|
||||
|
||||
return ret if ret != {} else None
|
||||
|
||||
|
||||
def test_on_record_wav():
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
audio_file = Path("record.wav")
|
||||
if not audio_file.exists():
|
||||
return
|
||||
|
||||
import soundfile as sf
|
||||
audio_data, sample_rate = sf.read(str(audio_file), dtype='float32')
|
||||
|
||||
if len(audio_data.shape) > 1:
|
||||
audio_data = np.mean(audio_data, axis=1)
|
||||
|
||||
vad = TenVADIterator(
|
||||
threshold=0.5,
|
||||
sampling_rate=sample_rate,
|
||||
min_silence_duration_ms=100,
|
||||
speech_pad_ms=30
|
||||
)
|
||||
|
||||
chunk_size = 1024
|
||||
speech_segments = []
|
||||
current_segment = None
|
||||
|
||||
for i in range(0, len(audio_data), chunk_size):
|
||||
chunk = audio_data[i:i+chunk_size]
|
||||
|
||||
if chunk.dtype != np.int16:
|
||||
chunk_int16 = (chunk * 32767.0).astype(np.int16)
|
||||
else:
|
||||
chunk_int16 = chunk
|
||||
|
||||
result = vad(chunk_int16, return_seconds=True)
|
||||
|
||||
if result is not None:
|
||||
if 'start' in result:
|
||||
current_segment = {'start': result['start'], 'end': None}
|
||||
print(f"Speech start detected at {result['start']:.2f}s")
|
||||
elif 'end' in result:
|
||||
if current_segment:
|
||||
current_segment['end'] = result['end']
|
||||
duration = current_segment['end'] - current_segment['start']
|
||||
speech_segments.append(current_segment)
|
||||
print(f"Speech end detected at {result['end']:.2f}s (duration: {duration:.2f}s)")
|
||||
current_segment = None
|
||||
else:
|
||||
print(f"Speech end detected at {result['end']:.2f}s (no corresponding start)")
|
||||
|
||||
if current_segment and current_segment['end'] is None:
|
||||
current_segment['end'] = len(audio_data) / sample_rate
|
||||
speech_segments.append(current_segment)
|
||||
print(f"End of file (last segment at {current_segment['start']:.2f}s)")
|
||||
|
||||
print("-" * 60)
|
||||
print(f"\nSummary:")
|
||||
print(f"Number of speech segments detected: {len(speech_segments)}")
|
||||
|
||||
if speech_segments:
|
||||
total_speech_time = sum(seg['end'] - seg['start'] for seg in speech_segments)
|
||||
total_time = len(audio_data) / sample_rate
|
||||
speech_ratio = (total_speech_time / total_time) * 100
|
||||
|
||||
print(f"Total speech time: {total_speech_time:.2f}s")
|
||||
print(f"Total file time: {total_time:.2f}s")
|
||||
print(f"Speech ratio: {speech_ratio:.1f}%")
|
||||
print(f"\nDetected segments:")
|
||||
for i, seg in enumerate(speech_segments, 1):
|
||||
print(f" {i}. {seg['start']:.2f}s - {seg['end']:.2f}s (duration: {seg['end'] - seg['start']:.2f}s)")
|
||||
else:
|
||||
print("No speech segments detected")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Extracting silence segments...")
|
||||
|
||||
silence_segments = []
|
||||
total_time = len(audio_data) / sample_rate
|
||||
|
||||
if not speech_segments:
|
||||
silence_segments = [{'start': 0.0, 'end': total_time}]
|
||||
else:
|
||||
if speech_segments[0]['start'] > 0:
|
||||
silence_segments.append({'start': 0.0, 'end': speech_segments[0]['start']})
|
||||
|
||||
for i in range(len(speech_segments) - 1):
|
||||
silence_start = speech_segments[i]['end']
|
||||
silence_end = speech_segments[i + 1]['start']
|
||||
if silence_end > silence_start:
|
||||
silence_segments.append({'start': silence_start, 'end': silence_end})
|
||||
|
||||
if speech_segments[-1]['end'] < total_time:
|
||||
silence_segments.append({'start': speech_segments[-1]['end'], 'end': total_time})
|
||||
|
||||
silence_audio = np.array([], dtype=audio_data.dtype)
|
||||
|
||||
for seg in silence_segments:
|
||||
start_sample = int(seg['start'] * sample_rate)
|
||||
end_sample = int(seg['end'] * sample_rate)
|
||||
start_sample = max(0, min(start_sample, len(audio_data)))
|
||||
end_sample = max(0, min(end_sample, len(audio_data)))
|
||||
|
||||
if end_sample > start_sample:
|
||||
silence_audio = np.concatenate([silence_audio, audio_data[start_sample:end_sample]])
|
||||
|
||||
if len(silence_audio) > 0:
|
||||
output_file = "record_silence_only.wav"
|
||||
try:
|
||||
import soundfile as sf
|
||||
sf.write(output_file, silence_audio, sample_rate)
|
||||
print(f"Silence file saved: {output_file}")
|
||||
except ImportError:
|
||||
try:
|
||||
from scipy.io import wavfile
|
||||
if silence_audio.dtype == np.float32:
|
||||
silence_audio_int16 = (silence_audio * 32767.0).astype(np.int16)
|
||||
else:
|
||||
silence_audio_int16 = silence_audio.astype(np.int16)
|
||||
wavfile.write(output_file, sample_rate, silence_audio_int16)
|
||||
print(f"Silence file saved: {output_file}")
|
||||
except ImportError:
|
||||
print("Unable to save: soundfile or scipy required")
|
||||
|
||||
total_silence_time = sum(seg['end'] - seg['start'] for seg in silence_segments)
|
||||
silence_ratio = (total_silence_time / total_time) * 100
|
||||
print(f"Total silence duration: {total_silence_time:.2f}s")
|
||||
print(f"Silence ratio: {silence_ratio:.1f}%")
|
||||
print(f"Number of silence segments: {len(silence_segments)}")
|
||||
print(f"\nYou can listen to {output_file} to verify that only silences are present.")
|
||||
else:
|
||||
print("No silence segments found (file entirely speech)")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Extracting speech segments...")
|
||||
|
||||
if speech_segments:
|
||||
speech_audio = np.array([], dtype=audio_data.dtype)
|
||||
|
||||
for seg in speech_segments:
|
||||
start_sample = int(seg['start'] * sample_rate)
|
||||
end_sample = int(seg['end'] * sample_rate)
|
||||
start_sample = max(0, min(start_sample, len(audio_data)))
|
||||
end_sample = max(0, min(end_sample, len(audio_data)))
|
||||
|
||||
if end_sample > start_sample:
|
||||
speech_audio = np.concatenate([speech_audio, audio_data[start_sample:end_sample]])
|
||||
|
||||
if len(speech_audio) > 0:
|
||||
output_file = "record_speech_only.wav"
|
||||
try:
|
||||
import soundfile as sf
|
||||
sf.write(output_file, speech_audio, sample_rate)
|
||||
print(f"Speech file saved: {output_file}")
|
||||
except ImportError:
|
||||
try:
|
||||
from scipy.io import wavfile
|
||||
if speech_audio.dtype == np.float32:
|
||||
speech_audio_int16 = (speech_audio * 32767.0).astype(np.int16)
|
||||
else:
|
||||
speech_audio_int16 = speech_audio.astype(np.int16)
|
||||
wavfile.write(output_file, sample_rate, speech_audio_int16)
|
||||
print(f"Speech file saved: {output_file}")
|
||||
except ImportError:
|
||||
print("Unable to save: soundfile or scipy required")
|
||||
|
||||
total_speech_time = sum(seg['end'] - seg['start'] for seg in speech_segments)
|
||||
speech_ratio = (total_speech_time / total_time) * 100
|
||||
print(f"Total speech duration: {total_speech_time:.2f}s")
|
||||
print(f"Speech ratio: {speech_ratio:.1f}%")
|
||||
print(f"Number of speech segments: {len(speech_segments)}")
|
||||
print(f"\nYou can listen to {output_file} to verify that only speech segments are present.")
|
||||
else:
|
||||
print("No speech audio to extract")
|
||||
else:
|
||||
print("No speech segments found (file entirely silence)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_on_record_wav()
|
||||
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Any, List
|
||||
from typing import Optional, List, Union, Dict, Any
|
||||
from datetime import timedelta
|
||||
|
||||
PUNCTUATION_MARKS = {'.', '!', '?', '。', '!', '?'}
|
||||
@@ -19,11 +19,8 @@ class TimedText(Timed):
|
||||
speaker: Optional[int] = -1
|
||||
detected_language: Optional[str] = None
|
||||
|
||||
def is_punctuation(self):
|
||||
return self.text.strip() in PUNCTUATION_MARKS
|
||||
|
||||
def overlaps_with(self, other: 'TimedText') -> bool:
|
||||
return not (self.end <= other.start or other.end <= self.start)
|
||||
def has_punctuation(self) -> bool:
|
||||
return any(char in PUNCTUATION_MARKS for char in self.text.strip())
|
||||
|
||||
def is_within(self, other: 'TimedText') -> bool:
|
||||
return other.contains_timespan(self)
|
||||
@@ -31,28 +28,26 @@ class TimedText(Timed):
|
||||
def duration(self) -> float:
|
||||
return self.end - self.start
|
||||
|
||||
def contains_time(self, time: float) -> bool:
|
||||
return self.start <= time <= self.end
|
||||
|
||||
def contains_timespan(self, other: 'TimedText') -> bool:
|
||||
return self.start <= other.start and self.end >= other.end
|
||||
|
||||
def __bool__(self):
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.text)
|
||||
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.text)
|
||||
|
||||
@dataclass()
|
||||
class ASRToken(TimedText):
|
||||
|
||||
corrected_speaker: Optional[int] = -1
|
||||
validated_speaker: bool = False
|
||||
validated_text: bool = False
|
||||
validated_language: bool = False
|
||||
|
||||
def with_offset(self, offset: float) -> "ASRToken":
|
||||
"""Return a new token with the time offset added."""
|
||||
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, detected_language=self.detected_language)
|
||||
|
||||
def is_silence(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class Sentence(TimedText):
|
||||
pass
|
||||
@@ -70,6 +65,7 @@ class Transcript(TimedText):
|
||||
sep: Optional[str] = None,
|
||||
offset: float = 0
|
||||
) -> "Transcript":
|
||||
"""Collapse multiple ASR tokens into a single transcript span."""
|
||||
sep = sep if sep is not None else ' '
|
||||
text = sep.join(token.text for token in tokens)
|
||||
if tokens:
|
||||
@@ -93,47 +89,70 @@ class SpeakerSegment(Timed):
|
||||
class Translation(TimedText):
|
||||
pass
|
||||
|
||||
def approximate_cut_at(self, cut_time):
|
||||
"""
|
||||
Each word in text is considered to be of duration (end-start)/len(words in text)
|
||||
"""
|
||||
if not self.text or not self.contains_time(cut_time):
|
||||
return self, None
|
||||
|
||||
words = self.text.split()
|
||||
num_words = len(words)
|
||||
if num_words == 0:
|
||||
return self, None
|
||||
|
||||
duration_per_word = self.duration() / num_words
|
||||
|
||||
cut_word_index = int((cut_time - self.start) / duration_per_word)
|
||||
|
||||
if cut_word_index >= num_words:
|
||||
cut_word_index = num_words -1
|
||||
|
||||
text0 = " ".join(words[:cut_word_index])
|
||||
text1 = " ".join(words[cut_word_index:])
|
||||
|
||||
segment0 = Translation(start=self.start, end=cut_time, text=text0)
|
||||
segment1 = Translation(start=cut_time, end=self.end, text=text1)
|
||||
|
||||
return segment0, segment1
|
||||
|
||||
|
||||
@dataclass
|
||||
class Silence():
|
||||
start: Optional[float] = None
|
||||
end: Optional[float] = None
|
||||
duration: Optional[float] = None
|
||||
is_starting: bool = False
|
||||
has_ended: bool = False
|
||||
|
||||
def compute_duration(self) -> Optional[float]:
|
||||
if self.start is None or self.end is None:
|
||||
return None
|
||||
self.duration = self.end - self.start
|
||||
return self.duration
|
||||
|
||||
|
||||
def is_silence(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@dataclass
|
||||
class Segment(TimedText):
|
||||
"""Generic contiguous span built from tokens or silence markers."""
|
||||
start: Optional[float]
|
||||
end: Optional[float]
|
||||
text: Optional[str]
|
||||
speaker: Optional[str]
|
||||
@classmethod
|
||||
def from_tokens(
|
||||
cls,
|
||||
tokens: List[Union[ASRToken, Silence]],
|
||||
is_silence: bool = False
|
||||
) -> Optional["Segment"]:
|
||||
"""Return a normalized segment representing the provided tokens."""
|
||||
if not tokens:
|
||||
return None
|
||||
|
||||
start_token = tokens[0]
|
||||
end_token = tokens[-1]
|
||||
if is_silence:
|
||||
return cls(
|
||||
start=start_token.start,
|
||||
end=end_token.end,
|
||||
text=None,
|
||||
speaker=-2
|
||||
)
|
||||
else:
|
||||
return cls(
|
||||
start=start_token.start,
|
||||
end=end_token.end,
|
||||
text=''.join(token.text for token in tokens),
|
||||
speaker=-1,
|
||||
detected_language=start_token.detected_language
|
||||
)
|
||||
def is_silence(self) -> bool:
|
||||
"""True when this segment represents a silence gap."""
|
||||
return self.speaker == -2
|
||||
|
||||
|
||||
@dataclass
|
||||
class Line(TimedText):
|
||||
translation: str = ''
|
||||
|
||||
def to_dict(self):
|
||||
_dict = {
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serialize the line for frontend consumption."""
|
||||
_dict: Dict[str, Any] = {
|
||||
'speaker': int(self.speaker) if self.speaker != -1 else 1,
|
||||
'text': self.text,
|
||||
'start': format_time(self.start),
|
||||
@@ -145,6 +164,33 @@ class Line(TimedText):
|
||||
_dict['detected_language'] = self.detected_language
|
||||
return _dict
|
||||
|
||||
def build_from_tokens(self, tokens: List[ASRToken]) -> "Line":
|
||||
"""Populate line attributes from a contiguous token list."""
|
||||
self.text = ''.join([token.text for token in tokens])
|
||||
self.start = tokens[0].start
|
||||
self.end = tokens[-1].end
|
||||
self.speaker = 1
|
||||
self.detected_language = tokens[0].detected_language
|
||||
return self
|
||||
|
||||
def build_from_segment(self, segment: Segment) -> "Line":
|
||||
"""Populate the line fields from a pre-built segment."""
|
||||
self.text = segment.text
|
||||
self.start = segment.start
|
||||
self.end = segment.end
|
||||
self.speaker = segment.speaker
|
||||
self.detected_language = segment.detected_language
|
||||
return self
|
||||
|
||||
def is_silent(self) -> bool:
|
||||
return self.speaker == -2
|
||||
|
||||
class SilentLine(Line):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.speaker = -2
|
||||
self.text = ''
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrontData():
|
||||
@@ -157,8 +203,9 @@ class FrontData():
|
||||
remaining_time_transcription: float = 0.
|
||||
remaining_time_diarization: float = 0.
|
||||
|
||||
def to_dict(self):
|
||||
_dict = {
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serialize the front-end data payload."""
|
||||
_dict: Dict[str, Any] = {
|
||||
'status': self.status,
|
||||
'lines': [line.to_dict() for line in self.lines if (line.text or line.speaker == -2)],
|
||||
'buffer_transcription': self.buffer_transcription,
|
||||
@@ -178,26 +225,22 @@ class ChangeSpeaker:
|
||||
|
||||
@dataclass
|
||||
class State():
|
||||
tokens: list = field(default_factory=list)
|
||||
last_validated_token: int = 0
|
||||
last_speaker: int = 1
|
||||
last_punctuation_index: Optional[int] = None
|
||||
translation_validated_segments: list = field(default_factory=list)
|
||||
buffer_translation: str = field(default_factory=Transcript)
|
||||
buffer_transcription: str = field(default_factory=Transcript)
|
||||
diarization_segments: list = field(default_factory=list)
|
||||
"""Unified state class for audio processing.
|
||||
|
||||
Contains both persistent state (tokens, buffers) and temporary update buffers
|
||||
(new_* fields) that are consumed by TokensAlignment.
|
||||
"""
|
||||
# Persistent state
|
||||
tokens: List[ASRToken] = field(default_factory=list)
|
||||
buffer_transcription: Transcript = field(default_factory=Transcript)
|
||||
end_buffer: float = 0.0
|
||||
end_attributed_speaker: float = 0.0
|
||||
remaining_time_transcription: float = 0.0
|
||||
remaining_time_diarization: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class StateLight():
|
||||
new_tokens: list = field(default_factory=list)
|
||||
new_translation: list = field(default_factory=list)
|
||||
new_diarization: list = field(default_factory=list)
|
||||
new_tokens_buffer: list = field(default_factory=list) #only when local agreement
|
||||
new_tokens_index = 0
|
||||
new_translation_index = 0
|
||||
new_diarization_index = 0
|
||||
|
||||
# Temporary update buffers (consumed by TokensAlignment.update())
|
||||
new_tokens: List[Union[ASRToken, Silence]] = field(default_factory=list)
|
||||
new_translation: List[Any] = field(default_factory=list)
|
||||
new_diarization: List[Any] = field(default_factory=list)
|
||||
new_tokens_buffer: List[Any] = field(default_factory=list) # only when local agreement
|
||||
new_translation_buffer= TimedText()
|
||||
177
whisperlivekit/tokens_alignment.py
Normal file
177
whisperlivekit/tokens_alignment.py
Normal file
@@ -0,0 +1,177 @@
|
||||
from time import time
|
||||
from typing import Optional, List, Tuple, Union, Any
|
||||
|
||||
from whisperlivekit.timed_objects import Line, SilentLine, ASRToken, SpeakerSegment, Silence, TimedText, Segment
|
||||
|
||||
|
||||
class TokensAlignment:
|
||||
|
||||
def __init__(self, state: Any, args: Any, sep: Optional[str]) -> None:
|
||||
self.state = state
|
||||
self.diarization = args.diarization
|
||||
self._tokens_index: int = 0
|
||||
self._diarization_index: int = 0
|
||||
self._translation_index: int = 0
|
||||
|
||||
self.all_tokens: List[ASRToken] = []
|
||||
self.all_diarization_segments: List[SpeakerSegment] = []
|
||||
self.all_translation_segments: List[Any] = []
|
||||
|
||||
self.new_tokens: List[ASRToken] = []
|
||||
self.new_diarization: List[SpeakerSegment] = []
|
||||
self.new_translation: List[Any] = []
|
||||
self.new_translation_buffer: Union[TimedText, str] = TimedText()
|
||||
self.new_tokens_buffer: List[Any] = []
|
||||
self.sep: str = sep if sep is not None else ' '
|
||||
self.beg_loop: Optional[float] = None
|
||||
|
||||
def update(self) -> None:
|
||||
"""Drain state buffers into the running alignment context."""
|
||||
self.new_tokens, self.state.new_tokens = self.state.new_tokens, []
|
||||
self.new_diarization, self.state.new_diarization = self.state.new_diarization, []
|
||||
self.new_translation, self.state.new_translation = self.state.new_translation, []
|
||||
self.new_tokens_buffer, self.state.new_tokens_buffer = self.state.new_tokens_buffer, []
|
||||
|
||||
self.all_tokens.extend(self.new_tokens)
|
||||
self.all_diarization_segments.extend(self.new_diarization)
|
||||
self.all_translation_segments.extend(self.new_translation)
|
||||
self.new_translation_buffer = self.state.new_translation_buffer
|
||||
|
||||
def add_translation(self, line: Line) -> None:
|
||||
"""Append translated text segments that overlap with a line."""
|
||||
for ts in self.all_translation_segments:
|
||||
if ts.is_within(line):
|
||||
line.translation += ts.text + (self.sep if ts.text else '')
|
||||
elif line.translation:
|
||||
break
|
||||
|
||||
|
||||
def compute_punctuations_segments(self, tokens: Optional[List[ASRToken]] = None) -> List[Segment]:
|
||||
"""Group tokens into segments split by punctuation and explicit silence."""
|
||||
segments = []
|
||||
segment_start_idx = 0
|
||||
for i, token in enumerate(self.all_tokens):
|
||||
if token.is_silence():
|
||||
previous_segment = Segment.from_tokens(
|
||||
tokens=self.all_tokens[segment_start_idx: i],
|
||||
)
|
||||
if previous_segment:
|
||||
segments.append(previous_segment)
|
||||
segment = Segment.from_tokens(
|
||||
tokens=[token],
|
||||
is_silence=True
|
||||
)
|
||||
segments.append(segment)
|
||||
segment_start_idx = i+1
|
||||
else:
|
||||
if token.has_punctuation():
|
||||
segment = Segment.from_tokens(
|
||||
tokens=self.all_tokens[segment_start_idx: i+1],
|
||||
)
|
||||
segments.append(segment)
|
||||
segment_start_idx = i+1
|
||||
|
||||
final_segment = Segment.from_tokens(
|
||||
tokens=self.all_tokens[segment_start_idx:],
|
||||
)
|
||||
if final_segment:
|
||||
segments.append(final_segment)
|
||||
return segments
|
||||
|
||||
|
||||
def concatenate_diar_segments(self) -> List[SpeakerSegment]:
|
||||
"""Merge consecutive diarization slices that share the same speaker."""
|
||||
if not self.all_diarization_segments:
|
||||
return []
|
||||
merged = [self.all_diarization_segments[0]]
|
||||
for segment in self.all_diarization_segments[1:]:
|
||||
if segment.speaker == merged[-1].speaker:
|
||||
merged[-1].end = segment.end
|
||||
else:
|
||||
merged.append(segment)
|
||||
return merged
|
||||
|
||||
|
||||
@staticmethod
|
||||
def intersection_duration(seg1: TimedText, seg2: TimedText) -> float:
|
||||
"""Return the overlap duration between two timed segments."""
|
||||
start = max(seg1.start, seg2.start)
|
||||
end = min(seg1.end, seg2.end)
|
||||
|
||||
return max(0, end - start)
|
||||
|
||||
def get_lines_diarization(self) -> Tuple[List[Line], str]:
|
||||
"""Build lines when diarization is enabled and track overflow buffer."""
|
||||
diarization_buffer = ''
|
||||
punctuation_segments = self.compute_punctuations_segments()
|
||||
diarization_segments = self.concatenate_diar_segments()
|
||||
for punctuation_segment in punctuation_segments:
|
||||
if not punctuation_segment.is_silence():
|
||||
if diarization_segments and punctuation_segment.start >= diarization_segments[-1].end:
|
||||
diarization_buffer += punctuation_segment.text
|
||||
else:
|
||||
max_overlap = 0.0
|
||||
max_overlap_speaker = 1
|
||||
for diarization_segment in diarization_segments:
|
||||
intersec = self.intersection_duration(punctuation_segment, diarization_segment)
|
||||
if intersec > max_overlap:
|
||||
max_overlap = intersec
|
||||
max_overlap_speaker = diarization_segment.speaker + 1
|
||||
punctuation_segment.speaker = max_overlap_speaker
|
||||
|
||||
lines = []
|
||||
if punctuation_segments:
|
||||
lines = [Line().build_from_segment(punctuation_segments[0])]
|
||||
for segment in punctuation_segments[1:]:
|
||||
if segment.speaker == lines[-1].speaker:
|
||||
if lines[-1].text:
|
||||
lines[-1].text += segment.text
|
||||
lines[-1].end = segment.end
|
||||
else:
|
||||
lines.append(Line().build_from_segment(segment))
|
||||
|
||||
return lines, diarization_buffer
|
||||
|
||||
|
||||
def get_lines(
|
||||
self,
|
||||
diarization: bool = False,
|
||||
translation: bool = False,
|
||||
current_silence: Optional[Silence] = None
|
||||
) -> Tuple[List[Line], str, Union[str, TimedText]]:
|
||||
"""Return the formatted lines plus buffers, optionally with diarization/translation."""
|
||||
if diarization:
|
||||
lines, diarization_buffer = self.get_lines_diarization()
|
||||
else:
|
||||
diarization_buffer = ''
|
||||
lines = []
|
||||
current_line_tokens = []
|
||||
for token in self.all_tokens:
|
||||
if token.is_silence():
|
||||
if current_line_tokens:
|
||||
lines.append(Line().build_from_tokens(current_line_tokens))
|
||||
current_line_tokens = []
|
||||
end_silence = token.end if token.has_ended else time() - self.beg_loop
|
||||
if lines and lines[-1].is_silent():
|
||||
lines[-1].end = end_silence
|
||||
else:
|
||||
lines.append(SilentLine(
|
||||
start = token.start,
|
||||
end = end_silence
|
||||
))
|
||||
else:
|
||||
current_line_tokens.append(token)
|
||||
if current_line_tokens:
|
||||
lines.append(Line().build_from_tokens(current_line_tokens))
|
||||
if current_silence:
|
||||
end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop
|
||||
if lines and lines[-1].is_silent():
|
||||
lines[-1].end = end_silence
|
||||
else:
|
||||
lines.append(SilentLine(
|
||||
start = current_silence.start,
|
||||
end = end_silence
|
||||
))
|
||||
if translation:
|
||||
[self.add_translation(line) for line in lines if not type(line) == Silence]
|
||||
return lines, diarization_buffer, self.new_translation_buffer.text
|
||||
@@ -1,60 +0,0 @@
|
||||
from typing import Sequence, Callable, Any, Optional, Dict
|
||||
|
||||
def _detect_tail_repetition(
|
||||
seq: Sequence[Any],
|
||||
key: Callable[[Any], Any] = lambda x: x, # extract comparable value
|
||||
min_block: int = 1, # set to 2 to ignore 1-token loops like "."
|
||||
max_tail: int = 300, # search window from the end for speed
|
||||
prefer: str = "longest", # "longest" coverage or "smallest" block
|
||||
) -> Optional[Dict]:
|
||||
vals = [key(x) for x in seq][-max_tail:]
|
||||
n = len(vals)
|
||||
best = None
|
||||
|
||||
# try every possible block length
|
||||
for b in range(min_block, n // 2 + 1):
|
||||
block = vals[-b:]
|
||||
# count how many times this block repeats contiguously at the very end
|
||||
count, i = 0, n
|
||||
while i - b >= 0 and vals[i - b:i] == block:
|
||||
count += 1
|
||||
i -= b
|
||||
|
||||
if count >= 2:
|
||||
cand = {
|
||||
"block_size": b,
|
||||
"count": count,
|
||||
"start_index": len(seq) - count * b, # in original seq
|
||||
"end_index": len(seq),
|
||||
}
|
||||
if (best is None or
|
||||
(prefer == "longest" and count * b > best["count"] * best["block_size"]) or
|
||||
(prefer == "smallest" and b < best["block_size"])):
|
||||
best = cand
|
||||
return best
|
||||
|
||||
def trim_tail_repetition(
|
||||
seq: Sequence[Any],
|
||||
key: Callable[[Any], Any] = lambda x: x,
|
||||
min_block: int = 1,
|
||||
max_tail: int = 300,
|
||||
prefer: str = "longest",
|
||||
keep: int = 1, # how many copies of the repeating block to keep at the end (0 or 1 are common)
|
||||
):
|
||||
"""
|
||||
Returns a new sequence with repeated tail trimmed.
|
||||
keep=1 -> keep a single copy of the repeated block.
|
||||
keep=0 -> remove all copies of the repeated block.
|
||||
"""
|
||||
rep = _detect_tail_repetition(seq, key, min_block, max_tail, prefer)
|
||||
if not rep:
|
||||
return seq, False # nothing to trim
|
||||
|
||||
b, c = rep["block_size"], rep["count"]
|
||||
if keep < 0:
|
||||
keep = 0
|
||||
if keep >= c:
|
||||
return seq, False # nothing to trim (already <= keep copies)
|
||||
# new length = total - (copies_to_remove * block_size)
|
||||
new_len = len(seq) - (c - keep) * b
|
||||
return seq[:new_len], True
|
||||
Reference in New Issue
Block a user