mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-21 16:40:35 +00:00
394 lines
13 KiB
Python
394 lines
13 KiB
Python
"""Headless test client for WhisperLiveKit.
|
|
|
|
Feeds audio files to the transcription pipeline via WebSocket
|
|
and collects results — no browser or microphone needed.
|
|
|
|
Usage:
|
|
# Against a running server (server must be started with --pcm-input):
|
|
python -m whisperlivekit.test_client audio.wav
|
|
|
|
# Custom server URL and speed:
|
|
python -m whisperlivekit.test_client audio.wav --url ws://localhost:9090/asr --speed 0
|
|
|
|
# Output raw JSON responses:
|
|
python -m whisperlivekit.test_client audio.wav --json
|
|
|
|
# Programmatic usage:
|
|
from whisperlivekit.test_client import transcribe_audio
|
|
result = asyncio.run(transcribe_audio("audio.wav"))
|
|
print(result.text)
|
|
"""
|
|
|
|
import argparse
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import subprocess
|
|
import sys
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import List, Optional
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
SAMPLE_RATE = 16000
|
|
BYTES_PER_SAMPLE = 2 # s16le
|
|
|
|
|
|
@dataclass
|
|
class TranscriptionResult:
|
|
"""Collected transcription results from a session."""
|
|
|
|
responses: List[dict] = field(default_factory=list)
|
|
audio_duration: float = 0.0
|
|
|
|
@property
|
|
def text(self) -> str:
|
|
"""Full transcription text from the last response (committed lines + buffer)."""
|
|
if not self.responses:
|
|
return ""
|
|
for resp in reversed(self.responses):
|
|
lines = resp.get("lines", [])
|
|
buffer = resp.get("buffer_transcription", "")
|
|
if lines or buffer:
|
|
parts = [line["text"] for line in lines if line.get("text")]
|
|
if buffer:
|
|
parts.append(buffer)
|
|
return " ".join(parts)
|
|
return ""
|
|
|
|
@property
|
|
def committed_text(self) -> str:
|
|
"""Only the committed (finalized) transcription lines, no buffer."""
|
|
if not self.responses:
|
|
return ""
|
|
for resp in reversed(self.responses):
|
|
lines = resp.get("lines", [])
|
|
if lines:
|
|
return " ".join(line["text"] for line in lines if line.get("text"))
|
|
return ""
|
|
|
|
@property
|
|
def lines(self) -> List[dict]:
|
|
"""Committed lines from the last response."""
|
|
for resp in reversed(self.responses):
|
|
if resp.get("lines"):
|
|
return resp["lines"]
|
|
return []
|
|
|
|
@property
|
|
def n_updates(self) -> int:
|
|
"""Number of non-empty updates received."""
|
|
return sum(
|
|
1 for r in self.responses
|
|
if r.get("lines") or r.get("buffer_transcription")
|
|
)
|
|
|
|
|
|
def reconstruct_state(msg: dict, lines: List[dict]) -> dict:
|
|
"""Reconstruct full state from a diff or snapshot message.
|
|
|
|
Mutates ``lines`` in-place (prune front, append new) and returns
|
|
a full-state dict compatible with TranscriptionResult.
|
|
"""
|
|
if msg.get("type") == "snapshot":
|
|
lines.clear()
|
|
lines.extend(msg.get("lines", []))
|
|
return msg
|
|
|
|
# Apply diff
|
|
n_pruned = msg.get("lines_pruned", 0)
|
|
if n_pruned > 0:
|
|
del lines[:n_pruned]
|
|
new_lines = msg.get("new_lines", [])
|
|
lines.extend(new_lines)
|
|
|
|
return {
|
|
"status": msg.get("status", ""),
|
|
"lines": lines[:], # snapshot copy
|
|
"buffer_transcription": msg.get("buffer_transcription", ""),
|
|
"buffer_diarization": msg.get("buffer_diarization", ""),
|
|
"buffer_translation": msg.get("buffer_translation", ""),
|
|
"remaining_time_transcription": msg.get("remaining_time_transcription", 0),
|
|
"remaining_time_diarization": msg.get("remaining_time_diarization", 0),
|
|
}
|
|
|
|
|
|
def load_audio_pcm(audio_path: str, sample_rate: int = SAMPLE_RATE) -> bytes:
|
|
"""Load an audio file and convert to PCM s16le mono via ffmpeg.
|
|
|
|
Supports any format ffmpeg can decode (wav, mp3, flac, ogg, m4a, ...).
|
|
"""
|
|
cmd = [
|
|
"ffmpeg", "-i", str(audio_path),
|
|
"-f", "s16le", "-acodec", "pcm_s16le",
|
|
"-ar", str(sample_rate), "-ac", "1",
|
|
"-loglevel", "error",
|
|
"pipe:1",
|
|
]
|
|
proc = subprocess.run(cmd, capture_output=True)
|
|
if proc.returncode != 0:
|
|
raise RuntimeError(f"ffmpeg conversion failed: {proc.stderr.decode().strip()}")
|
|
if not proc.stdout:
|
|
raise RuntimeError(f"ffmpeg produced no output for {audio_path}")
|
|
return proc.stdout
|
|
|
|
|
|
async def transcribe_audio(
|
|
audio_path: str,
|
|
url: str = "ws://localhost:8000/asr",
|
|
chunk_duration: float = 0.5,
|
|
speed: float = 1.0,
|
|
timeout: float = 60.0,
|
|
on_response: Optional[callable] = None,
|
|
mode: str = "full",
|
|
) -> TranscriptionResult:
|
|
"""Feed an audio file to a running WhisperLiveKit server and collect results.
|
|
|
|
Args:
|
|
audio_path: Path to an audio file (any format ffmpeg supports).
|
|
url: WebSocket URL of the /asr endpoint.
|
|
chunk_duration: Duration of each audio chunk sent (seconds).
|
|
speed: Playback speed multiplier (1.0 = real-time, 0 = as fast as possible).
|
|
timeout: Max seconds to wait for the server after audio finishes.
|
|
on_response: Optional callback invoked with each response dict as it arrives.
|
|
mode: Output mode — "full" (default) or "diff" for incremental updates.
|
|
|
|
Returns:
|
|
TranscriptionResult with collected responses and convenience accessors.
|
|
"""
|
|
import websockets
|
|
|
|
result = TranscriptionResult()
|
|
|
|
# Convert audio to PCM for both modes (we need duration either way)
|
|
pcm_data = load_audio_pcm(audio_path)
|
|
result.audio_duration = len(pcm_data) / (SAMPLE_RATE * BYTES_PER_SAMPLE)
|
|
logger.info("Loaded %s: %.1fs of audio", audio_path, result.audio_duration)
|
|
|
|
chunk_bytes = int(chunk_duration * SAMPLE_RATE * BYTES_PER_SAMPLE)
|
|
|
|
# Append mode query parameter if using diff mode
|
|
connect_url = url
|
|
if mode == "diff":
|
|
sep = "&" if "?" in url else "?"
|
|
connect_url = f"{url}{sep}mode=diff"
|
|
|
|
async with websockets.connect(connect_url) as ws:
|
|
# Server sends config on connect
|
|
config_raw = await ws.recv()
|
|
config_msg = json.loads(config_raw)
|
|
is_pcm = config_msg.get("useAudioWorklet", False)
|
|
logger.info("Server config: %s", config_msg)
|
|
|
|
if not is_pcm:
|
|
logger.warning(
|
|
"Server is not in PCM mode. Start the server with --pcm-input "
|
|
"for the test client. Attempting raw file streaming instead."
|
|
)
|
|
|
|
done_event = asyncio.Event()
|
|
diff_lines: List[dict] = [] # running state for diff mode reconstruction
|
|
|
|
async def send_audio():
|
|
if is_pcm:
|
|
offset = 0
|
|
n_chunks = 0
|
|
while offset < len(pcm_data):
|
|
end = min(offset + chunk_bytes, len(pcm_data))
|
|
await ws.send(pcm_data[offset:end])
|
|
offset = end
|
|
n_chunks += 1
|
|
if speed > 0:
|
|
await asyncio.sleep(chunk_duration / speed)
|
|
logger.info("Sent %d PCM chunks (%.1fs)", n_chunks, result.audio_duration)
|
|
else:
|
|
# Non-PCM: send raw file bytes for server-side ffmpeg decoding
|
|
file_bytes = Path(audio_path).read_bytes()
|
|
raw_chunk_size = 32000
|
|
offset = 0
|
|
while offset < len(file_bytes):
|
|
end = min(offset + raw_chunk_size, len(file_bytes))
|
|
await ws.send(file_bytes[offset:end])
|
|
offset = end
|
|
if speed > 0:
|
|
await asyncio.sleep(0.5 / speed)
|
|
logger.info("Sent %d bytes of raw audio", len(file_bytes))
|
|
|
|
# Signal end of audio
|
|
await ws.send(b"")
|
|
logger.info("End-of-audio signal sent")
|
|
|
|
async def receive_results():
|
|
try:
|
|
async for raw_msg in ws:
|
|
data = json.loads(raw_msg)
|
|
if data.get("type") == "ready_to_stop":
|
|
logger.info("Server signaled ready_to_stop")
|
|
done_event.set()
|
|
return
|
|
# In diff mode, reconstruct full state for uniform API
|
|
if mode == "diff" and data.get("type") in ("snapshot", "diff"):
|
|
data = reconstruct_state(data, diff_lines)
|
|
result.responses.append(data)
|
|
if on_response:
|
|
on_response(data)
|
|
except Exception as e:
|
|
logger.debug("Receiver ended: %s", e)
|
|
done_event.set()
|
|
|
|
send_task = asyncio.create_task(send_audio())
|
|
recv_task = asyncio.create_task(receive_results())
|
|
|
|
# Total wait = time to send + time for server to process + timeout margin
|
|
send_time = result.audio_duration / speed if speed > 0 else 1.0
|
|
total_timeout = send_time + timeout
|
|
|
|
try:
|
|
await asyncio.wait_for(
|
|
asyncio.gather(send_task, recv_task),
|
|
timeout=total_timeout,
|
|
)
|
|
except asyncio.TimeoutError:
|
|
logger.warning("Timed out after %.0fs", total_timeout)
|
|
send_task.cancel()
|
|
recv_task.cancel()
|
|
try:
|
|
await asyncio.gather(send_task, recv_task, return_exceptions=True)
|
|
except Exception:
|
|
pass
|
|
|
|
logger.info(
|
|
"Session complete: %d responses, %d updates",
|
|
len(result.responses), result.n_updates,
|
|
)
|
|
return result
|
|
|
|
|
|
def _print_result(result: TranscriptionResult, output_json: bool = False) -> None:
|
|
"""Print transcription results to stdout."""
|
|
if output_json:
|
|
for resp in result.responses:
|
|
print(json.dumps(resp))
|
|
return
|
|
|
|
if result.lines:
|
|
for line in result.lines:
|
|
speaker = line.get("speaker", "")
|
|
text = line.get("text", "")
|
|
start = line.get("start", "")
|
|
end = line.get("end", "")
|
|
prefix = f"[{start} -> {end}]"
|
|
if speaker and speaker != 1:
|
|
prefix += f" Speaker {speaker}"
|
|
print(f"{prefix} {text}")
|
|
|
|
buffer = ""
|
|
if result.responses:
|
|
buffer = result.responses[-1].get("buffer_transcription", "")
|
|
if buffer:
|
|
print(f"[buffer] {buffer}")
|
|
|
|
if not result.lines and not buffer:
|
|
print("(no transcription received)")
|
|
|
|
print(
|
|
f"\n--- {len(result.responses)} responses | "
|
|
f"{result.n_updates} updates | "
|
|
f"{result.audio_duration:.1f}s audio ---"
|
|
)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
prog="whisperlivekit-test-client",
|
|
description=(
|
|
"Headless test client for WhisperLiveKit. "
|
|
"Feeds audio files via WebSocket and prints the transcription."
|
|
),
|
|
)
|
|
parser.add_argument("audio", help="Path to audio file (wav, mp3, flac, ...)")
|
|
parser.add_argument(
|
|
"--url", default="ws://localhost:8000/asr",
|
|
help="WebSocket endpoint URL (default: ws://localhost:8000/asr)",
|
|
)
|
|
parser.add_argument(
|
|
"--speed", type=float, default=1.0,
|
|
help="Playback speed multiplier (1.0 = real-time, 0 = fastest, default: 1.0)",
|
|
)
|
|
parser.add_argument(
|
|
"--chunk-duration", type=float, default=0.5,
|
|
help="Chunk duration in seconds (default: 0.5)",
|
|
)
|
|
parser.add_argument(
|
|
"--timeout", type=float, default=60.0,
|
|
help="Max seconds to wait for server after audio ends (default: 60)",
|
|
)
|
|
parser.add_argument(
|
|
"--language", "-l", default=None,
|
|
help="Override transcription language for this session (e.g. en, fr, auto)",
|
|
)
|
|
parser.add_argument("--json", action="store_true", help="Output raw JSON responses")
|
|
parser.add_argument(
|
|
"--diff", action="store_true",
|
|
help="Use diff protocol (only receive incremental changes from server)",
|
|
)
|
|
parser.add_argument(
|
|
"--live", action="store_true",
|
|
help="Print transcription updates as they arrive",
|
|
)
|
|
parser.add_argument("--verbose", "-v", action="store_true")
|
|
|
|
args = parser.parse_args()
|
|
|
|
logging.basicConfig(
|
|
level=logging.DEBUG if args.verbose else logging.WARNING,
|
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
|
)
|
|
|
|
audio_path = Path(args.audio)
|
|
if not audio_path.exists():
|
|
print(f"Error: file not found: {audio_path}", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
live_callback = None
|
|
if args.live:
|
|
def live_callback(data):
|
|
lines = data.get("lines", [])
|
|
buf = data.get("buffer_transcription", "")
|
|
parts = [l["text"] for l in lines if l.get("text")]
|
|
if buf:
|
|
parts.append(f"[{buf}]")
|
|
if parts:
|
|
print("\r" + " ".join(parts), end="", flush=True)
|
|
|
|
# Build URL with query parameters for language and mode
|
|
url = args.url
|
|
params = []
|
|
if args.language:
|
|
params.append(f"language={args.language}")
|
|
if args.diff:
|
|
params.append("mode=diff")
|
|
if params:
|
|
sep = "&" if "?" in url else "?"
|
|
url = f"{url}{sep}{'&'.join(params)}"
|
|
|
|
result = asyncio.run(transcribe_audio(
|
|
audio_path=str(audio_path),
|
|
url=url,
|
|
chunk_duration=args.chunk_duration,
|
|
speed=args.speed,
|
|
timeout=args.timeout,
|
|
on_response=live_callback,
|
|
mode="diff" if args.diff else "full",
|
|
))
|
|
|
|
if args.live:
|
|
print() # newline after live output
|
|
|
|
_print_result(result, output_json=args.json)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|