mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 14:23:18 +00:00
diart diarization handles pauses/silences thanks to offset
This commit is contained in:
@@ -372,7 +372,7 @@ class AudioProcessor:
|
||||
|
||||
if type(item) is Silence:
|
||||
cumulative_pcm_duration_stream_time += item.duration
|
||||
# self.diarization_obj.insert_silence(item.duration, self.tokens[-1].end)
|
||||
diarization_obj.insert_silence(item.duration)
|
||||
continue
|
||||
|
||||
if isinstance(item, np.ndarray):
|
||||
|
||||
@@ -57,7 +57,7 @@ class TranscriptionEngine:
|
||||
"static_init_prompt": None,
|
||||
"max_context_tokens": None,
|
||||
"model_path": './base.pt',
|
||||
"diarization_backend": "sortformer",
|
||||
"diarization_backend": "diart",
|
||||
# diart params:
|
||||
"segmentation_model": "pyannote/segmentation-3.0",
|
||||
"embedding_model": "pyannote/embedding",
|
||||
@@ -127,10 +127,7 @@ class TranscriptionEngine:
|
||||
embedding_model_name=self.args.embedding_model
|
||||
)
|
||||
elif self.args.diarization_backend == "sortformer":
|
||||
from whisperlivekit.diarization.sortformer_backend import SortformerDiarization
|
||||
self.diarization = SortformerDiarization(
|
||||
model_name="nvidia/diar_streaming_sortformer_4spk-v2"
|
||||
)
|
||||
raise ValueError('Sortformer backend in developement')
|
||||
else:
|
||||
raise ValueError(f"Unknown diarization backend: {self.args.diarization_backend}")
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ class DiarizationObserver(Observer):
|
||||
self.speaker_segments = []
|
||||
self.processed_time = 0
|
||||
self.segment_lock = threading.Lock()
|
||||
self.global_time_offset = 0.0
|
||||
|
||||
def on_next(self, value: Tuple[Annotation, Any]):
|
||||
annotation, audio = value
|
||||
@@ -49,8 +50,8 @@ class DiarizationObserver(Observer):
|
||||
print(f" {speaker}: {start:.2f}s-{end:.2f}s")
|
||||
self.speaker_segments.append(SpeakerSegment(
|
||||
speaker=speaker,
|
||||
start=start,
|
||||
end=end
|
||||
start=start + self.global_time_offset,
|
||||
end=end + self.global_time_offset
|
||||
))
|
||||
else:
|
||||
logger.debug("\nNo speakers detected in this segment")
|
||||
@@ -199,6 +200,9 @@ class DiartDiarization:
|
||||
self.inference.attach_observers(self.observer)
|
||||
asyncio.get_event_loop().run_in_executor(None, self.inference)
|
||||
|
||||
def insert_silence(self, silence_duration):
|
||||
self.observer.global_time_offset += silence_duration
|
||||
|
||||
async def diarize(self, pcm_array: np.ndarray):
|
||||
"""
|
||||
Process audio data for diarization.
|
||||
|
||||
@@ -61,7 +61,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--diarization-backend",
|
||||
type=str,
|
||||
default="sortformer",
|
||||
default="diart",
|
||||
choices=["sortformer", "diart"],
|
||||
help="The diarization backend to use.",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user