black formating

This commit is contained in:
silask
2024-12-30 21:20:38 +01:00
parent 4cb3660666
commit 5fdb08edae
3 changed files with 492 additions and 244 deletions

View File

@@ -6,15 +6,16 @@ import torch
# Their licence is MIT, same as ours: https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/LICENSE
class VADIterator:
def __init__(self,
model,
threshold: float = 0.5,
sampling_rate: int = 16000,
min_silence_duration_ms: int = 500, # makes sense on one recording that I checked
speech_pad_ms: int = 100 # same
):
class VADIterator:
def __init__(
self,
model,
threshold: float = 0.5,
sampling_rate: int = 16000,
min_silence_duration_ms: int = 500, # makes sense on one recording that I checked
speech_pad_ms: int = 100, # same
):
"""
Class for stream imitation
@@ -41,7 +42,9 @@ class VADIterator:
self.sampling_rate = sampling_rate
if sampling_rate not in [8000, 16000]:
raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]')
raise ValueError(
"VADIterator does not support sampling rates other than [8000, 16000]"
)
self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
@@ -80,7 +83,13 @@ class VADIterator:
if (speech_prob >= self.threshold) and not self.triggered:
self.triggered = True
speech_start = self.current_sample - self.speech_pad_samples
return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, 1)}
return {
"start": (
int(speech_start)
if not return_seconds
else round(speech_start / self.sampling_rate, 1)
)
}
if (speech_prob < self.threshold - 0.15) and self.triggered:
if not self.temp_end:
@@ -91,26 +100,35 @@ class VADIterator:
speech_end = self.temp_end + self.speech_pad_samples
self.temp_end = 0
self.triggered = False
return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, 1)}
return {
"end": (
int(speech_end)
if not return_seconds
else round(speech_end / self.sampling_rate, 1)
)
}
return None
#######################
# because Silero now requires exactly 512-sized audio chunks
# because Silero now requires exactly 512-sized audio chunks
import numpy as np
class FixedVADIterator(VADIterator):
'''It fixes VADIterator by allowing to process any audio length, not only exactly 512 frames at once.
If audio to be processed at once is long and multiple voiced segments detected,
then __call__ returns the start of the first segment, and end (or middle, which means no end) of the last segment.
'''
"""It fixes VADIterator by allowing to process any audio length, not only exactly 512 frames at once.
If audio to be processed at once is long and multiple voiced segments detected,
then __call__ returns the start of the first segment, and end (or middle, which means no end) of the last segment.
"""
def reset_states(self):
super().reset_states()
self.buffer = np.array([],dtype=np.float32)
self.buffer = np.array([], dtype=np.float32)
def __call__(self, x, return_seconds=False):
self.buffer = np.append(self.buffer, x)
self.buffer = np.append(self.buffer, x)
ret = None
while len(self.buffer) >= 512:
r = super().__call__(self.buffer[:512], return_seconds=return_seconds)
@@ -118,29 +136,28 @@ class FixedVADIterator(VADIterator):
if ret is None:
ret = r
elif r is not None:
if 'end' in r:
ret['end'] = r['end'] # the latter end
if 'start' in r and 'end' in ret: # there is an earlier start.
if "end" in r:
ret["end"] = r["end"] # the latter end
if "start" in r and "end" in ret: # there is an earlier start.
# Remove end, merging this segment with the previous one.
del ret['end']
del ret["end"]
return ret if ret != {} else None
if __name__ == "__main__":
# test/demonstrate the need for FixedVADIterator:
import torch
model, _ = torch.hub.load(
repo_or_dir='snakers4/silero-vad',
model='silero_vad'
)
model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
vac = FixedVADIterator(model)
# vac = VADIterator(model) # the second case crashes with this
# vac = VADIterator(model) # the second case crashes with this
# this works: for both
audio_buffer = np.array([0]*(512),dtype=np.float32)
audio_buffer = np.array([0] * (512), dtype=np.float32)
vac(audio_buffer)
# this crashes on the non FixedVADIterator with
# this crashes on the non FixedVADIterator with
# ops.prim.RaiseException("Input audio chunk is too short", "builtins.ValueError")
audio_buffer = np.array([0]*(512-1),dtype=np.float32)
audio_buffer = np.array([0] * (512 - 1), dtype=np.float32)
vac(audio_buffer)

View File

@@ -22,10 +22,21 @@ app.add_middleware(
parser = argparse.ArgumentParser(description="Whisper FastAPI Online Server")
parser.add_argument("--host", type=str, default='localhost', help="The host address to bind the server to.")
parser.add_argument("--port", type=int, default=8000, help="The port number to bind the server to.")
parser.add_argument("--warmup-file", type=str, dest="warmup_file",
help="The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast. It can be e.g. https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav .")
parser.add_argument(
"--host",
type=str,
default="localhost",
help="The host address to bind the server to.",
)
parser.add_argument(
"--port", type=int, default=8000, help="The port number to bind the server to."
)
parser.add_argument(
"--warmup-file",
type=str,
dest="warmup_file",
help="The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast. It can be e.g. https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav .",
)
add_shared_args(parser)
args = parser.parse_args()
@@ -35,29 +46,38 @@ asr, online = asr_factory(args)
with open("src/live_transcription.html", "r") as f:
html = f.read()
@app.get("/")
async def get():
return HTMLResponse(html)
SAMPLE_RATE = 16000
CHANNELS = 1
SAMPLES_PER_SEC = SAMPLE_RATE * int(args.min_chunk_size)
BYTES_PER_SAMPLE = 2 # s16le = 2 bytes per sample
BYTES_PER_SAMPLE = 2 # s16le = 2 bytes per sample
BYTES_PER_SEC = SAMPLES_PER_SEC * BYTES_PER_SAMPLE
async def start_ffmpeg_decoder():
"""
Start an FFmpeg process in async streaming mode that reads WebM from stdin
and outputs raw s16le PCM on stdout. Returns the process object.
"""
process = (
ffmpeg
.input('pipe:0', format='webm')
.output('pipe:1', format='s16le', acodec='pcm_s16le', ac=CHANNELS, ar=str(SAMPLE_RATE))
ffmpeg.input("pipe:0", format="webm")
.output(
"pipe:1",
format="s16le",
acodec="pcm_s16le",
ac=CHANNELS,
ar=str(SAMPLE_RATE),
)
.run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True)
)
return process
@app.websocket("/asr")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
@@ -65,6 +85,7 @@ async def websocket_endpoint(websocket: WebSocket):
ffmpeg_process = await start_ffmpeg_decoder()
pcm_buffer = bytearray()
# Continuously read decoded PCM from ffmpeg stdout in a background task
async def ffmpeg_stdout_reader():
nonlocal pcm_buffer
@@ -75,10 +96,16 @@ async def websocket_endpoint(websocket: WebSocket):
try:
elapsed_time = int(time() - beg)
beg = time()
chunk = await loop.run_in_executor(None, ffmpeg_process.stdout.read, 32000*elapsed_time)
if not chunk: # The first chunk will be almost empty, FFmpeg is still starting up
chunk = await loop.run_in_executor(None, ffmpeg_process.stdout.read, 4096)
if not chunk: # FFmpeg might have closed
chunk = await loop.run_in_executor(
None, ffmpeg_process.stdout.read, 32000 * elapsed_time
)
if (
not chunk
): # The first chunk will be almost empty, FFmpeg is still starting up
chunk = await loop.run_in_executor(
None, ffmpeg_process.stdout.read, 4096
)
if not chunk: # FFmpeg might have closed
print("FFmpeg stdout closed.")
break
@@ -86,21 +113,29 @@ async def websocket_endpoint(websocket: WebSocket):
if len(pcm_buffer) >= BYTES_PER_SEC:
# Convert int16 -> float32
pcm_array = np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
pcm_array = (
np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32)
/ 32768.0
)
pcm_buffer = bytearray()
online.insert_audio_chunk(pcm_array)
transcription = online.process_iter()[2]
full_transcription += transcription
if args.vac:
buffer = online.online.to_flush(online.online.transcript_buffer.buffer)[2] # We need to access the underlying online object to get the buffer
buffer = online.online.to_flush(
online.online.transcript_buffer.buffer
)[
2
] # We need to access the underlying online object to get the buffer
else:
buffer = online.to_flush(online.transcript_buffer.buffer)[2]
if buffer in full_transcription: # With VAC, the buffer is not updated until the next chunk is processed
if (
buffer in full_transcription
): # With VAC, the buffer is not updated until the next chunk is processed
buffer = ""
await websocket.send_json({
"transcription": transcription,
"buffer": buffer
})
await websocket.send_json(
{"transcription": transcription, "buffer": buffer}
)
except Exception as e:
print(f"Exception in ffmpeg_stdout_reader: {e}")
break
@@ -135,8 +170,11 @@ async def websocket_endpoint(websocket: WebSocket):
pass
ffmpeg_process.wait()
if __name__ == "__main__":
import uvicorn
uvicorn.run("whisper_fastapi_online_server:app", host=args.host, port=args.port, reload=True)
uvicorn.run(
"whisper_fastapi_online_server:app", host=args.host, port=args.port, reload=True
)

File diff suppressed because it is too large Load Diff