mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e165916952 | ||
|
|
8532a91c7a | ||
|
|
b01b81bad0 | ||
|
|
0f79d442ee |
@@ -32,6 +32,7 @@ WhisperLiveKit consists of three main components:
|
|||||||
- **👥 Speaker Diarization** - Identify different speakers in real-time using [Diart](https://github.com/juanmc2005/diart)
|
- **👥 Speaker Diarization** - Identify different speakers in real-time using [Diart](https://github.com/juanmc2005/diart)
|
||||||
- **🔒 Fully Local** - All processing happens on your machine - no data sent to external servers
|
- **🔒 Fully Local** - All processing happens on your machine - no data sent to external servers
|
||||||
- **📱 Multi-User Support** - Handle multiple users simultaneously with a single backend/server
|
- **📱 Multi-User Support** - Handle multiple users simultaneously with a single backend/server
|
||||||
|
- **📝 Punctuation-Based Speaker Splitting [BETA] ** - Align speaker changes with natural sentence boundaries for more readable transcripts
|
||||||
|
|
||||||
### ⚙️ Core differences from [Whisper Streaming](https://github.com/ufal/whisper_streaming)
|
### ⚙️ Core differences from [Whisper Streaming](https://github.com/ufal/whisper_streaming)
|
||||||
|
|
||||||
@@ -230,6 +231,7 @@ WhisperLiveKit offers extensive configuration options:
|
|||||||
| `--task` | `transcribe` or `translate` | `transcribe` |
|
| `--task` | `transcribe` or `translate` | `transcribe` |
|
||||||
| `--backend` | Processing backend | `faster-whisper` |
|
| `--backend` | Processing backend | `faster-whisper` |
|
||||||
| `--diarization` | Enable speaker identification | `False` |
|
| `--diarization` | Enable speaker identification | `False` |
|
||||||
|
| `--punctuation-split` | Use punctuation to improve speaker boundaries | `True` |
|
||||||
| `--confidence-validation` | Use confidence scores for faster validation | `False` |
|
| `--confidence-validation` | Use confidence scores for faster validation | `False` |
|
||||||
| `--min-chunk-size` | Minimum audio chunk size (seconds) | `1.0` |
|
| `--min-chunk-size` | Minimum audio chunk size (seconds) | `1.0` |
|
||||||
| `--vac` | Use Voice Activity Controller | `False` |
|
| `--vac` | Use Voice Activity Controller | `False` |
|
||||||
@@ -238,6 +240,8 @@ WhisperLiveKit offers extensive configuration options:
|
|||||||
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
|
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
|
||||||
| `--ssl-certfile` | Path to the SSL certificate file (for HTTPS support) | `None` |
|
| `--ssl-certfile` | Path to the SSL certificate file (for HTTPS support) | `None` |
|
||||||
| `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` |
|
| `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` |
|
||||||
|
| `--segmentation-model` | Hugging Face model ID for pyannote.audio segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
|
||||||
|
| `--embedding-model` | Hugging Face model ID for pyannote.audio embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
|
||||||
|
|
||||||
## 🔧 How It Works
|
## 🔧 How It Works
|
||||||
|
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -1,7 +1,7 @@
|
|||||||
from setuptools import setup, find_packages
|
from setuptools import setup, find_packages
|
||||||
setup(
|
setup(
|
||||||
name="whisperlivekit",
|
name="whisperlivekit",
|
||||||
version="0.1.8",
|
version="0.1.9",
|
||||||
description="Real-time, Fully Local Whisper's Speech-to-Text and Speaker Diarization",
|
description="Real-time, Fully Local Whisper's Speech-to-Text and Speaker Diarization",
|
||||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
|
|||||||
@@ -377,13 +377,16 @@ class AudioProcessor:
|
|||||||
# Process diarization
|
# Process diarization
|
||||||
await diarization_obj.diarize(pcm_array)
|
await diarization_obj.diarize(pcm_array)
|
||||||
|
|
||||||
# Get current state and update speakers
|
async with self.lock:
|
||||||
state = await self.get_current_state()
|
new_end = diarization_obj.assign_speakers_to_tokens(
|
||||||
new_end = diarization_obj.assign_speakers_to_tokens(
|
self.end_attributed_speaker,
|
||||||
state["end_attributed_speaker"], state["tokens"]
|
self.tokens,
|
||||||
)
|
use_punctuation_split=self.args.punctuation_split
|
||||||
|
)
|
||||||
|
self.end_attributed_speaker = new_end
|
||||||
|
if buffer_diarization:
|
||||||
|
self.buffer_diarization = buffer_diarization
|
||||||
|
|
||||||
await self.update_diarization(new_end, buffer_diarization)
|
|
||||||
self.diarization_queue.task_done()
|
self.diarization_queue.task_done()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ class TranscriptionEngine:
|
|||||||
"warmup_file": None,
|
"warmup_file": None,
|
||||||
"confidence_validation": False,
|
"confidence_validation": False,
|
||||||
"diarization": False,
|
"diarization": False,
|
||||||
|
"punctuation_split": False,
|
||||||
"min_chunk_size": 0.5,
|
"min_chunk_size": 0.5,
|
||||||
"model": "tiny",
|
"model": "tiny",
|
||||||
"model_cache_dir": None,
|
"model_cache_dir": None,
|
||||||
@@ -40,6 +41,8 @@ class TranscriptionEngine:
|
|||||||
"ssl_keyfile": None,
|
"ssl_keyfile": None,
|
||||||
"transcription": True,
|
"transcription": True,
|
||||||
"vad": True,
|
"vad": True,
|
||||||
|
"segmentation_model": "pyannote/segmentation-3.0",
|
||||||
|
"embedding_model": "pyannote/embedding",
|
||||||
}
|
}
|
||||||
|
|
||||||
config_dict = {**defaults, **kwargs}
|
config_dict = {**defaults, **kwargs}
|
||||||
@@ -68,6 +71,10 @@ class TranscriptionEngine:
|
|||||||
|
|
||||||
if self.args.diarization:
|
if self.args.diarization:
|
||||||
from whisperlivekit.diarization.diarization_online import DiartDiarization
|
from whisperlivekit.diarization.diarization_online import DiartDiarization
|
||||||
self.diarization = DiartDiarization()
|
self.diarization = DiartDiarization(
|
||||||
|
block_duration=self.args.min_chunk_size,
|
||||||
|
segmentation_model_name=self.args.segmentation_model,
|
||||||
|
embedding_model_name=self.args.embedding_model
|
||||||
|
)
|
||||||
|
|
||||||
TranscriptionEngine._initialized = True
|
TranscriptionEngine._initialized = True
|
||||||
|
|||||||
@@ -3,7 +3,8 @@ import re
|
|||||||
import threading
|
import threading
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
|
from queue import SimpleQueue, Empty
|
||||||
|
|
||||||
from diart import SpeakerDiarization, SpeakerDiarizationConfig
|
from diart import SpeakerDiarization, SpeakerDiarizationConfig
|
||||||
from diart.inference import StreamingInference
|
from diart.inference import StreamingInference
|
||||||
@@ -13,6 +14,7 @@ from diart.sources import MicrophoneAudioSource
|
|||||||
from rx.core import Observer
|
from rx.core import Observer
|
||||||
from typing import Tuple, Any, List
|
from typing import Tuple, Any, List
|
||||||
from pyannote.core import Annotation
|
from pyannote.core import Annotation
|
||||||
|
import diart.models as m
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -78,40 +80,114 @@ class DiarizationObserver(Observer):
|
|||||||
|
|
||||||
class WebSocketAudioSource(AudioSource):
|
class WebSocketAudioSource(AudioSource):
|
||||||
"""
|
"""
|
||||||
Custom AudioSource that blocks in read() until close() is called.
|
Buffers incoming audio and releases it in fixed-size chunks at regular intervals.
|
||||||
Use push_audio() to inject PCM chunks.
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, uri: str = "websocket", sample_rate: int = 16000):
|
def __init__(self, uri: str = "websocket", sample_rate: int = 16000, block_duration: float = 0.5):
|
||||||
super().__init__(uri, sample_rate)
|
super().__init__(uri, sample_rate)
|
||||||
|
self.block_duration = block_duration
|
||||||
|
self.block_size = int(np.rint(block_duration * sample_rate))
|
||||||
|
self._queue = SimpleQueue()
|
||||||
|
self._buffer = np.array([], dtype=np.float32)
|
||||||
|
self._buffer_lock = threading.Lock()
|
||||||
self._closed = False
|
self._closed = False
|
||||||
self._close_event = threading.Event()
|
self._close_event = threading.Event()
|
||||||
|
self._processing_thread = None
|
||||||
|
self._last_chunk_time = time.time()
|
||||||
|
|
||||||
def read(self):
|
def read(self):
|
||||||
|
"""Start processing buffered audio and emit fixed-size chunks."""
|
||||||
|
self._processing_thread = threading.Thread(target=self._process_chunks)
|
||||||
|
self._processing_thread.daemon = True
|
||||||
|
self._processing_thread.start()
|
||||||
|
|
||||||
self._close_event.wait()
|
self._close_event.wait()
|
||||||
|
if self._processing_thread:
|
||||||
|
self._processing_thread.join(timeout=2.0)
|
||||||
|
|
||||||
|
def _process_chunks(self):
|
||||||
|
"""Process audio from queue and emit fixed-size chunks at regular intervals."""
|
||||||
|
while not self._closed:
|
||||||
|
try:
|
||||||
|
audio_chunk = self._queue.get(timeout=0.1)
|
||||||
|
|
||||||
|
with self._buffer_lock:
|
||||||
|
self._buffer = np.concatenate([self._buffer, audio_chunk])
|
||||||
|
|
||||||
|
while len(self._buffer) >= self.block_size:
|
||||||
|
chunk = self._buffer[:self.block_size]
|
||||||
|
self._buffer = self._buffer[self.block_size:]
|
||||||
|
|
||||||
|
current_time = time.time()
|
||||||
|
time_since_last = current_time - self._last_chunk_time
|
||||||
|
if time_since_last < self.block_duration:
|
||||||
|
time.sleep(self.block_duration - time_since_last)
|
||||||
|
|
||||||
|
chunk_reshaped = chunk.reshape(1, -1)
|
||||||
|
self.stream.on_next(chunk_reshaped)
|
||||||
|
self._last_chunk_time = time.time()
|
||||||
|
|
||||||
|
except Empty:
|
||||||
|
with self._buffer_lock:
|
||||||
|
if len(self._buffer) > 0 and time.time() - self._last_chunk_time > self.block_duration:
|
||||||
|
padded_chunk = np.zeros(self.block_size, dtype=np.float32)
|
||||||
|
padded_chunk[:len(self._buffer)] = self._buffer
|
||||||
|
self._buffer = np.array([], dtype=np.float32)
|
||||||
|
|
||||||
|
chunk_reshaped = padded_chunk.reshape(1, -1)
|
||||||
|
self.stream.on_next(chunk_reshaped)
|
||||||
|
self._last_chunk_time = time.time()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in audio processing thread: {e}")
|
||||||
|
self.stream.on_error(e)
|
||||||
|
break
|
||||||
|
|
||||||
|
with self._buffer_lock:
|
||||||
|
if len(self._buffer) > 0:
|
||||||
|
padded_chunk = np.zeros(self.block_size, dtype=np.float32)
|
||||||
|
padded_chunk[:len(self._buffer)] = self._buffer
|
||||||
|
chunk_reshaped = padded_chunk.reshape(1, -1)
|
||||||
|
self.stream.on_next(chunk_reshaped)
|
||||||
|
|
||||||
|
self.stream.on_completed()
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
if not self._closed:
|
if not self._closed:
|
||||||
self._closed = True
|
self._closed = True
|
||||||
self.stream.on_completed()
|
|
||||||
self._close_event.set()
|
self._close_event.set()
|
||||||
|
|
||||||
def push_audio(self, chunk: np.ndarray):
|
def push_audio(self, chunk: np.ndarray):
|
||||||
|
"""Add audio chunk to the processing queue."""
|
||||||
if not self._closed:
|
if not self._closed:
|
||||||
new_audio = np.expand_dims(chunk, axis=0)
|
if chunk.ndim > 1:
|
||||||
logger.debug('Add new chunk with shape:', new_audio.shape)
|
chunk = chunk.flatten()
|
||||||
self.stream.on_next(new_audio)
|
self._queue.put(chunk)
|
||||||
|
logger.debug(f'Added chunk to queue with {len(chunk)} samples')
|
||||||
|
|
||||||
|
|
||||||
class DiartDiarization:
|
class DiartDiarization:
|
||||||
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False):
|
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 0.5, segmentation_model_name: str = "pyannote/segmentation-3.0", embedding_model_name: str = "speechbrain/spkrec-ecapa-voxceleb"):
|
||||||
|
segmentation_model = m.SegmentationModel.from_pretrained(segmentation_model_name)
|
||||||
|
embedding_model = m.EmbeddingModel.from_pretrained(embedding_model_name)
|
||||||
|
|
||||||
|
if config is None:
|
||||||
|
config = SpeakerDiarizationConfig(
|
||||||
|
segmentation=segmentation_model,
|
||||||
|
embedding=embedding_model,
|
||||||
|
)
|
||||||
|
|
||||||
self.pipeline = SpeakerDiarization(config=config)
|
self.pipeline = SpeakerDiarization(config=config)
|
||||||
self.observer = DiarizationObserver()
|
self.observer = DiarizationObserver()
|
||||||
|
self.lag_diart = None
|
||||||
|
|
||||||
if use_microphone:
|
if use_microphone:
|
||||||
self.source = MicrophoneAudioSource()
|
self.source = MicrophoneAudioSource(block_duration=block_duration)
|
||||||
self.custom_source = None
|
self.custom_source = None
|
||||||
else:
|
else:
|
||||||
self.custom_source = WebSocketAudioSource(uri="websocket_source", sample_rate=sample_rate)
|
self.custom_source = WebSocketAudioSource(
|
||||||
|
uri="websocket_source",
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
block_duration=block_duration
|
||||||
|
)
|
||||||
self.source = self.custom_source
|
self.source = self.custom_source
|
||||||
|
|
||||||
self.inference = StreamingInference(
|
self.inference = StreamingInference(
|
||||||
@@ -138,16 +214,102 @@ class DiartDiarization:
|
|||||||
if self.custom_source:
|
if self.custom_source:
|
||||||
self.custom_source.close()
|
self.custom_source.close()
|
||||||
|
|
||||||
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) -> float:
|
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list, use_punctuation_split: bool = False) -> float:
|
||||||
"""
|
"""
|
||||||
Assign speakers to tokens based on timing overlap with speaker segments.
|
Assign speakers to tokens based on timing overlap with speaker segments.
|
||||||
Uses the segments collected by the observer.
|
Uses the segments collected by the observer.
|
||||||
|
|
||||||
|
If use_punctuation_split is True, uses punctuation marks to refine speaker boundaries.
|
||||||
"""
|
"""
|
||||||
segments = self.observer.get_segments()
|
segments = self.observer.get_segments()
|
||||||
|
|
||||||
|
# Debug logging
|
||||||
|
logger.debug(f"assign_speakers_to_tokens called with {len(tokens)} tokens")
|
||||||
|
logger.debug(f"Available segments: {len(segments)}")
|
||||||
|
for i, seg in enumerate(segments[:5]): # Show first 5 segments
|
||||||
|
logger.debug(f" Segment {i}: {seg.speaker} [{seg.start:.2f}-{seg.end:.2f}]")
|
||||||
|
|
||||||
|
if not self.lag_diart and segments and tokens:
|
||||||
|
self.lag_diart = segments[0].start - tokens[0].start
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
for segment in segments:
|
for segment in segments:
|
||||||
if not (segment.end <= token.start or segment.start >= token.end):
|
if not (segment.end <= token.start + self.lag_diart or segment.start >= token.end + self.lag_diart):
|
||||||
token.speaker = extract_number(segment.speaker) + 1
|
token.speaker = extract_number(segment.speaker) + 1
|
||||||
end_attributed_speaker = max(token.end, end_attributed_speaker)
|
end_attributed_speaker = max(token.end, end_attributed_speaker)
|
||||||
|
|
||||||
|
if use_punctuation_split and len(tokens) > 1:
|
||||||
|
punctuation_marks = {'.', '!', '?'}
|
||||||
|
|
||||||
|
print("Here are the tokens:",
|
||||||
|
[(t.text, t.start, t.end, t.speaker) for t in tokens[:10]])
|
||||||
|
|
||||||
|
segment_map = []
|
||||||
|
for segment in segments:
|
||||||
|
speaker_num = extract_number(segment.speaker) + 1
|
||||||
|
segment_map.append((segment.start, segment.end, speaker_num))
|
||||||
|
segment_map.sort(key=lambda x: x[0])
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
while i < len(tokens):
|
||||||
|
current_token = tokens[i]
|
||||||
|
|
||||||
|
is_sentence_end = False
|
||||||
|
if current_token.text and current_token.text.strip():
|
||||||
|
text = current_token.text.strip()
|
||||||
|
if text[-1] in punctuation_marks:
|
||||||
|
is_sentence_end = True
|
||||||
|
logger.debug(f"Token {i} ends sentence: '{current_token.text}' at {current_token.end:.2f}s")
|
||||||
|
|
||||||
|
if is_sentence_end and current_token.speaker != -1:
|
||||||
|
punctuation_time = current_token.end
|
||||||
|
current_speaker = current_token.speaker
|
||||||
|
|
||||||
|
j = i + 1
|
||||||
|
next_sentence_tokens = []
|
||||||
|
while j < len(tokens):
|
||||||
|
next_token = tokens[j]
|
||||||
|
next_sentence_tokens.append(j)
|
||||||
|
|
||||||
|
# Check if this token ends the next sentence
|
||||||
|
if next_token.text and next_token.text.strip():
|
||||||
|
if next_token.text.strip()[-1] in punctuation_marks:
|
||||||
|
break
|
||||||
|
j += 1
|
||||||
|
|
||||||
|
if next_sentence_tokens:
|
||||||
|
speaker_times = {}
|
||||||
|
|
||||||
|
for idx in next_sentence_tokens:
|
||||||
|
token = tokens[idx]
|
||||||
|
# Find which segments overlap with this token
|
||||||
|
for seg_start, seg_end, seg_speaker in segment_map:
|
||||||
|
if not (seg_end <= token.start or seg_start >= token.end):
|
||||||
|
# Calculate overlap duration
|
||||||
|
overlap_start = max(seg_start, token.start)
|
||||||
|
overlap_end = min(seg_end, token.end)
|
||||||
|
overlap_duration = overlap_end - overlap_start
|
||||||
|
|
||||||
|
if seg_speaker not in speaker_times:
|
||||||
|
speaker_times[seg_speaker] = 0
|
||||||
|
speaker_times[seg_speaker] += overlap_duration
|
||||||
|
|
||||||
|
if speaker_times:
|
||||||
|
dominant_speaker = max(speaker_times.items(), key=lambda x: x[1])[0]
|
||||||
|
|
||||||
|
if dominant_speaker != current_speaker:
|
||||||
|
logger.debug(f" Speaker change after punctuation: {current_speaker} → {dominant_speaker}")
|
||||||
|
|
||||||
|
for idx in next_sentence_tokens:
|
||||||
|
if tokens[idx].speaker != dominant_speaker:
|
||||||
|
logger.debug(f" Reassigning token {idx} ('{tokens[idx].text}') to Speaker {dominant_speaker}")
|
||||||
|
tokens[idx].speaker = dominant_speaker
|
||||||
|
end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker)
|
||||||
|
else:
|
||||||
|
for idx in next_sentence_tokens:
|
||||||
|
if tokens[idx].speaker == -1:
|
||||||
|
tokens[idx].speaker = current_speaker
|
||||||
|
end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker)
|
||||||
|
|
||||||
|
i += 1
|
||||||
|
|
||||||
return end_attributed_speaker
|
return end_attributed_speaker
|
||||||
@@ -37,6 +37,27 @@ def parse_args():
|
|||||||
help="Enable speaker diarization.",
|
help="Enable speaker diarization.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--punctuation-split",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Use punctuation marks from transcription to improve speaker boundary detection. Requires both transcription and diarization to be enabled.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--segmentation-model",
|
||||||
|
type=str,
|
||||||
|
default="pyannote/segmentation-3.0",
|
||||||
|
help="Hugging Face model ID for pyannote.audio segmentation model.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-model",
|
||||||
|
type=str,
|
||||||
|
default="pyannote/embedding",
|
||||||
|
help="Hugging Face model ID for pyannote.audio embedding model.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--no-transcription",
|
"--no-transcription",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
@@ -26,4 +26,7 @@ class Transcript(TimedText):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SpeakerSegment(TimedText):
|
class SpeakerSegment(TimedText):
|
||||||
|
"""Represents a segment of audio attributed to a specific speaker.
|
||||||
|
No text nor probability is associated with this segment.
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
Reference in New Issue
Block a user