mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
undiarized text is assigned to last speaker, with buffer information; traceback is used to format_exc errors
This commit is contained in:
@@ -16,6 +16,7 @@ from src.whisper_streaming.timed_objects import ASRToken
|
||||
import math
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
import traceback
|
||||
|
||||
def format_time(seconds):
|
||||
return str(timedelta(seconds=int(seconds)))
|
||||
@@ -48,7 +49,7 @@ parser.add_argument(
|
||||
parser.add_argument(
|
||||
"--diarization",
|
||||
type=bool,
|
||||
default=True,
|
||||
default=False,
|
||||
help="Whether to enable speaker diarization.",
|
||||
)
|
||||
|
||||
@@ -81,6 +82,7 @@ class SharedState:
|
||||
self.lock = asyncio.Lock()
|
||||
self.beg_loop = time()
|
||||
self.sep = " " # Default separator
|
||||
self.last_response_content = "" # To track changes in response
|
||||
|
||||
async def update_transcription(self, new_tokens, buffer, end_buffer, full_transcription, sep):
|
||||
async with self.lock:
|
||||
@@ -119,7 +121,7 @@ class SharedState:
|
||||
|
||||
# Calculate remaining time for diarization
|
||||
if self.end_attributed_speaker > 0:
|
||||
remaining_time_diarization = max(0, round(current_time - self.beg_loop - self.end_attributed_speaker, 2))
|
||||
remaining_time_diarization = max(0, round(max(self.end_buffer, self.tokens[-1].end if self.tokens else 0) - self.end_attributed_speaker, 2))
|
||||
|
||||
return {
|
||||
"tokens": self.tokens.copy(),
|
||||
@@ -142,6 +144,7 @@ class SharedState:
|
||||
self.end_attributed_speaker = 0
|
||||
self.full_transcription = ""
|
||||
self.beg_loop = time()
|
||||
self.last_response_content = ""
|
||||
|
||||
##### LOAD APP #####
|
||||
|
||||
@@ -221,6 +224,7 @@ async def transcription_processor(shared_state, pcm_queue, online):
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in transcription_processor: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
finally:
|
||||
pcm_queue.task_done()
|
||||
|
||||
@@ -247,6 +251,7 @@ async def diarization_processor(shared_state, pcm_queue, diarization_obj):
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in diarization_processor: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
finally:
|
||||
pcm_queue.task_done()
|
||||
|
||||
@@ -272,21 +277,28 @@ async def results_formatter(shared_state, websocket):
|
||||
|
||||
# Process tokens to create response
|
||||
previous_speaker = -10
|
||||
lines = []
|
||||
lines = [
|
||||
]
|
||||
last_end_diarized = 0
|
||||
undiarized_text = []
|
||||
|
||||
for token in tokens:
|
||||
speaker = token.speaker
|
||||
# Handle diarization differently if diarization is enabled
|
||||
if args.diarization:
|
||||
if speaker == -1 or speaker == 0:
|
||||
if token.end < end_attributed_speaker:
|
||||
speaker = previous_speaker
|
||||
else:
|
||||
speaker = 0
|
||||
else:
|
||||
# If token is not yet processed by diarization
|
||||
if (speaker == -1 or speaker == 0) and token.end >= end_attributed_speaker:
|
||||
# Add this token's text to undiarized buffer instead of creating a new line
|
||||
undiarized_text.append(token.text)
|
||||
continue
|
||||
# If speaker isn't assigned yet but should be (based on timestamp)
|
||||
elif (speaker == -1 or speaker == 0) and token.end < end_attributed_speaker:
|
||||
speaker = previous_speaker
|
||||
# Track last diarized token end time
|
||||
if speaker not in [-1, 0]:
|
||||
last_end_diarized = max(token.end, last_end_diarized)
|
||||
|
||||
if speaker != previous_speaker:
|
||||
if speaker != previous_speaker or not lines:
|
||||
lines.append(
|
||||
{
|
||||
"speaker": speaker,
|
||||
@@ -302,22 +314,53 @@ async def results_formatter(shared_state, websocket):
|
||||
lines[-1]["end"] = format_time(token.end)
|
||||
lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
|
||||
|
||||
# Update buffer_diarization with undiarized text
|
||||
if undiarized_text:
|
||||
combined_buffer_diarization = sep.join(undiarized_text)
|
||||
if buffer_transcription:
|
||||
combined_buffer_diarization += sep
|
||||
await shared_state.update_diarization(end_attributed_speaker, combined_buffer_diarization)
|
||||
buffer_diarization = combined_buffer_diarization
|
||||
|
||||
# Prepare response object
|
||||
response = {
|
||||
"lines": lines,
|
||||
"buffer_transcription": buffer_transcription,
|
||||
"buffer_diarization": buffer_diarization,
|
||||
"remaining_time_transcription": remaining_time_transcription,
|
||||
"remaining_time_diarization": remaining_time_diarization
|
||||
}
|
||||
if lines:
|
||||
response = {
|
||||
"lines": lines,
|
||||
"buffer_transcription": buffer_transcription,
|
||||
"buffer_diarization": buffer_diarization,
|
||||
"remaining_time_transcription": remaining_time_transcription,
|
||||
"remaining_time_diarization": remaining_time_diarization
|
||||
}
|
||||
else:
|
||||
response = {
|
||||
"lines": [{
|
||||
"speaker": 1,
|
||||
"text": "",
|
||||
"beg": format_time(0),
|
||||
"end": format_time(token.end) if token else format_time(0),
|
||||
"diff": 0
|
||||
}],
|
||||
"buffer_transcription": buffer_transcription,
|
||||
"buffer_diarization": buffer_diarization,
|
||||
"remaining_time_transcription": remaining_time_transcription,
|
||||
"remaining_time_diarization": remaining_time_diarization
|
||||
|
||||
}
|
||||
|
||||
await websocket.send_json(response)
|
||||
response_content = ' '.join([str(line['speaker']) + ' ' + line["text"] for line in lines]) + ' | ' + buffer_transcription + ' | ' + buffer_diarization
|
||||
|
||||
if response_content != shared_state.last_response_content:
|
||||
# Only send if there's actual content to send
|
||||
if lines or buffer_transcription or buffer_diarization:
|
||||
await websocket.send_json(response)
|
||||
shared_state.last_response_content = response_content
|
||||
|
||||
# Add a small delay to avoid overwhelming the client
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in results_formatter: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
await asyncio.sleep(0.5) # Back off on error
|
||||
|
||||
##### ENDPOINTS #####
|
||||
@@ -422,6 +465,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in ffmpeg_stdout_reader: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
break
|
||||
|
||||
logger.info("Exiting ffmpeg_stdout_reader...")
|
||||
|
||||
Reference in New Issue
Block a user