4 Commits
0.1.8 ... 0.1.9

Author SHA1 Message Date
Quentin Fuxa
e165916952 add diarization model list url 2025-06-19 16:43:23 +02:00
Quentin Fuxa
8532a91c7a add segmentation and embedding model options to configuration 2025-06-19 16:29:25 +02:00
Quentin Fuxa
b01b81bad0 improve diarization with lag diarization substraction 2025-06-19 16:18:49 +02:00
Quentin Fuxa
0f79d442ee improve diarization speed + Use punctuation to better align speakers and diarization 2025-06-19 13:03:29 +02:00
7 changed files with 224 additions and 24 deletions

View File

@@ -32,6 +32,7 @@ WhisperLiveKit consists of three main components:
- **👥 Speaker Diarization** - Identify different speakers in real-time using [Diart](https://github.com/juanmc2005/diart) - **👥 Speaker Diarization** - Identify different speakers in real-time using [Diart](https://github.com/juanmc2005/diart)
- **🔒 Fully Local** - All processing happens on your machine - no data sent to external servers - **🔒 Fully Local** - All processing happens on your machine - no data sent to external servers
- **📱 Multi-User Support** - Handle multiple users simultaneously with a single backend/server - **📱 Multi-User Support** - Handle multiple users simultaneously with a single backend/server
- **📝 Punctuation-Based Speaker Splitting [BETA] ** - Align speaker changes with natural sentence boundaries for more readable transcripts
### ⚙️ Core differences from [Whisper Streaming](https://github.com/ufal/whisper_streaming) ### ⚙️ Core differences from [Whisper Streaming](https://github.com/ufal/whisper_streaming)
@@ -230,6 +231,7 @@ WhisperLiveKit offers extensive configuration options:
| `--task` | `transcribe` or `translate` | `transcribe` | | `--task` | `transcribe` or `translate` | `transcribe` |
| `--backend` | Processing backend | `faster-whisper` | | `--backend` | Processing backend | `faster-whisper` |
| `--diarization` | Enable speaker identification | `False` | | `--diarization` | Enable speaker identification | `False` |
| `--punctuation-split` | Use punctuation to improve speaker boundaries | `True` |
| `--confidence-validation` | Use confidence scores for faster validation | `False` | | `--confidence-validation` | Use confidence scores for faster validation | `False` |
| `--min-chunk-size` | Minimum audio chunk size (seconds) | `1.0` | | `--min-chunk-size` | Minimum audio chunk size (seconds) | `1.0` |
| `--vac` | Use Voice Activity Controller | `False` | | `--vac` | Use Voice Activity Controller | `False` |
@@ -238,6 +240,8 @@ WhisperLiveKit offers extensive configuration options:
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` | | `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
| `--ssl-certfile` | Path to the SSL certificate file (for HTTPS support) | `None` | | `--ssl-certfile` | Path to the SSL certificate file (for HTTPS support) | `None` |
| `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` | | `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` |
| `--segmentation-model` | Hugging Face model ID for pyannote.audio segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
| `--embedding-model` | Hugging Face model ID for pyannote.audio embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
## 🔧 How It Works ## 🔧 How It Works

View File

@@ -1,7 +1,7 @@
from setuptools import setup, find_packages from setuptools import setup, find_packages
setup( setup(
name="whisperlivekit", name="whisperlivekit",
version="0.1.8", version="0.1.9",
description="Real-time, Fully Local Whisper's Speech-to-Text and Speaker Diarization", description="Real-time, Fully Local Whisper's Speech-to-Text and Speaker Diarization",
long_description=open("README.md", "r", encoding="utf-8").read(), long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",

View File

@@ -377,13 +377,16 @@ class AudioProcessor:
# Process diarization # Process diarization
await diarization_obj.diarize(pcm_array) await diarization_obj.diarize(pcm_array)
# Get current state and update speakers async with self.lock:
state = await self.get_current_state() new_end = diarization_obj.assign_speakers_to_tokens(
new_end = diarization_obj.assign_speakers_to_tokens( self.end_attributed_speaker,
state["end_attributed_speaker"], state["tokens"] self.tokens,
) use_punctuation_split=self.args.punctuation_split
)
self.end_attributed_speaker = new_end
if buffer_diarization:
self.buffer_diarization = buffer_diarization
await self.update_diarization(new_end, buffer_diarization)
self.diarization_queue.task_done() self.diarization_queue.task_done()
except Exception as e: except Exception as e:

View File

@@ -24,6 +24,7 @@ class TranscriptionEngine:
"warmup_file": None, "warmup_file": None,
"confidence_validation": False, "confidence_validation": False,
"diarization": False, "diarization": False,
"punctuation_split": False,
"min_chunk_size": 0.5, "min_chunk_size": 0.5,
"model": "tiny", "model": "tiny",
"model_cache_dir": None, "model_cache_dir": None,
@@ -40,6 +41,8 @@ class TranscriptionEngine:
"ssl_keyfile": None, "ssl_keyfile": None,
"transcription": True, "transcription": True,
"vad": True, "vad": True,
"segmentation_model": "pyannote/segmentation-3.0",
"embedding_model": "pyannote/embedding",
} }
config_dict = {**defaults, **kwargs} config_dict = {**defaults, **kwargs}
@@ -68,6 +71,10 @@ class TranscriptionEngine:
if self.args.diarization: if self.args.diarization:
from whisperlivekit.diarization.diarization_online import DiartDiarization from whisperlivekit.diarization.diarization_online import DiartDiarization
self.diarization = DiartDiarization() self.diarization = DiartDiarization(
block_duration=self.args.min_chunk_size,
segmentation_model_name=self.args.segmentation_model,
embedding_model_name=self.args.embedding_model
)
TranscriptionEngine._initialized = True TranscriptionEngine._initialized = True

View File

@@ -3,7 +3,8 @@ import re
import threading import threading
import numpy as np import numpy as np
import logging import logging
import time
from queue import SimpleQueue, Empty
from diart import SpeakerDiarization, SpeakerDiarizationConfig from diart import SpeakerDiarization, SpeakerDiarizationConfig
from diart.inference import StreamingInference from diart.inference import StreamingInference
@@ -13,6 +14,7 @@ from diart.sources import MicrophoneAudioSource
from rx.core import Observer from rx.core import Observer
from typing import Tuple, Any, List from typing import Tuple, Any, List
from pyannote.core import Annotation from pyannote.core import Annotation
import diart.models as m
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -78,40 +80,114 @@ class DiarizationObserver(Observer):
class WebSocketAudioSource(AudioSource): class WebSocketAudioSource(AudioSource):
""" """
Custom AudioSource that blocks in read() until close() is called. Buffers incoming audio and releases it in fixed-size chunks at regular intervals.
Use push_audio() to inject PCM chunks.
""" """
def __init__(self, uri: str = "websocket", sample_rate: int = 16000): def __init__(self, uri: str = "websocket", sample_rate: int = 16000, block_duration: float = 0.5):
super().__init__(uri, sample_rate) super().__init__(uri, sample_rate)
self.block_duration = block_duration
self.block_size = int(np.rint(block_duration * sample_rate))
self._queue = SimpleQueue()
self._buffer = np.array([], dtype=np.float32)
self._buffer_lock = threading.Lock()
self._closed = False self._closed = False
self._close_event = threading.Event() self._close_event = threading.Event()
self._processing_thread = None
self._last_chunk_time = time.time()
def read(self): def read(self):
"""Start processing buffered audio and emit fixed-size chunks."""
self._processing_thread = threading.Thread(target=self._process_chunks)
self._processing_thread.daemon = True
self._processing_thread.start()
self._close_event.wait() self._close_event.wait()
if self._processing_thread:
self._processing_thread.join(timeout=2.0)
def _process_chunks(self):
"""Process audio from queue and emit fixed-size chunks at regular intervals."""
while not self._closed:
try:
audio_chunk = self._queue.get(timeout=0.1)
with self._buffer_lock:
self._buffer = np.concatenate([self._buffer, audio_chunk])
while len(self._buffer) >= self.block_size:
chunk = self._buffer[:self.block_size]
self._buffer = self._buffer[self.block_size:]
current_time = time.time()
time_since_last = current_time - self._last_chunk_time
if time_since_last < self.block_duration:
time.sleep(self.block_duration - time_since_last)
chunk_reshaped = chunk.reshape(1, -1)
self.stream.on_next(chunk_reshaped)
self._last_chunk_time = time.time()
except Empty:
with self._buffer_lock:
if len(self._buffer) > 0 and time.time() - self._last_chunk_time > self.block_duration:
padded_chunk = np.zeros(self.block_size, dtype=np.float32)
padded_chunk[:len(self._buffer)] = self._buffer
self._buffer = np.array([], dtype=np.float32)
chunk_reshaped = padded_chunk.reshape(1, -1)
self.stream.on_next(chunk_reshaped)
self._last_chunk_time = time.time()
except Exception as e:
logger.error(f"Error in audio processing thread: {e}")
self.stream.on_error(e)
break
with self._buffer_lock:
if len(self._buffer) > 0:
padded_chunk = np.zeros(self.block_size, dtype=np.float32)
padded_chunk[:len(self._buffer)] = self._buffer
chunk_reshaped = padded_chunk.reshape(1, -1)
self.stream.on_next(chunk_reshaped)
self.stream.on_completed()
def close(self): def close(self):
if not self._closed: if not self._closed:
self._closed = True self._closed = True
self.stream.on_completed()
self._close_event.set() self._close_event.set()
def push_audio(self, chunk: np.ndarray): def push_audio(self, chunk: np.ndarray):
"""Add audio chunk to the processing queue."""
if not self._closed: if not self._closed:
new_audio = np.expand_dims(chunk, axis=0) if chunk.ndim > 1:
logger.debug('Add new chunk with shape:', new_audio.shape) chunk = chunk.flatten()
self.stream.on_next(new_audio) self._queue.put(chunk)
logger.debug(f'Added chunk to queue with {len(chunk)} samples')
class DiartDiarization: class DiartDiarization:
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False): def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 0.5, segmentation_model_name: str = "pyannote/segmentation-3.0", embedding_model_name: str = "speechbrain/spkrec-ecapa-voxceleb"):
segmentation_model = m.SegmentationModel.from_pretrained(segmentation_model_name)
embedding_model = m.EmbeddingModel.from_pretrained(embedding_model_name)
if config is None:
config = SpeakerDiarizationConfig(
segmentation=segmentation_model,
embedding=embedding_model,
)
self.pipeline = SpeakerDiarization(config=config) self.pipeline = SpeakerDiarization(config=config)
self.observer = DiarizationObserver() self.observer = DiarizationObserver()
self.lag_diart = None
if use_microphone: if use_microphone:
self.source = MicrophoneAudioSource() self.source = MicrophoneAudioSource(block_duration=block_duration)
self.custom_source = None self.custom_source = None
else: else:
self.custom_source = WebSocketAudioSource(uri="websocket_source", sample_rate=sample_rate) self.custom_source = WebSocketAudioSource(
uri="websocket_source",
sample_rate=sample_rate,
block_duration=block_duration
)
self.source = self.custom_source self.source = self.custom_source
self.inference = StreamingInference( self.inference = StreamingInference(
@@ -138,16 +214,102 @@ class DiartDiarization:
if self.custom_source: if self.custom_source:
self.custom_source.close() self.custom_source.close()
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) -> float: def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list, use_punctuation_split: bool = False) -> float:
""" """
Assign speakers to tokens based on timing overlap with speaker segments. Assign speakers to tokens based on timing overlap with speaker segments.
Uses the segments collected by the observer. Uses the segments collected by the observer.
If use_punctuation_split is True, uses punctuation marks to refine speaker boundaries.
""" """
segments = self.observer.get_segments() segments = self.observer.get_segments()
# Debug logging
logger.debug(f"assign_speakers_to_tokens called with {len(tokens)} tokens")
logger.debug(f"Available segments: {len(segments)}")
for i, seg in enumerate(segments[:5]): # Show first 5 segments
logger.debug(f" Segment {i}: {seg.speaker} [{seg.start:.2f}-{seg.end:.2f}]")
if not self.lag_diart and segments and tokens:
self.lag_diart = segments[0].start - tokens[0].start
for token in tokens: for token in tokens:
for segment in segments: for segment in segments:
if not (segment.end <= token.start or segment.start >= token.end): if not (segment.end <= token.start + self.lag_diart or segment.start >= token.end + self.lag_diart):
token.speaker = extract_number(segment.speaker) + 1 token.speaker = extract_number(segment.speaker) + 1
end_attributed_speaker = max(token.end, end_attributed_speaker) end_attributed_speaker = max(token.end, end_attributed_speaker)
if use_punctuation_split and len(tokens) > 1:
punctuation_marks = {'.', '!', '?'}
print("Here are the tokens:",
[(t.text, t.start, t.end, t.speaker) for t in tokens[:10]])
segment_map = []
for segment in segments:
speaker_num = extract_number(segment.speaker) + 1
segment_map.append((segment.start, segment.end, speaker_num))
segment_map.sort(key=lambda x: x[0])
i = 0
while i < len(tokens):
current_token = tokens[i]
is_sentence_end = False
if current_token.text and current_token.text.strip():
text = current_token.text.strip()
if text[-1] in punctuation_marks:
is_sentence_end = True
logger.debug(f"Token {i} ends sentence: '{current_token.text}' at {current_token.end:.2f}s")
if is_sentence_end and current_token.speaker != -1:
punctuation_time = current_token.end
current_speaker = current_token.speaker
j = i + 1
next_sentence_tokens = []
while j < len(tokens):
next_token = tokens[j]
next_sentence_tokens.append(j)
# Check if this token ends the next sentence
if next_token.text and next_token.text.strip():
if next_token.text.strip()[-1] in punctuation_marks:
break
j += 1
if next_sentence_tokens:
speaker_times = {}
for idx in next_sentence_tokens:
token = tokens[idx]
# Find which segments overlap with this token
for seg_start, seg_end, seg_speaker in segment_map:
if not (seg_end <= token.start or seg_start >= token.end):
# Calculate overlap duration
overlap_start = max(seg_start, token.start)
overlap_end = min(seg_end, token.end)
overlap_duration = overlap_end - overlap_start
if seg_speaker not in speaker_times:
speaker_times[seg_speaker] = 0
speaker_times[seg_speaker] += overlap_duration
if speaker_times:
dominant_speaker = max(speaker_times.items(), key=lambda x: x[1])[0]
if dominant_speaker != current_speaker:
logger.debug(f" Speaker change after punctuation: {current_speaker}{dominant_speaker}")
for idx in next_sentence_tokens:
if tokens[idx].speaker != dominant_speaker:
logger.debug(f" Reassigning token {idx} ('{tokens[idx].text}') to Speaker {dominant_speaker}")
tokens[idx].speaker = dominant_speaker
end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker)
else:
for idx in next_sentence_tokens:
if tokens[idx].speaker == -1:
tokens[idx].speaker = current_speaker
end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker)
i += 1
return end_attributed_speaker return end_attributed_speaker

View File

@@ -37,6 +37,27 @@ def parse_args():
help="Enable speaker diarization.", help="Enable speaker diarization.",
) )
parser.add_argument(
"--punctuation-split",
action="store_true",
default=False,
help="Use punctuation marks from transcription to improve speaker boundary detection. Requires both transcription and diarization to be enabled.",
)
parser.add_argument(
"--segmentation-model",
type=str,
default="pyannote/segmentation-3.0",
help="Hugging Face model ID for pyannote.audio segmentation model.",
)
parser.add_argument(
"--embedding-model",
type=str,
default="pyannote/embedding",
help="Hugging Face model ID for pyannote.audio embedding model.",
)
parser.add_argument( parser.add_argument(
"--no-transcription", "--no-transcription",
action="store_true", action="store_true",

View File

@@ -26,4 +26,7 @@ class Transcript(TimedText):
@dataclass @dataclass
class SpeakerSegment(TimedText): class SpeakerSegment(TimedText):
"""Represents a segment of audio attributed to a specific speaker.
No text nor probability is associated with this segment.
"""
pass pass