Files
WhisperLiveKit/whisperlivekit/test_client.py
2026-02-22 16:19:00 +01:00

394 lines
13 KiB
Python

"""Headless test client for WhisperLiveKit.
Feeds audio files to the transcription pipeline via WebSocket
and collects results — no browser or microphone needed.
Usage:
# Against a running server (server must be started with --pcm-input):
python -m whisperlivekit.test_client audio.wav
# Custom server URL and speed:
python -m whisperlivekit.test_client audio.wav --url ws://localhost:9090/asr --speed 0
# Output raw JSON responses:
python -m whisperlivekit.test_client audio.wav --json
# Programmatic usage:
from whisperlivekit.test_client import transcribe_audio
result = asyncio.run(transcribe_audio("audio.wav"))
print(result.text)
"""
import argparse
import asyncio
import json
import logging
import subprocess
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional
logger = logging.getLogger(__name__)
SAMPLE_RATE = 16000
BYTES_PER_SAMPLE = 2 # s16le
@dataclass
class TranscriptionResult:
"""Collected transcription results from a session."""
responses: List[dict] = field(default_factory=list)
audio_duration: float = 0.0
@property
def text(self) -> str:
"""Full transcription text from the last response (committed lines + buffer)."""
if not self.responses:
return ""
for resp in reversed(self.responses):
lines = resp.get("lines", [])
buffer = resp.get("buffer_transcription", "")
if lines or buffer:
parts = [line["text"] for line in lines if line.get("text")]
if buffer:
parts.append(buffer)
return " ".join(parts)
return ""
@property
def committed_text(self) -> str:
"""Only the committed (finalized) transcription lines, no buffer."""
if not self.responses:
return ""
for resp in reversed(self.responses):
lines = resp.get("lines", [])
if lines:
return " ".join(line["text"] for line in lines if line.get("text"))
return ""
@property
def lines(self) -> List[dict]:
"""Committed lines from the last response."""
for resp in reversed(self.responses):
if resp.get("lines"):
return resp["lines"]
return []
@property
def n_updates(self) -> int:
"""Number of non-empty updates received."""
return sum(
1 for r in self.responses
if r.get("lines") or r.get("buffer_transcription")
)
def reconstruct_state(msg: dict, lines: List[dict]) -> dict:
"""Reconstruct full state from a diff or snapshot message.
Mutates ``lines`` in-place (prune front, append new) and returns
a full-state dict compatible with TranscriptionResult.
"""
if msg.get("type") == "snapshot":
lines.clear()
lines.extend(msg.get("lines", []))
return msg
# Apply diff
n_pruned = msg.get("lines_pruned", 0)
if n_pruned > 0:
del lines[:n_pruned]
new_lines = msg.get("new_lines", [])
lines.extend(new_lines)
return {
"status": msg.get("status", ""),
"lines": lines[:], # snapshot copy
"buffer_transcription": msg.get("buffer_transcription", ""),
"buffer_diarization": msg.get("buffer_diarization", ""),
"buffer_translation": msg.get("buffer_translation", ""),
"remaining_time_transcription": msg.get("remaining_time_transcription", 0),
"remaining_time_diarization": msg.get("remaining_time_diarization", 0),
}
def load_audio_pcm(audio_path: str, sample_rate: int = SAMPLE_RATE) -> bytes:
"""Load an audio file and convert to PCM s16le mono via ffmpeg.
Supports any format ffmpeg can decode (wav, mp3, flac, ogg, m4a, ...).
"""
cmd = [
"ffmpeg", "-i", str(audio_path),
"-f", "s16le", "-acodec", "pcm_s16le",
"-ar", str(sample_rate), "-ac", "1",
"-loglevel", "error",
"pipe:1",
]
proc = subprocess.run(cmd, capture_output=True)
if proc.returncode != 0:
raise RuntimeError(f"ffmpeg conversion failed: {proc.stderr.decode().strip()}")
if not proc.stdout:
raise RuntimeError(f"ffmpeg produced no output for {audio_path}")
return proc.stdout
async def transcribe_audio(
audio_path: str,
url: str = "ws://localhost:8000/asr",
chunk_duration: float = 0.5,
speed: float = 1.0,
timeout: float = 60.0,
on_response: Optional[callable] = None,
mode: str = "full",
) -> TranscriptionResult:
"""Feed an audio file to a running WhisperLiveKit server and collect results.
Args:
audio_path: Path to an audio file (any format ffmpeg supports).
url: WebSocket URL of the /asr endpoint.
chunk_duration: Duration of each audio chunk sent (seconds).
speed: Playback speed multiplier (1.0 = real-time, 0 = as fast as possible).
timeout: Max seconds to wait for the server after audio finishes.
on_response: Optional callback invoked with each response dict as it arrives.
mode: Output mode — "full" (default) or "diff" for incremental updates.
Returns:
TranscriptionResult with collected responses and convenience accessors.
"""
import websockets
result = TranscriptionResult()
# Convert audio to PCM for both modes (we need duration either way)
pcm_data = load_audio_pcm(audio_path)
result.audio_duration = len(pcm_data) / (SAMPLE_RATE * BYTES_PER_SAMPLE)
logger.info("Loaded %s: %.1fs of audio", audio_path, result.audio_duration)
chunk_bytes = int(chunk_duration * SAMPLE_RATE * BYTES_PER_SAMPLE)
# Append mode query parameter if using diff mode
connect_url = url
if mode == "diff":
sep = "&" if "?" in url else "?"
connect_url = f"{url}{sep}mode=diff"
async with websockets.connect(connect_url) as ws:
# Server sends config on connect
config_raw = await ws.recv()
config_msg = json.loads(config_raw)
is_pcm = config_msg.get("useAudioWorklet", False)
logger.info("Server config: %s", config_msg)
if not is_pcm:
logger.warning(
"Server is not in PCM mode. Start the server with --pcm-input "
"for the test client. Attempting raw file streaming instead."
)
done_event = asyncio.Event()
diff_lines: List[dict] = [] # running state for diff mode reconstruction
async def send_audio():
if is_pcm:
offset = 0
n_chunks = 0
while offset < len(pcm_data):
end = min(offset + chunk_bytes, len(pcm_data))
await ws.send(pcm_data[offset:end])
offset = end
n_chunks += 1
if speed > 0:
await asyncio.sleep(chunk_duration / speed)
logger.info("Sent %d PCM chunks (%.1fs)", n_chunks, result.audio_duration)
else:
# Non-PCM: send raw file bytes for server-side ffmpeg decoding
file_bytes = Path(audio_path).read_bytes()
raw_chunk_size = 32000
offset = 0
while offset < len(file_bytes):
end = min(offset + raw_chunk_size, len(file_bytes))
await ws.send(file_bytes[offset:end])
offset = end
if speed > 0:
await asyncio.sleep(0.5 / speed)
logger.info("Sent %d bytes of raw audio", len(file_bytes))
# Signal end of audio
await ws.send(b"")
logger.info("End-of-audio signal sent")
async def receive_results():
try:
async for raw_msg in ws:
data = json.loads(raw_msg)
if data.get("type") == "ready_to_stop":
logger.info("Server signaled ready_to_stop")
done_event.set()
return
# In diff mode, reconstruct full state for uniform API
if mode == "diff" and data.get("type") in ("snapshot", "diff"):
data = reconstruct_state(data, diff_lines)
result.responses.append(data)
if on_response:
on_response(data)
except Exception as e:
logger.debug("Receiver ended: %s", e)
done_event.set()
send_task = asyncio.create_task(send_audio())
recv_task = asyncio.create_task(receive_results())
# Total wait = time to send + time for server to process + timeout margin
send_time = result.audio_duration / speed if speed > 0 else 1.0
total_timeout = send_time + timeout
try:
await asyncio.wait_for(
asyncio.gather(send_task, recv_task),
timeout=total_timeout,
)
except asyncio.TimeoutError:
logger.warning("Timed out after %.0fs", total_timeout)
send_task.cancel()
recv_task.cancel()
try:
await asyncio.gather(send_task, recv_task, return_exceptions=True)
except Exception:
pass
logger.info(
"Session complete: %d responses, %d updates",
len(result.responses), result.n_updates,
)
return result
def _print_result(result: TranscriptionResult, output_json: bool = False) -> None:
"""Print transcription results to stdout."""
if output_json:
for resp in result.responses:
print(json.dumps(resp))
return
if result.lines:
for line in result.lines:
speaker = line.get("speaker", "")
text = line.get("text", "")
start = line.get("start", "")
end = line.get("end", "")
prefix = f"[{start} -> {end}]"
if speaker and speaker != 1:
prefix += f" Speaker {speaker}"
print(f"{prefix} {text}")
buffer = ""
if result.responses:
buffer = result.responses[-1].get("buffer_transcription", "")
if buffer:
print(f"[buffer] {buffer}")
if not result.lines and not buffer:
print("(no transcription received)")
print(
f"\n--- {len(result.responses)} responses | "
f"{result.n_updates} updates | "
f"{result.audio_duration:.1f}s audio ---"
)
def main():
parser = argparse.ArgumentParser(
prog="whisperlivekit-test-client",
description=(
"Headless test client for WhisperLiveKit. "
"Feeds audio files via WebSocket and prints the transcription."
),
)
parser.add_argument("audio", help="Path to audio file (wav, mp3, flac, ...)")
parser.add_argument(
"--url", default="ws://localhost:8000/asr",
help="WebSocket endpoint URL (default: ws://localhost:8000/asr)",
)
parser.add_argument(
"--speed", type=float, default=1.0,
help="Playback speed multiplier (1.0 = real-time, 0 = fastest, default: 1.0)",
)
parser.add_argument(
"--chunk-duration", type=float, default=0.5,
help="Chunk duration in seconds (default: 0.5)",
)
parser.add_argument(
"--timeout", type=float, default=60.0,
help="Max seconds to wait for server after audio ends (default: 60)",
)
parser.add_argument(
"--language", "-l", default=None,
help="Override transcription language for this session (e.g. en, fr, auto)",
)
parser.add_argument("--json", action="store_true", help="Output raw JSON responses")
parser.add_argument(
"--diff", action="store_true",
help="Use diff protocol (only receive incremental changes from server)",
)
parser.add_argument(
"--live", action="store_true",
help="Print transcription updates as they arrive",
)
parser.add_argument("--verbose", "-v", action="store_true")
args = parser.parse_args()
logging.basicConfig(
level=logging.DEBUG if args.verbose else logging.WARNING,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
audio_path = Path(args.audio)
if not audio_path.exists():
print(f"Error: file not found: {audio_path}", file=sys.stderr)
sys.exit(1)
live_callback = None
if args.live:
def live_callback(data):
lines = data.get("lines", [])
buf = data.get("buffer_transcription", "")
parts = [l["text"] for l in lines if l.get("text")]
if buf:
parts.append(f"[{buf}]")
if parts:
print("\r" + " ".join(parts), end="", flush=True)
# Build URL with query parameters for language and mode
url = args.url
params = []
if args.language:
params.append(f"language={args.language}")
if args.diff:
params.append("mode=diff")
if params:
sep = "&" if "?" in url else "?"
url = f"{url}{sep}{'&'.join(params)}"
result = asyncio.run(transcribe_audio(
audio_path=str(audio_path),
url=url,
chunk_duration=args.chunk_duration,
speed=args.speed,
timeout=args.timeout,
on_response=live_callback,
mode="diff" if args.diff else "full",
))
if args.live:
print() # newline after live output
_print_result(result, output_json=args.json)
if __name__ == "__main__":
main()