clean simulwhisper backend and online

This commit is contained in:
Quentin Fuxa
2025-08-09 18:02:15 +02:00
parent 197293e25e
commit b05297a96d
2 changed files with 5 additions and 97 deletions

View File

@@ -3,16 +3,15 @@ import numpy as np
import logging
from typing import List, Tuple, Optional
import logging
from whisperlivekit.timed_objects import ASRToken, Sentence, Transcript
from whisperlivekit.timed_objects import ASRToken, Transcript
from whisperlivekit.simul_whisper.license_simulstreaming import SIMULSTREAMING_LICENSE
logger = logging.getLogger(__name__)
try:
import torch
from whisperlivekit.simul_whisper.config import AlignAttConfig
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper, DEC_PAD
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper
from whisperlivekit.simul_whisper.whisper import tokenizer
SIMULSTREAMING_AVAILABLE = True
except ImportError as e:
raise ImportError(
"""SimulStreaming dependencies are not available.
@@ -28,23 +27,14 @@ class SimulStreamingOnlineProcessor:
buffer_trimming: Tuple[str, float] = ("segment", 15),
confidence_validation = False,
logfile=sys.stderr,
):
if not SIMULSTREAMING_AVAILABLE:
raise ImportError("SimulStreaming dependencies are not available.")
):
self.asr = asr
self.tokenize = tokenize_method
self.logfile = logfile
self.confidence_validation = confidence_validation
self.init()
# buffer does not work yet
self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
def init(self, offset: Optional[float] = None):
"""Initialize or reset the processing state."""
self.audio_chunks = []
self.offset = offset if offset is not None else 0.0
self.offset = 0.0
self.is_last = False
self.beg = self.offset
self.end = self.offset
@@ -56,14 +46,8 @@ class SimulStreamingOnlineProcessor:
self.buffer_content = ""
self.processed_audio_duration = 0.0
def get_audio_buffer_end_time(self) -> float:
"""Returns the absolute end time of the current audio buffer."""
return self.end
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: Optional[float] = None):
"""Append an audio chunk to be processed by SimulStreaming."""
if torch is None:
raise ImportError("PyTorch is required for SimulStreaming but not available")
# Convert numpy array to torch tensor
audio_tensor = torch.from_numpy(audio).float()
@@ -79,13 +63,6 @@ class SimulStreamingOnlineProcessor:
else:
self.end = self.offset + self.cumulative_audio_duration
def prompt(self) -> Tuple[str, str]:
"""
Returns a tuple: (prompt, context).
SimulStreaming handles prompting internally, so we return empty strings.
"""
return "", ""
def get_buffer(self):
"""
Get the unvalidated buffer content.
@@ -150,7 +127,6 @@ class SimulStreamingOnlineProcessor:
self.asr.model.insert_audio(audio)
tokens, generation_progress = self.asr.model.infer(is_last=self.is_last)
ts_words = self.timestamped_text(tokens, generation_progress)
text = self.asr.model.tokenizer.decode(tokens)
new_tokens = []
for ts_word in ts_words:
@@ -172,55 +148,6 @@ class SimulStreamingOnlineProcessor:
logger.exception(f"SimulStreaming processing error: {e}")
return [], self.end
def finish(self) -> Tuple[List[ASRToken], float]:
logger.debug("SimulStreaming finish() called")
self.is_last = True
final_tokens, final_time = self.process_iter()
self.is_last = False
return final_tokens, final_time
def concatenate_tokens(
self,
tokens: List[ASRToken],
sep: Optional[str] = None,
offset: float = 0
) -> Transcript:
"""Concatenate tokens into a Transcript object."""
sep = sep if sep is not None else self.asr.sep
text = sep.join(token.text for token in tokens)
probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
if tokens:
start = offset + tokens[0].start
end = offset + tokens[-1].end
else:
start = None
end = None
return Transcript(start, end, text, probability=probability)
def chunk_at(self, time: float):
"""
useless but kept for compatibility
"""
logger.debug(f"SimulStreaming chunk_at({time:.2f}) - handled internally")
pass
def words_to_sentences(self, tokens: List[ASRToken]) -> List[Sentence]:
"""
Create simple sentences.
"""
if not tokens:
return []
full_text = " ".join(token.text for token in tokens)
sentence = Sentence(
start=tokens[0].start,
end=tokens[-1].end,
text=full_text
)
return [sentence]
class SimulStreamingASR():
"""SimulStreaming backend with AlignAtt policy."""
sep = ""
@@ -247,7 +174,7 @@ class SimulStreamingASR():
if model_dir is not None:
self.model_path = model_dir
elif modelsize is not None: #For the moment the .en.pt models do not work!
elif modelsize is not None:
model_mapping = {
'tiny': './tiny.pt',
'base': './base.pt',
@@ -297,13 +224,6 @@ class SimulStreamingASR():
logger.error(f"Failed to load SimulStreaming model: {e}")
raise
def segments_end_ts(self, result) -> List[float]:
"""Get segment end timestamps."""
if torch.is_tensor(result):
num_tokens = len(result)
return [num_tokens * 0.1] # rough estimate
return [1.0]
def set_translate_task(self):
"""Set up translation task."""
try:

View File

@@ -6,18 +6,6 @@ from whisperlivekit.timed_objects import ASRToken, Sentence, Transcript
logger = logging.getLogger(__name__)
# simulStreaming imports - we check if the files are here
try:
import torch
from whisperlivekit.simul_whisper.config import AlignAttConfig
SIMULSTREAMING_AVAILABLE = True
except ImportError:
logger.warning("SimulStreaming dependencies not available for online processor.")
SIMULSTREAMING_AVAILABLE = False
OnlineProcessorInterface = None
torch = None
class HypothesisBuffer:
"""
Buffer to store and process ASR hypothesis tokens.