mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
411 lines
15 KiB
Python
411 lines
15 KiB
Python
import sys
|
|
import numpy as np
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class HypothesisBuffer:
|
|
|
|
def __init__(self, logfile=sys.stderr):
|
|
self.commited_in_buffer = []
|
|
self.buffer = []
|
|
self.new = []
|
|
|
|
self.last_commited_time = 0
|
|
self.last_commited_word = None
|
|
|
|
self.logfile = logfile
|
|
|
|
def insert(self, new, offset):
|
|
# compare self.commited_in_buffer and new. It inserts only the words in new that extend the commited_in_buffer, it means they are roughly behind last_commited_time and new in content
|
|
# the new tail is added to self.new
|
|
|
|
new = [(a + offset, b + offset, t) for a, b, t in new]
|
|
self.new = [(a, b, t) for a, b, t in new if a > self.last_commited_time - 0.1]
|
|
|
|
if len(self.new) >= 1:
|
|
a, b, t = self.new[0]
|
|
if abs(a - self.last_commited_time) < 1:
|
|
if self.commited_in_buffer:
|
|
# it's going to search for 1, 2, ..., 5 consecutive words (n-grams) that are identical in commited and new. If they are, they're dropped.
|
|
cn = len(self.commited_in_buffer)
|
|
nn = len(self.new)
|
|
for i in range(1, min(min(cn, nn), 5) + 1): # 5 is the maximum
|
|
c = " ".join(
|
|
[self.commited_in_buffer[-j][2] for j in range(1, i + 1)][
|
|
::-1
|
|
]
|
|
)
|
|
tail = " ".join(self.new[j - 1][2] for j in range(1, i + 1))
|
|
if c == tail:
|
|
words = []
|
|
for j in range(i):
|
|
words.append(repr(self.new.pop(0)))
|
|
words_msg = " ".join(words)
|
|
logger.debug(f"removing last {i} words: {words_msg}")
|
|
break
|
|
|
|
def flush(self):
|
|
# returns commited chunk = the longest common prefix of 2 last inserts.
|
|
|
|
commit = []
|
|
while self.new:
|
|
na, nb, nt = self.new[0]
|
|
|
|
if len(self.buffer) == 0:
|
|
break
|
|
|
|
if nt == self.buffer[0][2]:
|
|
commit.append((na, nb, nt))
|
|
self.last_commited_word = nt
|
|
self.last_commited_time = nb
|
|
self.buffer.pop(0)
|
|
self.new.pop(0)
|
|
else:
|
|
break
|
|
self.buffer = self.new
|
|
self.new = []
|
|
self.commited_in_buffer.extend(commit)
|
|
return commit
|
|
|
|
def pop_commited(self, time):
|
|
while self.commited_in_buffer and self.commited_in_buffer[0][1] <= time:
|
|
self.commited_in_buffer.pop(0)
|
|
|
|
def complete(self):
|
|
return self.buffer
|
|
|
|
|
|
class OnlineASRProcessor:
|
|
|
|
SAMPLING_RATE = 16000
|
|
|
|
def __init__(
|
|
self,
|
|
asr,
|
|
tokenize_method=None,
|
|
buffer_trimming=("segment", 15),
|
|
logfile=sys.stderr,
|
|
):
|
|
"""asr: WhisperASR object
|
|
tokenize_method: sentence tokenizer function for the target language. Must be a callable and behaves like the one of MosesTokenizer. It can be None, if "segment" buffer trimming option is used, then tokenizer is not used at all.
|
|
("segment", 15)
|
|
buffer_trimming: a pair of (option, seconds), where option is either "sentence" or "segment", and seconds is a number. Buffer is trimmed if it is longer than "seconds" threshold. Default is the most recommended option.
|
|
logfile: where to store the log.
|
|
"""
|
|
self.asr = asr
|
|
self.tokenize = tokenize_method
|
|
self.logfile = logfile
|
|
|
|
self.init()
|
|
|
|
self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
|
|
|
|
def init(self, offset=None):
|
|
"""run this when starting or restarting processing"""
|
|
self.audio_buffer = np.array([], dtype=np.float32)
|
|
self.transcript_buffer = HypothesisBuffer(logfile=self.logfile)
|
|
self.buffer_time_offset = 0
|
|
if offset is not None:
|
|
self.buffer_time_offset = offset
|
|
self.transcript_buffer.last_commited_time = self.buffer_time_offset
|
|
self.commited = []
|
|
|
|
def insert_audio_chunk(self, audio):
|
|
self.audio_buffer = np.append(self.audio_buffer, audio)
|
|
|
|
def prompt(self):
|
|
"""Returns a tuple: (prompt, context), where "prompt" is a 200-character suffix of commited text that is inside of the scrolled away part of audio buffer.
|
|
"context" is the commited text that is inside the audio buffer. It is transcribed again and skipped. It is returned only for debugging and logging reasons.
|
|
"""
|
|
k = max(0, len(self.commited) - 1)
|
|
while k > 0 and self.commited[k - 1][1] > self.buffer_time_offset:
|
|
k -= 1
|
|
|
|
p = self.commited[:k]
|
|
p = [t for _, _, t in p]
|
|
prompt = []
|
|
l = 0
|
|
while p and l < 200: # 200 characters prompt size
|
|
x = p.pop(-1)
|
|
l += len(x) + 1
|
|
prompt.append(x)
|
|
non_prompt = self.commited[k:]
|
|
return self.asr.sep.join(prompt[::-1]), self.asr.sep.join(
|
|
t for _, _, t in non_prompt
|
|
)
|
|
|
|
def process_iter(self):
|
|
"""Runs on the current audio buffer.
|
|
Returns: a tuple (beg_timestamp, end_timestamp, "text"), or (None, None, "").
|
|
The non-emty text is confirmed (committed) partial transcript.
|
|
"""
|
|
|
|
prompt, non_prompt = self.prompt()
|
|
logger.debug(f"PROMPT: {prompt}")
|
|
logger.debug(f"CONTEXT: {non_prompt}")
|
|
logger.debug(
|
|
f"transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds from {self.buffer_time_offset:2.2f}"
|
|
)
|
|
res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt)
|
|
|
|
# transform to [(beg,end,"word1"), ...]
|
|
tsw = self.asr.ts_words(res)
|
|
|
|
self.transcript_buffer.insert(tsw, self.buffer_time_offset)
|
|
o = self.transcript_buffer.flush()
|
|
self.commited.extend(o)
|
|
completed = self.to_flush(o)
|
|
logger.debug(f">>>>COMPLETE NOW: {completed[2]}")
|
|
the_rest = self.to_flush(self.transcript_buffer.complete())
|
|
logger.debug(f"INCOMPLETE: {the_rest[2]}")
|
|
|
|
# there is a newly confirmed text
|
|
|
|
if o and self.buffer_trimming_way == "sentence": # trim the completed sentences
|
|
if (
|
|
len(self.audio_buffer) / self.SAMPLING_RATE > self.buffer_trimming_sec
|
|
): # longer than this
|
|
|
|
logger.debug("chunking sentence")
|
|
self.chunk_completed_sentence()
|
|
|
|
|
|
else:
|
|
logger.debug("not enough audio to trim as a sentence")
|
|
|
|
if self.buffer_trimming_way == "segment":
|
|
s = self.buffer_trimming_sec # trim the completed segments longer than s,
|
|
else:
|
|
s = 30 # if the audio buffer is longer than 30s, trim it
|
|
|
|
if len(self.audio_buffer) / self.SAMPLING_RATE > s:
|
|
self.chunk_completed_segment(res)
|
|
|
|
# alternative: on any word
|
|
# l = self.buffer_time_offset + len(self.audio_buffer)/self.SAMPLING_RATE - 10
|
|
# let's find commited word that is less
|
|
# k = len(self.commited)-1
|
|
# while k>0 and self.commited[k][1] > l:
|
|
# k -= 1
|
|
# t = self.commited[k][1]
|
|
logger.debug("chunking segment")
|
|
# self.chunk_at(t)
|
|
|
|
logger.debug(
|
|
f"len of buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}"
|
|
)
|
|
return self.to_flush(o)
|
|
|
|
def chunk_completed_sentence(self):
|
|
if self.commited == []:
|
|
return
|
|
|
|
import pdb; pdb.set_trace()
|
|
raw_text = self.asr.sep.join([s[2] for s in self.commited])
|
|
logger.debug(f"COMPLETED SENTENCE: {raw_text}")
|
|
sents = self.words_to_sentences(self.commited)
|
|
for s in sents:
|
|
logger.debug(f"\t\tSENT: {s}")
|
|
if len(sents) < 2:
|
|
return
|
|
while len(sents) > 2:
|
|
sents.pop(0)
|
|
# we will continue with audio processing at this timestamp
|
|
chunk_at = sents[-2][1]
|
|
|
|
logger.debug(f"--- sentence chunked at {chunk_at:2.2f}")
|
|
self.chunk_at(chunk_at)
|
|
|
|
def chunk_completed_segment(self, res):
|
|
if self.commited == []:
|
|
return
|
|
|
|
ends = self.asr.segments_end_ts(res)
|
|
|
|
t = self.commited[-1][1]
|
|
|
|
if len(ends) > 1:
|
|
|
|
e = ends[-2] + self.buffer_time_offset
|
|
while len(ends) > 2 and e > t:
|
|
ends.pop(-1)
|
|
e = ends[-2] + self.buffer_time_offset
|
|
if e <= t:
|
|
logger.debug(f"--- segment chunked at {e:2.2f}")
|
|
self.chunk_at(e)
|
|
else:
|
|
logger.debug(f"--- last segment not within commited area")
|
|
else:
|
|
logger.debug(f"--- not enough segments to chunk")
|
|
|
|
def chunk_at(self, time):
|
|
"""trims the hypothesis and audio buffer at "time" """
|
|
self.transcript_buffer.pop_commited(time)
|
|
cut_seconds = time - self.buffer_time_offset
|
|
self.audio_buffer = self.audio_buffer[int(cut_seconds * self.SAMPLING_RATE) :]
|
|
self.buffer_time_offset = time
|
|
|
|
def words_to_sentences(self, words):
|
|
"""Uses self.tokenize for sentence segmentation of words.
|
|
Returns: [(beg,end,"sentence 1"),...]
|
|
"""
|
|
|
|
cwords = [w for w in words]
|
|
t = self.asr.sep.join(o[2] for o in cwords)
|
|
s = self.tokenize(t)
|
|
out = []
|
|
while s:
|
|
beg = None
|
|
end = None
|
|
sent = s.pop(0).strip()
|
|
fsent = sent
|
|
while cwords:
|
|
b, e, w = cwords.pop(0)
|
|
w = w.strip()
|
|
if beg is None and sent.startswith(w):
|
|
beg = b
|
|
elif end is None and sent == w:
|
|
end = e
|
|
out.append((beg, end, fsent))
|
|
break
|
|
sent = sent[len(w) :].strip()
|
|
return out
|
|
|
|
def finish(self):
|
|
"""Flush the incomplete text when the whole processing ends.
|
|
Returns: the same format as self.process_iter()
|
|
"""
|
|
o = self.transcript_buffer.complete()
|
|
f = self.to_flush(o)
|
|
logger.debug(f"last, noncommited: {f[0]*1000:.0f}-{f[1]*1000:.0f}: {f[2]}")
|
|
self.buffer_time_offset += len(self.audio_buffer) / 16000
|
|
return f
|
|
|
|
def to_flush(
|
|
self,
|
|
sents,
|
|
sep=None,
|
|
offset=0,
|
|
):
|
|
# concatenates the timestamped words or sentences into one sequence that is flushed in one line
|
|
# sents: [(beg1, end1, "sentence1"), ...] or [] if empty
|
|
# return: (beg1,end-of-last-sentence,"concatenation of sentences") or (None, None, "") if empty
|
|
if sep is None:
|
|
sep = self.asr.sep
|
|
t = sep.join(s[2] for s in sents)
|
|
if len(sents) == 0:
|
|
b = None
|
|
e = None
|
|
else:
|
|
b = offset + sents[0][0]
|
|
e = offset + sents[-1][1]
|
|
return (b, e, t)
|
|
|
|
|
|
class VACOnlineASRProcessor(OnlineASRProcessor):
|
|
"""Wraps OnlineASRProcessor with VAC (Voice Activity Controller).
|
|
|
|
It works the same way as OnlineASRProcessor: it receives chunks of audio (e.g. 0.04 seconds),
|
|
it runs VAD and continuously detects whether there is speech or not.
|
|
When it detects end of speech (non-voice for 500ms), it makes OnlineASRProcessor to end the utterance immediately.
|
|
"""
|
|
|
|
def __init__(self, online_chunk_size, *a, **kw):
|
|
self.online_chunk_size = online_chunk_size
|
|
|
|
self.online = OnlineASRProcessor(*a, **kw)
|
|
|
|
# VAC:
|
|
import torch
|
|
|
|
model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
|
|
from silero_vad_iterator import FixedVADIterator
|
|
|
|
self.vac = FixedVADIterator(
|
|
model
|
|
) # we use the default options there: 500ms silence, 100ms padding, etc.
|
|
|
|
self.logfile = self.online.logfile
|
|
self.init()
|
|
|
|
def init(self):
|
|
self.online.init()
|
|
self.vac.reset_states()
|
|
self.current_online_chunk_buffer_size = 0
|
|
|
|
self.is_currently_final = False
|
|
|
|
self.status = None # or "voice" or "nonvoice"
|
|
self.audio_buffer = np.array([], dtype=np.float32)
|
|
self.buffer_offset = 0 # in frames
|
|
|
|
def clear_buffer(self):
|
|
self.buffer_offset += len(self.audio_buffer)
|
|
self.audio_buffer = np.array([], dtype=np.float32)
|
|
|
|
def insert_audio_chunk(self, audio):
|
|
res = self.vac(audio)
|
|
self.audio_buffer = np.append(self.audio_buffer, audio)
|
|
|
|
if res is not None:
|
|
frame = list(res.values())[0] - self.buffer_offset
|
|
if "start" in res and "end" not in res:
|
|
self.status = "voice"
|
|
send_audio = self.audio_buffer[frame:]
|
|
self.online.init(
|
|
offset=(frame + self.buffer_offset) / self.SAMPLING_RATE
|
|
)
|
|
self.online.insert_audio_chunk(send_audio)
|
|
self.current_online_chunk_buffer_size += len(send_audio)
|
|
self.clear_buffer()
|
|
elif "end" in res and "start" not in res:
|
|
self.status = "nonvoice"
|
|
send_audio = self.audio_buffer[:frame]
|
|
self.online.insert_audio_chunk(send_audio)
|
|
self.current_online_chunk_buffer_size += len(send_audio)
|
|
self.is_currently_final = True
|
|
self.clear_buffer()
|
|
else:
|
|
beg = res["start"] - self.buffer_offset
|
|
end = res["end"] - self.buffer_offset
|
|
self.status = "nonvoice"
|
|
send_audio = self.audio_buffer[beg:end]
|
|
self.online.init(offset=(beg + self.buffer_offset) / self.SAMPLING_RATE)
|
|
self.online.insert_audio_chunk(send_audio)
|
|
self.current_online_chunk_buffer_size += len(send_audio)
|
|
self.is_currently_final = True
|
|
self.clear_buffer()
|
|
else:
|
|
if self.status == "voice":
|
|
self.online.insert_audio_chunk(self.audio_buffer)
|
|
self.current_online_chunk_buffer_size += len(self.audio_buffer)
|
|
self.clear_buffer()
|
|
else:
|
|
# We keep 1 second because VAD may later find start of voice in it.
|
|
# But we trim it to prevent OOM.
|
|
self.buffer_offset += max(
|
|
0, len(self.audio_buffer) - self.SAMPLING_RATE
|
|
)
|
|
self.audio_buffer = self.audio_buffer[-self.SAMPLING_RATE :]
|
|
|
|
def process_iter(self):
|
|
if self.is_currently_final:
|
|
return self.finish()
|
|
elif (
|
|
self.current_online_chunk_buffer_size
|
|
> self.SAMPLING_RATE * self.online_chunk_size
|
|
):
|
|
self.current_online_chunk_buffer_size = 0
|
|
ret = self.online.process_iter()
|
|
return ret
|
|
else:
|
|
print("no online update, only VAD", self.status, file=self.logfile)
|
|
return (None, None, "")
|
|
|
|
def finish(self):
|
|
ret = self.online.finish()
|
|
self.current_online_chunk_buffer_size = 0
|
|
self.is_currently_final = False
|
|
return ret
|