mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
Compare commits
21 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e165916952 | ||
|
|
8532a91c7a | ||
|
|
b01b81bad0 | ||
|
|
0f79d442ee | ||
|
|
c9f60504e3 | ||
|
|
993a83546a | ||
|
|
eabd1b199a | ||
|
|
f7644268c1 | ||
|
|
34e8fe260e | ||
|
|
debfefaf3e | ||
|
|
101ca9ef90 | ||
|
|
94bb05d53e | ||
|
|
6797b88176 | ||
|
|
46770efd6c | ||
|
|
b23ef3ec3e | ||
|
|
fa29a24abe | ||
|
|
fea3c3553c | ||
|
|
d6d65a663b | ||
|
|
083d5b2f44 | ||
|
|
8e4674b093 | ||
|
|
bc7c32100f |
33
LICENSE
33
LICENSE
@@ -1,21 +1,28 @@
|
|||||||
MIT License
|
MIT License
|
||||||
|
|
||||||
Copyright (c) 2023 ÚFAL
|
Copyright (c) 2025 Quentin Fuxa.
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
in the Software without restriction, including without limitation the rights
|
in the Software without restriction, including without limitation the rights
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
furnished to do so, subject to the following conditions:
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
The above copyright notice and this permission notice shall be included in all
|
||||||
copies or substantial portions of the Software.
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
SOFTWARE.
|
SOFTWARE.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
Based on:
|
||||||
|
- **whisper_streaming** by ÚFAL – MIT License – https://github.com/ufal/whisper_streaming. The original work by ÚFAL. License: https://github.com/ufal/whisper_streaming/blob/main/LICENSE
|
||||||
|
- **silero-vad** by Snakers4 – MIT License – https://github.com/snakers4/silero-vad. The work by Snakers4 (silero-vad). License: https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/LICENSE
|
||||||
|
- **Diart** by juanmc2005 – MIT License – https://github.com/juanmc2005/diart. The work in Diart by juanmc2005. License: https://github.com/juanmc2005/diart/blob/main/LICENSE
|
||||||
79
README.md
79
README.md
@@ -9,8 +9,8 @@
|
|||||||
<p align="center">
|
<p align="center">
|
||||||
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
|
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
|
||||||
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=downloads"></a>
|
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=downloads"></a>
|
||||||
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9%20%7C%203.10%20%7C%203.11%20%7C%203.12-dark_green"></a>
|
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.13-dark_green"></a>
|
||||||
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/github/license/QuentinFuxa/WhisperLiveKit?color=blue"></a>
|
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT-dark_green"></a>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
## 🚀 Overview
|
## 🚀 Overview
|
||||||
@@ -32,6 +32,7 @@ WhisperLiveKit consists of three main components:
|
|||||||
- **👥 Speaker Diarization** - Identify different speakers in real-time using [Diart](https://github.com/juanmc2005/diart)
|
- **👥 Speaker Diarization** - Identify different speakers in real-time using [Diart](https://github.com/juanmc2005/diart)
|
||||||
- **🔒 Fully Local** - All processing happens on your machine - no data sent to external servers
|
- **🔒 Fully Local** - All processing happens on your machine - no data sent to external servers
|
||||||
- **📱 Multi-User Support** - Handle multiple users simultaneously with a single backend/server
|
- **📱 Multi-User Support** - Handle multiple users simultaneously with a single backend/server
|
||||||
|
- **📝 Punctuation-Based Speaker Splitting [BETA] ** - Align speaker changes with natural sentence boundaries for more readable transcripts
|
||||||
|
|
||||||
### ⚙️ Core differences from [Whisper Streaming](https://github.com/ufal/whisper_streaming)
|
### ⚙️ Core differences from [Whisper Streaming](https://github.com/ufal/whisper_streaming)
|
||||||
|
|
||||||
@@ -142,52 +143,79 @@ whisperlivekit-server --host 0.0.0.0 --port 8000 --model medium --diarization --
|
|||||||
```
|
```
|
||||||
|
|
||||||
### Python API Integration (Backend)
|
### Python API Integration (Backend)
|
||||||
|
Check [basic_server.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) for a complete example.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from whisperlivekit import WhisperLiveKit
|
from whisperlivekit import TranscriptionEngine, AudioProcessor, get_web_interface_html, parse_args
|
||||||
from whisperlivekit.audio_processor import AudioProcessor
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||||
from fastapi import FastAPI, WebSocket
|
|
||||||
import asyncio
|
|
||||||
from fastapi.responses import HTMLResponse
|
from fastapi.responses import HTMLResponse
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
import asyncio
|
||||||
|
|
||||||
# Initialize components
|
# Global variable for the transcription engine
|
||||||
app = FastAPI()
|
transcription_engine = None
|
||||||
kit = WhisperLiveKit(model="medium", diarization=True)
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
global transcription_engine
|
||||||
|
# Example: Initialize with specific parameters directly
|
||||||
|
# You can also load from command-line arguments using parse_args()
|
||||||
|
# args = parse_args()
|
||||||
|
# transcription_engine = TranscriptionEngine(**vars(args))
|
||||||
|
transcription_engine = TranscriptionEngine(model="medium", diarization=True, lan="en")
|
||||||
|
yield
|
||||||
|
|
||||||
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
# Serve the web interface
|
# Serve the web interface
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
async def get():
|
async def get():
|
||||||
return HTMLResponse(kit.web_interface()) # Use the built-in web interface
|
return HTMLResponse(get_web_interface_html())
|
||||||
|
|
||||||
# Process WebSocket connections
|
# Process WebSocket connections
|
||||||
async def handle_websocket_results(websocket, results_generator):
|
async def handle_websocket_results(websocket: WebSocket, results_generator):
|
||||||
async for response in results_generator:
|
try:
|
||||||
await websocket.send_json(response)
|
async for response in results_generator:
|
||||||
|
await websocket.send_json(response)
|
||||||
|
await websocket.send_json({"type": "ready_to_stop"})
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
print("WebSocket disconnected during results handling.")
|
||||||
|
|
||||||
@app.websocket("/asr")
|
@app.websocket("/asr")
|
||||||
async def websocket_endpoint(websocket: WebSocket):
|
async def websocket_endpoint(websocket: WebSocket):
|
||||||
audio_processor = AudioProcessor()
|
global transcription_engine
|
||||||
await websocket.accept()
|
|
||||||
results_generator = await audio_processor.create_tasks()
|
|
||||||
websocket_task = asyncio.create_task(
|
|
||||||
handle_websocket_results(websocket, results_generator)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# Create a new AudioProcessor for each connection, passing the shared engine
|
||||||
|
audio_processor = AudioProcessor(transcription_engine=transcription_engine)
|
||||||
|
results_generator = await audio_processor.create_tasks()
|
||||||
|
send_results_to_client = handle_websocket_results(websocket, results_generator)
|
||||||
|
results_task = asyncio.create_task(send_results_to_client)
|
||||||
|
await websocket.accept()
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
message = await websocket.receive_bytes()
|
message = await websocket.receive_bytes()
|
||||||
await audio_processor.process_audio(message)
|
await audio_processor.process_audio(message)
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
print(f"Client disconnected: {websocket.client}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"WebSocket error: {e}")
|
await websocket.close(code=1011, reason=f"Server error: {e}")
|
||||||
websocket_task.cancel()
|
finally:
|
||||||
|
results_task.cancel()
|
||||||
|
try:
|
||||||
|
await results_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("Results task successfully cancelled.")
|
||||||
```
|
```
|
||||||
|
|
||||||
### Frontend Implementation
|
### Frontend Implementation
|
||||||
|
|
||||||
The package includes a simple HTML/JavaScript implementation that you can adapt for your project. You can get in in [whisperlivekit/web/live_transcription.html](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html), or using :
|
The package includes a simple HTML/JavaScript implementation that you can adapt for your project. You can find it in `whisperlivekit/web/live_transcription.html`, or load its content using the `get_web_interface_html()` function from `whisperlivekit`:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
kit.web_interface()
|
from whisperlivekit import get_web_interface_html
|
||||||
|
|
||||||
|
# ... later in your code where you need the HTML string ...
|
||||||
|
html_content = get_web_interface_html()
|
||||||
```
|
```
|
||||||
|
|
||||||
## ⚙️ Configuration Reference
|
## ⚙️ Configuration Reference
|
||||||
@@ -203,6 +231,7 @@ WhisperLiveKit offers extensive configuration options:
|
|||||||
| `--task` | `transcribe` or `translate` | `transcribe` |
|
| `--task` | `transcribe` or `translate` | `transcribe` |
|
||||||
| `--backend` | Processing backend | `faster-whisper` |
|
| `--backend` | Processing backend | `faster-whisper` |
|
||||||
| `--diarization` | Enable speaker identification | `False` |
|
| `--diarization` | Enable speaker identification | `False` |
|
||||||
|
| `--punctuation-split` | Use punctuation to improve speaker boundaries | `True` |
|
||||||
| `--confidence-validation` | Use confidence scores for faster validation | `False` |
|
| `--confidence-validation` | Use confidence scores for faster validation | `False` |
|
||||||
| `--min-chunk-size` | Minimum audio chunk size (seconds) | `1.0` |
|
| `--min-chunk-size` | Minimum audio chunk size (seconds) | `1.0` |
|
||||||
| `--vac` | Use Voice Activity Controller | `False` |
|
| `--vac` | Use Voice Activity Controller | `False` |
|
||||||
@@ -211,6 +240,8 @@ WhisperLiveKit offers extensive configuration options:
|
|||||||
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
|
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
|
||||||
| `--ssl-certfile` | Path to the SSL certificate file (for HTTPS support) | `None` |
|
| `--ssl-certfile` | Path to the SSL certificate file (for HTTPS support) | `None` |
|
||||||
| `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` |
|
| `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` |
|
||||||
|
| `--segmentation-model` | Hugging Face model ID for pyannote.audio segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
|
||||||
|
| `--embedding-model` | Hugging Face model ID for pyannote.audio embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
|
||||||
|
|
||||||
## 🔧 How It Works
|
## 🔧 How It Works
|
||||||
|
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -1,7 +1,7 @@
|
|||||||
from setuptools import setup, find_packages
|
from setuptools import setup, find_packages
|
||||||
setup(
|
setup(
|
||||||
name="whisperlivekit",
|
name="whisperlivekit",
|
||||||
version="0.1.5",
|
version="0.1.9",
|
||||||
description="Real-time, Fully Local Whisper's Speech-to-Text and Speaker Diarization",
|
description="Real-time, Fully Local Whisper's Speech-to-Text and Speaker Diarization",
|
||||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from .core import WhisperLiveKit, parse_args
|
from .core import TranscriptionEngine
|
||||||
from .audio_processor import AudioProcessor
|
from .audio_processor import AudioProcessor
|
||||||
|
from .web.web_interface import get_web_interface_html
|
||||||
__all__ = ['WhisperLiveKit', 'AudioProcessor', 'parse_args']
|
from .parse_args import parse_args
|
||||||
|
__all__ = ['TranscriptionEngine', 'AudioProcessor', 'get_web_interface_html', 'parse_args']
|
||||||
@@ -8,13 +8,15 @@ import traceback
|
|||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from whisperlivekit.timed_objects import ASRToken
|
from whisperlivekit.timed_objects import ASRToken
|
||||||
from whisperlivekit.whisper_streaming_custom.whisper_online import online_factory
|
from whisperlivekit.whisper_streaming_custom.whisper_online import online_factory
|
||||||
from whisperlivekit.core import WhisperLiveKit
|
from whisperlivekit.core import TranscriptionEngine
|
||||||
|
|
||||||
# Set up logging once
|
# Set up logging once
|
||||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
SENTINEL = object() # unique sentinel object for end of stream marker
|
||||||
|
|
||||||
def format_time(seconds: float) -> str:
|
def format_time(seconds: float) -> str:
|
||||||
"""Format seconds as HH:MM:SS."""
|
"""Format seconds as HH:MM:SS."""
|
||||||
return str(timedelta(seconds=int(seconds)))
|
return str(timedelta(seconds=int(seconds)))
|
||||||
@@ -25,10 +27,13 @@ class AudioProcessor:
|
|||||||
Handles audio processing, state management, and result formatting.
|
Handles audio processing, state management, and result formatting.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, **kwargs):
|
||||||
"""Initialize the audio processor with configuration, models, and state."""
|
"""Initialize the audio processor with configuration, models, and state."""
|
||||||
|
|
||||||
models = WhisperLiveKit()
|
if 'transcription_engine' in kwargs and isinstance(kwargs['transcription_engine'], TranscriptionEngine):
|
||||||
|
models = kwargs['transcription_engine']
|
||||||
|
else:
|
||||||
|
models = TranscriptionEngine(**kwargs)
|
||||||
|
|
||||||
# Audio processing settings
|
# Audio processing settings
|
||||||
self.args = models.args
|
self.args = models.args
|
||||||
@@ -41,8 +46,9 @@ class AudioProcessor:
|
|||||||
self.last_ffmpeg_activity = time()
|
self.last_ffmpeg_activity = time()
|
||||||
self.ffmpeg_health_check_interval = 5
|
self.ffmpeg_health_check_interval = 5
|
||||||
self.ffmpeg_max_idle_time = 10
|
self.ffmpeg_max_idle_time = 10
|
||||||
|
|
||||||
# State management
|
# State management
|
||||||
|
self.is_stopping = False
|
||||||
self.tokens = []
|
self.tokens = []
|
||||||
self.buffer_transcription = ""
|
self.buffer_transcription = ""
|
||||||
self.buffer_diarization = ""
|
self.buffer_diarization = ""
|
||||||
@@ -62,6 +68,13 @@ class AudioProcessor:
|
|||||||
self.transcription_queue = asyncio.Queue() if self.args.transcription else None
|
self.transcription_queue = asyncio.Queue() if self.args.transcription else None
|
||||||
self.diarization_queue = asyncio.Queue() if self.args.diarization else None
|
self.diarization_queue = asyncio.Queue() if self.args.diarization else None
|
||||||
self.pcm_buffer = bytearray()
|
self.pcm_buffer = bytearray()
|
||||||
|
|
||||||
|
# Task references
|
||||||
|
self.transcription_task = None
|
||||||
|
self.diarization_task = None
|
||||||
|
self.ffmpeg_reader_task = None
|
||||||
|
self.watchdog_task = None
|
||||||
|
self.all_tasks_for_cleanup = []
|
||||||
|
|
||||||
# Initialize transcription engine if enabled
|
# Initialize transcription engine if enabled
|
||||||
if self.args.transcription:
|
if self.args.transcription:
|
||||||
@@ -73,10 +86,33 @@ class AudioProcessor:
|
|||||||
|
|
||||||
def start_ffmpeg_decoder(self):
|
def start_ffmpeg_decoder(self):
|
||||||
"""Start FFmpeg process for WebM to PCM conversion."""
|
"""Start FFmpeg process for WebM to PCM conversion."""
|
||||||
return (ffmpeg.input("pipe:0", format="webm")
|
try:
|
||||||
.output("pipe:1", format="s16le", acodec="pcm_s16le",
|
return (ffmpeg.input("pipe:0", format="webm")
|
||||||
ac=self.channels, ar=str(self.sample_rate))
|
.output("pipe:1", format="s16le", acodec="pcm_s16le",
|
||||||
.run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True))
|
ac=self.channels, ar=str(self.sample_rate))
|
||||||
|
.run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True))
|
||||||
|
except FileNotFoundError:
|
||||||
|
error = """
|
||||||
|
FFmpeg is not installed or not found in your system's PATH.
|
||||||
|
Please install FFmpeg to enable audio processing.
|
||||||
|
|
||||||
|
Installation instructions:
|
||||||
|
|
||||||
|
# Ubuntu/Debian:
|
||||||
|
sudo apt update && sudo apt install ffmpeg
|
||||||
|
|
||||||
|
# macOS (using Homebrew):
|
||||||
|
brew install ffmpeg
|
||||||
|
|
||||||
|
# Windows:
|
||||||
|
# 1. Download the latest static build from https://ffmpeg.org/download.html
|
||||||
|
# 2. Extract the archive (e.g., to C:\\FFmpeg).
|
||||||
|
# 3. Add the 'bin' directory (e.g., C:\\FFmpeg\\bin) to your system's PATH environment variable.
|
||||||
|
|
||||||
|
After installation, please restart the application.
|
||||||
|
"""
|
||||||
|
logger.error(error)
|
||||||
|
raise FileNotFoundError(error)
|
||||||
|
|
||||||
async def restart_ffmpeg(self):
|
async def restart_ffmpeg(self):
|
||||||
"""Restart the FFmpeg process after failure."""
|
"""Restart the FFmpeg process after failure."""
|
||||||
@@ -210,7 +246,7 @@ class AudioProcessor:
|
|||||||
self.last_ffmpeg_activity = time()
|
self.last_ffmpeg_activity = time()
|
||||||
|
|
||||||
if not chunk:
|
if not chunk:
|
||||||
logger.info("FFmpeg stdout closed.")
|
logger.info("FFmpeg stdout closed, no more data to read.")
|
||||||
break
|
break
|
||||||
|
|
||||||
self.pcm_buffer.extend(chunk)
|
self.pcm_buffer.extend(chunk)
|
||||||
@@ -245,45 +281,86 @@ class AudioProcessor:
|
|||||||
logger.warning(f"Exception in ffmpeg_stdout_reader: {e}")
|
logger.warning(f"Exception in ffmpeg_stdout_reader: {e}")
|
||||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||||
break
|
break
|
||||||
|
|
||||||
|
logger.info("FFmpeg stdout processing finished. Signaling downstream processors.")
|
||||||
|
if self.args.transcription and self.transcription_queue:
|
||||||
|
await self.transcription_queue.put(SENTINEL)
|
||||||
|
logger.debug("Sentinel put into transcription_queue.")
|
||||||
|
if self.args.diarization and self.diarization_queue:
|
||||||
|
await self.diarization_queue.put(SENTINEL)
|
||||||
|
logger.debug("Sentinel put into diarization_queue.")
|
||||||
|
|
||||||
|
|
||||||
async def transcription_processor(self):
|
async def transcription_processor(self):
|
||||||
"""Process audio chunks for transcription."""
|
"""Process audio chunks for transcription."""
|
||||||
self.full_transcription = ""
|
self.full_transcription = ""
|
||||||
self.sep = self.online.asr.sep
|
self.sep = self.online.asr.sep
|
||||||
|
cumulative_pcm_duration_stream_time = 0.0
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
pcm_array = await self.transcription_queue.get()
|
pcm_array = await self.transcription_queue.get()
|
||||||
|
if pcm_array is SENTINEL:
|
||||||
|
logger.debug("Transcription processor received sentinel. Finishing.")
|
||||||
|
self.transcription_queue.task_done()
|
||||||
|
break
|
||||||
|
|
||||||
logger.info(f"{len(self.online.audio_buffer) / self.online.SAMPLING_RATE} seconds of audio to process.")
|
if not self.online: # Should not happen if queue is used
|
||||||
|
logger.warning("Transcription processor: self.online not initialized.")
|
||||||
|
self.transcription_queue.task_done()
|
||||||
|
continue
|
||||||
|
|
||||||
|
asr_internal_buffer_duration_s = len(self.online.audio_buffer) / self.online.SAMPLING_RATE
|
||||||
|
transcription_lag_s = max(0.0, time() - self.beg_loop - self.end_buffer)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"ASR processing: internal_buffer={asr_internal_buffer_duration_s:.2f}s, "
|
||||||
|
f"lag={transcription_lag_s:.2f}s."
|
||||||
|
)
|
||||||
|
|
||||||
# Process transcription
|
# Process transcription
|
||||||
self.online.insert_audio_chunk(pcm_array)
|
duration_this_chunk = len(pcm_array) / self.sample_rate if isinstance(pcm_array, np.ndarray) else 0
|
||||||
new_tokens = self.online.process_iter()
|
cumulative_pcm_duration_stream_time += duration_this_chunk
|
||||||
|
stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time
|
||||||
|
|
||||||
|
self.online.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm)
|
||||||
|
new_tokens, current_audio_processed_upto = self.online.process_iter()
|
||||||
|
|
||||||
if new_tokens:
|
if new_tokens:
|
||||||
self.full_transcription += self.sep.join([t.text for t in new_tokens])
|
self.full_transcription += self.sep.join([t.text for t in new_tokens])
|
||||||
|
|
||||||
# Get buffer information
|
# Get buffer information
|
||||||
_buffer = self.online.get_buffer()
|
_buffer_transcript_obj = self.online.get_buffer()
|
||||||
buffer = _buffer.text
|
buffer_text = _buffer_transcript_obj.text
|
||||||
end_buffer = _buffer.end if _buffer.end else (
|
|
||||||
new_tokens[-1].end if new_tokens else 0
|
candidate_end_times = [self.end_buffer]
|
||||||
)
|
|
||||||
|
if new_tokens:
|
||||||
|
candidate_end_times.append(new_tokens[-1].end)
|
||||||
|
|
||||||
|
if _buffer_transcript_obj.end is not None:
|
||||||
|
candidate_end_times.append(_buffer_transcript_obj.end)
|
||||||
|
|
||||||
|
candidate_end_times.append(current_audio_processed_upto)
|
||||||
|
|
||||||
|
new_end_buffer = max(candidate_end_times)
|
||||||
|
|
||||||
# Avoid duplicating content
|
# Avoid duplicating content
|
||||||
if buffer in self.full_transcription:
|
if buffer_text in self.full_transcription:
|
||||||
buffer = ""
|
buffer_text = ""
|
||||||
|
|
||||||
await self.update_transcription(
|
await self.update_transcription(
|
||||||
new_tokens, buffer, end_buffer, self.full_transcription, self.sep
|
new_tokens, buffer_text, new_end_buffer, self.full_transcription, self.sep
|
||||||
)
|
)
|
||||||
|
self.transcription_queue.task_done()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Exception in transcription_processor: {e}")
|
logger.warning(f"Exception in transcription_processor: {e}")
|
||||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||||
finally:
|
if 'pcm_array' in locals() and pcm_array is not SENTINEL : # Check if pcm_array was assigned from queue
|
||||||
self.transcription_queue.task_done()
|
self.transcription_queue.task_done()
|
||||||
|
logger.info("Transcription processor task finished.")
|
||||||
|
|
||||||
|
|
||||||
async def diarization_processor(self, diarization_obj):
|
async def diarization_processor(self, diarization_obj):
|
||||||
"""Process audio chunks for speaker diarization."""
|
"""Process audio chunks for speaker diarization."""
|
||||||
@@ -292,23 +369,33 @@ class AudioProcessor:
|
|||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
pcm_array = await self.diarization_queue.get()
|
pcm_array = await self.diarization_queue.get()
|
||||||
|
if pcm_array is SENTINEL:
|
||||||
|
logger.debug("Diarization processor received sentinel. Finishing.")
|
||||||
|
self.diarization_queue.task_done()
|
||||||
|
break
|
||||||
|
|
||||||
# Process diarization
|
# Process diarization
|
||||||
await diarization_obj.diarize(pcm_array)
|
await diarization_obj.diarize(pcm_array)
|
||||||
|
|
||||||
# Get current state and update speakers
|
async with self.lock:
|
||||||
state = await self.get_current_state()
|
new_end = diarization_obj.assign_speakers_to_tokens(
|
||||||
new_end = diarization_obj.assign_speakers_to_tokens(
|
self.end_attributed_speaker,
|
||||||
state["end_attributed_speaker"], state["tokens"]
|
self.tokens,
|
||||||
)
|
use_punctuation_split=self.args.punctuation_split
|
||||||
|
)
|
||||||
|
self.end_attributed_speaker = new_end
|
||||||
|
if buffer_diarization:
|
||||||
|
self.buffer_diarization = buffer_diarization
|
||||||
|
|
||||||
await self.update_diarization(new_end, buffer_diarization)
|
self.diarization_queue.task_done()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Exception in diarization_processor: {e}")
|
logger.warning(f"Exception in diarization_processor: {e}")
|
||||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||||
finally:
|
if 'pcm_array' in locals() and pcm_array is not SENTINEL:
|
||||||
self.diarization_queue.task_done()
|
self.diarization_queue.task_done()
|
||||||
|
logger.info("Diarization processor task finished.")
|
||||||
|
|
||||||
|
|
||||||
async def results_formatter(self):
|
async def results_formatter(self):
|
||||||
"""Format processing results for output."""
|
"""Format processing results for output."""
|
||||||
@@ -372,31 +459,51 @@ class AudioProcessor:
|
|||||||
await self.update_diarization(end_attributed_speaker, combined)
|
await self.update_diarization(end_attributed_speaker, combined)
|
||||||
buffer_diarization = combined
|
buffer_diarization = combined
|
||||||
|
|
||||||
# Create response object
|
response_status = "active_transcription"
|
||||||
if not lines:
|
final_lines_for_response = lines.copy()
|
||||||
lines = [{
|
|
||||||
|
if not tokens and not buffer_transcription and not buffer_diarization:
|
||||||
|
response_status = "no_audio_detected"
|
||||||
|
final_lines_for_response = []
|
||||||
|
elif response_status == "active_transcription" and not final_lines_for_response:
|
||||||
|
final_lines_for_response = [{
|
||||||
"speaker": 1,
|
"speaker": 1,
|
||||||
"text": "",
|
"text": "",
|
||||||
"beg": format_time(0),
|
"beg": format_time(state.get("end_buffer", 0)),
|
||||||
"end": format_time(tokens[-1].end if tokens else 0),
|
"end": format_time(state.get("end_buffer", 0)),
|
||||||
"diff": 0
|
"diff": 0
|
||||||
}]
|
}]
|
||||||
|
|
||||||
response = {
|
response = {
|
||||||
"lines": lines,
|
"status": response_status,
|
||||||
|
"lines": final_lines_for_response,
|
||||||
"buffer_transcription": buffer_transcription,
|
"buffer_transcription": buffer_transcription,
|
||||||
"buffer_diarization": buffer_diarization,
|
"buffer_diarization": buffer_diarization,
|
||||||
"remaining_time_transcription": state["remaining_time_transcription"],
|
"remaining_time_transcription": state["remaining_time_transcription"],
|
||||||
"remaining_time_diarization": state["remaining_time_diarization"]
|
"remaining_time_diarization": state["remaining_time_diarization"]
|
||||||
}
|
}
|
||||||
|
|
||||||
# Only yield if content has changed
|
current_response_signature = f"{response_status} | " + \
|
||||||
response_content = ' '.join([f"{line['speaker']} {line['text']}" for line in lines]) + \
|
' '.join([f"{line['speaker']} {line['text']}" for line in final_lines_for_response]) + \
|
||||||
f" | {buffer_transcription} | {buffer_diarization}"
|
f" | {buffer_transcription} | {buffer_diarization}"
|
||||||
|
|
||||||
if response_content != self.last_response_content and (lines or buffer_transcription or buffer_diarization):
|
if current_response_signature != self.last_response_content and \
|
||||||
|
(final_lines_for_response or buffer_transcription or buffer_diarization or response_status == "no_audio_detected"):
|
||||||
yield response
|
yield response
|
||||||
self.last_response_content = response_content
|
self.last_response_content = current_response_signature
|
||||||
|
|
||||||
|
# Check for termination condition
|
||||||
|
if self.is_stopping:
|
||||||
|
all_processors_done = True
|
||||||
|
if self.args.transcription and self.transcription_task and not self.transcription_task.done():
|
||||||
|
all_processors_done = False
|
||||||
|
if self.args.diarization and self.diarization_task and not self.diarization_task.done():
|
||||||
|
all_processors_done = False
|
||||||
|
|
||||||
|
if all_processors_done:
|
||||||
|
logger.info("Results formatter: All upstream processors are done and in stopping state. Terminating.")
|
||||||
|
final_state = await self.get_current_state()
|
||||||
|
return
|
||||||
|
|
||||||
await asyncio.sleep(0.1) # Avoid overwhelming the client
|
await asyncio.sleep(0.1) # Avoid overwhelming the client
|
||||||
|
|
||||||
@@ -407,65 +514,117 @@ class AudioProcessor:
|
|||||||
|
|
||||||
async def create_tasks(self):
|
async def create_tasks(self):
|
||||||
"""Create and start processing tasks."""
|
"""Create and start processing tasks."""
|
||||||
|
self.all_tasks_for_cleanup = []
|
||||||
tasks = []
|
processing_tasks_for_watchdog = []
|
||||||
|
|
||||||
if self.args.transcription and self.online:
|
if self.args.transcription and self.online:
|
||||||
tasks.append(asyncio.create_task(self.transcription_processor()))
|
self.transcription_task = asyncio.create_task(self.transcription_processor())
|
||||||
|
self.all_tasks_for_cleanup.append(self.transcription_task)
|
||||||
|
processing_tasks_for_watchdog.append(self.transcription_task)
|
||||||
|
|
||||||
if self.args.diarization and self.diarization:
|
if self.args.diarization and self.diarization:
|
||||||
tasks.append(asyncio.create_task(self.diarization_processor(self.diarization)))
|
self.diarization_task = asyncio.create_task(self.diarization_processor(self.diarization))
|
||||||
|
self.all_tasks_for_cleanup.append(self.diarization_task)
|
||||||
|
processing_tasks_for_watchdog.append(self.diarization_task)
|
||||||
|
|
||||||
tasks.append(asyncio.create_task(self.ffmpeg_stdout_reader()))
|
self.ffmpeg_reader_task = asyncio.create_task(self.ffmpeg_stdout_reader())
|
||||||
|
self.all_tasks_for_cleanup.append(self.ffmpeg_reader_task)
|
||||||
# Monitor overall system health
|
processing_tasks_for_watchdog.append(self.ffmpeg_reader_task)
|
||||||
async def watchdog():
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
await asyncio.sleep(10) # Check every 10 seconds instead of 60
|
|
||||||
|
|
||||||
current_time = time()
|
|
||||||
# Check for stalled tasks
|
|
||||||
for i, task in enumerate(tasks):
|
|
||||||
if task.done():
|
|
||||||
exc = task.exception() if task.done() else None
|
|
||||||
task_name = task.get_name() if hasattr(task, 'get_name') else f"Task {i}"
|
|
||||||
logger.error(f"{task_name} unexpectedly completed with exception: {exc}")
|
|
||||||
|
|
||||||
# Check for FFmpeg process health with shorter thresholds
|
|
||||||
ffmpeg_idle_time = current_time - self.last_ffmpeg_activity
|
|
||||||
if ffmpeg_idle_time > 15: # 15 seconds instead of 180
|
|
||||||
logger.warning(f"FFmpeg idle for {ffmpeg_idle_time:.2f}s - may need attention")
|
|
||||||
|
|
||||||
# Force restart after 30 seconds of inactivity (instead of 600)
|
|
||||||
if ffmpeg_idle_time > 30:
|
|
||||||
logger.error("FFmpeg idle for too long, forcing restart")
|
|
||||||
await self.restart_ffmpeg()
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in watchdog task: {e}")
|
|
||||||
|
|
||||||
tasks.append(asyncio.create_task(watchdog()))
|
# Monitor overall system health
|
||||||
self.tasks = tasks
|
self.watchdog_task = asyncio.create_task(self.watchdog(processing_tasks_for_watchdog))
|
||||||
|
self.all_tasks_for_cleanup.append(self.watchdog_task)
|
||||||
|
|
||||||
return self.results_formatter()
|
return self.results_formatter()
|
||||||
|
|
||||||
|
async def watchdog(self, tasks_to_monitor):
|
||||||
|
"""Monitors the health of critical processing tasks."""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(10)
|
||||||
|
current_time = time()
|
||||||
|
|
||||||
|
for i, task in enumerate(tasks_to_monitor):
|
||||||
|
if task.done():
|
||||||
|
exc = task.exception()
|
||||||
|
task_name = task.get_name() if hasattr(task, 'get_name') else f"Monitored Task {i}"
|
||||||
|
if exc:
|
||||||
|
logger.error(f"{task_name} unexpectedly completed with exception: {exc}")
|
||||||
|
else:
|
||||||
|
logger.info(f"{task_name} completed normally.")
|
||||||
|
|
||||||
|
ffmpeg_idle_time = current_time - self.last_ffmpeg_activity
|
||||||
|
if ffmpeg_idle_time > 15:
|
||||||
|
logger.warning(f"FFmpeg idle for {ffmpeg_idle_time:.2f}s - may need attention.")
|
||||||
|
if ffmpeg_idle_time > 30 and not self.is_stopping:
|
||||||
|
logger.error("FFmpeg idle for too long and not in stopping phase, forcing restart.")
|
||||||
|
await self.restart_ffmpeg()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("Watchdog task cancelled.")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in watchdog task: {e}", exc_info=True)
|
||||||
|
|
||||||
async def cleanup(self):
|
async def cleanup(self):
|
||||||
"""Clean up resources when processing is complete."""
|
"""Clean up resources when processing is complete."""
|
||||||
for task in self.tasks:
|
logger.info("Starting cleanup of AudioProcessor resources.")
|
||||||
task.cancel()
|
for task in self.all_tasks_for_cleanup:
|
||||||
|
if task and not task.done():
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
created_tasks = [t for t in self.all_tasks_for_cleanup if t]
|
||||||
|
if created_tasks:
|
||||||
|
await asyncio.gather(*created_tasks, return_exceptions=True)
|
||||||
|
logger.info("All processing tasks cancelled or finished.")
|
||||||
|
|
||||||
|
if self.ffmpeg_process:
|
||||||
|
if self.ffmpeg_process.stdin and not self.ffmpeg_process.stdin.closed:
|
||||||
|
try:
|
||||||
|
self.ffmpeg_process.stdin.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error closing ffmpeg stdin during cleanup: {e}")
|
||||||
|
|
||||||
try:
|
# Wait for ffmpeg process to terminate
|
||||||
await asyncio.gather(*self.tasks, return_exceptions=True)
|
if self.ffmpeg_process.poll() is None: # Check if process is still running
|
||||||
self.ffmpeg_process.stdin.close()
|
logger.info("Waiting for FFmpeg process to terminate...")
|
||||||
self.ffmpeg_process.wait()
|
try:
|
||||||
except Exception as e:
|
# Run wait in executor to avoid blocking async loop
|
||||||
logger.warning(f"Error during cleanup: {e}")
|
await asyncio.get_event_loop().run_in_executor(None, self.ffmpeg_process.wait, 5.0) # 5s timeout
|
||||||
|
except Exception as e: # subprocess.TimeoutExpired is not directly caught by asyncio.wait_for with run_in_executor
|
||||||
if self.args.diarization and hasattr(self, 'diarization'):
|
logger.warning(f"FFmpeg did not terminate gracefully, killing. Error: {e}")
|
||||||
|
self.ffmpeg_process.kill()
|
||||||
|
await asyncio.get_event_loop().run_in_executor(None, self.ffmpeg_process.wait) # Wait for kill
|
||||||
|
logger.info("FFmpeg process terminated.")
|
||||||
|
|
||||||
|
if self.args.diarization and hasattr(self, 'diarization') and hasattr(self.diarization, 'close'):
|
||||||
self.diarization.close()
|
self.diarization.close()
|
||||||
|
logger.info("AudioProcessor cleanup complete.")
|
||||||
|
|
||||||
|
|
||||||
async def process_audio(self, message):
|
async def process_audio(self, message):
|
||||||
"""Process incoming audio data."""
|
"""Process incoming audio data."""
|
||||||
|
# If already stopping or stdin is closed, ignore further audio, especially residual chunks.
|
||||||
|
if self.is_stopping or (self.ffmpeg_process and self.ffmpeg_process.stdin and self.ffmpeg_process.stdin.closed):
|
||||||
|
logger.warning(f"AudioProcessor is stopping or stdin is closed. Ignoring incoming audio message (length: {len(message)}).")
|
||||||
|
if not message and self.ffmpeg_process and self.ffmpeg_process.stdin and not self.ffmpeg_process.stdin.closed:
|
||||||
|
logger.info("Received empty message while already in stopping state; ensuring stdin is closed.")
|
||||||
|
try:
|
||||||
|
self.ffmpeg_process.stdin.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error closing ffmpeg stdin on redundant stop signal during stopping state: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not message: # primary signal to start stopping
|
||||||
|
logger.info("Empty audio message received, initiating stop sequence.")
|
||||||
|
self.is_stopping = True
|
||||||
|
if self.ffmpeg_process and self.ffmpeg_process.stdin and not self.ffmpeg_process.stdin.closed:
|
||||||
|
try:
|
||||||
|
self.ffmpeg_process.stdin.close()
|
||||||
|
logger.info("FFmpeg stdin closed due to primary stop signal.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error closing ffmpeg stdin on stop: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
retry_count = 0
|
retry_count = 0
|
||||||
max_retries = 3
|
max_retries = 3
|
||||||
|
|
||||||
@@ -517,4 +676,4 @@ class AudioProcessor:
|
|||||||
else:
|
else:
|
||||||
logger.error("Maximum retries reached for FFmpeg process")
|
logger.error("Maximum retries reached for FFmpeg process")
|
||||||
await self.restart_ffmpeg()
|
await self.restart_ffmpeg()
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -2,26 +2,24 @@ from contextlib import asynccontextmanager
|
|||||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||||
from fastapi.responses import HTMLResponse
|
from fastapi.responses import HTMLResponse
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from whisperlivekit import TranscriptionEngine, AudioProcessor, get_web_interface_html, parse_args
|
||||||
from whisperlivekit import WhisperLiveKit, parse_args
|
|
||||||
from whisperlivekit.audio_processor import AudioProcessor
|
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os, sys
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||||
logging.getLogger().setLevel(logging.WARNING)
|
logging.getLogger().setLevel(logging.WARNING)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
kit = None
|
args = parse_args()
|
||||||
|
transcription_engine = None
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
global kit
|
global transcription_engine
|
||||||
kit = WhisperLiveKit()
|
transcription_engine = TranscriptionEngine(
|
||||||
|
**vars(args),
|
||||||
|
)
|
||||||
yield
|
yield
|
||||||
|
|
||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
@@ -33,10 +31,9 @@ app.add_middleware(
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
async def get():
|
async def get():
|
||||||
return HTMLResponse(kit.web_interface())
|
return HTMLResponse(get_web_interface_html())
|
||||||
|
|
||||||
|
|
||||||
async def handle_websocket_results(websocket, results_generator):
|
async def handle_websocket_results(websocket, results_generator):
|
||||||
@@ -44,14 +41,21 @@ async def handle_websocket_results(websocket, results_generator):
|
|||||||
try:
|
try:
|
||||||
async for response in results_generator:
|
async for response in results_generator:
|
||||||
await websocket.send_json(response)
|
await websocket.send_json(response)
|
||||||
|
# when the results_generator finishes it means all audio has been processed
|
||||||
|
logger.info("Results generator finished. Sending 'ready_to_stop' to client.")
|
||||||
|
await websocket.send_json({"type": "ready_to_stop"})
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
logger.info("WebSocket disconnected while handling results (client likely closed connection).")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error in WebSocket results handler: {e}")
|
logger.warning(f"Error in WebSocket results handler: {e}")
|
||||||
|
|
||||||
|
|
||||||
@app.websocket("/asr")
|
@app.websocket("/asr")
|
||||||
async def websocket_endpoint(websocket: WebSocket):
|
async def websocket_endpoint(websocket: WebSocket):
|
||||||
audio_processor = AudioProcessor()
|
global transcription_engine
|
||||||
|
audio_processor = AudioProcessor(
|
||||||
|
transcription_engine=transcription_engine,
|
||||||
|
)
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
logger.info("WebSocket connection opened.")
|
logger.info("WebSocket connection opened.")
|
||||||
|
|
||||||
@@ -62,19 +66,33 @@ async def websocket_endpoint(websocket: WebSocket):
|
|||||||
while True:
|
while True:
|
||||||
message = await websocket.receive_bytes()
|
message = await websocket.receive_bytes()
|
||||||
await audio_processor.process_audio(message)
|
await audio_processor.process_audio(message)
|
||||||
|
except KeyError as e:
|
||||||
|
if 'bytes' in str(e):
|
||||||
|
logger.warning(f"Client has closed the connection.")
|
||||||
|
else:
|
||||||
|
logger.error(f"Unexpected KeyError in websocket_endpoint: {e}", exc_info=True)
|
||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
logger.warning("WebSocket disconnected.")
|
logger.info("WebSocket disconnected by client during message receiving loop.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error in websocket_endpoint main loop: {e}", exc_info=True)
|
||||||
finally:
|
finally:
|
||||||
websocket_task.cancel()
|
logger.info("Cleaning up WebSocket endpoint...")
|
||||||
|
if not websocket_task.done():
|
||||||
|
websocket_task.cancel()
|
||||||
|
try:
|
||||||
|
await websocket_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("WebSocket results handler task was cancelled.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Exception while awaiting websocket_task completion: {e}")
|
||||||
|
|
||||||
await audio_processor.cleanup()
|
await audio_processor.cleanup()
|
||||||
logger.info("WebSocket endpoint cleaned up.")
|
logger.info("WebSocket endpoint cleaned up successfully.")
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Entry point for the CLI command."""
|
"""Entry point for the CLI command."""
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
args = parse_args()
|
|
||||||
|
|
||||||
uvicorn_kwargs = {
|
uvicorn_kwargs = {
|
||||||
"app": "whisperlivekit.basic_server:app",
|
"app": "whisperlivekit.basic_server:app",
|
||||||
"host":args.host,
|
"host":args.host,
|
||||||
@@ -93,7 +111,6 @@ def main():
|
|||||||
"ssl_keyfile": args.ssl_keyfile
|
"ssl_keyfile": args.ssl_keyfile
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if ssl_kwargs:
|
if ssl_kwargs:
|
||||||
uvicorn_kwargs = {**uvicorn_kwargs, **ssl_kwargs}
|
uvicorn_kwargs = {**uvicorn_kwargs, **ssl_kwargs}
|
||||||
|
|
||||||
|
|||||||
@@ -2,148 +2,10 @@ try:
|
|||||||
from whisperlivekit.whisper_streaming_custom.whisper_online import backend_factory, warmup_asr
|
from whisperlivekit.whisper_streaming_custom.whisper_online import backend_factory, warmup_asr
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from .whisper_streaming_custom.whisper_online import backend_factory, warmup_asr
|
from .whisper_streaming_custom.whisper_online import backend_factory, warmup_asr
|
||||||
from argparse import Namespace, ArgumentParser
|
from argparse import Namespace
|
||||||
|
|
||||||
def parse_args():
|
|
||||||
parser = ArgumentParser(description="Whisper FastAPI Online Server")
|
|
||||||
parser.add_argument(
|
|
||||||
"--host",
|
|
||||||
type=str,
|
|
||||||
default="localhost",
|
|
||||||
help="The host address to bind the server to.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--port", type=int, default=8000, help="The port number to bind the server to."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--warmup-file",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
dest="warmup_file",
|
|
||||||
help="""
|
|
||||||
The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast.
|
|
||||||
If not set, uses https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav.
|
|
||||||
If False, no warmup is performed.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--confidence-validation",
|
|
||||||
action="store_true",
|
|
||||||
help="Accelerates validation of tokens using confidence scores. Transcription will be faster but punctuation might be less accurate.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--diarization",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
help="Enable speaker diarization.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--no-transcription",
|
|
||||||
action="store_true",
|
|
||||||
help="Disable transcription to only see live diarization results.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--min-chunk-size",
|
|
||||||
type=float,
|
|
||||||
default=0.5,
|
|
||||||
help="Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--model",
|
|
||||||
type=str,
|
|
||||||
default="tiny",
|
|
||||||
help="Name size of the Whisper model to use (default: tiny). Suggested values: tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo. The model is automatically downloaded from the model hub if not present in model cache dir.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--model_cache_dir",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Overriding the default model cache dir where models downloaded from the hub are saved",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--model_dir",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--lan",
|
|
||||||
"--language",
|
|
||||||
type=str,
|
|
||||||
default="auto",
|
|
||||||
help="Source language code, e.g. en,de,cs, or 'auto' for language detection.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--task",
|
|
||||||
type=str,
|
|
||||||
default="transcribe",
|
|
||||||
choices=["transcribe", "translate"],
|
|
||||||
help="Transcribe or translate.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--backend",
|
|
||||||
type=str,
|
|
||||||
default="faster-whisper",
|
|
||||||
choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"],
|
|
||||||
help="Load only this backend for Whisper processing.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--vac",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
help="Use VAC = voice activity controller. Recommended. Requires torch.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--vac-chunk-size", type=float, default=0.04, help="VAC sample size in seconds."
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--no-vad",
|
|
||||||
action="store_true",
|
|
||||||
help="Disable VAD (voice activity detection).",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--buffer_trimming",
|
|
||||||
type=str,
|
|
||||||
default="segment",
|
|
||||||
choices=["sentence", "segment"],
|
|
||||||
help='Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--buffer_trimming_sec",
|
|
||||||
type=float,
|
|
||||||
default=15,
|
|
||||||
help="Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"-l",
|
|
||||||
"--log-level",
|
|
||||||
dest="log_level",
|
|
||||||
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
|
||||||
help="Set the log level",
|
|
||||||
default="DEBUG",
|
|
||||||
)
|
|
||||||
parser.add_argument("--ssl-certfile", type=str, help="Path to the SSL certificate file.", default=None)
|
|
||||||
parser.add_argument("--ssl-keyfile", type=str, help="Path to the SSL private key file.", default=None)
|
|
||||||
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
class TranscriptionEngine:
|
||||||
|
|
||||||
args.transcription = not args.no_transcription
|
|
||||||
args.vad = not args.no_vad
|
|
||||||
delattr(args, 'no_transcription')
|
|
||||||
delattr(args, 'no_vad')
|
|
||||||
|
|
||||||
return args
|
|
||||||
|
|
||||||
class WhisperLiveKit:
|
|
||||||
_instance = None
|
_instance = None
|
||||||
_initialized = False
|
_initialized = False
|
||||||
|
|
||||||
@@ -153,14 +15,51 @@ class WhisperLiveKit:
|
|||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
if WhisperLiveKit._initialized:
|
if TranscriptionEngine._initialized:
|
||||||
return
|
return
|
||||||
|
|
||||||
default_args = vars(parse_args())
|
defaults = {
|
||||||
|
"host": "localhost",
|
||||||
|
"port": 8000,
|
||||||
|
"warmup_file": None,
|
||||||
|
"confidence_validation": False,
|
||||||
|
"diarization": False,
|
||||||
|
"punctuation_split": False,
|
||||||
|
"min_chunk_size": 0.5,
|
||||||
|
"model": "tiny",
|
||||||
|
"model_cache_dir": None,
|
||||||
|
"model_dir": None,
|
||||||
|
"lan": "auto",
|
||||||
|
"task": "transcribe",
|
||||||
|
"backend": "faster-whisper",
|
||||||
|
"vac": False,
|
||||||
|
"vac_chunk_size": 0.04,
|
||||||
|
"buffer_trimming": "segment",
|
||||||
|
"buffer_trimming_sec": 15,
|
||||||
|
"log_level": "DEBUG",
|
||||||
|
"ssl_certfile": None,
|
||||||
|
"ssl_keyfile": None,
|
||||||
|
"transcription": True,
|
||||||
|
"vad": True,
|
||||||
|
"segmentation_model": "pyannote/segmentation-3.0",
|
||||||
|
"embedding_model": "pyannote/embedding",
|
||||||
|
}
|
||||||
|
|
||||||
|
config_dict = {**defaults, **kwargs}
|
||||||
|
|
||||||
|
if 'no_transcription' in kwargs:
|
||||||
|
config_dict['transcription'] = not kwargs['no_transcription']
|
||||||
|
if 'no_vad' in kwargs:
|
||||||
|
config_dict['vad'] = not kwargs['no_vad']
|
||||||
|
|
||||||
merged_args = {**default_args, **kwargs}
|
config_dict.pop('no_transcription', None)
|
||||||
|
config_dict.pop('no_vad', None)
|
||||||
self.args = Namespace(**merged_args)
|
|
||||||
|
if 'language' in kwargs:
|
||||||
|
config_dict['lan'] = kwargs['language']
|
||||||
|
config_dict.pop('language', None)
|
||||||
|
|
||||||
|
self.args = Namespace(**config_dict)
|
||||||
|
|
||||||
self.asr = None
|
self.asr = None
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
@@ -172,13 +71,10 @@ class WhisperLiveKit:
|
|||||||
|
|
||||||
if self.args.diarization:
|
if self.args.diarization:
|
||||||
from whisperlivekit.diarization.diarization_online import DiartDiarization
|
from whisperlivekit.diarization.diarization_online import DiartDiarization
|
||||||
self.diarization = DiartDiarization()
|
self.diarization = DiartDiarization(
|
||||||
|
block_duration=self.args.min_chunk_size,
|
||||||
|
segmentation_model_name=self.args.segmentation_model,
|
||||||
|
embedding_model_name=self.args.embedding_model
|
||||||
|
)
|
||||||
|
|
||||||
WhisperLiveKit._initialized = True
|
TranscriptionEngine._initialized = True
|
||||||
|
|
||||||
def web_interface(self):
|
|
||||||
import pkg_resources
|
|
||||||
html_path = pkg_resources.resource_filename('whisperlivekit', 'web/live_transcription.html')
|
|
||||||
with open(html_path, "r", encoding="utf-8") as f:
|
|
||||||
html = f.read()
|
|
||||||
return html
|
|
||||||
|
|||||||
@@ -3,7 +3,8 @@ import re
|
|||||||
import threading
|
import threading
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
|
from queue import SimpleQueue, Empty
|
||||||
|
|
||||||
from diart import SpeakerDiarization, SpeakerDiarizationConfig
|
from diart import SpeakerDiarization, SpeakerDiarizationConfig
|
||||||
from diart.inference import StreamingInference
|
from diart.inference import StreamingInference
|
||||||
@@ -13,6 +14,7 @@ from diart.sources import MicrophoneAudioSource
|
|||||||
from rx.core import Observer
|
from rx.core import Observer
|
||||||
from typing import Tuple, Any, List
|
from typing import Tuple, Any, List
|
||||||
from pyannote.core import Annotation
|
from pyannote.core import Annotation
|
||||||
|
import diart.models as m
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -78,40 +80,114 @@ class DiarizationObserver(Observer):
|
|||||||
|
|
||||||
class WebSocketAudioSource(AudioSource):
|
class WebSocketAudioSource(AudioSource):
|
||||||
"""
|
"""
|
||||||
Custom AudioSource that blocks in read() until close() is called.
|
Buffers incoming audio and releases it in fixed-size chunks at regular intervals.
|
||||||
Use push_audio() to inject PCM chunks.
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, uri: str = "websocket", sample_rate: int = 16000):
|
def __init__(self, uri: str = "websocket", sample_rate: int = 16000, block_duration: float = 0.5):
|
||||||
super().__init__(uri, sample_rate)
|
super().__init__(uri, sample_rate)
|
||||||
|
self.block_duration = block_duration
|
||||||
|
self.block_size = int(np.rint(block_duration * sample_rate))
|
||||||
|
self._queue = SimpleQueue()
|
||||||
|
self._buffer = np.array([], dtype=np.float32)
|
||||||
|
self._buffer_lock = threading.Lock()
|
||||||
self._closed = False
|
self._closed = False
|
||||||
self._close_event = threading.Event()
|
self._close_event = threading.Event()
|
||||||
|
self._processing_thread = None
|
||||||
|
self._last_chunk_time = time.time()
|
||||||
|
|
||||||
def read(self):
|
def read(self):
|
||||||
|
"""Start processing buffered audio and emit fixed-size chunks."""
|
||||||
|
self._processing_thread = threading.Thread(target=self._process_chunks)
|
||||||
|
self._processing_thread.daemon = True
|
||||||
|
self._processing_thread.start()
|
||||||
|
|
||||||
self._close_event.wait()
|
self._close_event.wait()
|
||||||
|
if self._processing_thread:
|
||||||
|
self._processing_thread.join(timeout=2.0)
|
||||||
|
|
||||||
|
def _process_chunks(self):
|
||||||
|
"""Process audio from queue and emit fixed-size chunks at regular intervals."""
|
||||||
|
while not self._closed:
|
||||||
|
try:
|
||||||
|
audio_chunk = self._queue.get(timeout=0.1)
|
||||||
|
|
||||||
|
with self._buffer_lock:
|
||||||
|
self._buffer = np.concatenate([self._buffer, audio_chunk])
|
||||||
|
|
||||||
|
while len(self._buffer) >= self.block_size:
|
||||||
|
chunk = self._buffer[:self.block_size]
|
||||||
|
self._buffer = self._buffer[self.block_size:]
|
||||||
|
|
||||||
|
current_time = time.time()
|
||||||
|
time_since_last = current_time - self._last_chunk_time
|
||||||
|
if time_since_last < self.block_duration:
|
||||||
|
time.sleep(self.block_duration - time_since_last)
|
||||||
|
|
||||||
|
chunk_reshaped = chunk.reshape(1, -1)
|
||||||
|
self.stream.on_next(chunk_reshaped)
|
||||||
|
self._last_chunk_time = time.time()
|
||||||
|
|
||||||
|
except Empty:
|
||||||
|
with self._buffer_lock:
|
||||||
|
if len(self._buffer) > 0 and time.time() - self._last_chunk_time > self.block_duration:
|
||||||
|
padded_chunk = np.zeros(self.block_size, dtype=np.float32)
|
||||||
|
padded_chunk[:len(self._buffer)] = self._buffer
|
||||||
|
self._buffer = np.array([], dtype=np.float32)
|
||||||
|
|
||||||
|
chunk_reshaped = padded_chunk.reshape(1, -1)
|
||||||
|
self.stream.on_next(chunk_reshaped)
|
||||||
|
self._last_chunk_time = time.time()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in audio processing thread: {e}")
|
||||||
|
self.stream.on_error(e)
|
||||||
|
break
|
||||||
|
|
||||||
|
with self._buffer_lock:
|
||||||
|
if len(self._buffer) > 0:
|
||||||
|
padded_chunk = np.zeros(self.block_size, dtype=np.float32)
|
||||||
|
padded_chunk[:len(self._buffer)] = self._buffer
|
||||||
|
chunk_reshaped = padded_chunk.reshape(1, -1)
|
||||||
|
self.stream.on_next(chunk_reshaped)
|
||||||
|
|
||||||
|
self.stream.on_completed()
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
if not self._closed:
|
if not self._closed:
|
||||||
self._closed = True
|
self._closed = True
|
||||||
self.stream.on_completed()
|
|
||||||
self._close_event.set()
|
self._close_event.set()
|
||||||
|
|
||||||
def push_audio(self, chunk: np.ndarray):
|
def push_audio(self, chunk: np.ndarray):
|
||||||
|
"""Add audio chunk to the processing queue."""
|
||||||
if not self._closed:
|
if not self._closed:
|
||||||
new_audio = np.expand_dims(chunk, axis=0)
|
if chunk.ndim > 1:
|
||||||
logger.debug('Add new chunk with shape:', new_audio.shape)
|
chunk = chunk.flatten()
|
||||||
self.stream.on_next(new_audio)
|
self._queue.put(chunk)
|
||||||
|
logger.debug(f'Added chunk to queue with {len(chunk)} samples')
|
||||||
|
|
||||||
|
|
||||||
class DiartDiarization:
|
class DiartDiarization:
|
||||||
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False):
|
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 0.5, segmentation_model_name: str = "pyannote/segmentation-3.0", embedding_model_name: str = "speechbrain/spkrec-ecapa-voxceleb"):
|
||||||
|
segmentation_model = m.SegmentationModel.from_pretrained(segmentation_model_name)
|
||||||
|
embedding_model = m.EmbeddingModel.from_pretrained(embedding_model_name)
|
||||||
|
|
||||||
|
if config is None:
|
||||||
|
config = SpeakerDiarizationConfig(
|
||||||
|
segmentation=segmentation_model,
|
||||||
|
embedding=embedding_model,
|
||||||
|
)
|
||||||
|
|
||||||
self.pipeline = SpeakerDiarization(config=config)
|
self.pipeline = SpeakerDiarization(config=config)
|
||||||
self.observer = DiarizationObserver()
|
self.observer = DiarizationObserver()
|
||||||
|
self.lag_diart = None
|
||||||
|
|
||||||
if use_microphone:
|
if use_microphone:
|
||||||
self.source = MicrophoneAudioSource()
|
self.source = MicrophoneAudioSource(block_duration=block_duration)
|
||||||
self.custom_source = None
|
self.custom_source = None
|
||||||
else:
|
else:
|
||||||
self.custom_source = WebSocketAudioSource(uri="websocket_source", sample_rate=sample_rate)
|
self.custom_source = WebSocketAudioSource(
|
||||||
|
uri="websocket_source",
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
block_duration=block_duration
|
||||||
|
)
|
||||||
self.source = self.custom_source
|
self.source = self.custom_source
|
||||||
|
|
||||||
self.inference = StreamingInference(
|
self.inference = StreamingInference(
|
||||||
@@ -138,16 +214,102 @@ class DiartDiarization:
|
|||||||
if self.custom_source:
|
if self.custom_source:
|
||||||
self.custom_source.close()
|
self.custom_source.close()
|
||||||
|
|
||||||
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) -> float:
|
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list, use_punctuation_split: bool = False) -> float:
|
||||||
"""
|
"""
|
||||||
Assign speakers to tokens based on timing overlap with speaker segments.
|
Assign speakers to tokens based on timing overlap with speaker segments.
|
||||||
Uses the segments collected by the observer.
|
Uses the segments collected by the observer.
|
||||||
|
|
||||||
|
If use_punctuation_split is True, uses punctuation marks to refine speaker boundaries.
|
||||||
"""
|
"""
|
||||||
segments = self.observer.get_segments()
|
segments = self.observer.get_segments()
|
||||||
|
|
||||||
|
# Debug logging
|
||||||
|
logger.debug(f"assign_speakers_to_tokens called with {len(tokens)} tokens")
|
||||||
|
logger.debug(f"Available segments: {len(segments)}")
|
||||||
|
for i, seg in enumerate(segments[:5]): # Show first 5 segments
|
||||||
|
logger.debug(f" Segment {i}: {seg.speaker} [{seg.start:.2f}-{seg.end:.2f}]")
|
||||||
|
|
||||||
|
if not self.lag_diart and segments and tokens:
|
||||||
|
self.lag_diart = segments[0].start - tokens[0].start
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
for segment in segments:
|
for segment in segments:
|
||||||
if not (segment.end <= token.start or segment.start >= token.end):
|
if not (segment.end <= token.start + self.lag_diart or segment.start >= token.end + self.lag_diart):
|
||||||
token.speaker = extract_number(segment.speaker) + 1
|
token.speaker = extract_number(segment.speaker) + 1
|
||||||
end_attributed_speaker = max(token.end, end_attributed_speaker)
|
end_attributed_speaker = max(token.end, end_attributed_speaker)
|
||||||
return end_attributed_speaker
|
|
||||||
|
if use_punctuation_split and len(tokens) > 1:
|
||||||
|
punctuation_marks = {'.', '!', '?'}
|
||||||
|
|
||||||
|
print("Here are the tokens:",
|
||||||
|
[(t.text, t.start, t.end, t.speaker) for t in tokens[:10]])
|
||||||
|
|
||||||
|
segment_map = []
|
||||||
|
for segment in segments:
|
||||||
|
speaker_num = extract_number(segment.speaker) + 1
|
||||||
|
segment_map.append((segment.start, segment.end, speaker_num))
|
||||||
|
segment_map.sort(key=lambda x: x[0])
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
while i < len(tokens):
|
||||||
|
current_token = tokens[i]
|
||||||
|
|
||||||
|
is_sentence_end = False
|
||||||
|
if current_token.text and current_token.text.strip():
|
||||||
|
text = current_token.text.strip()
|
||||||
|
if text[-1] in punctuation_marks:
|
||||||
|
is_sentence_end = True
|
||||||
|
logger.debug(f"Token {i} ends sentence: '{current_token.text}' at {current_token.end:.2f}s")
|
||||||
|
|
||||||
|
if is_sentence_end and current_token.speaker != -1:
|
||||||
|
punctuation_time = current_token.end
|
||||||
|
current_speaker = current_token.speaker
|
||||||
|
|
||||||
|
j = i + 1
|
||||||
|
next_sentence_tokens = []
|
||||||
|
while j < len(tokens):
|
||||||
|
next_token = tokens[j]
|
||||||
|
next_sentence_tokens.append(j)
|
||||||
|
|
||||||
|
# Check if this token ends the next sentence
|
||||||
|
if next_token.text and next_token.text.strip():
|
||||||
|
if next_token.text.strip()[-1] in punctuation_marks:
|
||||||
|
break
|
||||||
|
j += 1
|
||||||
|
|
||||||
|
if next_sentence_tokens:
|
||||||
|
speaker_times = {}
|
||||||
|
|
||||||
|
for idx in next_sentence_tokens:
|
||||||
|
token = tokens[idx]
|
||||||
|
# Find which segments overlap with this token
|
||||||
|
for seg_start, seg_end, seg_speaker in segment_map:
|
||||||
|
if not (seg_end <= token.start or seg_start >= token.end):
|
||||||
|
# Calculate overlap duration
|
||||||
|
overlap_start = max(seg_start, token.start)
|
||||||
|
overlap_end = min(seg_end, token.end)
|
||||||
|
overlap_duration = overlap_end - overlap_start
|
||||||
|
|
||||||
|
if seg_speaker not in speaker_times:
|
||||||
|
speaker_times[seg_speaker] = 0
|
||||||
|
speaker_times[seg_speaker] += overlap_duration
|
||||||
|
|
||||||
|
if speaker_times:
|
||||||
|
dominant_speaker = max(speaker_times.items(), key=lambda x: x[1])[0]
|
||||||
|
|
||||||
|
if dominant_speaker != current_speaker:
|
||||||
|
logger.debug(f" Speaker change after punctuation: {current_speaker} → {dominant_speaker}")
|
||||||
|
|
||||||
|
for idx in next_sentence_tokens:
|
||||||
|
if tokens[idx].speaker != dominant_speaker:
|
||||||
|
logger.debug(f" Reassigning token {idx} ('{tokens[idx].text}') to Speaker {dominant_speaker}")
|
||||||
|
tokens[idx].speaker = dominant_speaker
|
||||||
|
end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker)
|
||||||
|
else:
|
||||||
|
for idx in next_sentence_tokens:
|
||||||
|
if tokens[idx].speaker == -1:
|
||||||
|
tokens[idx].speaker = current_speaker
|
||||||
|
end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker)
|
||||||
|
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return end_attributed_speaker
|
||||||
|
|||||||
162
whisperlivekit/parse_args.py
Normal file
162
whisperlivekit/parse_args.py
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = ArgumentParser(description="Whisper FastAPI Online Server")
|
||||||
|
parser.add_argument(
|
||||||
|
"--host",
|
||||||
|
type=str,
|
||||||
|
default="localhost",
|
||||||
|
help="The host address to bind the server to.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--port", type=int, default=8000, help="The port number to bind the server to."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--warmup-file",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
dest="warmup_file",
|
||||||
|
help="""
|
||||||
|
The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast.
|
||||||
|
If not set, uses https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav.
|
||||||
|
If False, no warmup is performed.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--confidence-validation",
|
||||||
|
action="store_true",
|
||||||
|
help="Accelerates validation of tokens using confidence scores. Transcription will be faster but punctuation might be less accurate.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--diarization",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Enable speaker diarization.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--punctuation-split",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Use punctuation marks from transcription to improve speaker boundary detection. Requires both transcription and diarization to be enabled.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--segmentation-model",
|
||||||
|
type=str,
|
||||||
|
default="pyannote/segmentation-3.0",
|
||||||
|
help="Hugging Face model ID for pyannote.audio segmentation model.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-model",
|
||||||
|
type=str,
|
||||||
|
default="pyannote/embedding",
|
||||||
|
help="Hugging Face model ID for pyannote.audio embedding model.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-transcription",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable transcription to only see live diarization results.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--min-chunk-size",
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help="Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
type=str,
|
||||||
|
default="tiny",
|
||||||
|
help="Name size of the Whisper model to use (default: tiny). Suggested values: tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo. The model is automatically downloaded from the model hub if not present in model cache dir.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_cache_dir",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Overriding the default model cache dir where models downloaded from the hub are saved",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_dir",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lan",
|
||||||
|
"--language",
|
||||||
|
type=str,
|
||||||
|
default="auto",
|
||||||
|
help="Source language code, e.g. en,de,cs, or 'auto' for language detection.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--task",
|
||||||
|
type=str,
|
||||||
|
default="transcribe",
|
||||||
|
choices=["transcribe", "translate"],
|
||||||
|
help="Transcribe or translate.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--backend",
|
||||||
|
type=str,
|
||||||
|
default="faster-whisper",
|
||||||
|
choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"],
|
||||||
|
help="Load only this backend for Whisper processing.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vac",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Use VAC = voice activity controller. Recommended. Requires torch.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vac-chunk-size", type=float, default=0.04, help="VAC sample size in seconds."
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-vad",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable VAD (voice activity detection).",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--buffer_trimming",
|
||||||
|
type=str,
|
||||||
|
default="segment",
|
||||||
|
choices=["sentence", "segment"],
|
||||||
|
help='Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--buffer_trimming_sec",
|
||||||
|
type=float,
|
||||||
|
default=15,
|
||||||
|
help="Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-l",
|
||||||
|
"--log-level",
|
||||||
|
dest="log_level",
|
||||||
|
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
||||||
|
help="Set the log level",
|
||||||
|
default="DEBUG",
|
||||||
|
)
|
||||||
|
parser.add_argument("--ssl-certfile", type=str, help="Path to the SSL certificate file.", default=None)
|
||||||
|
parser.add_argument("--ssl-keyfile", type=str, help="Path to the SSL private key file.", default=None)
|
||||||
|
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
args.transcription = not args.no_transcription
|
||||||
|
args.vad = not args.no_vad
|
||||||
|
delattr(args, 'no_transcription')
|
||||||
|
delattr(args, 'no_vad')
|
||||||
|
|
||||||
|
return args
|
||||||
@@ -26,4 +26,7 @@ class Transcript(TimedText):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SpeakerSegment(TimedText):
|
class SpeakerSegment(TimedText):
|
||||||
|
"""Represents a segment of audio attributed to a specific speaker.
|
||||||
|
No text nor probability is associated with this segment.
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
@@ -308,6 +308,7 @@
|
|||||||
let waveCtx = waveCanvas.getContext("2d");
|
let waveCtx = waveCanvas.getContext("2d");
|
||||||
let animationFrame = null;
|
let animationFrame = null;
|
||||||
let waitingForStop = false;
|
let waitingForStop = false;
|
||||||
|
let lastReceivedData = null;
|
||||||
waveCanvas.width = 60 * (window.devicePixelRatio || 1);
|
waveCanvas.width = 60 * (window.devicePixelRatio || 1);
|
||||||
waveCanvas.height = 30 * (window.devicePixelRatio || 1);
|
waveCanvas.height = 30 * (window.devicePixelRatio || 1);
|
||||||
waveCtx.scale(window.devicePixelRatio || 1, window.devicePixelRatio || 1);
|
waveCtx.scale(window.devicePixelRatio || 1, window.devicePixelRatio || 1);
|
||||||
@@ -357,18 +358,31 @@
|
|||||||
|
|
||||||
websocket.onclose = () => {
|
websocket.onclose = () => {
|
||||||
if (userClosing) {
|
if (userClosing) {
|
||||||
if (!statusText.textContent.includes("Recording stopped. Processing final audio")) { // This is a bit of a hack. We should have a better way to handle this. eg. using a status code.
|
if (waitingForStop) {
|
||||||
statusText.textContent = "Finished processing audio! Ready to record again.";
|
statusText.textContent = "Processing finalized or connection closed.";
|
||||||
|
if (lastReceivedData) {
|
||||||
|
renderLinesWithBuffer(
|
||||||
|
lastReceivedData.lines || [],
|
||||||
|
lastReceivedData.buffer_diarization || "",
|
||||||
|
lastReceivedData.buffer_transcription || "",
|
||||||
|
0, 0, true // isFinalizing = true
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
waitingForStop = false;
|
// If ready_to_stop was received, statusText is already "Finished processing..."
|
||||||
|
// and waitingForStop is false.
|
||||||
} else {
|
} else {
|
||||||
statusText.textContent =
|
statusText.textContent = "Disconnected from the WebSocket server. (Check logs if model is loading.)";
|
||||||
"Disconnected from the WebSocket server. (Check logs if model is loading.)";
|
|
||||||
if (isRecording) {
|
if (isRecording) {
|
||||||
stopRecording();
|
stopRecording();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
userClosing = false;
|
isRecording = false;
|
||||||
|
waitingForStop = false;
|
||||||
|
userClosing = false;
|
||||||
|
lastReceivedData = null;
|
||||||
|
websocket = null;
|
||||||
|
updateUI();
|
||||||
};
|
};
|
||||||
|
|
||||||
websocket.onerror = () => {
|
websocket.onerror = () => {
|
||||||
@@ -382,31 +396,39 @@
|
|||||||
|
|
||||||
// Check for status messages
|
// Check for status messages
|
||||||
if (data.type === "ready_to_stop") {
|
if (data.type === "ready_to_stop") {
|
||||||
console.log("Ready to stop, closing WebSocket");
|
console.log("Ready to stop received, finalizing display and closing WebSocket.");
|
||||||
|
|
||||||
// signal that we are not waiting for stop anymore
|
|
||||||
waitingForStop = false;
|
waitingForStop = false;
|
||||||
recordButton.disabled = false; // this should be elsewhere
|
|
||||||
console.log("Record button enabled");
|
|
||||||
|
|
||||||
//Now we can close the WebSocket
|
if (lastReceivedData) {
|
||||||
if (websocket) {
|
renderLinesWithBuffer(
|
||||||
websocket.close();
|
lastReceivedData.lines || [],
|
||||||
websocket = null;
|
lastReceivedData.buffer_diarization || "",
|
||||||
|
lastReceivedData.buffer_transcription || "",
|
||||||
|
0, // No more lag
|
||||||
|
0, // No more lag
|
||||||
|
true // isFinalizing = true
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
statusText.textContent = "Finished processing audio! Ready to record again.";
|
||||||
|
recordButton.disabled = false;
|
||||||
|
|
||||||
|
if (websocket) {
|
||||||
|
websocket.close(); // will trigger onclose
|
||||||
|
// websocket = null; // onclose handle setting websocket to null
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
lastReceivedData = data;
|
||||||
|
|
||||||
// Handle normal transcription updates
|
// Handle normal transcription updates
|
||||||
const {
|
const {
|
||||||
lines = [],
|
lines = [],
|
||||||
buffer_transcription = "",
|
buffer_transcription = "",
|
||||||
buffer_diarization = "",
|
buffer_diarization = "",
|
||||||
remaining_time_transcription = 0,
|
remaining_time_transcription = 0,
|
||||||
remaining_time_diarization = 0
|
remaining_time_diarization = 0,
|
||||||
|
status = "active_transcription"
|
||||||
} = data;
|
} = data;
|
||||||
|
|
||||||
renderLinesWithBuffer(
|
renderLinesWithBuffer(
|
||||||
@@ -414,13 +436,20 @@
|
|||||||
buffer_diarization,
|
buffer_diarization,
|
||||||
buffer_transcription,
|
buffer_transcription,
|
||||||
remaining_time_diarization,
|
remaining_time_diarization,
|
||||||
remaining_time_transcription
|
remaining_time_transcription,
|
||||||
|
false,
|
||||||
|
status
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
function renderLinesWithBuffer(lines, buffer_diarization, buffer_transcription, remaining_time_diarization, remaining_time_transcription) {
|
function renderLinesWithBuffer(lines, buffer_diarization, buffer_transcription, remaining_time_diarization, remaining_time_transcription, isFinalizing = false, current_status = "active_transcription") {
|
||||||
|
if (current_status === "no_audio_detected") {
|
||||||
|
linesTranscriptDiv.innerHTML = "<p style='text-align: center; color: #666; margin-top: 20px;'><em>No audio detected...</em></p>";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const linesHtml = lines.map((item, idx) => {
|
const linesHtml = lines.map((item, idx) => {
|
||||||
let timeInfo = "";
|
let timeInfo = "";
|
||||||
if (item.beg !== undefined && item.end !== undefined) {
|
if (item.beg !== undefined && item.end !== undefined) {
|
||||||
@@ -430,30 +459,46 @@
|
|||||||
let speakerLabel = "";
|
let speakerLabel = "";
|
||||||
if (item.speaker === -2) {
|
if (item.speaker === -2) {
|
||||||
speakerLabel = `<span class="silence">Silence<span id='timeInfo'>${timeInfo}</span></span>`;
|
speakerLabel = `<span class="silence">Silence<span id='timeInfo'>${timeInfo}</span></span>`;
|
||||||
} else if (item.speaker == 0) {
|
} else if (item.speaker == 0 && !isFinalizing) {
|
||||||
speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'>${remaining_time_diarization} second(s) of audio are undergoing diarization</span></span>`;
|
speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'>${remaining_time_diarization} second(s) of audio are undergoing diarization</span></span>`;
|
||||||
} else if (item.speaker == -1) {
|
} else if (item.speaker == -1) {
|
||||||
speakerLabel = `<span id="speaker"><span id='timeInfo'>${timeInfo}</span></span>`;
|
speakerLabel = `<span id="speaker">Speaker 1<span id='timeInfo'>${timeInfo}</span></span>`;
|
||||||
} else if (item.speaker !== -1) {
|
} else if (item.speaker !== -1 && item.speaker !== 0) {
|
||||||
speakerLabel = `<span id="speaker">Speaker ${item.speaker}<span id='timeInfo'>${timeInfo}</span></span>`;
|
speakerLabel = `<span id="speaker">Speaker ${item.speaker}<span id='timeInfo'>${timeInfo}</span></span>`;
|
||||||
}
|
}
|
||||||
|
|
||||||
let textContent = item.text;
|
|
||||||
if (idx === lines.length - 1) {
|
let currentLineText = item.text || "";
|
||||||
speakerLabel += `<span class="label_transcription"><span class="spinner"></span>Transcription lag <span id='timeInfo'>${remaining_time_transcription}s</span></span>`
|
|
||||||
}
|
if (idx === lines.length - 1) {
|
||||||
if (idx === lines.length - 1 && buffer_diarization) {
|
if (!isFinalizing) {
|
||||||
speakerLabel += `<span class="label_diarization"><span class="spinner"></span>Diarization lag<span id='timeInfo'>${remaining_time_diarization}s</span></span>`
|
if (remaining_time_transcription > 0) {
|
||||||
textContent += `<span class="buffer_diarization">${buffer_diarization}</span>`;
|
speakerLabel += `<span class="label_transcription"><span class="spinner"></span>Transcription lag <span id='timeInfo'>${remaining_time_transcription}s</span></span>`;
|
||||||
}
|
}
|
||||||
if (idx === lines.length - 1) {
|
if (buffer_diarization && remaining_time_diarization > 0) {
|
||||||
textContent += `<span class="buffer_transcription">${buffer_transcription}</span>`;
|
speakerLabel += `<span class="label_diarization"><span class="spinner"></span>Diarization lag<span id='timeInfo'>${remaining_time_diarization}s</span></span>`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (buffer_diarization) {
|
||||||
|
if (isFinalizing) {
|
||||||
|
currentLineText += (currentLineText.length > 0 && buffer_diarization.trim().length > 0 ? " " : "") + buffer_diarization.trim();
|
||||||
|
} else {
|
||||||
|
currentLineText += `<span class="buffer_diarization">${buffer_diarization}</span>`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (buffer_transcription) {
|
||||||
|
if (isFinalizing) {
|
||||||
|
currentLineText += (currentLineText.length > 0 && buffer_transcription.trim().length > 0 ? " " : "") + buffer_transcription.trim();
|
||||||
|
} else {
|
||||||
|
currentLineText += `<span class="buffer_transcription">${buffer_transcription}</span>`;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return currentLineText.trim().length > 0 || speakerLabel.length > 0
|
||||||
return textContent
|
? `<p>${speakerLabel}<br/><div class='textcontent'>${currentLineText}</div></p>`
|
||||||
? `<p>${speakerLabel}<br/><div class='textcontent'>${textContent}</div></p>`
|
: `<p>${speakerLabel}<br/></p>`;
|
||||||
: `<p>${speakerLabel}<br/></p>`;
|
|
||||||
}).join("");
|
}).join("");
|
||||||
|
|
||||||
linesTranscriptDiv.innerHTML = linesHtml;
|
linesTranscriptDiv.innerHTML = linesHtml;
|
||||||
@@ -578,20 +623,6 @@
|
|||||||
timerElement.textContent = "00:00";
|
timerElement.textContent = "00:00";
|
||||||
startTime = null;
|
startTime = null;
|
||||||
|
|
||||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
|
||||||
try {
|
|
||||||
await websocket.send(JSON.stringify({
|
|
||||||
type: "stop",
|
|
||||||
message: "User stopped recording"
|
|
||||||
}));
|
|
||||||
statusText.textContent = "Recording stopped. Processing final audio...";
|
|
||||||
} catch (e) {
|
|
||||||
console.error("Could not send stop message:", e);
|
|
||||||
statusText.textContent = "Recording stopped. Error during final audio processing.";
|
|
||||||
websocket.close();
|
|
||||||
websocket = null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
isRecording = false;
|
isRecording = false;
|
||||||
updateUI();
|
updateUI();
|
||||||
@@ -625,19 +656,22 @@
|
|||||||
|
|
||||||
function updateUI() {
|
function updateUI() {
|
||||||
recordButton.classList.toggle("recording", isRecording);
|
recordButton.classList.toggle("recording", isRecording);
|
||||||
|
recordButton.disabled = waitingForStop;
|
||||||
|
|
||||||
if (waitingForStop) {
|
if (waitingForStop) {
|
||||||
statusText.textContent = "Please wait for processing to complete...";
|
if (statusText.textContent !== "Recording stopped. Processing final audio...") {
|
||||||
recordButton.disabled = true; // Optionally disable the button while waiting
|
statusText.textContent = "Please wait for processing to complete...";
|
||||||
console.log("Record button disabled");
|
}
|
||||||
} else if (isRecording) {
|
} else if (isRecording) {
|
||||||
statusText.textContent = "Recording...";
|
statusText.textContent = "Recording...";
|
||||||
recordButton.disabled = false;
|
|
||||||
console.log("Record button enabled");
|
|
||||||
} else {
|
} else {
|
||||||
statusText.textContent = "Click to start transcription";
|
if (statusText.textContent !== "Finished processing audio! Ready to record again." &&
|
||||||
|
statusText.textContent !== "Processing finalized or connection closed.") {
|
||||||
|
statusText.textContent = "Click to start transcription";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!waitingForStop) {
|
||||||
recordButton.disabled = false;
|
recordButton.disabled = false;
|
||||||
console.log("Record button enabled");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -645,4 +679,4 @@
|
|||||||
</script>
|
</script>
|
||||||
</body>
|
</body>
|
||||||
|
|
||||||
</html>
|
</html>
|
||||||
|
|||||||
13
whisperlivekit/web/web_interface.py
Normal file
13
whisperlivekit/web/web_interface.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
import logging
|
||||||
|
import importlib.resources as resources
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def get_web_interface_html():
|
||||||
|
"""Loads the HTML for the web interface using importlib.resources."""
|
||||||
|
try:
|
||||||
|
with resources.files('whisperlivekit.web').joinpath('live_transcription.html').open('r', encoding='utf-8') as f:
|
||||||
|
return f.read()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading web interface HTML: {e}")
|
||||||
|
return "<html><body><h1>Error loading interface</h1></body></html>"
|
||||||
@@ -144,7 +144,11 @@ class OnlineASRProcessor:
|
|||||||
self.transcript_buffer.last_committed_time = self.buffer_time_offset
|
self.transcript_buffer.last_committed_time = self.buffer_time_offset
|
||||||
self.committed: List[ASRToken] = []
|
self.committed: List[ASRToken] = []
|
||||||
|
|
||||||
def insert_audio_chunk(self, audio: np.ndarray):
|
def get_audio_buffer_end_time(self) -> float:
|
||||||
|
"""Returns the absolute end time of the current audio_buffer."""
|
||||||
|
return self.buffer_time_offset + (len(self.audio_buffer) / self.SAMPLING_RATE)
|
||||||
|
|
||||||
|
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: Optional[float] = None):
|
||||||
"""Append an audio chunk (a numpy array) to the current audio buffer."""
|
"""Append an audio chunk (a numpy array) to the current audio buffer."""
|
||||||
self.audio_buffer = np.append(self.audio_buffer, audio)
|
self.audio_buffer = np.append(self.audio_buffer, audio)
|
||||||
|
|
||||||
@@ -179,18 +183,19 @@ class OnlineASRProcessor:
|
|||||||
return self.concatenate_tokens(self.transcript_buffer.buffer)
|
return self.concatenate_tokens(self.transcript_buffer.buffer)
|
||||||
|
|
||||||
|
|
||||||
def process_iter(self) -> Transcript:
|
def process_iter(self) -> Tuple[List[ASRToken], float]:
|
||||||
"""
|
"""
|
||||||
Processes the current audio buffer.
|
Processes the current audio buffer.
|
||||||
|
|
||||||
Returns a Transcript object representing the committed transcript.
|
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
|
||||||
"""
|
"""
|
||||||
|
current_audio_processed_upto = self.get_audio_buffer_end_time()
|
||||||
prompt_text, _ = self.prompt()
|
prompt_text, _ = self.prompt()
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds from {self.buffer_time_offset:.2f}"
|
f"Transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds from {self.buffer_time_offset:.2f}"
|
||||||
)
|
)
|
||||||
res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt_text)
|
res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt_text)
|
||||||
tokens = self.asr.ts_words(res) # Expecting List[ASRToken]
|
tokens = self.asr.ts_words(res)
|
||||||
self.transcript_buffer.insert(tokens, self.buffer_time_offset)
|
self.transcript_buffer.insert(tokens, self.buffer_time_offset)
|
||||||
committed_tokens = self.transcript_buffer.flush()
|
committed_tokens = self.transcript_buffer.flush()
|
||||||
self.committed.extend(committed_tokens)
|
self.committed.extend(committed_tokens)
|
||||||
@@ -210,7 +215,7 @@ class OnlineASRProcessor:
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
f"Length of audio buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds"
|
f"Length of audio buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds"
|
||||||
)
|
)
|
||||||
return committed_tokens
|
return committed_tokens, current_audio_processed_upto
|
||||||
|
|
||||||
def chunk_completed_sentence(self):
|
def chunk_completed_sentence(self):
|
||||||
"""
|
"""
|
||||||
@@ -343,15 +348,17 @@ class OnlineASRProcessor:
|
|||||||
)
|
)
|
||||||
sentences.append(sentence)
|
sentences.append(sentence)
|
||||||
return sentences
|
return sentences
|
||||||
def finish(self) -> Transcript:
|
|
||||||
|
def finish(self) -> Tuple[List[ASRToken], float]:
|
||||||
"""
|
"""
|
||||||
Flush the remaining transcript when processing ends.
|
Flush the remaining transcript when processing ends.
|
||||||
|
Returns a tuple: (list of remaining ASRToken objects, float representing the final audio processed up to time).
|
||||||
"""
|
"""
|
||||||
remaining_tokens = self.transcript_buffer.buffer
|
remaining_tokens = self.transcript_buffer.buffer
|
||||||
final_transcript = self.concatenate_tokens(remaining_tokens)
|
logger.debug(f"Final non-committed tokens: {remaining_tokens}")
|
||||||
logger.debug(f"Final non-committed transcript: {final_transcript}")
|
final_processed_upto = self.buffer_time_offset + (len(self.audio_buffer) / self.SAMPLING_RATE)
|
||||||
self.buffer_time_offset += len(self.audio_buffer) / self.SAMPLING_RATE
|
self.buffer_time_offset = final_processed_upto
|
||||||
return final_transcript
|
return remaining_tokens, final_processed_upto
|
||||||
|
|
||||||
def concatenate_tokens(
|
def concatenate_tokens(
|
||||||
self,
|
self,
|
||||||
@@ -384,7 +391,8 @@ class VACOnlineASRProcessor:
|
|||||||
def __init__(self, online_chunk_size: float, *args, **kwargs):
|
def __init__(self, online_chunk_size: float, *args, **kwargs):
|
||||||
self.online_chunk_size = online_chunk_size
|
self.online_chunk_size = online_chunk_size
|
||||||
self.online = OnlineASRProcessor(*args, **kwargs)
|
self.online = OnlineASRProcessor(*args, **kwargs)
|
||||||
|
self.asr = self.online.asr
|
||||||
|
|
||||||
# Load a VAD model (e.g. Silero VAD)
|
# Load a VAD model (e.g. Silero VAD)
|
||||||
import torch
|
import torch
|
||||||
model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
|
model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
|
||||||
@@ -392,28 +400,35 @@ class VACOnlineASRProcessor:
|
|||||||
|
|
||||||
self.vac = FixedVADIterator(model)
|
self.vac = FixedVADIterator(model)
|
||||||
self.logfile = self.online.logfile
|
self.logfile = self.online.logfile
|
||||||
|
self.last_input_audio_stream_end_time: float = 0.0
|
||||||
self.init()
|
self.init()
|
||||||
|
|
||||||
def init(self):
|
def init(self):
|
||||||
self.online.init()
|
self.online.init()
|
||||||
self.vac.reset_states()
|
self.vac.reset_states()
|
||||||
self.current_online_chunk_buffer_size = 0
|
self.current_online_chunk_buffer_size = 0
|
||||||
|
self.last_input_audio_stream_end_time = self.online.buffer_time_offset
|
||||||
self.is_currently_final = False
|
self.is_currently_final = False
|
||||||
self.status: Optional[str] = None # "voice" or "nonvoice"
|
self.status: Optional[str] = None # "voice" or "nonvoice"
|
||||||
self.audio_buffer = np.array([], dtype=np.float32)
|
self.audio_buffer = np.array([], dtype=np.float32)
|
||||||
self.buffer_offset = 0 # in frames
|
self.buffer_offset = 0 # in frames
|
||||||
|
|
||||||
|
def get_audio_buffer_end_time(self) -> float:
|
||||||
|
"""Returns the absolute end time of the audio processed by the underlying OnlineASRProcessor."""
|
||||||
|
return self.online.get_audio_buffer_end_time()
|
||||||
|
|
||||||
def clear_buffer(self):
|
def clear_buffer(self):
|
||||||
self.buffer_offset += len(self.audio_buffer)
|
self.buffer_offset += len(self.audio_buffer)
|
||||||
self.audio_buffer = np.array([], dtype=np.float32)
|
self.audio_buffer = np.array([], dtype=np.float32)
|
||||||
|
|
||||||
def insert_audio_chunk(self, audio: np.ndarray):
|
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
|
||||||
"""
|
"""
|
||||||
Process an incoming small audio chunk:
|
Process an incoming small audio chunk:
|
||||||
- run VAD on the chunk,
|
- run VAD on the chunk,
|
||||||
- decide whether to send the audio to the online ASR processor immediately,
|
- decide whether to send the audio to the online ASR processor immediately,
|
||||||
- and/or to mark the current utterance as finished.
|
- and/or to mark the current utterance as finished.
|
||||||
"""
|
"""
|
||||||
|
self.last_input_audio_stream_end_time = audio_stream_end_time
|
||||||
res = self.vac(audio)
|
res = self.vac(audio)
|
||||||
self.audio_buffer = np.append(self.audio_buffer, audio)
|
self.audio_buffer = np.append(self.audio_buffer, audio)
|
||||||
|
|
||||||
@@ -455,10 +470,11 @@ class VACOnlineASRProcessor:
|
|||||||
self.buffer_offset += max(0, len(self.audio_buffer) - self.SAMPLING_RATE)
|
self.buffer_offset += max(0, len(self.audio_buffer) - self.SAMPLING_RATE)
|
||||||
self.audio_buffer = self.audio_buffer[-self.SAMPLING_RATE:]
|
self.audio_buffer = self.audio_buffer[-self.SAMPLING_RATE:]
|
||||||
|
|
||||||
def process_iter(self) -> Transcript:
|
def process_iter(self) -> Tuple[List[ASRToken], float]:
|
||||||
"""
|
"""
|
||||||
Depending on the VAD status and the amount of accumulated audio,
|
Depending on the VAD status and the amount of accumulated audio,
|
||||||
process the current audio chunk.
|
process the current audio chunk.
|
||||||
|
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
|
||||||
"""
|
"""
|
||||||
if self.is_currently_final:
|
if self.is_currently_final:
|
||||||
return self.finish()
|
return self.finish()
|
||||||
@@ -467,17 +483,20 @@ class VACOnlineASRProcessor:
|
|||||||
return self.online.process_iter()
|
return self.online.process_iter()
|
||||||
else:
|
else:
|
||||||
logger.debug("No online update, only VAD")
|
logger.debug("No online update, only VAD")
|
||||||
return Transcript(None, None, "")
|
return [], self.last_input_audio_stream_end_time
|
||||||
|
|
||||||
def finish(self) -> Transcript:
|
def finish(self) -> Tuple[List[ASRToken], float]:
|
||||||
"""Finish processing by flushing any remaining text."""
|
"""
|
||||||
result = self.online.finish()
|
Finish processing by flushing any remaining text.
|
||||||
|
Returns a tuple: (list of remaining ASRToken objects, float representing the final audio processed up to time).
|
||||||
|
"""
|
||||||
|
result_tokens, processed_upto = self.online.finish()
|
||||||
self.current_online_chunk_buffer_size = 0
|
self.current_online_chunk_buffer_size = 0
|
||||||
self.is_currently_final = False
|
self.is_currently_final = False
|
||||||
return result
|
return result_tokens, processed_upto
|
||||||
|
|
||||||
def get_buffer(self):
|
def get_buffer(self):
|
||||||
"""
|
"""
|
||||||
Get the unvalidated buffer in string format.
|
Get the unvalidated buffer in string format.
|
||||||
"""
|
"""
|
||||||
return self.online.concatenate_tokens(self.online.transcript_buffer.buffer).text
|
return self.online.concatenate_tokens(self.online.transcript_buffer.buffer)
|
||||||
|
|||||||
Reference in New Issue
Block a user