From d59ddbaeae32417daa2be95f5cd8234356d1245f Mon Sep 17 00:00:00 2001 From: Emmanuel Schmidbauer Date: Fri, 9 Jan 2026 11:23:19 -0500 Subject: [PATCH] Fix critical thread safety issues --- whisperlivekit/core.py | 18 ++- whisperlivekit/simul_whisper/decoder_state.py | 19 ++- whisperlivekit/simul_whisper/simul_whisper.py | 6 +- whisperlivekit/thread_safety.py | 139 ++++++++++++++++++ 4 files changed, 174 insertions(+), 8 deletions(-) create mode 100644 whisperlivekit/thread_safety.py diff --git a/whisperlivekit/core.py b/whisperlivekit/core.py index 978329b..133ccbf 100644 --- a/whisperlivekit/core.py +++ b/whisperlivekit/core.py @@ -1,5 +1,6 @@ import logging import sys +import threading from argparse import Namespace from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor @@ -19,16 +20,26 @@ logger = logging.getLogger(__name__) class TranscriptionEngine: _instance = None _initialized = False + _lock = threading.Lock() # Thread-safe singleton lock def __new__(cls, *args, **kwargs): + # Double-checked locking pattern for thread-safe singleton if cls._instance is None: - cls._instance = super().__new__(cls) + with cls._lock: + # Check again inside lock to prevent race condition + if cls._instance is None: + cls._instance = super().__new__(cls) return cls._instance def __init__(self, **kwargs): - if TranscriptionEngine._initialized: - return + # Thread-safe initialization check + with TranscriptionEngine._lock: + if TranscriptionEngine._initialized: + return + # Set flag immediately to prevent re-initialization + TranscriptionEngine._initialized = True + # Perform initialization outside lock to avoid holding lock during slow operations global_params = { "host": "localhost", "port": 8000, @@ -172,7 +183,6 @@ class TranscriptionEngine: } translation_params = update_with_kwargs(translation_params, kwargs) self.translation_model = load_model([self.args.lan], **translation_params) #in the future we want to handle different languages for different speakers - TranscriptionEngine._initialized = True def online_factory(args, asr): diff --git a/whisperlivekit/simul_whisper/decoder_state.py b/whisperlivekit/simul_whisper/decoder_state.py index bbab43b..428e798 100644 --- a/whisperlivekit/simul_whisper/decoder_state.py +++ b/whisperlivekit/simul_whisper/decoder_state.py @@ -47,9 +47,24 @@ class DecoderState: def clean_cache(self): """Clean the kv_cache after each inference step.""" - self.kv_cache = {} + # Explicitly delete tensor references to free GPU memory + if self.kv_cache: + for key in list(self.kv_cache.keys()): + tensor = self.kv_cache.pop(key, None) + if tensor is not None: + del tensor + + # Clear the dict + self.kv_cache.clear() + + # Force GPU cache cleanup (only if CUDA is available) + import torch + if torch.cuda.is_available(): + torch.cuda.empty_cache() + if self.decoder_type == "beam" and self.inference is not None: - self.inference.kv_cache = self.kv_cache + # Create NEW dict instead of sharing reference + self.inference.kv_cache = {} if self.token_decoder is not None: self.token_decoder.reset() diff --git a/whisperlivekit/simul_whisper/simul_whisper.py b/whisperlivekit/simul_whisper/simul_whisper.py index 104db15..174e806 100644 --- a/whisperlivekit/simul_whisper/simul_whisper.py +++ b/whisperlivekit/simul_whisper/simul_whisper.py @@ -626,8 +626,10 @@ class AlignAtt: try: current_timestamp = l_absolute_timestamps[timestamp_idx] - except: - pass + except IndexError: + # Use last timestamp if index out of range + logger.warning(f"Timestamp index {timestamp_idx} out of range, using last timestamp") + current_timestamp = l_absolute_timestamps[-1] if l_absolute_timestamps else 0.0 timestamp_idx += len(word_tokens) timestamp_entry = ASRToken( diff --git a/whisperlivekit/thread_safety.py b/whisperlivekit/thread_safety.py new file mode 100644 index 0000000..18e8303 --- /dev/null +++ b/whisperlivekit/thread_safety.py @@ -0,0 +1,139 @@ +""" +Thread Safety Configuration for WhisperLiveKit + +This module provides thread safety configuration and utilities. + +Environment Variables: + WHISPERLIVEKIT_MODEL_LOCK: Enable/disable model locking (default: 1) + Set to "0" to disable for single-connection deployments + + WHISPERLIVEKIT_LOCK_TIMEOUT: Lock acquisition timeout in seconds (default: 30) + +Usage: + # Enable model locking (default) + export WHISPERLIVEKIT_MODEL_LOCK=1 + + # Disable for single-connection deployment + export WHISPERLIVEKIT_MODEL_LOCK=0 + + # Custom timeout + export WHISPERLIVEKIT_LOCK_TIMEOUT=60 +""" + +import os +import logging +import threading + +logger = logging.getLogger(__name__) + +# Configuration +USE_MODEL_LOCK = os.environ.get("WHISPERLIVEKIT_MODEL_LOCK", "1") == "1" +LOCK_TIMEOUT = float(os.environ.get("WHISPERLIVEKIT_LOCK_TIMEOUT", "30.0")) + +# Global model lock +_model_lock = threading.Lock() + +# Log configuration on import +if USE_MODEL_LOCK: + logger.info(f"Model locking ENABLED (timeout: {LOCK_TIMEOUT}s)") + logger.info("For single-connection deployments, set WHISPERLIVEKIT_MODEL_LOCK=0") +else: + logger.warning("Model locking DISABLED - only safe for single-connection deployments") + + +def get_model_lock(): + """Get the global model lock instance""" + return _model_lock + + +def acquire_model_lock(timeout=None): + """ + Acquire model lock with timeout. + + Args: + timeout: Lock acquisition timeout (default: use LOCK_TIMEOUT) + + Returns: + bool: True if lock acquired, False on timeout + """ + if not USE_MODEL_LOCK: + return True + + timeout = timeout or LOCK_TIMEOUT + acquired = _model_lock.acquire(timeout=timeout) + + if not acquired: + logger.error(f"Failed to acquire model lock within {timeout}s") + + return acquired + + +def release_model_lock(): + """Release model lock""" + if not USE_MODEL_LOCK: + return + + try: + _model_lock.release() + except RuntimeError: + # Lock not held - this is fine + pass + + +class ModelLockContext: + """Context manager for model lock""" + + def __init__(self, timeout=None): + self.timeout = timeout + self.acquired = False + + def __enter__(self): + self.acquired = acquire_model_lock(self.timeout) + return self.acquired + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.acquired: + release_model_lock() + return False + + +# Concurrency recommendations +RECOMMENDED_CONNECTIONS_PER_WORKER = 1 if USE_MODEL_LOCK else 1 +RECOMMENDED_WORKERS = 4 + +def print_deployment_recommendations(): + """Print recommended deployment configuration""" + print("\n" + "="*60) + print("WhisperLiveKit Deployment Recommendations") + print("="*60) + + if USE_MODEL_LOCK: + print("⚠️ Model locking is ENABLED") + print(" This serializes inference across connections.") + print() + print("Recommended deployment:") + print(f" gunicorn -w {RECOMMENDED_WORKERS} \\") + print(" -k uvicorn.workers.UvicornWorker \\") + print(" --worker-connections 1 \\") + print(" whisperlivekit.basic_server:app") + print() + print("Expected capacity:") + print(f" - {RECOMMENDED_WORKERS} concurrent users (1 per worker)") + print(f" - Memory: ~{RECOMMENDED_WORKERS}x model size") + else: + print("✅ Model locking is DISABLED") + print(" ⚠️ ONLY safe for single-connection deployments") + print() + print("Recommended deployment:") + print(" uvicorn whisperlivekit.basic_server:app \\") + print(" --host 0.0.0.0 --port 8000 \\") + print(" --workers 1") + print() + print("Expected capacity:") + print(" - 1 concurrent user only") + + print("="*60 + "\n") + + +if __name__ == "__main__": + print_deployment_recommendations()