all text-related classes now share a common TimedText base class

This commit is contained in:
Quentin Fuxa
2025-02-07 23:12:04 +01:00
parent f7f1f259c1
commit 1f6119e405
5 changed files with 32 additions and 56 deletions

View File

@@ -1,15 +0,0 @@
class ASRToken:
"""
A token (word) from the ASR system with start/end times and text.
"""
def __init__(self, start: float, end: float, text: str):
self.start = start
self.end = end
self.text = text
def with_offset(self, offset: float) -> "ASRToken":
"""Return a new token with the time offset added."""
return ASRToken(self.start + offset, self.end + offset, self.text)
def __repr__(self):
return f"ASRToken(start={self.start:.2f}, end={self.end:.2f}, text={self.text!r})"

View File

@@ -6,7 +6,7 @@ import math
import torch
from typing import List
import numpy as np
from src.whisper_streaming.asr_token import ASRToken
from src.whisper_streaming.timed_objects import ASRToken
logger = logging.getLogger(__name__)

View File

@@ -2,37 +2,10 @@ import sys
import numpy as np
import logging
from typing import List, Tuple, Optional
from src.whisper_streaming.asr_token import ASRToken
from src.whisper_streaming.timed_objects import ASRToken, Sentence, Transcript
logger = logging.getLogger(__name__)
class Sentence:
"""
A sentence assembled from tokens.
"""
def __init__(self, start: float, end: float, text: str):
self.start = start
self.end = end
self.text = text
def __repr__(self):
return f"Sentence(start={self.start:.2f}, end={self.end:.2f}, text={self.text!r})"
class Transcript:
"""
A transcript that bundles a start time, an end time, and a concatenated text.
"""
def __init__(self, start: Optional[float], end: Optional[float], text: str):
self.start = start
self.end = end
self.text = text
def __iter__(self):
return iter((self.start, self.end, self.text))
def __repr__(self):
return f"Transcript(start={self.start}, end={self.end}, text={self.text!r})"
class HypothesisBuffer:
"""
@@ -111,10 +84,6 @@ class HypothesisBuffer:
while self.committed_in_buffer and self.committed_in_buffer[0].end <= time:
self.committed_in_buffer.pop(0)
def complete(self) -> List[ASRToken]:
"""Return any remaining tokens (i.e. the current buffer)."""
return self.buffer
class OnlineASRProcessor:
"""
@@ -211,7 +180,7 @@ class OnlineASRProcessor:
self.committed.extend(committed_tokens)
completed = self.concatenate_tokens(committed_tokens)
logger.debug(f">>>> COMPLETE NOW: {completed.text}")
incomp = self.concatenate_tokens(self.transcript_buffer.complete())
incomp = self.concatenate_tokens(self.transcript_buffer.buffer)
logger.debug(f"INCOMPLETE: {incomp.text}")
if committed_tokens and self.buffer_trimming_way == "sentence":
@@ -318,7 +287,7 @@ class OnlineASRProcessor:
"""
Flush the remaining transcript when processing ends.
"""
remaining_tokens = self.transcript_buffer.complete()
remaining_tokens = self.transcript_buffer.buffer
final_transcript = self.concatenate_tokens(remaining_tokens)
logger.debug(f"Final non-committed transcript: {final_transcript}")
self.buffer_time_offset += len(self.audio_buffer) / self.SAMPLING_RATE

View File

@@ -0,0 +1,22 @@
from dataclasses import dataclass
from typing import Optional
@dataclass
class TimedText:
start: Optional[float]
end: Optional[float]
text: str
@dataclass
class ASRToken(TimedText):
def with_offset(self, offset: float) -> "ASRToken":
"""Return a new token with the time offset added."""
return ASRToken(self.start + offset, self.end + offset, self.text)
@dataclass
class Sentence(TimedText):
pass
@dataclass
class Transcript(TimedText):
pass

View File

@@ -201,17 +201,17 @@ async def websocket_endpoint(websocket: WebSocket):
)
pcm_buffer = bytearray()
online.insert_audio_chunk(pcm_array)
beg_trans, end_trans, trans = online.process_iter()
transcription = online.process_iter()
if trans:
if transcription:
chunk_history.append({
"beg": beg_trans,
"end": end_trans,
"text": trans,
"beg": transcription.start,
"end": transcription.end,
"text": transcription.text,
"speaker": "0"
})
full_transcription += trans
full_transcription += transcription.text
if args.vac:
transcript = online.online.concatenate_tokens(online.online.transcript_buffer.buffer)
else: