compatible with the latest version of simulstreaming

This commit is contained in:
Quentin Fuxa
2025-07-01 20:10:45 +02:00
parent 25526b3aa2
commit 62bf28949e
2 changed files with 44 additions and 44 deletions

View File

@@ -484,7 +484,8 @@ class SimulStreamingASR(ASRBase):
try:
if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio).float()
self.model.infer(audio, True)
self.model.insert_audio(audio)
self.model.infer(True)
self.model.refresh_segment(complete=True)
logger.info("SimulStreaming model warmed up successfully")
except Exception as e:

View File

@@ -609,6 +609,31 @@ class SimulStreamingOnlineProcessor:
probability=None
)
def timestamped_text(self, tokens, generation):
# From the simulstreaming repo. self.model to self.asr.model
pr = generation["progress"]
if "result" not in generation:
split_words, split_tokens = self.asr.model.tokenizer.split_to_word_tokens(tokens)
else:
split_words, split_tokens = generation["result"]["split_words"], generation["result"]["split_tokens"]
frames = [p["most_attended_frames"][0] for p in pr]
tokens = tokens.copy()
ret = []
for sw,st in zip(split_words,split_tokens):
b = None
for stt in st:
t,f = tokens.pop(0), frames.pop(0)
if t != stt:
raise ValueError(f"Token mismatch: {t} != {stt} at frame {f}.")
if b is None:
b = f
e = f
out = (b*0.02, e*0.02, sw)
ret.append(out)
logger.debug(f"TS-WORD:\t{' '.join(map(str, out))}")
return ret
def process_iter(self) -> Tuple[List[ASRToken], float]:
"""
Process accumulated audio chunks using SimulStreaming.
@@ -633,52 +658,26 @@ class SimulStreamingOnlineProcessor:
logger.debug(f"SimulStreaming processing audio shape: {audio.shape}, duration: {audio_duration:.2f}s")
logger.debug(f"Current end time: {self.end:.2f}s, last stream time: {self.last_audio_stream_end_time:.2f}s")
result = self.asr.model.infer(audio, is_last=self.is_last)
self.asr.model.insert_audio(audio)
tokens, generation_progress = self.asr.model.infer(is_last=self.is_last)
ts_words = self.timestamped_text(tokens, generation_progress)
text = self.asr.model.tokenizer.decode(tokens)
if torch.is_tensor(result):
# we filter out padding tokens as it s done in simul whisper
from simul_whisper.simul_whisper import DEC_PAD
result = result[result < DEC_PAD]
new_tokens = []
for ts_word in ts_words:
# C/P from simul_whisper.simul_whisper.py
if len(result) > 0:
decoded_text = self.asr.model.tokenizer.decode(result.cpu().numpy())
logger.debug(f"SimulStreaming decoded: {decoded_text}")
if decoded_text.strip():
words = decoded_text.strip().split()
new_tokens = []
num_words = len(words)
if num_words > 0:
# distribute words evenly across the processed audio duration
# we NEED that for when we use diarization. Even if that s not perfect
start_time = self.end - audio_duration
time_per_word = audio_duration / num_words if num_words > 1 else audio_duration
for i, word in enumerate(words):
token_start = start_time + (i * time_per_word)
token_end = start_time + ((i + 1) * time_per_word)
token_end = min(token_end, self.end)
token = ASRToken(
start=token_start,
end=token_end,
text=word,
probability=0.95 # fake prob. Maybe we can extract it from the model?
)
new_tokens.append(token)
self.beg = self.end
self.committed.extend(new_tokens)
self.last_result_tokens = new_tokens
logger.debug(f"SimulStreaming generated {len(new_tokens)} tokens with end time: {self.end:.2f}s")
return new_tokens, self.end
start, end, word = ts_word
token = ASRToken(
start=start,
end=end,
text=word,
probability=0.95 # fake prob. Maybe we can extract it from the model?
)
new_tokens.append(token)
self.committed.extend(new_tokens)
return [], self.end
return new_tokens, self.end
except Exception as e:
logger.error(f"SimulStreaming processing error: {e}")