16 Commits

Author SHA1 Message Date
Quentin Fuxa
e7e82f7c19 bump to 0.2.18 2026-02-11 22:10:00 +01:00
Quentin Fuxa
8c799fa4d1 fix simulstreaming vram leak: cap cross-attn accumulation + token budget
fixes #283, fixes #275

- accumulated_cross_attns was growing unboundedly during decoding loop,
  using up to ~5GB for repetition loops. now capped to rolling window of 16
- max_tokens_per_chunk was using TOKENS_PER_SECOND (mel frame rate = 50)
  instead of actual text token rate (~15/s), allowing 10-40x too many
  decoding steps
- removed unused torch.cat on early return path
- removed dead self.committed/last_result_tokens lists (never read)
- same fixes applied to mlx variant
2026-02-11 22:10:00 +01:00
Quentin Fuxa
8923337380 fix --direct-english-translation not setting task=translate for localagreement backends
the flag was only used for tokenizer language selection but never
actually passed to whisper/faster-whisper transcribe calls. also init
OpenaiApiASR.task and read from transcribe_kargs.

fixes #306
2026-02-11 22:10:00 +01:00
Quentin Fuxa
aded1649ae fix model_cache_dir + direct_english_translation task in simulstreaming
pass actual cache dir instead of None, and use proper task string
instead of boolean for AlignAttConfig

fixes #310
2026-02-11 22:10:00 +01:00
Quentin Fuxa
3b535e857a fix NoneType concatenation in add_translation
fixes #296
2026-02-11 22:10:00 +01:00
Quentin Fuxa
d649250b9a fix Segment classmethod call + isinstance type narrowing
fixes #331, fixes #329
2026-02-11 22:10:00 +01:00
Quentin Fuxa
7735478286 add insert_audio_chunk to DiartDiarization
fixes #332
2026-02-11 22:10:00 +01:00
Quentin Fuxa
b9e72d2b9a add probability field to ASRToken
fixes #330, fixes #313
2026-02-11 22:10:00 +01:00
Quentin Fuxa
e5b01033af add json normalizers for english language in build 2026-01-16 10:47:46 +01:00
Quentin Fuxa
6ae545bcb1 bump to 0.2.17.post1 2026-01-16 10:43:52 +01:00
Quentin Fuxa
04980d3f5e Merge branch 'main' of https://github.com/QuentinFuxa/WhisperLiveKit 2026-01-16 10:38:29 +01:00
Quentin Fuxa
79a705c969 fixes #323 2026-01-16 10:38:07 +01:00
Quentin Fuxa
34e4abd455 Merge pull request #322 from eschmidbauer/fix/thread-safety-issues
Fix kv cache not being properly cleaned between sessions
2026-01-09 19:23:35 +01:00
Emmanuel Schmidbauer
d59ddbaeae Fix critical thread safety issues 2026-01-09 11:23:19 -05:00
Quentin Fuxa
4dd66e7766 Merge pull request #317 from jantonj/fix-bug-diarization-lag
update diarization lag after stream analysed
2025-12-19 17:43:07 +01:00
Anton Jacobson
3db5d81a20 update diarization lag after stream analysed 2025-12-18 14:13:28 +01:00
13 changed files with 259 additions and 77 deletions

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "whisperlivekit"
version = "0.2.17"
version = "0.2.18"
description = "Real-time speech-to-text with speaker diarization using Whisper"
readme = "README.md"
authors = [
@@ -57,6 +57,7 @@ packages = [
"whisperlivekit",
"whisperlivekit.diarization",
"whisperlivekit.simul_whisper",
"whisperlivekit.simul_whisper.mlx",
"whisperlivekit.whisper",
"whisperlivekit.whisper.assets",
"whisperlivekit.whisper.normalizers",
@@ -68,4 +69,5 @@ packages = [
[tool.setuptools.package-data]
whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"]
"whisperlivekit.whisper.assets" = ["*.tiktoken", "*.npz"]
"whisperlivekit.whisper.normalizers" = ["*.json"]
"whisperlivekit.silero_vad_models" = ["*.jit", "*.onnx"]

View File

@@ -32,7 +32,7 @@ async def get_all_from_queue(queue: asyncio.Queue) -> Union[object, Silence, np.
if isinstance(first_item, Silence):
return first_item
items.append(first_item)
while True:
if not queue._queue:
break
@@ -53,15 +53,15 @@ class AudioProcessor:
Processes audio streams for transcription and diarization.
Handles audio processing, state management, and result formatting.
"""
def __init__(self, **kwargs: Any) -> None:
"""Initialize the audio processor with configuration, models, and state."""
if 'transcription_engine' in kwargs and isinstance(kwargs['transcription_engine'], TranscriptionEngine):
models = kwargs['transcription_engine']
else:
models = TranscriptionEngine(**kwargs)
# Audio processing settings
self.args = models.args
self.sample_rate = 16000
@@ -86,13 +86,13 @@ class AudioProcessor:
# Models and processing
self.asr: Any = models.asr
self.vac: Optional[FixedVADIterator] = None
if self.args.vac:
if models.vac_session is not None:
vac_model = OnnxWrapper(session=models.vac_session)
self.vac = FixedVADIterator(vac_model)
else:
self.vac = FixedVADIterator(load_jit_vad())
self.vac = FixedVADIterator(load_jit_vad())
self.ffmpeg_manager: Optional[FFmpegManager] = None
self.ffmpeg_reader_task: Optional[asyncio.Task] = None
self._ffmpeg_error: Optional[str] = None
@@ -106,7 +106,7 @@ class AudioProcessor:
logger.error(f"FFmpeg error: {error_type}")
self._ffmpeg_error = error_type
self.ffmpeg_manager.on_error_callback = handle_ffmpeg_error
self.transcription_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.transcription else None
self.diarization_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.diarization else None
self.translation_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.target_language else None
@@ -117,14 +117,14 @@ class AudioProcessor:
self.translation_task: Optional[asyncio.Task] = None
self.watchdog_task: Optional[asyncio.Task] = None
self.all_tasks_for_cleanup: List[asyncio.Task] = []
self.transcription: Optional[Any] = None
self.translation: Optional[Any] = None
self.diarization: Optional[Any] = None
if self.args.transcription:
self.transcription = online_factory(self.args, models.asr)
self.sep = self.transcription.asr.sep
self.transcription = online_factory(self.args, models.asr)
self.sep = self.transcription.asr.sep
if self.args.diarization:
self.diarization = online_diarization_factory(self.args, models.diarization_model)
if models.translation_model:
@@ -182,24 +182,24 @@ class AudioProcessor:
def convert_pcm_to_float(self, pcm_buffer: Union[bytes, bytearray]) -> np.ndarray:
"""Convert PCM buffer in s16le format to normalized NumPy array."""
return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
async def get_current_state(self) -> State:
"""Get current state."""
async with self.lock:
current_time = time()
remaining_transcription = 0
if self.state.end_buffer > 0:
remaining_transcription = max(0, round(current_time - self.beg_loop - self.state.end_buffer, 1))
remaining_diarization = 0
if self.state.tokens:
latest_end = max(self.state.end_buffer, self.state.tokens[-1].end if self.state.tokens else 0)
remaining_diarization = max(0, round(latest_end - self.state.end_attributed_speaker, 1))
self.state.remaining_time_transcription = remaining_transcription
self.state.remaining_time_diarization = remaining_diarization
return self.state
async def ffmpeg_stdout_reader(self) -> None:
@@ -255,7 +255,7 @@ class AudioProcessor:
async def transcription_processor(self) -> None:
"""Process audio chunks for transcription."""
cumulative_pcm_duration_stream_time = 0.0
while True:
try:
# item = await self.transcription_queue.get()
@@ -311,12 +311,12 @@ class AudioProcessor:
if new_tokens:
candidate_end_times.append(new_tokens[-1].end)
if _buffer_transcript.end is not None:
candidate_end_times.append(_buffer_transcript.end)
candidate_end_times.append(current_audio_processed_upto)
async with self.lock:
self.state.tokens.extend(new_tokens)
self.state.buffer_transcription = _buffer_transcript
@@ -326,13 +326,13 @@ class AudioProcessor:
if self.translation_queue:
for token in new_tokens:
await self.translation_queue.put(token)
await self.translation_queue.put(token)
except Exception as e:
logger.warning(f"Exception in transcription_processor: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}")
if 'pcm_array' in locals() and pcm_array is not SENTINEL : # Check if pcm_array was assigned from queue
self.transcription_queue.task_done()
if self.is_stopping:
logger.info("Transcription processor finishing due to stopping flag.")
if self.diarization_queue:
@@ -353,18 +353,21 @@ class AudioProcessor:
if item.has_ended:
self.diarization.insert_silence(item.duration)
continue
self.diarization.insert_audio_chunk(item)
diarization_segments = await self.diarization.diarize()
self.state.new_diarization = diarization_segments
diar_end = 0.0
if diarization_segments:
diar_end = max(getattr(s, "end", 0.0) for s in diarization_segments)
async with self.lock:
self.state.new_diarization = diarization_segments
self.state.end_attributed_speaker = max(self.state.end_attributed_speaker, diar_end)
except Exception as e:
logger.warning(f"Exception in diarization_processor: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}")
logger.info("Diarization processor task finished.")
async def translation_processor(self) -> None:
# the idea is to ignore diarization for the moment. We use only transcription tokens.
# the idea is to ignore diarization for the moment. We use only transcription tokens.
# And the speaker is attributed given the segments used for the translation
# in the future we want to have different languages for each speaker etc, so it will be more complex.
while True:
@@ -426,22 +429,22 @@ class AudioProcessor:
remaining_time_transcription=state.remaining_time_transcription,
remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0
)
should_push = (response != self.last_response_content)
if should_push:
yield response
self.last_response_content = response
if self.is_stopping and self._processing_tasks_done():
logger.info("Results formatter: All upstream processors are done and in stopping state. Terminating.")
return
await asyncio.sleep(0.05)
except Exception as e:
logger.warning(f"Exception in results_formatter. Traceback: {traceback.format_exc()}")
await asyncio.sleep(0.5)
async def create_tasks(self) -> AsyncGenerator[FrontData, None]:
"""Create and start processing tasks."""
self.all_tasks_for_cleanup = []
@@ -466,21 +469,21 @@ class AudioProcessor:
self.transcription_task = asyncio.create_task(self.transcription_processor())
self.all_tasks_for_cleanup.append(self.transcription_task)
processing_tasks_for_watchdog.append(self.transcription_task)
if self.diarization:
self.diarization_task = asyncio.create_task(self.diarization_processor())
self.all_tasks_for_cleanup.append(self.diarization_task)
processing_tasks_for_watchdog.append(self.diarization_task)
if self.translation:
self.translation_task = asyncio.create_task(self.translation_processor())
self.all_tasks_for_cleanup.append(self.translation_task)
processing_tasks_for_watchdog.append(self.translation_task)
# Monitor overall system health
self.watchdog_task = asyncio.create_task(self.watchdog(processing_tasks_for_watchdog))
self.all_tasks_for_cleanup.append(self.watchdog_task)
return self.results_formatter()
async def watchdog(self, tasks_to_monitor: List[asyncio.Task]) -> None:
@@ -493,7 +496,7 @@ class AudioProcessor:
return
await asyncio.sleep(10)
for i, task in enumerate(list(tasks_remaining)):
if task.done():
exc = task.exception()
@@ -503,13 +506,13 @@ class AudioProcessor:
else:
logger.info(f"{task_name} completed normally.")
tasks_remaining.remove(task)
except asyncio.CancelledError:
logger.info("Watchdog task cancelled.")
break
except Exception as e:
logger.error(f"Error in watchdog task: {e}", exc_info=True)
async def cleanup(self) -> None:
"""Clean up resources when processing is complete."""
logger.info("Starting cleanup of AudioProcessor resources.")
@@ -517,7 +520,7 @@ class AudioProcessor:
for task in self.all_tasks_for_cleanup:
if task and not task.done():
task.cancel()
created_tasks = [t for t in self.all_tasks_for_cleanup if t]
if created_tasks:
await asyncio.gather(*created_tasks, return_exceptions=True)
@@ -555,7 +558,7 @@ class AudioProcessor:
if not message:
logger.info("Empty audio message received, initiating stop sequence.")
self.is_stopping = True
if self.transcription_queue:
await self.transcription_queue.put(SENTINEL)
@@ -596,7 +599,7 @@ class AudioProcessor:
chunk_size = min(len(self.pcm_buffer), self.max_bytes_per_sec)
aligned_chunk_size = (chunk_size // self.bytes_per_sample) * self.bytes_per_sample
if aligned_chunk_size == 0:
return
pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:aligned_chunk_size])
@@ -613,7 +616,7 @@ class AudioProcessor:
if res is not None:
if "start" in res and self.current_silence:
await self._end_silence()
if "end" in res and not self.current_silence:
pre_silence_chunk = self._slice_before_silence(
pcm_array, chunk_sample_start, res.get("end")

View File

@@ -1,5 +1,6 @@
import logging
import sys
import threading
from argparse import Namespace
from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor
@@ -19,16 +20,26 @@ logger = logging.getLogger(__name__)
class TranscriptionEngine:
_instance = None
_initialized = False
_lock = threading.Lock() # Thread-safe singleton lock
def __new__(cls, *args, **kwargs):
# Double-checked locking pattern for thread-safe singleton
if cls._instance is None:
cls._instance = super().__new__(cls)
with cls._lock:
# Check again inside lock to prevent race condition
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, **kwargs):
if TranscriptionEngine._initialized:
return
# Thread-safe initialization check
with TranscriptionEngine._lock:
if TranscriptionEngine._initialized:
return
# Set flag immediately to prevent re-initialization
TranscriptionEngine._initialized = True
# Perform initialization outside lock to avoid holding lock during slow operations
global_params = {
"host": "localhost",
"port": 8000,
@@ -172,7 +183,6 @@ class TranscriptionEngine:
}
translation_params = update_with_kwargs(translation_params, kwargs)
self.translation_model = load_model([self.args.lan], **translation_params) #in the future we want to handle different languages for different speakers
TranscriptionEngine._initialized = True
def online_factory(args, asr):

View File

@@ -202,14 +202,14 @@ class DiartDiarization:
def insert_silence(self, silence_duration):
self.observer.global_time_offset += silence_duration
async def diarize(self, pcm_array: np.ndarray):
"""
Process audio data for diarization.
Only used when working with WebSocketAudioSource.
"""
def insert_audio_chunk(self, pcm_array: np.ndarray):
"""Buffer audio for the next diarization step."""
if self.custom_source:
self.custom_source.push_audio(pcm_array)
# self.observer.clear_old_segments()
self.custom_source.push_audio(pcm_array)
async def diarize(self):
"""Return the current speaker segments from the diarization pipeline."""
return self.observer.get_segments()
def close(self):
"""Close the audio source."""

View File

@@ -151,7 +151,7 @@ class FasterWhisperASR(ASRBase):
if segment.no_speech_prob > 0.9:
continue
for word in segment.words:
token = ASRToken(word.start, word.end, word.word)
token = ASRToken(word.start, word.end, word.word, probability=word.probability)
tokens.append(token)
return tokens
@@ -249,6 +249,7 @@ class OpenaiApiASR(ASRBase):
self.load_model()
self.use_vad_opt = False
self.direct_english_translation = False
self.task = "transcribe"
def load_model(self, *args, **kwargs):
from openai import OpenAI
@@ -294,7 +295,8 @@ class OpenaiApiASR(ASRBase):
params["language"] = self.original_language
if prompt:
params["prompt"] = prompt
proc = self.client.audio.translations if self.task == "translate" else self.client.audio.transcriptions
task = self.transcribe_kargs.get("task", self.task)
proc = self.client.audio.translations if task == "translate" else self.client.audio.transcriptions
transcript = proc.create(**params)
logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds")
return transcript

View File

@@ -146,6 +146,7 @@ def backend_factory(
if direct_english_translation:
tgt_language = "en" # Whisper translates into English
asr.transcribe_kargs["task"] = "translate"
else:
tgt_language = lan # Whisper transcribes in this language
@@ -154,9 +155,9 @@ def backend_factory(
tokenizer = create_tokenizer(tgt_language)
else:
tokenizer = None
warmup_asr(asr, warmup_file)
asr.confidence_validation = confidence_validation
asr.tokenizer = tokenizer
asr.buffer_trimming = buffer_trimming

View File

@@ -46,8 +46,6 @@ class SimulStreamingOnlineProcessor:
self.logfile = logfile
self.end = 0.0
self.buffer = []
self.committed: List[ASRToken] = []
self.last_result_tokens: List[ASRToken] = []
self.model = self._create_alignatt()
if asr.tokenizer:
@@ -122,7 +120,6 @@ class SimulStreamingOnlineProcessor:
self.buffer.extend(timestamped_words)
return [], self.end
self.committed.extend(timestamped_words)
self.buffer = []
return timestamped_words, self.end
except Exception as e:
@@ -217,7 +214,7 @@ class SimulStreamingASR:
cif_ckpt_path=self.cif_ckpt_path,
decoder_type="beam",
beam_size=self.beams,
task=self.direct_english_translation,
task="translate" if self.direct_english_translation else "transcribe",
never_fire=self.never_fire,
init_prompt=self.init_prompt,
max_context_tokens=self.max_context_tokens,
@@ -330,7 +327,7 @@ class SimulStreamingASR:
lora_path = getattr(self, 'lora_path', None)
whisper_model = load_model(
name=model_ref,
download_root=None,
download_root=getattr(self, 'model_cache_dir', None),
decoder_only=self.fast_encoder,
custom_alignment_heads=self.custom_alignment_heads,
lora_path=lora_path,

View File

@@ -47,9 +47,24 @@ class DecoderState:
def clean_cache(self):
"""Clean the kv_cache after each inference step."""
self.kv_cache = {}
# Explicitly delete tensor references to free GPU memory
if self.kv_cache:
for key in list(self.kv_cache.keys()):
tensor = self.kv_cache.pop(key, None)
if tensor is not None:
del tensor
# Clear the dict
self.kv_cache.clear()
# Force GPU cache cleanup (only if CUDA is available)
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
if self.decoder_type == "beam" and self.inference is not None:
self.inference.kv_cache = self.kv_cache
# Create NEW dict instead of sharing reference
self.inference.kv_cache = {}
if self.token_decoder is not None:
self.token_decoder.reset()

View File

@@ -532,7 +532,9 @@ class MLXAlignAtt:
accumulated_cross_attns = []
audio_duration_s = self.segments_len()
max_tokens_per_chunk = max(50, int(audio_duration_s * TOKENS_PER_SECOND * 2.0))
# ~15 text tokens/s is a generous upper bound for speech; TOKENS_PER_SECOND (50)
# is the mel-frame rate and was causing 10-40x over-allocation on repetition loops.
max_tokens_per_chunk = max(50, int(audio_duration_s * 15 * 1.5))
tokens_produced_this_chunk = 0
while not completed and current_tokens.shape[1] < self.max_text_len:
@@ -558,6 +560,8 @@ class MLXAlignAtt:
mx.eval(logits)
accumulated_cross_attns.append(cross_qk)
if len(accumulated_cross_attns) > 16:
accumulated_cross_attns = accumulated_cross_attns[-16:]
if new_segment and self.tokenizer.no_speech is not None:
probs_at_sot = mx.softmax(logits[:, self.state.sot_index, :], axis=-1)

View File

@@ -390,7 +390,6 @@ class AlignAtt:
return []
if not self._apply_minseglen():
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
input_segments = torch.cat(self.state.segments, dim=0)
return []
# input_segments is concatenation of audio, it's one array
@@ -485,7 +484,9 @@ class AlignAtt:
accumulated_cross_attns = []
audio_duration_s = self.segments_len()
max_tokens_per_chunk = max(50, int(audio_duration_s * TOKENS_PER_SECOND * 2.0)) # 2x margin, min 50
# ~15 text tokens/s is a generous upper bound for speech; TOKENS_PER_SECOND (50)
# is the mel-frame rate and was causing 10-40x over-allocation on repetition loops.
max_tokens_per_chunk = max(50, int(audio_duration_s * 15 * 1.5))
tokens_produced_this_chunk = 0
while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens
@@ -506,8 +507,12 @@ class AlignAtt:
result = self.logits(tokens_for_logits, encoder_feature, return_cross_attn=True)
logits, cross_attns = result
# Accumulate cross-attention from this forward pass
# Accumulate cross-attention from this forward pass (rolling window to
# bound VRAM — only the last entry matters for alignment, and the
# median_filter kernel is 7, so 16 entries is more than enough).
accumulated_cross_attns.append(cross_attns)
if len(accumulated_cross_attns) > 16:
accumulated_cross_attns = accumulated_cross_attns[-16:]
if new_segment and self.tokenizer.no_speech is not None:
probs_at_sot = logits[:, self.state.sot_index, :].float().softmax(dim=-1)
@@ -626,8 +631,10 @@ class AlignAtt:
try:
current_timestamp = l_absolute_timestamps[timestamp_idx]
except:
pass
except IndexError:
# Use last timestamp if index out of range
logger.warning(f"Timestamp index {timestamp_idx} out of range, using last timestamp")
current_timestamp = l_absolute_timestamps[-1] if l_absolute_timestamps else 0.0
timestamp_idx += len(word_tokens)
timestamp_entry = ASRToken(

View File

@@ -0,0 +1,139 @@
"""
Thread Safety Configuration for WhisperLiveKit
This module provides thread safety configuration and utilities.
Environment Variables:
WHISPERLIVEKIT_MODEL_LOCK: Enable/disable model locking (default: 1)
Set to "0" to disable for single-connection deployments
WHISPERLIVEKIT_LOCK_TIMEOUT: Lock acquisition timeout in seconds (default: 30)
Usage:
# Enable model locking (default)
export WHISPERLIVEKIT_MODEL_LOCK=1
# Disable for single-connection deployment
export WHISPERLIVEKIT_MODEL_LOCK=0
# Custom timeout
export WHISPERLIVEKIT_LOCK_TIMEOUT=60
"""
import os
import logging
import threading
logger = logging.getLogger(__name__)
# Configuration
USE_MODEL_LOCK = os.environ.get("WHISPERLIVEKIT_MODEL_LOCK", "1") == "1"
LOCK_TIMEOUT = float(os.environ.get("WHISPERLIVEKIT_LOCK_TIMEOUT", "30.0"))
# Global model lock
_model_lock = threading.Lock()
# Log configuration on import
if USE_MODEL_LOCK:
logger.info(f"Model locking ENABLED (timeout: {LOCK_TIMEOUT}s)")
logger.info("For single-connection deployments, set WHISPERLIVEKIT_MODEL_LOCK=0")
else:
logger.warning("Model locking DISABLED - only safe for single-connection deployments")
def get_model_lock():
"""Get the global model lock instance"""
return _model_lock
def acquire_model_lock(timeout=None):
"""
Acquire model lock with timeout.
Args:
timeout: Lock acquisition timeout (default: use LOCK_TIMEOUT)
Returns:
bool: True if lock acquired, False on timeout
"""
if not USE_MODEL_LOCK:
return True
timeout = timeout or LOCK_TIMEOUT
acquired = _model_lock.acquire(timeout=timeout)
if not acquired:
logger.error(f"Failed to acquire model lock within {timeout}s")
return acquired
def release_model_lock():
"""Release model lock"""
if not USE_MODEL_LOCK:
return
try:
_model_lock.release()
except RuntimeError:
# Lock not held - this is fine
pass
class ModelLockContext:
"""Context manager for model lock"""
def __init__(self, timeout=None):
self.timeout = timeout
self.acquired = False
def __enter__(self):
self.acquired = acquire_model_lock(self.timeout)
return self.acquired
def __exit__(self, exc_type, exc_val, exc_tb):
if self.acquired:
release_model_lock()
return False
# Concurrency recommendations
RECOMMENDED_CONNECTIONS_PER_WORKER = 1 if USE_MODEL_LOCK else 1
RECOMMENDED_WORKERS = 4
def print_deployment_recommendations():
"""Print recommended deployment configuration"""
print("\n" + "="*60)
print("WhisperLiveKit Deployment Recommendations")
print("="*60)
if USE_MODEL_LOCK:
print("⚠️ Model locking is ENABLED")
print(" This serializes inference across connections.")
print()
print("Recommended deployment:")
print(f" gunicorn -w {RECOMMENDED_WORKERS} \\")
print(" -k uvicorn.workers.UvicornWorker \\")
print(" --worker-connections 1 \\")
print(" whisperlivekit.basic_server:app")
print()
print("Expected capacity:")
print(f" - {RECOMMENDED_WORKERS} concurrent users (1 per worker)")
print(f" - Memory: ~{RECOMMENDED_WORKERS}x model size")
else:
print("✅ Model locking is DISABLED")
print(" ⚠️ ONLY safe for single-connection deployments")
print()
print("Recommended deployment:")
print(" uvicorn whisperlivekit.basic_server:app \\")
print(" --host 0.0.0.0 --port 8000 \\")
print(" --workers 1")
print()
print("Expected capacity:")
print(" - 1 concurrent user only")
print("="*60 + "\n")
if __name__ == "__main__":
print_deployment_recommendations()

View File

@@ -39,10 +39,11 @@ class TimedText(Timed):
@dataclass()
class ASRToken(TimedText):
probability: Optional[float] = None
def with_offset(self, offset: float) -> "ASRToken":
"""Return a new token with the time offset added."""
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, detected_language=self.detected_language)
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, detected_language=self.detected_language, probability=self.probability)
def is_silence(self) -> bool:
return False

View File

@@ -53,7 +53,8 @@ class TokensAlignment:
segment.translation = ''
for ts in self.all_translation_segments:
if ts.is_within(segment):
segment.translation += ts.text + (self.sep if ts.text else '')
if ts.text:
segment.translation += ts.text + self.sep
elif segment.translation:
break
@@ -185,11 +186,11 @@ class TokensAlignment:
else:
diarization_buffer = ''
for token in self.new_tokens:
if token.is_silence():
if isinstance(token, Silence):
if self.current_line_tokens:
self.validated_segments.append(Segment().from_tokens(self.current_line_tokens))
self.validated_segments.append(Segment.from_tokens(self.current_line_tokens))
self.current_line_tokens = []
end_silence = token.end if token.has_ended else time() - self.beg_loop
if self.validated_segments and self.validated_segments[-1].is_silence():
self.validated_segments[-1].end = end_silence
@@ -203,7 +204,7 @@ class TokensAlignment:
segments = list(self.validated_segments)
if self.current_line_tokens:
segments.append(Segment().from_tokens(self.current_line_tokens))
segments.append(Segment.from_tokens(self.current_line_tokens))
if current_silence:
end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop