diff --git a/README.md b/README.md
index 7c43451..1dc0a7e 100644
--- a/README.md
+++ b/README.md
@@ -5,7 +5,7 @@
This project is based on [Whisper Streaming](https://github.com/ufal/whisper_streaming) and lets you transcribe audio directly from your browser. Simply launch the local server and grant microphone access. Everything runs locally on your machine ✨
-
+
### Differences from [Whisper Streaming](https://github.com/ufal/whisper_streaming)
@@ -24,20 +24,27 @@ This project is based on [Whisper Streaming](https://github.com/ufal/whisper_str
- **FastAPI WebSocket Server** – Real-time speech-to-text processing with async FFmpeg streaming.
- **JavaScript Client** – Ready-to-use MediaRecorder implementation for seamless client-side integration.
-
## Installation
+### Via pip
+
+```bash
+pip install whisperlivekit
+```
+
+### From source
+
1. **Clone the Repository**:
```bash
git clone https://github.com/QuentinFuxa/WhisperLiveKit
cd WhisperLiveKit
+ pip install -e .
```
+### System Dependencies
-### How to Launch the Server
-
-1. **Dependencies**:
+You need to install FFmpeg on your system:
- Install system dependencies:
```bash
diff --git a/web/demo.png b/demo.png
similarity index 100%
rename from web/demo.png
rename to demo.png
diff --git a/parse_args.py b/parse_args.py
deleted file mode 100644
index f201477..0000000
--- a/parse_args.py
+++ /dev/null
@@ -1,52 +0,0 @@
-
-import argparse
-from whisper_streaming_custom.whisper_online import add_shared_args
-
-
-def parse_args():
- parser = argparse.ArgumentParser(description="Whisper FastAPI Online Server")
- parser.add_argument(
- "--host",
- type=str,
- default="localhost",
- help="The host address to bind the server to.",
- )
- parser.add_argument(
- "--port", type=int, default=8000, help="The port number to bind the server to."
- )
- parser.add_argument(
- "--warmup-file",
- type=str,
- default=None,
- dest="warmup_file",
- help="""
- The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast.
- If not set, uses https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav.
- If False, no warmup is performed.
- """,
- )
-
- parser.add_argument(
- "--confidence-validation",
- type=bool,
- default=False,
- help="Accelerates validation of tokens using confidence scores. Transcription will be faster but punctuation might be less accurate.",
- )
-
- parser.add_argument(
- "--diarization",
- type=bool,
- default=True,
- help="Whether to enable speaker diarization.",
- )
-
- parser.add_argument(
- "--transcription",
- type=bool,
- default=True,
- help="To disable to only see live diarization results.",
- )
-
- add_shared_args(parser)
- args = parser.parse_args()
- return args
\ No newline at end of file
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..b7eb28b
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,44 @@
+from setuptools import setup, find_packages
+
+setup(
+ name="whisperlivekit",
+ version="0.1.0",
+ description="Real-time, Fully Local Whisper's Speech-to-Text and Speaker Diarization",
+ long_description=open("README.md", "r", encoding="utf-8").read(),
+ long_description_content_type="text/markdown",
+ author="Quentin Fuxa",
+ url="https://github.com/QuentinFuxa/WhisperLiveKit",
+ packages=find_packages(),
+ install_requires=[
+ "fastapi",
+ "ffmpeg-python",
+ "librosa",
+ "soundfile",
+ "faster-whisper",
+ "uvicorn",
+ "websockets",
+ ],
+ extras_require={
+ "diarization": ["diart"],
+ "vac": ["torch"],
+ "sentence": ["mosestokenizer", "wtpsplit"],
+ },
+ package_data={
+ 'whisperlivekit': ['web/*.html'],
+ },
+ entry_points={
+ 'console_scripts': [
+ 'whisperlivekit-server=whisperlivekit.server:run_server',
+ ],
+ },
+ classifiers=[
+ "Development Status :: 4 - Beta",
+ "Intended Audience :: Developers",
+ "License :: OSI Approved :: MIT License",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+ "Topic :: Multimedia :: Sound/Audio :: Speech",
+ ],
+ python_requires=">=3.9",
+)
\ No newline at end of file
diff --git a/whisper_fastapi_online_server.py b/whisper_fastapi_online_server.py
index 6c41f27..053998f 100644
--- a/whisper_fastapi_online_server.py
+++ b/whisper_fastapi_online_server.py
@@ -1,37 +1,26 @@
from contextlib import asynccontextmanager
-
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
-from whisper_streaming_custom.whisper_online import backend_factory, warmup_asr
+from whisperlivekit import WhisperLiveKit
+from whisperlivekit.audio_processor import AudioProcessor
+
import asyncio
import logging
-from parse_args import parse_args
-from audio_processor import AudioProcessor
+import os
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logging.getLogger().setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
-args = parse_args()
-
+kit = None
@asynccontextmanager
async def lifespan(app: FastAPI):
- global asr, tokenizer, diarization
- if args.transcription:
- asr, tokenizer = backend_factory(args)
- warmup_asr(asr, args.warmup_file)
- else:
- asr, tokenizer = None, None
-
- if args.diarization:
- from diarization.diarization_online import DiartDiarization
- diarization = DiartDiarization()
- else :
- diarization = None
+ global kit
+ kit = WhisperLiveKit()
yield
app = FastAPI(lifespan=lifespan)
@@ -44,13 +33,9 @@ app.add_middleware(
)
-# Load demo HTML for the root endpoint
-with open("web/live_transcription.html", "r", encoding="utf-8") as f:
- html = f.read()
-
@app.get("/")
async def get():
- return HTMLResponse(html)
+ return HTMLResponse(kit.web_interface())
async def handle_websocket_results(websocket, results_generator):
@@ -64,12 +49,12 @@ async def handle_websocket_results(websocket, results_generator):
@app.websocket("/asr")
async def websocket_endpoint(websocket: WebSocket):
- audio_processor = AudioProcessor(args, asr, tokenizer)
+ audio_processor = AudioProcessor()
await websocket.accept()
logger.info("WebSocket connection opened.")
- results_generator = await audio_processor.create_tasks(diarization)
+ results_generator = await audio_processor.create_tasks()
websocket_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
try:
@@ -85,8 +70,13 @@ async def websocket_endpoint(websocket: WebSocket):
if __name__ == "__main__":
import uvicorn
-
+
+ temp_kit = WhisperLiveKit(transcription=False, diarization=False)
+
uvicorn.run(
- "whisper_fastapi_online_server:app", host=args.host, port=args.port, reload=True,
+ "whisper_fastapi_online_server:app",
+ host=temp_kit.args.host,
+ port=temp_kit.args.port,
+ reload=True,
log_level="info"
)
\ No newline at end of file
diff --git a/whisperlivekit/__init__.py b/whisperlivekit/__init__.py
new file mode 100644
index 0000000..7319521
--- /dev/null
+++ b/whisperlivekit/__init__.py
@@ -0,0 +1,4 @@
+from .core import WhisperLiveKit, parse_args
+from .audio_processor import AudioProcessor
+
+__all__ = ['WhisperLiveKit', 'AudioProcessor', 'parse_args']
\ No newline at end of file
diff --git a/audio_processor.py b/whisperlivekit/audio_processor.py
similarity index 95%
rename from audio_processor.py
rename to whisperlivekit/audio_processor.py
index 94bfbef..68ee571 100644
--- a/audio_processor.py
+++ b/whisperlivekit/audio_processor.py
@@ -7,8 +7,9 @@ import logging
import traceback
from datetime import timedelta
from typing import List, Dict, Any
-from timed_objects import ASRToken
-from whisper_streaming_custom.whisper_online import online_factory
+from whisperlivekit.timed_objects import ASRToken
+from whisperlivekit.whisper_streaming_custom.whisper_online import online_factory
+from whisperlivekit.core import WhisperLiveKit
# Set up logging once
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
@@ -22,16 +23,19 @@ def format_time(seconds: float) -> str:
class AudioProcessor:
"""
Processes audio streams for transcription and diarization.
- Handles audio processing, state management, and result formatting in a single class.
+ Handles audio processing, state management, and result formatting.
"""
- def __init__(self, args, asr, tokenizer):
+ def __init__(self):
"""Initialize the audio processor with configuration, models, and state."""
+
+ models = WhisperLiveKit()
+
# Audio processing settings
- self.args = args
+ self.args = models.args
self.sample_rate = 16000
self.channels = 1
- self.samples_per_sec = int(self.sample_rate * args.min_chunk_size)
+ self.samples_per_sec = int(self.sample_rate * self.args.min_chunk_size)
self.bytes_per_sample = 2
self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample
self.max_bytes_per_sec = 32000 * 5 # 5 seconds of audio at 32 kHz
@@ -49,16 +53,17 @@ class AudioProcessor:
self.last_response_content = ""
# Models and processing
- self.asr = asr
- self.tokenizer = tokenizer
+ self.asr = models.asr
+ self.tokenizer = models.tokenizer
+ self.diarization = models.diarization
self.ffmpeg_process = self.start_ffmpeg_decoder()
- self.transcription_queue = asyncio.Queue() if args.transcription else None
- self.diarization_queue = asyncio.Queue() if args.diarization else None
+ self.transcription_queue = asyncio.Queue() if self.args.transcription else None
+ self.diarization_queue = asyncio.Queue() if self.args.diarization else None
self.pcm_buffer = bytearray()
# Initialize transcription engine if enabled
- if args.transcription:
- self.online = online_factory(args, asr, tokenizer)
+ if self.args.transcription:
+ self.online = online_factory(self.args, models.asr, models.tokenizer)
def convert_pcm_to_float(self, pcm_buffer):
"""Convert PCM buffer in s16le format to normalized NumPy array."""
@@ -362,10 +367,8 @@ class AudioProcessor:
logger.warning(f"Traceback: {traceback.format_exc()}")
await asyncio.sleep(0.5) # Back off on error
- async def create_tasks(self, diarization=None):
+ async def create_tasks(self):
"""Create and start processing tasks."""
- if diarization:
- self.diarization = diarization
tasks = []
if self.args.transcription and self.online:
diff --git a/whisperlivekit/core.py b/whisperlivekit/core.py
new file mode 100644
index 0000000..9363668
--- /dev/null
+++ b/whisperlivekit/core.py
@@ -0,0 +1,174 @@
+from whisperlivekit.whisper_streaming_custom.whisper_online import backend_factory, warmup_asr
+from argparse import Namespace, ArgumentParser
+
+def parse_args():
+ parser = ArgumentParser(description="Whisper FastAPI Online Server")
+ parser.add_argument(
+ "--host",
+ type=str,
+ default="localhost",
+ help="The host address to bind the server to.",
+ )
+ parser.add_argument(
+ "--port", type=int, default=8000, help="The port number to bind the server to."
+ )
+ parser.add_argument(
+ "--warmup-file",
+ type=str,
+ default=None,
+ dest="warmup_file",
+ help="""
+ The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast.
+ If not set, uses https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav.
+ If False, no warmup is performed.
+ """,
+ )
+
+ parser.add_argument(
+ "--confidence-validation",
+ type=bool,
+ default=False,
+ help="Accelerates validation of tokens using confidence scores. Transcription will be faster but punctuation might be less accurate.",
+ )
+
+ parser.add_argument(
+ "--diarization",
+ type=bool,
+ default=True,
+ help="Whether to enable speaker diarization.",
+ )
+
+ parser.add_argument(
+ "--transcription",
+ type=bool,
+ default=True,
+ help="To disable to only see live diarization results.",
+ )
+
+ parser.add_argument(
+ "--min-chunk-size",
+ type=float,
+ default=0.5,
+ help="Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.",
+ )
+ parser.add_argument(
+ "--model",
+ type=str,
+ default="tiny",
+ choices="tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo".split(
+ ","
+ ),
+ help="Name size of the Whisper model to use (default: large-v2). The model is automatically downloaded from the model hub if not present in model cache dir.",
+ )
+ parser.add_argument(
+ "--model_cache_dir",
+ type=str,
+ default=None,
+ help="Overriding the default model cache dir where models downloaded from the hub are saved",
+ )
+ parser.add_argument(
+ "--model_dir",
+ type=str,
+ default=None,
+ help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.",
+ )
+ parser.add_argument(
+ "--lan",
+ "--language",
+ type=str,
+ default="auto",
+ help="Source language code, e.g. en,de,cs, or 'auto' for language detection.",
+ )
+ parser.add_argument(
+ "--task",
+ type=str,
+ default="transcribe",
+ choices=["transcribe", "translate"],
+ help="Transcribe or translate.",
+ )
+ parser.add_argument(
+ "--backend",
+ type=str,
+ default="faster-whisper",
+ choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"],
+ help="Load only this backend for Whisper processing.",
+ )
+ parser.add_argument(
+ "--vac",
+ action="store_true",
+ default=False,
+ help="Use VAC = voice activity controller. Recommended. Requires torch.",
+ )
+ parser.add_argument(
+ "--vac-chunk-size", type=float, default=0.04, help="VAC sample size in seconds."
+ )
+ parser.add_argument(
+ "--vad",
+ action="store_true",
+ default=True,
+ help="Use VAD = voice activity detection, with the default parameters.",
+ )
+ parser.add_argument(
+ "--buffer_trimming",
+ type=str,
+ default="segment",
+ choices=["sentence", "segment"],
+ help='Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.',
+ )
+ parser.add_argument(
+ "--buffer_trimming_sec",
+ type=float,
+ default=15,
+ help="Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.",
+ )
+ parser.add_argument(
+ "-l",
+ "--log-level",
+ dest="log_level",
+ choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
+ help="Set the log level",
+ default="DEBUG",
+ )
+
+ args = parser.parse_args()
+ return args
+
+class WhisperLiveKit:
+ _instance = None
+ _initialized = False
+
+ def __new__(cls, *args, **kwargs):
+ if cls._instance is None:
+ cls._instance = super().__new__(cls)
+ return cls._instance
+
+ def __init__(self, **kwargs):
+ if WhisperLiveKit._initialized:
+ return
+
+ default_args = vars(parse_args())
+
+ merged_args = {**default_args, **kwargs}
+
+ self.args = Namespace(**merged_args)
+
+ self.asr = None
+ self.tokenizer = None
+ self.diarization = None
+
+ if self.args.transcription:
+ self.asr, self.tokenizer = backend_factory(self.args)
+ warmup_asr(self.asr, self.args.warmup_file)
+
+ if self.args.diarization:
+ from whisperlivekit.diarization.diarization_online import DiartDiarization
+ self.diarization = DiartDiarization()
+
+ WhisperLiveKit._initialized = True
+
+ def web_interface(self):
+ import pkg_resources
+ html_path = pkg_resources.resource_filename('whisperlivekit', 'web/live_transcription.html')
+ with open(html_path, "r", encoding="utf-8") as f:
+ html = f.read()
+ return html
\ No newline at end of file
diff --git a/diarization/diarization_online.py b/whisperlivekit/diarization/diarization_online.py
similarity index 99%
rename from diarization/diarization_online.py
rename to whisperlivekit/diarization/diarization_online.py
index 622fb15..7db5003 100644
--- a/diarization/diarization_online.py
+++ b/whisperlivekit/diarization/diarization_online.py
@@ -8,7 +8,7 @@ import logging
from diart import SpeakerDiarization, SpeakerDiarizationConfig
from diart.inference import StreamingInference
from diart.sources import AudioSource
-from timed_objects import SpeakerSegment
+from whisperlivekit.timed_objects import SpeakerSegment
from diart.sources import MicrophoneAudioSource
from rx.core import Observer
from typing import Tuple, Any, List
diff --git a/silero_vad_iterator.py b/whisperlivekit/silero_vad_iterator.py
similarity index 100%
rename from silero_vad_iterator.py
rename to whisperlivekit/silero_vad_iterator.py
diff --git a/timed_objects.py b/whisperlivekit/timed_objects.py
similarity index 100%
rename from timed_objects.py
rename to whisperlivekit/timed_objects.py
diff --git a/web/live_transcription.html b/whisperlivekit/web/live_transcription.html
similarity index 100%
rename from web/live_transcription.html
rename to whisperlivekit/web/live_transcription.html
diff --git a/whisper_streaming_custom/backends.py b/whisperlivekit/whisper_streaming_custom/backends.py
similarity index 99%
rename from whisper_streaming_custom/backends.py
rename to whisperlivekit/whisper_streaming_custom/backends.py
index fa52104..8f1090e 100644
--- a/whisper_streaming_custom/backends.py
+++ b/whisperlivekit/whisper_streaming_custom/backends.py
@@ -6,7 +6,7 @@ import math
import torch
from typing import List
import numpy as np
-from timed_objects import ASRToken
+from whisperlivekit.timed_objects import ASRToken
logger = logging.getLogger(__name__)
diff --git a/whisper_streaming_custom/online_asr.py b/whisperlivekit/whisper_streaming_custom/online_asr.py
similarity index 99%
rename from whisper_streaming_custom/online_asr.py
rename to whisperlivekit/whisper_streaming_custom/online_asr.py
index bc09395..2fd9de0 100644
--- a/whisper_streaming_custom/online_asr.py
+++ b/whisperlivekit/whisper_streaming_custom/online_asr.py
@@ -2,7 +2,7 @@ import sys
import numpy as np
import logging
from typing import List, Tuple, Optional
-from timed_objects import ASRToken, Sentence, Transcript
+from whisperlivekit.timed_objects import ASRToken, Sentence, Transcript
logger = logging.getLogger(__name__)
diff --git a/whisper_streaming_custom/whisper_online.py b/whisperlivekit/whisper_streaming_custom/whisper_online.py
similarity index 66%
rename from whisper_streaming_custom/whisper_online.py
rename to whisperlivekit/whisper_streaming_custom/whisper_online.py
index d7263ac..00287a9 100644
--- a/whisper_streaming_custom/whisper_online.py
+++ b/whisperlivekit/whisper_streaming_custom/whisper_online.py
@@ -64,95 +64,6 @@ def create_tokenizer(lan):
return WtPtok()
-def add_shared_args(parser):
- """shared args for simulation (this entry point) and server
- parser: argparse.ArgumentParser object
- """
- parser.add_argument(
- "--min-chunk-size",
- type=float,
- default=0.5,
- help="Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.",
- )
- parser.add_argument(
- "--model",
- type=str,
- default="tiny",
- choices="tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo".split(
- ","
- ),
- help="Name size of the Whisper model to use (default: large-v2). The model is automatically downloaded from the model hub if not present in model cache dir.",
- )
- parser.add_argument(
- "--model_cache_dir",
- type=str,
- default=None,
- help="Overriding the default model cache dir where models downloaded from the hub are saved",
- )
- parser.add_argument(
- "--model_dir",
- type=str,
- default=None,
- help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.",
- )
- parser.add_argument(
- "--lan",
- "--language",
- type=str,
- default="auto",
- help="Source language code, e.g. en,de,cs, or 'auto' for language detection.",
- )
- parser.add_argument(
- "--task",
- type=str,
- default="transcribe",
- choices=["transcribe", "translate"],
- help="Transcribe or translate.",
- )
- parser.add_argument(
- "--backend",
- type=str,
- default="faster-whisper",
- choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"],
- help="Load only this backend for Whisper processing.",
- )
- parser.add_argument(
- "--vac",
- action="store_true",
- default=False,
- help="Use VAC = voice activity controller. Recommended. Requires torch.",
- )
- parser.add_argument(
- "--vac-chunk-size", type=float, default=0.04, help="VAC sample size in seconds."
- )
- parser.add_argument(
- "--vad",
- action="store_true",
- default=True,
- help="Use VAD = voice activity detection, with the default parameters.",
- )
- parser.add_argument(
- "--buffer_trimming",
- type=str,
- default="segment",
- choices=["sentence", "segment"],
- help='Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.',
- )
- parser.add_argument(
- "--buffer_trimming_sec",
- type=float,
- default=15,
- help="Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.",
- )
- parser.add_argument(
- "-l",
- "--log-level",
- dest="log_level",
- choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
- help="Set the log level",
- default="DEBUG",
- )
-
def backend_factory(args):
backend = args.backend
if backend == "openai-api":