From ce781831ee5852ef21e85f49b48f0b3fdd1319f7 Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Sun, 24 Aug 2025 18:32:01 +0200 Subject: [PATCH] punctuation is checked in audio-processor's result formatter --- whisperlivekit/audio_processor.py | 55 ++--------- .../diarization/sortformer_backend.py | 99 ++++++------------- whisperlivekit/results_formater.py | 75 ++++++++++++++ 3 files changed, 114 insertions(+), 115 deletions(-) create mode 100644 whisperlivekit/results_formater.py diff --git a/whisperlivekit/audio_processor.py b/whisperlivekit/audio_processor.py index fcb2782..850fd0a 100644 --- a/whisperlivekit/audio_processor.py +++ b/whisperlivekit/audio_processor.py @@ -11,6 +11,7 @@ from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState from whisperlivekit.remove_silences import handle_silences from whisperlivekit.trail_repetition import trim_tail_repetition from whisperlivekit.silero_vad_iterator import FixedVADIterator +from whisperlivekit.results_formater import format_output, format_time # Set up logging once logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) @@ -18,10 +19,6 @@ logger.setLevel(logging.DEBUG) SENTINEL = object() # unique sentinel object for end of stream marker -def format_time(seconds: float) -> str: - """Format seconds as HH:MM:SS.""" - return str(timedelta(seconds=int(seconds))) - class AudioProcessor: """ Processes audio streams for transcription and diarization. @@ -433,7 +430,7 @@ class AudioProcessor: buffer_diarization = state["buffer_diarization"] end_attributed_speaker = state["end_attributed_speaker"] sep = state["sep"] - + # Add dummy tokens if needed if (not tokens or tokens[-1].is_dummy) and not self.args.transcription and self.args.diarization: await self.add_dummy_token() @@ -442,45 +439,13 @@ class AudioProcessor: tokens = state["tokens"] # Format output - previous_speaker = -1 - lines = [] - last_end_diarized = 0 - undiarized_text = [] - current_time = time() - self.beg_loop if self.beg_loop else None - tokens, buffer_transcription, buffer_diarization = handle_silences(tokens, buffer_transcription, buffer_diarization, current_time, self.silence) - for token in tokens: - speaker = token.speaker - - if speaker == -1: #Speaker -1 means no attributed by diarization. In the frontend, it should appear under 'Speaker 1' - speaker = 1 - - # Handle diarization - if self.args.diarization and not tokens[-1].speaker == -2: - if (speaker in [-1, 0]) and token.end >= end_attributed_speaker: - undiarized_text.append(token.text) - continue - elif (speaker in [-1, 0]) and token.end < end_attributed_speaker: - speaker = previous_speaker - if speaker not in [-1, 0]: - last_end_diarized = max(token.end, last_end_diarized) - - debug_info = "" - if self.debug: - debug_info = f"[{format_time(token.start)} : {format_time(token.end)}]" - if speaker != previous_speaker or not lines: - lines.append({ - "speaker": str(speaker), - "text": token.text + debug_info, - "beg": format_time(token.start), - "end": format_time(token.end), - "diff": round(token.end - last_end_diarized, 2) - }) - previous_speaker = speaker - elif token.text: # Only append if text isn't empty - lines[-1]["text"] += sep + token.text + debug_info - lines[-1]["end"] = format_time(token.end) - lines[-1]["diff"] = round(token.end - last_end_diarized, 2) - + lines, undiarized_text, buffer_transcription, buffer_diarization = format_output( + state, + self.silence, + current_time = time() - self.beg_loop if self.beg_loop else None, + diarization = self.args.diarization, + debug = self.debug + ) # Handle undiarized text if undiarized_text: combined = sep.join(undiarized_text) @@ -510,7 +475,7 @@ class AudioProcessor: "buffer_transcription": buffer_transcription, "buffer_diarization": buffer_diarization, "remaining_time_transcription": state["remaining_time_transcription"], - "remaining_time_diarization": state["remaining_time_diarization"] + "remaining_time_diarization": state["remaining_time_diarization"] if self.args.diarization else 0 } current_response_signature = f"{response_status} | " + \ diff --git a/whisperlivekit/diarization/sortformer_backend.py b/whisperlivekit/diarization/sortformer_backend.py index c31e5b2..a55931c 100644 --- a/whisperlivekit/diarization/sortformer_backend.py +++ b/whisperlivekit/diarization/sortformer_backend.py @@ -58,17 +58,16 @@ class SortformerDiarization: """ self.sample_rate = sample_rate self.speaker_segments = [] + self.buffer_audio = np.array([], dtype=np.float32) self.segment_lock = threading.Lock() self.global_time_offset = 0.0 self.processed_time = 0.0 + self.debug = False - # Load and configure the model self._load_model(model_name) - # Initialize streaming state self._init_streaming_state() - # Audio processing variables self._previous_chunk_features = None self._chunk_index = 0 self._len_prediction = None @@ -169,50 +168,35 @@ class SortformerDiarization: pcm_array: Audio data as numpy array """ try: - # Store PCM array for debugging - self.audio_buffer.append(pcm_array.copy()) + if self.debug: + self.audio_buffer.append(pcm_array.copy()) + + threshold = int(self.chunk_duration_seconds * self.sample_rate) - # Add to buffer and accumulate duration - self.audio_chunk_buffer.append(pcm_array.copy()) - chunk_duration = len(pcm_array) / self.sample_rate - self.accumulated_duration += chunk_duration - - # Check if we have accumulated enough audio - if self.accumulated_duration < self.chunk_duration_seconds: - print(f"Accumulating audio: {self.accumulated_duration:.2f}/{self.chunk_duration_seconds:.2f}s") + self.buffer_audio = np.concatenate([self.buffer_audio, pcm_array.copy()]) + if not len(self.buffer_audio) >= threshold: return - # Concatenate all buffered audio chunks - concatenated_audio = np.concatenate(self.audio_chunk_buffer) + audio = self.buffer_audio[:threshold] + self.buffer_audio = self.buffer_audio[threshold:] - # Reset buffer and accumulated duration - self.audio_chunk_buffer = [] - self.accumulated_duration = 0.0 - - # Convert audio to torch tensor - audio_signal_chunk = torch.tensor(concatenated_audio).unsqueeze(0).to(self.diar_model.device) + audio_signal_chunk = torch.tensor(audio).unsqueeze(0).to(self.diar_model.device) audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]]).to(self.diar_model.device) - # Extract mel features processed_signal_chunk, processed_signal_length_chunk = self.audio2mel.get_features( audio_signal_chunk, audio_signal_length_chunk ) - # Handle feature overlap for continuity if self._previous_chunk_features is not None: - # Add overlap from previous chunk (99 frames as in offline version) to_add = self._previous_chunk_features[:, :, -99:] total_features = torch.concat([to_add, processed_signal_chunk], dim=2) else: total_features = processed_signal_chunk - # Store current features for next iteration self._previous_chunk_features = processed_signal_chunk - # Transpose for model input chunk_feat_seq_t = torch.transpose(total_features, 1, 2) - # Process with streaming model with torch.inference_mode(): left_offset = 8 if self._chunk_index > 0 else 0 right_offset = 8 @@ -272,7 +256,6 @@ class SortformerDiarization: start=start_time, end=end_time )) - print('NEW SPEAKER, SpeakerSegment:', str(self.speaker_segments[-1])) # Update processed time self.processed_time = max(self.processed_time, base_time + self.chunk_duration_seconds) @@ -301,6 +284,7 @@ class SortformerDiarization: return tokens logger.debug(f"Assigning speakers to {len(tokens)} tokens using {len(segments)} segments") + use_punctuation_split = False if not use_punctuation_split: # Simple overlap-based assignment for token in tokens: @@ -415,27 +399,15 @@ class SortformerDiarization: with self.segment_lock: self.speaker_segments.clear() - # Save audio buffer to WAV file for debugging - if self.audio_buffer: - try: - # Concatenate all PCM chunks - concatenated_audio = np.concatenate(self.audio_buffer) - - # Convert from float32 back to int16 for WAV file - audio_data_int16 = (concatenated_audio * 32767).astype(np.int16) - - # Write to WAV file - with wave.open("diarization_audio.wav", "wb") as wav_file: - wav_file.setnchannels(1) # Mono audio - wav_file.setsampwidth(2) # 2 bytes per sample (int16) - wav_file.setframerate(self.sample_rate) - wav_file.writeframes(audio_data_int16.tobytes()) - - logger.info(f"Saved {len(concatenated_audio)} samples to diarization_audio.wav") - except Exception as e: - logger.error(f"Error saving audio to WAV file: {e}") - else: - logger.info("No audio data to save") + if self.debug: + concatenated_audio = np.concatenate(self.audio_buffer) + audio_data_int16 = (concatenated_audio * 32767).astype(np.int16) + with wave.open("diarization_audio.wav", "wb") as wav_file: + wav_file.setnchannels(1) # mono audio + wav_file.setsampwidth(2) # 2 bytes per sample (int16) + wav_file.setframerate(self.sample_rate) + wav_file.writeframes(audio_data_int16.tobytes()) + logger.info(f"Saved {len(concatenated_audio)} samples to diarization_audio.wav") def extract_number(s: str) -> int: @@ -450,40 +422,27 @@ if __name__ == '__main__': import librosa async def main(): - """Main function to test SortformerDiarization with the same example as offline version.""" - # Load and prepare audio (same as offline version) - try: - an4_audio = 'audio_test.mp3' - signal, sr = librosa.load(an4_audio, sr=16000) - signal = signal[:16000*30] # 30 seconds - except Exception as e: - print(f"Error loading audio file: {e}") - return - + """TEST ONLY.""" + an4_audio = 'audio_test.mp3' + signal, sr = librosa.load(an4_audio, sr=16000) + signal = signal[:16000*30] + print("\n" + "=" * 50) - print("Expected ground truth:") + print("ground truth:") print("Speaker 0: 0:00 - 0:09") print("Speaker 1: 0:09 - 0:19") print("Speaker 2: 0:19 - 0:25") print("Speaker 0: 0:25 - 0:30") print("=" * 50) - # Create diarization instance - diarization = SortformerDiarization(sample_rate=16000) - - # Chunk and process audio - chunk_size = 16000 # 1 second - print(f"Processing {len(signal)} samples in {len(signal) // chunk_size} chunks...") + diarization = SortformerDiarization(sample_rate=16000) + chunk_size = 1600 for i in range(0, len(signal), chunk_size): chunk = signal[i:i+chunk_size] await diarization.diarize(chunk) print(f"Processed chunk {i // chunk_size + 1}") - # Close and save WAV - # diarization.close() - - # Print results segments = diarization.get_segments() print("\nDiarization results:") for segment in segments: diff --git a/whisperlivekit/results_formater.py b/whisperlivekit/results_formater.py new file mode 100644 index 0000000..6b9ba44 --- /dev/null +++ b/whisperlivekit/results_formater.py @@ -0,0 +1,75 @@ + +import logging +from datetime import timedelta +from whisperlivekit.remove_silences import handle_silences + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + +PUNCTUATION_MARKS = {'.', '!', '?'} + +def format_time(seconds: float) -> str: + """Format seconds as HH:MM:SS.""" + return str(timedelta(seconds=int(seconds))) + + +def check_punctuation_nearby(i, tokens): + if i < len(tokens): + for ind in range(i, min(len(tokens), i+1)): #we check in the next 1 tokens + if tokens[ind].text.strip() in PUNCTUATION_MARKS: + return True + return False + + + +def format_output(state, silence, current_time, diarization, debug): + tokens = state["tokens"] + buffer_transcription = state["buffer_transcription"] + buffer_diarization = state["buffer_diarization"] + end_attributed_speaker = state["end_attributed_speaker"] + sep = state["sep"] + + previous_speaker = -1 + lines = [] + last_end_diarized = 0 + undiarized_text = [] + tokens, buffer_transcription, buffer_diarization = handle_silences(tokens, buffer_transcription, buffer_diarization, current_time, silence) + for i, token in enumerate(tokens): + speaker = token.speaker + + if len(tokens) == 1 and not diarization: + if speaker == -1: #Speaker -1 means no attributed by diarization. In the frontend, it should appear under 'Speaker 1' + speaker = 1 + + if diarization and not tokens[-1].speaker == -2: + if (speaker in [-1, 0]) and token.end >= end_attributed_speaker: + undiarized_text.append(token.text) + continue + elif (speaker in [-1, 0]) and token.end < end_attributed_speaker: + speaker = previous_speaker + if speaker not in [-1, 0]: + last_end_diarized = max(token.end, last_end_diarized) + + debug_info = "" + if debug: + debug_info = f"[{format_time(token.start)} : {format_time(token.end)}]" + if speaker != previous_speaker or not lines: + if speaker != previous_speaker and lines and check_punctuation_nearby(i, tokens): # check if punctuation nearby + lines[-1]["text"] += sep + token.text + debug_info + lines[-1]["end"] = format_time(token.end) + lines[-1]["diff"] = round(token.end - last_end_diarized, 2) + else: + lines.append({ + "speaker": int(speaker), + "text": token.text + debug_info, + "beg": format_time(token.start), + "end": format_time(token.end), + "diff": round(token.end - last_end_diarized, 2) + }) + previous_speaker = speaker + elif token.text: # Only append if text isn't empty + lines[-1]["text"] += sep + token.text + debug_info + lines[-1]["end"] = format_time(token.end) + lines[-1]["diff"] = round(token.end - last_end_diarized, 2) + + return lines, undiarized_text, buffer_transcription, '' \ No newline at end of file