From 6121083549c129aba048f61e6ac1c11a558c3102 Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Wed, 19 Feb 2025 11:21:40 +0100 Subject: [PATCH] add parameter to disable transcription (only diarization), add time in output --- whisper_fastapi_online_server.py | 84 ++++++++++++++++++++------------ 1 file changed, 54 insertions(+), 30 deletions(-) diff --git a/whisper_fastapi_online_server.py b/whisper_fastapi_online_server.py index 89b010c..e4a1571 100644 --- a/whisper_fastapi_online_server.py +++ b/whisper_fastapi_online_server.py @@ -3,7 +3,7 @@ import argparse import asyncio import numpy as np import ffmpeg -from time import time +from time import time, sleep from contextlib import asynccontextmanager from fastapi import FastAPI, WebSocket, WebSocketDisconnect @@ -12,9 +12,12 @@ from fastapi.middleware.cors import CORSMiddleware from src.whisper_streaming.whisper_online import backend_factory, online_factory, add_shared_args -import subprocess import math import logging +from datetime import timedelta + +def format_time(seconds): + return str(timedelta(seconds=int(seconds))) logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") @@ -48,6 +51,12 @@ parser.add_argument( help="Whether to enable speaker diarization.", ) +parser.add_argument( + "--transcription", + type=bool, + default=True, + help="To disable to only see live diarization results.", +) add_shared_args(parser) args = parser.parse_args() @@ -68,7 +77,10 @@ if args.diarization: @asynccontextmanager async def lifespan(app: FastAPI): global asr, tokenizer - asr, tokenizer = backend_factory(args) + if args.transcription: + asr, tokenizer = backend_factory(args) + else: + asr, tokenizer = None, None yield app = FastAPI(lifespan=lifespan) @@ -117,7 +129,7 @@ async def websocket_endpoint(websocket: WebSocket): ffmpeg_process = None pcm_buffer = bytearray() - online = online_factory(args, asr, tokenizer) + online = online_factory(args, asr, tokenizer) if args.transcription else None diarization = DiartDiarization(SAMPLE_RATE) if args.diarization else None async def restart_ffmpeg(): @@ -130,7 +142,7 @@ async def websocket_endpoint(websocket: WebSocket): logger.warning(f"Error killing FFmpeg process: {e}") ffmpeg_process = await start_ffmpeg_decoder() pcm_buffer = bytearray() - online = online_factory(args, asr, tokenizer) + online = online_factory(args, asr, tokenizer) if args.transcription else None if args.diarization: diarization = DiartDiarization(SAMPLE_RATE) logger.info("FFmpeg process started.") @@ -142,7 +154,7 @@ async def websocket_endpoint(websocket: WebSocket): loop = asyncio.get_event_loop() full_transcription = "" beg = time() - + beg_loop = time() chunk_history = [] # Will store dicts: {beg, end, text, speaker} while True: @@ -184,45 +196,57 @@ async def websocket_endpoint(websocket: WebSocket): / 32768.0 ) pcm_buffer = pcm_buffer[MAX_BYTES_PER_SEC:] - logger.info(f"{len(online.audio_buffer) / online.SAMPLING_RATE} seconds of audio will be processed by the model.") - online.insert_audio_chunk(pcm_array) - transcription = online.process_iter() - if transcription: + if args.transcription: + logger.info(f"{len(online.audio_buffer) / online.SAMPLING_RATE} seconds of audio will be processed by the model.") + online.insert_audio_chunk(pcm_array) + transcription = online.process_iter() + if transcription.start: + chunk_history.append({ + "beg": transcription.start, + "end": transcription.end, + "text": transcription.text, + }) + full_transcription += transcription.text if transcription else "" + buffer = online.get_buffer() + if buffer in full_transcription: # With VAC, the buffer is not updated until the next chunk is processed + buffer = "" + else: chunk_history.append({ - "beg": transcription.start, - "end": transcription.end, - "text": transcription.text, - "speaker": "0" + "beg": time() - beg_loop, + "end": time() - beg_loop + 0.1, + "text": '', }) + sleep(0.1) + buffer = '' - full_transcription += transcription.text if transcription else "" - buffer = online.get_buffer() - - if buffer in full_transcription: # With VAC, the buffer is not updated until the next chunk is processed - buffer = "" - - lines = [ - { - "speaker": "0", - "text": "", - } - ] - if args.diarization: await diarization.diarize(pcm_array) diarization.assign_speakers_to_chunks(chunk_history) + + current_speaker = -1 + lines = [{ + "beg": 0, + "end": 0, + "speaker": current_speaker, + "text": "" + }] for ch in chunk_history: - if args.diarization and ch["speaker"] and ch["speaker"][-1] != lines[-1]["speaker"]: + if args.diarization and ch["speaker"] and ch["speaker"] != current_speaker: + new_speaker = ch["speaker"] lines.append( { - "speaker": ch["speaker"][-1], - "text": ch['text'] + "speaker": new_speaker, + "text": ch['text'], + "beg": format_time(ch['beg']), + "end": format_time(ch['end']), } ) + current_speaker = new_speaker else: lines[-1]["text"] += ch['text'] + lines[-1]["end"] = format_time(ch['end']) response = {"lines": lines, "buffer": buffer} await websocket.send_json(response)