mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
96 lines
3.9 KiB
Python
96 lines
3.9 KiB
Python
import asyncio
|
|
import logging
|
|
from time import time
|
|
from typing import List, Dict, Any, Optional
|
|
from dataclasses import dataclass, field
|
|
from timed_objects import ASRToken
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SharedState:
|
|
"""
|
|
Thread-safe state manager for streaming transcription and diarization.
|
|
Handles coordination between audio processing, transcription, and diarization.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.tokens: List[ASRToken] = []
|
|
self.buffer_transcription: str = ""
|
|
self.buffer_diarization: str = ""
|
|
self.full_transcription: str = ""
|
|
self.end_buffer: float = 0
|
|
self.end_attributed_speaker: float = 0
|
|
self.lock = asyncio.Lock()
|
|
self.beg_loop: float = time()
|
|
self.sep: str = " " # Default separator
|
|
self.last_response_content: str = "" # To track changes in response
|
|
|
|
async def update_transcription(self, new_tokens: List[ASRToken], buffer: str,
|
|
end_buffer: float, full_transcription: str, sep: str) -> None:
|
|
"""Update the state with new transcription data."""
|
|
async with self.lock:
|
|
self.tokens.extend(new_tokens)
|
|
self.buffer_transcription = buffer
|
|
self.end_buffer = end_buffer
|
|
self.full_transcription = full_transcription
|
|
self.sep = sep
|
|
|
|
async def update_diarization(self, end_attributed_speaker: float, buffer_diarization: str = "") -> None:
|
|
"""Update the state with new diarization data."""
|
|
async with self.lock:
|
|
self.end_attributed_speaker = end_attributed_speaker
|
|
if buffer_diarization:
|
|
self.buffer_diarization = buffer_diarization
|
|
|
|
async def add_dummy_token(self) -> None:
|
|
"""Add a dummy token to keep the state updated even without transcription."""
|
|
async with self.lock:
|
|
current_time = time() - self.beg_loop
|
|
dummy_token = ASRToken(
|
|
start=current_time,
|
|
end=current_time + 1,
|
|
text=".",
|
|
speaker=-1,
|
|
is_dummy=True
|
|
)
|
|
self.tokens.append(dummy_token)
|
|
|
|
async def get_current_state(self) -> Dict[str, Any]:
|
|
"""Get the current state with calculated timing information."""
|
|
async with self.lock:
|
|
current_time = time()
|
|
remaining_time_transcription = 0
|
|
remaining_time_diarization = 0
|
|
|
|
# Calculate remaining time for transcription buffer
|
|
if self.end_buffer > 0:
|
|
remaining_time_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 2))
|
|
|
|
# Calculate remaining time for diarization
|
|
if self.tokens:
|
|
latest_end = max(self.end_buffer, self.tokens[-1].end if self.tokens else 0)
|
|
remaining_time_diarization = max(0, round(latest_end - self.end_attributed_speaker, 2))
|
|
|
|
return {
|
|
"tokens": self.tokens.copy(),
|
|
"buffer_transcription": self.buffer_transcription,
|
|
"buffer_diarization": self.buffer_diarization,
|
|
"end_buffer": self.end_buffer,
|
|
"end_attributed_speaker": self.end_attributed_speaker,
|
|
"sep": self.sep,
|
|
"remaining_time_transcription": remaining_time_transcription,
|
|
"remaining_time_diarization": remaining_time_diarization
|
|
}
|
|
|
|
async def reset(self) -> None:
|
|
"""Reset the state to initial values."""
|
|
async with self.lock:
|
|
self.tokens = []
|
|
self.buffer_transcription = ""
|
|
self.buffer_diarization = ""
|
|
self.end_buffer = 0
|
|
self.end_attributed_speaker = 0
|
|
self.full_transcription = ""
|
|
self.beg_loop = time()
|
|
self.last_response_content = "" |