diff --git a/whisperlivekit/audio_processor.py b/whisperlivekit/audio_processor.py index c9ff1f7..fa454c2 100644 --- a/whisperlivekit/audio_processor.py +++ b/whisperlivekit/audio_processor.py @@ -5,10 +5,12 @@ import math import logging import traceback from datetime import timedelta -from whisperlivekit.timed_objects import ASRToken +from whisperlivekit.timed_objects import ASRToken, Silence from whisperlivekit.core import TranscriptionEngine, online_factory from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState from .remove_silences import handle_silences +from trail_repetition import trim_tail_repetition +from silero_vad_iterator import FixedVADIterator # Set up logging once logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) @@ -48,6 +50,8 @@ class AudioProcessor: # State management self.is_stopping = False + self.silence = False + self.silence_duration = 0.0 self.tokens = [] self.buffer_transcription = "" self.buffer_diarization = "" @@ -62,7 +66,10 @@ class AudioProcessor: self.asr = models.asr self.tokenizer = models.tokenizer self.diarization = models.diarization - + import torch + model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad") + self.vac = FixedVADIterator(model) + self.vac.reset_states() self.ffmpeg_manager = FFmpegManager( sample_rate=self.sample_rate, channels=self.channels @@ -98,6 +105,17 @@ class AudioProcessor: """Thread-safe update of transcription with new data.""" async with self.lock: self.tokens.extend(new_tokens) + + # self.tokens, has_been_trimmed = trim_tail_repetition( + # self.tokens, + # key=lambda t: t.text.strip().lower(), + # min_block=2, # avoid trimming single '.' loops; set to 1 if you want to remove those too + # max_tail=200, + # prefer="longest", # prefer removing the longest repeated phrase + # keep=1 + # ) + # if has_been_trimmed: + # print('HAS BEEN TRIMMED !') self.buffer_transcription = buffer self.end_buffer = end_buffer self.sep = sep @@ -200,19 +218,45 @@ class AudioProcessor: # Process audio chunk pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:self.max_bytes_per_sec]) self.pcm_buffer = self.pcm_buffer[self.max_bytes_per_sec:] - - # Send to transcription if enabled - if self.args.transcription and self.transcription_queue: - await self.transcription_queue.put(pcm_array.copy()) + res = self.vac(pcm_array) - # Send to diarization if enabled - if self.args.diarization and self.diarization_queue: - await self.diarization_queue.put(pcm_array.copy()) + end_of_audio = False + silence_buffer = None + + if self.silence: + print('NO AUDIO') + + if res is not None: + if res.get('end', 0) > res.get('start', 0): + end_of_audio = True + elif self.silence: #end of silence + self.silence = False + silence_buffer = Silence(duration=time() - self.start_silence) + + if silence_buffer: + if self.args.transcription and self.transcription_queue: + await self.transcription_queue.put(silence_buffer) + if self.args.diarization and self.diarization_queue: + await self.diarization_queue.put(silence_buffer) + + if not self.silence: + if self.args.transcription and self.transcription_queue: + await self.transcription_queue.put(pcm_array.copy()) + + if self.args.diarization and self.diarization_queue: + await self.diarization_queue.put(pcm_array.copy()) + + self.silence_duration = 0.0 + if end_of_audio: + self.silence = True + self.start_silence = time() # Sleep if no processing is happening if not self.args.transcription and not self.args.diarization: await asyncio.sleep(0.1) + + except Exception as e: logger.warning(f"Exception in ffmpeg_stdout_reader: {e}") logger.warning(f"Traceback: {traceback.format_exc()}") @@ -239,8 +283,8 @@ class AudioProcessor: while True: try: - pcm_array = await self.transcription_queue.get() - if pcm_array is SENTINEL: + item = await self.transcription_queue.get() + if item is SENTINEL: logger.debug("Transcription processor received sentinel. Finishing.") self.transcription_queue.task_done() break @@ -258,11 +302,23 @@ class AudioProcessor: f"lag={transcription_lag_s:.2f}s." ) - # Process transcription - duration_this_chunk = len(pcm_array) / self.sample_rate if isinstance(pcm_array, np.ndarray) else 0 + if type(item) is Silence: + cumulative_pcm_duration_stream_time += item.duration + self.online.insert_silence(item.duration) + continue + + if isinstance(item, np.ndarray): + pcm_array = item + else: + raise Exception('item should be pcm_array') + + duration_this_chunk = len(pcm_array) / self.sample_rate cumulative_pcm_duration_stream_time += duration_this_chunk stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time + + + self.online.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm) new_tokens, current_audio_processed_upto = self.online.process_iter() diff --git a/whisperlivekit/whisper_streaming_custom/silero_vad_iterator.py b/whisperlivekit/silero_vad_iterator.py similarity index 100% rename from whisperlivekit/whisper_streaming_custom/silero_vad_iterator.py rename to whisperlivekit/silero_vad_iterator.py diff --git a/whisperlivekit/timed_objects.py b/whisperlivekit/timed_objects.py index 82fd1f6..c2a3619 100644 --- a/whisperlivekit/timed_objects.py +++ b/whisperlivekit/timed_objects.py @@ -29,4 +29,8 @@ class SpeakerSegment(TimedText): """Represents a segment of audio attributed to a specific speaker. No text nor probability is associated with this segment. """ - pass \ No newline at end of file + pass + +@dataclass +class Silence(): + duration: float \ No newline at end of file diff --git a/whisperlivekit/trail_repetition.py b/whisperlivekit/trail_repetition.py new file mode 100644 index 0000000..18d9f5e --- /dev/null +++ b/whisperlivekit/trail_repetition.py @@ -0,0 +1,60 @@ +from typing import Sequence, Callable, Any, Optional, Dict + +def _detect_tail_repetition( + seq: Sequence[Any], + key: Callable[[Any], Any] = lambda x: x, # extract comparable value + min_block: int = 1, # set to 2 to ignore 1-token loops like "." + max_tail: int = 300, # search window from the end for speed + prefer: str = "longest", # "longest" coverage or "smallest" block +) -> Optional[Dict]: + vals = [key(x) for x in seq][-max_tail:] + n = len(vals) + best = None + + # try every possible block length + for b in range(min_block, n // 2 + 1): + block = vals[-b:] + # count how many times this block repeats contiguously at the very end + count, i = 0, n + while i - b >= 0 and vals[i - b:i] == block: + count += 1 + i -= b + + if count >= 2: + cand = { + "block_size": b, + "count": count, + "start_index": len(seq) - count * b, # in original seq + "end_index": len(seq), + } + if (best is None or + (prefer == "longest" and count * b > best["count"] * best["block_size"]) or + (prefer == "smallest" and b < best["block_size"])): + best = cand + return best + +def trim_tail_repetition( + seq: Sequence[Any], + key: Callable[[Any], Any] = lambda x: x, + min_block: int = 1, + max_tail: int = 300, + prefer: str = "longest", + keep: int = 1, # how many copies of the repeating block to keep at the end (0 or 1 are common) +): + """ + Returns a new sequence with repeated tail trimmed. + keep=1 -> keep a single copy of the repeated block. + keep=0 -> remove all copies of the repeated block. + """ + rep = _detect_tail_repetition(seq, key, min_block, max_tail, prefer) + if not rep: + return seq, False # nothing to trim + + b, c = rep["block_size"], rep["count"] + if keep < 0: + keep = 0 + if keep >= c: + return seq, False # nothing to trim (already <= keep copies) + # new length = total - (copies_to_remove * block_size) + new_len = len(seq) - (c - keep) * b + return seq[:new_len], True \ No newline at end of file diff --git a/whisperlivekit/whisper_streaming_custom/online_asr.py b/whisperlivekit/whisper_streaming_custom/online_asr.py index 4fd6de8..d575035 100644 --- a/whisperlivekit/whisper_streaming_custom/online_asr.py +++ b/whisperlivekit/whisper_streaming_custom/online_asr.py @@ -411,7 +411,7 @@ class VACOnlineASRProcessor: # Load a VAD model (e.g. Silero VAD) import torch model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad") - from .silero_vad_iterator import FixedVADIterator + from ..silero_vad_iterator import FixedVADIterator self.vac = FixedVADIterator(model) self.logfile = self.online.logfile