mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-04-26 16:45:46 +00:00
Use Sentence, Transcript and ASRToken classes for clarity
This commit is contained in:
15
src/whisper_streaming/asr_token.py
Normal file
15
src/whisper_streaming/asr_token.py
Normal file
@@ -0,0 +1,15 @@
|
||||
class ASRToken:
|
||||
"""
|
||||
A token (word) from the ASR system with start/end times and text.
|
||||
"""
|
||||
def __init__(self, start: float, end: float, text: str):
|
||||
self.start = start
|
||||
self.end = end
|
||||
self.text = text
|
||||
|
||||
def with_offset(self, offset: float) -> "ASRToken":
|
||||
"""Return a new token with the time offset added."""
|
||||
return ASRToken(self.start + offset, self.end + offset, self.text)
|
||||
|
||||
def __repr__(self):
|
||||
return f"ASRToken(start={self.start:.2f}, end={self.end:.2f}, text={self.text!r})"
|
||||
@@ -1,112 +1,145 @@
|
||||
import sys
|
||||
import numpy as np
|
||||
import logging
|
||||
from typing import List, Tuple, Optional
|
||||
from src.whisper_streaming.asr_token import ASRToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class Sentence:
|
||||
"""
|
||||
A sentence assembled from tokens.
|
||||
"""
|
||||
def __init__(self, start: float, end: float, text: str):
|
||||
self.start = start
|
||||
self.end = end
|
||||
self.text = text
|
||||
|
||||
def __repr__(self):
|
||||
return f"Sentence(start={self.start:.2f}, end={self.end:.2f}, text={self.text!r})"
|
||||
|
||||
class Transcript:
|
||||
"""
|
||||
A transcript that bundles a start time, an end time, and a concatenated text.
|
||||
"""
|
||||
def __init__(self, start: Optional[float], end: Optional[float], text: str):
|
||||
self.start = start
|
||||
self.end = end
|
||||
self.text = text
|
||||
|
||||
def __iter__(self):
|
||||
return iter((self.start, self.end, self.text))
|
||||
|
||||
def __repr__(self):
|
||||
return f"Transcript(start={self.start}, end={self.end}, text={self.text!r})"
|
||||
|
||||
|
||||
class HypothesisBuffer:
|
||||
"""
|
||||
Buffer to store and process ASR hypothesis tokens.
|
||||
|
||||
It holds:
|
||||
- committed_in_buffer: tokens that have been confirmed (committed)
|
||||
- buffer: the last hypothesis that is not yet committed
|
||||
- new: new tokens coming from the recognizer
|
||||
"""
|
||||
def __init__(self, logfile=sys.stderr):
|
||||
self.commited_in_buffer = []
|
||||
self.buffer = []
|
||||
self.new = []
|
||||
|
||||
self.last_commited_time = 0
|
||||
self.last_commited_word = None
|
||||
|
||||
self.committed_in_buffer: List[ASRToken] = []
|
||||
self.buffer: List[ASRToken] = []
|
||||
self.new: List[ASRToken] = []
|
||||
self.last_committed_time = 0.0
|
||||
self.last_committed_word: Optional[str] = None
|
||||
self.logfile = logfile
|
||||
|
||||
def insert(self, new, offset):
|
||||
def insert(self, new_tokens: List[ASRToken], offset: float):
|
||||
"""
|
||||
compare self.commited_in_buffer and new. It inserts only the words in new that extend the commited_in_buffer, it means they are roughly behind last_commited_time and new in content
|
||||
The new tail is added to self.new
|
||||
Insert new tokens (after applying a time offset) and compare them with the
|
||||
already committed tokens. Only tokens that extend the committed hypothesis
|
||||
are added.
|
||||
"""
|
||||
# Apply the offset to each token.
|
||||
new_tokens = [token.with_offset(offset) for token in new_tokens]
|
||||
# Only keep tokens that are roughly “new”
|
||||
self.new = [token for token in new_tokens if token.start > self.last_committed_time - 0.1]
|
||||
|
||||
new = [(a + offset, b + offset, t) for a, b, t in new]
|
||||
self.new = [(a, b, t) for a, b, t in new if a > self.last_commited_time - 0.1]
|
||||
|
||||
if len(self.new) >= 1:
|
||||
a, b, t = self.new[0]
|
||||
if abs(a - self.last_commited_time) < 1:
|
||||
if self.commited_in_buffer:
|
||||
# it's going to search for 1, 2, ..., 5 consecutive words (n-grams) that are identical in commited and new. If they are, they're dropped.
|
||||
cn = len(self.commited_in_buffer)
|
||||
nn = len(self.new)
|
||||
for i in range(1, min(min(cn, nn), 5) + 1): # 5 is the maximum
|
||||
c = " ".join(
|
||||
[self.commited_in_buffer[-j][2] for j in range(1, i + 1)][
|
||||
::-1
|
||||
]
|
||||
)
|
||||
tail = " ".join(self.new[j - 1][2] for j in range(1, i + 1))
|
||||
if c == tail:
|
||||
words = []
|
||||
for j in range(i):
|
||||
words.append(repr(self.new.pop(0)))
|
||||
words_msg = " ".join(words)
|
||||
logger.debug(f"removing last {i} words: {words_msg}")
|
||||
if self.new:
|
||||
first_token = self.new[0]
|
||||
if abs(first_token.start - self.last_committed_time) < 1:
|
||||
if self.committed_in_buffer:
|
||||
committed_len = len(self.committed_in_buffer)
|
||||
new_len = len(self.new)
|
||||
# Try to match 1 to 5 consecutive tokens
|
||||
max_ngram = min(min(committed_len, new_len), 5)
|
||||
for i in range(1, max_ngram + 1):
|
||||
committed_ngram = " ".join(token.text for token in self.committed_in_buffer[-i:])
|
||||
new_ngram = " ".join(token.text for token in self.new[:i])
|
||||
if committed_ngram == new_ngram:
|
||||
removed = []
|
||||
for _ in range(i):
|
||||
removed_token = self.new.pop(0)
|
||||
removed.append(repr(removed_token))
|
||||
logger.debug(f"Removing last {i} words: {' '.join(removed)}")
|
||||
break
|
||||
|
||||
def flush(self):
|
||||
# returns commited chunk = the longest common prefix of 2 last inserts.
|
||||
|
||||
commit = []
|
||||
def flush(self) -> List[ASRToken]:
|
||||
"""
|
||||
Returns the committed chunk, defined as the longest common prefix
|
||||
between the previous hypothesis and the new tokens.
|
||||
"""
|
||||
committed: List[ASRToken] = []
|
||||
while self.new:
|
||||
na, nb, nt = self.new[0]
|
||||
|
||||
if len(self.buffer) == 0:
|
||||
current_new = self.new[0]
|
||||
if not self.buffer:
|
||||
break
|
||||
|
||||
if nt == self.buffer[0][2]:
|
||||
commit.append((na, nb, nt))
|
||||
self.last_commited_word = nt
|
||||
self.last_commited_time = nb
|
||||
if current_new.text == self.buffer[0].text:
|
||||
committed.append(current_new)
|
||||
self.last_committed_word = current_new.text
|
||||
self.last_committed_time = current_new.end
|
||||
self.buffer.pop(0)
|
||||
self.new.pop(0)
|
||||
else:
|
||||
break
|
||||
self.buffer = self.new
|
||||
self.new = []
|
||||
self.commited_in_buffer.extend(commit)
|
||||
return commit
|
||||
self.committed_in_buffer.extend(committed)
|
||||
return committed
|
||||
|
||||
def pop_commited(self, time):
|
||||
"Remove (from the beginning) of commited_in_buffer all the words that are finished before `time`"
|
||||
while self.commited_in_buffer and self.commited_in_buffer[0][1] <= time:
|
||||
self.commited_in_buffer.pop(0)
|
||||
def pop_committed(self, time: float):
|
||||
"""
|
||||
Remove tokens (from the beginning) that have ended before `time`.
|
||||
"""
|
||||
while self.committed_in_buffer and self.committed_in_buffer[0].end <= time:
|
||||
self.committed_in_buffer.pop(0)
|
||||
|
||||
def complete(self):
|
||||
def complete(self) -> List[ASRToken]:
|
||||
"""Return any remaining tokens (i.e. the current buffer)."""
|
||||
return self.buffer
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class OnlineASRProcessor:
|
||||
|
||||
"""
|
||||
Processes incoming audio in a streaming fashion, calling the ASR system
|
||||
periodically, and uses a hypothesis buffer to commit and trim recognized text.
|
||||
|
||||
The processor supports two types of buffer trimming:
|
||||
- "sentence": trims at sentence boundaries (using a sentence tokenizer)
|
||||
- "segment": trims at fixed segment durations.
|
||||
"""
|
||||
SAMPLING_RATE = 16000
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
asr,
|
||||
tokenize_method=None,
|
||||
buffer_trimming=("segment", 15),
|
||||
tokenize_method: Optional[callable] = None,
|
||||
buffer_trimming: Tuple[str, float] = ("segment", 15),
|
||||
logfile=sys.stderr,
|
||||
):
|
||||
"""
|
||||
Initialize OnlineASRProcessor.
|
||||
|
||||
Args:
|
||||
asr: WhisperASR object
|
||||
tokenize_method: Sentence tokenizer function for the target language.
|
||||
Must be a function that takes a list of text as input like MosesSentenceSplitter.
|
||||
Can be None if using "segment" buffer trimming option.
|
||||
buffer_trimming: Tuple of (option, seconds) where:
|
||||
- option: Either "sentence" or "segment"
|
||||
- seconds: Number of seconds threshold for buffer trimming
|
||||
Default is ("segment", 15)
|
||||
logfile: File to store logs
|
||||
|
||||
asr: An ASR system object (for example, a WhisperASR instance) that
|
||||
provides a `transcribe` method, a `ts_words` method (to extract tokens),
|
||||
a `segments_end_ts` method, and a separator attribute `sep`.
|
||||
tokenize_method: A function that receives text and returns a list of sentence strings.
|
||||
buffer_trimming: A tuple (option, seconds), where option is either "sentence" or "segment".
|
||||
"""
|
||||
self.asr = asr
|
||||
self.tokenize = tokenize_method
|
||||
@@ -125,235 +158,209 @@ class OnlineASRProcessor:
|
||||
f"buffer_trimming_sec is set to {self.buffer_trimming_sec}, which is very long. It may cause OOM."
|
||||
)
|
||||
|
||||
def init(self, offset=None):
|
||||
"""run this when starting or restarting processing"""
|
||||
def init(self, offset: Optional[float] = None):
|
||||
"""Initialize or reset the processing buffers."""
|
||||
self.audio_buffer = np.array([], dtype=np.float32)
|
||||
self.transcript_buffer = HypothesisBuffer(logfile=self.logfile)
|
||||
self.buffer_time_offset = 0
|
||||
if offset is not None:
|
||||
self.buffer_time_offset = offset
|
||||
self.transcript_buffer.last_commited_time = self.buffer_time_offset
|
||||
self.commited = []
|
||||
self.buffer_time_offset = offset if offset is not None else 0.0
|
||||
self.transcript_buffer.last_committed_time = self.buffer_time_offset
|
||||
self.committed: List[ASRToken] = []
|
||||
|
||||
def insert_audio_chunk(self, audio):
|
||||
def insert_audio_chunk(self, audio: np.ndarray):
|
||||
"""Append an audio chunk (a numpy array) to the current audio buffer."""
|
||||
self.audio_buffer = np.append(self.audio_buffer, audio)
|
||||
|
||||
def prompt(self):
|
||||
"""Returns a tuple: (prompt, context), where "prompt" is a 200-character suffix of commited text that is inside of the scrolled away part of audio buffer.
|
||||
"context" is the commited text that is inside the audio buffer. It is transcribed again and skipped. It is returned only for debugging and logging reasons.
|
||||
def prompt(self) -> Tuple[str, str]:
|
||||
"""
|
||||
k = max(0, len(self.commited) - 1)
|
||||
while k > 0 and self.commited[k - 1][1] > self.buffer_time_offset:
|
||||
Returns a tuple: (prompt, context), where:
|
||||
- prompt is a 200-character suffix of committed text that falls
|
||||
outside the current audio buffer.
|
||||
- context is the committed text within the current audio buffer.
|
||||
"""
|
||||
k = len(self.committed)
|
||||
while k > 0 and self.committed[k - 1].end > self.buffer_time_offset:
|
||||
k -= 1
|
||||
|
||||
p = self.commited[:k]
|
||||
p = [t for _, _, t in p]
|
||||
prompt = []
|
||||
l = 0
|
||||
while p and l < 200: # 200 characters prompt size
|
||||
x = p.pop(-1)
|
||||
l += len(x) + 1
|
||||
prompt.append(x)
|
||||
non_prompt = self.commited[k:]
|
||||
return self.asr.sep.join(prompt[::-1]), self.asr.sep.join(
|
||||
t for _, _, t in non_prompt
|
||||
)
|
||||
prompt_tokens = self.committed[:k]
|
||||
prompt_words = [token.text for token in prompt_tokens]
|
||||
prompt_list = []
|
||||
length_count = 0
|
||||
# Use the last words until reaching 200 characters.
|
||||
while prompt_words and length_count < 200:
|
||||
word = prompt_words.pop(-1)
|
||||
length_count += len(word) + 1
|
||||
prompt_list.append(word)
|
||||
non_prompt_tokens = self.committed[k:]
|
||||
context_text = self.asr.sep.join(token.text for token in non_prompt_tokens)
|
||||
return self.asr.sep.join(prompt_list[::-1]), context_text
|
||||
|
||||
def process_iter(self):
|
||||
"""Runs on the current audio buffer.
|
||||
Returns: a tuple (beg_timestamp, end_timestamp, "text"), or (None, None, "").
|
||||
The non-emty text is confirmed (committed) partial transcript.
|
||||
def process_iter(self) -> Transcript:
|
||||
"""
|
||||
Processes the current audio buffer.
|
||||
|
||||
prompt, non_prompt = self.prompt()
|
||||
|
||||
Returns a Transcript object representing the committed transcript.
|
||||
"""
|
||||
prompt_text, _ = self.prompt()
|
||||
logger.debug(
|
||||
f"transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds from {self.buffer_time_offset:2.2f}"
|
||||
f"Transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds from {self.buffer_time_offset:.2f}"
|
||||
)
|
||||
res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt)
|
||||
res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt_text)
|
||||
tokens = self.asr.ts_words(res) # Expecting List[ASRToken]
|
||||
self.transcript_buffer.insert(tokens, self.buffer_time_offset)
|
||||
committed_tokens = self.transcript_buffer.flush()
|
||||
self.committed.extend(committed_tokens)
|
||||
completed = self.concatenate_tokens(committed_tokens)
|
||||
logger.debug(f">>>> COMPLETE NOW: {completed.text}")
|
||||
incomp = self.concatenate_tokens(self.transcript_buffer.complete())
|
||||
logger.debug(f"INCOMPLETE: {incomp.text}")
|
||||
|
||||
# transform to [(beg,end,"word1"), ...]
|
||||
tsw = self.asr.ts_words(res)
|
||||
|
||||
self.transcript_buffer.insert(tsw, self.buffer_time_offset)
|
||||
o = self.transcript_buffer.flush()
|
||||
self.commited.extend(o)
|
||||
completed = self.concatenate_tsw(o)
|
||||
logger.debug(f">>>>COMPLETE NOW: {completed[2]}")
|
||||
the_rest = self.concatenate_tsw(self.transcript_buffer.complete())
|
||||
logger.debug(f"INCOMPLETE: {the_rest[2]}")
|
||||
|
||||
# there is a newly confirmed text
|
||||
|
||||
if o and self.buffer_trimming_way == "sentence": # trim the completed sentences
|
||||
if (
|
||||
len(self.audio_buffer) / self.SAMPLING_RATE > self.buffer_trimming_sec
|
||||
): # longer than this
|
||||
if committed_tokens and self.buffer_trimming_way == "sentence":
|
||||
if len(self.audio_buffer) / self.SAMPLING_RATE > self.buffer_trimming_sec:
|
||||
self.chunk_completed_sentence()
|
||||
|
||||
if self.buffer_trimming_way == "segment":
|
||||
s = self.buffer_trimming_sec # trim the completed segments longer than s,
|
||||
else:
|
||||
s = 30 # if the audio buffer is longer than 30s, trim it
|
||||
|
||||
s = self.buffer_trimming_sec if self.buffer_trimming_way == "segment" else 30
|
||||
if len(self.audio_buffer) / self.SAMPLING_RATE > s:
|
||||
self.chunk_completed_segment(res)
|
||||
|
||||
# alternative: on any word
|
||||
# l = self.buffer_time_offset + len(self.audio_buffer)/self.SAMPLING_RATE - 10
|
||||
# let's find commited word that is less
|
||||
# k = len(self.commited)-1
|
||||
# while k>0 and self.commited[k][1] > l:
|
||||
# k -= 1
|
||||
# t = self.commited[k][1]
|
||||
logger.debug("chunking segment")
|
||||
# self.chunk_at(t)
|
||||
|
||||
logger.debug("Chunking segment")
|
||||
logger.debug(
|
||||
f"len of buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}"
|
||||
f"Length of audio buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds"
|
||||
)
|
||||
return self.concatenate_tsw(o)
|
||||
return self.concatenate_tokens(committed_tokens)
|
||||
|
||||
def chunk_completed_sentence(self):
|
||||
if self.commited == []:
|
||||
"""
|
||||
If the committed tokens form at least two sentences, chunk the audio
|
||||
buffer at the end time of the penultimate sentence.
|
||||
"""
|
||||
if not self.committed:
|
||||
return
|
||||
logger.debug("COMPLETED SENTENCE: ", [s[2] for s in self.commited])
|
||||
sents = self.words_to_sentences(self.commited)
|
||||
for s in sents:
|
||||
logger.debug(f"\t\tSENT: {s}")
|
||||
if len(sents) < 2:
|
||||
logger.debug("COMPLETED SENTENCE: " + " ".join(token.text for token in self.committed))
|
||||
sentences = self.words_to_sentences(self.committed)
|
||||
for sentence in sentences:
|
||||
logger.debug(f"\tSentence: {sentence.text}")
|
||||
if len(sentences) < 2:
|
||||
return
|
||||
while len(sents) > 2:
|
||||
sents.pop(0)
|
||||
# we will continue with audio processing at this timestamp
|
||||
chunk_at = sents[-2][1]
|
||||
|
||||
logger.debug(f"--- sentence chunked at {chunk_at:2.2f}")
|
||||
self.chunk_at(chunk_at)
|
||||
# Keep the last two sentences.
|
||||
while len(sentences) > 2:
|
||||
sentences.pop(0)
|
||||
chunk_time = sentences[-2].end
|
||||
logger.debug(f"--- Sentence chunked at {chunk_time:.2f}")
|
||||
self.chunk_at(chunk_time)
|
||||
|
||||
def chunk_completed_segment(self, res):
|
||||
if self.commited == []:
|
||||
"""
|
||||
Chunk the audio buffer based on segment-end timestamps reported by the ASR.
|
||||
"""
|
||||
if not self.committed:
|
||||
return
|
||||
|
||||
ends = self.asr.segments_end_ts(res)
|
||||
|
||||
t = self.commited[-1][1]
|
||||
|
||||
last_committed_time = self.committed[-1].end
|
||||
if len(ends) > 1:
|
||||
|
||||
e = ends[-2] + self.buffer_time_offset
|
||||
while len(ends) > 2 and e > t:
|
||||
while len(ends) > 2 and e > last_committed_time:
|
||||
ends.pop(-1)
|
||||
e = ends[-2] + self.buffer_time_offset
|
||||
if e <= t:
|
||||
logger.debug(f"--- segment chunked at {e:2.2f}")
|
||||
if e <= last_committed_time:
|
||||
logger.debug(f"--- Segment chunked at {e:.2f}")
|
||||
self.chunk_at(e)
|
||||
else:
|
||||
logger.debug(f"--- last segment not within commited area")
|
||||
logger.debug("--- Last segment not within committed area")
|
||||
else:
|
||||
logger.debug(f"--- not enough segments to chunk")
|
||||
|
||||
def chunk_at(self, time):
|
||||
"""trims the hypothesis and audio buffer at "time" """
|
||||
logger.debug(f"chunking at {time:2.2f}s")
|
||||
logger.debug("--- Not enough segments to chunk")
|
||||
|
||||
def chunk_at(self, time: float):
|
||||
"""
|
||||
Trim both the hypothesis and audio buffer at the given time.
|
||||
"""
|
||||
logger.debug(f"Chunking at {time:.2f}s")
|
||||
logger.debug(
|
||||
f"len of audio buffer before chunking is: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}s"
|
||||
)
|
||||
|
||||
|
||||
self.transcript_buffer.pop_commited(time)
|
||||
f"Audio buffer length before chunking: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f}s"
|
||||
)
|
||||
self.transcript_buffer.pop_committed(time)
|
||||
cut_seconds = time - self.buffer_time_offset
|
||||
self.audio_buffer = self.audio_buffer[int(cut_seconds * self.SAMPLING_RATE) :]
|
||||
self.audio_buffer = self.audio_buffer[int(cut_seconds * self.SAMPLING_RATE):]
|
||||
self.buffer_time_offset = time
|
||||
|
||||
logger.debug(
|
||||
f"len of audio buffer is now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}s"
|
||||
)
|
||||
f"Audio buffer length after chunking: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f}s"
|
||||
)
|
||||
|
||||
def words_to_sentences(self, words):
|
||||
"""Uses self.tokenize for sentence segmentation of words.
|
||||
Returns: [(beg,end,"sentence 1"),...]
|
||||
def words_to_sentences(self, tokens: List[ASRToken]) -> List[Sentence]:
|
||||
"""
|
||||
|
||||
|
||||
cwords = [w for w in words]
|
||||
t = " ".join(o[2] for o in cwords)
|
||||
s = self.tokenize(t)
|
||||
out = []
|
||||
while s:
|
||||
beg = None
|
||||
end = None
|
||||
sent = s.pop(0).strip()
|
||||
fsent = sent
|
||||
while cwords:
|
||||
b, e, w = cwords.pop(0)
|
||||
w = w.strip()
|
||||
if beg is None and sent.startswith(w):
|
||||
beg = b
|
||||
elif end is None and sent == w:
|
||||
end = e
|
||||
out.append((beg, end, fsent))
|
||||
break
|
||||
sent = sent[len(w) :].strip()
|
||||
return out
|
||||
|
||||
def finish(self):
|
||||
"""Flush the incomplete text when the whole processing ends.
|
||||
Returns: the same format as self.process_iter()
|
||||
Converts a list of tokens to a list of Sentence objects by using the provided
|
||||
sentence tokenizer.
|
||||
"""
|
||||
o = self.transcript_buffer.complete()
|
||||
f = self.concatenate_tsw(o)
|
||||
logger.debug(f"last, noncommited: {f}")
|
||||
self.buffer_time_offset += len(self.audio_buffer) / 16000
|
||||
return f
|
||||
full_text = " ".join(token.text for token in tokens)
|
||||
sentence_texts = self.tokenize(full_text) if self.tokenize else [full_text]
|
||||
sentences: List[Sentence] = []
|
||||
token_index = 0
|
||||
for sent_text in sentence_texts:
|
||||
sent_text = sent_text.strip()
|
||||
if not sent_text:
|
||||
continue
|
||||
sent_tokens = []
|
||||
accumulated = ""
|
||||
# Accumulate tokens until roughly matching the sentence text.
|
||||
while token_index < len(tokens) and len(accumulated) < len(sent_text):
|
||||
token = tokens[token_index]
|
||||
accumulated = (accumulated + " " + token.text).strip() if accumulated else token.text
|
||||
sent_tokens.append(token)
|
||||
token_index += 1
|
||||
if sent_tokens:
|
||||
sentence = Sentence(
|
||||
start=sent_tokens[0].start,
|
||||
end=sent_tokens[-1].end,
|
||||
text=" ".join(t.text for t in sent_tokens),
|
||||
)
|
||||
sentences.append(sentence)
|
||||
return sentences
|
||||
|
||||
def concatenate_tsw(
|
||||
def finish(self) -> Transcript:
|
||||
"""
|
||||
Flush the remaining transcript when processing ends.
|
||||
"""
|
||||
remaining_tokens = self.transcript_buffer.complete()
|
||||
final_transcript = self.concatenate_tokens(remaining_tokens)
|
||||
logger.debug(f"Final non-committed transcript: {final_transcript}")
|
||||
self.buffer_time_offset += len(self.audio_buffer) / self.SAMPLING_RATE
|
||||
return final_transcript
|
||||
|
||||
def concatenate_tokens(
|
||||
self,
|
||||
sents,
|
||||
sep=None,
|
||||
offset=0,
|
||||
):
|
||||
# concatenates the timestamped words or sentences into one sequence that is flushed in one line
|
||||
# sents: [(beg1, end1, "sentence1"), ...] or [] if empty
|
||||
# return: (beg1,end-of-last-sentence,"concatenation of sentences") or (None, None, "") if empty
|
||||
if sep is None:
|
||||
sep = self.asr.sep
|
||||
t = sep.join(s[2] for s in sents)
|
||||
if len(sents) == 0:
|
||||
b = None
|
||||
e = None
|
||||
tokens: List[ASRToken],
|
||||
sep: Optional[str] = None,
|
||||
offset: float = 0
|
||||
) -> Transcript:
|
||||
sep = sep if sep is not None else self.asr.sep
|
||||
text = sep.join(token.text for token in tokens)
|
||||
if tokens:
|
||||
start = offset + tokens[0].start
|
||||
end = offset + tokens[-1].end
|
||||
else:
|
||||
b = offset + sents[0][0]
|
||||
e = offset + sents[-1][1]
|
||||
return (b, e, t)
|
||||
start = None
|
||||
end = None
|
||||
return Transcript(start, end, text)
|
||||
|
||||
|
||||
class VACOnlineASRProcessor(OnlineASRProcessor):
|
||||
"""Wraps OnlineASRProcessor with VAC (Voice Activity Controller).
|
||||
|
||||
It works the same way as OnlineASRProcessor: it receives chunks of audio (e.g. 0.04 seconds),
|
||||
it runs VAD and continuously detects whether there is speech or not.
|
||||
When it detects end of speech (non-voice for 500ms), it makes OnlineASRProcessor to end the utterance immediately.
|
||||
class VACOnlineASRProcessor:
|
||||
"""
|
||||
Wraps an OnlineASRProcessor with a Voice Activity Controller (VAC).
|
||||
|
||||
It receives small chunks of audio, applies VAD (e.g. with Silero),
|
||||
and when the system detects a pause in speech (or end of an utterance)
|
||||
it finalizes the utterance immediately.
|
||||
"""
|
||||
SAMPLING_RATE = 16000
|
||||
|
||||
# TODO: VACOnlineASRProcessor does not break after chunch length is reached, this can lead to overflow!
|
||||
|
||||
def __init__(self, online_chunk_size, *a, **kw):
|
||||
def __init__(self, online_chunk_size: float, *args, **kwargs):
|
||||
self.online_chunk_size = online_chunk_size
|
||||
self.online = OnlineASRProcessor(*args, **kwargs)
|
||||
|
||||
self.online = OnlineASRProcessor(*a, **kw)
|
||||
|
||||
# VAC:
|
||||
# Load a VAD model (e.g. Silero VAD)
|
||||
import torch
|
||||
|
||||
model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
|
||||
from src.whisper_streaming.silero_vad_iterator import FixedVADIterator
|
||||
|
||||
self.vac = FixedVADIterator(
|
||||
model
|
||||
) # we use the default options there: 500ms silence, 100ms padding, etc.
|
||||
|
||||
self.vac = FixedVADIterator(model)
|
||||
self.logfile = self.online.logfile
|
||||
self.init()
|
||||
|
||||
@@ -361,10 +368,8 @@ class VACOnlineASRProcessor(OnlineASRProcessor):
|
||||
self.online.init()
|
||||
self.vac.reset_states()
|
||||
self.current_online_chunk_buffer_size = 0
|
||||
|
||||
self.is_currently_final = False
|
||||
|
||||
self.status = None # or "voice" or "nonvoice"
|
||||
self.status: Optional[str] = None # "voice" or "nonvoice"
|
||||
self.audio_buffer = np.array([], dtype=np.float32)
|
||||
self.buffer_offset = 0 # in frames
|
||||
|
||||
@@ -372,18 +377,23 @@ class VACOnlineASRProcessor(OnlineASRProcessor):
|
||||
self.buffer_offset += len(self.audio_buffer)
|
||||
self.audio_buffer = np.array([], dtype=np.float32)
|
||||
|
||||
def insert_audio_chunk(self, audio):
|
||||
def insert_audio_chunk(self, audio: np.ndarray):
|
||||
"""
|
||||
Process an incoming small audio chunk:
|
||||
- run VAD on the chunk,
|
||||
- decide whether to send the audio to the online ASR processor immediately,
|
||||
- and/or to mark the current utterance as finished.
|
||||
"""
|
||||
res = self.vac(audio)
|
||||
self.audio_buffer = np.append(self.audio_buffer, audio)
|
||||
|
||||
if res is not None:
|
||||
# VAD returned a result; adjust the frame number
|
||||
frame = list(res.values())[0] - self.buffer_offset
|
||||
if "start" in res and "end" not in res:
|
||||
self.status = "voice"
|
||||
send_audio = self.audio_buffer[frame:]
|
||||
self.online.init(
|
||||
offset=(frame + self.buffer_offset) / self.SAMPLING_RATE
|
||||
)
|
||||
self.online.init(offset=(frame + self.buffer_offset) / self.SAMPLING_RATE)
|
||||
self.online.insert_audio_chunk(send_audio)
|
||||
self.current_online_chunk_buffer_size += len(send_audio)
|
||||
self.clear_buffer()
|
||||
@@ -410,29 +420,28 @@ class VACOnlineASRProcessor(OnlineASRProcessor):
|
||||
self.current_online_chunk_buffer_size += len(self.audio_buffer)
|
||||
self.clear_buffer()
|
||||
else:
|
||||
# We keep 1 second because VAD may later find start of voice in it.
|
||||
# But we trim it to prevent OOM.
|
||||
self.buffer_offset += max(
|
||||
0, len(self.audio_buffer) - self.SAMPLING_RATE
|
||||
)
|
||||
self.audio_buffer = self.audio_buffer[-self.SAMPLING_RATE :]
|
||||
# Keep 1 second worth of audio in case VAD later detects voice,
|
||||
# but trim to avoid unbounded memory usage.
|
||||
self.buffer_offset += max(0, len(self.audio_buffer) - self.SAMPLING_RATE)
|
||||
self.audio_buffer = self.audio_buffer[-self.SAMPLING_RATE:]
|
||||
|
||||
def process_iter(self):
|
||||
def process_iter(self) -> Transcript:
|
||||
"""
|
||||
Depending on the VAD status and the amount of accumulated audio,
|
||||
process the current audio chunk.
|
||||
"""
|
||||
if self.is_currently_final:
|
||||
return self.finish()
|
||||
elif (
|
||||
self.current_online_chunk_buffer_size
|
||||
> self.SAMPLING_RATE * self.online_chunk_size
|
||||
):
|
||||
elif self.current_online_chunk_buffer_size > self.SAMPLING_RATE * self.online_chunk_size:
|
||||
self.current_online_chunk_buffer_size = 0
|
||||
ret = self.online.process_iter()
|
||||
return ret
|
||||
return self.online.process_iter()
|
||||
else:
|
||||
logger.debug("no online update, only VAD")
|
||||
return (None, None, "")
|
||||
logger.debug("No online update, only VAD")
|
||||
return Transcript(None, None, "")
|
||||
|
||||
def finish(self):
|
||||
ret = self.online.finish()
|
||||
def finish(self) -> Transcript:
|
||||
"""Finish processing by flushing any remaining text."""
|
||||
result = self.online.finish()
|
||||
self.current_online_chunk_buffer_size = 0
|
||||
self.is_currently_final = False
|
||||
return ret
|
||||
return result
|
||||
Reference in New Issue
Block a user