16 Commits

Author SHA1 Message Date
Quentin Fuxa
e7e82f7c19 bump to 0.2.18 2026-02-11 22:10:00 +01:00
Quentin Fuxa
8c799fa4d1 fix simulstreaming vram leak: cap cross-attn accumulation + token budget
fixes #283, fixes #275

- accumulated_cross_attns was growing unboundedly during decoding loop,
  using up to ~5GB for repetition loops. now capped to rolling window of 16
- max_tokens_per_chunk was using TOKENS_PER_SECOND (mel frame rate = 50)
  instead of actual text token rate (~15/s), allowing 10-40x too many
  decoding steps
- removed unused torch.cat on early return path
- removed dead self.committed/last_result_tokens lists (never read)
- same fixes applied to mlx variant
2026-02-11 22:10:00 +01:00
Quentin Fuxa
8923337380 fix --direct-english-translation not setting task=translate for localagreement backends
the flag was only used for tokenizer language selection but never
actually passed to whisper/faster-whisper transcribe calls. also init
OpenaiApiASR.task and read from transcribe_kargs.

fixes #306
2026-02-11 22:10:00 +01:00
Quentin Fuxa
aded1649ae fix model_cache_dir + direct_english_translation task in simulstreaming
pass actual cache dir instead of None, and use proper task string
instead of boolean for AlignAttConfig

fixes #310
2026-02-11 22:10:00 +01:00
Quentin Fuxa
3b535e857a fix NoneType concatenation in add_translation
fixes #296
2026-02-11 22:10:00 +01:00
Quentin Fuxa
d649250b9a fix Segment classmethod call + isinstance type narrowing
fixes #331, fixes #329
2026-02-11 22:10:00 +01:00
Quentin Fuxa
7735478286 add insert_audio_chunk to DiartDiarization
fixes #332
2026-02-11 22:10:00 +01:00
Quentin Fuxa
b9e72d2b9a add probability field to ASRToken
fixes #330, fixes #313
2026-02-11 22:10:00 +01:00
Quentin Fuxa
e5b01033af add json normalizers for english language in build 2026-01-16 10:47:46 +01:00
Quentin Fuxa
6ae545bcb1 bump to 0.2.17.post1 2026-01-16 10:43:52 +01:00
Quentin Fuxa
04980d3f5e Merge branch 'main' of https://github.com/QuentinFuxa/WhisperLiveKit 2026-01-16 10:38:29 +01:00
Quentin Fuxa
79a705c969 fixes #323 2026-01-16 10:38:07 +01:00
Quentin Fuxa
34e4abd455 Merge pull request #322 from eschmidbauer/fix/thread-safety-issues
Fix kv cache not being properly cleaned between sessions
2026-01-09 19:23:35 +01:00
Emmanuel Schmidbauer
d59ddbaeae Fix critical thread safety issues 2026-01-09 11:23:19 -05:00
Quentin Fuxa
4dd66e7766 Merge pull request #317 from jantonj/fix-bug-diarization-lag
update diarization lag after stream analysed
2025-12-19 17:43:07 +01:00
Anton Jacobson
3db5d81a20 update diarization lag after stream analysed 2025-12-18 14:13:28 +01:00
13 changed files with 259 additions and 77 deletions

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "whisperlivekit" name = "whisperlivekit"
version = "0.2.17" version = "0.2.18"
description = "Real-time speech-to-text with speaker diarization using Whisper" description = "Real-time speech-to-text with speaker diarization using Whisper"
readme = "README.md" readme = "README.md"
authors = [ authors = [
@@ -57,6 +57,7 @@ packages = [
"whisperlivekit", "whisperlivekit",
"whisperlivekit.diarization", "whisperlivekit.diarization",
"whisperlivekit.simul_whisper", "whisperlivekit.simul_whisper",
"whisperlivekit.simul_whisper.mlx",
"whisperlivekit.whisper", "whisperlivekit.whisper",
"whisperlivekit.whisper.assets", "whisperlivekit.whisper.assets",
"whisperlivekit.whisper.normalizers", "whisperlivekit.whisper.normalizers",
@@ -68,4 +69,5 @@ packages = [
[tool.setuptools.package-data] [tool.setuptools.package-data]
whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"] whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"]
"whisperlivekit.whisper.assets" = ["*.tiktoken", "*.npz"] "whisperlivekit.whisper.assets" = ["*.tiktoken", "*.npz"]
"whisperlivekit.whisper.normalizers" = ["*.json"]
"whisperlivekit.silero_vad_models" = ["*.jit", "*.onnx"] "whisperlivekit.silero_vad_models" = ["*.jit", "*.onnx"]

View File

@@ -32,7 +32,7 @@ async def get_all_from_queue(queue: asyncio.Queue) -> Union[object, Silence, np.
if isinstance(first_item, Silence): if isinstance(first_item, Silence):
return first_item return first_item
items.append(first_item) items.append(first_item)
while True: while True:
if not queue._queue: if not queue._queue:
break break
@@ -53,15 +53,15 @@ class AudioProcessor:
Processes audio streams for transcription and diarization. Processes audio streams for transcription and diarization.
Handles audio processing, state management, and result formatting. Handles audio processing, state management, and result formatting.
""" """
def __init__(self, **kwargs: Any) -> None: def __init__(self, **kwargs: Any) -> None:
"""Initialize the audio processor with configuration, models, and state.""" """Initialize the audio processor with configuration, models, and state."""
if 'transcription_engine' in kwargs and isinstance(kwargs['transcription_engine'], TranscriptionEngine): if 'transcription_engine' in kwargs and isinstance(kwargs['transcription_engine'], TranscriptionEngine):
models = kwargs['transcription_engine'] models = kwargs['transcription_engine']
else: else:
models = TranscriptionEngine(**kwargs) models = TranscriptionEngine(**kwargs)
# Audio processing settings # Audio processing settings
self.args = models.args self.args = models.args
self.sample_rate = 16000 self.sample_rate = 16000
@@ -86,13 +86,13 @@ class AudioProcessor:
# Models and processing # Models and processing
self.asr: Any = models.asr self.asr: Any = models.asr
self.vac: Optional[FixedVADIterator] = None self.vac: Optional[FixedVADIterator] = None
if self.args.vac: if self.args.vac:
if models.vac_session is not None: if models.vac_session is not None:
vac_model = OnnxWrapper(session=models.vac_session) vac_model = OnnxWrapper(session=models.vac_session)
self.vac = FixedVADIterator(vac_model) self.vac = FixedVADIterator(vac_model)
else: else:
self.vac = FixedVADIterator(load_jit_vad()) self.vac = FixedVADIterator(load_jit_vad())
self.ffmpeg_manager: Optional[FFmpegManager] = None self.ffmpeg_manager: Optional[FFmpegManager] = None
self.ffmpeg_reader_task: Optional[asyncio.Task] = None self.ffmpeg_reader_task: Optional[asyncio.Task] = None
self._ffmpeg_error: Optional[str] = None self._ffmpeg_error: Optional[str] = None
@@ -106,7 +106,7 @@ class AudioProcessor:
logger.error(f"FFmpeg error: {error_type}") logger.error(f"FFmpeg error: {error_type}")
self._ffmpeg_error = error_type self._ffmpeg_error = error_type
self.ffmpeg_manager.on_error_callback = handle_ffmpeg_error self.ffmpeg_manager.on_error_callback = handle_ffmpeg_error
self.transcription_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.transcription else None self.transcription_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.transcription else None
self.diarization_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.diarization else None self.diarization_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.diarization else None
self.translation_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.target_language else None self.translation_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.target_language else None
@@ -117,14 +117,14 @@ class AudioProcessor:
self.translation_task: Optional[asyncio.Task] = None self.translation_task: Optional[asyncio.Task] = None
self.watchdog_task: Optional[asyncio.Task] = None self.watchdog_task: Optional[asyncio.Task] = None
self.all_tasks_for_cleanup: List[asyncio.Task] = [] self.all_tasks_for_cleanup: List[asyncio.Task] = []
self.transcription: Optional[Any] = None self.transcription: Optional[Any] = None
self.translation: Optional[Any] = None self.translation: Optional[Any] = None
self.diarization: Optional[Any] = None self.diarization: Optional[Any] = None
if self.args.transcription: if self.args.transcription:
self.transcription = online_factory(self.args, models.asr) self.transcription = online_factory(self.args, models.asr)
self.sep = self.transcription.asr.sep self.sep = self.transcription.asr.sep
if self.args.diarization: if self.args.diarization:
self.diarization = online_diarization_factory(self.args, models.diarization_model) self.diarization = online_diarization_factory(self.args, models.diarization_model)
if models.translation_model: if models.translation_model:
@@ -182,24 +182,24 @@ class AudioProcessor:
def convert_pcm_to_float(self, pcm_buffer: Union[bytes, bytearray]) -> np.ndarray: def convert_pcm_to_float(self, pcm_buffer: Union[bytes, bytearray]) -> np.ndarray:
"""Convert PCM buffer in s16le format to normalized NumPy array.""" """Convert PCM buffer in s16le format to normalized NumPy array."""
return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0 return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
async def get_current_state(self) -> State: async def get_current_state(self) -> State:
"""Get current state.""" """Get current state."""
async with self.lock: async with self.lock:
current_time = time() current_time = time()
remaining_transcription = 0 remaining_transcription = 0
if self.state.end_buffer > 0: if self.state.end_buffer > 0:
remaining_transcription = max(0, round(current_time - self.beg_loop - self.state.end_buffer, 1)) remaining_transcription = max(0, round(current_time - self.beg_loop - self.state.end_buffer, 1))
remaining_diarization = 0 remaining_diarization = 0
if self.state.tokens: if self.state.tokens:
latest_end = max(self.state.end_buffer, self.state.tokens[-1].end if self.state.tokens else 0) latest_end = max(self.state.end_buffer, self.state.tokens[-1].end if self.state.tokens else 0)
remaining_diarization = max(0, round(latest_end - self.state.end_attributed_speaker, 1)) remaining_diarization = max(0, round(latest_end - self.state.end_attributed_speaker, 1))
self.state.remaining_time_transcription = remaining_transcription self.state.remaining_time_transcription = remaining_transcription
self.state.remaining_time_diarization = remaining_diarization self.state.remaining_time_diarization = remaining_diarization
return self.state return self.state
async def ffmpeg_stdout_reader(self) -> None: async def ffmpeg_stdout_reader(self) -> None:
@@ -255,7 +255,7 @@ class AudioProcessor:
async def transcription_processor(self) -> None: async def transcription_processor(self) -> None:
"""Process audio chunks for transcription.""" """Process audio chunks for transcription."""
cumulative_pcm_duration_stream_time = 0.0 cumulative_pcm_duration_stream_time = 0.0
while True: while True:
try: try:
# item = await self.transcription_queue.get() # item = await self.transcription_queue.get()
@@ -311,12 +311,12 @@ class AudioProcessor:
if new_tokens: if new_tokens:
candidate_end_times.append(new_tokens[-1].end) candidate_end_times.append(new_tokens[-1].end)
if _buffer_transcript.end is not None: if _buffer_transcript.end is not None:
candidate_end_times.append(_buffer_transcript.end) candidate_end_times.append(_buffer_transcript.end)
candidate_end_times.append(current_audio_processed_upto) candidate_end_times.append(current_audio_processed_upto)
async with self.lock: async with self.lock:
self.state.tokens.extend(new_tokens) self.state.tokens.extend(new_tokens)
self.state.buffer_transcription = _buffer_transcript self.state.buffer_transcription = _buffer_transcript
@@ -326,13 +326,13 @@ class AudioProcessor:
if self.translation_queue: if self.translation_queue:
for token in new_tokens: for token in new_tokens:
await self.translation_queue.put(token) await self.translation_queue.put(token)
except Exception as e: except Exception as e:
logger.warning(f"Exception in transcription_processor: {e}") logger.warning(f"Exception in transcription_processor: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}") logger.warning(f"Traceback: {traceback.format_exc()}")
if 'pcm_array' in locals() and pcm_array is not SENTINEL : # Check if pcm_array was assigned from queue if 'pcm_array' in locals() and pcm_array is not SENTINEL : # Check if pcm_array was assigned from queue
self.transcription_queue.task_done() self.transcription_queue.task_done()
if self.is_stopping: if self.is_stopping:
logger.info("Transcription processor finishing due to stopping flag.") logger.info("Transcription processor finishing due to stopping flag.")
if self.diarization_queue: if self.diarization_queue:
@@ -353,18 +353,21 @@ class AudioProcessor:
if item.has_ended: if item.has_ended:
self.diarization.insert_silence(item.duration) self.diarization.insert_silence(item.duration)
continue continue
self.diarization.insert_audio_chunk(item) self.diarization.insert_audio_chunk(item)
diarization_segments = await self.diarization.diarize() diarization_segments = await self.diarization.diarize()
self.state.new_diarization = diarization_segments diar_end = 0.0
if diarization_segments:
diar_end = max(getattr(s, "end", 0.0) for s in diarization_segments)
async with self.lock:
self.state.new_diarization = diarization_segments
self.state.end_attributed_speaker = max(self.state.end_attributed_speaker, diar_end)
except Exception as e: except Exception as e:
logger.warning(f"Exception in diarization_processor: {e}") logger.warning(f"Exception in diarization_processor: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}") logger.warning(f"Traceback: {traceback.format_exc()}")
logger.info("Diarization processor task finished.") logger.info("Diarization processor task finished.")
async def translation_processor(self) -> None: async def translation_processor(self) -> None:
# the idea is to ignore diarization for the moment. We use only transcription tokens. # the idea is to ignore diarization for the moment. We use only transcription tokens.
# And the speaker is attributed given the segments used for the translation # And the speaker is attributed given the segments used for the translation
# in the future we want to have different languages for each speaker etc, so it will be more complex. # in the future we want to have different languages for each speaker etc, so it will be more complex.
while True: while True:
@@ -426,22 +429,22 @@ class AudioProcessor:
remaining_time_transcription=state.remaining_time_transcription, remaining_time_transcription=state.remaining_time_transcription,
remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0 remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0
) )
should_push = (response != self.last_response_content) should_push = (response != self.last_response_content)
if should_push: if should_push:
yield response yield response
self.last_response_content = response self.last_response_content = response
if self.is_stopping and self._processing_tasks_done(): if self.is_stopping and self._processing_tasks_done():
logger.info("Results formatter: All upstream processors are done and in stopping state. Terminating.") logger.info("Results formatter: All upstream processors are done and in stopping state. Terminating.")
return return
await asyncio.sleep(0.05) await asyncio.sleep(0.05)
except Exception as e: except Exception as e:
logger.warning(f"Exception in results_formatter. Traceback: {traceback.format_exc()}") logger.warning(f"Exception in results_formatter. Traceback: {traceback.format_exc()}")
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
async def create_tasks(self) -> AsyncGenerator[FrontData, None]: async def create_tasks(self) -> AsyncGenerator[FrontData, None]:
"""Create and start processing tasks.""" """Create and start processing tasks."""
self.all_tasks_for_cleanup = [] self.all_tasks_for_cleanup = []
@@ -466,21 +469,21 @@ class AudioProcessor:
self.transcription_task = asyncio.create_task(self.transcription_processor()) self.transcription_task = asyncio.create_task(self.transcription_processor())
self.all_tasks_for_cleanup.append(self.transcription_task) self.all_tasks_for_cleanup.append(self.transcription_task)
processing_tasks_for_watchdog.append(self.transcription_task) processing_tasks_for_watchdog.append(self.transcription_task)
if self.diarization: if self.diarization:
self.diarization_task = asyncio.create_task(self.diarization_processor()) self.diarization_task = asyncio.create_task(self.diarization_processor())
self.all_tasks_for_cleanup.append(self.diarization_task) self.all_tasks_for_cleanup.append(self.diarization_task)
processing_tasks_for_watchdog.append(self.diarization_task) processing_tasks_for_watchdog.append(self.diarization_task)
if self.translation: if self.translation:
self.translation_task = asyncio.create_task(self.translation_processor()) self.translation_task = asyncio.create_task(self.translation_processor())
self.all_tasks_for_cleanup.append(self.translation_task) self.all_tasks_for_cleanup.append(self.translation_task)
processing_tasks_for_watchdog.append(self.translation_task) processing_tasks_for_watchdog.append(self.translation_task)
# Monitor overall system health # Monitor overall system health
self.watchdog_task = asyncio.create_task(self.watchdog(processing_tasks_for_watchdog)) self.watchdog_task = asyncio.create_task(self.watchdog(processing_tasks_for_watchdog))
self.all_tasks_for_cleanup.append(self.watchdog_task) self.all_tasks_for_cleanup.append(self.watchdog_task)
return self.results_formatter() return self.results_formatter()
async def watchdog(self, tasks_to_monitor: List[asyncio.Task]) -> None: async def watchdog(self, tasks_to_monitor: List[asyncio.Task]) -> None:
@@ -493,7 +496,7 @@ class AudioProcessor:
return return
await asyncio.sleep(10) await asyncio.sleep(10)
for i, task in enumerate(list(tasks_remaining)): for i, task in enumerate(list(tasks_remaining)):
if task.done(): if task.done():
exc = task.exception() exc = task.exception()
@@ -503,13 +506,13 @@ class AudioProcessor:
else: else:
logger.info(f"{task_name} completed normally.") logger.info(f"{task_name} completed normally.")
tasks_remaining.remove(task) tasks_remaining.remove(task)
except asyncio.CancelledError: except asyncio.CancelledError:
logger.info("Watchdog task cancelled.") logger.info("Watchdog task cancelled.")
break break
except Exception as e: except Exception as e:
logger.error(f"Error in watchdog task: {e}", exc_info=True) logger.error(f"Error in watchdog task: {e}", exc_info=True)
async def cleanup(self) -> None: async def cleanup(self) -> None:
"""Clean up resources when processing is complete.""" """Clean up resources when processing is complete."""
logger.info("Starting cleanup of AudioProcessor resources.") logger.info("Starting cleanup of AudioProcessor resources.")
@@ -517,7 +520,7 @@ class AudioProcessor:
for task in self.all_tasks_for_cleanup: for task in self.all_tasks_for_cleanup:
if task and not task.done(): if task and not task.done():
task.cancel() task.cancel()
created_tasks = [t for t in self.all_tasks_for_cleanup if t] created_tasks = [t for t in self.all_tasks_for_cleanup if t]
if created_tasks: if created_tasks:
await asyncio.gather(*created_tasks, return_exceptions=True) await asyncio.gather(*created_tasks, return_exceptions=True)
@@ -555,7 +558,7 @@ class AudioProcessor:
if not message: if not message:
logger.info("Empty audio message received, initiating stop sequence.") logger.info("Empty audio message received, initiating stop sequence.")
self.is_stopping = True self.is_stopping = True
if self.transcription_queue: if self.transcription_queue:
await self.transcription_queue.put(SENTINEL) await self.transcription_queue.put(SENTINEL)
@@ -596,7 +599,7 @@ class AudioProcessor:
chunk_size = min(len(self.pcm_buffer), self.max_bytes_per_sec) chunk_size = min(len(self.pcm_buffer), self.max_bytes_per_sec)
aligned_chunk_size = (chunk_size // self.bytes_per_sample) * self.bytes_per_sample aligned_chunk_size = (chunk_size // self.bytes_per_sample) * self.bytes_per_sample
if aligned_chunk_size == 0: if aligned_chunk_size == 0:
return return
pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:aligned_chunk_size]) pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:aligned_chunk_size])
@@ -613,7 +616,7 @@ class AudioProcessor:
if res is not None: if res is not None:
if "start" in res and self.current_silence: if "start" in res and self.current_silence:
await self._end_silence() await self._end_silence()
if "end" in res and not self.current_silence: if "end" in res and not self.current_silence:
pre_silence_chunk = self._slice_before_silence( pre_silence_chunk = self._slice_before_silence(
pcm_array, chunk_sample_start, res.get("end") pcm_array, chunk_sample_start, res.get("end")

View File

@@ -1,5 +1,6 @@
import logging import logging
import sys import sys
import threading
from argparse import Namespace from argparse import Namespace
from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor
@@ -19,16 +20,26 @@ logger = logging.getLogger(__name__)
class TranscriptionEngine: class TranscriptionEngine:
_instance = None _instance = None
_initialized = False _initialized = False
_lock = threading.Lock() # Thread-safe singleton lock
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
# Double-checked locking pattern for thread-safe singleton
if cls._instance is None: if cls._instance is None:
cls._instance = super().__new__(cls) with cls._lock:
# Check again inside lock to prevent race condition
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance return cls._instance
def __init__(self, **kwargs): def __init__(self, **kwargs):
if TranscriptionEngine._initialized: # Thread-safe initialization check
return with TranscriptionEngine._lock:
if TranscriptionEngine._initialized:
return
# Set flag immediately to prevent re-initialization
TranscriptionEngine._initialized = True
# Perform initialization outside lock to avoid holding lock during slow operations
global_params = { global_params = {
"host": "localhost", "host": "localhost",
"port": 8000, "port": 8000,
@@ -172,7 +183,6 @@ class TranscriptionEngine:
} }
translation_params = update_with_kwargs(translation_params, kwargs) translation_params = update_with_kwargs(translation_params, kwargs)
self.translation_model = load_model([self.args.lan], **translation_params) #in the future we want to handle different languages for different speakers self.translation_model = load_model([self.args.lan], **translation_params) #in the future we want to handle different languages for different speakers
TranscriptionEngine._initialized = True
def online_factory(args, asr): def online_factory(args, asr):

View File

@@ -202,14 +202,14 @@ class DiartDiarization:
def insert_silence(self, silence_duration): def insert_silence(self, silence_duration):
self.observer.global_time_offset += silence_duration self.observer.global_time_offset += silence_duration
async def diarize(self, pcm_array: np.ndarray): def insert_audio_chunk(self, pcm_array: np.ndarray):
""" """Buffer audio for the next diarization step."""
Process audio data for diarization.
Only used when working with WebSocketAudioSource.
"""
if self.custom_source: if self.custom_source:
self.custom_source.push_audio(pcm_array) self.custom_source.push_audio(pcm_array)
# self.observer.clear_old_segments()
async def diarize(self):
"""Return the current speaker segments from the diarization pipeline."""
return self.observer.get_segments()
def close(self): def close(self):
"""Close the audio source.""" """Close the audio source."""

View File

@@ -151,7 +151,7 @@ class FasterWhisperASR(ASRBase):
if segment.no_speech_prob > 0.9: if segment.no_speech_prob > 0.9:
continue continue
for word in segment.words: for word in segment.words:
token = ASRToken(word.start, word.end, word.word) token = ASRToken(word.start, word.end, word.word, probability=word.probability)
tokens.append(token) tokens.append(token)
return tokens return tokens
@@ -249,6 +249,7 @@ class OpenaiApiASR(ASRBase):
self.load_model() self.load_model()
self.use_vad_opt = False self.use_vad_opt = False
self.direct_english_translation = False self.direct_english_translation = False
self.task = "transcribe"
def load_model(self, *args, **kwargs): def load_model(self, *args, **kwargs):
from openai import OpenAI from openai import OpenAI
@@ -294,7 +295,8 @@ class OpenaiApiASR(ASRBase):
params["language"] = self.original_language params["language"] = self.original_language
if prompt: if prompt:
params["prompt"] = prompt params["prompt"] = prompt
proc = self.client.audio.translations if self.task == "translate" else self.client.audio.transcriptions task = self.transcribe_kargs.get("task", self.task)
proc = self.client.audio.translations if task == "translate" else self.client.audio.transcriptions
transcript = proc.create(**params) transcript = proc.create(**params)
logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds") logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds")
return transcript return transcript

View File

@@ -146,6 +146,7 @@ def backend_factory(
if direct_english_translation: if direct_english_translation:
tgt_language = "en" # Whisper translates into English tgt_language = "en" # Whisper translates into English
asr.transcribe_kargs["task"] = "translate"
else: else:
tgt_language = lan # Whisper transcribes in this language tgt_language = lan # Whisper transcribes in this language
@@ -154,9 +155,9 @@ def backend_factory(
tokenizer = create_tokenizer(tgt_language) tokenizer = create_tokenizer(tgt_language)
else: else:
tokenizer = None tokenizer = None
warmup_asr(asr, warmup_file) warmup_asr(asr, warmup_file)
asr.confidence_validation = confidence_validation asr.confidence_validation = confidence_validation
asr.tokenizer = tokenizer asr.tokenizer = tokenizer
asr.buffer_trimming = buffer_trimming asr.buffer_trimming = buffer_trimming

View File

@@ -46,8 +46,6 @@ class SimulStreamingOnlineProcessor:
self.logfile = logfile self.logfile = logfile
self.end = 0.0 self.end = 0.0
self.buffer = [] self.buffer = []
self.committed: List[ASRToken] = []
self.last_result_tokens: List[ASRToken] = []
self.model = self._create_alignatt() self.model = self._create_alignatt()
if asr.tokenizer: if asr.tokenizer:
@@ -122,7 +120,6 @@ class SimulStreamingOnlineProcessor:
self.buffer.extend(timestamped_words) self.buffer.extend(timestamped_words)
return [], self.end return [], self.end
self.committed.extend(timestamped_words)
self.buffer = [] self.buffer = []
return timestamped_words, self.end return timestamped_words, self.end
except Exception as e: except Exception as e:
@@ -217,7 +214,7 @@ class SimulStreamingASR:
cif_ckpt_path=self.cif_ckpt_path, cif_ckpt_path=self.cif_ckpt_path,
decoder_type="beam", decoder_type="beam",
beam_size=self.beams, beam_size=self.beams,
task=self.direct_english_translation, task="translate" if self.direct_english_translation else "transcribe",
never_fire=self.never_fire, never_fire=self.never_fire,
init_prompt=self.init_prompt, init_prompt=self.init_prompt,
max_context_tokens=self.max_context_tokens, max_context_tokens=self.max_context_tokens,
@@ -330,7 +327,7 @@ class SimulStreamingASR:
lora_path = getattr(self, 'lora_path', None) lora_path = getattr(self, 'lora_path', None)
whisper_model = load_model( whisper_model = load_model(
name=model_ref, name=model_ref,
download_root=None, download_root=getattr(self, 'model_cache_dir', None),
decoder_only=self.fast_encoder, decoder_only=self.fast_encoder,
custom_alignment_heads=self.custom_alignment_heads, custom_alignment_heads=self.custom_alignment_heads,
lora_path=lora_path, lora_path=lora_path,

View File

@@ -47,9 +47,24 @@ class DecoderState:
def clean_cache(self): def clean_cache(self):
"""Clean the kv_cache after each inference step.""" """Clean the kv_cache after each inference step."""
self.kv_cache = {} # Explicitly delete tensor references to free GPU memory
if self.kv_cache:
for key in list(self.kv_cache.keys()):
tensor = self.kv_cache.pop(key, None)
if tensor is not None:
del tensor
# Clear the dict
self.kv_cache.clear()
# Force GPU cache cleanup (only if CUDA is available)
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
if self.decoder_type == "beam" and self.inference is not None: if self.decoder_type == "beam" and self.inference is not None:
self.inference.kv_cache = self.kv_cache # Create NEW dict instead of sharing reference
self.inference.kv_cache = {}
if self.token_decoder is not None: if self.token_decoder is not None:
self.token_decoder.reset() self.token_decoder.reset()

View File

@@ -532,7 +532,9 @@ class MLXAlignAtt:
accumulated_cross_attns = [] accumulated_cross_attns = []
audio_duration_s = self.segments_len() audio_duration_s = self.segments_len()
max_tokens_per_chunk = max(50, int(audio_duration_s * TOKENS_PER_SECOND * 2.0)) # ~15 text tokens/s is a generous upper bound for speech; TOKENS_PER_SECOND (50)
# is the mel-frame rate and was causing 10-40x over-allocation on repetition loops.
max_tokens_per_chunk = max(50, int(audio_duration_s * 15 * 1.5))
tokens_produced_this_chunk = 0 tokens_produced_this_chunk = 0
while not completed and current_tokens.shape[1] < self.max_text_len: while not completed and current_tokens.shape[1] < self.max_text_len:
@@ -558,6 +560,8 @@ class MLXAlignAtt:
mx.eval(logits) mx.eval(logits)
accumulated_cross_attns.append(cross_qk) accumulated_cross_attns.append(cross_qk)
if len(accumulated_cross_attns) > 16:
accumulated_cross_attns = accumulated_cross_attns[-16:]
if new_segment and self.tokenizer.no_speech is not None: if new_segment and self.tokenizer.no_speech is not None:
probs_at_sot = mx.softmax(logits[:, self.state.sot_index, :], axis=-1) probs_at_sot = mx.softmax(logits[:, self.state.sot_index, :], axis=-1)

View File

@@ -390,7 +390,6 @@ class AlignAtt:
return [] return []
if not self._apply_minseglen(): if not self._apply_minseglen():
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.") logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
input_segments = torch.cat(self.state.segments, dim=0)
return [] return []
# input_segments is concatenation of audio, it's one array # input_segments is concatenation of audio, it's one array
@@ -485,7 +484,9 @@ class AlignAtt:
accumulated_cross_attns = [] accumulated_cross_attns = []
audio_duration_s = self.segments_len() audio_duration_s = self.segments_len()
max_tokens_per_chunk = max(50, int(audio_duration_s * TOKENS_PER_SECOND * 2.0)) # 2x margin, min 50 # ~15 text tokens/s is a generous upper bound for speech; TOKENS_PER_SECOND (50)
# is the mel-frame rate and was causing 10-40x over-allocation on repetition loops.
max_tokens_per_chunk = max(50, int(audio_duration_s * 15 * 1.5))
tokens_produced_this_chunk = 0 tokens_produced_this_chunk = 0
while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens
@@ -506,8 +507,12 @@ class AlignAtt:
result = self.logits(tokens_for_logits, encoder_feature, return_cross_attn=True) result = self.logits(tokens_for_logits, encoder_feature, return_cross_attn=True)
logits, cross_attns = result logits, cross_attns = result
# Accumulate cross-attention from this forward pass # Accumulate cross-attention from this forward pass (rolling window to
# bound VRAM — only the last entry matters for alignment, and the
# median_filter kernel is 7, so 16 entries is more than enough).
accumulated_cross_attns.append(cross_attns) accumulated_cross_attns.append(cross_attns)
if len(accumulated_cross_attns) > 16:
accumulated_cross_attns = accumulated_cross_attns[-16:]
if new_segment and self.tokenizer.no_speech is not None: if new_segment and self.tokenizer.no_speech is not None:
probs_at_sot = logits[:, self.state.sot_index, :].float().softmax(dim=-1) probs_at_sot = logits[:, self.state.sot_index, :].float().softmax(dim=-1)
@@ -626,8 +631,10 @@ class AlignAtt:
try: try:
current_timestamp = l_absolute_timestamps[timestamp_idx] current_timestamp = l_absolute_timestamps[timestamp_idx]
except: except IndexError:
pass # Use last timestamp if index out of range
logger.warning(f"Timestamp index {timestamp_idx} out of range, using last timestamp")
current_timestamp = l_absolute_timestamps[-1] if l_absolute_timestamps else 0.0
timestamp_idx += len(word_tokens) timestamp_idx += len(word_tokens)
timestamp_entry = ASRToken( timestamp_entry = ASRToken(

View File

@@ -0,0 +1,139 @@
"""
Thread Safety Configuration for WhisperLiveKit
This module provides thread safety configuration and utilities.
Environment Variables:
WHISPERLIVEKIT_MODEL_LOCK: Enable/disable model locking (default: 1)
Set to "0" to disable for single-connection deployments
WHISPERLIVEKIT_LOCK_TIMEOUT: Lock acquisition timeout in seconds (default: 30)
Usage:
# Enable model locking (default)
export WHISPERLIVEKIT_MODEL_LOCK=1
# Disable for single-connection deployment
export WHISPERLIVEKIT_MODEL_LOCK=0
# Custom timeout
export WHISPERLIVEKIT_LOCK_TIMEOUT=60
"""
import os
import logging
import threading
logger = logging.getLogger(__name__)
# Configuration
USE_MODEL_LOCK = os.environ.get("WHISPERLIVEKIT_MODEL_LOCK", "1") == "1"
LOCK_TIMEOUT = float(os.environ.get("WHISPERLIVEKIT_LOCK_TIMEOUT", "30.0"))
# Global model lock
_model_lock = threading.Lock()
# Log configuration on import
if USE_MODEL_LOCK:
logger.info(f"Model locking ENABLED (timeout: {LOCK_TIMEOUT}s)")
logger.info("For single-connection deployments, set WHISPERLIVEKIT_MODEL_LOCK=0")
else:
logger.warning("Model locking DISABLED - only safe for single-connection deployments")
def get_model_lock():
"""Get the global model lock instance"""
return _model_lock
def acquire_model_lock(timeout=None):
"""
Acquire model lock with timeout.
Args:
timeout: Lock acquisition timeout (default: use LOCK_TIMEOUT)
Returns:
bool: True if lock acquired, False on timeout
"""
if not USE_MODEL_LOCK:
return True
timeout = timeout or LOCK_TIMEOUT
acquired = _model_lock.acquire(timeout=timeout)
if not acquired:
logger.error(f"Failed to acquire model lock within {timeout}s")
return acquired
def release_model_lock():
"""Release model lock"""
if not USE_MODEL_LOCK:
return
try:
_model_lock.release()
except RuntimeError:
# Lock not held - this is fine
pass
class ModelLockContext:
"""Context manager for model lock"""
def __init__(self, timeout=None):
self.timeout = timeout
self.acquired = False
def __enter__(self):
self.acquired = acquire_model_lock(self.timeout)
return self.acquired
def __exit__(self, exc_type, exc_val, exc_tb):
if self.acquired:
release_model_lock()
return False
# Concurrency recommendations
RECOMMENDED_CONNECTIONS_PER_WORKER = 1 if USE_MODEL_LOCK else 1
RECOMMENDED_WORKERS = 4
def print_deployment_recommendations():
"""Print recommended deployment configuration"""
print("\n" + "="*60)
print("WhisperLiveKit Deployment Recommendations")
print("="*60)
if USE_MODEL_LOCK:
print("⚠️ Model locking is ENABLED")
print(" This serializes inference across connections.")
print()
print("Recommended deployment:")
print(f" gunicorn -w {RECOMMENDED_WORKERS} \\")
print(" -k uvicorn.workers.UvicornWorker \\")
print(" --worker-connections 1 \\")
print(" whisperlivekit.basic_server:app")
print()
print("Expected capacity:")
print(f" - {RECOMMENDED_WORKERS} concurrent users (1 per worker)")
print(f" - Memory: ~{RECOMMENDED_WORKERS}x model size")
else:
print("✅ Model locking is DISABLED")
print(" ⚠️ ONLY safe for single-connection deployments")
print()
print("Recommended deployment:")
print(" uvicorn whisperlivekit.basic_server:app \\")
print(" --host 0.0.0.0 --port 8000 \\")
print(" --workers 1")
print()
print("Expected capacity:")
print(" - 1 concurrent user only")
print("="*60 + "\n")
if __name__ == "__main__":
print_deployment_recommendations()

View File

@@ -39,10 +39,11 @@ class TimedText(Timed):
@dataclass() @dataclass()
class ASRToken(TimedText): class ASRToken(TimedText):
probability: Optional[float] = None
def with_offset(self, offset: float) -> "ASRToken": def with_offset(self, offset: float) -> "ASRToken":
"""Return a new token with the time offset added.""" """Return a new token with the time offset added."""
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, detected_language=self.detected_language) return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, detected_language=self.detected_language, probability=self.probability)
def is_silence(self) -> bool: def is_silence(self) -> bool:
return False return False

View File

@@ -53,7 +53,8 @@ class TokensAlignment:
segment.translation = '' segment.translation = ''
for ts in self.all_translation_segments: for ts in self.all_translation_segments:
if ts.is_within(segment): if ts.is_within(segment):
segment.translation += ts.text + (self.sep if ts.text else '') if ts.text:
segment.translation += ts.text + self.sep
elif segment.translation: elif segment.translation:
break break
@@ -185,11 +186,11 @@ class TokensAlignment:
else: else:
diarization_buffer = '' diarization_buffer = ''
for token in self.new_tokens: for token in self.new_tokens:
if token.is_silence(): if isinstance(token, Silence):
if self.current_line_tokens: if self.current_line_tokens:
self.validated_segments.append(Segment().from_tokens(self.current_line_tokens)) self.validated_segments.append(Segment.from_tokens(self.current_line_tokens))
self.current_line_tokens = [] self.current_line_tokens = []
end_silence = token.end if token.has_ended else time() - self.beg_loop end_silence = token.end if token.has_ended else time() - self.beg_loop
if self.validated_segments and self.validated_segments[-1].is_silence(): if self.validated_segments and self.validated_segments[-1].is_silence():
self.validated_segments[-1].end = end_silence self.validated_segments[-1].end = end_silence
@@ -203,7 +204,7 @@ class TokensAlignment:
segments = list(self.validated_segments) segments = list(self.validated_segments)
if self.current_line_tokens: if self.current_line_tokens:
segments.append(Segment().from_tokens(self.current_line_tokens)) segments.append(Segment.from_tokens(self.current_line_tokens))
if current_silence: if current_silence:
end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop