audioProcessor state variables are now uniquely in State dataclass

This commit is contained in:
Quentin Fuxa
2025-10-26 18:54:47 +01:00
parent 4e455b8aab
commit 61edb70fff
4 changed files with 54 additions and 73 deletions

View File

@@ -67,20 +67,17 @@ class AudioProcessor:
self.is_stopping = False
self.silence = False
self.silence_duration = 0.0
self.tokens = []
self.last_validated_token = 0
self.translated_segments = []
self.buffer_transcription = Transcript()
self.end_buffer = 0
self.end_attributed_speaker = 0
self.state = State()
self.lock = asyncio.Lock()
self.beg_loop = 0.0 #to deal with a potential little lag at the websocket initialization, this is now set in process_audio
self.sep = " " # Default separator
self.last_response_content = FrontData()
self.last_detected_speaker = None
self.speaker_languages = {}
self.diarization_before_transcription = False
self.segments = []
if self.diarization_before_transcription:
self.cumulative_pcm = []
self.last_start = 0.0
@@ -138,8 +135,8 @@ class AudioProcessor:
async def add_dummy_token(self):
"""Placeholder token when no transcription is available."""
async with self.lock:
current_time = time() - self.beg_loop
self.tokens.append(ASRToken(
current_time = time() - self.state.beg_loop
self.state.tokens.append(ASRToken(
start=current_time, end=current_time + 1,
text=".", speaker=-1, is_dummy=True
))
@@ -149,35 +146,19 @@ class AudioProcessor:
async with self.lock:
current_time = time()
# Calculate remaining times
remaining_transcription = 0
if self.end_buffer > 0:
remaining_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 1))
if self.state.end_buffer > 0:
remaining_transcription = max(0, round(current_time - self.state.beg_loop - self.state.end_buffer, 1))
remaining_diarization = 0
if self.tokens:
latest_end = max(self.end_buffer, self.tokens[-1].end if self.tokens else 0)
remaining_diarization = max(0, round(latest_end - self.end_attributed_speaker, 1))
if self.state.tokens:
latest_end = max(self.state.end_buffer, self.state.tokens[-1].end if self.state.tokens else 0)
remaining_diarization = max(0, round(latest_end - self.state.end_attributed_speaker, 1))
return State(
tokens=self.tokens.copy(),
last_validated_token=self.last_validated_token,
translated_segments=self.translated_segments.copy(),
buffer_transcription=self.buffer_transcription,
end_buffer=self.end_buffer,
end_attributed_speaker=self.end_attributed_speaker,
remaining_time_transcription=remaining_transcription,
remaining_time_diarization=remaining_diarization
)
self.state.remaining_time_transcription = remaining_transcription
self.state.remaining_time_diarization = remaining_diarization
async def reset(self):
"""Reset all state variables to initial values."""
async with self.lock:
self.tokens = []
self.translated_segments = []
self.buffer_transcription = Transcript()
self.end_buffer = self.end_attributed_speaker = 0
self.beg_loop = time()
return self.state
async def ffmpeg_stdout_reader(self):
"""Read audio data from FFmpeg stdout and process it into the PCM pipeline."""
@@ -242,15 +223,15 @@ class AudioProcessor:
break
asr_internal_buffer_duration_s = len(getattr(self.transcription, 'audio_buffer', [])) / self.transcription.SAMPLING_RATE
transcription_lag_s = max(0.0, time() - self.beg_loop - self.end_buffer)
transcription_lag_s = max(0.0, time() - self.state.beg_loop - self.state.end_buffer)
asr_processing_logs = f"internal_buffer={asr_internal_buffer_duration_s:.2f}s | lag={transcription_lag_s:.2f}s |"
if type(item) is Silence:
asr_processing_logs += f" + Silence of = {item.duration:.2f}s"
if self.tokens:
asr_processing_logs += f" | last_end = {self.tokens[-1].end} |"
if self.state.tokens:
asr_processing_logs += f" | last_end = {self.state.tokens[-1].end} |"
logger.info(asr_processing_logs)
cumulative_pcm_duration_stream_time += item.duration
self.transcription.insert_silence(item.duration, self.tokens[-1].end if self.tokens else 0)
self.transcription.insert_silence(item.duration, self.state.tokens[-1].end if self.state.tokens else 0)
continue
elif isinstance(item, ChangeSpeaker):
self.transcription.new_speaker(item)
@@ -274,7 +255,7 @@ class AudioProcessor:
if buffer_text.startswith(validated_text):
_buffer_transcript.text = buffer_text[len(validated_text):].lstrip()
candidate_end_times = [self.end_buffer]
candidate_end_times = [self.state.end_buffer]
if new_tokens:
candidate_end_times.append(new_tokens[-1].end)
@@ -285,9 +266,9 @@ class AudioProcessor:
candidate_end_times.append(current_audio_processed_upto)
async with self.lock:
self.tokens.extend(new_tokens)
self.buffer_transcription = _buffer_transcript
self.end_buffer = max(candidate_end_times)
self.state.tokens.extend(new_tokens)
self.state.buffer_transcription = _buffer_transcript
self.state.end_buffer = max(candidate_end_times)
if self.translation_queue:
for token in new_tokens:
@@ -360,12 +341,12 @@ class AudioProcessor:
self.last_end = last_segment.end
elif not self.diarization_before_transcription:
async with self.lock:
self.tokens = diarization_obj.assign_speakers_to_tokens(
self.tokens,
self.state.tokens = diarization_obj.assign_speakers_to_tokens(
self.state.tokens,
use_punctuation_split=self.args.punctuation_split
)
if len(self.tokens) > 0:
self.end_attributed_speaker = max(self.tokens[-1].end, self.end_attributed_speaker)
if len(self.state.tokens) > 0:
self.state.end_attributed_speaker = max(self.state.tokens[-1].end, self.state.end_attributed_speaker)
self.diarization_queue.task_done()
except Exception as e:
@@ -406,7 +387,10 @@ class AudioProcessor:
tokens_to_process.append(additional_token)
if tokens_to_process:
self.translation.insert_tokens(tokens_to_process)
self.translated_segments = await asyncio.to_thread(self.translation.process)
translation_validated_segments, translation_buffer = await asyncio.to_thread(self.translation.process)
async with self.lock:
self.state.translation_validated_segments = translation_validated_segments
self.state.translation_buffer = translation_buffer
self.translation_queue.task_done()
for _ in additional_tokens:
self.translation_queue.task_done()
@@ -440,7 +424,6 @@ class AudioProcessor:
lines, undiarized_text = format_output(
state,
self.silence,
current_time = time() - self.beg_loop,
args = self.args,
sep=self.sep
)
@@ -454,7 +437,7 @@ class AudioProcessor:
buffer_diarization = self.sep.join(undiarized_text)
async with self.lock:
self.end_attributed_speaker = state.end_attributed_speaker
self.state.end_attributed_speaker = state.end_attributed_speaker
response_status = "active_transcription"
if not state.tokens and not buffer_transcription and not buffer_diarization:
@@ -580,8 +563,8 @@ class AudioProcessor:
async def process_audio(self, message):
"""Process incoming audio data."""
if not self.beg_loop:
self.beg_loop = time()
if not self.state.beg_loop:
self.state.beg_loop = time()
if not message:
logger.info("Empty audio message received, initiating stop sequence.")

View File

@@ -1,4 +1,5 @@
from whisperlivekit.timed_objects import ASRToken
from time import time
import re
MIN_SILENCE_DURATION = 4 #in seconds
@@ -77,7 +78,8 @@ def no_token_to_silence(tokens):
new_tokens.append(token)
return new_tokens
def ends_with_silence(tokens, current_time, vac_detected_silence):
def ends_with_silence(tokens, beg_loop, vac_detected_silence):
current_time = time() - (beg_loop if beg_loop else 0.0)
last_token = tokens[-1]
if vac_detected_silence or (current_time - last_token.end >= END_SILENCE_DURATION):
if last_token.speaker == -2:
@@ -94,11 +96,11 @@ def ends_with_silence(tokens, current_time, vac_detected_silence):
return tokens
def handle_silences(tokens, current_time, vac_detected_silence):
def handle_silences(tokens, beg_loop, vac_detected_silence):
if not tokens:
return []
tokens = blank_to_silence(tokens) #useful for simulstreaming backend which tends to generate [BLANK_AUDIO] text
tokens = no_token_to_silence(tokens)
tokens = ends_with_silence(tokens, current_time, vac_detected_silence)
tokens = ends_with_silence(tokens, beg_loop, vac_detected_silence)
return tokens

View File

@@ -52,16 +52,17 @@ def append_token_to_last_line(lines, sep, token):
lines[-1].detected_language = token.detected_language
def format_output(state, silence, current_time, args, sep):
def format_output(state, silence, args, sep):
diarization = args.diarization
disable_punctuation_split = args.disable_punctuation_split
tokens = state.tokens
translated_segments = state.translated_segments # Here we will attribute the speakers only based on the timestamps of the segments
translation_validated_segments = state.translation_validated_segments # Here we will attribute the speakers only based on the timestamps of the segments
translation_buffer = state.translation_buffer
last_validated_token = state.last_validated_token
previous_speaker = 1
undiarized_text = []
tokens = handle_silences(tokens, current_time, silence)
tokens = handle_silences(tokens, state.beg_loop, silence)
last_punctuation = None
for i, token in enumerate(tokens[last_validated_token:]):
speaker = int(token.speaker)
@@ -71,13 +72,6 @@ def format_output(state, silence, current_time, args, sep):
token.corrected_speaker = 1
token.validated_speaker = True
else:
# if token.end > end_attributed_speaker and token.speaker != -2:
# if tokens[-1].speaker == -2: #if it finishes by a silence, we want to append the undiarized text to the last speaker.
# token.corrected_speaker = previous_speaker
# else:
# undiarized_text.append(token.text)
# continue
# else:
if is_punctuation(token):
last_punctuation = i
@@ -123,9 +117,9 @@ def format_output(state, silence, current_time, args, sep):
previous_speaker = token.corrected_speaker
if lines and translated_segments:
if lines:
unassigned_translated_segments = []
for ts in translated_segments:
for ts in translation_validated_segments:
assigned = False
for line in lines:
if ts and ts.overlaps_with(line):

View File

@@ -174,11 +174,13 @@ class ChangeSpeaker:
@dataclass
class State():
tokens: list
last_validated_token: int
translated_segments: list
buffer_transcription: str
end_buffer: float
end_attributed_speaker: float
remaining_time_transcription: float
remaining_time_diarization: float
tokens: list = field(default_factory=list)
last_validated_token: int = 0
translation_validated_segments: list = field(default_factory=list)
translation_buffer: list = field(default_factory=list)
buffer_transcription: str = field(default_factory=Transcript)
end_buffer: float = 0.0
end_attributed_speaker: float = 0.0
remaining_time_transcription: float = 0.0
remaining_time_diarization: float = 0.0
beg_loop: Optional[int] = None