restore a functionnal buffer_diarization

This commit is contained in:
Quentin Fuxa
2025-06-25 23:38:23 +02:00
parent 6867041254
commit bfec335a5f
5 changed files with 79 additions and 11 deletions

View File

@@ -453,12 +453,11 @@ class AudioProcessor:
# Handle undiarized text
if undiarized_text:
combined_undiarized = sep.join(undiarized_text)
combined = sep.join(undiarized_text)
if buffer_transcription:
buffer_transcription = combined_undiarized + sep + buffer_transcription
else:
buffer_transcription = combined_undiarized
buffer_diarization = ""
combined += sep
await self.update_diarization(end_attributed_speaker, combined)
buffer_diarization = combined
response_status = "active_transcription"
final_lines_for_response = lines.copy()

View File

@@ -0,0 +1,73 @@
import torch
import sys
class TokenBuffer:
def __init__(self, text="", tokenizer=None, device=None, prefix_token_ids=[]):
self.text = text
self.prefix_token_ids = prefix_token_ids
self.tokenizer = tokenizer
self.device = device
def as_token_ids(self, tokenizer=None):
if tokenizer is None:
tokenizer = self.tokenizer
if tokenizer is None:
raise ValueError("Tokenizer is not set.")
return self.prefix_token_ids + tokenizer.encode(self.text)
def as_tensor(self, device=None):
if device is None:
device = self.device
if device is None:
raise ValueError("Device is not set.")
tok_ids = self.as_token_ids()
return torch.tensor(tok_ids,
dtype=torch.long, device=device).unsqueeze(0)
def as_tensor_beam(self, beam, device=None):
t = self.as_tensor(device=device)
return t.repeat_interleave(beam, dim=0)
def as_text(self):
return self.text
@staticmethod
def empty(*a, **kw):
return TokenBuffer(*a,**kw)
@staticmethod
def from_text(text, *a, **kw):
return TokenBuffer(*a, text=text, **kw)
def is_empty(self):
return self.text is None or self.text == ""
def trim_words(self, num=1, after=0):
'''
num: how many words to trim from the beginning
after: how many characters to skip (length of the static prompt)
'''
tokenizer = self.tokenizer
assert tokenizer is not None, "Tokenizer is not set."
ids = tokenizer.encode(self.text[after:])
words, wids = self.tokenizer.split_to_word_tokens(ids)
print(words, file=sys.stderr)
print(wids, file=sys.stderr)
if not words:
return 0
self.text = self.text[:after] + "".join(words[num:])
return sum(len(wi) for wi in wids[:num])
def append_token_ids(self, token_ids):
tokenizer = self.tokenizer
assert tokenizer is not None, "Tokenizer is not set."
self.text += self.tokenizer.decode(token_ids)
def as_split_word_tokens(self):
tokenizer = self.tokenizer
assert tokenizer is not None, "Tokenizer is not set."
ids = tokenizer.encode(self.text)
return tokenizer.split_to_word_tokens(ids)

View File

@@ -337,7 +337,7 @@ class SimulStreamingASR(ASRBase):
if model_dir is not None:
self.model_path = model_dir
elif modelsize is not None:
elif modelsize is not None: #For the moment the .en.pt models do not work!
model_mapping = {
'tiny': './tiny.pt',
'base': './base.pt',

View File

@@ -545,11 +545,8 @@ class SimulStreamingOnlineProcessor:
self.end = self.offset
self.cumulative_audio_duration = 0.0
# Keep track of committed tokens for compatibility with existing interface
self.committed: List[ASRToken] = []
self.last_result_tokens: List[ASRToken] = []
# Buffer for unvalidated content
self.last_result_tokens: List[ASRToken] = []
self.buffer_content = ""
def get_audio_buffer_end_time(self) -> float:

View File

@@ -77,7 +77,6 @@ def backend_factory(args):
"See the documentation for installation instructions."
)
# Extract SimulStreaming-specific arguments
simulstreaming_kwargs = {}
for attr in ['frame_threshold', 'beams', 'decoder_type', 'audio_max_len', 'audio_min_len',
'cif_ckpt_path', 'never_fire', 'init_prompt', 'static_init_prompt',