diff --git a/src/whisper_streaming/online_asr.py b/src/whisper_streaming/online_asr.py index 1a07f35..e329f86 100644 --- a/src/whisper_streaming/online_asr.py +++ b/src/whisper_streaming/online_asr.py @@ -255,11 +255,8 @@ class OnlineASRProcessor: + completed = self.chunk_completed_segment() - completed = self.concatenate_tsw(self.commited_not_final) - self.commited_not_final = [] - self.chunk_completed_segment(res) - # TODO: I don't know if res is the correct variable to pass here else: completed = [] @@ -285,18 +282,19 @@ class OnlineASRProcessor: return completed_text_segment - def chunk_completed_segment(self, res): - if self.final_transcript == []: - return + def chunk_completed_segment(self) -> list: - ends = self.asr.segments_end_ts(res) + + ts_words = self.commited_not_final - t = self.final_transcript[-1][1] - - if len(ends) <= 1: + if len(ts_words) <= 1: logger.debug(f"--- not enough segments to chunk (<=1 words)") + return [] else: + ends = [w[1] for w in ts_words] + + t = ts_words[-1][1] e = ends[-2] + self.buffer_time_offset while len(ends) > 2 and e > t: ends.pop(-1) @@ -304,8 +302,20 @@ class OnlineASRProcessor: if e <= t: logger.debug(f"--- segment chunked at {e:2.2f}") self.chunk_at(e) + + n_commited_words = len(ends)-1 + + words_to_commit = ts_words[:n_commited_words] + self.final_transcript.extend(words_to_commit) + self.commited_not_final = ts_words[n_commited_words:] + + return words_to_commit + + + else: logger.debug(f"--- last segment not within commited area") + return [] def chunk_at(self, time): @@ -379,6 +389,8 @@ class OnlineASRProcessor: if sep is None: sep = self.asr.sep + + t = sep.join(s[2] for s in tsw) if len(tsw) == 0: b = None