8 Commits

Author SHA1 Message Date
Quentin Fuxa
9d4ae33249 WIP. Trying ten VAD #280 2025-11-23 11:20:00 +01:00
Quentin Fuxa
6206fff118 0.2.15 2025-11-21 23:52:00 +01:00
Quentin Fuxa
b5067249c0 stt/diar/nllw alignment: internal rework 5 2025-11-20 23:52:00 +01:00
Quentin Fuxa
f4f9831d39 stt/diar/nllw alignment: internal rework 5 2025-11-20 23:52:00 +01:00
Quentin Fuxa
254faaf64c stt/diar/nllw alignment: internal rework 5 2025-11-20 23:52:00 +01:00
Quentin Fuxa
8e7aea4fcf internal rework 4 2025-11-20 23:45:20 +01:00
Quentin Fuxa
270faf2069 internal rework 3 2025-11-20 22:28:30 +01:00
Quentin Fuxa
b7c1cc77cc internal rework 2 2025-11-20 22:06:38 +01:00
16 changed files with 720 additions and 881 deletions

View File

@@ -141,7 +141,7 @@ async def websocket_endpoint(websocket: WebSocket):
|-----------|-------------|---------| |-----------|-------------|---------|
| `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/available_models.md) | `small` | | `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/available_models.md) | `small` |
| `--model-path` | Local .pt file/directory **or** Hugging Face repo ID containing the Whisper model. Overrides `--model`. Recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/models_compatible_formats.md) | `None` | | `--model-path` | Local .pt file/directory **or** Hugging Face repo ID containing the Whisper model. Overrides `--model`. Recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/models_compatible_formats.md) | `None` |
| `--language` | List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` | | `--language` | List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` |
| `--target-language` | If sets, translates using [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting). [200 languages available](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/supported_languages.md). If you want to translate to english, you can also use `--direct-english-translation`. The STT model will try to directly output the translation. | `None` | | `--target-language` | If sets, translates using [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting). [200 languages available](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/supported_languages.md). If you want to translate to english, you can also use `--direct-english-translation`. The STT model will try to directly output the translation. | `None` |
| `--diarization` | Enable speaker identification | `False` | | `--diarization` | Enable speaker identification | `False` |
| `--backend-policy` | Streaming strategy: `1`/`simulstreaming` uses AlignAtt SimulStreaming, `2`/`localagreement` uses the LocalAgreement policy | `simulstreaming` | | `--backend-policy` | Streaming strategy: `1`/`simulstreaming` uses AlignAtt SimulStreaming, `2`/`localagreement` uses the LocalAgreement policy | `simulstreaming` |

View File

@@ -4,7 +4,7 @@
- Example 2: The punctuation from STT comes from prediction `t`, but the speaker change from Diariation come in the prediction `t-1` - Example 2: The punctuation from STT comes from prediction `t`, but the speaker change from Diariation come in the prediction `t-1`
- Example 3: The punctuation from STT comes from prediction `t-1`, but the speaker change from Diariation come in the prediction `t` - Example 3: The punctuation from STT comes from prediction `t-1`, but the speaker change from Diariation come in the prediction `t`
> `#` Is the split between the `t-1` prediction and t prediction. > `#` Is the split between the `t-1` prediction and `t` prediction.
## Example 1: ## Example 1:

View File

@@ -0,0 +1,43 @@
# Technical Integration Guide
This document introduce how to reuse the core components when you do **not** want to ship the bundled frontend, FastAPI server, or even the provided CLI.
---
## 1. Runtime Components
| Layer | File(s) | Purpose |
|-------|---------|---------|
| Transport | `whisperlivekit/basic_server.py`, any ASGI/WebSocket server | Accepts audio over WebSocket (MediaRecorder WebM or raw PCM chunks) and streams JSON updates back |
| Audio processing | `whisperlivekit/audio_processor.py` | Buffers audio, orchestrates transcription, diarization, translation, handles FFmpeg/PCM input |
| Engines | `whisperlivekit/core.py`, `whisperlivekit/simul_whisper/*`, `whisperlivekit/local_agreement/*` | Load models once (SimulStreaming or LocalAgreement), expose `TranscriptionEngine` and helpers |
| Frontends | `whisperlivekit/web/*`, `chrome-extension/*` | Optional UI layers feeding the WebSocket endpoint |
**Key idea:** The server boundary is just `AudioProcessor.process_audio()` for incoming bytes and the async generator returned by `AudioProcessor.create_tasks()` for outgoing updates (`FrontData`). Everything else is optional.
---
## 2. Running Without the Bundled Frontend
1. Start the server/engine however you like:
```bash
wlk --model small --language en --host 0.0.0.0 --port 9000
# or launch your own app that instantiates TranscriptionEngine(...)
```
2. Build your own client (browser, mobile, desktop) that:
- Opens `ws(s)://<host>:<port>/asr`
- Sends either MediaRecorder/Opus WebM blobs **or** raw PCM (`--pcm-input` on the server tells the client to use the AudioWorklet).
- Consumes the JSON payload defined in `docs/API.md`.
---
## 3. Running Without FastAPI
`whisperlivekit/basic_server.py` is just an example. Any async framework works, as long as you:
1. Create a global `TranscriptionEngine` (expensive to initialize; reuse it).
2. Instantiate `AudioProcessor(transcription_engine=engine)` for each connection.
3. Call `create_tasks()` to get the async generator, `process_audio()` with incoming bytes, and ensure `cleanup()` runs when the client disconnects.
If you prefer to send compressed audio, instantiate `AudioProcessor(pcm_input=False)` and pipe encoded chunks through `FFmpegManager` transparently—just ensure `ffmpeg` is available or be ready to handle the `"ffmpeg_not_found"` error in the streamed `FrontData`.

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "whisperlivekit" name = "whisperlivekit"
version = "0.2.14.post4" version = "0.2.15"
description = "Real-time speech-to-text with speaker diarization using Whisper" description = "Real-time speech-to-text with speaker diarization using Whisper"
readme = "README.md" readme = "README.md"
authors = [ authors = [

View File

@@ -1,168 +0,0 @@
class TokensAlignment:
def __init__(self, state_light, silence=None, args=None):
self.state_light = state_light
self.silence = silence
self.args = args
self._tokens_index = 0
self._diarization_index = 0
self._translation_index = 0
def update(self):
pass
def compute_punctuations_segments(self):
punctuations_breaks = []
new_tokens = self.state.tokens[self.state.last_validated_token:]
for i in range(len(new_tokens)):
token = new_tokens[i]
if token.is_punctuation():
punctuations_breaks.append({
'token_index': i,
'token': token,
'start': token.start,
'end': token.end,
})
punctuations_segments = []
for i, break_info in enumerate(punctuations_breaks):
start = punctuations_breaks[i - 1]['end'] if i > 0 else 0.0
end = break_info['end']
punctuations_segments.append({
'start': start,
'end': end,
'token_index': break_info['token_index'],
'token': break_info['token']
})
return punctuations_segments
def concatenate_diar_segments(self):
diarization_segments = self.state.diarization_segments
if __name__ == "__main__":
from whisperlivekit.timed_objects import State, ASRToken, SpeakerSegment, Transcript, Silence
# Reconstruct the state from the backup data
tokens = [
ASRToken(start=1.38, end=1.48, text=' The'),
ASRToken(start=1.42, end=1.52, text=' description'),
ASRToken(start=1.82, end=1.92, text=' technology'),
ASRToken(start=2.54, end=2.64, text=' has'),
ASRToken(start=2.7, end=2.8, text=' improved'),
ASRToken(start=3.24, end=3.34, text=' so'),
ASRToken(start=3.66, end=3.76, text=' much'),
ASRToken(start=4.02, end=4.12, text=' in'),
ASRToken(start=4.08, end=4.18, text=' the'),
ASRToken(start=4.26, end=4.36, text=' past'),
ASRToken(start=4.48, end=4.58, text=' few'),
ASRToken(start=4.76, end=4.86, text=' years'),
ASRToken(start=5.76, end=5.86, text='.'),
ASRToken(start=5.72, end=5.82, text=' Have'),
ASRToken(start=5.92, end=6.02, text=' you'),
ASRToken(start=6.08, end=6.18, text=' noticed'),
ASRToken(start=6.52, end=6.62, text=' how'),
ASRToken(start=6.8, end=6.9, text=' accurate'),
ASRToken(start=7.46, end=7.56, text=' real'),
ASRToken(start=7.72, end=7.82, text='-time'),
ASRToken(start=8.06, end=8.16, text=' speech'),
ASRToken(start=8.48, end=8.58, text=' to'),
ASRToken(start=8.68, end=8.78, text=' text'),
ASRToken(start=9.0, end=9.1, text=' is'),
ASRToken(start=9.24, end=9.34, text=' now'),
ASRToken(start=9.82, end=9.92, text='?'),
ASRToken(start=9.86, end=9.96, text=' Absolutely'),
ASRToken(start=11.26, end=11.36, text='.'),
ASRToken(start=11.36, end=11.46, text=' I'),
ASRToken(start=11.58, end=11.68, text=' use'),
ASRToken(start=11.78, end=11.88, text=' it'),
ASRToken(start=11.94, end=12.04, text=' all'),
ASRToken(start=12.08, end=12.18, text=' the'),
ASRToken(start=12.32, end=12.42, text=' time'),
ASRToken(start=12.58, end=12.68, text=' for'),
ASRToken(start=12.78, end=12.88, text=' taking'),
ASRToken(start=13.14, end=13.24, text=' notes'),
ASRToken(start=13.4, end=13.5, text=' during'),
ASRToken(start=13.78, end=13.88, text=' meetings'),
ASRToken(start=14.6, end=14.7, text='.'),
ASRToken(start=14.82, end=14.92, text=' It'),
ASRToken(start=14.92, end=15.02, text="'s"),
ASRToken(start=15.04, end=15.14, text=' amazing'),
ASRToken(start=15.5, end=15.6, text=' how'),
ASRToken(start=15.66, end=15.76, text=' it'),
ASRToken(start=15.8, end=15.9, text=' can'),
ASRToken(start=15.96, end=16.06, text=' recognize'),
ASRToken(start=16.58, end=16.68, text=' different'),
ASRToken(start=16.94, end=17.04, text=' speakers'),
ASRToken(start=17.82, end=17.92, text=' and'),
ASRToken(start=18.0, end=18.1, text=' even'),
ASRToken(start=18.42, end=18.52, text=' add'),
ASRToken(start=18.74, end=18.84, text=' punct'),
ASRToken(start=19.02, end=19.12, text='uation'),
ASRToken(start=19.68, end=19.78, text='.'),
ASRToken(start=20.04, end=20.14, text=' Yeah'),
ASRToken(start=20.5, end=20.6, text=','),
ASRToken(start=20.6, end=20.7, text=' but'),
ASRToken(start=20.76, end=20.86, text=' sometimes'),
ASRToken(start=21.42, end=21.52, text=' noise'),
ASRToken(start=21.82, end=21.92, text=' can'),
ASRToken(start=22.08, end=22.18, text=' still'),
ASRToken(start=22.38, end=22.48, text=' cause'),
ASRToken(start=22.72, end=22.82, text=' mistakes'),
ASRToken(start=23.74, end=23.84, text='.'),
ASRToken(start=23.96, end=24.06, text=' Does'),
ASRToken(start=24.16, end=24.26, text=' this'),
ASRToken(start=24.4, end=24.5, text=' system'),
ASRToken(start=24.76, end=24.86, text=' handle'),
ASRToken(start=25.12, end=25.22, text=' that'),
ASRToken(start=25.38, end=25.48, text=' well'),
ASRToken(start=25.68, end=25.78, text='?'),
ASRToken(start=26.4, end=26.5, text=' It'),
ASRToken(start=26.5, end=26.6, text=' does'),
ASRToken(start=26.7, end=26.8, text=' a'),
ASRToken(start=27.08, end=27.18, text=' pretty'),
ASRToken(start=27.12, end=27.22, text=' good'),
ASRToken(start=27.34, end=27.44, text=' job'),
ASRToken(start=27.64, end=27.74, text=' filtering'),
ASRToken(start=28.1, end=28.2, text=' noise'),
ASRToken(start=28.64, end=28.74, text=','),
ASRToken(start=28.78, end=28.88, text=' especially'),
ASRToken(start=29.3, end=29.4, text=' with'),
ASRToken(start=29.51, end=29.61, text=' models'),
ASRToken(start=29.99, end=30.09, text=' that'),
ASRToken(start=30.21, end=30.31, text=' use'),
ASRToken(start=30.51, end=30.61, text=' voice'),
ASRToken(start=30.83, end=30.93, text=' activity'),
]
diarization_segments = [
SpeakerSegment(start=1.3255040645599365, end=4.3255040645599365, speaker=0),
SpeakerSegment(start=4.806154012680054, end=9.806154012680054, speaker=0),
SpeakerSegment(start=9.806154012680054, end=10.806154012680054, speaker=1),
SpeakerSegment(start=11.168735027313232, end=14.168735027313232, speaker=1),
SpeakerSegment(start=14.41029405593872, end=17.41029405593872, speaker=1),
SpeakerSegment(start=17.52983808517456, end=19.52983808517456, speaker=1),
SpeakerSegment(start=19.64953374862671, end=20.066200415293377, speaker=1),
SpeakerSegment(start=20.066200415293377, end=22.64953374862671, speaker=2),
SpeakerSegment(start=23.012792587280273, end=25.012792587280273, speaker=2),
SpeakerSegment(start=25.495875597000122, end=26.41254226366679, speaker=2),
SpeakerSegment(start=26.41254226366679, end=30.495875597000122, speaker=0),
]
state = State(
tokens=tokens,
last_validated_token=72,
last_speaker=-1,
last_punctuation_index=71,
translation_validated_segments=[],
buffer_translation=Transcript(start=0, end=0, speaker=-1),
buffer_transcription=Transcript(start=None, end=None, speaker=-1),
diarization_segments=diarization_segments,
end_buffer=31.21587559700018,
end_attributed_speaker=30.495875597000122,
remaining_time_transcription=0.4,
remaining_time_diarization=0.7,
beg_loop=1763627603.968919
)
alignment = TokensAlignment(state)

View File

@@ -1,36 +1,23 @@
import asyncio import asyncio
import numpy as np import numpy as np
from time import time, sleep from time import time
import math
import logging import logging
import traceback import traceback
from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State, StateLight, Transcript, ChangeSpeaker from typing import Optional, Union, List, Any, AsyncGenerator
from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State, Transcript, ChangeSpeaker
from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory, online_translation_factory from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory, online_translation_factory
from whisperlivekit.silero_vad_iterator import FixedVADIterator from whisperlivekit.silero_vad_iterator import FixedVADIterator
from whisperlivekit.results_formater import format_output
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
from whisperlivekit.TokensAlignment import TokensAlignment from whisperlivekit.tokens_alignment import TokensAlignment
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
SENTINEL = object() # unique sentinel object for end of stream marker SENTINEL = object() # unique sentinel object for end of stream marker
MIN_DURATION_REAL_SILENCE = 5
def cut_at(cumulative_pcm, cut_sec): async def get_all_from_queue(queue: asyncio.Queue) -> Union[object, Silence, np.ndarray, List[Any]]:
cumulative_len = 0 items: List[Any] = []
cut_sample = int(cut_sec * 16000)
for ind, pcm_array in enumerate(cumulative_pcm):
if (cumulative_len + len(pcm_array)) >= cut_sample:
cut_chunk = cut_sample - cumulative_len
before = np.concatenate(cumulative_pcm[:ind] + [cumulative_pcm[ind][:cut_chunk]])
after = [cumulative_pcm[ind][cut_chunk:]] + cumulative_pcm[ind+1:]
return before, after
cumulative_len += len(pcm_array)
return np.concatenate(cumulative_pcm), []
async def get_all_from_queue(queue):
items = []
first_item = await queue.get() first_item = await queue.get()
queue.task_done() queue.task_done()
@@ -61,7 +48,7 @@ class AudioProcessor:
Handles audio processing, state management, and result formatting. Handles audio processing, state management, and result formatting.
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs: Any) -> None:
"""Initialize the audio processor with configuration, models, and state.""" """Initialize the audio processor with configuration, models, and state."""
if 'transcription_engine' in kwargs and isinstance(kwargs['transcription_engine'], TranscriptionEngine): if 'transcription_engine' in kwargs and isinstance(kwargs['transcription_engine'], TranscriptionEngine):
@@ -80,33 +67,27 @@ class AudioProcessor:
self.is_pcm_input = self.args.pcm_input self.is_pcm_input = self.args.pcm_input
# State management # State management
self.is_stopping = False self.is_stopping: bool = False
self.silence = True self.current_silence: Optional[Silence] = None
self.silence_duration = 0.0 self.state: State = State()
self.start_silence = None self.lock: asyncio.Lock = asyncio.Lock()
self.last_silence_dispatch_time = None self.sep: str = " " # Default separator
self.state = State() self.last_response_content: FrontData = FrontData()
self.state_light = StateLight()
self.lock = asyncio.Lock()
self.sep = " " # Default separator
self.last_response_content = FrontData()
self.last_detected_speaker = None
self.speaker_languages = {}
self.tokens_alignment = TokensAlignment(self.state_light, self.args, self.sep) self.tokens_alignment: TokensAlignment = TokensAlignment(self.state, self.args, self.sep)
self.beg_loop = None self.beg_loop: Optional[float] = None
# Models and processing # Models and processing
self.asr = models.asr self.asr: Any = models.asr
self.vac_model = models.vac_model self.vac_model: Any = models.vac_model
if self.args.vac: if self.args.vac:
self.vac = FixedVADIterator(models.vac_model) self.vac: Optional[FixedVADIterator] = FixedVADIterator(models.vac_model)
else: else:
self.vac = None self.vac: Optional[FixedVADIterator] = None
self.ffmpeg_manager = None self.ffmpeg_manager: Optional[FFmpegManager] = None
self.ffmpeg_reader_task = None self.ffmpeg_reader_task: Optional[asyncio.Task] = None
self._ffmpeg_error = None self._ffmpeg_error: Optional[str] = None
if not self.is_pcm_input: if not self.is_pcm_input:
self.ffmpeg_manager = FFmpegManager( self.ffmpeg_manager = FFmpegManager(
@@ -118,21 +99,20 @@ class AudioProcessor:
self._ffmpeg_error = error_type self._ffmpeg_error = error_type
self.ffmpeg_manager.on_error_callback = handle_ffmpeg_error self.ffmpeg_manager.on_error_callback = handle_ffmpeg_error
self.transcription_queue = asyncio.Queue() if self.args.transcription else None self.transcription_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.transcription else None
self.diarization_queue = asyncio.Queue() if self.args.diarization else None self.diarization_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.diarization else None
self.translation_queue = asyncio.Queue() if self.args.target_language else None self.translation_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.target_language else None
self.pcm_buffer = bytearray() self.pcm_buffer: bytearray = bytearray()
self.total_pcm_samples = 0 self.total_pcm_samples: int = 0
self.end_buffer = 0.0 self.transcription_task: Optional[asyncio.Task] = None
self.transcription_task = None self.diarization_task: Optional[asyncio.Task] = None
self.diarization_task = None self.translation_task: Optional[asyncio.Task] = None
self.translation_task = None self.watchdog_task: Optional[asyncio.Task] = None
self.watchdog_task = None self.all_tasks_for_cleanup: List[asyncio.Task] = []
self.all_tasks_for_cleanup = []
self.transcription = None self.transcription: Optional[Any] = None
self.translation = None self.translation: Optional[Any] = None
self.diarization = None self.diarization: Optional[Any] = None
if self.args.transcription: if self.args.transcription:
self.transcription = online_factory(self.args, models.asr) self.transcription = online_factory(self.args, models.asr)
@@ -142,44 +122,45 @@ class AudioProcessor:
if models.translation_model: if models.translation_model:
self.translation = online_translation_factory(self.args, models.translation_model) self.translation = online_translation_factory(self.args, models.translation_model)
async def _push_silence_event(self, silence_buffer: Silence): async def _push_silence_event(self) -> None:
if self.transcription_queue: if self.transcription_queue:
await self.transcription_queue.put(silence_buffer) await self.transcription_queue.put(self.current_silence)
if self.args.diarization and self.diarization_queue: if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(silence_buffer) await self.diarization_queue.put(self.current_silence)
if self.translation_queue: if self.translation_queue:
await self.translation_queue.put(silence_buffer) await self.translation_queue.put(self.current_silence)
async def _begin_silence(self): async def _begin_silence(self) -> None:
if self.silence: if self.current_silence:
return return
self.silence = True now = time() - self.beg_loop
now = time() self.current_silence = Silence(
self.start_silence = now is_starting=True, start=now
self.last_silence_dispatch_time = now )
await self._push_silence_event(Silence(is_starting=True)) await self._push_silence_event()
async def _end_silence(self): async def _end_silence(self) -> None:
if not self.silence: if not self.current_silence:
return return
now = time() now = time() - self.beg_loop
duration = now - (self.last_silence_dispatch_time if self.last_silence_dispatch_time else self.beg_loop) self.current_silence.end = now
await self._push_silence_event(Silence(duration=duration, has_ended=True)) self.current_silence.is_starting=False
self.last_silence_dispatch_time = now self.current_silence.has_ended=True
self.silence = False self.current_silence.compute_duration()
self.start_silence = None if self.current_silence.duration > MIN_DURATION_REAL_SILENCE:
self.last_silence_dispatch_time = None self.state.new_tokens.append(self.current_silence)
await self._push_silence_event()
self.current_silence = None
async def _enqueue_active_audio(self, pcm_chunk: np.ndarray): async def _enqueue_active_audio(self, pcm_chunk: np.ndarray) -> None:
if pcm_chunk is None or pcm_chunk.size == 0: if pcm_chunk is None or pcm_chunk.size == 0:
return return
if self.transcription_queue: if self.transcription_queue:
await self.transcription_queue.put(pcm_chunk.copy()) await self.transcription_queue.put(pcm_chunk.copy())
if self.args.diarization and self.diarization_queue: if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(pcm_chunk.copy()) await self.diarization_queue.put(pcm_chunk.copy())
self.silence_duration = 0.0
def _slice_before_silence(self, pcm_array, chunk_sample_start, silence_sample): def _slice_before_silence(self, pcm_array: np.ndarray, chunk_sample_start: int, silence_sample: Optional[int]) -> Optional[np.ndarray]:
if silence_sample is None: if silence_sample is None:
return None return None
relative_index = int(silence_sample - chunk_sample_start) relative_index = int(silence_sample - chunk_sample_start)
@@ -190,22 +171,22 @@ class AudioProcessor:
return None return None
return pcm_array[:split_index] return pcm_array[:split_index]
def convert_pcm_to_float(self, pcm_buffer): def convert_pcm_to_float(self, pcm_buffer: Union[bytes, bytearray]) -> np.ndarray:
"""Convert PCM buffer in s16le format to normalized NumPy array.""" """Convert PCM buffer in s16le format to normalized NumPy array."""
return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0 return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
async def get_current_state(self): async def get_current_state(self) -> State:
"""Get current state.""" """Get current state."""
async with self.lock: async with self.lock:
current_time = time() current_time = time()
remaining_transcription = 0 remaining_transcription = 0
if self.end_buffer > 0: if self.state.end_buffer > 0:
remaining_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 1)) remaining_transcription = max(0, round(current_time - self.beg_loop - self.state.end_buffer, 1))
remaining_diarization = 0 remaining_diarization = 0
if self.state.tokens: if self.state.tokens:
latest_end = max(self.end_buffer, self.state.tokens[-1].end if self.state.tokens else 0) latest_end = max(self.state.end_buffer, self.state.tokens[-1].end if self.state.tokens else 0)
remaining_diarization = max(0, round(latest_end - self.state.end_attributed_speaker, 1)) remaining_diarization = max(0, round(latest_end - self.state.end_attributed_speaker, 1))
self.state.remaining_time_transcription = remaining_transcription self.state.remaining_time_transcription = remaining_transcription
@@ -213,7 +194,7 @@ class AudioProcessor:
return self.state return self.state
async def ffmpeg_stdout_reader(self): async def ffmpeg_stdout_reader(self) -> None:
"""Read audio data from FFmpeg stdout and process it into the PCM pipeline.""" """Read audio data from FFmpeg stdout and process it into the PCM pipeline."""
beg = time() beg = time()
while True: while True:
@@ -263,7 +244,7 @@ class AudioProcessor:
if self.translation: if self.translation:
await self.translation_queue.put(SENTINEL) await self.translation_queue.put(SENTINEL)
async def transcription_processor(self): async def transcription_processor(self) -> None:
"""Process audio chunks for transcription.""" """Process audio chunks for transcription."""
cumulative_pcm_duration_stream_time = 0.0 cumulative_pcm_duration_stream_time = 0.0
@@ -276,11 +257,11 @@ class AudioProcessor:
break break
asr_internal_buffer_duration_s = len(getattr(self.transcription, 'audio_buffer', [])) / self.transcription.SAMPLING_RATE asr_internal_buffer_duration_s = len(getattr(self.transcription, 'audio_buffer', [])) / self.transcription.SAMPLING_RATE
transcription_lag_s = max(0.0, time() - self.beg_loop - self.end_buffer) transcription_lag_s = max(0.0, time() - self.beg_loop - self.state.end_buffer)
asr_processing_logs = f"internal_buffer={asr_internal_buffer_duration_s:.2f}s | lag={transcription_lag_s:.2f}s |" asr_processing_logs = f"internal_buffer={asr_internal_buffer_duration_s:.2f}s | lag={transcription_lag_s:.2f}s |"
stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time
new_tokens = [] new_tokens = []
current_audio_processed_upto = self.end_buffer current_audio_processed_upto = self.state.end_buffer
if isinstance(item, Silence): if isinstance(item, Silence):
if item.is_starting: if item.is_starting:
@@ -318,7 +299,7 @@ class AudioProcessor:
if buffer_text.startswith(validated_text): if buffer_text.startswith(validated_text):
_buffer_transcript.text = buffer_text[len(validated_text):].lstrip() _buffer_transcript.text = buffer_text[len(validated_text):].lstrip()
candidate_end_times = [self.end_buffer] candidate_end_times = [self.state.end_buffer]
if new_tokens: if new_tokens:
candidate_end_times.append(new_tokens[-1].end) candidate_end_times.append(new_tokens[-1].end)
@@ -331,10 +312,9 @@ class AudioProcessor:
async with self.lock: async with self.lock:
self.state.tokens.extend(new_tokens) self.state.tokens.extend(new_tokens)
self.state.buffer_transcription = _buffer_transcript self.state.buffer_transcription = _buffer_transcript
self.end_buffer = max(candidate_end_times) self.state.end_buffer = max(candidate_end_times)
self.state_light.new_tokens = new_tokens self.state.new_tokens.extend(new_tokens)
self.state_light.new_tokens += 1 self.state.new_tokens_buffer = _buffer_transcript
self.state_light.new_tokens_buffer = _buffer_transcript
if self.translation_queue: if self.translation_queue:
for token in new_tokens: for token in new_tokens:
@@ -355,7 +335,7 @@ class AudioProcessor:
logger.info("Transcription processor task finished.") logger.info("Transcription processor task finished.")
async def diarization_processor(self): async def diarization_processor(self) -> None:
while True: while True:
try: try:
item = await get_all_from_queue(self.diarization_queue) item = await get_all_from_queue(self.diarization_queue)
@@ -368,41 +348,44 @@ class AudioProcessor:
self.diarization.insert_audio_chunk(item) self.diarization.insert_audio_chunk(item)
diarization_segments = await self.diarization.diarize() diarization_segments = await self.diarization.diarize()
self.state_light.new_diarization = diarization_segments self.state.new_diarization = diarization_segments
self.state_light.new_diarization_index += 1
except Exception as e: except Exception as e:
logger.warning(f"Exception in diarization_processor: {e}") logger.warning(f"Exception in diarization_processor: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}") logger.warning(f"Traceback: {traceback.format_exc()}")
logger.info("Diarization processor task finished.") logger.info("Diarization processor task finished.")
async def translation_processor(self): async def translation_processor(self) -> None:
# the idea is to ignore diarization for the moment. We use only transcription tokens. # the idea is to ignore diarization for the moment. We use only transcription tokens.
# And the speaker is attributed given the segments used for the translation # And the speaker is attributed given the segments used for the translation
# in the future we want to have different languages for each speaker etc, so it will be more complex. # in the future we want to have different languages for each speaker etc, so it will be more complex.
while True: while True:
try: try:
tokens_to_process = await get_all_from_queue(self.translation_queue) item = await get_all_from_queue(self.translation_queue)
if tokens_to_process is SENTINEL: if item is SENTINEL:
logger.debug("Translation processor received sentinel. Finishing.") logger.debug("Translation processor received sentinel. Finishing.")
self.translation_queue.task_done()
break break
elif type(tokens_to_process) is Silence: elif type(item) is Silence:
if tokens_to_process.has_ended: if item.is_starting:
self.translation.insert_silence(tokens_to_process.duration) new_translation, new_translation_buffer = self.translation.validate_buffer_and_reset()
continue if item.has_ended:
if tokens_to_process: self.translation.insert_silence(item.duration)
self.translation.insert_tokens(tokens_to_process) continue
translation_validated_segments, buffer_translation = await asyncio.to_thread(self.translation.process) elif isinstance(item, ChangeSpeaker):
async with self.lock: new_translation, new_translation_buffer = self.translation.validate_buffer_and_reset()
self.state.translation_validated_segments = translation_validated_segments pass
self.state.buffer_translation = buffer_translation else:
self.translation.insert_tokens(item)
new_translation, new_translation_buffer = await asyncio.to_thread(self.translation.process)
async with self.lock:
self.state.new_translation.append(new_translation)
self.state.new_translation_buffer = new_translation_buffer
except Exception as e: except Exception as e:
logger.warning(f"Exception in translation_processor: {e}") logger.warning(f"Exception in translation_processor: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}") logger.warning(f"Traceback: {traceback.format_exc()}")
logger.info("Translation processor task finished.") logger.info("Translation processor task finished.")
async def results_formatter(self): async def results_formatter(self) -> AsyncGenerator[FrontData, None]:
"""Format processing results for output.""" """Format processing results for output."""
while True: while True:
try: try:
@@ -412,55 +395,32 @@ class AudioProcessor:
await asyncio.sleep(1) await asyncio.sleep(1)
continue continue
state = await self.get_current_state() self.tokens_alignment.update()
self.tokens_alignment.compute_punctuations_segments() lines, buffer_diarization_text, buffer_translation_text = self.tokens_alignment.get_lines(
lines, undiarized_text = format_output( diarization=self.args.diarization,
state, translation=bool(self.translation),
self.silence, current_silence=self.current_silence
args = self.args,
sep=self.sep
) )
if lines and lines[-1].speaker == -2: state = await self.get_current_state()
buffer_transcription = Transcript()
else:
buffer_transcription = state.buffer_transcription
buffer_diarization = '' buffer_transcription_text = state.buffer_transcription.text if state.buffer_transcription else ''
if undiarized_text:
buffer_diarization = self.sep.join(undiarized_text)
async with self.lock:
self.state.end_attributed_speaker = state.end_attributed_speaker
buffer_translation_text = ''
if state.buffer_translation:
raw_buffer_translation = getattr(state.buffer_translation, 'text', state.buffer_translation)
if raw_buffer_translation:
buffer_translation_text = raw_buffer_translation.strip()
response_status = "active_transcription" response_status = "active_transcription"
if not state.tokens and not buffer_transcription and not buffer_diarization: if not lines and not buffer_transcription_text and not buffer_diarization_text:
response_status = "no_audio_detected" response_status = "no_audio_detected"
lines = []
elif not lines:
lines = [Line(
speaker=1,
start=state.end_buffer,
end=state.end_buffer
)]
response = FrontData( response = FrontData(
status=response_status, status=response_status,
lines=lines, lines=lines,
buffer_transcription=buffer_transcription.text.strip(), buffer_transcription=buffer_transcription_text,
buffer_diarization=buffer_diarization, buffer_diarization=buffer_diarization_text,
buffer_translation=buffer_translation_text, buffer_translation=buffer_translation_text,
remaining_time_transcription=state.remaining_time_transcription, remaining_time_transcription=state.remaining_time_transcription,
remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0 remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0
) )
should_push = (response != self.last_response_content) should_push = (response != self.last_response_content)
if should_push and (lines or buffer_transcription or buffer_diarization or response_status == "no_audio_detected"): if should_push:
yield response yield response
self.last_response_content = response self.last_response_content = response
@@ -474,17 +434,17 @@ class AudioProcessor:
logger.warning(f"Exception in results_formatter. Traceback: {traceback.format_exc()}") logger.warning(f"Exception in results_formatter. Traceback: {traceback.format_exc()}")
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
async def create_tasks(self): async def create_tasks(self) -> AsyncGenerator[FrontData, None]:
"""Create and start processing tasks.""" """Create and start processing tasks."""
self.all_tasks_for_cleanup = [] self.all_tasks_for_cleanup = []
processing_tasks_for_watchdog = [] processing_tasks_for_watchdog: List[asyncio.Task] = []
# If using FFmpeg (non-PCM input), start it and spawn stdout reader # If using FFmpeg (non-PCM input), start it and spawn stdout reader
if not self.is_pcm_input: if not self.is_pcm_input:
success = await self.ffmpeg_manager.start() success = await self.ffmpeg_manager.start()
if not success: if not success:
logger.error("Failed to start FFmpeg manager") logger.error("Failed to start FFmpeg manager")
async def error_generator(): async def error_generator() -> AsyncGenerator[FrontData, None]:
yield FrontData( yield FrontData(
status="error", status="error",
error="FFmpeg failed to start. Please check that FFmpeg is installed." error="FFmpeg failed to start. Please check that FFmpeg is installed."
@@ -515,9 +475,9 @@ class AudioProcessor:
return self.results_formatter() return self.results_formatter()
async def watchdog(self, tasks_to_monitor): async def watchdog(self, tasks_to_monitor: List[asyncio.Task]) -> None:
"""Monitors the health of critical processing tasks.""" """Monitors the health of critical processing tasks."""
tasks_remaining = [task for task in tasks_to_monitor if task] tasks_remaining: List[asyncio.Task] = [task for task in tasks_to_monitor if task]
while True: while True:
try: try:
if not tasks_remaining: if not tasks_remaining:
@@ -542,7 +502,7 @@ class AudioProcessor:
except Exception as e: except Exception as e:
logger.error(f"Error in watchdog task: {e}", exc_info=True) logger.error(f"Error in watchdog task: {e}", exc_info=True)
async def cleanup(self): async def cleanup(self) -> None:
"""Clean up resources when processing is complete.""" """Clean up resources when processing is complete."""
logger.info("Starting cleanup of AudioProcessor resources.") logger.info("Starting cleanup of AudioProcessor resources.")
self.is_stopping = True self.is_stopping = True
@@ -565,7 +525,7 @@ class AudioProcessor:
self.diarization.close() self.diarization.close()
logger.info("AudioProcessor cleanup complete.") logger.info("AudioProcessor cleanup complete.")
def _processing_tasks_done(self): def _processing_tasks_done(self) -> bool:
"""Return True when all active processing tasks have completed.""" """Return True when all active processing tasks have completed."""
tasks_to_check = [ tasks_to_check = [
self.transcription_task, self.transcription_task,
@@ -576,11 +536,13 @@ class AudioProcessor:
return all(task.done() for task in tasks_to_check if task) return all(task.done() for task in tasks_to_check if task)
async def process_audio(self, message): async def process_audio(self, message: Optional[bytes]) -> None:
"""Process incoming audio data.""" """Process incoming audio data."""
if not self.beg_loop: if not self.beg_loop:
self.beg_loop = time() self.beg_loop = time()
self.current_silence = Silence(start=0.0, is_starting=True)
self.tokens_alignment.beg_loop = self.beg_loop
if not message: if not message:
logger.info("Empty audio message received, initiating stop sequence.") logger.info("Empty audio message received, initiating stop sequence.")
@@ -613,7 +575,7 @@ class AudioProcessor:
else: else:
logger.warning("Failed to write audio data to FFmpeg") logger.warning("Failed to write audio data to FFmpeg")
async def handle_pcm_data(self): async def handle_pcm_data(self) -> None:
# Process when enough data # Process when enough data
if len(self.pcm_buffer) < self.bytes_per_sec: if len(self.pcm_buffer) < self.bytes_per_sec:
return return
@@ -642,17 +604,17 @@ class AudioProcessor:
if res is not None: if res is not None:
silence_detected = res.get("end", 0) > res.get("start", 0) silence_detected = res.get("end", 0) > res.get("start", 0)
if silence_detected and not self.silence: if silence_detected and not self.current_silence:
pre_silence_chunk = self._slice_before_silence( pre_silence_chunk = self._slice_before_silence(
pcm_array, chunk_sample_start, res.get("end") pcm_array, chunk_sample_start, res.get("end")
) )
if pre_silence_chunk is not None and pre_silence_chunk.size > 0: if pre_silence_chunk is not None and pre_silence_chunk.size > 0:
await self._enqueue_active_audio(pre_silence_chunk) await self._enqueue_active_audio(pre_silence_chunk)
await self._begin_silence() await self._begin_silence()
elif self.silence: elif self.current_silence:
await self._end_silence() await self._end_silence()
if not self.silence: if not self.current_silence:
await self._enqueue_active_audio(pcm_array) await self._enqueue_active_audio(pcm_array)
self.total_pcm_samples = chunk_sample_end self.total_pcm_samples = chunk_sample_end

View File

@@ -224,7 +224,8 @@ class MLXWhisper(ASRBase):
if segment.get("no_speech_prob", 0) > 0.9: if segment.get("no_speech_prob", 0) > 0.9:
continue continue
for word in segment.get("words", []): for word in segment.get("words", []):
token = ASRToken(word["start"], word["end"], word["word"], probability=word["probability"]) probability=word["probability"]
token = ASRToken(word["start"], word["end"], word["word"])
tokens.append(token) tokens.append(token)
return tokens return tokens

View File

@@ -411,11 +411,11 @@ class OnlineASRProcessor:
) -> Transcript: ) -> Transcript:
sep = sep if sep is not None else self.asr.sep sep = sep if sep is not None else self.asr.sep
text = sep.join(token.text for token in tokens) text = sep.join(token.text for token in tokens)
probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None # probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
if tokens: if tokens:
start = offset + tokens[0].start start = offset + tokens[0].start
end = offset + tokens[-1].end end = offset + tokens[-1].end
else: else:
start = None start = None
end = None end = None
return Transcript(start, end, text, probability=probability) return Transcript(start, end, text)

View File

@@ -1,103 +0,0 @@
from whisperlivekit.timed_objects import ASRToken
from time import time
import re
MIN_SILENCE_DURATION = 4 #in seconds
END_SILENCE_DURATION = 8 #in seconds. you should keep it important to not have false positive when the model lag is important
END_SILENCE_DURATION_VAC = 3 #VAC is good at detecting silences, but we want to skip the smallest silences
def blank_to_silence(tokens):
full_string = ''.join([t.text for t in tokens])
patterns = [re.compile(r'(?:\s*\[BLANK_AUDIO\]\s*)+'), re.compile(r'(?:\s*\[typing\]\s*)+')]
matches = []
for pattern in patterns:
for m in pattern.finditer(full_string):
matches.append({
'start': m.start(),
'end': m.end()
})
if matches:
# cleaned = pattern.sub(' ', full_string).strip()
# print("Cleaned:", cleaned)
cumulated_len = 0
silence_token = None
cleaned_tokens = []
for token in tokens:
if matches:
start = cumulated_len
end = cumulated_len + len(token.text)
cumulated_len = end
if start >= matches[0]['start'] and end <= matches[0]['end']:
if silence_token: #previous token was already silence
silence_token.start = min(silence_token.start, token.start)
silence_token.end = max(silence_token.end, token.end)
else: #new silence
silence_token = ASRToken(
start=token.start,
end=token.end,
speaker=-2,
)
else:
if silence_token: #there was silence but no more
if silence_token.duration() >= MIN_SILENCE_DURATION:
cleaned_tokens.append(
silence_token
)
silence_token = None
matches.pop(0)
cleaned_tokens.append(token)
# print(cleaned_tokens)
return cleaned_tokens
return tokens
def no_token_to_silence(tokens):
new_tokens = []
silence_token = None
for token in tokens:
if token.speaker == -2:
if new_tokens and new_tokens[-1].speaker == -2: #if token is silence and previous one too
new_tokens[-1].end = token.end
else:
new_tokens.append(token)
last_end = new_tokens[-1].end if new_tokens else 0.0
if token.start - last_end >= MIN_SILENCE_DURATION: #if token is not silence but important gap
if new_tokens and new_tokens[-1].speaker == -2:
new_tokens[-1].end = token.start
else:
silence_token = ASRToken(
start=last_end,
end=token.start,
speaker=-2,
)
new_tokens.append(silence_token)
if token.speaker != -2:
new_tokens.append(token)
return new_tokens
def ends_with_silence(tokens, beg_loop, vac_detected_silence):
current_time = time() - (beg_loop if beg_loop else 0.0)
last_token = tokens[-1]
if vac_detected_silence or (current_time - last_token.end >= END_SILENCE_DURATION):
if last_token.speaker == -2:
last_token.end = current_time
else:
tokens.append(
ASRToken(
start=tokens[-1].end,
end=current_time,
speaker=-2,
)
)
return tokens
def handle_silences(tokens, beg_loop, vac_detected_silence):
if not tokens:
return []
tokens = blank_to_silence(tokens) #useful for simulstreaming backend which tends to generate [BLANK_AUDIO] text
tokens = no_token_to_silence(tokens)
tokens = ends_with_silence(tokens, beg_loop, vac_detected_silence)
return tokens

View File

@@ -1,60 +0,0 @@
########### WHAT IS PRODUCED: ###########
SPEAKER 1 0:00:04 - 0:00:06
Transcription technology has improved so much in the past
SPEAKER 1 0:00:07 - 0:00:12
years. Have you noticed how accurate real-time speech detects is now?
SPEAKER 2 0:00:12 - 0:00:12
Absolutely
SPEAKER 1 0:00:13 - 0:00:13
.
SPEAKER 2 0:00:14 - 0:00:14
I
SPEAKER 1 0:00:14 - 0:00:17
use it all the time for taking notes during meetings.
SPEAKER 2 0:00:17 - 0:00:17
It
SPEAKER 1 0:00:17 - 0:00:22
's amazing how it can recognize different speakers, and even add punctuation.
SPEAKER 2 0:00:22 - 0:00:22
Yeah
SPEAKER 1 0:00:23 - 0:00:26
, but sometimes noise can still cause mistakes.
SPEAKER 3 0:00:26 - 0:00:27
Does
SPEAKER 1 0:00:27 - 0:00:28
this system handle that
SPEAKER 1 0:00:29 - 0:00:29
?
SPEAKER 3 0:00:29 - 0:00:29
It
SPEAKER 1 0:00:29 - 0:00:33
does a pretty good job filtering noise, especially with models that use voice activity
########### WHAT SHOULD BE PRODUCED: ###########
SPEAKER 1 0:00:04 - 0:00:12
Transcription technology has improved so much in the past years. Have you noticed how accurate real-time speech detects is now?
SPEAKER 2 0:00:12 - 0:00:22
Absolutely. I use it all the time for taking notes during meetings. It's amazing how it can recognize different speakers, and even add punctuation.
SPEAKER 3 0:00:22 - 0:00:28
Yeah, but sometimes noise can still cause mistakes. Does this system handle that well?
SPEAKER 1 0:00:29 - 0:00:29
It does a pretty good job filtering noise, especially with models that use voice activity

View File

@@ -1,257 +0,0 @@
import logging
import re
from whisperlivekit.remove_silences import handle_silences
from whisperlivekit.timed_objects import Line, format_time, SpeakerSegment
from typing import List
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
CHECK_AROUND = 4
DEBUG = False
def next_punctuation_change(i, tokens):
for ind in range(i+1, min(len(tokens), i+CHECK_AROUND+1)):
if tokens[ind].is_punctuation():
return ind
return None
def next_speaker_change(i, tokens, speaker):
for ind in range(i-1, max(0, i-CHECK_AROUND)-1, -1):
token = tokens[ind]
if token.is_punctuation():
break
if token.speaker != speaker:
return ind, token.speaker
return None, speaker
def new_line(
token,
):
return Line(
speaker = token.corrected_speaker,
text = token.text + (f"[{format_time(token.start)} : {format_time(token.end)}]" if DEBUG else ""),
start = token.start,
end = token.end,
detected_language=token.detected_language
)
def append_token_to_last_line(lines, sep, token):
if not lines:
lines.append(new_line(token))
else:
if token.text:
lines[-1].text += sep + token.text + (f"[{format_time(token.start)} : {format_time(token.end)}]" if DEBUG else "")
lines[-1].end = token.end
if not lines[-1].detected_language and token.detected_language:
lines[-1].detected_language = token.detected_language
def extract_number(s) -> int:
"""Extract number from speaker string (for diart compatibility)."""
if isinstance(s, int):
return s
m = re.search(r'\d+', str(s))
return int(m.group()) if m else 0
def concatenate_speakers(segments: List[SpeakerSegment]) -> List[dict]:
"""Concatenate consecutive segments from the same speaker."""
if not segments:
return []
# Get speaker number from first segment
first_speaker = extract_number(segments[0].speaker)
segments_concatenated = [{"speaker": first_speaker + 1, "begin": segments[0].start, "end": segments[0].end}]
for segment in segments[1:]:
speaker = extract_number(segment.speaker) + 1
if segments_concatenated[-1]['speaker'] != speaker:
segments_concatenated.append({"speaker": speaker, "begin": segment.start, "end": segment.end})
else:
segments_concatenated[-1]['end'] = segment.end
return segments_concatenated
def add_speaker_to_tokens_with_punctuation(segments: List[SpeakerSegment], tokens: list) -> list:
"""Assign speakers to tokens with punctuation-aware boundary adjustment."""
punctuation_marks = {'.', '!', '?'}
punctuation_tokens = [token for token in tokens if token.text.strip() in punctuation_marks]
segments_concatenated = concatenate_speakers(segments)
for ind, segment in enumerate(segments_concatenated):
for i, punctuation_token in enumerate(punctuation_tokens):
if punctuation_token.start > segment['end']:
after_length = punctuation_token.start - segment['end']
before_length = segment['end'] - punctuation_tokens[i - 1].end if i > 0 else float('inf')
if before_length > after_length:
segment['end'] = punctuation_token.start
if i < len(punctuation_tokens) - 1 and ind + 1 < len(segments_concatenated):
segments_concatenated[ind + 1]['begin'] = punctuation_token.start
else:
segment['end'] = punctuation_tokens[i - 1].end if i > 0 else segment['end']
if i < len(punctuation_tokens) - 1 and ind - 1 >= 0:
segments_concatenated[ind - 1]['begin'] = punctuation_tokens[i - 1].end
break
# Ensure non-overlapping tokens
last_end = 0.0
for token in tokens:
start = max(last_end + 0.01, token.start)
token.start = start
token.end = max(start, token.end)
last_end = token.end
# Assign speakers based on adjusted segments
ind_last_speaker = 0
for segment in segments_concatenated:
for i, token in enumerate(tokens[ind_last_speaker:]):
if token.end <= segment['end']:
token.speaker = segment['speaker']
ind_last_speaker = i + 1
elif token.start > segment['end']:
break
return tokens
def assign_speakers_to_tokens(tokens: list, segments: List[SpeakerSegment], use_punctuation_split: bool = False) -> list:
"""
Assign speakers to tokens based on timing overlap with speaker segments.
Args:
tokens: List of tokens with timing information
segments: List of speaker segments
use_punctuation_split: Whether to use punctuation for boundary refinement
Returns:
List of tokens with speaker assignments
"""
if not segments or not tokens:
logger.debug("No segments or tokens available for speaker assignment")
return tokens
logger.debug(f"Assigning speakers to {len(tokens)} tokens using {len(segments)} segments")
if not use_punctuation_split:
# Simple overlap-based assignment
for token in tokens:
token.speaker = -1 # Default to no speaker
for segment in segments:
# Check for timing overlap
if not (segment.end <= token.start or segment.start >= token.end):
speaker_num = extract_number(segment.speaker)
token.speaker = speaker_num + 1 # Convert to 1-based indexing
break
else:
# Use punctuation-aware assignment
tokens = add_speaker_to_tokens_with_punctuation(segments, tokens)
return tokens
def format_output(state, silence, args, sep):
diarization = args.diarization
disable_punctuation_split = args.disable_punctuation_split
tokens = state.tokens
translation_validated_segments = state.translation_validated_segments # Here we will attribute the speakers only based on the timestamps of the segments
last_validated_token = state.last_validated_token
last_speaker = abs(state.last_speaker)
undiarized_text = []
tokens = handle_silences(tokens, state.beg_loop, silence)
# Assign speakers to tokens based on segments stored in state
if False and diarization and state.diarization_segments:
use_punctuation_split = args.punctuation_split if hasattr(args, 'punctuation_split') else False
tokens = assign_speakers_to_tokens(tokens, state.diarization_segments, use_punctuation_split=use_punctuation_split)
for i in range(last_validated_token, len(tokens)):
token = tokens[i]
speaker = int(token.speaker)
token.corrected_speaker = speaker
if True or not diarization:
if speaker == -1: #Speaker -1 means no attributed by diarization. In the frontend, it should appear under 'Speaker 1'
token.corrected_speaker = 1
token.validated_speaker = True
else:
if token.speaker == -1:
undiarized_text.append(token.text)
elif token.is_punctuation():
state.last_punctuation_index = i
token.corrected_speaker = last_speaker
token.validated_speaker = True
elif state.last_punctuation_index == i-1:
if token.speaker != last_speaker:
token.corrected_speaker = token.speaker
token.validated_speaker = True
# perfect, diarization perfectly aligned
else:
speaker_change_pos, new_speaker = next_speaker_change(i, tokens, speaker)
if speaker_change_pos:
# Corrects delay:
# That was the idea. <Okay> haha |SPLIT SPEAKER| that's a good one
# should become:
# That was the idea. |SPLIT SPEAKER| <Okay> haha that's a good one
token.corrected_speaker = new_speaker
token.validated_speaker = True
elif speaker != last_speaker:
if not (speaker == -2 or last_speaker == -2):
if next_punctuation_change(i, tokens):
# Corrects advance:
# Are you |SPLIT SPEAKER| <okay>? yeah, sure. Absolutely
# should become:
# Are you <okay>? |SPLIT SPEAKER| yeah, sure. Absolutely
token.corrected_speaker = last_speaker
token.validated_speaker = True
else: #Problematic, except if the language has no punctuation. We append to previous line, except if disable_punctuation_split is set to True.
if not disable_punctuation_split:
token.corrected_speaker = last_speaker
token.validated_speaker = False
if token.validated_speaker:
state.last_validated_token = i
state.last_speaker = token.corrected_speaker
last_speaker = 1
lines = []
for token in tokens:
if token.corrected_speaker != -1:
if int(token.corrected_speaker) != int(last_speaker):
lines.append(new_line(token))
else:
append_token_to_last_line(lines, sep, token)
last_speaker = token.corrected_speaker
if lines:
unassigned_translated_segments = []
for ts in translation_validated_segments:
assigned = False
for line in lines:
if ts and ts.overlaps_with(line):
if ts.is_within(line):
line.translation += ts.text + ' '
assigned = True
break
else:
ts0, ts1 = ts.approximate_cut_at(line.end)
if ts0 and line.overlaps_with(ts0):
line.translation += ts0.text + ' '
if ts1:
unassigned_translated_segments.append(ts1)
assigned = True
break
if not assigned:
unassigned_translated_segments.append(ts)
if unassigned_translated_segments:
for line in lines:
remaining_segments = []
for ts in unassigned_translated_segments:
if ts and ts.overlaps_with(line):
line.translation += ts.text + ' '
else:
remaining_segments.append(ts)
unassigned_translated_segments = remaining_segments #maybe do smth in the future about that
if state.buffer_transcription and lines:
lines[-1].end = max(state.buffer_transcription.end, lines[-1].end)
return lines, undiarized_text

View File

@@ -266,7 +266,7 @@ class AlignAtt:
logger.debug("Refreshing segment:") logger.debug("Refreshing segment:")
self.init_tokens() self.init_tokens()
self.last_attend_frame = -self.cfg.rewind_threshold self.last_attend_frame = -self.cfg.rewind_threshold
self.detected_language = None # self.detected_language = None
self.cumulative_time_offset = 0.0 self.cumulative_time_offset = 0.0
self.init_context() self.init_context()
logger.debug(f"Context: {self.context}") logger.debug(f"Context: {self.context}")

View File

@@ -0,0 +1,261 @@
"""
ALPHA. results are not great yet
To replace `whisperlivekit.silero_vad_iterator import FixedVADIterator`
by `from whisperlivekit.ten_vad_iterator import TenVADIterator`
Use self.vac = TenVADIterator() instead of self.vac = FixedVADIterator(models.vac_model)
"""
import numpy as np
from ten_vad import TenVad
class TenVADIterator:
def __init__(self,
threshold: float = 0.5,
sampling_rate: int = 16000,
min_silence_duration_ms: int = 100,
speech_pad_ms: int = 30):
self.vad = TenVad()
self.threshold = threshold
self.sampling_rate = sampling_rate
self.min_silence_duration_ms = min_silence_duration_ms
self.speech_pad_ms = speech_pad_ms
self.min_silence_samples = int(sampling_rate * min_silence_duration_ms / 1000)
self.speech_pad_samples = int(sampling_rate * speech_pad_ms / 1000)
self.reset_states()
def reset_states(self):
self.triggered = False
self.temp_end = 0
self.current_sample = 0
self.buffer = np.array([], dtype=np.float32)
def __call__(self, x, return_seconds=False):
if not isinstance(x, np.ndarray):
x = np.array(x, dtype=np.float32)
self.buffer = np.append(self.buffer, x)
chunk_size = 256
ret = None
while len(self.buffer) >= chunk_size:
chunk = self.buffer[:chunk_size].astype(np.int16)
self.buffer = self.buffer[chunk_size:]
window_size_samples = len(chunk)
self.current_sample += window_size_samples
speech_prob, speech_flag = self.vad.process(chunk)
if (speech_prob >= self.threshold) and self.temp_end:
self.temp_end = 0
if (speech_prob >= self.threshold) and not self.triggered:
self.triggered = True
speech_start = max(0, self.current_sample - self.speech_pad_samples - window_size_samples)
result = {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, 1)}
if ret is None:
ret = result
elif "end" in ret:
ret = result
else:
ret.update(result)
if (speech_prob < self.threshold - 0.15) and self.triggered:
if not self.temp_end:
self.temp_end = self.current_sample
if self.current_sample - self.temp_end < self.min_silence_samples:
continue
else:
speech_end = self.temp_end + self.speech_pad_samples - window_size_samples
self.temp_end = 0
self.triggered = False
result = {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, 1)}
if ret is None:
ret = result
else:
ret.update(result)
return ret if ret != {} else None
def test_on_record_wav():
import os
from pathlib import Path
audio_file = Path("record.wav")
if not audio_file.exists():
return
import soundfile as sf
audio_data, sample_rate = sf.read(str(audio_file), dtype='float32')
if len(audio_data.shape) > 1:
audio_data = np.mean(audio_data, axis=1)
vad = TenVADIterator(
threshold=0.5,
sampling_rate=sample_rate,
min_silence_duration_ms=100,
speech_pad_ms=30
)
chunk_size = 1024
speech_segments = []
current_segment = None
for i in range(0, len(audio_data), chunk_size):
chunk = audio_data[i:i+chunk_size]
if chunk.dtype != np.int16:
chunk_int16 = (chunk * 32767.0).astype(np.int16)
else:
chunk_int16 = chunk
result = vad(chunk_int16, return_seconds=True)
if result is not None:
if 'start' in result:
current_segment = {'start': result['start'], 'end': None}
print(f"Speech start detected at {result['start']:.2f}s")
elif 'end' in result:
if current_segment:
current_segment['end'] = result['end']
duration = current_segment['end'] - current_segment['start']
speech_segments.append(current_segment)
print(f"Speech end detected at {result['end']:.2f}s (duration: {duration:.2f}s)")
current_segment = None
else:
print(f"Speech end detected at {result['end']:.2f}s (no corresponding start)")
if current_segment and current_segment['end'] is None:
current_segment['end'] = len(audio_data) / sample_rate
speech_segments.append(current_segment)
print(f"End of file (last segment at {current_segment['start']:.2f}s)")
print("-" * 60)
print(f"\nSummary:")
print(f"Number of speech segments detected: {len(speech_segments)}")
if speech_segments:
total_speech_time = sum(seg['end'] - seg['start'] for seg in speech_segments)
total_time = len(audio_data) / sample_rate
speech_ratio = (total_speech_time / total_time) * 100
print(f"Total speech time: {total_speech_time:.2f}s")
print(f"Total file time: {total_time:.2f}s")
print(f"Speech ratio: {speech_ratio:.1f}%")
print(f"\nDetected segments:")
for i, seg in enumerate(speech_segments, 1):
print(f" {i}. {seg['start']:.2f}s - {seg['end']:.2f}s (duration: {seg['end'] - seg['start']:.2f}s)")
else:
print("No speech segments detected")
print("\n" + "=" * 60)
print("Extracting silence segments...")
silence_segments = []
total_time = len(audio_data) / sample_rate
if not speech_segments:
silence_segments = [{'start': 0.0, 'end': total_time}]
else:
if speech_segments[0]['start'] > 0:
silence_segments.append({'start': 0.0, 'end': speech_segments[0]['start']})
for i in range(len(speech_segments) - 1):
silence_start = speech_segments[i]['end']
silence_end = speech_segments[i + 1]['start']
if silence_end > silence_start:
silence_segments.append({'start': silence_start, 'end': silence_end})
if speech_segments[-1]['end'] < total_time:
silence_segments.append({'start': speech_segments[-1]['end'], 'end': total_time})
silence_audio = np.array([], dtype=audio_data.dtype)
for seg in silence_segments:
start_sample = int(seg['start'] * sample_rate)
end_sample = int(seg['end'] * sample_rate)
start_sample = max(0, min(start_sample, len(audio_data)))
end_sample = max(0, min(end_sample, len(audio_data)))
if end_sample > start_sample:
silence_audio = np.concatenate([silence_audio, audio_data[start_sample:end_sample]])
if len(silence_audio) > 0:
output_file = "record_silence_only.wav"
try:
import soundfile as sf
sf.write(output_file, silence_audio, sample_rate)
print(f"Silence file saved: {output_file}")
except ImportError:
try:
from scipy.io import wavfile
if silence_audio.dtype == np.float32:
silence_audio_int16 = (silence_audio * 32767.0).astype(np.int16)
else:
silence_audio_int16 = silence_audio.astype(np.int16)
wavfile.write(output_file, sample_rate, silence_audio_int16)
print(f"Silence file saved: {output_file}")
except ImportError:
print("Unable to save: soundfile or scipy required")
total_silence_time = sum(seg['end'] - seg['start'] for seg in silence_segments)
silence_ratio = (total_silence_time / total_time) * 100
print(f"Total silence duration: {total_silence_time:.2f}s")
print(f"Silence ratio: {silence_ratio:.1f}%")
print(f"Number of silence segments: {len(silence_segments)}")
print(f"\nYou can listen to {output_file} to verify that only silences are present.")
else:
print("No silence segments found (file entirely speech)")
print("\n" + "=" * 60)
print("Extracting speech segments...")
if speech_segments:
speech_audio = np.array([], dtype=audio_data.dtype)
for seg in speech_segments:
start_sample = int(seg['start'] * sample_rate)
end_sample = int(seg['end'] * sample_rate)
start_sample = max(0, min(start_sample, len(audio_data)))
end_sample = max(0, min(end_sample, len(audio_data)))
if end_sample > start_sample:
speech_audio = np.concatenate([speech_audio, audio_data[start_sample:end_sample]])
if len(speech_audio) > 0:
output_file = "record_speech_only.wav"
try:
import soundfile as sf
sf.write(output_file, speech_audio, sample_rate)
print(f"Speech file saved: {output_file}")
except ImportError:
try:
from scipy.io import wavfile
if speech_audio.dtype == np.float32:
speech_audio_int16 = (speech_audio * 32767.0).astype(np.int16)
else:
speech_audio_int16 = speech_audio.astype(np.int16)
wavfile.write(output_file, sample_rate, speech_audio_int16)
print(f"Speech file saved: {output_file}")
except ImportError:
print("Unable to save: soundfile or scipy required")
total_speech_time = sum(seg['end'] - seg['start'] for seg in speech_segments)
speech_ratio = (total_speech_time / total_time) * 100
print(f"Total speech duration: {total_speech_time:.2f}s")
print(f"Speech ratio: {speech_ratio:.1f}%")
print(f"Number of speech segments: {len(speech_segments)}")
print(f"\nYou can listen to {output_file} to verify that only speech segments are present.")
else:
print("No speech audio to extract")
else:
print("No speech segments found (file entirely silence)")
if __name__ == "__main__":
test_on_record_wav()

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional, Any, List from typing import Optional, List, Union, Dict, Any
from datetime import timedelta from datetime import timedelta
PUNCTUATION_MARKS = {'.', '!', '?', '', '', ''} PUNCTUATION_MARKS = {'.', '!', '?', '', '', ''}
@@ -19,11 +19,8 @@ class TimedText(Timed):
speaker: Optional[int] = -1 speaker: Optional[int] = -1
detected_language: Optional[str] = None detected_language: Optional[str] = None
def is_punctuation(self): def has_punctuation(self) -> bool:
return self.text.strip() in PUNCTUATION_MARKS return any(char in PUNCTUATION_MARKS for char in self.text.strip())
def overlaps_with(self, other: 'TimedText') -> bool:
return not (self.end <= other.start or other.end <= self.start)
def is_within(self, other: 'TimedText') -> bool: def is_within(self, other: 'TimedText') -> bool:
return other.contains_timespan(self) return other.contains_timespan(self)
@@ -31,28 +28,26 @@ class TimedText(Timed):
def duration(self) -> float: def duration(self) -> float:
return self.end - self.start return self.end - self.start
def contains_time(self, time: float) -> bool:
return self.start <= time <= self.end
def contains_timespan(self, other: 'TimedText') -> bool: def contains_timespan(self, other: 'TimedText') -> bool:
return self.start <= other.start and self.end >= other.end return self.start <= other.start and self.end >= other.end
def __bool__(self): def __bool__(self) -> bool:
return bool(self.text) return bool(self.text)
def __str__(self) -> str:
return str(self.text)
@dataclass() @dataclass()
class ASRToken(TimedText): class ASRToken(TimedText):
corrected_speaker: Optional[int] = -1
validated_speaker: bool = False
validated_text: bool = False
validated_language: bool = False
def with_offset(self, offset: float) -> "ASRToken": def with_offset(self, offset: float) -> "ASRToken":
"""Return a new token with the time offset added.""" """Return a new token with the time offset added."""
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, detected_language=self.detected_language) return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, detected_language=self.detected_language)
def is_silence(self) -> bool:
return False
@dataclass @dataclass
class Sentence(TimedText): class Sentence(TimedText):
pass pass
@@ -70,6 +65,7 @@ class Transcript(TimedText):
sep: Optional[str] = None, sep: Optional[str] = None,
offset: float = 0 offset: float = 0
) -> "Transcript": ) -> "Transcript":
"""Collapse multiple ASR tokens into a single transcript span."""
sep = sep if sep is not None else ' ' sep = sep if sep is not None else ' '
text = sep.join(token.text for token in tokens) text = sep.join(token.text for token in tokens)
if tokens: if tokens:
@@ -93,47 +89,70 @@ class SpeakerSegment(Timed):
class Translation(TimedText): class Translation(TimedText):
pass pass
def approximate_cut_at(self, cut_time):
"""
Each word in text is considered to be of duration (end-start)/len(words in text)
"""
if not self.text or not self.contains_time(cut_time):
return self, None
words = self.text.split()
num_words = len(words)
if num_words == 0:
return self, None
duration_per_word = self.duration() / num_words
cut_word_index = int((cut_time - self.start) / duration_per_word)
if cut_word_index >= num_words:
cut_word_index = num_words -1
text0 = " ".join(words[:cut_word_index])
text1 = " ".join(words[cut_word_index:])
segment0 = Translation(start=self.start, end=cut_time, text=text0)
segment1 = Translation(start=cut_time, end=self.end, text=text1)
return segment0, segment1
@dataclass @dataclass
class Silence(): class Silence():
start: Optional[float] = None
end: Optional[float] = None
duration: Optional[float] = None duration: Optional[float] = None
is_starting: bool = False is_starting: bool = False
has_ended: bool = False has_ended: bool = False
def compute_duration(self) -> Optional[float]:
if self.start is None or self.end is None:
return None
self.duration = self.end - self.start
return self.duration
def is_silence(self) -> bool:
return True
@dataclass
class Segment(TimedText):
"""Generic contiguous span built from tokens or silence markers."""
start: Optional[float]
end: Optional[float]
text: Optional[str]
speaker: Optional[str]
@classmethod
def from_tokens(
cls,
tokens: List[Union[ASRToken, Silence]],
is_silence: bool = False
) -> Optional["Segment"]:
"""Return a normalized segment representing the provided tokens."""
if not tokens:
return None
start_token = tokens[0]
end_token = tokens[-1]
if is_silence:
return cls(
start=start_token.start,
end=end_token.end,
text=None,
speaker=-2
)
else:
return cls(
start=start_token.start,
end=end_token.end,
text=''.join(token.text for token in tokens),
speaker=-1,
detected_language=start_token.detected_language
)
def is_silence(self) -> bool:
"""True when this segment represents a silence gap."""
return self.speaker == -2
@dataclass @dataclass
class Line(TimedText): class Line(TimedText):
translation: str = '' translation: str = ''
def to_dict(self): def to_dict(self) -> Dict[str, Any]:
_dict = { """Serialize the line for frontend consumption."""
_dict: Dict[str, Any] = {
'speaker': int(self.speaker) if self.speaker != -1 else 1, 'speaker': int(self.speaker) if self.speaker != -1 else 1,
'text': self.text, 'text': self.text,
'start': format_time(self.start), 'start': format_time(self.start),
@@ -145,6 +164,33 @@ class Line(TimedText):
_dict['detected_language'] = self.detected_language _dict['detected_language'] = self.detected_language
return _dict return _dict
def build_from_tokens(self, tokens: List[ASRToken]) -> "Line":
"""Populate line attributes from a contiguous token list."""
self.text = ''.join([token.text for token in tokens])
self.start = tokens[0].start
self.end = tokens[-1].end
self.speaker = 1
self.detected_language = tokens[0].detected_language
return self
def build_from_segment(self, segment: Segment) -> "Line":
"""Populate the line fields from a pre-built segment."""
self.text = segment.text
self.start = segment.start
self.end = segment.end
self.speaker = segment.speaker
self.detected_language = segment.detected_language
return self
def is_silent(self) -> bool:
return self.speaker == -2
class SilentLine(Line):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.speaker = -2
self.text = ''
@dataclass @dataclass
class FrontData(): class FrontData():
@@ -157,8 +203,9 @@ class FrontData():
remaining_time_transcription: float = 0. remaining_time_transcription: float = 0.
remaining_time_diarization: float = 0. remaining_time_diarization: float = 0.
def to_dict(self): def to_dict(self) -> Dict[str, Any]:
_dict = { """Serialize the front-end data payload."""
_dict: Dict[str, Any] = {
'status': self.status, 'status': self.status,
'lines': [line.to_dict() for line in self.lines if (line.text or line.speaker == -2)], 'lines': [line.to_dict() for line in self.lines if (line.text or line.speaker == -2)],
'buffer_transcription': self.buffer_transcription, 'buffer_transcription': self.buffer_transcription,
@@ -178,26 +225,22 @@ class ChangeSpeaker:
@dataclass @dataclass
class State(): class State():
tokens: list = field(default_factory=list) """Unified state class for audio processing.
last_validated_token: int = 0
last_speaker: int = 1 Contains both persistent state (tokens, buffers) and temporary update buffers
last_punctuation_index: Optional[int] = None (new_* fields) that are consumed by TokensAlignment.
translation_validated_segments: list = field(default_factory=list) """
buffer_translation: str = field(default_factory=Transcript) # Persistent state
buffer_transcription: str = field(default_factory=Transcript) tokens: List[ASRToken] = field(default_factory=list)
diarization_segments: list = field(default_factory=list) buffer_transcription: Transcript = field(default_factory=Transcript)
end_buffer: float = 0.0 end_buffer: float = 0.0
end_attributed_speaker: float = 0.0 end_attributed_speaker: float = 0.0
remaining_time_transcription: float = 0.0 remaining_time_transcription: float = 0.0
remaining_time_diarization: float = 0.0 remaining_time_diarization: float = 0.0
# Temporary update buffers (consumed by TokensAlignment.update())
@dataclass new_tokens: List[Union[ASRToken, Silence]] = field(default_factory=list)
class StateLight(): new_translation: List[Any] = field(default_factory=list)
new_tokens: list = field(default_factory=list) new_diarization: List[Any] = field(default_factory=list)
new_translation: list = field(default_factory=list) new_tokens_buffer: List[Any] = field(default_factory=list) # only when local agreement
new_diarization: list = field(default_factory=list) new_translation_buffer= TimedText()
new_tokens_buffer: list = field(default_factory=list) #only when local agreement
new_tokens_index = 0
new_translation_index = 0
new_diarization_index = 0

View File

@@ -0,0 +1,177 @@
from time import time
from typing import Optional, List, Tuple, Union, Any
from whisperlivekit.timed_objects import Line, SilentLine, ASRToken, SpeakerSegment, Silence, TimedText, Segment
class TokensAlignment:
def __init__(self, state: Any, args: Any, sep: Optional[str]) -> None:
self.state = state
self.diarization = args.diarization
self._tokens_index: int = 0
self._diarization_index: int = 0
self._translation_index: int = 0
self.all_tokens: List[ASRToken] = []
self.all_diarization_segments: List[SpeakerSegment] = []
self.all_translation_segments: List[Any] = []
self.new_tokens: List[ASRToken] = []
self.new_diarization: List[SpeakerSegment] = []
self.new_translation: List[Any] = []
self.new_translation_buffer: Union[TimedText, str] = TimedText()
self.new_tokens_buffer: List[Any] = []
self.sep: str = sep if sep is not None else ' '
self.beg_loop: Optional[float] = None
def update(self) -> None:
"""Drain state buffers into the running alignment context."""
self.new_tokens, self.state.new_tokens = self.state.new_tokens, []
self.new_diarization, self.state.new_diarization = self.state.new_diarization, []
self.new_translation, self.state.new_translation = self.state.new_translation, []
self.new_tokens_buffer, self.state.new_tokens_buffer = self.state.new_tokens_buffer, []
self.all_tokens.extend(self.new_tokens)
self.all_diarization_segments.extend(self.new_diarization)
self.all_translation_segments.extend(self.new_translation)
self.new_translation_buffer = self.state.new_translation_buffer
def add_translation(self, line: Line) -> None:
"""Append translated text segments that overlap with a line."""
for ts in self.all_translation_segments:
if ts.is_within(line):
line.translation += ts.text + (self.sep if ts.text else '')
elif line.translation:
break
def compute_punctuations_segments(self, tokens: Optional[List[ASRToken]] = None) -> List[Segment]:
"""Group tokens into segments split by punctuation and explicit silence."""
segments = []
segment_start_idx = 0
for i, token in enumerate(self.all_tokens):
if token.is_silence():
previous_segment = Segment.from_tokens(
tokens=self.all_tokens[segment_start_idx: i],
)
if previous_segment:
segments.append(previous_segment)
segment = Segment.from_tokens(
tokens=[token],
is_silence=True
)
segments.append(segment)
segment_start_idx = i+1
else:
if token.has_punctuation():
segment = Segment.from_tokens(
tokens=self.all_tokens[segment_start_idx: i+1],
)
segments.append(segment)
segment_start_idx = i+1
final_segment = Segment.from_tokens(
tokens=self.all_tokens[segment_start_idx:],
)
if final_segment:
segments.append(final_segment)
return segments
def concatenate_diar_segments(self) -> List[SpeakerSegment]:
"""Merge consecutive diarization slices that share the same speaker."""
if not self.all_diarization_segments:
return []
merged = [self.all_diarization_segments[0]]
for segment in self.all_diarization_segments[1:]:
if segment.speaker == merged[-1].speaker:
merged[-1].end = segment.end
else:
merged.append(segment)
return merged
@staticmethod
def intersection_duration(seg1: TimedText, seg2: TimedText) -> float:
"""Return the overlap duration between two timed segments."""
start = max(seg1.start, seg2.start)
end = min(seg1.end, seg2.end)
return max(0, end - start)
def get_lines_diarization(self) -> Tuple[List[Line], str]:
"""Build lines when diarization is enabled and track overflow buffer."""
diarization_buffer = ''
punctuation_segments = self.compute_punctuations_segments()
diarization_segments = self.concatenate_diar_segments()
for punctuation_segment in punctuation_segments:
if not punctuation_segment.is_silence():
if diarization_segments and punctuation_segment.start >= diarization_segments[-1].end:
diarization_buffer += punctuation_segment.text
else:
max_overlap = 0.0
max_overlap_speaker = 1
for diarization_segment in diarization_segments:
intersec = self.intersection_duration(punctuation_segment, diarization_segment)
if intersec > max_overlap:
max_overlap = intersec
max_overlap_speaker = diarization_segment.speaker + 1
punctuation_segment.speaker = max_overlap_speaker
lines = []
if punctuation_segments:
lines = [Line().build_from_segment(punctuation_segments[0])]
for segment in punctuation_segments[1:]:
if segment.speaker == lines[-1].speaker:
if lines[-1].text:
lines[-1].text += segment.text
lines[-1].end = segment.end
else:
lines.append(Line().build_from_segment(segment))
return lines, diarization_buffer
def get_lines(
self,
diarization: bool = False,
translation: bool = False,
current_silence: Optional[Silence] = None
) -> Tuple[List[Line], str, Union[str, TimedText]]:
"""Return the formatted lines plus buffers, optionally with diarization/translation."""
if diarization:
lines, diarization_buffer = self.get_lines_diarization()
else:
diarization_buffer = ''
lines = []
current_line_tokens = []
for token in self.all_tokens:
if token.is_silence():
if current_line_tokens:
lines.append(Line().build_from_tokens(current_line_tokens))
current_line_tokens = []
end_silence = token.end if token.has_ended else time() - self.beg_loop
if lines and lines[-1].is_silent():
lines[-1].end = end_silence
else:
lines.append(SilentLine(
start = token.start,
end = end_silence
))
else:
current_line_tokens.append(token)
if current_line_tokens:
lines.append(Line().build_from_tokens(current_line_tokens))
if current_silence:
end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop
if lines and lines[-1].is_silent():
lines[-1].end = end_silence
else:
lines.append(SilentLine(
start = current_silence.start,
end = end_silence
))
if translation:
[self.add_translation(line) for line in lines if not type(line) == Silence]
return lines, diarization_buffer, self.new_translation_buffer.text

View File

@@ -1,60 +0,0 @@
from typing import Sequence, Callable, Any, Optional, Dict
def _detect_tail_repetition(
seq: Sequence[Any],
key: Callable[[Any], Any] = lambda x: x, # extract comparable value
min_block: int = 1, # set to 2 to ignore 1-token loops like "."
max_tail: int = 300, # search window from the end for speed
prefer: str = "longest", # "longest" coverage or "smallest" block
) -> Optional[Dict]:
vals = [key(x) for x in seq][-max_tail:]
n = len(vals)
best = None
# try every possible block length
for b in range(min_block, n // 2 + 1):
block = vals[-b:]
# count how many times this block repeats contiguously at the very end
count, i = 0, n
while i - b >= 0 and vals[i - b:i] == block:
count += 1
i -= b
if count >= 2:
cand = {
"block_size": b,
"count": count,
"start_index": len(seq) - count * b, # in original seq
"end_index": len(seq),
}
if (best is None or
(prefer == "longest" and count * b > best["count"] * best["block_size"]) or
(prefer == "smallest" and b < best["block_size"])):
best = cand
return best
def trim_tail_repetition(
seq: Sequence[Any],
key: Callable[[Any], Any] = lambda x: x,
min_block: int = 1,
max_tail: int = 300,
prefer: str = "longest",
keep: int = 1, # how many copies of the repeating block to keep at the end (0 or 1 are common)
):
"""
Returns a new sequence with repeated tail trimmed.
keep=1 -> keep a single copy of the repeated block.
keep=0 -> remove all copies of the repeated block.
"""
rep = _detect_tail_repetition(seq, key, min_block, max_tail, prefer)
if not rep:
return seq, False # nothing to trim
b, c = rep["block_size"], rep["count"]
if keep < 0:
keep = 0
if keep >= c:
return seq, False # nothing to trim (already <= keep copies)
# new length = total - (copies_to_remove * block_size)
new_len = len(seq) - (c - keep) * b
return seq[:new_len], True