5 Commits

Author SHA1 Message Date
Quentin Fuxa
674b20d3af in buffer while language not detected » 2025-09-21 11:05:00 +02:00
Quentin Fuxa
a5503308c5 O(n) to O(1) for simulstreaming timestamp determination 2025-09-21 11:04:00 +02:00
Quentin Fuxa
e61afdefa3 punctuation is now checked in timed_object 2025-09-22 22:40:39 +02:00
Quentin Fuxa
426d70a790 simulstreaming infer does not return a dictionary anymore 2025-09-21 11:03:00 +02:00
Quentin Fuxa
b03a212fbf fixes #227 , auto language dectection v0.1 - simulstreaming only - when diarization and auto 2025-09-19 19:15:28 +02:00
15 changed files with 262 additions and 303 deletions

BIN
demo.png

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.2 MiB

After

Width:  |  Height:  |  Size: 985 KiB

View File

@@ -4,7 +4,7 @@ from time import time, sleep
import math
import logging
import traceback
from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State
from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State, Transcript
from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory, online_translation_factory
from whisperlivekit.silero_vad_iterator import FixedVADIterator
from whisperlivekit.results_formater import format_output
@@ -58,7 +58,7 @@ class AudioProcessor:
self.silence_duration = 0.0
self.tokens = []
self.translated_segments = []
self.buffer_transcription = ""
self.buffer_transcription = Transcript()
self.buffer_diarization = ""
self.end_buffer = 0
self.end_attributed_speaker = 0
@@ -66,6 +66,8 @@ class AudioProcessor:
self.beg_loop = None #to deal with a potential little lag at the websocket initialization, this is now set in process_audio
self.sep = " " # Default separator
self.last_response_content = FrontData()
self.last_detected_speaker = None
self.speaker_languages = {}
# Models and processing
self.asr = models.asr
@@ -112,20 +114,6 @@ class AudioProcessor:
"""Convert PCM buffer in s16le format to normalized NumPy array."""
return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
async def update_transcription(self, new_tokens, buffer, end_buffer):
"""Thread-safe update of transcription with new data."""
async with self.lock:
self.tokens.extend(new_tokens)
self.buffer_transcription = buffer
self.end_buffer = end_buffer
async def update_diarization(self, end_attributed_speaker, buffer_diarization=""):
"""Thread-safe update of diarization with new data."""
async with self.lock:
self.end_attributed_speaker = end_attributed_speaker
if buffer_diarization:
self.buffer_diarization = buffer_diarization
async def add_dummy_token(self):
"""Placeholder token when no transcription is available."""
async with self.lock:
@@ -166,7 +154,7 @@ class AudioProcessor:
async with self.lock:
self.tokens = []
self.translated_segments = []
self.buffer_transcription = self.buffer_diarization = ""
self.buffer_transcription = self.buffer_diarization = Transcript()
self.end_buffer = self.end_attributed_speaker = 0
self.beg_loop = time()
@@ -262,30 +250,28 @@ class AudioProcessor:
self.online.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm)
new_tokens, current_audio_processed_upto = await asyncio.to_thread(self.online.process_iter)
# Get buffer information
_buffer_transcript_obj = self.online.get_buffer()
buffer_text = _buffer_transcript_obj.text
_buffer_transcript = self.online.get_buffer()
buffer_text = _buffer_transcript.text
if new_tokens:
validated_text = self.sep.join([t.text for t in new_tokens])
if buffer_text.startswith(validated_text):
buffer_text = buffer_text[len(validated_text):].lstrip()
_buffer_transcript.text = buffer_text[len(validated_text):].lstrip()
candidate_end_times = [self.end_buffer]
if new_tokens:
candidate_end_times.append(new_tokens[-1].end)
if _buffer_transcript_obj.end is not None:
candidate_end_times.append(_buffer_transcript_obj.end)
if _buffer_transcript.end is not None:
candidate_end_times.append(_buffer_transcript.end)
candidate_end_times.append(current_audio_processed_upto)
new_end_buffer = max(candidate_end_times)
await self.update_transcription(
new_tokens, buffer_text, new_end_buffer
)
async with self.lock:
self.tokens.extend(new_tokens)
self.buffer_transcription = _buffer_transcript
self.end_buffer = max(candidate_end_times)
if self.translation_queue:
for token in new_tokens:
@@ -333,7 +319,7 @@ class AudioProcessor:
await diarization_obj.diarize(pcm_array)
async with self.lock:
self.tokens = diarization_obj.assign_speakers_to_tokens(
self.tokens, last_segment = diarization_obj.assign_speakers_to_tokens(
self.tokens,
use_punctuation_split=self.args.punctuation_split
)
@@ -341,7 +327,12 @@ class AudioProcessor:
self.end_attributed_speaker = max(self.tokens[-1].end, self.end_attributed_speaker)
if buffer_diarization:
self.buffer_diarization = buffer_diarization
# if last_segment is not None and last_segment.speaker != self.last_detected_speaker:
# if not self.speaker_languages.get(last_segment.speaker, None):
# self.last_detected_speaker = last_segment.speaker
# self.online.on_new_speaker(last_segment)
self.diarization_queue.task_done()
except Exception as e:
@@ -422,7 +413,7 @@ class AudioProcessor:
state = await self.get_current_state()
# Format output
lines, undiarized_text, buffer_transcription, buffer_diarization = format_output(
lines, undiarized_text, end_w_silence = format_output(
state,
self.silence,
current_time = time() - self.beg_loop if self.beg_loop else None,
@@ -430,13 +421,25 @@ class AudioProcessor:
debug = self.debug,
sep=self.sep
)
if end_w_silence:
buffer_transcription = Transcript()
buffer_diarization = Transcript()
else:
buffer_transcription = state.buffer_transcription
buffer_diarization = state.buffer_diarization
# Handle undiarized text
if undiarized_text:
combined = self.sep.join(undiarized_text)
if buffer_transcription:
combined += self.sep
await self.update_diarization(state.end_attributed_speaker, combined)
buffer_diarization = combined
async with self.lock:
self.end_attributed_speaker = state.end_attributed_speaker
if buffer_diarization:
self.buffer_diarization = buffer_diarization
buffer_diarization.text = combined
response_status = "active_transcription"
if not state.tokens and not buffer_transcription and not buffer_diarization:
@@ -452,8 +455,8 @@ class AudioProcessor:
response = FrontData(
status=response_status,
lines=lines,
buffer_transcription=buffer_transcription,
buffer_diarization=buffer_diarization,
buffer_transcription=buffer_transcription.text,
buffer_diarization=buffer_transcription.text,
remaining_time_transcription=state.remaining_time_transcription,
remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0
)
@@ -552,20 +555,20 @@ class AudioProcessor:
if task and not task.done():
task.cancel()
created_tasks = [t for t in self.all_tasks_for_cleanup if t]
if created_tasks:
await asyncio.gather(*created_tasks, return_exceptions=True)
logger.info("All processing tasks cancelled or finished.")
created_tasks = [t for t in self.all_tasks_for_cleanup if t]
if created_tasks:
await asyncio.gather(*created_tasks, return_exceptions=True)
logger.info("All processing tasks cancelled or finished.")
if not self.is_pcm_input and self.ffmpeg_manager:
try:
await self.ffmpeg_manager.stop()
logger.info("FFmpeg manager stopped.")
except Exception as e:
logger.warning(f"Error stopping FFmpeg manager: {e}")
if self.args.diarization and hasattr(self, 'dianization') and hasattr(self.diarization, 'close'):
self.diarization.close()
logger.info("AudioProcessor cleanup complete.")
if not self.is_pcm_input and self.ffmpeg_manager:
try:
await self.ffmpeg_manager.stop()
logger.info("FFmpeg manager stopped.")
except Exception as e:
logger.warning(f"Error stopping FFmpeg manager: {e}")
if self.args.diarization and hasattr(self, 'dianization') and hasattr(self.diarization, 'close'):
self.diarization.close()
logger.info("AudioProcessor cleanup complete.")
async def process_audio(self, message):

View File

@@ -242,7 +242,7 @@ class DiartDiarization:
token.speaker = extract_number(segment.speaker) + 1
else:
tokens = add_speaker_to_tokens(segments, tokens)
return tokens
return tokens, segments[-1]
def concatenate_speakers(segments):
segments_concatenated = [{"speaker": 1, "begin": 0.0, "end": 0.0}]

View File

@@ -289,13 +289,14 @@ class SortformerDiarizationOnline:
Returns:
List of tokens with speaker assignments
Last speaker_segment
"""
with self.segment_lock:
segments = self.speaker_segments.copy()
if not segments or not tokens:
logger.debug("No segments or tokens available for speaker assignment")
return tokens
return tokens, None
logger.debug(f"Assigning speakers to {len(tokens)} tokens using {len(segments)} segments")
use_punctuation_split = False
@@ -312,7 +313,7 @@ class SortformerDiarizationOnline:
# Use punctuation-aware assignment (similar to diart_backend)
tokens = self._add_speaker_to_tokens_with_punctuation(segments, tokens)
return tokens
return tokens, segments[-1]
def _add_speaker_to_tokens_with_punctuation(self, segments: List[SpeakerSegment], tokens: list) -> list:
"""

View File

@@ -77,15 +77,17 @@ def no_token_to_silence(tokens):
new_tokens.append(token)
return new_tokens
def ends_with_silence(tokens, buffer_transcription, buffer_diarization, current_time, vac_detected_silence):
def ends_with_silence(tokens, current_time, vac_detected_silence):
end_w_silence = False
if not tokens:
return [], buffer_transcription, buffer_diarization
return [], end_w_silence
last_token = tokens[-1]
if tokens and current_time and (
current_time - last_token.end >= END_SILENCE_DURATION
or
or
(current_time - last_token.end >= 3 and vac_detected_silence)
):
end_w_silence = True
if last_token.speaker == -2:
last_token.end = current_time
else:
@@ -97,14 +99,12 @@ def ends_with_silence(tokens, buffer_transcription, buffer_diarization, current_
probability=0.95
)
)
buffer_transcription = "" # for whisperstreaming backend, we should probably validate the buffer has because of the silence
buffer_diarization = ""
return tokens, buffer_transcription, buffer_diarization
return tokens, end_w_silence
def handle_silences(tokens, buffer_transcription, buffer_diarization, current_time, vac_detected_silence):
def handle_silences(tokens, current_time, vac_detected_silence):
tokens = blank_to_silence(tokens) #useful for simulstreaming backend which tends to generate [BLANK_AUDIO] text
tokens = no_token_to_silence(tokens)
tokens, buffer_transcription, buffer_diarization = ends_with_silence(tokens, buffer_transcription, buffer_diarization, current_time, vac_detected_silence)
return tokens, buffer_transcription, buffer_diarization
tokens, end_w_silence = ends_with_silence(tokens, current_time, vac_detected_silence)
return tokens, end_w_silence

View File

@@ -6,11 +6,10 @@ from whisperlivekit.timed_objects import Line, format_time
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
PUNCTUATION_MARKS = {'.', '!', '?', '', '', ''}
CHECK_AROUND = 4
def is_punctuation(token):
if token.text.strip() in PUNCTUATION_MARKS:
if token.is_punctuation():
return True
return False
@@ -51,14 +50,12 @@ def format_output(state, silence, current_time, args, debug, sep):
disable_punctuation_split = args.disable_punctuation_split
tokens = state.tokens
translated_segments = state.translated_segments # Here we will attribute the speakers only based on the timestamps of the segments
buffer_transcription = state.buffer_transcription
buffer_diarization = state.buffer_diarization
end_attributed_speaker = state.end_attributed_speaker
previous_speaker = -1
lines = []
undiarized_text = []
tokens, buffer_transcription, buffer_diarization = handle_silences(tokens, buffer_transcription, buffer_diarization, current_time, silence)
tokens, end_w_silence = handle_silences(tokens, current_time, silence)
last_punctuation = None
for i, token in enumerate(tokens):
speaker = token.speaker
@@ -122,6 +119,7 @@ def format_output(state, silence, current_time, args, debug, sep):
pass
append_token_to_last_line(lines, sep, token, debug_info)
if lines and translated_segments:
unassigned_translated_segments = []
for ts in translated_segments:
@@ -152,4 +150,8 @@ def format_output(state, silence, current_time, args, debug, sep):
else:
remaining_segments.append(ts)
unassigned_translated_segments = remaining_segments #maybe do smth in the future about that
return lines, undiarized_text, buffer_transcription, ''
if state.buffer_transcription and lines:
lines[-1].end = max(state.buffer_transcription.end, lines[-1].end)
return lines, undiarized_text, end_w_silence

View File

@@ -4,9 +4,8 @@ import logging
from typing import List, Tuple, Optional
import logging
import platform
from whisperlivekit.timed_objects import ASRToken, Transcript
from whisperlivekit.timed_objects import ASRToken, Transcript, SpeakerSegment
from whisperlivekit.warmup import load_file
from whisperlivekit.simul_whisper.license_simulstreaming import SIMULSTREAMING_LICENSE
from .whisper import load_model, tokenizer
from .whisper.audio import TOKENS_PER_SECOND
import os
@@ -23,7 +22,11 @@ try:
HAS_MLX_WHISPER = True
except ImportError:
if platform.system() == "Darwin" and platform.machine() == "arm64":
print('MLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: pip install mlx-whisper')
print(f"""
{"="*50}
MLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: pip install mlx-whisper
{"="*50}
""")
HAS_MLX_WHISPER = False
if HAS_MLX_WHISPER:
HAS_FASTER_WHISPER = False
@@ -49,8 +52,7 @@ class SimulStreamingOnlineProcessor:
self.asr = asr
self.logfile = logfile
self.end = 0.0
self.global_time_offset = 0.0
self.buffer = []
self.committed: List[ASRToken] = []
self.last_result_tokens: List[ASRToken] = []
self.load_new_backend()
@@ -79,7 +81,7 @@ class SimulStreamingOnlineProcessor:
else:
self.process_iter(is_last=True) #we want to totally process what remains in the buffer.
self.model.refresh_segment(complete=True)
self.global_time_offset = silence_duration + offset
self.model.global_time_offset = silence_duration + offset
@@ -91,64 +93,14 @@ class SimulStreamingOnlineProcessor:
self.end = audio_stream_end_time #Only to be aligned with what happens in whisperstreaming backend.
self.model.insert_audio(audio_tensor)
def on_new_speaker(self, last_segment: SpeakerSegment):
self.model.on_new_speaker(last_segment)
self.model.refresh_segment(complete=True)
def get_buffer(self):
return Transcript(
start=None,
end=None,
text='',
probability=None
)
def timestamped_text(self, tokens, generation):
"""
generate timestamped text from tokens and generation data.
args:
tokens: List of tokens to process
generation: Dictionary containing generation progress and optionally results
concat_buffer = Transcript.from_tokens(tokens= self.buffer, sep='')
return concat_buffer
returns:
List of tuples containing (start_time, end_time, word) for each word
"""
FRAME_DURATION = 0.02
if "result" in generation:
split_words = generation["result"]["split_words"]
split_tokens = generation["result"]["split_tokens"]
else:
split_words, split_tokens = self.model.tokenizer.split_to_word_tokens(tokens)
progress = generation["progress"]
frames = [p["most_attended_frames"][0] for p in progress]
absolute_timestamps = [p["absolute_timestamps"][0] for p in progress]
tokens_queue = tokens.copy()
timestamped_words = []
for word, word_tokens in zip(split_words, split_tokens):
# start_frame = None
# end_frame = None
for expected_token in word_tokens:
if not tokens_queue or not frames:
raise ValueError(f"Insufficient tokens or frames for word '{word}'")
actual_token = tokens_queue.pop(0)
current_frame = frames.pop(0)
current_timestamp = absolute_timestamps.pop(0)
if actual_token != expected_token:
raise ValueError(
f"Token mismatch: expected '{expected_token}', "
f"got '{actual_token}' at frame {current_frame}"
)
# if start_frame is None:
# start_frame = current_frame
# end_frame = current_frame
# start_time = start_frame * FRAME_DURATION
# end_time = end_frame * FRAME_DURATION
start_time = current_timestamp
end_time = current_timestamp + 0.1
timestamp_entry = (start_time, end_time, word)
timestamped_words.append(timestamp_entry)
logger.debug(f"TS-WORD:\t{start_time:.2f}\t{end_time:.2f}\t{word}")
return timestamped_words
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
"""
Process accumulated audio chunks using SimulStreaming.
@@ -156,47 +108,10 @@ class SimulStreamingOnlineProcessor:
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
"""
try:
tokens, generation_progress = self.model.infer(is_last=is_last)
ts_words = self.timestamped_text(tokens, generation_progress)
new_tokens = []
for ts_word in ts_words:
start, end, word = ts_word
token = ASRToken(
start=start,
end=end,
text=word,
probability=0.95 # fake prob. Maybe we can extract it from the model?
).with_offset(
self.global_time_offset
)
new_tokens.append(token)
# identical_tokens = 0
# n_new_tokens = len(new_tokens)
# if n_new_tokens:
self.committed.extend(new_tokens)
# if token in self.committed:
# pos = len(self.committed) - 1 - self.committed[::-1].index(token)
# if pos:
# for i in range(len(self.committed) - n_new_tokens, -1, -n_new_tokens):
# commited_segment = self.committed[i:i+n_new_tokens]
# if commited_segment == new_tokens:
# identical_segments +=1
# if identical_tokens >= TOO_MANY_REPETITIONS:
# logger.warning('Too many repetition, model is stuck. Load a new one')
# self.committed = self.committed[:i]
# self.load_new_backend()
# return [], self.end
# pos = self.committed.rindex(token)
return new_tokens, self.end
timestamped_words, timestamped_buffer_language = self.model.infer(is_last=is_last)
self.buffer = timestamped_buffer_language
self.committed.extend(timestamped_words)
return timestamped_words, self.end
except Exception as e:
@@ -226,7 +141,6 @@ class SimulStreamingASR():
sep = ""
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr, **kwargs):
logger.warning(SIMULSTREAMING_LICENSE)
self.logfile = logfile
self.transcribe_kargs = {}
self.original_language = lan
@@ -362,4 +276,4 @@ class SimulStreamingASR():
"""
Warmup is done directly in load_model
"""
pass
pass

View File

@@ -8,6 +8,7 @@ import torch.nn.functional as F
from .whisper import load_model, DecodingOptions, tokenizer
from .config import AlignAttConfig
from whisperlivekit.timed_objects import ASRToken
from .whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES
from .whisper.timing import median_filter
from .whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens, detect_language
@@ -18,6 +19,7 @@ from time import time
from .token_buffer import TokenBuffer
import numpy as np
from ..timed_objects import PUNCTUATION_MARKS
from .generation_progress import *
DEC_PAD = 50257
@@ -40,12 +42,6 @@ else:
except ImportError:
HAS_FASTER_WHISPER = False
# New features added to the original version of Simul-Whisper:
# - large-v3 model support
# - translation support
# - beam search
# - prompt -- static vs. non-static
# - context
class PaddedAlignAttWhisper:
def __init__(
self,
@@ -78,7 +74,11 @@ class PaddedAlignAttWhisper:
)
self.tokenizer_is_multilingual = not model_name.endswith(".en")
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
# self.create_tokenizer('en')
self.detected_language = cfg.language if cfg.language != "auto" else None
self.global_time_offset = 0.0
self.reset_tokenizer_to_auto_next_call = False
self.sentence_start_time = 0.0
self.max_text_len = self.model.dims.n_text_ctx
self.num_decoder_layers = len(self.model.decoder.blocks)
@@ -153,6 +153,7 @@ class PaddedAlignAttWhisper:
self.last_attend_frame = -self.cfg.rewind_threshold
self.cumulative_time_offset = 0.0
self.sentence_start_time = self.cumulative_time_offset + self.segments_len()
if self.cfg.max_context_tokens is None:
self.max_context_tokens = self.max_text_len
@@ -382,11 +383,11 @@ class PaddedAlignAttWhisper:
new_segment = True
if len(self.segments) == 0:
logger.debug("No segments, nothing to do")
return [], {}
return []
if not self._apply_minseglen():
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
input_segments = torch.cat(self.segments, dim=0)
return [], {}
return []
# input_segments is concatenation of audio, it's one array
if len(self.segments) > 1:
@@ -394,6 +395,13 @@ class PaddedAlignAttWhisper:
else:
input_segments = self.segments[0]
# if self.cfg.language == "auto" and self.reset_tokenizer_to_auto_next_call:
# logger.debug("Resetting tokenizer to auto for new sentence.")
# self.create_tokenizer(None)
# self.detected_language = None
# self.init_tokens()
# self.reset_tokenizer_to_auto_next_call = False
# NEW : we can use a different encoder, before using standart whisper for cross attention with the hooks on the decoder
beg_encode = time()
if self.mlx_encoder:
@@ -426,58 +434,37 @@ class PaddedAlignAttWhisper:
end_encode = time()
# print('Encoder duration:', end_encode-beg_encode)
# logger.debug(f"Encoder feature shape: {encoder_feature.shape}")
# if mel.shape[-2:] != (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
# logger.debug("mel ")
if self.cfg.language == "auto" and self.detected_language is None:
language_tokens, language_probs = self.lang_id(encoder_feature)
logger.debug(f"Language tokens: {language_tokens}, probs: {language_probs}")
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
logger.info(f"Detected language: {top_lan} with p={p:.4f}")
#self.tokenizer.language = top_lan
#self.tokenizer.__post_init__()
self.create_tokenizer(top_lan)
self.detected_language = top_lan
self.init_tokens()
logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}")
seconds_since_start = (self.cumulative_time_offset + self.segments_len()) - self.sentence_start_time
if seconds_since_start >= 3.0:
language_tokens, language_probs = self.lang_id(encoder_feature)
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
print(f"Detected language: {top_lan} with p={p:.4f}")
self.create_tokenizer(top_lan)
self.refresh_segment(complete=True)
self.detected_language = top_lan
logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}")
else:
logger.debug(f"Skipping language detection: {seconds_since_start:.2f}s < 3.0s")
self.trim_context()
current_tokens = self._current_tokens()
#
fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :])
####################### Decoding loop
logger.info("Decoding loop starts\n")
sum_logprobs = torch.zeros(self.cfg.beam_size, device=self.device)
completed = False
# punctuation_stop = False
attn_of_alignment_heads = None
most_attended_frame = None
token_len_before_decoding = current_tokens.shape[1]
generation_progress = []
generation = {
"starting_tokens": BeamTokens(current_tokens[0,:].clone(), self.cfg.beam_size),
"token_len_before_decoding": token_len_before_decoding,
#"fire_detected": fire_detected,
"frames_len": content_mel_len,
"frames_threshold": 4 if is_last else self.cfg.frame_threshold,
# to be filled later
"logits_starting": None,
# to be filled later
"no_speech_prob": None,
"no_speech": False,
# to be filled in the loop
"progress": generation_progress,
}
l_absolute_timestamps = []
while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens
generation_progress_loop = []
if new_segment:
tokens_for_logits = current_tokens
@@ -486,50 +473,26 @@ class PaddedAlignAttWhisper:
tokens_for_logits = current_tokens[:,-1:]
logits = self.logits(tokens_for_logits, encoder_feature) # B, len(tokens), token dict size
if new_segment:
generation["logits_starting"] = Logits(logits[:,:,:])
if new_segment and self.tokenizer.no_speech is not None:
probs_at_sot = logits[:, self.sot_index, :].float().softmax(dim=-1)
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
generation["no_speech_prob"] = no_speech_probs[0]
if no_speech_probs[0] > self.cfg.nonspeech_prob:
generation["no_speech"] = True
logger.info("no speech, stop")
break
logits = logits[:, -1, :] # logits for the last token
generation_progress_loop.append(("logits_before_suppress",Logits(logits)))
# supress blank tokens only at the beginning of the segment
if new_segment:
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
new_segment = False
self.suppress_tokens(logits)
#generation_progress_loop.append(("logits_after_suppres",BeamLogits(logits[0,:].clone(), self.cfg.beam_size)))
generation_progress_loop.append(("logits_after_suppress",Logits(logits)))
current_tokens, completed = self.token_decoder.update(current_tokens, logits, sum_logprobs)
generation_progress_loop.append(("beam_tokens",Tokens(current_tokens[:,-1].clone())))
generation_progress_loop.append(("sum_logprobs",sum_logprobs.tolist()))
generation_progress_loop.append(("completed",completed))
logger.debug(f"Decoding completed: {completed}, sum_logprobs: {sum_logprobs.tolist()}, tokens: ")
self.debug_print_tokens(current_tokens)
# if self.decoder_type == "beam":
# logger.debug(f"Finished sequences: {self.token_decoder.finished_sequences}")
# logprobs = F.log_softmax(logits.float(), dim=-1)
# idx = 0
# logger.debug(f"Beam search topk: {logprobs[idx].topk(self.cfg.beam_size + 1)}")
# logger.debug(f"Greedy search argmax: {logits.argmax(dim=-1)}")
# if completed:
# self.debug_print_tokens(current_tokens)
# logger.debug("decode stopped because decoder completed")
attn_of_alignment_heads = [[] for _ in range(self.num_align_heads)]
for i, attn_mat in enumerate(self.dec_attns):
layer_rank = int(i % len(self.model.decoder.blocks))
@@ -548,30 +511,24 @@ class PaddedAlignAttWhisper:
t = torch.cat(mat, dim=1)
tmp.append(t)
attn_of_alignment_heads = torch.stack(tmp, dim=1)
# logger.debug(str(attn_of_alignment_heads.shape) + " tttady")
std, mean = torch.std_mean(attn_of_alignment_heads, dim=-2, keepdim=True, unbiased=False)
attn_of_alignment_heads = (attn_of_alignment_heads - mean) / std
attn_of_alignment_heads = median_filter(attn_of_alignment_heads, 7) # from whisper.timing
attn_of_alignment_heads = attn_of_alignment_heads.mean(dim=1)
# logger.debug(str(attn_of_alignment_heads.shape) + " po mean")
attn_of_alignment_heads = attn_of_alignment_heads[:,:, :content_mel_len]
# logger.debug(str(attn_of_alignment_heads.shape) + " pak ")
# for each beam, the most attended frame is:
most_attended_frames = torch.argmax(attn_of_alignment_heads[:,-1,:], dim=-1)
generation_progress_loop.append(("most_attended_frames",most_attended_frames.clone().tolist()))
# Calculate absolute timestamps accounting for cumulative offset
absolute_timestamps = [(frame * 0.02 + self.cumulative_time_offset) for frame in most_attended_frames.tolist()]
generation_progress_loop.append(("absolute_timestamps", absolute_timestamps))
logger.debug(str(most_attended_frames.tolist()) + " most att frames")
logger.debug(f"Absolute timestamps: {absolute_timestamps} (offset: {self.cumulative_time_offset:.2f}s)")
most_attended_frame = most_attended_frames[0].item()
l_absolute_timestamps.append(absolute_timestamps[0])
generation_progress.append(dict(generation_progress_loop))
logger.debug("current tokens" + str(current_tokens.shape))
if completed:
# # stripping the last token, the eot
@@ -609,66 +566,53 @@ class PaddedAlignAttWhisper:
self.tokenizer.decode([current_tokens[i, -1].item()])
))
# for k,v in generation.items():
# print(k,v,file=sys.stderr)
# for x in generation_progress:
# for y in x.items():
# print("\t\t",*y,file=sys.stderr)
# print("\t","----", file=sys.stderr)
# print("\t", "end of generation_progress_loop", file=sys.stderr)
# sys.exit(1)
####################### End of decoding loop
logger.info("End of decoding loop")
# if attn_of_alignment_heads is not None:
# seg_len = int(segment.shape[0] / 16000 * TOKENS_PER_SECOND)
# # Lets' now consider only the top hypothesis in the beam search
# top_beam_attn_of_alignment_heads = attn_of_alignment_heads[0]
# # debug print: how is the new token attended?
# new_token_attn = top_beam_attn_of_alignment_heads[token_len_before_decoding:, -seg_len:]
# logger.debug(f"New token attention shape: {new_token_attn.shape}")
# if new_token_attn.shape[0] == 0: # it's not attended in the current audio segment
# logger.debug("no token generated")
# else: # it is, and the max attention is:
# new_token_max_attn, _ = new_token_attn.max(dim=-1)
# logger.debug(f"segment max attention: {new_token_max_attn.mean().item()/len(self.segments)}")
# let's now operate only with the top beam hypothesis
tokens_to_split = current_tokens[0, token_len_before_decoding:]
if fire_detected or is_last:
if fire_detected or is_last: #or punctuation_stop:
new_hypothesis = tokens_to_split.flatten().tolist()
split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis)
else:
# going to truncate the tokens after the last space
split_words, split_tokens = self.tokenizer.split_to_word_tokens(tokens_to_split.tolist())
generation["result"] = {"split_words": split_words[:-1], "split_tokens": split_tokens[:-1]}
generation["result_truncated"] = {"split_words": split_words[-1:], "split_tokens": split_tokens[-1:]}
# text_to_split = self.tokenizer.decode(tokens_to_split)
# logger.debug(f"text_to_split: {text_to_split}")
# logger.debug("text at current step: {}".format(text_to_split.replace(" ", "<space>")))
# text_before_space = " ".join(text_to_split.split(" ")[:-1])
# logger.debug("before the last space: {}".format(text_before_space.replace(" ", "<space>")))
if len(split_words) > 1:
new_hypothesis = [i for sublist in split_tokens[:-1] for i in sublist]
else:
new_hypothesis = []
### new hypothesis
logger.debug(f"new_hypothesis: {new_hypothesis}")
new_tokens = torch.tensor([new_hypothesis], dtype=torch.long).repeat_interleave(self.cfg.beam_size, dim=0).to(
device=self.device,
)
self.tokens.append(new_tokens)
# TODO: test if this is redundant or not
# ret = ret[ret<DEC_PAD]
logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}")
self._clean_cache()
return new_hypothesis, generation
timestamped_words = []
timestamp_idx = 0
for word, word_tokens in zip(split_words, split_tokens):
try:
current_timestamp = l_absolute_timestamps[timestamp_idx]
except:
pass
timestamp_idx += len(word_tokens)
timestamp_entry = ASRToken(
start=current_timestamp,
end=current_timestamp + 0.1,
text= word,
probability=0.95,
language=self.detected_language
).with_offset(
self.global_time_offset
)
timestamped_words.append(timestamp_entry)
if self.detected_language is None and self.cfg.language == "auto":
timestamped_buffer_language, timestamped_words = timestamped_words, []
else:
timestamped_buffer_language = []
return timestamped_words, timestamped_buffer_language

View File

@@ -1,7 +1,9 @@
from dataclasses import dataclass, field
from typing import Optional, Any
from typing import Optional, Any, List
from datetime import timedelta
PUNCTUATION_MARKS = {'.', '!', '?', '', '', ''}
def format_time(seconds: float) -> str:
"""Format seconds as HH:MM:SS."""
return str(timedelta(seconds=int(seconds)))
@@ -15,6 +17,10 @@ class TimedText:
speaker: Optional[int] = -1
probability: Optional[float] = None
is_dummy: Optional[bool] = False
language: str = None
def is_punctuation(self):
return self.text.strip() in PUNCTUATION_MARKS
def overlaps_with(self, other: 'TimedText') -> bool:
return not (self.end <= other.start or other.end <= self.start)
@@ -30,6 +36,10 @@ class TimedText:
def contains_timespan(self, other: 'TimedText') -> bool:
return self.start <= other.start and self.end >= other.end
def __bool__(self):
return bool(self.text)
@dataclass
class ASRToken(TimedText):
@@ -43,7 +53,28 @@ class Sentence(TimedText):
@dataclass
class Transcript(TimedText):
pass
"""
represents a concatenation of several ASRToken
"""
@classmethod
def from_tokens(
cls,
tokens: List[ASRToken],
sep: Optional[str] = None,
offset: float = 0
) -> "Transcript":
sep = sep if sep is not None else ' '
text = sep.join(token.text for token in tokens)
probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
if tokens:
start = offset + tokens[0].start
end = offset + tokens[-1].end
else:
start = None
end = None
return cls(start, end, text, probability=probability)
@dataclass
class SpeakerSegment(TimedText):
@@ -92,16 +123,22 @@ class Silence():
@dataclass
class Line(TimedText):
translation: str = ''
detected_language: str = None
def to_dict(self):
return {
_dict = {
'speaker': int(self.speaker),
'text': self.text,
'translation': self.translation,
'start': format_time(self.start),
'end': format_time(self.end),
}
if self.translation:
_dict['translation'] = self.translation
if self.detected_language:
_dict['detected_language'] = self.detected_language
return _dict
@dataclass
class FrontData():
status: str = ''

View File

@@ -12,8 +12,6 @@ logger = logging.getLogger(__name__)
#In diarization case, we may want to translate just one speaker, or at least start the sentences there
PUNCTUATION_MARKS = {'.', '!', '?', '', '', ''}
MIN_SILENCE_DURATION_DEL_BUFFER = 3 #After a silence of x seconds, we consider the model should not use the buffer, even if the previous
# sentence is not finished.
@@ -111,7 +109,7 @@ class OnlineTranslation:
if len(self.buffer) < self.len_processed_buffer + 3: #nothing new to process
return self.validated + [self.translation_remaining]
while i < len(self.buffer):
if self.buffer[i].text in PUNCTUATION_MARKS:
if self.buffer[i].is_punctuation():
translation_sentence = self.translate_tokens(self.buffer[:i+1])
self.validated.append(translation_sentence)
self.buffer = self.buffer[i+1:]

View File

@@ -346,7 +346,7 @@ label {
.label_diarization {
background-color: var(--chip-bg);
border-radius: 8px 8px 8px 8px;
border-radius: 100px;
padding: 2px 10px;
margin-left: 10px;
display: inline-block;
@@ -358,7 +358,7 @@ label {
.label_transcription {
background-color: var(--chip-bg);
border-radius: 8px 8px 8px 8px;
border-radius: 100px;
padding: 2px 10px;
display: inline-block;
white-space: nowrap;
@@ -370,16 +370,20 @@ label {
.label_translation {
background-color: var(--chip-bg);
display: inline-flex;
border-radius: 10px;
padding: 4px 8px;
margin-top: 4px;
font-size: 14px;
color: var(--text);
display: flex;
align-items: flex-start;
gap: 4px;
}
.lag-diarization-value {
margin-left: 10px;
}
.label_translation img {
margin-top: 2px;
}
@@ -391,7 +395,7 @@ label {
#timeInfo {
color: var(--muted);
margin-left: 10px;
margin-left: 0px;
}
.textcontent {
@@ -514,3 +518,49 @@ label {
padding: 10px;
}
}
.label_language {
background-color: var(--chip-bg);
margin-bottom: 0px;
margin-top: 5px;
height: 18.5px;
border-radius: 100px;
padding: 2px 8px;
margin-left: 10px;
display: inline-flex;
align-items: center;
gap: 4px;
font-size: 14px;
color: var(--muted);
}
.label_language img {
width: 12px;
height: 12px;
}
.silence-icon {
width: 14px;
height: 14px;
vertical-align: text-bottom;
}
.speaker-icon {
width: 16px;
height: 16px;
vertical-align: text-bottom;
}
.speaker-badge {
display: inline-flex;
align-items: center;
justify-content: center;
width: 16px;
height: 16px;
margin-left: -5px;
border-radius: 50%;
font-size: 11px;
line-height: 1;
font-weight: 800;
color: var(--muted);
}

View File

@@ -306,7 +306,7 @@ function renderLinesWithBuffer(
const showTransLag = !isFinalizing && remaining_time_transcription > 0;
const showDiaLag = !isFinalizing && !!buffer_diarization && remaining_time_diarization > 0;
const signature = JSON.stringify({
lines: (lines || []).map((it) => ({ speaker: it.speaker, text: it.text, start: it.start, end: it.end })),
lines: (lines || []).map((it) => ({ speaker: it.speaker, text: it.text, start: it.start, end: it.end, detected_language: it.detected_language })),
buffer_transcription: buffer_transcription || "",
buffer_diarization: buffer_diarization || "",
status: current_status,
@@ -335,13 +335,20 @@ function renderLinesWithBuffer(
let speakerLabel = "";
if (item.speaker === -2) {
speakerLabel = `<span class="silence">Silence<span id='timeInfo'>${timeInfo}</span></span>`;
const silenceIcon = `<img class="silence-icon" src="/web/src/silence.svg" alt="Silence" />`;
speakerLabel = `<span class="silence">${silenceIcon}<span id='timeInfo'>${timeInfo}</span></span>`;
} else if (item.speaker == 0 && !isFinalizing) {
speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'><span class="loading-diarization-value">${fmt1(
remaining_time_diarization
)}</span> second(s) of audio are undergoing diarization</span></span>`;
} else if (item.speaker !== 0) {
speakerLabel = `<span id="speaker">Speaker ${item.speaker}<span id='timeInfo'>${timeInfo}</span></span>`;
const speakerIcon = `<img class="speaker-icon" src="/web/src/speaker.svg" alt="Speaker ${item.speaker}" />`;
const speakerNum = `<span class="speaker-badge">${item.speaker}</span>`;
speakerLabel = `<span id="speaker">${speakerIcon}${speakerNum}<span id='timeInfo'>${timeInfo}</span></span>`;
if (item.detected_language) {
speakerLabel += `<span class="label_language"><img src="/web/src/language.svg" alt="Detected language" width="12" height="12" /><span>${item.detected_language}</span></span>`;
}
}
let currentLineText = item.text || "";

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M480-80q-82 0-155-31.5t-127.5-86Q143-252 111.5-325T80-480q0-83 31.5-155.5t86-127Q252-817 325-848.5T480-880q83 0 155.5 31.5t127 86q54.5 54.5 86 127T880-480q0 82-31.5 155t-86 127.5q-54.5 54.5-127 86T480-80Zm0-82q26-36 45-75t31-83H404q12 44 31 83t45 75Zm-104-16q-18-33-31.5-68.5T322-320H204q29 50 72.5 87t99.5 55Zm208 0q56-18 99.5-55t72.5-87H638q-9 38-22.5 73.5T584-178ZM170-400h136q-3-20-4.5-39.5T300-480q0-21 1.5-40.5T306-560H170q-5 20-7.5 39.5T160-480q0 21 2.5 40.5T170-400Zm216 0h188q3-20 4.5-39.5T580-480q0-21-1.5-40.5T574-560H386q-3 20-4.5 39.5T380-480q0 21 1.5 40.5T386-400Zm268 0h136q5-20 7.5-39.5T800-480q0-21-2.5-40.5T790-560H654q3 20 4.5 39.5T660-480q0 21-1.5 40.5T654-400Zm-16-240h118q-29-50-72.5-87T584-782q18 33 31.5 68.5T638-640Zm-234 0h152q-12-44-31-83t-45-75q-26 36-45 75t-31 83Zm-200 0h118q9-38 22.5-73.5T376-782q-56 18-99.5 55T204-640Z"/></svg>

After

Width:  |  Height:  |  Size: 976 B

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M514-556 320-752q9-3 19-5.5t21-2.5q66 0 113 47t47 113q0 11-1.5 22t-4.5 22ZM40-200v-32q0-33 17-62t47-44q51-26 115-44t141-18q26 0 49.5 2.5T456-392l-56-54q-9 3-19 4.5t-21 1.5q-66 0-113-47t-47-113q0-11 1.5-21t4.5-19L84-764q-11-11-11-28t11-28q12-12 28.5-12t27.5 12l675 685q11 11 11.5 27.5T816-80q-11 13-28 12.5T759-80L641-200h39q0 33-23.5 56.5T600-120H120q-33 0-56.5-23.5T40-200Zm80 0h480v-32q0-14-4.5-19.5T580-266q-36-18-92.5-36T360-320q-71 0-127.5 18T140-266q-9 5-14.5 14t-5.5 20v32Zm240 0Zm560-400q0 69-24.5 131.5T829-355q-12 14-30 15t-32-13q-13-13-12-31t12-33q30-38 46.5-85t16.5-98q0-51-16.5-97T767-781q-12-15-12.5-33t12.5-32q13-14 31.5-13.5T829-845q42 51 66.5 113.5T920-600Zm-182 0q0 32-10 61.5T700-484q-11 15-29.5 15.5T638-482q-13-13-13.5-31.5T633-549q6-11 9.5-24t3.5-27q0-14-3.5-27t-9.5-25q-9-17-8.5-35t13.5-31q14-14 32.5-13.5T700-716q18 25 28 54.5t10 61.5Z"/></svg>

After

Width:  |  Height:  |  Size: 984 B

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M480-480q-66 0-113-47t-47-113q0-66 47-113t113-47q66 0 113 47t47 113q0 66-47 113t-113 47ZM160-240v-32q0-34 17.5-62.5T224-378q62-31 126-46.5T480-440q66 0 130 15.5T736-378q29 15 46.5 43.5T800-272v32q0 33-23.5 56.5T720-160H240q-33 0-56.5-23.5T160-240Zm80 0h480v-32q0-11-5.5-20T700-306q-54-27-109-40.5T480-360q-56 0-111 13.5T260-306q-9 5-14.5 14t-5.5 20v32Zm240-320q33 0 56.5-23.5T560-640q0-33-23.5-56.5T480-720q-33 0-56.5 23.5T400-640q0 33 23.5 56.5T480-560Zm0-80Zm0 400Z"/></svg>

After

Width:  |  Height:  |  Size: 592 B