diff --git a/README.md b/README.md index a1bcd82..e0e5969 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,7 @@ WhisperLiveKit consists of three main components: - **👥 Speaker Diarization** - Identify different speakers in real-time using [Diart](https://github.com/juanmc2005/diart) - **🔒 Fully Local** - All processing happens on your machine - no data sent to external servers - **📱 Multi-User Support** - Handle multiple users simultaneously with a single backend/server +- **📝 Punctuation-Based Speaker Splitting [BETA] ** - Align speaker changes with natural sentence boundaries for more readable transcripts ### ⚙️ Core differences from [Whisper Streaming](https://github.com/ufal/whisper_streaming) @@ -230,6 +231,7 @@ WhisperLiveKit offers extensive configuration options: | `--task` | `transcribe` or `translate` | `transcribe` | | `--backend` | Processing backend | `faster-whisper` | | `--diarization` | Enable speaker identification | `False` | +| `--punctuation-split` | Use punctuation to improve speaker boundaries | `True` | | `--confidence-validation` | Use confidence scores for faster validation | `False` | | `--min-chunk-size` | Minimum audio chunk size (seconds) | `1.0` | | `--vac` | Use Voice Activity Controller | `False` | diff --git a/whisperlivekit/audio_processor.py b/whisperlivekit/audio_processor.py index 08f2b29..bd0af88 100644 --- a/whisperlivekit/audio_processor.py +++ b/whisperlivekit/audio_processor.py @@ -377,13 +377,16 @@ class AudioProcessor: # Process diarization await diarization_obj.diarize(pcm_array) - # Get current state and update speakers - state = await self.get_current_state() - new_end = diarization_obj.assign_speakers_to_tokens( - state["end_attributed_speaker"], state["tokens"] - ) + async with self.lock: + new_end = diarization_obj.assign_speakers_to_tokens( + self.end_attributed_speaker, + self.tokens, + use_punctuation_split=self.args.punctuation_split + ) + self.end_attributed_speaker = new_end + if buffer_diarization: + self.buffer_diarization = buffer_diarization - await self.update_diarization(new_end, buffer_diarization) self.diarization_queue.task_done() except Exception as e: diff --git a/whisperlivekit/core.py b/whisperlivekit/core.py index 45467b9..ae6fa71 100644 --- a/whisperlivekit/core.py +++ b/whisperlivekit/core.py @@ -24,6 +24,7 @@ class TranscriptionEngine: "warmup_file": None, "confidence_validation": False, "diarization": False, + "punctuation_split": True, "min_chunk_size": 0.5, "model": "tiny", "model_cache_dir": None, @@ -68,6 +69,6 @@ class TranscriptionEngine: if self.args.diarization: from whisperlivekit.diarization.diarization_online import DiartDiarization - self.diarization = DiartDiarization() + self.diarization = DiartDiarization(block_duration=self.args.min_chunk_size) TranscriptionEngine._initialized = True diff --git a/whisperlivekit/diarization/diarization_online.py b/whisperlivekit/diarization/diarization_online.py index 7db5003..978c47a 100644 --- a/whisperlivekit/diarization/diarization_online.py +++ b/whisperlivekit/diarization/diarization_online.py @@ -3,7 +3,8 @@ import re import threading import numpy as np import logging - +import time +from queue import SimpleQueue, Empty from diart import SpeakerDiarization, SpeakerDiarizationConfig from diart.inference import StreamingInference @@ -13,6 +14,10 @@ from diart.sources import MicrophoneAudioSource from rx.core import Observer from typing import Tuple, Any, List from pyannote.core import Annotation +import diart.models as m + +segmentation = m.SegmentationModel.from_pretrained("pyannote/segmentation-3.0") +embedding = m.EmbeddingModel.from_pretrained("speechbrain/spkrec-ecapa-voxceleb") logger = logging.getLogger(__name__) @@ -78,40 +83,104 @@ class DiarizationObserver(Observer): class WebSocketAudioSource(AudioSource): """ - Custom AudioSource that blocks in read() until close() is called. - Use push_audio() to inject PCM chunks. + Buffers incoming audio and releases it in fixed-size chunks at regular intervals. """ - def __init__(self, uri: str = "websocket", sample_rate: int = 16000): + def __init__(self, uri: str = "websocket", sample_rate: int = 16000, block_duration: float = 0.5): super().__init__(uri, sample_rate) + self.block_duration = block_duration + self.block_size = int(np.rint(block_duration * sample_rate)) + self._queue = SimpleQueue() + self._buffer = np.array([], dtype=np.float32) + self._buffer_lock = threading.Lock() self._closed = False self._close_event = threading.Event() + self._processing_thread = None + self._last_chunk_time = time.time() def read(self): + """Start processing buffered audio and emit fixed-size chunks.""" + self._processing_thread = threading.Thread(target=self._process_chunks) + self._processing_thread.daemon = True + self._processing_thread.start() + self._close_event.wait() + if self._processing_thread: + self._processing_thread.join(timeout=2.0) + + def _process_chunks(self): + """Process audio from queue and emit fixed-size chunks at regular intervals.""" + while not self._closed: + try: + audio_chunk = self._queue.get(timeout=0.1) + + with self._buffer_lock: + self._buffer = np.concatenate([self._buffer, audio_chunk]) + + while len(self._buffer) >= self.block_size: + chunk = self._buffer[:self.block_size] + self._buffer = self._buffer[self.block_size:] + + current_time = time.time() + time_since_last = current_time - self._last_chunk_time + if time_since_last < self.block_duration: + time.sleep(self.block_duration - time_since_last) + + chunk_reshaped = chunk.reshape(1, -1) + self.stream.on_next(chunk_reshaped) + self._last_chunk_time = time.time() + + except Empty: + with self._buffer_lock: + if len(self._buffer) > 0 and time.time() - self._last_chunk_time > self.block_duration: + padded_chunk = np.zeros(self.block_size, dtype=np.float32) + padded_chunk[:len(self._buffer)] = self._buffer + self._buffer = np.array([], dtype=np.float32) + + chunk_reshaped = padded_chunk.reshape(1, -1) + self.stream.on_next(chunk_reshaped) + self._last_chunk_time = time.time() + except Exception as e: + logger.error(f"Error in audio processing thread: {e}") + self.stream.on_error(e) + break + + with self._buffer_lock: + if len(self._buffer) > 0: + padded_chunk = np.zeros(self.block_size, dtype=np.float32) + padded_chunk[:len(self._buffer)] = self._buffer + chunk_reshaped = padded_chunk.reshape(1, -1) + self.stream.on_next(chunk_reshaped) + + self.stream.on_completed() def close(self): if not self._closed: self._closed = True - self.stream.on_completed() self._close_event.set() def push_audio(self, chunk: np.ndarray): + """Add audio chunk to the processing queue.""" if not self._closed: - new_audio = np.expand_dims(chunk, axis=0) - logger.debug('Add new chunk with shape:', new_audio.shape) - self.stream.on_next(new_audio) + if chunk.ndim > 1: + chunk = chunk.flatten() + self._queue.put(chunk) + logger.debug(f'Added chunk to queue with {len(chunk)} samples') class DiartDiarization: - def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False): + def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 0.5): self.pipeline = SpeakerDiarization(config=config) self.observer = DiarizationObserver() if use_microphone: - self.source = MicrophoneAudioSource() + self.source = MicrophoneAudioSource(block_duration=block_duration) self.custom_source = None else: - self.custom_source = WebSocketAudioSource(uri="websocket_source", sample_rate=sample_rate) + self.custom_source = WebSocketAudioSource( + uri="websocket_source", + sample_rate=sample_rate, + block_duration=block_duration + ) self.source = self.custom_source self.inference = StreamingInference( @@ -138,16 +207,107 @@ class DiartDiarization: if self.custom_source: self.custom_source.close() - def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) -> float: + def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list, use_punctuation_split: bool = False) -> float: """ Assign speakers to tokens based on timing overlap with speaker segments. Uses the segments collected by the observer. + + If use_punctuation_split is True, uses punctuation marks to refine speaker boundaries. """ segments = self.observer.get_segments() + # Debug logging + logger.debug(f"assign_speakers_to_tokens called with {len(tokens)} tokens") + logger.debug(f"Available segments: {len(segments)}") + for i, seg in enumerate(segments[:5]): # Show first 5 segments + logger.debug(f" Segment {i}: {seg.speaker} [{seg.start:.2f}-{seg.end:.2f}]") + + # First pass: assign speakers based on timing overlap for token in tokens: for segment in segments: if not (segment.end <= token.start or segment.start >= token.end): token.speaker = extract_number(segment.speaker) + 1 end_attributed_speaker = max(token.end, end_attributed_speaker) - return end_attributed_speaker \ No newline at end of file + + if use_punctuation_split and len(tokens) > 1: + punctuation_marks = {'.', '!', '?'} + + print("Here are the tokens:", + [(t.text, t.start, t.end, t.speaker) for t in tokens[:10]]) + + segment_map = [] + for segment in segments: + speaker_num = extract_number(segment.speaker) + 1 + segment_map.append((segment.start, segment.end, speaker_num)) + segment_map.sort(key=lambda x: x[0]) # Sort by start time + + i = 0 + while i < len(tokens): + current_token = tokens[i] + + # Check if current token ends with sentence-ending punctuation + is_sentence_end = False + if current_token.text and current_token.text.strip(): + text = current_token.text.strip() + if text[-1] in punctuation_marks: + is_sentence_end = True + logger.debug(f"Token {i} ends sentence: '{current_token.text}' at {current_token.end:.2f}s") + + if is_sentence_end and current_token.speaker != -1: + # Find the dominant speaker for tokens after this punctuation + punctuation_time = current_token.end + current_speaker = current_token.speaker + + # Look ahead to find where the next sentence starts and ends + j = i + 1 + next_sentence_tokens = [] + + # Collect tokens until we hit another sentence-ending punctuation or run out + while j < len(tokens): + next_token = tokens[j] + next_sentence_tokens.append(j) + + # Check if this token ends the next sentence + if next_token.text and next_token.text.strip(): + if next_token.text.strip()[-1] in punctuation_marks: + break + j += 1 + + if next_sentence_tokens: + speaker_times = {} + + for idx in next_sentence_tokens: + token = tokens[idx] + # Find which segments overlap with this token + for seg_start, seg_end, seg_speaker in segment_map: + if not (seg_end <= token.start or seg_start >= token.end): + # Calculate overlap duration + overlap_start = max(seg_start, token.start) + overlap_end = min(seg_end, token.end) + overlap_duration = overlap_end - overlap_start + + if seg_speaker not in speaker_times: + speaker_times[seg_speaker] = 0 + speaker_times[seg_speaker] += overlap_duration + + if speaker_times: + dominant_speaker = max(speaker_times.items(), key=lambda x: x[1])[0] + + if dominant_speaker != current_speaker: + logger.debug(f" Speaker change after punctuation: {current_speaker} → {dominant_speaker}") + + for idx in next_sentence_tokens: + if tokens[idx].speaker != dominant_speaker: + logger.debug(f" Reassigning token {idx} ('{tokens[idx].text}') to Speaker {dominant_speaker}") + tokens[idx].speaker = dominant_speaker + end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker) + else: + for idx in next_sentence_tokens: + if tokens[idx].speaker == -1: + tokens[idx].speaker = current_speaker + end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker) + + i += 1 + + + return end_attributed_speaker diff --git a/whisperlivekit/parse_args.py b/whisperlivekit/parse_args.py index f97215a..da28532 100644 --- a/whisperlivekit/parse_args.py +++ b/whisperlivekit/parse_args.py @@ -37,6 +37,13 @@ def parse_args(): help="Enable speaker diarization.", ) + parser.add_argument( + "--punctuation-split", + action="store_true", + default=False, + help="Use punctuation marks from transcription to improve speaker boundary detection. Requires both transcription and diarization to be enabled.", + ) + parser.add_argument( "--no-transcription", action="store_true",