mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
compatible with the latest version of simulstreaming
This commit is contained in:
@@ -484,7 +484,8 @@ class SimulStreamingASR(ASRBase):
|
||||
try:
|
||||
if isinstance(audio, np.ndarray):
|
||||
audio = torch.from_numpy(audio).float()
|
||||
self.model.infer(audio, True)
|
||||
self.model.insert_audio(audio)
|
||||
self.model.infer(True)
|
||||
self.model.refresh_segment(complete=True)
|
||||
logger.info("SimulStreaming model warmed up successfully")
|
||||
except Exception as e:
|
||||
|
||||
@@ -609,6 +609,31 @@ class SimulStreamingOnlineProcessor:
|
||||
probability=None
|
||||
)
|
||||
|
||||
def timestamped_text(self, tokens, generation):
|
||||
# From the simulstreaming repo. self.model to self.asr.model
|
||||
pr = generation["progress"]
|
||||
if "result" not in generation:
|
||||
split_words, split_tokens = self.asr.model.tokenizer.split_to_word_tokens(tokens)
|
||||
else:
|
||||
split_words, split_tokens = generation["result"]["split_words"], generation["result"]["split_tokens"]
|
||||
|
||||
frames = [p["most_attended_frames"][0] for p in pr]
|
||||
tokens = tokens.copy()
|
||||
ret = []
|
||||
for sw,st in zip(split_words,split_tokens):
|
||||
b = None
|
||||
for stt in st:
|
||||
t,f = tokens.pop(0), frames.pop(0)
|
||||
if t != stt:
|
||||
raise ValueError(f"Token mismatch: {t} != {stt} at frame {f}.")
|
||||
if b is None:
|
||||
b = f
|
||||
e = f
|
||||
out = (b*0.02, e*0.02, sw)
|
||||
ret.append(out)
|
||||
logger.debug(f"TS-WORD:\t{' '.join(map(str, out))}")
|
||||
return ret
|
||||
|
||||
def process_iter(self) -> Tuple[List[ASRToken], float]:
|
||||
"""
|
||||
Process accumulated audio chunks using SimulStreaming.
|
||||
@@ -633,52 +658,26 @@ class SimulStreamingOnlineProcessor:
|
||||
logger.debug(f"SimulStreaming processing audio shape: {audio.shape}, duration: {audio_duration:.2f}s")
|
||||
logger.debug(f"Current end time: {self.end:.2f}s, last stream time: {self.last_audio_stream_end_time:.2f}s")
|
||||
|
||||
result = self.asr.model.infer(audio, is_last=self.is_last)
|
||||
self.asr.model.insert_audio(audio)
|
||||
tokens, generation_progress = self.asr.model.infer(is_last=self.is_last)
|
||||
ts_words = self.timestamped_text(tokens, generation_progress)
|
||||
text = self.asr.model.tokenizer.decode(tokens)
|
||||
|
||||
if torch.is_tensor(result):
|
||||
# we filter out padding tokens as it s done in simul whisper
|
||||
from simul_whisper.simul_whisper import DEC_PAD
|
||||
result = result[result < DEC_PAD]
|
||||
new_tokens = []
|
||||
for ts_word in ts_words:
|
||||
|
||||
# C/P from simul_whisper.simul_whisper.py
|
||||
if len(result) > 0:
|
||||
decoded_text = self.asr.model.tokenizer.decode(result.cpu().numpy())
|
||||
logger.debug(f"SimulStreaming decoded: {decoded_text}")
|
||||
|
||||
if decoded_text.strip():
|
||||
words = decoded_text.strip().split()
|
||||
new_tokens = []
|
||||
|
||||
num_words = len(words)
|
||||
if num_words > 0:
|
||||
# distribute words evenly across the processed audio duration
|
||||
# we NEED that for when we use diarization. Even if that s not perfect
|
||||
start_time = self.end - audio_duration
|
||||
time_per_word = audio_duration / num_words if num_words > 1 else audio_duration
|
||||
|
||||
for i, word in enumerate(words):
|
||||
token_start = start_time + (i * time_per_word)
|
||||
token_end = start_time + ((i + 1) * time_per_word)
|
||||
|
||||
token_end = min(token_end, self.end)
|
||||
|
||||
token = ASRToken(
|
||||
start=token_start,
|
||||
end=token_end,
|
||||
text=word,
|
||||
probability=0.95 # fake prob. Maybe we can extract it from the model?
|
||||
)
|
||||
new_tokens.append(token)
|
||||
|
||||
self.beg = self.end
|
||||
|
||||
self.committed.extend(new_tokens)
|
||||
self.last_result_tokens = new_tokens
|
||||
|
||||
logger.debug(f"SimulStreaming generated {len(new_tokens)} tokens with end time: {self.end:.2f}s")
|
||||
return new_tokens, self.end
|
||||
start, end, word = ts_word
|
||||
token = ASRToken(
|
||||
start=start,
|
||||
end=end,
|
||||
text=word,
|
||||
probability=0.95 # fake prob. Maybe we can extract it from the model?
|
||||
)
|
||||
new_tokens.append(token)
|
||||
self.committed.extend(new_tokens)
|
||||
|
||||
return [], self.end
|
||||
return new_tokens, self.end
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"SimulStreaming processing error: {e}")
|
||||
|
||||
Reference in New Issue
Block a user