mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
use confidence scores returned by whisper to immediately validate tokens
This commit is contained in:
28
README.md
28
README.md
@@ -1,4 +1,4 @@
|
||||
# Real-time, fully local Speech-to-Text and speaker diarization using FastAPI WebSockets with a web interface
|
||||
# Real-time, Fully Local Speech-to-Text and Speaker Diarization
|
||||
|
||||
This project is based on [Whisper Streaming](https://github.com/ufal/whisper_streaming) and lets you transcribe audio directly from your browser. Simply launch the local server and grant microphone access. Everything runs locally on your machine ✨
|
||||
|
||||
@@ -8,24 +8,23 @@ This project is based on [Whisper Streaming](https://github.com/ufal/whisper_str
|
||||
|
||||
### Differences from [Whisper Streaming](https://github.com/ufal/whisper_streaming)
|
||||
|
||||
#### 🌐 **Web & API**
|
||||
- **Built-in Web UI** – No frontend setup required, just open your browser and start transcribing.
|
||||
- **FastAPI WebSocket Server** – Real-time speech-to-text processing with async FFmpeg streaming.
|
||||
- **JavaScript Client** – Ready-to-use MediaRecorder implementation for seamless client-side integration.
|
||||
|
||||
#### ⚙️ **Core Improvements**
|
||||
- **Buffering Preview** – Displays unvalidated transcription segments for immediate feedback.
|
||||
- **Multi-User Support** – Handles multiple users simultaneously without conflicts.
|
||||
- **MLX Whisper Backend** – Optimized for Apple Silicon for faster local processing.
|
||||
- **Enhanced Sentence Segmentation** – Improved buffer trimming for better accuracy across languages.
|
||||
- **Extended Logging** – More detailed logs to improve debugging and monitoring.
|
||||
- **Confidence validation** – Immediately validate high-confidence tokens for faster inference
|
||||
|
||||
#### 🎙️ **Advanced Features**
|
||||
- **Real-Time Diarization** – Identify different speakers in real time using [Diart](https://github.com/juanmc2005/diart).
|
||||
#### 🎙️ **Speaker Identification**
|
||||
- **Real-Time Diarization** – Identify different speakers in real time using [Diart](https://github.com/juanmc2005/diart).
|
||||
|
||||
#### 🌐 **Web & API**
|
||||
- **Built-in Web UI** – Simple browser interface with no frontend setup required
|
||||
- **FastAPI WebSocket Server** – Real-time speech-to-text processing with async FFmpeg streaming.
|
||||
- **JavaScript Client** – Ready-to-use MediaRecorder implementation for seamless client-side integration.
|
||||
|
||||
#### 🚀 **Coming Soon**
|
||||
|
||||
- **Faster Word Validation** – Accelerate real-time transcription by validating high-confidence words immediately upon first appearance for whisper backends that return word & segment probabilities
|
||||
- **Enhanced Diarization Performance** – Optimize speaker identification by implementing longer steps for Diart processing and leveraging language-specific segmentation patterns to improve speaker boundary detection
|
||||
|
||||
|
||||
@@ -87,12 +86,13 @@ This project is based on [Whisper Streaming](https://github.com/ufal/whisper_str
|
||||
python whisper_fastapi_online_server.py --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
All [Whisper Streaming](https://github.com/ufal/whisper_streaming) parameters are supported.
|
||||
Additional parameters:
|
||||
- `--host` and `--port` let you specify the server’s IP/port.
|
||||
- `-min-chunk-size` sets the minimum chunk size for audio processing. Make sure this value aligns with the chunk size selected in the frontend. If not aligned, the system will work but may unnecessarily over-process audio data.
|
||||
- For a full list of configurable options, run `python whisper_fastapi_online_server.py -h`
|
||||
- `--transcription`, default to True. Change to False if you want to run only diarization
|
||||
- `--diarization`, default to False, let you choose whether or not you want to run diarization in parallel
|
||||
- For other parameters, look at [whisper streaming](https://github.com/ufal/whisper_streaming) readme.
|
||||
- `--transcription`: Enable/disable transcription (default: True)
|
||||
- `--diarization`: Enable/disable speaker diarization (default: False)
|
||||
- `--confidence-validation`: Use confidence scores for faster validation. Transcription will be faster but punctuation might be less accurate (default: True)
|
||||
|
||||
4. **Open the Provided HTML**:
|
||||
|
||||
|
||||
@@ -7,12 +7,13 @@ class TimedText:
|
||||
end: Optional[float]
|
||||
text: Optional[str] = ''
|
||||
speaker: Optional[int] = -1
|
||||
probability: Optional[float] = None
|
||||
|
||||
@dataclass
|
||||
class ASRToken(TimedText):
|
||||
def with_offset(self, offset: float) -> "ASRToken":
|
||||
"""Return a new token with the time offset added."""
|
||||
return ASRToken(self.start + offset, self.end + offset, self.text)
|
||||
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, self.probability)
|
||||
|
||||
@dataclass
|
||||
class Sentence(TimedText):
|
||||
|
||||
@@ -46,6 +46,13 @@ parser.add_argument(
|
||||
help="The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast. It can be e.g. https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav .",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--confidence-validation",
|
||||
type=bool,
|
||||
default=True,
|
||||
help="Accelerates validation of tokens using confidence scores. Transcription will be faster but punctuation might be less accurate.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--diarization",
|
||||
type=bool,
|
||||
|
||||
@@ -131,7 +131,7 @@ class FasterWhisperASR(ASRBase):
|
||||
if segment.no_speech_prob > 0.9:
|
||||
continue
|
||||
for word in segment.words:
|
||||
token = ASRToken(word.start, word.end, word.word)
|
||||
token = ASRToken(word.start, word.end, word.word, probability=word.probability)
|
||||
tokens.append(token)
|
||||
return tokens
|
||||
|
||||
@@ -210,7 +210,7 @@ class MLXWhisper(ASRBase):
|
||||
if segment.get("no_speech_prob", 0) > 0.9:
|
||||
continue
|
||||
for word in segment.get("words", []):
|
||||
token = ASRToken(word["start"], word["end"], word["word"])
|
||||
token = ASRToken(word["start"], word["end"], word["word"], probability=word["probability"])
|
||||
tokens.append(token)
|
||||
return tokens
|
||||
|
||||
|
||||
@@ -16,7 +16,8 @@ class HypothesisBuffer:
|
||||
- buffer: the last hypothesis that is not yet committed
|
||||
- new: new tokens coming from the recognizer
|
||||
"""
|
||||
def __init__(self, logfile=sys.stderr):
|
||||
def __init__(self, logfile=sys.stderr, confidence_validation=False):
|
||||
self.confidence_validation = confidence_validation
|
||||
self.committed_in_buffer: List[ASRToken] = []
|
||||
self.buffer: List[ASRToken] = []
|
||||
self.new: List[ASRToken] = []
|
||||
@@ -62,9 +63,15 @@ class HypothesisBuffer:
|
||||
committed: List[ASRToken] = []
|
||||
while self.new:
|
||||
current_new = self.new[0]
|
||||
if not self.buffer:
|
||||
if self.confidence_validation and current_new.probability and current_new.probability > 0.95:
|
||||
committed.append(current_new)
|
||||
self.last_committed_word = current_new.text
|
||||
self.last_committed_time = current_new.end
|
||||
self.new.pop(0)
|
||||
self.buffer.pop(0) if self.buffer else None
|
||||
elif not self.buffer:
|
||||
break
|
||||
if current_new.text == self.buffer[0].text:
|
||||
elif current_new.text == self.buffer[0].text:
|
||||
committed.append(current_new)
|
||||
self.last_committed_word = current_new.text
|
||||
self.last_committed_time = current_new.end
|
||||
@@ -102,6 +109,7 @@ class OnlineASRProcessor:
|
||||
asr,
|
||||
tokenize_method: Optional[callable] = None,
|
||||
buffer_trimming: Tuple[str, float] = ("segment", 15),
|
||||
confidence_validation = False,
|
||||
logfile=sys.stderr,
|
||||
):
|
||||
"""
|
||||
@@ -114,7 +122,7 @@ class OnlineASRProcessor:
|
||||
self.asr = asr
|
||||
self.tokenize = tokenize_method
|
||||
self.logfile = logfile
|
||||
|
||||
self.confidence_validation = confidence_validation
|
||||
self.init()
|
||||
|
||||
self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
|
||||
@@ -131,7 +139,7 @@ class OnlineASRProcessor:
|
||||
def init(self, offset: Optional[float] = None):
|
||||
"""Initialize or reset the processing buffers."""
|
||||
self.audio_buffer = np.array([], dtype=np.float32)
|
||||
self.transcript_buffer = HypothesisBuffer(logfile=self.logfile)
|
||||
self.transcript_buffer = HypothesisBuffer(logfile=self.logfile, confidence_validation=self.confidence_validation)
|
||||
self.buffer_time_offset = offset if offset is not None else 0.0
|
||||
self.transcript_buffer.last_committed_time = self.buffer_time_offset
|
||||
self.committed: List[ASRToken] = []
|
||||
@@ -323,13 +331,14 @@ class OnlineASRProcessor:
|
||||
) -> Transcript:
|
||||
sep = sep if sep is not None else self.asr.sep
|
||||
text = sep.join(token.text for token in tokens)
|
||||
probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
|
||||
if tokens:
|
||||
start = offset + tokens[0].start
|
||||
end = offset + tokens[-1].end
|
||||
else:
|
||||
start = None
|
||||
end = None
|
||||
return Transcript(start, end, text)
|
||||
return Transcript(start, end, text, probability=probability)
|
||||
|
||||
|
||||
class VACOnlineASRProcessor:
|
||||
|
||||
@@ -77,7 +77,7 @@ def add_shared_args(parser):
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="tiny.en",
|
||||
default="large-v3-turbo",
|
||||
choices="tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo".split(
|
||||
","
|
||||
),
|
||||
@@ -207,6 +207,7 @@ def online_factory(args, asr, tokenizer, logfile=sys.stderr):
|
||||
tokenizer,
|
||||
logfile=logfile,
|
||||
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
|
||||
confidence_validation = args.confidence_validation
|
||||
)
|
||||
else:
|
||||
online = OnlineASRProcessor(
|
||||
@@ -214,6 +215,7 @@ def online_factory(args, asr, tokenizer, logfile=sys.stderr):
|
||||
tokenizer,
|
||||
logfile=logfile,
|
||||
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
|
||||
confidence_validation = args.confidence_validation
|
||||
)
|
||||
return online
|
||||
|
||||
|
||||
Reference in New Issue
Block a user