Use Sentence, Transcript and ASRToken classes for clarity

This commit is contained in:
Quentin Fuxa
2025-02-07 12:24:11 +01:00
parent 48c111f494
commit 46f7f9cbd1
2 changed files with 293 additions and 269 deletions

View File

@@ -0,0 +1,15 @@
class ASRToken:
"""
A token (word) from the ASR system with start/end times and text.
"""
def __init__(self, start: float, end: float, text: str):
self.start = start
self.end = end
self.text = text
def with_offset(self, offset: float) -> "ASRToken":
"""Return a new token with the time offset added."""
return ASRToken(self.start + offset, self.end + offset, self.text)
def __repr__(self):
return f"ASRToken(start={self.start:.2f}, end={self.end:.2f}, text={self.text!r})"

View File

@@ -1,112 +1,145 @@
import sys
import numpy as np
import logging
from typing import List, Tuple, Optional
from src.whisper_streaming.asr_token import ASRToken
logger = logging.getLogger(__name__)
class Sentence:
"""
A sentence assembled from tokens.
"""
def __init__(self, start: float, end: float, text: str):
self.start = start
self.end = end
self.text = text
def __repr__(self):
return f"Sentence(start={self.start:.2f}, end={self.end:.2f}, text={self.text!r})"
class Transcript:
"""
A transcript that bundles a start time, an end time, and a concatenated text.
"""
def __init__(self, start: Optional[float], end: Optional[float], text: str):
self.start = start
self.end = end
self.text = text
def __iter__(self):
return iter((self.start, self.end, self.text))
def __repr__(self):
return f"Transcript(start={self.start}, end={self.end}, text={self.text!r})"
class HypothesisBuffer:
"""
Buffer to store and process ASR hypothesis tokens.
It holds:
- committed_in_buffer: tokens that have been confirmed (committed)
- buffer: the last hypothesis that is not yet committed
- new: new tokens coming from the recognizer
"""
def __init__(self, logfile=sys.stderr):
self.commited_in_buffer = []
self.buffer = []
self.new = []
self.last_commited_time = 0
self.last_commited_word = None
self.committed_in_buffer: List[ASRToken] = []
self.buffer: List[ASRToken] = []
self.new: List[ASRToken] = []
self.last_committed_time = 0.0
self.last_committed_word: Optional[str] = None
self.logfile = logfile
def insert(self, new, offset):
def insert(self, new_tokens: List[ASRToken], offset: float):
"""
compare self.commited_in_buffer and new. It inserts only the words in new that extend the commited_in_buffer, it means they are roughly behind last_commited_time and new in content
The new tail is added to self.new
Insert new tokens (after applying a time offset) and compare them with the
already committed tokens. Only tokens that extend the committed hypothesis
are added.
"""
# Apply the offset to each token.
new_tokens = [token.with_offset(offset) for token in new_tokens]
# Only keep tokens that are roughly “new”
self.new = [token for token in new_tokens if token.start > self.last_committed_time - 0.1]
new = [(a + offset, b + offset, t) for a, b, t in new]
self.new = [(a, b, t) for a, b, t in new if a > self.last_commited_time - 0.1]
if len(self.new) >= 1:
a, b, t = self.new[0]
if abs(a - self.last_commited_time) < 1:
if self.commited_in_buffer:
# it's going to search for 1, 2, ..., 5 consecutive words (n-grams) that are identical in commited and new. If they are, they're dropped.
cn = len(self.commited_in_buffer)
nn = len(self.new)
for i in range(1, min(min(cn, nn), 5) + 1): # 5 is the maximum
c = " ".join(
[self.commited_in_buffer[-j][2] for j in range(1, i + 1)][
::-1
]
)
tail = " ".join(self.new[j - 1][2] for j in range(1, i + 1))
if c == tail:
words = []
for j in range(i):
words.append(repr(self.new.pop(0)))
words_msg = " ".join(words)
logger.debug(f"removing last {i} words: {words_msg}")
if self.new:
first_token = self.new[0]
if abs(first_token.start - self.last_committed_time) < 1:
if self.committed_in_buffer:
committed_len = len(self.committed_in_buffer)
new_len = len(self.new)
# Try to match 1 to 5 consecutive tokens
max_ngram = min(min(committed_len, new_len), 5)
for i in range(1, max_ngram + 1):
committed_ngram = " ".join(token.text for token in self.committed_in_buffer[-i:])
new_ngram = " ".join(token.text for token in self.new[:i])
if committed_ngram == new_ngram:
removed = []
for _ in range(i):
removed_token = self.new.pop(0)
removed.append(repr(removed_token))
logger.debug(f"Removing last {i} words: {' '.join(removed)}")
break
def flush(self):
# returns commited chunk = the longest common prefix of 2 last inserts.
commit = []
def flush(self) -> List[ASRToken]:
"""
Returns the committed chunk, defined as the longest common prefix
between the previous hypothesis and the new tokens.
"""
committed: List[ASRToken] = []
while self.new:
na, nb, nt = self.new[0]
if len(self.buffer) == 0:
current_new = self.new[0]
if not self.buffer:
break
if nt == self.buffer[0][2]:
commit.append((na, nb, nt))
self.last_commited_word = nt
self.last_commited_time = nb
if current_new.text == self.buffer[0].text:
committed.append(current_new)
self.last_committed_word = current_new.text
self.last_committed_time = current_new.end
self.buffer.pop(0)
self.new.pop(0)
else:
break
self.buffer = self.new
self.new = []
self.commited_in_buffer.extend(commit)
return commit
self.committed_in_buffer.extend(committed)
return committed
def pop_commited(self, time):
"Remove (from the beginning) of commited_in_buffer all the words that are finished before `time`"
while self.commited_in_buffer and self.commited_in_buffer[0][1] <= time:
self.commited_in_buffer.pop(0)
def pop_committed(self, time: float):
"""
Remove tokens (from the beginning) that have ended before `time`.
"""
while self.committed_in_buffer and self.committed_in_buffer[0].end <= time:
self.committed_in_buffer.pop(0)
def complete(self):
def complete(self) -> List[ASRToken]:
"""Return any remaining tokens (i.e. the current buffer)."""
return self.buffer
class OnlineASRProcessor:
"""
Processes incoming audio in a streaming fashion, calling the ASR system
periodically, and uses a hypothesis buffer to commit and trim recognized text.
The processor supports two types of buffer trimming:
- "sentence": trims at sentence boundaries (using a sentence tokenizer)
- "segment": trims at fixed segment durations.
"""
SAMPLING_RATE = 16000
def __init__(
self,
asr,
tokenize_method=None,
buffer_trimming=("segment", 15),
tokenize_method: Optional[callable] = None,
buffer_trimming: Tuple[str, float] = ("segment", 15),
logfile=sys.stderr,
):
"""
Initialize OnlineASRProcessor.
Args:
asr: WhisperASR object
tokenize_method: Sentence tokenizer function for the target language.
Must be a function that takes a list of text as input like MosesSentenceSplitter.
Can be None if using "segment" buffer trimming option.
buffer_trimming: Tuple of (option, seconds) where:
- option: Either "sentence" or "segment"
- seconds: Number of seconds threshold for buffer trimming
Default is ("segment", 15)
logfile: File to store logs
asr: An ASR system object (for example, a WhisperASR instance) that
provides a `transcribe` method, a `ts_words` method (to extract tokens),
a `segments_end_ts` method, and a separator attribute `sep`.
tokenize_method: A function that receives text and returns a list of sentence strings.
buffer_trimming: A tuple (option, seconds), where option is either "sentence" or "segment".
"""
self.asr = asr
self.tokenize = tokenize_method
@@ -125,235 +158,209 @@ class OnlineASRProcessor:
f"buffer_trimming_sec is set to {self.buffer_trimming_sec}, which is very long. It may cause OOM."
)
def init(self, offset=None):
"""run this when starting or restarting processing"""
def init(self, offset: Optional[float] = None):
"""Initialize or reset the processing buffers."""
self.audio_buffer = np.array([], dtype=np.float32)
self.transcript_buffer = HypothesisBuffer(logfile=self.logfile)
self.buffer_time_offset = 0
if offset is not None:
self.buffer_time_offset = offset
self.transcript_buffer.last_commited_time = self.buffer_time_offset
self.commited = []
self.buffer_time_offset = offset if offset is not None else 0.0
self.transcript_buffer.last_committed_time = self.buffer_time_offset
self.committed: List[ASRToken] = []
def insert_audio_chunk(self, audio):
def insert_audio_chunk(self, audio: np.ndarray):
"""Append an audio chunk (a numpy array) to the current audio buffer."""
self.audio_buffer = np.append(self.audio_buffer, audio)
def prompt(self):
"""Returns a tuple: (prompt, context), where "prompt" is a 200-character suffix of commited text that is inside of the scrolled away part of audio buffer.
"context" is the commited text that is inside the audio buffer. It is transcribed again and skipped. It is returned only for debugging and logging reasons.
def prompt(self) -> Tuple[str, str]:
"""
k = max(0, len(self.commited) - 1)
while k > 0 and self.commited[k - 1][1] > self.buffer_time_offset:
Returns a tuple: (prompt, context), where:
- prompt is a 200-character suffix of committed text that falls
outside the current audio buffer.
- context is the committed text within the current audio buffer.
"""
k = len(self.committed)
while k > 0 and self.committed[k - 1].end > self.buffer_time_offset:
k -= 1
p = self.commited[:k]
p = [t for _, _, t in p]
prompt = []
l = 0
while p and l < 200: # 200 characters prompt size
x = p.pop(-1)
l += len(x) + 1
prompt.append(x)
non_prompt = self.commited[k:]
return self.asr.sep.join(prompt[::-1]), self.asr.sep.join(
t for _, _, t in non_prompt
)
prompt_tokens = self.committed[:k]
prompt_words = [token.text for token in prompt_tokens]
prompt_list = []
length_count = 0
# Use the last words until reaching 200 characters.
while prompt_words and length_count < 200:
word = prompt_words.pop(-1)
length_count += len(word) + 1
prompt_list.append(word)
non_prompt_tokens = self.committed[k:]
context_text = self.asr.sep.join(token.text for token in non_prompt_tokens)
return self.asr.sep.join(prompt_list[::-1]), context_text
def process_iter(self):
"""Runs on the current audio buffer.
Returns: a tuple (beg_timestamp, end_timestamp, "text"), or (None, None, "").
The non-emty text is confirmed (committed) partial transcript.
def process_iter(self) -> Transcript:
"""
Processes the current audio buffer.
prompt, non_prompt = self.prompt()
Returns a Transcript object representing the committed transcript.
"""
prompt_text, _ = self.prompt()
logger.debug(
f"transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds from {self.buffer_time_offset:2.2f}"
f"Transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds from {self.buffer_time_offset:.2f}"
)
res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt)
res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt_text)
tokens = self.asr.ts_words(res) # Expecting List[ASRToken]
self.transcript_buffer.insert(tokens, self.buffer_time_offset)
committed_tokens = self.transcript_buffer.flush()
self.committed.extend(committed_tokens)
completed = self.concatenate_tokens(committed_tokens)
logger.debug(f">>>> COMPLETE NOW: {completed.text}")
incomp = self.concatenate_tokens(self.transcript_buffer.complete())
logger.debug(f"INCOMPLETE: {incomp.text}")
# transform to [(beg,end,"word1"), ...]
tsw = self.asr.ts_words(res)
self.transcript_buffer.insert(tsw, self.buffer_time_offset)
o = self.transcript_buffer.flush()
self.commited.extend(o)
completed = self.concatenate_tsw(o)
logger.debug(f">>>>COMPLETE NOW: {completed[2]}")
the_rest = self.concatenate_tsw(self.transcript_buffer.complete())
logger.debug(f"INCOMPLETE: {the_rest[2]}")
# there is a newly confirmed text
if o and self.buffer_trimming_way == "sentence": # trim the completed sentences
if (
len(self.audio_buffer) / self.SAMPLING_RATE > self.buffer_trimming_sec
): # longer than this
if committed_tokens and self.buffer_trimming_way == "sentence":
if len(self.audio_buffer) / self.SAMPLING_RATE > self.buffer_trimming_sec:
self.chunk_completed_sentence()
if self.buffer_trimming_way == "segment":
s = self.buffer_trimming_sec # trim the completed segments longer than s,
else:
s = 30 # if the audio buffer is longer than 30s, trim it
s = self.buffer_trimming_sec if self.buffer_trimming_way == "segment" else 30
if len(self.audio_buffer) / self.SAMPLING_RATE > s:
self.chunk_completed_segment(res)
# alternative: on any word
# l = self.buffer_time_offset + len(self.audio_buffer)/self.SAMPLING_RATE - 10
# let's find commited word that is less
# k = len(self.commited)-1
# while k>0 and self.commited[k][1] > l:
# k -= 1
# t = self.commited[k][1]
logger.debug("chunking segment")
# self.chunk_at(t)
logger.debug("Chunking segment")
logger.debug(
f"len of buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}"
f"Length of audio buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds"
)
return self.concatenate_tsw(o)
return self.concatenate_tokens(committed_tokens)
def chunk_completed_sentence(self):
if self.commited == []:
"""
If the committed tokens form at least two sentences, chunk the audio
buffer at the end time of the penultimate sentence.
"""
if not self.committed:
return
logger.debug("COMPLETED SENTENCE: ", [s[2] for s in self.commited])
sents = self.words_to_sentences(self.commited)
for s in sents:
logger.debug(f"\t\tSENT: {s}")
if len(sents) < 2:
logger.debug("COMPLETED SENTENCE: " + " ".join(token.text for token in self.committed))
sentences = self.words_to_sentences(self.committed)
for sentence in sentences:
logger.debug(f"\tSentence: {sentence.text}")
if len(sentences) < 2:
return
while len(sents) > 2:
sents.pop(0)
# we will continue with audio processing at this timestamp
chunk_at = sents[-2][1]
logger.debug(f"--- sentence chunked at {chunk_at:2.2f}")
self.chunk_at(chunk_at)
# Keep the last two sentences.
while len(sentences) > 2:
sentences.pop(0)
chunk_time = sentences[-2].end
logger.debug(f"--- Sentence chunked at {chunk_time:.2f}")
self.chunk_at(chunk_time)
def chunk_completed_segment(self, res):
if self.commited == []:
"""
Chunk the audio buffer based on segment-end timestamps reported by the ASR.
"""
if not self.committed:
return
ends = self.asr.segments_end_ts(res)
t = self.commited[-1][1]
last_committed_time = self.committed[-1].end
if len(ends) > 1:
e = ends[-2] + self.buffer_time_offset
while len(ends) > 2 and e > t:
while len(ends) > 2 and e > last_committed_time:
ends.pop(-1)
e = ends[-2] + self.buffer_time_offset
if e <= t:
logger.debug(f"--- segment chunked at {e:2.2f}")
if e <= last_committed_time:
logger.debug(f"--- Segment chunked at {e:.2f}")
self.chunk_at(e)
else:
logger.debug(f"--- last segment not within commited area")
logger.debug("--- Last segment not within committed area")
else:
logger.debug(f"--- not enough segments to chunk")
def chunk_at(self, time):
"""trims the hypothesis and audio buffer at "time" """
logger.debug(f"chunking at {time:2.2f}s")
logger.debug("--- Not enough segments to chunk")
def chunk_at(self, time: float):
"""
Trim both the hypothesis and audio buffer at the given time.
"""
logger.debug(f"Chunking at {time:.2f}s")
logger.debug(
f"len of audio buffer before chunking is: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}s"
)
self.transcript_buffer.pop_commited(time)
f"Audio buffer length before chunking: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f}s"
)
self.transcript_buffer.pop_committed(time)
cut_seconds = time - self.buffer_time_offset
self.audio_buffer = self.audio_buffer[int(cut_seconds * self.SAMPLING_RATE) :]
self.audio_buffer = self.audio_buffer[int(cut_seconds * self.SAMPLING_RATE):]
self.buffer_time_offset = time
logger.debug(
f"len of audio buffer is now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}s"
)
f"Audio buffer length after chunking: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f}s"
)
def words_to_sentences(self, words):
"""Uses self.tokenize for sentence segmentation of words.
Returns: [(beg,end,"sentence 1"),...]
def words_to_sentences(self, tokens: List[ASRToken]) -> List[Sentence]:
"""
cwords = [w for w in words]
t = " ".join(o[2] for o in cwords)
s = self.tokenize(t)
out = []
while s:
beg = None
end = None
sent = s.pop(0).strip()
fsent = sent
while cwords:
b, e, w = cwords.pop(0)
w = w.strip()
if beg is None and sent.startswith(w):
beg = b
elif end is None and sent == w:
end = e
out.append((beg, end, fsent))
break
sent = sent[len(w) :].strip()
return out
def finish(self):
"""Flush the incomplete text when the whole processing ends.
Returns: the same format as self.process_iter()
Converts a list of tokens to a list of Sentence objects by using the provided
sentence tokenizer.
"""
o = self.transcript_buffer.complete()
f = self.concatenate_tsw(o)
logger.debug(f"last, noncommited: {f}")
self.buffer_time_offset += len(self.audio_buffer) / 16000
return f
full_text = " ".join(token.text for token in tokens)
sentence_texts = self.tokenize(full_text) if self.tokenize else [full_text]
sentences: List[Sentence] = []
token_index = 0
for sent_text in sentence_texts:
sent_text = sent_text.strip()
if not sent_text:
continue
sent_tokens = []
accumulated = ""
# Accumulate tokens until roughly matching the sentence text.
while token_index < len(tokens) and len(accumulated) < len(sent_text):
token = tokens[token_index]
accumulated = (accumulated + " " + token.text).strip() if accumulated else token.text
sent_tokens.append(token)
token_index += 1
if sent_tokens:
sentence = Sentence(
start=sent_tokens[0].start,
end=sent_tokens[-1].end,
text=" ".join(t.text for t in sent_tokens),
)
sentences.append(sentence)
return sentences
def concatenate_tsw(
def finish(self) -> Transcript:
"""
Flush the remaining transcript when processing ends.
"""
remaining_tokens = self.transcript_buffer.complete()
final_transcript = self.concatenate_tokens(remaining_tokens)
logger.debug(f"Final non-committed transcript: {final_transcript}")
self.buffer_time_offset += len(self.audio_buffer) / self.SAMPLING_RATE
return final_transcript
def concatenate_tokens(
self,
sents,
sep=None,
offset=0,
):
# concatenates the timestamped words or sentences into one sequence that is flushed in one line
# sents: [(beg1, end1, "sentence1"), ...] or [] if empty
# return: (beg1,end-of-last-sentence,"concatenation of sentences") or (None, None, "") if empty
if sep is None:
sep = self.asr.sep
t = sep.join(s[2] for s in sents)
if len(sents) == 0:
b = None
e = None
tokens: List[ASRToken],
sep: Optional[str] = None,
offset: float = 0
) -> Transcript:
sep = sep if sep is not None else self.asr.sep
text = sep.join(token.text for token in tokens)
if tokens:
start = offset + tokens[0].start
end = offset + tokens[-1].end
else:
b = offset + sents[0][0]
e = offset + sents[-1][1]
return (b, e, t)
start = None
end = None
return Transcript(start, end, text)
class VACOnlineASRProcessor(OnlineASRProcessor):
"""Wraps OnlineASRProcessor with VAC (Voice Activity Controller).
It works the same way as OnlineASRProcessor: it receives chunks of audio (e.g. 0.04 seconds),
it runs VAD and continuously detects whether there is speech or not.
When it detects end of speech (non-voice for 500ms), it makes OnlineASRProcessor to end the utterance immediately.
class VACOnlineASRProcessor:
"""
Wraps an OnlineASRProcessor with a Voice Activity Controller (VAC).
It receives small chunks of audio, applies VAD (e.g. with Silero),
and when the system detects a pause in speech (or end of an utterance)
it finalizes the utterance immediately.
"""
SAMPLING_RATE = 16000
# TODO: VACOnlineASRProcessor does not break after chunch length is reached, this can lead to overflow!
def __init__(self, online_chunk_size, *a, **kw):
def __init__(self, online_chunk_size: float, *args, **kwargs):
self.online_chunk_size = online_chunk_size
self.online = OnlineASRProcessor(*args, **kwargs)
self.online = OnlineASRProcessor(*a, **kw)
# VAC:
# Load a VAD model (e.g. Silero VAD)
import torch
model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
from src.whisper_streaming.silero_vad_iterator import FixedVADIterator
self.vac = FixedVADIterator(
model
) # we use the default options there: 500ms silence, 100ms padding, etc.
self.vac = FixedVADIterator(model)
self.logfile = self.online.logfile
self.init()
@@ -361,10 +368,8 @@ class VACOnlineASRProcessor(OnlineASRProcessor):
self.online.init()
self.vac.reset_states()
self.current_online_chunk_buffer_size = 0
self.is_currently_final = False
self.status = None # or "voice" or "nonvoice"
self.status: Optional[str] = None # "voice" or "nonvoice"
self.audio_buffer = np.array([], dtype=np.float32)
self.buffer_offset = 0 # in frames
@@ -372,18 +377,23 @@ class VACOnlineASRProcessor(OnlineASRProcessor):
self.buffer_offset += len(self.audio_buffer)
self.audio_buffer = np.array([], dtype=np.float32)
def insert_audio_chunk(self, audio):
def insert_audio_chunk(self, audio: np.ndarray):
"""
Process an incoming small audio chunk:
- run VAD on the chunk,
- decide whether to send the audio to the online ASR processor immediately,
- and/or to mark the current utterance as finished.
"""
res = self.vac(audio)
self.audio_buffer = np.append(self.audio_buffer, audio)
if res is not None:
# VAD returned a result; adjust the frame number
frame = list(res.values())[0] - self.buffer_offset
if "start" in res and "end" not in res:
self.status = "voice"
send_audio = self.audio_buffer[frame:]
self.online.init(
offset=(frame + self.buffer_offset) / self.SAMPLING_RATE
)
self.online.init(offset=(frame + self.buffer_offset) / self.SAMPLING_RATE)
self.online.insert_audio_chunk(send_audio)
self.current_online_chunk_buffer_size += len(send_audio)
self.clear_buffer()
@@ -410,29 +420,28 @@ class VACOnlineASRProcessor(OnlineASRProcessor):
self.current_online_chunk_buffer_size += len(self.audio_buffer)
self.clear_buffer()
else:
# We keep 1 second because VAD may later find start of voice in it.
# But we trim it to prevent OOM.
self.buffer_offset += max(
0, len(self.audio_buffer) - self.SAMPLING_RATE
)
self.audio_buffer = self.audio_buffer[-self.SAMPLING_RATE :]
# Keep 1 second worth of audio in case VAD later detects voice,
# but trim to avoid unbounded memory usage.
self.buffer_offset += max(0, len(self.audio_buffer) - self.SAMPLING_RATE)
self.audio_buffer = self.audio_buffer[-self.SAMPLING_RATE:]
def process_iter(self):
def process_iter(self) -> Transcript:
"""
Depending on the VAD status and the amount of accumulated audio,
process the current audio chunk.
"""
if self.is_currently_final:
return self.finish()
elif (
self.current_online_chunk_buffer_size
> self.SAMPLING_RATE * self.online_chunk_size
):
elif self.current_online_chunk_buffer_size > self.SAMPLING_RATE * self.online_chunk_size:
self.current_online_chunk_buffer_size = 0
ret = self.online.process_iter()
return ret
return self.online.process_iter()
else:
logger.debug("no online update, only VAD")
return (None, None, "")
logger.debug("No online update, only VAD")
return Transcript(None, None, "")
def finish(self):
ret = self.online.finish()
def finish(self) -> Transcript:
"""Finish processing by flushing any remaining text."""
result = self.online.finish()
self.current_online_chunk_buffer_size = 0
self.is_currently_final = False
return ret
return result