2 Commits

Author SHA1 Message Date
Quentin Fuxa
e704b0b0db Refactor imports and update __all__ to include internal argument parsing functions 2025-05-05 09:38:46 +02:00
Quentin Fuxa
2dd974ade0 Add support for PyAudioWPatch audio input on Windows
- Updated README.md to include installation instructions for PyAudioWPatch.
- Modified setup.py to add PyAudioWPatch as an optional dependency.
- Enhanced audio_processor.py to initialize and handle PyAudioWPatch for system audio capture.
- Updated basic_server.py to manage audio input modes and integrate PyAudioWPatch processing.
- Refactored core.py to include audio input argument parsing.
2025-05-05 09:30:18 +02:00
13 changed files with 702 additions and 919 deletions

13
LICENSE
View File

@@ -1,6 +1,10 @@
MIT License MIT License
Copyright (c) 2025 Quentin Fuxa. Copyright (c) 2025 Quentin Fuxa.
Based on:
- The original work by ÚFAL. License: https://github.com/ufal/whisper_streaming/blob/main/LICENSE
- The work by Snakers4 (silero-vad). License: https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/LICENSE
- The work in Diart by juanmc2005. License: https://github.com/juanmc2005/diart/blob/main/LICENSE
Permission is hereby granted, free of charge, to any person obtaining a copy Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal of this software and associated documentation files (the "Software"), to deal
@@ -22,7 +26,8 @@ SOFTWARE.
--- ---
Based on: Third-party components included in this software:
- **whisper_streaming** by ÚFAL MIT License https://github.com/ufal/whisper_streaming. The original work by ÚFAL. License: https://github.com/ufal/whisper_streaming/blob/main/LICENSE
- **silero-vad** by Snakers4 MIT License https://github.com/snakers4/silero-vad. The work by Snakers4 (silero-vad). License: https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/LICENSE - **whisper_streaming** by ÚFAL MIT License https://github.com/ufal/whisper_streaming
- **Diart** by juanmc2005 MIT License https://github.com/juanmc2005/diart. The work in Diart by juanmc2005. License: https://github.com/juanmc2005/diart/blob/main/LICENSE - **silero-vad** by Snakers4 MIT License https://github.com/snakers4/silero-vad
- **Diart** by juanmc2005 MIT License https://github.com/juanmc2005/diart

102
README.md
View File

@@ -9,8 +9,8 @@
<p align="center"> <p align="center">
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a> <a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=downloads"></a> <a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=downloads"></a>
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.13-dark_green"></a> <a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9%20%7C%203.10%20%7C%203.11%20%7C%203.12-dark_green"></a>
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT-dark_green"></a> <a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/github/license/QuentinFuxa/WhisperLiveKit?color=blue"></a>
</p> </p>
## 🚀 Overview ## 🚀 Overview
@@ -32,7 +32,6 @@ WhisperLiveKit consists of three main components:
- **👥 Speaker Diarization** - Identify different speakers in real-time using [Diart](https://github.com/juanmc2005/diart) - **👥 Speaker Diarization** - Identify different speakers in real-time using [Diart](https://github.com/juanmc2005/diart)
- **🔒 Fully Local** - All processing happens on your machine - no data sent to external servers - **🔒 Fully Local** - All processing happens on your machine - no data sent to external servers
- **📱 Multi-User Support** - Handle multiple users simultaneously with a single backend/server - **📱 Multi-User Support** - Handle multiple users simultaneously with a single backend/server
- **📝 Punctuation-Based Speaker Splitting [BETA] ** - Align speaker changes with natural sentence boundaries for more readable transcripts
### ⚙️ Core differences from [Whisper Streaming](https://github.com/ufal/whisper_streaming) ### ⚙️ Core differences from [Whisper Streaming](https://github.com/ufal/whisper_streaming)
@@ -113,6 +112,9 @@ pip install whisperlivekit[whisper] # Original Whisper
pip install whisperlivekit[whisper-timestamped] # Improved timestamps pip install whisperlivekit[whisper-timestamped] # Improved timestamps
pip install whisperlivekit[mlx-whisper] # Apple Silicon optimization pip install whisperlivekit[mlx-whisper] # Apple Silicon optimization
pip install whisperlivekit[openai] # OpenAI API pip install whisperlivekit[openai] # OpenAI API
# System audio capture (Windows only)
pip install whisperlivekit[pyaudiowpatch] # Use PyAudioWPatch for system audio loopback
``` ```
### 🎹 Pyannote Models Setup ### 🎹 Pyannote Models Setup
@@ -140,82 +142,58 @@ whisperlivekit-server --model tiny.en
# Advanced configuration with diarization # Advanced configuration with diarization
whisperlivekit-server --host 0.0.0.0 --port 8000 --model medium --diarization --language auto whisperlivekit-server --host 0.0.0.0 --port 8000 --model medium --diarization --language auto
# Using PyAudioWPatch for system audio input (Windows only)
whisperlivekit-server --model tiny.en --audio-input pyaudiowpatch
``` ```
### Python API Integration (Backend) ### Python API Integration (Backend)
Check [basic_server.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) for a complete example.
```python ```python
from whisperlivekit import TranscriptionEngine, AudioProcessor, get_web_interface_html, parse_args from whisperlivekit import WhisperLiveKit
from fastapi import FastAPI, WebSocket, WebSocketDisconnect from whisperlivekit.audio_processor import AudioProcessor
from fastapi.responses import HTMLResponse from fastapi import FastAPI, WebSocket
from contextlib import asynccontextmanager
import asyncio import asyncio
from fastapi.responses import HTMLResponse
# Global variable for the transcription engine # Initialize components
transcription_engine = None app = FastAPI()
kit = WhisperLiveKit(model="medium", diarization=True)
@asynccontextmanager
async def lifespan(app: FastAPI):
global transcription_engine
# Example: Initialize with specific parameters directly
# You can also load from command-line arguments using parse_args()
# args = parse_args()
# transcription_engine = TranscriptionEngine(**vars(args))
transcription_engine = TranscriptionEngine(model="medium", diarization=True, lan="en")
yield
app = FastAPI(lifespan=lifespan)
# Serve the web interface # Serve the web interface
@app.get("/") @app.get("/")
async def get(): async def get():
return HTMLResponse(get_web_interface_html()) return HTMLResponse(kit.web_interface()) # Use the built-in web interface
# Process WebSocket connections # Process WebSocket connections
async def handle_websocket_results(websocket: WebSocket, results_generator): async def handle_websocket_results(websocket, results_generator):
try: async for response in results_generator:
async for response in results_generator: await websocket.send_json(response)
await websocket.send_json(response)
await websocket.send_json({"type": "ready_to_stop"})
except WebSocketDisconnect:
print("WebSocket disconnected during results handling.")
@app.websocket("/asr") @app.websocket("/asr")
async def websocket_endpoint(websocket: WebSocket): async def websocket_endpoint(websocket: WebSocket):
global transcription_engine audio_processor = AudioProcessor()
# Create a new AudioProcessor for each connection, passing the shared engine
audio_processor = AudioProcessor(transcription_engine=transcription_engine)
results_generator = await audio_processor.create_tasks()
send_results_to_client = handle_websocket_results(websocket, results_generator)
results_task = asyncio.create_task(send_results_to_client)
await websocket.accept() await websocket.accept()
results_generator = await audio_processor.create_tasks()
websocket_task = asyncio.create_task(
handle_websocket_results(websocket, results_generator)
)
try: try:
while True: while True:
message = await websocket.receive_bytes() message = await websocket.receive_bytes()
await audio_processor.process_audio(message) await audio_processor.process_audio(message)
except WebSocketDisconnect:
print(f"Client disconnected: {websocket.client}")
except Exception as e: except Exception as e:
await websocket.close(code=1011, reason=f"Server error: {e}") print(f"WebSocket error: {e}")
finally: websocket_task.cancel()
results_task.cancel()
try:
await results_task
except asyncio.CancelledError:
logger.info("Results task successfully cancelled.")
``` ```
### Frontend Implementation ### Frontend Implementation
The package includes a simple HTML/JavaScript implementation that you can adapt for your project. You can find it in `whisperlivekit/web/live_transcription.html`, or load its content using the `get_web_interface_html()` function from `whisperlivekit`: The package includes a simple HTML/JavaScript implementation that you can adapt for your project. You can get in in [whisperlivekit/web/live_transcription.html](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html), or using :
```python ```python
from whisperlivekit import get_web_interface_html kit.web_interface()
# ... later in your code where you need the HTML string ...
html_content = get_web_interface_html()
``` ```
## ⚙️ Configuration Reference ## ⚙️ Configuration Reference
@@ -231,17 +209,15 @@ WhisperLiveKit offers extensive configuration options:
| `--task` | `transcribe` or `translate` | `transcribe` | | `--task` | `transcribe` or `translate` | `transcribe` |
| `--backend` | Processing backend | `faster-whisper` | | `--backend` | Processing backend | `faster-whisper` |
| `--diarization` | Enable speaker identification | `False` | | `--diarization` | Enable speaker identification | `False` |
| `--punctuation-split` | Use punctuation to improve speaker boundaries | `True` |
| `--confidence-validation` | Use confidence scores for faster validation | `False` | | `--confidence-validation` | Use confidence scores for faster validation | `False` |
| `--min-chunk-size` | Minimum audio chunk size (seconds) | `1.0` | | `--min-chunk-size` | Minimum audio chunk size (seconds) | `1.0` |
| `--vac` | Use Voice Activity Controller | `False` | | `--vac` | Use Voice Activity Controller | `False` |
| `--no-vad` | Disable Voice Activity Detection | `False` | | `--no-vad` | Disable Voice Activity Detection | `False` |
| `--buffer_trimming` | Buffer trimming strategy (`sentence` or `segment`) | `segment` | | `--buffer_trimming` | Buffer trimming strategy (`sentence` or `segment`) | `segment` |
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` | | `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
| `--audio-input` | Source of audio (`websocket` or `pyaudiowpatch`) | `websocket` |
| `--ssl-certfile` | Path to the SSL certificate file (for HTTPS support) | `None` | | `--ssl-certfile` | Path to the SSL certificate file (for HTTPS support) | `None` |
| `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` | | `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` |
| `--segmentation-model` | Hugging Face model ID for pyannote.audio segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
| `--embedding-model` | Hugging Face model ID for pyannote.audio embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
## 🔧 How It Works ## 🔧 How It Works
@@ -249,12 +225,16 @@ WhisperLiveKit offers extensive configuration options:
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/demo.png" alt="WhisperLiveKit in Action" width="500"> <img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/demo.png" alt="WhisperLiveKit in Action" width="500">
</p> </p>
1. **Audio Capture**: Browser's MediaRecorder API captures audio in webm/opus format 1. **Audio Input**:
2. **Streaming**: Audio chunks are sent to the server via WebSocket - **WebSocket (Default)**: Browser's MediaRecorder API captures audio (webm/opus), streams via WebSocket.
3. **Processing**: Server decodes audio with FFmpeg and streams into Whisper for transcription - **PyAudioWPatch (Windows Only)**: Captures system audio output directly using WASAPI loopback. Requires `--audio-input pyaudiowpatch`.
4. **Real-time Output**: 2. **Processing**:
- Partial transcriptions appear immediately in light gray (the 'aperçu') - **WebSocket**: Server decodes webm/opus audio with FFmpeg.
- Finalized text appears in normal color - **PyAudioWPatch**: Server receives raw PCM audio directly.
- Audio is streamed into Whisper for transcription.
3. **Real-time Output**:
- Partial transcriptions appear immediately in light gray (the 'aperçu').
- Finalized text appears in normal color.
- (When enabled) Different speakers are identified and highlighted - (When enabled) Different speakers are identified and highlighted
## 🚀 Deployment Guide ## 🚀 Deployment Guide

View File

@@ -1,7 +1,7 @@
from setuptools import setup, find_packages from setuptools import setup, find_packages
setup( setup(
name="whisperlivekit", name="whisperlivekit",
version="0.1.9", version="0.1.5",
description="Real-time, Fully Local Whisper's Speech-to-Text and Speaker Diarization", description="Real-time, Fully Local Whisper's Speech-to-Text and Speaker Diarization",
long_description=open("README.md", "r", encoding="utf-8").read(), long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
@@ -25,6 +25,7 @@ setup(
"whisper-timestamped": ["whisper-timestamped"], "whisper-timestamped": ["whisper-timestamped"],
"mlx-whisper": ["mlx-whisper"], "mlx-whisper": ["mlx-whisper"],
"openai": ["openai"], "openai": ["openai"],
"pyaudiowpatch": ["PyAudioWPatch"],
}, },
package_data={ package_data={
'whisperlivekit': ['web/*.html'], 'whisperlivekit': ['web/*.html'],

View File

@@ -1,5 +1,4 @@
from .core import TranscriptionEngine from .core import WhisperLiveKit, _parse_args_internal, get_parsed_args
from .audio_processor import AudioProcessor from .audio_processor import AudioProcessor
from .web.web_interface import get_web_interface_html
from .parse_args import parse_args __all__ = ['WhisperLiveKit', 'AudioProcessor', '_parse_args_internal', 'get_parsed_args']
__all__ = ['TranscriptionEngine', 'AudioProcessor', 'get_web_interface_html', 'parse_args']

View File

@@ -2,20 +2,25 @@ import asyncio
import numpy as np import numpy as np
import ffmpeg import ffmpeg
from time import time, sleep from time import time, sleep
import platform # To check OS
try:
import pyaudiowpatch as pyaudio
PYAUDIOWPATCH_AVAILABLE = True
except ImportError:
pyaudio = None
PYAUDIOWPATCH_AVAILABLE = False
import math import math
import logging import logging
import traceback import traceback
from datetime import timedelta from datetime import timedelta
from whisperlivekit.timed_objects import ASRToken from whisperlivekit.timed_objects import ASRToken
from whisperlivekit.whisper_streaming_custom.whisper_online import online_factory from whisperlivekit.whisper_streaming_custom.whisper_online import online_factory
from whisperlivekit.core import TranscriptionEngine from whisperlivekit.core import WhisperLiveKit
# Set up logging once # Set up logging once
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
SENTINEL = object() # unique sentinel object for end of stream marker
def format_time(seconds: float) -> str: def format_time(seconds: float) -> str:
"""Format seconds as HH:MM:SS.""" """Format seconds as HH:MM:SS."""
@@ -27,13 +32,10 @@ class AudioProcessor:
Handles audio processing, state management, and result formatting. Handles audio processing, state management, and result formatting.
""" """
def __init__(self, **kwargs): def __init__(self):
"""Initialize the audio processor with configuration, models, and state.""" """Initialize the audio processor with configuration, models, and state."""
if 'transcription_engine' in kwargs and isinstance(kwargs['transcription_engine'], TranscriptionEngine): models = WhisperLiveKit()
models = kwargs['transcription_engine']
else:
models = TranscriptionEngine(**kwargs)
# Audio processing settings # Audio processing settings
self.args = models.args self.args = models.args
@@ -46,9 +48,8 @@ class AudioProcessor:
self.last_ffmpeg_activity = time() self.last_ffmpeg_activity = time()
self.ffmpeg_health_check_interval = 5 self.ffmpeg_health_check_interval = 5
self.ffmpeg_max_idle_time = 10 self.ffmpeg_max_idle_time = 10
# State management # State management
self.is_stopping = False
self.tokens = [] self.tokens = []
self.buffer_transcription = "" self.buffer_transcription = ""
self.buffer_diarization = "" self.buffer_diarization = ""
@@ -64,55 +65,87 @@ class AudioProcessor:
self.asr = models.asr self.asr = models.asr
self.tokenizer = models.tokenizer self.tokenizer = models.tokenizer
self.diarization = models.diarization self.diarization = models.diarization
self.ffmpeg_process = self.start_ffmpeg_decoder()
self.transcription_queue = asyncio.Queue() if self.args.transcription else None self.transcription_queue = asyncio.Queue() if self.args.transcription else None
self.diarization_queue = asyncio.Queue() if self.args.diarization else None self.diarization_queue = asyncio.Queue() if self.args.diarization else None
self.pcm_buffer = bytearray() self.pcm_buffer = bytearray()
self.ffmpeg_process = None
self.pyaudio_instance = None
self.pyaudio_stream = None
# Task references # Initialize audio input based on args
self.transcription_task = None if self.args.audio_input == "websocket":
self.diarization_task = None self.ffmpeg_process = self.start_ffmpeg_decoder()
self.ffmpeg_reader_task = None elif self.args.audio_input == "pyaudiowpatch":
self.watchdog_task = None if not PYAUDIOWPATCH_AVAILABLE:
self.all_tasks_for_cleanup = [] logger.error("PyAudioWPatch selected but not installed. Please install it: pip install whisperlivekit[pyaudiowpatch]")
raise ImportError("PyAudioWPatch not found.")
if platform.system() != "Windows":
logger.error("PyAudioWPatch is only supported on Windows.")
raise OSError("PyAudioWPatch requires Windows.")
self.initialize_pyaudiowpatch()
else:
raise ValueError(f"Unsupported audio input type: {self.args.audio_input}")
# Initialize transcription engine if enabled # Initialize transcription engine if enabled
if self.args.transcription: if self.args.transcription:
self.online = online_factory(self.args, models.asr, models.tokenizer) self.online = online_factory(self.args, models.asr, models.tokenizer)
def initialize_pyaudiowpatch(self):
"""Initialize PyAudioWPatch for audio input."""
logger.info("Initializing PyAudioWPatch...")
try:
self.pyaudio_instance = pyaudio.PyAudio()
# Find the default WASAPI loopback device
wasapi_info = self.pyaudio_instance.get_host_api_info_by_type(pyaudio.paWASAPI)
default_speakers = self.pyaudio_instance.get_device_info_by_index(wasapi_info["defaultOutputDevice"])
if not default_speakers["isLoopbackDevice"]:
for loopback in self.pyaudio_instance.get_loopback_device_info_generator():
if default_speakers["name"] in loopback["name"]:
default_speakers = loopback
break
else:
logger.error("Default loopback output device not found.")
raise OSError("Default loopback output device not found.")
logger.info(f"Using loopback device: {default_speakers['name']}")
self.pyaudio_stream = self.pyaudio_instance.open(
format=pyaudio.paInt16,
channels=default_speakers["maxInputChannels"],
rate=int(default_speakers["defaultSampleRate"]),
input=True,
input_device_index=default_speakers["index"],
frames_per_buffer=int(self.sample_rate * self.args.min_chunk_size)
)
self.sample_rate = int(default_speakers["defaultSampleRate"])
self.channels = default_speakers["maxInputChannels"]
self.samples_per_sec = int(self.sample_rate * self.args.min_chunk_size)
self.bytes_per_sample = 2
self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample
logger.info(f"PyAudioWPatch initialized with {self.channels} channels and {self.sample_rate} Hz sample rate.")
except Exception as e:
logger.error(f"Failed to initialize PyAudioWPatch: {e}")
logger.error(traceback.format_exc())
if self.pyaudio_instance:
self.pyaudio_instance.terminate()
raise
def convert_pcm_to_float(self, pcm_buffer): def convert_pcm_to_float(self, pcm_buffer):
"""Convert PCM buffer in s16le format to normalized NumPy array.""" """Convert PCM buffer in s16le format to normalized NumPy array."""
return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0 if isinstance(pcm_buffer, (bytes, bytearray)):
return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
else:
logger.error(f"Invalid buffer type for PCM conversion: {type(pcm_buffer)}")
return np.array([], dtype=np.float32)
def start_ffmpeg_decoder(self): def start_ffmpeg_decoder(self):
"""Start FFmpeg process for WebM to PCM conversion.""" """Start FFmpeg process for WebM to PCM conversion."""
try: return (ffmpeg.input("pipe:0", format="webm")
return (ffmpeg.input("pipe:0", format="webm") .output("pipe:1", format="s16le", acodec="pcm_s16le",
.output("pipe:1", format="s16le", acodec="pcm_s16le", ac=self.channels, ar=str(self.sample_rate))
ac=self.channels, ar=str(self.sample_rate)) .run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True))
.run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True))
except FileNotFoundError:
error = """
FFmpeg is not installed or not found in your system's PATH.
Please install FFmpeg to enable audio processing.
Installation instructions:
# Ubuntu/Debian:
sudo apt update && sudo apt install ffmpeg
# macOS (using Homebrew):
brew install ffmpeg
# Windows:
# 1. Download the latest static build from https://ffmpeg.org/download.html
# 2. Extract the archive (e.g., to C:\\FFmpeg).
# 3. Add the 'bin' directory (e.g., C:\\FFmpeg\\bin) to your system's PATH environment variable.
After installation, please restart the application.
"""
logger.error(error)
raise FileNotFoundError(error)
async def restart_ffmpeg(self): async def restart_ffmpeg(self):
"""Restart the FFmpeg process after failure.""" """Restart the FFmpeg process after failure."""
@@ -161,6 +194,45 @@ class AudioProcessor:
logger.critical(f"Failed to restart FFmpeg process on second attempt: {e2}") logger.critical(f"Failed to restart FFmpeg process on second attempt: {e2}")
logger.critical(traceback.format_exc()) logger.critical(traceback.format_exc())
async def pyaudiowpatch_reader(self):
"""Read audio data from PyAudioWPatch stream and process it."""
logger.info("Starting PyAudioWPatch reader task.")
loop = asyncio.get_event_loop()
while True:
try:
chunk = await loop.run_in_executor(
None,
self.pyaudio_stream.read,
int(self.sample_rate * self.args.min_chunk_size),
False
)
if not chunk:
logger.info("PyAudioWPatch stream closed or read empty chunk.")
await asyncio.sleep(0.1)
continue
pcm_array = self.convert_pcm_to_float(chunk)
if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(pcm_array.copy())
if self.args.transcription and self.transcription_queue:
await self.transcription_queue.put(pcm_array.copy())
except OSError as e:
logger.error(f"PyAudioWPatch stream error: {e}")
logger.error(traceback.format_exc())
break
except Exception as e:
logger.error(f"Exception in pyaudiowpatch_reader: {e}")
logger.error(traceback.format_exc())
await asyncio.sleep(1) # Wait before retrying or breaking
break
logger.info("PyAudioWPatch reader task finished.")
async def update_transcription(self, new_tokens, buffer, end_buffer, full_transcription, sep): async def update_transcription(self, new_tokens, buffer, end_buffer, full_transcription, sep):
"""Thread-safe update of transcription with new data.""" """Thread-safe update of transcription with new data."""
async with self.lock: async with self.lock:
@@ -246,7 +318,7 @@ class AudioProcessor:
self.last_ffmpeg_activity = time() self.last_ffmpeg_activity = time()
if not chunk: if not chunk:
logger.info("FFmpeg stdout closed, no more data to read.") logger.info("FFmpeg stdout closed.")
break break
self.pcm_buffer.extend(chunk) self.pcm_buffer.extend(chunk)
@@ -281,86 +353,45 @@ class AudioProcessor:
logger.warning(f"Exception in ffmpeg_stdout_reader: {e}") logger.warning(f"Exception in ffmpeg_stdout_reader: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}") logger.warning(f"Traceback: {traceback.format_exc()}")
break break
logger.info("FFmpeg stdout processing finished. Signaling downstream processors.")
if self.args.transcription and self.transcription_queue:
await self.transcription_queue.put(SENTINEL)
logger.debug("Sentinel put into transcription_queue.")
if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(SENTINEL)
logger.debug("Sentinel put into diarization_queue.")
async def transcription_processor(self): async def transcription_processor(self):
"""Process audio chunks for transcription.""" """Process audio chunks for transcription."""
self.full_transcription = "" self.full_transcription = ""
self.sep = self.online.asr.sep self.sep = self.online.asr.sep
cumulative_pcm_duration_stream_time = 0.0
while True: while True:
try: try:
pcm_array = await self.transcription_queue.get() pcm_array = await self.transcription_queue.get()
if pcm_array is SENTINEL:
logger.debug("Transcription processor received sentinel. Finishing.")
self.transcription_queue.task_done()
break
if not self.online: # Should not happen if queue is used logger.info(f"{len(self.online.audio_buffer) / self.online.SAMPLING_RATE} seconds of audio to process.")
logger.warning("Transcription processor: self.online not initialized.")
self.transcription_queue.task_done()
continue
asr_internal_buffer_duration_s = len(self.online.audio_buffer) / self.online.SAMPLING_RATE
transcription_lag_s = max(0.0, time() - self.beg_loop - self.end_buffer)
logger.info(
f"ASR processing: internal_buffer={asr_internal_buffer_duration_s:.2f}s, "
f"lag={transcription_lag_s:.2f}s."
)
# Process transcription # Process transcription
duration_this_chunk = len(pcm_array) / self.sample_rate if isinstance(pcm_array, np.ndarray) else 0 self.online.insert_audio_chunk(pcm_array)
cumulative_pcm_duration_stream_time += duration_this_chunk new_tokens = self.online.process_iter()
stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time
self.online.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm)
new_tokens, current_audio_processed_upto = self.online.process_iter()
if new_tokens: if new_tokens:
self.full_transcription += self.sep.join([t.text for t in new_tokens]) self.full_transcription += self.sep.join([t.text for t in new_tokens])
# Get buffer information # Get buffer information
_buffer_transcript_obj = self.online.get_buffer() _buffer = self.online.get_buffer()
buffer_text = _buffer_transcript_obj.text buffer = _buffer.text
end_buffer = _buffer.end if _buffer.end else (
candidate_end_times = [self.end_buffer] new_tokens[-1].end if new_tokens else 0
)
if new_tokens:
candidate_end_times.append(new_tokens[-1].end)
if _buffer_transcript_obj.end is not None:
candidate_end_times.append(_buffer_transcript_obj.end)
candidate_end_times.append(current_audio_processed_upto)
new_end_buffer = max(candidate_end_times)
# Avoid duplicating content # Avoid duplicating content
if buffer_text in self.full_transcription: if buffer in self.full_transcription:
buffer_text = "" buffer = ""
await self.update_transcription( await self.update_transcription(
new_tokens, buffer_text, new_end_buffer, self.full_transcription, self.sep new_tokens, buffer, end_buffer, self.full_transcription, self.sep
) )
self.transcription_queue.task_done()
except Exception as e: except Exception as e:
logger.warning(f"Exception in transcription_processor: {e}") logger.warning(f"Exception in transcription_processor: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}") logger.warning(f"Traceback: {traceback.format_exc()}")
if 'pcm_array' in locals() and pcm_array is not SENTINEL : # Check if pcm_array was assigned from queue finally:
self.transcription_queue.task_done() self.transcription_queue.task_done()
logger.info("Transcription processor task finished.")
async def diarization_processor(self, diarization_obj): async def diarization_processor(self, diarization_obj):
"""Process audio chunks for speaker diarization.""" """Process audio chunks for speaker diarization."""
@@ -369,33 +400,23 @@ class AudioProcessor:
while True: while True:
try: try:
pcm_array = await self.diarization_queue.get() pcm_array = await self.diarization_queue.get()
if pcm_array is SENTINEL:
logger.debug("Diarization processor received sentinel. Finishing.")
self.diarization_queue.task_done()
break
# Process diarization # Process diarization
await diarization_obj.diarize(pcm_array) await diarization_obj.diarize(pcm_array)
async with self.lock: # Get current state and update speakers
new_end = diarization_obj.assign_speakers_to_tokens( state = await self.get_current_state()
self.end_attributed_speaker, new_end = diarization_obj.assign_speakers_to_tokens(
self.tokens, state["end_attributed_speaker"], state["tokens"]
use_punctuation_split=self.args.punctuation_split )
)
self.end_attributed_speaker = new_end
if buffer_diarization:
self.buffer_diarization = buffer_diarization
self.diarization_queue.task_done() await self.update_diarization(new_end, buffer_diarization)
except Exception as e: except Exception as e:
logger.warning(f"Exception in diarization_processor: {e}") logger.warning(f"Exception in diarization_processor: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}") logger.warning(f"Traceback: {traceback.format_exc()}")
if 'pcm_array' in locals() and pcm_array is not SENTINEL: finally:
self.diarization_queue.task_done() self.diarization_queue.task_done()
logger.info("Diarization processor task finished.")
async def results_formatter(self): async def results_formatter(self):
"""Format processing results for output.""" """Format processing results for output."""
@@ -459,51 +480,31 @@ class AudioProcessor:
await self.update_diarization(end_attributed_speaker, combined) await self.update_diarization(end_attributed_speaker, combined)
buffer_diarization = combined buffer_diarization = combined
response_status = "active_transcription" # Create response object
final_lines_for_response = lines.copy() if not lines:
lines = [{
if not tokens and not buffer_transcription and not buffer_diarization:
response_status = "no_audio_detected"
final_lines_for_response = []
elif response_status == "active_transcription" and not final_lines_for_response:
final_lines_for_response = [{
"speaker": 1, "speaker": 1,
"text": "", "text": "",
"beg": format_time(state.get("end_buffer", 0)), "beg": format_time(0),
"end": format_time(state.get("end_buffer", 0)), "end": format_time(tokens[-1].end if tokens else 0),
"diff": 0 "diff": 0
}] }]
response = { response = {
"status": response_status, "lines": lines,
"lines": final_lines_for_response,
"buffer_transcription": buffer_transcription, "buffer_transcription": buffer_transcription,
"buffer_diarization": buffer_diarization, "buffer_diarization": buffer_diarization,
"remaining_time_transcription": state["remaining_time_transcription"], "remaining_time_transcription": state["remaining_time_transcription"],
"remaining_time_diarization": state["remaining_time_diarization"] "remaining_time_diarization": state["remaining_time_diarization"]
} }
current_response_signature = f"{response_status} | " + \ # Only yield if content has changed
' '.join([f"{line['speaker']} {line['text']}" for line in final_lines_for_response]) + \ response_content = ' '.join([f"{line['speaker']} {line['text']}" for line in lines]) + \
f" | {buffer_transcription} | {buffer_diarization}" f" | {buffer_transcription} | {buffer_diarization}"
if current_response_signature != self.last_response_content and \ if response_content != self.last_response_content and (lines or buffer_transcription or buffer_diarization):
(final_lines_for_response or buffer_transcription or buffer_diarization or response_status == "no_audio_detected"):
yield response yield response
self.last_response_content = current_response_signature self.last_response_content = response_content
# Check for termination condition
if self.is_stopping:
all_processors_done = True
if self.args.transcription and self.transcription_task and not self.transcription_task.done():
all_processors_done = False
if self.args.diarization and self.diarization_task and not self.diarization_task.done():
all_processors_done = False
if all_processors_done:
logger.info("Results formatter: All upstream processors are done and in stopping state. Terminating.")
final_state = await self.get_current_state()
return
await asyncio.sleep(0.1) # Avoid overwhelming the client await asyncio.sleep(0.1) # Avoid overwhelming the client
@@ -514,117 +515,85 @@ class AudioProcessor:
async def create_tasks(self): async def create_tasks(self):
"""Create and start processing tasks.""" """Create and start processing tasks."""
self.all_tasks_for_cleanup = []
processing_tasks_for_watchdog = []
if self.args.transcription and self.online:
self.transcription_task = asyncio.create_task(self.transcription_processor())
self.all_tasks_for_cleanup.append(self.transcription_task)
processing_tasks_for_watchdog.append(self.transcription_task)
tasks = []
if self.args.transcription and self.online:
tasks.append(asyncio.create_task(self.transcription_processor()))
if self.args.diarization and self.diarization: if self.args.diarization and self.diarization:
self.diarization_task = asyncio.create_task(self.diarization_processor(self.diarization)) tasks.append(asyncio.create_task(self.diarization_processor(self.diarization))) # Corrected indentation
self.all_tasks_for_cleanup.append(self.diarization_task)
processing_tasks_for_watchdog.append(self.diarization_task) if self.args.audio_input == "websocket":
tasks.append(asyncio.create_task(self.ffmpeg_stdout_reader()))
self.ffmpeg_reader_task = asyncio.create_task(self.ffmpeg_stdout_reader()) elif self.args.audio_input == "pyaudiowpatch":
self.all_tasks_for_cleanup.append(self.ffmpeg_reader_task) tasks.append(asyncio.create_task(self.pyaudiowpatch_reader()))
processing_tasks_for_watchdog.append(self.ffmpeg_reader_task)
# Monitor overall system health # Monitor overall system health
self.watchdog_task = asyncio.create_task(self.watchdog(processing_tasks_for_watchdog)) async def watchdog():
self.all_tasks_for_cleanup.append(self.watchdog_task) while True:
try:
await asyncio.sleep(10) # Check every 10 seconds instead of 60
current_time = time()
# Check for stalled tasks
for i, task in enumerate(tasks):
if task.done():
exc = task.exception() if task.done() else None
task_name = task.get_name() if hasattr(task, 'get_name') else f"Task {i}"
logger.error(f"{task_name} unexpectedly completed with exception: {exc}")
if self.args.audio_input == "websocket":
ffmpeg_idle_time = current_time - self.last_ffmpeg_activity
if ffmpeg_idle_time > 15: # 15 seconds instead of 180
logger.warning(f"FFmpeg idle for {ffmpeg_idle_time:.2f}s - may need attention")
# Force restart after 30 seconds of inactivity (instead of 600)
if ffmpeg_idle_time > 30:
logger.error("FFmpeg idle for too long, forcing restart")
await self.restart_ffmpeg()
elif self.args.audio_input == "pyaudiowpatch":
if self.pyaudio_stream and not self.pyaudio_stream.is_active():
logger.warning("PyAudioWPatch stream is not active. Attempting to restart or handle.")
except Exception as e:
logger.error(f"Error in watchdog task: {e}")
logger.error(traceback.format_exc())
tasks.append(asyncio.create_task(watchdog()))
self.tasks = tasks
return self.results_formatter() return self.results_formatter()
async def watchdog(self, tasks_to_monitor):
"""Monitors the health of critical processing tasks."""
while True:
try:
await asyncio.sleep(10)
current_time = time()
for i, task in enumerate(tasks_to_monitor):
if task.done():
exc = task.exception()
task_name = task.get_name() if hasattr(task, 'get_name') else f"Monitored Task {i}"
if exc:
logger.error(f"{task_name} unexpectedly completed with exception: {exc}")
else:
logger.info(f"{task_name} completed normally.")
ffmpeg_idle_time = current_time - self.last_ffmpeg_activity
if ffmpeg_idle_time > 15:
logger.warning(f"FFmpeg idle for {ffmpeg_idle_time:.2f}s - may need attention.")
if ffmpeg_idle_time > 30 and not self.is_stopping:
logger.error("FFmpeg idle for too long and not in stopping phase, forcing restart.")
await self.restart_ffmpeg()
except asyncio.CancelledError:
logger.info("Watchdog task cancelled.")
break
except Exception as e:
logger.error(f"Error in watchdog task: {e}", exc_info=True)
async def cleanup(self): async def cleanup(self):
"""Clean up resources when processing is complete.""" """Clean up resources when processing is complete."""
logger.info("Starting cleanup of AudioProcessor resources.") for task in self.tasks:
for task in self.all_tasks_for_cleanup: task.cancel()
if task and not task.done():
task.cancel()
created_tasks = [t for t in self.all_tasks_for_cleanup if t]
if created_tasks:
await asyncio.gather(*created_tasks, return_exceptions=True)
logger.info("All processing tasks cancelled or finished.")
if self.ffmpeg_process:
if self.ffmpeg_process.stdin and not self.ffmpeg_process.stdin.closed:
try:
self.ffmpeg_process.stdin.close()
except Exception as e:
logger.warning(f"Error closing ffmpeg stdin during cleanup: {e}")
# Wait for ffmpeg process to terminate try:
if self.ffmpeg_process.poll() is None: # Check if process is still running await asyncio.gather(*self.tasks, return_exceptions=True)
logger.info("Waiting for FFmpeg process to terminate...") if self.args.audio_input == "websocket" and self.ffmpeg_process:
try: if self.ffmpeg_process.stdin:
# Run wait in executor to avoid blocking async loop self.ffmpeg_process.stdin.close()
await asyncio.get_event_loop().run_in_executor(None, self.ffmpeg_process.wait, 5.0) # 5s timeout if self.ffmpeg_process.poll() is None:
except Exception as e: # subprocess.TimeoutExpired is not directly caught by asyncio.wait_for with run_in_executor self.ffmpeg_process.wait()
logger.warning(f"FFmpeg did not terminate gracefully, killing. Error: {e}") elif self.args.audio_input == "pyaudiowpatch":
self.ffmpeg_process.kill() if self.pyaudio_stream:
await asyncio.get_event_loop().run_in_executor(None, self.ffmpeg_process.wait) # Wait for kill self.pyaudio_stream.stop_stream()
logger.info("FFmpeg process terminated.") self.pyaudio_stream.close()
logger.info("PyAudioWPatch stream closed.")
if self.args.diarization and hasattr(self, 'diarization') and hasattr(self.diarization, 'close'): if self.pyaudio_instance:
self.pyaudio_instance.terminate()
logger.info("PyAudioWPatch instance terminated.")
except Exception as e:
logger.warning(f"Error during cleanup: {e}")
logger.warning(traceback.format_exc())
if self.args.diarization and hasattr(self, 'diarization'):
self.diarization.close() self.diarization.close()
logger.info("AudioProcessor cleanup complete.")
async def process_audio(self, message): async def process_audio(self, message):
"""Process incoming audio data.""" """Process incoming audio data."""
# If already stopping or stdin is closed, ignore further audio, especially residual chunks.
if self.is_stopping or (self.ffmpeg_process and self.ffmpeg_process.stdin and self.ffmpeg_process.stdin.closed):
logger.warning(f"AudioProcessor is stopping or stdin is closed. Ignoring incoming audio message (length: {len(message)}).")
if not message and self.ffmpeg_process and self.ffmpeg_process.stdin and not self.ffmpeg_process.stdin.closed:
logger.info("Received empty message while already in stopping state; ensuring stdin is closed.")
try:
self.ffmpeg_process.stdin.close()
except Exception as e:
logger.warning(f"Error closing ffmpeg stdin on redundant stop signal during stopping state: {e}")
return
if not message: # primary signal to start stopping
logger.info("Empty audio message received, initiating stop sequence.")
self.is_stopping = True
if self.ffmpeg_process and self.ffmpeg_process.stdin and not self.ffmpeg_process.stdin.closed:
try:
self.ffmpeg_process.stdin.close()
logger.info("FFmpeg stdin closed due to primary stop signal.")
except Exception as e:
logger.warning(f"Error closing ffmpeg stdin on stop: {e}")
return
retry_count = 0 retry_count = 0
max_retries = 3 max_retries = 3
@@ -633,14 +602,37 @@ class AudioProcessor:
if not hasattr(self, '_last_heartbeat') or current_time - self._last_heartbeat >= 10: if not hasattr(self, '_last_heartbeat') or current_time - self._last_heartbeat >= 10:
logger.debug(f"Processing audio chunk, last FFmpeg activity: {current_time - self.last_ffmpeg_activity:.2f}s ago") logger.debug(f"Processing audio chunk, last FFmpeg activity: {current_time - self.last_ffmpeg_activity:.2f}s ago")
self._last_heartbeat = current_time self._last_heartbeat = current_time
if self.args.audio_input != "websocket":
# logger.debug("Audio input is not WebSocket, skipping process_audio.")
return # Do nothing if input is not WebSocket
while retry_count < max_retries: while retry_count < max_retries:
try: try:
if not self.ffmpeg_process or not hasattr(self.ffmpeg_process, 'stdin') or self.ffmpeg_process.poll() is not None:
logger.warning("FFmpeg process not available, restarting...") if not self.ffmpeg_process or self.ffmpeg_process.poll() is not None:
logger.warning("FFmpeg process not running or unavailable, attempting restart...")
await self.restart_ffmpeg() await self.restart_ffmpeg()
loop = asyncio.get_running_loop() if not self.ffmpeg_process or self.ffmpeg_process.poll() is not None:
logger.error("FFmpeg restart failed or process terminated immediately.")
# maybe raise an error or break after retries
await asyncio.sleep(1)
retry_count += 1
continue
# Ensure stdin is available
if not hasattr(self.ffmpeg_process, 'stdin') or self.ffmpeg_process.stdin.closed:
logger.warning("FFmpeg stdin is not available or closed. Restarting...")
await self.restart_ffmpeg()
if not hasattr(self.ffmpeg_process, 'stdin') or self.ffmpeg_process.stdin.closed:
logger.error("FFmpeg stdin still unavailable after restart.")
await asyncio.sleep(1)
retry_count += 1
continue
loop = asyncio.get_running_loop()
try: try:
await asyncio.wait_for( await asyncio.wait_for(
loop.run_in_executor(None, lambda: self.ffmpeg_process.stdin.write(message)), loop.run_in_executor(None, lambda: self.ffmpeg_process.stdin.write(message)),
@@ -676,4 +668,4 @@ class AudioProcessor:
else: else:
logger.error("Maximum retries reached for FFmpeg process") logger.error("Maximum retries reached for FFmpeg process")
await self.restart_ffmpeg() await self.restart_ffmpeg()
return return

View File

@@ -2,26 +2,48 @@ from contextlib import asynccontextmanager
from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from whisperlivekit import TranscriptionEngine, AudioProcessor, get_web_interface_html, parse_args
from whisperlivekit import WhisperLiveKit, get_parsed_args
from whisperlivekit.audio_processor import AudioProcessor
import asyncio import asyncio
import logging import logging
import os, sys
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logging.getLogger().setLevel(logging.WARNING) logging.getLogger().setLevel(logging.WARNING)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
args = parse_args()
transcription_engine = None
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
global transcription_engine logger.info("Starting up...")
transcription_engine = TranscriptionEngine( kit = WhisperLiveKit()
**vars(args), app.state.kit = kit
) logger.info(f"Audio Input mode: {kit.args.audio_input}")
audio_processor = AudioProcessor()
app.state.audio_processor = audio_processor
app.state.results_generator = None # Initialize
if kit.args.audio_input == "pyaudiowpatch":
logger.info("Starting PyAudioWPatch processing tasks...")
try:
app.state.results_generator = await audio_processor.create_tasks()
except Exception as e:
logger.critical(f"Failed to start PyAudioWPatch processing: {e}", exc_info=True)
else:
logger.info("WebSocket input mode selected. Processing will start on client connection.")
yield yield
logger.info("Shutting down...")
if hasattr(app.state, 'audio_processor') and app.state.audio_processor:
logger.info("Cleaning up AudioProcessor...")
await app.state.audio_processor.cleanup()
logger.info("Shutdown complete.")
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
@@ -31,74 +53,126 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
@app.get("/") @app.get("/")
async def get(): async def get():
return HTMLResponse(get_web_interface_html()) return HTMLResponse(app.state.kit.web_interface())
async def handle_websocket_results(websocket, results_generator): async def handle_websocket_results(websocket: WebSocket, results_generator):
"""Consumes results from the audio processor and sends them via WebSocket.""" """Consumes results from the audio processor and sends them via WebSocket."""
try: try:
async for response in results_generator: async for response in results_generator:
await websocket.send_json(response) await websocket.send_json(response)
# when the results_generator finishes it means all audio has been processed
logger.info("Results generator finished. Sending 'ready_to_stop' to client.")
await websocket.send_json({"type": "ready_to_stop"})
except WebSocketDisconnect:
logger.info("WebSocket disconnected while handling results (client likely closed connection).")
except Exception as e: except Exception as e:
logger.warning(f"Error in WebSocket results handler: {e}") logger.warning(f"Error in WebSocket results handler: {e}")
@app.websocket("/asr") @app.websocket("/asr")
async def websocket_endpoint(websocket: WebSocket): async def websocket_endpoint(websocket: WebSocket):
global transcription_engine
audio_processor = AudioProcessor(
transcription_engine=transcription_engine,
)
await websocket.accept() await websocket.accept()
logger.info("WebSocket connection opened.") logger.info("WebSocket connection accepted.")
results_generator = await audio_processor.create_tasks() audio_processor = app.state.audio_processor
websocket_task = asyncio.create_task(handle_websocket_results(websocket, results_generator)) kit_args = app.state.kit.args
results_generator = None
websocket_task = None
receive_task = None
try: try:
while True: if kit_args.audio_input == "websocket":
message = await websocket.receive_bytes() logger.info("WebSocket mode: Starting processing tasks for this connection.")
await audio_processor.process_audio(message) results_generator = await audio_processor.create_tasks()
except KeyError as e: websocket_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
if 'bytes' in str(e):
logger.warning(f"Client has closed the connection.") async def receive_audio():
else: try:
logger.error(f"Unexpected KeyError in websocket_endpoint: {e}", exc_info=True) while True:
except WebSocketDisconnect: message = await websocket.receive_bytes()
logger.info("WebSocket disconnected by client during message receiving loop.") await audio_processor.process_audio(message)
except Exception as e: except WebSocketDisconnect:
logger.error(f"Unexpected error in websocket_endpoint main loop: {e}", exc_info=True) logger.info("WebSocket disconnected by client (receive_audio).")
finally: except Exception as e:
logger.info("Cleaning up WebSocket endpoint...") logger.error(f"Error receiving audio: {e}", exc_info=True)
if not websocket_task.done(): finally:
websocket_task.cancel() logger.debug("Receive audio task finished.")
try:
receive_task = asyncio.create_task(receive_audio())
done, pending = await asyncio.wait(
{websocket_task, receive_task},
return_when=asyncio.FIRST_COMPLETED,
)
for task in pending:
task.cancel() # Cancel the other task
elif kit_args.audio_input == "pyaudiowpatch":
logger.info("PyAudioWPatch mode: Streaming existing results.")
results_generator = app.state.results_generator
if results_generator is None:
logger.error("PyAudioWPatch results generator not available. Was startup successful?")
await websocket.close(code=1011, reason="Server error: Audio processing not started.")
return
websocket_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
await websocket_task await websocket_task
except asyncio.CancelledError:
logger.info("WebSocket results handler task was cancelled.") else:
except Exception as e: logger.error(f"Unsupported audio input mode configured: {kit_args.audio_input}")
logger.warning(f"Exception while awaiting websocket_task completion: {e}") await websocket.close(code=1011, reason="Server configuration error.")
await audio_processor.cleanup() except WebSocketDisconnect:
logger.info("WebSocket endpoint cleaned up successfully.") logger.info("WebSocket disconnected by client.")
except Exception as e:
logger.error(f"Error in WebSocket endpoint: {e}", exc_info=True)
# Attempt to close gracefully
try:
await websocket.close(code=1011, reason=f"Server error: {e}")
except Exception:
pass # Ignore errors during close after another error
finally:
logger.info("Cleaning up WebSocket connection...")
if websocket_task and not websocket_task.done():
websocket_task.cancel()
if receive_task and not receive_task.done():
receive_task.cancel()
if kit_args.audio_input == "websocket":
pass
logger.info("WebSocket connection closed.")
def main(): def main():
"""Entry point for the CLI command.""" """Entry point for the CLI command."""
import uvicorn import uvicorn
# Get the globally parsed arguments
args = get_parsed_args()
# Set logger level based on args
log_level_name = args.log_level.upper()
# Ensure the level name is valid for the logging module
numeric_level = getattr(logging, log_level_name, None)
if not isinstance(numeric_level, int):
logging.warning(f"Invalid log level: {args.log_level}. Defaulting to INFO.")
numeric_level = logging.INFO
logging.getLogger().setLevel(numeric_level) # Set root logger level
# Set our specific logger level too
logger.setLevel(numeric_level)
logger.info(f"Log level set to: {log_level_name}")
# Determine uvicorn log level (map CRITICAL to critical, etc.)
uvicorn_log_level = log_level_name.lower()
if uvicorn_log_level == "debug": # Uvicorn uses 'trace' for more verbose than debug
uvicorn_log_level = "trace"
uvicorn_kwargs = { uvicorn_kwargs = {
"app": "whisperlivekit.basic_server:app", "app": "whisperlivekit.basic_server:app",
"host":args.host, "host":args.host,
"port":args.port, "port":args.port,
"reload": False, "reload": False,
"log_level": "info", "log_level": uvicorn_log_level,
"lifespan": "on", "lifespan": "on",
} }
@@ -111,6 +185,7 @@ def main():
"ssl_keyfile": args.ssl_keyfile "ssl_keyfile": args.ssl_keyfile
} }
if ssl_kwargs: if ssl_kwargs:
uvicorn_kwargs = {**uvicorn_kwargs, **ssl_kwargs} uvicorn_kwargs = {**uvicorn_kwargs, **ssl_kwargs}

View File

@@ -1,66 +1,187 @@
import sys
from argparse import Namespace, ArgumentParser
try: try:
from whisperlivekit.whisper_streaming_custom.whisper_online import backend_factory, warmup_asr from whisperlivekit.whisper_streaming_custom.whisper_online import backend_factory, warmup_asr
except ImportError: except ImportError:
from .whisper_streaming_custom.whisper_online import backend_factory, warmup_asr if '.' not in sys.path:
from argparse import Namespace sys.path.insert(0, '.')
from whisperlivekit.whisper_streaming_custom.whisper_online import backend_factory, warmup_asr
def _parse_args_internal():
parser = ArgumentParser(description="Whisper FastAPI Online Server")
parser.add_argument(
"--host",
type=str,
default="localhost",
help="The host address to bind the server to.",
)
parser.add_argument(
"--port", type=int, default=8000, help="The port number to bind the server to."
)
parser.add_argument(
"--warmup-file",
type=str,
default=None,
dest="warmup_file",
help="""
The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast.
If not set, uses https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav.
If False, no warmup is performed.
""",
)
parser.add_argument(
"--confidence-validation",
action="store_true",
help="Accelerates validation of tokens using confidence scores. Transcription will be faster but punctuation might be less accurate.",
)
parser.add_argument(
"--diarization",
action="store_true",
default=False,
help="Enable speaker diarization.",
)
parser.add_argument(
"--no-transcription",
action="store_true",
help="Disable transcription to only see live diarization results.",
)
parser.add_argument(
"--min-chunk-size",
type=float,
default=0.5,
help="Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.",
)
parser.add_argument(
"--model",
type=str,
default="tiny",
help="Name size of the Whisper model to use (default: tiny). Suggested values: tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo. The model is automatically downloaded from the model hub if not present in model cache dir.",
)
parser.add_argument(
"--model_cache_dir",
type=str,
default=None,
help="Overriding the default model cache dir where models downloaded from the hub are saved",
)
parser.add_argument(
"--model_dir",
type=str,
default=None,
help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.",
)
parser.add_argument(
"--lan",
"--language",
type=str,
default="auto",
help="Source language code, e.g. en,de,cs, or 'auto' for language detection.",
)
parser.add_argument(
"--task",
type=str,
default="transcribe",
choices=["transcribe", "translate"],
help="Transcribe or translate.",
)
parser.add_argument(
"--backend",
type=str,
default="faster-whisper",
choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"],
help="Load only this backend for Whisper processing.",
)
parser.add_argument(
"--vac",
action="store_true",
default=False,
help="Use VAC = voice activity controller. Recommended. Requires torch.",
)
parser.add_argument(
"--vac-chunk-size", type=float, default=0.04, help="VAC sample size in seconds."
)
parser.add_argument(
"--no-vad",
action="store_true",
help="Disable VAD (voice activity detection).",
)
parser.add_argument(
"--buffer_trimming",
type=str,
default="segment",
choices=["sentence", "segment"],
help='Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.',
)
parser.add_argument(
"--buffer_trimming_sec",
type=float,
default=15,
help="Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.",
)
parser.add_argument(
"-l",
"--log-level",
dest="log_level",
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Set the log level",
default="DEBUG",
)
parser.add_argument(
"--audio-input",
type=str,
default="websocket",
choices=["websocket", "pyaudiowpatch"],
help="Source of the audio input. 'websocket' expects audio via WebSocket (default). 'pyaudiowpatch' uses PyAudioWPatch to capture system audio output.",
)
parser.add_argument("--ssl-certfile", type=str, help="Path to the SSL certificate file.", default=None)
parser.add_argument("--ssl-keyfile", type=str, help="Path to the SSL private key file.", default=None)
class TranscriptionEngine: args = parser.parse_args()
args.transcription = not args.no_transcription
args.vad = not args.no_vad
delattr(args, 'no_transcription')
delattr(args, 'no_vad')
return args
_cli_args = _parse_args_internal()
def get_parsed_args() -> Namespace:
"""Returns the globally parsed command-line arguments."""
return _cli_args
# --- WhisperLiveKit Class ---
class WhisperLiveKit:
_instance = None _instance = None
_initialized = False _initialized = False
def __new__(cls, *args, **kwargs): def __new__(cls, args: Namespace = None, **kwargs):
if cls._instance is None: if cls._instance is None:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
return cls._instance return cls._instance
def __init__(self, **kwargs): def __init__(self, args: Namespace = None, **kwargs):
if TranscriptionEngine._initialized: """
Initializes WhisperLiveKit.
Args:
args (Namespace, optional): Pre-parsed arguments. If None, uses globally parsed args.
Defaults to None.
**kwargs: Additional keyword arguments (currently not used directly but captured).
"""
if WhisperLiveKit._initialized:
return return
defaults = { self.args = args if args is not None else get_parsed_args()
"host": "localhost",
"port": 8000,
"warmup_file": None,
"confidence_validation": False,
"diarization": False,
"punctuation_split": False,
"min_chunk_size": 0.5,
"model": "tiny",
"model_cache_dir": None,
"model_dir": None,
"lan": "auto",
"task": "transcribe",
"backend": "faster-whisper",
"vac": False,
"vac_chunk_size": 0.04,
"buffer_trimming": "segment",
"buffer_trimming_sec": 15,
"log_level": "DEBUG",
"ssl_certfile": None,
"ssl_keyfile": None,
"transcription": True,
"vad": True,
"segmentation_model": "pyannote/segmentation-3.0",
"embedding_model": "pyannote/embedding",
}
config_dict = {**defaults, **kwargs}
if 'no_transcription' in kwargs:
config_dict['transcription'] = not kwargs['no_transcription']
if 'no_vad' in kwargs:
config_dict['vad'] = not kwargs['no_vad']
config_dict.pop('no_transcription', None)
config_dict.pop('no_vad', None)
if 'language' in kwargs:
config_dict['lan'] = kwargs['language']
config_dict.pop('language', None)
self.args = Namespace(**config_dict)
self.asr = None self.asr = None
self.tokenizer = None self.tokenizer = None
self.diarization = None self.diarization = None
@@ -71,10 +192,13 @@ class TranscriptionEngine:
if self.args.diarization: if self.args.diarization:
from whisperlivekit.diarization.diarization_online import DiartDiarization from whisperlivekit.diarization.diarization_online import DiartDiarization
self.diarization = DiartDiarization( self.diarization = DiartDiarization()
block_duration=self.args.min_chunk_size,
segmentation_model_name=self.args.segmentation_model,
embedding_model_name=self.args.embedding_model
)
TranscriptionEngine._initialized = True WhisperLiveKit._initialized = True
def web_interface(self):
import pkg_resources
html_path = pkg_resources.resource_filename('whisperlivekit', 'web/live_transcription.html')
with open(html_path, "r", encoding="utf-8") as f:
html = f.read()
return html

View File

@@ -3,8 +3,7 @@ import re
import threading import threading
import numpy as np import numpy as np
import logging import logging
import time
from queue import SimpleQueue, Empty
from diart import SpeakerDiarization, SpeakerDiarizationConfig from diart import SpeakerDiarization, SpeakerDiarizationConfig
from diart.inference import StreamingInference from diart.inference import StreamingInference
@@ -14,7 +13,6 @@ from diart.sources import MicrophoneAudioSource
from rx.core import Observer from rx.core import Observer
from typing import Tuple, Any, List from typing import Tuple, Any, List
from pyannote.core import Annotation from pyannote.core import Annotation
import diart.models as m
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -80,114 +78,40 @@ class DiarizationObserver(Observer):
class WebSocketAudioSource(AudioSource): class WebSocketAudioSource(AudioSource):
""" """
Buffers incoming audio and releases it in fixed-size chunks at regular intervals. Custom AudioSource that blocks in read() until close() is called.
Use push_audio() to inject PCM chunks.
""" """
def __init__(self, uri: str = "websocket", sample_rate: int = 16000, block_duration: float = 0.5): def __init__(self, uri: str = "websocket", sample_rate: int = 16000):
super().__init__(uri, sample_rate) super().__init__(uri, sample_rate)
self.block_duration = block_duration
self.block_size = int(np.rint(block_duration * sample_rate))
self._queue = SimpleQueue()
self._buffer = np.array([], dtype=np.float32)
self._buffer_lock = threading.Lock()
self._closed = False self._closed = False
self._close_event = threading.Event() self._close_event = threading.Event()
self._processing_thread = None
self._last_chunk_time = time.time()
def read(self): def read(self):
"""Start processing buffered audio and emit fixed-size chunks."""
self._processing_thread = threading.Thread(target=self._process_chunks)
self._processing_thread.daemon = True
self._processing_thread.start()
self._close_event.wait() self._close_event.wait()
if self._processing_thread:
self._processing_thread.join(timeout=2.0)
def _process_chunks(self):
"""Process audio from queue and emit fixed-size chunks at regular intervals."""
while not self._closed:
try:
audio_chunk = self._queue.get(timeout=0.1)
with self._buffer_lock:
self._buffer = np.concatenate([self._buffer, audio_chunk])
while len(self._buffer) >= self.block_size:
chunk = self._buffer[:self.block_size]
self._buffer = self._buffer[self.block_size:]
current_time = time.time()
time_since_last = current_time - self._last_chunk_time
if time_since_last < self.block_duration:
time.sleep(self.block_duration - time_since_last)
chunk_reshaped = chunk.reshape(1, -1)
self.stream.on_next(chunk_reshaped)
self._last_chunk_time = time.time()
except Empty:
with self._buffer_lock:
if len(self._buffer) > 0 and time.time() - self._last_chunk_time > self.block_duration:
padded_chunk = np.zeros(self.block_size, dtype=np.float32)
padded_chunk[:len(self._buffer)] = self._buffer
self._buffer = np.array([], dtype=np.float32)
chunk_reshaped = padded_chunk.reshape(1, -1)
self.stream.on_next(chunk_reshaped)
self._last_chunk_time = time.time()
except Exception as e:
logger.error(f"Error in audio processing thread: {e}")
self.stream.on_error(e)
break
with self._buffer_lock:
if len(self._buffer) > 0:
padded_chunk = np.zeros(self.block_size, dtype=np.float32)
padded_chunk[:len(self._buffer)] = self._buffer
chunk_reshaped = padded_chunk.reshape(1, -1)
self.stream.on_next(chunk_reshaped)
self.stream.on_completed()
def close(self): def close(self):
if not self._closed: if not self._closed:
self._closed = True self._closed = True
self.stream.on_completed()
self._close_event.set() self._close_event.set()
def push_audio(self, chunk: np.ndarray): def push_audio(self, chunk: np.ndarray):
"""Add audio chunk to the processing queue."""
if not self._closed: if not self._closed:
if chunk.ndim > 1: new_audio = np.expand_dims(chunk, axis=0)
chunk = chunk.flatten() logger.debug('Add new chunk with shape:', new_audio.shape)
self._queue.put(chunk) self.stream.on_next(new_audio)
logger.debug(f'Added chunk to queue with {len(chunk)} samples')
class DiartDiarization: class DiartDiarization:
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 0.5, segmentation_model_name: str = "pyannote/segmentation-3.0", embedding_model_name: str = "speechbrain/spkrec-ecapa-voxceleb"): def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False):
segmentation_model = m.SegmentationModel.from_pretrained(segmentation_model_name)
embedding_model = m.EmbeddingModel.from_pretrained(embedding_model_name)
if config is None:
config = SpeakerDiarizationConfig(
segmentation=segmentation_model,
embedding=embedding_model,
)
self.pipeline = SpeakerDiarization(config=config) self.pipeline = SpeakerDiarization(config=config)
self.observer = DiarizationObserver() self.observer = DiarizationObserver()
self.lag_diart = None
if use_microphone: if use_microphone:
self.source = MicrophoneAudioSource(block_duration=block_duration) self.source = MicrophoneAudioSource()
self.custom_source = None self.custom_source = None
else: else:
self.custom_source = WebSocketAudioSource( self.custom_source = WebSocketAudioSource(uri="websocket_source", sample_rate=sample_rate)
uri="websocket_source",
sample_rate=sample_rate,
block_duration=block_duration
)
self.source = self.custom_source self.source = self.custom_source
self.inference = StreamingInference( self.inference = StreamingInference(
@@ -214,102 +138,16 @@ class DiartDiarization:
if self.custom_source: if self.custom_source:
self.custom_source.close() self.custom_source.close()
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list, use_punctuation_split: bool = False) -> float: def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) -> float:
""" """
Assign speakers to tokens based on timing overlap with speaker segments. Assign speakers to tokens based on timing overlap with speaker segments.
Uses the segments collected by the observer. Uses the segments collected by the observer.
If use_punctuation_split is True, uses punctuation marks to refine speaker boundaries.
""" """
segments = self.observer.get_segments() segments = self.observer.get_segments()
# Debug logging
logger.debug(f"assign_speakers_to_tokens called with {len(tokens)} tokens")
logger.debug(f"Available segments: {len(segments)}")
for i, seg in enumerate(segments[:5]): # Show first 5 segments
logger.debug(f" Segment {i}: {seg.speaker} [{seg.start:.2f}-{seg.end:.2f}]")
if not self.lag_diart and segments and tokens:
self.lag_diart = segments[0].start - tokens[0].start
for token in tokens: for token in tokens:
for segment in segments: for segment in segments:
if not (segment.end <= token.start + self.lag_diart or segment.start >= token.end + self.lag_diart): if not (segment.end <= token.start or segment.start >= token.end):
token.speaker = extract_number(segment.speaker) + 1 token.speaker = extract_number(segment.speaker) + 1
end_attributed_speaker = max(token.end, end_attributed_speaker) end_attributed_speaker = max(token.end, end_attributed_speaker)
return end_attributed_speaker
if use_punctuation_split and len(tokens) > 1:
punctuation_marks = {'.', '!', '?'}
print("Here are the tokens:",
[(t.text, t.start, t.end, t.speaker) for t in tokens[:10]])
segment_map = []
for segment in segments:
speaker_num = extract_number(segment.speaker) + 1
segment_map.append((segment.start, segment.end, speaker_num))
segment_map.sort(key=lambda x: x[0])
i = 0
while i < len(tokens):
current_token = tokens[i]
is_sentence_end = False
if current_token.text and current_token.text.strip():
text = current_token.text.strip()
if text[-1] in punctuation_marks:
is_sentence_end = True
logger.debug(f"Token {i} ends sentence: '{current_token.text}' at {current_token.end:.2f}s")
if is_sentence_end and current_token.speaker != -1:
punctuation_time = current_token.end
current_speaker = current_token.speaker
j = i + 1
next_sentence_tokens = []
while j < len(tokens):
next_token = tokens[j]
next_sentence_tokens.append(j)
# Check if this token ends the next sentence
if next_token.text and next_token.text.strip():
if next_token.text.strip()[-1] in punctuation_marks:
break
j += 1
if next_sentence_tokens:
speaker_times = {}
for idx in next_sentence_tokens:
token = tokens[idx]
# Find which segments overlap with this token
for seg_start, seg_end, seg_speaker in segment_map:
if not (seg_end <= token.start or seg_start >= token.end):
# Calculate overlap duration
overlap_start = max(seg_start, token.start)
overlap_end = min(seg_end, token.end)
overlap_duration = overlap_end - overlap_start
if seg_speaker not in speaker_times:
speaker_times[seg_speaker] = 0
speaker_times[seg_speaker] += overlap_duration
if speaker_times:
dominant_speaker = max(speaker_times.items(), key=lambda x: x[1])[0]
if dominant_speaker != current_speaker:
logger.debug(f" Speaker change after punctuation: {current_speaker}{dominant_speaker}")
for idx in next_sentence_tokens:
if tokens[idx].speaker != dominant_speaker:
logger.debug(f" Reassigning token {idx} ('{tokens[idx].text}') to Speaker {dominant_speaker}")
tokens[idx].speaker = dominant_speaker
end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker)
else:
for idx in next_sentence_tokens:
if tokens[idx].speaker == -1:
tokens[idx].speaker = current_speaker
end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker)
i += 1
return end_attributed_speaker

View File

@@ -1,162 +0,0 @@
from argparse import ArgumentParser
def parse_args():
parser = ArgumentParser(description="Whisper FastAPI Online Server")
parser.add_argument(
"--host",
type=str,
default="localhost",
help="The host address to bind the server to.",
)
parser.add_argument(
"--port", type=int, default=8000, help="The port number to bind the server to."
)
parser.add_argument(
"--warmup-file",
type=str,
default=None,
dest="warmup_file",
help="""
The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast.
If not set, uses https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav.
If False, no warmup is performed.
""",
)
parser.add_argument(
"--confidence-validation",
action="store_true",
help="Accelerates validation of tokens using confidence scores. Transcription will be faster but punctuation might be less accurate.",
)
parser.add_argument(
"--diarization",
action="store_true",
default=False,
help="Enable speaker diarization.",
)
parser.add_argument(
"--punctuation-split",
action="store_true",
default=False,
help="Use punctuation marks from transcription to improve speaker boundary detection. Requires both transcription and diarization to be enabled.",
)
parser.add_argument(
"--segmentation-model",
type=str,
default="pyannote/segmentation-3.0",
help="Hugging Face model ID for pyannote.audio segmentation model.",
)
parser.add_argument(
"--embedding-model",
type=str,
default="pyannote/embedding",
help="Hugging Face model ID for pyannote.audio embedding model.",
)
parser.add_argument(
"--no-transcription",
action="store_true",
help="Disable transcription to only see live diarization results.",
)
parser.add_argument(
"--min-chunk-size",
type=float,
default=0.5,
help="Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.",
)
parser.add_argument(
"--model",
type=str,
default="tiny",
help="Name size of the Whisper model to use (default: tiny). Suggested values: tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo. The model is automatically downloaded from the model hub if not present in model cache dir.",
)
parser.add_argument(
"--model_cache_dir",
type=str,
default=None,
help="Overriding the default model cache dir where models downloaded from the hub are saved",
)
parser.add_argument(
"--model_dir",
type=str,
default=None,
help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.",
)
parser.add_argument(
"--lan",
"--language",
type=str,
default="auto",
help="Source language code, e.g. en,de,cs, or 'auto' for language detection.",
)
parser.add_argument(
"--task",
type=str,
default="transcribe",
choices=["transcribe", "translate"],
help="Transcribe or translate.",
)
parser.add_argument(
"--backend",
type=str,
default="faster-whisper",
choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"],
help="Load only this backend for Whisper processing.",
)
parser.add_argument(
"--vac",
action="store_true",
default=False,
help="Use VAC = voice activity controller. Recommended. Requires torch.",
)
parser.add_argument(
"--vac-chunk-size", type=float, default=0.04, help="VAC sample size in seconds."
)
parser.add_argument(
"--no-vad",
action="store_true",
help="Disable VAD (voice activity detection).",
)
parser.add_argument(
"--buffer_trimming",
type=str,
default="segment",
choices=["sentence", "segment"],
help='Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.',
)
parser.add_argument(
"--buffer_trimming_sec",
type=float,
default=15,
help="Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.",
)
parser.add_argument(
"-l",
"--log-level",
dest="log_level",
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Set the log level",
default="DEBUG",
)
parser.add_argument("--ssl-certfile", type=str, help="Path to the SSL certificate file.", default=None)
parser.add_argument("--ssl-keyfile", type=str, help="Path to the SSL private key file.", default=None)
args = parser.parse_args()
args.transcription = not args.no_transcription
args.vad = not args.no_vad
delattr(args, 'no_transcription')
delattr(args, 'no_vad')
return args

View File

@@ -26,7 +26,4 @@ class Transcript(TimedText):
@dataclass @dataclass
class SpeakerSegment(TimedText): class SpeakerSegment(TimedText):
"""Represents a segment of audio attributed to a specific speaker.
No text nor probability is associated with this segment.
"""
pass pass

View File

@@ -308,7 +308,6 @@
let waveCtx = waveCanvas.getContext("2d"); let waveCtx = waveCanvas.getContext("2d");
let animationFrame = null; let animationFrame = null;
let waitingForStop = false; let waitingForStop = false;
let lastReceivedData = null;
waveCanvas.width = 60 * (window.devicePixelRatio || 1); waveCanvas.width = 60 * (window.devicePixelRatio || 1);
waveCanvas.height = 30 * (window.devicePixelRatio || 1); waveCanvas.height = 30 * (window.devicePixelRatio || 1);
waveCtx.scale(window.devicePixelRatio || 1, window.devicePixelRatio || 1); waveCtx.scale(window.devicePixelRatio || 1, window.devicePixelRatio || 1);
@@ -358,31 +357,18 @@
websocket.onclose = () => { websocket.onclose = () => {
if (userClosing) { if (userClosing) {
if (waitingForStop) { if (!statusText.textContent.includes("Recording stopped. Processing final audio")) { // This is a bit of a hack. We should have a better way to handle this. eg. using a status code.
statusText.textContent = "Processing finalized or connection closed."; statusText.textContent = "Finished processing audio! Ready to record again.";
if (lastReceivedData) {
renderLinesWithBuffer(
lastReceivedData.lines || [],
lastReceivedData.buffer_diarization || "",
lastReceivedData.buffer_transcription || "",
0, 0, true // isFinalizing = true
);
}
} }
// If ready_to_stop was received, statusText is already "Finished processing..." waitingForStop = false;
// and waitingForStop is false.
} else { } else {
statusText.textContent = "Disconnected from the WebSocket server. (Check logs if model is loading.)"; statusText.textContent =
"Disconnected from the WebSocket server. (Check logs if model is loading.)";
if (isRecording) { if (isRecording) {
stopRecording(); stopRecording();
} }
} }
isRecording = false; userClosing = false;
waitingForStop = false;
userClosing = false;
lastReceivedData = null;
websocket = null;
updateUI();
}; };
websocket.onerror = () => { websocket.onerror = () => {
@@ -396,39 +382,31 @@
// Check for status messages // Check for status messages
if (data.type === "ready_to_stop") { if (data.type === "ready_to_stop") {
console.log("Ready to stop received, finalizing display and closing WebSocket."); console.log("Ready to stop, closing WebSocket");
waitingForStop = false;
if (lastReceivedData) { // signal that we are not waiting for stop anymore
renderLinesWithBuffer( waitingForStop = false;
lastReceivedData.lines || [], recordButton.disabled = false; // this should be elsewhere
lastReceivedData.buffer_diarization || "", console.log("Record button enabled");
lastReceivedData.buffer_transcription || "",
0, // No more lag //Now we can close the WebSocket
0, // No more lag
true // isFinalizing = true
);
}
statusText.textContent = "Finished processing audio! Ready to record again.";
recordButton.disabled = false;
if (websocket) { if (websocket) {
websocket.close(); // will trigger onclose websocket.close();
// websocket = null; // onclose handle setting websocket to null websocket = null;
} }
return; return;
} }
lastReceivedData = data;
// Handle normal transcription updates // Handle normal transcription updates
const { const {
lines = [], lines = [],
buffer_transcription = "", buffer_transcription = "",
buffer_diarization = "", buffer_diarization = "",
remaining_time_transcription = 0, remaining_time_transcription = 0,
remaining_time_diarization = 0, remaining_time_diarization = 0
status = "active_transcription"
} = data; } = data;
renderLinesWithBuffer( renderLinesWithBuffer(
@@ -436,20 +414,13 @@
buffer_diarization, buffer_diarization,
buffer_transcription, buffer_transcription,
remaining_time_diarization, remaining_time_diarization,
remaining_time_transcription, remaining_time_transcription
false,
status
); );
}; };
}); });
} }
function renderLinesWithBuffer(lines, buffer_diarization, buffer_transcription, remaining_time_diarization, remaining_time_transcription, isFinalizing = false, current_status = "active_transcription") { function renderLinesWithBuffer(lines, buffer_diarization, buffer_transcription, remaining_time_diarization, remaining_time_transcription) {
if (current_status === "no_audio_detected") {
linesTranscriptDiv.innerHTML = "<p style='text-align: center; color: #666; margin-top: 20px;'><em>No audio detected...</em></p>";
return;
}
const linesHtml = lines.map((item, idx) => { const linesHtml = lines.map((item, idx) => {
let timeInfo = ""; let timeInfo = "";
if (item.beg !== undefined && item.end !== undefined) { if (item.beg !== undefined && item.end !== undefined) {
@@ -459,46 +430,30 @@
let speakerLabel = ""; let speakerLabel = "";
if (item.speaker === -2) { if (item.speaker === -2) {
speakerLabel = `<span class="silence">Silence<span id='timeInfo'>${timeInfo}</span></span>`; speakerLabel = `<span class="silence">Silence<span id='timeInfo'>${timeInfo}</span></span>`;
} else if (item.speaker == 0 && !isFinalizing) { } else if (item.speaker == 0) {
speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'>${remaining_time_diarization} second(s) of audio are undergoing diarization</span></span>`; speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'>${remaining_time_diarization} second(s) of audio are undergoing diarization</span></span>`;
} else if (item.speaker == -1) { } else if (item.speaker == -1) {
speakerLabel = `<span id="speaker">Speaker 1<span id='timeInfo'>${timeInfo}</span></span>`; speakerLabel = `<span id="speaker"><span id='timeInfo'>${timeInfo}</span></span>`;
} else if (item.speaker !== -1 && item.speaker !== 0) { } else if (item.speaker !== -1) {
speakerLabel = `<span id="speaker">Speaker ${item.speaker}<span id='timeInfo'>${timeInfo}</span></span>`; speakerLabel = `<span id="speaker">Speaker ${item.speaker}<span id='timeInfo'>${timeInfo}</span></span>`;
} }
let textContent = item.text;
let currentLineText = item.text || ""; if (idx === lines.length - 1) {
speakerLabel += `<span class="label_transcription"><span class="spinner"></span>Transcription lag <span id='timeInfo'>${remaining_time_transcription}s</span></span>`
if (idx === lines.length - 1) { }
if (!isFinalizing) { if (idx === lines.length - 1 && buffer_diarization) {
if (remaining_time_transcription > 0) { speakerLabel += `<span class="label_diarization"><span class="spinner"></span>Diarization lag<span id='timeInfo'>${remaining_time_diarization}s</span></span>`
speakerLabel += `<span class="label_transcription"><span class="spinner"></span>Transcription lag <span id='timeInfo'>${remaining_time_transcription}s</span></span>`; textContent += `<span class="buffer_diarization">${buffer_diarization}</span>`;
} }
if (buffer_diarization && remaining_time_diarization > 0) { if (idx === lines.length - 1) {
speakerLabel += `<span class="label_diarization"><span class="spinner"></span>Diarization lag<span id='timeInfo'>${remaining_time_diarization}s</span></span>`; textContent += `<span class="buffer_transcription">${buffer_transcription}</span>`;
}
}
if (buffer_diarization) {
if (isFinalizing) {
currentLineText += (currentLineText.length > 0 && buffer_diarization.trim().length > 0 ? " " : "") + buffer_diarization.trim();
} else {
currentLineText += `<span class="buffer_diarization">${buffer_diarization}</span>`;
}
}
if (buffer_transcription) {
if (isFinalizing) {
currentLineText += (currentLineText.length > 0 && buffer_transcription.trim().length > 0 ? " " : "") + buffer_transcription.trim();
} else {
currentLineText += `<span class="buffer_transcription">${buffer_transcription}</span>`;
}
}
} }
return currentLineText.trim().length > 0 || speakerLabel.length > 0
? `<p>${speakerLabel}<br/><div class='textcontent'>${currentLineText}</div></p>` return textContent
: `<p>${speakerLabel}<br/></p>`; ? `<p>${speakerLabel}<br/><div class='textcontent'>${textContent}</div></p>`
: `<p>${speakerLabel}<br/></p>`;
}).join(""); }).join("");
linesTranscriptDiv.innerHTML = linesHtml; linesTranscriptDiv.innerHTML = linesHtml;
@@ -623,6 +578,20 @@
timerElement.textContent = "00:00"; timerElement.textContent = "00:00";
startTime = null; startTime = null;
if (websocket && websocket.readyState === WebSocket.OPEN) {
try {
await websocket.send(JSON.stringify({
type: "stop",
message: "User stopped recording"
}));
statusText.textContent = "Recording stopped. Processing final audio...";
} catch (e) {
console.error("Could not send stop message:", e);
statusText.textContent = "Recording stopped. Error during final audio processing.";
websocket.close();
websocket = null;
}
}
isRecording = false; isRecording = false;
updateUI(); updateUI();
@@ -656,22 +625,19 @@
function updateUI() { function updateUI() {
recordButton.classList.toggle("recording", isRecording); recordButton.classList.toggle("recording", isRecording);
recordButton.disabled = waitingForStop;
if (waitingForStop) { if (waitingForStop) {
if (statusText.textContent !== "Recording stopped. Processing final audio...") { statusText.textContent = "Please wait for processing to complete...";
statusText.textContent = "Please wait for processing to complete..."; recordButton.disabled = true; // Optionally disable the button while waiting
} console.log("Record button disabled");
} else if (isRecording) { } else if (isRecording) {
statusText.textContent = "Recording..."; statusText.textContent = "Recording...";
} else {
if (statusText.textContent !== "Finished processing audio! Ready to record again." &&
statusText.textContent !== "Processing finalized or connection closed.") {
statusText.textContent = "Click to start transcription";
}
}
if (!waitingForStop) {
recordButton.disabled = false; recordButton.disabled = false;
console.log("Record button enabled");
} else {
statusText.textContent = "Click to start transcription";
recordButton.disabled = false;
console.log("Record button enabled");
} }
} }
@@ -679,4 +645,4 @@
</script> </script>
</body> </body>
</html> </html>

View File

@@ -1,13 +0,0 @@
import logging
import importlib.resources as resources
logger = logging.getLogger(__name__)
def get_web_interface_html():
"""Loads the HTML for the web interface using importlib.resources."""
try:
with resources.files('whisperlivekit.web').joinpath('live_transcription.html').open('r', encoding='utf-8') as f:
return f.read()
except Exception as e:
logger.error(f"Error loading web interface HTML: {e}")
return "<html><body><h1>Error loading interface</h1></body></html>"

View File

@@ -144,11 +144,7 @@ class OnlineASRProcessor:
self.transcript_buffer.last_committed_time = self.buffer_time_offset self.transcript_buffer.last_committed_time = self.buffer_time_offset
self.committed: List[ASRToken] = [] self.committed: List[ASRToken] = []
def get_audio_buffer_end_time(self) -> float: def insert_audio_chunk(self, audio: np.ndarray):
"""Returns the absolute end time of the current audio_buffer."""
return self.buffer_time_offset + (len(self.audio_buffer) / self.SAMPLING_RATE)
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: Optional[float] = None):
"""Append an audio chunk (a numpy array) to the current audio buffer.""" """Append an audio chunk (a numpy array) to the current audio buffer."""
self.audio_buffer = np.append(self.audio_buffer, audio) self.audio_buffer = np.append(self.audio_buffer, audio)
@@ -183,19 +179,18 @@ class OnlineASRProcessor:
return self.concatenate_tokens(self.transcript_buffer.buffer) return self.concatenate_tokens(self.transcript_buffer.buffer)
def process_iter(self) -> Tuple[List[ASRToken], float]: def process_iter(self) -> Transcript:
""" """
Processes the current audio buffer. Processes the current audio buffer.
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time). Returns a Transcript object representing the committed transcript.
""" """
current_audio_processed_upto = self.get_audio_buffer_end_time()
prompt_text, _ = self.prompt() prompt_text, _ = self.prompt()
logger.debug( logger.debug(
f"Transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds from {self.buffer_time_offset:.2f}" f"Transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds from {self.buffer_time_offset:.2f}"
) )
res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt_text) res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt_text)
tokens = self.asr.ts_words(res) tokens = self.asr.ts_words(res) # Expecting List[ASRToken]
self.transcript_buffer.insert(tokens, self.buffer_time_offset) self.transcript_buffer.insert(tokens, self.buffer_time_offset)
committed_tokens = self.transcript_buffer.flush() committed_tokens = self.transcript_buffer.flush()
self.committed.extend(committed_tokens) self.committed.extend(committed_tokens)
@@ -215,7 +210,7 @@ class OnlineASRProcessor:
logger.debug( logger.debug(
f"Length of audio buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds" f"Length of audio buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds"
) )
return committed_tokens, current_audio_processed_upto return committed_tokens
def chunk_completed_sentence(self): def chunk_completed_sentence(self):
""" """
@@ -348,17 +343,15 @@ class OnlineASRProcessor:
) )
sentences.append(sentence) sentences.append(sentence)
return sentences return sentences
def finish(self) -> Transcript:
def finish(self) -> Tuple[List[ASRToken], float]:
""" """
Flush the remaining transcript when processing ends. Flush the remaining transcript when processing ends.
Returns a tuple: (list of remaining ASRToken objects, float representing the final audio processed up to time).
""" """
remaining_tokens = self.transcript_buffer.buffer remaining_tokens = self.transcript_buffer.buffer
logger.debug(f"Final non-committed tokens: {remaining_tokens}") final_transcript = self.concatenate_tokens(remaining_tokens)
final_processed_upto = self.buffer_time_offset + (len(self.audio_buffer) / self.SAMPLING_RATE) logger.debug(f"Final non-committed transcript: {final_transcript}")
self.buffer_time_offset = final_processed_upto self.buffer_time_offset += len(self.audio_buffer) / self.SAMPLING_RATE
return remaining_tokens, final_processed_upto return final_transcript
def concatenate_tokens( def concatenate_tokens(
self, self,
@@ -391,8 +384,7 @@ class VACOnlineASRProcessor:
def __init__(self, online_chunk_size: float, *args, **kwargs): def __init__(self, online_chunk_size: float, *args, **kwargs):
self.online_chunk_size = online_chunk_size self.online_chunk_size = online_chunk_size
self.online = OnlineASRProcessor(*args, **kwargs) self.online = OnlineASRProcessor(*args, **kwargs)
self.asr = self.online.asr
# Load a VAD model (e.g. Silero VAD) # Load a VAD model (e.g. Silero VAD)
import torch import torch
model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad") model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
@@ -400,35 +392,28 @@ class VACOnlineASRProcessor:
self.vac = FixedVADIterator(model) self.vac = FixedVADIterator(model)
self.logfile = self.online.logfile self.logfile = self.online.logfile
self.last_input_audio_stream_end_time: float = 0.0
self.init() self.init()
def init(self): def init(self):
self.online.init() self.online.init()
self.vac.reset_states() self.vac.reset_states()
self.current_online_chunk_buffer_size = 0 self.current_online_chunk_buffer_size = 0
self.last_input_audio_stream_end_time = self.online.buffer_time_offset
self.is_currently_final = False self.is_currently_final = False
self.status: Optional[str] = None # "voice" or "nonvoice" self.status: Optional[str] = None # "voice" or "nonvoice"
self.audio_buffer = np.array([], dtype=np.float32) self.audio_buffer = np.array([], dtype=np.float32)
self.buffer_offset = 0 # in frames self.buffer_offset = 0 # in frames
def get_audio_buffer_end_time(self) -> float:
"""Returns the absolute end time of the audio processed by the underlying OnlineASRProcessor."""
return self.online.get_audio_buffer_end_time()
def clear_buffer(self): def clear_buffer(self):
self.buffer_offset += len(self.audio_buffer) self.buffer_offset += len(self.audio_buffer)
self.audio_buffer = np.array([], dtype=np.float32) self.audio_buffer = np.array([], dtype=np.float32)
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float): def insert_audio_chunk(self, audio: np.ndarray):
""" """
Process an incoming small audio chunk: Process an incoming small audio chunk:
- run VAD on the chunk, - run VAD on the chunk,
- decide whether to send the audio to the online ASR processor immediately, - decide whether to send the audio to the online ASR processor immediately,
- and/or to mark the current utterance as finished. - and/or to mark the current utterance as finished.
""" """
self.last_input_audio_stream_end_time = audio_stream_end_time
res = self.vac(audio) res = self.vac(audio)
self.audio_buffer = np.append(self.audio_buffer, audio) self.audio_buffer = np.append(self.audio_buffer, audio)
@@ -470,11 +455,10 @@ class VACOnlineASRProcessor:
self.buffer_offset += max(0, len(self.audio_buffer) - self.SAMPLING_RATE) self.buffer_offset += max(0, len(self.audio_buffer) - self.SAMPLING_RATE)
self.audio_buffer = self.audio_buffer[-self.SAMPLING_RATE:] self.audio_buffer = self.audio_buffer[-self.SAMPLING_RATE:]
def process_iter(self) -> Tuple[List[ASRToken], float]: def process_iter(self) -> Transcript:
""" """
Depending on the VAD status and the amount of accumulated audio, Depending on the VAD status and the amount of accumulated audio,
process the current audio chunk. process the current audio chunk.
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
""" """
if self.is_currently_final: if self.is_currently_final:
return self.finish() return self.finish()
@@ -483,20 +467,17 @@ class VACOnlineASRProcessor:
return self.online.process_iter() return self.online.process_iter()
else: else:
logger.debug("No online update, only VAD") logger.debug("No online update, only VAD")
return [], self.last_input_audio_stream_end_time return Transcript(None, None, "")
def finish(self) -> Tuple[List[ASRToken], float]: def finish(self) -> Transcript:
""" """Finish processing by flushing any remaining text."""
Finish processing by flushing any remaining text. result = self.online.finish()
Returns a tuple: (list of remaining ASRToken objects, float representing the final audio processed up to time).
"""
result_tokens, processed_upto = self.online.finish()
self.current_online_chunk_buffer_size = 0 self.current_online_chunk_buffer_size = 0
self.is_currently_final = False self.is_currently_final = False
return result_tokens, processed_upto return result
def get_buffer(self): def get_buffer(self):
""" """
Get the unvalidated buffer in string format. Get the unvalidated buffer in string format.
""" """
return self.online.concatenate_tokens(self.online.transcript_buffer.buffer) return self.online.concatenate_tokens(self.online.transcript_buffer.buffer).text