mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 14:23:18 +00:00
clean SimulStreamingOnlineProcessor initialization + audio processing
This commit is contained in:
@@ -125,10 +125,7 @@ def online_factory(args, asr, tokenizer, logfile=sys.stderr):
|
||||
from simul_whisper import SimulStreamingOnlineProcessor
|
||||
online = SimulStreamingOnlineProcessor(
|
||||
asr,
|
||||
tokenizer,
|
||||
logfile=logfile,
|
||||
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
|
||||
confidence_validation=args.confidence_validation
|
||||
)
|
||||
elif args.vac:
|
||||
online = VACOnlineASRProcessor(
|
||||
|
||||
@@ -23,45 +23,34 @@ class SimulStreamingOnlineProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
asr,
|
||||
tokenize_method: Optional[callable] = None,
|
||||
buffer_trimming: Tuple[str, float] = ("segment", 15),
|
||||
confidence_validation = False,
|
||||
logfile=sys.stderr,
|
||||
):
|
||||
self.asr = asr
|
||||
self.logfile = logfile
|
||||
self.confidence_validation = confidence_validation
|
||||
# buffer does not work yet
|
||||
self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
|
||||
self.audio_chunks = []
|
||||
self.offset = 0.0
|
||||
self.is_last = False
|
||||
self.beg = self.offset
|
||||
self.end = self.offset
|
||||
self.beg = 0.0
|
||||
self.end = 0.0
|
||||
self.cumulative_audio_duration = 0.0
|
||||
self.last_audio_stream_end_time = self.offset
|
||||
|
||||
self.committed: List[ASRToken] = []
|
||||
self.last_result_tokens: List[ASRToken] = []
|
||||
self.buffer_content = ""
|
||||
self.processed_audio_duration = 0.0
|
||||
|
||||
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: Optional[float] = None):
|
||||
"""Append an audio chunk to be processed by SimulStreaming."""
|
||||
|
||||
# Convert numpy array to torch tensor
|
||||
audio_tensor = torch.from_numpy(audio).float()
|
||||
self.audio_chunks.append(audio_tensor)
|
||||
|
||||
# Update timing
|
||||
chunk_duration = len(audio) / self.SAMPLING_RATE
|
||||
self.cumulative_audio_duration += chunk_duration
|
||||
|
||||
if audio_stream_end_time is not None:
|
||||
self.last_audio_stream_end_time = audio_stream_end_time
|
||||
self.end = audio_stream_end_time
|
||||
else:
|
||||
self.end = self.offset + self.cumulative_audio_duration
|
||||
self.end = self.cumulative_audio_duration
|
||||
self.asr.model.insert_audio(audio_tensor)
|
||||
|
||||
def get_buffer(self):
|
||||
"""
|
||||
@@ -106,25 +95,7 @@ class SimulStreamingOnlineProcessor:
|
||||
|
||||
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
|
||||
"""
|
||||
if not self.audio_chunks:
|
||||
return [], self.end
|
||||
|
||||
try:
|
||||
# concatenate all audio chunks
|
||||
if len(self.audio_chunks) == 1:
|
||||
audio = self.audio_chunks[0]
|
||||
else:
|
||||
audio = torch.cat(self.audio_chunks, dim=0)
|
||||
|
||||
audio_duration = audio.shape[0] / self.SAMPLING_RATE if audio.shape[0] > 0 else 0
|
||||
self.processed_audio_duration += audio_duration
|
||||
|
||||
self.audio_chunks = []
|
||||
|
||||
logger.debug(f"SimulStreaming processing audio shape: {audio.shape}, duration: {audio_duration:.2f}s")
|
||||
logger.debug(f"Current end time: {self.end:.2f}s, last stream time: {self.last_audio_stream_end_time:.2f}s")
|
||||
|
||||
self.asr.model.insert_audio(audio)
|
||||
try:
|
||||
tokens, generation_progress = self.asr.model.infer(is_last=self.is_last)
|
||||
ts_words = self.timestamped_text(tokens, generation_progress)
|
||||
|
||||
@@ -214,9 +185,7 @@ class SimulStreamingASR():
|
||||
init_prompt=self.init_prompt,
|
||||
max_context_tokens=self.max_context_tokens,
|
||||
static_init_prompt=self.static_init_prompt,
|
||||
)
|
||||
|
||||
logger.info(f"Loading SimulStreaming model with language: {self.original_language}")
|
||||
)
|
||||
model = PaddedAlignAttWhisper(cfg)
|
||||
return model
|
||||
|
||||
|
||||
Reference in New Issue
Block a user