6 Commits
0.1.7 ... 0.1.9

Author SHA1 Message Date
Quentin Fuxa
e165916952 add diarization model list url 2025-06-19 16:43:23 +02:00
Quentin Fuxa
8532a91c7a add segmentation and embedding model options to configuration 2025-06-19 16:29:25 +02:00
Quentin Fuxa
b01b81bad0 improve diarization with lag diarization substraction 2025-06-19 16:18:49 +02:00
Quentin Fuxa
0f79d442ee improve diarization speed + Use punctuation to better align speakers and diarization 2025-06-19 13:03:29 +02:00
Quentin Fuxa
c9f60504e3 update with up to date example 2025-06-16 16:57:47 +02:00
Quentin Fuxa
993a83546a core refactoring 2025-06-16 16:13:57 +02:00
10 changed files with 490 additions and 220 deletions

View File

@@ -32,6 +32,7 @@ WhisperLiveKit consists of three main components:
- **👥 Speaker Diarization** - Identify different speakers in real-time using [Diart](https://github.com/juanmc2005/diart)
- **🔒 Fully Local** - All processing happens on your machine - no data sent to external servers
- **📱 Multi-User Support** - Handle multiple users simultaneously with a single backend/server
- **📝 Punctuation-Based Speaker Splitting [BETA] ** - Align speaker changes with natural sentence boundaries for more readable transcripts
### ⚙️ Core differences from [Whisper Streaming](https://github.com/ufal/whisper_streaming)
@@ -142,52 +143,79 @@ whisperlivekit-server --host 0.0.0.0 --port 8000 --model medium --diarization --
```
### Python API Integration (Backend)
Check [basic_server.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) for a complete example.
```python
from whisperlivekit import WhisperLiveKit
from whisperlivekit.audio_processor import AudioProcessor
from fastapi import FastAPI, WebSocket
import asyncio
from whisperlivekit import TranscriptionEngine, AudioProcessor, get_web_interface_html, parse_args
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from contextlib import asynccontextmanager
import asyncio
# Initialize components
app = FastAPI()
kit = WhisperLiveKit(model="medium", diarization=True)
# Global variable for the transcription engine
transcription_engine = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global transcription_engine
# Example: Initialize with specific parameters directly
# You can also load from command-line arguments using parse_args()
# args = parse_args()
# transcription_engine = TranscriptionEngine(**vars(args))
transcription_engine = TranscriptionEngine(model="medium", diarization=True, lan="en")
yield
app = FastAPI(lifespan=lifespan)
# Serve the web interface
@app.get("/")
async def get():
return HTMLResponse(kit.web_interface()) # Use the built-in web interface
return HTMLResponse(get_web_interface_html())
# Process WebSocket connections
async def handle_websocket_results(websocket, results_generator):
async for response in results_generator:
await websocket.send_json(response)
async def handle_websocket_results(websocket: WebSocket, results_generator):
try:
async for response in results_generator:
await websocket.send_json(response)
await websocket.send_json({"type": "ready_to_stop"})
except WebSocketDisconnect:
print("WebSocket disconnected during results handling.")
@app.websocket("/asr")
async def websocket_endpoint(websocket: WebSocket):
audio_processor = AudioProcessor()
await websocket.accept()
results_generator = await audio_processor.create_tasks()
websocket_task = asyncio.create_task(
handle_websocket_results(websocket, results_generator)
)
global transcription_engine
# Create a new AudioProcessor for each connection, passing the shared engine
audio_processor = AudioProcessor(transcription_engine=transcription_engine)
results_generator = await audio_processor.create_tasks()
send_results_to_client = handle_websocket_results(websocket, results_generator)
results_task = asyncio.create_task(send_results_to_client)
await websocket.accept()
try:
while True:
message = await websocket.receive_bytes()
await audio_processor.process_audio(message)
await audio_processor.process_audio(message)
except WebSocketDisconnect:
print(f"Client disconnected: {websocket.client}")
except Exception as e:
print(f"WebSocket error: {e}")
websocket_task.cancel()
await websocket.close(code=1011, reason=f"Server error: {e}")
finally:
results_task.cancel()
try:
await results_task
except asyncio.CancelledError:
logger.info("Results task successfully cancelled.")
```
### Frontend Implementation
The package includes a simple HTML/JavaScript implementation that you can adapt for your project. You can get in in [whisperlivekit/web/live_transcription.html](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html), or using :
The package includes a simple HTML/JavaScript implementation that you can adapt for your project. You can find it in `whisperlivekit/web/live_transcription.html`, or load its content using the `get_web_interface_html()` function from `whisperlivekit`:
```python
kit.web_interface()
from whisperlivekit import get_web_interface_html
# ... later in your code where you need the HTML string ...
html_content = get_web_interface_html()
```
## ⚙️ Configuration Reference
@@ -203,6 +231,7 @@ WhisperLiveKit offers extensive configuration options:
| `--task` | `transcribe` or `translate` | `transcribe` |
| `--backend` | Processing backend | `faster-whisper` |
| `--diarization` | Enable speaker identification | `False` |
| `--punctuation-split` | Use punctuation to improve speaker boundaries | `True` |
| `--confidence-validation` | Use confidence scores for faster validation | `False` |
| `--min-chunk-size` | Minimum audio chunk size (seconds) | `1.0` |
| `--vac` | Use Voice Activity Controller | `False` |
@@ -211,6 +240,8 @@ WhisperLiveKit offers extensive configuration options:
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
| `--ssl-certfile` | Path to the SSL certificate file (for HTTPS support) | `None` |
| `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` |
| `--segmentation-model` | Hugging Face model ID for pyannote.audio segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
| `--embedding-model` | Hugging Face model ID for pyannote.audio embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
## 🔧 How It Works

View File

@@ -1,7 +1,7 @@
from setuptools import setup, find_packages
setup(
name="whisperlivekit",
version="0.1.7",
version="0.1.9",
description="Real-time, Fully Local Whisper's Speech-to-Text and Speaker Diarization",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",

View File

@@ -1,4 +1,5 @@
from .core import WhisperLiveKit, parse_args
from .core import TranscriptionEngine
from .audio_processor import AudioProcessor
__all__ = ['WhisperLiveKit', 'AudioProcessor', 'parse_args']
from .web.web_interface import get_web_interface_html
from .parse_args import parse_args
__all__ = ['TranscriptionEngine', 'AudioProcessor', 'get_web_interface_html', 'parse_args']

View File

@@ -8,7 +8,7 @@ import traceback
from datetime import timedelta
from whisperlivekit.timed_objects import ASRToken
from whisperlivekit.whisper_streaming_custom.whisper_online import online_factory
from whisperlivekit.core import WhisperLiveKit
from whisperlivekit.core import TranscriptionEngine
# Set up logging once
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
@@ -27,10 +27,13 @@ class AudioProcessor:
Handles audio processing, state management, and result formatting.
"""
def __init__(self):
def __init__(self, **kwargs):
"""Initialize the audio processor with configuration, models, and state."""
models = WhisperLiveKit()
if 'transcription_engine' in kwargs and isinstance(kwargs['transcription_engine'], TranscriptionEngine):
models = kwargs['transcription_engine']
else:
models = TranscriptionEngine(**kwargs)
# Audio processing settings
self.args = models.args
@@ -374,13 +377,16 @@ class AudioProcessor:
# Process diarization
await diarization_obj.diarize(pcm_array)
# Get current state and update speakers
state = await self.get_current_state()
new_end = diarization_obj.assign_speakers_to_tokens(
state["end_attributed_speaker"], state["tokens"]
)
async with self.lock:
new_end = diarization_obj.assign_speakers_to_tokens(
self.end_attributed_speaker,
self.tokens,
use_punctuation_split=self.args.punctuation_split
)
self.end_attributed_speaker = new_end
if buffer_diarization:
self.buffer_diarization = buffer_diarization
await self.update_diarization(new_end, buffer_diarization)
self.diarization_queue.task_done()
except Exception as e:

View File

@@ -2,26 +2,24 @@ from contextlib import asynccontextmanager
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from whisperlivekit import WhisperLiveKit, parse_args
from whisperlivekit.audio_processor import AudioProcessor
from whisperlivekit import TranscriptionEngine, AudioProcessor, get_web_interface_html, parse_args
import asyncio
import logging
import os, sys
import argparse
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logging.getLogger().setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
kit = None
args = parse_args()
transcription_engine = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global kit
kit = WhisperLiveKit()
global transcription_engine
transcription_engine = TranscriptionEngine(
**vars(args),
)
yield
app = FastAPI(lifespan=lifespan)
@@ -33,10 +31,9 @@ app.add_middleware(
allow_headers=["*"],
)
@app.get("/")
async def get():
return HTMLResponse(kit.web_interface())
return HTMLResponse(get_web_interface_html())
async def handle_websocket_results(websocket, results_generator):
@@ -55,8 +52,10 @@ async def handle_websocket_results(websocket, results_generator):
@app.websocket("/asr")
async def websocket_endpoint(websocket: WebSocket):
audio_processor = AudioProcessor()
global transcription_engine
audio_processor = AudioProcessor(
transcription_engine=transcription_engine,
)
await websocket.accept()
logger.info("WebSocket connection opened.")
@@ -94,8 +93,6 @@ def main():
"""Entry point for the CLI command."""
import uvicorn
args = parse_args()
uvicorn_kwargs = {
"app": "whisperlivekit.basic_server:app",
"host":args.host,
@@ -114,7 +111,6 @@ def main():
"ssl_keyfile": args.ssl_keyfile
}
if ssl_kwargs:
uvicorn_kwargs = {**uvicorn_kwargs, **ssl_kwargs}

View File

@@ -2,148 +2,10 @@ try:
from whisperlivekit.whisper_streaming_custom.whisper_online import backend_factory, warmup_asr
except ImportError:
from .whisper_streaming_custom.whisper_online import backend_factory, warmup_asr
from argparse import Namespace, ArgumentParser
def parse_args():
parser = ArgumentParser(description="Whisper FastAPI Online Server")
parser.add_argument(
"--host",
type=str,
default="localhost",
help="The host address to bind the server to.",
)
parser.add_argument(
"--port", type=int, default=8000, help="The port number to bind the server to."
)
parser.add_argument(
"--warmup-file",
type=str,
default=None,
dest="warmup_file",
help="""
The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast.
If not set, uses https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav.
If False, no warmup is performed.
""",
)
parser.add_argument(
"--confidence-validation",
action="store_true",
help="Accelerates validation of tokens using confidence scores. Transcription will be faster but punctuation might be less accurate.",
)
parser.add_argument(
"--diarization",
action="store_true",
default=False,
help="Enable speaker diarization.",
)
parser.add_argument(
"--no-transcription",
action="store_true",
help="Disable transcription to only see live diarization results.",
)
parser.add_argument(
"--min-chunk-size",
type=float,
default=0.5,
help="Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.",
)
parser.add_argument(
"--model",
type=str,
default="tiny",
help="Name size of the Whisper model to use (default: tiny). Suggested values: tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo. The model is automatically downloaded from the model hub if not present in model cache dir.",
)
parser.add_argument(
"--model_cache_dir",
type=str,
default=None,
help="Overriding the default model cache dir where models downloaded from the hub are saved",
)
parser.add_argument(
"--model_dir",
type=str,
default=None,
help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.",
)
parser.add_argument(
"--lan",
"--language",
type=str,
default="auto",
help="Source language code, e.g. en,de,cs, or 'auto' for language detection.",
)
parser.add_argument(
"--task",
type=str,
default="transcribe",
choices=["transcribe", "translate"],
help="Transcribe or translate.",
)
parser.add_argument(
"--backend",
type=str,
default="faster-whisper",
choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"],
help="Load only this backend for Whisper processing.",
)
parser.add_argument(
"--vac",
action="store_true",
default=False,
help="Use VAC = voice activity controller. Recommended. Requires torch.",
)
parser.add_argument(
"--vac-chunk-size", type=float, default=0.04, help="VAC sample size in seconds."
)
parser.add_argument(
"--no-vad",
action="store_true",
help="Disable VAD (voice activity detection).",
)
parser.add_argument(
"--buffer_trimming",
type=str,
default="segment",
choices=["sentence", "segment"],
help='Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.',
)
parser.add_argument(
"--buffer_trimming_sec",
type=float,
default=15,
help="Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.",
)
parser.add_argument(
"-l",
"--log-level",
dest="log_level",
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Set the log level",
default="DEBUG",
)
parser.add_argument("--ssl-certfile", type=str, help="Path to the SSL certificate file.", default=None)
parser.add_argument("--ssl-keyfile", type=str, help="Path to the SSL private key file.", default=None)
from argparse import Namespace
args = parser.parse_args()
args.transcription = not args.no_transcription
args.vad = not args.no_vad
delattr(args, 'no_transcription')
delattr(args, 'no_vad')
return args
class WhisperLiveKit:
class TranscriptionEngine:
_instance = None
_initialized = False
@@ -153,14 +15,51 @@ class WhisperLiveKit:
return cls._instance
def __init__(self, **kwargs):
if WhisperLiveKit._initialized:
if TranscriptionEngine._initialized:
return
default_args = vars(parse_args())
defaults = {
"host": "localhost",
"port": 8000,
"warmup_file": None,
"confidence_validation": False,
"diarization": False,
"punctuation_split": False,
"min_chunk_size": 0.5,
"model": "tiny",
"model_cache_dir": None,
"model_dir": None,
"lan": "auto",
"task": "transcribe",
"backend": "faster-whisper",
"vac": False,
"vac_chunk_size": 0.04,
"buffer_trimming": "segment",
"buffer_trimming_sec": 15,
"log_level": "DEBUG",
"ssl_certfile": None,
"ssl_keyfile": None,
"transcription": True,
"vad": True,
"segmentation_model": "pyannote/segmentation-3.0",
"embedding_model": "pyannote/embedding",
}
config_dict = {**defaults, **kwargs}
if 'no_transcription' in kwargs:
config_dict['transcription'] = not kwargs['no_transcription']
if 'no_vad' in kwargs:
config_dict['vad'] = not kwargs['no_vad']
merged_args = {**default_args, **kwargs}
self.args = Namespace(**merged_args)
config_dict.pop('no_transcription', None)
config_dict.pop('no_vad', None)
if 'language' in kwargs:
config_dict['lan'] = kwargs['language']
config_dict.pop('language', None)
self.args = Namespace(**config_dict)
self.asr = None
self.tokenizer = None
@@ -172,13 +71,10 @@ class WhisperLiveKit:
if self.args.diarization:
from whisperlivekit.diarization.diarization_online import DiartDiarization
self.diarization = DiartDiarization()
self.diarization = DiartDiarization(
block_duration=self.args.min_chunk_size,
segmentation_model_name=self.args.segmentation_model,
embedding_model_name=self.args.embedding_model
)
WhisperLiveKit._initialized = True
def web_interface(self):
import pkg_resources
html_path = pkg_resources.resource_filename('whisperlivekit', 'web/live_transcription.html')
with open(html_path, "r", encoding="utf-8") as f:
html = f.read()
return html
TranscriptionEngine._initialized = True

View File

@@ -3,7 +3,8 @@ import re
import threading
import numpy as np
import logging
import time
from queue import SimpleQueue, Empty
from diart import SpeakerDiarization, SpeakerDiarizationConfig
from diart.inference import StreamingInference
@@ -13,6 +14,7 @@ from diart.sources import MicrophoneAudioSource
from rx.core import Observer
from typing import Tuple, Any, List
from pyannote.core import Annotation
import diart.models as m
logger = logging.getLogger(__name__)
@@ -78,40 +80,114 @@ class DiarizationObserver(Observer):
class WebSocketAudioSource(AudioSource):
"""
Custom AudioSource that blocks in read() until close() is called.
Use push_audio() to inject PCM chunks.
Buffers incoming audio and releases it in fixed-size chunks at regular intervals.
"""
def __init__(self, uri: str = "websocket", sample_rate: int = 16000):
def __init__(self, uri: str = "websocket", sample_rate: int = 16000, block_duration: float = 0.5):
super().__init__(uri, sample_rate)
self.block_duration = block_duration
self.block_size = int(np.rint(block_duration * sample_rate))
self._queue = SimpleQueue()
self._buffer = np.array([], dtype=np.float32)
self._buffer_lock = threading.Lock()
self._closed = False
self._close_event = threading.Event()
self._processing_thread = None
self._last_chunk_time = time.time()
def read(self):
"""Start processing buffered audio and emit fixed-size chunks."""
self._processing_thread = threading.Thread(target=self._process_chunks)
self._processing_thread.daemon = True
self._processing_thread.start()
self._close_event.wait()
if self._processing_thread:
self._processing_thread.join(timeout=2.0)
def _process_chunks(self):
"""Process audio from queue and emit fixed-size chunks at regular intervals."""
while not self._closed:
try:
audio_chunk = self._queue.get(timeout=0.1)
with self._buffer_lock:
self._buffer = np.concatenate([self._buffer, audio_chunk])
while len(self._buffer) >= self.block_size:
chunk = self._buffer[:self.block_size]
self._buffer = self._buffer[self.block_size:]
current_time = time.time()
time_since_last = current_time - self._last_chunk_time
if time_since_last < self.block_duration:
time.sleep(self.block_duration - time_since_last)
chunk_reshaped = chunk.reshape(1, -1)
self.stream.on_next(chunk_reshaped)
self._last_chunk_time = time.time()
except Empty:
with self._buffer_lock:
if len(self._buffer) > 0 and time.time() - self._last_chunk_time > self.block_duration:
padded_chunk = np.zeros(self.block_size, dtype=np.float32)
padded_chunk[:len(self._buffer)] = self._buffer
self._buffer = np.array([], dtype=np.float32)
chunk_reshaped = padded_chunk.reshape(1, -1)
self.stream.on_next(chunk_reshaped)
self._last_chunk_time = time.time()
except Exception as e:
logger.error(f"Error in audio processing thread: {e}")
self.stream.on_error(e)
break
with self._buffer_lock:
if len(self._buffer) > 0:
padded_chunk = np.zeros(self.block_size, dtype=np.float32)
padded_chunk[:len(self._buffer)] = self._buffer
chunk_reshaped = padded_chunk.reshape(1, -1)
self.stream.on_next(chunk_reshaped)
self.stream.on_completed()
def close(self):
if not self._closed:
self._closed = True
self.stream.on_completed()
self._close_event.set()
def push_audio(self, chunk: np.ndarray):
"""Add audio chunk to the processing queue."""
if not self._closed:
new_audio = np.expand_dims(chunk, axis=0)
logger.debug('Add new chunk with shape:', new_audio.shape)
self.stream.on_next(new_audio)
if chunk.ndim > 1:
chunk = chunk.flatten()
self._queue.put(chunk)
logger.debug(f'Added chunk to queue with {len(chunk)} samples')
class DiartDiarization:
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False):
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 0.5, segmentation_model_name: str = "pyannote/segmentation-3.0", embedding_model_name: str = "speechbrain/spkrec-ecapa-voxceleb"):
segmentation_model = m.SegmentationModel.from_pretrained(segmentation_model_name)
embedding_model = m.EmbeddingModel.from_pretrained(embedding_model_name)
if config is None:
config = SpeakerDiarizationConfig(
segmentation=segmentation_model,
embedding=embedding_model,
)
self.pipeline = SpeakerDiarization(config=config)
self.observer = DiarizationObserver()
self.lag_diart = None
if use_microphone:
self.source = MicrophoneAudioSource()
self.source = MicrophoneAudioSource(block_duration=block_duration)
self.custom_source = None
else:
self.custom_source = WebSocketAudioSource(uri="websocket_source", sample_rate=sample_rate)
self.custom_source = WebSocketAudioSource(
uri="websocket_source",
sample_rate=sample_rate,
block_duration=block_duration
)
self.source = self.custom_source
self.inference = StreamingInference(
@@ -138,16 +214,102 @@ class DiartDiarization:
if self.custom_source:
self.custom_source.close()
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) -> float:
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list, use_punctuation_split: bool = False) -> float:
"""
Assign speakers to tokens based on timing overlap with speaker segments.
Uses the segments collected by the observer.
If use_punctuation_split is True, uses punctuation marks to refine speaker boundaries.
"""
segments = self.observer.get_segments()
# Debug logging
logger.debug(f"assign_speakers_to_tokens called with {len(tokens)} tokens")
logger.debug(f"Available segments: {len(segments)}")
for i, seg in enumerate(segments[:5]): # Show first 5 segments
logger.debug(f" Segment {i}: {seg.speaker} [{seg.start:.2f}-{seg.end:.2f}]")
if not self.lag_diart and segments and tokens:
self.lag_diart = segments[0].start - tokens[0].start
for token in tokens:
for segment in segments:
if not (segment.end <= token.start or segment.start >= token.end):
if not (segment.end <= token.start + self.lag_diart or segment.start >= token.end + self.lag_diart):
token.speaker = extract_number(segment.speaker) + 1
end_attributed_speaker = max(token.end, end_attributed_speaker)
return end_attributed_speaker
if use_punctuation_split and len(tokens) > 1:
punctuation_marks = {'.', '!', '?'}
print("Here are the tokens:",
[(t.text, t.start, t.end, t.speaker) for t in tokens[:10]])
segment_map = []
for segment in segments:
speaker_num = extract_number(segment.speaker) + 1
segment_map.append((segment.start, segment.end, speaker_num))
segment_map.sort(key=lambda x: x[0])
i = 0
while i < len(tokens):
current_token = tokens[i]
is_sentence_end = False
if current_token.text and current_token.text.strip():
text = current_token.text.strip()
if text[-1] in punctuation_marks:
is_sentence_end = True
logger.debug(f"Token {i} ends sentence: '{current_token.text}' at {current_token.end:.2f}s")
if is_sentence_end and current_token.speaker != -1:
punctuation_time = current_token.end
current_speaker = current_token.speaker
j = i + 1
next_sentence_tokens = []
while j < len(tokens):
next_token = tokens[j]
next_sentence_tokens.append(j)
# Check if this token ends the next sentence
if next_token.text and next_token.text.strip():
if next_token.text.strip()[-1] in punctuation_marks:
break
j += 1
if next_sentence_tokens:
speaker_times = {}
for idx in next_sentence_tokens:
token = tokens[idx]
# Find which segments overlap with this token
for seg_start, seg_end, seg_speaker in segment_map:
if not (seg_end <= token.start or seg_start >= token.end):
# Calculate overlap duration
overlap_start = max(seg_start, token.start)
overlap_end = min(seg_end, token.end)
overlap_duration = overlap_end - overlap_start
if seg_speaker not in speaker_times:
speaker_times[seg_speaker] = 0
speaker_times[seg_speaker] += overlap_duration
if speaker_times:
dominant_speaker = max(speaker_times.items(), key=lambda x: x[1])[0]
if dominant_speaker != current_speaker:
logger.debug(f" Speaker change after punctuation: {current_speaker}{dominant_speaker}")
for idx in next_sentence_tokens:
if tokens[idx].speaker != dominant_speaker:
logger.debug(f" Reassigning token {idx} ('{tokens[idx].text}') to Speaker {dominant_speaker}")
tokens[idx].speaker = dominant_speaker
end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker)
else:
for idx in next_sentence_tokens:
if tokens[idx].speaker == -1:
tokens[idx].speaker = current_speaker
end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker)
i += 1
return end_attributed_speaker

View File

@@ -0,0 +1,162 @@
from argparse import ArgumentParser
def parse_args():
parser = ArgumentParser(description="Whisper FastAPI Online Server")
parser.add_argument(
"--host",
type=str,
default="localhost",
help="The host address to bind the server to.",
)
parser.add_argument(
"--port", type=int, default=8000, help="The port number to bind the server to."
)
parser.add_argument(
"--warmup-file",
type=str,
default=None,
dest="warmup_file",
help="""
The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast.
If not set, uses https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav.
If False, no warmup is performed.
""",
)
parser.add_argument(
"--confidence-validation",
action="store_true",
help="Accelerates validation of tokens using confidence scores. Transcription will be faster but punctuation might be less accurate.",
)
parser.add_argument(
"--diarization",
action="store_true",
default=False,
help="Enable speaker diarization.",
)
parser.add_argument(
"--punctuation-split",
action="store_true",
default=False,
help="Use punctuation marks from transcription to improve speaker boundary detection. Requires both transcription and diarization to be enabled.",
)
parser.add_argument(
"--segmentation-model",
type=str,
default="pyannote/segmentation-3.0",
help="Hugging Face model ID for pyannote.audio segmentation model.",
)
parser.add_argument(
"--embedding-model",
type=str,
default="pyannote/embedding",
help="Hugging Face model ID for pyannote.audio embedding model.",
)
parser.add_argument(
"--no-transcription",
action="store_true",
help="Disable transcription to only see live diarization results.",
)
parser.add_argument(
"--min-chunk-size",
type=float,
default=0.5,
help="Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.",
)
parser.add_argument(
"--model",
type=str,
default="tiny",
help="Name size of the Whisper model to use (default: tiny). Suggested values: tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo. The model is automatically downloaded from the model hub if not present in model cache dir.",
)
parser.add_argument(
"--model_cache_dir",
type=str,
default=None,
help="Overriding the default model cache dir where models downloaded from the hub are saved",
)
parser.add_argument(
"--model_dir",
type=str,
default=None,
help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.",
)
parser.add_argument(
"--lan",
"--language",
type=str,
default="auto",
help="Source language code, e.g. en,de,cs, or 'auto' for language detection.",
)
parser.add_argument(
"--task",
type=str,
default="transcribe",
choices=["transcribe", "translate"],
help="Transcribe or translate.",
)
parser.add_argument(
"--backend",
type=str,
default="faster-whisper",
choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"],
help="Load only this backend for Whisper processing.",
)
parser.add_argument(
"--vac",
action="store_true",
default=False,
help="Use VAC = voice activity controller. Recommended. Requires torch.",
)
parser.add_argument(
"--vac-chunk-size", type=float, default=0.04, help="VAC sample size in seconds."
)
parser.add_argument(
"--no-vad",
action="store_true",
help="Disable VAD (voice activity detection).",
)
parser.add_argument(
"--buffer_trimming",
type=str,
default="segment",
choices=["sentence", "segment"],
help='Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.',
)
parser.add_argument(
"--buffer_trimming_sec",
type=float,
default=15,
help="Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.",
)
parser.add_argument(
"-l",
"--log-level",
dest="log_level",
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Set the log level",
default="DEBUG",
)
parser.add_argument("--ssl-certfile", type=str, help="Path to the SSL certificate file.", default=None)
parser.add_argument("--ssl-keyfile", type=str, help="Path to the SSL private key file.", default=None)
args = parser.parse_args()
args.transcription = not args.no_transcription
args.vad = not args.no_vad
delattr(args, 'no_transcription')
delattr(args, 'no_vad')
return args

View File

@@ -26,4 +26,7 @@ class Transcript(TimedText):
@dataclass
class SpeakerSegment(TimedText):
"""Represents a segment of audio attributed to a specific speaker.
No text nor probability is associated with this segment.
"""
pass

View File

@@ -0,0 +1,13 @@
import logging
import importlib.resources as resources
logger = logging.getLogger(__name__)
def get_web_interface_html():
"""Loads the HTML for the web interface using importlib.resources."""
try:
with resources.files('whisperlivekit.web').joinpath('live_transcription.html').open('r', encoding='utf-8') as f:
return f.read()
except Exception as e:
logger.error(f"Error loading web interface HTML: {e}")
return "<html><body><h1>Error loading interface</h1></body></html>"