mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
restore a functionnal buffer_diarization
This commit is contained in:
@@ -453,12 +453,11 @@ class AudioProcessor:
|
||||
|
||||
# Handle undiarized text
|
||||
if undiarized_text:
|
||||
combined_undiarized = sep.join(undiarized_text)
|
||||
combined = sep.join(undiarized_text)
|
||||
if buffer_transcription:
|
||||
buffer_transcription = combined_undiarized + sep + buffer_transcription
|
||||
else:
|
||||
buffer_transcription = combined_undiarized
|
||||
buffer_diarization = ""
|
||||
combined += sep
|
||||
await self.update_diarization(end_attributed_speaker, combined)
|
||||
buffer_diarization = combined
|
||||
|
||||
response_status = "active_transcription"
|
||||
final_lines_for_response = lines.copy()
|
||||
|
||||
73
whisperlivekit/token_buffer.py
Normal file
73
whisperlivekit/token_buffer.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import torch
|
||||
import sys
|
||||
class TokenBuffer:
|
||||
|
||||
def __init__(self, text="", tokenizer=None, device=None, prefix_token_ids=[]):
|
||||
self.text = text
|
||||
self.prefix_token_ids = prefix_token_ids
|
||||
self.tokenizer = tokenizer
|
||||
self.device = device
|
||||
|
||||
def as_token_ids(self, tokenizer=None):
|
||||
|
||||
if tokenizer is None:
|
||||
tokenizer = self.tokenizer
|
||||
if tokenizer is None:
|
||||
raise ValueError("Tokenizer is not set.")
|
||||
return self.prefix_token_ids + tokenizer.encode(self.text)
|
||||
|
||||
def as_tensor(self, device=None):
|
||||
if device is None:
|
||||
device = self.device
|
||||
if device is None:
|
||||
raise ValueError("Device is not set.")
|
||||
tok_ids = self.as_token_ids()
|
||||
return torch.tensor(tok_ids,
|
||||
dtype=torch.long, device=device).unsqueeze(0)
|
||||
|
||||
def as_tensor_beam(self, beam, device=None):
|
||||
t = self.as_tensor(device=device)
|
||||
return t.repeat_interleave(beam, dim=0)
|
||||
|
||||
|
||||
def as_text(self):
|
||||
return self.text
|
||||
|
||||
@staticmethod
|
||||
def empty(*a, **kw):
|
||||
return TokenBuffer(*a,**kw)
|
||||
|
||||
@staticmethod
|
||||
def from_text(text, *a, **kw):
|
||||
return TokenBuffer(*a, text=text, **kw)
|
||||
|
||||
def is_empty(self):
|
||||
return self.text is None or self.text == ""
|
||||
|
||||
def trim_words(self, num=1, after=0):
|
||||
'''
|
||||
num: how many words to trim from the beginning
|
||||
after: how many characters to skip (length of the static prompt)
|
||||
'''
|
||||
tokenizer = self.tokenizer
|
||||
assert tokenizer is not None, "Tokenizer is not set."
|
||||
|
||||
ids = tokenizer.encode(self.text[after:])
|
||||
words, wids = self.tokenizer.split_to_word_tokens(ids)
|
||||
print(words, file=sys.stderr)
|
||||
print(wids, file=sys.stderr)
|
||||
if not words:
|
||||
return 0
|
||||
self.text = self.text[:after] + "".join(words[num:])
|
||||
return sum(len(wi) for wi in wids[:num])
|
||||
|
||||
def append_token_ids(self, token_ids):
|
||||
tokenizer = self.tokenizer
|
||||
assert tokenizer is not None, "Tokenizer is not set."
|
||||
self.text += self.tokenizer.decode(token_ids)
|
||||
|
||||
def as_split_word_tokens(self):
|
||||
tokenizer = self.tokenizer
|
||||
assert tokenizer is not None, "Tokenizer is not set."
|
||||
ids = tokenizer.encode(self.text)
|
||||
return tokenizer.split_to_word_tokens(ids)
|
||||
@@ -337,7 +337,7 @@ class SimulStreamingASR(ASRBase):
|
||||
|
||||
if model_dir is not None:
|
||||
self.model_path = model_dir
|
||||
elif modelsize is not None:
|
||||
elif modelsize is not None: #For the moment the .en.pt models do not work!
|
||||
model_mapping = {
|
||||
'tiny': './tiny.pt',
|
||||
'base': './base.pt',
|
||||
|
||||
@@ -545,11 +545,8 @@ class SimulStreamingOnlineProcessor:
|
||||
self.end = self.offset
|
||||
self.cumulative_audio_duration = 0.0
|
||||
|
||||
# Keep track of committed tokens for compatibility with existing interface
|
||||
self.committed: List[ASRToken] = []
|
||||
self.last_result_tokens: List[ASRToken] = []
|
||||
|
||||
# Buffer for unvalidated content
|
||||
self.last_result_tokens: List[ASRToken] = []
|
||||
self.buffer_content = ""
|
||||
|
||||
def get_audio_buffer_end_time(self) -> float:
|
||||
|
||||
@@ -77,7 +77,6 @@ def backend_factory(args):
|
||||
"See the documentation for installation instructions."
|
||||
)
|
||||
|
||||
# Extract SimulStreaming-specific arguments
|
||||
simulstreaming_kwargs = {}
|
||||
for attr in ['frame_threshold', 'beams', 'decoder_type', 'audio_max_len', 'audio_min_len',
|
||||
'cif_ckpt_path', 'never_fire', 'init_prompt', 'static_init_prompt',
|
||||
|
||||
Reference in New Issue
Block a user