mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-21 16:40:35 +00:00
Refactor DiartDiarization initialization and streamline WebSocket audio processing
This commit is contained in:
136
audio.py
136
audio.py
@@ -1,25 +1,15 @@
|
||||
import io
|
||||
import argparse
|
||||
import asyncio
|
||||
import numpy as np
|
||||
import ffmpeg
|
||||
from time import time, sleep
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from whisper_streaming_custom.whisper_online import backend_factory, online_factory, add_shared_args, warmup_asr
|
||||
from timed_objects import ASRToken
|
||||
|
||||
from whisper_streaming_custom.whisper_online import online_factory
|
||||
import math
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
import traceback
|
||||
from state import SharedState
|
||||
from formatters import format_time
|
||||
from parse_args import parse_args
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
@@ -27,7 +17,6 @@ logging.getLogger().setLevel(logging.WARNING)
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
class AudioProcessor:
|
||||
|
||||
def __init__(self, args, asr, tokenizer):
|
||||
@@ -38,9 +27,22 @@ class AudioProcessor:
|
||||
self.bytes_per_sample = 2
|
||||
self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample
|
||||
self.max_bytes_per_sec = 32000 * 5 # 5 seconds of audio at 32 kHz
|
||||
|
||||
|
||||
self.shared_state = SharedState()
|
||||
self.asr = asr
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.ffmpeg_process = self.start_ffmpeg_decoder()
|
||||
|
||||
self.transcription_queue = asyncio.Queue() if self.args.transcription else None
|
||||
self.diarization_queue = asyncio.Queue() if self.args.diarization else None
|
||||
|
||||
self.pcm_buffer = bytearray()
|
||||
if self.args.transcription:
|
||||
self.online = online_factory(self.args, self.asr, self.tokenizer)
|
||||
|
||||
|
||||
|
||||
def convert_pcm_to_float(self, pcm_buffer):
|
||||
"""
|
||||
@@ -70,26 +72,17 @@ class AudioProcessor:
|
||||
)
|
||||
return process
|
||||
|
||||
async def restart_ffmpeg(self, ffmpeg_process, online, pcm_buffer):
|
||||
if ffmpeg_process:
|
||||
async def restart_ffmpeg(self):
|
||||
if self.ffmpeg_process:
|
||||
try:
|
||||
ffmpeg_process.kill()
|
||||
await asyncio.get_event_loop().run_in_executor(None, ffmpeg_process.wait)
|
||||
self.ffmpeg_process.kill()
|
||||
await asyncio.get_event_loop().run_in_executor(None, self.ffmpeg_process.wait)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error killing FFmpeg process: {e}")
|
||||
ffmpeg_process = await self.start_ffmpeg_decoder()
|
||||
pcm_buffer = bytearray()
|
||||
|
||||
if self.args.transcription:
|
||||
online = online_factory(self.args, self.asr, self.tokenizer)
|
||||
|
||||
await self.shared_state.reset()
|
||||
logger.info("FFmpeg process started.")
|
||||
return ffmpeg_process, online, pcm_buffer
|
||||
self.ffmpeg_process = await self.start_ffmpeg_decoder()
|
||||
self.pcm_buffer = bytearray()
|
||||
|
||||
|
||||
|
||||
async def ffmpeg_stdout_reader(self, ffmpeg_process, pcm_buffer, diarization_queue, transcription_queue):
|
||||
async def ffmpeg_stdout_reader(self):
|
||||
loop = asyncio.get_event_loop()
|
||||
beg = time()
|
||||
|
||||
@@ -103,36 +96,36 @@ class AudioProcessor:
|
||||
try:
|
||||
chunk = await asyncio.wait_for(
|
||||
loop.run_in_executor(
|
||||
None, ffmpeg_process.stdout.read, ffmpeg_buffer_from_duration
|
||||
None, self.ffmpeg_process.stdout.read, ffmpeg_buffer_from_duration
|
||||
),
|
||||
timeout=15.0
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("FFmpeg read timeout. Restarting...")
|
||||
ffmpeg_process, online, pcm_buffer = await self.restart_ffmpeg(ffmpeg_process, online, pcm_buffer)
|
||||
await self.restart_ffmpeg()
|
||||
beg = time()
|
||||
continue # Skip processing and read from new process
|
||||
|
||||
if not chunk:
|
||||
logger.info("FFmpeg stdout closed.")
|
||||
break
|
||||
pcm_buffer.extend(chunk)
|
||||
self.pcm_buffer.extend(chunk)
|
||||
|
||||
if self.args.diarization and diarization_queue:
|
||||
await diarization_queue.put(self.convert_pcm_to_float(pcm_buffer).copy())
|
||||
if self.args.diarization and self.diarization_queue:
|
||||
await self.diarization_queue.put(self.convert_pcm_to_float(self.pcm_buffer).copy())
|
||||
|
||||
if len(pcm_buffer) >= self.bytes_per_sec:
|
||||
if len(pcm_buffer) > self.max_bytes_per_sec:
|
||||
if len(self.pcm_buffer) >= self.bytes_per_sec:
|
||||
if len(self.pcm_buffer) > self.max_bytes_per_sec:
|
||||
logger.warning(
|
||||
f"""Audio buffer is too large: {len(pcm_buffer) / self.bytes_per_sec:.2f} seconds.
|
||||
f"""Audio buffer is too large: {len(self.pcm_buffer) / self.bytes_per_sec:.2f} seconds.
|
||||
The model probably struggles to keep up. Consider using a smaller model.
|
||||
""")
|
||||
|
||||
pcm_array = self.convert_pcm_to_float(pcm_buffer[:self.max_bytes_per_sec])
|
||||
pcm_buffer = pcm_buffer[self.max_bytes_per_sec:]
|
||||
pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:self.max_bytes_per_sec])
|
||||
self.pcm_buffer = self.pcm_buffer[self.max_bytes_per_sec:]
|
||||
|
||||
if self.args.transcription and transcription_queue:
|
||||
await transcription_queue.put(pcm_array.copy())
|
||||
if self.args.transcription and self.transcription_queue:
|
||||
await self.transcription_queue.put(pcm_array.copy())
|
||||
|
||||
|
||||
if not self.args.transcription and not self.args.diarization:
|
||||
@@ -144,27 +137,24 @@ class AudioProcessor:
|
||||
break
|
||||
logger.info("Exiting ffmpeg_stdout_reader...")
|
||||
|
||||
|
||||
|
||||
|
||||
async def transcription_processor(self, pcm_queue, online):
|
||||
async def transcription_processor(self):
|
||||
full_transcription = ""
|
||||
sep = online.asr.sep
|
||||
sep = self.online.asr.sep
|
||||
|
||||
while True:
|
||||
try:
|
||||
pcm_array = await pcm_queue.get()
|
||||
pcm_array = await self.transcription_queue.get()
|
||||
|
||||
logger.info(f"{len(online.audio_buffer) / online.SAMPLING_RATE} seconds of audio will be processed by the model.")
|
||||
logger.info(f"{len(self.online.audio_buffer) / self.online.SAMPLING_RATE} seconds of audio will be processed by the model.")
|
||||
|
||||
# Process transcription
|
||||
online.insert_audio_chunk(pcm_array)
|
||||
new_tokens = online.process_iter()
|
||||
self.online.insert_audio_chunk(pcm_array)
|
||||
new_tokens = self.online.process_iter()
|
||||
|
||||
if new_tokens:
|
||||
full_transcription += sep.join([t.text for t in new_tokens])
|
||||
|
||||
_buffer = online.get_buffer()
|
||||
_buffer = self.online.get_buffer()
|
||||
buffer = _buffer.text
|
||||
end_buffer = _buffer.end if _buffer.end else (new_tokens[-1].end if new_tokens else 0)
|
||||
|
||||
@@ -178,14 +168,15 @@ class AudioProcessor:
|
||||
logger.warning(f"Exception in transcription_processor: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
finally:
|
||||
pcm_queue.task_done()
|
||||
self.transcription_queue.task_done()
|
||||
|
||||
async def diarization_processor(self, pcm_queue, diarization_obj):
|
||||
|
||||
async def diarization_processor(self, diarization_obj):
|
||||
buffer_diarization = ""
|
||||
|
||||
while True:
|
||||
try:
|
||||
pcm_array = await pcm_queue.get()
|
||||
pcm_array = await self.diarization_queue.get()
|
||||
|
||||
# Process diarization
|
||||
await diarization_obj.diarize(pcm_array)
|
||||
@@ -205,7 +196,7 @@ class AudioProcessor:
|
||||
logger.warning(f"Exception in diarization_processor: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
finally:
|
||||
pcm_queue.task_done()
|
||||
self.diarization_queue.task_done()
|
||||
|
||||
async def results_formatter(self, websocket):
|
||||
while True:
|
||||
@@ -304,3 +295,40 @@ class AudioProcessor:
|
||||
logger.warning(f"Exception in results_formatter: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
await asyncio.sleep(0.5) # Back off on error
|
||||
|
||||
async def create_tasks(self, websocket, diarization):
|
||||
tasks = []
|
||||
if self.args.transcription and self.online:
|
||||
tasks.append(asyncio.create_task(self.transcription_processor()))
|
||||
if self.args.diarization and diarization:
|
||||
tasks.append(asyncio.create_task(self.diarization_processor(diarization)))
|
||||
formatter_task = asyncio.create_task(self.results_formatter(websocket))
|
||||
tasks.append(formatter_task)
|
||||
stdout_reader_task = asyncio.create_task(self.ffmpeg_stdout_reader())
|
||||
tasks.append(stdout_reader_task)
|
||||
self.tasks = tasks
|
||||
self.diarization = diarization
|
||||
|
||||
async def cleanup(self):
|
||||
for task in self.tasks:
|
||||
task.cancel()
|
||||
try:
|
||||
await asyncio.gather(*self.tasks, return_exceptions=True)
|
||||
self.ffmpeg_process.stdin.close()
|
||||
self.ffmpeg_process.wait()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error during cleanup: {e}")
|
||||
if self.args.diarization and self.diarization:
|
||||
self.diarization.close()
|
||||
|
||||
async def process_audio(self, message):
|
||||
try:
|
||||
self.ffmpeg_process.stdin.write(message)
|
||||
self.ffmpeg_process.stdin.flush()
|
||||
except (BrokenPipeError, AttributeError) as e:
|
||||
logger.warning(f"Error writing to FFmpeg: {e}. Restarting...")
|
||||
await self.restart_ffmpeg()
|
||||
self.ffmpeg_process.stdin.write(message)
|
||||
self.ffmpeg_process.stdin.flush()
|
||||
|
||||
|
||||
@@ -103,7 +103,7 @@ class WebSocketAudioSource(AudioSource):
|
||||
|
||||
|
||||
class DiartDiarization:
|
||||
def __init__(self, sample_rate: int, config : SpeakerDiarizationConfig = None, use_microphone: bool = False):
|
||||
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False):
|
||||
self.pipeline = SpeakerDiarization(config=config)
|
||||
self.observer = DiarizationObserver()
|
||||
|
||||
|
||||
@@ -1,24 +1,11 @@
|
||||
import io
|
||||
import argparse
|
||||
import asyncio
|
||||
import numpy as np
|
||||
import ffmpeg
|
||||
from time import time, sleep
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from whisper_streaming_custom.whisper_online import backend_factory, online_factory, add_shared_args, warmup_asr
|
||||
from timed_objects import ASRToken
|
||||
|
||||
import math
|
||||
from whisper_streaming_custom.whisper_online import backend_factory, warmup_asr
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
import traceback
|
||||
from state import SharedState
|
||||
from formatters import format_time
|
||||
from parse_args import parse_args
|
||||
from audio import AudioProcessor
|
||||
|
||||
@@ -27,19 +14,8 @@ logging.getLogger().setLevel(logging.WARNING)
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
|
||||
args = parse_args()
|
||||
|
||||
SAMPLE_RATE = 16000
|
||||
# CHANNELS = 1
|
||||
# SAMPLES_PER_SEC = int(SAMPLE_RATE * args.min_chunk_size)
|
||||
# BYTES_PER_SAMPLE = 2 # s16le = 2 bytes per sample
|
||||
# BYTES_PER_SEC = SAMPLES_PER_SEC * BYTES_PER_SAMPLE
|
||||
# MAX_BYTES_PER_SEC = 32000 * 5 # 5 seconds of audio at 32 kHz
|
||||
|
||||
|
||||
##### LOAD APP #####
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
@@ -52,7 +28,7 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
if args.diarization:
|
||||
from diarization.diarization_online import DiartDiarization
|
||||
diarization = DiartDiarization(SAMPLE_RATE)
|
||||
diarization = DiartDiarization()
|
||||
else :
|
||||
diarization = None
|
||||
yield
|
||||
@@ -75,66 +51,22 @@ with open("web/live_transcription.html", "r", encoding="utf-8") as f:
|
||||
async def get():
|
||||
return HTMLResponse(html)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@app.websocket("/asr")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
audio_processor = AudioProcessor(args, asr, tokenizer)
|
||||
|
||||
await websocket.accept()
|
||||
logger.info("WebSocket connection opened.")
|
||||
|
||||
ffmpeg_process = None
|
||||
pcm_buffer = bytearray()
|
||||
|
||||
transcription_queue = asyncio.Queue() if args.transcription else None
|
||||
diarization_queue = asyncio.Queue() if args.diarization else None
|
||||
|
||||
online = None
|
||||
|
||||
ffmpeg_process, online, pcm_buffer = await audio_processor.restart_ffmpeg(ffmpeg_process, online, pcm_buffer)
|
||||
tasks = []
|
||||
if args.transcription and online:
|
||||
tasks.append(asyncio.create_task(
|
||||
audio_processor.transcription_processor(transcription_queue, online)))
|
||||
if args.diarization and diarization:
|
||||
tasks.append(asyncio.create_task(
|
||||
audio_processor.diarization_processor(diarization_queue, diarization)))
|
||||
formatter_task = asyncio.create_task(audio_processor.results_formatter(websocket))
|
||||
tasks.append(formatter_task)
|
||||
stdout_reader_task = asyncio.create_task(audio_processor.ffmpeg_stdout_reader(ffmpeg_process, pcm_buffer, diarization_queue, transcription_queue))
|
||||
tasks.append(stdout_reader_task)
|
||||
|
||||
|
||||
await audio_processor.create_tasks(websocket, diarization)
|
||||
try:
|
||||
while True:
|
||||
# Receive incoming WebM audio chunks from the client
|
||||
message = await websocket.receive_bytes()
|
||||
try:
|
||||
ffmpeg_process.stdin.write(message)
|
||||
ffmpeg_process.stdin.flush()
|
||||
except (BrokenPipeError, AttributeError) as e:
|
||||
logger.warning(f"Error writing to FFmpeg: {e}. Restarting...")
|
||||
ffmpeg_process, online, pcm_buffer = await audio_processor.restart_ffmpeg(ffmpeg_process, online, pcm_buffer)
|
||||
ffmpeg_process.stdin.write(message)
|
||||
ffmpeg_process.stdin.flush()
|
||||
audio_processor.process_audio(message)
|
||||
except WebSocketDisconnect:
|
||||
logger.warning("WebSocket disconnected.")
|
||||
finally:
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
ffmpeg_process.stdin.close()
|
||||
ffmpeg_process.wait()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error during cleanup: {e}")
|
||||
if args.diarization and diarization:
|
||||
diarization.close()
|
||||
audio_processor.cleanup()
|
||||
logger.info("WebSocket endpoint cleaned up.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user