undiarized text is assigned to last speaker, with buffer information; traceback is used to format_exc errors

This commit is contained in:
Quentin Fuxa
2025-02-28 18:11:36 +01:00
parent 00eb4a0a4f
commit 14af47e84b

View File

@@ -16,6 +16,7 @@ from src.whisper_streaming.timed_objects import ASRToken
import math
import logging
from datetime import timedelta
import traceback
def format_time(seconds):
return str(timedelta(seconds=int(seconds)))
@@ -48,7 +49,7 @@ parser.add_argument(
parser.add_argument(
"--diarization",
type=bool,
default=True,
default=False,
help="Whether to enable speaker diarization.",
)
@@ -81,6 +82,7 @@ class SharedState:
self.lock = asyncio.Lock()
self.beg_loop = time()
self.sep = " " # Default separator
self.last_response_content = "" # To track changes in response
async def update_transcription(self, new_tokens, buffer, end_buffer, full_transcription, sep):
async with self.lock:
@@ -119,7 +121,7 @@ class SharedState:
# Calculate remaining time for diarization
if self.end_attributed_speaker > 0:
remaining_time_diarization = max(0, round(current_time - self.beg_loop - self.end_attributed_speaker, 2))
remaining_time_diarization = max(0, round(max(self.end_buffer, self.tokens[-1].end if self.tokens else 0) - self.end_attributed_speaker, 2))
return {
"tokens": self.tokens.copy(),
@@ -142,6 +144,7 @@ class SharedState:
self.end_attributed_speaker = 0
self.full_transcription = ""
self.beg_loop = time()
self.last_response_content = ""
##### LOAD APP #####
@@ -221,6 +224,7 @@ async def transcription_processor(shared_state, pcm_queue, online):
except Exception as e:
logger.warning(f"Exception in transcription_processor: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}")
finally:
pcm_queue.task_done()
@@ -247,6 +251,7 @@ async def diarization_processor(shared_state, pcm_queue, diarization_obj):
except Exception as e:
logger.warning(f"Exception in diarization_processor: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}")
finally:
pcm_queue.task_done()
@@ -272,21 +277,28 @@ async def results_formatter(shared_state, websocket):
# Process tokens to create response
previous_speaker = -10
lines = []
lines = [
]
last_end_diarized = 0
undiarized_text = []
for token in tokens:
speaker = token.speaker
# Handle diarization differently if diarization is enabled
if args.diarization:
if speaker == -1 or speaker == 0:
if token.end < end_attributed_speaker:
speaker = previous_speaker
else:
speaker = 0
else:
# If token is not yet processed by diarization
if (speaker == -1 or speaker == 0) and token.end >= end_attributed_speaker:
# Add this token's text to undiarized buffer instead of creating a new line
undiarized_text.append(token.text)
continue
# If speaker isn't assigned yet but should be (based on timestamp)
elif (speaker == -1 or speaker == 0) and token.end < end_attributed_speaker:
speaker = previous_speaker
# Track last diarized token end time
if speaker not in [-1, 0]:
last_end_diarized = max(token.end, last_end_diarized)
if speaker != previous_speaker:
if speaker != previous_speaker or not lines:
lines.append(
{
"speaker": speaker,
@@ -302,22 +314,53 @@ async def results_formatter(shared_state, websocket):
lines[-1]["end"] = format_time(token.end)
lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
# Update buffer_diarization with undiarized text
if undiarized_text:
combined_buffer_diarization = sep.join(undiarized_text)
if buffer_transcription:
combined_buffer_diarization += sep
await shared_state.update_diarization(end_attributed_speaker, combined_buffer_diarization)
buffer_diarization = combined_buffer_diarization
# Prepare response object
response = {
"lines": lines,
"buffer_transcription": buffer_transcription,
"buffer_diarization": buffer_diarization,
"remaining_time_transcription": remaining_time_transcription,
"remaining_time_diarization": remaining_time_diarization
}
if lines:
response = {
"lines": lines,
"buffer_transcription": buffer_transcription,
"buffer_diarization": buffer_diarization,
"remaining_time_transcription": remaining_time_transcription,
"remaining_time_diarization": remaining_time_diarization
}
else:
response = {
"lines": [{
"speaker": 1,
"text": "",
"beg": format_time(0),
"end": format_time(token.end) if token else format_time(0),
"diff": 0
}],
"buffer_transcription": buffer_transcription,
"buffer_diarization": buffer_diarization,
"remaining_time_transcription": remaining_time_transcription,
"remaining_time_diarization": remaining_time_diarization
}
await websocket.send_json(response)
response_content = ' '.join([str(line['speaker']) + ' ' + line["text"] for line in lines]) + ' | ' + buffer_transcription + ' | ' + buffer_diarization
if response_content != shared_state.last_response_content:
# Only send if there's actual content to send
if lines or buffer_transcription or buffer_diarization:
await websocket.send_json(response)
shared_state.last_response_content = response_content
# Add a small delay to avoid overwhelming the client
await asyncio.sleep(0.1)
except Exception as e:
logger.warning(f"Exception in results_formatter: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}")
await asyncio.sleep(0.5) # Back off on error
##### ENDPOINTS #####
@@ -422,6 +465,7 @@ async def websocket_endpoint(websocket: WebSocket):
except Exception as e:
logger.warning(f"Exception in ffmpeg_stdout_reader: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}")
break
logger.info("Exiting ffmpeg_stdout_reader...")