From 5a12c627b474f9af84f2acb89d38a8affbe1714e Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Sun, 22 Feb 2026 23:27:40 +0100 Subject: [PATCH] feat: add 99-test unit test suite with zero model dependencies Test suite covering: - metrics.py: WER computation, timestamp accuracy, text normalization - config.py: defaults, .en model detection, policy aliases, from_namespace - timed_objects.py: ASRToken, Silence, Transcript, Segment, FrontData - hypothesis_buffer.py: insert, flush, LCP matching, pop_committed - silence_handling.py: state machine, double-counting regression test - audio_processor.py: async pipeline with MockOnlineProcessor All tests run in ~1.3s without downloading any ASR models. Add pytest and pytest-asyncio as optional test dependencies. Update .gitignore to allow tests/ directory. --- .gitignore | 6 +- pyproject.toml | 4 +- tests/__init__.py | 0 tests/conftest.py | 58 +++++++++ tests/test_audio_processor.py | 209 ++++++++++++++++++++++++++++++++ tests/test_config.py | 99 +++++++++++++++ tests/test_hypothesis_buffer.py | 172 ++++++++++++++++++++++++++ tests/test_metrics.py | 147 ++++++++++++++++++++++ tests/test_silence_handling.py | 99 +++++++++++++++ tests/test_timed_objects.py | 185 ++++++++++++++++++++++++++++ 10 files changed, 976 insertions(+), 3 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/test_audio_processor.py create mode 100644 tests/test_config.py create mode 100644 tests/test_hypothesis_buffer.py create mode 100644 tests/test_metrics.py create mode 100644 tests/test_silence_handling.py create mode 100644 tests/test_timed_objects.py diff --git a/.gitignore b/.gitignore index a015198..ecfdcd4 100644 --- a/.gitignore +++ b/.gitignore @@ -119,9 +119,11 @@ run_*.sh *.pt # Debug & testing -test_*.py +/test_*.py +!test_backend_offline.py launch.json .DS_Store -test/* +/test/ +!tests/ nllb-200-distilled-600M-ctranslate2/* *.mp3 \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 74ade12..9a79780 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "whisperlivekit" -version = "0.2.18" +version = "0.2.19" description = "Real-time speech-to-text with speaker diarization using Whisper" readme = "README.md" authors = [ @@ -42,6 +42,7 @@ dependencies = [ ] [project.optional-dependencies] +test = ["pytest>=7.0", "pytest-asyncio>=0.21"] translation = ["nllw"] sentence_tokenizer = ["mosestokenizer", "wtpsplit"] voxtral-hf = ["transformers>=5.2.0", "mistral-common[audio]"] @@ -64,6 +65,7 @@ packages = [ "whisperlivekit.whisper.normalizers", "whisperlivekit.web", "whisperlivekit.local_agreement", + "whisperlivekit.voxtral_mlx", "whisperlivekit.silero_vad_models" ] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..1a26f33 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,58 @@ +"""Shared pytest fixtures for WhisperLiveKit tests.""" + +import json +from pathlib import Path +from types import SimpleNamespace + +import pytest + +from whisperlivekit.timed_objects import ASRToken, Silence, Transcript + + +AUDIO_TESTS_DIR = Path(__file__).parent.parent / "audio_tests" + + +@pytest.fixture +def sample_tokens(): + """A short sequence of ASRToken objects.""" + return [ + ASRToken(start=0.0, end=0.5, text="Hello"), + ASRToken(start=0.5, end=1.0, text=" world"), + ASRToken(start=1.0, end=1.5, text=" test."), + ] + + +@pytest.fixture +def sample_silence(): + """A completed silence event.""" + s = Silence(start=1.5, end=3.0, is_starting=False, has_ended=True) + s.compute_duration() + return s + + +@pytest.fixture +def mock_args(): + """Minimal args namespace for AudioProcessor tests.""" + return SimpleNamespace( + diarization=False, + transcription=True, + target_language="", + vac=False, + vac_chunk_size=0.04, + min_chunk_size=0.1, + pcm_input=True, + punctuation_split=False, + backend="faster-whisper", + backend_policy="localagreement", + vad=True, + ) + + +@pytest.fixture +def ground_truth_en(): + """Ground truth transcript for the 7s English audio (if available).""" + path = AUDIO_TESTS_DIR / "00_00_07_english_1_speaker.transcript.json" + if path.exists(): + with open(path) as f: + return json.load(f) + return None diff --git a/tests/test_audio_processor.py b/tests/test_audio_processor.py new file mode 100644 index 0000000..9286108 --- /dev/null +++ b/tests/test_audio_processor.py @@ -0,0 +1,209 @@ +"""Tests for AudioProcessor pipeline with mocked ASR backends. + +These tests verify the async audio processing pipeline works correctly +without requiring any real ASR models to be loaded. +""" + +import asyncio +from types import SimpleNamespace +from unittest.mock import patch + +import numpy as np +import pytest + +from whisperlivekit.timed_objects import ASRToken, Transcript + + +# --------------------------------------------------------------------------- +# Mock ASR components +# --------------------------------------------------------------------------- + +class MockASR: + """Mock ASR model holder.""" + sep = " " + SAMPLING_RATE = 16000 + + def __init__(self): + self.transcribe_kargs = {} + self.original_language = "en" + self.backend_choice = "mock" + + def transcribe(self, audio): + return None + + +class MockOnlineProcessor: + """Mock online processor that returns canned tokens.""" + SAMPLING_RATE = 16000 + + def __init__(self, asr=None): + self.asr = asr or MockASR() + self.audio_buffer = np.array([], dtype=np.float32) + self.end = 0.0 + self._call_count = 0 + self._finished = False + + def insert_audio_chunk(self, audio, audio_stream_end_time): + self.audio_buffer = np.append(self.audio_buffer, audio) + self.end = audio_stream_end_time + + def process_iter(self, is_last=False): + self._call_count += 1 + # Emit a token on every call when we have audio + if len(self.audio_buffer) > 0: + t = self._call_count * 0.5 + return [ASRToken(start=t, end=t + 0.5, text=f"word{self._call_count}")], self.end + return [], self.end + + def get_buffer(self): + return Transcript(start=None, end=None, text="") + + def start_silence(self): + return [], self.end + + def end_silence(self, silence_duration, offset): + pass + + def new_speaker(self, change_speaker): + pass + + def finish(self): + self._finished = True + return [], self.end + + def warmup(self, audio, init_prompt=""): + pass + + +def _make_pcm_bytes(duration_s=0.1, sample_rate=16000): + """Generate silent PCM s16le bytes.""" + n_samples = int(duration_s * sample_rate) + audio = np.zeros(n_samples, dtype=np.float32) + return (audio * 32768).clip(-32768, 32767).astype(np.int16).tobytes() + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def mock_engine(): + """Create a mock TranscriptionEngine-like object.""" + engine = SimpleNamespace( + asr=MockASR(), + diarization_model=None, + translation_model=None, + args=SimpleNamespace( + diarization=False, + transcription=True, + target_language="", + vac=False, + vac_chunk_size=0.04, + min_chunk_size=0.1, + pcm_input=True, + punctuation_split=False, + backend="mock", + backend_policy="localagreement", + vad=True, + model_size="base", + lan="en", + ), + ) + return engine + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestPCMConversion: + """Test PCM byte conversion without needing the full pipeline.""" + + def test_s16le_roundtrip(self): + """Convert float32 → s16le → float32 and verify approximate roundtrip.""" + original = np.array([0.0, 0.5, -0.5, 1.0, -1.0], dtype=np.float32) + s16 = (original * 32768).clip(-32768, 32767).astype(np.int16) + pcm_bytes = s16.tobytes() + # Direct numpy conversion (same logic as AudioProcessor.convert_pcm_to_float) + recovered = np.frombuffer(pcm_bytes, dtype=np.int16).astype(np.float32) / 32768.0 + + np.testing.assert_allclose(recovered, original, atol=1 / 32768) + + +@pytest.mark.asyncio +class TestPipelineBasics: + async def test_feed_audio_and_get_responses(self, mock_engine): + """Feed audio through the pipeline and verify we get responses.""" + from whisperlivekit.audio_processor import AudioProcessor + + with patch("whisperlivekit.audio_processor.online_factory", return_value=MockOnlineProcessor()): + processor = AudioProcessor(transcription_engine=mock_engine) + results_gen = await processor.create_tasks() + + responses = [] + + async def collect(): + async for resp in results_gen: + responses.append(resp) + + task = asyncio.create_task(collect()) + + # Feed 2 seconds of audio in 100ms chunks + for _ in range(20): + await processor.process_audio(_make_pcm_bytes(0.1)) + + # Signal EOF + await processor.process_audio(None) + + await asyncio.wait_for(task, timeout=10.0) + await processor.cleanup() + + # We should have gotten at least one response + assert len(responses) > 0 + + async def test_eof_terminates_pipeline(self, mock_engine): + """Sending None (EOF) should cleanly terminate the pipeline.""" + from whisperlivekit.audio_processor import AudioProcessor + + with patch("whisperlivekit.audio_processor.online_factory", return_value=MockOnlineProcessor()): + processor = AudioProcessor(transcription_engine=mock_engine) + results_gen = await processor.create_tasks() + + responses = [] + + async def collect(): + async for resp in results_gen: + responses.append(resp) + + task = asyncio.create_task(collect()) + + # Send a small amount of audio then EOF + await processor.process_audio(_make_pcm_bytes(0.5)) + await processor.process_audio(None) + + await asyncio.wait_for(task, timeout=10.0) + await processor.cleanup() + + # Pipeline should have terminated without error + assert task.done() + + async def test_empty_audio_no_crash(self, mock_engine): + """Sending EOF immediately (no audio) should not crash.""" + from whisperlivekit.audio_processor import AudioProcessor + + with patch("whisperlivekit.audio_processor.online_factory", return_value=MockOnlineProcessor()): + processor = AudioProcessor(transcription_engine=mock_engine) + results_gen = await processor.create_tasks() + + responses = [] + + async def collect(): + async for resp in results_gen: + responses.append(resp) + + task = asyncio.create_task(collect()) + await processor.process_audio(None) + + await asyncio.wait_for(task, timeout=10.0) + await processor.cleanup() + assert task.done() diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..23f4c56 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,99 @@ +"""Tests for WhisperLiveKitConfig.""" + +import logging +from types import SimpleNamespace + +import pytest + +from whisperlivekit.config import WhisperLiveKitConfig + + +class TestDefaults: + def test_default_backend(self): + c = WhisperLiveKitConfig() + assert c.backend == "auto" + + def test_default_policy(self): + c = WhisperLiveKitConfig() + assert c.backend_policy == "simulstreaming" + + def test_default_language(self): + c = WhisperLiveKitConfig() + assert c.lan == "auto" + + def test_default_vac(self): + c = WhisperLiveKitConfig() + assert c.vac is True + + def test_default_model_size(self): + c = WhisperLiveKitConfig() + assert c.model_size == "base" + + def test_default_transcription(self): + c = WhisperLiveKitConfig() + assert c.transcription is True + assert c.diarization is False + + +class TestPostInit: + def test_en_model_forces_english(self): + c = WhisperLiveKitConfig(model_size="tiny.en") + assert c.lan == "en" + + def test_en_suffix_with_auto_language(self): + c = WhisperLiveKitConfig(model_size="base.en", lan="auto") + assert c.lan == "en" + + def test_non_en_model_keeps_language(self): + c = WhisperLiveKitConfig(model_size="base", lan="fr") + assert c.lan == "fr" + + def test_policy_alias_1(self): + c = WhisperLiveKitConfig(backend_policy="1") + assert c.backend_policy == "simulstreaming" + + def test_policy_alias_2(self): + c = WhisperLiveKitConfig(backend_policy="2") + assert c.backend_policy == "localagreement" + + def test_policy_no_alias(self): + c = WhisperLiveKitConfig(backend_policy="localagreement") + assert c.backend_policy == "localagreement" + + +class TestFromNamespace: + def test_known_keys(self): + ns = SimpleNamespace(backend="faster-whisper", lan="en", model_size="large-v3") + c = WhisperLiveKitConfig.from_namespace(ns) + assert c.backend == "faster-whisper" + assert c.lan == "en" + assert c.model_size == "large-v3" + + def test_ignores_unknown_keys(self): + ns = SimpleNamespace(backend="auto", unknown_key="value", another="x") + c = WhisperLiveKitConfig.from_namespace(ns) + assert c.backend == "auto" + assert not hasattr(c, "unknown_key") + + def test_preserves_defaults_for_missing(self): + ns = SimpleNamespace(backend="voxtral-mlx") + c = WhisperLiveKitConfig.from_namespace(ns) + assert c.lan == "auto" + assert c.vac is True + + +class TestFromKwargs: + def test_known_keys(self): + c = WhisperLiveKitConfig.from_kwargs(backend="mlx-whisper", lan="fr") + assert c.backend == "mlx-whisper" + assert c.lan == "fr" + + def test_warns_on_unknown_keys(self, caplog): + with caplog.at_level(logging.WARNING, logger="whisperlivekit.config"): + c = WhisperLiveKitConfig.from_kwargs(backend="auto", bogus="value") + assert c.backend == "auto" + assert "bogus" in caplog.text + + def test_post_init_runs(self): + c = WhisperLiveKitConfig.from_kwargs(model_size="small.en") + assert c.lan == "en" diff --git a/tests/test_hypothesis_buffer.py b/tests/test_hypothesis_buffer.py new file mode 100644 index 0000000..732090a --- /dev/null +++ b/tests/test_hypothesis_buffer.py @@ -0,0 +1,172 @@ +"""Tests for HypothesisBuffer — the core of LocalAgreement policy.""" + +import pytest + +from whisperlivekit.timed_objects import ASRToken +from whisperlivekit.local_agreement.online_asr import HypothesisBuffer + + +def make_tokens(words, start=0.0, step=0.5): + """Helper: create ASRToken list from word strings.""" + tokens = [] + t = start + for w in words: + tokens.append(ASRToken(start=t, end=t + step, text=w, probability=0.9)) + t += step + return tokens + + +class TestInsert: + def test_basic_insert(self): + buf = HypothesisBuffer() + tokens = make_tokens(["hello", "world"]) + buf.insert(tokens, offset=0.0) + assert len(buf.new) == 2 + assert buf.new[0].text == "hello" + + def test_insert_with_offset(self): + buf = HypothesisBuffer() + tokens = make_tokens(["hello"], start=0.0) + buf.insert(tokens, offset=5.0) + assert buf.new[0].start == pytest.approx(5.0) + + def test_insert_filters_old_tokens(self): + buf = HypothesisBuffer() + buf.last_committed_time = 10.0 + tokens = make_tokens(["old", "new"], start=5.0, step=3.0) + buf.insert(tokens, offset=0.0) + # "old" at 5.0 is before last_committed_time - 0.1 = 9.9 → filtered + # "new" at 8.0 is also before 9.9 → filtered + assert len(buf.new) == 0 + + def test_insert_deduplicates_committed(self): + buf = HypothesisBuffer() + # Commit "hello" + tokens1 = make_tokens(["hello", "world"]) + buf.insert(tokens1, offset=0.0) + buf.flush() # commits "hello" (buffer was empty, so nothing matches) + # Actually with empty buffer, flush won't commit anything + # Let's do it properly: two rounds + buf2 = HypothesisBuffer() + first = make_tokens(["hello", "world"]) + buf2.insert(first, offset=0.0) + buf2.flush() # buffer was empty → no commits, buffer = ["hello", "world"] + + second = make_tokens(["hello", "world", "test"]) + buf2.insert(second, offset=0.0) + committed = buf2.flush() + # LCP of ["hello", "world"] and ["hello", "world", "test"] = ["hello", "world"] + assert len(committed) == 2 + assert committed[0].text == "hello" + assert committed[1].text == "world" + + +class TestFlush: + def test_flush_empty(self): + buf = HypothesisBuffer() + committed = buf.flush() + assert committed == [] + + def test_flush_lcp_matching(self): + buf = HypothesisBuffer() + # Round 1: establish buffer + buf.insert(make_tokens(["hello", "world"]), offset=0.0) + buf.flush() # buffer = ["hello", "world"], committed = [] + + # Round 2: same prefix, new suffix + buf.insert(make_tokens(["hello", "world", "test"]), offset=0.0) + committed = buf.flush() + assert [t.text for t in committed] == ["hello", "world"] + + def test_flush_no_match(self): + buf = HypothesisBuffer() + # Round 1 + buf.insert(make_tokens(["hello", "world"]), offset=0.0) + buf.flush() + + # Round 2: completely different + buf.insert(make_tokens(["foo", "bar"]), offset=0.0) + committed = buf.flush() + assert committed == [] + + def test_flush_partial_match(self): + buf = HypothesisBuffer() + buf.insert(make_tokens(["hello", "world", "test"]), offset=0.0) + buf.flush() + + buf.insert(make_tokens(["hello", "earth", "again"]), offset=0.0) + committed = buf.flush() + assert len(committed) == 1 + assert committed[0].text == "hello" + + def test_flush_updates_last_committed(self): + buf = HypothesisBuffer() + buf.insert(make_tokens(["hello", "world"]), offset=0.0) + buf.flush() + + buf.insert(make_tokens(["hello", "world", "test"]), offset=0.0) + buf.flush() + assert buf.last_committed_word == "world" + assert buf.last_committed_time > 0 + + def test_flush_with_confidence_validation(self): + buf = HypothesisBuffer(confidence_validation=True) + high_conf = [ + ASRToken(start=0.0, end=0.5, text="sure", probability=0.99), + ASRToken(start=0.5, end=1.0, text="maybe", probability=0.5), + ] + buf.insert(high_conf, offset=0.0) + committed = buf.flush() + # "sure" has p>0.95 → committed immediately + assert len(committed) == 1 + assert committed[0].text == "sure" + + +class TestPopCommitted: + def test_pop_removes_old(self): + buf = HypothesisBuffer() + buf.committed_in_buffer = make_tokens(["a", "b", "c"], start=0.0, step=1.0) + # "a": end=1.0, "b": end=2.0, "c": end=3.0 + # pop_committed removes tokens with end <= time + buf.pop_committed(2.0) + # "a" (end=1.0) and "b" (end=2.0) removed, "c" (end=3.0) remains + assert len(buf.committed_in_buffer) == 1 + assert buf.committed_in_buffer[0].text == "c" + + def test_pop_nothing(self): + buf = HypothesisBuffer() + buf.committed_in_buffer = make_tokens(["a", "b"], start=5.0) + buf.pop_committed(0.0) + assert len(buf.committed_in_buffer) == 2 + + def test_pop_all(self): + buf = HypothesisBuffer() + buf.committed_in_buffer = make_tokens(["a", "b"], start=0.0, step=0.5) + buf.pop_committed(100.0) + assert len(buf.committed_in_buffer) == 0 + + +class TestStreamingSimulation: + """Multi-round insert/flush simulating real streaming behavior.""" + + def test_three_rounds(self): + buf = HypothesisBuffer() + all_committed = [] + + # Round 1: "this is" + buf.insert(make_tokens(["this", "is"]), offset=0.0) + all_committed.extend(buf.flush()) + + # Round 2: "this is a test" + buf.insert(make_tokens(["this", "is", "a", "test"]), offset=0.0) + all_committed.extend(buf.flush()) + + # Round 3: "this is a test today" + buf.insert(make_tokens(["this", "is", "a", "test", "today"]), offset=0.0) + all_committed.extend(buf.flush()) + + words = [t.text for t in all_committed] + assert "this" in words + assert "is" in words + assert "a" in words + assert "test" in words diff --git a/tests/test_metrics.py b/tests/test_metrics.py new file mode 100644 index 0000000..365e168 --- /dev/null +++ b/tests/test_metrics.py @@ -0,0 +1,147 @@ +"""Tests for whisperlivekit.metrics — WER, timestamp accuracy, normalization.""" + +import pytest + +from whisperlivekit.metrics import compute_wer, compute_timestamp_accuracy, normalize_text + + +class TestNormalizeText: + def test_lowercase(self): + assert normalize_text("Hello World") == "hello world" + + def test_strip_punctuation(self): + assert normalize_text("Hello, world!") == "hello world" + + def test_collapse_whitespace(self): + assert normalize_text(" hello world ") == "hello world" + + def test_keep_hyphens(self): + assert normalize_text("real-time") == "real-time" + + def test_keep_apostrophes(self): + assert normalize_text("don't") == "don't" + + def test_unicode_normalized(self): + # e + combining accent should be same as precomposed + assert normalize_text("caf\u0065\u0301") == normalize_text("caf\u00e9") + + def test_empty(self): + assert normalize_text("") == "" + + def test_only_punctuation(self): + assert normalize_text("...!?") == "" + + +class TestComputeWER: + def test_perfect_match(self): + result = compute_wer("hello world", "hello world") + assert result["wer"] == 0.0 + assert result["substitutions"] == 0 + assert result["insertions"] == 0 + assert result["deletions"] == 0 + + def test_case_insensitive(self): + result = compute_wer("Hello World", "hello world") + assert result["wer"] == 0.0 + + def test_punctuation_ignored(self): + result = compute_wer("Hello, world!", "hello world") + assert result["wer"] == 0.0 + + def test_one_substitution(self): + result = compute_wer("hello world", "hello earth") + assert result["wer"] == pytest.approx(0.5) + assert result["substitutions"] == 1 + + def test_one_insertion(self): + result = compute_wer("hello world", "hello big world") + assert result["wer"] == pytest.approx(0.5) + assert result["insertions"] == 1 + + def test_one_deletion(self): + result = compute_wer("hello big world", "hello world") + assert result["wer"] == pytest.approx(1 / 3) + assert result["deletions"] == 1 + + def test_completely_different(self): + result = compute_wer("the cat sat", "a dog ran") + assert result["wer"] == pytest.approx(1.0) + + def test_empty_reference(self): + result = compute_wer("", "hello") + assert result["wer"] == 1.0 # 1 insertion / 0 ref → treated as float(m) + assert result["ref_words"] == 0 + + def test_empty_hypothesis(self): + result = compute_wer("hello world", "") + assert result["wer"] == pytest.approx(1.0) + assert result["deletions"] == 2 + + def test_both_empty(self): + result = compute_wer("", "") + assert result["wer"] == 0.0 + + def test_ref_and_hyp_word_counts(self): + result = compute_wer("one two three", "one two three four") + assert result["ref_words"] == 3 + assert result["hyp_words"] == 4 + + +class TestComputeTimestampAccuracy: + def test_perfect_match(self): + words = [ + {"word": "hello", "start": 0.0, "end": 0.5}, + {"word": "world", "start": 0.5, "end": 1.0}, + ] + result = compute_timestamp_accuracy(words, words) + assert result["mae_start"] == 0.0 + assert result["max_delta_start"] == 0.0 + assert result["n_matched"] == 2 + + def test_constant_offset(self): + ref = [ + {"word": "hello", "start": 0.0, "end": 0.5}, + {"word": "world", "start": 0.5, "end": 1.0}, + ] + pred = [ + {"word": "hello", "start": 0.1, "end": 0.6}, + {"word": "world", "start": 0.6, "end": 1.1}, + ] + result = compute_timestamp_accuracy(pred, ref) + assert result["mae_start"] == pytest.approx(0.1) + assert result["max_delta_start"] == pytest.approx(0.1) + assert result["n_matched"] == 2 + + def test_mismatched_word_counts(self): + ref = [ + {"word": "hello", "start": 0.0, "end": 0.5}, + {"word": "beautiful", "start": 0.5, "end": 1.0}, + {"word": "world", "start": 1.0, "end": 1.5}, + ] + pred = [ + {"word": "hello", "start": 0.0, "end": 0.5}, + {"word": "world", "start": 1.1, "end": 1.6}, + ] + result = compute_timestamp_accuracy(pred, ref) + assert result["n_matched"] == 2 + assert result["n_ref"] == 3 + assert result["n_pred"] == 2 + + def test_empty_predicted(self): + ref = [{"word": "hello", "start": 0.0, "end": 0.5}] + result = compute_timestamp_accuracy([], ref) + assert result["mae_start"] is None + assert result["n_matched"] == 0 + + def test_empty_reference(self): + pred = [{"word": "hello", "start": 0.0, "end": 0.5}] + result = compute_timestamp_accuracy(pred, []) + assert result["mae_start"] is None + assert result["n_matched"] == 0 + + def test_case_insensitive_matching(self): + ref = [{"word": "Hello", "start": 0.0, "end": 0.5}] + pred = [{"word": "hello", "start": 0.1, "end": 0.6}] + result = compute_timestamp_accuracy(pred, ref) + assert result["n_matched"] == 1 + assert result["mae_start"] == pytest.approx(0.1) diff --git a/tests/test_silence_handling.py b/tests/test_silence_handling.py new file mode 100644 index 0000000..08028be --- /dev/null +++ b/tests/test_silence_handling.py @@ -0,0 +1,99 @@ +"""Tests for silence handling — state machine and double-counting regression.""" + +import pytest + +from whisperlivekit.timed_objects import Silence + + +class TestSilenceStateMachine: + """Test Silence object state transitions.""" + + def test_initial_state(self): + s = Silence(start=1.0, is_starting=True) + assert s.is_starting is True + assert s.has_ended is False + assert s.duration is None + assert s.end is None + + def test_end_silence(self): + s = Silence(start=1.0, is_starting=True) + s.end = 3.0 + s.is_starting = False + s.has_ended = True + s.compute_duration() + assert s.duration == pytest.approx(2.0) + + def test_very_short_silence(self): + s = Silence(start=1.0, end=1.01, is_starting=False, has_ended=True) + s.compute_duration() + assert s.duration == pytest.approx(0.01) + + def test_zero_duration_silence(self): + s = Silence(start=5.0, end=5.0) + s.compute_duration() + assert s.duration == pytest.approx(0.0) + + +class TestSilenceDoubleCounting: + """Regression tests for the silence double-counting bug. + + The bug: _begin_silence and _end_silence both pushed self.current_silence + to the queue. Since they were the same Python object, _end_silence's mutation + affected the already-queued start event. The consumer processed both as + ended silences, doubling the duration. + + Fix: _begin_silence now pushes a separate Silence object for the start event. + """ + + def test_start_and_end_are_separate_objects(self): + """Simulate the fix: start event and end event must be different objects.""" + # Simulate _begin_silence: creates start event as separate object + current_silence = Silence(start=1.0, is_starting=True) + start_event = Silence(start=1.0, is_starting=True) # separate copy + + # Simulate _end_silence: mutates current_silence + current_silence.end = 3.0 + current_silence.is_starting = False + current_silence.has_ended = True + current_silence.compute_duration() + + # start_event should NOT be affected by mutations to current_silence + assert start_event.is_starting is True + assert start_event.has_ended is False + assert start_event.end is None + + # current_silence (end event) has the final state + assert current_silence.has_ended is True + assert current_silence.duration == pytest.approx(2.0) + + def test_single_object_would_cause_double_counting(self): + """Demonstrate the bug: if same object is used for both events.""" + shared = Silence(start=1.0, is_starting=True) + queue = [shared] # start event queued + + # Mutate (simulates _end_silence) + shared.end = 3.0 + shared.is_starting = False + shared.has_ended = True + shared.compute_duration() + queue.append(shared) # end event queued + + # Both queue items point to the SAME mutated object + assert queue[0] is queue[1] # same reference + assert queue[0].has_ended is True # start event also shows ended! + + # This would cause double-counting: both items have has_ended=True + # and duration=2.0, so the consumer adds 2.0 twice = 4.0 + + +class TestConsecutiveSilences: + def test_multiple_silences(self): + """Multiple silence periods should have independent durations.""" + s1 = Silence(start=1.0, end=2.0) + s1.compute_duration() + s2 = Silence(start=5.0, end=8.0) + s2.compute_duration() + assert s1.duration == pytest.approx(1.0) + assert s2.duration == pytest.approx(3.0) + # Total silence should be sum, not accumulated on single object + assert s1.duration + s2.duration == pytest.approx(4.0) diff --git a/tests/test_timed_objects.py b/tests/test_timed_objects.py new file mode 100644 index 0000000..559a1c3 --- /dev/null +++ b/tests/test_timed_objects.py @@ -0,0 +1,185 @@ +"""Tests for whisperlivekit.timed_objects data classes.""" + +import pytest + +from whisperlivekit.timed_objects import ( + ASRToken, + FrontData, + Segment, + Silence, + TimedText, + Transcript, + format_time, +) + + +class TestFormatTime: + def test_zero(self): + assert format_time(0) == "0:00:00" + + def test_one_minute(self): + assert format_time(60) == "0:01:00" + + def test_one_hour(self): + assert format_time(3600) == "1:00:00" + + def test_fractional_truncated(self): + assert format_time(61.9) == "0:01:01" + + +class TestASRToken: + def test_with_offset(self): + t = ASRToken(start=1.0, end=2.0, text="hello") + shifted = t.with_offset(0.5) + assert shifted.start == pytest.approx(1.5) + assert shifted.end == pytest.approx(2.5) + assert shifted.text == "hello" + + def test_with_offset_preserves_fields(self): + t = ASRToken(start=0.0, end=1.0, text="hi", speaker=2, probability=0.95) + shifted = t.with_offset(1.0) + assert shifted.speaker == 2 + assert shifted.probability == 0.95 + + def test_is_silence_false(self): + t = ASRToken(start=0.0, end=1.0, text="hello") + assert t.is_silence() is False + + def test_bool_truthy(self): + t = ASRToken(start=0.0, end=1.0, text="hello") + assert bool(t) is True + + def test_bool_falsy(self): + t = ASRToken(start=0.0, end=1.0, text="") + assert bool(t) is False + + +class TestTimedText: + def test_has_punctuation_period(self): + t = TimedText(text="hello.") + assert t.has_punctuation() is True + + def test_has_punctuation_exclamation(self): + t = TimedText(text="wow!") + assert t.has_punctuation() is True + + def test_has_punctuation_question(self): + t = TimedText(text="really?") + assert t.has_punctuation() is True + + def test_has_punctuation_cjk(self): + t = TimedText(text="hello。") + assert t.has_punctuation() is True + + def test_no_punctuation(self): + t = TimedText(text="hello world") + assert t.has_punctuation() is False + + def test_duration(self): + t = TimedText(start=1.0, end=3.5) + assert t.duration() == pytest.approx(2.5) + + def test_contains_timespan(self): + outer = TimedText(start=0.0, end=5.0) + inner = TimedText(start=1.0, end=3.0) + assert outer.contains_timespan(inner) is True + assert inner.contains_timespan(outer) is False + + +class TestSilence: + def test_compute_duration(self): + s = Silence(start=1.0, end=3.5) + d = s.compute_duration() + assert d == pytest.approx(2.5) + assert s.duration == pytest.approx(2.5) + + def test_compute_duration_none_start(self): + s = Silence(start=None, end=3.5) + d = s.compute_duration() + assert d is None + + def test_compute_duration_none_end(self): + s = Silence(start=1.0, end=None) + d = s.compute_duration() + assert d is None + + def test_is_silence_true(self): + s = Silence() + assert s.is_silence() is True + + +class TestTranscript: + def test_from_tokens(self, sample_tokens): + t = Transcript.from_tokens(sample_tokens, sep="") + assert t.text == "Hello world test." + assert t.start == pytest.approx(0.0) + assert t.end == pytest.approx(1.5) + + def test_from_tokens_with_sep(self, sample_tokens): + t = Transcript.from_tokens(sample_tokens, sep="|") + assert t.text == "Hello| world| test." + + def test_from_empty_tokens(self): + t = Transcript.from_tokens([]) + assert t.text == "" + assert t.start is None + assert t.end is None + + def test_from_tokens_with_offset(self, sample_tokens): + t = Transcript.from_tokens(sample_tokens, offset=10.0) + assert t.start == pytest.approx(10.0) + assert t.end == pytest.approx(11.5) + + +class TestSegment: + def test_from_tokens(self, sample_tokens): + seg = Segment.from_tokens(sample_tokens) + assert seg is not None + assert seg.text == "Hello world test." + assert seg.start == pytest.approx(0.0) + assert seg.end == pytest.approx(1.5) + assert seg.speaker == -1 + + def test_from_silence_tokens(self): + silences = [ + Silence(start=1.0, end=2.0), + Silence(start=2.0, end=3.0), + ] + seg = Segment.from_tokens(silences, is_silence=True) + assert seg is not None + assert seg.speaker == -2 + assert seg.is_silence() is True + assert seg.text is None + + def test_from_empty_tokens(self): + seg = Segment.from_tokens([]) + assert seg is None + + def test_to_dict(self, sample_tokens): + seg = Segment.from_tokens(sample_tokens) + d = seg.to_dict() + assert "text" in d + assert "speaker" in d + assert "start" in d + assert "end" in d + + +class TestFrontData: + def test_to_dict_empty(self): + fd = FrontData() + d = fd.to_dict() + assert d["lines"] == [] + assert d["buffer_transcription"] == "" + assert "error" not in d + + def test_to_dict_with_error(self): + fd = FrontData(error="something broke") + d = fd.to_dict() + assert d["error"] == "something broke" + + def test_to_dict_with_lines(self, sample_tokens): + seg = Segment.from_tokens(sample_tokens) + fd = FrontData(lines=[seg]) + d = fd.to_dict() + assert len(d["lines"]) == 1 + assert d["lines"][0]["text"] == "Hello world test."