first speaker is "0" no more None

This commit is contained in:
Quentin Fuxa
2025-01-19 19:40:09 +01:00
parent 9bdb92e923
commit 5523b51fd7
3 changed files with 139 additions and 70 deletions

View File

@@ -57,9 +57,10 @@ def init_diart(SAMPLE_RATE):
l_speakers = []
annotation, audio = result
for speaker in annotation._labels:
segment = annotation._labels[speaker].__str__()
segments_beg = annotation._labels[speaker].segments_boundaries_[0]
segments_end = annotation._labels[speaker].segments_boundaries_[-1]
asyncio.create_task(
l_speakers_queue.put({"speaker": speaker, "segment": segment})
l_speakers_queue.put({"speaker": speaker, "beg": segments_beg, "end": segments_end})
)
l_speakers_queue = asyncio.Queue()
@@ -74,13 +75,36 @@ def init_diart(SAMPLE_RATE):
class DiartDiarization():
def __init__(self, SAMPLE_RATE):
self.inference, self.l_speakers_queue, self.ws_source = init_diart(SAMPLE_RATE)
self.segment_speakers = []
async def get_speakers(self, pcm_array):
async def diarize(self, pcm_array):
self.ws_source.push_audio(pcm_array)
speakers = []
self.segment_speakers = []
while not self.l_speakers_queue.empty():
speakers.append(await self.l_speakers_queue.get())
return speakers
self.segment_speakers.append(await self.l_speakers_queue.get())
def close(self):
self.ws_source.close()
def assign_speakers_to_chunks(self, chunks):
"""
Go through each chunk and see which speaker(s) overlap
that chunk's time range in the Diart annotation.
Then store the speaker label(s) (or choose the most overlapping).
This modifies `chunks` in-place or returns a new list with assigned speakers.
"""
if not self.segment_speakers:
return chunks
for segment in self.segment_speakers:
seg_beg = segment["beg"]
seg_end = segment["end"]
speaker = segment["speaker"]
for ch in chunks:
if seg_end <= ch["beg"] or seg_beg >= ch["end"]:
continue
# We have overlap. Let's just pick the speaker (could be more precise in a more complex implementation)
ch["speaker"] = speaker
return chunks

View File

@@ -7,8 +7,8 @@
<style>
body {
font-family: 'Inter', sans-serif;
text-align: center;
margin: 20px;
text-align: center;
}
#recordButton {
width: 80px;
@@ -28,18 +28,10 @@
#recordButton:active {
transform: scale(0.95);
}
#transcriptions {
#status {
margin-top: 20px;
font-size: 18px;
text-align: left;
}
.transcription {
display: inline;
color: black;
}
.buffer {
display: inline;
color: rgb(197, 197, 197);
font-size: 16px;
color: #333;
}
.settings-container {
display: flex;
@@ -73,9 +65,29 @@
label {
font-size: 14px;
}
/* Speaker-labeled transcript area */
#linesTranscript {
margin: 20px auto;
max-width: 600px;
text-align: left;
font-size: 16px;
}
#linesTranscript p {
margin: 5px 0;
}
#linesTranscript strong {
color: #333;
}
/* Grey buffer styling */
.buffer {
color: rgb(180, 180, 180);
font-style: italic;
margin-left: 4px;
}
</style>
</head>
<body>
<div class="settings-container">
<button id="recordButton">🎙️</button>
<div class="settings">
@@ -96,9 +108,11 @@
</div>
</div>
</div>
<p id="status"></p>
<div id="transcriptions"></div>
<!-- Speaker-labeled transcript -->
<div id="linesTranscript"></div>
<script>
let isRecording = false;
@@ -106,89 +120,97 @@
let recorder = null;
let chunkDuration = 1000;
let websocketUrl = "ws://localhost:8000/asr";
// Tracks whether the user voluntarily closed the WebSocket
let userClosing = false;
const statusText = document.getElementById("status");
const recordButton = document.getElementById("recordButton");
const chunkSelector = document.getElementById("chunkSelector");
const websocketInput = document.getElementById("websocketInput");
const transcriptionsDiv = document.getElementById("transcriptions");
const linesTranscriptDiv = document.getElementById("linesTranscript");
let fullTranscription = ""; // Store confirmed transcription
// Update chunk duration based on the selector
chunkSelector.addEventListener("change", () => {
chunkDuration = parseInt(chunkSelector.value);
});
// Update WebSocket URL dynamically, with some basic checks
websocketInput.addEventListener("change", () => {
const urlValue = websocketInput.value.trim();
// Quick check to see if it starts with ws:// or wss://
if (!urlValue.startsWith("ws://") && !urlValue.startsWith("wss://")) {
statusText.textContent =
"Invalid WebSocket URL. It should start with ws:// or wss://";
statusText.textContent = "Invalid WebSocket URL (must start with ws:// or wss://)";
return;
}
websocketUrl = urlValue;
statusText.textContent = "WebSocket URL updated. Ready to connect.";
});
/**
* Opens webSocket connection.
* returns a Promise that resolves when the connection is open.
* rejects if there was an error.
*/
function setupWebSocket() {
return new Promise((resolve, reject) => {
try {
websocket = new WebSocket(websocketUrl);
} catch (error) {
statusText.textContent =
"Invalid WebSocket URL. Please check the URL and try again.";
statusText.textContent = "Invalid WebSocket URL. Please check and try again.";
reject(error);
return;
}
websocket.onopen = () => {
statusText.textContent = "Connected to server";
statusText.textContent = "Connected to server.";
resolve();
};
websocket.onclose = (event) => {
// If we manually closed it, we say so
websocket.onclose = () => {
if (userClosing) {
statusText.textContent = "WebSocket closed by user.";
} else {
statusText.textContent = "Disconnected from the websocket server. If this is the first launch, the model may be downloading in the backend. Check the API logs for more information.";
statusText.textContent =
"Disconnected from the WebSocket server. (Check logs if model is loading.)";
}
userClosing = false;
};
websocket.onerror = () => {
statusText.textContent = "Error connecting to WebSocket";
statusText.textContent = "Error connecting to WebSocket.";
reject(new Error("Error connecting to WebSocket"));
};
// Handle messages from server
websocket.onmessage = (event) => {
const data = JSON.parse(event.data);
const { transcription, buffer } = data;
// Update confirmed transcription
fullTranscription += transcription;
// Update the transcription display
transcriptionsDiv.innerHTML = `
<span class="transcription">${fullTranscription}</span>
<span class="buffer">${buffer}</span>
`;
/*
The server might send:
{
"lines": [
{"speaker": 0, "text": "Hello."},
{"speaker": 1, "text": "Bonjour."},
...
],
"buffer": "..."
}
*/
const { lines = [], buffer = "" } = data;
renderLinesWithBuffer(lines, buffer);
};
});
}
function renderLinesWithBuffer(lines, buffer) {
// Clears if no lines
if (!Array.isArray(lines) || lines.length === 0) {
linesTranscriptDiv.innerHTML = "";
return;
}
// Build the HTML
// The buffer is appended to the last line if it's non-empty
const linesHtml = lines.map((item, idx) => {
let textContent = item.text;
if (idx === lines.length - 1 && buffer) {
textContent += `<span class="buffer">${buffer}</span>`;
}
return `<p><strong>Speaker ${item.speaker}:</strong> ${textContent}</p>`;
}).join("");
linesTranscriptDiv.innerHTML = linesHtml;
}
async function startRecording() {
try {
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
@@ -202,22 +224,18 @@
isRecording = true;
updateUI();
} catch (err) {
statusText.textContent =
"Error accessing microphone. Please allow microphone access.";
statusText.textContent = "Error accessing microphone. Please allow microphone access.";
}
}
function stopRecording() {
userClosing = true;
// Stop the recorder if it exists
if (recorder) {
recorder.stop();
recorder = null;
}
isRecording = false;
// Close the websocket if it exists
if (websocket) {
websocket.close();
websocket = null;
@@ -228,15 +246,12 @@
async function toggleRecording() {
if (!isRecording) {
fullTranscription = "";
transcriptionsDiv.innerHTML = "";
linesTranscriptDiv.innerHTML = "";
try {
await setupWebSocket();
await startRecording();
} catch (err) {
statusText.textContent =
"Could not connect to WebSocket or access mic. Recording aborted.";
statusText.textContent = "Could not connect to WebSocket or access mic. Aborted.";
}
} else {
stopRecording();
@@ -245,9 +260,7 @@
function updateUI() {
recordButton.classList.toggle("recording", isRecording);
statusText.textContent = isRecording
? "Recording..."
: "Click to start transcription";
statusText.textContent = isRecording ? "Recording..." : "Click to start transcription";
}
recordButton.addEventListener("click", toggleRecording);

View File

@@ -90,6 +90,7 @@ async def start_ffmpeg_decoder():
return process
@app.websocket("/asr")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
@@ -110,6 +111,9 @@ async def websocket_endpoint(websocket: WebSocket):
loop = asyncio.get_event_loop()
full_transcription = ""
beg = time()
chunk_history = [] # Will store dicts: {beg, end, text, speaker}
while True:
try:
elapsed_time = int(time() - beg)
@@ -137,8 +141,17 @@ async def websocket_endpoint(websocket: WebSocket):
)
pcm_buffer = bytearray()
online.insert_audio_chunk(pcm_array)
transcription = online.process_iter()[2]
full_transcription += transcription
beg_trans, end_trans, trans = online.process_iter()
if trans:
chunk_history.append({
"beg": beg_trans,
"end": end_trans,
"text": trans,
"speaker": "0"
})
full_transcription += trans
if args.vac:
buffer = online.online.to_flush(
online.online.transcript_buffer.buffer
@@ -151,11 +164,30 @@ async def websocket_endpoint(websocket: WebSocket):
buffer in full_transcription
): # With VAC, the buffer is not updated until the next chunk is processed
buffer = ""
response = {"transcription": transcription, "buffer": buffer}
lines = [
{
"speaker": "0",
"text": "",
}
]
if args.diarization:
speakers = await diarization.get_speakers(pcm_array)
response["speakers"] = speakers
await diarization.diarize(pcm_array)
diarization.assign_speakers_to_chunks(chunk_history)
for ch in chunk_history:
if args.diarization and ch["speaker"] and ch["speaker"][-1] != lines[-1]["speaker"]:
lines.append(
{
"speaker": ch["speaker"][-1],
"text": ch['text'],
}
)
else:
lines[-1]["text"] += ch['text']
response = {"lines": lines, "buffer": buffer}
await websocket.send_json(response)
except Exception as e: