From f4f9831d399e5669267e1ba50cf848fa05fb5890 Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Thu, 20 Nov 2025 23:52:00 +0100 Subject: [PATCH] stt/diar/nllw alignment: internal rework 5 --- whisperlivekit/audio_processor.py | 156 +++++++++++++---------------- whisperlivekit/timed_objects.py | 105 ++++++------------- whisperlivekit/tokens_alignment.py | 51 +++++----- 3 files changed, 126 insertions(+), 186 deletions(-) diff --git a/whisperlivekit/audio_processor.py b/whisperlivekit/audio_processor.py index e6d092c..6fb9b19 100644 --- a/whisperlivekit/audio_processor.py +++ b/whisperlivekit/audio_processor.py @@ -1,10 +1,10 @@ import asyncio import numpy as np -from time import time, sleep -import math +from time import time import logging import traceback -from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State, StateLight, Transcript, ChangeSpeaker +from typing import Optional, Union, List, Any, AsyncGenerator +from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State, Transcript, ChangeSpeaker from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory, online_translation_factory from whisperlivekit.silero_vad_iterator import FixedVADIterator from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState @@ -16,21 +16,8 @@ logger.setLevel(logging.DEBUG) SENTINEL = object() # unique sentinel object for end of stream marker MIN_DURATION_REAL_SILENCE = 5 -def cut_at(cumulative_pcm, cut_sec): - cumulative_len = 0 - cut_sample = int(cut_sec * 16000) - - for ind, pcm_array in enumerate(cumulative_pcm): - if (cumulative_len + len(pcm_array)) >= cut_sample: - cut_chunk = cut_sample - cumulative_len - before = np.concatenate(cumulative_pcm[:ind] + [cumulative_pcm[ind][:cut_chunk]]) - after = [cumulative_pcm[ind][cut_chunk:]] + cumulative_pcm[ind+1:] - return before, after - cumulative_len += len(pcm_array) - return np.concatenate(cumulative_pcm), [] - -async def get_all_from_queue(queue): - items = [] +async def get_all_from_queue(queue: asyncio.Queue) -> Union[object, Silence, np.ndarray, List[Any]]: + items: List[Any] = [] first_item = await queue.get() queue.task_done() @@ -61,7 +48,7 @@ class AudioProcessor: Handles audio processing, state management, and result formatting. """ - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: """Initialize the audio processor with configuration, models, and state.""" if 'transcription_engine' in kwargs and isinstance(kwargs['transcription_engine'], TranscriptionEngine): @@ -80,30 +67,27 @@ class AudioProcessor: self.is_pcm_input = self.args.pcm_input # State management - self.is_stopping = False - self.current_silence = None - self.state = State() - self.state_light = StateLight() - self.lock = asyncio.Lock() - self.sep = " " # Default separator - self.last_response_content = FrontData() - self.last_detected_speaker = None - self.speaker_languages = {} + self.is_stopping: bool = False + self.current_silence: Optional[Silence] = None + self.state: State = State() + self.lock: asyncio.Lock = asyncio.Lock() + self.sep: str = " " # Default separator + self.last_response_content: FrontData = FrontData() - self.tokens_alignment = TokensAlignment(self.state_light, self.args, self.sep) - self.beg_loop = None + self.tokens_alignment: TokensAlignment = TokensAlignment(self.state, self.args, self.sep) + self.beg_loop: Optional[float] = None # Models and processing - self.asr = models.asr - self.vac_model = models.vac_model + self.asr: Any = models.asr + self.vac_model: Any = models.vac_model if self.args.vac: - self.vac = FixedVADIterator(models.vac_model) + self.vac: Optional[FixedVADIterator] = FixedVADIterator(models.vac_model) else: - self.vac = None + self.vac: Optional[FixedVADIterator] = None - self.ffmpeg_manager = None - self.ffmpeg_reader_task = None - self._ffmpeg_error = None + self.ffmpeg_manager: Optional[FFmpegManager] = None + self.ffmpeg_reader_task: Optional[asyncio.Task] = None + self._ffmpeg_error: Optional[str] = None if not self.is_pcm_input: self.ffmpeg_manager = FFmpegManager( @@ -115,21 +99,20 @@ class AudioProcessor: self._ffmpeg_error = error_type self.ffmpeg_manager.on_error_callback = handle_ffmpeg_error - self.transcription_queue = asyncio.Queue() if self.args.transcription else None - self.diarization_queue = asyncio.Queue() if self.args.diarization else None - self.translation_queue = asyncio.Queue() if self.args.target_language else None - self.pcm_buffer = bytearray() - self.total_pcm_samples = 0 - self.end_buffer = 0.0 - self.transcription_task = None - self.diarization_task = None - self.translation_task = None - self.watchdog_task = None - self.all_tasks_for_cleanup = [] + self.transcription_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.transcription else None + self.diarization_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.diarization else None + self.translation_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.target_language else None + self.pcm_buffer: bytearray = bytearray() + self.total_pcm_samples: int = 0 + self.transcription_task: Optional[asyncio.Task] = None + self.diarization_task: Optional[asyncio.Task] = None + self.translation_task: Optional[asyncio.Task] = None + self.watchdog_task: Optional[asyncio.Task] = None + self.all_tasks_for_cleanup: List[asyncio.Task] = [] - self.transcription = None - self.translation = None - self.diarization = None + self.transcription: Optional[Any] = None + self.translation: Optional[Any] = None + self.diarization: Optional[Any] = None if self.args.transcription: self.transcription = online_factory(self.args, models.asr) @@ -139,7 +122,7 @@ class AudioProcessor: if models.translation_model: self.translation = online_translation_factory(self.args, models.translation_model) - async def _push_silence_event(self): + async def _push_silence_event(self) -> None: if self.transcription_queue: await self.transcription_queue.put(self.current_silence) if self.args.diarization and self.diarization_queue: @@ -147,7 +130,7 @@ class AudioProcessor: if self.translation_queue: await self.translation_queue.put(self.current_silence) - async def _begin_silence(self): + async def _begin_silence(self) -> None: if self.current_silence: return now = time() - self.beg_loop @@ -156,7 +139,7 @@ class AudioProcessor: ) await self._push_silence_event() - async def _end_silence(self): + async def _end_silence(self) -> None: if not self.current_silence: return now = time() - self.beg_loop @@ -165,11 +148,11 @@ class AudioProcessor: self.current_silence.has_ended=True self.current_silence.compute_duration() if self.current_silence.duration > MIN_DURATION_REAL_SILENCE: - self.state_light.new_tokens.append(self.current_silence) + self.state.new_tokens.append(self.current_silence) await self._push_silence_event() self.current_silence = None - async def _enqueue_active_audio(self, pcm_chunk: np.ndarray): + async def _enqueue_active_audio(self, pcm_chunk: np.ndarray) -> None: if pcm_chunk is None or pcm_chunk.size == 0: return if self.transcription_queue: @@ -177,7 +160,7 @@ class AudioProcessor: if self.args.diarization and self.diarization_queue: await self.diarization_queue.put(pcm_chunk.copy()) - def _slice_before_silence(self, pcm_array, chunk_sample_start, silence_sample): + def _slice_before_silence(self, pcm_array: np.ndarray, chunk_sample_start: int, silence_sample: Optional[int]) -> Optional[np.ndarray]: if silence_sample is None: return None relative_index = int(silence_sample - chunk_sample_start) @@ -188,22 +171,22 @@ class AudioProcessor: return None return pcm_array[:split_index] - def convert_pcm_to_float(self, pcm_buffer): + def convert_pcm_to_float(self, pcm_buffer: Union[bytes, bytearray]) -> np.ndarray: """Convert PCM buffer in s16le format to normalized NumPy array.""" return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0 - async def get_current_state(self): + async def get_current_state(self) -> State: """Get current state.""" async with self.lock: current_time = time() remaining_transcription = 0 - if self.end_buffer > 0: - remaining_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 1)) + if self.state.end_buffer > 0: + remaining_transcription = max(0, round(current_time - self.beg_loop - self.state.end_buffer, 1)) remaining_diarization = 0 if self.state.tokens: - latest_end = max(self.end_buffer, self.state.tokens[-1].end if self.state.tokens else 0) + latest_end = max(self.state.end_buffer, self.state.tokens[-1].end if self.state.tokens else 0) remaining_diarization = max(0, round(latest_end - self.state.end_attributed_speaker, 1)) self.state.remaining_time_transcription = remaining_transcription @@ -211,7 +194,7 @@ class AudioProcessor: return self.state - async def ffmpeg_stdout_reader(self): + async def ffmpeg_stdout_reader(self) -> None: """Read audio data from FFmpeg stdout and process it into the PCM pipeline.""" beg = time() while True: @@ -261,7 +244,7 @@ class AudioProcessor: if self.translation: await self.translation_queue.put(SENTINEL) - async def transcription_processor(self): + async def transcription_processor(self) -> None: """Process audio chunks for transcription.""" cumulative_pcm_duration_stream_time = 0.0 @@ -274,11 +257,11 @@ class AudioProcessor: break asr_internal_buffer_duration_s = len(getattr(self.transcription, 'audio_buffer', [])) / self.transcription.SAMPLING_RATE - transcription_lag_s = max(0.0, time() - self.beg_loop - self.end_buffer) + transcription_lag_s = max(0.0, time() - self.beg_loop - self.state.end_buffer) asr_processing_logs = f"internal_buffer={asr_internal_buffer_duration_s:.2f}s | lag={transcription_lag_s:.2f}s |" stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time new_tokens = [] - current_audio_processed_upto = self.end_buffer + current_audio_processed_upto = self.state.end_buffer if isinstance(item, Silence): if item.is_starting: @@ -316,7 +299,7 @@ class AudioProcessor: if buffer_text.startswith(validated_text): _buffer_transcript.text = buffer_text[len(validated_text):].lstrip() - candidate_end_times = [self.end_buffer] + candidate_end_times = [self.state.end_buffer] if new_tokens: candidate_end_times.append(new_tokens[-1].end) @@ -329,9 +312,9 @@ class AudioProcessor: async with self.lock: self.state.tokens.extend(new_tokens) self.state.buffer_transcription = _buffer_transcript - self.end_buffer = max(candidate_end_times) - self.state_light.new_tokens.extend(new_tokens) - self.state_light.new_tokens_buffer = _buffer_transcript + self.state.end_buffer = max(candidate_end_times) + self.state.new_tokens.extend(new_tokens) + self.state.new_tokens_buffer = _buffer_transcript if self.translation_queue: for token in new_tokens: @@ -352,7 +335,7 @@ class AudioProcessor: logger.info("Transcription processor task finished.") - async def diarization_processor(self): + async def diarization_processor(self) -> None: while True: try: item = await get_all_from_queue(self.diarization_queue) @@ -365,14 +348,14 @@ class AudioProcessor: self.diarization.insert_audio_chunk(item) diarization_segments = await self.diarization.diarize() - self.state_light.new_diarization = diarization_segments + self.state.new_diarization = diarization_segments except Exception as e: logger.warning(f"Exception in diarization_processor: {e}") logger.warning(f"Traceback: {traceback.format_exc()}") logger.info("Diarization processor task finished.") - async def translation_processor(self): + async def translation_processor(self) -> None: # the idea is to ignore diarization for the moment. We use only transcription tokens. # And the speaker is attributed given the segments used for the translation # in the future we want to have different languages for each speaker etc, so it will be more complex. @@ -391,14 +374,14 @@ class AudioProcessor: self.translation.insert_tokens(tokens_to_process) translation_validated_segments, buffer_translation = await asyncio.to_thread(self.translation.process) async with self.lock: - self.state_light.new_translation = translation_validated_segments - self.state_light.new_translation_buffer = buffer_translation + self.state.new_translation = translation_validated_segments + self.state.new_translation_buffer = buffer_translation except Exception as e: logger.warning(f"Exception in translation_processor: {e}") logger.warning(f"Traceback: {traceback.format_exc()}") logger.info("Translation processor task finished.") - async def results_formatter(self): + async def results_formatter(self) -> AsyncGenerator[FrontData, None]: """Format processing results for output.""" while True: try: @@ -416,8 +399,7 @@ class AudioProcessor: ) state = await self.get_current_state() - buffer_transcription_text = '' - buffer_diarization_text = '' + buffer_transcription_text = state.buffer_transcription.text if state.buffer_transcription else '' response_status = "active_transcription" if not lines and not buffer_transcription_text and not buffer_diarization_text: @@ -448,17 +430,17 @@ class AudioProcessor: logger.warning(f"Exception in results_formatter. Traceback: {traceback.format_exc()}") await asyncio.sleep(0.5) - async def create_tasks(self): + async def create_tasks(self) -> AsyncGenerator[FrontData, None]: """Create and start processing tasks.""" self.all_tasks_for_cleanup = [] - processing_tasks_for_watchdog = [] + processing_tasks_for_watchdog: List[asyncio.Task] = [] # If using FFmpeg (non-PCM input), start it and spawn stdout reader if not self.is_pcm_input: success = await self.ffmpeg_manager.start() if not success: logger.error("Failed to start FFmpeg manager") - async def error_generator(): + async def error_generator() -> AsyncGenerator[FrontData, None]: yield FrontData( status="error", error="FFmpeg failed to start. Please check that FFmpeg is installed." @@ -489,9 +471,9 @@ class AudioProcessor: return self.results_formatter() - async def watchdog(self, tasks_to_monitor): + async def watchdog(self, tasks_to_monitor: List[asyncio.Task]) -> None: """Monitors the health of critical processing tasks.""" - tasks_remaining = [task for task in tasks_to_monitor if task] + tasks_remaining: List[asyncio.Task] = [task for task in tasks_to_monitor if task] while True: try: if not tasks_remaining: @@ -516,7 +498,7 @@ class AudioProcessor: except Exception as e: logger.error(f"Error in watchdog task: {e}", exc_info=True) - async def cleanup(self): + async def cleanup(self) -> None: """Clean up resources when processing is complete.""" logger.info("Starting cleanup of AudioProcessor resources.") self.is_stopping = True @@ -539,7 +521,7 @@ class AudioProcessor: self.diarization.close() logger.info("AudioProcessor cleanup complete.") - def _processing_tasks_done(self): + def _processing_tasks_done(self) -> bool: """Return True when all active processing tasks have completed.""" tasks_to_check = [ self.transcription_task, @@ -550,7 +532,7 @@ class AudioProcessor: return all(task.done() for task in tasks_to_check if task) - async def process_audio(self, message): + async def process_audio(self, message: Optional[bytes]) -> None: """Process incoming audio data.""" if not self.beg_loop: @@ -589,7 +571,7 @@ class AudioProcessor: else: logger.warning("Failed to write audio data to FFmpeg") - async def handle_pcm_data(self): + async def handle_pcm_data(self) -> None: # Process when enough data if len(self.pcm_buffer) < self.bytes_per_sec: return diff --git a/whisperlivekit/timed_objects.py b/whisperlivekit/timed_objects.py index e0b7da7..d0854e0 100644 --- a/whisperlivekit/timed_objects.py +++ b/whisperlivekit/timed_objects.py @@ -1,7 +1,6 @@ from dataclasses import dataclass, field -from typing import Optional, Any, List +from typing import Optional, List, Union, Dict, Any from datetime import timedelta -from typing import Union PUNCTUATION_MARKS = {'.', '!', '?', '。', '!', '?'} @@ -20,43 +19,32 @@ class TimedText(Timed): speaker: Optional[int] = -1 detected_language: Optional[str] = None - def is_punctuation(self): + def is_punctuation(self) -> bool: return self.text.strip() in PUNCTUATION_MARKS - def overlaps_with(self, other: 'TimedText') -> bool: - return not (self.end <= other.start or other.end <= self.start) - def is_within(self, other: 'TimedText') -> bool: return other.contains_timespan(self) def duration(self) -> float: return self.end - self.start - def contains_time(self, time: float) -> bool: - return self.start <= time <= self.end - def contains_timespan(self, other: 'TimedText') -> bool: return self.start <= other.start and self.end >= other.end - def __bool__(self): + def __bool__(self) -> bool: return bool(self.text) - def __str__(self): + def __str__(self) -> str: return str(self.text) @dataclass() class ASRToken(TimedText): - corrected_speaker: Optional[int] = -1 - validated_speaker: bool = False - validated_text: bool = False - validated_language: bool = False - def with_offset(self, offset: float) -> "ASRToken": """Return a new token with the time offset added.""" return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, detected_language=self.detected_language) - def is_silence(self): + def is_silence(self) -> bool: return False @@ -100,34 +88,6 @@ class SpeakerSegment(Timed): class Translation(TimedText): pass - def approximate_cut_at(self, cut_time): - """ - Each word in text is considered to be of duration (end-start)/len(words in text) - """ - if not self.text or not self.contains_time(cut_time): - return self, None - - words = self.text.split() - num_words = len(words) - if num_words == 0: - return self, None - - duration_per_word = self.duration() / num_words - - cut_word_index = int((cut_time - self.start) / duration_per_word) - - if cut_word_index >= num_words: - cut_word_index = num_words -1 - - text0 = " ".join(words[:cut_word_index]) - text1 = " ".join(words[cut_word_index:]) - - segment0 = Translation(start=self.start, end=cut_time, text=text0) - segment1 = Translation(start=cut_time, end=self.end, text=text1) - - return segment0, segment1 - - @dataclass class Silence(): start: Optional[float] = None @@ -136,12 +96,13 @@ class Silence(): is_starting: bool = False has_ended: bool = False - def compute_duration(self) -> float: + def compute_duration(self) -> Optional[float]: if self.start is None or self.end is None: return None self.duration = self.end - self.start + return self.duration - def is_silence(self): + def is_silence(self) -> bool: return True @@ -156,8 +117,8 @@ class Segment(): def from_tokens( cls, tokens: List[Union[ASRToken, Silence]], - is_silence=False - ) -> "Segment": + is_silence: bool = False + ) -> Optional["Segment"]: if not tokens: return None @@ -177,7 +138,7 @@ class Segment(): text=''.join(token.text for token in tokens), speaker = -1 ) - def is_silence(self): + def is_silence(self) -> bool: return self.speaker == -2 @@ -185,8 +146,8 @@ class Segment(): class Line(TimedText): translation: str = '' - def to_dict(self): - _dict = { + def to_dict(self) -> Dict[str, Any]: + _dict: Dict[str, Any] = { 'speaker': int(self.speaker) if self.speaker != -1 else 1, 'text': self.text, 'start': format_time(self.start), @@ -198,14 +159,14 @@ class Line(TimedText): _dict['detected_language'] = self.detected_language return _dict - def build_from_tokens(self, tokens: List[ASRToken]): + def build_from_tokens(self, tokens: List[ASRToken]) -> "Line": self.text = ''.join([token.text for token in tokens]) self.start = tokens[0].start self.end = tokens[-1].end self.speaker = 1 return self - def build_from_segment(self, segment: Segment): + def build_from_segment(self, segment: Segment) -> "Line": self.text = segment.text self.start = segment.start self.end = segment.end @@ -216,7 +177,7 @@ class Line(TimedText): return self.speaker == -2 class SilentLine(Line): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.speaker = -2 self.text = '' @@ -233,8 +194,8 @@ class FrontData(): remaining_time_transcription: float = 0. remaining_time_diarization: float = 0. - def to_dict(self): - _dict = { + def to_dict(self) -> Dict[str, Any]: + _dict: Dict[str, Any] = { 'status': self.status, 'lines': [line.to_dict() for line in self.lines if (line.text or line.speaker == -2)], 'buffer_transcription': self.buffer_transcription, @@ -254,24 +215,22 @@ class ChangeSpeaker: @dataclass class State(): - tokens: list = field(default_factory=list) - last_validated_token: int = 0 - last_speaker: int = 1 - last_punctuation_index: Optional[int] = None - translation_validated_segments: list = field(default_factory=list) - buffer_translation: str = field(default_factory=Transcript) - buffer_transcription: str = field(default_factory=Transcript) - diarization_segments: list = field(default_factory=list) + """Unified state class for audio processing. + + Contains both persistent state (tokens, buffers) and temporary update buffers + (new_* fields) that are consumed by TokensAlignment. + """ + # Persistent state + tokens: List[ASRToken] = field(default_factory=list) + buffer_transcription: Transcript = field(default_factory=Transcript) end_buffer: float = 0.0 end_attributed_speaker: float = 0.0 remaining_time_transcription: float = 0.0 remaining_time_diarization: float = 0.0 - - -@dataclass -class StateLight(): - new_tokens: list = field(default_factory=list) - new_translation: list = field(default_factory=list) - new_diarization: list = field(default_factory=list) - new_tokens_buffer: list = field(default_factory=list) #only when local agreement + + # Temporary update buffers (consumed by TokensAlignment.update()) + new_tokens: List[Union[ASRToken, Silence]] = field(default_factory=list) + new_translation: List[Any] = field(default_factory=list) + new_diarization: List[Any] = field(default_factory=list) + new_tokens_buffer: List[Any] = field(default_factory=list) # only when local agreement new_translation_buffer: str = '' \ No newline at end of file diff --git a/whisperlivekit/tokens_alignment.py b/whisperlivekit/tokens_alignment.py index 246fba3..4711f52 100644 --- a/whisperlivekit/tokens_alignment.py +++ b/whisperlivekit/tokens_alignment.py @@ -1,31 +1,31 @@ from time import time -from typing import Optional +from typing import Optional, List, Tuple, Union, Any from whisperlivekit.timed_objects import Line, SilentLine, ASRToken, SpeakerSegment, Silence, TimedText, Segment class TokensAlignment: - def __init__(self, state, args, sep): + def __init__(self, state: Any, args: Any, sep: Optional[str]) -> None: self.state = state self.diarization = args.diarization - self._tokens_index = 0 - self._diarization_index = 0 - self._translation_index = 0 + self._tokens_index: int = 0 + self._diarization_index: int = 0 + self._translation_index: int = 0 - self.all_tokens : list[ASRToken] = [] - self.all_diarization_segments: list[SpeakerSegment] = [] - self.all_translation_segments = [] + self.all_tokens: List[ASRToken] = [] + self.all_diarization_segments: List[SpeakerSegment] = [] + self.all_translation_segments: List[Any] = [] - self.new_tokens : list[ASRToken] = [] - self.new_diarization: list[SpeakerSegment] = [] - self.new_translation = [] - self.new_translation_buffer = TimedText() - self.new_tokens_buffer = [] - self.sep = sep if sep is not None else ' ' - self.beg_loop = None + self.new_tokens: List[ASRToken] = [] + self.new_diarization: List[SpeakerSegment] = [] + self.new_translation: List[Any] = [] + self.new_translation_buffer: Union[TimedText, str] = TimedText() + self.new_tokens_buffer: List[Any] = [] + self.sep: str = sep if sep is not None else ' ' + self.beg_loop: Optional[float] = None - def update(self): + def update(self) -> None: self.new_tokens, self.state.new_tokens = self.state.new_tokens, [] self.new_diarization, self.state.new_diarization = self.state.new_diarization, [] self.new_translation, self.state.new_translation = self.state.new_translation, [] @@ -38,8 +38,7 @@ class TokensAlignment: self.new_translation_buffer = self.state.new_translation_buffer if self.new_translation else self.new_translation_buffer self.new_translation_buffer = self.new_translation_buffer if type(self.new_translation_buffer) == str else self.new_translation_buffer.text - def add_translation(self, line : Line): - + def add_translation(self, line: Line) -> None: for ts in self.all_translation_segments: if ts.is_within(line): line.translation += ts.text + self.sep @@ -47,7 +46,7 @@ class TokensAlignment: break - def compute_punctuations_segments(self, tokens: Optional[list[ASRToken]] = None): + def compute_punctuations_segments(self, tokens: Optional[List[ASRToken]] = None) -> List[Segment]: segments = [] segment_start_idx = 0 for i, token in enumerate(self.all_tokens): @@ -79,7 +78,7 @@ class TokensAlignment: return segments - def concatenate_diar_segments(self): + def concatenate_diar_segments(self) -> List[SpeakerSegment]: if not self.all_diarization_segments: return [] merged = [self.all_diarization_segments[0]] @@ -92,13 +91,13 @@ class TokensAlignment: @staticmethod - def intersection_duration(seg1, seg2): + def intersection_duration(seg1: TimedText, seg2: TimedText) -> float: start = max(seg1.start, seg2.start) end = min(seg1.end, seg2.end) return max(0, end - start) - def get_lines_diarization(self): + def get_lines_diarization(self) -> Tuple[List[Line], str]: """ use compute_punctuations_segments, concatenate_diar_segments, intersection_duration """ @@ -135,10 +134,10 @@ class TokensAlignment: def get_lines( self, - diarization=False, - translation=False, - current_silence=None - ): + diarization: bool = False, + translation: bool = False, + current_silence: Optional[Silence] = None + ) -> Tuple[List[Line], str, Union[str, TimedText]]: """ In the case without diarization """