mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
all text-related classes now share a common TimedText base class
This commit is contained in:
@@ -1,15 +0,0 @@
|
||||
class ASRToken:
|
||||
"""
|
||||
A token (word) from the ASR system with start/end times and text.
|
||||
"""
|
||||
def __init__(self, start: float, end: float, text: str):
|
||||
self.start = start
|
||||
self.end = end
|
||||
self.text = text
|
||||
|
||||
def with_offset(self, offset: float) -> "ASRToken":
|
||||
"""Return a new token with the time offset added."""
|
||||
return ASRToken(self.start + offset, self.end + offset, self.text)
|
||||
|
||||
def __repr__(self):
|
||||
return f"ASRToken(start={self.start:.2f}, end={self.end:.2f}, text={self.text!r})"
|
||||
@@ -6,7 +6,7 @@ import math
|
||||
import torch
|
||||
from typing import List
|
||||
import numpy as np
|
||||
from src.whisper_streaming.asr_token import ASRToken
|
||||
from src.whisper_streaming.timed_objects import ASRToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -2,37 +2,10 @@ import sys
|
||||
import numpy as np
|
||||
import logging
|
||||
from typing import List, Tuple, Optional
|
||||
from src.whisper_streaming.asr_token import ASRToken
|
||||
from src.whisper_streaming.timed_objects import ASRToken, Sentence, Transcript
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class Sentence:
|
||||
"""
|
||||
A sentence assembled from tokens.
|
||||
"""
|
||||
def __init__(self, start: float, end: float, text: str):
|
||||
self.start = start
|
||||
self.end = end
|
||||
self.text = text
|
||||
|
||||
def __repr__(self):
|
||||
return f"Sentence(start={self.start:.2f}, end={self.end:.2f}, text={self.text!r})"
|
||||
|
||||
class Transcript:
|
||||
"""
|
||||
A transcript that bundles a start time, an end time, and a concatenated text.
|
||||
"""
|
||||
def __init__(self, start: Optional[float], end: Optional[float], text: str):
|
||||
self.start = start
|
||||
self.end = end
|
||||
self.text = text
|
||||
|
||||
def __iter__(self):
|
||||
return iter((self.start, self.end, self.text))
|
||||
|
||||
def __repr__(self):
|
||||
return f"Transcript(start={self.start}, end={self.end}, text={self.text!r})"
|
||||
|
||||
|
||||
class HypothesisBuffer:
|
||||
"""
|
||||
@@ -111,10 +84,6 @@ class HypothesisBuffer:
|
||||
while self.committed_in_buffer and self.committed_in_buffer[0].end <= time:
|
||||
self.committed_in_buffer.pop(0)
|
||||
|
||||
def complete(self) -> List[ASRToken]:
|
||||
"""Return any remaining tokens (i.e. the current buffer)."""
|
||||
return self.buffer
|
||||
|
||||
|
||||
class OnlineASRProcessor:
|
||||
"""
|
||||
@@ -211,7 +180,7 @@ class OnlineASRProcessor:
|
||||
self.committed.extend(committed_tokens)
|
||||
completed = self.concatenate_tokens(committed_tokens)
|
||||
logger.debug(f">>>> COMPLETE NOW: {completed.text}")
|
||||
incomp = self.concatenate_tokens(self.transcript_buffer.complete())
|
||||
incomp = self.concatenate_tokens(self.transcript_buffer.buffer)
|
||||
logger.debug(f"INCOMPLETE: {incomp.text}")
|
||||
|
||||
if committed_tokens and self.buffer_trimming_way == "sentence":
|
||||
@@ -318,7 +287,7 @@ class OnlineASRProcessor:
|
||||
"""
|
||||
Flush the remaining transcript when processing ends.
|
||||
"""
|
||||
remaining_tokens = self.transcript_buffer.complete()
|
||||
remaining_tokens = self.transcript_buffer.buffer
|
||||
final_transcript = self.concatenate_tokens(remaining_tokens)
|
||||
logger.debug(f"Final non-committed transcript: {final_transcript}")
|
||||
self.buffer_time_offset += len(self.audio_buffer) / self.SAMPLING_RATE
|
||||
|
||||
22
src/whisper_streaming/timed_objects.py
Normal file
22
src/whisper_streaming/timed_objects.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
@dataclass
|
||||
class TimedText:
|
||||
start: Optional[float]
|
||||
end: Optional[float]
|
||||
text: str
|
||||
|
||||
@dataclass
|
||||
class ASRToken(TimedText):
|
||||
def with_offset(self, offset: float) -> "ASRToken":
|
||||
"""Return a new token with the time offset added."""
|
||||
return ASRToken(self.start + offset, self.end + offset, self.text)
|
||||
|
||||
@dataclass
|
||||
class Sentence(TimedText):
|
||||
pass
|
||||
|
||||
@dataclass
|
||||
class Transcript(TimedText):
|
||||
pass
|
||||
@@ -201,17 +201,17 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
)
|
||||
pcm_buffer = bytearray()
|
||||
online.insert_audio_chunk(pcm_array)
|
||||
beg_trans, end_trans, trans = online.process_iter()
|
||||
transcription = online.process_iter()
|
||||
|
||||
if trans:
|
||||
if transcription:
|
||||
chunk_history.append({
|
||||
"beg": beg_trans,
|
||||
"end": end_trans,
|
||||
"text": trans,
|
||||
"beg": transcription.start,
|
||||
"end": transcription.end,
|
||||
"text": transcription.text,
|
||||
"speaker": "0"
|
||||
})
|
||||
|
||||
full_transcription += trans
|
||||
full_transcription += transcription.text
|
||||
if args.vac:
|
||||
transcript = online.online.concatenate_tokens(online.online.transcript_buffer.buffer)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user