mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 14:23:18 +00:00
VAC before doing transcription and diarization. V0
This commit is contained in:
@@ -5,10 +5,12 @@ import math
|
||||
import logging
|
||||
import traceback
|
||||
from datetime import timedelta
|
||||
from whisperlivekit.timed_objects import ASRToken
|
||||
from whisperlivekit.timed_objects import ASRToken, Silence
|
||||
from whisperlivekit.core import TranscriptionEngine, online_factory
|
||||
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
|
||||
from .remove_silences import handle_silences
|
||||
from trail_repetition import trim_tail_repetition
|
||||
from silero_vad_iterator import FixedVADIterator
|
||||
# Set up logging once
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -48,6 +50,8 @@ class AudioProcessor:
|
||||
|
||||
# State management
|
||||
self.is_stopping = False
|
||||
self.silence = False
|
||||
self.silence_duration = 0.0
|
||||
self.tokens = []
|
||||
self.buffer_transcription = ""
|
||||
self.buffer_diarization = ""
|
||||
@@ -62,7 +66,10 @@ class AudioProcessor:
|
||||
self.asr = models.asr
|
||||
self.tokenizer = models.tokenizer
|
||||
self.diarization = models.diarization
|
||||
|
||||
import torch
|
||||
model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
|
||||
self.vac = FixedVADIterator(model)
|
||||
self.vac.reset_states()
|
||||
self.ffmpeg_manager = FFmpegManager(
|
||||
sample_rate=self.sample_rate,
|
||||
channels=self.channels
|
||||
@@ -98,6 +105,17 @@ class AudioProcessor:
|
||||
"""Thread-safe update of transcription with new data."""
|
||||
async with self.lock:
|
||||
self.tokens.extend(new_tokens)
|
||||
|
||||
# self.tokens, has_been_trimmed = trim_tail_repetition(
|
||||
# self.tokens,
|
||||
# key=lambda t: t.text.strip().lower(),
|
||||
# min_block=2, # avoid trimming single '.' loops; set to 1 if you want to remove those too
|
||||
# max_tail=200,
|
||||
# prefer="longest", # prefer removing the longest repeated phrase
|
||||
# keep=1
|
||||
# )
|
||||
# if has_been_trimmed:
|
||||
# print('HAS BEEN TRIMMED !')
|
||||
self.buffer_transcription = buffer
|
||||
self.end_buffer = end_buffer
|
||||
self.sep = sep
|
||||
@@ -200,19 +218,45 @@ class AudioProcessor:
|
||||
# Process audio chunk
|
||||
pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:self.max_bytes_per_sec])
|
||||
self.pcm_buffer = self.pcm_buffer[self.max_bytes_per_sec:]
|
||||
|
||||
# Send to transcription if enabled
|
||||
if self.args.transcription and self.transcription_queue:
|
||||
await self.transcription_queue.put(pcm_array.copy())
|
||||
res = self.vac(pcm_array)
|
||||
|
||||
# Send to diarization if enabled
|
||||
if self.args.diarization and self.diarization_queue:
|
||||
await self.diarization_queue.put(pcm_array.copy())
|
||||
end_of_audio = False
|
||||
silence_buffer = None
|
||||
|
||||
if self.silence:
|
||||
print('NO AUDIO')
|
||||
|
||||
if res is not None:
|
||||
if res.get('end', 0) > res.get('start', 0):
|
||||
end_of_audio = True
|
||||
elif self.silence: #end of silence
|
||||
self.silence = False
|
||||
silence_buffer = Silence(duration=time() - self.start_silence)
|
||||
|
||||
if silence_buffer:
|
||||
if self.args.transcription and self.transcription_queue:
|
||||
await self.transcription_queue.put(silence_buffer)
|
||||
if self.args.diarization and self.diarization_queue:
|
||||
await self.diarization_queue.put(silence_buffer)
|
||||
|
||||
if not self.silence:
|
||||
if self.args.transcription and self.transcription_queue:
|
||||
await self.transcription_queue.put(pcm_array.copy())
|
||||
|
||||
if self.args.diarization and self.diarization_queue:
|
||||
await self.diarization_queue.put(pcm_array.copy())
|
||||
|
||||
self.silence_duration = 0.0
|
||||
if end_of_audio:
|
||||
self.silence = True
|
||||
self.start_silence = time()
|
||||
|
||||
# Sleep if no processing is happening
|
||||
if not self.args.transcription and not self.args.diarization:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in ffmpeg_stdout_reader: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
@@ -239,8 +283,8 @@ class AudioProcessor:
|
||||
|
||||
while True:
|
||||
try:
|
||||
pcm_array = await self.transcription_queue.get()
|
||||
if pcm_array is SENTINEL:
|
||||
item = await self.transcription_queue.get()
|
||||
if item is SENTINEL:
|
||||
logger.debug("Transcription processor received sentinel. Finishing.")
|
||||
self.transcription_queue.task_done()
|
||||
break
|
||||
@@ -258,11 +302,23 @@ class AudioProcessor:
|
||||
f"lag={transcription_lag_s:.2f}s."
|
||||
)
|
||||
|
||||
# Process transcription
|
||||
duration_this_chunk = len(pcm_array) / self.sample_rate if isinstance(pcm_array, np.ndarray) else 0
|
||||
if type(item) is Silence:
|
||||
cumulative_pcm_duration_stream_time += item.duration
|
||||
self.online.insert_silence(item.duration)
|
||||
continue
|
||||
|
||||
if isinstance(item, np.ndarray):
|
||||
pcm_array = item
|
||||
else:
|
||||
raise Exception('item should be pcm_array')
|
||||
|
||||
duration_this_chunk = len(pcm_array) / self.sample_rate
|
||||
cumulative_pcm_duration_stream_time += duration_this_chunk
|
||||
stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time
|
||||
|
||||
|
||||
|
||||
|
||||
self.online.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm)
|
||||
new_tokens, current_audio_processed_upto = self.online.process_iter()
|
||||
|
||||
|
||||
@@ -29,4 +29,8 @@ class SpeakerSegment(TimedText):
|
||||
"""Represents a segment of audio attributed to a specific speaker.
|
||||
No text nor probability is associated with this segment.
|
||||
"""
|
||||
pass
|
||||
pass
|
||||
|
||||
@dataclass
|
||||
class Silence():
|
||||
duration: float
|
||||
60
whisperlivekit/trail_repetition.py
Normal file
60
whisperlivekit/trail_repetition.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from typing import Sequence, Callable, Any, Optional, Dict
|
||||
|
||||
def _detect_tail_repetition(
|
||||
seq: Sequence[Any],
|
||||
key: Callable[[Any], Any] = lambda x: x, # extract comparable value
|
||||
min_block: int = 1, # set to 2 to ignore 1-token loops like "."
|
||||
max_tail: int = 300, # search window from the end for speed
|
||||
prefer: str = "longest", # "longest" coverage or "smallest" block
|
||||
) -> Optional[Dict]:
|
||||
vals = [key(x) for x in seq][-max_tail:]
|
||||
n = len(vals)
|
||||
best = None
|
||||
|
||||
# try every possible block length
|
||||
for b in range(min_block, n // 2 + 1):
|
||||
block = vals[-b:]
|
||||
# count how many times this block repeats contiguously at the very end
|
||||
count, i = 0, n
|
||||
while i - b >= 0 and vals[i - b:i] == block:
|
||||
count += 1
|
||||
i -= b
|
||||
|
||||
if count >= 2:
|
||||
cand = {
|
||||
"block_size": b,
|
||||
"count": count,
|
||||
"start_index": len(seq) - count * b, # in original seq
|
||||
"end_index": len(seq),
|
||||
}
|
||||
if (best is None or
|
||||
(prefer == "longest" and count * b > best["count"] * best["block_size"]) or
|
||||
(prefer == "smallest" and b < best["block_size"])):
|
||||
best = cand
|
||||
return best
|
||||
|
||||
def trim_tail_repetition(
|
||||
seq: Sequence[Any],
|
||||
key: Callable[[Any], Any] = lambda x: x,
|
||||
min_block: int = 1,
|
||||
max_tail: int = 300,
|
||||
prefer: str = "longest",
|
||||
keep: int = 1, # how many copies of the repeating block to keep at the end (0 or 1 are common)
|
||||
):
|
||||
"""
|
||||
Returns a new sequence with repeated tail trimmed.
|
||||
keep=1 -> keep a single copy of the repeated block.
|
||||
keep=0 -> remove all copies of the repeated block.
|
||||
"""
|
||||
rep = _detect_tail_repetition(seq, key, min_block, max_tail, prefer)
|
||||
if not rep:
|
||||
return seq, False # nothing to trim
|
||||
|
||||
b, c = rep["block_size"], rep["count"]
|
||||
if keep < 0:
|
||||
keep = 0
|
||||
if keep >= c:
|
||||
return seq, False # nothing to trim (already <= keep copies)
|
||||
# new length = total - (copies_to_remove * block_size)
|
||||
new_len = len(seq) - (c - keep) * b
|
||||
return seq[:new_len], True
|
||||
@@ -411,7 +411,7 @@ class VACOnlineASRProcessor:
|
||||
# Load a VAD model (e.g. Silero VAD)
|
||||
import torch
|
||||
model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
|
||||
from .silero_vad_iterator import FixedVADIterator
|
||||
from ..silero_vad_iterator import FixedVADIterator
|
||||
|
||||
self.vac = FixedVADIterator(model)
|
||||
self.logfile = self.online.logfile
|
||||
|
||||
Reference in New Issue
Block a user