DiartDiarization now uses SpeakerSegment

This commit is contained in:
Quentin Fuxa
2025-02-28 15:44:09 +01:00
parent 7b1c88589e
commit 56717b094f

View File

@@ -6,7 +6,7 @@ import numpy as np
from diart import SpeakerDiarization
from diart.inference import StreamingInference
from diart.sources import AudioSource
from src.whisper_streaming.timed_objects import SpeakerSegment
def extract_number(s: str) -> int:
m = re.search(r'\d+', s)
@@ -58,15 +58,15 @@ class DiartDiarization:
annotation, audio = result
if annotation._labels:
for speaker, label in annotation._labels.items():
beg = label.segments_boundaries_[0]
start = label.segments_boundaries_[0]
end = label.segments_boundaries_[-1]
if end > self.processed_time:
self.processed_time = end
asyncio.create_task(self.speakers_queue.put({
"speaker": speaker,
"beg": beg,
"end": end
}))
asyncio.create_task(self.speakers_queue.put(SpeakerSegment(
speaker=speaker,
start=start,
end=end,
)))
else:
dur = audio.extent.end
if dur > self.processed_time:
@@ -84,7 +84,7 @@ class DiartDiarization:
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) -> list:
for token in tokens:
for segment in self.segment_speakers:
if not (segment["end"] <= token.start or segment["beg"] >= token.end):
token.speaker = extract_number(segment["speaker"]) + 1
if not (segment.end <= token.start or segment.start >= token.end):
token.speaker = extract_number(segment.speaker) + 1
end_attributed_speaker = max(token.end, end_attributed_speaker)
return end_attributed_speaker