mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-04-24 07:10:29 +00:00
Add test harness and test client
This commit is contained in:
393
whisperlivekit/test_client.py
Normal file
393
whisperlivekit/test_client.py
Normal file
@@ -0,0 +1,393 @@
|
||||
"""Headless test client for WhisperLiveKit.
|
||||
|
||||
Feeds audio files to the transcription pipeline via WebSocket
|
||||
and collects results — no browser or microphone needed.
|
||||
|
||||
Usage:
|
||||
# Against a running server (server must be started with --pcm-input):
|
||||
python -m whisperlivekit.test_client audio.wav
|
||||
|
||||
# Custom server URL and speed:
|
||||
python -m whisperlivekit.test_client audio.wav --url ws://localhost:9090/asr --speed 0
|
||||
|
||||
# Output raw JSON responses:
|
||||
python -m whisperlivekit.test_client audio.wav --json
|
||||
|
||||
# Programmatic usage:
|
||||
from whisperlivekit.test_client import transcribe_audio
|
||||
result = asyncio.run(transcribe_audio("audio.wav"))
|
||||
print(result.text)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import subprocess
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SAMPLE_RATE = 16000
|
||||
BYTES_PER_SAMPLE = 2 # s16le
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptionResult:
|
||||
"""Collected transcription results from a session."""
|
||||
|
||||
responses: List[dict] = field(default_factory=list)
|
||||
audio_duration: float = 0.0
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
"""Full transcription text from the last response (committed lines + buffer)."""
|
||||
if not self.responses:
|
||||
return ""
|
||||
for resp in reversed(self.responses):
|
||||
lines = resp.get("lines", [])
|
||||
buffer = resp.get("buffer_transcription", "")
|
||||
if lines or buffer:
|
||||
parts = [line["text"] for line in lines if line.get("text")]
|
||||
if buffer:
|
||||
parts.append(buffer)
|
||||
return " ".join(parts)
|
||||
return ""
|
||||
|
||||
@property
|
||||
def committed_text(self) -> str:
|
||||
"""Only the committed (finalized) transcription lines, no buffer."""
|
||||
if not self.responses:
|
||||
return ""
|
||||
for resp in reversed(self.responses):
|
||||
lines = resp.get("lines", [])
|
||||
if lines:
|
||||
return " ".join(line["text"] for line in lines if line.get("text"))
|
||||
return ""
|
||||
|
||||
@property
|
||||
def lines(self) -> List[dict]:
|
||||
"""Committed lines from the last response."""
|
||||
for resp in reversed(self.responses):
|
||||
if resp.get("lines"):
|
||||
return resp["lines"]
|
||||
return []
|
||||
|
||||
@property
|
||||
def n_updates(self) -> int:
|
||||
"""Number of non-empty updates received."""
|
||||
return sum(
|
||||
1 for r in self.responses
|
||||
if r.get("lines") or r.get("buffer_transcription")
|
||||
)
|
||||
|
||||
|
||||
def reconstruct_state(msg: dict, lines: List[dict]) -> dict:
|
||||
"""Reconstruct full state from a diff or snapshot message.
|
||||
|
||||
Mutates ``lines`` in-place (prune front, append new) and returns
|
||||
a full-state dict compatible with TranscriptionResult.
|
||||
"""
|
||||
if msg.get("type") == "snapshot":
|
||||
lines.clear()
|
||||
lines.extend(msg.get("lines", []))
|
||||
return msg
|
||||
|
||||
# Apply diff
|
||||
n_pruned = msg.get("lines_pruned", 0)
|
||||
if n_pruned > 0:
|
||||
del lines[:n_pruned]
|
||||
new_lines = msg.get("new_lines", [])
|
||||
lines.extend(new_lines)
|
||||
|
||||
return {
|
||||
"status": msg.get("status", ""),
|
||||
"lines": lines[:], # snapshot copy
|
||||
"buffer_transcription": msg.get("buffer_transcription", ""),
|
||||
"buffer_diarization": msg.get("buffer_diarization", ""),
|
||||
"buffer_translation": msg.get("buffer_translation", ""),
|
||||
"remaining_time_transcription": msg.get("remaining_time_transcription", 0),
|
||||
"remaining_time_diarization": msg.get("remaining_time_diarization", 0),
|
||||
}
|
||||
|
||||
|
||||
def load_audio_pcm(audio_path: str, sample_rate: int = SAMPLE_RATE) -> bytes:
|
||||
"""Load an audio file and convert to PCM s16le mono via ffmpeg.
|
||||
|
||||
Supports any format ffmpeg can decode (wav, mp3, flac, ogg, m4a, ...).
|
||||
"""
|
||||
cmd = [
|
||||
"ffmpeg", "-i", str(audio_path),
|
||||
"-f", "s16le", "-acodec", "pcm_s16le",
|
||||
"-ar", str(sample_rate), "-ac", "1",
|
||||
"-loglevel", "error",
|
||||
"pipe:1",
|
||||
]
|
||||
proc = subprocess.run(cmd, capture_output=True)
|
||||
if proc.returncode != 0:
|
||||
raise RuntimeError(f"ffmpeg conversion failed: {proc.stderr.decode().strip()}")
|
||||
if not proc.stdout:
|
||||
raise RuntimeError(f"ffmpeg produced no output for {audio_path}")
|
||||
return proc.stdout
|
||||
|
||||
|
||||
async def transcribe_audio(
|
||||
audio_path: str,
|
||||
url: str = "ws://localhost:8000/asr",
|
||||
chunk_duration: float = 0.5,
|
||||
speed: float = 1.0,
|
||||
timeout: float = 60.0,
|
||||
on_response: Optional[callable] = None,
|
||||
mode: str = "full",
|
||||
) -> TranscriptionResult:
|
||||
"""Feed an audio file to a running WhisperLiveKit server and collect results.
|
||||
|
||||
Args:
|
||||
audio_path: Path to an audio file (any format ffmpeg supports).
|
||||
url: WebSocket URL of the /asr endpoint.
|
||||
chunk_duration: Duration of each audio chunk sent (seconds).
|
||||
speed: Playback speed multiplier (1.0 = real-time, 0 = as fast as possible).
|
||||
timeout: Max seconds to wait for the server after audio finishes.
|
||||
on_response: Optional callback invoked with each response dict as it arrives.
|
||||
mode: Output mode — "full" (default) or "diff" for incremental updates.
|
||||
|
||||
Returns:
|
||||
TranscriptionResult with collected responses and convenience accessors.
|
||||
"""
|
||||
import websockets
|
||||
|
||||
result = TranscriptionResult()
|
||||
|
||||
# Convert audio to PCM for both modes (we need duration either way)
|
||||
pcm_data = load_audio_pcm(audio_path)
|
||||
result.audio_duration = len(pcm_data) / (SAMPLE_RATE * BYTES_PER_SAMPLE)
|
||||
logger.info("Loaded %s: %.1fs of audio", audio_path, result.audio_duration)
|
||||
|
||||
chunk_bytes = int(chunk_duration * SAMPLE_RATE * BYTES_PER_SAMPLE)
|
||||
|
||||
# Append mode query parameter if using diff mode
|
||||
connect_url = url
|
||||
if mode == "diff":
|
||||
sep = "&" if "?" in url else "?"
|
||||
connect_url = f"{url}{sep}mode=diff"
|
||||
|
||||
async with websockets.connect(connect_url) as ws:
|
||||
# Server sends config on connect
|
||||
config_raw = await ws.recv()
|
||||
config_msg = json.loads(config_raw)
|
||||
is_pcm = config_msg.get("useAudioWorklet", False)
|
||||
logger.info("Server config: %s", config_msg)
|
||||
|
||||
if not is_pcm:
|
||||
logger.warning(
|
||||
"Server is not in PCM mode. Start the server with --pcm-input "
|
||||
"for the test client. Attempting raw file streaming instead."
|
||||
)
|
||||
|
||||
done_event = asyncio.Event()
|
||||
diff_lines: List[dict] = [] # running state for diff mode reconstruction
|
||||
|
||||
async def send_audio():
|
||||
if is_pcm:
|
||||
offset = 0
|
||||
n_chunks = 0
|
||||
while offset < len(pcm_data):
|
||||
end = min(offset + chunk_bytes, len(pcm_data))
|
||||
await ws.send(pcm_data[offset:end])
|
||||
offset = end
|
||||
n_chunks += 1
|
||||
if speed > 0:
|
||||
await asyncio.sleep(chunk_duration / speed)
|
||||
logger.info("Sent %d PCM chunks (%.1fs)", n_chunks, result.audio_duration)
|
||||
else:
|
||||
# Non-PCM: send raw file bytes for server-side ffmpeg decoding
|
||||
file_bytes = Path(audio_path).read_bytes()
|
||||
raw_chunk_size = 32000
|
||||
offset = 0
|
||||
while offset < len(file_bytes):
|
||||
end = min(offset + raw_chunk_size, len(file_bytes))
|
||||
await ws.send(file_bytes[offset:end])
|
||||
offset = end
|
||||
if speed > 0:
|
||||
await asyncio.sleep(0.5 / speed)
|
||||
logger.info("Sent %d bytes of raw audio", len(file_bytes))
|
||||
|
||||
# Signal end of audio
|
||||
await ws.send(b"")
|
||||
logger.info("End-of-audio signal sent")
|
||||
|
||||
async def receive_results():
|
||||
try:
|
||||
async for raw_msg in ws:
|
||||
data = json.loads(raw_msg)
|
||||
if data.get("type") == "ready_to_stop":
|
||||
logger.info("Server signaled ready_to_stop")
|
||||
done_event.set()
|
||||
return
|
||||
# In diff mode, reconstruct full state for uniform API
|
||||
if mode == "diff" and data.get("type") in ("snapshot", "diff"):
|
||||
data = reconstruct_state(data, diff_lines)
|
||||
result.responses.append(data)
|
||||
if on_response:
|
||||
on_response(data)
|
||||
except Exception as e:
|
||||
logger.debug("Receiver ended: %s", e)
|
||||
done_event.set()
|
||||
|
||||
send_task = asyncio.create_task(send_audio())
|
||||
recv_task = asyncio.create_task(receive_results())
|
||||
|
||||
# Total wait = time to send + time for server to process + timeout margin
|
||||
send_time = result.audio_duration / speed if speed > 0 else 1.0
|
||||
total_timeout = send_time + timeout
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.gather(send_task, recv_task),
|
||||
timeout=total_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Timed out after %.0fs", total_timeout)
|
||||
send_task.cancel()
|
||||
recv_task.cancel()
|
||||
try:
|
||||
await asyncio.gather(send_task, recv_task, return_exceptions=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info(
|
||||
"Session complete: %d responses, %d updates",
|
||||
len(result.responses), result.n_updates,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _print_result(result: TranscriptionResult, output_json: bool = False) -> None:
|
||||
"""Print transcription results to stdout."""
|
||||
if output_json:
|
||||
for resp in result.responses:
|
||||
print(json.dumps(resp))
|
||||
return
|
||||
|
||||
if result.lines:
|
||||
for line in result.lines:
|
||||
speaker = line.get("speaker", "")
|
||||
text = line.get("text", "")
|
||||
start = line.get("start", "")
|
||||
end = line.get("end", "")
|
||||
prefix = f"[{start} -> {end}]"
|
||||
if speaker and speaker != 1:
|
||||
prefix += f" Speaker {speaker}"
|
||||
print(f"{prefix} {text}")
|
||||
|
||||
buffer = ""
|
||||
if result.responses:
|
||||
buffer = result.responses[-1].get("buffer_transcription", "")
|
||||
if buffer:
|
||||
print(f"[buffer] {buffer}")
|
||||
|
||||
if not result.lines and not buffer:
|
||||
print("(no transcription received)")
|
||||
|
||||
print(
|
||||
f"\n--- {len(result.responses)} responses | "
|
||||
f"{result.n_updates} updates | "
|
||||
f"{result.audio_duration:.1f}s audio ---"
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="whisperlivekit-test-client",
|
||||
description=(
|
||||
"Headless test client for WhisperLiveKit. "
|
||||
"Feeds audio files via WebSocket and prints the transcription."
|
||||
),
|
||||
)
|
||||
parser.add_argument("audio", help="Path to audio file (wav, mp3, flac, ...)")
|
||||
parser.add_argument(
|
||||
"--url", default="ws://localhost:8000/asr",
|
||||
help="WebSocket endpoint URL (default: ws://localhost:8000/asr)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speed", type=float, default=1.0,
|
||||
help="Playback speed multiplier (1.0 = real-time, 0 = fastest, default: 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chunk-duration", type=float, default=0.5,
|
||||
help="Chunk duration in seconds (default: 0.5)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeout", type=float, default=60.0,
|
||||
help="Max seconds to wait for server after audio ends (default: 60)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--language", "-l", default=None,
|
||||
help="Override transcription language for this session (e.g. en, fr, auto)",
|
||||
)
|
||||
parser.add_argument("--json", action="store_true", help="Output raw JSON responses")
|
||||
parser.add_argument(
|
||||
"--diff", action="store_true",
|
||||
help="Use diff protocol (only receive incremental changes from server)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--live", action="store_true",
|
||||
help="Print transcription updates as they arrive",
|
||||
)
|
||||
parser.add_argument("--verbose", "-v", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG if args.verbose else logging.WARNING,
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
)
|
||||
|
||||
audio_path = Path(args.audio)
|
||||
if not audio_path.exists():
|
||||
print(f"Error: file not found: {audio_path}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
live_callback = None
|
||||
if args.live:
|
||||
def live_callback(data):
|
||||
lines = data.get("lines", [])
|
||||
buf = data.get("buffer_transcription", "")
|
||||
parts = [l["text"] for l in lines if l.get("text")]
|
||||
if buf:
|
||||
parts.append(f"[{buf}]")
|
||||
if parts:
|
||||
print("\r" + " ".join(parts), end="", flush=True)
|
||||
|
||||
# Build URL with query parameters for language and mode
|
||||
url = args.url
|
||||
params = []
|
||||
if args.language:
|
||||
params.append(f"language={args.language}")
|
||||
if args.diff:
|
||||
params.append("mode=diff")
|
||||
if params:
|
||||
sep = "&" if "?" in url else "?"
|
||||
url = f"{url}{sep}{'&'.join(params)}"
|
||||
|
||||
result = asyncio.run(transcribe_audio(
|
||||
audio_path=str(audio_path),
|
||||
url=url,
|
||||
chunk_duration=args.chunk_duration,
|
||||
speed=args.speed,
|
||||
timeout=args.timeout,
|
||||
on_response=live_callback,
|
||||
mode="diff" if args.diff else "full",
|
||||
))
|
||||
|
||||
if args.live:
|
||||
print() # newline after live output
|
||||
|
||||
_print_result(result, output_json=args.json)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
365
whisperlivekit/test_data.py
Normal file
365
whisperlivekit/test_data.py
Normal file
@@ -0,0 +1,365 @@
|
||||
"""Standard test audio samples for evaluating the WhisperLiveKit pipeline.
|
||||
|
||||
Downloads curated samples from public ASR datasets (LibriSpeech, AMI)
|
||||
and caches them locally. Each sample includes the audio file path,
|
||||
ground truth transcript, speaker info, and timing metadata.
|
||||
|
||||
Usage::
|
||||
|
||||
from whisperlivekit.test_data import get_samples, get_sample
|
||||
|
||||
# Download all standard test samples (first call downloads, then cached)
|
||||
samples = get_samples()
|
||||
|
||||
for s in samples:
|
||||
print(f"{s.name}: {s.duration:.1f}s, {s.n_speakers} speaker(s)")
|
||||
print(f" Reference: {s.reference[:60]}...")
|
||||
|
||||
# Use with TestHarness
|
||||
from whisperlivekit.test_harness import TestHarness
|
||||
|
||||
async with TestHarness(model_size="base", lan="en") as h:
|
||||
sample = get_sample("librispeech_short")
|
||||
await h.feed(sample.path, speed=0)
|
||||
result = await h.finish()
|
||||
print(f"WER: {result.wer(sample.reference):.2%}")
|
||||
|
||||
Requires: pip install whisperlivekit[test] (installs 'datasets' and 'librosa')
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import wave
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CACHE_DIR = Path.home() / ".cache" / "whisperlivekit" / "test_data"
|
||||
METADATA_FILE = "metadata.json"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestSample:
|
||||
"""A test audio sample with ground truth metadata."""
|
||||
|
||||
name: str
|
||||
path: str # absolute path to WAV file
|
||||
reference: str # ground truth transcript
|
||||
duration: float # audio duration in seconds
|
||||
sample_rate: int = 16000
|
||||
n_speakers: int = 1
|
||||
language: str = "en"
|
||||
source: str = "" # dataset name
|
||||
# Per-utterance ground truth for multi-speaker: [(start, end, speaker, text), ...]
|
||||
utterances: List[Dict] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def has_timestamps(self) -> bool:
|
||||
return len(self.utterances) > 0
|
||||
|
||||
|
||||
def _save_wav(path: Path, audio: np.ndarray, sample_rate: int = 16000) -> None:
|
||||
"""Save numpy audio array as 16-bit PCM WAV."""
|
||||
# Ensure mono
|
||||
if audio.ndim > 1:
|
||||
audio = audio.mean(axis=-1)
|
||||
# Normalize to int16 range
|
||||
if audio.dtype in (np.float32, np.float64):
|
||||
audio = np.clip(audio, -1.0, 1.0)
|
||||
audio = (audio * 32767).astype(np.int16)
|
||||
elif audio.dtype != np.int16:
|
||||
audio = audio.astype(np.int16)
|
||||
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with wave.open(str(path), "w") as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(sample_rate)
|
||||
wf.writeframes(audio.tobytes())
|
||||
|
||||
|
||||
def _load_metadata() -> Dict:
|
||||
"""Load cached metadata if it exists."""
|
||||
meta_path = CACHE_DIR / METADATA_FILE
|
||||
if meta_path.exists():
|
||||
return json.loads(meta_path.read_text())
|
||||
return {}
|
||||
|
||||
|
||||
def _save_metadata(meta: Dict) -> None:
|
||||
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
(CACHE_DIR / METADATA_FILE).write_text(json.dumps(meta, indent=2))
|
||||
|
||||
|
||||
def _ensure_datasets():
|
||||
"""Check that the datasets library is available."""
|
||||
try:
|
||||
import datasets # noqa: F401
|
||||
return True
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The 'datasets' package is required for test data download. "
|
||||
"Install it with: pip install whisperlivekit[test]"
|
||||
)
|
||||
|
||||
|
||||
def _decode_audio(audio_bytes: bytes) -> tuple:
|
||||
"""Decode audio bytes using soundfile (avoids torchcodec dependency).
|
||||
|
||||
Returns:
|
||||
(audio_array, sample_rate) — float32 numpy array and int sample rate.
|
||||
"""
|
||||
import io
|
||||
|
||||
import soundfile as sf
|
||||
audio_array, sr = sf.read(io.BytesIO(audio_bytes), dtype="float32")
|
||||
return np.array(audio_array, dtype=np.float32), sr
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dataset-specific download functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _download_librispeech_samples(n_samples: int = 3) -> List[Dict]:
|
||||
"""Download short samples from LibriSpeech test-clean."""
|
||||
_ensure_datasets()
|
||||
import datasets.config
|
||||
datasets.config.TORCHCODEC_AVAILABLE = False
|
||||
from datasets import Audio, load_dataset
|
||||
|
||||
logger.info("Downloading LibriSpeech test-clean samples (streaming)...")
|
||||
ds = load_dataset(
|
||||
"openslr/librispeech_asr",
|
||||
"clean",
|
||||
split="test",
|
||||
streaming=True,
|
||||
)
|
||||
ds = ds.cast_column("audio", Audio(decode=False))
|
||||
|
||||
samples = []
|
||||
for i, item in enumerate(ds):
|
||||
if i >= n_samples:
|
||||
break
|
||||
|
||||
audio_array, sr = _decode_audio(item["audio"]["bytes"])
|
||||
duration = len(audio_array) / sr
|
||||
text = item["text"]
|
||||
sample_id = item.get("id", f"librispeech_{i}")
|
||||
|
||||
# Save WAV
|
||||
wav_name = f"librispeech_{i}.wav"
|
||||
wav_path = CACHE_DIR / wav_name
|
||||
_save_wav(wav_path, audio_array, sr)
|
||||
|
||||
# Name: first sample is "librispeech_short", rest are numbered
|
||||
name = "librispeech_short" if i == 0 else f"librispeech_{i}"
|
||||
|
||||
samples.append({
|
||||
"name": name,
|
||||
"file": wav_name,
|
||||
"reference": text,
|
||||
"duration": round(duration, 2),
|
||||
"sample_rate": sr,
|
||||
"n_speakers": 1,
|
||||
"language": "en",
|
||||
"source": "openslr/librispeech_asr (test-clean)",
|
||||
"source_id": str(sample_id),
|
||||
"utterances": [],
|
||||
})
|
||||
logger.info(
|
||||
" [%d] %.1fs - %s",
|
||||
i, duration, text[:60] + ("..." if len(text) > 60 else ""),
|
||||
)
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
def _download_ami_sample() -> List[Dict]:
|
||||
"""Download one AMI meeting segment with multiple speakers."""
|
||||
_ensure_datasets()
|
||||
import datasets.config
|
||||
datasets.config.TORCHCODEC_AVAILABLE = False
|
||||
from datasets import Audio, load_dataset
|
||||
|
||||
logger.info("Downloading AMI meeting test sample (streaming)...")
|
||||
|
||||
# Use the edinburghcstr/ami version which has pre-segmented utterances
|
||||
# with speaker_id, begin_time, end_time, text
|
||||
ds = load_dataset(
|
||||
"edinburghcstr/ami",
|
||||
"ihm",
|
||||
split="test",
|
||||
streaming=True,
|
||||
)
|
||||
ds = ds.cast_column("audio", Audio(decode=False))
|
||||
|
||||
# Collect utterances from one meeting
|
||||
meeting_utterances = []
|
||||
meeting_id = None
|
||||
audio_arrays = []
|
||||
sample_rate = None
|
||||
|
||||
for item in ds:
|
||||
mid = item.get("meeting_id", "unknown")
|
||||
|
||||
# Take the first meeting only
|
||||
if meeting_id is None:
|
||||
meeting_id = mid
|
||||
elif mid != meeting_id:
|
||||
# We've moved to a different meeting, stop
|
||||
break
|
||||
|
||||
audio_array, sr = _decode_audio(item["audio"]["bytes"])
|
||||
sample_rate = sr
|
||||
|
||||
meeting_utterances.append({
|
||||
"start": round(item.get("begin_time", 0.0), 2),
|
||||
"end": round(item.get("end_time", 0.0), 2),
|
||||
"speaker": item.get("speaker_id", "unknown"),
|
||||
"text": item.get("text", ""),
|
||||
})
|
||||
audio_arrays.append(audio_array)
|
||||
|
||||
# Limit to reasonable size (~60s of utterances)
|
||||
total_dur = sum(u["end"] - u["start"] for u in meeting_utterances)
|
||||
if total_dur > 60:
|
||||
break
|
||||
|
||||
if not audio_arrays:
|
||||
logger.warning("No AMI samples found")
|
||||
return []
|
||||
|
||||
# Concatenate all utterance audio
|
||||
full_audio = np.concatenate(audio_arrays)
|
||||
duration = len(full_audio) / sample_rate
|
||||
|
||||
# Build reference text
|
||||
speakers = set(u["speaker"] for u in meeting_utterances)
|
||||
reference = " ".join(u["text"] for u in meeting_utterances if u["text"])
|
||||
|
||||
wav_name = "ami_meeting.wav"
|
||||
wav_path = CACHE_DIR / wav_name
|
||||
_save_wav(wav_path, full_audio, sample_rate)
|
||||
|
||||
logger.info(
|
||||
" AMI meeting %s: %.1fs, %d speakers, %d utterances",
|
||||
meeting_id, duration, len(speakers), len(meeting_utterances),
|
||||
)
|
||||
|
||||
return [{
|
||||
"name": "ami_meeting",
|
||||
"file": wav_name,
|
||||
"reference": reference,
|
||||
"duration": round(duration, 2),
|
||||
"sample_rate": sample_rate,
|
||||
"n_speakers": len(speakers),
|
||||
"language": "en",
|
||||
"source": f"edinburghcstr/ami (ihm, meeting {meeting_id})",
|
||||
"source_id": meeting_id,
|
||||
"utterances": meeting_utterances,
|
||||
}]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def download_test_samples(force: bool = False) -> List[TestSample]:
|
||||
"""Download standard test audio samples.
|
||||
|
||||
Downloads samples from LibriSpeech (clean single-speaker) and
|
||||
AMI (multi-speaker meetings) on first call. Subsequent calls
|
||||
return cached data.
|
||||
|
||||
Args:
|
||||
force: Re-download even if cached.
|
||||
|
||||
Returns:
|
||||
List of TestSample objects ready for use with TestHarness.
|
||||
"""
|
||||
meta = _load_metadata()
|
||||
|
||||
if meta.get("samples") and not force:
|
||||
# Check all files still exist
|
||||
all_exist = all(
|
||||
(CACHE_DIR / s["file"]).exists()
|
||||
for s in meta["samples"]
|
||||
)
|
||||
if all_exist:
|
||||
return _meta_to_samples(meta["samples"])
|
||||
|
||||
logger.info("Downloading test samples to %s ...", CACHE_DIR)
|
||||
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
all_samples = []
|
||||
|
||||
try:
|
||||
all_samples.extend(_download_librispeech_samples(n_samples=3))
|
||||
except Exception as e:
|
||||
logger.warning("Failed to download LibriSpeech samples: %s", e)
|
||||
|
||||
try:
|
||||
all_samples.extend(_download_ami_sample())
|
||||
except Exception as e:
|
||||
logger.warning("Failed to download AMI sample: %s", e)
|
||||
|
||||
if not all_samples:
|
||||
raise RuntimeError(
|
||||
"Failed to download any test samples. "
|
||||
"Check your internet connection and ensure 'datasets' is installed: "
|
||||
"pip install whisperlivekit[test]"
|
||||
)
|
||||
|
||||
_save_metadata({"samples": all_samples})
|
||||
logger.info("Downloaded %d test samples to %s", len(all_samples), CACHE_DIR)
|
||||
|
||||
return _meta_to_samples(all_samples)
|
||||
|
||||
|
||||
def get_samples() -> List[TestSample]:
|
||||
"""Get standard test samples (downloads on first call)."""
|
||||
return download_test_samples()
|
||||
|
||||
|
||||
def get_sample(name: str) -> TestSample:
|
||||
"""Get a specific test sample by name.
|
||||
|
||||
Available names: 'librispeech_short', 'librispeech_1', 'librispeech_2',
|
||||
'ami_meeting'.
|
||||
|
||||
Raises:
|
||||
KeyError: If the sample name is not found.
|
||||
"""
|
||||
samples = get_samples()
|
||||
for s in samples:
|
||||
if s.name == name:
|
||||
return s
|
||||
available = [s.name for s in samples]
|
||||
raise KeyError(f"Sample '{name}' not found. Available: {available}")
|
||||
|
||||
|
||||
def list_sample_names() -> List[str]:
|
||||
"""List names of available test samples (downloads if needed)."""
|
||||
return [s.name for s in get_samples()]
|
||||
|
||||
|
||||
def _meta_to_samples(meta_list: List[Dict]) -> List[TestSample]:
|
||||
"""Convert metadata dicts to TestSample objects."""
|
||||
samples = []
|
||||
for m in meta_list:
|
||||
samples.append(TestSample(
|
||||
name=m["name"],
|
||||
path=str(CACHE_DIR / m["file"]),
|
||||
reference=m["reference"],
|
||||
duration=m["duration"],
|
||||
sample_rate=m.get("sample_rate", 16000),
|
||||
n_speakers=m.get("n_speakers", 1),
|
||||
language=m.get("language", "en"),
|
||||
source=m.get("source", ""),
|
||||
utterances=m.get("utterances", []),
|
||||
))
|
||||
return samples
|
||||
745
whisperlivekit/test_harness.py
Normal file
745
whisperlivekit/test_harness.py
Normal file
@@ -0,0 +1,745 @@
|
||||
"""In-process testing harness for the full WhisperLiveKit pipeline.
|
||||
|
||||
Wraps AudioProcessor to provide a controllable, observable interface
|
||||
for testing transcription, diarization, silence detection, and timing
|
||||
without needing a running server or WebSocket connection.
|
||||
|
||||
Designed for use by AI agents: feed audio with timeline control,
|
||||
inspect state at any point, pause/resume to test silence detection,
|
||||
cut to test abrupt termination.
|
||||
|
||||
Usage::
|
||||
|
||||
import asyncio
|
||||
from whisperlivekit.test_harness import TestHarness
|
||||
|
||||
async def main():
|
||||
async with TestHarness(model_size="base", lan="en") as h:
|
||||
# Load audio with timeline control
|
||||
player = h.load_audio("interview.wav")
|
||||
|
||||
# Play first 5 seconds at real-time speed
|
||||
await player.play(5.0, speed=1.0)
|
||||
print(h.state.text) # Check what's transcribed so far
|
||||
|
||||
# Pause for 7 seconds (triggers silence detection)
|
||||
await h.pause(7.0, speed=1.0)
|
||||
assert h.state.has_silence
|
||||
|
||||
# Resume playback
|
||||
await player.play(5.0, speed=1.0)
|
||||
|
||||
# Finish and evaluate
|
||||
result = await h.finish()
|
||||
print(f"WER: {result.wer('expected transcription'):.2%}")
|
||||
print(f"Speakers: {result.speakers}")
|
||||
print(f"Silence segments: {len(result.silence_segments)}")
|
||||
|
||||
# Inspect historical state at specific audio position
|
||||
snap = h.snapshot_at(3.0)
|
||||
print(f"At 3s: '{snap.text}'")
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import subprocess
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from whisperlivekit.timed_objects import FrontData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Engine cache: avoids reloading models when switching backends in tests.
|
||||
# Key is a frozen config tuple, value is the TranscriptionEngine instance.
|
||||
_engine_cache: Dict[Tuple, "Any"] = {}
|
||||
|
||||
SAMPLE_RATE = 16000
|
||||
BYTES_PER_SAMPLE = 2 # s16le
|
||||
|
||||
|
||||
def _parse_time(time_str: str) -> float:
|
||||
"""Parse 'H:MM:SS.cc' timestamp string to seconds."""
|
||||
parts = time_str.split(":")
|
||||
if len(parts) == 3:
|
||||
return int(parts[0]) * 3600 + int(parts[1]) * 60 + float(parts[2])
|
||||
if len(parts) == 2:
|
||||
return int(parts[0]) * 60 + float(parts[1])
|
||||
return float(parts[0])
|
||||
|
||||
|
||||
def load_audio_pcm(audio_path: str, sample_rate: int = SAMPLE_RATE) -> bytes:
|
||||
"""Load any audio file and convert to PCM s16le mono via ffmpeg."""
|
||||
cmd = [
|
||||
"ffmpeg", "-i", str(audio_path),
|
||||
"-f", "s16le", "-acodec", "pcm_s16le",
|
||||
"-ar", str(sample_rate), "-ac", "1",
|
||||
"-loglevel", "error",
|
||||
"pipe:1",
|
||||
]
|
||||
proc = subprocess.run(cmd, capture_output=True)
|
||||
if proc.returncode != 0:
|
||||
raise RuntimeError(f"ffmpeg conversion failed: {proc.stderr.decode().strip()}")
|
||||
if not proc.stdout:
|
||||
raise RuntimeError(f"ffmpeg produced no output for {audio_path}")
|
||||
return proc.stdout
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestState — observable transcription state
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class TestState:
|
||||
"""Observable transcription state at a point in time.
|
||||
|
||||
Provides accessors for inspecting lines, buffers, speakers, timestamps,
|
||||
silence segments, and computing evaluation metrics like WER.
|
||||
|
||||
All time-based queries accept seconds as floats.
|
||||
"""
|
||||
|
||||
lines: List[Dict[str, Any]] = field(default_factory=list)
|
||||
buffer_transcription: str = ""
|
||||
buffer_diarization: str = ""
|
||||
buffer_translation: str = ""
|
||||
remaining_time_transcription: float = 0.0
|
||||
remaining_time_diarization: float = 0.0
|
||||
audio_position: float = 0.0
|
||||
status: str = ""
|
||||
error: str = ""
|
||||
|
||||
@classmethod
|
||||
def from_front_data(cls, front_data: FrontData, audio_position: float = 0.0) -> "TestState":
|
||||
d = front_data.to_dict()
|
||||
return cls(
|
||||
lines=d.get("lines", []),
|
||||
buffer_transcription=d.get("buffer_transcription", ""),
|
||||
buffer_diarization=d.get("buffer_diarization", ""),
|
||||
buffer_translation=d.get("buffer_translation", ""),
|
||||
remaining_time_transcription=d.get("remaining_time_transcription", 0),
|
||||
remaining_time_diarization=d.get("remaining_time_diarization", 0),
|
||||
audio_position=audio_position,
|
||||
status=d.get("status", ""),
|
||||
error=d.get("error", ""),
|
||||
)
|
||||
|
||||
# ── Text accessors ──
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
"""Full transcription: committed lines + buffer."""
|
||||
parts = [l["text"] for l in self.lines if l.get("text")]
|
||||
if self.buffer_transcription:
|
||||
parts.append(self.buffer_transcription)
|
||||
return " ".join(parts)
|
||||
|
||||
@property
|
||||
def committed_text(self) -> str:
|
||||
"""Only committed (finalized) lines, no buffer."""
|
||||
return " ".join(l["text"] for l in self.lines if l.get("text"))
|
||||
|
||||
@property
|
||||
def committed_word_count(self) -> int:
|
||||
"""Number of words in committed lines."""
|
||||
t = self.committed_text
|
||||
return len(t.split()) if t.strip() else 0
|
||||
|
||||
@property
|
||||
def buffer_word_count(self) -> int:
|
||||
"""Number of words in the unconfirmed buffer."""
|
||||
return len(self.buffer_transcription.split()) if self.buffer_transcription.strip() else 0
|
||||
|
||||
# ── Speaker accessors ──
|
||||
|
||||
@property
|
||||
def speakers(self) -> Set[int]:
|
||||
"""Set of speaker IDs (excluding silence marker -2)."""
|
||||
return {l["speaker"] for l in self.lines if l.get("speaker", 0) > 0}
|
||||
|
||||
@property
|
||||
def n_speakers(self) -> int:
|
||||
return len(self.speakers)
|
||||
|
||||
def speaker_at(self, time_s: float) -> Optional[int]:
|
||||
"""Speaker ID at the given timestamp, or None if no segment covers it."""
|
||||
line = self.line_at(time_s)
|
||||
return line["speaker"] if line else None
|
||||
|
||||
def speakers_in(self, start_s: float, end_s: float) -> Set[int]:
|
||||
"""All speaker IDs active in the time range (excluding silence -2)."""
|
||||
return {
|
||||
l.get("speaker")
|
||||
for l in self.lines_between(start_s, end_s)
|
||||
if l.get("speaker", 0) > 0
|
||||
}
|
||||
|
||||
@property
|
||||
def speaker_timeline(self) -> List[Dict[str, Any]]:
|
||||
"""Timeline: [{"start": float, "end": float, "speaker": int}] for all lines."""
|
||||
return [
|
||||
{
|
||||
"start": _parse_time(l.get("start", "0:00:00")),
|
||||
"end": _parse_time(l.get("end", "0:00:00")),
|
||||
"speaker": l.get("speaker", -1),
|
||||
}
|
||||
for l in self.lines
|
||||
]
|
||||
|
||||
@property
|
||||
def n_speaker_changes(self) -> int:
|
||||
"""Number of speaker transitions (excluding silence segments)."""
|
||||
speech = [s for s in self.speaker_timeline if s["speaker"] != -2]
|
||||
return sum(
|
||||
1 for i in range(1, len(speech))
|
||||
if speech[i]["speaker"] != speech[i - 1]["speaker"]
|
||||
)
|
||||
|
||||
# ── Silence accessors ──
|
||||
|
||||
@property
|
||||
def has_silence(self) -> bool:
|
||||
"""Whether any silence segment (speaker=-2) exists."""
|
||||
return any(l.get("speaker") == -2 for l in self.lines)
|
||||
|
||||
@property
|
||||
def silence_segments(self) -> List[Dict[str, Any]]:
|
||||
"""All silence segments (raw line dicts)."""
|
||||
return [l for l in self.lines if l.get("speaker") == -2]
|
||||
|
||||
def silence_at(self, time_s: float) -> bool:
|
||||
"""True if time_s falls within a silence segment."""
|
||||
line = self.line_at(time_s)
|
||||
return line is not None and line.get("speaker") == -2
|
||||
|
||||
# ── Line / segment accessors ──
|
||||
|
||||
@property
|
||||
def speech_lines(self) -> List[Dict[str, Any]]:
|
||||
"""Lines excluding silence segments."""
|
||||
return [l for l in self.lines if l.get("speaker", 0) != -2 and l.get("text")]
|
||||
|
||||
def line_at(self, time_s: float) -> Optional[Dict[str, Any]]:
|
||||
"""Find the line covering the given timestamp (seconds)."""
|
||||
for line in self.lines:
|
||||
start = _parse_time(line.get("start", "0:00:00"))
|
||||
end = _parse_time(line.get("end", "0:00:00"))
|
||||
if start <= time_s <= end:
|
||||
return line
|
||||
return None
|
||||
|
||||
def text_at(self, time_s: float) -> Optional[str]:
|
||||
"""Text of the segment covering the given timestamp."""
|
||||
line = self.line_at(time_s)
|
||||
return line["text"] if line else None
|
||||
|
||||
def lines_between(self, start_s: float, end_s: float) -> List[Dict[str, Any]]:
|
||||
"""All lines overlapping the time range [start_s, end_s]."""
|
||||
result = []
|
||||
for line in self.lines:
|
||||
ls = _parse_time(line.get("start", "0:00:00"))
|
||||
le = _parse_time(line.get("end", "0:00:00"))
|
||||
if le >= start_s and ls <= end_s:
|
||||
result.append(line)
|
||||
return result
|
||||
|
||||
def text_between(self, start_s: float, end_s: float) -> str:
|
||||
"""Concatenated text of all lines overlapping the time range."""
|
||||
return " ".join(
|
||||
l["text"] for l in self.lines_between(start_s, end_s)
|
||||
if l.get("text")
|
||||
)
|
||||
|
||||
# ── Evaluation ──
|
||||
|
||||
def wer(self, reference: str) -> float:
|
||||
"""Word Error Rate of committed text against reference.
|
||||
|
||||
Returns:
|
||||
WER as a float (0.0 = perfect, 1.0 = 100% error rate).
|
||||
"""
|
||||
from whisperlivekit.metrics import compute_wer
|
||||
result = compute_wer(reference, self.committed_text)
|
||||
return result["wer"]
|
||||
|
||||
def wer_detailed(self, reference: str) -> Dict:
|
||||
"""Full WER breakdown: substitutions, insertions, deletions, etc."""
|
||||
from whisperlivekit.metrics import compute_wer
|
||||
return compute_wer(reference, self.committed_text)
|
||||
|
||||
# ── Timing validation ──
|
||||
|
||||
@property
|
||||
def timestamps(self) -> List[Dict[str, Any]]:
|
||||
"""All line timestamps as [{"start": float, "end": float, "speaker": int, "text": str}]."""
|
||||
result = []
|
||||
for line in self.lines:
|
||||
result.append({
|
||||
"start": _parse_time(line.get("start", "0:00:00")),
|
||||
"end": _parse_time(line.get("end", "0:00:00")),
|
||||
"speaker": line.get("speaker", -1),
|
||||
"text": line.get("text", ""),
|
||||
})
|
||||
return result
|
||||
|
||||
@property
|
||||
def timing_valid(self) -> bool:
|
||||
"""All timestamps have start <= end and no negative values."""
|
||||
for ts in self.timestamps:
|
||||
if ts["start"] < 0 or ts["end"] < 0:
|
||||
return False
|
||||
if ts["end"] < ts["start"]:
|
||||
return False
|
||||
return True
|
||||
|
||||
@property
|
||||
def timing_monotonic(self) -> bool:
|
||||
"""Line start times are non-decreasing."""
|
||||
stamps = self.timestamps
|
||||
for i in range(1, len(stamps)):
|
||||
if stamps[i]["start"] < stamps[i - 1]["start"]:
|
||||
return False
|
||||
return True
|
||||
|
||||
def timing_errors(self) -> List[str]:
|
||||
"""Human-readable list of timing issues found."""
|
||||
errors = []
|
||||
stamps = self.timestamps
|
||||
for i, ts in enumerate(stamps):
|
||||
if ts["start"] < 0:
|
||||
errors.append(f"Line {i}: negative start {ts['start']:.2f}s")
|
||||
if ts["end"] < 0:
|
||||
errors.append(f"Line {i}: negative end {ts['end']:.2f}s")
|
||||
if ts["end"] < ts["start"]:
|
||||
errors.append(
|
||||
f"Line {i}: end ({ts['end']:.2f}s) < start ({ts['start']:.2f}s)"
|
||||
)
|
||||
for i in range(1, len(stamps)):
|
||||
if stamps[i]["start"] < stamps[i - 1]["start"]:
|
||||
errors.append(
|
||||
f"Line {i}: start ({stamps[i]['start']:.2f}s) < previous start "
|
||||
f"({stamps[i-1]['start']:.2f}s) — non-monotonic"
|
||||
)
|
||||
return errors
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AudioPlayer — timeline control for a loaded audio file
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class AudioPlayer:
|
||||
"""Controls playback of a loaded audio file through the pipeline.
|
||||
|
||||
Tracks position in the audio, enabling play/pause/resume patterns::
|
||||
|
||||
player = h.load_audio("speech.wav")
|
||||
await player.play(3.0) # Play first 3 seconds
|
||||
await h.pause(7.0) # 7s silence (triggers detection)
|
||||
await player.play(5.0) # Play next 5 seconds
|
||||
await player.play() # Play all remaining audio
|
||||
|
||||
Args:
|
||||
harness: The TestHarness instance.
|
||||
pcm_data: Raw PCM s16le 16kHz mono bytes.
|
||||
sample_rate: Audio sample rate (default 16000).
|
||||
"""
|
||||
|
||||
def __init__(self, harness: "TestHarness", pcm_data: bytes, sample_rate: int = SAMPLE_RATE):
|
||||
self._harness = harness
|
||||
self._pcm = pcm_data
|
||||
self._sr = sample_rate
|
||||
self._bps = sample_rate * BYTES_PER_SAMPLE # bytes per second
|
||||
self._pos = 0 # current position in bytes
|
||||
|
||||
@property
|
||||
def position(self) -> float:
|
||||
"""Current playback position in seconds."""
|
||||
return self._pos / self._bps
|
||||
|
||||
@property
|
||||
def duration(self) -> float:
|
||||
"""Total audio duration in seconds."""
|
||||
return len(self._pcm) / self._bps
|
||||
|
||||
@property
|
||||
def remaining(self) -> float:
|
||||
"""Remaining audio in seconds."""
|
||||
return max(0.0, (len(self._pcm) - self._pos) / self._bps)
|
||||
|
||||
@property
|
||||
def done(self) -> bool:
|
||||
"""True if all audio has been played."""
|
||||
return self._pos >= len(self._pcm)
|
||||
|
||||
async def play(
|
||||
self,
|
||||
duration_s: Optional[float] = None,
|
||||
speed: float = 1.0,
|
||||
chunk_duration: float = 0.5,
|
||||
) -> None:
|
||||
"""Play audio from the current position.
|
||||
|
||||
Args:
|
||||
duration_s: Seconds of audio to play. None = all remaining.
|
||||
speed: 1.0 = real-time, 0 = instant, >1 = faster.
|
||||
chunk_duration: Size of each chunk fed to the pipeline (seconds).
|
||||
"""
|
||||
if duration_s is None:
|
||||
end_pos = len(self._pcm)
|
||||
else:
|
||||
end_pos = min(self._pos + int(duration_s * self._bps), len(self._pcm))
|
||||
|
||||
# Align to sample boundary
|
||||
end_pos = (end_pos // BYTES_PER_SAMPLE) * BYTES_PER_SAMPLE
|
||||
|
||||
if end_pos <= self._pos:
|
||||
return
|
||||
|
||||
segment = self._pcm[self._pos:end_pos]
|
||||
self._pos = end_pos
|
||||
await self._harness.feed_pcm(segment, speed=speed, chunk_duration=chunk_duration)
|
||||
|
||||
async def play_until(
|
||||
self,
|
||||
time_s: float,
|
||||
speed: float = 1.0,
|
||||
chunk_duration: float = 0.5,
|
||||
) -> None:
|
||||
"""Play until reaching time_s in the audio timeline."""
|
||||
target = min(int(time_s * self._bps), len(self._pcm))
|
||||
target = (target // BYTES_PER_SAMPLE) * BYTES_PER_SAMPLE
|
||||
|
||||
if target <= self._pos:
|
||||
return
|
||||
|
||||
segment = self._pcm[self._pos:target]
|
||||
self._pos = target
|
||||
await self._harness.feed_pcm(segment, speed=speed, chunk_duration=chunk_duration)
|
||||
|
||||
def seek(self, time_s: float) -> None:
|
||||
"""Move the playback cursor without feeding audio."""
|
||||
pos = int(time_s * self._bps)
|
||||
pos = (pos // BYTES_PER_SAMPLE) * BYTES_PER_SAMPLE
|
||||
self._pos = max(0, min(pos, len(self._pcm)))
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset to the beginning of the audio."""
|
||||
self._pos = 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestHarness — pipeline controller
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestHarness:
|
||||
"""In-process testing harness for the full WhisperLiveKit pipeline.
|
||||
|
||||
Use as an async context manager. Provides methods to feed audio,
|
||||
pause/resume, inspect state, and evaluate results.
|
||||
|
||||
Methods:
|
||||
load_audio(path) → AudioPlayer with play/seek controls
|
||||
feed(path, speed) → feed entire audio file (simple mode)
|
||||
pause(duration) → inject silence (triggers detection if > 5s)
|
||||
drain(seconds) → let pipeline catch up
|
||||
finish() → flush and return final state
|
||||
cut() → abrupt stop, return partial state
|
||||
wait_for(pred) → wait for condition on state
|
||||
|
||||
State inspection:
|
||||
.state → current TestState
|
||||
.history → all historical states
|
||||
.snapshot_at(t) → state at audio position t
|
||||
.metrics → SessionMetrics (latency, RTF, etc.)
|
||||
|
||||
Args:
|
||||
All keyword arguments passed to AudioProcessor.
|
||||
Common: model_size, lan, backend, diarization, vac.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
kwargs.setdefault("pcm_input", True)
|
||||
self._engine_kwargs = kwargs
|
||||
self._processor = None
|
||||
self._results_gen = None
|
||||
self._collect_task = None
|
||||
self._state = TestState()
|
||||
self._audio_position = 0.0
|
||||
self._history: List[TestState] = []
|
||||
self._on_update: Optional[Callable[[TestState], None]] = None
|
||||
|
||||
async def __aenter__(self) -> "TestHarness":
|
||||
from whisperlivekit.audio_processor import AudioProcessor
|
||||
from whisperlivekit.core import TranscriptionEngine
|
||||
|
||||
# Cache engines by config to avoid reloading models when switching
|
||||
# backends between tests. The singleton is reset only when the
|
||||
# requested config doesn't match any cached engine.
|
||||
cache_key = tuple(sorted(self._engine_kwargs.items()))
|
||||
|
||||
if cache_key not in _engine_cache:
|
||||
TranscriptionEngine.reset()
|
||||
_engine_cache[cache_key] = TranscriptionEngine(**self._engine_kwargs)
|
||||
|
||||
engine = _engine_cache[cache_key]
|
||||
|
||||
self._processor = AudioProcessor(transcription_engine=engine)
|
||||
self._results_gen = await self._processor.create_tasks()
|
||||
self._collect_task = asyncio.create_task(self._collect_results())
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *exc: Any) -> None:
|
||||
if self._processor:
|
||||
await self._processor.cleanup()
|
||||
if self._collect_task and not self._collect_task.done():
|
||||
self._collect_task.cancel()
|
||||
try:
|
||||
await self._collect_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _collect_results(self) -> None:
|
||||
"""Background task: consume results from the pipeline."""
|
||||
try:
|
||||
async for front_data in self._results_gen:
|
||||
self._state = TestState.from_front_data(front_data, self._audio_position)
|
||||
self._history.append(self._state)
|
||||
if self._on_update:
|
||||
self._on_update(self._state)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning("Result collector ended: %s", e)
|
||||
|
||||
# ── Properties ──
|
||||
|
||||
@property
|
||||
def state(self) -> TestState:
|
||||
"""Current transcription state (updated live as results arrive)."""
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def history(self) -> List[TestState]:
|
||||
"""All states received so far, in order."""
|
||||
return self._history
|
||||
|
||||
@property
|
||||
def audio_position(self) -> float:
|
||||
"""How many seconds of audio have been fed so far."""
|
||||
return self._audio_position
|
||||
|
||||
@property
|
||||
def metrics(self):
|
||||
"""Pipeline's SessionMetrics (latency, RTF, token counts, etc.)."""
|
||||
if self._processor:
|
||||
return self._processor.metrics
|
||||
return None
|
||||
|
||||
def on_update(self, callback: Callable[[TestState], None]) -> None:
|
||||
"""Register a callback invoked on each new state update."""
|
||||
self._on_update = callback
|
||||
|
||||
# ── Audio loading and feeding ──
|
||||
|
||||
def load_audio(self, source) -> AudioPlayer:
|
||||
"""Load audio and return a player with timeline control.
|
||||
|
||||
Args:
|
||||
source: Path to audio file (str), or a TestSample with .path attribute.
|
||||
|
||||
Returns:
|
||||
AudioPlayer with play/play_until/seek/reset methods.
|
||||
"""
|
||||
path = source.path if hasattr(source, "path") else str(source)
|
||||
pcm = load_audio_pcm(path)
|
||||
return AudioPlayer(self, pcm)
|
||||
|
||||
async def feed(
|
||||
self,
|
||||
audio_path: str,
|
||||
speed: float = 1.0,
|
||||
chunk_duration: float = 0.5,
|
||||
) -> None:
|
||||
"""Feed an entire audio file to the pipeline (simple mode).
|
||||
|
||||
For timeline control (play/pause/resume), use load_audio() instead.
|
||||
|
||||
Args:
|
||||
audio_path: Path to any audio file ffmpeg can decode.
|
||||
speed: Playback speed (1.0 = real-time, 0 = instant).
|
||||
chunk_duration: Size of each PCM chunk in seconds.
|
||||
"""
|
||||
pcm = load_audio_pcm(audio_path)
|
||||
await self.feed_pcm(pcm, speed=speed, chunk_duration=chunk_duration)
|
||||
|
||||
async def feed_pcm(
|
||||
self,
|
||||
pcm_data: bytes,
|
||||
speed: float = 1.0,
|
||||
chunk_duration: float = 0.5,
|
||||
) -> None:
|
||||
"""Feed raw PCM s16le 16kHz mono bytes to the pipeline.
|
||||
|
||||
Args:
|
||||
pcm_data: Raw PCM bytes.
|
||||
speed: Playback speed multiplier.
|
||||
chunk_duration: Duration of each chunk sent (seconds).
|
||||
"""
|
||||
chunk_bytes = int(chunk_duration * SAMPLE_RATE * BYTES_PER_SAMPLE)
|
||||
offset = 0
|
||||
while offset < len(pcm_data):
|
||||
end = min(offset + chunk_bytes, len(pcm_data))
|
||||
await self._processor.process_audio(pcm_data[offset:end])
|
||||
chunk_seconds = (end - offset) / (SAMPLE_RATE * BYTES_PER_SAMPLE)
|
||||
self._audio_position += chunk_seconds
|
||||
offset = end
|
||||
if speed > 0:
|
||||
await asyncio.sleep(chunk_duration / speed)
|
||||
|
||||
# ── Pause / silence ──
|
||||
|
||||
async def pause(self, duration_s: float, speed: float = 1.0) -> None:
|
||||
"""Inject silence to simulate a pause in speech.
|
||||
|
||||
Pauses > 5s trigger silence segment detection (MIN_DURATION_REAL_SILENCE).
|
||||
Pauses < 5s are treated as brief gaps and produce no silence segment
|
||||
(provided speech resumes afterward).
|
||||
|
||||
Args:
|
||||
duration_s: Duration of silence in seconds.
|
||||
speed: Playback speed (1.0 = real-time, 0 = instant).
|
||||
"""
|
||||
silent_pcm = bytes(int(duration_s * SAMPLE_RATE * BYTES_PER_SAMPLE))
|
||||
await self.feed_pcm(silent_pcm, speed=speed)
|
||||
|
||||
async def silence(self, duration_s: float, speed: float = 1.0) -> None:
|
||||
"""Alias for pause(). Inject silence for the given duration."""
|
||||
await self.pause(duration_s, speed=speed)
|
||||
|
||||
# ── Waiting ──
|
||||
|
||||
async def wait_for(
|
||||
self,
|
||||
predicate: Callable[[TestState], bool],
|
||||
timeout: float = 30.0,
|
||||
poll_interval: float = 0.1,
|
||||
) -> TestState:
|
||||
"""Wait until predicate(state) returns True.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the condition is not met within timeout.
|
||||
"""
|
||||
deadline = asyncio.get_event_loop().time() + timeout
|
||||
while asyncio.get_event_loop().time() < deadline:
|
||||
if predicate(self._state):
|
||||
return self._state
|
||||
await asyncio.sleep(poll_interval)
|
||||
raise TimeoutError(
|
||||
f"Condition not met within {timeout}s. "
|
||||
f"Current state: {len(self._state.lines)} lines, "
|
||||
f"buffer='{self._state.buffer_transcription[:50]}', "
|
||||
f"audio_pos={self._audio_position:.1f}s"
|
||||
)
|
||||
|
||||
async def wait_for_text(self, timeout: float = 30.0) -> TestState:
|
||||
"""Wait until any transcription text appears."""
|
||||
return await self.wait_for(lambda s: s.text.strip(), timeout=timeout)
|
||||
|
||||
async def wait_for_lines(self, n: int = 1, timeout: float = 30.0) -> TestState:
|
||||
"""Wait until at least n committed speech lines exist."""
|
||||
return await self.wait_for(lambda s: len(s.speech_lines) >= n, timeout=timeout)
|
||||
|
||||
async def wait_for_silence(self, timeout: float = 30.0) -> TestState:
|
||||
"""Wait until a silence segment is detected."""
|
||||
return await self.wait_for(lambda s: s.has_silence, timeout=timeout)
|
||||
|
||||
async def wait_for_speakers(self, n: int = 2, timeout: float = 30.0) -> TestState:
|
||||
"""Wait until at least n distinct speakers are detected."""
|
||||
return await self.wait_for(lambda s: s.n_speakers >= n, timeout=timeout)
|
||||
|
||||
async def drain(self, seconds: float = 2.0) -> None:
|
||||
"""Let the pipeline process without feeding audio.
|
||||
|
||||
Useful after feeding audio to allow the ASR backend to catch up.
|
||||
"""
|
||||
await asyncio.sleep(seconds)
|
||||
|
||||
# ── Finishing ──
|
||||
|
||||
async def finish(self, timeout: float = 30.0) -> TestState:
|
||||
"""Signal end of audio and wait for pipeline to flush all results.
|
||||
|
||||
Returns:
|
||||
Final TestState with all committed lines and empty buffer.
|
||||
"""
|
||||
await self._processor.process_audio(b"")
|
||||
if self._collect_task:
|
||||
try:
|
||||
await asyncio.wait_for(self._collect_task, timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Timed out waiting for pipeline to finish after %.0fs", timeout)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
return self._state
|
||||
|
||||
async def cut(self, timeout: float = 5.0) -> TestState:
|
||||
"""Abrupt audio stop — signal EOF and return current state quickly.
|
||||
|
||||
Simulates user closing the connection mid-speech. Sends EOF but
|
||||
uses a short timeout, so partial results are returned even if
|
||||
the pipeline hasn't fully flushed.
|
||||
|
||||
Returns:
|
||||
TestState with whatever has been processed so far.
|
||||
"""
|
||||
await self._processor.process_audio(b"")
|
||||
if self._collect_task:
|
||||
try:
|
||||
await asyncio.wait_for(self._collect_task, timeout=timeout)
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
pass
|
||||
return self._state
|
||||
|
||||
# ── History inspection ──
|
||||
|
||||
def snapshot_at(self, audio_time: float) -> Optional[TestState]:
|
||||
"""Find the historical state closest to when audio_time was reached.
|
||||
|
||||
Args:
|
||||
audio_time: Audio position in seconds.
|
||||
|
||||
Returns:
|
||||
The TestState captured at that point, or None if no history.
|
||||
"""
|
||||
if not self._history:
|
||||
return None
|
||||
best = None
|
||||
best_diff = float("inf")
|
||||
for s in self._history:
|
||||
diff = abs(s.audio_position - audio_time)
|
||||
if diff < best_diff:
|
||||
best_diff = diff
|
||||
best = s
|
||||
return best
|
||||
|
||||
# ── Debug ──
|
||||
|
||||
def print_state(self) -> None:
|
||||
"""Print current state to stdout for debugging."""
|
||||
s = self._state
|
||||
print(f"--- Audio: {self._audio_position:.1f}s | Status: {s.status} ---")
|
||||
for line in s.lines:
|
||||
speaker = line.get("speaker", "?")
|
||||
text = line.get("text", "")
|
||||
start = line.get("start", "")
|
||||
end = line.get("end", "")
|
||||
tag = "SILENCE" if speaker == -2 else f"Speaker {speaker}"
|
||||
print(f" [{start} -> {end}] {tag}: {text}")
|
||||
if s.buffer_transcription:
|
||||
print(f" [buffer] {s.buffer_transcription}")
|
||||
if s.buffer_diarization:
|
||||
print(f" [diar buffer] {s.buffer_diarization}")
|
||||
print(f" Speakers: {s.speakers or 'none'} | Silence: {s.has_silence}")
|
||||
print()
|
||||
Reference in New Issue
Block a user