diff --git a/whisperlivekit/audio_processor.py b/whisperlivekit/audio_processor.py index 37c6a44..3d43c03 100644 --- a/whisperlivekit/audio_processor.py +++ b/whisperlivekit/audio_processor.py @@ -9,6 +9,7 @@ import numpy as np from whisperlivekit.core import (TranscriptionEngine, online_diarization_factory, online_factory, online_translation_factory) +from whisperlivekit.metrics_collector import SessionMetrics from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState from whisperlivekit.silero_vad_iterator import FixedVADIterator, OnnxWrapper, load_jit_vad from whisperlivekit.timed_objects import (ASRToken, ChangeSpeaker, FrontData, @@ -118,6 +119,7 @@ class AudioProcessor: self.translation_task: Optional[asyncio.Task] = None self.watchdog_task: Optional[asyncio.Task] = None self.all_tasks_for_cleanup: List[asyncio.Task] = [] + self.metrics: SessionMetrics = SessionMetrics() self.transcription: Optional[Any] = None self.translation: Optional[Any] = None @@ -139,25 +141,43 @@ class AudioProcessor: if self.translation_queue: await self.translation_queue.put(self.current_silence) - async def _begin_silence(self) -> None: + async def _begin_silence(self, at_sample: Optional[int] = None) -> None: if self.current_silence: return - now = time() - self.beg_loop + # Use audio stream time (sample-precise) for accurate silence duration + if at_sample is not None: + audio_t = at_sample / self.sample_rate + else: + audio_t = self.total_pcm_samples / self.sample_rate if self.sample_rate else 0.0 self.current_silence = Silence( - is_starting=True, start=now + is_starting=True, start=audio_t ) - await self._push_silence_event() + # Push a separate start-only event so _end_silence won't mutate it + start_event = Silence(is_starting=True, start=audio_t) + if self.transcription_queue: + await self.transcription_queue.put(start_event) + if self.args.diarization and self.diarization_queue: + await self.diarization_queue.put(start_event) + if self.translation_queue: + await self.translation_queue.put(start_event) - async def _end_silence(self) -> None: + async def _end_silence(self, at_sample: Optional[int] = None) -> None: if not self.current_silence: return - now = time() - self.beg_loop - self.current_silence.end = now - self.current_silence.is_starting=False - self.current_silence.has_ended=True + if at_sample is not None: + audio_t = at_sample / self.sample_rate + else: + audio_t = self.total_pcm_samples / self.sample_rate if self.sample_rate else 0.0 + self.current_silence.end = audio_t + self.current_silence.is_starting = False + self.current_silence.has_ended = True self.current_silence.compute_duration() + self.metrics.n_silence_events += 1 + if self.current_silence.duration is not None: + self.metrics.total_silence_duration_s += self.current_silence.duration if self.current_silence.duration > MIN_DURATION_REAL_SILENCE: self.state.new_tokens.append(self.current_silence) + # Push the completed silence as the end event (separate from the start event) await self._push_silence_event() self.current_silence = None @@ -253,6 +273,34 @@ class AudioProcessor: if self.translation: await self.translation_queue.put(SENTINEL) + async def _finish_transcription(self) -> None: + """Call finish() on the online processor to flush remaining tokens.""" + if not self.transcription: + return + try: + if hasattr(self.transcription, 'finish'): + final_tokens, end_time = await asyncio.to_thread(self.transcription.finish) + else: + # SimulStreamingOnlineProcessor uses start_silence() → process_iter(is_last=True) + final_tokens, end_time = await asyncio.to_thread(self.transcription.start_silence) + + final_tokens = final_tokens or [] + if final_tokens: + logger.info(f"Finish flushed {len(final_tokens)} tokens") + _buffer_transcript = self.transcription.get_buffer() + async with self.lock: + self.state.tokens.extend(final_tokens) + self.state.buffer_transcription = _buffer_transcript + self.state.end_buffer = max(self.state.end_buffer, end_time) + self.state.new_tokens.extend(final_tokens) + self.state.new_tokens_buffer = _buffer_transcript + if self.translation_queue: + for token in final_tokens: + await self.translation_queue.put(token) + except Exception as e: + logger.warning(f"Error finishing transcription: {e}") + logger.debug(f"Traceback: {traceback.format_exc()}") + async def transcription_processor(self) -> None: """Process audio chunks for transcription.""" cumulative_pcm_duration_stream_time = 0.0 @@ -263,6 +311,7 @@ class AudioProcessor: item = await get_all_from_queue(self.transcription_queue) if item is SENTINEL: logger.debug("Transcription processor received sentinel. Finishing.") + await self._finish_transcription() break asr_internal_buffer_duration_s = len(getattr(self.transcription, 'audio_buffer', [])) / self.transcription.SAMPLING_RATE @@ -297,8 +346,13 @@ class AudioProcessor: cumulative_pcm_duration_stream_time += len(pcm_array) / self.sample_rate stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time self.transcription.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm) + _t0 = time() new_tokens, current_audio_processed_upto = await asyncio.to_thread(self.transcription.process_iter) + _dur = time() - _t0 + self.metrics.transcription_durations.append(_dur) + self.metrics.n_transcription_calls += 1 new_tokens = new_tokens or [] + self.metrics.n_tokens_produced += len(new_tokens) _buffer_transcript = self.transcription.get_buffer() buffer_text = _buffer_transcript.text @@ -433,6 +487,7 @@ class AudioProcessor: should_push = (response != self.last_response_content) if should_push: + self.metrics.n_responses_sent += 1 yield response self.last_response_content = response @@ -535,6 +590,10 @@ class AudioProcessor: logger.warning(f"Error stopping FFmpeg manager: {e}") if self.diarization: self.diarization.close() + + # Finalize session metrics + self.metrics.total_audio_duration_s = self.total_pcm_samples / self.sample_rate + self.metrics.log_summary() logger.info("AudioProcessor cleanup complete.") def _processing_tasks_done(self) -> bool: @@ -553,6 +612,7 @@ class AudioProcessor: if not self.beg_loop: self.beg_loop = time() + self.metrics.session_start = self.beg_loop self.current_silence = Silence(start=0.0, is_starting=True) self.tokens_alignment.beg_loop = self.beg_loop @@ -560,6 +620,10 @@ class AudioProcessor: logger.info("Empty audio message received, initiating stop sequence.") self.is_stopping = True + # Flush any remaining PCM data before signaling end-of-stream + if self.is_pcm_input and self.pcm_buffer: + await self._flush_remaining_pcm() + if self.transcription_queue: await self.transcription_queue.put(SENTINEL) @@ -572,6 +636,8 @@ class AudioProcessor: logger.warning("AudioProcessor is stopping. Ignoring incoming audio.") return + self.metrics.n_chunks_received += 1 + if self.is_pcm_input: self.pcm_buffer.extend(message) await self.handle_pcm_data() @@ -588,6 +654,11 @@ class AudioProcessor: logger.warning("Failed to write audio data to FFmpeg") async def handle_pcm_data(self) -> None: + # Without VAC, there's no speech detector to end the initial silence. + # Clear it on the first audio chunk so audio actually gets enqueued. + if not self.args.vac and self.current_silence: + await self._end_silence() + # Process when enough data if len(self.pcm_buffer) < self.bytes_per_sec: return @@ -616,7 +687,7 @@ class AudioProcessor: if res is not None: if "start" in res and self.current_silence: - await self._end_silence() + await self._end_silence(at_sample=res.get("start")) if "end" in res and not self.current_silence: pre_silence_chunk = self._slice_before_silence( @@ -624,7 +695,7 @@ class AudioProcessor: ) if pre_silence_chunk is not None and pre_silence_chunk.size > 0: await self._enqueue_active_audio(pre_silence_chunk) - await self._begin_silence() + await self._begin_silence(at_sample=res.get("end")) if not self.current_silence: await self._enqueue_active_audio(pcm_array) @@ -633,3 +704,21 @@ class AudioProcessor: if not self.args.transcription and not self.args.diarization: await asyncio.sleep(0.1) + + async def _flush_remaining_pcm(self) -> None: + """Flush whatever PCM data remains in the buffer, regardless of size threshold.""" + if not self.pcm_buffer: + return + aligned_size = (len(self.pcm_buffer) // self.bytes_per_sample) * self.bytes_per_sample + if aligned_size == 0: + return + pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:aligned_size]) + self.pcm_buffer = self.pcm_buffer[aligned_size:] + + # End any active silence so the audio gets enqueued + if self.current_silence: + await self._end_silence(at_sample=self.total_pcm_samples) + + await self._enqueue_active_audio(pcm_array) + self.total_pcm_samples += len(pcm_array) + logger.info(f"Flushed remaining PCM buffer: {len(pcm_array)} samples ({len(pcm_array)/self.sample_rate:.2f}s)") diff --git a/whisperlivekit/metrics.py b/whisperlivekit/metrics.py new file mode 100644 index 0000000..09e9c12 --- /dev/null +++ b/whisperlivekit/metrics.py @@ -0,0 +1,151 @@ +"""Lightweight ASR evaluation metrics — no external dependencies. + +Provides WER (Word Error Rate) computation via word-level Levenshtein distance, +text normalization, and word-level timestamp accuracy metrics with greedy alignment. +""" + +import re +import unicodedata +from typing import Dict, List, Optional + + +def normalize_text(text: str) -> str: + """Normalize text for WER comparison: lowercase, strip punctuation, collapse whitespace.""" + text = text.lower() + # Normalize unicode (e.g., accented chars to composed form) + text = unicodedata.normalize("NFC", text) + # Remove punctuation (keep letters, numbers, spaces, hyphens within words) + text = re.sub(r"[^\w\s\-']", " ", text) + # Collapse whitespace + text = re.sub(r"\s+", " ", text).strip() + return text + + +def compute_wer(reference: str, hypothesis: str) -> Dict: + """Compute Word Error Rate using word-level Levenshtein edit distance. + + Args: + reference: Ground truth transcription. + hypothesis: Predicted transcription. + + Returns: + Dict with keys: wer, substitutions, insertions, deletions, ref_words, hyp_words. + WER can exceed 1.0 if there are more errors than reference words. + """ + ref_words = normalize_text(reference).split() + hyp_words = normalize_text(hypothesis).split() + + n = len(ref_words) + m = len(hyp_words) + + if n == 0: + return { + "wer": 0.0 if m == 0 else float(m), + "substitutions": 0, + "insertions": m, + "deletions": 0, + "ref_words": 0, + "hyp_words": m, + } + + # DP table: dp[i][j] = (edit_distance, substitutions, insertions, deletions) + dp = [[(0, 0, 0, 0) for _ in range(m + 1)] for _ in range(n + 1)] + + for i in range(1, n + 1): + dp[i][0] = (i, 0, 0, i) + for j in range(1, m + 1): + dp[0][j] = (j, 0, j, 0) + + for i in range(1, n + 1): + for j in range(1, m + 1): + if ref_words[i - 1] == hyp_words[j - 1]: + dp[i][j] = dp[i - 1][j - 1] + else: + sub = dp[i - 1][j - 1] + ins = dp[i][j - 1] + dele = dp[i - 1][j] + + sub_cost = (sub[0] + 1, sub[1] + 1, sub[2], sub[3]) + ins_cost = (ins[0] + 1, ins[1], ins[2] + 1, ins[3]) + del_cost = (dele[0] + 1, dele[1], dele[2], dele[3] + 1) + + dp[i][j] = min(sub_cost, del_cost, ins_cost, key=lambda x: x[0]) + + dist, subs, ins, dels = dp[n][m] + return { + "wer": dist / n, + "substitutions": subs, + "insertions": ins, + "deletions": dels, + "ref_words": n, + "hyp_words": m, + } + + +def compute_timestamp_accuracy( + predicted: List[Dict], + reference: List[Dict], +) -> Dict: + """Compute timestamp accuracy by aligning predicted words to reference words. + + Uses greedy left-to-right alignment on normalized text. For each matched pair, + computes the start-time delta (predicted - reference). + + Args: + predicted: List of dicts with keys: word, start, end. + reference: List of dicts with keys: word, start, end. + + Returns: + Dict with keys: mae_start, max_delta_start, median_delta_start, + n_matched, n_ref, n_pred. Returns None values if no matches found. + """ + if not predicted or not reference: + return { + "mae_start": None, + "max_delta_start": None, + "median_delta_start": None, + "n_matched": 0, + "n_ref": len(reference), + "n_pred": len(predicted), + } + + # Normalize words for matching + pred_norm = [normalize_text(p["word"]) for p in predicted] + ref_norm = [normalize_text(r["word"]) for r in reference] + + # Greedy left-to-right alignment + deltas_start = [] + ref_idx = 0 + for p_idx, p_word in enumerate(pred_norm): + if not p_word: + continue + # Scan forward in reference to find a match (allow small skips) + search_limit = min(ref_idx + 3, len(ref_norm)) + for r_idx in range(ref_idx, search_limit): + if ref_norm[r_idx] == p_word: + delta = predicted[p_idx]["start"] - reference[r_idx]["start"] + deltas_start.append(delta) + ref_idx = r_idx + 1 + break + + if not deltas_start: + return { + "mae_start": None, + "max_delta_start": None, + "median_delta_start": None, + "n_matched": 0, + "n_ref": len(reference), + "n_pred": len(predicted), + } + + abs_deltas = [abs(d) for d in deltas_start] + sorted_abs = sorted(abs_deltas) + + return { + "mae_start": sum(abs_deltas) / len(abs_deltas), + "max_delta_start": max(abs_deltas), + "median_delta_start": sorted_abs[len(sorted_abs) // 2], + "n_matched": len(deltas_start), + "n_ref": len(reference), + "n_pred": len(predicted), + } diff --git a/whisperlivekit/metrics_collector.py b/whisperlivekit/metrics_collector.py new file mode 100644 index 0000000..365f07a --- /dev/null +++ b/whisperlivekit/metrics_collector.py @@ -0,0 +1,84 @@ +"""Lightweight runtime metrics for AudioProcessor sessions. + +Zero external dependencies. Negligible overhead when not queried — +just integer increments and list appends during normal operation. +""" + +import logging +import time +from dataclasses import dataclass, field +from typing import Dict, List + +logger = logging.getLogger(__name__) + + +@dataclass +class SessionMetrics: + """Per-session metrics collected by AudioProcessor.""" + + session_start: float = 0.0 + total_audio_duration_s: float = 0.0 + total_processing_time_s: float = 0.0 + + # Chunk / call counters + n_chunks_received: int = 0 + n_transcription_calls: int = 0 + n_tokens_produced: int = 0 + n_responses_sent: int = 0 + + # Per-call ASR latency (seconds) + transcription_durations: List[float] = field(default_factory=list) + + # Silence + n_silence_events: int = 0 + total_silence_duration_s: float = 0.0 + + # --- Computed properties --- + + @property + def rtf(self) -> float: + """Real-time factor: processing_time / audio_duration.""" + if self.total_audio_duration_s <= 0: + return 0.0 + return self.total_processing_time_s / self.total_audio_duration_s + + @property + def avg_latency_ms(self) -> float: + """Average per-call ASR latency in milliseconds.""" + if not self.transcription_durations: + return 0.0 + return (sum(self.transcription_durations) / len(self.transcription_durations)) * 1000 + + @property + def p95_latency_ms(self) -> float: + """95th percentile per-call ASR latency in milliseconds.""" + if not self.transcription_durations: + return 0.0 + sorted_d = sorted(self.transcription_durations) + idx = int(len(sorted_d) * 0.95) + idx = min(idx, len(sorted_d) - 1) + return sorted_d[idx] * 1000 + + def to_dict(self) -> Dict: + """Serialize to a plain dict (JSON-safe).""" + return { + "session_start": self.session_start, + "total_audio_duration_s": round(self.total_audio_duration_s, 3), + "total_processing_time_s": round(self.total_processing_time_s, 3), + "rtf": round(self.rtf, 3), + "n_chunks_received": self.n_chunks_received, + "n_transcription_calls": self.n_transcription_calls, + "n_tokens_produced": self.n_tokens_produced, + "n_responses_sent": self.n_responses_sent, + "avg_latency_ms": round(self.avg_latency_ms, 2), + "p95_latency_ms": round(self.p95_latency_ms, 2), + "n_silence_events": self.n_silence_events, + "total_silence_duration_s": round(self.total_silence_duration_s, 3), + } + + def log_summary(self) -> None: + """Emit a structured log line summarising the session.""" + elapsed = time.time() - self.session_start if self.session_start else 0 + self.total_processing_time_s = elapsed + d = self.to_dict() + logger.info(f"SESSION_METRICS {d}")