Merge branch 'improve_EOS_handling'

This commit is contained in:
Quentin Fuxa
2025-11-16 22:30:31 +01:00
5 changed files with 103 additions and 55 deletions

View File

@@ -67,6 +67,8 @@ class AudioProcessor:
self.is_stopping = False
self.silence = False
self.silence_duration = 0.0
self.start_silence = None
self.last_silence_dispatch_time = None
self.state = State()
self.lock = asyncio.Lock()
self.sep = " " # Default separator
@@ -128,6 +130,34 @@ class AudioProcessor:
if models.translation_model:
self.translation = online_translation_factory(self.args, models.translation_model)
async def _push_silence_event(self, silence_buffer: Silence):
if not self.diarization_before_transcription and self.transcription_queue:
await self.transcription_queue.put(silence_buffer)
if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(silence_buffer)
if self.translation_queue:
await self.translation_queue.put(silence_buffer)
async def _begin_silence(self):
if self.silence:
return
self.silence = True
now = time()
self.start_silence = now
self.last_silence_dispatch_time = now
await self._push_silence_event(Silence(is_starting=True))
async def _end_silence(self):
if not self.silence:
return
now = time()
duration = now - self.last_silence_dispatch_time
await self._push_silence_event(Silence(duration=duration, has_ended=True))
self.last_silence_dispatch_time = now
self.silence = False
self.start_silence = None
self.last_silence_dispatch_time = None
def convert_pcm_to_float(self, pcm_buffer):
"""Convert PCM buffer in s16le format to normalized NumPy array."""
return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
@@ -225,28 +255,42 @@ class AudioProcessor:
asr_internal_buffer_duration_s = len(getattr(self.transcription, 'audio_buffer', [])) / self.transcription.SAMPLING_RATE
transcription_lag_s = max(0.0, time() - self.state.beg_loop - self.state.end_buffer)
asr_processing_logs = f"internal_buffer={asr_internal_buffer_duration_s:.2f}s | lag={transcription_lag_s:.2f}s |"
if type(item) is Silence:
asr_processing_logs += f" + Silence of = {item.duration:.2f}s"
stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time
new_tokens = []
current_audio_processed_upto = self.state.end_buffer
if isinstance(item, Silence):
if item.is_starting:
new_tokens, current_audio_processed_upto = await asyncio.to_thread(
self.transcription.start_silence
)
asr_processing_logs += f" + Silence starting"
if item.has_ended:
asr_processing_logs += f" + Silence of = {item.duration:.2f}s"
cumulative_pcm_duration_stream_time += item.duration
current_audio_processed_upto = cumulative_pcm_duration_stream_time
self.transcription.end_silence(item.duration, self.state.tokens[-1].end if self.state.tokens else 0)
if self.state.tokens:
asr_processing_logs += f" | last_end = {self.state.tokens[-1].end} |"
logger.info(asr_processing_logs)
cumulative_pcm_duration_stream_time += item.duration
self.transcription.insert_silence(item.duration, self.state.tokens[-1].end if self.state.tokens else 0)
continue
new_tokens = new_tokens or []
current_audio_processed_upto = max(current_audio_processed_upto, stream_time_end_of_current_pcm)
elif isinstance(item, ChangeSpeaker):
self.transcription.new_speaker(item)
self.transcription_queue.task_done()
continue
elif isinstance(item, np.ndarray):
pcm_array = item
logger.info(asr_processing_logs)
duration_this_chunk = len(pcm_array) / self.sample_rate
cumulative_pcm_duration_stream_time += duration_this_chunk
stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time
logger.info(asr_processing_logs)
cumulative_pcm_duration_stream_time += len(pcm_array) / self.sample_rate
stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time
self.transcription.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm)
new_tokens, current_audio_processed_upto = await asyncio.to_thread(self.transcription.process_iter)
new_tokens = new_tokens or []
else:
self.transcription_queue.task_done()
continue
self.transcription.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm)
new_tokens, current_audio_processed_upto = await asyncio.to_thread(self.transcription.process_iter)
_buffer_transcript = self.transcription.get_buffer()
buffer_text = _buffer_transcript.text
@@ -304,7 +348,7 @@ class AudioProcessor:
logger.debug("Diarization processor received sentinel. Finishing.")
self.diarization_queue.task_done()
break
elif type(item) is Silence:
elif type(item) is Silence and item.has_ended:
diarization_obj.insert_silence(item.duration)
continue
elif isinstance(item, np.ndarray):
@@ -380,7 +424,7 @@ class AudioProcessor:
if additional_token is SENTINEL:
sentinel_found = True
break
elif type(additional_token) is Silence:
elif type(additional_token) is Silence and additional_token.has_ended:
self.translation.insert_silence(additional_token.duration)
continue
else:
@@ -640,26 +684,15 @@ class AudioProcessor:
self.pcm_buffer = self.pcm_buffer[aligned_chunk_size:]
res = None
end_of_audio = False
silence_buffer = None
if self.args.vac:
res = self.vac(pcm_array)
if res is not None:
if res.get("end", 0) > res.get("start", 0):
end_of_audio = True
elif self.silence: #end of silence
self.silence = False
silence_buffer = Silence(duration=time() - self.start_silence)
if res.get("end", 0) > res.get("start", 0) and not self.silence:
await self._begin_silence()
elif self.silence:
await self._end_silence()
if silence_buffer:
if not self.diarization_before_transcription and self.transcription_queue:
await self.transcription_queue.put(silence_buffer)
if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(silence_buffer)
if self.translation_queue:
await self.translation_queue.put(silence_buffer)
if not self.silence:
if not self.diarization_before_transcription and self.transcription_queue:
@@ -670,9 +703,5 @@ class AudioProcessor:
self.silence_duration = 0.0
if end_of_audio:
self.silence = True
self.start_silence = time()
if not self.args.transcription and not self.args.diarization:
await asyncio.sleep(0.1)

View File

@@ -160,7 +160,7 @@ class SortformerDiarizationOnline:
# Initialize total predictions tensor
self.total_preds = torch.zeros((batch_size, 0, self.diar_model.sortformer_modules.n_spk), device=device)
def insert_silence(self, silence_duration: float):
def insert_silence(self, silence_duration: Optional[float]):
"""
Insert silence period by adjusting the global time offset.

View File

@@ -151,21 +151,32 @@ class OnlineASRProcessor:
"""Append an audio chunk (a numpy array) to the current audio buffer."""
self.audio_buffer = np.append(self.audio_buffer, audio)
def insert_silence(self, silence_duration, offset):
"""
If silences are > 5s, we do a complete context clear. Otherwise, we just insert a small silence and shift the last_attend_frame
"""
# if self.transcript_buffer.buffer:
# self.committed.extend(self.transcript_buffer.buffer)
# self.transcript_buffer.buffer = []
if True: #silence_duration < 3: #we want the last audio to be treated to not have a gap. could also be handled in the future in ends_with_silence.
gap_silence = np.zeros(int(16000 * silence_duration), dtype=np.int16)
self.insert_audio_chunk(gap_silence)
def start_silence(self):
if self.audio_buffer.size == 0:
return [], self.get_audio_buffer_end_time()
return self.process_iter()
def end_silence(self, silence_duration: Optional[float], offset: float):
if not silence_duration or silence_duration <= 0:
return
long_silence = silence_duration >= 5
if not long_silence:
gap_samples = int(self.SAMPLING_RATE * silence_duration)
if gap_samples > 0:
gap_silence = np.zeros(gap_samples, dtype=np.float32)
self.insert_audio_chunk(gap_silence)
else:
self.init(offset=silence_duration + offset)
self.global_time_offset += silence_duration
def insert_silence(self, silence_duration, offset):
"""
Backwards compatibility shim for legacy callers that still use insert_silence.
"""
self.end_silence(silence_duration, offset)
def prompt(self) -> Tuple[str, str]:
"""
Returns a tuple: (prompt, context), where:

View File

@@ -63,16 +63,22 @@ class SimulStreamingOnlineProcessor:
fw_encoder=self.asr.fw_encoder,
)
def insert_silence(self, silence_duration, offset):
def start_silence(self):
tokens, processed_upto = self.process_iter(is_last=True)
return tokens, processed_upto
def end_silence(self, silence_duration, offset):
"""
If silences are > 5s, we do a complete context clear. Otherwise, we just insert a small silence and shift the last_attend_frame
"""
if silence_duration < 5:
gap_silence = torch.zeros(int(16000*silence_duration))
self.model.insert_audio(gap_silence)
# self.global_time_offset += silence_duration
else:
self.process_iter(is_last=True) #we want to totally process what remains in the buffer.
self.end += silence_duration
long_silence = silence_duration >= 5
if not long_silence:
gap_len = int(16000 * silence_duration)
if gap_len > 0:
gap_silence = torch.zeros(gap_len)
self.model.insert_audio(gap_silence)
if long_silence:
self.model.refresh_segment(complete=True)
self.model.global_time_offset = silence_duration + offset

View File

@@ -123,7 +123,9 @@ class Translation(TimedText):
@dataclass
class Silence():
duration: float
duration: Optional[float] = None
is_starting: bool = False
has_ended: bool = False
@dataclass