diff --git a/whisperlivekit/audio_processor.py b/whisperlivekit/audio_processor.py index 86d0fc6..6229b32 100644 --- a/whisperlivekit/audio_processor.py +++ b/whisperlivekit/audio_processor.py @@ -453,12 +453,11 @@ class AudioProcessor: # Handle undiarized text if undiarized_text: - combined_undiarized = sep.join(undiarized_text) + combined = sep.join(undiarized_text) if buffer_transcription: - buffer_transcription = combined_undiarized + sep + buffer_transcription - else: - buffer_transcription = combined_undiarized - buffer_diarization = "" + combined += sep + await self.update_diarization(end_attributed_speaker, combined) + buffer_diarization = combined response_status = "active_transcription" final_lines_for_response = lines.copy() diff --git a/whisperlivekit/token_buffer.py b/whisperlivekit/token_buffer.py new file mode 100644 index 0000000..1972f3d --- /dev/null +++ b/whisperlivekit/token_buffer.py @@ -0,0 +1,73 @@ +import torch +import sys +class TokenBuffer: + + def __init__(self, text="", tokenizer=None, device=None, prefix_token_ids=[]): + self.text = text + self.prefix_token_ids = prefix_token_ids + self.tokenizer = tokenizer + self.device = device + + def as_token_ids(self, tokenizer=None): + + if tokenizer is None: + tokenizer = self.tokenizer + if tokenizer is None: + raise ValueError("Tokenizer is not set.") + return self.prefix_token_ids + tokenizer.encode(self.text) + + def as_tensor(self, device=None): + if device is None: + device = self.device + if device is None: + raise ValueError("Device is not set.") + tok_ids = self.as_token_ids() + return torch.tensor(tok_ids, + dtype=torch.long, device=device).unsqueeze(0) + + def as_tensor_beam(self, beam, device=None): + t = self.as_tensor(device=device) + return t.repeat_interleave(beam, dim=0) + + + def as_text(self): + return self.text + + @staticmethod + def empty(*a, **kw): + return TokenBuffer(*a,**kw) + + @staticmethod + def from_text(text, *a, **kw): + return TokenBuffer(*a, text=text, **kw) + + def is_empty(self): + return self.text is None or self.text == "" + + def trim_words(self, num=1, after=0): + ''' + num: how many words to trim from the beginning + after: how many characters to skip (length of the static prompt) + ''' + tokenizer = self.tokenizer + assert tokenizer is not None, "Tokenizer is not set." + + ids = tokenizer.encode(self.text[after:]) + words, wids = self.tokenizer.split_to_word_tokens(ids) + print(words, file=sys.stderr) + print(wids, file=sys.stderr) + if not words: + return 0 + self.text = self.text[:after] + "".join(words[num:]) + return sum(len(wi) for wi in wids[:num]) + + def append_token_ids(self, token_ids): + tokenizer = self.tokenizer + assert tokenizer is not None, "Tokenizer is not set." + self.text += self.tokenizer.decode(token_ids) + + def as_split_word_tokens(self): + tokenizer = self.tokenizer + assert tokenizer is not None, "Tokenizer is not set." + ids = tokenizer.encode(self.text) + return tokenizer.split_to_word_tokens(ids) diff --git a/whisperlivekit/whisper_streaming_custom/backends.py b/whisperlivekit/whisper_streaming_custom/backends.py index 4f94bae..91c11a7 100644 --- a/whisperlivekit/whisper_streaming_custom/backends.py +++ b/whisperlivekit/whisper_streaming_custom/backends.py @@ -337,7 +337,7 @@ class SimulStreamingASR(ASRBase): if model_dir is not None: self.model_path = model_dir - elif modelsize is not None: + elif modelsize is not None: #For the moment the .en.pt models do not work! model_mapping = { 'tiny': './tiny.pt', 'base': './base.pt', diff --git a/whisperlivekit/whisper_streaming_custom/online_asr.py b/whisperlivekit/whisper_streaming_custom/online_asr.py index 3fcd1f1..bc7c69f 100644 --- a/whisperlivekit/whisper_streaming_custom/online_asr.py +++ b/whisperlivekit/whisper_streaming_custom/online_asr.py @@ -545,11 +545,8 @@ class SimulStreamingOnlineProcessor: self.end = self.offset self.cumulative_audio_duration = 0.0 - # Keep track of committed tokens for compatibility with existing interface self.committed: List[ASRToken] = [] - self.last_result_tokens: List[ASRToken] = [] - - # Buffer for unvalidated content + self.last_result_tokens: List[ASRToken] = [] self.buffer_content = "" def get_audio_buffer_end_time(self) -> float: diff --git a/whisperlivekit/whisper_streaming_custom/whisper_online.py b/whisperlivekit/whisper_streaming_custom/whisper_online.py index e9da486..8a3869d 100644 --- a/whisperlivekit/whisper_streaming_custom/whisper_online.py +++ b/whisperlivekit/whisper_streaming_custom/whisper_online.py @@ -77,7 +77,6 @@ def backend_factory(args): "See the documentation for installation instructions." ) - # Extract SimulStreaming-specific arguments simulstreaming_kwargs = {} for attr in ['frame_threshold', 'beams', 'decoder_type', 'audio_max_len', 'audio_min_len', 'cif_ckpt_path', 'never_fire', 'init_prompt', 'static_init_prompt',