mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-04-26 16:45:46 +00:00
DiartDiarization now uses SpeakerSegment
This commit is contained in:
@@ -6,7 +6,7 @@ import numpy as np
|
||||
from diart import SpeakerDiarization
|
||||
from diart.inference import StreamingInference
|
||||
from diart.sources import AudioSource
|
||||
|
||||
from src.whisper_streaming.timed_objects import SpeakerSegment
|
||||
|
||||
def extract_number(s: str) -> int:
|
||||
m = re.search(r'\d+', s)
|
||||
@@ -58,15 +58,15 @@ class DiartDiarization:
|
||||
annotation, audio = result
|
||||
if annotation._labels:
|
||||
for speaker, label in annotation._labels.items():
|
||||
beg = label.segments_boundaries_[0]
|
||||
start = label.segments_boundaries_[0]
|
||||
end = label.segments_boundaries_[-1]
|
||||
if end > self.processed_time:
|
||||
self.processed_time = end
|
||||
asyncio.create_task(self.speakers_queue.put({
|
||||
"speaker": speaker,
|
||||
"beg": beg,
|
||||
"end": end
|
||||
}))
|
||||
asyncio.create_task(self.speakers_queue.put(SpeakerSegment(
|
||||
speaker=speaker,
|
||||
start=start,
|
||||
end=end,
|
||||
)))
|
||||
else:
|
||||
dur = audio.extent.end
|
||||
if dur > self.processed_time:
|
||||
@@ -84,7 +84,7 @@ class DiartDiarization:
|
||||
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) -> list:
|
||||
for token in tokens:
|
||||
for segment in self.segment_speakers:
|
||||
if not (segment["end"] <= token.start or segment["beg"] >= token.end):
|
||||
token.speaker = extract_number(segment["speaker"]) + 1
|
||||
if not (segment.end <= token.start or segment.start >= token.end):
|
||||
token.speaker = extract_number(segment.speaker) + 1
|
||||
end_attributed_speaker = max(token.end, end_attributed_speaker)
|
||||
return end_attributed_speaker
|
||||
|
||||
Reference in New Issue
Block a user