diff --git a/README.md b/README.md index f68557a..e7fd47c 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Real-time, fully local Speech-to-Text and speaker diarization using FastAPI WebSockets with a web interface +# Real-time, Fully Local Speech-to-Text and Speaker Diarization This project is based on [Whisper Streaming](https://github.com/ufal/whisper_streaming) and lets you transcribe audio directly from your browser. Simply launch the local server and grant microphone access. Everything runs locally on your machine ✨ @@ -8,24 +8,23 @@ This project is based on [Whisper Streaming](https://github.com/ufal/whisper_str ### Differences from [Whisper Streaming](https://github.com/ufal/whisper_streaming) -#### 🌐 **Web & API** -- **Built-in Web UI** – No frontend setup required, just open your browser and start transcribing. -- **FastAPI WebSocket Server** – Real-time speech-to-text processing with async FFmpeg streaming. -- **JavaScript Client** – Ready-to-use MediaRecorder implementation for seamless client-side integration. - #### βš™οΈ **Core Improvements** - **Buffering Preview** – Displays unvalidated transcription segments for immediate feedback. - **Multi-User Support** – Handles multiple users simultaneously without conflicts. - **MLX Whisper Backend** – Optimized for Apple Silicon for faster local processing. - **Enhanced Sentence Segmentation** – Improved buffer trimming for better accuracy across languages. -- **Extended Logging** – More detailed logs to improve debugging and monitoring. +- **Confidence validation** – Immediately validate high-confidence tokens for faster inference -#### πŸŽ™οΈ **Advanced Features** -- **Real-Time Diarization** – Identify different speakers in real time using [Diart](https://github.com/juanmc2005/diart). +#### πŸŽ™οΈ **Speaker Identification** +- **Real-Time Diarization** – Identify different speakers in real time using [Diart](https://github.com/juanmc2005/diart). + +#### 🌐 **Web & API** +- **Built-in Web UI** – Simple browser interface with no frontend setup required +- **FastAPI WebSocket Server** – Real-time speech-to-text processing with async FFmpeg streaming. +- **JavaScript Client** – Ready-to-use MediaRecorder implementation for seamless client-side integration. #### πŸš€ **Coming Soon** -- **Faster Word Validation** – Accelerate real-time transcription by validating high-confidence words immediately upon first appearance for whisper backends that return word & segment probabilities - **Enhanced Diarization Performance** – Optimize speaker identification by implementing longer steps for Diart processing and leveraging language-specific segmentation patterns to improve speaker boundary detection @@ -87,12 +86,13 @@ This project is based on [Whisper Streaming](https://github.com/ufal/whisper_str python whisper_fastapi_online_server.py --host 0.0.0.0 --port 8000 ``` + All [Whisper Streaming](https://github.com/ufal/whisper_streaming) parameters are supported. + Additional parameters: - `--host` and `--port` let you specify the server’s IP/port. - `-min-chunk-size` sets the minimum chunk size for audio processing. Make sure this value aligns with the chunk size selected in the frontend. If not aligned, the system will work but may unnecessarily over-process audio data. - - For a full list of configurable options, run `python whisper_fastapi_online_server.py -h` - - `--transcription`, default to True. Change to False if you want to run only diarization - - `--diarization`, default to False, let you choose whether or not you want to run diarization in parallel - - For other parameters, look at [whisper streaming](https://github.com/ufal/whisper_streaming) readme. + - `--transcription`: Enable/disable transcription (default: True) + - `--diarization`: Enable/disable speaker diarization (default: False) + - `--confidence-validation`: Use confidence scores for faster validation. Transcription will be faster but punctuation might be less accurate (default: True) 4. **Open the Provided HTML**: diff --git a/timed_objects.py b/timed_objects.py index b1baa0a..c1c3e4e 100644 --- a/timed_objects.py +++ b/timed_objects.py @@ -7,12 +7,13 @@ class TimedText: end: Optional[float] text: Optional[str] = '' speaker: Optional[int] = -1 + probability: Optional[float] = None @dataclass class ASRToken(TimedText): def with_offset(self, offset: float) -> "ASRToken": """Return a new token with the time offset added.""" - return ASRToken(self.start + offset, self.end + offset, self.text) + return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, self.probability) @dataclass class Sentence(TimedText): diff --git a/whisper_fastapi_online_server.py b/whisper_fastapi_online_server.py index c4a12f4..d55e104 100644 --- a/whisper_fastapi_online_server.py +++ b/whisper_fastapi_online_server.py @@ -46,6 +46,13 @@ parser.add_argument( help="The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast. It can be e.g. https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav .", ) +parser.add_argument( + "--confidence-validation", + type=bool, + default=True, + help="Accelerates validation of tokens using confidence scores. Transcription will be faster but punctuation might be less accurate.", +) + parser.add_argument( "--diarization", type=bool, diff --git a/whisper_streaming_custom/backends.py b/whisper_streaming_custom/backends.py index 7a88315..fa52104 100644 --- a/whisper_streaming_custom/backends.py +++ b/whisper_streaming_custom/backends.py @@ -131,7 +131,7 @@ class FasterWhisperASR(ASRBase): if segment.no_speech_prob > 0.9: continue for word in segment.words: - token = ASRToken(word.start, word.end, word.word) + token = ASRToken(word.start, word.end, word.word, probability=word.probability) tokens.append(token) return tokens @@ -210,7 +210,7 @@ class MLXWhisper(ASRBase): if segment.get("no_speech_prob", 0) > 0.9: continue for word in segment.get("words", []): - token = ASRToken(word["start"], word["end"], word["word"]) + token = ASRToken(word["start"], word["end"], word["word"], probability=word["probability"]) tokens.append(token) return tokens diff --git a/whisper_streaming_custom/online_asr.py b/whisper_streaming_custom/online_asr.py index 6fa7d9e..bc09395 100644 --- a/whisper_streaming_custom/online_asr.py +++ b/whisper_streaming_custom/online_asr.py @@ -16,7 +16,8 @@ class HypothesisBuffer: - buffer: the last hypothesis that is not yet committed - new: new tokens coming from the recognizer """ - def __init__(self, logfile=sys.stderr): + def __init__(self, logfile=sys.stderr, confidence_validation=False): + self.confidence_validation = confidence_validation self.committed_in_buffer: List[ASRToken] = [] self.buffer: List[ASRToken] = [] self.new: List[ASRToken] = [] @@ -62,9 +63,15 @@ class HypothesisBuffer: committed: List[ASRToken] = [] while self.new: current_new = self.new[0] - if not self.buffer: + if self.confidence_validation and current_new.probability and current_new.probability > 0.95: + committed.append(current_new) + self.last_committed_word = current_new.text + self.last_committed_time = current_new.end + self.new.pop(0) + self.buffer.pop(0) if self.buffer else None + elif not self.buffer: break - if current_new.text == self.buffer[0].text: + elif current_new.text == self.buffer[0].text: committed.append(current_new) self.last_committed_word = current_new.text self.last_committed_time = current_new.end @@ -102,6 +109,7 @@ class OnlineASRProcessor: asr, tokenize_method: Optional[callable] = None, buffer_trimming: Tuple[str, float] = ("segment", 15), + confidence_validation = False, logfile=sys.stderr, ): """ @@ -114,7 +122,7 @@ class OnlineASRProcessor: self.asr = asr self.tokenize = tokenize_method self.logfile = logfile - + self.confidence_validation = confidence_validation self.init() self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming @@ -131,7 +139,7 @@ class OnlineASRProcessor: def init(self, offset: Optional[float] = None): """Initialize or reset the processing buffers.""" self.audio_buffer = np.array([], dtype=np.float32) - self.transcript_buffer = HypothesisBuffer(logfile=self.logfile) + self.transcript_buffer = HypothesisBuffer(logfile=self.logfile, confidence_validation=self.confidence_validation) self.buffer_time_offset = offset if offset is not None else 0.0 self.transcript_buffer.last_committed_time = self.buffer_time_offset self.committed: List[ASRToken] = [] @@ -323,13 +331,14 @@ class OnlineASRProcessor: ) -> Transcript: sep = sep if sep is not None else self.asr.sep text = sep.join(token.text for token in tokens) + probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None if tokens: start = offset + tokens[0].start end = offset + tokens[-1].end else: start = None end = None - return Transcript(start, end, text) + return Transcript(start, end, text, probability=probability) class VACOnlineASRProcessor: diff --git a/whisper_streaming_custom/whisper_online.py b/whisper_streaming_custom/whisper_online.py index f997dcd..29e5a22 100644 --- a/whisper_streaming_custom/whisper_online.py +++ b/whisper_streaming_custom/whisper_online.py @@ -77,7 +77,7 @@ def add_shared_args(parser): parser.add_argument( "--model", type=str, - default="tiny.en", + default="large-v3-turbo", choices="tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo".split( "," ), @@ -207,6 +207,7 @@ def online_factory(args, asr, tokenizer, logfile=sys.stderr): tokenizer, logfile=logfile, buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec), + confidence_validation = args.confidence_validation ) else: online = OnlineASRProcessor( @@ -214,6 +215,7 @@ def online_factory(args, asr, tokenizer, logfile=sys.stderr): tokenizer, logfile=logfile, buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec), + confidence_validation = args.confidence_validation ) return online