From 1f6119e405dfe2b71894736d3eed1b8c9a57ab4c Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Fri, 7 Feb 2025 23:12:04 +0100 Subject: [PATCH] all text-related classes now share a common TimedText base class --- src/whisper_streaming/asr_token.py | 15 ----------- src/whisper_streaming/backends.py | 2 +- src/whisper_streaming/online_asr.py | 37 +++----------------------- src/whisper_streaming/timed_objects.py | 22 +++++++++++++++ whisper_fastapi_online_server.py | 12 ++++----- 5 files changed, 32 insertions(+), 56 deletions(-) delete mode 100644 src/whisper_streaming/asr_token.py create mode 100644 src/whisper_streaming/timed_objects.py diff --git a/src/whisper_streaming/asr_token.py b/src/whisper_streaming/asr_token.py deleted file mode 100644 index 2d9f232..0000000 --- a/src/whisper_streaming/asr_token.py +++ /dev/null @@ -1,15 +0,0 @@ -class ASRToken: - """ - A token (word) from the ASR system with start/end times and text. - """ - def __init__(self, start: float, end: float, text: str): - self.start = start - self.end = end - self.text = text - - def with_offset(self, offset: float) -> "ASRToken": - """Return a new token with the time offset added.""" - return ASRToken(self.start + offset, self.end + offset, self.text) - - def __repr__(self): - return f"ASRToken(start={self.start:.2f}, end={self.end:.2f}, text={self.text!r})" \ No newline at end of file diff --git a/src/whisper_streaming/backends.py b/src/whisper_streaming/backends.py index 514dded..96fdbed 100644 --- a/src/whisper_streaming/backends.py +++ b/src/whisper_streaming/backends.py @@ -6,7 +6,7 @@ import math import torch from typing import List import numpy as np -from src.whisper_streaming.asr_token import ASRToken +from src.whisper_streaming.timed_objects import ASRToken logger = logging.getLogger(__name__) diff --git a/src/whisper_streaming/online_asr.py b/src/whisper_streaming/online_asr.py index 35b093c..605f9a7 100644 --- a/src/whisper_streaming/online_asr.py +++ b/src/whisper_streaming/online_asr.py @@ -2,37 +2,10 @@ import sys import numpy as np import logging from typing import List, Tuple, Optional -from src.whisper_streaming.asr_token import ASRToken +from src.whisper_streaming.timed_objects import ASRToken, Sentence, Transcript logger = logging.getLogger(__name__) -class Sentence: - """ - A sentence assembled from tokens. - """ - def __init__(self, start: float, end: float, text: str): - self.start = start - self.end = end - self.text = text - - def __repr__(self): - return f"Sentence(start={self.start:.2f}, end={self.end:.2f}, text={self.text!r})" - -class Transcript: - """ - A transcript that bundles a start time, an end time, and a concatenated text. - """ - def __init__(self, start: Optional[float], end: Optional[float], text: str): - self.start = start - self.end = end - self.text = text - - def __iter__(self): - return iter((self.start, self.end, self.text)) - - def __repr__(self): - return f"Transcript(start={self.start}, end={self.end}, text={self.text!r})" - class HypothesisBuffer: """ @@ -111,10 +84,6 @@ class HypothesisBuffer: while self.committed_in_buffer and self.committed_in_buffer[0].end <= time: self.committed_in_buffer.pop(0) - def complete(self) -> List[ASRToken]: - """Return any remaining tokens (i.e. the current buffer).""" - return self.buffer - class OnlineASRProcessor: """ @@ -211,7 +180,7 @@ class OnlineASRProcessor: self.committed.extend(committed_tokens) completed = self.concatenate_tokens(committed_tokens) logger.debug(f">>>> COMPLETE NOW: {completed.text}") - incomp = self.concatenate_tokens(self.transcript_buffer.complete()) + incomp = self.concatenate_tokens(self.transcript_buffer.buffer) logger.debug(f"INCOMPLETE: {incomp.text}") if committed_tokens and self.buffer_trimming_way == "sentence": @@ -318,7 +287,7 @@ class OnlineASRProcessor: """ Flush the remaining transcript when processing ends. """ - remaining_tokens = self.transcript_buffer.complete() + remaining_tokens = self.transcript_buffer.buffer final_transcript = self.concatenate_tokens(remaining_tokens) logger.debug(f"Final non-committed transcript: {final_transcript}") self.buffer_time_offset += len(self.audio_buffer) / self.SAMPLING_RATE diff --git a/src/whisper_streaming/timed_objects.py b/src/whisper_streaming/timed_objects.py new file mode 100644 index 0000000..19f3d4f --- /dev/null +++ b/src/whisper_streaming/timed_objects.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass +from typing import Optional + +@dataclass +class TimedText: + start: Optional[float] + end: Optional[float] + text: str + +@dataclass +class ASRToken(TimedText): + def with_offset(self, offset: float) -> "ASRToken": + """Return a new token with the time offset added.""" + return ASRToken(self.start + offset, self.end + offset, self.text) + +@dataclass +class Sentence(TimedText): + pass + +@dataclass +class Transcript(TimedText): + pass \ No newline at end of file diff --git a/whisper_fastapi_online_server.py b/whisper_fastapi_online_server.py index 7138fa4..c43262d 100644 --- a/whisper_fastapi_online_server.py +++ b/whisper_fastapi_online_server.py @@ -201,17 +201,17 @@ async def websocket_endpoint(websocket: WebSocket): ) pcm_buffer = bytearray() online.insert_audio_chunk(pcm_array) - beg_trans, end_trans, trans = online.process_iter() + transcription = online.process_iter() - if trans: + if transcription: chunk_history.append({ - "beg": beg_trans, - "end": end_trans, - "text": trans, + "beg": transcription.start, + "end": transcription.end, + "text": transcription.text, "speaker": "0" }) - full_transcription += trans + full_transcription += transcription.text if args.vac: transcript = online.online.concatenate_tokens(online.online.transcript_buffer.buffer) else: