Merge pull request #53 from QuentinFuxa/diart_integration_improvements

DiartDiarization loaded in asynccontextmanager lifespan
This commit is contained in:
Quentin Fuxa
2025-02-19 16:49:21 +01:00
committed by GitHub

View File

@@ -68,19 +68,23 @@ BYTES_PER_SAMPLE = 2 # s16le = 2 bytes per sample
BYTES_PER_SEC = SAMPLES_PER_SEC * BYTES_PER_SAMPLE
MAX_BYTES_PER_SEC = 32000 * 5 # 5 seconds of audio at 32 kHz
if args.diarization:
from src.diarization.diarization_online import DiartDiarization
##### LOAD APP #####
@asynccontextmanager
async def lifespan(app: FastAPI):
global asr, tokenizer
global asr, tokenizer, diarization
if args.transcription:
asr, tokenizer = backend_factory(args)
else:
asr, tokenizer = None, None
if args.diarization:
from src.diarization.diarization_online import DiartDiarization
diarization = DiartDiarization(SAMPLE_RATE)
else :
diarization = None
yield
app = FastAPI(lifespan=lifespan)
@@ -130,10 +134,10 @@ async def websocket_endpoint(websocket: WebSocket):
ffmpeg_process = None
pcm_buffer = bytearray()
online = online_factory(args, asr, tokenizer) if args.transcription else None
diarization = DiartDiarization(SAMPLE_RATE) if args.diarization else None
async def restart_ffmpeg():
nonlocal ffmpeg_process, online, diarization, pcm_buffer
nonlocal ffmpeg_process, online, pcm_buffer
if ffmpeg_process:
try:
ffmpeg_process.kill()
@@ -143,14 +147,12 @@ async def websocket_endpoint(websocket: WebSocket):
ffmpeg_process = await start_ffmpeg_decoder()
pcm_buffer = bytearray()
online = online_factory(args, asr, tokenizer) if args.transcription else None
if args.diarization:
diarization = DiartDiarization(SAMPLE_RATE)
logger.info("FFmpeg process started.")
await restart_ffmpeg()
async def ffmpeg_stdout_reader():
nonlocal ffmpeg_process, online, diarization, pcm_buffer
nonlocal ffmpeg_process, online, pcm_buffer
loop = asyncio.get_event_loop()
full_transcription = ""
beg = time()