sortformer diar implementation v0

This commit is contained in:
Quentin Fuxa
2025-08-19 17:02:55 +02:00
parent 7496163467
commit e14bbde77d
6 changed files with 159 additions and 108 deletions

View File

@@ -57,10 +57,10 @@ class TranscriptionEngine:
"static_init_prompt": None,
"max_context_tokens": None,
"model_path": './base.pt',
"diarization_backend": "sortformer",
# diart params:
"segmentation_model": "pyannote/segmentation-3.0",
"embedding_model": "pyannote/embedding",
}
config_dict = {**defaults, **kwargs}
@@ -119,12 +119,20 @@ class TranscriptionEngine:
warmup_asr(self.asr, self.args.warmup_file) #for simulstreaming, warmup should be done in the online class not here
if self.args.diarization:
from whisperlivekit.diarization.diarization_online import DiartDiarization
self.diarization = DiartDiarization(
block_duration=self.args.min_chunk_size,
segmentation_model_name=self.args.segmentation_model,
embedding_model_name=self.args.embedding_model
)
if self.args.diarization_backend == "diart":
from whisperlivekit.diarization.diart_backend import DiartDiarization
self.diarization = DiartDiarization(
block_duration=self.args.min_chunk_size,
segmentation_model_name=self.args.segmentation_model,
embedding_model_name=self.args.embedding_model
)
elif self.args.diarization_backend == "sortformer":
from whisperlivekit.diarization.sortformer_backend import SortformerDiarization
self.diarization = SortformerDiarization(
model_name="nvidia/diar_streaming_sortformer_4spk-v2"
)
else:
raise ValueError(f"Unknown diarization backend: {self.args.diarization_backend}")
TranscriptionEngine._initialized = True