mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
Refactor import statement for AudioProcessor and update cleanup method to be awaited; remove unused formatters and state management files
This commit is contained in:
335
audio.py
335
audio.py
@@ -1,335 +0,0 @@
|
||||
import asyncio
|
||||
import numpy as np
|
||||
import ffmpeg
|
||||
from time import time, sleep
|
||||
|
||||
|
||||
from whisper_streaming_custom.whisper_online import online_factory
|
||||
import math
|
||||
import logging
|
||||
import traceback
|
||||
from state import SharedState
|
||||
from formatters import format_time
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
logging.getLogger().setLevel(logging.WARNING)
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
class AudioProcessor:
|
||||
|
||||
def __init__(self, args, asr, tokenizer):
|
||||
self.args = args
|
||||
self.sample_rate = 16000
|
||||
self.channels = 1
|
||||
self.samples_per_sec = int(self.sample_rate * args.min_chunk_size)
|
||||
self.bytes_per_sample = 2
|
||||
self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample
|
||||
self.max_bytes_per_sec = 32000 * 5 # 5 seconds of audio at 32 kHz
|
||||
|
||||
|
||||
self.shared_state = SharedState()
|
||||
self.asr = asr
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.ffmpeg_process = self.start_ffmpeg_decoder()
|
||||
|
||||
self.transcription_queue = asyncio.Queue() if self.args.transcription else None
|
||||
self.diarization_queue = asyncio.Queue() if self.args.diarization else None
|
||||
|
||||
self.pcm_buffer = bytearray()
|
||||
if self.args.transcription:
|
||||
self.online = online_factory(self.args, self.asr, self.tokenizer)
|
||||
|
||||
|
||||
|
||||
def convert_pcm_to_float(self, pcm_buffer):
|
||||
"""
|
||||
Converts a PCM buffer in s16le format to a normalized NumPy array.
|
||||
Arg: pcm_buffer. PCM buffer containing raw audio data in s16le format
|
||||
Returns: np.ndarray. NumPy array of float32 type normalized between -1.0 and 1.0
|
||||
"""
|
||||
pcm_array = (np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32)
|
||||
/ 32768.0)
|
||||
return pcm_array
|
||||
|
||||
def start_ffmpeg_decoder(self):
|
||||
"""
|
||||
Start an FFmpeg process in async streaming mode that reads WebM from stdin
|
||||
and outputs raw s16le PCM on stdout. Returns the process object.
|
||||
"""
|
||||
process = (
|
||||
ffmpeg.input("pipe:0", format="webm")
|
||||
.output(
|
||||
"pipe:1",
|
||||
format="s16le",
|
||||
acodec="pcm_s16le",
|
||||
ac=self.channels,
|
||||
ar=str(self.sample_rate),
|
||||
)
|
||||
.run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True)
|
||||
)
|
||||
return process
|
||||
|
||||
async def restart_ffmpeg(self):
|
||||
if self.ffmpeg_process:
|
||||
try:
|
||||
self.ffmpeg_process.kill()
|
||||
await asyncio.get_event_loop().run_in_executor(None, self.ffmpeg_process.wait)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error killing FFmpeg process: {e}")
|
||||
self.ffmpeg_process = self.start_ffmpeg_decoder()
|
||||
self.pcm_buffer = bytearray()
|
||||
|
||||
async def ffmpeg_stdout_reader(self):
|
||||
loop = asyncio.get_event_loop()
|
||||
beg = time()
|
||||
|
||||
while True:
|
||||
try:
|
||||
elapsed_time = math.floor((time() - beg) * 10) / 10 # Round to 0.1 sec
|
||||
ffmpeg_buffer_from_duration = max(int(32000 * elapsed_time), 4096)
|
||||
beg = time()
|
||||
|
||||
# Read chunk with timeout
|
||||
try:
|
||||
chunk = await asyncio.wait_for(
|
||||
loop.run_in_executor(
|
||||
None, self.ffmpeg_process.stdout.read, ffmpeg_buffer_from_duration
|
||||
),
|
||||
timeout=15.0
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("FFmpeg read timeout. Restarting...")
|
||||
await self.restart_ffmpeg()
|
||||
beg = time()
|
||||
continue # Skip processing and read from new process
|
||||
|
||||
if not chunk:
|
||||
logger.info("FFmpeg stdout closed.")
|
||||
break
|
||||
self.pcm_buffer.extend(chunk)
|
||||
|
||||
if self.args.diarization and self.diarization_queue:
|
||||
await self.diarization_queue.put(self.convert_pcm_to_float(self.pcm_buffer).copy())
|
||||
|
||||
if len(self.pcm_buffer) >= self.bytes_per_sec:
|
||||
if len(self.pcm_buffer) > self.max_bytes_per_sec:
|
||||
logger.warning(
|
||||
f"""Audio buffer is too large: {len(self.pcm_buffer) / self.bytes_per_sec:.2f} seconds.
|
||||
The model probably struggles to keep up. Consider using a smaller model.
|
||||
""")
|
||||
|
||||
pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:self.max_bytes_per_sec])
|
||||
self.pcm_buffer = self.pcm_buffer[self.max_bytes_per_sec:]
|
||||
|
||||
if self.args.transcription and self.transcription_queue:
|
||||
await self.transcription_queue.put(pcm_array.copy())
|
||||
|
||||
|
||||
if not self.args.transcription and not self.args.diarization:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in ffmpeg_stdout_reader: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
break
|
||||
logger.info("Exiting ffmpeg_stdout_reader...")
|
||||
|
||||
async def transcription_processor(self):
|
||||
full_transcription = ""
|
||||
sep = self.online.asr.sep
|
||||
|
||||
while True:
|
||||
try:
|
||||
pcm_array = await self.transcription_queue.get()
|
||||
|
||||
logger.info(f"{len(self.online.audio_buffer) / self.online.SAMPLING_RATE} seconds of audio will be processed by the model.")
|
||||
|
||||
# Process transcription
|
||||
self.online.insert_audio_chunk(pcm_array)
|
||||
new_tokens = self.online.process_iter()
|
||||
|
||||
if new_tokens:
|
||||
full_transcription += sep.join([t.text for t in new_tokens])
|
||||
|
||||
_buffer = self.online.get_buffer()
|
||||
buffer = _buffer.text
|
||||
end_buffer = _buffer.end if _buffer.end else (new_tokens[-1].end if new_tokens else 0)
|
||||
|
||||
if buffer in full_transcription:
|
||||
buffer = ""
|
||||
|
||||
await self.shared_state.update_transcription(
|
||||
new_tokens, buffer, end_buffer, full_transcription, sep)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in transcription_processor: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
finally:
|
||||
self.transcription_queue.task_done()
|
||||
|
||||
|
||||
async def diarization_processor(self, diarization_obj):
|
||||
buffer_diarization = ""
|
||||
|
||||
while True:
|
||||
try:
|
||||
pcm_array = await self.diarization_queue.get()
|
||||
|
||||
# Process diarization
|
||||
await diarization_obj.diarize(pcm_array)
|
||||
|
||||
# Get current state
|
||||
state = await self.shared_state.get_current_state()
|
||||
tokens = state["tokens"]
|
||||
end_attributed_speaker = state["end_attributed_speaker"]
|
||||
|
||||
# Update speaker information
|
||||
new_end_attributed_speaker = diarization_obj.assign_speakers_to_tokens(
|
||||
end_attributed_speaker, tokens)
|
||||
|
||||
await self.shared_state.update_diarization(new_end_attributed_speaker, buffer_diarization)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in diarization_processor: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
finally:
|
||||
self.diarization_queue.task_done()
|
||||
|
||||
async def results_formatter(self):
|
||||
while True:
|
||||
try:
|
||||
state = await self.shared_state.get_current_state()
|
||||
tokens = state["tokens"]
|
||||
buffer_transcription = state["buffer_transcription"]
|
||||
buffer_diarization = state["buffer_diarization"]
|
||||
end_attributed_speaker = state["end_attributed_speaker"]
|
||||
remaining_time_transcription = state["remaining_time_transcription"]
|
||||
remaining_time_diarization = state["remaining_time_diarization"]
|
||||
sep = state["sep"]
|
||||
|
||||
# If diarization is enabled but no transcription, add dummy tokens periodically
|
||||
if (not tokens or tokens[-1].is_dummy) and not self.args.transcription and self.args.diarization:
|
||||
await self.shared_state.add_dummy_token()
|
||||
sleep(0.5)
|
||||
state = await self.shared_state.get_current_state()
|
||||
tokens = state["tokens"]
|
||||
previous_speaker = -1
|
||||
lines = []
|
||||
last_end_diarized = 0
|
||||
undiarized_text = []
|
||||
|
||||
for token in tokens:
|
||||
speaker = token.speaker
|
||||
if self.args.diarization:
|
||||
if (speaker == -1 or speaker == 0) and token.end >= end_attributed_speaker:
|
||||
undiarized_text.append(token.text)
|
||||
continue
|
||||
elif (speaker == -1 or speaker == 0) and token.end < end_attributed_speaker:
|
||||
speaker = previous_speaker
|
||||
if speaker not in [-1, 0]:
|
||||
last_end_diarized = max(token.end, last_end_diarized)
|
||||
|
||||
if speaker != previous_speaker or not lines:
|
||||
lines.append(
|
||||
{
|
||||
"speaker": speaker,
|
||||
"text": token.text,
|
||||
"beg": format_time(token.start),
|
||||
"end": format_time(token.end),
|
||||
"diff": round(token.end - last_end_diarized, 2)
|
||||
}
|
||||
)
|
||||
previous_speaker = speaker
|
||||
elif token.text: # Only append if text isn't empty
|
||||
lines[-1]["text"] += sep + token.text
|
||||
lines[-1]["end"] = format_time(token.end)
|
||||
lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
|
||||
|
||||
if undiarized_text:
|
||||
combined_buffer_diarization = sep.join(undiarized_text)
|
||||
if buffer_transcription:
|
||||
combined_buffer_diarization += sep
|
||||
await self.shared_state.update_diarization(end_attributed_speaker, combined_buffer_diarization)
|
||||
buffer_diarization = combined_buffer_diarization
|
||||
|
||||
if lines:
|
||||
response = {
|
||||
"lines": lines,
|
||||
"buffer_transcription": buffer_transcription,
|
||||
"buffer_diarization": buffer_diarization,
|
||||
"remaining_time_transcription": remaining_time_transcription,
|
||||
"remaining_time_diarization": remaining_time_diarization
|
||||
}
|
||||
else:
|
||||
response = {
|
||||
"lines": [{
|
||||
"speaker": 1,
|
||||
"text": "",
|
||||
"beg": format_time(0),
|
||||
"end": format_time(tokens[-1].end) if tokens else format_time(0),
|
||||
"diff": 0
|
||||
}],
|
||||
"buffer_transcription": buffer_transcription,
|
||||
"buffer_diarization": buffer_diarization,
|
||||
"remaining_time_transcription": remaining_time_transcription,
|
||||
"remaining_time_diarization": remaining_time_diarization
|
||||
}
|
||||
|
||||
response_content = ' '.join([str(line['speaker']) + ' ' + line["text"] for line in lines]) + ' | ' + buffer_transcription + ' | ' + buffer_diarization
|
||||
|
||||
if response_content != self.shared_state.last_response_content:
|
||||
if lines or buffer_transcription or buffer_diarization:
|
||||
yield response
|
||||
self.shared_state.last_response_content = response_content
|
||||
|
||||
#small delay to avoid overwhelming the client
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in results_formatter: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
await asyncio.sleep(0.5) # Back off on error
|
||||
|
||||
async def create_tasks(self, diarization=None):
|
||||
if diarization:
|
||||
self.diarization = diarization
|
||||
|
||||
tasks = []
|
||||
if self.args.transcription and self.online:
|
||||
tasks.append(asyncio.create_task(self.transcription_processor()))
|
||||
if self.args.diarization and self.diarization:
|
||||
tasks.append(asyncio.create_task(self.diarization_processor(self.diarization)))
|
||||
|
||||
stdout_reader_task = asyncio.create_task(self.ffmpeg_stdout_reader())
|
||||
tasks.append(stdout_reader_task)
|
||||
|
||||
self.tasks = tasks
|
||||
|
||||
return self.results_formatter()
|
||||
|
||||
async def cleanup(self):
|
||||
for task in self.tasks:
|
||||
task.cancel()
|
||||
try:
|
||||
await asyncio.gather(*self.tasks, return_exceptions=True)
|
||||
self.ffmpeg_process.stdin.close()
|
||||
self.ffmpeg_process.wait()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error during cleanup: {e}")
|
||||
if self.args.diarization and self.diarization:
|
||||
self.diarization.close()
|
||||
|
||||
async def process_audio(self, message):
|
||||
try:
|
||||
self.ffmpeg_process.stdin.write(message)
|
||||
self.ffmpeg_process.stdin.flush()
|
||||
except (BrokenPipeError, AttributeError) as e:
|
||||
logger.warning(f"Error writing to FFmpeg: {e}. Restarting...")
|
||||
await self.restart_ffmpeg()
|
||||
self.ffmpeg_process.stdin.write(message)
|
||||
self.ffmpeg_process.stdin.flush()
|
||||
|
||||
|
||||
406
audio_processor.py
Normal file
406
audio_processor.py
Normal file
@@ -0,0 +1,406 @@
|
||||
import asyncio
|
||||
import numpy as np
|
||||
import ffmpeg
|
||||
from time import time, sleep
|
||||
import math
|
||||
import logging
|
||||
import traceback
|
||||
from datetime import timedelta
|
||||
from typing import List, Dict, Any
|
||||
from timed_objects import ASRToken
|
||||
from whisper_streaming_custom.whisper_online import online_factory
|
||||
|
||||
# Set up logging once
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
def format_time(seconds: float) -> str:
|
||||
"""Format seconds as HH:MM:SS."""
|
||||
return str(timedelta(seconds=int(seconds)))
|
||||
|
||||
class AudioProcessor:
|
||||
"""
|
||||
Processes audio streams for transcription and diarization.
|
||||
Handles audio processing, state management, and result formatting in a single class.
|
||||
"""
|
||||
|
||||
def __init__(self, args, asr, tokenizer):
|
||||
"""Initialize the audio processor with configuration, models, and state."""
|
||||
# Audio processing settings
|
||||
self.args = args
|
||||
self.sample_rate = 16000
|
||||
self.channels = 1
|
||||
self.samples_per_sec = int(self.sample_rate * args.min_chunk_size)
|
||||
self.bytes_per_sample = 2
|
||||
self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample
|
||||
self.max_bytes_per_sec = 32000 * 5 # 5 seconds of audio at 32 kHz
|
||||
|
||||
# State management
|
||||
self.tokens = []
|
||||
self.buffer_transcription = ""
|
||||
self.buffer_diarization = ""
|
||||
self.full_transcription = ""
|
||||
self.end_buffer = 0
|
||||
self.end_attributed_speaker = 0
|
||||
self.lock = asyncio.Lock()
|
||||
self.beg_loop = time()
|
||||
self.sep = " " # Default separator
|
||||
self.last_response_content = ""
|
||||
|
||||
# Models and processing
|
||||
self.asr = asr
|
||||
self.tokenizer = tokenizer
|
||||
self.ffmpeg_process = self.start_ffmpeg_decoder()
|
||||
self.transcription_queue = asyncio.Queue() if args.transcription else None
|
||||
self.diarization_queue = asyncio.Queue() if args.diarization else None
|
||||
self.pcm_buffer = bytearray()
|
||||
|
||||
# Initialize transcription engine if enabled
|
||||
if args.transcription:
|
||||
self.online = online_factory(args, asr, tokenizer)
|
||||
|
||||
def convert_pcm_to_float(self, pcm_buffer):
|
||||
"""Convert PCM buffer in s16le format to normalized NumPy array."""
|
||||
return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
|
||||
|
||||
def start_ffmpeg_decoder(self):
|
||||
"""Start FFmpeg process for WebM to PCM conversion."""
|
||||
return (ffmpeg.input("pipe:0", format="webm")
|
||||
.output("pipe:1", format="s16le", acodec="pcm_s16le",
|
||||
ac=self.channels, ar=str(self.sample_rate))
|
||||
.run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True))
|
||||
|
||||
async def restart_ffmpeg(self):
|
||||
"""Restart the FFmpeg process after failure."""
|
||||
if self.ffmpeg_process:
|
||||
try:
|
||||
self.ffmpeg_process.kill()
|
||||
await asyncio.get_event_loop().run_in_executor(None, self.ffmpeg_process.wait)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error killing FFmpeg process: {e}")
|
||||
self.ffmpeg_process = self.start_ffmpeg_decoder()
|
||||
self.pcm_buffer = bytearray()
|
||||
|
||||
async def update_transcription(self, new_tokens, buffer, end_buffer, full_transcription, sep):
|
||||
"""Thread-safe update of transcription with new data."""
|
||||
async with self.lock:
|
||||
self.tokens.extend(new_tokens)
|
||||
self.buffer_transcription = buffer
|
||||
self.end_buffer = end_buffer
|
||||
self.full_transcription = full_transcription
|
||||
self.sep = sep
|
||||
|
||||
async def update_diarization(self, end_attributed_speaker, buffer_diarization=""):
|
||||
"""Thread-safe update of diarization with new data."""
|
||||
async with self.lock:
|
||||
self.end_attributed_speaker = end_attributed_speaker
|
||||
if buffer_diarization:
|
||||
self.buffer_diarization = buffer_diarization
|
||||
|
||||
async def add_dummy_token(self):
|
||||
"""Placeholder token when no transcription is available."""
|
||||
async with self.lock:
|
||||
current_time = time() - self.beg_loop
|
||||
self.tokens.append(ASRToken(
|
||||
start=current_time, end=current_time + 1,
|
||||
text=".", speaker=-1, is_dummy=True
|
||||
))
|
||||
|
||||
async def get_current_state(self):
|
||||
"""Get current state."""
|
||||
async with self.lock:
|
||||
current_time = time()
|
||||
|
||||
# Calculate remaining times
|
||||
remaining_transcription = 0
|
||||
if self.end_buffer > 0:
|
||||
remaining_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 2))
|
||||
|
||||
remaining_diarization = 0
|
||||
if self.tokens:
|
||||
latest_end = max(self.end_buffer, self.tokens[-1].end if self.tokens else 0)
|
||||
remaining_diarization = max(0, round(latest_end - self.end_attributed_speaker, 2))
|
||||
|
||||
return {
|
||||
"tokens": self.tokens.copy(),
|
||||
"buffer_transcription": self.buffer_transcription,
|
||||
"buffer_diarization": self.buffer_diarization,
|
||||
"end_buffer": self.end_buffer,
|
||||
"end_attributed_speaker": self.end_attributed_speaker,
|
||||
"sep": self.sep,
|
||||
"remaining_time_transcription": remaining_transcription,
|
||||
"remaining_time_diarization": remaining_diarization
|
||||
}
|
||||
|
||||
async def reset(self):
|
||||
"""Reset all state variables to initial values."""
|
||||
async with self.lock:
|
||||
self.tokens = []
|
||||
self.buffer_transcription = self.buffer_diarization = ""
|
||||
self.end_buffer = self.end_attributed_speaker = 0
|
||||
self.full_transcription = self.last_response_content = ""
|
||||
self.beg_loop = time()
|
||||
|
||||
async def ffmpeg_stdout_reader(self):
|
||||
"""Read audio data from FFmpeg stdout and process it."""
|
||||
loop = asyncio.get_event_loop()
|
||||
beg = time()
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Calculate buffer size based on elapsed time
|
||||
elapsed_time = math.floor((time() - beg) * 10) / 10 # Round to 0.1 sec
|
||||
buffer_size = max(int(32000 * elapsed_time), 4096)
|
||||
beg = time()
|
||||
|
||||
# Read chunk with timeout
|
||||
try:
|
||||
chunk = await asyncio.wait_for(
|
||||
loop.run_in_executor(None, self.ffmpeg_process.stdout.read, buffer_size),
|
||||
timeout=15.0
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("FFmpeg read timeout. Restarting...")
|
||||
await self.restart_ffmpeg()
|
||||
beg = time()
|
||||
continue
|
||||
|
||||
if not chunk:
|
||||
logger.info("FFmpeg stdout closed.")
|
||||
break
|
||||
|
||||
self.pcm_buffer.extend(chunk)
|
||||
|
||||
# Send to diarization if enabled
|
||||
if self.args.diarization and self.diarization_queue:
|
||||
await self.diarization_queue.put(
|
||||
self.convert_pcm_to_float(self.pcm_buffer).copy()
|
||||
)
|
||||
|
||||
# Process when we have enough data
|
||||
if len(self.pcm_buffer) >= self.bytes_per_sec:
|
||||
if len(self.pcm_buffer) > self.max_bytes_per_sec:
|
||||
logger.warning(
|
||||
f"Audio buffer too large: {len(self.pcm_buffer) / self.bytes_per_sec:.2f}s. "
|
||||
f"Consider using a smaller model."
|
||||
)
|
||||
|
||||
# Process audio chunk
|
||||
pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:self.max_bytes_per_sec])
|
||||
self.pcm_buffer = self.pcm_buffer[self.max_bytes_per_sec:]
|
||||
|
||||
# Send to transcription if enabled
|
||||
if self.args.transcription and self.transcription_queue:
|
||||
await self.transcription_queue.put(pcm_array.copy())
|
||||
|
||||
# Sleep if no processing is happening
|
||||
if not self.args.transcription and not self.args.diarization:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in ffmpeg_stdout_reader: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
break
|
||||
|
||||
async def transcription_processor(self):
|
||||
"""Process audio chunks for transcription."""
|
||||
self.full_transcription = ""
|
||||
self.sep = self.online.asr.sep
|
||||
|
||||
while True:
|
||||
try:
|
||||
pcm_array = await self.transcription_queue.get()
|
||||
|
||||
logger.info(f"{len(self.online.audio_buffer) / self.online.SAMPLING_RATE} seconds of audio to process.")
|
||||
|
||||
# Process transcription
|
||||
self.online.insert_audio_chunk(pcm_array)
|
||||
new_tokens = self.online.process_iter()
|
||||
|
||||
if new_tokens:
|
||||
self.full_transcription += self.sep.join([t.text for t in new_tokens])
|
||||
|
||||
# Get buffer information
|
||||
_buffer = self.online.get_buffer()
|
||||
buffer = _buffer.text
|
||||
end_buffer = _buffer.end if _buffer.end else (
|
||||
new_tokens[-1].end if new_tokens else 0
|
||||
)
|
||||
|
||||
# Avoid duplicating content
|
||||
if buffer in self.full_transcription:
|
||||
buffer = ""
|
||||
|
||||
await self.update_transcription(
|
||||
new_tokens, buffer, end_buffer, self.full_transcription, self.sep
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in transcription_processor: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
finally:
|
||||
self.transcription_queue.task_done()
|
||||
|
||||
async def diarization_processor(self, diarization_obj):
|
||||
"""Process audio chunks for speaker diarization."""
|
||||
buffer_diarization = ""
|
||||
|
||||
while True:
|
||||
try:
|
||||
pcm_array = await self.diarization_queue.get()
|
||||
|
||||
# Process diarization
|
||||
await diarization_obj.diarize(pcm_array)
|
||||
|
||||
# Get current state and update speakers
|
||||
state = await self.get_current_state()
|
||||
new_end = diarization_obj.assign_speakers_to_tokens(
|
||||
state["end_attributed_speaker"], state["tokens"]
|
||||
)
|
||||
|
||||
await self.update_diarization(new_end, buffer_diarization)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in diarization_processor: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
finally:
|
||||
self.diarization_queue.task_done()
|
||||
|
||||
async def results_formatter(self):
|
||||
"""Format processing results for output."""
|
||||
while True:
|
||||
try:
|
||||
# Get current state
|
||||
state = await self.get_current_state()
|
||||
tokens = state["tokens"]
|
||||
buffer_transcription = state["buffer_transcription"]
|
||||
buffer_diarization = state["buffer_diarization"]
|
||||
end_attributed_speaker = state["end_attributed_speaker"]
|
||||
sep = state["sep"]
|
||||
|
||||
# Add dummy tokens if needed
|
||||
if (not tokens or tokens[-1].is_dummy) and not self.args.transcription and self.args.diarization:
|
||||
await self.add_dummy_token()
|
||||
sleep(0.5)
|
||||
state = await self.get_current_state()
|
||||
tokens = state["tokens"]
|
||||
|
||||
# Format output
|
||||
previous_speaker = -1
|
||||
lines = []
|
||||
last_end_diarized = 0
|
||||
undiarized_text = []
|
||||
|
||||
# Process each token
|
||||
for token in tokens:
|
||||
speaker = token.speaker
|
||||
|
||||
# Handle diarization
|
||||
if self.args.diarization:
|
||||
if (speaker in [-1, 0]) and token.end >= end_attributed_speaker:
|
||||
undiarized_text.append(token.text)
|
||||
continue
|
||||
elif (speaker in [-1, 0]) and token.end < end_attributed_speaker:
|
||||
speaker = previous_speaker
|
||||
if speaker not in [-1, 0]:
|
||||
last_end_diarized = max(token.end, last_end_diarized)
|
||||
|
||||
# Group by speaker
|
||||
if speaker != previous_speaker or not lines:
|
||||
lines.append({
|
||||
"speaker": speaker,
|
||||
"text": token.text,
|
||||
"beg": format_time(token.start),
|
||||
"end": format_time(token.end),
|
||||
"diff": round(token.end - last_end_diarized, 2)
|
||||
})
|
||||
previous_speaker = speaker
|
||||
elif token.text: # Only append if text isn't empty
|
||||
lines[-1]["text"] += sep + token.text
|
||||
lines[-1]["end"] = format_time(token.end)
|
||||
lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
|
||||
|
||||
# Handle undiarized text
|
||||
if undiarized_text:
|
||||
combined = sep.join(undiarized_text)
|
||||
if buffer_transcription:
|
||||
combined += sep
|
||||
await self.update_diarization(end_attributed_speaker, combined)
|
||||
buffer_diarization = combined
|
||||
|
||||
# Create response object
|
||||
if not lines:
|
||||
lines = [{
|
||||
"speaker": 1,
|
||||
"text": "",
|
||||
"beg": format_time(0),
|
||||
"end": format_time(tokens[-1].end if tokens else 0),
|
||||
"diff": 0
|
||||
}]
|
||||
|
||||
response = {
|
||||
"lines": lines,
|
||||
"buffer_transcription": buffer_transcription,
|
||||
"buffer_diarization": buffer_diarization,
|
||||
"remaining_time_transcription": state["remaining_time_transcription"],
|
||||
"remaining_time_diarization": state["remaining_time_diarization"]
|
||||
}
|
||||
|
||||
# Only yield if content has changed
|
||||
response_content = ' '.join([f"{line['speaker']} {line['text']}" for line in lines]) + \
|
||||
f" | {buffer_transcription} | {buffer_diarization}"
|
||||
|
||||
if response_content != self.last_response_content and (lines or buffer_transcription or buffer_diarization):
|
||||
yield response
|
||||
self.last_response_content = response_content
|
||||
|
||||
await asyncio.sleep(0.1) # Avoid overwhelming the client
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in results_formatter: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
await asyncio.sleep(0.5) # Back off on error
|
||||
|
||||
async def create_tasks(self, diarization=None):
|
||||
"""Create and start processing tasks."""
|
||||
if diarization:
|
||||
self.diarization = diarization
|
||||
|
||||
tasks = []
|
||||
if self.args.transcription and self.online:
|
||||
tasks.append(asyncio.create_task(self.transcription_processor()))
|
||||
|
||||
if self.args.diarization and self.diarization:
|
||||
tasks.append(asyncio.create_task(self.diarization_processor(self.diarization)))
|
||||
|
||||
tasks.append(asyncio.create_task(self.ffmpeg_stdout_reader()))
|
||||
self.tasks = tasks
|
||||
|
||||
return self.results_formatter()
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up resources when processing is complete."""
|
||||
for task in self.tasks:
|
||||
task.cancel()
|
||||
|
||||
try:
|
||||
await asyncio.gather(*self.tasks, return_exceptions=True)
|
||||
self.ffmpeg_process.stdin.close()
|
||||
self.ffmpeg_process.wait()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error during cleanup: {e}")
|
||||
|
||||
if self.args.diarization and hasattr(self, 'diarization'):
|
||||
self.diarization.close()
|
||||
|
||||
async def process_audio(self, message):
|
||||
"""Process incoming audio data."""
|
||||
try:
|
||||
self.ffmpeg_process.stdin.write(message)
|
||||
self.ffmpeg_process.stdin.flush()
|
||||
except (BrokenPipeError, AttributeError) as e:
|
||||
logger.warning(f"Error writing to FFmpeg: {e}. Restarting...")
|
||||
await self.restart_ffmpeg()
|
||||
self.ffmpeg_process.stdin.write(message)
|
||||
self.ffmpeg_process.stdin.flush()
|
||||
@@ -1,91 +0,0 @@
|
||||
from typing import Dict, Any, List
|
||||
from datetime import timedelta
|
||||
|
||||
def format_time(seconds: float) -> str:
|
||||
"""Format seconds as HH:MM:SS."""
|
||||
return str(timedelta(seconds=int(seconds)))
|
||||
|
||||
def format_response(state: Dict[str, Any], with_diarization: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
Format the shared state into a client-friendly response.
|
||||
|
||||
Args:
|
||||
state: Current shared state dictionary
|
||||
with_diarization: Whether to include diarization formatting
|
||||
|
||||
Returns:
|
||||
Formatted response dictionary ready to send to client
|
||||
"""
|
||||
tokens = state["tokens"]
|
||||
buffer_transcription = state["buffer_transcription"]
|
||||
buffer_diarization = state["buffer_diarization"]
|
||||
end_attributed_speaker = state["end_attributed_speaker"]
|
||||
remaining_time_transcription = state["remaining_time_transcription"]
|
||||
remaining_time_diarization = state["remaining_time_diarization"]
|
||||
sep = state["sep"]
|
||||
|
||||
# Default response for empty state
|
||||
if not tokens:
|
||||
return {
|
||||
"lines": [{
|
||||
"speaker": 1,
|
||||
"text": "",
|
||||
"beg": format_time(0),
|
||||
"end": format_time(0),
|
||||
"diff": 0
|
||||
}],
|
||||
"buffer_transcription": buffer_transcription,
|
||||
"buffer_diarization": buffer_diarization,
|
||||
"remaining_time_transcription": remaining_time_transcription,
|
||||
"remaining_time_diarization": remaining_time_diarization
|
||||
}
|
||||
|
||||
# Process tokens to create response
|
||||
previous_speaker = -1
|
||||
lines = []
|
||||
last_end_diarized = 0
|
||||
undiarized_text = []
|
||||
|
||||
for token in tokens:
|
||||
speaker = token.speaker
|
||||
|
||||
# Handle diarization logic
|
||||
if with_diarization:
|
||||
if (speaker == -1 or speaker == 0) and token.end >= end_attributed_speaker:
|
||||
undiarized_text.append(token.text)
|
||||
continue
|
||||
elif (speaker == -1 or speaker == 0) and token.end < end_attributed_speaker:
|
||||
speaker = previous_speaker
|
||||
|
||||
if speaker not in [-1, 0]:
|
||||
last_end_diarized = max(token.end, last_end_diarized)
|
||||
|
||||
# Add new line or append to existing line
|
||||
if speaker != previous_speaker or not lines:
|
||||
lines.append({
|
||||
"speaker": speaker,
|
||||
"text": token.text,
|
||||
"beg": format_time(token.start),
|
||||
"end": format_time(token.end),
|
||||
"diff": round(token.end - last_end_diarized, 2)
|
||||
})
|
||||
previous_speaker = speaker
|
||||
elif token.text: # Only append if text isn't empty
|
||||
lines[-1]["text"] += sep + token.text
|
||||
lines[-1]["end"] = format_time(token.end)
|
||||
lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
|
||||
|
||||
# If we have undiarized text, include it in the buffer
|
||||
if undiarized_text:
|
||||
combined_buffer = sep.join(undiarized_text)
|
||||
if buffer_transcription:
|
||||
combined_buffer += sep + buffer_transcription
|
||||
buffer_diarization = combined_buffer
|
||||
|
||||
return {
|
||||
"lines": lines,
|
||||
"buffer_transcription": buffer_transcription,
|
||||
"buffer_diarization": buffer_diarization,
|
||||
"remaining_time_transcription": remaining_time_transcription,
|
||||
"remaining_time_diarization": remaining_time_diarization
|
||||
}
|
||||
96
state.py
96
state.py
@@ -1,96 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from time import time
|
||||
from typing import List, Dict, Any, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from timed_objects import ASRToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SharedState:
|
||||
"""
|
||||
Thread-safe state manager for streaming transcription and diarization.
|
||||
Handles coordination between audio processing, transcription, and diarization.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.tokens: List[ASRToken] = []
|
||||
self.buffer_transcription: str = ""
|
||||
self.buffer_diarization: str = ""
|
||||
self.full_transcription: str = ""
|
||||
self.end_buffer: float = 0
|
||||
self.end_attributed_speaker: float = 0
|
||||
self.lock = asyncio.Lock()
|
||||
self.beg_loop: float = time()
|
||||
self.sep: str = " " # Default separator
|
||||
self.last_response_content: str = "" # To track changes in response
|
||||
|
||||
async def update_transcription(self, new_tokens: List[ASRToken], buffer: str,
|
||||
end_buffer: float, full_transcription: str, sep: str) -> None:
|
||||
"""Update the state with new transcription data."""
|
||||
async with self.lock:
|
||||
self.tokens.extend(new_tokens)
|
||||
self.buffer_transcription = buffer
|
||||
self.end_buffer = end_buffer
|
||||
self.full_transcription = full_transcription
|
||||
self.sep = sep
|
||||
|
||||
async def update_diarization(self, end_attributed_speaker: float, buffer_diarization: str = "") -> None:
|
||||
"""Update the state with new diarization data."""
|
||||
async with self.lock:
|
||||
self.end_attributed_speaker = end_attributed_speaker
|
||||
if buffer_diarization:
|
||||
self.buffer_diarization = buffer_diarization
|
||||
|
||||
async def add_dummy_token(self) -> None:
|
||||
"""Add a dummy token to keep the state updated even without transcription."""
|
||||
async with self.lock:
|
||||
current_time = time() - self.beg_loop
|
||||
dummy_token = ASRToken(
|
||||
start=current_time,
|
||||
end=current_time + 1,
|
||||
text=".",
|
||||
speaker=-1,
|
||||
is_dummy=True
|
||||
)
|
||||
self.tokens.append(dummy_token)
|
||||
|
||||
async def get_current_state(self) -> Dict[str, Any]:
|
||||
"""Get the current state with calculated timing information."""
|
||||
async with self.lock:
|
||||
current_time = time()
|
||||
remaining_time_transcription = 0
|
||||
remaining_time_diarization = 0
|
||||
|
||||
# Calculate remaining time for transcription buffer
|
||||
if self.end_buffer > 0:
|
||||
remaining_time_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 2))
|
||||
|
||||
# Calculate remaining time for diarization
|
||||
if self.tokens:
|
||||
latest_end = max(self.end_buffer, self.tokens[-1].end if self.tokens else 0)
|
||||
remaining_time_diarization = max(0, round(latest_end - self.end_attributed_speaker, 2))
|
||||
|
||||
return {
|
||||
"tokens": self.tokens.copy(),
|
||||
"buffer_transcription": self.buffer_transcription,
|
||||
"buffer_diarization": self.buffer_diarization,
|
||||
"end_buffer": self.end_buffer,
|
||||
"end_attributed_speaker": self.end_attributed_speaker,
|
||||
"sep": self.sep,
|
||||
"remaining_time_transcription": remaining_time_transcription,
|
||||
"remaining_time_diarization": remaining_time_diarization
|
||||
}
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""Reset the state to initial values."""
|
||||
async with self.lock:
|
||||
self.tokens = []
|
||||
self.buffer_transcription = ""
|
||||
self.buffer_diarization = ""
|
||||
self.end_buffer = 0
|
||||
self.end_attributed_speaker = 0
|
||||
self.full_transcription = ""
|
||||
self.beg_loop = time()
|
||||
self.last_response_content = ""
|
||||
@@ -8,7 +8,7 @@ from whisper_streaming_custom.whisper_online import backend_factory, warmup_asr
|
||||
import asyncio
|
||||
import logging
|
||||
from parse_args import parse_args
|
||||
from audio import AudioProcessor
|
||||
from audio_processor import AudioProcessor
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
logging.getLogger().setLevel(logging.WARNING)
|
||||
@@ -80,7 +80,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
logger.warning("WebSocket disconnected.")
|
||||
finally:
|
||||
websocket_task.cancel()
|
||||
audio_processor.cleanup()
|
||||
await audio_processor.cleanup()
|
||||
logger.info("WebSocket endpoint cleaned up.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user