mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-04-28 17:40:00 +00:00
asyncio.to_thread for transcription and translation
This commit is contained in:
@@ -4,7 +4,7 @@ from time import time, sleep
|
|||||||
import math
|
import math
|
||||||
import logging
|
import logging
|
||||||
import traceback
|
import traceback
|
||||||
from whisperlivekit.timed_objects import ASRToken, Silence, Line
|
from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State
|
||||||
from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory, online_translation_factory
|
from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory, online_translation_factory
|
||||||
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
|
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
|
||||||
from whisperlivekit.silero_vad_iterator import FixedVADIterator
|
from whisperlivekit.silero_vad_iterator import FixedVADIterator
|
||||||
@@ -68,7 +68,7 @@ class AudioProcessor:
|
|||||||
self.lock = asyncio.Lock()
|
self.lock = asyncio.Lock()
|
||||||
self.beg_loop = None #to deal with a potential little lag at the websocket initialization, this is now set in process_audio
|
self.beg_loop = None #to deal with a potential little lag at the websocket initialization, this is now set in process_audio
|
||||||
self.sep = " " # Default separator
|
self.sep = " " # Default separator
|
||||||
self.last_response_content = ""
|
self.last_response_content = FrontData()
|
||||||
|
|
||||||
# Models and processing
|
# Models and processing
|
||||||
self.asr = models.asr
|
self.asr = models.asr
|
||||||
@@ -103,7 +103,8 @@ class AudioProcessor:
|
|||||||
self.all_tasks_for_cleanup = []
|
self.all_tasks_for_cleanup = []
|
||||||
|
|
||||||
if self.args.transcription:
|
if self.args.transcription:
|
||||||
self.online = online_factory(self.args, models.asr, models.tokenizer)
|
self.online = online_factory(self.args, models.asr, models.tokenizer)
|
||||||
|
self.sep = self.online.asr.sep
|
||||||
if self.args.diarization:
|
if self.args.diarization:
|
||||||
self.diarization = online_diarization_factory(self.args, models.diarization_model)
|
self.diarization = online_diarization_factory(self.args, models.diarization_model)
|
||||||
if self.args.target_language:
|
if self.args.target_language:
|
||||||
@@ -113,13 +114,12 @@ class AudioProcessor:
|
|||||||
"""Convert PCM buffer in s16le format to normalized NumPy array."""
|
"""Convert PCM buffer in s16le format to normalized NumPy array."""
|
||||||
return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
|
return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
|
||||||
|
|
||||||
async def update_transcription(self, new_tokens, buffer, end_buffer, sep):
|
async def update_transcription(self, new_tokens, buffer, end_buffer):
|
||||||
"""Thread-safe update of transcription with new data."""
|
"""Thread-safe update of transcription with new data."""
|
||||||
async with self.lock:
|
async with self.lock:
|
||||||
self.tokens.extend(new_tokens)
|
self.tokens.extend(new_tokens)
|
||||||
self.buffer_transcription = buffer
|
self.buffer_transcription = buffer
|
||||||
self.end_buffer = end_buffer
|
self.end_buffer = end_buffer
|
||||||
self.sep = sep
|
|
||||||
|
|
||||||
async def update_diarization(self, end_attributed_speaker, buffer_diarization=""):
|
async def update_diarization(self, end_attributed_speaker, buffer_diarization=""):
|
||||||
"""Thread-safe update of diarization with new data."""
|
"""Thread-safe update of diarization with new data."""
|
||||||
@@ -152,17 +152,16 @@ class AudioProcessor:
|
|||||||
latest_end = max(self.end_buffer, self.tokens[-1].end if self.tokens else 0)
|
latest_end = max(self.end_buffer, self.tokens[-1].end if self.tokens else 0)
|
||||||
remaining_diarization = max(0, round(latest_end - self.end_attributed_speaker, 1))
|
remaining_diarization = max(0, round(latest_end - self.end_attributed_speaker, 1))
|
||||||
|
|
||||||
return {
|
return State(
|
||||||
"tokens": self.tokens.copy(),
|
tokens=self.tokens.copy(),
|
||||||
"translated_segments": self.translated_segments.copy(),
|
translated_segments=self.translated_segments.copy(),
|
||||||
"buffer_transcription": self.buffer_transcription,
|
buffer_transcription=self.buffer_transcription,
|
||||||
"buffer_diarization": self.buffer_diarization,
|
buffer_diarization=self.buffer_diarization,
|
||||||
"end_buffer": self.end_buffer,
|
end_buffer=self.end_buffer,
|
||||||
"end_attributed_speaker": self.end_attributed_speaker,
|
end_attributed_speaker=self.end_attributed_speaker,
|
||||||
"sep": self.sep,
|
remaining_time_transcription=remaining_transcription,
|
||||||
"remaining_time_transcription": remaining_transcription,
|
remaining_time_diarization=remaining_diarization
|
||||||
"remaining_time_diarization": remaining_diarization
|
)
|
||||||
}
|
|
||||||
|
|
||||||
async def reset(self):
|
async def reset(self):
|
||||||
"""Reset all state variables to initial values."""
|
"""Reset all state variables to initial values."""
|
||||||
@@ -236,7 +235,6 @@ class AudioProcessor:
|
|||||||
|
|
||||||
async def transcription_processor(self):
|
async def transcription_processor(self):
|
||||||
"""Process audio chunks for transcription."""
|
"""Process audio chunks for transcription."""
|
||||||
self.sep = self.online.asr.sep
|
|
||||||
cumulative_pcm_duration_stream_time = 0.0
|
cumulative_pcm_duration_stream_time = 0.0
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
@@ -276,7 +274,7 @@ class AudioProcessor:
|
|||||||
stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time
|
stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time
|
||||||
|
|
||||||
self.online.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm)
|
self.online.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm)
|
||||||
new_tokens, current_audio_processed_upto = self.online.process_iter()
|
new_tokens, current_audio_processed_upto = await asyncio.to_thread(self.online.process_iter)
|
||||||
|
|
||||||
# Get buffer information
|
# Get buffer information
|
||||||
_buffer_transcript_obj = self.online.get_buffer()
|
_buffer_transcript_obj = self.online.get_buffer()
|
||||||
@@ -300,7 +298,7 @@ class AudioProcessor:
|
|||||||
new_end_buffer = max(candidate_end_times)
|
new_end_buffer = max(candidate_end_times)
|
||||||
|
|
||||||
await self.update_transcription(
|
await self.update_transcription(
|
||||||
new_tokens, buffer_text, new_end_buffer, self.sep
|
new_tokens, buffer_text, new_end_buffer
|
||||||
)
|
)
|
||||||
|
|
||||||
if new_tokens and self.args.target_language and self.translation_queue:
|
if new_tokens and self.args.target_language and self.translation_queue:
|
||||||
@@ -385,7 +383,7 @@ class AudioProcessor:
|
|||||||
tokens_to_process.append(additional_token)
|
tokens_to_process.append(additional_token)
|
||||||
if tokens_to_process:
|
if tokens_to_process:
|
||||||
online_translation.insert_tokens(tokens_to_process)
|
online_translation.insert_tokens(tokens_to_process)
|
||||||
self.translated_segments = online_translation.process()
|
self.translated_segments = await asyncio.to_thread(online_translation.process)
|
||||||
|
|
||||||
self.translation_queue.task_done()
|
self.translation_queue.task_done()
|
||||||
for _ in additional_tokens:
|
for _ in additional_tokens:
|
||||||
@@ -407,39 +405,25 @@ class AudioProcessor:
|
|||||||
|
|
||||||
async def results_formatter(self):
|
async def results_formatter(self):
|
||||||
"""Format processing results for output."""
|
"""Format processing results for output."""
|
||||||
last_sent_trans = None
|
|
||||||
last_sent_diar = None
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
ffmpeg_state = await self.ffmpeg_manager.get_state()
|
ffmpeg_state = await self.ffmpeg_manager.get_state()
|
||||||
if ffmpeg_state == FFmpegState.FAILED and self._ffmpeg_error:
|
if ffmpeg_state == FFmpegState.FAILED and self._ffmpeg_error:
|
||||||
yield {
|
yield FrontData(
|
||||||
"status": "error",
|
status="error",
|
||||||
"error": f"FFmpeg error: {self._ffmpeg_error}",
|
error=f"FFmpeg error: {self._ffmpeg_error}"
|
||||||
"lines": [],
|
)
|
||||||
"buffer_transcription": "",
|
|
||||||
"buffer_diarization": "",
|
|
||||||
"remaining_time_transcription": 0,
|
|
||||||
"remaining_time_diarization": 0
|
|
||||||
}
|
|
||||||
self._ffmpeg_error = None
|
self._ffmpeg_error = None
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Get current state
|
|
||||||
state = await self.get_current_state()
|
state = await self.get_current_state()
|
||||||
tokens = state["tokens"]
|
|
||||||
buffer_transcription = state["buffer_transcription"]
|
|
||||||
buffer_diarization = state["buffer_diarization"]
|
|
||||||
end_attributed_speaker = state["end_attributed_speaker"]
|
|
||||||
sep = state["sep"]
|
|
||||||
|
|
||||||
# Add dummy tokens if needed
|
# Add dummy tokens if needed
|
||||||
if (not tokens or tokens[-1].is_dummy) and not self.args.transcription and self.args.diarization:
|
if (not state.tokens or state.tokens[-1].is_dummy) and not self.args.transcription and self.args.diarization:
|
||||||
await self.add_dummy_token()
|
await self.add_dummy_token()
|
||||||
sleep(0.5)
|
sleep(0.5)
|
||||||
state = await self.get_current_state()
|
state = await self.get_current_state()
|
||||||
tokens = state["tokens"]
|
|
||||||
|
|
||||||
# Format output
|
# Format output
|
||||||
lines, undiarized_text, buffer_transcription, buffer_diarization = format_output(
|
lines, undiarized_text, buffer_transcription, buffer_diarization = format_output(
|
||||||
@@ -447,18 +431,19 @@ class AudioProcessor:
|
|||||||
self.silence,
|
self.silence,
|
||||||
current_time = time() - self.beg_loop if self.beg_loop else None,
|
current_time = time() - self.beg_loop if self.beg_loop else None,
|
||||||
args = self.args,
|
args = self.args,
|
||||||
debug = self.debug
|
debug = self.debug,
|
||||||
|
sep=self.sep
|
||||||
)
|
)
|
||||||
# Handle undiarized text
|
# Handle undiarized text
|
||||||
if undiarized_text:
|
if undiarized_text:
|
||||||
combined = sep.join(undiarized_text)
|
combined = self.sep.join(undiarized_text)
|
||||||
if buffer_transcription:
|
if buffer_transcription:
|
||||||
combined += sep
|
combined += self.sep
|
||||||
await self.update_diarization(end_attributed_speaker, combined)
|
await self.update_diarization(state.end_attributed_speaker, combined)
|
||||||
buffer_diarization = combined
|
buffer_diarization = combined
|
||||||
|
|
||||||
response_status = "active_transcription"
|
response_status = "active_transcription"
|
||||||
if not tokens and not buffer_transcription and not buffer_diarization:
|
if not state.tokens and not buffer_transcription and not buffer_diarization:
|
||||||
response_status = "no_audio_detected"
|
response_status = "no_audio_detected"
|
||||||
lines = []
|
lines = []
|
||||||
elif response_status == "active_transcription" and not lines:
|
elif response_status == "active_transcription" and not lines:
|
||||||
@@ -468,32 +453,19 @@ class AudioProcessor:
|
|||||||
end=state.get("end_buffer", 0)
|
end=state.get("end_buffer", 0)
|
||||||
)]
|
)]
|
||||||
|
|
||||||
response = {
|
response = FrontData(
|
||||||
"status": response_status,
|
status=response_status,
|
||||||
"lines": [line.to_dict() for line in lines],
|
lines=lines,
|
||||||
"buffer_transcription": buffer_transcription,
|
buffer_transcription=buffer_transcription,
|
||||||
"buffer_diarization": buffer_diarization,
|
buffer_diarization=buffer_diarization,
|
||||||
"remaining_time_transcription": state["remaining_time_transcription"],
|
remaining_time_transcription=state.remaining_time_transcription,
|
||||||
"remaining_time_diarization": state["remaining_time_diarization"] if self.args.diarization else 0
|
remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0
|
||||||
}
|
|
||||||
|
|
||||||
current_response_signature = f"{response_status} | " + \
|
|
||||||
' '.join([f"{line.speaker} {line.text}" for line in lines]) + \
|
|
||||||
f" | {buffer_transcription} | {buffer_diarization}"
|
|
||||||
|
|
||||||
trans = state["remaining_time_transcription"]
|
|
||||||
diar = state["remaining_time_diarization"]
|
|
||||||
should_push = (
|
|
||||||
current_response_signature != self.last_response_content
|
|
||||||
or last_sent_trans is None
|
|
||||||
or round(trans, 1) != round(last_sent_trans, 1)
|
|
||||||
or round(diar, 1) != round(last_sent_diar, 1)
|
|
||||||
)
|
)
|
||||||
if should_push and (lines or buffer_transcription or buffer_diarization or response_status == "no_audio_detected" or trans > 0 or diar > 0):
|
|
||||||
|
should_push = (response != self.last_response_content)
|
||||||
|
if should_push and (lines or buffer_transcription or buffer_diarization or response_status == "no_audio_detected"):
|
||||||
yield response
|
yield response
|
||||||
self.last_response_content = current_response_signature
|
self.last_response_content = response
|
||||||
last_sent_trans = trans
|
|
||||||
last_sent_diar = diar
|
|
||||||
|
|
||||||
# Check for termination condition
|
# Check for termination condition
|
||||||
if self.is_stopping:
|
if self.is_stopping:
|
||||||
@@ -507,12 +479,12 @@ class AudioProcessor:
|
|||||||
logger.info("Results formatter: All upstream processors are done and in stopping state. Terminating.")
|
logger.info("Results formatter: All upstream processors are done and in stopping state. Terminating.")
|
||||||
return
|
return
|
||||||
|
|
||||||
await asyncio.sleep(0.1) # Avoid overwhelming the client
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Exception in results_formatter: {e}")
|
logger.warning(f"Exception in results_formatter: {e}")
|
||||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||||
await asyncio.sleep(0.5) # Back off on error
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
async def create_tasks(self):
|
async def create_tasks(self):
|
||||||
"""Create and start processing tasks."""
|
"""Create and start processing tasks."""
|
||||||
@@ -523,15 +495,10 @@ class AudioProcessor:
|
|||||||
if not success:
|
if not success:
|
||||||
logger.error("Failed to start FFmpeg manager")
|
logger.error("Failed to start FFmpeg manager")
|
||||||
async def error_generator():
|
async def error_generator():
|
||||||
yield {
|
yield FrontData(
|
||||||
"status": "error",
|
status="error",
|
||||||
"error": "FFmpeg failed to start. Please check that FFmpeg is installed.",
|
error="FFmpeg failed to start. Please check that FFmpeg is installed."
|
||||||
"lines": [],
|
)
|
||||||
"buffer_transcription": "",
|
|
||||||
"buffer_diarization": "",
|
|
||||||
"remaining_time_transcription": 0,
|
|
||||||
"remaining_time_diarization": 0
|
|
||||||
}
|
|
||||||
return error_generator()
|
return error_generator()
|
||||||
|
|
||||||
if self.args.transcription and self.online:
|
if self.args.transcription and self.online:
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ async def handle_websocket_results(websocket, results_generator):
|
|||||||
"""Consumes results from the audio processor and sends them via WebSocket."""
|
"""Consumes results from the audio processor and sends them via WebSocket."""
|
||||||
try:
|
try:
|
||||||
async for response in results_generator:
|
async for response in results_generator:
|
||||||
await websocket.send_json(response)
|
await websocket.send_json(response.to_dict())
|
||||||
# when the results_generator finishes it means all audio has been processed
|
# when the results_generator finishes it means all audio has been processed
|
||||||
logger.info("Results generator finished. Sending 'ready_to_stop' to client.")
|
logger.info("Results generator finished. Sending 'ready_to_stop' to client.")
|
||||||
await websocket.send_json({"type": "ready_to_stop"})
|
await websocket.send_json({"type": "ready_to_stop"})
|
||||||
|
|||||||
@@ -46,15 +46,14 @@ def append_token_to_last_line(lines, sep, token, debug_info):
|
|||||||
lines[-1].text += sep + token.text + debug_info
|
lines[-1].text += sep + token.text + debug_info
|
||||||
lines[-1].end = token.end
|
lines[-1].end = token.end
|
||||||
|
|
||||||
def format_output(state, silence, current_time, args, debug):
|
def format_output(state, silence, current_time, args, debug, sep):
|
||||||
diarization = args.diarization
|
diarization = args.diarization
|
||||||
disable_punctuation_split = args.disable_punctuation_split
|
disable_punctuation_split = args.disable_punctuation_split
|
||||||
tokens = state["tokens"]
|
tokens = state.tokens
|
||||||
translated_segments = state["translated_segments"] # Here we will attribute the speakers only based on the timestamps of the segments
|
translated_segments = state.translated_segments # Here we will attribute the speakers only based on the timestamps of the segments
|
||||||
buffer_transcription = state["buffer_transcription"]
|
buffer_transcription = state.buffer_transcription
|
||||||
buffer_diarization = state["buffer_diarization"]
|
buffer_diarization = state.buffer_diarization
|
||||||
end_attributed_speaker = state["end_attributed_speaker"]
|
end_attributed_speaker = state.end_attributed_speaker
|
||||||
sep = state["sep"]
|
|
||||||
|
|
||||||
previous_speaker = -1
|
previous_speaker = -1
|
||||||
lines = []
|
lines = []
|
||||||
@@ -128,7 +127,7 @@ def format_output(state, silence, current_time, args, debug):
|
|||||||
for line in lines:
|
for line in lines:
|
||||||
while cts_idx < len(translated_segments):
|
while cts_idx < len(translated_segments):
|
||||||
ts = translated_segments[cts_idx]
|
ts = translated_segments[cts_idx]
|
||||||
if ts.start and ts.start >= line.start and ts.end <= line.end:
|
if ts and ts.start and ts.start >= line.start and ts.end <= line.end:
|
||||||
line.translation += ts.text + ' '
|
line.translation += ts.text + ' '
|
||||||
cts_idx += 1
|
cts_idx += 1
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
@@ -57,4 +57,38 @@ class Line(TimedText):
|
|||||||
'translation': self.translation,
|
'translation': self.translation,
|
||||||
'start': format_time(self.start),
|
'start': format_time(self.start),
|
||||||
'end': format_time(self.end),
|
'end': format_time(self.end),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FrontData():
|
||||||
|
status: str = ''
|
||||||
|
error: str = ''
|
||||||
|
lines: list[Line] = field(default_factory=list)
|
||||||
|
buffer_transcription: str = ''
|
||||||
|
buffer_diarization: str = ''
|
||||||
|
remaining_time_transcription: float = 0.
|
||||||
|
remaining_time_diarization: float = 0.
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
_dict = {
|
||||||
|
'status': self.status,
|
||||||
|
'lines': [line.to_dict() for line in self.lines],
|
||||||
|
'buffer_transcription': self.buffer_transcription,
|
||||||
|
'buffer_diarization': self.buffer_diarization,
|
||||||
|
'remaining_time_transcription': self.remaining_time_transcription,
|
||||||
|
'remaining_time_diarization': self.remaining_time_diarization,
|
||||||
|
}
|
||||||
|
if self.error:
|
||||||
|
_dict['error'] = self.error
|
||||||
|
return _dict
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class State():
|
||||||
|
tokens: list
|
||||||
|
translated_segments: list
|
||||||
|
buffer_transcription: str
|
||||||
|
buffer_diarization: str
|
||||||
|
end_buffer: float
|
||||||
|
end_attributed_speaker: float
|
||||||
|
remaining_time_transcription: float
|
||||||
|
remaining_time_diarization: float
|
||||||
Reference in New Issue
Block a user