diart diarization handles pauses/silences thanks to offset

This commit is contained in:
Quentin Fuxa
2025-08-19 21:12:55 +02:00
parent 0c6e4b2aee
commit 253a080df5
4 changed files with 10 additions and 9 deletions

View File

@@ -372,7 +372,7 @@ class AudioProcessor:
if type(item) is Silence:
cumulative_pcm_duration_stream_time += item.duration
# self.diarization_obj.insert_silence(item.duration, self.tokens[-1].end)
diarization_obj.insert_silence(item.duration)
continue
if isinstance(item, np.ndarray):

View File

@@ -57,7 +57,7 @@ class TranscriptionEngine:
"static_init_prompt": None,
"max_context_tokens": None,
"model_path": './base.pt',
"diarization_backend": "sortformer",
"diarization_backend": "diart",
# diart params:
"segmentation_model": "pyannote/segmentation-3.0",
"embedding_model": "pyannote/embedding",
@@ -127,10 +127,7 @@ class TranscriptionEngine:
embedding_model_name=self.args.embedding_model
)
elif self.args.diarization_backend == "sortformer":
from whisperlivekit.diarization.sortformer_backend import SortformerDiarization
self.diarization = SortformerDiarization(
model_name="nvidia/diar_streaming_sortformer_4spk-v2"
)
raise ValueError('Sortformer backend in developement')
else:
raise ValueError(f"Unknown diarization backend: {self.args.diarization_backend}")

View File

@@ -29,6 +29,7 @@ class DiarizationObserver(Observer):
self.speaker_segments = []
self.processed_time = 0
self.segment_lock = threading.Lock()
self.global_time_offset = 0.0
def on_next(self, value: Tuple[Annotation, Any]):
annotation, audio = value
@@ -49,8 +50,8 @@ class DiarizationObserver(Observer):
print(f" {speaker}: {start:.2f}s-{end:.2f}s")
self.speaker_segments.append(SpeakerSegment(
speaker=speaker,
start=start,
end=end
start=start + self.global_time_offset,
end=end + self.global_time_offset
))
else:
logger.debug("\nNo speakers detected in this segment")
@@ -199,6 +200,9 @@ class DiartDiarization:
self.inference.attach_observers(self.observer)
asyncio.get_event_loop().run_in_executor(None, self.inference)
def insert_silence(self, silence_duration):
self.observer.global_time_offset += silence_duration
async def diarize(self, pcm_array: np.ndarray):
"""
Process audio data for diarization.

View File

@@ -61,7 +61,7 @@ def parse_args():
parser.add_argument(
"--diarization-backend",
type=str,
default="sortformer",
default="diart",
choices=["sortformer", "diart"],
help="The diarization backend to use.",
)