mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
diarization now works at word - not chunk - level!
This commit is contained in:
@@ -81,11 +81,10 @@ class DiartDiarization:
|
||||
def close(self):
|
||||
self.source.close()
|
||||
|
||||
def assign_speakers_to_chunks(self, chunks: list) -> list:
|
||||
end_attributed_speaker = 0
|
||||
for chunk in chunks:
|
||||
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) -> list:
|
||||
for token in tokens:
|
||||
for segment in self.segment_speakers:
|
||||
if not (segment["end"] <= chunk["beg"] or segment["beg"] >= chunk["end"]):
|
||||
chunk["speaker"] = extract_number(segment["speaker"]) + 1
|
||||
end_attributed_speaker = chunk["end"]
|
||||
if not (segment["end"] <= token.start or segment["beg"] >= token.end):
|
||||
token.speaker = extract_number(segment["speaker"]) + 1
|
||||
end_attributed_speaker = max(token.end, end_attributed_speaker)
|
||||
return end_attributed_speaker
|
||||
|
||||
@@ -202,7 +202,7 @@ class OnlineASRProcessor:
|
||||
logger.debug(
|
||||
f"Length of audio buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds"
|
||||
)
|
||||
return self.concatenate_tokens(committed_tokens)
|
||||
return committed_tokens
|
||||
|
||||
def chunk_completed_sentence(self):
|
||||
"""
|
||||
|
||||
@@ -5,7 +5,8 @@ from typing import Optional
|
||||
class TimedText:
|
||||
start: Optional[float]
|
||||
end: Optional[float]
|
||||
text: str
|
||||
text: Optional[str] = ''
|
||||
speaker: Optional[int] = -1
|
||||
|
||||
@dataclass
|
||||
class ASRToken(TimedText):
|
||||
|
||||
@@ -11,6 +11,7 @@ from fastapi.responses import HTMLResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from src.whisper_streaming.whisper_online import backend_factory, online_factory, add_shared_args
|
||||
from src.whisper_streaming.timed_objects import ASRToken
|
||||
|
||||
import math
|
||||
import logging
|
||||
@@ -47,7 +48,7 @@ parser.add_argument(
|
||||
parser.add_argument(
|
||||
"--diarization",
|
||||
type=bool,
|
||||
default=False,
|
||||
default=True,
|
||||
help="Whether to enable speaker diarization.",
|
||||
)
|
||||
|
||||
@@ -157,7 +158,9 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
full_transcription = ""
|
||||
beg = time()
|
||||
beg_loop = time()
|
||||
chunk_history = [] # Will store dicts: {beg, end, text, speaker}
|
||||
tokens = []
|
||||
end_attributed_speaker = 0
|
||||
sep = online.asr.sep
|
||||
|
||||
while True:
|
||||
try:
|
||||
@@ -177,7 +180,6 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
logger.warning("FFmpeg read timeout. Restarting...")
|
||||
await restart_ffmpeg()
|
||||
full_transcription = ""
|
||||
chunk_history = []
|
||||
beg = time()
|
||||
continue # Skip processing and read from new process
|
||||
|
||||
@@ -202,63 +204,53 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
if args.transcription:
|
||||
logger.info(f"{len(online.audio_buffer) / online.SAMPLING_RATE} seconds of audio will be processed by the model.")
|
||||
online.insert_audio_chunk(pcm_array)
|
||||
transcription = online.process_iter()
|
||||
if transcription.start:
|
||||
chunk_history.append({
|
||||
"beg": transcription.start,
|
||||
"end": transcription.end,
|
||||
"text": transcription.text,
|
||||
"speaker": -1
|
||||
})
|
||||
full_transcription += transcription.text if transcription else ""
|
||||
new_tokens = online.process_iter()
|
||||
tokens.extend(new_tokens)
|
||||
full_transcription += sep.join([t.text for t in new_tokens])
|
||||
buffer = online.get_buffer()
|
||||
if buffer in full_transcription: # With VAC, the buffer is not updated until the next chunk is processed
|
||||
buffer = ""
|
||||
else:
|
||||
chunk_history.append({
|
||||
"beg": time() - beg_loop,
|
||||
"end": time() - beg_loop + 1,
|
||||
"text": '',
|
||||
"speaker": -1
|
||||
})
|
||||
sleep(1)
|
||||
tokens.append(
|
||||
ASRToken(
|
||||
start = time() - beg_loop,
|
||||
end = time() - beg_loop + 0.5))
|
||||
sleep(0.5)
|
||||
buffer = ''
|
||||
|
||||
if args.diarization:
|
||||
await diarization.diarize(pcm_array)
|
||||
end_attributed_speaker = diarization.assign_speakers_to_chunks(chunk_history)
|
||||
|
||||
end_attributed_speaker = diarization.assign_speakers_to_tokens(end_attributed_speaker, tokens)
|
||||
|
||||
current_speaker = -10
|
||||
previous_speaker = -10
|
||||
lines = []
|
||||
last_end_diarized = 0
|
||||
previous_speaker = -1
|
||||
for ind, ch in enumerate(chunk_history):
|
||||
speaker = ch.get("speaker")
|
||||
for token in tokens:
|
||||
speaker = token.speaker
|
||||
if args.diarization:
|
||||
if speaker == -1 or speaker == 0:
|
||||
if ch['end'] < end_attributed_speaker:
|
||||
if token.end < end_attributed_speaker:
|
||||
speaker = previous_speaker
|
||||
else:
|
||||
speaker = 0
|
||||
else:
|
||||
last_end_diarized = max(ch['end'], last_end_diarized)
|
||||
last_end_diarized = max(token.end, last_end_diarized)
|
||||
|
||||
if speaker != current_speaker:
|
||||
if speaker != previous_speaker:
|
||||
lines.append(
|
||||
{
|
||||
"speaker": speaker,
|
||||
"text": ch['text'],
|
||||
"beg": format_time(ch['beg']),
|
||||
"end": format_time(ch['end']),
|
||||
"diff": round(ch['end'] - last_end_diarized, 2)
|
||||
"text": token.text,
|
||||
"beg": format_time(token.start),
|
||||
"end": format_time(token.end),
|
||||
"diff": round(token.end - last_end_diarized, 2)
|
||||
}
|
||||
)
|
||||
current_speaker = speaker
|
||||
previous_speaker = speaker
|
||||
else:
|
||||
lines[-1]["text"] += ch['text']
|
||||
lines[-1]["end"] = format_time(ch['end'])
|
||||
lines[-1]["diff"] = round(ch['end'] - last_end_diarized, 2)
|
||||
lines[-1]["text"] += sep + token.text
|
||||
lines[-1]["end"] = format_time(token.end)
|
||||
lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
|
||||
|
||||
response = {"lines": lines, "buffer": buffer}
|
||||
await websocket.send_json(response)
|
||||
|
||||
Reference in New Issue
Block a user