mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 14:23:18 +00:00
- BENCHMARK.md: whisper also supports --language auto, voxtral is not the only one. Fixed mlx-whisper speed comparison (LA is actually faster than SS for mlx-whisper, not comparable). - metrics.py: median calculation was wrong for even-length lists (took upper middle instead of averaging the two middle values). - metrics_collector.py: RTF was inflated because log_summary() used wall-clock elapsed time instead of sum of actual ASR call durations. - README.md: clarified that whisper also supports auto language detection, voxtral just does it better. - Added 2 new median tests (even + odd length).
184 lines
6.8 KiB
Python
184 lines
6.8 KiB
Python
"""Tests for whisperlivekit.metrics — WER, timestamp accuracy, normalization."""
|
|
|
|
import pytest
|
|
|
|
from whisperlivekit.metrics import compute_wer, compute_timestamp_accuracy, normalize_text
|
|
|
|
|
|
class TestNormalizeText:
|
|
def test_lowercase(self):
|
|
assert normalize_text("Hello World") == "hello world"
|
|
|
|
def test_strip_punctuation(self):
|
|
assert normalize_text("Hello, world!") == "hello world"
|
|
|
|
def test_collapse_whitespace(self):
|
|
assert normalize_text(" hello world ") == "hello world"
|
|
|
|
def test_keep_hyphens(self):
|
|
assert normalize_text("real-time") == "real-time"
|
|
|
|
def test_keep_apostrophes(self):
|
|
assert normalize_text("don't") == "don't"
|
|
|
|
def test_unicode_normalized(self):
|
|
# e + combining accent should be same as precomposed
|
|
assert normalize_text("caf\u0065\u0301") == normalize_text("caf\u00e9")
|
|
|
|
def test_empty(self):
|
|
assert normalize_text("") == ""
|
|
|
|
def test_only_punctuation(self):
|
|
assert normalize_text("...!?") == ""
|
|
|
|
|
|
class TestComputeWER:
|
|
def test_perfect_match(self):
|
|
result = compute_wer("hello world", "hello world")
|
|
assert result["wer"] == 0.0
|
|
assert result["substitutions"] == 0
|
|
assert result["insertions"] == 0
|
|
assert result["deletions"] == 0
|
|
|
|
def test_case_insensitive(self):
|
|
result = compute_wer("Hello World", "hello world")
|
|
assert result["wer"] == 0.0
|
|
|
|
def test_punctuation_ignored(self):
|
|
result = compute_wer("Hello, world!", "hello world")
|
|
assert result["wer"] == 0.0
|
|
|
|
def test_one_substitution(self):
|
|
result = compute_wer("hello world", "hello earth")
|
|
assert result["wer"] == pytest.approx(0.5)
|
|
assert result["substitutions"] == 1
|
|
|
|
def test_one_insertion(self):
|
|
result = compute_wer("hello world", "hello big world")
|
|
assert result["wer"] == pytest.approx(0.5)
|
|
assert result["insertions"] == 1
|
|
|
|
def test_one_deletion(self):
|
|
result = compute_wer("hello big world", "hello world")
|
|
assert result["wer"] == pytest.approx(1 / 3)
|
|
assert result["deletions"] == 1
|
|
|
|
def test_completely_different(self):
|
|
result = compute_wer("the cat sat", "a dog ran")
|
|
assert result["wer"] == pytest.approx(1.0)
|
|
|
|
def test_empty_reference(self):
|
|
result = compute_wer("", "hello")
|
|
assert result["wer"] == 1.0 # 1 insertion / 0 ref → treated as float(m)
|
|
assert result["ref_words"] == 0
|
|
|
|
def test_empty_hypothesis(self):
|
|
result = compute_wer("hello world", "")
|
|
assert result["wer"] == pytest.approx(1.0)
|
|
assert result["deletions"] == 2
|
|
|
|
def test_both_empty(self):
|
|
result = compute_wer("", "")
|
|
assert result["wer"] == 0.0
|
|
|
|
def test_ref_and_hyp_word_counts(self):
|
|
result = compute_wer("one two three", "one two three four")
|
|
assert result["ref_words"] == 3
|
|
assert result["hyp_words"] == 4
|
|
|
|
|
|
class TestComputeTimestampAccuracy:
|
|
def test_perfect_match(self):
|
|
words = [
|
|
{"word": "hello", "start": 0.0, "end": 0.5},
|
|
{"word": "world", "start": 0.5, "end": 1.0},
|
|
]
|
|
result = compute_timestamp_accuracy(words, words)
|
|
assert result["mae_start"] == 0.0
|
|
assert result["max_delta_start"] == 0.0
|
|
assert result["n_matched"] == 2
|
|
|
|
def test_constant_offset(self):
|
|
ref = [
|
|
{"word": "hello", "start": 0.0, "end": 0.5},
|
|
{"word": "world", "start": 0.5, "end": 1.0},
|
|
]
|
|
pred = [
|
|
{"word": "hello", "start": 0.1, "end": 0.6},
|
|
{"word": "world", "start": 0.6, "end": 1.1},
|
|
]
|
|
result = compute_timestamp_accuracy(pred, ref)
|
|
assert result["mae_start"] == pytest.approx(0.1)
|
|
assert result["max_delta_start"] == pytest.approx(0.1)
|
|
assert result["n_matched"] == 2
|
|
|
|
def test_mismatched_word_counts(self):
|
|
ref = [
|
|
{"word": "hello", "start": 0.0, "end": 0.5},
|
|
{"word": "beautiful", "start": 0.5, "end": 1.0},
|
|
{"word": "world", "start": 1.0, "end": 1.5},
|
|
]
|
|
pred = [
|
|
{"word": "hello", "start": 0.0, "end": 0.5},
|
|
{"word": "world", "start": 1.1, "end": 1.6},
|
|
]
|
|
result = compute_timestamp_accuracy(pred, ref)
|
|
assert result["n_matched"] == 2
|
|
assert result["n_ref"] == 3
|
|
assert result["n_pred"] == 2
|
|
|
|
def test_empty_predicted(self):
|
|
ref = [{"word": "hello", "start": 0.0, "end": 0.5}]
|
|
result = compute_timestamp_accuracy([], ref)
|
|
assert result["mae_start"] is None
|
|
assert result["n_matched"] == 0
|
|
|
|
def test_empty_reference(self):
|
|
pred = [{"word": "hello", "start": 0.0, "end": 0.5}]
|
|
result = compute_timestamp_accuracy(pred, [])
|
|
assert result["mae_start"] is None
|
|
assert result["n_matched"] == 0
|
|
|
|
def test_case_insensitive_matching(self):
|
|
ref = [{"word": "Hello", "start": 0.0, "end": 0.5}]
|
|
pred = [{"word": "hello", "start": 0.1, "end": 0.6}]
|
|
result = compute_timestamp_accuracy(pred, ref)
|
|
assert result["n_matched"] == 1
|
|
assert result["mae_start"] == pytest.approx(0.1)
|
|
|
|
def test_median_even_count(self):
|
|
"""Median with even number of matched words should average the two middle values."""
|
|
ref = [
|
|
{"word": "a", "start": 0.0, "end": 0.2},
|
|
{"word": "b", "start": 0.5, "end": 0.7},
|
|
{"word": "c", "start": 1.0, "end": 1.2},
|
|
{"word": "d", "start": 1.5, "end": 1.7},
|
|
]
|
|
pred = [
|
|
{"word": "a", "start": 0.1, "end": 0.3}, # delta 0.1
|
|
{"word": "b", "start": 0.7, "end": 0.9}, # delta 0.2
|
|
{"word": "c", "start": 1.3, "end": 1.5}, # delta 0.3
|
|
{"word": "d", "start": 1.9, "end": 2.1}, # delta 0.4
|
|
]
|
|
result = compute_timestamp_accuracy(pred, ref)
|
|
assert result["n_matched"] == 4
|
|
# sorted abs deltas: [0.1, 0.2, 0.3, 0.4] -> median = (0.2 + 0.3) / 2 = 0.25
|
|
assert result["median_delta_start"] == pytest.approx(0.25)
|
|
|
|
def test_median_odd_count(self):
|
|
"""Median with odd number of matched words takes the middle value."""
|
|
ref = [
|
|
{"word": "a", "start": 0.0, "end": 0.2},
|
|
{"word": "b", "start": 0.5, "end": 0.7},
|
|
{"word": "c", "start": 1.0, "end": 1.2},
|
|
]
|
|
pred = [
|
|
{"word": "a", "start": 0.1, "end": 0.3}, # delta 0.1
|
|
{"word": "b", "start": 0.8, "end": 1.0}, # delta 0.3
|
|
{"word": "c", "start": 1.2, "end": 1.4}, # delta 0.2
|
|
]
|
|
result = compute_timestamp_accuracy(pred, ref)
|
|
assert result["n_matched"] == 3
|
|
# sorted abs deltas: [0.1, 0.2, 0.3] -> median = 0.2
|
|
assert result["median_delta_start"] == pytest.approx(0.2)
|