mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
punctuation is checked in audio-processor's result formatter
This commit is contained in:
@@ -11,6 +11,7 @@ from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
|
||||
from whisperlivekit.remove_silences import handle_silences
|
||||
from whisperlivekit.trail_repetition import trim_tail_repetition
|
||||
from whisperlivekit.silero_vad_iterator import FixedVADIterator
|
||||
from whisperlivekit.results_formater import format_output, format_time
|
||||
# Set up logging once
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -18,10 +19,6 @@ logger.setLevel(logging.DEBUG)
|
||||
|
||||
SENTINEL = object() # unique sentinel object for end of stream marker
|
||||
|
||||
def format_time(seconds: float) -> str:
|
||||
"""Format seconds as HH:MM:SS."""
|
||||
return str(timedelta(seconds=int(seconds)))
|
||||
|
||||
class AudioProcessor:
|
||||
"""
|
||||
Processes audio streams for transcription and diarization.
|
||||
@@ -433,7 +430,7 @@ class AudioProcessor:
|
||||
buffer_diarization = state["buffer_diarization"]
|
||||
end_attributed_speaker = state["end_attributed_speaker"]
|
||||
sep = state["sep"]
|
||||
|
||||
|
||||
# Add dummy tokens if needed
|
||||
if (not tokens or tokens[-1].is_dummy) and not self.args.transcription and self.args.diarization:
|
||||
await self.add_dummy_token()
|
||||
@@ -442,45 +439,13 @@ class AudioProcessor:
|
||||
tokens = state["tokens"]
|
||||
|
||||
# Format output
|
||||
previous_speaker = -1
|
||||
lines = []
|
||||
last_end_diarized = 0
|
||||
undiarized_text = []
|
||||
current_time = time() - self.beg_loop if self.beg_loop else None
|
||||
tokens, buffer_transcription, buffer_diarization = handle_silences(tokens, buffer_transcription, buffer_diarization, current_time, self.silence)
|
||||
for token in tokens:
|
||||
speaker = token.speaker
|
||||
|
||||
if speaker == -1: #Speaker -1 means no attributed by diarization. In the frontend, it should appear under 'Speaker 1'
|
||||
speaker = 1
|
||||
|
||||
# Handle diarization
|
||||
if self.args.diarization and not tokens[-1].speaker == -2:
|
||||
if (speaker in [-1, 0]) and token.end >= end_attributed_speaker:
|
||||
undiarized_text.append(token.text)
|
||||
continue
|
||||
elif (speaker in [-1, 0]) and token.end < end_attributed_speaker:
|
||||
speaker = previous_speaker
|
||||
if speaker not in [-1, 0]:
|
||||
last_end_diarized = max(token.end, last_end_diarized)
|
||||
|
||||
debug_info = ""
|
||||
if self.debug:
|
||||
debug_info = f"[{format_time(token.start)} : {format_time(token.end)}]"
|
||||
if speaker != previous_speaker or not lines:
|
||||
lines.append({
|
||||
"speaker": str(speaker),
|
||||
"text": token.text + debug_info,
|
||||
"beg": format_time(token.start),
|
||||
"end": format_time(token.end),
|
||||
"diff": round(token.end - last_end_diarized, 2)
|
||||
})
|
||||
previous_speaker = speaker
|
||||
elif token.text: # Only append if text isn't empty
|
||||
lines[-1]["text"] += sep + token.text + debug_info
|
||||
lines[-1]["end"] = format_time(token.end)
|
||||
lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
|
||||
|
||||
lines, undiarized_text, buffer_transcription, buffer_diarization = format_output(
|
||||
state,
|
||||
self.silence,
|
||||
current_time = time() - self.beg_loop if self.beg_loop else None,
|
||||
diarization = self.args.diarization,
|
||||
debug = self.debug
|
||||
)
|
||||
# Handle undiarized text
|
||||
if undiarized_text:
|
||||
combined = sep.join(undiarized_text)
|
||||
@@ -510,7 +475,7 @@ class AudioProcessor:
|
||||
"buffer_transcription": buffer_transcription,
|
||||
"buffer_diarization": buffer_diarization,
|
||||
"remaining_time_transcription": state["remaining_time_transcription"],
|
||||
"remaining_time_diarization": state["remaining_time_diarization"]
|
||||
"remaining_time_diarization": state["remaining_time_diarization"] if self.args.diarization else 0
|
||||
}
|
||||
|
||||
current_response_signature = f"{response_status} | " + \
|
||||
|
||||
@@ -58,17 +58,16 @@ class SortformerDiarization:
|
||||
"""
|
||||
self.sample_rate = sample_rate
|
||||
self.speaker_segments = []
|
||||
self.buffer_audio = np.array([], dtype=np.float32)
|
||||
self.segment_lock = threading.Lock()
|
||||
self.global_time_offset = 0.0
|
||||
self.processed_time = 0.0
|
||||
self.debug = False
|
||||
|
||||
# Load and configure the model
|
||||
self._load_model(model_name)
|
||||
|
||||
# Initialize streaming state
|
||||
self._init_streaming_state()
|
||||
|
||||
# Audio processing variables
|
||||
self._previous_chunk_features = None
|
||||
self._chunk_index = 0
|
||||
self._len_prediction = None
|
||||
@@ -169,50 +168,35 @@ class SortformerDiarization:
|
||||
pcm_array: Audio data as numpy array
|
||||
"""
|
||||
try:
|
||||
# Store PCM array for debugging
|
||||
self.audio_buffer.append(pcm_array.copy())
|
||||
if self.debug:
|
||||
self.audio_buffer.append(pcm_array.copy())
|
||||
|
||||
threshold = int(self.chunk_duration_seconds * self.sample_rate)
|
||||
|
||||
# Add to buffer and accumulate duration
|
||||
self.audio_chunk_buffer.append(pcm_array.copy())
|
||||
chunk_duration = len(pcm_array) / self.sample_rate
|
||||
self.accumulated_duration += chunk_duration
|
||||
|
||||
# Check if we have accumulated enough audio
|
||||
if self.accumulated_duration < self.chunk_duration_seconds:
|
||||
print(f"Accumulating audio: {self.accumulated_duration:.2f}/{self.chunk_duration_seconds:.2f}s")
|
||||
self.buffer_audio = np.concatenate([self.buffer_audio, pcm_array.copy()])
|
||||
if not len(self.buffer_audio) >= threshold:
|
||||
return
|
||||
|
||||
# Concatenate all buffered audio chunks
|
||||
concatenated_audio = np.concatenate(self.audio_chunk_buffer)
|
||||
audio = self.buffer_audio[:threshold]
|
||||
self.buffer_audio = self.buffer_audio[threshold:]
|
||||
|
||||
# Reset buffer and accumulated duration
|
||||
self.audio_chunk_buffer = []
|
||||
self.accumulated_duration = 0.0
|
||||
|
||||
# Convert audio to torch tensor
|
||||
audio_signal_chunk = torch.tensor(concatenated_audio).unsqueeze(0).to(self.diar_model.device)
|
||||
audio_signal_chunk = torch.tensor(audio).unsqueeze(0).to(self.diar_model.device)
|
||||
audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]]).to(self.diar_model.device)
|
||||
|
||||
# Extract mel features
|
||||
processed_signal_chunk, processed_signal_length_chunk = self.audio2mel.get_features(
|
||||
audio_signal_chunk, audio_signal_length_chunk
|
||||
)
|
||||
|
||||
# Handle feature overlap for continuity
|
||||
if self._previous_chunk_features is not None:
|
||||
# Add overlap from previous chunk (99 frames as in offline version)
|
||||
to_add = self._previous_chunk_features[:, :, -99:]
|
||||
total_features = torch.concat([to_add, processed_signal_chunk], dim=2)
|
||||
else:
|
||||
total_features = processed_signal_chunk
|
||||
|
||||
# Store current features for next iteration
|
||||
self._previous_chunk_features = processed_signal_chunk
|
||||
|
||||
# Transpose for model input
|
||||
chunk_feat_seq_t = torch.transpose(total_features, 1, 2)
|
||||
|
||||
# Process with streaming model
|
||||
with torch.inference_mode():
|
||||
left_offset = 8 if self._chunk_index > 0 else 0
|
||||
right_offset = 8
|
||||
@@ -272,7 +256,6 @@ class SortformerDiarization:
|
||||
start=start_time,
|
||||
end=end_time
|
||||
))
|
||||
print('NEW SPEAKER, SpeakerSegment:', str(self.speaker_segments[-1]))
|
||||
|
||||
# Update processed time
|
||||
self.processed_time = max(self.processed_time, base_time + self.chunk_duration_seconds)
|
||||
@@ -301,6 +284,7 @@ class SortformerDiarization:
|
||||
return tokens
|
||||
|
||||
logger.debug(f"Assigning speakers to {len(tokens)} tokens using {len(segments)} segments")
|
||||
use_punctuation_split = False
|
||||
if not use_punctuation_split:
|
||||
# Simple overlap-based assignment
|
||||
for token in tokens:
|
||||
@@ -415,27 +399,15 @@ class SortformerDiarization:
|
||||
with self.segment_lock:
|
||||
self.speaker_segments.clear()
|
||||
|
||||
# Save audio buffer to WAV file for debugging
|
||||
if self.audio_buffer:
|
||||
try:
|
||||
# Concatenate all PCM chunks
|
||||
concatenated_audio = np.concatenate(self.audio_buffer)
|
||||
|
||||
# Convert from float32 back to int16 for WAV file
|
||||
audio_data_int16 = (concatenated_audio * 32767).astype(np.int16)
|
||||
|
||||
# Write to WAV file
|
||||
with wave.open("diarization_audio.wav", "wb") as wav_file:
|
||||
wav_file.setnchannels(1) # Mono audio
|
||||
wav_file.setsampwidth(2) # 2 bytes per sample (int16)
|
||||
wav_file.setframerate(self.sample_rate)
|
||||
wav_file.writeframes(audio_data_int16.tobytes())
|
||||
|
||||
logger.info(f"Saved {len(concatenated_audio)} samples to diarization_audio.wav")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving audio to WAV file: {e}")
|
||||
else:
|
||||
logger.info("No audio data to save")
|
||||
if self.debug:
|
||||
concatenated_audio = np.concatenate(self.audio_buffer)
|
||||
audio_data_int16 = (concatenated_audio * 32767).astype(np.int16)
|
||||
with wave.open("diarization_audio.wav", "wb") as wav_file:
|
||||
wav_file.setnchannels(1) # mono audio
|
||||
wav_file.setsampwidth(2) # 2 bytes per sample (int16)
|
||||
wav_file.setframerate(self.sample_rate)
|
||||
wav_file.writeframes(audio_data_int16.tobytes())
|
||||
logger.info(f"Saved {len(concatenated_audio)} samples to diarization_audio.wav")
|
||||
|
||||
|
||||
def extract_number(s: str) -> int:
|
||||
@@ -450,40 +422,27 @@ if __name__ == '__main__':
|
||||
import librosa
|
||||
|
||||
async def main():
|
||||
"""Main function to test SortformerDiarization with the same example as offline version."""
|
||||
# Load and prepare audio (same as offline version)
|
||||
try:
|
||||
an4_audio = 'audio_test.mp3'
|
||||
signal, sr = librosa.load(an4_audio, sr=16000)
|
||||
signal = signal[:16000*30] # 30 seconds
|
||||
except Exception as e:
|
||||
print(f"Error loading audio file: {e}")
|
||||
return
|
||||
|
||||
"""TEST ONLY."""
|
||||
an4_audio = 'audio_test.mp3'
|
||||
signal, sr = librosa.load(an4_audio, sr=16000)
|
||||
signal = signal[:16000*30]
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("Expected ground truth:")
|
||||
print("ground truth:")
|
||||
print("Speaker 0: 0:00 - 0:09")
|
||||
print("Speaker 1: 0:09 - 0:19")
|
||||
print("Speaker 2: 0:19 - 0:25")
|
||||
print("Speaker 0: 0:25 - 0:30")
|
||||
print("=" * 50)
|
||||
|
||||
# Create diarization instance
|
||||
diarization = SortformerDiarization(sample_rate=16000)
|
||||
|
||||
# Chunk and process audio
|
||||
chunk_size = 16000 # 1 second
|
||||
print(f"Processing {len(signal)} samples in {len(signal) // chunk_size} chunks...")
|
||||
diarization = SortformerDiarization(sample_rate=16000)
|
||||
chunk_size = 1600
|
||||
|
||||
for i in range(0, len(signal), chunk_size):
|
||||
chunk = signal[i:i+chunk_size]
|
||||
await diarization.diarize(chunk)
|
||||
print(f"Processed chunk {i // chunk_size + 1}")
|
||||
|
||||
# Close and save WAV
|
||||
# diarization.close()
|
||||
|
||||
# Print results
|
||||
segments = diarization.get_segments()
|
||||
print("\nDiarization results:")
|
||||
for segment in segments:
|
||||
|
||||
75
whisperlivekit/results_formater.py
Normal file
75
whisperlivekit/results_formater.py
Normal file
@@ -0,0 +1,75 @@
|
||||
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
from whisperlivekit.remove_silences import handle_silences
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
PUNCTUATION_MARKS = {'.', '!', '?'}
|
||||
|
||||
def format_time(seconds: float) -> str:
|
||||
"""Format seconds as HH:MM:SS."""
|
||||
return str(timedelta(seconds=int(seconds)))
|
||||
|
||||
|
||||
def check_punctuation_nearby(i, tokens):
|
||||
if i < len(tokens):
|
||||
for ind in range(i, min(len(tokens), i+1)): #we check in the next 1 tokens
|
||||
if tokens[ind].text.strip() in PUNCTUATION_MARKS:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
|
||||
def format_output(state, silence, current_time, diarization, debug):
|
||||
tokens = state["tokens"]
|
||||
buffer_transcription = state["buffer_transcription"]
|
||||
buffer_diarization = state["buffer_diarization"]
|
||||
end_attributed_speaker = state["end_attributed_speaker"]
|
||||
sep = state["sep"]
|
||||
|
||||
previous_speaker = -1
|
||||
lines = []
|
||||
last_end_diarized = 0
|
||||
undiarized_text = []
|
||||
tokens, buffer_transcription, buffer_diarization = handle_silences(tokens, buffer_transcription, buffer_diarization, current_time, silence)
|
||||
for i, token in enumerate(tokens):
|
||||
speaker = token.speaker
|
||||
|
||||
if len(tokens) == 1 and not diarization:
|
||||
if speaker == -1: #Speaker -1 means no attributed by diarization. In the frontend, it should appear under 'Speaker 1'
|
||||
speaker = 1
|
||||
|
||||
if diarization and not tokens[-1].speaker == -2:
|
||||
if (speaker in [-1, 0]) and token.end >= end_attributed_speaker:
|
||||
undiarized_text.append(token.text)
|
||||
continue
|
||||
elif (speaker in [-1, 0]) and token.end < end_attributed_speaker:
|
||||
speaker = previous_speaker
|
||||
if speaker not in [-1, 0]:
|
||||
last_end_diarized = max(token.end, last_end_diarized)
|
||||
|
||||
debug_info = ""
|
||||
if debug:
|
||||
debug_info = f"[{format_time(token.start)} : {format_time(token.end)}]"
|
||||
if speaker != previous_speaker or not lines:
|
||||
if speaker != previous_speaker and lines and check_punctuation_nearby(i, tokens): # check if punctuation nearby
|
||||
lines[-1]["text"] += sep + token.text + debug_info
|
||||
lines[-1]["end"] = format_time(token.end)
|
||||
lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
|
||||
else:
|
||||
lines.append({
|
||||
"speaker": int(speaker),
|
||||
"text": token.text + debug_info,
|
||||
"beg": format_time(token.start),
|
||||
"end": format_time(token.end),
|
||||
"diff": round(token.end - last_end_diarized, 2)
|
||||
})
|
||||
previous_speaker = speaker
|
||||
elif token.text: # Only append if text isn't empty
|
||||
lines[-1]["text"] += sep + token.text + debug_info
|
||||
lines[-1]["end"] = format_time(token.end)
|
||||
lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
|
||||
|
||||
return lines, undiarized_text, buffer_transcription, ''
|
||||
Reference in New Issue
Block a user