mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 14:23:18 +00:00
This fix addresses a critical bug in the Whisper tokenizer that causes the transcription server to crash with an `IndexError: string index out of range` when streaming audio in languages utilizing multi-byte UTF-8 characters (e.g., Cantonese, Japanese, Mandarin). When a 3-byte character is cut off at the boundary of an audio chunk, incomplete bytes are decoded into a single Unicode replacement character (`\ufffd`), artificially shortening the string and breaking the offset mapping assumed by `split_tokens_on_unicode`. This ports the upstream fix from SYSTRAN/faster-whisper (PR #111) to add a strict bounds check before accessing the string index, allowing incomplete bytes to be safely caught and handled in the next chunk.
401 lines
12 KiB
Python
401 lines
12 KiB
Python
import base64
|
|
import os
|
|
import string
|
|
from dataclasses import dataclass, field
|
|
from functools import cached_property, lru_cache
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
import tiktoken
|
|
|
|
LANGUAGES = {
|
|
"en": "english",
|
|
"zh": "chinese",
|
|
"de": "german",
|
|
"es": "spanish",
|
|
"ru": "russian",
|
|
"ko": "korean",
|
|
"fr": "french",
|
|
"ja": "japanese",
|
|
"pt": "portuguese",
|
|
"tr": "turkish",
|
|
"pl": "polish",
|
|
"ca": "catalan",
|
|
"nl": "dutch",
|
|
"ar": "arabic",
|
|
"sv": "swedish",
|
|
"it": "italian",
|
|
"id": "indonesian",
|
|
"hi": "hindi",
|
|
"fi": "finnish",
|
|
"vi": "vietnamese",
|
|
"he": "hebrew",
|
|
"uk": "ukrainian",
|
|
"el": "greek",
|
|
"ms": "malay",
|
|
"cs": "czech",
|
|
"ro": "romanian",
|
|
"da": "danish",
|
|
"hu": "hungarian",
|
|
"ta": "tamil",
|
|
"no": "norwegian",
|
|
"th": "thai",
|
|
"ur": "urdu",
|
|
"hr": "croatian",
|
|
"bg": "bulgarian",
|
|
"lt": "lithuanian",
|
|
"la": "latin",
|
|
"mi": "maori",
|
|
"ml": "malayalam",
|
|
"cy": "welsh",
|
|
"sk": "slovak",
|
|
"te": "telugu",
|
|
"fa": "persian",
|
|
"lv": "latvian",
|
|
"bn": "bengali",
|
|
"sr": "serbian",
|
|
"az": "azerbaijani",
|
|
"sl": "slovenian",
|
|
"kn": "kannada",
|
|
"et": "estonian",
|
|
"mk": "macedonian",
|
|
"br": "breton",
|
|
"eu": "basque",
|
|
"is": "icelandic",
|
|
"hy": "armenian",
|
|
"ne": "nepali",
|
|
"mn": "mongolian",
|
|
"bs": "bosnian",
|
|
"kk": "kazakh",
|
|
"sq": "albanian",
|
|
"sw": "swahili",
|
|
"gl": "galician",
|
|
"mr": "marathi",
|
|
"pa": "punjabi",
|
|
"si": "sinhala",
|
|
"km": "khmer",
|
|
"sn": "shona",
|
|
"yo": "yoruba",
|
|
"so": "somali",
|
|
"af": "afrikaans",
|
|
"oc": "occitan",
|
|
"ka": "georgian",
|
|
"be": "belarusian",
|
|
"tg": "tajik",
|
|
"sd": "sindhi",
|
|
"gu": "gujarati",
|
|
"am": "amharic",
|
|
"yi": "yiddish",
|
|
"lo": "lao",
|
|
"uz": "uzbek",
|
|
"fo": "faroese",
|
|
"ht": "haitian creole",
|
|
"ps": "pashto",
|
|
"tk": "turkmen",
|
|
"nn": "nynorsk",
|
|
"mt": "maltese",
|
|
"sa": "sanskrit",
|
|
"lb": "luxembourgish",
|
|
"my": "myanmar",
|
|
"bo": "tibetan",
|
|
"tl": "tagalog",
|
|
"mg": "malagasy",
|
|
"as": "assamese",
|
|
"tt": "tatar",
|
|
"haw": "hawaiian",
|
|
"ln": "lingala",
|
|
"ha": "hausa",
|
|
"ba": "bashkir",
|
|
"jw": "javanese",
|
|
"su": "sundanese",
|
|
"yue": "cantonese",
|
|
}
|
|
|
|
# language code lookup by name, with a few language aliases
|
|
TO_LANGUAGE_CODE = {
|
|
**{language: code for code, language in LANGUAGES.items()},
|
|
"burmese": "my",
|
|
"valencian": "ca",
|
|
"flemish": "nl",
|
|
"haitian": "ht",
|
|
"letzeburgesch": "lb",
|
|
"pushto": "ps",
|
|
"panjabi": "pa",
|
|
"moldavian": "ro",
|
|
"moldovan": "ro",
|
|
"sinhalese": "si",
|
|
"castilian": "es",
|
|
"mandarin": "zh",
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class Tokenizer:
|
|
"""A thin wrapper around `tiktoken` providing quick access to special tokens"""
|
|
|
|
encoding: tiktoken.Encoding
|
|
num_languages: int
|
|
language: Optional[str] = None
|
|
task: Optional[str] = None
|
|
sot_sequence: Tuple[int] = ()
|
|
special_tokens: Dict[str, int] = field(default_factory=dict)
|
|
|
|
def __post_init__(self):
|
|
for special in self.encoding.special_tokens_set:
|
|
special_token = self.encoding.encode_single_token(special)
|
|
self.special_tokens[special] = special_token
|
|
|
|
sot: int = self.special_tokens["<|startoftranscript|>"]
|
|
translate: int = self.special_tokens["<|translate|>"]
|
|
transcribe: int = self.special_tokens["<|transcribe|>"]
|
|
|
|
langs = tuple(LANGUAGES.keys())[: self.num_languages]
|
|
sot_sequence = [sot]
|
|
if self.language is not None:
|
|
sot_sequence.append(sot + 1 + langs.index(self.language))
|
|
if self.task is not None:
|
|
task_token: int = transcribe if self.task == "transcribe" else translate
|
|
sot_sequence.append(task_token)
|
|
|
|
self.sot_sequence = tuple(sot_sequence)
|
|
|
|
def encode(self, text, **kwargs):
|
|
return self.encoding.encode(text, **kwargs)
|
|
|
|
def decode(self, token_ids: List[int], **kwargs) -> str:
|
|
token_ids = [t for t in token_ids if t < self.timestamp_begin]
|
|
return self.encoding.decode(token_ids, **kwargs)
|
|
|
|
def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str:
|
|
"""
|
|
Timestamp tokens are above other special tokens' id range and are ignored by `decode()`.
|
|
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
|
|
"""
|
|
return self.encoding.decode(token_ids, **kwargs)
|
|
|
|
@cached_property
|
|
def eot(self) -> int:
|
|
return self.encoding.eot_token
|
|
|
|
@cached_property
|
|
def transcribe(self) -> int:
|
|
return self.special_tokens["<|transcribe|>"]
|
|
|
|
@cached_property
|
|
def translate(self) -> int:
|
|
return self.special_tokens["<|translate|>"]
|
|
|
|
@cached_property
|
|
def sot(self) -> int:
|
|
return self.special_tokens["<|startoftranscript|>"]
|
|
|
|
@cached_property
|
|
def sot_lm(self) -> int:
|
|
return self.special_tokens["<|startoflm|>"]
|
|
|
|
@cached_property
|
|
def sot_prev(self) -> int:
|
|
return self.special_tokens["<|startofprev|>"]
|
|
|
|
@cached_property
|
|
def no_speech(self) -> int:
|
|
return self.special_tokens["<|nospeech|>"]
|
|
|
|
@cached_property
|
|
def no_timestamps(self) -> int:
|
|
return self.special_tokens["<|notimestamps|>"]
|
|
|
|
@cached_property
|
|
def timestamp_begin(self) -> int:
|
|
return self.special_tokens["<|0.00|>"]
|
|
|
|
@cached_property
|
|
def language_token(self) -> int:
|
|
"""Returns the token id corresponding to the value of the `language` field"""
|
|
if self.language is None:
|
|
raise ValueError("This tokenizer does not have language token configured")
|
|
|
|
return self.to_language_token(self.language)
|
|
|
|
def to_language_token(self, language):
|
|
if token := self.special_tokens.get(f"<|{language}|>", None):
|
|
return token
|
|
|
|
raise KeyError(f"Language {language} not found in tokenizer.")
|
|
|
|
@cached_property
|
|
def all_language_tokens(self) -> Tuple[int]:
|
|
result = []
|
|
for token, token_id in self.special_tokens.items():
|
|
if token.strip("<|>") in LANGUAGES:
|
|
result.append(token_id)
|
|
return tuple(result)[: self.num_languages]
|
|
|
|
@cached_property
|
|
def all_language_codes(self) -> Tuple[str]:
|
|
return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens)
|
|
|
|
@cached_property
|
|
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
|
|
return tuple(list(self.sot_sequence) + [self.no_timestamps])
|
|
|
|
@cached_property
|
|
def non_speech_tokens(self) -> Tuple[int]:
|
|
"""
|
|
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
|
|
annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
|
|
|
|
- ♪♪♪
|
|
- ( SPEAKING FOREIGN LANGUAGE )
|
|
- [DAVID] Hey there,
|
|
|
|
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
|
|
"""
|
|
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
|
|
symbols += (
|
|
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
|
|
)
|
|
|
|
# symbols that may be a single token or multiple tokens depending on the tokenizer.
|
|
# In case they're multiple tokens, suppress the first token, which is safe because:
|
|
# These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
|
|
# in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
|
|
miscellaneous = set("♩♪♫♬♭♮♯")
|
|
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
|
|
|
|
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
|
|
result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
|
|
for symbol in symbols + list(miscellaneous):
|
|
for tokens in [
|
|
self.encoding.encode(symbol),
|
|
self.encoding.encode(" " + symbol),
|
|
]:
|
|
if len(tokens) == 1 or symbol in miscellaneous:
|
|
result.add(tokens[0])
|
|
|
|
return tuple(sorted(result))
|
|
|
|
def split_to_word_tokens(self, tokens: List[int]):
|
|
if self.language in {"zh", "ja", "th", "lo", "my", "yue"}:
|
|
# These languages don't typically use spaces, so it is difficult to split words
|
|
# without morpheme analysis. Here, we instead split words at any
|
|
# position where the tokens are decoded as valid unicode points
|
|
return self.split_tokens_on_unicode(tokens)
|
|
|
|
return self.split_tokens_on_spaces(tokens)
|
|
|
|
def split_tokens_on_unicode(self, tokens: List[int]):
|
|
decoded_full = self.decode_with_timestamps(tokens)
|
|
replacement_char = "\ufffd"
|
|
|
|
words = []
|
|
word_tokens = []
|
|
current_tokens = []
|
|
unicode_offset = 0
|
|
|
|
for token in tokens:
|
|
current_tokens.append(token)
|
|
decoded = self.decode_with_timestamps(current_tokens)
|
|
|
|
try:
|
|
replacement_char_index = decoded.index(replacement_char)
|
|
replacement_char_index += unicode_offset
|
|
except ValueError:
|
|
replacement_char_index = None
|
|
|
|
if replacement_char_index is None or (
|
|
replacement_char_index < len(decoded_full)
|
|
and decoded_full[replacement_char_index] == replacement_char
|
|
):
|
|
words.append(decoded)
|
|
word_tokens.append(current_tokens)
|
|
current_tokens = []
|
|
unicode_offset += len(decoded)
|
|
|
|
return words, word_tokens
|
|
|
|
def split_tokens_on_spaces(self, tokens: List[int]):
|
|
subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
|
|
words = []
|
|
word_tokens = []
|
|
|
|
for subword, subword_tokens in zip(subwords, subword_tokens_list):
|
|
special = subword_tokens[0] >= self.eot
|
|
with_space = subword.startswith(" ")
|
|
punctuation = subword.strip() in string.punctuation
|
|
if special or with_space or punctuation or len(words) == 0:
|
|
words.append(subword)
|
|
word_tokens.append(subword_tokens)
|
|
else:
|
|
words[-1] = words[-1] + subword
|
|
word_tokens[-1].extend(subword_tokens)
|
|
|
|
return words, word_tokens
|
|
|
|
|
|
@lru_cache(maxsize=None)
|
|
def get_encoding(name: str = "gpt2", num_languages: int = 99):
|
|
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
|
|
ranks = {
|
|
base64.b64decode(token): int(rank)
|
|
for token, rank in (line.split() for line in open(vocab_path) if line)
|
|
}
|
|
n_vocab = len(ranks)
|
|
special_tokens = {}
|
|
|
|
specials = [
|
|
"<|endoftext|>",
|
|
"<|startoftranscript|>",
|
|
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
|
|
"<|translate|>",
|
|
"<|transcribe|>",
|
|
"<|startoflm|>",
|
|
"<|startofprev|>",
|
|
"<|nospeech|>",
|
|
"<|notimestamps|>",
|
|
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
|
|
]
|
|
|
|
for token in specials:
|
|
special_tokens[token] = n_vocab
|
|
n_vocab += 1
|
|
|
|
return tiktoken.Encoding(
|
|
name=os.path.basename(vocab_path),
|
|
explicit_n_vocab=n_vocab,
|
|
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
|
|
mergeable_ranks=ranks,
|
|
special_tokens=special_tokens,
|
|
)
|
|
|
|
|
|
@lru_cache(maxsize=None)
|
|
def get_tokenizer(
|
|
multilingual: bool,
|
|
*,
|
|
num_languages: int = 99,
|
|
language: Optional[str] = None,
|
|
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
|
) -> Tokenizer:
|
|
if language is not None:
|
|
language = language.lower()
|
|
if language not in LANGUAGES:
|
|
if language in TO_LANGUAGE_CODE:
|
|
language = TO_LANGUAGE_CODE[language]
|
|
else:
|
|
raise ValueError(f"Unsupported language: {language}")
|
|
|
|
if multilingual:
|
|
encoding_name = "multilingual"
|
|
language = language or "en"
|
|
task = task or "transcribe"
|
|
else:
|
|
encoding_name = "gpt2"
|
|
language = None
|
|
task = None
|
|
|
|
encoding = get_encoding(name=encoding_name, num_languages=num_languages)
|
|
|
|
return Tokenizer(
|
|
encoding=encoding, num_languages=num_languages, language=language, task=task
|
|
)
|