add parameter to disable transcription (only diarization), add time in output

This commit is contained in:
Quentin Fuxa
2025-02-19 11:21:40 +01:00
parent 0ecac75455
commit 6121083549

View File

@@ -3,7 +3,7 @@ import argparse
import asyncio
import numpy as np
import ffmpeg
from time import time
from time import time, sleep
from contextlib import asynccontextmanager
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
@@ -12,9 +12,12 @@ from fastapi.middleware.cors import CORSMiddleware
from src.whisper_streaming.whisper_online import backend_factory, online_factory, add_shared_args
import subprocess
import math
import logging
from datetime import timedelta
def format_time(seconds):
return str(timedelta(seconds=int(seconds)))
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
@@ -48,6 +51,12 @@ parser.add_argument(
help="Whether to enable speaker diarization.",
)
parser.add_argument(
"--transcription",
type=bool,
default=True,
help="To disable to only see live diarization results.",
)
add_shared_args(parser)
args = parser.parse_args()
@@ -68,7 +77,10 @@ if args.diarization:
@asynccontextmanager
async def lifespan(app: FastAPI):
global asr, tokenizer
asr, tokenizer = backend_factory(args)
if args.transcription:
asr, tokenizer = backend_factory(args)
else:
asr, tokenizer = None, None
yield
app = FastAPI(lifespan=lifespan)
@@ -117,7 +129,7 @@ async def websocket_endpoint(websocket: WebSocket):
ffmpeg_process = None
pcm_buffer = bytearray()
online = online_factory(args, asr, tokenizer)
online = online_factory(args, asr, tokenizer) if args.transcription else None
diarization = DiartDiarization(SAMPLE_RATE) if args.diarization else None
async def restart_ffmpeg():
@@ -130,7 +142,7 @@ async def websocket_endpoint(websocket: WebSocket):
logger.warning(f"Error killing FFmpeg process: {e}")
ffmpeg_process = await start_ffmpeg_decoder()
pcm_buffer = bytearray()
online = online_factory(args, asr, tokenizer)
online = online_factory(args, asr, tokenizer) if args.transcription else None
if args.diarization:
diarization = DiartDiarization(SAMPLE_RATE)
logger.info("FFmpeg process started.")
@@ -142,7 +154,7 @@ async def websocket_endpoint(websocket: WebSocket):
loop = asyncio.get_event_loop()
full_transcription = ""
beg = time()
beg_loop = time()
chunk_history = [] # Will store dicts: {beg, end, text, speaker}
while True:
@@ -184,45 +196,57 @@ async def websocket_endpoint(websocket: WebSocket):
/ 32768.0
)
pcm_buffer = pcm_buffer[MAX_BYTES_PER_SEC:]
logger.info(f"{len(online.audio_buffer) / online.SAMPLING_RATE} seconds of audio will be processed by the model.")
online.insert_audio_chunk(pcm_array)
transcription = online.process_iter()
if transcription:
if args.transcription:
logger.info(f"{len(online.audio_buffer) / online.SAMPLING_RATE} seconds of audio will be processed by the model.")
online.insert_audio_chunk(pcm_array)
transcription = online.process_iter()
if transcription.start:
chunk_history.append({
"beg": transcription.start,
"end": transcription.end,
"text": transcription.text,
})
full_transcription += transcription.text if transcription else ""
buffer = online.get_buffer()
if buffer in full_transcription: # With VAC, the buffer is not updated until the next chunk is processed
buffer = ""
else:
chunk_history.append({
"beg": transcription.start,
"end": transcription.end,
"text": transcription.text,
"speaker": "0"
"beg": time() - beg_loop,
"end": time() - beg_loop + 0.1,
"text": '',
})
sleep(0.1)
buffer = ''
full_transcription += transcription.text if transcription else ""
buffer = online.get_buffer()
if buffer in full_transcription: # With VAC, the buffer is not updated until the next chunk is processed
buffer = ""
lines = [
{
"speaker": "0",
"text": "",
}
]
if args.diarization:
await diarization.diarize(pcm_array)
diarization.assign_speakers_to_chunks(chunk_history)
current_speaker = -1
lines = [{
"beg": 0,
"end": 0,
"speaker": current_speaker,
"text": ""
}]
for ch in chunk_history:
if args.diarization and ch["speaker"] and ch["speaker"][-1] != lines[-1]["speaker"]:
if args.diarization and ch["speaker"] and ch["speaker"] != current_speaker:
new_speaker = ch["speaker"]
lines.append(
{
"speaker": ch["speaker"][-1],
"text": ch['text']
"speaker": new_speaker,
"text": ch['text'],
"beg": format_time(ch['beg']),
"end": format_time(ch['end']),
}
)
current_speaker = new_speaker
else:
lines[-1]["text"] += ch['text']
lines[-1]["end"] = format_time(ch['end'])
response = {"lines": lines, "buffer": buffer}
await websocket.send_json(response)