mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e165916952 | ||
|
|
8532a91c7a | ||
|
|
b01b81bad0 | ||
|
|
0f79d442ee | ||
|
|
c9f60504e3 | ||
|
|
993a83546a |
75
README.md
75
README.md
@@ -32,6 +32,7 @@ WhisperLiveKit consists of three main components:
|
||||
- **👥 Speaker Diarization** - Identify different speakers in real-time using [Diart](https://github.com/juanmc2005/diart)
|
||||
- **🔒 Fully Local** - All processing happens on your machine - no data sent to external servers
|
||||
- **📱 Multi-User Support** - Handle multiple users simultaneously with a single backend/server
|
||||
- **📝 Punctuation-Based Speaker Splitting [BETA] ** - Align speaker changes with natural sentence boundaries for more readable transcripts
|
||||
|
||||
### ⚙️ Core differences from [Whisper Streaming](https://github.com/ufal/whisper_streaming)
|
||||
|
||||
@@ -142,52 +143,79 @@ whisperlivekit-server --host 0.0.0.0 --port 8000 --model medium --diarization --
|
||||
```
|
||||
|
||||
### Python API Integration (Backend)
|
||||
Check [basic_server.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) for a complete example.
|
||||
|
||||
```python
|
||||
from whisperlivekit import WhisperLiveKit
|
||||
from whisperlivekit.audio_processor import AudioProcessor
|
||||
from fastapi import FastAPI, WebSocket
|
||||
import asyncio
|
||||
from whisperlivekit import TranscriptionEngine, AudioProcessor, get_web_interface_html, parse_args
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import HTMLResponse
|
||||
from contextlib import asynccontextmanager
|
||||
import asyncio
|
||||
|
||||
# Initialize components
|
||||
app = FastAPI()
|
||||
kit = WhisperLiveKit(model="medium", diarization=True)
|
||||
# Global variable for the transcription engine
|
||||
transcription_engine = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global transcription_engine
|
||||
# Example: Initialize with specific parameters directly
|
||||
# You can also load from command-line arguments using parse_args()
|
||||
# args = parse_args()
|
||||
# transcription_engine = TranscriptionEngine(**vars(args))
|
||||
transcription_engine = TranscriptionEngine(model="medium", diarization=True, lan="en")
|
||||
yield
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
# Serve the web interface
|
||||
@app.get("/")
|
||||
async def get():
|
||||
return HTMLResponse(kit.web_interface()) # Use the built-in web interface
|
||||
return HTMLResponse(get_web_interface_html())
|
||||
|
||||
# Process WebSocket connections
|
||||
async def handle_websocket_results(websocket, results_generator):
|
||||
async for response in results_generator:
|
||||
await websocket.send_json(response)
|
||||
async def handle_websocket_results(websocket: WebSocket, results_generator):
|
||||
try:
|
||||
async for response in results_generator:
|
||||
await websocket.send_json(response)
|
||||
await websocket.send_json({"type": "ready_to_stop"})
|
||||
except WebSocketDisconnect:
|
||||
print("WebSocket disconnected during results handling.")
|
||||
|
||||
@app.websocket("/asr")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
audio_processor = AudioProcessor()
|
||||
await websocket.accept()
|
||||
results_generator = await audio_processor.create_tasks()
|
||||
websocket_task = asyncio.create_task(
|
||||
handle_websocket_results(websocket, results_generator)
|
||||
)
|
||||
global transcription_engine
|
||||
|
||||
# Create a new AudioProcessor for each connection, passing the shared engine
|
||||
audio_processor = AudioProcessor(transcription_engine=transcription_engine)
|
||||
results_generator = await audio_processor.create_tasks()
|
||||
send_results_to_client = handle_websocket_results(websocket, results_generator)
|
||||
results_task = asyncio.create_task(send_results_to_client)
|
||||
await websocket.accept()
|
||||
try:
|
||||
while True:
|
||||
message = await websocket.receive_bytes()
|
||||
await audio_processor.process_audio(message)
|
||||
await audio_processor.process_audio(message)
|
||||
except WebSocketDisconnect:
|
||||
print(f"Client disconnected: {websocket.client}")
|
||||
except Exception as e:
|
||||
print(f"WebSocket error: {e}")
|
||||
websocket_task.cancel()
|
||||
await websocket.close(code=1011, reason=f"Server error: {e}")
|
||||
finally:
|
||||
results_task.cancel()
|
||||
try:
|
||||
await results_task
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Results task successfully cancelled.")
|
||||
```
|
||||
|
||||
### Frontend Implementation
|
||||
|
||||
The package includes a simple HTML/JavaScript implementation that you can adapt for your project. You can get in in [whisperlivekit/web/live_transcription.html](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html), or using :
|
||||
The package includes a simple HTML/JavaScript implementation that you can adapt for your project. You can find it in `whisperlivekit/web/live_transcription.html`, or load its content using the `get_web_interface_html()` function from `whisperlivekit`:
|
||||
|
||||
```python
|
||||
kit.web_interface()
|
||||
from whisperlivekit import get_web_interface_html
|
||||
|
||||
# ... later in your code where you need the HTML string ...
|
||||
html_content = get_web_interface_html()
|
||||
```
|
||||
|
||||
## ⚙️ Configuration Reference
|
||||
@@ -203,6 +231,7 @@ WhisperLiveKit offers extensive configuration options:
|
||||
| `--task` | `transcribe` or `translate` | `transcribe` |
|
||||
| `--backend` | Processing backend | `faster-whisper` |
|
||||
| `--diarization` | Enable speaker identification | `False` |
|
||||
| `--punctuation-split` | Use punctuation to improve speaker boundaries | `True` |
|
||||
| `--confidence-validation` | Use confidence scores for faster validation | `False` |
|
||||
| `--min-chunk-size` | Minimum audio chunk size (seconds) | `1.0` |
|
||||
| `--vac` | Use Voice Activity Controller | `False` |
|
||||
@@ -211,6 +240,8 @@ WhisperLiveKit offers extensive configuration options:
|
||||
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
|
||||
| `--ssl-certfile` | Path to the SSL certificate file (for HTTPS support) | `None` |
|
||||
| `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` |
|
||||
| `--segmentation-model` | Hugging Face model ID for pyannote.audio segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
|
||||
| `--embedding-model` | Hugging Face model ID for pyannote.audio embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
|
||||
|
||||
## 🔧 How It Works
|
||||
|
||||
|
||||
2
setup.py
2
setup.py
@@ -1,7 +1,7 @@
|
||||
from setuptools import setup, find_packages
|
||||
setup(
|
||||
name="whisperlivekit",
|
||||
version="0.1.7",
|
||||
version="0.1.9",
|
||||
description="Real-time, Fully Local Whisper's Speech-to-Text and Speaker Diarization",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from .core import WhisperLiveKit, parse_args
|
||||
from .core import TranscriptionEngine
|
||||
from .audio_processor import AudioProcessor
|
||||
|
||||
__all__ = ['WhisperLiveKit', 'AudioProcessor', 'parse_args']
|
||||
from .web.web_interface import get_web_interface_html
|
||||
from .parse_args import parse_args
|
||||
__all__ = ['TranscriptionEngine', 'AudioProcessor', 'get_web_interface_html', 'parse_args']
|
||||
@@ -8,7 +8,7 @@ import traceback
|
||||
from datetime import timedelta
|
||||
from whisperlivekit.timed_objects import ASRToken
|
||||
from whisperlivekit.whisper_streaming_custom.whisper_online import online_factory
|
||||
from whisperlivekit.core import WhisperLiveKit
|
||||
from whisperlivekit.core import TranscriptionEngine
|
||||
|
||||
# Set up logging once
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
@@ -27,10 +27,13 @@ class AudioProcessor:
|
||||
Handles audio processing, state management, and result formatting.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize the audio processor with configuration, models, and state."""
|
||||
|
||||
models = WhisperLiveKit()
|
||||
if 'transcription_engine' in kwargs and isinstance(kwargs['transcription_engine'], TranscriptionEngine):
|
||||
models = kwargs['transcription_engine']
|
||||
else:
|
||||
models = TranscriptionEngine(**kwargs)
|
||||
|
||||
# Audio processing settings
|
||||
self.args = models.args
|
||||
@@ -374,13 +377,16 @@ class AudioProcessor:
|
||||
# Process diarization
|
||||
await diarization_obj.diarize(pcm_array)
|
||||
|
||||
# Get current state and update speakers
|
||||
state = await self.get_current_state()
|
||||
new_end = diarization_obj.assign_speakers_to_tokens(
|
||||
state["end_attributed_speaker"], state["tokens"]
|
||||
)
|
||||
async with self.lock:
|
||||
new_end = diarization_obj.assign_speakers_to_tokens(
|
||||
self.end_attributed_speaker,
|
||||
self.tokens,
|
||||
use_punctuation_split=self.args.punctuation_split
|
||||
)
|
||||
self.end_attributed_speaker = new_end
|
||||
if buffer_diarization:
|
||||
self.buffer_diarization = buffer_diarization
|
||||
|
||||
await self.update_diarization(new_end, buffer_diarization)
|
||||
self.diarization_queue.task_done()
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -2,26 +2,24 @@ from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from whisperlivekit import WhisperLiveKit, parse_args
|
||||
from whisperlivekit.audio_processor import AudioProcessor
|
||||
|
||||
from whisperlivekit import TranscriptionEngine, AudioProcessor, get_web_interface_html, parse_args
|
||||
import asyncio
|
||||
import logging
|
||||
import os, sys
|
||||
import argparse
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
logging.getLogger().setLevel(logging.WARNING)
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
kit = None
|
||||
args = parse_args()
|
||||
transcription_engine = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global kit
|
||||
kit = WhisperLiveKit()
|
||||
global transcription_engine
|
||||
transcription_engine = TranscriptionEngine(
|
||||
**vars(args),
|
||||
)
|
||||
yield
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
@@ -33,10 +31,9 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def get():
|
||||
return HTMLResponse(kit.web_interface())
|
||||
return HTMLResponse(get_web_interface_html())
|
||||
|
||||
|
||||
async def handle_websocket_results(websocket, results_generator):
|
||||
@@ -55,8 +52,10 @@ async def handle_websocket_results(websocket, results_generator):
|
||||
|
||||
@app.websocket("/asr")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
audio_processor = AudioProcessor()
|
||||
|
||||
global transcription_engine
|
||||
audio_processor = AudioProcessor(
|
||||
transcription_engine=transcription_engine,
|
||||
)
|
||||
await websocket.accept()
|
||||
logger.info("WebSocket connection opened.")
|
||||
|
||||
@@ -94,8 +93,6 @@ def main():
|
||||
"""Entry point for the CLI command."""
|
||||
import uvicorn
|
||||
|
||||
args = parse_args()
|
||||
|
||||
uvicorn_kwargs = {
|
||||
"app": "whisperlivekit.basic_server:app",
|
||||
"host":args.host,
|
||||
@@ -114,7 +111,6 @@ def main():
|
||||
"ssl_keyfile": args.ssl_keyfile
|
||||
}
|
||||
|
||||
|
||||
if ssl_kwargs:
|
||||
uvicorn_kwargs = {**uvicorn_kwargs, **ssl_kwargs}
|
||||
|
||||
|
||||
@@ -2,148 +2,10 @@ try:
|
||||
from whisperlivekit.whisper_streaming_custom.whisper_online import backend_factory, warmup_asr
|
||||
except ImportError:
|
||||
from .whisper_streaming_custom.whisper_online import backend_factory, warmup_asr
|
||||
from argparse import Namespace, ArgumentParser
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser(description="Whisper FastAPI Online Server")
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="localhost",
|
||||
help="The host address to bind the server to.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=8000, help="The port number to bind the server to."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warmup-file",
|
||||
type=str,
|
||||
default=None,
|
||||
dest="warmup_file",
|
||||
help="""
|
||||
The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast.
|
||||
If not set, uses https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav.
|
||||
If False, no warmup is performed.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--confidence-validation",
|
||||
action="store_true",
|
||||
help="Accelerates validation of tokens using confidence scores. Transcription will be faster but punctuation might be less accurate.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--diarization",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable speaker diarization.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-transcription",
|
||||
action="store_true",
|
||||
help="Disable transcription to only see live diarization results.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--min-chunk-size",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="tiny",
|
||||
help="Name size of the Whisper model to use (default: tiny). Suggested values: tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo. The model is automatically downloaded from the model hub if not present in model cache dir.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model_cache_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Overriding the default model cache dir where models downloaded from the hub are saved",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lan",
|
||||
"--language",
|
||||
type=str,
|
||||
default="auto",
|
||||
help="Source language code, e.g. en,de,cs, or 'auto' for language detection.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
default="transcribe",
|
||||
choices=["transcribe", "translate"],
|
||||
help="Transcribe or translate.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
type=str,
|
||||
default="faster-whisper",
|
||||
choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"],
|
||||
help="Load only this backend for Whisper processing.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vac",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use VAC = voice activity controller. Recommended. Requires torch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vac-chunk-size", type=float, default=0.04, help="VAC sample size in seconds."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-vad",
|
||||
action="store_true",
|
||||
help="Disable VAD (voice activity detection).",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--buffer_trimming",
|
||||
type=str,
|
||||
default="segment",
|
||||
choices=["sentence", "segment"],
|
||||
help='Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--buffer_trimming_sec",
|
||||
type=float,
|
||||
default=15,
|
||||
help="Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--log-level",
|
||||
dest="log_level",
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
||||
help="Set the log level",
|
||||
default="DEBUG",
|
||||
)
|
||||
parser.add_argument("--ssl-certfile", type=str, help="Path to the SSL certificate file.", default=None)
|
||||
parser.add_argument("--ssl-keyfile", type=str, help="Path to the SSL private key file.", default=None)
|
||||
from argparse import Namespace
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
args.transcription = not args.no_transcription
|
||||
args.vad = not args.no_vad
|
||||
delattr(args, 'no_transcription')
|
||||
delattr(args, 'no_vad')
|
||||
|
||||
return args
|
||||
|
||||
class WhisperLiveKit:
|
||||
class TranscriptionEngine:
|
||||
_instance = None
|
||||
_initialized = False
|
||||
|
||||
@@ -153,14 +15,51 @@ class WhisperLiveKit:
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
if WhisperLiveKit._initialized:
|
||||
if TranscriptionEngine._initialized:
|
||||
return
|
||||
|
||||
default_args = vars(parse_args())
|
||||
|
||||
defaults = {
|
||||
"host": "localhost",
|
||||
"port": 8000,
|
||||
"warmup_file": None,
|
||||
"confidence_validation": False,
|
||||
"diarization": False,
|
||||
"punctuation_split": False,
|
||||
"min_chunk_size": 0.5,
|
||||
"model": "tiny",
|
||||
"model_cache_dir": None,
|
||||
"model_dir": None,
|
||||
"lan": "auto",
|
||||
"task": "transcribe",
|
||||
"backend": "faster-whisper",
|
||||
"vac": False,
|
||||
"vac_chunk_size": 0.04,
|
||||
"buffer_trimming": "segment",
|
||||
"buffer_trimming_sec": 15,
|
||||
"log_level": "DEBUG",
|
||||
"ssl_certfile": None,
|
||||
"ssl_keyfile": None,
|
||||
"transcription": True,
|
||||
"vad": True,
|
||||
"segmentation_model": "pyannote/segmentation-3.0",
|
||||
"embedding_model": "pyannote/embedding",
|
||||
}
|
||||
|
||||
config_dict = {**defaults, **kwargs}
|
||||
|
||||
if 'no_transcription' in kwargs:
|
||||
config_dict['transcription'] = not kwargs['no_transcription']
|
||||
if 'no_vad' in kwargs:
|
||||
config_dict['vad'] = not kwargs['no_vad']
|
||||
|
||||
merged_args = {**default_args, **kwargs}
|
||||
|
||||
self.args = Namespace(**merged_args)
|
||||
config_dict.pop('no_transcription', None)
|
||||
config_dict.pop('no_vad', None)
|
||||
|
||||
if 'language' in kwargs:
|
||||
config_dict['lan'] = kwargs['language']
|
||||
config_dict.pop('language', None)
|
||||
|
||||
self.args = Namespace(**config_dict)
|
||||
|
||||
self.asr = None
|
||||
self.tokenizer = None
|
||||
@@ -172,13 +71,10 @@ class WhisperLiveKit:
|
||||
|
||||
if self.args.diarization:
|
||||
from whisperlivekit.diarization.diarization_online import DiartDiarization
|
||||
self.diarization = DiartDiarization()
|
||||
self.diarization = DiartDiarization(
|
||||
block_duration=self.args.min_chunk_size,
|
||||
segmentation_model_name=self.args.segmentation_model,
|
||||
embedding_model_name=self.args.embedding_model
|
||||
)
|
||||
|
||||
WhisperLiveKit._initialized = True
|
||||
|
||||
def web_interface(self):
|
||||
import pkg_resources
|
||||
html_path = pkg_resources.resource_filename('whisperlivekit', 'web/live_transcription.html')
|
||||
with open(html_path, "r", encoding="utf-8") as f:
|
||||
html = f.read()
|
||||
return html
|
||||
TranscriptionEngine._initialized = True
|
||||
|
||||
@@ -3,7 +3,8 @@ import re
|
||||
import threading
|
||||
import numpy as np
|
||||
import logging
|
||||
|
||||
import time
|
||||
from queue import SimpleQueue, Empty
|
||||
|
||||
from diart import SpeakerDiarization, SpeakerDiarizationConfig
|
||||
from diart.inference import StreamingInference
|
||||
@@ -13,6 +14,7 @@ from diart.sources import MicrophoneAudioSource
|
||||
from rx.core import Observer
|
||||
from typing import Tuple, Any, List
|
||||
from pyannote.core import Annotation
|
||||
import diart.models as m
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -78,40 +80,114 @@ class DiarizationObserver(Observer):
|
||||
|
||||
class WebSocketAudioSource(AudioSource):
|
||||
"""
|
||||
Custom AudioSource that blocks in read() until close() is called.
|
||||
Use push_audio() to inject PCM chunks.
|
||||
Buffers incoming audio and releases it in fixed-size chunks at regular intervals.
|
||||
"""
|
||||
def __init__(self, uri: str = "websocket", sample_rate: int = 16000):
|
||||
def __init__(self, uri: str = "websocket", sample_rate: int = 16000, block_duration: float = 0.5):
|
||||
super().__init__(uri, sample_rate)
|
||||
self.block_duration = block_duration
|
||||
self.block_size = int(np.rint(block_duration * sample_rate))
|
||||
self._queue = SimpleQueue()
|
||||
self._buffer = np.array([], dtype=np.float32)
|
||||
self._buffer_lock = threading.Lock()
|
||||
self._closed = False
|
||||
self._close_event = threading.Event()
|
||||
self._processing_thread = None
|
||||
self._last_chunk_time = time.time()
|
||||
|
||||
def read(self):
|
||||
"""Start processing buffered audio and emit fixed-size chunks."""
|
||||
self._processing_thread = threading.Thread(target=self._process_chunks)
|
||||
self._processing_thread.daemon = True
|
||||
self._processing_thread.start()
|
||||
|
||||
self._close_event.wait()
|
||||
if self._processing_thread:
|
||||
self._processing_thread.join(timeout=2.0)
|
||||
|
||||
def _process_chunks(self):
|
||||
"""Process audio from queue and emit fixed-size chunks at regular intervals."""
|
||||
while not self._closed:
|
||||
try:
|
||||
audio_chunk = self._queue.get(timeout=0.1)
|
||||
|
||||
with self._buffer_lock:
|
||||
self._buffer = np.concatenate([self._buffer, audio_chunk])
|
||||
|
||||
while len(self._buffer) >= self.block_size:
|
||||
chunk = self._buffer[:self.block_size]
|
||||
self._buffer = self._buffer[self.block_size:]
|
||||
|
||||
current_time = time.time()
|
||||
time_since_last = current_time - self._last_chunk_time
|
||||
if time_since_last < self.block_duration:
|
||||
time.sleep(self.block_duration - time_since_last)
|
||||
|
||||
chunk_reshaped = chunk.reshape(1, -1)
|
||||
self.stream.on_next(chunk_reshaped)
|
||||
self._last_chunk_time = time.time()
|
||||
|
||||
except Empty:
|
||||
with self._buffer_lock:
|
||||
if len(self._buffer) > 0 and time.time() - self._last_chunk_time > self.block_duration:
|
||||
padded_chunk = np.zeros(self.block_size, dtype=np.float32)
|
||||
padded_chunk[:len(self._buffer)] = self._buffer
|
||||
self._buffer = np.array([], dtype=np.float32)
|
||||
|
||||
chunk_reshaped = padded_chunk.reshape(1, -1)
|
||||
self.stream.on_next(chunk_reshaped)
|
||||
self._last_chunk_time = time.time()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in audio processing thread: {e}")
|
||||
self.stream.on_error(e)
|
||||
break
|
||||
|
||||
with self._buffer_lock:
|
||||
if len(self._buffer) > 0:
|
||||
padded_chunk = np.zeros(self.block_size, dtype=np.float32)
|
||||
padded_chunk[:len(self._buffer)] = self._buffer
|
||||
chunk_reshaped = padded_chunk.reshape(1, -1)
|
||||
self.stream.on_next(chunk_reshaped)
|
||||
|
||||
self.stream.on_completed()
|
||||
|
||||
def close(self):
|
||||
if not self._closed:
|
||||
self._closed = True
|
||||
self.stream.on_completed()
|
||||
self._close_event.set()
|
||||
|
||||
def push_audio(self, chunk: np.ndarray):
|
||||
"""Add audio chunk to the processing queue."""
|
||||
if not self._closed:
|
||||
new_audio = np.expand_dims(chunk, axis=0)
|
||||
logger.debug('Add new chunk with shape:', new_audio.shape)
|
||||
self.stream.on_next(new_audio)
|
||||
if chunk.ndim > 1:
|
||||
chunk = chunk.flatten()
|
||||
self._queue.put(chunk)
|
||||
logger.debug(f'Added chunk to queue with {len(chunk)} samples')
|
||||
|
||||
|
||||
class DiartDiarization:
|
||||
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False):
|
||||
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 0.5, segmentation_model_name: str = "pyannote/segmentation-3.0", embedding_model_name: str = "speechbrain/spkrec-ecapa-voxceleb"):
|
||||
segmentation_model = m.SegmentationModel.from_pretrained(segmentation_model_name)
|
||||
embedding_model = m.EmbeddingModel.from_pretrained(embedding_model_name)
|
||||
|
||||
if config is None:
|
||||
config = SpeakerDiarizationConfig(
|
||||
segmentation=segmentation_model,
|
||||
embedding=embedding_model,
|
||||
)
|
||||
|
||||
self.pipeline = SpeakerDiarization(config=config)
|
||||
self.observer = DiarizationObserver()
|
||||
self.lag_diart = None
|
||||
|
||||
if use_microphone:
|
||||
self.source = MicrophoneAudioSource()
|
||||
self.source = MicrophoneAudioSource(block_duration=block_duration)
|
||||
self.custom_source = None
|
||||
else:
|
||||
self.custom_source = WebSocketAudioSource(uri="websocket_source", sample_rate=sample_rate)
|
||||
self.custom_source = WebSocketAudioSource(
|
||||
uri="websocket_source",
|
||||
sample_rate=sample_rate,
|
||||
block_duration=block_duration
|
||||
)
|
||||
self.source = self.custom_source
|
||||
|
||||
self.inference = StreamingInference(
|
||||
@@ -138,16 +214,102 @@ class DiartDiarization:
|
||||
if self.custom_source:
|
||||
self.custom_source.close()
|
||||
|
||||
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) -> float:
|
||||
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list, use_punctuation_split: bool = False) -> float:
|
||||
"""
|
||||
Assign speakers to tokens based on timing overlap with speaker segments.
|
||||
Uses the segments collected by the observer.
|
||||
|
||||
If use_punctuation_split is True, uses punctuation marks to refine speaker boundaries.
|
||||
"""
|
||||
segments = self.observer.get_segments()
|
||||
|
||||
# Debug logging
|
||||
logger.debug(f"assign_speakers_to_tokens called with {len(tokens)} tokens")
|
||||
logger.debug(f"Available segments: {len(segments)}")
|
||||
for i, seg in enumerate(segments[:5]): # Show first 5 segments
|
||||
logger.debug(f" Segment {i}: {seg.speaker} [{seg.start:.2f}-{seg.end:.2f}]")
|
||||
|
||||
if not self.lag_diart and segments and tokens:
|
||||
self.lag_diart = segments[0].start - tokens[0].start
|
||||
for token in tokens:
|
||||
for segment in segments:
|
||||
if not (segment.end <= token.start or segment.start >= token.end):
|
||||
if not (segment.end <= token.start + self.lag_diart or segment.start >= token.end + self.lag_diart):
|
||||
token.speaker = extract_number(segment.speaker) + 1
|
||||
end_attributed_speaker = max(token.end, end_attributed_speaker)
|
||||
return end_attributed_speaker
|
||||
|
||||
if use_punctuation_split and len(tokens) > 1:
|
||||
punctuation_marks = {'.', '!', '?'}
|
||||
|
||||
print("Here are the tokens:",
|
||||
[(t.text, t.start, t.end, t.speaker) for t in tokens[:10]])
|
||||
|
||||
segment_map = []
|
||||
for segment in segments:
|
||||
speaker_num = extract_number(segment.speaker) + 1
|
||||
segment_map.append((segment.start, segment.end, speaker_num))
|
||||
segment_map.sort(key=lambda x: x[0])
|
||||
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
current_token = tokens[i]
|
||||
|
||||
is_sentence_end = False
|
||||
if current_token.text and current_token.text.strip():
|
||||
text = current_token.text.strip()
|
||||
if text[-1] in punctuation_marks:
|
||||
is_sentence_end = True
|
||||
logger.debug(f"Token {i} ends sentence: '{current_token.text}' at {current_token.end:.2f}s")
|
||||
|
||||
if is_sentence_end and current_token.speaker != -1:
|
||||
punctuation_time = current_token.end
|
||||
current_speaker = current_token.speaker
|
||||
|
||||
j = i + 1
|
||||
next_sentence_tokens = []
|
||||
while j < len(tokens):
|
||||
next_token = tokens[j]
|
||||
next_sentence_tokens.append(j)
|
||||
|
||||
# Check if this token ends the next sentence
|
||||
if next_token.text and next_token.text.strip():
|
||||
if next_token.text.strip()[-1] in punctuation_marks:
|
||||
break
|
||||
j += 1
|
||||
|
||||
if next_sentence_tokens:
|
||||
speaker_times = {}
|
||||
|
||||
for idx in next_sentence_tokens:
|
||||
token = tokens[idx]
|
||||
# Find which segments overlap with this token
|
||||
for seg_start, seg_end, seg_speaker in segment_map:
|
||||
if not (seg_end <= token.start or seg_start >= token.end):
|
||||
# Calculate overlap duration
|
||||
overlap_start = max(seg_start, token.start)
|
||||
overlap_end = min(seg_end, token.end)
|
||||
overlap_duration = overlap_end - overlap_start
|
||||
|
||||
if seg_speaker not in speaker_times:
|
||||
speaker_times[seg_speaker] = 0
|
||||
speaker_times[seg_speaker] += overlap_duration
|
||||
|
||||
if speaker_times:
|
||||
dominant_speaker = max(speaker_times.items(), key=lambda x: x[1])[0]
|
||||
|
||||
if dominant_speaker != current_speaker:
|
||||
logger.debug(f" Speaker change after punctuation: {current_speaker} → {dominant_speaker}")
|
||||
|
||||
for idx in next_sentence_tokens:
|
||||
if tokens[idx].speaker != dominant_speaker:
|
||||
logger.debug(f" Reassigning token {idx} ('{tokens[idx].text}') to Speaker {dominant_speaker}")
|
||||
tokens[idx].speaker = dominant_speaker
|
||||
end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker)
|
||||
else:
|
||||
for idx in next_sentence_tokens:
|
||||
if tokens[idx].speaker == -1:
|
||||
tokens[idx].speaker = current_speaker
|
||||
end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker)
|
||||
|
||||
i += 1
|
||||
|
||||
return end_attributed_speaker
|
||||
|
||||
162
whisperlivekit/parse_args.py
Normal file
162
whisperlivekit/parse_args.py
Normal file
@@ -0,0 +1,162 @@
|
||||
|
||||
from argparse import ArgumentParser
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser(description="Whisper FastAPI Online Server")
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="localhost",
|
||||
help="The host address to bind the server to.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=8000, help="The port number to bind the server to."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warmup-file",
|
||||
type=str,
|
||||
default=None,
|
||||
dest="warmup_file",
|
||||
help="""
|
||||
The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast.
|
||||
If not set, uses https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav.
|
||||
If False, no warmup is performed.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--confidence-validation",
|
||||
action="store_true",
|
||||
help="Accelerates validation of tokens using confidence scores. Transcription will be faster but punctuation might be less accurate.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--diarization",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable speaker diarization.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--punctuation-split",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use punctuation marks from transcription to improve speaker boundary detection. Requires both transcription and diarization to be enabled.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--segmentation-model",
|
||||
type=str,
|
||||
default="pyannote/segmentation-3.0",
|
||||
help="Hugging Face model ID for pyannote.audio segmentation model.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--embedding-model",
|
||||
type=str,
|
||||
default="pyannote/embedding",
|
||||
help="Hugging Face model ID for pyannote.audio embedding model.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-transcription",
|
||||
action="store_true",
|
||||
help="Disable transcription to only see live diarization results.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--min-chunk-size",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="tiny",
|
||||
help="Name size of the Whisper model to use (default: tiny). Suggested values: tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo. The model is automatically downloaded from the model hub if not present in model cache dir.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model_cache_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Overriding the default model cache dir where models downloaded from the hub are saved",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lan",
|
||||
"--language",
|
||||
type=str,
|
||||
default="auto",
|
||||
help="Source language code, e.g. en,de,cs, or 'auto' for language detection.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
default="transcribe",
|
||||
choices=["transcribe", "translate"],
|
||||
help="Transcribe or translate.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
type=str,
|
||||
default="faster-whisper",
|
||||
choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"],
|
||||
help="Load only this backend for Whisper processing.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vac",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use VAC = voice activity controller. Recommended. Requires torch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vac-chunk-size", type=float, default=0.04, help="VAC sample size in seconds."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-vad",
|
||||
action="store_true",
|
||||
help="Disable VAD (voice activity detection).",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--buffer_trimming",
|
||||
type=str,
|
||||
default="segment",
|
||||
choices=["sentence", "segment"],
|
||||
help='Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--buffer_trimming_sec",
|
||||
type=float,
|
||||
default=15,
|
||||
help="Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--log-level",
|
||||
dest="log_level",
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
||||
help="Set the log level",
|
||||
default="DEBUG",
|
||||
)
|
||||
parser.add_argument("--ssl-certfile", type=str, help="Path to the SSL certificate file.", default=None)
|
||||
parser.add_argument("--ssl-keyfile", type=str, help="Path to the SSL private key file.", default=None)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
args.transcription = not args.no_transcription
|
||||
args.vad = not args.no_vad
|
||||
delattr(args, 'no_transcription')
|
||||
delattr(args, 'no_vad')
|
||||
|
||||
return args
|
||||
@@ -26,4 +26,7 @@ class Transcript(TimedText):
|
||||
|
||||
@dataclass
|
||||
class SpeakerSegment(TimedText):
|
||||
"""Represents a segment of audio attributed to a specific speaker.
|
||||
No text nor probability is associated with this segment.
|
||||
"""
|
||||
pass
|
||||
13
whisperlivekit/web/web_interface.py
Normal file
13
whisperlivekit/web/web_interface.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import logging
|
||||
import importlib.resources as resources
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_web_interface_html():
|
||||
"""Loads the HTML for the web interface using importlib.resources."""
|
||||
try:
|
||||
with resources.files('whisperlivekit.web').joinpath('live_transcription.html').open('r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading web interface HTML: {e}")
|
||||
return "<html><body><h1>Error loading interface</h1></body></html>"
|
||||
Reference in New Issue
Block a user