improve diarization speed + Use punctuation to better align speakers and diarization

This commit is contained in:
Quentin Fuxa
2025-06-19 13:03:29 +02:00
parent c9f60504e3
commit 0f79d442ee
5 changed files with 193 additions and 20 deletions

View File

@@ -32,6 +32,7 @@ WhisperLiveKit consists of three main components:
- **👥 Speaker Diarization** - Identify different speakers in real-time using [Diart](https://github.com/juanmc2005/diart)
- **🔒 Fully Local** - All processing happens on your machine - no data sent to external servers
- **📱 Multi-User Support** - Handle multiple users simultaneously with a single backend/server
- **📝 Punctuation-Based Speaker Splitting [BETA] ** - Align speaker changes with natural sentence boundaries for more readable transcripts
### ⚙️ Core differences from [Whisper Streaming](https://github.com/ufal/whisper_streaming)
@@ -230,6 +231,7 @@ WhisperLiveKit offers extensive configuration options:
| `--task` | `transcribe` or `translate` | `transcribe` |
| `--backend` | Processing backend | `faster-whisper` |
| `--diarization` | Enable speaker identification | `False` |
| `--punctuation-split` | Use punctuation to improve speaker boundaries | `True` |
| `--confidence-validation` | Use confidence scores for faster validation | `False` |
| `--min-chunk-size` | Minimum audio chunk size (seconds) | `1.0` |
| `--vac` | Use Voice Activity Controller | `False` |

View File

@@ -377,13 +377,16 @@ class AudioProcessor:
# Process diarization
await diarization_obj.diarize(pcm_array)
# Get current state and update speakers
state = await self.get_current_state()
new_end = diarization_obj.assign_speakers_to_tokens(
state["end_attributed_speaker"], state["tokens"]
)
async with self.lock:
new_end = diarization_obj.assign_speakers_to_tokens(
self.end_attributed_speaker,
self.tokens,
use_punctuation_split=self.args.punctuation_split
)
self.end_attributed_speaker = new_end
if buffer_diarization:
self.buffer_diarization = buffer_diarization
await self.update_diarization(new_end, buffer_diarization)
self.diarization_queue.task_done()
except Exception as e:

View File

@@ -24,6 +24,7 @@ class TranscriptionEngine:
"warmup_file": None,
"confidence_validation": False,
"diarization": False,
"punctuation_split": True,
"min_chunk_size": 0.5,
"model": "tiny",
"model_cache_dir": None,
@@ -68,6 +69,6 @@ class TranscriptionEngine:
if self.args.diarization:
from whisperlivekit.diarization.diarization_online import DiartDiarization
self.diarization = DiartDiarization()
self.diarization = DiartDiarization(block_duration=self.args.min_chunk_size)
TranscriptionEngine._initialized = True

View File

@@ -3,7 +3,8 @@ import re
import threading
import numpy as np
import logging
import time
from queue import SimpleQueue, Empty
from diart import SpeakerDiarization, SpeakerDiarizationConfig
from diart.inference import StreamingInference
@@ -13,6 +14,10 @@ from diart.sources import MicrophoneAudioSource
from rx.core import Observer
from typing import Tuple, Any, List
from pyannote.core import Annotation
import diart.models as m
segmentation = m.SegmentationModel.from_pretrained("pyannote/segmentation-3.0")
embedding = m.EmbeddingModel.from_pretrained("speechbrain/spkrec-ecapa-voxceleb")
logger = logging.getLogger(__name__)
@@ -78,40 +83,104 @@ class DiarizationObserver(Observer):
class WebSocketAudioSource(AudioSource):
"""
Custom AudioSource that blocks in read() until close() is called.
Use push_audio() to inject PCM chunks.
Buffers incoming audio and releases it in fixed-size chunks at regular intervals.
"""
def __init__(self, uri: str = "websocket", sample_rate: int = 16000):
def __init__(self, uri: str = "websocket", sample_rate: int = 16000, block_duration: float = 0.5):
super().__init__(uri, sample_rate)
self.block_duration = block_duration
self.block_size = int(np.rint(block_duration * sample_rate))
self._queue = SimpleQueue()
self._buffer = np.array([], dtype=np.float32)
self._buffer_lock = threading.Lock()
self._closed = False
self._close_event = threading.Event()
self._processing_thread = None
self._last_chunk_time = time.time()
def read(self):
"""Start processing buffered audio and emit fixed-size chunks."""
self._processing_thread = threading.Thread(target=self._process_chunks)
self._processing_thread.daemon = True
self._processing_thread.start()
self._close_event.wait()
if self._processing_thread:
self._processing_thread.join(timeout=2.0)
def _process_chunks(self):
"""Process audio from queue and emit fixed-size chunks at regular intervals."""
while not self._closed:
try:
audio_chunk = self._queue.get(timeout=0.1)
with self._buffer_lock:
self._buffer = np.concatenate([self._buffer, audio_chunk])
while len(self._buffer) >= self.block_size:
chunk = self._buffer[:self.block_size]
self._buffer = self._buffer[self.block_size:]
current_time = time.time()
time_since_last = current_time - self._last_chunk_time
if time_since_last < self.block_duration:
time.sleep(self.block_duration - time_since_last)
chunk_reshaped = chunk.reshape(1, -1)
self.stream.on_next(chunk_reshaped)
self._last_chunk_time = time.time()
except Empty:
with self._buffer_lock:
if len(self._buffer) > 0 and time.time() - self._last_chunk_time > self.block_duration:
padded_chunk = np.zeros(self.block_size, dtype=np.float32)
padded_chunk[:len(self._buffer)] = self._buffer
self._buffer = np.array([], dtype=np.float32)
chunk_reshaped = padded_chunk.reshape(1, -1)
self.stream.on_next(chunk_reshaped)
self._last_chunk_time = time.time()
except Exception as e:
logger.error(f"Error in audio processing thread: {e}")
self.stream.on_error(e)
break
with self._buffer_lock:
if len(self._buffer) > 0:
padded_chunk = np.zeros(self.block_size, dtype=np.float32)
padded_chunk[:len(self._buffer)] = self._buffer
chunk_reshaped = padded_chunk.reshape(1, -1)
self.stream.on_next(chunk_reshaped)
self.stream.on_completed()
def close(self):
if not self._closed:
self._closed = True
self.stream.on_completed()
self._close_event.set()
def push_audio(self, chunk: np.ndarray):
"""Add audio chunk to the processing queue."""
if not self._closed:
new_audio = np.expand_dims(chunk, axis=0)
logger.debug('Add new chunk with shape:', new_audio.shape)
self.stream.on_next(new_audio)
if chunk.ndim > 1:
chunk = chunk.flatten()
self._queue.put(chunk)
logger.debug(f'Added chunk to queue with {len(chunk)} samples')
class DiartDiarization:
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False):
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 0.5):
self.pipeline = SpeakerDiarization(config=config)
self.observer = DiarizationObserver()
if use_microphone:
self.source = MicrophoneAudioSource()
self.source = MicrophoneAudioSource(block_duration=block_duration)
self.custom_source = None
else:
self.custom_source = WebSocketAudioSource(uri="websocket_source", sample_rate=sample_rate)
self.custom_source = WebSocketAudioSource(
uri="websocket_source",
sample_rate=sample_rate,
block_duration=block_duration
)
self.source = self.custom_source
self.inference = StreamingInference(
@@ -138,16 +207,107 @@ class DiartDiarization:
if self.custom_source:
self.custom_source.close()
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) -> float:
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list, use_punctuation_split: bool = False) -> float:
"""
Assign speakers to tokens based on timing overlap with speaker segments.
Uses the segments collected by the observer.
If use_punctuation_split is True, uses punctuation marks to refine speaker boundaries.
"""
segments = self.observer.get_segments()
# Debug logging
logger.debug(f"assign_speakers_to_tokens called with {len(tokens)} tokens")
logger.debug(f"Available segments: {len(segments)}")
for i, seg in enumerate(segments[:5]): # Show first 5 segments
logger.debug(f" Segment {i}: {seg.speaker} [{seg.start:.2f}-{seg.end:.2f}]")
# First pass: assign speakers based on timing overlap
for token in tokens:
for segment in segments:
if not (segment.end <= token.start or segment.start >= token.end):
token.speaker = extract_number(segment.speaker) + 1
end_attributed_speaker = max(token.end, end_attributed_speaker)
return end_attributed_speaker
if use_punctuation_split and len(tokens) > 1:
punctuation_marks = {'.', '!', '?'}
print("Here are the tokens:",
[(t.text, t.start, t.end, t.speaker) for t in tokens[:10]])
segment_map = []
for segment in segments:
speaker_num = extract_number(segment.speaker) + 1
segment_map.append((segment.start, segment.end, speaker_num))
segment_map.sort(key=lambda x: x[0]) # Sort by start time
i = 0
while i < len(tokens):
current_token = tokens[i]
# Check if current token ends with sentence-ending punctuation
is_sentence_end = False
if current_token.text and current_token.text.strip():
text = current_token.text.strip()
if text[-1] in punctuation_marks:
is_sentence_end = True
logger.debug(f"Token {i} ends sentence: '{current_token.text}' at {current_token.end:.2f}s")
if is_sentence_end and current_token.speaker != -1:
# Find the dominant speaker for tokens after this punctuation
punctuation_time = current_token.end
current_speaker = current_token.speaker
# Look ahead to find where the next sentence starts and ends
j = i + 1
next_sentence_tokens = []
# Collect tokens until we hit another sentence-ending punctuation or run out
while j < len(tokens):
next_token = tokens[j]
next_sentence_tokens.append(j)
# Check if this token ends the next sentence
if next_token.text and next_token.text.strip():
if next_token.text.strip()[-1] in punctuation_marks:
break
j += 1
if next_sentence_tokens:
speaker_times = {}
for idx in next_sentence_tokens:
token = tokens[idx]
# Find which segments overlap with this token
for seg_start, seg_end, seg_speaker in segment_map:
if not (seg_end <= token.start or seg_start >= token.end):
# Calculate overlap duration
overlap_start = max(seg_start, token.start)
overlap_end = min(seg_end, token.end)
overlap_duration = overlap_end - overlap_start
if seg_speaker not in speaker_times:
speaker_times[seg_speaker] = 0
speaker_times[seg_speaker] += overlap_duration
if speaker_times:
dominant_speaker = max(speaker_times.items(), key=lambda x: x[1])[0]
if dominant_speaker != current_speaker:
logger.debug(f" Speaker change after punctuation: {current_speaker}{dominant_speaker}")
for idx in next_sentence_tokens:
if tokens[idx].speaker != dominant_speaker:
logger.debug(f" Reassigning token {idx} ('{tokens[idx].text}') to Speaker {dominant_speaker}")
tokens[idx].speaker = dominant_speaker
end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker)
else:
for idx in next_sentence_tokens:
if tokens[idx].speaker == -1:
tokens[idx].speaker = current_speaker
end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker)
i += 1
return end_attributed_speaker

View File

@@ -37,6 +37,13 @@ def parse_args():
help="Enable speaker diarization.",
)
parser.add_argument(
"--punctuation-split",
action="store_true",
default=False,
help="Use punctuation marks from transcription to improve speaker boundary detection. Requires both transcription and diarization to be enabled.",
)
parser.add_argument(
"--no-transcription",
action="store_true",