diarization now works at word - not chunk - level!

This commit is contained in:
Quentin Fuxa
2025-02-24 00:35:42 +01:00
parent d89622b9c2
commit 2ced4fef20
4 changed files with 36 additions and 44 deletions

View File

@@ -81,11 +81,10 @@ class DiartDiarization:
def close(self):
self.source.close()
def assign_speakers_to_chunks(self, chunks: list) -> list:
end_attributed_speaker = 0
for chunk in chunks:
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) -> list:
for token in tokens:
for segment in self.segment_speakers:
if not (segment["end"] <= chunk["beg"] or segment["beg"] >= chunk["end"]):
chunk["speaker"] = extract_number(segment["speaker"]) + 1
end_attributed_speaker = chunk["end"]
if not (segment["end"] <= token.start or segment["beg"] >= token.end):
token.speaker = extract_number(segment["speaker"]) + 1
end_attributed_speaker = max(token.end, end_attributed_speaker)
return end_attributed_speaker

View File

@@ -202,7 +202,7 @@ class OnlineASRProcessor:
logger.debug(
f"Length of audio buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds"
)
return self.concatenate_tokens(committed_tokens)
return committed_tokens
def chunk_completed_sentence(self):
"""

View File

@@ -5,7 +5,8 @@ from typing import Optional
class TimedText:
start: Optional[float]
end: Optional[float]
text: str
text: Optional[str] = ''
speaker: Optional[int] = -1
@dataclass
class ASRToken(TimedText):

View File

@@ -11,6 +11,7 @@ from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from src.whisper_streaming.whisper_online import backend_factory, online_factory, add_shared_args
from src.whisper_streaming.timed_objects import ASRToken
import math
import logging
@@ -47,7 +48,7 @@ parser.add_argument(
parser.add_argument(
"--diarization",
type=bool,
default=False,
default=True,
help="Whether to enable speaker diarization.",
)
@@ -157,7 +158,9 @@ async def websocket_endpoint(websocket: WebSocket):
full_transcription = ""
beg = time()
beg_loop = time()
chunk_history = [] # Will store dicts: {beg, end, text, speaker}
tokens = []
end_attributed_speaker = 0
sep = online.asr.sep
while True:
try:
@@ -177,7 +180,6 @@ async def websocket_endpoint(websocket: WebSocket):
logger.warning("FFmpeg read timeout. Restarting...")
await restart_ffmpeg()
full_transcription = ""
chunk_history = []
beg = time()
continue # Skip processing and read from new process
@@ -202,63 +204,53 @@ async def websocket_endpoint(websocket: WebSocket):
if args.transcription:
logger.info(f"{len(online.audio_buffer) / online.SAMPLING_RATE} seconds of audio will be processed by the model.")
online.insert_audio_chunk(pcm_array)
transcription = online.process_iter()
if transcription.start:
chunk_history.append({
"beg": transcription.start,
"end": transcription.end,
"text": transcription.text,
"speaker": -1
})
full_transcription += transcription.text if transcription else ""
new_tokens = online.process_iter()
tokens.extend(new_tokens)
full_transcription += sep.join([t.text for t in new_tokens])
buffer = online.get_buffer()
if buffer in full_transcription: # With VAC, the buffer is not updated until the next chunk is processed
buffer = ""
else:
chunk_history.append({
"beg": time() - beg_loop,
"end": time() - beg_loop + 1,
"text": '',
"speaker": -1
})
sleep(1)
tokens.append(
ASRToken(
start = time() - beg_loop,
end = time() - beg_loop + 0.5))
sleep(0.5)
buffer = ''
if args.diarization:
await diarization.diarize(pcm_array)
end_attributed_speaker = diarization.assign_speakers_to_chunks(chunk_history)
end_attributed_speaker = diarization.assign_speakers_to_tokens(end_attributed_speaker, tokens)
current_speaker = -10
previous_speaker = -10
lines = []
last_end_diarized = 0
previous_speaker = -1
for ind, ch in enumerate(chunk_history):
speaker = ch.get("speaker")
for token in tokens:
speaker = token.speaker
if args.diarization:
if speaker == -1 or speaker == 0:
if ch['end'] < end_attributed_speaker:
if token.end < end_attributed_speaker:
speaker = previous_speaker
else:
speaker = 0
else:
last_end_diarized = max(ch['end'], last_end_diarized)
last_end_diarized = max(token.end, last_end_diarized)
if speaker != current_speaker:
if speaker != previous_speaker:
lines.append(
{
"speaker": speaker,
"text": ch['text'],
"beg": format_time(ch['beg']),
"end": format_time(ch['end']),
"diff": round(ch['end'] - last_end_diarized, 2)
"text": token.text,
"beg": format_time(token.start),
"end": format_time(token.end),
"diff": round(token.end - last_end_diarized, 2)
}
)
current_speaker = speaker
previous_speaker = speaker
else:
lines[-1]["text"] += ch['text']
lines[-1]["end"] = format_time(ch['end'])
lines[-1]["diff"] = round(ch['end'] - last_end_diarized, 2)
lines[-1]["text"] += sep + token.text
lines[-1]["end"] = format_time(token.end)
lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
response = {"lines": lines, "buffer": buffer}
await websocket.send_json(response)