Merge branch 'main' into fix-sentencesegmenter

This commit is contained in:
Quentin Fuxa
2025-01-28 15:53:10 +01:00
committed by GitHub
8 changed files with 260 additions and 80 deletions

View File

@@ -12,6 +12,8 @@ This project extends the [Whisper Streaming](https://github.com/ufal/whisper_str
5. **MLX Whisper backend**: Integrates the alternative backend option MLX Whisper, optimized for efficient speech recognition on Apple silicon.
6. **Diarization (beta)**: Adds speaker labeling in real-time alongside transcription using the [Diart](https://github.com/juanmc2005/diart) library. Each transcription segment is tagged with a speaker.
![Demo Screenshot](src/web/demo.png)
## Code Origins
@@ -64,6 +66,9 @@ This project reuses and extends code from the original Whisper Streaming reposit
# If you want to run the server using uvicorn (recommended)
uvicorn
# If you want to use diarization
diart
```
@@ -76,6 +81,8 @@ This project reuses and extends code from the original Whisper Streaming reposit
- `--host` and `--port` let you specify the servers IP/port.
- `-min-chunk-size` sets the minimum chunk size for audio processing. Make sure this value aligns with the chunk size selected in the frontend. If not aligned, the system will work but may unnecessarily over-process audio data.
- For a full list of configurable options, run `python whisper_fastapi_online_server.py -h`
- `--diarization`, default to False, let you choose whether or not you want to run diarization in parallel
- For other parameters, look at [whisper streaming](https://github.com/ufal/whisper_streaming) readme.
4. **Open the Provided HTML**:

View File

@@ -0,0 +1,110 @@
from diart import SpeakerDiarization
from diart.inference import StreamingInference
from diart.sources import AudioSource
from rx.subject import Subject
import threading
import numpy as np
import asyncio
class WebSocketAudioSource(AudioSource):
"""
Simple custom AudioSource that blocks in read()
until close() is called.
push_audio() is used to inject new PCM chunks.
"""
def __init__(self, uri: str = "websocket", sample_rate: int = 16000):
super().__init__(uri, sample_rate)
self._close_event = threading.Event()
self._closed = False
def read(self):
self._close_event.wait()
def close(self):
if not self._closed:
self._closed = True
self.stream.on_completed()
self._close_event.set()
def push_audio(self, chunk: np.ndarray):
chunk = np.expand_dims(chunk, axis=0)
if not self._closed:
self.stream.on_next(chunk)
def create_pipeline(SAMPLE_RATE):
diar_pipeline = SpeakerDiarization()
ws_source = WebSocketAudioSource(uri="websocket_source", sample_rate=SAMPLE_RATE)
inference = StreamingInference(
pipeline=diar_pipeline,
source=ws_source,
do_plot=False,
show_progress=False,
)
return inference, ws_source
def init_diart(SAMPLE_RATE):
inference, ws_source = create_pipeline(SAMPLE_RATE)
def diar_hook(result):
"""
Hook called each time Diart processes a chunk.
result is (annotation, audio).
We store the label of the last segment in 'current_speaker'.
"""
global l_speakers
l_speakers = []
annotation, audio = result
for speaker in annotation._labels:
segments_beg = annotation._labels[speaker].segments_boundaries_[0]
segments_end = annotation._labels[speaker].segments_boundaries_[-1]
asyncio.create_task(
l_speakers_queue.put({"speaker": speaker, "beg": segments_beg, "end": segments_end})
)
l_speakers_queue = asyncio.Queue()
inference.attach_hooks(diar_hook)
# Launch Diart in a background thread
loop = asyncio.get_event_loop()
diar_future = loop.run_in_executor(None, inference)
return inference, l_speakers_queue, ws_source
class DiartDiarization():
def __init__(self, SAMPLE_RATE):
self.inference, self.l_speakers_queue, self.ws_source = init_diart(SAMPLE_RATE)
self.segment_speakers = []
async def diarize(self, pcm_array):
self.ws_source.push_audio(pcm_array)
self.segment_speakers = []
while not self.l_speakers_queue.empty():
self.segment_speakers.append(await self.l_speakers_queue.get())
def close(self):
self.ws_source.close()
def assign_speakers_to_chunks(self, chunks):
"""
Go through each chunk and see which speaker(s) overlap
that chunk's time range in the Diart annotation.
Then store the speaker label(s) (or choose the most overlapping).
This modifies `chunks` in-place or returns a new list with assigned speakers.
"""
if not self.segment_speakers:
return chunks
for segment in self.segment_speakers:
seg_beg = segment["beg"]
seg_end = segment["end"]
speaker = segment["speaker"]
for ch in chunks:
if seg_end <= ch["beg"] or seg_beg >= ch["end"]:
continue
# We have overlap. Let's just pick the speaker (could be more precise in a more complex implementation)
ch["speaker"] = speaker
return chunks

Binary file not shown.

Before

Width:  |  Height:  |  Size: 81 KiB

After

Width:  |  Height:  |  Size: 174 KiB

View File

@@ -7,8 +7,8 @@
<style>
body {
font-family: 'Inter', sans-serif;
text-align: center;
margin: 20px;
text-align: center;
}
#recordButton {
width: 80px;
@@ -28,18 +28,10 @@
#recordButton:active {
transform: scale(0.95);
}
#transcriptions {
#status {
margin-top: 20px;
font-size: 18px;
text-align: left;
}
.transcription {
display: inline;
color: black;
}
.buffer {
display: inline;
color: rgb(197, 197, 197);
font-size: 16px;
color: #333;
}
.settings-container {
display: flex;
@@ -73,9 +65,29 @@
label {
font-size: 14px;
}
/* Speaker-labeled transcript area */
#linesTranscript {
margin: 20px auto;
max-width: 600px;
text-align: left;
font-size: 16px;
}
#linesTranscript p {
margin: 5px 0;
}
#linesTranscript strong {
color: #333;
}
/* Grey buffer styling */
.buffer {
color: rgb(180, 180, 180);
font-style: italic;
margin-left: 4px;
}
</style>
</head>
<body>
<div class="settings-container">
<button id="recordButton">🎙️</button>
<div class="settings">
@@ -96,9 +108,11 @@
</div>
</div>
</div>
<p id="status"></p>
<div id="transcriptions"></div>
<!-- Speaker-labeled transcript -->
<div id="linesTranscript"></div>
<script>
let isRecording = false;
@@ -106,89 +120,97 @@
let recorder = null;
let chunkDuration = 1000;
let websocketUrl = "ws://localhost:8000/asr";
// Tracks whether the user voluntarily closed the WebSocket
let userClosing = false;
const statusText = document.getElementById("status");
const recordButton = document.getElementById("recordButton");
const chunkSelector = document.getElementById("chunkSelector");
const websocketInput = document.getElementById("websocketInput");
const transcriptionsDiv = document.getElementById("transcriptions");
const linesTranscriptDiv = document.getElementById("linesTranscript");
let fullTranscription = ""; // Store confirmed transcription
// Update chunk duration based on the selector
chunkSelector.addEventListener("change", () => {
chunkDuration = parseInt(chunkSelector.value);
});
// Update WebSocket URL dynamically, with some basic checks
websocketInput.addEventListener("change", () => {
const urlValue = websocketInput.value.trim();
// Quick check to see if it starts with ws:// or wss://
if (!urlValue.startsWith("ws://") && !urlValue.startsWith("wss://")) {
statusText.textContent =
"Invalid WebSocket URL. It should start with ws:// or wss://";
statusText.textContent = "Invalid WebSocket URL (must start with ws:// or wss://)";
return;
}
websocketUrl = urlValue;
statusText.textContent = "WebSocket URL updated. Ready to connect.";
});
/**
* Opens webSocket connection.
* returns a Promise that resolves when the connection is open.
* rejects if there was an error.
*/
function setupWebSocket() {
return new Promise((resolve, reject) => {
try {
websocket = new WebSocket(websocketUrl);
} catch (error) {
statusText.textContent =
"Invalid WebSocket URL. Please check the URL and try again.";
statusText.textContent = "Invalid WebSocket URL. Please check and try again.";
reject(error);
return;
}
websocket.onopen = () => {
statusText.textContent = "Connected to server";
statusText.textContent = "Connected to server.";
resolve();
};
websocket.onclose = (event) => {
// If we manually closed it, we say so
websocket.onclose = () => {
if (userClosing) {
statusText.textContent = "WebSocket closed by user.";
} else {
statusText.textContent = "Disconnected from the websocket server. If this is the first launch, the model may be downloading in the backend. Check the API logs for more information.";
statusText.textContent =
"Disconnected from the WebSocket server. (Check logs if model is loading.)";
}
userClosing = false;
};
websocket.onerror = () => {
statusText.textContent = "Error connecting to WebSocket";
statusText.textContent = "Error connecting to WebSocket.";
reject(new Error("Error connecting to WebSocket"));
};
// Handle messages from server
websocket.onmessage = (event) => {
const data = JSON.parse(event.data);
const { transcription, buffer } = data;
// Update confirmed transcription
fullTranscription += transcription;
// Update the transcription display
transcriptionsDiv.innerHTML = `
<span class="transcription">${fullTranscription}</span>
<span class="buffer">${buffer}</span>
`;
/*
The server might send:
{
"lines": [
{"speaker": 0, "text": "Hello."},
{"speaker": 1, "text": "Bonjour."},
...
],
"buffer": "..."
}
*/
const { lines = [], buffer = "" } = data;
renderLinesWithBuffer(lines, buffer);
};
});
}
function renderLinesWithBuffer(lines, buffer) {
// Clears if no lines
if (!Array.isArray(lines) || lines.length === 0) {
linesTranscriptDiv.innerHTML = "";
return;
}
// Build the HTML
// The buffer is appended to the last line if it's non-empty
const linesHtml = lines.map((item, idx) => {
let textContent = item.text;
if (idx === lines.length - 1 && buffer) {
textContent += `<span class="buffer">${buffer}</span>`;
}
return `<p><strong>Speaker ${item.speaker}:</strong> ${textContent}</p>`;
}).join("");
linesTranscriptDiv.innerHTML = linesHtml;
}
async function startRecording() {
try {
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
@@ -202,22 +224,18 @@
isRecording = true;
updateUI();
} catch (err) {
statusText.textContent =
"Error accessing microphone. Please allow microphone access.";
statusText.textContent = "Error accessing microphone. Please allow microphone access.";
}
}
function stopRecording() {
userClosing = true;
// Stop the recorder if it exists
if (recorder) {
recorder.stop();
recorder = null;
}
isRecording = false;
// Close the websocket if it exists
if (websocket) {
websocket.close();
websocket = null;
@@ -228,15 +246,12 @@
async function toggleRecording() {
if (!isRecording) {
fullTranscription = "";
transcriptionsDiv.innerHTML = "";
linesTranscriptDiv.innerHTML = "";
try {
await setupWebSocket();
await startRecording();
} catch (err) {
statusText.textContent =
"Could not connect to WebSocket or access mic. Recording aborted.";
statusText.textContent = "Could not connect to WebSocket or access mic. Aborted.";
}
} else {
stopRecording();
@@ -245,9 +260,7 @@
function updateUI() {
recordButton.classList.toggle("recording", isRecording);
statusText.textContent = isRecording
? "Recording..."
: "Click to start transcription";
statusText.textContent = isRecording ? "Recording..." : "Click to start transcription";
}
recordButton.addEventListener("click", toggleRecording);

View File

@@ -215,21 +215,14 @@ class OnlineASRProcessor:
# self.chunk_at(t)
return completed
def chunk_completed_sentence(self, commited_text):
if commited_text == []:
return
sents = self.words_to_sentences(commited_text)
def chunk_completed_sentence(self):
if self.commited == []:
return
raw_text = self.asr.sep.join([s[2] for s in self.commited])
logger.debug(f"COMPLETED SENTENCE: {raw_text}")
sents = self.words_to_sentences(self.commited)
if len(sents) < 2:
@@ -322,7 +315,7 @@ class OnlineASRProcessor:
"""
o = self.transcript_buffer.complete()
f = self.concatenate_tsw(o)
logger.debug(f"last, noncommited: {f[0]*1000:.0f}-{f[1]*1000:.0f}: {f[2]}")
logger.debug(f"last, noncommited: {f[0]*1000:.0f}-{f[1]*1000:.0f}: {f[2][0]*1000:.0f}-{f[1]*1000:.0f}: {f[2]}")
self.buffer_time_offset += len(self.audio_buffer) / 16000
return f
@@ -365,7 +358,7 @@ class VACOnlineASRProcessor(OnlineASRProcessor):
import torch
model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
from silero_vad_iterator import FixedVADIterator
from src.whisper_streaming.silero_vad_iterator import FixedVADIterator
self.vac = FixedVADIterator(
model

View File

@@ -9,7 +9,7 @@ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from whisper_online import backend_factory, online_factory, add_shared_args
from src.whisper_streaming.whisper_online import backend_factory, online_factory, add_shared_args
app = FastAPI()
app.add_middleware(
@@ -37,11 +37,24 @@ parser.add_argument(
dest="warmup_file",
help="The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast. It can be e.g. https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav .",
)
parser.add_argument(
"--diarization",
type=bool,
default=False,
help="Whether to enable speaker diarization.",
)
add_shared_args(parser)
args = parser.parse_args()
asr, tokenizer = backend_factory(args)
if args.diarization:
from src.diarization.diarization_online import DiartDiarization
# Load demo HTML for the root endpoint
with open("src/web/live_transcription.html", "r", encoding="utf-8") as f:
html = f.read()
@@ -78,6 +91,7 @@ async def start_ffmpeg_decoder():
return process
@app.websocket("/asr")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
@@ -89,12 +103,18 @@ async def websocket_endpoint(websocket: WebSocket):
online = online_factory(args, asr, tokenizer)
print("Online loaded.")
if args.diarization:
diarization = DiartDiarization(SAMPLE_RATE)
# Continuously read decoded PCM from ffmpeg stdout in a background task
async def ffmpeg_stdout_reader():
nonlocal pcm_buffer
loop = asyncio.get_event_loop()
full_transcription = ""
beg = time()
chunk_history = [] # Will store dicts: {beg, end, text, speaker}
while True:
try:
elapsed_time = int(time() - beg)
@@ -122,8 +142,17 @@ async def websocket_endpoint(websocket: WebSocket):
)
pcm_buffer = bytearray()
online.insert_audio_chunk(pcm_array)
transcription = online.process_iter()[2]
full_transcription += transcription
beg_trans, end_trans, trans = online.process_iter()
if trans:
chunk_history.append({
"beg": beg_trans,
"end": end_trans,
"text": trans,
"speaker": "0"
})
full_transcription += trans
if args.vac:
buffer = online.online.concatenate_tsw(
online.online.transcript_buffer.buffer
@@ -136,9 +165,32 @@ async def websocket_endpoint(websocket: WebSocket):
buffer in full_transcription
): # With VAC, the buffer is not updated until the next chunk is processed
buffer = ""
await websocket.send_json(
{"transcription": transcription, "buffer": buffer}
)
lines = [
{
"speaker": "0",
"text": "",
}
]
if args.diarization:
await diarization.diarize(pcm_array)
diarization.assign_speakers_to_chunks(chunk_history)
for ch in chunk_history:
if args.diarization and ch["speaker"] and ch["speaker"][-1] != lines[-1]["speaker"]:
lines.append(
{
"speaker": ch["speaker"][-1],
"text": ch['text'],
}
)
else:
lines[-1]["text"] += ch['text']
response = {"lines": lines, "buffer": buffer}
await websocket.send_json(response)
except Exception as e:
print(f"Exception in ffmpeg_stdout_reader: {e}")
break
@@ -174,6 +226,11 @@ async def websocket_endpoint(websocket: WebSocket):
ffmpeg_process.wait()
del online
if args.diarization:
# Stop Diart
diarization.close()