mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
clean simulwhisper backend and online
This commit is contained in:
@@ -3,16 +3,15 @@ import numpy as np
|
||||
import logging
|
||||
from typing import List, Tuple, Optional
|
||||
import logging
|
||||
from whisperlivekit.timed_objects import ASRToken, Sentence, Transcript
|
||||
from whisperlivekit.timed_objects import ASRToken, Transcript
|
||||
from whisperlivekit.simul_whisper.license_simulstreaming import SIMULSTREAMING_LICENSE
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import torch
|
||||
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
||||
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper, DEC_PAD
|
||||
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper
|
||||
from whisperlivekit.simul_whisper.whisper import tokenizer
|
||||
SIMULSTREAMING_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"""SimulStreaming dependencies are not available.
|
||||
@@ -28,23 +27,14 @@ class SimulStreamingOnlineProcessor:
|
||||
buffer_trimming: Tuple[str, float] = ("segment", 15),
|
||||
confidence_validation = False,
|
||||
logfile=sys.stderr,
|
||||
):
|
||||
if not SIMULSTREAMING_AVAILABLE:
|
||||
raise ImportError("SimulStreaming dependencies are not available.")
|
||||
|
||||
):
|
||||
self.asr = asr
|
||||
self.tokenize = tokenize_method
|
||||
self.logfile = logfile
|
||||
self.confidence_validation = confidence_validation
|
||||
self.init()
|
||||
|
||||
# buffer does not work yet
|
||||
self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
|
||||
|
||||
def init(self, offset: Optional[float] = None):
|
||||
"""Initialize or reset the processing state."""
|
||||
self.audio_chunks = []
|
||||
self.offset = offset if offset is not None else 0.0
|
||||
self.offset = 0.0
|
||||
self.is_last = False
|
||||
self.beg = self.offset
|
||||
self.end = self.offset
|
||||
@@ -56,14 +46,8 @@ class SimulStreamingOnlineProcessor:
|
||||
self.buffer_content = ""
|
||||
self.processed_audio_duration = 0.0
|
||||
|
||||
def get_audio_buffer_end_time(self) -> float:
|
||||
"""Returns the absolute end time of the current audio buffer."""
|
||||
return self.end
|
||||
|
||||
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: Optional[float] = None):
|
||||
"""Append an audio chunk to be processed by SimulStreaming."""
|
||||
if torch is None:
|
||||
raise ImportError("PyTorch is required for SimulStreaming but not available")
|
||||
|
||||
# Convert numpy array to torch tensor
|
||||
audio_tensor = torch.from_numpy(audio).float()
|
||||
@@ -79,13 +63,6 @@ class SimulStreamingOnlineProcessor:
|
||||
else:
|
||||
self.end = self.offset + self.cumulative_audio_duration
|
||||
|
||||
def prompt(self) -> Tuple[str, str]:
|
||||
"""
|
||||
Returns a tuple: (prompt, context).
|
||||
SimulStreaming handles prompting internally, so we return empty strings.
|
||||
"""
|
||||
return "", ""
|
||||
|
||||
def get_buffer(self):
|
||||
"""
|
||||
Get the unvalidated buffer content.
|
||||
@@ -150,7 +127,6 @@ class SimulStreamingOnlineProcessor:
|
||||
self.asr.model.insert_audio(audio)
|
||||
tokens, generation_progress = self.asr.model.infer(is_last=self.is_last)
|
||||
ts_words = self.timestamped_text(tokens, generation_progress)
|
||||
text = self.asr.model.tokenizer.decode(tokens)
|
||||
|
||||
new_tokens = []
|
||||
for ts_word in ts_words:
|
||||
@@ -172,55 +148,6 @@ class SimulStreamingOnlineProcessor:
|
||||
logger.exception(f"SimulStreaming processing error: {e}")
|
||||
return [], self.end
|
||||
|
||||
def finish(self) -> Tuple[List[ASRToken], float]:
|
||||
logger.debug("SimulStreaming finish() called")
|
||||
self.is_last = True
|
||||
final_tokens, final_time = self.process_iter()
|
||||
self.is_last = False
|
||||
return final_tokens, final_time
|
||||
|
||||
def concatenate_tokens(
|
||||
self,
|
||||
tokens: List[ASRToken],
|
||||
sep: Optional[str] = None,
|
||||
offset: float = 0
|
||||
) -> Transcript:
|
||||
"""Concatenate tokens into a Transcript object."""
|
||||
sep = sep if sep is not None else self.asr.sep
|
||||
text = sep.join(token.text for token in tokens)
|
||||
probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
|
||||
if tokens:
|
||||
start = offset + tokens[0].start
|
||||
end = offset + tokens[-1].end
|
||||
else:
|
||||
start = None
|
||||
end = None
|
||||
return Transcript(start, end, text, probability=probability)
|
||||
|
||||
def chunk_at(self, time: float):
|
||||
"""
|
||||
useless but kept for compatibility
|
||||
"""
|
||||
logger.debug(f"SimulStreaming chunk_at({time:.2f}) - handled internally")
|
||||
pass
|
||||
|
||||
def words_to_sentences(self, tokens: List[ASRToken]) -> List[Sentence]:
|
||||
"""
|
||||
Create simple sentences.
|
||||
"""
|
||||
if not tokens:
|
||||
return []
|
||||
|
||||
full_text = " ".join(token.text for token in tokens)
|
||||
sentence = Sentence(
|
||||
start=tokens[0].start,
|
||||
end=tokens[-1].end,
|
||||
text=full_text
|
||||
)
|
||||
return [sentence]
|
||||
|
||||
|
||||
|
||||
class SimulStreamingASR():
|
||||
"""SimulStreaming backend with AlignAtt policy."""
|
||||
sep = ""
|
||||
@@ -247,7 +174,7 @@ class SimulStreamingASR():
|
||||
|
||||
if model_dir is not None:
|
||||
self.model_path = model_dir
|
||||
elif modelsize is not None: #For the moment the .en.pt models do not work!
|
||||
elif modelsize is not None:
|
||||
model_mapping = {
|
||||
'tiny': './tiny.pt',
|
||||
'base': './base.pt',
|
||||
@@ -297,13 +224,6 @@ class SimulStreamingASR():
|
||||
logger.error(f"Failed to load SimulStreaming model: {e}")
|
||||
raise
|
||||
|
||||
def segments_end_ts(self, result) -> List[float]:
|
||||
"""Get segment end timestamps."""
|
||||
if torch.is_tensor(result):
|
||||
num_tokens = len(result)
|
||||
return [num_tokens * 0.1] # rough estimate
|
||||
return [1.0]
|
||||
|
||||
def set_translate_task(self):
|
||||
"""Set up translation task."""
|
||||
try:
|
||||
|
||||
@@ -6,18 +6,6 @@ from whisperlivekit.timed_objects import ASRToken, Sentence, Transcript
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# simulStreaming imports - we check if the files are here
|
||||
try:
|
||||
import torch
|
||||
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
||||
SIMULSTREAMING_AVAILABLE = True
|
||||
except ImportError:
|
||||
logger.warning("SimulStreaming dependencies not available for online processor.")
|
||||
SIMULSTREAMING_AVAILABLE = False
|
||||
OnlineProcessorInterface = None
|
||||
torch = None
|
||||
|
||||
|
||||
class HypothesisBuffer:
|
||||
"""
|
||||
Buffer to store and process ASR hypothesis tokens.
|
||||
|
||||
Reference in New Issue
Block a user