punctuation is checked in audio-processor's result formatter

This commit is contained in:
Quentin Fuxa
2025-08-24 18:32:01 +02:00
parent 58297daf6d
commit ce781831ee
3 changed files with 114 additions and 115 deletions

View File

@@ -11,6 +11,7 @@ from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
from whisperlivekit.remove_silences import handle_silences
from whisperlivekit.trail_repetition import trim_tail_repetition
from whisperlivekit.silero_vad_iterator import FixedVADIterator
from whisperlivekit.results_formater import format_output, format_time
# Set up logging once
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
@@ -18,10 +19,6 @@ logger.setLevel(logging.DEBUG)
SENTINEL = object() # unique sentinel object for end of stream marker
def format_time(seconds: float) -> str:
"""Format seconds as HH:MM:SS."""
return str(timedelta(seconds=int(seconds)))
class AudioProcessor:
"""
Processes audio streams for transcription and diarization.
@@ -433,7 +430,7 @@ class AudioProcessor:
buffer_diarization = state["buffer_diarization"]
end_attributed_speaker = state["end_attributed_speaker"]
sep = state["sep"]
# Add dummy tokens if needed
if (not tokens or tokens[-1].is_dummy) and not self.args.transcription and self.args.diarization:
await self.add_dummy_token()
@@ -442,45 +439,13 @@ class AudioProcessor:
tokens = state["tokens"]
# Format output
previous_speaker = -1
lines = []
last_end_diarized = 0
undiarized_text = []
current_time = time() - self.beg_loop if self.beg_loop else None
tokens, buffer_transcription, buffer_diarization = handle_silences(tokens, buffer_transcription, buffer_diarization, current_time, self.silence)
for token in tokens:
speaker = token.speaker
if speaker == -1: #Speaker -1 means no attributed by diarization. In the frontend, it should appear under 'Speaker 1'
speaker = 1
# Handle diarization
if self.args.diarization and not tokens[-1].speaker == -2:
if (speaker in [-1, 0]) and token.end >= end_attributed_speaker:
undiarized_text.append(token.text)
continue
elif (speaker in [-1, 0]) and token.end < end_attributed_speaker:
speaker = previous_speaker
if speaker not in [-1, 0]:
last_end_diarized = max(token.end, last_end_diarized)
debug_info = ""
if self.debug:
debug_info = f"[{format_time(token.start)} : {format_time(token.end)}]"
if speaker != previous_speaker or not lines:
lines.append({
"speaker": str(speaker),
"text": token.text + debug_info,
"beg": format_time(token.start),
"end": format_time(token.end),
"diff": round(token.end - last_end_diarized, 2)
})
previous_speaker = speaker
elif token.text: # Only append if text isn't empty
lines[-1]["text"] += sep + token.text + debug_info
lines[-1]["end"] = format_time(token.end)
lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
lines, undiarized_text, buffer_transcription, buffer_diarization = format_output(
state,
self.silence,
current_time = time() - self.beg_loop if self.beg_loop else None,
diarization = self.args.diarization,
debug = self.debug
)
# Handle undiarized text
if undiarized_text:
combined = sep.join(undiarized_text)
@@ -510,7 +475,7 @@ class AudioProcessor:
"buffer_transcription": buffer_transcription,
"buffer_diarization": buffer_diarization,
"remaining_time_transcription": state["remaining_time_transcription"],
"remaining_time_diarization": state["remaining_time_diarization"]
"remaining_time_diarization": state["remaining_time_diarization"] if self.args.diarization else 0
}
current_response_signature = f"{response_status} | " + \

View File

@@ -58,17 +58,16 @@ class SortformerDiarization:
"""
self.sample_rate = sample_rate
self.speaker_segments = []
self.buffer_audio = np.array([], dtype=np.float32)
self.segment_lock = threading.Lock()
self.global_time_offset = 0.0
self.processed_time = 0.0
self.debug = False
# Load and configure the model
self._load_model(model_name)
# Initialize streaming state
self._init_streaming_state()
# Audio processing variables
self._previous_chunk_features = None
self._chunk_index = 0
self._len_prediction = None
@@ -169,50 +168,35 @@ class SortformerDiarization:
pcm_array: Audio data as numpy array
"""
try:
# Store PCM array for debugging
self.audio_buffer.append(pcm_array.copy())
if self.debug:
self.audio_buffer.append(pcm_array.copy())
threshold = int(self.chunk_duration_seconds * self.sample_rate)
# Add to buffer and accumulate duration
self.audio_chunk_buffer.append(pcm_array.copy())
chunk_duration = len(pcm_array) / self.sample_rate
self.accumulated_duration += chunk_duration
# Check if we have accumulated enough audio
if self.accumulated_duration < self.chunk_duration_seconds:
print(f"Accumulating audio: {self.accumulated_duration:.2f}/{self.chunk_duration_seconds:.2f}s")
self.buffer_audio = np.concatenate([self.buffer_audio, pcm_array.copy()])
if not len(self.buffer_audio) >= threshold:
return
# Concatenate all buffered audio chunks
concatenated_audio = np.concatenate(self.audio_chunk_buffer)
audio = self.buffer_audio[:threshold]
self.buffer_audio = self.buffer_audio[threshold:]
# Reset buffer and accumulated duration
self.audio_chunk_buffer = []
self.accumulated_duration = 0.0
# Convert audio to torch tensor
audio_signal_chunk = torch.tensor(concatenated_audio).unsqueeze(0).to(self.diar_model.device)
audio_signal_chunk = torch.tensor(audio).unsqueeze(0).to(self.diar_model.device)
audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]]).to(self.diar_model.device)
# Extract mel features
processed_signal_chunk, processed_signal_length_chunk = self.audio2mel.get_features(
audio_signal_chunk, audio_signal_length_chunk
)
# Handle feature overlap for continuity
if self._previous_chunk_features is not None:
# Add overlap from previous chunk (99 frames as in offline version)
to_add = self._previous_chunk_features[:, :, -99:]
total_features = torch.concat([to_add, processed_signal_chunk], dim=2)
else:
total_features = processed_signal_chunk
# Store current features for next iteration
self._previous_chunk_features = processed_signal_chunk
# Transpose for model input
chunk_feat_seq_t = torch.transpose(total_features, 1, 2)
# Process with streaming model
with torch.inference_mode():
left_offset = 8 if self._chunk_index > 0 else 0
right_offset = 8
@@ -272,7 +256,6 @@ class SortformerDiarization:
start=start_time,
end=end_time
))
print('NEW SPEAKER, SpeakerSegment:', str(self.speaker_segments[-1]))
# Update processed time
self.processed_time = max(self.processed_time, base_time + self.chunk_duration_seconds)
@@ -301,6 +284,7 @@ class SortformerDiarization:
return tokens
logger.debug(f"Assigning speakers to {len(tokens)} tokens using {len(segments)} segments")
use_punctuation_split = False
if not use_punctuation_split:
# Simple overlap-based assignment
for token in tokens:
@@ -415,27 +399,15 @@ class SortformerDiarization:
with self.segment_lock:
self.speaker_segments.clear()
# Save audio buffer to WAV file for debugging
if self.audio_buffer:
try:
# Concatenate all PCM chunks
concatenated_audio = np.concatenate(self.audio_buffer)
# Convert from float32 back to int16 for WAV file
audio_data_int16 = (concatenated_audio * 32767).astype(np.int16)
# Write to WAV file
with wave.open("diarization_audio.wav", "wb") as wav_file:
wav_file.setnchannels(1) # Mono audio
wav_file.setsampwidth(2) # 2 bytes per sample (int16)
wav_file.setframerate(self.sample_rate)
wav_file.writeframes(audio_data_int16.tobytes())
logger.info(f"Saved {len(concatenated_audio)} samples to diarization_audio.wav")
except Exception as e:
logger.error(f"Error saving audio to WAV file: {e}")
else:
logger.info("No audio data to save")
if self.debug:
concatenated_audio = np.concatenate(self.audio_buffer)
audio_data_int16 = (concatenated_audio * 32767).astype(np.int16)
with wave.open("diarization_audio.wav", "wb") as wav_file:
wav_file.setnchannels(1) # mono audio
wav_file.setsampwidth(2) # 2 bytes per sample (int16)
wav_file.setframerate(self.sample_rate)
wav_file.writeframes(audio_data_int16.tobytes())
logger.info(f"Saved {len(concatenated_audio)} samples to diarization_audio.wav")
def extract_number(s: str) -> int:
@@ -450,40 +422,27 @@ if __name__ == '__main__':
import librosa
async def main():
"""Main function to test SortformerDiarization with the same example as offline version."""
# Load and prepare audio (same as offline version)
try:
an4_audio = 'audio_test.mp3'
signal, sr = librosa.load(an4_audio, sr=16000)
signal = signal[:16000*30] # 30 seconds
except Exception as e:
print(f"Error loading audio file: {e}")
return
"""TEST ONLY."""
an4_audio = 'audio_test.mp3'
signal, sr = librosa.load(an4_audio, sr=16000)
signal = signal[:16000*30]
print("\n" + "=" * 50)
print("Expected ground truth:")
print("ground truth:")
print("Speaker 0: 0:00 - 0:09")
print("Speaker 1: 0:09 - 0:19")
print("Speaker 2: 0:19 - 0:25")
print("Speaker 0: 0:25 - 0:30")
print("=" * 50)
# Create diarization instance
diarization = SortformerDiarization(sample_rate=16000)
# Chunk and process audio
chunk_size = 16000 # 1 second
print(f"Processing {len(signal)} samples in {len(signal) // chunk_size} chunks...")
diarization = SortformerDiarization(sample_rate=16000)
chunk_size = 1600
for i in range(0, len(signal), chunk_size):
chunk = signal[i:i+chunk_size]
await diarization.diarize(chunk)
print(f"Processed chunk {i // chunk_size + 1}")
# Close and save WAV
# diarization.close()
# Print results
segments = diarization.get_segments()
print("\nDiarization results:")
for segment in segments:

View File

@@ -0,0 +1,75 @@
import logging
from datetime import timedelta
from whisperlivekit.remove_silences import handle_silences
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
PUNCTUATION_MARKS = {'.', '!', '?'}
def format_time(seconds: float) -> str:
"""Format seconds as HH:MM:SS."""
return str(timedelta(seconds=int(seconds)))
def check_punctuation_nearby(i, tokens):
if i < len(tokens):
for ind in range(i, min(len(tokens), i+1)): #we check in the next 1 tokens
if tokens[ind].text.strip() in PUNCTUATION_MARKS:
return True
return False
def format_output(state, silence, current_time, diarization, debug):
tokens = state["tokens"]
buffer_transcription = state["buffer_transcription"]
buffer_diarization = state["buffer_diarization"]
end_attributed_speaker = state["end_attributed_speaker"]
sep = state["sep"]
previous_speaker = -1
lines = []
last_end_diarized = 0
undiarized_text = []
tokens, buffer_transcription, buffer_diarization = handle_silences(tokens, buffer_transcription, buffer_diarization, current_time, silence)
for i, token in enumerate(tokens):
speaker = token.speaker
if len(tokens) == 1 and not diarization:
if speaker == -1: #Speaker -1 means no attributed by diarization. In the frontend, it should appear under 'Speaker 1'
speaker = 1
if diarization and not tokens[-1].speaker == -2:
if (speaker in [-1, 0]) and token.end >= end_attributed_speaker:
undiarized_text.append(token.text)
continue
elif (speaker in [-1, 0]) and token.end < end_attributed_speaker:
speaker = previous_speaker
if speaker not in [-1, 0]:
last_end_diarized = max(token.end, last_end_diarized)
debug_info = ""
if debug:
debug_info = f"[{format_time(token.start)} : {format_time(token.end)}]"
if speaker != previous_speaker or not lines:
if speaker != previous_speaker and lines and check_punctuation_nearby(i, tokens): # check if punctuation nearby
lines[-1]["text"] += sep + token.text + debug_info
lines[-1]["end"] = format_time(token.end)
lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
else:
lines.append({
"speaker": int(speaker),
"text": token.text + debug_info,
"beg": format_time(token.start),
"end": format_time(token.end),
"diff": round(token.end - last_end_diarized, 2)
})
previous_speaker = speaker
elif token.text: # Only append if text isn't empty
lines[-1]["text"] += sep + token.text + debug_info
lines[-1]["end"] = format_time(token.end)
lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
return lines, undiarized_text, buffer_transcription, ''