Files
WhisperLiveKit/whisperlivekit/token_buffer.py
2025-06-25 23:38:23 +02:00

74 lines
2.3 KiB
Python

import torch
import sys
class TokenBuffer:
def __init__(self, text="", tokenizer=None, device=None, prefix_token_ids=[]):
self.text = text
self.prefix_token_ids = prefix_token_ids
self.tokenizer = tokenizer
self.device = device
def as_token_ids(self, tokenizer=None):
if tokenizer is None:
tokenizer = self.tokenizer
if tokenizer is None:
raise ValueError("Tokenizer is not set.")
return self.prefix_token_ids + tokenizer.encode(self.text)
def as_tensor(self, device=None):
if device is None:
device = self.device
if device is None:
raise ValueError("Device is not set.")
tok_ids = self.as_token_ids()
return torch.tensor(tok_ids,
dtype=torch.long, device=device).unsqueeze(0)
def as_tensor_beam(self, beam, device=None):
t = self.as_tensor(device=device)
return t.repeat_interleave(beam, dim=0)
def as_text(self):
return self.text
@staticmethod
def empty(*a, **kw):
return TokenBuffer(*a,**kw)
@staticmethod
def from_text(text, *a, **kw):
return TokenBuffer(*a, text=text, **kw)
def is_empty(self):
return self.text is None or self.text == ""
def trim_words(self, num=1, after=0):
'''
num: how many words to trim from the beginning
after: how many characters to skip (length of the static prompt)
'''
tokenizer = self.tokenizer
assert tokenizer is not None, "Tokenizer is not set."
ids = tokenizer.encode(self.text[after:])
words, wids = self.tokenizer.split_to_word_tokens(ids)
print(words, file=sys.stderr)
print(wids, file=sys.stderr)
if not words:
return 0
self.text = self.text[:after] + "".join(words[num:])
return sum(len(wi) for wi in wids[:num])
def append_token_ids(self, token_ids):
tokenizer = self.tokenizer
assert tokenizer is not None, "Tokenizer is not set."
self.text += self.tokenizer.decode(token_ids)
def as_split_word_tokens(self):
tokenizer = self.tokenizer
assert tokenizer is not None, "Tokenizer is not set."
ids = tokenizer.encode(self.text)
return tokenizer.split_to_word_tokens(ids)