8 Commits

Author SHA1 Message Date
Quentin Fuxa
9d4ae33249 WIP. Trying ten VAD #280 2025-11-23 11:20:00 +01:00
Quentin Fuxa
6206fff118 0.2.15 2025-11-21 23:52:00 +01:00
Quentin Fuxa
b5067249c0 stt/diar/nllw alignment: internal rework 5 2025-11-20 23:52:00 +01:00
Quentin Fuxa
f4f9831d39 stt/diar/nllw alignment: internal rework 5 2025-11-20 23:52:00 +01:00
Quentin Fuxa
254faaf64c stt/diar/nllw alignment: internal rework 5 2025-11-20 23:52:00 +01:00
Quentin Fuxa
8e7aea4fcf internal rework 4 2025-11-20 23:45:20 +01:00
Quentin Fuxa
270faf2069 internal rework 3 2025-11-20 22:28:30 +01:00
Quentin Fuxa
b7c1cc77cc internal rework 2 2025-11-20 22:06:38 +01:00
16 changed files with 720 additions and 881 deletions

View File

@@ -141,7 +141,7 @@ async def websocket_endpoint(websocket: WebSocket):
|-----------|-------------|---------|
| `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/available_models.md) | `small` |
| `--model-path` | Local .pt file/directory **or** Hugging Face repo ID containing the Whisper model. Overrides `--model`. Recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/models_compatible_formats.md) | `None` |
| `--language` | List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` |
| `--language` | List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` |
| `--target-language` | If sets, translates using [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting). [200 languages available](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/supported_languages.md). If you want to translate to english, you can also use `--direct-english-translation`. The STT model will try to directly output the translation. | `None` |
| `--diarization` | Enable speaker identification | `False` |
| `--backend-policy` | Streaming strategy: `1`/`simulstreaming` uses AlignAtt SimulStreaming, `2`/`localagreement` uses the LocalAgreement policy | `simulstreaming` |

View File

@@ -4,7 +4,7 @@
- Example 2: The punctuation from STT comes from prediction `t`, but the speaker change from Diariation come in the prediction `t-1`
- Example 3: The punctuation from STT comes from prediction `t-1`, but the speaker change from Diariation come in the prediction `t`
> `#` Is the split between the `t-1` prediction and t prediction.
> `#` Is the split between the `t-1` prediction and `t` prediction.
## Example 1:

View File

@@ -0,0 +1,43 @@
# Technical Integration Guide
This document introduce how to reuse the core components when you do **not** want to ship the bundled frontend, FastAPI server, or even the provided CLI.
---
## 1. Runtime Components
| Layer | File(s) | Purpose |
|-------|---------|---------|
| Transport | `whisperlivekit/basic_server.py`, any ASGI/WebSocket server | Accepts audio over WebSocket (MediaRecorder WebM or raw PCM chunks) and streams JSON updates back |
| Audio processing | `whisperlivekit/audio_processor.py` | Buffers audio, orchestrates transcription, diarization, translation, handles FFmpeg/PCM input |
| Engines | `whisperlivekit/core.py`, `whisperlivekit/simul_whisper/*`, `whisperlivekit/local_agreement/*` | Load models once (SimulStreaming or LocalAgreement), expose `TranscriptionEngine` and helpers |
| Frontends | `whisperlivekit/web/*`, `chrome-extension/*` | Optional UI layers feeding the WebSocket endpoint |
**Key idea:** The server boundary is just `AudioProcessor.process_audio()` for incoming bytes and the async generator returned by `AudioProcessor.create_tasks()` for outgoing updates (`FrontData`). Everything else is optional.
---
## 2. Running Without the Bundled Frontend
1. Start the server/engine however you like:
```bash
wlk --model small --language en --host 0.0.0.0 --port 9000
# or launch your own app that instantiates TranscriptionEngine(...)
```
2. Build your own client (browser, mobile, desktop) that:
- Opens `ws(s)://<host>:<port>/asr`
- Sends either MediaRecorder/Opus WebM blobs **or** raw PCM (`--pcm-input` on the server tells the client to use the AudioWorklet).
- Consumes the JSON payload defined in `docs/API.md`.
---
## 3. Running Without FastAPI
`whisperlivekit/basic_server.py` is just an example. Any async framework works, as long as you:
1. Create a global `TranscriptionEngine` (expensive to initialize; reuse it).
2. Instantiate `AudioProcessor(transcription_engine=engine)` for each connection.
3. Call `create_tasks()` to get the async generator, `process_audio()` with incoming bytes, and ensure `cleanup()` runs when the client disconnects.
If you prefer to send compressed audio, instantiate `AudioProcessor(pcm_input=False)` and pipe encoded chunks through `FFmpegManager` transparently—just ensure `ffmpeg` is available or be ready to handle the `"ffmpeg_not_found"` error in the streamed `FrontData`.

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "whisperlivekit"
version = "0.2.14.post4"
version = "0.2.15"
description = "Real-time speech-to-text with speaker diarization using Whisper"
readme = "README.md"
authors = [

View File

@@ -1,168 +0,0 @@
class TokensAlignment:
def __init__(self, state_light, silence=None, args=None):
self.state_light = state_light
self.silence = silence
self.args = args
self._tokens_index = 0
self._diarization_index = 0
self._translation_index = 0
def update(self):
pass
def compute_punctuations_segments(self):
punctuations_breaks = []
new_tokens = self.state.tokens[self.state.last_validated_token:]
for i in range(len(new_tokens)):
token = new_tokens[i]
if token.is_punctuation():
punctuations_breaks.append({
'token_index': i,
'token': token,
'start': token.start,
'end': token.end,
})
punctuations_segments = []
for i, break_info in enumerate(punctuations_breaks):
start = punctuations_breaks[i - 1]['end'] if i > 0 else 0.0
end = break_info['end']
punctuations_segments.append({
'start': start,
'end': end,
'token_index': break_info['token_index'],
'token': break_info['token']
})
return punctuations_segments
def concatenate_diar_segments(self):
diarization_segments = self.state.diarization_segments
if __name__ == "__main__":
from whisperlivekit.timed_objects import State, ASRToken, SpeakerSegment, Transcript, Silence
# Reconstruct the state from the backup data
tokens = [
ASRToken(start=1.38, end=1.48, text=' The'),
ASRToken(start=1.42, end=1.52, text=' description'),
ASRToken(start=1.82, end=1.92, text=' technology'),
ASRToken(start=2.54, end=2.64, text=' has'),
ASRToken(start=2.7, end=2.8, text=' improved'),
ASRToken(start=3.24, end=3.34, text=' so'),
ASRToken(start=3.66, end=3.76, text=' much'),
ASRToken(start=4.02, end=4.12, text=' in'),
ASRToken(start=4.08, end=4.18, text=' the'),
ASRToken(start=4.26, end=4.36, text=' past'),
ASRToken(start=4.48, end=4.58, text=' few'),
ASRToken(start=4.76, end=4.86, text=' years'),
ASRToken(start=5.76, end=5.86, text='.'),
ASRToken(start=5.72, end=5.82, text=' Have'),
ASRToken(start=5.92, end=6.02, text=' you'),
ASRToken(start=6.08, end=6.18, text=' noticed'),
ASRToken(start=6.52, end=6.62, text=' how'),
ASRToken(start=6.8, end=6.9, text=' accurate'),
ASRToken(start=7.46, end=7.56, text=' real'),
ASRToken(start=7.72, end=7.82, text='-time'),
ASRToken(start=8.06, end=8.16, text=' speech'),
ASRToken(start=8.48, end=8.58, text=' to'),
ASRToken(start=8.68, end=8.78, text=' text'),
ASRToken(start=9.0, end=9.1, text=' is'),
ASRToken(start=9.24, end=9.34, text=' now'),
ASRToken(start=9.82, end=9.92, text='?'),
ASRToken(start=9.86, end=9.96, text=' Absolutely'),
ASRToken(start=11.26, end=11.36, text='.'),
ASRToken(start=11.36, end=11.46, text=' I'),
ASRToken(start=11.58, end=11.68, text=' use'),
ASRToken(start=11.78, end=11.88, text=' it'),
ASRToken(start=11.94, end=12.04, text=' all'),
ASRToken(start=12.08, end=12.18, text=' the'),
ASRToken(start=12.32, end=12.42, text=' time'),
ASRToken(start=12.58, end=12.68, text=' for'),
ASRToken(start=12.78, end=12.88, text=' taking'),
ASRToken(start=13.14, end=13.24, text=' notes'),
ASRToken(start=13.4, end=13.5, text=' during'),
ASRToken(start=13.78, end=13.88, text=' meetings'),
ASRToken(start=14.6, end=14.7, text='.'),
ASRToken(start=14.82, end=14.92, text=' It'),
ASRToken(start=14.92, end=15.02, text="'s"),
ASRToken(start=15.04, end=15.14, text=' amazing'),
ASRToken(start=15.5, end=15.6, text=' how'),
ASRToken(start=15.66, end=15.76, text=' it'),
ASRToken(start=15.8, end=15.9, text=' can'),
ASRToken(start=15.96, end=16.06, text=' recognize'),
ASRToken(start=16.58, end=16.68, text=' different'),
ASRToken(start=16.94, end=17.04, text=' speakers'),
ASRToken(start=17.82, end=17.92, text=' and'),
ASRToken(start=18.0, end=18.1, text=' even'),
ASRToken(start=18.42, end=18.52, text=' add'),
ASRToken(start=18.74, end=18.84, text=' punct'),
ASRToken(start=19.02, end=19.12, text='uation'),
ASRToken(start=19.68, end=19.78, text='.'),
ASRToken(start=20.04, end=20.14, text=' Yeah'),
ASRToken(start=20.5, end=20.6, text=','),
ASRToken(start=20.6, end=20.7, text=' but'),
ASRToken(start=20.76, end=20.86, text=' sometimes'),
ASRToken(start=21.42, end=21.52, text=' noise'),
ASRToken(start=21.82, end=21.92, text=' can'),
ASRToken(start=22.08, end=22.18, text=' still'),
ASRToken(start=22.38, end=22.48, text=' cause'),
ASRToken(start=22.72, end=22.82, text=' mistakes'),
ASRToken(start=23.74, end=23.84, text='.'),
ASRToken(start=23.96, end=24.06, text=' Does'),
ASRToken(start=24.16, end=24.26, text=' this'),
ASRToken(start=24.4, end=24.5, text=' system'),
ASRToken(start=24.76, end=24.86, text=' handle'),
ASRToken(start=25.12, end=25.22, text=' that'),
ASRToken(start=25.38, end=25.48, text=' well'),
ASRToken(start=25.68, end=25.78, text='?'),
ASRToken(start=26.4, end=26.5, text=' It'),
ASRToken(start=26.5, end=26.6, text=' does'),
ASRToken(start=26.7, end=26.8, text=' a'),
ASRToken(start=27.08, end=27.18, text=' pretty'),
ASRToken(start=27.12, end=27.22, text=' good'),
ASRToken(start=27.34, end=27.44, text=' job'),
ASRToken(start=27.64, end=27.74, text=' filtering'),
ASRToken(start=28.1, end=28.2, text=' noise'),
ASRToken(start=28.64, end=28.74, text=','),
ASRToken(start=28.78, end=28.88, text=' especially'),
ASRToken(start=29.3, end=29.4, text=' with'),
ASRToken(start=29.51, end=29.61, text=' models'),
ASRToken(start=29.99, end=30.09, text=' that'),
ASRToken(start=30.21, end=30.31, text=' use'),
ASRToken(start=30.51, end=30.61, text=' voice'),
ASRToken(start=30.83, end=30.93, text=' activity'),
]
diarization_segments = [
SpeakerSegment(start=1.3255040645599365, end=4.3255040645599365, speaker=0),
SpeakerSegment(start=4.806154012680054, end=9.806154012680054, speaker=0),
SpeakerSegment(start=9.806154012680054, end=10.806154012680054, speaker=1),
SpeakerSegment(start=11.168735027313232, end=14.168735027313232, speaker=1),
SpeakerSegment(start=14.41029405593872, end=17.41029405593872, speaker=1),
SpeakerSegment(start=17.52983808517456, end=19.52983808517456, speaker=1),
SpeakerSegment(start=19.64953374862671, end=20.066200415293377, speaker=1),
SpeakerSegment(start=20.066200415293377, end=22.64953374862671, speaker=2),
SpeakerSegment(start=23.012792587280273, end=25.012792587280273, speaker=2),
SpeakerSegment(start=25.495875597000122, end=26.41254226366679, speaker=2),
SpeakerSegment(start=26.41254226366679, end=30.495875597000122, speaker=0),
]
state = State(
tokens=tokens,
last_validated_token=72,
last_speaker=-1,
last_punctuation_index=71,
translation_validated_segments=[],
buffer_translation=Transcript(start=0, end=0, speaker=-1),
buffer_transcription=Transcript(start=None, end=None, speaker=-1),
diarization_segments=diarization_segments,
end_buffer=31.21587559700018,
end_attributed_speaker=30.495875597000122,
remaining_time_transcription=0.4,
remaining_time_diarization=0.7,
beg_loop=1763627603.968919
)
alignment = TokensAlignment(state)

View File

@@ -1,36 +1,23 @@
import asyncio
import numpy as np
from time import time, sleep
import math
from time import time
import logging
import traceback
from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State, StateLight, Transcript, ChangeSpeaker
from typing import Optional, Union, List, Any, AsyncGenerator
from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State, Transcript, ChangeSpeaker
from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory, online_translation_factory
from whisperlivekit.silero_vad_iterator import FixedVADIterator
from whisperlivekit.results_formater import format_output
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
from whisperlivekit.TokensAlignment import TokensAlignment
from whisperlivekit.tokens_alignment import TokensAlignment
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
SENTINEL = object() # unique sentinel object for end of stream marker
MIN_DURATION_REAL_SILENCE = 5
def cut_at(cumulative_pcm, cut_sec):
cumulative_len = 0
cut_sample = int(cut_sec * 16000)
for ind, pcm_array in enumerate(cumulative_pcm):
if (cumulative_len + len(pcm_array)) >= cut_sample:
cut_chunk = cut_sample - cumulative_len
before = np.concatenate(cumulative_pcm[:ind] + [cumulative_pcm[ind][:cut_chunk]])
after = [cumulative_pcm[ind][cut_chunk:]] + cumulative_pcm[ind+1:]
return before, after
cumulative_len += len(pcm_array)
return np.concatenate(cumulative_pcm), []
async def get_all_from_queue(queue):
items = []
async def get_all_from_queue(queue: asyncio.Queue) -> Union[object, Silence, np.ndarray, List[Any]]:
items: List[Any] = []
first_item = await queue.get()
queue.task_done()
@@ -61,7 +48,7 @@ class AudioProcessor:
Handles audio processing, state management, and result formatting.
"""
def __init__(self, **kwargs):
def __init__(self, **kwargs: Any) -> None:
"""Initialize the audio processor with configuration, models, and state."""
if 'transcription_engine' in kwargs and isinstance(kwargs['transcription_engine'], TranscriptionEngine):
@@ -80,33 +67,27 @@ class AudioProcessor:
self.is_pcm_input = self.args.pcm_input
# State management
self.is_stopping = False
self.silence = True
self.silence_duration = 0.0
self.start_silence = None
self.last_silence_dispatch_time = None
self.state = State()
self.state_light = StateLight()
self.lock = asyncio.Lock()
self.sep = " " # Default separator
self.last_response_content = FrontData()
self.last_detected_speaker = None
self.speaker_languages = {}
self.is_stopping: bool = False
self.current_silence: Optional[Silence] = None
self.state: State = State()
self.lock: asyncio.Lock = asyncio.Lock()
self.sep: str = " " # Default separator
self.last_response_content: FrontData = FrontData()
self.tokens_alignment = TokensAlignment(self.state_light, self.args, self.sep)
self.beg_loop = None
self.tokens_alignment: TokensAlignment = TokensAlignment(self.state, self.args, self.sep)
self.beg_loop: Optional[float] = None
# Models and processing
self.asr = models.asr
self.vac_model = models.vac_model
self.asr: Any = models.asr
self.vac_model: Any = models.vac_model
if self.args.vac:
self.vac = FixedVADIterator(models.vac_model)
self.vac: Optional[FixedVADIterator] = FixedVADIterator(models.vac_model)
else:
self.vac = None
self.vac: Optional[FixedVADIterator] = None
self.ffmpeg_manager = None
self.ffmpeg_reader_task = None
self._ffmpeg_error = None
self.ffmpeg_manager: Optional[FFmpegManager] = None
self.ffmpeg_reader_task: Optional[asyncio.Task] = None
self._ffmpeg_error: Optional[str] = None
if not self.is_pcm_input:
self.ffmpeg_manager = FFmpegManager(
@@ -118,21 +99,20 @@ class AudioProcessor:
self._ffmpeg_error = error_type
self.ffmpeg_manager.on_error_callback = handle_ffmpeg_error
self.transcription_queue = asyncio.Queue() if self.args.transcription else None
self.diarization_queue = asyncio.Queue() if self.args.diarization else None
self.translation_queue = asyncio.Queue() if self.args.target_language else None
self.pcm_buffer = bytearray()
self.total_pcm_samples = 0
self.end_buffer = 0.0
self.transcription_task = None
self.diarization_task = None
self.translation_task = None
self.watchdog_task = None
self.all_tasks_for_cleanup = []
self.transcription_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.transcription else None
self.diarization_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.diarization else None
self.translation_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.target_language else None
self.pcm_buffer: bytearray = bytearray()
self.total_pcm_samples: int = 0
self.transcription_task: Optional[asyncio.Task] = None
self.diarization_task: Optional[asyncio.Task] = None
self.translation_task: Optional[asyncio.Task] = None
self.watchdog_task: Optional[asyncio.Task] = None
self.all_tasks_for_cleanup: List[asyncio.Task] = []
self.transcription = None
self.translation = None
self.diarization = None
self.transcription: Optional[Any] = None
self.translation: Optional[Any] = None
self.diarization: Optional[Any] = None
if self.args.transcription:
self.transcription = online_factory(self.args, models.asr)
@@ -142,44 +122,45 @@ class AudioProcessor:
if models.translation_model:
self.translation = online_translation_factory(self.args, models.translation_model)
async def _push_silence_event(self, silence_buffer: Silence):
async def _push_silence_event(self) -> None:
if self.transcription_queue:
await self.transcription_queue.put(silence_buffer)
await self.transcription_queue.put(self.current_silence)
if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(silence_buffer)
await self.diarization_queue.put(self.current_silence)
if self.translation_queue:
await self.translation_queue.put(silence_buffer)
await self.translation_queue.put(self.current_silence)
async def _begin_silence(self):
if self.silence:
async def _begin_silence(self) -> None:
if self.current_silence:
return
self.silence = True
now = time()
self.start_silence = now
self.last_silence_dispatch_time = now
await self._push_silence_event(Silence(is_starting=True))
now = time() - self.beg_loop
self.current_silence = Silence(
is_starting=True, start=now
)
await self._push_silence_event()
async def _end_silence(self):
if not self.silence:
async def _end_silence(self) -> None:
if not self.current_silence:
return
now = time()
duration = now - (self.last_silence_dispatch_time if self.last_silence_dispatch_time else self.beg_loop)
await self._push_silence_event(Silence(duration=duration, has_ended=True))
self.last_silence_dispatch_time = now
self.silence = False
self.start_silence = None
self.last_silence_dispatch_time = None
now = time() - self.beg_loop
self.current_silence.end = now
self.current_silence.is_starting=False
self.current_silence.has_ended=True
self.current_silence.compute_duration()
if self.current_silence.duration > MIN_DURATION_REAL_SILENCE:
self.state.new_tokens.append(self.current_silence)
await self._push_silence_event()
self.current_silence = None
async def _enqueue_active_audio(self, pcm_chunk: np.ndarray):
async def _enqueue_active_audio(self, pcm_chunk: np.ndarray) -> None:
if pcm_chunk is None or pcm_chunk.size == 0:
return
if self.transcription_queue:
await self.transcription_queue.put(pcm_chunk.copy())
if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(pcm_chunk.copy())
self.silence_duration = 0.0
def _slice_before_silence(self, pcm_array, chunk_sample_start, silence_sample):
def _slice_before_silence(self, pcm_array: np.ndarray, chunk_sample_start: int, silence_sample: Optional[int]) -> Optional[np.ndarray]:
if silence_sample is None:
return None
relative_index = int(silence_sample - chunk_sample_start)
@@ -190,22 +171,22 @@ class AudioProcessor:
return None
return pcm_array[:split_index]
def convert_pcm_to_float(self, pcm_buffer):
def convert_pcm_to_float(self, pcm_buffer: Union[bytes, bytearray]) -> np.ndarray:
"""Convert PCM buffer in s16le format to normalized NumPy array."""
return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
async def get_current_state(self):
async def get_current_state(self) -> State:
"""Get current state."""
async with self.lock:
current_time = time()
remaining_transcription = 0
if self.end_buffer > 0:
remaining_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 1))
if self.state.end_buffer > 0:
remaining_transcription = max(0, round(current_time - self.beg_loop - self.state.end_buffer, 1))
remaining_diarization = 0
if self.state.tokens:
latest_end = max(self.end_buffer, self.state.tokens[-1].end if self.state.tokens else 0)
latest_end = max(self.state.end_buffer, self.state.tokens[-1].end if self.state.tokens else 0)
remaining_diarization = max(0, round(latest_end - self.state.end_attributed_speaker, 1))
self.state.remaining_time_transcription = remaining_transcription
@@ -213,7 +194,7 @@ class AudioProcessor:
return self.state
async def ffmpeg_stdout_reader(self):
async def ffmpeg_stdout_reader(self) -> None:
"""Read audio data from FFmpeg stdout and process it into the PCM pipeline."""
beg = time()
while True:
@@ -263,7 +244,7 @@ class AudioProcessor:
if self.translation:
await self.translation_queue.put(SENTINEL)
async def transcription_processor(self):
async def transcription_processor(self) -> None:
"""Process audio chunks for transcription."""
cumulative_pcm_duration_stream_time = 0.0
@@ -276,11 +257,11 @@ class AudioProcessor:
break
asr_internal_buffer_duration_s = len(getattr(self.transcription, 'audio_buffer', [])) / self.transcription.SAMPLING_RATE
transcription_lag_s = max(0.0, time() - self.beg_loop - self.end_buffer)
transcription_lag_s = max(0.0, time() - self.beg_loop - self.state.end_buffer)
asr_processing_logs = f"internal_buffer={asr_internal_buffer_duration_s:.2f}s | lag={transcription_lag_s:.2f}s |"
stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time
new_tokens = []
current_audio_processed_upto = self.end_buffer
current_audio_processed_upto = self.state.end_buffer
if isinstance(item, Silence):
if item.is_starting:
@@ -318,7 +299,7 @@ class AudioProcessor:
if buffer_text.startswith(validated_text):
_buffer_transcript.text = buffer_text[len(validated_text):].lstrip()
candidate_end_times = [self.end_buffer]
candidate_end_times = [self.state.end_buffer]
if new_tokens:
candidate_end_times.append(new_tokens[-1].end)
@@ -331,10 +312,9 @@ class AudioProcessor:
async with self.lock:
self.state.tokens.extend(new_tokens)
self.state.buffer_transcription = _buffer_transcript
self.end_buffer = max(candidate_end_times)
self.state_light.new_tokens = new_tokens
self.state_light.new_tokens += 1
self.state_light.new_tokens_buffer = _buffer_transcript
self.state.end_buffer = max(candidate_end_times)
self.state.new_tokens.extend(new_tokens)
self.state.new_tokens_buffer = _buffer_transcript
if self.translation_queue:
for token in new_tokens:
@@ -355,7 +335,7 @@ class AudioProcessor:
logger.info("Transcription processor task finished.")
async def diarization_processor(self):
async def diarization_processor(self) -> None:
while True:
try:
item = await get_all_from_queue(self.diarization_queue)
@@ -368,41 +348,44 @@ class AudioProcessor:
self.diarization.insert_audio_chunk(item)
diarization_segments = await self.diarization.diarize()
self.state_light.new_diarization = diarization_segments
self.state_light.new_diarization_index += 1
self.state.new_diarization = diarization_segments
except Exception as e:
logger.warning(f"Exception in diarization_processor: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}")
logger.info("Diarization processor task finished.")
async def translation_processor(self):
async def translation_processor(self) -> None:
# the idea is to ignore diarization for the moment. We use only transcription tokens.
# And the speaker is attributed given the segments used for the translation
# in the future we want to have different languages for each speaker etc, so it will be more complex.
while True:
try:
tokens_to_process = await get_all_from_queue(self.translation_queue)
if tokens_to_process is SENTINEL:
item = await get_all_from_queue(self.translation_queue)
if item is SENTINEL:
logger.debug("Translation processor received sentinel. Finishing.")
self.translation_queue.task_done()
break
elif type(tokens_to_process) is Silence:
if tokens_to_process.has_ended:
self.translation.insert_silence(tokens_to_process.duration)
continue
if tokens_to_process:
self.translation.insert_tokens(tokens_to_process)
translation_validated_segments, buffer_translation = await asyncio.to_thread(self.translation.process)
async with self.lock:
self.state.translation_validated_segments = translation_validated_segments
self.state.buffer_translation = buffer_translation
elif type(item) is Silence:
if item.is_starting:
new_translation, new_translation_buffer = self.translation.validate_buffer_and_reset()
if item.has_ended:
self.translation.insert_silence(item.duration)
continue
elif isinstance(item, ChangeSpeaker):
new_translation, new_translation_buffer = self.translation.validate_buffer_and_reset()
pass
else:
self.translation.insert_tokens(item)
new_translation, new_translation_buffer = await asyncio.to_thread(self.translation.process)
async with self.lock:
self.state.new_translation.append(new_translation)
self.state.new_translation_buffer = new_translation_buffer
except Exception as e:
logger.warning(f"Exception in translation_processor: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}")
logger.info("Translation processor task finished.")
async def results_formatter(self):
async def results_formatter(self) -> AsyncGenerator[FrontData, None]:
"""Format processing results for output."""
while True:
try:
@@ -412,55 +395,32 @@ class AudioProcessor:
await asyncio.sleep(1)
continue
state = await self.get_current_state()
self.tokens_alignment.compute_punctuations_segments()
lines, undiarized_text = format_output(
state,
self.silence,
args = self.args,
sep=self.sep
self.tokens_alignment.update()
lines, buffer_diarization_text, buffer_translation_text = self.tokens_alignment.get_lines(
diarization=self.args.diarization,
translation=bool(self.translation),
current_silence=self.current_silence
)
if lines and lines[-1].speaker == -2:
buffer_transcription = Transcript()
else:
buffer_transcription = state.buffer_transcription
state = await self.get_current_state()
buffer_diarization = ''
if undiarized_text:
buffer_diarization = self.sep.join(undiarized_text)
buffer_transcription_text = state.buffer_transcription.text if state.buffer_transcription else ''
async with self.lock:
self.state.end_attributed_speaker = state.end_attributed_speaker
buffer_translation_text = ''
if state.buffer_translation:
raw_buffer_translation = getattr(state.buffer_translation, 'text', state.buffer_translation)
if raw_buffer_translation:
buffer_translation_text = raw_buffer_translation.strip()
response_status = "active_transcription"
if not state.tokens and not buffer_transcription and not buffer_diarization:
if not lines and not buffer_transcription_text and not buffer_diarization_text:
response_status = "no_audio_detected"
lines = []
elif not lines:
lines = [Line(
speaker=1,
start=state.end_buffer,
end=state.end_buffer
)]
response = FrontData(
status=response_status,
lines=lines,
buffer_transcription=buffer_transcription.text.strip(),
buffer_diarization=buffer_diarization,
buffer_transcription=buffer_transcription_text,
buffer_diarization=buffer_diarization_text,
buffer_translation=buffer_translation_text,
remaining_time_transcription=state.remaining_time_transcription,
remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0
)
should_push = (response != self.last_response_content)
if should_push and (lines or buffer_transcription or buffer_diarization or response_status == "no_audio_detected"):
if should_push:
yield response
self.last_response_content = response
@@ -474,17 +434,17 @@ class AudioProcessor:
logger.warning(f"Exception in results_formatter. Traceback: {traceback.format_exc()}")
await asyncio.sleep(0.5)
async def create_tasks(self):
async def create_tasks(self) -> AsyncGenerator[FrontData, None]:
"""Create and start processing tasks."""
self.all_tasks_for_cleanup = []
processing_tasks_for_watchdog = []
processing_tasks_for_watchdog: List[asyncio.Task] = []
# If using FFmpeg (non-PCM input), start it and spawn stdout reader
if not self.is_pcm_input:
success = await self.ffmpeg_manager.start()
if not success:
logger.error("Failed to start FFmpeg manager")
async def error_generator():
async def error_generator() -> AsyncGenerator[FrontData, None]:
yield FrontData(
status="error",
error="FFmpeg failed to start. Please check that FFmpeg is installed."
@@ -515,9 +475,9 @@ class AudioProcessor:
return self.results_formatter()
async def watchdog(self, tasks_to_monitor):
async def watchdog(self, tasks_to_monitor: List[asyncio.Task]) -> None:
"""Monitors the health of critical processing tasks."""
tasks_remaining = [task for task in tasks_to_monitor if task]
tasks_remaining: List[asyncio.Task] = [task for task in tasks_to_monitor if task]
while True:
try:
if not tasks_remaining:
@@ -542,7 +502,7 @@ class AudioProcessor:
except Exception as e:
logger.error(f"Error in watchdog task: {e}", exc_info=True)
async def cleanup(self):
async def cleanup(self) -> None:
"""Clean up resources when processing is complete."""
logger.info("Starting cleanup of AudioProcessor resources.")
self.is_stopping = True
@@ -565,7 +525,7 @@ class AudioProcessor:
self.diarization.close()
logger.info("AudioProcessor cleanup complete.")
def _processing_tasks_done(self):
def _processing_tasks_done(self) -> bool:
"""Return True when all active processing tasks have completed."""
tasks_to_check = [
self.transcription_task,
@@ -576,11 +536,13 @@ class AudioProcessor:
return all(task.done() for task in tasks_to_check if task)
async def process_audio(self, message):
async def process_audio(self, message: Optional[bytes]) -> None:
"""Process incoming audio data."""
if not self.beg_loop:
self.beg_loop = time()
self.current_silence = Silence(start=0.0, is_starting=True)
self.tokens_alignment.beg_loop = self.beg_loop
if not message:
logger.info("Empty audio message received, initiating stop sequence.")
@@ -613,7 +575,7 @@ class AudioProcessor:
else:
logger.warning("Failed to write audio data to FFmpeg")
async def handle_pcm_data(self):
async def handle_pcm_data(self) -> None:
# Process when enough data
if len(self.pcm_buffer) < self.bytes_per_sec:
return
@@ -642,17 +604,17 @@ class AudioProcessor:
if res is not None:
silence_detected = res.get("end", 0) > res.get("start", 0)
if silence_detected and not self.silence:
if silence_detected and not self.current_silence:
pre_silence_chunk = self._slice_before_silence(
pcm_array, chunk_sample_start, res.get("end")
)
if pre_silence_chunk is not None and pre_silence_chunk.size > 0:
await self._enqueue_active_audio(pre_silence_chunk)
await self._begin_silence()
elif self.silence:
elif self.current_silence:
await self._end_silence()
if not self.silence:
if not self.current_silence:
await self._enqueue_active_audio(pcm_array)
self.total_pcm_samples = chunk_sample_end

View File

@@ -224,7 +224,8 @@ class MLXWhisper(ASRBase):
if segment.get("no_speech_prob", 0) > 0.9:
continue
for word in segment.get("words", []):
token = ASRToken(word["start"], word["end"], word["word"], probability=word["probability"])
probability=word["probability"]
token = ASRToken(word["start"], word["end"], word["word"])
tokens.append(token)
return tokens

View File

@@ -411,11 +411,11 @@ class OnlineASRProcessor:
) -> Transcript:
sep = sep if sep is not None else self.asr.sep
text = sep.join(token.text for token in tokens)
probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
# probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
if tokens:
start = offset + tokens[0].start
end = offset + tokens[-1].end
else:
start = None
end = None
return Transcript(start, end, text, probability=probability)
return Transcript(start, end, text)

View File

@@ -1,103 +0,0 @@
from whisperlivekit.timed_objects import ASRToken
from time import time
import re
MIN_SILENCE_DURATION = 4 #in seconds
END_SILENCE_DURATION = 8 #in seconds. you should keep it important to not have false positive when the model lag is important
END_SILENCE_DURATION_VAC = 3 #VAC is good at detecting silences, but we want to skip the smallest silences
def blank_to_silence(tokens):
full_string = ''.join([t.text for t in tokens])
patterns = [re.compile(r'(?:\s*\[BLANK_AUDIO\]\s*)+'), re.compile(r'(?:\s*\[typing\]\s*)+')]
matches = []
for pattern in patterns:
for m in pattern.finditer(full_string):
matches.append({
'start': m.start(),
'end': m.end()
})
if matches:
# cleaned = pattern.sub(' ', full_string).strip()
# print("Cleaned:", cleaned)
cumulated_len = 0
silence_token = None
cleaned_tokens = []
for token in tokens:
if matches:
start = cumulated_len
end = cumulated_len + len(token.text)
cumulated_len = end
if start >= matches[0]['start'] and end <= matches[0]['end']:
if silence_token: #previous token was already silence
silence_token.start = min(silence_token.start, token.start)
silence_token.end = max(silence_token.end, token.end)
else: #new silence
silence_token = ASRToken(
start=token.start,
end=token.end,
speaker=-2,
)
else:
if silence_token: #there was silence but no more
if silence_token.duration() >= MIN_SILENCE_DURATION:
cleaned_tokens.append(
silence_token
)
silence_token = None
matches.pop(0)
cleaned_tokens.append(token)
# print(cleaned_tokens)
return cleaned_tokens
return tokens
def no_token_to_silence(tokens):
new_tokens = []
silence_token = None
for token in tokens:
if token.speaker == -2:
if new_tokens and new_tokens[-1].speaker == -2: #if token is silence and previous one too
new_tokens[-1].end = token.end
else:
new_tokens.append(token)
last_end = new_tokens[-1].end if new_tokens else 0.0
if token.start - last_end >= MIN_SILENCE_DURATION: #if token is not silence but important gap
if new_tokens and new_tokens[-1].speaker == -2:
new_tokens[-1].end = token.start
else:
silence_token = ASRToken(
start=last_end,
end=token.start,
speaker=-2,
)
new_tokens.append(silence_token)
if token.speaker != -2:
new_tokens.append(token)
return new_tokens
def ends_with_silence(tokens, beg_loop, vac_detected_silence):
current_time = time() - (beg_loop if beg_loop else 0.0)
last_token = tokens[-1]
if vac_detected_silence or (current_time - last_token.end >= END_SILENCE_DURATION):
if last_token.speaker == -2:
last_token.end = current_time
else:
tokens.append(
ASRToken(
start=tokens[-1].end,
end=current_time,
speaker=-2,
)
)
return tokens
def handle_silences(tokens, beg_loop, vac_detected_silence):
if not tokens:
return []
tokens = blank_to_silence(tokens) #useful for simulstreaming backend which tends to generate [BLANK_AUDIO] text
tokens = no_token_to_silence(tokens)
tokens = ends_with_silence(tokens, beg_loop, vac_detected_silence)
return tokens

View File

@@ -1,60 +0,0 @@
########### WHAT IS PRODUCED: ###########
SPEAKER 1 0:00:04 - 0:00:06
Transcription technology has improved so much in the past
SPEAKER 1 0:00:07 - 0:00:12
years. Have you noticed how accurate real-time speech detects is now?
SPEAKER 2 0:00:12 - 0:00:12
Absolutely
SPEAKER 1 0:00:13 - 0:00:13
.
SPEAKER 2 0:00:14 - 0:00:14
I
SPEAKER 1 0:00:14 - 0:00:17
use it all the time for taking notes during meetings.
SPEAKER 2 0:00:17 - 0:00:17
It
SPEAKER 1 0:00:17 - 0:00:22
's amazing how it can recognize different speakers, and even add punctuation.
SPEAKER 2 0:00:22 - 0:00:22
Yeah
SPEAKER 1 0:00:23 - 0:00:26
, but sometimes noise can still cause mistakes.
SPEAKER 3 0:00:26 - 0:00:27
Does
SPEAKER 1 0:00:27 - 0:00:28
this system handle that
SPEAKER 1 0:00:29 - 0:00:29
?
SPEAKER 3 0:00:29 - 0:00:29
It
SPEAKER 1 0:00:29 - 0:00:33
does a pretty good job filtering noise, especially with models that use voice activity
########### WHAT SHOULD BE PRODUCED: ###########
SPEAKER 1 0:00:04 - 0:00:12
Transcription technology has improved so much in the past years. Have you noticed how accurate real-time speech detects is now?
SPEAKER 2 0:00:12 - 0:00:22
Absolutely. I use it all the time for taking notes during meetings. It's amazing how it can recognize different speakers, and even add punctuation.
SPEAKER 3 0:00:22 - 0:00:28
Yeah, but sometimes noise can still cause mistakes. Does this system handle that well?
SPEAKER 1 0:00:29 - 0:00:29
It does a pretty good job filtering noise, especially with models that use voice activity

View File

@@ -1,257 +0,0 @@
import logging
import re
from whisperlivekit.remove_silences import handle_silences
from whisperlivekit.timed_objects import Line, format_time, SpeakerSegment
from typing import List
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
CHECK_AROUND = 4
DEBUG = False
def next_punctuation_change(i, tokens):
for ind in range(i+1, min(len(tokens), i+CHECK_AROUND+1)):
if tokens[ind].is_punctuation():
return ind
return None
def next_speaker_change(i, tokens, speaker):
for ind in range(i-1, max(0, i-CHECK_AROUND)-1, -1):
token = tokens[ind]
if token.is_punctuation():
break
if token.speaker != speaker:
return ind, token.speaker
return None, speaker
def new_line(
token,
):
return Line(
speaker = token.corrected_speaker,
text = token.text + (f"[{format_time(token.start)} : {format_time(token.end)}]" if DEBUG else ""),
start = token.start,
end = token.end,
detected_language=token.detected_language
)
def append_token_to_last_line(lines, sep, token):
if not lines:
lines.append(new_line(token))
else:
if token.text:
lines[-1].text += sep + token.text + (f"[{format_time(token.start)} : {format_time(token.end)}]" if DEBUG else "")
lines[-1].end = token.end
if not lines[-1].detected_language and token.detected_language:
lines[-1].detected_language = token.detected_language
def extract_number(s) -> int:
"""Extract number from speaker string (for diart compatibility)."""
if isinstance(s, int):
return s
m = re.search(r'\d+', str(s))
return int(m.group()) if m else 0
def concatenate_speakers(segments: List[SpeakerSegment]) -> List[dict]:
"""Concatenate consecutive segments from the same speaker."""
if not segments:
return []
# Get speaker number from first segment
first_speaker = extract_number(segments[0].speaker)
segments_concatenated = [{"speaker": first_speaker + 1, "begin": segments[0].start, "end": segments[0].end}]
for segment in segments[1:]:
speaker = extract_number(segment.speaker) + 1
if segments_concatenated[-1]['speaker'] != speaker:
segments_concatenated.append({"speaker": speaker, "begin": segment.start, "end": segment.end})
else:
segments_concatenated[-1]['end'] = segment.end
return segments_concatenated
def add_speaker_to_tokens_with_punctuation(segments: List[SpeakerSegment], tokens: list) -> list:
"""Assign speakers to tokens with punctuation-aware boundary adjustment."""
punctuation_marks = {'.', '!', '?'}
punctuation_tokens = [token for token in tokens if token.text.strip() in punctuation_marks]
segments_concatenated = concatenate_speakers(segments)
for ind, segment in enumerate(segments_concatenated):
for i, punctuation_token in enumerate(punctuation_tokens):
if punctuation_token.start > segment['end']:
after_length = punctuation_token.start - segment['end']
before_length = segment['end'] - punctuation_tokens[i - 1].end if i > 0 else float('inf')
if before_length > after_length:
segment['end'] = punctuation_token.start
if i < len(punctuation_tokens) - 1 and ind + 1 < len(segments_concatenated):
segments_concatenated[ind + 1]['begin'] = punctuation_token.start
else:
segment['end'] = punctuation_tokens[i - 1].end if i > 0 else segment['end']
if i < len(punctuation_tokens) - 1 and ind - 1 >= 0:
segments_concatenated[ind - 1]['begin'] = punctuation_tokens[i - 1].end
break
# Ensure non-overlapping tokens
last_end = 0.0
for token in tokens:
start = max(last_end + 0.01, token.start)
token.start = start
token.end = max(start, token.end)
last_end = token.end
# Assign speakers based on adjusted segments
ind_last_speaker = 0
for segment in segments_concatenated:
for i, token in enumerate(tokens[ind_last_speaker:]):
if token.end <= segment['end']:
token.speaker = segment['speaker']
ind_last_speaker = i + 1
elif token.start > segment['end']:
break
return tokens
def assign_speakers_to_tokens(tokens: list, segments: List[SpeakerSegment], use_punctuation_split: bool = False) -> list:
"""
Assign speakers to tokens based on timing overlap with speaker segments.
Args:
tokens: List of tokens with timing information
segments: List of speaker segments
use_punctuation_split: Whether to use punctuation for boundary refinement
Returns:
List of tokens with speaker assignments
"""
if not segments or not tokens:
logger.debug("No segments or tokens available for speaker assignment")
return tokens
logger.debug(f"Assigning speakers to {len(tokens)} tokens using {len(segments)} segments")
if not use_punctuation_split:
# Simple overlap-based assignment
for token in tokens:
token.speaker = -1 # Default to no speaker
for segment in segments:
# Check for timing overlap
if not (segment.end <= token.start or segment.start >= token.end):
speaker_num = extract_number(segment.speaker)
token.speaker = speaker_num + 1 # Convert to 1-based indexing
break
else:
# Use punctuation-aware assignment
tokens = add_speaker_to_tokens_with_punctuation(segments, tokens)
return tokens
def format_output(state, silence, args, sep):
diarization = args.diarization
disable_punctuation_split = args.disable_punctuation_split
tokens = state.tokens
translation_validated_segments = state.translation_validated_segments # Here we will attribute the speakers only based on the timestamps of the segments
last_validated_token = state.last_validated_token
last_speaker = abs(state.last_speaker)
undiarized_text = []
tokens = handle_silences(tokens, state.beg_loop, silence)
# Assign speakers to tokens based on segments stored in state
if False and diarization and state.diarization_segments:
use_punctuation_split = args.punctuation_split if hasattr(args, 'punctuation_split') else False
tokens = assign_speakers_to_tokens(tokens, state.diarization_segments, use_punctuation_split=use_punctuation_split)
for i in range(last_validated_token, len(tokens)):
token = tokens[i]
speaker = int(token.speaker)
token.corrected_speaker = speaker
if True or not diarization:
if speaker == -1: #Speaker -1 means no attributed by diarization. In the frontend, it should appear under 'Speaker 1'
token.corrected_speaker = 1
token.validated_speaker = True
else:
if token.speaker == -1:
undiarized_text.append(token.text)
elif token.is_punctuation():
state.last_punctuation_index = i
token.corrected_speaker = last_speaker
token.validated_speaker = True
elif state.last_punctuation_index == i-1:
if token.speaker != last_speaker:
token.corrected_speaker = token.speaker
token.validated_speaker = True
# perfect, diarization perfectly aligned
else:
speaker_change_pos, new_speaker = next_speaker_change(i, tokens, speaker)
if speaker_change_pos:
# Corrects delay:
# That was the idea. <Okay> haha |SPLIT SPEAKER| that's a good one
# should become:
# That was the idea. |SPLIT SPEAKER| <Okay> haha that's a good one
token.corrected_speaker = new_speaker
token.validated_speaker = True
elif speaker != last_speaker:
if not (speaker == -2 or last_speaker == -2):
if next_punctuation_change(i, tokens):
# Corrects advance:
# Are you |SPLIT SPEAKER| <okay>? yeah, sure. Absolutely
# should become:
# Are you <okay>? |SPLIT SPEAKER| yeah, sure. Absolutely
token.corrected_speaker = last_speaker
token.validated_speaker = True
else: #Problematic, except if the language has no punctuation. We append to previous line, except if disable_punctuation_split is set to True.
if not disable_punctuation_split:
token.corrected_speaker = last_speaker
token.validated_speaker = False
if token.validated_speaker:
state.last_validated_token = i
state.last_speaker = token.corrected_speaker
last_speaker = 1
lines = []
for token in tokens:
if token.corrected_speaker != -1:
if int(token.corrected_speaker) != int(last_speaker):
lines.append(new_line(token))
else:
append_token_to_last_line(lines, sep, token)
last_speaker = token.corrected_speaker
if lines:
unassigned_translated_segments = []
for ts in translation_validated_segments:
assigned = False
for line in lines:
if ts and ts.overlaps_with(line):
if ts.is_within(line):
line.translation += ts.text + ' '
assigned = True
break
else:
ts0, ts1 = ts.approximate_cut_at(line.end)
if ts0 and line.overlaps_with(ts0):
line.translation += ts0.text + ' '
if ts1:
unassigned_translated_segments.append(ts1)
assigned = True
break
if not assigned:
unassigned_translated_segments.append(ts)
if unassigned_translated_segments:
for line in lines:
remaining_segments = []
for ts in unassigned_translated_segments:
if ts and ts.overlaps_with(line):
line.translation += ts.text + ' '
else:
remaining_segments.append(ts)
unassigned_translated_segments = remaining_segments #maybe do smth in the future about that
if state.buffer_transcription and lines:
lines[-1].end = max(state.buffer_transcription.end, lines[-1].end)
return lines, undiarized_text

View File

@@ -266,7 +266,7 @@ class AlignAtt:
logger.debug("Refreshing segment:")
self.init_tokens()
self.last_attend_frame = -self.cfg.rewind_threshold
self.detected_language = None
# self.detected_language = None
self.cumulative_time_offset = 0.0
self.init_context()
logger.debug(f"Context: {self.context}")

View File

@@ -0,0 +1,261 @@
"""
ALPHA. results are not great yet
To replace `whisperlivekit.silero_vad_iterator import FixedVADIterator`
by `from whisperlivekit.ten_vad_iterator import TenVADIterator`
Use self.vac = TenVADIterator() instead of self.vac = FixedVADIterator(models.vac_model)
"""
import numpy as np
from ten_vad import TenVad
class TenVADIterator:
def __init__(self,
threshold: float = 0.5,
sampling_rate: int = 16000,
min_silence_duration_ms: int = 100,
speech_pad_ms: int = 30):
self.vad = TenVad()
self.threshold = threshold
self.sampling_rate = sampling_rate
self.min_silence_duration_ms = min_silence_duration_ms
self.speech_pad_ms = speech_pad_ms
self.min_silence_samples = int(sampling_rate * min_silence_duration_ms / 1000)
self.speech_pad_samples = int(sampling_rate * speech_pad_ms / 1000)
self.reset_states()
def reset_states(self):
self.triggered = False
self.temp_end = 0
self.current_sample = 0
self.buffer = np.array([], dtype=np.float32)
def __call__(self, x, return_seconds=False):
if not isinstance(x, np.ndarray):
x = np.array(x, dtype=np.float32)
self.buffer = np.append(self.buffer, x)
chunk_size = 256
ret = None
while len(self.buffer) >= chunk_size:
chunk = self.buffer[:chunk_size].astype(np.int16)
self.buffer = self.buffer[chunk_size:]
window_size_samples = len(chunk)
self.current_sample += window_size_samples
speech_prob, speech_flag = self.vad.process(chunk)
if (speech_prob >= self.threshold) and self.temp_end:
self.temp_end = 0
if (speech_prob >= self.threshold) and not self.triggered:
self.triggered = True
speech_start = max(0, self.current_sample - self.speech_pad_samples - window_size_samples)
result = {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, 1)}
if ret is None:
ret = result
elif "end" in ret:
ret = result
else:
ret.update(result)
if (speech_prob < self.threshold - 0.15) and self.triggered:
if not self.temp_end:
self.temp_end = self.current_sample
if self.current_sample - self.temp_end < self.min_silence_samples:
continue
else:
speech_end = self.temp_end + self.speech_pad_samples - window_size_samples
self.temp_end = 0
self.triggered = False
result = {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, 1)}
if ret is None:
ret = result
else:
ret.update(result)
return ret if ret != {} else None
def test_on_record_wav():
import os
from pathlib import Path
audio_file = Path("record.wav")
if not audio_file.exists():
return
import soundfile as sf
audio_data, sample_rate = sf.read(str(audio_file), dtype='float32')
if len(audio_data.shape) > 1:
audio_data = np.mean(audio_data, axis=1)
vad = TenVADIterator(
threshold=0.5,
sampling_rate=sample_rate,
min_silence_duration_ms=100,
speech_pad_ms=30
)
chunk_size = 1024
speech_segments = []
current_segment = None
for i in range(0, len(audio_data), chunk_size):
chunk = audio_data[i:i+chunk_size]
if chunk.dtype != np.int16:
chunk_int16 = (chunk * 32767.0).astype(np.int16)
else:
chunk_int16 = chunk
result = vad(chunk_int16, return_seconds=True)
if result is not None:
if 'start' in result:
current_segment = {'start': result['start'], 'end': None}
print(f"Speech start detected at {result['start']:.2f}s")
elif 'end' in result:
if current_segment:
current_segment['end'] = result['end']
duration = current_segment['end'] - current_segment['start']
speech_segments.append(current_segment)
print(f"Speech end detected at {result['end']:.2f}s (duration: {duration:.2f}s)")
current_segment = None
else:
print(f"Speech end detected at {result['end']:.2f}s (no corresponding start)")
if current_segment and current_segment['end'] is None:
current_segment['end'] = len(audio_data) / sample_rate
speech_segments.append(current_segment)
print(f"End of file (last segment at {current_segment['start']:.2f}s)")
print("-" * 60)
print(f"\nSummary:")
print(f"Number of speech segments detected: {len(speech_segments)}")
if speech_segments:
total_speech_time = sum(seg['end'] - seg['start'] for seg in speech_segments)
total_time = len(audio_data) / sample_rate
speech_ratio = (total_speech_time / total_time) * 100
print(f"Total speech time: {total_speech_time:.2f}s")
print(f"Total file time: {total_time:.2f}s")
print(f"Speech ratio: {speech_ratio:.1f}%")
print(f"\nDetected segments:")
for i, seg in enumerate(speech_segments, 1):
print(f" {i}. {seg['start']:.2f}s - {seg['end']:.2f}s (duration: {seg['end'] - seg['start']:.2f}s)")
else:
print("No speech segments detected")
print("\n" + "=" * 60)
print("Extracting silence segments...")
silence_segments = []
total_time = len(audio_data) / sample_rate
if not speech_segments:
silence_segments = [{'start': 0.0, 'end': total_time}]
else:
if speech_segments[0]['start'] > 0:
silence_segments.append({'start': 0.0, 'end': speech_segments[0]['start']})
for i in range(len(speech_segments) - 1):
silence_start = speech_segments[i]['end']
silence_end = speech_segments[i + 1]['start']
if silence_end > silence_start:
silence_segments.append({'start': silence_start, 'end': silence_end})
if speech_segments[-1]['end'] < total_time:
silence_segments.append({'start': speech_segments[-1]['end'], 'end': total_time})
silence_audio = np.array([], dtype=audio_data.dtype)
for seg in silence_segments:
start_sample = int(seg['start'] * sample_rate)
end_sample = int(seg['end'] * sample_rate)
start_sample = max(0, min(start_sample, len(audio_data)))
end_sample = max(0, min(end_sample, len(audio_data)))
if end_sample > start_sample:
silence_audio = np.concatenate([silence_audio, audio_data[start_sample:end_sample]])
if len(silence_audio) > 0:
output_file = "record_silence_only.wav"
try:
import soundfile as sf
sf.write(output_file, silence_audio, sample_rate)
print(f"Silence file saved: {output_file}")
except ImportError:
try:
from scipy.io import wavfile
if silence_audio.dtype == np.float32:
silence_audio_int16 = (silence_audio * 32767.0).astype(np.int16)
else:
silence_audio_int16 = silence_audio.astype(np.int16)
wavfile.write(output_file, sample_rate, silence_audio_int16)
print(f"Silence file saved: {output_file}")
except ImportError:
print("Unable to save: soundfile or scipy required")
total_silence_time = sum(seg['end'] - seg['start'] for seg in silence_segments)
silence_ratio = (total_silence_time / total_time) * 100
print(f"Total silence duration: {total_silence_time:.2f}s")
print(f"Silence ratio: {silence_ratio:.1f}%")
print(f"Number of silence segments: {len(silence_segments)}")
print(f"\nYou can listen to {output_file} to verify that only silences are present.")
else:
print("No silence segments found (file entirely speech)")
print("\n" + "=" * 60)
print("Extracting speech segments...")
if speech_segments:
speech_audio = np.array([], dtype=audio_data.dtype)
for seg in speech_segments:
start_sample = int(seg['start'] * sample_rate)
end_sample = int(seg['end'] * sample_rate)
start_sample = max(0, min(start_sample, len(audio_data)))
end_sample = max(0, min(end_sample, len(audio_data)))
if end_sample > start_sample:
speech_audio = np.concatenate([speech_audio, audio_data[start_sample:end_sample]])
if len(speech_audio) > 0:
output_file = "record_speech_only.wav"
try:
import soundfile as sf
sf.write(output_file, speech_audio, sample_rate)
print(f"Speech file saved: {output_file}")
except ImportError:
try:
from scipy.io import wavfile
if speech_audio.dtype == np.float32:
speech_audio_int16 = (speech_audio * 32767.0).astype(np.int16)
else:
speech_audio_int16 = speech_audio.astype(np.int16)
wavfile.write(output_file, sample_rate, speech_audio_int16)
print(f"Speech file saved: {output_file}")
except ImportError:
print("Unable to save: soundfile or scipy required")
total_speech_time = sum(seg['end'] - seg['start'] for seg in speech_segments)
speech_ratio = (total_speech_time / total_time) * 100
print(f"Total speech duration: {total_speech_time:.2f}s")
print(f"Speech ratio: {speech_ratio:.1f}%")
print(f"Number of speech segments: {len(speech_segments)}")
print(f"\nYou can listen to {output_file} to verify that only speech segments are present.")
else:
print("No speech audio to extract")
else:
print("No speech segments found (file entirely silence)")
if __name__ == "__main__":
test_on_record_wav()

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Optional, Any, List
from typing import Optional, List, Union, Dict, Any
from datetime import timedelta
PUNCTUATION_MARKS = {'.', '!', '?', '', '', ''}
@@ -19,11 +19,8 @@ class TimedText(Timed):
speaker: Optional[int] = -1
detected_language: Optional[str] = None
def is_punctuation(self):
return self.text.strip() in PUNCTUATION_MARKS
def overlaps_with(self, other: 'TimedText') -> bool:
return not (self.end <= other.start or other.end <= self.start)
def has_punctuation(self) -> bool:
return any(char in PUNCTUATION_MARKS for char in self.text.strip())
def is_within(self, other: 'TimedText') -> bool:
return other.contains_timespan(self)
@@ -31,28 +28,26 @@ class TimedText(Timed):
def duration(self) -> float:
return self.end - self.start
def contains_time(self, time: float) -> bool:
return self.start <= time <= self.end
def contains_timespan(self, other: 'TimedText') -> bool:
return self.start <= other.start and self.end >= other.end
def __bool__(self):
def __bool__(self) -> bool:
return bool(self.text)
def __str__(self) -> str:
return str(self.text)
@dataclass()
class ASRToken(TimedText):
corrected_speaker: Optional[int] = -1
validated_speaker: bool = False
validated_text: bool = False
validated_language: bool = False
def with_offset(self, offset: float) -> "ASRToken":
"""Return a new token with the time offset added."""
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, detected_language=self.detected_language)
def is_silence(self) -> bool:
return False
@dataclass
class Sentence(TimedText):
pass
@@ -70,6 +65,7 @@ class Transcript(TimedText):
sep: Optional[str] = None,
offset: float = 0
) -> "Transcript":
"""Collapse multiple ASR tokens into a single transcript span."""
sep = sep if sep is not None else ' '
text = sep.join(token.text for token in tokens)
if tokens:
@@ -93,47 +89,70 @@ class SpeakerSegment(Timed):
class Translation(TimedText):
pass
def approximate_cut_at(self, cut_time):
"""
Each word in text is considered to be of duration (end-start)/len(words in text)
"""
if not self.text or not self.contains_time(cut_time):
return self, None
words = self.text.split()
num_words = len(words)
if num_words == 0:
return self, None
duration_per_word = self.duration() / num_words
cut_word_index = int((cut_time - self.start) / duration_per_word)
if cut_word_index >= num_words:
cut_word_index = num_words -1
text0 = " ".join(words[:cut_word_index])
text1 = " ".join(words[cut_word_index:])
segment0 = Translation(start=self.start, end=cut_time, text=text0)
segment1 = Translation(start=cut_time, end=self.end, text=text1)
return segment0, segment1
@dataclass
class Silence():
start: Optional[float] = None
end: Optional[float] = None
duration: Optional[float] = None
is_starting: bool = False
has_ended: bool = False
def compute_duration(self) -> Optional[float]:
if self.start is None or self.end is None:
return None
self.duration = self.end - self.start
return self.duration
def is_silence(self) -> bool:
return True
@dataclass
class Segment(TimedText):
"""Generic contiguous span built from tokens or silence markers."""
start: Optional[float]
end: Optional[float]
text: Optional[str]
speaker: Optional[str]
@classmethod
def from_tokens(
cls,
tokens: List[Union[ASRToken, Silence]],
is_silence: bool = False
) -> Optional["Segment"]:
"""Return a normalized segment representing the provided tokens."""
if not tokens:
return None
start_token = tokens[0]
end_token = tokens[-1]
if is_silence:
return cls(
start=start_token.start,
end=end_token.end,
text=None,
speaker=-2
)
else:
return cls(
start=start_token.start,
end=end_token.end,
text=''.join(token.text for token in tokens),
speaker=-1,
detected_language=start_token.detected_language
)
def is_silence(self) -> bool:
"""True when this segment represents a silence gap."""
return self.speaker == -2
@dataclass
class Line(TimedText):
translation: str = ''
def to_dict(self):
_dict = {
def to_dict(self) -> Dict[str, Any]:
"""Serialize the line for frontend consumption."""
_dict: Dict[str, Any] = {
'speaker': int(self.speaker) if self.speaker != -1 else 1,
'text': self.text,
'start': format_time(self.start),
@@ -145,6 +164,33 @@ class Line(TimedText):
_dict['detected_language'] = self.detected_language
return _dict
def build_from_tokens(self, tokens: List[ASRToken]) -> "Line":
"""Populate line attributes from a contiguous token list."""
self.text = ''.join([token.text for token in tokens])
self.start = tokens[0].start
self.end = tokens[-1].end
self.speaker = 1
self.detected_language = tokens[0].detected_language
return self
def build_from_segment(self, segment: Segment) -> "Line":
"""Populate the line fields from a pre-built segment."""
self.text = segment.text
self.start = segment.start
self.end = segment.end
self.speaker = segment.speaker
self.detected_language = segment.detected_language
return self
def is_silent(self) -> bool:
return self.speaker == -2
class SilentLine(Line):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.speaker = -2
self.text = ''
@dataclass
class FrontData():
@@ -157,8 +203,9 @@ class FrontData():
remaining_time_transcription: float = 0.
remaining_time_diarization: float = 0.
def to_dict(self):
_dict = {
def to_dict(self) -> Dict[str, Any]:
"""Serialize the front-end data payload."""
_dict: Dict[str, Any] = {
'status': self.status,
'lines': [line.to_dict() for line in self.lines if (line.text or line.speaker == -2)],
'buffer_transcription': self.buffer_transcription,
@@ -178,26 +225,22 @@ class ChangeSpeaker:
@dataclass
class State():
tokens: list = field(default_factory=list)
last_validated_token: int = 0
last_speaker: int = 1
last_punctuation_index: Optional[int] = None
translation_validated_segments: list = field(default_factory=list)
buffer_translation: str = field(default_factory=Transcript)
buffer_transcription: str = field(default_factory=Transcript)
diarization_segments: list = field(default_factory=list)
"""Unified state class for audio processing.
Contains both persistent state (tokens, buffers) and temporary update buffers
(new_* fields) that are consumed by TokensAlignment.
"""
# Persistent state
tokens: List[ASRToken] = field(default_factory=list)
buffer_transcription: Transcript = field(default_factory=Transcript)
end_buffer: float = 0.0
end_attributed_speaker: float = 0.0
remaining_time_transcription: float = 0.0
remaining_time_diarization: float = 0.0
@dataclass
class StateLight():
new_tokens: list = field(default_factory=list)
new_translation: list = field(default_factory=list)
new_diarization: list = field(default_factory=list)
new_tokens_buffer: list = field(default_factory=list) #only when local agreement
new_tokens_index = 0
new_translation_index = 0
new_diarization_index = 0
# Temporary update buffers (consumed by TokensAlignment.update())
new_tokens: List[Union[ASRToken, Silence]] = field(default_factory=list)
new_translation: List[Any] = field(default_factory=list)
new_diarization: List[Any] = field(default_factory=list)
new_tokens_buffer: List[Any] = field(default_factory=list) # only when local agreement
new_translation_buffer= TimedText()

View File

@@ -0,0 +1,177 @@
from time import time
from typing import Optional, List, Tuple, Union, Any
from whisperlivekit.timed_objects import Line, SilentLine, ASRToken, SpeakerSegment, Silence, TimedText, Segment
class TokensAlignment:
def __init__(self, state: Any, args: Any, sep: Optional[str]) -> None:
self.state = state
self.diarization = args.diarization
self._tokens_index: int = 0
self._diarization_index: int = 0
self._translation_index: int = 0
self.all_tokens: List[ASRToken] = []
self.all_diarization_segments: List[SpeakerSegment] = []
self.all_translation_segments: List[Any] = []
self.new_tokens: List[ASRToken] = []
self.new_diarization: List[SpeakerSegment] = []
self.new_translation: List[Any] = []
self.new_translation_buffer: Union[TimedText, str] = TimedText()
self.new_tokens_buffer: List[Any] = []
self.sep: str = sep if sep is not None else ' '
self.beg_loop: Optional[float] = None
def update(self) -> None:
"""Drain state buffers into the running alignment context."""
self.new_tokens, self.state.new_tokens = self.state.new_tokens, []
self.new_diarization, self.state.new_diarization = self.state.new_diarization, []
self.new_translation, self.state.new_translation = self.state.new_translation, []
self.new_tokens_buffer, self.state.new_tokens_buffer = self.state.new_tokens_buffer, []
self.all_tokens.extend(self.new_tokens)
self.all_diarization_segments.extend(self.new_diarization)
self.all_translation_segments.extend(self.new_translation)
self.new_translation_buffer = self.state.new_translation_buffer
def add_translation(self, line: Line) -> None:
"""Append translated text segments that overlap with a line."""
for ts in self.all_translation_segments:
if ts.is_within(line):
line.translation += ts.text + (self.sep if ts.text else '')
elif line.translation:
break
def compute_punctuations_segments(self, tokens: Optional[List[ASRToken]] = None) -> List[Segment]:
"""Group tokens into segments split by punctuation and explicit silence."""
segments = []
segment_start_idx = 0
for i, token in enumerate(self.all_tokens):
if token.is_silence():
previous_segment = Segment.from_tokens(
tokens=self.all_tokens[segment_start_idx: i],
)
if previous_segment:
segments.append(previous_segment)
segment = Segment.from_tokens(
tokens=[token],
is_silence=True
)
segments.append(segment)
segment_start_idx = i+1
else:
if token.has_punctuation():
segment = Segment.from_tokens(
tokens=self.all_tokens[segment_start_idx: i+1],
)
segments.append(segment)
segment_start_idx = i+1
final_segment = Segment.from_tokens(
tokens=self.all_tokens[segment_start_idx:],
)
if final_segment:
segments.append(final_segment)
return segments
def concatenate_diar_segments(self) -> List[SpeakerSegment]:
"""Merge consecutive diarization slices that share the same speaker."""
if not self.all_diarization_segments:
return []
merged = [self.all_diarization_segments[0]]
for segment in self.all_diarization_segments[1:]:
if segment.speaker == merged[-1].speaker:
merged[-1].end = segment.end
else:
merged.append(segment)
return merged
@staticmethod
def intersection_duration(seg1: TimedText, seg2: TimedText) -> float:
"""Return the overlap duration between two timed segments."""
start = max(seg1.start, seg2.start)
end = min(seg1.end, seg2.end)
return max(0, end - start)
def get_lines_diarization(self) -> Tuple[List[Line], str]:
"""Build lines when diarization is enabled and track overflow buffer."""
diarization_buffer = ''
punctuation_segments = self.compute_punctuations_segments()
diarization_segments = self.concatenate_diar_segments()
for punctuation_segment in punctuation_segments:
if not punctuation_segment.is_silence():
if diarization_segments and punctuation_segment.start >= diarization_segments[-1].end:
diarization_buffer += punctuation_segment.text
else:
max_overlap = 0.0
max_overlap_speaker = 1
for diarization_segment in diarization_segments:
intersec = self.intersection_duration(punctuation_segment, diarization_segment)
if intersec > max_overlap:
max_overlap = intersec
max_overlap_speaker = diarization_segment.speaker + 1
punctuation_segment.speaker = max_overlap_speaker
lines = []
if punctuation_segments:
lines = [Line().build_from_segment(punctuation_segments[0])]
for segment in punctuation_segments[1:]:
if segment.speaker == lines[-1].speaker:
if lines[-1].text:
lines[-1].text += segment.text
lines[-1].end = segment.end
else:
lines.append(Line().build_from_segment(segment))
return lines, diarization_buffer
def get_lines(
self,
diarization: bool = False,
translation: bool = False,
current_silence: Optional[Silence] = None
) -> Tuple[List[Line], str, Union[str, TimedText]]:
"""Return the formatted lines plus buffers, optionally with diarization/translation."""
if diarization:
lines, diarization_buffer = self.get_lines_diarization()
else:
diarization_buffer = ''
lines = []
current_line_tokens = []
for token in self.all_tokens:
if token.is_silence():
if current_line_tokens:
lines.append(Line().build_from_tokens(current_line_tokens))
current_line_tokens = []
end_silence = token.end if token.has_ended else time() - self.beg_loop
if lines and lines[-1].is_silent():
lines[-1].end = end_silence
else:
lines.append(SilentLine(
start = token.start,
end = end_silence
))
else:
current_line_tokens.append(token)
if current_line_tokens:
lines.append(Line().build_from_tokens(current_line_tokens))
if current_silence:
end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop
if lines and lines[-1].is_silent():
lines[-1].end = end_silence
else:
lines.append(SilentLine(
start = current_silence.start,
end = end_silence
))
if translation:
[self.add_translation(line) for line in lines if not type(line) == Silence]
return lines, diarization_buffer, self.new_translation_buffer.text

View File

@@ -1,60 +0,0 @@
from typing import Sequence, Callable, Any, Optional, Dict
def _detect_tail_repetition(
seq: Sequence[Any],
key: Callable[[Any], Any] = lambda x: x, # extract comparable value
min_block: int = 1, # set to 2 to ignore 1-token loops like "."
max_tail: int = 300, # search window from the end for speed
prefer: str = "longest", # "longest" coverage or "smallest" block
) -> Optional[Dict]:
vals = [key(x) for x in seq][-max_tail:]
n = len(vals)
best = None
# try every possible block length
for b in range(min_block, n // 2 + 1):
block = vals[-b:]
# count how many times this block repeats contiguously at the very end
count, i = 0, n
while i - b >= 0 and vals[i - b:i] == block:
count += 1
i -= b
if count >= 2:
cand = {
"block_size": b,
"count": count,
"start_index": len(seq) - count * b, # in original seq
"end_index": len(seq),
}
if (best is None or
(prefer == "longest" and count * b > best["count"] * best["block_size"]) or
(prefer == "smallest" and b < best["block_size"])):
best = cand
return best
def trim_tail_repetition(
seq: Sequence[Any],
key: Callable[[Any], Any] = lambda x: x,
min_block: int = 1,
max_tail: int = 300,
prefer: str = "longest",
keep: int = 1, # how many copies of the repeating block to keep at the end (0 or 1 are common)
):
"""
Returns a new sequence with repeated tail trimmed.
keep=1 -> keep a single copy of the repeated block.
keep=0 -> remove all copies of the repeated block.
"""
rep = _detect_tail_repetition(seq, key, min_block, max_tail, prefer)
if not rep:
return seq, False # nothing to trim
b, c = rep["block_size"], rep["count"]
if keep < 0:
keep = 0
if keep >= c:
return seq, False # nothing to trim (already <= keep copies)
# new length = total - (copies_to_remove * block_size)
new_len = len(seq) - (c - keep) * b
return seq[:new_len], True