VAC before doing transcription and diarization. V0

This commit is contained in:
Quentin Fuxa
2025-08-16 23:04:21 +02:00
parent e4221fa6c3
commit 28bdc52e1d
5 changed files with 135 additions and 15 deletions

View File

@@ -5,10 +5,12 @@ import math
import logging
import traceback
from datetime import timedelta
from whisperlivekit.timed_objects import ASRToken
from whisperlivekit.timed_objects import ASRToken, Silence
from whisperlivekit.core import TranscriptionEngine, online_factory
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
from .remove_silences import handle_silences
from trail_repetition import trim_tail_repetition
from silero_vad_iterator import FixedVADIterator
# Set up logging once
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
@@ -48,6 +50,8 @@ class AudioProcessor:
# State management
self.is_stopping = False
self.silence = False
self.silence_duration = 0.0
self.tokens = []
self.buffer_transcription = ""
self.buffer_diarization = ""
@@ -62,7 +66,10 @@ class AudioProcessor:
self.asr = models.asr
self.tokenizer = models.tokenizer
self.diarization = models.diarization
import torch
model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
self.vac = FixedVADIterator(model)
self.vac.reset_states()
self.ffmpeg_manager = FFmpegManager(
sample_rate=self.sample_rate,
channels=self.channels
@@ -98,6 +105,17 @@ class AudioProcessor:
"""Thread-safe update of transcription with new data."""
async with self.lock:
self.tokens.extend(new_tokens)
# self.tokens, has_been_trimmed = trim_tail_repetition(
# self.tokens,
# key=lambda t: t.text.strip().lower(),
# min_block=2, # avoid trimming single '.' loops; set to 1 if you want to remove those too
# max_tail=200,
# prefer="longest", # prefer removing the longest repeated phrase
# keep=1
# )
# if has_been_trimmed:
# print('HAS BEEN TRIMMED !')
self.buffer_transcription = buffer
self.end_buffer = end_buffer
self.sep = sep
@@ -200,19 +218,45 @@ class AudioProcessor:
# Process audio chunk
pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:self.max_bytes_per_sec])
self.pcm_buffer = self.pcm_buffer[self.max_bytes_per_sec:]
# Send to transcription if enabled
if self.args.transcription and self.transcription_queue:
await self.transcription_queue.put(pcm_array.copy())
res = self.vac(pcm_array)
# Send to diarization if enabled
if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(pcm_array.copy())
end_of_audio = False
silence_buffer = None
if self.silence:
print('NO AUDIO')
if res is not None:
if res.get('end', 0) > res.get('start', 0):
end_of_audio = True
elif self.silence: #end of silence
self.silence = False
silence_buffer = Silence(duration=time() - self.start_silence)
if silence_buffer:
if self.args.transcription and self.transcription_queue:
await self.transcription_queue.put(silence_buffer)
if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(silence_buffer)
if not self.silence:
if self.args.transcription and self.transcription_queue:
await self.transcription_queue.put(pcm_array.copy())
if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(pcm_array.copy())
self.silence_duration = 0.0
if end_of_audio:
self.silence = True
self.start_silence = time()
# Sleep if no processing is happening
if not self.args.transcription and not self.args.diarization:
await asyncio.sleep(0.1)
except Exception as e:
logger.warning(f"Exception in ffmpeg_stdout_reader: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}")
@@ -239,8 +283,8 @@ class AudioProcessor:
while True:
try:
pcm_array = await self.transcription_queue.get()
if pcm_array is SENTINEL:
item = await self.transcription_queue.get()
if item is SENTINEL:
logger.debug("Transcription processor received sentinel. Finishing.")
self.transcription_queue.task_done()
break
@@ -258,11 +302,23 @@ class AudioProcessor:
f"lag={transcription_lag_s:.2f}s."
)
# Process transcription
duration_this_chunk = len(pcm_array) / self.sample_rate if isinstance(pcm_array, np.ndarray) else 0
if type(item) is Silence:
cumulative_pcm_duration_stream_time += item.duration
self.online.insert_silence(item.duration)
continue
if isinstance(item, np.ndarray):
pcm_array = item
else:
raise Exception('item should be pcm_array')
duration_this_chunk = len(pcm_array) / self.sample_rate
cumulative_pcm_duration_stream_time += duration_this_chunk
stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time
self.online.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm)
new_tokens, current_audio_processed_upto = self.online.process_iter()