VAC before doing transcription and diarization. V0

This commit is contained in:
Quentin Fuxa
2025-08-16 23:04:21 +02:00
parent e4221fa6c3
commit 28bdc52e1d
5 changed files with 135 additions and 15 deletions

View File

@@ -5,10 +5,12 @@ import math
import logging
import traceback
from datetime import timedelta
from whisperlivekit.timed_objects import ASRToken
from whisperlivekit.timed_objects import ASRToken, Silence
from whisperlivekit.core import TranscriptionEngine, online_factory
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
from .remove_silences import handle_silences
from trail_repetition import trim_tail_repetition
from silero_vad_iterator import FixedVADIterator
# Set up logging once
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
@@ -48,6 +50,8 @@ class AudioProcessor:
# State management
self.is_stopping = False
self.silence = False
self.silence_duration = 0.0
self.tokens = []
self.buffer_transcription = ""
self.buffer_diarization = ""
@@ -62,7 +66,10 @@ class AudioProcessor:
self.asr = models.asr
self.tokenizer = models.tokenizer
self.diarization = models.diarization
import torch
model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
self.vac = FixedVADIterator(model)
self.vac.reset_states()
self.ffmpeg_manager = FFmpegManager(
sample_rate=self.sample_rate,
channels=self.channels
@@ -98,6 +105,17 @@ class AudioProcessor:
"""Thread-safe update of transcription with new data."""
async with self.lock:
self.tokens.extend(new_tokens)
# self.tokens, has_been_trimmed = trim_tail_repetition(
# self.tokens,
# key=lambda t: t.text.strip().lower(),
# min_block=2, # avoid trimming single '.' loops; set to 1 if you want to remove those too
# max_tail=200,
# prefer="longest", # prefer removing the longest repeated phrase
# keep=1
# )
# if has_been_trimmed:
# print('HAS BEEN TRIMMED !')
self.buffer_transcription = buffer
self.end_buffer = end_buffer
self.sep = sep
@@ -200,19 +218,45 @@ class AudioProcessor:
# Process audio chunk
pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:self.max_bytes_per_sec])
self.pcm_buffer = self.pcm_buffer[self.max_bytes_per_sec:]
# Send to transcription if enabled
if self.args.transcription and self.transcription_queue:
await self.transcription_queue.put(pcm_array.copy())
res = self.vac(pcm_array)
# Send to diarization if enabled
if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(pcm_array.copy())
end_of_audio = False
silence_buffer = None
if self.silence:
print('NO AUDIO')
if res is not None:
if res.get('end', 0) > res.get('start', 0):
end_of_audio = True
elif self.silence: #end of silence
self.silence = False
silence_buffer = Silence(duration=time() - self.start_silence)
if silence_buffer:
if self.args.transcription and self.transcription_queue:
await self.transcription_queue.put(silence_buffer)
if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(silence_buffer)
if not self.silence:
if self.args.transcription and self.transcription_queue:
await self.transcription_queue.put(pcm_array.copy())
if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(pcm_array.copy())
self.silence_duration = 0.0
if end_of_audio:
self.silence = True
self.start_silence = time()
# Sleep if no processing is happening
if not self.args.transcription and not self.args.diarization:
await asyncio.sleep(0.1)
except Exception as e:
logger.warning(f"Exception in ffmpeg_stdout_reader: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}")
@@ -239,8 +283,8 @@ class AudioProcessor:
while True:
try:
pcm_array = await self.transcription_queue.get()
if pcm_array is SENTINEL:
item = await self.transcription_queue.get()
if item is SENTINEL:
logger.debug("Transcription processor received sentinel. Finishing.")
self.transcription_queue.task_done()
break
@@ -258,11 +302,23 @@ class AudioProcessor:
f"lag={transcription_lag_s:.2f}s."
)
# Process transcription
duration_this_chunk = len(pcm_array) / self.sample_rate if isinstance(pcm_array, np.ndarray) else 0
if type(item) is Silence:
cumulative_pcm_duration_stream_time += item.duration
self.online.insert_silence(item.duration)
continue
if isinstance(item, np.ndarray):
pcm_array = item
else:
raise Exception('item should be pcm_array')
duration_this_chunk = len(pcm_array) / self.sample_rate
cumulative_pcm_duration_stream_time += duration_this_chunk
stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time
self.online.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm)
new_tokens, current_audio_processed_upto = self.online.process_iter()

View File

@@ -29,4 +29,8 @@ class SpeakerSegment(TimedText):
"""Represents a segment of audio attributed to a specific speaker.
No text nor probability is associated with this segment.
"""
pass
pass
@dataclass
class Silence():
duration: float

View File

@@ -0,0 +1,60 @@
from typing import Sequence, Callable, Any, Optional, Dict
def _detect_tail_repetition(
seq: Sequence[Any],
key: Callable[[Any], Any] = lambda x: x, # extract comparable value
min_block: int = 1, # set to 2 to ignore 1-token loops like "."
max_tail: int = 300, # search window from the end for speed
prefer: str = "longest", # "longest" coverage or "smallest" block
) -> Optional[Dict]:
vals = [key(x) for x in seq][-max_tail:]
n = len(vals)
best = None
# try every possible block length
for b in range(min_block, n // 2 + 1):
block = vals[-b:]
# count how many times this block repeats contiguously at the very end
count, i = 0, n
while i - b >= 0 and vals[i - b:i] == block:
count += 1
i -= b
if count >= 2:
cand = {
"block_size": b,
"count": count,
"start_index": len(seq) - count * b, # in original seq
"end_index": len(seq),
}
if (best is None or
(prefer == "longest" and count * b > best["count"] * best["block_size"]) or
(prefer == "smallest" and b < best["block_size"])):
best = cand
return best
def trim_tail_repetition(
seq: Sequence[Any],
key: Callable[[Any], Any] = lambda x: x,
min_block: int = 1,
max_tail: int = 300,
prefer: str = "longest",
keep: int = 1, # how many copies of the repeating block to keep at the end (0 or 1 are common)
):
"""
Returns a new sequence with repeated tail trimmed.
keep=1 -> keep a single copy of the repeated block.
keep=0 -> remove all copies of the repeated block.
"""
rep = _detect_tail_repetition(seq, key, min_block, max_tail, prefer)
if not rep:
return seq, False # nothing to trim
b, c = rep["block_size"], rep["count"]
if keep < 0:
keep = 0
if keep >= c:
return seq, False # nothing to trim (already <= keep copies)
# new length = total - (copies_to_remove * block_size)
new_len = len(seq) - (c - keep) * b
return seq[:new_len], True

View File

@@ -411,7 +411,7 @@ class VACOnlineASRProcessor:
# Load a VAD model (e.g. Silero VAD)
import torch
model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
from .silero_vad_iterator import FixedVADIterator
from ..silero_vad_iterator import FixedVADIterator
self.vac = FixedVADIterator(model)
self.logfile = self.online.logfile