mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
Merge pull request #10 from SilasK/main
More flexibility by using custom tokenize_method + black
This commit is contained in:
@@ -6,15 +6,16 @@ import torch
|
||||
|
||||
# Their licence is MIT, same as ours: https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/LICENSE
|
||||
|
||||
class VADIterator:
|
||||
def __init__(self,
|
||||
model,
|
||||
threshold: float = 0.5,
|
||||
sampling_rate: int = 16000,
|
||||
min_silence_duration_ms: int = 500, # makes sense on one recording that I checked
|
||||
speech_pad_ms: int = 100 # same
|
||||
):
|
||||
|
||||
class VADIterator:
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
threshold: float = 0.5,
|
||||
sampling_rate: int = 16000,
|
||||
min_silence_duration_ms: int = 500, # makes sense on one recording that I checked
|
||||
speech_pad_ms: int = 100, # same
|
||||
):
|
||||
"""
|
||||
Class for stream imitation
|
||||
|
||||
@@ -41,7 +42,9 @@ class VADIterator:
|
||||
self.sampling_rate = sampling_rate
|
||||
|
||||
if sampling_rate not in [8000, 16000]:
|
||||
raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]')
|
||||
raise ValueError(
|
||||
"VADIterator does not support sampling rates other than [8000, 16000]"
|
||||
)
|
||||
|
||||
self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
|
||||
self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
|
||||
@@ -80,7 +83,13 @@ class VADIterator:
|
||||
if (speech_prob >= self.threshold) and not self.triggered:
|
||||
self.triggered = True
|
||||
speech_start = self.current_sample - self.speech_pad_samples
|
||||
return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, 1)}
|
||||
return {
|
||||
"start": (
|
||||
int(speech_start)
|
||||
if not return_seconds
|
||||
else round(speech_start / self.sampling_rate, 1)
|
||||
)
|
||||
}
|
||||
|
||||
if (speech_prob < self.threshold - 0.15) and self.triggered:
|
||||
if not self.temp_end:
|
||||
@@ -91,26 +100,35 @@ class VADIterator:
|
||||
speech_end = self.temp_end + self.speech_pad_samples
|
||||
self.temp_end = 0
|
||||
self.triggered = False
|
||||
return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, 1)}
|
||||
return {
|
||||
"end": (
|
||||
int(speech_end)
|
||||
if not return_seconds
|
||||
else round(speech_end / self.sampling_rate, 1)
|
||||
)
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
#######################
|
||||
# because Silero now requires exactly 512-sized audio chunks
|
||||
# because Silero now requires exactly 512-sized audio chunks
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class FixedVADIterator(VADIterator):
|
||||
'''It fixes VADIterator by allowing to process any audio length, not only exactly 512 frames at once.
|
||||
If audio to be processed at once is long and multiple voiced segments detected,
|
||||
then __call__ returns the start of the first segment, and end (or middle, which means no end) of the last segment.
|
||||
'''
|
||||
"""It fixes VADIterator by allowing to process any audio length, not only exactly 512 frames at once.
|
||||
If audio to be processed at once is long and multiple voiced segments detected,
|
||||
then __call__ returns the start of the first segment, and end (or middle, which means no end) of the last segment.
|
||||
"""
|
||||
|
||||
def reset_states(self):
|
||||
super().reset_states()
|
||||
self.buffer = np.array([],dtype=np.float32)
|
||||
self.buffer = np.array([], dtype=np.float32)
|
||||
|
||||
def __call__(self, x, return_seconds=False):
|
||||
self.buffer = np.append(self.buffer, x)
|
||||
self.buffer = np.append(self.buffer, x)
|
||||
ret = None
|
||||
while len(self.buffer) >= 512:
|
||||
r = super().__call__(self.buffer[:512], return_seconds=return_seconds)
|
||||
@@ -118,29 +136,28 @@ class FixedVADIterator(VADIterator):
|
||||
if ret is None:
|
||||
ret = r
|
||||
elif r is not None:
|
||||
if 'end' in r:
|
||||
ret['end'] = r['end'] # the latter end
|
||||
if 'start' in r and 'end' in ret: # there is an earlier start.
|
||||
if "end" in r:
|
||||
ret["end"] = r["end"] # the latter end
|
||||
if "start" in r and "end" in ret: # there is an earlier start.
|
||||
# Remove end, merging this segment with the previous one.
|
||||
del ret['end']
|
||||
del ret["end"]
|
||||
return ret if ret != {} else None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# test/demonstrate the need for FixedVADIterator:
|
||||
|
||||
import torch
|
||||
model, _ = torch.hub.load(
|
||||
repo_or_dir='snakers4/silero-vad',
|
||||
model='silero_vad'
|
||||
)
|
||||
|
||||
model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
|
||||
vac = FixedVADIterator(model)
|
||||
# vac = VADIterator(model) # the second case crashes with this
|
||||
# vac = VADIterator(model) # the second case crashes with this
|
||||
|
||||
# this works: for both
|
||||
audio_buffer = np.array([0]*(512),dtype=np.float32)
|
||||
audio_buffer = np.array([0] * (512), dtype=np.float32)
|
||||
vac(audio_buffer)
|
||||
|
||||
# this crashes on the non FixedVADIterator with
|
||||
# this crashes on the non FixedVADIterator with
|
||||
# ops.prim.RaiseException("Input audio chunk is too short", "builtins.ValueError")
|
||||
audio_buffer = np.array([0]*(512-1),dtype=np.float32)
|
||||
audio_buffer = np.array([0] * (512 - 1), dtype=np.float32)
|
||||
vac(audio_buffer)
|
||||
|
||||
@@ -22,10 +22,21 @@ app.add_middleware(
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="Whisper FastAPI Online Server")
|
||||
parser.add_argument("--host", type=str, default='localhost', help="The host address to bind the server to.")
|
||||
parser.add_argument("--port", type=int, default=8000, help="The port number to bind the server to.")
|
||||
parser.add_argument("--warmup-file", type=str, dest="warmup_file",
|
||||
help="The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast. It can be e.g. https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav .")
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="localhost",
|
||||
help="The host address to bind the server to.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=8000, help="The port number to bind the server to."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warmup-file",
|
||||
type=str,
|
||||
dest="warmup_file",
|
||||
help="The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast. It can be e.g. https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav .",
|
||||
)
|
||||
add_shared_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -35,29 +46,38 @@ asr, online = asr_factory(args)
|
||||
with open("src/live_transcription.html", "r") as f:
|
||||
html = f.read()
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def get():
|
||||
return HTMLResponse(html)
|
||||
|
||||
|
||||
SAMPLE_RATE = 16000
|
||||
CHANNELS = 1
|
||||
SAMPLES_PER_SEC = SAMPLE_RATE * int(args.min_chunk_size)
|
||||
BYTES_PER_SAMPLE = 2 # s16le = 2 bytes per sample
|
||||
BYTES_PER_SAMPLE = 2 # s16le = 2 bytes per sample
|
||||
BYTES_PER_SEC = SAMPLES_PER_SEC * BYTES_PER_SAMPLE
|
||||
|
||||
|
||||
async def start_ffmpeg_decoder():
|
||||
"""
|
||||
Start an FFmpeg process in async streaming mode that reads WebM from stdin
|
||||
and outputs raw s16le PCM on stdout. Returns the process object.
|
||||
"""
|
||||
process = (
|
||||
ffmpeg
|
||||
.input('pipe:0', format='webm')
|
||||
.output('pipe:1', format='s16le', acodec='pcm_s16le', ac=CHANNELS, ar=str(SAMPLE_RATE))
|
||||
ffmpeg.input("pipe:0", format="webm")
|
||||
.output(
|
||||
"pipe:1",
|
||||
format="s16le",
|
||||
acodec="pcm_s16le",
|
||||
ac=CHANNELS,
|
||||
ar=str(SAMPLE_RATE),
|
||||
)
|
||||
.run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True)
|
||||
)
|
||||
return process
|
||||
|
||||
|
||||
@app.websocket("/asr")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
@@ -65,6 +85,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
|
||||
ffmpeg_process = await start_ffmpeg_decoder()
|
||||
pcm_buffer = bytearray()
|
||||
|
||||
# Continuously read decoded PCM from ffmpeg stdout in a background task
|
||||
async def ffmpeg_stdout_reader():
|
||||
nonlocal pcm_buffer
|
||||
@@ -75,10 +96,16 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
try:
|
||||
elapsed_time = int(time() - beg)
|
||||
beg = time()
|
||||
chunk = await loop.run_in_executor(None, ffmpeg_process.stdout.read, 32000*elapsed_time)
|
||||
if not chunk: # The first chunk will be almost empty, FFmpeg is still starting up
|
||||
chunk = await loop.run_in_executor(None, ffmpeg_process.stdout.read, 4096)
|
||||
if not chunk: # FFmpeg might have closed
|
||||
chunk = await loop.run_in_executor(
|
||||
None, ffmpeg_process.stdout.read, 32000 * elapsed_time
|
||||
)
|
||||
if (
|
||||
not chunk
|
||||
): # The first chunk will be almost empty, FFmpeg is still starting up
|
||||
chunk = await loop.run_in_executor(
|
||||
None, ffmpeg_process.stdout.read, 4096
|
||||
)
|
||||
if not chunk: # FFmpeg might have closed
|
||||
print("FFmpeg stdout closed.")
|
||||
break
|
||||
|
||||
@@ -86,21 +113,29 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
|
||||
if len(pcm_buffer) >= BYTES_PER_SEC:
|
||||
# Convert int16 -> float32
|
||||
pcm_array = np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
|
||||
pcm_array = (
|
||||
np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32)
|
||||
/ 32768.0
|
||||
)
|
||||
pcm_buffer = bytearray()
|
||||
online.insert_audio_chunk(pcm_array)
|
||||
transcription = online.process_iter()[2]
|
||||
full_transcription += transcription
|
||||
if args.vac:
|
||||
buffer = online.online.to_flush(online.online.transcript_buffer.buffer)[2] # We need to access the underlying online object to get the buffer
|
||||
buffer = online.online.to_flush(
|
||||
online.online.transcript_buffer.buffer
|
||||
)[
|
||||
2
|
||||
] # We need to access the underlying online object to get the buffer
|
||||
else:
|
||||
buffer = online.to_flush(online.transcript_buffer.buffer)[2]
|
||||
if buffer in full_transcription: # With VAC, the buffer is not updated until the next chunk is processed
|
||||
if (
|
||||
buffer in full_transcription
|
||||
): # With VAC, the buffer is not updated until the next chunk is processed
|
||||
buffer = ""
|
||||
await websocket.send_json({
|
||||
"transcription": transcription,
|
||||
"buffer": buffer
|
||||
})
|
||||
await websocket.send_json(
|
||||
{"transcription": transcription, "buffer": buffer}
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Exception in ffmpeg_stdout_reader: {e}")
|
||||
break
|
||||
@@ -135,8 +170,11 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
pass
|
||||
|
||||
ffmpeg_process.wait()
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run("whisper_fastapi_online_server:app", host=args.host, port=args.port, reload=True)
|
||||
|
||||
uvicorn.run(
|
||||
"whisper_fastapi_online_server:app", host=args.host, port=args.port, reload=True
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user