mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
Merge branch 'main' into fix-sentencesegmenter
This commit is contained in:
@@ -12,6 +12,8 @@ This project extends the [Whisper Streaming](https://github.com/ufal/whisper_str
|
||||
|
||||
5. **MLX Whisper backend**: Integrates the alternative backend option MLX Whisper, optimized for efficient speech recognition on Apple silicon.
|
||||
|
||||
6. **Diarization (beta)**: Adds speaker labeling in real-time alongside transcription using the [Diart](https://github.com/juanmc2005/diart) library. Each transcription segment is tagged with a speaker.
|
||||
|
||||

|
||||
|
||||
## Code Origins
|
||||
@@ -64,6 +66,9 @@ This project reuses and extends code from the original Whisper Streaming reposit
|
||||
|
||||
# If you want to run the server using uvicorn (recommended)
|
||||
uvicorn
|
||||
|
||||
# If you want to use diarization
|
||||
diart
|
||||
```
|
||||
|
||||
|
||||
@@ -76,6 +81,8 @@ This project reuses and extends code from the original Whisper Streaming reposit
|
||||
- `--host` and `--port` let you specify the server’s IP/port.
|
||||
- `-min-chunk-size` sets the minimum chunk size for audio processing. Make sure this value aligns with the chunk size selected in the frontend. If not aligned, the system will work but may unnecessarily over-process audio data.
|
||||
- For a full list of configurable options, run `python whisper_fastapi_online_server.py -h`
|
||||
- `--diarization`, default to False, let you choose whether or not you want to run diarization in parallel
|
||||
- For other parameters, look at [whisper streaming](https://github.com/ufal/whisper_streaming) readme.
|
||||
|
||||
4. **Open the Provided HTML**:
|
||||
|
||||
|
||||
110
src/diarization/diarization_online.py
Normal file
110
src/diarization/diarization_online.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from diart import SpeakerDiarization
|
||||
from diart.inference import StreamingInference
|
||||
from diart.sources import AudioSource
|
||||
from rx.subject import Subject
|
||||
import threading
|
||||
import numpy as np
|
||||
import asyncio
|
||||
|
||||
class WebSocketAudioSource(AudioSource):
|
||||
"""
|
||||
Simple custom AudioSource that blocks in read()
|
||||
until close() is called.
|
||||
push_audio() is used to inject new PCM chunks.
|
||||
"""
|
||||
def __init__(self, uri: str = "websocket", sample_rate: int = 16000):
|
||||
super().__init__(uri, sample_rate)
|
||||
self._close_event = threading.Event()
|
||||
self._closed = False
|
||||
|
||||
def read(self):
|
||||
self._close_event.wait()
|
||||
|
||||
def close(self):
|
||||
if not self._closed:
|
||||
self._closed = True
|
||||
self.stream.on_completed()
|
||||
self._close_event.set()
|
||||
|
||||
def push_audio(self, chunk: np.ndarray):
|
||||
chunk = np.expand_dims(chunk, axis=0)
|
||||
if not self._closed:
|
||||
self.stream.on_next(chunk)
|
||||
|
||||
|
||||
def create_pipeline(SAMPLE_RATE):
|
||||
diar_pipeline = SpeakerDiarization()
|
||||
ws_source = WebSocketAudioSource(uri="websocket_source", sample_rate=SAMPLE_RATE)
|
||||
inference = StreamingInference(
|
||||
pipeline=diar_pipeline,
|
||||
source=ws_source,
|
||||
do_plot=False,
|
||||
show_progress=False,
|
||||
)
|
||||
return inference, ws_source
|
||||
|
||||
|
||||
def init_diart(SAMPLE_RATE):
|
||||
inference, ws_source = create_pipeline(SAMPLE_RATE)
|
||||
|
||||
def diar_hook(result):
|
||||
"""
|
||||
Hook called each time Diart processes a chunk.
|
||||
result is (annotation, audio).
|
||||
We store the label of the last segment in 'current_speaker'.
|
||||
"""
|
||||
global l_speakers
|
||||
l_speakers = []
|
||||
annotation, audio = result
|
||||
for speaker in annotation._labels:
|
||||
segments_beg = annotation._labels[speaker].segments_boundaries_[0]
|
||||
segments_end = annotation._labels[speaker].segments_boundaries_[-1]
|
||||
asyncio.create_task(
|
||||
l_speakers_queue.put({"speaker": speaker, "beg": segments_beg, "end": segments_end})
|
||||
)
|
||||
|
||||
l_speakers_queue = asyncio.Queue()
|
||||
inference.attach_hooks(diar_hook)
|
||||
|
||||
# Launch Diart in a background thread
|
||||
loop = asyncio.get_event_loop()
|
||||
diar_future = loop.run_in_executor(None, inference)
|
||||
return inference, l_speakers_queue, ws_source
|
||||
|
||||
|
||||
class DiartDiarization():
|
||||
def __init__(self, SAMPLE_RATE):
|
||||
self.inference, self.l_speakers_queue, self.ws_source = init_diart(SAMPLE_RATE)
|
||||
self.segment_speakers = []
|
||||
|
||||
async def diarize(self, pcm_array):
|
||||
self.ws_source.push_audio(pcm_array)
|
||||
self.segment_speakers = []
|
||||
while not self.l_speakers_queue.empty():
|
||||
self.segment_speakers.append(await self.l_speakers_queue.get())
|
||||
|
||||
def close(self):
|
||||
self.ws_source.close()
|
||||
|
||||
|
||||
def assign_speakers_to_chunks(self, chunks):
|
||||
"""
|
||||
Go through each chunk and see which speaker(s) overlap
|
||||
that chunk's time range in the Diart annotation.
|
||||
Then store the speaker label(s) (or choose the most overlapping).
|
||||
This modifies `chunks` in-place or returns a new list with assigned speakers.
|
||||
"""
|
||||
if not self.segment_speakers:
|
||||
return chunks
|
||||
|
||||
for segment in self.segment_speakers:
|
||||
seg_beg = segment["beg"]
|
||||
seg_end = segment["end"]
|
||||
speaker = segment["speaker"]
|
||||
for ch in chunks:
|
||||
if seg_end <= ch["beg"] or seg_beg >= ch["end"]:
|
||||
continue
|
||||
# We have overlap. Let's just pick the speaker (could be more precise in a more complex implementation)
|
||||
ch["speaker"] = speaker
|
||||
|
||||
return chunks
|
||||
BIN
src/web/demo.png
BIN
src/web/demo.png
Binary file not shown.
|
Before Width: | Height: | Size: 81 KiB After Width: | Height: | Size: 174 KiB |
@@ -7,8 +7,8 @@
|
||||
<style>
|
||||
body {
|
||||
font-family: 'Inter', sans-serif;
|
||||
text-align: center;
|
||||
margin: 20px;
|
||||
text-align: center;
|
||||
}
|
||||
#recordButton {
|
||||
width: 80px;
|
||||
@@ -28,18 +28,10 @@
|
||||
#recordButton:active {
|
||||
transform: scale(0.95);
|
||||
}
|
||||
#transcriptions {
|
||||
#status {
|
||||
margin-top: 20px;
|
||||
font-size: 18px;
|
||||
text-align: left;
|
||||
}
|
||||
.transcription {
|
||||
display: inline;
|
||||
color: black;
|
||||
}
|
||||
.buffer {
|
||||
display: inline;
|
||||
color: rgb(197, 197, 197);
|
||||
font-size: 16px;
|
||||
color: #333;
|
||||
}
|
||||
.settings-container {
|
||||
display: flex;
|
||||
@@ -73,9 +65,29 @@
|
||||
label {
|
||||
font-size: 14px;
|
||||
}
|
||||
/* Speaker-labeled transcript area */
|
||||
#linesTranscript {
|
||||
margin: 20px auto;
|
||||
max-width: 600px;
|
||||
text-align: left;
|
||||
font-size: 16px;
|
||||
}
|
||||
#linesTranscript p {
|
||||
margin: 5px 0;
|
||||
}
|
||||
#linesTranscript strong {
|
||||
color: #333;
|
||||
}
|
||||
/* Grey buffer styling */
|
||||
.buffer {
|
||||
color: rgb(180, 180, 180);
|
||||
font-style: italic;
|
||||
margin-left: 4px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
|
||||
<div class="settings-container">
|
||||
<button id="recordButton">🎙️</button>
|
||||
<div class="settings">
|
||||
@@ -96,9 +108,11 @@
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<p id="status"></p>
|
||||
|
||||
<div id="transcriptions"></div>
|
||||
<!-- Speaker-labeled transcript -->
|
||||
<div id="linesTranscript"></div>
|
||||
|
||||
<script>
|
||||
let isRecording = false;
|
||||
@@ -106,89 +120,97 @@
|
||||
let recorder = null;
|
||||
let chunkDuration = 1000;
|
||||
let websocketUrl = "ws://localhost:8000/asr";
|
||||
|
||||
// Tracks whether the user voluntarily closed the WebSocket
|
||||
let userClosing = false;
|
||||
|
||||
const statusText = document.getElementById("status");
|
||||
const recordButton = document.getElementById("recordButton");
|
||||
const chunkSelector = document.getElementById("chunkSelector");
|
||||
const websocketInput = document.getElementById("websocketInput");
|
||||
const transcriptionsDiv = document.getElementById("transcriptions");
|
||||
const linesTranscriptDiv = document.getElementById("linesTranscript");
|
||||
|
||||
let fullTranscription = ""; // Store confirmed transcription
|
||||
|
||||
// Update chunk duration based on the selector
|
||||
chunkSelector.addEventListener("change", () => {
|
||||
chunkDuration = parseInt(chunkSelector.value);
|
||||
});
|
||||
|
||||
// Update WebSocket URL dynamically, with some basic checks
|
||||
websocketInput.addEventListener("change", () => {
|
||||
const urlValue = websocketInput.value.trim();
|
||||
|
||||
// Quick check to see if it starts with ws:// or wss://
|
||||
if (!urlValue.startsWith("ws://") && !urlValue.startsWith("wss://")) {
|
||||
statusText.textContent =
|
||||
"Invalid WebSocket URL. It should start with ws:// or wss://";
|
||||
statusText.textContent = "Invalid WebSocket URL (must start with ws:// or wss://)";
|
||||
return;
|
||||
}
|
||||
websocketUrl = urlValue;
|
||||
statusText.textContent = "WebSocket URL updated. Ready to connect.";
|
||||
});
|
||||
|
||||
/**
|
||||
* Opens webSocket connection.
|
||||
* returns a Promise that resolves when the connection is open.
|
||||
* rejects if there was an error.
|
||||
*/
|
||||
function setupWebSocket() {
|
||||
return new Promise((resolve, reject) => {
|
||||
try {
|
||||
websocket = new WebSocket(websocketUrl);
|
||||
} catch (error) {
|
||||
statusText.textContent =
|
||||
"Invalid WebSocket URL. Please check the URL and try again.";
|
||||
statusText.textContent = "Invalid WebSocket URL. Please check and try again.";
|
||||
reject(error);
|
||||
return;
|
||||
}
|
||||
|
||||
websocket.onopen = () => {
|
||||
statusText.textContent = "Connected to server";
|
||||
statusText.textContent = "Connected to server.";
|
||||
resolve();
|
||||
};
|
||||
|
||||
websocket.onclose = (event) => {
|
||||
// If we manually closed it, we say so
|
||||
websocket.onclose = () => {
|
||||
if (userClosing) {
|
||||
statusText.textContent = "WebSocket closed by user.";
|
||||
} else {
|
||||
statusText.textContent = "Disconnected from the websocket server. If this is the first launch, the model may be downloading in the backend. Check the API logs for more information.";
|
||||
statusText.textContent =
|
||||
"Disconnected from the WebSocket server. (Check logs if model is loading.)";
|
||||
}
|
||||
userClosing = false;
|
||||
};
|
||||
|
||||
websocket.onerror = () => {
|
||||
statusText.textContent = "Error connecting to WebSocket";
|
||||
statusText.textContent = "Error connecting to WebSocket.";
|
||||
reject(new Error("Error connecting to WebSocket"));
|
||||
};
|
||||
|
||||
// Handle messages from server
|
||||
websocket.onmessage = (event) => {
|
||||
const data = JSON.parse(event.data);
|
||||
const { transcription, buffer } = data;
|
||||
|
||||
// Update confirmed transcription
|
||||
fullTranscription += transcription;
|
||||
|
||||
// Update the transcription display
|
||||
transcriptionsDiv.innerHTML = `
|
||||
<span class="transcription">${fullTranscription}</span>
|
||||
<span class="buffer">${buffer}</span>
|
||||
`;
|
||||
/*
|
||||
The server might send:
|
||||
{
|
||||
"lines": [
|
||||
{"speaker": 0, "text": "Hello."},
|
||||
{"speaker": 1, "text": "Bonjour."},
|
||||
...
|
||||
],
|
||||
"buffer": "..."
|
||||
}
|
||||
*/
|
||||
const { lines = [], buffer = "" } = data;
|
||||
renderLinesWithBuffer(lines, buffer);
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
function renderLinesWithBuffer(lines, buffer) {
|
||||
// Clears if no lines
|
||||
if (!Array.isArray(lines) || lines.length === 0) {
|
||||
linesTranscriptDiv.innerHTML = "";
|
||||
return;
|
||||
}
|
||||
// Build the HTML
|
||||
// The buffer is appended to the last line if it's non-empty
|
||||
const linesHtml = lines.map((item, idx) => {
|
||||
let textContent = item.text;
|
||||
if (idx === lines.length - 1 && buffer) {
|
||||
textContent += `<span class="buffer">${buffer}</span>`;
|
||||
}
|
||||
return `<p><strong>Speaker ${item.speaker}:</strong> ${textContent}</p>`;
|
||||
}).join("");
|
||||
|
||||
linesTranscriptDiv.innerHTML = linesHtml;
|
||||
}
|
||||
|
||||
async function startRecording() {
|
||||
try {
|
||||
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
|
||||
@@ -202,22 +224,18 @@
|
||||
isRecording = true;
|
||||
updateUI();
|
||||
} catch (err) {
|
||||
statusText.textContent =
|
||||
"Error accessing microphone. Please allow microphone access.";
|
||||
statusText.textContent = "Error accessing microphone. Please allow microphone access.";
|
||||
}
|
||||
}
|
||||
|
||||
function stopRecording() {
|
||||
userClosing = true;
|
||||
|
||||
// Stop the recorder if it exists
|
||||
if (recorder) {
|
||||
recorder.stop();
|
||||
recorder = null;
|
||||
}
|
||||
isRecording = false;
|
||||
|
||||
// Close the websocket if it exists
|
||||
if (websocket) {
|
||||
websocket.close();
|
||||
websocket = null;
|
||||
@@ -228,15 +246,12 @@
|
||||
|
||||
async function toggleRecording() {
|
||||
if (!isRecording) {
|
||||
fullTranscription = "";
|
||||
transcriptionsDiv.innerHTML = "";
|
||||
|
||||
linesTranscriptDiv.innerHTML = "";
|
||||
try {
|
||||
await setupWebSocket();
|
||||
await startRecording();
|
||||
} catch (err) {
|
||||
statusText.textContent =
|
||||
"Could not connect to WebSocket or access mic. Recording aborted.";
|
||||
statusText.textContent = "Could not connect to WebSocket or access mic. Aborted.";
|
||||
}
|
||||
} else {
|
||||
stopRecording();
|
||||
@@ -245,9 +260,7 @@
|
||||
|
||||
function updateUI() {
|
||||
recordButton.classList.toggle("recording", isRecording);
|
||||
statusText.textContent = isRecording
|
||||
? "Recording..."
|
||||
: "Click to start transcription";
|
||||
statusText.textContent = isRecording ? "Recording..." : "Click to start transcription";
|
||||
}
|
||||
|
||||
recordButton.addEventListener("click", toggleRecording);
|
||||
|
||||
@@ -215,21 +215,14 @@ class OnlineASRProcessor:
|
||||
# self.chunk_at(t)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
return completed
|
||||
|
||||
def chunk_completed_sentence(self, commited_text):
|
||||
if commited_text == []:
|
||||
return
|
||||
|
||||
sents = self.words_to_sentences(commited_text)
|
||||
|
||||
def chunk_completed_sentence(self):
|
||||
if self.commited == []:
|
||||
return
|
||||
raw_text = self.asr.sep.join([s[2] for s in self.commited])
|
||||
logger.debug(f"COMPLETED SENTENCE: {raw_text}")
|
||||
sents = self.words_to_sentences(self.commited)
|
||||
|
||||
|
||||
if len(sents) < 2:
|
||||
@@ -322,7 +315,7 @@ class OnlineASRProcessor:
|
||||
"""
|
||||
o = self.transcript_buffer.complete()
|
||||
f = self.concatenate_tsw(o)
|
||||
logger.debug(f"last, noncommited: {f[0]*1000:.0f}-{f[1]*1000:.0f}: {f[2]}")
|
||||
logger.debug(f"last, noncommited: {f[0]*1000:.0f}-{f[1]*1000:.0f}: {f[2][0]*1000:.0f}-{f[1]*1000:.0f}: {f[2]}")
|
||||
self.buffer_time_offset += len(self.audio_buffer) / 16000
|
||||
return f
|
||||
|
||||
@@ -365,7 +358,7 @@ class VACOnlineASRProcessor(OnlineASRProcessor):
|
||||
import torch
|
||||
|
||||
model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
|
||||
from silero_vad_iterator import FixedVADIterator
|
||||
from src.whisper_streaming.silero_vad_iterator import FixedVADIterator
|
||||
|
||||
self.vac = FixedVADIterator(
|
||||
model
|
||||
|
||||
@@ -9,7 +9,7 @@ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from whisper_online import backend_factory, online_factory, add_shared_args
|
||||
from src.whisper_streaming.whisper_online import backend_factory, online_factory, add_shared_args
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
@@ -37,11 +37,24 @@ parser.add_argument(
|
||||
dest="warmup_file",
|
||||
help="The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast. It can be e.g. https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav .",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--diarization",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Whether to enable speaker diarization.",
|
||||
)
|
||||
|
||||
|
||||
add_shared_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
asr, tokenizer = backend_factory(args)
|
||||
|
||||
if args.diarization:
|
||||
from src.diarization.diarization_online import DiartDiarization
|
||||
|
||||
|
||||
# Load demo HTML for the root endpoint
|
||||
with open("src/web/live_transcription.html", "r", encoding="utf-8") as f:
|
||||
html = f.read()
|
||||
@@ -78,6 +91,7 @@ async def start_ffmpeg_decoder():
|
||||
return process
|
||||
|
||||
|
||||
|
||||
@app.websocket("/asr")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
@@ -89,12 +103,18 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
online = online_factory(args, asr, tokenizer)
|
||||
print("Online loaded.")
|
||||
|
||||
if args.diarization:
|
||||
diarization = DiartDiarization(SAMPLE_RATE)
|
||||
|
||||
# Continuously read decoded PCM from ffmpeg stdout in a background task
|
||||
async def ffmpeg_stdout_reader():
|
||||
nonlocal pcm_buffer
|
||||
loop = asyncio.get_event_loop()
|
||||
full_transcription = ""
|
||||
beg = time()
|
||||
|
||||
chunk_history = [] # Will store dicts: {beg, end, text, speaker}
|
||||
|
||||
while True:
|
||||
try:
|
||||
elapsed_time = int(time() - beg)
|
||||
@@ -122,8 +142,17 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
)
|
||||
pcm_buffer = bytearray()
|
||||
online.insert_audio_chunk(pcm_array)
|
||||
transcription = online.process_iter()[2]
|
||||
full_transcription += transcription
|
||||
beg_trans, end_trans, trans = online.process_iter()
|
||||
|
||||
if trans:
|
||||
chunk_history.append({
|
||||
"beg": beg_trans,
|
||||
"end": end_trans,
|
||||
"text": trans,
|
||||
"speaker": "0"
|
||||
})
|
||||
|
||||
full_transcription += trans
|
||||
if args.vac:
|
||||
buffer = online.online.concatenate_tsw(
|
||||
online.online.transcript_buffer.buffer
|
||||
@@ -136,9 +165,32 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
buffer in full_transcription
|
||||
): # With VAC, the buffer is not updated until the next chunk is processed
|
||||
buffer = ""
|
||||
await websocket.send_json(
|
||||
{"transcription": transcription, "buffer": buffer}
|
||||
)
|
||||
|
||||
lines = [
|
||||
{
|
||||
"speaker": "0",
|
||||
"text": "",
|
||||
}
|
||||
]
|
||||
|
||||
if args.diarization:
|
||||
await diarization.diarize(pcm_array)
|
||||
diarization.assign_speakers_to_chunks(chunk_history)
|
||||
|
||||
for ch in chunk_history:
|
||||
if args.diarization and ch["speaker"] and ch["speaker"][-1] != lines[-1]["speaker"]:
|
||||
lines.append(
|
||||
{
|
||||
"speaker": ch["speaker"][-1],
|
||||
"text": ch['text'],
|
||||
}
|
||||
)
|
||||
else:
|
||||
lines[-1]["text"] += ch['text']
|
||||
|
||||
response = {"lines": lines, "buffer": buffer}
|
||||
await websocket.send_json(response)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Exception in ffmpeg_stdout_reader: {e}")
|
||||
break
|
||||
@@ -174,6 +226,11 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
|
||||
ffmpeg_process.wait()
|
||||
del online
|
||||
|
||||
if args.diarization:
|
||||
# Stop Diart
|
||||
diarization.close()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user