diff --git a/whisperlivekit/simul_whisper/simul_whisper.py b/whisperlivekit/simul_whisper/simul_whisper.py index 73a9c42..7aa4009 100644 --- a/whisperlivekit/simul_whisper/simul_whisper.py +++ b/whisperlivekit/simul_whisper/simul_whisper.py @@ -167,7 +167,10 @@ class PaddedAlignAttWhisper: self.inference.kv_cache = self.kv_cache self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size) - + + # Tokens to carry over to next chunk for incomplete UTF-8 characters + self.pending_incomplete_tokens = [] + def remove_hooks(self): for hook in self.l_hooks: hook.remove() @@ -261,6 +264,7 @@ class PaddedAlignAttWhisper: self.segments = [] self.log_segments += 1 + self.pending_incomplete_tokens = [] def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor): if self.always_fire: return True @@ -562,6 +566,12 @@ class PaddedAlignAttWhisper: tokens_to_split = current_tokens[0, token_len_before_decoding:] + # Prepend pending tokens from previous chunk if any + if self.pending_incomplete_tokens: + logger.debug(f"[UTF-8 Fix] Prepending {len(self.pending_incomplete_tokens)} pending tokens: {self.pending_incomplete_tokens}") + pending_tensor = torch.tensor(self.pending_incomplete_tokens, dtype=torch.long, device=self.device) + tokens_to_split = torch.cat([pending_tensor, tokens_to_split]) + if fire_detected or is_last: #or punctuation_stop: new_hypothesis = tokens_to_split.flatten().tolist() split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis) @@ -590,7 +600,14 @@ class PaddedAlignAttWhisper: timestamped_words = [] timestamp_idx = 0 + replacement_char = "\ufffd" for word, word_tokens in zip(split_words, split_tokens): + # Skip words containing incomplete UTF-8 from client output + if replacement_char in word: + logger.warning(f"[UTF-8 Filter] Skipping incomplete word from client output: {repr(word)}") + timestamp_idx += len(word_tokens) + continue + try: current_timestamp = l_absolute_timestamps[timestamp_idx] except: @@ -608,5 +625,11 @@ class PaddedAlignAttWhisper: self.global_time_offset ) timestamped_words.append(timestamp_entry) - - return timestamped_words \ No newline at end of file + + # Hold incomplete tokens for next chunk + self.pending_incomplete_tokens = [] + if split_words and replacement_char in split_words[-1]: + self.pending_incomplete_tokens = split_tokens[-1] + logger.warning(f"[UTF-8 Fix] Holding {len(self.pending_incomplete_tokens)} incomplete tokens for next chunk: {self.pending_incomplete_tokens}") + + return timestamped_words diff --git a/whisperlivekit/simul_whisper/token_buffer.py b/whisperlivekit/simul_whisper/token_buffer.py index 50462e0..1146591 100644 --- a/whisperlivekit/simul_whisper/token_buffer.py +++ b/whisperlivekit/simul_whisper/token_buffer.py @@ -7,6 +7,7 @@ class TokenBuffer: self.prefix_token_ids = prefix_token_ids self.tokenizer = tokenizer self.device = device + self.pending_token_ids = [] def as_token_ids(self, tokenizer=None): @@ -64,7 +65,26 @@ class TokenBuffer: def append_token_ids(self, token_ids): tokenizer = self.tokenizer assert tokenizer is not None, "Tokenizer is not set." - self.text += self.tokenizer.decode(token_ids) + + all_tokens = self.pending_token_ids + token_ids + + decoded = tokenizer.decode(all_tokens) + replacement_char = "\ufffd" + + if replacement_char in decoded: + if len(all_tokens) > 1: + decoded_partial = tokenizer.decode(all_tokens[:-1]) + + if replacement_char not in decoded_partial: + self.text += decoded_partial + self.pending_token_ids = [all_tokens[-1]] + else: + self.pending_token_ids = all_tokens + else: + self.pending_token_ids = all_tokens + else: + self.text += decoded + self.pending_token_ids = [] def as_split_word_tokens(self): tokenizer = self.tokenizer